85 lines
3.6 KiB
Python
85 lines
3.6 KiB
Python
from transformers import Qwen2VLProcessor
|
|
import torch
|
|
|
|
|
|
def collate_fn_for_train(examples, processor: Qwen2VLProcessor):
|
|
# Get the texts and images, and apply the chat template
|
|
texts = [
|
|
processor.apply_chat_template(example["chat"], tokenize=False)
|
|
for example in examples
|
|
]
|
|
# print(texts)
|
|
images = [example["image"] for example in examples]
|
|
# Tokenize the texts and process the images
|
|
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
|
|
|
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
|
labels = batch["input_ids"].clone()
|
|
labels[labels == processor.tokenizer.pad_token_id] = -100 #
|
|
# Ignore the image token index in the loss computation (model specific)
|
|
# 对<|im_start|>system *** <|im_end|>\n加掩码
|
|
im_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_start|>")
|
|
im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>")
|
|
system_token_id = processor.tokenizer.convert_tokens_to_ids("system")
|
|
user_token_id = processor.tokenizer.convert_tokens_to_ids("user")
|
|
assistant_token_id = processor.tokenizer.convert_tokens_to_ids("assistant")
|
|
# enter_token_id = processor.tokenizer.convert_tokens_to_ids("\n")
|
|
# print(im_start_token_id, im_end_token_id, system_token_id, user_token_id, assistant_token_id, enter_token_id, processor.tokenizer.pad_token_id)
|
|
# 151644 151645 8948 872 77091 None 151643
|
|
|
|
for i, label in enumerate(labels):
|
|
now_index = 0
|
|
while now_index < len(label):
|
|
if label[now_index] == im_start_token_id:
|
|
label[now_index] = -100
|
|
now_index += 1
|
|
if (
|
|
label[now_index] == system_token_id
|
|
or label[now_index] == user_token_id
|
|
):
|
|
while label[now_index] != im_end_token_id:
|
|
label[now_index] = -100
|
|
now_index += 1
|
|
label[now_index] = -100
|
|
elif label[now_index] == assistant_token_id:
|
|
label[now_index] = -100
|
|
label[now_index + 1] = -100
|
|
now_index += 2
|
|
while (
|
|
now_index < len(label) and label[now_index] != im_end_token_id
|
|
):
|
|
now_index += 1
|
|
now_index += 1
|
|
batch["labels"] = labels
|
|
# batch["task_id"] = torch.tensor([0] * len(labels), dtype=torch.long)
|
|
|
|
return batch
|
|
|
|
|
|
def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
|
|
# Get the texts and images, and apply the chat template
|
|
texts = [
|
|
processor.apply_chat_template(
|
|
example["chat"], tokenize=False, add_generation_prompt=True
|
|
)
|
|
for example in examples
|
|
]
|
|
# print(texts)
|
|
images = [example["image"] for example in examples]
|
|
|
|
# Tokenize the texts and process the images
|
|
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
|
|
|
answers = [example["answer"] for example in examples]
|
|
answers = processor(text=answers, return_tensors="pt", padding=True)
|
|
batch["answers_ids"] = answers["input_ids"]
|
|
batch["answers_mask"] = answers["attention_mask"]
|
|
batch["original_data"] = [example["original"] for example in examples]
|
|
# input_ids torch.Size([3, 370])
|
|
# attention_mask torch.Size([3, 370])
|
|
# pixel_values torch.Size([3888, 1176])
|
|
# image_grid_thw torch.Size([3, 3])
|
|
# answers_ids torch.Size([3, 10])
|
|
# answers_mask torch.Size([3, 10])
|
|
return batch
|