diff --git a/paddle/fluid/inference/tests/api/analyzer_mm_dnn_tester.cc b/paddle/fluid/inference/tests/api/analyzer_mm_dnn_tester.cc index 1318fbcbc4022457354fb34c727cf56ce26e12ec..529a0174c8542f5226e70ef4a47bde069220ecc2 100644 --- a/paddle/fluid/inference/tests/api/analyzer_mm_dnn_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_mm_dnn_tester.cc @@ -94,11 +94,15 @@ void SetInput(std::vector> *inputs) { } // Easy for profiling independently. -TEST(Analyzer_MM_DNN, profile) { +void profile(bool use_mkldnn = false) { contrib::AnalysisConfig cfg; SetConfig(&cfg); std::vector outputs; + if (use_mkldnn) { + cfg.EnableMKLDNN(); + } + std::vector> input_slots_all; SetInput(&input_slots_all); TestPrediction(reinterpret_cast(&cfg), @@ -119,6 +123,11 @@ TEST(Analyzer_MM_DNN, profile) { } } +TEST(Analyzer_MM_DNN, profile) { profile(); } +#ifdef PADDLE_WITH_MKLDNN +TEST(Analyzer_MM_DNN, profile_mkldnn) { profile(true /* use_mkldnn */); } +#endif + // Check the fuse status TEST(Analyzer_MM_DNN, fuse_statis) { contrib::AnalysisConfig cfg; @@ -131,16 +140,25 @@ TEST(Analyzer_MM_DNN, fuse_statis) { } // Compare result of NativeConfig and AnalysisConfig -TEST(Analyzer_MM_DNN, compare) { +void compare(bool use_mkldnn = false) { contrib::AnalysisConfig cfg; SetConfig(&cfg); + if (use_mkldnn) { + cfg.EnableMKLDNN(); + } + std::vector> input_slots_all; SetInput(&input_slots_all); CompareNativeAndAnalysis( reinterpret_cast(&cfg), input_slots_all); } +TEST(Analyzer_MM_DNN, compare) { compare(); } +#ifdef PADDLE_WITH_MKLDNN +TEST(Analyzer_MM_DNN, compare_mkldnn) { compare(true /* use_mkldnn */); } +#endif + // Compare Deterministic result TEST(Analyzer_MM_DNN, compare_determine) { AnalysisConfig cfg; diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc index 4c73a70ed1ce2435bfc1a0f3d45afe9b6e3c4cf6..04e8800bbc888540c4df21360c767688eb19c423 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc @@ -134,16 +134,18 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { const bool are_inputs_in_same_format = x->format() == y->format(); const bool is_x_nchw = x->format() == memory::format::nchw; const bool is_x_nc = x->format() == memory::format::nc; + const bool is_x_x = x->format() == memory::format::x; const bool is_y_nchw = y->format() == memory::format::nchw; const bool is_y_nc = y->format() == memory::format::nc; + const bool is_y_x = y->format() == memory::format::x; if (!are_inputs_in_same_format) { using platform::MKLDNNDeviceContext; auto& dev_ctx = ctx.template device_context(); const auto& mkldnn_engine = dev_ctx.GetEngine(); - if (!(is_x_nchw || is_x_nc)) + if (!(is_x_nchw || is_x_nc || is_x_x)) ReorderInput(const_cast(x), ctx.GetPlace(), mkldnn_engine, x->dims().size() == 4); - if (!(is_y_nchw || is_y_nc)) + if (!(is_y_nchw || is_y_nc || is_y_x)) ReorderInput(const_cast(y), ctx.GetPlace(), mkldnn_engine, y->dims().size() == 4); }