未验证 提交 bcd40f21 编写于 作者: W wuhuanzhou 提交者: GitHub

relu supports bfloat16 data type (#32542)

上级 b5882c6e
......@@ -13,6 +13,7 @@ limitations under the License. */
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/cuda_device_function.h"
namespace paddle {
......@@ -1437,9 +1438,9 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================================================================== */
/* =========================== relu register ============================ */
#ifdef PADDLE_WITH_HIP
REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, CudaReluFunctor,
CudaReluGradFunctor);
REGISTER_OP_CUDA_KERNEL(
relu_grad_grad,
ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
......@@ -1448,6 +1449,36 @@ REGISTER_OP_CUDA_KERNEL(
ops::ReluGradGradFunctor<double>>,
ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
ops::ReluGradGradFunctor<plat::float16>>);
#else
REGISTER_OP_CUDA_KERNEL(
relu, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
ops::CudaReluFunctor<float>>,
ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
ops::CudaReluFunctor<double>>,
ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaReluFunctor<plat::float16>>,
ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaReluFunctor<plat::bfloat16>>);
REGISTER_OP_CUDA_KERNEL(
relu_grad, ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaReluGradFunctor<float>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaReluGradFunctor<double>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaReluGradFunctor<plat::float16>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaReluGradFunctor<plat::bfloat16>>);
REGISTER_OP_CUDA_KERNEL(
relu_grad_grad,
ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::ReluGradGradFunctor<float>>,
ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::ReluGradGradFunctor<double>>,
ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
ops::ReluGradGradFunctor<plat::float16>>,
ops::ActivationDoubleGradKernel<plat::CUDADeviceContext,
ops::ReluGradGradFunctor<plat::bfloat16>>);
#endif
/* ========================================================================== */
/* =========================== tanh register ============================ */
......
......@@ -95,6 +95,7 @@ struct CastOpFunctor<platform::CUDADeviceContext, InT> {
namespace ops = paddle::operators;
#ifdef PADDLE_WITH_HIP
REGISTER_OP_CUDA_KERNEL(
cast, ops::CastOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, double>,
......@@ -108,3 +109,20 @@ REGISTER_OP_CUDA_KERNEL(
paddle::platform::complex64>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
#else
REGISTER_OP_CUDA_KERNEL(
cast, ops::CastOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, bool>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::bfloat16>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
ops::CastOpKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
#endif
......@@ -132,6 +132,8 @@ def get_numeric_gradient(place,
tensor_to_check_dtype = np.float16
# set delta as np.float16, will automatic convert to float32, float64
delta = np.array(delta).astype(np.float16)
elif tensor_to_check_dtype == core.VarDesc.VarType.BF16:
tensor_to_check_dtype = np.float32
else:
raise ValueError("Not supported data type " + str(
tensor_to_check_dtype))
......@@ -140,9 +142,10 @@ def get_numeric_gradient(place,
sum = []
op.run(scope, place)
for output_name in output_names:
sum.append(
np.array(scope.find_var(output_name).get_tensor()).astype(
tensor_to_check_dtype).mean())
output_numpy = np.array(scope.find_var(output_name).get_tensor())
if tensor_to_check._dtype() == core.VarDesc.VarType.BF16:
output_numpy = convert_uint16_to_float(output_numpy)
sum.append(output_numpy.astype(tensor_to_check_dtype).mean())
return tensor_to_check_dtype(np.array(sum).sum() / len(output_names))
gradient_flat = np.zeros(shape=(tensor_size, ), dtype=tensor_to_check_dtype)
......@@ -152,6 +155,11 @@ def get_numeric_gradient(place,
numpy_tensor = np.array(tensor).astype(np.float16)
numpy_tensor = numpy_tensor.flatten()
return numpy_tensor[i]
elif tensor_to_check._dtype() == core.VarDesc.VarType.BF16:
numpy_tensor = np.array(tensor).astype(np.uint16)
numpy_tensor = numpy_tensor.flatten()
return struct.unpack('<f', struct.pack('<I', numpy_tensor[i]
<< 16))[0]
elif tensor_to_check_dtype == np.float32:
return tensor._get_float_element(i)
elif tensor_to_check_dtype == np.float64:
......@@ -168,6 +176,13 @@ def get_numeric_gradient(place,
numpy_tensor[i] = e
numpy_tensor = numpy_tensor.reshape(shape)
tensor.set(numpy_tensor, place)
elif tensor_to_check._dtype() == core.VarDesc.VarType.BF16:
numpy_tensor = np.array(tensor).astype(np.uint16)
shape = numpy_tensor.shape
numpy_tensor = numpy_tensor.flatten()
numpy_tensor[i] = np.uint16(copy_bits_from_float_to_uint16(e))
numpy_tensor = numpy_tensor.reshape(shape)
tensor.set(numpy_tensor, place)
elif tensor_to_check_dtype == np.float32:
tensor._set_float_element(i, e)
elif tensor_to_check_dtype == np.float64:
......@@ -1353,6 +1368,8 @@ class OpTest(unittest.TestCase):
abs_a[abs_a < 1e-10] = 1e-3
abs_a[np.logical_and(abs_a > 1e-10, abs_a <= 1e-8)] *= 1e4
abs_a[np.logical_and(abs_a > 1e-8, abs_a <= 1e-6)] *= 1e2
elif self.is_bfloat16_op():
abs_a[abs_a < 1e-2] = 1
else:
abs_a[abs_a < 1e-3] = 1
......@@ -1500,6 +1517,13 @@ class OpTest(unittest.TestCase):
dygraph_grad = self._get_dygraph_grad(
inputs_to_check, place, output_names, user_defined_grad_outputs,
no_grad_set)
fp32_grads = []
for grad in dygraph_grad:
if grad.dtype == np.uint16:
grad = convert_uint16_to_float(grad)
max_relative_error = 0.03
fp32_grads.append(grad)
dygraph_grad = fp32_grads
self._assert_is_close(numeric_grads, dygraph_grad, inputs_to_check,
max_relative_error,
"Gradient Check On %s" % str(place))
......@@ -1544,6 +1568,21 @@ class OpTest(unittest.TestCase):
outputs=outputs,
attrs=attrs_outputs if hasattr(self, "attrs") else None)
if self.dtype == np.uint16:
cast_inputs = self._find_var_in_dygraph(outputs,
output_names[0])
cast_outputs = block.create_var(
dtype="float32", shape=cast_inputs[0].shape)
cast_op = block.append_op(
inputs={"X": cast_inputs},
outputs={"Out": cast_outputs},
type="cast",
attrs={
"in_dtype": core.VarDesc.VarType.BF16,
"out_dtype": core.VarDesc.VarType.FP32
})
outputs = {output_names[0]: cast_outputs}
outputs_valid = {}
for output_name in output_names:
outputs_valid[output_name] = self._find_var_in_dygraph(
......@@ -1659,6 +1698,21 @@ class OpTest(unittest.TestCase):
feed_dict = self.feed_var(inputs, place)
if user_defined_grad_outputs is None:
if self.dtype == np.uint16:
cast_inputs = list(map(block.var, output_names))
cast_outputs = block.create_var(
dtype="float32", shape=cast_inputs[0].shape)
cast_op = block.append_op(
inputs={"X": cast_inputs},
outputs={"Out": cast_outputs},
type="cast",
attrs={
"in_dtype": core.VarDesc.VarType.BF16,
"out_dtype": core.VarDesc.VarType.FP32
})
cast_op.desc.infer_var_type(block.desc)
cast_op.desc.infer_shape(block.desc)
output_names = [cast_outputs.name]
loss = append_loss_ops(block, output_names)
param_grad_list = append_backward(
loss=loss,
......
......@@ -18,7 +18,7 @@ import unittest
import numpy as np
from scipy.special import expit, erf
from op_test import OpTest
from op_test import OpTest, convert_float_to_uint16
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
......@@ -1103,12 +1103,19 @@ class TestRelu(TestActivation):
self.init_dtype()
np.random.seed(1024)
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
# The same reason with TestAbs
x[np.abs(x) < 0.005] = 0.02
out = np.maximum(x, 0)
if self.dtype == np.uint16:
x = np.random.uniform(-1, 1, [11, 17]).astype(np.float32)
# The same reason with TestAbs
x[np.abs(x) < 0.005] = 0.02
out = convert_float_to_uint16(np.maximum(x, 0))
self.inputs = {'X': convert_float_to_uint16(x)}
else:
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
# The same reason with TestAbs
x[np.abs(x) < 0.005] = 0.02
out = np.maximum(x, 0)
self.inputs = {'X': x}
self.inputs = {'X': x}
self.outputs = {'Out': out}
def test_check_grad(self):
......@@ -2739,5 +2746,32 @@ create_test_act_fp16_class(TestHardSigmoid)
create_test_act_fp16_class(TestSwish, grad_atol=0.85)
create_test_act_fp16_class(TestHardSwish)
def create_test_act_bf16_class(parent,
atol=1e-2,
grad_check=True,
grad_atol=0.80):
@unittest.skipIf(not paddle.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestActBF16(parent):
def init_dtype(self):
self.dtype = np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=atol)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', max_relative_error=grad_atol)
cls_name = "{0}_{1}".format(parent.__name__, "bf16")
TestActBF16.__name__ = cls_name
globals()[cls_name] = TestActBF16
create_test_act_bf16_class(TestRelu)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册