Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
f6029709
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f6029709
编写于
5月 07, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 07, 2020
浏览文件
操作
浏览文件
下载
差异文件
!323 Gpu Concat support 4 inputs
Merge pull request !323 from chenweifeng/concat
上级
4e25fec7
cc936462
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
197 addition
and
54 deletion
+197
-54
mindspore/ccsrc/kernel/gpu/arrays/concatv2_gpu_kernel.cc
mindspore/ccsrc/kernel/gpu/arrays/concatv2_gpu_kernel.cc
+5
-7
mindspore/ccsrc/kernel/gpu/arrays/concatv2_gpu_kernel.h
mindspore/ccsrc/kernel/gpu/arrays/concatv2_gpu_kernel.h
+51
-35
mindspore/ccsrc/kernel/gpu/cuda_impl/concatv2_impl.cu
mindspore/ccsrc/kernel/gpu/cuda_impl/concatv2_impl.cu
+73
-9
mindspore/ccsrc/kernel/gpu/cuda_impl/concatv2_impl.cuh
mindspore/ccsrc/kernel/gpu/cuda_impl/concatv2_impl.cuh
+9
-3
tests/st/ops/gpu/test_concatv2_op.py
tests/st/ops/gpu/test_concatv2_op.py
+59
-0
未找到文件。
mindspore/ccsrc/kernel/gpu/arrays/concatv2_gpu_kernel.cc
浏览文件 @
f6029709
...
...
@@ -19,15 +19,13 @@
namespace
mindspore
{
namespace
kernel
{
MS_REG_GPU_KERNEL_ONE
(
Concat
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
Concat
,
KernelAttr
().
AddAllSameAttr
(
true
).
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
ConcatV2GpuFwdKernel
,
float
)
MS_REG_GPU_KERNEL_ONE
(
Concat
,
KernelAttr
().
AddAllSameAttr
(
true
).
AddInputAttr
(
kNumberTypeInt32
).
AddOutputAttr
(
kNumberTypeInt32
),
ConcatV2GpuFwdKernel
,
int
)
MS_REG_GPU_KERNEL_ONE
(
Concat
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddInputAttr
(
kNumberTypeInt32
).
AddOutputAttr
(
kNumberTypeInt32
),
ConcatV2GpuFwdKernel
,
int
)
MS_REG_GPU_KERNEL_ONE
(
Concat
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
Concat
,
KernelAttr
().
AddAllSameAttr
(
true
).
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
ConcatV2GpuFwdKernel
,
half
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/gpu/arrays/concatv2_gpu_kernel.h
浏览文件 @
f6029709
...
...
@@ -27,7 +27,7 @@ namespace kernel {
template
<
typename
T
>
class
ConcatV2GpuFwdKernel
:
public
GpuKernel
{
public:
ConcatV2GpuFwdKernel
()
:
axis_
(
0
),
input0_size_
(
0
),
input1_size_
(
0
),
output_size_
(
0
),
workspace
_size_
(
0
)
{}
ConcatV2GpuFwdKernel
()
:
axis_
(
0
),
output
_size_
(
0
)
{}
~
ConcatV2GpuFwdKernel
()
override
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
{
return
output_size_list_
;
}
...
...
@@ -35,12 +35,32 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
uintptr_t
stream_ptr
)
override
{
T
*
input_0
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
T
*
input_1
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
T
*
output
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
if
(
inputs
.
size
()
==
2
)
{
T
*
input_0
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
T
*
input_1
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
T
*
output
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
ConcatKernel
(
output_size_
/
sizeof
(
T
),
w_
[
0
],
w_
[
1
],
input_0
,
input_1
,
output
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
}
if
(
inputs
.
size
()
==
3
)
{
T
*
input_0
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
T
*
input_1
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
T
*
input_2
=
GetDeviceAddress
<
T
>
(
inputs
,
2
);
T
*
output
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
ConcatKernel
(
output_size_
/
sizeof
(
T
),
w_
[
0
],
w_
[
1
],
w_
[
2
],
input_0
,
input_1
,
input_2
,
output
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
}
CalConcatV2
(
output_size_
/
sizeof
(
T
),
w_
[
0
],
w_
[
1
],
input_0
,
input_1
,
output
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
if
(
inputs
.
size
()
==
4
)
{
T
*
input_0
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
T
*
input_1
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
T
*
input_2
=
GetDeviceAddress
<
T
>
(
inputs
,
2
);
T
*
input_3
=
GetDeviceAddress
<
T
>
(
inputs
,
3
);
T
*
output
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
ConcatKernel
(
output_size_
/
sizeof
(
T
),
w_
[
0
],
w_
[
1
],
w_
[
2
],
w_
[
3
],
input_0
,
input_1
,
input_2
,
input_3
,
output
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
}
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
...
...
@@ -48,44 +68,44 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
return
false
;
}
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
input0_size_
=
sizeof
(
T
);
for
(
size_t
i
=
0
;
i
<
input_shape
.
size
();
i
++
)
{
input0_size_
*=
input_shape
[
i
];
}
auto
input_shape1
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
1
);
input1_size_
=
sizeof
(
T
);
for
(
size_t
i
=
0
;
i
<
input_shape1
.
size
();
i
++
)
{
input1_size_
*=
input_shape1
[
i
];
}
output_size_
=
input0_size_
+
input1_size_
;
axis_
=
GetAttr
<
int
>
(
kernel_node
,
"axis"
);
if
(
axis_
<
0
)
{
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
axis_
+=
SizeToInt
(
input_shape
.
size
());
}
w_
[
0
]
=
1
;
w_
[
1
]
=
1
;
for
(
size_t
i
=
IntToSize
(
axis_
);
i
<
input_shape
.
size
();
i
++
)
{
w_
[
0
]
*=
SizeToInt
(
input_shape
[
i
]);
w_
[
1
]
*=
SizeToInt
(
input_shape1
[
i
]);
auto
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
for
(
size_t
i
=
0
;
i
<
input_num
;
i
++
)
{
auto
input_size
=
sizeof
(
T
);
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
i
);
for
(
size_t
j
=
0
;
j
<
input_shape
.
size
();
j
++
)
{
input_size
*=
SizeToInt
(
input_shape
[
j
]);
if
(
j
>=
IntToSize
(
axis_
))
{
w_
[
i
]
*=
SizeToInt
(
input_shape
[
j
]);
}
input_size_list_
.
push_back
(
input_size
);
}
}
auto
output_shape
=
AnfAlgo
::
GetOutputInferShape
(
kernel_node
,
0
);
output_size_
=
sizeof
(
T
);
for
(
size_t
i
=
0
;
i
<
output_shape
.
size
();
i
++
)
{
output_size_
*=
output_shape
[
i
];
}
output_size_list_
.
push_back
(
output_size_
);
InitSizeLists
();
return
true
;
}
protected:
void
InitSizeLists
()
override
{
input_size_list_
.
push_back
(
input0_size_
);
input_size_list_
.
push_back
(
input1_size_
);
output_size_list_
.
push_back
(
output_size_
);
}
void
InitSizeLists
()
override
{}
private:
bool
CheckParam
(
const
CNodePtr
&
kernel_node
)
{
size_t
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
if
(
input_num
!=
2
)
{
MS_LOG
(
ERROR
)
<<
"Input number is "
<<
input_num
<<
", but ConcatV2GpuFwdKernel needs
2 inputs
."
;
if
(
input_num
<
2
||
input_num
>
4
)
{
MS_LOG
(
ERROR
)
<<
"Input number is "
<<
input_num
<<
", but ConcatV2GpuFwdKernel needs
inputs between 2 and 4
."
;
return
false
;
}
size_t
output_num
=
AnfAlgo
::
GetOutputTensorNum
(
kernel_node
);
...
...
@@ -95,16 +115,12 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
}
return
true
;
}
int
w_
[
2
]
=
{
1
};
int
w_
[
4
]
=
{
1
,
1
,
1
,
1
};
int
axis_
;
size_t
output_size_
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
size_t
input0_size_
;
size_t
input1_size_
;
size_t
output_size_
;
size_t
workspace_size_
;
};
}
// namespace kernel
}
// namespace mindspore
...
...
mindspore/ccsrc/kernel/gpu/cuda_impl/concatv2_impl.cu
浏览文件 @
f6029709
...
...
@@ -19,7 +19,7 @@
#include <cuda_runtime.h>
#include "kernel/gpu/cuda_impl/concatv2_impl.cuh"
template
<
typename
T
>
__global__
void
Concat
V2
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
T
*
input_1
,
const
T
*
input_2
,
T
*
output
)
{
__global__
void
Concat
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
T
*
input_1
,
const
T
*
input_2
,
T
*
output
)
{
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
(
size
);
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
n
=
pos
/
(
w1
+
w2
);
int
m
=
pos
%
(
w1
+
w2
);
...
...
@@ -29,16 +29,80 @@ __global__ void ConcatV2(const size_t size, const int w1, const int w2, const T*
}
template
<
typename
T
>
void
CalConcatV2
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
T
*
input_1
,
const
T
*
input_2
,
T
*
output
,
__global__
void
Concat
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
int
w3
,
const
T
*
input_1
,
const
T
*
input_2
,
const
T
*
input_3
,
T
*
output
)
{
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
(
size
);
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
n
=
pos
/
(
w1
+
w2
+
w3
);
int
m
=
pos
%
(
w1
+
w2
+
w3
);
output
[
pos
]
=
m
<
w1
?
input_1
[
n
*
w1
+
m
]
:
m
<
w1
+
w2
?
input_2
[
n
*
w2
+
m
-
w1
]
:
input_3
[
n
*
w3
+
m
-
w1
-
w2
];
}
return
;
}
template
<
typename
T
>
__global__
void
Concat
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
int
w3
,
const
int
w4
,
const
T
*
input_1
,
const
T
*
input_2
,
const
T
*
input_3
,
const
T
*
input_4
,
T
*
output
)
{
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
(
size
);
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
n
=
pos
/
(
w1
+
w2
+
w3
+
w4
);
int
m
=
pos
%
(
w1
+
w2
+
w3
+
w4
);
output
[
pos
]
=
m
<
w1
?
input_1
[
n
*
w1
+
m
]
:
m
<
w1
+
w2
?
input_2
[
n
*
w2
+
m
-
w1
]
:
m
<
w1
+
w2
+
w3
?
input_3
[
n
*
w3
+
m
-
w1
-
w2
]
:
input_4
[
n
*
w4
+
m
-
w1
-
w2
-
w3
];
}
return
;
}
template
<
typename
T
>
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
T
*
input_1
,
const
T
*
input_2
,
T
*
output
,
cudaStream_t
cuda_stream
)
{
ConcatV2
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
size
,
w1
,
w2
,
input_1
,
input_2
,
output
);
Concat
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
size
,
w1
,
w2
,
input_1
,
input_2
,
output
);
return
;
}
template
<
typename
T
>
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
int
w3
,
const
T
*
input_1
,
const
T
*
input_2
,
const
T
*
input_3
,
T
*
output
,
cudaStream_t
cuda_stream
)
{
Concat
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
size
,
w1
,
w2
,
w3
,
input_1
,
input_2
,
input_3
,
output
);
return
;
}
template
void
CalConcatV2
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
float
*
input_1
,
const
float
*
input_2
,
float
*
output
,
cudaStream_t
cuda_stream
);
template
void
CalConcatV2
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
int
*
input_1
,
const
int
*
input_2
,
int
*
output
,
cudaStream_t
cuda_stream
);
template
void
CalConcatV2
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
half
*
input_1
,
const
half
*
input_2
,
half
*
output
,
cudaStream_t
cuda_stream
);
template
<
typename
T
>
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
int
w3
,
const
int
w4
,
const
T
*
input_1
,
const
T
*
input_2
,
const
T
*
input_3
,
const
T
*
input_4
,
T
*
output
,
cudaStream_t
cuda_stream
)
{
Concat
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
size
,
w1
,
w2
,
w3
,
w4
,
input_1
,
input_2
,
input_3
,
input_4
,
output
);
return
;
}
template
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
float
*
input_1
,
const
float
*
input_2
,
float
*
output
,
cudaStream_t
cuda_stream
);
template
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
int
*
input_1
,
const
int
*
input_2
,
int
*
output
,
cudaStream_t
cuda_stream
);
template
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
half
*
input_1
,
const
half
*
input_2
,
half
*
output
,
cudaStream_t
cuda_stream
);
template
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
int
w3
,
const
float
*
input_1
,
const
float
*
input_2
,
const
float
*
input_3
,
float
*
output
,
cudaStream_t
cuda_stream
);
template
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
int
w3
,
const
int
*
input_1
,
const
int
*
input_2
,
const
int
*
input_3
,
int
*
output
,
cudaStream_t
cuda_stream
);
template
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
int
w3
,
const
half
*
input_1
,
const
half
*
input_2
,
const
half
*
input_3
,
half
*
output
,
cudaStream_t
cuda_stream
);
template
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
int
w3
,
const
int
w4
,
const
float
*
input_1
,
const
float
*
input_2
,
const
float
*
input_3
,
const
float
*
input_4
,
float
*
output
,
cudaStream_t
cuda_stream
);
template
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
int
w3
,
const
int
w4
,
const
int
*
input_1
,
const
int
*
input_2
,
const
int
*
input_3
,
const
int
*
input_4
,
int
*
output
,
cudaStream_t
cuda_stream
);
template
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
int
w3
,
const
int
w4
,
const
half
*
input_1
,
const
half
*
input_2
,
const
half
*
input_3
,
const
half
*
input_4
,
half
*
output
,
cudaStream_t
cuda_stream
);
mindspore/ccsrc/kernel/gpu/cuda_impl/concatv2_impl.cuh
浏览文件 @
f6029709
...
...
@@ -19,7 +19,13 @@
#include "device/gpu/cuda_common.h"
template
<
typename
T
>
void
CalConcatV2
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
T
*
input_1
,
const
T
*
input_2
,
T
*
output
,
cudaStream_t
cuda_stream
);
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
T
*
input_1
,
const
T
*
input_2
,
T
*
output
,
cudaStream_t
cuda_stream
);
template
<
typename
T
>
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
int
w3
,
const
T
*
input_1
,
const
T
*
input_2
,
const
T
*
input_3
,
T
*
output
,
cudaStream_t
cuda_stream
);
template
<
typename
T
>
void
ConcatKernel
(
const
size_t
size
,
const
int
w1
,
const
int
w2
,
const
int
w3
,
const
int
w4
,
const
T
*
input_1
,
const
T
*
input_2
,
const
T
*
input_3
,
const
T
*
input_4
,
T
*
output
,
cudaStream_t
cuda_stream
);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_
tests/st/ops/gpu/test_concatv2_op.py
浏览文件 @
f6029709
...
...
@@ -113,3 +113,62 @@ def test_axis21():
[
2.
,
3.
,
3.
,
4.
,
5.
]]
assert
(
output
.
asnumpy
()
==
expect
).
all
()
print
(
output
)
class
Concat3INet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Concat3INet
,
self
).
__init__
()
self
.
cat
=
P
.
Concat
(
axis
=
1
)
def
construct
(
self
,
x1
,
x2
,
x3
):
return
self
.
cat
((
x1
,
x2
,
x3
))
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_concat_3i
():
cat
=
Concat3INet
()
x1_np
=
np
.
random
.
randn
(
32
,
4
,
224
,
224
).
astype
(
np
.
float32
)
x2_np
=
np
.
random
.
randn
(
32
,
8
,
224
,
224
).
astype
(
np
.
float32
)
x3_np
=
np
.
random
.
randn
(
32
,
10
,
224
,
224
).
astype
(
np
.
float32
)
output_np
=
np
.
concatenate
((
x1_np
,
x2_np
,
x3_np
),
axis
=
1
)
x1_ms
=
Tensor
(
x1_np
)
x2_ms
=
Tensor
(
x2_np
)
x3_ms
=
Tensor
(
x3_np
)
output_ms
=
cat
(
x1_ms
,
x2_ms
,
x3_ms
)
error
=
np
.
ones
(
shape
=
output_np
.
shape
)
*
10e-6
diff
=
output_ms
.
asnumpy
()
-
output_np
assert
np
.
all
(
diff
<
error
)
class
Concat4INet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Concat4INet
,
self
).
__init__
()
self
.
cat
=
P
.
Concat
(
axis
=
1
)
def
construct
(
self
,
x1
,
x2
,
x3
,
x4
):
return
self
.
cat
((
x1
,
x2
,
x3
,
x4
))
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_concat_4i
():
cat
=
Concat4INet
()
x1_np
=
np
.
random
.
randn
(
32
,
4
,
224
,
224
).
astype
(
np
.
float32
)
x2_np
=
np
.
random
.
randn
(
32
,
8
,
224
,
224
).
astype
(
np
.
float32
)
x3_np
=
np
.
random
.
randn
(
32
,
10
,
224
,
224
).
astype
(
np
.
float32
)
x4_np
=
np
.
random
.
randn
(
32
,
5
,
224
,
224
).
astype
(
np
.
float32
)
output_np
=
np
.
concatenate
((
x1_np
,
x2_np
,
x3_np
,
x4_np
),
axis
=
1
)
x1_ms
=
Tensor
(
x1_np
)
x2_ms
=
Tensor
(
x2_np
)
x3_ms
=
Tensor
(
x3_np
)
x4_ms
=
Tensor
(
x4_np
)
output_ms
=
cat
(
x1_ms
,
x2_ms
,
x3_ms
,
x4_ms
)
error
=
np
.
ones
(
shape
=
output_np
.
shape
)
*
10e-6
diff
=
output_ms
.
asnumpy
()
-
output_np
assert
np
.
all
(
diff
<
error
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录