Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
64a40442
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
64a40442
编写于
1月 08, 2020
作者:
L
liu zhengxi
提交者:
GitHub
1月 08, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add double register op_data_type of pad2d and fix compile error, test=develop (#22075)
上级
7ba7acd1
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
18 addition
and
10 deletion
+18
-10
paddle/fluid/operators/pad2d_op.cc
paddle/fluid/operators/pad2d_op.cc
+5
-2
paddle/fluid/operators/pad2d_op.cu
paddle/fluid/operators/pad2d_op.cu
+13
-8
未找到文件。
paddle/fluid/operators/pad2d_op.cc
浏览文件 @
64a40442
...
...
@@ -661,5 +661,8 @@ REGISTER_OPERATOR(pad2d, ops::Pad2dOp, ops::Pad2dOpMaker,
ops
::
Pad2dOpGradMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OPERATOR
(
pad2d_grad
,
ops
::
Pad2dOpGrad
,
ops
::
Pad2dOpGradNoNeedBufferVarsInference
);
REGISTER_OP_CPU_KERNEL
(
pad2d
,
ops
::
Pad2dCPUKernel
<
float
>
);
REGISTER_OP_CPU_KERNEL
(
pad2d_grad
,
ops
::
Pad2dGradCPUKernel
<
float
>
);
REGISTER_OP_CPU_KERNEL
(
pad2d
,
ops
::
Pad2dCPUKernel
<
float
>
,
ops
::
Pad2dCPUKernel
<
double
>
,
ops
::
Pad2dCPUKernel
<
int
>
,
ops
::
Pad2dCPUKernel
<
int64_t
>
);
REGISTER_OP_CPU_KERNEL
(
pad2d_grad
,
ops
::
Pad2dGradCPUKernel
<
float
>
,
ops
::
Pad2dGradCPUKernel
<
double
>
);
paddle/fluid/operators/pad2d_op.cu
浏览文件 @
64a40442
...
...
@@ -215,8 +215,9 @@ __global__ void Pad2DGradReflectNCHW(const int out_size, T* d_in_data,
in_w
=
max
(
in_w
,
-
in_w
);
in_h
=
min
(
in_h
,
2
*
in_height
-
in_h
-
2
);
in_w
=
min
(
in_w
,
2
*
in_width
-
in_w
-
2
);
atomicAdd
(
&
d_in_data
[(
nc
*
in_height
+
in_h
)
*
in_width
+
in_w
],
d_out_data
[
out_index
]);
platform
::
CudaAtomicAdd
(
&
d_in_data
[(
nc
*
in_height
+
in_h
)
*
in_width
+
in_w
],
d_out_data
[
out_index
]);
}
}
...
...
@@ -240,7 +241,7 @@ __global__ void Pad2DGradReflectNHWC(const int out_size, T* d_in_data,
in_w
=
max
(
in_w
,
-
in_w
);
in_h
=
min
(
in_h
,
in_height
*
2
-
in_h
-
2
);
in_w
=
min
(
in_w
,
in_width
*
2
-
in_w
-
2
);
a
tomicAdd
(
platform
::
CudaA
tomicAdd
(
&
d_in_data
[((
n
*
in_height
+
in_h
)
*
in_width
+
in_w
)
*
channels
+
c
],
d_out_data
[
out_index
]);
}
...
...
@@ -260,8 +261,9 @@ __global__ void Pad2DGradEdgeNCHW(const int out_size, T* d_in_data,
nc
/=
out_height
;
const
int
in_h
=
min
(
in_height
-
1
,
max
(
out_h
-
pad_top
,
0
));
const
int
in_w
=
min
(
in_width
-
1
,
max
(
out_w
-
pad_left
,
0
));
atomicAdd
(
&
d_in_data
[(
nc
*
in_height
+
in_h
)
*
in_width
+
in_w
],
d_out_data
[
out_index
]);
platform
::
CudaAtomicAdd
(
&
d_in_data
[(
nc
*
in_height
+
in_h
)
*
in_width
+
in_w
],
d_out_data
[
out_index
]);
}
}
...
...
@@ -281,7 +283,7 @@ __global__ void Pad2DGradEdgeNHWC(const int out_size, T* d_in_data,
n
/=
out_height
;
const
int
in_h
=
min
(
in_height
-
1
,
max
(
out_h
-
pad_top
,
0
));
const
int
in_w
=
min
(
in_width
-
1
,
max
(
out_w
-
pad_left
,
0
));
a
tomicAdd
(
platform
::
CudaA
tomicAdd
(
&
d_in_data
[((
n
*
in_height
+
in_h
)
*
in_width
+
in_w
)
*
channels
+
c
],
d_out_data
[
out_index
]);
}
...
...
@@ -459,5 +461,8 @@ class Pad2dGradCUDAKernel : public framework::OpKernel<T> {
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
pad2d
,
ops
::
Pad2dCUDAKernel
<
float
>
);
REGISTER_OP_CUDA_KERNEL
(
pad2d_grad
,
ops
::
Pad2dGradCUDAKernel
<
float
>
);
REGISTER_OP_CUDA_KERNEL
(
pad2d
,
ops
::
Pad2dCUDAKernel
<
float
>
,
ops
::
Pad2dCUDAKernel
<
double
>
,
ops
::
Pad2dCUDAKernel
<
int
>
,
ops
::
Pad2dCUDAKernel
<
int64_t
>
);
REGISTER_OP_CUDA_KERNEL
(
pad2d_grad
,
ops
::
Pad2dGradCUDAKernel
<
float
>
,
ops
::
Pad2dGradCUDAKernel
<
double
>
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录