From 6dca7a1de70a85b16e2fa8d7f1affd5c632ca10c Mon Sep 17 00:00:00 2001 From: jakpiase <62569058+jakpiase@users.noreply.github.com> Date: Tue, 30 Mar 2021 11:04:07 +0200 Subject: [PATCH] Added int8 kernel for oneDNN LSTM op (#31894) --- .../fluid/operators/fused/fusion_lstm_op.cc | 12 ++ .../fused/mkldnn/fusion_lstm_mkldnn_op.cc | 19 ++- .../mkldnn/test_fusion_lstm_int8_mkldnn_op.py | 153 ++++++++++++++++++ tools/static_mode_white_list.py | 1 + 4 files changed, 178 insertions(+), 7 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/mkldnn/test_fusion_lstm_int8_mkldnn_op.py diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.cc b/paddle/fluid/operators/fused/fusion_lstm_op.cc index 3c82be2c4e4..6cca6b5a972 100644 --- a/paddle/fluid/operators/fused/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fused/fusion_lstm_op.cc @@ -249,6 +249,18 @@ void FusionLSTMOpMaker::Make() { AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); + AddAttr("Scale_data", + "Scale to be used for int8 input/output data." + "Only used with MKL-DNN INT8.") + .SetDefault(1.0f); + AddAttr("Shift_data", + "Shift to be used for int8 input/output data." + "Only used with MKL-DNN INT8.") + .SetDefault(0.0f); + AddAttr>("Scale_weights", + "Scale_weights to be used for int8 weights data." + "Only used with MKL-DNN INT8.") + .SetDefault({1.0f}); AddAttr("force_fp32_output", "(bool, default false) Force INT8 kernel output FP32, only " "used in MKL-DNN INT8") diff --git a/paddle/fluid/operators/fused/mkldnn/fusion_lstm_mkldnn_op.cc b/paddle/fluid/operators/fused/mkldnn/fusion_lstm_mkldnn_op.cc index cf39968a900..1adbd5cd9e7 100644 --- a/paddle/fluid/operators/fused/mkldnn/fusion_lstm_mkldnn_op.cc +++ b/paddle/fluid/operators/fused/mkldnn/fusion_lstm_mkldnn_op.cc @@ -79,13 +79,11 @@ class LSTMMKLDNNHandler MKLDNNMemoryFormat::ldgo); auto hidden_md = MKLDNNMemDesc({Ti, N, OC}, MKLDNNGetDataType(), MKLDNNMemoryFormat::tnc); + auto h0_md = MKLDNNMemDesc({L, D, N, OC}, MKLDNNGetDataType(), MKLDNNMemoryFormat::ldnc); - auto c0_md = MKLDNNMemDesc( - {L, D, N, OC}, MKLDNNGetDataType(), // Vanilla LSTM and LSTM - // with peepoles has c0 as - // fp32 - MKLDNNMemoryFormat::ldnc); + auto c0_md = MKLDNNMemDesc({L, D, N, OC}, MKLDNNGetDataType(), + MKLDNNMemoryFormat::ldnc); // Create LSTM oneDNN primitive const auto direction = @@ -266,7 +264,7 @@ class LSTMMKLDNNHandler this->fwd_pd_->src_iter_c_desc(), this->engine_); auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); - dnnl::reorder(user_c0_memory, *memory_p, this->attr_) + dnnl::reorder(user_c0_memory, *memory_p) .execute(astream, user_c0_memory, *memory_p); this->dev_ctx_.SetBlob(c0_key, memory_p); @@ -360,6 +358,12 @@ class FusionLSTMMKLDNNKernel : public framework::OpKernel { weight_h_memory_p = handler.template AcquireWeightHMemory( weight_h); + } else { + h0_memory_p = handler.template AcquireH0Memory(h0); + weight_x_memory_p = + handler.template AcquireWeightXMemory(weight_x); + weight_h_memory_p = + handler.template AcquireWeightHMemory(weight_h); } auto bias_memory_p = handler.AcquireBiasMemory(bias); @@ -406,4 +410,5 @@ class FusionLSTMMKLDNNKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_KERNEL(fusion_lstm, MKLDNN, paddle::platform::CPUPlace, ops::FusionLSTMMKLDNNKernel, - ops::FusionLSTMMKLDNNKernel); + ops::FusionLSTMMKLDNNKernel, + ops::FusionLSTMMKLDNNKernel); diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_lstm_int8_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_lstm_int8_mkldnn_op.py new file mode 100644 index 00000000000..93dc45f2650 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_lstm_int8_mkldnn_op.py @@ -0,0 +1,153 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from paddle.fluid.tests.unittests.op_test import OpTest +from paddle.fluid.tests.unittests.test_fusion_lstm_op import fc, ACTIVATION, fusion_lstm + + +class TestFusionLSTMINT8MKLDNNOp(OpTest): + def set_confs(self): + pass + + def setUp(self): + self.op_type = "fusion_lstm" + self.lod = [[2, 3, 5, 4]] + self.IC = 3 + self.OC = 5 + self.is_reverse = False + self.has_initial_state = False + self.act_cell = 'tanh' + self.act_gate = 'sigmoid' + self.act_cand = 'tanh' + self.use_peepholes = False # LSTM u8 doesn't support peepholes + self.use_mkldnn = True + self.force_fp32_output = False + self.error_margin = 1e-5 + self.set_confs() + + # RNN dimensions + T = sum(self.lod[0]) + N = len(self.lod[0]) + + # Input data + x_f32 = np.random.rand(T, self.IC).astype('float32') * 2 - 1 + scale_data = 63.0 + shift_data = 64.0 + x_u8 = np.rint(x_f32 * scale_data + shift_data).astype(np.uint8) + + # WeightX/WeightH data + wx = np.random.rand(self.IC, 4 * self.OC).astype('float32') * 2 - 1 + wh = np.random.rand(self.OC, 4 * self.OC).astype('float32') * 2 - 1 + + # Calculating weight scales + # scales = 127 / max(abs(channel_wise(weightsX + weightsH))) + s8_max = 127.0 + + scale_weights = s8_max / np.max( + np.abs(np.concatenate( + [wx[:, :], wh[:, :]], axis=0)), axis=0) + + scale_weights = scale_weights.astype('float') + + if self.use_peepholes: + b = np.random.rand(1, 7 * self.OC).astype('float32') + else: + b = np.random.rand(1, 4 * self.OC).astype('float32') + w_b = np.copy(b[:, 0:4 * self.OC]) + w_c = b[:, 4 * self.OC:] if self.use_peepholes else None + + bx = np.random.normal(size=(1, 4 * self.OC)).astype('float32') + b[0, 0:4 * self.OC] += bx[0, :] + + if self.has_initial_state: + h0 = np.random.rand(N, self.OC).astype('float32') + c0 = np.random.rand(N, self.OC).astype('float32') + else: + h0 = np.zeros((N, self.OC)).astype('float32') + c0 = np.zeros((N, self.OC)).astype('float32') + + hidden_f32, c = fusion_lstm( + x_f32, self.lod, wx, bx, h0, c0, wh, w_b, w_c, self.is_reverse, + ACTIVATION[self.act_gate], ACTIVATION[self.act_cell], + ACTIVATION[self.act_cand]) + + self.inputs = { + 'X': (x_u8, self.lod), + 'WeightX': wx, + 'WeightH': wh, + 'Bias': b + } + + if self.has_initial_state: + self.inputs['H0'] = h0 + self.inputs['C0'] = c0 + + if self.force_fp32_output: + self.error_margin = 1e-1 + self.outputs = { + 'Hidden': (hidden_f32, self.lod), + 'Cell': (c, self.lod) + } + else: + self.error_margin = 2 + hidden_u8 = np.rint(hidden_f32 * scale_data + shift_data).astype( + np.uint8) + self.outputs = { + 'Hidden': (hidden_u8, self.lod), + 'Cell': (c, self.lod) + } + + self.attrs = { + 'gate_activation': self.act_gate, + 'cell_activation': self.act_cell, + 'candidate_activation': self.act_cand, + 'is_reverse': self.is_reverse, + 'use_peepholes': self.use_peepholes, + 'use_mkldnn': self.use_mkldnn, + 'force_fp32_output': self.force_fp32_output, + 'Scale_data': scale_data, + 'Shift_data': shift_data, + 'Scale_weights': scale_weights + } + + def test_check_output(self): + for use_seq in {True, False}: + self.attrs['use_seq'] = use_seq + self.check_output( + check_dygraph=False, + no_check_set=["Cell"], + atol=self.error_margin) + + +class TestFusionLSTMINT8MKLDNNOp2(TestFusionLSTMINT8MKLDNNOp): + def set_confs(self): + self.force_fp32_output = True + + +class TestFusionLSTMINT8MKLDNNOp4(TestFusionLSTMINT8MKLDNNOp): + def set_confs(self): + self.is_reverse = True + + +class TestFusionLSTMINT8MKLDNNOp5(TestFusionLSTMINT8MKLDNNOp): + def set_confs(self): + self.has_initial_state = True + + +if __name__ == "__main__": + from paddle import enable_static + enable_static() + unittest.main() diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 6453eb48d70..ab5b6516b90 100644 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -606,6 +606,7 @@ STATIC_MODE_TESTING_LIST = [ 'test_fusion_gru_bf16_mkldnn_op', 'test_fusion_gru_mkldnn_op', 'test_fusion_lstm_mkldnn_op', + 'test_fusion_lstm_int8_mkldnn_op', 'test_fusion_lstm_bf16_mkldnn_op', 'test_gaussian_random_mkldnn_op', 'test_lrn_mkldnn_op', -- GitLab