未验证 提交 b9191902 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #15531 from jczaja/prv-googlenet-fix

Performance and functional fixes to LRN
......@@ -54,6 +54,7 @@ else()
message(WARNING "These tests has been disabled in OSX or WITH_MKL=OFF before being fixed: \n test_analyzer_seq_pool1")
endif()
# RNN2
set(RNN2_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/rnn2")
download_model_and_data(${RNN2_INSTALL_DIR} "rnn2_model.tar.gz" "rnn2_data.txt.tar.gz")
......@@ -115,6 +116,10 @@ if (NOT EXISTS ${MOBILENET_INSTALL_DIR})
endif()
inference_analysis_api_test_with_refer_result(test_analyzer_mobilenet_transpose ${MOBILENET_INSTALL_DIR} analyzer_vis_tester.cc SERIAL)
# googlenet
inference_analysis_api_test_with_fake_data(test_analyzer_googlenet
"${INFERENCE_DEMO_INSTALL_DIR}/googlenet" analyzer_resnet50_tester.cc "googlenet.tar.gz" SERIAL)
# resnet50
inference_analysis_api_test_with_fake_data(test_analyzer_resnet50
"${INFERENCE_DEMO_INSTALL_DIR}/resnet50" analyzer_resnet50_tester.cc "resnet50_model.tar.gz" SERIAL)
......
......@@ -67,7 +67,13 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mid->mutable_data<T>(ctx.GetPlace());
const int n = ctx.Attr<int>("n");
const float alpha = ctx.Attr<float>("alpha");
// MKL-DNN implements LRN in a caffe way:
// http://caffe.berkeleyvision.org/tutorial/layers/lrn.html
// Where sum of squares is divided by size of normalization window
// this is not the case for PaddlePaddle LRN.
// Hence we need to compensate for this diffrence by
// multipliing alpha by size of window(n)
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
const float beta = ctx.Attr<float>("beta");
const float k = ctx.Attr<float>("k");
const bool is_test = ctx.Attr<bool>("is_test");
......@@ -78,10 +84,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto dims = paddle::framework::vectorize2int(x->dims());
auto src_md = paddle::platform::MKLDNNMemDesc(
dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
auto dst_md = paddle::platform::MKLDNNMemDesc(
dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
dims, mkldnn::memory::data_type::f32, x->format());
auto forward_desc = mkldnn::lrn_forward::desc{mkldnn::prop_kind::forward,
mkldnn::lrn_across_channels,
......@@ -92,8 +95,6 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
k};
auto src_memory_pd = mkldnn::memory::primitive_desc{src_md, mkldnn_engine};
auto dst_memory = mkldnn::memory{{dst_md, mkldnn_engine},
static_cast<void*>(output_data)};
if (!is_test) {
const std::string key = ctx.op().Output("Out");
......@@ -110,11 +111,16 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
src_memory->set_data_handle(
static_cast<void*>(const_cast<T*>(input_data)));
auto dst_memory = mkldnn::memory(forward_pd->dst_primitive_desc(),
static_cast<void*>(output_data));
auto workspace_memory = insert_to_context<mkldnn::memory>(
key_workspace_memory, dev_ctx,
forward_pd->workspace_primitive_desc());
run_primitive(*forward_pd, *src_memory, *workspace_memory, dst_memory);
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat(dst_memory));
} else {
auto forward_pd =
mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine};
......@@ -122,8 +128,13 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
src_memory_pd, static_cast<void*>(const_cast<T*>(input_data))};
auto workspace_memory =
mkldnn::memory{forward_pd.workspace_primitive_desc()};
auto dst_memory = mkldnn::memory(forward_pd.dst_primitive_desc(),
static_cast<void*>(output_data));
run_primitive(forward_pd, src_memory, workspace_memory, dst_memory);
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat(dst_memory));
}
}
};
......@@ -151,7 +162,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const std::string key_workspace_memory = key + "@lrn_workspace_memory";
const int n = ctx.Attr<int>("n");
const float alpha = ctx.Attr<float>("alpha");
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
const float beta = ctx.Attr<float>("beta");
const float k = ctx.Attr<float>("k");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册