未验证 提交 51a33962 编写于 作者: P pangyoki 提交者: GitHub

add unittest (#36511)

上级 dd1d3789
......@@ -33,18 +33,22 @@ namespace operators {
template <typename T>
__global__ void NormalizeProbability(T* norm_probs, const T* in_data,
T* sum_rows) {
T* sum_rows, int64_t num_distributions,
int64_t num_categories) {
int id = threadIdx.x + blockIdx.x * blockDim.x +
blockIdx.y * gridDim.x * blockDim.x;
PADDLE_ENFORCE(
in_data[id] >= 0.0,
"The input of multinomial distribution should be >= 0, but got %f.",
in_data[id]);
PADDLE_ENFORCE(sum_rows[blockIdx.y] > 0.0,
"The sum of one multinomial distribution probability should "
"be > 0, but got %f.",
sum_rows[blockIdx.y]);
norm_probs[id] = in_data[id] / sum_rows[blockIdx.y];
if (id < num_distributions * num_categories) {
PADDLE_ENFORCE(
in_data[id] >= 0.0,
"The input of multinomial distribution should be >= 0, but got %f.",
in_data[id]);
int64_t row_id = id / num_categories;
PADDLE_ENFORCE(sum_rows[row_id] > 0.0,
"The sum of one multinomial distribution probability should "
"be > 0, but got %f.",
sum_rows[row_id]);
norm_probs[id] = in_data[id] / sum_rows[row_id];
}
}
template <typename T>
......@@ -52,12 +56,10 @@ __global__ void GetCumulativeProbs(T* norm_probs_data,
int64_t num_distributions,
int64_t num_categories,
T* cumulative_probs) {
for (int id = blockIdx.x; id < num_distributions; id += gridDim.x) {
thrust::inclusive_scan(thrust::device,
norm_probs_data + id * num_categories,
norm_probs_data + (id + 1) * num_categories,
cumulative_probs + id * num_categories);
}
int id = blockIdx.x;
thrust::inclusive_scan(thrust::device, norm_probs_data + id * num_categories,
norm_probs_data + (id + 1) * num_categories,
cumulative_probs + id * num_categories);
}
template <typename T>
......@@ -108,23 +110,19 @@ __global__ void sampleMultinomialWithReplacement(
// use binary search to get the selected category sample id.
// let cumulative_probs[id-1] < rng_data < cumulative_probs[id].
int idx = threadIdx.x + blockIdx.x * blockDim.x +
blockIdx.y * gridDim.x * blockDim.x;
// for every distribution
for (int dist = blockIdx.y; dist < num_distributions; dist += gridDim.y) {
// for every sample
for (int sample = blockIdx.x * blockDim.x + threadIdx.x;
sample < num_samples; sample += blockDim.x * gridDim.x) {
T rng_number = rng_data[sample + dist * num_samples];
// Find the bucket that a uniform random number lies in
int selected_category = binarySearchFunctor<T>(
cumulative_probs + dist * num_categories,
norm_probs_data + dist * num_categories, num_categories, rng_number);
out_data[sample + dist * num_samples] = selected_category;
}
int dist = blockIdx.y;
// for every sample
int sample = blockIdx.x * blockDim.x + threadIdx.x;
if (sample < num_samples) {
T rng_number = rng_data[sample + dist * num_samples];
// Find the bucket that a uniform random number lies in
int selected_category = binarySearchFunctor<T>(
cumulative_probs + dist * num_categories,
norm_probs_data + dist * num_categories, num_categories, rng_number);
out_data[sample + dist * num_samples] = selected_category;
}
}
......@@ -215,10 +213,11 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>
// number of threads in a block is min(num_categories, 512)
dim3 block_norm(num_categories < 512 ? num_categories : 512);
dim3 grid_norm((num_categories - 1) / block_norm.x + 1, num_distributions);
dim3 grid_norm((num_distributions * num_categories - 1) / block_norm.x + 1);
NormalizeProbability<
T><<<grid_norm, block_norm, 0, ctx.cuda_device_context().stream()>>>(
norm_probs_data, in_data, sum_rows_data);
norm_probs_data, in_data, sum_rows_data, num_distributions,
num_categories);
// Get cumulative probability of each distribution. It's the same function
// of
......
......@@ -141,6 +141,14 @@ class TestMultinomialApi(unittest.TestCase):
"replacement is False. categories can't be sampled repeatedly")
paddle.enable_static()
def test_dygraph4(self):
paddle.disable_static()
logits = -1 * paddle.ones([2800])
# Categorical.sample API will call multinomial op with replacement=True
cat = paddle.distribution.Categorical(logits.exp())
cat.sample([1])
paddle.enable_static()
def test_static(self):
paddle.enable_static()
startup_program = fluid.Program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册