Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
f3ebc731
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f3ebc731
编写于
6月 13, 2020
作者:
J
jonyguo
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix: MindDataset padded log error
上级
b3f09b1d
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
111 addition
and
19 deletion
+111
-19
mindspore/ccsrc/mindrecord/include/shard_distributed_sample.h
...spore/ccsrc/mindrecord/include/shard_distributed_sample.h
+2
-0
mindspore/ccsrc/mindrecord/meta/shard_distributed_sample.cc
mindspore/ccsrc/mindrecord/meta/shard_distributed_sample.cc
+10
-3
tests/ut/python/dataset/test_minddataset_padded.py
tests/ut/python/dataset/test_minddataset_padded.py
+99
-16
未找到文件。
mindspore/ccsrc/mindrecord/include/shard_distributed_sample.h
浏览文件 @
f3ebc731
...
...
@@ -40,6 +40,8 @@ class ShardDistributedSample : public ShardSample {
private:
bool
shuffle_
;
int
no_of_padded_samples_
;
bool
init_judgment_
;
// we should judge the (num_sample + num_padded) % num_shards == 0 in first time
};
}
// namespace mindrecord
}
// namespace mindspore
...
...
mindspore/ccsrc/mindrecord/meta/shard_distributed_sample.cc
浏览文件 @
f3ebc731
...
...
@@ -24,7 +24,10 @@ namespace mindspore {
namespace
mindrecord
{
ShardDistributedSample
::
ShardDistributedSample
(
int
num_shards
,
int
shard_id
,
int
no_of_padded_samples
,
bool
shuffle
,
uint32_t
seed
)
:
ShardSample
(
1
,
num_shards
,
shard_id
),
shuffle_
(
shuffle
),
no_of_padded_samples_
(
no_of_padded_samples
)
{
:
ShardSample
(
1
,
num_shards
,
shard_id
),
shuffle_
(
shuffle
),
no_of_padded_samples_
(
no_of_padded_samples
),
init_judgment_
(
false
)
{
shuffle_op_
=
std
::
make_shared
<
ShardShuffle
>
(
seed
,
kShuffleSample
);
}
...
...
@@ -45,11 +48,15 @@ int64_t ShardDistributedSample::GetNumSamples(int64_t dataset_size, int64_t num_
}
return
0
;
}
MSRStatus
ShardDistributedSample
::
PreExecute
(
ShardTask
&
tasks
)
{
auto
total_no
=
tasks
.
Size
();
if
(
no_of_padded_samples_
>
0
)
{
if
(
no_of_padded_samples_
>
0
&&
init_judgment_
==
false
)
{
// we only judge this in first time
init_judgment_
=
true
;
if
(
total_no
%
denominator_
!=
0
)
{
MS_LOG
(
ERROR
)
<<
"Dataset size plus number of padded samples is not divisible by number of shards."
;
MS_LOG
(
ERROR
)
<<
"Dataset size plus number of padded samples is not divisible by number of shards. "
<<
"task size: "
<<
total_no
<<
", number padded: "
<<
no_of_padded_samples_
<<
", denominator: "
<<
denominator_
;
return
FAILED
;
}
}
...
...
tests/ut/python/dataset/test_minddataset_padded.py
浏览文件 @
f3ebc731
...
...
@@ -120,7 +120,7 @@ def test_cv_minddataset_reader_basic_padded_samples(add_and_remove_cv_file):
assert
item
[
'label'
]
==
padded_sample
[
'label'
]
assert
(
item
[
'data'
]
==
np
.
array
(
list
(
padded_sample
[
'data'
]))).
all
()
num_iter
+=
1
assert
num_padded_iter
==
5
assert
num_padded_iter
==
5
assert
num_iter
==
15
...
...
@@ -135,6 +135,8 @@ def test_cv_minddataset_partition_padded_samples(add_and_remove_cv_file):
num_readers
=
4
def
partitions
(
num_shards
,
num_padded
,
dataset_size
):
num_padded_iter
=
0
num_iter
=
0
for
partition_id
in
range
(
num_shards
):
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
+
"0"
,
columns_list
,
num_readers
,
num_shards
=
num_shards
,
...
...
@@ -142,8 +144,6 @@ def test_cv_minddataset_partition_padded_samples(add_and_remove_cv_file):
padded_sample
=
padded_sample
,
num_padded
=
num_padded
)
assert
data_set
.
get_dataset_size
()
==
dataset_size
num_iter
=
0
num_padded_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
"-------------- partition : {} ------------------------"
.
format
(
partition_id
))
logger
.
info
(
"-------------- len(item[data]): {} ------------------------"
.
format
(
len
(
item
[
"data"
])))
...
...
@@ -156,11 +156,53 @@ def test_cv_minddataset_partition_padded_samples(add_and_remove_cv_file):
assert
item
[
'label'
]
==
padded_sample
[
'label'
]
assert
(
item
[
'data'
]
==
np
.
array
(
list
(
padded_sample
[
'data'
]))).
all
()
num_iter
+=
1
return
num_iter
assert
num_padded_iter
==
num_padded
return
num_iter
==
dataset_size
*
num_shards
partitions
(
4
,
2
,
3
)
partitions
(
5
,
5
,
3
)
partitions
(
9
,
8
,
2
)
def
test_cv_minddataset_partition_padded_samples_multi_epoch
(
add_and_remove_cv_file
):
"""tutorial for cv minddataset."""
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
data
=
get_data
(
CV_DIR_NAME
)
padded_sample
=
data
[
0
]
padded_sample
[
'label'
]
=
-
2
padded_sample
[
'file_name'
]
=
'dummy.jpg'
num_readers
=
4
def
partitions
(
num_shards
,
num_padded
,
dataset_size
):
repeat_size
=
5
num_padded_iter
=
0
num_iter
=
0
for
partition_id
in
range
(
num_shards
):
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
+
"0"
,
columns_list
,
num_readers
,
num_shards
=
num_shards
,
shard_id
=
partition_id
,
padded_sample
=
padded_sample
,
num_padded
=
num_padded
)
assert
data_set
.
get_dataset_size
()
==
dataset_size
data_set
=
data_set
.
repeat
(
repeat_size
)
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
"-------------- partition : {} ------------------------"
.
format
(
partition_id
))
logger
.
info
(
"-------------- len(item[data]): {} ------------------------"
.
format
(
len
(
item
[
"data"
])))
logger
.
info
(
"-------------- item[data]: {} -----------------------------"
.
format
(
item
[
"data"
]))
logger
.
info
(
"-------------- item[file_name]: {} ------------------------"
.
format
(
item
[
"file_name"
]))
logger
.
info
(
"-------------- item[label]: {} -----------------------"
.
format
(
item
[
"label"
]))
if
item
[
'label'
]
==
-
2
:
num_padded_iter
+=
1
assert
item
[
'file_name'
]
==
bytes
(
padded_sample
[
'file_name'
],
encoding
=
'utf8'
)
assert
item
[
'label'
]
==
padded_sample
[
'label'
]
assert
(
item
[
'data'
]
==
np
.
array
(
list
(
padded_sample
[
'data'
]))).
all
()
num_iter
+=
1
assert
num_padded_iter
==
num_padded
*
repeat_size
assert
num_iter
==
dataset_size
*
num_shards
*
repeat_size
assert
partitions
(
4
,
2
,
3
)
==
3
assert
partitions
(
5
,
5
,
3
)
==
3
assert
partitions
(
9
,
8
,
2
)
==
2
partitions
(
4
,
2
,
3
)
partitions
(
5
,
5
,
3
)
partitions
(
9
,
8
,
2
)
def
test_cv_minddataset_partition_padded_samples_no_dividsible
(
add_and_remove_cv_file
):
"""tutorial for cv minddataset."""
...
...
@@ -308,6 +350,8 @@ def test_nlp_minddataset_reader_basic_padded_samples(add_and_remove_nlp_file):
num_readers
=
4
def
partitions
(
num_shards
,
num_padded
,
dataset_size
):
num_padded_iter
=
0
num_iter
=
0
for
partition_id
in
range
(
num_shards
):
data_set
=
ds
.
MindDataset
(
NLP_FILE_NAME
+
"0"
,
columns_list
,
num_readers
,
num_shards
=
num_shards
,
...
...
@@ -315,22 +359,61 @@ def test_nlp_minddataset_reader_basic_padded_samples(add_and_remove_nlp_file):
padded_sample
=
padded_sample
,
num_padded
=
num_padded
)
assert
data_set
.
get_dataset_size
()
==
dataset_size
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
"-------------- item[id]: {} ------------------------"
.
format
(
item
[
"id"
]))
logger
.
info
(
"-------------- item[rating]: {} --------------------"
.
format
(
item
[
"rating"
]))
logger
.
info
(
"-------------- item[input_ids]: {}, shape: {} -----------------"
.
format
(
item
[
"input_ids"
],
item
[
"input_ids"
].
shape
))
if
item
[
'id'
]
==
'-1'
:
if
item
[
'id'
]
==
bytes
(
'-1'
,
encoding
=
'utf-8'
)
:
num_padded_iter
+=
1
assert
item
[
'id'
]
==
padded_sample
[
'id'
]
assert
item
[
'input_ids'
]
==
padded_sample
[
'input_ids'
]
assert
item
[
'rating'
]
==
padded_sample
[
'rating'
]
assert
item
[
'id'
]
==
bytes
(
padded_sample
[
'id'
],
encoding
=
'utf-8'
)
assert
(
item
[
'input_ids'
]
==
padded_sample
[
'input_ids'
]).
all
()
assert
(
item
[
'rating'
]
==
padded_sample
[
'rating'
]).
all
()
num_iter
+=
1
return
num_iter
assert
num_padded_iter
==
num_padded
assert
num_iter
==
dataset_size
*
num_shards
partitions
(
4
,
6
,
4
)
partitions
(
5
,
5
,
3
)
partitions
(
9
,
8
,
2
)
def
test_nlp_minddataset_reader_basic_padded_samples_multi_epoch
(
add_and_remove_nlp_file
):
columns_list
=
[
"input_ids"
,
"id"
,
"rating"
]
data
=
[
x
for
x
in
get_nlp_data
(
NLP_FILE_POS
,
NLP_FILE_VOCAB
,
10
)]
padded_sample
=
data
[
0
]
padded_sample
[
'id'
]
=
"-1"
padded_sample
[
'input_ids'
]
=
np
.
array
([
-
1
,
-
1
,
-
1
,
-
1
],
dtype
=
np
.
int64
)
padded_sample
[
'rating'
]
=
1.0
num_readers
=
4
repeat_size
=
3
def
partitions
(
num_shards
,
num_padded
,
dataset_size
):
num_padded_iter
=
0
num_iter
=
0
for
partition_id
in
range
(
num_shards
):
data_set
=
ds
.
MindDataset
(
NLP_FILE_NAME
+
"0"
,
columns_list
,
num_readers
,
num_shards
=
num_shards
,
shard_id
=
partition_id
,
padded_sample
=
padded_sample
,
num_padded
=
num_padded
)
assert
data_set
.
get_dataset_size
()
==
dataset_size
data_set
=
data_set
.
repeat
(
repeat_size
)
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
"-------------- item[id]: {} ------------------------"
.
format
(
item
[
"id"
]))
logger
.
info
(
"-------------- item[rating]: {} --------------------"
.
format
(
item
[
"rating"
]))
logger
.
info
(
"-------------- item[input_ids]: {}, shape: {} -----------------"
.
format
(
item
[
"input_ids"
],
item
[
"input_ids"
].
shape
))
if
item
[
'id'
]
==
bytes
(
'-1'
,
encoding
=
'utf-8'
):
num_padded_iter
+=
1
assert
item
[
'id'
]
==
bytes
(
padded_sample
[
'id'
],
encoding
=
'utf-8'
)
assert
(
item
[
'input_ids'
]
==
padded_sample
[
'input_ids'
]).
all
()
assert
(
item
[
'rating'
]
==
padded_sample
[
'rating'
]).
all
()
num_iter
+=
1
assert
num_padded_iter
==
num_padded
*
repeat_size
assert
num_iter
==
dataset_size
*
num_shards
*
repeat_size
assert
partitions
(
4
,
6
,
4
)
==
4
assert
partitions
(
5
,
5
,
3
)
==
3
assert
partitions
(
9
,
8
,
2
)
==
2
partitions
(
4
,
6
,
4
)
partitions
(
5
,
5
,
3
)
partitions
(
9
,
8
,
2
)
def
get_data
(
dir_name
):
"""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录