/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include "paddle/fluid/operators/elementwise_op.h" #include "paddle/fluid/operators/elementwise_op_function.h" #include "paddle/fluid/platform/mkldnn_helper.h" #include "xbyak/xbyak.h" #include "xbyak/xbyak_util.h" namespace paddle { namespace operators { using framework::DataLayout; using mkldnn::memory; struct vector_mul : public Xbyak::CodeGenerator { vector_mul() { // RDI is ptr X // RSI is ptr Y // RDX is ptr Z // RCX is h // r8 is w push(rbx); xor_(rax, rax); xor_(r10, r10); vmovups(zmm3, ptr[rsi]); L("h_loop"); xor_(rbx, rbx); L("w_loop"); vmovups(zmm2, ptr[rdi + rax]); vmulps(zmm1, zmm2, zmm3); vmovups(ptr[rdx + rax], zmm1); add(rax, 64); inc(rbx); cmp(r8, rbx); jnz("w_loop"); inc(r10); cmp(r10, rcx); jnz("h_loop"); pop(rbx); ret(); } }; void check(const float* x, const float* y, float* z, int w) { for (int wi = 0; wi < w; wi++) { for (int i = 0; i < 16; i++) { z[wi * 16 + i] = x[wi * 16 + i] * y[i]; } } } static mkldnn::memory::format StringToMKLDNNFormat(std::string& format) { std::transform(format.begin(), format.end(), format.begin(), ::tolower); if (!format.compare("nchw")) { return memory::format::nchw; } else if (!format.compare("nchw16c")) { return memory::format::nChw16c; } else if (!format.compare("nchw8c")) { return memory::format::nChw8c; } else if (!format.compare("nhwc")) { return memory::format::nhwc; } else { return memory::format::any; } } static void UpdateDataFormat(const framework::ExecutionContext& ctx, framework::Tensor* tensor, const char* attribute) { if (ctx.op().HasAttr(attribute)) { auto format_as_string = ctx.Attr(attribute); auto format = StringToMKLDNNFormat(format_as_string); if (format != memory::format::any) { tensor->set_format(format); } } } template static void ReorderInput(framework::Tensor* tensor, const platform::Place& place, const mkldnn::engine& engine, bool isFourDim) { using platform::to_void_cast; auto dims = paddle::framework::vectorize2int(tensor->dims()); framework::Tensor out_tensor; out_tensor.Resize(tensor->dims()); out_tensor.set_format(isFourDim ? memory::format::nchw : memory::format::nc); out_tensor.set_layout(tensor->layout()); mkldnn::memory input_memory = { {{dims, platform::MKLDNNGetDataType(), tensor->format()}, engine}, to_void_cast(tensor->data())}; mkldnn::memory output_memory = { {{dims, platform::MKLDNNGetDataType(), out_tensor.format()}, engine}, to_void_cast(out_tensor.mutable_data(place))}; platform::Reorder(input_memory, output_memory); tensor->ShareDataWith(out_tensor); } template class ElementwiseMulMKLDNNKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { using Tensor = framework::Tensor; int axis = ctx.Attr("axis"); auto* x = ctx.Input("X"); auto* y = ctx.Input("Y"); auto* z = ctx.Output("Out"); const T* x_data = x->data(); const T* y_data = y->data(); T* z_data = z->mutable_data(ctx.GetPlace()); auto x_dims = x->dims(); auto y_dims_untrimmed = y->dims(); auto x_int_dims = paddle::framework::vectorize2int(x_dims); UpdateDataFormat(ctx, (Tensor*)x, "x_data_format"); UpdateDataFormat(ctx, (Tensor*)y, "y_data_format"); const bool are_dims_divisable = !(x_int_dims[1] % 16); const bool is_x_format_correct = x->format() == memory::format::nChw16c; const bool is_y_format_correct = y->format() == memory::format::nc; if (is_x_format_correct && is_y_format_correct && are_dims_divisable) { int pre, n, post; get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post); if (post == 1) { PADDLE_THROW("Not implemented when post is 1"); } else { // Just check whether it works for RE-Resnext. PADDLE_ENFORCE_EQ(x_dims.size(), 4, "X should have 4 dimensions"); int n = x_dims[0]; int c = x_dims[1]; int h = x_dims[2]; int w = x_dims[3]; PADDLE_ENFORCE(y_dims_untrimmed[0] == n && y_dims_untrimmed[1] == c, "Y should be in nc format"); constexpr int simd_width = 16; int C = c / simd_width; vector_mul mul; using mul_func_t = void (*)(const float*, const float*, float*, int, int); mul_func_t mul_func = (mul_func_t)mul.getCode(); #pragma omp parallel for collapse(2) for (int ni = 0; ni < n; ni++) { for (int ci = 0; ci < C; ci++) { auto ptr_x = x_data + ni * C * h * w * simd_width + ci * h * w * simd_width; auto ptr_y = y_data + ni * C * simd_width + ci * simd_width; auto ptr_z = z_data + ni * C * h * w * simd_width + ci * h * w * simd_width; mul_func(ptr_x, ptr_y, ptr_z, h, w); } } } z->set_layout(DataLayout::kMKLDNN); z->set_format(x->format()); } else { // Fallback to naive version: const bool are_inputs_in_same_format = x->format() == y->format(); const bool is_x_nchw = x->format() == memory::format::nchw; const bool is_x_nc = x->format() == memory::format::nc; const bool is_y_nchw = y->format() == memory::format::nchw; const bool is_y_nc = y->format() == memory::format::nc; if (!are_inputs_in_same_format) { using platform::MKLDNNDeviceContext; auto& dev_ctx = ctx.template device_context(); const auto& mkldnn_engine = dev_ctx.GetEngine(); if (!(is_x_nchw || is_x_nc)) ReorderInput((Tensor*)x, ctx.GetPlace(), mkldnn_engine, x->dims().size() == 4); if (!(is_y_nchw || is_y_nc)) ReorderInput((Tensor*)y, ctx.GetPlace(), mkldnn_engine, y->dims().size() == 4); } auto mul_func = [](T a, T b) -> T { return a * b; }; TransformFunctor functor( x, y, z, ctx.template device_context(), mul_func); axis = (axis == -1 ? x_dims.size() - y_dims_untrimmed.size() : axis); PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), "Axis should be in range [0, x_dims)"); auto y_dims = trim_trailing_singular_dims(y_dims_untrimmed); axis = (y_dims.size() == 0) ? x_dims.size() : axis; int pre, n, post; get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post); if (post == 1) { functor.RunRowWise(n, pre); } else { functor.RunMidWise(n, pre, post); } z->set_layout(DataLayout::kMKLDNN); z->set_format(x->format()); } } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OP_KERNEL(elementwise_mul, MKLDNN, ::paddle::platform::CPUPlace, ops::ElementwiseMulMKLDNNKernel)