diff --git a/paddle/fluid/operators/fill_any_op.cc b/paddle/fluid/operators/fill_any_op.cc index 853ebbdd9e57cb2f769ac874ea4d241e7d56d164..1af302d1fc032d60ac9e7d3685a022f419459f0e 100644 --- a/paddle/fluid/operators/fill_any_op.cc +++ b/paddle/fluid/operators/fill_any_op.cc @@ -12,7 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/fill_any_op.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace paddle { namespace operators { @@ -34,30 +38,11 @@ class FillAnyOpMaker : public framework::OpProtoAndCheckerMaker { class FillAnyOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *context) const override { - OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "FillAny"); - OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "FillAny"); - auto x_dims = context->GetInputDim("X"); - context->SetOutputDim("Out", x_dims); - } }; class FillAnyGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - "Out@GRAD", - "mul"); - auto x_dims = ctx->GetInputDim(framework::GradVarName("Out")); - auto x_grad_name = framework::GradVarName("X"); - if (ctx->HasOutput(x_grad_name)) { - ctx->SetOutputDim(x_grad_name, x_dims); - } - } }; template @@ -82,31 +67,22 @@ DECLARE_INPLACE_OP_INFERER(FillAnyGradInplaceInferer, } // namespace paddle namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(fill_any, + FillInferShapeFunctor, + PD_INFER_META(phi::UnchangedInferMeta)); +DECLARE_INFER_SHAPE_FUNCTOR(fill_any_grad, + FillAnyInferShapeFunctor, + PD_INFER_META(phi::UnchangedInferMeta)); + REGISTER_OPERATOR(fill_any, ops::FillAnyOp, ops::FillAnyOpMaker, ops::FillAnyGradOpMaker, ops::FillAnyGradOpMaker, - ops::FillAnyOpInplaceInferer); + ops::FillAnyOpInplaceInferer, + FillInferShapeFunctor); REGISTER_OPERATOR(fill_any_grad, ops::FillAnyGradOp, - ops::FillAnyGradInplaceInferer); - -REGISTER_OP_CPU_KERNEL( - fill_any, - ops::FillAnyKernel, - ops::FillAnyKernel, - ops::FillAnyKernel, - ops::FillAnyKernel, - ops::FillAnyKernel, - ops::FillAnyKernel); - -REGISTER_OP_CPU_KERNEL( - fill_any_grad, - ops::FillAnyGradKernel, - ops::FillAnyGradKernel, - ops::FillAnyGradKernel, - ops::FillAnyGradKernel, - ops::FillAnyGradKernel, - ops::FillAnyGradKernel); + ops::FillAnyGradInplaceInferer, + FillAnyInferShapeFunctor); diff --git a/paddle/fluid/operators/fill_any_op.cu.cc b/paddle/fluid/operators/fill_any_op.cu.cc deleted file mode 100644 index 2a561e6d3500e62bca32e599ca2b7254d31c4392..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/fill_any_op.cu.cc +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/fill_any_op.h" -namespace ops = paddle::operators; - -REGISTER_OP_CUDA_KERNEL( - fill_any, - ops::FillAnyKernel, - ops::FillAnyKernel, - ops::FillAnyKernel, - ops::FillAnyKernel, - ops::FillAnyKernel, - ops::FillAnyKernel); - -REGISTER_OP_CUDA_KERNEL( - fill_any_grad, - ops::FillAnyGradKernel, - ops::FillAnyGradKernel, - ops::FillAnyGradKernel, - ops::FillAnyGradKernel, - ops::FillAnyGradKernel, - ops::FillAnyGradKernel); diff --git a/paddle/fluid/operators/fill_any_op.h b/paddle/fluid/operators/fill_any_op.h deleted file mode 100644 index 4f59d4f6ec65944ff5006ccbdd28826f911c0e27..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/fill_any_op.h +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#pragma once -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -template -class FillAnyKernel : public framework::OpKernel { - public: - void Compute(const paddle::framework::ExecutionContext &ctx) const override { - auto *out = ctx.Output("Out"); - auto floatvar = ctx.template Attr("value_float"); - auto intvar = ctx.template Attr("value_int"); - auto isfloat = ((typeid(float) == typeid(T)) || - (typeid(double) == typeid(T) || - typeid(paddle::platform::float16) == typeid(T))); - - T fill_var = static_cast(floatvar); - if (!isfloat) { - fill_var = static_cast(intvar); - } - - PADDLE_ENFORCE_EQ( - std::isnan(static_cast(fill_var)), - false, - platform::errors::InvalidArgument("fill value should not be NaN," - " but received NaN")); - - out->mutable_data(ctx.GetPlace()); - auto &dev_ctx = ctx.template device_context(); - phi::funcs::SetConstant functor; - functor(reinterpret_cast(dev_ctx), - out, - static_cast(fill_var)); - } -}; - -template -class FillAnyGradKernel : public framework::OpKernel { - public: - void Compute(const paddle::framework::ExecutionContext &ctx) const override { - auto *dx = ctx.Output(framework::GradVarName("X")); - if (dx) { - dx->mutable_data(ctx.GetPlace()); - auto &dev_ctx = ctx.template device_context(); - phi::funcs::SetConstant functor; - functor(reinterpret_cast(dev_ctx), dx, T(0)); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 307edbefd03ca9a2717732553271feab327d2d28..753e3dc6762d314cfda454a84df3e0f6b2f3d6f3 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -864,6 +864,17 @@ data_type : dtype backend : place +- api : fill + args : (Tensor x, Scalar value) + output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : fill + inplace : (x -> out) + backward: fill_grad + - api : fill_diagonal args : (Tensor x, float value, int offset, bool wrap) output : Tensor(out) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 2a5d8bff70cd9427baf2984ce414af38942c3d2b..363edb430ffcd260339bc7cb8ee2d31c19214ee1 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -811,7 +811,7 @@ infer_meta : func : UnchangedInferMeta invoke : zeros_like(out_grad, DataType::UNDEFINED, {}) - + - backward_api : fill_diagonal_grad forward : fill_diagonal (Tensor x, float value, int offset, bool wrap) -> Tensor(out) args : (Tensor out_grad, float value, int offset, bool wrap) @@ -831,6 +831,17 @@ func : fill_diagonal_tensor_grad inplace : (out_grad -> x_grad) +- backward_api : fill_grad + forward : fill (Tensor x, Scalar value) -> Tensor(out) + args : (Tensor out_grad, Scalar value) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param : [out_grad] + kernel : + func : fill_grad + inplace : (out_grad -> x_grad) + - backward_api : flatten_grad forward : flatten(Tensor x, int start_axis, int stop_axis) -> Tensor(out), Tensor(xshape) args : (Tensor xshape, Tensor out_grad) diff --git a/paddle/phi/core/compat/op_utils.h b/paddle/phi/core/compat/op_utils.h index ae3b8924ece696de7ad0c48faf8144d64d014264..f4cca91a562da09aef92e65c45b500db07ea7468 100644 --- a/paddle/phi/core/compat/op_utils.h +++ b/paddle/phi/core/compat/op_utils.h @@ -51,6 +51,7 @@ const std::unordered_set deprecated_op_names({"diag", "squeeze_grad", "isfinite", "matmul", + "fill", "matmul_grad", "matmul_grad_grad", "max", diff --git a/paddle/phi/kernels/cpu/fill_grad_kernel.cc b/paddle/phi/kernels/cpu/fill_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..ee676773762ca5987aceb1b007ab2196de792d59 --- /dev/null +++ b/paddle/phi/kernels/cpu/fill_grad_kernel.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/impl/fill_grad_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(fill_grad, + CPU, + ALL_LAYOUT, + phi::FillGradKernel, + float, + double, + int64_t, + int, + paddle::platform::float16, + bool) {} diff --git a/paddle/phi/kernels/cpu/fill_kernel.cc b/paddle/phi/kernels/cpu/fill_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..ee8dac7f6770c40b2192d18016b8bc2582dd6d33 --- /dev/null +++ b/paddle/phi/kernels/cpu/fill_kernel.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/impl/fill_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(fill, + CPU, + ALL_LAYOUT, + phi::FillKernel, + float, + double, + int64_t, + int, + paddle::platform::float16, + bool) {} diff --git a/paddle/phi/kernels/fill_grad_kernel.h b/paddle/phi/kernels/fill_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..8e43d996489cbae84e15e4e6379a37b524954981 --- /dev/null +++ b/paddle/phi/kernels/fill_grad_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void FillGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, + const Scalar& value, + DenseTensor* in_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/fill_kernel.h b/paddle/phi/kernels/fill_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..9af3f465303b3636dc8f3ca12f0179375027de91 --- /dev/null +++ b/paddle/phi/kernels/fill_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void FillKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& value, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/fill_grad_kernel.cu b/paddle/phi/kernels/gpu/fill_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..32559ba95dfbca8abdaf0539182879e3953effca --- /dev/null +++ b/paddle/phi/kernels/gpu/fill_grad_kernel.cu @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/fill_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/fill_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(fill_grad, + GPU, + ALL_LAYOUT, + phi::FillGradKernel, + float, + double, + int64_t, + int, + paddle::platform::float16, + bool) {} diff --git a/paddle/phi/kernels/gpu/fill_kernel.cu b/paddle/phi/kernels/gpu/fill_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..141e47b8cb109bd3be55311c997f88ec117a6e3c --- /dev/null +++ b/paddle/phi/kernels/gpu/fill_kernel.cu @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/fill_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/fill_kernel_impl.h" + +PD_REGISTER_KERNEL(fill, + GPU, + ALL_LAYOUT, + phi::FillKernel, + float, + double, + int64_t, + int, + paddle::platform::float16, + bool) {} diff --git a/paddle/phi/kernels/impl/fill_grad_kernel_impl.h b/paddle/phi/kernels/impl/fill_grad_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..dffb81fbea4e3614306e28081255ae7f23daec04 --- /dev/null +++ b/paddle/phi/kernels/impl/fill_grad_kernel_impl.h @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/fill_grad_kernel.h" + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void FillGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, + const Scalar& value, + DenseTensor* in_grad) { + if (in_grad) { + dev_ctx.template Alloc(in_grad); + + phi::funcs::SetConstant functor; + functor(dev_ctx, in_grad, T(0)); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/fill_kernel_impl.h b/paddle/phi/kernels/impl/fill_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..7d10ea42bd6b65e2d2c374560ff8d6ae0e31ea75 --- /dev/null +++ b/paddle/phi/kernels/impl/fill_kernel_impl.h @@ -0,0 +1,43 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/fill_kernel.h" + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void FillKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& value, + DenseTensor* out) { + T fill_var = value.to(); + + PADDLE_ENFORCE_EQ(std::isnan(static_cast(fill_var)), + false, + phi::errors::InvalidArgument("fill value should not be NaN," + " but received NaN")); + + dev_ctx.template Alloc(out); + + phi::funcs::SetConstant functor; + functor(dev_ctx, out, fill_var); +} + +} // namespace phi diff --git a/paddle/phi/ops/compat/fill_sig.cc b/paddle/phi/ops/compat/fill_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..2af8fcbea49ca569bca434c6e100d29ed81dccdb --- /dev/null +++ b/paddle/phi/ops/compat/fill_sig.cc @@ -0,0 +1,33 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { +KernelSignature FillOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("fill", {"X"}, {"value_float"}, {"Out"}); +} + +KernelSignature FillGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature( + "fill_grad", {"Out@GRAD"}, {"value_float"}, {"X@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_BASE_KERNEL_NAME(fill_any, fill); +PD_REGISTER_BASE_KERNEL_NAME(fill_any_grad, fill_grad); + +PD_REGISTER_ARG_MAPPING_FN(fill_any, phi::FillOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(fill_any_grad, phi::FillGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_fill_any_op.py b/python/paddle/fluid/tests/unittests/test_fill_any_op.py index 1262c28edda8485ea1996cacb7a5520ed71f5f24..ad7fd26a0a08813003262b3d158260759875b16c 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_any_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_any_op.py @@ -16,9 +16,11 @@ from __future__ import print_function import paddle import paddle.fluid.core as core +from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode import unittest import numpy as np from op_test import OpTest +from paddle.tensor.manipulation import fill_ class TestFillAnyOp(OpTest): @@ -75,5 +77,41 @@ class TestFillAnyOpvalue2(TestFillAnyOp): self.value = 11111.1111 +class TestFillAnyInplace(unittest.TestCase): + + def test_fill_any_version(self): + with paddle.fluid.dygraph.guard(): + var = paddle.to_tensor(np.ones((4, 2, 3)).astype(np.float32)) + self.assertEqual(var.inplace_version, 0) + + var.fill_(0) + self.assertEqual(var.inplace_version, 1) + + var.fill_(0) + self.assertEqual(var.inplace_version, 2) + + var.fill_(0) + self.assertEqual(var.inplace_version, 3) + + def test_fill_any_eqaul(self): + with paddle.fluid.dygraph.guard(): + tensor = paddle.to_tensor( + np.random.random((20, 30)).astype(np.float32)) + target = tensor.numpy() + target[...] = 1 + + tensor.fill_(1) + self.assertEqual((tensor.numpy() == target).all().item(), True) + + def test_backward(self): + with paddle.fluid.dygraph.guard(): + x = paddle.full([10, 10], -1., dtype='float32') + x.stop_gradient = False + y = 2 * x + y.fill_(1) + y.backward() + self.assertTrue(np.array_equal(x.grad.numpy(), np.zeros([10, 10]))) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index e0e34abff365eddec718ec098d77e81be5668a0a..0b1a22865d506ae37079d0ff469348a51322e8af 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -777,8 +777,11 @@ def fill_(x, value): raise TypeError( "The type of 'value' must be int or float, but received %s." % (type(value))) - return _C_ops.fill_any_(x, "value_float", float(value), "value_int", - int(value)) + if in_dygraph_mode(): + return _C_ops.final_state_fill_(x, value) + else: + return _C_ops.fill_any_(x, "value_float", float(value), "value_int", + int(value)) @dygraph_only @@ -806,7 +809,10 @@ def zero_(x): print(tensor.tolist()) #[0, 0, 0, 0, 0] """ - return _C_ops.fill_any_(x, "value_float", 0., "value_int", int(0)) + if in_dygraph_mode(): + return _C_ops.final_state_fill_(x, 0.) + else: + return _C_ops.fill_any_(x, "value_float", 0., "value_int", int(0)) @dygraph_only