From 915341e3de4a98dd8296f0648c43727773324f1d Mon Sep 17 00:00:00 2001 From: wawltor Date: Sat, 4 Apr 2020 12:29:06 +0800 Subject: [PATCH] Add the zeros, ones, ones_like, zeros_like for api 2.0, test=develop (#23471) Update the new api ops of creation ops to the api 2.0 --- paddle/fluid/operators/fill_any_like_op.cc | 47 +++- paddle/fluid/operators/fill_any_like_op.cu | 1 + python/paddle/__init__.py | 10 +- python/paddle/common_ops_import.py | 13 +- .../tests/unittests/test_fill_any_like_op.py | 116 +++++++++ .../tests/unittests/test_fill_constant_op.py | 92 ++++++++ python/paddle/tensor/creation.py | 220 +++++++++++++++++- 7 files changed, 482 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/operators/fill_any_like_op.cc b/paddle/fluid/operators/fill_any_like_op.cc index 43a71c91726..613caca374f 100644 --- a/paddle/fluid/operators/fill_any_like_op.cc +++ b/paddle/fluid/operators/fill_any_like_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/fill_any_like_op.h" +#include namespace paddle { namespace operators { @@ -29,6 +30,25 @@ class FillAnyLikeOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->ShareLoD("X", /*->*/ "Out"); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); + const auto &data_type = ctx.Attr("dtype"); + if (data_type >= 0) { + kt.data_type_ = static_cast(data_type); + } + return kt; + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + return framework::OpKernelType(expected_kernel_type.data_type_, + expected_kernel_type.place_, + tensor.layout()); + } }; class FillAnyLikeOpMaker : public framework::OpProtoAndCheckerMaker { @@ -37,6 +57,10 @@ class FillAnyLikeOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "The input of fill-zeros-like op."); AddOutput("Out", "The variable will be filled up with specified value."); AddAttr("value", "The filled value").SetDefault(0.0); + AddAttr("dtype", + "Output tensor data type. defalut value is -1," + "according to the input dtype.") + .SetDefault(-1); AddComment(R"DOC( FillAnyLike Operator. @@ -47,18 +71,37 @@ The output will have the same shape and dtype as the input. } }; +class FillAnyLikeVarTypeInference : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto out_var_name = ctx->Output("Out").front(); + auto var_data_type = static_cast( + boost::get(ctx->GetAttr("dtype"))); + if (var_data_type < 0) { + const auto &input_var_name = ctx->Input("X").front(); + ctx->SetDataType(out_var_name, ctx->GetDataType(input_var_name)); + } else { + ctx->SetDataType(out_var_name, var_data_type); + } + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(fill_any_like, ops::FillAnyLikeOp, - ops::FillAnyLikeOpMaker); +REGISTER_OPERATOR( + fill_any_like, ops::FillAnyLikeOp, ops::FillAnyLikeOpMaker, + ::paddle::framework::EmptyGradOpMaker, + ::paddle::framework::EmptyGradOpMaker, + ops::FillAnyLikeVarTypeInference) REGISTER_OP_CPU_KERNEL( fill_any_like, ops::FillAnyLikeKernel, ops::FillAnyLikeKernel, ops::FillAnyLikeKernel, + ops::FillAnyLikeKernel, ops::FillAnyLikeKernel, ops::FillAnyLikeKernel); diff --git a/paddle/fluid/operators/fill_any_like_op.cu b/paddle/fluid/operators/fill_any_like_op.cu index 26b215d1e7f..1d8c8ace60a 100644 --- a/paddle/fluid/operators/fill_any_like_op.cu +++ b/paddle/fluid/operators/fill_any_like_op.cu @@ -22,6 +22,7 @@ REGISTER_OP_CUDA_KERNEL( ops::FillAnyLikeKernel, ops::FillAnyLikeKernel, ops::FillAnyLikeKernel, + ops::FillAnyLikeKernel, ops::FillAnyLikeKernel, ops::FillAnyLikeKernel); diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 24db112564b..d2e8ae581a0 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -42,14 +42,14 @@ import paddle.nn # from .tensor.creation import crop_.tensor #DEFINE_ALIAS # from .tensor.creation import diag #DEFINE_ALIAS # from .tensor.creation import eye #DEFINE_ALIAS -# from .tensor.creation import fill_constant #DEFINE_ALIAS +from .tensor.creation import fill_constant #DEFINE_ALIAS # from .tensor.creation import get_.tensor_from_selected_rows #DEFINE_ALIAS from .tensor.creation import linspace #DEFINE_ALIAS -# from .tensor.creation import ones #DEFINE_ALIAS -# from .tensor.creation import ones_like #DEFINE_ALIAS +from .tensor.creation import ones #DEFINE_ALIAS +from .tensor.creation import ones_like #DEFINE_ALIAS # from .tensor.creation import range #DEFINE_ALIAS -# from .tensor.creation import zeros #DEFINE_ALIAS -# from .tensor.creation import zeros_like #DEFINE_ALIAS +from .tensor.creation import zeros #DEFINE_ALIAS +from .tensor.creation import zeros_like #DEFINE_ALIAS # from .tensor.creation import arrange #DEFINE_ALIAS # from .tensor.creation import eye #DEFINE_ALIAS from .tensor.creation import full #DEFINE_ALIAS diff --git a/python/paddle/common_ops_import.py b/python/paddle/common_ops_import.py index 477ff2fe4e0..c40883db679 100644 --- a/python/paddle/common_ops_import.py +++ b/python/paddle/common_ops_import.py @@ -11,17 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from six.moves import reduce from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.param_attr import ParamAttr from paddle.fluid.framework import convert_np_dtype_to_dtype_, in_dygraph_mode, _varbase_creator -from paddle.fluid.framework import Variable, device_guard +from paddle.fluid.framework import device_guard, default_main_program, dygraph_only, _dygraph_tracer +from paddle.fluid.framework import OpProtoHolder, Variable from paddle.fluid.initializer import Constant from paddle.fluid.core import VarDesc -from paddle.fluid import core -from paddle.fluid.data_feeder import check_type, check_dtype, convert_dtype -from paddle.fluid.layers import utils -from paddle.fluid.layers import fill_constant +from paddle.fluid import core, dygraph_utils +from paddle.fluid.data_feeder import check_type, check_dtype, check_variable_and_dtype, convert_dtype +from paddle.fluid.layers import fill_constant, utils, scale +from paddle.fluid.layers.layer_function_generator import templatedoc +import paddle.fluid as fluid import numpy import warnings diff --git a/python/paddle/fluid/tests/unittests/test_fill_any_like_op.py b/python/paddle/fluid/tests/unittests/test_fill_any_like_op.py index 044ee1a7ead..68eb5d47938 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_any_like_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_any_like_op.py @@ -14,6 +14,8 @@ from __future__ import print_function +import paddle +import paddle.fluid as fluid import paddle.fluid.core as core import paddle.compat as cpt import unittest @@ -59,6 +61,23 @@ class TestFillAnyLikeOpValue3(TestFillAnyLikeOp): self.value = 1e-100 +class TestFillAnyLikeOpType(TestFillAnyLikeOp): + def setUp(self): + self.op_type = "fill_any_like" + self.dtype = np.int32 + self.value = 0.0 + self.init() + self.inputs = {'X': np.random.random((219, 232)).astype(self.dtype)} + self.attrs = { + 'value': self.value, + 'dtype': int(core.VarDesc.VarType.FP32) + } + self.outputs = { + 'Out': + self.value * np.ones_like(self.inputs["X"]).astype(np.float32) + } + + class TestFillAnyLikeOpOverflow(TestFillAnyLikeOp): def init(self): self.value = 1e100 @@ -77,5 +96,102 @@ class TestFillAnyLikeOpFloat16(TestFillAnyLikeOp): self.dtype = np.float16 +class ApiOnesLikeTest(unittest.TestCase): + def test_out(self): + with fluid.program_guard(fluid.Program()): + data = fluid.data(shape=[10], dtype="float64", name="data") + ones = paddle.ones_like(data, device="cpu") + place = fluid.CPUPlace() + exe = fluid.Executor(place) + result, = exe.run(feed={"data": np.random.rand(10)}, + fetch_list=[ones]) + expected_result = np.ones(10, dtype="float64") + self.assertEqual((result == expected_result).all(), True) + + with fluid.program_guard(fluid.Program()): + data = fluid.data(shape=[10], dtype="float64", name="data") + ones = paddle.ones_like(data, device="cpu", dtype="float32") + place = fluid.CPUPlace() + exe = fluid.Executor(place) + result, = exe.run(feed={"data": np.random.rand(10)}, + fetch_list=[ones]) + expected_result = np.ones(10, dtype="float32") + self.assertEqual((result == expected_result).all(), True) + + with fluid.program_guard(fluid.Program()): + data = fluid.data(shape=[10], dtype="float64", name="data") + ones = paddle.ones_like(data) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + result, = exe.run(feed={"data": np.random.rand(10)}, + fetch_list=[ones]) + expected_result = np.ones(10, dtype="float32") + self.assertEqual((result == expected_result).all(), True) + + +class ApiZerosLikeTest(unittest.TestCase): + def test_out(self): + with fluid.program_guard(fluid.Program()): + data = fluid.data(shape=[10], dtype="float64", name="data") + zeros = paddle.zeros_like(data, device="cpu") + place = fluid.CPUPlace() + exe = fluid.Executor(place) + result, = exe.run(feed={"data": np.random.rand(10)}, + fetch_list=[zeros]) + expected_result = np.zeros(10, dtype="float64") + self.assertEqual((result == expected_result).all(), True) + + with fluid.program_guard(fluid.Program()): + data = fluid.data(shape=[10], dtype="float64", name="data") + zeros = paddle.zeros_like(data, device="cpu", dtype="float32") + place = fluid.CPUPlace() + exe = fluid.Executor(place) + result, = exe.run(feed={"data": np.random.rand(10)}, + fetch_list=[zeros]) + expected_result = np.zeros(10, dtype="float32") + self.assertEqual((result == expected_result).all(), True) + + with fluid.program_guard(fluid.Program()): + data = fluid.data(shape=[10], dtype="float64", name="data") + zeros = paddle.zeros_like(data) + place = fluid.CPUPlace() + exe = fluid.Executor(place) + result, = exe.run(feed={"data": np.random.rand(10)}, + fetch_list=[zeros]) + expected_result = np.zeros(10, dtype="float32") + self.assertEqual((result == expected_result).all(), True) + + +class TestOnesZerosError(unittest.TestCase): + def test_errors(self): + def test_device_error1(): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data = fluid.data(name="data", shape=[10], dtype="float32") + paddle.ones_like(data, device="opu") + + self.assertRaises(ValueError, test_device_error1) + + def test_device_error2(): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data = fluid.data(name="data", shape=[10], dtype="float32") + paddle.ones_like(data, dtype="float") + + self.assertRaises(ValueError, test_device_error2) + + def test_device_error3(): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data = fluid.data(name="data", shape=[10], dtype="float32") + paddle.zeros_like(data, device="opu") + + self.assertRaises(ValueError, test_device_error3) + + def test_device_error4(): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data = fluid.data(name="data", shape=[10], dtype="float32") + paddle.zeros_like(data, dtype="float") + + self.assertRaises(ValueError, test_device_error4) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py index a9bfa14cc1a..e6a6df6bdac 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py @@ -18,6 +18,7 @@ import unittest import numpy as np from op_test import OpTest +import paddle import paddle.fluid.core as core from paddle.fluid.op import Operator import paddle.fluid as fluid @@ -81,6 +82,28 @@ class TestFillConstantOp4(OpTest): self.check_output() +class TestFillConstantOp5(unittest.TestCase): + def test_errors(self): + with fluid.program_guard(fluid.Program()): + data = fluid.data(name="X", shape=[1], dtype="float32") + out = paddle.zeros(shape=[1], out=data, dtype="float32") + place = fluid.CPUPlace() + exe = fluid.Executor(place) + result = exe.run(feed={"X": np.array( + [0.1], dtype="float32")}, + fetch_list=[data, out]) + self.assertEqual(result[0], result[1]) + with fluid.program_guard(fluid.Program()): + data = fluid.data(name="X", shape=[1], dtype="float32") + out = paddle.ones(shape=[1], out=data, dtype="float32") + place = fluid.CPUPlace() + exe = fluid.Executor(place) + result = exe.run(feed={"X": np.array( + [0.1], dtype="float32")}, + fetch_list=[data, out]) + self.assertEqual(result[0], result[1]) + + class TestFillConstantOpWithSelectedRows(unittest.TestCase): def check_with_place(self, place): scope = core.Scope() @@ -303,5 +326,74 @@ class TestFillConstantOpError(unittest.TestCase): self.assertRaises(TypeError, test_shape_tensor_list_dtype) +class ApiZerosTest(unittest.TestCase): + def test_out(self): + with fluid.program_guard(fluid.Program()): + zeros = paddle.zeros(shape=[10], dtype="float64") + place = fluid.CPUPlace() + exe = fluid.Executor(place) + result, = exe.run(fetch_list=[zeros]) + expected_result = np.zeros(10, dtype="float64") + self.assertEqual((result == expected_result).all(), True) + + with fluid.program_guard(fluid.Program()): + zeros = paddle.zeros(shape=[10], dtype="int64") + place = fluid.CPUPlace() + exe = fluid.Executor(place) + result, = exe.run(fetch_list=[zeros]) + expected_result = np.zeros(10, dtype="int64") + self.assertEqual((result == expected_result).all(), True) + + with fluid.program_guard(fluid.Program()): + zeros = paddle.zeros(shape=[10], dtype="int64", device="cpu") + place = fluid.CPUPlace() + exe = fluid.Executor(place) + result, = exe.run(fetch_list=[zeros]) + expected_result = np.zeros(10, dtype="int64") + self.assertEqual((result == expected_result).all(), True) + + +class ApiOnesTest(unittest.TestCase): + def test_out(self): + with fluid.program_guard(fluid.Program()): + ones = paddle.ones(shape=[10], dtype="float64") + place = fluid.CPUPlace() + exe = fluid.Executor(place) + result, = exe.run(fetch_list=[ones]) + expected_result = np.ones(10, dtype="float64") + self.assertEqual((result == expected_result).all(), True) + + with fluid.program_guard(fluid.Program()): + ones = paddle.ones(shape=[10], dtype="int64") + place = fluid.CPUPlace() + exe = fluid.Executor(place) + result, = exe.run(fetch_list=[ones]) + expected_result = np.ones(10, dtype="int64") + self.assertEqual((result == expected_result).all(), True) + + with fluid.program_guard(fluid.Program()): + ones = paddle.ones(shape=[10], dtype="int64", device="cpu") + place = fluid.CPUPlace() + exe = fluid.Executor(place) + result, = exe.run(fetch_list=[ones]) + expected_result = np.ones(10, dtype="int64") + self.assertEqual((result == expected_result).all(), True) + + +class ApiOnesZerosError(unittest.TestCase): + def test_errors(self): + def test_error1(): + with fluid.program_guard(fluid.Program()): + ones = paddle.ones(shape=10, dtype="int64", device="opu") + + self.assertRaises(ValueError, test_error1) + + def test_error2(): + with fluid.program_guard(fluid.Program()): + ones = paddle.ones(shape=10, dtype="int64", device="opu") + + self.assertRaises(ValueError, test_error2) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index aeadb034da2..c97de6901f9 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -25,11 +25,11 @@ __all__ = [ # 'fill_constant', # 'get_tensor_from_selected_rows', 'linspace', - # 'ones', - # 'ones_like', + 'ones', + 'ones_like', # 'range', - # 'zeros', - # 'zeros_like', + 'zeros', + 'zeros_like', # 'arrange', # 'eye', 'full', @@ -126,6 +126,218 @@ def linspace(start, stop, num, dtype, out=None, device=None, name=None): return out +def ones(shape, dtype=None, out=None, device=None): + """ + The OP creates a tensor of specified :attr:`shape` and :attr:`dtype`, and fills it with 1. + + Args: + shape(tuple|list): Shape of output tensor. + dtype(np.dtype|core.VarDesc.VarType|str): Data type of output tensor, it supports + bool, float16, float32, float64, int32 and int64. + out(Variable, optional): Optional output which can be any created + Variable that meets the requirements to store the result of operation. + if out is None, a new Varibale will be create to store the result. + device(str, optional): Which device to run the operator. The :attr:`device` must be + None,'cpu', 'gpu'. If :attr:`device` is None, it will be choose the device that the user set in + the paddle program. Default value is False. + + Returns: + Variable: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 1. + + Examples: + .. code-block:: python + + import paddle + data = paddle.ones(shape=[3, 2], dtype='float32') # [[1., 1.], [1., 1.], [1., 1.]] + data = paddle.ones(shape=[2, 2], dtype='float32', device='cpu') # [[1., 1.], [1., 0.]] + """ + check_dtype(dtype, 'create data type', + ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], + 'zeros') + + if device is not None: + if device not in ['cpu', 'gpu']: + raise ValueError( + "The value of 'device' in zeros_op must be cpu or gpu, but received %s." + % (device)) + with fluid.device_guard(device): + return fill_constant(value=1.0, shape=shape, dtype=dtype, out=out) + return fill_constant(value=1.0, shape=shape, dtype=dtype, out=out) + + +def ones_like(input, dtype=None, device=None, name=None): + """ + This function creates a ones tensor which has identical shape and dtype + with `input`. + + Args: + input(Variable): The input tensor which specifies shape and dtype.The dtype of input can be + float32, float64, int32, int64. + dtype(np.dtype|core.VarDesc.VarType|str, optional): The data type can be set bool, float32, float64, int32, int64. + The default value is None, the dtype is the same as input. + device(str, optional): Which device to run the operator. The :attr:`device` must be + None, 'cpu', 'gpu'. If :attr:`device` is None, it will be choose the device that the user set in + the paddle program. Default value is None. + name(str, optional): The name of output variable, normally there is no need for user to set this this property. + Default value is None, the framework set the name of output variable. + Returns: + out(Variable): The tensor variable storing the output. + + Examples: + .. code-block:: python + + import paddle + import paddle.fluid as fluid + + x = fluid.layers.data(name='x', dtype='float32', shape=[3], append_batch_size=False) + data = paddle.ones_like(x) # data=[1.0, 1.0, 1.0] + data1 = paddle.ones_like(input=x, device="gpu") data1=[1.0, 1.0. 1.0] + + """ + + helper = LayerHelper("zeros_like", **locals()) + + attrs = {"value": 1.0} + var_dtype = None + if dtype is not None: + check_dtype( + dtype, 'create data type', + ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], + 'zeros_like') + var_dtype = convert_np_dtype_to_dtype_(dtype) + attrs["dtype"] = var_dtype + else: + var_dtype = input.dtype + + out = helper.create_variable_for_type_inference(dtype=var_dtype) + + if device is not None: + if device not in ['cpu', 'gpu']: + raise ValueError( + "The value of 'device' in zeros_op must be cpu or gpu, but received %s." + % (device)) + with fluid.device_guard(device): + helper.append_op( + type='fill_any_like', + inputs={'X': [input]}, + attrs=attrs, + outputs={'Out': [out]}) + return out + helper.append_op( + type='fill_any_like', + inputs={'X': [input]}, + attrs=attrs, + outputs={'Out': [out]}) + out.stop_gradient = True + return out + + +def zeros(shape, dtype, out=None, device=None): + """ + The OP creates a tensor of specified :attr:`shape` and :attr:`dtype`, and fills it with 0. + + Args: + shape(tuple|list): Shape of output tensor. + dtype(np.dtype|core.VarDesc.VarType|str): Data type of output tensor, it supports + bool, float16, float32, float64, int32 and int64. + out(Variable, optional): Optional output which can be any created + Variable that meets the requirements to store the result of operation. + if out is None, a new Varibale will be create to store the result. + device(str, optional): Which device to run the operator. The :attr:`device` must be + None,'cpu', 'gpu'. If :attr:`device` is None, it will be choose the device that the user set in + the paddle program. Default value is False. + + Returns: + Variable: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 0. + + Examples: + .. code-block:: python + + import paddle + data = paddle.zeros(shape=[3, 2], dtype='float32') # [[0., 0.], [0., 0.], [0., 0.]] + data = paddle.zeros(shape=[2, 2], dtype='float32', device='cpu') # [[0., 0.], [0., 0.]] + """ + check_dtype(dtype, 'create data type', + ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], + 'zeros') + if device is not None: + if device not in ['cpu', 'gpu']: + raise ValueError( + "The value of 'device' in zeros_op must be cpu or gpu, but received %s." + % (device)) + with fluid.device_guard(device): + return fill_constant(value=0.0, shape=shape, dtype=dtype, out=out) + + return fill_constant(value=0.0, shape=shape, dtype=dtype, out=out) + + +def zeros_like(input, dtype=None, device=None, name=None): + """ + This function creates a zeros tensor which has identical shape and dtype + with `input`. + + Args: + input(Variable): The input tensor which specifies shape and dtype.The dtype of input can be + bool, float32, float64, int32, int64. + dtype(np.dtype|core.VarDesc.VarType|str, optional): The data type can be set bool, float32, float64, int32, int64. + The default value is None, the dtype is the same as input. + device(str, optional): Which device to run the operator. The :attr:`device` must be + None, 'cpu', 'gpu'. If :attr:`device` is None, it will be choose the device that the user set in + the paddle program. Default value is None. + name(str, optional): The name of output variable, normally there is no need for user to set this this property. + Default value is None, the framework set the name of output variable. + + Returns: + out(Variable): The tensor variable storing the output. + + Examples: + .. code-block:: python + + import paddle + import paddle.fluid as fluid + + x = fluid.layers.data(name='x', dtype='float32', shape=[3], append_batch_size=False) + data = paddle.ones_like(x) # data=[1.0, 1.0, 1.0] + data1 = paddle.ones_like(input=x, device="gpu") data1=[1.0, 1.0. 1.0] + + """ + + helper = LayerHelper("zeros_like", **locals()) + + attrs = {"value": 0.0} + var_dtype = None + if dtype is not None: + check_dtype(dtype, 'create data type', + ['bool', 'float32', 'float64', 'int32', 'int64'], + 'zeros_like') + var_dtype = convert_np_dtype_to_dtype_(dtype) + attrs["dtype"] = var_dtype + else: + var_dtype = input.dtype + + out = helper.create_variable_for_type_inference(dtype=var_dtype) + + if device is not None: + if device not in ['cpu', 'gpu']: + raise ValueError( + "The value of 'device' in zeros_op must be cpu or gpu, but received %s." + % (device)) + with fluid.device_guard(device): + helper.append_op( + type='fill_any_like', + inputs={'X': [input]}, + attrs=attrs, + outputs={'Out': [out]}) + return out + helper.append_op( + type='fill_any_like', + inputs={'X': [input]}, + attrs=attrs, + outputs={'Out': [out]}) + out.stop_gradient = True + return out + + def full(shape, fill_value, out=None, -- GitLab