From a76d0dd488934d2f8859234f1d8d0e6eccfd5c30 Mon Sep 17 00:00:00 2001 From: Krzysztof Binias Date: Mon, 14 May 2018 12:18:48 +0200 Subject: [PATCH] MKL-DNN activations improvements --- paddle/fluid/operators/activation_op.cc | 58 +++++++++++-------- paddle/fluid/operators/mkldnn_activation_op.h | 47 --------------- 2 files changed, 34 insertions(+), 71 deletions(-) diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 6f7a965bcf3..dd71c66a75a 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -41,7 +41,7 @@ namespace operators { \ protected: \ std::unique_ptr<::paddle::framework::OpDesc> Apply() const override { \ - auto *op = new ::paddle::framework::OpDesc(); \ + auto* op = new ::paddle::framework::OpDesc(); \ op->SetType(#KERNEL_TYPE "_grad"); \ op->SetInput("Out", Output("Out")); \ op->SetInput(::paddle::framework::GradVarName("Out"), \ @@ -54,23 +54,50 @@ namespace operators { } \ } +framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, + const framework::OperatorWithKernel& oper, + const std::string& name) { + framework::LibraryType library{framework::LibraryType::kPlain}; +#ifdef PADDLE_WITH_MKLDNN + auto it = oper.Attrs().find("use_mkldnn"); + if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() && + platform::CanMKLDNNBeUsed(ctx)) { + library = framework::LibraryType::kMKLDNN; + } +#endif + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + return framework::OpKernelType( + framework::ToDataType(ctx.Input(name)->type()), + ctx.GetPlace(), layout, library); +} + class ActivationOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->ShareLoD("X", /*->*/ "Out"); } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return GetKernelType(ctx, *this, "X"); + } }; class ActivationOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out")); } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return GetKernelType(ctx, *this, "Out"); + } }; __attribute__((unused)) constexpr char SigmoidDoc[] = R"DOC( @@ -458,22 +485,21 @@ namespace ops = paddle::operators; #define FOR_EACH_INPLACE_OP_FUNCTOR(__macro) \ __macro(Sigmoid, sigmoid); \ + __macro(Relu, relu); \ __macro(Exp, exp); \ + __macro(Tanh, tanh); \ __macro(Ceil, ceil); \ __macro(Floor, floor); \ + __macro(Sqrt, sqrt); \ __macro(SoftRelu, soft_relu); \ __macro(Relu6, relu6); \ __macro(Reciprocal, reciprocal); \ __macro(HardSigmoid, hard_sigmoid); -#define FOR_EACH_MKLDNN_INPLACE_OP_FUNCTOR(__macro) \ - __macro(Relu, relu); \ - __macro(Tanh, tanh); \ - __macro(Sqrt, sqrt); - #define FOR_EACH_OP_FUNCTOR(__macro) \ __macro(LogSigmoid, logsigmoid); \ __macro(SoftShrink, softshrink); \ + __macro(Abs, abs); \ __macro(Cos, cos); \ __macro(Sin, sin); \ __macro(Round, round); \ @@ -491,32 +517,18 @@ namespace ops = paddle::operators; __macro(Swish, swish); \ __macro(ThresholdedRelu, thresholded_relu); -#define FOR_EACH_MKLDNN_OP_FUNCTOR(__macro) __macro(Abs, abs); - #define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \ REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \ ::paddle::operators::OP_NAME##OpMaker, \ ::paddle::operators::OP_NAME##GradMaker); \ REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad) -#define REGISTER_INPLACE_ACTIVATION_MKLDNN_OP(OP_NAME, KERNEL_TYPE) \ - REGISTER_OPERATOR(KERNEL_TYPE, ops::ActivationWithMKLDNNOp, \ - ::paddle::operators::OP_NAME##OpMaker, \ - ::paddle::operators::OP_NAME##GradMaker); \ - REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationWithMKLDNNOpGrad) - #define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \ REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \ ::paddle::operators::OP_NAME##OpMaker, \ ::paddle::framework::DefaultGradOpDescMaker); \ REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad) -#define REGISTER_ACTIVATION_MKLDNN_OP(OP_NAME, KERNEL_TYPE) \ - REGISTER_OPERATOR(KERNEL_TYPE, ops::ActivationWithMKLDNNOp, \ - ::paddle::operators::OP_NAME##OpMaker, \ - ::paddle::framework::DefaultGradOpDescMaker); \ - REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationWithMKLDNNOpGrad) - #define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \ REGISTER_OP_CPU_KERNEL( \ act_type, ops::ActivationKernel>); FOR_EACH_OP_FUNCTOR(REGISTER_ACTIVATION_OP); -FOR_EACH_MKLDNN_OP_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_OP); FOR_EACH_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_OP); -FOR_EACH_MKLDNN_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_MKLDNN_OP); FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL); diff --git a/paddle/fluid/operators/mkldnn_activation_op.h b/paddle/fluid/operators/mkldnn_activation_op.h index de8daed1706..85664623d73 100644 --- a/paddle/fluid/operators/mkldnn_activation_op.h +++ b/paddle/fluid/operators/mkldnn_activation_op.h @@ -62,52 +62,5 @@ class MKLDNNActivationGradKernel } }; -namespace { // NOLINT -framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, - const framework::OperatorWithKernel& oper, - const std::string& name) { - framework::LibraryType library{framework::LibraryType::kPlain}; -#ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { - library = framework::LibraryType::kMKLDNN; - } -#endif - framework::DataLayout layout = framework::DataLayout::kAnyLayout; - return framework::OpKernelType( - framework::ToDataType(ctx.Input(name)->type()), - ctx.GetPlace(), layout, library); -} -} // anonymous namespace - -class ActivationWithMKLDNNOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - ctx->SetOutputDim("Out", ctx->GetInputDim("X")); - ctx->ShareLoD("X", /*->*/ "Out"); - } - - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return GetKernelType(ctx, *this, "X"); - } -}; - -class ActivationWithMKLDNNOpGrad : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out")); - } - - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return GetKernelType(ctx, *this, "Out"); - } -}; - } // namespace operators } // namespace paddle -- GitLab