diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index dd0ffe8b9fd0d44b9bf87b4538bf6da5907f16a0..5334b08248992b1df71c717f6640a60a0dbc395c 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1619,6 +1619,26 @@ PDNode *patterns::Reshape::operator()() { return reshape_out; } +PDNode *patterns::Slice::operator()() { + auto prev_op = pattern->NewNode(prev_op_repr())->assert_is_op(); + + auto slice_op = pattern->NewNode(slice_op_repr())->assert_is_op("slice"); + + auto slice_in = pattern->NewNode(slice_in_repr()) + ->AsInput() + ->assert_is_op_input("slice", "Input"); + auto slice_out = pattern->NewNode(slice_out_repr()) + ->AsOutput() + ->assert_is_op_output("slice", "Out"); + + auto next_op = pattern->NewNode(next_op_repr())->assert_is_op(); + + prev_op->LinksTo({slice_in}); + slice_op->LinksFrom({slice_in}).LinksTo({slice_out}); + next_op->LinksFrom({slice_out}); + return slice_out; +} + PDNode *patterns::Matmul::operator()() { auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul"); @@ -2315,7 +2335,7 @@ PDNode *patterns::QuantizePlacement::operator()( std::unordered_set({"concat", "conv2d", "elementwise_add", "fc", "matmul", "pool2d", "prior_box", "reshape2", "transpose2", "fusion_gru", - "fusion_lstm", "multi_gru"}); + "fusion_lstm", "multi_gru", "slice"}); if (!quantize_enabled_op_types.empty()) { supported_op_types = quantize_enabled_op_types; } diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index d7bfdc57d1c7ed525d3df765b7c3f3d221e72a3f..fa8504d074a8842999b1d85b586b0e1d3c84001b 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -980,6 +980,20 @@ struct Reshape : public PatternBase { PATTERN_DECL_NODE(reshape_out); PATTERN_DECL_NODE(next_op); }; +// Slice op +// Forward pass for slice. +// slice_out is a result of the operator. +struct Slice : public PatternBase { + Slice(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "slice") {} + + PDNode* operator()(); + PATTERN_DECL_NODE(prev_op); + PATTERN_DECL_NODE(slice_in); + PATTERN_DECL_NODE(slice_op); + PATTERN_DECL_NODE(slice_out); + PATTERN_DECL_NODE(next_op); +}; // Matmul op // Forward pass for matmul. diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc index 2bf8a3b64f0a78039d896c141aaad7e2f3e1fe2c..3df4a844705242f78e2c0f59ebd012101fb628b0 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc @@ -676,6 +676,57 @@ void CPUQuantizePass::QuantizeReshape(Graph* graph) const { PrettyLogDetail("--- quantized %d reshape ops", quantize_reshape_count); } +void CPUQuantizePass::QuantizeSlice(Graph* graph) const { + GraphPatternDetector gpd; + auto pattern = gpd.mutable_pattern(); + patterns::Slice slice_pattern{pattern, name_scope_}; + slice_pattern(); + + int quantize_slice_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "Quantize slice op"; + GET_IR_NODE_FROM_SUBGRAPH(slice_op, slice_op, slice_pattern); + + // skip if should not be quantized + if (!platform::HasOpINT8DataType(slice_op->Op())) { + LogQuantizationDisabled(slice_op); + return; + } + GET_IR_NODE_FROM_SUBGRAPH(prev_op, prev_op, slice_pattern); + GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, slice_pattern); + + // skip if prev op and next op is not quantized + if (!IsOpDequantized(prev_op) && !IsOpQuantized(next_op)) { + return; + } + GET_IR_NODE_FROM_SUBGRAPH(slice_in, slice_in, slice_pattern); + GET_IR_NODE_FROM_SUBGRAPH(slice_out, slice_out, slice_pattern); + + if (!AreScalesPresentForNodes({slice_out})) { + LogCannotQuantizeOp(slice_op); + return; + } + + bool is_input_unsigned{false}; + auto input_scale = GetScaleValueForNode(slice_out, &is_input_unsigned); + QuantizeInput(g, slice_op, slice_in, "Input", input_scale, + is_input_unsigned); + + bool is_output_unsigned{false}; + auto output_scale = GetScaleValueForNode(slice_out, &is_output_unsigned); + DequantizeOutput(g, slice_op, slice_out, "Out", output_scale, + is_output_unsigned); + + ++quantize_slice_count; + }; + + gpd(graph, handler); + AddStatis(quantize_slice_count); + + PrettyLogDetail("--- quantized %d slice ops", quantize_slice_count); +} + void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { GraphPatternDetector gpd; auto pattern = gpd.mutable_pattern(); @@ -1024,6 +1075,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { QuantizeFusionGru(graph); QuantizeMultiGru(graph); QuantizeFusionLSTM(graph); + QuantizeSlice(graph); } } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h index 18735633c0d69a231c2873ddb3ce73112ea5cdae..b3ee98263c0c0a6f53ab18ec87b59a421e4ccdf3 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h @@ -61,6 +61,7 @@ class CPUQuantizePass : public FusePassBase { void QuantizeFusionGru(Graph* graph) const; void QuantizeMultiGru(Graph* graph) const; void QuantizeFusionLSTM(Graph* graph) const; + void QuantizeSlice(Graph* graph) const; void QuantizeInput(Graph* g, Node* op, Node* input, std::string input_name, double scale_to_one, bool is_input_unsigned, diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc index b6a8de263aa2afb3934226e925961be1592f38dd..838912f659ff7c57683fba3920ede9f2d9829edd 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc @@ -55,6 +55,10 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, op->SetInput("X", {inputs[0]}); op->SetOutput("Out", {outputs[0]}); op->SetAttr("mkldnn_data_type", mkldnn_data_type); + } else if (type == "slice") { + op->SetInput("Input", {inputs[0]}); + op->SetOutput("Out", {outputs[0]}); + op->SetAttr("mkldnn_data_type", mkldnn_data_type); } else if (type == "dropout") { op->SetInput("X", {inputs[0]}); op->SetOutput("Out", {outputs[0]}); @@ -784,6 +788,113 @@ TEST(CpuQuantizePass, reshapeBetweenNonQuantizedOp) { added_nodes_count, 2.0f * 127); } +static const std::initializer_list variable_names_slice = { + "a", "b", "c", "d"}; + +// a->Dequantize->b +// b->Slice->c +// c->Dropout->d +ProgramDesc BuildProgramDescSlice() { + ProgramDesc prog; + for (auto& v : variable_names_slice) { + prog.MutableBlock(0)->Var(v); + } + SetOp(&prog, "dequantize", "Dequantize1", {"a"}, {"b"}, true); + SetOp(&prog, "slice", "Slice", {"b"}, {"c"}, true, "int8"); + SetOp(&prog, "dropout", "Dropout", {"c"}, {"d"}, true, "float32"); + + return prog; +} + +// a->Transpose->b +// b->slice->c +// c->Dropout->d +ProgramDesc BuildProgramDescSliceBetweenNonQuantizedOp() { + ProgramDesc prog; + for (auto& v : variable_names_slice) { + prog.MutableBlock(0)->Var(v); + } + + SetOp(&prog, "transpose2", "Transpose2", {"a"}, {"b"}, true, "float32"); + SetOp(&prog, "slice", "Slice", {"b"}, {"c"}, true, "int8"); + SetOp(&prog, "dropout", "Dropout", {"c"}, {"d"}, true, "float32"); + + return prog; +} + +void MainTestSlice(const ProgramDesc& prog, int transpose_count, + int slice_count, int quant_count, int dequant_count, + int added_nodes_count, float scale) { + std::unique_ptr graph(new ir::Graph(prog)); + int original_nodes_num, current_nodes_num; + PreparePass(&graph, prog, variable_names_slice, &original_nodes_num, + ¤t_nodes_num); + + float quant_scale = 1.0f; + float dequant_scale = 1.0f; + int quantize_nodes_count = 0; + int dequantize_nodes_count = 0; + int transpose_nodes_count = 0; + int slice_nodes_count = 0; + for (auto* node : graph->Nodes()) { + if (node->IsOp()) { + auto* op = node->Op(); + if (op->Type() == "transpose2") { + transpose_nodes_count++; + } else if (op->Type() == "slice") { + slice_nodes_count++; + } else if (op->Type() == "quantize") { + quantize_nodes_count++; + quant_scale = BOOST_GET_CONST(float, op->GetAttr("Scale")); + EXPECT_EQ(quant_scale, scale) << "Scale for node '" + op->Type() + "'."; + } else if (op->Type() == "dequantize") { + dequantize_nodes_count++; + auto op_name = op->GetAttrIfExists("name"); + VLOG(3) << op_name << "\n"; + if (op_name != "Dequantize1") { + dequant_scale = BOOST_GET_CONST(float, op->GetAttr("Scale")); + EXPECT_EQ(dequant_scale, scale) + << "Scale for node '" + op->Type() + "'."; + } + } + } + } + EXPECT_EQ(transpose_nodes_count, transpose_count); + EXPECT_EQ(slice_nodes_count, slice_count); + EXPECT_EQ(quantize_nodes_count, quant_count); + EXPECT_EQ(dequantize_nodes_count, dequant_count); + EXPECT_EQ(original_nodes_num + added_nodes_count, current_nodes_num); +} + +TEST(CpuQuantizePass, slice) { + // a->Dequantize->b + // b2->Quant->b3->slice->c1->Dequant->c2 + // c2->Dropout->d + int slice_count = 1; + int transpose_count = 0; + int quant_count = 1; + int dequant_count = 2; + // 1 Quant + 1 IN + 1 DeQuant + 1 OUT + int added_nodes_count = 4; + MainTestSlice(BuildProgramDescSlice(), transpose_count, slice_count, + quant_count, dequant_count, added_nodes_count, 2.0f * 127); +} + +TEST(CpuQuantizePass, sliceBetweenNonQuantizedOp) { + // a->Transpos2->b + // b->slice->c + // c->Dropout->d + int slice_count = 1; + int transpose_count = 1; + int quant_count = 0; + int dequant_count = 0; + // 0 Quant + 0 IN + 0 DeQuant + 0 OUT + int added_nodes_count = 0; + MainTestSlice(BuildProgramDescSliceBetweenNonQuantizedOp(), transpose_count, + slice_count, quant_count, dequant_count, added_nodes_count, + 2.0f * 127); +} + static const std::initializer_list variable_names_matmul = { "a", "b", "c", "d", "e", "f"}; diff --git a/paddle/fluid/inference/api/mkldnn_quantizer.cc b/paddle/fluid/inference/api/mkldnn_quantizer.cc index 654b58a2ded341a5e422f9ff9a3abb43a646b808..aa29b779e471b3c75acaaafa0934512a6f701c61 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer.cc @@ -134,6 +134,16 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForOpOutputs( scales_[var_name] = scales_[input_var_name]; } compute_scale = false; + } else if (op->Type() == "slice") { + auto input_var_name = op->Input("Input")[0]; + PADDLE_ENFORCE_NE(scales_.find(input_var_name), scales_.end(), + platform::errors::PreconditionNotMet( + "Input scales must be calculated before the " + "output scales to infer if output is unsigned.")); + if (scales_.find(input_var_name) != scales_.end()) { + scales_[var_name] = scales_[input_var_name]; + } + compute_scale = false; } else if (op->Type() == "concat") { // output of ops with unsigned input must be unsigned is_unsigned = true; diff --git a/paddle/fluid/inference/api/mkldnn_quantizer_config.cc b/paddle/fluid/inference/api/mkldnn_quantizer_config.cc index 5a07cc7e240d5e760ea4464e5519fe2795fe767a..6642a2c030b2662ca5ce32969423a41780518674 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer_config.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer_config.cc @@ -42,6 +42,9 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() { rules_["transpose2"]["X"] = ScaleAlgo::KL; rules_["transpose2"]["Out"] = ScaleAlgo::NONE; + rules_["slice"]["Input"] = ScaleAlgo::KL; + rules_["slice"]["Out"] = ScaleAlgo::NONE; + rules_["fc"]["Input"] = ScaleAlgo::KL; rules_["fc"]["W"] = ScaleAlgo::MAX_CH_T; rules_["fc"]["Bias"] = ScaleAlgo::NONE; diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 6fd3944a6c52809c924a149933c98bbdb31ebc2f..a28b0c172aff0e7bea822b8bef7beb3a87945581 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -94,6 +94,17 @@ function(inference_analysis_api_test target install_dir filename) ARGS --infer_model=${install_dir}/model --infer_data=${install_dir}/data.txt --refer_result=${install_dir}/result.txt) endfunction() +function(inference_analysis_api_int8_test target install_dir filename) + inference_analysis_test(${target} SRCS ${filename} + EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} + ARGS --infer_model=${install_dir}/model + --infer_data=${install_dir}/data.txt + --refer_result=${install_dir}/result.txt + --accuracy=0.8 + --batch_size=5 + --enable_int8=true) +endfunction() + function(inference_multiple_models_analysis_api_test target install_dir filename) inference_analysis_test(${target} SRCS ${filename} EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} @@ -284,13 +295,14 @@ set(PYRAMID_DNN_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/pyramid_dnn") download_model_and_data_without_verify(${PYRAMID_DNN_INSTALL_DIR} "PyramidDNN_model.tar.gz" "PyramidDNN_data.txt.tar.gz") inference_analysis_api_test(test_analyzer_pyramid_dnn ${PYRAMID_DNN_INSTALL_DIR} analyzer_pyramid_dnn_tester.cc) -#Ernie +# Ernie set(ERNIE_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie") download_model_and_data(${ERNIE_INSTALL_DIR} "Ernie_model.tar.gz" aa59192dd41ed377f9f168e3a1309fa6 "Ernie_data.txt.tar.gz" 5396e63548edad7ca561e7e26a9476d1) download_result(${ERNIE_INSTALL_DIR} "Ernie_result.txt.tar.gz" 73beea65abda2edb61c1662cd3180c62) inference_analysis_api_test(test_analyzer_ernie ${ERNIE_INSTALL_DIR} analyzer_ernie_tester.cc) +inference_analysis_api_int8_test(test_analyzer_ernie_int8 ${ERNIE_INSTALL_DIR} analyzer_ernie_int8_tester.cc) -#Ernie large +# Ernie large set(ERNIE_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie_Large") download_model_and_data(${ERNIE_INSTALL_DIR} "Ernie_large_model.tar.gz" af7715245ed32cc77374625d4c80f7ef "Ernie_large_data.txt.tar.gz" edb2113eec93783cad56ed76d47ba57f) download_result(${ERNIE_INSTALL_DIR} "Ernie_large_result.txt.tar.gz" 1facda98eef1085dc9d435ebf3f23a73) @@ -426,7 +438,7 @@ if(WITH_MKLDNN) # TODO(grygielski) Enable after MKL-DNN 1.0 merge set(INT8_VGG16_MODEL_DIR "${INT8_DATA_DIR}/vgg16") download_int8_data_without_verify(${INT8_VGG16_MODEL_DIR} "VGG16_int8_model.tar.gz" ) -# inference_analysis_api_int8_test_run(test_analyzer_int8_vgg16 ${INT8_IMG_CLASS_TEST_APP} ${INT8_VGG16_MODEL_DIR} ${IMAGENET_DATA_PATH}) +# inference_analysis_api_int8_test_run(test_analyzer_int8_vgg16 ${INT8_IMG_CLASS_TEST_APP} ${INT8_VGG16_MODEL_DIR} ${IMAGENET_DATA_PATH}) # vgg19 int8 # TODO(grygielski) Enable after MKL-DNN 1.0 merge @@ -730,6 +742,7 @@ set_tests_properties(test_analyzer_mobilenet_transpose PROPERTIES TIMEOUT 120) set_tests_properties(test_analyzer_resnet50 PROPERTIES TIMEOUT 120) set_tests_properties(test_analyzer_ner PROPERTIES TIMEOUT 120) set_tests_properties(test_analyzer_ernie PROPERTIES TIMEOUT 120) +set_tests_properties(test_analyzer_ernie_int8 PROPERTIES TIMEOUT 120) set_tests_properties(test_analyzer_googlenet PROPERTIES TIMEOUT 120) set_tests_properties(test_analyzer_small_dam PROPERTIES TIMEOUT 120) set_tests_properties(test_analyzer_transformer PROPERTIES TIMEOUT 120) diff --git a/paddle/fluid/inference/tests/api/analyzer_ernie_int8_tester.cc b/paddle/fluid/inference/tests/api/analyzer_ernie_int8_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..b85726647b548c77f1ef5e50a592584b8af1ee9e --- /dev/null +++ b/paddle/fluid/inference/tests/api/analyzer_ernie_int8_tester.cc @@ -0,0 +1,54 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/tests/api/analyzer_ernie_tester.h" + +namespace paddle { +namespace inference { + +using paddle::PaddleTensor; + +#ifdef PADDLE_WITH_MKLDNN +void SetInt8Config(AnalysisConfig *cfg, + std::vector data) { + cfg->SetModel(FLAGS_infer_model); + cfg->EnableMKLDNN(); + cfg->EnableMkldnnQuantizer(); + auto warmup_data = std::make_shared>(data); + cfg->mkldnn_quantizer_config()->SetWarmupData(warmup_data); + cfg->mkldnn_quantizer_config()->SetWarmupBatchSize(FLAGS_batch_size); + cfg->SwitchSpecifyInputNames(); + cfg->SwitchIrOptim(); + cfg->SetCpuMathLibraryNumThreads(FLAGS_cpu_num_threads); +} + +// Compare result of NativeConfig and AnalysisConfig +void compare_int8(bool use_mkldnn = false) { + std::vector> inputs; + LoadInputData(&inputs); + + AnalysisConfig cfg; + SetInt8Config(&cfg, inputs[0]); + + CompareNativeAndAnalysis( + reinterpret_cast(&cfg), inputs); +} + +TEST(Analyzer_ernie, compare_int8_mkldnn) { + compare_int8(true /* use_mkldnn */); +} +#endif + +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tests/api/analyzer_ernie_tester.cc b/paddle/fluid/inference/tests/api/analyzer_ernie_tester.cc index 0c2a140023e293895bc8476bcfd0f2e98efb4eab..d6ff3e422368bd9427e4cd3412429baf571c3303 100644 --- a/paddle/fluid/inference/tests/api/analyzer_ernie_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_ernie_tester.cc @@ -12,142 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/inference/tests/api/tester_helper.h" +#include "paddle/fluid/inference/tests/api/analyzer_ernie_tester.h" namespace paddle { namespace inference { using paddle::PaddleTensor; -template -void GetValueFromStream(std::stringstream *ss, T *t) { - (*ss) >> (*t); -} - -template <> -void GetValueFromStream(std::stringstream *ss, std::string *t) { - *t = ss->str(); -} - -// Split string to vector -template -void Split(const std::string &line, char sep, std::vector *v) { - std::stringstream ss; - T t; - for (auto c : line) { - if (c != sep) { - ss << c; - } else { - GetValueFromStream(&ss, &t); - v->push_back(std::move(t)); - ss.str({}); - ss.clear(); - } - } - - if (!ss.str().empty()) { - GetValueFromStream(&ss, &t); - v->push_back(std::move(t)); - ss.str({}); - ss.clear(); - } -} - -// Parse tensor from string -template -bool ParseTensor(const std::string &field, paddle::PaddleTensor *tensor) { - std::vector data; - Split(field, ':', &data); - if (data.size() < 2) return false; - - std::string shape_str = data[0]; - - std::vector shape; - Split(shape_str, ' ', &shape); - - std::string mat_str = data[1]; - - std::vector mat; - Split(mat_str, ' ', &mat); - - tensor->shape = shape; - auto size = - std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()) * - sizeof(T); - tensor->data.Resize(size); - std::copy(mat.begin(), mat.end(), static_cast(tensor->data.data())); - tensor->dtype = GetPaddleDType(); - - return true; -} - -// Parse input tensors from string -bool ParseLine(const std::string &line, - std::vector *tensors) { - std::vector fields; - Split(line, ';', &fields); - - tensors->clear(); - tensors->reserve(4); - - int i = 0; - auto input_name = FLAGS_ernie_large ? "eval_placeholder_" : "placeholder_"; - for (; i < 3; i++) { - paddle::PaddleTensor temp; - ParseTensor(fields[i], &temp); - temp.name = input_name + std::to_string(i); - tensors->push_back(temp); - } - - // input_mask - paddle::PaddleTensor input_mask; - ParseTensor(fields[i], &input_mask); - input_mask.name = input_name + std::to_string(i); - tensors->push_back(input_mask); - - return true; -} - -bool LoadInputData(std::vector> *inputs) { - if (FLAGS_infer_data.empty()) { - LOG(ERROR) << "please set input data path"; - return false; - } - - std::ifstream fin(FLAGS_infer_data); - std::string line; - int sample = 0; - - // The unit-test dataset only have 10 samples, each sample have 5 feeds. - while (std::getline(fin, line)) { - std::vector feed_data; - ParseLine(line, &feed_data); - inputs->push_back(std::move(feed_data)); - sample++; - if (!FLAGS_test_all_data && sample == FLAGS_batch_size) break; - } - LOG(INFO) << "number of samples: " << sample; - return true; -} - -void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false, - bool use_gpu = false) { - cfg->SetModel(FLAGS_infer_model); - if (use_mkldnn) { - cfg->EnableMKLDNN(); - } - if (use_gpu) { - cfg->EnableUseGpu(100, 0); - } else { - cfg->DisableGpu(); - } - cfg->SwitchSpecifyInputNames(); - cfg->SwitchIrOptim(); - cfg->SetCpuMathLibraryNumThreads(FLAGS_cpu_num_threads); -} - void profile(bool use_mkldnn = false, bool use_gpu = false) { AnalysisConfig config; + SetConfig(&config, use_mkldnn, use_gpu); std::vector> outputs; @@ -189,11 +63,12 @@ TEST(Analyzer_Ernie, fuse_statis) { // Compare result of NativeConfig and AnalysisConfig void compare(bool use_mkldnn = false) { + std::vector> inputs; + LoadInputData(&inputs); + AnalysisConfig cfg; SetConfig(&cfg, use_mkldnn, false); - std::vector> inputs; - LoadInputData(&inputs); CompareNativeAndAnalysis( reinterpret_cast(&cfg), inputs); } diff --git a/paddle/fluid/inference/tests/api/analyzer_ernie_tester.h b/paddle/fluid/inference/tests/api/analyzer_ernie_tester.h new file mode 100644 index 0000000000000000000000000000000000000000..dd3faac7592104ba47c7f7db54c8c0114c8cb1f1 --- /dev/null +++ b/paddle/fluid/inference/tests/api/analyzer_ernie_tester.h @@ -0,0 +1,152 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/inference/tests/api/tester_helper.h" + +namespace paddle { +namespace inference { + +using paddle::PaddleTensor; + +template +void GetValueFromStream(std::stringstream *ss, T *t) { + (*ss) >> (*t); +} + +template <> +void GetValueFromStream(std::stringstream *ss, std::string *t) { + *t = ss->str(); +} + +// Split string to vector +template +void Split(const std::string &line, char sep, std::vector *v) { + std::stringstream ss; + T t; + for (auto c : line) { + if (c != sep) { + ss << c; + } else { + GetValueFromStream(&ss, &t); + v->push_back(std::move(t)); + ss.str({}); + ss.clear(); + } + } + + if (!ss.str().empty()) { + GetValueFromStream(&ss, &t); + v->push_back(std::move(t)); + ss.str({}); + ss.clear(); + } +} + +// Parse tensor from string +template +bool ParseTensor(const std::string &field, paddle::PaddleTensor *tensor) { + std::vector data; + Split(field, ':', &data); + if (data.size() < 2) return false; + + std::string shape_str = data[0]; + + std::vector shape; + Split(shape_str, ' ', &shape); + + std::string mat_str = data[1]; + + std::vector mat; + Split(mat_str, ' ', &mat); + + tensor->shape = shape; + auto size = + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()) * + sizeof(T); + tensor->data.Resize(size); + std::copy(mat.begin(), mat.end(), static_cast(tensor->data.data())); + tensor->dtype = GetPaddleDType(); + + return true; +} + +// Parse input tensors from string +bool ParseLine(const std::string &line, + std::vector *tensors) { + std::vector fields; + Split(line, ';', &fields); + + tensors->clear(); + tensors->reserve(4); + + int i = 0; + auto input_name = FLAGS_ernie_large ? "eval_placeholder_" : "placeholder_"; + for (; i < 3; i++) { + paddle::PaddleTensor temp; + ParseTensor(fields[i], &temp); + temp.name = input_name + std::to_string(i); + tensors->push_back(temp); + } + + // input_mask + paddle::PaddleTensor input_mask; + ParseTensor(fields[i], &input_mask); + input_mask.name = input_name + std::to_string(i); + tensors->push_back(input_mask); + + return true; +} + +bool LoadInputData(std::vector> *inputs) { + if (FLAGS_infer_data.empty()) { + LOG(ERROR) << "please set input data path"; + return false; + } + + std::ifstream fin(FLAGS_infer_data); + std::string line; + int sample = 0; + + // The unit-test dataset only have 10 samples, each sample have 5 feeds. + while (std::getline(fin, line)) { + std::vector feed_data; + ParseLine(line, &feed_data); + inputs->push_back(std::move(feed_data)); + sample++; + if (!FLAGS_test_all_data && sample == FLAGS_batch_size) break; + } + LOG(INFO) << "number of samples: " << sample; + return true; +} + +void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false, + bool use_gpu = false) { + cfg->SetModel(FLAGS_infer_model); + if (use_mkldnn) { + cfg->EnableMKLDNN(); + } + if (use_gpu) { + cfg->EnableUseGpu(100, 0); + } else { + cfg->DisableGpu(); + } + cfg->SwitchSpecifyInputNames(); + cfg->SwitchIrOptim(); + cfg->SetCpuMathLibraryNumThreads(FLAGS_cpu_num_threads); +} + +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/operators/mkldnn/slice_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/slice_mkldnn_op.cc index d9bd843a9d0cf07ea23074c7605849cc147734ef..e5f70fa10e3751b2a6af62af3a52ef0d61b8580c 100644 --- a/paddle/fluid/operators/mkldnn/slice_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/slice_mkldnn_op.cc @@ -227,6 +227,8 @@ class SliceGradMKLDNNKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_KERNEL(slice, MKLDNN, paddle::platform::CPUPlace, ops::SliceMKLDNNKernel, + ops::SliceMKLDNNKernel, + ops::SliceMKLDNNKernel, ops::SliceMKLDNNKernel); namespace ops = paddle::operators; diff --git a/paddle/fluid/operators/slice_op.cc b/paddle/fluid/operators/slice_op.cc index a5513ba648776c1906d2a67bd51890ca51dc01fd..4965e5e156c342be084938d6caf31820ca3428bc 100644 --- a/paddle/fluid/operators/slice_op.cc +++ b/paddle/fluid/operators/slice_op.cc @@ -244,7 +244,7 @@ class SliceOpMaker : public framework::OpProtoAndCheckerMaker { "mkldnn_data_type", "(string, default \"float32\"). Data type of mkldnn kernel") .SetDefault("float32") - .InEnum({"float32", "bfloat16"}) + .InEnum({"float32", "int8", "bfloat16"}) .AsExtra(); AddComment(R"DOC( Slice Operator. diff --git a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py index 4c9c4058318a9785ade7210de752f65ae2cbb378..0627bf2123adbdee88aede4ef4c66f4beb8f9dc1 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py @@ -62,7 +62,9 @@ class Quant2Int8MkldnnPass(object): self._ops_to_quantize = _ops_to_quantize self._op_ids_to_skip = _op_ids_to_skip if _op_ids_to_skip is not None else set( [-1]) - self._scale_immutable_ops = ['transpose2', 'reshape2', 'pool2d'] + self._scale_immutable_ops = [ + 'transpose2', 'reshape2', 'pool2d', 'slice' + ] self._scale_ops = ['scale'] self._conv_ops = ['conv2d', 'depthwise_conv2d'] self._pool_ops = ['pool2d'] @@ -241,7 +243,10 @@ class Quant2Int8MkldnnPass(object): waiting_for_scale = set() for op in graph.all_op_nodes(): if op.name() in self._scale_immutable_ops: - input_name = op.input("X")[0] + if op.name() == 'slice': + input_name = op.input("Input")[0] + else: + input_name = op.input("X")[0] output_name = op.output("Out")[0] tensor_names = [input_name, output_name] diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt index 03503111fca9a6e259aefe8657ac07a69e6bcaf1..94d7a2ed1534880445bce372fe29cf67d21a7dd7 100644 --- a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -253,7 +253,7 @@ if(LINUX AND WITH_MKLDNN) set(FP32_ERNIE_MODEL_ARCHIVE "ernie_fp32_model.tar.gz") set(FP32_ERNIE_MODEL_DIR "${QUANT_INSTALL_DIR}/Ernie_float") download_quant_fp32_model(${FP32_ERNIE_MODEL_DIR} ${FP32_ERNIE_MODEL_ARCHIVE} 114f38804a3ef8c45e7259e68bbd838b) - set(QUANT2_ERNIE_OPS_TO_QUANTIZE "fc,reshape2,transpose2,matmul,elementwise_add") + set(QUANT2_ERNIE_OPS_TO_QUANTIZE "fc,reshape2,transpose2,matmul,elementwise_add,slice") inference_quant2_int8_nlp_test(test_quant2_int8_ernie_mkldnn ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${FP32_ERNIE_MODEL_DIR}/ernie_fp32_model ${NLP_DATA_PATH} ${NLP_LABLES_PATH} ${QUANT2_ERNIE_OPS_TO_QUANTIZE}) # Quant2 GRU