feat✨: 添加子模块支持,更新数据集处理逻辑,优化训练和评估脚本
This commit is contained in:
parent
44b06ea5db
commit
68c3e053fb
6
.gitmodules
vendored
Normal file
6
.gitmodules
vendored
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
[submodule "src/transformers_repo"]
|
||||||
|
path = src/transformers_repo
|
||||||
|
url = git@github.com:Shujakuinkuraudo/transformers.git
|
||||||
|
[submodule "src/peft_repo"]
|
||||||
|
path = src/peft_repo
|
||||||
|
url = git@github.com:Shujakuinkuraudo/peft.git
|
6
src/.vscode/settings.json
vendored
Normal file
6
src/.vscode/settings.json
vendored
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"python.analysis.extraPaths": [
|
||||||
|
"transformers_repo/src/",
|
||||||
|
"peft_repo/src/"
|
||||||
|
]
|
||||||
|
}
|
@ -25,9 +25,9 @@ class CHEMDataset(Dataset):
|
|||||||
|
|
||||||
def _vis_processor(self, image: Image.Image):
|
def _vis_processor(self, image: Image.Image):
|
||||||
width, height = image.size
|
width, height = image.size
|
||||||
if width > 800 or height > 800:
|
if width > 600 or height > 600:
|
||||||
max_size = max(width, height)
|
max_size = max(width, height)
|
||||||
ratio = 800 / max_size
|
ratio = 600 / max_size
|
||||||
new_width = int(width * ratio)
|
new_width = int(width * ratio)
|
||||||
new_height = int(height * ratio)
|
new_height = int(height * ratio)
|
||||||
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
||||||
@ -69,7 +69,7 @@ class CHEMDataset(Dataset):
|
|||||||
return processed_data
|
return processed_data
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data)
|
return len(self.data) // 60
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
sample = self.data[index]
|
sample = self.data[index]
|
||||||
|
@ -1,3 +1,7 @@
|
|||||||
|
import sys
|
||||||
|
sys.path.insert(0, "./transformers_repo/src/")
|
||||||
|
sys.path.insert(0, "./peft_repo/src/")
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from dataset_library.factory import get_dataset
|
from dataset_library.factory import get_dataset
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
@ -9,7 +9,7 @@ accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml
|
|||||||
--per_device_train_batch_size 1 \
|
--per_device_train_batch_size 1 \
|
||||||
--per_device_eval_batch_size 2 \
|
--per_device_eval_batch_size 2 \
|
||||||
--gradient_accumulation_steps 8 \
|
--gradient_accumulation_steps 8 \
|
||||||
--output_dir checkpoint/qwen2/ \
|
--output_dir checkpoint/qwen2mmoe/ \
|
||||||
--bf16 \
|
--bf16 \
|
||||||
--torch_dtype bfloat16
|
--torch_dtype bfloat16
|
||||||
# --eval_strategy epoch \
|
# --eval_strategy epoch \
|
||||||
|
@ -5,6 +5,8 @@ from trl import (
|
|||||||
get_quantization_config,
|
get_quantization_config,
|
||||||
)
|
)
|
||||||
from utils.args import ContinualModelConfig
|
from utils.args import ContinualModelConfig
|
||||||
|
import transformers
|
||||||
|
print(transformers.__version__)
|
||||||
|
|
||||||
|
|
||||||
def get_model(model_args: ContinualModelConfig):
|
def get_model(model_args: ContinualModelConfig):
|
||||||
|
1
src/peft_repo
Submodule
1
src/peft_repo
Submodule
@ -0,0 +1 @@
|
|||||||
|
Subproject commit de88c703065fdd3a05521da1054c0463d16ea33c
|
12
src/train.py
12
src/train.py
@ -1,5 +1,8 @@
|
|||||||
from dataset_library.factory import get_dataset
|
import sys
|
||||||
|
sys.path.insert(0, "./transformers_repo/src/")
|
||||||
|
sys.path.insert(0, "./peft_repo/src/")
|
||||||
|
|
||||||
|
from dataset_library.factory import get_dataset
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
@ -8,8 +11,6 @@ from transformers import (
|
|||||||
from trl import (
|
from trl import (
|
||||||
TrlParser,
|
TrlParser,
|
||||||
)
|
)
|
||||||
from peft_library import get_peft_model
|
|
||||||
|
|
||||||
from utils.trainer import ContinualTrainer
|
from utils.trainer import ContinualTrainer
|
||||||
from utils.args import ContinualScriptArguments, ContinualModelConfig
|
from utils.args import ContinualScriptArguments, ContinualModelConfig
|
||||||
import logging
|
import logging
|
||||||
@ -49,10 +50,9 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules)
|
peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules)
|
||||||
|
|
||||||
model = get_peft_model(model, peft_config)
|
# model = get_peft_model(model, peft_config)
|
||||||
|
model.add_adapter(peft_config)
|
||||||
|
|
||||||
if accelerator.is_local_main_process:
|
|
||||||
model.print_trainable_parameters()
|
|
||||||
elif model_args.peft_type == "LORA":
|
elif model_args.peft_type == "LORA":
|
||||||
from peft.tuners.lora import LoraConfig
|
from peft.tuners.lora import LoraConfig
|
||||||
|
|
||||||
|
@ -3,13 +3,13 @@
|
|||||||
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
|
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
|
||||||
--dataset_name CHEM \
|
--dataset_name CHEM \
|
||||||
--use_peft \
|
--use_peft \
|
||||||
--peft_type LORA \
|
--peft_type MMOELORA \
|
||||||
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
||||||
--lora_target_modules q_proj v_proj \
|
--lora_target_modules q_proj v_proj \
|
||||||
--per_device_train_batch_size 1 \
|
--per_device_train_batch_size 1 \
|
||||||
--per_device_eval_batch_size 2 \
|
--per_device_eval_batch_size 2 \
|
||||||
--gradient_accumulation_steps 4 \
|
--gradient_accumulation_steps 4 \
|
||||||
--output_dir checkpoint/qwen2/ \
|
--output_dir checkpoint/qwen2mmoe/ \
|
||||||
--bf16 \
|
--bf16 \
|
||||||
--torch_dtype bfloat16 \
|
--torch_dtype bfloat16 \
|
||||||
--logging_steps 30
|
--logging_steps 30
|
||||||
|
1
src/transformers_repo
Submodule
1
src/transformers_repo
Submodule
@ -0,0 +1 @@
|
|||||||
|
Subproject commit 6bc0fbcfa7acb6ac4937e7456a76c2f7975fefec
|
Loading…
Reference in New Issue
Block a user