diff --git a/paddle/phi/kernels/gpu/norm_grad_kernel.cu b/paddle/phi/kernels/gpu/norm_grad_kernel.cu index cb02cc713852c8e5c96edf0e82b02c183ad51b6c..ac29541326ec9d59bf0c7fa6185b763a58fe3d96 100644 --- a/paddle/phi/kernels/gpu/norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/norm_grad_kernel.cu @@ -116,4 +116,5 @@ PD_REGISTER_KERNEL(norm_grad, phi::NormGradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/norm_kernel.cu b/paddle/phi/kernels/gpu/norm_kernel.cu index 4843831ebfc683c715b0328a03b14614a86b2b2d..a933de0bffac30b1bdd9ffd14b6e2df081d1acc3 100644 --- a/paddle/phi/kernels/gpu/norm_kernel.cu +++ b/paddle/phi/kernels/gpu/norm_kernel.cu @@ -43,7 +43,7 @@ __global__ void Normalize(const T* x, const int pre, const int axis_n, // dim in axis const int post, - const T eps, + const float eps, T* y, T* out_norm) { using MT = typename phi::dtype::MPTypeTrait::Type; @@ -86,7 +86,6 @@ void NormKernel(const Context& ctx, auto xdim = in_x->dims(); if (axis < 0) axis = xdim.size() + axis; - T eps = static_cast(epsilon); DenseTensor* out_norm; DenseTensor out_norm_tmp; @@ -117,8 +116,8 @@ void NormKernel(const Context& ctx, int max_threads = ctx.GetMaxPhysicalThreadCount(); const int max_blocks = std::max(max_threads / block, 1); int grid = std::min(max_blocks, pre * post); - Normalize - <<>>(x_ptr, pre, n, post, eps, y, norm_ptr); + Normalize<<>>( + x_ptr, pre, n, post, epsilon, y, norm_ptr); } } // namespace phi @@ -129,4 +128,5 @@ PD_REGISTER_KERNEL(norm, phi::NormKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/python/paddle/fluid/tests/unittests/test_norm_op.py b/python/paddle/fluid/tests/unittests/test_norm_op.py index 7b4b8dc60a02af8349b71cfcc32295d9d7a60fa9..64a17969ddeb42edf7208a0d5f5253f9768887f4 100644 --- a/python/paddle/fluid/tests/unittests/test_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_norm_op.py @@ -16,9 +16,11 @@ import unittest import numpy as np from eager_op_test import OpTest, skip_check_grad_ci +from op_test import convert_float_to_uint16 import paddle import paddle.fluid as fluid +import paddle.fluid.core as core def l2_norm(x, axis, epsilon): @@ -157,6 +159,37 @@ class TestNormTestOp(OpTest): self.epsilon = 1e-8 +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "core is not compiled with CUDA and not support the bfloat16", +) +class TestNormBF16Op(OpTest): + def setUp(self): + self.op_type = "norm" + self.python_api = norm_wrapper + self.init_test_case() + self.dtype = "float32" + x = np.random.random(self.shape).astype(self.dtype) + y, norm = l2_norm(x, self.axis, self.epsilon) + self.inputs = {'X': convert_float_to_uint16(x)} + self.attrs = {'epsilon': self.epsilon, 'axis': self.axis} + self.outputs = {'Out': convert_float_to_uint16(y), 'Norm': norm} + self.python_out_sig = ['Out'] + + def test_check_output(self): + self.check_output_with_place(core.CUDAPlace(0), atol=1e-1) + + def test_check_grad(self): + self.check_grad_with_place( + core.CUDAPlace(0), ['X'], 'Out', max_relative_error=1e-2 + ) + + def init_test_case(self): + self.shape = [2, 3, 4, 5] + self.axis = 1 + self.epsilon = 1e-8 + + class API_NormTest(unittest.TestCase): def test_errors(self): with fluid.program_guard(fluid.Program()): diff --git a/python/paddle/fluid/tests/unittests/test_scale_op.py b/python/paddle/fluid/tests/unittests/test_scale_op.py index a643f3136f400b409782f44345d42afe7241871d..b8029e9f7ac0f0a507cb1779975cf093a92054e5 100644 --- a/python/paddle/fluid/tests/unittests/test_scale_op.py +++ b/python/paddle/fluid/tests/unittests/test_scale_op.py @@ -148,12 +148,10 @@ class TestScaleFp16Op(TestScaleOp): self.dtype = np.float16 def test_check_output(self): - place = core.CUDAPlace(0) - self.check_output_with_place(place, check_eager=True) + self.check_output(check_eager=True) def test_check_grad(self): - place = core.CUDAPlace(0) - self.check_grad_with_place(place, ["X"], "Out", check_eager=True) + self.check_grad(["X"], "Out", check_eager=True) class TestScaleBF16Op(OpTest):