diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index f77711ecffb438d2a9c312ee20c7b24ab0beaa64..5b272e30ab3420b31b1dac15af630f153ed37be5 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -2660,6 +2660,19 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, const OpKernelType& expected_kernel_type) const { +#ifdef PADDLE_WITH_MKLDNN + // When the op is first oneDNN op (there was some non oneDNN op + // previously) + // then we also need to rotate shape NHWC -> NCWH + if ((expected_kernel_type.data_layout_ == phi::DataLayout::kMKLDNN) && + (tensor.layout() != phi::DataLayout::kMKLDNN) && + paddle::platform::MKLDNNDeviceContext::tls() + .get_cur_paddle_data_layout() == phi::DataLayout::kNHWC) { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), + phi::DataLayout::kNHWC); + } +#endif return OpKernelType( expected_kernel_type.data_type_, tensor.place(), tensor.layout()); } diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 42ee893a3d47acea9aa74cad311ce3f9add3a709..dfc5be8de08bc46c334187b60f2b1cf74fccb66f 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -110,27 +110,6 @@ class ActivationOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { return GetKernelType(ctx, *this, "X"); } - - framework::OpKernelType GetKernelTypeForVar( - const std::string& var_name, - const phi::DenseTensor& tensor, - const framework::OpKernelType& expected_kernel_type) const override { -#ifdef PADDLE_WITH_MKLDNN - // When activation is first oneDNN op (there was some non oneDNN op - // previously) - // then we also need to rotate shape NHWC -> NCWH - if ((expected_kernel_type.data_layout_ == phi::DataLayout::kMKLDNN) && - (tensor.layout() != phi::DataLayout::kMKLDNN) && - paddle::platform::MKLDNNDeviceContext::tls() - .get_cur_paddle_data_layout() == phi::DataLayout::kNHWC) { - return framework::OpKernelType(expected_kernel_type.data_type_, - tensor.place(), - phi::DataLayout::kNHWC); - } -#endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); - } }; class ActivationOpInferVarType