activation_sig.cc 10.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/core/compat/op_utils.h"

namespace phi {

Y
YuanRisheng 已提交
19 20 21
#define DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(func_name, op_name, attrs) \
  KernelSignature func_name##GradOpArgumentMapping(               \
      const ArgumentMappingContext& ctx) {                        \
22 23
    return KernelSignature(                                       \
        op_name "_grad", {"X", "Out@GRAD"}, {attrs}, {"X@GRAD"}); \
24 25
  }

Y
YuanRisheng 已提交
26 27 28
#define DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(func_name, op_name, attrs) \
  KernelSignature func_name##GradOpArgumentMapping(                 \
      const ArgumentMappingContext& ctx) {                          \
29 30
    return KernelSignature(                                         \
        op_name "_grad", {"Out", "Out@GRAD"}, {attrs}, {"X@GRAD"}); \
31 32
  }

33 34 35 36 37
#define DEFINE_ACT_GRAD_NODEP_OP_ARGMAP(func_name, op_name, attrs) \
  KernelSignature func_name##GradOpArgumentMapping(                \
      const ArgumentMappingContext& ctx) {                         \
    return KernelSignature(                                        \
        op_name "_grad", {"Out@GRAD"}, {attrs}, {"X@GRAD"});       \
Y
YuanRisheng 已提交
38 39
  }

40 41
#define comma ,

42 43
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Square, "square", );  // NOLINT

44
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(HardTanh, "hard_tanh", "t_min" comma "t_max");
Y
YuanRisheng 已提交
45 46 47 48 49
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(LeakyRelu, "leaky_relu", "alpha");
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(ThresholdedRelu,
                               "thresholded_relu",
                               "threshold");
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(SoftShrink, "soft_shrink", "lambda");
50
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Mish, "mish", "threshold");
Y
YuanRisheng 已提交
51
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(TanhShrink, "tanh_shrink", );  // NOLINT
52
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Softsign, "softsign", );       // NOLINT
53
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log, "log", );                 // NOLINT
Y
YuanRisheng 已提交
54
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Celu, "celu", "alpha");        // NOLINT
Y
YuanRisheng 已提交
55 56 57 58 59
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(HardSwish,
                               "hard_swish",
                               "threshold" comma "scale" comma
                               "offset");                // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Swish, "swish", "beta");  // NOLINT
Y
YuanRisheng 已提交
60

61 62 63 64 65 66 67 68
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(STanh,
                               "stanh",
                               "scale_a" comma "scale_b");  // NOLINT

DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Softplus,
                               "softplus",
                               "beta" comma "threshold");  // NOLINT

69 70 71 72 73
DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Relu, "relu", );               // NOLINT
DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Sigmoid, "sigmoid", );         // NOLINT
DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Sqrt, "sqrt", );               // NOLINT
DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Rsqrt, "rsqrt", );             // NOLINT
DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Relu6, "relu6", "threshold");  // NOLINT
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89

KernelSignature SqrtActiOpArgumentMapping(const ArgumentMappingContext& ctx) {
  if (ctx.IsDenseTensorInput("X")) {
    return KernelSignature("sqrt", {"X"}, {}, {"Out"});
  } else {
    return KernelSignature("sqrt_sr", {"X"}, {}, {"Out"});
  }
}

KernelSignature SquareActiOpArgumentMapping(const ArgumentMappingContext& ctx) {
  if (ctx.IsDenseTensorInput("X")) {
    return KernelSignature("square", {"X"}, {}, {"Out"});
  } else {
    return KernelSignature("square_sr", {"X"}, {}, {"Out"});
  }
}
90

91 92 93 94 95
KernelSignature ReluDoubleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("relu_double_grad", {"Out", "DDX"}, {}, {"DDOut"});
}

Y
YuanRisheng 已提交
96 97 98
KernelSignature SigmoidDoubleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature(
99
      "sigmoid_double_grad", {"Out", "DOut", "DDX"}, {}, {"DOutNew", "DDOut"});
Y
YuanRisheng 已提交
100 101 102 103 104
}

