diff --git a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc index 36e88cd7895c35f7f0d268a78f421f63d0e84ff0..58aadd003313df698067e41c28b99d98a8c0bd34 100644 --- a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc @@ -95,6 +95,26 @@ static void UpdateDataFormat(const framework::ExecutionContext& ctx, } } +template +static void ReorderInput(framework::Tensor* tensor, + const platform::Place& place, + const mkldnn::engine& engine, + bool isFourDim) { + using platform::to_void_cast; + auto dims = paddle::framework::vectorize2int(tensor->dims()); + framework::Tensor out_tensor; + out_tensor.Resize(tensor->dims()); + out_tensor.set_format(isFourDim ? memory::format::nchw : memory::format::nc); + out_tensor.set_layout(tensor->layout()); + mkldnn::memory input_memory = {{{dims, platform::MKLDNNGetDataType(), + tensor->format()}, engine}, to_void_cast(tensor->data())}; + mkldnn::memory output_memory = {{{dims, platform::MKLDNNGetDataType(), + out_tensor.format()}, engine}, + to_void_cast(out_tensor.mutable_data(place))}; + platform::Reorder(input_memory, output_memory); + tensor->ShareDataWith(out_tensor); +} + template class ElementwiseMulMKLDNNKernel : public framework::OpKernel { public: @@ -111,63 +131,78 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { auto x_dims = x->dims(); auto y_dims_untrimmed = y->dims(); + auto x_int_dims = paddle::framework::vectorize2int(x_dims); UpdateDataFormat(ctx, (Tensor*)x, "x_data_format"); UpdateDataFormat(ctx, (Tensor*)y, "y_data_format"); - if (x->format() == memory::format::nChw16c && y->format() == memory::format::nc) { - if (x_dims != y_dims_untrimmed) { - int pre, n, post; - get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post); + const bool are_dims_divisable = !(x_int_dims[1] % 16); + const bool is_x_format_correct = x->format() == memory::format::nChw16c; + const bool is_y_format_correct = y->format() == memory::format::nc; + if (is_x_format_correct && is_y_format_correct && are_dims_divisable) { + int pre, n, post; + get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post); - if (post == 1) { - PADDLE_THROW("Not implemented when post is 1"); - } else { - // Just check whether it works for RE-Resnext. - PADDLE_ENFORCE_EQ(x_dims.size(), 4, "X should have 4 dimensions"); + if (post == 1) { + PADDLE_THROW("Not implemented when post is 1"); + } else { + // Just check whether it works for RE-Resnext. + PADDLE_ENFORCE_EQ(x_dims.size(), 4, "X should have 4 dimensions"); - int n = x_dims[0]; - int c = x_dims[1]; - int h = x_dims[2]; - int w = x_dims[3]; + int n = x_dims[0]; + int c = x_dims[1]; + int h = x_dims[2]; + int w = x_dims[3]; - PADDLE_ENFORCE(y_dims_untrimmed[0] == n && y_dims_untrimmed[1] == c, - "Y should be in nc format"); + PADDLE_ENFORCE(y_dims_untrimmed[0] == n && y_dims_untrimmed[1] == c, + "Y should be in nc format"); - constexpr int simd_width = 16; - int C = c / simd_width; + constexpr int simd_width = 16; + int C = c / simd_width; - vector_mul mul; + vector_mul mul; - using mul_func_t = - void (*)(const float *, const float *, float *, int, int); + using mul_func_t = + void (*)(const float *, const float *, float *, int, int); - mul_func_t mul_func = (mul_func_t) mul.getCode(); + mul_func_t mul_func = (mul_func_t) mul.getCode(); - #pragma omp parallel for collapse(2) - for (int ni = 0; ni < n; ni++) { - for (int ci = 0; ci < C; ci++) { - auto ptr_x = - x_data + ni * C * h * w * simd_width + - ci * h * w * simd_width; + #pragma omp parallel for collapse(2) + for (int ni = 0; ni < n; ni++) { + for (int ci = 0; ci < C; ci++) { + auto ptr_x = + x_data + ni * C * h * w * simd_width + + ci * h * w * simd_width; - auto ptr_y = y_data + ni * C * simd_width + ci * simd_width; - auto ptr_z = - z_data + ni * C * h * w * simd_width + - ci * h * w * simd_width; + auto ptr_y = y_data + ni * C * simd_width + ci * simd_width; + auto ptr_z = + z_data + ni * C * h * w * simd_width + + ci * h * w * simd_width; - mul_func(ptr_x, ptr_y, ptr_z, h, w); - } + mul_func(ptr_x, ptr_y, ptr_z, h, w); } } - - z->set_layout(DataLayout::kMKLDNN); - z->set_format(x->format()); - } else { - PADDLE_THROW("Not implemented when dims are equal"); } + + z->set_layout(DataLayout::kMKLDNN); + z->set_format(x->format()); } else { // Fallback to naive version: + const bool are_inputs_in_same_format = x->format() == y->format(); + const bool is_x_nchw= x->format() == memory::format::nchw; + const bool is_x_nc = x->format() == memory::format::nc; + const bool is_y_nchw= y->format() == memory::format::nchw; + const bool is_y_nc = y->format() == memory::format::nc; + if(!are_inputs_in_same_format) { + using platform::MKLDNNDeviceContext; + auto& dev_ctx = ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + if(!(is_x_nchw || is_x_nc)) + ReorderInput((Tensor*)x, ctx.GetPlace(), mkldnn_engine, x->dims().size() == 4); + if(!(is_y_nchw || is_y_nc)) + ReorderInput((Tensor*)y, ctx.GetPlace(), mkldnn_engine, y->dims().size() == 4); + } + auto mul_func = [](T a, T b) -> T { return a * b; }; TransformFunctor