From 1b7fea800e770471c725dac5a607996b04cec000 Mon Sep 17 00:00:00 2001 From: YunyaoZhou Date: Wed, 15 Jan 2025 03:48:58 +0800 Subject: [PATCH] =?UTF-8?q?feat=E2=9C=A8:=20=E6=9B=B4=E6=96=B0=E4=BE=9D?= =?UTF-8?q?=E8=B5=96=E9=A1=B9=EF=BC=8C=E4=BF=AE=E6=94=B9=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E9=9B=86=E5=90=8D=E7=A7=B0=E4=B8=BACHEM=EF=BC=8C=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E8=AE=AD=E7=BB=83=E5=92=8C=E8=AF=84=E4=BC=B0=E8=84=9A?= =?UTF-8?q?=E6=9C=AC=EF=BC=8C=E6=B7=BB=E5=8A=A0=E5=8E=9F=E5=A7=8B=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 44 ++ pyproject.toml | 4 + src/dataset_library/CHEM.py | 2 + src/dataset_library/factory.py | 6 +- src/evaluation.py | 71 +-- src/evaluation.sh | 4 +- src/model_library/factory.py | 71 +-- src/model_library/qwen2vl/collate_fn.py | 3 +- src/train.py | 20 +- src/train.sh | 11 +- src/utils/evaluate_tool.py | 61 ++- src/utils/trainer.py | 672 +++++++++++++++++++++++- uv.lock | 56 ++ 13 files changed, 918 insertions(+), 107 deletions(-) diff --git a/README.md b/README.md index 188457c..60796eb 100644 --- a/README.md +++ b/README.md @@ -6,3 +6,47 @@ uv sync uv sync --extra compile ``` + +## Scripts + +```bash +uv run -- ./train.sh +uv run -- ./evaluation.sh +``` + +## Recommand Structure + +```bash +. +├── install.sh +├── LICENSE +├── pyproject.toml +├── README.md +├── rsync.sh +├── src +│ ├── configs +│ ├── dataset_library +│ ├── evaluation.py +│ ├── evaluation.sh +│ ├── model_library +│ ├── peft_library +│ ├── todo.md +│ ├── train.py +│ ├── train.sh +│ └── utils +├── dataset +│   ├── chem +│   │   ├── conversations_loc_train.jsonl +│   │   ├── conversations_loc_val.jsonl +│   │   └── images +│   ├── OCR-VQA-200K +│   │   ├── dataset.json +│   │   ├── images +│   │   ├── LICENCE.txt +│   │   └── loadDataset.py +│   └── TextCaps +│   ├── TextCaps_0.1_train.json +│   ├── train_val_images.zip +│   └── wget-log +└── uv.lock +``` diff --git a/pyproject.toml b/pyproject.toml index 9751420..131d5c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,15 +1,19 @@ [project] dependencies = [ + "absl-py>=2.1.0", "accelerate==1.2.1", "datasets==3.2.0", "deepspeed==0.16.2", "evaluate==0.4.3", "librosa>=0.10.2.post1", "markupsafe==2.1.5", + "nltk>=3.9.1", "numba>=0.60.0", "peft==0.14.0", "pip==24.3.1", "requests==2.32.3", + "rouge-score>=0.1.2", + "safetensors>=0.5.2", "setuptools>=70.0.0", "soundfile>=0.13.0", "torch==2.5.1+cu124", diff --git a/src/dataset_library/CHEM.py b/src/dataset_library/CHEM.py index 732165f..27a8849 100644 --- a/src/dataset_library/CHEM.py +++ b/src/dataset_library/CHEM.py @@ -63,6 +63,7 @@ class CHEMDataset(Dataset): "answer": answer, "image_path": image_file, "system": system, + "original": data, } ) return processed_data @@ -150,6 +151,7 @@ class CHEMDatasetForGeneration(CHEMDataset): "image": image, "chat": chat, "answer": answer, + "original": sample["original"], } diff --git a/src/dataset_library/factory.py b/src/dataset_library/factory.py index 21b30e2..579449a 100644 --- a/src/dataset_library/factory.py +++ b/src/dataset_library/factory.py @@ -34,17 +34,17 @@ def get_dataset( dataset = { "train": CHEMDataset( osp.join(base_path, "chem/images"), - osp.join(base_path, "chem/qwen_data"), + osp.join(base_path, "chem"), split="train", ), "test": CHEMDataset( osp.join(base_path, "chem/images"), - osp.join(base_path, "chem/qwen_data"), + osp.join(base_path, "chem"), split="test", ), "generation": CHEMDatasetForGeneration( osp.join(base_path, "chem/images"), - osp.join(base_path, "chem/qwen_data"), + osp.join(base_path, "chem"), split="test", ), } diff --git a/src/evaluation.py b/src/evaluation.py index 1b8fd04..1a1e41f 100644 --- a/src/evaluation.py +++ b/src/evaluation.py @@ -1,6 +1,11 @@ import torch from dataset_library.factory import get_dataset -from transformers import AutoModelForVision2Seq, AutoProcessor, TrainingArguments +from transformers import ( + AutoModelForVision2Seq, + AutoProcessor, + TrainingArguments, + modeling_utils, +) from trl import ( ModelConfig, @@ -26,32 +31,41 @@ if __name__ == "__main__": training_args.remove_unused_columns = False training_args.dataset_kwargs = {"skip_prepare_dataset": True} - torch_dtype = ( - model_args.torch_dtype - if model_args.torch_dtype in ["auto", None] - else getattr(torch, model_args.torch_dtype) - ) - quantization_config = get_quantization_config(model_args) - model_kwargs = dict( - revision=model_args.model_revision, - attn_implementation=model_args.attn_implementation, - torch_dtype=torch_dtype, - device_map=get_kbit_device_map() if quantization_config is not None else None, - quantization_config=quantization_config, - ) - processor = AutoProcessor.from_pretrained( - model_args.model_name_or_path, - trust_remote_code=model_args.trust_remote_code, - padding_side="left", - ) - - model = AutoModelForVision2Seq.from_pretrained( - training_args.output_dir, - trust_remote_code=model_args.trust_remote_code, - **model_kwargs, - ) + from model_library.factory import get_model if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct": + torch_dtype = ( + model_args.torch_dtype + if model_args.torch_dtype in ["auto", None] + else getattr(torch, model_args.torch_dtype) + ) + quantization_config = get_quantization_config(model_args) + model_kwargs = dict( + attn_implementation=model_args.attn_implementation, + torch_dtype=torch_dtype, + quantization_config=quantization_config, + ) + from transformers import ( + Qwen2VLProcessor, + Qwen2VLForConditionalGeneration, + AutoModelForVision2Seq, + AutoModel, + ) + from peft.peft_model import PeftModelForCausalLM + from model_library.qwen2vl import Qwen2VLForConditionalGeneration_modified + + model = Qwen2VLForConditionalGeneration.from_pretrained( + training_args.output_dir, + **model_kwargs, + ) + + # from peft_library import get_peft_model + + processor = Qwen2VLProcessor.from_pretrained( + model_args.model_name_or_path, + trust_remote_code=model_args.trust_remote_code, + padding_side="left", + ) from model_library.qwen2vl import ( collate_fn_for_train, collate_fn_for_evaluate, @@ -61,9 +75,6 @@ if __name__ == "__main__": collate_fn_for_train = partial(collate_fn_for_train, processor=processor) collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor) - # peft_config = get_peft_config(model_args) - # model = get_peft_model(model, peft_config) - ################ # Dataset ################ @@ -89,6 +100,6 @@ if __name__ == "__main__": collate_fn=collate_fn_for_evaluate, ) val_dataloader = accelerator.prepare_data_loader(val_dataloader) - from utils.evaluate_tool import evaluate_rouge + from utils.evaluate_tool import evaluate_rouge, evalute_save - evaluate_rouge(model, val_dataloader, processor) + evalute_save(model, val_dataloader, processor, accelerator) diff --git a/src/evaluation.sh b/src/evaluation.sh index a440e82..0e0655f 100755 --- a/src/evaluation.sh +++ b/src/evaluation.sh @@ -1,7 +1,7 @@ #!/bin/bash accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml evaluation.py \ - --dataset_name OCR_VQA_200K \ + --dataset_name CHEM \ --use_peft \ --peft_type MMOELORA \ --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \ @@ -9,7 +9,7 @@ accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml --per_device_train_batch_size 1 \ --per_device_eval_batch_size 2 \ --gradient_accumulation_steps 8 \ - --output_dir checkpoint/sft-llava-1.5-7b-hf \ + --output_dir checkpoint/qwen2/ \ --bf16 \ --torch_dtype bfloat16 # --eval_strategy epoch \ diff --git a/src/model_library/factory.py b/src/model_library/factory.py index 8aee3df..74ed698 100644 --- a/src/model_library/factory.py +++ b/src/model_library/factory.py @@ -4,29 +4,29 @@ from trl import ( # get_peft_config, get_quantization_config, ) +from utils.args import ContinualModelConfig -def get_model(model_args): +def get_model(model_args: ContinualModelConfig): + torch_dtype = ( + model_args.torch_dtype + if model_args.torch_dtype in ["auto", None] + else getattr(torch, model_args.torch_dtype) + ) + quantization_config = get_quantization_config(model_args) + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + torch_dtype=torch_dtype, + device_map=(get_kbit_device_map() if quantization_config is not None else None), + quantization_config=quantization_config, + ) + if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct": - torch_dtype = ( - model_args.torch_dtype - if model_args.torch_dtype in ["auto", None] - else getattr(torch, model_args.torch_dtype) - ) - quantization_config = get_quantization_config(model_args) - model_kwargs = dict( - revision=model_args.model_revision, - attn_implementation=model_args.attn_implementation, - torch_dtype=torch_dtype, - device_map=( - get_kbit_device_map() if quantization_config is not None else None - ), - quantization_config=quantization_config, - ) - from transformers import Qwen2VLProcessor + from transformers import Qwen2VLProcessor, Qwen2VLForConditionalGeneration from model_library.qwen2vl import Qwen2VLForConditionalGeneration_modified - model = Qwen2VLForConditionalGeneration_modified.from_pretrained( + model = Qwen2VLForConditionalGeneration.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs, @@ -36,7 +36,6 @@ def get_model(model_args): trust_remote_code=model_args.trust_remote_code, padding_side="left", ) - print(model) from model_library.qwen2vl import ( collate_fn_for_train, collate_fn_for_evaluate, @@ -47,21 +46,6 @@ def get_model(model_args): collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor) if model_args.model_name_or_path == "Qwen/Qwen2-Audio-7B-Instruct": - torch_dtype = ( - model_args.torch_dtype - if model_args.torch_dtype in ["auto", None] - else getattr(torch, model_args.torch_dtype) - ) - quantization_config = get_quantization_config(model_args) - model_kwargs = dict( - revision=model_args.model_revision, - attn_implementation=model_args.attn_implementation, - torch_dtype=torch_dtype, - device_map=( - get_kbit_device_map() if quantization_config is not None else None - ), - quantization_config=quantization_config, - ) from transformers import Qwen2AudioProcessor, Qwen2AudioForConditionalGeneration model = Qwen2AudioForConditionalGeneration.from_pretrained( @@ -74,7 +58,6 @@ def get_model(model_args): trust_remote_code=model_args.trust_remote_code, padding_side="left", ) - print(model) from model_library.qwen2audio import ( collate_fn_for_train, collate_fn_for_evaluate, @@ -84,22 +67,4 @@ def get_model(model_args): collate_fn_for_train = partial(collate_fn_for_train, processor=processor) collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor) - if model_args.model_name_or_path == "VITA-MLLM/VITA-1.5": - # from transformers import - # from model_library.qwen2vl import Qwen2VLForConditionalGeneration_modified - - model = Qwen2VLForConditionalGeneration_modified.from_pretrained( - model_args.model_name_or_path, - trust_remote_code=model_args.trust_remote_code, - **model_kwargs, - ) - print(model) - from model_library.qwen2vl import ( - collate_fn_for_train, - collate_fn_for_evaluate, - ) - from functools import partial - - collate_fn_for_train = partial(collate_fn_for_train, processor=processor) - collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor) return model, processor, collate_fn_for_train, collate_fn_for_evaluate diff --git a/src/model_library/qwen2vl/collate_fn.py b/src/model_library/qwen2vl/collate_fn.py index 09ee258..9446dee 100644 --- a/src/model_library/qwen2vl/collate_fn.py +++ b/src/model_library/qwen2vl/collate_fn.py @@ -51,7 +51,7 @@ def collate_fn_for_train(examples, processor: Qwen2VLProcessor): now_index += 1 now_index += 1 batch["labels"] = labels - batch["task_id"] = torch.tensor([0] * len(labels), dtype=torch.long) + # batch["task_id"] = torch.tensor([0] * len(labels), dtype=torch.long) return batch @@ -74,6 +74,7 @@ def collate_fn_for_evaluate(examples, processor: Qwen2VLProcessor): answers = processor(text=answers, return_tensors="pt", padding=True) batch["answers_ids"] = answers["input_ids"] batch["answers_mask"] = answers["attention_mask"] + batch["original_data"] = [example["original"] for example in examples] # input_ids torch.Size([3, 370]) # attention_mask torch.Size([3, 370]) # pixel_values torch.Size([3888, 1176]) diff --git a/src/train.py b/src/train.py index 491cef5..7f68bfb 100644 --- a/src/train.py +++ b/src/train.py @@ -1,4 +1,6 @@ from dataset_library.factory import get_dataset + + from transformers import ( TrainingArguments, ) @@ -10,6 +12,9 @@ from peft_library import get_peft_model from utils.trainer import ContinualTrainer from utils.args import ContinualScriptArguments, ContinualModelConfig +import logging + +logging.basicConfig(level=logging.INFO) if __name__ == "__main__": @@ -31,7 +36,6 @@ if __name__ == "__main__": model, processor, collate_fn_for_train, collate_fn_for_evaluate = get_model( model_args ) - ################ # Dataset ################ @@ -52,16 +56,20 @@ if __name__ == "__main__": elif model_args.peft_type == "LORA": from peft.tuners.lora import LoraConfig - peft_config = LoraConfig(target_modules=model_args.lora_target_modules, r=2) + peft_config = LoraConfig(target_modules=model_args.lora_target_modules) - model = get_peft_model(model, peft_config) + # model = get_peft_model(model, peft_config) + model.add_adapter(peft_config) - if accelerator.is_local_main_process: - model.print_trainable_parameters() + # if accelerator.is_local_main_process: + # model.print_trainable_parameters() else: peft_config = None + if accelerator.is_local_main_process: + print(model) + for dataset_name in script_args.dataset_name: dataset = get_dataset(dataset_name) model.train() @@ -83,6 +91,7 @@ if __name__ == "__main__": if accelerator.is_local_main_process: print("Saving model") trainer.save_model(training_args.output_dir) + if accelerator.is_local_main_process: print("Model saved") # 同步 accelerator @@ -98,6 +107,7 @@ if __name__ == "__main__": collate_fn=collate_fn_for_evaluate, ) val_dataloader = accelerator.prepare(val_dataloader) + from utils.evaluate_tool import evaluate_rouge evaluate_rouge(model, val_dataloader, processor) diff --git a/src/train.sh b/src/train.sh index 7374f79..50b26f7 100755 --- a/src/train.sh +++ b/src/train.sh @@ -1,14 +1,15 @@ #!/bin/bash accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \ - --dataset_name gigaspeech \ + --dataset_name CHEM \ --use_peft \ --peft_type LORA \ - --model_name_or_path Qwen/Qwen2-Audio-7B-Instruct \ + --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \ --lora_target_modules q_proj v_proj \ --per_device_train_batch_size 1 \ --per_device_eval_batch_size 2 \ - --gradient_accumulation_steps 16 \ - --output_dir checkpoint/sft-llava-1.5-7b-hf \ + --gradient_accumulation_steps 4 \ + --output_dir checkpoint/qwen2/ \ --bf16 \ - --torch_dtype bfloat16 + --torch_dtype bfloat16 \ + --logging_steps 30 diff --git a/src/utils/evaluate_tool.py b/src/utils/evaluate_tool.py index 18ab1f7..494c8d5 100644 --- a/src/utils/evaluate_tool.py +++ b/src/utils/evaluate_tool.py @@ -1,7 +1,8 @@ import evaluate +from accelerate import Accelerator -def evaluate_rouge(model, val_dataloader, processor): +def evaluate_rouge(model, val_dataloader, processor, accelerator: Accelerator = None): glue = evaluate.load("rouge") for batch in val_dataloader: @@ -23,4 +24,60 @@ def evaluate_rouge(model, val_dataloader, processor): target_text = processor.tokenizer.batch_decode(target, skip_special_tokens=True) glue.add_batch(predictions=generated_text, references=target_text) - print(glue.compute()) \ No newline at end of file + print(glue.compute()) + + +def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = None): + import os + mtime = 0 + for root, dirs, files in os.walk("."): + for file in files: + time = os.path.getmtime(os.path.join(root, file)) + if time > mtime: + mtime = time + + # 获取目录最后修改时间 + if not os.path.exists(f"results/{mtime}"): + os.makedirs(f"results/{mtime}") + + + from tqdm import tqdm + + if accelerator.is_local_main_process: + bar = tqdm(total=len(val_dataloader)) + + for batch in val_dataloader: + answers = [] + completion = model.generate( + input_ids=batch["input_ids"], + attention_mask=batch["attention_mask"], + pixel_values=batch["pixel_values"], + image_grid_thw=batch["image_grid_thw"], + max_length=1000, + ) + target = batch["answers_ids"] + generated_text = [ + out_ids[len(in_ids) :] + for out_ids, in_ids in zip(completion, batch["input_ids"]) + ] + generated_text = processor.tokenizer.batch_decode( + generated_text, skip_special_tokens=True + ) + target_text = processor.tokenizer.batch_decode(target, skip_special_tokens=True) + for i in range(len(generated_text)): + answers.append( + { + "generated": generated_text[i], + "target": target_text[i], + "original": batch["original_data"][i], + } + ) + import json + + world_size = accelerator.process_index + + with open(f"results/{mtime}/answers_{world_size}.jsonl", "a") as f: + for answer in answers: + f.write(json.dumps(answer) + "\n") + if accelerator.is_local_main_process: + bar.update(1) diff --git a/src/utils/trainer.py b/src/utils/trainer.py index 0152005..cd997a7 100644 --- a/src/utils/trainer.py +++ b/src/utils/trainer.py @@ -1,11 +1,7 @@ # _________________________________________________________ -from transformers import Trainer -from transformers.trainer import ( - Trainer, - _is_peft_model, - MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, -) + +from transformers.trainer import * class ContinualTrainer(Trainer): @@ -78,3 +74,667 @@ class ContinualTrainer(Trainer): loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] return (loss, outputs) if return_outputs else loss + + def _inner_training_loop( + self, + batch_size=None, + args=None, + resume_from_checkpoint=None, + trial=None, + ignore_keys_for_eval=None, + ): + self.accelerator.free_memory() + self._train_batch_size = batch_size + if self.args.auto_find_batch_size: + if self.state.train_batch_size != self._train_batch_size: + from accelerate.utils import release_memory + + (self.model_wrapped,) = release_memory(self.model_wrapped) + self.model_wrapped = self.model + + # Check for DeepSpeed *after* the intial pass and modify the config + if self.is_deepspeed_enabled: + # Temporarily unset `self.args.train_batch_size` + original_bs = self.args.per_device_train_batch_size + self.args.per_device_train_batch_size = ( + self._train_batch_size // max(1, self.args.n_gpu) + ) + self.propagate_args_to_deepspeed(True) + self.args.per_device_train_batch_size = original_bs + self.state.train_batch_size = self._train_batch_size + logger.debug( + f"Currently training with a batch size of: {self._train_batch_size}" + ) + # Data loader and number of training steps + train_dataloader = self.get_train_dataloader() + if self.is_fsdp_xla_v2_enabled: + train_dataloader = tpu_spmd_dataloader(train_dataloader) + + # Setting up training control variables: + # number of training epochs: num_train_epochs + # number of training steps per epoch: num_update_steps_per_epoch + # total number of training steps to execute: max_steps + total_train_batch_size = ( + self._train_batch_size * args.gradient_accumulation_steps * args.world_size + ) + + len_dataloader = None + num_train_tokens = None + if has_length(train_dataloader): + len_dataloader = len(train_dataloader) + num_update_steps_per_epoch = ( + len_dataloader // args.gradient_accumulation_steps + ) + num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) + num_examples = self.num_examples(train_dataloader) + if args.max_steps > 0: + max_steps = args.max_steps + num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( + args.max_steps % num_update_steps_per_epoch > 0 + ) + # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's + # the best we can do. + num_train_samples = args.max_steps * total_train_batch_size + if args.include_tokens_per_second: + num_train_tokens = ( + self.num_tokens(train_dataloader, args.max_steps) + * args.gradient_accumulation_steps + ) + else: + max_steps = math.ceil( + args.num_train_epochs * num_update_steps_per_epoch + ) + num_train_epochs = math.ceil(args.num_train_epochs) + num_train_samples = ( + self.num_examples(train_dataloader) * args.num_train_epochs + ) + if args.include_tokens_per_second: + num_train_tokens = ( + self.num_tokens(train_dataloader) * args.num_train_epochs + ) + elif ( + args.max_steps > 0 + ): # Rely on max_steps when dataloader does not have a working size + max_steps = args.max_steps + # Setting a very large number of epochs so we go as many times as necessary over the iterator. + num_train_epochs = sys.maxsize + num_update_steps_per_epoch = max_steps + num_examples = total_train_batch_size * args.max_steps + num_train_samples = args.max_steps * total_train_batch_size + if args.include_tokens_per_second: + num_train_tokens = ( + self.num_tokens(train_dataloader, args.max_steps) + * args.gradient_accumulation_steps + ) + else: + raise ValueError( + "args.max_steps must be set to a positive value if dataloader does not have a length, was" + f" {args.max_steps}" + ) + + if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: + if self.args.n_gpu > 1: + # nn.DataParallel(model) replicates the model, creating new variables and module + # references registered here no longer work on other gpus, breaking the module + raise ValueError( + "Currently --debug underflow_overflow is not supported under DP. Please use DDP" + " (torchrun or torch.distributed.launch (deprecated))." + ) + else: + debug_overflow = DebugUnderflowOverflow(self.model) # noqa + + delay_optimizer_creation = ( + is_sagemaker_mp_enabled() + or self.is_fsdp_xla_enabled + or self.is_fsdp_enabled + ) + + # We need to reset the scheduler, as its parameters may be different on subsequent calls + if self._created_lr_scheduler: + self.lr_scheduler = None + self._created_lr_scheduler = False + + if self.is_deepspeed_enabled: + self.optimizer, self.lr_scheduler = deepspeed_init( + self, num_training_steps=max_steps + ) + + if not delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + self.state = TrainerState( + stateful_callbacks=[ + cb + for cb in self.callback_handler.callbacks + [self.control] + if isinstance(cb, ExportableState) + ] + ) + self.state.is_hyper_param_search = trial is not None + self.state.train_batch_size = self._train_batch_size + + # Compute absolute values for logging, eval, and save if given as ratio + if args.logging_steps is not None: + if args.logging_steps < 1: + self.state.logging_steps = math.ceil(max_steps * args.logging_steps) + else: + self.state.logging_steps = args.logging_steps + if args.eval_steps is not None: + if args.eval_steps < 1: + self.state.eval_steps = math.ceil(max_steps * args.eval_steps) + else: + self.state.eval_steps = args.eval_steps + if args.save_steps is not None: + if args.save_steps < 1: + self.state.save_steps = math.ceil(max_steps * args.save_steps) + else: + self.state.save_steps = args.save_steps + + # Activate gradient checkpointing if needed + if args.gradient_checkpointing: + self.model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs + ) + + model = self._wrap_model(self.model_wrapped) + + # as the model is wrapped, don't use `accelerator.prepare` + # this is for unhandled cases such as + # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX + use_accelerator_prepare = True if model is self.model else False + + if use_accelerator_prepare and self.is_fsdp_enabled: + # In case of auto_find_batch_size=True + # Remove FSDP wrapping from sub-models. + self.model = unwrap_model(self.model, recursive=True) + + if delay_optimizer_creation: + if use_accelerator_prepare: + # configure fsdp plugin for qlora if any + self._fsdp_qlora_plugin_updates() + if self.accelerator.mixed_precision != "fp8": + self.model = self.accelerator.prepare(self.model) + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + # prepare using `accelerator` prepare + if use_accelerator_prepare: + self.model.train() + if hasattr(self.lr_scheduler, "step"): + if self.use_apex: + model = self.accelerator.prepare(self.model) + else: + model, self.optimizer = self.accelerator.prepare( + self.model, self.optimizer + ) + else: + # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config. + model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( + self.model, self.optimizer, self.lr_scheduler + ) + elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: + # In this case we are in DDP + LOMO, which should be supported + self.optimizer = self.accelerator.prepare(self.optimizer) + + if self.is_fsdp_enabled: + self.model = self.model_wrapped = model + + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model_wrapped + + # ckpt loading + if resume_from_checkpoint is not None: + if self.is_deepspeed_enabled: + deepspeed_load_checkpoint( + self.model_wrapped, + resume_from_checkpoint, + load_module_strict=not _is_peft_model(self.model), + ) + elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled: + self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped) + + # Check if saved optimizer or scheduler states exist + self._load_optimizer_and_scheduler(resume_from_checkpoint) + + # important: at this point: + # self.model is the Transformers Model + # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), + # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc. + + # Train! + logger.info("***** Running training *****") + logger.info(f" Num examples = {num_examples:,}") + logger.info(f" Num Epochs = {num_train_epochs:,}") + logger.info( + f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}" + ) + if self.args.per_device_train_batch_size != self._train_batch_size: + logger.info( + f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}" + ) + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}" + ) + logger.info( + f" Gradient Accumulation steps = {args.gradient_accumulation_steps}" + ) + logger.info(f" Total optimization steps = {max_steps:,}") + logger.info( + f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}" + ) + + self.state.epoch = 0 + start_time = time.time() + epochs_trained = 0 + steps_trained_in_current_epoch = 0 + steps_trained_progress_bar = None + + # Check if continuing training from a checkpoint + if resume_from_checkpoint is not None and os.path.isfile( + os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) + ): + self.state = TrainerState.load_from_json( + os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) + ) + self.compare_trainer_and_checkpoint_args(self.args, self.state) + self._load_callback_state() + epochs_trained = int(self.state.global_step // num_update_steps_per_epoch) + if not args.ignore_data_skip: + steps_trained_in_current_epoch = self.state.global_step % ( + num_update_steps_per_epoch + ) + steps_trained_in_current_epoch *= args.gradient_accumulation_steps + else: + steps_trained_in_current_epoch = 0 + + logger.info( + " Continuing training from checkpoint, will skip to saved global_step" + ) + logger.info(f" Continuing training from epoch {epochs_trained}") + logger.info( + f" Continuing training from global step {self.state.global_step}" + ) + if not args.ignore_data_skip: + logger.info( + f" Will skip the first {epochs_trained} epochs then the first" + f" {steps_trained_in_current_epoch} batches in the first epoch." + ) + + # Update the references + self.callback_handler.model = self.model + self.callback_handler.optimizer = self.optimizer + self.callback_handler.lr_scheduler = self.lr_scheduler + self.callback_handler.train_dataloader = train_dataloader + if self.hp_name is not None and self._trial is not None: + # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial + # parameter to Train when using DDP. + self.state.trial_name = self.hp_name(self._trial) + if trial is not None: + assignments = ( + trial.assignments + if self.hp_search_backend == HPSearchBackend.SIGOPT + else trial + ) + self.state.trial_params = hp_params(assignments) + else: + self.state.trial_params = None + # This should be the same if the state has been saved but in case the training arguments changed, it's safer + # to set this after the load. + self.state.max_steps = max_steps + self.state.num_train_epochs = num_train_epochs + self.state.is_local_process_zero = self.is_local_process_zero() + self.state.is_world_process_zero = self.is_world_process_zero() + + # tr_loss is a tensor to avoid synchronization of TPUs through .item() + tr_loss = torch.tensor(0.0).to(args.device) + # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses + self._total_loss_scalar = 0.0 + self._globalstep_last_logged = self.state.global_step + model.zero_grad() + grad_norm: Optional[float] = None + self.control = self.callback_handler.on_train_begin( + args, self.state, self.control + ) + + if args.eval_on_start: + self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True) + + for epoch in range(epochs_trained, num_train_epochs): + epoch_dataloader = train_dataloader + if hasattr(epoch_dataloader, "set_epoch"): + epoch_dataloader.set_epoch(epoch) + + # Reset the past mems state at the beginning of each epoch if necessary. + if args.past_index >= 0: + self._past = None + + steps_in_epoch = ( + len(epoch_dataloader) + if len_dataloader is not None + else args.max_steps * args.gradient_accumulation_steps + ) + self.control = self.callback_handler.on_epoch_begin( + args, self.state, self.control + ) + + if ( + epoch == epochs_trained + and resume_from_checkpoint is not None + and steps_trained_in_current_epoch == 0 + ): + self._load_rng_state(resume_from_checkpoint) + + rng_to_sync = False + steps_skipped = 0 + if steps_trained_in_current_epoch > 0: + epoch_dataloader = skip_first_batches( + epoch_dataloader, steps_trained_in_current_epoch + ) + steps_skipped = steps_trained_in_current_epoch + steps_trained_in_current_epoch = 0 + rng_to_sync = True + + step = -1 + epoch_iterator = iter(epoch_dataloader) + # We chunkify the epoch iterator into gradient accumulation steps `n` batches + remainder = num_examples % args.gradient_accumulation_steps + if remainder == 0: + remainder = args.gradient_accumulation_steps + update_step = -1 + total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1 + for _ in range(total_updates): + update_step += 1 + num_batches = ( + args.gradient_accumulation_steps + if update_step != (total_updates - 1) + else remainder + ) + batch_samples, num_items_in_batch = self.get_batch_samples( + epoch_iterator, num_batches + ) + for i, inputs in enumerate(batch_samples): + step += 1 + do_sync_step = ( + step + 1 + ) % args.gradient_accumulation_steps == 0 or ( + step + 1 + ) == steps_in_epoch + # Since we perform prefetching, we need to manually set sync_gradients + if not do_sync_step: + self.accelerator.gradient_state._set_sync_gradients(False) + else: + self.accelerator.gradient_state._set_sync_gradients(True) + + if self.args.include_num_input_tokens_seen: + main_input_name = getattr( + self.model, "main_input_name", "input_ids" + ) + if main_input_name not in inputs: + logger.warning( + "Tried to track the number of tokens seen, however the current model is " + "not configured properly to know what item is the input. To fix this, add " + "a `main_input_name` attribute to the model class you are using." + ) + else: + input_tokens = inputs[main_input_name].numel() + input_tokens = torch.tensor( + input_tokens, device=self.args.device, dtype=torch.int64 + ) + self.state.num_input_tokens_seen += ( + self.accelerator.gather(input_tokens).sum().cpu().item() + ) + if rng_to_sync: + self._load_rng_state(resume_from_checkpoint) + rng_to_sync = False + + # Skip past any already trained steps if resuming training + if steps_trained_in_current_epoch > 0: + steps_trained_in_current_epoch -= 1 + if steps_trained_progress_bar is not None: + steps_trained_progress_bar.update(1) + if steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + continue + elif steps_trained_progress_bar is not None: + steps_trained_progress_bar.close() + steps_trained_progress_bar = None + + if step % args.gradient_accumulation_steps == 0: + self.control = self.callback_handler.on_step_begin( + args, self.state, self.control + ) + + # We explicitly want to avoid relying on `accelerator.accumulate` for generation training + context = ( + functools.partial(self.accelerator.no_sync, model=model) + if i != len(batch_samples) - 1 + and self.accelerator.distributed_type + != DistributedType.DEEPSPEED + else contextlib.nullcontext + ) + with context(): + tr_loss_step = self.training_step( + model, inputs, num_items_in_batch + ) + + if ( + args.logging_nan_inf_filter + and not is_torch_xla_available() + and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) + ): + # if loss is nan or inf simply add the average of previous logged losses + tr_loss = tr_loss + tr_loss / ( + 1 + self.state.global_step - self._globalstep_last_logged + ) + else: + if tr_loss.device != tr_loss_step.device: + raise ValueError( + f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}" + ) + tr_loss = tr_loss + tr_loss_step + + self.current_flos += float(self.floating_point_ops(inputs)) + + if do_sync_step: + # Since we perform prefetching, we need to manually set sync_gradients to True + self.accelerator.gradient_state._set_sync_gradients(True) + + # Gradient clipping + if args.max_grad_norm is not None and args.max_grad_norm > 0: + # deepspeed does its own clipping + + if is_sagemaker_mp_enabled() and args.fp16: + _grad_norm = self.optimizer.clip_master_grads( + args.max_grad_norm + ) + elif self.use_apex: + # Revert to normal clipping otherwise, handling Apex or full precision + _grad_norm = nn.utils.clip_grad_norm_( + amp.master_params(self.optimizer), + args.max_grad_norm, + ) + else: + _grad_norm = self.accelerator.clip_grad_norm_( + model.parameters(), + args.max_grad_norm, + ) + + if ( + is_accelerate_available() + and self.accelerator.distributed_type + == DistributedType.DEEPSPEED + ): + grad_norm = model.get_global_grad_norm() + # In some cases the grad norm may not return a float + if hasattr(grad_norm, "item"): + grad_norm = grad_norm.item() + else: + grad_norm = _grad_norm + + self.control = self.callback_handler.on_pre_optimizer_step( + args, self.state, self.control + ) + + self.optimizer.step() + + self.control = self.callback_handler.on_optimizer_step( + args, self.state, self.control + ) + + optimizer_was_run = ( + not self.accelerator.optimizer_step_was_skipped + ) + if optimizer_was_run: + # Delay optimizer scheduling until metrics are generated + if not isinstance( + self.lr_scheduler, + torch.optim.lr_scheduler.ReduceLROnPlateau, + ): + self.lr_scheduler.step() + + model.zero_grad() + self.state.global_step += 1 + self.state.epoch = ( + epoch + (step + 1 + steps_skipped) / steps_in_epoch + ) + self.control = self.callback_handler.on_step_end( + args, self.state, self.control + ) + self._maybe_log_save_evaluate( + tr_loss, + grad_norm, + model, + trial, + epoch, + ignore_keys_for_eval, + start_time, + ) + else: + self.control = self.callback_handler.on_substep_end( + args, self.state, self.control + ) + + # PyTorch/XLA relies on the data loader to insert the mark_step for + # each step. Since we are breaking the loop early, we need to manually + # insert the mark_step here. + if ( + self.control.should_epoch_stop + or self.control.should_training_stop + ): + if is_torch_xla_available(): + xm.mark_step() + break + # We also need to break out of the nested loop + if self.control.should_epoch_stop or self.control.should_training_stop: + if is_torch_xla_available(): + xm.mark_step() + break + if step < 0: + logger.warning( + "There seems not to be a single sample in your epoch_iterator, stopping training at step" + f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" + f" num_steps ({max_steps}) higher than the number of available samples." + ) + self.control.should_training_stop = True + + self.control = self.callback_handler.on_epoch_end( + args, self.state, self.control + ) + self._maybe_log_save_evaluate( + tr_loss, + grad_norm, + model, + trial, + epoch, + ignore_keys_for_eval, + start_time, + ) + + if DebugOption.TPU_METRICS_DEBUG in self.args.debug: + if is_torch_xla_available(): + # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) + xm.master_print(met.metrics_report()) + else: + logger.warning( + "You enabled PyTorch/XLA debug metrics but you don't have a TPU " + "configured. Check your training configuration if this is unexpected." + ) + if self.control.should_training_stop: + break + + if args.past_index and hasattr(self, "_past"): + # Clean the state at the end of training + delattr(self, "_past") + + logger.info( + "\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n" + ) + if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: + # Wait for everyone to get here so we are sure the model has been saved by process 0. + if is_torch_xla_available(): + xm.rendezvous("load_best_model_at_end") + elif args.parallel_mode == ParallelMode.DISTRIBUTED: + dist.barrier() + elif is_sagemaker_mp_enabled(): + smp.barrier() + + self._load_best_model() + + # add remaining tr_loss + self._total_loss_scalar += tr_loss.item() + effective_global_step = max( + self.state.global_step, 0.001 + ) # Avoid ZeroDivisionError + train_loss = self._total_loss_scalar / effective_global_step + + metrics = speed_metrics( + "train", + start_time, + num_samples=num_train_samples, + num_steps=self.state.max_steps, + num_tokens=num_train_tokens, + ) + self.store_flos() + metrics["total_flos"] = self.state.total_flos + metrics["train_loss"] = train_loss + + self.is_in_train = False + + self._memory_tracker.stop_and_update_metrics(metrics) + + self.log(metrics) + + run_dir = self._get_output_dir(trial) + checkpoints_sorted = self._sorted_checkpoints( + use_mtime=False, output_dir=run_dir + ) + + # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save. + if ( + self.args.should_save + and self.state.best_model_checkpoint is not None + and self.args.save_total_limit == 1 + ): + for checkpoint in checkpoints_sorted: + if not os.path.samefile(checkpoint, self.state.best_model_checkpoint): + logger.info( + f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit" + ) + shutil.rmtree(checkpoint, ignore_errors=True) + + self.control = self.callback_handler.on_train_end( + args, self.state, self.control + ) + + # Wait for the checkpoint to be uploaded. + self._finish_current_push() + + # After training we make sure to retrieve back the original forward pass method + # for the embedding layer by removing the forward post hook. + if self.neftune_noise_alpha is not None: + self._deactivate_neftune(self.model) + + return TrainOutput(self.state.global_step, train_loss, metrics) diff --git a/uv.lock b/uv.lock index 7a78591..3631262 100644 --- a/uv.lock +++ b/uv.lock @@ -5,6 +5,15 @@ resolution-markers = [ "python_full_version < '3.12'", ] +[[package]] +name = "absl-py" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7a/8f/fc001b92ecc467cc32ab38398bd0bfb45df46e7523bf33c2ad22a505f06e/absl-py-2.1.0.tar.gz", hash = "sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff", size = 118055 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/ad/e0d3c824784ff121c03cc031f944bc7e139a8f1870ffd2845cc2dd76f6c4/absl_py-2.1.0-py3-none-any.whl", hash = "sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308", size = 133706 }, +] + [[package]] name = "accelerate" version = "1.2.1" @@ -240,16 +249,20 @@ name = "cl-lmm" version = "0.1.0" source = { virtual = "." } dependencies = [ + { name = "absl-py" }, { name = "accelerate" }, { name = "datasets" }, { name = "deepspeed" }, { name = "evaluate" }, { name = "librosa" }, { name = "markupsafe" }, + { name = "nltk" }, { name = "numba" }, { name = "peft" }, { name = "pip" }, { name = "requests" }, + { name = "rouge-score" }, + { name = "safetensors" }, { name = "setuptools" }, { name = "soundfile" }, { name = "torch" }, @@ -267,6 +280,7 @@ compile = [ [package.metadata] requires-dist = [ + { name = "absl-py", specifier = ">=2.1.0" }, { name = "accelerate", specifier = "==1.2.1" }, { name = "datasets", specifier = "==3.2.0" }, { name = "deepspeed", specifier = "==0.16.2" }, @@ -274,10 +288,13 @@ requires-dist = [ { name = "flash-attn", marker = "extra == 'compile'", specifier = ">=2.7.2.post1" }, { name = "librosa", specifier = ">=0.10.2.post1" }, { name = "markupsafe", specifier = "==2.1.5", index = "https://download.pytorch.org/whl/cu124" }, + { name = "nltk", specifier = ">=3.9.1" }, { name = "numba", specifier = ">=0.60.0" }, { name = "peft", specifier = "==0.14.0" }, { name = "pip", specifier = "==24.3.1" }, { name = "requests", specifier = "==2.32.3", index = "https://pypi.org/simple" }, + { name = "rouge-score", specifier = ">=0.1.2" }, + { name = "safetensors", specifier = ">=0.5.2" }, { name = "setuptools", specifier = ">=70.0.0" }, { name = "soundfile", specifier = ">=0.13.0" }, { name = "torch", specifier = "==2.5.1+cu124", index = "https://download.pytorch.org/whl/cu124" }, @@ -288,6 +305,18 @@ requires-dist = [ { name = "wheel", specifier = ">=0.45.1" }, ] +[[package]] +name = "click" +version = "8.1.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/d4/7ebdbd03970677812aac39c869717059dbb71a4cfc033ca6e5221787892c/click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2", size = 98188 }, +] + [[package]] name = "colorama" version = "0.4.6" @@ -781,6 +810,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d9/9d/0cc1e82849070ff3cbee69f326cb48a839407bcd15d8844443c30a5e7509/ninja-1.11.1.3-py3-none-win_arm64.whl", hash = "sha256:17978ad611d8ead578d83637f5ae80c2261b033db0b493a7ce94f88623f29e1b", size = 270571 }, ] +[[package]] +name = "nltk" +version = "3.9.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "joblib" }, + { name = "regex" }, + { name = "tqdm" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3c/87/db8be88ad32c2d042420b6fd9ffd4a149f9a0d7f0e86b3f543be2eeeedd2/nltk-3.9.1.tar.gz", hash = "sha256:87d127bd3de4bd89a4f81265e5fa59cb1b199b27440175370f7417d2bc7ae868", size = 2904691 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4d/66/7d9e26593edda06e8cb531874633f7c2372279c3b0f46235539fe546df8b/nltk-3.9.1-py3-none-any.whl", hash = "sha256:4fa26829c5b00715afe3061398a8989dc643b92ce7dd93fb4585a70930d168a1", size = 1505442 }, +] + [[package]] name = "numba" version = "0.60.0" @@ -1441,6 +1485,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/19/71/39c7c0d87f8d4e6c020a393182060eaefeeae6c01dab6a84ec346f2567df/rich-13.9.4-py3-none-any.whl", hash = "sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90", size = 242424 }, ] +[[package]] +name = "rouge-score" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "absl-py" }, + { name = "nltk" }, + { name = "numpy" }, + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e2/c5/9136736c37022a6ad27fea38f3111eb8f02fe75d067f9a985cc358653102/rouge_score-0.1.2.tar.gz", hash = "sha256:c7d4da2683e68c9abf0135ef915d63a46643666f848e558a1b9f7ead17ff0f04", size = 17400 } + [[package]] name = "safetensors" version = "0.5.2"