Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
d8ffb261
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d8ffb261
编写于
10月 28, 2021
作者:
P
pangyoki
提交者:
GitHub
10月 28, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【Cherry-pick PR 36511】fix out_of_range bug of multinomial op's cuda kernel (#36511) (#36808)
Cherry-pick PR #36511
上级
e3db65d5
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
41 addition
and
34 deletion
+41
-34
paddle/fluid/operators/multinomial_op.cu
paddle/fluid/operators/multinomial_op.cu
+33
-34
python/paddle/fluid/tests/unittests/test_multinomial_op.py
python/paddle/fluid/tests/unittests/test_multinomial_op.py
+8
-0
未找到文件。
paddle/fluid/operators/multinomial_op.cu
浏览文件 @
d8ffb261
...
...
@@ -33,18 +33,22 @@ namespace operators {
template
<
typename
T
>
__global__
void
NormalizeProbability
(
T
*
norm_probs
,
const
T
*
in_data
,
T
*
sum_rows
)
{
T
*
sum_rows
,
int64_t
num_distributions
,
int64_t
num_categories
)
{
int
id
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
+
blockIdx
.
y
*
gridDim
.
x
*
blockDim
.
x
;
PADDLE_ENFORCE
(
in_data
[
id
]
>=
0.0
,
"The input of multinomial distribution should be >= 0, but got %f."
,
in_data
[
id
]);
PADDLE_ENFORCE
(
sum_rows
[
blockIdx
.
y
]
>
0.0
,
"The sum of one multinomial distribution probability should "
"be > 0, but got %f."
,
sum_rows
[
blockIdx
.
y
]);
norm_probs
[
id
]
=
in_data
[
id
]
/
sum_rows
[
blockIdx
.
y
];
if
(
id
<
num_distributions
*
num_categories
)
{
PADDLE_ENFORCE
(
in_data
[
id
]
>=
0.0
,
"The input of multinomial distribution should be >= 0, but got %f."
,
in_data
[
id
]);
int64_t
row_id
=
id
/
num_categories
;
PADDLE_ENFORCE
(
sum_rows
[
row_id
]
>
0.0
,
"The sum of one multinomial distribution probability should "
"be > 0, but got %f."
,
sum_rows
[
row_id
]);
norm_probs
[
id
]
=
in_data
[
id
]
/
sum_rows
[
row_id
];
}
}
template
<
typename
T
>
...
...
@@ -52,12 +56,10 @@ __global__ void GetCumulativeProbs(T* norm_probs_data,
int64_t
num_distributions
,
int64_t
num_categories
,
T
*
cumulative_probs
)
{
for
(
int
id
=
blockIdx
.
x
;
id
<
num_distributions
;
id
+=
gridDim
.
x
)
{
thrust
::
inclusive_scan
(
thrust
::
device
,
norm_probs_data
+
id
*
num_categories
,
norm_probs_data
+
(
id
+
1
)
*
num_categories
,
cumulative_probs
+
id
*
num_categories
);
}
int
id
=
blockIdx
.
x
;
thrust
::
inclusive_scan
(
thrust
::
device
,
norm_probs_data
+
id
*
num_categories
,
norm_probs_data
+
(
id
+
1
)
*
num_categories
,
cumulative_probs
+
id
*
num_categories
);
}
template
<
typename
T
>
...
...
@@ -108,23 +110,19 @@ __global__ void sampleMultinomialWithReplacement(
// use binary search to get the selected category sample id.
// let cumulative_probs[id-1] < rng_data < cumulative_probs[id].
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
+
blockIdx
.
y
*
gridDim
.
x
*
blockDim
.
x
;
// for every distribution
for
(
int
dist
=
blockIdx
.
y
;
dist
<
num_distributions
;
dist
+=
gridDim
.
y
)
{
// for every sample
for
(
int
sample
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
sample
<
num_samples
;
sample
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
rng_number
=
rng_data
[
sample
+
dist
*
num_samples
];
// Find the bucket that a uniform random number lies in
int
selected_category
=
binarySearchFunctor
<
T
>
(
cumulative_probs
+
dist
*
num_categories
,
norm_probs_data
+
dist
*
num_categories
,
num_categories
,
rng_number
);
out_data
[
sample
+
dist
*
num_samples
]
=
selected_category
;
}
int
dist
=
blockIdx
.
y
;
// for every sample
int
sample
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
sample
<
num_samples
)
{
T
rng_number
=
rng_data
[
sample
+
dist
*
num_samples
];
// Find the bucket that a uniform random number lies in
int
selected_category
=
binarySearchFunctor
<
T
>
(
cumulative_probs
+
dist
*
num_categories
,
norm_probs_data
+
dist
*
num_categories
,
num_categories
,
rng_number
);
out_data
[
sample
+
dist
*
num_samples
]
=
selected_category
;
}
}
...
...
@@ -215,10 +213,11 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>
// number of threads in a block is min(num_categories, 512)
dim3
block_norm
(
num_categories
<
512
?
num_categories
:
512
);
dim3
grid_norm
((
num_
categories
-
1
)
/
block_norm
.
x
+
1
,
num_distributions
);
dim3
grid_norm
((
num_
distributions
*
num_categories
-
1
)
/
block_norm
.
x
+
1
);
NormalizeProbability
<
T
><<<
grid_norm
,
block_norm
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
norm_probs_data
,
in_data
,
sum_rows_data
);
norm_probs_data
,
in_data
,
sum_rows_data
,
num_distributions
,
num_categories
);
// Get cumulative probability of each distribution. It's the same function
// of
...
...
python/paddle/fluid/tests/unittests/test_multinomial_op.py
浏览文件 @
d8ffb261
...
...
@@ -141,6 +141,14 @@ class TestMultinomialApi(unittest.TestCase):
"replacement is False. categories can't be sampled repeatedly"
)
paddle
.
enable_static
()
def
test_dygraph4
(
self
):
paddle
.
disable_static
()
logits
=
-
1
*
paddle
.
ones
([
2800
])
# Categorical.sample API will call multinomial op with replacement=True
cat
=
paddle
.
distribution
.
Categorical
(
logits
.
exp
())
cat
.
sample
([
1
])
paddle
.
enable_static
()
def
test_static
(
self
):
paddle
.
enable_static
()
startup_program
=
fluid
.
Program
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录