151 lines
4.3 KiB
Python
151 lines
4.3 KiB
Python
import sys
|
|
|
|
sys.path.insert(0, "transformers_repo/src/")
|
|
sys.path.insert(0, "peft_repo/src/")
|
|
|
|
from dataset_library.factory import get_dataset
|
|
|
|
from transformers import (
|
|
TrainingArguments,
|
|
)
|
|
|
|
from trl import (
|
|
TrlParser,
|
|
)
|
|
from utils.trainer import ContinualTrainer
|
|
from utils.args import (
|
|
ContinualScriptArguments,
|
|
ContinualModelConfig,
|
|
ContinualRegularizationArguments,
|
|
)
|
|
import logging
|
|
from typing import TYPE_CHECKING
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = TrlParser(
|
|
(
|
|
ContinualScriptArguments,
|
|
TrainingArguments,
|
|
ContinualModelConfig,
|
|
ContinualRegularizationArguments,
|
|
) # type: ignore
|
|
)
|
|
script_args, training_args, model_args, reg_args = parser.parse_args_and_config()
|
|
# for type hint
|
|
if TYPE_CHECKING:
|
|
script_args = ContinualScriptArguments()
|
|
training_args = TrainingArguments()
|
|
model_args = ContinualModelConfig()
|
|
reg_args = ContinualRegularizationArguments()
|
|
|
|
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
|
training_args.remove_unused_columns = False
|
|
|
|
from model_library.factory import get_model
|
|
|
|
model, processor, collate_fn_for_train, collate_fn_for_evaluate = get_model(
|
|
model_args=model_args
|
|
)
|
|
################
|
|
# Dataset
|
|
################
|
|
|
|
from utils.accelerator import create_accelerator_and_postprocess
|
|
|
|
accelerator = create_accelerator_and_postprocess(training_args)
|
|
|
|
if model_args.peft_type == "MMOELORA":
|
|
from peft.tuners import MMOELoraConfig
|
|
|
|
peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules)
|
|
|
|
model.add_adapter(peft_config)
|
|
|
|
elif model_args.peft_type == "MOELORA":
|
|
from peft.tuners import MOELoraConfig
|
|
|
|
peft_config = MOELoraConfig(target_modules=model_args.lora_target_modules)
|
|
|
|
model.add_adapter(peft_config)
|
|
|
|
elif model_args.peft_type == "LORA":
|
|
from peft.tuners.lora import LoraConfig
|
|
|
|
peft_config = LoraConfig(
|
|
target_modules=model_args.lora_target_modules,
|
|
r=model_args.lora_r,
|
|
lora_alpha=model_args.lora_alpha,
|
|
lora_dropout=model_args.lora_dropout,
|
|
)
|
|
|
|
model.add_adapter(peft_config)
|
|
|
|
elif model_args.peft_type == "OLORA":
|
|
from peft.tuners import LoraConfig
|
|
|
|
peft_config = LoraConfig(
|
|
target_modules=model_args.lora_target_modules,
|
|
r=model_args.lora_r,
|
|
lora_alpha=model_args.lora_alpha,
|
|
lora_dropout=model_args.lora_dropout,
|
|
init_lora_weights="olora"
|
|
)
|
|
|
|
model.add_adapter(peft_config)
|
|
|
|
else:
|
|
peft_config = None
|
|
|
|
from peft import get_peft_model
|
|
|
|
if accelerator.is_local_main_process:
|
|
print(model)
|
|
|
|
for dataset_name in script_args.dataset_name:
|
|
dataset = get_dataset(dataset_name)
|
|
model.train()
|
|
|
|
trainer = ContinualTrainer(
|
|
model=model,
|
|
args=training_args,
|
|
data_collator=collate_fn_for_train,
|
|
train_dataset=dataset[script_args.dataset_train_split], # type: ignore
|
|
eval_dataset=(
|
|
dataset[script_args.dataset_test_split] # type: ignore
|
|
if training_args.eval_strategy != "no"
|
|
else None
|
|
),
|
|
accelerator=accelerator,
|
|
reg_args=reg_args,
|
|
)
|
|
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
|
|
|
if accelerator.is_local_main_process:
|
|
print("Saving model")
|
|
# trainer.save_model(training_args.output_dir)
|
|
model.save_pretrained(training_args.output_dir)
|
|
|
|
if accelerator.is_local_main_process:
|
|
print("Model saved")
|
|
# 同步 accelerator
|
|
accelerator.wait_for_everyone()
|
|
|
|
# model.eval()
|
|
|
|
# from torch.utils.data import DataLoader
|
|
|
|
# val_dataloader = DataLoader(
|
|
# dataset[script_args.dataset_generation_split],
|
|
# batch_size=3,
|
|
# collate_fn=collate_fn_for_evaluate,
|
|
# )
|
|
# val_dataloader = accelerator.prepare(val_dataloader)
|
|
|
|
# from utils.evaluate_tool import evaluate_save
|
|
|
|
# evaluate_save(model, val_dataloader, processor, accelerator)
|