From e235882c18cf4975efd4a232d8da140e412df154 Mon Sep 17 00:00:00 2001 From: xiaolil1 <39753926+xiaolil1@users.noreply.github.com> Date: Fri, 22 Mar 2019 17:02:54 +0800 Subject: [PATCH] Enable MKL-DNN INT8 Concat Kernel. (#16156) * Enable INT8 Concat Kernel to improve the performance of MobileNet-SSD. test=develop * Optimize UT format. test=develop * Fix UT file address issue. test=develop * Refine the license year. test=develop * Optimize code for new API. test=develop * Restructure INT8 Concat kernel. test=develop --- .../operators/mkldnn/concat_mkldnn_op.cc | 120 ++++++++++++++--- .../mkldnn/test_concat_int8_mkldnn_op.py | 124 ++++++++++++++++++ 2 files changed, 223 insertions(+), 21 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/mkldnn/test_concat_int8_mkldnn_op.py diff --git a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc index 54c6a71111a..97387af92ff 100644 --- a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/concat_op.h" #include "paddle/fluid/platform/mkldnn_helper.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" namespace paddle { namespace operators { @@ -38,15 +39,20 @@ static void EnforceLayouts(const std::vector inputs) { } static memory::primitive_desc CreateMemPrimDesc(const Tensor& input, - const mkldnn::engine& engine) { - constexpr auto data_type = mkldnn::memory::f32; + const mkldnn::engine& engine, + const memory::data_type& dt) { const auto dims = paddle::framework::vectorize2int(input.dims()); const auto format = input.format(); - auto description = memory::desc(dims, data_type, format); + auto description = memory::desc(dims, dt, format); auto mem_prim_desc = memory::primitive_desc(description, engine); return mem_prim_desc; } +static mkldnn::memory::format GetDstMemFormat( + const concat::primitive_desc& concat_pd) { + return (memory::format)concat_pd.dst_primitive_desc().desc().data.format; +} + static platform::CPUPlace GetCpuPlace( const paddle::framework::ExecutionContext& ctx) { auto place = ctx.GetPlace(); @@ -61,14 +67,30 @@ static const mkldnn::engine& GetMKLDNNEngine( return dev_ctx.GetEngine(); } +std::string CreateKey(const paddle::framework::ExecutionContext& ctx, + const std::vector multi_input, + const int64_t& concat_axis, const memory::data_type& dt) { + std::string key; + key.reserve(platform::MKLDNNHandler::MaxKeyLength); + for (size_t i = 0; i < multi_input.size(); i++) { + platform::MKLDNNHandler::AppendKeyDims( + &key, paddle::framework::vectorize2int(multi_input[i]->dims())); + } + platform::MKLDNNHandler::AppendKey(&key, std::to_string(concat_axis)); + platform::MKLDNNHandler::AppendKey(&key, ctx.op().Output("Out")); + platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt)); + return key; +} + template class ConcatPrimitiveFactory { public: concat::primitive_desc CreateConcatPrimDescriptor( const std::vector multi_input, Tensor* output, - int concat_axis, const mkldnn::engine& mkldnn_engine) { - CreateSourcesDescriptors(multi_input, mkldnn_engine); - auto dst_desc = CreateDstMemDescriptor(output); + int concat_axis, const mkldnn::engine& mkldnn_engine, + const memory::data_type& dt = memory::data_type::f32) { + CreateSourcesDescriptors(multi_input, mkldnn_engine, dt); + auto dst_desc = CreateDstMemDescriptor(output, dt); return concat::primitive_desc(dst_desc, concat_axis, srcs_pd); } @@ -79,23 +101,39 @@ class ConcatPrimitiveFactory { return concat(concat_pd, inputs, dst_mem.get()); } + void SetSrcDataHandleByIndex(const std::vector& srcs, const size_t& i, + void* handler) { + srcs[i].set_data_handle(handler); + } + + void SetDstDataHandle(const memory& dst_mem, void* handler) { + dst_mem.set_data_handle(handler); + } + + std::vector GetSrcs() { return srcs; } + + memory GetDst() { return dst_mem.get(); } + private: - memory::desc CreateDstMemDescriptor(Tensor* output) { + memory::desc CreateDstMemDescriptor(Tensor* output, + const memory::data_type& dt) { auto dst_dims = paddle::framework::vectorize2int(output->dims()); - return memory::desc(dst_dims, platform::MKLDNNGetDataType(), - memory::format::any); + return memory::desc(dst_dims, dt, memory::format::any); } mkldnn::memory CreateDstMemory(const concat::primitive_desc& concat_pd, - Tensor* output, platform::CPUPlace place) { + Tensor* output, + const platform::CPUPlace& place) { return memory(concat_pd.dst_primitive_desc(), output->mutable_data(place)); } void CreateSourcesDescriptors(const std::vector multi_input, - const mkldnn::engine& mkldnn_engine) { + const mkldnn::engine& mkldnn_engine, + const memory::data_type& dt) { for (size_t i = 0; i < multi_input.size(); i++) { - auto mem_prim_desc = CreateMemPrimDesc(*multi_input[i], mkldnn_engine); + auto mem_prim_desc = + CreateMemPrimDesc(*multi_input[i], mkldnn_engine, dt); srcs_pd.push_back(mem_prim_desc); srcs.push_back( memory(mem_prim_desc, to_void_cast(multi_input[i]->data()))); @@ -120,21 +158,59 @@ template class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel { public: void Compute(const paddle::framework::ExecutionContext& ctx) const override { - auto place = GetCpuPlace(ctx); - const auto& mkldnn_engine = GetMKLDNNEngine(ctx); - auto multi_input = ctx.MultiInput("X"); EnforceLayouts(multi_input); Tensor* output = ctx.Output("Out"); int64_t concat_axis = static_cast(ctx.Attr("axis")); + auto& dev_ctx = + ctx.template device_context(); + auto place = GetCpuPlace(ctx); + + memory::data_type dt = + paddle::framework::ToMKLDNNDataType(multi_input[0]->type()); ConcatPrimitiveFactory prim_creator; - auto concat_pd = prim_creator.CreateConcatPrimDescriptor( - multi_input, output, static_cast(concat_axis), mkldnn_engine); - auto concat = prim_creator.CreateConcatPrimitive(concat_pd, output, place); - stream(stream::kind::eager).submit({concat}).wait(); + std::string key = CreateKey(ctx, multi_input, concat_axis, dt); + const std::string key_prim = key + "@concat_p"; + const std::string key_concat_pd = key + "@concat_pd"; + const std::string key_srcs = key + "@concat_srcs"; + const std::string key_dst = key + "@concat_dst"; + + std::shared_ptr concat_pd; + std::shared_ptr> srcs; + std::shared_ptr dst_mem; + auto concat_p = std::static_pointer_cast(dev_ctx.GetBlob(key_prim)); + + if (concat_p == nullptr) { + const auto& mkldnn_engine = dev_ctx.GetEngine(); + concat_pd = std::make_shared( + prim_creator.CreateConcatPrimDescriptor(multi_input, output, + static_cast(concat_axis), + mkldnn_engine, dt)); + concat_p = std::make_shared( + prim_creator.CreateConcatPrimitive(*concat_pd, output, place)); + srcs = std::make_shared>(prim_creator.GetSrcs()); + dst_mem = std::make_shared(prim_creator.GetDst()); + dev_ctx.SetBlob(key_prim, concat_p); + dev_ctx.SetBlob(key_concat_pd, concat_pd); + dev_ctx.SetBlob(key_srcs, srcs); + dev_ctx.SetBlob(key_dst, dst_mem); + } else { + srcs = std::static_pointer_cast>( + dev_ctx.GetBlob(key_srcs)); + dst_mem = std::static_pointer_cast(dev_ctx.GetBlob(key_dst)); + concat_pd = std::static_pointer_cast( + dev_ctx.GetBlob(key_concat_pd)); + for (size_t i = 0; i < multi_input.size(); i++) { + prim_creator.SetSrcDataHandleByIndex( + *srcs, i, to_void_cast(multi_input[i]->data())); + } + prim_creator.SetDstDataHandle(*dst_mem, output->mutable_data(place)); + } + + stream(stream::kind::eager).submit({*concat_p}).wait(); - output->set_mkldnn_prim_desc(concat_pd.dst_primitive_desc()); + output->set_mkldnn_prim_desc(concat_pd->dst_primitive_desc()); } }; } // namespace operators @@ -143,4 +219,6 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_KERNEL(concat, MKLDNN, ::paddle::platform::CPUPlace, - ops::ConcatMKLDNNOpKernel) + ops::ConcatMKLDNNOpKernel, + ops::ConcatMKLDNNOpKernel, + ops::ConcatMKLDNNOpKernel); diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_concat_int8_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_concat_int8_mkldnn_op.py new file mode 100644 index 00000000000..0b6556746cd --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_concat_int8_mkldnn_op.py @@ -0,0 +1,124 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from paddle.fluid.tests.unittests.op_test import OpTest + + +class TestConcatOp(OpTest): + def setUp(self): + self.op_type = "concat" + self.use_mkldnn = True + self._cpu_only = True + self.init_axis() + self.init_shape() + self.init_test_data() + self.inputs = {'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]} + self.attrs = {'axis': self.axis, 'use_mkldnn': True} + + self.output = np.concatenate( + (self.x0, self.x1, self.x2), axis=self.axis).astype('int') + + self.outputs = {'Out': self.output} + + def test_check_output(self): + self.check_output() + +#--------------------test concat s8 in with axis 0-------------------- + + def init_test_data(self): + self.x0 = (np.random.randint(0, 100, self.x0_shape) - 50).astype('int8') + self.x1 = (np.random.randint(0, 80, self.x1_shape) - 30).astype('int8') + self.x2 = (np.random.randint(0, 110, self.x2_shape) - 80).astype('int8') + + def init_axis(self): + self.axis = 0 + + def init_shape(self): + self.x0_shape = [2, 2, 1, 2] + self.x1_shape = [1, 2, 1, 2] + self.x2_shape = [3, 2, 1, 2] + + +#--------------------test concat u8 in with axis 0-------------------- + + +class TestConcatOp2(TestConcatOp): + def init_test_data(self): + self.x0 = (np.random.randint(0, 100, self.x0_shape)).astype('uint8') + self.x1 = (np.random.randint(0, 50, self.x1_shape)).astype('uint8') + self.x2 = (np.random.randint(0, 80, self.x2_shape)).astype('uint8') + + def init_axis(self): + self.axis = 0 + + def init_shape(self): + self.x0_shape = [2, 1, 5, 5] + self.x1_shape = [1, 1, 5, 5] + self.x2_shape = [3, 1, 5, 5] + + +def create_test_int8_class(parent): + + #--------------------test concat s8/u8 in with axis 1-------------------- + + class TestAxis1Case(parent): + def init_axis(self): + self.axis = 1 + + def init_shape(self): + self.x0_shape = [1, 1, 5, 5] + self.x1_shape = [1, 2, 5, 5] + self.x2_shape = [1, 3, 5, 5] + +#--------------------test concat s8/u8 in with axis 2-------------------- + + class TestAxis2Case(parent): + def init_axis(self): + self.axis = 2 + + def init_shape(self): + self.x0_shape = [2, 3, 4, 5] + self.x1_shape = [2, 3, 5, 5] + self.x2_shape = [2, 3, 6, 5] + +#--------------------test concat s8/u8 in with axis 3-------------------- + + class TestAxis3Case(parent): + def init_axis(self): + self.axis = 3 + + def init_shape(self): + self.x0_shape = [2, 3, 5, 5] + self.x1_shape = [2, 3, 5, 6] + self.x2_shape = [2, 3, 5, 7] + + cls_name_1 = "{0}_axis_{1}".format(parent.__name__, "1") + cls_name_2 = "{0}_axis_{1}".format(parent.__name__, "2") + cls_name_3 = "{0}_axis_{1}".format(parent.__name__, "3") + TestAxis1Case.__name__ = cls_name_1 + TestAxis2Case.__name__ = cls_name_2 + TestAxis3Case.__name__ = cls_name_3 + globals()[cls_name_1] = TestAxis1Case + globals()[cls_name_2] = TestAxis2Case + globals()[cls_name_3] = TestAxis3Case + +create_test_int8_class(TestConcatOp) +create_test_int8_class(TestConcatOp2) + +if __name__ == '__main__': + unittest.main() -- GitLab