From 0b71bfc617201db4f0a6259dd54467a21e62958b Mon Sep 17 00:00:00 2001 From: YunyaoZhou Date: Tue, 7 Jan 2025 13:43:56 +0800 Subject: [PATCH 1/2] =?UTF-8?q?feat=E2=9C=A8:=20=E9=87=8D=E6=9E=84?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E5=BA=93=EF=BC=8C=E6=9B=B4=E6=96=B0=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E9=9B=86=E5=A4=84=E7=90=86=E9=80=BB=E8=BE=91=EF=BC=8C?= =?UTF-8?q?=E4=BC=98=E5=8C=96=E5=AF=BC=E5=85=A5=E8=B7=AF=E5=BE=84=EF=BC=8C?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=96=B0=E7=9A=84=E5=8C=96=E5=AD=A6=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E9=9B=86=E7=B1=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/dataset_library/chem.py | 181 +++++++++ src/evaluation.py | 2 +- .../__init__.py | 0 src/model_library/qwen2vl/__init__.py | 4 + .../qwen2vl/collate_fn.py} | 48 +-- src/model_library/qwen2vl/model.py | 370 ++++++++++++++++++ src/train.py | 16 +- src/utils/args.py | 1 - 8 files changed, 567 insertions(+), 55 deletions(-) create mode 100644 src/dataset_library/chem.py rename src/{collatefn_library => model_library}/__init__.py (100%) create mode 100644 src/model_library/qwen2vl/__init__.py rename src/{collatefn_library/qwen2.py => model_library/qwen2vl/collate_fn.py} (65%) create mode 100644 src/model_library/qwen2vl/model.py diff --git a/src/dataset_library/chem.py b/src/dataset_library/chem.py new file mode 100644 index 0000000..a1999cb --- /dev/null +++ b/src/dataset_library/chem.py @@ -0,0 +1,181 @@ +from PIL import Image +from torch.utils.data import Dataset +import json +import os + + +class ChemDataseet(Dataset): + def __init__( + self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train" + ): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + self.vis_root = vis_root + + self.vis_processor = vis_processor + self.text_processor = text_processor + if split == "train": + self.data = self.create_data(ann_path, split=1)[:200] + elif split == "test": + self.data = self.create_data(ann_path, split=3)[:200] + + def create_data(self, ann_path, split=1): + processed_data = [] + with open(ann_path, "r") as f: + data = json.load(f) + for k in data.keys(): + if data[k]["split"] != split: + continue # 1 for training, 2 for validation, 3 for test + ext = os.path.splitext(data[k]["imageURL"])[1] + imageFile = k + ext + assert len(data[k]["questions"]) == len(data[k]["answers"]) + for q, a in zip(data[k]["questions"], data[k]["answers"]): + if os.path.exists(os.path.join(self.vis_root, imageFile)): + processed_data.append( + { + "question": q, + "answer": a, + "image_path": imageFile, + "image_id": k, + "title": data[k]["title"], + "genre": data[k]["genre"], + } + ) + return processed_data + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + sample = self.data[index] + image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert( + "RGB" + ) + question = sample["question"] + answer = sample["answer"] + if self.vis_processor is not None: + image = self.vis_processor(image) + if self.text_processor is not None: + question = self.text_processor(question) + answer = self.text_processor(answer) + + chat = [ + { + "role": "user", + "content": [ + {"type": "image"}, + { + "type": "text", + "text": f"[vqa] Based on the image, respond to this question with a short answer: {question}", + }, + ], + }, + {"role": "assistant", "content": [{"type": "text", "text": answer}]}, + ] + return { + "image": image, + "chat": chat, + "image_id": sample["image_id"], + } + + +class OCRVQADatasetForGeneration(Dataset): + def __init__( + self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train" + ): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + self.vis_root = vis_root + + self.vis_processor = vis_processor + self.text_processor = text_processor + if split == "train": + self.data = self.create_data(ann_path, split=1)[:200] + elif split == "test": + self.data = self.create_data(ann_path, split=3)[:200] + + # self.instruction_pool = [ + # "[vqa] {}", + # "[vqa] Based on the image, respond to this question with a short answer: {}", + # ] + + def create_data(self, ann_path, split=1): + processed_data = [] + with open(ann_path, "r") as f: + data = json.load(f) + for k in data.keys(): + if data[k]["split"] != split: + continue # 1 for training, 2 for validation, 3 for test + ext = os.path.splitext(data[k]["imageURL"])[1] + imageFile = k + ext + assert len(data[k]["questions"]) == len(data[k]["answers"]) + for q, a in zip(data[k]["questions"], data[k]["answers"]): + if os.path.exists(os.path.join(self.vis_root, imageFile)): + processed_data.append( + { + "question": q, + "answer": a, + "image_path": imageFile, + "image_id": k, + "title": data[k]["title"], + "genre": data[k]["genre"], + } + ) + return processed_data + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + sample = self.data[index] + image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert( + "RGB" + ) + question = sample["question"] + answer = sample["answer"] + if self.vis_processor is not None: + image = self.vis_processor(image) + if self.text_processor is not None: + question = self.text_processor(question) + answer = self.text_processor(answer) + + chat = [ + { + "role": "user", + "content": [ + {"type": "image"}, + { + "type": "text", + "text": f"[vqa] Based on the image, respond to this question with a short answer: {question}", + }, + ], + } + # {"role": "assistant", "content": answer}, + ] + return { + "image": image, + "chat": chat, + "answer": answer, + "image_id": sample["image_id"], + } + +if __name__ == "__main__": + dataset = ChemDataseet( + "/home/zyy/research/accelerate/dataset/chem/images/", + "/home/zyy/research/accelerate/dataset/chem/qwen_data/conversations_loc_train.jsonl", + split="train", + ) + print(len(dataset)) + print(dataset[0]) + dataset = OCRVQADatasetForGeneration( + "/home/zyy/research/accelerate/dataset/OCR-VQA-200K/images", + "/home/zyy/research/accelerate/dataset/OCR-VQA-200K/dataset.json", + split="train", + ) + print(len(dataset)) + print(dataset[0]) + pass \ No newline at end of file diff --git a/src/evaluation.py b/src/evaluation.py index be5b01e..48071ca 100644 --- a/src/evaluation.py +++ b/src/evaluation.py @@ -51,7 +51,7 @@ if __name__ == "__main__": print(model) if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct": - from collatefn_library.qwen2 import ( + from model_library.qwen2 import ( collate_fn_for_train, collate_fn_for_evaluate, ) diff --git a/src/collatefn_library/__init__.py b/src/model_library/__init__.py similarity index 100% rename from src/collatefn_library/__init__.py rename to src/model_library/__init__.py diff --git a/src/model_library/qwen2vl/__init__.py b/src/model_library/qwen2vl/__init__.py new file mode 100644 index 0000000..3d2bfcb --- /dev/null +++ b/src/model_library/qwen2vl/__init__.py @@ -0,0 +1,4 @@ +from .collate_fn import collate_fn_for_evaluate, collate_fn_for_train +from .model import Qwen2VLForConditionalGeneration_modified + +__all__ = ["collate_fn_for_train", "collate_fn_for_evaluate", "Qwen2VLForConditionalGeneration_modified"] \ No newline at end of file diff --git a/src/collatefn_library/qwen2.py b/src/model_library/qwen2vl/collate_fn.py similarity index 65% rename from src/collatefn_library/qwen2.py rename to src/model_library/qwen2vl/collate_fn.py index 6ce09a7..09ee258 100644 --- a/src/collatefn_library/qwen2.py +++ b/src/model_library/qwen2vl/collate_fn.py @@ -10,7 +10,6 @@ def collate_fn_for_train(examples, processor: Qwen2VLProcessor): ] # print(texts) images = [example["image"] for example in examples] - # Tokenize the texts and process the images batch = processor(text=texts, images=images, return_tensors="pt", padding=True) @@ -52,7 +51,7 @@ def collate_fn_for_train(examples, processor: Qwen2VLProcessor): now_index += 1 now_index += 1 batch["labels"] = labels - # batch["task_id"] = torch.tensor([0] * len(labels), dtype=torch.long) + batch["task_id"] = torch.tensor([0] * len(labels), dtype=torch.long) return batch @@ -80,48 +79,5 @@ def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor): # pixel_values torch.Size([3888, 1176]) # image_grid_thw torch.Size([3, 3]) # answers_ids torch.Size([3, 10]) - # answers_mask torch.Size([3, 10]) + # answers_mask torch.Size([3, 10]) return batch - - -if __name__ == "__main__": - from transformers import Qwen2VLProcessor - from dataset_library.OCRVQADataset import OCRVQADatasetForGeneration - - processor = Qwen2VLProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") - dataset = OCRVQADatasetForGeneration( - vis_root="/home/zyy/research/accelerate/dataset/OCR-VQA-200K/images", - ann_path="/home/zyy/research/accelerate/dataset/OCR-VQA-200K/dataset.json", - split="train", - ) - examples = [dataset[i] for i in range(3)] - # print(collate_fn_for_evaluate(examples, processor)) - from transformers import AutoModelForCausalLM - - model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") - from accelerate import Accelerator - from torch.utils.data import DataLoader - - val_dataloader = DataLoader( - dataset, - batch_size=3, - collate_fn=lambda x: collate_fn_for_evaluate(x, processor), - ) - accelerator = Accelerator() - model = accelerator.prepare(model) - val_dataloader = accelerator.prepare(val_dataloader) - - for batch in val_dataloader: - completion = model.generate( - input_ids=batch["input_ids"], - attention_mask=batch["attention_mask"], - pixel_values=batch["pixel_values"], - image_grid_thw=batch["image_grid_thw"], - max_length=100, - ) - target = batch["answers_ids"] - generated_text = [out_ids[len(in_ids) :] for out_ids, in_ids in zip(completion, batch["input_ids"])] - generated_text = processor.tokenizer.batch_decode(generated_text) - target_text = processor.tokenizer.batch_decode(target) - print(generated_text, target_text) - diff --git a/src/model_library/qwen2vl/model.py b/src/model_library/qwen2vl/model.py new file mode 100644 index 0000000..5dfe83e --- /dev/null +++ b/src/model_library/qwen2vl/model.py @@ -0,0 +1,370 @@ +# from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast, Qwen2VLModel, Qwen2VLForConditionalGeneration, logger, DynamicCache, Qwen2VLDecoderLayer, Qwen2VLConfig, Qwen2VLAttention, +from transformers.models.qwen2_vl.modeling_qwen2_vl import * +from transformers.modeling_outputs import BaseModelOutputWithPast +import torch +from typing import Optional, List, Union, Tuple +from torch.nn import CrossEntropyLoss + +class Qwen2VLAttention_modified(Qwen2VLAttention): + +QWEN2_VL_ATTENTION_CLASSES = { + "eager": Qwen2VLAttention, + "flash_attention_2": Qwen2VLFlashAttention2, + "sdpa": Qwen2VLSdpaAttention, +} + + +class Qwen2VLDecoderLayer_modified(Qwen2VLDecoderLayer): + def __init__(self, config: Qwen2VLConfig, layer_idx: int): + super().__init__(config, layer_idx) + + +class Qwen2VLModel_modified(Qwen2VLModel): + def __init__(self, config): + super().__init__(config) + self.layers = nn.ModuleList( + [Qwen2VLDecoderLayer_modified(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand( + 3, inputs_embeds.shape[0], -1 + ) + elif position_ids.dim() == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values, + output_attentions, + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + **kwargs, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class Qwen2VLForConditionalGeneration_modified(Qwen2VLForConditionalGeneration): + def __init__(self, config): + super().__init__(config) + self.model = Qwen2VLModel_modified(config) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration + + >>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.get_dtype()) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_mask = ( + (input_ids == self.config.image_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_embeds = image_embeds.to( + inputs_embeds.device, inputs_embeds.dtype + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + video_mask = ( + (input_ids == self.config.video_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + video_embeds = video_embeds.to( + inputs_embeds.device, inputs_embeds.dtype + ) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme + if ( + position_ids is None + and input_ids is not None + and (attention_mask is None or attention_mask.ndim == 2) + ): + # calculate RoPE index once per generation in the pre-fill stage only + if ( + cache_position is not None and cache_position[0] == 0 + ) or self.rope_deltas is None: + position_ids, rope_deltas = self.get_rope_index( + input_ids, image_grid_thw, video_grid_thw, attention_mask + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + cache_position[0] + self.rope_deltas + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) diff --git a/src/train.py b/src/train.py index f2d3c2f..062d4de 100644 --- a/src/train.py +++ b/src/train.py @@ -53,14 +53,16 @@ if __name__ == "__main__": padding_side="left", ) - model = AutoModelForVision2Seq.from_pretrained( - model_args.model_name_or_path, - trust_remote_code=model_args.trust_remote_code, - **model_kwargs, - ) - if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct": - from collatefn_library.qwen2 import ( + # from transformers import Qwen2VLForConditionalGeneration + from model_library.qwen2vl import Qwen2VLForConditionalGeneration_modified + + model = Qwen2VLForConditionalGeneration_modified.from_pretrained( + model_args.model_name_or_path, + trust_remote_code=model_args.trust_remote_code, + **model_kwargs, + ) + from model_library.qwen2vl import ( collate_fn_for_train, collate_fn_for_evaluate, ) diff --git a/src/utils/args.py b/src/utils/args.py index 55668cb..6c66992 100644 --- a/src/utils/args.py +++ b/src/utils/args.py @@ -1,7 +1,6 @@ from dataclasses import dataclass, field from typing import Optional from trl import ScriptArguments, ModelConfig -from transformers import TrainingArguments @dataclass From a1bb0f7c8c005dd7f9e911b2ac6c44be2a0d3aac Mon Sep 17 00:00:00 2001 From: YunyaoZhou Date: Tue, 7 Jan 2025 15:07:08 +0800 Subject: [PATCH 2/2] =?UTF-8?q?feat=E2=9C=A8:=20=E4=BD=BF=E5=BE=97MOELORA?= =?UTF-8?q?=E6=94=AF=E6=8C=81task=5Fid=E4=BB=A5=E5=8F=8A=E5=85=B6=E4=BB=96?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E7=9A=84=E4=BC=A0=E9=80=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/model_library/qwen2vl/model.py | 330 +++++++++++++++++++++- src/peft_library/tuners/mmoelora/layer.py | 7 +- 2 files changed, 332 insertions(+), 5 deletions(-) diff --git a/src/model_library/qwen2vl/model.py b/src/model_library/qwen2vl/model.py index 5dfe83e..10313a4 100644 --- a/src/model_library/qwen2vl/model.py +++ b/src/model_library/qwen2vl/model.py @@ -1,29 +1,346 @@ # from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast, Qwen2VLModel, Qwen2VLForConditionalGeneration, logger, DynamicCache, Qwen2VLDecoderLayer, Qwen2VLConfig, Qwen2VLAttention, from transformers.models.qwen2_vl.modeling_qwen2_vl import * +from transformers.cache_utils import DynamicCache from transformers.modeling_outputs import BaseModelOutputWithPast import torch from typing import Optional, List, Union, Tuple from torch.nn import CrossEntropyLoss +from torch import nn +from torch.nn import functional as F +from torch import Tensor + + +class LinearLayer(nn.Linear): + def forward(self, input: Tensor, **kwargs) -> Tensor: + return F.linear(input, self.weight, self.bias) + + class Qwen2VLAttention_modified(Qwen2VLAttention): + def __init__(self, config: Qwen2VLConfig, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + self.q_proj = LinearLayer( + self.hidden_size, self.num_heads * self.head_dim, bias=True + ) + self.k_proj = LinearLayer( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True + ) + self.v_proj = LinearLayer( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True + ) + self.o_proj = LinearLayer( + self.num_heads * self.head_dim, self.hidden_size, bias=False + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # Fix precision issues in Qwen2-VL float16 inference + # Replace inf values with zeros in attention weights to prevent NaN propagation + if query_states.dtype == torch.float16: + attn_weights = torch.where( + torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights + ) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Qwen2VLSdpaAttention_modified(Qwen2VLAttention_modified): + """ + Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Qwen2Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Qwen2VLModel is using Qwen2VLSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states, **kwargs) + key_states = self.k_proj(hidden_states, **kwargs) + value_states = self.v_proj(hidden_states, **kwargs) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + QWEN2_VL_ATTENTION_CLASSES = { "eager": Qwen2VLAttention, "flash_attention_2": Qwen2VLFlashAttention2, - "sdpa": Qwen2VLSdpaAttention, + "sdpa": Qwen2VLSdpaAttention_modified, } class Qwen2VLDecoderLayer_modified(Qwen2VLDecoderLayer): def __init__(self, config: Qwen2VLConfig, layer_idx: int): super().__init__(config, layer_idx) + self.self_attn = QWEN2_VL_ATTENTION_CLASSES[config._attn_implementation]( + config, layer_idx + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs class Qwen2VLModel_modified(Qwen2VLModel): def __init__(self, config): super().__init__(config) self.layers = nn.ModuleList( - [Qwen2VLDecoderLayer_modified(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + [ + Qwen2VLDecoderLayer_modified(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] ) def forward( @@ -174,7 +491,16 @@ class Qwen2VLModel_modified(Qwen2VLModel): class Qwen2VLForConditionalGeneration_modified(Qwen2VLForConditionalGeneration): def __init__(self, config): super().__init__(config) + self.visual = Qwen2VisionTransformerPretrainedModel._from_config( + config.vision_config + ) self.model = Qwen2VLModel_modified(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.rope_deltas = None # cache rope_deltas here + + # Initialize weights and apply final processing + self.post_init() def forward( self, diff --git a/src/peft_library/tuners/mmoelora/layer.py b/src/peft_library/tuners/mmoelora/layer.py index 2a2bd60..29e7d21 100644 --- a/src/peft_library/tuners/mmoelora/layer.py +++ b/src/peft_library/tuners/mmoelora/layer.py @@ -151,9 +151,10 @@ class MMOELoraLinear(nn.Module, MMOELoraLayer): def forward(self, x: torch.Tensor, *args, **kwargs): self._check_forward_args(x, *args, **kwargs) adapter_names = kwargs.pop("adapter_names", None) - task_id = kwargs.pop( - "task_id", torch.tensor([0] * len(x), dtype=torch.long).to(x.device) - ) + # task_id = kwargs.pop( + # "task_id", torch.tensor([0] * len(x), dtype=torch.long).to(x.device) + # ) + task_id = kwargs.pop("task_id", torch.tensor([0] * len(x), dtype=torch.long)) previous_dtype = x.dtype if self.disable_adapters: # No adapter