fix: 修复数据集注册和加载逻辑,优化代码格式,确保一致性

This commit is contained in:
YunyaoZhou 2025-05-20 13:33:17 +08:00
parent 1d2e7b9dcd
commit 56e46f0e0c
17 changed files with 94 additions and 51 deletions

View File

@ -25,8 +25,10 @@ class RefCOCODataset(Dataset):
self.vis_processor = vis_processor self.vis_processor = vis_processor
self.text_processor = text_processor self.text_processor = text_processor
self.split = split self.split = split
def load_data(self): def load_data(self):
from .format import dataset_dir from .format import dataset_dir
ds = load_dataset("lmms-lab/RefCOCO", cache_dir=dataset_dir) # type: ignore ds = load_dataset("lmms-lab/RefCOCO", cache_dir=dataset_dir) # type: ignore
self.data = ds[self.split] # type: ignore self.data = ds[self.split] # type: ignore
@ -118,12 +120,14 @@ def test_RefCOCO():
assert len(dataset) > 0 assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0 assert len(dataset[0]["chat"]) > 0
dataset = { dataset = {
"train": RefCOCODataset(split="val"), "train": RefCOCODataset(split="val"),
"test": RefCOCODataset(split="test"), "test": RefCOCODataset(split="test"),
"generation": RefCOCODatasetForGeneration(split="test"), "generation": RefCOCODatasetForGeneration(split="test"),
} }
from .factory import register_dataset from .factory import register_dataset
register_dataset( register_dataset(
dataset_name="refcoco", dataset_name="refcoco",
dataset=dataset, dataset=dataset,

View File

@ -28,6 +28,7 @@ class RefCOCOplusDataset(Dataset):
def load_data(self): def load_data(self):
from .format import dataset_dir from .format import dataset_dir
ds = load_dataset("lmms-lab/RefCOCOplus", cache_dir=dataset_dir) # type: ignore ds = load_dataset("lmms-lab/RefCOCOplus", cache_dir=dataset_dir) # type: ignore
self.data = ds[self.split] # type: ignore self.data = ds[self.split] # type: ignore
@ -119,12 +120,14 @@ def test_RefCOCOplus():
assert len(dataset) > 0 assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0 assert len(dataset[0]["chat"]) > 0
dataset = { dataset = {
"train": RefCOCOplusDataset(split="val"), "train": RefCOCOplusDataset(split="val"),
"test": RefCOCOplusDataset(split="testA"), "test": RefCOCOplusDataset(split="testA"),
"generation": RefCOCOplusDatasetForGeneration(split="testA"), "generation": RefCOCOplusDatasetForGeneration(split="testA"),
} }
from .factory import register_dataset from .factory import register_dataset
register_dataset( register_dataset(
dataset_name="refcocoplus", dataset_name="refcocoplus",
dataset=dataset, dataset=dataset,

View File

@ -28,6 +28,7 @@ class RefCOCOgDataset(Dataset):
def load_data(self): def load_data(self):
from .format import dataset_dir from .format import dataset_dir
ds = load_dataset("lmms-lab/RefCOCOg", cache_dir=dataset_dir) # type: ignore ds = load_dataset("lmms-lab/RefCOCOg", cache_dir=dataset_dir) # type: ignore
self.data = ds[self.split] # type: ignore self.data = ds[self.split] # type: ignore
@ -119,12 +120,14 @@ def test_RefCOCOg():
assert len(dataset) > 0 assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0 assert len(dataset[0]["chat"]) > 0
dataset = { dataset = {
"train": RefCOCOgDataset(split="val"), "train": RefCOCOgDataset(split="val"),
"test": RefCOCOgDataset(split="test"), "test": RefCOCOgDataset(split="test"),
"generation": RefCOCOgDatasetForGeneration(split="test"), "generation": RefCOCOgDatasetForGeneration(split="test"),
} }
from .factory import register_dataset from .factory import register_dataset
register_dataset( register_dataset(
dataset_name="refcocog", dataset_name="refcocog",
dataset=dataset, dataset=dataset,

View File

@ -22,7 +22,8 @@ class ScienceQADataset(Dataset):
def load_data(self): def load_data(self):
from .format import dataset_dir from .format import dataset_dir
ds = load_dataset("derek-thomas/ScienceQA",cache_dir=dataset_dir) # type: ignore
ds = load_dataset("derek-thomas/ScienceQA", cache_dir=dataset_dir) # type: ignore
self.data = ds[self.split] # type: ignore self.data = ds[self.split] # type: ignore
def __len__(self): def __len__(self):
@ -119,12 +120,14 @@ def test_scienceQA():
assert len(dataset) > 0 assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0 assert len(dataset[0]["chat"]) > 0
dataset = { dataset = {
"train": ScienceQADataset(split="train"), "train": ScienceQADataset(split="train"),
"test": ScienceQADataset(split="test"), "test": ScienceQADataset(split="test"),
"generation": ScienceQADatasetForGeneration(split="test"), "generation": ScienceQADatasetForGeneration(split="test"),
} }
from .factory import register_dataset from .factory import register_dataset
register_dataset( register_dataset(
dataset_name="scienceqa", dataset_name="scienceqa",
dataset=dataset, dataset=dataset,

View File

@ -1,5 +1,5 @@
from torch.utils.data import Dataset from torch.utils.data import Dataset
from typing import Literal,List from typing import Literal, List
from pathlib import Path from pathlib import Path
from dataset_library.format import dataset_dir from dataset_library.format import dataset_dir
@ -58,7 +58,6 @@ from .VizWizDataset import (
) )
def get_dataset( def get_dataset(
dataset_name: str, base_path=dataset_dir dataset_name: str, base_path=dataset_dir
) -> dict[Literal["train", "test", "generation"], Dataset]: ) -> dict[Literal["train", "test", "generation"], Dataset]:

View File

@ -4,6 +4,7 @@ from .factory import get_dataset, MAPPING_NAME_TO_DATASET
# Get all registered dataset names for parameterization # Get all registered dataset names for parameterization
dataset_names = list(MAPPING_NAME_TO_DATASET.keys()) dataset_names = list(MAPPING_NAME_TO_DATASET.keys())
@pytest.mark.parametrize("dataset_name", dataset_names) @pytest.mark.parametrize("dataset_name", dataset_names)
def test_registered_datasets(dataset_name): def test_registered_datasets(dataset_name):
dataset = get_dataset(dataset_name) dataset = get_dataset(dataset_name)

View File

@ -1,4 +1,6 @@
from PIL import Image from PIL import Image
def size_processor(image: Image.Image): def size_processor(image: Image.Image):
width, height = image.size width, height = image.size
if width > 500 or height > 500: if width > 500 or height > 500:

View File

@ -1,10 +1,10 @@
class RegularizationMethod: class RegularizationMethod:
"""RegularizationMethod implement regularization strategies. """RegularizationMethod implement regularization strategies.
RegularizationMethod is a callable. RegularizationMethod is a callable.
The method `update` is called to update the loss, typically at the end The method `update` is called to update the loss, typically at the end
of an experience. of an experience.
""" """
def pre_adapt(self, agent, exp): def pre_adapt(self, agent, exp):
pass # implementation may be empty if adapt is not needed pass # implementation may be empty if adapt is not needed
@ -13,3 +13,7 @@ class RegularizationMethod:
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
raise NotImplementedError() raise NotImplementedError()
from .ewc import EWC
from .lwf import LWF

View File

@ -1,6 +1,7 @@
from . import RegularizationMethod from . import RegularizationMethod
import torch import torch
class EWC(RegularizationMethod): class EWC(RegularizationMethod):
"""Learning Without Forgetting. """Learning Without Forgetting.
@ -24,14 +25,19 @@ class EWC(RegularizationMethod):
Knowledge distillation uses only units corresponding Knowledge distillation uses only units corresponding
to old classes. to old classes.
""" """
def adapt(self, output,model, **kwargs):
def adapt(self, output, model, **kwargs):
ewc_loss = 0 ewc_loss = 0
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if p.requires_grad: if p.requires_grad:
dev = p.device dev = p.device
l = self.EWC_lambda * self.fisher[n].to(dev) * (p.data - self.optpar[n].to(dev)).pow(2) l = (
self.EWC_lambda
* self.fisher[n].to(dev)
* (p.data - self.optpar[n].to(dev)).pow(2)
)
ewc_loss += l.sum() ewc_loss += l.sum()
output['loss'] += ewc_loss output["loss"] += ewc_loss
return output return output
def init_epoch(self, model): def init_epoch(self, model):
@ -42,6 +48,7 @@ class EWC(RegularizationMethod):
if p.requires_grad: if p.requires_grad:
fisher[n] = torch.zeros(p.data.shape) fisher[n] = torch.zeros(p.data.shape)
optpar[n] = p.clone().cpu().data optpar[n] = p.clone().cpu().data
def update_fisher(self, model): def update_fisher(self, model):
"""Update the fisher information for the given question id.""" """Update the fisher information for the given question id."""
for n, p in model.module.base_model.model.named_parameters(): for n, p in model.module.base_model.model.named_parameters():

View File

@ -1,6 +1,7 @@
from . import RegularizationMethod from . import RegularizationMethod
import torch import torch
class LWF(RegularizationMethod): class LWF(RegularizationMethod):
"""Learning Without Forgetting. """Learning Without Forgetting.
@ -23,6 +24,7 @@ class LWF(RegularizationMethod):
Knowledge distillation uses only units corresponding Knowledge distillation uses only units corresponding
to old classes. to old classes.
""" """
def adapt(self, output, **kwargs): def adapt(self, output, **kwargs):
def modified_kl_div(old, new): def modified_kl_div(old, new):
return -torch.mean(torch.sum(old * torch.log(new), 1)) return -torch.mean(torch.sum(old * torch.log(new), 1))
@ -37,16 +39,27 @@ class LWF(RegularizationMethod):
previous_keys = self.previous_logits.keys() previous_keys = self.previous_logits.keys()
for index, question_id in enumerate(iterable=kwargs['question_ids']): for index, question_id in enumerate(iterable=kwargs["question_ids"]):
if question_id in previous_keys: if question_id in previous_keys:
previous_logits = self.previous_logits[question_id] previous_logits = self.previous_logits[question_id]
current_logits = output['logits'][index] current_logits = output["logits"][index]
short_index = min(len(previous_logits), len(current_logits)) short_index = min(len(previous_logits), len(current_logits))
previous_logits = previous_logits[:short_index] previous_logits = previous_logits[:short_index]
current_logits = current_logits[:short_index] current_logits = current_logits[:short_index]
lwf_loss.append(modified_kl_div(old=smooth(logits=soft(previous_logits).to(current_logits.device), temp=2, dim=1),new=smooth(logits=soft(current_logits), temp=2, dim=1))) lwf_loss.append(
modified_kl_div(
old=smooth(
logits=soft(previous_logits).to(current_logits.device),
temp=2,
dim=1,
),
new=smooth(logits=soft(current_logits), temp=2, dim=1),
)
)
if len(lwf_loss) > 0: if len(lwf_loss) > 0:
output['loss'] += self.LWF_lambda * torch.stack(tensors=lwf_loss, dim=0).sum(dim=0) output["loss"] += self.LWF_lambda * torch.stack(
tensors=lwf_loss, dim=0
).sum(dim=0)
return output return output
def update_previous_logits(self, question_id, logits): def update_previous_logits(self, question_id, logits):

View File

@ -15,6 +15,8 @@ from trl import (
from utils.trainer import ContinualTrainer from utils.trainer import ContinualTrainer
from utils.args import ContinualScriptArguments, ContinualModelConfig from utils.args import ContinualScriptArguments, ContinualModelConfig
import logging import logging
from typing import TYPE_CHECKING
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
@ -24,15 +26,16 @@ if __name__ == "__main__":
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig) (ContinualScriptArguments, TrainingArguments, ContinualModelConfig)
) )
script_args, training_args, model_args = parser.parse_args_and_config() script_args, training_args, model_args = parser.parse_args_and_config()
# for type hint
if 0 == 1:
script_args = ContinualScriptArguments()
training_args = TrainingArguments()
model_args = ContinualModelConfig()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True} training_args.dataset_kwargs = {"skip_prepare_dataset": True}
# for type hint
if 1 == 0:
script_args = ContinualScriptArguments
training_args = TrainingArguments
model_args = ContinualModelConfig
from model_library.factory import get_model from model_library.factory import get_model
model, processor, collate_fn_for_train, collate_fn_for_evaluate = get_model( model, processor, collate_fn_for_train, collate_fn_for_evaluate = get_model(
@ -68,6 +71,8 @@ if __name__ == "__main__":
else: else:
peft_config = None peft_config = None
from peft import get_peft_model
if accelerator.is_local_main_process: if accelerator.is_local_main_process:
print(model) print(model)

View File

@ -31,4 +31,3 @@ class ContiunalRegularizationArguments:
# LWF # LWF
lwf_lambda: float = 0.0 lwf_lambda: float = 0.0
lwf_enable: bool = False lwf_enable: bool = False

View File

@ -15,6 +15,7 @@ from transformers import (
TrainingArguments, TrainingArguments,
) )
from .args import ContiunalRegularizationArguments from .args import ContiunalRegularizationArguments
from peft_library.regularizations import EWC, LWF
class ContinualTrainer(Trainer): class ContinualTrainer(Trainer):
@ -26,7 +27,7 @@ class ContinualTrainer(Trainer):
train_dataset, train_dataset,
eval_dataset, eval_dataset,
accelerator, accelerator,
regularization_args:ContiunalRegularizationArguments=None, regularization_args: ContiunalRegularizationArguments = None,
): ):
self.accelerator = accelerator self.accelerator = accelerator
super().__init__(model, args, data_collator, train_dataset, eval_dataset) super().__init__(model, args, data_collator, train_dataset, eval_dataset)
@ -38,7 +39,6 @@ class ContinualTrainer(Trainer):
if regularization_args.lwf_enable: if regularization_args.lwf_enable:
self.lwf_lambda = regularization_args.lwf_lambda self.lwf_lambda = regularization_args.lwf_lambda
def create_accelerator_and_postprocess(self): def create_accelerator_and_postprocess(self):
if self.accelerator is not None: if self.accelerator is not None: