Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
3b71bd0d
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3b71bd0d
编写于
6月 20, 2020
作者:
X
xulei2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rename input to children, output to parent
上级
5dac9c4c
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
87 addition
and
87 deletion
+87
-87
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+71
-71
mindspore/dataset/engine/iterators.py
mindspore/dataset/engine/iterators.py
+11
-11
mindspore/dataset/engine/serializer_deserializer.py
mindspore/dataset/engine/serializer_deserializer.py
+4
-4
tests/ut/python/dataset/test_iterator.py
tests/ut/python/dataset/test_iterator.py
+1
-1
未找到文件。
mindspore/dataset/engine/datasets.py
浏览文件 @
3b71bd0d
...
...
@@ -134,8 +134,8 @@ class Dataset:
"""
def
__init__
(
self
,
num_parallel_workers
=
None
):
self
.
input
=
[]
self
.
outpu
t
=
[]
self
.
children
=
[]
self
.
paren
t
=
[]
self
.
num_parallel_workers
=
num_parallel_workers
self
.
_device_iter
=
0
self
.
_input_indexs
=
()
...
...
@@ -1006,9 +1006,9 @@ class Dataset:
dev_id
=
output_dataset
.
shard_id
return
""
,
dev_id
if
not
output_dataset
.
input
:
if
not
output_dataset
.
children
:
raise
RuntimeError
(
"Unknown output_dataset: {}"
.
format
(
type
(
output_dataset
)))
input_dataset
=
output_dataset
.
input
[
0
]
input_dataset
=
output_dataset
.
children
[
0
]
return
get_distribution
(
input_dataset
)
distribution_path
,
device_id
=
get_distribution
(
self
)
...
...
@@ -1129,8 +1129,8 @@ class Dataset:
Return:
Number, number of batches.
"""
if
self
.
input
:
return
self
.
input
[
0
].
get_dataset_size
()
if
self
.
children
:
return
self
.
children
[
0
].
get_dataset_size
()
return
None
def
num_classes
(
self
):
...
...
@@ -1140,23 +1140,23 @@ class Dataset:
Return:
Number, number of classes.
"""
if
self
.
input
:
return
self
.
input
[
0
].
num_classes
()
if
self
.
children
:
return
self
.
children
[
0
].
num_classes
()
return
None
def
get_sync_notifiers
(
self
):
if
self
.
input
:
return
self
.
input
[
0
].
get_sync_notifiers
()
if
self
.
children
:
return
self
.
children
[
0
].
get_sync_notifiers
()
return
{}
def
disable_sync
(
self
):
if
self
.
input
:
return
self
.
input
[
0
].
disable_sync
()
if
self
.
children
:
return
self
.
children
[
0
].
disable_sync
()
return
{}
def
is_sync
(
self
):
if
self
.
input
:
return
self
.
input
[
0
].
is_sync
()
if
self
.
children
:
return
self
.
children
[
0
].
is_sync
()
return
False
def
sync_update
(
self
,
condition_name
,
num_batch
=
None
,
data
=
None
):
...
...
@@ -1190,8 +1190,8 @@ class Dataset:
Return:
Number, the number of data in a batch.
"""
if
self
.
input
:
return
self
.
input
[
0
].
get_batch_size
()
if
self
.
children
:
return
self
.
children
[
0
].
get_batch_size
()
return
1
def
get_repeat_count
(
self
):
...
...
@@ -1201,8 +1201,8 @@ class Dataset:
Return:
Number, the count of repeat.
"""
if
self
.
input
:
return
self
.
input
[
0
].
get_repeat_count
()
if
self
.
children
:
return
self
.
children
[
0
].
get_repeat_count
()
return
1
def
get_class_indexing
(
self
):
...
...
@@ -1212,22 +1212,22 @@ class Dataset:
Return:
Dict, A str-to-int mapping from label name to index.
"""
if
self
.
input
:
return
self
.
input
[
0
].
get_class_indexing
()
if
self
.
children
:
return
self
.
children
[
0
].
get_class_indexing
()
raise
NotImplementedError
(
"Dataset {} has not supported api get_class_indexing yet."
.
format
(
type
(
self
)))
def
reset
(
self
):
"""Reset the dataset for next epoch."""
def
is_shuffled
(
self
):
for
input_dataset
in
self
.
input
:
for
input_dataset
in
self
.
children
:
if
input_dataset
.
is_shuffled
():
return
True
return
False
def
is_sharded
(
self
):
for
input_dataset
in
self
.
input
:
for
input_dataset
in
self
.
children
:
if
input_dataset
.
is_sharded
():
return
True
...
...
@@ -1466,8 +1466,8 @@ class BucketBatchByLengthDataset(DatasetOp):
self
.
pad_to_bucket_boundary
=
pad_to_bucket_boundary
self
.
drop_remainder
=
drop_remainder
self
.
input
.
append
(
input_dataset
)
input_dataset
.
outpu
t
.
append
(
self
)
self
.
children
.
append
(
input_dataset
)
input_dataset
.
paren
t
.
append
(
self
)
self
.
_input_indexs
=
input_dataset
.
input_indexs
def
get_args
(
self
):
...
...
@@ -1529,8 +1529,8 @@ class BatchDataset(DatasetOp):
self
.
per_batch_map
=
per_batch_map
self
.
input_columns
=
input_columns
self
.
pad_info
=
pad_info
self
.
input
.
append
(
input_dataset
)
input_dataset
.
outpu
t
.
append
(
self
)
self
.
children
.
append
(
input_dataset
)
input_dataset
.
paren
t
.
append
(
self
)
self
.
_input_indexs
=
input_dataset
.
input_indexs
def
get_args
(
self
):
...
...
@@ -1549,7 +1549,7 @@ class BatchDataset(DatasetOp):
Return:
Number, number of batches.
"""
child_size
=
self
.
input
[
0
].
get_dataset_size
()
child_size
=
self
.
children
[
0
].
get_dataset_size
()
if
child_size
is
not
None
:
if
self
.
drop_remainder
:
return
math
.
floor
(
child_size
/
self
.
batch_size
)
...
...
@@ -1578,7 +1578,7 @@ class BatchDataset(DatasetOp):
if
isinstance
(
dataset
,
RepeatDataset
):
return
True
flag
=
False
for
input_dataset
in
dataset
.
input
:
for
input_dataset
in
dataset
.
children
:
flag
=
flag
|
BatchDataset
.
_is_ancestor_of_repeat
(
input_dataset
)
return
flag
...
...
@@ -1593,7 +1593,7 @@ class BatchDataset(DatasetOp):
"""
if
isinstance
(
dataset
,
SyncWaitDataset
):
dataset
.
update_sync_batch_size
(
batch_size
)
for
input_dataset
in
dataset
.
input
:
for
input_dataset
in
dataset
.
children
:
BatchDataset
.
_update_batch_size_for_syncwait
(
input_dataset
,
batch_size
)
...
...
@@ -1699,21 +1699,21 @@ class SyncWaitDataset(DatasetOp):
def
__init__
(
self
,
input_dataset
,
condition_name
,
num_batch
,
callback
=
None
):
super
().
__init__
()
self
.
input
.
append
(
input_dataset
)
input_dataset
.
outpu
t
.
append
(
self
)
self
.
children
.
append
(
input_dataset
)
input_dataset
.
paren
t
.
append
(
self
)
# set to the default value, waiting for the batch to update it
self
.
_condition_name
=
condition_name
if
isinstance
(
num_batch
,
int
)
and
num_batch
<=
0
:
raise
ValueError
(
"num_batch need to be greater than 0."
)
self
.
_pair
=
BlockReleasePair
(
num_batch
,
callback
)
if
self
.
_condition_name
in
self
.
input
[
0
].
get_sync_notifiers
():
if
self
.
_condition_name
in
self
.
children
[
0
].
get_sync_notifiers
():
raise
RuntimeError
(
"Condition name is already in use"
)
logger
.
warning
(
"Please remember to add dataset.sync_update(condition=%s), otherwise will result in hanging"
,
condition_name
)
def
get_sync_notifiers
(
self
):
return
{
**
self
.
input
[
0
].
get_sync_notifiers
(),
**
{
self
.
_condition_name
:
self
.
_pair
.
release_func
}}
return
{
**
self
.
children
[
0
].
get_sync_notifiers
(),
**
{
self
.
_condition_name
:
self
.
_pair
.
release_func
}}
def
is_sync
(
self
):
return
True
...
...
@@ -1746,7 +1746,7 @@ class SyncWaitDataset(DatasetOp):
if
isinstance
(
dataset
,
BatchDataset
):
return
True
flag
=
False
for
input_dataset
in
dataset
.
input
:
for
input_dataset
in
dataset
.
children
:
flag
=
flag
|
SyncWaitDataset
.
_is_ancestor_of_batch
(
input_dataset
)
return
flag
...
...
@@ -1766,9 +1766,9 @@ class ShuffleDataset(DatasetOp):
def
__init__
(
self
,
input_dataset
,
buffer_size
):
super
().
__init__
()
self
.
buffer_size
=
buffer_size
self
.
input
.
append
(
input_dataset
)
self
.
children
.
append
(
input_dataset
)
self
.
reshuffle_each_epoch
=
None
input_dataset
.
outpu
t
.
append
(
self
)
input_dataset
.
paren
t
.
append
(
self
)
self
.
_input_indexs
=
input_dataset
.
input_indexs
if
self
.
is_sync
():
raise
RuntimeError
(
"No shuffle after sync operators"
)
...
...
@@ -1864,7 +1864,7 @@ class MapDataset(DatasetOp):
def
__init__
(
self
,
input_dataset
,
input_columns
=
None
,
operations
=
None
,
output_columns
=
None
,
columns_order
=
None
,
num_parallel_workers
=
None
,
python_multiprocessing
=
False
):
super
().
__init__
(
num_parallel_workers
)
self
.
input
.
append
(
input_dataset
)
self
.
children
.
append
(
input_dataset
)
if
input_columns
is
not
None
and
not
isinstance
(
input_columns
,
list
):
input_columns
=
[
input_columns
]
self
.
input_columns
=
input_columns
...
...
@@ -1881,7 +1881,7 @@ class MapDataset(DatasetOp):
and
self
.
columns_order
is
None
:
raise
ValueError
(
"When (len(input_columns) != len(output_columns)), columns_order must be specified."
)
input_dataset
.
outpu
t
.
append
(
self
)
input_dataset
.
paren
t
.
append
(
self
)
self
.
_input_indexs
=
input_dataset
.
input_indexs
self
.
python_multiprocessing
=
python_multiprocessing
self
.
process_pool
=
None
...
...
@@ -1901,7 +1901,7 @@ class MapDataset(DatasetOp):
Return:
Number, number of batches.
"""
return
self
.
input
[
0
].
get_dataset_size
()
return
self
.
children
[
0
].
get_dataset_size
()
def
__deepcopy__
(
self
,
memodict
):
if
id
(
self
)
in
memodict
:
...
...
@@ -1909,12 +1909,12 @@ class MapDataset(DatasetOp):
cls
=
self
.
__class__
new_op
=
cls
.
__new__
(
cls
)
memodict
[
id
(
self
)]
=
new_op
new_op
.
input
=
copy
.
deepcopy
(
self
.
input
,
memodict
)
new_op
.
children
=
copy
.
deepcopy
(
self
.
children
,
memodict
)
new_op
.
input_columns
=
copy
.
deepcopy
(
self
.
input_columns
,
memodict
)
new_op
.
output_columns
=
copy
.
deepcopy
(
self
.
output_columns
,
memodict
)
new_op
.
columns_order
=
copy
.
deepcopy
(
self
.
columns_order
,
memodict
)
new_op
.
num_parallel_workers
=
copy
.
deepcopy
(
self
.
num_parallel_workers
,
memodict
)
new_op
.
output
=
copy
.
deepcopy
(
self
.
outpu
t
,
memodict
)
new_op
.
parent
=
copy
.
deepcopy
(
self
.
paren
t
,
memodict
)
new_op
.
input_indexs
=
copy
.
deepcopy
(
self
.
_input_indexs
,
memodict
)
new_op
.
python_multiprocessing
=
copy
.
deepcopy
(
self
.
python_multiprocessing
,
memodict
)
new_op
.
operations
=
self
.
operations
...
...
@@ -1975,8 +1975,8 @@ class FilterDataset(DatasetOp):
def
__init__
(
self
,
input_dataset
,
predicate
,
input_columns
=
None
,
num_parallel_workers
=
None
):
super
().
__init__
(
num_parallel_workers
)
self
.
predicate
=
lambda
*
args
:
bool
(
predicate
(
*
args
))
self
.
input
.
append
(
input_dataset
)
input_dataset
.
outpu
t
.
append
(
self
)
self
.
children
.
append
(
input_dataset
)
input_dataset
.
paren
t
.
append
(
self
)
if
input_columns
is
not
None
and
not
isinstance
(
input_columns
,
list
):
input_columns
=
[
input_columns
]
self
.
input_columns
=
input_columns
...
...
@@ -2012,8 +2012,8 @@ class RepeatDataset(DatasetOp):
self
.
count
=
-
1
else
:
self
.
count
=
count
self
.
input
.
append
(
input_dataset
)
input_dataset
.
outpu
t
.
append
(
self
)
self
.
children
.
append
(
input_dataset
)
input_dataset
.
paren
t
.
append
(
self
)
self
.
_input_indexs
=
input_dataset
.
input_indexs
def
get_args
(
self
):
...
...
@@ -2028,7 +2028,7 @@ class RepeatDataset(DatasetOp):
Return:
Number, number of batches.
"""
child_size
=
self
.
input
[
0
].
get_dataset_size
()
child_size
=
self
.
children
[
0
].
get_dataset_size
()
if
child_size
is
not
None
:
return
child_size
return
None
...
...
@@ -2055,8 +2055,8 @@ class SkipDataset(DatasetOp):
def
__init__
(
self
,
input_dataset
,
count
):
super
().
__init__
()
self
.
count
=
count
self
.
input
.
append
(
input_dataset
)
input_dataset
.
outpu
t
.
append
(
self
)
self
.
children
.
append
(
input_dataset
)
input_dataset
.
paren
t
.
append
(
self
)
self
.
_input_indexs
=
input_dataset
.
input_indexs
def
get_args
(
self
):
...
...
@@ -2071,7 +2071,7 @@ class SkipDataset(DatasetOp):
Return:
Number, number of batches.
"""
child_size
=
self
.
input
[
0
].
get_dataset_size
()
child_size
=
self
.
children
[
0
].
get_dataset_size
()
output_size
=
0
if
self
.
count
>=
0
and
self
.
count
<
child_size
:
output_size
=
child_size
-
self
.
count
...
...
@@ -2090,8 +2090,8 @@ class TakeDataset(DatasetOp):
def
__init__
(
self
,
input_dataset
,
count
):
super
().
__init__
()
self
.
count
=
count
self
.
input
.
append
(
input_dataset
)
input_dataset
.
outpu
t
.
append
(
self
)
self
.
children
.
append
(
input_dataset
)
input_dataset
.
paren
t
.
append
(
self
)
self
.
_input_indexs
=
input_dataset
.
input_indexs
def
get_args
(
self
):
...
...
@@ -2106,7 +2106,7 @@ class TakeDataset(DatasetOp):
Return:
Number, number of batches.
"""
child_size
=
self
.
input
[
0
].
get_dataset_size
()
child_size
=
self
.
children
[
0
].
get_dataset_size
()
if
child_size
<
self
.
count
:
return
child_size
return
self
.
count
...
...
@@ -2130,8 +2130,8 @@ class ZipDataset(DatasetOp):
raise
TypeError
(
"The parameter %s of zip has type error!"
%
(
dataset
))
self
.
datasets
=
datasets
for
data
in
datasets
:
self
.
input
.
append
(
data
)
data
.
outpu
t
.
append
(
self
)
self
.
children
.
append
(
data
)
data
.
paren
t
.
append
(
self
)
def
get_dataset_size
(
self
):
"""
...
...
@@ -2140,7 +2140,7 @@ class ZipDataset(DatasetOp):
Return:
Number, number of batches.
"""
children_sizes
=
[
c
.
get_dataset_size
()
for
c
in
self
.
input
]
children_sizes
=
[
c
.
get_dataset_size
()
for
c
in
self
.
children
]
if
all
(
c
is
not
None
for
c
in
children_sizes
):
return
min
(
children_sizes
)
return
None
...
...
@@ -2155,7 +2155,7 @@ class ZipDataset(DatasetOp):
return
None
def
is_sync
(
self
):
return
any
([
c
.
is_sync
()
for
c
in
self
.
input
])
return
any
([
c
.
is_sync
()
for
c
in
self
.
children
])
def
get_args
(
self
):
args
=
super
().
get_args
()
...
...
@@ -2180,8 +2180,8 @@ class ConcatDataset(DatasetOp):
raise
TypeError
(
"The parameter %s of concat has type error!"
%
(
dataset
))
self
.
datasets
=
datasets
for
data
in
datasets
:
self
.
input
.
append
(
data
)
data
.
outpu
t
.
append
(
self
)
self
.
children
.
append
(
data
)
data
.
paren
t
.
append
(
self
)
def
get_dataset_size
(
self
):
"""
...
...
@@ -2190,7 +2190,7 @@ class ConcatDataset(DatasetOp):
Return:
Number, number of batches.
"""
children_sizes
=
[
c
.
get_dataset_size
()
for
c
in
self
.
input
]
children_sizes
=
[
c
.
get_dataset_size
()
for
c
in
self
.
children
]
dataset_size
=
sum
(
children_sizes
)
return
dataset_size
...
...
@@ -2213,8 +2213,8 @@ class RenameDataset(DatasetOp):
output_columns
=
[
output_columns
]
self
.
input_column_names
=
input_columns
self
.
output_column_names
=
output_columns
self
.
input
.
append
(
input_dataset
)
input_dataset
.
outpu
t
.
append
(
self
)
self
.
children
.
append
(
input_dataset
)
input_dataset
.
paren
t
.
append
(
self
)
self
.
_input_indexs
=
input_dataset
.
input_indexs
def
get_args
(
self
):
...
...
@@ -2240,10 +2240,10 @@ class ProjectDataset(DatasetOp):
if
not
isinstance
(
columns
,
list
):
columns
=
[
columns
]
self
.
columns
=
columns
self
.
input
.
append
(
input_dataset
)
self
.
children
.
append
(
input_dataset
)
self
.
prefetch_size
=
prefetch_size
input_dataset
.
outpu
t
.
append
(
self
)
input_dataset
.
paren
t
.
append
(
self
)
self
.
_input_indexs
=
input_dataset
.
input_indexs
def
get_args
(
self
):
...
...
@@ -2267,8 +2267,8 @@ class TransferDataset(DatasetOp):
def
__init__
(
self
,
input_dataset
,
queue_name
,
device_id
,
device_type
,
num_batch
=
None
):
super
().
__init__
()
self
.
input
.
append
(
input_dataset
)
input_dataset
.
outpu
t
.
append
(
self
)
self
.
children
.
append
(
input_dataset
)
input_dataset
.
paren
t
.
append
(
self
)
self
.
queue_name
=
queue_name
self
.
_input_indexs
=
input_dataset
.
input_indexs
self
.
_device_type
=
device_type
...
...
@@ -3170,8 +3170,8 @@ class GeneratorDataset(MappableDataset):
cls
=
self
.
__class__
new_op
=
cls
.
__new__
(
cls
)
memodict
[
id
(
self
)]
=
new_op
new_op
.
input
=
copy
.
deepcopy
(
self
.
input
,
memodict
)
new_op
.
output
=
copy
.
deepcopy
(
self
.
outpu
t
,
memodict
)
new_op
.
children
=
copy
.
deepcopy
(
self
.
children
,
memodict
)
new_op
.
parent
=
copy
.
deepcopy
(
self
.
paren
t
,
memodict
)
new_op
.
num_parallel_workers
=
copy
.
deepcopy
(
self
.
num_parallel_workers
,
memodict
)
new_op
.
column_types
=
copy
.
deepcopy
(
self
.
column_types
,
memodict
)
new_op
.
column_names
=
copy
.
deepcopy
(
self
.
column_names
,
memodict
)
...
...
@@ -4879,14 +4879,14 @@ class BuildVocabDataset(DatasetOp):
prefetch_size
=
None
):
super
().
__init__
()
self
.
columns
=
columns
self
.
input
.
append
(
input_dataset
)
self
.
children
.
append
(
input_dataset
)
self
.
prefetch_size
=
prefetch_size
self
.
vocab
=
vocab
self
.
freq_range
=
freq_range
self
.
top_k
=
top_k
self
.
special_tokens
=
special_tokens
self
.
special_first
=
special_first
input_dataset
.
outpu
t
.
append
(
self
)
input_dataset
.
paren
t
.
append
(
self
)
def
get_args
(
self
):
args
=
super
().
get_args
()
...
...
@@ -4905,11 +4905,11 @@ class BuildVocabDataset(DatasetOp):
cls
=
self
.
__class__
new_op
=
cls
.
__new__
(
cls
)
memodict
[
id
(
self
)]
=
new_op
new_op
.
input
=
copy
.
deepcopy
(
self
.
input
,
memodict
)
new_op
.
children
=
copy
.
deepcopy
(
self
.
children
,
memodict
)
new_op
.
columns
=
copy
.
deepcopy
(
self
.
columns
,
memodict
)
new_op
.
num_parallel_workers
=
copy
.
deepcopy
(
self
.
num_parallel_workers
,
memodict
)
new_op
.
prefetch_size
=
copy
.
deepcopy
(
self
.
prefetch_size
,
memodict
)
new_op
.
output
=
copy
.
deepcopy
(
self
.
outpu
t
,
memodict
)
new_op
.
parent
=
copy
.
deepcopy
(
self
.
paren
t
,
memodict
)
new_op
.
freq_range
=
copy
.
deepcopy
(
self
.
freq_range
,
memodict
)
new_op
.
top_k
=
copy
.
deepcopy
(
self
.
top_k
,
memodict
)
new_op
.
vocab
=
self
.
vocab
...
...
mindspore/dataset/engine/iterators.py
浏览文件 @
3b71bd0d
...
...
@@ -38,13 +38,13 @@ def _cleanup():
def
alter_tree
(
node
):
"""Traversing the python Dataset tree/graph to perform some alteration to some specific nodes."""
if
not
node
.
input
:
if
not
node
.
children
:
return
_alter_node
(
node
)
converted_children
=
[]
for
input_op
in
node
.
input
:
for
input_op
in
node
.
children
:
converted_children
.
append
(
alter_tree
(
input_op
))
node
.
input
=
converted_children
node
.
children
=
converted_children
return
_alter_node
(
node
)
...
...
@@ -86,14 +86,14 @@ class Iterator:
def
__is_tree_node
(
self
,
node
):
"""Check if a node is tree node."""
if
not
node
.
input
:
if
len
(
node
.
outpu
t
)
>
1
:
if
not
node
.
children
:
if
len
(
node
.
paren
t
)
>
1
:
return
False
if
len
(
node
.
outpu
t
)
>
1
:
if
len
(
node
.
paren
t
)
>
1
:
return
False
for
input_node
in
node
.
input
:
for
input_node
in
node
.
children
:
cls
=
self
.
__is_tree_node
(
input_node
)
if
not
cls
:
return
False
...
...
@@ -174,7 +174,7 @@ class Iterator:
op_type
=
self
.
__get_dataset_type
(
node
)
c_node
=
self
.
depipeline
.
AddNodeToTree
(
op_type
,
node
.
get_args
())
for
py_child
in
node
.
input
:
for
py_child
in
node
.
children
:
c_child
=
self
.
__convert_node_postorder
(
py_child
)
self
.
depipeline
.
AddChildToParentNode
(
c_child
,
c_node
)
...
...
@@ -184,7 +184,7 @@ class Iterator:
"""Recursively get batch node in the dataset tree."""
if
isinstance
(
dataset
,
de
.
BatchDataset
):
return
for
input_op
in
dataset
.
input
:
for
input_op
in
dataset
.
children
:
self
.
__batch_node
(
input_op
,
level
+
1
)
@
staticmethod
...
...
@@ -194,11 +194,11 @@ class Iterator:
ptr
=
hex
(
id
(
dataset
))
for
_
in
range
(
level
):
logger
.
info
(
"
\t
"
,
end
=
''
)
if
not
dataset
.
input
:
if
not
dataset
.
children
:
logger
.
info
(
"-%s (%s)"
,
name
,
ptr
)
else
:
logger
.
info
(
"+%s (%s)"
,
name
,
ptr
)
for
input_op
in
dataset
.
input
:
for
input_op
in
dataset
.
children
:
Iterator
.
__print_local
(
input_op
,
level
+
1
)
def
print
(
self
):
...
...
mindspore/dataset/engine/serializer_deserializer.py
浏览文件 @
3b71bd0d
...
...
@@ -182,11 +182,11 @@ def traverse(node):
node_repr
[
'shard_id'
]
=
None
# Leaf node doesn't have input attribute.
if
not
node
.
input
:
if
not
node
.
children
:
return
node_repr
# Recursively traverse the child and assign it to the current node_repr['children'].
for
child
in
node
.
input
:
for
child
in
node
.
children
:
node_repr
[
"children"
].
append
(
traverse
(
child
))
return
node_repr
...
...
@@ -226,11 +226,11 @@ def construct_pipeline(node):
# Instantiate python Dataset object based on the current dictionary element
dataset
=
create_node
(
node
)
# Initially it is not connected to any other object.
dataset
.
input
=
[]
dataset
.
children
=
[]
# Construct the children too and add edge between the children and parent.
for
child
in
node
[
'children'
]:
dataset
.
input
.
append
(
construct_pipeline
(
child
))
dataset
.
children
.
append
(
construct_pipeline
(
child
))
return
dataset
...
...
tests/ut/python/dataset/test_iterator.py
浏览文件 @
3b71bd0d
...
...
@@ -103,7 +103,7 @@ def test_tree_copy():
itr
=
data1
.
create_tuple_iterator
()
assert
id
(
data1
)
!=
id
(
itr
.
dataset
)
assert
id
(
data
)
!=
id
(
itr
.
dataset
.
input
[
0
])
assert
id
(
data
)
!=
id
(
itr
.
dataset
.
children
[
0
])
assert
id
(
data1
.
operations
[
0
])
==
id
(
itr
.
dataset
.
operations
[
0
])
itr
.
release
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录