Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
d3936b9f
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d3936b9f
编写于
5月 08, 2020
作者:
Z
ZPaC
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
GPU kernels adapt with special dimensions.
上级
6c79c00a
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
122 addition
and
56 deletion
+122
-56
mindspore/ccsrc/kernel/gpu/arrays/array_reduce_gpu_kernel.h
mindspore/ccsrc/kernel/gpu/arrays/array_reduce_gpu_kernel.h
+13
-13
mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.h
mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.h
+8
-0
mindspore/ccsrc/kernel/gpu/math/tensoradd_gpu_kernel.h
mindspore/ccsrc/kernel/gpu/math/tensoradd_gpu_kernel.h
+1
-1
mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h
mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h
+12
-1
tests/st/ops/gpu/test_reduce_max_op.py
tests/st/ops/gpu/test_reduce_max_op.py
+24
-9
tests/st/ops/gpu/test_reduce_mean_op.py
tests/st/ops/gpu/test_reduce_mean_op.py
+33
-17
tests/st/ops/gpu/test_reduce_sum_op.py
tests/st/ops/gpu/test_reduce_sum_op.py
+31
-15
未找到文件。
mindspore/ccsrc/kernel/gpu/arrays/array_reduce_gpu_kernel.h
浏览文件 @
d3936b9f
...
@@ -43,7 +43,7 @@ class ArrayReduceGpuKernel : public GpuKernel {
...
@@ -43,7 +43,7 @@ class ArrayReduceGpuKernel : public GpuKernel {
inputA_descriptor_
(
nullptr
),
inputA_descriptor_
(
nullptr
),
outputC_descriptor_
(
nullptr
),
outputC_descriptor_
(
nullptr
),
keep_dims_
(
false
),
keep_dims_
(
false
),
is_reduce_dim_one_
(
tru
e
),
all_match_
(
fals
e
),
is_null_input_
(
false
),
is_null_input_
(
false
),
input_size_
(
0
),
input_size_
(
0
),
output_size_
(
0
),
output_size_
(
0
),
...
@@ -65,7 +65,9 @@ class ArrayReduceGpuKernel : public GpuKernel {
...
@@ -65,7 +65,9 @@ class ArrayReduceGpuKernel : public GpuKernel {
const
float
alpha
=
1
;
const
float
alpha
=
1
;
const
float
beta
=
0
;
const
float
beta
=
0
;
if
(
is_reduce_dim_one_
)
{
if
(
all_match_
)
{
MS_LOG
(
WARNING
)
<<
"The corresponding dimensions of the input and output tensors all match. No need to call cuDNN kernel."
;
CHECK_CUDA_RET_WITH_EXCEPT
(
cudaMemcpyAsync
(
output_addr
,
input_addr
,
inputs
[
0
]
->
size
,
cudaMemcpyDeviceToDevice
,
CHECK_CUDA_RET_WITH_EXCEPT
(
cudaMemcpyAsync
(
output_addr
,
input_addr
,
inputs
[
0
]
->
size
,
cudaMemcpyDeviceToDevice
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
)),
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
)),
"cudaMemcpyAsync failed in ArrayReduceGpuKernel::Launch."
);
"cudaMemcpyAsync failed in ArrayReduceGpuKernel::Launch."
);
...
@@ -178,6 +180,7 @@ class ArrayReduceGpuKernel : public GpuKernel {
...
@@ -178,6 +180,7 @@ class ArrayReduceGpuKernel : public GpuKernel {
void
InferInAndOutDesc
(
const
std
::
vector
<
size_t
>
&
input_shape
,
const
std
::
vector
<
size_t
>
&
output_shape
)
{
void
InferInAndOutDesc
(
const
std
::
vector
<
size_t
>
&
input_shape
,
const
std
::
vector
<
size_t
>
&
output_shape
)
{
std
::
vector
<
size_t
>
inputA_shape
=
input_shape
;
std
::
vector
<
size_t
>
inputA_shape
=
input_shape
;
std
::
vector
<
size_t
>
outputC_shape
=
output_shape
;
std
::
vector
<
size_t
>
outputC_shape
=
output_shape
;
std
::
vector
<
int
>
real_input_shape
;
int
shapeA_n
,
shapeA_c
,
shapeA_h
,
shapeA_w
;
int
shapeA_n
,
shapeA_c
,
shapeA_h
,
shapeA_w
;
shapeA_n
=
inputA_shape
.
size
()
<
4
?
1
:
SizeToInt
(
inputA_shape
[
inputA_shape
.
size
()
-
4
]);
shapeA_n
=
inputA_shape
.
size
()
<
4
?
1
:
SizeToInt
(
inputA_shape
[
inputA_shape
.
size
()
-
4
]);
shapeA_c
=
inputA_shape
.
size
()
<
3
?
1
:
SizeToInt
(
inputA_shape
[
inputA_shape
.
size
()
-
3
]);
shapeA_c
=
inputA_shape
.
size
()
<
3
?
1
:
SizeToInt
(
inputA_shape
[
inputA_shape
.
size
()
-
3
]);
...
@@ -196,7 +199,9 @@ class ArrayReduceGpuKernel : public GpuKernel {
...
@@ -196,7 +199,9 @@ class ArrayReduceGpuKernel : public GpuKernel {
CHECK_CUDNN_RET_WITH_EXCEPT
(
cudnnSetTensor4dDescriptor
(
outputC_descriptor_
,
CUDNN_TENSOR_NCHW
,
data_type_
,
CHECK_CUDNN_RET_WITH_EXCEPT
(
cudnnSetTensor4dDescriptor
(
outputC_descriptor_
,
CUDNN_TENSOR_NCHW
,
data_type_
,
shapeC_n
,
shapeC_c
,
shapeC_h
,
shapeC_w
),
shapeC_n
,
shapeC_c
,
shapeC_h
,
shapeC_w
),
"cudnnSetTensor4dDescriptor failed"
);
"cudnnSetTensor4dDescriptor failed"
);
is_reduce_dim_one_
=
false
;
if
(
shapeA_n
==
shapeC_n
&&
shapeA_c
==
shapeC_c
&&
shapeA_h
==
shapeC_h
&&
shapeA_w
==
shapeC_w
)
{
all_match_
=
true
;
}
return
;
return
;
}
}
...
@@ -205,21 +210,16 @@ class ArrayReduceGpuKernel : public GpuKernel {
...
@@ -205,21 +210,16 @@ class ArrayReduceGpuKernel : public GpuKernel {
(
void
)(
outputC_shape
.
insert
(
outputC_shape
.
begin
()
+
i
,
1
));
(
void
)(
outputC_shape
.
insert
(
outputC_shape
.
begin
()
+
i
,
1
));
}
}
}
}
for
(
auto
i
:
axis_
)
{
if
(
inputA_shape
[
IntToSize
(
i
)]
!=
1
)
{
// To avoid cudnnReduceTensor bug when the dimension which needs to be
// reduced is already 1.
is_reduce_dim_one_
=
false
;
}
}
shapeC_n
=
outputC_shape
.
size
()
<
4
?
1
:
SizeToInt
(
outputC_shape
[
outputC_shape
.
size
()
-
4
]);
shapeC_n
=
outputC_shape
.
size
()
<
4
?
1
:
SizeToInt
(
outputC_shape
[
outputC_shape
.
size
()
-
4
]);
shapeC_c
=
outputC_shape
.
size
()
<
3
?
1
:
SizeToInt
(
outputC_shape
[
outputC_shape
.
size
()
-
3
]);
shapeC_c
=
outputC_shape
.
size
()
<
3
?
1
:
SizeToInt
(
outputC_shape
[
outputC_shape
.
size
()
-
3
]);
shapeC_h
=
outputC_shape
.
size
()
<
2
?
1
:
SizeToInt
(
outputC_shape
[
outputC_shape
.
size
()
-
2
]);
shapeC_h
=
outputC_shape
.
size
()
<
2
?
1
:
SizeToInt
(
outputC_shape
[
outputC_shape
.
size
()
-
2
]);
shapeC_w
=
SizeToInt
(
outputC_shape
[
outputC_shape
.
size
()
-
1
]);
shapeC_w
=
outputC_shape
.
size
()
==
0
?
1
:
SizeToInt
(
outputC_shape
[
outputC_shape
.
size
()
-
1
]);
CHECK_CUDNN_RET_WITH_EXCEPT
(
cudnnSetTensor4dDescriptor
(
outputC_descriptor_
,
CUDNN_TENSOR_NCHW
,
data_type_
,
shapeC_n
,
CHECK_CUDNN_RET_WITH_EXCEPT
(
cudnnSetTensor4dDescriptor
(
outputC_descriptor_
,
CUDNN_TENSOR_NCHW
,
data_type_
,
shapeC_n
,
shapeC_c
,
shapeC_h
,
shapeC_w
),
shapeC_c
,
shapeC_h
,
shapeC_w
),
"cudnnSetTensor4dDescriptor failed"
);
"cudnnSetTensor4dDescriptor failed"
);
if
(
shapeA_n
==
shapeC_n
&&
shapeA_c
==
shapeC_c
&&
shapeA_h
==
shapeC_h
&&
shapeA_w
==
shapeC_w
)
{
all_match_
=
true
;
}
return
;
return
;
}
}
...
@@ -234,7 +234,7 @@ class ArrayReduceGpuKernel : public GpuKernel {
...
@@ -234,7 +234,7 @@ class ArrayReduceGpuKernel : public GpuKernel {
std
::
vector
<
int
>
axis_
;
std
::
vector
<
int
>
axis_
;
bool
keep_dims_
;
bool
keep_dims_
;
bool
is_reduce_dim_one
_
;
bool
all_match
_
;
bool
is_null_input_
;
bool
is_null_input_
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
...
...
mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.h
浏览文件 @
d3936b9f
...
@@ -114,6 +114,14 @@ class BinaryOpGpuKernel : public GpuKernel {
...
@@ -114,6 +114,14 @@ class BinaryOpGpuKernel : public GpuKernel {
InferBinaryType
(
kernel_node
);
InferBinaryType
(
kernel_node
);
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
auto
input_shapeB
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
1
);
auto
input_shapeB
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
1
);
auto
output_shape
=
AnfAlgo
::
GetOutputInferShape
(
kernel_node
,
0
);
if
(
input_shape
!=
output_shape
&&
input_shapeB
!=
output_shape
)
{
MS_LOG
(
ERROR
)
<<
"Double-sided broadcast was not supported in cudnn of cudnnOpTensor:
\n
"
"InputA must match the corresponding dimension of the destination tensor outC, and each "
"dimension of the inputB "
"must match the corresponding dimension of outC or must be equal to 1."
;
return
false
;
}
is_null_input_
=
CHECK_NULL_INPUT
(
input_shape
)
||
CHECK_NULL_INPUT
(
input_shapeB
);
is_null_input_
=
CHECK_NULL_INPUT
(
input_shape
)
||
CHECK_NULL_INPUT
(
input_shapeB
);
if
(
is_null_input_
)
{
if
(
is_null_input_
)
{
MS_LOG
(
WARNING
)
<<
"BinaryOpGpuKernel input is null"
;
MS_LOG
(
WARNING
)
<<
"BinaryOpGpuKernel input is null"
;
...
...
mindspore/ccsrc/kernel/gpu/math/tensoradd_gpu_kernel.h
浏览文件 @
d3936b9f
...
@@ -90,7 +90,7 @@ class TensorAddGpuFwdKernel : public GpuKernel {
...
@@ -90,7 +90,7 @@ class TensorAddGpuFwdKernel : public GpuKernel {
if
(
input_shape
!=
output_shape
&&
input_shapeB
!=
output_shape
)
{
if
(
input_shape
!=
output_shape
&&
input_shapeB
!=
output_shape
)
{
MS_LOG
(
ERROR
)
<<
"Double-sided broadcast was not supported in cudnn of cudnnOpTensor:
\n
"
MS_LOG
(
ERROR
)
<<
"Double-sided broadcast was not supported in cudnn of cudnnOpTensor:
\n
"
"InputA must match the corresponding dimension of the destination tensor outC, and each "
"InputA must match the corresponding dimension of the destination tensor outC, and each "
"dimension of the inputB"
"dimension of the inputB
"
"must match the corresponding dimension of outC or must be equal to 1."
;
"must match the corresponding dimension of outC or must be equal to 1."
;
return
false
;
return
false
;
}
}
...
...
mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h
浏览文件 @
d3936b9f
...
@@ -50,7 +50,11 @@ template <typename T>
...
@@ -50,7 +50,11 @@ template <typename T>
class
UnaryOpGpuKernel
:
public
GpuKernel
{
class
UnaryOpGpuKernel
:
public
GpuKernel
{
public:
public:
UnaryOpGpuKernel
()
UnaryOpGpuKernel
()
:
unary_op_type_
(
UNARY_OP_INVALID_TYPE
),
input_size_
(
sizeof
(
T
)),
output_size_
(
sizeof
(
T
)),
workspace_size_
(
0
)
{}
:
unary_op_type_
(
UNARY_OP_INVALID_TYPE
),
input_size_
(
sizeof
(
T
)),
output_size_
(
sizeof
(
T
)),
workspace_size_
(
0
),
is_null_input_
(
false
)
{}
~
UnaryOpGpuKernel
()
override
=
default
;
~
UnaryOpGpuKernel
()
override
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
...
@@ -121,6 +125,12 @@ class UnaryOpGpuKernel : public GpuKernel {
...
@@ -121,6 +125,12 @@ class UnaryOpGpuKernel : public GpuKernel {
return
false
;
return
false
;
}
}
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
is_null_input_
=
CHECK_NULL_INPUT
(
input_shape
);
if
(
is_null_input_
)
{
MS_LOG
(
WARNING
)
<<
"UnaryOpGpuKernel input is null"
;
InitSizeLists
();
return
true
;
}
for
(
size_t
i
=
0
;
i
<
input_shape
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
input_shape
.
size
();
i
++
)
{
input_size_
*=
input_shape
[
i
];
input_size_
*=
input_shape
[
i
];
}
}
...
@@ -140,6 +150,7 @@ class UnaryOpGpuKernel : public GpuKernel {
...
@@ -140,6 +150,7 @@ class UnaryOpGpuKernel : public GpuKernel {
size_t
input_size_
;
size_t
input_size_
;
size_t
output_size_
;
size_t
output_size_
;
size_t
workspace_size_
;
size_t
workspace_size_
;
bool
is_null_input_
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
...
...
tests/st/ops/gpu/test_reduce_max_op.py
浏览文件 @
d3936b9f
...
@@ -55,6 +55,11 @@ x7 = np.random.rand(2, 3, 4, 4).astype(np.float32)
...
@@ -55,6 +55,11 @@ x7 = np.random.rand(2, 3, 4, 4).astype(np.float32)
axis7
=
(
-
2
,
-
1
)
axis7
=
(
-
2
,
-
1
)
keep_dims7
=
True
keep_dims7
=
True
x8
=
np
.
random
.
rand
(
1
,
1
,
1
,
1
).
astype
(
np
.
float32
)
axis8
=
()
np_axis8
=
None
keep_dims8
=
True
context
.
set_context
(
device_target
=
'GPU'
)
context
.
set_context
(
device_target
=
'GPU'
)
...
@@ -94,6 +99,10 @@ class ReduceMax(nn.Cell):
...
@@ -94,6 +99,10 @@ class ReduceMax(nn.Cell):
self
.
axis7
=
axis7
self
.
axis7
=
axis7
self
.
keep_dims7
=
keep_dims7
self
.
keep_dims7
=
keep_dims7
self
.
x8
=
Tensor
(
x8
)
self
.
axis8
=
axis8
self
.
keep_dims8
=
keep_dims8
@
ms_function
@
ms_function
def
construct
(
self
):
def
construct
(
self
):
return
(
P
.
ReduceMax
(
self
.
keep_dims0
)(
self
.
x0
,
self
.
axis0
),
return
(
P
.
ReduceMax
(
self
.
keep_dims0
)(
self
.
x0
,
self
.
axis0
),
...
@@ -103,7 +112,8 @@ class ReduceMax(nn.Cell):
...
@@ -103,7 +112,8 @@ class ReduceMax(nn.Cell):
P
.
ReduceMax
(
self
.
keep_dims4
)(
self
.
x4
,
self
.
axis4
),
P
.
ReduceMax
(
self
.
keep_dims4
)(
self
.
x4
,
self
.
axis4
),
P
.
ReduceMax
(
self
.
keep_dims5
)(
self
.
x5
,
self
.
axis5
),
P
.
ReduceMax
(
self
.
keep_dims5
)(
self
.
x5
,
self
.
axis5
),
P
.
ReduceMax
(
self
.
keep_dims6
)(
self
.
x6
,
self
.
axis6
),
P
.
ReduceMax
(
self
.
keep_dims6
)(
self
.
x6
,
self
.
axis6
),
P
.
ReduceMax
(
self
.
keep_dims7
)(
self
.
x7
,
self
.
axis7
))
P
.
ReduceMax
(
self
.
keep_dims7
)(
self
.
x7
,
self
.
axis7
),
P
.
ReduceMax
(
self
.
keep_dims8
)(
self
.
x8
,
self
.
axis8
))
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
level0
...
@@ -114,48 +124,53 @@ def test_ReduceMax():
...
@@ -114,48 +124,53 @@ def test_ReduceMax():
output
=
reduce_max
()
output
=
reduce_max
()
expect0
=
np
.
max
(
x0
,
axis
=
axis0
,
keepdims
=
keep_dims0
)
expect0
=
np
.
max
(
x0
,
axis
=
axis0
,
keepdims
=
keep_dims0
)
diff0
=
output
[
0
].
asnumpy
()
-
expect0
diff0
=
abs
(
output
[
0
].
asnumpy
()
-
expect0
)
error0
=
np
.
ones
(
shape
=
expect0
.
shape
)
*
1.0e-5
error0
=
np
.
ones
(
shape
=
expect0
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff0
<
error0
)
assert
np
.
all
(
diff0
<
error0
)
assert
(
output
[
0
].
shape
()
==
expect0
.
shape
)
assert
(
output
[
0
].
shape
()
==
expect0
.
shape
)
expect1
=
np
.
max
(
x1
,
axis
=
axis1
,
keepdims
=
keep_dims1
)
expect1
=
np
.
max
(
x1
,
axis
=
axis1
,
keepdims
=
keep_dims1
)
diff1
=
output
[
1
].
asnumpy
()
-
expect1
diff1
=
abs
(
output
[
1
].
asnumpy
()
-
expect1
)
error1
=
np
.
ones
(
shape
=
expect1
.
shape
)
*
1.0e-5
error1
=
np
.
ones
(
shape
=
expect1
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff1
<
error1
)
assert
np
.
all
(
diff1
<
error1
)
assert
(
output
[
1
].
shape
()
==
expect1
.
shape
)
assert
(
output
[
1
].
shape
()
==
expect1
.
shape
)
expect2
=
np
.
max
(
x2
,
axis
=
axis2
,
keepdims
=
keep_dims2
)
expect2
=
np
.
max
(
x2
,
axis
=
axis2
,
keepdims
=
keep_dims2
)
diff2
=
output
[
2
].
asnumpy
()
-
expect2
diff2
=
abs
(
output
[
2
].
asnumpy
()
-
expect2
)
error2
=
np
.
ones
(
shape
=
expect2
.
shape
)
*
1.0e-5
error2
=
np
.
ones
(
shape
=
expect2
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff2
<
error2
)
assert
np
.
all
(
diff2
<
error2
)
assert
(
output
[
2
].
shape
()
==
expect2
.
shape
)
assert
(
output
[
2
].
shape
()
==
expect2
.
shape
)
expect3
=
np
.
max
(
x3
,
axis
=
axis3
,
keepdims
=
keep_dims3
)
expect3
=
np
.
max
(
x3
,
axis
=
axis3
,
keepdims
=
keep_dims3
)
diff3
=
output
[
3
].
asnumpy
()
-
expect3
diff3
=
abs
(
output
[
3
].
asnumpy
()
-
expect3
)
error3
=
np
.
ones
(
shape
=
expect3
.
shape
)
*
1.0e-5
error3
=
np
.
ones
(
shape
=
expect3
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff3
<
error3
)
assert
np
.
all
(
diff3
<
error3
)
assert
(
output
[
3
].
shape
()
==
expect3
.
shape
)
assert
(
output
[
3
].
shape
()
==
expect3
.
shape
)
expect4
=
np
.
max
(
x4
,
axis
=
np_axis4
,
keepdims
=
keep_dims4
)
expect4
=
np
.
max
(
x4
,
axis
=
np_axis4
,
keepdims
=
keep_dims4
)
diff4
=
output
[
4
].
asnumpy
()
-
expect4
diff4
=
abs
(
output
[
4
].
asnumpy
()
-
expect4
)
error4
=
np
.
ones
(
shape
=
expect4
.
shape
)
*
1.0e-5
error4
=
np
.
ones
(
shape
=
expect4
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff4
<
error4
)
assert
np
.
all
(
diff4
<
error4
)
assert
(
output
[
4
].
shape
()
==
expect4
.
shape
)
assert
(
output
[
4
].
shape
()
==
expect4
.
shape
)
expect5
=
np
.
max
(
x5
,
axis
=
np_axis5
,
keepdims
=
keep_dims5
)
expect5
=
np
.
max
(
x5
,
axis
=
np_axis5
,
keepdims
=
keep_dims5
)
diff5
=
output
[
5
].
asnumpy
()
-
expect5
diff5
=
abs
(
output
[
5
].
asnumpy
()
-
expect5
)
error5
=
np
.
ones
(
shape
=
expect5
.
shape
)
*
1.0e-5
error5
=
np
.
ones
(
shape
=
expect5
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff5
<
error5
)
assert
np
.
all
(
diff5
<
error5
)
assert
(
output
[
5
].
shape
()
==
expect5
.
shape
)
assert
(
output
[
5
].
shape
()
==
expect5
.
shape
)
expect6
=
np
.
max
(
x6
,
axis
=
axis6
,
keepdims
=
keep_dims6
)
expect6
=
np
.
max
(
x6
,
axis
=
axis6
,
keepdims
=
keep_dims6
)
diff6
=
output
[
6
].
asnumpy
()
-
expect6
diff6
=
abs
(
output
[
6
].
asnumpy
()
-
expect6
)
error6
=
np
.
ones
(
shape
=
expect6
.
shape
)
*
1.0e-5
error6
=
np
.
ones
(
shape
=
expect6
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff6
<
error6
)
assert
np
.
all
(
diff6
<
error6
)
assert
(
output
[
6
].
shape
()
==
expect6
.
shape
)
assert
(
output
[
6
].
shape
()
==
expect6
.
shape
)
expect7
=
np
.
max
(
x7
,
axis
=
axis7
,
keepdims
=
keep_dims7
)
expect7
=
np
.
max
(
x7
,
axis
=
axis7
,
keepdims
=
keep_dims7
)
diff7
=
output
[
7
].
asnumpy
()
-
expect7
diff7
=
abs
(
output
[
7
].
asnumpy
()
-
expect7
)
error7
=
np
.
ones
(
shape
=
expect7
.
shape
)
*
1.0e-5
error7
=
np
.
ones
(
shape
=
expect7
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff7
<
error7
)
assert
np
.
all
(
diff7
<
error7
)
expect8
=
np
.
max
(
x8
,
axis
=
np_axis8
,
keepdims
=
keep_dims8
)
diff8
=
abs
(
output
[
8
].
asnumpy
()
-
expect8
)
error8
=
np
.
ones
(
shape
=
expect8
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff8
<
error8
)
tests/st/ops/gpu/test_reduce_mean_op.py
浏览文件 @
d3936b9f
...
@@ -77,6 +77,11 @@ x13 = np.random.rand(2, 3, 4, 4).astype(np.float32)
...
@@ -77,6 +77,11 @@ x13 = np.random.rand(2, 3, 4, 4).astype(np.float32)
axis13
=
(
-
2
,
-
1
)
axis13
=
(
-
2
,
-
1
)
keep_dims13
=
True
keep_dims13
=
True
x14
=
np
.
random
.
rand
(
1
,
1
,
1
,
1
).
astype
(
np
.
float32
)
axis14
=
()
np_axis14
=
None
keep_dims14
=
True
context
.
set_context
(
device_target
=
'GPU'
)
context
.
set_context
(
device_target
=
'GPU'
)
...
@@ -140,6 +145,10 @@ class ReduceMean(nn.Cell):
...
@@ -140,6 +145,10 @@ class ReduceMean(nn.Cell):
self
.
axis13
=
axis13
self
.
axis13
=
axis13
self
.
keep_dims13
=
keep_dims13
self
.
keep_dims13
=
keep_dims13
self
.
x14
=
Tensor
(
x14
)
self
.
axis14
=
axis14
self
.
keep_dims14
=
keep_dims14
@
ms_function
@
ms_function
def
construct
(
self
):
def
construct
(
self
):
return
(
P
.
ReduceMean
(
self
.
keep_dims0
)(
self
.
x0
,
self
.
axis0
),
return
(
P
.
ReduceMean
(
self
.
keep_dims0
)(
self
.
x0
,
self
.
axis0
),
...
@@ -155,7 +164,8 @@ class ReduceMean(nn.Cell):
...
@@ -155,7 +164,8 @@ class ReduceMean(nn.Cell):
P
.
ReduceMean
(
self
.
keep_dims10
)(
self
.
x10
,
self
.
axis10
),
P
.
ReduceMean
(
self
.
keep_dims10
)(
self
.
x10
,
self
.
axis10
),
P
.
ReduceMean
(
self
.
keep_dims11
)(
self
.
x11
,
self
.
axis11
),
P
.
ReduceMean
(
self
.
keep_dims11
)(
self
.
x11
,
self
.
axis11
),
P
.
ReduceMean
(
self
.
keep_dims12
)(
self
.
x12
,
self
.
axis12
),
P
.
ReduceMean
(
self
.
keep_dims12
)(
self
.
x12
,
self
.
axis12
),
P
.
ReduceMean
(
self
.
keep_dims13
)(
self
.
x13
,
self
.
axis13
))
P
.
ReduceMean
(
self
.
keep_dims13
)(
self
.
x13
,
self
.
axis13
),
P
.
ReduceMean
(
self
.
keep_dims14
)(
self
.
x14
,
self
.
axis14
))
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
level0
...
@@ -166,85 +176,91 @@ def test_ReduceMean():
...
@@ -166,85 +176,91 @@ def test_ReduceMean():
output
=
reduce_mean
()
output
=
reduce_mean
()
expect0
=
np
.
mean
(
x0
,
axis
=
axis0
,
keepdims
=
keep_dims0
)
expect0
=
np
.
mean
(
x0
,
axis
=
axis0
,
keepdims
=
keep_dims0
)
diff0
=
output
[
0
].
asnumpy
()
-
expect0
diff0
=
abs
(
output
[
0
].
asnumpy
()
-
expect0
)
error0
=
np
.
ones
(
shape
=
expect0
.
shape
)
*
1.0e-5
error0
=
np
.
ones
(
shape
=
expect0
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff0
<
error0
)
assert
np
.
all
(
diff0
<
error0
)
assert
(
output
[
0
].
shape
()
==
expect0
.
shape
)
assert
(
output
[
0
].
shape
()
==
expect0
.
shape
)
expect1
=
np
.
mean
(
x1
,
axis
=
axis1
,
keepdims
=
keep_dims1
)
expect1
=
np
.
mean
(
x1
,
axis
=
axis1
,
keepdims
=
keep_dims1
)
diff1
=
output
[
1
].
asnumpy
()
-
expect1
diff1
=
abs
(
output
[
1
].
asnumpy
()
-
expect1
)
error1
=
np
.
ones
(
shape
=
expect1
.
shape
)
*
1.0e-5
error1
=
np
.
ones
(
shape
=
expect1
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff1
<
error1
)
assert
np
.
all
(
diff1
<
error1
)
assert
(
output
[
1
].
shape
()
==
expect1
.
shape
)
assert
(
output
[
1
].
shape
()
==
expect1
.
shape
)
expect2
=
np
.
mean
(
x2
,
axis
=
axis2
,
keepdims
=
keep_dims2
)
expect2
=
np
.
mean
(
x2
,
axis
=
axis2
,
keepdims
=
keep_dims2
)
diff2
=
output
[
2
].
asnumpy
()
-
expect2
diff2
=
abs
(
output
[
2
].
asnumpy
()
-
expect2
)
error2
=
np
.
ones
(
shape
=
expect2
.
shape
)
*
1.0e-5
error2
=
np
.
ones
(
shape
=
expect2
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff2
<
error2
)
assert
np
.
all
(
diff2
<
error2
)
assert
(
output
[
2
].
shape
()
==
expect2
.
shape
)
assert
(
output
[
2
].
shape
()
==
expect2
.
shape
)
expect3
=
np
.
mean
(
x3
,
axis
=
axis3
,
keepdims
=
keep_dims3
)
expect3
=
np
.
mean
(
x3
,
axis
=
axis3
,
keepdims
=
keep_dims3
)
diff3
=
output
[
3
].
asnumpy
()
-
expect3
diff3
=
abs
(
output
[
3
].
asnumpy
()
-
expect3
)
error3
=
np
.
ones
(
shape
=
expect3
.
shape
)
*
1.0e-5
error3
=
np
.
ones
(
shape
=
expect3
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff3
<
error3
)
assert
np
.
all
(
diff3
<
error3
)
assert
(
output
[
3
].
shape
()
==
expect3
.
shape
)
assert
(
output
[
3
].
shape
()
==
expect3
.
shape
)
expect4
=
np
.
mean
(
x4
,
axis
=
axis4
,
keepdims
=
keep_dims4
)
expect4
=
np
.
mean
(
x4
,
axis
=
axis4
,
keepdims
=
keep_dims4
)
diff4
=
output
[
4
].
asnumpy
()
-
expect4
diff4
=
abs
(
output
[
4
].
asnumpy
()
-
expect4
)
error4
=
np
.
ones
(
shape
=
expect4
.
shape
)
*
1.0e-5
error4
=
np
.
ones
(
shape
=
expect4
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff4
<
error4
)
assert
np
.
all
(
diff4
<
error4
)
assert
(
output
[
4
].
shape
()
==
expect4
.
shape
)
assert
(
output
[
4
].
shape
()
==
expect4
.
shape
)
expect5
=
np
.
mean
(
x5
,
axis
=
axis5
,
keepdims
=
keep_dims5
)
expect5
=
np
.
mean
(
x5
,
axis
=
axis5
,
keepdims
=
keep_dims5
)
diff5
=
output
[
5
].
asnumpy
()
-
expect5
diff5
=
abs
(
output
[
5
].
asnumpy
()
-
expect5
)
error5
=
np
.
ones
(
shape
=
expect5
.
shape
)
*
1.0e-5
error5
=
np
.
ones
(
shape
=
expect5
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff5
<
error5
)
assert
np
.
all
(
diff5
<
error5
)
assert
(
output
[
5
].
shape
()
==
expect5
.
shape
)
assert
(
output
[
5
].
shape
()
==
expect5
.
shape
)
expect6
=
np
.
mean
(
x6
,
axis
=
axis6
,
keepdims
=
keep_dims6
)
expect6
=
np
.
mean
(
x6
,
axis
=
axis6
,
keepdims
=
keep_dims6
)
diff6
=
output
[
6
].
asnumpy
()
-
expect6
diff6
=
abs
(
output
[
6
].
asnumpy
()
-
expect6
)
error6
=
np
.
ones
(
shape
=
expect6
.
shape
)
*
1.0e-5
error6
=
np
.
ones
(
shape
=
expect6
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff6
<
error6
)
assert
np
.
all
(
diff6
<
error6
)
assert
(
output
[
6
].
shape
()
==
expect6
.
shape
)
assert
(
output
[
6
].
shape
()
==
expect6
.
shape
)
expect7
=
np
.
mean
(
x7
,
axis
=
axis7
,
keepdims
=
keep_dims7
)
expect7
=
np
.
mean
(
x7
,
axis
=
axis7
,
keepdims
=
keep_dims7
)
diff7
=
output
[
7
].
asnumpy
()
-
expect7
diff7
=
abs
(
output
[
7
].
asnumpy
()
-
expect7
)
error7
=
np
.
ones
(
shape
=
expect7
.
shape
)
*
1.0e-5
error7
=
np
.
ones
(
shape
=
expect7
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff7
<
error7
)
assert
np
.
all
(
diff7
<
error7
)
assert
(
output
[
7
].
shape
()
==
expect7
.
shape
)
assert
(
output
[
7
].
shape
()
==
expect7
.
shape
)
expect8
=
np
.
mean
(
x8
,
axis
=
axis8
,
keepdims
=
keep_dims8
)
expect8
=
np
.
mean
(
x8
,
axis
=
axis8
,
keepdims
=
keep_dims8
)
diff8
=
output
[
8
].
asnumpy
()
-
expect8
diff8
=
abs
(
output
[
8
].
asnumpy
()
-
expect8
)
error8
=
np
.
ones
(
shape
=
expect8
.
shape
)
*
1.0e-5
error8
=
np
.
ones
(
shape
=
expect8
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff8
<
error8
)
assert
np
.
all
(
diff8
<
error8
)
assert
(
output
[
8
].
shape
()
==
expect8
.
shape
)
assert
(
output
[
8
].
shape
()
==
expect8
.
shape
)
expect9
=
np
.
mean
(
x9
,
axis
=
axis9
,
keepdims
=
keep_dims9
)
expect9
=
np
.
mean
(
x9
,
axis
=
axis9
,
keepdims
=
keep_dims9
)
diff9
=
output
[
9
].
asnumpy
()
-
expect9
diff9
=
abs
(
output
[
9
].
asnumpy
()
-
expect9
)
error9
=
np
.
ones
(
shape
=
expect9
.
shape
)
*
1.0e-5
error9
=
np
.
ones
(
shape
=
expect9
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff9
<
error9
)
assert
np
.
all
(
diff9
<
error9
)
assert
(
output
[
9
].
shape
()
==
expect9
.
shape
)
assert
(
output
[
9
].
shape
()
==
expect9
.
shape
)
expect10
=
np
.
mean
(
x10
,
axis
=
axis10
,
keepdims
=
keep_dims10
)
expect10
=
np
.
mean
(
x10
,
axis
=
axis10
,
keepdims
=
keep_dims10
)
diff10
=
output
[
10
].
asnumpy
()
-
expect10
diff10
=
abs
(
output
[
10
].
asnumpy
()
-
expect10
)
error10
=
np
.
ones
(
shape
=
expect10
.
shape
)
*
1.0e-5
error10
=
np
.
ones
(
shape
=
expect10
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff10
<
error10
)
assert
np
.
all
(
diff10
<
error10
)
assert
(
output
[
10
].
shape
()
==
expect10
.
shape
)
assert
(
output
[
10
].
shape
()
==
expect10
.
shape
)
expect11
=
np
.
mean
(
x11
,
axis
=
axis11
,
keepdims
=
keep_dims11
)
expect11
=
np
.
mean
(
x11
,
axis
=
axis11
,
keepdims
=
keep_dims11
)
diff11
=
output
[
11
].
asnumpy
()
-
expect11
diff11
=
abs
(
output
[
11
].
asnumpy
()
-
expect11
)
error11
=
np
.
ones
(
shape
=
expect11
.
shape
)
*
1.0e-5
error11
=
np
.
ones
(
shape
=
expect11
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff11
<
error11
)
assert
np
.
all
(
diff11
<
error11
)
assert
(
output
[
11
].
shape
()
==
expect11
.
shape
)
assert
(
output
[
11
].
shape
()
==
expect11
.
shape
)
expect12
=
np
.
sum
(
x12
,
axis
=
axis12
,
keepdims
=
keep_dims12
)
expect12
=
np
.
mean
(
x12
,
axis
=
axis12
,
keepdims
=
keep_dims12
)
diff12
=
output
[
12
].
asnumpy
()
-
expect12
diff12
=
abs
(
output
[
12
].
asnumpy
()
-
expect12
)
error12
=
np
.
ones
(
shape
=
expect12
.
shape
)
*
1.0e-5
error12
=
np
.
ones
(
shape
=
expect12
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff12
<
error12
)
assert
np
.
all
(
diff12
<
error12
)
assert
(
output
[
12
].
shape
()
==
expect12
.
shape
)
assert
(
output
[
12
].
shape
()
==
expect12
.
shape
)
expect13
=
np
.
sum
(
x13
,
axis
=
axis13
,
keepdims
=
keep_dims13
)
expect13
=
np
.
mean
(
x13
,
axis
=
axis13
,
keepdims
=
keep_dims13
)
diff13
=
output
[
13
].
asnumpy
()
-
expect13
diff13
=
abs
(
output
[
13
].
asnumpy
()
-
expect13
)
error13
=
np
.
ones
(
shape
=
expect13
.
shape
)
*
1.0e-5
error13
=
np
.
ones
(
shape
=
expect13
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff13
<
error13
)
assert
np
.
all
(
diff13
<
error13
)
assert
(
output
[
13
].
shape
()
==
expect13
.
shape
)
assert
(
output
[
13
].
shape
()
==
expect13
.
shape
)
expect14
=
np
.
mean
(
x14
,
axis
=
np_axis14
,
keepdims
=
keep_dims14
)
diff14
=
abs
(
output
[
14
].
asnumpy
()
-
expect14
)
error14
=
np
.
ones
(
shape
=
expect14
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff14
<
error14
)
assert
(
output
[
14
].
shape
()
==
expect14
.
shape
)
tests/st/ops/gpu/test_reduce_sum_op.py
浏览文件 @
d3936b9f
...
@@ -79,6 +79,11 @@ x13 = np.random.rand(2, 3, 4, 4).astype(np.float32)
...
@@ -79,6 +79,11 @@ x13 = np.random.rand(2, 3, 4, 4).astype(np.float32)
axis13
=
(
-
2
,
-
1
)
axis13
=
(
-
2
,
-
1
)
keep_dims13
=
True
keep_dims13
=
True
x14
=
np
.
random
.
rand
(
1
,
1
,
1
,
1
).
astype
(
np
.
float32
)
axis14
=
()
np_axis14
=
None
keep_dims14
=
True
context
.
set_context
(
device_target
=
'GPU'
)
context
.
set_context
(
device_target
=
'GPU'
)
...
@@ -142,6 +147,10 @@ class ReduceSum(nn.Cell):
...
@@ -142,6 +147,10 @@ class ReduceSum(nn.Cell):
self
.
axis13
=
axis13
self
.
axis13
=
axis13
self
.
keep_dims13
=
keep_dims13
self
.
keep_dims13
=
keep_dims13
self
.
x14
=
Tensor
(
x14
)
self
.
axis14
=
axis14
self
.
keep_dims14
=
keep_dims14
@
ms_function
@
ms_function
def
construct
(
self
):
def
construct
(
self
):
return
(
P
.
ReduceSum
(
self
.
keep_dims0
)(
self
.
x0
,
self
.
axis0
),
return
(
P
.
ReduceSum
(
self
.
keep_dims0
)(
self
.
x0
,
self
.
axis0
),
...
@@ -157,7 +166,8 @@ class ReduceSum(nn.Cell):
...
@@ -157,7 +166,8 @@ class ReduceSum(nn.Cell):
P
.
ReduceSum
(
self
.
keep_dims10
)(
self
.
x10
,
self
.
axis10
),
P
.
ReduceSum
(
self
.
keep_dims10
)(
self
.
x10
,
self
.
axis10
),
P
.
ReduceSum
(
self
.
keep_dims11
)(
self
.
x11
,
self
.
axis11
),
P
.
ReduceSum
(
self
.
keep_dims11
)(
self
.
x11
,
self
.
axis11
),
P
.
ReduceSum
(
self
.
keep_dims12
)(
self
.
x12
,
self
.
axis12
),
P
.
ReduceSum
(
self
.
keep_dims12
)(
self
.
x12
,
self
.
axis12
),
P
.
ReduceSum
(
self
.
keep_dims13
)(
self
.
x13
,
self
.
axis13
))
P
.
ReduceSum
(
self
.
keep_dims13
)(
self
.
x13
,
self
.
axis13
),
P
.
ReduceSum
(
self
.
keep_dims14
)(
self
.
x14
,
self
.
axis14
))
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
level0
...
@@ -168,85 +178,91 @@ def test_ReduceSum():
...
@@ -168,85 +178,91 @@ def test_ReduceSum():
output
=
reduce_sum
()
output
=
reduce_sum
()
expect0
=
np
.
sum
(
x0
,
axis
=
axis0
,
keepdims
=
keep_dims0
)
expect0
=
np
.
sum
(
x0
,
axis
=
axis0
,
keepdims
=
keep_dims0
)
diff0
=
output
[
0
].
asnumpy
()
-
expect0
diff0
=
abs
(
output
[
0
].
asnumpy
()
-
expect0
)
error0
=
np
.
ones
(
shape
=
expect0
.
shape
)
*
1.0e-5
error0
=
np
.
ones
(
shape
=
expect0
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff0
<
error0
)
assert
np
.
all
(
diff0
<
error0
)
assert
(
output
[
0
].
shape
()
==
expect0
.
shape
)
assert
(
output
[
0
].
shape
()
==
expect0
.
shape
)
expect1
=
np
.
sum
(
x1
,
axis
=
axis1
,
keepdims
=
keep_dims1
)
expect1
=
np
.
sum
(
x1
,
axis
=
axis1
,
keepdims
=
keep_dims1
)
diff1
=
output
[
1
].
asnumpy
()
-
expect1
diff1
=
abs
(
output
[
1
].
asnumpy
()
-
expect1
)
error1
=
np
.
ones
(
shape
=
expect1
.
shape
)
*
1.0e-5
error1
=
np
.
ones
(
shape
=
expect1
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff1
<
error1
)
assert
np
.
all
(
diff1
<
error1
)
assert
(
output
[
1
].
shape
()
==
expect1
.
shape
)
assert
(
output
[
1
].
shape
()
==
expect1
.
shape
)
expect2
=
np
.
sum
(
x2
,
axis
=
axis2
,
keepdims
=
keep_dims2
)
expect2
=
np
.
sum
(
x2
,
axis
=
axis2
,
keepdims
=
keep_dims2
)
diff2
=
output
[
2
].
asnumpy
()
-
expect2
diff2
=
abs
(
output
[
2
].
asnumpy
()
-
expect2
)
error2
=
np
.
ones
(
shape
=
expect2
.
shape
)
*
1.0e-5
error2
=
np
.
ones
(
shape
=
expect2
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff2
<
error2
)
assert
np
.
all
(
diff2
<
error2
)
assert
(
output
[
2
].
shape
()
==
expect2
.
shape
)
assert
(
output
[
2
].
shape
()
==
expect2
.
shape
)
expect3
=
np
.
sum
(
x3
,
axis
=
axis3
,
keepdims
=
keep_dims3
)
expect3
=
np
.
sum
(
x3
,
axis
=
axis3
,
keepdims
=
keep_dims3
)
diff3
=
output
[
3
].
asnumpy
()
-
expect3
diff3
=
abs
(
output
[
3
].
asnumpy
()
-
expect3
)
error3
=
np
.
ones
(
shape
=
expect3
.
shape
)
*
1.0e-5
error3
=
np
.
ones
(
shape
=
expect3
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff3
<
error3
)
assert
np
.
all
(
diff3
<
error3
)
assert
(
output
[
3
].
shape
()
==
expect3
.
shape
)
assert
(
output
[
3
].
shape
()
==
expect3
.
shape
)
expect4
=
np
.
sum
(
x4
,
axis
=
np_axis4
,
keepdims
=
keep_dims4
)
expect4
=
np
.
sum
(
x4
,
axis
=
np_axis4
,
keepdims
=
keep_dims4
)
diff4
=
output
[
4
].
asnumpy
()
-
expect4
diff4
=
abs
(
output
[
4
].
asnumpy
()
-
expect4
)
error4
=
np
.
ones
(
shape
=
expect4
.
shape
)
*
1.0e-5
error4
=
np
.
ones
(
shape
=
expect4
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff4
<
error4
)
assert
np
.
all
(
diff4
<
error4
)
assert
(
output
[
4
].
shape
()
==
expect4
.
shape
)
assert
(
output
[
4
].
shape
()
==
expect4
.
shape
)
expect5
=
np
.
sum
(
x5
,
axis
=
np_axis5
,
keepdims
=
keep_dims5
)
expect5
=
np
.
sum
(
x5
,
axis
=
np_axis5
,
keepdims
=
keep_dims5
)
diff5
=
output
[
5
].
asnumpy
()
-
expect5
diff5
=
abs
(
output
[
5
].
asnumpy
()
-
expect5
)
error5
=
np
.
ones
(
shape
=
expect5
.
shape
)
*
1.0e-5
error5
=
np
.
ones
(
shape
=
expect5
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff5
<
error5
)
assert
np
.
all
(
diff5
<
error5
)
assert
(
output
[
5
].
shape
()
==
expect5
.
shape
)
assert
(
output
[
5
].
shape
()
==
expect5
.
shape
)
expect6
=
np
.
sum
(
x6
,
axis
=
axis6
,
keepdims
=
keep_dims6
)
expect6
=
np
.
sum
(
x6
,
axis
=
axis6
,
keepdims
=
keep_dims6
)
diff6
=
output
[
6
].
asnumpy
()
-
expect6
diff6
=
abs
(
output
[
6
].
asnumpy
()
-
expect6
)
error6
=
np
.
ones
(
shape
=
expect6
.
shape
)
*
1.0e-5
error6
=
np
.
ones
(
shape
=
expect6
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff6
<
error6
)
assert
np
.
all
(
diff6
<
error6
)
assert
(
output
[
6
].
shape
()
==
expect6
.
shape
)
assert
(
output
[
6
].
shape
()
==
expect6
.
shape
)
expect7
=
np
.
sum
(
x7
,
axis
=
axis7
,
keepdims
=
keep_dims7
)
expect7
=
np
.
sum
(
x7
,
axis
=
axis7
,
keepdims
=
keep_dims7
)
diff7
=
output
[
7
].
asnumpy
()
-
expect7
diff7
=
abs
(
output
[
7
].
asnumpy
()
-
expect7
)
error7
=
np
.
ones
(
shape
=
expect7
.
shape
)
*
1.0e-5
error7
=
np
.
ones
(
shape
=
expect7
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff7
<
error7
)
assert
np
.
all
(
diff7
<
error7
)
assert
(
output
[
7
].
shape
()
==
expect7
.
shape
)
assert
(
output
[
7
].
shape
()
==
expect7
.
shape
)
expect8
=
np
.
sum
(
x8
,
axis
=
axis8
,
keepdims
=
keep_dims8
)
expect8
=
np
.
sum
(
x8
,
axis
=
axis8
,
keepdims
=
keep_dims8
)
diff8
=
output
[
8
].
asnumpy
()
-
expect8
diff8
=
abs
(
output
[
8
].
asnumpy
()
-
expect8
)
error8
=
np
.
ones
(
shape
=
expect8
.
shape
)
*
1.0e-5
error8
=
np
.
ones
(
shape
=
expect8
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff8
<
error8
)
assert
np
.
all
(
diff8
<
error8
)
assert
(
output
[
8
].
shape
()
==
expect8
.
shape
)
assert
(
output
[
8
].
shape
()
==
expect8
.
shape
)
expect9
=
np
.
sum
(
x9
,
axis
=
axis9
,
keepdims
=
keep_dims9
)
expect9
=
np
.
sum
(
x9
,
axis
=
axis9
,
keepdims
=
keep_dims9
)
diff9
=
output
[
9
].
asnumpy
()
-
expect9
diff9
=
abs
(
output
[
9
].
asnumpy
()
-
expect9
)
error9
=
np
.
ones
(
shape
=
expect9
.
shape
)
*
1.0e-5
error9
=
np
.
ones
(
shape
=
expect9
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff9
<
error9
)
assert
np
.
all
(
diff9
<
error9
)
assert
(
output
[
9
].
shape
()
==
expect9
.
shape
)
assert
(
output
[
9
].
shape
()
==
expect9
.
shape
)
expect10
=
np
.
sum
(
x10
,
axis
=
axis10
,
keepdims
=
keep_dims10
)
expect10
=
np
.
sum
(
x10
,
axis
=
axis10
,
keepdims
=
keep_dims10
)
diff10
=
output
[
10
].
asnumpy
()
-
expect10
diff10
=
abs
(
output
[
10
].
asnumpy
()
-
expect10
)
error10
=
np
.
ones
(
shape
=
expect10
.
shape
)
*
1.0e-5
error10
=
np
.
ones
(
shape
=
expect10
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff10
<
error10
)
assert
np
.
all
(
diff10
<
error10
)
assert
(
output
[
10
].
shape
()
==
expect10
.
shape
)
assert
(
output
[
10
].
shape
()
==
expect10
.
shape
)
expect11
=
np
.
sum
(
x11
,
axis
=
axis11
,
keepdims
=
keep_dims11
)
expect11
=
np
.
sum
(
x11
,
axis
=
axis11
,
keepdims
=
keep_dims11
)
diff11
=
output
[
11
].
asnumpy
()
-
expect11
diff11
=
abs
(
output
[
11
].
asnumpy
()
-
expect11
)
error11
=
np
.
ones
(
shape
=
expect11
.
shape
)
*
1.0e-5
error11
=
np
.
ones
(
shape
=
expect11
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff11
<
error11
)
assert
np
.
all
(
diff11
<
error11
)
assert
(
output
[
11
].
shape
()
==
expect11
.
shape
)
assert
(
output
[
11
].
shape
()
==
expect11
.
shape
)
expect12
=
np
.
sum
(
x12
,
axis
=
axis12
,
keepdims
=
keep_dims12
)
expect12
=
np
.
sum
(
x12
,
axis
=
axis12
,
keepdims
=
keep_dims12
)
diff12
=
output
[
12
].
asnumpy
()
-
expect12
diff12
=
abs
(
output
[
12
].
asnumpy
()
-
expect12
)
error12
=
np
.
ones
(
shape
=
expect12
.
shape
)
*
1.0e-5
error12
=
np
.
ones
(
shape
=
expect12
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff12
<
error12
)
assert
np
.
all
(
diff12
<
error12
)
assert
(
output
[
12
].
shape
()
==
expect12
.
shape
)
assert
(
output
[
12
].
shape
()
==
expect12
.
shape
)
expect13
=
np
.
sum
(
x13
,
axis
=
axis13
,
keepdims
=
keep_dims13
)
expect13
=
np
.
sum
(
x13
,
axis
=
axis13
,
keepdims
=
keep_dims13
)
diff13
=
output
[
13
].
asnumpy
()
-
expect13
diff13
=
abs
(
output
[
13
].
asnumpy
()
-
expect13
)
error13
=
np
.
ones
(
shape
=
expect13
.
shape
)
*
1.0e-5
error13
=
np
.
ones
(
shape
=
expect13
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff13
<
error13
)
assert
np
.
all
(
diff13
<
error13
)
assert
(
output
[
13
].
shape
()
==
expect13
.
shape
)
assert
(
output
[
13
].
shape
()
==
expect13
.
shape
)
expect14
=
np
.
sum
(
x14
,
axis
=
np_axis14
,
keepdims
=
keep_dims14
)
diff14
=
abs
(
output
[
14
].
asnumpy
()
-
expect14
)
error14
=
np
.
ones
(
shape
=
expect14
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff14
<
error14
)
assert
(
output
[
14
].
shape
()
==
expect14
.
shape
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录