Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
277aba53
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
277aba53
编写于
6月 25, 2020
作者:
C
Cathy Wong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
dataset: Fixup docs; remove pylint disabled messages in UT
上级
e11c9532
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
39 addition
and
73 deletion
+39
-73
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+3
-3
mindspore/dataset/transforms/vision/c_transforms.py
mindspore/dataset/transforms/vision/c_transforms.py
+3
-3
tests/ut/data/dataset/declient.cfg
tests/ut/data/dataset/declient.cfg
+2
-1
tests/ut/python/dataset/test_batch.py
tests/ut/python/dataset/test_batch.py
+3
-5
tests/ut/python/dataset/test_center_crop.py
tests/ut/python/dataset/test_center_crop.py
+3
-8
tests/ut/python/dataset/test_config.py
tests/ut/python/dataset/test_config.py
+6
-1
tests/ut/python/dataset/test_filterop.py
tests/ut/python/dataset/test_filterop.py
+14
-43
tests/ut/python/dataset/test_pad.py
tests/ut/python/dataset/test_pad.py
+5
-9
未找到文件。
mindspore/dataset/engine/datasets.py
浏览文件 @
277aba53
...
...
@@ -1040,7 +1040,7 @@ class Dataset:
Args:
columns (list[str], optional): List of columns to be used to specify the order of columns
(default
s
=None, means all columns).
(default=None, means all columns).
Returns:
Iterator, list of ndarray.
...
...
@@ -3382,7 +3382,7 @@ class ManifestDataset(MappableDataset):
class_indexing (dict, optional): A str-to-int mapping from label name to index
(default=None, the folder names will be sorted alphabetically and each
class will be given a unique index starting from 0).
decode (bool, optional): decode the images after reading (default
s
=False).
decode (bool, optional): decode the images after reading (default=False).
num_shards (int, optional): Number of shards that the dataset should be divided
into (default=None).
shard_id (int, optional): The shard ID within num_shards (default=None). This
...
...
@@ -4760,7 +4760,7 @@ class _NumpySlicesDataset:
def
process_dict
(
self
,
input_data
):
"""
Convert the dict like data into tuple format, when input is a tuple of dict then compose it into a dict first.
Convert the dict like data into tuple format, when input is a tuple of dict
s
then compose it into a dict first.
"""
# Convert pandas like dict(has "values" column) into General dict
data_keys
=
list
(
input_data
.
keys
())
...
...
mindspore/dataset/transforms/vision/c_transforms.py
浏览文件 @
277aba53
...
...
@@ -202,7 +202,7 @@ class RandomHorizontalFlip(cde.RandomHorizontalFlipOp):
Flip the input image horizontally, randomly with a given probability.
Args:
prob (float): Probability of the image being flipped (default=0.5).
prob (float
, optional
): Probability of the image being flipped (default=0.5).
"""
@
check_prob
...
...
@@ -217,7 +217,7 @@ class RandomHorizontalFlipWithBBox(cde.RandomHorizontalFlipWithBBoxOp):
Maintains data integrity by also flipping bounding boxes in an object detection pipeline.
Args:
prob (float): Probability of the image being flipped (default=0.5).
prob (float
, optional
): Probability of the image being flipped (default=0.5).
"""
@
check_prob
...
...
@@ -231,7 +231,7 @@ class RandomVerticalFlip(cde.RandomVerticalFlipOp):
Flip the input image vertically, randomly with a given probability.
Args:
prob (float): Probability of the image being flipped (default=0.5).
prob (float
, optional
): Probability of the image being flipped (default=0.5).
"""
@
check_prob
...
...
tests/ut/data/dataset/declient.cfg
浏览文件 @
277aba53
...
...
@@ -4,6 +4,7 @@
"numParallelWorkers": 4,
"workerConnectorSize": 16,
"opConnectorSize": 16,
"seed": 5489
"seed": 5489,
"monitor_sampling_interval": 15
}
tests/ut/python/dataset/test_batch.py
浏览文件 @
277aba53
...
...
@@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
util
import
save_and_check
import
mindspore.dataset
as
ds
from
mindspore
import
log
as
logger
from
util
import
save_and_check
# Note: Number of rows in test.data dataset: 12
DATA_DIR
=
[
"../data/dataset/testTFTestAllTypes/test.data"
]
...
...
@@ -434,7 +433,6 @@ def test_batch_exception_11():
assert
"drop_remainder"
in
str
(
e
)
# pylint: disable=redundant-keyword-arg
def
test_batch_exception_12
():
"""
Test batch exception: wrong input order, drop_remainder wrongly used as batch_size
...
...
@@ -447,12 +445,12 @@ def test_batch_exception_12():
# apply dataset operations
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
)
try
:
data1
=
data1
.
batch
(
drop_remainder
,
batch_size
=
batch_size
)
data1
=
data1
.
batch
(
drop_remainder
,
batch_size
)
sum
([
1
for
_
in
data1
])
except
Exception
as
e
:
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
e
)))
assert
"
batch_size
"
in
str
(
e
)
assert
"
drop_remainder
"
in
str
(
e
)
def
test_batch_exception_13
():
...
...
tests/ut/python/dataset/test_center_crop.py
浏览文件 @
277aba53
...
...
@@ -109,23 +109,18 @@ def test_center_crop_comp(height=375, width=375, plot=False):
visualize_list
(
image_c_cropped
,
image_py_cropped
,
visualize_mode
=
2
)
# pylint: disable=unnecessary-lambda
def
test_crop_grayscale
(
height
=
375
,
width
=
375
):
"""
Test that centercrop works with pad and grayscale images
"""
def
channel_swap
(
image
):
"""
Py func hack for our pytransforms to work with c transforms
"""
return
(
image
.
transpose
(
1
,
2
,
0
)
*
255
).
astype
(
np
.
uint8
)
# Note: image.transpose performs channel swap to allow py transforms to
# work with c transforms
transforms
=
[
py_vision
.
Decode
(),
py_vision
.
Grayscale
(
1
),
py_vision
.
ToTensor
(),
(
lambda
image
:
channel_swap
(
image
))
(
lambda
image
:
(
image
.
transpose
(
1
,
2
,
0
)
*
255
).
astype
(
np
.
uint8
))
]
transform
=
py_vision
.
ComposeOp
(
transforms
)
...
...
tests/ut/python/dataset/test_config.py
浏览文件 @
277aba53
...
...
@@ -37,6 +37,7 @@ def test_basic():
num_parallel_workers_original
=
ds
.
config
.
get_num_parallel_workers
()
prefetch_size_original
=
ds
.
config
.
get_prefetch_size
()
seed_original
=
ds
.
config
.
get_seed
()
monitor_sampling_interval_original
=
ds
.
config
.
get_monitor_sampling_interval
()
ds
.
config
.
load
(
'../data/dataset/declient.cfg'
)
...
...
@@ -45,23 +46,27 @@ def test_basic():
# assert ds.config.get_worker_connector_size() == 16
assert
ds
.
config
.
get_prefetch_size
()
==
16
assert
ds
.
config
.
get_seed
()
==
5489
# assert ds.config.get_monitor_sampling_interval() == 15
# ds.config.set_rows_per_buffer(1)
ds
.
config
.
set_num_parallel_workers
(
2
)
# ds.config.set_worker_connector_size(3)
ds
.
config
.
set_prefetch_size
(
4
)
ds
.
config
.
set_seed
(
5
)
ds
.
config
.
set_monitor_sampling_interval
(
45
)
# assert ds.config.get_rows_per_buffer() == 1
assert
ds
.
config
.
get_num_parallel_workers
()
==
2
# assert ds.config.get_worker_connector_size() == 3
assert
ds
.
config
.
get_prefetch_size
()
==
4
assert
ds
.
config
.
get_seed
()
==
5
assert
ds
.
config
.
get_monitor_sampling_interval
()
==
45
# Restore original configuration values
ds
.
config
.
set_num_parallel_workers
(
num_parallel_workers_original
)
ds
.
config
.
set_prefetch_size
(
prefetch_size_original
)
ds
.
config
.
set_seed
(
seed_original
)
ds
.
config
.
set_monitor_sampling_interval
(
monitor_sampling_interval_original
)
def
test_get_seed
():
...
...
@@ -150,7 +155,7 @@ def test_deterministic_run_fail():
def
test_deterministic_run_pass
():
"""
Test deterministic run with
with
setting the seed
Test deterministic run with setting the seed
"""
logger
.
info
(
"test_deterministic_run_pass"
)
...
...
tests/ut/python/dataset/test_filterop.py
浏览文件 @
277aba53
...
...
@@ -50,9 +50,7 @@ def test_diff_predicate_func():
def
filter_func_ge
(
data
):
if
data
>
10
:
return
False
return
True
return
data
<=
10
def
generator_1d
():
...
...
@@ -108,15 +106,11 @@ def test_filter_by_generator_with_repeat_after():
def
filter_func_batch
(
data
):
if
data
[
0
]
>
8
:
return
False
return
True
return
data
[
0
]
<=
8
def
filter_func_batch_after
(
data
):
if
data
>
20
:
return
False
return
True
return
data
<=
20
# test with batchOp before
...
...
@@ -152,9 +146,7 @@ def test_filter_by_generator_with_batch_after():
def
filter_func_shuffle
(
data
):
if
data
>
20
:
return
False
return
True
return
data
<=
20
# test with batchOp before
...
...
@@ -169,9 +161,7 @@ def test_filter_by_generator_with_shuffle():
def
filter_func_shuffle_after
(
data
):
if
data
>
20
:
return
False
return
True
return
data
<=
20
# test with batchOp after
...
...
@@ -197,15 +187,11 @@ def generator_1d_zip2():
def
filter_func_zip
(
data1
,
data2
):
_
=
data2
if
data1
>
20
:
return
False
return
True
return
data1
<=
20
def
filter_func_zip_after
(
data1
):
if
data1
>
20
:
return
False
return
True
return
data1
<=
20
# test with zipOp before
...
...
@@ -247,16 +233,11 @@ def test_filter_by_generator_with_zip_after():
def
filter_func_map
(
col1
,
col2
):
_
=
col2
if
col1
[
0
]
>
8
:
return
True
return
False
return
col1
[
0
]
>
8
# pylint: disable=simplifiable-if-statement
def
filter_func_map_part
(
col1
):
if
col1
<
3
:
return
True
return
False
return
col1
<
3
def
filter_func_map_all
(
col1
,
col2
):
...
...
@@ -311,9 +292,7 @@ def test_filter_by_generator_with_map_part_col():
def
filter_func_rename
(
data
):
if
data
>
8
:
return
True
return
False
return
data
>
8
# test with rename before
...
...
@@ -334,15 +313,11 @@ def test_filter_by_generator_with_rename():
# test input_column
def
filter_func_input_column1
(
col1
,
col2
):
_
=
col2
if
col1
[
0
]
<
8
:
return
True
return
False
return
col1
[
0
]
<
8
def
filter_func_input_column2
(
col1
):
if
col1
[
0
]
<
8
:
return
True
return
False
return
col1
[
0
]
<
8
def
filter_func_input_column3
(
col1
):
...
...
@@ -439,9 +414,7 @@ def test_filter_by_generator_Partial2():
def
filter_func_Partial
(
col1
,
col2
):
_
=
col2
if
col1
[
0
]
%
3
==
0
:
return
True
return
False
return
col1
[
0
]
%
3
==
0
def
generator_big
(
maxid
=
20
):
...
...
@@ -461,9 +434,7 @@ def test_filter_by_generator_Partial():
def
filter_func_cifar
(
col1
,
col2
):
_
=
col1
if
col2
%
3
==
0
:
return
True
return
False
return
col2
%
3
==
0
# test with cifar10
...
...
tests/ut/python/dataset/test_pad.py
浏览文件 @
277aba53
...
...
@@ -16,12 +16,12 @@
Testing Pad op in DE
"""
import
numpy
as
np
from
util
import
diff_mse
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.vision.c_transforms
as
c_vision
import
mindspore.dataset.transforms.vision.py_transforms
as
py_vision
from
mindspore
import
log
as
logger
from
util
import
diff_mse
DATA_DIR
=
[
"../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
]
SCHEMA_DIR
=
"../data/dataset/test_tf_file_3_images/datasetSchema.json"
...
...
@@ -69,23 +69,19 @@ def test_pad_op():
assert
mse
<
0.01
# pylint: disable=unnecessary-lambda
def
test_pad_grayscale
():
"""
Tests that the pad works for grayscale images
"""
def
channel_swap
(
image
):
"""
Py func hack for our pytransforms to work with c transforms
"""
return
(
image
.
transpose
(
1
,
2
,
0
)
*
255
).
astype
(
np
.
uint8
)
# Note: image.transpose performs channel swap to allow py transforms to
# work with c transforms
transforms
=
[
py_vision
.
Decode
(),
py_vision
.
Grayscale
(
1
),
py_vision
.
ToTensor
(),
(
lambda
image
:
channel_swap
(
image
))
(
lambda
image
:
(
image
.
transpose
(
1
,
2
,
0
)
*
255
).
astype
(
np
.
uint8
))
]
transform
=
py_vision
.
ComposeOp
(
transforms
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录