KernelSignature SigmoidTripleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature("sigmoid_triple_grad",
105
                         {"Out", "DOut", "DDX", "D_DOut_New", "D_DDOut"},
Y
YuanRisheng 已提交
106 107 108 109
                         {},
                         {"D_OutNew", "D_DOut", "D_DDx"});
}

110 111 112 113 114 115 116 117 118 119
KernelSignature LeakyReluDoubleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature(
      "leaky_relu_double_grad", {"X", "DDX"}, {"alpha"}, {"DDOut"});
}

KernelSignature LeakyReluOpArgumentMapping(const ArgumentMappingContext& ctx) {
  return KernelSignature("leaky_relu", {"X"}, {"alpha"}, {"Out"});
}

Y
YuanRisheng 已提交
120 121 122 123 124
KernelSignature EluOpArgumentMapping(const ArgumentMappingContext& ctx) {
  return KernelSignature("elu", {"X"}, {"alpha"}, {"Out"});
}

KernelSignature EluGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
125 126
  return KernelSignature(
      "elu_grad", {"X", "Out", "Out@GRAD"}, {"alpha"}, {"X@GRAD"});
Y
YuanRisheng 已提交
127 128 129 130 131 132 133 134
}

KernelSignature EluDoubleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature(
      "elu_double_grad", {"X", "DOut", "DDX"}, {"alpha"}, {"DX", "DDOut"});
}

135 136 137 138 139 140
KernelSignature LogDoubleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature(
      "log_double_grad", {"X", "DOut", "DDX"}, {}, {"DX", "DDOut"});
}

Y
YuanRisheng 已提交
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
KernelSignature SqrtDoubleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature(
      "sqrt_double_grad", {"Out", "DX", "DDX"}, {}, {"DOut", "DDOut"});
}

KernelSignature RsqrtDoubleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature(
      "rsqrt_double_grad", {"Out", "DX", "DDX"}, {}, {"DOut", "DDOut"});
}

KernelSignature CeluDoubleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature(
      "celu_double_grad", {"X", "DOut", "DDX"}, {"alpha"}, {"DX", "DDOut"});
}

KernelSignature SquareDoubleGradOpArgumentMapping(
    const ArgumentMappingContext& ctx) {
  return KernelSignature(
      "square_double_grad", {"X", "DOut", "DDX"}, {}, {"DX", "DDOut"});
}

Y
YuanRisheng 已提交
165 166 167 168 169 170 171 172 173 174
KernelSignature PowOpArgumentMapping(const ArgumentMappingContext& ctx) {
  if (ctx.HasInput("FactorTensor")) {
    return KernelSignature("pow", {"X"}, {"FactorTensor"}, {"Out"});
  } else {
    return KernelSignature("pow", {"X"}, {"factor"}, {"Out"});
  }
}

KernelSignature PowGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
  if (ctx.HasInput("FactorTensor")) {
175 176
    return KernelSignature(
        "pow_grad", {"X", "Out@GRAD"}, {"FactorTensor"}, {"X@GRAD"});
Y
YuanRisheng 已提交
177 178
  } else {
    return KernelSignature(
179
        "pow_grad", {"X", "Out@GRAD"}, {"factor"}, {"X@GRAD"});
Y
YuanRisheng 已提交
180 181 182
  }
}

183 184 185
}  // namespace phi

