From 6a0d60d27ff44ea425e35afa9b3f4bd884fb6506 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Mon, 7 Mar 2022 13:20:07 +0800 Subject: [PATCH] [bf16] add bf16 kernel: gaussian_random fill_constant fill_any_like (#40027) * add gaussian random * add full * refine reduce * refine code * refine gaussian_random unittest * add unittest for fill_any_like fill_constant --- paddle/fluid/operators/gaussian_random_op.cu | 3 +- .../phi/kernels/funcs/distribution_helper.h | 9 ++-- paddle/phi/kernels/gpu/full_kernel.cu | 10 ++-- .../phi/kernels/gpu/gaussian_random_kernel.cu | 13 ++++-- .../kernels/primitive/compute_primitives.h | 1 + .../tests/unittests/test_fill_any_like_op.py | 21 ++++++++- .../tests/unittests/test_fill_constant_op.py | 21 +++++++++ .../unittests/test_gaussian_random_op.py | 46 ++++++++++++++++++- 8 files changed, 110 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/operators/gaussian_random_op.cu b/paddle/fluid/operators/gaussian_random_op.cu index 717ec774414..00ce10bfe3b 100644 --- a/paddle/fluid/operators/gaussian_random_op.cu +++ b/paddle/fluid/operators/gaussian_random_op.cu @@ -45,7 +45,8 @@ struct GaussianGenerator { thrust::minstd_rand rng; rng.seed(seed_); using MT = typename details::MPTypeTrait::Type; - thrust::normal_distribution dist(mean_, std_); + thrust::normal_distribution dist(static_cast(mean_), + static_cast(std_)); unsigned int new_n = n + offset_; rng.discard(new_n); MT out = dist(rng); diff --git a/paddle/phi/kernels/funcs/distribution_helper.h b/paddle/phi/kernels/funcs/distribution_helper.h index 3ef39dc55d1..acc31d68b78 100644 --- a/paddle/phi/kernels/funcs/distribution_helper.h +++ b/paddle/phi/kernels/funcs/distribution_helper.h @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/generator.h" #include "paddle/phi/core/hostdevice.h" @@ -255,11 +256,13 @@ __global__ void DistributionKernel(size_t size, using SType = hiprandStatePhilox4_32_10_t; #endif size_t total_thread = GRID_NUM_X * BLOCK_NUM_X; - T args[kCount]; + using MT = typename phi::dtype::MPTypeTrait::Type; + MT args[kCount]; T result[kCount]; for (size_t i = idx; i < size; i += total_thread * kCount) { - kps::ElementwiseRandom(&args[0], dist, &state); - kps::ElementwiseUnary( + kps::ElementwiseRandom( + &args[0], dist, &state); + kps::ElementwiseUnary( &result[0], &args[0], trans); kps::WriteData( out_data + i, &result[0], size - i, 1, stride, 1); diff --git a/paddle/phi/kernels/gpu/full_kernel.cu b/paddle/phi/kernels/gpu/full_kernel.cu index 1f756bfdbed..a905979f08b 100644 --- a/paddle/phi/kernels/gpu/full_kernel.cu +++ b/paddle/phi/kernels/gpu/full_kernel.cu @@ -63,9 +63,11 @@ void FullLikeKernel(const Context& dev_ctx, auto value = val.to(); using CommonType = typename std::common_type< float, - typename std::conditional::value, - float, - T>::type>::type; + typename std::conditional< + std::is_same::value || + std::is_same::value, + float, + T>::type>::type; auto common_type_value = static_cast(value); @@ -110,6 +112,7 @@ PD_REGISTER_KERNEL(full, int64_t, bool, phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} @@ -123,6 +126,7 @@ PD_REGISTER_KERNEL(full_like, int, int64_t, bool, + phi::dtype::bfloat16, phi::dtype::float16) { kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); } diff --git a/paddle/phi/kernels/gpu/gaussian_random_kernel.cu b/paddle/phi/kernels/gpu/gaussian_random_kernel.cu index da16800ad02..e2fe2190c1c 100644 --- a/paddle/phi/kernels/gpu/gaussian_random_kernel.cu +++ b/paddle/phi/kernels/gpu/gaussian_random_kernel.cu @@ -18,8 +18,8 @@ #include #include #include - #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/distribution_helper.h" @@ -46,8 +46,9 @@ struct GaussianGenerator { __host__ __device__ T operator()(const unsigned int n) const { thrust::minstd_rand rng; rng.seed(seed_); - using MT = typename phi::kps::details::MPTypeTrait::Type; - thrust::normal_distribution dist(mean_, std_); + using MT = typename phi::dtype::MPTypeTrait::Type; + thrust::normal_distribution dist(static_cast(mean_), + static_cast(std_)); unsigned int new_n = n + offset_; rng.discard(new_n); MT out = dist(rng); @@ -83,9 +84,10 @@ void GaussianRandomKernel(const Context& dev_ctx, if (gen_cuda->GetIsInitPy() && seed_flag) { if (FLAGS_use_curand) { - using MT = typename phi::kps::details::MPTypeTrait::Type; + using MT = typename phi::dtype::MPTypeTrait::Type; funcs::normal_distribution dist; - funcs::normal_transform trans(mean, std); + funcs::normal_transform trans(static_cast(mean), + static_cast(std)); funcs::distribution_and_transform(dev_ctx, tensor, dist, trans); } else { auto seed_offset = gen_cuda->IncrementOffset(1); @@ -110,5 +112,6 @@ PD_REGISTER_KERNEL(gaussian_random, ALL_LAYOUT, phi::GaussianRandomKernel, phi::dtype::float16, + phi::dtype::bfloat16, float, double) {} diff --git a/paddle/phi/kernels/primitive/compute_primitives.h b/paddle/phi/kernels/primitive/compute_primitives.h index 19427551fb3..632ad00f6d0 100644 --- a/paddle/phi/kernels/primitive/compute_primitives.h +++ b/paddle/phi/kernels/primitive/compute_primitives.h @@ -22,6 +22,7 @@ #endif #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" +// #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" namespace phi { diff --git a/python/paddle/fluid/tests/unittests/test_fill_any_like_op.py b/python/paddle/fluid/tests/unittests/test_fill_any_like_op.py index 5bc2d1cda18..9be2e57ff0c 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_any_like_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_any_like_op.py @@ -21,7 +21,7 @@ from paddle.fluid import Program, program_guard import paddle.compat as cpt import unittest import numpy as np -from op_test import OpTest +from op_test import OpTest, convert_float_to_uint16 class TestFillAnyLikeOp(OpTest): @@ -47,6 +47,25 @@ class TestFillAnyLikeOpFloat32(TestFillAnyLikeOp): self.value = 0.0 +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFillAnyLikeOpBfloat16(OpTest): + def setUp(self): + self.op_type = "fill_any_like" + self.dtype = np.uint16 + self.value = 0.0 + self.inputs = {'X': np.random.random((219, 232)).astype(np.float32)} + self.attrs = {'value': self.value, 'dtype': core.VarDesc.VarType.BF16} + self.outputs = { + 'Out': + convert_float_to_uint16(self.value * np.ones_like(self.inputs["X"])) + } + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place) + + class TestFillAnyLikeOpValue1(TestFillAnyLikeOp): def init(self): self.value = 1.0 diff --git a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py index 822c952893e..15071b2b6aa 100644 --- a/python/paddle/fluid/tests/unittests/test_fill_constant_op.py +++ b/python/paddle/fluid/tests/unittests/test_fill_constant_op.py @@ -83,6 +83,27 @@ class TestFillConstantOp4(OpTest): self.check_output() +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFillConstantBF16Op(OpTest): + def setUp(self): + '''Test fill_constant op with specified value + ''' + self.op_type = "fill_constant" + self.dtype = np.uint16 + self.inputs = {} + self.attrs = { + 'shape': [123, 92], + 'value': 3.8, + 'dtype': core.VarDesc.VarType.BF16 + } + self.outputs = {'Out': convert_float_to_uint16(np.full((123, 92), 3.8))} + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place) + + class TestFillConstantOpWithSelectedRows(unittest.TestCase): def check_with_place(self, place): scope = core.Scope() diff --git a/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py b/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py index 31caf4bd6be..738441a46d3 100644 --- a/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py +++ b/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py @@ -22,7 +22,7 @@ import paddle.fluid as fluid import paddle.fluid.core as core from paddle.fluid.op import Operator from paddle.fluid.executor import Executor -from paddle.fluid.tests.unittests.op_test import OpTest +from paddle.fluid.tests.unittests.op_test import OpTest, convert_uint16_to_float import paddle @@ -65,6 +65,50 @@ class TestGaussianRandomOp(OpTest): "hist: " + str(hist) + " hist2: " + str(hist2)) +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestGaussianRandomBF16Op(OpTest): + def setUp(self): + self.op_type = "gaussian_random" + self.set_attrs() + self.inputs = {} + self.use_mkldnn = False + self.attrs = { + "shape": [123, 92], + "mean": self.mean, + "std": self.std, + "seed": 10, + "dtype": paddle.fluid.core.VarDesc.VarType.BF16, + "use_mkldnn": self.use_mkldnn + } + paddle.seed(10) + + self.outputs = {'Out': np.zeros((123, 92), dtype='float32')} + + def set_attrs(self): + self.mean = 1.0 + self.std = 2. + + def test_check_output(self): + self.check_output_with_place_customized( + self.verify_output, place=core.CUDAPlace(0)) + + def verify_output(self, outs): + outs = convert_uint16_to_float(outs) + self.assertEqual(outs[0].shape, (123, 92)) + hist, _ = np.histogram(outs[0], range=(-3, 5)) + hist = hist.astype("float32") + hist /= float(outs[0].size) + data = np.random.normal(size=(123, 92), loc=1, scale=2) + hist2, _ = np.histogram(data, range=(-3, 5)) + hist2 = hist2.astype("float32") + hist2 /= float(outs[0].size) + self.assertTrue( + np.allclose( + hist, hist2, rtol=0, atol=0.05), + "hist: " + str(hist) + " hist2: " + str(hist2)) + + class TestMeanStdAreInt(TestGaussianRandomOp): def set_attrs(self): self.mean = 1 -- GitLab