未验证 提交 606611d3 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] GRU BF16 kernel (#27731)

上级 6c1acf34
...@@ -1894,7 +1894,8 @@ PDNode *patterns::QuantizePlacement::operator()( ...@@ -1894,7 +1894,8 @@ PDNode *patterns::QuantizePlacement::operator()(
PDNode *patterns::Bfloat16Placement::operator()( PDNode *patterns::Bfloat16Placement::operator()(
const std::unordered_set<std::string> &bfloat16_enabled_op_types) { const std::unordered_set<std::string> &bfloat16_enabled_op_types) {
std::unordered_set<std::string> supported_op_types{"conv2d"}; std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>({"conv2d", "fusion_gru"});
if (!bfloat16_enabled_op_types.empty()) { if (!bfloat16_enabled_op_types.empty()) {
supported_op_types = bfloat16_enabled_op_types; supported_op_types = bfloat16_enabled_op_types;
} }
......
...@@ -86,7 +86,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> { ...@@ -86,7 +86,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
// Weights for int8 kernel are of a type s8 // Weights for int8 kernel are of a type s8
const auto weights_dt = const auto weights_dt =
is_INT8 ? dnnl::memory::data_type::s8 : dnnl::memory::data_type::f32; is_INT8 ? dnnl::memory::data_type::s8 : MKLDNNGetDataType<T>();
// oneDNN RNN dimensions // oneDNN RNN dimensions
const int64_t D = 1; // Directions const int64_t D = 1; // Directions
...@@ -226,6 +226,8 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> { ...@@ -226,6 +226,8 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
} }
// TODO(grygielski) H0 is for now persistable // TODO(grygielski) H0 is for now persistable
// TODO(jczaja) H0 should be updated each iter and of T type (Fusion pass does
// not support in yet)
std::shared_ptr<dnnl::memory> AcquireH0Memory(const Tensor* h0) { std::shared_ptr<dnnl::memory> AcquireH0Memory(const Tensor* h0) {
const std::string h0_key = memory_key_ + "@h0"; const std::string h0_key = memory_key_ + "@h0";
auto memory_p = auto memory_p =
...@@ -397,14 +399,14 @@ template <typename T> ...@@ -397,14 +399,14 @@ template <typename T>
class FusionGRUMKLDNNKernel : public framework::OpKernel<T> { class FusionGRUMKLDNNKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const bool is_INT8 = std::is_same<T, uint8_t>::value; const bool is_bf16 = std::is_same<T, paddle::platform::bfloat16>::value;
const bool force_fp32_output = ctx.Attr<bool>("force_fp32_output"); const bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
// TODO(grygielski) Add option for bfloat // BF16 does not support force output
if (!is_INT8 || force_fp32_output) { if (!is_bf16 && force_fp32_output) {
RunKernel<float>(ctx); RunKernel<float>(ctx);
} else { } else {
RunKernel<uint8_t>(ctx); RunKernel<T>(ctx);
} }
} }
...@@ -495,4 +497,5 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel<T> { ...@@ -495,4 +497,5 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL(fusion_gru, MKLDNN, paddle::platform::CPUPlace, REGISTER_OP_KERNEL(fusion_gru, MKLDNN, paddle::platform::CPUPlace,
ops::FusionGRUMKLDNNKernel<float>, ops::FusionGRUMKLDNNKernel<float>,
ops::FusionGRUMKLDNNKernel<paddle::platform::bfloat16>,
ops::FusionGRUMKLDNNKernel<uint8_t>); ops::FusionGRUMKLDNNKernel<uint8_t>);
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import struct
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16
from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.test_fusion_gru_op import fusion_gru
from paddle.fluid.tests.unittests.test_fusion_lstm_op import fc, ACTIVATION
@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestFusionGRUBF16MKLDNNOp(OpTest):
def set_confs(self):
self.mkldnn_data_type = False
def setUp(self):
self.op_type = "fusion_gru"
self.lod = [[2, 4, 3]]
self.M = 3
self.D = 5
self.is_reverse = False
self.with_h0 = False
self.use_mkldnn = True
self._cpu_only = True
self.with_bias = True
self.act_state = 'tanh'
self.act_gate = 'sigmoid'
self.origin_mode = False
self.use_mkldnn = True
self.force_fp32_output = False
self.set_confs()
T = sum(self.lod[0])
N = len(self.lod[0])
# fp32 X input for reference implementation and
# corressponding bf16 data as input to GRU oneDNN bf16 kernel
x_fp32 = np.random.rand(T, self.M).astype('float32')
x_bf16 = convert_float_to_uint16(x_fp32)
wx_fp32 = np.random.rand(self.M, 3 * self.D).astype('float32')
wh_fp32 = np.random.rand(self.D, 3 * self.D).astype('float32')
# bias is fp32 despite other inputs being in bf16
bias = np.random.rand(
1, 3 * self.D).astype('float32') if self.with_bias else np.zeros(
(1, 3 * self.D), dtype='float32')
h0_fp32 = np.random.rand(
N, self.D).astype('float32') if self.with_h0 else np.zeros(
(N, self.D), dtype='float32')
_, _, _, hidden = fusion_gru(
x_fp32, self.lod, h0_fp32, wx_fp32, wh_fp32, bias, self.is_reverse,
self.origin_mode, ACTIVATION[self.act_state],
ACTIVATION[self.act_gate])
hidden_bf16 = convert_float_to_uint16(hidden)
self.inputs = {
'X': (x_bf16, self.lod),
'WeightX': wx_fp32,
'WeightH': wh_fp32
}
if self.with_bias:
self.inputs['Bias'] = bias
if self.with_h0:
self.inputs['H0'] = h0_bf16
h0_bf16 = convert_float_to_uint16(h0_fp32)
self.outputs = {'Hidden': (hidden_bf16, self.lod)}
self.attrs = {
'activation': self.act_state,
'gate_activation': self.act_gate,
'is_reverse': self.is_reverse,
'origin_mode': self.origin_mode,
'force_fp32_output': self.force_fp32_output,
'use_mkldnn': self.use_mkldnn
}
class TestFusionGRUINT8MKLDNNOp2(TestFusionGRUBF16MKLDNNOp):
def set_confs(self):
self.origin_mode = False
class TestFusionGRUINT8MKLDNNOp3(TestFusionGRUBF16MKLDNNOp):
def set_confs(self):
self.with_bias = False
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册