111 lines
3.4 KiB
Python
111 lines
3.4 KiB
Python
import sys
|
|
|
|
sys.path.insert(0, "./transformers_repo/src/")
|
|
sys.path.insert(0, "./peft_repo/src/")
|
|
|
|
import torch
|
|
from dataset_library.factory import get_dataset
|
|
from transformers import (
|
|
AutoModelForVision2Seq,
|
|
AutoProcessor,
|
|
TrainingArguments,
|
|
modeling_utils,
|
|
)
|
|
|
|
from trl import (
|
|
ModelConfig,
|
|
TrlParser,
|
|
get_kbit_device_map,
|
|
get_quantization_config,
|
|
)
|
|
|
|
from utils.args import ContinualScriptArguments, ContinualModelConfig
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = TrlParser(
|
|
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig)
|
|
)
|
|
script_args, training_args, model_args = parser.parse_args_and_config()
|
|
# for type hint
|
|
if 0 == 1:
|
|
script_args = ContinualScriptArguments()
|
|
training_args = TrainingArguments()
|
|
model_args = ModelConfig()
|
|
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
|
training_args.remove_unused_columns = False
|
|
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
|
|
|
from model_library.factory import get_model
|
|
|
|
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
|
|
torch_dtype = (
|
|
model_args.torch_dtype
|
|
if model_args.torch_dtype in ["auto", None]
|
|
else getattr(torch, model_args.torch_dtype)
|
|
)
|
|
quantization_config = get_quantization_config(model_args)
|
|
model_kwargs = dict(
|
|
attn_implementation=model_args.attn_implementation,
|
|
torch_dtype=torch_dtype,
|
|
quantization_config=quantization_config,
|
|
)
|
|
from transformers import (
|
|
Qwen2VLProcessor,
|
|
Qwen2VLForConditionalGeneration,
|
|
AutoModelForVision2Seq,
|
|
AutoModel,
|
|
)
|
|
from peft.peft_model import PeftModelForCausalLM
|
|
from model_library.qwen2vl import Qwen2VLForConditionalGeneration_modified
|
|
|
|
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
training_args.output_dir,
|
|
**model_kwargs,
|
|
)
|
|
|
|
# from peft_library import get_peft_model
|
|
|
|
processor = Qwen2VLProcessor.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
trust_remote_code=model_args.trust_remote_code,
|
|
padding_side="left",
|
|
)
|
|
from model_library.qwen2vl import (
|
|
collate_fn_for_train,
|
|
collate_fn_for_evaluate,
|
|
)
|
|
from functools import partial
|
|
|
|
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
|
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
|
|
|
################
|
|
# Dataset
|
|
################
|
|
|
|
from utils.accelerator import create_accelerator_and_postprocess
|
|
|
|
accelerator = create_accelerator_and_postprocess(training_args)
|
|
|
|
if accelerator.is_local_main_process:
|
|
print(model)
|
|
|
|
for dataset_name in script_args.dataset_name:
|
|
dataset = get_dataset(dataset_name)
|
|
model = accelerator.prepare_model(model, evaluation_mode=True)
|
|
|
|
model.eval()
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
val_dataloader = DataLoader(
|
|
dataset[script_args.dataset_generation_split],
|
|
batch_size=3,
|
|
collate_fn=collate_fn_for_evaluate,
|
|
)
|
|
val_dataloader = accelerator.prepare_data_loader(val_dataloader)
|
|
from utils.evaluate_tool import evaluate_rouge, evalute_save
|
|
|
|
evalute_save(model, val_dataloader, processor, accelerator)
|