提交 ac177d61 编写于 作者: Z Zhang, Guoming

merge is_test feature for pooling op(mkldnn)

上级 fa164241
...@@ -87,6 +87,8 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -87,6 +87,8 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize"); std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides"); std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
bool is_test = ctx.Attr<bool>("is_test");
if (ctx.Attr<bool>("global_pooling")) { if (ctx.Attr<bool>("global_pooling")) {
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0; paddings[i] = 0;
...@@ -145,17 +147,11 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -145,17 +147,11 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::shared_ptr<mkldnn::pooling_forward::primitive_desc> pool_pd = std::shared_ptr<mkldnn::pooling_forward::primitive_desc> pool_pd =
CreatePrimitiveDesc(src_md, dst_md, propagation, strides, padding_left_top, CreatePrimitiveDesc(src_md, dst_md, propagation, strides, padding_left_top,
padding_right_bottom, ksize, pooling_type, padding_right_bottom, ksize, pooling_type,
mkldnn_engine, ceil_mode); mkldnn_engine, ceil_mode,is_test);
// save pool_pd into global device context to be referred in backward path // save pool_pd into global device context to be referred in backward path
dev_ctx.SetBlob(key_pool_pd, pool_pd); dev_ctx.SetBlob(key_pool_pd, pool_pd);
std::shared_ptr<mkldnn::memory> workspace_memory =
CreateWorkspaceMemory(pool_pd, pooling_type, mkldnn_engine);
// save pool_workspace_memory to be referred in backward path
dev_ctx.SetBlob(key_pool_workspace_memory, workspace_memory);
auto src_memory = std::make_shared<memory>(pool_pd->src_primitive_desc(), auto src_memory = std::make_shared<memory>(pool_pd->src_primitive_desc(),
to_void_cast<T>(input_data)); to_void_cast<T>(input_data));
auto dst_memory = auto dst_memory =
...@@ -164,14 +160,17 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -164,14 +160,17 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dev_ctx.SetBlob(key_pool_src_mem_p, src_memory); dev_ctx.SetBlob(key_pool_src_mem_p, src_memory);
dev_ctx.SetBlob(key_pool_dst_mem_p, dst_memory); dev_ctx.SetBlob(key_pool_dst_mem_p, dst_memory);
if (propagation == mkldnn::prop_kind::forward_training) { if (is_test) {
pool_p = std::make_shared<pooling_forward>(*pool_pd, *(src_memory.get()), pool_p = std::make_shared<pooling_forward>(
*(dst_memory.get()), *pool_pd, *(src_memory.get()), *(dst_memory.get()));
*workspace_memory); } else {
} else{ std::shared_ptr<mkldnn::memory> workspace_memory =
pool_p = std::make_shared<pooling_forward>(*pool_pd, *(src_memory.get()), CreateWorkspaceMemory(pool_pd, pooling_type, mkldnn_engine);
*(dst_memory.get()));//, // save pool_workspace_memory to be referred in backward path
//*workspace_memory); dev_ctx.SetBlob(key_pool_workspace_memory, workspace_memory);
pool_p = std::make_shared<pooling_forward>(
*pool_pd, *(src_memory.get()), *(dst_memory.get()),
*workspace_memory);
} }
dev_ctx.SetBlob(key_pool_p, pool_p); dev_ctx.SetBlob(key_pool_p, pool_p);
...@@ -211,17 +210,22 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -211,17 +210,22 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const std::vector<int>& stride, const std::vector<int>& padding_left_top, const std::vector<int>& stride, const std::vector<int>& padding_left_top,
const std::vector<int>& padding_right_bot, const std::vector<int>& kernel, const std::vector<int>& padding_right_bot, const std::vector<int>& kernel,
const std::string& pooling_type, const mkldnn::engine& engine, const std::string& pooling_type, const mkldnn::engine& engine,
bool ceil_mode) const { bool ceil_mode, bool is_test) const {
auto pool_desc = mkldnn::pooling_forward::desc(
propagation, auto mkldnn_forward_prop_kind = is_test
? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training;
std::cout<<is_test<<" "<<__LINE__<<std::endl;
auto pool_desc = mkldnn::pooling_forward::desc(
mkldnn_forward_prop_kind,
pooling_type == "max" ? mkldnn::algorithm::pooling_max pooling_type == "max" ? mkldnn::algorithm::pooling_max
: mkldnn::algorithm::pooling_avg, : mkldnn::algorithm::pooling_avg,
src, dst, stride, kernel, padding_left_top, padding_right_bot, src, dst, stride, kernel, padding_left_top, padding_right_bot,
mkldnn::padding_kind::zero); mkldnn::padding_kind::zero);
auto p_pool_pd = auto p_pool_pd =
new mkldnn::pooling_forward::primitive_desc(pool_desc, engine); new mkldnn::pooling_forward::primitive_desc(pool_desc, engine);
return std::unique_ptr<mkldnn::pooling_forward::primitive_desc>(p_pool_pd); return std::unique_ptr<mkldnn::pooling_forward::primitive_desc>(p_pool_pd);
} }
std::unique_ptr<mkldnn::memory> CreateWorkspaceMemory( std::unique_ptr<mkldnn::memory> CreateWorkspaceMemory(
...@@ -233,7 +237,7 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -233,7 +237,7 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
: mkldnn::memory::primitive_desc({{}, : mkldnn::memory::primitive_desc({{},
platform::MKLDNNGetDataType<T>(), platform::MKLDNNGetDataType<T>(),
mkldnn::memory::format::nchw}, mkldnn::memory::format::nchw},
engine); engine);
auto p_workspace_memory = new mkldnn::memory(workspace_md); auto p_workspace_memory = new mkldnn::memory(workspace_md);
return std::unique_ptr<mkldnn::memory>(p_workspace_memory); return std::unique_ptr<mkldnn::memory>(p_workspace_memory);
......
...@@ -206,6 +206,11 @@ void Pool2dOpMaker::Make() { ...@@ -206,6 +206,11 @@ void Pool2dOpMaker::Make() {
"Defaults to \"NHWC\". Specify the data format of the output data, " "Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ") "the input will be transformed automatically. ")
.SetDefault("AnyLayout"); .SetDefault("AnyLayout");
AddAttr<bool>("is_test",
"(bool, default false) If true, the forward pass is not "
"part of training."
"MKL-DNN might be faster if this is set to true.")
.SetDefault(false);
// TODO(dzhwinter): need to registered layout transform function // TODO(dzhwinter): need to registered layout transform function
AddComment(R"DOC( AddComment(R"DOC(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册