From 95ca886ca2905fa4d539479cb8a8f6f5d21baf35 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 19 Oct 2022 14:15:56 +0800 Subject: [PATCH] move the logic of mkldnn layout in GetKernelTypeForVar from ActivationOp to base class (#47104) --- paddle/fluid/framework/operator.cc | 13 +++++++++++++ paddle/fluid/operators/activation_op.cc | 21 --------------------- 2 files changed, 13 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index f77711ecff..5b272e30ab 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -2660,6 +2660,19 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, const OpKernelType& expected_kernel_type) const { +#ifdef PADDLE_WITH_MKLDNN + // When the op is first oneDNN op (there was some non oneDNN op + // previously) + // then we also need to rotate shape NHWC -> NCWH + if ((expected_kernel_type.data_layout_ == phi::DataLayout::kMKLDNN) && + (tensor.layout() != phi::DataLayout::kMKLDNN) && + paddle::platform::MKLDNNDeviceContext::tls() + .get_cur_paddle_data_layout() == phi::DataLayout::kNHWC) { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), + phi::DataLayout::kNHWC); + } +#endif return OpKernelType( expected_kernel_type.data_type_, tensor.place(), tensor.layout()); } diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 42ee893a3d..dfc5be8de0 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -110,27 +110,6 @@ class ActivationOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { return GetKernelType(ctx, *this, "X"); } - - framework::OpKernelType GetKernelTypeForVar( - const std::string& var_name, - const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { -#ifdef PADDLE_WITH_MKLDNN - // When activation is first oneDNN op (there was some non oneDNN op - // previously) - // then we also need to rotate shape NHWC -> NCWH - if ((expected_kernel_type.data_layout_ == phi::DataLayout::kMKLDNN) && - (tensor.layout() != phi::DataLayout::kMKLDNN) && - paddle::platform::MKLDNNDeviceContext::tls() - .get_cur_paddle_data_layout() == phi::DataLayout::kNHWC) { - return framework::OpKernelType(expected_kernel_type.data_type_, - tensor.place(), - phi::DataLayout::kNHWC); - } -#endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); - } }; class ActivationOpInferVarType -- GitLab