提交 b8fae8a6 编写于 作者: Y Yibing Liu

Revise the comments for arguments

上级 48c1845f
......@@ -29,22 +29,23 @@ def parse_args():
'--mean_var',
type=str,
default='data/global_mean_var_search26kHr',
help='mean var path')
help="The path for feature's global mean and variance. "
"(default: %(default)s)")
parser.add_argument(
'--infer_feature_lst',
type=str,
default='data/infer_feature.lst',
help='feature list path for inference.')
help='The feature list path for inference. (default: %(default)s)')
parser.add_argument(
'--infer_label_lst',
type=str,
default='data/infer_label.lst',
help='label list path for inference.')
help='The label list path for inference. (default: %(default)s)')
parser.add_argument(
'--model_save_path',
type=str,
default='./checkpoints/deep_asr.pass_0.model/',
help='directory to save model.')
help='The directory for saving model. (default: %(default)s)')
args = parser.parse_args()
return args
......@@ -64,7 +65,7 @@ def split_infer_result(infer_seq, lod):
def infer(args):
""" Get one batch of feature data and predicts labels for each sample.
""" Gets one batch of feature data and predicts labels for each sample.
"""
if not os.path.exists(args.model_save_path):
......
......@@ -72,33 +72,34 @@ def parse_args():
'--mean_var',
type=str,
default='data/global_mean_var_search26kHr',
help='mean var path')
help="The path for feature's global mean and variance. "
"(default: %(default)s)")
parser.add_argument(
'--train_feature_lst',
type=str,
default='data/feature.lst',
help='feature list path for training.')
help='The feature list path for training. (default: %(default)s)')
parser.add_argument(
'--train_label_lst',
type=str,
default='data/label.lst',
help='label list path for training.')
help='The label list path for training. (default: %(default)s)')
parser.add_argument(
'--val_feature_lst',
type=str,
default='data/val_feature.lst',
help='feature list path for validation.')
help='The feature list path for validation. (default: %(default)s)')
parser.add_argument(
'--val_label_lst',
type=str,
default='data/val_label.lst',
help='label list path for validation.')
help='The label list path for validation. (default: %(default)s)')
parser.add_argument(
'--model_save_dir',
type=str,
default='./checkpoints',
help='directory to save model. Do not save model if set to '
'.')
help="The directory for saving model. Do not save model if set to "
"''. (default: %(default)s)")
args = parser.parse_args()
return args
......@@ -114,8 +115,6 @@ def train(args):
"""train in loop.
"""
# prediction, avg_cost, accuracy = stacked_lstmp_model(args.hidden_dim,
# args.proj_dim, args.stacked_num, class_num=1749, args.parallel)
prediction, avg_cost, accuracy = stacked_lstmp_model(
hidden_dim=args.hidden_dim,
proj_dim=args.proj_dim,
......@@ -206,7 +205,7 @@ def train(args):
sys.stdout.flush()
# run test
val_cost, val_acc = test(exe)
# save model
# save model
if args.model_save_dir != '':
model_path = os.path.join(
args.model_save_dir, "deep_asr.pass_" + str(pass_id) + ".model")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册