cl-lmm/src/train.py

141 lines
4.6 KiB
Python

import torch
from dataset_library.factory import get_dataset
from transformers import (
AutoModelForVision2Seq,
AutoProcessor,
TrainingArguments,
)
from trl import (
TrlParser,
get_kbit_device_map,
# get_peft_config,
get_quantization_config,
)
from peft_library import get_peft_model
from utils.trainer import ContinualTrainer
from utils.args import ContinualScriptArguments, ContinualModelConfig
if __name__ == "__main__":
parser = TrlParser(
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig)
)
script_args, training_args, model_args = parser.parse_args_and_config()
# for type hint
if 0 == 1:
script_args = ContinualScriptArguments()
training_args = TrainingArguments()
model_args = ContinualModelConfig()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
# peft_config = get_peft_config(dict(**vars(model_args)))
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
processor = AutoProcessor.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
padding_side="left",
)
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
# from transformers import Qwen2VLForConditionalGeneration
from model_library.qwen2vl import Qwen2VLForConditionalGeneration_modified
model = Qwen2VLForConditionalGeneration_modified.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
from model_library.qwen2vl import (
collate_fn_for_train,
collate_fn_for_evaluate,
)
from functools import partial
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
################
# Dataset
################
from utils.accelerator import create_accelerator_and_postprocess
accelerator = create_accelerator_and_postprocess(training_args)
if model_args.peft_type == "MMOELORA":
from peft_library.tuners import MMOELoraConfig
peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules)
model = get_peft_model(model, peft_config)
# model = inject_adapter_in_model(peft_config, model)
elif model_args.peft_type == "LORA":
from peft.tuners.lora import LoraConfig
peft_config = LoraConfig(target_modules=model_args.lora_target_modules)
model = get_peft_model(model, peft_config)
if accelerator.is_local_main_process:
model.print_trainable_parameters()
else:
peft_config = None
for dataset_name in script_args.dataset_name:
dataset = get_dataset(dataset_name)
model.train()
trainer = ContinualTrainer(
model=model,
args=training_args,
data_collator=collate_fn_for_train,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=(
dataset[script_args.dataset_test_split]
if training_args.eval_strategy != "no"
else None
),
accelerator=accelerator,
)
trainer.train()
if accelerator.is_local_main_process:
print("Saving model")
trainer.save_model(training_args.output_dir)
if accelerator.is_local_main_process:
print("Model saved")
# 同步 accelerator
accelerator.wait_for_everyone()
model.eval()
accelerator = trainer.accelerator
from torch.utils.data import DataLoader
val_dataloader = DataLoader(
dataset[script_args.dataset_generation_split],
batch_size=3,
collate_fn=collate_fn_for_evaluate,
)
val_dataloader = accelerator.prepare(val_dataloader)
from utils.evaluate_tool import evaluate_rouge
evaluate_rouge(model, val_dataloader, processor)