From 889bdde3a6e7515cb07a4b00531fccc0ee31bc2a Mon Sep 17 00:00:00 2001 From: Zhou Wei <1183042833@qq.com> Date: Wed, 20 Jul 2022 16:52:18 +0800 Subject: [PATCH] [Phi] migrate exponential kernel to phi (#44376) * [Phi] migrate exponential kernel to phi * fix comment * fix CI --- paddle/fluid/operators/exponential_op.cc | 86 +++---------------- paddle/fluid/operators/exponential_op.cu | 48 ----------- paddle/fluid/operators/exponential_op.h | 42 --------- .../yaml/generator/wrapped_infermeta_gen.py | 3 +- paddle/phi/api/yaml/legacy_api.yaml | 11 +++ paddle/phi/api/yaml/legacy_backward.yaml | 9 ++ paddle/phi/kernels/cpu/exponential_kernel.cc | 45 ++++++++++ paddle/phi/kernels/exponential_kernel.h | 27 ++++++ paddle/phi/kernels/gpu/exponential_kernel.cu | 36 ++++++++ paddle/phi/ops/compat/exponential_sig.cc | 26 ++++++ .../tests/unittests/test_exponential_op.py | 12 ++- python/paddle/tensor/random.py | 4 +- 12 files changed, 181 insertions(+), 168 deletions(-) delete mode 100644 paddle/fluid/operators/exponential_op.cu delete mode 100644 paddle/fluid/operators/exponential_op.h create mode 100644 paddle/phi/kernels/cpu/exponential_kernel.cc create mode 100644 paddle/phi/kernels/exponential_kernel.h create mode 100644 paddle/phi/kernels/gpu/exponential_kernel.cu create mode 100644 paddle/phi/ops/compat/exponential_sig.cc diff --git a/paddle/fluid/operators/exponential_op.cc b/paddle/fluid/operators/exponential_op.cc index 5a75063fba..26e06e50a7 100644 --- a/paddle/fluid/operators/exponential_op.cc +++ b/paddle/fluid/operators/exponential_op.cc @@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/exponential_op.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace operators { @@ -21,13 +23,6 @@ class ExponentialOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ExponentialOp"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ExponentialOp"); - auto dim = ctx->GetInputDim("X"); - ctx->SetOutputDim("Out", dim); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { @@ -51,52 +46,6 @@ exponential distribution. } }; -class ExponentialOpInferVarType - : public framework::PassInDtypeAndVarTypeToOutput { - protected: - std::unordered_map &GetInputOutputWithSameType() - const override { - static std::unordered_map m{{"X", /*->*/ "Out"}}; - return m; - } -}; - -template -class ExponentialKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - auto *out = ctx.Output("Out"); - T *out_data = out->mutable_data(ctx.GetPlace()); - - T lambda = static_cast(ctx.Attr("lambda")); - int64_t size = out->numel(); - - auto gen = framework::DefaultCPUGenerator(); - auto engine = gen->GetCPUEngine(); - - std::uniform_real_distribution uniform(0.0, 1.0); - phi::funcs::exponential_transform trans(lambda); - for (int64_t i = 0; i < size; ++i) { - out_data[i] = trans(uniform(*engine)); - } - } -}; - -class ExponentialGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - "Out_Grad", - "ExponentialGradOp"); - - auto dout_dim = ctx->GetInputDim(framework::GradVarName("Out")); - ctx->SetOutputDim(framework::GradVarName("X"), dout_dim); - } -}; - template class ExponentialGradOpMaker : public framework::SingleGradOpMaker { public: @@ -104,10 +53,10 @@ class ExponentialGradOpMaker : public framework::SingleGradOpMaker { protected: void Apply(GradOpPtr retv) const override { - retv->SetType("exponential_grad"); - retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - retv->SetAttrMap(this->Attrs()); + retv->SetType("fill_any_like"); + retv->SetInput("X", this->OutputGrad("Out")); + retv->SetAttr("value", 0.0f); + retv->SetOutput("Out", this->InputGrad("X")); } }; @@ -118,24 +67,15 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; DECLARE_INPLACE_OP_INFERER(ExponentialInferer, {"X", "Out"}); -DECLARE_INPLACE_OP_INFERER(ExponentialGradInferer, - {paddle::framework::GradVarName("Out"), - paddle::framework::GradVarName("X")}); + +DECLARE_INFER_SHAPE_FUNCTOR(exponential, + ExponentialInfershapeFunctor, + PD_INFER_META(phi::UnchangedInferMeta)); REGISTER_OPERATOR(exponential, ops::ExponentialOp, ops::ExponentialOpMaker, - ops::ExponentialOpInferVarType, ops::ExponentialGradOpMaker, ops::ExponentialGradOpMaker, - ExponentialInferer); -REGISTER_OPERATOR(exponential_grad, - ops::ExponentialGradOp, - ExponentialGradInferer); - -REGISTER_OP_CPU_KERNEL(exponential, - ops::ExponentialKernel, - ops::ExponentialKernel); -REGISTER_OP_CPU_KERNEL(exponential_grad, - ops::ExponentialGradKernel, - ops::ExponentialGradKernel); + ExponentialInferer, + ExponentialInfershapeFunctor); diff --git a/paddle/fluid/operators/exponential_op.cu b/paddle/fluid/operators/exponential_op.cu deleted file mode 100644 index 58d6fa674b..0000000000 --- a/paddle/fluid/operators/exponential_op.cu +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/exponential_op.h" - -namespace paddle { -namespace operators { - -template -class ExponentialKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - framework::Tensor* out = ctx.Output("Out"); - auto& dev_cxt = ctx.template device_context(); - T lambda = static_cast(ctx.Attr("lambda")); - - phi::funcs::uniform_distribution dist; - phi::funcs::exponential_transform trans(lambda); - phi::funcs::distribution_and_transform(dev_cxt, out, dist, trans); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_CUDA_KERNEL( - exponential, - ops::ExponentialKernel, - ops::ExponentialKernel); -REGISTER_OP_CUDA_KERNEL( - exponential_grad, - ops::ExponentialGradKernel, - ops::ExponentialGradKernel); diff --git a/paddle/fluid/operators/exponential_op.h b/paddle/fluid/operators/exponential_op.h deleted file mode 100644 index 7ded174a9f..0000000000 --- a/paddle/fluid/operators/exponential_op.h +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/fluid/framework/generator.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/phi/kernels/funcs/distribution_helper.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace operators { - -template -class ExponentialKernel; - -template -class ExponentialGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* dx = ctx.Output(framework::GradVarName("X")); - dx->mutable_data(ctx.GetPlace()); - phi::funcs::SetConstant functor; - auto& dev_ctx = ctx.template device_context(); - functor(dev_ctx, dx, static_cast(0)); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/api/yaml/generator/wrapped_infermeta_gen.py b/paddle/phi/api/yaml/generator/wrapped_infermeta_gen.py index 99da6ce3d9..dfa6a7f93c 100644 --- a/paddle/phi/api/yaml/generator/wrapped_infermeta_gen.py +++ b/paddle/phi/api/yaml/generator/wrapped_infermeta_gen.py @@ -46,7 +46,8 @@ PD_REGISTER_INFER_META_FN({api.kernel['func'][0]}, phi::{api.infer_meta['func']} 'const paddle::optional&': 'const MetaTensor&' } - wrapped_infermeta_name = get_wrapped_infermeta_name(api.api) + wrapped_infermeta_name = get_wrapped_infermeta_name( + api.kernel['func'][0]) args = [] for input_name in api.inputs['names']: if input_name in kernel_params: diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index ed08fe48ee..f60309985a 100644 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -689,6 +689,17 @@ func : expm1 backward : expm1_grad +- api : exponential_ + args : (Tensor x, float lambda) + output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : exponential + inplace : (x -> out) + backward : exponential__grad + - api : eye args : (int64_t num_rows, int64_t num_columns, DataType dtype=DataType::FLOAT32, Place place={}) output : Tensor(out) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 91464ac769..6df4883145 100644 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -720,6 +720,15 @@ func : expm1_grad inplace : (out_grad -> x_grad) +- backward_api : exponential__grad + forward : exponential_ (Tensor x, float lambda) -> Tensor(out) + args : (Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + invoke : zeros_like(out_grad, DataType::UNDEFINED, {}) + inplace : (out_grad -> x_grad) + - backward_api : flatten_grad forward : flatten(Tensor x, int start_axis, int stop_axis) -> Tensor(out), Tensor(xshape) args : (Tensor xshape, Tensor out_grad) diff --git a/paddle/phi/kernels/cpu/exponential_kernel.cc b/paddle/phi/kernels/cpu/exponential_kernel.cc new file mode 100644 index 0000000000..a4a07fc7a6 --- /dev/null +++ b/paddle/phi/kernels/cpu/exponential_kernel.cc @@ -0,0 +1,45 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/exponential_kernel.h" + +#include + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/generator.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/distribution_helper.h" + +namespace phi { + +template +void ExponentialKernel(const Context& dev_ctx, + const DenseTensor& x, + float lambda, + DenseTensor* out) { + T* out_data = dev_ctx.template Alloc(out); + auto engine = dev_ctx.GetGenerator()->GetCPUEngine(); + + std::uniform_real_distribution uniform(0.0, 1.0); + phi::funcs::exponential_transform trans(lambda); + + for (int64_t i = 0; i < out->numel(); ++i) { + out_data[i] = trans(uniform(*engine)); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + exponential, CPU, ALL_LAYOUT, phi::ExponentialKernel, float, double) {} diff --git a/paddle/phi/kernels/exponential_kernel.h b/paddle/phi/kernels/exponential_kernel.h new file mode 100644 index 0000000000..736baacca4 --- /dev/null +++ b/paddle/phi/kernels/exponential_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void ExponentialKernel(const Context &dev_ctx, + const DenseTensor &x, + float lambda, + DenseTensor *out); + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/exponential_kernel.cu b/paddle/phi/kernels/gpu/exponential_kernel.cu new file mode 100644 index 0000000000..fc1730dde6 --- /dev/null +++ b/paddle/phi/kernels/gpu/exponential_kernel.cu @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/exponential_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/distribution_helper.h" + +namespace phi { + +template +void ExponentialKernel(const Context &dev_ctx, + const DenseTensor &x, + float lambda, + DenseTensor *out) { + phi::funcs::uniform_distribution dist; + phi::funcs::exponential_transform trans(lambda); + phi::funcs::distribution_and_transform(dev_ctx, out, dist, trans); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + exponential, GPU, ALL_LAYOUT, phi::ExponentialKernel, float, double) {} diff --git a/paddle/phi/ops/compat/exponential_sig.cc b/paddle/phi/ops/compat/exponential_sig.cc new file mode 100644 index 0000000000..2d70a4200a --- /dev/null +++ b/paddle/phi/ops/compat/exponential_sig.cc @@ -0,0 +1,26 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature ExponentialOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("exponential", {"X"}, {"lambda"}, {"Out"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(exponential, phi::ExponentialOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_exponential_op.py b/python/paddle/fluid/tests/unittests/test_exponential_op.py index 57c4fb02d8..72b4d89904 100644 --- a/python/paddle/fluid/tests/unittests/test_exponential_op.py +++ b/python/paddle/fluid/tests/unittests/test_exponential_op.py @@ -18,13 +18,13 @@ import numpy as np from op_test import OpTest import os -paddle.enable_static() paddle.seed(100) class TestExponentialOp1(OpTest): def setUp(self): + paddle.enable_static() self.op_type = "exponential" self.config() @@ -87,8 +87,14 @@ class TestExponentialAPI(unittest.TestCase): def test_dygraph(self): paddle.disable_static() x = paddle.full([10, 10], -1., dtype='float32') - x.exponential_(0.5) - self.assertTrue(np.min(x.numpy()) >= 0) + x.stop_gradient = False + y = 2 * x + y.exponential_(0.5) + print(y) + self.assertTrue(np.min(y.numpy()) >= 0) + + y.backward() + self.assertTrue(np.array_equal(x.grad.numpy(), np.zeros([10, 10]))) paddle.enable_static() def test_fixed_random_number(self): diff --git a/python/paddle/tensor/random.py b/python/paddle/tensor/random.py index 990b20a267..e25366df75 100644 --- a/python/paddle/tensor/random.py +++ b/python/paddle/tensor/random.py @@ -1052,7 +1052,9 @@ def exponential_(x, lam=1.0, name=None): # [0.72520673, 0.45208144, 0.30234432]] """ - if paddle.in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_exponential_(x, lam) + elif paddle.in_dynamic_mode(): return _C_ops.exponential_(x, "lambda", lam) check_variable_and_dtype(x, "x", ["float32", "float64"], "exponential") -- GitLab