From d8bd436fc16497e1f29de2b1f4c2d6f59abb80de Mon Sep 17 00:00:00 2001 From: Krzysztof Binias Date: Wed, 21 Mar 2018 15:48:26 +0100 Subject: [PATCH] Fixed tests --- paddle/fluid/operators/activation_op.cc | 27 ++++------- paddle/fluid/operators/activation_op.h | 19 -------- paddle/fluid/operators/mkldnn_activation_op.h | 47 +++++++++++++++++++ 3 files changed, 56 insertions(+), 37 deletions(-) diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 043ffb01f..979115eee 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/activation_op.h" +#include "paddle/fluid/operators/mkldnn_activation_op.h" namespace paddle { namespace operators { @@ -25,11 +26,6 @@ class ActivationOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->ShareLoD("X", /*->*/ "Out"); } - - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - return ActivationHelper().GetKernelType(ctx, *this); - } }; class ActivationOpGrad : public framework::OperatorWithKernel { @@ -39,11 +35,6 @@ class ActivationOpGrad : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext *ctx) const override { ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out")); } - - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - return ActivationHelper().GetKernelType(ctx, *this); - } }; class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { @@ -546,11 +537,11 @@ REGISTER_OP(logsigmoid, ops::ActivationOp, ops::LogSigmoidOpMaker, REGISTER_OP(exp, ops::ActivationOp, ops::ExpOpMaker, exp_grad, ops::ActivationOpGrad); -REGISTER_OP(relu, ops::ActivationOp, ops::ReluOpMaker, relu_grad, - ops::ActivationOpGrad); +REGISTER_OP(relu, ops::ActivationWithMKLDNNOp, ops::ReluOpMaker, relu_grad, + ops::ActivationWithMKLDNNOpGrad); -REGISTER_OP(tanh, ops::ActivationOp, ops::TanhOpMaker, tanh_grad, - ops::ActivationOpGrad); +REGISTER_OP(tanh, ops::ActivationWithMKLDNNOp, ops::TanhOpMaker, tanh_grad, + ops::ActivationWithMKLDNNOpGrad); REGISTER_OP(tanh_shrink, ops::ActivationOp, ops::TanhShrinkOpMaker, tanh_shrink_grad, ops::ActivationOpGrad); @@ -558,11 +549,11 @@ REGISTER_OP(tanh_shrink, ops::ActivationOp, ops::TanhShrinkOpMaker, REGISTER_OP(softshrink, ops::ActivationOp, ops::SoftShrinkOpMaker, softshrink_grad, ops::ActivationOpGrad); -REGISTER_OP(sqrt, ops::ActivationOp, ops::SqrtOpMaker, sqrt_grad, - ops::ActivationOpGrad); +REGISTER_OP(sqrt, ops::ActivationWithMKLDNNOp, ops::SqrtOpMaker, sqrt_grad, + ops::ActivationWithMKLDNNOpGrad); -REGISTER_OP(abs, ops::ActivationOp, ops::AbsOpMaker, abs_grad, - ops::ActivationOpGrad); +REGISTER_OP(abs, ops::ActivationWithMKLDNNOp, ops::AbsOpMaker, abs_grad, + ops::ActivationWithMKLDNNOpGrad); REGISTER_OP(ceil, ops::ActivationOp, ops::CeilOpMaker, ceil_grad, ops::ActivationOpGrad); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index e607a5554..4c575b4a7 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -24,25 +24,6 @@ limitations under the License. */ namespace paddle { namespace operators { -class ActivationHelper { - public: - framework::OpKernelType GetKernelType( - const framework::ExecutionContext& ctx, - const framework::OperatorWithKernel& oper) const { - framework::LibraryType library{framework::LibraryType::kPlain}; -#ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { - library = framework::LibraryType::kMKLDNN; - } -#endif - framework::DataLayout layout = framework::DataLayout::kAnyLayout; - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.GetPlace(), layout, library); - } -}; - template class ActivationKernel : public framework::OpKernel { diff --git a/paddle/fluid/operators/mkldnn_activation_op.h b/paddle/fluid/operators/mkldnn_activation_op.h index 976e36291..083d03ebe 100644 --- a/paddle/fluid/operators/mkldnn_activation_op.h +++ b/paddle/fluid/operators/mkldnn_activation_op.h @@ -60,5 +60,52 @@ class MKLDNNActivationGradKernel } }; +namespace { +framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx, + const framework::OperatorWithKernel& oper) { + framework::LibraryType library{framework::LibraryType::kPlain}; +#ifdef PADDLE_WITH_MKLDNN + if (library == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library = framework::LibraryType::kMKLDNN; + } +#endif + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.GetPlace(), layout, library); +} +} // anonymous namespace + +class ActivationWithMKLDNNOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return GetKernelType(ctx, *this); + } +}; + +class ActivationWithMKLDNNOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out")); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return GetKernelType(ctx, *this); + } +}; + } // namespace operators } // namespace paddle -- GitLab