Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
cfa41733
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cfa41733
编写于
7月 02, 2020
作者:
B
buxue
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support Python built-in function 'enumerate'
上级
dd666ec3
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
240 addition
and
40 deletion
+240
-40
mindspore/_extends/parse/resources.py
mindspore/_extends/parse/resources.py
+1
-0
mindspore/_extends/parse/standard_method.py
mindspore/_extends/parse/standard_method.py
+16
-0
mindspore/_extends/parse/trope.py
mindspore/_extends/parse/trope.py
+2
-2
mindspore/ccsrc/operator/composite/map.cc
mindspore/ccsrc/operator/composite/map.cc
+1
-1
mindspore/ccsrc/operator/composite/zip_operation.cc
mindspore/ccsrc/operator/composite/zip_operation.cc
+21
-19
mindspore/ccsrc/operator/prim_statement.cc
mindspore/ccsrc/operator/prim_statement.cc
+1
-0
mindspore/ccsrc/pipeline/parse/parse.cc
mindspore/ccsrc/pipeline/parse/parse.cc
+2
-3
mindspore/ccsrc/pipeline/static_analysis/param_validator.h
mindspore/ccsrc/pipeline/static_analysis/param_validator.h
+1
-0
mindspore/ccsrc/pynative/pynative_execute.cc
mindspore/ccsrc/pynative/pynative_execute.cc
+5
-8
mindspore/common/dtype.py
mindspore/common/dtype.py
+2
-0
mindspore/ops/composite/multitype_ops/_compile_utils.py
mindspore/ops/composite/multitype_ops/_compile_utils.py
+2
-2
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+1
-1
tests/ut/python/ops/test_ops.py
tests/ut/python/ops/test_ops.py
+4
-4
tests/ut/python/pipeline/parse/test_enumerate.py
tests/ut/python/pipeline/parse/test_enumerate.py
+181
-0
未找到文件。
mindspore/_extends/parse/resources.py
浏览文件 @
cfa41733
...
...
@@ -116,6 +116,7 @@ convert_object_map = {
T
.
partial
:
F
.
partial
,
T
.
zip
:
C
.
zip_operation
,
T
.
print
:
F
.
print_
,
T
.
enumerate
:
M
.
enumerate_
,
# custom define operation
T
.
iter
:
M
.
ms_iter
,
...
...
mindspore/_extends/parse/standard_method.py
浏览文件 @
cfa41733
...
...
@@ -104,6 +104,15 @@ def bool_(x):
return
x
.
__bool__
()
def
enumerate_
(
x
,
start
=
0
):
"""Enumerate list or tuple."""
x_type
=
F
.
typeof
(
x
)
ret
=
()
if
check_is_tuple_or_list
(
x_type
,
"enumerate"
):
ret
=
zip
(
range
(
start
,
start
+
len
(
x
)),
x
)
return
ret
def
while_cond
(
x
):
"""For while condtion, if the condition is a tensor, the loop will not be unrolled"""
if
F
.
issubclass_
(
F
.
typeof
(
x
),
F
.
typeof
(
mstype
.
tensor
)):
...
...
@@ -113,6 +122,13 @@ def while_cond(x):
return
x
@
constexpr
def
check_is_tuple_or_list
(
x
,
op_name
):
"""check whether x is list or tuple."""
if
isinstance
(
x
,
(
mstype
.
list_type
,
mstype
.
tuple_type
)):
return
True
raise
TypeError
(
f
"For '
{
op_name
}
', the input parameter should be tuple or list, but got
{
x
}
."
)
@
constexpr
def
check_is_tensor_bool_cond
(
shp
):
"""check if tensor is a bool condition"""
...
...
mindspore/_extends/parse/trope.py
浏览文件 @
cfa41733
...
...
@@ -27,7 +27,7 @@ from operator import ( # noqa
# support system function call
from
builtins
import
(
# noqa
bool
,
getattr
,
setattr
,
len
,
iter
,
next
,
pow
,
range
,
map
,
zip
,
print
bool
,
getattr
,
setattr
,
len
,
iter
,
next
,
pow
,
range
,
map
,
zip
,
print
,
enumerate
)
# support functools
...
...
@@ -44,7 +44,7 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt',
'not_'
,
'and_'
,
'or_'
,
'xor'
,
'lshift'
,
'rshift'
,
'invert'
,
'is_'
,
'is_not'
,
'contains'
,
'matmul'
,
'getitem'
,
'setitem'
,
'bool'
,
'getattr'
,
'setattr'
,
'len'
,
'iter'
,
'next'
,
'pow'
,
'range'
,
'map'
,
'zip'
,
'partial'
,
'print'
,
'partial'
,
'print'
,
'enumerate'
,
'exp'
,
'log'
,
'sin'
,
'cos'
,
'tan'
]
...
...
mindspore/ccsrc/operator/composite/map.cc
浏览文件 @
cfa41733
...
...
@@ -181,7 +181,7 @@ AnfNodePtr Map::FullMakeClass(const std::shared_ptr<Class> &type, const FuncGrap
}
AnfNodePtr
Map
::
Make
(
const
FuncGraphPtr
&
func_graph
,
const
AnfNodePtr
&
fn_arg
,
const
ArgsPairList
&
arg_pairs
)
{
if
(
arg_pairs
.
size
()
<
1
)
{
if
(
arg_pairs
.
empty
()
)
{
MS_EXCEPTION
(
TypeError
)
<<
"map() must have at least two arguments"
;
}
bool
found
=
false
;
...
...
mindspore/ccsrc/operator/composite/zip_operation.cc
浏览文件 @
cfa41733
...
...
@@ -18,44 +18,44 @@
#include "operator/composite/zip_operation.h"
#include <algorithm>
#include <utility>
#include "pipeline/static_analysis/abstract_value.h"
#include "ir/anf.h"
#include "pipeline/static_analysis/dshape.h"
#include "pipeline/static_analysis/param_validator.h"
#include "operator/cc_implementations.h"
#include "optimizer/opt.h"
#include "utils/symbolic.h"
#include "./common.h"
#include "pybind_api/api_register.h"
namespace
mindspore
{
// namespace to support composite operators definition
namespace
prim
{
using
mindspore
::
abstract
::
AbstractBase
;
using
mindspore
::
abstract
::
AbstractList
;
using
mindspore
::
abstract
::
AbstractSequeue
;
using
mindspore
::
abstract
::
AbstractSequeuePtr
;
using
mindspore
::
abstract
::
AbstractTuple
;
FuncGraphPtr
ZipOperation
::
GenerateFuncGraph
(
const
AbstractBasePtrList
&
args_spec_list
)
{
// zip operation:
// input: tuple arguments
// output: tuple of items of input iterated on every input
if
(
args_spec_list
.
size
()
==
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"
zip arguments input should not be empty
"
;
if
(
args_spec_list
.
empty
()
)
{
MS_LOG
(
EXCEPTION
)
<<
"
For 'zip', there is at least one input.
"
;
}
auto
is_all_tuple
=
std
::
all_of
(
args_spec_list
.
begin
(),
args_spec_list
.
end
(),
[](
const
AbstractBasePtr
&
abs
)
->
bool
{
MS_EXCEPTION_IF_NULL
(
abs
);
return
abs
->
isa
<
AbstractTuple
>
();
});
if
(
!
is_all_tuple
)
{
MS_LOG
(
EXCEPTION
)
<<
"zip input args should be tuple"
;
auto
is_all_sequeue
=
std
::
all_of
(
args_spec_list
.
begin
(),
args_spec_list
.
end
(),
[](
const
AbstractBasePtr
&
abs
)
->
bool
{
MS_EXCEPTION_IF_NULL
(
abs
);
return
abs
->
isa
<
AbstractSequeue
>
();
});
if
(
!
is_all_sequeue
)
{
MS_LOG
(
EXCEPTION
)
<<
"For 'zip', all inputs must be sequence."
;
}
auto
min_abs
=
std
::
min_element
(
args_spec_list
.
begin
(),
args_spec_list
.
end
(),
[](
const
AbstractBasePtr
&
x
,
const
AbstractBasePtr
&
y
)
{
return
(
x
->
cast
<
AbstractTuplePtr
>
()
->
size
()
<
y
->
cast
<
AbstractTupl
ePtr
>
()
->
size
());
});
auto
min_abs
=
std
::
min_element
(
args_spec_list
.
begin
(),
args_spec_list
.
end
(),
[](
const
AbstractBasePtr
&
x
,
const
AbstractBasePtr
&
y
)
{
return
(
x
->
cast
<
AbstractSequeuePtr
>
()
->
size
()
<
y
->
cast
<
AbstractSequeu
ePtr
>
()
->
size
());
});
FuncGraphPtr
ret_graph
=
std
::
make_shared
<
FuncGraph
>
();
ret_graph
->
set_flag
(
FUNC_GRAPH_FLAG_CORE
,
true
);
for
(
size_t
idx
=
0
;
idx
<
args_spec_list
.
size
();
idx
++
)
{
...
...
@@ -65,12 +65,14 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spe
// generate tuple output of ziped arguments input
std
::
vector
<
AnfNodePtr
>
make_tuple_nodes
;
make_tuple_nodes
.
push_back
(
NewValueNode
(
prim
::
kPrimMakeTuple
));
for
(
size_t
idx
=
0
;
idx
<
(
*
min_abs
)
->
cast
<
Abstract
Tupl
ePtr
>
()
->
size
();
idx
++
)
{
for
(
size_t
idx
=
0
;
idx
<
(
*
min_abs
)
->
cast
<
Abstract
Sequeu
ePtr
>
()
->
size
();
idx
++
)
{
std
::
vector
<
AnfNodePtr
>
make_tuple_zip_nodes
;
make_tuple_zip_nodes
.
push_back
(
NewValueNode
(
prim
::
kPrimMakeTuple
));
std
::
string
module_name
=
"mindspore.ops.composite.multitype_ops.getitem_impl"
;
ValuePtr
op
=
prim
::
GetPythonOps
(
"getitem"
,
module_name
);
for
(
size_t
arg_idx
=
0
;
arg_idx
<
args_spec_list
.
size
();
arg_idx
++
)
{
std
::
vector
<
AnfNodePtr
>
tuple_get_item_nodes
{
NewValueNode
(
prim
::
kPrimTupleGetItem
)
,
ret_graph
->
parameters
()[
arg_idx
],
NewValueNode
(
SizeToInt
(
idx
))};
std
::
vector
<
AnfNodePtr
>
tuple_get_item_nodes
{
NewValueNode
(
op
),
ret_graph
->
parameters
()[
arg_idx
]
,
NewValueNode
(
SizeToInt
(
idx
))};
auto
tuple_get_item_op
=
ret_graph
->
NewCNode
(
tuple_get_item_nodes
);
make_tuple_zip_nodes
.
push_back
(
tuple_get_item_op
);
}
...
...
mindspore/ccsrc/operator/prim_statement.cc
浏览文件 @
cfa41733
...
...
@@ -229,6 +229,7 @@ AbstractBasePtr InferImplNotInDict(const AnalysisEnginePtr &, const PrimitivePtr
// Inputs: x, t
return
std
::
make_shared
<
AbstractScalar
>
(
!
IsInDict
(
primitive
,
args_spec_list
));
}
AbstractBasePtr
InferImplIsConstant
(
const
AnalysisEnginePtr
&
,
const
PrimitivePtr
&
primitive
,
const
AbstractBasePtrList
&
args_spec_list
)
{
// statement: isconstant(x)
...
...
mindspore/ccsrc/pipeline/parse/parse.cc
浏览文件 @
cfa41733
...
...
@@ -1048,11 +1048,10 @@ FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::objec
CNodePtr
app
=
body_block
->
func_graph
()
->
NewCNode
({
op_next
,
iter_param
});
CNodePtr
target_app
=
body_block
->
func_graph
()
->
NewCNode
({
op_getitem
,
app
,
NewValueNode
(
0
)});
py
::
object
target_node
=
python_adapter
::
GetPyObjAttr
(
node
,
"target"
);
auto
name_id
=
py
::
cast
<
std
::
string
>
(
python_adapter
::
GetPyObjAttr
(
target_node
,
"id"
));
target_app
->
debug_info
()
->
set_name
(
name_id
);
CNodePtr
iter2_app
=
body_block
->
func_graph
()
->
NewCNode
({
op_getitem
,
app
,
NewValueNode
(
1
)});
body_block
->
WriteVariable
(
name_id
,
target_app
);
WriteAssignVars
(
body_block
,
target_node
,
target_app
);
// link the variable name with the target
auto
it_info
=
std
::
make_shared
<
TraceIterator
>
(
target_app
->
debug_info
());
iter_param
->
debug_info
()
->
set_trace_info
(
it_info
);
...
...
mindspore/ccsrc/pipeline/static_analysis/param_validator.h
浏览文件 @
cfa41733
...
...
@@ -67,6 +67,7 @@ ABSTRACT_REPORT_NAME_TRAITS(Type)
ABSTRACT_REPORT_NAME_TRAITS
(
KeywordArg
)
ABSTRACT_REPORT_NAME_TRAITS
(
Class
)
ABSTRACT_REPORT_NAME_TRAITS
(
IndexedSlices
)
ABSTRACT_REPORT_NAME_TRAITS
(
Sequeue
)
template
<
typename
T
>
std
::
shared_ptr
<
T
>
CheckArg
(
const
std
::
string
&
op
,
const
AbstractBasePtrList
&
args_spec_list
,
size_t
index
)
{
...
...
mindspore/ccsrc/pynative/pynative_execute.cc
浏览文件 @
cfa41733
...
...
@@ -226,11 +226,8 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn
AbstractBasePtrList
args_spec_list
;
for
(
size_t
i
=
0
;
i
<
size
;
i
++
)
{
ValuePtr
input_value
=
PyAttrValue
(
py_args
[
i
]);
if
(
!
py
::
hasattr
(
prim
->
GetPyObj
(),
"const_value"
)
&&
input_value
->
isa
<
tensor
::
Tensor
>
())
{
args_spec_list
.
emplace_back
(
abstract
::
FromValueInside
(
input_value
,
true
));
}
else
{
args_spec_list
.
emplace_back
(
abstract
::
FromValueInside
(
input_value
,
false
));
}
args_spec_list
.
emplace_back
(
abstract
::
FromValueInside
(
input_value
,
!
py
::
hasattr
(
prim
->
GetPyObj
(),
"const_value"
)
&&
input_value
->
isa
<
tensor
::
Tensor
>
()));
}
AbstractBasePtr
infer_res
=
EvalOnePrim
(
prim
,
args_spec_list
)
->
abstract
();
op_exec_info
->
abstract
=
infer_res
;
...
...
@@ -512,7 +509,7 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
return
result
;
}
py
::
object
RunOpWithBackendPolicy
(
MsBackendPolicy
backend_policy
,
const
OpExecInfoPtr
op_exec_info
,
py
::
object
RunOpWithBackendPolicy
(
MsBackendPolicy
backend_policy
,
const
OpExecInfoPtr
&
op_exec_info
,
PynativeStatusCode
*
const
status
)
{
MS_EXCEPTION_IF_NULL
(
status
);
py
::
object
result
;
...
...
@@ -550,7 +547,7 @@ py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecIn
}
AnfNodePtr
PynativeExecutor
::
MakeCNode
(
const
OpExecInfoPtr
&
op_exec_info
,
const
py
::
args
&
args
,
const
py
::
tuple
&
out
)
{
if
(
!
grad_flag_
||
graph_info_map_
.
size
()
==
0
)
{
if
(
!
grad_flag_
||
graph_info_map_
.
empty
()
)
{
return
nullptr
;
}
std
::
vector
<
AnfNodePtr
>
inputs
;
...
...
@@ -753,7 +750,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, const py::object &o
if
(
py
::
isinstance
<
py
::
none
>
(
name_attr
))
{
MS_LOG
(
EXCEPTION
)
<<
"Parameter object should have name attribute"
;
}
std
::
string
param_name
=
py
::
cast
<
std
::
string
>
(
name_attr
);
auto
param_name
=
py
::
cast
<
std
::
string
>
(
name_attr
);
if
(
graph_info_map_
[
df_builder_
].
param_map
.
count
(
obj_id
)
==
0
)
{
auto
free_param
=
df_builder_
->
add_parameter
();
free_param
->
set_name
(
param_name
);
...
...
mindspore/common/dtype.py
浏览文件 @
cfa41733
...
...
@@ -97,6 +97,8 @@ tensor_type = typing.TensorType
anything_type
=
typing
.
TypeAnything
slice_type
=
typing
.
Slice
ellipsis_type
=
typing
.
TypeEllipsis
list_type
=
typing
.
List
tuple_type
=
typing
.
Tuple
number_type
=
(
int8
,
int16
,
...
...
mindspore/ops/composite/multitype_ops/_compile_utils.py
浏览文件 @
cfa41733
...
...
@@ -65,9 +65,9 @@ def _generate_indices_from_tuple_of_mixed_tensors(data, tuple_index, op_name):
tuple_len
=
len
(
tuple_index
)
for
i
in
range
(
tuple_len
):
if
i
in
int_positions
:
tuple_index_new
=
tuple_index_new
+
(
F
.
scalar_to_tensor
(
tuple_index
[
i
],
mstype
.
int32
),)
tuple_index_new
+=
(
F
.
scalar_to_tensor
(
tuple_index
[
i
],
mstype
.
int32
),)
else
:
tuple_index_new
=
tuple_index_new
+
(
tuple_index
[
i
],)
tuple_index_new
+=
(
tuple_index
[
i
],)
indexes_types
=
hyper_map
(
F
.
typeof
,
tuple_index_new
)
tensor_positions
,
slice_positions
,
ellipsis_position
=
\
const_utils
.
separate_mixed_tensors_index
(
indexes_types
,
op_name
)
...
...
mindspore/ops/operations/array_ops.py
浏览文件 @
cfa41733
...
...
@@ -1469,7 +1469,7 @@ class Concat(PrimitiveWithInfer):
def
_get_pack_shape
(
x_shape
,
x_type
,
axis
,
prim_name
):
"""for pack output shape"""
validator
.
check_value_type
(
"shape"
,
x_shape
,
[
tuple
,
list
],
prim_name
)
validator
.
check_integer
(
"len of input_x"
,
len
(
x_shape
),
1
,
Rel
.
G
T
,
prim_name
)
validator
.
check_integer
(
"len of input_x"
,
len
(
x_shape
),
1
,
Rel
.
G
E
,
prim_name
)
validator
.
check_subclass
(
"input_x[0]"
,
x_type
[
0
],
mstype
.
tensor
,
prim_name
)
rank_base
=
len
(
x_shape
[
0
])
N
=
len
(
x_shape
)
...
...
tests/ut/python/ops/test_ops.py
浏览文件 @
cfa41733
...
...
@@ -1747,6 +1747,10 @@ test_case_array_ops = [
'desc_inputs'
:
[[
128
,
128
],
[
128
,
128
]],
'desc_bprop'
:
[[
2
,
128
,
128
]],
}),
(
'Pack_3'
,
{
'block'
:
NetForPackInput
(
P
.
Pack
()),
'desc_inputs'
:
[[
2
,
2
]],
'desc_bprop'
:
[[
1
,
2
,
2
]]}),
(
'Unpack_0'
,
{
'block'
:
NetForUnpackInput
(
P
.
Unpack
(
axis
=
0
)),
'desc_inputs'
:
[[
2
,
4
]],
...
...
@@ -2206,10 +2210,6 @@ raise_set = [
Tensor
(
np
.
ones
((
2
,
2
),
np
.
float32
)),
Tensor
(
np
.
ones
((
2
,),
np
.
float32
))),
'desc_bprop'
:
[[
2
,
3
]]}),
(
'Pack'
,
{
'block'
:
(
NetForPackInput
(
P
.
Pack
()),
{
'exception'
:
ValueError
}),
'desc_inputs'
:
[[
2
,
2
]],
'desc_bprop'
:
[[
1
,
2
,
2
]]}),
(
'PReLU'
,
{
'block'
:
(
P
.
PReLU
(),
{
'exception'
:
ValueError
}),
'desc_inputs'
:
[[
2
],
[
1
]],
...
...
tests/ut/python/pipeline/parse/test_enumerate.py
0 → 100644
浏览文件 @
cfa41733
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test enumerate"""
import
numpy
as
np
import
pytest
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore
import
context
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
def
test_enumerate_list_const
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
value
=
[
11
,
22
,
33
,
44
]
def
construct
(
self
):
index_sum
=
0
value_sum
=
0
for
i
,
j
in
enumerate
(
self
.
value
):
index_sum
+=
i
value_sum
+=
j
return
index_sum
,
value_sum
net
=
Net
()
assert
net
()
==
(
6
,
110
)
def
test_enumerate_tuple_const
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
value
=
(
11
,
22
,
33
,
44
)
def
construct
(
self
):
index_sum
=
0
value_sum
=
0
for
i
,
j
in
enumerate
(
self
.
value
):
index_sum
+=
i
value_sum
+=
j
return
index_sum
,
value_sum
net
=
Net
()
assert
net
()
==
(
6
,
110
)
def
test_enumerate_list_parameter
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
def
construct
(
self
,
x
,
y
,
z
):
index_sum
=
0
value
=
[
x
,
y
,
z
]
ret
=
()
for
i
,
j
in
enumerate
(
value
):
index_sum
+=
i
ret
+=
(
j
,)
return
index_sum
,
ret
x
=
Tensor
(
np
.
arange
(
3
*
4
*
5
).
reshape
((
3
,
4
,
5
)))
net
=
Net
()
net
(
x
,
x
,
x
)
def
test_enumerate_tuple_parameter
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
def
construct
(
self
,
x
,
y
,
z
):
index_sum
=
0
value
=
(
x
,
y
,
z
)
ret
=
()
for
i
,
j
in
enumerate
(
value
):
index_sum
+=
i
ret
+=
(
j
,)
return
index_sum
,
ret
x
=
Tensor
(
np
.
arange
(
3
*
4
*
5
).
reshape
((
3
,
4
,
5
)))
net
=
Net
()
net
(
x
,
x
,
x
)
def
test_enumerate_tuple_const_1
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
value
=
(
11
,
22
,
33
,
44
)
def
construct
(
self
):
index_sum
=
0
value_sum
=
0
for
i
in
enumerate
(
self
.
value
):
index_sum
+=
i
[
0
]
value_sum
+=
i
[
1
]
return
index_sum
,
value_sum
net
=
Net
()
assert
net
()
==
(
6
,
110
)
def
test_enumerate_tuple_parameter_1
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
def
construct
(
self
,
x
,
y
,
z
):
index_sum
=
0
value
=
(
x
,
y
,
z
)
ret
=
()
for
i
in
enumerate
(
value
):
index_sum
+=
i
[
0
]
ret
+=
(
i
[
1
],)
return
index_sum
,
ret
x
=
Tensor
(
np
.
arange
(
3
*
4
*
5
).
reshape
((
3
,
4
,
5
)))
net
=
Net
()
net
(
x
,
x
,
x
)
def
test_enumerate_tuple_const_2
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
value
=
(
11
,
22
,
33
,
44
)
def
construct
(
self
):
index_sum
=
0
value_sum
=
0
for
i
in
enumerate
(
self
.
value
,
1
):
index_sum
+=
i
[
0
]
value_sum
+=
i
[
1
]
return
index_sum
,
value_sum
net
=
Net
()
assert
net
()
==
(
10
,
110
)
def
test_enumerate_tuple_parameter_2
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
def
construct
(
self
,
x
,
y
,
z
):
index_sum
=
0
value
=
(
x
,
y
,
z
)
ret
=
()
for
i
in
enumerate
(
value
,
2
):
index_sum
+=
i
[
0
]
ret
+=
(
i
[
1
],)
return
index_sum
,
ret
x
=
Tensor
(
np
.
arange
(
3
*
4
*
5
).
reshape
((
3
,
4
,
5
)))
net
=
Net
()
net
(
x
,
x
,
x
)
def
test_enumerate_parameter_type_error
():
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
def
construct
(
self
,
x
):
return
enumerate
(
x
)
x
=
Tensor
(
np
.
arange
(
3
*
4
*
5
).
reshape
((
3
,
4
,
5
)))
net
=
Net
()
with
pytest
.
raises
(
TypeError
)
as
ex
:
net
(
x
)
assert
"For 'enumerate', the input parameter should be tuple or list"
in
str
(
ex
.
value
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录