From d1c7b386a34c8efa9af67f743ba3c96af75b48f2 Mon Sep 17 00:00:00 2001 From: Ainavo <57820731+Ainavo@users.noreply.github.com> Date: Thu, 30 Mar 2023 11:37:03 +0800 Subject: [PATCH] support auto generate for prelu (#51913) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * support auto generate for prelu * op_compat 中增加输入参数 * del attrs ; add kernel data_type * add PreluGradInferMeta --- paddle/fluid/operators/prelu_op.cc | 135 ----------------------- paddle/phi/api/yaml/backward.yaml | 11 ++ paddle/phi/api/yaml/legacy_backward.yaml | 10 -- paddle/phi/api/yaml/legacy_ops.yaml | 9 -- paddle/phi/api/yaml/op_compat.yaml | 4 + paddle/phi/api/yaml/ops.yaml | 10 ++ paddle/phi/infermeta/backward.cc | 12 ++ paddle/phi/infermeta/backward.h | 5 + paddle/phi/ops/compat/prelu_sig.cc | 34 ------ 9 files changed, 42 insertions(+), 188 deletions(-) delete mode 100644 paddle/fluid/operators/prelu_op.cc delete mode 100644 paddle/phi/ops/compat/prelu_sig.cc diff --git a/paddle/fluid/operators/prelu_op.cc b/paddle/fluid/operators/prelu_op.cc deleted file mode 100644 index 5100b4f8698..00000000000 --- a/paddle/fluid/operators/prelu_op.cc +++ /dev/null @@ -1,135 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/phi/core/infermeta_utils.h" -#include "paddle/phi/infermeta/binary.h" - -namespace paddle { -namespace operators { - -class PReluOp : public framework::OperatorWithKernel { - public: - PReluOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorWithKernel(type, inputs, outputs, attrs) {} - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - auto input_data_type = - framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return phi::KernelKey(input_data_type, ctx.GetPlace()); - } -}; - -class PReluOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "The input tensor of prelu operator."); - AddInput("Alpha", "The alpha weight of prelu operator."); - AddOutput("Out", "The output tensor of prelu operator."); - AddComment(R"DOC( -PRelu Operator. -The equation is: -$$ -f(x) = -\begin{cases} -\alpha * x, \quad \text{if} \ x < 0 \\ -x, \qquad \text{if} \ x >= 0 -\end{cases} -$$ -The input `X` can carry the LoD (Level of Details) information, -or not. And the output shares the LoD information with input `X`. -There are modes: - all: all elements share same weight - channel: elements in a channel share same weight - element: each element has a weight -)DOC"); - AddAttr("mode", "The mode for inputs to share weights.") - .SetDefault("all"); - AddAttr("data_format", - "Data format that specifies the layout of input") - .SetDefault("NCHW"); - } -}; - -// The operator to calculate gradients of a prelu operator. -class PReluGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext *ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "prelu"); - OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), - "Input", - "Out@GRAD", - "prelu"); - - auto x_grad_name = framework::GradVarName("X"); - auto alpha_grad_name = framework::GradVarName("Alpha"); - - if (ctx->HasOutput(x_grad_name)) { - ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X")); - } - if (ctx->HasOutput(alpha_grad_name)) { - ctx->SetOutputDim(alpha_grad_name, ctx->GetInputDim("Alpha")); - } - } - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - auto input_data_type = - framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return phi::KernelKey(input_data_type, ctx.GetPlace()); - } -}; - -template -class PReluGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("prelu_grad"); - op->SetInput("X", this->Input("X")); - op->SetInput("Alpha", this->Input("Alpha")); - op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - op->SetOutput(framework::GradVarName("Alpha"), this->InputGrad("Alpha")); - op->SetAttrMap(this->Attrs()); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -DECLARE_INFER_SHAPE_FUNCTOR(prelu, - PReluInferShapeFunctor, - PD_INFER_META(phi::PReluInferMeta)); -REGISTER_OPERATOR(prelu, - ops::PReluOp, - ops::PReluOpMaker, - ops::PReluGradOpMaker, - ops::PReluGradOpMaker, - PReluInferShapeFunctor); -REGISTER_OPERATOR(prelu_grad, ops::PReluGradOp); diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index b760840f9ec..68ef92c2626 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1170,6 +1170,17 @@ func : pow_triple_grad data_type : x +- backward_op : prelu_grad + forward : prelu(Tensor x, Tensor alpha, str data_format="NCHW", str mode="all") -> Tensor(out) + args : (Tensor x, Tensor alpha, Tensor out_grad, str data_format, str mode) + output : Tensor(x_grad), Tensor(alpha_grad) + infer_meta : + func : PreluGradInferMeta + param: [x, alpha] + kernel : + func : prelu_grad + data_type : x + - backward_op : put_along_axis_grad forward : put_along_axis (Tensor arr, Tensor indices, Tensor value, int axis, str reduce = "assign") -> Tensor(out) args : (Tensor arr, Tensor indices, Tensor out_grad, int axis, str reduce) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index d6d6bbcc20d..035d301589d 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -866,16 +866,6 @@ func : pool3d_grad param : [x, out, out_grad, kernel_size, strides, paddings, ceil_mode, exclusive, data_format, pooling_type, global_pooling, adaptive, padding_algorithm] -- backward_op : prelu_grad - forward : prelu(Tensor x, Tensor alpha, str data_format, str mode) -> Tensor(out) - args : (Tensor x, Tensor alpha, Tensor out_grad, str data_format, str mode) - output : Tensor(x_grad), Tensor(alpha_grad) - infer_meta : - func : GeneralBinaryGradInferMeta - param: [x, alpha] - kernel : - func : prelu_grad - - backward_op : prod_grad forward : prod (Tensor x, IntArray dims, bool keep_dim, bool reduce_all) -> Tensor(out) args : (Tensor x, Tensor out, Tensor out_grad, IntArray dims, bool keep_dim, bool reduce_all) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 59c69272827..6cf0d1640fc 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -1188,15 +1188,6 @@ param : [x, kernel_size, strides, paddings, ceil_mode, exclusive, data_format, pooling_type, global_pooling, adaptive, padding_algorithm] backward : pool3d_grad -- op : prelu - args : (Tensor x, Tensor alpha, str data_format, str mode) - output : Tensor(out) - infer_meta : - func : PReluInferMeta - kernel : - func : prelu - backward : prelu_grad - - op : prior_box args : (Tensor input, Tensor image, float[] min_sizes, float[] aspect_ratios, float[] variances, float[] max_sizes = {}, bool flip=true, bool clip=true, float step_w=0.0, float step_h=0.0, float offset=0.5, bool min_max_aspect_ratios_order=false) output : Tensor(out), Tensor(var) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 60f0449bd07..1111d14351a 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1447,6 +1447,10 @@ - op : prelu backward : prelu_grad + inputs : + { x : X, alpha : Alpha} + outputs : + out : Out extra : attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false] diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index ddfe80e535c..f80a0a770fc 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1186,6 +1186,16 @@ data_type : x backward : pow_grad +- op : prelu + args : (Tensor x, Tensor alpha, str data_format="NCHW", str mode="all") + output : Tensor(out) + infer_meta : + func : PReluInferMeta + kernel : + func : prelu + data_type : x + backward : prelu_grad + - op : put_along_axis args : (Tensor arr, Tensor indices, Tensor values, int axis, str reduce = "assign") output : Tensor(out) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 8acd927f473..6393cabc213 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -885,6 +885,18 @@ void PixelUnshuffleGradInferMeta(const MetaTensor& out_grad, x_grad->set_dtype(out_grad.dtype()); } +void PreluGradInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* dx, + MetaTensor* dy) { + if (dx) { + dx->share_dims(x); + } + if (dy) { + dy->share_dims(y); + } +} + void PsroiPoolGradInferMeta(const MetaTensor& x, const MetaTensor& rois, const MetaTensor& rois_num, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index e65ba2085e6..c8608ebcd10 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -344,6 +344,11 @@ void PixelUnshuffleGradInferMeta(const MetaTensor& out_grad, const std::string& data_format, MetaTensor* x_grad); +void PreluGradInferMeta(const MetaTensor& x, + const MetaTensor& y, + MetaTensor* dx, + MetaTensor* dy); + void OverlapAddGradInferMeta(const MetaTensor& x, const MetaTensor& out_grad, int hop_length, diff --git a/paddle/phi/ops/compat/prelu_sig.cc b/paddle/phi/ops/compat/prelu_sig.cc deleted file mode 100644 index 6e25e1d9f75..00000000000 --- a/paddle/phi/ops/compat/prelu_sig.cc +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature PReluOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature( - "prelu", {"X", "Alpha"}, {"data_format", "mode"}, {"Out"}); -} - -KernelSignature PReluGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("prelu_grad", - {"X", "Alpha", "Out@GRAD"}, - {"data_format", "mode"}, - {"X@GRAD", "Alpha@GRAD"}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(prelu, phi::PReluOpArgumentMapping); -PD_REGISTER_ARG_MAPPING_FN(prelu_grad, phi::PReluGradOpArgumentMapping); -- GitLab