Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
a270fdf2
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a270fdf2
编写于
11月 08, 2018
作者:
C
chengduo
提交者:
GitHub
11月 08, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix SelectedRowsAdd bug (#14309)
* fix selected_rows bug test=develop * refine cos_sim test=develop
上级
1001f8e1
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
7 addition
and
7 deletion
+7
-7
paddle/fluid/operators/math/cos_sim_functor.cu
paddle/fluid/operators/math/cos_sim_functor.cu
+1
-1
paddle/fluid/operators/math/selected_rows_functor.cu
paddle/fluid/operators/math/selected_rows_functor.cu
+6
-6
未找到文件。
paddle/fluid/operators/math/cos_sim_functor.cu
浏览文件 @
a270fdf2
...
...
@@ -51,7 +51,7 @@ struct CosSimDyFunctor<platform::CUDADeviceContext, T> {
T
*
dy
)
const
{
const
int
block_size
=
512
;
dim3
threads
(
block_size
,
1
);
dim3
grid
(
1
,
(
rows
+
block_size
-
1
)
/
block_size
);
dim3
grid
(
(
rows
+
block_size
-
1
)
/
block_size
,
1
);
CosSimDyKernel
<
T
><<<
grid
,
threads
,
0
,
ctx
.
stream
()
>>>
(
x_norm
,
y_norm
,
x
,
y
,
z
,
dz
,
rows
,
cols
,
dy
);
}
...
...
paddle/fluid/operators/math/selected_rows_functor.cu
浏览文件 @
a270fdf2
...
...
@@ -81,7 +81,7 @@ template <typename T, int block_size>
__global__
void
SelectedRowsAddTensorKernel
(
const
T
*
selected_rows
,
const
int64_t
*
rows
,
T
*
tensor_out
,
int64_t
row_numel
)
{
const
int
ty
=
blockIdx
.
y
;
const
int
ty
=
blockIdx
.
x
;
int
tid
=
threadIdx
.
x
;
selected_rows
+=
ty
*
row_numel
;
...
...
@@ -123,7 +123,7 @@ struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {
const
int
block_size
=
256
;
dim3
threads
(
block_size
,
1
);
dim3
grid
(
1
,
in1_rows
.
size
()
);
dim3
grid
(
in1_rows
.
size
(),
1
);
SelectedRowsAddTensorKernel
<
T
,
block_size
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
in1_data
,
in1_rows
.
CUDAData
(
context
.
GetPlace
()),
out_data
,
...
...
@@ -188,7 +188,7 @@ __global__ void SelectedRowsAddToTensorKernel(const T* selected_rows,
const
int64_t
*
rows
,
T
*
tensor_out
,
int64_t
row_numel
)
{
const
int
ty
=
blockIdx
.
y
;
const
int
ty
=
blockIdx
.
x
;
int
tid
=
threadIdx
.
x
;
selected_rows
+=
ty
*
row_numel
;
...
...
@@ -221,7 +221,7 @@ struct SelectedRowsAddToTensor<platform::CUDADeviceContext, T> {
auto
*
in2_data
=
input2
->
data
<
T
>
();
const
int
block_size
=
256
;
dim3
threads
(
block_size
,
1
);
dim3
grid
(
1
,
in1_rows
.
size
()
);
dim3
grid
(
in1_rows
.
size
(),
1
);
SelectedRowsAddToTensorKernel
<
T
,
block_size
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
in1_data
,
in1_rows
.
CUDAData
(
context
.
GetPlace
()),
in2_data
,
...
...
@@ -388,7 +388,7 @@ template <typename T, int block_size>
__global__
void
UpdateToTensorKernel
(
const
T
*
selected_rows
,
const
int64_t
*
rows
,
const
ScatterOps
&
op
,
T
*
tensor_out
,
int64_t
row_numel
)
{
const
int
ty
=
blockIdx
.
y
;
const
int
ty
=
blockIdx
.
x
;
int
tid
=
threadIdx
.
x
;
selected_rows
+=
ty
*
row_numel
;
...
...
@@ -457,7 +457,7 @@ struct UpdateToTensor<platform::CUDADeviceContext, T> {
auto
*
in2_data
=
input2
->
data
<
T
>
();
dim3
threads
(
platform
::
PADDLE_CUDA_NUM_THREADS
,
1
);
dim3
grid
(
1
,
in1_rows
.
size
()
);
dim3
grid
(
in1_rows
.
size
(),
1
);
UpdateToTensorKernel
<
T
,
platform
::
PADDLE_CUDA_NUM_THREADS
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
in1_data
,
in1_rows
.
cuda_data
(),
op
,
in2_data
,
in1_row_numel
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录