From a3e96ed41129dfe5d05c4e7a9638b625821e670e Mon Sep 17 00:00:00 2001 From: chenxuyi Date: Mon, 4 Nov 2019 17:44:42 +0800 Subject: [PATCH] 1.6compat: PaddlTensor --- ernie/infer_classifyer.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/ernie/infer_classifyer.py b/ernie/infer_classifyer.py index ae6114d..20ac683 100644 --- a/ernie/infer_classifyer.py +++ b/ernie/infer_classifyer.py @@ -173,17 +173,13 @@ def main(args): def array2tensor(ndarray): """ convert numpy array to PaddleTensor""" assert isinstance(ndarray, np.ndarray), "input type must be np.ndarray" - tensor = PaddleTensor() - tensor.name = "data" - tensor.shape = ndarray.shape - if "float" in str(ndarray.dtype): - tensor.dtype = PaddleDType.FLOAT32 - elif "int" in str(ndarray.dtype): - tensor.dtype = PaddleDType.INT64 + if ndarray.dtype == np.float32: + dtype = PaddleDType.FLOAT32 + elif ndarray.dtype == np.int64: + dtype = PaddleDType.INT64 else: - raise ValueError("{} type ndarray is unsupported".format(tensor.dtype)) - - tensor.data = PaddleBuf(ndarray.flatten().tolist()) + raise ValueError("{} type ndarray is unsupported".format(ndarray.dtype)) + tensor = PaddleTensor(data=ndarray, name="data") return tensor if __name__ == '__main__': -- GitLab