diff --git a/src/collate_fn_library/__init__.py b/src/collate_fn_library/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/collate_fn_library/qwen2.py b/src/collate_fn_library/qwen2.py new file mode 100644 index 0000000..ea9b7db --- /dev/null +++ b/src/collate_fn_library/qwen2.py @@ -0,0 +1,125 @@ +from transformers import Qwen2VLProcessor + + +def collate_fn_for_train(examples, processor: Qwen2VLProcessor): + # Get the texts and images, and apply the chat template + texts = [ + processor.apply_chat_template(example["chat"], tokenize=False) + for example in examples + ] + # print(texts) + images = [example["image"] for example in examples] + + # Tokenize the texts and process the images + batch = processor(text=texts, images=images, return_tensors="pt", padding=True) + + # The labels are the input_ids, and we mask the padding tokens in the loss computation + labels = batch["input_ids"].clone() + labels[labels == processor.tokenizer.pad_token_id] = -100 # + # Ignore the image token index in the loss computation (model specific) + # 对<|im_start|>system *** <|im_end|>\n加掩码 + im_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_start|>") + im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>") + system_token_id = processor.tokenizer.convert_tokens_to_ids("system") + user_token_id = processor.tokenizer.convert_tokens_to_ids("user") + assistant_token_id = processor.tokenizer.convert_tokens_to_ids("assistant") + # enter_token_id = processor.tokenizer.convert_tokens_to_ids("\n") + # print(im_start_token_id, im_end_token_id, system_token_id, user_token_id, assistant_token_id, enter_token_id, processor.tokenizer.pad_token_id) + # 151644 151645 8948 872 77091 None 151643 + + for i, label in enumerate(labels): + now_index = 0 + while now_index < len(label): + if label[now_index] == im_start_token_id: + label[now_index] = -100 + now_index += 1 + if ( + label[now_index] == system_token_id + or label[now_index] == user_token_id + ): + while label[now_index] != im_end_token_id: + label[now_index] = -100 + now_index += 1 + label[now_index] = -100 + elif label[now_index] == assistant_token_id: + label[now_index] = -100 + label[now_index + 1] = -100 + now_index += 2 + while ( + now_index < len(label) and label[now_index] != im_end_token_id + ): + now_index += 1 + now_index += 1 + batch["labels"] = labels + + return batch + + +def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor): + # Get the texts and images, and apply the chat template + texts = [ + processor.apply_chat_template( + example["chat"], tokenize=False, add_generation_prompt=True + ) + for example in examples + ] + # print(texts) + images = [example["image"] for example in examples] + + # Tokenize the texts and process the images + batch = processor(text=texts, images=images, return_tensors="pt", padding=True) + + answers = [example["answer"] for example in examples] + answers = processor(text=answers, return_tensors="pt", padding=True) + batch["answers_ids"] = answers["input_ids"] + batch["answers_mask"] = answers["attention_mask"] + # input_ids torch.Size([3, 370]) + # attention_mask torch.Size([3, 370]) + # pixel_values torch.Size([3888, 1176]) + # image_grid_thw torch.Size([3, 3]) + # answers_ids torch.Size([3, 10]) + # answers_mask torch.Size([3, 10]) + return batch + + +if __name__ == "__main__": + from transformers import Qwen2VLProcessor + from datasets_library.OCRVQADataset import OCRVQADatasetForGeneration + + processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") + dataset = OCRVQADatasetForGeneration( + vis_root="/home/zyy/research/accelerate/dataset/OCR-VQA-200K/images", + ann_path="/home/zyy/research/accelerate/dataset/OCR-VQA-200K/dataset.json", + split="train", + ) + examples = [dataset[i] for i in range(3)] + # print(collate_fn_for_evaluate(examples, processor)) + from transformers import AutoModelForCausalLM + + model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") + from accelerate import Accelerator + from torch.utils.data import DataLoader + + val_dataloader = DataLoader( + dataset, + batch_size=3, + collate_fn=lambda x: collate_fn_for_evaluate(x, processor), + ) + accelerator = Accelerator() + model = accelerator.prepare(model) + val_dataloader = accelerator.prepare(val_dataloader) + + for batch in val_dataloader: + completion = model.generate( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + pixel_values=batch["pixel_values"], + image_grid_thw=batch["image_grid_thw"], + max_length=100, + ) + target = batch["answers_ids"] + generated_text = [out_ids[len(in_ids) :] for out_ids, in_ids in zip(completion, batch["input_ids"])] + generated_text = processor.tokenizer.batch_decode(generated_text) + target_text = processor.tokenizer.batch_decode(target) + print(generated_text, target_text) + diff --git a/src/datasets_library/OCRVQADataset.py b/src/datasets_library/OCRVQADataset.py index 62ddfc0..627a931 100644 --- a/src/datasets_library/OCRVQADataset.py +++ b/src/datasets_library/OCRVQADataset.py @@ -5,6 +5,88 @@ import os class OCRVQADataset(Dataset): + def __init__( + self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train" + ): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + self.vis_root = vis_root + + self.vis_processor = vis_processor + self.text_processor = text_processor + if split == "train": + self.data = self.create_data(ann_path, split=1)[:1] + elif split == "test": + self.data = self.create_data(ann_path, split=3)[:1] + + # self.instruction_pool = [ + # "[vqa] {}", + # "[vqa] Based on the image, respond to this question with a short answer: {}", + # ] + + def create_data(self, ann_path, split=1): + processed_data = [] + with open(ann_path, "r") as f: + data = json.load(f) + for k in data.keys(): + if data[k]["split"] != split: + continue # 1 for training, 2 for validation, 3 for test + ext = os.path.splitext(data[k]["imageURL"])[1] + imageFile = k + ext + assert len(data[k]["questions"]) == len(data[k]["answers"]) + for q, a in zip(data[k]["questions"], data[k]["answers"]): + if os.path.exists(os.path.join(self.vis_root, imageFile)): + processed_data.append( + { + "question": q, + "answer": a, + "image_path": imageFile, + "image_id": k, + "title": data[k]["title"], + "genre": data[k]["genre"], + } + ) + return processed_data + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + sample = self.data[index] + image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert( + "RGB" + ) + question = sample["question"] + answer = sample["answer"] + if self.vis_processor is not None: + image = self.vis_processor(image) + if self.text_processor is not None: + question = self.text_processor(question) + answer = self.text_processor(answer) + + chat = [ + { + "role": "user", + "content": [ + {"type": "image"}, + { + "type": "text", + "text": f"[vqa] Based on the image, respond to this question with a short answer: {question}", + }, + ], + }, + {"role": "assistant", "content": [{"type": "text", "text": answer}]}, + ] + return { + "image": image, + "chat": chat, + "image_id": sample["image_id"], + } + + +class OCRVQADatasetForGeneration(Dataset): def __init__( self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train" ): @@ -73,14 +155,15 @@ class OCRVQADataset(Dataset): {"type": "image"}, { "type": "text", - "content": f"[vqa] Based on the image, respond to this question with a short answer: {question}", + "text": f"[vqa] Based on the image, respond to this question with a short answer: {question}", }, ], - }, - {"role": "assistant", "content": answer}, + } + # {"role": "assistant", "content": answer}, ] return { "image": image, "chat": chat, + "answer": answer, "image_id": sample["image_id"], } diff --git a/src/evaluation.py b/src/evaluation.py new file mode 100644 index 0000000..375257d --- /dev/null +++ b/src/evaluation.py @@ -0,0 +1,153 @@ +import torch +from datasets import load_dataset +from transformers import ( + AutoModelForVision2Seq, + AutoProcessor, + LlavaForConditionalGeneration, +) + +from trl import ( + ModelConfig, + SFTScriptArguments, + SFTConfig, + SFTTrainer, + TrlParser, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) + + +if __name__ == "__main__": + parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_and_config() + script_args: SFTScriptArguments + training_args: SFTConfig + model_args: ModelConfig + training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) + training_args.remove_unused_columns = False + training_args.dataset_kwargs = {"skip_prepare_dataset": True} + + ################ + # Model, Tokenizer & Processor + ################ + torch_dtype = ( + model_args.torch_dtype + if model_args.torch_dtype in ["auto", None] + else getattr(torch, model_args.torch_dtype) + ) + quantization_config = get_quantization_config(model_args) + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + torch_dtype=torch_dtype, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + processor = AutoProcessor.from_pretrained( + model_args.model_name_or_path, + trust_remote_code=model_args.trust_remote_code, + padding_side="left", + ) + + model = AutoModelForVision2Seq.from_pretrained( + model_args.model_name_or_path, + trust_remote_code=model_args.trust_remote_code, + **model_kwargs, + ) + + if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct": + from collate_fn_library.qwen2 import collate_fn_for_train + from functools import partial + + collate_fn_for_train = partial(collate_fn_for_train, processor=processor) + + ################ + # Dataset + ################ + base_path = "/home/zyy/research/accelerate/dataset" + if script_args.dataset_name == "OCR_VQA_200K": + import os.path as osp + from datasets_library.OCRVQADataset import OCRVQADataset + + dataset = { + "train": OCRVQADataset( + osp.join(base_path, "OCR-VQA-200K/images"), + osp.join(base_path, "OCR-VQA-200K/dataset.json"), + split="train", + ), + "test": OCRVQADataset( + osp.join(base_path, "OCR-VQA-200K/images"), + osp.join(base_path, "OCR-VQA-200K/dataset.json"), + split="test", + ), + } + + else: + dataset = load_dataset(script_args.dataset_name, name=script_args.config) + + ################ + # Training + ################ + trainer = SFTTrainer( + model=model, + args=training_args, + data_collator=collate_fn_for_train, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=( + dataset[script_args.dataset_test_split] + if training_args.eval_strategy != "no" + else None + ), + tokenizer=processor.tokenizer, + # processing_class=processor.tokenizer, + peft_config=get_peft_config(model_args), + ) + trainer.train() + + model = trainer.model + # model进gpu + accelerator = trainer.accelerator + # model = accelerator.prepare(model) + + from datasets_library.OCRVQADataset import OCRVQADatasetForGeneration + from collate_fn_library.qwen2 import collate_fn_for_evaluate + + dataset = OCRVQADatasetForGeneration( + vis_root="/home/zyy/research/accelerate/dataset/OCR-VQA-200K/images", + ann_path="/home/zyy/research/accelerate/dataset/OCR-VQA-200K/dataset.json", + split="train", + ) + examples = [dataset[i] for i in range(3)] + # print(collate_fn_for_evaluate(examples, processor)) + + from torch.utils.data import DataLoader + + val_dataloader = DataLoader( + dataset, + batch_size=3, + collate_fn=lambda x: collate_fn_for_evaluate(x, processor), + ) + val_dataloader = accelerator.prepare(val_dataloader) + + import evaluate + glue = evaluate.load("rouge") + + for batch in val_dataloader: + completion = model.generate( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + pixel_values=batch["pixel_values"], + image_grid_thw=batch["image_grid_thw"], + max_length=1000, + ) + target = batch["answers_ids"] + generated_text = [ + out_ids[len(in_ids) :] + for out_ids, in_ids in zip(completion, batch["input_ids"]) + ] + generated_text = processor.tokenizer.batch_decode(generated_text, skip_special_tokens=True) + target_text = processor.tokenizer.batch_decode(target, skip_special_tokens=True) + glue.add_batch(predictions=generated_text, references=target_text) + + print(glue.compute()) diff --git a/src/evaluation.sh b/src/evaluation.sh new file mode 100755 index 0000000..87f8290 --- /dev/null +++ b/src/evaluation.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml evaluation.py \ + --dataset_name OCR_VQA_200K \ + --use_peft \ + --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \ + --lora_target_modules q_proj v_proj \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --max_seq_length 1024 \ + --output_dir checkpoint/sft-llava-1.5-7b-hf \ + --bf16 \ + --torch_dtype bfloat16 +# --eval_strategy epoch \ diff --git a/src/evaluations_library/__init__.py b/src/evaluations_library/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/run.sh b/src/run.sh old mode 100644 new mode 100755 index 242937f..d46514d --- a/src/run.sh +++ b/src/run.sh @@ -6,7 +6,11 @@ accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml sft_vlm. --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \ --lora_target_modules q_proj v_proj \ --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 2 \ --gradient_accumulation_steps 8 \ - --output_dir sft-llava-1.5-7b-hf \ + --max_seq_length 1024 \ + --output_dir checkpoint/sft-llava-1.5-7b-hf \ --bf16 \ - --torch_dtype bfloat16 + --torch_dtype bfloat16 \ + --eval_strategy epoch \ + diff --git a/src/sft_vlm.py b/src/sft_vlm.py index 4810c43..ae36e00 100644 --- a/src/sft_vlm.py +++ b/src/sft_vlm.py @@ -56,73 +56,11 @@ if __name__ == "__main__": **model_kwargs, ) - ################ - # Create a data collator to encode text and image pairs - ################ - def collate_fn_qwenv2(examples): - # Get the texts and images, and apply the chat template - texts = [ - processor.apply_chat_template( - example["chat"], tokenize=False, add_generation_prompt=True - ) - for example in examples - ] - # print(texts) - images = [example["image"] for example in examples] - if isinstance(model, LlavaForConditionalGeneration): - # LLava1.5 does not support multiple images - images = [image[0] for image in images] + if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct": + from collate_fn_library.qwen2 import collate_fn_for_train + from functools import partial - # Tokenize the texts and process the images - batch = processor(text=texts, images=images, return_tensors="pt", padding=True) - - # The labels are the input_ids, and we mask the padding tokens in the loss computation - labels = batch["input_ids"].clone() - labels[labels == processor.tokenizer.pad_token_id] = -100 # - # Ignore the image token index in the loss computation (model specific) - # 对<|im_start|>system *** <|im_end|>\n加掩码 - im_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_start|>") - im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>") - system_token_id = processor.tokenizer.convert_tokens_to_ids("system") - user_token_id = processor.tokenizer.convert_tokens_to_ids("user") - assistant_token_id = processor.tokenizer.convert_tokens_to_ids("assistant") - # enter_token_id = processor.tokenizer.convert_tokens_to_ids("\n") - # print(im_start_token_id, im_end_token_id, system_token_id, user_token_id, assistant_token_id, enter_token_id, processor.tokenizer.pad_token_id) - # 151644 151645 8948 872 77091 None 151643 - - for i, label in enumerate(labels): - now_index = 0 - while now_index < len(label): - if label[now_index] == im_start_token_id: - label[now_index] = -100 - now_index += 1 - if ( - label[now_index] == system_token_id - or label[now_index] == user_token_id - ): - while label[now_index] != im_end_token_id: - label[now_index] = -100 - now_index += 1 - label[now_index] = -100 - elif label[now_index] == assistant_token_id: - label[now_index] = -100 - label[now_index + 1] = -100 - now_index += 2 - while ( - now_index < len(label) - and label[now_index] != im_end_token_id - ): - now_index += 1 - now_index += 1 - # print(labels[0], batch["input_ids"][0]) - - # image_token_id = processor.tokenizer.convert_tokens_to_ids( - # processor.image_token - # ) - # labels[labels == image_token_id] = -100 - batch["labels"] = labels - - return batch + collate_fn_for_train = partial(collate_fn_for_train, processor=processor) ################ # Dataset @@ -141,7 +79,7 @@ if __name__ == "__main__": "test": OCRVQADataset( osp.join(base_path, "OCR-VQA-200K/images"), osp.join(base_path, "OCR-VQA-200K/dataset.json"), - split="val", + split="test", ), } @@ -154,7 +92,7 @@ if __name__ == "__main__": trainer = SFTTrainer( model=model, args=training_args, - data_collator=collate_fn_qwenv2, + data_collator=collate_fn_for_train, train_dataset=dataset[script_args.dataset_train_split], eval_dataset=( dataset[script_args.dataset_test_split] @@ -168,6 +106,11 @@ if __name__ == "__main__": trainer.train() + model = trainer.model + + # trainer.evaluate() + # 并行evaluate进行补全 + # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: diff --git a/src/todo.md b/src/todo.md new file mode 100644 index 0000000..f99af50 --- /dev/null +++ b/src/todo.md @@ -0,0 +1,6 @@ +## TODO: + +[2024.12.31] + +- [ ] 采用数据集多次训练 +- [ ] 整理evaluate的代码