PD_REGISTER_BASE_KERNEL_NAME(relu_grad_grad, relu_double_grad);
186
PD_REGISTER_BASE_KERNEL_NAME(leaky_relu_grad_grad, leaky_relu_double_grad);
Y
YuanRisheng 已提交
187 188 189
PD_REGISTER_BASE_KERNEL_NAME(softshrink, soft_shrink);
PD_REGISTER_BASE_KERNEL_NAME(softshrink_grad, soft_shrink_grad);
PD_REGISTER_BASE_KERNEL_NAME(elu_grad_grad, elu_double_grad);
Y
YuanRisheng 已提交
190
PD_REGISTER_BASE_KERNEL_NAME(sigmoid_grad_grad, sigmoid_double_grad);
191
PD_REGISTER_BASE_KERNEL_NAME(log_grad_grad, log_double_grad);
Y
YuanRisheng 已提交
192 193 194 195
PD_REGISTER_BASE_KERNEL_NAME(sqrt_grad_grad, sqrt_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(rsqrt_grad_grad, rsqrt_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(celu_grad_grad, celu_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(square_grad_grad, square_double_grad);
196 197
PD_REGISTER_BASE_KERNEL_NAME(brelu, hard_tanh);
PD_REGISTER_BASE_KERNEL_NAME(brelu_grad, hard_tanh_grad);
198 199

PD_REGISTER_ARG_MAPPING_FN(relu_grad, phi::ReluGradOpArgumentMapping);
200 201 202

PD_REGISTER_ARG_MAPPING_FN(square_grad, phi::SquareGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(sqrt_grad, phi::SqrtGradOpArgumentMapping);
Y
YuanRisheng 已提交
203 204
PD_REGISTER_ARG_MAPPING_FN(sqrt_grad_grad,
                           phi::SqrtDoubleGradOpArgumentMapping);
205
PD_REGISTER_ARG_MAPPING_FN(rsqrt_grad, phi::RsqrtGradOpArgumentMapping);
Y
YuanRisheng 已提交
206 207
PD_REGISTER_ARG_MAPPING_FN(rsqrt_grad_grad,
                           phi::RsqrtDoubleGradOpArgumentMapping);
208 209 210 211
PD_REGISTER_ARG_MAPPING_FN(mish_grad, phi::MishGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(stanh_grad, phi::STanhGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(softplus_grad, phi::SoftplusGradOpArgumentMapping);

212 213
PD_REGISTER_ARG_MAPPING_FN(relu_grad_grad,
                           phi::ReluDoubleGradOpArgumentMapping);
214
PD_REGISTER_ARG_MAPPING_FN(brelu_grad, phi::HardTanhGradOpArgumentMapping);
215 216 217 218 219 220 221
PD_REGISTER_ARG_MAPPING_FN(leaky_relu, phi::LeakyReluOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(leaky_relu_grad,
                           phi::LeakyReluGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(leaky_relu_grad_grad,
                           phi::LeakyReluDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(thresholded_relu_grad,
                           phi::ThresholdedReluGradOpArgumentMapping);
222
PD_REGISTER_ARG_MAPPING_FN(relu6_grad, phi::Relu6GradOpArgumentMapping);
Y
YuanRisheng 已提交
223 224 225 226 227 228 229
PD_REGISTER_ARG_MAPPING_FN(softshrink_grad,
                           phi::SoftShrinkGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(tanh_shrink_grad,
                           phi::TanhShrinkGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elu, phi::EluOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elu_grad, phi::EluGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elu_grad_grad, phi::EluDoubleGradOpArgumentMapping);
230
PD_REGISTER_ARG_MAPPING_FN(softsign_grad, phi::SoftsignGradOpArgumentMapping);
Y
YuanRisheng 已提交
231 232 233 234 235
PD_REGISTER_ARG_MAPPING_FN(sigmoid_grad, phi::SigmoidGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(sigmoid_grad_grad,
                           phi::SigmoidDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(sigmoid_triple_grad,
                           phi::SigmoidTripleGradOpArgumentMapping);
236

237 238
PD_REGISTER_ARG_MAPPING_FN(log_grad, phi::LogGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(log_grad_grad, phi::LogDoubleGradOpArgumentMapping);
239 240 241

PD_REGISTER_ARG_MAPPING_FN(sqrt, phi::SqrtActiOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(square, phi::SquareActiOpArgumentMapping);
Y
YuanRisheng 已提交
242 243 244 245 246
PD_REGISTER_ARG_MAPPING_FN(hard_swish_grad,
                           phi::HardSwishGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(swish_grad, phi::SwishGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(pow_grad, phi::PowGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(pow, phi::PowOpArgumentMapping);
Y
YuanRisheng 已提交
247 248 249 250 251
PD_REGISTER_ARG_MAPPING_FN(celu_grad, phi::CeluGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(celu_grad_grad,
                           phi::CeluDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(square_grad_grad,
                           phi::SquareDoubleGradOpArgumentMapping);