feat: 更新评估脚本以支持新的持续学习模型配置,修正路径并增强训练过程中的模型保存逻辑

This commit is contained in:
YunyaoZhou 2025-01-04 00:48:31 +08:00
parent ce206d213c
commit 9a06d6a237
Signed by: shujakuin
GPG Key ID: 418C3CA28E350CCF
4 changed files with 12 additions and 4 deletions

View File

@ -9,11 +9,11 @@ from trl import (
get_quantization_config, get_quantization_config,
) )
from utils.args import ContinualScriptArguments from utils.args import ContinualScriptArguments, ContinualModelConfig
if __name__ == "__main__": if __name__ == "__main__":
parser = TrlParser((ContinualScriptArguments, TrainingArguments, ModelConfig)) parser = TrlParser((ContinualScriptArguments, TrainingArguments, ContinualModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config() script_args, training_args, model_args = parser.parse_args_and_config()
# for type hint # for type hint
if 0 == 1: if 0 == 1:

View File

@ -1,8 +1,9 @@
#!/bin/bash #!/bin/bash
accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml evaluation.py \ accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml evaluation.py \
--dataset_name OCR_VQA_200K \ --dataset_name OCR_VQA_200K \
--use_peft \ --use_peft \
--peft_type MMOELORA \
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \ --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
--lora_target_modules q_proj v_proj \ --lora_target_modules q_proj v_proj \
--per_device_train_batch_size 1 \ --per_device_train_batch_size 1 \

View File

@ -11,5 +11,6 @@
[2025.01.03] [2025.01.03]
- [ ] 处理量化逻辑(懒得调了) - [ ] 处理量化逻辑
- [ ] 严查moelora的原始代码太粗糙了😡 - [ ] 严查moelora的原始代码太粗糙了😡
- [ ] 未知原因trainer后处理时间长

View File

@ -114,7 +114,13 @@ if __name__ == "__main__":
) )
trainer.train() trainer.train()
if accelerator.is_local_main_process:
print("Saving model")
trainer.save_model(training_args.output_dir) trainer.save_model(training_args.output_dir)
if accelerator.is_local_main_process:
print("Model saved")
# 同步 accelerator
accelerator.wait_for_everyone()
model.eval() model.eval()
accelerator = trainer.accelerator accelerator = trainer.accelerator