未验证 提交 95ca886c 编写于 作者: Z zyfncg 提交者: GitHub

move the logic of mkldnn layout in GetKernelTypeForVar from ActivationOp to base class (#47104)

上级 85489d39
......@@ -2660,6 +2660,19 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const OpKernelType& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
// When the op is first oneDNN op (there was some non oneDNN op
// previously)
// then we also need to rotate shape NHWC -> NCWH
if ((expected_kernel_type.data_layout_ == phi::DataLayout::kMKLDNN) &&
(tensor.layout() != phi::DataLayout::kMKLDNN) &&
paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout() == phi::DataLayout::kNHWC) {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(),
phi::DataLayout::kNHWC);
}
#endif
return OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
}
......
......@@ -110,27 +110,6 @@ class ActivationOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "X");
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
#ifdef PADDLE_WITH_MKLDNN
// When activation is first oneDNN op (there was some non oneDNN op
// previously)
// then we also need to rotate shape NHWC -> NCWH
if ((expected_kernel_type.data_layout_ == phi::DataLayout::kMKLDNN) &&
(tensor.layout() != phi::DataLayout::kMKLDNN) &&
paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout() == phi::DataLayout::kNHWC) {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(),
phi::DataLayout::kNHWC);
}
#endif
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
}
};
class ActivationOpInferVarType
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册