feat✨: 添加子模块支持,更新数据集处理逻辑,优化训练和评估脚本
This commit is contained in:
parent
44b06ea5db
commit
68c3e053fb
6
.gitmodules
vendored
Normal file
6
.gitmodules
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
[submodule "src/transformers_repo"]
|
||||
path = src/transformers_repo
|
||||
url = git@github.com:Shujakuinkuraudo/transformers.git
|
||||
[submodule "src/peft_repo"]
|
||||
path = src/peft_repo
|
||||
url = git@github.com:Shujakuinkuraudo/peft.git
|
6
src/.vscode/settings.json
vendored
Normal file
6
src/.vscode/settings.json
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
{
|
||||
"python.analysis.extraPaths": [
|
||||
"transformers_repo/src/",
|
||||
"peft_repo/src/"
|
||||
]
|
||||
}
|
@ -25,9 +25,9 @@ class CHEMDataset(Dataset):
|
||||
|
||||
def _vis_processor(self, image: Image.Image):
|
||||
width, height = image.size
|
||||
if width > 800 or height > 800:
|
||||
if width > 600 or height > 600:
|
||||
max_size = max(width, height)
|
||||
ratio = 800 / max_size
|
||||
ratio = 600 / max_size
|
||||
new_width = int(width * ratio)
|
||||
new_height = int(height * ratio)
|
||||
image = image.resize((new_width, new_height), Image.Resampling.BILINEAR)
|
||||
@ -69,7 +69,7 @@ class CHEMDataset(Dataset):
|
||||
return processed_data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
return len(self.data) // 60
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.data[index]
|
||||
|
@ -1,3 +1,7 @@
|
||||
import sys
|
||||
sys.path.insert(0, "./transformers_repo/src/")
|
||||
sys.path.insert(0, "./peft_repo/src/")
|
||||
|
||||
import torch
|
||||
from dataset_library.factory import get_dataset
|
||||
from transformers import (
|
||||
|
@ -9,7 +9,7 @@ accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 2 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--output_dir checkpoint/qwen2/ \
|
||||
--output_dir checkpoint/qwen2mmoe/ \
|
||||
--bf16 \
|
||||
--torch_dtype bfloat16
|
||||
# --eval_strategy epoch \
|
||||
|
@ -5,6 +5,8 @@ from trl import (
|
||||
get_quantization_config,
|
||||
)
|
||||
from utils.args import ContinualModelConfig
|
||||
import transformers
|
||||
print(transformers.__version__)
|
||||
|
||||
|
||||
def get_model(model_args: ContinualModelConfig):
|
||||
|
1
src/peft_repo
Submodule
1
src/peft_repo
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit de88c703065fdd3a05521da1054c0463d16ea33c
|
12
src/train.py
12
src/train.py
@ -1,5 +1,8 @@
|
||||
from dataset_library.factory import get_dataset
|
||||
import sys
|
||||
sys.path.insert(0, "./transformers_repo/src/")
|
||||
sys.path.insert(0, "./peft_repo/src/")
|
||||
|
||||
from dataset_library.factory import get_dataset
|
||||
|
||||
from transformers import (
|
||||
TrainingArguments,
|
||||
@ -8,8 +11,6 @@ from transformers import (
|
||||
from trl import (
|
||||
TrlParser,
|
||||
)
|
||||
from peft_library import get_peft_model
|
||||
|
||||
from utils.trainer import ContinualTrainer
|
||||
from utils.args import ContinualScriptArguments, ContinualModelConfig
|
||||
import logging
|
||||
@ -49,10 +50,9 @@ if __name__ == "__main__":
|
||||
|
||||
peft_config = MMOELoraConfig(target_modules=model_args.lora_target_modules)
|
||||
|
||||
model = get_peft_model(model, peft_config)
|
||||
# model = get_peft_model(model, peft_config)
|
||||
model.add_adapter(peft_config)
|
||||
|
||||
if accelerator.is_local_main_process:
|
||||
model.print_trainable_parameters()
|
||||
elif model_args.peft_type == "LORA":
|
||||
from peft.tuners.lora import LoraConfig
|
||||
|
||||
|
@ -3,13 +3,13 @@
|
||||
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero2.yaml train.py \
|
||||
--dataset_name CHEM \
|
||||
--use_peft \
|
||||
--peft_type LORA \
|
||||
--peft_type MMOELORA \
|
||||
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
||||
--lora_target_modules q_proj v_proj \
|
||||
--per_device_train_batch_size 1 \
|
||||
--per_device_eval_batch_size 2 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
--output_dir checkpoint/qwen2/ \
|
||||
--output_dir checkpoint/qwen2mmoe/ \
|
||||
--bf16 \
|
||||
--torch_dtype bfloat16 \
|
||||
--logging_steps 30
|
||||
|
1
src/transformers_repo
Submodule
1
src/transformers_repo
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit 6bc0fbcfa7acb6ac4937e7456a76c2f7975fefec
|
Loading…
Reference in New Issue
Block a user