提交 33b87902 编写于 作者: H Hui Zhang

refactor avg_model; fix set_value not support start==end

上级 9d5eb740
......@@ -94,9 +94,19 @@ def pad_sequence(sequences: List[paddle.Tensor],
length = tensor.shape[0]
# use index notation to prevent duplicate references to the tensor
if batch_first:
out_tensor[i, :length, ...] = tensor
# TODO (Hui Zhang): set_value op not supprot `end==start`
# out_tensor[i, :length, ...] = tensor
if length != 0:
out_tensor[i, :length, ...] = tensor
else:
out_tensor[i, length, ...] = tensor
else:
out_tensor[:length, i, ...] = tensor
# TODO (Hui Zhang): set_value op not supprot `end==start`
# out_tensor[:length, i, ...] = tensor
if length != 0:
out_tensor[:length, i, ...] = tensor
else:
out_tensor[length, i, ...] = tensor
return out_tensor
......
......@@ -27,33 +27,33 @@ def main(args):
val_scores = []
beat_val_scores = []
selected_epochs = []
if args.val_best:
jsons = glob.glob(f'{args.ckpt_dir}/[!train]*.json')
for y in jsons:
with open(y, 'r') as f:
dic_json = json.load(f)
loss = dic_json['val_loss']
epoch = dic_json['epoch']
if epoch >= args.min_epoch and epoch <= args.max_epoch:
val_scores.append((epoch, loss))
val_scores = np.array(val_scores)
jsons = glob.glob(f'{args.ckpt_dir}/[!train]*.json')
jsons = sorted(jsons, key=os.path.getmtime, reverse=True)
for y in jsons:
with open(y, 'r') as f:
dic_json = json.load(f)
loss = dic_json['val_loss']
epoch = dic_json['epoch']
if epoch >= args.min_epoch and epoch <= args.max_epoch:
val_scores.append((epoch, loss))
val_scores = np.array(val_scores)
if args.val_best:
sort_idx = np.argsort(val_scores[:, 1])
sorted_val_scores = val_scores[sort_idx]
path_list = [
args.ckpt_dir + '/{}.pdparams'.format(int(epoch))
for epoch in sorted_val_scores[:args.num, 0]
]
beat_val_scores = sorted_val_scores[:args.num, 1]
selected_epochs = sorted_val_scores[:args.num, 0].astype(np.int64)
print("best val scores = " + str(beat_val_scores))
print("selected epochs = " + str(selected_epochs))
else:
path_list = glob.glob(f'{args.ckpt_dir}/[!avg][!final]*.pdparams')
path_list = sorted(path_list, key=os.path.getmtime)
path_list = path_list[-args.num:]
sorted_val_scores = val_scores
beat_val_scores = sorted_val_scores[:args.num, 1]
selected_epochs = sorted_val_scores[:args.num, 0].astype(np.int64)
print("selected val scores = " + str(beat_val_scores))
print("selected epochs = " + str(selected_epochs))
path_list = [
args.ckpt_dir + '/{}.pdparams'.format(int(epoch))
for epoch in sorted_val_scores[:args.num, 0]
]
print(path_list)
avg = None
......@@ -78,10 +78,11 @@ def main(args):
meta_path = os.path.splitext(args.dst_model)[0] + '.avg.json'
with open(meta_path, 'w') as f:
data = json.dumps({
"mode": 'val_best' if args.val_best else 'latest',
"avg_ckpt": args.dst_model,
"ckpt": path_list,
"epoch": selected_epochs,
"val_loss": beat_val_scores,
"epoch": selected_epochs.tolist(),
"val_loss": beat_val_scores.tolist(),
})
f.write(data + "\n")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册