提交 ce20dfa2 编写于 作者: T tensor-tang

enable more choices

上级 602e28bf
...@@ -19,6 +19,10 @@ limitations under the License. */ ...@@ -19,6 +19,10 @@ limitations under the License. */
#include "paddle/fluid/inference/tests/test_helper.h" #include "paddle/fluid/inference/tests/test_helper.h"
DEFINE_string(dirname, "", "Directory of the inference model."); DEFINE_string(dirname, "", "Directory of the inference model.");
DEFINE_int32(repeat, 100, "Running the inference program repeat times");
DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run inference");
DEFINE_bool(prepare_vars, true, "Prepare variables before executor");
DEFINE_bool(prepare_context, true, "Prepare Context before executor");
TEST(inference, understand_sentiment) { TEST(inference, understand_sentiment) {
if (FLAGS_dirname.empty()) { if (FLAGS_dirname.empty()) {
...@@ -61,10 +65,29 @@ TEST(inference, understand_sentiment) { ...@@ -61,10 +65,29 @@ TEST(inference, understand_sentiment) {
std::vector<paddle::framework::LoDTensor*> cpu_fetchs1; std::vector<paddle::framework::LoDTensor*> cpu_fetchs1;
cpu_fetchs1.push_back(&output1); cpu_fetchs1.push_back(&output1);
int repeat = 100;
// Run inference on CPU // Run inference on CPU
TestInference<paddle::platform::CPUPlace, true, true>(dirname, cpu_feeds, const bool model_combined = false;
cpu_fetchs1, repeat); if (FLAGS_prepare_vars) {
if (FLAGS_prepare_context) {
TestInference<paddle::platform::CPUPlace, false, true>(
dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat, model_combined,
FLAGS_use_mkldnn);
} else {
TestInference<paddle::platform::CPUPlace, false, false>(
dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat, model_combined,
FLAGS_use_mkldnn);
}
} else {
if (FLAGS_prepare_context) {
TestInference<paddle::platform::CPUPlace, true, true>(
dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat, model_combined,
FLAGS_use_mkldnn);
} else {
TestInference<paddle::platform::CPUPlace, true, false>(
dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat, model_combined,
FLAGS_use_mkldnn);
}
}
LOG(INFO) << output1.lod(); LOG(INFO) << output1.lod();
LOG(INFO) << output1.dims(); LOG(INFO) << output1.dims();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册