feat✨: 添加CHEM数据集支持并优化图像处理逻辑
This commit is contained in:
parent
8d6e5d5416
commit
b766c21c9b
@ -8,3 +8,4 @@ torchaudio==2.5.1+cu124
|
|||||||
torchvision==0.20.1+cu124
|
torchvision==0.20.1+cu124
|
||||||
transformers==4.46.1
|
transformers==4.46.1
|
||||||
trl==0.13.0
|
trl==0.13.0
|
||||||
|
pillow==9.5.0
|
||||||
|
171
src/dataset_library/CHEM.py
Normal file
171
src/dataset_library/CHEM.py
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class CHEMDataset(Dataset):
|
||||||
|
def __init__(
|
||||||
|
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
vis_root (string): Root directory of images (e.g. coco/images/)
|
||||||
|
ann_root (string): directory to store the annotation file
|
||||||
|
"""
|
||||||
|
self.vis_root = vis_root
|
||||||
|
|
||||||
|
self.vis_processor = (
|
||||||
|
vis_processor if vis_processor is not None else self._vis_processor
|
||||||
|
)
|
||||||
|
self.text_processor = text_processor
|
||||||
|
if split == "train":
|
||||||
|
self.data = self.create_data(ann_path, split="train")
|
||||||
|
elif split == "test":
|
||||||
|
self.data = self.create_data(ann_path, split="test")
|
||||||
|
|
||||||
|
def _vis_processor(self, image: Image.Image):
|
||||||
|
width, height = image.size
|
||||||
|
if width > 800 or height > 800:
|
||||||
|
max_size = max(width, height)
|
||||||
|
ratio = 800 / max_size
|
||||||
|
new_width = int(width * ratio)
|
||||||
|
new_height = int(height * ratio)
|
||||||
|
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
||||||
|
|
||||||
|
if width < 28 or height < 28:
|
||||||
|
min_size = min(width, height)
|
||||||
|
ratio = 28 / min_size + 1
|
||||||
|
new_width = int(width * ratio)
|
||||||
|
new_height = int(height * ratio)
|
||||||
|
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def create_data(self, ann_path, split=1):
|
||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
if split == "train":
|
||||||
|
json_path = osp.join(ann_path, "conversations_loc_train.jsonl")
|
||||||
|
elif split == "test":
|
||||||
|
json_path = osp.join(ann_path, "conversations_loc_val.jsonl")
|
||||||
|
processed_data = []
|
||||||
|
with open(json_path, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
data = json.loads(line)
|
||||||
|
image_file = osp.join(self.vis_root, data["images"][0].split("/")[-1])
|
||||||
|
conversation = data["messages"]
|
||||||
|
system = conversation[0]["content"]
|
||||||
|
query = conversation[1]["content"]
|
||||||
|
answer = conversation[2]["content"]
|
||||||
|
processed_data.append(
|
||||||
|
{
|
||||||
|
"question": query,
|
||||||
|
"answer": answer,
|
||||||
|
"image_path": image_file,
|
||||||
|
"system": system,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return processed_data
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
sample = self.data[index]
|
||||||
|
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
|
||||||
|
"RGB"
|
||||||
|
)
|
||||||
|
question = sample["question"]
|
||||||
|
answer = sample["answer"]
|
||||||
|
if self.vis_processor is not None:
|
||||||
|
image = self.vis_processor(image)
|
||||||
|
if self.text_processor is not None:
|
||||||
|
question = self.text_processor(question)
|
||||||
|
answer = self.text_processor(answer)
|
||||||
|
|
||||||
|
chat = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": sample["system"],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image"},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": f"[vqa] {question.replace('<image>','')}",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "assistant", "content": [{"type": "text", "text": answer}]},
|
||||||
|
]
|
||||||
|
return {
|
||||||
|
"image": image,
|
||||||
|
"chat": chat,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class CHEMDatasetForGeneration(CHEMDataset):
|
||||||
|
def __getitem__(self, index):
|
||||||
|
sample = self.data[index]
|
||||||
|
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
|
||||||
|
"RGB"
|
||||||
|
)
|
||||||
|
question = sample["question"]
|
||||||
|
answer = sample["answer"]
|
||||||
|
if self.vis_processor is not None:
|
||||||
|
image = self.vis_processor(image)
|
||||||
|
if self.text_processor is not None:
|
||||||
|
question = self.text_processor(question)
|
||||||
|
answer = self.text_processor(answer)
|
||||||
|
|
||||||
|
chat = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": sample["system"],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image"},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": f"[vqa] {question.replace('<image>','')}",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
return {
|
||||||
|
"image": image,
|
||||||
|
"chat": chat,
|
||||||
|
"answer": answer,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
dataset = CHEMDataset(
|
||||||
|
"/home/zyy/research/accelerate/dataset/chem/images",
|
||||||
|
"/home/zyy/research/accelerate/dataset/chem/qwen_data",
|
||||||
|
split="train",
|
||||||
|
)
|
||||||
|
print(len(dataset))
|
||||||
|
print(dataset[0])
|
||||||
|
dataset = CHEMDatasetForGeneration(
|
||||||
|
"/home/zyy/research/accelerate/dataset/chem/images",
|
||||||
|
"/home/zyy/research/accelerate/dataset/chem/qwen_data",
|
||||||
|
split="train",
|
||||||
|
)
|
||||||
|
print(len(dataset))
|
||||||
|
print(dataset[0])
|
||||||
|
pass
|
@ -14,7 +14,9 @@ class OCRVQADataset(Dataset):
|
|||||||
"""
|
"""
|
||||||
self.vis_root = vis_root
|
self.vis_root = vis_root
|
||||||
|
|
||||||
self.vis_processor = vis_processor
|
self.vis_processor = (
|
||||||
|
vis_processor if vis_processor is not None else self._vis_processor
|
||||||
|
)
|
||||||
self.text_processor = text_processor
|
self.text_processor = text_processor
|
||||||
if split == "train":
|
if split == "train":
|
||||||
self.data = self.create_data(ann_path, split=1)
|
self.data = self.create_data(ann_path, split=1)
|
||||||
@ -53,12 +55,7 @@ class OCRVQADataset(Dataset):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def _vis_processor(self, image: Image.Image):
|
||||||
sample = self.data[index]
|
|
||||||
image: Image.Image = Image.open(
|
|
||||||
os.path.join(self.vis_root, sample["image_path"])
|
|
||||||
).convert("RGB")
|
|
||||||
# resize image
|
|
||||||
width, height = image.size
|
width, height = image.size
|
||||||
if width > 500 or height > 500:
|
if width > 500 or height > 500:
|
||||||
max_size = max(width, height)
|
max_size = max(width, height)
|
||||||
@ -66,7 +63,7 @@ class OCRVQADataset(Dataset):
|
|||||||
new_width = int(width * ratio)
|
new_width = int(width * ratio)
|
||||||
new_height = int(height * ratio)
|
new_height = int(height * ratio)
|
||||||
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
||||||
|
|
||||||
if width < 28 or height < 28:
|
if width < 28 or height < 28:
|
||||||
min_size = min(width, height)
|
min_size = min(width, height)
|
||||||
ratio = 28 / min_size + 1
|
ratio = 28 / min_size + 1
|
||||||
@ -74,6 +71,15 @@ class OCRVQADataset(Dataset):
|
|||||||
new_height = int(height * ratio)
|
new_height = int(height * ratio)
|
||||||
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
sample = self.data[index]
|
||||||
|
image: Image.Image = Image.open(
|
||||||
|
os.path.join(self.vis_root, sample["image_path"])
|
||||||
|
).convert("RGB")
|
||||||
|
# resize image
|
||||||
|
|
||||||
question = sample["question"]
|
question = sample["question"]
|
||||||
answer = sample["answer"]
|
answer = sample["answer"]
|
||||||
if self.vis_processor is not None:
|
if self.vis_processor is not None:
|
||||||
@ -98,64 +104,17 @@ class OCRVQADataset(Dataset):
|
|||||||
return {
|
return {
|
||||||
"image": image,
|
"image": image,
|
||||||
"chat": chat,
|
"chat": chat,
|
||||||
"image_id": sample["image_id"],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class OCRVQADatasetForGeneration(Dataset):
|
class OCRVQADatasetForGeneration(Dataset):
|
||||||
def __init__(
|
|
||||||
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
|
||||||
ann_root (string): directory to store the annotation file
|
|
||||||
"""
|
|
||||||
self.vis_root = vis_root
|
|
||||||
|
|
||||||
self.vis_processor = vis_processor
|
|
||||||
self.text_processor = text_processor
|
|
||||||
if split == "train":
|
|
||||||
self.data = self.create_data(ann_path, split=1)[:200]
|
|
||||||
elif split == "test":
|
|
||||||
self.data = self.create_data(ann_path, split=3)[:200]
|
|
||||||
|
|
||||||
# self.instruction_pool = [
|
|
||||||
# "[vqa] {}",
|
|
||||||
# "[vqa] Based on the image, respond to this question with a short answer: {}",
|
|
||||||
# ]
|
|
||||||
|
|
||||||
def create_data(self, ann_path, split=1):
|
|
||||||
processed_data = []
|
|
||||||
with open(ann_path, "r") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
for k in data.keys():
|
|
||||||
if data[k]["split"] != split:
|
|
||||||
continue # 1 for training, 2 for validation, 3 for test
|
|
||||||
ext = os.path.splitext(data[k]["imageURL"])[1]
|
|
||||||
imageFile = k + ext
|
|
||||||
assert len(data[k]["questions"]) == len(data[k]["answers"])
|
|
||||||
for q, a in zip(data[k]["questions"], data[k]["answers"]):
|
|
||||||
if os.path.exists(os.path.join(self.vis_root, imageFile)):
|
|
||||||
processed_data.append(
|
|
||||||
{
|
|
||||||
"question": q,
|
|
||||||
"answer": a,
|
|
||||||
"image_path": imageFile,
|
|
||||||
"image_id": k,
|
|
||||||
"title": data[k]["title"],
|
|
||||||
"genre": data[k]["genre"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return processed_data
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.data)
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
sample = self.data[index]
|
sample = self.data[index]
|
||||||
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
|
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
|
||||||
"RGB"
|
"RGB"
|
||||||
)
|
)
|
||||||
|
# resize image
|
||||||
question = sample["question"]
|
question = sample["question"]
|
||||||
answer = sample["answer"]
|
answer = sample["answer"]
|
||||||
if self.vis_processor is not None:
|
if self.vis_processor is not None:
|
||||||
@ -181,5 +140,4 @@ class OCRVQADatasetForGeneration(Dataset):
|
|||||||
"image": image,
|
"image": image,
|
||||||
"chat": chat,
|
"chat": chat,
|
||||||
"answer": answer,
|
"answer": answer,
|
||||||
"image_id": sample["image_id"],
|
|
||||||
}
|
}
|
||||||
|
@ -1,181 +0,0 @@
|
|||||||
from PIL import Image
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
|
|
||||||
class ChemDataseet(Dataset):
|
|
||||||
def __init__(
|
|
||||||
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
|
||||||
ann_root (string): directory to store the annotation file
|
|
||||||
"""
|
|
||||||
self.vis_root = vis_root
|
|
||||||
|
|
||||||
self.vis_processor = vis_processor
|
|
||||||
self.text_processor = text_processor
|
|
||||||
if split == "train":
|
|
||||||
self.data = self.create_data(ann_path, split=1)[:200]
|
|
||||||
elif split == "test":
|
|
||||||
self.data = self.create_data(ann_path, split=3)[:200]
|
|
||||||
|
|
||||||
def create_data(self, ann_path, split=1):
|
|
||||||
processed_data = []
|
|
||||||
with open(ann_path, "r") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
for k in data.keys():
|
|
||||||
if data[k]["split"] != split:
|
|
||||||
continue # 1 for training, 2 for validation, 3 for test
|
|
||||||
ext = os.path.splitext(data[k]["imageURL"])[1]
|
|
||||||
imageFile = k + ext
|
|
||||||
assert len(data[k]["questions"]) == len(data[k]["answers"])
|
|
||||||
for q, a in zip(data[k]["questions"], data[k]["answers"]):
|
|
||||||
if os.path.exists(os.path.join(self.vis_root, imageFile)):
|
|
||||||
processed_data.append(
|
|
||||||
{
|
|
||||||
"question": q,
|
|
||||||
"answer": a,
|
|
||||||
"image_path": imageFile,
|
|
||||||
"image_id": k,
|
|
||||||
"title": data[k]["title"],
|
|
||||||
"genre": data[k]["genre"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return processed_data
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.data)
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
sample = self.data[index]
|
|
||||||
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
|
|
||||||
"RGB"
|
|
||||||
)
|
|
||||||
question = sample["question"]
|
|
||||||
answer = sample["answer"]
|
|
||||||
if self.vis_processor is not None:
|
|
||||||
image = self.vis_processor(image)
|
|
||||||
if self.text_processor is not None:
|
|
||||||
question = self.text_processor(question)
|
|
||||||
answer = self.text_processor(answer)
|
|
||||||
|
|
||||||
chat = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "image"},
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
{"role": "assistant", "content": [{"type": "text", "text": answer}]},
|
|
||||||
]
|
|
||||||
return {
|
|
||||||
"image": image,
|
|
||||||
"chat": chat,
|
|
||||||
"image_id": sample["image_id"],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class OCRVQADatasetForGeneration(Dataset):
|
|
||||||
def __init__(
|
|
||||||
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
vis_root (string): Root directory of images (e.g. coco/images/)
|
|
||||||
ann_root (string): directory to store the annotation file
|
|
||||||
"""
|
|
||||||
self.vis_root = vis_root
|
|
||||||
|
|
||||||
self.vis_processor = vis_processor
|
|
||||||
self.text_processor = text_processor
|
|
||||||
if split == "train":
|
|
||||||
self.data = self.create_data(ann_path, split=1)[:200]
|
|
||||||
elif split == "test":
|
|
||||||
self.data = self.create_data(ann_path, split=3)[:200]
|
|
||||||
|
|
||||||
# self.instruction_pool = [
|
|
||||||
# "[vqa] {}",
|
|
||||||
# "[vqa] Based on the image, respond to this question with a short answer: {}",
|
|
||||||
# ]
|
|
||||||
|
|
||||||
def create_data(self, ann_path, split=1):
|
|
||||||
processed_data = []
|
|
||||||
with open(ann_path, "r") as f:
|
|
||||||
data = json.load(f)
|
|
||||||
for k in data.keys():
|
|
||||||
if data[k]["split"] != split:
|
|
||||||
continue # 1 for training, 2 for validation, 3 for test
|
|
||||||
ext = os.path.splitext(data[k]["imageURL"])[1]
|
|
||||||
imageFile = k + ext
|
|
||||||
assert len(data[k]["questions"]) == len(data[k]["answers"])
|
|
||||||
for q, a in zip(data[k]["questions"], data[k]["answers"]):
|
|
||||||
if os.path.exists(os.path.join(self.vis_root, imageFile)):
|
|
||||||
processed_data.append(
|
|
||||||
{
|
|
||||||
"question": q,
|
|
||||||
"answer": a,
|
|
||||||
"image_path": imageFile,
|
|
||||||
"image_id": k,
|
|
||||||
"title": data[k]["title"],
|
|
||||||
"genre": data[k]["genre"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return processed_data
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.data)
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
sample = self.data[index]
|
|
||||||
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
|
|
||||||
"RGB"
|
|
||||||
)
|
|
||||||
question = sample["question"]
|
|
||||||
answer = sample["answer"]
|
|
||||||
if self.vis_processor is not None:
|
|
||||||
image = self.vis_processor(image)
|
|
||||||
if self.text_processor is not None:
|
|
||||||
question = self.text_processor(question)
|
|
||||||
answer = self.text_processor(answer)
|
|
||||||
|
|
||||||
chat = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{"type": "image"},
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": f"[vqa] Based on the image, respond to this question with a short answer: {question}",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
# {"role": "assistant", "content": answer},
|
|
||||||
]
|
|
||||||
return {
|
|
||||||
"image": image,
|
|
||||||
"chat": chat,
|
|
||||||
"answer": answer,
|
|
||||||
"image_id": sample["image_id"],
|
|
||||||
}
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
dataset = ChemDataseet(
|
|
||||||
"/home/zyy/research/accelerate/dataset/chem/images/",
|
|
||||||
"/home/zyy/research/accelerate/dataset/chem/qwen_data/conversations_loc_train.jsonl",
|
|
||||||
split="train",
|
|
||||||
)
|
|
||||||
print(len(dataset))
|
|
||||||
print(dataset[0])
|
|
||||||
dataset = OCRVQADatasetForGeneration(
|
|
||||||
"/home/zyy/research/accelerate/dataset/OCR-VQA-200K/images",
|
|
||||||
"/home/zyy/research/accelerate/dataset/OCR-VQA-200K/dataset.json",
|
|
||||||
split="train",
|
|
||||||
)
|
|
||||||
print(len(dataset))
|
|
||||||
print(dataset[0])
|
|
||||||
pass
|
|
@ -27,4 +27,25 @@ def get_dataset(
|
|||||||
split="test",
|
split="test",
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
if dataset_name == "CHEM":
|
||||||
|
import os.path as osp
|
||||||
|
from .CHEM import CHEMDataset, CHEMDatasetForGeneration
|
||||||
|
|
||||||
|
dataset = {
|
||||||
|
"train": CHEMDataset(
|
||||||
|
osp.join(base_path, "chem/images"),
|
||||||
|
osp.join(base_path, "chem/qwen_data"),
|
||||||
|
split="train",
|
||||||
|
),
|
||||||
|
"test": CHEMDataset(
|
||||||
|
osp.join(base_path, "chem/images"),
|
||||||
|
osp.join(base_path, "chem/qwen_data"),
|
||||||
|
split="test",
|
||||||
|
),
|
||||||
|
"generation": CHEMDatasetForGeneration(
|
||||||
|
osp.join(base_path, "chem/images"),
|
||||||
|
osp.join(base_path, "chem/qwen_data"),
|
||||||
|
split="test",
|
||||||
|
),
|
||||||
|
}
|
||||||
return dataset
|
return dataset
|
||||||
|
@ -13,7 +13,9 @@ from utils.args import ContinualScriptArguments, ContinualModelConfig
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = TrlParser((ContinualScriptArguments, TrainingArguments, ContinualModelConfig))
|
parser = TrlParser(
|
||||||
|
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig)
|
||||||
|
)
|
||||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||||
# for type hint
|
# for type hint
|
||||||
if 0 == 1:
|
if 0 == 1:
|
||||||
@ -48,10 +50,9 @@ if __name__ == "__main__":
|
|||||||
trust_remote_code=model_args.trust_remote_code,
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
print(model)
|
|
||||||
|
|
||||||
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
|
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
|
||||||
from model_library.qwen2 import (
|
from model_library.qwen2vl import (
|
||||||
collate_fn_for_train,
|
collate_fn_for_train,
|
||||||
collate_fn_for_evaluate,
|
collate_fn_for_evaluate,
|
||||||
)
|
)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
|
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
|
||||||
--dataset_name OCR_VQA_200K OCR_VQA_200K OCR_VQA_200K \
|
--dataset_name CHEM \
|
||||||
--use_peft \
|
--use_peft \
|
||||||
--peft_type MMOELORA \
|
--peft_type MMOELORA \
|
||||||
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
||||||
|
Loading…
Reference in New Issue
Block a user