141 lines
4.6 KiB
Python
141 lines
4.6 KiB
Python
import torch
|
|
from dataset_library.factory import get_dataset
|
|
from transformers import (
|
|
AutoModelForVision2Seq,
|
|
AutoProcessor,
|
|
TrainingArguments,
|
|
)
|
|
|
|
from trl import (
|
|
TrlParser,
|
|
get_kbit_device_map,
|
|
# get_peft_config,
|
|
get_quantization_config,
|
|
)
|
|
from peft_library import get_peft_model
|
|
|
|
from utils.trainer import ContinualTrainer
|
|
from utils.args import ContinualScriptArguments, ContinualModelConfig
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = TrlParser(
|
|
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig)
|
|
)
|
|
script_args, training_args, model_args = parser.parse_args_and_config()
|
|
# for type hint
|
|
if 0 == 1:
|
|
script_args = ContinualScriptArguments()
|
|
training_args = TrainingArguments()
|
|
model_args = ContinualModelConfig()
|
|
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
|
training_args.remove_unused_columns = False
|
|
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
|
|
|
# peft_config = get_peft_config(dict(**vars(model_args)))
|
|
|
|
torch_dtype = (
|
|
model_args.torch_dtype
|
|
if model_args.torch_dtype in ["auto", None]
|
|
else getattr(torch, model_args.torch_dtype)
|
|
)
|
|
quantization_config = get_quantization_config(model_args)
|
|
model_kwargs = dict(
|
|
revision=model_args.model_revision,
|
|
attn_implementation=model_args.attn_implementation,
|
|
torch_dtype=torch_dtype,
|
|
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
|
quantization_config=quantization_config,
|
|
)
|
|
processor = AutoProcessor.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
trust_remote_code=model_args.trust_remote_code,
|
|
padding_side="left",
|
|
)
|
|
|
|
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
|
|
# from transformers import Qwen2VLForConditionalGeneration
|
|
from model_library.qwen2vl import Qwen2VLForConditionalGeneration_modified
|
|
|
|
model = Qwen2VLForConditionalGeneration_modified.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
trust_remote_code=model_args.trust_remote_code,
|
|
**model_kwargs,
|
|
)
|
|
from model_library.qwen2vl import (
|
|
collate_fn_for_train,
|
|
collate_fn_for_evaluate,
|
|
)
|
|
from functools import partial
|
|
|
|
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
|
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
|
|
|
################
|
|
# Dataset
|
|
################
|
|
|
|
from utils.accelerator import create_accelerator_and_postprocess
|
|
|
|
accelerator = create_accelerator_and_postprocess(training_args)
|
|
|
|
if model_args.peft_type == "MMOELORA":
|
|
from peft_library.tuners import MMOELoraConfig
|
|
|
|
peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules)
|
|
|
|
model = get_peft_model(model, peft_config)
|
|
# model = inject_adapter_in_model(peft_config, model)
|
|
elif model_args.peft_type == "LORA":
|
|
from peft.tuners.lora import LoraConfig
|
|
|
|
peft_config = LoraConfig(target_modules=model_args.lora_target_modules)
|
|
|
|
model = get_peft_model(model, peft_config)
|
|
if accelerator.is_local_main_process:
|
|
model.print_trainable_parameters()
|
|
|
|
else:
|
|
peft_config = None
|
|
|
|
for dataset_name in script_args.dataset_name:
|
|
dataset = get_dataset(dataset_name)
|
|
model.train()
|
|
|
|
trainer = ContinualTrainer(
|
|
model=model,
|
|
args=training_args,
|
|
data_collator=collate_fn_for_train,
|
|
train_dataset=dataset[script_args.dataset_train_split],
|
|
eval_dataset=(
|
|
dataset[script_args.dataset_test_split]
|
|
if training_args.eval_strategy != "no"
|
|
else None
|
|
),
|
|
accelerator=accelerator,
|
|
)
|
|
trainer.train()
|
|
|
|
if accelerator.is_local_main_process:
|
|
print("Saving model")
|
|
trainer.save_model(training_args.output_dir)
|
|
if accelerator.is_local_main_process:
|
|
print("Model saved")
|
|
# 同步 accelerator
|
|
accelerator.wait_for_everyone()
|
|
|
|
model.eval()
|
|
accelerator = trainer.accelerator
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
val_dataloader = DataLoader(
|
|
dataset[script_args.dataset_generation_split],
|
|
batch_size=3,
|
|
collate_fn=collate_fn_for_evaluate,
|
|
)
|
|
val_dataloader = accelerator.prepare(val_dataloader)
|
|
from utils.evaluate_tool import evaluate_rouge
|
|
|
|
evaluate_rouge(model, val_dataloader, processor)
|