提交 f2a42bd3 编写于 作者: H Hui Zhang

more avg and test info

上级 c693bb08
...@@ -125,6 +125,7 @@ if not hasattr(paddle, 'cat'): ...@@ -125,6 +125,7 @@ if not hasattr(paddle, 'cat'):
def item(x: paddle.Tensor): def item(x: paddle.Tensor):
return x.numpy().item() return x.numpy().item()
if not hasattr(paddle.Tensor, 'item'): if not hasattr(paddle.Tensor, 'item'):
logger.warn( logger.warn(
"override item of paddle.Tensor if exists or register, remove this when fixed!" "override item of paddle.Tensor if exists or register, remove this when fixed!"
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Contains U2 model.""" """Contains U2 model."""
import json
import os
import sys import sys
import time import time
from collections import defaultdict from collections import defaultdict
...@@ -439,6 +441,31 @@ class U2Tester(U2Trainer): ...@@ -439,6 +441,31 @@ class U2Tester(U2Trainer):
error_rate_type, num_ins, num_ins, errors_sum / len_refs) error_rate_type, num_ins, num_ins, errors_sum / len_refs)
logger.info(msg) logger.info(msg)
# test meta results
err_meta_path = os.path.splitext(self.args.checkpoint_path)[0] + '.err'
err_type_str = "{}".format(error_rate_type)
with open(err_meta_path, 'w') as f:
data = json.dumps({
"epoch":
self.epoch,
"step":
self.iteration,
"rtf":
rtf,
error_rate_type:
errors_sum / len_refs,
"dataset_hour": (num_frames * stride_ms) / 1000.0 / 3600.0,
"process_hour":
num_time / 1000.0 / 3600.0,
"num_examples":
num_ins,
"err_sum":
errors_sum,
"ref_len":
len_refs,
})
f.write(data + '\n')
def run_test(self): def run_test(self):
self.resume_or_scratch() self.resume_or_scratch()
try: try:
......
...@@ -7,7 +7,7 @@ fi ...@@ -7,7 +7,7 @@ fi
ckpt_path=${1} ckpt_path=${1}
average_num=${2} average_num=${2}
decode_checkpoint=${ckpt_path}/avg_${average_num}.pt decode_checkpoint=${ckpt_path}/avg_${average_num}.pdparams
python3 -u ${MAIN_ROOT}/utils/avg_model.py \ python3 -u ${MAIN_ROOT}/utils/avg_model.py \
--dst_model ${decode_checkpoint} \ --dst_model ${decode_checkpoint} \
...@@ -21,4 +21,4 @@ if [ $? -ne 0 ]; then ...@@ -21,4 +21,4 @@ if [ $? -ne 0 ]; then
fi fi
exit 0 exit 0
\ No newline at end of file
...@@ -21,14 +21,15 @@ import paddle ...@@ -21,14 +21,15 @@ import paddle
def main(args): def main(args):
checkpoints = []
val_scores = [] val_scores = []
beat_val_scores = []
selected_epochs = []
if args.val_best: if args.val_best:
jsons = glob.glob(f'{args.ckpt_dir}/[!train]*.json') jsons = glob.glob(f'{args.ckpt_dir}/[!train]*.json')
for y in jsons: for y in jsons:
dic_json = json.load(y) with open(y, 'r') as f:
loss = dic_json['valid_loss'] dic_json = json.load(f)
loss = dic_json['val_loss']
epoch = dic_json['epoch'] epoch = dic_json['epoch']
if epoch >= args.min_epoch and epoch <= args.max_epoch: if epoch >= args.min_epoch and epoch <= args.max_epoch:
val_scores.append((epoch, loss)) val_scores.append((epoch, loss))
...@@ -40,9 +41,11 @@ def main(args): ...@@ -40,9 +41,11 @@ def main(args):
args.ckpt_dir + '/{}.pdparams'.format(int(epoch)) args.ckpt_dir + '/{}.pdparams'.format(int(epoch))
for epoch in sorted_val_scores[:args.num, 0] for epoch in sorted_val_scores[:args.num, 0]
] ]
print("best val scores = " + str(sorted_val_scores[:args.num, 1]))
print("selected epochs = " + str(sorted_val_scores[:args.num, 0].astype( beat_val_scores = sorted_val_scores[:args.num, 1]
np.int64))) selected_epochs = sorted_val_scores[:args.num, 0].astype(np.int64)
print("best val scores = " + str(beat_val_scores))
print("selected epochs = " + str(selected_epochs))
else: else:
path_list = glob.glob(f'{args.ckpt_dir}/[!avg][!final]*.pdparams') path_list = glob.glob(f'{args.ckpt_dir}/[!avg][!final]*.pdparams')
path_list = sorted(path_list, key=os.path.getmtime) path_list = sorted(path_list, key=os.path.getmtime)
...@@ -64,11 +67,21 @@ def main(args): ...@@ -64,11 +67,21 @@ def main(args):
# average # average
for k in avg.keys(): for k in avg.keys():
if avg[k] is not None: if avg[k] is not None:
avg[k] = paddle.divide(avg[k], num) avg[k] /= num
paddle.save(avg, args.dst_model) paddle.save(avg, args.dst_model)
print(f'Saving to {args.dst_model}') print(f'Saving to {args.dst_model}')
meta_path = os.path.splitext(args.dst_model)[0] + '.avg.json'
with open(meta_path, 'w') as f:
data = json.dumps({
"avg_ckpt": args.dst_model,
"ckpt": path_list,
"epoch": selected_epochs.tolist(),
"val_loss": beat_val_scores.tolist(),
})
f.write(data + "\n")
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='average model') parser = argparse.ArgumentParser(description='average model')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册