未验证 提交 8e67629c 编写于 作者: J jakpiase 提交者: GitHub

Added support for BF16 datatype for all oneDNN activation kernels (#40721)

* added missing BF16 activations

* added softplus bf16

* minor change

* disabled tests for GPU
上级 292011eb
......@@ -30,6 +30,21 @@ namespace operators {
class AbsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class AbsOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -72,8 +87,17 @@ class AbsGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(dtype, ctx.GetPlace());
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......
......@@ -315,15 +315,7 @@ using ExpMKLDNNGradUseOutFunctor = MKLDNNActivationGradUseOutFunc<
namespace ops = paddle::operators;
#define REGISTER_ACTIVATION_MKLDNN_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_KERNEL(act_type, MKLDNN, ::paddle::platform::CPUPlace, \
ops::MKLDNNActivationKernel<ops::functor<float>>); \
REGISTER_OP_KERNEL( \
act_type##_grad, MKLDNN, ::paddle::platform::CPUPlace, \
ops::MKLDNNActivationGradKernel<ops::grad_functor<float>>);
#define REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(act_type, functor, \
grad_functor) \
#define REGISTER_ACTIVATION_MKLDNN_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_KERNEL( \
act_type, MKLDNN, ::paddle::platform::CPUPlace, \
ops::MKLDNNActivationKernel<ops::functor<float>>, \
......@@ -339,30 +331,27 @@ namespace ops = paddle::operators;
ops::MKLDNNActivationKernel<ops::functor<float>>);
#define FOR_EACH_MKLDNN_KERNEL_FUNCTOR(__macro) \
__macro(relu6, Relu6MKLDNNFunctor, Relu6MKLDNNGradFunctor); \
__macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
__macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor); \
__macro(hard_swish, HardSwishMKLDNNFunctor, HardSwishMKLDNNGradFunctor); \
__macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradUseOutFunctor); \
__macro(abs, AbsMKLDNNFunctor, AbsMKLDNNGradFunctor); \
__macro(elu, EluMKLDNNFunctor, EluMKLDNNGradUseOutFunctor); \
__macro(exp, ExpMKLDNNFunctor, ExpMKLDNNGradUseOutFunctor);
__macro(exp, ExpMKLDNNFunctor, ExpMKLDNNGradUseOutFunctor); \
__macro(gelu, GeluMKLDNNFunctor, GeluMKLDNNGradFunctor); \
__macro(hard_swish, HardSwishMKLDNNFunctor, HardSwishMKLDNNGradFunctor); \
__macro(leaky_relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
__macro(mish, MishMKLDNNFunctor, MishMKLDNNGradFunctor); \
__macro(relu, ReluMKLDNNFunctor, ReluMKLDNNGradFunctor); \
__macro(relu6, Relu6MKLDNNFunctor, Relu6MKLDNNGradFunctor); \
__macro(sigmoid, SigmoidMKLDNNFunctor, SigmoidMKLDNNGradUseOutFunctor); \
__macro(sqrt, SqrtMKLDNNFunctor, SqrtMKLDNNGradUseOutFunctor); \
__macro(swish, SwishMKLDNNFunctor, SwishMKLDNNGradFunctor); \
__macro(tanh, TanhMKLDNNFunctor, TanhMKLDNNGradUseOutFunctor);
FOR_EACH_MKLDNN_KERNEL_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_KERNEL);
REGISTER_ACTIVATION_MKLDNN_KERNEL_FWD_ONLY(round, RoundMKLDNNFunctor);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(relu, ReluMKLDNNFunctor,
ReluMKLDNNGradFunctor);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(gelu, GeluMKLDNNFunctor,
GeluMKLDNNGradFunctor);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sigmoid, SigmoidMKLDNNFunctor,
SigmoidMKLDNNGradUseOutFunctor);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(sqrt, SqrtMKLDNNFunctor,
SqrtMKLDNNGradUseOutFunctor);
REGISTER_ACTIVATION_MKLDNN_BF16_KERNEL(mish, MishMKLDNNFunctor,
MishMKLDNNGradFunctor);
REGISTER_ACTIVATION_MKLDNN_KERNEL_FWD_ONLY(round, RoundMKLDNNFunctor);
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(
softplus, MKLDNN, paddle::platform::CPUPlace,
ops::MKLDNNActivationKernel<ops::SoftplusMKLDNNFunctor<float>>);
ops::MKLDNNActivationKernel<ops::SoftplusMKLDNNFunctor<float>>,
ops::MKLDNNActivationKernel<
ops::SoftplusMKLDNNFunctor<paddle::platform::bfloat16>>);
......@@ -50,11 +50,11 @@ class MKLDNNBF16ActivationOp(object):
self.dtype = np.uint16
self.init_data()
self.config()
self.set_attrs()
self.out = self.op_forward(self.x)
self.inputs = {'X': convert_float_to_uint16(self.x)}
self.outputs = {'Out': self.out}
self.set_attrs()
def calculate_grads(self):
self.dx = self.op_grad(self.out, self.x)
......@@ -162,5 +162,110 @@ class TestMKLDNNMishBF16Op(MKLDNNBF16ActivationOp, TestActivation):
return dout * ((np.exp(x) * omega) / delta**2)
class TestMKLDNNRelu6BF16Op(MKLDNNBF16ActivationOp, TestActivation):
def config(self):
self.op_type = "relu6"
def op_forward(self, x):
return np.clip(x, 0, 6)
def op_grad(self, dout, x):
return np.where((x > 0) & (x <= 6), dout, 0)
class TestMKLDNNLeakyReluBF16Op(MKLDNNBF16ActivationOp, TestActivation):
def config(self):
self.op_type = "leaky_relu"
def op_forward(self, x):
return np.where(x > 0, x, self.alpha * x)
def op_grad(self, dout, x):
return np.where(x > 0, dout, self.alpha * dout)
def set_attrs(self):
self.alpha = 0.2
self.attrs = {"use_mkldnn": True, "alpha": self.alpha}
class TestMKLDNNSwishBF16Op(MKLDNNBF16ActivationOp, TestActivation):
def config(self):
self.op_type = "swish"
def expit(self, val):
return 1 / (1 + np.exp(-self.beta * val))
def op_forward(self, x):
return x * self.expit(x)
def op_grad(self, dout, x):
return dout * self.expit(x) * (1 + self.beta * x * (1 - self.expit(x)))
def set_attrs(self):
self.beta = 0.2
self.attrs = {"use_mkldnn": True, "beta": self.beta}
class TestMKLDNNHardSwishBF16Op(MKLDNNBF16ActivationOp, TestActivation):
def config(self):
self.op_type = "hard_swish"
def op_forward(self, x):
result = np.where(x < -3, 0, x)
return np.where(result > 3, result, result * (result + 3) / 6)
def op_grad(self, dout, x):
result = np.where(x < -3, 0, x)
return np.where(result > 3, dout, dout * (2 * x + 3) / 6)
class TestMKLDNNTanhBF16Op(MKLDNNBF16ActivationOp, TestActivation):
def config(self):
self.op_type = "tanh"
def op_forward(self, x):
return np.tanh(x)
def op_grad(self, dout, x):
return dout * (1 - np.tanh(x)**2)
class TestMKLDNNAbsBF16Op(MKLDNNBF16ActivationOp, TestActivation):
def config(self):
self.op_type = "abs"
def op_forward(self, x):
return np.absolute(x)
def op_grad(self, dout, x):
return dout * np.sign(x)
class TestMKLDNNEluBF16Op(MKLDNNBF16ActivationOp, TestActivation):
def config(self):
self.op_type = "elu"
def op_forward(self, x):
return np.where(x > 0, x, self.alpha * (np.exp(x) - 1))
def op_grad(self, dout, x):
return np.where(x > 0, dout, dout * self.alpha * np.exp(x))
def set_attrs(self):
self.alpha = 0.2
self.attrs = {"use_mkldnn": True, "alpha": self.alpha}
class TestMKLDNNExpBF16Op(MKLDNNBF16ActivationOp, TestActivation):
def config(self):
self.op_type = "exp"
def op_forward(self, x):
return np.exp(x)
def op_grad(self, dout, x):
return dout * np.exp(x)
if __name__ == '__main__':
unittest.main()
......@@ -16,7 +16,7 @@ from __future__ import print_function
import unittest
import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool
from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
......@@ -30,23 +30,32 @@ def ref_softplus(x, beta, threshold):
return out
@OpTestTool.skip_if(not (isinstance(_current_expected_place(), core.CPUPlace)),
"GPU is not supported")
@OpTestTool.skip_if_not_cpu_bf16()
class TestSoftplusOneDNNOp(OpTest):
def setUp(self):
self.op_type = "softplus"
self.beta = 1
self.threshold = 20
self.config()
self.set_dtype()
self.attrs = {'use_mkldnn': True, 'beta': self.beta}
self.inputs = {'X': np.random.random(self.x_shape).astype(np.float32)}
self.x = np.random.random(self.x_shape)
self.out = ref_softplus(self.x, self.beta, self.threshold)
if self.dtype != np.float32:
self.x = convert_float_to_uint16(self.x)
self.inputs = {'X': self.out}
self.outputs = {
'Out': ref_softplus(self.inputs['X'], self.beta, self.threshold)
'Out': ref_softplus(self.out, self.beta, self.threshold)
}
def config(self):
self.x_shape = (10, 10)
def set_dtype(self):
self.dtype = np.float32
def test_check_output(self):
self.check_output()
......@@ -73,6 +82,27 @@ class TestSoftplus3DExtendedFunctorOneDNNOp(TestSoftplusOneDNNOp):
self.beta = 0.4
class TestSoftplusBF16OneDNNOp(TestSoftplusOneDNNOp):
def set_dtype(self):
self.dtype = np.uint16
class TestSoftplus4DBF16OneDNNOp(TestSoftplus4DOneDNNOp):
def set_dtype(self):
self.dtype = np.uint16
class TestSoftplus6DBF16OneDNNOp(TestSoftplus6DOneDNNOp):
def set_dtype(self):
self.dtype = np.uint16
class TestSoftplus3DExtendedFunctorBF16OneDNNOp(
TestSoftplus3DExtendedFunctorOneDNNOp):
def set_dtype(self):
self.dtype = np.uint16
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册