Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
5d0801ed
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5d0801ed
编写于
4月 11, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 11, 2020
浏览文件
操作
浏览文件
下载
差异文件
!188 Support pow's second input could be tensor and fix bug in bprop of pow
Merge pull request !188 from zhangbuxue/fix_pow_bprop
上级
ca326d1b
5841fe01
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
95 addition
and
139 deletion
+95
-139
mindspore/ccsrc/parallel/ops_info/arithmetic_info.h
mindspore/ccsrc/parallel/ops_info/arithmetic_info.h
+7
-0
mindspore/ccsrc/parallel/ops_info/elementary_function_info.cc
...spore/ccsrc/parallel/ops_info/elementary_function_info.cc
+0
-47
mindspore/ccsrc/parallel/ops_info/elementary_function_info.h
mindspore/ccsrc/parallel/ops_info/elementary_function_info.h
+0
-10
mindspore/nn/layer/pooling.py
mindspore/nn/layer/pooling.py
+1
-1
mindspore/ops/_grad/grad_math_ops.py
mindspore/ops/_grad/grad_math_ops.py
+5
-6
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+2
-2
mindspore/ops/operations/math_ops.py
mindspore/ops/operations/math_ops.py
+62
-58
mindspore/ops/primitive.py
mindspore/ops/primitive.py
+0
-3
tests/ut/cpp/parallel/ops_info/pow_info_test.cc
tests/ut/cpp/parallel/ops_info/pow_info_test.cc
+9
-9
tests/ut/python/ops/test_math_ops.py
tests/ut/python/ops/test_math_ops.py
+2
-1
tests/ut/python/ops/test_ops.py
tests/ut/python/ops/test_ops.py
+5
-1
tests/ut/python/parallel/test_element_wise_function.py
tests/ut/python/parallel/test_element_wise_function.py
+1
-1
tests/vm_impl/math_ops_vm_impl.py
tests/vm_impl/math_ops_vm_impl.py
+1
-0
未找到文件。
mindspore/ccsrc/parallel/ops_info/arithmetic_info.h
浏览文件 @
5d0801ed
...
@@ -98,6 +98,13 @@ class FloorDivInfo : public ArithmeticBase {
...
@@ -98,6 +98,13 @@ class FloorDivInfo : public ArithmeticBase {
~
FloorDivInfo
()
override
=
default
;
~
FloorDivInfo
()
override
=
default
;
};
};
class
PowInfo
:
public
ArithmeticBase
{
public:
PowInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{}
~
PowInfo
()
override
=
default
;
};
class
GreaterInfo
:
public
ArithmeticBase
{
class
GreaterInfo
:
public
ArithmeticBase
{
public:
public:
GreaterInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
GreaterInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
...
...
mindspore/ccsrc/parallel/ops_info/elementary_function_info.cc
已删除
100644 → 0
浏览文件 @
ca326d1b
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "parallel/ops_info/elementary_function_info.h"
namespace
mindspore
{
namespace
parallel
{
Status
PowInfo
::
InferMirrorOps
()
{
mirror_ops_
.
clear
();
Shape
tensor_map
=
inputs_tensor_map_
[
0
];
std
::
vector
<
Group
>
group
;
if
(
CreateGroupByTensorMap
(
tensor_map
,
&
group
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
" : Create group failed."
;
return
FAILED
;
}
OperatorVector
mirror_op
;
OperatorVector
op_for_value
;
if
(
group
.
empty
())
{
MS_LOG
(
INFO
)
<<
name_
<<
" : The mirror ops is empty."
;
return
SUCCESS
;
}
else
{
mirror_op
=
CreateMirrorOps
(
group
[
0
].
name
(),
group
[
0
].
GetDevNum
());
mirror_ops_
.
push_back
(
mirror_op
);
mirror_ops_
.
push_back
(
op_for_value
);
std
::
string
group_name
=
group
[
0
].
name
();
MS_LOG
(
INFO
)
<<
name_
<<
" : Create the mirror ops success, the group name is "
<<
group_name
;
}
return
SUCCESS
;
}
}
// namespace parallel
}
// namespace mindspore
mindspore/ccsrc/parallel/ops_info/elementary_function_info.h
浏览文件 @
5d0801ed
...
@@ -27,16 +27,6 @@
...
@@ -27,16 +27,6 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
parallel
{
namespace
parallel
{
class
PowInfo
:
public
ActivationOther
{
public:
PowInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
ActivationOther
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{}
~
PowInfo
()
override
=
default
;
protected:
Status
InferMirrorOps
()
override
;
};
class
ExpInfo
:
public
ActivationOther
{
class
ExpInfo
:
public
ActivationOther
{
public:
public:
ExpInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
ExpInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
...
...
mindspore/nn/layer/pooling.py
浏览文件 @
5d0801ed
...
@@ -58,7 +58,7 @@ class _PoolNd(Cell):
...
@@ -58,7 +58,7 @@ class _PoolNd(Cell):
pass
pass
def
extend_repr
(
self
):
def
extend_repr
(
self
):
return
'kernel_size={kernel_size}, stride
s={strides
}, pad_mode={pad_mode}'
.
format
(
**
self
.
__dict__
)
return
'kernel_size={kernel_size}, stride
={stride
}, pad_mode={pad_mode}'
.
format
(
**
self
.
__dict__
)
class
MaxPool2d
(
_PoolNd
):
class
MaxPool2d
(
_PoolNd
):
...
...
mindspore/ops/_grad/grad_math_ops.py
浏览文件 @
5d0801ed
...
@@ -336,14 +336,13 @@ def get_bprop_log(self):
...
@@ -336,14 +336,13 @@ def get_bprop_log(self):
@
bprop_getters
.
register
(
P
.
Pow
)
@
bprop_getters
.
register
(
P
.
Pow
)
def
get_bprop_pow
(
self
):
def
get_bprop_pow
(
self
):
"""Grad definition for `Pow` operation."""
"""Grad definition for `Pow` operation."""
pow_
=
P
.
Pow
()
pow_op
=
P
.
Pow
()
cast
=
P
.
Cast
()
ln
=
P
.
Log
()
dtype
=
P
.
DType
()
def
bprop
(
x
,
power
,
out
,
dout
):
def
bprop
(
x
,
power
,
out
,
dout
):
g
=
cast
(
F
.
tuple_to_array
((
power
,)),
dtype
(
x
))
*
pow_
(
x
,
power
-
1.0
)
dx
=
power
*
pow_op
(
x
,
power
-
1.0
)
*
dout
d
x
=
g
*
dout
d
power
=
pow_op
(
x
,
power
)
*
ln
(
x
)
*
dout
return
dx
,
0
return
dx
,
dpower
return
bprop
return
bprop
...
...
mindspore/ops/operations/array_ops.py
浏览文件 @
5d0801ed
...
@@ -1097,7 +1097,7 @@ class ArgMaxWithValue(PrimitiveWithInfer):
...
@@ -1097,7 +1097,7 @@ class ArgMaxWithValue(PrimitiveWithInfer):
axis
=
self
.
axis
axis
=
self
.
axis
x_rank
=
len
(
x_shape
)
x_rank
=
len
(
x_shape
)
validator
.
check_int_range
(
"axis"
,
axis
,
-
x_rank
,
x_rank
,
Rel
.
INC_LEFT
)
validator
.
check_int_range
(
"axis"
,
axis
,
-
x_rank
,
x_rank
,
Rel
.
INC_LEFT
)
ouput_shape
=
_infer_shape_reduce
(
x_shape
,
self
.
axis
,
self
.
keep_dims
,
self
.
prim_name
()
)
ouput_shape
=
_infer_shape_reduce
(
x_shape
,
self
.
axis
,
self
.
keep_dims
,
self
.
name
)
return
ouput_shape
,
ouput_shape
return
ouput_shape
,
ouput_shape
def
infer_dtype
(
self
,
x_dtype
):
def
infer_dtype
(
self
,
x_dtype
):
...
@@ -1143,7 +1143,7 @@ class ArgMinWithValue(PrimitiveWithInfer):
...
@@ -1143,7 +1143,7 @@ class ArgMinWithValue(PrimitiveWithInfer):
axis
=
self
.
axis
axis
=
self
.
axis
x_rank
=
len
(
x_shape
)
x_rank
=
len
(
x_shape
)
validator
.
check_int_range
(
"axis"
,
axis
,
-
x_rank
,
x_rank
,
Rel
.
INC_LEFT
)
validator
.
check_int_range
(
"axis"
,
axis
,
-
x_rank
,
x_rank
,
Rel
.
INC_LEFT
)
ouput_shape
=
_infer_shape_reduce
(
x_shape
,
self
.
axis
,
self
.
keep_dims
,
self
.
prim_name
()
)
ouput_shape
=
_infer_shape_reduce
(
x_shape
,
self
.
axis
,
self
.
keep_dims
,
self
.
name
)
return
ouput_shape
,
ouput_shape
return
ouput_shape
,
ouput_shape
def
infer_dtype
(
self
,
x_dtype
):
def
infer_dtype
(
self
,
x_dtype
):
...
...
mindspore/ops/operations/math_ops.py
浏览文件 @
5d0801ed
...
@@ -74,7 +74,7 @@ class _BinaryOp(PrimitiveWithInfer):
...
@@ -74,7 +74,7 @@ class _BinaryOp(PrimitiveWithInfer):
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'y'
],
outputs
=
[
'output'
])
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'y'
],
outputs
=
[
'output'
])
def
infer_shape
(
self
,
x_shape
,
y_shape
):
def
infer_shape
(
self
,
x_shape
,
y_shape
):
return
_get_broadcast_shape
(
x_shape
,
y_shape
,
self
.
prim_name
()
)
return
_get_broadcast_shape
(
x_shape
,
y_shape
,
self
.
name
)
class
_MathBinaryOp
(
_BinaryOp
):
class
_MathBinaryOp
(
_BinaryOp
):
...
@@ -89,7 +89,7 @@ class _MathBinaryOp(_BinaryOp):
...
@@ -89,7 +89,7 @@ class _MathBinaryOp(_BinaryOp):
return
x_dtype
return
x_dtype
def
infer_dtype
(
self
,
x_dtype
,
y_dtype
):
def
infer_dtype
(
self
,
x_dtype
,
y_dtype
):
return
_MathBinaryOp
.
do_infer_dtype
(
x_dtype
,
y_dtype
,
mstype
.
number_type
,
self
.
prim_name
()
)
return
_MathBinaryOp
.
do_infer_dtype
(
x_dtype
,
y_dtype
,
mstype
.
number_type
,
self
.
name
)
class
TensorAdd
(
_MathBinaryOp
):
class
TensorAdd
(
_MathBinaryOp
):
...
@@ -158,7 +158,7 @@ class AssignAdd(PrimitiveWithInfer):
...
@@ -158,7 +158,7 @@ class AssignAdd(PrimitiveWithInfer):
def
infer_dtype
(
self
,
variable
,
value
):
def
infer_dtype
(
self
,
variable
,
value
):
args
=
{
"value"
:
value
}
args
=
{
"value"
:
value
}
validator
.
check_scalar_or_tensor_type_same
(
args
,
mstype
.
number_type
,
self
.
prim_name
()
)
validator
.
check_scalar_or_tensor_type_same
(
args
,
mstype
.
number_type
,
self
.
name
)
return
value
return
value
...
@@ -201,7 +201,7 @@ class AssignSub(PrimitiveWithInfer):
...
@@ -201,7 +201,7 @@ class AssignSub(PrimitiveWithInfer):
def
infer_dtype
(
self
,
variable
,
value
):
def
infer_dtype
(
self
,
variable
,
value
):
args
=
{
"value"
:
value
}
args
=
{
"value"
:
value
}
validator
.
check_scalar_or_tensor_type_same
(
args
,
mstype
.
number_type
,
self
.
prim_name
()
)
validator
.
check_scalar_or_tensor_type_same
(
args
,
mstype
.
number_type
,
self
.
name
)
return
value
return
value
...
@@ -222,16 +222,16 @@ class _Reduce(PrimitiveWithInfer):
...
@@ -222,16 +222,16 @@ class _Reduce(PrimitiveWithInfer):
@
prim_attr_register
@
prim_attr_register
def
__init__
(
self
,
keep_dims
=
False
):
def
__init__
(
self
,
keep_dims
=
False
):
"""init Reduce"""
"""init Reduce"""
validator
.
check_value_type
(
'keep_dims'
,
keep_dims
,
[
bool
],
self
.
prim_name
()
)
validator
.
check_value_type
(
'keep_dims'
,
keep_dims
,
[
bool
],
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'input_x'
,
'axis'
],
outputs
=
[
'y'
])
self
.
init_prim_io_names
(
inputs
=
[
'input_x'
,
'axis'
],
outputs
=
[
'y'
])
def
do_infer
(
self
,
input_x
,
axis
,
valid_dtype
=
mstype
.
number_type
):
def
do_infer
(
self
,
input_x
,
axis
,
valid_dtype
=
mstype
.
number_type
):
axis_v
=
axis
[
'value'
]
axis_v
=
axis
[
'value'
]
input_shp
=
input_x
[
'shape'
]
input_shp
=
input_x
[
'shape'
]
args
=
{
'input_x'
:
input_x
[
'dtype'
]}
args
=
{
'input_x'
:
input_x
[
'dtype'
]}
validator
.
check_tensor_type_same
(
args
,
valid_dtype
,
self
.
prim_name
()
)
validator
.
check_tensor_type_same
(
args
,
valid_dtype
,
self
.
name
)
input_shp
=
_infer_shape_reduce
(
input_shp
,
axis_v
,
self
.
keep_dims
,
self
.
prim_name
()
)
input_shp
=
_infer_shape_reduce
(
input_shp
,
axis_v
,
self
.
keep_dims
,
self
.
name
)
return
{
'shape'
:
input_shp
,
return
{
'shape'
:
input_shp
,
'dtype'
:
input_x
[
'dtype'
],
'dtype'
:
input_x
[
'dtype'
],
'value'
:
None
}
'value'
:
None
}
...
@@ -466,7 +466,7 @@ class CumProd(PrimitiveWithInfer):
...
@@ -466,7 +466,7 @@ class CumProd(PrimitiveWithInfer):
"""
"""
@
prim_attr_register
@
prim_attr_register
def
__init__
(
self
,
exclusive
=
False
,
reverse
=
False
):
def
__init__
(
self
,
exclusive
=
False
,
reverse
=
False
):
cls_name
=
self
.
prim_name
()
cls_name
=
self
.
name
self
.
exclusive
=
validator
.
check_value_type
(
"exclusive"
,
exclusive
,
[
bool
],
cls_name
)
self
.
exclusive
=
validator
.
check_value_type
(
"exclusive"
,
exclusive
,
[
bool
],
cls_name
)
self
.
reverse
=
validator
.
check_value_type
(
"reverse"
,
reverse
,
[
bool
],
cls_name
)
self
.
reverse
=
validator
.
check_value_type
(
"reverse"
,
reverse
,
[
bool
],
cls_name
)
...
@@ -474,7 +474,7 @@ class CumProd(PrimitiveWithInfer):
...
@@ -474,7 +474,7 @@ class CumProd(PrimitiveWithInfer):
return
x_shape
return
x_shape
def
infer_dtype
(
self
,
x_type
,
axis_type
):
def
infer_dtype
(
self
,
x_type
,
axis_type
):
cls_name
=
self
.
prim_name
()
cls_name
=
self
.
name
validator
.
check_tensor_type_same
({
'x'
:
x_type
},
mstype
.
number_type
,
cls_name
)
validator
.
check_tensor_type_same
({
'x'
:
x_type
},
mstype
.
number_type
,
cls_name
)
validator
.
check_subclass
(
"axis"
,
axis_type
,
mstype
.
int_
,
cls_name
)
validator
.
check_subclass
(
"axis"
,
axis_type
,
mstype
.
int_
,
cls_name
)
return
x_type
return
x_type
...
@@ -510,7 +510,7 @@ class MatMul(PrimitiveWithInfer):
...
@@ -510,7 +510,7 @@ class MatMul(PrimitiveWithInfer):
def
__init__
(
self
,
transpose_a
=
False
,
transpose_b
=
False
):
def
__init__
(
self
,
transpose_a
=
False
,
transpose_b
=
False
):
self
.
init_prim_io_names
(
inputs
=
[
'x1'
,
'x2'
],
outputs
=
[
'output'
])
self
.
init_prim_io_names
(
inputs
=
[
'x1'
,
'x2'
],
outputs
=
[
'output'
])
self
.
__setattr_flag__
=
True
self
.
__setattr_flag__
=
True
cls_name
=
self
.
prim_name
()
cls_name
=
self
.
name
validator
.
check_value_type
(
"transpose_a"
,
transpose_a
,
[
bool
],
cls_name
)
validator
.
check_value_type
(
"transpose_a"
,
transpose_a
,
[
bool
],
cls_name
)
validator
.
check_value_type
(
"transpose_b"
,
transpose_b
,
[
bool
],
cls_name
)
validator
.
check_value_type
(
"transpose_b"
,
transpose_b
,
[
bool
],
cls_name
)
...
@@ -521,7 +521,7 @@ class MatMul(PrimitiveWithInfer):
...
@@ -521,7 +521,7 @@ class MatMul(PrimitiveWithInfer):
def
infer_shape
(
self
,
x
,
y
):
def
infer_shape
(
self
,
x
,
y
):
self
.
check_shape_size
(
x
,
y
)
self
.
check_shape_size
(
x
,
y
)
cls_name
=
self
.
prim_name
()
cls_name
=
self
.
name
# expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two
# expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two
for
i
in
range
(
len
(
x
)
-
2
):
for
i
in
range
(
len
(
x
)
-
2
):
if
x
[
i
]
!=
y
[
i
]:
if
x
[
i
]
!=
y
[
i
]:
...
@@ -546,7 +546,7 @@ class MatMul(PrimitiveWithInfer):
...
@@ -546,7 +546,7 @@ class MatMul(PrimitiveWithInfer):
def
infer_dtype
(
self
,
x
,
y
):
def
infer_dtype
(
self
,
x
,
y
):
args
=
{
"x"
:
x
,
"y"
:
y
}
args
=
{
"x"
:
x
,
"y"
:
y
}
validator
.
check_tensor_type_same
(
args
,
mstype
.
float_type
+
mstype
.
int_type
,
self
.
prim_name
()
)
validator
.
check_tensor_type_same
(
args
,
mstype
.
float_type
+
mstype
.
int_type
,
self
.
name
)
return
x
return
x
...
@@ -590,7 +590,7 @@ class BatchMatMul(MatMul):
...
@@ -590,7 +590,7 @@ class BatchMatMul(MatMul):
def
__init__
(
self
,
transpose_a
=
False
,
transpose_b
=
False
):
def
__init__
(
self
,
transpose_a
=
False
,
transpose_b
=
False
):
self
.
init_prim_io_names
(
inputs
=
[
'x1'
,
'x2'
],
outputs
=
[
'output'
])
self
.
init_prim_io_names
(
inputs
=
[
'x1'
,
'x2'
],
outputs
=
[
'output'
])
self
.
__setattr_flag__
=
True
self
.
__setattr_flag__
=
True
cls_name
=
self
.
prim_name
()
cls_name
=
self
.
name
validator
.
check_value_type
(
"transpose_a"
,
transpose_a
,
[
bool
],
cls_name
)
validator
.
check_value_type
(
"transpose_a"
,
transpose_a
,
[
bool
],
cls_name
)
validator
.
check_value_type
(
"transpose_b"
,
transpose_b
,
[
bool
],
cls_name
)
validator
.
check_value_type
(
"transpose_b"
,
transpose_b
,
[
bool
],
cls_name
)
...
@@ -628,13 +628,13 @@ class CumSum(PrimitiveWithInfer):
...
@@ -628,13 +628,13 @@ class CumSum(PrimitiveWithInfer):
@
prim_attr_register
@
prim_attr_register
def
__init__
(
self
,
exclusive
=
False
,
reverse
=
False
):
def
__init__
(
self
,
exclusive
=
False
,
reverse
=
False
):
"""init cumsum"""
"""init cumsum"""
cls_name
=
self
.
prim_name
()
cls_name
=
self
.
name
validator
.
check_value_type
(
'exclusive'
,
exclusive
,
[
bool
],
cls_name
)
validator
.
check_value_type
(
'exclusive'
,
exclusive
,
[
bool
],
cls_name
)
validator
.
check_value_type
(
'reverse'
,
reverse
,
[
bool
],
cls_name
)
validator
.
check_value_type
(
'reverse'
,
reverse
,
[
bool
],
cls_name
)
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'axis'
],
outputs
=
[
'y'
])
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'axis'
],
outputs
=
[
'y'
])
def
__infer__
(
self
,
x
,
axis
):
def
__infer__
(
self
,
x
,
axis
):
cls_name
=
self
.
prim_name
()
cls_name
=
self
.
name
x_shp
=
x
[
'shape'
]
x_shp
=
x
[
'shape'
]
validator
.
check_value_type
(
'axis'
,
axis
[
'value'
],
[
int
],
cls_name
)
validator
.
check_value_type
(
'axis'
,
axis
[
'value'
],
[
int
],
cls_name
)
valid_types
=
[
mstype
.
uint8
,
mstype
.
int8
,
mstype
.
int32
,
mstype
.
float16
,
mstype
.
float32
]
valid_types
=
[
mstype
.
uint8
,
mstype
.
int8
,
mstype
.
int32
,
mstype
.
float16
,
mstype
.
float32
]
...
@@ -679,7 +679,7 @@ class AddN(PrimitiveWithInfer):
...
@@ -679,7 +679,7 @@ class AddN(PrimitiveWithInfer):
self
.
init_prim_io_names
(
inputs
=
[
"inputs"
],
outputs
=
[
"sum"
])
self
.
init_prim_io_names
(
inputs
=
[
"inputs"
],
outputs
=
[
"sum"
])
def
infer_shape
(
self
,
inputs
):
def
infer_shape
(
self
,
inputs
):
cls_name
=
self
.
prim_name
()
cls_name
=
self
.
name
validator
.
check_integer
(
"inputs"
,
len
(
inputs
),
1
,
Rel
.
GE
,
cls_name
)
validator
.
check_integer
(
"inputs"
,
len
(
inputs
),
1
,
Rel
.
GE
,
cls_name
)
self
.
add_prim_attr
(
'n'
,
len
(
inputs
))
self
.
add_prim_attr
(
'n'
,
len
(
inputs
))
shp0
=
inputs
[
0
]
shp0
=
inputs
[
0
]
...
@@ -688,7 +688,7 @@ class AddN(PrimitiveWithInfer):
...
@@ -688,7 +688,7 @@ class AddN(PrimitiveWithInfer):
return
shp0
return
shp0
def
infer_dtype
(
self
,
inputs
):
def
infer_dtype
(
self
,
inputs
):
cls_name
=
self
.
prim_name
()
cls_name
=
self
.
name
validator
.
check_value_type
(
"inputs"
,
inputs
,
[
tuple
,
list
],
cls_name
)
validator
.
check_value_type
(
"inputs"
,
inputs
,
[
tuple
,
list
],
cls_name
)
validator
.
check_integer
(
"inputs"
,
len
(
inputs
),
1
,
Rel
.
GE
,
cls_name
)
validator
.
check_integer
(
"inputs"
,
len
(
inputs
),
1
,
Rel
.
GE
,
cls_name
)
args
=
{}
args
=
{}
...
@@ -718,7 +718,7 @@ class Neg(PrimitiveWithInfer):
...
@@ -718,7 +718,7 @@ class Neg(PrimitiveWithInfer):
return
input_x
return
input_x
def
infer_dtype
(
self
,
input_x
):
def
infer_dtype
(
self
,
input_x
):
validator
.
check_tensor_type_same
({
"input_x"
:
input_x
},
mstype
.
number_type
,
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
"input_x"
:
input_x
},
mstype
.
number_type
,
self
.
name
)
return
input_x
return
input_x
...
@@ -809,7 +809,7 @@ class Square(PrimitiveWithInfer):
...
@@ -809,7 +809,7 @@ class Square(PrimitiveWithInfer):
return
x_shape
return
x_shape
def
infer_dtype
(
self
,
x_type
):
def
infer_dtype
(
self
,
x_type
):
validator
.
check_tensor_type_same
({
"x"
:
x_type
},
mstype
.
number_type
,
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
"x"
:
x_type
},
mstype
.
number_type
,
self
.
name
)
return
x_type
return
x_type
...
@@ -838,7 +838,7 @@ class Rsqrt(PrimitiveWithInfer):
...
@@ -838,7 +838,7 @@ class Rsqrt(PrimitiveWithInfer):
return
x_shape
return
x_shape
def
infer_dtype
(
self
,
x_type
):
def
infer_dtype
(
self
,
x_type
):
validator
.
check_tensor_type_same
({
"x"
:
x_type
},
mstype
.
number_type
,
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
"x"
:
x_type
},
mstype
.
number_type
,
self
.
name
)
return
x_type
return
x_type
...
@@ -867,7 +867,7 @@ class Sqrt(PrimitiveWithInfer):
...
@@ -867,7 +867,7 @@ class Sqrt(PrimitiveWithInfer):
return
x_shape
return
x_shape
def
infer_dtype
(
self
,
x_type
):
def
infer_dtype
(
self
,
x_type
):
validator
.
check_tensor_type_same
({
"x"
:
x_type
},
mstype
.
number_type
,
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
"x"
:
x_type
},
mstype
.
number_type
,
self
.
name
)
return
x_type
return
x_type
...
@@ -897,14 +897,29 @@ class Reciprocal(PrimitiveWithInfer):
...
@@ -897,14 +897,29 @@ class Reciprocal(PrimitiveWithInfer):
return
x
return
x
def
infer_dtype
(
self
,
x
):
def
infer_dtype
(
self
,
x
):
validator
.
check_subclass
(
"x"
,
x
,
mstype
.
tensor
,
self
.
prim_name
()
)
validator
.
check_subclass
(
"x"
,
x
,
mstype
.
tensor
,
self
.
name
)
return
x
return
x
class
Pow
(
PrimitiveWithInfer
):
class
Pow
(
_MathBinaryOp
):
"""
"""
Computes a tensor to the power of the second input.
Computes a tensor to the power of the second input.
The first input must be a tensor, and the second input should be a tensor or a number.
When the inputs are two tensors, the shapes of them could be broadcast,
and the data types of them should be the same.
When the inputs are one tensor and one scalar, the scalar could not be a parameter,
only could be a constant, and the type of the scalar is the same as the data type of the tensor.
Inputs:
- **input_x** (Union[Tensor]) - The first input is a tensor whose data type is number.
- **input_y** (Union[Tensor, Number]) - The second input is a tensor whose data type is same as 'input_x' or
a number.
Outputs:
Tensor, the shape is same as the shape after broadcasting, and the data type is same as 'input_x'.
Inputs:
Inputs:
- **input_x** (Tensor) - The input tensor.
- **input_x** (Tensor) - The input tensor.
- **input_y** (Union[Tensor, Number]) - The exponent part. If exponent is a tensor, its shape must be able to
- **input_y** (Union[Tensor, Number]) - The exponent part. If exponent is a tensor, its shape must be able to
...
@@ -927,17 +942,6 @@ class Pow(PrimitiveWithInfer):
...
@@ -927,17 +942,6 @@ class Pow(PrimitiveWithInfer):
[1.0, 16.0, 64.0]
[1.0, 16.0, 64.0]
"""
"""
@
prim_attr_register
def
__init__
(
self
):
"""init Multiply"""
def
infer_shape
(
self
,
x
,
power
):
return
x
def
infer_dtype
(
self
,
x
,
power
):
validator
.
check_tensor_type_same
({
"x"
:
x
},
mstype
.
number_type
,
self
.
prim_name
())
return
x
class
Exp
(
PrimitiveWithInfer
):
class
Exp
(
PrimitiveWithInfer
):
"""
"""
...
@@ -965,7 +969,7 @@ class Exp(PrimitiveWithInfer):
...
@@ -965,7 +969,7 @@ class Exp(PrimitiveWithInfer):
return
x_shape
return
x_shape
def
infer_dtype
(
self
,
x_type
):
def
infer_dtype
(
self
,
x_type
):
validator
.
check_subclass
(
"x"
,
x_type
,
mstype
.
tensor
,
self
.
prim_name
()
)
validator
.
check_subclass
(
"x"
,
x_type
,
mstype
.
tensor
,
self
.
name
)
return
x_type
return
x_type
...
@@ -994,7 +998,7 @@ class Log(PrimitiveWithInfer):
...
@@ -994,7 +998,7 @@ class Log(PrimitiveWithInfer):
return
x
return
x
def
infer_dtype
(
self
,
x
):
def
infer_dtype
(
self
,
x
):
validator
.
check_subclass
(
"x"
,
x
,
mstype
.
tensor
,
self
.
prim_name
()
)
validator
.
check_subclass
(
"x"
,
x
,
mstype
.
tensor
,
self
.
name
)
return
x
return
x
...
@@ -1176,7 +1180,7 @@ class Floor(PrimitiveWithInfer):
...
@@ -1176,7 +1180,7 @@ class Floor(PrimitiveWithInfer):
return
x_shape
return
x_shape
def
infer_dtype
(
self
,
x_dtype
):
def
infer_dtype
(
self
,
x_dtype
):
validator
.
check_tensor_type_same
({
"x"
:
x_dtype
},
mstype
.
float_type
,
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
"x"
:
x_dtype
},
mstype
.
float_type
,
self
.
name
)
return
x_dtype
return
x_dtype
...
@@ -1231,7 +1235,7 @@ class Acosh(PrimitiveWithInfer):
...
@@ -1231,7 +1235,7 @@ class Acosh(PrimitiveWithInfer):
return
x
return
x
def
infer_dtype
(
self
,
x
):
def
infer_dtype
(
self
,
x
):
validator
.
check_tensor_type_same
({
'x'
:
x
},
mstype
.
number_type
,
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
'x'
:
x
},
mstype
.
number_type
,
self
.
name
)
return
x
return
x
...
@@ -1247,7 +1251,7 @@ class _LogicBinaryOp(_BinaryOp):
...
@@ -1247,7 +1251,7 @@ class _LogicBinaryOp(_BinaryOp):
return
mstype
.
tensor_type
(
mstype
.
bool_
)
return
mstype
.
tensor_type
(
mstype
.
bool_
)
def
infer_dtype
(
self
,
x_dtype
,
y_dtype
):
def
infer_dtype
(
self
,
x_dtype
,
y_dtype
):
return
_LogicBinaryOp
.
do_infer_dtype
(
x_dtype
,
y_dtype
,
prim_name
=
self
.
prim_name
()
)
return
_LogicBinaryOp
.
do_infer_dtype
(
x_dtype
,
y_dtype
,
prim_name
=
self
.
name
)
class
Equal
(
_LogicBinaryOp
):
class
Equal
(
_LogicBinaryOp
):
...
@@ -1283,7 +1287,7 @@ class Equal(_LogicBinaryOp):
...
@@ -1283,7 +1287,7 @@ class Equal(_LogicBinaryOp):
"""
"""
def
infer_dtype
(
self
,
x_dtype
,
y_dtype
):
def
infer_dtype
(
self
,
x_dtype
,
y_dtype
):
return
_LogicBinaryOp
.
do_infer_dtype
(
x_dtype
,
y_dtype
,
mstype
.
number_type
+
(
mstype
.
bool_
,),
self
.
prim_name
()
)
return
_LogicBinaryOp
.
do_infer_dtype
(
x_dtype
,
y_dtype
,
mstype
.
number_type
+
(
mstype
.
bool_
,),
self
.
name
)
class
EqualCount
(
PrimitiveWithInfer
):
class
EqualCount
(
PrimitiveWithInfer
):
...
@@ -1318,7 +1322,7 @@ class EqualCount(PrimitiveWithInfer):
...
@@ -1318,7 +1322,7 @@ class EqualCount(PrimitiveWithInfer):
def
infer_dtype
(
self
,
x_dtype
,
y_dtype
):
def
infer_dtype
(
self
,
x_dtype
,
y_dtype
):
args
=
{
'x'
:
x_dtype
,
'y'
:
y_dtype
}
args
=
{
'x'
:
x_dtype
,
'y'
:
y_dtype
}
validator
.
check_tensor_type_same
(
args
,
mstype
.
number_type
+
(
mstype
.
bool_
,),
self
.
prim_name
()
)
validator
.
check_tensor_type_same
(
args
,
mstype
.
number_type
+
(
mstype
.
bool_
,),
self
.
name
)
return
x_dtype
return
x_dtype
...
@@ -1355,7 +1359,7 @@ class NotEqual(_LogicBinaryOp):
...
@@ -1355,7 +1359,7 @@ class NotEqual(_LogicBinaryOp):
"""
"""
def
infer_dtype
(
self
,
x_dtype
,
y_dtype
):
def
infer_dtype
(
self
,
x_dtype
,
y_dtype
):
return
_LogicBinaryOp
.
do_infer_dtype
(
x_dtype
,
y_dtype
,
mstype
.
number_type
+
(
mstype
.
bool_
,),
self
.
prim_name
()
)
return
_LogicBinaryOp
.
do_infer_dtype
(
x_dtype
,
y_dtype
,
mstype
.
number_type
+
(
mstype
.
bool_
,),
self
.
name
)
class
Greater
(
_LogicBinaryOp
):
class
Greater
(
_LogicBinaryOp
):
...
@@ -1491,7 +1495,7 @@ class LogicalNot(PrimitiveWithInfer):
...
@@ -1491,7 +1495,7 @@ class LogicalNot(PrimitiveWithInfer):
return
x_shape
return
x_shape
def
infer_dtype
(
self
,
x_dtype
):
def
infer_dtype
(
self
,
x_dtype
):
validator
.
check_tensor_type_same
({
"x"
:
x_dtype
},
[
mstype
.
bool_
],
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
"x"
:
x_dtype
},
[
mstype
.
bool_
],
self
.
name
)
return
mstype
.
tensor_type
(
mstype
.
bool_
)
return
mstype
.
tensor_type
(
mstype
.
bool_
)
...
@@ -1521,7 +1525,7 @@ class LogicalAnd(_LogicBinaryOp):
...
@@ -1521,7 +1525,7 @@ class LogicalAnd(_LogicBinaryOp):
"""
"""
def
infer_dtype
(
self
,
x_dtype
,
y_dtype
):
def
infer_dtype
(
self
,
x_dtype
,
y_dtype
):
return
_LogicBinaryOp
.
do_infer_dtype
(
x_dtype
,
y_dtype
,
(
mstype
.
bool_
,),
self
.
prim_name
()
)
return
_LogicBinaryOp
.
do_infer_dtype
(
x_dtype
,
y_dtype
,
(
mstype
.
bool_
,),
self
.
name
)
class
LogicalOr
(
_LogicBinaryOp
):
class
LogicalOr
(
_LogicBinaryOp
):
...
@@ -1550,7 +1554,7 @@ class LogicalOr(_LogicBinaryOp):
...
@@ -1550,7 +1554,7 @@ class LogicalOr(_LogicBinaryOp):
"""
"""
def
infer_dtype
(
self
,
x_dtype
,
y_dtype
):
def
infer_dtype
(
self
,
x_dtype
,
y_dtype
):
return
_LogicBinaryOp
.
do_infer_dtype
(
x_dtype
,
y_dtype
,
(
mstype
.
bool_
,),
self
.
prim_name
()
)
return
_LogicBinaryOp
.
do_infer_dtype
(
x_dtype
,
y_dtype
,
(
mstype
.
bool_
,),
self
.
name
)
class
IsNan
(
PrimitiveWithInfer
):
class
IsNan
(
PrimitiveWithInfer
):
"""
"""
...
@@ -1699,13 +1703,13 @@ class NPUGetFloatStatus(PrimitiveWithInfer):
...
@@ -1699,13 +1703,13 @@ class NPUGetFloatStatus(PrimitiveWithInfer):
self
.
add_prim_attr
(
"_side_effect_flag"
,
True
)
self
.
add_prim_attr
(
"_side_effect_flag"
,
True
)
def
infer_shape
(
self
,
x_shape
):
def
infer_shape
(
self
,
x_shape
):
cls_name
=
self
.
prim_name
()
cls_name
=
self
.
name
validator
.
check_integer
(
"len(x_shape)"
,
len
(
x_shape
),
1
,
Rel
.
EQ
,
cls_name
)
validator
.
check_integer
(
"len(x_shape)"
,
len
(
x_shape
),
1
,
Rel
.
EQ
,
cls_name
)
validator
.
check_integer
(
"x_shape[0]"
,
x_shape
[
0
],
8
,
Rel
.
EQ
,
cls_name
)
validator
.
check_integer
(
"x_shape[0]"
,
x_shape
[
0
],
8
,
Rel
.
EQ
,
cls_name
)
return
[
8
]
return
[
8
]
def
infer_dtype
(
self
,
x_dtype
):
def
infer_dtype
(
self
,
x_dtype
):
validator
.
check_tensor_type_same
({
'x'
:
x_dtype
},
[
mstype
.
float32
],
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
'x'
:
x_dtype
},
[
mstype
.
float32
],
self
.
name
)
return
mstype
.
float32
return
mstype
.
float32
...
@@ -1741,13 +1745,13 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
...
@@ -1741,13 +1745,13 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
self
.
add_prim_attr
(
"_side_effect_flag"
,
True
)
self
.
add_prim_attr
(
"_side_effect_flag"
,
True
)
def
infer_shape
(
self
,
x_shape
):
def
infer_shape
(
self
,
x_shape
):
cls_name
=
self
.
prim_name
()
cls_name
=
self
.
name
validator
.
check_integer
(
"len(x_shape)"
,
len
(
x_shape
),
1
,
Rel
.
EQ
,
cls_name
)
validator
.
check_integer
(
"len(x_shape)"
,
len
(
x_shape
),
1
,
Rel
.
EQ
,
cls_name
)
validator
.
check_integer
(
"x_shape[0]"
,
x_shape
[
0
],
8
,
Rel
.
EQ
,
cls_name
)
validator
.
check_integer
(
"x_shape[0]"
,
x_shape
[
0
],
8
,
Rel
.
EQ
,
cls_name
)
return
[
8
]
return
[
8
]
def
infer_dtype
(
self
,
x_dtype
):
def
infer_dtype
(
self
,
x_dtype
):
validator
.
check_tensor_type_same
({
'x'
:
x_dtype
},
[
mstype
.
float32
],
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
'x'
:
x_dtype
},
[
mstype
.
float32
],
self
.
name
)
return
mstype
.
float32
return
mstype
.
float32
...
@@ -1775,7 +1779,7 @@ class Cos(PrimitiveWithInfer):
...
@@ -1775,7 +1779,7 @@ class Cos(PrimitiveWithInfer):
return
x
return
x
def
infer_dtype
(
self
,
x
):
def
infer_dtype
(
self
,
x
):
validator
.
check_tensor_type_same
({
'x'
:
x
},
mstype
.
number_type
,
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
'x'
:
x
},
mstype
.
number_type
,
self
.
name
)
return
x
return
x
...
@@ -1803,7 +1807,7 @@ class ACos(PrimitiveWithInfer):
...
@@ -1803,7 +1807,7 @@ class ACos(PrimitiveWithInfer):
return
x
return
x
def
infer_dtype
(
self
,
x
):
def
infer_dtype
(
self
,
x
):
validator
.
check_tensor_type_same
({
'x'
:
x
},
mstype
.
number_type
,
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
'x'
:
x
},
mstype
.
number_type
,
self
.
name
)
return
x
return
x
...
@@ -1831,7 +1835,7 @@ class Sin(PrimitiveWithInfer):
...
@@ -1831,7 +1835,7 @@ class Sin(PrimitiveWithInfer):
return
x
return
x
def
infer_dtype
(
self
,
x
):
def
infer_dtype
(
self
,
x
):
validator
.
check_tensor_type_same
({
'x'
:
x
},
mstype
.
number_type
,
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
'x'
:
x
},
mstype
.
number_type
,
self
.
name
)
return
x
return
x
...
@@ -1876,11 +1880,11 @@ class NMSWithMask(PrimitiveWithInfer):
...
@@ -1876,11 +1880,11 @@ class NMSWithMask(PrimitiveWithInfer):
@
prim_attr_register
@
prim_attr_register
def
__init__
(
self
,
iou_threshold
=
0.5
):
def
__init__
(
self
,
iou_threshold
=
0.5
):
"""Init NMSWithMask"""
"""Init NMSWithMask"""
validator
.
check_value_type
(
"iou_threshold"
,
iou_threshold
,
[
float
],
self
.
prim_name
()
)
validator
.
check_value_type
(
"iou_threshold"
,
iou_threshold
,
[
float
],
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'bboxes'
],
outputs
=
[
'selected_boxes'
,
'selected_idx'
,
'selected_mask'
])
self
.
init_prim_io_names
(
inputs
=
[
'bboxes'
],
outputs
=
[
'selected_boxes'
,
'selected_idx'
,
'selected_mask'
])
def
infer_shape
(
self
,
bboxes_shape
):
def
infer_shape
(
self
,
bboxes_shape
):
cls_name
=
self
.
prim_name
()
cls_name
=
self
.
name
validator
.
check_integer
(
"bboxes rank"
,
len
(
bboxes_shape
),
2
,
Rel
.
EQ
,
cls_name
)
validator
.
check_integer
(
"bboxes rank"
,
len
(
bboxes_shape
),
2
,
Rel
.
EQ
,
cls_name
)
validator
.
check_integer
(
"bboxes.shape()[0]"
,
bboxes_shape
[
0
],
0
,
Rel
.
GT
,
cls_name
)
validator
.
check_integer
(
"bboxes.shape()[0]"
,
bboxes_shape
[
0
],
0
,
Rel
.
GT
,
cls_name
)
validator
.
check_integer
(
"bboxes.shape()[1]"
,
bboxes_shape
[
1
],
5
,
Rel
.
EQ
,
cls_name
)
validator
.
check_integer
(
"bboxes.shape()[1]"
,
bboxes_shape
[
1
],
5
,
Rel
.
EQ
,
cls_name
)
...
@@ -1888,7 +1892,7 @@ class NMSWithMask(PrimitiveWithInfer):
...
@@ -1888,7 +1892,7 @@ class NMSWithMask(PrimitiveWithInfer):
return
(
bboxes_shape
,
(
num
,),
(
num
,))
return
(
bboxes_shape
,
(
num
,),
(
num
,))
def
infer_dtype
(
self
,
bboxes_dtype
):
def
infer_dtype
(
self
,
bboxes_dtype
):
validator
.
check_tensor_type_same
({
"bboxes"
:
bboxes_dtype
},
[
mstype
.
float16
,
mstype
.
float32
],
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
"bboxes"
:
bboxes_dtype
},
[
mstype
.
float16
,
mstype
.
float32
],
self
.
name
)
return
(
bboxes_dtype
,
mstype
.
int32
,
mstype
.
bool_
)
return
(
bboxes_dtype
,
mstype
.
int32
,
mstype
.
bool_
)
...
@@ -1917,7 +1921,7 @@ class Abs(PrimitiveWithInfer):
...
@@ -1917,7 +1921,7 @@ class Abs(PrimitiveWithInfer):
return
x_shape
return
x_shape
def
infer_dtype
(
self
,
x_type
):
def
infer_dtype
(
self
,
x_type
):
validator
.
check_tensor_type_same
({
'x'
:
x_type
},
mstype
.
number_type
,
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
'x'
:
x_type
},
mstype
.
number_type
,
self
.
name
)
return
x_type
return
x_type
def
infer_value
(
self
,
x
):
def
infer_value
(
self
,
x
):
...
@@ -1959,7 +1963,7 @@ class Sign(PrimitiveWithInfer):
...
@@ -1959,7 +1963,7 @@ class Sign(PrimitiveWithInfer):
return
x_shape
return
x_shape
def
infer_dtype
(
self
,
x_dtype
):
def
infer_dtype
(
self
,
x_dtype
):
validator
.
check_tensor_type_same
({
'x'
:
x_dtype
},
mstype
.
number_type
,
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
'x'
:
x_dtype
},
mstype
.
number_type
,
self
.
name
)
return
x_dtype
return
x_dtype
...
@@ -1988,7 +1992,7 @@ class Round(PrimitiveWithInfer):
...
@@ -1988,7 +1992,7 @@ class Round(PrimitiveWithInfer):
return
x_shape
return
x_shape
def
infer_dtype
(
self
,
x_type
):
def
infer_dtype
(
self
,
x_type
):
validator
.
check_tensor_type_same
({
'x'
:
x_type
},
mstype
.
number_type
,
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
'x'
:
x_type
},
mstype
.
number_type
,
self
.
name
)
return
x_type
return
x_type
...
...
mindspore/ops/primitive.py
浏览文件 @
5d0801ed
...
@@ -194,9 +194,6 @@ class PrimitiveWithInfer(Primitive):
...
@@ -194,9 +194,6 @@ class PrimitiveWithInfer(Primitive):
Primitive
.
__init__
(
self
,
name
)
Primitive
.
__init__
(
self
,
name
)
self
.
set_prim_type
(
prim_type
.
py_infer_shape
)
self
.
set_prim_type
(
prim_type
.
py_infer_shape
)
def
prim_name
(
self
):
return
self
.
__class__
.
__name__
def
_clone
(
self
):
def
_clone
(
self
):
"""
"""
Deeply clones the primitive object.
Deeply clones the primitive object.
...
...
tests/ut/cpp/parallel/ops_info/pow_info_test.cc
浏览文件 @
5d0801ed
...
@@ -19,7 +19,7 @@
...
@@ -19,7 +19,7 @@
#include <vector>
#include <vector>
#include "common/common_test.h"
#include "common/common_test.h"
#include "parallel/strategy.h"
#include "parallel/strategy.h"
#include "parallel/ops_info/
elementary_function
_info.h"
#include "parallel/ops_info/
arithmetic
_info.h"
#include "parallel/device_manager.h"
#include "parallel/device_manager.h"
#include "parallel/step_parallel.h"
#include "parallel/step_parallel.h"
...
@@ -56,14 +56,14 @@ void TestPowInfo::SetUp() {
...
@@ -56,14 +56,14 @@ void TestPowInfo::SetUp() {
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
attr
;
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
attr
;
Shapes
inputs_shape
=
{{
32
,
64
,
128
}};
Shapes
inputs_shape
=
{{
32
,
64
,
128
}
,
{
32
,
64
,
128
}
};
Shapes
outputs_shape
=
{{
32
,
64
,
128
}};
Shapes
outputs_shape
=
{{
32
,
64
,
128
}};
pow
=
std
::
make_shared
<
PowInfo
>
(
"pow_info"
,
inputs_shape
,
outputs_shape
,
attr
);
pow
=
std
::
make_shared
<
PowInfo
>
(
"pow_info"
,
inputs_shape
,
outputs_shape
,
attr
);
}
}
TEST_F
(
TestPowInfo
,
InferDevMatrixShape1
)
{
TEST_F
(
TestPowInfo
,
InferDevMatrixShape1
)
{
std
::
vector
<
Dimensions
>
inputs
=
{{
2
,
4
,
8
}};
std
::
vector
<
Dimensions
>
inputs
=
{{
2
,
4
,
8
}
,
{
2
,
4
,
8
}
};
StrategyPtr
strategy
=
NewStrategy
(
0
,
inputs
);
StrategyPtr
strategy
=
NewStrategy
(
0
,
inputs
);
pow
->
Init
(
strategy
);
pow
->
Init
(
strategy
);
...
@@ -74,7 +74,7 @@ TEST_F(TestPowInfo, InferDevMatrixShape1) {
...
@@ -74,7 +74,7 @@ TEST_F(TestPowInfo, InferDevMatrixShape1) {
}
}
TEST_F
(
TestPowInfo
,
InferSliceShape1
)
{
TEST_F
(
TestPowInfo
,
InferSliceShape1
)
{
std
::
vector
<
Dimensions
>
str
=
{{
2
,
4
,
8
}};
std
::
vector
<
Dimensions
>
str
=
{{
2
,
4
,
8
}
,
{
2
,
4
,
8
}
};
StrategyPtr
strategy
=
NewStrategy
(
0
,
str
);
StrategyPtr
strategy
=
NewStrategy
(
0
,
str
);
pow
->
Init
(
strategy
);
pow
->
Init
(
strategy
);
...
@@ -95,7 +95,7 @@ TEST_F(TestPowInfo, InferSliceShape1) {
...
@@ -95,7 +95,7 @@ TEST_F(TestPowInfo, InferSliceShape1) {
}
}
TEST_F
(
TestPowInfo
,
GetTensorLayout1
)
{
TEST_F
(
TestPowInfo
,
GetTensorLayout1
)
{
std
::
vector
<
Dimensions
>
str
=
{{
2
,
4
,
8
}};
std
::
vector
<
Dimensions
>
str
=
{{
2
,
4
,
8
}
,
{
2
,
4
,
8
}
};
StrategyPtr
strategy
=
NewStrategy
(
0
,
str
);
StrategyPtr
strategy
=
NewStrategy
(
0
,
str
);
pow
->
Init
(
strategy
);
pow
->
Init
(
strategy
);
...
@@ -116,7 +116,7 @@ TEST_F(TestPowInfo, GetTensorLayout1) {
...
@@ -116,7 +116,7 @@ TEST_F(TestPowInfo, GetTensorLayout1) {
}
}
TEST_F
(
TestPowInfo
,
GetForwardOp1
)
{
TEST_F
(
TestPowInfo
,
GetForwardOp1
)
{
std
::
vector
<
Dimensions
>
inputs
=
{{
2
,
4
,
8
}};
std
::
vector
<
Dimensions
>
inputs
=
{{
2
,
4
,
8
}
,
{
2
,
4
,
8
}
};
StrategyPtr
strategy
=
NewStrategy
(
0
,
inputs
);
StrategyPtr
strategy
=
NewStrategy
(
0
,
inputs
);
pow
->
Init
(
strategy
);
pow
->
Init
(
strategy
);
...
@@ -127,7 +127,7 @@ TEST_F(TestPowInfo, GetForwardOp1) {
...
@@ -127,7 +127,7 @@ TEST_F(TestPowInfo, GetForwardOp1) {
}
}
TEST_F
(
TestPowInfo
,
GetMirrorOPs1
)
{
TEST_F
(
TestPowInfo
,
GetMirrorOPs1
)
{
std
::
vector
<
Dimensions
>
inputs
=
{{
2
,
4
,
8
}};
std
::
vector
<
Dimensions
>
inputs
=
{{
2
,
4
,
8
}
,
{
2
,
4
,
8
}
};
StrategyPtr
strategy
=
NewStrategy
(
0
,
inputs
);
StrategyPtr
strategy
=
NewStrategy
(
0
,
inputs
);
pow
->
Init
(
strategy
);
pow
->
Init
(
strategy
);
...
@@ -147,7 +147,7 @@ TEST_F(TestPowInfo, CheckStrategy1) {
...
@@ -147,7 +147,7 @@ TEST_F(TestPowInfo, CheckStrategy1) {
}
}
TEST_F
(
TestPowInfo
,
CheckStrategy2
)
{
TEST_F
(
TestPowInfo
,
CheckStrategy2
)
{
std
::
vector
<
Dimensions
>
inputs
=
{{
2
,
4
,
8
,
16
}};
std
::
vector
<
Dimensions
>
inputs
=
{{
2
,
4
,
8
,
16
}
,
{
2
,
4
,
8
,
16
}
};
StrategyPtr
strategy
=
NewStrategy
(
0
,
inputs
);
StrategyPtr
strategy
=
NewStrategy
(
0
,
inputs
);
Status
ret
=
pow
->
Init
(
strategy
);
Status
ret
=
pow
->
Init
(
strategy
);
...
@@ -155,7 +155,7 @@ TEST_F(TestPowInfo, CheckStrategy2) {
...
@@ -155,7 +155,7 @@ TEST_F(TestPowInfo, CheckStrategy2) {
}
}
TEST_F
(
TestPowInfo
,
CheckStrategy3
)
{
TEST_F
(
TestPowInfo
,
CheckStrategy3
)
{
std
::
vector
<
Dimensions
>
inputs
=
{{
2
,
4
,
8
}};
std
::
vector
<
Dimensions
>
inputs
=
{{
2
,
4
,
8
}
,
{
2
,
4
,
8
}
};
StrategyPtr
strategy
=
NewStrategy
(
0
,
inputs
);
StrategyPtr
strategy
=
NewStrategy
(
0
,
inputs
);
Status
ret
=
pow
->
Init
(
strategy
);
Status
ret
=
pow
->
Init
(
strategy
);
...
...
tests/ut/python/ops/test_math_ops.py
浏览文件 @
5d0801ed
...
@@ -82,9 +82,10 @@ def test_sqrt():
...
@@ -82,9 +82,10 @@ def test_sqrt():
def
test_pow
():
def
test_pow
():
""" test_pow """
""" test_pow """
input_tensor
=
Tensor
(
np
.
array
([[
2
,
2
],
[
3
,
3
]]))
input_tensor
=
Tensor
(
np
.
array
([[
2
,
2
],
[
3
,
3
]]))
power
=
Tensor
(
np
.
array
(
3.0
,
np
.
int64
))
testpow
=
P
.
Pow
()
testpow
=
P
.
Pow
()
expect
=
np
.
array
([[
8
,
8
],
[
27
,
27
]])
expect
=
np
.
array
([[
8
,
8
],
[
27
,
27
]])
result
=
testpow
(
input_tensor
,
3.0
)
result
=
testpow
(
input_tensor
,
power
)
assert
np
.
all
(
result
.
asnumpy
()
==
expect
)
assert
np
.
all
(
result
.
asnumpy
()
==
expect
)
...
...
tests/ut/python/ops/test_ops.py
浏览文件 @
5d0801ed
...
@@ -224,11 +224,15 @@ test_case_math_ops = [
...
@@ -224,11 +224,15 @@ test_case_math_ops = [
'block'
:
P
.
Minimum
(),
'block'
:
P
.
Minimum
(),
'desc_inputs'
:
[[
2
,
3
,
3
,
5
],
[
2
,
3
,
3
,
5
]],
'desc_inputs'
:
[[
2
,
3
,
3
,
5
],
[
2
,
3
,
3
,
5
]],
'desc_bprop'
:
[[
2
,
3
,
3
,
5
]]}),
'desc_bprop'
:
[[
2
,
3
,
3
,
5
]]}),
(
'Pow'
,
{
(
'Pow
_0
'
,
{
'block'
:
P
.
Pow
(),
'block'
:
P
.
Pow
(),
'desc_const'
:
[
2.0
],
'desc_const'
:
[
2.0
],
'desc_inputs'
:
[[
2
,
3
,
3
,
5
]],
'desc_inputs'
:
[[
2
,
3
,
3
,
5
]],
'desc_bprop'
:
[[
2
,
3
,
3
,
5
]]}),
'desc_bprop'
:
[[
2
,
3
,
3
,
5
]]}),
(
'Pow_1'
,
{
'block'
:
P
.
Pow
(),
'desc_inputs'
:
[[
3
,
5
],
[
2
,
3
,
3
,
5
]],
'desc_bprop'
:
[[
2
,
3
,
3
,
5
]]}),
(
'Exp'
,
{
(
'Exp'
,
{
'block'
:
P
.
Exp
(),
'block'
:
P
.
Exp
(),
'desc_inputs'
:
[[
2
,
3
]],
'desc_inputs'
:
[[
2
,
3
]],
...
...
tests/ut/python/parallel/test_element_wise_function.py
浏览文件 @
5d0801ed
...
@@ -59,7 +59,7 @@ def test_matmul_pow():
...
@@ -59,7 +59,7 @@ def test_matmul_pow():
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
)
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
)
strategy1
=
((
2
,
2
),
(
2
,
2
))
strategy1
=
((
2
,
2
),
(
2
,
2
))
strategy2
=
((
4
,
2
),
)
strategy2
=
((
4
,
2
),
()
)
net
=
GradWrap
(
NetWithLoss
(
Net
(
strategy1
,
strategy2
)))
net
=
GradWrap
(
NetWithLoss
(
Net
(
strategy1
,
strategy2
)))
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
)
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
)
...
...
tests/vm_impl/math_ops_vm_impl.py
浏览文件 @
5d0801ed
...
@@ -117,6 +117,7 @@ def vm_impl_pow(self):
...
@@ -117,6 +117,7 @@ def vm_impl_pow(self):
"""Generate vm_impl function for Pow."""
"""Generate vm_impl function for Pow."""
def
vm_impl
(
x
,
y
):
def
vm_impl
(
x
,
y
):
x
=
x
.
asnumpy
()
x
=
x
.
asnumpy
()
y
=
y
.
asnumpy
()
res
=
vm
.
power
(
x
,
y
)
res
=
vm
.
power
(
x
,
y
)
return
Tensor
(
res
)
return
Tensor
(
res
)
return
vm_impl
return
vm_impl
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录