diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index e59d45db8cc622f9eeb507ee7a01986bc5db1e86..e1b77a59911fbe06e0d829e36715be89ef1f656c 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2249,9 +2249,10 @@ PDNode *patterns::MultipleQuantize::operator()() { PDNode *patterns::QuantizePlacement::operator()( const std::unordered_set &quantize_enabled_op_types) { std::unordered_set supported_op_types = - std::unordered_set( - {"concat", "conv2d", "elementwise_add", "fc", "matmul", "pool2d", - "prior_box", "reshape2", "transpose2", "fusion_gru", "multi_gru"}); + std::unordered_set({"concat", "conv2d", "elementwise_add", + "fc", "matmul", "pool2d", "prior_box", + "reshape2", "transpose2", "fusion_gru", + "fusion_lstm", "multi_gru"}); if (!quantize_enabled_op_types.empty()) { supported_op_types = quantize_enabled_op_types; } @@ -2723,6 +2724,26 @@ PDNode *patterns::FusionGru::operator()() { return out; } +PDNode *patterns::FusionLSTM::operator()() { + auto op = pattern->NewNode(op_repr())->assert_is_op("fusion_lstm"); + auto x = pattern->NewNode(x_repr())->AsInput()->assert_is_op_input( + "fusion_lstm", "X"); + auto weight_h = pattern->NewNode(weight_h_repr()) + ->AsInput() + ->assert_is_op_input("fusion_lstm", "WeightH"); + auto weight_x = pattern->NewNode(weight_x_repr()) + ->AsInput() + ->assert_is_op_input("fusion_lstm", "WeightX"); + auto hidden = pattern->NewNode(hidden_repr()) + ->AsOutput() + ->assert_is_op_output("fusion_lstm", "Hidden"); + auto cell = pattern->NewNode(cell_repr()) + ->AsOutput() + ->assert_is_op_output("fusion_lstm", "Cell"); + op->LinksFrom({x, weight_h, weight_x}).LinksTo({hidden, cell}); + return hidden; +} + PDNode *patterns::TwoFusionGruConcat::operator()() { auto x = pattern->NewNode(x_repr())->AsInput()->assert_is_op_input( "fusion_gru", "X"); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 13f65859954d58ce446ab3b9de488833f6220dee..3cfaa4661ae68e0359245a841aa40caf00329aff 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1562,6 +1562,28 @@ struct FusionGru : public PatternBase { PATTERN_DECL_NODE(out); }; +// fusion_lstm op +// Forward pass for fusion_lstm. +// fusion_lstm out is a result of the operator. +struct FusionLSTM : public PatternBase { + FusionLSTM(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "fusion_lstm") {} + // TODO(lidanqing): Is it enough to detect fusion_lstm with these things + PDNode* operator()(); + + // declare op + PATTERN_DECL_NODE(op); + + // declate inputs + PATTERN_DECL_NODE(x); + PATTERN_DECL_NODE(weight_h); + PATTERN_DECL_NODE(weight_x); + + // decalre outputs + PATTERN_DECL_NODE(hidden); + PATTERN_DECL_NODE(cell); +}; + // two concatenated fusion_gru ops // Forward pass for fusion of two concatenated fusion_gru ops. // concat_out is a result of the operator(). diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc index 0abee33b2942ada95591c18110d79c0b755fe8ba..2bf8a3b64f0a78039d896c141aaad7e2f3e1fe2c 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc @@ -944,6 +944,64 @@ void CPUQuantizePass::QuantizeMultiGru(Graph* graph) const { PrettyLogDetail("--- quantized %d multi_gru ops", quantize_count); } +void CPUQuantizePass::QuantizeFusionLSTM(Graph* graph) const { + GraphPatternDetector gpd; + patterns::FusionLSTM pattern{gpd.mutable_pattern(), name_scope_}; + pattern(); + + int quantize_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "Quantize fusion_lstm op"; + GET_IR_NODE_FROM_SUBGRAPH(op, op, pattern); + + // skip if should not be quantized + if (!platform::HasOpINT8DataType(op->Op())) { + LogQuantizationDisabled(op); + return; + } + + GET_IR_NODE_FROM_SUBGRAPH(x, x, pattern); + GET_IR_NODE_FROM_SUBGRAPH(weight_h, weight_h, pattern); + GET_IR_NODE_FROM_SUBGRAPH(weight_x, weight_x, pattern); + GET_IR_NODE_FROM_SUBGRAPH(hidden, hidden, pattern); + GET_IR_NODE_FROM_SUBGRAPH(cell, cell, pattern); + + // Starting from here there maybe issues + if (!AreScalesPresentForNodes({x, weight_x})) { + LogCannotQuantizeOp(op); + return; + } + + bool is_x_unsigned{false}; + auto input_x_scale = GetScaleValueForNode(x, &is_x_unsigned); + + double input_x_shift{128.}; + if (is_x_unsigned) input_x_shift = 0.; + + QuantizeInput(g, op, x, "X", input_x_scale, is_x_unsigned, "Scale_data", + input_x_shift, "Shift_data"); + + auto weight_scale_tensor = GetScaleTensorForNode(weight_x); + EigenVectorArrayMap eigen_tensor{weight_scale_tensor.data(), + weight_scale_tensor.numel()}; + eigen_tensor *= static_cast(S8_MAX); + std::vector scale_weights{ + weight_scale_tensor.data(), + weight_scale_tensor.data() + weight_scale_tensor.numel()}; + + op->Op()->SetAttr("Scale_weights", scale_weights); + // return fp32 data + op->Op()->SetAttr("force_fp32_output", true); + + ++quantize_count; + }; + gpd(graph, handler); + AddStatis(quantize_count); + + PrettyLogDetail("--- quantized %d fusion_lstm ops", quantize_count); +} + void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { VLOG(3) << "Quantizing the graph."; PADDLE_ENFORCE_NOT_NULL( @@ -965,6 +1023,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { QuantizeElementwiseAdd(graph); QuantizeFusionGru(graph); QuantizeMultiGru(graph); + QuantizeFusionLSTM(graph); } } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h index 896b31c154710cf25645e66e45248771fa31f2ba..18735633c0d69a231c2873ddb3ce73112ea5cdae 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h @@ -60,6 +60,7 @@ class CPUQuantizePass : public FusePassBase { void QuantizeElementwiseAdd(Graph* graph) const; void QuantizeFusionGru(Graph* graph) const; void QuantizeMultiGru(Graph* graph) const; + void QuantizeFusionLSTM(Graph* graph) const; void QuantizeInput(Graph* g, Node* op, Node* input, std::string input_name, double scale_to_one, bool is_input_unsigned, diff --git a/paddle/fluid/inference/api/mkldnn_quantizer.cc b/paddle/fluid/inference/api/mkldnn_quantizer.cc index b3768dda24c07be4da4c011501dd9f940c5431a2..d0e47eea16f14076293152a5992566f92f4d8b35 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer.cc @@ -61,8 +61,8 @@ static void check_tensor(const LoDTensor& tensor) { "Tensor dimension is empty.")); } -void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForGRUWeights( - const paddle::framework::OpDesc* op) { +void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForRNNWeights( + const paddle::framework::OpDesc* op, bool gru) { const auto& wx_names = op->Input("WeightX"); const auto& wh_names = op->Input("WeightH"); for (size_t i = 0; i < wx_names.size(); ++i) { @@ -74,14 +74,20 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForGRUWeights( check_var(wh_var, wh_name); LoDTensor* wx_tensor = wx_var->GetMutable(); LoDTensor* wh_tensor = wh_var->GetMutable(); - scales_[wx_name] = GetMaxChGRUScalingFactor(*wx_tensor, *wh_tensor); + if (gru) { + scales_[wx_name] = GetMaxChGRUScalingFactor(*wx_tensor, *wh_tensor); + } else { + scales_[wx_name] = GetMaxChLSTMScalingFactor(*wx_tensor, *wh_tensor); + } } } void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForOpInputs( const paddle::framework::OpDesc* op) { if (op->Type() == "fusion_gru" || op->Type() == "multi_gru") { - CalculateScalesForGRUWeights(op); + CalculateScalesForRNNWeights(op, true); + } else if (op->Type() == "fusion_lstm") { + CalculateScalesForRNNWeights(op, false); } for (auto const& input : op->Inputs()) { for (const auto& var_name : input.second) { @@ -464,6 +470,41 @@ AnalysisPredictor::MkldnnQuantizer::GetMaxChGRUScalingFactor( return std::make_pair(is_unsigned, scale_tensor); } +std::pair +AnalysisPredictor::MkldnnQuantizer::GetMaxChLSTMScalingFactor( + const LoDTensor& wx_tensor, const LoDTensor& wh_tensor) const { + check_tensor(wx_tensor); + check_tensor(wh_tensor); + + std::vector scale(wx_tensor.dims()[1]); + + for (int row_id = 0; row_id < wx_tensor.dims()[0]; row_id++) { + for (int col_id = 0; col_id < wx_tensor.dims()[1]; col_id++) { + int idx = (row_id * wx_tensor.dims()[1]) + col_id; + auto abs_value = std::abs(wx_tensor.data()[idx]); + if (row_id == 0) { + scale[col_id] = abs_value; + } else { + if (abs_value > scale[col_id]) scale[col_id] = abs_value; + } + } + } + for (int row_id = 0; row_id < wh_tensor.dims()[0]; row_id++) { + for (int col_id = 0; col_id < wh_tensor.dims()[1]; col_id++) { + int idx = (row_id * wh_tensor.dims()[1]) + col_id; + auto abs_value = std::abs(wh_tensor.data()[idx]); + if (abs_value > scale[col_id]) scale[col_id] = abs_value; + } + } + transform(scale.begin(), scale.end(), scale.begin(), + [](float& c) { return 1 / c; }); + LoDTensor scale_tensor = CreateScaleTensor(scale.size()); + auto* scale_ptr = scale_tensor.mutable_data(CPUPlace()); + std::copy(scale.begin(), scale.end(), scale_ptr); + bool is_unsigned = false; + return std::make_pair(is_unsigned, scale_tensor); +} + std::pair, float> AnalysisPredictor::MkldnnQuantizer::Histogram( const framework::LoDTensor& var_tensor, float min_val, float max_val, diff --git a/paddle/fluid/inference/api/mkldnn_quantizer.h b/paddle/fluid/inference/api/mkldnn_quantizer.h index c41b9d08f676227ef8330880fc01b74a86aa05f1..5e7aa39de52bc74d424c53fc593452e56bd7e6ba 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer.h +++ b/paddle/fluid/inference/api/mkldnn_quantizer.h @@ -69,7 +69,8 @@ class AnalysisPredictor::MkldnnQuantizer { bool is_unsigned); void CalculateSingleGRUWeightsScale(const std::string& var_name, const framework::LoDTensor& var_tensor); - void CalculateScalesForGRUWeights(const paddle::framework::OpDesc* op); + void CalculateScalesForRNNWeights(const paddle::framework::OpDesc* op, + bool gru); void CalculateScalesForOpOutputs(const paddle::framework::OpDesc* op); void CalculateScalesForOpInputs(const paddle::framework::OpDesc* op); void PrepareArgument() const; @@ -91,6 +92,10 @@ class AnalysisPredictor::MkldnnQuantizer { const framework::LoDTensor& wx_tensor, const framework::LoDTensor& wh_tensor) const; + std::pair GetMaxChLSTMScalingFactor( + const framework::LoDTensor& wx_tensor, + const framework::LoDTensor& wh_tensor) const; + std::pair GetMaxScalingFactor( const framework::LoDTensor& var_tensor, bool is_unsigned) const; diff --git a/paddle/fluid/inference/api/mkldnn_quantizer_config.cc b/paddle/fluid/inference/api/mkldnn_quantizer_config.cc index 245bee57c98fc3a15db4a86f193e955b0009c40e..5a07cc7e240d5e760ea4464e5519fe2795fe767a 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer_config.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer_config.cc @@ -85,6 +85,25 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() { rules_["multi_gru"]["WeightH"] = ScaleAlgo::NONE; // separately rules_["multi_gru"]["Scale_weights"] = ScaleAlgo::NONE; rules_["multi_gru"]["Hidden"] = ScaleAlgo::KL; + + rules_["fusion_lstm"]["X"] = ScaleAlgo::KL; + rules_["fusion_lstm"]["H0"] = ScaleAlgo::NONE; + rules_["fusion_lstm"]["C0"] = ScaleAlgo::NONE; + rules_["fusion_lstm"]["Bias"] = ScaleAlgo::NONE; + rules_["fusion_lstm"]["WeightX"] = + ScaleAlgo::NONE; // Weights will be handled separately + rules_["fusion_lstm"]["WeightH"] = ScaleAlgo::NONE; + rules_["fusion_lstm"]["XX"] = ScaleAlgo::NONE; + rules_["fusion_lstm"]["Cell"] = ScaleAlgo::NONE; + rules_["fusion_lstm"]["BatchedInput"] = ScaleAlgo::NONE; + rules_["fusion_lstm"]["BatchedHidden"] = ScaleAlgo::NONE; + rules_["fusion_lstm"]["BatchedCell"] = ScaleAlgo::NONE; + rules_["fusion_lstm"]["BatchedGate"] = ScaleAlgo::NONE; + rules_["fusion_lstm"]["BatchedCellPreAct"] = ScaleAlgo::NONE; + rules_["fusion_lstm"]["ReorderedH0"] = ScaleAlgo::NONE; + rules_["fusion_lstm"]["ReorderedC0"] = ScaleAlgo::NONE; + rules_["fusion_lstm"]["CheckedCell"] = ScaleAlgo::NONE; + rules_["fusion_lstm"]["Hidden"] = ScaleAlgo::KL; } ScaleAlgo MkldnnQuantizerConfig::scale_algo( diff --git a/paddle/fluid/inference/api/mkldnn_quantizer_tester.cc b/paddle/fluid/inference/api/mkldnn_quantizer_tester.cc index 954a9806bec8c383db4a106091857da281dd8695..40e846dab649629ee16ce1f544ff943db7e72683 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer_tester.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer_tester.cc @@ -62,6 +62,12 @@ class MkldnnQuantizerTest : public testing::Test { return mkldnn_quantizer->GetMaxChGRUScalingFactor(wx_tensor, wh_tensor); } + std::pair GetMaxChLSTMScalingFactor( + const framework::LoDTensor& wx_tensor, + const framework::LoDTensor& wh_tensor) const { + return mkldnn_quantizer->GetMaxChLSTMScalingFactor(wx_tensor, wh_tensor); + } + protected: std::unique_ptr predictor; std::unique_ptr mkldnn_quantizer; @@ -297,4 +303,33 @@ TEST_F(MkldnnQuantizerTest, max_ch_gru_scaling_factor) { ASSERT_NEAR(lod_tensor.data()[i], scales[i], abs_error); } } + +TEST_F(MkldnnQuantizerTest, max_ch_lstm_scaling_factor) { + framework::LoDTensor wx_tensor, wh_tensor, lod_tensor; + + wx_tensor.Resize(framework::make_dim(wx.size(), wx[0].size())); + for (size_t i = 0; i < wx.size(); i++) + std::copy( + begin(wx[i]), end(wx[i]), + wx_tensor.mutable_data(platform::CPUPlace()) + i * wx[0].size()); + + wh_tensor.Resize(framework::make_dim(wh.size(), wh[0].size())); + for (size_t i = 0; i < wh.size(); i++) + std::copy( + begin(wh[i]), end(wh[i]), + wh_tensor.mutable_data(platform::CPUPlace()) + i * wh[0].size()); + + bool is_unsigned; + std::tie(is_unsigned, lod_tensor) = + GetMaxChLSTMScalingFactor(wx_tensor, wh_tensor); + + std::vector scales = {2.35381475, 1.10797026, 1.00151656, + 1.19001095, 1.09045166, 1.01785819}; + ASSERT_EQ(is_unsigned, false); + ASSERT_EQ(lod_tensor.numel(), static_cast(scales.size())); + for (int64_t i = 0; i < lod_tensor.numel(); i++) { + ASSERT_NEAR(lod_tensor.data()[i], scales[i], abs_error); + } +} + } // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index b2e3de63691c555b24eb6f1e1fb9ffcc35d400f9..704fbb2b95c8929fdb8c76072c804340b3c0fe08 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -192,7 +192,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { "seqpool_cvm_concat_fuse_pass", // // "embedding_fc_lstm_fuse_pass", // // TODO(wilber): fix correctness problem. - // "fc_lstm_fuse_pass", // + // "fc_lstm_fuse_pass", // "mul_lstm_fuse_pass", // "fc_gru_fuse_pass", // "mul_gru_fuse_pass", // diff --git a/paddle/fluid/operators/fused/mkldnn/fusion_lstm_mkldnn_op.cc b/paddle/fluid/operators/fused/mkldnn/fusion_lstm_mkldnn_op.cc index 1adbd5cd9e7bc5f3fb5c3dc36868467e2f0b6e4b..a61a3de62f3978943ea9ee4e7c0785d11582dedd 100644 --- a/paddle/fluid/operators/fused/mkldnn/fusion_lstm_mkldnn_op.cc +++ b/paddle/fluid/operators/fused/mkldnn/fusion_lstm_mkldnn_op.cc @@ -78,12 +78,12 @@ class LSTMMKLDNNHandler auto bias_md = MKLDNNMemDesc({L, D, G, OC}, MKLDNNGetDataType(), MKLDNNMemoryFormat::ldgo); auto hidden_md = MKLDNNMemDesc({Ti, N, OC}, MKLDNNGetDataType(), - MKLDNNMemoryFormat::tnc); + MKLDNNMemoryFormat::any); auto h0_md = MKLDNNMemDesc({L, D, N, OC}, MKLDNNGetDataType(), - MKLDNNMemoryFormat::ldnc); + MKLDNNMemoryFormat::any); auto c0_md = MKLDNNMemDesc({L, D, N, OC}, MKLDNNGetDataType(), - MKLDNNMemoryFormat::ldnc); + MKLDNNMemoryFormat::any); // Create LSTM oneDNN primitive const auto direction = diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt index febed599783417f713a17d924ed90a2483aa452a..329b96898e975cfdb8d5f5f7ec0090c358271927 100644 --- a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -43,6 +43,12 @@ function(download_quant_fp32_model install_dir data_file check_sum) endif() endfunction() +function(download_lstm_model install_dir data_file check_sum) + if (NOT EXISTS ${install_dir}/${data_file}) + inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/lstm ${data_file} ${check_sum}) + endif() +endfunction() + function(inference_quant_int8_image_classification_test target quant_model_dir dataset_path) py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/quant_int8_image_classification_comparison.py" ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} @@ -86,6 +92,20 @@ function(inference_quant2_int8_nlp_test target quant_model_dir fp32_model_dir da --ops_to_quantize ${ops_to_quantize}) endfunction() +function(inference_quant2_int8_lstm_model_test target fp32_model dataset_path) + py_test(${target} SRCS "${CMAKE_CURRENT_SOURCE_DIR}/quant2_int8_lstm_model.py" + ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} + OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI} + FLAGS_use_mkldnn=true + ARGS --fp32_model ${fp32_model} + --infer_data ${dataset_path} + --num_threads 4 + --mkldnn_cache_capacity 100 + --warmup_iter 100 + --warmup_batch_size 1 + --acc_diff_threshold 0.11) +endfunction() + function(download_quant_data install_dir data_file check_sum) if (NOT EXISTS ${install_dir}/${data_file}) inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8 ${data_file} ${check_sum}) @@ -260,6 +280,16 @@ if(LINUX AND WITH_MKLDNN) set(QUANT2_INT8_ERNIE_DOT_SAVE_PATH "${QUANT_INSTALL_DIR}/Ernie_quant2_int8_dot_file") convert_model2dot_test(convert_model2dot_ernie ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QUANT2_INT8_ERNIE_DOT_SAVE_PATH} "Ernie_quant2_int8") + ### PTQ INT8 + + # PTQ int8 lstm model + set(LSTM_DATA_ARCHIVE "unittest_model_data/quant_lstm_input_data.tar.gz") + set(QUANT2_INT8_LSTM_SAVE_PATH "${QUANT_INSTALL_DIR}/lstm_quant2") + download_quant_data(${QUANT2_INT8_LSTM_SAVE_PATH} ${LSTM_DATA_ARCHIVE} add84c754e9b792fea1fbd728d134ab7) + set(QUANT2_FP32_LSTM_MODEL_ARCHIVE "lstm_fp32_model.tar.gz") + download_lstm_model(${QUANT2_INT8_LSTM_SAVE_PATH} ${QUANT2_FP32_LSTM_MODEL_ARCHIVE} eecd9f44d69a84acc1cf2235c4b8b743) + inference_quant2_int8_lstm_model_test(test_quant2_int8_lstm_mkldnn ${QUANT2_INT8_LSTM_SAVE_PATH}/lstm_fp32_model ${QUANT2_INT8_LSTM_SAVE_PATH}/quant_lstm_input_data) + endif() # Since the tests for Quant & INT8 comparison support only testing on Linux @@ -323,4 +353,5 @@ if(LINUX AND WITH_MKLDNN) set_tests_properties(test_quant2_int8_ernie_mkldnn PROPERTIES TIMEOUT 120) set_tests_properties(test_quant_int8_googlenet_mkldnn PROPERTIES TIMEOUT 120) set_tests_properties(test_quant2_int8_resnet50_mkldnn PROPERTIES TIMEOUT 120) + set_tests_properties(test_quant2_int8_lstm_mkldnn PROPERTIES TIMEOUT 120) endif() diff --git a/python/paddle/fluid/contrib/slim/tests/quant2_int8_lstm_model.py b/python/paddle/fluid/contrib/slim/tests/quant2_int8_lstm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..0e33bd8ba1a4e085fc46ff132a20c1a4a06360bf --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/quant2_int8_lstm_model.py @@ -0,0 +1,219 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import numpy as np +import struct +import sys +import time +import unittest +from paddle import fluid +from paddle.fluid.core import AnalysisConfig, create_paddle_predictor + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--fp32_model', type=str, default='', help='A path to a FP32 model.') + parser.add_argument('--infer_data', type=str, default='', help='Data file.') + parser.add_argument( + '--num_threads', type=int, default=1, help='Number of threads.') + parser.add_argument( + '--warmup_iter', + type=int, + default=1, + help='Number of the first iterations to skip in performance statistics.') + parser.add_argument( + '--warmup_batch_size', + type=int, + default=1, + help='Number of batches to use in PTQ warmup. Default: 1.') + parser.add_argument( + '--acc_diff_threshold', + type=float, + default=0.01, + help='Accepted accuracy difference threshold.') + parser.add_argument( + '--mkldnn_cache_capacity', + type=int, + default=0, + help='Mkldnn cache capacity. The default value in Python API is 15, which can slow down int8 models. Default 0 means unlimited cache.' + ) + + test_args, args = parser.parse_known_args(namespace=unittest) + return test_args, sys.argv[:1] + args + + +class TestLstmModelPTQ(unittest.TestCase): + def get_warmup_tensor(self, data_path, place, warmup_batch_size): + data = [] + with open(data_path, 'rb') as in_f: + while True: + plen = in_f.read(4) + if plen is None or len(plen) != 4: + break + + alllen = struct.unpack('i', plen)[0] + label_len = alllen & 0xFFFF + seq_len = (alllen >> 16) & 0xFFFF + + label = in_f.read(4 * label_len) + label = np.frombuffer( + label, dtype=np.int32).reshape([len(label) // 4]) + feat = in_f.read(4 * seq_len * 8) + feat = np.frombuffer( + feat, dtype=np.float32).reshape([len(feat) // 4 // 8, 8]) + lod_feat = [feat.shape[0]] + minputs = fluid.create_lod_tensor(feat, [lod_feat], place) + + infer_data = fluid.core.PaddleTensor() + infer_data.lod = minputs.lod() + infer_data.data = fluid.core.PaddleBuf(np.array(minputs)) + infer_data.shape = minputs.shape() + infer_data.dtype = fluid.core.PaddleDType.FLOAT32 + infer_label = fluid.core.PaddleTensor() + infer_label.data = fluid.core.PaddleBuf(np.array(label)) + infer_label.shape = label.shape + infer_label.dtype = fluid.core.PaddleDType.INT32 + data.append([infer_data, infer_label]) + warmup_data = data[:warmup_batch_size] + inputs = data[warmup_batch_size:] + return warmup_data, inputs + + def set_config(self, + model_path, + num_threads, + mkldnn_cache_capacity, + warmup_batch_size, + warmup_data=None, + enable_int8=False): + config = AnalysisConfig(model_path) + config.disable_gpu() + config.switch_use_feed_fetch_ops(True) + config.switch_ir_optim(True) + config.set_cpu_math_library_num_threads(num_threads) + # This pass to work properly, must be added before fc_fuse_pass + config.pass_builder().insert_pass(5, "fc_lstm_fuse_pass") + config.enable_mkldnn() + config.set_mkldnn_cache_capacity(mkldnn_cache_capacity) + if enable_int8: + config.enable_quantizer() + config.quantizer_config().set_quant_data(warmup_data) + config.quantizer_config().set_quant_batch_size(warmup_batch_size) + return config + + def run_program(self, + model_path, + data_path, + num_threads, + mkldnn_cache_capacity, + warmup_iter, + warmup_batch_size, + enable_ptq_int8=False): + place = fluid.CPUPlace() + warmup_data, inputs = self.get_warmup_tensor(data_path, place, + warmup_batch_size) + warmup_data = [item[0] for item in warmup_data] + config = self.set_config(model_path, num_threads, mkldnn_cache_capacity, + warmup_batch_size, warmup_data, + enable_ptq_int8) + + predictor = create_paddle_predictor(config) + data = [item[0] for item in inputs] + label = np.array([item[1] for item in inputs]) + + all_hz_num = 0 + ok_hz_num = 0 + all_ctc_num = 0 + ok_ctc_num = 0 + + dataset_size = len(data) + start = time.time() + for i in range(dataset_size): + if i == warmup_iter: + start = time.time() + hz_out, ctc_out = predictor.run([data[i]]) + np_hz_out = np.array(hz_out.data.float_data()).reshape(-1) + np_ctc_out = np.array(ctc_out.data.int64_data()).reshape(-1) + + out_hz_label = np.argmax(np_hz_out) + + this_label = label[i] + this_label_data = np.array(this_label.data.int32_data()).reshape(-1) + if this_label.shape[0] == 1: + all_hz_num += 1 + best = this_label_data[0] + if out_hz_label == best: + ok_hz_num += 1 + + if this_label_data[0] <= 6350: + all_ctc_num += 1 + if np_ctc_out.shape[0] == 1 and np_ctc_out.all( + ) == this_label_data.all(): + ok_ctc_num += 1 + else: + all_ctc_num += 1 + if np_ctc_out.shape[0] == this_label.shape[ + 0] and np_ctc_out.all() == this_label_data.all(): + ok_ctc_num += 1 + + if all_ctc_num > 1000 or all_hz_num > 1000: + break + + end = time.time() + fps = (dataset_size - warmup_iter) / (end - start) + hx_acc = ok_hz_num / all_hz_num + ctc_acc = ok_ctc_num / all_ctc_num + return hx_acc, ctc_acc, fps + + def test_lstm_model(self): + if not fluid.core.is_compiled_with_mkldnn(): + return + + fp32_model = test_case_args.fp32_model + assert fp32_model, 'The FP32 model path cannot be empty. Please, use the --fp32_model option.' + infer_data = test_case_args.infer_data + assert infer_data, 'The dataset path cannot be empty. Please, use the --infer_data option.' + num_threads = test_case_args.num_threads + mkldnn_cache_capacity = test_case_args.mkldnn_cache_capacity + warmup_iter = test_case_args.warmup_iter + warmup_batch_size = test_case_args.warmup_batch_size + acc_diff_threshold = test_case_args.acc_diff_threshold + + (fp32_hx_acc, fp32_ctc_acc, fp32_fps) = self.run_program( + fp32_model, infer_data, num_threads, mkldnn_cache_capacity, + warmup_iter, warmup_batch_size, False) + + (int8_hx_acc, int8_ctc_acc, int8_fps) = self.run_program( + fp32_model, infer_data, num_threads, mkldnn_cache_capacity, + warmup_iter, warmup_batch_size, True) + + print("FP32: fps {0}, hx_acc {1}, ctc_acc {2}.".format( + fp32_fps, fp32_hx_acc, fp32_ctc_acc)) + + print("PTQ INT8: fps {0}, hx_acc {1}, ctc_acc {2}.".format( + int8_fps, int8_hx_acc, int8_ctc_acc)) + + sys.stdout.flush() + + hx_delta_value = fp32_hx_acc - int8_hx_acc + ctc_delta_value = fp32_ctc_acc - int8_ctc_acc + self.assertLess(hx_delta_value, acc_diff_threshold) + self.assertLess(ctc_delta_value, acc_diff_threshold) + + +if __name__ == "__main__": + global test_case_args + test_case_args, remaining_args = parse_args() + unittest.main(argv=remaining_args)