cl-lmm/src/train.py

151 lines
4.3 KiB
Python

import sys
sys.path.insert(0, "transformers_repo/src/")
sys.path.insert(0, "peft_repo/src/")
from dataset_library.factory import get_dataset
from transformers import (
TrainingArguments,
)
from trl import (
TrlParser,
)
from utils.trainer import ContinualTrainer
from utils.args import (
ContinualScriptArguments,
ContinualModelConfig,
ContinualRegularizationArguments,
)
import logging
from typing import TYPE_CHECKING
logging.basicConfig(level=logging.INFO)
if __name__ == "__main__":
parser = TrlParser(
(
ContinualScriptArguments,
TrainingArguments,
ContinualModelConfig,
ContinualRegularizationArguments,
) # type: ignore
)
script_args, training_args, model_args, reg_args = parser.parse_args_and_config()
# for type hint
if TYPE_CHECKING:
script_args = ContinualScriptArguments()
training_args = TrainingArguments()
model_args = ContinualModelConfig()
reg_args = ContinualRegularizationArguments()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False
from model_library.factory import get_model
model, processor, collate_fn_for_train, collate_fn_for_evaluate = get_model(
model_args=model_args
)
################
# Dataset
################
from utils.accelerator import create_accelerator_and_postprocess
accelerator = create_accelerator_and_postprocess(training_args)
if model_args.peft_type == "MMOELORA":
from peft.tuners import MMOELoraConfig
peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules)
model.add_adapter(peft_config)
elif model_args.peft_type == "MOELORA":
from peft.tuners import MOELoraConfig
peft_config = MOELoraConfig(target_modules=model_args.lora_target_modules)
model.add_adapter(peft_config)
elif model_args.peft_type == "LORA":
from peft.tuners.lora import LoraConfig
peft_config = LoraConfig(
target_modules=model_args.lora_target_modules,
r=model_args.lora_r,
lora_alpha=model_args.lora_alpha,
lora_dropout=model_args.lora_dropout,
)
model.add_adapter(peft_config)
elif model_args.peft_type == "OLORA":
from peft.tuners import LoraConfig
peft_config = LoraConfig(
target_modules=model_args.lora_target_modules,
r=model_args.lora_r,
lora_alpha=model_args.lora_alpha,
lora_dropout=model_args.lora_dropout,
init_lora_weights="olora"
)
model.add_adapter(peft_config)
else:
peft_config = None
from peft import get_peft_model
if accelerator.is_local_main_process:
print(model)
for dataset_name in script_args.dataset_name:
dataset = get_dataset(dataset_name)
model.train()
trainer = ContinualTrainer(
model=model,
args=training_args,
data_collator=collate_fn_for_train,
train_dataset=dataset[script_args.dataset_train_split], # type: ignore
eval_dataset=(
dataset[script_args.dataset_test_split] # type: ignore
if training_args.eval_strategy != "no"
else None
),
accelerator=accelerator,
reg_args=reg_args,
)
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
if accelerator.is_local_main_process:
print("Saving model")
# trainer.save_model(training_args.output_dir)
model.save_pretrained(training_args.output_dir)
if accelerator.is_local_main_process:
print("Model saved")
# 同步 accelerator
accelerator.wait_for_everyone()
# model.eval()
# from torch.utils.data import DataLoader
# val_dataloader = DataLoader(
# dataset[script_args.dataset_generation_split],
# batch_size=3,
# collate_fn=collate_fn_for_evaluate,
# )
# val_dataloader = accelerator.prepare(val_dataloader)
# from utils.evaluate_tool import evaluate_save
# evaluate_save(model, val_dataloader, processor, accelerator)