fix: 修复数据集注册和加载逻辑,优化代码格式,确保一致性
This commit is contained in:
parent
1d2e7b9dcd
commit
56e46f0e0c
@ -25,8 +25,10 @@ class RefCOCODataset(Dataset):
|
|||||||
self.vis_processor = vis_processor
|
self.vis_processor = vis_processor
|
||||||
self.text_processor = text_processor
|
self.text_processor = text_processor
|
||||||
self.split = split
|
self.split = split
|
||||||
|
|
||||||
def load_data(self):
|
def load_data(self):
|
||||||
from .format import dataset_dir
|
from .format import dataset_dir
|
||||||
|
|
||||||
ds = load_dataset("lmms-lab/RefCOCO", cache_dir=dataset_dir) # type: ignore
|
ds = load_dataset("lmms-lab/RefCOCO", cache_dir=dataset_dir) # type: ignore
|
||||||
self.data = ds[self.split] # type: ignore
|
self.data = ds[self.split] # type: ignore
|
||||||
|
|
||||||
@ -118,12 +120,14 @@ def test_RefCOCO():
|
|||||||
assert len(dataset) > 0
|
assert len(dataset) > 0
|
||||||
assert len(dataset[0]["chat"]) > 0
|
assert len(dataset[0]["chat"]) > 0
|
||||||
|
|
||||||
|
|
||||||
dataset = {
|
dataset = {
|
||||||
"train": RefCOCODataset(split="val"),
|
"train": RefCOCODataset(split="val"),
|
||||||
"test": RefCOCODataset(split="test"),
|
"test": RefCOCODataset(split="test"),
|
||||||
"generation": RefCOCODatasetForGeneration(split="test"),
|
"generation": RefCOCODatasetForGeneration(split="test"),
|
||||||
}
|
}
|
||||||
from .factory import register_dataset
|
from .factory import register_dataset
|
||||||
|
|
||||||
register_dataset(
|
register_dataset(
|
||||||
dataset_name="refcoco",
|
dataset_name="refcoco",
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
|
@ -28,6 +28,7 @@ class RefCOCOplusDataset(Dataset):
|
|||||||
|
|
||||||
def load_data(self):
|
def load_data(self):
|
||||||
from .format import dataset_dir
|
from .format import dataset_dir
|
||||||
|
|
||||||
ds = load_dataset("lmms-lab/RefCOCOplus", cache_dir=dataset_dir) # type: ignore
|
ds = load_dataset("lmms-lab/RefCOCOplus", cache_dir=dataset_dir) # type: ignore
|
||||||
self.data = ds[self.split] # type: ignore
|
self.data = ds[self.split] # type: ignore
|
||||||
|
|
||||||
@ -119,12 +120,14 @@ def test_RefCOCOplus():
|
|||||||
assert len(dataset) > 0
|
assert len(dataset) > 0
|
||||||
assert len(dataset[0]["chat"]) > 0
|
assert len(dataset[0]["chat"]) > 0
|
||||||
|
|
||||||
|
|
||||||
dataset = {
|
dataset = {
|
||||||
"train": RefCOCOplusDataset(split="val"),
|
"train": RefCOCOplusDataset(split="val"),
|
||||||
"test": RefCOCOplusDataset(split="testA"),
|
"test": RefCOCOplusDataset(split="testA"),
|
||||||
"generation": RefCOCOplusDatasetForGeneration(split="testA"),
|
"generation": RefCOCOplusDatasetForGeneration(split="testA"),
|
||||||
}
|
}
|
||||||
from .factory import register_dataset
|
from .factory import register_dataset
|
||||||
|
|
||||||
register_dataset(
|
register_dataset(
|
||||||
dataset_name="refcocoplus",
|
dataset_name="refcocoplus",
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
|
@ -28,6 +28,7 @@ class RefCOCOgDataset(Dataset):
|
|||||||
|
|
||||||
def load_data(self):
|
def load_data(self):
|
||||||
from .format import dataset_dir
|
from .format import dataset_dir
|
||||||
|
|
||||||
ds = load_dataset("lmms-lab/RefCOCOg", cache_dir=dataset_dir) # type: ignore
|
ds = load_dataset("lmms-lab/RefCOCOg", cache_dir=dataset_dir) # type: ignore
|
||||||
self.data = ds[self.split] # type: ignore
|
self.data = ds[self.split] # type: ignore
|
||||||
|
|
||||||
@ -119,12 +120,14 @@ def test_RefCOCOg():
|
|||||||
assert len(dataset) > 0
|
assert len(dataset) > 0
|
||||||
assert len(dataset[0]["chat"]) > 0
|
assert len(dataset[0]["chat"]) > 0
|
||||||
|
|
||||||
|
|
||||||
dataset = {
|
dataset = {
|
||||||
"train": RefCOCOgDataset(split="val"),
|
"train": RefCOCOgDataset(split="val"),
|
||||||
"test": RefCOCOgDataset(split="test"),
|
"test": RefCOCOgDataset(split="test"),
|
||||||
"generation": RefCOCOgDatasetForGeneration(split="test"),
|
"generation": RefCOCOgDatasetForGeneration(split="test"),
|
||||||
}
|
}
|
||||||
from .factory import register_dataset
|
from .factory import register_dataset
|
||||||
|
|
||||||
register_dataset(
|
register_dataset(
|
||||||
dataset_name="refcocog",
|
dataset_name="refcocog",
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
|
@ -22,7 +22,8 @@ class ScienceQADataset(Dataset):
|
|||||||
|
|
||||||
def load_data(self):
|
def load_data(self):
|
||||||
from .format import dataset_dir
|
from .format import dataset_dir
|
||||||
ds = load_dataset("derek-thomas/ScienceQA",cache_dir=dataset_dir) # type: ignore
|
|
||||||
|
ds = load_dataset("derek-thomas/ScienceQA", cache_dir=dataset_dir) # type: ignore
|
||||||
self.data = ds[self.split] # type: ignore
|
self.data = ds[self.split] # type: ignore
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@ -119,12 +120,14 @@ def test_scienceQA():
|
|||||||
assert len(dataset) > 0
|
assert len(dataset) > 0
|
||||||
assert len(dataset[0]["chat"]) > 0
|
assert len(dataset[0]["chat"]) > 0
|
||||||
|
|
||||||
|
|
||||||
dataset = {
|
dataset = {
|
||||||
"train": ScienceQADataset(split="train"),
|
"train": ScienceQADataset(split="train"),
|
||||||
"test": ScienceQADataset(split="test"),
|
"test": ScienceQADataset(split="test"),
|
||||||
"generation": ScienceQADatasetForGeneration(split="test"),
|
"generation": ScienceQADatasetForGeneration(split="test"),
|
||||||
}
|
}
|
||||||
from .factory import register_dataset
|
from .factory import register_dataset
|
||||||
|
|
||||||
register_dataset(
|
register_dataset(
|
||||||
dataset_name="scienceqa",
|
dataset_name="scienceqa",
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from typing import Literal,List
|
from typing import Literal, List
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from dataset_library.format import dataset_dir
|
from dataset_library.format import dataset_dir
|
||||||
|
|
||||||
@ -58,7 +58,6 @@ from .VizWizDataset import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(
|
def get_dataset(
|
||||||
dataset_name: str, base_path=dataset_dir
|
dataset_name: str, base_path=dataset_dir
|
||||||
) -> dict[Literal["train", "test", "generation"], Dataset]:
|
) -> dict[Literal["train", "test", "generation"], Dataset]:
|
||||||
|
@ -4,6 +4,7 @@ from .factory import get_dataset, MAPPING_NAME_TO_DATASET
|
|||||||
# Get all registered dataset names for parameterization
|
# Get all registered dataset names for parameterization
|
||||||
dataset_names = list(MAPPING_NAME_TO_DATASET.keys())
|
dataset_names = list(MAPPING_NAME_TO_DATASET.keys())
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dataset_name", dataset_names)
|
@pytest.mark.parametrize("dataset_name", dataset_names)
|
||||||
def test_registered_datasets(dataset_name):
|
def test_registered_datasets(dataset_name):
|
||||||
dataset = get_dataset(dataset_name)
|
dataset = get_dataset(dataset_name)
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
def size_processor(image: Image.Image):
|
def size_processor(image: Image.Image):
|
||||||
width, height = image.size
|
width, height = image.size
|
||||||
if width > 500 or height > 500:
|
if width > 500 or height > 500:
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
|
|
||||||
class RegularizationMethod:
|
class RegularizationMethod:
|
||||||
"""RegularizationMethod implement regularization strategies.
|
"""RegularizationMethod implement regularization strategies.
|
||||||
RegularizationMethod is a callable.
|
RegularizationMethod is a callable.
|
||||||
The method `update` is called to update the loss, typically at the end
|
The method `update` is called to update the loss, typically at the end
|
||||||
of an experience.
|
of an experience.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def pre_adapt(self, agent, exp):
|
def pre_adapt(self, agent, exp):
|
||||||
pass # implementation may be empty if adapt is not needed
|
pass # implementation may be empty if adapt is not needed
|
||||||
|
|
||||||
@ -13,3 +13,7 @@ class RegularizationMethod:
|
|||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
from .ewc import EWC
|
||||||
|
from .lwf import LWF
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from . import RegularizationMethod
|
from . import RegularizationMethod
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class EWC(RegularizationMethod):
|
class EWC(RegularizationMethod):
|
||||||
"""Learning Without Forgetting.
|
"""Learning Without Forgetting.
|
||||||
|
|
||||||
@ -24,14 +25,19 @@ class EWC(RegularizationMethod):
|
|||||||
Knowledge distillation uses only units corresponding
|
Knowledge distillation uses only units corresponding
|
||||||
to old classes.
|
to old classes.
|
||||||
"""
|
"""
|
||||||
def adapt(self, output,model, **kwargs):
|
|
||||||
|
def adapt(self, output, model, **kwargs):
|
||||||
ewc_loss = 0
|
ewc_loss = 0
|
||||||
for n, p in model.named_parameters():
|
for n, p in model.named_parameters():
|
||||||
if p.requires_grad:
|
if p.requires_grad:
|
||||||
dev = p.device
|
dev = p.device
|
||||||
l = self.EWC_lambda * self.fisher[n].to(dev) * (p.data - self.optpar[n].to(dev)).pow(2)
|
l = (
|
||||||
|
self.EWC_lambda
|
||||||
|
* self.fisher[n].to(dev)
|
||||||
|
* (p.data - self.optpar[n].to(dev)).pow(2)
|
||||||
|
)
|
||||||
ewc_loss += l.sum()
|
ewc_loss += l.sum()
|
||||||
output['loss'] += ewc_loss
|
output["loss"] += ewc_loss
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def init_epoch(self, model):
|
def init_epoch(self, model):
|
||||||
@ -42,6 +48,7 @@ class EWC(RegularizationMethod):
|
|||||||
if p.requires_grad:
|
if p.requires_grad:
|
||||||
fisher[n] = torch.zeros(p.data.shape)
|
fisher[n] = torch.zeros(p.data.shape)
|
||||||
optpar[n] = p.clone().cpu().data
|
optpar[n] = p.clone().cpu().data
|
||||||
|
|
||||||
def update_fisher(self, model):
|
def update_fisher(self, model):
|
||||||
"""Update the fisher information for the given question id."""
|
"""Update the fisher information for the given question id."""
|
||||||
for n, p in model.module.base_model.model.named_parameters():
|
for n, p in model.module.base_model.model.named_parameters():
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from . import RegularizationMethod
|
from . import RegularizationMethod
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class LWF(RegularizationMethod):
|
class LWF(RegularizationMethod):
|
||||||
"""Learning Without Forgetting.
|
"""Learning Without Forgetting.
|
||||||
|
|
||||||
@ -23,6 +24,7 @@ class LWF(RegularizationMethod):
|
|||||||
Knowledge distillation uses only units corresponding
|
Knowledge distillation uses only units corresponding
|
||||||
to old classes.
|
to old classes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def adapt(self, output, **kwargs):
|
def adapt(self, output, **kwargs):
|
||||||
def modified_kl_div(old, new):
|
def modified_kl_div(old, new):
|
||||||
return -torch.mean(torch.sum(old * torch.log(new), 1))
|
return -torch.mean(torch.sum(old * torch.log(new), 1))
|
||||||
@ -37,16 +39,27 @@ class LWF(RegularizationMethod):
|
|||||||
|
|
||||||
previous_keys = self.previous_logits.keys()
|
previous_keys = self.previous_logits.keys()
|
||||||
|
|
||||||
for index, question_id in enumerate(iterable=kwargs['question_ids']):
|
for index, question_id in enumerate(iterable=kwargs["question_ids"]):
|
||||||
if question_id in previous_keys:
|
if question_id in previous_keys:
|
||||||
previous_logits = self.previous_logits[question_id]
|
previous_logits = self.previous_logits[question_id]
|
||||||
current_logits = output['logits'][index]
|
current_logits = output["logits"][index]
|
||||||
short_index = min(len(previous_logits), len(current_logits))
|
short_index = min(len(previous_logits), len(current_logits))
|
||||||
previous_logits = previous_logits[:short_index]
|
previous_logits = previous_logits[:short_index]
|
||||||
current_logits = current_logits[:short_index]
|
current_logits = current_logits[:short_index]
|
||||||
lwf_loss.append(modified_kl_div(old=smooth(logits=soft(previous_logits).to(current_logits.device), temp=2, dim=1),new=smooth(logits=soft(current_logits), temp=2, dim=1)))
|
lwf_loss.append(
|
||||||
|
modified_kl_div(
|
||||||
|
old=smooth(
|
||||||
|
logits=soft(previous_logits).to(current_logits.device),
|
||||||
|
temp=2,
|
||||||
|
dim=1,
|
||||||
|
),
|
||||||
|
new=smooth(logits=soft(current_logits), temp=2, dim=1),
|
||||||
|
)
|
||||||
|
)
|
||||||
if len(lwf_loss) > 0:
|
if len(lwf_loss) > 0:
|
||||||
output['loss'] += self.LWF_lambda * torch.stack(tensors=lwf_loss, dim=0).sum(dim=0)
|
output["loss"] += self.LWF_lambda * torch.stack(
|
||||||
|
tensors=lwf_loss, dim=0
|
||||||
|
).sum(dim=0)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def update_previous_logits(self, question_id, logits):
|
def update_previous_logits(self, question_id, logits):
|
||||||
|
15
src/train.py
15
src/train.py
@ -15,6 +15,8 @@ from trl import (
|
|||||||
from utils.trainer import ContinualTrainer
|
from utils.trainer import ContinualTrainer
|
||||||
from utils.args import ContinualScriptArguments, ContinualModelConfig
|
from utils.args import ContinualScriptArguments, ContinualModelConfig
|
||||||
import logging
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
@ -24,15 +26,16 @@ if __name__ == "__main__":
|
|||||||
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig)
|
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig)
|
||||||
)
|
)
|
||||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||||
# for type hint
|
|
||||||
if 0 == 1:
|
|
||||||
script_args = ContinualScriptArguments()
|
|
||||||
training_args = TrainingArguments()
|
|
||||||
model_args = ContinualModelConfig()
|
|
||||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||||
training_args.remove_unused_columns = False
|
training_args.remove_unused_columns = False
|
||||||
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
||||||
|
|
||||||
|
# for type hint
|
||||||
|
if 1 == 0:
|
||||||
|
script_args = ContinualScriptArguments
|
||||||
|
training_args = TrainingArguments
|
||||||
|
model_args = ContinualModelConfig
|
||||||
|
|
||||||
from model_library.factory import get_model
|
from model_library.factory import get_model
|
||||||
|
|
||||||
model, processor, collate_fn_for_train, collate_fn_for_evaluate = get_model(
|
model, processor, collate_fn_for_train, collate_fn_for_evaluate = get_model(
|
||||||
@ -68,6 +71,8 @@ if __name__ == "__main__":
|
|||||||
else:
|
else:
|
||||||
peft_config = None
|
peft_config = None
|
||||||
|
|
||||||
|
from peft import get_peft_model
|
||||||
|
|
||||||
if accelerator.is_local_main_process:
|
if accelerator.is_local_main_process:
|
||||||
print(model)
|
print(model)
|
||||||
|
|
||||||
|
@ -31,4 +31,3 @@ class ContiunalRegularizationArguments:
|
|||||||
# LWF
|
# LWF
|
||||||
lwf_lambda: float = 0.0
|
lwf_lambda: float = 0.0
|
||||||
lwf_enable: bool = False
|
lwf_enable: bool = False
|
||||||
|
|
@ -15,6 +15,7 @@ from transformers import (
|
|||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
)
|
)
|
||||||
from .args import ContiunalRegularizationArguments
|
from .args import ContiunalRegularizationArguments
|
||||||
|
from peft_library.regularizations import EWC, LWF
|
||||||
|
|
||||||
|
|
||||||
class ContinualTrainer(Trainer):
|
class ContinualTrainer(Trainer):
|
||||||
@ -26,7 +27,7 @@ class ContinualTrainer(Trainer):
|
|||||||
train_dataset,
|
train_dataset,
|
||||||
eval_dataset,
|
eval_dataset,
|
||||||
accelerator,
|
accelerator,
|
||||||
regularization_args:ContiunalRegularizationArguments=None,
|
regularization_args: ContiunalRegularizationArguments = None,
|
||||||
):
|
):
|
||||||
self.accelerator = accelerator
|
self.accelerator = accelerator
|
||||||
super().__init__(model, args, data_collator, train_dataset, eval_dataset)
|
super().__init__(model, args, data_collator, train_dataset, eval_dataset)
|
||||||
@ -38,7 +39,6 @@ class ContinualTrainer(Trainer):
|
|||||||
if regularization_args.lwf_enable:
|
if regularization_args.lwf_enable:
|
||||||
self.lwf_lambda = regularization_args.lwf_lambda
|
self.lwf_lambda = regularization_args.lwf_lambda
|
||||||
|
|
||||||
|
|
||||||
def create_accelerator_and_postprocess(self):
|
def create_accelerator_and_postprocess(self):
|
||||||
|
|
||||||
if self.accelerator is not None:
|
if self.accelerator is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user