feat: 添加ScienceQA数据集支持,更新数据集工厂和评估逻辑,调整批处理大小

This commit is contained in:
YunyaoZhou 2025-01-19 00:14:03 +08:00
parent 7b9349091e
commit bcb0494f52
Signed by: shujakuin
GPG Key ID: 418C3CA28E350CCF
10 changed files with 174 additions and 18 deletions

View File

@ -0,0 +1,120 @@
from .format import (
Conversation,
ConverstationAudio,
ConverstationImage,
ConverstationText,
DatasetOutput,
)
from torch.utils.data import Dataset
from datasets import load_dataset
class ScienceQADataset(Dataset):
def __init__(self, audio_processor=None, text_processor=None, split="train"):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_processor = audio_processor
self.text_processor = text_processor
ds = load_dataset("derek-thomas/ScienceQA")
self.data = ds[split]
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
# {'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=750x429 at 0x71B9ACD6EF50>, 'question': 'Which of these states is farthest north?', 'choices': ['West Virginia', 'Louisiana', 'Arizona', 'Oklahoma'], 'answer': 0, 'hint': '', 'task': 'closed choice', 'grade': 'grade2', 'subject': 'social science', 'topic': 'geography', 'category': 'Geography', 'skill': 'Read a map: cardinal directions', 'lecture': 'Maps have four cardinal directions, or main directions. Those directions are north, south, east, and west.\nA compass rose is a set of arrows that point to the cardinal directions. A compass rose usually shows only the first letter of each cardinal direction.\nThe north arrow points to the North Pole. On most maps, north is at the top of the map.', 'solution': 'To find the answer, look at the compass rose. Look at which way the north arrow is pointing. West Virginia is farthest north.'}
images = sample["image"]
question = sample["question"]
choices = sample["choices"]
task = sample["task"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=f"[{task}] '{question}' choose from '{choices}'",
),
],
),
Conversation(
role="assistant",
content=[ConverstationText(type="text", text=choices[answer])],
),
]
return DatasetOutput(
images=[images],
chat=chat,
original=sample,
)
class ScienceQADatasetForGeneration(ScienceQADataset):
def __getitem__(self, index):
sample = self.data[index]
# print(sample)
# {'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=750x429 at 0x71B9ACD6EF50>, 'question': 'Which of these states is farthest north?', 'choices': ['West Virginia', 'Louisiana', 'Arizona', 'Oklahoma'], 'answer': 0, 'hint': '', 'task': 'closed choice', 'grade': 'grade2', 'subject': 'social science', 'topic': 'geography', 'category': 'Geography', 'skill': 'Read a map: cardinal directions', 'lecture': 'Maps have four cardinal directions, or main directions. Those directions are north, south, east, and west.\nA compass rose is a set of arrows that point to the cardinal directions. A compass rose usually shows only the first letter of each cardinal direction.\nThe north arrow points to the North Pole. On most maps, north is at the top of the map.', 'solution': 'To find the answer, look at the compass rose. Look at which way the north arrow is pointing. West Virginia is farthest north.'}
images = sample["image"]
question = sample["question"]
choices = sample["choices"]
task = sample["task"]
answer = sample["answer"]
if self.vis_processor is not None:
images = self.vis_processor(images)
if self.text_processor is not None:
question = self.text_processor(question)
chat = [
Conversation(
role="user",
content=[
ConverstationImage(type="image", image_url=""),
ConverstationText(
type="text",
text=f"[{task}] '{question}' choose from '{choices}'",
),
],
),
]
return DatasetOutput(
images=[images],
chat=chat,
answer=choices[answer],
original=sample,
)
def test_scienceQA():
dataset = ScienceQADataset(
split="train",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
dataset = ScienceQADatasetForGeneration(
split="train",
)
print(dataset[3])
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
if __name__ == "__main__":
test_scienceQA()

View File

@ -78,4 +78,13 @@ def get_dataset(
),
}
if dataset_name == "scienceqa":
from .ScienceQADataset import ScienceQADataset, ScienceQADatasetForGeneration
dataset = {
"train": ScienceQADataset(split="train"),
"test": ScienceQADataset(split="test"),
"generation": ScienceQADatasetForGeneration(split="test"),
}
return dataset

View File

@ -35,3 +35,12 @@ def test_textvqa():
assert len(dataset["test"]) > 0
assert len(dataset["test"][0]["chat"]) > 0
def test_scienceqa():
dataset = get_dataset("scienceqa")
assert len(dataset["train"]) > 0
assert len(dataset["train"][0]["chat"]) > 0
assert len(dataset["test"]) > 0
assert len(dataset["test"][0]["chat"]) > 0

