diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index 69609fa5bcdeb24e73e0ea6608a16445388477e4..75bbe2ba8c3ec88e9034a1130e01638a23a44dee 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -17,6 +17,9 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/blas.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif namespace paddle { namespace operators { @@ -409,6 +412,21 @@ class MatMulOp : public framework::OperatorWithKernel { context->SetOutputDim("Out", framework::make_ddim(dim_out)); context->ShareLoD("X", /*->*/ "Out"); } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + +#ifdef PADDLE_WITH_MKLDNN + using mkldnn::memory; + if (platform::CanMKLDNNBeUsed(ctx)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class MatMulOpMaker : public framework::OpProtoAndCheckerMaker { @@ -426,6 +444,30 @@ class MatMulOpMaker : public framework::OpProtoAndCheckerMaker { )DOC") .SetDefault(false); AddAttr("alpha", "The scale of Out").SetDefault(1.0f); + AddAttr( + "use_mkldnn", + "(bool, default false) Indicates if MKL-DNN kernel will be used") + .SetDefault(false); + /* int8 parameters */ + AddAttr("use_quantizer", + "(bool, default false) " + "Set to true for operators that should be quantized and use " + "int8 kernel. " + "Only used on CPU.") + .SetDefault(false); + AddAttr("Scale_x", + "(float, default 1.0f), The quantize scale of X tensor") + .SetDefault(1.0f); + AddAttr("Scale_y", + "(float, default 1.0f), The quantize scale of Y tensor") + .SetDefault(1.0f); + AddAttr("Scale_out", + "(float, default 1.0f), The quantize scale of output data") + .SetDefault(1.0f); + AddAttr("force_fp32_output", + "(bool, default false) Force INT8 kernel output FP32, only " + "used in MKL-DNN INT8") + .SetDefault(false); #if defined(PADDLE_WITH_MKLML) && !defined(PADDLE_WITH_CUDA) AddAttr("head_number", "The number of heads of the matrix") .SetDefault(1); diff --git a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..338a5206356e336a7fda002d33077400e8e306de --- /dev/null +++ b/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc @@ -0,0 +1,258 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "mkldnn.hpp" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/platform/mkldnn_helper.h" + +namespace paddle { +namespace operators { + +using dnnl::memory; +using dnnl::primitive; +using platform::to_void_cast; +using framework::DataLayout; +using platform::GetMKLDNNFormat; +using platform::MKLDNNGetDataType; +using platform::MKLDNNDeviceContext; +using framework::ExecutionContext; +using Tensor = framework::Tensor; + +// Get row matrix shape from a vector shape. If the rank of x_dim > 1, the +// original x_dim is returned. +static framework::DDim RowMatrixDimsFromVector(const framework::DDim& x_dim) { + return x_dim.size() > 1 ? x_dim : framework::make_ddim({1, x_dim[0]}); +} + +// Get column matrix shape from a vector shape. If the ran of y_dim > 1, the +// original y_dim is returned. +static framework::DDim ColumnMatrixDimsFromVector( + const framework::DDim& y_dim) { + return y_dim.size() > 1 ? y_dim : framework::make_ddim({y_dim[0], 1}); +} + +template +class MatMulFactory { + public: + void CreateAndExecute(const ExecutionContext& ctx) { + SetDNNLEngine(ctx); + if (IsInitialized()) { + UpdateDataPointers(ctx); + Execute(); + SetOutputFormat(ctx); + return; + } + CreateMemories(ctx); + CreatePrimitive(ctx); + Execute(); + SetOutputFormat(ctx); + SetInitialized(); + } + + private: + struct MatMulDims { + const memory::dim BS, M, N, K; + }; + + void SetDNNLEngine(const ExecutionContext& ctx) { + auto& dev_ctx = + ctx.template device_context(); + engine_ = dev_ctx.GetEngine(); + } + + template + dnnl::memory CreateMemory(const memory::dims& dims, + const memory::dims& strides, const T* data) { + auto md = memory::desc(dims, MKLDNNGetDataType(), strides); + return dnnl::memory(md, engine_, to_void_cast(data)); + } + + MatMulDims GetMatmulDims(const ExecutionContext& ctx) { + auto mat_dim_x = math::CreateMatrixDescriptor( + RowMatrixDimsFromVector(ctx.Input("X")->dims()), 0, + ctx.Attr("transpose_X")); + auto mat_dim_y = math::CreateMatrixDescriptor( + ColumnMatrixDimsFromVector(ctx.Input("Y")->dims()), 0, + ctx.Attr("transpose_Y")); + + const auto x_bs = mat_dim_x.batch_size_; + const auto y_bs = mat_dim_y.batch_size_; + PADDLE_ENFORCE_EQ(x_bs > 0 && y_bs > 0 && x_bs != y_bs, false, + platform::errors::InvalidArgument( + "If batch sizes of X and Y are positive," + "they have to be equal.")); + + // Store 1 if both batches are zero, otherwise save the nonzero batch + const memory::dim BS = x_bs || y_bs ? std::max(x_bs, y_bs) : 1; + const memory::dim M = mat_dim_x.height_; + const memory::dim N = mat_dim_y.width_; + const memory::dim K = mat_dim_x.width_; + return {BS, M, N, K}; + } + + void CreateMemories(const ExecutionContext& ctx) { + auto matmul_dims = GetMatmulDims(ctx); + auto BS = matmul_dims.BS; + auto M = matmul_dims.M; + auto N = matmul_dims.N; + auto K = matmul_dims.K; + bool x_trans = ctx.Attr("transpose_X"); + bool y_trans = ctx.Attr("transpose_Y"); + + typedef memory::dims dims; + dims x_dims = {BS, M, K}; + dims y_dims = {BS, K, N}; + dims out_dims = {BS, M, N}; + + // Translate transA and transB + dims x_strides = !x_trans ? dims{M * K, K, 1} : dims{M * K, 1, M}; + dims y_strides = !y_trans ? dims{N * K, N, 1} : dims{N * K, 1, K}; + dims out_strides = {M * N, N, 1}; + + x_mem_ = + CreateMemory(x_dims, x_strides, ctx.Input("X")->data()); + y_mem_ = + CreateMemory(y_dims, y_strides, ctx.Input("Y")->data()); + out_mem_ = CreateMemory( + out_dims, out_strides, + ctx.Output("Out")->mutable_data(ctx.GetPlace())); + } + + float ComputeOutputScale(const ExecutionContext& ctx) { + float scale_x = ctx.Attr("Scale_x"); + float scale_y = ctx.Attr("Scale_y"); + bool force_fp32_out = ctx.Attr("force_fp32_output"); + float scale_out = force_fp32_out ? 1.f : ctx.Attr("Scale_out"); + float alpha = ctx.Attr("alpha"); + return alpha * scale_out / (scale_x * scale_y); + } + + void CreatePrimitive(const ExecutionContext& ctx) { + dnnl::primitive_attr attr; + float scale_out = ComputeOutputScale(ctx); + if (scale_out != 1.0f) { + constexpr unsigned tensor_wide_scale = 0; + attr.set_output_scales(tensor_wide_scale, {scale_out}); + } + + auto matmul_d = dnnl::matmul::desc(x_mem_.get_desc(), y_mem_.get_desc(), + out_mem_.get_desc()); + auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, engine_); + matmul_prim_ = dnnl::matmul(matmul_pd); + } + + void Execute() { + dnnl::stream stream(engine_); + matmul_prim_.execute(stream, { + {MKLDNN_ARG_SRC, x_mem_}, + {MKLDNN_ARG_WEIGHTS, y_mem_}, + {MKLDNN_ARG_DST, out_mem_}, + }); + stream.wait(); + } + + void SetOutputFormat(const ExecutionContext& ctx) { + using platform::MKLDNNFormatForSize; + auto* out = ctx.Output("Out"); + auto format = + MKLDNNFormatForSize(out->dims().size(), MKLDNNMemoryFormat::nchw); + out->set_format(format); + out->set_layout(DataLayout::kMKLDNN); + } + + void UpdateDataPointers(const ExecutionContext& ctx) { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* out = ctx.Output("Out"); + x_mem_.set_data_handle(to_void_cast(x->data())); + y_mem_.set_data_handle(to_void_cast(y->data())); + out_mem_.set_data_handle(out->mutable_data(ctx.GetPlace())); + } + + // If initialized, x memory should've been already initialized + bool IsInitialized() { return initialized_; } + + void SetInitialized() { initialized_ = true; } + + private: + dnnl::engine engine_; + dnnl::memory x_mem_; + dnnl::memory y_mem_; + dnnl::memory out_mem_; + dnnl::matmul matmul_prim_; + bool initialized_ = false; +}; + +template +static std::shared_ptr> GetPrimitiveFactory( + const ExecutionContext& ctx) { + const auto x_dims = framework::vectorize(ctx.Input("X")->dims()); + const auto y_dims = framework::vectorize(ctx.Input("Y")->dims()); + const auto& out_name = ctx.OutputName("Out"); + const auto& dev_ctx = ctx.template device_context(); + + const std::string key = + platform::CreateKey(platform::ThreadIDasStr(), x_dims, y_dims, out_name); + + auto factory = + std::static_pointer_cast>(dev_ctx.GetBlob(key)); + if (factory == nullptr) { + factory = std::make_shared>(); + dev_ctx.SetBlob(key, factory); + } + + return factory; +} + +template +constexpr bool IsInt8() { + return std::is_same::value || std::is_same::value; +} +// Choose appropriate primitive factory implementation based on inferred +// output type (uint8, int8 or float). +template +static void ExecuteMatMul(const ExecutionContext& ctx) { + constexpr bool is_int8 = IsInt8(); + const bool force_fp32_output = ctx.Attr("force_fp32_output"); + constexpr bool fuse_relu = false; // TODO(intel): Enable eltwise fuses + if (!is_int8 || force_fp32_output) { + GetPrimitiveFactory(ctx)->CreateAndExecute(ctx); + } else if (fuse_relu) { + GetPrimitiveFactory(ctx)->CreateAndExecute(ctx); + } else { + GetPrimitiveFactory(ctx)->CreateAndExecute(ctx); + } +} + +template +class DNNLMatMulKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + if (ctx.HasAttr("head_number")) { + PADDLE_ENFORCE_EQ(ctx.Attr("head_number"), 1, + platform::errors::Unimplemented( + "DNNL matmul doesn't support multiple heads.")); + } + ExecuteMatMul(ctx); + } +}; +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; + +REGISTER_OP_KERNEL(matmul, MKLDNN, ::paddle::platform::CPUPlace, + ops::DNNLMatMulKernel, ops::DNNLMatMulKernel, + ops::DNNLMatMulKernel); diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 13230cab0d17237a353bb24da054d48fde920823..cc5e6b0da14a0c3f11711a348a937a249b43455d 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include #include #include #include #include #include +#include "mkldnn.hpp" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/place.h" namespace paddle { diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..de547a6a19052ff9c0d50e7115acebc6833c2ca5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_matmul_mkldnn_op.py @@ -0,0 +1,165 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest, os +import numpy as np +from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci + + +@skip_check_grad_ci(reason="DNNL's MatMul doesn't implemend grad kernel.") +class TestDnnlMatMulOp(OpTest): + def generate_data(self): + self.x = np.random.random((25, 2, 2)).astype("float32") + self.y = np.random.random((25, 2, 2)).astype("float32") + self.alpha = 1.0 + self.out = self.alpha * np.matmul(self.x, self.y) + + def set_attributes(self): + self.alpha = self.alpha if hasattr(self, 'alpha') else 1.0 + self.attrs = {'alpha': self.alpha} + + def setUp(self): + # Set max isa, otherwise fails on SKX and earlier + os.environ["DNNL_MAX_CPU_ISA"] = "AVX" + self.op_type = "matmul" + self._cpu_only = True + self.use_mkldnn = True + self.generate_data() + self.set_attributes() + self.attrs['use_mkldnn'] = True + + self.inputs = {'X': self.x, 'Y': self.y} + self.outputs = {'Out': self.out} + + def test_check_output(self): + self.check_output() + + +class TestDnnlMatMulOpAlpha(TestDnnlMatMulOp): + def generate_data(self): + self.x = np.random.random((17, 2, 3)).astype("float32") + self.y = np.random.random((17, 3, 2)).astype("float32") + self.alpha = 2.0 + self.out = self.alpha * np.matmul(self.x, self.y) + + +class TestDnnlMatMulOp2D(TestDnnlMatMulOp): + def print_tensor(self, name, tensor): + print(name) + print(tensor) + + def generate_data(self): + self.x = np.random.random((12, 9)).astype("float32") + self.y = np.random.random((9, 12)).astype("float32") + self.out = np.matmul(self.x, self.y) + + +class TestDnnlMatMulOpTransposeX(TestDnnlMatMulOp): + def generate_data(self): + self.x = np.random.random((12, 9)).astype("float32") + self.y = np.random.random((12, 9)).astype("float32") + self.out = np.matmul(np.transpose(self.x), self.y) + + def set_attributes(self): + self.attrs = {'transpose_X': True} + + +class TestDnnlMatMulOpTransposeY(TestDnnlMatMulOp): + def generate_data(self): + self.x = np.random.random((12, 9)).astype("float32") + self.y = np.random.random((12, 9)).astype("float32") + self.out = np.matmul(self.x, np.transpose(self.y)) + + def set_attributes(self): + self.attrs = {'transpose_Y': True} + + +class TestDnnlMatMulOpTransposeY3D(TestDnnlMatMulOp): + def generate_data(self): + self.x = np.random.random((17, 3, 2)).astype("float32") + self.y = np.random.random((17, 3, 2)).astype("float32") + self.out = np.matmul(self.x, np.transpose(self.y, (0, 2, 1))) + + def set_attributes(self): + self.attrs = {'transpose_Y': True} + + +class TestDnnlMatMulOpInt8NoScales(TestDnnlMatMulOp): + def generate_data(self): + self.x = np.random.random((12, 9)).astype("int8") + self.y = np.random.random((9, 12)).astype("int8") + self.out = np.matmul(self.x, self.y) + + +class TestDnnlMatMulOpInt8(TestDnnlMatMulOp): + def quantize(self, tensor): + scale = 127. / np.abs(np.amax(tensor)) + quantized = np.round(scale * tensor).astype("int8") + return scale, quantized + + def generate_data(self): + x_float = np.random.random((12, 9)).astype("float32") + self.x_scale, self.x = self.quantize(x_float) + + y_float = np.random.random((9, 12)).astype("float32") + self.y_scale, self.y = self.quantize(y_float) + + out_float = np.matmul(x_float, y_float) + self.out_scale, self.out = self.quantize(out_float) + + def set_attributes(self): + self.attrs = { + 'Scale_x': self.x_scale, + 'Scale_y': self.y_scale, + 'Scale_out': self.out_scale, + } + + def test_check_output(self): + int_atol = 1 + self.check_output(atol=int_atol) + + +class TestDnnlMatMulOpInt8ForceFP32(TestDnnlMatMulOpInt8): + def generate_data(self): + x_float = np.random.random((12, 9)).astype("float32") + self.x_scale, self.x = self.quantize(x_float) + + y_float = np.random.random((9, 12)).astype("float32") + self.y_scale, self.y = self.quantize(y_float) + + out_float = np.matmul(x_float, y_float) + self.out = out_float + + def set_attributes(self): + self.attrs = { + 'Scale_x': self.x_scale, + 'Scale_y': self.y_scale, + 'force_fp32_output': True + } + + +class TestDnnlMatMulOpInt8ForceFP32BasicScales(TestDnnlMatMulOp): + def generate_data(self): + self.x = np.random.randint(0, 3, (12, 9)).astype("int8") + self.y = np.random.randint(0, 3, (9, 12)).astype("int8") + self.out = np.matmul(self.x, self.y).astype("float32") + + def set_attributes(self): + self.attrs = {'force_fp32_output': True} + + +if __name__ == "__main__": + unittest.main()