提交 6512be59 编写于 作者: M mozga-intel

MKLDNN layout: the code-review changes

上级 96b4904d
...@@ -118,19 +118,18 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -118,19 +118,18 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_mpd); auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_mpd);
std::shared_ptr<memory> dst_mem; std::shared_ptr<memory> dst_mem;
if (in_place) if (in_place) {
dst_mem.reset(new memory(sum_pd.dst_primitive_desc())); dst_mem.reset(new memory(sum_pd.dst_primitive_desc()));
else } else {
dst_mem.reset(new memory(sum_pd.dst_primitive_desc(), output_data)); dst_mem.reset(new memory(sum_pd.dst_primitive_desc(), output_data));
}
std::vector<mkldnn::primitive::at> inputs; std::vector<mkldnn::primitive::at> inputs;
for (size_t i = 0; i < srcs_mem.size(); ++i) { for (size_t i = 0; i < srcs_mem.size(); ++i) {
inputs.push_back(srcs_mem[i]); inputs.push_back(srcs_mem[i]);
} }
auto sum_prim = mkldnn::sum(sum_pd, inputs, *dst_mem); auto sum_prim = mkldnn::sum(sum_pd, inputs, *dst_mem);
output_format = output_format = (memory::format)platform::GetMKLDNNFormat(sum_pd);
(memory::format)sum_pd.dst_primitive_desc().desc().data.format;
primitive reorder_prim; primitive reorder_prim;
std::shared_ptr<memory> target_mem; std::shared_ptr<memory> target_mem;
......
...@@ -99,5 +99,11 @@ inline mkldnn::memory::format GetMKLDNNFormat(const mkldnn::memory memory) { ...@@ -99,5 +99,11 @@ inline mkldnn::memory::format GetMKLDNNFormat(const mkldnn::memory memory) {
memory.get_primitive_desc().desc().data.format); memory.get_primitive_desc().desc().data.format);
} }
inline mkldnn::memory::format GetMKLDNNFormat(
const mkldnn::sum::primitive_desc& memory) {
return static_cast<mkldnn::memory::format>(
memory.dst_primitive_desc().desc().data.format);
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册