更新.gitignore以排除虚拟环境和缓存文件,修改TODO列表,重命名评估脚本,添加训练和评估脚本,新增数据集工厂和评估工具类
This commit is contained in:
parent
d6b4ec79ad
commit
f2f921113e
4
.gitignore
vendored
4
.gitignore
vendored
@ -1,2 +1,2 @@
|
|||||||
**/.venv/
|
**/.venv/*
|
||||||
**/__pycache__/
|
**/__pycache__/*
|
||||||
|
30
src/datasets_library/factory.py
Normal file
30
src/datasets_library/factory.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from torch.utils.data import Dataset
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset(
|
||||||
|
script_args, base_path="/home/zyy/research/accelerate/dataset"
|
||||||
|
) -> dict[Literal["train", "test", "generation"], Dataset]:
|
||||||
|
dataset: dict[Literal["train", "test", "generation"], Dataset] = {}
|
||||||
|
if script_args.dataset_name == "OCR_VQA_200K":
|
||||||
|
import os.path as osp
|
||||||
|
from .OCRVQADataset import OCRVQADataset, OCRVQADatasetForGeneration
|
||||||
|
|
||||||
|
dataset = {
|
||||||
|
"train": OCRVQADataset(
|
||||||
|
osp.join(base_path, "OCR-VQA-200K/images"),
|
||||||
|
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
|
||||||
|
split="train",
|
||||||
|
),
|
||||||
|
"test": OCRVQADataset(
|
||||||
|
osp.join(base_path, "OCR-VQA-200K/images"),
|
||||||
|
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
|
||||||
|
split="test",
|
||||||
|
),
|
||||||
|
"generation": OCRVQADatasetForGeneration(
|
||||||
|
osp.join(base_path, "OCR-VQA-200K/images"),
|
||||||
|
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
|
||||||
|
split="test",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
return dataset
|
155
src/evaluate.py
Normal file
155
src/evaluate.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
import torch
|
||||||
|
from datasets_library.factory import get_dataset
|
||||||
|
from transformers import (
|
||||||
|
AutoModelForVision2Seq,
|
||||||
|
AutoProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
from trl import (
|
||||||
|
ModelConfig,
|
||||||
|
SFTScriptArguments,
|
||||||
|
SFTConfig,
|
||||||
|
SFTTrainer,
|
||||||
|
TrlParser,
|
||||||
|
get_kbit_device_map,
|
||||||
|
get_peft_config,
|
||||||
|
get_quantization_config,
|
||||||
|
)
|
||||||
|
from peft import get_peft_model
|
||||||
|
|
||||||
|
from utils.trainer import ContinualTrainer
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))
|
||||||
|
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||||
|
script_args: SFTScriptArguments = script_args
|
||||||
|
training_args: SFTConfig = training_args
|
||||||
|
model_args: ModelConfig = model_args
|
||||||
|
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||||
|
training_args.remove_unused_columns = False
|
||||||
|
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
||||||
|
|
||||||
|
torch_dtype = (
|
||||||
|
model_args.torch_dtype
|
||||||
|
if model_args.torch_dtype in ["auto", None]
|
||||||
|
else getattr(torch, model_args.torch_dtype)
|
||||||
|
)
|
||||||
|
quantization_config = get_quantization_config(model_args)
|
||||||
|
model_kwargs = dict(
|
||||||
|
revision=model_args.model_revision,
|
||||||
|
attn_implementation=model_args.attn_implementation,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||||
|
quantization_config=quantization_config,
|
||||||
|
)
|
||||||
|
processor = AutoProcessor.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
padding_side="left",
|
||||||
|
)
|
||||||
|
|
||||||
|
model = AutoModelForVision2Seq.from_pretrained(
|
||||||
|
training_args.output_dir,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
|
||||||
|
from collatefn_library.qwen2 import (
|
||||||
|
collate_fn_for_train,
|
||||||
|
collate_fn_for_evaluate,
|
||||||
|
)
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||||
|
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||||
|
|
||||||
|
################
|
||||||
|
# Dataset
|
||||||
|
################
|
||||||
|
|
||||||
|
dataset = get_dataset(script_args)
|
||||||
|
|
||||||
|
# peft_config = get_peft_config(model_args)
|
||||||
|
# model = get_peft_model(model, peft_config)
|
||||||
|
# 仅在rank1 rank2 rank3时打印
|
||||||
|
if torch.distributed.get_rank() in [1]:
|
||||||
|
print(model)
|
||||||
|
|
||||||
|
# _________________________________________________________
|
||||||
|
model.train()
|
||||||
|
import copy
|
||||||
|
|
||||||
|
training_args_init = copy.copy(training_args)
|
||||||
|
training_args_init.do_train = False
|
||||||
|
training_args_init.do_eval = False
|
||||||
|
training_args_init.do_predict = False
|
||||||
|
training_args_init.num_train_epochs = 0
|
||||||
|
trainer = SFTTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args_init,
|
||||||
|
data_collator=collate_fn_for_train,
|
||||||
|
train_dataset=dataset[script_args.dataset_train_split],
|
||||||
|
eval_dataset=(
|
||||||
|
dataset[script_args.dataset_test_split]
|
||||||
|
if training_args.eval_strategy != "no"
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
accelerator = trainer.accelerator
|
||||||
|
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
val_dataloader = DataLoader(
|
||||||
|
dataset["generation"],
|
||||||
|
batch_size=3,
|
||||||
|
collate_fn=collate_fn_for_evaluate,
|
||||||
|
)
|
||||||
|
val_dataloader = accelerator.prepare(val_dataloader)
|
||||||
|
from utils.evaluate_tool import evaluate_rouge
|
||||||
|
|
||||||
|
evaluate_rouge(model, val_dataloader, processor)
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
trainer = ContinualTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
data_collator=collate_fn_for_train,
|
||||||
|
train_dataset=dataset[script_args.dataset_train_split],
|
||||||
|
eval_dataset=(
|
||||||
|
dataset[script_args.dataset_test_split]
|
||||||
|
if training_args.eval_strategy != "no"
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
accelerator=accelerator,
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
trainer.save_model(training_args.output_dir)
|
||||||
|
|
||||||
|
# 清理cache
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# load_model
|
||||||
|
from transformers import AutoModelForVision2Seq
|
||||||
|
model = AutoModelForVision2Seq.from_pretrained(training_args.output_dir)
|
||||||
|
model = accelerator.prepare(model)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
accelerator = trainer.accelerator
|
||||||
|
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
val_dataloader = DataLoader(
|
||||||
|
dataset["generation"],
|
||||||
|
batch_size=3,
|
||||||
|
collate_fn=collate_fn_for_evaluate,
|
||||||
|
)
|
||||||
|
val_dataloader = accelerator.prepare(val_dataloader)
|
||||||
|
from utils.evaluate_tool import evaluate_rouge
|
||||||
|
|
||||||
|
evaluate_rouge(model, val_dataloader, processor)
|
@ -1,6 +1,6 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml evaluation.py \
|
accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml evaluate.py \
|
||||||
--dataset_name OCR_VQA_200K \
|
--dataset_name OCR_VQA_200K \
|
||||||
--use_peft \
|
--use_peft \
|
||||||
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
@ -1,153 +0,0 @@
|
|||||||
import torch
|
|
||||||
from datasets import load_dataset
|
|
||||||
from transformers import (
|
|
||||||
AutoModelForVision2Seq,
|
|
||||||
AutoProcessor,
|
|
||||||
LlavaForConditionalGeneration,
|
|
||||||
)
|
|
||||||
|
|
||||||
from trl import (
|
|
||||||
ModelConfig,
|
|
||||||
SFTScriptArguments,
|
|
||||||
SFTConfig,
|
|
||||||
SFTTrainer,
|
|
||||||
TrlParser,
|
|
||||||
get_kbit_device_map,
|
|
||||||
get_peft_config,
|
|
||||||
get_quantization_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))
|
|
||||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
|
||||||
script_args: SFTScriptArguments
|
|
||||||
training_args: SFTConfig
|
|
||||||
model_args: ModelConfig
|
|
||||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
|
||||||
training_args.remove_unused_columns = False
|
|
||||||
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
|
||||||
|
|
||||||
################
|
|
||||||
# Model, Tokenizer & Processor
|
|
||||||
################
|
|
||||||
torch_dtype = (
|
|
||||||
model_args.torch_dtype
|
|
||||||
if model_args.torch_dtype in ["auto", None]
|
|
||||||
else getattr(torch, model_args.torch_dtype)
|
|
||||||
)
|
|
||||||
quantization_config = get_quantization_config(model_args)
|
|
||||||
model_kwargs = dict(
|
|
||||||
revision=model_args.model_revision,
|
|
||||||
attn_implementation=model_args.attn_implementation,
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
|
||||||
quantization_config=quantization_config,
|
|
||||||
)
|
|
||||||
processor = AutoProcessor.from_pretrained(
|
|
||||||
model_args.model_name_or_path,
|
|
||||||
trust_remote_code=model_args.trust_remote_code,
|
|
||||||
padding_side="left",
|
|
||||||
)
|
|
||||||
|
|
||||||
model = AutoModelForVision2Seq.from_pretrained(
|
|
||||||
model_args.model_name_or_path,
|
|
||||||
trust_remote_code=model_args.trust_remote_code,
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
|
|
||||||
from collate_fn_library.qwen2 import collate_fn_for_train
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
|
||||||
|
|
||||||
################
|
|
||||||
# Dataset
|
|
||||||
################
|
|
||||||
base_path = "/home/zyy/research/accelerate/dataset"
|
|
||||||
if script_args.dataset_name == "OCR_VQA_200K":
|
|
||||||
import os.path as osp
|
|
||||||
from datasets_library.OCRVQADataset import OCRVQADataset
|
|
||||||
|
|
||||||
dataset = {
|
|
||||||
"train": OCRVQADataset(
|
|
||||||
osp.join(base_path, "OCR-VQA-200K/images"),
|
|
||||||
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
|
|
||||||
split="train",
|
|
||||||
),
|
|
||||||
"test": OCRVQADataset(
|
|
||||||
osp.join(base_path, "OCR-VQA-200K/images"),
|
|
||||||
osp.join(base_path, "OCR-VQA-200K/dataset.json"),
|
|
||||||
split="test",
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
else:
|
|
||||||
dataset = load_dataset(script_args.dataset_name, name=script_args.config)
|
|
||||||
|
|
||||||
################
|
|
||||||
# Training
|
|
||||||
################
|
|
||||||
trainer = SFTTrainer(
|
|
||||||
model=model,
|
|
||||||
args=training_args,
|
|
||||||
data_collator=collate_fn_for_train,
|
|
||||||
train_dataset=dataset[script_args.dataset_train_split],
|
|
||||||
eval_dataset=(
|
|
||||||
dataset[script_args.dataset_test_split]
|
|
||||||
if training_args.eval_strategy != "no"
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
tokenizer=processor.tokenizer,
|
|
||||||
# processing_class=processor.tokenizer,
|
|
||||||
peft_config=get_peft_config(model_args),
|
|
||||||
)
|
|
||||||
trainer.train()
|
|
||||||
|
|
||||||
model = trainer.model
|
|
||||||
# model进gpu
|
|
||||||
accelerator = trainer.accelerator
|
|
||||||
# model = accelerator.prepare(model)
|
|
||||||
|
|
||||||
from datasets_library.OCRVQADataset import OCRVQADatasetForGeneration
|
|
||||||
from collate_fn_library.qwen2 import collate_fn_for_evaluate
|
|
||||||
|
|
||||||
dataset = OCRVQADatasetForGeneration(
|
|
||||||
vis_root="/home/zyy/research/accelerate/dataset/OCR-VQA-200K/images",
|
|
||||||
ann_path="/home/zyy/research/accelerate/dataset/OCR-VQA-200K/dataset.json",
|
|
||||||
split="train",
|
|
||||||
)
|
|
||||||
examples = [dataset[i] for i in range(3)]
|
|
||||||
# print(collate_fn_for_evaluate(examples, processor))
|
|
||||||
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
val_dataloader = DataLoader(
|
|
||||||
dataset,
|
|
||||||
batch_size=3,
|
|
||||||
collate_fn=lambda x: collate_fn_for_evaluate(x, processor),
|
|
||||||
)
|
|
||||||
val_dataloader = accelerator.prepare(val_dataloader)
|
|
||||||
|
|
||||||
import evaluate
|
|
||||||
glue = evaluate.load("rouge")
|
|
||||||
|
|
||||||
for batch in val_dataloader:
|
|
||||||
completion = model.generate(
|
|
||||||
input_ids=batch["input_ids"],
|
|
||||||
attention_mask=batch["attention_mask"],
|
|
||||||
pixel_values=batch["pixel_values"],
|
|
||||||
image_grid_thw=batch["image_grid_thw"],
|
|
||||||
max_length=1000,
|
|
||||||
)
|
|
||||||
target = batch["answers_ids"]
|
|
||||||
generated_text = [
|
|
||||||
out_ids[len(in_ids) :]
|
|
||||||
for out_ids, in_ids in zip(completion, batch["input_ids"])
|
|
||||||
]
|
|
||||||
generated_text = processor.tokenizer.batch_decode(generated_text, skip_special_tokens=True)
|
|
||||||
target_text = processor.tokenizer.batch_decode(target, skip_special_tokens=True)
|
|
||||||
glue.add_batch(predictions=generated_text, references=target_text)
|
|
||||||
|
|
||||||
print(glue.compute())
|
|
@ -99,8 +99,6 @@ if __name__ == "__main__":
|
|||||||
if training_args.eval_strategy != "no"
|
if training_args.eval_strategy != "no"
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
tokenizer=processor.tokenizer,
|
|
||||||
# processing_class=processor.tokenizer,
|
|
||||||
peft_config=get_peft_config(model_args),
|
peft_config=get_peft_config(model_args),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
## TODO:
|
## TODO:
|
||||||
|
|
||||||
[2024.12.31]
|
[2024.12.31]
|
||||||
|
|
||||||
- [ ] 采用数据集多次训练
|
- [X] 采用数据集多次训练
|
||||||
- [ ] 整理evaluate的代码
|
- [X] 整理evaluate的代码
|
||||||
|
155
src/train.py
Normal file
155
src/train.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
import torch
|
||||||
|
from datasets_library.factory import get_dataset
|
||||||
|
from transformers import (
|
||||||
|
AutoModelForVision2Seq,
|
||||||
|
AutoProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
from trl import (
|
||||||
|
ModelConfig,
|
||||||
|
SFTScriptArguments,
|
||||||
|
SFTConfig,
|
||||||
|
SFTTrainer,
|
||||||
|
TrlParser,
|
||||||
|
get_kbit_device_map,
|
||||||
|
get_peft_config,
|
||||||
|
get_quantization_config,
|
||||||
|
)
|
||||||
|
from peft import get_peft_model
|
||||||
|
|
||||||
|
from utils.trainer import ContinualTrainer
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))
|
||||||
|
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||||
|
script_args: SFTScriptArguments
|
||||||
|
training_args: SFTConfig
|
||||||
|
model_args: ModelConfig
|
||||||
|
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||||
|
training_args.remove_unused_columns = False
|
||||||
|
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
||||||
|
|
||||||
|
torch_dtype = (
|
||||||
|
model_args.torch_dtype
|
||||||
|
if model_args.torch_dtype in ["auto", None]
|
||||||
|
else getattr(torch, model_args.torch_dtype)
|
||||||
|
)
|
||||||
|
quantization_config = get_quantization_config(model_args)
|
||||||
|
model_kwargs = dict(
|
||||||
|
revision=model_args.model_revision,
|
||||||
|
attn_implementation=model_args.attn_implementation,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||||
|
quantization_config=quantization_config,
|
||||||
|
)
|
||||||
|
processor = AutoProcessor.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
padding_side="left",
|
||||||
|
)
|
||||||
|
|
||||||
|
model = AutoModelForVision2Seq.from_pretrained(
|
||||||
|
training_args.output_dir,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
|
||||||
|
from collatefn_library.qwen2 import (
|
||||||
|
collate_fn_for_train,
|
||||||
|
collate_fn_for_evaluate,
|
||||||
|
)
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||||
|
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||||
|
|
||||||
|
################
|
||||||
|
# Dataset
|
||||||
|
################
|
||||||
|
|
||||||
|
dataset = get_dataset(script_args)
|
||||||
|
|
||||||
|
peft_config = get_peft_config(model_args)
|
||||||
|
model = get_peft_model(model, peft_config)
|
||||||
|
# 仅在rank1 rank2 rank3时打印
|
||||||
|
if torch.distributed.get_rank() in [1]:
|
||||||
|
print(model.print_trainable_parameters)
|
||||||
|
|
||||||
|
# _________________________________________________________
|
||||||
|
model.train()
|
||||||
|
import copy
|
||||||
|
|
||||||
|
training_args_init = copy.copy(training_args)
|
||||||
|
training_args_init.do_train = False
|
||||||
|
training_args_init.do_eval = False
|
||||||
|
training_args_init.do_predict = False
|
||||||
|
training_args_init.num_train_epochs = 0
|
||||||
|
trainer = SFTTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args_init,
|
||||||
|
data_collator=collate_fn_for_train,
|
||||||
|
train_dataset=dataset[script_args.dataset_train_split],
|
||||||
|
eval_dataset=(
|
||||||
|
dataset[script_args.dataset_test_split]
|
||||||
|
if training_args.eval_strategy != "no"
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
accelerator = trainer.accelerator
|
||||||
|
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
val_dataloader = DataLoader(
|
||||||
|
dataset["generation"],
|
||||||
|
batch_size=3,
|
||||||
|
collate_fn=collate_fn_for_evaluate,
|
||||||
|
)
|
||||||
|
val_dataloader = accelerator.prepare(val_dataloader)
|
||||||
|
from utils.evaluate_tool import evaluate_rouge
|
||||||
|
|
||||||
|
evaluate_rouge(model, val_dataloader, processor)
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
trainer = ContinualTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
data_collator=collate_fn_for_train,
|
||||||
|
train_dataset=dataset[script_args.dataset_train_split],
|
||||||
|
eval_dataset=(
|
||||||
|
dataset[script_args.dataset_test_split]
|
||||||
|
if training_args.eval_strategy != "no"
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
accelerator=accelerator,
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
trainer.save_model(training_args.output_dir)
|
||||||
|
|
||||||
|
# 清理cache
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# load_model
|
||||||
|
from transformers import AutoModelForVision2Seq
|
||||||
|
model = AutoModelForVision2Seq.from_pretrained(training_args.output_dir)
|
||||||
|
model = accelerator.prepare(model)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
accelerator = trainer.accelerator
|
||||||
|
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
val_dataloader = DataLoader(
|
||||||
|
dataset["generation"],
|
||||||
|
batch_size=3,
|
||||||
|
collate_fn=collate_fn_for_evaluate,
|
||||||
|
)
|
||||||
|
val_dataloader = accelerator.prepare(val_dataloader)
|
||||||
|
from utils.evaluate_tool import evaluate_rouge
|
||||||
|
|
||||||
|
evaluate_rouge(model, val_dataloader, processor)
|
15
src/train.sh
Executable file
15
src/train.sh
Executable file
@ -0,0 +1,15 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml train.py \
|
||||||
|
--dataset_name OCR_VQA_200K \
|
||||||
|
--use_peft \
|
||||||
|
--model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
||||||
|
--lora_target_modules q_proj v_proj \
|
||||||
|
--per_device_train_batch_size 1 \
|
||||||
|
--per_device_eval_batch_size 2 \
|
||||||
|
--gradient_accumulation_steps 8 \
|
||||||
|
--max_seq_length 1024 \
|
||||||
|
--output_dir checkpoint/sft-llava-1.5-7b-hf \
|
||||||
|
--bf16 \
|
||||||
|
--torch_dtype bfloat16
|
||||||
|
# --eval_strategy epoch \
|
26
src/utils/evaluate_tool.py
Normal file
26
src/utils/evaluate_tool.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import evaluate
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_rouge(model, val_dataloader, processor):
|
||||||
|
glue = evaluate.load("rouge")
|
||||||
|
|
||||||
|
for batch in val_dataloader:
|
||||||
|
completion = model.generate(
|
||||||
|
input_ids=batch["input_ids"],
|
||||||
|
attention_mask=batch["attention_mask"],
|
||||||
|
pixel_values=batch["pixel_values"],
|
||||||
|
image_grid_thw=batch["image_grid_thw"],
|
||||||
|
max_length=1000,
|
||||||
|
)
|
||||||
|
target = batch["answers_ids"]
|
||||||
|
generated_text = [
|
||||||
|
out_ids[len(in_ids) :]
|
||||||
|
for out_ids, in_ids in zip(completion, batch["input_ids"])
|
||||||
|
]
|
||||||
|
generated_text = processor.tokenizer.batch_decode(
|
||||||
|
generated_text, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
target_text = processor.tokenizer.batch_decode(target, skip_special_tokens=True)
|
||||||
|
glue.add_batch(predictions=generated_text, references=target_text)
|
||||||
|
|
||||||
|
print(glue.compute())
|
24
src/utils/trainer.py
Normal file
24
src/utils/trainer.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
# _________________________________________________________
|
||||||
|
|
||||||
|
from trl import SFTTrainer
|
||||||
|
|
||||||
|
|
||||||
|
class ContinualTrainer(SFTTrainer):
|
||||||
|
def __init__(
|
||||||
|
self, model, args, data_collator, train_dataset, eval_dataset, accelerator
|
||||||
|
):
|
||||||
|
self.accelerator = accelerator
|
||||||
|
super().__init__(model, args, data_collator, train_dataset, eval_dataset)
|
||||||
|
|
||||||
|
def create_accelerator_and_postprocess(self):
|
||||||
|
if self.accelerator is not None:
|
||||||
|
self.is_deepspeed_enabled = (
|
||||||
|
getattr(self.accelerator.state, "deepspeed_plugin", None)
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
self.is_fsdp_enabled = (
|
||||||
|
getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
||||||
|
)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
super().create_accelerator_and_postprocess()
|
Loading…
Reference in New Issue
Block a user