提交 73722097 编写于 作者: Y Yibing Liu

Tiny revision in infer

上级 8c937cbd
...@@ -42,10 +42,11 @@ def parse_args(): ...@@ -42,10 +42,11 @@ def parse_args():
default='data/infer_label.lst', default='data/infer_label.lst',
help='The label list path for inference. (default: %(default)s)') help='The label list path for inference. (default: %(default)s)')
parser.add_argument( parser.add_argument(
'--model_save_path', '--infer_model_path',
type=str, type=str,
default='./checkpoints/deep_asr.pass_0.model/', default='./infer_models/deep_asr.pass_0.infer.model/',
help='The directory for saving model. (default: %(default)s)') help='The directory for loading inference model. '
'(default: %(default)s)')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -68,15 +69,15 @@ def infer(args): ...@@ -68,15 +69,15 @@ def infer(args):
""" Gets one batch of feature data and predicts labels for each sample. """ Gets one batch of feature data and predicts labels for each sample.
""" """
if not os.path.exists(args.model_save_path): if not os.path.exists(args.infer_model_path):
raise IOError("Invalid model path!") raise IOError("Invalid inference model path!")
place = fluid.CUDAPlace(0) if args.device == 'GPU' else fluid.CPUPlace() place = fluid.CUDAPlace(0) if args.device == 'GPU' else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
# load model # load model
[infer_program, feed_dict, [infer_program, feed_dict,
fetch_targets] = fluid.io.load_inference_model(args.model_save_path, exe) fetch_targets] = fluid.io.load_inference_model(args.infer_model_path, exe)
ltrans = [ ltrans = [
trans_add_delta.TransAddDelta(2, 2), trans_add_delta.TransAddDelta(2, 2),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册