Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
17879045
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
17879045
编写于
12月 07, 2022
作者:
Z
zhoutianzi666
提交者:
GitHub
12月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize nchw<->nhwc kernel in fp16 model (#48692)
上级
e5bc2eec
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
105 addition
and
5 deletion
+105
-5
paddle/phi/kernels/funcs/math_function.cu
paddle/phi/kernels/funcs/math_function.cu
+76
-5
paddle/phi/kernels/funcs/math_function.h
paddle/phi/kernels/funcs/math_function.h
+3
-0
paddle/phi/kernels/transfer_layout_kernel.cc
paddle/phi/kernels/transfer_layout_kernel.cc
+26
-0
未找到文件。
paddle/phi/kernels/funcs/math_function.cu
浏览文件 @
17879045
...
...
@@ -27,11 +27,83 @@ limitations under the License. */
namespace
phi
{
namespace
funcs
{
// The following part of the code refers to NVIDIA-cutlass
// https://github.com/NVIDIA/cutlass/blob/master/tools/util/include/cutlass/util/device_nchw_to_nhwc.h
// Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
// reserved. SPDX-License-Identifier: BSD-3-Clause
template
<
typename
T
>
__global__
void
batch_transpose_kernel
(
T
*
output
,
const
T
*
input
,
const
int
batch
,
const
int
M
,
const
int
N
)
{
const
int
num
=
M
*
N
;
// "+1" to avoid smem bank conflict
__shared__
T
shbuf
[
32
*
(
32
+
1
)];
const
int32_t
tid
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
const
int32_t
wid
=
tid
/
32
;
const
int32_t
lid
=
tid
%
32
;
const
int32_t
batch_i
=
blockIdx
.
z
;
const
int32_t
mi0
=
blockIdx
.
y
*
32
;
const
int32_t
ni0
=
blockIdx
.
x
*
32
;
const
size_t
input_idx
=
batch_i
*
num
+
(
mi0
+
wid
)
*
N
+
ni0
;
const
T
*
A
=
input
+
input_idx
;
if
(
ni0
+
lid
<
N
)
{
const
int
lid_x_33
=
lid
*
33
;
if
((
mi0
+
32
)
<=
M
)
{
int
mi
=
wid
;
// between 0 and 7
#pragma unroll
for
(
int
mLoopIdx
=
0
;
mLoopIdx
<
4
;
mLoopIdx
++
)
{
shbuf
[
lid_x_33
+
mi
]
=
A
[
lid
];
A
=
&
A
[
8
*
N
];
mi
+=
8
;
}
}
else
{
for
(
int
mi
=
wid
;
mi
<
32
;
mi
+=
8
)
{
if
((
mi
+
mi0
)
<
M
)
{
shbuf
[
lid_x_33
+
mi
]
=
A
[
lid
];
}
A
=
&
A
[
8
*
N
];
}
}
}
__syncthreads
();
const
int32_t
miOut
=
mi0
+
lid
;
output
=
&
output
[
batch_i
*
num
+
miOut
];
if
(
miOut
<
M
)
{
if
(
ni0
+
32
<
N
)
{
int
nI
=
wid
;
#pragma unroll
for
(
int
nLoopIdx
=
0
;
nLoopIdx
<
4
;
++
nLoopIdx
)
{
output
[(
ni0
+
nI
)
*
M
]
=
shbuf
[(
nI
)
*
33
+
lid
];
nI
+=
8
;
}
}
else
{
for
(
int
nI
=
wid
;
nI
<
32
;
nI
+=
8
)
{
if
(
ni0
+
nI
<
N
)
{
output
[(
ni0
+
nI
)
*
M
]
=
shbuf
[(
nI
)
*
33
+
lid
];
}
}
}
}
}
template
<
typename
T
>
void
BatchTranspose
(
T
*
output
,
const
T
*
input
,
int
batch
,
int
m
,
int
n
)
{
dim3
grid
((
n
+
31
)
/
32
,
(
m
+
31
)
/
32
,
batch
);
dim3
block
(
32
,
8
);
batch_transpose_kernel
<<<
grid
,
block
>>>
(
output
,
input
,
batch
,
m
,
n
);
}
using
float16
=
phi
::
dtype
::
float16
;
using
bfloat16
=
phi
::
dtype
::
bfloat16
;
template
struct
SetConstant
<
phi
::
GPUContext
,
phi
::
dtype
::
float16
>;
template
struct
SetConstant
<
phi
::
GPUContext
,
phi
::
dtype
::
bfloat16
>;
template
void
BatchTranspose
(
float16
*
output
,
const
float16
*
input
,
int
batch
,
int
m
,
int
n
);
template
void
BatchTranspose
(
float
*
output
,
const
float
*
input
,
int
batch
,
int
m
,
int
n
);
template
struct
SetConstant
<
phi
::
GPUContext
,
float16
>;
template
struct
SetConstant
<
phi
::
GPUContext
,
bfloat16
>;
template
struct
SetConstant
<
phi
::
GPUContext
,
float
>;
template
struct
SetConstant
<
phi
::
GPUContext
,
double
>;
template
struct
SetConstant
<
phi
::
GPUContext
,
uint8_t
>;
...
...
@@ -42,10 +114,9 @@ template struct SetConstant<phi::GPUContext, bool>;
template
struct
SetConstant
<
phi
::
GPUContext
,
phi
::
dtype
::
complex
<
float
>
>
;
template
struct
SetConstant
<
phi
::
GPUContext
,
phi
::
dtype
::
complex
<
double
>
>
;
template
struct
SetConstant
<
paddle
::
platform
::
CUDAPinnedDeviceContext
,
float16
>;
template
struct
SetConstant
<
paddle
::
platform
::
CUDAPinnedDeviceContext
,
phi
::
dtype
::
float16
>;
template
struct
SetConstant
<
paddle
::
platform
::
CUDAPinnedDeviceContext
,
phi
::
dtype
::
bfloat16
>;
bfloat16
>;
template
struct
SetConstant
<
paddle
::
platform
::
CUDAPinnedDeviceContext
,
float
>;
template
struct
SetConstant
<
paddle
::
platform
::
CUDAPinnedDeviceContext
,
double
>;
template
struct
SetConstant
<
paddle
::
platform
::
CUDAPinnedDeviceContext
,
uint8_t
>;
...
...
paddle/phi/kernels/funcs/math_function.h
浏览文件 @
17879045
...
...
@@ -29,6 +29,9 @@ limitations under the License. */
namespace
phi
{
namespace
funcs
{
template
<
typename
T
>
void
BatchTranspose
(
T
*
output
,
const
T
*
input
,
int
batch
,
int
m
,
int
n
);
template
<
typename
DeviceContext
,
typename
T
>
struct
TransposeNormal
{
// for dims >= 7 situation
...
...
paddle/phi/kernels/transfer_layout_kernel.cc
浏览文件 @
17879045
...
...
@@ -70,6 +70,32 @@ void TransferLayoutGeneral(const Context& dev_ctx,
out
->
Resize
(
phi
::
make_ddim
(
dst_dim
));
dev_ctx
.
Alloc
(
out
,
x
.
dtype
());
// In GPU fp16 model, we will insert many transfer_layout ops in
// conv2d_fusion_layout_transfer_pass, so we optimize this kernel on GPU
if
(
std
::
is_same
<
Context
,
phi
::
GPUContext
>::
value
)
{
std
::
vector
<
int
>
axis_nchw_nhwc
=
{
0
,
2
,
3
,
1
};
std
::
vector
<
int
>
axis_nhwc_nchw
=
{
0
,
3
,
1
,
2
};
const
int
batch
=
src_dim
[
0
];
int
row_len
=
src_dim
[
1
];
int
col_len
=
src_dim
[
2
]
*
src_dim
[
3
];
if
(
axis
==
axis_nhwc_nchw
)
{
row_len
=
src_dim
[
1
]
*
src_dim
[
2
];
col_len
=
src_dim
[
3
];
}
if
(
x
.
dtype
()
==
phi
::
DataType
::
FLOAT16
)
{
funcs
::
BatchTranspose
(
out
->
data
<
phi
::
dtype
::
float16
>
(),
x
.
data
<
phi
::
dtype
::
float16
>
(),
batch
,
row_len
,
col_len
);
return
;
}
else
if
(
x
.
dtype
()
==
phi
::
DataType
::
FLOAT32
)
{
funcs
::
BatchTranspose
(
out
->
data
<
float
>
(),
x
.
data
<
float
>
(),
batch
,
row_len
,
col_len
);
return
;
}
}
PD_VISIT_ALL_TYPES
(
x
.
dtype
(),
"CastDataLayout"
,
([
&
]
{
CastDataLayout
<
data_t
,
Context
>
(
dev_ctx
,
x
,
axis
,
out
);
}));
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录