Refactor code structure for improved readability and maintainability
This commit is contained in:
parent
3fe2c85f6b
commit
0bc1034f35
15
.vscode/settings.json
vendored
15
.vscode/settings.json
vendored
@ -1,14 +1,19 @@
|
|||||||
{
|
{
|
||||||
"python.analysis.extraPaths": [
|
"python.analysis.extraPaths": [
|
||||||
"src/transformers_repo/src/",
|
"./src/peft_repo/src/",
|
||||||
"src/peft_repo/src/"
|
"./src/transformers_repo/src/",
|
||||||
],
|
],
|
||||||
"python.analysis.exclude": [
|
"python.analysis.exclude": [
|
||||||
"dataset/**/*",
|
"dataset/**/*",
|
||||||
],
|
],
|
||||||
"python.languageServer": "Default",
|
"python.languageServer": "Default",
|
||||||
"python.terminal.activateEnvInCurrentTerminal": true,
|
"python.terminal.activateEnvInCurrentTerminal": true,
|
||||||
"python.analysis.include": [
|
// "python.analysis.include": [
|
||||||
"src/**/*"
|
// "src/**/*"
|
||||||
]
|
// ],
|
||||||
|
"python.analysis.languageServerMode": "default",
|
||||||
|
"python.analysis.typeCheckingMode": "basic",
|
||||||
|
"python.analysis.userFileIndexingLimit": -1,
|
||||||
|
"python.analysis.usePullDiagnostics": false,
|
||||||
|
"python.analysis.importFormat": "relative"
|
||||||
}
|
}
|
@ -2,6 +2,7 @@
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"absl-py>=2.1.0",
|
"absl-py>=2.1.0",
|
||||||
"accelerate==1.2.1",
|
"accelerate==1.2.1",
|
||||||
|
"calflops>=0.3.2",
|
||||||
"datasets==3.2.0",
|
"datasets==3.2.0",
|
||||||
"deepspeed==0.16.2",
|
"deepspeed==0.16.2",
|
||||||
"evaluate==0.4.3",
|
"evaluate==0.4.3",
|
||||||
@ -20,9 +21,9 @@ dependencies = [
|
|||||||
"safetensors>=0.5.2",
|
"safetensors>=0.5.2",
|
||||||
"setuptools>=70.0.0",
|
"setuptools>=70.0.0",
|
||||||
"soundfile>=0.13.0",
|
"soundfile>=0.13.0",
|
||||||
"torch==2.5.1+cu124",
|
"torch==2.6.0",
|
||||||
"torchaudio==2.5.1+cu124",
|
"torchaudio==2.6.0",
|
||||||
"torchvision==0.20.1+cu124",
|
"torchvision==0.21.0",
|
||||||
"transformers==4.48.0",
|
"transformers==4.48.0",
|
||||||
"trl==0.13.0",
|
"trl==0.13.0",
|
||||||
"wandb>=0.19.4",
|
"wandb>=0.19.4",
|
||||||
|
@ -11,7 +11,7 @@ machine_rank: 0
|
|||||||
main_training_function: main
|
main_training_function: main
|
||||||
mixed_precision: 'bf16'
|
mixed_precision: 'bf16'
|
||||||
num_machines: 1
|
num_machines: 1
|
||||||
num_processes: 4
|
num_processes: 3
|
||||||
rdzv_backend: static
|
rdzv_backend: static
|
||||||
same_network: true
|
same_network: true
|
||||||
tpu_env: []
|
tpu_env: []
|
||||||
|
@ -89,4 +89,28 @@ def get_model(model_args: ContinualModelConfig):
|
|||||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||||
|
|
||||||
|
if model_args.model_name_or_path == "Qwen/Qwen2.5-Omni-3B":
|
||||||
|
from transformers.models.qwen2_5_omni import (
|
||||||
|
Qwen2_5OmniThinkerForConditionalGeneration,
|
||||||
|
Qwen2_5OmniProcessor
|
||||||
|
)
|
||||||
|
|
||||||
|
model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
processor = Qwen2_5OmniProcessor.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
padding_side="left",
|
||||||
|
)
|
||||||
|
|
||||||
|
from model_library.qwen2vl import collate_fn_for_train, collate_fn_for_evaluate
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||||
|
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||||
|
|
||||||
return model, processor, collate_fn_for_train, collate_fn_for_evaluate
|
return model, processor, collate_fn_for_train, collate_fn_for_evaluate
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
from transformers import Qwen2AudioProcessor
|
from transformers import Qwen2AudioProcessor
|
||||||
|
from dataset_library.format import Conversation
|
||||||
|
|
||||||
|
|
||||||
def collate_fn_for_train(examples, processor: Qwen2AudioProcessor):
|
def collate_fn_for_train(examples:list[Conversation], processor: Qwen2AudioProcessor):
|
||||||
# Get the texts and images, and apply the chat template
|
# Get the texts and images, and apply the chat template
|
||||||
texts = [
|
texts = [
|
||||||
processor.apply_chat_template(example["chat"], tokenize=False)
|
processor.apply_chat_template(example["chat"], tokenize=False)
|
||||||
@ -60,7 +61,7 @@ def collate_fn_for_train(examples, processor: Qwen2AudioProcessor):
|
|||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
def collate_fn_for_evaluate(examples, processor: Qwen2AudioProcessor):
|
def collate_fn_for_evaluate(examples:list[Conversation], processor: Qwen2AudioProcessor):
|
||||||
# Get the texts and images, and apply the chat template
|
# Get the texts and images, and apply the chat template
|
||||||
texts = [
|
texts = [
|
||||||
processor.apply_chat_template(
|
processor.apply_chat_template(
|
||||||
@ -91,3 +92,4 @@ def collate_fn_for_evaluate(examples, processor: Qwen2AudioProcessor):
|
|||||||
# answers_ids torch.Size([3, 10])
|
# answers_ids torch.Size([3, 10])
|
||||||
# answers_mask torch.Size([3, 10])
|
# answers_mask torch.Size([3, 10])
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
@ -1,9 +1,22 @@
|
|||||||
from transformers import Qwen2VLProcessor
|
# from transformers import Qwen2VLProcessor
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.insert(0, "transformers_repo/src/")
|
||||||
|
sys.path.insert(0, "peft_repo/src/")
|
||||||
|
import transformers
|
||||||
|
import peft
|
||||||
from dataset_library.format import DatasetOutput
|
from dataset_library.format import DatasetOutput
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
def collate_fn_for_train(examples: list[DatasetOutput], processor: Qwen2VLProcessor):
|
if TYPE_CHECKING:
|
||||||
|
from transformers import Qwen2VLProcessor
|
||||||
|
from transformers import Qwen2_5_VLProcessor
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn_for_train(examples: list[DatasetOutput], processor: "Qwen2VLProcessor"):
|
||||||
# Get the texts and images, and apply the chat template
|
# Get the texts and images, and apply the chat template
|
||||||
texts = [
|
texts = [
|
||||||
processor.apply_chat_template(example["chat"], tokenize=False)
|
processor.apply_chat_template(example["chat"], tokenize=False)
|
||||||
@ -32,36 +45,36 @@ def collate_fn_for_train(examples: list[DatasetOutput], processor: Qwen2VLProces
|
|||||||
# print(im_start_token_id, im_end_token_id, system_token_id, user_token_id, assistant_token_id, enter_token_id, processor.tokenizer.pad_token_id)
|
# print(im_start_token_id, im_end_token_id, system_token_id, user_token_id, assistant_token_id, enter_token_id, processor.tokenizer.pad_token_id)
|
||||||
# 151644 151645 8948 872 77091 None 151643
|
# 151644 151645 8948 872 77091 None 151643
|
||||||
|
|
||||||
# for i, label in enumerate(labels):
|
for i, label in enumerate(labels):
|
||||||
# now_index = 0
|
now_index = 0
|
||||||
# while now_index < len(label):
|
while now_index < len(label):
|
||||||
# if label[now_index] == im_start_token_id:
|
if label[now_index] == im_start_token_id:
|
||||||
# label[now_index] = -100
|
label[now_index] = -100
|
||||||
# now_index += 1
|
now_index += 1
|
||||||
# if (
|
if (
|
||||||
# label[now_index] == system_token_id
|
label[now_index] == system_token_id
|
||||||
# or label[now_index] == user_token_id
|
or label[now_index] == user_token_id
|
||||||
# ):
|
):
|
||||||
# while label[now_index] != im_end_token_id:
|
while label[now_index] != im_end_token_id:
|
||||||
# label[now_index] = -100
|
label[now_index] = -100
|
||||||
# now_index += 1
|
now_index += 1
|
||||||
# label[now_index] = -100
|
label[now_index] = -100
|
||||||
# elif label[now_index] == assistant_token_id:
|
elif label[now_index] == assistant_token_id:
|
||||||
# label[now_index] = -100
|
label[now_index] = -100
|
||||||
# label[now_index + 1] = -100
|
label[now_index + 1] = -100
|
||||||
# now_index += 2
|
now_index += 2
|
||||||
# while (
|
while (
|
||||||
# now_index < len(label) and label[now_index] != im_end_token_id
|
now_index < len(label) and label[now_index] != im_end_token_id
|
||||||
# ):
|
):
|
||||||
# now_index += 1
|
now_index += 1
|
||||||
# now_index += 1
|
now_index += 1
|
||||||
batch["labels"] = labels
|
batch["labels"] = labels
|
||||||
# batch["task_id"] = torch.tensor([0] * len(labels), dtype=torch.long)
|
# batch["task_id"] = torch.tensor([0] * len(labels), dtype=torch.long)
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
|
def collate_fn_for_evaluate(examples, processor: "Qwen2VLProcessor"):
|
||||||
# Get the texts and images, and apply the chat template
|
# Get the texts and images, and apply the chat template
|
||||||
texts = [
|
texts = [
|
||||||
processor.apply_chat_template(
|
processor.apply_chat_template(
|
||||||
@ -89,3 +102,68 @@ def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
|
|||||||
# answers_ids torch.Size([3, 10])
|
# answers_ids torch.Size([3, 10])
|
||||||
# answers_mask torch.Size([3, 10])
|
# answers_mask torch.Size([3, 10])
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
# 随机生成一个图片
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
random_image = Image.fromarray(
|
||||||
|
np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
||||||
|
)
|
||||||
|
example = {
|
||||||
|
"chat": [
|
||||||
|
# {"role": "user", "content": "What is the capital of France?"},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"image_url": "", # Assuming no image for this example
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"image_url": "", # Assuming no image for this example
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"image_url": "", # Assuming no image for this example
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What is the capital of France?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}, # Assuming no image for this example
|
||||||
|
{"role": "assistant", "content": "The capital of France is Paris."},
|
||||||
|
],
|
||||||
|
"images": [
|
||||||
|
random_image,
|
||||||
|
random_image,
|
||||||
|
random_image,
|
||||||
|
], # Assuming no images for this example
|
||||||
|
}
|
||||||
|
batch = collate_fn_for_train([example], processor)
|
||||||
|
# print(batch)
|
||||||
|
for k, v in batch.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
print(f"{k}: {v.shape}")
|
||||||
|
else:
|
||||||
|
print(f"{k}: {v}")
|
||||||
|
# input_ids: torch.Size([1, 101])
|
||||||
|
# attention_mask: torch.Size([1, 101])
|
||||||
|
# pixel_values: torch.Size([256, 1176])
|
||||||
|
# image_grid_thw: torch.Size([1, 3])
|
||||||
|
# labels: torch.Size([1, 101])
|
||||||
|
# Load model directly
|
||||||
|
from transformers.models.qwen2_5_omni import Qwen2_5OmniThinkerForConditionalGeneration, Qwen2_5OmniProcessor
|
||||||
|
model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-Omni-3B", torch_dtype="auto", device_map="auto")
|
||||||
|
processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-3B")
|
||||||
|
|
||||||
|
|
||||||
|
35
src/test.py
Normal file
35
src/test.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
# import sys
|
||||||
|
|
||||||
|
# sys.path.insert(0, "transformers_repo/src/")
|
||||||
|
# sys.path.insert(0, "peft_repo/src/")
|
||||||
|
|
||||||
|
# from calflops import calculate_flops_hf
|
||||||
|
|
||||||
|
# batch_size, max_seq_length = 1, 128
|
||||||
|
# model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||||
|
|
||||||
|
# flops, macs, params = calculate_flops_hf(model_name=model_name, input_shape=(batch_size, max_seq_length), access_token="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")
|
||||||
|
# print("%s FLOPs:%s MACs:%s Params:%s \n" %(model_name, flops, macs, params))
|
||||||
|
|
||||||
|
# Transformers Model, such as bert.
|
||||||
|
from calflops import calculate_flops
|
||||||
|
from transformers import AutoModel
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
batch_size, max_seq_length = 1, 128
|
||||||
|
# Load model directly
|
||||||
|
from transformers import AutoProcessor, AutoModelForImageTextToText
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
|
||||||
|
model = AutoModelForImageTextToText.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
|
||||||
|
|
||||||
|
flops, macs, params = calculate_flops(
|
||||||
|
model=model,
|
||||||
|
input_shape=(batch_size, max_seq_length),
|
||||||
|
transformer_tokenizer=tokenizer,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"Bert(hfl/chinese-roberta-wwm-ext) FLOPs:%s MACs:%s Params:%s \n"
|
||||||
|
% (flops, macs, params)
|
||||||
|
)
|
||||||
|
# Bert(hfl/chinese-roberta-wwm-ext) FLOPs:67.1 GFLOPS MACs:33.52 GMACs Params:102.27 M
|
16
src/train.py
16
src/train.py
@ -13,7 +13,7 @@ from trl import (
|
|||||||
TrlParser,
|
TrlParser,
|
||||||
)
|
)
|
||||||
from utils.trainer import ContinualTrainer
|
from utils.trainer import ContinualTrainer
|
||||||
from utils.args import ContinualScriptArguments, ContinualModelConfig
|
from utils.args import ContinualScriptArguments, ContinualModelConfig, ContiunalRegularizationArguments
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
@ -23,18 +23,19 @@ logging.basicConfig(level=logging.INFO)
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = TrlParser(
|
parser = TrlParser(
|
||||||
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig)
|
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig, ContiunalRegularizationArguments)
|
||||||
)
|
)
|
||||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
script_args, training_args, model_args, reg_args = parser.parse_args_and_config()
|
||||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
|
||||||
training_args.remove_unused_columns = False
|
|
||||||
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
|
||||||
|
|
||||||
# for type hint
|
# for type hint
|
||||||
if 1 == 0:
|
if 1 == 0:
|
||||||
script_args = ContinualScriptArguments
|
script_args = ContinualScriptArguments
|
||||||
training_args = TrainingArguments
|
training_args = TrainingArguments
|
||||||
model_args = ContinualModelConfig
|
model_args = ContinualModelConfig
|
||||||
|
reg_args = ContiunalRegularizationArguments
|
||||||
|
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||||
|
training_args.remove_unused_columns = False
|
||||||
|
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
||||||
|
|
||||||
|
|
||||||
from model_library.factory import get_model
|
from model_library.factory import get_model
|
||||||
|
|
||||||
@ -91,6 +92,7 @@ if __name__ == "__main__":
|
|||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
accelerator=accelerator,
|
accelerator=accelerator,
|
||||||
|
reg_args=reg_args,
|
||||||
)
|
)
|
||||||
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||||
|
|
||||||
|
@ -1,18 +1,19 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero1.yaml train.py \
|
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero1.yaml train.py \
|
||||||
--dataset_name refcoco \
|
--dataset_name textvqa \
|
||||||
--use_peft \
|
--use_peft \
|
||||||
--peft_type LORA \
|
--peft_type LORA \
|
||||||
--model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
|
--model_name_or_path Qwen/Qwen2.5-Omni-3B \
|
||||||
--lora_target_modules .\*proj.\*\|.\*fc.\*\|.\*mlp\.0\|.\*mlp\.2 \
|
--lora_target_modules .\*proj.\*\|.\*fc.\*\|.\*mlp\.0\|.\*mlp\.2 \
|
||||||
--lora_r 8 \
|
--lora_r 8 \
|
||||||
--lora_alpha 32 \
|
--lora_alpha 32 \
|
||||||
--per_device_train_batch_size 1 \
|
--per_device_train_batch_size 3 \
|
||||||
--per_device_eval_batch_size 1 \
|
--per_device_eval_batch_size 1 \
|
||||||
--gradient_accumulation_steps 4 \
|
--gradient_accumulation_steps 4 \
|
||||||
|
--num_train_epochs 1 \
|
||||||
--output_dir checkpoint/qwen2_alllinear/ \
|
--output_dir checkpoint/qwen2_alllinear/ \
|
||||||
--learning_rate 1e-4 \
|
--learning_rate 5e-5 \
|
||||||
--bf16 \
|
--bf16 \
|
||||||
--torch_dtype bfloat16 \
|
--torch_dtype bfloat16 \
|
||||||
--logging_steps 30 \
|
--logging_steps 30 \
|
||||||
|
@ -1 +1 @@
|
|||||||
Subproject commit 684f12be1c8f26c46b1eebad50ce21ce6e3378b3
|
Subproject commit c8a4ee5b9daf9865b372a483fd04a984f0b265dc
|
@ -7,6 +7,29 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
from .args import ContiunalRegularizationArguments
|
from .args import ContiunalRegularizationArguments
|
||||||
from peft_library.regularizations import EWC, LWF
|
from peft_library.regularizations import EWC, LWF
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
|
|
||||||
|
def ce_loss_func(outputs, labels, num_items_in_batch=None, **kwargs):
|
||||||
|
logits = outputs.logits
|
||||||
|
device = logits.device
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :]
|
||||||
|
shift_labels = labels[..., 1:].to(device)
|
||||||
|
# Save memory
|
||||||
|
masks = shift_labels != -100
|
||||||
|
shift_logits = shift_logits[masks]
|
||||||
|
shift_labels = shift_labels[masks]
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss(reduction="none")
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
if num_items_in_batch is None:
|
||||||
|
loss = loss.mean()
|
||||||
|
else:
|
||||||
|
# compat transformers>=4.46
|
||||||
|
loss = loss.sum() / num_items_in_batch
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
class ContinualTrainer(Trainer):
|
class ContinualTrainer(Trainer):
|
||||||
@ -18,17 +41,30 @@ class ContinualTrainer(Trainer):
|
|||||||
train_dataset,
|
train_dataset,
|
||||||
eval_dataset,
|
eval_dataset,
|
||||||
accelerator,
|
accelerator,
|
||||||
regularization_args: ContiunalRegularizationArguments = None,
|
reg_args: ContiunalRegularizationArguments = None,
|
||||||
):
|
):
|
||||||
self.accelerator = accelerator
|
self.accelerator = accelerator
|
||||||
super().__init__(model, args, data_collator, train_dataset, eval_dataset)
|
super().__init__(
|
||||||
|
model,
|
||||||
|
args,
|
||||||
|
data_collator,
|
||||||
|
train_dataset,
|
||||||
|
eval_dataset,
|
||||||
|
# compute_loss_func=ce_loss_func,
|
||||||
|
)
|
||||||
|
|
||||||
# if regularization_args.ewc_enable:
|
if reg_args.ewc_enable:
|
||||||
# self.ewc_lambda = regularization_args.ewc_lambda
|
self.ewc_lambda = reg_args.ewc_lambda
|
||||||
# # fisher = t
|
from peft_library.regularizations.ewc import EWC
|
||||||
|
|
||||||
# if regularization_args.lwf_enable:
|
self.EWC = EWC()
|
||||||
# self.lwf_lambda = regularization_args.lwf_lambda
|
# fisher = t
|
||||||
|
|
||||||
|
if reg_args.lwf_enable:
|
||||||
|
self.lwf_lambda = reg_args.lwf_lambda
|
||||||
|
from peft_library.regularizations.lwf import LWF
|
||||||
|
|
||||||
|
self.LWF = LWF()
|
||||||
|
|
||||||
def create_accelerator_and_postprocess(self):
|
def create_accelerator_and_postprocess(self):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user