未验证 提交 975f99ab 编写于 作者: Y YuanRisheng 提交者: GitHub

[Phi]Move Relu/Cos/Sin/Tan/Acos/Asin/Atan/Sinh/Cosh/Asinh/Acosh/Atanh kernels...

[Phi]Move Relu/Cos/Sin/Tan/Acos/Asin/Atan/Sinh/Cosh/Asinh/Acosh/Atanh kernels in Activation to Phi (#40175)

* move activation op

* adjust code format

* fix compile bugs

* fix ci bugs

* code format adjust

* code format adjust2

* activate ci status

* modify according to comment
上级 f1fe2ad4
...@@ -478,7 +478,7 @@ function(op_library TARGET) ...@@ -478,7 +478,7 @@ function(op_library TARGET)
if (${pybind_flag} EQUAL 0) if (${pybind_flag} EQUAL 0)
# NOTE(*): activation use macro to regist the kernels, set use_op manually. # NOTE(*): activation use macro to regist the kernels, set use_op manually.
if(${TARGET} STREQUAL "activation") if(${TARGET} STREQUAL "activation")
file(APPEND ${pybind_file} "USE_OP(relu);\n") file(APPEND ${pybind_file} "USE_OP_ITSELF(relu);\n")
elseif(${TARGET} STREQUAL "fake_dequantize") elseif(${TARGET} STREQUAL "fake_dequantize")
file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n") file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n")
elseif(${TARGET} STREQUAL "fake_quantize") elseif(${TARGET} STREQUAL "fake_quantize")
......
...@@ -27,7 +27,7 @@ USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); ...@@ -27,7 +27,7 @@ USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
USE_OP(leaky_relu); USE_OP(leaky_relu);
USE_OP_DEVICE_KERNEL(leaky_relu, MKLDNN); USE_OP_DEVICE_KERNEL(leaky_relu, MKLDNN);
USE_OP(gelu); USE_OP(gelu);
USE_OP(relu); USE_OP_ITSELF(relu);
USE_OP(tanh); USE_OP(tanh);
USE_OP_DEVICE_KERNEL(tanh, MKLDNN); USE_OP_DEVICE_KERNEL(tanh, MKLDNN);
......
...@@ -675,7 +675,7 @@ TEST(BuildCinnPassTest, NoNeedBufferInput) { ...@@ -675,7 +675,7 @@ TEST(BuildCinnPassTest, NoNeedBufferInput) {
USE_PASS(build_cinn_pass); USE_PASS(build_cinn_pass);
USE_OP(mul); USE_OP(mul);
USE_OP(relu); USE_OP_ITSELF(relu);
USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP(relu_grad); USE_OP_ITSELF(relu_grad);
USE_OP_ITSELF(elementwise_add_grad); USE_OP_ITSELF(elementwise_add_grad);
...@@ -301,5 +301,5 @@ TEST(CinnCompilerTest, Compile) { ...@@ -301,5 +301,5 @@ TEST(CinnCompilerTest, Compile) {
USE_PASS(build_cinn_pass); USE_PASS(build_cinn_pass);
USE_PASS(graph_viz_pass); USE_PASS(graph_viz_pass);
USE_OP(mul); USE_OP(mul);
USE_OP(relu); USE_OP_ITSELF(relu);
USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(elementwise_add);
...@@ -226,7 +226,7 @@ TEST(test_prepare_op, test_prepare_data_cpu_mkldnn) { ...@@ -226,7 +226,7 @@ TEST(test_prepare_op, test_prepare_data_cpu_mkldnn) {
} // namespace paddle } // namespace paddle
USE_OP_ITSELF(split); USE_OP_ITSELF(split);
USE_OP(relu); USE_OP_ITSELF(relu);
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
USE_OP_DEVICE_KERNEL(relu, MKLDNN); USE_OP_DEVICE_KERNEL(relu, MKLDNN);
#endif #endif
...@@ -52,7 +52,7 @@ TEST(Relu6OpConverter, main) { test_activation("relu6"); } ...@@ -52,7 +52,7 @@ TEST(Relu6OpConverter, main) { test_activation("relu6"); }
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
USE_OP(relu); USE_OP_ITSELF(relu);
USE_OP(sigmoid); USE_OP(sigmoid);
USE_OP(tanh); USE_OP(tanh);
USE_OP(relu6); USE_OP(relu6);
...@@ -132,7 +132,9 @@ struct CudnnReluGradFunctor : public CudnnActivationGradFunctor<T> { ...@@ -132,7 +132,9 @@ struct CudnnReluGradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnReluGradFunctor(const CUDADeviceContext& ctx) explicit CudnnReluGradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_RELU) {} : CudnnActivationGradFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_RELU) {}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
template <typename T> template <typename T>
...@@ -146,7 +148,9 @@ struct CudnnRelu6GradFunctor : public CudnnActivationGradFunctor<T> { ...@@ -146,7 +148,9 @@ struct CudnnRelu6GradFunctor : public CudnnActivationGradFunctor<T> {
: CudnnActivationGradFunctor<T>(ctx, 6.0, : CudnnActivationGradFunctor<T>(ctx, 6.0,
GPUDNN_ACTIVATION_CLIPPED_RELU) {} GPUDNN_ACTIVATION_CLIPPED_RELU) {}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
template <typename T> template <typename T>
...@@ -159,7 +163,9 @@ struct CudnnSigmoidGradFunctor : public CudnnActivationGradFunctor<T> { ...@@ -159,7 +163,9 @@ struct CudnnSigmoidGradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnSigmoidGradFunctor(const CUDADeviceContext& ctx) explicit CudnnSigmoidGradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_SIGMOID) {} : CudnnActivationGradFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_SIGMOID) {}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
template <typename T> template <typename T>
...@@ -172,7 +178,9 @@ struct CudnnTanhGradFunctor : public CudnnActivationGradFunctor<T> { ...@@ -172,7 +178,9 @@ struct CudnnTanhGradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnTanhGradFunctor(const CUDADeviceContext& ctx) explicit CudnnTanhGradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_TANH) {} : CudnnActivationGradFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_TANH) {}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
template <typename Functor> template <typename Functor>
...@@ -197,7 +205,8 @@ class CudnnActivationGradKernel ...@@ -197,7 +205,8 @@ class CudnnActivationGradKernel
public: public:
using T = typename Functor::ELEMENT_TYPE; using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
static_assert(Functor::FwdDeps() == kDepOut, "Forward deps must be Out."); static_assert(Functor::FwdDeps() == ActBwdOpFwdDeps::kDepOut,
"Forward deps must be Out.");
const framework::Tensor *X, *Out, *dOut; const framework::Tensor *X, *Out, *dOut;
X = Out = dOut = nullptr; X = Out = dOut = nullptr;
......
...@@ -34,7 +34,8 @@ using paddle::framework::Tensor; ...@@ -34,7 +34,8 @@ using paddle::framework::Tensor;
template <typename GradFunctor> template <typename GradFunctor>
static constexpr bool CanInplaceAct() { static constexpr bool CanInplaceAct() {
return GradFunctor::FwdDeps() == kDepOut || GradFunctor::FwdDeps() == kNoDeps; return GradFunctor::FwdDeps() == ActBwdOpFwdDeps::kDepOut ||
GradFunctor::FwdDeps() == ActBwdOpFwdDeps::kNoDeps;
} }
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \ #define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \
...@@ -921,7 +922,8 @@ class ActivationOpDoubleGrad : public framework::OperatorWithKernel { ...@@ -921,7 +922,8 @@ class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) { if (static_cast<int>(kDepValue) &
static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
if (ctx->HasOutput("DX")) { if (ctx->HasOutput("DX")) {
ctx->ShareDim("X", "DX"); ctx->ShareDim("X", "DX");
ctx->ShareLoD("X", "DX"); ctx->ShareLoD("X", "DX");
...@@ -931,7 +933,8 @@ class ActivationOpDoubleGrad : public framework::OperatorWithKernel { ...@@ -931,7 +933,8 @@ class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
ctx->ShareLoD("X", "DDOut"); ctx->ShareLoD("X", "DDOut");
} }
} }
if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) { if (static_cast<int>(kDepValue) &
static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
if (ctx->HasOutput("DOut")) { if (ctx->HasOutput("DOut")) {
ctx->ShareDim("Out", "DOut"); ctx->ShareDim("Out", "DOut");
ctx->ShareLoD("Out", "DOut"); ctx->ShareLoD("Out", "DOut");
...@@ -960,13 +963,15 @@ class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel { ...@@ -960,13 +963,15 @@ class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) { if (static_cast<int>(kDepValue) &
static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
if (ctx->HasOutput("DDOut")) { if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("X", "DDOut"); ctx->ShareDim("X", "DDOut");
ctx->ShareLoD("X", "DDOut"); ctx->ShareLoD("X", "DDOut");
} }
} }
if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) { if (static_cast<int>(kDepValue) &
static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
if (ctx->HasOutput("DDOut")) { if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("Out", "DDOut"); ctx->ShareDim("Out", "DDOut");
ctx->ShareLoD("Out", "DDOut"); ctx->ShareLoD("Out", "DDOut");
...@@ -987,7 +992,8 @@ class ActivationOpTripleGrad : public framework::OperatorWithKernel { ...@@ -987,7 +992,8 @@ class ActivationOpTripleGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) { if (static_cast<int>(kDepValue) &
static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
if (ctx->HasOutput("DX")) { if (ctx->HasOutput("DX")) {
ctx->ShareDim("X", "DX"); ctx->ShareDim("X", "DX");
ctx->ShareLoD("X", "DX"); ctx->ShareLoD("X", "DX");
...@@ -997,7 +1003,8 @@ class ActivationOpTripleGrad : public framework::OperatorWithKernel { ...@@ -997,7 +1003,8 @@ class ActivationOpTripleGrad : public framework::OperatorWithKernel {
ctx->ShareLoD("X", "DDOut"); ctx->ShareLoD("X", "DDOut");
} }
} }
if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) { if (static_cast<int>(kDepValue) &
static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
if (ctx->HasOutput("D_DOut")) { if (ctx->HasOutput("D_DOut")) {
ctx->ShareDim("Out", "D_DOut"); ctx->ShareDim("Out", "D_DOut");
ctx->ShareLoD("Out", "D_DOut"); ctx->ShareLoD("Out", "D_DOut");
...@@ -1464,6 +1471,18 @@ namespace plat = paddle::platform; ...@@ -1464,6 +1471,18 @@ namespace plat = paddle::platform;
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP); FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL); FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
REGISTER_ACTIVATION_OP(cos, Cos, CosFunctor, CosGradFunctor)
REGISTER_ACTIVATION_OP(tan, Tan, TanFunctor, TanGradFunctor);
REGISTER_ACTIVATION_OP(acos, Acos, AcosFunctor, AcosGradFunctor);
REGISTER_ACTIVATION_OP(sin, Sin, SinFunctor, SinGradFunctor);
REGISTER_ACTIVATION_OP(asin, Asin, AsinFunctor, AsinGradFunctor);
REGISTER_ACTIVATION_OP(atan, Atan, AtanFunctor, AtanGradFunctor);
REGISTER_ACTIVATION_OP(sinh, Sinh, SinhFunctor, SinhGradFunctor);
REGISTER_ACTIVATION_OP(cosh, Cosh, CoshFunctor, CoshGradFunctor);
REGISTER_ACTIVATION_OP(asinh, Asinh, AsinhFunctor, AsinhGradFunctor);
REGISTER_ACTIVATION_OP(acosh, Acosh, AcoshFunctor, AcoshGradFunctor);
REGISTER_ACTIVATION_OP(atanh, Atanh, AtanhFunctor, AtanhGradFunctor);
/* ========================== sigmoid register ============================= /* ========================== sigmoid register =============================
*/ */
// 1. Register Sigmoid Operator // 1. Register Sigmoid Operator
...@@ -1584,16 +1603,6 @@ REGISTER_OPERATOR( ...@@ -1584,16 +1603,6 @@ REGISTER_OPERATOR(
ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>, ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer); ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluCPUFunctor, ReluGradFunctor);
REGISTER_OP_CPU_KERNEL(
relu_grad_grad,
ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
ops::ReluGradGradFunctor<float>>,
ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
ops::ReluGradGradFunctor<double>>,
ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
ops::ReluGradGradFunctor<plat::float16>>);
/* ========================================================================== */ /* ========================================================================== */
/* ======================== leaky relu register ============================ */ /* ======================== leaky relu register ============================ */
......
...@@ -35,16 +35,14 @@ limitations under the License. */ ...@@ -35,16 +35,14 @@ limitations under the License. */
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
#include "paddle/phi/kernels/funcs/activation_functor.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::To32BitIndex; using framework::To32BitIndex;
enum ActBwdOpFwdDeps { using ActBwdOpFwdDeps = phi::funcs::ActBwdOpFwdDeps;
kNoDeps = 0x00, // Do not need any forward input/output
kDepX = 0x01, // Only need forward input X
kDepOut = 0x02, // Only need forward output Out
};
/* The following operator can be used to process SelectedRows, because the /* The following operator can be used to process SelectedRows, because the
* output of those operator for zero is zero too. * output of those operator for zero is zero too.
...@@ -89,7 +87,8 @@ inline void ExtractActivationGradTensor( ...@@ -89,7 +87,8 @@ inline void ExtractActivationGradTensor(
auto x_grad_var = context.OutputVar(framework::GradVarName("X")); auto x_grad_var = context.OutputVar(framework::GradVarName("X"));
const framework::Variable* out_var = nullptr; const framework::Variable* out_var = nullptr;
if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) { if (static_cast<int>(kDepValue) &
static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
out_var = context.InputVar("Out"); out_var = context.InputVar("Out");
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
out_var, platform::errors::NotFound( out_var, platform::errors::NotFound(
...@@ -139,7 +138,7 @@ inline void ExtractActivationGradTensor( ...@@ -139,7 +138,7 @@ inline void ExtractActivationGradTensor(
"Output(Out), variable name = %s", "Output(Out), variable name = %s",
context.OutputName(framework::GradVarName("X")))); context.OutputName(framework::GradVarName("X"))));
if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) { if (static_cast<int>(kDepValue) & static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
auto x_var = context.InputVar("X"); auto x_var = context.InputVar("X");
PADDLE_ENFORCE_NOT_NULL(x_var, platform::errors::NotFound( PADDLE_ENFORCE_NOT_NULL(x_var, platform::errors::NotFound(
"Cannot get the tensor from the " "Cannot get the tensor from the "
...@@ -248,6 +247,24 @@ struct SigmoidFunctor : public BaseActivationFunctor<T> { ...@@ -248,6 +247,24 @@ struct SigmoidFunctor : public BaseActivationFunctor<T> {
} }
}; };
#define USE_PHI_FUNCTOR(name) \
template <typename T> \
using name##Functor = phi::funcs::name##Functor<T>; \
template <typename T> \
using name##GradFunctor = phi::funcs::name##GradFunctor<T>;
USE_PHI_FUNCTOR(Cos)
USE_PHI_FUNCTOR(Tan)
USE_PHI_FUNCTOR(Acos)
USE_PHI_FUNCTOR(Sin)
USE_PHI_FUNCTOR(Asin)
USE_PHI_FUNCTOR(Atan)
USE_PHI_FUNCTOR(Sinh)
USE_PHI_FUNCTOR(Cosh)
USE_PHI_FUNCTOR(Asinh)
USE_PHI_FUNCTOR(Acosh)
USE_PHI_FUNCTOR(Atanh)
template <typename T> template <typename T>
struct SigmoidGradFunctor : public BaseActivationFunctor<T> { struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut, template <typename Device, typename X, typename Out, typename dOut,
...@@ -256,7 +273,9 @@ struct SigmoidGradFunctor : public BaseActivationFunctor<T> { ...@@ -256,7 +273,9 @@ struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) = dout * out * (static_cast<T>(1) - out); dx.device(d) = dout * out * (static_cast<T>(1) - out);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
/* /*
...@@ -293,7 +312,9 @@ struct SigmoidGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -293,7 +312,9 @@ struct SigmoidGradGradFunctor : public BaseActivationFunctor<T> {
ddout.device(*d) = (static_cast<T>(1) - out) * out * ddx; ddout.device(*d) = (static_cast<T>(1) - out) * out * ddx;
} }
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
/* /*
...@@ -351,7 +372,9 @@ struct SigmoidTripleGradFunctor : public BaseActivationFunctor<T> { ...@@ -351,7 +372,9 @@ struct SigmoidTripleGradFunctor : public BaseActivationFunctor<T> {
(static_cast<T>(1) - static_cast<T>(2) * out) * dout * d_dOutNew; (static_cast<T>(1) - static_cast<T>(2) * out) * dout * d_dOutNew;
} }
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
// silu(x) = x / (1 + exp(-x)) // silu(x) = x / (1 + exp(-x))
...@@ -376,7 +399,7 @@ struct SiluGradFunctor : public BaseActivationFunctor<T> { ...@@ -376,7 +399,7 @@ struct SiluGradFunctor : public BaseActivationFunctor<T> {
(static_cast<T>(1) + (temp2 / temp1))); (static_cast<T>(1) + (temp2 / temp1)));
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
// Originally: logsigmoid(x) = -log (1 + exp(-x)) // Originally: logsigmoid(x) = -log (1 + exp(-x))
...@@ -414,7 +437,7 @@ struct LogSigmoidGradFunctor : public BaseActivationFunctor<T> { ...@@ -414,7 +437,7 @@ struct LogSigmoidGradFunctor : public BaseActivationFunctor<T> {
dout * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp())); dout * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp()));
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
// exp(x) = e^x // exp(x) = e^x
...@@ -434,7 +457,9 @@ struct ExpGradFunctor : public BaseActivationFunctor<T> { ...@@ -434,7 +457,9 @@ struct ExpGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) = dout * out; dx.device(d) = dout * out;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
// expm1(x) = e^x - 1 // expm1(x) = e^x - 1
...@@ -454,38 +479,23 @@ struct Expm1GradFunctor : public BaseActivationFunctor<T> { ...@@ -454,38 +479,23 @@ struct Expm1GradFunctor : public BaseActivationFunctor<T> {
dx.device(d) = dout * out + dout; dx.device(d) = dout * out + dout;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
// relu(x) = max(x, 0) // relu(x) = max(x, 0)
template <typename T>
struct ReluCPUFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr([] HOSTDEVICE(T v) {
return v > static_cast<T>(0) ? v : static_cast<T>(0);
});
}
};
template <typename T> template <typename T>
struct ReluCUDAFunctor : public BaseActivationFunctor<T> { using ReluCPUFunctor = phi::funcs::ReluCPUFunctor<T>;
template <typename Device, typename X, typename Out> template <typename T>
void operator()(Device d, X x, Out out) const { using ReluGradFunctor = phi::funcs::ReluGradFunctor<T>;
out.device(d) = x.cwiseMax(static_cast<T>(0));
}
};
template <typename T> template <typename T>
struct ReluGradFunctor : public BaseActivationFunctor<T> { using ReluGradGradFunctor = phi::funcs::ReluGradGradFunctor<T>;
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (out > static_cast<T>(0)).template cast<T>();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } template <typename T>
}; using ReluCUDAFunctor = phi::funcs::ReluCUDAFunctor<T>;
// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) // tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T> template <typename T>
...@@ -504,7 +514,9 @@ struct TanhGradFunctor : public BaseActivationFunctor<T> { ...@@ -504,7 +514,9 @@ struct TanhGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) = dout * (static_cast<T>(1) - out * out); dx.device(d) = dout * (static_cast<T>(1) - out * out);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
template <typename T> template <typename T>
...@@ -534,7 +546,9 @@ struct TanhGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -534,7 +546,9 @@ struct TanhGradGradFunctor : public BaseActivationFunctor<T> {
ddout.device(*d) = (static_cast<T>(1) - out * out) * ddx; ddout.device(*d) = (static_cast<T>(1) - out * out) * ddx;
} }
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
/* /*
Out Out
...@@ -589,7 +603,9 @@ struct TanhTripleGradFunctor : public BaseActivationFunctor<T> { ...@@ -589,7 +603,9 @@ struct TanhTripleGradFunctor : public BaseActivationFunctor<T> {
static_cast<T>(2) * out * dout * d_dOutNew; static_cast<T>(2) * out * dout * d_dOutNew;
} }
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
// tanhshrink(x) = x - tanh(x) // tanhshrink(x) = x - tanh(x)
...@@ -610,7 +626,7 @@ struct TanhShrinkGradFunctor : public BaseActivationFunctor<T> { ...@@ -610,7 +626,7 @@ struct TanhShrinkGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) = dout * (x.tanh() * x.tanh()); dx.device(d) = dout * (x.tanh() * x.tanh());
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
// tanhshrink(x) = x - tanh(x) // tanhshrink(x) = x - tanh(x)
...@@ -646,7 +662,7 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor<T> { ...@@ -646,7 +662,7 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) = dout * (temp1 || temp2).template cast<T>(); dx.device(d) = dout * (temp1 || temp2).template cast<T>();
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
// softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < -lambda; 0 // softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < -lambda; 0
...@@ -682,7 +698,7 @@ struct SoftShrinkGradFunctor : public BaseActivationFunctor<T> { ...@@ -682,7 +698,7 @@ struct SoftShrinkGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) = dout * (temp1 + temp2).template cast<T>(); dx.device(d) = dout * (temp1 + temp2).template cast<T>();
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
// sqrt(x) = x^(1/2) // sqrt(x) = x^(1/2)
...@@ -702,7 +718,9 @@ struct SqrtGradFunctor : public BaseActivationFunctor<T> { ...@@ -702,7 +718,9 @@ struct SqrtGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) = static_cast<T>(0.5) * dout / out; dx.device(d) = static_cast<T>(0.5) * dout / out;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
// rsqrt(x) = x^(-1/2) // rsqrt(x) = x^(-1/2)
...@@ -722,7 +740,9 @@ struct RsqrtGradFunctor : public BaseActivationFunctor<T> { ...@@ -722,7 +740,9 @@ struct RsqrtGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) = static_cast<T>(-0.5) * dout * out * out * out; dx.device(d) = static_cast<T>(-0.5) * dout * out * out * out;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
// ceil(x) = ceiling(x) // ceil(x) = ceiling(x)
...@@ -742,7 +762,9 @@ struct ZeroGradFunctor : public BaseActivationFunctor<T> { ...@@ -742,7 +762,9 @@ struct ZeroGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) = static_cast<T>(0) * out; dx.device(d) = static_cast<T>(0) * out;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kNoDeps; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kNoDeps;
}
}; };
// floor(x) = flooring(x) // floor(x) = flooring(x)
...@@ -754,373 +776,6 @@ struct FloorFunctor : public BaseActivationFunctor<T> { ...@@ -754,373 +776,6 @@ struct FloorFunctor : public BaseActivationFunctor<T> {
} }
}; };
template <typename T>
struct Sine {
HOSTDEVICE T operator()(const T& val) const { return sin(val); }
};
template <>
struct Sine<platform::float16> {
HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
return platform::float16(sin(static_cast<float>(val)));
}
};
template <typename T>
struct Cosine {
HOSTDEVICE T operator()(const T& val) const { return cos(val); }
};
template <>
struct Cosine<platform::float16> {
HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
return platform::float16(cos(static_cast<float>(val)));
}
};
// cosine'(x) = -sin(x)
template <typename T>
struct CosGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = -dout * x.unaryExpr(Sine<T>());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// cosine(x) = cos(x)
template <typename T>
struct CosFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Cosine<T>());
}
};
// sine'(x) = cos(x)
template <typename T>
struct SinGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * x.unaryExpr(Cosine<T>());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// sine(x) = sin(x)
template <typename T>
struct SinFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Sine<T>());
}
};
template <typename T>
struct Tangent {
HOSTDEVICE T operator()(const T& val) const { return tan(val); }
};
template <>
struct Tangent<platform::float16> {
HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
return platform::float16(tan(static_cast<float>(val)));
}
};
// Tangent'(x) = -Tangent(x)
template <typename T>
struct TanGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout / x.unaryExpr(Cosine<T>()).square();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// Tangent(x) = tan(x)
template <typename T>
struct TanFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Tangent<T>());
}
};
template <typename T>
struct Sinh {
HOSTDEVICE T operator()(const T& val) const { return sinh(val); }
};
template <>
struct Sinh<platform::float16> {
HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
return platform::float16(sinhf(static_cast<float>(val)));
}
};
template <typename T>
struct Cosh {
HOSTDEVICE T operator()(const T& val) const { return cosh(val); }
};
template <>
struct Cosh<platform::float16> {
HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
return platform::float16(coshf(static_cast<float>(val)));
}
};
// sinh(x) = sinh(x)
template <typename T>
struct SinhFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Sinh<T>());
}
};
// cosh(x) = cosh(x)
template <typename T>
struct CoshFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Cosh<T>());
}
};
// sinh'(x) = cosh(x)
template <typename T>
struct SinhGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * x.unaryExpr(Cosh<T>());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// cosh'(x) = sinh(x)
template <typename T>
struct CoshGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * x.unaryExpr(Sinh<T>());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct Acos {
HOSTDEVICE T operator()(const T& val) const { return acos(val); }
};
template <>
struct Acos<platform::float16> {
HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
return platform::float16(acos(static_cast<float>(val)));
}
};
// Acos(x) = acos(x)
template <typename T>
struct AcosFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Acos<T>());
}
};
// acos'(x) = -1/sqrt(1-x^2)
template <typename T>
struct AcosGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) =
-dout * static_cast<T>(1) / (static_cast<T>(1) - x.square()).sqrt();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct Asin {
HOSTDEVICE T operator()(const T& val) const { return asin(val); }
};
template <>
struct Asin<platform::float16> {
HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
return platform::float16(asin(static_cast<float>(val)));
}
};
// Asin(x) = asin(x)
template <typename T>
struct AsinFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Asin<T>());
}
};
// asin'(x) = 1/sqrt(1-x^2)
template <typename T>
struct AsinGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) =
dout * static_cast<T>(1) / (static_cast<T>(1) - x.square()).sqrt();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct Atan {
HOSTDEVICE T operator()(const T& val) const { return atan(val); }
};
template <>
struct Atan<platform::float16> {
HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
return platform::float16(atan(static_cast<float>(val)));
}
};
// Atan(x) = atan(x)
template <typename T>
struct AtanFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Atan<T>());
}
};
// atan'(x) = 1 / (1 + x^2)
template <typename T>
struct AtanGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(1) / (static_cast<T>(1) + x.square());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct Acosh {
HOSTDEVICE T operator()(const T& val) const { return acosh(val); }
};
template <>
struct Acosh<platform::float16> {
HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
return platform::float16(acosh(static_cast<float>(val)));
}
};
// Acosh(x) = acosh(x)
template <typename T>
struct AcoshFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Acosh<T>());
}
};
// acosh'(x) = 1/sqrt(x^2 - 1)
template <typename T>
struct AcoshGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) =
dout * static_cast<T>(1) / (x * x - static_cast<T>(1)).sqrt();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct Asinh {
HOSTDEVICE T operator()(const T& val) const { return asinh(val); }
};
template <>
struct Asinh<platform::float16> {
HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
return platform::float16(asinh(static_cast<float>(val)));
}
};
// Asinh(x) = asinh(x)
template <typename T>
struct AsinhFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Asinh<T>());
}
};
// asinh'(x) = 1/sqrt(x^2 + 1)
template <typename T>
struct AsinhGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) =
dout * static_cast<T>(1) / (x.square() + static_cast<T>(1)).sqrt();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct Atanh {
HOSTDEVICE T operator()(const T& val) const { return atanh(val); }
};
template <>
struct Atanh<platform::float16> {
HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
return platform::float16(atanh(static_cast<float>(val)));
}
};
// Atanh(x) = atanh(x)
template <typename T>
struct AtanhFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Atanh<T>());
}
};
// atanh'(x) = 1/(1 - x^2)
template <typename T>
struct AtanhGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(1) / (static_cast<T>(1) - x.square());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// round(x) = [x] // round(x) = [x]
template <typename T> template <typename T>
struct RoundFunctor : public BaseActivationFunctor<T> { struct RoundFunctor : public BaseActivationFunctor<T> {
...@@ -1147,7 +802,9 @@ struct ReciprocalGradFunctor : public BaseActivationFunctor<T> { ...@@ -1147,7 +802,9 @@ struct ReciprocalGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) = dout * static_cast<T>(-1) * out * out; dx.device(d) = dout * static_cast<T>(-1) * out * out;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
// log(x) = natural logarithm of x // log(x) = natural logarithm of x
...@@ -1167,7 +824,7 @@ struct LogGradFunctor : public BaseActivationFunctor<T> { ...@@ -1167,7 +824,7 @@ struct LogGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) = dout * (static_cast<T>(1) / x); dx.device(d) = dout * (static_cast<T>(1) / x);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
// log2(x) = logarithm to the base 2 of the elements of x // log2(x) = logarithm to the base 2 of the elements of x
...@@ -1188,7 +845,7 @@ struct Log2GradFunctor : public BaseActivationFunctor<T> { ...@@ -1188,7 +845,7 @@ struct Log2GradFunctor : public BaseActivationFunctor<T> {
dx.device(d) = dout * static_cast<T>(1) / (x * static_cast<T>(log(2))); dx.device(d) = dout * static_cast<T>(1) / (x * static_cast<T>(log(2)));
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
// log10(x) = logarithm to the base 10 of the elements of x // log10(x) = logarithm to the base 10 of the elements of x
...@@ -1209,7 +866,7 @@ struct Log10GradFunctor : public BaseActivationFunctor<T> { ...@@ -1209,7 +866,7 @@ struct Log10GradFunctor : public BaseActivationFunctor<T> {
dx.device(d) = dout * static_cast<T>(1) / (x * static_cast<T>(log(10))); dx.device(d) = dout * static_cast<T>(1) / (x * static_cast<T>(log(10)));
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
// log1p(x) = natural logarithm of x+1 // log1p(x) = natural logarithm of x+1
...@@ -1229,7 +886,7 @@ struct Log1pGradFunctor : public BaseActivationFunctor<T> { ...@@ -1229,7 +886,7 @@ struct Log1pGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) = dout * (static_cast<T>(1) / (x + static_cast<T>(1))); dx.device(d) = dout * (static_cast<T>(1) / (x + static_cast<T>(1)));
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
// square(x) = x^2 // square(x) = x^2
...@@ -1249,7 +906,7 @@ struct SquareGradFunctor : public BaseActivationFunctor<T> { ...@@ -1249,7 +906,7 @@ struct SquareGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) = dout * static_cast<T>(2) * x; dx.device(d) = dout * static_cast<T>(2) * x;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T> template <typename T>
...@@ -1285,7 +942,7 @@ struct BReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -1285,7 +942,7 @@ struct BReluGradFunctor : public BaseActivationFunctor<T> {
.template cast<T>(); .template cast<T>();
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
// relu6(x) = min(max(0, x), 6) // relu6(x) = min(max(0, x), 6)
...@@ -1319,7 +976,9 @@ struct Relu6GradFunctor : public BaseActivationFunctor<T> { ...@@ -1319,7 +976,9 @@ struct Relu6GradFunctor : public BaseActivationFunctor<T> {
.template cast<T>(); .template cast<T>();
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
// HardSwish = min(max(0, x+3), 6) * x / 6 // HardSwish = min(max(0, x+3), 6) * x / 6
...@@ -1364,7 +1023,7 @@ struct HardSwishGradFunctor : public BaseActivationFunctor<T> { ...@@ -1364,7 +1023,7 @@ struct HardSwishGradFunctor : public BaseActivationFunctor<T> {
static_cast<T>(1) * (static_cast<T>(1) - tmp)); static_cast<T>(1) * (static_cast<T>(1) - tmp));
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
// For numerical stability, using the following formula instead of softplus(x) = // For numerical stability, using the following formula instead of softplus(x) =
...@@ -1409,7 +1068,7 @@ struct SoftplusGradFunctor : public BaseActivationFunctor<T> { ...@@ -1409,7 +1068,7 @@ struct SoftplusGradFunctor : public BaseActivationFunctor<T> {
.select(dout, dout / (static_cast<T>(1) + (-x_beta).exp())); .select(dout, dout / (static_cast<T>(1) + (-x_beta).exp()));
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
// mish(x) = x * tanh(softplus(x)) // mish(x) = x * tanh(softplus(x))
...@@ -1449,7 +1108,7 @@ struct MishGradFunctor : public BaseActivationFunctor<T> { ...@@ -1449,7 +1108,7 @@ struct MishGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) = dout * (tsp + x * (static_cast<T>(1) - tsp * tsp) * gsp); dx.device(d) = dout * (tsp + x * (static_cast<T>(1) - tsp * tsp) * gsp);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
// softsign(x) = x / (1 + |x|) // softsign(x) = x / (1 + |x|)
...@@ -1472,7 +1131,7 @@ struct SoftsignGradFunctor : public BaseActivationFunctor<T> { ...@@ -1472,7 +1131,7 @@ struct SoftsignGradFunctor : public BaseActivationFunctor<T> {
dout * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square()); dout * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square());
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T> template <typename T>
...@@ -1504,7 +1163,9 @@ struct SoftReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -1504,7 +1163,9 @@ struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) = dout * (static_cast<T>(1) - (-out).exp()) * temp; dx.device(d) = dout * (static_cast<T>(1) - (-out).exp()) * temp;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
template <typename T> template <typename T>
...@@ -1539,7 +1200,7 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -1539,7 +1200,7 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) = dout * (temp1 + temp2).template cast<T>(); dx.device(d) = dout * (temp1 + temp2).template cast<T>();
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T> template <typename T>
...@@ -1573,7 +1234,7 @@ struct ELUGradFunctor : public BaseActivationFunctor<T> { ...@@ -1573,7 +1234,7 @@ struct ELUGradFunctor : public BaseActivationFunctor<T> {
.select(dout, dout * (out + static_cast<T>(alpha))); .select(dout, dout * (out + static_cast<T>(alpha)));
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T> template <typename T>
...@@ -1592,7 +1253,7 @@ struct ELUGradNegativeAlphaFunctor : public BaseActivationFunctor<T> { ...@@ -1592,7 +1253,7 @@ struct ELUGradNegativeAlphaFunctor : public BaseActivationFunctor<T> {
.select(dout, dout * static_cast<T>(alpha) * x.exp()); .select(dout, dout * static_cast<T>(alpha) * x.exp());
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -1672,7 +1333,7 @@ struct CELUGradFunctor : public BaseActivationFunctor<T> { ...@@ -1672,7 +1333,7 @@ struct CELUGradFunctor : public BaseActivationFunctor<T> {
dout * (x / static_cast<T>(alpha)).exp() * temp_a_neg * temp_x_neg; dout * (x / static_cast<T>(alpha)).exp() * temp_a_neg * temp_x_neg;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
// FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198 // FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198
...@@ -1701,7 +1362,7 @@ struct PowGradFunctor : public BaseActivationFunctor<T> { ...@@ -1701,7 +1362,7 @@ struct PowGradFunctor : public BaseActivationFunctor<T> {
x.pow(static_cast<T>(factor) - static_cast<T>(1)); x.pow(static_cast<T>(factor) - static_cast<T>(1));
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T> template <typename T>
...@@ -1766,7 +1427,7 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> { ...@@ -1766,7 +1427,7 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) = dout * a * b * (static_cast<T>(1) - temp); dx.device(d) = dout * a * b * (static_cast<T>(1) - temp);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T> template <typename T>
...@@ -1797,7 +1458,7 @@ struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -1797,7 +1458,7 @@ struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) = dout * (x > th).template cast<T>(); dx.device(d) = dout * (x > th).template cast<T>();
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T> template <typename T>
...@@ -1832,7 +1493,9 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> { ...@@ -1832,7 +1493,9 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> {
static_cast<T>(slope); static_cast<T>(slope);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
template <typename T> template <typename T>
...@@ -1865,7 +1528,7 @@ struct SwishGradFunctor : public BaseActivationFunctor<T> { ...@@ -1865,7 +1528,7 @@ struct SwishGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) = dout * ((static_cast<T>(beta) * out) + temp2); dx.device(d) = dout * ((static_cast<T>(beta) * out) + temp2);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
/* /*
...@@ -1902,7 +1565,7 @@ inline void ExtractActivationDoubleGradTensor( ...@@ -1902,7 +1565,7 @@ inline void ExtractActivationDoubleGradTensor(
"Cannot get the tensor from the Variable Output, variable name = %s", "Cannot get the tensor from the Variable Output, variable name = %s",
ctx.OutputName("DDX"))); ctx.OutputName("DDX")));
if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) { if (static_cast<int>(kDepValue) & static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
auto x_var = ctx.InputVar("X"); auto x_var = ctx.InputVar("X");
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
x_var, platform::errors::NotFound( x_var, platform::errors::NotFound(
...@@ -1925,7 +1588,8 @@ inline void ExtractActivationDoubleGradTensor( ...@@ -1925,7 +1588,8 @@ inline void ExtractActivationDoubleGradTensor(
VLOG(10) << "Inplace activation of Op: " << ctx.Type(); VLOG(10) << "Inplace activation of Op: " << ctx.Type();
*X = *ddX; *X = *ddX;
} }
if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) { if (static_cast<int>(kDepValue) &
static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
auto out_var = ctx.InputVar("Out"); auto out_var = ctx.InputVar("Out");
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
out_var, out_var,
...@@ -2000,28 +1664,7 @@ struct AbsGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -2000,28 +1664,7 @@ struct AbsGradGradFunctor : public BaseActivationFunctor<T> {
ddout.device(*d) = ddx * x.sign(); ddout.device(*d) = ddx * x.sign();
} }
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct ReluGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev, const framework::Tensor* X,
const framework::Tensor* Out, const framework::Tensor* ddX,
framework::Tensor* ddOut, framework::Tensor* dOut,
framework::Tensor* dX) const {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "ReluGradGrad"));
auto out = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "ReluGradGrad"));
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "ReluGradGrad"));
ddout.device(*d) = ddx * (out > static_cast<T>(0)).template cast<T>();
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
}; };
template <typename T> template <typename T>
...@@ -2050,7 +1693,7 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -2050,7 +1693,7 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> {
.template cast<T>(); .template cast<T>();
} }
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T> template <typename T>
...@@ -2088,7 +1731,7 @@ struct ELUGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -2088,7 +1731,7 @@ struct ELUGradGradFunctor : public BaseActivationFunctor<T> {
.template cast<T>(); .template cast<T>();
} }
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T> template <typename T>
...@@ -2127,7 +1770,7 @@ struct CELUGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -2127,7 +1770,7 @@ struct CELUGradGradFunctor : public BaseActivationFunctor<T> {
.template cast<T>(); .template cast<T>();
} }
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T> template <typename T>
...@@ -2156,7 +1799,9 @@ struct SqrtGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -2156,7 +1799,9 @@ struct SqrtGradGradFunctor : public BaseActivationFunctor<T> {
ddout.device(*d) = ddx * static_cast<T>(0.5) / out; ddout.device(*d) = ddx * static_cast<T>(0.5) / out;
} }
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
template <typename T> template <typename T>
...@@ -2185,7 +1830,9 @@ struct RsqrtGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -2185,7 +1830,9 @@ struct RsqrtGradGradFunctor : public BaseActivationFunctor<T> {
ddout.device(*d) = ddx * static_cast<T>(-0.5) * out * out * out; ddout.device(*d) = ddx * static_cast<T>(-0.5) * out * out * out;
} }
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
template <typename T> template <typename T>
...@@ -2214,7 +1861,7 @@ struct SquareGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -2214,7 +1861,7 @@ struct SquareGradGradFunctor : public BaseActivationFunctor<T> {
ddout.device(*d) = ddx * static_cast<T>(2) * x; ddout.device(*d) = ddx * static_cast<T>(2) * x;
} }
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
// TODO(dengkaipeng): double gradient calculation for Square/Sqrt need // TODO(dengkaipeng): double gradient calculation for Square/Sqrt need
...@@ -2840,7 +2487,7 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -2840,7 +2487,7 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> {
} }
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
} // namespace operators } // namespace operators
...@@ -2849,20 +2496,9 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -2849,20 +2496,9 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> {
#define FOR_EACH_ACTIVATION_OP(__macro) \ #define FOR_EACH_ACTIVATION_OP(__macro) \
__macro(silu, Silu, SiluFunctor, SiluGradFunctor); \ __macro(silu, Silu, SiluFunctor, SiluGradFunctor); \
__macro(logsigmoid, LogSigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \ __macro(logsigmoid, LogSigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \
__macro(atan, Atan, AtanFunctor, AtanGradFunctor); \
__macro(softshrink, SoftShrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \ __macro(softshrink, SoftShrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
__macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \ __macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \
__macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \ __macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \
__macro(cos, Cos, CosFunctor, CosGradFunctor); \
__macro(tan, Tan, TanFunctor, TanGradFunctor); \
__macro(acos, Acos, AcosFunctor, AcosGradFunctor); \
__macro(sin, Sin, SinFunctor, SinGradFunctor); \
__macro(asin, Asin, AsinFunctor, AsinGradFunctor); \
__macro(sinh, Sinh, SinhFunctor, SinhGradFunctor); \
__macro(cosh, Cosh, CoshFunctor, CoshGradFunctor); \
__macro(asinh, Asinh, AsinhFunctor, AsinhGradFunctor); \
__macro(acosh, Acosh, AcoshFunctor, AcoshGradFunctor); \
__macro(atanh, Atanh, AtanhFunctor, AtanhGradFunctor); \
__macro(round, Round, RoundFunctor, ZeroGradFunctor); \ __macro(round, Round, RoundFunctor, ZeroGradFunctor); \
__macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ __macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(log1p, Log1p, Log1pFunctor, Log1pGradFunctor); \ __macro(log1p, Log1p, Log1pFunctor, Log1pGradFunctor); \
......
...@@ -18,28 +18,6 @@ limitations under the License. */ ...@@ -18,28 +18,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T>
struct CudaReluFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
// relu(x) = max(x, 0)
__device__ __forceinline__ T operator()(const T x) const {
return x > zero ? x : zero;
}
};
template <typename T>
struct CudaReluGradFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
// dx = dout * (out > 0)
__device__ __forceinline__ T operator()(const T dout, const T out) const {
return out > zero ? dout : zero;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <typename T> template <typename T>
struct CudaLeakyReluFunctor : public BaseActivationFunctor<T> { struct CudaLeakyReluFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f); T zero = static_cast<T>(0.0f);
...@@ -69,7 +47,7 @@ struct CudaLeakyReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -69,7 +47,7 @@ struct CudaLeakyReluGradFunctor : public BaseActivationFunctor<T> {
return x > zero ? dout : static_cast<T>(alpha) * dout; return x > zero ? dout : static_cast<T>(alpha) * dout;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T> template <typename T>
...@@ -93,7 +71,9 @@ struct CudaSigmoidGradFunctor : public BaseActivationFunctor<T> { ...@@ -93,7 +71,9 @@ struct CudaSigmoidGradFunctor : public BaseActivationFunctor<T> {
return dout * out * (one - out); return dout * out * (one - out);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
template <typename T> template <typename T>
...@@ -122,7 +102,7 @@ struct CudaSiluGradFunctor : public BaseActivationFunctor<T> { ...@@ -122,7 +102,7 @@ struct CudaSiluGradFunctor : public BaseActivationFunctor<T> {
return static_cast<T>(dout * (temp * (one + x * (one - temp)))); return static_cast<T>(dout * (temp * (one + x * (one - temp))));
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T> template <typename T>
...@@ -159,30 +139,7 @@ struct CudaLogSigmoidGradFunctor : public BaseActivationFunctor<T> { ...@@ -159,30 +139,7 @@ struct CudaLogSigmoidGradFunctor : public BaseActivationFunctor<T> {
return static_cast<T>(dout * (temp2 / (exp(-temp1) + temp2))); return static_cast<T>(dout * (temp2 / (exp(-temp1) + temp2)));
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaAtanFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// atan(x) = atan(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(atan(x));
}
};
template <typename T>
struct CudaAtanGradFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);
// dx = dout / (1 + x^2)
__device__ __forceinline__ T operator()(const T dout, const T x) const {
return dout / (one + x * x);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
}; };
template <typename T> template <typename T>
...@@ -219,7 +176,7 @@ struct CudaSoftShrinkGradFunctor : public BaseActivationFunctor<T> { ...@@ -219,7 +176,7 @@ struct CudaSoftShrinkGradFunctor : public BaseActivationFunctor<T> {
return (x >= -l && x <= l) ? zero : dout; return (x >= -l && x <= l) ? zero : dout;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T> template <typename T>
...@@ -262,191 +219,9 @@ struct CudaZeroGradFunctor : public BaseActivationFunctor<T> { ...@@ -262,191 +219,9 @@ struct CudaZeroGradFunctor : public BaseActivationFunctor<T> {
return static_cast<T>(0.0f); return static_cast<T>(0.0f);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kNoDeps; } static constexpr ActBwdOpFwdDeps FwdDeps() {
}; return ActBwdOpFwdDeps::kNoDeps;
template <typename T>
struct CudaCosFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// cos(x) = cos(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(cos(x));
}
};
template <typename T>
struct CudaCosGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// dx = dout * (-sin(x))
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(-dout * sin(x));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaSinFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// sin(x) = sin(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(sin(x));
}
};
template <typename T>
struct CudaSinGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// dx = dout * cos(x)
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(dout * cos(x));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaTanFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// tan(x) = tan(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(tan(x));
}
};
template <typename T>
struct CudaTanGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// dx = dout / cos(x)^2
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(dout / (cos(x) * cos(x)));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaAsinFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// asin(x) = asin(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(asin(x));
}
};
template <typename T>
struct CudaAsinGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
// dx = dout / sqrt(1 - x^2)
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(dout / sqrt(one - x * x));
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaAcosFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// acos(x) = acos(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(acos(x));
}
};
template <typename T>
struct CudaAcosGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
// dx = -dout / sqrt(1 - x^2)
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(-dout / sqrt(one - x * x));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaCoshFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// cosh(x) = cosh(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(cosh(x));
}
};
template <typename T>
struct CudaCoshGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// dx = dout * sinh(x)
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(dout * sinh(x));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaSinhFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// sinh(x) = sinh(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(sinh(x));
}
};
template <typename T>
struct CudaSinhGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// dx = dout * cosh(x)
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(dout * cosh(x));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
}; };
template <typename T> template <typename T>
...@@ -469,88 +244,11 @@ struct CudaTanhGradFunctor : public BaseActivationFunctor<T> { ...@@ -469,88 +244,11 @@ struct CudaTanhGradFunctor : public BaseActivationFunctor<T> {
return dout * (one - out * out); return dout * (one - out * out);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
}; return ActBwdOpFwdDeps::kDepOut;
template <typename T>
struct CudaAcoshFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// Acosh(x) = acosh(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(acosh(x));
} }
}; };
template <typename T>
struct CudaAcoshGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
// dx = dout * 1 / sqrt(x^2 - 1)
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(dout * one / sqrt(x * x - one));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaAsinhFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// Asinh(x) = asinh(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(asinh(x));
}
};
template <typename T>
struct CudaAsinhGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
// dx = dout * 1/sqrt(x^2 + 1)
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(dout * one / sqrt(x * x + one));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaAtanhFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// Atanh(x) = atanh(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(atanh(x));
}
};
template <typename T>
struct CudaAtanhGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
// dx = dout * 1/(1- x^2)
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(dout * one / (one - x * x));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T> template <typename T>
struct CudaReciprocalFunctor : public BaseActivationFunctor<T> { struct CudaReciprocalFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f); T one = static_cast<T>(1.0f);
...@@ -566,7 +264,9 @@ struct CudaReciprocalGradFunctor : public BaseActivationFunctor<T> { ...@@ -566,7 +264,9 @@ struct CudaReciprocalGradFunctor : public BaseActivationFunctor<T> {
return -dout * out * out; return -dout * out * out;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
template <typename T> template <typename T>
...@@ -587,7 +287,9 @@ struct CudaExpGradFunctor : public BaseActivationFunctor<T> { ...@@ -587,7 +287,9 @@ struct CudaExpGradFunctor : public BaseActivationFunctor<T> {
return dout * out; return dout * out;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
template <typename T> template <typename T>
...@@ -608,7 +310,9 @@ struct CudaExpm1GradFunctor : public BaseActivationFunctor<T> { ...@@ -608,7 +310,9 @@ struct CudaExpm1GradFunctor : public BaseActivationFunctor<T> {
return dout * out + dout; return dout * out + dout;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
template <typename T> template <typename T>
...@@ -629,7 +333,7 @@ struct CudaLogGradFunctor : public BaseActivationFunctor<T> { ...@@ -629,7 +333,7 @@ struct CudaLogGradFunctor : public BaseActivationFunctor<T> {
return dout / x; return dout / x;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T> template <typename T>
...@@ -647,7 +351,7 @@ struct CudaSquareGradFunctor : public BaseActivationFunctor<T> { ...@@ -647,7 +351,7 @@ struct CudaSquareGradFunctor : public BaseActivationFunctor<T> {
return dout * two * x; return dout * two * x;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T> template <typename T>
...@@ -670,7 +374,9 @@ struct CudaSqrtGradFunctor : public BaseActivationFunctor<T> { ...@@ -670,7 +374,9 @@ struct CudaSqrtGradFunctor : public BaseActivationFunctor<T> {
return one_half * dout / out; return one_half * dout / out;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
template <typename T> template <typename T>
...@@ -693,7 +399,9 @@ struct CudaRsqrtGradFunctor : public BaseActivationFunctor<T> { ...@@ -693,7 +399,9 @@ struct CudaRsqrtGradFunctor : public BaseActivationFunctor<T> {
return minus_one_half * dout * out * out * out; return minus_one_half * dout * out * out * out;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
template <typename T> template <typename T>
...@@ -717,7 +425,7 @@ struct CudaLog1pGradFunctor : public BaseActivationFunctor<T> { ...@@ -717,7 +425,7 @@ struct CudaLog1pGradFunctor : public BaseActivationFunctor<T> {
return dout / (one + x); return dout / (one + x);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T> template <typename T>
...@@ -741,7 +449,7 @@ struct CudaLog2GradFunctor : public BaseActivationFunctor<T> { ...@@ -741,7 +449,7 @@ struct CudaLog2GradFunctor : public BaseActivationFunctor<T> {
return dout / (x * log_two); return dout / (x * log_two);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T> template <typename T>
...@@ -765,7 +473,7 @@ struct CudaLog10GradFunctor : public BaseActivationFunctor<T> { ...@@ -765,7 +473,7 @@ struct CudaLog10GradFunctor : public BaseActivationFunctor<T> {
return dout / (x * log_ten); return dout / (x * log_ten);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T> template <typename T>
...@@ -804,7 +512,7 @@ struct CudaBReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -804,7 +512,7 @@ struct CudaBReluGradFunctor : public BaseActivationFunctor<T> {
return (x > t_min_cast && x < t_max_cast) ? dout : zero; return (x > t_min_cast && x < t_max_cast) ? dout : zero;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T> template <typename T>
...@@ -849,7 +557,9 @@ struct CudaSoftReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -849,7 +557,9 @@ struct CudaSoftReluGradFunctor : public BaseActivationFunctor<T> {
: static_cast<T>(0.0f); : static_cast<T>(0.0f);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
template <typename T> template <typename T>
...@@ -893,7 +603,7 @@ struct CudaSTanhGradFunctor : public BaseActivationFunctor<T> { ...@@ -893,7 +603,7 @@ struct CudaSTanhGradFunctor : public BaseActivationFunctor<T> {
return static_cast<T>(dout * a * b * (one - temp * temp)); return static_cast<T>(dout * a * b * (one - temp * temp));
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T> template <typename T>
...@@ -939,7 +649,7 @@ struct CudaSoftplusGradFunctor : public BaseActivationFunctor<T> { ...@@ -939,7 +649,7 @@ struct CudaSoftplusGradFunctor : public BaseActivationFunctor<T> {
return x_beta > t ? arg_dout : static_cast<T>(dout / (one + exp(-x_beta))); return x_beta > t ? arg_dout : static_cast<T>(dout / (one + exp(-x_beta)));
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T> template <typename T>
...@@ -962,7 +672,7 @@ struct CudaSoftsignGradFunctor : public BaseActivationFunctor<T> { ...@@ -962,7 +672,7 @@ struct CudaSoftsignGradFunctor : public BaseActivationFunctor<T> {
return dout / (temp * temp); return dout / (temp * temp);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T> template <typename T>
...@@ -996,7 +706,9 @@ struct CudaRelu6GradFunctor : public BaseActivationFunctor<T> { ...@@ -996,7 +706,9 @@ struct CudaRelu6GradFunctor : public BaseActivationFunctor<T> {
return (out > zero && out < t) ? dout : zero; return (out > zero && out < t) ? dout : zero;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
template <typename T> template <typename T>
...@@ -1022,7 +734,7 @@ struct CudaTanhShrinkGradFunctor : public BaseActivationFunctor<T> { ...@@ -1022,7 +734,7 @@ struct CudaTanhShrinkGradFunctor : public BaseActivationFunctor<T> {
return static_cast<T>(dout * tanh(x) * tanh(x)); return static_cast<T>(dout * tanh(x) * tanh(x));
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T> template <typename T>
...@@ -1056,7 +768,7 @@ struct CudaHardShrinkGradFunctor : public BaseActivationFunctor<T> { ...@@ -1056,7 +768,7 @@ struct CudaHardShrinkGradFunctor : public BaseActivationFunctor<T> {
return (x > -t && x < t) ? zero : dout; return (x > -t && x < t) ? zero : dout;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T> template <typename T>
...@@ -1097,7 +809,9 @@ struct CudaHardSigmoidGradFunctor : public BaseActivationFunctor<T> { ...@@ -1097,7 +809,9 @@ struct CudaHardSigmoidGradFunctor : public BaseActivationFunctor<T> {
return (out > zero && out < one) ? dout * static_cast<T>(slope) : zero; return (out > zero && out < one) ? dout * static_cast<T>(slope) : zero;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
template <typename T> template <typename T>
...@@ -1141,7 +855,7 @@ struct CudaSwishGradFunctor : public BaseActivationFunctor<T> { ...@@ -1141,7 +855,7 @@ struct CudaSwishGradFunctor : public BaseActivationFunctor<T> {
return static_cast<T>(dout * (temp2 + temp3)); return static_cast<T>(dout * (temp2 + temp3));
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T> template <typename T>
...@@ -1190,7 +904,7 @@ struct CudaMishGradFunctor : public BaseActivationFunctor<T> { ...@@ -1190,7 +904,7 @@ struct CudaMishGradFunctor : public BaseActivationFunctor<T> {
return static_cast<T>(dout * (tsp + x * (one - tsp * tsp) * gsp)); return static_cast<T>(dout * (tsp + x * (one - tsp * tsp) * gsp));
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T> template <typename T>
...@@ -1222,7 +936,7 @@ struct CudaThresholdedReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -1222,7 +936,7 @@ struct CudaThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
return x > static_cast<T>(threshold) ? dout : zero; return x > static_cast<T>(threshold) ? dout : zero;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T> template <typename T>
...@@ -1274,7 +988,7 @@ struct CudaHardSwishGradFunctor : public BaseActivationFunctor<T> { ...@@ -1274,7 +988,7 @@ struct CudaHardSwishGradFunctor : public BaseActivationFunctor<T> {
return dout * (temp1 * temp2 * (two * x + o) / s + one - temp2); return dout * (temp1 * temp2 * (two * x + o) / s + one - temp2);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T> template <typename T>
...@@ -1320,7 +1034,9 @@ struct CudaELUGradFunctor : public BaseActivationFunctor<T> { ...@@ -1320,7 +1034,9 @@ struct CudaELUGradFunctor : public BaseActivationFunctor<T> {
return static_cast<T>(dout * (out_pos + out_neg * (out + a))); return static_cast<T>(dout * (out_pos + out_neg * (out + a)));
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
}; };
template <typename T> template <typename T>
...@@ -1347,7 +1063,7 @@ struct CudaELUGradNegativeAlphaFunctor : public BaseActivationFunctor<T> { ...@@ -1347,7 +1063,7 @@ struct CudaELUGradNegativeAlphaFunctor : public BaseActivationFunctor<T> {
return static_cast<T>(dout * (x_pos + x_neg * (out + a))); return static_cast<T>(dout * (x_pos + x_neg * (out + a)));
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
...@@ -1429,7 +1145,7 @@ struct CudaCELUGradFunctor : public BaseActivationFunctor<T> { ...@@ -1429,7 +1145,7 @@ struct CudaCELUGradFunctor : public BaseActivationFunctor<T> {
temp_a_neg * temp_x_pos + exp(x / a) * temp_a_neg * temp_x_neg)); temp_a_neg * temp_x_pos + exp(x / a) * temp_a_neg * temp_x_neg));
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename DeviceContext, typename Functor> template <typename DeviceContext, typename Functor>
...@@ -1477,13 +1193,14 @@ class ActivationGradCudaKernel ...@@ -1477,13 +1193,14 @@ class ActivationGradCudaKernel
std::vector<const framework::Tensor*> ins = {d_out}; std::vector<const framework::Tensor*> ins = {d_out};
std::vector<framework::Tensor*> outs = {d_x}; std::vector<framework::Tensor*> outs = {d_x};
if (static_cast<int>(Functor::FwdDeps()) == static_cast<int>(kDepOut)) { if (static_cast<int>(Functor::FwdDeps()) ==
static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
// Only need forward output Out // Only need forward output Out
ins.push_back(out); ins.push_back(out);
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins, paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
&outs, functor); &outs, functor);
} else if (static_cast<int>(Functor::FwdDeps()) == } else if (static_cast<int>(Functor::FwdDeps()) ==
static_cast<int>(kDepX)) { static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
// Only need forward input X // Only need forward input X
ins.push_back(x); ins.push_back(x);
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins, paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
...@@ -1602,50 +1319,6 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -1602,50 +1319,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::CELUGradGradFunctor<plat::float16>>); ops::CELUGradGradFunctor<plat::float16>>);
/* ========================================================================== */ /* ========================================================================== */
/* =========================== relu register ============================ */
#ifdef PADDLE_WITH_HIP
REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, CudaReluFunctor,
CudaReluGradFunctor);
REGISTER_OP_CUDA_KERNEL(
relu_grad_grad,
ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::ReluGradGradFunctor<float>>,
ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::ReluGradGradFunctor<double>>,
ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
ops::ReluGradGradFunctor<plat::float16>>);
#else
REGISTER_OP_CUDA_KERNEL(
relu, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
ops::CudaReluFunctor<float>>,
ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
ops::CudaReluFunctor<double>>,
ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaReluFunctor<plat::float16>>,
ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaReluFunctor<plat::bfloat16>>);
REGISTER_OP_CUDA_KERNEL(
relu_grad, ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaReluGradFunctor<float>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaReluGradFunctor<double>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaReluGradFunctor<plat::float16>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaReluGradFunctor<plat::bfloat16>>);
REGISTER_OP_CUDA_KERNEL(
relu_grad_grad,
ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::ReluGradGradFunctor<float>>,
ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::ReluGradGradFunctor<double>>,
ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
ops::ReluGradGradFunctor<plat::float16>>,
ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
ops::ReluGradGradFunctor<plat::bfloat16>>);
#endif
/* ========================================================================== */
/* =========================== sigmoid register ============================ /* =========================== sigmoid register ============================
*/ */
REGISTER_ACTIVATION_CUDA_KERNEL(sigmoid, Sigmoid, CudaSigmoidFunctor, REGISTER_ACTIVATION_CUDA_KERNEL(sigmoid, Sigmoid, CudaSigmoidFunctor,
...@@ -1838,21 +1511,10 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -1838,21 +1511,10 @@ REGISTER_OP_CUDA_KERNEL(
__macro(silu, Silu, CudaSiluFunctor, CudaSiluGradFunctor); \ __macro(silu, Silu, CudaSiluFunctor, CudaSiluGradFunctor); \
__macro(logsigmoid, LogSigmoid, CudaLogSigmoidFunctor, \ __macro(logsigmoid, LogSigmoid, CudaLogSigmoidFunctor, \
CudaLogSigmoidGradFunctor); \ CudaLogSigmoidGradFunctor); \
__macro(atan, Atan, CudaAtanFunctor, CudaAtanGradFunctor); \
__macro(softshrink, SoftShrink, CudaSoftShrinkFunctor, \ __macro(softshrink, SoftShrink, CudaSoftShrinkFunctor, \
CudaSoftShrinkGradFunctor); \ CudaSoftShrinkGradFunctor); \
__macro(ceil, Ceil, CudaCeilFunctor, CudaZeroGradFunctor); \ __macro(ceil, Ceil, CudaCeilFunctor, CudaZeroGradFunctor); \
__macro(floor, Floor, CudaFloorFunctor, CudaZeroGradFunctor); \ __macro(floor, Floor, CudaFloorFunctor, CudaZeroGradFunctor); \
__macro(cos, Cos, CudaCosFunctor, CudaCosGradFunctor); \
__macro(tan, Tan, CudaTanFunctor, CudaTanGradFunctor); \
__macro(acos, Acos, CudaAcosFunctor, CudaAcosGradFunctor); \
__macro(sin, Sin, CudaSinFunctor, CudaSinGradFunctor); \
__macro(asin, Asin, CudaAsinFunctor, CudaAsinGradFunctor); \
__macro(sinh, Sinh, CudaSinhFunctor, CudaSinhGradFunctor); \
__macro(cosh, Cosh, CudaCoshFunctor, CudaCoshGradFunctor); \
__macro(asinh, Asinh, CudaAsinhFunctor, CudaAsinhGradFunctor); \
__macro(acosh, Acosh, CudaAcoshFunctor, CudaAcoshGradFunctor); \
__macro(atanh, Atanh, CudaAtanhFunctor, CudaAtanhGradFunctor); \
__macro(round, Round, CudaRoundFunctor, CudaZeroGradFunctor); \ __macro(round, Round, CudaRoundFunctor, CudaZeroGradFunctor); \
__macro(reciprocal, Reciprocal, CudaReciprocalFunctor, \ __macro(reciprocal, Reciprocal, CudaReciprocalFunctor, \
CudaReciprocalGradFunctor); \ CudaReciprocalGradFunctor); \
......
...@@ -29,7 +29,7 @@ USE_OP_ITSELF(elementwise_add); ...@@ -29,7 +29,7 @@ USE_OP_ITSELF(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
USE_OP(elementwise_mul); USE_OP(elementwise_mul);
USE_OP_DEVICE_KERNEL(elementwise_mul, MKLDNN); USE_OP_DEVICE_KERNEL(elementwise_mul, MKLDNN);
USE_OP(relu); USE_OP_ITSELF(relu);
USE_OP_DEVICE_KERNEL(relu, MKLDNN); USE_OP_DEVICE_KERNEL(relu, MKLDNN);
USE_OP_ITSELF(softmax); USE_OP_ITSELF(softmax);
USE_OP_DEVICE_KERNEL(softmax, MKLDNN); USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
USE_OP(relu); USE_OP_ITSELF(relu);
USE_OP_DEVICE_KERNEL(relu, MKLDNN); USE_OP_DEVICE_KERNEL(relu, MKLDNN);
USE_OP_ITSELF(softmax); USE_OP_ITSELF(softmax);
USE_OP_DEVICE_KERNEL(softmax, MKLDNN); USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
USE_OP(pool2d); USE_OP(pool2d);
USE_OP_DEVICE_KERNEL(pool2d, MKLDNN); USE_OP_DEVICE_KERNEL(pool2d, MKLDNN);
USE_OP(relu); USE_OP_ITSELF(relu);
USE_OP_DEVICE_KERNEL(relu, MKLDNN); USE_OP_DEVICE_KERNEL(relu, MKLDNN);
USE_OP_ITSELF(transpose); USE_OP_ITSELF(transpose);
USE_OP_DEVICE_KERNEL(transpose, MKLDNN); USE_OP_DEVICE_KERNEL(transpose, MKLDNN);
......
...@@ -22,7 +22,7 @@ limitations under the License. */ ...@@ -22,7 +22,7 @@ limitations under the License. */
namespace fw = paddle::framework; namespace fw = paddle::framework;
namespace plat = paddle::platform; namespace plat = paddle::platform;
USE_OP(relu); USE_OP_ITSELF(relu);
USE_OP_DEVICE_KERNEL(relu, MLU); USE_OP_DEVICE_KERNEL(relu, MLU);
// relu // relu
......
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/operators/common_infer_shape_functions.h" #include "paddle/fluid/operators/common_infer_shape_functions.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
USE_OP(relu); USE_OP_ITSELF(relu);
USE_OP_ITSELF(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP_ITSELF(softmax); USE_OP_ITSELF(softmax);
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/unary.h"
namespace phi {
#define DECLARE_ACTIVATION_GRAD_KERNEL_DepX(name) \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& dout, \
DenseTensor* dx);
#define DECLARE_ACTIVATION_GRAD_KERNEL_DepOut(name) \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& out, \
const DenseTensor& dout, \
DenseTensor* dx);
template <typename T, typename Context>
void ReluDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& ddx,
DenseTensor* ddout);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Cos);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Tan);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Acos);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Sin);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Asin);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Atan);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Sinh);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Cosh);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Asinh);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Acosh);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Atanh);
DECLARE_ACTIVATION_GRAD_KERNEL_DepOut(Relu);
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/unary.h"
namespace phi {
#define DECLARE_ACTIVATION_KERNEL(name) \
template <typename T, typename Context> \
void name##Kernel( \
const Context& dev_ctx, const DenseTensor& x, DenseTensor* out);
DECLARE_ACTIVATION_KERNEL(Cos)
DECLARE_ACTIVATION_KERNEL(Tan)
DECLARE_ACTIVATION_KERNEL(Acos)
DECLARE_ACTIVATION_KERNEL(Sin)
DECLARE_ACTIVATION_KERNEL(Asin)
DECLARE_ACTIVATION_KERNEL(Atan)
DECLARE_ACTIVATION_KERNEL(Sinh)
DECLARE_ACTIVATION_KERNEL(Cosh)
DECLARE_ACTIVATION_KERNEL(Asinh)
DECLARE_ACTIVATION_KERNEL(Acosh)
DECLARE_ACTIVATION_KERNEL(Atanh)
DECLARE_ACTIVATION_KERNEL(Relu)
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/activation_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/activation_grad_impl.h"
namespace phi {
#define DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(name, functor_class) \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& dout, \
DenseTensor* dx) { \
functor_class functor; \
ActivationGradImpl<T, Context, functor_class>( \
dev_ctx, &x, nullptr, &dout, dx, functor); \
}
#define DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepOut(name, functor_class) \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& out, \
const DenseTensor& dout, \
DenseTensor* dx) { \
functor_class functor; \
ActivationGradImpl<T, Context, functor_class>( \
dev_ctx, nullptr, &out, &dout, dx, functor); \
}
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Cos, funcs::CosGradFunctor<T>);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Tan, funcs::TanGradFunctor<T>);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Acos, funcs::AcosGradFunctor<T>);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Sin, funcs::SinGradFunctor<T>);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Asin, funcs::AsinGradFunctor<T>);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Atan, funcs::AtanGradFunctor<T>);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Sinh, funcs::SinhGradFunctor<T>);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Cosh, funcs::CoshGradFunctor<T>);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Asinh, funcs::AsinhGradFunctor<T>);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Acosh, funcs::AcoshGradFunctor<T>);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Atanh, funcs::AtanhGradFunctor<T>);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepOut(Relu, funcs::ReluGradFunctor<T>);
} // namespace phi
PD_REGISTER_KERNEL(
cos_grad, CPU, ALL_LAYOUT, phi::CosGradKernel, float, double) {}
PD_REGISTER_KERNEL(
tan_grad, CPU, ALL_LAYOUT, phi::TanGradKernel, float, double) {}
PD_REGISTER_KERNEL(
acos_grad, CPU, ALL_LAYOUT, phi::AcosGradKernel, float, double) {}
PD_REGISTER_KERNEL(
sin_grad, CPU, ALL_LAYOUT, phi::SinGradKernel, float, double) {}
PD_REGISTER_KERNEL(
asin_grad, CPU, ALL_LAYOUT, phi::AsinGradKernel, float, double) {}
PD_REGISTER_KERNEL(
atan_grad, CPU, ALL_LAYOUT, phi::AtanGradKernel, float, double) {}
PD_REGISTER_KERNEL(
sinh_grad, CPU, ALL_LAYOUT, phi::SinhGradKernel, float, double) {}
PD_REGISTER_KERNEL(
cosh_grad, CPU, ALL_LAYOUT, phi::CoshGradKernel, float, double) {}
PD_REGISTER_KERNEL(
asinh_grad, CPU, ALL_LAYOUT, phi::AsinhGradKernel, float, double) {}
PD_REGISTER_KERNEL(
acosh_grad, CPU, ALL_LAYOUT, phi::AcoshGradKernel, float, double) {}
PD_REGISTER_KERNEL(
atanh_grad, CPU, ALL_LAYOUT, phi::AtanhGradKernel, float, double) {}
PD_REGISTER_KERNEL(
relu_grad, CPU, ALL_LAYOUT, phi::ReluGradKernel, float, double) {}
PD_REGISTER_KERNEL(relu_double_grad,
CPU,
ALL_LAYOUT,
phi::ReluDoubleGradKernel,
float,
double,
phi::dtype::float16) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/activation_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/activation_impl.h"
namespace phi {
#define DEFINE_CPU_ACTIVATION_KERNEL(name, functor_class) \
template <typename T, typename Context> \
void name##Kernel( \
const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { \
functor_class functor; \
ActivationImpl<T, Context, functor_class>(dev_ctx, x, out, functor); \
}
DEFINE_CPU_ACTIVATION_KERNEL(Sin, funcs::SinFunctor<T>)
DEFINE_CPU_ACTIVATION_KERNEL(Cos, funcs::CosFunctor<T>)
DEFINE_CPU_ACTIVATION_KERNEL(Tan, funcs::TanFunctor<T>)
DEFINE_CPU_ACTIVATION_KERNEL(Asin, funcs::AsinFunctor<T>)
DEFINE_CPU_ACTIVATION_KERNEL(Atan, funcs::AtanFunctor<T>)
DEFINE_CPU_ACTIVATION_KERNEL(Acos, funcs::AcosFunctor<T>)
DEFINE_CPU_ACTIVATION_KERNEL(Sinh, funcs::SinhFunctor<T>)
DEFINE_CPU_ACTIVATION_KERNEL(Cosh, funcs::CoshFunctor<T>)
DEFINE_CPU_ACTIVATION_KERNEL(Asinh, funcs::AsinhFunctor<T>)
DEFINE_CPU_ACTIVATION_KERNEL(Acosh, funcs::AcoshFunctor<T>)
DEFINE_CPU_ACTIVATION_KERNEL(Atanh, funcs::AtanhFunctor<T>)
DEFINE_CPU_ACTIVATION_KERNEL(Relu, funcs::ReluCPUFunctor<T>)
} // namespace phi
PD_REGISTER_KERNEL(sin, CPU, ALL_LAYOUT, phi::SinKernel, float, double) {}
PD_REGISTER_KERNEL(cos, CPU, ALL_LAYOUT, phi::CosKernel, float, double) {}
PD_REGISTER_KERNEL(tan, CPU, ALL_LAYOUT, phi::TanKernel, float, double) {}
PD_REGISTER_KERNEL(acos, CPU, ALL_LAYOUT, phi::AcosKernel, float, double) {}
PD_REGISTER_KERNEL(asin, CPU, ALL_LAYOUT, phi::AsinKernel, float, double) {}
PD_REGISTER_KERNEL(atan, CPU, ALL_LAYOUT, phi::AtanKernel, float, double) {}
PD_REGISTER_KERNEL(sinh, CPU, ALL_LAYOUT, phi::SinhKernel, float, double) {}
PD_REGISTER_KERNEL(cosh, CPU, ALL_LAYOUT, phi::CoshKernel, float, double) {}
PD_REGISTER_KERNEL(asinh, CPU, ALL_LAYOUT, phi::AsinhKernel, float, double) {}
PD_REGISTER_KERNEL(acosh, CPU, ALL_LAYOUT, phi::AcoshKernel, float, double) {}
PD_REGISTER_KERNEL(atanh, CPU, ALL_LAYOUT, phi::AtanhKernel, float, double) {}
PD_REGISTER_KERNEL(relu, CPU, ALL_LAYOUT, phi::ReluKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include <cmath>
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES
#endif
#include <type_traits>
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
namespace phi {
namespace funcs {
enum ActBwdOpFwdDeps {
kNoDeps = 0x00, // Do not need any forward input/output
kDepX = 0x01, // Only need forward input X
kDepOut = 0x02, // Only need forward output Out
};
template <typename T>
struct BaseActivationFunctor {
using ELEMENT_TYPE = T;
using AttrPair = std::vector<std::pair<const char*, float*>>;
AttrPair GetAttrs() { return AttrPair(); }
};
template <typename T>
struct Sine {
HOSTDEVICE T operator()(const T& val) const { return sin(val); }
};
template <>
struct Sine<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
return dtype::float16(sin(static_cast<float>(val)));
}
};
template <typename T>
struct Cosine {
HOSTDEVICE T operator()(const T& val) const { return cos(val); }
};
template <>
struct Cosine<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
return dtype::float16(cos(static_cast<float>(val)));
}
};
// sine'(x) = cos(x)
template <typename T>
struct SinGradFunctor : public BaseActivationFunctor<T> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * x.unaryExpr(Cosine<T>());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// sine(x) = sin(x)
template <typename T>
struct SinFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Sine<T>());
}
};
// cosine'(x) = -sin(x)
template <typename T>
struct CosGradFunctor : public BaseActivationFunctor<T> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = -dout * x.unaryExpr(Sine<T>());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// cosine(x) = cos(x)
template <typename T>
struct CosFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Cosine<T>());
}
};
template <typename T>
struct Tangent {
HOSTDEVICE T operator()(const T& val) const { return tan(val); }
};
template <>
struct Tangent<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
return dtype::float16(tan(static_cast<float>(val)));
}
};
// Tangent'(x) = -Tangent(x)
template <typename T>
struct TanGradFunctor : public BaseActivationFunctor<T> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout / x.unaryExpr(Cosine<T>()).square();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// Tangent(x) = tan(x)
template <typename T>
struct TanFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Tangent<T>());
}
};
template <typename T>
struct Sinh {
HOSTDEVICE T operator()(const T& val) const { return sinh(val); }
};
template <>
struct Sinh<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
return dtype::float16(sinhf(static_cast<float>(val)));
}
};
template <typename T>
struct Cosh {
HOSTDEVICE T operator()(const T& val) const { return cosh(val); }
};
template <>
struct Cosh<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
return dtype::float16(coshf(static_cast<float>(val)));
}
};
// sinh(x) = sinh(x)
template <typename T>
struct SinhFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Sinh<T>());
}
};
// cosh(x) = cosh(x)
template <typename T>
struct CoshFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Cosh<T>());
}
};
// sinh'(x) = cosh(x)
template <typename T>
struct SinhGradFunctor : public BaseActivationFunctor<T> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * x.unaryExpr(Cosh<T>());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// cosh'(x) = sinh(x)
template <typename T>
struct CoshGradFunctor : public BaseActivationFunctor<T> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * x.unaryExpr(Sinh<T>());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct Acos {
HOSTDEVICE T operator()(const T& val) const { return acos(val); }
};
template <>
struct Acos<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
return dtype::float16(acos(static_cast<float>(val)));
}
};
// Acos(x) = acos(x)
template <typename T>
struct AcosFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Acos<T>());
}
};
// acos'(x) = -1/sqrt(1-x^2)
template <typename T>
struct AcosGradFunctor : public BaseActivationFunctor<T> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) =
-dout * static_cast<T>(1) / (static_cast<T>(1) - x.square()).sqrt();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct Asin {
HOSTDEVICE T operator()(const T& val) const { return asin(val); }
};
template <>
struct Asin<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
return dtype::float16(asin(static_cast<float>(val)));
}
};
// Asin(x) = asin(x)
template <typename T>
struct AsinFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Asin<T>());
}
};
// asin'(x) = 1/sqrt(1-x^2)
template <typename T>
struct AsinGradFunctor : public BaseActivationFunctor<T> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) =
dout * static_cast<T>(1) / (static_cast<T>(1) - x.square()).sqrt();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct Atan {
HOSTDEVICE T operator()(const T& val) const { return atan(val); }
};
template <>
struct Atan<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
return dtype::float16(atan(static_cast<float>(val)));
}
};
// Atan(x) = atan(x)
template <typename T>
struct AtanFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Atan<T>());
}
};
// atan'(x) = 1 / (1 + x^2)
template <typename T>
struct AtanGradFunctor : public BaseActivationFunctor<T> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(1) / (static_cast<T>(1) + x.square());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct Acosh {
HOSTDEVICE T operator()(const T& val) const { return acosh(val); }
};
template <>
struct Acosh<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
return dtype::float16(acosh(static_cast<float>(val)));
}
};
// Acosh(x) = acosh(x)
template <typename T>
struct AcoshFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Acosh<T>());
}
};
// acosh'(x) = 1/sqrt(x^2 - 1)
template <typename T>
struct AcoshGradFunctor : public BaseActivationFunctor<T> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) =
dout * static_cast<T>(1) / (x * x - static_cast<T>(1)).sqrt();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct Asinh {
HOSTDEVICE T operator()(const T& val) const { return asinh(val); }
};
template <>
struct Asinh<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
return dtype::float16(asinh(static_cast<float>(val)));
}
};
// Asinh(x) = asinh(x)
template <typename T>
struct AsinhFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Asinh<T>());
}
};
// asinh'(x) = 1/sqrt(x^2 + 1)
template <typename T>
struct AsinhGradFunctor : public BaseActivationFunctor<T> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) =
dout * static_cast<T>(1) / (x.square() + static_cast<T>(1)).sqrt();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct Atanh {
HOSTDEVICE T operator()(const T& val) const { return atanh(val); }
};
template <>
struct Atanh<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const {
return dtype::float16(atanh(static_cast<float>(val)));
}
};
// Atanh(x) = atanh(x)
template <typename T>
struct AtanhFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr(Atanh<T>());
}
};
// atanh'(x) = 1/(1 - x^2)
template <typename T>
struct AtanhGradFunctor : public BaseActivationFunctor<T> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(1) / (static_cast<T>(1) - x.square());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// relu(x) = max(x, 0)
template <typename T>
struct ReluCPUFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.unaryExpr([] HOSTDEVICE(T v) {
return v > static_cast<T>(0) ? v : static_cast<T>(0);
});
}
};
template <typename T>
struct ReluCUDAFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.cwiseMax(static_cast<T>(0));
}
};
template <typename T>
struct ReluGradFunctor : public BaseActivationFunctor<T> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (out > static_cast<T>(0)).template cast<T>();
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct ReluGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev,
const DenseTensor* X,
const DenseTensor* Out,
const DenseTensor* ddX,
DenseTensor* ddOut,
DenseTensor* dOut,
DenseTensor* dX) const {
auto* d = dev.eigen_device();
auto ddx = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "ReluGradGrad"));
auto out = EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "ReluGradGrad"));
if (ddOut) {
auto ddout = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "ReluGradGrad"));
ddout.device(*d) = ddx * (out > static_cast<T>(0)).template cast<T>();
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
#if defined(__NVCC__) || defined(__HIPCC__)
template <typename T>
struct CudaReluFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
// relu(x) = max(x, 0)
__device__ __forceinline__ T operator()(const T x) const {
return x > zero ? x : zero;
}
};
template <typename T>
struct CudaReluGradFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
// dx = dout * (out > 0)
__device__ __forceinline__ T operator()(const T dout, const T out) const {
return out > zero ? dout : zero;
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct CudaCosFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
// cos(x) = cos(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(cos(x));
}
};
template <typename T>
struct CudaCosGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
// dx = dout * (-sin(x))
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(-dout * sin(x));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaSinFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
// sin(x) = sin(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(sin(x));
}
};
template <typename T>
struct CudaSinGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
// dx = dout * cos(x)
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(dout * cos(x));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaTanFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
// tan(x) = tan(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(tan(x));
}
};
template <typename T>
struct CudaTanGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
// dx = dout / cos(x)^2
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(dout / (cos(x) * cos(x)));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaAsinFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
// asin(x) = asin(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(asin(x));
}
};
template <typename T>
struct CudaAsinGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
// dx = dout / sqrt(1 - x^2)
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(dout / sqrt(one - x * x));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaAcosFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
// acos(x) = acos(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(acos(x));
}
};
template <typename T>
struct CudaAcosGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
// dx = -dout / sqrt(1 - x^2)
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(-dout / sqrt(one - x * x));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaCoshFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
// cosh(x) = cosh(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(cosh(x));
}
};
template <typename T>
struct CudaCoshGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
// dx = dout * sinh(x)
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(dout * sinh(x));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaSinhFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
// sinh(x) = sinh(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(sinh(x));
}
};
template <typename T>
struct CudaSinhGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
// dx = dout * cosh(x)
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(dout * cosh(x));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaAcoshFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
// Acosh(x) = acosh(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(acosh(x));
}
};
template <typename T>
struct CudaAcoshGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
// dx = dout * 1 / sqrt(x^2 - 1)
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(dout * one / sqrt(x * x - one));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaAsinhFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
// Asinh(x) = asinh(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(asinh(x));
}
};
template <typename T>
struct CudaAsinhGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
// dx = dout * 1/sqrt(x^2 + 1)
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(dout * one / sqrt(x * x + one));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaAtanhFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
// Atanh(x) = atanh(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(atanh(x));
}
};
template <typename T>
struct CudaAtanhGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
// dx = dout * 1/(1- x^2)
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(dout * one / (one - x * x));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaAtanFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
// atan(x) = atan(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(atan(x));
}
};
template <typename T>
struct CudaAtanGradFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);
// dx = dout / (1 + x^2)
__device__ __forceinline__ T operator()(const T dout, const T x) const {
return dout / (one + x * x);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
#endif
} // namespace funcs
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/activation_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/impl/activation_grad_impl.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
namespace phi {
template <typename T, typename Context, typename Functor>
void ActivationGradGPUImpl(const Context& dev_ctx,
const DenseTensor* x,
const DenseTensor* out,
const DenseTensor* d_out,
DenseTensor* d_x,
const Functor& functor) {
if (static_cast<int>(Functor::FwdDeps()) &
static_cast<int>(funcs::ActBwdOpFwdDeps::kDepOut)) {
PADDLE_ENFORCE_NOT_NULL(
out, errors::NotFound("The input DenseTensor Out can not be nullptr"));
}
PADDLE_ENFORCE_NOT_NULL(
d_out, errors::NotFound("The input DenseTensor dOut can not be nullptr"));
PADDLE_ENFORCE_NOT_NULL(
d_x, errors::NotFound("The output DenseTensor dX can not be nullptr"));
if (!out) {
out = d_out; // fake out
}
if (static_cast<int>(Functor::FwdDeps()) &
static_cast<int>(funcs::ActBwdOpFwdDeps::kDepX)) {
PADDLE_ENFORCE_NOT_NULL(
x, errors::NotFound("The input DenseTensor X can not be nullptr"));
} else {
VLOG(10) << "Inplace activation of Op Functor: " << typeid(Functor).name();
x = d_x;
}
dev_ctx.template Alloc<T>(d_x);
std::vector<const DenseTensor*> ins = {d_out};
std::vector<DenseTensor*> outs = {d_x};
if (static_cast<int>(Functor::FwdDeps()) ==
static_cast<int>(funcs::ActBwdOpFwdDeps::kDepOut)) {
// Only need forward output Out
ins.push_back(out);
funcs::ElementwiseKernel<T>(dev_ctx, ins, &outs, functor);
} else if (static_cast<int>(Functor::FwdDeps()) ==
static_cast<int>(funcs::ActBwdOpFwdDeps::kDepX)) {
// Only need forward input X
ins.push_back(x);
funcs::ElementwiseKernel<T>(dev_ctx, ins, &outs, functor);
} else {
funcs::ElementwiseKernel<T>(dev_ctx, ins, &outs, functor);
}
}
#define DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(name, functor_class) \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& dout, \
DenseTensor* dx) { \
functor_class functor; \
ActivationGradGPUImpl<T, Context, functor_class>( \
dev_ctx, &x, nullptr, &dout, dx, functor); \
}
#define DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepOut(name, functor_class) \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& out, \
const DenseTensor& dout, \
DenseTensor* dx) { \
functor_class functor; \
ActivationGradGPUImpl<T, Context, functor_class>( \
dev_ctx, nullptr, &out, &dout, dx, functor); \
}
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepOut(Relu, funcs::CudaReluGradFunctor<T>);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Cos, funcs::CudaCosGradFunctor<T>);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Tan, funcs::CudaTanGradFunctor<T>);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Acos, funcs::CudaAcosGradFunctor<T>);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Sin, funcs::CudaSinGradFunctor<T>);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Asin, funcs::CudaAsinGradFunctor<T>);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Atan, funcs::CudaAtanGradFunctor<T>);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Sinh, funcs::CudaSinhGradFunctor<T>);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Cosh, funcs::CudaCoshGradFunctor<T>);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Asinh, funcs::CudaAsinhGradFunctor<T>);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Acosh, funcs::CudaAcoshGradFunctor<T>);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Atanh, funcs::CudaAtanhGradFunctor<T>);
} // namespace phi
PD_REGISTER_KERNEL(cos_grad,
GPU,
ALL_LAYOUT,
phi::CosGradKernel,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(tan_grad,
GPU,
ALL_LAYOUT,
phi::TanGradKernel,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(acos_grad,
GPU,
ALL_LAYOUT,
phi::AcosGradKernel,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(sin_grad,
GPU,
ALL_LAYOUT,
phi::SinGradKernel,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(asin_grad,
GPU,
ALL_LAYOUT,
phi::AsinGradKernel,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(atan_grad,
GPU,
ALL_LAYOUT,
phi::AtanGradKernel,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(sinh_grad,
GPU,
ALL_LAYOUT,
phi::SinhGradKernel,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(cosh_grad,
GPU,
ALL_LAYOUT,
phi::CoshGradKernel,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(asinh_grad,
GPU,
ALL_LAYOUT,
phi::AsinhGradKernel,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(acosh_grad,
GPU,
ALL_LAYOUT,
phi::AcoshGradKernel,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(atanh_grad,
GPU,
ALL_LAYOUT,
phi::AtanhGradKernel,
float,
double,
phi::dtype::float16) {}
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(relu_grad,
GPU,
ALL_LAYOUT,
phi::ReluGradKernel,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(relu_double_grad,
GPU,
ALL_LAYOUT,
phi::ReluDoubleGradKernel,
float,
double,
phi::dtype::float16) {}
#else
PD_REGISTER_KERNEL(relu_grad,
GPU,
ALL_LAYOUT,
phi::ReluGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(relu_double_grad,
GPU,
ALL_LAYOUT,
phi::ReluDoubleGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#endif
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/activation_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/impl/activation_grad_impl.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
namespace phi {
template <typename T, typename Context, typename Functor>
void ActivationGPUImpl(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out,
const Functor& functor) {
PADDLE_ENFORCE_NOT_NULL(out,
errors::NotFound("Output Out should not be nullptr"));
dev_ctx.template Alloc<T>(out);
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
funcs::ElementwiseKernel<T>(dev_ctx, ins, &outs, functor);
}
#define DEFINE_GPU_ACTIVATION_KERNEL(name, functor_class) \
template <typename T, typename Context> \
void name##Kernel( \
const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { \
functor_class functor; \
ActivationGPUImpl<T, Context, functor_class>(dev_ctx, x, out, functor); \
}
DEFINE_GPU_ACTIVATION_KERNEL(Cos, funcs::CudaCosFunctor<T>)
DEFINE_GPU_ACTIVATION_KERNEL(Tan, funcs::CudaTanFunctor<T>)
DEFINE_GPU_ACTIVATION_KERNEL(Acos, funcs::CudaAcosFunctor<T>)
DEFINE_GPU_ACTIVATION_KERNEL(Sin, funcs::CudaSinFunctor<T>)
DEFINE_GPU_ACTIVATION_KERNEL(Asin, funcs::CudaAsinFunctor<T>)
DEFINE_GPU_ACTIVATION_KERNEL(Atan, funcs::CudaAtanFunctor<T>)
DEFINE_GPU_ACTIVATION_KERNEL(Sinh, funcs::CudaSinhFunctor<T>)
DEFINE_GPU_ACTIVATION_KERNEL(Cosh, funcs::CudaCoshFunctor<T>)
DEFINE_GPU_ACTIVATION_KERNEL(Asinh, funcs::CudaAsinhFunctor<T>)
DEFINE_GPU_ACTIVATION_KERNEL(Acosh, funcs::CudaAcoshFunctor<T>)
DEFINE_GPU_ACTIVATION_KERNEL(Atanh, funcs::CudaAtanhFunctor<T>)
DEFINE_GPU_ACTIVATION_KERNEL(Relu, funcs::CudaReluFunctor<T>)
} // namespace phi
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(relu,
GPU,
ALL_LAYOUT,
phi::ReluKernel,
float,
double,
phi::dtype::float16) {}
#else
PD_REGISTER_KERNEL(relu,
GPU,
ALL_LAYOUT,
phi::ReluKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#endif
PD_REGISTER_KERNEL(
sin, GPU, ALL_LAYOUT, phi::SinKernel, float, double, phi::dtype::float16) {}
PD_REGISTER_KERNEL(
cos, GPU, ALL_LAYOUT, phi::CosKernel, float, double, phi::dtype::float16) {}
PD_REGISTER_KERNEL(
tan, GPU, ALL_LAYOUT, phi::TanKernel, float, double, phi::dtype::float16) {}
PD_REGISTER_KERNEL(acos,
GPU,
ALL_LAYOUT,
phi::AcosKernel,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(asin,
GPU,
ALL_LAYOUT,
phi::AsinKernel,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(atan,
GPU,
ALL_LAYOUT,
phi::AtanKernel,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(sinh,
GPU,
ALL_LAYOUT,
phi::SinhKernel,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(cosh,
GPU,
ALL_LAYOUT,
phi::CoshKernel,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(asinh,
GPU,
ALL_LAYOUT,
phi::AsinhKernel,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(acosh,
GPU,
ALL_LAYOUT,
phi::AcoshKernel,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(atanh,
GPU,
ALL_LAYOUT,
phi::AtanhKernel,
float,
double,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/fluid/platform/device_context.h"
namespace phi {
template <typename T, typename Context, typename Functor>
void ActivationGradImpl(const Context& dev_ctx,
const DenseTensor* X,
const DenseTensor* Out,
const DenseTensor* dOut,
DenseTensor* dX,
const Functor& functor) {
if (static_cast<int>(Functor::FwdDeps()) &
static_cast<int>(funcs::ActBwdOpFwdDeps::kDepOut)) {
PADDLE_ENFORCE_NOT_NULL(
Out, errors::NotFound("The input DenseTensor Out can not be nullptr"));
}
PADDLE_ENFORCE_NOT_NULL(
dOut, errors::NotFound("The input DenseTensor dOut can not be nullptr"));
PADDLE_ENFORCE_NOT_NULL(
dX, errors::NotFound("The output DenseTensor dX can not be nullptr"));
if (!Out) {
Out = dOut; // fake out
}
if (static_cast<int>(Functor::FwdDeps()) &
static_cast<int>(funcs::ActBwdOpFwdDeps::kDepX)) {
PADDLE_ENFORCE_NOT_NULL(
X, errors::NotFound("The input DenseTensor X can not be nullptr"));
} else {
VLOG(10) << "Inplace activation of Op Functor: " << typeid(Functor).name();
X = dX;
}
dev_ctx.template Alloc<T>(dX);
auto dout = phi::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "ActivationGrad"));
auto out = phi::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Input", "Out", "ActivationGrad"));
auto dx = phi::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Input", "X@GRAD", "ActivationGrad"));
auto x = phi::EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "ActivationGrad"));
auto* place = dev_ctx.eigen_device();
// use 32bit index to speed up computation
bool use_32bit_index = out.size() < Eigen::NumTraits<int>::highest();
bool is_gpu_place = paddle::platform::is_gpu_place(dev_ctx.GetPlace());
if (use_32bit_index && is_gpu_place) {
functor(*place,
To32BitIndex(x),
To32BitIndex(out),
To32BitIndex(dout),
To32BitIndex(dx));
} else {
functor(*place, x, out, dout, dx);
}
}
template <typename T, typename Context, typename Functor>
void ActivationDoubleGradImpl(const Context& dev_ctx,
const DenseTensor* X,
const DenseTensor* Out,
const DenseTensor* ddX,
DenseTensor* dX,
DenseTensor* dOut,
DenseTensor* ddOut,
const Functor& functor) {
if (static_cast<int>(Functor::FwdDeps()) &
static_cast<int>(funcs::ActBwdOpFwdDeps::kDepX)) {
PADDLE_ENFORCE_NOT_NULL(
X, errors::NotFound("The input DenseTensor X can not be nullptr"));
} else {
VLOG(10) << "Inplace activation of Op Functor: " << typeid(Functor).name();
X = ddX;
}
if (static_cast<int>(Functor::FwdDeps()) &
static_cast<int>(funcs::ActBwdOpFwdDeps::kDepOut)) {
PADDLE_ENFORCE_NOT_NULL(
Out, errors::NotFound("The input DenseTensor Out can not be nullptr"));
} else {
VLOG(10) << "Inplace activation of Op Functor: " << typeid(Functor).name();
Out = ddX;
}
if (ddOut) {
dev_ctx.template Alloc<T>(ddOut);
}
if (dOut) {
dev_ctx.template Alloc<T>(dOut);
}
if (dX) {
dX->Resize(Out->dims());
dev_ctx.template Alloc<T>(dX);
}
functor(dev_ctx, X, Out, ddX, ddOut, dOut, dX);
}
template <typename T, typename Context>
void ReluDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& ddx,
DenseTensor* ddout) {
funcs::ReluGradGradFunctor<T> relu_double_grad_functor;
ActivationDoubleGradImpl<T, Context, funcs::ReluGradGradFunctor<T>>(
dev_ctx,
nullptr,
&out,
&ddx,
nullptr,
nullptr,
ddout,
relu_double_grad_functor);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/fluid/platform/device_context.h"
namespace phi {
#define ToString(x) #x
template <typename T, typename Context, typename Functor>
void ActivationImpl(const Context& dev_ctx,
const DenseTensor& X,
DenseTensor* Out,
const Functor& functor) {
PADDLE_ENFORCE_NOT_NULL(Out,
errors::NotFound("Output Out should not be nullptr"));
dev_ctx.template Alloc<T>(Out);
auto x = phi::EigenVector<T>::Flatten(
GET_DATA_SAFELY(&X, "Input", "X", "Activation"));
auto out = phi::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "Activation"));
auto* place = dev_ctx.eigen_device();
// use 32bit index to speed up computation
bool use_32bit_index = out.size() < Eigen::NumTraits<int>::highest();
bool is_gpu_place = paddle::platform::is_gpu_place(dev_ctx.GetPlace());
if (use_32bit_index && is_gpu_place) {
functor(*place, To32BitIndex(x), To32BitIndex(out));
} else {
functor(*place, x, out);
}
}
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
#define DefineActGradDepXOpArgMap(func_name, op_name) \
KernelSignature func_name##GradOpArgumentMapping( \
const ArgumentMappingContext& ctx) { \
return KernelSignature( \
op_name "_grad", {"X", GradVarName("Out")}, {}, {GradVarName("X")}); \
}
#define DefineActGradDepOutOpArgMap(func_name, op_name) \
KernelSignature func_name##GradOpArgumentMapping( \
const ArgumentMappingContext& ctx) { \
return KernelSignature( \
op_name "_grad", {"Out", GradVarName("Out")}, {}, {GradVarName("X")}); \
}
KernelSignature ReluDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("relu_double_grad", {"Out", "DDX"}, {}, {"DDOut"});
}
DefineActGradDepXOpArgMap(Cos, "cos");
DefineActGradDepXOpArgMap(Tan, "tan");
DefineActGradDepXOpArgMap(Acos, "acos");
DefineActGradDepXOpArgMap(Sin, "sin");
DefineActGradDepXOpArgMap(Asin, "asin");
DefineActGradDepXOpArgMap(Atan, "atan");
DefineActGradDepXOpArgMap(Sinh, "sinh");
DefineActGradDepXOpArgMap(Cosh, "cosh");
DefineActGradDepXOpArgMap(Asinh, "asinh");
DefineActGradDepXOpArgMap(Acosh, "acosh");
DefineActGradDepXOpArgMap(Atanh, "atanh");
DefineActGradDepOutOpArgMap(Relu, "relu");
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(relu_grad_grad, relu_double_grad);
PD_REGISTER_ARG_MAPPING_FN(cos_grad, phi::CosGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(tan_grad, phi::TanGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(acos_grad, phi::AcosGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(sin_grad, phi::SinGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(asin_grad, phi::AsinGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(atan_grad, phi::AtanGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(sinh_grad, phi::SinhGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(cosh_grad, phi::CoshGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(asinh_grad, phi::AsinhGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(acosh_grad, phi::AcoshGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(atanh_grad, phi::AtanhGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(relu_grad, phi::ReluGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(relu_grad_grad,
phi::ReluDoubleGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册