cl-lmm/src/evaluation.py

145 lines
4.7 KiB
Python

import sys
sys.path.insert(0, "./transformers_repo/src/")
sys.path.insert(0, "./peft_repo/src/")
import torch
from dataset_library.factory import get_dataset
from transformers import (
AutoModelForVision2Seq,
AutoProcessor,
TrainingArguments,
modeling_utils,
)
from trl import (
ModelConfig,
TrlParser,
get_kbit_device_map,
get_quantization_config,
)
from utils.args import ContinualScriptArguments, ContinualModelConfig
if __name__ == "__main__":
parser = TrlParser(
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig)
)
script_args, training_args, model_args = parser.parse_args_and_config()
# for type hint
if 0 == 1:
script_args = ContinualScriptArguments()
training_args = TrainingArguments()
model_args = ModelConfig()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
from model_library.factory import get_model
if model_args.model_name_or_path == "Qwen/Qwen2.5-VL-3B-Instruct":
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
quantization_config=quantization_config,
)
from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
training_args.output_dir,
**model_kwargs,
)
processor = Qwen2_5_VLProcessor.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
padding_side="left",
)
from model_library.qwen2vl import collate_fn_for_train, collate_fn_for_evaluate
from functools import partial
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
elif model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
quantization_config=quantization_config,
)
from transformers import (
Qwen2VLProcessor,
Qwen2VLForConditionalGeneration,
AutoModelForVision2Seq,
AutoModel,
)
from peft.peft_model import PeftModelForCausalLM
from model_library.qwen2vl import Qwen2VLForConditionalGeneration_modified
model = Qwen2VLForConditionalGeneration.from_pretrained(
training_args.output_dir,
**model_kwargs,
)
# from peft_library import get_peft_model
processor = Qwen2VLProcessor.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
padding_side="left",
)
from model_library.qwen2vl import (
collate_fn_for_train,
collate_fn_for_evaluate,
)
from functools import partial
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
################
# Dataset
################
from utils.accelerator import create_accelerator_and_postprocess
accelerator = create_accelerator_and_postprocess(training_args)
accelerator.print(model)
for dataset_name in script_args.dataset_name:
dataset = get_dataset(dataset_name)
model = accelerator.prepare_model(model, evaluation_mode=True)
model.eval()
from torch.utils.data import DataLoader
bs = 3 if dataset_name not in ["scienceqa"] else 1
accelerator.print("Batch size:", bs)
val_dataloader = DataLoader(
dataset[script_args.dataset_generation_split],
batch_size=bs,
collate_fn=collate_fn_for_evaluate,
)
val_dataloader = accelerator.prepare_data_loader(val_dataloader)
from utils.evaluate_tool import evaluate_rouge, evalute_save
evalute_save(model, val_dataloader, processor, accelerator)