提交 a3e96ed4 编写于 作者: C chenxuyi 提交者: Meiyim

1.6compat: PaddlTensor

上级 4d0c99f7
...@@ -173,17 +173,13 @@ def main(args): ...@@ -173,17 +173,13 @@ def main(args):
def array2tensor(ndarray): def array2tensor(ndarray):
""" convert numpy array to PaddleTensor""" """ convert numpy array to PaddleTensor"""
assert isinstance(ndarray, np.ndarray), "input type must be np.ndarray" assert isinstance(ndarray, np.ndarray), "input type must be np.ndarray"
tensor = PaddleTensor() if ndarray.dtype == np.float32:
tensor.name = "data" dtype = PaddleDType.FLOAT32
tensor.shape = ndarray.shape elif ndarray.dtype == np.int64:
if "float" in str(ndarray.dtype): dtype = PaddleDType.INT64
tensor.dtype = PaddleDType.FLOAT32
elif "int" in str(ndarray.dtype):
tensor.dtype = PaddleDType.INT64
else: else:
raise ValueError("{} type ndarray is unsupported".format(tensor.dtype)) raise ValueError("{} type ndarray is unsupported".format(ndarray.dtype))
tensor = PaddleTensor(data=ndarray, name="data")
tensor.data = PaddleBuf(ndarray.flatten().tolist())
return tensor return tensor
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册