feat✨: 更新评估脚本以支持新的持续学习模型配置,修正路径并增强训练过程中的模型保存逻辑
This commit is contained in:
parent
ce206d213c
commit
9a06d6a237
@ -9,11 +9,11 @@ from trl import (
|
|||||||
get_quantization_config,
|
get_quantization_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
from utils.args import ContinualScriptArguments
|
from utils.args import ContinualScriptArguments, ContinualModelConfig
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = TrlParser((ContinualScriptArguments, TrainingArguments, ModelConfig))
|
parser = TrlParser((ContinualScriptArguments, TrainingArguments, ContinualModelConfig))
|
||||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||||
# for type hint
|
# for type hint
|
||||||
if 0 == 1:
|
if 0 == 1:
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml evaluation.py \
|
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml evaluation.py \
|
||||||
--dataset_name OCR_VQA_200K \
|
--dataset_name OCR_VQA_200K \
|
||||||
--use_peft \
|
--use_peft \
|
||||||
|
--peft_type MMOELORA \
|
||||||
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
||||||
--lora_target_modules q_proj v_proj \
|
--lora_target_modules q_proj v_proj \
|
||||||
--per_device_train_batch_size 1 \
|
--per_device_train_batch_size 1 \
|
||||||
|
@ -11,5 +11,6 @@
|
|||||||
|
|
||||||
[2025.01.03]
|
[2025.01.03]
|
||||||
|
|
||||||
- [ ] 处理量化逻辑(懒得调了)
|
- [ ] 处理量化逻辑
|
||||||
- [ ] 严查moelora的原始代码,太粗糙了😡
|
- [ ] 严查moelora的原始代码,太粗糙了😡
|
||||||
|
- [ ] 未知原因trainer后处理时间长
|
||||||
|
@ -114,7 +114,13 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
if accelerator.is_local_main_process:
|
||||||
|
print("Saving model")
|
||||||
trainer.save_model(training_args.output_dir)
|
trainer.save_model(training_args.output_dir)
|
||||||
|
if accelerator.is_local_main_process:
|
||||||
|
print("Model saved")
|
||||||
|
# 同步 accelerator
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
accelerator = trainer.accelerator
|
accelerator = trainer.accelerator
|
||||||
|
Loading…
Reference in New Issue
Block a user