添加安装脚本和依赖文件,重命名评估脚本,更新训练脚本以使用模型名称,删除临时评估文件,完成训练与测试的整体框架
This commit is contained in:
parent
f336496d8e
commit
aef0f6834e
3
install.sh
Normal file
3
install.sh
Normal file
@ -0,0 +1,3 @@
|
||||
#!/bin/bash
|
||||
uv venv --python 3.11.7
|
||||
pip install -U torch==2.5.1+cu124 torchvision==0.20.1+cu124 torchaudio==2.5.1+cu124 --extra-index-url https://download.pytorch.org/whl/cu124
|
15
requirements.txt
Normal file
15
requirements.txt
Normal file
@ -0,0 +1,15 @@
|
||||
accelerate==1.2.1
|
||||
deepspeed==0.16.2
|
||||
evaluate==0.4.3
|
||||
networkx==3.2.1
|
||||
ninja==1.11.1.3
|
||||
numpy==1.26.3
|
||||
packaging==24.2
|
||||
pandas==2.2.3
|
||||
peft==0.14.0
|
||||
pillow==10.2.0
|
||||
torch==2.5.1+cu124
|
||||
torchaudio==2.5.1+cu124
|
||||
torchvision==0.20.1+cu124
|
||||
transformers==4.47.1
|
||||
trl==0.13.0
|
@ -1,156 +0,0 @@
|
||||
import torch
|
||||
from datasets_library.factory import get_dataset
|
||||
from transformers import (
|
||||
AutoModelForVision2Seq,
|
||||
AutoProcessor,
|
||||
)
|
||||
|
||||
from trl import (
|
||||
ModelConfig,
|
||||
SFTScriptArguments,
|
||||
SFTConfig,
|
||||
SFTTrainer,
|
||||
TrlParser,
|
||||
get_kbit_device_map,
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
from peft import get_peft_model
|
||||
|
||||
from utils.trainer import ContinualTrainer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
script_args: SFTScriptArguments = script_args
|
||||
training_args: SFTConfig = training_args
|
||||
model_args: ModelConfig = model_args
|
||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||
training_args.remove_unused_columns = False
|
||||
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
||||
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype
|
||||
if model_args.torch_dtype in ["auto", None]
|
||||
else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
padding_side="left",
|
||||
)
|
||||
|
||||
model = AutoModelForVision2Seq.from_pretrained(
|
||||
training_args.output_dir,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
|
||||
from collatefn_library.qwen2 import (
|
||||
collate_fn_for_train,
|
||||
collate_fn_for_evaluate,
|
||||
)
|
||||
from functools import partial
|
||||
|
||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||
|
||||
################
|
||||
# Dataset
|
||||
################
|
||||
|
||||
dataset = get_dataset(script_args)
|
||||
|
||||
# peft_config = get_peft_config(model_args)
|
||||
# model = get_peft_model(model, peft_config)
|
||||
# 仅在rank1 rank2 rank3时打印
|
||||
if torch.distributed.get_rank() in [1]:
|
||||
print(model)
|
||||
|
||||
# _________________________________________________________
|
||||
model.train()
|
||||
import copy
|
||||
|
||||
training_args_init = copy.copy(training_args)
|
||||
training_args_init.do_train = False
|
||||
training_args_init.do_eval = False
|
||||
training_args_init.do_predict = False
|
||||
training_args_init.num_train_epochs = 0
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args_init,
|
||||
data_collator=collate_fn_for_train,
|
||||
train_dataset=dataset[script_args.dataset_train_split],
|
||||
eval_dataset=(
|
||||
dataset[script_args.dataset_test_split]
|
||||
if training_args.eval_strategy != "no"
|
||||
else None
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
model.eval()
|
||||
accelerator = trainer.accelerator
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
val_dataloader = DataLoader(
|
||||
dataset["generation"],
|
||||
batch_size=3,
|
||||
collate_fn=collate_fn_for_evaluate,
|
||||
)
|
||||
val_dataloader = accelerator.prepare(val_dataloader)
|
||||
from utils.evaluate_tool import evaluate_rouge
|
||||
|
||||
evaluate_rouge(model, val_dataloader, processor)
|
||||
|
||||
model.train()
|
||||
trainer = ContinualTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
data_collator=collate_fn_for_train,
|
||||
train_dataset=dataset[script_args.dataset_train_split],
|
||||
eval_dataset=(
|
||||
dataset[script_args.dataset_test_split]
|
||||
if training_args.eval_strategy != "no"
|
||||
else None
|
||||
),
|
||||
accelerator=accelerator,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
trainer.save_model(training_args.output_dir)
|
||||
|
||||
# 清理cache
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# load_model
|
||||
from transformers import AutoModelForVision2Seq
|
||||
|
||||
model = AutoModelForVision2Seq.from_pretrained(training_args.output_dir)
|
||||
model = accelerator.prepare(model)
|
||||
|
||||
model.eval()
|
||||
accelerator = trainer.accelerator
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
val_dataloader = DataLoader(
|
||||
dataset["generation"],
|
||||
batch_size=3,
|
||||
collate_fn=collate_fn_for_evaluate,
|
||||
)
|
||||
val_dataloader = accelerator.prepare(val_dataloader)
|
||||
from utils.evaluate_tool import evaluate_rouge
|
||||
|
||||
evaluate_rouge(model, val_dataloader, processor)
|
92
src/evaluation.py
Normal file
92
src/evaluation.py
Normal file
@ -0,0 +1,92 @@
|
||||
import torch
|
||||
from datasets_library.factory import get_dataset
|
||||
from transformers import AutoModelForVision2Seq, AutoProcessor, TrainingArguments
|
||||
|
||||
from trl import (
|
||||
ModelConfig,
|
||||
TrlParser,
|
||||
get_kbit_device_map,
|
||||
get_quantization_config,
|
||||
)
|
||||
|
||||
from utils.args import ContinualScriptArguments
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = TrlParser((ContinualScriptArguments, TrainingArguments, ModelConfig))
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
# for type hint
|
||||
if 0 == 1:
|
||||
script_args = ContinualScriptArguments()
|
||||
training_args = TrainingArguments()
|
||||
model_args = ModelConfig()
|
||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||
training_args.remove_unused_columns = False
|
||||
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
||||
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype
|
||||
if model_args.torch_dtype in ["auto", None]
|
||||
else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
padding_side="left",
|
||||
)
|
||||
|
||||
model = AutoModelForVision2Seq.from_pretrained(
|
||||
training_args.output_dir,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
|
||||
from collatefn_library.qwen2 import (
|
||||
collate_fn_for_train,
|
||||
collate_fn_for_evaluate,
|
||||
)
|
||||
from functools import partial
|
||||
|
||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||
|
||||
# peft_config = get_peft_config(model_args)
|
||||
# model = get_peft_model(model, peft_config)
|
||||
|
||||
################
|
||||
# Dataset
|
||||
################
|
||||
|
||||
from utils.accelerator import create_accelerator_and_postprocess
|
||||
|
||||
accelerator = create_accelerator_and_postprocess(training_args)
|
||||
|
||||
if accelerator.is_local_main_process:
|
||||
print(model)
|
||||
|
||||
for dataset_name in script_args.dataset_name:
|
||||
dataset = get_dataset(dataset_name)
|
||||
model = accelerator.prepare_model(model, evaluation_mode=True)
|
||||
|
||||
model.eval()
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
val_dataloader = DataLoader(
|
||||
dataset[script_args.dataset_generation_split],
|
||||
batch_size=3,
|
||||
collate_fn=collate_fn_for_evaluate,
|
||||
)
|
||||
val_dataloader = accelerator.prepare_data_loader(val_dataloader)
|
||||
from utils.evaluate_tool import evaluate_rouge
|
||||
|
||||
evaluate_rouge(model, val_dataloader, processor)
|
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml evaluate_1.py \
|
||||
accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml evaluation.py \
|
||||
--dataset_name OCR_VQA_200K \
|
||||
--use_peft \
|
||||
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
||||
@ -8,7 +8,6 @@ accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml evaluate
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 2 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--max_seq_length 1024 \
|
||||
--output_dir checkpoint/sft-llava-1.5-7b-hf \
|
||||
--bf16 \
|
||||
--torch_dtype bfloat16
|
@ -47,7 +47,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
model = AutoModelForVision2Seq.from_pretrained(
|
||||
training_args.output_dir,
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
@ -1,7 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml train.py \
|
||||
--dataset_name OCR_VQA_200K \
|
||||
--dataset_name OCR_VQA_200K OCR_VQA_200K OCR_VQA_200K \
|
||||
--use_peft \
|
||||
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
||||
--lora_target_modules q_proj v_proj \
|
||||
@ -11,4 +11,3 @@ accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml train.py
|
||||
--output_dir checkpoint/sft-llava-1.5-7b-hf \
|
||||
--bf16 \
|
||||
--torch_dtype bfloat16
|
||||
# --eval_strategy epoch \
|
||||
|
Loading…
Reference in New Issue
Block a user