diff --git a/src/evaluation.py b/src/evaluation.py index d28ea68..be5b01e 100644 --- a/src/evaluation.py +++ b/src/evaluation.py @@ -9,11 +9,11 @@ from trl import ( get_quantization_config, ) -from utils.args import ContinualScriptArguments +from utils.args import ContinualScriptArguments, ContinualModelConfig if __name__ == "__main__": - parser = TrlParser((ContinualScriptArguments, TrainingArguments, ModelConfig)) + parser = TrlParser((ContinualScriptArguments, TrainingArguments, ContinualModelConfig)) script_args, training_args, model_args = parser.parse_args_and_config() # for type hint if 0 == 1: diff --git a/src/evaluation.sh b/src/evaluation.sh index e50bf72..a440e82 100755 --- a/src/evaluation.sh +++ b/src/evaluation.sh @@ -1,8 +1,9 @@ #!/bin/bash -accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml evaluation.py \ +accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml evaluation.py \ --dataset_name OCR_VQA_200K \ --use_peft \ + --peft_type MMOELORA \ --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \ --lora_target_modules q_proj v_proj \ --per_device_train_batch_size 1 \ diff --git a/src/todo.md b/src/todo.md index 4538288..8e7bf4a 100644 --- a/src/todo.md +++ b/src/todo.md @@ -11,5 +11,6 @@ [2025.01.03] -- [ ] 处理量化逻辑(懒得调了) +- [ ] 处理量化逻辑 - [ ] 严查moelora的原始代码,太粗糙了😡 +- [ ] 未知原因trainer后处理时间长 diff --git a/src/train.py b/src/train.py index d9a276c..f2d3c2f 100644 --- a/src/train.py +++ b/src/train.py @@ -114,7 +114,13 @@ if __name__ == "__main__": ) trainer.train() + if accelerator.is_local_main_process: + print("Saving model") trainer.save_model(training_args.output_dir) + if accelerator.is_local_main_process: + print("Model saved") + # 同步 accelerator + accelerator.wait_for_everyone() model.eval() accelerator = trainer.accelerator