diff --git a/src/utils/evaluate_tool.py b/src/utils/evaluate_tool.py index 494c8d5..01624d4 100644 --- a/src/utils/evaluate_tool.py +++ b/src/utils/evaluate_tool.py @@ -29,18 +29,18 @@ def evaluate_rouge(model, val_dataloader, processor, accelerator: Accelerator = def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = None): import os + mtime = 0 for root, dirs, files in os.walk("."): for file in files: time = os.path.getmtime(os.path.join(root, file)) if time > mtime: mtime = time - + # 获取目录最后修改时间 if not os.path.exists(f"results/{mtime}"): os.makedirs(f"results/{mtime}") - from tqdm import tqdm if accelerator.is_local_main_process: @@ -81,3 +81,19 @@ def evalute_save(model, val_dataloader, processor, accelerator: Accelerator = No f.write(json.dumps(answer) + "\n") if accelerator.is_local_main_process: bar.update(1) + accelerator.wait_for_everyone() + if accelerator.is_local_main_process: + bar.close() + # merge file + answers = [] + files = [ + file for file in os.listdir(f"results/{mtime}") if file.endswith(".jsonl") + ] + for file in files: + with open(f"results/{mtime}/{file}", "r") as f: + answers.extend(f.readlines()) + with open(f"results/{mtime}/answers.jsonl", "w") as f: + f.writelines(answers) + # delete file + for file in files: + os.remove(f"results/{mtime}/{file}")