From b9675acc9d4326b73f5b3167265a1d3f6e98dac9 Mon Sep 17 00:00:00 2001 From: Zhou Wei <1183042833@qq.com> Date: Wed, 23 Feb 2022 16:55:49 +0800 Subject: [PATCH] change CUDA implementaion of bernoulli OP (#39732) * change CUDA implementaion of bernoulli OP * fix CI --- paddle/fluid/operators/distribution_helper.h | 9 +- paddle/phi/backends/gpu/gpu_launch_config.h | 1 + paddle/phi/kernels/gpu/bernoulli_kernel.cu | 82 +++++++++++++++---- .../tests/unittests/test_bernoulli_op.py | 39 +++++++++ .../tests/unittests/test_exponential_op.py | 11 +-- .../unittests/test_gaussian_random_op.py | 9 +- .../fluid/tests/unittests/test_poisson_op.py | 7 +- .../tests/unittests/test_uniform_random_op.py | 9 +- 8 files changed, 135 insertions(+), 32 deletions(-) diff --git a/paddle/fluid/operators/distribution_helper.h b/paddle/fluid/operators/distribution_helper.h index ca6bcb1147a..c13bf687af2 100644 --- a/paddle/fluid/operators/distribution_helper.h +++ b/paddle/fluid/operators/distribution_helper.h @@ -180,8 +180,8 @@ struct normal_distribution { /******** Launch GPU function of distribution and transformation *********/ template __global__ void DistributionKernel(size_t size, uint64_t seed, uint64_t offset, - DistOp dist, TransformOp trans, - T *out_data) { + DistOp dist, TransformOp trans, T *out_data, + size_t stride) { size_t idx = static_cast(BLOCK_ID_X * BLOCK_NUM_X); static constexpr int kCount = DistOp::kReturnsCount; #if defined(__NVCC__) @@ -201,7 +201,8 @@ __global__ void DistributionKernel(size_t size, uint64_t seed, uint64_t offset, kps::ElementwiseUnary(&result[0], &args[0], trans); kps::WriteData(out_data + i, &result[0], size - i, - 1, total_thread, 1); + 1, stride, 1); + __syncthreads(); } } @@ -234,7 +235,7 @@ void distribution_and_transform(const platform::CUDADeviceContext &dev_ctx, DistributionKernel< T, DistOp, TransformOp><<>>( - size, seed, offset, dist, trans, out_data); + size, seed, offset, dist, trans, out_data, total_thread); } #endif diff --git a/paddle/phi/backends/gpu/gpu_launch_config.h b/paddle/phi/backends/gpu/gpu_launch_config.h index 5aa569e0197..e45b4651225 100644 --- a/paddle/phi/backends/gpu/gpu_launch_config.h +++ b/paddle/phi/backends/gpu/gpu_launch_config.h @@ -29,6 +29,7 @@ #include #include #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/enforce.h" #ifdef __HIPCC__ // HIP results in error or nan if > 256 diff --git a/paddle/phi/kernels/gpu/bernoulli_kernel.cu b/paddle/phi/kernels/gpu/bernoulli_kernel.cu index 6127bceef50..ac69d398b8a 100644 --- a/paddle/phi/kernels/gpu/bernoulli_kernel.cu +++ b/paddle/phi/kernels/gpu/bernoulli_kernel.cu @@ -12,19 +12,30 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include #include +#ifdef __NVCC__ +#include +#endif +#ifdef __HIPCC__ +#include +#endif + #include #include + #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/bernoulli_kernel.h" // See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/operators/distribution_helper.h" #include "paddle/fluid/platform/transform.h" +DECLARE_bool(use_curand); + namespace phi { template @@ -49,26 +60,69 @@ struct BernoulliCudaFunctor { } }; +// 'curand_uniform4/hiprand_uniform4' generate 4 random number each time +template +__global__ void bernoulli_cuda_kernel( + size_t size, uint64_t seed, uint64_t offset, const T* x_data, T* out_data) { + size_t thread_idx = + static_cast(blockIdx.x * blockDim.x + threadIdx.x); + +#if defined(__NVCC__) + curandStatePhilox4_32_10_t state; + curand_init(seed, thread_idx, offset, &state); +#else + hiprandStatePhilox4_32_10_t state; + hiprand_init(seed, thread_idx, offset, &state); +#endif + + size_t total_thread = gridDim.x * blockDim.x; + for (size_t i = 4 * thread_idx; i < size; i += total_thread * 4) { + paddle::distribution::uniform_distribution dist; + float4 rand = dist(&state); +#pragma unroll + for (size_t j = 0; j < 4; j++) { + size_t idx = i + j; + if (idx < size) { + out_data[idx] = static_cast((&rand.x)[j] <= x_data[idx]); + } + } + } +} + template void BernoulliKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) { - auto numel = x.numel(); - auto* x_data = x.data(); + const T* x_data = x.data(); T* out_data = ctx.template Alloc(out); + auto numel = x.numel(); auto gen_cuda = ctx.GetGenerator(); - auto seed_offset = gen_cuda->IncrementOffset(1); - int64_t gen_offset = numel * seed_offset.second; - paddle::platform::Transform trans; - thrust::counting_iterator index_sequence_begin(0); - trans(ctx, - index_sequence_begin, - index_sequence_begin + numel, - x_data, - out_data, - BernoulliCudaFunctor(static_cast(seed_offset.first), - static_cast(gen_offset))); + + if (FLAGS_use_curand) { + auto seed_offset = gen_cuda->IncrementOffset(12); + uint64_t seed = seed_offset.first; + uint64_t offset = seed_offset.second; + + auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, 4); + size_t grid_size = gpu_config.GetGridSize(); + size_t block_size = gpu_config.GetBlockSize(); + + bernoulli_cuda_kernel<<>>( + numel, seed, offset, x_data, out_data); + } else { + auto seed_offset = gen_cuda->IncrementOffset(1); + int64_t gen_offset = numel * seed_offset.second; + paddle::platform::Transform trans; + thrust::counting_iterator index_sequence_begin(0); + trans(ctx, + index_sequence_begin, + index_sequence_begin + numel, + x_data, + out_data, + BernoulliCudaFunctor(static_cast(seed_offset.first), + static_cast(gen_offset))); + } } } // namespace phi diff --git a/python/paddle/fluid/tests/unittests/test_bernoulli_op.py b/python/paddle/fluid/tests/unittests/test_bernoulli_op.py index 471caeb77bf..426d5d463f4 100644 --- a/python/paddle/fluid/tests/unittests/test_bernoulli_op.py +++ b/python/paddle/fluid/tests/unittests/test_bernoulli_op.py @@ -18,6 +18,7 @@ import unittest import paddle from op_test import OpTest import numpy as np +import os def output_hist(out): @@ -68,5 +69,43 @@ class TestBernoulliApi(unittest.TestCase): hist, prob, rtol=0, atol=0.01), "hist: " + str(hist)) +class TestRandomValue(unittest.TestCase): + def test_fixed_random_number(self): + # Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t' + if not paddle.is_compiled_with_cuda(): + return + + if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None): + return + + print("Test Fixed Random number on GPU------>") + paddle.disable_static() + paddle.set_device('gpu') + paddle.seed(100) + np.random.seed(100) + + x_np = np.random.rand(32, 1024, 1024) + + x = paddle.to_tensor(x_np, dtype='float64') + y = paddle.bernoulli(x).numpy() + index0, index1, index2 = np.nonzero(y) + self.assertEqual(np.sum(index0), 260028995) + self.assertEqual(np.sum(index1), 8582429431) + self.assertEqual(np.sum(index2), 8581445798) + expect = [0., 0., 0., 0., 0., 0., 0., 1., 1., 1.] + self.assertTrue(np.array_equal(y[16, 500, 500:510], expect)) + + x = paddle.to_tensor(x_np, dtype='float32') + y = paddle.bernoulli(x).numpy() + index0, index1, index2 = np.nonzero(y) + self.assertEqual(np.sum(index0), 260092343) + self.assertEqual(np.sum(index1), 8583509076) + self.assertEqual(np.sum(index2), 8582778540) + expect = [0., 0., 1., 1., 1., 1., 0., 1., 1., 1.] + self.assertTrue(np.array_equal(y[16, 500, 500:510], expect)) + + paddle.enable_static() + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_exponential_op.py b/python/paddle/fluid/tests/unittests/test_exponential_op.py index 7d43ebadf41..ccbc0a16763 100644 --- a/python/paddle/fluid/tests/unittests/test_exponential_op.py +++ b/python/paddle/fluid/tests/unittests/test_exponential_op.py @@ -16,6 +16,7 @@ import unittest import paddle import numpy as np from op_test import OpTest +import os paddle.enable_static() paddle.seed(100) @@ -90,18 +91,18 @@ class TestExponentialAPI(unittest.TestCase): self.assertTrue(np.min(x.numpy()) >= 0) paddle.enable_static() - # Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t' def test_fixed_random_number(self): + # Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t' if not paddle.is_compiled_with_cuda(): return - # Note(zhouwei): The Number of threads is determined by - # 'multiProcessorCount * maxThreadsPerMultiProcessor'. So, different - # GPU have different number of threads, which result in different - # random value. Only test on V100 GPU here. + # Different GPU generatte different random value. Only test V100 here. if not "V100" in paddle.device.cuda.get_device_name(): return + if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None): + return + print("Test Fixed Random number on V100 GPU------>") paddle.disable_static() paddle.set_device('gpu') diff --git a/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py b/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py index 43bcc3438ee..31caf4bd6be 100644 --- a/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py +++ b/python/paddle/fluid/tests/unittests/test_gaussian_random_op.py @@ -14,6 +14,7 @@ from __future__ import print_function +import os import unittest import numpy as np import paddle @@ -293,13 +294,13 @@ class TestRandomValue(unittest.TestCase): if not paddle.is_compiled_with_cuda(): return - # Note(zhouwei): The Number of threads is determined by - # 'multiProcessorCount * maxThreadsPerMultiProcessor'. So, different - # GPU have different number of threads, which result in different - # random value. Only test on V100 GPU here. + # Different GPU generatte different random value. Only test V100 here. if not "V100" in paddle.device.cuda.get_device_name(): return + if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None): + return + def _check_random_value(dtype, expect, expect_mean, expect_std): x = paddle.randn([32, 3, 1024, 1024], dtype=dtype) actual = x.numpy() diff --git a/python/paddle/fluid/tests/unittests/test_poisson_op.py b/python/paddle/fluid/tests/unittests/test_poisson_op.py index dc4dc3284e9..2123d4e0e7e 100644 --- a/python/paddle/fluid/tests/unittests/test_poisson_op.py +++ b/python/paddle/fluid/tests/unittests/test_poisson_op.py @@ -17,6 +17,7 @@ import paddle import numpy as np from op_test import OpTest import math +import os paddle.enable_static() paddle.seed(100) @@ -101,11 +102,15 @@ class TestPoissonAPI(unittest.TestCase): self.assertTrue(np.min(y.numpy()) >= 0) paddle.enable_static() - # Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t' def test_fixed_random_number(self): + # Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t' if not paddle.is_compiled_with_cuda(): return + if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None): + return + + print("Test Fixed Random number on GPU------>") paddle.disable_static() paddle.set_device('gpu') paddle.seed(2021) diff --git a/python/paddle/fluid/tests/unittests/test_uniform_random_op.py b/python/paddle/fluid/tests/unittests/test_uniform_random_op.py index a84c3b20da2..41b6ed36d65 100644 --- a/python/paddle/fluid/tests/unittests/test_uniform_random_op.py +++ b/python/paddle/fluid/tests/unittests/test_uniform_random_op.py @@ -15,6 +15,7 @@ from __future__ import print_function import sys +import os import subprocess import unittest import numpy as np @@ -568,13 +569,13 @@ class TestRandomValue(unittest.TestCase): if not paddle.is_compiled_with_cuda(): return - # Note(zhouwei): The Number of threads is determined by - # 'multiProcessorCount * maxThreadsPerMultiProcessor'. So, different - # GPU have different number of threads, which result in different - # random value. Only test on V100 GPU here. + # Different GPU generate different random value. Only test V100 here. if not "V100" in paddle.device.cuda.get_device_name(): return + if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None): + return + def _check_random_value(dtype, expect, expect_mean, expect_std): x = paddle.rand([32, 3, 1024, 1024], dtype=dtype) actual = x.numpy() -- GitLab