未验证 提交 58970995 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

change CUDA implementation of multinomial OP (#40752)

上级 95d3ebc8
......@@ -50,11 +50,15 @@ struct exponential_transform {
HOSTDEVICE inline T operator()(T val) const {
#if defined(__NVCC__) || defined(__HIPCC__)
if (std::is_same<T, double>::value) {
return static_cast<T>(-1.0) / lambda_ * log(val);
} else {
return static_cast<T>(-1.0) / lambda_ * __logf(val);
T log = -std::numeric_limits<T>::epsilon() / 2;
if (val < static_cast<T>(1.) - std::numeric_limits<T>::epsilon() / 2) {
if (std::is_same<T, double>::value) {
log = logf(val);
} else {
log = __logf(val);
}
}
return static_cast<T>(-1.0) / lambda_ * log;
#else
return static_cast<T>(-1.0) / lambda_ * std::log(static_cast<T>(1.0) - val);
#endif
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include <thrust/device_ptr.h>
#include <thrust/iterator/reverse_iterator.h>
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/funcs/for_range.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/memory/malloc.h"
namespace phi {
namespace funcs {
template <typename T>
struct IsComplex : public std::false_type {};
template <>
struct IsComplex<::phi::dtype::complex<float>> : public std::true_type {};
template <>
struct IsComplex<::phi::dtype::complex<double>> : public std::true_type {};
template <typename InputIterator, typename OutputIterator, typename BinaryOp>
static void CubInclusiveScan(InputIterator x_iter,
OutputIterator y_iter,
size_t n,
BinaryOp op,
const phi::GPUContext &dev_ctx) {
paddle::memory::allocation::AllocationPtr allocation;
void *temp_storage = nullptr;
size_t temp_storage_bytes = 0;
for (size_t i = 0; i < 2; ++i) {
PADDLE_ENFORCE_GPU_SUCCESS(
cub::DeviceScan::InclusiveScan(temp_storage,
temp_storage_bytes,
x_iter,
y_iter,
op,
static_cast<int>(n),
dev_ctx.stream()));
if (i == 0 && temp_storage_bytes > 0) {
allocation =
paddle::memory::Alloc(dev_ctx.GetPlace(), temp_storage_bytes);
temp_storage = allocation->ptr();
}
}
}
template <typename T>
static auto MakeThrustReverseIterator(T *x) {
return thrust::reverse_iterator<thrust::device_ptr<T>>(
thrust::device_pointer_cast(x));
}
template <typename T, typename BinaryOp, bool kReverse>
struct InclusiveScanOuterOrMidDimFunctor {
HOSTDEVICE InclusiveScanOuterOrMidDimFunctor(
const T *x, T *y, size_t mid_dim, size_t inner_dim, T init, BinaryOp op)
: x_(x),
y_(y),
mid_dim_(mid_dim),
inner_dim_(inner_dim),
init_(init),
op_(op) {}
HOSTDEVICE void operator()(size_t idx) const {
auto outer_idx = idx / inner_dim_;
auto inner_idx = idx % inner_dim_;
if (kReverse) {
idx = outer_idx * mid_dim_ * inner_dim_ + (mid_dim_ - 1) * inner_dim_ +
inner_idx;
} else {
idx = outer_idx * mid_dim_ * inner_dim_ + inner_idx;
}
auto x_ptr = x_ + idx;
auto y_ptr = y_ + idx;
T acc_value = init_;
for (size_t i = 0; i < mid_dim_; ++i) {
acc_value = op_(acc_value, *x_ptr);
*y_ptr = acc_value;
if (kReverse) {
x_ptr -= inner_dim_;
y_ptr -= inner_dim_;
} else {
x_ptr += inner_dim_;
y_ptr += inner_dim_;
}
}
}
private:
const T *x_;
T *y_;
size_t mid_dim_;
size_t inner_dim_;
T init_;
BinaryOp op_;
};
template <typename T,
typename BinaryOp,
size_t kThreadNumX,
size_t kThreadNumY,
bool kReverse>
static __global__ void InclusiveScanInnerDimCUDAKernel(
const T *x, T *y, size_t num_rows, size_t row_size, T init, BinaryOp op) {
using RealT = phi::dtype::Real<T>;
constexpr auto kSharedBufferSize =
IsComplex<T>::value ? 4 * kThreadNumX : 2 * kThreadNumX;
__shared__ RealT sbuf[kThreadNumY][kSharedBufferSize];
T *row_buf = reinterpret_cast<T *>(sbuf[threadIdx.y]);
size_t block_row = static_cast<size_t>(blockIdx.x * kThreadNumY);
size_t block_row_stride = static_cast<size_t>(gridDim.x * kThreadNumY);
for (; block_row < num_rows; block_row += block_row_stride) {
size_t row = block_row + threadIdx.y;
T block_total = init;
const T *row_x = x + row * row_size;
T *row_y = y + row * row_size;
for (size_t block_col = 0; block_col < row_size;
block_col += 2 * kThreadNumX) {
size_t col1, col2;
if (kReverse) {
col1 = row_size - 1 - block_col - threadIdx.x;
col2 = col1 - kThreadNumX;
} else {
col1 = block_col + threadIdx.x;
col2 = col1 + kThreadNumX;
}
if (row < num_rows) {
if (col1 < row_size) {
row_buf[threadIdx.x] = row_x[col1];
} else {
row_buf[threadIdx.x] = init;
}
if (col2 < row_size) {
row_buf[kThreadNumX + threadIdx.x] = row_x[col2];
} else {
row_buf[kThreadNumX + threadIdx.x] = init;
}
if (threadIdx.x == 0) {
row_buf[0] = op(row_buf[0], block_total);
}
}
__syncthreads();
for (size_t s = kThreadNumX, d = 1; s >= 1; s >>= 1, d <<= 1) {
if (row < num_rows && threadIdx.x < s) {
size_t offset = (2 * threadIdx.x + 1) * d - 1;
row_buf[offset + d] = op(row_buf[offset], row_buf[offset + d]);
}
__syncthreads();
}
for (size_t s = 2, d = kThreadNumX / 2; d >= 1; s <<= 1, d >>= 1) {
if (row < num_rows && threadIdx.x < s - 1) {
size_t offset = 2 * (threadIdx.x + 1) * d - 1;
row_buf[offset + d] = op(row_buf[offset], row_buf[offset + d]);
}
__syncthreads();
}
if (row < num_rows) {
if (col1 < row_size) row_y[col1] = row_buf[threadIdx.x];
if (col2 < row_size) row_y[col2] = row_buf[kThreadNumX + threadIdx.x];
}
block_total = row_buf[2 * kThreadNumX - 1];
__syncthreads();
}
}
}
template <typename T, typename BinaryOp>
static void InclusiveScanInnerDim(const T *x,
T *y,
size_t outer_dim,
size_t inner_dim,
T init,
BinaryOp op,
bool reverse,
const phi::GPUContext &dev_ctx) {
constexpr size_t kThreadNumX = 16;
constexpr size_t kThreadNumY = 32;
size_t grid_dim = (outer_dim + kThreadNumY - 1) / kThreadNumY;
grid_dim = std::min<size_t>(grid_dim, dev_ctx.GetCUDAMaxGridDimSize()[0]);
dim3 thread_dims(kThreadNumX, kThreadNumY);
if (reverse) {
InclusiveScanInnerDimCUDAKernel<
T,
BinaryOp,
kThreadNumX,
kThreadNumY,
/*kReverse=*/true><<<grid_dim, thread_dims, 0, dev_ctx.stream()>>>(
x, y, outer_dim, inner_dim, init, op);
} else {
InclusiveScanInnerDimCUDAKernel<
T,
BinaryOp,
kThreadNumX,
kThreadNumY,
/*kReverse=*/false><<<grid_dim, thread_dims, 0, dev_ctx.stream()>>>(
x, y, outer_dim, inner_dim, init, op);
}
}
template <typename T, typename BinaryOp>
void InclusiveScan(const T *x,
T *y,
size_t outer_dim,
size_t mid_dim,
size_t inner_dim,
T init,
BinaryOp op,
bool reverse,
const phi::GPUContext &dev_ctx) {
if (outer_dim == 0 || mid_dim == 0 || inner_dim == 0) return;
if (outer_dim == 1 && inner_dim == 1) {
if (reverse) {
auto x_reverse_iter = MakeThrustReverseIterator(x + mid_dim);
auto y_reverse_iter = MakeThrustReverseIterator(y + mid_dim);
CubInclusiveScan(x_reverse_iter, y_reverse_iter, mid_dim, op, dev_ctx);
} else {
CubInclusiveScan(x, y, mid_dim, op, dev_ctx);
}
} else if (inner_dim != 1) {
phi::funcs::ForRange<phi::GPUContext> for_range(dev_ctx,
outer_dim * inner_dim);
if (reverse) {
for_range(
InclusiveScanOuterOrMidDimFunctor<T, BinaryOp, /*kReverse=*/true>(
x, y, mid_dim, inner_dim, init, op));
} else {
for_range(
InclusiveScanOuterOrMidDimFunctor<T, BinaryOp, /*kReverse=*/false>(
x, y, mid_dim, inner_dim, init, op));
}
} else {
InclusiveScanInnerDim<T, BinaryOp>(
x, y, outer_dim, mid_dim, init, op, reverse, dev_ctx);
}
}
} // namespace funcs
} // namespace phi
......@@ -23,11 +23,32 @@ limitations under the License. */
#include <thrust/scan.h>
#include <thrust/transform.h>
#include "paddle/fluid/platform/transform.h"
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/arg_min_max_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/inclusive_scan.h"
#include "paddle/phi/kernels/funcs/multinomial_functor.h"
#include "paddle/phi/kernels/top_k_kernel.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/transform.h"
DECLARE_bool(use_curand);
namespace phi {
......@@ -57,12 +78,12 @@ template <typename T>
__global__ void GetCumulativeProbs(T* norm_probs_data,
int64_t num_distributions,
int64_t num_categories,
T* cumulative_probs) {
T* cumulative_probs_data) {
int id = blockIdx.x;
thrust::inclusive_scan(thrust::device,
norm_probs_data + id * num_categories,
norm_probs_data + (id + 1) * num_categories,
cumulative_probs + id * num_categories);
cumulative_probs_data + id * num_categories);
}
template <typename T>
......@@ -80,7 +101,7 @@ struct RandomGeneratorCudaFunctor {
};
template <typename T>
__device__ int binarySearchFunctor(T* cumulative_probs,
__device__ int binarySearchFunctor(T* cumulative_probs_data,
T* norm_probs_data,
int num_categories,
T rng_number) {
......@@ -90,7 +111,7 @@ __device__ int binarySearchFunctor(T* cumulative_probs,
while (right - left > 0) {
int mid = left + (right - left) / 2;
T temp_prob = cumulative_probs[mid];
T temp_prob = cumulative_probs_data[mid];
if (temp_prob < rng_number) {
left = mid + 1;
} else {
......@@ -114,26 +135,35 @@ __global__ void sampleMultinomialWithReplacement(
int64_t* out_data,
const int64_t num_distributions,
const int64_t num_categories,
T* cumulative_probs,
T* norm_probs_data) {
T* cumulative_probs_data,
T* norm_probs_data,
uint64_t seed,
uint64_t offset,
bool use_curand) {
// use binary search to get the selected category sample id.
// let cumulative_probs[id-1] < rng_data < cumulative_probs[id].
// let cumulative_probs_data[id-1] < rng_data < cumulative_probs_data[id].
size_t idx = gridDim.x * blockDim.x * blockIdx.y + blockDim.x * blockIdx.x +
threadIdx.x;
// for every distribution
int dist = blockIdx.y;
// for every sample
int sample = blockIdx.x * blockDim.x + threadIdx.x;
if (sample < num_samples) {
T rng_number = rng_data[sample + dist * num_samples];
curandStatePhilox4_32_10_t state;
curand_init(seed, idx, offset, &state);
// Find the bucket that a uniform random number lies in
int selected_category =
binarySearchFunctor<T>(cumulative_probs + dist * num_categories,
norm_probs_data + dist * num_categories,
num_categories,
rng_number);
int sample = blockIdx.x * blockDim.x + threadIdx.x;
for (int dist = blockIdx.y; dist < num_distributions; dist += gridDim.y) {
if (sample < num_samples) {
T rng_number = rng_data[sample + dist * num_samples];
if (use_curand) {
rng_number = static_cast<T>(curand_uniform4(&state).x);
}
// Find the bucket that a uniform random number lies in
int selected_category =
binarySearchFunctor<T>(cumulative_probs_data + dist * num_categories,
norm_probs_data + dist * num_categories,
num_categories,
rng_number);
out_data[sample + dist * num_samples] = selected_category;
out_data[sample + dist * num_samples] = selected_category;
}
}
}
......@@ -172,6 +202,54 @@ void MultinomialKernel(const Context& dev_ctx,
in_data_numel * sizeof(T),
cudaMemcpyDeviceToHost);
#endif
if (FLAGS_use_curand) {
for (size_t i = 0; i < num_distributions; ++i) {
int zero_num = 0;
for (size_t j = 0; j < num_categories; ++j) {
T weight = cpu_in_data[i * num_distributions + j];
PADDLE_ENFORCE_GE(
weight,
0,
errors::InvalidArgument(
"Each element of multinomial'input must >= 0, but got %f.",
weight));
if (weight == static_cast<T>(0)) {
zero_num++;
}
}
int valid_samples = num_categories - zero_num;
PADDLE_ENFORCE_LE(
num_samples,
valid_samples,
errors::InvalidArgument("When replacement=False, 'num_samples' "
"must less than or eaqual to the number of "
"positive item of input"));
}
// Refer to [gumbel softmax algorithm]
DenseTensor rand = EmptyLike<T, Context>(dev_ctx, x);
T* rand_data = rand.data<T>();
funcs::uniform_distribution<T> dist;
funcs::exponential_transform<T> trans(1.0);
funcs::distribution_and_transform<T>(dev_ctx, &rand, dist, trans);
funcs::ForRange<Context> for_range(dev_ctx, x.numel());
for_range([rand_data, in_data] __device__(size_t idx) {
rand_data[idx] = in_data[idx] / rand_data[idx];
});
if (num_samples == 1) {
ArgMaxKernel<T, Context>(
dev_ctx, rand, -1, true, false, 3 /*proto::VarType::INT64*/, out);
} else {
std::vector<int64_t> out_dim_vec = vectorize<int64_t>(out->dims());
DenseTensor value =
Empty<T, Context>(dev_ctx, ScalarArray(out_dim_vec));
TopkKernel<T, Context>(
dev_ctx, rand, Scalar(num_samples), -1, true, true, &value, out);
}
return;
}
funcs::MultinomialFunctor<T>(dev_ctx,
cpu_out_data,
......@@ -228,7 +306,8 @@ void MultinomialKernel(const Context& dev_ctx,
auto* norm_probs_data = dev_ctx.template Alloc<T>(&norm_probs_tensor);
// number of threads in a block is min(num_categories, 512)
dim3 block_norm(num_categories < 512 ? num_categories : 512);
int block_size = num_categories < 512 ? num_categories : 512;
dim3 block_norm(block_size);
dim3 grid_norm((num_distributions * num_categories - 1) / block_norm.x + 1);
NormalizeProbability<T><<<grid_norm, block_norm, 0, dev_ctx.stream()>>>(
norm_probs_data,
......@@ -238,16 +317,34 @@ void MultinomialKernel(const Context& dev_ctx,
num_categories);
// Get cumulative probability of each distribution. It's the same function
// of
// ``cumsum`` op.
// of ``cumsum`` op.
DenseTensor cumulative_probs_tensor;
cumulative_probs_tensor.Resize({num_distributions, num_categories});
auto* cumulative_probs = dev_ctx.template Alloc<T>(&cumulative_probs_tensor);
dim3 block_cumsum(1);
dim3 grid_cumsum(num_distributions);
GetCumulativeProbs<T><<<grid_cumsum, block_cumsum, 0, dev_ctx.stream()>>>(
norm_probs_data, num_distributions, num_categories, cumulative_probs);
auto* cumulative_probs_data =
dev_ctx.template Alloc<T>(&cumulative_probs_tensor);
if (FLAGS_use_curand) {
// 'phi::funcs::InclusiveScan' has higher accuracy than
// 'thrust::inclusive_scan'
funcs::InclusiveScan<T, std::plus<T>>(
/*in*/ norm_probs_data,
/*out*/ cumulative_probs_data,
/*outer_dim*/ static_cast<size_t>(num_distributions),
/*mid_dim*/ static_cast<size_t>(num_categories),
/*inner_dim*/ static_cast<size_t>(1),
/*init*/ static_cast<T>(0),
std::plus<T>(),
/*reverse=*/false,
dev_ctx);
} else {
dim3 block_cumsum(1);
dim3 grid_cumsum(num_distributions);
GetCumulativeProbs<T><<<grid_cumsum, block_cumsum, 0, dev_ctx.stream()>>>(
norm_probs_data,
num_distributions,
num_categories,
cumulative_probs_data);
}
// Generate random number for each sample.
std::random_device rd;
......@@ -266,16 +363,30 @@ void MultinomialKernel(const Context& dev_ctx,
RandomGeneratorCudaFunctor<T>(seed));
// Sample the multinomial distributions.
dim3 block_sample(128);
dim3 grid_sample((num_samples - 1) / block_sample.x + 1, num_distributions);
sampleMultinomialWithReplacement<
T><<<grid_sample, block_sample, 0, dev_ctx.stream()>>>(rng_data,
num_samples,
out_data,
num_distributions,
num_categories,
cumulative_probs,
norm_probs_data);
dim3 block(128);
int64_t device_id = dev_ctx.GetPlace().GetDeviceId();
const auto& prop = phi::backends::gpu::GetDeviceProperties(device_id);
int grid_y = std::min<int64_t>(num_distributions, prop.maxGridSize[1]);
dim3 grid((num_samples - 1) / block.x + 1, grid_y);
auto gen_cuda = dev_ctx.GetGenerator();
size_t curand4_loop_times =
(num_distributions + 4 * grid_y - 1) / (4 * grid_y);
// 'increment' shoulde be multiple of 4
uint64_t increment = curand4_loop_times * 4;
auto seed_offset = gen_cuda->IncrementOffset(increment);
sampleMultinomialWithReplacement<T><<<grid, block, 0, dev_ctx.stream()>>>(
rng_data,
num_samples,
out_data,
num_distributions,
num_categories,
cumulative_probs_data,
norm_probs_data,
seed_offset.first,
seed_offset.second,
FLAGS_use_curand);
}
} // namespace phi
......
......@@ -20,6 +20,7 @@ import paddle.fluid as fluid
from paddle.fluid import core
from op_test import OpTest
import numpy as np
import os
def sample_output_one_dimension(out, dim):
......@@ -216,5 +217,59 @@ class TestMultinomialError(unittest.TestCase):
self.assertRaises(ValueError, test_dim_less_than_1)
class TestRandomValue(unittest.TestCase):
def test_fixed_random_number(self):
# Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t'
if not paddle.is_compiled_with_cuda():
return
# Different GPU generatte different random value. Only test V100 here.
if not "V100" in paddle.device.cuda.get_device_name():
return
if os.getenv("FLAGS_use_curand", None) in ('0', 'False', None):
return
print("Test Fixed Random number on V100 GPU------>")
paddle.disable_static()
paddle.set_device('gpu')
paddle.seed(100)
x = paddle.randint(0, 100, [1024, 10000]).astype('float32')
y = paddle.multinomial(x, 1, replacement=False).numpy()
self.assertEqual(np.sum(y), 5187793)
self.assertEqual(np.mean(y), 5066.2041015625)
expect = [9982, 1655, 4741, 1323, 9319, 3298, 6473, 7477, 2507, 2628]
self.assertTrue(np.array_equal(y[100:110, :].flatten(), expect))
y = paddle.multinomial(x, 5000, replacement=False).numpy()
self.assertEqual(np.sum(y), 25603962316)
self.assertEqual(np.mean(y), 5000.77388984375)
expect = [7300, 6055, 8714, 5401, 7360, 161, 5035, 7002, 6788, 2916]
self.assertTrue(np.array_equal(y[100, 1000:1010], expect))
y = paddle.multinomial(x, 5000, replacement=False).numpy()
self.assertEqual(np.sum(y), 25592855710)
self.assertEqual(np.mean(y), 4998.604630859375)
expect = [5700, 6567, 4399, 5688, 7472, 545, 6894, 526, 2124, 385]
self.assertTrue(np.array_equal(y[300, 3000:3010], expect))
y = paddle.multinomial(x, 20000, replacement=True).numpy()
self.assertEqual(np.sum(y), 102371362581)
self.assertEqual(np.mean(y), 4998.60168852539)
self.assertEqual(np.std(y), 2886.316308500771)
expect = [7630, 8235, 8445, 3275, 5580, 4591, 1331, 342, 1662, 7156]
self.assertTrue(np.array_equal(y[100, 0:10], expect))
y = paddle.multinomial(x, 20000, replacement=True).numpy()
self.assertEqual(np.sum(y), 102400672117)
self.assertEqual(np.mean(y), 5000.032818212891)
self.assertEqual(np.std(y), 2886.913426124017)
expect = [4159, 7849, 9305, 5759, 4422, 122, 345, 2897, 5200, 5911]
self.assertTrue(np.array_equal(y[100, 0:10], expect))
paddle.enable_static()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册