未验证 提交 bafa890a 编写于 作者: Z zyfncg 提交者: GitHub

Support generating static code of high order grad op by yaml (#47511)

* support generating static code of high order grad op by yaml

* polish code
上级 77395619
......@@ -161,13 +161,6 @@ $$out = \max(x, 0)$$
)DOC";
UNUSED constexpr char TanhDoc[] = R"DOC(
Tanh Activation Operator.
$$out = \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
)DOC";
UNUSED constexpr char TanhShrinkDoc[] = R"DOC(
TanhShrink Activation Operator.
......@@ -529,7 +522,6 @@ It is recommended to use the defaults for this activation.
REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc);
REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc);
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
REGISTER_ACTIVATION_OP_MAKER(Rsqrt, RsqrtDoc);
......@@ -699,54 +691,6 @@ class SigmoidTripleGradMaker
}
};
template <typename T>
class TanhDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
public:
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("tanh_grad_grad");
// input1: Out
op->SetInput("Out", this->Input("Out"));
// input2: ddx
op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
op->SetAttrMap(this->Attrs());
// output: ddy
op->SetOutput("DOutNew", this->InputGrad("Out"));
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
}
};
template <typename T>
class TanhTripleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
public:
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("tanh_triple_grad");
// Out, DDX, DOut, D_DDOut, D_DOut_New // input
// D_OutNew, D_DOut, D_DDx // output
// input1: Out
op->SetInput("Out", this->Input("Out"));
// input2: ddx
op->SetInput("DDX", this->Input("DDX"));
// input3: dout
op->SetInput("DOut", this->Input("DOut"));
// input4: d_ddout
op->SetInput("D_DDOut", this->OutputGrad("DDOut"));
// input5: d_dout_new
op->SetInput("D_DOut_New", this->OutputGrad("DOutNew"));
op->SetAttrMap(this->Attrs());
// output: d_dOut, d_OutNew, d_ddx
op->SetOutput("D_OutNew", this->InputGrad("Out"));
op->SetOutput("D_DOut", this->InputGrad("DOut"));
op->SetOutput("D_DDx", this->InputGrad("DDX"));
}
};
// ReluGrad: dx = dy if y >= 0 else 0
// ReluGradGrad: ddy = ddx if y >= 0 else 0
template <typename T>
......@@ -1103,38 +1047,6 @@ REGISTER_OPERATOR(sigmoid_triple_grad,
/* ========================================================================== */
/* ========================== tanh register ============================= */
REGISTER_OPERATOR(
tanh,
ops::ActivationOp,
ops::TanhOpMaker,
ops::ActivationOpInferVarType,
ops::ActivationGradOpMaker<ops::TanhGradFunctor<float>::FwdDeps(),
paddle::framework::OpDesc>,
ops::ActivationGradOpMaker<ops::TanhGradFunctor<float>::FwdDeps(),
paddle::imperative::OpBase>,
std::conditional<ops::CanInplaceAct<ops::TanhGradFunctor<float>>(),
ops::ActFwdInplaceInferer,
void>::type);
REGISTER_OPERATOR(tanh_grad,
ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInferer,
ops::TanhDoubleGradMaker<paddle::framework::OpDesc>,
ops::TanhDoubleGradMaker<paddle::imperative::OpBase>)
REGISTER_OPERATOR(
tanh_grad_grad,
ops::ActivationOpDoubleGrad<ops::TanhGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer,
ops::TanhTripleGradMaker<paddle::framework::OpDesc>,
ops::TanhTripleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
tanh_triple_grad,
ops::ActivationOpTripleGrad<ops::TanhTripleGradFunctor<float>::FwdDeps()>,
ops::ActivationTripleGradOpInplaceInferer);
/* ========================================================================== */
/* ========================== relu register ============================= */
REGISTER_OPERATOR(
relu,
......
......@@ -521,6 +521,41 @@
func : tan_grad
inplace : (out_grad -> x_grad)
- backward_op : tanh_double_grad
forward : tanh_grad (Tensor out, Tensor grad_out) -> Tensor(grad_x)
args : (Tensor out, Tensor grad_out, Tensor grad_x_grad)
output : Tensor(out_grad), Tensor(grad_out_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [out, out]
kernel :
func : tanh_double_grad
backward : tanh_triple_grad
inplace : (grad_x_grad -> grad_out_grad)
- backward_op : tanh_grad
forward : tanh (Tensor x) -> Tensor(out)
args : (Tensor out, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [out]
kernel :
func : tanh_grad
backward : tanh_double_grad
inplace : (out_grad -> x_grad)
- backward_op : tanh_triple_grad
forward : tanh_double_grad (Tensor out, Tensor grad_out_forward, Tensor grad_x_grad_forward) -> Tensor(grad_out_new), Tensor(grad_out_grad)
args : (Tensor out, Tensor grad_out_forward, Tensor grad_x_grad_forward, Tensor grad_out_new_grad, Tensor grad_out_grad_grad)
output : Tensor(out_grad), Tensor(grad_out_forward_grad), Tensor(grad_x_grad_forward_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
param : [out, out, grad_x_grad_forward]
kernel :
func : tanh_triple_grad
inplace : (grad_x_grad_forward -> grad_out_forward_grad)
- backward_op : trace_grad
forward : trace (Tensor x, int offset, int axis1, int axis2) -> Tensor(out)
args : (Tensor x, Tensor out_grad, int offset, int axis1, int axis2)
......
......@@ -86,12 +86,30 @@ def replace_compat_name(api_op_map, forward_api_dict, backward_api_dict):
if api_name != op_name:
forward_api_item['op_name'] = op_name
if 'backward' in api_args and has_backward:
bw_api_name, bw_op_name = get_api_and_op_name(
api_args['backward'].split(',')[0]
)
backward_op_list = api_args['backward'].split(',')
bw_api_name, bw_op_name = get_api_and_op_name(backward_op_list[0])
forward_api_item['backward'] = bw_op_name
backward_api_item['op_name'] = bw_op_name
# for double grad
if len(backward_op_list) > 1:
double_grad_api_name, double_grad_op_name = get_api_and_op_name(
backward_op_list[1]
)
double_grad_item = backward_api_dict[double_grad_api_name]
backward_api_item['backward'] = double_grad_op_name
double_grad_item['op_name'] = double_grad_op_name
# for triple grad
if len(backward_op_list) > 2:
(
triple_grad_api_name,
triple_grad_op_name,
) = get_api_and_op_name(backward_op_list[2])
triple_grad_item = backward_api_dict[triple_grad_api_name]
double_grad_item['backward'] = triple_grad_op_name
triple_grad_item['op_name'] = triple_grad_op_name
key_set = ['inputs', 'attrs', 'outputs']
args_map = {}
for key in key_set:
......
......@@ -389,7 +389,7 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
forward_output_orig_names)}});
{% endfor %}
grad_op->SetAttrMap(this->Attrs());
grad_op->SetAttrMap(this->Attrs());
{% for attr in api["attrs"] %}
{% set attr_name = attr["name"] %}
{% if attr_name in forward_attr_names %}
......@@ -456,15 +456,15 @@ class {{name | to_pascal_case}}OpMaker : public framework::SingleGradOpMaker<T>
input_orig_names, output_orig_names) %}{# inline #}
{% if name in input_names %}
{% set name_in_forward_orig = input_orig_names[input_names.index(name)]%}
Input("{{name_in_forward_orig}}")
Input({{name_in_forward_orig | to_opmaker_name}})
{%- elif name in output_names %}
{% set name_in_forward_orig = output_orig_names[output_names.index(name)]%}
Output("{{name}}")
Output({{name | to_opmaker_name}})
{%- elif name.endswith("_grad") %}{# output grad#}
{% set name_in_forward = name[:-5] %}
{% if name_in_forward in output_names %}
{% set name_in_forward_orig = output_orig_names[output_names.index(name_in_forward)] %}
OutputGrad("{{name_in_forward_orig}}")
OutputGrad({{name_in_forward_orig | to_opmaker_name}})
{%- endif %}
{%- endif %}
{%- endmacro %}
......@@ -474,11 +474,11 @@ OutputGrad("{{name_in_forward_orig}}")
{% if name[:-5] in input_names %}
{% set name_in_forward = name[:-5] %}
{% set name_in_forward_orig = input_orig_names[input_names.index(name_in_forward)]%}
InputGrad("{{name[:-5]}}")
InputGrad({{name_in_forward_orig | to_opmaker_name}})
{%- elif (name | to_input_name) in input_names %}
{% set name_in_forward = name | to_input_name %}
{% set name_in_forward_orig = input_orig_names[input_names.index(name_in_forward)]%}
InputGrad("{{name | to_input_name}}")
InputGrad({{name | to_input_name | to_opmaker_name}})
{%- endif %}
{%- endmacro %}
......
......@@ -2112,30 +2112,6 @@
kernel :
func : take_along_axis_grad
- backward_op : tanh_double_grad
forward : tanh_grad (Tensor out, Tensor grad_out) -> Tensor(grad_x)
args : (Tensor out, Tensor grad_out, Tensor grad_x_grad)
output : Tensor(out_grad), Tensor(grad_out_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [out, out]
kernel :
func : tanh_double_grad
backward : tanh_triple_grad
inplace : (grad_x_grad -> grad_out_grad)
- backward_op : tanh_grad
forward : tanh (Tensor x) -> Tensor(out)
args : (Tensor out, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [out]
kernel :
func : tanh_grad
backward : tanh_double_grad
inplace : (out_grad -> x_grad)
- backward_op : tanh_shrink_grad
forward : tanh_shrink (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
......@@ -2147,17 +2123,6 @@
func : tanh_shrink_grad
inplace : (out_grad -> x_grad)
- backward_op : tanh_triple_grad
forward : tanh_double_grad (Tensor out, Tensor grad_out_forward, Tensor grad_x_grad_forward) -> Tensor(grad_out_new), Tensor(grad_out_grad)
args : (Tensor out, Tensor grad_out_forward, Tensor grad_x_grad_forward, Tensor grad_out_new_grad, Tensor grad_out_grad_grad)
output : Tensor(out_grad), Tensor(grad_out_forward_grad), Tensor(grad_x_grad_forward_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
param : [out, out, grad_x_grad_forward]
kernel :
func : tanh_triple_grad
inplace : (grad_x_grad_forward -> grad_out_forward_grad)
- backward_op : temporal_shift_grad
forward : temporal_shift(Tensor x, int seg_num, float shift_ratio, str data_format_str) -> Tensor(out)
args : (Tensor out_grad, int seg_num, float shift_ratio, str data_format_str)
......
......@@ -2394,16 +2394,6 @@
data_type : arr
backward : take_along_axis_grad
- op : tanh
args : (Tensor x)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : tanh
inplace : (x -> out)
backward : tanh_grad
- op : tanh_shrink
args : (Tensor x)
output : Tensor
......
......@@ -5,6 +5,10 @@
- op : abs
backward : abs_grad
inputs :
x : X
outputs :
out : Out
extra :
attrs : [bool use_mkldnn = false]
......@@ -889,7 +893,11 @@
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
- op : tanh
backward : tanh_grad
backward : tanh_grad, tanh_double_grad (tanh_grad_grad), tanh_triple_grad
inputs :
x : X
outputs :
out : Out
extra :
attrs : [bool use_mkldnn = false, bool use_cudnn = false]
......
......@@ -461,6 +461,16 @@
func : tan
backward : tan_grad
- op : tanh
args : (Tensor x)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : tanh
inplace : (x -> out)
backward : tanh_grad
- op : trace
args : (Tensor x, int offset = 0, int axis1 = 0, int axis2 = 1)
output : Tensor
......
......@@ -67,7 +67,6 @@ DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Softplus,
"beta" comma "threshold"); // NOLINT
DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Relu, "relu", ); // NOLINT
DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Tanh, "tanh", ); // NOLINT
DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Sigmoid, "sigmoid", ); // NOLINT
DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Sqrt, "sqrt", ); // NOLINT
DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Rsqrt, "rsqrt", ); // NOLINT
......@@ -94,20 +93,6 @@ KernelSignature ReluDoubleGradOpArgumentMapping(
return KernelSignature("relu_double_grad", {"Out", "DDX"}, {}, {"DDOut"});
}
KernelSignature TanhDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"tanh_double_grad", {"Out", "DOut", "DDX"}, {}, {"DOutNew", "DDOut"});
}
KernelSignature TanhTripleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("tanh_triple_grad",
{"Out", "DOut", "DDX", "D_DOut_New", "D_DDOut"},
{},
{"D_OutNew", "D_DOut", "D_DDx"});
}
KernelSignature SigmoidDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
......@@ -198,7 +183,6 @@ KernelSignature PowGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(relu_grad_grad, relu_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(tanh_grad_grad, tanh_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(leaky_relu_grad_grad, leaky_relu_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(softshrink, soft_shrink);
PD_REGISTER_BASE_KERNEL_NAME(softshrink_grad, soft_shrink_grad);
......@@ -227,11 +211,6 @@ PD_REGISTER_ARG_MAPPING_FN(softplus_grad, phi::SoftplusGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(relu_grad_grad,
phi::ReluDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(tanh_grad, phi::TanhGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(tanh_grad_grad,
phi::TanhDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(tanh_triple_grad,
phi::TanhTripleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(brelu_grad, phi::HardTanhGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(leaky_relu, phi::LeakyReluOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(leaky_relu_grad,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册