未验证 提交 6dca7a1d 编写于 作者: J jakpiase 提交者: GitHub

Added int8 kernel for oneDNN LSTM op (#31894)

上级 14b7e3cf
......@@ -249,6 +249,18 @@ void FusionLSTMOpMaker::Make() {
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<float>("Scale_data",
"Scale to be used for int8 input/output data."
"Only used with MKL-DNN INT8.")
.SetDefault(1.0f);
AddAttr<float>("Shift_data",
"Shift to be used for int8 input/output data."
"Only used with MKL-DNN INT8.")
.SetDefault(0.0f);
AddAttr<std::vector<float>>("Scale_weights",
"Scale_weights to be used for int8 weights data."
"Only used with MKL-DNN INT8.")
.SetDefault({1.0f});
AddAttr<bool>("force_fp32_output",
"(bool, default false) Force INT8 kernel output FP32, only "
"used in MKL-DNN INT8")
......
......@@ -79,13 +79,11 @@ class LSTMMKLDNNHandler
MKLDNNMemoryFormat::ldgo);
auto hidden_md = MKLDNNMemDesc({Ti, N, OC}, MKLDNNGetDataType<T_out>(),
MKLDNNMemoryFormat::tnc);
auto h0_md = MKLDNNMemDesc({L, D, N, OC}, MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::ldnc);
auto c0_md = MKLDNNMemDesc(
{L, D, N, OC}, MKLDNNGetDataType<float>(), // Vanilla LSTM and LSTM
// with peepoles has c0 as
// fp32
MKLDNNMemoryFormat::ldnc);
auto c0_md = MKLDNNMemDesc({L, D, N, OC}, MKLDNNGetDataType<float>(),
MKLDNNMemoryFormat::ldnc);
// Create LSTM oneDNN primitive
const auto direction =
......@@ -266,7 +264,7 @@ class LSTMMKLDNNHandler
this->fwd_pd_->src_iter_c_desc(), this->engine_);
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
dnnl::reorder(user_c0_memory, *memory_p, this->attr_)
dnnl::reorder(user_c0_memory, *memory_p)
.execute(astream, user_c0_memory, *memory_p);
this->dev_ctx_.SetBlob(c0_key, memory_p);
......@@ -360,6 +358,12 @@ class FusionLSTMMKLDNNKernel : public framework::OpKernel<T> {
weight_h_memory_p =
handler.template AcquireWeightHMemory<paddle::platform::bfloat16>(
weight_h);
} else {
h0_memory_p = handler.template AcquireH0Memory<uint8_t>(h0);
weight_x_memory_p =
handler.template AcquireWeightXMemory<int8_t>(weight_x);
weight_h_memory_p =
handler.template AcquireWeightHMemory<int8_t>(weight_h);
}
auto bias_memory_p = handler.AcquireBiasMemory(bias);
......@@ -406,4 +410,5 @@ class FusionLSTMMKLDNNKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(fusion_lstm, MKLDNN, paddle::platform::CPUPlace,
ops::FusionLSTMMKLDNNKernel<float>,
ops::FusionLSTMMKLDNNKernel<paddle::platform::bfloat16>);
ops::FusionLSTMMKLDNNKernel<paddle::platform::bfloat16>,
ops::FusionLSTMMKLDNNKernel<uint8_t>);
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.test_fusion_lstm_op import fc, ACTIVATION, fusion_lstm
class TestFusionLSTMINT8MKLDNNOp(OpTest):
def set_confs(self):
pass
def setUp(self):
self.op_type = "fusion_lstm"
self.lod = [[2, 3, 5, 4]]
self.IC = 3
self.OC = 5
self.is_reverse = False
self.has_initial_state = False
self.act_cell = 'tanh'
self.act_gate = 'sigmoid'
self.act_cand = 'tanh'
self.use_peepholes = False # LSTM u8 doesn't support peepholes
self.use_mkldnn = True
self.force_fp32_output = False
self.error_margin = 1e-5
self.set_confs()
# RNN dimensions
T = sum(self.lod[0])
N = len(self.lod[0])
# Input data
x_f32 = np.random.rand(T, self.IC).astype('float32') * 2 - 1
scale_data = 63.0
shift_data = 64.0
x_u8 = np.rint(x_f32 * scale_data + shift_data).astype(np.uint8)
# WeightX/WeightH data
wx = np.random.rand(self.IC, 4 * self.OC).astype('float32') * 2 - 1
wh = np.random.rand(self.OC, 4 * self.OC).astype('float32') * 2 - 1
# Calculating weight scales
# scales = 127 / max(abs(channel_wise(weightsX + weightsH)))
s8_max = 127.0
scale_weights = s8_max / np.max(
np.abs(np.concatenate(
[wx[:, :], wh[:, :]], axis=0)), axis=0)
scale_weights = scale_weights.astype('float')
if self.use_peepholes:
b = np.random.rand(1, 7 * self.OC).astype('float32')
else:
b = np.random.rand(1, 4 * self.OC).astype('float32')
w_b = np.copy(b[:, 0:4 * self.OC])
w_c = b[:, 4 * self.OC:] if self.use_peepholes else None
bx = np.random.normal(size=(1, 4 * self.OC)).astype('float32')
b[0, 0:4 * self.OC] += bx[0, :]
if self.has_initial_state:
h0 = np.random.rand(N, self.OC).astype('float32')
c0 = np.random.rand(N, self.OC).astype('float32')
else:
h0 = np.zeros((N, self.OC)).astype('float32')
c0 = np.zeros((N, self.OC)).astype('float32')
hidden_f32, c = fusion_lstm(
x_f32, self.lod, wx, bx, h0, c0, wh, w_b, w_c, self.is_reverse,
ACTIVATION[self.act_gate], ACTIVATION[self.act_cell],
ACTIVATION[self.act_cand])
self.inputs = {
'X': (x_u8, self.lod),
'WeightX': wx,
'WeightH': wh,
'Bias': b
}
if self.has_initial_state:
self.inputs['H0'] = h0
self.inputs['C0'] = c0
if self.force_fp32_output:
self.error_margin = 1e-1
self.outputs = {
'Hidden': (hidden_f32, self.lod),
'Cell': (c, self.lod)
}
else:
self.error_margin = 2
hidden_u8 = np.rint(hidden_f32 * scale_data + shift_data).astype(
np.uint8)
self.outputs = {
'Hidden': (hidden_u8, self.lod),
'Cell': (c, self.lod)
}
self.attrs = {
'gate_activation': self.act_gate,
'cell_activation': self.act_cell,
'candidate_activation': self.act_cand,
'is_reverse': self.is_reverse,
'use_peepholes': self.use_peepholes,
'use_mkldnn': self.use_mkldnn,
'force_fp32_output': self.force_fp32_output,
'Scale_data': scale_data,
'Shift_data': shift_data,
'Scale_weights': scale_weights
}
def test_check_output(self):
for use_seq in {True, False}:
self.attrs['use_seq'] = use_seq
self.check_output(
check_dygraph=False,
no_check_set=["Cell"],
atol=self.error_margin)
class TestFusionLSTMINT8MKLDNNOp2(TestFusionLSTMINT8MKLDNNOp):
def set_confs(self):
self.force_fp32_output = True
class TestFusionLSTMINT8MKLDNNOp4(TestFusionLSTMINT8MKLDNNOp):
def set_confs(self):
self.is_reverse = True
class TestFusionLSTMINT8MKLDNNOp5(TestFusionLSTMINT8MKLDNNOp):
def set_confs(self):
self.has_initial_state = True
if __name__ == "__main__":
from paddle import enable_static
enable_static()
unittest.main()
......@@ -606,6 +606,7 @@ STATIC_MODE_TESTING_LIST = [
'test_fusion_gru_bf16_mkldnn_op',
'test_fusion_gru_mkldnn_op',
'test_fusion_lstm_mkldnn_op',
'test_fusion_lstm_int8_mkldnn_op',
'test_fusion_lstm_bf16_mkldnn_op',
'test_gaussian_random_mkldnn_op',
'test_lrn_mkldnn_op',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册