From bcd40f21195ad95d195aa0451f54d2350627a97d Mon Sep 17 00:00:00 2001 From: wuhuanzhou Date: Tue, 18 May 2021 15:05:19 +0800 Subject: [PATCH] relu supports bfloat16 data type (#32542) --- paddle/fluid/operators/activation_op.cu | 33 +++++++++- paddle/fluid/operators/cast_op.cu | 18 ++++++ .../paddle/fluid/tests/unittests/op_test.py | 60 ++++++++++++++++++- .../tests/unittests/test_activation_op.py | 46 ++++++++++++-- 4 files changed, 147 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index 618f17031b1..002fae60120 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -13,6 +13,7 @@ limitations under the License. */ #include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/operators/math/math_cuda_utils.h" +#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/cuda_device_function.h" namespace paddle { @@ -1437,9 +1438,9 @@ REGISTER_OP_CUDA_KERNEL( /* ========================================================================== */ /* =========================== relu register ============================ */ +#ifdef PADDLE_WITH_HIP REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, CudaReluFunctor, CudaReluGradFunctor); - REGISTER_OP_CUDA_KERNEL( relu_grad_grad, ops::ActivationDoubleGradKernel>, ops::ActivationDoubleGradKernel>); +#else +REGISTER_OP_CUDA_KERNEL( + relu, ops::ActivationCudaKernel>, + ops::ActivationCudaKernel>, + ops::ActivationCudaKernel>, + ops::ActivationCudaKernel>); +REGISTER_OP_CUDA_KERNEL( + relu_grad, ops::ActivationGradCudaKernel>, + ops::ActivationGradCudaKernel>, + ops::ActivationGradCudaKernel>, + ops::ActivationGradCudaKernel>); +REGISTER_OP_CUDA_KERNEL( + relu_grad_grad, + ops::ActivationDoubleGradKernel>, + ops::ActivationDoubleGradKernel>, + ops::ActivationDoubleGradKernel>, + ops::ActivationDoubleGradKernel>); +#endif /* ========================================================================== */ /* =========================== tanh register ============================ */ diff --git a/paddle/fluid/operators/cast_op.cu b/paddle/fluid/operators/cast_op.cu index 13759633d01..2ef5b9ae3ac 100644 --- a/paddle/fluid/operators/cast_op.cu +++ b/paddle/fluid/operators/cast_op.cu @@ -95,6 +95,7 @@ struct CastOpFunctor { namespace ops = paddle::operators; +#ifdef PADDLE_WITH_HIP REGISTER_OP_CUDA_KERNEL( cast, ops::CastOpKernel, ops::CastOpKernel, @@ -108,3 +109,20 @@ REGISTER_OP_CUDA_KERNEL( paddle::platform::complex64>, ops::CastOpKernel); +#else +REGISTER_OP_CUDA_KERNEL( + cast, ops::CastOpKernel, + ops::CastOpKernel, + ops::CastOpKernel, + ops::CastOpKernel, + ops::CastOpKernel, + ops::CastOpKernel, + ops::CastOpKernel, + ops::CastOpKernel, + ops::CastOpKernel, + ops::CastOpKernel); +#endif diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index a2e467ad747..3524d1e553d 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -132,6 +132,8 @@ def get_numeric_gradient(place, tensor_to_check_dtype = np.float16 # set delta as np.float16, will automatic convert to float32, float64 delta = np.array(delta).astype(np.float16) + elif tensor_to_check_dtype == core.VarDesc.VarType.BF16: + tensor_to_check_dtype = np.float32 else: raise ValueError("Not supported data type " + str( tensor_to_check_dtype)) @@ -140,9 +142,10 @@ def get_numeric_gradient(place, sum = [] op.run(scope, place) for output_name in output_names: - sum.append( - np.array(scope.find_var(output_name).get_tensor()).astype( - tensor_to_check_dtype).mean()) + output_numpy = np.array(scope.find_var(output_name).get_tensor()) + if tensor_to_check._dtype() == core.VarDesc.VarType.BF16: + output_numpy = convert_uint16_to_float(output_numpy) + sum.append(output_numpy.astype(tensor_to_check_dtype).mean()) return tensor_to_check_dtype(np.array(sum).sum() / len(output_names)) gradient_flat = np.zeros(shape=(tensor_size, ), dtype=tensor_to_check_dtype) @@ -152,6 +155,11 @@ def get_numeric_gradient(place, numpy_tensor = np.array(tensor).astype(np.float16) numpy_tensor = numpy_tensor.flatten() return numpy_tensor[i] + elif tensor_to_check._dtype() == core.VarDesc.VarType.BF16: + numpy_tensor = np.array(tensor).astype(np.uint16) + numpy_tensor = numpy_tensor.flatten() + return struct.unpack(' 1e-10, abs_a <= 1e-8)] *= 1e4 abs_a[np.logical_and(abs_a > 1e-8, abs_a <= 1e-6)] *= 1e2 + elif self.is_bfloat16_op(): + abs_a[abs_a < 1e-2] = 1 else: abs_a[abs_a < 1e-3] = 1 @@ -1500,6 +1517,13 @@ class OpTest(unittest.TestCase): dygraph_grad = self._get_dygraph_grad( inputs_to_check, place, output_names, user_defined_grad_outputs, no_grad_set) + fp32_grads = [] + for grad in dygraph_grad: + if grad.dtype == np.uint16: + grad = convert_uint16_to_float(grad) + max_relative_error = 0.03 + fp32_grads.append(grad) + dygraph_grad = fp32_grads self._assert_is_close(numeric_grads, dygraph_grad, inputs_to_check, max_relative_error, "Gradient Check On %s" % str(place)) @@ -1544,6 +1568,21 @@ class OpTest(unittest.TestCase): outputs=outputs, attrs=attrs_outputs if hasattr(self, "attrs") else None) + if self.dtype == np.uint16: + cast_inputs = self._find_var_in_dygraph(outputs, + output_names[0]) + cast_outputs = block.create_var( + dtype="float32", shape=cast_inputs[0].shape) + cast_op = block.append_op( + inputs={"X": cast_inputs}, + outputs={"Out": cast_outputs}, + type="cast", + attrs={ + "in_dtype": core.VarDesc.VarType.BF16, + "out_dtype": core.VarDesc.VarType.FP32 + }) + outputs = {output_names[0]: cast_outputs} + outputs_valid = {} for output_name in output_names: outputs_valid[output_name] = self._find_var_in_dygraph( @@ -1659,6 +1698,21 @@ class OpTest(unittest.TestCase): feed_dict = self.feed_var(inputs, place) if user_defined_grad_outputs is None: + if self.dtype == np.uint16: + cast_inputs = list(map(block.var, output_names)) + cast_outputs = block.create_var( + dtype="float32", shape=cast_inputs[0].shape) + cast_op = block.append_op( + inputs={"X": cast_inputs}, + outputs={"Out": cast_outputs}, + type="cast", + attrs={ + "in_dtype": core.VarDesc.VarType.BF16, + "out_dtype": core.VarDesc.VarType.FP32 + }) + cast_op.desc.infer_var_type(block.desc) + cast_op.desc.infer_shape(block.desc) + output_names = [cast_outputs.name] loss = append_loss_ops(block, output_names) param_grad_list = append_backward( loss=loss, diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 31589ca4ae3..ef5ac46cede 100755 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -18,7 +18,7 @@ import unittest import numpy as np from scipy.special import expit, erf -from op_test import OpTest +from op_test import OpTest, convert_float_to_uint16 import paddle import paddle.nn as nn import paddle.nn.functional as F @@ -1103,12 +1103,19 @@ class TestRelu(TestActivation): self.init_dtype() np.random.seed(1024) - x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) - # The same reason with TestAbs - x[np.abs(x) < 0.005] = 0.02 - out = np.maximum(x, 0) + if self.dtype == np.uint16: + x = np.random.uniform(-1, 1, [11, 17]).astype(np.float32) + # The same reason with TestAbs + x[np.abs(x) < 0.005] = 0.02 + out = convert_float_to_uint16(np.maximum(x, 0)) + self.inputs = {'X': convert_float_to_uint16(x)} + else: + x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) + # The same reason with TestAbs + x[np.abs(x) < 0.005] = 0.02 + out = np.maximum(x, 0) + self.inputs = {'X': x} - self.inputs = {'X': x} self.outputs = {'Out': out} def test_check_grad(self): @@ -2739,5 +2746,32 @@ create_test_act_fp16_class(TestHardSigmoid) create_test_act_fp16_class(TestSwish, grad_atol=0.85) create_test_act_fp16_class(TestHardSwish) + +def create_test_act_bf16_class(parent, + atol=1e-2, + grad_check=True, + grad_atol=0.80): + @unittest.skipIf(not paddle.is_compiled_with_cuda(), + "core is not compiled with CUDA") + class TestActBF16(parent): + def init_dtype(self): + self.dtype = np.uint16 + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=atol) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, ['X'], 'Out', max_relative_error=grad_atol) + + cls_name = "{0}_{1}".format(parent.__name__, "bf16") + TestActBF16.__name__ = cls_name + globals()[cls_name] = TestActBF16 + + +create_test_act_bf16_class(TestRelu) + if __name__ == "__main__": unittest.main() -- GitLab