提交 337b7ebe 编写于 作者: Y Yu Yang

Unify Activation functions and simplify register code

上级 184768e0
...@@ -195,111 +195,54 @@ class STanhOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -195,111 +195,54 @@ class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(sigmoid, ops::ActivationOp, ops::SigmoidOpMaker, sigmoid_grad, REGISTER_OP(sigmoid, ops::ActivationOp, ops::SigmoidOpMaker, sigmoid_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(sigmoid,
ops::ActivationKernel<paddle::platform::CPUPlace, float,
ops::SigmoidFunctor<float>>);
REGISTER_OP_CPU_KERNEL(
sigmoid_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::SigmoidGradFunctor<float>>);
REGISTER_OP(exp, ops::ActivationOp, ops::ExpOpMaker, exp_grad, REGISTER_OP(exp, ops::ActivationOp, ops::ExpOpMaker, exp_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
exp,
ops::ActivationKernel<paddle::platform::CPUPlace, float, ops::ExpFunctor>);
REGISTER_OP_CPU_KERNEL(exp_grad,
ops::ActivationGradKernel<paddle::platform::CPUPlace,
float, ops::ExpGradFunctor>);
REGISTER_OP(relu, ops::ActivationOp, ops::ReluOpMaker, relu_grad, REGISTER_OP(relu, ops::ActivationOp, ops::ReluOpMaker, relu_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(relu,
ops::ActivationKernel<paddle::platform::CPUPlace, float,
ops::ReluFunctor<float>>);
REGISTER_OP_CPU_KERNEL(
relu_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::ReluGradFunctor<float>>);
REGISTER_OP(tanh, ops::ActivationOp, ops::TanhOpMaker, tanh_grad, REGISTER_OP(tanh, ops::ActivationOp, ops::TanhOpMaker, tanh_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
tanh,
ops::ActivationKernel<paddle::platform::CPUPlace, float, ops::TanhFunctor>);
REGISTER_OP_CPU_KERNEL(
tanh_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::TanhGradFunctor<float>>);
REGISTER_OP(sqrt, ops::ActivationOp, ops::SqrtOpMaker, sqrt_grad, REGISTER_OP(sqrt, ops::ActivationOp, ops::SqrtOpMaker, sqrt_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
sqrt,
ops::ActivationKernel<paddle::platform::CPUPlace, float, ops::SqrtFunctor>);
REGISTER_OP_CPU_KERNEL(
sqrt_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::SqrtGradFunctor<float>>);
REGISTER_OP(abs, ops::ActivationOp, ops::AbsOpMaker, abs_grad, REGISTER_OP(abs, ops::ActivationOp, ops::AbsOpMaker, abs_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
abs,
ops::ActivationKernel<paddle::platform::CPUPlace, float, ops::AbsFunctor>);
REGISTER_OP_CPU_KERNEL(abs_grad,
ops::ActivationGradKernel<paddle::platform::CPUPlace,
float, ops::AbsGradFunctor>);
REGISTER_OP(reciprocal, ops::ActivationOp, ops::ReciprocalOpMaker, REGISTER_OP(reciprocal, ops::ActivationOp, ops::ReciprocalOpMaker,
reciprocal_grad, ops::ActivationOpGrad); reciprocal_grad, ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(reciprocal,
ops::ActivationKernel<paddle::platform::CPUPlace, float,
ops::ReciprocalFunctor<float>>);
REGISTER_OP_CPU_KERNEL(
reciprocal_grad,
ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::ReciprocalGradFunctor<float>>);
REGISTER_OP(log, ops::ActivationOp, ops::LogOpMaker, log_grad, REGISTER_OP(log, ops::ActivationOp, ops::LogOpMaker, log_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
log,
ops::ActivationKernel<paddle::platform::CPUPlace, float, ops::LogFunctor>);
REGISTER_OP_CPU_KERNEL(
log_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::LogGradFunctor<float>>);
REGISTER_OP(square, ops::ActivationOp, ops::SquareOpMaker, square_grad, REGISTER_OP(square, ops::ActivationOp, ops::SquareOpMaker, square_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(square,
ops::ActivationKernel<paddle::platform::CPUPlace, float,
ops::SquareFunctor>);
REGISTER_OP_CPU_KERNEL(
square_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::SquareGradFunctor<float>>);
REGISTER_OP(brelu, ops::ActivationOp, ops::BReluOpMaker<float>, brelu_grad, REGISTER_OP(brelu, ops::ActivationOp, ops::BReluOpMaker<float>, brelu_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(brelu,
ops::BReluKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(brelu_grad,
ops::BReluGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker<float>, REGISTER_OP(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker<float>,
soft_relu_grad, ops::ActivationOpGrad); soft_relu_grad, ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(soft_relu,
ops::SoftReluKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
soft_relu_grad, ops::SoftReluGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(pow, ops::ActivationOp, ops::PowOpMaker<float>, pow_grad, REGISTER_OP(pow, ops::ActivationOp, ops::PowOpMaker<float>, pow_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(pow, ops::PowKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(pow_grad,
ops::PowGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker<float>, stanh_grad, REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker<float>, stanh_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(stanh,
ops::STanhKernel<paddle::platform::CPUPlace, float>); #define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_CPU_KERNEL(stanh_grad, REGISTER_OP_CPU_KERNEL( \
ops::STanhGradKernel<paddle::platform::CPUPlace, float>); act_type, \
paddle::operators::ActivationKernel<paddle::platform::CPUPlace, \
paddle::operators::functor<float>>); \
REGISTER_OP_CPU_KERNEL(act_type##_grad, \
paddle::operators::ActivationGradKernel< \
paddle::platform::CPUPlace, \
paddle::operators::grad_functor<float>>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);
...@@ -15,86 +15,14 @@ ...@@ -15,86 +15,14 @@
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/operators/activation_op.h" #include "paddle/operators/activation_op.h"
namespace ops = paddle::operators; #define REGISTER_ACTIVATION_GPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_GPU_KERNEL( \
REGISTER_OP_GPU_KERNEL(sigmoid, act_type, \
ops::ActivationKernel<paddle::platform::GPUPlace, float, paddle::operators::ActivationKernel<paddle::platform::GPUPlace, \
ops::SigmoidFunctor<float>>); paddle::operators::functor<float>>); \
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(act_type##_grad, \
sigmoid_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float, paddle::operators::ActivationGradKernel< \
ops::SigmoidGradFunctor<float>>); paddle::platform::GPUPlace, \
paddle::operators::grad_functor<float>>);
REGISTER_OP_GPU_KERNEL(
exp, FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_GPU_KERNEL);
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::ExpFunctor>);
REGISTER_OP_GPU_KERNEL(exp_grad,
ops::ActivationGradKernel<paddle::platform::GPUPlace,
float, ops::ExpGradFunctor>);
REGISTER_OP_GPU_KERNEL(relu,
ops::ActivationKernel<paddle::platform::GPUPlace, float,
ops::ReluFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
relu_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::ReluGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
tanh,
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::TanhFunctor>);
REGISTER_OP_GPU_KERNEL(
tanh_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::TanhGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
sqrt,
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::SqrtFunctor>);
REGISTER_OP_GPU_KERNEL(
sqrt_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::SqrtGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
abs,
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::AbsFunctor>);
REGISTER_OP_GPU_KERNEL(abs_grad,
ops::ActivationGradKernel<paddle::platform::GPUPlace,
float, ops::AbsGradFunctor>);
REGISTER_OP_GPU_KERNEL(reciprocal,
ops::ActivationKernel<paddle::platform::GPUPlace, float,
ops::ReciprocalFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
reciprocal_grad,
ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::ReciprocalGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
log,
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::LogFunctor>);
REGISTER_OP_GPU_KERNEL(
log_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::LogGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(square,
ops::ActivationKernel<paddle::platform::GPUPlace, float,
ops::SquareFunctor>);
REGISTER_OP_GPU_KERNEL(
square_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::SquareGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(brelu,
ops::BReluKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(brelu_grad,
ops::BReluGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(soft_relu,
ops::SoftReluKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
soft_relu_grad, ops::SoftReluGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(pow, ops::PowKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(pow_grad,
ops::PowGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(stanh,
ops::STanhKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(stanh_grad,
ops::STanhGradKernel<paddle::platform::GPUPlace, float>);
...@@ -19,9 +19,12 @@ ...@@ -19,9 +19,12 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T, typename Functor> template <typename Place, typename Functor>
class ActivationKernel : public framework::OpKernel<T> { class ActivationKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public: public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X"); auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Output<framework::Tensor>("Y"); auto* Y = context.Output<framework::Tensor>("Y");
...@@ -31,13 +34,20 @@ class ActivationKernel : public framework::OpKernel<T> { ...@@ -31,13 +34,20 @@ class ActivationKernel : public framework::OpKernel<T> {
auto y = framework::EigenVector<T>::Flatten(*Y); auto y = framework::EigenVector<T>::Flatten(*Y);
auto place = context.GetEigenDevice<Place>(); auto place = context.GetEigenDevice<Place>();
Functor functor; Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
functor(place, x, y); functor(place, x, y);
} }
}; };
template <typename Place, typename T, typename Functor> template <typename Place, typename Functor>
class ActivationGradKernel : public framework::OpKernel<T> { class ActivationGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public: public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X"); auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Input<framework::Tensor>("Y"); auto* Y = context.Input<framework::Tensor>("Y");
...@@ -51,303 +61,301 @@ class ActivationGradKernel : public framework::OpKernel<T> { ...@@ -51,303 +61,301 @@ class ActivationGradKernel : public framework::OpKernel<T> {
auto dx = framework::EigenVector<T>::Flatten(*dX); auto dx = framework::EigenVector<T>::Flatten(*dX);
auto place = context.GetEigenDevice<Place>(); auto place = context.GetEigenDevice<Place>();
Functor functor; Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
functor(place, x, y, dy, dx); functor(place, x, y, dy, dx);
} }
}; };
template <typename T>
struct BaseActivationFunctor {
using ELEMENT_TYPE = T;
using AttrPair = std::vector<std::pair<const char*, float*>>;
AttrPair GetAttrs() { return AttrPair(); }
};
// sigmoid(x) = 1 / (1 + exp(-x)) // sigmoid(x) = 1 / (1 + exp(-x))
template <typename T> template <typename T>
struct SigmoidFunctor { struct SigmoidFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) { void operator()(Device d, X x, Y y) const {
y.device(d) = static_cast<T>(1) / (static_cast<T>(1) + (-x).exp()); y.device(d) = static_cast<T>(1) / (static_cast<T>(1) + (-x).exp());
} }
}; };
template <typename T> template <typename T>
struct SigmoidGradFunctor { struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) { void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * y * (static_cast<T>(1) - y); dx.device(d) = dy * y * (static_cast<T>(1) - y);
} }
}; };
// exp(x) = e^x // exp(x) = e^x
struct ExpFunctor { template <typename T>
struct ExpFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) { void operator()(Device d, X x, Y y) const {
y.device(d) = x.exp(); y.device(d) = x.exp();
} }
}; };
struct ExpGradFunctor { template <typename T>
struct ExpGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) { void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * y; dx.device(d) = dy * y;
} }
}; };
// relu(x) = max(x, 0) // relu(x) = max(x, 0)
template <typename T> template <typename T>
struct ReluFunctor { struct ReluFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) { void operator()(Device d, X x, Y y) const {
y.device(d) = x.cwiseMax(static_cast<T>(0)); y.device(d) = x.cwiseMax(static_cast<T>(0));
} }
}; };
template <typename T> template <typename T>
struct ReluGradFunctor { struct ReluGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) { void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * (x > static_cast<T>(0)).template cast<T>(); dx.device(d) = dy * (x > static_cast<T>(0)).template cast<T>();
} }
}; };
// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) // tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
struct TanhFunctor { template <typename T>
struct TanhFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) { void operator()(Device d, X x, Y y) const {
y.device(d) = x.tanh(); y.device(d) = x.tanh();
} }
}; };
template <typename T> template <typename T>
struct TanhGradFunctor { struct TanhGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) { void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * (static_cast<T>(1) - y * y); dx.device(d) = dy * (static_cast<T>(1) - y * y);
} }
}; };
// sqrt(x) = x^(1/2) // sqrt(x) = x^(1/2)
struct SqrtFunctor { template <typename T>
struct SqrtFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) { void operator()(Device d, X x, Y y) const {
y.device(d) = x.sqrt(); y.device(d) = x.sqrt();
} }
}; };
template <typename T> template <typename T>
struct SqrtGradFunctor { struct SqrtGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) { void operator()(Device d, X x, Y y, dY dy, dX dx) const {
const Y y_conj = Eigen::numext::conj(y); const Y y_conj = Eigen::numext::conj(y);
dx.device(d) = static_cast<T>(0.5) * dy / y_conj; dx.device(d) = static_cast<T>(0.5) * dy / y_conj;
} }
}; };
// abs(x) = |x| // abs(x) = |x|
struct AbsFunctor { template <typename T>
struct AbsFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) { void operator()(Device d, X x, Y y) const {
y.device(d) = x.abs(); y.device(d) = x.abs();
} }
}; };
struct AbsGradFunctor { template <typename T>
struct AbsGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) { void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * x.sign(); dx.device(d) = dy * x.sign();
} }
}; };
// reciprocal(x) = 1 / x // reciprocal(x) = 1 / x
template <typename T> template <typename T>
struct ReciprocalFunctor { struct ReciprocalFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) { void operator()(Device d, X x, Y y) const {
y.device(d) = static_cast<T>(1) / x; y.device(d) = static_cast<T>(1) / x;
} }
}; };
template <typename T> template <typename T>
struct ReciprocalGradFunctor { struct ReciprocalGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) { void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * static_cast<T>(-1) * y * y; dx.device(d) = dy * static_cast<T>(-1) * y * y;
} }
}; };
// log(x) = natural logarithm of x // log(x) = natural logarithm of x
struct LogFunctor { template <typename T>
struct LogFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) { void operator()(Device d, X x, Y y) const {
y.device(d) = x.log(); y.device(d) = x.log();
} }
}; };
template <typename T> template <typename T>
struct LogGradFunctor { struct LogGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) { void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * (static_cast<T>(1) / x); dx.device(d) = dy * (static_cast<T>(1) / x);
} }
}; };
// square(x) = x^2 // square(x) = x^2
struct SquareFunctor { template <typename T>
struct SquareFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) { void operator()(Device d, X x, Y y) const {
y.device(d) = x.square(); y.device(d) = x.square();
} }
}; };
template <typename T> template <typename T>
struct SquareGradFunctor { struct SquareGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) { void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * static_cast<T>(2) * x; dx.device(d) = dy * static_cast<T>(2) * x;
} }
}; };
template <typename Place, typename T, typename AttrType = T> template <typename T>
class BReluKernel : public framework::OpKernel<T> { struct BReluFunctor : public BaseActivationFunctor<T> {
public: float t_min;
void Compute(const framework::ExecutionContext& context) const override { float t_max;
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Output<framework::Tensor>("Y"); // NOTE: Explicit hides the `BaseActivationFunctor<T>::GetAttrs`
auto t_min = static_cast<T>(context.Attr<AttrType>("t_min")); // not polymorphism for speed.
auto t_max = static_cast<T>(context.Attr<AttrType>("t_max")); typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
Y->mutable_data<T>(context.GetPlace()); return {{"t_min", &t_min}, {"t_max", &t_max}};
}
auto x = framework::EigenVector<T>::Flatten(*X); template <typename Device, typename X, typename Y>
auto y = framework::EigenVector<T>::Flatten(*Y); void operator()(Device d, X x, Y y) const {
auto place = context.GetEigenDevice<Place>(); y.device(d) = x.cwiseMax(t_min).cwiseMin(t_max);
y.device(place) = x.cwiseMax(t_min).cwiseMin(t_max);
} }
}; };
template <typename Place, typename T, typename AttrType = T> template <typename T>
class BReluGradKernel : public framework::OpKernel<T> { struct BReluGradFunctor : public BaseActivationFunctor<T> {
public: float t_min;
void Compute(const framework::ExecutionContext& context) const override { float t_max;
auto* X = context.Input<framework::Tensor>("X"); typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
auto* dY = context.Input<framework::Tensor>(framework::GradVarName("Y")); return {{"t_min", &t_min}, {"t_max", &t_max}};
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X")); }
auto t_min = static_cast<T>(context.Attr<AttrType>("t_min")); template <typename Device, typename X, typename Y, typename dY, typename dX>
auto t_max = static_cast<T>(context.Attr<AttrType>("t_max")); void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dX->mutable_data<T>(context.GetPlace()); dx.device(d) = dy * ((x > t_min) * (x < t_max)).template cast<T>();
auto dy = framework::EigenVector<T>::Flatten(*dY);
auto x = framework::EigenVector<T>::Flatten(*X);
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto place = context.GetEigenDevice<Place>();
dx.device(place) = dy * ((x > t_min) * (x < t_max)).template cast<T>();
} }
}; };
template <typename Place, typename T, typename AttrType = T> template <typename T>
class SoftReluKernel : public framework::OpKernel<T> { struct SoftReluFunctor : public BaseActivationFunctor<T> {
public: float threshold;
void Compute(const framework::ExecutionContext& context) const override { typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
auto* X = context.Input<framework::Tensor>("X"); return {{"threshold", &threshold}};
auto* Y = context.Output<framework::Tensor>("Y"); }
auto threshold = static_cast<T>(context.Attr<AttrType>("threshold"));
Y->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X); template <typename Device, typename X, typename Y>
auto y = framework::EigenVector<T>::Flatten(*Y); void operator()(Device d, X x, Y y) const {
auto place = context.GetEigenDevice<Place>(); auto temp = x.cwiseMax(-threshold).cwiseMin(threshold);
auto temp = x.cwiseMax(-threshold).cwiseMin(threshold).eval(); y.device(d) = (static_cast<T>(1) + temp.exp()).log();
y.device(place) = (static_cast<T>(1) + temp.exp()).log();
} }
}; };
template <typename Place, typename T, typename AttrType = T> template <typename T>
class SoftReluGradKernel : public framework::OpKernel<T> { struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
public: float threshold;
void Compute(const framework::ExecutionContext& context) const override { typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
auto* X = context.Input<framework::Tensor>("X"); return {{"threshold", &threshold}};
auto* Y = context.Input<framework::Tensor>("Y"); }
auto* dY = context.Input<framework::Tensor>(framework::GradVarName("Y")); template <typename Device, typename X, typename Y, typename dY, typename dX>
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X")); void operator()(Device d, X x, Y y, dY dy, dX dx) const {
auto threshold = static_cast<T>(context.Attr<AttrType>("threshold"));
dX->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto dy = framework::EigenVector<T>::Flatten(*dY);
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto place = context.GetEigenDevice<Place>();
auto temp = ((x > -threshold) * (x < threshold)).template cast<T>().eval(); auto temp = ((x > -threshold) * (x < threshold)).template cast<T>().eval();
dx.device(place) = dy * (static_cast<T>(1) - (-y).exp()) * temp; dx.device(d) = dy * (static_cast<T>(1) - (-y).exp()) * temp;
} }
}; };
template <typename Place, typename T, typename AttrType = T> template <typename T>
class PowKernel : public framework::OpKernel<T> { struct PowFunctor : public BaseActivationFunctor<T> {
public: float factor;
void Compute(const framework::ExecutionContext& context) const override { typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
auto* X = context.Input<framework::Tensor>("X"); return {{"factor", &factor}};
auto* Y = context.Output<framework::Tensor>("Y"); }
auto factor = static_cast<T>(context.Attr<AttrType>("factor")); template <typename Device, typename X, typename Y>
Y->mutable_data<T>(context.GetPlace()); void operator()(Device d, X x, Y y) const {
y.device(d) = x.pow(factor);
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto place = context.GetEigenDevice<Place>();
y.device(place) = x.pow(factor);
} }
}; };
template <typename Place, typename T, typename AttrType = T> template <typename T>
class PowGradKernel : public framework::OpKernel<T> { struct PowGradFunctor : public BaseActivationFunctor<T> {
public: float factor;
void Compute(const framework::ExecutionContext& context) const override { typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
auto* X = context.Input<framework::Tensor>("X"); return {{"factor", &factor}};
auto* dY = context.Input<framework::Tensor>(framework::GradVarName("Y")); }
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X")); template <typename Device, typename X, typename Y, typename dY, typename dX>
auto factor = static_cast<T>(context.Attr<AttrType>("factor")); void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dX->mutable_data<T>(context.GetPlace()); dx.device(d) = dy * factor * x.pow(factor - static_cast<T>(1));
auto dy = framework::EigenVector<T>::Flatten(*dY);
auto x = framework::EigenVector<T>::Flatten(*X);
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto place = context.GetEigenDevice<Place>();
dx.device(place) = dy * factor * x.pow(factor - static_cast<T>(1));
} }
}; };
template <typename Place, typename T, typename AttrType = T> template <typename T>
class STanhKernel : public framework::OpKernel<T> { struct STanhFunctor : public BaseActivationFunctor<T> {
public: float scale_a;
void Compute(const framework::ExecutionContext& context) const override { float scale_b;
auto* X = context.Input<framework::Tensor>("X"); typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
auto* Y = context.Output<framework::Tensor>("Y"); return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
auto scale_a = static_cast<T>(context.Attr<AttrType>("scale_a")); }
auto scale_b = static_cast<T>(context.Attr<AttrType>("scale_b"));
Y->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X); template <typename Device, typename X, typename Y>
auto y = framework::EigenVector<T>::Flatten(*Y); void operator()(Device d, X x, Y y) const {
auto place = context.GetEigenDevice<Place>(); y.device(d) = scale_b * (scale_a * x).tanh();
y.device(place) = scale_b * (scale_a * x).tanh();
} }
}; };
template <typename Place, typename T, typename AttrType = T> template <typename T>
class STanhGradKernel : public framework::OpKernel<T> { struct STanhGradFunctor : public BaseActivationFunctor<T> {
public: float scale_a;
void Compute(const framework::ExecutionContext& context) const override { float scale_b;
auto* X = context.Input<framework::Tensor>("X"); typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
auto* dY = context.Input<framework::Tensor>(framework::GradVarName("Y")); return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X")); }
auto scale_a = static_cast<T>(context.Attr<AttrType>("scale_a"));
auto scale_b = static_cast<T>(context.Attr<AttrType>("scale_b"));
dX->mutable_data<T>(context.GetPlace());
auto dy = framework::EigenVector<T>::Flatten(*dY);
auto x = framework::EigenVector<T>::Flatten(*X);
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto place = context.GetEigenDevice<Place>();
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
auto temp = (scale_a * x).tanh() * (scale_a * x).tanh(); auto temp = (scale_a * x).tanh() * (scale_a * x).tanh();
dx.device(place) = dy * scale_a * scale_b * (static_cast<T>(1) - temp); dx.device(d) = dy * scale_a * scale_b * (static_cast<T>(1) - temp);
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#define FOR_EACH_KERNEL_FUNCTOR(__macro) \
__macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
__macro(exp, ExpFunctor, ExpGradFunctor); \
__macro(relu, ReluFunctor, ReluGradFunctor); \
__macro(tanh, TanhFunctor, TanhGradFunctor); \
__macro(sqrt, SqrtFunctor, SqrtGradFunctor); \
__macro(abs, AbsFunctor, AbsGradFunctor); \
__macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(log, LogFunctor, LogGradFunctor); \
__macro(square, SquareFunctor, SquareGradFunctor); \
__macro(brelu, BReluFunctor, BReluGradFunctor); \
__macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \
__macro(pow, PowFunctor, PowGradFunctor); \
__macro(stanh, STanhFunctor, STanhGradFunctor)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册