cl-lmm/src/model_library/qwen2vl/collate_fn.py

92 lines
3.8 KiB
Python

from transformers import Qwen2VLProcessor
from dataset_library.format import DatasetOutput
import torch
def collate_fn_for_train(examples: list[DatasetOutput], processor: Qwen2VLProcessor):
# Get the texts and images, and apply the chat template
texts = [
processor.apply_chat_template(example["chat"], tokenize=False)
for example in examples
]
# print(texts)
images = [
example["images"] for example in examples if example["images"][0] is not None
]
images = images if len(images) > 0 else None
# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100 #
# Ignore the image token index in the loss computation (model specific)
# 对<|im_start|>system *** <|im_end|>\n加掩码
im_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_start|>")
im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>")
system_token_id = processor.tokenizer.convert_tokens_to_ids("system")
user_token_id = processor.tokenizer.convert_tokens_to_ids("user")
assistant_token_id = processor.tokenizer.convert_tokens_to_ids("assistant")
# enter_token_id = processor.tokenizer.convert_tokens_to_ids("\n")
# print(im_start_token_id, im_end_token_id, system_token_id, user_token_id, assistant_token_id, enter_token_id, processor.tokenizer.pad_token_id)
# 151644 151645 8948 872 77091 None 151643
# for i, label in enumerate(labels):
# now_index = 0
# while now_index < len(label):
# if label[now_index] == im_start_token_id:
# label[now_index] = -100
# now_index += 1
# if (
# label[now_index] == system_token_id
# or label[now_index] == user_token_id
# ):
# while label[now_index] != im_end_token_id:
# label[now_index] = -100
# now_index += 1
# label[now_index] = -100
# elif label[now_index] == assistant_token_id:
# label[now_index] = -100
# label[now_index + 1] = -100
# now_index += 2
# while (
# now_index < len(label) and label[now_index] != im_end_token_id
# ):
# now_index += 1
# now_index += 1
batch["labels"] = labels
# batch["task_id"] = torch.tensor([0] * len(labels), dtype=torch.long)
return batch
def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
# Get the texts and images, and apply the chat template
texts = [
processor.apply_chat_template(
example["chat"], tokenize=False, add_generation_prompt=True
)
for example in examples
]
# print(texts)
images = [
example["images"] for example in examples if example["images"][0] is not None
]
images = images if len(images) > 0 else None
# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
answers = [example["answer"] for example in examples]
answers = processor(text=answers, return_tensors="pt", padding=True)
batch["answers_ids"] = answers["input_ids"]
batch["original_data"] = [example["original"] for example in examples]
# input_ids torch.Size([3, 370])
# attention_mask torch.Size([3, 370])
# pixel_values torch.Size([3888, 1176])
# image_grid_thw torch.Size([3, 3])
# answers_ids torch.Size([3, 10])
# answers_mask torch.Size([3, 10])
return batch