diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 423c39813f05af0d6aaade184914e6777c9b8a83..07b9e0e051bce13f6caeca54a664019c55d80fa6 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -54,6 +54,7 @@ else() message(WARNING "These tests has been disabled in OSX or WITH_MKL=OFF before being fixed: \n test_analyzer_seq_pool1") endif() + # RNN2 set(RNN2_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/rnn2") download_model_and_data(${RNN2_INSTALL_DIR} "rnn2_model.tar.gz" "rnn2_data.txt.tar.gz") @@ -115,6 +116,10 @@ if (NOT EXISTS ${MOBILENET_INSTALL_DIR}) endif() inference_analysis_api_test_with_refer_result(test_analyzer_mobilenet_transpose ${MOBILENET_INSTALL_DIR} analyzer_vis_tester.cc SERIAL) +# googlenet +inference_analysis_api_test_with_fake_data(test_analyzer_googlenet + "${INFERENCE_DEMO_INSTALL_DIR}/googlenet" analyzer_resnet50_tester.cc "googlenet.tar.gz" SERIAL) + # resnet50 inference_analysis_api_test_with_fake_data(test_analyzer_resnet50 "${INFERENCE_DEMO_INSTALL_DIR}/resnet50" analyzer_resnet50_tester.cc "resnet50_model.tar.gz" SERIAL) diff --git a/paddle/fluid/operators/lrn_mkldnn_op.cc b/paddle/fluid/operators/lrn_mkldnn_op.cc index 4e4f977fcc742856b877ef0b7f9a3cc9879aefce..097ba01d401dbc7969e30f576cac2567c874ed99 100644 --- a/paddle/fluid/operators/lrn_mkldnn_op.cc +++ b/paddle/fluid/operators/lrn_mkldnn_op.cc @@ -67,7 +67,13 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { mid->mutable_data(ctx.GetPlace()); const int n = ctx.Attr("n"); - const float alpha = ctx.Attr("alpha"); + // MKL-DNN implements LRN in a caffe way: + // http://caffe.berkeleyvision.org/tutorial/layers/lrn.html + // Where sum of squares is divided by size of normalization window + // this is not the case for PaddlePaddle LRN. + // Hence we need to compensate for this diffrence by + // multipliing alpha by size of window(n) + const float alpha = ctx.Attr("alpha") * static_cast(n); const float beta = ctx.Attr("beta"); const float k = ctx.Attr("k"); const bool is_test = ctx.Attr("is_test"); @@ -78,10 +84,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { auto dims = paddle::framework::vectorize2int(x->dims()); auto src_md = paddle::platform::MKLDNNMemDesc( - dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw); - - auto dst_md = paddle::platform::MKLDNNMemDesc( - dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw); + dims, mkldnn::memory::data_type::f32, x->format()); auto forward_desc = mkldnn::lrn_forward::desc{mkldnn::prop_kind::forward, mkldnn::lrn_across_channels, @@ -92,8 +95,6 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { k}; auto src_memory_pd = mkldnn::memory::primitive_desc{src_md, mkldnn_engine}; - auto dst_memory = mkldnn::memory{{dst_md, mkldnn_engine}, - static_cast(output_data)}; if (!is_test) { const std::string key = ctx.op().Output("Out"); @@ -110,11 +111,16 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { src_memory->set_data_handle( static_cast(const_cast(input_data))); + auto dst_memory = mkldnn::memory(forward_pd->dst_primitive_desc(), + static_cast(output_data)); auto workspace_memory = insert_to_context( key_workspace_memory, dev_ctx, forward_pd->workspace_primitive_desc()); run_primitive(*forward_pd, *src_memory, *workspace_memory, dst_memory); + + out->set_layout(framework::DataLayout::kMKLDNN); + out->set_format(platform::GetMKLDNNFormat(dst_memory)); } else { auto forward_pd = mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine}; @@ -122,8 +128,13 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { src_memory_pd, static_cast(const_cast(input_data))}; auto workspace_memory = mkldnn::memory{forward_pd.workspace_primitive_desc()}; + auto dst_memory = mkldnn::memory(forward_pd.dst_primitive_desc(), + static_cast(output_data)); run_primitive(forward_pd, src_memory, workspace_memory, dst_memory); + + out->set_layout(framework::DataLayout::kMKLDNN); + out->set_format(platform::GetMKLDNNFormat(dst_memory)); } } }; @@ -151,7 +162,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel { const std::string key_workspace_memory = key + "@lrn_workspace_memory"; const int n = ctx.Attr("n"); - const float alpha = ctx.Attr("alpha"); + const float alpha = ctx.Attr("alpha") * static_cast(n); const float beta = ctx.Attr("beta"); const float k = ctx.Attr("k");