添加PEFT库的初始化文件,更新数据集导入路径,修改训练脚本以支持新的PEFT类型和配置,新增持续学习模型配置类,添加PEFT类型枚举,更新评估和训练逻辑以适应新结构
This commit is contained in:
@@ -0,0 +1,169 @@
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
class OCRVQADataset(Dataset):
|
||||
def __init__(
|
||||
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
|
||||
):
|
||||
"""
|
||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||
ann_root (string): directory to store the annotation file
|
||||
"""
|
||||
self.vis_root = vis_root
|
||||
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
if split == "train":
|
||||
self.data = self.create_data(ann_path, split=1)[:1]
|
||||
elif split == "test":
|
||||
self.data = self.create_data(ann_path, split=3)[:1]
|
||||
|
||||
# self.instruction_pool = [
|
||||
# "[vqa] {}",
|
||||
# "[vqa] Based on the image, respond to this question with a short answer: {}",
|
||||
# ]
|
||||
|
||||
def create_data(self, ann_path, split=1):
|
||||
processed_data = []
|
||||
with open(ann_path, "r") as f:
|
||||
data = json.load(f)
|
||||
for k in data.keys():
|
||||
if data[k]["split"] != split:
|
||||
continue # 1 for training, 2 for validation, 3 for test
|
||||
ext = os.path.splitext(data[k]["imageURL"])[1]
|
||||
imageFile = k + ext
|
||||
assert len(data[k]["questions"]) == len(data[k]["answers"])
|
||||
for q, a in zip(data[k]["questions"], data[k]["answers"]):
|
||||
if os.path.exists(os.path.join(self.vis_root, imageFile)):
|
||||
processed_data.append(
|
||||
{
|
||||
"question": q,
|
||||
"answer": a,
|
||||
"image_path": imageFile,
|
||||
"image_id": k,
|
||||
"title": data[k]["title"],
|
||||
"genre": data[k]["genre"],
|
||||
}
|
||||
)
|
||||
return processed_data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
|
||||
"RGB"
|
||||
)
|
||||
question = sample["question"]
|
||||
answer = sample["answer"]
|
||||
if self.vis_processor is not None:
|
||||
image = self.vis_processor(image)
|
||||
if self.text_processor is not None:
|
||||
question = self.text_processor(question)
|
||||
answer = self.text_processor(answer)
|
||||
|
||||
chat = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": answer}]},
|
||||
]
|
||||
return {
|
||||
"image": image,
|
||||
"chat": chat,
|
||||
"image_id": sample["image_id"],
|
||||
}
|
||||
|
||||
|
||||
class OCRVQADatasetForGeneration(Dataset):
|
||||
def __init__(
|
||||
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
|
||||
):
|
||||
"""
|
||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||
ann_root (string): directory to store the annotation file
|
||||
"""
|
||||
self.vis_root = vis_root
|
||||
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
if split == "train":
|
||||
self.data = self.create_data(ann_path, split=1)[:200]
|
||||
elif split == "test":
|
||||
self.data = self.create_data(ann_path, split=3)[:200]
|
||||
|
||||
# self.instruction_pool = [
|
||||
# "[vqa] {}",
|
||||
# "[vqa] Based on the image, respond to this question with a short answer: {}",
|
||||
# ]
|
||||
|
||||
def create_data(self, ann_path, split=1):
|
||||
processed_data = []
|
||||
with open(ann_path, "r") as f:
|
||||
data = json.load(f)
|
||||
for k in data.keys():
|
||||
if data[k]["split"] != split:
|
||||
continue # 1 for training, 2 for validation, 3 for test
|
||||
ext = os.path.splitext(data[k]["imageURL"])[1]
|
||||
imageFile = k + ext
|
||||
assert len(data[k]["questions"]) == len(data[k]["answers"])
|
||||
for q, a in zip(data[k]["questions"], data[k]["answers"]):
|
||||
if os.path.exists(os.path.join(self.vis_root, imageFile)):
|
||||
processed_data.append(
|
||||
{
|
||||
"question": q,
|
||||
"answer": a,
|
||||
"image_path": imageFile,
|
||||
"image_id": k,
|
||||
"title": data[k]["title"],
|
||||
"genre": data[k]["genre"],
|
||||
}
|
||||
)
|
||||
return processed_data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
|
||||
"RGB"
|
||||
)
|
||||
question = sample["question"]
|
||||
answer = sample["answer"]
|
||||
if self.vis_processor is not None:
|
||||
image = self.vis_processor(image)
|
||||
if self.text_processor is not None:
|
||||
question = self.text_processor(question)
|
||||
answer = self.text_processor(answer)
|
||||
|
||||
chat = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
||||
},
|
||||
],
|
||||
}
|
||||
# {"role": "assistant", "content": answer},
|
||||
]
|
||||
return {
|
||||
"image": image,
|
||||
"chat": chat,
|
||||
"answer": answer,
|
||||
"image_id": sample["image_id"],
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
from torch.utils.data import Dataset
|
||||
from typing import Literal
|
||||
|
||||
|
||||
def get_dataset(
|
||||
dataset_name, base_path="/home/zyy/research/accelerate/dataset"
|
||||
) -> dict[Literal["train", "test", "generation"], Dataset]:
|
||||
dataset: dict[Literal["train", "test", "generation"], Dataset] = {}
|
||||
if dataset_name == "OCR_VQA_200K":
|
||||
import os.path as osp
|
||||
from .OCRVQADataset import OCRVQADataset, OCRVQADatasetForGeneration
|
||||
|
||||
dataset = {
|
||||
"train": OCRVQADataset(
|
||||
osp.join(base_path, "OCR-VQA-200K/images"),
|
||||
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
|
||||
split="train",
|
||||
),
|
||||
"test": OCRVQADataset(
|
||||
osp.join(base_path, "OCR-VQA-200K/images"),
|
||||
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
|
||||
split="test",
|
||||
),
|
||||
"generation": OCRVQADatasetForGeneration(
|
||||
osp.join(base_path, "OCR-VQA-200K/images"),
|
||||
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
|
||||
split="test",
|
||||
),
|
||||
}
|
||||
return dataset
|
||||
Reference in New Issue
Block a user