提交 4aa7ef3c 编写于 作者: J Jacek Czaja

- Compensation fix to LRN MKL-DNN op

test=develop
上级 fa286b10
...@@ -54,6 +54,7 @@ else() ...@@ -54,6 +54,7 @@ else()
message(WARNING "These tests has been disabled in OSX or WITH_MKL=OFF before being fixed: \n test_analyzer_seq_pool1") message(WARNING "These tests has been disabled in OSX or WITH_MKL=OFF before being fixed: \n test_analyzer_seq_pool1")
endif() endif()
# RNN2 # RNN2
set(RNN2_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/rnn2") set(RNN2_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/rnn2")
download_model_and_data(${RNN2_INSTALL_DIR} "rnn2_model.tar.gz" "rnn2_data.txt.tar.gz") download_model_and_data(${RNN2_INSTALL_DIR} "rnn2_model.tar.gz" "rnn2_data.txt.tar.gz")
...@@ -115,6 +116,10 @@ if (NOT EXISTS ${MOBILENET_INSTALL_DIR}) ...@@ -115,6 +116,10 @@ if (NOT EXISTS ${MOBILENET_INSTALL_DIR})
endif() endif()
inference_analysis_api_test_with_refer_result(test_analyzer_mobilenet_transpose ${MOBILENET_INSTALL_DIR} analyzer_vis_tester.cc SERIAL) inference_analysis_api_test_with_refer_result(test_analyzer_mobilenet_transpose ${MOBILENET_INSTALL_DIR} analyzer_vis_tester.cc SERIAL)
# googlenet
inference_analysis_api_test_with_fake_data(test_analyzer_googlenet
"${INFERENCE_DEMO_INSTALL_DIR}/googlenet" analyzer_resnet50_tester.cc "googlenet.tar.gz" SERIAL)
# resnet50 # resnet50
inference_analysis_api_test_with_fake_data(test_analyzer_resnet50 inference_analysis_api_test_with_fake_data(test_analyzer_resnet50
"${INFERENCE_DEMO_INSTALL_DIR}/resnet50" analyzer_resnet50_tester.cc "resnet50_model.tar.gz" SERIAL) "${INFERENCE_DEMO_INSTALL_DIR}/resnet50" analyzer_resnet50_tester.cc "resnet50_model.tar.gz" SERIAL)
......
...@@ -67,7 +67,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -67,7 +67,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mid->mutable_data<T>(ctx.GetPlace()); mid->mutable_data<T>(ctx.GetPlace());
const int n = ctx.Attr<int>("n"); const int n = ctx.Attr<int>("n");
const float alpha = ctx.Attr<float>("alpha"); const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
const float beta = ctx.Attr<float>("beta"); const float beta = ctx.Attr<float>("beta");
const float k = ctx.Attr<float>("k"); const float k = ctx.Attr<float>("k");
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
...@@ -156,7 +156,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -156,7 +156,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const std::string key_workspace_memory = key + "@lrn_workspace_memory"; const std::string key_workspace_memory = key + "@lrn_workspace_memory";
const int n = ctx.Attr<int>("n"); const int n = ctx.Attr<int>("n");
const float alpha = ctx.Attr<float>("alpha"); const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
const float beta = ctx.Attr<float>("beta"); const float beta = ctx.Attr<float>("beta");
const float k = ctx.Attr<float>("k"); const float k = ctx.Attr<float>("k");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册