diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 449881a9f8feb88d80527ded9a1fecb1d7aaf6a7..ed2863e8bf7985462c0ca400095553b42d132569 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1894,7 +1894,8 @@ PDNode *patterns::QuantizePlacement::operator()( PDNode *patterns::Bfloat16Placement::operator()( const std::unordered_set &bfloat16_enabled_op_types) { - std::unordered_set supported_op_types{"conv2d"}; + std::unordered_set supported_op_types = + std::unordered_set({"conv2d", "fusion_gru"}); if (!bfloat16_enabled_op_types.empty()) { supported_op_types = bfloat16_enabled_op_types; } diff --git a/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc b/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc index 58ecc6731f00bd475c10443ad9bc08dcb2d31a85..e51d94e4b1e05a9b394e96fd2c0e561b46453793 100644 --- a/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc +++ b/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc @@ -86,7 +86,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { // Weights for int8 kernel are of a type s8 const auto weights_dt = - is_INT8 ? dnnl::memory::data_type::s8 : dnnl::memory::data_type::f32; + is_INT8 ? dnnl::memory::data_type::s8 : MKLDNNGetDataType(); // oneDNN RNN dimensions const int64_t D = 1; // Directions @@ -226,6 +226,8 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { } // TODO(grygielski) H0 is for now persistable + // TODO(jczaja) H0 should be updated each iter and of T type (Fusion pass does + // not support in yet) std::shared_ptr AcquireH0Memory(const Tensor* h0) { const std::string h0_key = memory_key_ + "@h0"; auto memory_p = @@ -397,14 +399,14 @@ template class FusionGRUMKLDNNKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - const bool is_INT8 = std::is_same::value; + const bool is_bf16 = std::is_same::value; const bool force_fp32_output = ctx.Attr("force_fp32_output"); - // TODO(grygielski) Add option for bfloat - if (!is_INT8 || force_fp32_output) { + // BF16 does not support force output + if (!is_bf16 && force_fp32_output) { RunKernel(ctx); } else { - RunKernel(ctx); + RunKernel(ctx); } } @@ -495,4 +497,5 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_KERNEL(fusion_gru, MKLDNN, paddle::platform::CPUPlace, ops::FusionGRUMKLDNNKernel, + ops::FusionGRUMKLDNNKernel, ops::FusionGRUMKLDNNKernel); diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_bf16_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..83b636650ab41cbd6b31677860b944497bfd4aa3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_bf16_mkldnn_op.py @@ -0,0 +1,113 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import struct +import paddle.fluid.core as core +from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16 +from paddle.fluid.tests.unittests.op_test import OpTest +from paddle.fluid.tests.unittests.test_fusion_gru_op import fusion_gru +from paddle.fluid.tests.unittests.test_fusion_lstm_op import fc, ACTIVATION + + +@unittest.skipIf(not core.supports_bfloat16(), + "place does not support BF16 evaluation") +class TestFusionGRUBF16MKLDNNOp(OpTest): + def set_confs(self): + self.mkldnn_data_type = False + + def setUp(self): + self.op_type = "fusion_gru" + self.lod = [[2, 4, 3]] + self.M = 3 + self.D = 5 + self.is_reverse = False + self.with_h0 = False + self.use_mkldnn = True + self._cpu_only = True + self.with_bias = True + self.act_state = 'tanh' + self.act_gate = 'sigmoid' + self.origin_mode = False + self.use_mkldnn = True + self.force_fp32_output = False + self.set_confs() + + T = sum(self.lod[0]) + N = len(self.lod[0]) + + # fp32 X input for reference implementation and + # corressponding bf16 data as input to GRU oneDNN bf16 kernel + x_fp32 = np.random.rand(T, self.M).astype('float32') + x_bf16 = convert_float_to_uint16(x_fp32) + + wx_fp32 = np.random.rand(self.M, 3 * self.D).astype('float32') + wh_fp32 = np.random.rand(self.D, 3 * self.D).astype('float32') + + # bias is fp32 despite other inputs being in bf16 + bias = np.random.rand( + 1, 3 * self.D).astype('float32') if self.with_bias else np.zeros( + (1, 3 * self.D), dtype='float32') + + h0_fp32 = np.random.rand( + N, self.D).astype('float32') if self.with_h0 else np.zeros( + (N, self.D), dtype='float32') + + _, _, _, hidden = fusion_gru( + x_fp32, self.lod, h0_fp32, wx_fp32, wh_fp32, bias, self.is_reverse, + self.origin_mode, ACTIVATION[self.act_state], + ACTIVATION[self.act_gate]) + + hidden_bf16 = convert_float_to_uint16(hidden) + + self.inputs = { + 'X': (x_bf16, self.lod), + 'WeightX': wx_fp32, + 'WeightH': wh_fp32 + } + + if self.with_bias: + self.inputs['Bias'] = bias + + if self.with_h0: + self.inputs['H0'] = h0_bf16 + + h0_bf16 = convert_float_to_uint16(h0_fp32) + self.outputs = {'Hidden': (hidden_bf16, self.lod)} + + self.attrs = { + 'activation': self.act_state, + 'gate_activation': self.act_gate, + 'is_reverse': self.is_reverse, + 'origin_mode': self.origin_mode, + 'force_fp32_output': self.force_fp32_output, + 'use_mkldnn': self.use_mkldnn + } + + +class TestFusionGRUINT8MKLDNNOp2(TestFusionGRUBF16MKLDNNOp): + def set_confs(self): + self.origin_mode = False + + +class TestFusionGRUINT8MKLDNNOp3(TestFusionGRUBF16MKLDNNOp): + def set_confs(self): + self.with_bias = False + + +if __name__ == "__main__": + unittest.main()