未验证 提交 7821759d 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Add bfloat16 softmax and gelu (#28394)

* Add bfloat16 softmax and gelu

* Add pass attr bfloat16_enabled_op_types

* Changes from review
上级 ba036b88
......@@ -2101,8 +2101,9 @@ PDNode *patterns::QuantizePlacement::operator()(
PDNode *patterns::Bfloat16Placement::operator()(
const std::unordered_set<std::string> &bfloat16_enabled_op_types) {
std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>(
{"concat", "conv2d", "fusion_gru", "reshape2", "transpose2", "sum"});
std::unordered_set<std::string>({"concat", "conv2d", "fusion_gru", "gelu",
"reshape2", "softmax", "sum",
"transpose2"});
if (!bfloat16_enabled_op_types.empty()) {
supported_op_types = bfloat16_enabled_op_types;
}
......
......@@ -33,7 +33,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
if (type == "conv2d") {
op->SetAttr("name", name);
op->SetInput("Input", {inputs[0]});
} else if (type == "relu") {
} else if (type == "gelu") {
op->SetInput("X", inputs);
} else if (type == "concat") {
op->SetAttr("axis", 1);
......@@ -71,7 +71,7 @@ ProgramDesc BuildProgramDesc() {
SetOp(&prog, "concat", "concat1", {"a", "b"}, {"c"});
SetOp(&prog, "conv2d", "conv1", {"c"}, {"f"});
SetOp(&prog, "relu", "relu1", {"f"}, {"g"});
SetOp(&prog, "gelu", "gelu1", {"f"}, {"g"});
SetOp(&prog, "pool2d", "pool1", {"g"}, {"h"});
SetOp(&prog, "conv2d", "conv2", {"h"}, {"k"});
SetOp(&prog, "pool2d", "pool2", {"k"}, {"l"});
......@@ -126,7 +126,7 @@ void DefaultAttrTest(unsigned expected_bfloat16_data_type_count) {
}
TEST(Bfloat16PlacementPass, enable_all) {
MainTest({"conv2d", "pool2d", "relu", "concat", "sum"}, 8);
MainTest({"conv2d", "pool2d", "gelu", "concat", "sum"}, 8);
}
TEST(Bfloat16PlacementPass, enabled_conv_and_pool) {
......@@ -134,7 +134,7 @@ TEST(Bfloat16PlacementPass, enabled_conv_and_pool) {
MainTest({"conv2d", "pool2d"}, 3);
}
TEST(Bfloat16PlacementPass, default_attr_value) { DefaultAttrTest(6); }
TEST(Bfloat16PlacementPass, default_attr_value) { DefaultAttrTest(7); }
} // namespace ir
} // namespace framework
......
......@@ -79,6 +79,10 @@ void IRPassManager::CreatePasses(Argument *argument,
} else if (pass_name == "cpu_quantize_pass") {
pass->Set("quant_var_scales",
new VarQuantScale(argument->quant_var_scales()));
} else if (pass_name == "cpu_bfloat16_placement_pass") {
pass->Set("bfloat16_enabled_op_types",
new std::unordered_set<std::string>(
argument->bfloat16_enabled_op_types()));
#endif
} else if (pass_name == "tensorrt_subgraph_pass") {
pass->Set("workspace_size", new int(argument->tensorrt_workspace_size()));
......
......@@ -111,6 +111,11 @@ class GeluOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<std::string>(
"mkldnn_data_type",
"(string, default \"float32\"). Data type of mkldnn kernel")
.SetDefault("float32")
.InEnum({"float32", "int8", "bfloat16"});
AddAttr<bool>("use_cudnn",
"(bool, default false) Only used in cudnn kernel, need "
"install cudnn")
......
......@@ -83,14 +83,14 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
const auto *x = ctx.Input<Tensor>("X");
auto *y = ctx.Output<Tensor>("Out");
T alpha = ctx.HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0;
T beta = ctx.HasAttr("beta") ? ctx.Attr<T>("beta") : 0;
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 0;
float beta = ctx.HasAttr("beta") ? ctx.Attr<float>("beta") : 0;
// paddle uses beta but mkldnn uses alpha for swish
if (algorithm == mkldnn::algorithm::eltwise_swish) {
std::swap(alpha, beta);
} else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) {
alpha = ctx.Attr<T>("threshold");
alpha = ctx.Attr<float>("threshold");
}
PADDLE_ENFORCE(
......@@ -128,14 +128,14 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X"));
T alpha = ctx.HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0;
T beta = ctx.HasAttr("beta") ? ctx.Attr<T>("beta") : 0;
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 0;
float beta = ctx.HasAttr("beta") ? ctx.Attr<float>("beta") : 0;
// paddle uses beta but mkldnn uses alpha for swish
if (algorithm == mkldnn::algorithm::eltwise_swish) {
std::swap(alpha, beta);
} else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) {
alpha = ctx.Attr<T>("threshold");
alpha = ctx.Attr<float>("threshold");
}
auto diff_dst_tz = framework::vectorize<int64_t>(diff_y->dims());
......@@ -272,11 +272,20 @@ namespace ops = paddle::operators;
act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace, \
ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>);
#define REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(act_type, functor, \
grad_functor) \
REGISTER_OP_KERNEL( \
act_type, MKLDNN, ::paddle::platform::CPUPlace, \
ops::MKLDNNActivationKernel<ops::functor<float>>, \
ops::MKLDNNActivationKernel<ops::functor<paddle::platform::bfloat16>>); \
REGISTER_OP_KERNEL( \
act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace, \
ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>);
#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro) \
__macro(relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
__macro(relu6, Relu6MKLDNNFunctor, Relu6MKLDNNGradFunctor); \
__macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
__macro(gelu, GeluMKLDNNFunctor, GeluMKLDNNGradFunctor); \
__macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor); \
__macro(sigmoid, SigmoidMKLDNNFunctor, SigmoidMKLDNNGradFunctor); \
__macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradFunctor); \
......@@ -284,3 +293,5 @@ namespace ops = paddle::operators;
__macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor);
FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(gelu, GeluMKLDNNFunctor,
GeluMKLDNNGradFunctor);
......@@ -181,6 +181,7 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(softmax, MKLDNN, ::paddle::platform::CPUPlace,
ops::SoftmaxMKLDNNKernel<float>);
ops::SoftmaxMKLDNNKernel<float>,
ops::SoftmaxMKLDNNKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(softmax_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::SoftmaxMKLDNNGradKernel<float>);
......@@ -115,6 +115,11 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<std::string>(
"mkldnn_data_type",
"(string, default \"float32\"). Data type of mkldnn kernel")
.SetDefault("float32")
.InEnum({"float32", "bfloat16"});
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
......
......@@ -18,7 +18,7 @@ import unittest
import numpy as np
from scipy.special import expit
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16
from paddle.fluid.tests.unittests.test_activation_op import TestActivation, TestRelu, TestTanh, TestSqrt, TestAbs, TestLeakyRelu, TestSwish, TestRelu6, TestSigmoid
from paddle.fluid.tests.unittests.test_gelu_op import gelu
from mkldnn_op_test import check_if_mkldnn_primitives_exist_in_bwd
......@@ -79,6 +79,44 @@ class TestMKLDNNGeluDim2Approx(TestActivation):
self.attrs = {"use_mkldnn": True, "approximate": True}
class TestMKLDNNGeluBf16Dim2(TestActivation):
def setUp(self):
self.op_type = "gelu"
self.dtype = np.uint16
x = np.random.uniform(-1, 1, [11, 17]).astype(np.float32)
out = convert_float_to_uint16(gelu(x, False))
self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': out}
self.attrs = {"use_mkldnn": True}
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
def test_check_grad(self):
pass
class TestMKLDNNGeluBf16Dim2Approx(TestActivation):
def setUp(self):
self.op_type = "gelu"
self.dtype = np.uint16
x = np.random.uniform(-1, 1, [11, 17]).astype(np.float32)
out = convert_float_to_uint16(gelu(x, True))
self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': out}
self.attrs = {"use_mkldnn": True, "approximate": True}
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
def test_check_grad(self):
pass
class TestMKLDNNTanhDim2(TestTanh):
def setUp(self):
super(TestMKLDNNTanhDim2, self).setUp()
......@@ -187,6 +225,44 @@ class TestMKLDNNGeluDim4Approx(TestActivation):
self.attrs = {"use_mkldnn": True, "approximate": True}
class TestMKLDNNGeluBf16Dim4(TestActivation):
def setUp(self):
self.op_type = "gelu"
self.dtype = np.uint16
x = np.random.uniform(-1, 1, [2, 4, 3, 5]).astype(np.float32)
out = convert_float_to_uint16(gelu(x, False))
self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': out}
self.attrs = {"use_mkldnn": True}
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
def test_check_grad(self):
pass
class TestMKLDNNGeluBf16Dim4Approx(TestActivation):
def setUp(self):
self.op_type = "gelu"
self.dtype = np.uint16
x = np.random.uniform(-1, 1, [2, 4, 3, 5]).astype(np.float32)
out = convert_float_to_uint16(gelu(x, True))
self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': out}
self.attrs = {"use_mkldnn": True, "approximate": True}
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
def test_check_grad(self):
pass
class TestMKLDNNTanhDim4(TestTanh):
def setUp(self):
super(TestMKLDNNTanhDim4, self).setUp()
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from paddle.fluid.tests.unittests.op_test import convert_float_to_uint16
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.test_softmax_op import TestSoftmaxOp, TestSoftmaxOp2, TestSoftmaxOp3, TestSoftmaxOp4, TestSoftmaxOp5, TestSoftmaxOp6
from paddle import enable_static
def stable_softmax(x):
"""Compute the softmax of vector x in a numerically stable way."""
shiftx = x - np.max(x).clip(-64.)
exps = np.exp(shiftx)
return exps / np.sum(exps)
class TestSoftmaxMKLDNNOp(TestSoftmaxOp):
def get_x_shape(self):
return [10, 10]
def get_axis(self):
return -1
def setUp(self):
self.op_type = "softmax"
self.use_mkldnn = True
self.dtype = np.uint16
self.init_kernel_type()
self.shape = self.get_x_shape()
self.axis = self.get_axis()
x = np.random.uniform(0.1, 1, self.shape).astype(np.float)
out = convert_float_to_uint16(
np.apply_along_axis(stable_softmax, self.axis, x))
self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': out}
self.attrs = {'axis': self.axis, 'use_mkldnn': self.use_mkldnn}
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
def test_check_grad(self):
pass
def init_kernel_type(self):
self.use_mkldnn = True
class TestSoftmaxMKLDNNOp2(TestSoftmaxOp2):
def init_kernel_type(self):
self.use_mkldnn = True
class TestSoftmaxMKLDNNOp3(TestSoftmaxOp3):
def init_kernel_type(self):
self.use_mkldnn = True
class TestSoftmaxMKLDNNOp4(TestSoftmaxOp4):
def init_kernel_type(self):
self.use_mkldnn = True
class TestSoftmaxMKLDNNOp5(TestSoftmaxOp5):
def init_kernel_type(self):
self.use_mkldnn = True
class TestSoftmaxMKLDNNOp6(TestSoftmaxOp6):
def init_kernel_type(self):
self.use_mkldnn = True
if __name__ == '__main__':
enable_static()
unittest.main()
......@@ -601,6 +601,7 @@ STATIC_MODE_TESTING_LIST = [
'test_quantize_mkldnn_op',
'test_requantize_mkldnn_op',
'test_softmax_mkldnn_op',
'test_softmax_bf16_mkldnn_op',
'test_sum_mkldnn_op',
'test_sum_bf16_mkldnn_op',
'test_transpose_int8_mkldnn_op',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册