提交 337b7ebe 编写于 作者: Y Yu Yang

Unify Activation functions and simplify register code

上级 184768e0
......@@ -195,111 +195,54 @@ class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(sigmoid, ops::ActivationOp, ops::SigmoidOpMaker, sigmoid_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(sigmoid,
ops::ActivationKernel<paddle::platform::CPUPlace, float,
ops::SigmoidFunctor<float>>);
REGISTER_OP_CPU_KERNEL(
sigmoid_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::SigmoidGradFunctor<float>>);
REGISTER_OP(exp, ops::ActivationOp, ops::ExpOpMaker, exp_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
exp,
ops::ActivationKernel<paddle::platform::CPUPlace, float, ops::ExpFunctor>);
REGISTER_OP_CPU_KERNEL(exp_grad,
ops::ActivationGradKernel<paddle::platform::CPUPlace,
float, ops::ExpGradFunctor>);
REGISTER_OP(relu, ops::ActivationOp, ops::ReluOpMaker, relu_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(relu,
ops::ActivationKernel<paddle::platform::CPUPlace, float,
ops::ReluFunctor<float>>);
REGISTER_OP_CPU_KERNEL(
relu_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::ReluGradFunctor<float>>);
REGISTER_OP(tanh, ops::ActivationOp, ops::TanhOpMaker, tanh_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
tanh,
ops::ActivationKernel<paddle::platform::CPUPlace, float, ops::TanhFunctor>);
REGISTER_OP_CPU_KERNEL(
tanh_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::TanhGradFunctor<float>>);
REGISTER_OP(sqrt, ops::ActivationOp, ops::SqrtOpMaker, sqrt_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
sqrt,
ops::ActivationKernel<paddle::platform::CPUPlace, float, ops::SqrtFunctor>);
REGISTER_OP_CPU_KERNEL(
sqrt_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::SqrtGradFunctor<float>>);
REGISTER_OP(abs, ops::ActivationOp, ops::AbsOpMaker, abs_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
abs,
ops::ActivationKernel<paddle::platform::CPUPlace, float, ops::AbsFunctor>);
REGISTER_OP_CPU_KERNEL(abs_grad,
ops::ActivationGradKernel<paddle::platform::CPUPlace,
float, ops::AbsGradFunctor>);
REGISTER_OP(reciprocal, ops::ActivationOp, ops::ReciprocalOpMaker,
reciprocal_grad, ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(reciprocal,
ops::ActivationKernel<paddle::platform::CPUPlace, float,
ops::ReciprocalFunctor<float>>);
REGISTER_OP_CPU_KERNEL(
reciprocal_grad,
ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::ReciprocalGradFunctor<float>>);
REGISTER_OP(log, ops::ActivationOp, ops::LogOpMaker, log_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
log,
ops::ActivationKernel<paddle::platform::CPUPlace, float, ops::LogFunctor>);
REGISTER_OP_CPU_KERNEL(
log_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::LogGradFunctor<float>>);
REGISTER_OP(square, ops::ActivationOp, ops::SquareOpMaker, square_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(square,
ops::ActivationKernel<paddle::platform::CPUPlace, float,
ops::SquareFunctor>);
REGISTER_OP_CPU_KERNEL(
square_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::SquareGradFunctor<float>>);
REGISTER_OP(brelu, ops::ActivationOp, ops::BReluOpMaker<float>, brelu_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(brelu,
ops::BReluKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(brelu_grad,
ops::BReluGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker<float>,
soft_relu_grad, ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(soft_relu,
ops::SoftReluKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
soft_relu_grad, ops::SoftReluGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(pow, ops::ActivationOp, ops::PowOpMaker<float>, pow_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(pow, ops::PowKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(pow_grad,
ops::PowGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker<float>, stanh_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(stanh,
ops::STanhKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(stanh_grad,
ops::STanhGradKernel<paddle::platform::CPUPlace, float>);
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_CPU_KERNEL( \
act_type, \
paddle::operators::ActivationKernel<paddle::platform::CPUPlace, \
paddle::operators::functor<float>>); \
REGISTER_OP_CPU_KERNEL(act_type##_grad, \
paddle::operators::ActivationGradKernel< \
paddle::platform::CPUPlace, \
paddle::operators::grad_functor<float>>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);
......@@ -15,86 +15,14 @@
#define EIGEN_USE_GPU
#include "paddle/operators/activation_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(sigmoid,
ops::ActivationKernel<paddle::platform::GPUPlace, float,
ops::SigmoidFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
sigmoid_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::SigmoidGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
exp,
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::ExpFunctor>);
REGISTER_OP_GPU_KERNEL(exp_grad,
ops::ActivationGradKernel<paddle::platform::GPUPlace,
float, ops::ExpGradFunctor>);
REGISTER_OP_GPU_KERNEL(relu,
ops::ActivationKernel<paddle::platform::GPUPlace, float,
ops::ReluFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
relu_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::ReluGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
tanh,
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::TanhFunctor>);
REGISTER_OP_GPU_KERNEL(
tanh_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::TanhGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
sqrt,
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::SqrtFunctor>);
REGISTER_OP_GPU_KERNEL(
sqrt_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::SqrtGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
abs,
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::AbsFunctor>);
REGISTER_OP_GPU_KERNEL(abs_grad,
ops::ActivationGradKernel<paddle::platform::GPUPlace,
float, ops::AbsGradFunctor>);
REGISTER_OP_GPU_KERNEL(reciprocal,
ops::ActivationKernel<paddle::platform::GPUPlace, float,
ops::ReciprocalFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
reciprocal_grad,
ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::ReciprocalGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
log,
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::LogFunctor>);
REGISTER_OP_GPU_KERNEL(
log_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::LogGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(square,
ops::ActivationKernel<paddle::platform::GPUPlace, float,
ops::SquareFunctor>);
REGISTER_OP_GPU_KERNEL(
square_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::SquareGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(brelu,
ops::BReluKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(brelu_grad,
ops::BReluGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(soft_relu,
ops::SoftReluKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
soft_relu_grad, ops::SoftReluGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(pow, ops::PowKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(pow_grad,
ops::PowGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(stanh,
ops::STanhKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(stanh_grad,
ops::STanhGradKernel<paddle::platform::GPUPlace, float>);
#define REGISTER_ACTIVATION_GPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_GPU_KERNEL( \
act_type, \
paddle::operators::ActivationKernel<paddle::platform::GPUPlace, \
paddle::operators::functor<float>>); \
REGISTER_OP_GPU_KERNEL(act_type##_grad, \
paddle::operators::ActivationGradKernel< \
paddle::platform::GPUPlace, \
paddle::operators::grad_functor<float>>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_GPU_KERNEL);
......@@ -19,9 +19,12 @@
namespace paddle {
namespace operators {
template <typename Place, typename T, typename Functor>
class ActivationKernel : public framework::OpKernel<T> {
template <typename Place, typename Functor>
class ActivationKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Output<framework::Tensor>("Y");
......@@ -31,13 +34,20 @@ class ActivationKernel : public framework::OpKernel<T> {
auto y = framework::EigenVector<T>::Flatten(*Y);
auto place = context.GetEigenDevice<Place>();
Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
functor(place, x, y);
}
};
template <typename Place, typename T, typename Functor>
class ActivationGradKernel : public framework::OpKernel<T> {
template <typename Place, typename Functor>
class ActivationGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Input<framework::Tensor>("Y");
......@@ -51,303 +61,301 @@ class ActivationGradKernel : public framework::OpKernel<T> {
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto place = context.GetEigenDevice<Place>();
Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
functor(place, x, y, dy, dx);
}
};
template <typename T>
struct BaseActivationFunctor {
using ELEMENT_TYPE = T;
using AttrPair = std::vector<std::pair<const char*, float*>>;
AttrPair GetAttrs() { return AttrPair(); }
};
// sigmoid(x) = 1 / (1 + exp(-x))
template <typename T>
struct SigmoidFunctor {
struct SigmoidFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
void operator()(Device d, X x, Y y) const {
y.device(d) = static_cast<T>(1) / (static_cast<T>(1) + (-x).exp());
}
};
template <typename T>
struct SigmoidGradFunctor {
struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * y * (static_cast<T>(1) - y);
}
};
// exp(x) = e^x
struct ExpFunctor {
template <typename T>
struct ExpFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
void operator()(Device d, X x, Y y) const {
y.device(d) = x.exp();
}
};
struct ExpGradFunctor {
template <typename T>
struct ExpGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * y;
}
};
// relu(x) = max(x, 0)
template <typename T>
struct ReluFunctor {
struct ReluFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
void operator()(Device d, X x, Y y) const {
y.device(d) = x.cwiseMax(static_cast<T>(0));
}
};
template <typename T>
struct ReluGradFunctor {
struct ReluGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * (x > static_cast<T>(0)).template cast<T>();
}
};
// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
struct TanhFunctor {
template <typename T>
struct TanhFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
void operator()(Device d, X x, Y y) const {
y.device(d) = x.tanh();
}
};
template <typename T>
struct TanhGradFunctor {
struct TanhGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * (static_cast<T>(1) - y * y);
}
};
// sqrt(x) = x^(1/2)
struct SqrtFunctor {
template <typename T>
struct SqrtFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
void operator()(Device d, X x, Y y) const {
y.device(d) = x.sqrt();
}
};
template <typename T>
struct SqrtGradFunctor {
struct SqrtGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
const Y y_conj = Eigen::numext::conj(y);
dx.device(d) = static_cast<T>(0.5) * dy / y_conj;
}
};
// abs(x) = |x|
struct AbsFunctor {
template <typename T>
struct AbsFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
void operator()(Device d, X x, Y y) const {
y.device(d) = x.abs();
}
};
struct AbsGradFunctor {
template <typename T>
struct AbsGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * x.sign();
}
};
// reciprocal(x) = 1 / x
template <typename T>
struct ReciprocalFunctor {
struct ReciprocalFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
void operator()(Device d, X x, Y y) const {
y.device(d) = static_cast<T>(1) / x;
}
};
template <typename T>
struct ReciprocalGradFunctor {
struct ReciprocalGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * static_cast<T>(-1) * y * y;
}
};
// log(x) = natural logarithm of x
struct LogFunctor {
template <typename T>
struct LogFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
void operator()(Device d, X x, Y y) const {
y.device(d) = x.log();
}
};
template <typename T>
struct LogGradFunctor {
struct LogGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * (static_cast<T>(1) / x);
}
};
// square(x) = x^2
struct SquareFunctor {
template <typename T>
struct SquareFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
void operator()(Device d, X x, Y y) const {
y.device(d) = x.square();
}
};
template <typename T>
struct SquareGradFunctor {
struct SquareGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * static_cast<T>(2) * x;
}
};
template <typename Place, typename T, typename AttrType = T>
class BReluKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Output<framework::Tensor>("Y");
auto t_min = static_cast<T>(context.Attr<AttrType>("t_min"));
auto t_max = static_cast<T>(context.Attr<AttrType>("t_max"));
Y->mutable_data<T>(context.GetPlace());
template <typename T>
struct BReluFunctor : public BaseActivationFunctor<T> {
float t_min;
float t_max;
// NOTE: Explicit hides the `BaseActivationFunctor<T>::GetAttrs`
// not polymorphism for speed.
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"t_min", &t_min}, {"t_max", &t_max}};
}
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto place = context.GetEigenDevice<Place>();
y.device(place) = x.cwiseMax(t_min).cwiseMin(t_max);
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.cwiseMax(t_min).cwiseMin(t_max);
}
};
template <typename Place, typename T, typename AttrType = T>
class BReluGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* dY = context.Input<framework::Tensor>(framework::GradVarName("Y"));
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto t_min = static_cast<T>(context.Attr<AttrType>("t_min"));
auto t_max = static_cast<T>(context.Attr<AttrType>("t_max"));
dX->mutable_data<T>(context.GetPlace());
auto dy = framework::EigenVector<T>::Flatten(*dY);
auto x = framework::EigenVector<T>::Flatten(*X);
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto place = context.GetEigenDevice<Place>();
dx.device(place) = dy * ((x > t_min) * (x < t_max)).template cast<T>();
template <typename T>
struct BReluGradFunctor : public BaseActivationFunctor<T> {
float t_min;
float t_max;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"t_min", &t_min}, {"t_max", &t_max}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * ((x > t_min) * (x < t_max)).template cast<T>();
}
};
template <typename Place, typename T, typename AttrType = T>
class SoftReluKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Output<framework::Tensor>("Y");
auto threshold = static_cast<T>(context.Attr<AttrType>("threshold"));
Y->mutable_data<T>(context.GetPlace());
template <typename T>
struct SoftReluFunctor : public BaseActivationFunctor<T> {
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto place = context.GetEigenDevice<Place>();
auto temp = x.cwiseMax(-threshold).cwiseMin(threshold).eval();
y.device(place) = (static_cast<T>(1) + temp.exp()).log();
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
auto temp = x.cwiseMax(-threshold).cwiseMin(threshold);
y.device(d) = (static_cast<T>(1) + temp.exp()).log();
}
};
template <typename Place, typename T, typename AttrType = T>
class SoftReluGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Input<framework::Tensor>("Y");
auto* dY = context.Input<framework::Tensor>(framework::GradVarName("Y"));
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto threshold = static_cast<T>(context.Attr<AttrType>("threshold"));
dX->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto dy = framework::EigenVector<T>::Flatten(*dY);
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto place = context.GetEigenDevice<Place>();
template <typename T>
struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
auto temp = ((x > -threshold) * (x < threshold)).template cast<T>().eval();
dx.device(place) = dy * (static_cast<T>(1) - (-y).exp()) * temp;
dx.device(d) = dy * (static_cast<T>(1) - (-y).exp()) * temp;
}
};
template <typename Place, typename T, typename AttrType = T>
class PowKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Output<framework::Tensor>("Y");
auto factor = static_cast<T>(context.Attr<AttrType>("factor"));
Y->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto place = context.GetEigenDevice<Place>();
y.device(place) = x.pow(factor);
template <typename T>
struct PowFunctor : public BaseActivationFunctor<T> {
float factor;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"factor", &factor}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.pow(factor);
}
};
template <typename Place, typename T, typename AttrType = T>
class PowGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* dY = context.Input<framework::Tensor>(framework::GradVarName("Y"));
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto factor = static_cast<T>(context.Attr<AttrType>("factor"));
dX->mutable_data<T>(context.GetPlace());
auto dy = framework::EigenVector<T>::Flatten(*dY);
auto x = framework::EigenVector<T>::Flatten(*X);
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto place = context.GetEigenDevice<Place>();
dx.device(place) = dy * factor * x.pow(factor - static_cast<T>(1));
template <typename T>
struct PowGradFunctor : public BaseActivationFunctor<T> {
float factor;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"factor", &factor}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * factor * x.pow(factor - static_cast<T>(1));
}
};
template <typename Place, typename T, typename AttrType = T>
class STanhKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Output<framework::Tensor>("Y");
auto scale_a = static_cast<T>(context.Attr<AttrType>("scale_a"));
auto scale_b = static_cast<T>(context.Attr<AttrType>("scale_b"));
Y->mutable_data<T>(context.GetPlace());
template <typename T>
struct STanhFunctor : public BaseActivationFunctor<T> {
float scale_a;
float scale_b;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
}
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto place = context.GetEigenDevice<Place>();
y.device(place) = scale_b * (scale_a * x).tanh();
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = scale_b * (scale_a * x).tanh();
}
};
template <typename Place, typename T, typename AttrType = T>
class STanhGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* dY = context.Input<framework::Tensor>(framework::GradVarName("Y"));
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto scale_a = static_cast<T>(context.Attr<AttrType>("scale_a"));
auto scale_b = static_cast<T>(context.Attr<AttrType>("scale_b"));
dX->mutable_data<T>(context.GetPlace());
auto dy = framework::EigenVector<T>::Flatten(*dY);
auto x = framework::EigenVector<T>::Flatten(*X);
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto place = context.GetEigenDevice<Place>();
template <typename T>
struct STanhGradFunctor : public BaseActivationFunctor<T> {
float scale_a;
float scale_b;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
auto temp = (scale_a * x).tanh() * (scale_a * x).tanh();
dx.device(place) = dy * scale_a * scale_b * (static_cast<T>(1) - temp);
dx.device(d) = dy * scale_a * scale_b * (static_cast<T>(1) - temp);
}
};
} // namespace operators
} // namespace paddle
#define FOR_EACH_KERNEL_FUNCTOR(__macro) \
__macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
__macro(exp, ExpFunctor, ExpGradFunctor); \
__macro(relu, ReluFunctor, ReluGradFunctor); \
__macro(tanh, TanhFunctor, TanhGradFunctor); \
__macro(sqrt, SqrtFunctor, SqrtGradFunctor); \
__macro(abs, AbsFunctor, AbsGradFunctor); \
__macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(log, LogFunctor, LogGradFunctor); \
__macro(square, SquareFunctor, SquareGradFunctor); \
__macro(brelu, BReluFunctor, BReluGradFunctor); \
__macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \
__macro(pow, PowFunctor, PowGradFunctor); \
__macro(stanh, STanhFunctor, STanhGradFunctor)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册