from transformers import Qwen2VLProcessor from dataset_library.format import DatasetOutput import torch def collate_fn_for_train(examples: list[DatasetOutput], processor: Qwen2VLProcessor): # Get the texts and images, and apply the chat template texts = [ processor.apply_chat_template(example["chat"], tokenize=False) for example in examples ] # print(texts) images = [ example["images"] for example in examples if example["images"][0] is not None ] images = images if len(images) > 0 else None # Tokenize the texts and process the images batch = processor(text=texts, images=images, return_tensors="pt", padding=True) # The labels are the input_ids, and we mask the padding tokens in the loss computation labels = batch["input_ids"].clone() labels[labels == processor.tokenizer.pad_token_id] = -100 # # Ignore the image token index in the loss computation (model specific) # 对<|im_start|>system *** <|im_end|>\n加掩码 im_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_start|>") im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>") system_token_id = processor.tokenizer.convert_tokens_to_ids("system") user_token_id = processor.tokenizer.convert_tokens_to_ids("user") assistant_token_id = processor.tokenizer.convert_tokens_to_ids("assistant") # enter_token_id = processor.tokenizer.convert_tokens_to_ids("\n") # print(im_start_token_id, im_end_token_id, system_token_id, user_token_id, assistant_token_id, enter_token_id, processor.tokenizer.pad_token_id) # 151644 151645 8948 872 77091 None 151643 # for i, label in enumerate(labels): # now_index = 0 # while now_index < len(label): # if label[now_index] == im_start_token_id: # label[now_index] = -100 # now_index += 1 # if ( # label[now_index] == system_token_id # or label[now_index] == user_token_id # ): # while label[now_index] != im_end_token_id: # label[now_index] = -100 # now_index += 1 # label[now_index] = -100 # elif label[now_index] == assistant_token_id: # label[now_index] = -100 # label[now_index + 1] = -100 # now_index += 2 # while ( # now_index < len(label) and label[now_index] != im_end_token_id # ): # now_index += 1 # now_index += 1 batch["labels"] = labels # batch["task_id"] = torch.tensor([0] * len(labels), dtype=torch.long) return batch def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor): # Get the texts and images, and apply the chat template texts = [ processor.apply_chat_template( example["chat"], tokenize=False, add_generation_prompt=True ) for example in examples ] # print(texts) images = [ example["images"] for example in examples if example["images"][0] is not None ] images = images if len(images) > 0 else None # Tokenize the texts and process the images batch = processor(text=texts, images=images, return_tensors="pt", padding=True) answers = [example["answer"] for example in examples] answers = processor(text=answers, return_tensors="pt", padding=True) batch["answers_ids"] = answers["input_ids"] batch["original_data"] = [example["original"] for example in examples] # input_ids torch.Size([3, 370]) # attention_mask torch.Size([3, 370]) # pixel_values torch.Size([3888, 1176]) # image_grid_thw torch.Size([3, 3]) # answers_ids torch.Size([3, 10]) # answers_mask torch.Size([3, 10]) return batch