未验证 提交 d8ffb261 编写于 作者: P pangyoki 提交者: GitHub

【Cherry-pick PR 36511】fix out_of_range bug of multinomial op's cuda kernel (#36511) (#36808)

Cherry-pick PR #36511
上级 e3db65d5
...@@ -33,18 +33,22 @@ namespace operators { ...@@ -33,18 +33,22 @@ namespace operators {
template <typename T> template <typename T>
__global__ void NormalizeProbability(T* norm_probs, const T* in_data, __global__ void NormalizeProbability(T* norm_probs, const T* in_data,
T* sum_rows) { T* sum_rows, int64_t num_distributions,
int64_t num_categories) {
int id = threadIdx.x + blockIdx.x * blockDim.x + int id = threadIdx.x + blockIdx.x * blockDim.x +
blockIdx.y * gridDim.x * blockDim.x; blockIdx.y * gridDim.x * blockDim.x;
PADDLE_ENFORCE( if (id < num_distributions * num_categories) {
in_data[id] >= 0.0, PADDLE_ENFORCE(
"The input of multinomial distribution should be >= 0, but got %f.", in_data[id] >= 0.0,
in_data[id]); "The input of multinomial distribution should be >= 0, but got %f.",
PADDLE_ENFORCE(sum_rows[blockIdx.y] > 0.0, in_data[id]);
"The sum of one multinomial distribution probability should " int64_t row_id = id / num_categories;
"be > 0, but got %f.", PADDLE_ENFORCE(sum_rows[row_id] > 0.0,
sum_rows[blockIdx.y]); "The sum of one multinomial distribution probability should "
norm_probs[id] = in_data[id] / sum_rows[blockIdx.y]; "be > 0, but got %f.",
sum_rows[row_id]);
norm_probs[id] = in_data[id] / sum_rows[row_id];
}
} }
template <typename T> template <typename T>
...@@ -52,12 +56,10 @@ __global__ void GetCumulativeProbs(T* norm_probs_data, ...@@ -52,12 +56,10 @@ __global__ void GetCumulativeProbs(T* norm_probs_data,
int64_t num_distributions, int64_t num_distributions,
int64_t num_categories, int64_t num_categories,
T* cumulative_probs) { T* cumulative_probs) {
for (int id = blockIdx.x; id < num_distributions; id += gridDim.x) { int id = blockIdx.x;
thrust::inclusive_scan(thrust::device, thrust::inclusive_scan(thrust::device, norm_probs_data + id * num_categories,
norm_probs_data + id * num_categories, norm_probs_data + (id + 1) * num_categories,
norm_probs_data + (id + 1) * num_categories, cumulative_probs + id * num_categories);
cumulative_probs + id * num_categories);
}
} }
template <typename T> template <typename T>
...@@ -108,23 +110,19 @@ __global__ void sampleMultinomialWithReplacement( ...@@ -108,23 +110,19 @@ __global__ void sampleMultinomialWithReplacement(
// use binary search to get the selected category sample id. // use binary search to get the selected category sample id.
// let cumulative_probs[id-1] < rng_data < cumulative_probs[id]. // let cumulative_probs[id-1] < rng_data < cumulative_probs[id].
int idx = threadIdx.x + blockIdx.x * blockDim.x +
blockIdx.y * gridDim.x * blockDim.x;
// for every distribution // for every distribution
for (int dist = blockIdx.y; dist < num_distributions; dist += gridDim.y) { int dist = blockIdx.y;
// for every sample // for every sample
for (int sample = blockIdx.x * blockDim.x + threadIdx.x; int sample = blockIdx.x * blockDim.x + threadIdx.x;
sample < num_samples; sample += blockDim.x * gridDim.x) { if (sample < num_samples) {
T rng_number = rng_data[sample + dist * num_samples]; T rng_number = rng_data[sample + dist * num_samples];
// Find the bucket that a uniform random number lies in // Find the bucket that a uniform random number lies in
int selected_category = binarySearchFunctor<T>( int selected_category = binarySearchFunctor<T>(
cumulative_probs + dist * num_categories, cumulative_probs + dist * num_categories,
norm_probs_data + dist * num_categories, num_categories, rng_number); norm_probs_data + dist * num_categories, num_categories, rng_number);
out_data[sample + dist * num_samples] = selected_category; out_data[sample + dist * num_samples] = selected_category;
}
} }
} }
...@@ -215,10 +213,11 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T> ...@@ -215,10 +213,11 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>
// number of threads in a block is min(num_categories, 512) // number of threads in a block is min(num_categories, 512)
dim3 block_norm(num_categories < 512 ? num_categories : 512); dim3 block_norm(num_categories < 512 ? num_categories : 512);
dim3 grid_norm((num_categories - 1) / block_norm.x + 1, num_distributions); dim3 grid_norm((num_distributions * num_categories - 1) / block_norm.x + 1);
NormalizeProbability< NormalizeProbability<
T><<<grid_norm, block_norm, 0, ctx.cuda_device_context().stream()>>>( T><<<grid_norm, block_norm, 0, ctx.cuda_device_context().stream()>>>(
norm_probs_data, in_data, sum_rows_data); norm_probs_data, in_data, sum_rows_data, num_distributions,
num_categories);
// Get cumulative probability of each distribution. It's the same function // Get cumulative probability of each distribution. It's the same function
// of // of
......
...@@ -141,6 +141,14 @@ class TestMultinomialApi(unittest.TestCase): ...@@ -141,6 +141,14 @@ class TestMultinomialApi(unittest.TestCase):
"replacement is False. categories can't be sampled repeatedly") "replacement is False. categories can't be sampled repeatedly")
paddle.enable_static() paddle.enable_static()
def test_dygraph4(self):
paddle.disable_static()
logits = -1 * paddle.ones([2800])
# Categorical.sample API will call multinomial op with replacement=True
cat = paddle.distribution.Categorical(logits.exp())
cat.sample([1])
paddle.enable_static()
def test_static(self): def test_static(self):
paddle.enable_static() paddle.enable_static()
startup_program = fluid.Program() startup_program = fluid.Program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册