Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
a5572f15
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a5572f15
编写于
5月 07, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 07, 2020
浏览文件
操作
浏览文件
下载
差异文件
!955 fix some python code format
Merge pull request !955 from panfengfeng/panff/fix_code_format
上级
e75d7585
ee3be682
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
46 addition
and
46 deletion
+46
-46
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+46
-46
未找到文件。
mindspore/dataset/engine/datasets.py
浏览文件 @
a5572f15
...
@@ -205,12 +205,12 @@ class Dataset:
...
@@ -205,12 +205,12 @@ class Dataset:
@
check_sync_wait
@
check_sync_wait
def
sync_wait
(
self
,
condition_name
,
num_batch
=
1
,
callback
=
None
):
def
sync_wait
(
self
,
condition_name
,
num_batch
=
1
,
callback
=
None
):
'''
'''
Add a blocking condition to the input Dataset
Add a blocking condition to the input Dataset
.
Args:
Args:
num_batch (int): the number of batches without blocking at the start of each epoch
num_batch (int): the number of batches without blocking at the start of each epoch
.
condition_name (str): The condition name that is used to toggle sending next row
condition_name (str): The condition name that is used to toggle sending next row
.
callback (function): The callback funciton that will be invoked when sync_update is called
callback (function): The callback funciton that will be invoked when sync_update is called
.
Raises:
Raises:
RuntimeError: If condition name already exists.
RuntimeError: If condition name already exists.
...
@@ -920,13 +920,13 @@ class Dataset:
...
@@ -920,13 +920,13 @@ class Dataset:
def
sync_update
(
self
,
condition_name
,
num_batch
=
None
,
data
=
None
):
def
sync_update
(
self
,
condition_name
,
num_batch
=
None
,
data
=
None
):
"""
"""
Release a blocking condition and triger callback with given data
Release a blocking condition and triger callback with given data
.
Args:
Args:
condition_name (str): The condition name that is used to toggle sending next row
condition_name (str): The condition name that is used to toggle sending next row
.
num_batch (int or None): The number of batches(rows) that are released
num_batch (int or None): The number of batches(rows) that are released
.
When num_batch is None, it will default to the number specified by the sync_wait operator
When num_batch is None, it will default to the number specified by the sync_wait operator.
data (dict or None): The data passed to the callback
data (dict or None): The data passed to the callback
.
"""
"""
notifiers_dict
=
self
.
get_sync_notifiers
()
notifiers_dict
=
self
.
get_sync_notifiers
()
if
condition_name
not
in
notifiers_dict
:
if
condition_name
not
in
notifiers_dict
:
...
@@ -948,7 +948,7 @@ class Dataset:
...
@@ -948,7 +948,7 @@ class Dataset:
def
get_repeat_count
(
self
):
def
get_repeat_count
(
self
):
"""
"""
Get the replication times in RepeatDataset else 1
Get the replication times in RepeatDataset else 1
.
Return:
Return:
Number, the count of repeat.
Number, the count of repeat.
...
@@ -969,7 +969,7 @@ class Dataset:
...
@@ -969,7 +969,7 @@ class Dataset:
raise
NotImplementedError
(
"Dataset {} has not supported api get_class_indexing yet."
.
format
(
type
(
self
)))
raise
NotImplementedError
(
"Dataset {} has not supported api get_class_indexing yet."
.
format
(
type
(
self
)))
def
reset
(
self
):
def
reset
(
self
):
"""Reset the dataset for next epoch"""
"""Reset the dataset for next epoch
.
"""
class
SourceDataset
(
Dataset
):
class
SourceDataset
(
Dataset
):
...
@@ -1085,9 +1085,9 @@ class BatchDataset(DatasetOp):
...
@@ -1085,9 +1085,9 @@ class BatchDataset(DatasetOp):
Utility function to find the case where repeat is used before batch.
Utility function to find the case where repeat is used before batch.
Args:
Args:
dataset (Dataset): dataset to be checked
dataset (Dataset): dataset to be checked
.
Return:
Return:
True or False
True or False
.
"""
"""
if
isinstance
(
dataset
,
RepeatDataset
):
if
isinstance
(
dataset
,
RepeatDataset
):
return
True
return
True
...
@@ -1102,8 +1102,8 @@ class BatchDataset(DatasetOp):
...
@@ -1102,8 +1102,8 @@ class BatchDataset(DatasetOp):
Utility function to notify batch size to sync_wait.
Utility function to notify batch size to sync_wait.
Args:
Args:
dataset (Dataset): dataset to be checked
dataset (Dataset): dataset to be checked
.
batchsize (int): batch size to notify
batchsize (int): batch size to notify
.
"""
"""
if
isinstance
(
dataset
,
SyncWaitDataset
):
if
isinstance
(
dataset
,
SyncWaitDataset
):
dataset
.
update_sync_batch_size
(
batch_size
)
dataset
.
update_sync_batch_size
(
batch_size
)
...
@@ -1136,11 +1136,11 @@ class BatchInfo(CBatchInfo):
...
@@ -1136,11 +1136,11 @@ class BatchInfo(CBatchInfo):
class
BlockReleasePair
:
class
BlockReleasePair
:
"""
"""
The blocking condition class used by SyncWaitDataset
The blocking condition class used by SyncWaitDataset
.
Args:
Args:
init_release_rows (int): Number of lines to allow through the pipeline
init_release_rows (int): Number of lines to allow through the pipeline
.
callback (function): The callback funciton that will be called when release is called
callback (function): The callback funciton that will be called when release is called
.
"""
"""
def
__init__
(
self
,
init_release_rows
,
callback
=
None
):
def
__init__
(
self
,
init_release_rows
,
callback
=
None
):
self
.
row_count
=
-
init_release_rows
self
.
row_count
=
-
init_release_rows
...
@@ -1183,13 +1183,13 @@ class BlockReleasePair:
...
@@ -1183,13 +1183,13 @@ class BlockReleasePair:
class
SyncWaitDataset
(
DatasetOp
):
class
SyncWaitDataset
(
DatasetOp
):
"""
"""
The result of adding a blocking condition to the input Dataset
The result of adding a blocking condition to the input Dataset
.
Args:
Args:
input_dataset (Dataset): Input dataset to apply flow control
input_dataset (Dataset): Input dataset to apply flow control
.
num_batch (int): the number of batches without blocking at the start of each epoch
num_batch (int): the number of batches without blocking at the start of each epoch
.
condition_name (str): The condition name that is used to toggle sending next row
condition_name (str): The condition name that is used to toggle sending next row
.
callback (function): The callback funciton that will be invoked when sync_update is called
callback (function): The callback funciton that will be invoked when sync_update is called
.
Raises:
Raises:
RuntimeError: If condition name already exists.
RuntimeError: If condition name already exists.
...
@@ -1226,9 +1226,9 @@ class SyncWaitDataset(DatasetOp):
...
@@ -1226,9 +1226,9 @@ class SyncWaitDataset(DatasetOp):
Utility function to find the case where sync_wait is used before batch.
Utility function to find the case where sync_wait is used before batch.
Args:
Args:
dataset (Dataset): dataset to be checked
dataset (Dataset): dataset to be checked
.
Return:
Return:
True or False
True or False
.
"""
"""
if
isinstance
(
dataset
,
BatchDataset
):
if
isinstance
(
dataset
,
BatchDataset
):
return
True
return
True
...
@@ -1289,7 +1289,7 @@ def _pyfunc_worker_exec(index, *args):
...
@@ -1289,7 +1289,7 @@ def _pyfunc_worker_exec(index, *args):
# PythonCallable wrapper for multiprocess pyfunc
# PythonCallable wrapper for multiprocess pyfunc
class
_PythonCallable
:
class
_PythonCallable
:
"""
"""
Internal python function wrapper for multiprocessing pyfunc
Internal python function wrapper for multiprocessing pyfunc
.
"""
"""
def
__init__
(
self
,
py_callable
,
idx
,
pool
=
None
):
def
__init__
(
self
,
py_callable
,
idx
,
pool
=
None
):
# Original python callable from user.
# Original python callable from user.
...
@@ -1467,7 +1467,7 @@ class FilterDataset(DatasetOp):
...
@@ -1467,7 +1467,7 @@ class FilterDataset(DatasetOp):
def
get_dataset_size
(
self
):
def
get_dataset_size
(
self
):
"""
"""
Get the number of batches in an epoch.
Get the number of batches in an epoch.
the size cannot be determined before we run the pipeline
the size cannot be determined before we run the pipeline
.
Return:
Return:
0
0
"""
"""
...
@@ -1759,7 +1759,7 @@ class StorageDataset(SourceDataset):
...
@@ -1759,7 +1759,7 @@ class StorageDataset(SourceDataset):
columns_list (list[str], optional): List of columns to be read (default=None, read all columns).
columns_list (list[str], optional): List of columns to be read (default=None, read all columns).
num_parallel_workers (int, optional): Number of parallel working threads (default=None).
num_parallel_workers (int, optional): Number of parallel working threads (default=None).
deterministic_output (bool, optional): Whether the result of this dataset can be reproduced
deterministic_output (bool, optional): Whether the result of this dataset can be reproduced
or not (default=True). If True, performance might be affected.
or not (default=True). If True, performance might be affected.
prefetch_size (int, optional): Prefetch number of records ahead of the user's request (default=None).
prefetch_size (int, optional): Prefetch number of records ahead of the user's request (default=None).
Raises:
Raises:
...
@@ -1889,11 +1889,11 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
...
@@ -1889,11 +1889,11 @@ def _select_sampler(num_samples, input_sampler, shuffle, num_shards, shard_id):
Create sampler based on user input.
Create sampler based on user input.
Args:
Args:
num_samples (int): Number of samples
num_samples (int): Number of samples
.
input_sampler (Iterable / Sampler): Sampler from user
input_sampler (Iterable / Sampler): Sampler from user
.
shuffle (bool): Shuffle
shuffle (bool): Shuffle
.
num_shards (int): Number of shard for sharding
num_shards (int): Number of shard for sharding
.
shard_id (int): Shard ID
shard_id (int): Shard ID
.
"""
"""
if
shuffle
is
None
:
if
shuffle
is
None
:
if
input_sampler
is
not
None
:
if
input_sampler
is
not
None
:
...
@@ -2265,7 +2265,7 @@ class MindDataset(SourceDataset):
...
@@ -2265,7 +2265,7 @@ class MindDataset(SourceDataset):
def
_iter_fn
(
dataset
,
num_samples
):
def
_iter_fn
(
dataset
,
num_samples
):
"""
"""
Generator function wrapper for iterable dataset
Generator function wrapper for iterable dataset
.
"""
"""
if
num_samples
is
not
None
:
if
num_samples
is
not
None
:
ds_iter
=
iter
(
dataset
)
ds_iter
=
iter
(
dataset
)
...
@@ -2284,7 +2284,7 @@ def _iter_fn(dataset, num_samples):
...
@@ -2284,7 +2284,7 @@ def _iter_fn(dataset, num_samples):
def
_generator_fn
(
generator
,
num_samples
):
def
_generator_fn
(
generator
,
num_samples
):
"""
"""
Generator function wrapper for generator function dataset
Generator function wrapper for generator function dataset
.
"""
"""
if
num_samples
is
not
None
:
if
num_samples
is
not
None
:
gen_iter
=
generator
()
gen_iter
=
generator
()
...
@@ -2302,7 +2302,7 @@ def _generator_fn(generator, num_samples):
...
@@ -2302,7 +2302,7 @@ def _generator_fn(generator, num_samples):
def
_py_sampler_fn
(
sampler
,
num_samples
,
dataset
):
def
_py_sampler_fn
(
sampler
,
num_samples
,
dataset
):
"""
"""
Generator function wrapper for mappable dataset with python sampler
Generator function wrapper for mappable dataset with python sampler
.
"""
"""
if
num_samples
is
not
None
:
if
num_samples
is
not
None
:
sampler_iter
=
iter
(
sampler
)
sampler_iter
=
iter
(
sampler
)
...
@@ -2323,7 +2323,7 @@ def _py_sampler_fn(sampler, num_samples, dataset):
...
@@ -2323,7 +2323,7 @@ def _py_sampler_fn(sampler, num_samples, dataset):
def
_cpp_sampler_fn
(
sampler
,
dataset
):
def
_cpp_sampler_fn
(
sampler
,
dataset
):
"""
"""
Generator function wrapper for mappable dataset with cpp sampler
Generator function wrapper for mappable dataset with cpp sampler
.
"""
"""
indices
=
sampler
.
get_indices
()
indices
=
sampler
.
get_indices
()
for
i
in
indices
:
for
i
in
indices
:
...
@@ -2334,7 +2334,7 @@ def _cpp_sampler_fn(sampler, dataset):
...
@@ -2334,7 +2334,7 @@ def _cpp_sampler_fn(sampler, dataset):
def
_cpp_sampler_fn_mp
(
sampler
,
dataset
,
num_worker
):
def
_cpp_sampler_fn_mp
(
sampler
,
dataset
,
num_worker
):
"""
"""
Multiprocessing generator function wrapper for mappable dataset with cpp sampler
Multiprocessing generator function wrapper for mappable dataset with cpp sampler
.
"""
"""
indices
=
sampler
.
get_indices
()
indices
=
sampler
.
get_indices
()
return
_sampler_fn_mp
(
indices
,
dataset
,
num_worker
)
return
_sampler_fn_mp
(
indices
,
dataset
,
num_worker
)
...
@@ -2342,7 +2342,7 @@ def _cpp_sampler_fn_mp(sampler, dataset, num_worker):
...
@@ -2342,7 +2342,7 @@ def _cpp_sampler_fn_mp(sampler, dataset, num_worker):
def
_py_sampler_fn_mp
(
sampler
,
num_samples
,
dataset
,
num_worker
):
def
_py_sampler_fn_mp
(
sampler
,
num_samples
,
dataset
,
num_worker
):
"""
"""
Multiprocessing generator function wrapper for mappable dataset with python sampler
Multiprocessing generator function wrapper for mappable dataset with python sampler
.
"""
"""
indices
=
_fetch_py_sampler_indices
(
sampler
,
num_samples
)
indices
=
_fetch_py_sampler_indices
(
sampler
,
num_samples
)
return
_sampler_fn_mp
(
indices
,
dataset
,
num_worker
)
return
_sampler_fn_mp
(
indices
,
dataset
,
num_worker
)
...
@@ -2350,7 +2350,7 @@ def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker):
...
@@ -2350,7 +2350,7 @@ def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker):
def
_fetch_py_sampler_indices
(
sampler
,
num_samples
):
def
_fetch_py_sampler_indices
(
sampler
,
num_samples
):
"""
"""
Indices fetcher for python sampler
Indices fetcher for python sampler
.
"""
"""
if
num_samples
is
not
None
:
if
num_samples
is
not
None
:
sampler_iter
=
iter
(
sampler
)
sampler_iter
=
iter
(
sampler
)
...
@@ -2367,7 +2367,7 @@ def _fetch_py_sampler_indices(sampler, num_samples):
...
@@ -2367,7 +2367,7 @@ def _fetch_py_sampler_indices(sampler, num_samples):
def
_fill_worker_indices
(
workers
,
indices
,
idx
):
def
_fill_worker_indices
(
workers
,
indices
,
idx
):
"""
"""
Worker index queue filler, fill worker index queue in round robin order
Worker index queue filler, fill worker index queue in round robin order
.
"""
"""
num_worker
=
len
(
workers
)
num_worker
=
len
(
workers
)
while
idx
<
len
(
indices
):
while
idx
<
len
(
indices
):
...
@@ -2381,7 +2381,7 @@ def _fill_worker_indices(workers, indices, idx):
...
@@ -2381,7 +2381,7 @@ def _fill_worker_indices(workers, indices, idx):
def
_sampler_fn_mp
(
indices
,
dataset
,
num_worker
):
def
_sampler_fn_mp
(
indices
,
dataset
,
num_worker
):
"""
"""
Multiprocessing generator function wrapper master process
Multiprocessing generator function wrapper master process
.
"""
"""
workers
=
[]
workers
=
[]
# Event for end of epoch
# Event for end of epoch
...
@@ -2423,7 +2423,7 @@ def _sampler_fn_mp(indices, dataset, num_worker):
...
@@ -2423,7 +2423,7 @@ def _sampler_fn_mp(indices, dataset, num_worker):
def
_generator_worker_loop
(
dataset
,
idx_queue
,
result_queue
,
eoe
):
def
_generator_worker_loop
(
dataset
,
idx_queue
,
result_queue
,
eoe
):
"""
"""
Multiprocessing generator worker process loop
Multiprocessing generator worker process loop
.
"""
"""
while
True
:
while
True
:
# Fetch index, block
# Fetch index, block
...
@@ -2448,7 +2448,7 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eoe):
...
@@ -2448,7 +2448,7 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eoe):
class
_GeneratorWorker
(
multiprocessing
.
Process
):
class
_GeneratorWorker
(
multiprocessing
.
Process
):
"""
"""
Worker process for multiprocess Generator
Worker process for multiprocess Generator
.
"""
"""
def
__init__
(
self
,
dataset
,
eoe
):
def
__init__
(
self
,
dataset
,
eoe
):
self
.
idx_queue
=
multiprocessing
.
Queue
(
16
)
self
.
idx_queue
=
multiprocessing
.
Queue
(
16
)
...
@@ -2932,7 +2932,7 @@ class ManifestDataset(SourceDataset):
...
@@ -2932,7 +2932,7 @@ class ManifestDataset(SourceDataset):
def
get_class_indexing
(
self
):
def
get_class_indexing
(
self
):
"""
"""
Get the class index
Get the class index
.
Return:
Return:
Dict, A str-to-int mapping from label name to index.
Dict, A str-to-int mapping from label name to index.
...
@@ -3500,7 +3500,7 @@ class VOCDataset(SourceDataset):
...
@@ -3500,7 +3500,7 @@ class VOCDataset(SourceDataset):
class
CelebADataset
(
SourceDataset
):
class
CelebADataset
(
SourceDataset
):
"""
"""
A source dataset for reading and parsing CelebA dataset.Only support list_attr_celeba.txt currently
A source dataset for reading and parsing CelebA dataset.Only support list_attr_celeba.txt currently
.
Note:
Note:
The generated dataset has two columns ['image', 'attr'].
The generated dataset has two columns ['image', 'attr'].
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录