fix: 修复数据集注册和加载逻辑,优化代码格式,确保一致性
This commit is contained in:
parent
1d2e7b9dcd
commit
56e46f0e0c
@ -25,8 +25,10 @@ class RefCOCODataset(Dataset):
|
||||
self.vis_processor = vis_processor
|
||||
self.text_processor = text_processor
|
||||
self.split = split
|
||||
|
||||
def load_data(self):
|
||||
from .format import dataset_dir
|
||||
|
||||
ds = load_dataset("lmms-lab/RefCOCO", cache_dir=dataset_dir) # type: ignore
|
||||
self.data = ds[self.split] # type: ignore
|
||||
|
||||
@ -118,12 +120,14 @@ def test_RefCOCO():
|
||||
assert len(dataset) > 0
|
||||
assert len(dataset[0]["chat"]) > 0
|
||||
|
||||
|
||||
dataset = {
|
||||
"train": RefCOCODataset(split="val"),
|
||||
"test": RefCOCODataset(split="test"),
|
||||
"generation": RefCOCODatasetForGeneration(split="test"),
|
||||
}
|
||||
}
|
||||
from .factory import register_dataset
|
||||
|
||||
register_dataset(
|
||||
dataset_name="refcoco",
|
||||
dataset=dataset,
|
||||
|
@ -28,6 +28,7 @@ class RefCOCOplusDataset(Dataset):
|
||||
|
||||
def load_data(self):
|
||||
from .format import dataset_dir
|
||||
|
||||
ds = load_dataset("lmms-lab/RefCOCOplus", cache_dir=dataset_dir) # type: ignore
|
||||
self.data = ds[self.split] # type: ignore
|
||||
|
||||
@ -119,12 +120,14 @@ def test_RefCOCOplus():
|
||||
assert len(dataset) > 0
|
||||
assert len(dataset[0]["chat"]) > 0
|
||||
|
||||
|
||||
dataset = {
|
||||
"train": RefCOCOplusDataset(split="val"),
|
||||
"test": RefCOCOplusDataset(split="testA"),
|
||||
"generation": RefCOCOplusDatasetForGeneration(split="testA"),
|
||||
}
|
||||
}
|
||||
from .factory import register_dataset
|
||||
|
||||
register_dataset(
|
||||
dataset_name="refcocoplus",
|
||||
dataset=dataset,
|
||||
|
@ -28,6 +28,7 @@ class RefCOCOgDataset(Dataset):
|
||||
|
||||
def load_data(self):
|
||||
from .format import dataset_dir
|
||||
|
||||
ds = load_dataset("lmms-lab/RefCOCOg", cache_dir=dataset_dir) # type: ignore
|
||||
self.data = ds[self.split] # type: ignore
|
||||
|
||||
@ -119,12 +120,14 @@ def test_RefCOCOg():
|
||||
assert len(dataset) > 0
|
||||
assert len(dataset[0]["chat"]) > 0
|
||||
|
||||
|
||||
dataset = {
|
||||
"train": RefCOCOgDataset(split="val"),
|
||||
"test": RefCOCOgDataset(split="test"),
|
||||
"generation": RefCOCOgDatasetForGeneration(split="test"),
|
||||
}
|
||||
}
|
||||
from .factory import register_dataset
|
||||
|
||||
register_dataset(
|
||||
dataset_name="refcocog",
|
||||
dataset=dataset,
|
||||
|
@ -22,7 +22,8 @@ class ScienceQADataset(Dataset):
|
||||
|
||||
def load_data(self):
|
||||
from .format import dataset_dir
|
||||
ds = load_dataset("derek-thomas/ScienceQA",cache_dir=dataset_dir) # type: ignore
|
||||
|
||||
ds = load_dataset("derek-thomas/ScienceQA", cache_dir=dataset_dir) # type: ignore
|
||||
self.data = ds[self.split] # type: ignore
|
||||
|
||||
def __len__(self):
|
||||
@ -119,12 +120,14 @@ def test_scienceQA():
|
||||
assert len(dataset) > 0
|
||||
assert len(dataset[0]["chat"]) > 0
|
||||
|
||||
|
||||
dataset = {
|
||||
"train": ScienceQADataset(split="train"),
|
||||
"test": ScienceQADataset(split="test"),
|
||||
"generation": ScienceQADatasetForGeneration(split="test"),
|
||||
}
|
||||
from .factory import register_dataset
|
||||
|
||||
register_dataset(
|
||||
dataset_name="scienceqa",
|
||||
dataset=dataset,
|
||||
|
@ -1,5 +1,5 @@
|
||||
from torch.utils.data import Dataset
|
||||
from typing import Literal,List
|
||||
from typing import Literal, List
|
||||
from pathlib import Path
|
||||
from dataset_library.format import dataset_dir
|
||||
|
||||
@ -58,7 +58,6 @@ from .VizWizDataset import (
|
||||
)
|
||||
|
||||
|
||||
|
||||
def get_dataset(
|
||||
dataset_name: str, base_path=dataset_dir
|
||||
) -> dict[Literal["train", "test", "generation"], Dataset]:
|
||||
|
@ -4,6 +4,7 @@ from .factory import get_dataset, MAPPING_NAME_TO_DATASET
|
||||
# Get all registered dataset names for parameterization
|
||||
dataset_names = list(MAPPING_NAME_TO_DATASET.keys())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dataset_name", dataset_names)
|
||||
def test_registered_datasets(dataset_name):
|
||||
dataset = get_dataset(dataset_name)
|
||||
|
@ -1,4 +1,6 @@
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def size_processor(image: Image.Image):
|
||||
width, height = image.size
|
||||
if width > 500 or height > 500:
|
||||
|
@ -1,10 +1,10 @@
|
||||
|
||||
class RegularizationMethod:
|
||||
"""RegularizationMethod implement regularization strategies.
|
||||
RegularizationMethod is a callable.
|
||||
The method `update` is called to update the loss, typically at the end
|
||||
of an experience.
|
||||
"""
|
||||
|
||||
def pre_adapt(self, agent, exp):
|
||||
pass # implementation may be empty if adapt is not needed
|
||||
|
||||
@ -13,3 +13,7 @@ class RegularizationMethod:
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
from .ewc import EWC
|
||||
from .lwf import LWF
|
||||
|
@ -1,6 +1,7 @@
|
||||
from . import RegularizationMethod
|
||||
import torch
|
||||
|
||||
|
||||
class EWC(RegularizationMethod):
|
||||
"""Learning Without Forgetting.
|
||||
|
||||
@ -24,14 +25,19 @@ class EWC(RegularizationMethod):
|
||||
Knowledge distillation uses only units corresponding
|
||||
to old classes.
|
||||
"""
|
||||
def adapt(self, output,model, **kwargs):
|
||||
|
||||
def adapt(self, output, model, **kwargs):
|
||||
ewc_loss = 0
|
||||
for n, p in model.named_parameters():
|
||||
if p.requires_grad:
|
||||
dev = p.device
|
||||
l = self.EWC_lambda * self.fisher[n].to(dev) * (p.data - self.optpar[n].to(dev)).pow(2)
|
||||
l = (
|
||||
self.EWC_lambda
|
||||
* self.fisher[n].to(dev)
|
||||
* (p.data - self.optpar[n].to(dev)).pow(2)
|
||||
)
|
||||
ewc_loss += l.sum()
|
||||
output['loss'] += ewc_loss
|
||||
output["loss"] += ewc_loss
|
||||
return output
|
||||
|
||||
def init_epoch(self, model):
|
||||
@ -42,6 +48,7 @@ class EWC(RegularizationMethod):
|
||||
if p.requires_grad:
|
||||
fisher[n] = torch.zeros(p.data.shape)
|
||||
optpar[n] = p.clone().cpu().data
|
||||
|
||||
def update_fisher(self, model):
|
||||
"""Update the fisher information for the given question id."""
|
||||
for n, p in model.module.base_model.model.named_parameters():
|
||||
|
@ -1,6 +1,7 @@
|
||||
from . import RegularizationMethod
|
||||
import torch
|
||||
|
||||
|
||||
class LWF(RegularizationMethod):
|
||||
"""Learning Without Forgetting.
|
||||
|
||||
@ -23,6 +24,7 @@ class LWF(RegularizationMethod):
|
||||
Knowledge distillation uses only units corresponding
|
||||
to old classes.
|
||||
"""
|
||||
|
||||
def adapt(self, output, **kwargs):
|
||||
def modified_kl_div(old, new):
|
||||
return -torch.mean(torch.sum(old * torch.log(new), 1))
|
||||
@ -37,16 +39,27 @@ class LWF(RegularizationMethod):
|
||||
|
||||
previous_keys = self.previous_logits.keys()
|
||||
|
||||
for index, question_id in enumerate(iterable=kwargs['question_ids']):
|
||||
for index, question_id in enumerate(iterable=kwargs["question_ids"]):
|
||||
if question_id in previous_keys:
|
||||
previous_logits = self.previous_logits[question_id]
|
||||
current_logits = output['logits'][index]
|
||||
current_logits = output["logits"][index]
|
||||
short_index = min(len(previous_logits), len(current_logits))
|
||||
previous_logits = previous_logits[:short_index]
|
||||
current_logits = current_logits[:short_index]
|
||||
lwf_loss.append(modified_kl_div(old=smooth(logits=soft(previous_logits).to(current_logits.device), temp=2, dim=1),new=smooth(logits=soft(current_logits), temp=2, dim=1)))
|
||||
lwf_loss.append(
|
||||
modified_kl_div(
|
||||
old=smooth(
|
||||
logits=soft(previous_logits).to(current_logits.device),
|
||||
temp=2,
|
||||
dim=1,
|
||||
),
|
||||
new=smooth(logits=soft(current_logits), temp=2, dim=1),
|
||||
)
|
||||
)
|
||||
if len(lwf_loss) > 0:
|
||||
output['loss'] += self.LWF_lambda * torch.stack(tensors=lwf_loss, dim=0).sum(dim=0)
|
||||
output["loss"] += self.LWF_lambda * torch.stack(
|
||||
tensors=lwf_loss, dim=0
|
||||
).sum(dim=0)
|
||||
return output
|
||||
|
||||
def update_previous_logits(self, question_id, logits):
|
||||
|
15
src/train.py
15
src/train.py
@ -15,6 +15,8 @@ from trl import (
|
||||
from utils.trainer import ContinualTrainer
|
||||
from utils.args import ContinualScriptArguments, ContinualModelConfig
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
@ -24,15 +26,16 @@ if __name__ == "__main__":
|
||||
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig)
|
||||
)
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
# for type hint
|
||||
if 0 == 1:
|
||||
script_args = ContinualScriptArguments()
|
||||
training_args = TrainingArguments()
|
||||
model_args = ContinualModelConfig()
|
||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||
training_args.remove_unused_columns = False
|
||||
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
||||
|
||||
# for type hint
|
||||
if 1 == 0:
|
||||
script_args = ContinualScriptArguments
|
||||
training_args = TrainingArguments
|
||||
model_args = ContinualModelConfig
|
||||
|
||||
from model_library.factory import get_model
|
||||
|
||||
model, processor, collate_fn_for_train, collate_fn_for_evaluate = get_model(
|
||||
@ -68,6 +71,8 @@ if __name__ == "__main__":
|
||||
else:
|
||||
peft_config = None
|
||||
|
||||
from peft import get_peft_model
|
||||
|
||||
if accelerator.is_local_main_process:
|
||||
print(model)
|
||||
|
||||
|
@ -31,4 +31,3 @@ class ContiunalRegularizationArguments:
|
||||
# LWF
|
||||
lwf_lambda: float = 0.0
|
||||
lwf_enable: bool = False
|
||||
|
@ -15,6 +15,7 @@ from transformers import (
|
||||
TrainingArguments,
|
||||
)
|
||||
from .args import ContiunalRegularizationArguments
|
||||
from peft_library.regularizations import EWC, LWF
|
||||
|
||||
|
||||
class ContinualTrainer(Trainer):
|
||||
@ -26,7 +27,7 @@ class ContinualTrainer(Trainer):
|
||||
train_dataset,
|
||||
eval_dataset,
|
||||
accelerator,
|
||||
regularization_args:ContiunalRegularizationArguments=None,
|
||||
regularization_args: ContiunalRegularizationArguments = None,
|
||||
):
|
||||
self.accelerator = accelerator
|
||||
super().__init__(model, args, data_collator, train_dataset, eval_dataset)
|
||||
@ -38,7 +39,6 @@ class ContinualTrainer(Trainer):
|
||||
if regularization_args.lwf_enable:
|
||||
self.lwf_lambda = regularization_args.lwf_lambda
|
||||
|
||||
|
||||
def create_accelerator_and_postprocess(self):
|
||||
|
||||
if self.accelerator is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user