Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
80978cf3
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
80978cf3
编写于
4月 09, 2020
作者:
B
buxue
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support operator ** // % for scalar and tensor, and in not in for dict, ang str concat
上级
8f6b941a
变更
23
隐藏空白更改
内联
并排
Showing
23 changed file
with
435 addition
and
103 deletion
+435
-103
mindspore/_extends/parse/resources.py
mindspore/_extends/parse/resources.py
+5
-5
mindspore/ccsrc/operator/cc_implementations.cc
mindspore/ccsrc/operator/cc_implementations.cc
+32
-8
mindspore/ccsrc/operator/cc_implementations.h
mindspore/ccsrc/operator/cc_implementations.h
+2
-1
mindspore/ccsrc/operator/composite/do_signature.cc
mindspore/ccsrc/operator/composite/do_signature.cc
+16
-9
mindspore/ccsrc/operator/ops.cc
mindspore/ccsrc/operator/ops.cc
+4
-0
mindspore/ccsrc/operator/ops.h
mindspore/ccsrc/operator/ops.h
+4
-2
mindspore/ccsrc/operator/prim_nn.cc
mindspore/ccsrc/operator/prim_nn.cc
+5
-5
mindspore/ccsrc/operator/prim_statement.cc
mindspore/ccsrc/operator/prim_statement.cc
+31
-0
mindspore/ccsrc/operator/prim_structures.cc
mindspore/ccsrc/operator/prim_structures.cc
+39
-24
mindspore/ccsrc/operator/prim_to_function.cc
mindspore/ccsrc/operator/prim_to_function.cc
+29
-30
mindspore/ccsrc/pipeline/static_analysis/prim.cc
mindspore/ccsrc/pipeline/static_analysis/prim.cc
+5
-0
mindspore/ccsrc/pipeline/static_analysis/prim.h
mindspore/ccsrc/pipeline/static_analysis/prim.h
+6
-0
mindspore/ops/composite/multitype_ops/__init__.py
mindspore/ops/composite/multitype_ops/__init__.py
+6
-0
mindspore/ops/composite/multitype_ops/add_impl.py
mindspore/ops/composite/multitype_ops/add_impl.py
+17
-4
mindspore/ops/composite/multitype_ops/div_impl.py
mindspore/ops/composite/multitype_ops/div_impl.py
+2
-4
mindspore/ops/composite/multitype_ops/floordiv_impl.py
mindspore/ops/composite/multitype_ops/floordiv_impl.py
+50
-0
mindspore/ops/composite/multitype_ops/mod_impl.py
mindspore/ops/composite/multitype_ops/mod_impl.py
+50
-0
mindspore/ops/composite/multitype_ops/mul_impl.py
mindspore/ops/composite/multitype_ops/mul_impl.py
+2
-4
mindspore/ops/composite/multitype_ops/pow_impl.py
mindspore/ops/composite/multitype_ops/pow_impl.py
+50
-0
mindspore/ops/composite/multitype_ops/sub_impl.py
mindspore/ops/composite/multitype_ops/sub_impl.py
+2
-4
mindspore/ops/functional.py
mindspore/ops/functional.py
+7
-1
mindspore/ops/operations/math_ops.py
mindspore/ops/operations/math_ops.py
+2
-2
tests/ut/python/pipeline/parse/test_operator.py
tests/ut/python/pipeline/parse/test_operator.py
+69
-0
未找到文件。
mindspore/_extends/parse/resources.py
浏览文件 @
80978cf3
...
@@ -83,9 +83,9 @@ convert_object_map = {
...
@@ -83,9 +83,9 @@ convert_object_map = {
T
.
mul
:
multitype_ops
.
mul
,
T
.
mul
:
multitype_ops
.
mul
,
T
.
truediv
:
multitype_ops
.
div
,
T
.
truediv
:
multitype_ops
.
div
,
T
.
getitem
:
multitype_ops
.
getitem
,
T
.
getitem
:
multitype_ops
.
getitem
,
T
.
floordiv
:
NO_IMPLEMENT
,
T
.
floordiv
:
multitype_ops
.
floordiv
,
T
.
mod
:
F
.
scalar_
mod
,
T
.
mod
:
multitype_ops
.
mod
,
T
.
pow
:
F
.
scalar_pow
,
T
.
pow
:
multitype_ops
.
pow_
,
T
.
matmul
:
F
.
dot
,
T
.
matmul
:
F
.
dot
,
T
.
lshift
:
NO_IMPLEMENT
,
T
.
lshift
:
NO_IMPLEMENT
,
T
.
rshift
:
NO_IMPLEMENT
,
T
.
rshift
:
NO_IMPLEMENT
,
...
@@ -104,8 +104,8 @@ convert_object_map = {
...
@@ -104,8 +104,8 @@ convert_object_map = {
T
.
ge
:
multitype_ops
.
greater_equal
,
T
.
ge
:
multitype_ops
.
greater_equal
,
T
.
is_
:
F
.
is_
,
T
.
is_
:
F
.
is_
,
T
.
is_not
:
F
.
is_not
,
T
.
is_not
:
F
.
is_not
,
T
.
contains
:
NO_IMPLEMENT
,
T
.
contains
:
F
.
in_dict
,
T
.
not_contains
:
NO_IMPLEMENT
,
T
.
not_contains
:
F
.
not_in_dict
,
# system function
# system function
T
.
len
:
M
.
ms_len
,
T
.
len
:
M
.
ms_len
,
...
...
mindspore/ccsrc/operator/cc_implementations.cc
浏览文件 @
80978cf3
...
@@ -103,7 +103,7 @@ T InnerScalarMul(T x, T y) {
...
@@ -103,7 +103,7 @@ T InnerScalarMul(T x, T y) {
}
}
template
<
typename
T
>
template
<
typename
T
>
T
InnerScalarDiv
(
T
x
,
T
y
)
{
float
InnerScalarDiv
(
T
x
,
T
y
)
{
if
(
y
==
0
)
{
if
(
y
==
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"Divisor could not be zero"
;
MS_LOG
(
EXCEPTION
)
<<
"Divisor could not be zero"
;
}
}
...
@@ -111,23 +111,41 @@ T InnerScalarDiv(T x, T y) {
...
@@ -111,23 +111,41 @@ T InnerScalarDiv(T x, T y) {
MS_LOG
(
EXCEPTION
)
<<
"Overflow of the div of two signed number x: "
<<
std
::
to_string
(
x
)
MS_LOG
(
EXCEPTION
)
<<
"Overflow of the div of two signed number x: "
<<
std
::
to_string
(
x
)
<<
", y: "
<<
std
::
to_string
(
y
)
<<
"."
;
<<
", y: "
<<
std
::
to_string
(
y
)
<<
"."
;
}
}
return
x
/
y
;
return
static_cast
<
float
>
(
x
)
/
static_cast
<
float
>
(
y
)
;
}
}
int32_t
InnerScalarMod
(
int32_t
x
,
int32_t
y
)
{
template
<
typename
T
>
T
InnerScalarFloordiv
(
T
x
,
T
y
)
{
auto
ret
=
std
::
floor
(
InnerScalarDiv
(
x
,
y
));
if
(
std
::
is_integral
<
T
>::
value
)
{
return
static_cast
<
int
>
(
ret
);
}
return
ret
;
}
template
<
typename
T
>
T
InnerScalarMod
(
T
x
,
T
y
)
{
if
(
y
==
0
)
{
if
(
y
==
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"Could not mod to zero."
;
MS_LOG
(
EXCEPTION
)
<<
"Could not mod to zero."
;
}
}
if
(
IsSignedIntOverflow
(
x
,
y
,
OpType
::
MOD
))
{
if
(
std
::
is_integral
<
T
>::
value
&&
std
::
is_signed
<
T
>::
value
&&
IsSignedIntOverflow
(
x
,
y
,
OpType
::
MOD
))
{
MS_LOG
(
EXCEPTION
)
<<
"Overflow of the mod of two signed number x: "
<<
std
::
to_string
(
x
)
MS_LOG
(
EXCEPTION
)
<<
"Overflow of the mod of two signed number x: "
<<
std
::
to_string
(
x
)
<<
", y: "
<<
std
::
to_string
(
y
)
<<
"."
;
<<
", y: "
<<
std
::
to_string
(
y
)
<<
"."
;
}
}
return
x
%
y
;
if
(
std
::
is_integral
<
T
>::
value
)
{
return
static_cast
<
int
>
(
x
)
%
static_cast
<
int
>
(
y
);
}
float
x_int
=
std
::
floor
(
x
);
float
y_int
=
std
::
ceil
(
y
);
float
max
=
x_int
/
y_int
;
float
ret
=
x
-
y
*
max
;
return
ret
;
}
}
float
InnerScalarMod
(
float
,
float
)
{
MS_LOG
(
EXCEPTION
)
<<
"Float does not support mod operator."
;
}
template
<
typename
T
,
typename
U
>
T
InnerScalarPow
(
T
x
,
U
y
)
{
double
InnerScalarMod
(
double
,
double
)
{
MS_LOG
(
EXCEPTION
)
<<
"Double does not support mod operator."
;
}
return
std
::
pow
(
x
,
y
);
}
template
<
typename
T
,
typename
U
>
template
<
typename
T
,
typename
U
>
bool
InnerScalarEq
(
T
x
,
U
y
)
{
bool
InnerScalarEq
(
T
x
,
U
y
)
{
...
@@ -193,6 +211,8 @@ SCALAR_OP(Sub)
...
@@ -193,6 +211,8 @@ SCALAR_OP(Sub)
SCALAR_OP
(
Mul
)
SCALAR_OP
(
Mul
)
SCALAR_OP
(
Div
)
SCALAR_OP
(
Div
)
SCALAR_OP
(
Mod
)
SCALAR_OP
(
Mod
)
SCALAR_OP
(
Pow
)
SCALAR_OP
(
Floordiv
)
#define LOGIC_OP(op_t) \
#define LOGIC_OP(op_t) \
ValuePtr Scalar##op_t(const ValuePtrList& list) { \
ValuePtr Scalar##op_t(const ValuePtrList& list) { \
...
@@ -227,6 +247,10 @@ SCALAR_OP(Mod)
...
@@ -227,6 +247,10 @@ SCALAR_OP(Mod)
bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<int>(y)); \
bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<int>(y)); \
return MakeValue(sum); \
return MakeValue(sum); \
} \
} \
if (x->isa<Int32Imm>() && y->isa<FP32Imm>()) { \
bool sum = InnerScalar##op_t(GetValue<int>(x), GetValue<float>(y)); \
return MakeValue(sum); \
} \
if (x->isa<Int64Imm>() && y->isa<Int32Imm>()) { \
if (x->isa<Int64Imm>() && y->isa<Int32Imm>()) { \
bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int>(y)); \
bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int>(y)); \
return MakeValue(sum); \
return MakeValue(sum); \
...
...
mindspore/ccsrc/operator/cc_implementations.h
浏览文件 @
80978cf3
...
@@ -37,9 +37,10 @@ ValuePtr ScalarSub(const ValuePtrList& list);
...
@@ -37,9 +37,10 @@ ValuePtr ScalarSub(const ValuePtrList& list);
ValuePtr
ScalarMul
(
const
ValuePtrList
&
list
);
ValuePtr
ScalarMul
(
const
ValuePtrList
&
list
);
ValuePtr
ScalarDiv
(
const
ValuePtrList
&
list
);
ValuePtr
ScalarDiv
(
const
ValuePtrList
&
list
);
ValuePtr
ScalarMod
(
const
ValuePtrList
&
list
);
ValuePtr
ScalarMod
(
const
ValuePtrList
&
list
);
ValuePtr
ScalarPow
(
const
ValuePtrList
&
list
);
ValuePtr
ScalarFloordiv
(
const
ValuePtrList
&
list
);
ValuePtr
ScalarUAdd
(
const
ValuePtrList
&
list
);
ValuePtr
ScalarUAdd
(
const
ValuePtrList
&
list
);
ValuePtr
ScalarUSub
(
const
ValuePtrList
&
list
);
ValuePtr
ScalarUSub
(
const
ValuePtrList
&
list
);
ValuePtr
ScalarUSub
(
const
ValuePtrList
&
list
);
ValuePtr
ScalarLog
(
const
ValuePtrList
&
list
);
ValuePtr
ScalarLog
(
const
ValuePtrList
&
list
);
ValuePtr
ScalarEq
(
const
ValuePtrList
&
list
);
ValuePtr
ScalarEq
(
const
ValuePtrList
&
list
);
ValuePtr
ScalarLt
(
const
ValuePtrList
&
list
);
ValuePtr
ScalarLt
(
const
ValuePtrList
&
list
);
...
...
mindspore/ccsrc/operator/composite/do_signature.cc
浏览文件 @
80978cf3
...
@@ -88,14 +88,17 @@ std::map<SignatureEnumDType, size_t> GetMaxDtypeIndex(const std::vector<Signatur
...
@@ -88,14 +88,17 @@ std::map<SignatureEnumDType, size_t> GetMaxDtypeIndex(const std::vector<Signatur
if
(
indexs
.
size
()
<
2
)
{
if
(
indexs
.
size
()
<
2
)
{
continue
;
continue
;
}
}
size_t
m_index
=
indexs
[
0
];
for
(
size_t
i
=
1
;
i
<
indexs
.
size
();
++
i
)
{
for
(
const
auto
&
index
:
indexs
)
{
if
(
args_spec_list
[
indexs
[
i
]]
->
isa
<
abstract
::
AbstractTensor
>
())
{
AbstractBasePtr
arg_value
=
args_spec_list
[
index
];
m_index
=
indexs
[
i
];
if
(
arg_value
->
isa
<
abstract
::
AbstractRef
>
())
{
arg_value
=
arg_value
->
cast
<
abstract
::
AbstractRefPtr
>
()
->
ref
();
}
if
(
arg_value
->
isa
<
abstract
::
AbstractTensor
>
())
{
(
void
)
dst_type
.
insert
(
std
::
make_pair
(
type
,
index
));
break
;
}
}
}
if
(
args_spec_list
[
m_index
]
->
isa
<
abstract
::
AbstractTensor
>
())
{
(
void
)
dst_type
.
insert
(
std
::
make_pair
(
type
,
m_index
));
}
}
}
}
return
dst_type
;
return
dst_type
;
...
@@ -119,15 +122,19 @@ void DoAutoCast(const std::vector<Signature>& signature, const abstract::Abstrac
...
@@ -119,15 +122,19 @@ void DoAutoCast(const std::vector<Signature>& signature, const abstract::Abstrac
(
void
)
std
::
transform
(
signature
.
begin
(),
signature
.
end
(),
std
::
back_inserter
(
dtypes
),
(
void
)
std
::
transform
(
signature
.
begin
(),
signature
.
end
(),
std
::
back_inserter
(
dtypes
),
[](
const
Signature
&
sig
)
{
return
sig
.
dtype
;
});
[](
const
Signature
&
sig
)
{
return
sig
.
dtype
;
});
int
empty_dtype_count
=
std
::
count
(
dtypes
.
begin
(),
dtypes
.
end
(),
SignatureEnumDType
::
kDTypeEmptyDefaultValue
);
int
empty_dtype_count
=
std
::
count
(
dtypes
.
begin
(),
dtypes
.
end
(),
SignatureEnumDType
::
kDTypeEmptyDefaultValue
);
if
(
dtypes
.
size
()
==
0
||
static_cast
<
int
>
(
dtypes
.
size
())
==
empty_dtype_count
)
{
if
(
dtypes
.
empty
()
||
static_cast
<
int
>
(
dtypes
.
size
())
==
empty_dtype_count
)
{
return
;
return
;
}
}
// Stat the index of the arguments with the largest type in the same SignatureEnumDType.
// Stat the index of the arguments with the largest type in the same SignatureEnumDType.
std
::
map
<
SignatureEnumDType
,
size_t
>
dst_type
=
GetMaxDtypeIndex
(
dtypes
,
args_spec_list
);
std
::
map
<
SignatureEnumDType
,
size_t
>
dst_type
=
GetMaxDtypeIndex
(
dtypes
,
args_spec_list
);
// Identify which arg requires auto cast
// Identify which arg requires auto cast
for
(
size_t
i
=
0
;
i
<
args_spec_list
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
args_spec_list
.
size
();
++
i
)
{
AbstractBasePtr
arg_value
=
args_spec_list
[
i
];
if
(
arg_value
->
isa
<
abstract
::
AbstractRef
>
())
{
arg_value
=
arg_value
->
cast
<
abstract
::
AbstractRefPtr
>
()
->
ref
();
}
auto
it
=
dst_type
.
find
(
dtypes
[
i
]);
auto
it
=
dst_type
.
find
(
dtypes
[
i
]);
if
(
it
==
dst_type
.
end
()
||
it
->
second
==
i
||
!
arg
s_spec_list
[
i
]
->
isa
<
abstract
::
AbstractScalar
>
())
{
if
(
it
==
dst_type
.
end
()
||
it
->
second
==
i
||
!
arg
_value
->
isa
<
abstract
::
AbstractScalar
>
())
{
continue
;
continue
;
}
}
// get source node for cast
// get source node for cast
...
...
mindspore/ccsrc/operator/ops.cc
浏览文件 @
80978cf3
...
@@ -28,6 +28,7 @@ const PrimitivePtr kPrimScalarAdd = std::make_shared<Primitive>("scalar_add");
...
@@ -28,6 +28,7 @@ const PrimitivePtr kPrimScalarAdd = std::make_shared<Primitive>("scalar_add");
const
PrimitivePtr
kPrimScalarSub
=
std
::
make_shared
<
Primitive
>
(
"scalar_sub"
);
const
PrimitivePtr
kPrimScalarSub
=
std
::
make_shared
<
Primitive
>
(
"scalar_sub"
);
const
PrimitivePtr
kPrimScalarMul
=
std
::
make_shared
<
Primitive
>
(
"scalar_mul"
);
const
PrimitivePtr
kPrimScalarMul
=
std
::
make_shared
<
Primitive
>
(
"scalar_mul"
);
const
PrimitivePtr
kPrimScalarDiv
=
std
::
make_shared
<
Primitive
>
(
"scalar_div"
);
const
PrimitivePtr
kPrimScalarDiv
=
std
::
make_shared
<
Primitive
>
(
"scalar_div"
);
const
PrimitivePtr
kPrimScalarFloordiv
=
std
::
make_shared
<
Primitive
>
(
"scalar_floordiv"
);
const
PrimitivePtr
kPrimScalarMod
=
std
::
make_shared
<
Primitive
>
(
"scalar_mod"
);
const
PrimitivePtr
kPrimScalarMod
=
std
::
make_shared
<
Primitive
>
(
"scalar_mod"
);
const
PrimitivePtr
kPrimScalarPow
=
std
::
make_shared
<
Primitive
>
(
"scalar_pow"
);
const
PrimitivePtr
kPrimScalarPow
=
std
::
make_shared
<
Primitive
>
(
"scalar_pow"
);
const
PrimitivePtr
kPrimScalarTrunc
=
std
::
make_shared
<
Primitive
>
(
"scalar_trunc"
);
const
PrimitivePtr
kPrimScalarTrunc
=
std
::
make_shared
<
Primitive
>
(
"scalar_trunc"
);
...
@@ -78,6 +79,7 @@ const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_ins
...
@@ -78,6 +79,7 @@ const PrimitivePtr kPrimCreateInstance = std::make_shared<Primitive>("create_ins
// Structure
// Structure
const
PrimitivePtr
kPrimStringEqual
=
std
::
make_shared
<
Primitive
>
(
"string_equal"
);
const
PrimitivePtr
kPrimStringEqual
=
std
::
make_shared
<
Primitive
>
(
"string_equal"
);
const
PrimitivePtr
kPrimStringConcat
=
std
::
make_shared
<
Primitive
>
(
"string_concat"
);
const
PrimitivePtr
kPrimMakeTuple
=
std
::
make_shared
<
Primitive
>
(
"make_tuple"
);
const
PrimitivePtr
kPrimMakeTuple
=
std
::
make_shared
<
Primitive
>
(
"make_tuple"
);
const
PrimitivePtr
kPrimMakeList
=
std
::
make_shared
<
Primitive
>
(
"make_list"
);
const
PrimitivePtr
kPrimMakeList
=
std
::
make_shared
<
Primitive
>
(
"make_list"
);
const
PrimitivePtr
kPrimMakeDict
=
std
::
make_shared
<
Primitive
>
(
"make_dict"
);
const
PrimitivePtr
kPrimMakeDict
=
std
::
make_shared
<
Primitive
>
(
"make_dict"
);
...
@@ -221,6 +223,8 @@ const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("Bro
...
@@ -221,6 +223,8 @@ const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("Bro
const
PrimitivePtr
kPrimControlDepend
=
std
::
make_shared
<
Primitive
>
(
"ControlDepend"
);
const
PrimitivePtr
kPrimControlDepend
=
std
::
make_shared
<
Primitive
>
(
"ControlDepend"
);
const
PrimitivePtr
kPrimIs_
=
std
::
make_shared
<
Primitive
>
(
"is_"
);
const
PrimitivePtr
kPrimIs_
=
std
::
make_shared
<
Primitive
>
(
"is_"
);
const
PrimitivePtr
kPrimIsNot
=
std
::
make_shared
<
Primitive
>
(
"is_not"
);
const
PrimitivePtr
kPrimIsNot
=
std
::
make_shared
<
Primitive
>
(
"is_not"
);
const
PrimitivePtr
kPrimInDict
=
std
::
make_shared
<
Primitive
>
(
"in_dict"
);
const
PrimitivePtr
kPrimNotInDict
=
std
::
make_shared
<
Primitive
>
(
"not_in_dict"
);
// Comm ops
// Comm ops
const
PrimitivePtr
kPrimMirror
=
std
::
make_shared
<
Primitive
>
(
"_MirrorOperator"
);
const
PrimitivePtr
kPrimMirror
=
std
::
make_shared
<
Primitive
>
(
"_MirrorOperator"
);
...
...
mindspore/ccsrc/operator/ops.h
浏览文件 @
80978cf3
...
@@ -34,6 +34,7 @@ extern const PrimitivePtr kPrimScalarAdd;
...
@@ -34,6 +34,7 @@ extern const PrimitivePtr kPrimScalarAdd;
extern
const
PrimitivePtr
kPrimScalarSub
;
extern
const
PrimitivePtr
kPrimScalarSub
;
extern
const
PrimitivePtr
kPrimScalarMul
;
extern
const
PrimitivePtr
kPrimScalarMul
;
extern
const
PrimitivePtr
kPrimScalarDiv
;
extern
const
PrimitivePtr
kPrimScalarDiv
;
extern
const
PrimitivePtr
kPrimScalarFloordiv
;
extern
const
PrimitivePtr
kPrimScalarMod
;
extern
const
PrimitivePtr
kPrimScalarMod
;
extern
const
PrimitivePtr
kPrimScalarPow
;
extern
const
PrimitivePtr
kPrimScalarPow
;
extern
const
PrimitivePtr
kPrimScalarTrunc
;
extern
const
PrimitivePtr
kPrimScalarTrunc
;
...
@@ -84,6 +85,7 @@ extern const PrimitivePtr kPrimCreateInstance;
...
@@ -84,6 +85,7 @@ extern const PrimitivePtr kPrimCreateInstance;
// Structure
// Structure
extern
const
PrimitivePtr
kPrimStringEqual
;
extern
const
PrimitivePtr
kPrimStringEqual
;
extern
const
PrimitivePtr
kPrimStringConcat
;
extern
const
PrimitivePtr
kPrimMakeTuple
;
extern
const
PrimitivePtr
kPrimMakeTuple
;
extern
const
PrimitivePtr
kPrimMakeList
;
extern
const
PrimitivePtr
kPrimMakeList
;
extern
const
PrimitivePtr
kPrimMakeDict
;
extern
const
PrimitivePtr
kPrimMakeDict
;
...
@@ -227,8 +229,8 @@ extern const PrimitivePtr kPrimBroadcastGradientArgs;
...
@@ -227,8 +229,8 @@ extern const PrimitivePtr kPrimBroadcastGradientArgs;
extern
const
PrimitivePtr
kPrimControlDepend
;
extern
const
PrimitivePtr
kPrimControlDepend
;
extern
const
PrimitivePtr
kPrimIs_
;
extern
const
PrimitivePtr
kPrimIs_
;
extern
const
PrimitivePtr
kPrimIsNot
;
extern
const
PrimitivePtr
kPrimIsNot
;
extern
const
PrimitivePtr
kPrim
MinimumGrad
;
extern
const
PrimitivePtr
kPrim
InDict
;
extern
const
PrimitivePtr
kPrim
MaximumGrad
;
extern
const
PrimitivePtr
kPrim
NotInDict
;
// Comm ops
// Comm ops
extern
const
PrimitivePtr
kPrimMirror
;
extern
const
PrimitivePtr
kPrimMirror
;
...
...
mindspore/ccsrc/operator/prim_nn.cc
浏览文件 @
80978cf3
...
@@ -114,12 +114,12 @@ void FusedBatchNormCheckDim(const PrimitivePtr &primitive, const AbstractBasePtr
...
@@ -114,12 +114,12 @@ void FusedBatchNormCheckDim(const PrimitivePtr &primitive, const AbstractBasePtr
AbstractTensorPtr
arg
=
CheckArg
<
AbstractTensor
>
(
op_name
,
args_spec_list
,
i
);
AbstractTensorPtr
arg
=
CheckArg
<
AbstractTensor
>
(
op_name
,
args_spec_list
,
i
);
ShapePtr
arg_shape
=
dyn_cast
<
Shape
>
(
arg
->
GetShapeTrack
());
ShapePtr
arg_shape
=
dyn_cast
<
Shape
>
(
arg
->
GetShapeTrack
());
if
(
arg_shape
==
nullptr
)
{
if
(
arg_shape
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
""
<<
op_name
<<
" type of args["
<<
i
<<
"] should be Shape, but "
<<
arg
->
ToString
();
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" type of args["
<<
i
<<
"] should be Shape, but "
<<
arg
->
ToString
();
}
}
if
(
i
==
0
)
{
if
(
i
==
0
)
{
if
(
arg_shape
->
shape
().
size
()
<
2
)
{
if
(
arg_shape
->
shape
().
size
()
<
2
)
{
MS_LOG
(
EXCEPTION
)
<<
""
<<
op_name
<<
" shape of args["
<<
i
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" shape of args["
<<
i
<<
"] should be TensorShape with dimension greater than 1, but shape: "
<<
"] should be TensorShape with dimension greater than 1, but shape: "
<<
arg_shape
->
ToString
();
<<
arg_shape
->
ToString
();
}
}
...
@@ -127,7 +127,7 @@ void FusedBatchNormCheckDim(const PrimitivePtr &primitive, const AbstractBasePtr
...
@@ -127,7 +127,7 @@ void FusedBatchNormCheckDim(const PrimitivePtr &primitive, const AbstractBasePtr
}
}
if
(
arg_shape
->
shape
().
size
()
!=
1
)
{
if
(
arg_shape
->
shape
().
size
()
!=
1
)
{
MS_LOG
(
EXCEPTION
)
<<
""
<<
op_name
<<
" shape of args["
<<
i
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" shape of args["
<<
i
<<
"] should be TensorShape with dimension: 1, but shape: "
<<
arg_shape
->
ToString
();
<<
"] should be TensorShape with dimension: 1, but shape: "
<<
arg_shape
->
ToString
();
}
}
}
}
...
@@ -159,7 +159,7 @@ AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const Primiti
...
@@ -159,7 +159,7 @@ AbstractBasePtr InferImplFusedBatchNorm(const AnalysisEnginePtr &, const Primiti
MS_LOG
(
EXCEPTION
)
<<
"Arg shape size should >= 1."
;
MS_LOG
(
EXCEPTION
)
<<
"Arg shape size should >= 1."
;
}
}
if
(
arg_shape_list
[
0
]
!=
input_shape_list
[
1
])
{
if
(
arg_shape_list
[
0
]
!=
input_shape_list
[
1
])
{
MS_LOG
(
EXCEPTION
)
<<
""
<<
op_name
<<
" size of tensor param["
<<
i
<<
"](which is "
<<
arg_shape_list
[
0
]
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" size of tensor param["
<<
i
<<
"](which is "
<<
arg_shape_list
[
0
]
<<
") should match the second dimension of tensor"
<<
") should match the second dimension of tensor"
" param[0](which is "
" param[0](which is "
<<
input_shape_list
[
1
]
<<
")."
;
<<
input_shape_list
[
1
]
<<
")."
;
...
@@ -378,7 +378,7 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti
...
@@ -378,7 +378,7 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti
TypePtr
prob_type
=
keep_prob
->
element
()
->
BuildType
();
TypePtr
prob_type
=
keep_prob
->
element
()
->
BuildType
();
if
((
prob_type
->
type_id
()
!=
kNumberTypeFloat16
)
&&
(
prob_type
->
type_id
()
!=
kNumberTypeFloat32
))
{
if
((
prob_type
->
type_id
()
!=
kNumberTypeFloat16
)
&&
(
prob_type
->
type_id
()
!=
kNumberTypeFloat32
))
{
MS_LOG
(
EXCEPTION
)
<<
""
<<
op_name
<<
" keep_prob type should be float16 or float32, but "
<<
prob_type
->
ToString
()
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" keep_prob type should be float16 or float32, but "
<<
prob_type
->
ToString
()
<<
"."
;
<<
"."
;
}
}
...
...
mindspore/ccsrc/operator/prim_statement.cc
浏览文件 @
80978cf3
...
@@ -169,5 +169,36 @@ AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &pr
...
@@ -169,5 +169,36 @@ AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &pr
return
std
::
make_shared
<
AbstractScalar
>
(
!
(
*
t
==
*
x
));
return
std
::
make_shared
<
AbstractScalar
>
(
!
(
*
t
==
*
x
));
}
}
bool
IsInDict
(
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
const
std
::
string
op_name
=
primitive
->
name
();
CheckArgsSize
(
op_name
,
args_spec_list
,
2
);
auto
key
=
CheckArg
<
AbstractScalar
>
(
op_name
,
args_spec_list
,
0
);
auto
dict
=
CheckArg
<
AbstractDictionary
>
(
op_name
,
args_spec_list
,
1
);
ValuePtr
key_value
=
key
->
BuildValue
();
if
(
!
key_value
->
isa
<
StringImm
>
())
{
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" evaluator key should be string, but got "
<<
key_value
->
ToString
();
}
auto
key_str
=
GetValue
<
std
::
string
>
(
key_value
);
std
::
vector
<
AbstractAttribute
>
dict_elems
=
dict
->
elements
();
auto
it
=
std
::
find_if
(
dict_elems
.
begin
(),
dict_elems
.
end
(),
[
key_str
](
const
AbstractAttribute
&
item
)
{
return
item
.
first
==
key_str
;
});
return
it
!=
dict_elems
.
end
();
}
AbstractBasePtr
InferImplInDict
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// statement: x in t
// Inputs: x, t
return
std
::
make_shared
<
AbstractScalar
>
(
IsInDict
(
primitive
,
args_spec_list
));
}
AbstractBasePtr
InferImplNotInDict
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// statement: x not in t
// Inputs: x, t
return
std
::
make_shared
<
AbstractScalar
>
(
!
IsInDict
(
primitive
,
args_spec_list
));
}
}
// namespace abstract
}
// namespace abstract
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/operator/prim_structures.cc
浏览文件 @
80978cf3
...
@@ -36,7 +36,7 @@ AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitiveP
...
@@ -36,7 +36,7 @@ AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitiveP
ValuePtr
value_x
=
scalar_x
->
BuildValue
();
ValuePtr
value_x
=
scalar_x
->
BuildValue
();
ValuePtr
value_y
=
scalar_y
->
BuildValue
();
ValuePtr
value_y
=
scalar_y
->
BuildValue
();
if
(
!
value_x
->
isa
<
StringImm
>
()
||
!
value_y
->
isa
<
StringImm
>
())
{
if
(
!
value_x
->
isa
<
StringImm
>
()
||
!
value_y
->
isa
<
StringImm
>
())
{
MS_LOG
(
EXCEPTION
)
<<
""
<<
op_name
<<
" requires 2 parameters are string, but got param0: "
<<
value_x
->
ToString
()
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" requires 2 parameters are string, but got param0: "
<<
value_x
->
ToString
()
<<
", param1: "
<<
value_y
->
ToString
();
<<
", param1: "
<<
value_y
->
ToString
();
}
}
...
@@ -44,6 +44,25 @@ AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitiveP
...
@@ -44,6 +44,25 @@ AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitiveP
return
std
::
make_shared
<
AbstractScalar
>
(
ret
);
return
std
::
make_shared
<
AbstractScalar
>
(
ret
);
}
}
AbstractBasePtr
InferImplStringConcat
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// Inputs: two scalars whose value is a string.
const
std
::
string
op_name
=
primitive
->
name
();
CheckArgsSize
(
op_name
,
args_spec_list
,
2
);
AbstractScalarPtr
scalar_x
=
CheckArg
<
AbstractScalar
>
(
op_name
,
args_spec_list
,
0
);
AbstractScalarPtr
scalar_y
=
CheckArg
<
AbstractScalar
>
(
op_name
,
args_spec_list
,
1
);
ValuePtr
value_x
=
scalar_x
->
BuildValue
();
ValuePtr
value_y
=
scalar_y
->
BuildValue
();
if
(
!
value_x
->
isa
<
StringImm
>
()
||
!
value_y
->
isa
<
StringImm
>
())
{
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" requires 2 parameters are string, but got param0: "
<<
value_x
->
ToString
()
<<
", param1: "
<<
value_y
->
ToString
();
}
std
::
string
ret
=
(
value_x
->
cast
<
StringImmPtr
>
()
->
value
()
+
value_y
->
cast
<
StringImmPtr
>
()
->
value
());
return
std
::
make_shared
<
AbstractScalar
>
(
ret
);
}
AbstractBasePtr
InferImplMakeTuple
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
AbstractBasePtr
InferImplMakeTuple
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
)
{
const
AbstractBasePtrList
&
args_spec_list
)
{
return
std
::
make_shared
<
AbstractTuple
>
(
args_spec_list
);
return
std
::
make_shared
<
AbstractTuple
>
(
args_spec_list
);
...
@@ -64,7 +83,7 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr
...
@@ -64,7 +83,7 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr
size_t
keys_size
=
keys
->
size
();
size_t
keys_size
=
keys
->
size
();
if
(
values
->
size
()
!=
keys_size
)
{
if
(
values
->
size
()
!=
keys_size
)
{
MS_LOG
(
EXCEPTION
)
<<
""
<<
op_name
<<
" evaluator keys' size is not equal with values' size"
;
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" evaluator keys' size is not equal with values' size"
;
}
}
std
::
vector
<
AbstractAttribute
>
key_value
;
std
::
vector
<
AbstractAttribute
>
key_value
;
...
@@ -76,7 +95,7 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr
...
@@ -76,7 +95,7 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr
ValuePtr
keyPtr
=
key
->
BuildValue
();
ValuePtr
keyPtr
=
key
->
BuildValue
();
MS_EXCEPTION_IF_NULL
(
keyPtr
);
MS_EXCEPTION_IF_NULL
(
keyPtr
);
if
(
!
keyPtr
->
isa
<
StringImm
>
())
{
if
(
!
keyPtr
->
isa
<
StringImm
>
())
{
MS_LOG
(
EXCEPTION
)
<<
""
<<
op_name
<<
" evaluator keys should be string, but got "
<<
keyPtr
->
ToString
();
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" evaluator keys should be string, but got "
<<
keyPtr
->
ToString
();
}
}
std
::
string
key_string
=
GetValue
<
std
::
string
>
(
keyPtr
);
std
::
string
key_string
=
GetValue
<
std
::
string
>
(
keyPtr
);
key_value
.
emplace_back
(
key_string
,
value_list
[
index
]);
key_value
.
emplace_back
(
key_string
,
value_list
[
index
]);
...
@@ -93,7 +112,7 @@ AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr
...
@@ -93,7 +112,7 @@ AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr
ValuePtr
keyPtr
=
key
->
BuildValue
();
ValuePtr
keyPtr
=
key
->
BuildValue
();
if
(
!
keyPtr
->
isa
<
StringImm
>
())
{
if
(
!
keyPtr
->
isa
<
StringImm
>
())
{
MS_LOG
(
EXCEPTION
)
<<
""
<<
op_name
<<
" evaluator key should be string, but got "
<<
keyPtr
->
ToString
();
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" evaluator key should be string, but got "
<<
keyPtr
->
ToString
();
}
}
std
::
string
key_string
=
GetValue
<
std
::
string
>
(
keyPtr
);
std
::
string
key_string
=
GetValue
<
std
::
string
>
(
keyPtr
);
return
std
::
make_shared
<
AbstractKeywordArg
>
(
key_string
,
args_spec_list
[
1
]);
return
std
::
make_shared
<
AbstractKeywordArg
>
(
key_string
,
args_spec_list
[
1
]);
...
@@ -109,14 +128,13 @@ AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const Primitive
...
@@ -109,14 +128,13 @@ AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const Primitive
ValuePtr
key_value
=
key
->
BuildValue
();
ValuePtr
key_value
=
key
->
BuildValue
();
if
(
!
key_value
->
isa
<
StringImm
>
())
{
if
(
!
key_value
->
isa
<
StringImm
>
())
{
MS_LOG
(
EXCEPTION
)
<<
""
<<
op_name
<<
" evaluator key should be string, but got "
<<
key_value
->
ToString
();
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" evaluator key should be string, but got "
<<
key_value
->
ToString
();
}
}
std
::
string
key_input
=
GetValue
<
std
::
string
>
(
key_value
);
std
::
string
key_input
=
GetValue
<
std
::
string
>
(
key_value
);
std
::
string
key_actual
=
kwarg
->
get_key
();
std
::
string
key_actual
=
kwarg
->
get_key
();
if
(
key_actual
!=
key_input
)
{
if
(
key_actual
!=
key_input
)
{
MS_LOG
(
EXCEPTION
)
<<
""
<<
op_name
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" evaluator input key should be same as AbstractKeywordArg' key, but input is "
<<
" evaluator input key should be same as AbstractKeywordArg' key, but input is "
<<
key_input
<<
key_input
<<
", AbstractKeywordArg' key is "
<<
key_actual
;
<<
", AbstractKeywordArg' key is "
<<
key_actual
;
}
}
return
kwarg
->
get_arg
();
return
kwarg
->
get_arg
();
}
}
...
@@ -187,13 +205,12 @@ AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const Abstra
...
@@ -187,13 +205,12 @@ AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const Abstra
ValuePtr
index_value
=
index
->
BuildValue
();
ValuePtr
index_value
=
index
->
BuildValue
();
if
(
!
index_value
->
isa
<
Int32Imm
>
())
{
if
(
!
index_value
->
isa
<
Int32Imm
>
())
{
MS_LOG
(
EXCEPTION
)
<<
""
<<
op_name
<<
" evaluator index should be an int32 number, but got "
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" evaluator index should be an int32 number, but got "
<<
index_value
->
ToString
();
<<
index_value
->
ToString
();
}
}
int
idx_v
=
GetValue
<
int
>
(
index_value
);
int
idx_v
=
GetValue
<
int
>
(
index_value
);
std
::
size_t
nelems
=
queue
->
elements
().
size
();
std
::
size_t
nelems
=
queue
->
elements
().
size
();
if
(
idx_v
>=
SizeToInt
(
nelems
)
||
idx_v
<
-
SizeToInt
(
nelems
))
{
if
(
idx_v
>=
SizeToInt
(
nelems
)
||
idx_v
<
-
SizeToInt
(
nelems
))
{
MS_LOG
(
EXCEPTION
)
<<
""
<<
op_name
<<
" evaluator index should be in range[-"
<<
SizeToInt
(
nelems
)
<<
", "
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" evaluator index should be in range[-"
<<
SizeToInt
(
nelems
)
<<
", "
<<
SizeToInt
(
nelems
)
<<
"), but got "
<<
idx_v
<<
"."
;
<<
SizeToInt
(
nelems
)
<<
"), but got "
<<
idx_v
<<
"."
;
}
}
...
@@ -215,8 +232,7 @@ AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const Abstra
...
@@ -215,8 +232,7 @@ AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const Abstra
ValuePtr
index_value
=
index
->
BuildValue
();
ValuePtr
index_value
=
index
->
BuildValue
();
if
(
!
index_value
->
isa
<
Int32Imm
>
())
{
if
(
!
index_value
->
isa
<
Int32Imm
>
())
{
MS_LOG
(
EXCEPTION
)
<<
""
<<
op_name
<<
" evaluator index should be an int32 number, but got "
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" evaluator index should be an int32 number, but got "
<<
index_value
->
ToString
();
<<
index_value
->
ToString
();
}
}
int
idx_v
=
GetValue
<
int
>
(
index_value
);
int
idx_v
=
GetValue
<
int
>
(
index_value
);
if
(
idx_v
<
0
)
{
if
(
idx_v
<
0
)
{
...
@@ -227,8 +243,7 @@ AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const Abstra
...
@@ -227,8 +243,7 @@ AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const Abstra
AbstractBasePtrList
elements
=
queue
->
elements
();
AbstractBasePtrList
elements
=
queue
->
elements
();
std
::
size_t
nelems
=
elements
.
size
();
std
::
size_t
nelems
=
elements
.
size
();
if
(
uidx_v
>=
nelems
)
{
if
(
uidx_v
>=
nelems
)
{
MS_LOG
(
EXCEPTION
)
<<
""
<<
op_name
<<
" evaluator the index: "
<<
uidx_v
<<
" to set out of range: "
<<
nelems
-
1
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" evaluator the index: "
<<
uidx_v
<<
" to set out of range: "
<<
nelems
-
1
<<
"."
;
<<
"."
;
}
}
elements
[
uidx_v
]
=
args_spec_list
[
2
];
elements
[
uidx_v
]
=
args_spec_list
[
2
];
return
std
::
make_shared
<
T
>
(
elements
);
return
std
::
make_shared
<
T
>
(
elements
);
...
@@ -264,12 +279,12 @@ AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitiveP
...
@@ -264,12 +279,12 @@ AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitiveP
ValuePtr
key_value
=
key
->
BuildValue
();
ValuePtr
key_value
=
key
->
BuildValue
();
if
(
!
key_value
->
isa
<
StringImm
>
())
{
if
(
!
key_value
->
isa
<
StringImm
>
())
{
MS_LOG
(
EXCEPTION
)
<<
""
<<
op_name
<<
" evaluator key should be string, but got "
<<
key_value
->
ToString
();
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" evaluator key should be string, but got "
<<
key_value
->
ToString
();
}
}
std
::
string
key_str
=
GetValue
<
std
::
string
>
(
key_value
);
auto
key_str
=
GetValue
<
std
::
string
>
(
key_value
);
std
::
vector
<
AbstractAttribute
>
dict_elems
=
dict
->
elements
();
std
::
vector
<
AbstractAttribute
>
dict_elems
=
dict
->
elements
();
auto
it
=
std
::
find_if
(
dict_elems
.
begin
(),
dict_elems
.
end
(),
auto
it
=
std
::
find_if
(
dict_elems
.
begin
(),
dict_elems
.
end
(),
[
key_str
](
AbstractAttribute
&
item
)
{
return
item
.
first
==
key_str
;
});
[
key_str
](
const
AbstractAttribute
&
item
)
{
return
item
.
first
==
key_str
;
});
if
(
it
==
dict_elems
.
end
())
{
if
(
it
==
dict_elems
.
end
())
{
MS_LOG
(
EXCEPTION
)
<<
"The key "
<<
key_str
<<
" does not exist in the dict:"
<<
args_spec_list
[
0
]
->
ToString
();
MS_LOG
(
EXCEPTION
)
<<
"The key "
<<
key_str
<<
" does not exist in the dict:"
<<
args_spec_list
[
0
]
->
ToString
();
...
@@ -287,7 +302,7 @@ AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitiveP
...
@@ -287,7 +302,7 @@ AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitiveP
ValuePtr
key_value
=
key
->
BuildValue
();
ValuePtr
key_value
=
key
->
BuildValue
();
if
(
!
key_value
->
isa
<
StringImm
>
())
{
if
(
!
key_value
->
isa
<
StringImm
>
())
{
MS_LOG
(
EXCEPTION
)
<<
""
<<
op_name
<<
" evaluator key should be string, but got "
<<
key_value
->
ToString
();
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" evaluator key should be string, but got "
<<
key_value
->
ToString
();
}
}
std
::
string
key_str
=
GetValue
<
std
::
string
>
(
key_value
);
std
::
string
key_str
=
GetValue
<
std
::
string
>
(
key_value
);
std
::
vector
<
AbstractAttribute
>
dict_elems
=
dict
->
elements
();
std
::
vector
<
AbstractAttribute
>
dict_elems
=
dict
->
elements
();
...
@@ -446,27 +461,27 @@ AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitiveP
...
@@ -446,27 +461,27 @@ AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitiveP
auto
x_shp_value
=
shape_x
->
BuildValue
();
auto
x_shp_value
=
shape_x
->
BuildValue
();
if
(
x_shp_value
->
isa
<
AnyValue
>
())
{
if
(
x_shp_value
->
isa
<
AnyValue
>
())
{
MS_LOG
(
EXCEPTION
)
<<
""
<<
op_name
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" evaluator shape's data field can't be anything: "
<<
args_spec_list
[
1
]
->
ToString
();
<<
" evaluator shape's data field can't be anything: "
<<
args_spec_list
[
1
]
->
ToString
();
}
}
// Axis can be scalar, tuple or None
// Axis can be scalar, tuple or None
AbstractTuplePtr
axis
=
nullptr
;
AbstractTuplePtr
axis
=
nullptr
;
if
(
args_spec_list
[
1
]
->
isa
<
AbstractScalar
>
())
{
if
(
args_spec_list
[
1
]
->
isa
<
AbstractScalar
>
())
{
MS_LOG
(
DEBUG
)
<<
""
<<
op_name
<<
" evaluator second parameter is scalar"
;
MS_LOG
(
DEBUG
)
<<
op_name
<<
" evaluator second parameter is scalar"
;
AbstractBasePtrList
axis_list
=
{
dyn_cast
<
AbstractScalar
>
(
args_spec_list
[
1
])};
AbstractBasePtrList
axis_list
=
{
dyn_cast
<
AbstractScalar
>
(
args_spec_list
[
1
])};
axis
=
std
::
make_shared
<
AbstractTuple
>
(
axis_list
);
axis
=
std
::
make_shared
<
AbstractTuple
>
(
axis_list
);
}
else
if
(
args_spec_list
[
1
]
->
isa
<
AbstractTuple
>
())
{
}
else
if
(
args_spec_list
[
1
]
->
isa
<
AbstractTuple
>
())
{
MS_LOG
(
DEBUG
)
<<
""
<<
op_name
<<
" evaluator second parameter is tuple"
;
MS_LOG
(
DEBUG
)
<<
op_name
<<
" evaluator second parameter is tuple"
;
axis
=
args_spec_list
[
1
]
->
cast
<
AbstractTuplePtr
>
();
axis
=
args_spec_list
[
1
]
->
cast
<
AbstractTuplePtr
>
();
}
else
{
}
else
{
MS_LOG
(
EXCEPTION
)
<<
""
<<
op_name
<<
" evaluator second parameter should be a scalar or tuple, but got "
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" evaluator second parameter should be a scalar or tuple, but got "
<<
args_spec_list
[
1
]
->
ToString
();
<<
args_spec_list
[
1
]
->
ToString
();
}
}
auto
axis_value
=
axis
->
BuildValue
();
auto
axis_value
=
axis
->
BuildValue
();
if
(
axis_value
->
isa
<
AnyValue
>
())
{
if
(
axis_value
->
isa
<
AnyValue
>
())
{
MS_LOG
(
EXCEPTION
)
<<
""
<<
op_name
MS_LOG
(
EXCEPTION
)
<<
op_name
<<
" evaluator shape's data field can't be anything: "
<<
args_spec_list
[
1
]
->
ToString
();
<<
" evaluator shape's data field can't be anything: "
<<
args_spec_list
[
1
]
->
ToString
();
}
}
auto
axis_value_ptr
=
axis_value
->
cast
<
ValueTuplePtr
>
();
auto
axis_value_ptr
=
axis_value
->
cast
<
ValueTuplePtr
>
();
...
...
mindspore/ccsrc/operator/prim_to_function.cc
浏览文件 @
80978cf3
...
@@ -24,36 +24,35 @@ namespace mindspore {
...
@@ -24,36 +24,35 @@ namespace mindspore {
namespace
prim
{
namespace
prim
{
PrimToFunction
::
PrimToFunction
()
PrimToFunction
::
PrimToFunction
()
:
prim_func_type_map_
({
:
prim_func_type_map_
({
// ONE_ARG prim
// ONE_ARG prim
{
"bool_not"
,
kPrimTypeOneArg
},
{
"bool_not"
,
kPrimTypeOneArg
},
{
"scalar_cos"
,
kPrimTypeOneArg
},
{
"scalar_cos"
,
kPrimTypeOneArg
},
{
"scalar_exp"
,
kPrimTypeOneArg
},
{
"scalar_exp"
,
kPrimTypeOneArg
},
{
"scalar_floor"
,
kPrimTypeOneArg
},
{
"scalar_floor"
,
kPrimTypeOneArg
},
{
"scalar_log"
,
kPrimTypeOneArg
},
{
"scalar_log"
,
kPrimTypeOneArg
},
{
"scalar_sin"
,
kPrimTypeOneArg
},
{
"scalar_sin"
,
kPrimTypeOneArg
},
{
"scalar_tan"
,
kPrimTypeOneArg
},
{
"scalar_tan"
,
kPrimTypeOneArg
},
{
"scalar_trunc"
,
kPrimTypeOneArg
},
{
"scalar_trunc"
,
kPrimTypeOneArg
},
{
"typeof"
,
kPrimTypeOneArg
},
{
"typeof"
,
kPrimTypeOneArg
},
{
"scalar_uadd"
,
kPrimTypeOneArg
},
{
"scalar_uadd"
,
kPrimTypeOneArg
},
{
"scalar_usub"
,
kPrimTypeOneArg
},
{
"scalar_usub"
,
kPrimTypeOneArg
},
// TWO_ARGS prim
// TWO_ARGS prim
{
"scalar_add"
,
kPrimTypeTwoArgs
},
{
"scalar_add"
,
kPrimTypeTwoArgs
},
{
"bool_and"
,
kPrimTypeTwoArgs
},
{
"bool_and"
,
kPrimTypeTwoArgs
},
{
"bool_eq"
,
kPrimTypeTwoArgs
},
{
"bool_eq"
,
kPrimTypeTwoArgs
},
{
"bool_or"
,
kPrimTypeTwoArgs
},
{
"bool_or"
,
kPrimTypeTwoArgs
},
{
"scalar_div"
,
kPrimTypeTwoArgs
},
{
"scalar_div"
,
kPrimTypeTwoArgs
},
{
"scalar_eq"
,
kPrimTypeTwoArgs
},
{
"scalar_eq"
,
kPrimTypeTwoArgs
},
{
"scalar_ge"
,
kPrimTypeTwoArgs
},
{
"scalar_ge"
,
kPrimTypeTwoArgs
},
{
"scalar_gt"
,
kPrimTypeTwoArgs
},
{
"scalar_gt"
,
kPrimTypeTwoArgs
},
{
"scalar_le"
,
kPrimTypeTwoArgs
},
{
"scalar_le"
,
kPrimTypeTwoArgs
},
{
"scalar_lt"
,
kPrimTypeTwoArgs
},
{
"scalar_lt"
,
kPrimTypeTwoArgs
},
{
"scalar_ne"
,
kPrimTypeTwoArgs
},
{
"scalar_ne"
,
kPrimTypeTwoArgs
},
{
"scalar_mod"
,
kPrimTypeTwoArgs
},
{
"scalar_mod"
,
kPrimTypeTwoArgs
},
{
"scalar_mul"
,
kPrimTypeTwoArgs
},
{
"scalar_mul"
,
kPrimTypeTwoArgs
},
{
"scalar_pow"
,
kPrimTypeTwoArgs
},
{
"scalar_pow"
,
kPrimTypeTwoArgs
},
{
"scalar_sub"
,
kPrimTypeTwoArgs
},
{
"scalar_sub"
,
kPrimTypeTwoArgs
},
{
"scalar_floordiv"
,
kPrimTypeTwoArgs
}})
{}
})
{}
bool
PrimToFunction
::
GetFunction
(
const
PrimitivePtr
&
prim
,
FunctionPtr
*
const
func
)
const
{
bool
PrimToFunction
::
GetFunction
(
const
PrimitivePtr
&
prim
,
FunctionPtr
*
const
func
)
const
{
bool
result
=
false
;
bool
result
=
false
;
...
...
mindspore/ccsrc/pipeline/static_analysis/prim.cc
浏览文件 @
80978cf3
...
@@ -52,6 +52,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
...
@@ -52,6 +52,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{
prim
::
kPrimSwitch
,
{
InferImplSwitch
,
true
}},
{
prim
::
kPrimSwitch
,
{
InferImplSwitch
,
true
}},
{
prim
::
kPrimIs_
,
{
InferImplIs_
,
true
}},
{
prim
::
kPrimIs_
,
{
InferImplIs_
,
true
}},
{
prim
::
kPrimIsNot
,
{
InferImplIsNot
,
true
}},
{
prim
::
kPrimIsNot
,
{
InferImplIsNot
,
true
}},
{
prim
::
kPrimInDict
,
{
InferImplInDict
,
true
}},
{
prim
::
kPrimNotInDict
,
{
InferImplNotInDict
,
true
}},
// Maths
// Maths
{
prim
::
kPrimMaximumGrad
,
{
InferImplMinOrMaxGrad
,
true
}},
{
prim
::
kPrimMaximumGrad
,
{
InferImplMinOrMaxGrad
,
true
}},
{
prim
::
kPrimMinimumGrad
,
{
InferImplMinOrMaxGrad
,
true
}},
{
prim
::
kPrimMinimumGrad
,
{
InferImplMinOrMaxGrad
,
true
}},
...
@@ -91,6 +93,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
...
@@ -91,6 +93,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{
prim
::
kPrimMakeRange
,
{
InferImplMakeRange
,
false
}},
{
prim
::
kPrimMakeRange
,
{
InferImplMakeRange
,
false
}},
{
prim
::
kPrimStopGradient
,
{
InferImplStopGradient
,
false
}},
{
prim
::
kPrimStopGradient
,
{
InferImplStopGradient
,
false
}},
{
prim
::
kPrimStringEqual
,
{
InferImplStringEqual
,
false
}},
{
prim
::
kPrimStringEqual
,
{
InferImplStringEqual
,
false
}},
{
prim
::
kPrimStringConcat
,
{
InferImplStringConcat
,
false
}},
{
prim
::
kPrimDictLen
,
{
InferImplDictLen
,
false
}},
{
prim
::
kPrimDictLen
,
{
InferImplDictLen
,
false
}},
// NN
// NN
{
prim
::
kPrimPooling
,
{
InferImplPooling
,
true
}},
{
prim
::
kPrimPooling
,
{
InferImplPooling
,
true
}},
...
@@ -988,6 +991,8 @@ PrimitiveToImplMap &GetUniformPrimitiveToImplMap() {
...
@@ -988,6 +991,8 @@ PrimitiveToImplMap &GetUniformPrimitiveToImplMap() {
{
prim
::
kPrimScalarMul
,
{
prim
::
ScalarMul
,
true
,
nullptr
,
true
}},
{
prim
::
kPrimScalarMul
,
{
prim
::
ScalarMul
,
true
,
nullptr
,
true
}},
{
prim
::
kPrimScalarDiv
,
{
prim
::
ScalarDiv
,
true
,
nullptr
,
true
}},
{
prim
::
kPrimScalarDiv
,
{
prim
::
ScalarDiv
,
true
,
nullptr
,
true
}},
{
prim
::
kPrimScalarMod
,
{
prim
::
ScalarMod
,
true
,
nullptr
,
true
}},
{
prim
::
kPrimScalarMod
,
{
prim
::
ScalarMod
,
true
,
nullptr
,
true
}},
{
prim
::
kPrimScalarPow
,
{
prim
::
ScalarPow
,
true
,
nullptr
,
true
}},
{
prim
::
kPrimScalarFloordiv
,
{
prim
::
ScalarFloordiv
,
true
,
nullptr
,
true
}},
{
prim
::
kPrimScalarUadd
,
{
prim
::
ScalarUAdd
,
true
,
nullptr
,
true
}},
{
prim
::
kPrimScalarUadd
,
{
prim
::
ScalarUAdd
,
true
,
nullptr
,
true
}},
{
prim
::
kPrimScalarUsub
,
{
prim
::
ScalarUSub
,
true
,
nullptr
,
true
}},
{
prim
::
kPrimScalarUsub
,
{
prim
::
ScalarUSub
,
true
,
nullptr
,
true
}},
{
prim
::
kPrimScalarLog
,
{
prim
::
ScalarLog
,
true
,
nullptr
,
true
}},
{
prim
::
kPrimScalarLog
,
{
prim
::
ScalarLog
,
true
,
nullptr
,
true
}},
...
...
mindspore/ccsrc/pipeline/static_analysis/prim.h
浏览文件 @
80978cf3
...
@@ -178,6 +178,10 @@ AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &,
...
@@ -178,6 +178,10 @@ AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &,
const
AbstractBasePtrList
&
args_spec_list
);
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplIsNot
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
AbstractBasePtr
InferImplIsNot
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
);
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplInDict
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplNotInDict
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplPooling
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
AbstractBasePtr
InferImplPooling
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplPoolingGrad
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
AbstractBasePtr
InferImplPoolingGrad
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
...
@@ -287,6 +291,8 @@ AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const Primitive
...
@@ -287,6 +291,8 @@ AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const Primitive
const
AbstractBasePtrList
&
args_spec_list
);
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplStringEqual
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
AbstractBasePtr
InferImplStringEqual
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplStringConcat
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
AbstractBasePtr
InferImplDictLen
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
AbstractBasePtr
InferImplDictLen
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
);
const
AbstractBasePtrList
&
args_spec_list
);
...
...
mindspore/ops/composite/multitype_ops/__init__.py
浏览文件 @
80978cf3
...
@@ -19,6 +19,9 @@ from .add_impl import add
...
@@ -19,6 +19,9 @@ from .add_impl import add
from
.sub_impl
import
sub
from
.sub_impl
import
sub
from
.mul_impl
import
mul
from
.mul_impl
import
mul
from
.div_impl
import
div
from
.div_impl
import
div
from
.pow_impl
import
pow_
from
.floordiv_impl
import
floordiv
from
.mod_impl
import
mod
from
.getitem_impl
import
getitem
from
.getitem_impl
import
getitem
from
.zeros_like_impl
import
zeros_like
from
.zeros_like_impl
import
zeros_like
from
.ones_like_impl
import
ones_like
from
.ones_like_impl
import
ones_like
...
@@ -38,6 +41,9 @@ __all__ = [
...
@@ -38,6 +41,9 @@ __all__ = [
'sub'
,
'sub'
,
'mul'
,
'mul'
,
'div'
,
'div'
,
'pow_'
,
'floordiv'
,
'mod'
,
'uadd'
,
'uadd'
,
'zeros_like'
,
'zeros_like'
,
'ones_like'
,
'ones_like'
,
...
...
mindspore/ops/composite/multitype_ops/add_impl.py
浏览文件 @
80978cf3
...
@@ -69,6 +69,21 @@ def _scalar_add_scalar(x, y):
...
@@ -69,6 +69,21 @@ def _scalar_add_scalar(x, y):
return
F
.
scalar_add
(
x
,
y
)
return
F
.
scalar_add
(
x
,
y
)
@
add
.
register
(
"String"
,
"String"
)
def
_string_concat_string
(
x
,
y
):
"""
Concatenate the string y to the string x.
Args:
x (str): The first input string.
y (str): the second input string.
Returns:
str, concatenate the y to the x.
"""
return
F
.
string_concat
(
x
,
y
)
@
add
.
register
(
"Number"
,
"Tensor"
)
@
add
.
register
(
"Number"
,
"Tensor"
)
def
_scalar_add_tensor
(
x
,
y
):
def
_scalar_add_tensor
(
x
,
y
):
"""
"""
...
@@ -81,8 +96,7 @@ def _scalar_add_tensor(x, y):
...
@@ -81,8 +96,7 @@ def _scalar_add_tensor(x, y):
Returns:
Returns:
Tensor, has the same dtype as x.
Tensor, has the same dtype as x.
"""
"""
z
=
F
.
scalar_to_tensor
(
x
,
F
.
dtype
(
y
))
return
F
.
tensor_add
(
x
,
y
)
return
F
.
tensor_add
(
z
,
y
)
@
add
.
register
(
"Tensor"
,
"Number"
)
@
add
.
register
(
"Tensor"
,
"Number"
)
...
@@ -97,8 +111,7 @@ def _tensor_add_scalar(x, y):
...
@@ -97,8 +111,7 @@ def _tensor_add_scalar(x, y):
Returns:
Returns:
Tensor, has the same dtype as x.
Tensor, has the same dtype as x.
"""
"""
z
=
F
.
scalar_to_tensor
(
y
,
F
.
dtype
(
x
))
return
F
.
tensor_add
(
x
,
y
)
return
F
.
tensor_add
(
x
,
z
)
@
add
.
register
(
"Tensor"
,
"Tensor"
)
@
add
.
register
(
"Tensor"
,
"Tensor"
)
...
...
mindspore/ops/composite/multitype_ops/div_impl.py
浏览文件 @
80978cf3
...
@@ -68,8 +68,7 @@ def _scalar_div_tensor(x, y):
...
@@ -68,8 +68,7 @@ def _scalar_div_tensor(x, y):
Returns:
Returns:
Tensor, has the same dtype as x.
Tensor, has the same dtype as x.
"""
"""
z
=
F
.
scalar_to_tensor
(
x
,
F
.
dtype
(
y
))
return
F
.
tensor_div
(
x
,
y
)
return
F
.
tensor_div
(
z
,
y
)
@
div
.
register
(
"Tensor"
,
"Number"
)
@
div
.
register
(
"Tensor"
,
"Number"
)
...
@@ -84,5 +83,4 @@ def _tensor_div_scalar(x, y):
...
@@ -84,5 +83,4 @@ def _tensor_div_scalar(x, y):
Returns:
Returns:
Tensor, has the same dtype as x.
Tensor, has the same dtype as x.
"""
"""
z
=
F
.
scalar_to_tensor
(
y
,
F
.
dtype
(
x
))
return
F
.
tensor_div
(
x
,
y
)
return
F
.
tensor_div
(
x
,
z
)
mindspore/ops/composite/multitype_ops/floordiv_impl.py
0 → 100644
浏览文件 @
80978cf3
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Implementation for internal polymorphism `floordiv` operations."""
from
...composite
import
base
from
...
import
functional
as
F
floordiv
=
base
.
MultitypeFuncGraph
(
"floordiv"
)
"""
`floordiv` is a metafuncgraph object which will compute the floordiv of two objects
using ".register" decorator.
"""
@
floordiv
.
register
(
"Number"
,
"Number"
)
def
_floordiv_scalar
(
x
,
y
):
"""Returns x // y where x and y are all scalars."""
return
F
.
scalar_floordiv
(
x
,
y
)
@
floordiv
.
register
(
"Tensor"
,
"Tensor"
)
def
_floordiv_tensor
(
x
,
y
):
"""Returns x // y where x and y are all tensors and have save dtype."""
return
F
.
tensor_floordiv
(
x
,
y
)
@
floordiv
.
register
(
"Tensor"
,
"Number"
)
def
_tensor_floordiv_scalar
(
x
,
y
):
"""Returns x // y where x is a tensor and y is a scalar. x and y should have same dtype."""
return
F
.
tensor_floordiv
(
x
,
y
)
@
floordiv
.
register
(
"Number"
,
"Tensor"
)
def
_scalar_floordiv_tensor
(
x
,
y
):
"""Returns x // y where x is a scalar and y is a tensor. x and y should have same dtype."""
return
F
.
tensor_floordiv
(
x
,
y
)
mindspore/ops/composite/multitype_ops/mod_impl.py
0 → 100644
浏览文件 @
80978cf3
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Implementation for internal polymorphism `mod` operations."""
from
...composite
import
base
from
...
import
functional
as
F
mod
=
base
.
MultitypeFuncGraph
(
"mod"
)
"""
`mod` is a metafuncgraph object which will compute the mod of two objects
using ".register" decorator.
"""
@
mod
.
register
(
"Number"
,
"Number"
)
def
_mod_scalar
(
x
,
y
):
"""Returns x % y where x and y are all scalars."""
return
F
.
scalar_mod
(
x
,
y
)
@
mod
.
register
(
"Tensor"
,
"Tensor"
)
def
_mod_tensor
(
x
,
y
):
"""Returns x % y where x and y are all tensors and have save dtype."""
return
F
.
tensor_mod
(
x
,
y
)
@
mod
.
register
(
"Tensor"
,
"Number"
)
def
_tensor_mod_scalar
(
x
,
y
):
"""Returns x % y where x is a tensor and y is a scalar. x and y should have same dtype."""
return
F
.
tensor_mod
(
x
,
y
)
@
mod
.
register
(
"Number"
,
"Tensor"
)
def
_scalar_mod_tensor
(
x
,
y
):
"""Returns x % y where x is a scalar and y is a tensor. x and y should have same dtype."""
return
F
.
tensor_mod
(
x
,
y
)
mindspore/ops/composite/multitype_ops/mul_impl.py
浏览文件 @
80978cf3
...
@@ -56,8 +56,7 @@ def _scalar_mul_tensor(x, y):
...
@@ -56,8 +56,7 @@ def _scalar_mul_tensor(x, y):
Outputs:
Outputs:
Tensor, has the same dtype as x.
Tensor, has the same dtype as x.
"""
"""
z
=
F
.
scalar_to_tensor
(
x
,
F
.
dtype
(
y
))
return
F
.
tensor_mul
(
x
,
y
)
return
F
.
tensor_mul
(
z
,
y
)
@
mul
.
register
(
"Tensor"
,
"Number"
)
@
mul
.
register
(
"Tensor"
,
"Number"
)
...
@@ -68,5 +67,4 @@ def _tensor_mul_scalar(x, y):
...
@@ -68,5 +67,4 @@ def _tensor_mul_scalar(x, y):
Outputs:
Outputs:
Tensor, has the same dtype as x.
Tensor, has the same dtype as x.
"""
"""
z
=
F
.
scalar_to_tensor
(
y
,
F
.
dtype
(
x
))
return
F
.
tensor_mul
(
x
,
y
)
return
F
.
tensor_mul
(
x
,
z
)
mindspore/ops/composite/multitype_ops/pow_impl.py
0 → 100644
浏览文件 @
80978cf3
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Implementation for internal polymorphism `pow` operations."""
from
...composite
import
base
from
...
import
functional
as
F
pow_
=
base
.
MultitypeFuncGraph
(
"pow"
)
"""
`pow` is a metafuncgraph object which will compute the pow of two objects
using ".register" decorator.
"""
@
pow_
.
register
(
"Number"
,
"Number"
)
def
_pow_scalar
(
x
,
y
):
"""Returns x ** y where x and y are all scalars."""
return
F
.
scalar_pow
(
x
,
y
)
@
pow_
.
register
(
"Tensor"
,
"Tensor"
)
def
_pow_tensor
(
x
,
y
):
"""Returns x ** y where x and y are all tensors and have save dtype."""
return
F
.
tensor_pow
(
x
,
y
)
@
pow_
.
register
(
"Tensor"
,
"Number"
)
def
_tensor_pow_scalar
(
x
,
y
):
"""Returns x ** y where x is a tensor and y is a scalar. x and y should have same dtype."""
return
F
.
tensor_pow
(
x
,
y
)
@
pow_
.
register
(
"Number"
,
"Tensor"
)
def
_scalar_pow_tensor
(
x
,
y
):
"""Returns x ** y where x is a scalar and y is a tensor. x and y should have same dtype."""
return
F
.
tensor_pow
(
x
,
y
)
mindspore/ops/composite/multitype_ops/sub_impl.py
浏览文件 @
80978cf3
...
@@ -41,12 +41,10 @@ def _sub_tensor(x, y):
...
@@ -41,12 +41,10 @@ def _sub_tensor(x, y):
@
sub
.
register
(
"Number"
,
"Tensor"
)
@
sub
.
register
(
"Number"
,
"Tensor"
)
def
_scalar_sub_tensor
(
x
,
y
):
def
_scalar_sub_tensor
(
x
,
y
):
"""Returns x - y where x is a scalar and y is a tensor. x and y should have same dtype."""
"""Returns x - y where x is a scalar and y is a tensor. x and y should have same dtype."""
z
=
F
.
scalar_to_tensor
(
x
,
F
.
dtype
(
y
))
return
F
.
tensor_sub
(
x
,
y
)
return
F
.
tensor_sub
(
z
,
y
)
@
sub
.
register
(
"Tensor"
,
"Number"
)
@
sub
.
register
(
"Tensor"
,
"Number"
)
def
_tensor_sub_scalar
(
x
,
y
):
def
_tensor_sub_scalar
(
x
,
y
):
"""Returns x - y where x is a tensor and y is a scalar. x and y should have same dtype."""
"""Returns x - y where x is a tensor and y is a scalar. x and y should have same dtype."""
z
=
F
.
scalar_to_tensor
(
y
,
F
.
dtype
(
x
))
return
F
.
tensor_sub
(
x
,
y
)
return
F
.
tensor_sub
(
x
,
z
)
mindspore/ops/functional.py
浏览文件 @
80978cf3
...
@@ -48,6 +48,9 @@ tensor_ge = P.GreaterEqual()
...
@@ -48,6 +48,9 @@ tensor_ge = P.GreaterEqual()
tensor_sub
=
P
.
Sub
()
tensor_sub
=
P
.
Sub
()
tensor_mul
=
P
.
Mul
()
tensor_mul
=
P
.
Mul
()
tensor_div
=
P
.
RealDiv
()
tensor_div
=
P
.
RealDiv
()
tensor_floordiv
=
P
.
FloorDiv
()
tensor_pow
=
P
.
Pow
()
tensor_mod
=
P
.
FloorMod
()
strided_slice
=
P
.
StridedSlice
()
strided_slice
=
P
.
StridedSlice
()
same_type_shape
=
P
.
SameTypeShape
()
same_type_shape
=
P
.
SameTypeShape
()
equal
=
P
.
Equal
()
equal
=
P
.
Equal
()
...
@@ -83,6 +86,7 @@ scalar_add = Primitive('scalar_add')
...
@@ -83,6 +86,7 @@ scalar_add = Primitive('scalar_add')
scalar_mul
=
Primitive
(
'scalar_mul'
)
scalar_mul
=
Primitive
(
'scalar_mul'
)
scalar_sub
=
Primitive
(
'scalar_sub'
)
scalar_sub
=
Primitive
(
'scalar_sub'
)
scalar_div
=
Primitive
(
'scalar_div'
)
scalar_div
=
Primitive
(
'scalar_div'
)
scalar_floordiv
=
Primitive
(
'scalar_floordiv'
)
scalar_log
=
Primitive
(
'scalar_log'
)
scalar_log
=
Primitive
(
'scalar_log'
)
scalar_pow
=
Primitive
(
'scalar_pow'
)
scalar_pow
=
Primitive
(
'scalar_pow'
)
scalar_gt
=
Primitive
(
'scalar_gt'
)
scalar_gt
=
Primitive
(
'scalar_gt'
)
...
@@ -95,6 +99,7 @@ scalar_uadd = Primitive('scalar_uadd')
...
@@ -95,6 +99,7 @@ scalar_uadd = Primitive('scalar_uadd')
scalar_usub
=
Primitive
(
'scalar_usub'
)
scalar_usub
=
Primitive
(
'scalar_usub'
)
scalar_mod
=
Primitive
(
'scalar_mod'
)
scalar_mod
=
Primitive
(
'scalar_mod'
)
string_eq
=
Primitive
(
'string_equal'
)
string_eq
=
Primitive
(
'string_equal'
)
string_concat
=
Primitive
(
'string_concat'
)
bool_not
=
Primitive
(
"bool_not"
)
bool_not
=
Primitive
(
"bool_not"
)
bool_or
=
Primitive
(
"bool_or"
)
bool_or
=
Primitive
(
"bool_or"
)
bool_and
=
Primitive
(
"bool_and"
)
bool_and
=
Primitive
(
"bool_and"
)
...
@@ -104,7 +109,8 @@ logical_not = P.LogicalNot()
...
@@ -104,7 +109,8 @@ logical_not = P.LogicalNot()
array_to_scalar
=
Primitive
(
'array_to_scalar'
)
array_to_scalar
=
Primitive
(
'array_to_scalar'
)
is_
=
Primitive
(
"is_"
)
is_
=
Primitive
(
"is_"
)
is_not
=
Primitive
(
"is_not"
)
is_not
=
Primitive
(
"is_not"
)
in_dict
=
Primitive
(
"in_dict"
)
not_in_dict
=
Primitive
(
"not_in_dict"
)
broadcast_gradient_args
=
Primitive
(
'BroadcastGradientArgs'
)
broadcast_gradient_args
=
Primitive
(
'BroadcastGradientArgs'
)
dot
=
Primitive
(
'dot'
)
dot
=
Primitive
(
'dot'
)
array_reduce
=
Primitive
(
'array_reduce'
)
array_reduce
=
Primitive
(
'array_reduce'
)
...
...
mindspore/ops/operations/math_ops.py
浏览文件 @
80978cf3
...
@@ -667,8 +667,8 @@ class AddN(PrimitiveWithInfer):
...
@@ -667,8 +667,8 @@ class AddN(PrimitiveWithInfer):
>>> return self.addN(z)
>>> return self.addN(z)
>>>
>>>
>>> net = NetAddN()
>>> net = NetAddN()
>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.
in
t32)
>>> input_x = Tensor(np.array([1, 2, 3]), mindspore.
floa
t32)
>>> input_y = Tensor(np.array([4, 5, 6]), mindspore.
in
t32)
>>> input_y = Tensor(np.array([4, 5, 6]), mindspore.
floa
t32)
>>> net(input_x, input_y, input_x, input_y)
>>> net(input_x, input_y, input_x, input_y)
Tensor([10, 14, 18], shape=(3,), dtype=mindspore.int32)
Tensor([10, 14, 18], shape=(3,), dtype=mindspore.int32)
"""
"""
...
...
tests/ut/python/pipeline/parse/test_operator.py
浏览文件 @
80978cf3
...
@@ -131,3 +131,72 @@ def test_ME_arithmetic_operator_0070():
...
@@ -131,3 +131,72 @@ def test_ME_arithmetic_operator_0070():
def
test_ME_logical_operator_0020
():
def
test_ME_logical_operator_0020
():
""" test_ME_logical_operator_0020 """
""" test_ME_logical_operator_0020 """
logical_operator_base
(
'or'
)
logical_operator_base
(
'or'
)
def
test_ops
():
class
OpsNet
(
Cell
):
""" OpsNet definition """
def
__init__
(
self
,
x
,
y
):
super
(
OpsNet
,
self
).
__init__
()
self
.
x
=
x
self
.
y
=
y
self
.
int
=
4
self
.
float
=
3.2
self
.
str_a
=
"hello"
self
.
str_b
=
"world"
def
construct
(
self
,
x
,
y
):
h
=
x
//
y
m
=
x
**
y
n
=
x
%
y
r
=
self
.
x
//
self
.
y
s
=
self
.
x
**
self
.
y
t
=
self
.
x
%
self
.
y
p
=
h
+
m
+
n
q
=
r
+
s
+
t
ret_pow
=
p
**
q
+
q
**
p
ret_mod
=
p
%
q
+
q
%
p
ret_floor
=
p
//
q
+
q
//
p
ret
=
ret_pow
+
ret_mod
+
ret_floor
if
self
.
int
>
self
.
float
:
if
self
.
str_a
+
self
.
str_b
==
"helloworld"
:
return
ret
return
x
net
=
OpsNet
(
9
,
2
)
x
=
Tensor
(
np
.
random
.
randint
(
low
=
1
,
high
=
10
,
size
=
(
2
,
3
,
4
),
dtype
=
np
.
int32
))
y
=
Tensor
(
np
.
random
.
randint
(
low
=
10
,
high
=
20
,
size
=
(
2
,
3
,
4
),
dtype
=
np
.
int32
))
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
net
(
x
,
y
)
def
test_in_dict
():
class
InDictNet
(
Cell
):
""" InDictNet definition """
def
__init__
(
self
,
key_in
,
key_not_in
):
super
(
InDictNet
,
self
).
__init__
()
self
.
key_in
=
key_in
self
.
key_not_in
=
key_not_in
def
construct
(
self
,
x
,
y
,
z
):
d
=
{
"a"
:
x
,
"b"
:
y
}
ret_in
=
1
ret_not_in
=
2
if
self
.
key_in
in
d
:
ret_in
=
d
[
self
.
key_in
]
if
self
.
key_not_in
not
in
d
:
ret_not_in
=
z
ret
=
ret_in
+
ret_not_in
return
ret
net
=
InDictNet
(
"a"
,
"c"
)
x
=
Tensor
(
np
.
random
.
randint
(
low
=
1
,
high
=
10
,
size
=
(
2
,
3
,
4
),
dtype
=
np
.
int32
))
y
=
Tensor
(
np
.
random
.
randint
(
low
=
10
,
high
=
20
,
size
=
(
2
,
3
,
4
),
dtype
=
np
.
int32
))
z
=
Tensor
(
np
.
random
.
randint
(
low
=
20
,
high
=
30
,
size
=
(
2
,
3
,
4
),
dtype
=
np
.
int32
))
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
net
(
x
,
y
,
z
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录