From 44b06ea5dbc523ea26892c9b19d1e660b241fb70 Mon Sep 17 00:00:00 2001 From: YunyaoZhou Date: Wed, 15 Jan 2025 04:03:53 +0800 Subject: [PATCH] =?UTF-8?q?feat=E2=9C=A8:=20=E4=BC=98=E5=8C=96=E8=AF=84?= =?UTF-8?q?=E4=BC=B0=E5=B7=A5=E5=85=B7=EF=BC=8C=E6=B7=BB=E5=8A=A0=E7=BB=93?= =?UTF-8?q?=E6=9E=9C=E5=90=88=E5=B9=B6=E5=92=8C=E6=B8=85=E7=90=86=E5=8A=9F?= =?UTF-8?q?=E8=83=BD=EF=BC=8C=E6=94=B9=E8=BF=9B=E6=96=87=E4=BB=B6=E5=A4=84?= =?UTF-8?q?=E7=90=86=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/utils/evaluate_tool.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/utils/evaluate_tool.py b/src/utils/evaluate_tool.py index 494c8d5..01624d4 100644 --- a/src/utils/evaluate_tool.py +++ b/src/utils/evaluate_tool.py @@ -29,18 +29,18 @@ def evaluate_rouge(model, val_dataloader, processor, accelerator: Accelerator = def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = None): import os + mtime = 0 for root, dirs, files in os.walk("."): for file in files: time = os.path.getmtime(os.path.join(root, file)) if time > mtime: mtime = time - + # 获取目录最后修改时间 if not os.path.exists(f"results/{mtime}"): os.makedirs(f"results/{mtime}") - from tqdm import tqdm if accelerator.is_local_main_process: @@ -81,3 +81,19 @@ def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = No f.write(json.dumps(answer) + "\n") if accelerator.is_local_main_process: bar.update(1) + accelerator.wait_for_everyone() + if accelerator.is_local_main_process: + bar.close() + # merge file + answers = [] + files = [ + file for file in os.listdir(f"results/{mtime}") if file.endswith(".jsonl") + ] + for file in files: + with open(f"results/{mtime}/{file}", "r") as f: + answers.extend(f.readlines()) + with open(f"results/{mtime}/answers.jsonl", "w") as f: + f.writelines(answers) + # delete file + for file in files: + os.remove(f"results/{mtime}/{file}")