提交 6512be59 编写于 作者: M mozga-intel

MKLDNN layout: the code-review changes

上级 96b4904d
......@@ -118,19 +118,18 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_mpd);
std::shared_ptr<memory> dst_mem;
if (in_place)
if (in_place) {
dst_mem.reset(new memory(sum_pd.dst_primitive_desc()));
else
} else {
dst_mem.reset(new memory(sum_pd.dst_primitive_desc(), output_data));
}
std::vector<mkldnn::primitive::at> inputs;
for (size_t i = 0; i < srcs_mem.size(); ++i) {
inputs.push_back(srcs_mem[i]);
}
auto sum_prim = mkldnn::sum(sum_pd, inputs, *dst_mem);
output_format =
(memory::format)sum_pd.dst_primitive_desc().desc().data.format;
output_format = (memory::format)platform::GetMKLDNNFormat(sum_pd);
primitive reorder_prim;
std::shared_ptr<memory> target_mem;
......
......@@ -99,5 +99,11 @@ inline mkldnn::memory::format GetMKLDNNFormat(const mkldnn::memory memory) {
memory.get_primitive_desc().desc().data.format);
}
inline mkldnn::memory::format GetMKLDNNFormat(
const mkldnn::sum::primitive_desc& memory) {
return static_cast<mkldnn::memory::format>(
memory.dst_primitive_desc().desc().data.format);
}
} // namespace platform
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册