未验证 提交 7e914386 编写于 作者: P Paulina Gacek 提交者: GitHub

Enable fc passes (#45704)

* Analysis API interface for disabling fc passes

* Unit tests corrected

* Python API added

* test runs only when PADDLE_WITH_MKLDNN

* Fc op changed to relu in matmul_op_test

* Disable fc passes in tests where acc drops

* code formating

* Unit test for analysisConf added

* Unit test gpu added

* fc passes disabled when iterations=0 in gru test

* style

* passes disabled when fp32 in gru test

* fc passes disabled in lstm test

* Import from inference, not fluid in doc
上级 c919f6f7
...@@ -615,6 +615,16 @@ void AnalysisConfig::EnableMkldnnBfloat16() { ...@@ -615,6 +615,16 @@ void AnalysisConfig::EnableMkldnnBfloat16() {
Update(); Update();
} }
void AnalysisConfig::DisableMkldnnFcPasses() {
#ifdef PADDLE_WITH_MKLDNN
disable_mkldnn_fc_passes_ = true;
#else
LOG(ERROR) << "Please compile with MKLDNN first to use DisableMkldnnFcPasses";
disable_mkldnn_fc_passes_ = false;
#endif
Update();
}
void AnalysisConfig::EnableMkldnnInt8( void AnalysisConfig::EnableMkldnnInt8(
const std::unordered_set<std::string> &op_list) { const std::unordered_set<std::string> &op_list) {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
...@@ -892,6 +902,12 @@ void AnalysisConfig::Update() { ...@@ -892,6 +902,12 @@ void AnalysisConfig::Update() {
#endif #endif
} }
if (disable_mkldnn_fc_passes_) {
#ifdef PADDLE_WITH_MKLDNN
pass_builder()->DisableMkldnnFcPasses();
#endif
}
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// Do not optimize when mkldnn is on // Do not optimize when mkldnn is on
if (enable_memory_optim_ && !use_mkldnn_) { if (enable_memory_optim_ && !use_mkldnn_) {
......
...@@ -343,6 +343,45 @@ TEST(AnalysisPredictor, bf16_pass_strategy) { ...@@ -343,6 +343,45 @@ TEST(AnalysisPredictor, bf16_pass_strategy) {
passStrategy.EnableMkldnnBfloat16(); passStrategy.EnableMkldnnBfloat16();
} }
TEST(AnalysisPredictor, mkldnn_fc_pass_strategy) {
std::vector<std::string> passes;
PassStrategy passStrategy(passes);
passStrategy.DisableMkldnnFcPasses();
ASSERT_EQ(passes.size(), (size_t)0);
}
#ifdef PADDLE_WITH_MKLDNN
TEST(AnalysisPredictor, mkldnn_fc_passes_cpu_pass_strategy) {
CpuPassStrategy cpuPassStrategy;
cpuPassStrategy.EnableMKLDNN();
const std::vector<std::string> fc_passes_to_erase(
{"fc_mkldnn_pass",
"fc_act_mkldnn_fuse_pass",
"fc_elementwise_add_mkldnn_fuse_pass"});
for (const auto& pass : fc_passes_to_erase) {
ASSERT_NE(cpuPassStrategy.GetPassIndex(pass), (size_t)-1);
}
cpuPassStrategy.DisableMkldnnFcPasses();
for (const auto& pass : fc_passes_to_erase) {
ASSERT_EQ(cpuPassStrategy.GetPassIndex(pass), (size_t)-1);
}
}
#endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TEST(AnalysisPredictor, mkldnn_fc_passes_gpu_pass_strategy) {
AnalysisConfig config;
config.EnableUseGpu(100, 0);
config.EnableMKLDNN();
config.DisableMkldnnFcPasses();
#ifdef PADDLE_WITH_MKLDNN
ASSERT_TRUE(config.mkldnn_fc_passes_disabled());
#else
ASSERT_FALSE(config.mkldnn_fc_passes_disabled());
#endif
}
#endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
TEST(AnalysisPredictor, set_xpu_device_id) { TEST(AnalysisPredictor, set_xpu_device_id) {
AnalysisConfig config; AnalysisConfig config;
......
...@@ -820,6 +820,18 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -820,6 +820,18 @@ struct PD_INFER_DECL AnalysisConfig {
/// ///
void EnableMkldnnBfloat16(); void EnableMkldnnBfloat16();
///
/// \brief Turn off MKLDNN fc passes.
///
void DisableMkldnnFcPasses();
///
/// \brief A boolean state telling whether to disable the MKLDNN Fc passes.
///
/// \return bool Whether to disable the MKLDNN Fc passes.
///
bool mkldnn_fc_passes_disabled() const { return disable_mkldnn_fc_passes_; }
/// ///
/// \brief A boolean state telling whether to use the MKLDNN Bfloat16. /// \brief A boolean state telling whether to use the MKLDNN Bfloat16.
/// ///
...@@ -1137,6 +1149,8 @@ struct PD_INFER_DECL AnalysisConfig { ...@@ -1137,6 +1149,8 @@ struct PD_INFER_DECL AnalysisConfig {
"slice", "slice",
"split"}; "split"};
bool disable_mkldnn_fc_passes_{false};
// ipu related. // ipu related.
bool use_ipu_{false}; bool use_ipu_{false};
int ipu_device_num_{1}; int ipu_device_num_{1};
......
...@@ -265,6 +265,10 @@ void GpuPassStrategy::EnableMkldnnInt8() { ...@@ -265,6 +265,10 @@ void GpuPassStrategy::EnableMkldnnInt8() {
LOG(ERROR) << "GPU not support MKL-DNN int8"; LOG(ERROR) << "GPU not support MKL-DNN int8";
} }
void GpuPassStrategy::DisableMkldnnFcPasses() {
LOG(ERROR) << "GPU not support MKL-DNN fc";
}
CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
// NOTE the large fusions should be located in the front, so that they will // NOTE the large fusions should be located in the front, so that they will
// not be damaged by smaller ones. // not be damaged by smaller ones.
...@@ -333,8 +337,8 @@ void CpuPassStrategy::EnableMKLDNN() { ...@@ -333,8 +337,8 @@ void CpuPassStrategy::EnableMKLDNN() {
"matmul_elementwise_add_mkldnn_fuse_pass", // "matmul_elementwise_add_mkldnn_fuse_pass", //
"matmul_activation_mkldnn_fuse_pass", // "matmul_activation_mkldnn_fuse_pass", //
// Disabled due to topology-dependent speed-up // Disabled due to topology-dependent speed-up
// "fc_mkldnn_pass", "fc_mkldnn_pass",
// "fc_act_mkldnn_fuse_pass", "fc_act_mkldnn_fuse_pass",
"fc_elementwise_add_mkldnn_fuse_pass", // "fc_elementwise_add_mkldnn_fuse_pass", //
"batch_norm_act_fuse_pass", // "batch_norm_act_fuse_pass", //
"softplus_activation_mkldnn_fuse_pass", // "softplus_activation_mkldnn_fuse_pass", //
...@@ -454,6 +458,30 @@ void CpuPassStrategy::EnableMkldnnInt8() { ...@@ -454,6 +458,30 @@ void CpuPassStrategy::EnableMkldnnInt8() {
#endif #endif
} }
void CpuPassStrategy::DisableMkldnnFcPasses() {
#ifdef PADDLE_WITH_MKLDNN
if (!disable_mkldnn_fc_passes_) {
EraseFcMkldnnPasses();
}
disable_mkldnn_fc_passes_ = true;
#else
disable_mkldnn_fc_passes_ = false;
#endif
}
void CpuPassStrategy::EraseFcMkldnnPasses() {
std::vector<std::string> fc_passes_to_erase(
{"fc_mkldnn_pass",
"fc_act_mkldnn_fuse_pass",
"fc_elementwise_add_mkldnn_fuse_pass"});
for (const auto &pass : fc_passes_to_erase) {
int idx = GetPassIndex(pass);
if (idx != -1) {
passes_.erase(std::begin(passes_) + idx);
}
}
}
IpuPassStrategy::IpuPassStrategy() : PassStrategy({}) { IpuPassStrategy::IpuPassStrategy() : PassStrategy({}) {
passes_.assign({"inference_process_pass"}); passes_.assign({"inference_process_pass"});
} }
......
...@@ -152,6 +152,9 @@ class PD_INFER_DECL PassStrategy : public PaddlePassBuilder { ...@@ -152,6 +152,9 @@ class PD_INFER_DECL PassStrategy : public PaddlePassBuilder {
/// \brief Enable MKLDNN int8. /// \brief Enable MKLDNN int8.
virtual void EnableMkldnnInt8() {} virtual void EnableMkldnnInt8() {}
/// \brief Disable MKLDNN fc passes.
virtual void DisableMkldnnFcPasses() {}
/// \brief Check if we are using gpu. /// \brief Check if we are using gpu.
/// \return A bool variable implying whether we are in gpu mode. /// \return A bool variable implying whether we are in gpu mode.
bool use_gpu() const { return use_gpu_; } bool use_gpu() const { return use_gpu_; }
...@@ -205,6 +208,7 @@ class PD_INFER_DECL CpuPassStrategy : public PassStrategy { ...@@ -205,6 +208,7 @@ class PD_INFER_DECL CpuPassStrategy : public PassStrategy {
use_mkldnn_quantizer_ = other.use_mkldnn_quantizer_; use_mkldnn_quantizer_ = other.use_mkldnn_quantizer_;
use_mkldnn_bfloat16_ = other.use_mkldnn_bfloat16_; use_mkldnn_bfloat16_ = other.use_mkldnn_bfloat16_;
use_mkldnn_int8_ = other.use_mkldnn_int8_; use_mkldnn_int8_ = other.use_mkldnn_int8_;
disable_mkldnn_fc_passes_ = other.disable_mkldnn_fc_passes_;
} }
/// \brief Default destructor. /// \brief Default destructor.
virtual ~CpuPassStrategy() = default; virtual ~CpuPassStrategy() = default;
...@@ -224,11 +228,18 @@ class PD_INFER_DECL CpuPassStrategy : public PassStrategy { ...@@ -224,11 +228,18 @@ class PD_INFER_DECL CpuPassStrategy : public PassStrategy {
/// \brief Enable MKLDNN int8. /// \brief Enable MKLDNN int8.
void EnableMkldnnInt8() override; void EnableMkldnnInt8() override;
/// \brief Disable MKLDNN fc passes.
void DisableMkldnnFcPasses() override;
protected: protected:
/// \brief Erase MKLDNN fc passes.
void EraseFcMkldnnPasses();
/// \cond Protected /// \cond Protected
bool use_mkldnn_quantizer_{false}; bool use_mkldnn_quantizer_{false};
bool use_mkldnn_bfloat16_{false}; bool use_mkldnn_bfloat16_{false};
bool use_mkldnn_int8_{false}; bool use_mkldnn_int8_{false};
bool disable_mkldnn_fc_passes_{false};
/// \endcond /// \endcond
}; };
...@@ -263,6 +274,9 @@ class PD_INFER_DECL GpuPassStrategy : public PassStrategy { ...@@ -263,6 +274,9 @@ class PD_INFER_DECL GpuPassStrategy : public PassStrategy {
/// \brief Not supported in GPU mode yet. /// \brief Not supported in GPU mode yet.
void EnableMkldnnInt8() override; void EnableMkldnnInt8() override;
/// \brief Disable MKLDNN fc passes.
void DisableMkldnnFcPasses() override;
/// \brief Default destructor. /// \brief Default destructor.
virtual ~GpuPassStrategy() = default; virtual ~GpuPassStrategy() = default;
......
...@@ -218,9 +218,6 @@ AnalysisConfig SetConfig(bool use_mkldnn, bool use_bfloat16) { ...@@ -218,9 +218,6 @@ AnalysisConfig SetConfig(bool use_mkldnn, bool use_bfloat16) {
if (use_mkldnn) { if (use_mkldnn) {
config.EnableMKLDNN(); config.EnableMKLDNN();
config.pass_builder()->AppendPass("fc_mkldnn_pass");
config.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass");
config.pass_builder()->AppendPass("fc_elementwise_add_mkldnn_fuse_pass");
} }
if (use_bfloat16) config.EnableMkldnnBfloat16(); if (use_bfloat16) config.EnableMkldnnBfloat16();
......
...@@ -214,8 +214,6 @@ void profile(bool use_mkldnn = false) { ...@@ -214,8 +214,6 @@ void profile(bool use_mkldnn = false) {
std::unordered_set<std::string> op_list = { std::unordered_set<std::string> op_list = {
"softmax", "elementwise_add", "relu", "fc"}; "softmax", "elementwise_add", "relu", "fc"};
cfg.SetMKLDNNOp(op_list); cfg.SetMKLDNNOp(op_list);
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass");
} }
std::vector<std::vector<PaddleTensor>> outputs; std::vector<std::vector<PaddleTensor>> outputs;
...@@ -276,8 +274,6 @@ void compare(bool use_mkldnn = false) { ...@@ -276,8 +274,6 @@ void compare(bool use_mkldnn = false) {
std::unordered_set<std::string> op_list = { std::unordered_set<std::string> op_list = {
"softmax", "elementwise_add", "relu"}; "softmax", "elementwise_add", "relu"};
cfg.SetMKLDNNOp(op_list); cfg.SetMKLDNNOp(op_list);
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass");
} }
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
......
...@@ -29,6 +29,7 @@ void SetInt8Config(AnalysisConfig *cfg, ...@@ -29,6 +29,7 @@ void SetInt8Config(AnalysisConfig *cfg,
std::vector<paddle::PaddleTensor> data) { std::vector<paddle::PaddleTensor> data) {
cfg->SetModel(FLAGS_infer_model); cfg->SetModel(FLAGS_infer_model);
cfg->EnableMKLDNN(); cfg->EnableMKLDNN();
cfg->DisableMkldnnFcPasses(); // fc passes caused loss in accuracy
cfg->EnableMkldnnQuantizer(); cfg->EnableMkldnnQuantizer();
auto pass_builder = cfg->pass_builder(); auto pass_builder = cfg->pass_builder();
pass_builder->DeletePass("constant_folding_pass"); pass_builder->DeletePass("constant_folding_pass");
......
...@@ -92,6 +92,7 @@ void compare(bool use_mkldnn = false) { ...@@ -92,6 +92,7 @@ void compare(bool use_mkldnn = false) {
AnalysisConfig cfg; AnalysisConfig cfg;
SetConfig(&cfg, use_mkldnn, false); SetConfig(&cfg, use_mkldnn, false);
cfg.DisableMkldnnFcPasses(); // fc passes caused loss in accuracy
auto pass_builder = cfg.pass_builder(); auto pass_builder = cfg.pass_builder();
pass_builder->DeletePass("constant_folding_pass"); pass_builder->DeletePass("constant_folding_pass");
CompareNativeAndAnalysis( CompareNativeAndAnalysis(
......
...@@ -51,9 +51,8 @@ void profile(bool use_mkldnn = false) { ...@@ -51,9 +51,8 @@ void profile(bool use_mkldnn = false) {
if (use_mkldnn) { if (use_mkldnn) {
cfg.EnableMKLDNN(); cfg.EnableMKLDNN();
if (!FLAGS_disable_mkldnn_fc) { if (FLAGS_disable_mkldnn_fc) {
cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); cfg.DisableMkldnnFcPasses();
cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass");
} }
} }
std::vector<std::vector<PaddleTensor>> outputs; std::vector<std::vector<PaddleTensor>> outputs;
...@@ -88,9 +87,8 @@ void compare(bool use_mkldnn = false) { ...@@ -88,9 +87,8 @@ void compare(bool use_mkldnn = false) {
SetConfig(&cfg); SetConfig(&cfg);
if (use_mkldnn) { if (use_mkldnn) {
cfg.EnableMKLDNN(); cfg.EnableMKLDNN();
if (!FLAGS_disable_mkldnn_fc) { if (FLAGS_disable_mkldnn_fc) {
cfg.pass_builder()->AppendPass("fc_mkldnn_pass"); cfg.DisableMkldnnFcPasses();
cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass");
} }
} }
......
...@@ -262,15 +262,19 @@ TEST(Analyzer_lexical_test, Analyzer_lexical_analysis) { ...@@ -262,15 +262,19 @@ TEST(Analyzer_lexical_test, Analyzer_lexical_analysis) {
if (FLAGS_enable_bf16) { if (FLAGS_enable_bf16) {
analysis_cfg.EnableMkldnnBfloat16(); analysis_cfg.EnableMkldnnBfloat16();
} else if (FLAGS_enable_int8) { } else if (FLAGS_enable_int8) {
if (FLAGS_fuse_multi_gru) if (FLAGS_fuse_multi_gru) {
analysis_cfg.pass_builder()->AppendPass("multi_gru_fuse_pass"); analysis_cfg.pass_builder()->AppendPass("multi_gru_fuse_pass");
}
std::shared_ptr<std::vector<PaddleTensor>> warmup_data = std::shared_ptr<std::vector<PaddleTensor>> warmup_data =
WarmupData(input_slots_all); WarmupData(input_slots_all);
analysis_cfg.EnableMkldnnQuantizer(); analysis_cfg.EnableMkldnnQuantizer();
analysis_cfg.mkldnn_quantizer_config()->SetWarmupData(warmup_data); analysis_cfg.mkldnn_quantizer_config()->SetWarmupData(warmup_data);
analysis_cfg.mkldnn_quantizer_config()->SetWarmupBatchSize( analysis_cfg.mkldnn_quantizer_config()->SetWarmupBatchSize(
FLAGS_batch_size); FLAGS_batch_size);
} else {
// if fp32 => disable mkldnn fc passes
// when passes are enabled dnnl error occurs for iterations==0
analysis_cfg.DisableMkldnnFcPasses();
} }
std::vector<double> acc_analysis(3); std::vector<double> acc_analysis(3);
acc_analysis = Lexical_Test(input_slots_all, &outputs, &analysis_cfg, true); acc_analysis = Lexical_Test(input_slots_all, &outputs, &analysis_cfg, true);
......
...@@ -169,8 +169,6 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) { ...@@ -169,8 +169,6 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) {
} }
if (use_mkldnn) { if (use_mkldnn) {
cfg->EnableMKLDNN(); cfg->EnableMKLDNN();
cfg->pass_builder()->AppendPass("fc_mkldnn_pass");
cfg->pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass");
} }
// Enable seqpool_concat_fuse_pass, disabled by default since it takes much // Enable seqpool_concat_fuse_pass, disabled by default since it takes much
// time // time
......
...@@ -24,8 +24,6 @@ void compare(bool use_mkldnn = false) { ...@@ -24,8 +24,6 @@ void compare(bool use_mkldnn = false) {
SetConfig(&cfg); SetConfig(&cfg);
if (use_mkldnn) { if (use_mkldnn) {
cfg.EnableMKLDNN(); cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass");
} }
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
......
...@@ -25,8 +25,6 @@ void profile(bool use_mkldnn = false) { ...@@ -25,8 +25,6 @@ void profile(bool use_mkldnn = false) {
std::vector<std::vector<PaddleTensor>> outputs; std::vector<std::vector<PaddleTensor>> outputs;
if (use_mkldnn) { if (use_mkldnn) {
cfg.EnableMKLDNN(); cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass");
} }
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
......
...@@ -88,8 +88,6 @@ void profile(bool use_mkldnn = false) { ...@@ -88,8 +88,6 @@ void profile(bool use_mkldnn = false) {
SetConfig(&cfg); SetConfig(&cfg);
if (use_mkldnn) { if (use_mkldnn) {
cfg.EnableMKLDNN(); cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass");
} }
// cfg.pass_builder()->TurnOnDebug(); // cfg.pass_builder()->TurnOnDebug();
std::vector<std::vector<PaddleTensor>> outputs; std::vector<std::vector<PaddleTensor>> outputs;
...@@ -142,8 +140,6 @@ void compare(bool use_mkldnn = false) { ...@@ -142,8 +140,6 @@ void compare(bool use_mkldnn = false) {
SetConfig(&cfg); SetConfig(&cfg);
if (use_mkldnn) { if (use_mkldnn) {
cfg.EnableMKLDNN(); cfg.EnableMKLDNN();
cfg.pass_builder()->AppendPass("fc_mkldnn_pass");
cfg.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass");
} }
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
......
...@@ -72,11 +72,6 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) { ...@@ -72,11 +72,6 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) {
if (use_mkldnn) { if (use_mkldnn) {
cfg->EnableMKLDNN(); cfg->EnableMKLDNN();
cfg->SwitchIrOptim(); cfg->SwitchIrOptim();
size_t insertingIndex = cfg->pass_builder()->GetPassIndex(
"fc_elementwise_add_mkldnn_fuse_pass");
cfg->pass_builder()->InsertPass(insertingIndex, "fc_act_mkldnn_fuse_pass");
cfg->pass_builder()->InsertPass(insertingIndex, "fc_mkldnn_pass");
} }
} }
......
...@@ -809,6 +809,22 @@ void BindAnalysisConfig(py::module *m) { ...@@ -809,6 +809,22 @@ void BindAnalysisConfig(py::module *m) {
py::arg("mkldnn_int8_enabled_op_types") = py::arg("mkldnn_int8_enabled_op_types") =
std::unordered_set<std::string>({})) std::unordered_set<std::string>({}))
.def("mkldnn_int8_enabled", &AnalysisConfig::mkldnn_int8_enabled) .def("mkldnn_int8_enabled", &AnalysisConfig::mkldnn_int8_enabled)
.def("disable_mkldnn_fc_passes",
&AnalysisConfig::DisableMkldnnFcPasses,
R"DOC(
Disable Mkldnn FC
Args:
None.
Returns:
None.
Examples:
.. code-block:: python
from paddle.inference import Config
config = Config("")
config.enable_mkldnn()
config.disable_mkldnn_fc_passes()
)DOC")
#endif #endif
.def("set_mkldnn_op", &AnalysisConfig::SetMKLDNNOp) .def("set_mkldnn_op", &AnalysisConfig::SetMKLDNNOp)
.def("set_model_buffer", &AnalysisConfig::SetModelBuffer) .def("set_model_buffer", &AnalysisConfig::SetModelBuffer)
......
...@@ -112,6 +112,7 @@ class TestLstmModelPTQ(unittest.TestCase): ...@@ -112,6 +112,7 @@ class TestLstmModelPTQ(unittest.TestCase):
config.switch_use_feed_fetch_ops(True) config.switch_use_feed_fetch_ops(True)
config.switch_ir_optim(True) config.switch_ir_optim(True)
config.enable_mkldnn() config.enable_mkldnn()
config.disable_mkldnn_fc_passes() # fc passes caused dnnl error
config.set_mkldnn_cache_capacity(mkldnn_cache_capacity) config.set_mkldnn_cache_capacity(mkldnn_cache_capacity)
if enable_ptq: if enable_ptq:
# This pass to work properly, must be added before fc_fuse_pass # This pass to work properly, must be added before fc_fuse_pass
......
...@@ -40,7 +40,7 @@ class TestMKLDNNMatmulFuseOp(InferencePassTest): ...@@ -40,7 +40,7 @@ class TestMKLDNNMatmulFuseOp(InferencePassTest):
out = fluid.layers.reshape( out = fluid.layers.reshape(
out, [0, 0, self.shape_y[0] * self.shape_y[2]] out, [0, 0, self.shape_y[0] * self.shape_y[2]]
) )
out = fluid.layers.fc(out, size=1) out = fluid.layers.relu(out)
return out return out
def setUp(self): def setUp(self):
...@@ -109,7 +109,7 @@ class TestMKLDNNMatmulOpNotFusedBreakPattern(TestMKLDNNMatmulFuseOp): ...@@ -109,7 +109,7 @@ class TestMKLDNNMatmulOpNotFusedBreakPattern(TestMKLDNNMatmulFuseOp):
out = fluid.layers.reshape( out = fluid.layers.reshape(
out, [0, 0, self.shape_y[0] * self.shape_y[2]] out, [0, 0, self.shape_y[0] * self.shape_y[2]]
) )
out = fluid.layers.fc(out, size=1) out = fluid.layers.relu(out)
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册