92 lines
3.8 KiB
Python
92 lines
3.8 KiB
Python
from transformers import Qwen2VLProcessor
|
|
from dataset_library.format import DatasetOutput
|
|
import torch
|
|
|
|
|
|
def collate_fn_for_train(examples: list[DatasetOutput], processor: Qwen2VLProcessor):
|
|
# Get the texts and images, and apply the chat template
|
|
texts = [
|
|
processor.apply_chat_template(example["chat"], tokenize=False)
|
|
for example in examples
|
|
]
|
|
# print(texts)
|
|
images = [
|
|
example["images"] for example in examples if example["images"][0] is not None
|
|
]
|
|
images = images if len(images) > 0 else None
|
|
|
|
# Tokenize the texts and process the images
|
|
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
|
|
|
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
|
labels = batch["input_ids"].clone()
|
|
labels[labels == processor.tokenizer.pad_token_id] = -100 #
|
|
# Ignore the image token index in the loss computation (model specific)
|
|
# 对<|im_start|>system *** <|im_end|>\n加掩码
|
|
im_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_start|>")
|
|
im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>")
|
|
system_token_id = processor.tokenizer.convert_tokens_to_ids("system")
|
|
user_token_id = processor.tokenizer.convert_tokens_to_ids("user")
|
|
assistant_token_id = processor.tokenizer.convert_tokens_to_ids("assistant")
|
|
# enter_token_id = processor.tokenizer.convert_tokens_to_ids("\n")
|
|
# print(im_start_token_id, im_end_token_id, system_token_id, user_token_id, assistant_token_id, enter_token_id, processor.tokenizer.pad_token_id)
|
|
# 151644 151645 8948 872 77091 None 151643
|
|
|
|
# for i, label in enumerate(labels):
|
|
# now_index = 0
|
|
# while now_index < len(label):
|
|
# if label[now_index] == im_start_token_id:
|
|
# label[now_index] = -100
|
|
# now_index += 1
|
|
# if (
|
|
# label[now_index] == system_token_id
|
|
# or label[now_index] == user_token_id
|
|
# ):
|
|
# while label[now_index] != im_end_token_id:
|
|
# label[now_index] = -100
|
|
# now_index += 1
|
|
# label[now_index] = -100
|
|
# elif label[now_index] == assistant_token_id:
|
|
# label[now_index] = -100
|
|
# label[now_index + 1] = -100
|
|
# now_index += 2
|
|
# while (
|
|
# now_index < len(label) and label[now_index] != im_end_token_id
|
|
# ):
|
|
# now_index += 1
|
|
# now_index += 1
|
|
batch["labels"] = labels
|
|
# batch["task_id"] = torch.tensor([0] * len(labels), dtype=torch.long)
|
|
|
|
return batch
|
|
|
|
|
|
def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
|
|
# Get the texts and images, and apply the chat template
|
|
texts = [
|
|
processor.apply_chat_template(
|
|
example["chat"], tokenize=False, add_generation_prompt=True
|
|
)
|
|
for example in examples
|
|
]
|
|
# print(texts)
|
|
images = [
|
|
example["images"] for example in examples if example["images"][0] is not None
|
|
]
|
|
images = images if len(images) > 0 else None
|
|
|
|
# Tokenize the texts and process the images
|
|
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
|
|
|
answers = [example["answer"] for example in examples]
|
|
answers = processor(text=answers, return_tensors="pt", padding=True)
|
|
batch["answers_ids"] = answers["input_ids"]
|
|
batch["original_data"] = [example["original"] for example in examples]
|
|
# input_ids torch.Size([3, 370])
|
|
# attention_mask torch.Size([3, 370])
|
|
# pixel_values torch.Size([3888, 1176])
|
|
# image_grid_thw torch.Size([3, 3])
|
|
# answers_ids torch.Size([3, 10])
|
|
# answers_mask torch.Size([3, 10])
|
|
return batch
|