提交 0e9b75f9 编写于 作者: M MRXLT 提交者: GitHub

Merge pull request #478 from wangxicoding/fix_result_dtype

fix result numpy dtype
......@@ -273,7 +273,7 @@ class Client(object):
if self.fetch_names_to_type_[name] == int_type:
result_map[name] = result_batch.get_int64_by_name(mi, name)
shape = result_batch.get_shape(mi, name)
result_map[name] = np.array(result_map[name])
result_map[name] = np.array(result_map[name], dtype='int64')
result_map[name].shape = shape
if name in self.lod_tensor_set:
result_map["{}.lod".format(
......@@ -281,7 +281,8 @@ class Client(object):
elif self.fetch_names_to_type_[name] == float_type:
result_map[name] = result_batch.get_float_by_name(mi, name)
shape = result_batch.get_shape(mi, name)
result_map[name] = np.array(result_map[name])
result_map[name] = np.array(
result_map[name], dtype='float32')
result_map[name].shape = shape
if name in self.lod_tensor_set:
result_map["{}.lod".format(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册