diff --git a/paddle/fluid/framework/data_layout_transform.cc b/paddle/fluid/framework/data_layout_transform.cc index 5b8dfc57ba020cea259041f55a66472ea26b4eec..bc48fd3b479157d4aea390cd5f4dc61ea46dca4b 100644 --- a/paddle/fluid/framework/data_layout_transform.cc +++ b/paddle/fluid/framework/data_layout_transform.cc @@ -147,10 +147,9 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, "Input tensor type is not supported: ", in.type().name()); memory::data_type out_type = in_type; - memory::format in_format = - in_tz.size() == 2 ? memory::format::nc : in.format(); - memory::format out_format = - out_tz.size() == 2 ? memory::format::nc : ToMKLDNNFormat(out_layout); + auto in_format = MKLDNNFormatForSize(in_tz.size(), in.format()); + auto out_format = + MKLDNNFormatForSize(in_tz.size(), ToMKLDNNFormat(out_layout)); void* in_data = GetDataFromTensor(in, in_type); diff --git a/paddle/fluid/framework/data_layout_transform.h b/paddle/fluid/framework/data_layout_transform.h index 2ba84ce57fd8aa3d9aa651bdaa2930e459c74e88..67f91e4e48d3e11ed493c5e6943cb9071aff60c4 100644 --- a/paddle/fluid/framework/data_layout_transform.h +++ b/paddle/fluid/framework/data_layout_transform.h @@ -61,6 +61,13 @@ inline MKLDNNDataType ToMKLDNNDataType(const std::type_index type) { if (iter != dict.end()) return iter->second; return MKLDNNDataType::data_undef; } + +inline MKLDNNFormat MKLDNNFormatForSize(size_t dims_size, + MKLDNNFormat default_format) { + return (dims_size == 1 + ? mkldnn::memory::format::x + : dims_size == 2 ? mkldnn::memory::format::nc : default_format); +} #endif void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, diff --git a/paddle/fluid/framework/data_transform.cc b/paddle/fluid/framework/data_transform.cc index b8fcc92697ca1bf1d971f8fef020f31d405605a9..5f15e20c78fd5a333523fe9e73542c037a161cae 100644 --- a/paddle/fluid/framework/data_transform.cc +++ b/paddle/fluid/framework/data_transform.cc @@ -47,9 +47,13 @@ void DataTransform(const OpKernelType& expected_kernel_type, #ifdef PADDLE_WITH_MKLDNN // Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel // Just set layout/format. No real transform occur + + auto out_format = + MKLDNNFormatForSize(in.dims().size(), ToMKLDNNFormat(lin)); + out.ShareDataWith(input_tensor); out.set_layout(DataLayout::kMKLDNN); - out.set_format(ToMKLDNNFormat(lin)); + out.set_format(out_format); #endif } else { // Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel diff --git a/paddle/fluid/operators/elementwise_add_mkldnn_op.cc b/paddle/fluid/operators/elementwise_add_mkldnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..3f612256840825a75f49944ab97ff957d572a863 --- /dev/null +++ b/paddle/fluid/operators/elementwise_add_mkldnn_op.cc @@ -0,0 +1,190 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/operators/elementwise_add_op.h" +#include "paddle/fluid/operators/elementwise_op_function.h" + +#include "paddle/fluid/platform/mkldnn_helper.h" + +namespace paddle { +namespace operators { + +using framework::DataLayout; +using framework::Tensor; +using mkldnn::memory; +using mkldnn::reorder; +using mkldnn::primitive; +using mkldnn::stream; +using mkldnn::sum; + +template +class EltwiseAddMKLDNNKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = + ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* z = ctx.Output("Out"); + const T* x_data = x->data(); + const T* y_data = y->data(); + T* z_data = z->mutable_data(ctx.GetPlace()); + + int axis = ctx.Attr("axis"); + + auto x_dims = x->dims(); + auto y_dims = y->dims(); + auto z_dims = z->dims(); + + // Execute default elementwise_add operator when + // broadcast operations need to performed. + if (x_dims != y_dims) { + auto sum_func = [](T a, T b) -> T { return a + b; }; + + TransformFunctor + functor( + x, y, z, + ctx.template device_context(), + sum_func); + + axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); + PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), + "Axis should be in range [0, x_dims)"); + + trim_trailing_singular_dims(&y_dims); + axis = (y_dims.size() == 0) ? x_dims.size() : axis; + + int pre, n, post; + get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post); + + if (post == 1) { + functor.RunRowWise(n, pre); + } else { + functor.RunMidWise(n, pre, post); + } + z->set_layout(DataLayout::kMKLDNN); + z->set_format(x->format()); + } else { + PADDLE_ENFORCE(x->layout() == DataLayout::kMKLDNN && + x->format() != memory::format::format_undef, + "Wrong layout/format set for X tensor"); + PADDLE_ENFORCE(y->layout() == DataLayout::kMKLDNN && + y->format() != memory::format::format_undef, + "Wrong layout/format set for X tensor"); + + std::vector src_x_tz = framework::vectorize2int(x_dims); + std::vector src_y_tz = framework::vectorize2int(y_dims); + std::vector dst_tz = framework::vectorize2int(z_dims); + + std::vector srcs_pd; + std::vector srcs; + std::vector scales = {1.0f, 1.0f}; + + auto src_x_pd = memory::primitive_desc( + {{src_x_tz}, memory::data_type::f32, x->format()}, mkldnn_engine); + auto src_y_pd = memory::primitive_desc( + {{src_y_tz}, memory::data_type::f32, y->format()}, mkldnn_engine); + auto src_x_memory = + memory(src_x_pd, paddle::platform::to_void_cast(x_data)); + auto src_y_memory = + memory(src_y_pd, paddle::platform::to_void_cast(y_data)); + + srcs_pd.push_back(src_x_pd); + srcs_pd.push_back(src_y_pd); + srcs.push_back(src_x_memory); + srcs.push_back(src_y_memory); + + auto dst_md = + memory::desc({dst_tz}, memory::data_type::f32, memory::format::any); + + // create primitive descriptor for sum + auto sum_pd = sum::primitive_desc(dst_md, scales, srcs_pd); + + // create mkldnn memory for dst + memory dst_memory = memory(sum_pd.dst_primitive_desc(), z_data); + + std::vector inputs; + inputs.push_back(srcs[0]); + inputs.push_back(srcs[1]); + + // create sum primitive + auto sum_prim = sum(sum_pd, inputs, dst_memory); + + std::vector pipeline; + pipeline.push_back(sum_prim); + stream(stream::kind::eager).submit(pipeline).wait(); + + z->set_layout(DataLayout::kMKLDNN); + z->set_format( + (memory::format)dst_memory.get_primitive_desc().desc().data.format); + } + } +}; + +template +class EltwiseAddMKLDNNGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + using Tensor = framework::Tensor; + + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* out = ctx.Input("Out"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + int axis = ctx.Attr("axis"); + + auto set_mkldnn_format = [](Tensor* in, const Tensor* out) { + in->set_layout(DataLayout::kMKLDNN); + in->set_format(out->format()); + }; + + if (x->dims() == y->dims()) { + auto blas = math::GetBlas(ctx); + if (dx) { + blas.VCOPY(dout->numel(), dout->data(), + dx->mutable_data(ctx.GetPlace())); + set_mkldnn_format(dx, dout); + } + + if (dy) { + blas.VCOPY(dout->numel(), dout->data(), + dy->mutable_data(ctx.GetPlace())); + set_mkldnn_format(dy, dout); + } + } else { + // Execute default kernel when broadcast is needed + ElemwiseGradCompute, IdentityGrad>( + ctx, *x, *y, *out, *dout, axis, dx, dy, IdentityGrad(), + IdentityGrad()); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_KERNEL(elementwise_add, MKLDNN, ::paddle::platform::CPUPlace, + ops::EltwiseAddMKLDNNKernel) + +REGISTER_OP_KERNEL(elementwise_add_grad, MKLDNN, ::paddle::platform::CPUPlace, + ops::EltwiseAddMKLDNNGradKernel) diff --git a/paddle/fluid/operators/elementwise_op.h b/paddle/fluid/operators/elementwise_op.h index 12364fff96c03c5f9dff23c7c00ceedd043803a6..bb88970e42c194d9437609b62435f1a89e2b446b 100644 --- a/paddle/fluid/operators/elementwise_op.h +++ b/paddle/fluid/operators/elementwise_op.h @@ -14,8 +14,12 @@ limitations under the License. */ #pragma once #include +#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif namespace paddle { namespace operators { @@ -40,6 +44,21 @@ class ElementwiseOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Out", x_dim); ctx->ShareLoD("X", /*->*/ "Out"); } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto input_data_type = + framework::ToDataType(ctx.Input("X")->type()); + +#ifdef PADDLE_WITH_MKLDNN + if (platform::CanMKLDNNBeUsed(ctx)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class ElementwiseOpInferVarType : public framework::VarTypeInference { @@ -65,6 +84,8 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker { "for broadcasting Y onto X.") .SetDefault(-1) .EqualGreaterThan(-1); + AddAttr("use_mkldnn", "(bool, default false). Used by MKLDNN.") + .SetDefault(false); AddComment(string::Sprintf(R"DOC( Limited Elementwise %s Operator @@ -138,6 +159,21 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ctx->SetOutputDim(y_grad_name, y_dims); } } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto input_data_type = + framework::ToDataType(ctx.Input("X")->type()); + +#ifdef PADDLE_WITH_MKLDNN + if (platform::CanMKLDNNBeUsed(ctx)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_add_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_add_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..bcdbfc8e527d0dc9a95eddaf040f8035207b6c20 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_elementwise_add_mkldnn_op.py @@ -0,0 +1,130 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +import numpy as np +import paddle.fluid.core as core +from op_test import OpTest +from test_elementwise_add_op import * +''' +Some tests differ from the tests defined in test_elementwise_add_op.py +because MKLDNN does not support tensors of number of dimensions 3. +Such dimensions cause exceptions in MKLDNN reorder primitive. +''' + + +class TestMKLDNNElementwiseAddOp(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.uniform(0.1, 1, [2, 3, 4, 5]).astype(self.dtype) + self.y = np.random.uniform(0.1, 1, [2, 3, 4, 5]).astype(self.dtype) + self.out = np.add(self.x, self.y) + + def init_kernel_type(self): + self.use_mkldnn = True + + +class TestMKLDNNElementwiseAddOp_scalar(TestElementwiseAddOp_scalar): + def init_input_output(self): + self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype) + self.y = np.random.rand(1).astype(self.dtype) + self.out = self.x + self.y + + def init_kernel_type(self): + self.use_mkldnn = True + + +class TestMKLDNNElementwiseAddOp_scalar2(TestElementwiseAddOp_scalar2): + def init_input_output(self): + self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype) + self.y = np.random.rand(1, 1).astype(self.dtype) + self.out = self.x + self.y + + def init_kernel_type(self): + self.use_mkldnn = True + + +class TestMKLDNNElementwiseAddOp_Vector(TestElementwiseAddOp_Vector): + def init_kernel_type(self): + self.use_mkldnn = True + + +class TesMKLDNNtElementwiseAddOp_broadcast_0(TestElementwiseAddOp_broadcast_0): + def init_input_output(self): + self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype) + self.y = np.random.rand(2).astype(self.dtype) + self.out = self.x + self.y.reshape(2, 1, 1, 1) + + def init_kernel_type(self): + self.use_mkldnn = True + + +class TestMKLDNNElementwiseAddOp_broadcast_1(TestElementwiseAddOp_broadcast_1): + def init_input_output(self): + self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype) + self.y = np.random.rand(3).astype(self.dtype) + self.out = self.x + self.y.reshape(1, 3, 1, 1) + + def init_kernel_type(self): + self.use_mkldnn = True + + +class TestMKLDNNElementwiseAddOp_broadcast_2(TestElementwiseAddOp_broadcast_2): + def init_input_output(self): + self.x = np.random.rand(2, 2, 3, 4).astype(self.dtype) + self.y = np.random.rand(4).astype(self.dtype) + self.out = self.x + self.y.reshape(1, 1, 1, 4) + + def init_kernel_type(self): + self.use_mkldnn = True + + +class TestMKLDNNElementwiseAddOp_broadcast_3(TestElementwiseAddOp_broadcast_3): + def init_kernel_type(self): + self.use_mkldnn = True + + +class TestMKLDNNElementwiseAddOp_broadcast_4(TestElementwiseAddOp_broadcast_4): + def init_kernel_type(self): + self.use_mkldnn = True + + +class TestMKLDNNElementwiseAddOp_rowwise_add_0( + TestElementwiseAddOp_rowwise_add_0): + def init_input_output(self): + self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype) + self.y = np.random.rand(3, 4).astype(self.dtype) + self.out = self.x + self.y.reshape(1, 3, 4, 1) + + def init_kernel_type(self): + self.use_mkldnn = True + + +class TestMKLDNNElementwiseAddOp_rowwise_add_1( + TestElementwiseAddOp_rowwise_add_1): + def init_kernel_type(self): + self.use_mkldnn = True + + +class TestMKLDNNElementwiseAddOp_channelwise_add( + TestElementwiseAddOp_channelwise_add): + def init_input_output(self): + self.x = np.random.rand(3, 5, 20, 20).astype(self.dtype) + self.y = np.random.rand(3, 1, 1, 1).astype(self.dtype) + self.out = self.x + self.y + + def init_kernel_type(self): + self.use_mkldnn = True + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py index 96d47906a0606bba4b1d2207f7da85b058e42a2b..fb9a496126f0b6efcad73590c78efe5a47b88cd6 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_add_op.py @@ -18,19 +18,23 @@ from op_test import OpTest class TestElementwiseAddOp(OpTest): + def init_kernel_type(self): + self.use_mkldnn = False + def setUp(self): self.op_type = "elementwise_add" self.dtype = np.float32 self.axis = -1 self.init_dtype() self.init_input_output() + self.init_kernel_type() self.init_axis() self.inputs = { 'X': OpTest.np_dtype_to_fluid_dtype(self.x), 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) } - self.attrs = {'axis': self.axis} + self.attrs = {'axis': self.axis, 'use_mkldnn': self.use_mkldnn} self.outputs = {'Out': self.out} def test_check_output(self):