Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
51a33962
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
51a33962
编写于
10月 27, 2021
作者:
P
pangyoki
提交者:
GitHub
10月 27, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add unittest (#36511)
上级
dd1d3789
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
41 addition
and
34 deletion
+41
-34
paddle/fluid/operators/multinomial_op.cu
paddle/fluid/operators/multinomial_op.cu
+33
-34
python/paddle/fluid/tests/unittests/test_multinomial_op.py
python/paddle/fluid/tests/unittests/test_multinomial_op.py
+8
-0
未找到文件。
paddle/fluid/operators/multinomial_op.cu
浏览文件 @
51a33962
...
...
@@ -33,18 +33,22 @@ namespace operators {
template
<
typename
T
>
__global__
void
NormalizeProbability
(
T
*
norm_probs
,
const
T
*
in_data
,
T
*
sum_rows
)
{
T
*
sum_rows
,
int64_t
num_distributions
,
int64_t
num_categories
)
{
int
id
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
+
blockIdx
.
y
*
gridDim
.
x
*
blockDim
.
x
;
PADDLE_ENFORCE
(
in_data
[
id
]
>=
0.0
,
"The input of multinomial distribution should be >= 0, but got %f."
,
in_data
[
id
]);
PADDLE_ENFORCE
(
sum_rows
[
blockIdx
.
y
]
>
0.0
,
"The sum of one multinomial distribution probability should "
"be > 0, but got %f."
,
sum_rows
[
blockIdx
.
y
]);
norm_probs
[
id
]
=
in_data
[
id
]
/
sum_rows
[
blockIdx
.
y
];
if
(
id
<
num_distributions
*
num_categories
)
{
PADDLE_ENFORCE
(
in_data
[
id
]
>=
0.0
,
"The input of multinomial distribution should be >= 0, but got %f."
,
in_data
[
id
]);
int64_t
row_id
=
id
/
num_categories
;
PADDLE_ENFORCE
(
sum_rows
[
row_id
]
>
0.0
,
"The sum of one multinomial distribution probability should "
"be > 0, but got %f."
,
sum_rows
[
row_id
]);
norm_probs
[
id
]
=
in_data
[
id
]
/
sum_rows
[
row_id
];
}
}
template
<
typename
T
>
...
...
@@ -52,12 +56,10 @@ __global__ void GetCumulativeProbs(T* norm_probs_data,
int64_t
num_distributions
,
int64_t
num_categories
,
T
*
cumulative_probs
)
{
for
(
int
id
=
blockIdx
.
x
;
id
<
num_distributions
;
id
+=
gridDim
.
x
)
{
thrust
::
inclusive_scan
(
thrust
::
device
,
norm_probs_data
+
id
*
num_categories
,
norm_probs_data
+
(
id
+
1
)
*
num_categories
,
cumulative_probs
+
id
*
num_categories
);
}
int
id
=
blockIdx
.
x
;
thrust
::
inclusive_scan
(
thrust
::
device
,
norm_probs_data
+
id
*
num_categories
,
norm_probs_data
+
(
id
+
1
)
*
num_categories
,
cumulative_probs
+
id
*
num_categories
);
}
template
<
typename
T
>
...
...
@@ -108,23 +110,19 @@ __global__ void sampleMultinomialWithReplacement(
// use binary search to get the selected category sample id.
// let cumulative_probs[id-1] < rng_data < cumulative_probs[id].
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
+
blockIdx
.
y
*
gridDim
.
x
*
blockDim
.
x
;
// for every distribution
for
(
int
dist
=
blockIdx
.
y
;
dist
<
num_distributions
;
dist
+=
gridDim
.
y
)
{
// for every sample
for
(
int
sample
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
sample
<
num_samples
;
sample
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
rng_number
=
rng_data
[
sample
+
dist
*
num_samples
];
// Find the bucket that a uniform random number lies in
int
selected_category
=
binarySearchFunctor
<
T
>
(
cumulative_probs
+
dist
*
num_categories
,
norm_probs_data
+
dist
*
num_categories
,
num_categories
,
rng_number
);
out_data
[
sample
+
dist
*
num_samples
]
=
selected_category
;
}
int
dist
=
blockIdx
.
y
;
// for every sample
int
sample
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
sample
<
num_samples
)
{
T
rng_number
=
rng_data
[
sample
+
dist
*
num_samples
];
// Find the bucket that a uniform random number lies in
int
selected_category
=
binarySearchFunctor
<
T
>
(
cumulative_probs
+
dist
*
num_categories
,
norm_probs_data
+
dist
*
num_categories
,
num_categories
,
rng_number
);
out_data
[
sample
+
dist
*
num_samples
]
=
selected_category
;
}
}
...
...
@@ -215,10 +213,11 @@ class MultinomialOpKernel<platform::CUDADeviceContext, T>
// number of threads in a block is min(num_categories, 512)
dim3
block_norm
(
num_categories
<
512
?
num_categories
:
512
);
dim3
grid_norm
((
num_
categories
-
1
)
/
block_norm
.
x
+
1
,
num_distributions
);
dim3
grid_norm
((
num_
distributions
*
num_categories
-
1
)
/
block_norm
.
x
+
1
);
NormalizeProbability
<
T
><<<
grid_norm
,
block_norm
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
norm_probs_data
,
in_data
,
sum_rows_data
);
norm_probs_data
,
in_data
,
sum_rows_data
,
num_distributions
,
num_categories
);
// Get cumulative probability of each distribution. It's the same function
// of
...
...
python/paddle/fluid/tests/unittests/test_multinomial_op.py
浏览文件 @
51a33962
...
...
@@ -141,6 +141,14 @@ class TestMultinomialApi(unittest.TestCase):
"replacement is False. categories can't be sampled repeatedly"
)
paddle
.
enable_static
()
def
test_dygraph4
(
self
):
paddle
.
disable_static
()
logits
=
-
1
*
paddle
.
ones
([
2800
])
# Categorical.sample API will call multinomial op with replacement=True
cat
=
paddle
.
distribution
.
Categorical
(
logits
.
exp
())
cat
.
sample
([
1
])
paddle
.
enable_static
()
def
test_static
(
self
):
paddle
.
enable_static
()
startup_program
=
fluid
.
Program
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录