From c11a104c780d5305f420313d669dbb05a7f8e48c Mon Sep 17 00:00:00 2001 From: Wojciech Uss Date: Wed, 23 Sep 2020 05:24:05 +0200 Subject: [PATCH] Added support for quantization of fusion_gru --- .../framework/ir/graph_pattern_detector.cc | 23 ++++- .../framework/ir/graph_pattern_detector.h | 15 ++++ .../framework/ir/mkldnn/cpu_quantize_pass.cc | 79 ++++++++++++++++-- .../framework/ir/mkldnn/cpu_quantize_pass.h | 19 ++--- .../ir/mkldnn/cpu_quantize_pass_tester.cc | 83 ++++++++++++++++++- .../ir/mkldnn/cpu_quantize_squash_pass.cc | 10 ++- .../quantization/quant2_int8_mkldnn_pass.py | 36 ++++++++ .../fluid/contrib/slim/tests/CMakeLists.txt | 26 +++--- .../mkldnn/test_fusion_gru_int8_mkldnn_op.py | 20 +++-- tools/codestyle/clang_format.hook | 2 +- 10 files changed, 267 insertions(+), 46 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 9c1eaa99a3..4c66fbe777 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1882,9 +1882,9 @@ PDNode *patterns::MultipleQuantize::operator()() { PDNode *patterns::QuantizePlacement::operator()( const std::unordered_set &quantize_enabled_op_types) { std::unordered_set supported_op_types = - std::unordered_set({"concat", "conv2d", "elementwise_add", - "fc", "matmul", "pool2d", "prior_box", - "relu", "reshape2", "transpose2"}); + std::unordered_set( + {"concat", "conv2d", "elementwise_add", "fc", "matmul", "pool2d", + "prior_box", "relu", "reshape2", "transpose2", "fusion_gru"}); if (!quantize_enabled_op_types.empty()) { supported_op_types = quantize_enabled_op_types; } @@ -2281,6 +2281,23 @@ PDNode *patterns::MatmulTransposeReshapePattern::operator()() { return reshape_out; } +PDNode *patterns::FusionGru::operator()() { + auto op = pattern->NewNode(op_repr())->assert_is_op("fusion_gru"); + auto x = pattern->NewNode(x_repr())->AsInput()->assert_is_op_input( + "fusion_gru", "X"); + auto weight_h = pattern->NewNode(weight_h_repr()) + ->AsInput() + ->assert_is_op_input("fusion_gru", "WeightH"); + auto weight_x = pattern->NewNode(weight_x_repr()) + ->AsInput() + ->assert_is_op_input("fusion_gru", "WeightX"); + auto out = pattern->NewNode(out_repr()) + ->AsOutput() + ->assert_is_op_output("fusion_gru", "Hidden"); + op->LinksFrom({x, weight_h, weight_x}).LinksTo({out}); + return out; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 053c1fe832..bc1669caea 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1302,6 +1302,21 @@ struct MatmulTransposeReshapePattern : public PatternBase { PATTERN_DECL_NODE(reshape_out_xshape); }; +// fusion_gru op +// Forward pass for fusion_gru. +// fusion_gru out is a result of the operator. +struct FusionGru : public PatternBase { + FusionGru(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "fusion_gru") {} + + PDNode* operator()(); + PATTERN_DECL_NODE(op); + PATTERN_DECL_NODE(x); + PATTERN_DECL_NODE(weight_h); + PATTERN_DECL_NODE(weight_x); + PATTERN_DECL_NODE(out); +}; + } // namespace patterns // Link two ir::Nodes from each other. diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc index aa0979b4be..e1fddf9ed2 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc @@ -64,8 +64,9 @@ enum { U8_MAX = 255, S8_MAX = 127 }; void CPUQuantizePass::QuantizeInput(Graph* g, Node* op, Node* input, std::string input_name, double scale_to_one, - bool is_unsigned, - std::string scale_attr_name) const { + bool is_input_unsigned, + std::string scale_attr_name, float shift, + std::string shift_attr_name) const { auto inputs = op->Op()->InputNames(); bool name_found = std::find(inputs.begin(), inputs.end(), input_name) != inputs.end(); @@ -73,7 +74,7 @@ void CPUQuantizePass::QuantizeInput(Graph* g, Node* op, Node* input, platform::errors::InvalidArgument( "Var(%s) isn't the input of the %s operator.", input_name, op->Op()->Type())); - unsigned max = is_unsigned ? U8_MAX : S8_MAX; + unsigned max = is_input_unsigned ? U8_MAX : S8_MAX; float scale = scale_to_one * max; // Create quantize output variable @@ -87,7 +88,8 @@ void CPUQuantizePass::QuantizeInput(Graph* g, Node* op, Node* input, q_desc.SetOutput("Output", std::vector({quantize_out_node->Name()})); q_desc.SetAttr("Scale", scale); - q_desc.SetAttr("is_negative_input", !is_unsigned); + q_desc.SetAttr("Shift", shift); + q_desc.SetAttr("is_negative_input", !is_input_unsigned); q_desc.SetAttr("output_format", Has("data_layout") ? Get("data_layout") : "NHWC"); @@ -104,11 +106,13 @@ void CPUQuantizePass::QuantizeInput(Graph* g, Node* op, Node* input, IR_NODE_LINK_TO(quantize_out_node, op); if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale); + if (!shift_attr_name.empty()) op->Op()->SetAttr(shift_attr_name, shift); } void CPUQuantizePass::QuantizeInputs(Graph* g, Node* op, std::string input_name, - bool are_unsigned, - std::string scale_attr_name) const { + bool are_inputs_unsigned, + std::string scale_attr_name, float shift, + std::string shift_attr_name) const { auto inputs = op->inputs; auto output = op->outputs[0]; PADDLE_ENFORCE_GE(inputs.size(), 1, @@ -128,7 +132,7 @@ void CPUQuantizePass::QuantizeInputs(Graph* g, Node* op, std::string input_name, std::vector quantize_out_node_names(inputs.size()); double scale_out = GetScaleValueForNode(output); - unsigned max = are_unsigned ? U8_MAX : S8_MAX; + unsigned max = are_inputs_unsigned ? U8_MAX : S8_MAX; float scale = scale_out * max; for (size_t i = 0; i < inputs.size(); i++) { @@ -138,10 +142,11 @@ void CPUQuantizePass::QuantizeInputs(Graph* g, Node* op, std::string input_name, quantize_out_node_names[i] = quantize_out_nodes[i]->Name(); q_desc.SetAttr("Scale", scale); + q_desc.SetAttr("Shift", shift); q_desc.SetInput("Input", std::vector({inputs[i]->Name()})); q_desc.SetOutput("Output", std::vector({quantize_out_node_names[i]})); - q_desc.SetAttr("is_negative_input", !are_unsigned); + q_desc.SetAttr("is_negative_input", !are_inputs_unsigned); auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied. // link quantize op @@ -155,6 +160,7 @@ void CPUQuantizePass::QuantizeInputs(Graph* g, Node* op, std::string input_name, op->Op()->SetInput(input_name, quantize_out_node_names); if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale); + if (!shift_attr_name.empty()) op->Op()->SetAttr(shift_attr_name, shift); } void CPUQuantizePass::DequantizeOutput(Graph* g, Node* op, Node* output, @@ -783,6 +789,62 @@ void CPUQuantizePass::QuantizeElementwiseAdd(Graph* graph) const { quantize_elementwise_add_count); } +void CPUQuantizePass::QuantizeFusionGru(Graph* graph) const { + GraphPatternDetector gpd; + patterns::FusionGru pattern{gpd.mutable_pattern(), name_scope_}; + pattern(); + + int quantize_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "Quantize fusion_gru op"; + GET_IR_NODE_FROM_SUBGRAPH(op, op, pattern); + + // skip if should not be quantized + if (!platform::HasOpINT8DataType(op->Op())) { + LogQuantizationDisabled(op); + return; + } + + GET_IR_NODE_FROM_SUBGRAPH(x, x, pattern); + GET_IR_NODE_FROM_SUBGRAPH(weight_h, weight_h, pattern); + GET_IR_NODE_FROM_SUBGRAPH(weight_x, weight_x, pattern); + GET_IR_NODE_FROM_SUBGRAPH(out, out, pattern); + + if (!AreScalesPresentForNodes(op, {x, weight_h, weight_x})) { + LogCannotQuantizeOp(op); + return; + } + + bool is_x_unsigned{false}; + auto input_x_scale = GetScaleValueForNode(x, &is_x_unsigned); + + double input_x_shift{128.}; + if (is_x_unsigned) input_x_shift = 0.; + + QuantizeInput(g, op, x, "X", input_x_scale, is_x_unsigned, "Scale_data", + input_x_shift, "Shift_data"); + + auto weight_scale_tensor = GetScaleTensorForNode(weight_x); + EigenVectorArrayMap eigen_tensor{weight_scale_tensor.data(), + weight_scale_tensor.numel(), 1}; + eigen_tensor *= static_cast(S8_MAX); + std::vector scale_weights{ + weight_scale_tensor.data(), + weight_scale_tensor.data() + weight_scale_tensor.numel()}; + + op->Op()->SetAttr("Scale_weights", scale_weights); + // return fp32 data + op->Op()->SetAttr("force_fp32_output", true); + + ++quantize_count; + }; + gpd(graph, handler); + AddStatis(quantize_count); + + PrettyLogDetail("--- quantized %d fusion_gru ops", quantize_count); +} + void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { VLOG(3) << "Quantizing the graph."; PADDLE_ENFORCE_NOT_NULL( @@ -802,6 +864,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { QuantizeReshape(graph); QuantizeMatmul(graph); QuantizeElementwiseAdd(graph); + QuantizeFusionGru(graph); } } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h index 21219e7dca..ff93bb824b 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h @@ -45,31 +45,26 @@ class CPUQuantizePass : public FusePassBase { void ApplyImpl(ir::Graph* graph) const override; void QuantizeConv(Graph* graph, bool with_residual_data = false) const; - void QuantizeFc(Graph* graph) const; - void QuantizePool(Graph* graph) const; - void QuantizeConcat(Graph* graph) const; - void QuantizePriorBox(Graph* graph) const; - void QuantizeTranspose(Graph* graph) const; - void QuantizeReshape(Graph* graph) const; - void QuantizeMatmul(Graph* graph) const; - void QuantizeElementwiseAdd(Graph* graph) const; + void QuantizeFusionGru(Graph* graph) const; void QuantizeInput(Graph* g, Node* op, Node* input, std::string input_name, - double scale_to_one, bool is_unsigned, - std::string scale_attr_name = "") const; + double scale_to_one, bool is_input_unsigned, + std::string scale_attr_name = "", float shift = 0.0, + std::string shift_attr_name = "") const; // quantize all inputs of given name with the same (minimum) scale void QuantizeInputs(Graph* g, Node* op, std::string input_name, - bool are_unsigned, - std::string scale_attr_name = "") const; + bool are_inputs_unsigned, + std::string scale_attr_name = "", float shift = 0.0, + std::string shift_attr_name = "") const; void DequantizeOutput(Graph* g, Node* op, Node* output, std::string output_name, double scale_to_one, diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc index a66e9f0e93..1fdceed19a 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h" #include +#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h" #include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/imperative/type_defs.h" @@ -91,6 +91,16 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, op->SetAttr("Scale_x", 1.0f); op->SetAttr("Scale_y", 1.0f); op->SetAttr("Scale_out", 1.0f); + } else if (type == "fusion_gru") { + op->SetInput("X", {inputs[0]}); + op->SetInput("Bias", {inputs[1]}); + op->SetInput("WeightX", {inputs[2]}); + op->SetInput("WeightH", {inputs[3]}); + op->SetOutput("Hidden", {outputs[0]}); + op->SetAttr("mkldnn_data_type", mkldnn_data_type); + op->SetAttr("Scale_data", 1.0f); + op->SetAttr("Shift_data", 0.0f); + op->SetAttr("Weight_scale", std::vector{1.0f}); } } @@ -389,6 +399,77 @@ TEST(CpuQuantizePass, transpose) { quant_count, dequant_count, added_nodes_count, 2.0f * 127); } +static const std::initializer_list variable_names_fusion_gru = { + "x", "wx", "wh", "b", "h"}; + +// x->Fusion_gru->h +ProgramDesc BuildProgramDescFusionGru() { + ProgramDesc prog; + for (auto& v : variable_names_transpose) { + auto* var = prog.MutableBlock(0)->Var(v); + if (v.find("wx") == 0 || v.find("wh") || v.find("b")) { + var->SetPersistable(true); + } + } + + SetOp(&prog, "fusion_gru", "Fusion_gru", {"x", "wx", "wh", "b"}, {"h"}, true, + "int8"); + + return prog; +} + +void MainTestFusionGru(const ProgramDesc& prog, int gru_count, int quant_count, + int dequant_count, int added_nodes_count, float scale, + float shift) { + std::unique_ptr graph(new ir::Graph(prog)); + int original_nodes_num, current_nodes_num; + PreparePass(&graph, prog, variable_names_fusion_gru, &original_nodes_num, + ¤t_nodes_num); + + int quantize_nodes_count = 0; + int dequantize_nodes_count = 0; + int gru_nodes_count = 0; + for (auto* node : graph->Nodes()) { + if (node->IsOp()) { + auto* op = node->Op(); + if (op->Type() == "fusion_gru") { + gru_nodes_count++; + + auto op_name = BOOST_GET_CONST(std::string, op->GetAttr("name")); + EXPECT_EQ(BOOST_GET_CONST(float, op->GetAttr("Scale_data")), scale) + << "Scale_data for node '" + op_name + "'."; + EXPECT_EQ(BOOST_GET_CONST(float, op->GetAttr("Shift_data")), shift) + << "Shift_data for node '" + op_name + "'."; + EXPECT_EQ(BOOST_GET_CONST(std::vector, + op->GetAttr("Scale_weights"))[0], + scale) + << "Scale_weights for node '" + op_name + "'."; + EXPECT_EQ(BOOST_GET_CONST(bool, op->GetAttr("force_fp32_output")), true) + << "force_fp32_output for node '" + op_name + "'."; + } else if (op->Type() == "quantize") { + quantize_nodes_count++; + } else if (op->Type() == "dequantize") { + dequantize_nodes_count++; + } + } + } + EXPECT_EQ(gru_nodes_count, gru_count); + EXPECT_EQ(quantize_nodes_count, quant_count); + EXPECT_EQ(dequantize_nodes_count, dequant_count); + EXPECT_EQ(original_nodes_num + added_nodes_count, current_nodes_num); +} + +TEST(CpuQuantizePass, fusion_gru) { + // x->Fusion_gru->h + int gru_count = 1; + int quant_count = 1; + int dequant_count = 0; + // 1 Quant + 1 IN + 0 DeQuant + 0 OUT + int added_nodes_count = 1 + 1 + 0 + 0; + MainTestFusionGru(BuildProgramDescFusionGru(), gru_count, quant_count, + dequant_count, added_nodes_count, 2. * 127, 128.); +} + static const std::initializer_list variable_names_reshape = { "a", "w1", "b", "c", "d", "e", "f"}; diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc index bc24c10d9d..a9bc46184b 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc @@ -73,6 +73,8 @@ void CPUQuantizeSquashPass::DequantQuantSquash( BOOST_GET_CONST(float, dequant_op->Op()->GetAttr("Scale")); float quant_scale = BOOST_GET_CONST(float, quant_op->Op()->GetAttr("Scale")); + float dequant_shift = dequant_op->Op()->GetAttrIfExists("Shift"); + float quant_shift = quant_op->Op()->GetAttrIfExists("Shift"); PADDLE_ENFORCE_NE( nodes_keep_counter->find(dequant_out), nodes_keep_counter->end(), platform::errors::NotFound("The dequant output node is not found.")); @@ -80,7 +82,7 @@ void CPUQuantizeSquashPass::DequantQuantSquash( // check if dequantize op should be kept or removed, decrease the counter bool keep_dequant = (*nodes_keep_counter)[dequant_out]-- > 1; - if (dequant_scale == quant_scale) { + if (dequant_scale == quant_scale && dequant_shift == quant_shift) { // squash dequantize-quantize to nothing auto quant_out_var_name = quant_out->Name(); auto next_op_inputs = next_op_desc->InputNames(); @@ -107,7 +109,9 @@ void CPUQuantizeSquashPass::DequantQuantSquash( desc.SetInput("Input", std::vector({dequant_in->Name()})); desc.SetOutput("Output", std::vector({quant_out->Name()})); desc.SetAttr("Scale_in", dequant_scale); + desc.SetAttr("Shift_in", dequant_shift); desc.SetAttr("Scale_out", quant_scale); + desc.SetAttr("Shift_out", quant_shift); auto requant_op = g->CreateOpNode(&desc); @@ -290,6 +294,7 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const { })); auto* first_quant_out = first_quant_op->outputs[0]; float scale = first_quant_op->Op()->GetAttrIfExists("Scale"); + float shift = first_quant_op->Op()->GetAttrIfExists("Shift"); PADDLE_ENFORCE_NE(scale, 0, platform::errors::InvalidArgument( @@ -299,7 +304,8 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const { auto quant_op = prev_out->outputs[iter]; if (quant_op->IsOp() && quant_op->Op()->Type() == "quantize" && quant_op->id() != first_quant_op->id() && - quant_op->Op()->GetAttrIfExists("Scale") == scale) { + quant_op->Op()->GetAttrIfExists("Scale") == scale && + quant_op->Op()->GetAttrIfExists("Shift") == shift) { auto quant_out = quant_op->outputs[0]; auto last_op = quant_out->outputs[0]; diff --git a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py index dadc756c43..45df381b63 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py @@ -66,6 +66,7 @@ class Quant2Int8MkldnnPass(object): self._fc_ops = ['fc'] self._relu_ops = ['relu', 'relu6'] self._matmul_ops = ['matmul'] + self._gru_ops = ['fusion_gru'] self._weight_scales = {} # Collect the Input and Output sclaes from Fake quant models self._var_quant_scales = {} @@ -449,8 +450,43 @@ class Quant2Int8MkldnnPass(object): self._var_quant_scales[weight_var_name] = (use_unsigned_int, lod_tensor) + def _compute_gru_weight_scales(wx_name, wh_name): + for op in graph.all_op_nodes(): + if op.op().type() in self._gru_ops: + wx_var_name = op.input(wx_name)[0] + wh_var_name = op.input(wh_name)[0] + wx = np.array(self._load_param(self._scope, wx_var_name)) + wh = np.array(self._load_param(self._scope, wh_var_name)) + OC = wh.shape[0] + scale_ur = 1.0 / np.max(np.abs( + np.concatenate( + [ + wx[:, :2 * OC], wh.flatten()[:2 * OC * OC] + .reshape(OC, 2 * OC) + ], + axis=0)), + axis=0) + scale_o = 1.0 / np.max(np.abs( + np.concatenate( + [ + wx[:, 2 * OC:], wh.flatten()[2 * OC * OC:] + .reshape(OC, OC) + ], + axis=0)), + axis=0) + + gru_weights_scale = np.concatenate( + [scale_ur, scale_o]).astype('float') + + lod_tensor = self._convert_scale2tensor(gru_weights_scale) + use_unsigned_int = False + self._var_quant_scales[wx_var_name] = (use_unsigned_int, + lod_tensor) + _compute_var_scales(self._conv_ops, "Filter", axis=1) _compute_var_scales(self._fc_ops, "W", axis=0) + _compute_var_scales(self._gru_ops, "WeightH", axis=0) + _compute_gru_weight_scales("WeightX", "WeightH") return graph def _find_avg_pooling_ids(self, graph): diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt index 6ac005060e..2d21d372e5 100644 --- a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -98,18 +98,16 @@ function(download_quant_model install_dir data_file) endif() endfunction() -function(save_quant_ic_model_test target quant_model_dir fp32_model_save_path int8_model_save_path) +function(save_quant_ic_model_test target quant_model_dir int8_model_save_path) py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_quant_model.py ARGS --quant_model_path ${quant_model_dir} - --fp32_model_save_path ${fp32_model_save_path} --int8_model_save_path ${int8_model_save_path} --debug) endfunction() -function(save_quant_nlp_model_test target quant_model_dir fp32_model_save_path int8_model_save_path ops_to_quantize) +function(save_quant_nlp_model_test target quant_model_dir int8_model_save_path ops_to_quantize) py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_quant_model.py ARGS --quant_model_path ${quant_model_dir} - --fp32_model_save_path ${fp32_model_save_path} --int8_model_save_path ${int8_model_save_path} --ops_to_quantize ${ops_to_quantize}) endfunction() @@ -227,8 +225,6 @@ if(LINUX AND WITH_MKLDNN) set(NLP_LABLES_PATH "${NLP_DATA_DIR}/Ernie_dataset/label.xnli.dev") download_quant_data(${NLP_DATA_DIR} ${NLP_DATA_ARCHIVE}) - set(QUANT2_NLP_OPS_TO_QUANTIZE "fc,reshape2,transpose2,matmul,elementwise_add") - # Quant2 Ernie set(QUANT2_ERNIE_MODEL_ARCHIVE "ernie_qat.tar.gz") set(QUANT2_ERNIE_MODEL_DIR "${QUANT_INSTALL_DIR}/Ernie_quant2") @@ -236,17 +232,25 @@ if(LINUX AND WITH_MKLDNN) set(FP32_ERNIE_MODEL_ARCHIVE "ernie_fp32_model.tar.gz") set(FP32_ERNIE_MODEL_DIR "${QUANT_INSTALL_DIR}/Ernie_float") download_quant_fp32_model(${FP32_ERNIE_MODEL_DIR} ${FP32_ERNIE_MODEL_ARCHIVE}) - inference_quant2_int8_nlp_test(test_quant2_int8_ernie_mkldnn ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${FP32_ERNIE_MODEL_DIR}/ernie_fp32_model ${NLP_DATA_PATH} ${NLP_LABLES_PATH} ${QUANT2_NLP_OPS_TO_QUANTIZE}) + set(QUANT2_ERNIE_OPS_TO_QUANTIZE "fc,reshape2,transpose2,matmul,elementwise_add") + inference_quant2_int8_nlp_test(test_quant2_int8_ernie_mkldnn ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${FP32_ERNIE_MODEL_DIR}/ernie_fp32_model ${NLP_DATA_PATH} ${NLP_LABLES_PATH} ${QUANT2_ERNIE_OPS_TO_QUANTIZE}) + + # Quant2 GRU + set(QUANT2_GRU_MODEL_ARCHIVE "GRU_quant_acc.tar.gz") + set(QUANT2_GRU_MODEL_DIR "${QUANT_INSTALL_DIR}/GRU_quant2") + download_quant_model(${QUANT2_GRU_MODEL_DIR} ${QUANT2_GRU_MODEL_ARCHIVE}) + set(QUANT2_GRU_OPS_TO_QUANTIZE "fusion_gru") ### Save FP32 model or INT8 model from Quant model set(QUANT2_INT8_RESNET50_SAVE_PATH "${QUANT_INSTALL_DIR}/ResNet50_quant2_int8") - set(QUANT2_FP32_RESNET50_SAVE_PATH "${QUANT_INSTALL_DIR}/ResNet50_quant2_fp32") - save_quant_ic_model_test(save_quant2_model_resnet50 ${QUANT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${QUANT2_FP32_RESNET50_SAVE_PATH} ${QUANT2_INT8_RESNET50_SAVE_PATH}) + save_quant_ic_model_test(save_quant2_model_resnet50 ${QUANT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${QUANT2_INT8_RESNET50_SAVE_PATH}) set(QUANT2_INT8_ERNIE_SAVE_PATH "${QUANT_INSTALL_DIR}/Ernie_quant2_int8") - set(QUANT2_FP32_ERNIE_SAVE_PATH "${QUANT_INSTALL_DIR}/Ernie_quant2_fp32") - save_quant_nlp_model_test(save_quant2_model_ernie ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QUANT2_FP32_ERNIE_SAVE_PATH} ${QUANT2_INT8_ERNIE_SAVE_PATH} ${QUANT2_NLP_OPS_TO_QUANTIZE}) + save_quant_nlp_model_test(save_quant2_model_ernie ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QUANT2_INT8_ERNIE_SAVE_PATH} ${QUANT2_ERNIE_OPS_TO_QUANTIZE}) + + set(QUANT2_INT8_GRU_SAVE_PATH "${QUANT_INSTALL_DIR}/GRU_quant2_int8") + save_quant_nlp_model_test(save_quant2_model_gru ${QUANT2_GRU_MODEL_DIR}/GRU_quant_acc ${QUANT2_INT8_GRU_SAVE_PATH} ${QUANT2_GRU_OPS_TO_QUANTIZE}) # Convert Quant2 model to dot and pdf files set(QUANT2_INT8_ERNIE_DOT_SAVE_PATH "${QUANT_INSTALL_DIR}/Ernie_quant2_int8_dot_file") diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_int8_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_int8_mkldnn_op.py index ff4531f0e2..89343c9fae 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_int8_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_int8_mkldnn_op.py @@ -45,9 +45,10 @@ class TestFusionGRUINT8MKLDNNOp(OpTest): # Input data x_f32 = np.random.rand(T, self.IC).astype('float32') * 2 - 1 - scale_data = 63 - shift_data = 64 - x_u8 = (x_f32 * scale_data + shift_data).astype(np.uint8) + scale_data = 63.0 + shift_data = 64.0 + x_u8 = np.rint(x_f32 * scale_data + shift_data).astype(np.uint8) + # x_u8 = (x_f32 * scale_data + shift_data).astype(np.uint8) # WeightX/WeightH data wx = np.random.rand(self.IC, 3 * self.OC).astype('float32') * 2 - 1 @@ -58,22 +59,23 @@ class TestFusionGRUINT8MKLDNNOp(OpTest): # WeightX data shape in PP: [IC, 3 * OC] # WeightH data shape in PP: [OC, 2 * OC] + [OC, OC] # Scales shape in oneDNN: [3, OC] - scale_ur = 63 / np.max(np.abs( + s8_max = 127.0 + scale_ur = s8_max / np.max(np.abs( np.concatenate( [ wx[:, :2 * self.OC], wh.flatten()[:2 * self.OC * self.OC] .reshape(self.OC, 2 * self.OC) ], axis=0)), - axis=0) - scale_o = 63 / np.max(np.abs( + axis=0) + scale_o = s8_max / np.max(np.abs( np.concatenate( [ wx[:, 2 * self.OC:], wh.flatten()[2 * self.OC * self.OC:] .reshape(self.OC, self.OC) ], axis=0)), - axis=0) + axis=0) scale_weights = np.concatenate([scale_ur, scale_o]).astype('float') @@ -102,7 +104,9 @@ class TestFusionGRUINT8MKLDNNOp(OpTest): self.outputs = {'Hidden': (hidden_f32, self.lod)} else: self.error_margin = 1 - hidden_u8 = (hidden_f32 * scale_data + shift_data).astype(np.uint8) + hidden_u8 = np.rint(hidden_f32 * scale_data + shift_data).astype( + np.uint8) + # hidden_u8 = (hidden_f32 * scale_data + shift_data).astype(np.uint8) self.outputs = {'Hidden': (hidden_u8, self.lod)} self.attrs = { diff --git a/tools/codestyle/clang_format.hook b/tools/codestyle/clang_format.hook index 1d92821686..d646e52c43 100755 --- a/tools/codestyle/clang_format.hook +++ b/tools/codestyle/clang_format.hook @@ -1,7 +1,7 @@ #!/bin/bash set -e -readonly VERSION="3.8" +readonly VERSION="3.9" version=$(clang-format -version) -- GitLab