feat✨: 优化评估工具,添加结果合并和清理功能,改进文件处理逻辑
This commit is contained in:
parent
1b7fea800e
commit
44b06ea5db
@ -29,6 +29,7 @@ def evaluate_rouge(model, val_dataloader, processor, accelerator: Accelerator =
|
||||
|
||||
def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = None):
|
||||
import os
|
||||
|
||||
mtime = 0
|
||||
for root, dirs, files in os.walk("."):
|
||||
for file in files:
|
||||
@ -40,7 +41,6 @@ def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = No
|
||||
if not os.path.exists(f"results/{mtime}"):
|
||||
os.makedirs(f"results/{mtime}")
|
||||
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
if accelerator.is_local_main_process:
|
||||
@ -81,3 +81,19 @@ def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = No
|
||||
f.write(json.dumps(answer) + "\n")
|
||||
if accelerator.is_local_main_process:
|
||||
bar.update(1)
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_local_main_process:
|
||||
bar.close()
|
||||
# merge file
|
||||
answers = []
|
||||
files = [
|
||||
file for file in os.listdir(f"results/{mtime}") if file.endswith(".jsonl")
|
||||
]
|
||||
for file in files:
|
||||
with open(f"results/{mtime}/{file}", "r") as f:
|
||||
answers.extend(f.readlines())
|
||||
with open(f"results/{mtime}/answers.jsonl", "w") as f:
|
||||
f.writelines(answers)
|
||||
# delete file
|
||||
for file in files:
|
||||
os.remove(f"results/{mtime}/{file}")
|
||||
|
Loading…
Reference in New Issue
Block a user