未验证 提交 95ca886c 编写于 作者: Z zyfncg 提交者: GitHub

move the logic of mkldnn layout in GetKernelTypeForVar from ActivationOp to base class (#47104)

上级 85489d39
...@@ -2660,6 +2660,19 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar( ...@@ -2660,6 +2660,19 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar(
const std::string& var_name, const std::string& var_name,
const phi::DenseTensor& tensor, const phi::DenseTensor& tensor,
const OpKernelType& expected_kernel_type) const { const OpKernelType& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
// When the op is first oneDNN op (there was some non oneDNN op
// previously)
// then we also need to rotate shape NHWC -> NCWH
if ((expected_kernel_type.data_layout_ == phi::DataLayout::kMKLDNN) &&
(tensor.layout() != phi::DataLayout::kMKLDNN) &&
paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout() == phi::DataLayout::kNHWC) {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(),
phi::DataLayout::kNHWC);
}
#endif
return OpKernelType( return OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout()); expected_kernel_type.data_type_, tensor.place(), tensor.layout());
} }
......
...@@ -110,27 +110,6 @@ class ActivationOp : public framework::OperatorWithKernel { ...@@ -110,27 +110,6 @@ class ActivationOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "X"); return GetKernelType(ctx, *this, "X");
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
#ifdef PADDLE_WITH_MKLDNN
// When activation is first oneDNN op (there was some non oneDNN op
// previously)
// then we also need to rotate shape NHWC -> NCWH
if ((expected_kernel_type.data_layout_ == phi::DataLayout::kMKLDNN) &&
(tensor.layout() != phi::DataLayout::kMKLDNN) &&
paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout() == phi::DataLayout::kNHWC) {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(),
phi::DataLayout::kNHWC);
}
#endif
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
}
}; };
class ActivationOpInferVarType class ActivationOpInferVarType
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册