提交 fe581b0e 编写于 作者: A Adam 提交者: Tao Luo

Minor GetMKLDNNFormat changes (#20055)

test=develop
上级 54e07994
...@@ -167,9 +167,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -167,9 +167,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
z->set_layout(DataLayout::kMKLDNN); z->set_layout(DataLayout::kMKLDNN);
z->set_format((MKLDNNMemoryFormat)dst_memory->get_primitive_desc() z->set_format(platform::GetMKLDNNFormat(*dst_memory));
.desc()
.data.format);
} }
} }
}; };
......
...@@ -47,11 +47,6 @@ static memory::primitive_desc CreateMemPrimDesc(const Tensor& input, ...@@ -47,11 +47,6 @@ static memory::primitive_desc CreateMemPrimDesc(const Tensor& input,
return mem_prim_desc; return mem_prim_desc;
} }
static MKLDNNMemoryFormat GetDstMemFormat(
const concat::primitive_desc& concat_pd) {
return (MKLDNNMemoryFormat)concat_pd.dst_primitive_desc().desc().data.format;
}
static platform::CPUPlace GetCpuPlace( static platform::CPUPlace GetCpuPlace(
const paddle::framework::ExecutionContext& ctx) { const paddle::framework::ExecutionContext& ctx) {
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
...@@ -198,7 +193,7 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -198,7 +193,7 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
stream(stream::kind::eager).submit({*concat_p}).wait(); stream(stream::kind::eager).submit({*concat_p}).wait();
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetDstMemFormat(*concat_pd)); output->set_format(platform::GetMKLDNNFormat(*dst_mem));
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -71,7 +71,7 @@ class FCPrimitiveFactory { ...@@ -71,7 +71,7 @@ class FCPrimitiveFactory {
input_->set_data_handle(const_cast<T*>(in->data<T>())); input_->set_data_handle(const_cast<T*>(in->data<T>()));
output_->set_data_handle(out->mutable_data<T>(ctx.GetPlace())); output_->set_data_handle(out->mutable_data<T>(ctx.GetPlace()));
if (out->format() == MKLDNNMemoryFormat::format_undef) { if (out->format() == MKLDNNMemoryFormat::format_undef) {
auto output_format = output_->get_primitive_desc().desc().data.format; auto output_format = platform::GetMKLDNNFormat(*output_);
out->set_format((MKLDNNMemoryFormat)output_format); out->set_format((MKLDNNMemoryFormat)output_format);
} }
} }
...@@ -199,8 +199,9 @@ class FCPrimitiveFactory { ...@@ -199,8 +199,9 @@ class FCPrimitiveFactory {
auto dst_prim_desc = fc_prim_desc.dst_primitive_desc(); auto dst_prim_desc = fc_prim_desc.dst_primitive_desc();
auto buffer_size = dst_prim_desc.get_size(); auto buffer_size = dst_prim_desc.get_size();
T* output_data = output->mutable_data<T>(ctx.GetPlace(), buffer_size); T* output_data = output->mutable_data<T>(ctx.GetPlace(), buffer_size);
output->set_format((MKLDNNMemoryFormat)dst_prim_desc.desc().data.format); memory dst_mem(dst_prim_desc, to_void_cast<T>(output_data));
return memory(dst_prim_desc, to_void_cast<T>(output_data)); output->set_format(platform::GetMKLDNNFormat(dst_mem));
return dst_mem;
} }
void RecomputeOutputDims(const ExecutionContext& ctx, const LoDTensor* input, void RecomputeOutputDims(const ExecutionContext& ctx, const LoDTensor* input,
......
...@@ -77,13 +77,8 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -77,13 +77,8 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<mkldnn::primitive> pipeline = {*lrn_p}; std::vector<mkldnn::primitive> pipeline = {*lrn_p};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
auto output_format =
(mkldnn::memory::format)dst_memory->get_primitive_desc()
.desc()
.data.format;
out->set_layout(framework::DataLayout::kMKLDNN); out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(output_format); out->set_format(platform::GetMKLDNNFormat(*dst_memory));
} }
}; };
...@@ -129,13 +124,8 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -129,13 +124,8 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
std::vector<mkldnn::primitive> pipeline = {*lrn_bwd}; std::vector<mkldnn::primitive> pipeline = {*lrn_bwd};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
auto output_format =
(mkldnn::memory::format)diff_src_memory->get_primitive_desc()
.desc()
.data.format;
x_grad->set_layout(framework::DataLayout::kMKLDNN); x_grad->set_layout(framework::DataLayout::kMKLDNN);
x_grad->set_format(output_format); x_grad->set_format(platform::GetMKLDNNFormat(*diff_src_memory));
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -107,7 +107,7 @@ class MulPrimitiveFactory { ...@@ -107,7 +107,7 @@ class MulPrimitiveFactory {
output_->set_data_handle(out->mutable_data<OT>(ctx.GetPlace())); output_->set_data_handle(out->mutable_data<OT>(ctx.GetPlace()));
if (out->format() == MKLDNNMemoryFormat::format_undef) { if (out->format() == MKLDNNMemoryFormat::format_undef) {
auto output_format = output_->get_primitive_desc().desc().data.format; auto output_format = platform::GetMKLDNNFormat(*output_);
out->set_format((MKLDNNMemoryFormat)output_format); out->set_format((MKLDNNMemoryFormat)output_format);
} }
} }
...@@ -139,8 +139,9 @@ class MulPrimitiveFactory { ...@@ -139,8 +139,9 @@ class MulPrimitiveFactory {
auto buffer_size = dst_prim_desc.get_size(); auto buffer_size = dst_prim_desc.get_size();
OT *output_data = output->mutable_data<OT>(ctx.GetPlace(), buffer_size); OT *output_data = output->mutable_data<OT>(ctx.GetPlace(), buffer_size);
output->set_format((MKLDNNMemoryFormat)dst_prim_desc.desc().data.format); memory dst_mem(dst_prim_desc, to_void_cast<OT>(output_data));
return memory(dst_prim_desc, to_void_cast<OT>(output_data)); output->set_format(platform::GetMKLDNNFormat(dst_mem));
return dst_mem;
} }
memory Reorder(const memory::desc &src_desc, const memory::desc &dst_desc, memory Reorder(const memory::desc &src_desc, const memory::desc &dst_desc,
......
...@@ -95,11 +95,8 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -95,11 +95,8 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<mkldnn::primitive> pipeline{*pool_p}; std::vector<mkldnn::primitive> pipeline{*pool_p};
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
auto output_format =
(MKLDNNMemoryFormat)dst_memory->get_primitive_desc().desc().data.format;
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(output_format); output->set_format(platform::GetMKLDNNFormat(*dst_memory));
} }
}; };
...@@ -179,12 +176,8 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -179,12 +176,8 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
pipeline.push_back(*pool_bwd_p); pipeline.push_back(*pool_bwd_p);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
auto in_x_grad_format =
(MKLDNNMemoryFormat)diff_src_memory->get_primitive_desc()
.desc()
.data.format;
in_x_grad->set_layout(DataLayout::kMKLDNN); in_x_grad->set_layout(DataLayout::kMKLDNN);
in_x_grad->set_format(in_x_grad_format); in_x_grad->set_format(platform::GetMKLDNNFormat(*diff_src_memory));
} // Compute() } // Compute()
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册