From 56e46f0e0cb357b62d7361a0afd4d0aa6b65dfa8 Mon Sep 17 00:00:00 2001 From: YunyaoZhou Date: Tue, 20 May 2025 13:33:17 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E9=9B=86=E6=B3=A8=E5=86=8C=E5=92=8C=E5=8A=A0=E8=BD=BD=E9=80=BB?= =?UTF-8?q?=E8=BE=91=EF=BC=8C=E4=BC=98=E5=8C=96=E4=BB=A3=E7=A0=81=E6=A0=BC?= =?UTF-8?q?=E5=BC=8F=EF=BC=8C=E7=A1=AE=E4=BF=9D=E4=B8=80=E8=87=B4=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/dataset_library/OCRVQA200KDataset.py | 4 ++-- src/dataset_library/RefCOCODataset.py | 12 ++++++---- src/dataset_library/RefCOCOPlusDataset.py | 13 ++++++----- src/dataset_library/RefCOCOgDataset.py | 13 ++++++----- src/dataset_library/ScienceQADataset.py | 7 ++++-- src/dataset_library/TextVQADataset.py | 2 +- src/dataset_library/VizWizDataset.py | 2 +- src/dataset_library/factory.py | 3 +-- src/dataset_library/format.py | 2 +- src/dataset_library/test_dataset.py | 15 +++++++------ src/dataset_library/vis_processor.py | 4 +++- src/peft_library/regularizations/__init__.py | 8 +++++-- src/peft_library/regularizations/ewc.py | 17 ++++++++++----- src/peft_library/regularizations/lwf.py | 23 +++++++++++++++----- src/train.py | 15 ++++++++----- src/utils/args.py | 1 - src/utils/trainer.py | 4 ++-- 17 files changed, 94 insertions(+), 51 deletions(-) diff --git a/src/dataset_library/OCRVQA200KDataset.py b/src/dataset_library/OCRVQA200KDataset.py index 625b986..3a812d9 100644 --- a/src/dataset_library/OCRVQA200KDataset.py +++ b/src/dataset_library/OCRVQA200KDataset.py @@ -30,7 +30,7 @@ class OCRVQADataset(Dataset): self.text_processor = text_processor self.split = split self.ann_path = ann_path - + def load_data(self): if self.split == "train": self.data = self.create_data(Path(self.ann_path, "dataset.json"), split=1) @@ -165,4 +165,4 @@ dataset = { ), } -register_dataset("ocrvqa200k", dataset, tag=["image", "text"]) \ No newline at end of file +register_dataset("ocrvqa200k", dataset, tag=["image", "text"]) diff --git a/src/dataset_library/RefCOCODataset.py b/src/dataset_library/RefCOCODataset.py index a165ad5..b874c1a 100644 --- a/src/dataset_library/RefCOCODataset.py +++ b/src/dataset_library/RefCOCODataset.py @@ -25,8 +25,10 @@ class RefCOCODataset(Dataset): self.vis_processor = vis_processor self.text_processor = text_processor self.split = split + def load_data(self): from .format import dataset_dir + ds = load_dataset("lmms-lab/RefCOCO", cache_dir=dataset_dir) # type: ignore self.data = ds[self.split] # type: ignore @@ -118,12 +120,14 @@ def test_RefCOCO(): assert len(dataset) > 0 assert len(dataset[0]["chat"]) > 0 + dataset = { - "train": RefCOCODataset(split="val"), - "test": RefCOCODataset(split="test"), - "generation": RefCOCODatasetForGeneration(split="test"), - } + "train": RefCOCODataset(split="val"), + "test": RefCOCODataset(split="test"), + "generation": RefCOCODatasetForGeneration(split="test"), +} from .factory import register_dataset + register_dataset( dataset_name="refcoco", dataset=dataset, diff --git a/src/dataset_library/RefCOCOPlusDataset.py b/src/dataset_library/RefCOCOPlusDataset.py index 95a6c86..97fdf53 100644 --- a/src/dataset_library/RefCOCOPlusDataset.py +++ b/src/dataset_library/RefCOCOPlusDataset.py @@ -25,9 +25,10 @@ class RefCOCOplusDataset(Dataset): self.vis_processor = vis_processor self.text_processor = text_processor self.split = split - + def load_data(self): from .format import dataset_dir + ds = load_dataset("lmms-lab/RefCOCOplus", cache_dir=dataset_dir) # type: ignore self.data = ds[self.split] # type: ignore @@ -119,12 +120,14 @@ def test_RefCOCOplus(): assert len(dataset) > 0 assert len(dataset[0]["chat"]) > 0 + dataset = { - "train": RefCOCOplusDataset(split="val"), - "test": RefCOCOplusDataset(split="testA"), - "generation": RefCOCOplusDatasetForGeneration(split="testA"), - } + "train": RefCOCOplusDataset(split="val"), + "test": RefCOCOplusDataset(split="testA"), + "generation": RefCOCOplusDatasetForGeneration(split="testA"), +} from .factory import register_dataset + register_dataset( dataset_name="refcocoplus", dataset=dataset, diff --git a/src/dataset_library/RefCOCOgDataset.py b/src/dataset_library/RefCOCOgDataset.py index 73fb52f..e262d66 100644 --- a/src/dataset_library/RefCOCOgDataset.py +++ b/src/dataset_library/RefCOCOgDataset.py @@ -25,9 +25,10 @@ class RefCOCOgDataset(Dataset): self.vis_processor = vis_processor self.text_processor = text_processor self.split = split - + def load_data(self): from .format import dataset_dir + ds = load_dataset("lmms-lab/RefCOCOg", cache_dir=dataset_dir) # type: ignore self.data = ds[self.split] # type: ignore @@ -119,12 +120,14 @@ def test_RefCOCOg(): assert len(dataset) > 0 assert len(dataset[0]["chat"]) > 0 + dataset = { - "train": RefCOCOgDataset(split="val"), - "test": RefCOCOgDataset(split="test"), - "generation": RefCOCOgDatasetForGeneration(split="test"), - } + "train": RefCOCOgDataset(split="val"), + "test": RefCOCOgDataset(split="test"), + "generation": RefCOCOgDatasetForGeneration(split="test"), +} from .factory import register_dataset + register_dataset( dataset_name="refcocog", dataset=dataset, diff --git a/src/dataset_library/ScienceQADataset.py b/src/dataset_library/ScienceQADataset.py index d51f165..6bd1dc3 100644 --- a/src/dataset_library/ScienceQADataset.py +++ b/src/dataset_library/ScienceQADataset.py @@ -19,10 +19,11 @@ class ScienceQADataset(Dataset): self.vis_processor = vis_processor self.text_processor = text_processor self.split = split - + def load_data(self): from .format import dataset_dir - ds = load_dataset("derek-thomas/ScienceQA",cache_dir=dataset_dir) # type: ignore + + ds = load_dataset("derek-thomas/ScienceQA", cache_dir=dataset_dir) # type: ignore self.data = ds[self.split] # type: ignore def __len__(self): @@ -119,12 +120,14 @@ def test_scienceQA(): assert len(dataset) > 0 assert len(dataset[0]["chat"]) > 0 + dataset = { "train": ScienceQADataset(split="train"), "test": ScienceQADataset(split="test"), "generation": ScienceQADatasetForGeneration(split="test"), } from .factory import register_dataset + register_dataset( dataset_name="scienceqa", dataset=dataset, diff --git a/src/dataset_library/TextVQADataset.py b/src/dataset_library/TextVQADataset.py index 62d0526..8108abd 100644 --- a/src/dataset_library/TextVQADataset.py +++ b/src/dataset_library/TextVQADataset.py @@ -28,7 +28,7 @@ class TextVQADataset(Dataset): self.split = split self.ann_path = ann_path self.vis_root = vis_root - + def load_data(self): if self.split == "train": self.data = self.create_data( diff --git a/src/dataset_library/VizWizDataset.py b/src/dataset_library/VizWizDataset.py index 756734a..c4cf16b 100644 --- a/src/dataset_library/VizWizDataset.py +++ b/src/dataset_library/VizWizDataset.py @@ -30,7 +30,7 @@ class VizWizDataset(Dataset): ) self.text_processor = text_processor self.split = split - + def load_data(self): if self.split == "train": self.data = self.create_data(Path(self.ann_path, "train.json")) diff --git a/src/dataset_library/factory.py b/src/dataset_library/factory.py index 490412b..3c71385 100644 --- a/src/dataset_library/factory.py +++ b/src/dataset_library/factory.py @@ -1,5 +1,5 @@ from torch.utils.data import Dataset -from typing import Literal,List +from typing import Literal, List from pathlib import Path from dataset_library.format import dataset_dir @@ -58,7 +58,6 @@ from .VizWizDataset import ( ) - def get_dataset( dataset_name: str, base_path=dataset_dir ) -> dict[Literal["train", "test", "generation"], Dataset]: diff --git a/src/dataset_library/format.py b/src/dataset_library/format.py index 680478a..f73338e 100644 --- a/src/dataset_library/format.py +++ b/src/dataset_library/format.py @@ -32,4 +32,4 @@ class DatasetOutput(TypedDict): answer: Optional[str] original: Any images: Optional[list[Image.Image]] - audios: Optional[list[Tuple[np.ndarray, int]]] \ No newline at end of file + audios: Optional[list[Tuple[np.ndarray, int]]] diff --git a/src/dataset_library/test_dataset.py b/src/dataset_library/test_dataset.py index 34a7156..c8afc29 100644 --- a/src/dataset_library/test_dataset.py +++ b/src/dataset_library/test_dataset.py @@ -4,18 +4,19 @@ from .factory import get_dataset, MAPPING_NAME_TO_DATASET # Get all registered dataset names for parameterization dataset_names = list(MAPPING_NAME_TO_DATASET.keys()) + @pytest.mark.parametrize("dataset_name", dataset_names) def test_registered_datasets(dataset_name): dataset = get_dataset(dataset_name) - + # Test train split assert "train" in dataset, f"Train split not found in {dataset_name}" - assert len(dataset["train"]) > 0, f"Train split is empty for {dataset_name}" # type: ignore - assert "chat" in dataset["train"][0], f"'chat' key not found in first train sample of {dataset_name}" # type: ignore - assert len(dataset["train"][0]["chat"]) > 0, f"'chat' is empty in first train sample of {dataset_name}" # type: ignore + assert len(dataset["train"]) > 0, f"Train split is empty for {dataset_name}" # type: ignore + assert "chat" in dataset["train"][0], f"'chat' key not found in first train sample of {dataset_name}" # type: ignore + assert len(dataset["train"][0]["chat"]) > 0, f"'chat' is empty in first train sample of {dataset_name}" # type: ignore # Test test split assert "test" in dataset, f"Test split not found in {dataset_name}" - assert len(dataset["test"]) > 0, f"Test split is empty for {dataset_name}" # type: ignore - assert "chat" in dataset["test"][0], f"'chat' key not found in first test sample of {dataset_name}" # type: ignore - assert len(dataset["test"][0]["chat"]) > 0, f"'chat' is empty in first test sample of {dataset_name}" # type: ignore + assert len(dataset["test"]) > 0, f"Test split is empty for {dataset_name}" # type: ignore + assert "chat" in dataset["test"][0], f"'chat' key not found in first test sample of {dataset_name}" # type: ignore + assert len(dataset["test"][0]["chat"]) > 0, f"'chat' is empty in first test sample of {dataset_name}" # type: ignore diff --git a/src/dataset_library/vis_processor.py b/src/dataset_library/vis_processor.py index e452143..4e683bf 100644 --- a/src/dataset_library/vis_processor.py +++ b/src/dataset_library/vis_processor.py @@ -1,4 +1,6 @@ from PIL import Image + + def size_processor(image: Image.Image): width, height = image.size if width > 500 or height > 500: @@ -15,4 +17,4 @@ def size_processor(image: Image.Image): new_height = int(height * ratio) image = image.resize((new_width, new_height), Image.Resampling.BILINEAR) - return image \ No newline at end of file + return image diff --git a/src/peft_library/regularizations/__init__.py b/src/peft_library/regularizations/__init__.py index f1b5bba..bbfe6d6 100644 --- a/src/peft_library/regularizations/__init__.py +++ b/src/peft_library/regularizations/__init__.py @@ -1,10 +1,10 @@ - class RegularizationMethod: """RegularizationMethod implement regularization strategies. RegularizationMethod is a callable. The method `update` is called to update the loss, typically at the end of an experience. """ + def pre_adapt(self, agent, exp): pass # implementation may be empty if adapt is not needed @@ -12,4 +12,8 @@ class RegularizationMethod: pass # implementation may be empty if adapt is not needed def __call__(self, *args, **kwargs): - raise NotImplementedError() \ No newline at end of file + raise NotImplementedError() + + +from .ewc import EWC +from .lwf import LWF diff --git a/src/peft_library/regularizations/ewc.py b/src/peft_library/regularizations/ewc.py index 530fdd5..4cebb9a 100644 --- a/src/peft_library/regularizations/ewc.py +++ b/src/peft_library/regularizations/ewc.py @@ -1,6 +1,7 @@ from . import RegularizationMethod import torch + class EWC(RegularizationMethod): """Learning Without Forgetting. @@ -24,16 +25,21 @@ class EWC(RegularizationMethod): Knowledge distillation uses only units corresponding to old classes. """ - def adapt(self, output,model, **kwargs): + + def adapt(self, output, model, **kwargs): ewc_loss = 0 for n, p in model.named_parameters(): if p.requires_grad: dev = p.device - l = self.EWC_lambda * self.fisher[n].to(dev) * (p.data - self.optpar[n].to(dev)).pow(2) + l = ( + self.EWC_lambda + * self.fisher[n].to(dev) + * (p.data - self.optpar[n].to(dev)).pow(2) + ) ewc_loss += l.sum() - output['loss'] += ewc_loss + output["loss"] += ewc_loss return output - + def init_epoch(self, model): """Update the previous logits for the given question id.""" optpar = {} @@ -42,10 +48,11 @@ class EWC(RegularizationMethod): if p.requires_grad: fisher[n] = torch.zeros(p.data.shape) optpar[n] = p.clone().cpu().data + def update_fisher(self, model): """Update the fisher information for the given question id.""" for n, p in model.module.base_model.model.named_parameters(): if p.requires_grad: fisher = self.fisher[n] fisher += p.grad.data.pow(2).cpu() - self.fisher[n] = fisher \ No newline at end of file + self.fisher[n] = fisher diff --git a/src/peft_library/regularizations/lwf.py b/src/peft_library/regularizations/lwf.py index e93b348..2c794ff 100644 --- a/src/peft_library/regularizations/lwf.py +++ b/src/peft_library/regularizations/lwf.py @@ -1,6 +1,7 @@ from . import RegularizationMethod import torch + class LWF(RegularizationMethod): """Learning Without Forgetting. @@ -23,6 +24,7 @@ class LWF(RegularizationMethod): Knowledge distillation uses only units corresponding to old classes. """ + def adapt(self, output, **kwargs): def modified_kl_div(old, new): return -torch.mean(torch.sum(old * torch.log(new), 1)) @@ -37,18 +39,29 @@ class LWF(RegularizationMethod): previous_keys = self.previous_logits.keys() - for index, question_id in enumerate(iterable=kwargs['question_ids']): + for index, question_id in enumerate(iterable=kwargs["question_ids"]): if question_id in previous_keys: previous_logits = self.previous_logits[question_id] - current_logits = output['logits'][index] + current_logits = output["logits"][index] short_index = min(len(previous_logits), len(current_logits)) previous_logits = previous_logits[:short_index] current_logits = current_logits[:short_index] - lwf_loss.append(modified_kl_div(old=smooth(logits=soft(previous_logits).to(current_logits.device), temp=2, dim=1),new=smooth(logits=soft(current_logits), temp=2, dim=1))) + lwf_loss.append( + modified_kl_div( + old=smooth( + logits=soft(previous_logits).to(current_logits.device), + temp=2, + dim=1, + ), + new=smooth(logits=soft(current_logits), temp=2, dim=1), + ) + ) if len(lwf_loss) > 0: - output['loss'] += self.LWF_lambda * torch.stack(tensors=lwf_loss, dim=0).sum(dim=0) + output["loss"] += self.LWF_lambda * torch.stack( + tensors=lwf_loss, dim=0 + ).sum(dim=0) return output - + def update_previous_logits(self, question_id, logits): """Update the previous logits for the given question id.""" self.previous_logits[question_id] = logits diff --git a/src/train.py b/src/train.py index 5499094..c6a3d19 100644 --- a/src/train.py +++ b/src/train.py @@ -15,6 +15,8 @@ from trl import ( from utils.trainer import ContinualTrainer from utils.args import ContinualScriptArguments, ContinualModelConfig import logging +from typing import TYPE_CHECKING + logging.basicConfig(level=logging.INFO) @@ -24,15 +26,16 @@ if __name__ == "__main__": (ContinualScriptArguments, TrainingArguments, ContinualModelConfig) ) script_args, training_args, model_args = parser.parse_args_and_config() - # for type hint - if 0 == 1: - script_args = ContinualScriptArguments() - training_args = TrainingArguments() - model_args = ContinualModelConfig() training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) training_args.remove_unused_columns = False training_args.dataset_kwargs = {"skip_prepare_dataset": True} + # for type hint + if 1 == 0: + script_args = ContinualScriptArguments + training_args = TrainingArguments + model_args = ContinualModelConfig + from model_library.factory import get_model model, processor, collate_fn_for_train, collate_fn_for_evaluate = get_model( @@ -68,6 +71,8 @@ if __name__ == "__main__": else: peft_config = None + from peft import get_peft_model + if accelerator.is_local_main_process: print(model) diff --git a/src/utils/args.py b/src/utils/args.py index d7a81dc..0180178 100644 --- a/src/utils/args.py +++ b/src/utils/args.py @@ -31,4 +31,3 @@ class ContiunalRegularizationArguments: # LWF lwf_lambda: float = 0.0 lwf_enable: bool = False - \ No newline at end of file diff --git a/src/utils/trainer.py b/src/utils/trainer.py index b389314..3b00f9a 100644 --- a/src/utils/trainer.py +++ b/src/utils/trainer.py @@ -15,6 +15,7 @@ from transformers import ( TrainingArguments, ) from .args import ContiunalRegularizationArguments +from peft_library.regularizations import EWC, LWF class ContinualTrainer(Trainer): @@ -26,7 +27,7 @@ class ContinualTrainer(Trainer): train_dataset, eval_dataset, accelerator, - regularization_args:ContiunalRegularizationArguments=None, + regularization_args: ContiunalRegularizationArguments = None, ): self.accelerator = accelerator super().__init__(model, args, data_collator, train_dataset, eval_dataset) @@ -38,7 +39,6 @@ class ContinualTrainer(Trainer): if regularization_args.lwf_enable: self.lwf_lambda = regularization_args.lwf_lambda - def create_accelerator_and_postprocess(self): if self.accelerator is not None: