diff --git a/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc index e1260f62ddb6499abf1794af386045bf0565c4b3..9a43edf40ef443b370b679522cc04fcaf722e032 100644 --- a/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc @@ -17,7 +17,7 @@ #include #include "paddle/fluid/framework/op_version_registry.h" - +#include "paddle/fluid/string/pretty_log.h" namespace paddle { namespace framework { class Scope; @@ -335,6 +335,9 @@ void FCGRUFusePass::ApplyImpl(ir::Graph* graph) const { graph, name_scope_, param_scope(), true /*with_fc_bias*/); AddStatis(fusion_count); + + string::PrettyLogDetail("--- fused %d pairs of fc gru patterns", + fusion_count); } } // namespace ir diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc index 35704f1f3309e1a91b18d7a2c30ee7dda3b57e51..2e6ce1a0f73818a7f104bbef13220b58b72bd72f 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc @@ -16,6 +16,7 @@ #include #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/string/pretty_log.h" namespace paddle { namespace framework { @@ -348,6 +349,9 @@ void FCLstmFusePass::ApplyImpl(ir::Graph* graph) const { BuildFusion(graph, name_scope_, param_scope(), true /*with_fc_bias*/); AddStatis(fusion_count); + + string::PrettyLogDetail("--- fused %d pairs of fc lstm patterns", + fusion_count); } } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc index 6fcea6a66cc5d18729d000131a87687ea7c9ed6c..b6a8de263aa2afb3934226e925961be1592f38dd 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc @@ -97,6 +97,19 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, op->SetInput("WeightX", {inputs[2]}); op->SetInput("WeightH", {inputs[3]}); op->SetOutput("Hidden", {outputs[0]}); + op->SetAttr("mkldnn_data_type", mkldnn_data_type); + op->SetAttr("Scale_data", 1.0f); + op->SetAttr("Shift_data", 0.0f); + op->SetAttr("Weight_scale", std::vector{1.0f}); + } else if (type == "fusion_lstm") { + op->SetInput("X", {inputs[0]}); + op->SetInput("Bias", {inputs[1]}); + op->SetInput("WeightX", {inputs[2]}); + op->SetInput("WeightH", {inputs[3]}); + + op->SetOutput("Hidden", {outputs[0]}); + op->SetOutput("Cell", {outputs[1]}); + op->SetAttr("mkldnn_data_type", mkldnn_data_type); op->SetAttr("Scale_data", 1.0f); op->SetAttr("Shift_data", 0.0f); @@ -418,6 +431,25 @@ ProgramDesc BuildProgramDescFusionGru() { return prog; } +static const std::initializer_list variable_names_fusion_lstm = { + "x", "wx", "wh", "b", "h", "c"}; + +// (x, wx, wh, b)->Fusion_lstm_1->h +ProgramDesc BuildProgramDescFusionLSTM() { + ProgramDesc prog; + for (auto& v : variable_names_transpose) { + auto* var = prog.MutableBlock(0)->Var(v); + if (v.find("wx") == 0 || v.find("wh") || v.find("b")) { + var->SetPersistable(true); + } + } + + SetOp(&prog, "fusion_lstm", "Fusion_lstm_1", {"x", "wx", "wh", "b"}, + {"h", "c"}, true, "int8"); + + return prog; +} + void MainTestFusionGru(const ProgramDesc& prog, int gru_count, int quant_count, int dequant_count, int added_nodes_count, float scale, float shift) { @@ -470,6 +502,59 @@ TEST(CpuQuantizePass, fusion_gru) { dequant_count, added_nodes_count, 2. * 127, 128.); } +void MainTestFusionLSTM(const ProgramDesc& prog, int expect_lstm_count, + int quant_count, int dequant_count, + int added_nodes_count, float scale, float shift) { + std::unique_ptr graph(new ir::Graph(prog)); + int original_nodes_num, current_nodes_num; + PreparePass(&graph, prog, variable_names_fusion_lstm, &original_nodes_num, + ¤t_nodes_num); + + int quantize_nodes_count = 0; + int dequantize_nodes_count = 0; + int lstm_nodes_count = 0; + for (auto* node : graph->Nodes()) { + if (node->IsOp()) { + auto* op = node->Op(); + if (op->Type() == "fusion_lstm") { + lstm_nodes_count++; + + auto op_name = BOOST_GET_CONST(std::string, op->GetAttr("name")); + EXPECT_EQ(BOOST_GET_CONST(float, op->GetAttr("Scale_data")), scale) + << "Scale_data for node '" + op_name + "'."; + EXPECT_EQ(BOOST_GET_CONST(float, op->GetAttr("Shift_data")), shift) + << "Shift_data for node '" + op_name + "'."; + EXPECT_EQ(BOOST_GET_CONST(std::vector, + op->GetAttr("Scale_weights"))[0], + scale) + << "Scale_weights for node '" + op_name + "'."; + EXPECT_EQ(BOOST_GET_CONST(bool, op->GetAttr("force_fp32_output")), true) + << "force_fp32_output for node '" + op_name + "'."; + } else if (op->Type() == "quantize") { + quantize_nodes_count++; + } else if (op->Type() == "dequantize") { + dequantize_nodes_count++; + } + } + } + EXPECT_EQ(lstm_nodes_count, expect_lstm_count); + EXPECT_EQ(quantize_nodes_count, quant_count); + EXPECT_EQ(dequantize_nodes_count, dequant_count); + EXPECT_EQ(original_nodes_num + added_nodes_count, current_nodes_num); +} + +TEST(CpuQuantizePass, fusion_lstm) { + // (x, wx, wh, b)->Fusion_lstm->h + int expect_lstm_count = 1; + int expect_quant_count = 1; + int dequant_count = 0; + // 1 Quant + 1 IN + 0 DeQuant + 0 OUT + int added_nodes_count = 1 + 1 + 0 + 0; + MainTestFusionLSTM(BuildProgramDescFusionLSTM(), expect_lstm_count, + expect_quant_count, dequant_count, added_nodes_count, + 2. * 127, 128.); +} + const std::vector churn_out_vars(ProgramDesc* prog, const std::string& prefix, int number) { diff --git a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py index 9a8835ba4c8043aa63650d50718cf869d9edfb01..112623d23a65f2cd6e2747e24f3fb72c9d9b5cf3 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py @@ -71,6 +71,7 @@ class Quant2Int8MkldnnPass(object): self._relu_ops = ['relu', 'relu6'] self._matmul_ops = ['matmul'] self._gru_ops = ['fusion_gru', 'multi_gru'] + self._lstm_ops = ['fusion_lstm'] self._weight_thresholds = {} # Collect the Input and Output sclaes from Fake quant models self._var_quant_scales = {} @@ -535,10 +536,38 @@ class Quant2Int8MkldnnPass(object): self._var_quant_scales[wx_var_name] = (use_unsigned_int, lod_tensor) + def _compute_single_lstm_weight_scales(wx_var_name, wh_var_name): + wx = np.array(self._load_param(self._scope, wx_var_name)) + wh = np.array(self._load_param(self._scope, wh_var_name)) + + lstm_weights_scale = 1.0 / np.max( + np.abs(np.concatenate( + [wx[:, :], wh[:, :]], axis=0)), axis=0) + lstm_weights_scale = lstm_weights_scale.astype('float') + + return self._convert_scale2tensor(lstm_weights_scale) + + def _compute_lstm_weight_scales(wx_name, wh_name): + for op in graph.all_op_nodes(): + if op.op().type() in self._lstm_ops: + assert len(op.input(wx_name)) == len( + op.input(wh_name) + ), 'Mismatch in number of weights inputs ({} for WeightX vs. {} for WeightH).'.format( + len(op.input(wx_name)), len(op.input(wh_name))) + for i, wx_var_name in enumerate(op.input(wx_name)): + wh_var_name = op.input(wh_name)[i] + use_unsigned_int = False + lod_tensor = _compute_single_lstm_weight_scales( + wx_var_name, wh_var_name) + self._var_quant_scales[wx_var_name] = (use_unsigned_int, + lod_tensor) + _compute_var_scales(self._conv_ops, "Filter", axis=1) _compute_var_scales(self._fc_ops, "W", axis=0) _compute_var_scales(self._gru_ops, "WeightH", axis=0) + _compute_var_scales(self._lstm_ops, "WeightH", axis=0) _compute_gru_weight_scales("WeightX", "WeightH") + _compute_lstm_weight_scales("WeightX", "WeightH") return graph def _find_avg_pooling_ids(self, graph): diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt index 4ea1c04f89d30797278132d802c7c54445f2a02c..e55db665052cec0176e4e070f2bcb06190fabde7 100644 --- a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -265,6 +265,12 @@ if(LINUX AND WITH_MKLDNN) download_quant_model(${QUANT2_GRU_MODEL_DIR} ${QUANT2_GRU_MODEL_ARCHIVE} cf207f8076dcfb8b74d8b6bdddf9090c) set(QUANT2_GRU_OPS_TO_QUANTIZE "multi_gru") + # Quant2 LSTM + set(QUANT2_LSTM_MODEL_ARCHIVE "lstm_quant.tar.gz") + set(QUANT2_LSTM_MODEL_DIR "${QUANT_INSTALL_DIR}/lstm_quant_test") + download_quant_model(${QUANT2_LSTM_MODEL_DIR} ${QUANT2_LSTM_MODEL_ARCHIVE} 40a693803b12ee9e251258f32559abcb) + set(QUANT2_LSTM_OPS_TO_QUANTIZE "fusion_lstm") + ### Save FP32 model or INT8 model from Quant model set(QUANT2_INT8_RESNET50_SAVE_PATH "${QUANT_INSTALL_DIR}/ResNet50_quant2_int8") @@ -276,6 +282,9 @@ if(LINUX AND WITH_MKLDNN) set(QUANT2_INT8_GRU_SAVE_PATH "${QUANT_INSTALL_DIR}/GRU_quant2_int8") save_quant_nlp_model_test(save_quant2_model_gru ${QUANT2_GRU_MODEL_DIR}/GRU_quant_acc ${QUANT2_INT8_GRU_SAVE_PATH} ${QUANT2_GRU_OPS_TO_QUANTIZE}) + set(QUANT2_INT8_LSTM_SAVE_PATH "${QUANT_INSTALL_DIR}/lstm_quant2_int8") + save_quant_nlp_model_test(save_quant2_model_lstm ${QUANT2_LSTM_MODEL_DIR}/lstm_quant ${QUANT2_INT8_LSTM_SAVE_PATH} ${QUANT2_LSTM_OPS_TO_QUANTIZE}) + # Convert Quant2 model to dot and pdf files set(QUANT2_INT8_ERNIE_DOT_SAVE_PATH "${QUANT_INSTALL_DIR}/Ernie_quant2_int8_dot_file") convert_model2dot_test(convert_model2dot_ernie ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QUANT2_INT8_ERNIE_DOT_SAVE_PATH} "Ernie_quant2_int8")