未验证 提交 301a0317 编写于 作者: W whs 提交者: GitHub

Fix data layer of ocr model. (#3852)

上级 e50325d4
......@@ -232,9 +232,9 @@ def ctc_infer(images, num_classes, use_cudnn=True):
def ctc_eval(data_shape, num_classes, use_cudnn=True):
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
label = fluid.layers.data(
name='label', shape=[1], dtype='int32', lod_level=1)
images = fluid.data(name='pixel', shape=[None]+data_shape, dtype='float32')
label = fluid.data(
name='label', shape=[None, 1], dtype='int32', lod_level=1)
fc_out = encoder_net(images, num_classes, is_test=True, use_cudnn=use_cudnn)
decoded_out = fluid.layers.ctc_greedy_decoder(
input=fc_out, blank=num_classes)
......
......@@ -54,7 +54,7 @@ def inference(args):
num_classes = data_reader.num_classes()
data_shape = data_reader.data_shape()
# define network
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
images = fluid.data(name='pixel', shape=[None] + data_shape, dtype='float32')
ids = infer(images, num_classes, use_cudnn=True if args.use_gpu else False)
# data reader
infer_reader = data_reader.inference(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册