View File

@ -99,9 +99,11 @@ if __name__ == "__main__":
from torch.utils.data import DataLoader
bs = 3 if dataset_name not in ["scienceqa"] else 1
val_dataloader = DataLoader(
dataset[script_args.dataset_generation_split],
batch_size=3,
batch_size=1,
collate_fn=collate_fn_for_evaluate,
)
val_dataloader = accelerator.prepare_data_loader(val_dataloader)

View File

@ -1,7 +1,7 @@
#!/bin/bash
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml evaluation.py \
--dataset_name CHEM \
--dataset_name scienceqa \
--use_peft \
--peft_type MMOELORA \
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \

View File

@ -7,14 +7,14 @@ def collate_fn_for_train(examples, processor: Qwen2AudioProcessor):
processor.apply_chat_template(example["chat"], tokenize=False)
for example in examples
]
audios = [example["audio"][0] for example in examples]
audios = [example["audios"][0] for example in examples]
# Tokenize the texts and process the images
batch = processor(
text=texts,
audios=audios,
return_tensors="pt",
padding=True,
sampling_rate=examples[0]["audio"][1],
sampling_rate=examples[0]["audios"][1],
)
# The labels are the input_ids, and we mask the padding tokens in the loss computation
@ -69,10 +69,16 @@ def collate_fn_for_evaluate(examples, processor: Qwen2AudioProcessor):
for example in examples
]
# print(texts)
audios = [example["audio"] for example in examples]
audios = [example["audios"] for example in examples]
# Tokenize the texts and process the images
batch = processor(text=texts, audios=audios, return_tensors="pt", padding=True)
batch = processor(
text=texts,
audios=audios,
return_tensors="pt",
padding=True,
sampling_rate=examples[0]["audios"][1],
)
answers = [example["answer"] for example in examples]
answers = processor(text=answers, return_tensors="pt", padding=True)

View File

@ -10,7 +10,11 @@ def collate_fn_for_train(examples: list[DatasetOutput], processor: Qwen2VLProces
for example in examples
]
# print(texts)
images = [example["images"] for example in examples]
images = [
example["images"] for example in examples if example["images"][0] is not None
]
images = images if len(images) > 0 else None
# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
@ -66,7 +70,10 @@ def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
for example in examples
]
# print(texts)
images = [example["images"] for example in examples]
images = [
example["images"] for example in examples if example["images"][0] is not None
]
images = images if len(images) > 0 else None
# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
@ -74,7 +81,6 @@ def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
answers = [example["answer"] for example in examples]
answers = processor(text=answers, return_tensors="pt", padding=True)
batch["answers_ids"] = answers["input_ids"]
batch["answers_mask"] = answers["attention_mask"]
batch["original_data"] = [example["original"] for example in examples]
# input_ids torch.Size([3, 370])
# attention_mask torch.Size([3, 370])

View File

@ -12,5 +12,11 @@
[2025.01.03]
- [ ] 处理量化逻辑
- [ ] 严查moelora的原始代码太粗糙了😡
- [ ] 未知原因trainer后处理时间长
- [X] 严查moelora的原始代码太粗糙了😡
- [X] 未知原因trainer后处理时间长
[2025.01.19]
- [ ] 多个数据集引入
- [ ] 对于混合模态数据 batchsize只能为1 性能太低 要调整模型代码(也不一定有用)
- [ ] 引入EWC和LWF

View File

@ -1,7 +1,7 @@
#!/bin/bash
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
--dataset_name gigaspeech \
--dataset_name scienceqa \
--use_peft \
--peft_type LORA \
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \

View File

@ -47,15 +47,13 @@ def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = No
bar = tqdm(total=len(val_dataloader))
for batch in val_dataloader:
target = batch.pop("answers_ids")
origianl = batch.pop("original_data")
answers = []
completion = model.generate(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
pixel_values=batch["pixel_values"],
image_grid_thw=batch["image_grid_thw"],
**batch,
max_length=1000,
)
target = batch["answers_ids"]
generated_text = [
out_ids[len(in_ids) :]
for out_ids, in_ids in zip(completion, batch["input_ids"])
@ -69,7 +67,7 @@ def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = No
{
"generated": generated_text[i],
"target": target_text[i],
"original": batch["original_data"][i],
"original": str(origianl[i]),
}
)
import json