Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
ed857585
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ed857585
编写于
7月 28, 2022
作者:
X
xiaoxiaohehe001
提交者:
GitHub
7月 28, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle Inference] Support depthwise_conv2d fp16. (#44642)
* depthwise_fp16 * depthwise_fp16 * depthwise_fp16 * depthwise_fp16
上级
20759c30
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
41 addition
and
24 deletion
+41
-24
paddle/phi/kernels/gpu/depthwise_conv.h
paddle/phi/kernels/gpu/depthwise_conv.h
+37
-22
paddle/phi/kernels/gpu/depthwise_conv_grad_kernel.cu
paddle/phi/kernels/gpu/depthwise_conv_grad_kernel.cu
+2
-1
paddle/phi/kernels/gpu/depthwise_conv_kernel.cu
paddle/phi/kernels/gpu/depthwise_conv_kernel.cu
+2
-1
未找到文件。
paddle/phi/kernels/gpu/depthwise_conv.h
浏览文件 @
ed857585
...
...
@@ -153,7 +153,7 @@ __device__ __inline__ void KernelDepthwiseConvNCHW(
const
int
c_in
=
c_out
/
filter_multiplier
;
const
T
*
weight
=
filter_data
+
c_out
*
filter_height
*
filter_width
;
T
value
=
0
;
T
value
(
0
)
;
const
int
h_in_start
=
-
padding_height
+
h_out
*
stride_height
;
const
int
w_in_start
=
-
padding_width
+
w_out
*
stride_width
;
const
int
h_in_end
=
h_in_start
+
filter_height
*
dilate_height
;
...
...
@@ -176,7 +176,7 @@ __device__ __inline__ void KernelDepthwiseConvNCHW(
int
offset
=
in_offset
+
h_in
*
input_width
+
w_in
;
T
in_data
=
input_data
[
offset
];
if
(
fuse_relu_before_conv
)
{
value
+=
weight
[
weight_offset
]
*
max
(
0.0
f
,
in_data
);
value
+=
weight
[
weight_offset
]
*
T
(
max
(
0.0
f
,
double
(
in_data
))
);
}
else
{
value
+=
weight
[
weight_offset
]
*
in_data
;
}
...
...
@@ -205,7 +205,7 @@ __device__ __inline__ void KernelDepthwiseConvNHWC(
const
int
batch
=
idx
/
output_width
/
output_height
/
output_channels
;
const
int
c_in
=
c_out
/
filter_multiplier
;
T
value
=
0
;
T
value
(
0
)
;
const
int
h_in_start
=
-
padding_height
+
h_out
*
stride_height
;
const
int
w_in_start
=
-
padding_width
+
w_out
*
stride_width
;
const
int
h_in_end
=
h_in_start
+
filter_height
*
dilate_height
;
...
...
@@ -228,7 +228,7 @@ __device__ __inline__ void KernelDepthwiseConvNHWC(
T
in_data
=
input_data
[
offset
];
const
T
*
weight
=
filter_data
+
weight_offset
*
output_channels
+
c_out
;
if
(
fuse_relu_before_conv
)
{
value
+=
weight
[
0
]
*
max
(
0.0
f
,
in_data
);
value
+=
weight
[
0
]
*
T
(
max
(
0.0
f
,
double
(
in_data
))
);
}
else
{
value
+=
weight
[
0
]
*
in_data
;
}
...
...
@@ -258,7 +258,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNCHW(
const
int
c_out
=
blockIdx
.
x
;
const
int
c_in
=
c_out
/
filter_multiplier
;
T
value
=
0
;
T
value
(
0
)
;
const
int
h_in_start
=
-
padding_height
+
h_out
*
stride_height
;
const
int
w_in_start
=
-
padding_width
+
w_out
*
stride_width
;
const
int
h_in_end
=
h_in_start
+
c_filter
*
dilate_height
;
...
...
@@ -281,7 +281,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNCHW(
int
offset
=
in_offset
+
h_in
*
input_width
+
w_in
;
if
(
fuse_relu_before_conv
)
{
value
+=
r_weight
[
h_f
*
c_filter
+
w_f
]
*
max
(
0.0
f
,
input_data
[
offset
]
);
T
(
max
(
0.0
f
,
double
(
input_data
[
offset
]))
);
}
else
{
value
+=
r_weight
[
h_f
*
c_filter
+
w_f
]
*
input_data
[
offset
];
}
...
...
@@ -325,7 +325,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNHWC(
if
(
w_out
>=
output_width
)
{
continue
;
}
T
value
=
0
;
T
value
(
0
)
;
const
int
w_in_start
=
-
padding_width
+
w_out
*
stride_width
;
for
(
int
h_in
=
h_in_start
,
h_f
=
0
;
h_f
<
c_filter
;
h_in
+=
dilate_height
,
h_f
++
)
{
...
...
@@ -337,7 +337,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNHWC(
in_offset
+
(
h_in
*
input_width
+
w_in
)
*
input_channels
+
c_in
;
if
(
fuse_relu_before_conv
)
{
value
+=
r_weight
[
h_f
*
c_filter
+
w_f
]
*
max
(
0.0
f
,
input_data
[
offset
]
);
T
(
max
(
0.0
,
double
(
input_data
[
offset
]))
);
}
else
{
value
+=
r_weight
[
h_f
*
c_filter
+
w_f
]
*
input_data
[
offset
];
}
...
...
@@ -482,13 +482,13 @@ __device__ __inline__ void KernelDepthwiseConvInputGradNCHW(
w_in
-
(
filter_width
-
1
)
*
dilate_width
+
padding_width
;
int
w_out_end
=
w_in
+
padding_width
;
T
value
=
0
;
T
value
(
0
)
;
int
index
=
((
batch
*
gridDim
.
x
+
c_in
)
*
input_height
+
h_in
)
*
input_width
+
w_in
;
if
(
fuse_relu_before_conv
)
{
if
(
input_data
[
index
]
<=
0
)
{
if
(
input_data
[
index
]
<=
T
(
0
)
)
{
input_grad_data
[
index
]
=
0
;
continue
;
}
...
...
@@ -539,12 +539,12 @@ __device__ __inline__ void KernelDepthwiseConvInputGradNHWC(
int
w_out_start
=
w_in
-
(
filter_width
-
1
)
*
dilate_width
+
padding_width
;
T
value
=
0
;
T
value
(
0
)
;
int
index
=
((
batch
*
input_height
+
h_in
)
*
input_width
+
w_in
)
*
input_channels
+
c_in
;
if
(
fuse_relu_before_conv
)
{
if
(
input_data
[
index
]
<=
0
)
{
if
(
input_data
[
index
]
<=
T
(
0
)
)
{
input_grad_data
[
index
]
=
0
;
continue
;
}
...
...
@@ -603,12 +603,12 @@ __device__ __inline__ void KernelDepthwiseConvInputGradCFilterNCHW(
int
h_out_start
=
h_in
-
(
c_filter
-
1
)
*
dilate_height
+
padding_height
;
int
w_out_start
=
w_in
-
(
c_filter
-
1
)
*
dilate_width
+
padding_width
;
T
value
=
0
;
T
value
(
0
)
;
int
index
=
((
batch
*
gridDim
.
x
+
c_in
)
*
input_height
+
h_in
)
*
input_width
+
w_in
;
if
(
fuse_relu_before_conv
)
{
if
(
input_data
[
index
]
<=
0
)
{
if
(
input_data
[
index
]
<=
T
(
0
)
)
{
input_grad_data
[
index
]
=
0
;
continue
;
}
...
...
@@ -676,12 +676,12 @@ __device__ __inline__ void KernelDepthwiseConvInputGradCFilterNHWC(
}
int
w_out_start
=
w_in
-
(
c_filter
-
1
)
*
dilate_width
+
padding_width
;
T
value
=
0
;
T
value
(
0
)
;
int
index
=
((
batch
*
input_height
+
h_in
)
*
input_width
+
w_in
)
*
input_channels
+
c_in
;
if
(
fuse_relu_before_conv
)
{
if
(
input_data
[
index
]
<=
0
)
{
if
(
input_data
[
index
]
<=
T
(
0
)
)
{
input_grad_data
[
index
]
=
0
;
continue
;
}
...
...
@@ -854,7 +854,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW(
const
int
dilate_height
,
const
int
dilate_width
,
T
*
filter_grad_data
)
{
T
s
=
0
;
T
s
(
0
)
;
int
gbid
=
((
blockIdx
.
z
*
gridDim
.
y
)
+
blockIdx
.
y
)
*
gridDim
.
x
+
blockIdx
.
x
;
for
(
int
image_w
=
threadIdx
.
x
;
image_w
<
output_width
;
...
...
@@ -880,7 +880,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW(
image_wk
;
if
(
fuse_relu_before_conv
)
{
s
+=
output_grad_data
[
gaid
(
bid
,
kernel_id
,
image_h
,
image_w
)]
*
max
(
0.0
f
,
input_data
[
input_id
]
);
T
(
max
(
0.0
f
,
double
(
input_data
[
input_id
]))
);
}
else
{
s
+=
output_grad_data
[
gaid
(
bid
,
kernel_id
,
image_h
,
image_w
)]
*
input_data
[
input_id
];
...
...
@@ -921,7 +921,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNHWC(
int
kernel_ih
=
blockIdx
.
x
/
filter_width
;
for
(
int
kernel_id
=
threadIdx
.
x
;
kernel_id
<
output_channels
;
kernel_id
+=
blockDim
.
x
)
{
T
s
=
0
;
T
s
(
0
)
;
int
gbid
=
((
kernel_id
*
filter_height
)
+
kernel_ih
)
*
filter_width
+
kernel_iw
;
for
(
int
image_w
=
threadIdx
.
y
;
image_w
<
output_width
;
...
...
@@ -941,7 +941,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNHWC(
kernel_id
/
filter_multiplier
;
if
(
fuse_relu_before_conv
)
{
s
+=
output_grad_data
[
gaid
(
bid
,
image_h
,
image_w
,
kernel_id
)]
*
max
(
0.0
f
,
input_data
[
input_id
]
);
T
(
max
(
0.0
f
,
double
(
input_data
[
input_id
]))
);
}
else
{
s
+=
output_grad_data
[
gaid
(
bid
,
image_h
,
image_w
,
kernel_id
)]
*
input_data
[
input_id
];
...
...
@@ -1010,9 +1010,10 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradCFilterNHWC(
((
bid
*
output_height
+
image_h
)
*
output_width
+
image_w
)
*
output_channels
+
kernel_id
;
T
s
=
0
;
T
s
(
0
)
;
if
(
fuse_relu_before_conv
)
{
s
=
output_grad_data
[
output_id
]
*
max
(
0.0
f
,
input_data
[
input_id
]);
s
=
output_grad_data
[
output_id
]
*
T
(
max
(
0.0
f
,
double
(
input_data
[
input_id
])));
}
else
{
s
=
output_grad_data
[
output_id
]
*
input_data
[
input_id
];
}
...
...
@@ -1672,21 +1673,35 @@ class DepthwiseConvFilterGradFunctor<phi::GPUContext,
template
class
DepthwiseConvFunctor
<
phi
::
GPUContext
,
float
,
false
>;
template
class
DepthwiseConvFunctor
<
phi
::
GPUContext
,
double
,
false
>;
template
class
DepthwiseConvFunctor
<
phi
::
GPUContext
,
platform
::
float16
,
false
>;
template
class
DepthwiseConvInputGradFunctor
<
phi
::
GPUContext
,
float
,
false
>;
template
class
DepthwiseConvInputGradFunctor
<
phi
::
GPUContext
,
double
,
false
>;
template
class
DepthwiseConvInputGradFunctor
<
phi
::
GPUContext
,
platform
::
float16
,
false
>;
template
class
DepthwiseConvFilterGradFunctor
<
phi
::
GPUContext
,
float
,
false
>;
template
class
DepthwiseConvFilterGradFunctor
<
phi
::
GPUContext
,
double
,
false
>;
template
class
DepthwiseConvFilterGradFunctor
<
phi
::
GPUContext
,
platform
::
float16
,
false
>;
template
class
DepthwiseConvFunctor
<
phi
::
GPUContext
,
float
,
true
>;
template
class
DepthwiseConvFunctor
<
phi
::
GPUContext
,
double
,
true
>;
template
class
DepthwiseConvFunctor
<
phi
::
GPUContext
,
platform
::
float16
,
true
>;
template
class
DepthwiseConvInputGradFunctor
<
phi
::
GPUContext
,
float
,
true
>;
template
class
DepthwiseConvInputGradFunctor
<
phi
::
GPUContext
,
double
,
true
>;
template
class
DepthwiseConvInputGradFunctor
<
phi
::
GPUContext
,
platform
::
float16
,
true
>;
template
class
DepthwiseConvFilterGradFunctor
<
phi
::
GPUContext
,
float
,
true
>;
template
class
DepthwiseConvFilterGradFunctor
<
phi
::
GPUContext
,
double
,
true
>;
template
class
DepthwiseConvFilterGradFunctor
<
phi
::
GPUContext
,
platform
::
float16
,
true
>;
}
// namespace math
}
// namespace operators
...
...
paddle/phi/kernels/gpu/depthwise_conv_grad_kernel.cu
浏览文件 @
ed857585
...
...
@@ -139,4 +139,5 @@ PD_REGISTER_KERNEL(depthwise_conv2d_grad,
ALL_LAYOUT
,
phi
::
DepthwiseConvGradKernel
,
float
,
double
)
{}
double
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/gpu/depthwise_conv_kernel.cu
浏览文件 @
ed857585
...
...
@@ -124,4 +124,5 @@ PD_REGISTER_KERNEL(depthwise_conv2d,
ALL_LAYOUT
,
phi
::
DepthwiseConvKernel
,
float
,
double
)
{}
double
,
phi
::
dtype
::
float16
)
{}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录