import torch from dataset_library.factory import get_dataset from transformers import ( AutoModelForVision2Seq, AutoProcessor, TrainingArguments, ) from trl import ( TrlParser, get_kbit_device_map, # get_peft_config, get_quantization_config, ) from peft_library import get_peft_model from utils.trainer import ContinualTrainer from utils.args import ContinualScriptArguments, ContinualModelConfig if __name__ == "__main__": parser = TrlParser( (ContinualScriptArguments, TrainingArguments, ContinualModelConfig) ) script_args, training_args, model_args = parser.parse_args_and_config() # for type hint if 0 == 1: script_args = ContinualScriptArguments() training_args = TrainingArguments() model_args = ContinualModelConfig() training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) training_args.remove_unused_columns = False training_args.dataset_kwargs = {"skip_prepare_dataset": True} # peft_config = get_peft_config(dict(**vars(model_args))) torch_dtype = ( model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) quantization_config = get_quantization_config(model_args) model_kwargs = dict( revision=model_args.model_revision, attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) processor = AutoProcessor.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, padding_side="left", ) if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct": # from transformers import Qwen2VLForConditionalGeneration from model_library.qwen2vl import Qwen2VLForConditionalGeneration_modified model = Qwen2VLForConditionalGeneration_modified.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs, ) from model_library.qwen2vl import ( collate_fn_for_train, collate_fn_for_evaluate, ) from functools import partial collate_fn_for_train = partial(collate_fn_for_train, processor=processor) collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor) ################ # Dataset ################ from utils.accelerator import create_accelerator_and_postprocess accelerator = create_accelerator_and_postprocess(training_args) if model_args.peft_type == "MMOELORA": from peft_library.tuners import MMOELoraConfig peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules) model = get_peft_model(model, peft_config) # model = inject_adapter_in_model(peft_config, model) elif model_args.peft_type == "LORA": from peft.tuners.lora import LoraConfig peft_config = LoraConfig(target_modules=model_args.lora_target_modules) model = get_peft_model(model, peft_config) if accelerator.is_local_main_process: model.print_trainable_parameters() else: peft_config = None for dataset_name in script_args.dataset_name: dataset = get_dataset(dataset_name) model.train() trainer = ContinualTrainer( model=model, args=training_args, data_collator=collate_fn_for_train, train_dataset=dataset[script_args.dataset_train_split], eval_dataset=( dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None ), accelerator=accelerator, ) trainer.train() if accelerator.is_local_main_process: print("Saving model") trainer.save_model(training_args.output_dir) if accelerator.is_local_main_process: print("Model saved") # 同步 accelerator accelerator.wait_for_everyone() model.eval() accelerator = trainer.accelerator from torch.utils.data import DataLoader val_dataloader = DataLoader( dataset[script_args.dataset_generation_split], batch_size=3, collate_fn=collate_fn_for_evaluate, ) val_dataloader = accelerator.prepare(val_dataloader) from utils.evaluate_tool import evaluate_rouge evaluate_rouge(model, val_dataloader, processor)