feat✨: 添加MOELORA支持,优化训练和评估脚本,修复拼写错误,提升代码可读性
This commit is contained in:
parent
b84ebb03c7
commit
d686cbc254
2
.vscode/settings.json
vendored
2
.vscode/settings.json
vendored
@ -15,5 +15,5 @@
|
|||||||
"python.analysis.typeCheckingMode": "basic",
|
"python.analysis.typeCheckingMode": "basic",
|
||||||
"python.analysis.userFileIndexingLimit": 10000,
|
"python.analysis.userFileIndexingLimit": 10000,
|
||||||
"python.analysis.usePullDiagnostics": false,
|
"python.analysis.usePullDiagnostics": false,
|
||||||
"python.analysis.importFormat": "relative"
|
"python.analysis.importFormat": "relative",
|
||||||
}
|
}
|
@ -19,99 +19,39 @@ from trl import (
|
|||||||
get_quantization_config,
|
get_quantization_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
from utils.args import ContinualScriptArguments, ContinualModelConfig
|
from utils.args import (
|
||||||
|
ContinualScriptArguments,
|
||||||
|
ContinualModelConfig,
|
||||||
|
ContinualRegularizationArguments,
|
||||||
|
)
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = TrlParser(
|
parser = TrlParser(
|
||||||
(ContinualScriptArguments, TrainingArguments, ContinualModelConfig)
|
(
|
||||||
|
ContinualScriptArguments,
|
||||||
|
TrainingArguments,
|
||||||
|
ContinualModelConfig,
|
||||||
|
ContinualRegularizationArguments,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
script_args, training_args, model_args, reg_args = parser.parse_args_and_config()
|
||||||
# for type hint
|
# for type hint
|
||||||
if 0 == 1:
|
if TYPE_CHECKING:
|
||||||
script_args = ContinualScriptArguments()
|
script_args = ContinualScriptArguments()
|
||||||
training_args = TrainingArguments()
|
training_args = TrainingArguments()
|
||||||
model_args = ModelConfig()
|
model_args = ContinualModelConfig()
|
||||||
|
reg_args = ContinualRegularizationArguments()
|
||||||
|
|
||||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||||
training_args.remove_unused_columns = False
|
training_args.remove_unused_columns = False
|
||||||
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
|
||||||
|
|
||||||
from model_library.factory import get_model
|
from model_library.factory import get_model
|
||||||
|
|
||||||
if model_args.model_name_or_path == "Qwen/Qwen2.5-VL-3B-Instruct":
|
model, processor, collate_fn_for_train, collate_fn_for_evaluate = get_model(
|
||||||
torch_dtype = (
|
model_args=model_args, training_args=training_args
|
||||||
model_args.torch_dtype
|
)
|
||||||
if model_args.torch_dtype in ["auto", None]
|
|
||||||
else getattr(torch, model_args.torch_dtype)
|
|
||||||
)
|
|
||||||
quantization_config = get_quantization_config(model_args)
|
|
||||||
model_kwargs = dict(
|
|
||||||
attn_implementation=model_args.attn_implementation,
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
quantization_config=quantization_config,
|
|
||||||
)
|
|
||||||
from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration
|
|
||||||
|
|
||||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
|
||||||
training_args.output_dir,
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
processor = Qwen2_5_VLProcessor.from_pretrained(
|
|
||||||
model_args.model_name_or_path,
|
|
||||||
trust_remote_code=model_args.trust_remote_code,
|
|
||||||
padding_side="left",
|
|
||||||
)
|
|
||||||
|
|
||||||
from model_library.qwen2vl import collate_fn_for_train, collate_fn_for_evaluate
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
|
||||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
|
||||||
|
|
||||||
elif model_args.model_name_or_path == "Qwen/Qwen2-VL-7B-Instruct":
|
|
||||||
torch_dtype = (
|
|
||||||
model_args.torch_dtype
|
|
||||||
if model_args.torch_dtype in ["auto", None]
|
|
||||||
else getattr(torch, model_args.torch_dtype)
|
|
||||||
)
|
|
||||||
|
|
||||||
quantization_config = get_quantization_config(model_args)
|
|
||||||
model_kwargs = dict(
|
|
||||||
attn_implementation=model_args.attn_implementation,
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
quantization_config=quantization_config,
|
|
||||||
)
|
|
||||||
from transformers import (
|
|
||||||
Qwen2VLProcessor,
|
|
||||||
Qwen2VLForConditionalGeneration,
|
|
||||||
AutoModelForVision2Seq,
|
|
||||||
AutoModel,
|
|
||||||
)
|
|
||||||
from peft.peft_model import PeftModelForCausalLM
|
|
||||||
from model_library.qwen2vl import Qwen2VLForConditionalGeneration_modified
|
|
||||||
|
|
||||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
||||||
training_args.output_dir,
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# from peft_library import get_peft_model
|
|
||||||
|
|
||||||
processor = Qwen2VLProcessor.from_pretrained(
|
|
||||||
model_args.model_name_or_path,
|
|
||||||
trust_remote_code=model_args.trust_remote_code,
|
|
||||||
padding_side="left",
|
|
||||||
)
|
|
||||||
from model_library.qwen2vl import (
|
|
||||||
collate_fn_for_train,
|
|
||||||
collate_fn_for_evaluate,
|
|
||||||
)
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
|
||||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
|
||||||
|
|
||||||
################
|
################
|
||||||
# Dataset
|
# Dataset
|
||||||
################
|
################
|
||||||
@ -139,6 +79,6 @@ if __name__ == "__main__":
|
|||||||
collate_fn=collate_fn_for_evaluate,
|
collate_fn=collate_fn_for_evaluate,
|
||||||
)
|
)
|
||||||
val_dataloader = accelerator.prepare_data_loader(val_dataloader)
|
val_dataloader = accelerator.prepare_data_loader(val_dataloader)
|
||||||
from utils.evaluate_tool import evaluate_rouge, evalute_save
|
from utils.evaluate_tool import evaluate_rouge, evaluate_save
|
||||||
|
|
||||||
evalute_save(model, val_dataloader, processor, accelerator)
|
evaluate_save(model, val_dataloader, processor, accelerator)
|
||||||
|
@ -5,9 +5,12 @@ from trl import (
|
|||||||
get_quantization_config,
|
get_quantization_config,
|
||||||
)
|
)
|
||||||
from utils.args import ContinualModelConfig
|
from utils.args import ContinualModelConfig
|
||||||
|
from transformers import TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
def get_model(model_args: ContinualModelConfig):
|
def get_model(
|
||||||
|
model_args: ContinualModelConfig, training_args: TrainingArguments = None
|
||||||
|
):
|
||||||
torch_dtype = (
|
torch_dtype = (
|
||||||
model_args.torch_dtype
|
model_args.torch_dtype
|
||||||
if model_args.torch_dtype in ["auto", None]
|
if model_args.torch_dtype in ["auto", None]
|
||||||
@ -26,12 +29,20 @@ def get_model(model_args: ContinualModelConfig):
|
|||||||
from transformers import Qwen2VLProcessor, Qwen2VLForConditionalGeneration
|
from transformers import Qwen2VLProcessor, Qwen2VLForConditionalGeneration
|
||||||
|
|
||||||
# from .qwen2vl import Qwen2VLForConditionalGeneration_modified
|
# from .qwen2vl import Qwen2VLForConditionalGeneration_modified
|
||||||
|
if training_args is not None:
|
||||||
|
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||||
|
training_args.output_dir,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
||||||
model_args.model_name_or_path,
|
|
||||||
trust_remote_code=model_args.trust_remote_code,
|
|
||||||
**model_kwargs,
|
|
||||||
)
|
|
||||||
processor = Qwen2VLProcessor.from_pretrained(
|
processor = Qwen2VLProcessor.from_pretrained(
|
||||||
model_args.model_name_or_path,
|
model_args.model_name_or_path,
|
||||||
trust_remote_code=model_args.trust_remote_code,
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
@ -49,11 +60,18 @@ def get_model(model_args: ContinualModelConfig):
|
|||||||
if model_args.model_name_or_path == "Qwen/Qwen2-Audio-7B-Instruct":
|
if model_args.model_name_or_path == "Qwen/Qwen2-Audio-7B-Instruct":
|
||||||
from transformers import Qwen2AudioProcessor, Qwen2AudioForConditionalGeneration
|
from transformers import Qwen2AudioProcessor, Qwen2AudioForConditionalGeneration
|
||||||
|
|
||||||
model = Qwen2AudioForConditionalGeneration.from_pretrained(
|
if training_args is not None:
|
||||||
model_args.model_name_or_path,
|
model = Qwen2AudioForConditionalGeneration.from_pretrained(
|
||||||
trust_remote_code=model_args.trust_remote_code,
|
training_args.output_dir,
|
||||||
**model_kwargs,
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
)
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model = Qwen2AudioForConditionalGeneration.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
processor = Qwen2AudioProcessor.from_pretrained(
|
processor = Qwen2AudioProcessor.from_pretrained(
|
||||||
model_args.model_name_or_path,
|
model_args.model_name_or_path,
|
||||||
trust_remote_code=model_args.trust_remote_code,
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
@ -71,11 +89,18 @@ def get_model(model_args: ContinualModelConfig):
|
|||||||
if model_args.model_name_or_path == "Qwen/Qwen2.5-VL-3B-Instruct":
|
if model_args.model_name_or_path == "Qwen/Qwen2.5-VL-3B-Instruct":
|
||||||
from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration
|
from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration
|
||||||
|
|
||||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
if training_args is not None:
|
||||||
model_args.model_name_or_path,
|
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
trust_remote_code=model_args.trust_remote_code,
|
training_args.output_dir,
|
||||||
**model_kwargs,
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
)
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
processor = Qwen2_5_VLProcessor.from_pretrained(
|
processor = Qwen2_5_VLProcessor.from_pretrained(
|
||||||
model_args.model_name_or_path,
|
model_args.model_name_or_path,
|
||||||
@ -88,18 +113,25 @@ def get_model(model_args: ContinualModelConfig):
|
|||||||
|
|
||||||
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
collate_fn_for_train = partial(collate_fn_for_train, processor=processor)
|
||||||
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
collate_fn_for_evaluate = partial(collate_fn_for_evaluate, processor=processor)
|
||||||
|
|
||||||
if model_args.model_name_or_path == "Qwen/Qwen2.5-Omni-3B":
|
if model_args.model_name_or_path == "Qwen/Qwen2.5-Omni-3B":
|
||||||
from transformers.models.qwen2_5_omni import (
|
from transformers.models.qwen2_5_omni import (
|
||||||
Qwen2_5OmniThinkerForConditionalGeneration,
|
Qwen2_5OmniThinkerForConditionalGeneration,
|
||||||
Qwen2_5OmniProcessor
|
Qwen2_5OmniProcessor,
|
||||||
)
|
)
|
||||||
|
|
||||||
model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
|
if training_args is not None:
|
||||||
model_args.model_name_or_path,
|
model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
|
||||||
trust_remote_code=model_args.trust_remote_code,
|
training_args.output_dir,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
|
||||||
|
model_args.model_name_or_path,
|
||||||
|
trust_remote_code=model_args.trust_remote_code,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
processor = Qwen2_5OmniProcessor.from_pretrained(
|
processor = Qwen2_5OmniProcessor.from_pretrained(
|
||||||
model_args.model_name_or_path,
|
model_args.model_name_or_path,
|
||||||
|
@ -1 +1 @@
|
|||||||
Subproject commit 317d957cc101c4cb064066a1b228526a55f6e927
|
Subproject commit f58e3bd57f3f6cf2f713edaac4b8a54ecafe8e20
|
15
src/scripts/eval_omni.sh
Executable file
15
src/scripts/eval_omni.sh
Executable file
@ -0,0 +1,15 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
accelerate launch --config_file configs/accelerate_configs/deepspeed_zero1.yaml evaluation.py \
|
||||||
|
--dataset_name textvqa \
|
||||||
|
--use_peft \
|
||||||
|
--peft_type MOELORA \
|
||||||
|
--model_name_or_path Qwen/Qwen2.5-Omni-3B \
|
||||||
|
--lora_target_modules .*model\.layers.*proj\|.*merger.*0\|.*merger.*1 \
|
||||||
|
--per_device_train_batch_size 3 \
|
||||||
|
--per_device_eval_batch_size 2 \
|
||||||
|
--gradient_accumulation_steps 2 \
|
||||||
|
--output_dir ./checkpoint/qwen2_5omni_moelora/ \
|
||||||
|
--bf16 \
|
||||||
|
--torch_dtype bfloat16
|
||||||
|
# --eval_strategy epoch \
|
@ -18,7 +18,8 @@ accelerate launch --config_file configs/accelerate_configs/deepspeed_zero1.yaml
|
|||||||
--lr_scheduler_type cosine \
|
--lr_scheduler_type cosine \
|
||||||
--bf16 \
|
--bf16 \
|
||||||
--torch_dtype bfloat16 \
|
--torch_dtype bfloat16 \
|
||||||
--logging_steps 10 \
|
--logging_steps 100 \
|
||||||
--gradient_checkpointing \
|
--gradient_checkpointing \
|
||||||
--weight_decay 0.1 \
|
--weight_decay 0.1 \
|
||||||
# --resume_from_checkpoint /root/autodl-tmp/zhouyunyao/projects/CL-LMM/src/checkpoint/qwen2_alllinear/checkpoint-1000
|
--eval_strategy steps \
|
||||||
|
# --resume_from_checkpoint /root/autodl-tmp/zhouyunyao/projects/CL-LMM/src/checkpoint/qwen2_5omni_moelora/checkpoint-1500
|
31
src/test_evalutae.py
Normal file
31
src/test_evalutae.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import evaluate
|
||||||
|
|
||||||
|
# Bleu_1, Bleu_2, Bleu_3, Bleu_4, METEOR, ROUGE_L, and CIDEr
|
||||||
|
example = {
|
||||||
|
"generated": "The cat sat on the mat.",
|
||||||
|
"target": "The cat is sitting on the mat.",
|
||||||
|
"original": "The cat is sitting on the mat.",
|
||||||
|
}
|
||||||
|
evaluate_bleu = evaluate.load("bleu")
|
||||||
|
evaluate_rouge = evaluate.load("rouge")
|
||||||
|
evaluate_meteor = evaluate.load("meteor")
|
||||||
|
|
||||||
|
evaluate_bleu.add_batch(
|
||||||
|
predictions=[example["generated"]],
|
||||||
|
references=[[example["target"]]],
|
||||||
|
)
|
||||||
|
evaluate_rouge.add_batch(
|
||||||
|
predictions=[example["generated"]],
|
||||||
|
references=[[example["target"]]],
|
||||||
|
)
|
||||||
|
evaluate_meteor.add_batch(
|
||||||
|
predictions=[example["generated"]],
|
||||||
|
references=[[example["target"]]],
|
||||||
|
)
|
||||||
|
|
||||||
|
bleu = evaluate_bleu.compute()
|
||||||
|
rouge = evaluate_rouge.compute()
|
||||||
|
meteor = evaluate_meteor.compute()
|
||||||
|
|
||||||
|
comprehensive_results = sum(bleu['precisions']) + rouge['rougeL'] + meteor['meteor']
|
||||||
|
print("Comprehensive Results:", comprehensive_results/6)
|
21
src/train.py
21
src/train.py
@ -16,7 +16,7 @@ from utils.trainer import ContinualTrainer
|
|||||||
from utils.args import (
|
from utils.args import (
|
||||||
ContinualScriptArguments,
|
ContinualScriptArguments,
|
||||||
ContinualModelConfig,
|
ContinualModelConfig,
|
||||||
ContiunalRegularizationArguments,
|
ContinualRegularizationArguments,
|
||||||
)
|
)
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
@ -31,8 +31,8 @@ if __name__ == "__main__":
|
|||||||
ContinualScriptArguments,
|
ContinualScriptArguments,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
ContinualModelConfig,
|
ContinualModelConfig,
|
||||||
ContiunalRegularizationArguments,
|
ContinualRegularizationArguments,
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
)
|
)
|
||||||
script_args, training_args, model_args, reg_args = parser.parse_args_and_config()
|
script_args, training_args, model_args, reg_args = parser.parse_args_and_config()
|
||||||
# for type hint
|
# for type hint
|
||||||
@ -40,7 +40,7 @@ if __name__ == "__main__":
|
|||||||
script_args = ContinualScriptArguments()
|
script_args = ContinualScriptArguments()
|
||||||
training_args = TrainingArguments()
|
training_args = TrainingArguments()
|
||||||
model_args = ContinualModelConfig()
|
model_args = ContinualModelConfig()
|
||||||
reg_args = ContiunalRegularizationArguments()
|
reg_args = ContinualRegularizationArguments()
|
||||||
|
|
||||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||||
training_args.remove_unused_columns = False
|
training_args.remove_unused_columns = False
|
||||||
@ -48,7 +48,7 @@ if __name__ == "__main__":
|
|||||||
from model_library.factory import get_model
|
from model_library.factory import get_model
|
||||||
|
|
||||||
model, processor, collate_fn_for_train, collate_fn_for_evaluate = get_model(
|
model, processor, collate_fn_for_train, collate_fn_for_evaluate = get_model(
|
||||||
model_args
|
model_args=model_args
|
||||||
)
|
)
|
||||||
################
|
################
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -100,9 +100,9 @@ if __name__ == "__main__":
|
|||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
data_collator=collate_fn_for_train,
|
data_collator=collate_fn_for_train,
|
||||||
train_dataset=dataset[script_args.dataset_train_split], # type: ignore
|
train_dataset=dataset[script_args.dataset_train_split], # type: ignore
|
||||||
eval_dataset=(
|
eval_dataset=(
|
||||||
dataset[script_args.dataset_test_split] # type: ignore
|
dataset[script_args.dataset_test_split] # type: ignore
|
||||||
if training_args.eval_strategy != "no"
|
if training_args.eval_strategy != "no"
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
@ -113,7 +113,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
if accelerator.is_local_main_process:
|
if accelerator.is_local_main_process:
|
||||||
print("Saving model")
|
print("Saving model")
|
||||||
trainer.save_model(training_args.output_dir)
|
# trainer.save_model(training_args.output_dir)
|
||||||
|
model.save_pretrained(training_args.output_dir)
|
||||||
|
|
||||||
if accelerator.is_local_main_process:
|
if accelerator.is_local_main_process:
|
||||||
print("Model saved")
|
print("Model saved")
|
||||||
@ -131,6 +132,6 @@ if __name__ == "__main__":
|
|||||||
# )
|
# )
|
||||||
# val_dataloader = accelerator.prepare(val_dataloader)
|
# val_dataloader = accelerator.prepare(val_dataloader)
|
||||||
|
|
||||||
# from utils.evaluate_tool import evaluate_rouge
|
# from utils.evaluate_tool import evaluate_save
|
||||||
|
|
||||||
# evaluate_rouge(model, val_dataloader, processor)
|
# evaluate_save(model, val_dataloader, processor, accelerator)
|
||||||
|
@ -21,7 +21,7 @@ class ContinualModelConfig(ModelConfig):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ContiunalRegularizationArguments:
|
class ContinualRegularizationArguments:
|
||||||
"""Regularization arguments for continual learning."""
|
"""Regularization arguments for continual learning."""
|
||||||
|
|
||||||
# EWC
|
# EWC
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import evaluate
|
import evaluate
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
|
||||||
def evaluate_rouge(model, val_dataloader, processor, accelerator: Accelerator = None):
|
def evaluate_rouge(model, val_dataloader, processor, accelerator: Accelerator = None):
|
||||||
@ -7,10 +8,7 @@ def evaluate_rouge(model, val_dataloader, processor, accelerator: Accelerator =
|
|||||||
|
|
||||||
for batch in val_dataloader:
|
for batch in val_dataloader:
|
||||||
completion = model.generate(
|
completion = model.generate(
|
||||||
input_ids=batch["input_ids"],
|
**batch,
|
||||||
attention_mask=batch["attention_mask"],
|
|
||||||
pixel_values=batch["pixel_values"],
|
|
||||||
image_grid_thw=batch["image_grid_thw"],
|
|
||||||
max_length=1000,
|
max_length=1000,
|
||||||
)
|
)
|
||||||
target = batch["answers_ids"]
|
target = batch["answers_ids"]
|
||||||
@ -27,7 +25,7 @@ def evaluate_rouge(model, val_dataloader, processor, accelerator: Accelerator =
|
|||||||
print(glue.compute())
|
print(glue.compute())
|
||||||
|
|
||||||
|
|
||||||
def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = None):
|
def evaluate_save(model, val_dataloader, processor, accelerator: Accelerator = None):
|
||||||
import os
|
import os
|
||||||
|
|
||||||
mtime = 0
|
mtime = 0
|
||||||
@ -53,6 +51,7 @@ def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = No
|
|||||||
answers = []
|
answers = []
|
||||||
completion = model.generate(
|
completion = model.generate(
|
||||||
**batch,
|
**batch,
|
||||||
|
# max_new_tokens=30,
|
||||||
max_length=1000,
|
max_length=1000,
|
||||||
)
|
)
|
||||||
generated_text = [
|
generated_text = [
|
||||||
@ -63,20 +62,17 @@ def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = No
|
|||||||
generated_text, skip_special_tokens=True
|
generated_text, skip_special_tokens=True
|
||||||
)
|
)
|
||||||
target_text = processor.tokenizer.batch_decode(target, skip_special_tokens=True)
|
target_text = processor.tokenizer.batch_decode(target, skip_special_tokens=True)
|
||||||
for i in range(len(generated_text)):
|
|
||||||
answers.append(
|
|
||||||
{
|
|
||||||
"generated": generated_text[i],
|
|
||||||
"target": target_text[i],
|
|
||||||
"original": str(origianl[i]),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
world_size = accelerator.process_index
|
world_size = accelerator.process_index
|
||||||
|
|
||||||
with open(f"results/{mtime}/answers_{world_size}.jsonl", "a") as f:
|
for i in range(len(generated_text)):
|
||||||
for answer in answers:
|
answer = {
|
||||||
|
"generated": generated_text[i],
|
||||||
|
"target": target_text[i],
|
||||||
|
"original": str(origianl[i]),
|
||||||
|
}
|
||||||
|
with open(f"results/{mtime}/answers_{world_size}.jsonl", "a") as f:
|
||||||
f.write(json.dumps(answer) + "\n")
|
f.write(json.dumps(answer) + "\n")
|
||||||
|
|
||||||
if accelerator.is_local_main_process:
|
if accelerator.is_local_main_process:
|
||||||
@ -97,3 +93,71 @@ def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = No
|
|||||||
# delete file
|
# delete file
|
||||||
for file in files:
|
for file in files:
|
||||||
os.remove(f"results/{mtime}/{file}")
|
os.remove(f"results/{mtime}/{file}")
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_from_jsonl_directory(directory_path):
|
||||||
|
"""
|
||||||
|
从指定目录读取所有jsonl文件并计算综合评估结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
directory_path: 包含jsonl文件的目录路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 包含各项指标和综合结果的字典
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
# 初始化评估器
|
||||||
|
evaluate_bleu = evaluate.load("bleu")
|
||||||
|
evaluate_rouge = evaluate.load("rouge")
|
||||||
|
evaluate_meteor = evaluate.load("meteor")
|
||||||
|
|
||||||
|
# 读取目录下所有jsonl文件
|
||||||
|
all_data = []
|
||||||
|
for file in os.listdir(directory_path):
|
||||||
|
if file.endswith(".jsonl"):
|
||||||
|
file_path = os.path.join(directory_path, file)
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
data = json.loads(line)
|
||||||
|
all_data.append(data)
|
||||||
|
|
||||||
|
if not all_data:
|
||||||
|
print(f"未在目录 {directory_path} 中找到有效的jsonl数据")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 准备数据
|
||||||
|
predictions = [item["generated"] for item in all_data]
|
||||||
|
references = [[item["target"]] for item in all_data]
|
||||||
|
|
||||||
|
# 批量添加数据
|
||||||
|
evaluate_bleu.add_batch(predictions=predictions, references=references)
|
||||||
|
evaluate_rouge.add_batch(predictions=predictions, references=references)
|
||||||
|
evaluate_meteor.add_batch(predictions=predictions, references=references)
|
||||||
|
|
||||||
|
# 计算结果
|
||||||
|
bleu = evaluate_bleu.compute()
|
||||||
|
rouge = evaluate_rouge.compute()
|
||||||
|
meteor = evaluate_meteor.compute()
|
||||||
|
|
||||||
|
# 计算综合结果
|
||||||
|
comprehensive_score = (sum(bleu["precisions"]) + rouge["rougeL"] + meteor["meteor"]) / 6
|
||||||
|
|
||||||
|
results = {
|
||||||
|
"bleu": bleu,
|
||||||
|
"rouge": rouge,
|
||||||
|
"meteor": meteor,
|
||||||
|
"comprehensive_score": comprehensive_score,
|
||||||
|
"total_samples": len(all_data),
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"评估完成,共处理 {len(all_data)} 条数据")
|
||||||
|
print(f"BLEU分数: {bleu}")
|
||||||
|
print(f"ROUGE分数: {rouge}")
|
||||||
|
print(f"METEOR分数: {meteor}")
|
||||||
|
print(f"综合分数: {comprehensive_score}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
@ -5,7 +5,7 @@ from transformers.trainer import *
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
)
|
)
|
||||||
from .args import ContiunalRegularizationArguments
|
from .args import ContinualRegularizationArguments
|
||||||
from peft_library.regularizations import EWC, LWF
|
from peft_library.regularizations import EWC, LWF
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
@ -41,7 +41,7 @@ class ContinualTrainer(Trainer):
|
|||||||
train_dataset,
|
train_dataset,
|
||||||
eval_dataset,
|
eval_dataset,
|
||||||
accelerator,
|
accelerator,
|
||||||
reg_args: ContiunalRegularizationArguments = None,
|
reg_args: ContinualRegularizationArguments = None,
|
||||||
):
|
):
|
||||||
self.accelerator = accelerator
|
self.accelerator = accelerator
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -155,4 +155,3 @@ class ContinualTrainer(Trainer):
|
|||||||
self.optimizer = smp.DistributedOptimizer(self.optimizer)
|
self.optimizer = smp.DistributedOptimizer(self.optimizer)
|
||||||
|
|
||||||
return self.optimizer
|
return self.optimizer
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user