Merge branch 'bugfix/ocr_dataset' into develop
This commit is contained in:
commit
10f532618e
@ -6,5 +6,5 @@ pillow==10.2.0
|
||||
torch==2.5.1+cu124
|
||||
torchaudio==2.5.1+cu124
|
||||
torchvision==0.20.1+cu124
|
||||
transformers==4.47.1
|
||||
transformers==4.46.1
|
||||
trl==0.13.0
|
||||
|
@ -17,9 +17,9 @@ class OCRVQADataset(Dataset):
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
if split == "train":
|
||||
self.data = self.create_data(ann_path, split=1)[:200]
|
||||
self.data = self.create_data(ann_path, split=1)
|
||||
elif split == "test":
|
||||
self.data = self.create_data(ann_path, split=3)[:200]
|
||||
self.data = self.create_data(ann_path, split=3)
|
||||
|
||||
# self.instruction_pool = [
|
||||
# "[vqa] {}",
|
||||
@ -55,9 +55,18 @@ class OCRVQADataset(Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
|
||||
"RGB"
|
||||
)
|
||||
image: Image.Image = Image.open(
|
||||
os.path.join(self.vis_root, sample["image_path"])
|
||||
).convert("RGB")
|
||||
# resize image
|
||||
width, height = image.size
|
||||
if width > 500 or height > 500:
|
||||
max_size = max(width, height)
|
||||
ratio = 500 / max_size
|
||||
new_width = int(width * ratio)
|
||||
new_height = int(height * ratio)
|
||||
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
||||
|
||||
question = sample["question"]
|
||||
answer = sample["answer"]
|
||||
if self.vis_processor is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user