提交 c18ebc30 编写于 作者: Q qijun

remove macros

上级 4e173527
......@@ -14,33 +14,55 @@
#include "paddle/operators/activation_op.h"
#define FILL_ACTIVATION_OP \
public: \
using framework::OperatorWithKernel::OperatorWithKernel; \
\
protected: \
void InferShape(const framework::InferShapeContext &ctx) const override { \
ctx.Output<framework::Tensor>("Y")->Resize( \
ctx.Input<framework::Tensor>("X")->dims()); \
}
#define FILL_ACTIVATION_GRAD_OP \
public: \
using framework::OperatorWithKernel::OperatorWithKernel; \
\
protected: \
void InferShape(const framework::InferShapeContext &ctx) const override { \
ctx.Output<framework::Tensor>(framework::GradVarName("X")) \
->Resize(ctx.Input<framework::Tensor>("Y")->dims()); \
}
// #define FILL_ACTIVATION_OP \
// public: \
// using framework::OperatorWithKernel::OperatorWithKernel; \
// \
// protected: \
// void InferShape(const framework::InferShapeContext &ctx) const override { \
// ctx.Output<framework::Tensor>("Y")->Resize( \
// ctx.Input<framework::Tensor>("X")->dims()); \
// }
// #define FILL_ACTIVATION_GRAD_OP \
// public: \
// using framework::OperatorWithKernel::OperatorWithKernel; \
// \
// protected: \
// void InferShape(const framework::InferShapeContext &ctx) const override { \
// ctx.Output<framework::Tensor>(framework::GradVarName("X")) \
// ->Resize(ctx.Input<framework::Tensor>("Y")->dims()); \
// }
namespace paddle {
namespace operators {
class SigmoidOp : public framework::OperatorWithKernel {
FILL_ACTIVATION_OP
class ActivationOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<framework::Tensor>("Y")->Resize(
ctx.Input<framework::Tensor>("X")->dims());
}
};
class ActivationOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<framework::Tensor>(framework::GradVarName("X"))
->Resize(ctx.Input<framework::Tensor>("Y")->dims());
}
};
// class SigmoidOp : public framework::OperatorWithKernel {
// FILL_ACTIVATION_OP
// };
class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SigmoidOpMaker(framework::OpProto *proto,
......@@ -52,13 +74,13 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
class SigmoidOpGrad : public framework::OperatorWithKernel {
FILL_ACTIVATION_GRAD_OP
};
// class SigmoidOpGrad : public framework::OperatorWithKernel {
// FILL_ACTIVATION_GRAD_OP
// };
class ExpOp : public framework::OperatorWithKernel {
FILL_ACTIVATION_OP
};
// class ExpOp : public framework::OperatorWithKernel {
// FILL_ACTIVATION_OP
// };
class ExpOpMaker : public framework::OpProtoAndCheckerMaker {
public:
......@@ -70,13 +92,13 @@ class ExpOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
class ExpOpGrad : public framework::OperatorWithKernel {
FILL_ACTIVATION_GRAD_OP
};
// class ExpOpGrad : public framework::OperatorWithKernel {
// FILL_ACTIVATION_GRAD_OP
// };
class ReluOp : public framework::OperatorWithKernel {
FILL_ACTIVATION_OP
};
// class ReluOp : public framework::OperatorWithKernel {
// FILL_ACTIVATION_OP
// };
class ReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
......@@ -88,28 +110,36 @@ class ReluOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
class ReluOpGrad : public framework::OperatorWithKernel {
FILL_ACTIVATION_GRAD_OP
};
// class ReluOpGrad : public framework::OperatorWithKernel {
// FILL_ACTIVATION_GRAD_OP
// };
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, sigmoid_grad,
ops::SigmoidOpGrad);
REGISTER_OP_CPU_KERNEL(sigmoid,
ops::SigmoidKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(sigmoid, ops::ActivationOp, ops::SigmoidOpMaker, sigmoid_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
sigmoid,
ops::ActivationKernel<paddle::platform::CPUPlace, float, ops::Sigmoid>);
REGISTER_OP_CPU_KERNEL(sigmoid_grad,
ops::ActivationGradKernel<paddle::platform::CPUPlace,
float, ops::SigmoidGrad>);
REGISTER_OP(exp, ops::ActivationOp, ops::ExpOpMaker, exp_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
exp, ops::ActivationKernel<paddle::platform::CPUPlace, float, ops::Exp>);
REGISTER_OP_CPU_KERNEL(
sigmoid_grad, ops::SigmoidGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(exp, ops::ExpOp, ops::ExpOpMaker, exp_grad, ops::ExpOpGrad);
REGISTER_OP_CPU_KERNEL(exp, ops::ExpKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(exp_grad,
ops::ExpGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(relu, ops::ReluOp, ops::ReluOpMaker, relu_grad, ops::ReluOpGrad);
REGISTER_OP_CPU_KERNEL(relu,
ops::ReluKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(relu_grad,
ops::ReluGradKernel<paddle::platform::CPUPlace, float>);
exp_grad,
ops::ActivationGradKernel<paddle::platform::CPUPlace, float, ops::ExpGrad>);
// REGISTER_OP(relu, ops::ActivationOp, ops::ReluOpMaker, relu_grad,
// ops::ActivationOpGrad);
// REGISTER_OP_CPU_KERNEL(relu,
// ops::ReluKernel<paddle::platform::CPUPlace, float,
// ops::Relu>);
// REGISTER_OP_CPU_KERNEL(relu_grad,
// ops::ReluGradKernel<paddle::platform::CPUPlace, float,
// ops::ReluGrad>);
......@@ -15,57 +15,135 @@
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/activation_functor.h"
#define ACTIVATION_KERNEL_NAME(ACTIVATION_NAME) ACTIVATION_NAME##Kernel
#define DEFINE_ACTIVATION_KERNEL(ACTIVATION_NAME) \
template <typename Place, typename T> \
class ACTIVATION_KERNEL_NAME(ACTIVATION_NAME) : public framework::OpKernel { \
public: \
void Compute(const framework::ExecutionContext& context) const override { \
auto* X = context.Input<framework::Tensor>("X"); \
auto* Y = context.Output<framework::Tensor>("Y"); \
Y->mutable_data<T>(context.GetPlace()); \
math::ACTIVATION_NAME<Place, T> functor; \
auto* device_context = context.device_context(); \
functor(*device_context, *X, Y); \
} \
};
#define DEFINE_ACTIVATION_GRAD_KERNEL(ACTIVATION_GRAD_NAME) \
template <typename Place, typename T> \
class ACTIVATION_KERNEL_NAME(ACTIVATION_GRAD_NAME) \
: public framework::OpKernel { \
public: \
void Compute(const framework::ExecutionContext& context) const override { \
auto* X = context.Input<framework::Tensor>("X"); \
auto* Y = context.Input<framework::Tensor>("Y"); \
auto* dY = \
context.Input<framework::Tensor>(framework::GradVarName("Y")); \
auto* dX = \
context.Output<framework::Tensor>(framework::GradVarName("X")); \
dX->mutable_data<T>(context.GetPlace()); \
math::ACTIVATION_GRAD_NAME<Place, T> functor; \
auto* device_context = context.device_context(); \
functor(*device_context, *X, *Y, *dY, dX); \
} \
};
// #include "paddle/operators/math/activation_functor.h"
// #define ACTIVATION_KERNEL_NAME(ACTIVATION_NAME) ACTIVATION_NAME##Kernel
// #define DEFINE_ACTIVATION_KERNEL(ACTIVATION_NAME) \
// template <typename Place, typename T> \
// class ACTIVATION_KERNEL_NAME(ACTIVATION_NAME) : public framework::OpKernel { \
// public: \
// void Compute(const framework::ExecutionContext& context) const override { \
// auto* X = context.Input<framework::Tensor>("X"); \
// auto* Y = context.Output<framework::Tensor>("Y"); \
// Y->mutable_data<T>(context.GetPlace()); \
// math::ACTIVATION_NAME<Place, T> functor; \
// auto* device_context = context.device_context(); \
// functor(*device_context, *X, Y); \
// } \
// };
// #define DEFINE_ACTIVATION_GRAD_KERNEL(ACTIVATION_GRAD_NAME) \
// template <typename Place, typename T> \
// class ACTIVATION_KERNEL_NAME(ACTIVATION_GRAD_NAME) \
// : public framework::OpKernel { \
// public: \
// void Compute(const framework::ExecutionContext& context) const override { \
// auto* X = context.Input<framework::Tensor>("X"); \
// auto* Y = context.Input<framework::Tensor>("Y"); \
// auto* dY = \
// context.Input<framework::Tensor>(framework::GradVarName("Y")); \
// auto* dX = \
// context.Output<framework::Tensor>(framework::GradVarName("X")); \
// dX->mutable_data<T>(context.GetPlace()); \
// math::ACTIVATION_GRAD_NAME<Place, T> functor; \
// auto* device_context = context.device_context(); \
// functor(*device_context, *X, *Y, *dY, dX); \
// } \
// };
namespace paddle {
namespace operators {
DEFINE_ACTIVATION_KERNEL(Sigmoid);
template <typename Place, typename T, typename Functor>
class ActivationKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Output<framework::Tensor>("Y");
Y->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto place = context.GetEigenDevice<Place>();
Functor functor;
functor(place, x, y);
}
};
template <typename Place, typename T, typename Functor>
class ActivationGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Input<framework::Tensor>("Y");
auto* dY = context.Input<framework::Tensor>(framework::GradVarName("Y"));
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
dX->mutable_data<T>(context.GetPlace());
auto dy = framework::EigenVector<T>::Flatten(*dY);
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto place = context.GetEigenDevice<Place>();
Functor functor;
functor(place, x, y, dy, dx);
}
};
struct Sigmoid {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = 1. / (1. + (-x).exp());
}
};
struct SigmoidGrad {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
dx.device(d) = dy * y * (1. - y);
}
};
struct Exp {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = x.exp();
}
};
struct ExpGrad {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
dx.device(d) = y;
}
};
// template <typename Device, typename X, typename Y>
// struct Relu {
// void operator()(Device d, X x, Y y) {
// y.device(d) = x.cwiseMax(static_cast<T>(0));
// }
// };
// template <typename Device, typename X, typename Y, typename dY, typename dX>
// struct ReluGrad {
// void operator()(Device d, X x, Y y, dY dy, dX dx) {
// dx.device(d) = dy * (x > static_cast<T>(0)).template cast<T>();
// }
// };
// DEFINE_ACTIVATION_KERNEL(Sigmoid);
DEFINE_ACTIVATION_GRAD_KERNEL(SigmoidGrad);
// DEFINE_ACTIVATION_GRAD_KERNEL(SigmoidGrad);
DEFINE_ACTIVATION_KERNEL(Exp);
// DEFINE_ACTIVATION_KERNEL(Exp);
DEFINE_ACTIVATION_GRAD_KERNEL(ExpGrad);
// DEFINE_ACTIVATION_GRAD_KERNEL(ExpGrad);
DEFINE_ACTIVATION_KERNEL(Relu);
// DEFINE_ACTIVATION_KERNEL(Relu);
DEFINE_ACTIVATION_GRAD_KERNEL(ReluGrad);
// DEFINE_ACTIVATION_GRAD_KERNEL(ReluGrad);
} // namespace operators
} // namespace paddle
......@@ -56,7 +56,7 @@ USE_OP(sum);
USE_OP(reshape);
USE_OP(sigmoid);
USE_OP(exp);
USE_OP(relu);
// USE_OP(relu);
namespace paddle {
namespace framework {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册