提交 61cb4f2f 编写于 作者: D dzhwinter

"fix ci"

上级 425a1e76
......@@ -13,16 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/activation_op.h"
#include <string>
#include "paddle/fluid/operators/mkldnn_activation_op.h"
namespace paddle {
namespace operators {
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \
class OP_NAME##OpMaker : public framework::OpProtoAndCheckerMaker { \
class OP_NAME##OpMaker \
: public ::paddle::framework::OpProtoAndCheckerMaker { \
public: \
OP_NAME##OpMaker(OpProto *proto, OpAttrChecker *op_checker) \
: framework::OpProtoAndCheckerMaker(proto, op_checker) { \
: ::paddle::framework::OpProtoAndCheckerMaker(proto, op_checker) { \
AddInput("X", "Input of " #OP_NAME "operator"); \
AddOutput("Out", "Output of" #OP_NAME "operator"); \
AddAttr<bool>("use_mkldnn", \
......@@ -30,26 +32,28 @@ namespace operators {
.SetDefault(false); \
AddComment(#OP_COMMENT); \
} \
}
};
#define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE) \
class OP_NAME##GradMaker : public framework::SingleGradOpDescMaker { \
class OP_NAME##GradMaker \
: public ::paddle::framework::SingleGradOpDescMaker { \
public: \
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; \
using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker; \
\
protected: \
std::unique_ptr<framework::OpDesc> Apply() const override { \
auto *op = new framework::OpDesc(); \
std::unique_ptr<::paddle::framework::OpDesc> Apply() const override { \
auto *op = new ::paddle::framework::OpDesc(); \
op->SetType(#KERNEL_TYPE "_grad"); \
op->SetInput("Out", Output("Out")); \
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); \
op->SetInput(::paddle::framework::GradVarName("Out"), \
OutputGrad("Out")); \
\
op->SetAttrMap(Attrs()); \
\
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); \
return std::unique_ptr<framework::OpDesc>(op); \
op->SetOutput(::paddle::framework::GradVarName("X"), InputGrad("X")); \
return std::unique_ptr<::paddle::framework::OpDesc>(op); \
} \
}
};
class ActivationOp : public framework::OperatorWithKernel {
public:
......@@ -449,70 +453,67 @@ REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);
// NOTE(*) only gradient can be inplaced need to register its gradient maker,
// To tell the executor which input variable is used. By default, every Input
// variable
// is used in gradient operator.
// The operator name written in lowercase intentionally.
REGISTER_ACTIVATION_OP_GRAD_MAKER(Sigmoid, sigmoid);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Exp, exp);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu, relu);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Exp, exp);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Tanh, tanh);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Sqrt, sqrt);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Ceil, ceil);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Floor, floor);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Reciprocal, reciprocal);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu6, relu6);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Sqrt, sqrt);
REGISTER_ACTIVATION_OP_GRAD_MAKER(SoftRelu, soft_relu);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu6, relu6);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Reciprocal, reciprocal);
REGISTER_ACTIVATION_OP_GRAD_MAKER(HardSigmoid, hard_sigmoid);
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
#define REGISTER_INPLACE_ACTIVATION_OP(act_type, op_name) \
REGISTER_OPERATOR(act_type, ops::ActivationOp, ops::op_name##OpMaker, \
ops::op_name##GradMaker); \
REGISTER_OPERATOR(act_type##grad, ops::ActivationOpGrad)
#define REGISTER_ACTIVATION_OP(act_type, op_name) \
REGISTER_OP(act_type, ops::ActivationOp, ops::op_name##OpMaker, \
act_type##_grad, ops::ActivationOpGrad);
void DummyFunctor() {}
#define FOR_EACH_INPLACE_OP_FUNCTOR(__macro) \
__macro(sigmoid, Sigmoid); \
__macro(relu, Relu); \
__macro(exp, Exp); \
__macro(tanh, Tanh); \
__macro(ceil, Ceil); \
__macro(floor, Floor); \
__macro(sqrt, Sqrt); \
__macro(soft_relu, SoftRelu); \
__macro(relu6, Relu6); \
__macro(reciprocal, Reciprocal); \
__macro(hard_sigmoid, HardSigmoid);
__macro(Sigmoid, sigmoid); \
__macro(Relu, relu); \
__macro(Exp, exp); \
__macro(Tanh, tanh); \
__macro(Ceil, ceil); \
__macro(Floor, floor); \
__macro(Sqrt, sqrt); \
__macro(SoftRelu, soft_relu); \
__macro(Relu6, relu6); \
__macro(Reciprocal, reciprocal); \
__macro(HardSigmoid, hard_sigmoid);
#define FOR_EACH_OP_FUNCTOR(__macro) \
__macro(logsigmoid, LogSigmoid); \
__macro(softshrink, SoftShrink); \
__macro(abs, Abs); \
__macro(cos, Cos); \
__macro(sin, Sin); \
__macro(round, Round); \
__macro(log, Log); \
__macro(square, Square); \
__macro(brelu, BRelu); \
__macro(pow, Pow); \
__macro(stanh, STanh); \
__macro(softplus, Softplus); \
__macro(softsign, Softsign); \
__macro(leaky_relu, LeakyRelu); \
__macro(tanh_shrink, TanhShrink); \
__macro(elu, ELU); \
__macro(hard_shrink, HardShrink); \
__macro(swish, Swish); \
__macro(thresholded_relu, ThresholdedRelu);
__macro(LogSigmoid, logsigmoid); \
__macro(SoftShrink, softshrink); \
__macro(Abs, abs); \
__macro(Cos, cos); \
__macro(Sin, sin); \
__macro(Round, round); \
__macro(Log, log); \
__macro(Square, square); \
__macro(BRelu, brelu); \
__macro(Pow, pow); \
__macro(STanh, stanh); \
__macro(Softplus, softplus); \
__macro(Softsign, softsign); \
__macro(LeakyRelu, leaky_relu); \
__macro(TanhShrink, tanh_shrink); \
__macro(ELU, elu); \
__macro(HardShrink, hard_shrink); \
__macro(Swish, swish); \
__macro(ThresholdedRelu, thresholded_relu);
#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::OP_NAME##GradMaker); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OP(KERNEL_TYPE, ops::ActivationOp, ops::OP_NAME##OpMaker, \
KERNEL_TYPE##_grad, ops::ActivationOpGrad);
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_CPU_KERNEL( \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册