提交 ee6446a3 编写于 作者: H Hui Zhang

avg model dump val loss mean

上级 a4e27da6
......@@ -25,8 +25,8 @@ def main(args):
paddle.set_device('cpu')
val_scores = []
beat_val_scores = []
selected_epochs = []
beat_val_scores = None
selected_epochs = None
jsons = glob.glob(f'{args.ckpt_dir}/[!train]*.json')
jsons = sorted(jsons, key=os.path.getmtime, reverse=True)
......@@ -80,9 +80,10 @@ def main(args):
data = json.dumps({
"mode": 'val_best' if args.val_best else 'latest',
"avg_ckpt": args.dst_model,
"ckpt": path_list,
"epoch": selected_epochs.tolist(),
"val_loss": beat_val_scores.tolist(),
"val_loss_mean": np.mean(beat_val_scores),
"ckpts": path_list,
"epochs": selected_epochs.tolist(),
"val_losses": beat_val_scores.tolist(),
})
f.write(data + "\n")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册