提交 663a11ac 编写于 作者: T tensor-tang

bugfix and follow comment

上级 d0c65bff
......@@ -25,9 +25,8 @@ function (inference_analysis_test TARGET)
if(WITH_TESTING)
set(options "")
set(oneValueArgs "")
set(multiValueArgs SRCS EXTRA_DEPS)
set(multiValueArgs SRCS ARGS EXTRA_DEPS)
cmake_parse_arguments(analysis_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(mem_opt "")
if(WITH_GPU)
set(mem_opt "--fraction_of_gpu_memory_to_use=0.5")
......@@ -35,7 +34,7 @@ function (inference_analysis_test TARGET)
cc_test(${TARGET}
SRCS "${analysis_test_SRCS}"
DEPS analysis graph fc_fuse_pass graph_viz_pass infer_clean_graph_pass graph_pattern_detector pass ${analysis_test_EXTRA_DEPS}
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model ${mem_opt})
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model ${mem_opt} ${analysis_test_ARGS})
set_tests_properties(${TARGET} PROPERTIES DEPENDS test_word2vec)
endif(WITH_TESTING)
endfunction(inference_analysis_test)
......@@ -70,8 +69,7 @@ inference_analysis_test(test_analyzer SRCS analyzer_tester.cc
attention_lstm_fuse_pass
paddle_inference_api
pass
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model
--infer_ditu_rnn_model=${DITU_INSTALL_DIR}/model
ARGS --infer_ditu_rnn_model=${DITU_INSTALL_DIR}/model
--infer_ditu_rnn_data=${DITU_INSTALL_DIR}/data.txt)
inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc)
......@@ -93,8 +91,7 @@ if (NOT EXISTS ${CHINESE_NER_INSTALL_DIR})
inference_download_and_uncompress(${CHINESE_NER_INSTALL_DIR} ${CHINESE_NER_DATA_URL} "chinese_ner-data.txt.tar.gz")
endif()
inference_analysis_test(test_chinese_ner SRCS chinese_ner_tester.cc
inference_analysis_test(test_analyzer_ner SRCS analyzer_ner_tester.cc
EXTRA_DEPS paddle_inference_api paddle_fluid_api
ARGS --inference_model_dir=${PYTHON_TESTS_DIR}/book/word2vec.inference.model
--infer_model=${CHINESE_NER_INSTALL_DIR}/model
ARGS --infer_model=${CHINESE_NER_INSTALL_DIR}/model
--infer_data=${CHINESE_NER_INSTALL_DIR}/data.txt)
......@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/analyzer.h"
#include <google/protobuf/text_format.h>
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
......
......@@ -62,14 +62,14 @@ void NativePaddlePredictor::PrepareFeedFetch() {
for (auto *op : inference_program_->Block(0).AllOps()) {
if (op->Type() == "feed") {
int idx = boost::get<int>(op->GetAttr("col"));
if (feeds_.size() <= (size_t)idx) {
if (feeds_.size() <= static_cast<size_t>(idx)) {
feeds_.resize(idx + 1);
}
feeds_[idx] = op;
feed_names_[op->Output("Out")[0]] = idx;
} else if (op->Type() == "fetch") {
int idx = boost::get<int>(op->GetAttr("col"));
if (fetchs_.size() <= (size_t)idx) {
if (fetchs_.size() <= static_cast<size_t>(idx)) {
fetchs_.resize(idx + 1);
}
fetchs_[idx] = op;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册