diff --git a/cmake/external/mkldnn.cmake b/cmake/external/mkldnn.cmake index ad7b7c2c2ab1a756d359e6963867ca577275f56d..c0adda0da31ae1e7425ddfb352971444c09d5615 100644 --- a/cmake/external/mkldnn.cmake +++ b/cmake/external/mkldnn.cmake @@ -19,8 +19,8 @@ SET(MKLDNN_PREFIX_DIR ${THIRD_PARTY_PATH}/mkldnn) SET(MKLDNN_SOURCE_DIR ${THIRD_PARTY_PATH}/mkldnn/src/extern_mkldnn) SET(MKLDNN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/mkldnn) SET(MKLDNN_INC_DIR "${MKLDNN_INSTALL_DIR}/include" CACHE PATH "mkldnn include directory." FORCE) -SET(MKLDNN_REPOSITORY https://github.com/intel/mkl-dnn.git) -SET(MKLDNN_TAG 4c05c181b40cf7132f8943411fb3fab1786df0f7) +SET(MKLDNN_REPOSITORY https://github.com/oneapi-src/oneDNN.git) +SET(MKLDNN_TAG 64a48f9565aa72f6359917b3406328075a409939) # Introduce variables: # * CMAKE_INSTALL_LIBDIR diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc index f731a78f77bfcdd4602fa212ca82b7208347ed60..4013906609603e31b798e333d55ecccba197506a 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -206,6 +206,27 @@ void FusionGRUOpMaker::Make() { AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); + AddAttr( + "mkldnn_data_type", + "(string, default \"float32\"). Data type of mkldnn kernel") + .SetDefault("float32") + .InEnum({"float32", "int8", "bfloat16"}); + AddAttr("Scale_data", + "Scale to be used for int8 input/output data." + "Only used with MKL-DNN INT8.") + .SetDefault(1.0f); + AddAttr("Shift_data", + "Shift to be used for int8 input/output data." + "Only used with MKL-DNN INT8.") + .SetDefault(0.0f); + AddAttr>("Scale_weights", + "Scale_weights to be used for int8 weights data." + "Only used with MKL-DNN INT8.") + .SetDefault({1.0f}); + AddAttr("force_fp32_output", + "(bool, default false) Force INT8 kernel output FP32, only " + "used in MKL-DNN INT8") + .SetDefault(false); AddComment(R"DOC( The Fusion complete GRU Operator. This operator fuse the fully-connected operator into GRU, diff --git a/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc b/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc index a31fe1684396255b0c78e07fc3801ffb09f8d6b6..5fad1b116de6437e62e311318832ad77e24a40cc 100644 --- a/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc +++ b/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc @@ -21,11 +21,12 @@ namespace operators { using paddle::framework::LoDTensor; using paddle::framework::Tensor; using paddle::platform::CPUDeviceContext; +using paddle::platform::CreateKey; using paddle::platform::MKLDNNGetDataType; using paddle::platform::MKLDNNMemDesc; using platform::to_void_cast; -template +template class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { public: GRUMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, @@ -38,7 +39,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { const std::string& unique_name) : platform::MKLDNNHandlerT( dev_ctx, dev_ctx.GetEngine(), cpu_place, - platform::CreateKey(unique_name, Ti)), + CreateKey(unique_name, MKLDNNGetDataType(), Ti)), N(N), Ti(Ti), IC(IC), @@ -47,9 +48,29 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { // do not depend on Ti size but primitive and input/output memory do if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() != platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) { - memory_key_ = unique_name; + memory_key_ = CreateKey(unique_name, MKLDNNGetDataType()); } else { - memory_key_ = unique_name + "-t:" + platform::ThreadIDasStr(); + memory_key_ = CreateKey(unique_name, MKLDNNGetDataType(), "-t:", + platform::ThreadIDasStr()); + } + + // Is it int8 kernel + const bool is_INT8 = std::is_same::value; + + if (is_INT8) { + // Int8 attributes + const float scale_data = ctx.Attr("Scale_data"); + const float shift_data = ctx.Attr("Shift_data"); + const auto scale_weights = ctx.Attr>("Scale_weights"); + + const int weights_scale_mask = + 0 + + (1 << 3) // bit, indicating the unique scales for `g` dim in `ldigo` + + + (1 << 4); // bit, indicating the unique scales for `o` dim in `ldigo` + + attr_.set_rnn_data_qparams(scale_data, shift_data); + attr_.set_rnn_weights_qparams(weights_scale_mask, scale_weights); } if (!this->isCached()) { @@ -63,6 +84,10 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { platform::errors::Unimplemented( "oneDNN fusion_gru supports only tanh as an activation.")); + // Weights for int8 kernel are of a type s8 + const auto weights_dt = + is_INT8 ? dnnl::memory::data_type::s8 : dnnl::memory::data_type::f32; + // oneDNN RNN dimensions const int64_t D = 1; // Directions const int64_t L = 1; // Layers (PP supports only 1 stacked layer) @@ -71,19 +96,16 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { // Create memory descriptors auto input_md = MKLDNNMemDesc({Ti, N, IC}, MKLDNNGetDataType(), MKLDNNMemoryFormat::any); - auto weight_x_md = MKLDNNMemDesc( - {L, D, IC, G, OC}, MKLDNNGetDataType(), MKLDNNMemoryFormat::any); - auto weight_h_md = MKLDNNMemDesc( - {L, D, OC, G, OC}, MKLDNNGetDataType(), MKLDNNMemoryFormat::any); + auto weight_x_md = + MKLDNNMemDesc({L, D, IC, G, OC}, weights_dt, MKLDNNMemoryFormat::any); + auto weight_h_md = + MKLDNNMemDesc({L, D, OC, G, OC}, weights_dt, MKLDNNMemoryFormat::any); auto bias_md = MKLDNNMemDesc({L, D, G, OC}, MKLDNNGetDataType(), MKLDNNMemoryFormat::ldgo); - auto hidden_md = MKLDNNMemDesc({Ti, N, OC}, MKLDNNGetDataType(), + auto hidden_md = MKLDNNMemDesc({Ti, N, OC}, MKLDNNGetDataType(), MKLDNNMemoryFormat::any); - auto h0_md = dnnl::memory::desc(); - if (h0) { - h0_md = MKLDNNMemDesc({L, D, N, OC}, MKLDNNGetDataType(), - MKLDNNMemoryFormat::ldnc); - } + auto h0_md = MKLDNNMemDesc({L, D, N, OC}, MKLDNNGetDataType(), + MKLDNNMemoryFormat::ldnc); // Create GRU oneDNN primitive const auto direction = @@ -91,7 +113,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { : dnnl::rnn_direction::unidirectional_left2right; this->AcquireForwardPrimitiveDescriptor( - dnnl::prop_kind::forward_inference, direction, input_md, h0_md, + attr_, dnnl::prop_kind::forward_inference, direction, input_md, h0_md, weight_x_md, weight_h_md, bias_md, hidden_md, dnnl::memory::desc()); } } @@ -101,29 +123,31 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { dnnl::memory::format_tag::ntc); } - void reorderRNNdata(const T* input_data, T* output_data, + void reorderRNNdata(void* input_data, void* output_data, std::vector lod, const bool is_reverse, platform::RNNReorderType reorder_type) { switch (reorder_type) { // Reorder input memory [WORDS, C] + LoD -> [N, T, C] case platform::RNNReorderType::PP_NTC: { - auto* input_data_iter = input_data; + auto* input_data_iter = reinterpret_cast(input_data); + auto* output_data_iter = reinterpret_cast(output_data); for (int n = 0; n < N; ++n) { const auto num_elements = (lod[n + 1] - lod[n]) * IC; const auto offset = is_reverse ? (Ti * IC - num_elements) : 0; - memcpy(output_data + n * Ti * IC + offset, input_data_iter, + memcpy(output_data_iter + n * Ti * IC + offset, input_data_iter, sizeof(T) * num_elements); input_data_iter += num_elements; } } break; // Reorder input memory [WORDS, C] + LoD -> [T, N, C] case platform::RNNReorderType::PP_TNC: { - auto* input_data_iter = input_data; + auto* input_data_iter = reinterpret_cast(input_data); + auto* output_data_iter = reinterpret_cast(output_data); for (int n = 0; n < N; ++n) { const auto num_elements = (lod[n + 1] - lod[n]); const auto offset = is_reverse ? (Ti - num_elements) : 0; for (size_t t = 0; t < num_elements; ++t) { - memcpy(output_data + (t + offset) * N * IC + n * IC, + memcpy(output_data_iter + (t + offset) * N * IC + n * IC, input_data_iter, sizeof(T) * IC); input_data_iter += IC; } @@ -131,24 +155,27 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { } break; // Reorder output values to PP format [N, T, C] -> [WORDS, C] case platform::RNNReorderType::NTC_PP: { - auto* output_data_iter = output_data; + auto* input_data_iter = reinterpret_cast(input_data); + auto* output_data_iter = reinterpret_cast(output_data); for (int n = 0; n < N; ++n) { const auto num_elements = (lod[n + 1] - lod[n]) * OC; const auto offset = is_reverse ? (Ti * OC - num_elements) : 0; - memcpy(output_data_iter, input_data + n * Ti * OC + offset, - sizeof(T) * num_elements); + memcpy(output_data_iter, input_data_iter + n * Ti * OC + offset, + sizeof(T_out) * num_elements); output_data_iter += num_elements; } } break; // Reorder output values to PP format [T, N, C] -> [WORDS, C] case platform::RNNReorderType::TNC_PP: { - auto* output_data_iter = output_data; + auto* input_data_iter = reinterpret_cast(input_data); + auto* output_data_iter = reinterpret_cast(output_data); for (int n = 0; n < N; ++n) { const auto num_elements = lod[n + 1] - lod[n]; const auto offset = is_reverse ? (Ti - num_elements) : 0; for (size_t t = 0; t < num_elements; ++t) { memcpy(output_data_iter, - input_data + (t + offset) * N * OC + n * OC, sizeof(T) * OC); + input_data_iter + (t + offset) * N * OC + n * OC, + sizeof(T_out) * OC); output_data_iter += OC; } } @@ -169,9 +196,9 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { } const auto& input_lod = input->lod()[0]; - auto* x_data = input->data(); + auto* x_data = to_void_cast(input->data()); - auto* x_onednn_data = reinterpret_cast(memory_p->get_data_handle()); + auto* x_onednn_data = memory_p->get_data_handle(); memset(x_onednn_data, 0, sizeof(T) * N * Ti * IC); if (platform::GetMKLDNNFormat(this->fwd_pd_->src_desc()) == @@ -198,19 +225,35 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { return memory_p; } + // TODO(grygielski) H0 is for now persistable std::shared_ptr AcquireH0Memory(const Tensor* h0) { const std::string h0_key = memory_key_ + "@h0"; auto memory_p = std::static_pointer_cast(this->dev_ctx_.GetBlob(h0_key)); - auto* h0_data = to_void_cast(h0->data()); - if (!memory_p) { - memory_p = std::make_shared( - this->fwd_pd_->weights_layer_desc(), this->engine_, h0_data); + auto user_h0_memory = dnnl::memory(); + if (h0) { + user_h0_memory = + dnnl::memory({{1, 1, N, OC}, + MKLDNNGetDataType(), + MKLDNNMemoryFormat::ldnc}, + this->engine_, to_void_cast(h0->data())); + } else { + user_h0_memory = dnnl::memory({{1, 1, N, OC}, + MKLDNNGetDataType(), + MKLDNNMemoryFormat::ldnc}, + this->engine_); + memset(user_h0_memory.get_data_handle(), 0, sizeof(float) * N * OC); + } + memory_p = std::make_shared(this->fwd_pd_->src_iter_desc(), + this->engine_); + + dnnl::stream astream(this->engine_); + dnnl::reorder(user_h0_memory, *memory_p, attr_) + .execute(astream, user_h0_memory, *memory_p); + this->dev_ctx_.SetBlob(h0_key, memory_p); - } else { - memory_p->set_data_handle(h0_data); } return memory_p; } @@ -245,7 +288,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { this->fwd_pd_->weights_layer_desc(), this->engine_); dnnl::stream astream(this->engine_); - dnnl::reorder(user_memory, *memory_p) + dnnl::reorder(user_memory, *memory_p, attr_) .execute(astream, user_memory, *memory_p); this->dev_ctx_.SetBlob(wx_key, memory_p); @@ -298,7 +341,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { this->fwd_pd_->weights_iter_desc(), this->engine_); dnnl::stream astream(this->engine_); - dnnl::reorder(user_memory, *memory_p) + dnnl::reorder(user_memory, *memory_p, attr_) .execute(astream, user_memory, *memory_p); this->dev_ctx_.SetBlob(wh_key, memory_p); @@ -347,12 +390,26 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { // Memory size of weights, bias and h0 does not depend // on Ti size, thus we need another key to cache them std::string memory_key_; + dnnl::primitive_attr attr_; }; template class FusionGRUMKLDNNKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + const bool is_INT8 = std::is_same::value; + const bool force_fp32_output = ctx.Attr("force_fp32_output"); + + // TODO(grygielski) Add option for bfloat + if (!is_INT8 || force_fp32_output) { + RunKernel(ctx); + } else { + RunKernel(ctx); + } + } + + template + void RunKernel(const framework::ExecutionContext& ctx) const { auto& dev_ctx = ctx.template device_context(); const auto& mkldnn_engine = dev_ctx.GetEngine(); @@ -390,12 +447,14 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel { const int64_t IC = x_mat_dims_vec[1]; // Input channels const int64_t OC = weight_h_dims[0]; // Output channels - GRUMKLDNNHandler handler(ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), - input, weight_h, h0, is_reverse, N, Ti, IC, OC, - ctx.InputName("X") + ctx.InputName("WeightH")); + GRUMKLDNNHandler handler( + ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), input, weight_h, h0, + is_reverse, N, Ti, IC, OC, + ctx.InputName("X") + ctx.InputName("WeightH")); auto input_memory_p = handler.AcquireInputMemoryWithReorder(input, is_reverse); + auto h0_memory_p = handler.AcquireH0Memory(h0); auto weight_x_memory_p = handler.AcquireWeightXMemory(weight_x, origin_mode); auto weight_h_memory_p = @@ -405,25 +464,21 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel { std::unordered_map gru_args = { {DNNL_ARG_SRC_LAYER, *input_memory_p}, + {DNNL_ARG_SRC_ITER, *h0_memory_p}, {DNNL_ARG_WEIGHTS_LAYER, *weight_x_memory_p}, {DNNL_ARG_WEIGHTS_ITER, *weight_h_memory_p}, {DNNL_ARG_BIAS, *bias_memory_p}, {DNNL_ARG_DST_LAYER, *hidden_onednn_memory_p}}; - if (h0) { - auto h0_memory_p = handler.AcquireH0Memory(h0); - gru_args.insert({DNNL_ARG_SRC_ITER, *h0_memory_p}); - } - auto gru_forward_p = handler.AcquireForwardPrimitive(); dnnl::stream astream(mkldnn_engine); gru_forward_p->execute(astream, gru_args); astream.wait(); - auto* hidden_onednn_data = - reinterpret_cast(hidden_onednn_memory_p->get_data_handle()); - auto* hidden_data = hidden->mutable_data(ctx.GetPlace()); + auto* hidden_onednn_data = hidden_onednn_memory_p->get_data_handle(); + auto* hidden_data = + to_void_cast(hidden->mutable_data(ctx.GetPlace())); if (handler.is_NTC()) { handler.reorderRNNdata(hidden_onednn_data, hidden_data, input_lod, is_reverse, platform::RNNReorderType::NTC_PP); @@ -439,4 +494,5 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_KERNEL(fusion_gru, MKLDNN, paddle::platform::CPUPlace, - ops::FusionGRUMKLDNNKernel); + ops::FusionGRUMKLDNNKernel, + ops::FusionGRUMKLDNNKernel); diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_int8_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_int8_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..ff4531f0e250e325f39ef69161c8d1ee751a2336 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_int8_mkldnn_op.py @@ -0,0 +1,145 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from paddle.fluid.tests.unittests.op_test import OpTest +from paddle.fluid.tests.unittests.test_fusion_gru_op import fusion_gru +from paddle.fluid.tests.unittests.test_fusion_lstm_op import fc, ACTIVATION + + +class TestFusionGRUINT8MKLDNNOp(OpTest): + def set_confs(self): + pass + + def setUp(self): + self.op_type = "fusion_gru" + self.lod = [[2, 4, 3]] + self.IC = 3 + self.OC = 5 + self.is_reverse = False + self.with_h0 = False + self.with_bias = True + self.act_state = 'tanh' + self.act_gate = 'sigmoid' + self.origin_mode = True + self.use_mkldnn = True + self.force_fp32_output = True + self.error_margin = 1e-5 + self.set_confs() + + # RNN dimensions + T = sum(self.lod[0]) + N = len(self.lod[0]) + + # Input data + x_f32 = np.random.rand(T, self.IC).astype('float32') * 2 - 1 + scale_data = 63 + shift_data = 64 + x_u8 = (x_f32 * scale_data + shift_data).astype(np.uint8) + + # WeightX/WeightH data + wx = np.random.rand(self.IC, 3 * self.OC).astype('float32') * 2 - 1 + wh = np.random.rand(self.OC, 3 * self.OC).astype('float32') * 2 - 1 + + # Calculating weight scales + # scales = 63 / max(abs(channel_wise(weightsX + weightsH))) + # WeightX data shape in PP: [IC, 3 * OC] + # WeightH data shape in PP: [OC, 2 * OC] + [OC, OC] + # Scales shape in oneDNN: [3, OC] + scale_ur = 63 / np.max(np.abs( + np.concatenate( + [ + wx[:, :2 * self.OC], wh.flatten()[:2 * self.OC * self.OC] + .reshape(self.OC, 2 * self.OC) + ], + axis=0)), + axis=0) + scale_o = 63 / np.max(np.abs( + np.concatenate( + [ + wx[:, 2 * self.OC:], wh.flatten()[2 * self.OC * self.OC:] + .reshape(self.OC, self.OC) + ], + axis=0)), + axis=0) + + scale_weights = np.concatenate([scale_ur, scale_o]).astype('float') + + bias = np.random.rand( + 1, 3 * self.OC).astype('float32') if self.with_bias else np.zeros( + (1, 3 * self.OC), dtype='float32') + h0 = np.random.rand( + N, self.OC).astype('float32') if self.with_h0 else np.zeros( + (N, self.OC), dtype='float32') + + _, _, _, hidden_f32 = fusion_gru(x_f32, self.lod, h0, wx, wh, bias, + self.is_reverse, self.origin_mode, + ACTIVATION[self.act_state], + ACTIVATION[self.act_gate]) + + self.inputs = {'X': (x_u8, self.lod), 'WeightX': wx, 'WeightH': wh} + + if self.with_bias: + self.inputs['Bias'] = bias + + if self.with_h0: + self.inputs['H0'] = h0 + + if self.force_fp32_output: + self.error_margin = 1e-1 + self.outputs = {'Hidden': (hidden_f32, self.lod)} + else: + self.error_margin = 1 + hidden_u8 = (hidden_f32 * scale_data + shift_data).astype(np.uint8) + self.outputs = {'Hidden': (hidden_u8, self.lod)} + + self.attrs = { + 'activation': self.act_state, + 'gate_activation': self.act_gate, + 'is_reverse': self.is_reverse, + 'origin_mode': self.origin_mode, + 'use_mkldnn': self.use_mkldnn, + 'force_fp32_output': self.force_fp32_output, + 'Scale_data': scale_data, + 'Shift_data': shift_data, + 'Scale_weights': scale_weights + } + + def test_check_output(self): + self.check_output(check_dygraph=False, atol=self.error_margin) + + +class TestFusionGRUINT8MKLDNNOp2(TestFusionGRUINT8MKLDNNOp): + def set_confs(self): + self.force_fp32_output = False + + +class TestFusionGRUINT8MKLDNNOp3(TestFusionGRUINT8MKLDNNOp): + def set_confs(self): + self.origin_mode = False + + +class TestFusionGRUINT8MKLDNNOp4(TestFusionGRUINT8MKLDNNOp): + def set_confs(self): + self.with_bias = False + + +class TestFusionGRUINT8MKLDNNOp5(TestFusionGRUINT8MKLDNNOp): + def set_confs(self): + self.with_h0 = False + + +if __name__ == "__main__": + unittest.main()