Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
f8238411
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f8238411
编写于
8月 25, 2020
作者:
W
whs
提交者:
GitHub
8月 25, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix atomicAdd in grid sample op and affine grid op (#26647)
test=develop
上级
32ba8602
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
7 addition
and
7 deletion
+7
-7
paddle/fluid/operators/affine_grid_op.cu
paddle/fluid/operators/affine_grid_op.cu
+6
-6
paddle/fluid/operators/grid_sampler_op.cu
paddle/fluid/operators/grid_sampler_op.cu
+1
-1
未找到文件。
paddle/fluid/operators/affine_grid_op.cu
浏览文件 @
f8238411
...
...
@@ -86,14 +86,14 @@ __global__ void affine_grid_grad_kernel(const int count, int n, int out_h,
int
theta_offset
=
n
*
6
;
// 2 * 3;
T
out_grad_x
=
out_grad
[
index
*
2
];
a
tomicAdd
(
theta_grad
+
theta_offset
,
out_grad_x
*
h_coor
);
a
tomicAdd
(
theta_grad
+
theta_offset
+
1
,
out_grad_x
*
w_coor
);
a
tomicAdd
(
theta_grad
+
theta_offset
+
2
,
out_grad_x
);
platform
::
CudaA
tomicAdd
(
theta_grad
+
theta_offset
,
out_grad_x
*
h_coor
);
platform
::
CudaA
tomicAdd
(
theta_grad
+
theta_offset
+
1
,
out_grad_x
*
w_coor
);
platform
::
CudaA
tomicAdd
(
theta_grad
+
theta_offset
+
2
,
out_grad_x
);
T
out_grad_y
=
out_grad
[
index
*
2
+
1
];
a
tomicAdd
(
theta_grad
+
theta_offset
+
3
,
out_grad_y
*
h_coor
);
a
tomicAdd
(
theta_grad
+
theta_offset
+
4
,
out_grad_y
*
w_coor
);
a
tomicAdd
(
theta_grad
+
theta_offset
+
5
,
out_grad_y
);
platform
::
CudaA
tomicAdd
(
theta_grad
+
theta_offset
+
3
,
out_grad_y
*
h_coor
);
platform
::
CudaA
tomicAdd
(
theta_grad
+
theta_offset
+
4
,
out_grad_y
*
w_coor
);
platform
::
CudaA
tomicAdd
(
theta_grad
+
theta_offset
+
5
,
out_grad_y
);
}
}
...
...
paddle/fluid/operators/grid_sampler_op.cu
浏览文件 @
f8238411
...
...
@@ -31,7 +31,7 @@ static __forceinline__ __device__ void atomic_add(T* data, int h, int w, int sH,
int
sW
,
int
H
,
int
W
,
T
delta
)
{
if
(
in_bounds
(
h
,
w
,
H
,
W
))
{
a
tomicAdd
(
data
+
h
*
sH
+
w
*
sW
,
delta
);
platform
::
CudaA
tomicAdd
(
data
+
h
*
sH
+
w
*
sW
,
delta
);
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录