Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
4e25fec7
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4e25fec7
编写于
5月 07, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 07, 2020
浏览文件
操作
浏览文件
下载
差异文件
!324 Gpu Slice kernel performance improve
Merge pull request !324 from chenweifeng/slice
上级
378a7122
5b7790a2
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
60 addition
and
33 deletion
+60
-33
mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.h
mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.h
+3
-2
mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cu
mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cu
+33
-29
mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cuh
mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cuh
+5
-2
tests/st/ops/gpu/test_slice.py
tests/st/ops/gpu/test_slice.py
+19
-0
未找到文件。
mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.h
浏览文件 @
4e25fec7
...
...
@@ -41,8 +41,9 @@ class SliceGpuFwdKernel : public GpuKernel {
CalStridedSlice
(
output_size_
/
sizeof
(
T
),
input
,
input_shape_
,
begin_
,
size_
,
strides_
,
output
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
}
else
{
CalSlice
(
output_size_
/
sizeof
(
T
),
input
,
input_shape_
,
begin_
,
size_
,
output
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
Slice4DKernel
(
begin_
[
0
],
begin_
[
1
],
begin_
[
2
],
begin_
[
3
],
size_
[
0
],
size_
[
1
],
size_
[
2
],
size_
[
3
],
input_shape_
[
0
],
input_shape_
[
1
],
input_shape_
[
2
],
input_shape_
[
3
],
input
,
output
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
}
return
true
;
}
...
...
mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cu
浏览文件 @
4e25fec7
...
...
@@ -21,11 +21,22 @@
#include "kernel/gpu/cuda_impl/slice_impl.cuh"
template
<
typename
T
>
__global__
void
Slice
(
const
T
*
input
,
int
p
,
int
start
,
int
length
,
T
*
output
)
{
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
(
length
);
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
output
[
p
+
pos
]
=
input
[
start
+
pos
];
__global__
void
Slice4D
(
const
int
s1
,
const
int
s2
,
const
int
s3
,
const
int
s4
,
const
int
l1
,
const
int
l2
,
const
int
l3
,
const
int
l4
,
const
int
d1
,
const
int
d2
,
const
int
d3
,
const
int
d4
,
const
T
*
input
,
T
*
output
)
{
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
(
l1
*
l2
*
l3
*
l4
);
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
i
=
pos
/
(
l2
*
l3
*
l4
)
%
l1
;
int
j
=
pos
/
(
l3
*
l4
)
%
l2
;
int
k
=
pos
/
l4
%
l3
;
int
o
=
pos
%
l4
;
int
offset
=
(
i
+
s1
)
*
(
d2
*
d3
*
d4
)
+
(
j
+
s2
)
*
(
d3
*
d4
)
+
(
k
+
s3
)
*
d4
+
(
o
+
s4
);
output
[
pos
]
=
input
[
offset
];
}
return
;
}
template
<
typename
T
>
__global__
void
SliceGrad
(
const
T
*
dy
,
int
p
,
int
start
,
int
length
,
T
*
output
)
{
...
...
@@ -64,22 +75,12 @@ void FillDeviceArray(const size_t input_size, T* addr, const float value, cudaSt
return
;
}
template
<
typename
T
>
void
CalSlice
(
const
size_t
input_size
,
const
T
*
input
,
const
std
::
vector
<
int
>
in_shape
,
const
std
::
vector
<
int
>
begin
,
const
std
::
vector
<
int
>
size
,
T
*
output
,
cudaStream_t
cuda_stream
)
{
int
block
=
in_shape
[
1
]
*
in_shape
[
2
]
*
in_shape
[
3
];
int
map
=
in_shape
[
2
]
*
in_shape
[
3
];
int
w
=
in_shape
[
3
];
int
length
=
size
[
3
];
int
p
=
0
;
for
(
int
i
=
begin
[
0
];
i
<
size
[
0
]
+
begin
[
0
];
i
++
)
{
for
(
int
j
=
begin
[
1
];
j
<
size
[
1
]
+
begin
[
1
];
j
++
)
{
for
(
int
k
=
begin
[
2
];
k
<
size
[
2
]
+
begin
[
2
];
k
++
)
{
Slice
<<<
GET_BLOCKS
(
input_size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
input
,
p
,
i
*
block
+
j
*
map
+
k
*
w
+
begin
[
3
],
length
,
output
);
p
=
p
+
size
[
3
];
}
}
}
void
Slice4DKernel
(
const
int
s1
,
const
int
s2
,
const
int
s3
,
const
int
s4
,
const
int
l1
,
const
int
l2
,
const
int
l3
,
const
int
l4
,
const
int
d1
,
const
int
d2
,
const
int
d3
,
const
int
d4
,
const
T
*
input
,
T
*
output
,
cudaStream_t
stream
)
{
Slice4D
<<<
GET_BLOCKS
(
l1
*
l2
*
l3
*
l4
),
GET_THREADS
,
0
,
stream
>>>
(
s1
,
s2
,
s3
,
s4
,
l1
,
l2
,
l3
,
l4
,
d1
,
d2
,
d3
,
d4
,
input
,
output
);
}
template
<
typename
T
>
void
CalSliceGrad
(
const
size_t
input_size
,
const
T
*
dy
,
const
std
::
vector
<
int
>
in_shape
,
const
std
::
vector
<
int
>
begin
,
...
...
@@ -147,9 +148,10 @@ void CalStridedSliceGrad(const size_t input_size, const T* dy, const std::vector
}
template
void
FillDeviceArray
<
float
>(
const
size_t
input_size
,
float
*
addr
,
const
float
value
,
cudaStream_t
cuda_stream
);
template
void
CalSlice
<
float
>(
const
size_t
input_size
,
const
float
*
input
,
const
std
::
vector
<
int
>
in_shape
,
const
std
::
vector
<
int
>
begin
,
const
std
::
vector
<
int
>
size
,
float
*
output
,
cudaStream_t
cuda_stream
);
template
void
Slice4DKernel
(
const
int
s1
,
const
int
s2
,
const
int
s3
,
const
int
s4
,
const
int
l1
,
const
int
l2
,
const
int
l3
,
const
int
l4
,
const
int
d1
,
const
int
d2
,
const
int
d3
,
const
int
d4
,
const
float
*
input
,
float
*
output
,
cudaStream_t
stream
);
template
void
CalSliceGrad
<
float
>(
const
size_t
input_size
,
const
float
*
dy
,
const
std
::
vector
<
int
>
in_shape
,
const
std
::
vector
<
int
>
begin
,
const
std
::
vector
<
int
>
size
,
float
*
output
,
cudaStream_t
cuda_stream
);
...
...
@@ -160,9 +162,10 @@ template void CalStridedSliceGrad<float>(const size_t input_size, const float* d
const
std
::
vector
<
int
>
begin
,
const
std
::
vector
<
int
>
end
,
const
std
::
vector
<
int
>
strides
,
float
*
dx
,
cudaStream_t
cuda_stream
);
template
void
FillDeviceArray
<
half
>(
const
size_t
input_size
,
half
*
addr
,
const
float
value
,
cudaStream_t
cuda_stream
);
template
void
CalSlice
<
half
>(
const
size_t
input_size
,
const
half
*
input
,
const
std
::
vector
<
int
>
in_shape
,
const
std
::
vector
<
int
>
begin
,
const
std
::
vector
<
int
>
size
,
half
*
output
,
cudaStream_t
cuda_stream
);
template
void
Slice4DKernel
(
const
int
s1
,
const
int
s2
,
const
int
s3
,
const
int
s4
,
const
int
l1
,
const
int
l2
,
const
int
l3
,
const
int
l4
,
const
int
d1
,
const
int
d2
,
const
int
d3
,
const
int
d4
,
const
half
*
input
,
half
*
output
,
cudaStream_t
stream
);
template
void
CalSliceGrad
<
half
>(
const
size_t
input_size
,
const
half
*
dy
,
const
std
::
vector
<
int
>
in_shape
,
const
std
::
vector
<
int
>
begin
,
const
std
::
vector
<
int
>
size
,
half
*
output
,
cudaStream_t
cuda_stream
);
...
...
@@ -173,9 +176,10 @@ template void CalStridedSliceGrad<half>(const size_t input_size, const half* dy,
const
std
::
vector
<
int
>
begin
,
const
std
::
vector
<
int
>
end
,
const
std
::
vector
<
int
>
strides
,
half
*
dx
,
cudaStream_t
cuda_stream
);
template
void
FillDeviceArray
<
int
>(
const
size_t
input_size
,
int
*
addr
,
const
float
value
,
cudaStream_t
cuda_stream
);
template
void
CalSlice
<
int
>(
const
size_t
input_size
,
const
int
*
input
,
const
std
::
vector
<
int
>
in_shape
,
const
std
::
vector
<
int
>
begin
,
const
std
::
vector
<
int
>
size
,
int
*
output
,
cudaStream_t
cuda_stream
);
template
void
Slice4DKernel
(
const
int
s1
,
const
int
s2
,
const
int
s3
,
const
int
s4
,
const
int
l1
,
const
int
l2
,
const
int
l3
,
const
int
l4
,
const
int
d1
,
const
int
d2
,
const
int
d3
,
const
int
d4
,
const
int
*
input
,
int
*
output
,
cudaStream_t
stream
);
template
void
CalSliceGrad
<
int
>(
const
size_t
input_size
,
const
int
*
dy
,
const
std
::
vector
<
int
>
in_shape
,
const
std
::
vector
<
int
>
begin
,
const
std
::
vector
<
int
>
size
,
int
*
output
,
cudaStream_t
cuda_stream
);
...
...
mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cuh
浏览文件 @
4e25fec7
...
...
@@ -21,9 +21,12 @@
#include <vector>
#include "device/gpu/cuda_common.h"
template
<
typename
T
>
void
CalSlice
(
const
size_t
input_size
,
const
T
*
input
,
const
std
::
vector
<
int
>
in_shape
,
const
std
::
vector
<
int
>
begin
,
const
std
::
vector
<
int
>
size
,
T
*
output
,
cudaStream_t
cuda_stream
);
void
Slice4DKernel
(
const
int
s1
,
const
int
s2
,
const
int
s3
,
const
int
s4
,
const
int
l1
,
const
int
l2
,
const
int
l3
,
const
int
l4
,
const
int
d1
,
const
int
d2
,
const
int
d3
,
const
int
d4
,
const
T
*
input
,
T
*
output
,
cudaStream_t
stream
);
template
<
typename
T
>
void
CalSliceGrad
(
const
size_t
input_size
,
const
T
*
input
,
const
std
::
vector
<
int
>
in_shape
,
const
std
::
vector
<
int
>
begin
,
const
std
::
vector
<
int
>
size
,
T
*
output
,
cudaStream_t
cuda_stream
);
...
...
tests/st/ops/gpu/test_slice.py
浏览文件 @
4e25fec7
...
...
@@ -43,3 +43,22 @@ def test_slice():
slice
=
Slice
()
output
=
slice
(
x
)
assert
(
output
.
asnumpy
()
==
expect
).
all
()
class
SliceNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
SliceNet
,
self
).
__init__
()
self
.
slice
=
P
.
Slice
()
def
construct
(
self
,
x
):
return
self
.
slice
(
x
,
(
0
,
11
,
0
,
0
),
(
32
,
7
,
224
,
224
))
def
test_slice_4d
():
x_np
=
np
.
random
.
randn
(
32
,
24
,
224
,
224
).
astype
(
np
.
float32
)
output_np
=
x_np
[:,
11
:
18
,
:,
:]
x_ms
=
Tensor
(
x_np
)
net
=
SliceNet
()
output_ms
=
net
(
x_ms
)
assert
(
output_ms
.
asnumpy
()
==
output_np
).
all
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录