From b766c21c9be09f98f94c8c0a8818fd33b9ac2083 Mon Sep 17 00:00:00 2001 From: YunyaoZhou Date: Fri, 10 Jan 2025 00:18:50 +0800 Subject: [PATCH] =?UTF-8?q?feat=E2=9C=A8:=20=E6=B7=BB=E5=8A=A0CHEM?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E9=9B=86=E6=94=AF=E6=8C=81=E5=B9=B6=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E5=9B=BE=E5=83=8F=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | 1 + src/dataset_library/CHEM.py | 171 +++++++++++++++++++++++++ src/dataset_library/OCRVQADataset.py | 72 +++-------- src/dataset_library/chem.py | 181 --------------------------- src/dataset_library/factory.py | 21 ++++ src/evaluation.py | 7 +- src/train.sh | 2 +- 7 files changed, 213 insertions(+), 242 deletions(-) create mode 100644 src/dataset_library/CHEM.py delete mode 100644 src/dataset_library/chem.py diff --git a/requirements.txt b/requirements.txt index aea51d2..f3a9cf3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ torchaudio==2.5.1+cu124 torchvision==0.20.1+cu124 transformers==4.46.1 trl==0.13.0 +pillow==9.5.0 diff --git a/src/dataset_library/CHEM.py b/src/dataset_library/CHEM.py new file mode 100644 index 0000000..732165f --- /dev/null +++ b/src/dataset_library/CHEM.py @@ -0,0 +1,171 @@ +from PIL import Image +from torch.utils.data import Dataset +import json +import os + + +class CHEMDataset(Dataset): + def __init__( + self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train" + ): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + self.vis_root = vis_root + + self.vis_processor = ( + vis_processor if vis_processor is not None else self._vis_processor + ) + self.text_processor = text_processor + if split == "train": + self.data = self.create_data(ann_path, split="train") + elif split == "test": + self.data = self.create_data(ann_path, split="test") + + def _vis_processor(self, image: Image.Image): + width, height = image.size + if width > 800 or height > 800: + max_size = max(width, height) + ratio = 800 / max_size + new_width = int(width * ratio) + new_height = int(height * ratio) + image = image.resize((new_width, new_height), Image.Resampling.BILINEAR) + + if width < 28 or height < 28: + min_size = min(width, height) + ratio = 28 / min_size + 1 + new_width = int(width * ratio) + new_height = int(height * ratio) + image = image.resize((new_width, new_height), Image.Resampling.BILINEAR) + + return image + + def create_data(self, ann_path, split=1): + import os.path as osp + + if split == "train": + json_path = osp.join(ann_path, "conversations_loc_train.jsonl") + elif split == "test": + json_path = osp.join(ann_path, "conversations_loc_val.jsonl") + processed_data = [] + with open(json_path, "r") as f: + for line in f: + data = json.loads(line) + image_file = osp.join(self.vis_root, data["images"][0].split("/")[-1]) + conversation = data["messages"] + system = conversation[0]["content"] + query = conversation[1]["content"] + answer = conversation[2]["content"] + processed_data.append( + { + "question": query, + "answer": answer, + "image_path": image_file, + "system": system, + } + ) + return processed_data + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + sample = self.data[index] + image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert( + "RGB" + ) + question = sample["question"] + answer = sample["answer"] + if self.vis_processor is not None: + image = self.vis_processor(image) + if self.text_processor is not None: + question = self.text_processor(question) + answer = self.text_processor(answer) + + chat = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": sample["system"], + }, + ], + }, + { + "role": "user", + "content": [ + {"type": "image"}, + { + "type": "text", + "text": f"[vqa] {question.replace('','')}", + }, + ], + }, + {"role": "assistant", "content": [{"type": "text", "text": answer}]}, + ] + return { + "image": image, + "chat": chat, + } + + +class CHEMDatasetForGeneration(CHEMDataset): + def __getitem__(self, index): + sample = self.data[index] + image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert( + "RGB" + ) + question = sample["question"] + answer = sample["answer"] + if self.vis_processor is not None: + image = self.vis_processor(image) + if self.text_processor is not None: + question = self.text_processor(question) + answer = self.text_processor(answer) + + chat = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": sample["system"], + }, + ], + }, + { + "role": "user", + "content": [ + {"type": "image"}, + { + "type": "text", + "text": f"[vqa] {question.replace('','')}", + }, + ], + }, + ] + return { + "image": image, + "chat": chat, + "answer": answer, + } + + +if __name__ == "__main__": + dataset = CHEMDataset( + "/home/zyy/research/accelerate/dataset/chem/images", + "/home/zyy/research/accelerate/dataset/chem/qwen_data", + split="train", + ) + print(len(dataset)) + print(dataset[0]) + dataset = CHEMDatasetForGeneration( + "/home/zyy/research/accelerate/dataset/chem/images", + "/home/zyy/research/accelerate/dataset/chem/qwen_data", + split="train", + ) + print(len(dataset)) + print(dataset[0]) + pass diff --git a/src/dataset_library/OCRVQADataset.py b/src/dataset_library/OCRVQADataset.py index 58043c9..dedb432 100644 --- a/src/dataset_library/OCRVQADataset.py +++ b/src/dataset_library/OCRVQADataset.py @@ -14,7 +14,9 @@ class OCRVQADataset(Dataset): """ self.vis_root = vis_root - self.vis_processor = vis_processor + self.vis_processor = ( + vis_processor if vis_processor is not None else self._vis_processor + ) self.text_processor = text_processor if split == "train": self.data = self.create_data(ann_path, split=1) @@ -53,12 +55,7 @@ class OCRVQADataset(Dataset): def __len__(self): return len(self.data) - def __getitem__(self, index): - sample = self.data[index] - image: Image.Image = Image.open( - os.path.join(self.vis_root, sample["image_path"]) - ).convert("RGB") - # resize image + def _vis_processor(self, image: Image.Image): width, height = image.size if width > 500 or height > 500: max_size = max(width, height) @@ -66,7 +63,7 @@ class OCRVQADataset(Dataset): new_width = int(width * ratio) new_height = int(height * ratio) image = image.resize((new_width, new_height), Image.Resampling.BILINEAR) - + if width < 28 or height < 28: min_size = min(width, height) ratio = 28 / min_size + 1 @@ -74,6 +71,15 @@ class OCRVQADataset(Dataset): new_height = int(height * ratio) image = image.resize((new_width, new_height), Image.Resampling.BILINEAR) + return image + + def __getitem__(self, index): + sample = self.data[index] + image: Image.Image = Image.open( + os.path.join(self.vis_root, sample["image_path"]) + ).convert("RGB") + # resize image + question = sample["question"] answer = sample["answer"] if self.vis_processor is not None: @@ -98,64 +104,17 @@ class OCRVQADataset(Dataset): return { "image": image, "chat": chat, - "image_id": sample["image_id"], } class OCRVQADatasetForGeneration(Dataset): - def __init__( - self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train" - ): - """ - vis_root (string): Root directory of images (e.g. coco/images/) - ann_root (string): directory to store the annotation file - """ - self.vis_root = vis_root - - self.vis_processor = vis_processor - self.text_processor = text_processor - if split == "train": - self.data = self.create_data(ann_path, split=1)[:200] - elif split == "test": - self.data = self.create_data(ann_path, split=3)[:200] - - # self.instruction_pool = [ - # "[vqa] {}", - # "[vqa] Based on the image, respond to this question with a short answer: {}", - # ] - - def create_data(self, ann_path, split=1): - processed_data = [] - with open(ann_path, "r") as f: - data = json.load(f) - for k in data.keys(): - if data[k]["split"] != split: - continue # 1 for training, 2 for validation, 3 for test - ext = os.path.splitext(data[k]["imageURL"])[1] - imageFile = k + ext - assert len(data[k]["questions"]) == len(data[k]["answers"]) - for q, a in zip(data[k]["questions"], data[k]["answers"]): - if os.path.exists(os.path.join(self.vis_root, imageFile)): - processed_data.append( - { - "question": q, - "answer": a, - "image_path": imageFile, - "image_id": k, - "title": data[k]["title"], - "genre": data[k]["genre"], - } - ) - return processed_data - - def __len__(self): - return len(self.data) def __getitem__(self, index): sample = self.data[index] image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert( "RGB" ) + # resize image question = sample["question"] answer = sample["answer"] if self.vis_processor is not None: @@ -181,5 +140,4 @@ class OCRVQADatasetForGeneration(Dataset): "image": image, "chat": chat, "answer": answer, - "image_id": sample["image_id"], } diff --git a/src/dataset_library/chem.py b/src/dataset_library/chem.py deleted file mode 100644 index a1999cb..0000000 --- a/src/dataset_library/chem.py +++ /dev/null @@ -1,181 +0,0 @@ -from PIL import Image -from torch.utils.data import Dataset -import json -import os - - -class ChemDataseet(Dataset): - def __init__( - self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train" - ): - """ - vis_root (string): Root directory of images (e.g. coco/images/) - ann_root (string): directory to store the annotation file - """ - self.vis_root = vis_root - - self.vis_processor = vis_processor - self.text_processor = text_processor - if split == "train": - self.data = self.create_data(ann_path, split=1)[:200] - elif split == "test": - self.data = self.create_data(ann_path, split=3)[:200] - - def create_data(self, ann_path, split=1): - processed_data = [] - with open(ann_path, "r") as f: - data = json.load(f) - for k in data.keys(): - if data[k]["split"] != split: - continue # 1 for training, 2 for validation, 3 for test - ext = os.path.splitext(data[k]["imageURL"])[1] - imageFile = k + ext - assert len(data[k]["questions"]) == len(data[k]["answers"]) - for q, a in zip(data[k]["questions"], data[k]["answers"]): - if os.path.exists(os.path.join(self.vis_root, imageFile)): - processed_data.append( - { - "question": q, - "answer": a, - "image_path": imageFile, - "image_id": k, - "title": data[k]["title"], - "genre": data[k]["genre"], - } - ) - return processed_data - - def __len__(self): - return len(self.data) - - def __getitem__(self, index): - sample = self.data[index] - image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert( - "RGB" - ) - question = sample["question"] - answer = sample["answer"] - if self.vis_processor is not None: - image = self.vis_processor(image) - if self.text_processor is not None: - question = self.text_processor(question) - answer = self.text_processor(answer) - - chat = [ - { - "role": "user", - "content": [ - {"type": "image"}, - { - "type": "text", - "text": f"[vqa] Based on the image, respond to this question with a short answer: {question}", - }, - ], - }, - {"role": "assistant", "content": [{"type": "text", "text": answer}]}, - ] - return { - "image": image, - "chat": chat, - "image_id": sample["image_id"], - } - - -class OCRVQADatasetForGeneration(Dataset): - def __init__( - self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train" - ): - """ - vis_root (string): Root directory of images (e.g. coco/images/) - ann_root (string): directory to store the annotation file - """ - self.vis_root = vis_root - - self.vis_processor = vis_processor - self.text_processor = text_processor - if split == "train": - self.data = self.create_data(ann_path, split=1)[:200] - elif split == "test": - self.data = self.create_data(ann_path, split=3)[:200] - - # self.instruction_pool = [ - # "[vqa] {}", - # "[vqa] Based on the image, respond to this question with a short answer: {}", - # ] - - def create_data(self, ann_path, split=1): - processed_data = [] - with open(ann_path, "r") as f: - data = json.load(f) - for k in data.keys(): - if data[k]["split"] != split: - continue # 1 for training, 2 for validation, 3 for test - ext = os.path.splitext(data[k]["imageURL"])[1] - imageFile = k + ext - assert len(data[k]["questions"]) == len(data[k]["answers"]) - for q, a in zip(data[k]["questions"], data[k]["answers"]): - if os.path.exists(os.path.join(self.vis_root, imageFile)): - processed_data.append( - { - "question": q, - "answer": a, - "image_path": imageFile, - "image_id": k, - "title": data[k]["title"], - "genre": data[k]["genre"], - } - ) - return processed_data - - def __len__(self): - return len(self.data) - - def __getitem__(self, index): - sample = self.data[index] - image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert( - "RGB" - ) - question = sample["question"] - answer = sample["answer"] - if self.vis_processor is not None: - image = self.vis_processor(image) - if self.text_processor is not None: - question = self.text_processor(question) - answer = self.text_processor(answer) - - chat = [ - { - "role": "user", - "content": [ - {"type": "image"}, - { - "type": "text", - "text": f"[vqa] Based on the image, respond to this question with a short answer: {question}", - }, - ], - } - # {"role": "assistant", "content": answer}, - ] - return { - "image": image, - "chat": chat, - "answer": answer, - "image_id": sample["image_id"], - } - -if __name__ == "__main__": - dataset = ChemDataseet( - "/home/zyy/research/accelerate/dataset/chem/images/", - "/home/zyy/research/accelerate/dataset/chem/qwen_data/conversations_loc_train.jsonl", - split="train", - ) - print(len(dataset)) - print(dataset[0]) - dataset = OCRVQADatasetForGeneration( - "/home/zyy/research/accelerate/dataset/OCR-VQA-200K/images", - "/home/zyy/research/accelerate/dataset/OCR-VQA-200K/dataset.json", - split="train", - ) - print(len(dataset)) - print(dataset[0]) - pass \ No newline at end of file diff --git a/src/dataset_library/factory.py b/src/dataset_library/factory.py index eb5f037..f86cd0f 100644 --- a/src/dataset_library/factory.py +++ b/src/dataset_library/factory.py @@ -27,4 +27,25 @@ def get_dataset( split="test", ), } + if dataset_name == "CHEM": + import os.path as osp + from .CHEM import CHEMDataset, CHEMDatasetForGeneration + + dataset = { + "train": CHEMDataset( + osp.join(base_path, "chem/images"), + osp.join(base_path, "chem/qwen_data"), + split="train", + ), + "test": CHEMDataset( + osp.join(base_path, "chem/images"), + osp.join(base_path, "chem/qwen_data"), + split="test", + ), + "generation": CHEMDatasetForGeneration( + osp.join(base_path, "chem/images"), + osp.join(base_path, "chem/qwen_data"), + split="test", + ), + } return dataset diff --git a/src/evaluation.py b/src/evaluation.py index 48071ca..1b8fd04 100644 --- a/src/evaluation.py +++ b/src/evaluation.py @@ -13,7 +13,9 @@ from utils.args import ContinualScriptArguments, ContinualModelConfig if __name__ == "__main__": - parser = TrlParser((ContinualScriptArguments, TrainingArguments, ContinualModelConfig)) + parser = TrlParser( + (ContinualScriptArguments, TrainingArguments, ContinualModelConfig) + ) script_args, training_args, model_args = parser.parse_args_and_config() # for type hint if 0 == 1: @@ -48,10 +50,9 @@ if __name__ == "__main__": trust_remote_code=model_args.trust_remote_code, **model_kwargs, ) - print(model) if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct": - from model_library.qwen2 import ( + from model_library.qwen2vl import ( collate_fn_for_train, collate_fn_for_evaluate, ) diff --git a/src/train.sh b/src/train.sh index 5016a3c..9899457 100755 --- a/src/train.sh +++ b/src/train.sh @@ -1,7 +1,7 @@ #!/bin/bash accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \ - --dataset_name OCR_VQA_200K OCR_VQA_200K OCR_VQA_200K \ + --dataset_name CHEM \ --use_peft \ --peft_type MMOELORA \ --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \