From 2e59769612f6c9680cdfe89db6c56adc18693739 Mon Sep 17 00:00:00 2001 From: furnace <34057289+windstamp@users.noreply.github.com> Date: Fri, 11 Sep 2020 09:02:04 +0800 Subject: [PATCH] add empty op (c++, python, unit test) (#26659) --- paddle/fluid/operators/empty_op.cc | 132 +++++++++ paddle/fluid/operators/empty_op.cu.cc | 26 ++ paddle/fluid/operators/empty_op.h | 45 +++ paddle/fluid/operators/fill_constant_op.h | 24 +- paddle/fluid/operators/gaussian_random_op.cc | 3 +- paddle/fluid/operators/gaussian_random_op.cu | 3 +- .../mkldnn/gaussian_random_mkldnn_op.cc | 3 +- paddle/fluid/operators/utils.h | 21 ++ python/paddle/__init__.py | 1 + .../fluid/tests/unittests/test_empty_op.py | 270 ++++++++++++++++++ python/paddle/tensor/__init__.py | 1 + python/paddle/tensor/creation.py | 88 ++++++ 12 files changed, 588 insertions(+), 29 deletions(-) create mode 100644 paddle/fluid/operators/empty_op.cc create mode 100644 paddle/fluid/operators/empty_op.cu.cc create mode 100644 paddle/fluid/operators/empty_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_empty_op.py diff --git a/paddle/fluid/operators/empty_op.cc b/paddle/fluid/operators/empty_op.cc new file mode 100644 index 00000000000..f539e2e6f6d --- /dev/null +++ b/paddle/fluid/operators/empty_op.cc @@ -0,0 +1,132 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/empty_op.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class EmptyOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("ShapeTensor", + "(Tensor), optional). The shape of the output." + "It has a higher priority than Attr(shape).") + .AsDispensable(); + AddInput("ShapeTensorList", + "(vector>, optional). The shape of the output. " + "It has a higher priority than Attr(shape)." + "The shape of the element in vector must be [1].") + .AsDuplicable() + .AsDispensable(); + AddAttr>("shape", + "(vector) The shape of the output") + .SetDefault({}); + AddAttr("dtype", "The data type of output tensor, Default is float") + .SetDefault(framework::proto::VarType::FP32); + AddOutput("Out", "(Tensor) The output tensor."); + AddComment(R"DOC(empty operator +Returns a tensor filled with uninitialized data. The shape of the tensor is +defined by the variable argument shape. + + +The type of the tensor is specify by `dtype`. +)DOC"); + } +}; + +class EmptyOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* context) const override { + OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "empty"); + + if (context->HasInput("ShapeTensor")) { + auto dims = context->GetInputDim("ShapeTensor"); + int num_ele = 1; + for (int i = 0; i < dims.size(); ++i) { + num_ele *= dims[i]; + } + + context->SetOutputDim("Out", framework::make_ddim({num_ele})); + } else if (context->HasInputs("ShapeTensorList")) { + std::vector out_dims; + auto dims_list = context->GetInputsDim("ShapeTensorList"); + for (size_t i = 0; i < dims_list.size(); ++i) { + auto& dims = dims_list[i]; + PADDLE_ENFORCE_EQ( + dims, framework::make_ddim({1}), + "ShapeError: The shape of Tensor in list must be [1]. " + "But received the shape " + "is [%s]", + dims); + + out_dims.push_back(dims[0]); + } + + context->SetOutputDim("Out", framework::make_ddim(out_dims)); + } else { + auto& shape = context->Attrs().Get>("shape"); + context->SetOutputDim("Out", framework::make_ddim(shape)); + } + } + + protected: + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override { + if (var_name == "ShapeTensor" || var_name == "ShapeTensorList") { + return expected_kernel_type; + } else { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& context) const override { + return framework::OpKernelType( + framework::proto::VarType::Type(context.Attr("dtype")), + context.GetPlace()); + } +}; + +class EmptyOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext* context) const override { + auto data_type = static_cast( + BOOST_GET_CONST(int, context->GetAttr("dtype"))); + context->SetOutputDataType("Out", data_type); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OPERATOR( + empty, ops::EmptyOp, ops::EmptyOpMaker, ops::EmptyOpVarTypeInference, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL(empty, ops::EmptyKernel, + ops::EmptyKernel, + ops::EmptyKernel, + ops::EmptyKernel, + ops::EmptyKernel, + ops::EmptyKernel); diff --git a/paddle/fluid/operators/empty_op.cu.cc b/paddle/fluid/operators/empty_op.cu.cc new file mode 100644 index 00000000000..22799e507ae --- /dev/null +++ b/paddle/fluid/operators/empty_op.cu.cc @@ -0,0 +1,26 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/empty_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + empty, ops::EmptyKernel, + ops::EmptyKernel, + ops::EmptyKernel, + ops::EmptyKernel, + ops::EmptyKernel, + ops::EmptyKernel); diff --git a/paddle/fluid/operators/empty_op.h b/paddle/fluid/operators/empty_op.h new file mode 100644 index 00000000000..9c913776838 --- /dev/null +++ b/paddle/fluid/operators/empty_op.h @@ -0,0 +1,45 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/utils.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class EmptyKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto dtype = static_cast( + context.Attr("dtype")); + + Tensor *out_tensor = context.Output("Out"); + + auto shape = GetShape(context); + out_tensor->Resize(shape); + + out_tensor->mutable_data(context.GetPlace(), dtype); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fill_constant_op.h b/paddle/fluid/operators/fill_constant_op.h index 74939da08b3..6fea8fe98bf 100644 --- a/paddle/fluid/operators/fill_constant_op.h +++ b/paddle/fluid/operators/fill_constant_op.h @@ -27,27 +27,6 @@ namespace operators { using Tensor = framework::Tensor; -inline framework::DDim GetShape(const framework::ExecutionContext &ctx, - std::string op_type) { - // 1. shape is a Tensor - if (ctx.HasInput("ShapeTensor")) { - auto *shape_tensor = ctx.Input("ShapeTensor"); - auto vec_shape = GetDataFromTensor(shape_tensor); - return framework::make_ddim(vec_shape); - } - - // 2. shape is a list/tuple containing Tensor - auto shape_tensor_list = ctx.MultiInput("ShapeTensorList"); - if (shape_tensor_list.size() > 0) { - auto vec_shape = GetDataFromTensorList(shape_tensor_list); - return framework::make_ddim(vec_shape); - } - - // 3. shape is a list/tuple without containing Tensor - auto vec_shape = ctx.Attr>("shape"); - return framework::make_ddim(vec_shape); -} - template class FillConstantKernel : public framework::OpKernel { public: @@ -93,8 +72,7 @@ class FillConstantKernel : public framework::OpKernel { } value = tensor_data[0]; } - const std::string op_type = "fill_constant"; - auto shape = GetShape(ctx, op_type); + auto shape = GetShape(ctx); if (out_var->IsType()) { tensor = out_var->GetMutable(); diff --git a/paddle/fluid/operators/gaussian_random_op.cc b/paddle/fluid/operators/gaussian_random_op.cc index 4f128463375..17a71c67b8a 100644 --- a/paddle/fluid/operators/gaussian_random_op.cc +++ b/paddle/fluid/operators/gaussian_random_op.cc @@ -34,8 +34,7 @@ class CPUGaussianRandomKernel : public framework::OpKernel { auto* tensor = context.Output("Out"); std::normal_distribution dist(mean, std); - const std::string op_type = "gaussian_random"; - auto shape = GetShape(context, op_type); + auto shape = GetShape(context); tensor->Resize(shape); int64_t size = tensor->numel(); T* data = tensor->mutable_data(context.GetPlace()); diff --git a/paddle/fluid/operators/gaussian_random_op.cu b/paddle/fluid/operators/gaussian_random_op.cu index 69c8b600406..7a0c93eb1b2 100644 --- a/paddle/fluid/operators/gaussian_random_op.cu +++ b/paddle/fluid/operators/gaussian_random_op.cu @@ -58,8 +58,7 @@ class GPUGaussianRandomKernel : public framework::OpKernel { T mean = static_cast(context.Attr("mean")); T std = static_cast(context.Attr("std")); thrust::counting_iterator index_sequence_begin(0); - const std::string op_type = "gaussian_random"; - auto shape = GetShape(context, op_type); + auto shape = GetShape(context); tensor->Resize(shape); T* data = tensor->mutable_data(context.GetPlace()); diff --git a/paddle/fluid/operators/mkldnn/gaussian_random_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/gaussian_random_mkldnn_op.cc index 98200caca8c..51fa5ad021a 100644 --- a/paddle/fluid/operators/mkldnn/gaussian_random_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/gaussian_random_mkldnn_op.cc @@ -30,8 +30,7 @@ class GaussianMKLDNNKernel : public paddle::framework::OpKernel { float std = context.Attr("std"); auto* tensor = context.Output("Out"); - const std::string op_type = "gaussian_random"; - auto shape = GetShape(context, op_type); + auto shape = GetShape(context); tensor->Resize(shape); T* data = tensor->mutable_data(context.GetPlace()); int64_t size = tensor->numel(); diff --git a/paddle/fluid/operators/utils.h b/paddle/fluid/operators/utils.h index e53981a5365..aec995304a7 100644 --- a/paddle/fluid/operators/utils.h +++ b/paddle/fluid/operators/utils.h @@ -81,5 +81,26 @@ inline std::vector GetDataFromTensorList( } return vec_new_data; } + +inline framework::DDim GetShape(const framework::ExecutionContext& ctx) { + // 1. shape is a Tensor + if (ctx.HasInput("ShapeTensor")) { + auto* shape_tensor = ctx.Input("ShapeTensor"); + auto vec_shape = GetDataFromTensor(shape_tensor); + return framework::make_ddim(vec_shape); + } + + // 2. shape is a list/tuple containing Tensor + auto shape_tensor_list = ctx.MultiInput("ShapeTensorList"); + if (shape_tensor_list.size() > 0) { + auto vec_shape = GetDataFromTensorList(shape_tensor_list); + return framework::make_ddim(vec_shape); + } + + // 3. shape is a list/tuple without containing Tensor + auto vec_shape = ctx.Attr>("shape"); + return framework::make_ddim(vec_shape); +} + } // namespace operators } // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index d5793eb424a..ed0b415d0bf 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -75,6 +75,7 @@ from .tensor.creation import full_like #DEFINE_ALIAS from .tensor.creation import triu #DEFINE_ALIAS from .tensor.creation import tril #DEFINE_ALIAS from .tensor.creation import meshgrid #DEFINE_ALIAS +from .tensor.creation import empty #DEFINE_ALIAS from .tensor.linalg import matmul #DEFINE_ALIAS from .tensor.linalg import dot #DEFINE_ALIAS # from .tensor.linalg import einsum #DEFINE_ALIAS diff --git a/python/paddle/fluid/tests/unittests/test_empty_op.py b/python/paddle/fluid/tests/unittests/test_empty_op.py new file mode 100644 index 00000000000..e8b1f836fca --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_empty_op.py @@ -0,0 +1,270 @@ +#Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +from op_test import OpTest +from paddle.fluid import Program, program_guard +from paddle.fluid.framework import convert_np_dtype_to_dtype_ + + +# Situation 1: Attr(shape) is a list(without tensor) +class TestEmptyOp(OpTest): + def setUp(self): + self.op_type = "empty" + self.init_config() + + def test_check_output(self): + self.check_output_customized(self.verify_output) + + def verify_output(self, outs): + data_type = outs[0].dtype + if data_type in ['float32', 'float64', 'int32', 'int64']: + max_value = np.nanmax(outs[0]) + min_value = np.nanmin(outs[0]) + + always_full_zero = max_value == 0.0 and min_value == 0.0 + always_non_full_zero = max_value > min_value + self.assertTrue(always_full_zero or always_non_full_zero, + 'always_full_zero or always_non_full_zero.') + elif data_type in ['bool']: + total_num = outs[0].size + true_num = np.sum(outs[0] == True) + false_num = np.sum(outs[0] == False) + self.assertTrue(total_num == true_num + false_num, + 'The value should always be True or False.') + else: + self.assertTrue(False, 'invalid data type') + + def init_config(self): + shape = [500, 3] + dtype = 'float32' + dtype_inner = convert_np_dtype_to_dtype_(dtype) + self.attrs = {'shape': shape, 'dtype': dtype_inner} + self.inputs = {} + self.outputs = {'Out': np.zeros(shape).astype(dtype)} + + +class TestEmptyOp2(TestEmptyOp): + def init_config(self): + shape = [500, 3] + dtype = 'float64' + dtype_inner = convert_np_dtype_to_dtype_(dtype) + self.attrs = {'shape': shape, 'dtype': dtype_inner} + self.inputs = {} + self.outputs = {'Out': np.zeros(shape).astype(dtype)} + + +class TestEmptyOp3(TestEmptyOp): + def init_config(self): + shape = [500, 3] + dtype = 'int32' + dtype_inner = convert_np_dtype_to_dtype_(dtype) + self.attrs = {'shape': shape, 'dtype': dtype_inner} + self.inputs = {} + self.outputs = {'Out': np.zeros(shape).astype(dtype)} + + +class TestEmptyOp4(TestEmptyOp): + def init_config(self): + shape = [500, 3] + dtype = 'int64' + dtype_inner = convert_np_dtype_to_dtype_(dtype) + self.attrs = {'shape': shape, 'dtype': dtype_inner} + self.inputs = {} + self.outputs = {'Out': np.zeros(shape).astype(dtype)} + + +class TestEmptyOp5(TestEmptyOp): + def init_config(self): + shape = [500, 3] + dtype = 'bool' + dtype_inner = convert_np_dtype_to_dtype_(dtype) + self.attrs = {'shape': shape, 'dtype': dtype_inner} + self.inputs = {} + self.outputs = {'Out': np.zeros(shape).astype(dtype)} + + +# Situation 2: shape is a tensor +class TestEmptyOp_ShapeTensor(OpTest): + def setUp(self): + self.op_type = "empty" + self.init_config() + + def init_config(self): + self.shape = [500, 3] + dtype = 'float32' + dtype_inner = convert_np_dtype_to_dtype_(dtype) + self.attrs = {'dtype': dtype_inner} + self.inputs = {"ShapeTensor": np.array(self.shape).astype("int32")} + self.outputs = {'Out': np.zeros(self.shape).astype(dtype)} + + def test_check_output(self): + self.check_output_customized(self.verify_output) + + def verify_output(self, outs): + data_type = outs[0].dtype + if data_type in ['float32', 'float64', 'int32', 'int64']: + max_value = np.nanmax(outs[0]) + min_value = np.nanmin(outs[0]) + + always_full_zero = max_value == 0.0 and min_value == 0.0 + always_non_full_zero = max_value > min_value + self.assertTrue(always_full_zero or always_non_full_zero, + 'always_full_zero or always_non_full_zero.') + elif data_type in ['bool']: + total_num = outs[0].size + true_num = np.sum(outs[0] == True) + false_num = np.sum(outs[0] == False) + self.assertTrue(total_num == true_num + false_num, + 'The value should always be True or False.') + else: + self.assertTrue(False, 'invalid data type') + + +# Situation 3: Attr(shape) is a list(with tensor) +class TestEmptyOp_ShapeTensorList(OpTest): + def setUp(self): + self.op_type = "empty" + self.init_config() + + def init_config(self): + self.shape = [123, 92] + self.infer_shape = [-1, 92] + + dtype = 'float32' + dtype_inner = convert_np_dtype_to_dtype_(dtype) + + shape_tensor_list = [] + for index, ele in enumerate(self.shape): + shape_tensor_list.append(("x" + str(index), np.ones( + (1)).astype('int32') * ele)) + + self.inputs = {"ShapeTensorList": shape_tensor_list} + self.attrs = {'shape': self.infer_shape, 'dtype': dtype_inner} + self.outputs = {'Out': np.zeros(self.shape).astype(dtype)} + + def test_check_output(self): + self.check_output_customized(self.verify_output) + + def verify_output(self, outs): + data_type = outs[0].dtype + if data_type in ['float32', 'float64', 'int32', 'int64']: + max_value = np.nanmax(outs[0]) + min_value = np.nanmin(outs[0]) + + always_full_zero = max_value == 0.0 and min_value == 0.0 + always_non_full_zero = max_value > min_value + self.assertTrue(always_full_zero or always_non_full_zero, + 'always_full_zero or always_non_full_zero.') + elif data_type in ['bool']: + total_num = outs[0].size + true_num = np.sum(outs[0] == True) + false_num = np.sum(outs[0] == False) + self.assertTrue(total_num == true_num + false_num, + 'The value should always be True or False.') + else: + self.assertTrue(False, 'invalid data type') + + +class TestEmptyAPI(unittest.TestCase): + def __check_out__(self, out, dtype='float32'): + max_value = np.nanmax(np.array(out)) + min_value = np.nanmin(np.array(out)) + always_non_full_zero = max_value > min_value + always_full_zero = max_value == 0.0 and min_value == 0.0 + self.assertTrue(always_full_zero or always_non_full_zero, + 'always_full_zero or always_non_full_zero.') + + def test_dygraph_api_out(self): + paddle.disable_static() + shape = [200, 3] + out = paddle.empty(shape=shape) + self.__check_out__(out) + paddle.enable_static() + + def test_dygraph_api_out_2(self): + paddle.disable_static() + shape_data = np.array([200, 3]).astype('int32') + shape = paddle.to_tensor(shape_data) + out = paddle.empty(shape=shape) + self.__check_out__(out) + paddle.enable_static() + + def test_dygraph_api_out_3(self): + paddle.disable_static() + shape_data = np.array([200, 3]).astype('int64') + shape = paddle.to_tensor(shape_data) + out = paddle.empty(shape=shape) + self.__check_out__(out) + paddle.enable_static() + + def test_dygraph_api_attr(self): + paddle.disable_static() + shape = [200, 3] + dtype = 'float64' + out = paddle.empty(shape=shape, dtype=dtype) + self.__check_out__(out, dtype) + paddle.enable_static() + + def test_static_graph(self): + dtype = 'float64' + + positive_2_int32 = fluid.layers.fill_constant([1], "int32", 3) + positive_2_int64 = fluid.layers.fill_constant([1], "int64", 3) + + shape_tensor_int32 = fluid.data( + name="shape_tensor_int32", shape=[2], dtype="int32") + shape_tensor_int64 = fluid.data( + name="shape_tensor_int64", shape=[2], dtype="int64") + + out_1 = paddle.empty(shape=[200, 3], dtype=dtype) + out_2 = paddle.empty(shape=shape_tensor_int32, dtype=dtype) + out_3 = paddle.empty(shape=shape_tensor_int64, dtype=dtype) + out_4 = paddle.empty(shape=[200, positive_2_int32], dtype=dtype) + out_5 = paddle.empty(shape=[200, positive_2_int64], dtype=dtype) + + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + res_1, res_2, res_3, res_4, res_5 = exe.run( + fluid.default_main_program(), + feed={ + "shape_tensor_int32": np.array([200, 3]).astype("int32"), + "shape_tensor_int64": np.array([200, 3]).astype("int64"), + }, + fetch_list=[out_1, out_2, out_3, out_4, out_5]) + + self.__check_out__(res_1, dtype) + self.__check_out__(res_2, dtype) + self.__check_out__(res_3, dtype) + self.__check_out__(res_4, dtype) + self.__check_out__(res_5, dtype) + + +class TestEmptyError(unittest.TestCase): + def test_attr(self): + def test_dtype(): + shape = [200, 3] + dtype = 'uint8' + result = paddle.empty(shape=shape, dtype=dtype) + + self.assertRaises(TypeError, test_dtype) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 0fed32a1676..8bb584be236 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -40,6 +40,7 @@ from .creation import full_like #DEFINE_ALIAS from .creation import triu #DEFINE_ALIAS from .creation import tril #DEFINE_ALIAS from .creation import meshgrid #DEFINE_ALIAS +from .creation import empty #DEFINE_ALIAS from .io import save #DEFINE_ALIAS from .io import load #DEFINE_ALIAS from .linalg import matmul #DEFINE_ALIAS diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 9eece1240d7..b75e2a8851f 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -48,6 +48,7 @@ __all__ = [ 'eye', 'full', 'full_like', + 'empty', 'triu', 'tril', 'meshgrid' @@ -981,3 +982,90 @@ def diag(x, offset=0, padding_value=0, name=None): out.stop_gradient = True return out + + +def empty(shape, dtype=None, name=None): + """ + This Op returns a Tensor with uninitialized data which size is same as ``shape``. + + Args: + shape(list|tuple|Tensor): Shape of the Tensor to be created. + The data type of dimension of shape is ``int32`` or ``int64`` . If ``shape`` is a list or tuple, + the elements of it should be integers or Tensors with shape [1]. + If ``shape`` is an Tensor, it should be an 1-D Tensor. + dtype(np.dtype|str, optional): Data type of the output Tensor + which can be bool, float16, float32, float64, int32, int64, if dytpe is `None`, the data + type of created Tensor use global default dtype (see ``get_default_dtype`` + for details). + name(str, optional): The default value is None. Normally there is no need for user to set this + property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: Tensor which is created according to ``shape`` and ``dtype``, and is uninitialized. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() # Now we are in imperative mode + paddle.set_device("cpu") # and use cpu device + + # example 1: argument ``shape`` is a list which doesn't contain Tensor. + data1 = paddle.empty(shape=[2,3], dtype='float32') + #[[4.3612203e+27 1.8176809e+31 1.3555911e-19] # uninitialized + # [1.1699684e-19 1.3563156e-19 3.6408321e-11]] # uninitialized + + # example 2: argument ``shape`` is a Tensor, the data type must be int64 or int32. + shape_data = np.array([2, 3]).astype('int32') + shape = paddle.to_tensor(shape_data) + data2 = paddle.empty(shape=shape, dtype='float32') + #[[1.7192326e-37 4.8125365e-38 1.9866003e-36] # uninitialized + # [1.3284029e-40 7.1117408e-37 2.5353012e+30]] # uninitialized + + # example 3: argument ``shape`` is a list which contains Tensor. + dim2_data = np.array([3]).astype('int32') + dim2 = paddle.to_tensor(dim2_data) + data3 = paddle.empty(shape=[2, dim2], dtype='float32') + #[[1.1024214e+24 7.0379409e+22 6.5737699e-34] # uninitialized + # [7.5563101e+31 7.7130405e+31 2.8020654e+20]] # uninitialized + """ + + if dtype is None: + dtype = paddle.get_default_dtype() + + dtype = convert_dtype(dtype) + + if in_dygraph_mode(): + shape = utils.convert_shape_to_list(shape) + out = core.ops.empty('shape', shape, 'dtype', + convert_np_dtype_to_dtype_(dtype)) + out.stop_gradient = True + return out + + helper = LayerHelper("empty", **locals()) + inputs = {} + + check_dtype(dtype, 'dtype', + ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], + 'empty') + check_type(shape, 'shape', (Variable, list, tuple), 'empty') + + if isinstance(shape, Variable): + check_dtype(shape.dtype, 'shape', ['int32', 'int64'], 'empty') + + attrs = {} + utils.get_shape_tensor_inputs( + inputs=inputs, attrs=attrs, shape=shape, op_type='empty') + + out = helper.create_variable_for_type_inference(dtype=dtype) + attrs['dtype'] = convert_np_dtype_to_dtype_(dtype) + helper.append_op( + type='empty', + inputs=inputs, + outputs={'Out': [out]}, + attrs=attrs, + stop_gradient=True) + out.stop_gradient = True + return out -- GitLab