From 663a11ac7c5d276ea7df54c901d2445effb1858b Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 31 Aug 2018 20:22:54 +0800 Subject: [PATCH] bugfix and follow comment --- paddle/fluid/inference/analysis/CMakeLists.txt | 13 +++++-------- ...chinese_ner_tester.cc => analyzer_ner_tester.cc} | 2 +- paddle/fluid/inference/api/api_impl.cc | 4 ++-- 3 files changed, 8 insertions(+), 11 deletions(-) rename paddle/fluid/inference/analysis/{chinese_ner_tester.cc => analyzer_ner_tester.cc} (100%) diff --git a/paddle/fluid/inference/analysis/CMakeLists.txt b/paddle/fluid/inference/analysis/CMakeLists.txt index d43ecc722ea..817e36401f7 100644 --- a/paddle/fluid/inference/analysis/CMakeLists.txt +++ b/paddle/fluid/inference/analysis/CMakeLists.txt @@ -25,9 +25,8 @@ function (inference_analysis_test TARGET) if(WITH_TESTING) set(options "") set(oneValueArgs "") - set(multiValueArgs SRCS EXTRA_DEPS) + set(multiValueArgs SRCS ARGS EXTRA_DEPS) cmake_parse_arguments(analysis_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - set(mem_opt "") if(WITH_GPU) set(mem_opt "--fraction_of_gpu_memory_to_use=0.5") @@ -35,7 +34,7 @@ function (inference_analysis_test TARGET) cc_test(${TARGET} SRCS "${analysis_test_SRCS}" DEPS analysis graph fc_fuse_pass graph_viz_pass infer_clean_graph_pass graph_pattern_detector pass ${analysis_test_EXTRA_DEPS} - ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model ${mem_opt}) + ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model ${mem_opt} ${analysis_test_ARGS}) set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec) endif(WITH_TESTING) endfunction(inference_analysis_test) @@ -70,8 +69,7 @@ inference_analysis_test(test_analyzer SRCS analyzer_tester.cc attention_lstm_fuse_pass paddle_inference_api pass - ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model - --infer_ditu_rnn_model=${DITU_INSTALL_DIR}/model + ARGS --infer_ditu_rnn_model=${DITU_INSTALL_DIR}/model --infer_ditu_rnn_data=${DITU_INSTALL_DIR}/data.txt) inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc) @@ -93,8 +91,7 @@ if (NOT EXISTS ${CHINESE_NER_INSTALL_DIR}) inference_download_and_uncompress(${CHINESE_NER_INSTALL_DIR} ${CHINESE_NER_DATA_URL} "chinese_ner-data.txt.tar.gz") endif() -inference_analysis_test(test_chinese_ner SRCS chinese_ner_tester.cc +inference_analysis_test(test_analyzer_ner SRCS analyzer_ner_tester.cc EXTRA_DEPS paddle_inference_api paddle_fluid_api - ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model - --infer_model=${CHINESE_NER_INSTALL_DIR}/model + ARGS --infer_model=${CHINESE_NER_INSTALL_DIR}/model --infer_data=${CHINESE_NER_INSTALL_DIR}/data.txt) diff --git a/paddle/fluid/inference/analysis/chinese_ner_tester.cc b/paddle/fluid/inference/analysis/analyzer_ner_tester.cc similarity index 100% rename from paddle/fluid/inference/analysis/chinese_ner_tester.cc rename to paddle/fluid/inference/analysis/analyzer_ner_tester.cc index 9088a29d504..720a8811db7 100644 --- a/paddle/fluid/inference/analysis/chinese_ner_tester.cc +++ b/paddle/fluid/inference/analysis/analyzer_ner_tester.cc @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/inference/analysis/analyzer.h" #include #include #include "paddle/fluid/framework/ir/pass.h" -#include "paddle/fluid/inference/analysis/analyzer.h" #include "paddle/fluid/inference/analysis/ut_helper.h" #include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/api/paddle_inference_api.h" diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index 530274f0c92..2e02f6d974c 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -62,14 +62,14 @@ void NativePaddlePredictor::PrepareFeedFetch() { for (auto *op : inference_program_->Block(0).AllOps()) { if (op->Type() == "feed") { int idx = boost::get(op->GetAttr("col")); - if (feeds_.size() <= (size_t)idx) { + if (feeds_.size() <= static_cast(idx)) { feeds_.resize(idx + 1); } feeds_[idx] = op; feed_names_[op->Output("Out")[0]] = idx; } else if (op->Type() == "fetch") { int idx = boost::get(op->GetAttr("col")); - if (fetchs_.size() <= (size_t)idx) { + if (fetchs_.size() <= static_cast(idx)) { fetchs_.resize(idx + 1); } fetchs_[idx] = op; -- GitLab