Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
7e4d972f
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7e4d972f
编写于
6月 01, 2020
作者:
W
Wei Luning
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bug in do signature
上级
58464118
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
153 addition
and
59 deletion
+153
-59
.gitignore
.gitignore
+1
-0
mindspore/ccsrc/kernel/common_utils.cc
mindspore/ccsrc/kernel/common_utils.cc
+1
-1
mindspore/ccsrc/operator/composite/do_signature.cc
mindspore/ccsrc/operator/composite/do_signature.cc
+27
-24
mindspore/ccsrc/operator/prim_others.cc
mindspore/ccsrc/operator/prim_others.cc
+1
-1
mindspore/ccsrc/optimizer/irpass.cc
mindspore/ccsrc/optimizer/irpass.cc
+2
-2
mindspore/ccsrc/optimizer/irpass/ref_eliminate.h
mindspore/ccsrc/optimizer/irpass/ref_eliminate.h
+5
-0
mindspore/ccsrc/pipeline/parse/function_block.cc
mindspore/ccsrc/pipeline/parse/function_block.cc
+3
-3
mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc
mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc
+2
-2
mindspore/ccsrc/pipeline/static_analysis/prim.cc
mindspore/ccsrc/pipeline/static_analysis/prim.cc
+1
-1
mindspore/common/dtype.py
mindspore/common/dtype.py
+1
-1
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+15
-5
mindspore/ops/operations/math_ops.py
mindspore/ops/operations/math_ops.py
+2
-2
mindspore/ops/operations/other_ops.py
mindspore/ops/operations/other_ops.py
+2
-0
tests/ut/python/ops/test_layer_switch.py
tests/ut/python/ops/test_layer_switch.py
+15
-0
tests/ut/python/ops/test_nn_ops.py
tests/ut/python/ops/test_nn_ops.py
+0
-17
tests/ut/python/ops/test_signature.py
tests/ut/python/ops/test_signature.py
+75
-0
未找到文件。
.gitignore
浏览文件 @
7e4d972f
...
...
@@ -65,6 +65,7 @@ test_temp_summary_event_file/
*.ckpt
*.shp
*.pkl
*.pb
.clangd
mindspore/version.py
mindspore/default_config.py
...
...
mindspore/ccsrc/kernel/common_utils.cc
浏览文件 @
7e4d972f
...
...
@@ -253,7 +253,7 @@ std::string Dtype2String(const std::string &dtypes) {
std
::
string
TypeId2String
(
TypeId
type_id
)
{
auto
iter
=
type_id_str_map
.
find
(
type_id
);
if
(
iter
==
type_id_str_map
.
end
())
{
MS_EXCEPTION
(
ArgumentError
)
<<
"Illegal input dtype."
<<
TypeIdLabel
(
type_id
);
return
std
::
string
(
TypeIdLabel
(
type_id
)
);
}
return
iter
->
second
;
}
...
...
mindspore/ccsrc/operator/composite/do_signature.cc
浏览文件 @
7e4d972f
...
...
@@ -47,16 +47,6 @@ const std::vector<Signature> &GetSignature(const ValuePtr &function) {
return
empty
;
}
const
std
::
string
GetOpName
(
const
ValuePtr
&
function
)
{
std
::
string
name
=
""
;
if
(
function
->
isa
<
Primitive
>
())
{
name
=
function
->
cast
<
PrimitivePyPtr
>
()
->
name
();
}
else
if
(
function
->
isa
<
MetaFuncGraph
>
())
{
name
=
function
->
cast
<
MetaFuncGraphPtr
>
()
->
name
();
}
return
name
;
}
void
ProcessDefault
(
const
std
::
string
&
func_name
,
const
AbstractBasePtrList
&
args_spec_list
,
const
std
::
vector
<
Signature
>
&
signature
,
bool
has_var
,
std
::
vector
<
AnfNodePtr
>
*
const
op_inputs
)
{
std
::
size_t
sig_size
=
signature
.
size
();
...
...
@@ -93,7 +83,8 @@ void setMaxType(TypeId *max_type_id, TypeId *max_type, size_t *max_type_number,
*
max_type_number
=
type_number
;
}
TypeId
GetMaxTypeId
(
const
abstract
::
AbstractBasePtrList
&
args_spec_list
,
std
::
vector
<
size_t
>
indexs
)
{
TypeId
GetMaxTypeId
(
const
abstract
::
AbstractBasePtrList
&
args_spec_list
,
std
::
vector
<
size_t
>
indexs
,
const
std
::
set
<
size_t
>
&
write_indexs
)
{
TypeId
max_type_id
=
kTypeUnknown
;
TypeId
max_type
=
kTypeUnknown
;
size_t
max_type_number
=
0
;
...
...
@@ -103,7 +94,12 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
TypeId
arg_type
=
kTypeUnknown
;
AbstractBasePtr
arg_value
=
args_spec_list
[
index
];
if
(
arg_value
->
isa
<
abstract
::
AbstractRef
>
())
{
arg_value
=
arg_value
->
cast
<
abstract
::
AbstractRefPtr
>
()
->
ref
();
auto
is_write
=
(
write_indexs
.
find
(
index
)
!=
write_indexs
.
end
());
if
(
is_write
)
{
arg_value
=
arg_value
->
cast
<
abstract
::
AbstractRefPtr
>
()
->
ref_origin
();
}
else
{
arg_value
=
arg_value
->
cast
<
abstract
::
AbstractRefPtr
>
()
->
ref
();
}
}
if
(
arg_value
->
isa
<
abstract
::
AbstractTensor
>
())
{
auto
tensor
=
arg_value
->
cast
<
abstract
::
AbstractTensorPtr
>
();
...
...
@@ -157,7 +153,8 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve
// Get the largest type of index in the same SignatureEnumDType of arguments.
std
::
map
<
SignatureEnumDType
,
TypeId
>
GetMaxDtype
(
const
std
::
vector
<
SignatureEnumDType
>
&
dtypes
,
const
abstract
::
AbstractBasePtrList
&
args_spec_list
)
{
const
abstract
::
AbstractBasePtrList
&
args_spec_list
,
const
std
::
set
<
size_t
>
&
write_indexs
)
{
// record index for signature.dtypes of the same type
// eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}}
std
::
map
<
SignatureEnumDType
,
std
::
vector
<
size_t
>>
type_indexs
;
...
...
@@ -192,7 +189,7 @@ std::map<SignatureEnumDType, TypeId> GetMaxDtype(const std::vector<SignatureEnum
(
void
)
dst_type
.
insert
(
std
::
make_pair
(
type
,
kTypeUnknown
));
continue
;
}
(
void
)
dst_type
.
insert
(
std
::
make_pair
(
type
,
GetMaxTypeId
(
args_spec_list
,
indexs
)));
(
void
)
dst_type
.
insert
(
std
::
make_pair
(
type
,
GetMaxTypeId
(
args_spec_list
,
indexs
,
write_indexs
)));
}
return
dst_type
;
}
...
...
@@ -205,9 +202,9 @@ AnfNodePtr DoCast(const AnfNodePtr ¶m, const TypeId &type_id, const FuncGrap
return
NewCNode
({
cast_node
,
param
,
dtype_node
},
graph
);
}
void
DoAutoCast
(
const
std
::
vector
<
Signature
>
&
signature
,
const
abstract
::
AbstractBasePtrList
&
args_spec_list
,
const
FuncGraphPtr
&
graph
,
std
::
vector
<
AnfNodePtr
>
*
const
op_inputs
,
const
std
::
set
<
size_t
>
&
write_indexs
)
{
void
DoAutoCast
(
const
std
::
string
&
func_name
,
const
std
::
vector
<
Signature
>
&
signature
,
const
abstract
::
AbstractBasePtrList
&
args_spec_list
,
const
FuncGraphPtr
&
graph
,
std
::
vector
<
AnfNodePtr
>
*
const
op_inputs
,
const
std
::
set
<
size_t
>
&
write_indexs
)
{
std
::
vector
<
SignatureEnumDType
>
dtypes
;
(
void
)
std
::
transform
(
signature
.
begin
(),
signature
.
end
(),
std
::
back_inserter
(
dtypes
),
[](
const
Signature
&
sig
)
{
return
sig
.
dtype
;
});
...
...
@@ -216,16 +213,23 @@ void DoAutoCast(const std::vector<Signature> &signature, const abstract::Abstrac
return
;
}
// Stat the index of the arguments with the largest type in the same SignatureEnumDType.
std
::
map
<
SignatureEnumDType
,
TypeId
>
dst_type
=
GetMaxDtype
(
dtypes
,
args_spec_list
);
std
::
map
<
SignatureEnumDType
,
TypeId
>
dst_type
=
GetMaxDtype
(
dtypes
,
args_spec_list
,
write_indexs
);
// Identify which arg requires auto cast
for
(
size_t
i
=
0
;
i
<
args_spec_list
.
size
();
++
i
)
{
auto
it
=
dst_type
.
find
(
dtypes
[
i
]);
if
(
it
==
dst_type
.
end
()
||
it
->
second
==
kTypeUnknown
)
{
continue
;
}
auto
rw_it
=
write_indexs
.
find
(
i
);
auto
is_write
=
(
rw_it
!=
write_indexs
.
end
());
AbstractBasePtr
arg_value
=
args_spec_list
[
i
];
if
(
arg_value
->
isa
<
abstract
::
AbstractRef
>
())
{
arg_value
=
arg_value
->
cast
<
abstract
::
AbstractRefPtr
>
()
->
ref
();
if
(
is_write
)
{
arg_value
=
arg_value
->
cast
<
abstract
::
AbstractRefPtr
>
()
->
ref_origin
();
}
else
{
arg_value
=
arg_value
->
cast
<
abstract
::
AbstractRefPtr
>
()
->
ref
();
}
}
TypeId
arg_type_id
=
kTypeUnknown
;
if
(
arg_value
->
isa
<
abstract
::
AbstractTensor
>
())
{
...
...
@@ -243,10 +247,9 @@ void DoAutoCast(const std::vector<Signature> &signature, const abstract::Abstrac
if
(
it_map
==
type_map
.
end
())
{
continue
;
}
auto
rw_it
=
write_indexs
.
find
(
i
);
if
(
rw_it
!=
write_indexs
.
end
())
{
if
(
is_write
)
{
if
(
arg_type_id
!=
it
->
second
)
{
MS_LOG
(
EXCEPTION
)
<<
"In op '"
<<
GetOpName
(
graph
)
<<
"', argument '"
<<
args_spec_list
[
i
]
MS_LOG
(
EXCEPTION
)
<<
"In op '"
<<
func_name
<<
"', argument '"
<<
args_spec_list
[
i
]
<<
"' can not cast type from '"
<<
TypeIdLabel
(
arg_type_id
)
<<
"' to '"
<<
TypeIdLabel
(
it
->
second
)
<<
"' automatically."
;
}
...
...
@@ -299,8 +302,8 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
if
(
sig
==
SignatureEnumRW
::
kRWRead
)
{
param
=
func_graph
->
NewCNode
({
NewValueNode
(
prim
::
kPrimGetRefValue
),
param
});
}
else
if
(
sig
==
SignatureEnumRW
::
kRWWrite
)
{
param
=
func_graph
->
NewCNode
({
NewValueNode
(
prim
::
kPrimGetRefOrigin
),
param
});
write_indexs
.
insert
(
i
);
param
=
func_graph
->
NewCNode
({
NewValueNode
(
prim
::
kPrimGetRefKey
),
param
});
}
// If sig is SignatureEnumRW::kRWRef, not do anything.
}
else
if
(
sig
==
SignatureEnumRW
::
kRWWrite
&&
type
->
type_id
()
!=
kObjectTypeRefKey
)
{
...
...
@@ -310,7 +313,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
}
// process default
ProcessDefault
(
func_name
,
args_spec_list
,
signature
,
has_var
,
&
op_inputs
);
DoAutoCast
(
signature
,
args_spec_list
,
func_graph
,
&
op_inputs
,
write_indexs
);
DoAutoCast
(
func_name
,
signature
,
args_spec_list
,
func_graph
,
&
op_inputs
,
write_indexs
);
return
func_graph
->
NewCNode
(
op_inputs
);
}
}
// namespace
...
...
mindspore/ccsrc/operator/prim_others.cc
浏览文件 @
7e4d972f
...
...
@@ -160,7 +160,7 @@ AbstractBasePtr InferImplGetRefOrigin(const AnalysisEnginePtr &, const Primitive
const
AbstractBasePtrList
&
args_spec_list
)
{
// arguments: value
if
(
args_spec_list
.
size
()
!=
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"get_ref_
value
requires 1 parameters, while the input size is "
<<
args_spec_list
.
size
()
MS_LOG
(
EXCEPTION
)
<<
"get_ref_
origin
requires 1 parameters, while the input size is "
<<
args_spec_list
.
size
()
<<
"."
;
}
TypePtr
type
=
args_spec_list
[
0
]
->
GetTypeTrack
();
...
...
mindspore/ccsrc/optimizer/irpass.cc
浏览文件 @
7e4d972f
...
...
@@ -81,8 +81,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// Ref eliminate
make_ref_eliminate_
=
MakeSubstitution
(
MakeRefEliminater
(),
"make_ref_eliminate"
,
prim
::
kPrimMakeRef
);
get_make_ref_eliminate_
=
MakeSubstitution
(
GetMakeRefEliminater
(),
"get_make_ref_eliminate"
,
{
prim
::
kPrimGetRefKey
,
prim
::
kPrimGetRefValue
});
get_make_ref_eliminate_
=
MakeSubstitution
(
GetMakeRefEliminater
(),
"get_make_ref_eliminate"
,
{
prim
::
kPrimGetRefKey
,
prim
::
kPrimGetRefValue
,
prim
::
kPrimGetRefOrigin
});
replace_refkey_by_param_
=
MakeSubstitution
(
ReplaceRefkeyByParam
(),
"replace_refkey_by_param"
,
IsValueNode
<
RefKey
>
,
opt
::
FORCE_RENORM
);
...
...
mindspore/ccsrc/optimizer/irpass/ref_eliminate.h
浏览文件 @
7e4d972f
...
...
@@ -48,6 +48,7 @@ class MakeRefEliminater : public AnfVisitor {
// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X
// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y
// {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z
class
GetMakeRefEliminater
:
public
AnfVisitor
{
public:
AnfNodePtr
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
override
{
...
...
@@ -71,6 +72,10 @@ class GetMakeRefEliminater : public AnfVisitor {
return
ref
->
input
(
2
);
}
if
(
cnode
->
IsApply
(
prim
::
kPrimGetRefOrigin
))
{
return
ref
->
input
(
3
);
}
return
nullptr
;
}
};
...
...
mindspore/ccsrc/pipeline/parse/function_block.cc
浏览文件 @
7e4d972f
...
...
@@ -315,7 +315,7 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {
ValueNodePtr
make_tuple_op
=
NewValueNode
(
prim
::
kPrimMakeTuple
);
ValueNodePtr
depend_op
=
NewValueNode
(
prim
::
kPrimDepend
);
ValueNodePtr
get_ref
key_op
=
NewValueNode
(
prim
::
kPrimGetRefKey
);
ValueNodePtr
get_ref
_origin_op
=
NewValueNode
(
prim
::
kPrimGetRefOrigin
);
ValueNodePtr
stop_gradient_op
=
NewValueNode
(
prim
::
kPrimStopGradient
);
const
std
::
string
primitive_name
(
"assign"
);
const
std
::
string
module_name
(
"mindspore.ops.functional"
);
...
...
@@ -329,8 +329,8 @@ void FunctionBlock::InsertDependItemsBeforeReturn() {
vec_states
.
emplace_back
(
make_tuple_op
);
for
(
auto
&
item
:
state_assign_
)
{
auto
source
=
ReadVariable
(
item
.
second
);
auto
refkey
=
func_graph
()
->
NewCNode
({
get_refkey
_op
,
item
.
first
});
auto
assign
=
func_graph
()
->
NewCNode
({
assign_op
,
refkey
,
source
});
auto
origin
=
func_graph
()
->
NewCNode
({
get_ref_origin
_op
,
item
.
first
});
auto
assign
=
func_graph
()
->
NewCNode
({
assign_op
,
origin
,
source
});
MS_LOG
(
INFO
)
<<
"SetState read "
<<
item
.
first
->
ToString
()
<<
", "
<<
item
.
second
;
vec_states
.
emplace_back
(
assign
);
}
...
...
mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc
浏览文件 @
7e4d972f
...
...
@@ -801,8 +801,8 @@ bool AbstractRef::operator==(const AbstractBase &other) const {
std
::
string
AbstractRef
::
ToString
()
const
{
std
::
ostringstream
buffer
;
buffer
<<
type_name
()
<<
"("
<<
"key: "
<<
ref_key_
->
ToString
()
<<
"ref_value: "
<<
ref_
->
ToString
()
<<
"origin_value: "
<<
ref_origin_
->
ToString
();
<<
"key: "
<<
ref_key_
->
ToString
()
<<
"
ref_value: "
<<
ref_
->
ToString
()
<<
"
origin_value: "
<<
ref_origin_
->
ToString
();
auto
value
=
GetValueTrack
();
if
(
value
)
{
buffer
<<
", value: "
<<
value
->
ToString
();
...
...
mindspore/ccsrc/pipeline/static_analysis/prim.cc
浏览文件 @
7e4d972f
...
...
@@ -783,7 +783,7 @@ class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
AbstractBasePtr
abs
=
node_conf
->
GetEvaluatedValue
()
->
abstract
();
AbstractRefPtr
ref_abs
=
abs
->
cast
<
AbstractRefPtr
>
();
if
(
ref_abs
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"The first parameter of RefToEmbed should be Ref
."
;
MS_LOG
(
ERROR
)
<<
"The first parameter of RefToEmbed should be Ref
, but "
<<
abs
->
ToString
()
;
return
nullptr
;
}
auto
key_abs
=
ref_abs
->
ref_key
();
...
...
mindspore/common/dtype.py
浏览文件 @
7e4d972f
...
...
@@ -170,7 +170,7 @@ def get_py_obj_dtype(obj):
Type of MindSpore type.
"""
# Tensor
if
hasattr
(
obj
,
'dtype'
):
if
hasattr
(
obj
,
'dtype'
)
and
callable
(
obj
.
dtype
)
and
isinstance
(
obj
.
dtype
(),
typing
.
Type
)
:
return
tensor_type
(
obj
.
dtype
())
if
hasattr
(
obj
,
'__primitive_flag__'
)
or
hasattr
(
obj
,
'construct'
):
return
function
...
...
mindspore/ops/operations/array_ops.py
浏览文件 @
7e4d972f
...
...
@@ -31,7 +31,9 @@ from ...common.tensor import Tensor
from
..operations.math_ops
import
_infer_shape_reduce
from
.._utils
import
get_concat_offset
from
..primitive
import
Primitive
,
PrimitiveWithInfer
,
prim_attr_register
from
..._c_expression
import
signature_rw
as
sig_rw
from
..._c_expression
import
signature_kind
as
sig_kind
from
..._c_expression
import
signature_dtype
as
sig_dtype
def
_check_infer_attr_reduce
(
axis
,
keep_dims
,
prim_name
):
validator
.
check_value_type
(
'keep_dims'
,
keep_dims
,
[
bool
],
prim_name
)
...
...
@@ -2156,13 +2158,17 @@ class ScatterUpdate(PrimitiveWithInfer):
>>> input_x = mindspore.Parameter(Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32))
>>> indices = Tensor(np.array([[0, 0], [1, 1]]), mindspore.int32)
>>> update = Tensor(np.array([1.0, 2.2]), mindspore.float32)
>>> op = P.Scatter
Nd
Update()
>>> op = P.ScatterUpdate()
>>> output = op(input_x, indices, update)
"""
__mindspore_signature__
=
(
(
'x'
,
sig_rw
.
RW_WRITE
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T
),
(
'indices'
,
sig_rw
.
RW_READ
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T1
),
(
'value'
,
sig_rw
.
RW_READ
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T
)
)
@
prim_attr_register
def
__init__
(
self
,
use_locking
=
True
):
"""Init Scatter
Nd
Update"""
"""Init ScatterUpdate"""
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'indices'
,
'value'
],
outputs
=
[
'y'
])
def
infer_shape
(
self
,
x_shape
,
indices_shape
,
value_shape
):
...
...
@@ -2201,7 +2207,11 @@ class ScatterNdUpdate(PrimitiveWithInfer):
>>> op = P.ScatterNdUpdate()
>>> output = op(input_x, indices, update)
"""
__mindspore_signature__
=
(
(
'x'
,
sig_rw
.
RW_WRITE
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T
),
(
'indices'
,
sig_rw
.
RW_READ
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T1
),
(
'value'
,
sig_rw
.
RW_READ
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
,
sig_kind
.
KIND_EMPTY_DEFAULT_VALUE
,
sig_dtype
.
T
)
)
@
prim_attr_register
def
__init__
(
self
,
use_locking
=
True
):
"""Init ScatterNdUpdate"""
...
...
mindspore/ops/operations/math_ops.py
浏览文件 @
7e4d972f
...
...
@@ -179,7 +179,7 @@ class AssignAdd(PrimitiveWithInfer):
return
value
def
infer_dtype
(
self
,
variable
,
value
):
args
=
{
"value"
:
value
}
args
=
{
"va
riable"
:
variable
,
"va
lue"
:
value
}
validator
.
check_scalar_or_tensor_type_same
(
args
,
mstype
.
number_type
,
self
.
name
)
return
value
...
...
@@ -222,7 +222,7 @@ class AssignSub(PrimitiveWithInfer):
return
value
def
infer_dtype
(
self
,
variable
,
value
):
args
=
{
"value"
:
value
}
args
=
{
"va
riable"
:
variable
,
"va
lue"
:
value
}
validator
.
check_scalar_or_tensor_type_same
(
args
,
mstype
.
number_type
,
self
.
name
)
return
value
...
...
mindspore/ops/operations/other_ops.py
浏览文件 @
7e4d972f
...
...
@@ -58,6 +58,8 @@ class Assign(PrimitiveWithInfer):
return
variable
def
infer_dtype
(
self
,
variable
,
value
):
args
=
{
"variable"
:
variable
,
"value"
:
value
}
validator
.
check_tensor_type_same
(
args
,
(
mstype
.
bool_
,)
+
mstype
.
number_type
,
self
.
name
)
return
variable
...
...
tests/ut/python/ops/test_layer_switch.py
浏览文件 @
7e4d972f
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test layer switch"""
import
numpy
as
np
import
mindspore
...
...
tests/ut/python/ops/test_nn_ops.py
浏览文件 @
7e4d972f
...
...
@@ -345,19 +345,6 @@ class Conv2dNativeNet(nn.Cell):
return
self
.
flatten
(
self
.
conv
(
input_x
,
self
.
weight
))
class
MakeRefKeyNet
(
nn
.
Cell
):
""" MakeRefKeyNet definition """
def
__init__
(
self
):
super
(
MakeRefKeyNet
,
self
).
__init__
()
self
.
y
=
Parameter
(
Tensor
([
1.0
],
mindspore
.
float32
),
name
=
"y"
)
def
construct
(
self
,
x
):
key
=
P
.
MakeRefKey
(
"y"
)()
P
.
Assign
()(
key
,
x
)
return
x
class
StateNet
(
nn
.
Cell
):
""" StateTestTensor definition """
...
...
@@ -538,10 +525,6 @@ test_cases = [
'block'
:
Grad
(
NetWithLossClass
(
Conv2dNativeNet
())),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
1
,
3
,
16
,
16
],
np
.
float32
)),
Tensor
(
np
.
zeros
([
1
,
1764
],
np
.
float32
))],
}),
(
'MakeRefKey'
,
{
'block'
:
MakeRefKeyNet
(),
'desc_inputs'
:
[
Tensor
([
2.0
],
mindspore
.
float32
)],
}),
(
'StateTest'
,
{
'block'
:
StateNet
(),
'desc_inputs'
:
[
Tensor
(
np
.
ones
([
2
,
1
,
2
,
2
]).
astype
(
np
.
float32
))],
...
...
tests/ut/python/ops/test_signature.py
0 → 100644
浏览文件 @
7e4d972f
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test assign sub
"""
import
numpy
as
np
import
pytest
import
mindspore.context
as
context
import
mindspore.nn
as
nn
import
mindspore.ops.operations
as
P
from
mindspore
import
Tensor
from
mindspore.common.initializer
import
initializer
from
mindspore.common.parameter
import
Parameter
import
mindspore
as
ms
class
AssignW
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
AssignW
,
self
).
__init__
()
self
.
assign
=
P
.
Assign
()
def
construct
(
self
,
x
,
w
):
self
.
assign
(
x
,
w
)
return
x
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
b
=
Parameter
(
initializer
(
'ones'
,
[
5
]),
name
=
'b'
)
self
.
assign
=
AssignW
()
def
construct
(
self
,
value
):
return
self
.
assign
(
self
.
b
,
value
)
def
test_assign_through_cell
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
net
=
Net
()
net
.
to_float
(
ms
.
float16
)
net
.
add_flags_recursive
(
fp16
=
False
)
input_data
=
Tensor
(
np
.
ones
([
5
]).
astype
(
np
.
float32
))
net
(
input_data
)
with
pytest
.
raises
(
TypeError
):
net
(
None
)
class
NetScatterNdUpdate
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
NetScatterNdUpdate
,
self
).
__init__
()
self
.
b
=
Parameter
(
initializer
(
'ones'
,
[
5
,
5
]),
name
=
'b'
)
self
.
scatter
=
P
.
ScatterNdUpdate
()
def
construct
(
self
,
idx
,
x
):
return
self
.
scatter
(
self
.
b
,
idx
,
x
)
def
test_scatter_nd_update
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
net
=
NetScatterNdUpdate
()
x
=
Tensor
(
np
.
ones
([
5
]).
astype
(
np
.
float16
))
idx
=
Tensor
(
np
.
ones
([
1
]).
astype
(
np
.
int32
))
net
(
idx
,
x
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录