feat✨: 添加ScienceQA数据集支持,更新数据集工厂和评估逻辑,调整批处理大小
This commit is contained in:
parent
7b9349091e
commit
bcb0494f52
120
src/dataset_library/ScienceQADataset.py
Normal file
120
src/dataset_library/ScienceQADataset.py
Normal file
@ -0,0 +1,120 @@
|
||||
from .format import (
|
||||
Conversation,
|
||||
ConverstationAudio,
|
||||
ConverstationImage,
|
||||
ConverstationText,
|
||||
DatasetOutput,
|
||||
)
|
||||
from torch.utils.data import Dataset
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
class ScienceQADataset(Dataset):
|
||||
def __init__(self, audio_processor=None, text_processor=None, split="train"):
|
||||
"""
|
||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||
ann_root (string): directory to store the annotation file
|
||||
"""
|
||||
|
||||
self.vis_processor = audio_processor
|
||||
self.text_processor = text_processor
|
||||
ds = load_dataset("derek-thomas/ScienceQA")
|
||||
self.data = ds[split]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
# print(sample)
|
||||
# {'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=750x429 at 0x71B9ACD6EF50>, 'question': 'Which of these states is farthest north?', 'choices': ['West Virginia', 'Louisiana', 'Arizona', 'Oklahoma'], 'answer': 0, 'hint': '', 'task': 'closed choice', 'grade': 'grade2', 'subject': 'social science', 'topic': 'geography', 'category': 'Geography', 'skill': 'Read a map: cardinal directions', 'lecture': 'Maps have four cardinal directions, or main directions. Those directions are north, south, east, and west.\nA compass rose is a set of arrows that point to the cardinal directions. A compass rose usually shows only the first letter of each cardinal direction.\nThe north arrow points to the North Pole. On most maps, north is at the top of the map.', 'solution': 'To find the answer, look at the compass rose. Look at which way the north arrow is pointing. West Virginia is farthest north.'}
|
||||
images = sample["image"]
|
||||
question = sample["question"]
|
||||
choices = sample["choices"]
|
||||
task = sample["task"]
|
||||
answer = sample["answer"]
|
||||
|
||||
if self.vis_processor is not None:
|
||||
images = self.vis_processor(images)
|
||||
if self.text_processor is not None:
|
||||
question = self.text_processor(question)
|
||||
|
||||
chat = [
|
||||
Conversation(
|
||||
role="user",
|
||||
content=[
|
||||
ConverstationImage(type="image", image_url=""),
|
||||
ConverstationText(
|
||||
type="text",
|
||||
text=f"[{task}] '{question}' choose from '{choices}'",
|
||||
),
|
||||
],
|
||||
),
|
||||
Conversation(
|
||||
role="assistant",
|
||||
content=[ConverstationText(type="text", text=choices[answer])],
|
||||
),
|
||||
]
|
||||
|
||||
return DatasetOutput(
|
||||
images=[images],
|
||||
chat=chat,
|
||||
original=sample,
|
||||
)
|
||||
|
||||
|
||||
class ScienceQADatasetForGeneration(ScienceQADataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
# print(sample)
|
||||
# {'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=750x429 at 0x71B9ACD6EF50>, 'question': 'Which of these states is farthest north?', 'choices': ['West Virginia', 'Louisiana', 'Arizona', 'Oklahoma'], 'answer': 0, 'hint': '', 'task': 'closed choice', 'grade': 'grade2', 'subject': 'social science', 'topic': 'geography', 'category': 'Geography', 'skill': 'Read a map: cardinal directions', 'lecture': 'Maps have four cardinal directions, or main directions. Those directions are north, south, east, and west.\nA compass rose is a set of arrows that point to the cardinal directions. A compass rose usually shows only the first letter of each cardinal direction.\nThe north arrow points to the North Pole. On most maps, north is at the top of the map.', 'solution': 'To find the answer, look at the compass rose. Look at which way the north arrow is pointing. West Virginia is farthest north.'}
|
||||
images = sample["image"]
|
||||
question = sample["question"]
|
||||
choices = sample["choices"]
|
||||
task = sample["task"]
|
||||
answer = sample["answer"]
|
||||
|
||||
if self.vis_processor is not None:
|
||||
images = self.vis_processor(images)
|
||||
if self.text_processor is not None:
|
||||
question = self.text_processor(question)
|
||||
|
||||
chat = [
|
||||
Conversation(
|
||||
role="user",
|
||||
content=[
|
||||
ConverstationImage(type="image", image_url=""),
|
||||
ConverstationText(
|
||||
type="text",
|
||||
text=f"[{task}] '{question}' choose from '{choices}'",
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
return DatasetOutput(
|
||||
images=[images],
|
||||
chat=chat,
|
||||
answer=choices[answer],
|
||||
original=sample,
|
||||
)
|
||||
|
||||
|
||||
def test_scienceQA():
|
||||
dataset = ScienceQADataset(
|
||||
split="train",
|
||||
)
|
||||
print(dataset[3])
|
||||
assert len(dataset) > 0
|
||||
assert len(dataset[0]["chat"]) > 0
|
||||
dataset = ScienceQADatasetForGeneration(
|
||||
split="train",
|
||||
)
|
||||
print(dataset[3])
|
||||
assert len(dataset) > 0
|
||||
assert len(dataset[0]["chat"]) > 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_scienceQA()
|
@ -78,4 +78,13 @@ def get_dataset(
|
||||
),
|
||||
}
|
||||
|
||||
if dataset_name == "scienceqa":
|
||||
from .ScienceQADataset import ScienceQADataset, ScienceQADatasetForGeneration
|
||||
|
||||
dataset = {
|
||||
"train": ScienceQADataset(split="train"),
|
||||
"test": ScienceQADataset(split="test"),
|
||||
"generation": ScienceQADatasetForGeneration(split="test"),
|
||||
}
|
||||
|
||||
return dataset
|
||||
|
@ -35,3 +35,12 @@ def test_textvqa():
|
||||
|
||||
assert len(dataset["test"]) > 0
|
||||
assert len(dataset["test"][0]["chat"]) > 0
|
||||
|
||||
|
||||
def test_scienceqa():
|
||||
dataset = get_dataset("scienceqa")
|
||||
assert len(dataset["train"]) > 0
|
||||
assert len(dataset["train"][0]["chat"]) > 0
|
||||
|
||||
assert len(dataset["test"]) > 0
|
||||
assert len(dataset["test"][0]["chat"]) > 0
|
||||
|
@ -99,9 +99,11 @@ if __name__ == "__main__":
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
bs = 3 if dataset_name not in ["scienceqa"] else 1
|
||||
|
||||
val_dataloader = DataLoader(
|
||||
dataset[script_args.dataset_generation_split],
|
||||
batch_size=3,
|
||||
batch_size=1,
|
||||
collate_fn=collate_fn_for_evaluate,
|
||||
)
|
||||
val_dataloader = accelerator.prepare_data_loader(val_dataloader)
|
||||
|
@ -1,7 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml evaluation.py \
|
||||
--dataset_name CHEM \
|
||||
--dataset_name scienceqa \
|
||||
--use_peft \
|
||||
--peft_type MMOELORA \
|
||||
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
||||
|
@ -7,14 +7,14 @@ def collate_fn_for_train(examples, processor: Qwen2AudioProcessor):
|
||||
processor.apply_chat_template(example["chat"], tokenize=False)
|
||||
for example in examples
|
||||
]
|
||||
audios = [example["audio"][0] for example in examples]
|
||||
audios = [example["audios"][0] for example in examples]
|
||||
# Tokenize the texts and process the images
|
||||
batch = processor(
|
||||
text=texts,
|
||||
audios=audios,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
sampling_rate=examples[0]["audio"][1],
|
||||
sampling_rate=examples[0]["audios"][1],
|
||||
)
|
||||
|
||||
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
||||
@ -69,10 +69,16 @@ def collate_fn_for_evaluate(examples, processor: Qwen2AudioProcessor):
|
||||
for example in examples
|
||||
]
|
||||
# print(texts)
|
||||
audios = [example["audio"] for example in examples]
|
||||
audios = [example["audios"] for example in examples]
|
||||
|
||||
# Tokenize the texts and process the images
|
||||
batch = processor(text=texts, audios=audios, return_tensors="pt", padding=True)
|
||||
batch = processor(
|
||||
text=texts,
|
||||
audios=audios,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
sampling_rate=examples[0]["audios"][1],
|
||||
)
|
||||
|
||||
answers = [example["answer"] for example in examples]
|
||||
answers = processor(text=answers, return_tensors="pt", padding=True)
|
||||
|
@ -10,7 +10,11 @@ def collate_fn_for_train(examples: list[DatasetOutput], processor: Qwen2VLProces
|
||||
for example in examples
|
||||
]
|
||||
# print(texts)
|
||||
images = [example["images"] for example in examples]
|
||||
images = [
|
||||
example["images"] for example in examples if example["images"][0] is not None
|
||||
]
|
||||
images = images if len(images) > 0 else None
|
||||
|
||||
# Tokenize the texts and process the images
|
||||
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
||||
|
||||
@ -66,7 +70,10 @@ def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
|
||||
for example in examples
|
||||
]
|
||||
# print(texts)
|
||||
images = [example["images"] for example in examples]
|
||||
images = [
|
||||
example["images"] for example in examples if example["images"][0] is not None
|
||||
]
|
||||
images = images if len(images) > 0 else None
|
||||
|
||||
# Tokenize the texts and process the images
|
||||
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
||||
@ -74,7 +81,6 @@ def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
|
||||
answers = [example["answer"] for example in examples]
|
||||
answers = processor(text=answers, return_tensors="pt", padding=True)
|
||||
batch["answers_ids"] = answers["input_ids"]
|
||||
batch["answers_mask"] = answers["attention_mask"]
|
||||
batch["original_data"] = [example["original"] for example in examples]
|
||||
# input_ids torch.Size([3, 370])
|
||||
# attention_mask torch.Size([3, 370])
|
||||
|
10
src/todo.md
10
src/todo.md
@ -12,5 +12,11 @@
|
||||
[2025.01.03]
|
||||
|
||||
- [ ] 处理量化逻辑
|
||||
- [ ] 严查moelora的原始代码,太粗糙了😡
|
||||
- [ ] 未知原因trainer后处理时间长
|
||||
- [X] 严查moelora的原始代码,太粗糙了😡
|
||||
- [X] 未知原因trainer后处理时间长
|
||||
|
||||
[2025.01.19]
|
||||
|
||||
- [ ] 多个数据集引入
|
||||
- [ ] 对于混合模态数据 batchsize只能为1 性能太低 要调整模型代码(也不一定有用)
|
||||
- [ ] 引入EWC和LWF
|
||||
|
@ -1,7 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
|
||||
--dataset_name gigaspeech \
|
||||
--dataset_name scienceqa \
|
||||
--use_peft \
|
||||
--peft_type LORA \
|
||||
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
||||
|
@ -47,15 +47,13 @@ def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = No
|
||||
bar = tqdm(total=len(val_dataloader))
|
||||
|
||||
for batch in val_dataloader:
|
||||
target = batch.pop("answers_ids")
|
||||
origianl = batch.pop("original_data")
|
||||
answers = []
|
||||
completion = model.generate(
|
||||
input_ids=batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
pixel_values=batch["pixel_values"],
|
||||
image_grid_thw=batch["image_grid_thw"],
|
||||
**batch,
|
||||
max_length=1000,
|
||||
)
|
||||
target = batch["answers_ids"]
|
||||
generated_text = [
|
||||
out_ids[len(in_ids) :]
|
||||
for out_ids, in_ids in zip(completion, batch["input_ids"])
|
||||
@ -69,7 +67,7 @@ def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = No
|
||||
{
|
||||
"generated": generated_text[i],
|
||||
"target": target_text[i],
|
||||
"original": batch["original_data"][i],
|
||||
"original": str(origianl[i]),
|
||||
}
|
||||
)
|
||||
import json
|
||||
|
Loading…
Reference in New Issue
Block a user