未验证 提交 fb635089 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

optimize CUDA implementaion of randint OP (#39952)

* change CUDA implementaion of randint OP,move distribution common func to phi

* fix CI

* fix CI
上级 9af72957
......@@ -21,12 +21,11 @@ limitations under the License. */
#include <hiprand_kernel.h>
#endif
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/generator.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
#include "paddle/phi/core/hostdevice.h"
#if defined(__NVCC__) || defined(__HIPCC__)
#include "paddle/phi/kernels/primitive/kernel_primitives.h"
......@@ -40,7 +39,7 @@ limitations under the License. */
#endif
namespace phi {
namespace distribution {
namespace funcs {
/********************* Transformation Function **********************/
template <typename T>
......@@ -64,8 +63,9 @@ struct exponential_transform {
};
template <typename T>
struct uniform_transform {
explicit uniform_transform(T min, T max) : range_(max - min), min_(min) {}
struct uniform_real_transform {
explicit uniform_real_transform(T min, T max)
: range_(max - min), min_(min) {}
HOSTDEVICE inline T operator()(T val) const {
if (UNLIKELY(val == static_cast<T>(1.0))) {
......@@ -80,6 +80,22 @@ struct uniform_transform {
T min_;
};
template <typename T, typename R>
struct uniform_int_transform {
explicit uniform_int_transform(int min, int max) {
range_ = static_cast<uint32_t>(max - min);
min_ = min;
}
HOSTDEVICE inline T operator()(R rand) const {
return static_cast<T>(static_cast<int>(rand % range_) + min_);
}
private:
uint32_t range_;
int min_;
};
template <typename T>
struct normal_transform {
explicit normal_transform(T mean, T std) : mean_(mean), std_(std) {}
......@@ -120,6 +136,27 @@ struct uniform_distribution<double> {
static constexpr int kReturnsCount = 2;
};
template <>
struct uniform_distribution<uint32_t> {
__device__ inline uint4 operator()(curandStatePhilox4_32_10_t *state) const {
return curand4(state);
}
static constexpr int kReturnsCount = 4;
};
template <>
struct uniform_distribution<uint64_t> {
__device__ inline ulonglong2 operator()(
curandStatePhilox4_32_10_t *state) const {
ulonglong2 result;
uint4 rand = curand4(state);
result.x = (uint64_t)rand.x << 32 | rand.y;
result.y = (uint64_t)rand.z << 32 | rand.w;
return result;
}
static constexpr int kReturnsCount = 2;
};
template <>
struct normal_distribution<float> {
__device__ inline float4 operator()(curandStatePhilox4_32_10_t *state) const {
......@@ -156,6 +193,27 @@ struct uniform_distribution<double> {
static constexpr int kReturnsCount = 2;
};
template <>
struct uniform_distribution<uint32_t> {
__device__ inline uint4 operator()(hiprandStatePhilox4_32_10_t *state) const {
return hiprand4(state);
}
static constexpr int kReturnsCount = 4;
};
template <>
struct uniform_distribution<uint64_t> {
__device__ inline ulonglong2 operator()(
hiprandStatePhilox4_32_10_t *state) const {
ulonglong2 result;
uint4 rand = hiprand4(state);
result.x = (uint64_t)rand.x << 32 | rand.y;
result.y = (uint64_t)rand.z << 32 | rand.w;
return result;
}
static constexpr int kReturnsCount = 2;
};
template <>
struct normal_distribution<float> {
__device__ inline float4 operator()(
......@@ -209,19 +267,21 @@ __global__ void DistributionKernel(size_t size,
}
template <typename T, typename DistOp, typename TransformOp>
void distribution_and_transform(const GPUContext &dev_ctx,
void distribution_and_transform(const GPUContext &ctx,
DenseTensor *out,
DistOp dist,
TransformOp trans) {
T *out_data = dev_ctx.template Alloc<T>(out);
T *out_data = ctx.template Alloc<T>(out);
auto size = out->numel();
int64_t device_id = dev_ctx.GetPlace().GetDeviceId();
auto gen_cuda = dev_ctx.GetGenerator();
if (size == 0) return;
auto gen_cuda = ctx.GetGenerator();
size_t block_size = 256;
size_t expect_grid_size = (size + block_size - 1) / block_size;
const auto &prop = backends::gpu::GetDeviceProperties(device_id);
int64_t device_id = ctx.GetPlace().GetDeviceId();
const auto &prop = phi::backends::gpu::GetDeviceProperties(device_id);
size_t max_grid_size = (prop.maxThreadsPerMultiProcessor / block_size) *
prop.multiProcessorCount;
size_t grid_size =
......@@ -237,13 +297,13 @@ void distribution_and_transform(const GPUContext &dev_ctx,
uint64_t seed = seed_offset.first;
uint64_t offset = seed_offset.second;
DistributionKernel<
T,
DistOp,
TransformOp><<<grid_size, block_size, 0, dev_ctx.stream()>>>(
DistributionKernel<T,
DistOp,
TransformOp><<<grid_size, block_size, 0, ctx.stream()>>>(
size, seed, offset, dist, trans, out_data, total_thread);
}
#endif
} // namespace distribution
} // namespace funcs
} // namespace phi
......@@ -29,9 +29,9 @@
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/bernoulli_kernel.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/operators/distribution_helper.h"
#include "paddle/fluid/platform/transform.h"
DECLARE_bool(use_curand);
......@@ -77,7 +77,7 @@ __global__ void bernoulli_cuda_kernel(
size_t total_thread = gridDim.x * blockDim.x;
for (size_t i = 4 * thread_idx; i < size; i += total_thread * 4) {
paddle::distribution::uniform_distribution<float> dist;
funcs::uniform_distribution<float> dist;
float4 rand = dist(&state);
#pragma unroll
for (size_t j = 0; j < 4; j++) {
......
......@@ -18,10 +18,13 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/memory/memcpy.h"
DECLARE_bool(use_curand);
namespace phi {
template <typename T, typename Context>
......@@ -32,34 +35,39 @@ void RandintRawKernel(const Context& dev_ctx,
DataType dtype,
int seed,
DenseTensor* out) {
DenseTensor tmp;
tmp.Resize(phi::make_ddim(shape.GetData()));
T* tmp_data = dev_ctx.template HostAlloc<T>(&tmp);
out->Resize(tmp.dims());
out->Resize(phi::make_ddim(shape.GetData()));
T* data = dev_ctx.template Alloc<T>(out);
std::shared_ptr<std::mt19937_64> engine;
if (seed) {
engine = std::make_shared<std::mt19937_64>();
engine->seed(seed);
if (FLAGS_use_curand) {
funcs::uniform_distribution<uint32_t> dist;
funcs::uniform_int_transform<T, uint32_t> trans(low, high);
funcs::distribution_and_transform<T>(dev_ctx, out, dist, trans);
} else {
engine = dev_ctx.GetHostGenerator()->GetCPUEngine();
}
DenseTensor tmp;
tmp.Resize(phi::make_ddim(shape.GetData()));
T* tmp_data = dev_ctx.template HostAlloc<T>(&tmp);
std::uniform_int_distribution<T> dist(low, high - 1);
auto numel = out->numel();
for (int64_t i = 0; i < numel; ++i) {
tmp_data[i] = dist(*engine);
}
std::shared_ptr<std::mt19937_64> engine;
if (seed) {
engine = std::make_shared<std::mt19937_64>();
engine->seed(seed);
} else {
engine = dev_ctx.GetHostGenerator()->GetCPUEngine();
}
std::uniform_int_distribution<T> dist(low, high - 1);
auto numel = out->numel();
for (int64_t i = 0; i < numel; ++i) {
tmp_data[i] = dist(*engine);
}
paddle::memory::Copy<phi::GPUPlace, phi::Place>(
out->place(),
data,
tmp.place(),
tmp_data,
numel * paddle::experimental::SizeOf(out->dtype()),
0);
paddle::memory::Copy<phi::GPUPlace, phi::Place>(
out->place(),
data,
tmp.place(),
tmp_data,
numel * paddle::experimental::SizeOf(out->dtype()),
0);
}
}
template <typename T, typename Context>
......
......@@ -116,9 +116,9 @@ void UniformRandomRawKernel(const Context& dev_ctx,
if (generator->GetIsInitPy() && seed_flag) {
if (FLAGS_use_curand) {
using MT = typename kps::details::MPTypeTrait<T>::Type;
distribution::uniform_distribution<MT> dist;
distribution::uniform_transform<MT> trans(min, max);
distribution::distribution_and_transform<T>(dev_ctx, out, dist, trans);
funcs::uniform_distribution<MT> dist;
funcs::uniform_real_transform<MT> trans(min, max);
funcs::distribution_and_transform<T>(dev_ctx, out, dist, trans);
} else {
auto seed_offset = generator->IncrementOffset(1);
int64_t gen_offset = size * seed_offset.second;
......
......@@ -93,11 +93,11 @@ class TestGeneratorSeed(unittest.TestCase):
fluid.enable_dygraph()
gen = paddle.seed(12312321111)
paddle.seed(12312321111)
x = paddle.randint(low=10, shape=[10], dtype="int32")
st1 = gen.get_state()
st1 = paddle.get_cuda_rng_state()
x1 = paddle.randint(low=10, shape=[10], dtype="int32")
gen.set_state(st1)
paddle.set_cuda_rng_state(st1)
x2 = paddle.randint(low=10, shape=[10], dtype="int32")
paddle.seed(12312321111)
x3 = paddle.randint(low=10, shape=[10], dtype="int32")
......
......@@ -20,6 +20,9 @@ from op_test import OpTest
import paddle
from paddle.fluid import core
from paddle.static import program_guard, Program
import os
paddle.enable_static()
def output_hist(out):
......@@ -156,5 +159,47 @@ class TestRandintImperative(unittest.TestCase):
paddle.enable_static()
class TestRandomValue(unittest.TestCase):
def test_fixed_random_number(self):
# Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t'
if not paddle.is_compiled_with_cuda():
return
# Different GPU generatte different random value. Only test V100 here.
if not "V100" in paddle.device.cuda.get_device_name():
return
if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None):
return
print("Test Fixed Random number on GPU------>")
paddle.disable_static()
paddle.set_device('gpu')
paddle.seed(100)
x = paddle.randint(
-10000, 10000, [32, 3, 1024, 1024], dtype='int32').numpy()
self.assertTrue(x.mean(), -0.7517569760481516)
self.assertTrue(x.std(), 5773.696619107639)
expect = [2535, 2109, 5916, -5011, -261]
self.assertTrue(np.array_equal(x[10, 0, 100, 100:105], expect))
expect = [3465, 7206, -8660, -9628, -6574]
self.assertTrue(np.array_equal(x[20, 1, 600, 600:605], expect))
expect = [881, 1560, 1100, 9664, 1669]
self.assertTrue(np.array_equal(x[30, 2, 1000, 1000:1005], expect))
x = paddle.randint(
-10000, 10000, [32, 3, 1024, 1024], dtype='int64').numpy()
self.assertTrue(x.mean(), -1.461287518342336)
self.assertTrue(x.std(), 5773.023477548159)
expect = [7213, -9597, 754, 8129, -1158]
self.assertTrue(np.array_equal(x[10, 0, 100, 100:105], expect))
expect = [-7159, 8054, 7675, 6980, 8506]
self.assertTrue(np.array_equal(x[20, 1, 600, 600:605], expect))
expect = [3581, 3420, -8027, -5237, -2436]
self.assertTrue(np.array_equal(x[30, 2, 1000, 1000:1005], expect))
paddle.enable_static()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册