feat: 添加子模块支持,更新数据集处理逻辑,优化训练和评估脚本

This commit is contained in:
YunyaoZhou 2025-01-15 21:56:20 +08:00
parent 44b06ea5db
commit 68c3e053fb
Signed by: shujakuin
GPG Key ID: 418C3CA28E350CCF
10 changed files with 32 additions and 12 deletions

6
.gitmodules vendored Normal file
View File

@ -0,0 +1,6 @@
[submodule "src/transformers_repo"]
path = src/transformers_repo
url = git@github.com:Shujakuinkuraudo/transformers.git
[submodule "src/peft_repo"]
path = src/peft_repo
url = git@github.com:Shujakuinkuraudo/peft.git

6
src/.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,6 @@
{
"python.analysis.extraPaths": [
"transformers_repo/src/",
"peft_repo/src/"
]
}

View File

@ -25,9 +25,9 @@ class CHEMDataset(Dataset):
def _vis_processor(self, image: Image.Image): def _vis_processor(self, image: Image.Image):
width, height = image.size width, height = image.size
if width > 800 or height > 800: if width > 600 or height > 600:
max_size = max(width, height) max_size = max(width, height)
ratio = 800 / max_size ratio = 600 / max_size
new_width = int(width * ratio) new_width = int(width * ratio)
new_height = int(height * ratio) new_height = int(height * ratio)
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR) image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
@ -69,7 +69,7 @@ class CHEMDataset(Dataset):
return processed_data return processed_data
def __len__(self): def __len__(self):
return len(self.data) return len(self.data) // 60
def __getitem__(self, index): def __getitem__(self, index):
sample = self.data[index] sample = self.data[index]

View File

@ -1,3 +1,7 @@
import sys
sys.path.insert(0, "./transformers_repo/src/")
sys.path.insert(0, "./peft_repo/src/")
import torch import torch
from dataset_library.factory import get_dataset from dataset_library.factory import get_dataset
from transformers import ( from transformers import (

View File

@ -9,7 +9,7 @@ accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml
--per_device_train_batch_size 1 \ --per_device_train_batch_size 1 \
--per_device_eval_batch_size 2 \ --per_device_eval_batch_size 2 \
--gradient_accumulation_steps 8 \ --gradient_accumulation_steps 8 \
--output_dir checkpoint/qwen2/ \ --output_dir checkpoint/qwen2mmoe/ \
--bf16 \ --bf16 \
--torch_dtype bfloat16 --torch_dtype bfloat16
# --eval_strategy epoch \ # --eval_strategy epoch \

View File

@ -5,6 +5,8 @@ from trl import (
get_quantization_config, get_quantization_config,
) )
from utils.args import ContinualModelConfig from utils.args import ContinualModelConfig
import transformers
print(transformers.__version__)
def get_model(model_args: ContinualModelConfig): def get_model(model_args: ContinualModelConfig):

1
src/peft_repo Submodule

@ -0,0 +1 @@
Subproject commit de88c703065fdd3a05521da1054c0463d16ea33c

View File

@ -1,5 +1,8 @@
from dataset_library.factory import get_dataset import sys
sys.path.insert(0, "./transformers_repo/src/")
sys.path.insert(0, "./peft_repo/src/")
from dataset_library.factory import get_dataset
from transformers import ( from transformers import (
TrainingArguments, TrainingArguments,
@ -8,8 +11,6 @@ from transformers import (
from trl import ( from trl import (
TrlParser, TrlParser,
) )
from peft_library import get_peft_model
from utils.trainer import ContinualTrainer from utils.trainer import ContinualTrainer
from utils.args import ContinualScriptArguments, ContinualModelConfig from utils.args import ContinualScriptArguments, ContinualModelConfig
import logging import logging
@ -49,10 +50,9 @@ if __name__ == "__main__":
peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules) peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules)
model = get_peft_model(model, peft_config) # model = get_peft_model(model, peft_config)
model.add_adapter(peft_config)
if accelerator.is_local_main_process:
model.print_trainable_parameters()
elif model_args.peft_type == "LORA": elif model_args.peft_type == "LORA":
from peft.tuners.lora import LoraConfig from peft.tuners.lora import LoraConfig

View File

@ -3,13 +3,13 @@
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \ accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
--dataset_name CHEM \ --dataset_name CHEM \
--use_peft \ --use_peft \
--peft_type LORA \ --peft_type MMOELORA \
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \ --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
--lora_target_modules q_proj v_proj \ --lora_target_modules q_proj v_proj \
--per_device_train_batch_size 1 \ --per_device_train_batch_size 1 \
--per_device_eval_batch_size 2 \ --per_device_eval_batch_size 2 \
--gradient_accumulation_steps 4 \ --gradient_accumulation_steps 4 \
--output_dir checkpoint/qwen2/ \ --output_dir checkpoint/qwen2mmoe/ \
--bf16 \ --bf16 \
--torch_dtype bfloat16 \ --torch_dtype bfloat16 \
--logging_steps 30 --logging_steps 30

1
src/transformers_repo Submodule

@ -0,0 +1 @@
Subproject commit 6bc0fbcfa7acb6ac4937e7456a76c2f7975fefec