fix: 修复数据集注册和加载逻辑,优化代码格式,确保一致性

This commit is contained in:
YunyaoZhou 2025-05-20 13:33:17 +08:00
parent 1d2e7b9dcd
commit 56e46f0e0c
17 changed files with 94 additions and 51 deletions

View File

@ -30,7 +30,7 @@ class OCRVQADataset(Dataset):
self.text_processor = text_processor
self.split = split
self.ann_path = ann_path
def load_data(self):
if self.split == "train":
self.data = self.create_data(Path(self.ann_path, "dataset.json"), split=1)
@ -165,4 +165,4 @@ dataset = {
),
}
register_dataset("ocrvqa200k", dataset, tag=["image", "text"])
register_dataset("ocrvqa200k", dataset, tag=["image", "text"])

View File

@ -25,8 +25,10 @@ class RefCOCODataset(Dataset):
self.vis_processor = vis_processor
self.text_processor = text_processor
self.split = split
def load_data(self):
from .format import dataset_dir
ds = load_dataset("lmms-lab/RefCOCO", cache_dir=dataset_dir) # type: ignore
self.data = ds[self.split] # type: ignore
@ -118,12 +120,14 @@ def test_RefCOCO():
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
dataset = {
"train": RefCOCODataset(split="val"),
"test": RefCOCODataset(split="test"),
"generation": RefCOCODatasetForGeneration(split="test"),
}
"train": RefCOCODataset(split="val"),
"test": RefCOCODataset(split="test"),
"generation": RefCOCODatasetForGeneration(split="test"),
}
from .factory import register_dataset
register_dataset(
dataset_name="refcoco",
dataset=dataset,

View File

@ -25,9 +25,10 @@ class RefCOCOplusDataset(Dataset):
self.vis_processor = vis_processor
self.text_processor = text_processor
self.split = split
def load_data(self):
from .format import dataset_dir
ds = load_dataset("lmms-lab/RefCOCOplus", cache_dir=dataset_dir) # type: ignore
self.data = ds[self.split] # type: ignore
@ -119,12 +120,14 @@ def test_RefCOCOplus():
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
dataset = {
"train": RefCOCOplusDataset(split="val"),
"test": RefCOCOplusDataset(split="testA"),
"generation": RefCOCOplusDatasetForGeneration(split="testA"),
}
"train": RefCOCOplusDataset(split="val"),
"test": RefCOCOplusDataset(split="testA"),
"generation": RefCOCOplusDatasetForGeneration(split="testA"),
}
from .factory import register_dataset
register_dataset(
dataset_name="refcocoplus",
dataset=dataset,

View File

@ -25,9 +25,10 @@ class RefCOCOgDataset(Dataset):
self.vis_processor = vis_processor
self.text_processor = text_processor
self.split = split
def load_data(self):
from .format import dataset_dir
ds = load_dataset("lmms-lab/RefCOCOg", cache_dir=dataset_dir) # type: ignore
self.data = ds[self.split] # type: ignore
@ -119,12 +120,14 @@ def test_RefCOCOg():
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
dataset = {
"train": RefCOCOgDataset(split="val"),
"test": RefCOCOgDataset(split="test"),
"generation": RefCOCOgDatasetForGeneration(split="test"),
}
"train": RefCOCOgDataset(split="val"),
"test": RefCOCOgDataset(split="test"),
"generation": RefCOCOgDatasetForGeneration(split="test"),
}
from .factory import register_dataset
register_dataset(
dataset_name="refcocog",
dataset=dataset,

View File

@ -19,10 +19,11 @@ class ScienceQADataset(Dataset):
self.vis_processor = vis_processor
self.text_processor = text_processor
self.split = split
def load_data(self):
from .format import dataset_dir
ds = load_dataset("derek-thomas/ScienceQA",cache_dir=dataset_dir) # type: ignore
ds = load_dataset("derek-thomas/ScienceQA", cache_dir=dataset_dir) # type: ignore
self.data = ds[self.split] # type: ignore
def __len__(self):
@ -119,12 +120,14 @@ def test_scienceQA():
assert len(dataset) > 0
assert len(dataset[0]["chat"]) > 0
dataset = {
"train": ScienceQADataset(split="train"),
"test": ScienceQADataset(split="test"),
"generation": ScienceQADatasetForGeneration(split="test"),
}
from .factory import register_dataset
register_dataset(
dataset_name="scienceqa",
dataset=dataset,

View File

@ -28,7 +28,7 @@ class TextVQADataset(Dataset):
self.split = split
self.ann_path = ann_path
self.vis_root = vis_root
def load_data(self):
if self.split == "train":
self.data = self.create_data(

View File

@ -30,7 +30,7 @@ class VizWizDataset(Dataset):
)
self.text_processor = text_processor
self.split = split
def load_data(self):
if self.split == "train":
self.data = self.create_data(Path(self.ann_path, "train.json"))

View File

@ -1,5 +1,5 @@
from torch.utils.data import Dataset
from typing import Literal,List
from typing import Literal, List
from pathlib import Path
from dataset_library.format import dataset_dir
@ -58,7 +58,6 @@ from .VizWizDataset import (
)
def get_dataset(
dataset_name: str, base_path=dataset_dir
) -> dict[Literal["train", "test", "generation"], Dataset]:

View File

@ -32,4 +32,4 @@ class DatasetOutput(TypedDict):
answer: Optional[str]
original: Any
images: Optional[list[Image.Image]]
audios: Optional[list[Tuple[np.ndarray, int]]]
audios: Optional[list[Tuple[np.ndarray, int]]]

View File

@ -4,18 +4,19 @@ from .factory import get_dataset, MAPPING_NAME_TO_DATASET
# Get all registered dataset names for parameterization
dataset_names = list(MAPPING_NAME_TO_DATASET.keys())
@pytest.mark.parametrize("dataset_name", dataset_names)
def test_registered_datasets(dataset_name):
dataset = get_dataset(dataset_name)
# Test train split
assert "train" in dataset, f"Train split not found in {dataset_name}"
assert len(dataset["train"]) > 0, f"Train split is empty for {dataset_name}" # type: ignore
assert "chat" in dataset["train"][0], f"'chat' key not found in first train sample of {dataset_name}" # type: ignore
assert len(dataset["train"][0]["chat"]) > 0, f"'chat' is empty in first train sample of {dataset_name}" # type: ignore
assert len(dataset["train"]) > 0, f"Train split is empty for {dataset_name}" # type: ignore
assert "chat" in dataset["train"][0], f"'chat' key not found in first train sample of {dataset_name}" # type: ignore
assert len(dataset["train"][0]["chat"]) > 0, f"'chat' is empty in first train sample of {dataset_name}" # type: ignore
# Test test split
assert "test" in dataset, f"Test split not found in {dataset_name}"
assert len(dataset["test"]) > 0, f"Test split is empty for {dataset_name}" # type: ignore
assert "chat" in dataset["test"][0], f"'chat' key not found in first test sample of {dataset_name}" # type: ignore
assert len(dataset["test"][0]["chat"]) > 0, f"'chat' is empty in first test sample of {dataset_name}" # type: ignore
assert len(dataset["test"]) > 0, f"Test split is empty for {dataset_name}" # type: ignore
assert "chat" in dataset["test"][0], f"'chat' key not found in first test sample of {dataset_name}" # type: ignore
assert len(dataset["test"][0]["chat"]) > 0, f"'chat' is empty in first test sample of {dataset_name}" # type: ignore

View File

@ -1,4 +1,6 @@
from PIL import Image
def size_processor(image: Image.Image):
width, height = image.size
if width > 500 or height > 500:
@ -15,4 +17,4 @@ def size_processor(image: Image.Image):
new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
return image
return image

View File

@ -1,10 +1,10 @@
class RegularizationMethod:
"""RegularizationMethod implement regularization strategies.
RegularizationMethod is a callable.
The method `update` is called to update the loss, typically at the end
of an experience.
"""
def pre_adapt(self, agent, exp):
pass # implementation may be empty if adapt is not needed
@ -12,4 +12,8 @@ class RegularizationMethod:
pass # implementation may be empty if adapt is not needed
def __call__(self, *args, **kwargs):
raise NotImplementedError()
raise NotImplementedError()
from .ewc import EWC
from .lwf import LWF

View File

@ -1,6 +1,7 @@
from . import RegularizationMethod
import torch
class EWC(RegularizationMethod):
"""Learning Without Forgetting.
@ -24,16 +25,21 @@ class EWC(RegularizationMethod):
Knowledge distillation uses only units corresponding
to old classes.
"""
def adapt(self, output,model, **kwargs):
def adapt(self, output, model, **kwargs):
ewc_loss = 0
for n, p in model.named_parameters():
if p.requires_grad:
dev = p.device
l = self.EWC_lambda * self.fisher[n].to(dev) * (p.data - self.optpar[n].to(dev)).pow(2)
l = (
self.EWC_lambda
* self.fisher[n].to(dev)
* (p.data - self.optpar[n].to(dev)).pow(2)
)
ewc_loss += l.sum()
output['loss'] += ewc_loss
output["loss"] += ewc_loss
return output
def init_epoch(self, model):
"""Update the previous logits for the given question id."""
optpar = {}
@ -42,10 +48,11 @@ class EWC(RegularizationMethod):
if p.requires_grad:
fisher[n] = torch.zeros(p.data.shape)
optpar[n] = p.clone().cpu().data
def update_fisher(self, model):
"""Update the fisher information for the given question id."""
for n, p in model.module.base_model.model.named_parameters():
if p.requires_grad:
fisher = self.fisher[n]
fisher += p.grad.data.pow(2).cpu()
self.fisher[n] = fisher
self.fisher[n] = fisher

View File

@ -1,6 +1,7 @@
from . import RegularizationMethod
import torch
class LWF(RegularizationMethod):
"""Learning Without Forgetting.
@ -23,6 +24,7 @@ class LWF(RegularizationMethod):
Knowledge distillation uses only units corresponding
to old classes.
"""
def adapt(self, output, **kwargs):
def modified_kl_div(old, new):
return -torch.mean(torch.sum(old * torch.log(new), 1))
@ -37,18 +39,29 @@ class LWF(RegularizationMethod):
previous_keys = self.previous_logits.keys()
for index, question_id in enumerate(iterable=kwargs['question_ids']):
for index, question_id in enumerate(iterable=kwargs["question_ids"]):
if question_id in previous_keys:
previous_logits = self.previous_logits[question_id]
current_logits = output['logits'][index]
current_logits = output["logits"][index]
short_index = min(len(previous_logits), len(current_logits))
previous_logits = previous_logits[:short_index]
current_logits = current_logits[:short_index]
lwf_loss.append(modified_kl_div(old=smooth(logits=soft(previous_logits).to(current_logits.device), temp=2, dim=1),new=smooth(logits=soft(current_logits), temp=2, dim=1)))
lwf_loss.append(
modified_kl_div(
old=smooth(
logits=soft(previous_logits).to(current_logits.device),
temp=2,
dim=1,
),
new=smooth(logits=soft(current_logits), temp=2, dim=1),
)
)
if len(lwf_loss) > 0:
output['loss'] += self.LWF_lambda * torch.stack(tensors=lwf_loss, dim=0).sum(dim=0)
output["loss"] += self.LWF_lambda * torch.stack(
tensors=lwf_loss, dim=0
).sum(dim=0)
return output
def update_previous_logits(self, question_id, logits):
"""Update the previous logits for the given question id."""
self.previous_logits[question_id] = logits

View File

@ -15,6 +15,8 @@ from trl import (
from utils.trainer import ContinualTrainer
from utils.args import ContinualScriptArguments, ContinualModelConfig
import logging
from typing import TYPE_CHECKING
logging.basicConfig(level=logging.INFO)
@ -24,15 +26,16 @@ if __name__ == "__main__":
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig)
)
script_args, training_args, model_args = parser.parse_args_and_config()
# for type hint
if 0 == 1:
script_args = ContinualScriptArguments()
training_args = TrainingArguments()
model_args = ContinualModelConfig()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
# for type hint
if 1 == 0:
script_args = ContinualScriptArguments
training_args = TrainingArguments
model_args = ContinualModelConfig
from model_library.factory import get_model
model, processor, collate_fn_for_train, collate_fn_for_evaluate = get_model(
@ -68,6 +71,8 @@ if __name__ == "__main__":
else:
peft_config = None
from peft import get_peft_model
if accelerator.is_local_main_process:
print(model)

View File

@ -31,4 +31,3 @@ class ContiunalRegularizationArguments:
# LWF
lwf_lambda: float = 0.0
lwf_enable: bool = False

View File

@ -15,6 +15,7 @@ from transformers import (
TrainingArguments,
)
from .args import ContiunalRegularizationArguments
from peft_library.regularizations import EWC, LWF
class ContinualTrainer(Trainer):
@ -26,7 +27,7 @@ class ContinualTrainer(Trainer):
train_dataset,
eval_dataset,
accelerator,
regularization_args:ContiunalRegularizationArguments=None,
regularization_args: ContiunalRegularizationArguments = None,
):
self.accelerator = accelerator
super().__init__(model, args, data_collator, train_dataset, eval_dataset)
@ -38,7 +39,6 @@ class ContinualTrainer(Trainer):
if regularization_args.lwf_enable:
self.lwf_lambda = regularization_args.lwf_lambda
def create_accelerator_and_postprocess(self):
if self.accelerator is not None: