feat: 重构模型库,更新数据集处理逻辑,优化导入路径,添加新的化学数据集类

This commit is contained in:
YunyaoZhou 2025-01-07 13:43:56 +08:00
parent 5b09d27920
commit 0b71bfc617
Signed by: shujakuin
GPG Key ID: 418C3CA28E350CCF
8 changed files with 567 additions and 55 deletions

181
src/dataset_library/chem.py Normal file
View File

@ -0,0 +1,181 @@
from PIL import Image
from torch.utils.data import Dataset
import json
import os
class ChemDataseet(Dataset):
def __init__(
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.vis_processor = vis_processor
self.text_processor = text_processor
if split == "train":
self.data = self.create_data(ann_path, split=1)[:200]
elif split == "test":
self.data = self.create_data(ann_path, split=3)[:200]
def create_data(self, ann_path, split=1):
processed_data = []
with open(ann_path, "r") as f:
data = json.load(f)
for k in data.keys():
if data[k]["split"] != split:
continue # 1 for training, 2 for validation, 3 for test
ext = os.path.splitext(data[k]["imageURL"])[1]
imageFile = k + ext
assert len(data[k]["questions"]) == len(data[k]["answers"])
for q, a in zip(data[k]["questions"], data[k]["answers"]):
if os.path.exists(os.path.join(self.vis_root, imageFile)):
processed_data.append(
{
"question": q,
"answer": a,
"image_path": imageFile,
"image_id": k,
"title": data[k]["title"],
"genre": data[k]["genre"],
}
)
return processed_data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
"RGB"
)
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
image = self.vis_processor(image)
if self.text_processor is not None:
question = self.text_processor(question)
answer = self.text_processor(answer)
chat = [
{
"role": "user",
"content": [
{"type": "image"},
{
"type": "text",
"text": f"[vqa] Based on the image, respond to this question with a short answer: {question}",
},
],
},
{"role": "assistant", "content": [{"type": "text", "text": answer}]},
]
return {
"image": image,
"chat": chat,
"image_id": sample["image_id"],
}
class OCRVQADatasetForGeneration(Dataset):
def __init__(
self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train"
):
"""
vis_root (string): Root directory of images (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.vis_root = vis_root
self.vis_processor = vis_processor
self.text_processor = text_processor
if split == "train":
self.data = self.create_data(ann_path, split=1)[:200]
elif split == "test":
self.data = self.create_data(ann_path, split=3)[:200]
# self.instruction_pool = [
# "[vqa] {}",
# "[vqa] Based on the image, respond to this question with a short answer: {}",
# ]
def create_data(self, ann_path, split=1):
processed_data = []
with open(ann_path, "r") as f:
data = json.load(f)
for k in data.keys():
if data[k]["split"] != split:
continue # 1 for training, 2 for validation, 3 for test
ext = os.path.splitext(data[k]["imageURL"])[1]
imageFile = k + ext
assert len(data[k]["questions"]) == len(data[k]["answers"])
for q, a in zip(data[k]["questions"], data[k]["answers"]):
if os.path.exists(os.path.join(self.vis_root, imageFile)):
processed_data.append(
{
"question": q,
"answer": a,
"image_path": imageFile,
"image_id": k,
"title": data[k]["title"],
"genre": data[k]["genre"],
}
)
return processed_data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data[index]
image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert(
"RGB"
)
question = sample["question"]
answer = sample["answer"]
if self.vis_processor is not None:
image = self.vis_processor(image)
if self.text_processor is not None:
question = self.text_processor(question)
answer = self.text_processor(answer)
chat = [
{
"role": "user",
"content": [
{"type": "image"},
{
"type": "text",
"text": f"[vqa] Based on the image, respond to this question with a short answer: {question}",
},
],
}
# {"role": "assistant", "content": answer},
]
return {
"image": image,
"chat": chat,
"answer": answer,
"image_id": sample["image_id"],
}
if __name__ == "__main__":
dataset = ChemDataseet(
"/home/zyy/research/accelerate/dataset/chem/images/",
"/home/zyy/research/accelerate/dataset/chem/qwen_data/conversations_loc_train.jsonl",
split="train",
)
print(len(dataset))
print(dataset[0])
dataset = OCRVQADatasetForGeneration(
"/home/zyy/research/accelerate/dataset/OCR-VQA-200K/images",
"/home/zyy/research/accelerate/dataset/OCR-VQA-200K/dataset.json",
split="train",
)
print(len(dataset))
print(dataset[0])
pass

View File

@ -51,7 +51,7 @@ if __name__ == "__main__":
print(model)
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
from collatefn_library.qwen2 import (
from model_library.qwen2 import (
collate_fn_for_train,
collate_fn_for_evaluate,
)

View File

@ -0,0 +1,4 @@
from .collate_fn import collate_fn_for_evaluate, collate_fn_for_train
from .model import Qwen2VLForConditionalGeneration_modified
__all__ = ["collate_fn_for_train", "collate_fn_for_evaluate", "Qwen2VLForConditionalGeneration_modified"]

View File

@ -10,7 +10,6 @@ def collate_fn_for_train(examples, processor: Qwen2VLProcessor):
]
# print(texts)
images = [example["image"] for example in examples]
# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
@ -52,7 +51,7 @@ def collate_fn_for_train(examples, processor: Qwen2VLProcessor):
now_index += 1
now_index += 1
batch["labels"] = labels
# batch["task_id"] = torch.tensor([0] * len(labels), dtype=torch.long)
batch["task_id"] = torch.tensor([0] * len(labels), dtype=torch.long)
return batch
@ -82,46 +81,3 @@ def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor):
# answers_ids torch.Size([3, 10])
# answers_mask torch.Size([3, 10])
return batch
if __name__ == "__main__":
from transformers import Qwen2VLProcessor
from dataset_library.OCRVQADataset import OCRVQADatasetForGeneration
processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
dataset = OCRVQADatasetForGeneration(
vis_root="/home/zyy/research/accelerate/dataset/OCR-VQA-200K/images",
ann_path="/home/zyy/research/accelerate/dataset/OCR-VQA-200K/dataset.json",
split="train",
)
examples = [dataset[i] for i in range(3)]
# print(collate_fn_for_evaluate(examples, processor))
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
from accelerate import Accelerator
from torch.utils.data import DataLoader
val_dataloader = DataLoader(
dataset,
batch_size=3,
collate_fn=lambda x: collate_fn_for_evaluate(x, processor),
)
accelerator = Accelerator()
model = accelerator.prepare(model)
val_dataloader = accelerator.prepare(val_dataloader)
for batch in val_dataloader:
completion = model.generate(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
pixel_values=batch["pixel_values"],
image_grid_thw=batch["image_grid_thw"],
max_length=100,
)
target = batch["answers_ids"]
generated_text = [out_ids[len(in_ids) :] for out_ids, in_ids in zip(completion, batch["input_ids"])]
generated_text = processor.tokenizer.batch_decode(generated_text)
target_text = processor.tokenizer.batch_decode(target)
print(generated_text, target_text)

View File

@ -0,0 +1,370 @@
# from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast, Qwen2VLModel, Qwen2VLForConditionalGeneration, logger, DynamicCache, Qwen2VLDecoderLayer, Qwen2VLConfig, Qwen2VLAttention,
from transformers.models.qwen2_vl.modeling_qwen2_vl import *
from transformers.modeling_outputs import BaseModelOutputWithPast
import torch
from typing import Optional, List, Union, Tuple
from torch.nn import CrossEntropyLoss
class Qwen2VLAttention_modified(Qwen2VLAttention):
QWEN2_VL_ATTENTION_CLASSES = {
"eager": Qwen2VLAttention,
"flash_attention_2": Qwen2VLFlashAttention2,
"sdpa": Qwen2VLSdpaAttention,
}
class Qwen2VLDecoderLayer_modified(Qwen2VLDecoderLayer):
def __init__(self, config: Qwen2VLConfig, layer_idx: int):
super().__init__(config, layer_idx)
class Qwen2VLModel_modified(Qwen2VLModel):
def __init__(self, config):
super().__init__(config)
self.layers = nn.ModuleList(
[Qwen2VLDecoderLayer_modified(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# torch.jit.trace() doesn't support cache objects in the output
if use_cache and past_key_values is None and not torch.jit.is_tracing():
past_key_values = DynamicCache()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
# the hard coded `3` is for temporal, height and width.
if position_ids is None:
position_ids = cache_position.view(1, 1, -1).expand(
3, inputs_embeds.shape[0], -1
)
elif position_ids.dim() == 2:
position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
past_key_values,
output_attentions,
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
**kwargs,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class Qwen2VLForConditionalGeneration_modified(Qwen2VLForConditionalGeneration):
def __init__(self, config):
super().__init__(config)
self.model = Qwen2VLModel_modified(config)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
>>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
>>> messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.get_dtype())
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_mask = (
(input_ids == self.config.image_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_embeds = image_embeds.to(
inputs_embeds.device, inputs_embeds.dtype
)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_embeds.shape[0]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
video_mask = (
(input_ids == self.config.video_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
video_embeds = video_embeds.to(
inputs_embeds.device, inputs_embeds.dtype
)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
if (
position_ids is None
and input_ids is not None
and (attention_mask is None or attention_mask.ndim == 2)
):
# calculate RoPE index once per generation in the pre-fill stage only
if (
cache_position is not None and cache_position[0] == 0
) or self.rope_deltas is None:
position_ids, rope_deltas = self.get_rope_index(
input_ids, image_grid_thw, video_grid_thw, attention_mask
)
self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
else:
batch_size, seq_length, _ = inputs_embeds.shape
delta = (
cache_position[0] + self.rope_deltas
if cache_position is not None
else 0
)
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
if cache_position is not None: # otherwise `deltas` is an int `0`
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
outputs = self.model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return Qwen2VLCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=self.rope_deltas,
)

View File

@ -53,14 +53,16 @@ if __name__ == "__main__":
padding_side="left",
)
model = AutoModelForVision2Seq.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
from collatefn_library.qwen2 import (
# from transformers import Qwen2VLForConditionalGeneration
from model_library.qwen2vl import Qwen2VLForConditionalGeneration_modified
model = Qwen2VLForConditionalGeneration_modified.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
from model_library.qwen2vl import (
collate_fn_for_train,
collate_fn_for_evaluate,
)

View File

@ -1,7 +1,6 @@
from dataclasses import dataclass, field
from typing import Optional
from trl import ScriptArguments, ModelConfig
from transformers import TrainingArguments
@dataclass