未验证 提交 b9675acc 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

change CUDA implementaion of bernoulli OP (#39732)

* change CUDA implementaion of bernoulli OP

* fix CI
上级 69a04209
......@@ -180,8 +180,8 @@ struct normal_distribution<double> {
/******** Launch GPU function of distribution and transformation *********/
template <typename T, typename DistOp, typename TransformOp>
__global__ void DistributionKernel(size_t size, uint64_t seed, uint64_t offset,
DistOp dist, TransformOp trans,
T *out_data) {
DistOp dist, TransformOp trans, T *out_data,
size_t stride) {
size_t idx = static_cast<size_t>(BLOCK_ID_X * BLOCK_NUM_X);
static constexpr int kCount = DistOp::kReturnsCount;
#if defined(__NVCC__)
......@@ -201,7 +201,8 @@ __global__ void DistributionKernel(size_t size, uint64_t seed, uint64_t offset,
kps::ElementwiseUnary<T, T, kCount, 1, 1, TransformOp>(&result[0], &args[0],
trans);
kps::WriteData<T, T, kCount, 1, 1, true>(out_data + i, &result[0], size - i,
1, total_thread, 1);
1, stride, 1);
__syncthreads();
}
}
......@@ -234,7 +235,7 @@ void distribution_and_transform(const platform::CUDADeviceContext &dev_ctx,
DistributionKernel<
T, DistOp, TransformOp><<<grid_size, block_size, 0, dev_ctx.stream()>>>(
size, seed, offset, dist, trans, out_data);
size, seed, offset, dist, trans, out_data, total_thread);
}
#endif
......
......@@ -29,6 +29,7 @@
#include <string>
#include <vector>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/enforce.h"
#ifdef __HIPCC__
// HIP results in error or nan if > 256
......
......@@ -12,19 +12,30 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <thrust/execution_policy.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#ifdef __NVCC__
#include <curand_kernel.h>
#endif
#ifdef __HIPCC__
#include <hiprand_kernel.h>
#endif
#include <algorithm>
#include <vector>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/bernoulli_kernel.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/operators/distribution_helper.h"
#include "paddle/fluid/platform/transform.h"
DECLARE_bool(use_curand);
namespace phi {
template <typename T>
......@@ -49,26 +60,69 @@ struct BernoulliCudaFunctor {
}
};
// 'curand_uniform4/hiprand_uniform4' generate 4 random number each time
template <typename T>
__global__ void bernoulli_cuda_kernel(
size_t size, uint64_t seed, uint64_t offset, const T* x_data, T* out_data) {
size_t thread_idx =
static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x);
#if defined(__NVCC__)
curandStatePhilox4_32_10_t state;
curand_init(seed, thread_idx, offset, &state);
#else
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed, thread_idx, offset, &state);
#endif
size_t total_thread = gridDim.x * blockDim.x;
for (size_t i = 4 * thread_idx; i < size; i += total_thread * 4) {
paddle::distribution::uniform_distribution<float> dist;
float4 rand = dist(&state);
#pragma unroll
for (size_t j = 0; j < 4; j++) {
size_t idx = i + j;
if (idx < size) {
out_data[idx] = static_cast<T>((&rand.x)[j] <= x_data[idx]);
}
}
}
}
template <typename T, typename Context>
void BernoulliKernel(const Context& ctx,
const DenseTensor& x,
DenseTensor* out) {
auto numel = x.numel();
auto* x_data = x.data<T>();
const T* x_data = x.data<T>();
T* out_data = ctx.template Alloc<T>(out);
auto numel = x.numel();
auto gen_cuda = ctx.GetGenerator();
auto seed_offset = gen_cuda->IncrementOffset(1);
int64_t gen_offset = numel * seed_offset.second;
paddle::platform::Transform<phi::GPUContext> trans;
thrust::counting_iterator<int64_t> index_sequence_begin(0);
trans(ctx,
index_sequence_begin,
index_sequence_begin + numel,
x_data,
out_data,
BernoulliCudaFunctor<T>(static_cast<int64_t>(seed_offset.first),
static_cast<int64_t>(gen_offset)));
if (FLAGS_use_curand) {
auto seed_offset = gen_cuda->IncrementOffset(12);
uint64_t seed = seed_offset.first;
uint64_t offset = seed_offset.second;
auto gpu_config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, 4);
size_t grid_size = gpu_config.GetGridSize();
size_t block_size = gpu_config.GetBlockSize();
bernoulli_cuda_kernel<<<grid_size, block_size, 0, ctx.stream()>>>(
numel, seed, offset, x_data, out_data);
} else {
auto seed_offset = gen_cuda->IncrementOffset(1);
int64_t gen_offset = numel * seed_offset.second;
paddle::platform::Transform<phi::GPUContext> trans;
thrust::counting_iterator<int64_t> index_sequence_begin(0);
trans(ctx,
index_sequence_begin,
index_sequence_begin + numel,
x_data,
out_data,
BernoulliCudaFunctor<T>(static_cast<int64_t>(seed_offset.first),
static_cast<int64_t>(gen_offset)));
}
}
} // namespace phi
......
......@@ -18,6 +18,7 @@ import unittest
import paddle
from op_test import OpTest
import numpy as np
import os
def output_hist(out):
......@@ -68,5 +69,43 @@ class TestBernoulliApi(unittest.TestCase):
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))
class TestRandomValue(unittest.TestCase):
def test_fixed_random_number(self):
# Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t'
if not paddle.is_compiled_with_cuda():
return
if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None):
return
print("Test Fixed Random number on GPU------>")
paddle.disable_static()
paddle.set_device('gpu')
paddle.seed(100)
np.random.seed(100)
x_np = np.random.rand(32, 1024, 1024)
x = paddle.to_tensor(x_np, dtype='float64')
y = paddle.bernoulli(x).numpy()
index0, index1, index2 = np.nonzero(y)
self.assertEqual(np.sum(index0), 260028995)
self.assertEqual(np.sum(index1), 8582429431)
self.assertEqual(np.sum(index2), 8581445798)
expect = [0., 0., 0., 0., 0., 0., 0., 1., 1., 1.]
self.assertTrue(np.array_equal(y[16, 500, 500:510], expect))
x = paddle.to_tensor(x_np, dtype='float32')
y = paddle.bernoulli(x).numpy()
index0, index1, index2 = np.nonzero(y)
self.assertEqual(np.sum(index0), 260092343)
self.assertEqual(np.sum(index1), 8583509076)
self.assertEqual(np.sum(index2), 8582778540)
expect = [0., 0., 1., 1., 1., 1., 0., 1., 1., 1.]
self.assertTrue(np.array_equal(y[16, 500, 500:510], expect))
paddle.enable_static()
if __name__ == "__main__":
unittest.main()
......@@ -16,6 +16,7 @@ import unittest
import paddle
import numpy as np
from op_test import OpTest
import os
paddle.enable_static()
paddle.seed(100)
......@@ -90,18 +91,18 @@ class TestExponentialAPI(unittest.TestCase):
self.assertTrue(np.min(x.numpy()) >= 0)
paddle.enable_static()
# Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t'
def test_fixed_random_number(self):
# Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t'
if not paddle.is_compiled_with_cuda():
return
# Note(zhouwei): The Number of threads is determined by
# 'multiProcessorCount * maxThreadsPerMultiProcessor'. So, different
# GPU have different number of threads, which result in different
# random value. Only test on V100 GPU here.
# Different GPU generatte different random value. Only test V100 here.
if not "V100" in paddle.device.cuda.get_device_name():
return
if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None):
return
print("Test Fixed Random number on V100 GPU------>")
paddle.disable_static()
paddle.set_device('gpu')
......
......@@ -14,6 +14,7 @@
from __future__ import print_function
import os
import unittest
import numpy as np
import paddle
......@@ -293,13 +294,13 @@ class TestRandomValue(unittest.TestCase):
if not paddle.is_compiled_with_cuda():
return
# Note(zhouwei): The Number of threads is determined by
# 'multiProcessorCount * maxThreadsPerMultiProcessor'. So, different
# GPU have different number of threads, which result in different
# random value. Only test on V100 GPU here.
# Different GPU generatte different random value. Only test V100 here.
if not "V100" in paddle.device.cuda.get_device_name():
return
if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None):
return
def _check_random_value(dtype, expect, expect_mean, expect_std):
x = paddle.randn([32, 3, 1024, 1024], dtype=dtype)
actual = x.numpy()
......
......@@ -17,6 +17,7 @@ import paddle
import numpy as np
from op_test import OpTest
import math
import os
paddle.enable_static()
paddle.seed(100)
......@@ -101,11 +102,15 @@ class TestPoissonAPI(unittest.TestCase):
self.assertTrue(np.min(y.numpy()) >= 0)
paddle.enable_static()
# Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t'
def test_fixed_random_number(self):
# Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t'
if not paddle.is_compiled_with_cuda():
return
if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None):
return
print("Test Fixed Random number on GPU------>")
paddle.disable_static()
paddle.set_device('gpu')
paddle.seed(2021)
......
......@@ -15,6 +15,7 @@
from __future__ import print_function
import sys
import os
import subprocess
import unittest
import numpy as np
......@@ -568,13 +569,13 @@ class TestRandomValue(unittest.TestCase):
if not paddle.is_compiled_with_cuda():
return
# Note(zhouwei): The Number of threads is determined by
# 'multiProcessorCount * maxThreadsPerMultiProcessor'. So, different
# GPU have different number of threads, which result in different
# random value. Only test on V100 GPU here.
# Different GPU generate different random value. Only test V100 here.
if not "V100" in paddle.device.cuda.get_device_name():
return
if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None):
return
def _check_random_value(dtype, expect, expect_mean, expect_std):
x = paddle.rand([32, 3, 1024, 1024], dtype=dtype)
actual = x.numpy()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册