Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
b3283f4c
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b3283f4c
编写于
9月 15, 2022
作者:
傅
傅剑寒
提交者:
GitHub
9月 15, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize flip kernel by eliminating H2D data transfer, test=develop (#46046)
上级
65bdd80b
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
74 addition
and
77 deletion
+74
-77
paddle/phi/kernels/gpu/flip_kernel.cu
paddle/phi/kernels/gpu/flip_kernel.cu
+74
-77
未找到文件。
paddle/phi/kernels/gpu/flip_kernel.cu
浏览文件 @
b3283f4c
...
...
@@ -13,126 +13,123 @@
// limitations under the License.
#include "paddle/phi/kernels/flip_kernel.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/array.h"
namespace
phi
{
template
<
typename
T
>
template
<
typename
T
,
size_t
Rank
>
__global__
void
flip_cuda_kernel
(
const
int
N
,
const
T
*
in_data
,
T
*
out_data
,
int64_t
*
x_shape
,
int64_t
*
x_stride
,
int
*
flip_dims
,
int
flip_dims_size
,
int
total_dims
)
{
phi
::
Array
<
int64_t
,
Rank
>
shape
,
phi
::
Array
<
int64_t
,
Rank
>
stride
,
phi
::
Array
<
int
,
Rank
>
flip_dims
,
int
flip_dims_size
)
{
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
idx
>=
N
)
{
return
;
}
int
cur_indices
=
idx
,
rem
=
0
,
dst_offset
=
0
;
for
(
int
i
=
0
;
i
<
total_dims
;
++
i
)
{
for
(
int
i
=
0
;
i
<
Rank
;
++
i
)
{
int64_t
temp
=
cur_indices
;
cur_indices
=
cur_indices
/
x_
stride
[
i
];
rem
=
temp
-
cur_indices
*
x_
stride
[
i
];
cur_indices
=
cur_indices
/
stride
[
i
];
rem
=
temp
-
cur_indices
*
stride
[
i
];
// flip the indices if it is in flip_dims
for
(
int
j
=
0
;
j
<
flip_dims_size
;
++
j
)
{
if
(
i
==
flip_dims
[
j
])
{
cur_indices
=
x_
shape
[
i
]
-
1
-
cur_indices
;
cur_indices
=
shape
[
i
]
-
1
-
cur_indices
;
}
}
dst_offset
+=
cur_indices
*
x_
stride
[
i
];
dst_offset
+=
cur_indices
*
stride
[
i
];
cur_indices
=
rem
;
}
out_data
[
idx
]
=
in_data
[
dst_offset
];
}
template
<
typename
T
,
typename
Context
>
void
FlipKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
std
::
vector
<
int
>&
axis
,
DenseTensor
*
out
)
{
const
auto
gplace
=
dev_ctx
.
GetPlace
();
auto
cplace
=
phi
::
CPUPlace
();
std
::
vector
<
int
>
flip_dims
=
axis
;
template
<
typename
T
,
typename
Context
,
size_t
N
>
void
launch_flip_cuda_kernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
std
::
vector
<
int
>&
axis
,
DenseTensor
*
out
)
{
std
::
vector
<
int
>
flip_dims_v
=
axis
;
auto
*
in_data
=
x
.
data
<
T
>
();
auto
*
out_data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
const
int
flip_dims_size
=
static_cast
<
int
>
(
flip_dims
.
size
());
auto
x_dims
=
x
.
dims
();
const
int
total_dims
=
x_dims
.
size
();
const
int
N
=
x
.
numel
();
const
int
numel
=
x
.
numel
();
int
block_size
=
512
;
dim3
dim_block
(
block_size
);
dim3
dim_grid
((
N
+
block_size
-
1
)
/
block_size
);
dim3
dim_grid
((
numel
+
block_size
-
1
)
/
block_size
);
for
(
size_t
i
=
0
;
i
<
flip_dims
.
size
();
++
i
)
{
if
(
flip_dims
[
i
]
<
0
)
{
flip_dims
[
i
]
+=
total_dims
;
for
(
size_t
i
=
0
;
i
<
flip_dims
_v
.
size
();
++
i
)
{
if
(
flip_dims
_v
[
i
]
<
0
)
{
flip_dims
_v
[
i
]
+=
total_dims
;
}
}
auto
x_stride
=
phi
::
stride
(
x_dims
);
std
::
vector
<
int64_t
>
x_dims_v
=
phi
::
vectorize
(
x_dims
);
std
::
vector
<
int64_t
>
x_stride_v
=
phi
::
vectorize
(
x_stride
);
int
bytes
=
total_dims
*
sizeof
(
int64_t
);
auto
x_strides_array_tmp
=
paddle
::
memory
::
Alloc
(
dev_ctx
.
GetPlace
(),
bytes
,
phi
::
Stream
(
reinterpret_cast
<
phi
::
StreamId
>
(
dev_ctx
.
stream
())));
int64_t
*
x_strides_array_gpu
=
reinterpret_cast
<
int64_t
*>
(
x_strides_array_tmp
->
ptr
());
paddle
::
memory
::
Copy
(
gplace
,
x_strides_array_gpu
,
cplace
,
x_stride_v
.
data
(),
bytes
,
dev_ctx
.
stream
());
auto
x_shape_array_tmp
=
paddle
::
memory
::
Alloc
(
dev_ctx
.
GetPlace
(),
bytes
,
phi
::
Stream
(
reinterpret_cast
<
phi
::
StreamId
>
(
dev_ctx
.
stream
())));
int64_t
*
x_shape_array_gpu
=
reinterpret_cast
<
int64_t
*>
(
x_shape_array_tmp
->
ptr
());
paddle
::
memory
::
Copy
(
gplace
,
x_shape_array_gpu
,
cplace
,
x_dims_v
.
data
(),
bytes
,
dev_ctx
.
stream
());
bytes
=
flip_dims_size
*
sizeof
(
int
)
;
auto
flip_dims_array_tmp
=
paddle
::
memory
::
Alloc
(
dev_ctx
.
GetPlace
(),
bytes
,
phi
::
Stream
(
reinterpret_cast
<
phi
::
StreamId
>
(
dev_ctx
.
stream
())));
int
*
flip_dims_array_gpu
=
reinterpret_cast
<
int
*>
(
flip_dims_array_tmp
->
ptr
())
;
paddle
::
memory
::
Copy
(
gplace
,
flip_dims_array_gpu
,
cplace
,
flip_dims
.
data
(),
bytes
,
dev_ctx
.
stream
());
phi
::
Array
<
int64_t
,
N
>
stride_a
;
phi
::
Array
<
int64_t
,
N
>
shape_a
;
phi
::
Array
<
int
,
N
>
flip_dims_a
;
size_t
flip_dims_size
=
flip_dims_v
.
size
();
for
(
size_t
idx
=
0
;
idx
<
N
;
++
idx
)
{
stride_a
[
idx
]
=
x_stride
[
idx
]
;
shape_a
[
idx
]
=
x_dims
[
idx
];
flip_dims_a
[
idx
]
=
idx
<
flip_dims_size
?
flip_dims_v
[
idx
]
:
0
;
}
flip_cuda_kernel
<
T
,
N
><<<
dim_grid
,
dim_block
,
0
,
dev_ctx
.
stream
()
>>>
(
numel
,
in_data
,
out_data
,
shape_a
,
stride_a
,
flip_dims_a
,
flip_dims_size
);
}
flip_cuda_kernel
<
T
>
<<<
dim_grid
,
dim_block
,
0
,
dev_ctx
.
stream
()
>>>
(
N
,
in_data
,
out_data
,
x_shape_array_gpu
,
x_strides_array_gpu
,
flip_dims_array_gpu
,
flip_dims_size
,
total_dims
);
template
<
typename
T
,
typename
Context
>
void
FlipKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
std
::
vector
<
int
>&
axis
,
DenseTensor
*
out
)
{
const
size_t
total_dims
=
x
.
dims
().
size
();
switch
(
total_dims
)
{
case
1
:
launch_flip_cuda_kernel
<
T
,
Context
,
1
>
(
dev_ctx
,
x
,
axis
,
out
);
break
;
case
2
:
launch_flip_cuda_kernel
<
T
,
Context
,
2
>
(
dev_ctx
,
x
,
axis
,
out
);
break
;
case
3
:
launch_flip_cuda_kernel
<
T
,
Context
,
3
>
(
dev_ctx
,
x
,
axis
,
out
);
break
;
case
4
:
launch_flip_cuda_kernel
<
T
,
Context
,
4
>
(
dev_ctx
,
x
,
axis
,
out
);
break
;
case
5
:
launch_flip_cuda_kernel
<
T
,
Context
,
5
>
(
dev_ctx
,
x
,
axis
,
out
);
break
;
case
6
:
launch_flip_cuda_kernel
<
T
,
Context
,
6
>
(
dev_ctx
,
x
,
axis
,
out
);
break
;
case
7
:
launch_flip_cuda_kernel
<
T
,
Context
,
7
>
(
dev_ctx
,
x
,
axis
,
out
);
break
;
case
8
:
launch_flip_cuda_kernel
<
T
,
Context
,
8
>
(
dev_ctx
,
x
,
axis
,
out
);
break
;
case
9
:
launch_flip_cuda_kernel
<
T
,
Context
,
9
>
(
dev_ctx
,
x
,
axis
,
out
);
break
;
default:
PADDLE_THROW
(
phi
::
errors
::
InvalidArgument
(
"dims of input tensor should be less than 10, But received"
"%d"
,
x
.
dims
().
size
()));
}
}
}
// namespace phi
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录