From 5b4f8aac82c62e73ae0434e7cefe9d7f9ca0f967 Mon Sep 17 00:00:00 2001 From: jakpiase <62569058+jakpiase@users.noreply.github.com> Date: Thu, 4 Mar 2021 04:23:55 +0100 Subject: [PATCH] Added LSTM BF16 and fixed GRU BF16 (#31234) --- .../fluid/operators/fused/fusion_lstm_op.cc | 4 + .../fused/mkldnn/fusion_gru_mkldnn_op.cc | 64 ++++--- .../fused/mkldnn/fusion_lstm_mkldnn_op.cc | 74 +++++--- .../fused/mkldnn/fusion_rnn_mkldnn.h | 18 +- .../mkldnn/test_fusion_gru_bf16_mkldnn_op.py | 38 ++++- .../mkldnn/test_fusion_gru_int8_mkldnn_op.py | 2 + .../mkldnn/test_fusion_lstm_bf16_mkldnn_op.py | 159 ++++++++++++++++++ .../paddle/fluid/tests/unittests/op_test.py | 19 +++ tools/static_mode_white_list.py | 2 + 9 files changed, 322 insertions(+), 58 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/mkldnn/test_fusion_lstm_bf16_mkldnn_op.py diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.cc b/paddle/fluid/operators/fused/fusion_lstm_op.cc index f14a0514251..3c82be2c4e4 100644 --- a/paddle/fluid/operators/fused/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fused/fusion_lstm_op.cc @@ -249,6 +249,10 @@ void FusionLSTMOpMaker::Make() { AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); + AddAttr("force_fp32_output", + "(bool, default false) Force INT8 kernel output FP32, only " + "used in MKL-DNN INT8") + .SetDefault(false); AddComment(R"DOC( Fusion Long-Short Term Memory (LSTM) Operator. This operator fuse the X into LSTM, more details can refer to LSTM op. diff --git a/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc b/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc index a3b59419b7f..8e0627fc15c 100644 --- a/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc +++ b/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc @@ -89,6 +89,7 @@ class GRUMKLDNNHandler : public RNNMKLDNNHandler { } } + template std::shared_ptr AcquireWeightXMemory(const Tensor* weight_x, const bool origin_mode) { const std::string wx_key = this->memory_key_ + "@weight_x"; @@ -98,18 +99,18 @@ class GRUMKLDNNHandler : public RNNMKLDNNHandler { if (!memory_p) { auto user_md = MKLDNNMemDesc({1, 1, this->IC, this->G, this->OC}, - MKLDNNGetDataType(), MKLDNNMemoryFormat::ldigo); + MKLDNNGetDataType(), MKLDNNMemoryFormat::ldigo); auto user_memory = dnnl::memory(user_md, this->engine_); - auto* weight_x_data = - reinterpret_cast(user_memory.get_data_handle()); - memcpy(weight_x_data, weight_x->data(), - sizeof(float) * this->IC * this->G * this->OC); + auto* weight_x_data = reinterpret_cast(user_memory.get_data_handle()); + memcpy(weight_x_data, weight_x->data(), + sizeof(U) * this->IC * this->G * this->OC); if (origin_mode == false) { for (int64_t i = 0; i < this->IC; ++i) { for (int64_t j = 0; j < this->OC; ++j) { - weight_x_data[j] *= -1; + U minus_one(-1.0f); + weight_x_data[j] = minus_one * weight_x_data[j]; } weight_x_data += 3 * this->OC; } @@ -127,6 +128,7 @@ class GRUMKLDNNHandler : public RNNMKLDNNHandler { return memory_p; } + template std::shared_ptr AcquireWeightHMemory(const Tensor* weight_h, const bool origin_mode) { const std::string wh_key = this->memory_key_ + "@weight_h"; @@ -136,34 +138,33 @@ class GRUMKLDNNHandler : public RNNMKLDNNHandler { if (!memory_p) { auto user_md = MKLDNNMemDesc({1, 1, this->OC, this->G, this->OC}, - MKLDNNGetDataType(), MKLDNNMemoryFormat::ldigo); + MKLDNNGetDataType(), MKLDNNMemoryFormat::ldigo); auto user_memory = dnnl::memory(user_md, this->engine_); // Reorder weights_h from PP format [OC, 2OC] + [OC, OC] to // oneDNN format [OC, 3OC] - auto* weight_h_data = - reinterpret_cast(user_memory.get_data_handle()); - auto* user_weight_h_data = weight_h->data(); + auto* weight_h_data = reinterpret_cast(user_memory.get_data_handle()); + auto* user_weight_h_data = weight_h->data(); auto src1_iter = user_weight_h_data; auto src2_iter = user_weight_h_data + 2 * this->OC * this->OC; for (int64_t c = 0; c < this->OC; ++c) { - memcpy(weight_h_data, src1_iter, 2 * this->OC * sizeof(float)); - memcpy(weight_h_data + 2 * this->OC, src2_iter, - this->OC * sizeof(float)); + memcpy(weight_h_data, src1_iter, 2 * this->OC * sizeof(U)); + memcpy(weight_h_data + 2 * this->OC, src2_iter, this->OC * sizeof(U)); src1_iter += 2 * this->OC; src2_iter += this->OC; weight_h_data += 3 * this->OC; } - weight_h_data = reinterpret_cast(user_memory.get_data_handle()); + weight_h_data = reinterpret_cast(user_memory.get_data_handle()); if (origin_mode == false) { for (int64_t i = 0; i < this->OC; ++i) { for (int64_t j = 0; j < this->OC; ++j) { - weight_h_data[j] *= -1; + U minus_one(-1.0f); + weight_h_data[j] = minus_one * weight_h_data[j]; } weight_h_data += 3 * this->OC; } @@ -273,11 +274,34 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel { auto input_memory_p = handler.AcquireInputMemoryWithReorder(input, is_reverse); - auto h0_memory_p = handler.AcquireH0Memory(h0); - auto weight_x_memory_p = - handler.AcquireWeightXMemory(weight_x, origin_mode); - auto weight_h_memory_p = - handler.AcquireWeightHMemory(weight_h, origin_mode); + + std::shared_ptr h0_memory_p, weight_h_memory_p, + weight_x_memory_p; + + if (weight_h->type() == paddle::framework::proto::VarType_Type_FP32) { + h0_memory_p = handler.template AcquireH0Memory(h0); + weight_x_memory_p = + handler.template AcquireWeightXMemory(weight_x, origin_mode); + weight_h_memory_p = + handler.template AcquireWeightHMemory(weight_h, origin_mode); + } else if (weight_h->type() == + paddle::framework::proto::VarType_Type_BF16) { + h0_memory_p = + handler.template AcquireH0Memory(h0); + weight_x_memory_p = + handler.template AcquireWeightXMemory( + weight_x, origin_mode); + weight_h_memory_p = + handler.template AcquireWeightHMemory( + weight_h, origin_mode); + } else { + h0_memory_p = handler.template AcquireH0Memory(h0); + weight_x_memory_p = + handler.template AcquireWeightXMemory(weight_x, origin_mode); + weight_h_memory_p = + handler.template AcquireWeightHMemory(weight_h, origin_mode); + } + auto bias_memory_p = handler.AcquireBiasMemory(bias, origin_mode); auto hidden_onednn_memory_p = handler.AcquireOutputMemory(); diff --git a/paddle/fluid/operators/fused/mkldnn/fusion_lstm_mkldnn_op.cc b/paddle/fluid/operators/fused/mkldnn/fusion_lstm_mkldnn_op.cc index f5ad0644c6a..cf39968a900 100644 --- a/paddle/fluid/operators/fused/mkldnn/fusion_lstm_mkldnn_op.cc +++ b/paddle/fluid/operators/fused/mkldnn/fusion_lstm_mkldnn_op.cc @@ -81,8 +81,11 @@ class LSTMMKLDNNHandler MKLDNNMemoryFormat::tnc); auto h0_md = MKLDNNMemDesc({L, D, N, OC}, MKLDNNGetDataType(), MKLDNNMemoryFormat::ldnc); - auto c0_md = MKLDNNMemDesc({L, D, N, OC}, MKLDNNGetDataType(), - MKLDNNMemoryFormat::ldnc); + auto c0_md = MKLDNNMemDesc( + {L, D, N, OC}, MKLDNNGetDataType(), // Vanilla LSTM and LSTM + // with peepoles has c0 as + // fp32 + MKLDNNMemoryFormat::ldnc); // Create LSTM oneDNN primitive const auto direction = @@ -110,13 +113,14 @@ class LSTMMKLDNNHandler // needed // PaddlePaddle: {c, i, f, o} // oneDNN: {i, f, c, o} - void ReorderGates(float* weights, int64_t I) { + template + void ReorderGates(U* weights, int64_t I) { size_t inner_block_size = this->OC; size_t block_size = inner_block_size * this->G; for (size_t i = 0; i < (size_t)I; ++i) { size_t offset = i * block_size; - float* base_pos = weights + offset; + U* base_pos = weights + offset; std::swap_ranges(base_pos, base_pos + inner_block_size, base_pos + inner_block_size); // c <-> i std::swap_ranges(base_pos + inner_block_size, @@ -125,6 +129,7 @@ class LSTMMKLDNNHandler } } + template std::shared_ptr AcquireWeightXMemory(const Tensor* weight_x) { const std::string wx_key = this->memory_key_ + "@weight_x"; auto memory_p = @@ -133,13 +138,12 @@ class LSTMMKLDNNHandler if (!memory_p) { auto user_md = MKLDNNMemDesc({1, 1, this->IC, this->G, this->OC}, - MKLDNNGetDataType(), MKLDNNMemoryFormat::ldigo); + MKLDNNGetDataType(), MKLDNNMemoryFormat::ldigo); auto user_memory = dnnl::memory(user_md, this->engine_); - auto* weight_x_data = - reinterpret_cast(user_memory.get_data_handle()); - memcpy(weight_x_data, weight_x->data(), - sizeof(float) * this->IC * this->G * this->OC); + auto* weight_x_data = reinterpret_cast(user_memory.get_data_handle()); + memcpy(weight_x_data, weight_x->data(), + sizeof(U) * this->IC * this->G * this->OC); ReorderGates(weight_x_data, this->IC); @@ -155,6 +159,7 @@ class LSTMMKLDNNHandler return memory_p; } + template std::shared_ptr AcquireWeightHMemory(const Tensor* weight_h) { const std::string wh_key = this->memory_key_ + "@weight_h"; auto memory_p = @@ -163,13 +168,12 @@ class LSTMMKLDNNHandler if (!memory_p) { auto user_md = MKLDNNMemDesc({1, 1, this->OC, this->G, this->OC}, - MKLDNNGetDataType(), MKLDNNMemoryFormat::ldigo); + MKLDNNGetDataType(), MKLDNNMemoryFormat::ldigo); auto user_memory = dnnl::memory(user_md, this->engine_); - auto* weight_h_data = - reinterpret_cast(user_memory.get_data_handle()); - memcpy(weight_h_data, weight_h->data(), - sizeof(float) * this->OC * this->G * this->OC); + auto* weight_h_data = reinterpret_cast(user_memory.get_data_handle()); + memcpy(weight_h_data, weight_h->data(), + sizeof(U) * this->OC * this->G * this->OC); ReorderGates(weight_h_data, this->OC); @@ -258,8 +262,8 @@ class LSTMMKLDNNHandler memset(user_c0_memory.get_data_handle(), 0, sizeof(float) * this->N * this->OC); } - memory_p = std::make_shared(this->fwd_pd_->src_iter_desc(), - this->engine_); + memory_p = std::make_shared( + this->fwd_pd_->src_iter_c_desc(), this->engine_); auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); dnnl::reorder(user_c0_memory, *memory_p, this->attr_) @@ -275,7 +279,15 @@ template class FusionLSTMMKLDNNKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - RunKernel(ctx); + const bool is_bf16 = std::is_same::value; + const bool force_fp32_output = ctx.Attr("force_fp32_output"); + + // BF16 does not support force output + if (!is_bf16 && force_fp32_output) { + RunKernel(ctx); + } else { + RunKernel(ctx); + } } template @@ -327,10 +339,29 @@ class FusionLSTMMKLDNNKernel : public framework::OpKernel { auto input_memory_p = handler.AcquireInputMemoryWithReorder(input, is_reverse); - auto h0_memory_p = handler.AcquireH0Memory(h0); auto c0_memory_p = handler.AcquireC0Memory(c0); - auto weight_x_memory_p = handler.AcquireWeightXMemory(weight_x); - auto weight_h_memory_p = handler.AcquireWeightHMemory(weight_h); + + std::shared_ptr h0_memory_p, weight_h_memory_p, + weight_x_memory_p; + + if (weight_h->type() == paddle::framework::proto::VarType_Type_FP32) { + h0_memory_p = handler.template AcquireH0Memory(h0); + weight_x_memory_p = + handler.template AcquireWeightXMemory(weight_x); + weight_h_memory_p = + handler.template AcquireWeightHMemory(weight_h); + } else if (weight_h->type() == + paddle::framework::proto::VarType_Type_BF16) { + h0_memory_p = + handler.template AcquireH0Memory(h0); + weight_x_memory_p = + handler.template AcquireWeightXMemory( + weight_x); + weight_h_memory_p = + handler.template AcquireWeightHMemory( + weight_h); + } + auto bias_memory_p = handler.AcquireBiasMemory(bias); auto hidden_onednn_memory_p = handler.AcquireOutputMemory(); @@ -374,4 +405,5 @@ class FusionLSTMMKLDNNKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_KERNEL(fusion_lstm, MKLDNN, paddle::platform::CPUPlace, - ops::FusionLSTMMKLDNNKernel); + ops::FusionLSTMMKLDNNKernel, + ops::FusionLSTMMKLDNNKernel); diff --git a/paddle/fluid/operators/fused/mkldnn/fusion_rnn_mkldnn.h b/paddle/fluid/operators/fused/mkldnn/fusion_rnn_mkldnn.h index f102c535fdf..5ef84eac4e6 100644 --- a/paddle/fluid/operators/fused/mkldnn/fusion_rnn_mkldnn.h +++ b/paddle/fluid/operators/fused/mkldnn/fusion_rnn_mkldnn.h @@ -179,6 +179,7 @@ class RNNMKLDNNHandler : public platform::MKLDNNHandlerT { // TODO(grygielski) H0 is for now persistable // TODO(jczaja) H0 should be updated each iter and of T type (Fusion pass does // not support in yet) + template std::shared_ptr AcquireH0Memory(const Tensor* h0) { const std::string h0_key = memory_key_ + "@h0"; auto memory_p = @@ -187,17 +188,14 @@ class RNNMKLDNNHandler : public platform::MKLDNNHandlerT { if (!memory_p) { auto user_h0_memory = dnnl::memory(); if (h0) { - user_h0_memory = - dnnl::memory({{1, 1, N, OC}, - MKLDNNGetDataType(), - MKLDNNMemoryFormat::ldnc}, - this->engine_, to_void_cast(h0->data())); + user_h0_memory = dnnl::memory( + {{1, 1, N, OC}, MKLDNNGetDataType(), MKLDNNMemoryFormat::ldnc}, + this->engine_, to_void_cast(h0->data())); } else { - user_h0_memory = dnnl::memory({{1, 1, N, OC}, - MKLDNNGetDataType(), - MKLDNNMemoryFormat::ldnc}, - this->engine_); - memset(user_h0_memory.get_data_handle(), 0, sizeof(float) * N * OC); + user_h0_memory = dnnl::memory( + {{1, 1, N, OC}, MKLDNNGetDataType(), MKLDNNMemoryFormat::ldnc}, + this->engine_); + memset(user_h0_memory.get_data_handle(), 0, sizeof(U) * N * OC); } memory_p = std::make_shared(this->fwd_pd_->src_iter_desc(), this->engine_); diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_bf16_mkldnn_op.py index 90140a3474f..c024ffbdb4b 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_bf16_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_bf16_mkldnn_op.py @@ -30,6 +30,11 @@ class TestFusionGRUBF16MKLDNNOp(OpTest): def set_confs(self): self.mkldnn_data_type = False + def test_check_output(self): + for use_seq in {True, False}: + self.attrs['use_seq'] = use_seq + self.check_output(check_dygraph=False) + def setUp(self): self.op_type = "fusion_gru" self.lod = [[2, 4, 3]] @@ -45,6 +50,7 @@ class TestFusionGRUBF16MKLDNNOp(OpTest): self.origin_mode = False self.use_mkldnn = True self.force_fp32_output = False + self.weights_dtype = 'fp32' self.set_confs() T = sum(self.lod[0]) @@ -58,6 +64,9 @@ class TestFusionGRUBF16MKLDNNOp(OpTest): wx_fp32 = np.random.rand(self.M, 3 * self.D).astype('float32') wh_fp32 = np.random.rand(self.D, 3 * self.D).astype('float32') + wx_bf16 = convert_float_to_uint16(wx_fp32) + wh_bf16 = convert_float_to_uint16(wh_fp32) + # bias is fp32 despite other inputs being in bf16 bias = np.random.rand( 1, 3 * self.D).astype('float32') if self.with_bias else np.zeros( @@ -74,20 +83,30 @@ class TestFusionGRUBF16MKLDNNOp(OpTest): hidden_bf16 = convert_float_to_uint16(hidden) - self.inputs = { - 'X': (x_bf16, self.lod), - 'WeightX': wx_fp32, - 'WeightH': wh_fp32 - } + if self.weights_dtype == 'bf16': + self.inputs = { + 'X': (x_bf16, self.lod), + 'WeightX': wx_bf16, + 'WeightH': wh_bf16 + } + elif self.weights_dtype == 'fp32': + self.inputs = { + 'X': (x_bf16, self.lod), + 'WeightX': wx_fp32, + 'WeightH': wh_fp32 + } if self.with_bias: self.inputs['Bias'] = bias if self.with_h0: - self.inputs['H0'] = h0_bf16 + if self.weights_dtype == 'bf16': + self.inputs['H0'] = h0_bf16 + elif self.weights_dtype == 'fp32': + self.inputs['H0'] = h0_fp32 h0_bf16 = convert_float_to_uint16(h0_fp32) - self.outputs = {'Hidden': (hidden_bf16, self.lod)} + self.outputs = {'Hidden': (hidden, self.lod)} self.attrs = { 'activation': self.act_state, @@ -109,6 +128,11 @@ class TestFusionGRUINT8MKLDNNOp3(TestFusionGRUBF16MKLDNNOp): self.with_bias = False +class TestFusionGRUINT8MKLDNNBF16WeightsOp(TestFusionGRUBF16MKLDNNOp): + def set_confs(self): + self.weights_dtype = 'bf16' + + if __name__ == "__main__": from paddle import enable_static enable_static() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_int8_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_int8_mkldnn_op.py index 89343c9fae4..2d3caf0be97 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_int8_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_int8_mkldnn_op.py @@ -146,4 +146,6 @@ class TestFusionGRUINT8MKLDNNOp5(TestFusionGRUINT8MKLDNNOp): if __name__ == "__main__": + from paddle import enable_static + enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_lstm_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_lstm_bf16_mkldnn_op.py new file mode 100644 index 00000000000..46bdbb1a420 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_lstm_bf16_mkldnn_op.py @@ -0,0 +1,159 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import struct +import paddle.fluid.core as core +from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16, convert_uint16_to_float +from paddle.fluid.tests.unittests.test_fusion_lstm_op import TestFusionLSTMOp, fc, ACTIVATION, fusion_lstm +from paddle.fluid.tests.unittests.test_fusion_gru_op import fusion_gru + + +@unittest.skipIf(not core.supports_bfloat16(), + "place does not support BF16 evaluation") +class TestFusionLSTMBF16ONEDNNOp(OpTest): + def set_confs(self): + self.mkldnn_data_type = False + + def test_check_output(self): + for use_seq in {True, False}: + self.attrs['use_seq'] = use_seq + self.check_output(check_dygraph=False, no_check_set=["Cell"]) + + def setUp(self): + self.op_type = 'fusion_lstm' + self.lod = [[2, 3, 5, 4]] + self.M = 8 + self.D = 16 + self.has_initial_state = False + self.use_peepholes = False + self.is_reverse = False + self._cpu_only = True + self.act_gate = 'sigmoid' + self.act_cell = 'tanh' + self.act_cand = 'tanh' + self.use_mkldnn = True + self.force_fp32_output = False + self.weights_dtype = 'fp32' + self.set_confs() + + T = sum(self.lod[0]) + bs = len(self.lod[0]) + + # fp32 X input for reference implementation and + # corressponding bf16 data as input to LSTM oneDNN bf16 kernel + x = np.random.normal(size=(T, self.M)).astype('float32') + + x_bf16 = convert_float_to_uint16(x) + + if self.has_initial_state: + h0 = np.random.normal(size=(bs, self.D)).astype('float32') + c0 = np.random.normal(size=(bs, self.D)).astype('float32') + else: + h0 = np.zeros((bs, self.D)).astype('float32') + c0 = np.zeros((bs, self.D)).astype('float32') + + wh = np.random.normal(size=(self.D, 4 * self.D)).astype('float32') + + h0_bf16 = convert_float_to_uint16(h0) + + if self.use_peepholes: + b = np.random.normal(size=(1, 7 * self.D)).astype('float32') + else: + b = np.random.normal(size=(1, 4 * self.D)).astype('float32') + w_b = np.copy(b[:, 0:4 * self.D]) + w_c = b[:, 4 * self.D:] if self.use_peepholes else None + + wx = np.random.normal(size=(self.M, 4 * self.D)).astype('float32') + + wx_bf16 = convert_float_to_uint16(wx) + wh_bf16 = convert_float_to_uint16(wh) + + bx = np.random.normal(size=(1, 4 * self.D)).astype('float32') + b[0, 0:4 * self.D] += bx[0, :] + + hidden, c = fusion_lstm(x, self.lod, wx, bx, h0, c0, wh, w_b, w_c, + self.is_reverse, ACTIVATION[self.act_gate], + ACTIVATION[self.act_cell], + ACTIVATION[self.act_cand]) + + hidden = hidden.astype('float32') + hidden_bf16 = convert_float_to_uint16(hidden) + + if self.weights_dtype == 'bf16': + self.inputs = { + 'X': (x_bf16, self.lod), + 'WeightX': wx_bf16, + 'WeightH': wh_bf16, + 'Bias': b + } + elif self.weights_dtype == 'fp32': + self.inputs = { + 'X': (x_bf16, self.lod), + 'WeightX': wx, + 'WeightH': wh, + 'Bias': b + } + + if self.has_initial_state: + if self.weights_dtype == 'bf16': + self.inputs['H0'] = h0_bf16 + elif self.weights_dtype == 'fp32': + self.inputs['H0'] = h0 + + self.inputs['C0'] = c0 + + self.outputs = { + 'Hidden': (hidden, self.lod), + 'Cell': (c, self.lod), + } + + self.attrs = { + 'use_peepholes': self.use_peepholes, + 'is_reverse': self.is_reverse, + 'gate_activation': self.act_gate, + 'cell_activation': self.act_cell, + 'candidate_activation': self.act_cand, + 'force_fp32_output': self.force_fp32_output, + 'use_mkldnn': self.use_mkldnn + } + + +class TestFusionLSTMBF16ONEDNNPeepholesOp(TestFusionLSTMBF16ONEDNNOp): + def set_confs(self): + self.use_peepholes = True + + +class TestFusionLSTMBF16ONEDNNInitializedStateOp(TestFusionLSTMBF16ONEDNNOp): + def set_confs(self): + self.has_initial_state = True + + +class TestFusionLSTMBF16ONEDNNReverseOp(TestFusionLSTMBF16ONEDNNOp): + def set_confs(self): + self.is_reverse = True + + +class TestFusionLSTMBF16ONEDNNBF16WeightsOp(TestFusionLSTMBF16ONEDNNOp): + def set_confs(self): + self.weights_dtype = 'bf16' + + +if __name__ == "__main__": + from paddle import enable_static + enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index f5c58eb4517..47c187a80c8 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -235,6 +235,19 @@ def convert_float_to_uint16(float_list, data_format="NCHW"): return new_output +def copy_bits_from_uint16_to_float(i): + i = np.uint32(i) << 16 + return struct.unpack('