Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
5c0962ac
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5c0962ac
编写于
7月 15, 2020
作者:
Z
zhaoting
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add gpu split and restructure gpu concat
上级
45ad430a
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
406 addition
and
130 deletion
+406
-130
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h
.../backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h
+48
-44
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.cc
...rc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.cc
+31
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h
...src/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h
+153
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu
...rc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu
+39
-78
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh
...c/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh
+3
-8
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cu
...ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cu
+50
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh
...csrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh
+24
-0
tests/st/ops/gpu/test_split.py
tests/st/ops/gpu/test_split.py
+58
-0
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h
浏览文件 @
5c0962ac
...
...
@@ -18,6 +18,7 @@
#define MINDSPORE_CCSRC_KERNEL_GPU_CONCATV2_GPU_KERNEL_H
#include <vector>
#include <memory>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh"
...
...
@@ -27,40 +28,35 @@ namespace kernel {
template
<
typename
T
>
class
ConcatV2GpuFwdKernel
:
public
GpuKernel
{
public:
ConcatV2GpuFwdKernel
()
:
axis_
(
0
),
output_size_
(
0
)
{}
ConcatV2GpuFwdKernel
()
:
axis_
(
0
),
input_num_
(
1
),
output_size_
(
0
),
all_size_before_axis_
(
1
),
all_size_axis_
(
1
),
inputs_host_
(
nullptr
),
len_axis_
(
nullptr
)
{}
~
ConcatV2GpuFwdKernel
()
override
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
{
return
workspace_size_list_
;
}
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
,
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
override
{
if
(
inputs
.
size
()
==
2
)
{
T
*
input_0
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
T
*
input_1
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
T
*
output
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
ConcatKernel
(
output_size_
/
sizeof
(
T
),
w_
[
0
],
w_
[
1
],
input_0
,
input_1
,
output
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
}
if
(
inputs
.
size
()
==
3
)
{
T
*
input_0
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
T
*
input_1
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
T
*
input_2
=
GetDeviceAddress
<
T
>
(
inputs
,
2
);
T
*
output
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
ConcatKernel
(
output_size_
/
sizeof
(
T
),
w_
[
0
],
w_
[
1
],
w_
[
2
],
input_0
,
input_1
,
input_2
,
output
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
}
if
(
inputs
.
size
()
==
4
)
{
T
*
input_0
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
T
*
input_1
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
T
*
input_2
=
GetDeviceAddress
<
T
>
(
inputs
,
2
);
T
*
input_3
=
GetDeviceAddress
<
T
>
(
inputs
,
3
);
T
*
output
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
ConcatKernel
(
output_size_
/
sizeof
(
T
),
w_
[
0
],
w_
[
1
],
w_
[
2
],
w_
[
3
],
input_0
,
input_1
,
input_2
,
input_3
,
output
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
T
*
output
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
T
**
inputs_device
=
GetDeviceAddress
<
T
*>
(
workspace
,
0
);
int
*
len_axis_device
=
GetDeviceAddress
<
int
>
(
workspace
,
1
);
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
i
++
)
{
inputs_host_
[
i
]
=
GetDeviceAddress
<
T
>
(
inputs
,
i
);
}
CHECK_CUDA_RET_WITH_EXCEPT
(
cudaMemcpyAsync
(
inputs_device
,
inputs_host_
.
get
(),
sizeof
(
T
*
)
*
input_num_
,
cudaMemcpyHostToDevice
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
)),
"ConcatV2 opt cudaMemcpyAsync inputs failed"
);
CHECK_CUDA_RET_WITH_EXCEPT
(
cudaMemcpyAsync
(
len_axis_device
,
len_axis_
.
get
(),
sizeof
(
int
)
*
input_num_
,
cudaMemcpyHostToDevice
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
)),
"ConcatV2 opt cudaMemcpyAsync length on axis failed"
);
ConcatKernel
(
output_size_
,
input_num_
,
all_size_before_axis_
,
all_size_axis_
,
len_axis_device
,
inputs_device
,
output
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
...
...
@@ -74,25 +70,34 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
axis_
+=
SizeToInt
(
input_shape
.
size
());
}
auto
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
for
(
size_t
i
=
0
;
i
<
input_num
;
i
++
)
{
auto
input_size
=
sizeof
(
T
);
input_num_
=
SizeToInt
(
AnfAlgo
::
GetInputTensorNum
(
kernel_node
));
inputs_host_
=
std
::
make_unique
<
T
*
[]
>
(
input_num_
);
len_axis_
=
std
::
make_unique
<
int
[]
>
(
input_num_
);
for
(
int
i
=
0
;
i
<
input_num_
;
i
++
)
{
int
input_size
=
1
;
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
i
);
for
(
size_t
j
=
0
;
j
<
input_shape
.
size
();
j
++
)
{
input_size
*=
SizeToInt
(
input_shape
[
j
]);
if
(
j
>=
IntToSize
(
axis_
))
{
w_
[
i
]
*=
SizeToInt
(
input_shape
[
j
]);
}
input_size_list_
.
push_back
(
input_size
);
}
input_size_list_
.
push_back
(
IntToSize
(
input_size
*
sizeof
(
T
)));
len_axis_
[
i
]
=
SizeToInt
(
input_shape
[
axis_
]);
}
workspace_size_list_
.
push_back
(
sizeof
(
T
*
)
*
input_num_
);
workspace_size_list_
.
push_back
(
sizeof
(
int
)
*
input_num_
);
auto
output_shape
=
AnfAlgo
::
GetOutputInferShape
(
kernel_node
,
0
);
output_size_
=
sizeof
(
T
)
;
for
(
size_t
i
=
0
;
i
<
output_shape
.
size
(
);
i
++
)
{
output_size_
=
1
;
for
(
int
i
=
0
;
i
<
SizeToInt
(
output_shape
.
size
()
);
i
++
)
{
output_size_
*=
output_shape
[
i
];
if
(
i
>
axis_
)
{
all_size_before_axis_
*=
output_shape
[
i
];
all_size_axis_
*=
output_shape
[
i
];
}
if
(
i
==
axis_
)
{
all_size_before_axis_
*=
output_shape
[
i
];
}
}
output_size_list_
.
push_back
(
output_size_
);
output_size_list_
.
push_back
(
IntToSize
(
output_size_
*
sizeof
(
T
))
);
InitSizeLists
();
return
true
;
...
...
@@ -103,11 +108,6 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
private:
bool
CheckParam
(
const
CNodePtr
&
kernel_node
)
{
size_t
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
if
(
input_num
<
2
||
input_num
>
4
)
{
MS_LOG
(
ERROR
)
<<
"Input number is "
<<
input_num
<<
", but ConcatV2GpuFwdKernel needs inputs between 2 and 4."
;
return
false
;
}
size_t
output_num
=
AnfAlgo
::
GetOutputTensorNum
(
kernel_node
);
if
(
output_num
!=
1
)
{
MS_LOG
(
ERROR
)
<<
"Output number is "
<<
output_num
<<
", but ConcatV2GpuFwdKernel needs 1 output."
;
...
...
@@ -115,9 +115,13 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
}
return
true
;
}
int
w_
[
4
]
=
{
1
,
1
,
1
,
1
};
int
axis_
;
size_t
output_size_
;
int
input_num_
;
int
output_size_
;
int
all_size_before_axis_
;
int
all_size_axis_
;
std
::
unique_ptr
<
T
*
[]
>
inputs_host_
;
std
::
unique_ptr
<
int
[]
>
len_axis_
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
...
...
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.cc
0 → 100644
浏览文件 @
5c0962ac
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h"
namespace
mindspore
{
namespace
kernel
{
MS_REG_GPU_KERNEL_ONE
(
Split
,
KernelAttr
().
AddAllSameAttr
(
true
).
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
SplitGpuFwdKernel
,
float
)
MS_REG_GPU_KERNEL_ONE
(
Split
,
KernelAttr
().
AddAllSameAttr
(
true
).
AddInputAttr
(
kNumberTypeInt32
).
AddOutputAttr
(
kNumberTypeInt32
),
SplitGpuFwdKernel
,
int
)
MS_REG_GPU_KERNEL_ONE
(
Split
,
KernelAttr
().
AddAllSameAttr
(
true
).
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
SplitGpuFwdKernel
,
half
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h
0 → 100644
浏览文件 @
5c0962ac
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_SPLIT_GPU_KERNEL_H
#define MINDSPORE_CCSRC_KERNEL_GPU_SPLIT_GPU_KERNEL_H
#include <vector>
#include <memory>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh"
namespace
mindspore
{
namespace
kernel
{
template
<
typename
T
>
class
SplitGpuFwdKernel
:
public
GpuKernel
{
public:
SplitGpuFwdKernel
()
:
axis_
(
0
),
output_num_
(
1
),
input_size_
(
1
),
axis_step_
(
1
),
all_size_before_axis_
(
1
),
all_size_axis_
(
1
),
outputs_host_
(
nullptr
)
{}
~
SplitGpuFwdKernel
()
override
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
{
return
workspace_size_list_
;
}
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
override
{
T
*
input
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
T
**
outputs_device
=
GetDeviceAddress
<
T
*>
(
workspace
,
0
);
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
i
++
)
{
outputs_host_
[
i
]
=
GetDeviceAddress
<
T
>
(
outputs
,
i
);
}
CHECK_CUDA_RET_WITH_EXCEPT
(
cudaMemcpyAsync
(
outputs_device
,
outputs_host_
.
get
(),
sizeof
(
T
*
)
*
output_num_
,
cudaMemcpyHostToDevice
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
)),
"Split opt cudaMemcpyAsync outputs failed"
);
SplitKernel
(
input_size_
,
axis_step_
,
all_size_before_axis_
,
all_size_axis_
,
input
,
outputs_device
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
axis_
=
GetAttr
<
int
>
(
kernel_node
,
"axis"
);
if
(
axis_
<
0
)
{
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
axis_
+=
SizeToInt
(
input_shape
.
size
());
}
output_num_
=
GetAttr
<
int
>
(
kernel_node
,
"output_num"
);
if
(
!
CheckParam
(
kernel_node
))
{
return
false
;
}
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
input_size_
=
1
;
all_size_before_axis_
=
1
;
all_size_axis_
=
1
;
for
(
int
i
=
0
;
i
<
SizeToInt
(
input_shape
.
size
());
i
++
)
{
input_size_
*=
input_shape
[
i
];
if
(
i
>
axis_
)
{
all_size_before_axis_
*=
input_shape
[
i
];
all_size_axis_
*=
input_shape
[
i
];
}
if
(
i
==
axis_
)
{
all_size_before_axis_
*=
input_shape
[
i
];
}
}
input_size_list_
.
push_back
(
IntToSize
(
input_size_
*
sizeof
(
T
)));
axis_step_
=
input_shape
[
axis_
]
/
output_num_
;
for
(
int
i
=
0
;
i
<
output_num_
;
i
++
)
{
size_t
output_size
=
1
;
auto
output_shape
=
AnfAlgo
::
GetOutputInferShape
(
kernel_node
,
i
);
for
(
size_t
j
=
0
;
j
<
output_shape
.
size
();
j
++
)
{
output_size
*=
output_shape
[
j
];
}
output_size_list_
.
push_back
(
output_size
*
sizeof
(
T
));
}
workspace_size_list_
.
push_back
(
sizeof
(
T
*
)
*
output_num_
);
InitSizeLists
();
outputs_host_
=
std
::
make_unique
<
T
*
[]
>
(
output_num_
);
return
true
;
}
protected:
void
InitSizeLists
()
override
{}
private:
bool
CheckParam
(
const
CNodePtr
&
kernel_node
)
{
auto
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
int
dims
=
SizeToInt
(
input_shape
.
size
());
int
output_num
=
SizeToInt
(
AnfAlgo
::
GetOutputTensorNum
(
kernel_node
));
if
(
input_num
!=
1
)
{
MS_LOG
(
ERROR
)
<<
"Input number is "
<<
input_num
<<
", but Split needs 1 input."
;
return
false
;
}
if
(
dims
==
0
)
{
MS_LOG
(
ERROR
)
<<
"Input dims is "
<<
dims
<<
", scalar is not supported."
;
return
false
;
}
if
(
axis_
<
-
dims
||
axis_
>=
dims
)
{
MS_LOG
(
ERROR
)
<<
"Attr axis "
<<
axis_
<<
" must be in "
<<
-
dims
<<
"~"
<<
dims
;
return
false
;
}
if
(
output_num_
>
SizeToInt
(
input_shape
[
axis_
]))
{
MS_LOG
(
ERROR
)
<<
"Attr output_num "
<<
output_num_
<<
"must less than"
<<
input_shape
[
axis_
];
return
false
;
}
if
(
input_shape
[
axis_
]
%
output_num_
!=
0
)
{
MS_LOG
(
ERROR
)
<<
"Attr output_num "
<<
output_num_
<<
"must be divided by"
<<
input_shape
[
axis_
];
return
false
;
}
if
(
output_num_
!=
output_num
)
{
MS_LOG
(
ERROR
)
<<
"Output num is "
<<
output_num
<<
", but need "
<<
output_num_
;
return
false
;
}
return
true
;
}
int
axis_
;
int
output_num_
;
int
input_size_
;
int
axis_step_
;
int
all_size_before_axis_
;
int
all_size_axis_
;
std
::
unique_ptr
<
T
*
[]
>
outputs_host_
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
};
}
// namespace kernel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_SPLIT_GPU_KERNEL_H
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu
浏览文件 @
5c0962ac
...
...
@@ -19,90 +19,51 @@
#include <cuda_runtime.h>
#include "backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh"
template
<
typename
T
>
__global__
void
Concat
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
T
*
input_1
,
const
T
*
input_2
,
T
*
output
)
{
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
(
size
);
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
n
=
pos
/
(
w1
+
w2
);
int
m
=
pos
%
(
w1
+
w2
);
output
[
pos
]
=
m
>=
w1
?
input_2
[
n
*
w2
+
m
-
w1
]
:
input_1
[
n
*
w1
+
m
];
__global__
void
Concat
(
const
int
size
,
const
int
input_num
,
const
int
all_size_before_axis
,
const
int
all_size_axis
,
int
*
len_axis
,
T
**
inputs
,
T
*
output
)
{
for
(
int
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
(
size
);
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
num
=
pos
%
all_size_before_axis
/
all_size_axis
;
int
block
=
-
1
;
int
axis_inc
=
0
;
int
block_len
=
0
;
for
(
int
i
=
0
;
i
<
input_num
;
i
++
)
{
if
(
axis_inc
<=
num
)
{
block
++
;
axis_inc
+=
len_axis
[
i
];
}
else
{
break
;
}
}
block_len
=
len_axis
[
block
];
axis_inc
-=
len_axis
[
block
];
int
block_pos
=
pos
/
all_size_before_axis
*
block_len
*
all_size_axis
+
(
num
-
axis_inc
)
*
all_size_axis
+
pos
%
all_size_axis
;;
output
[
pos
]
=
inputs
[
block
][
block_pos
];
}
return
;
}
template
<
typename
T
>
__global__
void
Concat
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
int
w3
,
const
T
*
input_1
,
const
T
*
input_2
,
const
T
*
input_3
,
T
*
output
)
{
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
(
size
);
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
n
=
pos
/
(
w1
+
w2
+
w3
);
int
m
=
pos
%
(
w1
+
w2
+
w3
);
output
[
pos
]
=
m
<
w1
?
input_1
[
n
*
w1
+
m
]
:
m
<
w1
+
w2
?
input_2
[
n
*
w2
+
m
-
w1
]
:
input_3
[
n
*
w3
+
m
-
w1
-
w2
];
}
return
;
}
template
<
typename
T
>
__global__
void
Concat
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
int
w3
,
const
int
w4
,
const
T
*
input_1
,
const
T
*
input_2
,
const
T
*
input_3
,
const
T
*
input_4
,
T
*
output
)
{
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
(
size
);
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
n
=
pos
/
(
w1
+
w2
+
w3
+
w4
);
int
m
=
pos
%
(
w1
+
w2
+
w3
+
w4
);
output
[
pos
]
=
m
<
w1
?
input_1
[
n
*
w1
+
m
]
:
m
<
w1
+
w2
?
input_2
[
n
*
w2
+
m
-
w1
]
:
m
<
w1
+
w2
+
w3
?
input_3
[
n
*
w3
+
m
-
w1
-
w2
]
:
input_4
[
n
*
w4
+
m
-
w1
-
w2
-
w3
];
}
return
;
}
template
<
typename
T
>
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
T
*
input_1
,
const
T
*
input_2
,
T
*
output
,
cudaStream_t
cuda_stream
)
{
Concat
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
size
,
w1
,
w2
,
input_1
,
input_2
,
output
);
return
;
}
template
<
typename
T
>
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
int
w3
,
const
T
*
input_1
,
const
T
*
input_2
,
const
T
*
input_3
,
T
*
output
,
void
ConcatKernel
(
const
int
size
,
const
int
input_num
,
const
int
all_size_before_axis
,
const
int
all_size_axis
,
int
*
len_axis
,
T
**
inputs
,
T
*
output
,
cudaStream_t
cuda_stream
)
{
Concat
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
size
,
w1
,
w2
,
w3
,
input_1
,
input_2
,
input_3
,
output
);
Concat
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
size
,
input_num
,
all_size_before_axis
,
all_size_axis
,
len_axis
,
inputs
,
output
);
return
;
}
template
<
typename
T
>
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
int
w3
,
const
int
w4
,
const
T
*
input_1
,
const
T
*
input_2
,
const
T
*
input_3
,
const
T
*
input_4
,
T
*
output
,
cudaStream_t
cuda_stream
)
{
Concat
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
size
,
w1
,
w2
,
w3
,
w4
,
input_1
,
input_2
,
input_3
,
input_4
,
output
);
return
;
}
template
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
float
*
input_1
,
const
float
*
input_2
,
float
*
output
,
cudaStream_t
cuda_stream
);
template
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
int
*
input_1
,
const
int
*
input_2
,
int
*
output
,
cudaStream_t
cuda_stream
);
template
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
half
*
input_1
,
const
half
*
input_2
,
half
*
output
,
cudaStream_t
cuda_stream
);
template
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
int
w3
,
const
float
*
input_1
,
const
float
*
input_2
,
const
float
*
input_3
,
float
*
output
,
cudaStream_t
cuda_stream
);
template
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
int
w3
,
const
int
*
input_1
,
const
int
*
input_2
,
const
int
*
input_3
,
int
*
output
,
cudaStream_t
cuda_stream
);
template
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
int
w3
,
const
half
*
input_1
,
const
half
*
input_2
,
const
half
*
input_3
,
half
*
output
,
cudaStream_t
cuda_stream
);
template
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
int
w3
,
const
int
w4
,
const
float
*
input_1
,
const
float
*
input_2
,
const
float
*
input_3
,
const
float
*
input_4
,
float
*
output
,
cudaStream_t
cuda_stream
);
template
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
int
w3
,
const
int
w4
,
const
int
*
input_1
,
const
int
*
input_2
,
const
int
*
input_3
,
const
int
*
input_4
,
int
*
output
,
cudaStream_t
cuda_stream
);
template
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
int
w3
,
const
int
w4
,
const
half
*
input_1
,
const
half
*
input_2
,
const
half
*
input_3
,
const
half
*
input_4
,
half
*
output
,
cudaStream_t
cuda_stream
);
template
void
ConcatKernel
(
const
int
size
,
const
int
input_num
,
const
int
all_size_before_axis
,
const
int
all_size_axis
,
int
*
len_axis
,
float
**
inputs
,
float
*
output
,
cudaStream_t
cuda_stream
);
template
void
ConcatKernel
(
const
int
size
,
const
int
input_num
,
const
int
all_size_before_axis
,
const
int
all_size_axis
,
int
*
len_axis
,
int
**
inputs
,
int
*
output
,
cudaStream_t
cuda_stream
);
template
void
ConcatKernel
(
const
int
size
,
const
int
input_num
,
const
int
all_size_before_axis
,
const
int
all_size_axis
,
int
*
len_axis
,
half
**
inputs
,
half
*
output
,
cudaStream_t
cuda_stream
);
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh
浏览文件 @
5c0962ac
...
...
@@ -19,13 +19,8 @@
#include "runtime/device/gpu/cuda_common.h"
template
<
typename
T
>
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
T
*
input_1
,
const
T
*
input_2
,
T
*
output
,
cudaStream_t
cuda_stream
);
template
<
typename
T
>
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
int
w3
,
const
T
*
input_1
,
const
T
*
input_2
,
const
T
*
input_3
,
T
*
output
,
cudaStream_t
cuda_stream
);
template
<
typename
T
>
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
int
w3
,
const
int
w4
,
const
T
*
input_1
,
const
T
*
input_2
,
const
T
*
input_3
,
const
T
*
input_4
,
T
*
output
,
void
ConcatKernel
(
const
int
size
,
const
int
input_num
,
const
int
all_size_before_axis
,
const
int
all_size_axis
,
int
*
len_axis
,
T
**
inputs
,
T
*
output
,
cudaStream_t
cuda_stream
);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cu
0 → 100755
浏览文件 @
5c0962ac
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stdio.h>
#include <stdint.h>
#include <cuda_runtime.h>
#include "backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh"
template
<
typename
T
>
__global__
void
Split
(
const
int
size
,
const
int
axis_step
,
const
int
all_size_before_axis
,
const
int
all_size_axis
,
const
T
*
input
,
T
**
outputs
)
{
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
size
;
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
num
=
pos
%
all_size_before_axis
/
all_size_axis
;
int
block
=
num
/
axis_step
;
int
block_pos
=
pos
/
all_size_before_axis
*
axis_step
*
all_size_axis
+
num
%
axis_step
*
all_size_axis
+
pos
%
all_size_axis
;
outputs
[
block
][
block_pos
]
=
input
[
pos
];
}
return
;
}
template
<
typename
T
>
void
SplitKernel
(
const
int
size
,
const
int
axis_step
,
const
int
all_size_before_axis
,
const
int
all_size_axis
,
const
T
*
input
,
T
**
outputs
,
cudaStream_t
cuda_stream
)
{
Split
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
size
,
axis_step
,
all_size_before_axis
,
all_size_axis
,
input
,
outputs
);
return
;
}
template
void
SplitKernel
(
const
int
size
,
const
int
axis_step
,
const
int
all_size_before_axis
,
const
int
all_size_axis
,
const
float
*
input
,
float
**
outputs
,
cudaStream_t
cuda_stream
);
template
void
SplitKernel
(
const
int
size
,
const
int
axis_step
,
const
int
all_size_before_axis
,
const
int
all_size_axis
,
const
int
*
input
,
int
**
outputs
,
cudaStream_t
cuda_stream
);
template
void
SplitKernel
(
const
int
size
,
const
int
axis_step
,
const
int
all_size_before_axis
,
const
int
all_size_axis
,
const
half
*
input
,
half
**
outputs
,
cudaStream_t
cuda_stream
);
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh
0 → 100755
浏览文件 @
5c0962ac
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPLIT_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPLIT_H_
#include "runtime/device/gpu/cuda_common.h"
template
<
typename
T
>
void
SplitKernel
(
const
int
size
,
const
int
axis_step
,
const
int
all_size_before_axis
,
const
int
all_size_axis
,
const
T
*
input
,
T
**
outputs
,
cudaStream_t
cuda_stream
);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPLIT_H_
tests/st/ops/gpu/test_split.py
0 → 100644
浏览文件 @
5c0962ac
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
import
pytest
import
mindspore.context
as
context
from
mindspore
import
Tensor
import
mindspore.nn
as
nn
from
mindspore.ops
import
operations
as
P
class
Net
(
nn
.
Cell
):
def
__init__
(
self
,
axis
=
0
,
out_nums
=
1
):
super
(
Net
,
self
).
__init__
()
self
.
split
=
P
.
Split
(
axis
,
out_nums
)
def
construct
(
self
,
x
):
return
self
.
split
(
x
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_split
():
x
=
np
.
array
([[[
1
,
-
1
,
1
],
[
2
,
-
2
,
2
]],
[[
3
,
-
3
,
3
],
[
4
,
-
4
,
4
]],
[[
5
,
-
5
,
5
],
[
6
,
-
6
,
6
]]]).
astype
(
np
.
float32
)
split_op
=
Net
(
0
,
3
)
outputs
=
split_op
(
Tensor
(
x
))
for
i
,
out
in
enumerate
(
outputs
):
assert
(
out
.
asnumpy
()
==
x
[
i
]).
all
()
def
test_split_4d
():
x_np
=
np
.
random
.
randn
(
2
,
6
,
4
,
4
).
astype
(
np
.
float32
)
y
=
np
.
split
(
x_np
,
3
,
axis
=
1
)
split_op
=
Net
(
1
,
3
)
outputs
=
split_op
(
Tensor
(
x_np
))
for
i
,
out
in
enumerate
(
outputs
):
assert
(
out
.
asnumpy
()
==
y
[
i
]).
all
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录