提交 b8fae8a6 编写于 作者: Y Yibing Liu

Revise the comments for arguments

上级 48c1845f
...@@ -29,22 +29,23 @@ def parse_args(): ...@@ -29,22 +29,23 @@ def parse_args():
'--mean_var', '--mean_var',
type=str, type=str,
default='data/global_mean_var_search26kHr', default='data/global_mean_var_search26kHr',
help='mean var path') help="The path for feature's global mean and variance. "
"(default: %(default)s)")
parser.add_argument( parser.add_argument(
'--infer_feature_lst', '--infer_feature_lst',
type=str, type=str,
default='data/infer_feature.lst', default='data/infer_feature.lst',
help='feature list path for inference.') help='The feature list path for inference. (default: %(default)s)')
parser.add_argument( parser.add_argument(
'--infer_label_lst', '--infer_label_lst',
type=str, type=str,
default='data/infer_label.lst', default='data/infer_label.lst',
help='label list path for inference.') help='The label list path for inference. (default: %(default)s)')
parser.add_argument( parser.add_argument(
'--model_save_path', '--model_save_path',
type=str, type=str,
default='./checkpoints/deep_asr.pass_0.model/', default='./checkpoints/deep_asr.pass_0.model/',
help='directory to save model.') help='The directory for saving model. (default: %(default)s)')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -64,7 +65,7 @@ def split_infer_result(infer_seq, lod): ...@@ -64,7 +65,7 @@ def split_infer_result(infer_seq, lod):
def infer(args): def infer(args):
""" Get one batch of feature data and predicts labels for each sample. """ Gets one batch of feature data and predicts labels for each sample.
""" """
if not os.path.exists(args.model_save_path): if not os.path.exists(args.model_save_path):
......
...@@ -72,33 +72,34 @@ def parse_args(): ...@@ -72,33 +72,34 @@ def parse_args():
'--mean_var', '--mean_var',
type=str, type=str,
default='data/global_mean_var_search26kHr', default='data/global_mean_var_search26kHr',
help='mean var path') help="The path for feature's global mean and variance. "
"(default: %(default)s)")
parser.add_argument( parser.add_argument(
'--train_feature_lst', '--train_feature_lst',
type=str, type=str,
default='data/feature.lst', default='data/feature.lst',
help='feature list path for training.') help='The feature list path for training. (default: %(default)s)')
parser.add_argument( parser.add_argument(
'--train_label_lst', '--train_label_lst',
type=str, type=str,
default='data/label.lst', default='data/label.lst',
help='label list path for training.') help='The label list path for training. (default: %(default)s)')
parser.add_argument( parser.add_argument(
'--val_feature_lst', '--val_feature_lst',
type=str, type=str,
default='data/val_feature.lst', default='data/val_feature.lst',
help='feature list path for validation.') help='The feature list path for validation. (default: %(default)s)')
parser.add_argument( parser.add_argument(
'--val_label_lst', '--val_label_lst',
type=str, type=str,
default='data/val_label.lst', default='data/val_label.lst',
help='label list path for validation.') help='The label list path for validation. (default: %(default)s)')
parser.add_argument( parser.add_argument(
'--model_save_dir', '--model_save_dir',
type=str, type=str,
default='./checkpoints', default='./checkpoints',
help='directory to save model. Do not save model if set to ' help="The directory for saving model. Do not save model if set to "
'.') "''. (default: %(default)s)")
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -114,8 +115,6 @@ def train(args): ...@@ -114,8 +115,6 @@ def train(args):
"""train in loop. """train in loop.
""" """
# prediction, avg_cost, accuracy = stacked_lstmp_model(args.hidden_dim,
# args.proj_dim, args.stacked_num, class_num=1749, args.parallel)
prediction, avg_cost, accuracy = stacked_lstmp_model( prediction, avg_cost, accuracy = stacked_lstmp_model(
hidden_dim=args.hidden_dim, hidden_dim=args.hidden_dim,
proj_dim=args.proj_dim, proj_dim=args.proj_dim,
...@@ -206,7 +205,7 @@ def train(args): ...@@ -206,7 +205,7 @@ def train(args):
sys.stdout.flush() sys.stdout.flush()
# run test # run test
val_cost, val_acc = test(exe) val_cost, val_acc = test(exe)
# save model # save model
if args.model_save_dir != '': if args.model_save_dir != '':
model_path = os.path.join( model_path = os.path.join(
args.model_save_dir, "deep_asr.pass_" + str(pass_id) + ".model") args.model_save_dir, "deep_asr.pass_" + str(pass_id) + ".model")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册