Refactor code structure for improved readability and maintainability
This commit is contained in:
+2
-1
@@ -1 +1,2 @@
|
||||
checkpoint/*
|
||||
checkpoint/*
|
||||
wandb/*
|
||||
@@ -2,7 +2,7 @@ compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
deepspeed_multinode_launcher: standard
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_accumulation_steps: 4
|
||||
zero3_init_flag: false
|
||||
zero_stage: 1
|
||||
distributed_type: DEEPSPEED
|
||||
@@ -11,7 +11,7 @@ machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'bf16'
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
num_processes: 4
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
|
||||
@@ -12,7 +12,7 @@ machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'bf16'
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
num_processes: 4
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
|
||||
+33
-1
@@ -38,12 +38,44 @@ if __name__ == "__main__":
|
||||
|
||||
from model_library.factory import get_model
|
||||
|
||||
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
|
||||
if model_args.model_name_or_path == "Qwen/Qwen2.5-VL-3B-Instruct":
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype
|
||||
if model_args.torch_dtype in ["auto", None]
|
||||
else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
model_kwargs = dict(
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
training_args.output_dir,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
processor = Qwen2_5_VLProcessor.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
padding_side="left",
|
||||
)
|
||||
|
||||
from model_library.qwen2vl import collate_fn_for_train, collate_fn_for_evaluate
|
||||
from functools import partial
|
||||
|
||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||
|
||||
elif model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype
|
||||
if model_args.torch_dtype in ["auto", None]
|
||||
else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
model_kwargs = dict(
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
|
||||
@@ -68,4 +68,25 @@ def get_model(model_args: ContinualModelConfig):
|
||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||
|
||||
if model_args.model_name_or_path == "Qwen/Qwen2.5-VL-3B-Instruct":
|
||||
from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
processor = Qwen2_5_VLProcessor.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
padding_side="left",
|
||||
)
|
||||
|
||||
from model_library.qwen2vl import collate_fn_for_train, collate_fn_for_evaluate
|
||||
from functools import partial
|
||||
|
||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||
|
||||
return model, processor, collate_fn_for_train, collate_fn_for_evaluate
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from .collate_fn import collate_fn_for_evaluate, collate_fn_for_train
|
||||
from .model import Qwen2VLForConditionalGeneration_modified
|
||||
|
||||
# from .model import Qwen2VLForConditionalGeneration_modified
|
||||
|
||||
__all__ = [
|
||||
"collate_fn_for_train",
|
||||
|
||||
@@ -73,7 +73,6 @@ from peft.tuners import (
|
||||
from .tuners import MMOELoraModel, MMOELoraConfig
|
||||
from peft.tuners.tuners_utils import BaseTuner
|
||||
from peft.utils import _prepare_prompt_learning_config
|
||||
from peft.utils.constants import PEFT_TYPE_TO_PREFIX_MAPPING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -46,7 +46,7 @@ from transformers.modeling_outputs import (
|
||||
)
|
||||
from transformers.utils import PushToHubMixin
|
||||
|
||||
from peft.utils.constants import DUMMY_MODEL_CONFIG, PEFT_TYPE_TO_PREFIX_MAPPING
|
||||
from peft.utils.constants import DUMMY_MODEL_CONFIG
|
||||
|
||||
from peft import __version__
|
||||
from peft.config import PeftConfig
|
||||
|
||||
+1
-1
Submodule src/peft_repo updated: 65c3c43cd1...83111347f3
+5
-5
@@ -1,15 +1,15 @@
|
||||
#!/bin/bash
|
||||
|
||||
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
|
||||
--dataset_name chem \
|
||||
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero1.yaml train.py \
|
||||
--dataset_name refcoco \
|
||||
--use_peft \
|
||||
--peft_type LORA \
|
||||
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
||||
--model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--lora_target_modules .\*proj.\*\|.\*fc.\*\|.\*mlp\.0\|.\*mlp\.2 \
|
||||
--lora_r 8 \
|
||||
--lora_alpha 32 \
|
||||
--per_device_train_batch_size 2 \
|
||||
--per_device_eval_batch_size 2 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 1 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
--output_dir checkpoint/qwen2_alllinear/ \
|
||||
--learning_rate 1e-4 \
|
||||
|
||||
+1
-1
Submodule src/transformers_repo updated: 7961d291b3...684f12be1c
+5
-14
@@ -1,15 +1,6 @@
|
||||
# _________________________________________________________
|
||||
|
||||
|
||||
from transformers.trainer import (
|
||||
Trainer,
|
||||
_is_peft_model,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||||
tpu_spmd_dataloader,
|
||||
logger,
|
||||
has_length,
|
||||
sys,
|
||||
)
|
||||
from transformers.trainer import *
|
||||
from transformers import (
|
||||
TrainingArguments,
|
||||
@@ -32,12 +23,12 @@ class ContinualTrainer(Trainer):
|
||||
self.accelerator = accelerator
|
||||
super().__init__(model, args, data_collator, train_dataset, eval_dataset)
|
||||
|
||||
if regularization_args.ewc_enable:
|
||||
self.ewc_lambda = regularization_args.ewc_lambda
|
||||
# fisher = t
|
||||
# if regularization_args.ewc_enable:
|
||||
# self.ewc_lambda = regularization_args.ewc_lambda
|
||||
# # fisher = t
|
||||
|
||||
if regularization_args.lwf_enable:
|
||||
self.lwf_lambda = regularization_args.lwf_lambda
|
||||
# if regularization_args.lwf_enable:
|
||||
# self.lwf_lambda = regularization_args.lwf_lambda
|
||||
|
||||
def create_accelerator_and_postprocess(self):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user