164 lines
5.1 KiB
Python
164 lines
5.1 KiB
Python
import evaluate
|
|
from accelerate import Accelerator
|
|
from typing import TYPE_CHECKING
|
|
|
|
|
|
def evaluate_rouge(model, val_dataloader, processor, accelerator: Accelerator = None):
|
|
glue = evaluate.load("rouge")
|
|
|
|
for batch in val_dataloader:
|
|
completion = model.generate(
|
|
**batch,
|
|
max_length=1000,
|
|
)
|
|
target = batch["answers_ids"]
|
|
generated_text = [
|
|
out_ids[len(in_ids) :]
|
|
for out_ids, in_ids in zip(completion, batch["input_ids"])
|
|
]
|
|
generated_text = processor.tokenizer.batch_decode(
|
|
generated_text, skip_special_tokens=True
|
|
)
|
|
target_text = processor.tokenizer.batch_decode(target, skip_special_tokens=True)
|
|
glue.add_batch(predictions=generated_text, references=target_text)
|
|
|
|
print(glue.compute())
|
|
|
|
|
|
def evaluate_save(model, val_dataloader, processor, accelerator: Accelerator = None):
|
|
import os
|
|
|
|
mtime = 0
|
|
for root, dirs, files in os.walk("."):
|
|
for file in files:
|
|
time = os.path.getmtime(os.path.join(root, file))
|
|
if time > mtime:
|
|
mtime = time
|
|
|
|
# 获取目录最后修改时间
|
|
if accelerator.is_local_main_process:
|
|
if not os.path.exists(f"results/{mtime}"):
|
|
os.makedirs(f"results/{mtime}")
|
|
|
|
from tqdm import tqdm
|
|
|
|
if accelerator.is_local_main_process:
|
|
bar = tqdm(total=len(val_dataloader))
|
|
|
|
for batch in val_dataloader:
|
|
target = batch.pop("answers_ids")
|
|
origianl = batch.pop("original_data")
|
|
answers = []
|
|
completion = model.generate(
|
|
**batch,
|
|
# max_new_tokens=30,
|
|
max_length=1000,
|
|
)
|
|
generated_text = [
|
|
out_ids[len(in_ids) :]
|
|
for out_ids, in_ids in zip(completion, batch["input_ids"])
|
|
]
|
|
generated_text = processor.tokenizer.batch_decode(
|
|
generated_text, skip_special_tokens=True
|
|
)
|
|
target_text = processor.tokenizer.batch_decode(target, skip_special_tokens=True)
|
|
import json
|
|
|
|
world_size = accelerator.process_index
|
|
|
|
for i in range(len(generated_text)):
|
|
answer = {
|
|
"generated": generated_text[i],
|
|
"target": target_text[i],
|
|
"original": str(origianl[i]),
|
|
}
|
|
with open(f"results/{mtime}/answers_{world_size}.jsonl", "a") as f:
|
|
f.write(json.dumps(answer) + "\n")
|
|
|
|
if accelerator.is_local_main_process:
|
|
bar.update(1)
|
|
accelerator.wait_for_everyone()
|
|
if accelerator.is_local_main_process:
|
|
bar.close()
|
|
# merge file
|
|
answers = []
|
|
files = [
|
|
file for file in os.listdir(f"results/{mtime}") if file.endswith(".jsonl")
|
|
]
|
|
for file in files:
|
|
with open(f"results/{mtime}/{file}", "r") as f:
|
|
answers.extend(f.readlines())
|
|
with open(f"results/{mtime}/answers.jsonl", "w") as f:
|
|
f.writelines(answers)
|
|
# delete file
|
|
for file in files:
|
|
os.remove(f"results/{mtime}/{file}")
|
|
|
|
|
|
def evaluate_from_jsonl_directory(directory_path):
|
|
"""
|
|
从指定目录读取所有jsonl文件并计算综合评估结果
|
|
|
|
Args:
|
|
directory_path: 包含jsonl文件的目录路径
|
|
|
|
Returns:
|
|
dict: 包含各项指标和综合结果的字典
|
|
"""
|
|
import os
|
|
import json
|
|
|
|
# 初始化评估器
|
|
evaluate_bleu = evaluate.load("bleu")
|
|
evaluate_rouge = evaluate.load("rouge")
|
|
evaluate_meteor = evaluate.load("meteor")
|
|
|
|
# 读取目录下所有jsonl文件
|
|
all_data = []
|
|
for file in os.listdir(directory_path):
|
|
if file.endswith(".jsonl"):
|
|
file_path = os.path.join(directory_path, file)
|
|
with open(file_path, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if line:
|
|
data = json.loads(line)
|
|
all_data.append(data)
|
|
|
|
if not all_data:
|
|
print(f"未在目录 {directory_path} 中找到有效的jsonl数据")
|
|
return None
|
|
|
|
# 准备数据
|
|
predictions = [item["generated"] for item in all_data]
|
|
references = [[item["target"]] for item in all_data]
|
|
|
|
# 批量添加数据
|
|
evaluate_bleu.add_batch(predictions=predictions, references=references)
|
|
evaluate_rouge.add_batch(predictions=predictions, references=references)
|
|
evaluate_meteor.add_batch(predictions=predictions, references=references)
|
|
|
|
# 计算结果
|
|
bleu = evaluate_bleu.compute()
|
|
rouge = evaluate_rouge.compute()
|
|
meteor = evaluate_meteor.compute()
|
|
|
|
# 计算综合结果
|
|
comprehensive_score = (sum(bleu["precisions"]) + rouge["rougeL"] + meteor["meteor"]) / 6
|
|
|
|
results = {
|
|
"bleu": bleu,
|
|
"rouge": rouge,
|
|
"meteor": meteor,
|
|
"comprehensive_score": comprehensive_score,
|
|
"total_samples": len(all_data),
|
|
}
|
|
|
|
print(f"评估完成,共处理 {len(all_data)} 条数据")
|
|
print(f"BLEU分数: {bleu}")
|
|
print(f"ROUGE分数: {rouge}")
|
|
print(f"METEOR分数: {meteor}")
|
|
print(f"综合分数: {comprehensive_score}")
|
|
|
|
return results
|