Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
02f33a17
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
02f33a17
编写于
5月 27, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 27, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1489 GPU fix slice
Merge pull request !1489 from VectorSL/gpu-fix-slice
上级
2e684e89
f1bd5aba
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
7 addition
and
4 deletion
+7
-4
mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.h
mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.h
+3
-0
mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cu
mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cu
+4
-4
未找到文件。
mindspore/ccsrc/kernel/gpu/arrays/slice_gpu_kernel.h
浏览文件 @
02f33a17
...
...
@@ -79,6 +79,9 @@ class SliceGpuFwdKernel : public GpuKernel {
if
(
size_
[
i
]
<
0
)
{
size_
[
i
]
=
(
size_
[
i
]
+
input_shape_
[
i
])
>
0
?
(
size_
[
i
]
+
input_shape_
[
i
])
:
0
;
}
if
(
size_
[
i
]
==
0
)
{
size_
[
i
]
=
begin_
[
i
]
+
1
;
}
}
input_size_
=
IntToSize
(
input_shape_
[
0
]
*
input_shape_
[
1
]
*
input_shape_
[
2
]
*
input_shape_
[
3
])
*
sizeof
(
T
);
...
...
mindspore/ccsrc/kernel/gpu/cuda_impl/slice_impl.cu
浏览文件 @
02f33a17
...
...
@@ -47,7 +47,7 @@ __global__ void SliceGrad(const T* dy, int p, int start, int length, T* output)
}
template
<
typename
T
>
__global__
void
StridedSlice
(
const
T
*
input
,
int
p
,
int
start
,
int
begin
,
int
stride
,
int
ended
,
T
*
output
)
{
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
((
ended
-
1
-
begin
)
/
stride
)
+
1
;
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
std
::
ceil
(
static_cast
<
float
>
(
ended
-
begin
)
/
stride
)
;
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
output
[
p
+
pos
]
=
input
[
start
+
pos
*
stride
];
}
...
...
@@ -55,7 +55,7 @@ __global__ void StridedSlice(const T* input, int p, int start, int begin, int st
}
template
<
typename
T
>
__global__
void
StridedSliceGrad
(
const
T
*
dy
,
int
p
,
int
start
,
int
begin
,
int
stride
,
int
ended
,
T
*
dx
)
{
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
((
ended
-
1
-
begin
)
/
stride
)
+
1
;
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
std
::
ceil
(
static_cast
<
float
>
(
ended
-
begin
)
/
stride
)
;
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
dx
[
start
+
pos
*
stride
]
=
dy
[
p
+
pos
];
}
...
...
@@ -117,7 +117,7 @@ void CalStridedSlice(const size_t input_size, const T* input, const std::vector<
(
strides
[
2
]
>
0
?
k
:
2
*
begin
[
2
]
-
k
)
*
w
+
begin
[
3
];
StridedSlice
<<<
GET_BLOCKS
(
input_size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
input
,
p
,
start
,
begin
[
3
],
strides
[
3
],
ended
,
output
);
p
=
p
+
(
end
[
3
]
-
1
-
begin
[
3
])
/
strides
[
3
]
+
1
;
p
=
p
+
std
::
ceil
(
static_cast
<
float
>
(
end
[
3
]
-
begin
[
3
])
/
strides
[
3
])
;
}
}
}
...
...
@@ -141,7 +141,7 @@ void CalStridedSliceGrad(const size_t input_size, const T* dy, const std::vector
(
strides
[
2
]
>
0
?
k
:
2
*
begin
[
2
]
-
k
)
*
w
+
begin
[
3
];
StridedSliceGrad
<<<
GET_BLOCKS
(
input_size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
dy
,
p
,
start
,
begin
[
3
],
strides
[
3
],
ended
,
dx
);
p
=
p
+
(
end
[
3
]
-
1
-
begin
[
3
])
/
strides
[
3
]
+
1
;
p
=
p
+
std
::
ceil
(
static_cast
<
float
>
(
end
[
3
]
-
begin
[
3
])
/
strides
[
3
])
;
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录