From 9a06d6a237b35cea167fd8c686d70df94d527fd5 Mon Sep 17 00:00:00 2001 From: YunyaoZhou Date: Sat, 4 Jan 2025 00:48:31 +0800 Subject: [PATCH] =?UTF-8?q?feat=E2=9C=A8:=20=E6=9B=B4=E6=96=B0=E8=AF=84?= =?UTF-8?q?=E4=BC=B0=E8=84=9A=E6=9C=AC=E4=BB=A5=E6=94=AF=E6=8C=81=E6=96=B0?= =?UTF-8?q?=E7=9A=84=E6=8C=81=E7=BB=AD=E5=AD=A6=E4=B9=A0=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=EF=BC=8C=E4=BF=AE=E6=AD=A3=E8=B7=AF=E5=BE=84?= =?UTF-8?q?=E5=B9=B6=E5=A2=9E=E5=BC=BA=E8=AE=AD=E7=BB=83=E8=BF=87=E7=A8=8B?= =?UTF-8?q?=E4=B8=AD=E7=9A=84=E6=A8=A1=E5=9E=8B=E4=BF=9D=E5=AD=98=E9=80=BB?= =?UTF-8?q?=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/evaluation.py | 4 ++-- src/evaluation.sh | 3 ++- src/todo.md | 3 ++- src/train.py | 6 ++++++ 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/evaluation.py b/src/evaluation.py index d28ea68..be5b01e 100644 --- a/src/evaluation.py +++ b/src/evaluation.py @@ -9,11 +9,11 @@ from trl import ( get_quantization_config, ) -from utils.args import ContinualScriptArguments +from utils.args import ContinualScriptArguments, ContinualModelConfig if __name__ == "__main__": - parser = TrlParser((ContinualScriptArguments, TrainingArguments, ModelConfig)) + parser = TrlParser((ContinualScriptArguments, TrainingArguments, ContinualModelConfig)) script_args, training_args, model_args = parser.parse_args_and_config() # for type hint if 0 == 1: diff --git a/src/evaluation.sh b/src/evaluation.sh index e50bf72..a440e82 100755 --- a/src/evaluation.sh +++ b/src/evaluation.sh @@ -1,8 +1,9 @@ #!/bin/bash -accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml evaluation.py \ +accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml evaluation.py \ --dataset_name OCR_VQA_200K \ --use_peft \ + --peft_type MMOELORA \ --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \ --lora_target_modules q_proj v_proj \ --per_device_train_batch_size 1 \ diff --git a/src/todo.md b/src/todo.md index 4538288..8e7bf4a 100644 --- a/src/todo.md +++ b/src/todo.md @@ -11,5 +11,6 @@ [2025.01.03] -- [ ] 处理量化逻辑(懒得调了) +- [ ] 处理量化逻辑 - [ ] 严查moelora的原始代码,太粗糙了😡 +- [ ] 未知原因trainer后处理时间长 diff --git a/src/train.py b/src/train.py index d9a276c..f2d3c2f 100644 --- a/src/train.py +++ b/src/train.py @@ -114,7 +114,13 @@ if __name__ == "__main__": ) trainer.train() + if accelerator.is_local_main_process: + print("Saving model") trainer.save_model(training_args.output_dir) + if accelerator.is_local_main_process: + print("Model saved") + # 同步 accelerator + accelerator.wait_for_everyone() model.eval() accelerator = trainer.accelerator