未验证 提交 cb764459 编写于 作者: W whs 提交者: GitHub

Fix atention model. (#2059)

上级 a9c4cbb9
......@@ -339,7 +339,7 @@ def attention_infer(images, num_classes, use_cudnn=True):
return ids
def attention_eval(data_shape, num_classes):
def attention_eval(data_shape, num_classes, use_cudnn=True):
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
label_in = fluid.layers.data(
name='label_in', shape=[1], dtype='int32', lod_level=1)
......@@ -349,7 +349,7 @@ def attention_eval(data_shape, num_classes):
label_in = fluid.layers.cast(x=label_in, dtype='int64')
gru_backward, encoded_vector, encoded_proj = encoder_net(
images, is_test=True)
images, is_test=True, use_cudnn=use_cudnn)
backward_first = fluid.layers.sequence_pool(
input=gru_backward, pool_type='first')
......
......@@ -213,12 +213,12 @@ def ctc_train_net(args, data_shape, num_classes):
return sum_cost, error_evaluator, inference_program, model_average
def ctc_infer(images, num_classes, use_cudnn):
def ctc_infer(images, num_classes, use_cudnn=True):
fc_out = encoder_net(images, num_classes, is_test=True, use_cudnn=use_cudnn)
return fluid.layers.ctc_greedy_decoder(input=fc_out, blank=num_classes)
def ctc_eval(data_shape, num_classes, use_cudnn):
def ctc_eval(data_shape, num_classes, use_cudnn=True):
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
label = fluid.layers.data(
name='label', shape=[1], dtype='int32', lod_level=1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册