提交 a1cd7eac 编写于 作者: Y Yibing Liu

Adapt to the async data reader

上级 0ba71d92
......@@ -12,7 +12,7 @@ import paddle.fluid as fluid
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta
import data_utils.augmentor.trans_splice as trans_splice
import data_utils.data_reader as reader
import data_utils.async_data_reader as reader
from data_utils.util import lodtensor_to_ndarray
from model_utils.model import stacked_lstmp_model
......@@ -127,8 +127,8 @@ def infer_from_ckpt(args):
label_t = fluid.LoDTensor()
# infer data reader
infer_data_reader = reader.DataReader(args.infer_feature_lst,
args.infer_label_lst)
infer_data_reader = reader.AsyncDataReader(args.infer_feature_lst,
args.infer_label_lst)
infer_data_reader.set_transformers(ltrans)
infer_costs, infer_accs = [], []
for batch_id, batch_data in enumerate(
......@@ -136,10 +136,12 @@ def infer_from_ckpt(args):
args.minimum_batch_size)):
# load_data
(features, labels, lod) = batch_data
feature_t.set(features, place)
feature_t.set_lod([lod])
label_t.set(labels, place)
label_t.set_lod([lod])
feature_t.set(features.ndarray, place)
feature_t.set_lod([lod.ndarray])
label_t.set(labels.ndarray, place)
label_t.set_lod([lod.ndarray])
infer_data_reader.recycle(features, labels, lod)
cost, acc = exe.run(infer_program,
feed={"feature": feature_t,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册