From 9df2d8b5baa9edd1977d87d4bf519435a35c195e Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Wed, 5 Sep 2018 13:29:20 +0800 Subject: [PATCH] test/add text-classification test (#13081) --- .../fluid/framework/ir/fc_lstm_fuse_pass.cc | 32 +++-- paddle/fluid/framework/ir/fc_lstm_fuse_pass.h | 2 + .../framework/ir/graph_pattern_detector.cc | 1 - .../framework/ir/graph_pattern_detector.h | 3 + .../fluid/inference/analysis/CMakeLists.txt | 18 ++- paddle/fluid/inference/analysis/analyzer.cc | 37 +++--- paddle/fluid/inference/analysis/analyzer.h | 27 +++-- .../inference/analysis/analyzer_tester.cc | 9 +- paddle/fluid/inference/analysis/flags.h | 22 ++++ .../inference/analysis/fluid_to_ir_pass.h | 9 +- .../analysis/test_text_classification.cc | 109 ++++++++++++++++++ paddle/fluid/inference/api/CMakeLists.txt | 14 ++- .../fluid/inference/api/analysis_predictor.cc | 28 +++-- .../fluid/inference/api/analysis_predictor.h | 6 +- paddle/fluid/inference/api/api_impl.cc | 3 +- .../inference/api/paddle_inference_api.h | 15 +++ paddle/fluid/operators/lookup_table_op.h | 2 +- 17 files changed, 276 insertions(+), 61 deletions(-) create mode 100644 paddle/fluid/inference/analysis/flags.h create mode 100644 paddle/fluid/inference/analysis/test_text_classification.cc diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc index 55153ecc3ed..0d69dfa79aa 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc @@ -86,15 +86,24 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, } op_desc.SetInput("Bias", {new_bias_var}); } - #undef GET_NODE + // Create temp variables. + scope->Var(name_scope + "/BatchedInput.new") + ->GetMutable(); + scope->Var(name_scope + "/BatchCellPreAct.new") + ->GetMutable(); + scope->Var(name_scope + "/BatchedGate.new") + ->GetMutable(); + op_desc.SetInput("H0", {}); op_desc.SetInput("C0", {}); op_desc.SetOutput("Hidden", {hidden_n->Name()}); op_desc.SetOutput("Cell", {cell_n->Name()}); op_desc.SetOutput("XX", {xx_n->Name()}); - op_desc.SetOutput("BatchedInput", {"blstm_0.tmp_2"}); + op_desc.SetOutput("BatchedGate", {name_scope + "/BatchedGate.new"}); + op_desc.SetOutput("BatchCellPreAct", {name_scope + "/BatchCellPreAct.new"}); + op_desc.SetOutput("BatchedInput", {name_scope + "/BatchedInput.new"}); op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse")); op_desc.SetAttr("use_peepholes", lstm_n->Op()->GetAttr("use_peepholes")); // TODO(TJ): get from attr @@ -130,8 +139,8 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, int fusion_count{0}; - auto fc_no_bias_handler = [&]( - const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { #define GET_NODE(name__) \ std::string name__##key = name_scope + "/" + #name__; \ auto* name__##n = pattern->RetrieveNode(name__##key); \ @@ -152,21 +161,24 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, if (with_fc_bias) { GET_NODE(fc_bias); + GET_NODE(elementwise_add); lstm_creator(lstm, x, w, Weight, Bias, Hidden, Cell, fc_out, fc_bias); + // Remove unneeded nodes. + std::unordered_set marked_nodes( + {mul_n, lstm_n, elementwise_add_n}); + GraphSafeRemoveNodes(graph, marked_nodes); } else { lstm_creator(lstm, x, w, Weight, Bias, Hidden, Cell, fc_out, -1); + // Remove unneeded nodes. + std::unordered_set marked_nodes({mul_n, lstm_n}); + GraphSafeRemoveNodes(graph, marked_nodes); } #undef GET_NODE - // Remove unneeded nodes. - std::unordered_set marked_nodes({mul_n, lstm_n}); - - GraphSafeRemoveNodes(graph, marked_nodes); - ++fusion_count; }; - gpd(graph, fc_no_bias_handler); + gpd(graph, handler); return fusion_count; } diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h index 5a6687872eb..3ee32c63a46 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.h @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#pragma once + #include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index a4da69a0a2e..434bee4ccee 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -73,7 +73,6 @@ void PDPattern::AddEdge(PDNode* a, PDNode* b) { void GraphPatternDetector::operator()(Graph* graph, GraphPatternDetector::handle_t handler) { if (!MarkPDNodesInGraph(*graph)) { - LOG(INFO) << "Mark failed"; return; } diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 9d67c4a6997..eacea1750f6 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -19,6 +19,9 @@ #endif #include +#include +#include +#include #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/inference/analysis/dot.h" diff --git a/paddle/fluid/inference/analysis/CMakeLists.txt b/paddle/fluid/inference/analysis/CMakeLists.txt index dadc8a53706..f2e18a461fd 100644 --- a/paddle/fluid/inference/analysis/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/CMakeLists.txt @@ -58,7 +58,7 @@ endif() inference_analysis_test(test_analyzer SRCS analyzer_tester.cc EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis_predictor ARGS --infer_ditu_rnn_model=${DITU_INSTALL_DIR}/model - --infer_ditu_rnn_data=${DITU_INSTALL_DIR}/data.txt) + --infer_ditu_rnn_data=${DITU_INSTALL_DIR}/data.txt) inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc) inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc) @@ -74,7 +74,7 @@ inference_analysis_test(test_model_store_pass SRCS model_store_pass_tester.cc) set(CHINESE_NER_MODEL_URL "http://paddle-inference-dist.bj.bcebos.com/chinese_ner_model.tar.gz") set(CHINESE_NER_DATA_URL "http://paddle-inference-dist.bj.bcebos.com/chinese_ner-data.txt.tar.gz") set(CHINESE_NER_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo/chinese_ner" CACHE PATH "Chinese ner model and data root." FORCE) -if (NOT EXISTS ${CHINESE_NER_INSTALL_DIR} AND WITH_TESTING) +if (NOT EXISTS ${CHINESE_NER_INSTALL_DIR} AND WITH_TESTING AND WITH_INFERENCE) inference_download_and_uncompress(${CHINESE_NER_INSTALL_DIR} ${CHINESE_NER_MODEL_URL} "chinese_ner_model.tar.gz") inference_download_and_uncompress(${CHINESE_NER_INSTALL_DIR} ${CHINESE_NER_DATA_URL} "chinese_ner-data.txt.tar.gz") endif() @@ -87,7 +87,7 @@ inference_analysis_test(test_analyzer_ner SRCS analyzer_ner_tester.cc set(LAC_MODEL_URL "http://paddle-inference-dist.bj.bcebos.com/lac_model.tar.gz") set(LAC_DATA_URL "http://paddle-inference-dist.bj.bcebos.com/lac_data.txt.tar.gz") set(LAC_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo/lac" CACHE PATH "LAC model and data root." FORCE) -if (NOT EXISTS ${LAC_INSTALL_DIR} AND WITH_TESTING) +if (NOT EXISTS ${LAC_INSTALL_DIR} AND WITH_TESTING AND WITH_INFERENCE) inference_download_and_uncompress(${LAC_INSTALL_DIR} ${LAC_MODEL_URL} "lac_model.tar.gz") inference_download_and_uncompress(${LAC_INSTALL_DIR} ${LAC_DATA_URL} "lac_data.txt.tar.gz") endif() @@ -96,3 +96,15 @@ inference_analysis_test(test_analyzer_lac SRCS analyzer_lac_tester.cc EXTRA_DEPS paddle_inference_api paddle_fluid_api ARGS --infer_model=${LAC_INSTALL_DIR}/model --infer_data=${LAC_INSTALL_DIR}/data.txt) + + +set(TEXT_CLASSIFICATION_MODEL_URL "http://paddle-inference-dist.bj.bcebos.com/text-classification-Senta.tar.gz") +set(TEXT_CLASSIFICATION_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo/text_classification" CACHE PATH "Text Classification model and data root." FORCE) + +if (NOT EXISTS ${TEXT_CLASSIFICATION_INSTALL_DIR} AND WITH_TESTING AND WITH_INFERENCE) + inference_download_and_uncompress(${TEXT_CLASSIFICATION_INSTALL_DIR} ${TEXT_CLASSIFICATION_MODEL_URL} "text-classification-Senta.tar.gz") +endif() + +inference_analysis_test(test_text_classification SRCS test_text_classification.cc + EXTRA_DEPS paddle_inference_api paddle_fluid_api analysis_predictor + ARGS --infer_model=${TEXT_CLASSIFICATION_INSTALL_DIR}/text-classification-Senta) diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc index 192ac2daa6a..ca834406451 100644 --- a/paddle/fluid/inference/analysis/analyzer.cc +++ b/paddle/fluid/inference/analysis/analyzer.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/inference/analysis/analyzer.h" #include +#include #include "paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.h" #include "paddle/fluid/inference/analysis/dfg_graphviz_draw_pass.h" #include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h" @@ -41,20 +42,16 @@ class DfgPassManagerImpl final : public DfgPassManager { public: DfgPassManagerImpl() { // TODO(Superjomn) set the key with pass reprs. - LOG(INFO) - << "-----------------------------------------------------------------"; - if (FLAGS_IA_enable_ir) { - AddPass("fluid-to-ir-pass", new FluidToIrPass); - } else { + if (!FLAGS_IA_enable_ir) { AddPass("fluid-to-data-flow-graph", new FluidToDataFlowGraphPass); + } else { + AddPass("fluid-to-ir-pass", new FluidToIrPass); } TryAddTensorRtPass(); AddPass("data-flow-graph-to-fluid", new DataFlowGraphToFluidPass); if (!FLAGS_IA_output_storage_path.empty()) { AddPass("model-store-pass", new ModelStorePass); } - LOG(INFO) - << "-----------------------------------------------------------------"; } std::string repr() const override { return "dfg-pass-manager"; } @@ -101,19 +98,16 @@ class DfgPassManagerImpl final : public DfgPassManager { Analyzer::Analyzer() { Register("manager1", new DfgPassManagerImpl); } void Analyzer::Run(Argument* argument) { + std::vector passes; + for (auto& pass : all_ir_passes_) { + if (!disabled_ir_passes_.count(pass)) { + passes.push_back(pass); + passes.push_back("graph_viz_pass"); // add graphviz for debug. + } + } + passes.push_back("graph_viz_pass"); // Ugly support fluid-to-ir-pass - argument->Set(kFluidToIrPassesAttr, - new std::vector({ - // Manual update the passes here. - "graph_viz_pass", // - "infer_clean_graph_pass", "graph_viz_pass", // - "attention_lstm_fuse_pass", "graph_viz_pass", // - "fc_lstm_fuse_pass", "graph_viz_pass", // - "mul_lstm_fuse_pass", "graph_viz_pass", // - "seq_concat_fc_fuse_pass", "graph_viz_pass", // - "fc_fuse_pass", "graph_viz_pass" // - - })); + argument->Set(kFluidToIrPassesAttr, new std::vector(passes)); for (auto& x : data_) { PADDLE_ENFORCE(x->Initialize(argument)); @@ -122,6 +116,11 @@ void Analyzer::Run(Argument* argument) { } } +Analyzer& Analyzer::DisableIrPasses(const std::vector& passes) { + disabled_ir_passes_.insert(passes.begin(), passes.end()); + return *this; +} + } // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/analysis/analyzer.h b/paddle/fluid/inference/analysis/analyzer.h index 2e107c82dd5..3fdd2b9ec75 100644 --- a/paddle/fluid/inference/analysis/analyzer.h +++ b/paddle/fluid/inference/analysis/analyzer.h @@ -36,16 +36,10 @@ limitations under the License. */ */ #include +#include "paddle/fluid/inference/analysis/flags.h" #include "paddle/fluid/inference/analysis/pass.h" #include "paddle/fluid/inference/analysis/pass_manager.h" -// TODO(Superjomn) add a definition flag like PADDLE_WITH_TENSORRT and hide this -// flag if not available. -DECLARE_bool(IA_enable_tensorrt_subgraph_engine); -DECLARE_string(IA_graphviz_log_root); -DECLARE_string(IA_output_storage_path); -DECLARE_bool(IA_enable_ir); - namespace paddle { namespace inference { namespace analysis { @@ -57,7 +51,26 @@ class Analyzer : public OrderedRegistry { void Run(Argument* argument); + Analyzer& DisableIrPasses(const std::vector& passes); + DISABLE_COPY_AND_ASSIGN(Analyzer); + + private: + // All avaiable IR passes. + // The bigger fuse comes first, so that the small operators prefer to be + // merged in a larger fuse op. The small fusion will not break the pattern of + // larger fusion. + const std::vector all_ir_passes_{{ + // Manual update the passes here. + "infer_clean_graph_pass", // + "attention_lstm_fuse_pass", // + "fc_lstm_fuse_pass", // + "mul_lstm_fuse_pass", // + "seq_concat_fc_fuse_pass", // + "fc_fuse_pass", // + }}; + + std::unordered_set disabled_ir_passes_; }; } // namespace analysis diff --git a/paddle/fluid/inference/analysis/analyzer_tester.cc b/paddle/fluid/inference/analysis/analyzer_tester.cc index 59e103e1179..94be6733f63 100644 --- a/paddle/fluid/inference/analysis/analyzer_tester.cc +++ b/paddle/fluid/inference/analysis/analyzer_tester.cc @@ -271,17 +271,22 @@ void TestDituRNNPrediction(const std::string &model_path, const std::string &data_path, int batch_size, bool use_analysis, bool activate_ir, int num_times = 1) { - NativeConfig config; + AnalysisConfig config; config.prog_file = FLAGS_infer_ditu_rnn_model + "/__model__"; config.param_file = FLAGS_infer_ditu_rnn_model + "/param"; config.use_gpu = false; config.device = 0; config.specify_input_name = true; + config.enable_ir_optim = activate_ir; + PADDLE_ENFORCE(config.ir_mode == + AnalysisConfig::IrPassMode::kExclude); // default + config.ir_passes.clear(); // Do not exclude any pass. auto base_predictor = CreatePaddlePredictor(config); auto predictor = - CreatePaddlePredictor(config); + CreatePaddlePredictor( + config); std::vector input_slots; DataRecord data(data_path, batch_size); // Prepare inputs. diff --git a/paddle/fluid/inference/analysis/flags.h b/paddle/fluid/inference/analysis/flags.h new file mode 100644 index 00000000000..717e543f01d --- /dev/null +++ b/paddle/fluid/inference/analysis/flags.h @@ -0,0 +1,22 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +// TODO(Superjomn) add a definition flag like PADDLE_WITH_TENSORRT and hide this +// flag if not available. +DECLARE_bool(IA_enable_tensorrt_subgraph_engine); +DECLARE_string(IA_graphviz_log_root); +DECLARE_string(IA_output_storage_path); +DECLARE_bool(IA_enable_ir); diff --git a/paddle/fluid/inference/analysis/fluid_to_ir_pass.h b/paddle/fluid/inference/analysis/fluid_to_ir_pass.h index 6731b1f7593..3086085710d 100644 --- a/paddle/fluid/inference/analysis/fluid_to_ir_pass.h +++ b/paddle/fluid/inference/analysis/fluid_to_ir_pass.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/inference/analysis/flags.h" #include "paddle/fluid/inference/analysis/ir_pass_manager.h" #include "paddle/fluid/inference/analysis/pass.h" @@ -85,9 +86,11 @@ class FluidToIrPass final : public DataFlowGraphPass { new Scope *(&argument_->Get(ir::kParamScopeAttr))); } - const auto &ir_passes_to_apply = - argument_->Get>(kFluidToIrPassesAttr); - ir_passes.Apply(ir_passes_to_apply); + if (FLAGS_IA_enable_ir) { + const auto &ir_passes_to_apply = + argument_->Get>(kFluidToIrPassesAttr); + ir_passes.Apply(ir_passes_to_apply); + } PADDLE_ENFORCE(argument_->main_dfg.get()); argument_->main_dfg->Build(ir_passes.graph()); diff --git a/paddle/fluid/inference/analysis/test_text_classification.cc b/paddle/fluid/inference/analysis/test_text_classification.cc new file mode 100644 index 00000000000..2913824f623 --- /dev/null +++ b/paddle/fluid/inference/analysis/test_text_classification.cc @@ -0,0 +1,109 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include // use glog instead of PADDLE_ENFORCE to avoid importing other paddle header files. +#include +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/inference/analysis/analyzer.h" +#include "paddle/fluid/inference/analysis/ut_helper.h" +#include "paddle/fluid/inference/api/paddle_inference_api.h" +#include "paddle/fluid/inference/api/timer.h" + +DEFINE_string(infer_model, "", "Directory of the inference model."); +DEFINE_string(infer_data, "", "Path of the dataset."); +DEFINE_int32(batch_size, 1, "batch size."); +DEFINE_int32(repeat, 1, "How many times to repeat run."); + +namespace paddle { + +template +std::string to_string(const std::vector &vec) { + std::stringstream ss; + for (const auto &c : vec) { + ss << c << " "; + } + return ss.str(); +} + +void PrintTime(const double latency, const int bs, const int repeat) { + LOG(INFO) << "===========profile result==========="; + LOG(INFO) << "batch_size: " << bs << ", repeat: " << repeat + << ", avg latency: " << latency / repeat << "ms"; + LOG(INFO) << "====================================="; +} + +void Main(int batch_size) { + // Three sequence inputs. + std::vector input_slots(1); + // one batch starts + // data -- + int64_t data0[] = {0, 1, 2}; + for (auto &input : input_slots) { + input.data.Reset(data0, sizeof(data0)); + input.shape = std::vector({3, 1}); + // dtype -- + input.dtype = PaddleDType::INT64; + // LoD -- + input.lod = std::vector>({{0, 3}}); + } + + // shape -- + // Create Predictor -- + AnalysisConfig config; + config.model_dir = FLAGS_infer_model; + config.use_gpu = false; + config.enable_ir_optim = true; + config.ir_passes.push_back("fc_lstm_fuse_pass"); + auto predictor = + CreatePaddlePredictor( + config); + + inference::Timer timer; + double sum = 0; + std::vector output_slots; + for (int i = 0; i < FLAGS_repeat; i++) { + timer.tic(); + CHECK(predictor->Run(input_slots, &output_slots)); + sum += timer.toc(); + } + PrintTime(sum, batch_size, FLAGS_repeat); + + // Get output + LOG(INFO) << "get outputs " << output_slots.size(); + + for (auto &output : output_slots) { + LOG(INFO) << "output.shape: " << to_string(output.shape); + // no lod ? + CHECK_EQ(output.lod.size(), 0UL); + LOG(INFO) << "output.dtype: " << output.dtype; + std::stringstream ss; + for (int i = 0; i < 5; i++) { + ss << static_cast(output.data.data())[i] << " "; + } + LOG(INFO) << "output.data summary: " << ss.str(); + // one batch ends + } +} + +TEST(text_classification, basic) { Main(FLAGS_batch_size); } + +} // namespace paddle + +USE_PASS(fc_fuse_pass); +USE_PASS(seq_concat_fc_fuse_pass); +USE_PASS(fc_lstm_fuse_pass); +USE_PASS(graph_viz_pass); +USE_PASS(infer_clean_graph_pass); +USE_PASS(attention_lstm_fuse_pass); diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt index 3a43c72e33b..ea00bf36495 100644 --- a/paddle/fluid/inference/api/CMakeLists.txt +++ b/paddle/fluid/inference/api/CMakeLists.txt @@ -44,7 +44,19 @@ function(inference_api_test TARGET_NAME) endfunction(inference_api_test) cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS lod_tensor) -cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api) +cc_library(analysis_predictor SRCS analysis_predictor.cc DEPS paddle_inference_api + analysis + ir_pass_manager + pass + fc_fuse_pass + fc_lstm_fuse_pass + seq_concat_fc_fuse_pass + graph_viz_pass + infer_clean_graph_pass + graph_pattern_detector + infer_clean_graph_pass + attention_lstm_fuse_pass + ) cc_test(test_paddle_inference_api SRCS api_tester.cc diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index e87abd2feef..a8fa677202d 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -14,6 +14,8 @@ #include "paddle/fluid/inference/api/analysis_predictor.h" #include +#include +#include #include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/scope.h" @@ -28,6 +30,8 @@ bool AnalysisPredictor::Init( VLOG(3) << "Predictor::init()"; if (config_.use_gpu) { place_ = paddle::platform::CUDAPlace(config_.device); + LOG(WARNING) << "ir optimize only supports CPU currently"; + config_.enable_ir_optim = false; } else { place_ = paddle::platform::CPUPlace(); } @@ -73,7 +77,7 @@ bool AnalysisPredictor::Init( void AnalysisPredictor::OptimizeInferenceProgram() { LOG(INFO) << "optimize begin"; - FLAGS_IA_enable_ir = true; + FLAGS_IA_enable_ir = config_.enable_ir_optim; FLAGS_IA_enable_tensorrt_subgraph_engine = false; FLAGS_IA_output_storage_path = ""; // Don't output the model. // Analyze inference_program @@ -90,24 +94,26 @@ void AnalysisPredictor::OptimizeInferenceProgram() { } argument_.origin_program_desc.reset( new ProgramDesc(*inference_program_->Proto())); - Analyzer().Run(&argument_); + PADDLE_ENFORCE(config_.ir_mode == AnalysisConfig::IrPassMode::kExclude, + "Only kExclude is supported yet."); + Analyzer().DisableIrPasses(config_.ir_passes).Run(&argument_); + CHECK(argument_.transformed_program_desc); VLOG(5) << "to prepare executor"; - // LOG(INFO) << "transformed_parogram_desc " << - // argument.transformed_program_desc->DebugString(); inference_program_.reset( new framework::ProgramDesc(*argument_.transformed_program_desc)); - PADDLE_ENFORCE(argument_.Has(framework::ir::kParamScopeAttr)); - // Update scope. - scope_.reset( - argument_.Release(framework::ir::kParamScopeAttr)); - LOG(INFO) << "optimize end =="; + if (argument_.Has(framework::ir::kParamScopeAttr)) { + // Update scope. + scope_.reset( + argument_.Release(framework::ir::kParamScopeAttr)); + } + LOG(INFO) << "== optimize end =="; } template <> std::unique_ptr CreatePaddlePredictor< - NativeConfig, PaddleEngineKind::kAnalysis>(const NativeConfig& config) { - VLOG(3) << "create NativePredictor"; + AnalysisConfig, PaddleEngineKind::kAnalysis>(const AnalysisConfig& config) { + VLOG(3) << "create AnalysisConfig"; if (config.use_gpu) { // 1. GPU memeroy PADDLE_ENFORCE_GT( diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index e32b6185f60..e53925366e9 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include #include "paddle/fluid/inference/analysis/analyzer.h" #include "paddle/fluid/inference/api/api_impl.h" #include "paddle/fluid/inference/api/paddle_inference_api.h" @@ -28,7 +30,7 @@ using framework::proto::ProgramDesc; */ class AnalysisPredictor : public NativePaddlePredictor { public: - explicit AnalysisPredictor(const NativeConfig& config) + explicit AnalysisPredictor(const AnalysisConfig& config) : NativePaddlePredictor(config), config_(config) {} bool Init(const std::shared_ptr& parent_scope); @@ -44,7 +46,7 @@ class AnalysisPredictor : public NativePaddlePredictor { Argument& analysis_argument() { return argument_; } private: - NativeConfig config_; + AnalysisConfig config_; Argument argument_; }; diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index 38b11d9113e..bd9b4b1a814 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -176,7 +176,8 @@ bool NativePaddlePredictor::SetFeed(const std::vector &inputs, framework::Scope *scope) { VLOG(3) << "Predictor::set_feed"; if (inputs.size() != feeds_.size()) { - LOG(ERROR) << "wrong feed input size."; + LOG(ERROR) << "wrong feed input size, need " << feeds_.size() << " but get " + << inputs.size(); return false; } for (size_t i = 0; i < inputs.size(); ++i) { diff --git a/paddle/fluid/inference/api/paddle_inference_api.h b/paddle/fluid/inference/api/paddle_inference_api.h index 1baa64c249f..995da11e4a3 100644 --- a/paddle/fluid/inference/api/paddle_inference_api.h +++ b/paddle/fluid/inference/api/paddle_inference_api.h @@ -150,6 +150,21 @@ struct TensorRTConfig : public NativeConfig { int workspace_size{1 << 30}; }; +// NOTE WIP, not stable yet. +struct AnalysisConfig : public NativeConfig { + // + enum class IrPassMode { + kSystem, // Use system default passes, not customize. + kInclude, // Specify the passes in `ir_passes`. + kExclude // Specify the disabled passes in `ir_passes`. + }; + + bool enable_ir_optim = true; + IrPassMode ir_mode{IrPassMode::kExclude}; + // attention lstm fuse works only on some specific models, disable as default. + std::vector ir_passes{"attention_lstm_fuse_pass"}; +}; + // A factory to help create different predictors. // // FOR EXTENSION DEVELOPER: diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index f5c10ced830..58463dc4d6f 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -57,7 +57,7 @@ class LookupTableKernel : public framework::OpKernel { memset(output + i * row_width, 0, row_width * sizeof(T)); } else { PADDLE_ENFORCE_LT(ids[i], row_number); - PADDLE_ENFORCE_GE(ids[i], 0); + PADDLE_ENFORCE_GE(ids[i], 0, "ids %d", i); memcpy(output + i * row_width, table + ids[i] * row_width, row_width * sizeof(T)); } -- GitLab