Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
兔爷不爱我
mindspore
提交
b750e3e1
M
mindspore
项目概览
兔爷不爱我
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
b750e3e1
编写于
7月 18, 2020
作者:
Z
zhaoting
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix gpu Split and Concat memory allocation bug
上级
5c0962ac
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
19 addition
and
19 deletion
+19
-19
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h
.../backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h
+5
-5
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h
...src/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h
+2
-2
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu
...rc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu
+5
-5
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh
...c/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh
+1
-1
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cu
...ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cu
+5
-5
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh
...csrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh
+1
-1
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h
浏览文件 @
b750e3e1
...
...
@@ -74,12 +74,12 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
inputs_host_
=
std
::
make_unique
<
T
*
[]
>
(
input_num_
);
len_axis_
=
std
::
make_unique
<
int
[]
>
(
input_num_
);
for
(
int
i
=
0
;
i
<
input_num_
;
i
++
)
{
in
t
input_size
=
1
;
size_
t
input_size
=
1
;
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
i
);
for
(
size_t
j
=
0
;
j
<
input_shape
.
size
();
j
++
)
{
input_size
*=
SizeToInt
(
input_shape
[
j
])
;
input_size
*=
input_shape
[
j
]
;
}
input_size_list_
.
push_back
(
IntToSize
(
input_size
*
sizeof
(
T
)
));
input_size_list_
.
push_back
(
input_size
*
sizeof
(
T
));
len_axis_
[
i
]
=
SizeToInt
(
input_shape
[
axis_
]);
}
workspace_size_list_
.
push_back
(
sizeof
(
T
*
)
*
input_num_
);
...
...
@@ -97,7 +97,7 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
all_size_before_axis_
*=
output_shape
[
i
];
}
}
output_size_list_
.
push_back
(
IntToSize
(
output_size_
*
sizeof
(
T
)
));
output_size_list_
.
push_back
(
output_size_
*
sizeof
(
T
));
InitSizeLists
();
return
true
;
...
...
@@ -117,7 +117,7 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
}
int
axis_
;
int
input_num_
;
in
t
output_size_
;
size_
t
output_size_
;
int
all_size_before_axis_
;
int
all_size_axis_
;
std
::
unique_ptr
<
T
*
[]
>
inputs_host_
;
...
...
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h
浏览文件 @
b750e3e1
...
...
@@ -83,7 +83,7 @@ class SplitGpuFwdKernel : public GpuKernel {
all_size_before_axis_
*=
input_shape
[
i
];
}
}
input_size_list_
.
push_back
(
IntToSize
(
input_size_
*
sizeof
(
T
)
));
input_size_list_
.
push_back
(
input_size_
*
sizeof
(
T
));
axis_step_
=
input_shape
[
axis_
]
/
output_num_
;
for
(
int
i
=
0
;
i
<
output_num_
;
i
++
)
{
...
...
@@ -138,7 +138,7 @@ class SplitGpuFwdKernel : public GpuKernel {
}
int
axis_
;
int
output_num_
;
in
t
input_size_
;
size_
t
input_size_
;
int
axis_step_
;
int
all_size_before_axis_
;
int
all_size_axis_
;
...
...
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu
浏览文件 @
b750e3e1
...
...
@@ -19,7 +19,7 @@
#include <cuda_runtime.h>
#include "backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh"
template
<
typename
T
>
__global__
void
Concat
(
const
in
t
size
,
const
int
input_num
,
__global__
void
Concat
(
const
size_
t
size
,
const
int
input_num
,
const
int
all_size_before_axis
,
const
int
all_size_axis
,
int
*
len_axis
,
T
**
inputs
,
T
*
output
)
{
for
(
int
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
(
size
);
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
...
...
@@ -45,7 +45,7 @@ __global__ void Concat(const int size, const int input_num,
}
template
<
typename
T
>
void
ConcatKernel
(
const
in
t
size
,
const
int
input_num
,
void
ConcatKernel
(
const
size_
t
size
,
const
int
input_num
,
const
int
all_size_before_axis
,
const
int
all_size_axis
,
int
*
len_axis
,
T
**
inputs
,
T
*
output
,
cudaStream_t
cuda_stream
)
{
...
...
@@ -55,15 +55,15 @@ void ConcatKernel(const int size, const int input_num,
return
;
}
template
void
ConcatKernel
(
const
in
t
size
,
const
int
input_num
,
template
void
ConcatKernel
(
const
size_
t
size
,
const
int
input_num
,
const
int
all_size_before_axis
,
const
int
all_size_axis
,
int
*
len_axis
,
float
**
inputs
,
float
*
output
,
cudaStream_t
cuda_stream
);
template
void
ConcatKernel
(
const
in
t
size
,
const
int
input_num
,
template
void
ConcatKernel
(
const
size_
t
size
,
const
int
input_num
,
const
int
all_size_before_axis
,
const
int
all_size_axis
,
int
*
len_axis
,
int
**
inputs
,
int
*
output
,
cudaStream_t
cuda_stream
);
template
void
ConcatKernel
(
const
in
t
size
,
const
int
input_num
,
template
void
ConcatKernel
(
const
size_
t
size
,
const
int
input_num
,
const
int
all_size_before_axis
,
const
int
all_size_axis
,
int
*
len_axis
,
half
**
inputs
,
half
*
output
,
cudaStream_t
cuda_stream
);
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh
浏览文件 @
b750e3e1
...
...
@@ -19,7 +19,7 @@
#include "runtime/device/gpu/cuda_common.h"
template
<
typename
T
>
void
ConcatKernel
(
const
in
t
size
,
const
int
input_num
,
void
ConcatKernel
(
const
size_
t
size
,
const
int
input_num
,
const
int
all_size_before_axis
,
const
int
all_size_axis
,
int
*
len_axis
,
T
**
inputs
,
T
*
output
,
cudaStream_t
cuda_stream
);
...
...
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cu
浏览文件 @
b750e3e1
...
...
@@ -19,7 +19,7 @@
#include <cuda_runtime.h>
#include "backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh"
template
<
typename
T
>
__global__
void
Split
(
const
in
t
size
,
const
int
axis_step
,
const
int
all_size_before_axis
,
__global__
void
Split
(
const
size_
t
size
,
const
int
axis_step
,
const
int
all_size_before_axis
,
const
int
all_size_axis
,
const
T
*
input
,
T
**
outputs
)
{
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
size
;
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
num
=
pos
%
all_size_before_axis
/
all_size_axis
;
...
...
@@ -32,19 +32,19 @@ __global__ void Split(const int size, const int axis_step, const int all_size_be
}
template
<
typename
T
>
void
SplitKernel
(
const
in
t
size
,
const
int
axis_step
,
const
int
all_size_before_axis
,
void
SplitKernel
(
const
size_
t
size
,
const
int
axis_step
,
const
int
all_size_before_axis
,
const
int
all_size_axis
,
const
T
*
input
,
T
**
outputs
,
cudaStream_t
cuda_stream
)
{
Split
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
size
,
axis_step
,
all_size_before_axis
,
all_size_axis
,
input
,
outputs
);
return
;
}
template
void
SplitKernel
(
const
in
t
size
,
const
int
axis_step
,
const
int
all_size_before_axis
,
template
void
SplitKernel
(
const
size_
t
size
,
const
int
axis_step
,
const
int
all_size_before_axis
,
const
int
all_size_axis
,
const
float
*
input
,
float
**
outputs
,
cudaStream_t
cuda_stream
);
template
void
SplitKernel
(
const
in
t
size
,
const
int
axis_step
,
const
int
all_size_before_axis
,
template
void
SplitKernel
(
const
size_
t
size
,
const
int
axis_step
,
const
int
all_size_before_axis
,
const
int
all_size_axis
,
const
int
*
input
,
int
**
outputs
,
cudaStream_t
cuda_stream
);
template
void
SplitKernel
(
const
in
t
size
,
const
int
axis_step
,
const
int
all_size_before_axis
,
template
void
SplitKernel
(
const
size_
t
size
,
const
int
axis_step
,
const
int
all_size_before_axis
,
const
int
all_size_axis
,
const
half
*
input
,
half
**
outputs
,
cudaStream_t
cuda_stream
);
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh
浏览文件 @
b750e3e1
...
...
@@ -19,6 +19,6 @@
#include "runtime/device/gpu/cuda_common.h"
template
<
typename
T
>
void
SplitKernel
(
const
in
t
size
,
const
int
axis_step
,
const
int
all_size_before_axis
,
void
SplitKernel
(
const
size_
t
size
,
const
int
axis_step
,
const
int
all_size_before_axis
,
const
int
all_size_axis
,
const
T
*
input
,
T
**
outputs
,
cudaStream_t
cuda_stream
);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPLIT_H_
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录