Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
0c40d889
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0c40d889
编写于
9月 15, 2022
作者:
L
Li Min
提交者:
GitHub
9月 15, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add determine action for embed_grad and index_add. (#46040)
上级
54a43981
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
15 addition
and
0 deletion
+15
-0
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
+7
-0
paddle/phi/kernels/gpu/index_add_kernel.cu
paddle/phi/kernels/gpu/index_add_kernel.cu
+8
-0
未找到文件。
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
浏览文件 @
0c40d889
...
@@ -23,6 +23,8 @@
...
@@ -23,6 +23,8 @@
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
DECLARE_bool
(
cudnn_deterministic
);
namespace
phi
{
namespace
phi
{
template
<
typename
InT
,
typename
OutT
>
template
<
typename
InT
,
typename
OutT
>
...
@@ -101,6 +103,11 @@ struct EmbeddingGradCUDAFunctor {
...
@@ -101,6 +103,11 @@ struct EmbeddingGradCUDAFunctor {
const
int
gridx
=
2
*
dev_ctx_
.
GetSMCount
();
const
int
gridx
=
2
*
dev_ctx_
.
GetSMCount
();
dim3
threads
(
128
,
8
);
dim3
threads
(
128
,
8
);
dim3
grids
(
gridx
,
1
);
dim3
grids
(
gridx
,
1
);
if
(
FLAGS_cudnn_deterministic
)
{
VLOG
(
2
)
<<
"Run grad kernel of embedding with single thread."
;
grids
.
x
=
1
;
}
EmbeddingGrad
<
T
,
IdT
><<<
grids
,
threads
,
0
,
dev_ctx_
.
stream
()
>>>
(
EmbeddingGrad
<
T
,
IdT
><<<
grids
,
threads
,
0
,
dev_ctx_
.
stream
()
>>>
(
d_table
,
d_output
,
ids
,
N
,
K
,
D
);
d_table
,
d_output
,
ids
,
N
,
K
,
D
);
}
}
...
...
paddle/phi/kernels/gpu/index_add_kernel.cu
浏览文件 @
0c40d889
...
@@ -20,6 +20,8 @@
...
@@ -20,6 +20,8 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/core/utils/data_type.h"
DECLARE_bool
(
cudnn_deterministic
);
namespace
phi
{
namespace
phi
{
using
paddle
::
platform
::
PADDLE_CUDA_NUM_THREADS
;
using
paddle
::
platform
::
PADDLE_CUDA_NUM_THREADS
;
...
@@ -79,6 +81,12 @@ void IndexAddKernel(const Context& ctx,
...
@@ -79,6 +81,12 @@ void IndexAddKernel(const Context& ctx,
// todo(@limin29): inplace do not need copy.
// todo(@limin29): inplace do not need copy.
phi
::
Copy
(
ctx
,
x
,
ctx
.
GetPlace
(),
false
,
output
);
phi
::
Copy
(
ctx
,
x
,
ctx
.
GetPlace
(),
false
,
output
);
if
(
FLAGS_cudnn_deterministic
)
{
VLOG
(
2
)
<<
"Run grad kernel of index_add with single thread."
;
block_dim
=
1
;
grid_dim
.
x
=
1
;
}
if
(
index_type
==
phi
::
DataType
::
INT64
)
{
if
(
index_type
==
phi
::
DataType
::
INT64
)
{
const
int64_t
*
index_data
=
index
.
data
<
int64_t
>
();
const
int64_t
*
index_data
=
index
.
data
<
int64_t
>
();
index_add_cuda_kernel
<
T
,
int64_t
>
index_add_cuda_kernel
<
T
,
int64_t
>
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录