diff --git a/ernie/infer_classifyer.py b/ernie/infer_classifyer.py index ae6114d3178afa9fcc1879d5f8e76ef370c43a88..20ac683f0bc95b7a44f7a22e4ccbb143808383fd 100644 --- a/ernie/infer_classifyer.py +++ b/ernie/infer_classifyer.py @@ -173,17 +173,13 @@ def main(args): def array2tensor(ndarray): """ convert numpy array to PaddleTensor""" assert isinstance(ndarray, np.ndarray), "input type must be np.ndarray" - tensor = PaddleTensor() - tensor.name = "data" - tensor.shape = ndarray.shape - if "float" in str(ndarray.dtype): - tensor.dtype = PaddleDType.FLOAT32 - elif "int" in str(ndarray.dtype): - tensor.dtype = PaddleDType.INT64 + if ndarray.dtype == np.float32: + dtype = PaddleDType.FLOAT32 + elif ndarray.dtype == np.int64: + dtype = PaddleDType.INT64 else: - raise ValueError("{} type ndarray is unsupported".format(tensor.dtype)) - - tensor.data = PaddleBuf(ndarray.flatten().tolist()) + raise ValueError("{} type ndarray is unsupported".format(ndarray.dtype)) + tensor = PaddleTensor(data=ndarray, name="data") return tensor if __name__ == '__main__':