更新PEFT库以支持MMOELORA类型,修改训练脚本以适应新配置,增强数据集处理逻辑,添加适配器注入功能,扩展PEFT类型枚举
This commit is contained in:
parent
2cd1bb4993
commit
2062f90e5d
@ -1,4 +1,5 @@
|
|||||||
from transformers import Qwen2VLProcessor
|
from transformers import Qwen2VLProcessor
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def collate_fn_for_train(examples, processor: Qwen2VLProcessor):
|
def collate_fn_for_train(examples, processor: Qwen2VLProcessor):
|
||||||
@ -51,6 +52,7 @@ def collate_fn_for_train(examples, processor: Qwen2VLProcessor):
|
|||||||
now_index += 1
|
now_index += 1
|
||||||
now_index += 1
|
now_index += 1
|
||||||
batch["labels"] = labels
|
batch["labels"] = labels
|
||||||
|
# batch["task_id"] = torch.tensor([0] * len(labels), dtype=torch.long)
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
@ -17,9 +17,9 @@ class OCRVQADataset(Dataset):
|
|||||||
self.vis_processor = vis_processor
|
self.vis_processor = vis_processor
|
||||||
self.text_processor = text_processor
|
self.text_processor = text_processor
|
||||||
if split == "train":
|
if split == "train":
|
||||||
self.data = self.create_data(ann_path, split=1)[:1]
|
self.data = self.create_data(ann_path, split=1)[:200]
|
||||||
elif split == "test":
|
elif split == "test":
|
||||||
self.data = self.create_data(ann_path, split=3)[:1]
|
self.data = self.create_data(ann_path, split=3)[:200]
|
||||||
|
|
||||||
# self.instruction_pool = [
|
# self.instruction_pool = [
|
||||||
# "[vqa] {}",
|
# "[vqa] {}",
|
||||||
|
@ -1 +1 @@
|
|||||||
from .mapping import get_peft_config, get_peft_model
|
from .mapping import get_peft_config, get_peft_model, inject_adapter_in_model
|
@ -23,7 +23,7 @@ from peft.tuners.xlora.model import XLoraModel
|
|||||||
|
|
||||||
from peft.config import PeftConfig
|
from peft.config import PeftConfig
|
||||||
from peft.mixed_model import PeftMixedModel
|
from peft.mixed_model import PeftMixedModel
|
||||||
from peft.peft_model import (
|
from .peft_model import (
|
||||||
PeftModel,
|
PeftModel,
|
||||||
PeftModelForCausalLM,
|
PeftModelForCausalLM,
|
||||||
PeftModelForFeatureExtraction,
|
PeftModelForFeatureExtraction,
|
||||||
@ -280,9 +280,9 @@ def inject_adapter_in_model(
|
|||||||
# By instantiating a peft model we are injecting randomly initialized LoRA layers into the model's modules.
|
# By instantiating a peft model we are injecting randomly initialized LoRA layers into the model's modules.
|
||||||
peft_model = tuner_cls(
|
peft_model = tuner_cls(
|
||||||
model,
|
model,
|
||||||
peft_config,
|
{adapter_name: peft_config},
|
||||||
adapter_name=adapter_name,
|
adapter_name=adapter_name,
|
||||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||||
)
|
)
|
||||||
|
print("ok")
|
||||||
return peft_model.model
|
return peft_model.model
|
||||||
|
3053
src/peft_library/peft_model.py
Normal file
3053
src/peft_library/peft_model.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -12,7 +12,9 @@ from transformers.pytorch_utils import Conv1D
|
|||||||
|
|
||||||
from peft_library.utils.peft_types import PeftType
|
from peft_library.utils.peft_types import PeftType
|
||||||
|
|
||||||
from peft_library.utils.constants import TRANSFORMERS_MODELS_TO_MMOELORA_TARGET_MODULES_MAPPING
|
from peft_library.utils.constants import (
|
||||||
|
TRANSFORMERS_MODELS_TO_MMOELORA_TARGET_MODULES_MAPPING,
|
||||||
|
)
|
||||||
from peft.utils.other import _freeze_adapter, _get_submodules, transpose
|
from peft.utils.other import _freeze_adapter, _get_submodules, transpose
|
||||||
|
|
||||||
|
|
||||||
@ -23,7 +25,6 @@ from peft.tuners.lora import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def is_bnb_available():
|
def is_bnb_available():
|
||||||
return importlib.util.find_spec("bitsandbytes") is not None
|
return importlib.util.find_spec("bitsandbytes") is not None
|
||||||
|
|
||||||
@ -47,12 +48,17 @@ class MMOELoraModel(LoraModel):
|
|||||||
Create MMOELoRA (MMOE based LoRA) model from a pretrained transformers model.
|
Create MMOELoRA (MMOE based LoRA) model from a pretrained transformers model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model, config, adapter_name):
|
def __init__(self, model, config, adapter_name, **kwargs):
|
||||||
|
# LoraModel.__init__(self, model, config, adapter_name, **kwargs)
|
||||||
nn.Module.__init__(self)
|
nn.Module.__init__(self)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.forward = self.model.forward
|
self.forward = self.model.forward
|
||||||
self.peft_config = config
|
self.peft_config = config
|
||||||
self.add_adapter(adapter_name, self.peft_config[adapter_name])
|
# self.add_adapter(adapter_name, self.peft_config[adapter_name])
|
||||||
|
|
||||||
|
import sys; print(__file__, sys._getframe().f_lineno)
|
||||||
|
self.add_adapter(adapter_name, config=self.peft_config[adapter_name])
|
||||||
|
import sys; print(__file__, sys._getframe().f_lineno)
|
||||||
|
|
||||||
def add_adapter(self, adapter_name, config=None):
|
def add_adapter(self, adapter_name, config=None):
|
||||||
if config is not None: # get the lora config
|
if config is not None: # get the lora config
|
||||||
@ -64,14 +70,35 @@ class MMOELoraModel(LoraModel):
|
|||||||
config = self._prepare_mmoelora_config(config, model_config) # load config
|
config = self._prepare_mmoelora_config(config, model_config) # load config
|
||||||
self.peft_config[adapter_name] = config # subsititue the original config
|
self.peft_config[adapter_name] = config # subsititue the original config
|
||||||
self._find_and_replace(adapter_name)
|
self._find_and_replace(adapter_name)
|
||||||
|
|
||||||
|
|
||||||
if len(self.peft_config) > 1 and self.peft_config[adapter_name].bias != "none":
|
if len(self.peft_config) > 1 and self.peft_config[adapter_name].bias != "none":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"MMOELoraModel supports only 1 adapter with bias. When using multiple adapters, set bias to 'none' for all adapters."
|
"MMOELoraModel supports only 1 adapter with bias. When using multiple adapters, set bias to 'none' for all adapters."
|
||||||
)
|
)
|
||||||
|
print(self.peft_config)
|
||||||
|
self.mark_only_lora_as_trainable(self.model, self.peft_config[adapter_name].bias)
|
||||||
|
|
||||||
self._mark_only_adapters_as_trainable(self.model)
|
|
||||||
if self.peft_config[adapter_name].inference_mode:
|
if self.peft_config[adapter_name].inference_mode:
|
||||||
_freeze_adapter(self.model, adapter_name)
|
_freeze_adapter(self.model, adapter_name)
|
||||||
|
|
||||||
|
def mark_only_lora_as_trainable(self,model: nn.Module, bias: str = "none") -> None:
|
||||||
|
"""Only activate the LoRA layer as trainable"""
|
||||||
|
for n, p in model.named_parameters():
|
||||||
|
if "lora_" not in n:
|
||||||
|
p.requires_grad = False
|
||||||
|
if bias == "none":
|
||||||
|
return
|
||||||
|
elif bias == "all":
|
||||||
|
for n, p in model.named_parameters():
|
||||||
|
if "bias" in n:
|
||||||
|
p.requires_grad = True
|
||||||
|
elif bias == "lora_only":
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, LoraLayer) and hasattr(m, "bias") and m.bias is not None:
|
||||||
|
m.bias.requires_grad = True
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def _find_and_replace(self, adapter_name):
|
def _find_and_replace(self, adapter_name):
|
||||||
"""Replace the target `Linear` module with LoRA layer (Linear+LoRA)"""
|
"""Replace the target `Linear` module with LoRA layer (Linear+LoRA)"""
|
||||||
@ -106,8 +133,10 @@ class MMOELoraModel(LoraModel):
|
|||||||
for target_key in lora_config.target_modules
|
for target_key in lora_config.target_modules
|
||||||
)
|
)
|
||||||
if target_module_found:
|
if target_module_found:
|
||||||
|
|
||||||
if not is_target_modules_in_base_model:
|
if not is_target_modules_in_base_model:
|
||||||
is_target_modules_in_base_model = True
|
is_target_modules_in_base_model = True
|
||||||
|
|
||||||
parent, target, target_name = _get_submodules(self.model, key)
|
parent, target, target_name = _get_submodules(self.model, key)
|
||||||
bias = target.bias is not None
|
bias = target.bias is not None
|
||||||
if isinstance(target, MMOELoraLayer):
|
if isinstance(target, MMOELoraLayer):
|
||||||
@ -122,6 +151,8 @@ class MMOELoraModel(LoraModel):
|
|||||||
if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
|
if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
else:
|
else:
|
||||||
|
# debug print
|
||||||
|
|
||||||
if isinstance(target, torch.nn.Linear):
|
if isinstance(target, torch.nn.Linear):
|
||||||
in_features, out_features = (
|
in_features, out_features = (
|
||||||
target.in_features,
|
target.in_features,
|
||||||
@ -154,11 +185,18 @@ class MMOELoraModel(LoraModel):
|
|||||||
f"Target module {target} is not supported. "
|
f"Target module {target} is not supported. "
|
||||||
f"Currently, only `torch.nn.Linear` and `Conv1D` are supported."
|
f"Currently, only `torch.nn.Linear` and `Conv1D` are supported."
|
||||||
)
|
)
|
||||||
|
|
||||||
new_module = MMOELoraLinear(
|
new_module = MMOELoraLinear(
|
||||||
adapter_name, in_features, out_features, bias=bias, **kwargs
|
adapter_name,
|
||||||
|
in_features,
|
||||||
|
out_features,
|
||||||
|
bias=bias,
|
||||||
|
base_layer=target,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._replace_module(parent, target_name, new_module, target)
|
self._replace_module(parent, target_name, new_module, target)
|
||||||
|
|
||||||
if not is_target_modules_in_base_model:
|
if not is_target_modules_in_base_model:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Target modules {lora_config.target_modules} not found in the base model. "
|
f"Target modules {lora_config.target_modules} not found in the base model. "
|
||||||
@ -192,9 +230,16 @@ class MMOELoraModel(LoraModel):
|
|||||||
|
|
||||||
class MMOELoraLayer(LoraLayer):
|
class MMOELoraLayer(LoraLayer):
|
||||||
|
|
||||||
def __init__(self, in_features: int, out_features: int, expert_num: int):
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
out_features: int,
|
||||||
|
expert_num: int,
|
||||||
|
base_layer: nn.Linear = None,
|
||||||
|
):
|
||||||
|
super().__init__(base_layer=base_layer)
|
||||||
|
|
||||||
super().__init__(in_features, out_features)
|
self.in_features, self.out_features = in_features, out_features
|
||||||
self.expert_num = expert_num
|
self.expert_num = expert_num
|
||||||
|
|
||||||
def update_layer(
|
def update_layer(
|
||||||
@ -235,7 +280,7 @@ class MMOELoraLayer(LoraLayer):
|
|||||||
nn.init.zeros_(self.lora_B[adapter_name].loraB[i].mlp.weight)
|
nn.init.zeros_(self.lora_B[adapter_name].loraB[i].mlp.weight)
|
||||||
|
|
||||||
|
|
||||||
class MMOELoraLinear(nn.Linear, MMOELoraLayer):
|
class MMOELoraLinear(nn.Module, MMOELoraLayer):
|
||||||
# Lora implemented in a dense layer
|
# Lora implemented in a dense layer
|
||||||
# nn.Linear is the pretrained weights in LLM, MMOELoraLayer is the designed trainable Lora
|
# nn.Linear is the pretrained weights in LLM, MMOELoraLayer is the designed trainable Lora
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -243,25 +288,28 @@ class MMOELoraLinear(nn.Linear, MMOELoraLayer):
|
|||||||
adapter_name: str,
|
adapter_name: str,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
out_features: int,
|
out_features: int,
|
||||||
|
base_layer: nn.Linear = None,
|
||||||
r: int = 0,
|
r: int = 0,
|
||||||
lora_alpha: int = 1,
|
lora_alpha: int = 1,
|
||||||
lora_dropout: float = 0.0,
|
lora_dropout: float = 0.0,
|
||||||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
nn.Module.__init__(self)
|
||||||
init_lora_weights = kwargs.pop("init_lora_weights", True)
|
init_lora_weights = kwargs.pop("init_lora_weights", True)
|
||||||
self.expert_num = kwargs.pop("expert_num", True)
|
|
||||||
self.task_num = kwargs.pop("task_num", True)
|
self.task_num = kwargs.pop("task_num", True)
|
||||||
self.te_dim = kwargs.pop("task_embedding_dim", True)
|
self.te_dim = kwargs.pop("task_embedding_dim", True)
|
||||||
|
|
||||||
nn.Linear.__init__(self, in_features, out_features, **kwargs)
|
|
||||||
MMOELoraLayer.__init__(
|
MMOELoraLayer.__init__(
|
||||||
self,
|
self,
|
||||||
in_features=in_features,
|
in_features=in_features,
|
||||||
out_features=out_features,
|
out_features=out_features,
|
||||||
expert_num=self.expert_num,
|
expert_num=kwargs.pop("expert_num", 2),
|
||||||
|
base_layer=base_layer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# nn.Linear.__init__(self, in_features, out_features, **kwargs)
|
||||||
|
|
||||||
# init the Gate network
|
# init the Gate network
|
||||||
self.lora_task_embedding = nn.ModuleDict({})
|
self.lora_task_embedding = nn.ModuleDict({})
|
||||||
self.lora_gate = nn.ModuleDict({})
|
self.lora_gate = nn.ModuleDict({})
|
||||||
@ -279,100 +327,90 @@ class MMOELoraLinear(nn.Linear, MMOELoraLayer):
|
|||||||
if fan_in_fan_out:
|
if fan_in_fan_out:
|
||||||
self.weight.data = self.weight.data.T
|
self.weight.data = self.weight.data.T
|
||||||
|
|
||||||
nn.Linear.reset_parameters(self)
|
|
||||||
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
|
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
|
||||||
self.active_adapter = adapter_name
|
self._active_adapter = adapter_name
|
||||||
|
|
||||||
def merge(self, task_id):
|
def merge(self, task_id):
|
||||||
if self.active_adapter not in self.lora_A.keys():
|
if self._active_adapter not in self.lora_A.keys():
|
||||||
return
|
return
|
||||||
if self.merged:
|
if self.merged:
|
||||||
warnings.warn("Already merged. Nothing to do.")
|
warnings.warn("Already merged. Nothing to do.")
|
||||||
return
|
return
|
||||||
if self.r[self.active_adapter] > 0:
|
if self.r[self._active_adapter] > 0:
|
||||||
expert_weight = self.lora_gate[self.active_adapter](
|
expert_weight = self.lora_gate[self._active_adapter](
|
||||||
self.lora_task_embedding[self.active_adapter](task_id)
|
self.lora_task_embedding[self._active_adapter](task_id)
|
||||||
)
|
)
|
||||||
for i in range(self.expert_num):
|
for i in range(self.expert_num):
|
||||||
lora_A_weights = self.lora_A[self.active_adapter].loraA[i].mlp.weight
|
lora_A_weights = self.lora_A[self._active_adapter].loraA[i].mlp.weight
|
||||||
lora_B_weights = self.lora_B[self.active_adapter].loraB[i].mlp.weight
|
lora_B_weights = self.lora_B[self._active_adapter].loraB[i].mlp.weight
|
||||||
self.weight.data += (
|
self.base_layer.weight.data += (
|
||||||
transpose(
|
transpose(
|
||||||
lora_B_weights @ lora_A_weights,
|
lora_B_weights @ lora_A_weights,
|
||||||
self.fan_in_fan_out,
|
self.fan_in_fan_out,
|
||||||
)
|
)
|
||||||
* self.scaling[self.active_adapter]
|
* self.scaling[self._active_adapter]
|
||||||
* expert_weight[..., i]
|
* expert_weight[..., i]
|
||||||
)
|
)
|
||||||
self.merged = True
|
self.merged = True
|
||||||
|
|
||||||
def unmerge(self, task_id):
|
def unmerge(self, task_id):
|
||||||
if self.active_adapter not in self.lora_A.keys():
|
if self._active_adapter not in self.lora_A.keys():
|
||||||
return
|
return
|
||||||
if not self.merged:
|
if not self.merged:
|
||||||
warnings.warn("Already unmerged. Nothing to do.")
|
warnings.warn("Already unmerged. Nothing to do.")
|
||||||
return
|
return
|
||||||
if self.r[self.active_adapter] > 0:
|
if self.r[self._active_adapter] > 0:
|
||||||
expert_weight = self.lora_gate[self.active_adapter](
|
expert_weight = self.lora_gate[self._active_adapter](
|
||||||
self.lora_task_embedding[self.active_adapter](task_id)
|
self.lora_task_embedding[self._active_adapter](task_id)
|
||||||
)
|
)
|
||||||
for i in range(self.expert_num):
|
for i in range(self.expert_num):
|
||||||
lora_A_weights = self.lora_A[self.active_adapter].loraA[i].mlp.weight
|
lora_A_weights = self.lora_A[self._active_adapter].loraA[i].mlp.weight
|
||||||
lora_B_weights = self.lora_B[self.active_adapter].loraB[i].mlp.weight
|
lora_B_weights = self.lora_B[self._active_adapter].loraB[i].mlp.weight
|
||||||
self.weight.data -= (
|
self.base_layer.weight.data -= (
|
||||||
transpose(
|
transpose(
|
||||||
lora_B_weights @ lora_A_weights,
|
lora_B_weights @ lora_A_weights,
|
||||||
self.fan_in_fan_out,
|
self.fan_in_fan_out,
|
||||||
)
|
)
|
||||||
* self.scaling[self.active_adapter]
|
* self.scaling[self._active_adapter]
|
||||||
* expert_weight[..., i]
|
* expert_weight[..., i]
|
||||||
)
|
)
|
||||||
self.merged = False
|
self.merged = False
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, **kwargs):
|
def forward(self, x: torch.Tensor, **kwargs):
|
||||||
task_id = kwargs["task_id"]
|
# task_id = kwargs["task_id"]
|
||||||
|
for k,v in kwargs.items():
|
||||||
|
print(k, v.shape)
|
||||||
|
task_id = torch.tensor([0] * len(x), dtype=torch.long).to(x.device)
|
||||||
previous_dtype = x.dtype
|
previous_dtype = x.dtype
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.active_adapter not in self.lora_A.keys()
|
self._active_adapter not in self.lora_A.keys()
|
||||||
): # No adapter, directly use linear
|
): # No adapter, directly use linear
|
||||||
return F.linear(
|
return self.base_layer(x)
|
||||||
x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
|
|
||||||
)
|
|
||||||
if self.disable_adapters: # No adapter
|
if self.disable_adapters: # No adapter
|
||||||
if (
|
if (
|
||||||
self.r[self.active_adapter] > 0 and self.merged
|
self.r[self._active_adapter] > 0 and self.merged
|
||||||
): # merge the adapter to linear
|
): # merge the adapter to linear
|
||||||
self.unmerge(task_id)
|
self.unmerge(task_id)
|
||||||
result = F.linear(
|
self.base_layer.weight.data = self.weight.data
|
||||||
x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
|
|
||||||
)
|
|
||||||
elif (
|
elif (
|
||||||
self.r[self.active_adapter] > 0 and not self.merged
|
self.r[self._active_adapter] > 0 and not self.merged
|
||||||
): # general lora process
|
): # general lora process
|
||||||
result = F.linear(
|
result = self.base_layer(x)
|
||||||
x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
|
|
||||||
)
|
|
||||||
|
|
||||||
x = x.to(self.lora_A[self.active_adapter].loraA[0].weight.dtype)
|
x = x.to(self.lora_A[self._active_adapter].loraA[0].weight.dtype)
|
||||||
|
|
||||||
expert_weight = self.lora_gate[self.active_adapter](
|
expert_weight = self.lora_gate[self._active_adapter](
|
||||||
self.lora_task_embedding[self.active_adapter](task_id)
|
self.lora_task_embedding[self._active_adapter](task_id)
|
||||||
)
|
)
|
||||||
for i in range(self.expert_num):
|
for i in range(self.expert_num):
|
||||||
result += ( # lora process
|
result += (
|
||||||
self.lora_B[self.active_adapter].loraB[i](
|
self.lora_B[self._active_adapter].loraB[i](self.lora_A[self._active_adapter].loraA[i](self.lora_dropout[self._active_adapter](x)))
|
||||||
self.lora_A[self.active_adapter].loraA[i](
|
* self.scaling[self._active_adapter]
|
||||||
self.lora_dropout[self.active_adapter](x)
|
* expert_weight[..., i].view(-1, 1, 1)
|
||||||
),
|
|
||||||
)
|
|
||||||
* self.scaling[self.active_adapter]
|
|
||||||
* expert_weight[..., i].unsqueeze(-1).unsqueeze(0)
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result = F.linear(
|
result = self.base_layer(x)
|
||||||
x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias
|
|
||||||
)
|
|
||||||
|
|
||||||
result = result.to(previous_dtype)
|
result = result.to(previous_dtype)
|
||||||
|
|
||||||
|
@ -26,13 +26,29 @@ from .other import CONFIG_NAME
|
|||||||
|
|
||||||
class PeftType(str, enum.Enum):
|
class PeftType(str, enum.Enum):
|
||||||
PROMPT_TUNING = "PROMPT_TUNING"
|
PROMPT_TUNING = "PROMPT_TUNING"
|
||||||
|
MULTITASK_PROMPT_TUNING = "MULTITASK_PROMPT_TUNING"
|
||||||
P_TUNING = "P_TUNING"
|
P_TUNING = "P_TUNING"
|
||||||
PREFIX_TUNING = "PREFIX_TUNING"
|
PREFIX_TUNING = "PREFIX_TUNING"
|
||||||
LORA = "LORA"
|
LORA = "LORA"
|
||||||
ADALORA = "ADALORA"
|
ADALORA = "ADALORA"
|
||||||
|
BOFT = "BOFT"
|
||||||
ADAPTION_PROMPT = "ADAPTION_PROMPT"
|
ADAPTION_PROMPT = "ADAPTION_PROMPT"
|
||||||
|
IA3 = "IA3"
|
||||||
|
LOHA = "LOHA"
|
||||||
|
LOKR = "LOKR"
|
||||||
|
OFT = "OFT"
|
||||||
|
POLY = "POLY"
|
||||||
|
LN_TUNING = "LN_TUNING"
|
||||||
|
VERA = "VERA"
|
||||||
|
FOURIERFT = "FOURIERFT"
|
||||||
|
XLORA = "XLORA"
|
||||||
|
HRA = "HRA"
|
||||||
|
VBLORA = "VBLORA"
|
||||||
|
CPT = "CPT"
|
||||||
|
BONE = "BONE"
|
||||||
|
|
||||||
MMOELORAS = "MMOELORAS"
|
MMOELORAS = "MMOELORAS"
|
||||||
|
MMOELORA = "MMOELORA"
|
||||||
|
|
||||||
class TaskType(str, enum.Enum):
|
class TaskType(str, enum.Enum):
|
||||||
SEQ_CLS = "SEQ_CLS"
|
SEQ_CLS = "SEQ_CLS"
|
||||||
|
31
src/train.py
31
src/train.py
@ -1,6 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
from dataset_library.factory import get_dataset
|
from dataset_library.factory import get_dataset
|
||||||
from transformers import AutoModelForVision2Seq, AutoProcessor, TrainingArguments
|
from transformers import AutoModelForVision2Seq, AutoProcessor, TrainingArguments, Qwen2VLForConditionalGeneration
|
||||||
|
|
||||||
from trl import (
|
from trl import (
|
||||||
ModelConfig,
|
ModelConfig,
|
||||||
@ -9,7 +9,7 @@ from trl import (
|
|||||||
# get_peft_config,
|
# get_peft_config,
|
||||||
get_quantization_config,
|
get_quantization_config,
|
||||||
)
|
)
|
||||||
from peft_library import get_peft_model, get_peft_config
|
from peft_library import get_peft_model, get_peft_config, inject_adapter_in_model
|
||||||
|
|
||||||
from utils.trainer import ContinualTrainer
|
from utils.trainer import ContinualTrainer
|
||||||
from utils.args import ContinualScriptArguments, ContinualModelConfig
|
from utils.args import ContinualScriptArguments, ContinualModelConfig
|
||||||
@ -30,16 +30,6 @@ if __name__ == "__main__":
|
|||||||
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
||||||
|
|
||||||
# peft_config = get_peft_config(dict(**vars(model_args)))
|
# peft_config = get_peft_config(dict(**vars(model_args)))
|
||||||
if model_args.peft_type == "MMOELora":
|
|
||||||
from peft_library.tuners import MMOELoraConfig
|
|
||||||
|
|
||||||
peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules)
|
|
||||||
elif model_args.peft_type == "LORA":
|
|
||||||
from peft.tuners.lora import LoraConfig
|
|
||||||
|
|
||||||
peft_config = LoraConfig(target_modules=model_args.lora_target_modules)
|
|
||||||
else:
|
|
||||||
peft_config = None
|
|
||||||
|
|
||||||
torch_dtype = (
|
torch_dtype = (
|
||||||
model_args.torch_dtype
|
model_args.torch_dtype
|
||||||
@ -84,11 +74,26 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
accelerator = create_accelerator_and_postprocess(training_args)
|
accelerator = create_accelerator_and_postprocess(training_args)
|
||||||
|
|
||||||
if peft_config is not None:
|
if model_args.peft_type == "MMOELORA":
|
||||||
|
from peft_library.tuners import MMOELoraConfig
|
||||||
|
|
||||||
|
peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules)
|
||||||
|
|
||||||
|
# model = get_peft_model(model, peft_config)
|
||||||
|
model = inject_adapter_in_model(peft_config, model)
|
||||||
|
print(model)
|
||||||
|
elif model_args.peft_type == "LORA":
|
||||||
|
from peft.tuners.lora import LoraConfig
|
||||||
|
|
||||||
|
peft_config = LoraConfig(target_modules=model_args.lora_target_modules)
|
||||||
|
|
||||||
model = get_peft_model(model, peft_config)
|
model = get_peft_model(model, peft_config)
|
||||||
if accelerator.is_local_main_process:
|
if accelerator.is_local_main_process:
|
||||||
model.print_trainable_parameters()
|
model.print_trainable_parameters()
|
||||||
|
|
||||||
|
else:
|
||||||
|
peft_config = None
|
||||||
|
|
||||||
for dataset_name in script_args.dataset_name:
|
for dataset_name in script_args.dataset_name:
|
||||||
dataset = get_dataset(dataset_name)
|
dataset = get_dataset(dataset_name)
|
||||||
model.train()
|
model.train()
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
|
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
|
||||||
--dataset_name OCR_VQA_200K OCR_VQA_200K OCR_VQA_200K \
|
--dataset_name OCR_VQA_200K OCR_VQA_200K OCR_VQA_200K \
|
||||||
--use_peft \
|
--use_peft \
|
||||||
--peft_type LORA \
|
--peft_type MMOELORA \
|
||||||
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
||||||
--lora_target_modules q_proj v_proj \
|
--lora_target_modules q_proj v_proj \
|
||||||
--per_device_train_batch_size 1 \
|
--per_device_train_batch_size 1 \
|
||||||
|
Loading…
Reference in New Issue
Block a user