Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
5841fe01
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5841fe01
编写于
4月 08, 2020
作者:
B
buxue
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support pow's second input could be tensor and fix bug in bprop of pow
上级
7cec2852
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
95 addition
and
139 deletion
+95
-139
mindspore/ccsrc/parallel/ops_info/arithmetic_info.h
mindspore/ccsrc/parallel/ops_info/arithmetic_info.h
+7
-0
mindspore/ccsrc/parallel/ops_info/elementary_function_info.cc
...spore/ccsrc/parallel/ops_info/elementary_function_info.cc
+0
-47
mindspore/ccsrc/parallel/ops_info/elementary_function_info.h
mindspore/ccsrc/parallel/ops_info/elementary_function_info.h
+0
-10
mindspore/nn/layer/pooling.py
mindspore/nn/layer/pooling.py
+1
-1
mindspore/ops/_grad/grad_math_ops.py
mindspore/ops/_grad/grad_math_ops.py
+5
-6
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+2
-2
mindspore/ops/operations/math_ops.py
mindspore/ops/operations/math_ops.py
+62
-58
mindspore/ops/primitive.py
mindspore/ops/primitive.py
+0
-3
tests/ut/cpp/parallel/ops_info/pow_info_test.cc
tests/ut/cpp/parallel/ops_info/pow_info_test.cc
+9
-9
tests/ut/python/ops/test_math_ops.py
tests/ut/python/ops/test_math_ops.py
+2
-1
tests/ut/python/ops/test_ops.py
tests/ut/python/ops/test_ops.py
+5
-1
tests/ut/python/parallel/test_element_wise_function.py
tests/ut/python/parallel/test_element_wise_function.py
+1
-1
tests/vm_impl/math_ops_vm_impl.py
tests/vm_impl/math_ops_vm_impl.py
+1
-0
未找到文件。
mindspore/ccsrc/parallel/ops_info/arithmetic_info.h
浏览文件 @
5841fe01
...
...
@@ -98,6 +98,13 @@ class FloorDivInfo : public ArithmeticBase {
~
FloorDivInfo
()
override
=
default
;
};
class
PowInfo
:
public
ArithmeticBase
{
public:
PowInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
ArithmeticBase
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{}
~
PowInfo
()
override
=
default
;
};
class
GreaterInfo
:
public
ArithmeticBase
{
public:
GreaterInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
...
...
mindspore/ccsrc/parallel/ops_info/elementary_function_info.cc
已删除
100644 → 0
浏览文件 @
7cec2852
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "parallel/ops_info/elementary_function_info.h"
namespace
mindspore
{
namespace
parallel
{
Status
PowInfo
::
InferMirrorOps
()
{
mirror_ops_
.
clear
();
Shape
tensor_map
=
inputs_tensor_map_
[
0
];
std
::
vector
<
Group
>
group
;
if
(
CreateGroupByTensorMap
(
tensor_map
,
&
group
)
!=
SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
name_
<<
" : Create group failed."
;
return
FAILED
;
}
OperatorVector
mirror_op
;
OperatorVector
op_for_value
;
if
(
group
.
empty
())
{
MS_LOG
(
INFO
)
<<
name_
<<
" : The mirror ops is empty."
;
return
SUCCESS
;
}
else
{
mirror_op
=
CreateMirrorOps
(
group
[
0
].
name
(),
group
[
0
].
GetDevNum
());
mirror_ops_
.
push_back
(
mirror_op
);
mirror_ops_
.
push_back
(
op_for_value
);
std
::
string
group_name
=
group
[
0
].
name
();
MS_LOG
(
INFO
)
<<
name_
<<
" : Create the mirror ops success, the group name is "
<<
group_name
;
}
return
SUCCESS
;
}
}
// namespace parallel
}
// namespace mindspore
mindspore/ccsrc/parallel/ops_info/elementary_function_info.h
浏览文件 @
5841fe01
...
...
@@ -27,16 +27,6 @@
namespace
mindspore
{
namespace
parallel
{
class
PowInfo
:
public
ActivationOther
{
public:
PowInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
:
ActivationOther
(
name
,
inputs_shape
,
outputs_shape
,
attrs
)
{}
~
PowInfo
()
override
=
default
;
protected:
Status
InferMirrorOps
()
override
;
};
class
ExpInfo
:
public
ActivationOther
{
public:
ExpInfo
(
const
std
::
string
&
name
,
const
Shapes
&
inputs_shape
,
const
Shapes
&
outputs_shape
,
const
PrimitiveAttrs
&
attrs
)
...
...
mindspore/nn/layer/pooling.py
浏览文件 @
5841fe01
...
...
@@ -58,7 +58,7 @@ class _PoolNd(Cell):
pass
def
extend_repr
(
self
):
return
'kernel_size={kernel_size}, stride
s={strides
}, pad_mode={pad_mode}'
.
format
(
**
self
.
__dict__
)
return
'kernel_size={kernel_size}, stride
={stride
}, pad_mode={pad_mode}'
.
format
(
**
self
.
__dict__
)
class
MaxPool2d
(
_PoolNd
):
...
...
mindspore/ops/_grad/grad_math_ops.py
浏览文件 @
5841fe01
...
...
@@ -336,14 +336,13 @@ def get_bprop_log(self):
@
bprop_getters
.
register
(
P
.
Pow
)
def
get_bprop_pow
(
self
):
"""Grad definition for `Pow` operation."""
pow_
=
P
.
Pow
()
cast
=
P
.
Cast
()
dtype
=
P
.
DType
()
pow_op
=
P
.
Pow
()
ln
=
P
.
Log
()
def
bprop
(
x
,
power
,
out
,
dout
):
g
=
cast
(
F
.
tuple_to_array
((
power
,)),
dtype
(
x
))
*
pow_
(
x
,
power
-
1.0
)
d
x
=
g
*
dout
return
dx
,
0
dx
=
power
*
pow_op
(
x
,
power
-
1.0
)
*
dout
d
power
=
pow_op
(
x
,
power
)
*
ln
(
x
)
*
dout
return
dx
,
dpower
return
bprop
...
...
mindspore/ops/operations/array_ops.py
浏览文件 @
5841fe01
...
...
@@ -1097,7 +1097,7 @@ class ArgMaxWithValue(PrimitiveWithInfer):
axis
=
self
.
axis
x_rank
=
len
(
x_shape
)
validator
.
check_int_range
(
"axis"
,
axis
,
-
x_rank
,
x_rank
,
Rel
.
INC_LEFT
)
ouput_shape
=
_infer_shape_reduce
(
x_shape
,
self
.
axis
,
self
.
keep_dims
,
self
.
prim_name
()
)
ouput_shape
=
_infer_shape_reduce
(
x_shape
,
self
.
axis
,
self
.
keep_dims
,
self
.
name
)
return
ouput_shape
,
ouput_shape
def
infer_dtype
(
self
,
x_dtype
):
...
...
@@ -1143,7 +1143,7 @@ class ArgMinWithValue(PrimitiveWithInfer):
axis
=
self
.
axis
x_rank
=
len
(
x_shape
)
validator
.
check_int_range
(
"axis"
,
axis
,
-
x_rank
,
x_rank
,
Rel
.
INC_LEFT
)
ouput_shape
=
_infer_shape_reduce
(
x_shape
,
self
.
axis
,
self
.
keep_dims
,
self
.
prim_name
()
)
ouput_shape
=
_infer_shape_reduce
(
x_shape
,
self
.
axis
,
self
.
keep_dims
,
self
.
name
)
return
ouput_shape
,
ouput_shape
def
infer_dtype
(
self
,
x_dtype
):
...
...
mindspore/ops/operations/math_ops.py
浏览文件 @
5841fe01
...
...
@@ -74,7 +74,7 @@ class _BinaryOp(PrimitiveWithInfer):
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'y'
],
outputs
=
[
'output'
])
def
infer_shape
(
self
,
x_shape
,
y_shape
):
return
_get_broadcast_shape
(
x_shape
,
y_shape
,
self
.
prim_name
()
)
return
_get_broadcast_shape
(
x_shape
,
y_shape
,
self
.
name
)
class
_MathBinaryOp
(
_BinaryOp
):
...
...
@@ -89,7 +89,7 @@ class _MathBinaryOp(_BinaryOp):
return
x_dtype
def
infer_dtype
(
self
,
x_dtype
,
y_dtype
):
return
_MathBinaryOp
.
do_infer_dtype
(
x_dtype
,
y_dtype
,
mstype
.
number_type
,
self
.
prim_name
()
)
return
_MathBinaryOp
.
do_infer_dtype
(
x_dtype
,
y_dtype
,
mstype
.
number_type
,
self
.
name
)
class
TensorAdd
(
_MathBinaryOp
):
...
...
@@ -158,7 +158,7 @@ class AssignAdd(PrimitiveWithInfer):
def
infer_dtype
(
self
,
variable
,
value
):
args
=
{
"value"
:
value
}
validator
.
check_scalar_or_tensor_type_same
(
args
,
mstype
.
number_type
,
self
.
prim_name
()
)
validator
.
check_scalar_or_tensor_type_same
(
args
,
mstype
.
number_type
,
self
.
name
)
return
value
...
...
@@ -201,7 +201,7 @@ class AssignSub(PrimitiveWithInfer):
def
infer_dtype
(
self
,
variable
,
value
):
args
=
{
"value"
:
value
}
validator
.
check_scalar_or_tensor_type_same
(
args
,
mstype
.
number_type
,
self
.
prim_name
()
)
validator
.
check_scalar_or_tensor_type_same
(
args
,
mstype
.
number_type
,
self
.
name
)
return
value
...
...
@@ -222,16 +222,16 @@ class _Reduce(PrimitiveWithInfer):
@
prim_attr_register
def
__init__
(
self
,
keep_dims
=
False
):
"""init Reduce"""
validator
.
check_value_type
(
'keep_dims'
,
keep_dims
,
[
bool
],
self
.
prim_name
()
)
validator
.
check_value_type
(
'keep_dims'
,
keep_dims
,
[
bool
],
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'input_x'
,
'axis'
],
outputs
=
[
'y'
])
def
do_infer
(
self
,
input_x
,
axis
,
valid_dtype
=
mstype
.
number_type
):
axis_v
=
axis
[
'value'
]
input_shp
=
input_x
[
'shape'
]
args
=
{
'input_x'
:
input_x
[
'dtype'
]}
validator
.
check_tensor_type_same
(
args
,
valid_dtype
,
self
.
prim_name
()
)
validator
.
check_tensor_type_same
(
args
,
valid_dtype
,
self
.
name
)
input_shp
=
_infer_shape_reduce
(
input_shp
,
axis_v
,
self
.
keep_dims
,
self
.
prim_name
()
)
input_shp
=
_infer_shape_reduce
(
input_shp
,
axis_v
,
self
.
keep_dims
,
self
.
name
)
return
{
'shape'
:
input_shp
,
'dtype'
:
input_x
[
'dtype'
],
'value'
:
None
}
...
...
@@ -466,7 +466,7 @@ class CumProd(PrimitiveWithInfer):
"""
@
prim_attr_register
def
__init__
(
self
,
exclusive
=
False
,
reverse
=
False
):
cls_name
=
self
.
prim_name
()
cls_name
=
self
.
name
self
.
exclusive
=
validator
.
check_value_type
(
"exclusive"
,
exclusive
,
[
bool
],
cls_name
)
self
.
reverse
=
validator
.
check_value_type
(
"reverse"
,
reverse
,
[
bool
],
cls_name
)
...
...
@@ -474,7 +474,7 @@ class CumProd(PrimitiveWithInfer):
return
x_shape
def
infer_dtype
(
self
,
x_type
,
axis_type
):
cls_name
=
self
.
prim_name
()
cls_name
=
self
.
name
validator
.
check_tensor_type_same
({
'x'
:
x_type
},
mstype
.
number_type
,
cls_name
)
validator
.
check_subclass
(
"axis"
,
axis_type
,
mstype
.
int_
,
cls_name
)
return
x_type
...
...
@@ -510,7 +510,7 @@ class MatMul(PrimitiveWithInfer):
def
__init__
(
self
,
transpose_a
=
False
,
transpose_b
=
False
):
self
.
init_prim_io_names
(
inputs
=
[
'x1'
,
'x2'
],
outputs
=
[
'output'
])
self
.
__setattr_flag__
=
True
cls_name
=
self
.
prim_name
()
cls_name
=
self
.
name
validator
.
check_value_type
(
"transpose_a"
,
transpose_a
,
[
bool
],
cls_name
)
validator
.
check_value_type
(
"transpose_b"
,
transpose_b
,
[
bool
],
cls_name
)
...
...
@@ -521,7 +521,7 @@ class MatMul(PrimitiveWithInfer):
def
infer_shape
(
self
,
x
,
y
):
self
.
check_shape_size
(
x
,
y
)
cls_name
=
self
.
prim_name
()
cls_name
=
self
.
name
# expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two
for
i
in
range
(
len
(
x
)
-
2
):
if
x
[
i
]
!=
y
[
i
]:
...
...
@@ -546,7 +546,7 @@ class MatMul(PrimitiveWithInfer):
def
infer_dtype
(
self
,
x
,
y
):
args
=
{
"x"
:
x
,
"y"
:
y
}
validator
.
check_tensor_type_same
(
args
,
mstype
.
float_type
+
mstype
.
int_type
,
self
.
prim_name
()
)
validator
.
check_tensor_type_same
(
args
,
mstype
.
float_type
+
mstype
.
int_type
,
self
.
name
)
return
x
...
...
@@ -590,7 +590,7 @@ class BatchMatMul(MatMul):
def
__init__
(
self
,
transpose_a
=
False
,
transpose_b
=
False
):
self
.
init_prim_io_names
(
inputs
=
[
'x1'
,
'x2'
],
outputs
=
[
'output'
])
self
.
__setattr_flag__
=
True
cls_name
=
self
.
prim_name
()
cls_name
=
self
.
name
validator
.
check_value_type
(
"transpose_a"
,
transpose_a
,
[
bool
],
cls_name
)
validator
.
check_value_type
(
"transpose_b"
,
transpose_b
,
[
bool
],
cls_name
)
...
...
@@ -628,13 +628,13 @@ class CumSum(PrimitiveWithInfer):
@
prim_attr_register
def
__init__
(
self
,
exclusive
=
False
,
reverse
=
False
):
"""init cumsum"""
cls_name
=
self
.
prim_name
()
cls_name
=
self
.
name
validator
.
check_value_type
(
'exclusive'
,
exclusive
,
[
bool
],
cls_name
)
validator
.
check_value_type
(
'reverse'
,
reverse
,
[
bool
],
cls_name
)
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'axis'
],
outputs
=
[
'y'
])
def
__infer__
(
self
,
x
,
axis
):
cls_name
=
self
.
prim_name
()
cls_name
=
self
.
name
x_shp
=
x
[
'shape'
]
validator
.
check_value_type
(
'axis'
,
axis
[
'value'
],
[
int
],
cls_name
)
valid_types
=
[
mstype
.
uint8
,
mstype
.
int8
,
mstype
.
int32
,
mstype
.
float16
,
mstype
.
float32
]
...
...
@@ -679,7 +679,7 @@ class AddN(PrimitiveWithInfer):
self
.
init_prim_io_names
(
inputs
=
[
"inputs"
],
outputs
=
[
"sum"
])
def
infer_shape
(
self
,
inputs
):
cls_name
=
self
.
prim_name
()
cls_name
=
self
.
name
validator
.
check_integer
(
"inputs"
,
len
(
inputs
),
1
,
Rel
.
GE
,
cls_name
)
self
.
add_prim_attr
(
'n'
,
len
(
inputs
))
shp0
=
inputs
[
0
]
...
...
@@ -688,7 +688,7 @@ class AddN(PrimitiveWithInfer):
return
shp0
def
infer_dtype
(
self
,
inputs
):
cls_name
=
self
.
prim_name
()
cls_name
=
self
.
name
validator
.
check_value_type
(
"inputs"
,
inputs
,
[
tuple
,
list
],
cls_name
)
validator
.
check_integer
(
"inputs"
,
len
(
inputs
),
1
,
Rel
.
GE
,
cls_name
)
args
=
{}
...
...
@@ -718,7 +718,7 @@ class Neg(PrimitiveWithInfer):
return
input_x
def
infer_dtype
(
self
,
input_x
):
validator
.
check_tensor_type_same
({
"input_x"
:
input_x
},
mstype
.
number_type
,
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
"input_x"
:
input_x
},
mstype
.
number_type
,
self
.
name
)
return
input_x
...
...
@@ -809,7 +809,7 @@ class Square(PrimitiveWithInfer):
return
x_shape
def
infer_dtype
(
self
,
x_type
):
validator
.
check_tensor_type_same
({
"x"
:
x_type
},
mstype
.
number_type
,
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
"x"
:
x_type
},
mstype
.
number_type
,
self
.
name
)
return
x_type
...
...
@@ -838,7 +838,7 @@ class Rsqrt(PrimitiveWithInfer):
return
x_shape
def
infer_dtype
(
self
,
x_type
):
validator
.
check_tensor_type_same
({
"x"
:
x_type
},
mstype
.
number_type
,
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
"x"
:
x_type
},
mstype
.
number_type
,
self
.
name
)
return
x_type
...
...
@@ -867,7 +867,7 @@ class Sqrt(PrimitiveWithInfer):
return
x_shape
def
infer_dtype
(
self
,
x_type
):
validator
.
check_tensor_type_same
({
"x"
:
x_type
},
mstype
.
number_type
,
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
"x"
:
x_type
},
mstype
.
number_type
,
self
.
name
)
return
x_type
...
...
@@ -897,14 +897,29 @@ class Reciprocal(PrimitiveWithInfer):
return
x
def
infer_dtype
(
self
,
x
):
validator
.
check_subclass
(
"x"
,
x
,
mstype
.
tensor
,
self
.
prim_name
()
)
validator
.
check_subclass
(
"x"
,
x
,
mstype
.
tensor
,
self
.
name
)
return
x
class
Pow
(
PrimitiveWithInfer
):
class
Pow
(
_MathBinaryOp
):
"""
Computes a tensor to the power of the second input.
The first input must be a tensor, and the second input should be a tensor or a number.
When the inputs are two tensors, the shapes of them could be broadcast,
and the data types of them should be the same.
When the inputs are one tensor and one scalar, the scalar could not be a parameter,
only could be a constant, and the type of the scalar is the same as the data type of the tensor.
Inputs:
- **input_x** (Union[Tensor]) - The first input is a tensor whose data type is number.
- **input_y** (Union[Tensor, Number]) - The second input is a tensor whose data type is same as 'input_x' or
a number.
Outputs:
Tensor, the shape is same as the shape after broadcasting, and the data type is same as 'input_x'.
Inputs:
- **input_x** (Tensor) - The input tensor.
- **input_y** (Union[Tensor, Number]) - The exponent part. If exponent is a tensor, its shape must be able to
...
...
@@ -927,17 +942,6 @@ class Pow(PrimitiveWithInfer):
[1.0, 16.0, 64.0]
"""
@
prim_attr_register
def
__init__
(
self
):
"""init Multiply"""
def
infer_shape
(
self
,
x
,
power
):
return
x
def
infer_dtype
(
self
,
x
,
power
):
validator
.
check_tensor_type_same
({
"x"
:
x
},
mstype
.
number_type
,
self
.
prim_name
())
return
x
class
Exp
(
PrimitiveWithInfer
):
"""
...
...
@@ -965,7 +969,7 @@ class Exp(PrimitiveWithInfer):
return
x_shape
def
infer_dtype
(
self
,
x_type
):
validator
.
check_subclass
(
"x"
,
x_type
,
mstype
.
tensor
,
self
.
prim_name
()
)
validator
.
check_subclass
(
"x"
,
x_type
,
mstype
.
tensor
,
self
.
name
)
return
x_type
...
...
@@ -994,7 +998,7 @@ class Log(PrimitiveWithInfer):
return
x
def
infer_dtype
(
self
,
x
):
validator
.
check_subclass
(
"x"
,
x
,
mstype
.
tensor
,
self
.
prim_name
()
)
validator
.
check_subclass
(
"x"
,
x
,
mstype
.
tensor
,
self
.
name
)
return
x
...
...
@@ -1176,7 +1180,7 @@ class Floor(PrimitiveWithInfer):
return
x_shape
def
infer_dtype
(
self
,
x_dtype
):
validator
.
check_tensor_type_same
({
"x"
:
x_dtype
},
mstype
.
float_type
,
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
"x"
:
x_dtype
},
mstype
.
float_type
,
self
.
name
)
return
x_dtype
...
...
@@ -1231,7 +1235,7 @@ class Acosh(PrimitiveWithInfer):
return
x
def
infer_dtype
(
self
,
x
):
validator
.
check_tensor_type_same
({
'x'
:
x
},
mstype
.
number_type
,
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
'x'
:
x
},
mstype
.
number_type
,
self
.
name
)
return
x
...
...
@@ -1247,7 +1251,7 @@ class _LogicBinaryOp(_BinaryOp):
return
mstype
.
tensor_type
(
mstype
.
bool_
)
def
infer_dtype
(
self
,
x_dtype
,
y_dtype
):
return
_LogicBinaryOp
.
do_infer_dtype
(
x_dtype
,
y_dtype
,
prim_name
=
self
.
prim_name
()
)
return
_LogicBinaryOp
.
do_infer_dtype
(
x_dtype
,
y_dtype
,
prim_name
=
self
.
name
)
class
Equal
(
_LogicBinaryOp
):
...
...
@@ -1283,7 +1287,7 @@ class Equal(_LogicBinaryOp):
"""
def
infer_dtype
(
self
,
x_dtype
,
y_dtype
):
return
_LogicBinaryOp
.
do_infer_dtype
(
x_dtype
,
y_dtype
,
mstype
.
number_type
+
(
mstype
.
bool_
,),
self
.
prim_name
()
)
return
_LogicBinaryOp
.
do_infer_dtype
(
x_dtype
,
y_dtype
,
mstype
.
number_type
+
(
mstype
.
bool_
,),
self
.
name
)
class
EqualCount
(
PrimitiveWithInfer
):
...
...
@@ -1318,7 +1322,7 @@ class EqualCount(PrimitiveWithInfer):
def
infer_dtype
(
self
,
x_dtype
,
y_dtype
):
args
=
{
'x'
:
x_dtype
,
'y'
:
y_dtype
}
validator
.
check_tensor_type_same
(
args
,
mstype
.
number_type
+
(
mstype
.
bool_
,),
self
.
prim_name
()
)
validator
.
check_tensor_type_same
(
args
,
mstype
.
number_type
+
(
mstype
.
bool_
,),
self
.
name
)
return
x_dtype
...
...
@@ -1355,7 +1359,7 @@ class NotEqual(_LogicBinaryOp):
"""
def
infer_dtype
(
self
,
x_dtype
,
y_dtype
):
return
_LogicBinaryOp
.
do_infer_dtype
(
x_dtype
,
y_dtype
,
mstype
.
number_type
+
(
mstype
.
bool_
,),
self
.
prim_name
()
)
return
_LogicBinaryOp
.
do_infer_dtype
(
x_dtype
,
y_dtype
,
mstype
.
number_type
+
(
mstype
.
bool_
,),
self
.
name
)
class
Greater
(
_LogicBinaryOp
):
...
...
@@ -1491,7 +1495,7 @@ class LogicalNot(PrimitiveWithInfer):
return
x_shape
def
infer_dtype
(
self
,
x_dtype
):
validator
.
check_tensor_type_same
({
"x"
:
x_dtype
},
[
mstype
.
bool_
],
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
"x"
:
x_dtype
},
[
mstype
.
bool_
],
self
.
name
)
return
mstype
.
tensor_type
(
mstype
.
bool_
)
...
...
@@ -1521,7 +1525,7 @@ class LogicalAnd(_LogicBinaryOp):
"""
def
infer_dtype
(
self
,
x_dtype
,
y_dtype
):
return
_LogicBinaryOp
.
do_infer_dtype
(
x_dtype
,
y_dtype
,
(
mstype
.
bool_
,),
self
.
prim_name
()
)
return
_LogicBinaryOp
.
do_infer_dtype
(
x_dtype
,
y_dtype
,
(
mstype
.
bool_
,),
self
.
name
)
class
LogicalOr
(
_LogicBinaryOp
):
...
...
@@ -1550,7 +1554,7 @@ class LogicalOr(_LogicBinaryOp):
"""
def
infer_dtype
(
self
,
x_dtype
,
y_dtype
):
return
_LogicBinaryOp
.
do_infer_dtype
(
x_dtype
,
y_dtype
,
(
mstype
.
bool_
,),
self
.
prim_name
()
)
return
_LogicBinaryOp
.
do_infer_dtype
(
x_dtype
,
y_dtype
,
(
mstype
.
bool_
,),
self
.
name
)
class
IsNan
(
PrimitiveWithInfer
):
"""
...
...
@@ -1699,13 +1703,13 @@ class NPUGetFloatStatus(PrimitiveWithInfer):
self
.
add_prim_attr
(
"_side_effect_flag"
,
True
)
def
infer_shape
(
self
,
x_shape
):
cls_name
=
self
.
prim_name
()
cls_name
=
self
.
name
validator
.
check_integer
(
"len(x_shape)"
,
len
(
x_shape
),
1
,
Rel
.
EQ
,
cls_name
)
validator
.
check_integer
(
"x_shape[0]"
,
x_shape
[
0
],
8
,
Rel
.
EQ
,
cls_name
)
return
[
8
]
def
infer_dtype
(
self
,
x_dtype
):
validator
.
check_tensor_type_same
({
'x'
:
x_dtype
},
[
mstype
.
float32
],
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
'x'
:
x_dtype
},
[
mstype
.
float32
],
self
.
name
)
return
mstype
.
float32
...
...
@@ -1741,13 +1745,13 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
self
.
add_prim_attr
(
"_side_effect_flag"
,
True
)
def
infer_shape
(
self
,
x_shape
):
cls_name
=
self
.
prim_name
()
cls_name
=
self
.
name
validator
.
check_integer
(
"len(x_shape)"
,
len
(
x_shape
),
1
,
Rel
.
EQ
,
cls_name
)
validator
.
check_integer
(
"x_shape[0]"
,
x_shape
[
0
],
8
,
Rel
.
EQ
,
cls_name
)
return
[
8
]
def
infer_dtype
(
self
,
x_dtype
):
validator
.
check_tensor_type_same
({
'x'
:
x_dtype
},
[
mstype
.
float32
],
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
'x'
:
x_dtype
},
[
mstype
.
float32
],
self
.
name
)
return
mstype
.
float32
...
...
@@ -1775,7 +1779,7 @@ class Cos(PrimitiveWithInfer):
return
x
def
infer_dtype
(
self
,
x
):
validator
.
check_tensor_type_same
({
'x'
:
x
},
mstype
.
number_type
,
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
'x'
:
x
},
mstype
.
number_type
,
self
.
name
)
return
x
...
...
@@ -1803,7 +1807,7 @@ class ACos(PrimitiveWithInfer):
return
x
def
infer_dtype
(
self
,
x
):
validator
.
check_tensor_type_same
({
'x'
:
x
},
mstype
.
number_type
,
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
'x'
:
x
},
mstype
.
number_type
,
self
.
name
)
return
x
...
...
@@ -1831,7 +1835,7 @@ class Sin(PrimitiveWithInfer):
return
x
def
infer_dtype
(
self
,
x
):
validator
.
check_tensor_type_same
({
'x'
:
x
},
mstype
.
number_type
,
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
'x'
:
x
},
mstype
.
number_type
,
self
.
name
)
return
x
...
...
@@ -1876,11 +1880,11 @@ class NMSWithMask(PrimitiveWithInfer):
@
prim_attr_register
def
__init__
(
self
,
iou_threshold
=
0.5
):
"""Init NMSWithMask"""
validator
.
check_value_type
(
"iou_threshold"
,
iou_threshold
,
[
float
],
self
.
prim_name
()
)
validator
.
check_value_type
(
"iou_threshold"
,
iou_threshold
,
[
float
],
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'bboxes'
],
outputs
=
[
'selected_boxes'
,
'selected_idx'
,
'selected_mask'
])
def
infer_shape
(
self
,
bboxes_shape
):
cls_name
=
self
.
prim_name
()
cls_name
=
self
.
name
validator
.
check_integer
(
"bboxes rank"
,
len
(
bboxes_shape
),
2
,
Rel
.
EQ
,
cls_name
)
validator
.
check_integer
(
"bboxes.shape()[0]"
,
bboxes_shape
[
0
],
0
,
Rel
.
GT
,
cls_name
)
validator
.
check_integer
(
"bboxes.shape()[1]"
,
bboxes_shape
[
1
],
5
,
Rel
.
EQ
,
cls_name
)
...
...
@@ -1888,7 +1892,7 @@ class NMSWithMask(PrimitiveWithInfer):
return
(
bboxes_shape
,
(
num
,),
(
num
,))
def
infer_dtype
(
self
,
bboxes_dtype
):
validator
.
check_tensor_type_same
({
"bboxes"
:
bboxes_dtype
},
[
mstype
.
float16
,
mstype
.
float32
],
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
"bboxes"
:
bboxes_dtype
},
[
mstype
.
float16
,
mstype
.
float32
],
self
.
name
)
return
(
bboxes_dtype
,
mstype
.
int32
,
mstype
.
bool_
)
...
...
@@ -1917,7 +1921,7 @@ class Abs(PrimitiveWithInfer):
return
x_shape
def
infer_dtype
(
self
,
x_type
):
validator
.
check_tensor_type_same
({
'x'
:
x_type
},
mstype
.
number_type
,
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
'x'
:
x_type
},
mstype
.
number_type
,
self
.
name
)
return
x_type
def
infer_value
(
self
,
x
):
...
...
@@ -1959,7 +1963,7 @@ class Sign(PrimitiveWithInfer):
return
x_shape
def
infer_dtype
(
self
,
x_dtype
):
validator
.
check_tensor_type_same
({
'x'
:
x_dtype
},
mstype
.
number_type
,
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
'x'
:
x_dtype
},
mstype
.
number_type
,
self
.
name
)
return
x_dtype
...
...
@@ -1988,7 +1992,7 @@ class Round(PrimitiveWithInfer):
return
x_shape
def
infer_dtype
(
self
,
x_type
):
validator
.
check_tensor_type_same
({
'x'
:
x_type
},
mstype
.
number_type
,
self
.
prim_name
()
)
validator
.
check_tensor_type_same
({
'x'
:
x_type
},
mstype
.
number_type
,
self
.
name
)
return
x_type
...
...
mindspore/ops/primitive.py
浏览文件 @
5841fe01
...
...
@@ -194,9 +194,6 @@ class PrimitiveWithInfer(Primitive):
Primitive
.
__init__
(
self
,
name
)
self
.
set_prim_type
(
prim_type
.
py_infer_shape
)
def
prim_name
(
self
):
return
self
.
__class__
.
__name__
def
_clone
(
self
):
"""
Deeply clones the primitive object.
...
...
tests/ut/cpp/parallel/ops_info/pow_info_test.cc
浏览文件 @
5841fe01
...
...
@@ -19,7 +19,7 @@
#include <vector>
#include "common/common_test.h"
#include "parallel/strategy.h"
#include "parallel/ops_info/
elementary_function
_info.h"
#include "parallel/ops_info/
arithmetic
_info.h"
#include "parallel/device_manager.h"
#include "parallel/step_parallel.h"
...
...
@@ -56,14 +56,14 @@ void TestPowInfo::SetUp() {
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
attr
;
Shapes
inputs_shape
=
{{
32
,
64
,
128
}};
Shapes
inputs_shape
=
{{
32
,
64
,
128
}
,
{
32
,
64
,
128
}
};
Shapes
outputs_shape
=
{{
32
,
64
,
128
}};
pow
=
std
::
make_shared
<
PowInfo
>
(
"pow_info"
,
inputs_shape
,
outputs_shape
,
attr
);
}
TEST_F
(
TestPowInfo
,
InferDevMatrixShape1
)
{
std
::
vector
<
Dimensions
>
inputs
=
{{
2
,
4
,
8
}};
std
::
vector
<
Dimensions
>
inputs
=
{{
2
,
4
,
8
}
,
{
2
,
4
,
8
}
};
StrategyPtr
strategy
=
NewStrategy
(
0
,
inputs
);
pow
->
Init
(
strategy
);
...
...
@@ -74,7 +74,7 @@ TEST_F(TestPowInfo, InferDevMatrixShape1) {
}
TEST_F
(
TestPowInfo
,
InferSliceShape1
)
{
std
::
vector
<
Dimensions
>
str
=
{{
2
,
4
,
8
}};
std
::
vector
<
Dimensions
>
str
=
{{
2
,
4
,
8
}
,
{
2
,
4
,
8
}
};
StrategyPtr
strategy
=
NewStrategy
(
0
,
str
);
pow
->
Init
(
strategy
);
...
...
@@ -95,7 +95,7 @@ TEST_F(TestPowInfo, InferSliceShape1) {
}
TEST_F
(
TestPowInfo
,
GetTensorLayout1
)
{
std
::
vector
<
Dimensions
>
str
=
{{
2
,
4
,
8
}};
std
::
vector
<
Dimensions
>
str
=
{{
2
,
4
,
8
}
,
{
2
,
4
,
8
}
};
StrategyPtr
strategy
=
NewStrategy
(
0
,
str
);
pow
->
Init
(
strategy
);
...
...
@@ -116,7 +116,7 @@ TEST_F(TestPowInfo, GetTensorLayout1) {
}
TEST_F
(
TestPowInfo
,
GetForwardOp1
)
{
std
::
vector
<
Dimensions
>
inputs
=
{{
2
,
4
,
8
}};
std
::
vector
<
Dimensions
>
inputs
=
{{
2
,
4
,
8
}
,
{
2
,
4
,
8
}
};
StrategyPtr
strategy
=
NewStrategy
(
0
,
inputs
);
pow
->
Init
(
strategy
);
...
...
@@ -127,7 +127,7 @@ TEST_F(TestPowInfo, GetForwardOp1) {
}
TEST_F
(
TestPowInfo
,
GetMirrorOPs1
)
{
std
::
vector
<
Dimensions
>
inputs
=
{{
2
,
4
,
8
}};
std
::
vector
<
Dimensions
>
inputs
=
{{
2
,
4
,
8
}
,
{
2
,
4
,
8
}
};
StrategyPtr
strategy
=
NewStrategy
(
0
,
inputs
);
pow
->
Init
(
strategy
);
...
...
@@ -147,7 +147,7 @@ TEST_F(TestPowInfo, CheckStrategy1) {
}
TEST_F
(
TestPowInfo
,
CheckStrategy2
)
{
std
::
vector
<
Dimensions
>
inputs
=
{{
2
,
4
,
8
,
16
}};
std
::
vector
<
Dimensions
>
inputs
=
{{
2
,
4
,
8
,
16
}
,
{
2
,
4
,
8
,
16
}
};
StrategyPtr
strategy
=
NewStrategy
(
0
,
inputs
);
Status
ret
=
pow
->
Init
(
strategy
);
...
...
@@ -155,7 +155,7 @@ TEST_F(TestPowInfo, CheckStrategy2) {
}
TEST_F
(
TestPowInfo
,
CheckStrategy3
)
{
std
::
vector
<
Dimensions
>
inputs
=
{{
2
,
4
,
8
}};
std
::
vector
<
Dimensions
>
inputs
=
{{
2
,
4
,
8
}
,
{
2
,
4
,
8
}
};
StrategyPtr
strategy
=
NewStrategy
(
0
,
inputs
);
Status
ret
=
pow
->
Init
(
strategy
);
...
...
tests/ut/python/ops/test_math_ops.py
浏览文件 @
5841fe01
...
...
@@ -82,9 +82,10 @@ def test_sqrt():
def
test_pow
():
""" test_pow """
input_tensor
=
Tensor
(
np
.
array
([[
2
,
2
],
[
3
,
3
]]))
power
=
Tensor
(
np
.
array
(
3.0
,
np
.
int64
))
testpow
=
P
.
Pow
()
expect
=
np
.
array
([[
8
,
8
],
[
27
,
27
]])
result
=
testpow
(
input_tensor
,
3.0
)
result
=
testpow
(
input_tensor
,
power
)
assert
np
.
all
(
result
.
asnumpy
()
==
expect
)
...
...
tests/ut/python/ops/test_ops.py
浏览文件 @
5841fe01
...
...
@@ -224,11 +224,15 @@ test_case_math_ops = [
'block'
:
P
.
Minimum
(),
'desc_inputs'
:
[[
2
,
3
,
3
,
5
],
[
2
,
3
,
3
,
5
]],
'desc_bprop'
:
[[
2
,
3
,
3
,
5
]]}),
(
'Pow'
,
{
(
'Pow
_0
'
,
{
'block'
:
P
.
Pow
(),
'desc_const'
:
[
2.0
],
'desc_inputs'
:
[[
2
,
3
,
3
,
5
]],
'desc_bprop'
:
[[
2
,
3
,
3
,
5
]]}),
(
'Pow_1'
,
{
'block'
:
P
.
Pow
(),
'desc_inputs'
:
[[
3
,
5
],
[
2
,
3
,
3
,
5
]],
'desc_bprop'
:
[[
2
,
3
,
3
,
5
]]}),
(
'Exp'
,
{
'block'
:
P
.
Exp
(),
'desc_inputs'
:
[[
2
,
3
]],
...
...
tests/ut/python/parallel/test_element_wise_function.py
浏览文件 @
5841fe01
...
...
@@ -59,7 +59,7 @@ def test_matmul_pow():
context
.
set_auto_parallel_context
(
device_num
=
8
,
global_rank
=
0
)
strategy1
=
((
2
,
2
),
(
2
,
2
))
strategy2
=
((
4
,
2
),
)
strategy2
=
((
4
,
2
),
()
)
net
=
GradWrap
(
NetWithLoss
(
Net
(
strategy1
,
strategy2
)))
context
.
set_auto_parallel_context
(
parallel_mode
=
"semi_auto_parallel"
)
...
...
tests/vm_impl/math_ops_vm_impl.py
浏览文件 @
5841fe01
...
...
@@ -117,6 +117,7 @@ def vm_impl_pow(self):
"""Generate vm_impl function for Pow."""
def
vm_impl
(
x
,
y
):
x
=
x
.
asnumpy
()
y
=
y
.
asnumpy
()
res
=
vm
.
power
(
x
,
y
)
return
Tensor
(
res
)
return
vm_impl
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录