Merge branch 'bugfix/ocr_dataset' into develop

This commit is contained in:
YunyaoZhou 2025-01-07 13:58:26 +08:00
commit 10f532618e
Signed by: shujakuin
GPG Key ID: 418C3CA28E350CCF
2 changed files with 15 additions and 6 deletions

View File

@ -6,5 +6,5 @@ pillow==10.2.0
torch==2.5.1+cu124 torch==2.5.1+cu124
torchaudio==2.5.1+cu124 torchaudio==2.5.1+cu124
torchvision==0.20.1+cu124 torchvision==0.20.1+cu124
transformers==4.47.1 transformers==4.46.1
trl==0.13.0 trl==0.13.0

View File

@ -17,9 +17,9 @@ class OCRVQADataset(Dataset):
self.vis_processor = vis_processor self.vis_processor = vis_processor
self.text_processor = text_processor self.text_processor = text_processor
if split == "train": if split == "train":
self.data = self.create_data(ann_path, split=1)[:200] self.data = self.create_data(ann_path, split=1)
elif split == "test": elif split == "test":
self.data = self.create_data(ann_path, split=3)[:200] self.data = self.create_data(ann_path, split=3)
# self.instruction_pool = [ # self.instruction_pool = [
# "[vqa] {}", # "[vqa] {}",
@ -55,9 +55,18 @@ class OCRVQADataset(Dataset):
def __getitem__(self, index): def __getitem__(self, index):
sample = self.data[index] sample = self.data[index]
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert( image: Image.Image = Image.open(
"RGB" os.path.join(self.vis_root, sample["image_path"])
) ).convert("RGB")
# resize image
width, height = image.size
if width > 500 or height > 500:
max_size = max(width, height)
ratio = 500 / max_size
new_width = int(width * ratio)
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
question = sample["question"] question = sample["question"]
answer = sample["answer"] answer = sample["answer"]
if self.vis_processor is not None: if self.vis_processor is not None: