提交 7b472152 编写于 作者: B breezedeus

use `mod.predict` directly

上级 2055e6f7
......@@ -252,9 +252,7 @@ class CnOcr(object):
batch_size = len(img_list)
img_list, img_widths = self._pad_arrays(img_list)
sample = SimpleBatch(data_names=['data'], data=[mx.nd.array(img_list)])
prob = self._predict(sample)
prob = self._predict(mx.nd.array(img_list))
# [seq_len, batch_size, num_classes]
prob = np.reshape(prob, (-1, batch_size, prob.shape[1]))
......@@ -310,12 +308,9 @@ class CnOcr(object):
return padded_img_list, img_widths
def _predict(self, sample):
mod = self._mod
mod.forward(sample)
prob = mod.get_outputs()[0]
prob = self._mod.predict(sample)
mx.nd.waitall()
prob = prob.asnumpy()
return prob
return prob.asnumpy()
def _gen_line_pred_chars(self, line_prob, img_width, max_img_width):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册