Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
0c3e4cf6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
接近 2 年 前同步成功
通知
2323
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0c3e4cf6
编写于
8月 29, 2023
作者:
D
duanyanhui
提交者:
GitHub
8月 29, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[DCU] support cum & multinomial for dcu (#56612)
* support cum & multinomial for dcu * rm commt
上级
76b328bc
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
12 addition
and
12 deletion
+12
-12
paddle/phi/kernels/gpu/cum_grad_kernel.cu
paddle/phi/kernels/gpu/cum_grad_kernel.cu
+1
-0
paddle/phi/kernels/gpu/cum_kernel.cu
paddle/phi/kernels/gpu/cum_kernel.cu
+1
-0
paddle/phi/kernels/gpu/multinomial_kernel.cu
paddle/phi/kernels/gpu/multinomial_kernel.cu
+10
-8
python/paddle/tensor/random.py
python/paddle/tensor/random.py
+0
-4
未找到文件。
paddle/phi/kernels/gpu/cum_grad_kernel.cu
浏览文件 @
0c3e4cf6
...
...
@@ -70,6 +70,7 @@ PD_REGISTER_KERNEL(cumsum_grad,
phi
::
CumsumGradKernel
,
float
,
double
,
phi
::
dtype
::
float16
,
int16_t
,
int
,
int64_t
)
{}
...
...
paddle/phi/kernels/gpu/cum_kernel.cu
浏览文件 @
0c3e4cf6
...
...
@@ -435,6 +435,7 @@ PD_REGISTER_KERNEL(cumsum,
ALL_LAYOUT
,
phi
::
CumsumKernel
,
float
,
phi
::
dtype
::
float16
,
double
,
int16_t
,
int
,
...
...
paddle/phi/kernels/gpu/multinomial_kernel.cu
浏览文件 @
0c3e4cf6
...
...
@@ -12,10 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_WITH_HIP
// To-do(qili93): fix this after issue resolved
// https://github.com/ROCmSoftwarePlatform/rocPRIM/issues/202
#include "paddle/phi/kernels/multinomial_kernel.h"
#ifdef __NVCC__
...
...
@@ -107,14 +103,22 @@ __global__ void sampleMultinomialWithReplacement(
size_t
idx
=
gridDim
.
x
*
blockDim
.
x
*
blockIdx
.
y
+
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
#if defined(__NVCC__)
curandStatePhilox4_32_10_t
state
;
curand_init
(
seed
,
idx
,
offset
,
&
state
);
#else
hiprandStatePhilox4_32_10_t
state
;
hiprand_init
(
seed
,
idx
,
offset
,
&
state
);
#endif
int
sample
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(
int
dist
=
blockIdx
.
y
;
dist
<
num_distributions
;
dist
+=
gridDim
.
y
)
{
if
(
sample
<
num_samples
)
{
#if defined(__NVCC__)
T
rng_number
=
static_cast
<
T
>
(
curand_uniform4
(
&
state
).
x
);
// Find the bucket that a uniform random number lies in
#else
T
rng_number
=
static_cast
<
T
>
(
hiprand_uniform4
(
&
state
).
x
);
#endif
int
selected_category
=
binarySearchFunctor
<
T
>
(
cumulative_probs_data
+
dist
*
num_categories
,
norm_probs_data
+
dist
*
num_categories
,
...
...
@@ -283,7 +287,7 @@ void MultinomialKernel(const Context& dev_ctx,
}
// namespace phi
PD_REGISTER_KERNEL
(
multinomial
,
// cuda_only
PD_REGISTER_KERNEL
(
multinomial
,
GPU
,
ALL_LAYOUT
,
phi
::
MultinomialKernel
,
...
...
@@ -293,5 +297,3 @@ PD_REGISTER_KERNEL(multinomial, // cuda_only
double
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
phi
::
DataType
::
INT64
);
}
#endif
python/paddle/tensor/random.py
浏览文件 @
0c3e4cf6
...
...
@@ -183,10 +183,6 @@ def multinomial(x, num_samples=1, replacement=False, name=None):
"""
assert
(
not
core
.
is_compiled_with_rocm
()
),
"multinomial op is not supported on ROCM yet."
if
in_dynamic_mode
():
return
_C_ops
.
multinomial
(
x
,
num_samples
,
replacement
)
else
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录