diff --git a/src/dataset_library/OCRVQADataset.py b/src/dataset_library/OCRVQADataset.py index 415c3fc..58043c9 100644 --- a/src/dataset_library/OCRVQADataset.py +++ b/src/dataset_library/OCRVQADataset.py @@ -66,6 +66,13 @@ class OCRVQADataset(Dataset): new_width = int(width * ratio) new_height = int(height * ratio) image = image.resize((new_width, new_height), Image.Resampling.BILINEAR) + + if width < 28 or height < 28: + min_size = min(width, height) + ratio = 28 / min_size + 1 + new_width = int(width * ratio) + new_height = int(height * ratio) + image = image.resize((new_width, new_height), Image.Resampling.BILINEAR) question = sample["question"] answer = sample["answer"]