From 690767edb91db8fe63d956d50ca96254c5300533 Mon Sep 17 00:00:00 2001 From: Xinyu Chen Date: Thu, 6 Apr 2023 11:08:35 +0800 Subject: [PATCH] [oneDNN]disable interpolate operators by default (#52462) --- .../ir/mkldnn/mkldnn_placement_pass.cc | 20 +++++++++++++++++-- .../inference/api/paddle_pass_builder.cc | 1 - 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.cc b/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.cc index 8d8d7feacfb..c966c46e06b 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_placement_pass.cc @@ -69,8 +69,24 @@ inline bool FoundPhiOneDNNKernelWithCorrectDataType( bool MKLDNNPlacementPass::IsSupport(const Node* op) const { if (FoundOneDNNKernelWithCorrectDataType(op) || - FoundPhiOneDNNKernelWithCorrectDataType(op)) - return true; + FoundPhiOneDNNKernelWithCorrectDataType(op)) { + // For interpolate ops, there's a little difference between Paddle and + // DNNL. + // If run DNNL interpolate ops, manual set AnalysisConfig and apply + // the corresponding pass. + const std::vector not_default_op_types = {"bilinear_interp", + "nearest_interp", + "trilinear_interp", + "bicubic_interp", + "linear_interp", + "bilinear_interp_v2", + "linear_interp_v2"}; + bool is_interpolate_op = + std::find(not_default_op_types.begin(), + not_default_op_types.end(), + op->Op()->Type()) != not_default_op_types.end(); + return !is_interpolate_op; + } return false; } diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 43d452820e0..8b1399515ed 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -363,7 +363,6 @@ void CpuPassStrategy::EnableMKLDNN() { "conv_transpose_eltwiseadd_bn_fuse_pass", // "conv_bias_mkldnn_fuse_pass", // "conv_transpose_bias_mkldnn_fuse_pass", - "interpolate_mkldnn_pass", // TODO(baoachun): Need to support 5-dimensional input. // "conv3d_bias_mkldnn_fuse_pass", // "conv_elementwise_add_mkldnn_fuse_pass", -- GitLab