From bcb0494f52783e4c03de9d755669a38d1781f211 Mon Sep 17 00:00:00 2001 From: YunyaoZhou Date: Sun, 19 Jan 2025 00:14:03 +0800 Subject: [PATCH] =?UTF-8?q?feat=E2=9C=A8:=20=E6=B7=BB=E5=8A=A0ScienceQA?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E9=9B=86=E6=94=AF=E6=8C=81=EF=BC=8C=E6=9B=B4?= =?UTF-8?q?=E6=96=B0=E6=95=B0=E6=8D=AE=E9=9B=86=E5=B7=A5=E5=8E=82=E5=92=8C?= =?UTF-8?q?=E8=AF=84=E4=BC=B0=E9=80=BB=E8=BE=91=EF=BC=8C=E8=B0=83=E6=95=B4?= =?UTF-8?q?=E6=89=B9=E5=A4=84=E7=90=86=E5=A4=A7=E5=B0=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/dataset_library/ScienceQADataset.py | 120 +++++++++++++++++++++ src/dataset_library/factory.py | 9 ++ src/dataset_library/test_dataset.py | 9 ++ src/evaluation.py | 4 +- src/evaluation.sh | 2 +- src/model_library/qwen2audio/collate_fn.py | 14 ++- src/model_library/qwen2vl/collate_fn.py | 12 ++- src/todo.md | 10 +- src/train.sh | 2 +- src/utils/evaluate_tool.py | 10 +- 10 files changed, 174 insertions(+), 18 deletions(-) create mode 100644 src/dataset_library/ScienceQADataset.py diff --git a/src/dataset_library/ScienceQADataset.py b/src/dataset_library/ScienceQADataset.py new file mode 100644 index 0000000..77f92ec --- /dev/null +++ b/src/dataset_library/ScienceQADataset.py @@ -0,0 +1,120 @@ +from .format import ( + Conversation, + ConverstationAudio, + ConverstationImage, + ConverstationText, + DatasetOutput, +) +from torch.utils.data import Dataset +from datasets import load_dataset + + +class ScienceQADataset(Dataset): + def __init__(self, audio_processor=None, text_processor=None, split="train"): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + + self.vis_processor = audio_processor + self.text_processor = text_processor + ds = load_dataset("derek-thomas/ScienceQA") + self.data = ds[split] + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + sample = self.data[index] + # print(sample) + # {'image': , 'question': 'Which of these states is farthest north?', 'choices': ['West Virginia', 'Louisiana', 'Arizona', 'Oklahoma'], 'answer': 0, 'hint': '', 'task': 'closed choice', 'grade': 'grade2', 'subject': 'social science', 'topic': 'geography', 'category': 'Geography', 'skill': 'Read a map: cardinal directions', 'lecture': 'Maps have four cardinal directions, or main directions. Those directions are north, south, east, and west.\nA compass rose is a set of arrows that point to the cardinal directions. A compass rose usually shows only the first letter of each cardinal direction.\nThe north arrow points to the North Pole. On most maps, north is at the top of the map.', 'solution': 'To find the answer, look at the compass rose. Look at which way the north arrow is pointing. West Virginia is farthest north.'} + images = sample["image"] + question = sample["question"] + choices = sample["choices"] + task = sample["task"] + answer = sample["answer"] + + if self.vis_processor is not None: + images = self.vis_processor(images) + if self.text_processor is not None: + question = self.text_processor(question) + + chat = [ + Conversation( + role="user", + content=[ + ConverstationImage(type="image", image_url=""), + ConverstationText( + type="text", + text=f"[{task}] '{question}' choose from '{choices}'", + ), + ], + ), + Conversation( + role="assistant", + content=[ConverstationText(type="text", text=choices[answer])], + ), + ] + + return DatasetOutput( + images=[images], + chat=chat, + original=sample, + ) + + +class ScienceQADatasetForGeneration(ScienceQADataset): + + def __getitem__(self, index): + sample = self.data[index] + # print(sample) + # {'image': , 'question': 'Which of these states is farthest north?', 'choices': ['West Virginia', 'Louisiana', 'Arizona', 'Oklahoma'], 'answer': 0, 'hint': '', 'task': 'closed choice', 'grade': 'grade2', 'subject': 'social science', 'topic': 'geography', 'category': 'Geography', 'skill': 'Read a map: cardinal directions', 'lecture': 'Maps have four cardinal directions, or main directions. Those directions are north, south, east, and west.\nA compass rose is a set of arrows that point to the cardinal directions. A compass rose usually shows only the first letter of each cardinal direction.\nThe north arrow points to the North Pole. On most maps, north is at the top of the map.', 'solution': 'To find the answer, look at the compass rose. Look at which way the north arrow is pointing. West Virginia is farthest north.'} + images = sample["image"] + question = sample["question"] + choices = sample["choices"] + task = sample["task"] + answer = sample["answer"] + + if self.vis_processor is not None: + images = self.vis_processor(images) + if self.text_processor is not None: + question = self.text_processor(question) + + chat = [ + Conversation( + role="user", + content=[ + ConverstationImage(type="image", image_url=""), + ConverstationText( + type="text", + text=f"[{task}] '{question}' choose from '{choices}'", + ), + ], + ), + ] + + return DatasetOutput( + images=[images], + chat=chat, + answer=choices[answer], + original=sample, + ) + + +def test_scienceQA(): + dataset = ScienceQADataset( + split="train", + ) + print(dataset[3]) + assert len(dataset) > 0 + assert len(dataset[0]["chat"]) > 0 + dataset = ScienceQADatasetForGeneration( + split="train", + ) + print(dataset[3]) + assert len(dataset) > 0 + assert len(dataset[0]["chat"]) > 0 + + +if __name__ == "__main__": + test_scienceQA() diff --git a/src/dataset_library/factory.py b/src/dataset_library/factory.py index 338a239..e29e6fa 100644 --- a/src/dataset_library/factory.py +++ b/src/dataset_library/factory.py @@ -78,4 +78,13 @@ def get_dataset( ), } + if dataset_name == "scienceqa": + from .ScienceQADataset import ScienceQADataset, ScienceQADatasetForGeneration + + dataset = { + "train": ScienceQADataset(split="train"), + "test": ScienceQADataset(split="test"), + "generation": ScienceQADatasetForGeneration(split="test"), + } + return dataset diff --git a/src/dataset_library/test_dataset.py b/src/dataset_library/test_dataset.py index 94d432e..516abc7 100644 --- a/src/dataset_library/test_dataset.py +++ b/src/dataset_library/test_dataset.py @@ -35,3 +35,12 @@ def test_textvqa(): assert len(dataset["test"]) > 0 assert len(dataset["test"][0]["chat"]) > 0 + + +def test_scienceqa(): + dataset = get_dataset("scienceqa") + assert len(dataset["train"]) > 0 + assert len(dataset["train"][0]["chat"]) > 0 + + assert len(dataset["test"]) > 0 + assert len(dataset["test"][0]["chat"]) > 0 diff --git a/src/evaluation.py b/src/evaluation.py index 5898e51..ae36486 100644 --- a/src/evaluation.py +++ b/src/evaluation.py @@ -99,9 +99,11 @@ if __name__ == "__main__": from torch.utils.data import DataLoader + bs = 3 if dataset_name not in ["scienceqa"] else 1 + val_dataloader = DataLoader( dataset[script_args.dataset_generation_split], - batch_size=3, + batch_size=1, collate_fn=collate_fn_for_evaluate, ) val_dataloader = accelerator.prepare_data_loader(val_dataloader) diff --git a/src/evaluation.sh b/src/evaluation.sh index 598c97b..6669ee8 100755 --- a/src/evaluation.sh +++ b/src/evaluation.sh @@ -1,7 +1,7 @@ #!/bin/bash accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml evaluation.py \ - --dataset_name CHEM \ + --dataset_name scienceqa \ --use_peft \ --peft_type MMOELORA \ --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \ diff --git a/src/model_library/qwen2audio/collate_fn.py b/src/model_library/qwen2audio/collate_fn.py index b0cb927..d08e7dd 100644 --- a/src/model_library/qwen2audio/collate_fn.py +++ b/src/model_library/qwen2audio/collate_fn.py @@ -7,14 +7,14 @@ def collate_fn_for_train(examples, processor: Qwen2AudioProcessor): processor.apply_chat_template(example["chat"], tokenize=False) for example in examples ] - audios = [example["audio"][0] for example in examples] + audios = [example["audios"][0] for example in examples] # Tokenize the texts and process the images batch = processor( text=texts, audios=audios, return_tensors="pt", padding=True, - sampling_rate=examples[0]["audio"][1], + sampling_rate=examples[0]["audios"][1], ) # The labels are the input_ids, and we mask the padding tokens in the loss computation @@ -69,10 +69,16 @@ def collate_fn_for_evaluate(examples, processor: Qwen2AudioProcessor): for example in examples ] # print(texts) - audios = [example["audio"] for example in examples] + audios = [example["audios"] for example in examples] # Tokenize the texts and process the images - batch = processor(text=texts, audios=audios, return_tensors="pt", padding=True) + batch = processor( + text=texts, + audios=audios, + return_tensors="pt", + padding=True, + sampling_rate=examples[0]["audios"][1], + ) answers = [example["answer"] for example in examples] answers = processor(text=answers, return_tensors="pt", padding=True) diff --git a/src/model_library/qwen2vl/collate_fn.py b/src/model_library/qwen2vl/collate_fn.py index 649e544..9485d54 100644 --- a/src/model_library/qwen2vl/collate_fn.py +++ b/src/model_library/qwen2vl/collate_fn.py @@ -10,7 +10,11 @@ def collate_fn_for_train(examples: list[DatasetOutput], processor: Qwen2VLProces for example in examples ] # print(texts) - images = [example["images"] for example in examples] + images = [ + example["images"] for example in examples if example["images"][0] is not None + ] + images = images if len(images) > 0 else None + # Tokenize the texts and process the images batch = processor(text=texts, images=images, return_tensors="pt", padding=True) @@ -66,7 +70,10 @@ def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor): for example in examples ] # print(texts) - images = [example["images"] for example in examples] + images = [ + example["images"] for example in examples if example["images"][0] is not None + ] + images = images if len(images) > 0 else None # Tokenize the texts and process the images batch = processor(text=texts, images=images, return_tensors="pt", padding=True) @@ -74,7 +81,6 @@ def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor): answers = [example["answer"] for example in examples] answers = processor(text=answers, return_tensors="pt", padding=True) batch["answers_ids"] = answers["input_ids"] - batch["answers_mask"] = answers["attention_mask"] batch["original_data"] = [example["original"] for example in examples] # input_ids torch.Size([3, 370]) # attention_mask torch.Size([3, 370]) diff --git a/src/todo.md b/src/todo.md index 8e7bf4a..5a05c0f 100644 --- a/src/todo.md +++ b/src/todo.md @@ -12,5 +12,11 @@ [2025.01.03] - [ ] 处理量化逻辑 -- [ ] 严查moelora的原始代码,太粗糙了😡 -- [ ] 未知原因trainer后处理时间长 +- [X] 严查moelora的原始代码,太粗糙了😡 +- [X] 未知原因trainer后处理时间长 + +[2025.01.19] + +- [ ] 多个数据集引入 +- [ ] 对于混合模态数据 batchsize只能为1 性能太低 要调整模型代码(也不一定有用) +- [ ] 引入EWC和LWF diff --git a/src/train.sh b/src/train.sh index e660fc4..50a66cb 100755 --- a/src/train.sh +++ b/src/train.sh @@ -1,7 +1,7 @@ #!/bin/bash accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \ - --dataset_name gigaspeech \ + --dataset_name scienceqa \ --use_peft \ --peft_type LORA \ --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \ diff --git a/src/utils/evaluate_tool.py b/src/utils/evaluate_tool.py index 01624d4..29e1330 100644 --- a/src/utils/evaluate_tool.py +++ b/src/utils/evaluate_tool.py @@ -47,15 +47,13 @@ def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = No bar = tqdm(total=len(val_dataloader)) for batch in val_dataloader: + target = batch.pop("answers_ids") + origianl = batch.pop("original_data") answers = [] completion = model.generate( - input_ids=batch["input_ids"], - attention_mask=batch["attention_mask"], - pixel_values=batch["pixel_values"], - image_grid_thw=batch["image_grid_thw"], + **batch, max_length=1000, ) - target = batch["answers_ids"] generated_text = [ out_ids[len(in_ids) :] for out_ids, in_ids in zip(completion, batch["input_ids"]) @@ -69,7 +67,7 @@ def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = No { "generated": generated_text[i], "target": target_text[i], - "original": batch["original_data"][i], + "original": str(origianl[i]), } ) import json