diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 043ffb01fc0af6873c3d3a17e641d77f5e75b965..979115eee0dbe157dbcf2293d914cc250b35d22e 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/activation_op.h" +#include "paddle/fluid/operators/mkldnn_activation_op.h" namespace paddle { namespace operators { @@ -25,11 +26,6 @@ class ActivationOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->ShareLoD("X", /*->*/ "Out"); } - - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - return ActivationHelper().GetKernelType(ctx, *this); - } }; class ActivationOpGrad : public framework::OperatorWithKernel { @@ -39,11 +35,6 @@ class ActivationOpGrad : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext *ctx) const override { ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out")); } - - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - return ActivationHelper().GetKernelType(ctx, *this); - } }; class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { @@ -546,11 +537,11 @@ REGISTER_OP(logsigmoid, ops::ActivationOp, ops::LogSigmoidOpMaker, REGISTER_OP(exp, ops::ActivationOp, ops::ExpOpMaker, exp_grad, ops::ActivationOpGrad); -REGISTER_OP(relu, ops::ActivationOp, ops::ReluOpMaker, relu_grad, - ops::ActivationOpGrad); +REGISTER_OP(relu, ops::ActivationWithMKLDNNOp, ops::ReluOpMaker, relu_grad, + ops::ActivationWithMKLDNNOpGrad); -REGISTER_OP(tanh, ops::ActivationOp, ops::TanhOpMaker, tanh_grad, - ops::ActivationOpGrad); +REGISTER_OP(tanh, ops::ActivationWithMKLDNNOp, ops::TanhOpMaker, tanh_grad, + ops::ActivationWithMKLDNNOpGrad); REGISTER_OP(tanh_shrink, ops::ActivationOp, ops::TanhShrinkOpMaker, tanh_shrink_grad, ops::ActivationOpGrad); @@ -558,11 +549,11 @@ REGISTER_OP(tanh_shrink, ops::ActivationOp, ops::TanhShrinkOpMaker, REGISTER_OP(softshrink, ops::ActivationOp, ops::SoftShrinkOpMaker, softshrink_grad, ops::ActivationOpGrad); -REGISTER_OP(sqrt, ops::ActivationOp, ops::SqrtOpMaker, sqrt_grad, - ops::ActivationOpGrad); +REGISTER_OP(sqrt, ops::ActivationWithMKLDNNOp, ops::SqrtOpMaker, sqrt_grad, + ops::ActivationWithMKLDNNOpGrad); -REGISTER_OP(abs, ops::ActivationOp, ops::AbsOpMaker, abs_grad, - ops::ActivationOpGrad); +REGISTER_OP(abs, ops::ActivationWithMKLDNNOp, ops::AbsOpMaker, abs_grad, + ops::ActivationWithMKLDNNOpGrad); REGISTER_OP(ceil, ops::ActivationOp, ops::CeilOpMaker, ceil_grad, ops::ActivationOpGrad); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index e607a5554f4f67e3451713c405f129a9fc75a1b0..4c575b4a7b551be2d1288f7fec0a2821fc10c40d 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -24,25 +24,6 @@ limitations under the License. */ namespace paddle { namespace operators { -class ActivationHelper { - public: - framework::OpKernelType GetKernelType( - const framework::ExecutionContext& ctx, - const framework::OperatorWithKernel& oper) const { - framework::LibraryType library{framework::LibraryType::kPlain}; -#ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { - library = framework::LibraryType::kMKLDNN; - } -#endif - framework::DataLayout layout = framework::DataLayout::kAnyLayout; - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.GetPlace(), layout, library); - } -}; - template class ActivationKernel : public framework::OpKernel { diff --git a/paddle/fluid/operators/mkldnn_activation_op.h b/paddle/fluid/operators/mkldnn_activation_op.h index 976e362911e16a6e382ebd827c62c314bef8a410..083d03ebe610521c5a4beb7b977a8179700bcf40 100644 --- a/paddle/fluid/operators/mkldnn_activation_op.h +++ b/paddle/fluid/operators/mkldnn_activation_op.h @@ -60,5 +60,52 @@ class MKLDNNActivationGradKernel } }; +namespace { +framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx, + const framework::OperatorWithKernel& oper) { + framework::LibraryType library{framework::LibraryType::kPlain}; +#ifdef PADDLE_WITH_MKLDNN + if (library == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library = framework::LibraryType::kMKLDNN; + } +#endif + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.GetPlace(), layout, library); +} +} // anonymous namespace + +class ActivationWithMKLDNNOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return GetKernelType(ctx, *this); + } +}; + +class ActivationWithMKLDNNOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out")); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return GetKernelType(ctx, *this); + } +}; + } // namespace operators } // namespace paddle