未验证 提交 d1c7b386 编写于 作者: A Ainavo 提交者: GitHub

support auto generate for prelu (#51913)

* support auto generate for prelu

* op_compat 中增加输入参数

* del attrs ; add kernel data_type

* add PreluGradInferMeta
上级 6f8ab1fa
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle {
namespace operators {
class PReluOp : public framework::OperatorWithKernel {
public:
PReluOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
};
class PReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input tensor of prelu operator.");
AddInput("Alpha", "The alpha weight of prelu operator.");
AddOutput("Out", "The output tensor of prelu operator.");
AddComment(R"DOC(
PRelu Operator.
The equation is:
$$
f(x) =
\begin{cases}
\alpha * x, \quad \text{if} \ x < 0 \\
x, \qquad \text{if} \ x >= 0
\end{cases}
$$
The input `X` can carry the LoD (Level of Details) information,
or not. And the output shares the LoD information with input `X`.
There are modes:
all: all elements share same weight
channel: elements in a channel share same weight
element: each element has a weight
)DOC");
AddAttr<std::string>("mode", "The mode for inputs to share weights.")
.SetDefault("all");
AddAttr<std::string>("data_format",
"Data format that specifies the layout of input")
.SetDefault("NCHW");
}
};
// The operator to calculate gradients of a prelu operator.
class PReluGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "prelu");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
"Out@GRAD",
"prelu");
auto x_grad_name = framework::GradVarName("X");
auto alpha_grad_name = framework::GradVarName("Alpha");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
}
if (ctx->HasOutput(alpha_grad_name)) {
ctx->SetOutputDim(alpha_grad_name, ctx->GetInputDim("Alpha"));
}
}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
};
template <typename T>
class PReluGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("prelu_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Alpha", this->Input("Alpha"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Alpha"), this->InputGrad("Alpha"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(prelu,
PReluInferShapeFunctor,
PD_INFER_META(phi::PReluInferMeta));
REGISTER_OPERATOR(prelu,
ops::PReluOp,
ops::PReluOpMaker,
ops::PReluGradOpMaker<paddle::framework::OpDesc>,
ops::PReluGradOpMaker<paddle::imperative::OpBase>,
PReluInferShapeFunctor);
REGISTER_OPERATOR(prelu_grad, ops::PReluGradOp);
......@@ -1170,6 +1170,17 @@
func : pow_triple_grad
data_type : x
- backward_op : prelu_grad
forward : prelu(Tensor x, Tensor alpha, str data_format="NCHW", str mode="all") -> Tensor(out)
args : (Tensor x, Tensor alpha, Tensor out_grad, str data_format, str mode)
output : Tensor(x_grad), Tensor(alpha_grad)
infer_meta :
func : PreluGradInferMeta
param: [x, alpha]
kernel :
func : prelu_grad
data_type : x
- backward_op : put_along_axis_grad
forward : put_along_axis (Tensor arr, Tensor indices, Tensor value, int axis, str reduce = "assign") -> Tensor(out)
args : (Tensor arr, Tensor indices, Tensor out_grad, int axis, str reduce)
......
......@@ -866,16 +866,6 @@
func : pool3d_grad
param : [x, out, out_grad, kernel_size, strides, paddings, ceil_mode, exclusive, data_format, pooling_type, global_pooling, adaptive, padding_algorithm]
- backward_op : prelu_grad
forward : prelu(Tensor x, Tensor alpha, str data_format, str mode) -> Tensor(out)
args : (Tensor x, Tensor alpha, Tensor out_grad, str data_format, str mode)
output : Tensor(x_grad), Tensor(alpha_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param: [x, alpha]
kernel :
func : prelu_grad
- backward_op : prod_grad
forward : prod (Tensor x, IntArray dims, bool keep_dim, bool reduce_all) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, IntArray dims, bool keep_dim, bool reduce_all)
......
......@@ -1188,15 +1188,6 @@
param : [x, kernel_size, strides, paddings, ceil_mode, exclusive, data_format, pooling_type, global_pooling, adaptive, padding_algorithm]
backward : pool3d_grad
- op : prelu
args : (Tensor x, Tensor alpha, str data_format, str mode)
output : Tensor(out)
infer_meta :
func : PReluInferMeta
kernel :
func : prelu
backward : prelu_grad
- op : prior_box
args : (Tensor input, Tensor image, float[] min_sizes, float[] aspect_ratios, float[] variances, float[] max_sizes = {}, bool flip=true, bool clip=true, float step_w=0.0, float step_h=0.0, float offset=0.5, bool min_max_aspect_ratios_order=false)
output : Tensor(out), Tensor(var)
......
......@@ -1447,6 +1447,10 @@
- op : prelu
backward : prelu_grad
inputs :
{ x : X, alpha : Alpha}
outputs :
out : Out
extra :
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false]
......
......@@ -1186,6 +1186,16 @@
data_type : x
backward : pow_grad
- op : prelu
args : (Tensor x, Tensor alpha, str data_format="NCHW", str mode="all")
output : Tensor(out)
infer_meta :
func : PReluInferMeta
kernel :
func : prelu
data_type : x
backward : prelu_grad
- op : put_along_axis
args : (Tensor arr, Tensor indices, Tensor values, int axis, str reduce = "assign")
output : Tensor(out)
......
......@@ -885,6 +885,18 @@ void PixelUnshuffleGradInferMeta(const MetaTensor& out_grad,
x_grad->set_dtype(out_grad.dtype());
}
void PreluGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* dx,
MetaTensor* dy) {
if (dx) {
dx->share_dims(x);
}
if (dy) {
dy->share_dims(y);
}
}
void PsroiPoolGradInferMeta(const MetaTensor& x,
const MetaTensor& rois,
const MetaTensor& rois_num,
......
......@@ -344,6 +344,11 @@ void PixelUnshuffleGradInferMeta(const MetaTensor& out_grad,
const std::string& data_format,
MetaTensor* x_grad);
void PreluGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
MetaTensor* dx,
MetaTensor* dy);
void OverlapAddGradInferMeta(const MetaTensor& x,
const MetaTensor& out_grad,
int hop_length,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature PReluOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"prelu", {"X", "Alpha"}, {"data_format", "mode"}, {"Out"});
}
KernelSignature PReluGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("prelu_grad",
{"X", "Alpha", "Out@GRAD"},
{"data_format", "mode"},
{"X@GRAD", "Alpha@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(prelu, phi::PReluOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(prelu_grad, phi::PReluGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册