From 213ec37d6ad84c3774f1a5e203566dc47a1b63da Mon Sep 17 00:00:00 2001 From: Tomasz Patejko Date: Thu, 25 Oct 2018 16:18:04 +0200 Subject: [PATCH] MKLDNN elementwise_add: simple initial implementation of the operator for MKLDNN format --- .../operators/elementwise_mul_mkldnn_op.cc | 99 +++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 paddle/fluid/operators/elementwise_mul_mkldnn_op.cc diff --git a/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc new file mode 100644 index 000000000..22289ab41 --- /dev/null +++ b/paddle/fluid/operators/elementwise_mul_mkldnn_op.cc @@ -0,0 +1,99 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/elementwise_op.h" +#include "paddle/fluid/operators/elementwise_op_function.h" + +#include "paddle/fluid/platform/mkldnn_helper.h" + +namespace paddle { +namespace operators { + +using framework::DataLayout; + +template +class ElementwiseMulMKLDNNKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + using Tensor = framework::Tensor; + + int axis = ctx.Attr("axis"); + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* z = ctx.Output("Out"); + const T* x_data = x->data(); + const T* y_data = y->data(); + T* z_data = z->mutable_data(ctx.GetPlace()); + + auto x_dims = x->dims(); + auto y_dims_untrimmed = y->dims(); + + if (x_dims != y_dims_untrimmed) { + int pre, n, post; + get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post); + + if (post == 1) { + PADDLE_THROW("Not implemented when post is 1"); + } else { + // Just check whether it works for RE-Resnext. + + PADDLE_ENFORCE_EQ(x_dims.size(), 4, "X should have 4 dimensions"); + + int n = x_dims[0]; + int c = x_dims[1]; + int h = x_dims[2]; + int w = x_dims[3]; + + PADDLE_ENFORCE(y_dims_untrimmed[0] == n && y_dims_untrimmed[1] == c, + "Y should be in nc format"); + + constexpr int simd_width = 16; + int C = c / simd_width; + + for (int ni = 0; ni < n; ni++) { + for (int ci = 0; ci < C; ci++) { + for (int hi = 0; hi < h; hi++) { + for (int wi = 0; wi < w; wi++) { + auto ptr_x = x_data + ni * C * h * w * simd_width + + ci * h * w * simd_width + hi * w * simd_width + + wi * simd_width; + auto ptr_y = y_data + ni * C * simd_width + ci * simd_width; + + auto ptr_z = z_data + ni * C * h * w * simd_width + + ci * h * w * simd_width + hi * w * simd_width + + wi * simd_width; + + for (int i = 0; i < simd_width; i++) { + ptr_z[i] = ptr_x[i] * ptr_y[i]; + } + } + } + } + } + } + + z->set_layout(DataLayout::kMKLDNN); + z->set_format(x->format()); + } else { + PADDLE_THROW("Not implemented when dims are equal"); + } + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_KERNEL(elementwise_mul, MKLDNN, ::paddle::platform::CPUPlace, + ops::ElementwiseMulMKLDNNKernel) -- GitLab