提交 c225d7a2 编写于 作者: Y Yibing Liu

Use real test data for fit_a_line model to get valid result

上级 74dce390
...@@ -110,14 +110,23 @@ def infer(use_cuda, save_dirname=None): ...@@ -110,14 +110,23 @@ def infer(use_cuda, save_dirname=None):
# The input's dimension should be 2-D and the second dim is 13 # The input's dimension should be 2-D and the second dim is 13
# The input data should be >= 0 # The input data should be >= 0
batch_size = 10 batch_size = 10
tensor_x = numpy.random.uniform(0, 10,
[batch_size, 13]).astype("float32") test_reader = paddle.batch(
paddle.dataset.uci_housing.test(), batch_size=batch_size)
test_data = test_reader().next()
test_feat = numpy.array(
[data[0] for data in test_data]).astype("float32")
test_label = numpy.array(
[data[1] for data in test_data]).astype("float32")
assert feed_target_names[0] == 'x' assert feed_target_names[0] == 'x'
results = exe.run(inference_program, results = exe.run(inference_program,
feed={feed_target_names[0]: tensor_x}, feed={feed_target_names[0]: numpy.array(test_feat)},
fetch_list=fetch_targets) fetch_list=fetch_targets)
print("infer shape: ", results[0].shape) print("infer shape: ", results[0].shape)
print("infer results: ", results[0]) print("infer results: ", results[0])
print("ground truth: ", test_label)
def main(use_cuda, is_local=True): def main(use_cuda, is_local=True):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册