From 9b5017366a5e2ab020a30f3a67bd444bf07a5e1a Mon Sep 17 00:00:00 2001 From: Yihua Xu Date: Mon, 10 Jun 2019 10:43:57 +0800 Subject: [PATCH] Fix the format issue when 'X' is not nchw. (#17833) test=develop --- .../mkldnn/elementwise_add_mkldnn_op.cc | 38 ++++++++++++++++++- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc index 6a6741d8f..2779f6dd9 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" +#include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/platform/mkldnn_helper.h" namespace paddle { @@ -53,12 +54,45 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel { // Execute default elementwise_add operator when // broadcast operations need to performed. if (x_dims != y_dims_untrimed) { + Tensor _x; + mkldnn::memory::format format; + std::vector src_x_tz = framework::vectorize2int(x_dims); + + if ((src_x_tz.size() == 3 && + x->format() != (format = memory::format::ncw)) || + (src_x_tz.size() == 4 && + x->format() != (format = memory::format::nchw)) || + (src_x_tz.size() == 5 && + x->format() != (format = memory::format::ncdhw))) { + _x.Resize(x_dims); + auto user_x_memory_pd = memory::primitive_desc( + {{src_x_tz}, memory::data_type::f32, x->format()}, mkldnn_engine); + auto x_memory_pd = memory::primitive_desc( + {{src_x_tz}, memory::data_type::f32, format}, mkldnn_engine); + auto size = x_memory_pd.get_size(); + _x.mutable_data(ctx.GetPlace(), paddle::memory::Allocator::kDefault, + size); + auto user_x_memory = + memory(user_x_memory_pd, paddle::platform::to_void_cast(x_data)); + auto x_memory = memory(x_memory_pd, + paddle::platform::to_void_cast(_x.data())); + + auto x_reorder = reorder(user_x_memory, x_memory); + + std::vector pipeline; + pipeline.push_back(x_reorder); + stream(stream::kind::eager).submit(pipeline).wait(); + } else { + format = x->format(); + _x.ShareDataWith(*x); + } + auto sum_func = [](T a, T b) -> T { return a + b; }; TransformFunctor functor( - x, y, z, + &_x, y, z, ctx.template device_context(), sum_func); @@ -78,7 +112,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel { functor.RunMidWise(n, pre, post); } z->set_layout(DataLayout::kMKLDNN); - z->set_format(x->format()); + z->set_format(format); } else { PADDLE_ENFORCE(x->layout() == DataLayout::kMKLDNN && x->format() != memory::format::format_undef, -- GitLab