未验证 提交 0b678d40 编写于 作者: J Jacek Czaja 提交者: GitHub

- sum (#28233)

test=develop
上级 c11d9b30
...@@ -80,8 +80,6 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -80,8 +80,6 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto& input0 = in_vars[0]->Get<LoDTensor>(); auto& input0 = in_vars[0]->Get<LoDTensor>();
in_place = (input0.numel() > 0) && (input0.data<T>() == output_data); in_place = (input0.numel() > 0) && (input0.data<T>() == output_data);
MKLDNNMemoryFormat input_format = input0.format();
for (size_t i = 0; i < in_vars.size(); i++) { for (size_t i = 0; i < in_vars.size(); i++) {
auto& input_it = in_vars[i]->Get<LoDTensor>(); auto& input_it = in_vars[i]->Get<LoDTensor>();
if (input_it.numel() == 0) { if (input_it.numel() == 0) {
...@@ -89,6 +87,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -89,6 +87,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
const T* input_data = input_it.data<T>(); const T* input_data = input_it.data<T>();
MKLDNNMemoryFormat input_format = input_it.format();
auto src_md = memory::desc(src_tz, memory::data_type::f32, input_format); auto src_md = memory::desc(src_tz, memory::data_type::f32, input_format);
auto src_mem = memory(src_md, mkldnn_engine, to_void_cast(input_data)); auto src_mem = memory(src_md, mkldnn_engine, to_void_cast(input_data));
...@@ -115,7 +114,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -115,7 +114,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::shared_ptr<mkldnn::reorder> reorder_p; std::shared_ptr<mkldnn::reorder> reorder_p;
std::shared_ptr<memory> target_mem; std::shared_ptr<memory> target_mem;
if (in_place) { if (in_place) {
output_format = input_format; output_format = input0.format();
target_mem.reset( target_mem.reset(
new memory({{src_tz}, memory::data_type::f32, output_format}, new memory({{src_tz}, memory::data_type::f32, output_format},
mkldnn_engine, output_data)); mkldnn_engine, output_data));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册