提交 33b87902 编写于 作者: H Hui Zhang

refactor avg_model; fix set_value not support start==end

上级 9d5eb740
...@@ -94,9 +94,19 @@ def pad_sequence(sequences: List[paddle.Tensor], ...@@ -94,9 +94,19 @@ def pad_sequence(sequences: List[paddle.Tensor],
length = tensor.shape[0] length = tensor.shape[0]
# use index notation to prevent duplicate references to the tensor # use index notation to prevent duplicate references to the tensor
if batch_first: if batch_first:
# TODO (Hui Zhang): set_value op not supprot `end==start`
# out_tensor[i, :length, ...] = tensor
if length != 0:
out_tensor[i, :length, ...] = tensor out_tensor[i, :length, ...] = tensor
else: else:
out_tensor[i, length, ...] = tensor
else:
# TODO (Hui Zhang): set_value op not supprot `end==start`
# out_tensor[:length, i, ...] = tensor
if length != 0:
out_tensor[:length, i, ...] = tensor out_tensor[:length, i, ...] = tensor
else:
out_tensor[length, i, ...] = tensor
return out_tensor return out_tensor
......
...@@ -27,8 +27,9 @@ def main(args): ...@@ -27,8 +27,9 @@ def main(args):
val_scores = [] val_scores = []
beat_val_scores = [] beat_val_scores = []
selected_epochs = [] selected_epochs = []
if args.val_best:
jsons = glob.glob(f'{args.ckpt_dir}/[!train]*.json') jsons = glob.glob(f'{args.ckpt_dir}/[!train]*.json')
jsons = sorted(jsons, key=os.path.getmtime, reverse=True)
for y in jsons: for y in jsons:
with open(y, 'r') as f: with open(y, 'r') as f:
dic_json = json.load(f) dic_json = json.load(f)
...@@ -36,24 +37,23 @@ def main(args): ...@@ -36,24 +37,23 @@ def main(args):
epoch = dic_json['epoch'] epoch = dic_json['epoch']
if epoch >= args.min_epoch and epoch <= args.max_epoch: if epoch >= args.min_epoch and epoch <= args.max_epoch:
val_scores.append((epoch, loss)) val_scores.append((epoch, loss))
val_scores = np.array(val_scores) val_scores = np.array(val_scores)
if args.val_best:
sort_idx = np.argsort(val_scores[:, 1]) sort_idx = np.argsort(val_scores[:, 1])
sorted_val_scores = val_scores[sort_idx] sorted_val_scores = val_scores[sort_idx]
path_list = [ else:
args.ckpt_dir + '/{}.pdparams'.format(int(epoch)) sorted_val_scores = val_scores
for epoch in sorted_val_scores[:args.num, 0]
]
beat_val_scores = sorted_val_scores[:args.num, 1] beat_val_scores = sorted_val_scores[:args.num, 1]
selected_epochs = sorted_val_scores[:args.num, 0].astype(np.int64) selected_epochs = sorted_val_scores[:args.num, 0].astype(np.int64)
print("best val scores = " + str(beat_val_scores)) print("selected val scores = " + str(beat_val_scores))
print("selected epochs = " + str(selected_epochs)) print("selected epochs = " + str(selected_epochs))
else:
path_list = glob.glob(f'{args.ckpt_dir}/[!avg][!final]*.pdparams')
path_list = sorted(path_list, key=os.path.getmtime)
path_list = path_list[-args.num:]
path_list = [
args.ckpt_dir + '/{}.pdparams'.format(int(epoch))
for epoch in sorted_val_scores[:args.num, 0]
]
print(path_list) print(path_list)
avg = None avg = None
...@@ -78,10 +78,11 @@ def main(args): ...@@ -78,10 +78,11 @@ def main(args):
meta_path = os.path.splitext(args.dst_model)[0] + '.avg.json' meta_path = os.path.splitext(args.dst_model)[0] + '.avg.json'
with open(meta_path, 'w') as f: with open(meta_path, 'w') as f:
data = json.dumps({ data = json.dumps({
"mode": 'val_best' if args.val_best else 'latest',
"avg_ckpt": args.dst_model, "avg_ckpt": args.dst_model,
"ckpt": path_list, "ckpt": path_list,
"epoch": selected_epochs, "epoch": selected_epochs.tolist(),
"val_loss": beat_val_scores, "val_loss": beat_val_scores.tolist(),
}) })
f.write(data + "\n") f.write(data + "\n")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册