diff --git a/src/.gitignore b/src/.gitignore new file mode 100644 index 0000000..c8ca89a --- /dev/null +++ b/src/.gitignore @@ -0,0 +1 @@ +checkpoint/* \ No newline at end of file diff --git a/src/accelerate_configs/deepspeed_zero1.yaml b/src/accelerate_configs/deepspeed_zero1.yaml new file mode 100644 index 0000000..d5b5f78 --- /dev/null +++ b/src/accelerate_configs/deepspeed_zero1.yaml @@ -0,0 +1,20 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_accumulation_steps: 1 + zero3_init_flag: false + zero_stage: 1 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/src/accelerate_configs/deepspeed_zero2.yaml b/src/accelerate_configs/deepspeed_zero2.yaml new file mode 100644 index 0000000..239b14a --- /dev/null +++ b/src/accelerate_configs/deepspeed_zero2.yaml @@ -0,0 +1,21 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/src/accelerate_configs/deepspeed_zero3.yaml b/src/accelerate_configs/deepspeed_zero3.yaml new file mode 100644 index 0000000..b5a1201 --- /dev/null +++ b/src/accelerate_configs/deepspeed_zero3.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/src/accelerate_configs/fsdp_qlora.yaml b/src/accelerate_configs/fsdp_qlora.yaml new file mode 100644 index 0000000..93b3541 --- /dev/null +++ b/src/accelerate_configs/fsdp_qlora.yaml @@ -0,0 +1,25 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +fsdp_config: + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_PRE + fsdp_cpu_ram_efficient_loading: true + fsdp_forward_prefetch: false + fsdp_offload_params: true + fsdp_sharding_strategy: FULL_SHARD + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_sync_module_states: true + fsdp_use_orig_params: false +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/src/accelerate_configs/multi_gpu.yaml b/src/accelerate_configs/multi_gpu.yaml new file mode 100644 index 0000000..15dad9b --- /dev/null +++ b/src/accelerate_configs/multi_gpu.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/src/accelerate_configs/single_gpu.yaml b/src/accelerate_configs/single_gpu.yaml new file mode 100644 index 0000000..ebd00a0 --- /dev/null +++ b/src/accelerate_configs/single_gpu.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: "NO" +downcast_bf16: 'no' +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/src/datasets_library/OCRVQADataset.py b/src/datasets_library/OCRVQADataset.py new file mode 100644 index 0000000..62ddfc0 --- /dev/null +++ b/src/datasets_library/OCRVQADataset.py @@ -0,0 +1,86 @@ +from PIL import Image +from torch.utils.data import Dataset +import json +import os + + +class OCRVQADataset(Dataset): + def __init__( + self, vis_root, ann_path, vis_processor=None, text_processor=None, split="train" + ): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + self.vis_root = vis_root + + self.vis_processor = vis_processor + self.text_processor = text_processor + if split == "train": + self.data = self.create_data(ann_path, split=1)[:200] + elif split == "test": + self.data = self.create_data(ann_path, split=3)[:200] + + # self.instruction_pool = [ + # "[vqa] {}", + # "[vqa] Based on the image, respond to this question with a short answer: {}", + # ] + + def create_data(self, ann_path, split=1): + processed_data = [] + with open(ann_path, "r") as f: + data = json.load(f) + for k in data.keys(): + if data[k]["split"] != split: + continue # 1 for training, 2 for validation, 3 for test + ext = os.path.splitext(data[k]["imageURL"])[1] + imageFile = k + ext + assert len(data[k]["questions"]) == len(data[k]["answers"]) + for q, a in zip(data[k]["questions"], data[k]["answers"]): + if os.path.exists(os.path.join(self.vis_root, imageFile)): + processed_data.append( + { + "question": q, + "answer": a, + "image_path": imageFile, + "image_id": k, + "title": data[k]["title"], + "genre": data[k]["genre"], + } + ) + return processed_data + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + sample = self.data[index] + image = Image.open(os.path.join(self.vis_root, sample["image_path"])).convert( + "RGB" + ) + question = sample["question"] + answer = sample["answer"] + if self.vis_processor is not None: + image = self.vis_processor(image) + if self.text_processor is not None: + question = self.text_processor(question) + answer = self.text_processor(answer) + + chat = [ + { + "role": "user", + "content": [ + {"type": "image"}, + { + "type": "text", + "content": f"[vqa] Based on the image, respond to this question with a short answer: {question}", + }, + ], + }, + {"role": "assistant", "content": answer}, + ] + return { + "image": image, + "chat": chat, + "image_id": sample["image_id"], + } diff --git a/src/datasets_library/__init__.py b/src/datasets_library/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/run.sh b/src/run.sh new file mode 100644 index 0000000..242937f --- /dev/null +++ b/src/run.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml sft_vlm.py \ + --dataset_name OCR_VQA_200K \ + --use_peft \ + --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \ + --lora_target_modules q_proj v_proj \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 8 \ + --output_dir sft-llava-1.5-7b-hf \ + --bf16 \ + --torch_dtype bfloat16 diff --git a/src/sft_vlm.py b/src/sft_vlm.py new file mode 100644 index 0000000..4810c43 --- /dev/null +++ b/src/sft_vlm.py @@ -0,0 +1,176 @@ +import torch +from datasets import load_dataset +from transformers import ( + AutoModelForVision2Seq, + AutoProcessor, + LlavaForConditionalGeneration, +) + +from trl import ( + ModelConfig, + SFTScriptArguments, + SFTConfig, + SFTTrainer, + TrlParser, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) + + +if __name__ == "__main__": + parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_and_config() + script_args: SFTScriptArguments + training_args: SFTConfig + model_args: ModelConfig + training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) + training_args.remove_unused_columns = False + training_args.dataset_kwargs = {"skip_prepare_dataset": True} + + ################ + # Model, Tokenizer & Processor + ################ + torch_dtype = ( + model_args.torch_dtype + if model_args.torch_dtype in ["auto", None] + else getattr(torch, model_args.torch_dtype) + ) + quantization_config = get_quantization_config(model_args) + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + torch_dtype=torch_dtype, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + processor = AutoProcessor.from_pretrained( + model_args.model_name_or_path, + trust_remote_code=model_args.trust_remote_code, + padding_side="right", + ) + + model = AutoModelForVision2Seq.from_pretrained( + model_args.model_name_or_path, + trust_remote_code=model_args.trust_remote_code, + **model_kwargs, + ) + + ################ + # Create a data collator to encode text and image pairs + ################ + def collate_fn_qwenv2(examples): + # Get the texts and images, and apply the chat template + texts = [ + processor.apply_chat_template( + example["chat"], tokenize=False, add_generation_prompt=True + ) + for example in examples + ] + # print(texts) + images = [example["image"] for example in examples] + if isinstance(model, LlavaForConditionalGeneration): + # LLava1.5 does not support multiple images + images = [image[0] for image in images] + + # Tokenize the texts and process the images + batch = processor(text=texts, images=images, return_tensors="pt", padding=True) + + # The labels are the input_ids, and we mask the padding tokens in the loss computation + labels = batch["input_ids"].clone() + labels[labels == processor.tokenizer.pad_token_id] = -100 # + # Ignore the image token index in the loss computation (model specific) + # 对<|im_start|>system *** <|im_end|>\n加掩码 + im_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_start|>") + im_end_token_id = processor.tokenizer.convert_tokens_to_ids("<|im_end|>") + system_token_id = processor.tokenizer.convert_tokens_to_ids("system") + user_token_id = processor.tokenizer.convert_tokens_to_ids("user") + assistant_token_id = processor.tokenizer.convert_tokens_to_ids("assistant") + # enter_token_id = processor.tokenizer.convert_tokens_to_ids("\n") + # print(im_start_token_id, im_end_token_id, system_token_id, user_token_id, assistant_token_id, enter_token_id, processor.tokenizer.pad_token_id) + # 151644 151645 8948 872 77091 None 151643 + + for i, label in enumerate(labels): + now_index = 0 + while now_index < len(label): + if label[now_index] == im_start_token_id: + label[now_index] = -100 + now_index += 1 + if ( + label[now_index] == system_token_id + or label[now_index] == user_token_id + ): + while label[now_index] != im_end_token_id: + label[now_index] = -100 + now_index += 1 + label[now_index] = -100 + elif label[now_index] == assistant_token_id: + label[now_index] = -100 + label[now_index + 1] = -100 + now_index += 2 + while ( + now_index < len(label) + and label[now_index] != im_end_token_id + ): + now_index += 1 + now_index += 1 + # print(labels[0], batch["input_ids"][0]) + + # image_token_id = processor.tokenizer.convert_tokens_to_ids( + # processor.image_token + # ) + # labels[labels == image_token_id] = -100 + batch["labels"] = labels + + return batch + + ################ + # Dataset + ################ + base_path = "/home/zyy/research/accelerate/dataset" + if script_args.dataset_name == "OCR_VQA_200K": + import os.path as osp + from datasets_library.OCRVQADataset import OCRVQADataset + + dataset = { + "train": OCRVQADataset( + osp.join(base_path, "OCR-VQA-200K/images"), + osp.join(base_path, "OCR-VQA-200K/dataset.json"), + split="train", + ), + "test": OCRVQADataset( + osp.join(base_path, "OCR-VQA-200K/images"), + osp.join(base_path, "OCR-VQA-200K/dataset.json"), + split="val", + ), + } + + else: + dataset = load_dataset(script_args.dataset_name, name=script_args.config) + + ################ + # Training + ################ + trainer = SFTTrainer( + model=model, + args=training_args, + data_collator=collate_fn_qwenv2, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=( + dataset[script_args.dataset_test_split] + if training_args.eval_strategy != "no" + else None + ), + tokenizer=processor.tokenizer, + # processing_class=processor.tokenizer, + peft_config=get_peft_config(model_args), + ) + + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + if trainer.accelerator.is_main_process: + processor.push_to_hub(training_args.hub_model_id)