import sys sys.path.insert(0, "transformers_repo/src/") sys.path.insert(0, "peft_repo/src/") from dataset_library.factory import get_dataset from transformers import ( TrainingArguments, ) from trl import ( TrlParser, ) from utils.trainer import ContinualTrainer from utils.args import ( ContinualScriptArguments, ContinualModelConfig, ContinualRegularizationArguments, ) import logging from typing import TYPE_CHECKING logging.basicConfig(level=logging.INFO) if __name__ == "__main__": parser = TrlParser( ( ContinualScriptArguments, TrainingArguments, ContinualModelConfig, ContinualRegularizationArguments, ) # type: ignore ) script_args, training_args, model_args, reg_args = parser.parse_args_and_config() # for type hint if TYPE_CHECKING: script_args = ContinualScriptArguments() training_args = TrainingArguments() model_args = ContinualModelConfig() reg_args = ContinualRegularizationArguments() training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) training_args.remove_unused_columns = False from model_library.factory import get_model model, processor, collate_fn_for_train, collate_fn_for_evaluate = get_model( model_args=model_args ) ################ # Dataset ################ from utils.accelerator import create_accelerator_and_postprocess accelerator = create_accelerator_and_postprocess(training_args) if model_args.peft_type == "MMOELORA": from peft.tuners import MMOELoraConfig peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules) model.add_adapter(peft_config) elif model_args.peft_type == "MOELORA": from peft.tuners import MOELoraConfig peft_config = MOELoraConfig(target_modules=model_args.lora_target_modules) model.add_adapter(peft_config) elif model_args.peft_type == "LORA": from peft.tuners.lora import LoraConfig peft_config = LoraConfig( target_modules=model_args.lora_target_modules, r=model_args.lora_r, lora_alpha=model_args.lora_alpha, lora_dropout=model_args.lora_dropout, ) model.add_adapter(peft_config) elif model_args.peft_type == "OLORA": from peft.tuners import LoraConfig peft_config = LoraConfig( target_modules=model_args.lora_target_modules, r=model_args.lora_r, lora_alpha=model_args.lora_alpha, lora_dropout=model_args.lora_dropout, init_lora_weights="olora" ) model.add_adapter(peft_config) else: peft_config = None from peft import get_peft_model if accelerator.is_local_main_process: print(model) for dataset_name in script_args.dataset_name: dataset = get_dataset(dataset_name) model.train() trainer = ContinualTrainer( model=model, args=training_args, data_collator=collate_fn_for_train, train_dataset=dataset[script_args.dataset_train_split], # type: ignore eval_dataset=( dataset[script_args.dataset_test_split] # type: ignore if training_args.eval_strategy != "no" else None ), accelerator=accelerator, reg_args=reg_args, ) trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) if accelerator.is_local_main_process: print("Saving model") # trainer.save_model(training_args.output_dir) model.save_pretrained(training_args.output_dir) if accelerator.is_local_main_process: print("Model saved") # 同步 accelerator accelerator.wait_for_everyone() # model.eval() # from torch.utils.data import DataLoader # val_dataloader = DataLoader( # dataset[script_args.dataset_generation_split], # batch_size=3, # collate_fn=collate_fn_for_evaluate, # ) # val_dataloader = accelerator.prepare(val_dataloader) # from utils.evaluate_tool import evaluate_save # evaluate_save(model, val_dataloader, processor, accelerator)