feat: 更新评估脚本以支持新的持续学习模型配置,修正路径并增强训练过程中的模型保存逻辑

This commit is contained in:
YunyaoZhou 2025-01-04 00:48:31 +08:00
parent ce206d213c
commit 9a06d6a237
Signed by: shujakuin
GPG Key ID: 418C3CA28E350CCF
4 changed files with 12 additions and 4 deletions

View File

@ -9,11 +9,11 @@ from trl import (
get_quantization_config,
)
from utils.args import ContinualScriptArguments
from utils.args import ContinualScriptArguments, ContinualModelConfig
if __name__ == "__main__":
parser = TrlParser((ContinualScriptArguments, TrainingArguments, ModelConfig))
parser = TrlParser((ContinualScriptArguments, TrainingArguments, ContinualModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
# for type hint
if 0 == 1:

View File

@ -1,8 +1,9 @@
#!/bin/bash
accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml evaluation.py \
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml evaluation.py \
--dataset_name OCR_VQA_200K \
--use_peft \
--peft_type MMOELORA \
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
--lora_target_modules q_proj v_proj \
--per_device_train_batch_size 1 \

View File

@ -11,5 +11,6 @@
[2025.01.03]
- [ ] 处理量化逻辑(懒得调了)
- [ ] 处理量化逻辑
- [ ] 严查moelora的原始代码太粗糙了😡
- [ ] 未知原因trainer后处理时间长

View File

@ -114,7 +114,13 @@ if __name__ == "__main__":
)
trainer.train()
if accelerator.is_local_main_process:
print("Saving model")
trainer.save_model(training_args.output_dir)
if accelerator.is_local_main_process:
print("Model saved")
# 同步 accelerator
accelerator.wait_for_everyone()
model.eval()
accelerator = trainer.accelerator