提交 c18ebc30 编写于 作者: Q qijun

remove macros

上级 4e173527
...@@ -14,33 +14,55 @@ ...@@ -14,33 +14,55 @@
#include "paddle/operators/activation_op.h" #include "paddle/operators/activation_op.h"
#define FILL_ACTIVATION_OP \ // #define FILL_ACTIVATION_OP \
public: \ // public: \
using framework::OperatorWithKernel::OperatorWithKernel; \ // using framework::OperatorWithKernel::OperatorWithKernel; \
\ // \
protected: \ // protected: \
void InferShape(const framework::InferShapeContext &ctx) const override { \ // void InferShape(const framework::InferShapeContext &ctx) const override { \
ctx.Output<framework::Tensor>("Y")->Resize( \ // ctx.Output<framework::Tensor>("Y")->Resize( \
ctx.Input<framework::Tensor>("X")->dims()); \ // ctx.Input<framework::Tensor>("X")->dims()); \
} // }
#define FILL_ACTIVATION_GRAD_OP \ // #define FILL_ACTIVATION_GRAD_OP \
public: \ // public: \
using framework::OperatorWithKernel::OperatorWithKernel; \ // using framework::OperatorWithKernel::OperatorWithKernel; \
\ // \
protected: \ // protected: \
void InferShape(const framework::InferShapeContext &ctx) const override { \ // void InferShape(const framework::InferShapeContext &ctx) const override { \
ctx.Output<framework::Tensor>(framework::GradVarName("X")) \ // ctx.Output<framework::Tensor>(framework::GradVarName("X")) \
->Resize(ctx.Input<framework::Tensor>("Y")->dims()); \ // ->Resize(ctx.Input<framework::Tensor>("Y")->dims()); \
} // }
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class SigmoidOp : public framework::OperatorWithKernel { class ActivationOp : public framework::OperatorWithKernel {
FILL_ACTIVATION_OP public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<framework::Tensor>("Y")->Resize(
ctx.Input<framework::Tensor>("X")->dims());
}
}; };
class ActivationOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<framework::Tensor>(framework::GradVarName("X"))
->Resize(ctx.Input<framework::Tensor>("Y")->dims());
}
};
// class SigmoidOp : public framework::OperatorWithKernel {
// FILL_ACTIVATION_OP
// };
class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SigmoidOpMaker(framework::OpProto *proto, SigmoidOpMaker(framework::OpProto *proto,
...@@ -52,13 +74,13 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -52,13 +74,13 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
class SigmoidOpGrad : public framework::OperatorWithKernel { // class SigmoidOpGrad : public framework::OperatorWithKernel {
FILL_ACTIVATION_GRAD_OP // FILL_ACTIVATION_GRAD_OP
}; // };
class ExpOp : public framework::OperatorWithKernel { // class ExpOp : public framework::OperatorWithKernel {
FILL_ACTIVATION_OP // FILL_ACTIVATION_OP
}; // };
class ExpOpMaker : public framework::OpProtoAndCheckerMaker { class ExpOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
...@@ -70,13 +92,13 @@ class ExpOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -70,13 +92,13 @@ class ExpOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
class ExpOpGrad : public framework::OperatorWithKernel { // class ExpOpGrad : public framework::OperatorWithKernel {
FILL_ACTIVATION_GRAD_OP // FILL_ACTIVATION_GRAD_OP
}; // };
class ReluOp : public framework::OperatorWithKernel { // class ReluOp : public framework::OperatorWithKernel {
FILL_ACTIVATION_OP // FILL_ACTIVATION_OP
}; // };
class ReluOpMaker : public framework::OpProtoAndCheckerMaker { class ReluOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
...@@ -88,28 +110,36 @@ class ReluOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -88,28 +110,36 @@ class ReluOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
class ReluOpGrad : public framework::OperatorWithKernel { // class ReluOpGrad : public framework::OperatorWithKernel {
FILL_ACTIVATION_GRAD_OP // FILL_ACTIVATION_GRAD_OP
}; // };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, sigmoid_grad, REGISTER_OP(sigmoid, ops::ActivationOp, ops::SigmoidOpMaker, sigmoid_grad,
ops::SigmoidOpGrad); ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(sigmoid, REGISTER_OP_CPU_KERNEL(
ops::SigmoidKernel<paddle::platform::CPUPlace, float>); sigmoid,
ops::ActivationKernel<paddle::platform::CPUPlace, float, ops::Sigmoid>);
REGISTER_OP_CPU_KERNEL(sigmoid_grad,
ops::ActivationGradKernel<paddle::platform::CPUPlace,
float, ops::SigmoidGrad>);
REGISTER_OP(exp, ops::ActivationOp, ops::ExpOpMaker, exp_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
exp, ops::ActivationKernel<paddle::platform::CPUPlace, float, ops::Exp>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sigmoid_grad, ops::SigmoidGradKernel<paddle::platform::CPUPlace, float>); exp_grad,
ops::ActivationGradKernel<paddle::platform::CPUPlace, float, ops::ExpGrad>);
REGISTER_OP(exp, ops::ExpOp, ops::ExpOpMaker, exp_grad, ops::ExpOpGrad);
REGISTER_OP_CPU_KERNEL(exp, ops::ExpKernel<paddle::platform::CPUPlace, float>); // REGISTER_OP(relu, ops::ActivationOp, ops::ReluOpMaker, relu_grad,
REGISTER_OP_CPU_KERNEL(exp_grad, // ops::ActivationOpGrad);
ops::ExpGradKernel<paddle::platform::CPUPlace, float>); // REGISTER_OP_CPU_KERNEL(relu,
// ops::ReluKernel<paddle::platform::CPUPlace, float,
REGISTER_OP(relu, ops::ReluOp, ops::ReluOpMaker, relu_grad, ops::ReluOpGrad); // ops::Relu>);
REGISTER_OP_CPU_KERNEL(relu, // REGISTER_OP_CPU_KERNEL(relu_grad,
ops::ReluKernel<paddle::platform::CPUPlace, float>); // ops::ReluGradKernel<paddle::platform::CPUPlace, float,
REGISTER_OP_CPU_KERNEL(relu_grad, // ops::ReluGrad>);
ops::ReluGradKernel<paddle::platform::CPUPlace, float>);
...@@ -15,57 +15,135 @@ ...@@ -15,57 +15,135 @@
#pragma once #pragma once
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/math/activation_functor.h" // #include "paddle/operators/math/activation_functor.h"
#define ACTIVATION_KERNEL_NAME(ACTIVATION_NAME) ACTIVATION_NAME##Kernel // #define ACTIVATION_KERNEL_NAME(ACTIVATION_NAME) ACTIVATION_NAME##Kernel
#define DEFINE_ACTIVATION_KERNEL(ACTIVATION_NAME) \ // #define DEFINE_ACTIVATION_KERNEL(ACTIVATION_NAME) \
template <typename Place, typename T> \ // template <typename Place, typename T> \
class ACTIVATION_KERNEL_NAME(ACTIVATION_NAME) : public framework::OpKernel { \ // class ACTIVATION_KERNEL_NAME(ACTIVATION_NAME) : public framework::OpKernel { \
public: \ // public: \
void Compute(const framework::ExecutionContext& context) const override { \ // void Compute(const framework::ExecutionContext& context) const override { \
auto* X = context.Input<framework::Tensor>("X"); \ // auto* X = context.Input<framework::Tensor>("X"); \
auto* Y = context.Output<framework::Tensor>("Y"); \ // auto* Y = context.Output<framework::Tensor>("Y"); \
Y->mutable_data<T>(context.GetPlace()); \ // Y->mutable_data<T>(context.GetPlace()); \
math::ACTIVATION_NAME<Place, T> functor; \ // math::ACTIVATION_NAME<Place, T> functor; \
auto* device_context = context.device_context(); \ // auto* device_context = context.device_context(); \
functor(*device_context, *X, Y); \ // functor(*device_context, *X, Y); \
} \ // } \
}; // };
#define DEFINE_ACTIVATION_GRAD_KERNEL(ACTIVATION_GRAD_NAME) \ // #define DEFINE_ACTIVATION_GRAD_KERNEL(ACTIVATION_GRAD_NAME) \
template <typename Place, typename T> \ // template <typename Place, typename T> \
class ACTIVATION_KERNEL_NAME(ACTIVATION_GRAD_NAME) \ // class ACTIVATION_KERNEL_NAME(ACTIVATION_GRAD_NAME) \
: public framework::OpKernel { \ // : public framework::OpKernel { \
public: \ // public: \
void Compute(const framework::ExecutionContext& context) const override { \ // void Compute(const framework::ExecutionContext& context) const override { \
auto* X = context.Input<framework::Tensor>("X"); \ // auto* X = context.Input<framework::Tensor>("X"); \
auto* Y = context.Input<framework::Tensor>("Y"); \ // auto* Y = context.Input<framework::Tensor>("Y"); \
auto* dY = \ // auto* dY = \
context.Input<framework::Tensor>(framework::GradVarName("Y")); \ // context.Input<framework::Tensor>(framework::GradVarName("Y")); \
auto* dX = \ // auto* dX = \
context.Output<framework::Tensor>(framework::GradVarName("X")); \ // context.Output<framework::Tensor>(framework::GradVarName("X")); \
dX->mutable_data<T>(context.GetPlace()); \ // dX->mutable_data<T>(context.GetPlace()); \
math::ACTIVATION_GRAD_NAME<Place, T> functor; \ // math::ACTIVATION_GRAD_NAME<Place, T> functor; \
auto* device_context = context.device_context(); \ // auto* device_context = context.device_context(); \
functor(*device_context, *X, *Y, *dY, dX); \ // functor(*device_context, *X, *Y, *dY, dX); \
} \ // } \
}; // };
namespace paddle { namespace paddle {
namespace operators { namespace operators {
DEFINE_ACTIVATION_KERNEL(Sigmoid); template <typename Place, typename T, typename Functor>
class ActivationKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Output<framework::Tensor>("Y");
Y->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto place = context.GetEigenDevice<Place>();
Functor functor;
functor(place, x, y);
}
};
template <typename Place, typename T, typename Functor>
class ActivationGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Input<framework::Tensor>("Y");
auto* dY = context.Input<framework::Tensor>(framework::GradVarName("Y"));
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
dX->mutable_data<T>(context.GetPlace());
auto dy = framework::EigenVector<T>::Flatten(*dY);
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto place = context.GetEigenDevice<Place>();
Functor functor;
functor(place, x, y, dy, dx);
}
};
struct Sigmoid {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = 1. / (1. + (-x).exp());
}
};
struct SigmoidGrad {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
dx.device(d) = dy * y * (1. - y);
}
};
struct Exp {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = x.exp();
}
};
struct ExpGrad {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
dx.device(d) = y;
}
};
// template <typename Device, typename X, typename Y>
// struct Relu {
// void operator()(Device d, X x, Y y) {
// y.device(d) = x.cwiseMax(static_cast<T>(0));
// }
// };
// template <typename Device, typename X, typename Y, typename dY, typename dX>
// struct ReluGrad {
// void operator()(Device d, X x, Y y, dY dy, dX dx) {
// dx.device(d) = dy * (x > static_cast<T>(0)).template cast<T>();
// }
// };
// DEFINE_ACTIVATION_KERNEL(Sigmoid);
DEFINE_ACTIVATION_GRAD_KERNEL(SigmoidGrad); // DEFINE_ACTIVATION_GRAD_KERNEL(SigmoidGrad);
DEFINE_ACTIVATION_KERNEL(Exp); // DEFINE_ACTIVATION_KERNEL(Exp);
DEFINE_ACTIVATION_GRAD_KERNEL(ExpGrad); // DEFINE_ACTIVATION_GRAD_KERNEL(ExpGrad);
DEFINE_ACTIVATION_KERNEL(Relu); // DEFINE_ACTIVATION_KERNEL(Relu);
DEFINE_ACTIVATION_GRAD_KERNEL(ReluGrad); // DEFINE_ACTIVATION_GRAD_KERNEL(ReluGrad);
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -56,7 +56,7 @@ USE_OP(sum); ...@@ -56,7 +56,7 @@ USE_OP(sum);
USE_OP(reshape); USE_OP(reshape);
USE_OP(sigmoid); USE_OP(sigmoid);
USE_OP(exp); USE_OP(exp);
USE_OP(relu); // USE_OP(relu);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册