diff --git a/requirements.txt b/requirements.txt index c09c604..aea51d2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,5 +6,5 @@ pillow==10.2.0 torch==2.5.1+cu124 torchaudio==2.5.1+cu124 torchvision==0.20.1+cu124 -transformers==4.47.1 +transformers==4.46.1 trl==0.13.0 diff --git a/src/dataset_library/OCRVQADataset.py b/src/dataset_library/OCRVQADataset.py index 137f78a..415c3fc 100644 --- a/src/dataset_library/OCRVQADataset.py +++ b/src/dataset_library/OCRVQADataset.py @@ -17,9 +17,9 @@ class OCRVQADataset(Dataset): self.vis_processor = vis_processor self.text_processor = text_processor if split == "train": - self.data = self.create_data(ann_path, split=1)[:200] + self.data = self.create_data(ann_path, split=1) elif split == "test": - self.data = self.create_data(ann_path, split=3)[:200] + self.data = self.create_data(ann_path, split=3) # self.instruction_pool = [ # "[vqa] {}", @@ -55,9 +55,18 @@ class OCRVQADataset(Dataset): def __getitem__(self, index): sample = self.data[index] - image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert( - "RGB" - ) + image: Image.Image = Image.open( + os.path.join(self.vis_root, sample["image_path"]) + ).convert("RGB") + # resize image + width, height = image.size + if width > 500 or height > 500: + max_size = max(width, height) + ratio = 500 / max_size + new_width = int(width * ratio) + new_height = int(height * ratio) + image = image.resize((new_width, new_height), Image.Resampling.BILINEAR) + question = sample["question"] answer = sample["answer"] if self.vis_processor is not None: