提交 6ae7cbe2 编写于 作者: T tensor-tang

follow comments

上级 99d00cce
......@@ -40,8 +40,9 @@ inference_test(recommender_system)
inference_test(word2vec)
# This is an unly work around to make this test run
# TODO(TJ): clean me up
cc_test(test_inference_nlp
SRCS test_inference_nlp.cc
DEPS paddle_fluid
ARGS
--modelpath=${PADDLE_BINARY_DIR}/python/paddle/fluid/tests/book/recognize_digits_mlp.inference.model)
--model_path=${PADDLE_BINARY_DIR}/python/paddle/fluid/tests/book/recognize_digits_mlp.inference.model)
......@@ -24,8 +24,8 @@ limitations under the License. */
#include <omp.h>
#endif
DEFINE_string(modelpath, "", "Directory of the inference model.");
DEFINE_string(datafile, "", "File of input index data.");
DEFINE_string(model_path, "", "Directory of the inference model.");
DEFINE_string(data_file, "", "File of input index data.");
DEFINE_int32(repeat, 100, "Running the inference program repeat times");
DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run inference");
DEFINE_bool(prepare_vars, true, "Prepare variables before executor");
......@@ -65,6 +65,7 @@ size_t LoadData(std::vector<paddle::framework::LoDTensor>* out,
ids.push_back(stoi(field));
}
if (ids.size() >= 1024) {
// Synced with NLP guys, they will ignore input larger then 1024
continue;
}
......@@ -142,18 +143,18 @@ void ThreadRunInfer(
}
TEST(inference, nlp) {
if (FLAGS_modelpath.empty()) {
LOG(FATAL) << "Usage: ./example --modelpath=path/to/your/model";
if (FLAGS_model_path.empty()) {
LOG(FATAL) << "Usage: ./example --model_path=path/to/your/model";
}
if (FLAGS_datafile.empty()) {
LOG(WARNING) << " Not data file provided, will use dummy data!"
if (FLAGS_data_file.empty()) {
LOG(WARNING) << "No data file provided, will use dummy data!"
<< "Note: if you use nlp model, please provide data file.";
}
LOG(INFO) << "Model Path: " << FLAGS_modelpath;
LOG(INFO) << "Data File: " << FLAGS_datafile;
LOG(INFO) << "Model Path: " << FLAGS_model_path;
LOG(INFO) << "Data File: " << FLAGS_data_file;
std::vector<paddle::framework::LoDTensor> datasets;
size_t num_total_words = LoadData(&datasets, FLAGS_datafile);
size_t num_total_words = LoadData(&datasets, FLAGS_data_file);
LOG(INFO) << "Number of samples (seq_len<1024): " << datasets.size();
LOG(INFO) << "Total number of words: " << num_total_words;
......@@ -168,7 +169,7 @@ TEST(inference, nlp) {
// 2. Initialize the inference_program and load parameters
std::unique_ptr<paddle::framework::ProgramDesc> inference_program;
inference_program =
InitProgram(&executor, scope.get(), FLAGS_modelpath, model_combined);
InitProgram(&executor, scope.get(), FLAGS_model_path, model_combined);
if (FLAGS_use_mkldnn) {
EnableMKLDNN(inference_program);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册