diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index a7f09df4917532e7261cee471c711897c8eb3447..5f21dae60586e926472fc512eca7bcbb55dc8eda 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -44,6 +44,11 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, << dst_place; return; } +#ifdef PADDLE_WITH_MKLDNN + if (src.layout() == DataLayout::kMKLDNN) { + dst->set_mkldnn_prim_desc(src.get_mkldnn_prim_desc()); + } +#endif memory::Copy(boost::get(dst_place), dst_ptr, boost::get(src_place), src_ptr, size); } diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 8f7b6f31dec72a09c414654133dfe717606b0824..d9ac73b0638ad356501a9883b49e65f8f3e32245 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -110,7 +110,7 @@ set(TRANSFORMER_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/transformer") download_model_and_data(${TRANSFORMER_INSTALL_DIR} "temp%2Ftransformer_model.tar.gz" "temp%2Ftransformer_data.txt.tar.gz") inference_analysis_test(test_analyzer_transformer SRCS analyzer_transformer_tester.cc EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} - ARGS --infer_model=${TRANSFORMER_INSTALL_DIR}/model --infer_data=${TRANSFORMER_INSTALL_DIR}/data.txt --batch_size=8) + ARGS --infer_model=${TRANSFORMER_INSTALL_DIR}/model --infer_data=${TRANSFORMER_INSTALL_DIR}/data.txt --batch_size=8 SERIAL) # ocr set(OCR_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/ocr") diff --git a/paddle/fluid/inference/tests/api/analyzer_transformer_tester.cc b/paddle/fluid/inference/tests/api/analyzer_transformer_tester.cc index 9d17f38ab764148d4e1a63124289425c7e7aa983..f765f556112915bcfa07b5361a473d39292f711a 100644 --- a/paddle/fluid/inference/tests/api/analyzer_transformer_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_transformer_tester.cc @@ -183,10 +183,13 @@ void SetInput(std::vector> *inputs) { } // Easy for profiling independently. -TEST(Analyzer_Transformer, profile) { +void profile(bool use_mkldnn = false) { AnalysisConfig cfg; SetConfig(&cfg); std::vector outputs; + if (use_mkldnn) { + cfg.EnableMKLDNN(); + } std::vector> input_slots_all; SetInput(&input_slots_all); @@ -194,6 +197,11 @@ TEST(Analyzer_Transformer, profile) { input_slots_all, &outputs, FLAGS_num_threads); } +TEST(Analyzer_Transformer, profile) { profile(); } +#ifdef PADDLE_WITH_MKLDNN +TEST(Analyzer_Transformer, profile_mkldnn) { profile(true); } +#endif + // Check the fuse status TEST(Analyzer_Transformer, fuse_statis) { AnalysisConfig cfg; @@ -206,9 +214,12 @@ TEST(Analyzer_Transformer, fuse_statis) { } // Compare result of NativeConfig and AnalysisConfig -TEST(Analyzer_Transformer, compare) { +void compare(bool use_mkldnn = false) { AnalysisConfig cfg; SetConfig(&cfg); + if (use_mkldnn) { + cfg.EnableMKLDNN(); + } std::vector> input_slots_all; SetInput(&input_slots_all); @@ -216,5 +227,10 @@ TEST(Analyzer_Transformer, compare) { reinterpret_cast(&cfg), input_slots_all); } +TEST(Analyzer_Transformer, compare) { compare(); } +#ifdef PADDLE_WITH_MKLDNN +TEST(Analyzer_Transformer, compare_mkldnn) { compare(true /* use_mkldnn */); } +#endif + } // namespace inference } // namespace paddle