更新.gitignore以排除虚拟环境和缓存文件,修改TODO列表,重命名评估脚本,添加训练和评估脚本,新增数据集工厂和评估工具类

This commit is contained in:
YunyaoZhou 2024-12-31 17:53:16 +00:00
parent d6b4ec79ad
commit f2f921113e
14 changed files with 411 additions and 161 deletions

4
.gitignore vendored
View File

@ -1,2 +1,2 @@
**/.venv/
**/__pycache__/
**/.venv/*
**/__pycache__/*

View File

@ -0,0 +1,30 @@
from torch.utils.data import Dataset
from typing import Literal
def get_dataset(
script_args, base_path="/home/zyy/research/accelerate/dataset"
) -> dict[Literal["train", "test", "generation"], Dataset]:
dataset: dict[Literal["train", "test", "generation"], Dataset] = {}
if script_args.dataset_name == "OCR_VQA_200K":
import os.path as osp
from .OCRVQADataset import OCRVQADataset, OCRVQADatasetForGeneration
dataset = {
"train": OCRVQADataset(
osp.join(base_path, "OCR-VQA-200K/images"),
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
split="train",
),
"test": OCRVQADataset(
osp.join(base_path, "OCR-VQA-200K/images"),
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
split="test",
),
"generation": OCRVQADatasetForGeneration(
osp.join(base_path, "OCR-VQA-200K/images"),
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
split="test",
),
}
return dataset

155
src/evaluate.py Normal file
View File

@ -0,0 +1,155 @@
import torch
from datasets_library.factory import get_dataset
from transformers import (
AutoModelForVision2Seq,
AutoProcessor,
)
from trl import (
ModelConfig,
SFTScriptArguments,
SFTConfig,
SFTTrainer,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
from peft import get_peft_model
from utils.trainer import ContinualTrainer
if __name__ == "__main__":
parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
script_args: SFTScriptArguments = script_args
training_args: SFTConfig = training_args
model_args: ModelConfig = model_args
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
processor = AutoProcessor.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
padding_side="left",
)
model = AutoModelForVision2Seq.from_pretrained(
training_args.output_dir,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
from collatefn_library.qwen2 import (
collate_fn_for_train,
collate_fn_for_evaluate,
)
from functools import partial
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
################
# Dataset
################
dataset = get_dataset(script_args)
# peft_config = get_peft_config(model_args)
# model = get_peft_model(model, peft_config)
# 仅在rank1 rank2 rank3时打印
if torch.distributed.get_rank() in [1]:
print(model)
# _________________________________________________________
model.train()
import copy
training_args_init = copy.copy(training_args)
training_args_init.do_train = False
training_args_init.do_eval = False
training_args_init.do_predict = False
training_args_init.num_train_epochs = 0
trainer = SFTTrainer(
model=model,
args=training_args_init,
data_collator=collate_fn_for_train,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=(
dataset[script_args.dataset_test_split]
if training_args.eval_strategy != "no"
else None
),
)
trainer.train()
model.eval()
accelerator = trainer.accelerator
from torch.utils.data import DataLoader
val_dataloader = DataLoader(
dataset["generation"],
batch_size=3,
collate_fn=collate_fn_for_evaluate,
)
val_dataloader = accelerator.prepare(val_dataloader)
from utils.evaluate_tool import evaluate_rouge
evaluate_rouge(model, val_dataloader, processor)
model.train()
trainer = ContinualTrainer(
model=model,
args=training_args,
data_collator=collate_fn_for_train,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=(
dataset[script_args.dataset_test_split]
if training_args.eval_strategy != "no"
else None
),
accelerator=accelerator,
)
trainer.train()
trainer.save_model(training_args.output_dir)
# 清理cache
torch.cuda.empty_cache()
# load_model
from transformers import AutoModelForVision2Seq
model = AutoModelForVision2Seq.from_pretrained(training_args.output_dir)
model = accelerator.prepare(model)
model.eval()
accelerator = trainer.accelerator
from torch.utils.data import DataLoader
val_dataloader = DataLoader(
dataset["generation"],
batch_size=3,
collate_fn=collate_fn_for_evaluate,
)
val_dataloader = accelerator.prepare(val_dataloader)
from utils.evaluate_tool import evaluate_rouge
evaluate_rouge(model, val_dataloader, processor)

View File

@ -1,6 +1,6 @@
#!/bin/bash
accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml evaluation.py \
accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml evaluate.py \
--dataset_name OCR_VQA_200K \
--use_peft \
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \

View File

@ -1,153 +0,0 @@
import torch
from datasets import load_dataset
from transformers import (
AutoModelForVision2Seq,
AutoProcessor,
LlavaForConditionalGeneration,
)
from trl import (
ModelConfig,
SFTScriptArguments,
SFTConfig,
SFTTrainer,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
if __name__ == "__main__":
parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
script_args: SFTScriptArguments
training_args: SFTConfig
model_args: ModelConfig
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
################
# Model, Tokenizer & Processor
################
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
processor = AutoProcessor.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
padding_side="left",
)
model = AutoModelForVision2Seq.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
from collate_fn_library.qwen2 import collate_fn_for_train
from functools import partial
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
################
# Dataset
################
base_path = "/home/zyy/research/accelerate/dataset"
if script_args.dataset_name == "OCR_VQA_200K":
import os.path as osp
from datasets_library.OCRVQADataset import OCRVQADataset
dataset = {
"train": OCRVQADataset(
osp.join(base_path, "OCR-VQA-200K/images"),
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
split="train",
),
"test": OCRVQADataset(
osp.join(base_path, "OCR-VQA-200K/images"),
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
split="test",
),
}
else:
dataset = load_dataset(script_args.dataset_name, name=script_args.config)
################
# Training
################
trainer = SFTTrainer(
model=model,
args=training_args,
data_collator=collate_fn_for_train,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=(
dataset[script_args.dataset_test_split]
if training_args.eval_strategy != "no"
else None
),
tokenizer=processor.tokenizer,
# processing_class=processor.tokenizer,
peft_config=get_peft_config(model_args),
)
trainer.train()
model = trainer.model
# model进gpu
accelerator = trainer.accelerator
# model = accelerator.prepare(model)
from datasets_library.OCRVQADataset import OCRVQADatasetForGeneration
from collate_fn_library.qwen2 import collate_fn_for_evaluate
dataset = OCRVQADatasetForGeneration(
vis_root="/home/zyy/research/accelerate/dataset/OCR-VQA-200K/images",
ann_path="/home/zyy/research/accelerate/dataset/OCR-VQA-200K/dataset.json",
split="train",
)
examples = [dataset[i] for i in range(3)]
# print(collate_fn_for_evaluate(examples, processor))
from torch.utils.data import DataLoader
val_dataloader = DataLoader(
dataset,
batch_size=3,
collate_fn=lambda x: collate_fn_for_evaluate(x, processor),
)
val_dataloader = accelerator.prepare(val_dataloader)
import evaluate
glue = evaluate.load("rouge")
for batch in val_dataloader:
completion = model.generate(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
pixel_values=batch["pixel_values"],
image_grid_thw=batch["image_grid_thw"],
max_length=1000,
)
target = batch["answers_ids"]
generated_text = [
out_ids[len(in_ids) :]
for out_ids, in_ids in zip(completion, batch["input_ids"])
]
generated_text = processor.tokenizer.batch_decode(generated_text, skip_special_tokens=True)
target_text = processor.tokenizer.batch_decode(target, skip_special_tokens=True)
glue.add_batch(predictions=generated_text, references=target_text)
print(glue.compute())

View File

@ -99,8 +99,6 @@ if __name__ == "__main__":
if training_args.eval_strategy != "no"
else None
),
tokenizer=processor.tokenizer,
# processing_class=processor.tokenizer,
peft_config=get_peft_config(model_args),
)

View File

@ -1,6 +1,6 @@
## TODO:
[2024.12.31]
[2024.12.31]
- [ ] 采用数据集多次训练
- [ ] 整理evaluate的代码
- [X] 采用数据集多次训练
- [X] 整理evaluate的代码

155
src/train.py Normal file
View File

@ -0,0 +1,155 @@
import torch
from datasets_library.factory import get_dataset
from transformers import (
AutoModelForVision2Seq,
AutoProcessor,
)
from trl import (
ModelConfig,
SFTScriptArguments,
SFTConfig,
SFTTrainer,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
from peft import get_peft_model
from utils.trainer import ContinualTrainer
if __name__ == "__main__":
parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
script_args: SFTScriptArguments
training_args: SFTConfig
model_args: ModelConfig
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
processor = AutoProcessor.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=model_args.trust_remote_code,
padding_side="left",
)
model = AutoModelForVision2Seq.from_pretrained(
training_args.output_dir,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
from collatefn_library.qwen2 import (
collate_fn_for_train,
collate_fn_for_evaluate,
)
from functools import partial
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
################
# Dataset
################
dataset = get_dataset(script_args)
peft_config = get_peft_config(model_args)
model = get_peft_model(model, peft_config)
# 仅在rank1 rank2 rank3时打印
if torch.distributed.get_rank() in [1]:
print(model.print_trainable_parameters)
# _________________________________________________________
model.train()
import copy
training_args_init = copy.copy(training_args)
training_args_init.do_train = False
training_args_init.do_eval = False
training_args_init.do_predict = False
training_args_init.num_train_epochs = 0
trainer = SFTTrainer(
model=model,
args=training_args_init,
data_collator=collate_fn_for_train,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=(
dataset[script_args.dataset_test_split]
if training_args.eval_strategy != "no"
else None
),
)
trainer.train()
model.eval()
accelerator = trainer.accelerator
from torch.utils.data import DataLoader
val_dataloader = DataLoader(
dataset["generation"],
batch_size=3,
collate_fn=collate_fn_for_evaluate,
)
val_dataloader = accelerator.prepare(val_dataloader)
from utils.evaluate_tool import evaluate_rouge
evaluate_rouge(model, val_dataloader, processor)
model.train()
trainer = ContinualTrainer(
model=model,
args=training_args,
data_collator=collate_fn_for_train,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=(
dataset[script_args.dataset_test_split]
if training_args.eval_strategy != "no"
else None
),
accelerator=accelerator,
)
trainer.train()
trainer.save_model(training_args.output_dir)
# 清理cache
torch.cuda.empty_cache()
# load_model
from transformers import AutoModelForVision2Seq
model = AutoModelForVision2Seq.from_pretrained(training_args.output_dir)
model = accelerator.prepare(model)
model.eval()
accelerator = trainer.accelerator
from torch.utils.data import DataLoader
val_dataloader = DataLoader(
dataset["generation"],
batch_size=3,
collate_fn=collate_fn_for_evaluate,
)
val_dataloader = accelerator.prepare(val_dataloader)
from utils.evaluate_tool import evaluate_rouge
evaluate_rouge(model, val_dataloader, processor)

15
src/train.sh Executable file
View File

@ -0,0 +1,15 @@
#!/bin/bash
accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml train.py \
--dataset_name OCR_VQA_200K \
--use_peft \
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
--lora_target_modules q_proj v_proj \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 8 \
--max_seq_length 1024 \
--output_dir checkpoint/sft-llava-1.5-7b-hf \
--bf16 \
--torch_dtype bfloat16
# --eval_strategy epoch \

View File

@ -0,0 +1,26 @@
import evaluate
def evaluate_rouge(model, val_dataloader, processor):
glue = evaluate.load("rouge")
for batch in val_dataloader:
completion = model.generate(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
pixel_values=batch["pixel_values"],
image_grid_thw=batch["image_grid_thw"],
max_length=1000,
)
target = batch["answers_ids"]
generated_text = [
out_ids[len(in_ids) :]
for out_ids, in_ids in zip(completion, batch["input_ids"])
]
generated_text = processor.tokenizer.batch_decode(
generated_text, skip_special_tokens=True
)
target_text = processor.tokenizer.batch_decode(target, skip_special_tokens=True)
glue.add_batch(predictions=generated_text, references=target_text)
print(glue.compute())

24
src/utils/trainer.py Normal file
View File

@ -0,0 +1,24 @@
# _________________________________________________________
from trl import SFTTrainer
class ContinualTrainer(SFTTrainer):
def __init__(
self, model, args, data_collator, train_dataset, eval_dataset, accelerator
):
self.accelerator = accelerator
super().__init__(model, args, data_collator, train_dataset, eval_dataset)
def create_accelerator_and_postprocess(self):
if self.accelerator is not None:
self.is_deepspeed_enabled = (
getattr(self.accelerator.state, "deepspeed_plugin", None)
is not None
)
self.is_fsdp_enabled = (
getattr(self.accelerator.state, "fsdp_plugin", None) is not None
)
return
else:
super().create_accelerator_and_postprocess()