From 7e914386bb889b3ea6ccc9dfdbb9964ec83b8f2d Mon Sep 17 00:00:00 2001 From: Paulina Gacek Date: Wed, 9 Nov 2022 09:11:44 +0100 Subject: [PATCH] Enable fc passes (#45704) * Analysis API interface for disabling fc passes * Unit tests corrected * Python API added * test runs only when PADDLE_WITH_MKLDNN * Fc op changed to relu in matmul_op_test * Disable fc passes in tests where acc drops * code formating * Unit test for analysisConf added * Unit test gpu added * fc passes disabled when iterations=0 in gru test * style * passes disabled when fp32 in gru test * fc passes disabled in lstm test * Import from inference, not fluid in doc --- paddle/fluid/inference/api/analysis_config.cc | 16 ++++++++ .../api/analysis_predictor_tester.cc | 39 +++++++++++++++++++ .../inference/api/paddle_analysis_config.h | 14 +++++++ .../inference/api/paddle_pass_builder.cc | 32 ++++++++++++++- .../fluid/inference/api/paddle_pass_builder.h | 14 +++++++ .../tests/api/analyzer_bert_tester.cc | 3 -- .../tests/api/analyzer_dam_tester.cc | 4 -- .../tests/api/analyzer_ernie_int8_tester.cc | 1 + .../tests/api/analyzer_ernie_tester.cc | 1 + .../analyzer_image_classification_tester.cc | 10 ++--- .../analyzer_lexical_analysis_gru_tester.cc | 8 +++- .../api/analyzer_seq_pool1_tester_helper.h | 2 - .../analyzer_transformer_compare_tester.cc | 2 - .../analyzer_transformer_profile_tester.cc | 2 - .../tests/api/analyzer_vis_tester.cc | 4 -- .../tests/api/analyzer_vit_ocr_tester.cc | 5 --- paddle/fluid/pybind/inference_api.cc | 16 ++++++++ .../slim/tests/quant2_int8_lstm_model.py | 1 + .../test_mkldnn_matmul_op_output_fuse_pass.py | 4 +- 19 files changed, 144 insertions(+), 34 deletions(-) diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 319a3ea018d..00d667776ee 100755 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -615,6 +615,16 @@ void AnalysisConfig::EnableMkldnnBfloat16() { Update(); } +void AnalysisConfig::DisableMkldnnFcPasses() { +#ifdef PADDLE_WITH_MKLDNN + disable_mkldnn_fc_passes_ = true; +#else + LOG(ERROR) << "Please compile with MKLDNN first to use DisableMkldnnFcPasses"; + disable_mkldnn_fc_passes_ = false; +#endif + Update(); +} + void AnalysisConfig::EnableMkldnnInt8( const std::unordered_set &op_list) { #ifdef PADDLE_WITH_MKLDNN @@ -892,6 +902,12 @@ void AnalysisConfig::Update() { #endif } + if (disable_mkldnn_fc_passes_) { +#ifdef PADDLE_WITH_MKLDNN + pass_builder()->DisableMkldnnFcPasses(); +#endif + } + #ifdef PADDLE_WITH_MKLDNN // Do not optimize when mkldnn is on if (enable_memory_optim_ && !use_mkldnn_) { diff --git a/paddle/fluid/inference/api/analysis_predictor_tester.cc b/paddle/fluid/inference/api/analysis_predictor_tester.cc index 5cba8f06ab9..c75f1e4a569 100644 --- a/paddle/fluid/inference/api/analysis_predictor_tester.cc +++ b/paddle/fluid/inference/api/analysis_predictor_tester.cc @@ -343,6 +343,45 @@ TEST(AnalysisPredictor, bf16_pass_strategy) { passStrategy.EnableMkldnnBfloat16(); } +TEST(AnalysisPredictor, mkldnn_fc_pass_strategy) { + std::vector passes; + PassStrategy passStrategy(passes); + passStrategy.DisableMkldnnFcPasses(); + ASSERT_EQ(passes.size(), (size_t)0); +} + +#ifdef PADDLE_WITH_MKLDNN +TEST(AnalysisPredictor, mkldnn_fc_passes_cpu_pass_strategy) { + CpuPassStrategy cpuPassStrategy; + cpuPassStrategy.EnableMKLDNN(); + const std::vector fc_passes_to_erase( + {"fc_mkldnn_pass", + "fc_act_mkldnn_fuse_pass", + "fc_elementwise_add_mkldnn_fuse_pass"}); + for (const auto& pass : fc_passes_to_erase) { + ASSERT_NE(cpuPassStrategy.GetPassIndex(pass), (size_t)-1); + } + cpuPassStrategy.DisableMkldnnFcPasses(); + for (const auto& pass : fc_passes_to_erase) { + ASSERT_EQ(cpuPassStrategy.GetPassIndex(pass), (size_t)-1); + } +} +#endif + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +TEST(AnalysisPredictor, mkldnn_fc_passes_gpu_pass_strategy) { + AnalysisConfig config; + config.EnableUseGpu(100, 0); + config.EnableMKLDNN(); + config.DisableMkldnnFcPasses(); +#ifdef PADDLE_WITH_MKLDNN + ASSERT_TRUE(config.mkldnn_fc_passes_disabled()); +#else + ASSERT_FALSE(config.mkldnn_fc_passes_disabled()); +#endif +} +#endif + #ifdef PADDLE_WITH_XPU TEST(AnalysisPredictor, set_xpu_device_id) { AnalysisConfig config; diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index 99e8ddac048..579321fd17b 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -820,6 +820,18 @@ struct PD_INFER_DECL AnalysisConfig { /// void EnableMkldnnBfloat16(); + /// + /// \brief Turn off MKLDNN fc passes. + /// + void DisableMkldnnFcPasses(); + + /// + /// \brief A boolean state telling whether to disable the MKLDNN Fc passes. + /// + /// \return bool Whether to disable the MKLDNN Fc passes. + /// + bool mkldnn_fc_passes_disabled() const { return disable_mkldnn_fc_passes_; } + /// /// \brief A boolean state telling whether to use the MKLDNN Bfloat16. /// @@ -1137,6 +1149,8 @@ struct PD_INFER_DECL AnalysisConfig { "slice", "split"}; + bool disable_mkldnn_fc_passes_{false}; + // ipu related. bool use_ipu_{false}; int ipu_device_num_{1}; diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 4c6ab76dbac..be5c4d1d88d 100755 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -265,6 +265,10 @@ void GpuPassStrategy::EnableMkldnnInt8() { LOG(ERROR) << "GPU not support MKL-DNN int8"; } +void GpuPassStrategy::DisableMkldnnFcPasses() { + LOG(ERROR) << "GPU not support MKL-DNN fc"; +} + CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { // NOTE the large fusions should be located in the front, so that they will // not be damaged by smaller ones. @@ -333,8 +337,8 @@ void CpuPassStrategy::EnableMKLDNN() { "matmul_elementwise_add_mkldnn_fuse_pass", // "matmul_activation_mkldnn_fuse_pass", // // Disabled due to topology-dependent speed-up - // "fc_mkldnn_pass", - // "fc_act_mkldnn_fuse_pass", + "fc_mkldnn_pass", + "fc_act_mkldnn_fuse_pass", "fc_elementwise_add_mkldnn_fuse_pass", // "batch_norm_act_fuse_pass", // "softplus_activation_mkldnn_fuse_pass", // @@ -454,6 +458,30 @@ void CpuPassStrategy::EnableMkldnnInt8() { #endif } +void CpuPassStrategy::DisableMkldnnFcPasses() { +#ifdef PADDLE_WITH_MKLDNN + if (!disable_mkldnn_fc_passes_) { + EraseFcMkldnnPasses(); + } + disable_mkldnn_fc_passes_ = true; +#else + disable_mkldnn_fc_passes_ = false; +#endif +} + +void CpuPassStrategy::EraseFcMkldnnPasses() { + std::vector fc_passes_to_erase( + {"fc_mkldnn_pass", + "fc_act_mkldnn_fuse_pass", + "fc_elementwise_add_mkldnn_fuse_pass"}); + for (const auto &pass : fc_passes_to_erase) { + int idx = GetPassIndex(pass); + if (idx != -1) { + passes_.erase(std::begin(passes_) + idx); + } + } +} + IpuPassStrategy::IpuPassStrategy() : PassStrategy({}) { passes_.assign({"inference_process_pass"}); } diff --git a/paddle/fluid/inference/api/paddle_pass_builder.h b/paddle/fluid/inference/api/paddle_pass_builder.h index cd973827853..1b81098470a 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.h +++ b/paddle/fluid/inference/api/paddle_pass_builder.h @@ -152,6 +152,9 @@ class PD_INFER_DECL PassStrategy : public PaddlePassBuilder { /// \brief Enable MKLDNN int8. virtual void EnableMkldnnInt8() {} + /// \brief Disable MKLDNN fc passes. + virtual void DisableMkldnnFcPasses() {} + /// \brief Check if we are using gpu. /// \return A bool variable implying whether we are in gpu mode. bool use_gpu() const { return use_gpu_; } @@ -205,6 +208,7 @@ class PD_INFER_DECL CpuPassStrategy : public PassStrategy { use_mkldnn_quantizer_ = other.use_mkldnn_quantizer_; use_mkldnn_bfloat16_ = other.use_mkldnn_bfloat16_; use_mkldnn_int8_ = other.use_mkldnn_int8_; + disable_mkldnn_fc_passes_ = other.disable_mkldnn_fc_passes_; } /// \brief Default destructor. virtual ~CpuPassStrategy() = default; @@ -224,11 +228,18 @@ class PD_INFER_DECL CpuPassStrategy : public PassStrategy { /// \brief Enable MKLDNN int8. void EnableMkldnnInt8() override; + /// \brief Disable MKLDNN fc passes. + void DisableMkldnnFcPasses() override; + protected: + /// \brief Erase MKLDNN fc passes. + void EraseFcMkldnnPasses(); + /// \cond Protected bool use_mkldnn_quantizer_{false}; bool use_mkldnn_bfloat16_{false}; bool use_mkldnn_int8_{false}; + bool disable_mkldnn_fc_passes_{false}; /// \endcond }; @@ -263,6 +274,9 @@ class PD_INFER_DECL GpuPassStrategy : public PassStrategy { /// \brief Not supported in GPU mode yet. void EnableMkldnnInt8() override; + /// \brief Disable MKLDNN fc passes. + void DisableMkldnnFcPasses() override; + /// \brief Default destructor. virtual ~GpuPassStrategy() = default; diff --git a/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc b/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc index db1f2953c74..e7462786c40 100644 --- a/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc @@ -218,9 +218,6 @@ AnalysisConfig SetConfig(bool use_mkldnn, bool use_bfloat16) { if (use_mkldnn) { config.EnableMKLDNN(); - config.pass_builder()->AppendPass("fc_mkldnn_pass"); - config.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass"); - config.pass_builder()->AppendPass("fc_elementwise_add_mkldnn_fuse_pass"); } if (use_bfloat16) config.EnableMkldnnBfloat16(); diff --git a/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc b/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc index 1b8ed41e386..36a2dfcb715 100644 --- a/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc @@ -214,8 +214,6 @@ void profile(bool use_mkldnn = false) { std::unordered_set op_list = { "softmax", "elementwise_add", "relu", "fc"}; cfg.SetMKLDNNOp(op_list); - cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); - cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass"); } std::vector> outputs; @@ -276,8 +274,6 @@ void compare(bool use_mkldnn = false) { std::unordered_set op_list = { "softmax", "elementwise_add", "relu"}; cfg.SetMKLDNNOp(op_list); - cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); - cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass"); } std::vector> input_slots_all; diff --git a/paddle/fluid/inference/tests/api/analyzer_ernie_int8_tester.cc b/paddle/fluid/inference/tests/api/analyzer_ernie_int8_tester.cc index 26283dc34dc..43f1f8a1163 100644 --- a/paddle/fluid/inference/tests/api/analyzer_ernie_int8_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_ernie_int8_tester.cc @@ -29,6 +29,7 @@ void SetInt8Config(AnalysisConfig *cfg, std::vector data) { cfg->SetModel(FLAGS_infer_model); cfg->EnableMKLDNN(); + cfg->DisableMkldnnFcPasses(); // fc passes caused loss in accuracy cfg->EnableMkldnnQuantizer(); auto pass_builder = cfg->pass_builder(); pass_builder->DeletePass("constant_folding_pass"); diff --git a/paddle/fluid/inference/tests/api/analyzer_ernie_tester.cc b/paddle/fluid/inference/tests/api/analyzer_ernie_tester.cc index 1efbe7cecdd..79a6c840ea3 100644 --- a/paddle/fluid/inference/tests/api/analyzer_ernie_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_ernie_tester.cc @@ -92,6 +92,7 @@ void compare(bool use_mkldnn = false) { AnalysisConfig cfg; SetConfig(&cfg, use_mkldnn, false); + cfg.DisableMkldnnFcPasses(); // fc passes caused loss in accuracy auto pass_builder = cfg.pass_builder(); pass_builder->DeletePass("constant_folding_pass"); CompareNativeAndAnalysis( diff --git a/paddle/fluid/inference/tests/api/analyzer_image_classification_tester.cc b/paddle/fluid/inference/tests/api/analyzer_image_classification_tester.cc index dc8921ef731..e25c78bd287 100644 --- a/paddle/fluid/inference/tests/api/analyzer_image_classification_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_image_classification_tester.cc @@ -51,9 +51,8 @@ void profile(bool use_mkldnn = false) { if (use_mkldnn) { cfg.EnableMKLDNN(); - if (!FLAGS_disable_mkldnn_fc) { - cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); - cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass"); + if (FLAGS_disable_mkldnn_fc) { + cfg.DisableMkldnnFcPasses(); } } std::vector> outputs; @@ -88,9 +87,8 @@ void compare(bool use_mkldnn = false) { SetConfig(&cfg); if (use_mkldnn) { cfg.EnableMKLDNN(); - if (!FLAGS_disable_mkldnn_fc) { - cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); - cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass"); + if (FLAGS_disable_mkldnn_fc) { + cfg.DisableMkldnnFcPasses(); } } diff --git a/paddle/fluid/inference/tests/api/analyzer_lexical_analysis_gru_tester.cc b/paddle/fluid/inference/tests/api/analyzer_lexical_analysis_gru_tester.cc index 89bfc6812a4..d8cd551dbd0 100644 --- a/paddle/fluid/inference/tests/api/analyzer_lexical_analysis_gru_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_lexical_analysis_gru_tester.cc @@ -262,15 +262,19 @@ TEST(Analyzer_lexical_test, Analyzer_lexical_analysis) { if (FLAGS_enable_bf16) { analysis_cfg.EnableMkldnnBfloat16(); } else if (FLAGS_enable_int8) { - if (FLAGS_fuse_multi_gru) + if (FLAGS_fuse_multi_gru) { analysis_cfg.pass_builder()->AppendPass("multi_gru_fuse_pass"); - + } std::shared_ptr> warmup_data = WarmupData(input_slots_all); analysis_cfg.EnableMkldnnQuantizer(); analysis_cfg.mkldnn_quantizer_config()->SetWarmupData(warmup_data); analysis_cfg.mkldnn_quantizer_config()->SetWarmupBatchSize( FLAGS_batch_size); + } else { + // if fp32 => disable mkldnn fc passes + // when passes are enabled dnnl error occurs for iterations==0 + analysis_cfg.DisableMkldnnFcPasses(); } std::vector acc_analysis(3); acc_analysis = Lexical_Test(input_slots_all, &outputs, &analysis_cfg, true); diff --git a/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester_helper.h b/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester_helper.h index 85923658285..8386ac7445a 100644 --- a/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester_helper.h +++ b/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester_helper.h @@ -169,8 +169,6 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) { } if (use_mkldnn) { cfg->EnableMKLDNN(); - cfg->pass_builder()->AppendPass("fc_mkldnn_pass"); - cfg->pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass"); } // Enable seqpool_concat_fuse_pass, disabled by default since it takes much // time diff --git a/paddle/fluid/inference/tests/api/analyzer_transformer_compare_tester.cc b/paddle/fluid/inference/tests/api/analyzer_transformer_compare_tester.cc index 65306fd42ed..1d511309177 100644 --- a/paddle/fluid/inference/tests/api/analyzer_transformer_compare_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_transformer_compare_tester.cc @@ -24,8 +24,6 @@ void compare(bool use_mkldnn = false) { SetConfig(&cfg); if (use_mkldnn) { cfg.EnableMKLDNN(); - cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); - cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass"); } std::vector> input_slots_all; diff --git a/paddle/fluid/inference/tests/api/analyzer_transformer_profile_tester.cc b/paddle/fluid/inference/tests/api/analyzer_transformer_profile_tester.cc index 0e3a23895f0..9cbba30f9d0 100644 --- a/paddle/fluid/inference/tests/api/analyzer_transformer_profile_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_transformer_profile_tester.cc @@ -25,8 +25,6 @@ void profile(bool use_mkldnn = false) { std::vector> outputs; if (use_mkldnn) { cfg.EnableMKLDNN(); - cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); - cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass"); } std::vector> input_slots_all; diff --git a/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc b/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc index 0ee37c81a48..0581eb614a4 100644 --- a/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc @@ -88,8 +88,6 @@ void profile(bool use_mkldnn = false) { SetConfig(&cfg); if (use_mkldnn) { cfg.EnableMKLDNN(); - cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); - cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass"); } // cfg.pass_builder()->TurnOnDebug(); std::vector> outputs; @@ -142,8 +140,6 @@ void compare(bool use_mkldnn = false) { SetConfig(&cfg); if (use_mkldnn) { cfg.EnableMKLDNN(); - cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); - cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass"); } std::vector> input_slots_all; diff --git a/paddle/fluid/inference/tests/api/analyzer_vit_ocr_tester.cc b/paddle/fluid/inference/tests/api/analyzer_vit_ocr_tester.cc index 8c7ed7ffa29..3870fde8b53 100644 --- a/paddle/fluid/inference/tests/api/analyzer_vit_ocr_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_vit_ocr_tester.cc @@ -72,11 +72,6 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) { if (use_mkldnn) { cfg->EnableMKLDNN(); cfg->SwitchIrOptim(); - - size_t insertingIndex = cfg->pass_builder()->GetPassIndex( - "fc_elementwise_add_mkldnn_fuse_pass"); - cfg->pass_builder()->InsertPass(insertingIndex, "fc_act_mkldnn_fuse_pass"); - cfg->pass_builder()->InsertPass(insertingIndex, "fc_mkldnn_pass"); } } diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 60f1bfd9216..e076beb1c87 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -809,6 +809,22 @@ void BindAnalysisConfig(py::module *m) { py::arg("mkldnn_int8_enabled_op_types") = std::unordered_set({})) .def("mkldnn_int8_enabled", &AnalysisConfig::mkldnn_int8_enabled) + .def("disable_mkldnn_fc_passes", + &AnalysisConfig::DisableMkldnnFcPasses, + R"DOC( + Disable Mkldnn FC + Args: + None. + Returns: + None. + Examples: + .. code-block:: python + from paddle.inference import Config + + config = Config("") + config.enable_mkldnn() + config.disable_mkldnn_fc_passes() + )DOC") #endif .def("set_mkldnn_op", &AnalysisConfig::SetMKLDNNOp) .def("set_model_buffer", &AnalysisConfig::SetModelBuffer) diff --git a/python/paddle/fluid/contrib/slim/tests/quant2_int8_lstm_model.py b/python/paddle/fluid/contrib/slim/tests/quant2_int8_lstm_model.py index 71bac0208e4..96cb22dc2e5 100644 --- a/python/paddle/fluid/contrib/slim/tests/quant2_int8_lstm_model.py +++ b/python/paddle/fluid/contrib/slim/tests/quant2_int8_lstm_model.py @@ -112,6 +112,7 @@ class TestLstmModelPTQ(unittest.TestCase): config.switch_use_feed_fetch_ops(True) config.switch_ir_optim(True) config.enable_mkldnn() + config.disable_mkldnn_fc_passes() # fc passes caused dnnl error config.set_mkldnn_cache_capacity(mkldnn_cache_capacity) if enable_ptq: # This pass to work properly, must be added before fc_fuse_pass diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_op_output_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_op_output_fuse_pass.py index 064847dd7a0..1991f3592fc 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_op_output_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_op_output_fuse_pass.py @@ -40,7 +40,7 @@ class TestMKLDNNMatmulFuseOp(InferencePassTest): out = fluid.layers.reshape( out, [0, 0, self.shape_y[0] * self.shape_y[2]] ) - out = fluid.layers.fc(out, size=1) + out = fluid.layers.relu(out) return out def setUp(self): @@ -109,7 +109,7 @@ class TestMKLDNNMatmulOpNotFusedBreakPattern(TestMKLDNNMatmulFuseOp): out = fluid.layers.reshape( out, [0, 0, self.shape_y[0] * self.shape_y[2]] ) - out = fluid.layers.fc(out, size=1) + out = fluid.layers.relu(out) return out -- GitLab