From 76a9c3033967da667626cd6d4ed6866e9a57bfe0 Mon Sep 17 00:00:00 2001 From: YunyaoZhou Date: Tue, 7 Jan 2025 13:58:09 +0800 Subject: [PATCH] =?UTF-8?q?fix=F0=9F=90=9B:=20=E6=9B=B4=E6=96=B0requiremen?= =?UTF-8?q?ts.txt=E4=B8=AD=E7=9A=84transformers=E7=89=88=E6=9C=AC=EF=BC=8C?= =?UTF-8?q?=E7=A7=BB=E9=99=A4OCRVQADataset=E6=95=B0=E6=8D=AE=E9=9B=86?= =?UTF-8?q?=E7=9A=84=E6=A0=B7=E6=9C=AC=E9=99=90=E5=88=B6=E5=B9=B6=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E5=9B=BE=E5=83=8F=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | 2 +- src/dataset_library/OCRVQADataset.py | 19 ++++++++++++++----- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/requirements.txt b/requirements.txt index c09c604..aea51d2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,5 +6,5 @@ pillow==10.2.0 torch==2.5.1+cu124 torchaudio==2.5.1+cu124 torchvision==0.20.1+cu124 -transformers==4.47.1 +transformers==4.46.1 trl==0.13.0 diff --git a/src/dataset_library/OCRVQADataset.py b/src/dataset_library/OCRVQADataset.py index 137f78a..415c3fc 100644 --- a/src/dataset_library/OCRVQADataset.py +++ b/src/dataset_library/OCRVQADataset.py @@ -17,9 +17,9 @@ class OCRVQADataset(Dataset): self.vis_processor = vis_processor self.text_processor = text_processor if split == "train": - self.data = self.create_data(ann_path, split=1)[:200] + self.data = self.create_data(ann_path, split=1) elif split == "test": - self.data = self.create_data(ann_path, split=3)[:200] + self.data = self.create_data(ann_path, split=3) # self.instruction_pool = [ # "[vqa] {}", @@ -55,9 +55,18 @@ class OCRVQADataset(Dataset): def __getitem__(self, index): sample = self.data[index] - image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert( - "RGB" - ) + image: Image.Image = Image.open( + os.path.join(self.vis_root, sample["image_path"]) + ).convert("RGB") + # resize image + width, height = image.size + if width > 500 or height > 500: + max_size = max(width, height) + ratio = 500 / max_size + new_width = int(width * ratio) + new_height = int(height * ratio) + image = image.resize((new_width, new_height), Image.Resampling.BILINEAR) + question = sample["question"] answer = sample["answer"] if self.vis_processor is not None: