Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
兔爷不爱我
mindspore
提交
e2ea1fa0
M
mindspore
项目概览
兔爷不爱我
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
e2ea1fa0
编写于
7月 16, 2020
作者:
L
liyong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
activate num_samples in distributed samplers
上级
11732f0e
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
108 addition
and
11 deletion
+108
-11
mindspore/ccsrc/minddata/dataset/api/python_bindings.cc
mindspore/ccsrc/minddata/dataset/api/python_bindings.cc
+1
-1
mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h
...rc/minddata/mindrecord/include/shard_distributed_sample.h
+3
-2
mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h
mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h
+1
-1
mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc
...csrc/minddata/mindrecord/meta/shard_distributed_sample.cc
+5
-4
mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc
mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc
+8
-2
mindspore/dataset/engine/samplers.py
mindspore/dataset/engine/samplers.py
+3
-1
tests/ut/python/dataset/test_minddataset.py
tests/ut/python/dataset/test_minddataset.py
+66
-0
tests/ut/python/dataset/test_minddataset_exception.py
tests/ut/python/dataset/test_minddataset_exception.py
+21
-0
未找到文件。
mindspore/ccsrc/minddata/dataset/api/python_bindings.cc
浏览文件 @
e2ea1fa0
...
...
@@ -784,7 +784,7 @@ void bindSamplerOps(py::module *m) {
(
void
)
py
::
class_
<
mindrecord
::
ShardDistributedSample
,
mindrecord
::
ShardSample
,
std
::
shared_ptr
<
mindrecord
::
ShardDistributedSample
>>
(
*
m
,
"MindrecordDistributedSampler"
)
.
def
(
py
::
init
<
int64_t
,
int64_t
,
bool
,
uint32_t
>
());
.
def
(
py
::
init
<
int64_t
,
int64_t
,
bool
,
uint32_t
,
int64_t
>
());
(
void
)
py
::
class_
<
mindrecord
::
ShardShuffle
,
mindrecord
::
ShardOperator
,
std
::
shared_ptr
<
mindrecord
::
ShardShuffle
>>
(
*
m
,
"MindrecordRandomSampler"
)
...
...
mindspore/ccsrc/minddata/mindrecord/include/shard_distributed_sample.h
浏览文件 @
e2ea1fa0
...
...
@@ -29,9 +29,10 @@ namespace mindspore {
namespace
mindrecord
{
class
ShardDistributedSample
:
public
ShardSample
{
public:
ShardDistributedSample
(
int
num_shards
,
int
shard_id
,
int
no_of_padded_samples
,
bool
shuffle
,
uint32_t
seed
);
ShardDistributedSample
(
int
num_shards
,
int
shard_id
,
int
no_of_padded_samples
,
bool
shuffle
,
uint32_t
seed
,
int
no_of_samples
=
0
);
ShardDistributedSample
(
int
num_shards
,
int
shard_id
,
bool
shuffle
,
uint32_t
seed
);
ShardDistributedSample
(
int
num_shards
,
int
shard_id
,
bool
shuffle
,
uint32_t
seed
,
int
no_of_samples
=
0
);
void
SetNumPaddedSamples
(
int
no_of_padded_samples
)
{
no_of_padded_samples_
=
no_of_padded_samples
;
}
...
...
mindspore/ccsrc/minddata/mindrecord/include/shard_sample.h
浏览文件 @
e2ea1fa0
...
...
@@ -32,7 +32,7 @@ class ShardSample : public ShardOperator {
ShardSample
(
int
num
,
int
den
);
ShardSample
(
int
num
,
int
den
,
int
par
);
ShardSample
(
int
num
,
int
den
,
int
par
,
int
no_of_samples
=
0
);
ShardSample
(
const
std
::
vector
<
int64_t
>
&
indices
,
uint32_t
seed
);
...
...
mindspore/ccsrc/minddata/mindrecord/meta/shard_distributed_sample.cc
浏览文件 @
e2ea1fa0
...
...
@@ -23,16 +23,17 @@ using mindspore::MsLogLevel::ERROR;
namespace
mindspore
{
namespace
mindrecord
{
ShardDistributedSample
::
ShardDistributedSample
(
int
num_shards
,
int
shard_id
,
int
no_of_padded_samples
,
bool
shuffle
,
uint32_t
seed
)
:
ShardSample
(
1
,
num_shards
,
shard_id
),
uint32_t
seed
,
int
no_of_samples
)
:
ShardSample
(
1
,
num_shards
,
shard_id
,
no_of_samples
),
shuffle_
(
shuffle
),
no_of_padded_samples_
(
no_of_padded_samples
),
first_epoch_
(
true
)
{
shuffle_op_
=
std
::
make_shared
<
ShardShuffle
>
(
seed
,
kShuffleSample
);
}
ShardDistributedSample
::
ShardDistributedSample
(
int
num_shards
,
int
shard_id
,
bool
shuffle
,
uint32_t
seed
)
:
ShardDistributedSample
(
num_shards
,
shard_id
,
0
,
shuffle
,
seed
)
{}
ShardDistributedSample
::
ShardDistributedSample
(
int
num_shards
,
int
shard_id
,
bool
shuffle
,
uint32_t
seed
,
int
no_of_samples
)
:
ShardDistributedSample
(
num_shards
,
shard_id
,
0
,
shuffle
,
seed
,
no_of_samples
)
{}
int64_t
ShardDistributedSample
::
GetNumSamples
(
int64_t
dataset_size
,
int64_t
num_classes
)
{
if
(
no_of_padded_samples_
<=
0
)
{
...
...
mindspore/ccsrc/minddata/mindrecord/meta/shard_sample.cc
浏览文件 @
e2ea1fa0
...
...
@@ -38,11 +38,11 @@ ShardSample::ShardSample(int num, int den)
indices_
({}),
sampler_type_
(
kCustomTopPercentSampler
)
{}
ShardSample
::
ShardSample
(
int
num
,
int
den
,
int
par
)
ShardSample
::
ShardSample
(
int
num
,
int
den
,
int
par
,
int
no_of_samples
)
:
numerator_
(
num
),
denominator_
(
den
),
partition_id_
(
par
),
no_of_samples_
(
0
),
no_of_samples_
(
no_of_samples
),
indices_
({}),
sampler_type_
(
kCustomTopPercentSampler
)
{}
...
...
@@ -110,8 +110,11 @@ MSRStatus ShardSample::Execute(ShardTask &tasks) {
new_tasks
.
InsertTask
(
tasks
.
GetTaskByID
(
index
));
// different mod result between c and python
}
}
else
{
int
count
=
0
;
for
(
int
i
=
partition_id_
*
taking
;
i
<
(
partition_id_
+
1
)
*
taking
;
i
++
)
{
if
(
no_of_samples_
!=
0
&&
count
==
no_of_samples_
)
break
;
new_tasks
.
InsertTask
(
tasks
.
GetTaskByID
(
i
%
total_no
));
// rounding up. if overflow, go back to start
count
++
;
}
}
std
::
swap
(
tasks
,
new_tasks
);
...
...
@@ -121,8 +124,11 @@ MSRStatus ShardSample::Execute(ShardTask &tasks) {
return
FAILED
;
}
total_no
=
static_cast
<
int
>
(
tasks
.
permutation_
.
size
());
int
count
=
0
;
for
(
size_t
i
=
partition_id_
*
taking
;
i
<
(
partition_id_
+
1
)
*
taking
;
i
++
)
{
if
(
no_of_samples_
!=
0
&&
count
==
no_of_samples_
)
break
;
new_tasks
.
InsertTask
(
tasks
.
GetTaskByID
(
tasks
.
permutation_
[
i
%
total_no
]));
count
++
;
}
std
::
swap
(
tasks
,
new_tasks
);
}
...
...
mindspore/dataset/engine/samplers.py
浏览文件 @
e2ea1fa0
...
...
@@ -270,7 +270,9 @@ class DistributedSampler(BuiltinSampler):
return
c_sampler
def
create_for_minddataset
(
self
):
c_sampler
=
cde
.
MindrecordDistributedSampler
(
self
.
num_shards
,
self
.
shard_id
,
self
.
shuffle
,
self
.
seed
)
num_samples
=
self
.
num_samples
if
self
.
num_samples
is
not
None
else
0
c_sampler
=
cde
.
MindrecordDistributedSampler
(
self
.
num_shards
,
self
.
shard_id
,
self
.
shuffle
,
self
.
seed
,
num_samples
)
c_child_sampler
=
self
.
create_child_for_minddataset
()
c_sampler
.
add_child
(
c_child_sampler
)
return
c_sampler
...
...
tests/ut/python/dataset/test_minddataset.py
浏览文件 @
e2ea1fa0
...
...
@@ -238,6 +238,72 @@ def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file):
assert
partitions
(
5
)
==
2
assert
partitions
(
9
)
==
2
def
test_cv_minddataset_partition_num_samples_0
(
add_and_remove_cv_file
):
"""tutorial for cv minddataset."""
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
num_readers
=
4
def
partitions
(
num_shards
):
for
partition_id
in
range
(
num_shards
):
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
+
"0"
,
columns_list
,
num_readers
,
num_shards
=
num_shards
,
shard_id
=
partition_id
,
num_samples
=
1
)
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
"-------------- partition : {} ------------------------"
.
format
(
partition_id
))
logger
.
info
(
"-------------- item[file_name]: {}-----------------------"
.
format
(
item
[
"file_name"
]))
logger
.
info
(
"-------------- item[label]: {} -----------------------"
.
format
(
item
[
"label"
]))
num_iter
+=
1
return
num_iter
assert
partitions
(
4
)
==
1
assert
partitions
(
5
)
==
1
assert
partitions
(
9
)
==
1
def
test_cv_minddataset_partition_num_samples_1
(
add_and_remove_cv_file
):
"""tutorial for cv minddataset."""
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
num_readers
=
4
def
partitions
(
num_shards
):
for
partition_id
in
range
(
num_shards
):
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
+
"0"
,
columns_list
,
num_readers
,
num_shards
=
num_shards
,
shard_id
=
partition_id
,
num_samples
=
2
)
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
"-------------- partition : {} ------------------------"
.
format
(
partition_id
))
logger
.
info
(
"-------------- item[file_name]: {}-----------------------"
.
format
(
item
[
"file_name"
]))
logger
.
info
(
"-------------- item[label]: {} -----------------------"
.
format
(
item
[
"label"
]))
num_iter
+=
1
return
num_iter
assert
partitions
(
4
)
==
2
assert
partitions
(
5
)
==
2
assert
partitions
(
9
)
==
2
def
test_cv_minddataset_partition_num_samples_2
(
add_and_remove_cv_file
):
"""tutorial for cv minddataset."""
columns_list
=
[
"data"
,
"file_name"
,
"label"
]
num_readers
=
4
def
partitions
(
num_shards
):
for
partition_id
in
range
(
num_shards
):
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
+
"0"
,
columns_list
,
num_readers
,
num_shards
=
num_shards
,
shard_id
=
partition_id
,
num_samples
=
3
)
num_iter
=
0
for
item
in
data_set
.
create_dict_iterator
():
logger
.
info
(
"-------------- partition : {} ------------------------"
.
format
(
partition_id
))
logger
.
info
(
"-------------- item[file_name]: {}-----------------------"
.
format
(
item
[
"file_name"
]))
logger
.
info
(
"-------------- item[label]: {} -----------------------"
.
format
(
item
[
"label"
]))
num_iter
+=
1
return
num_iter
assert
partitions
(
4
)
==
3
assert
partitions
(
5
)
==
2
assert
partitions
(
9
)
==
2
def
test_cv_minddataset_partition_tutorial_check_shuffle_result
(
add_and_remove_cv_file
):
"""tutorial for cv minddataset."""
...
...
tests/ut/python/dataset/test_minddataset_exception.py
浏览文件 @
e2ea1fa0
...
...
@@ -228,3 +228,24 @@ def test_minddataset_shard_id_bigger_than_num_shard():
os
.
remove
(
CV_FILE_NAME
)
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME
))
def
test_cv_minddataset_partition_num_samples_equals_0
():
"""tutorial for cv minddataset."""
create_cv_mindrecord
(
1
)
columns_list
=
[
"data"
,
"label"
]
num_readers
=
4
def
partitions
(
num_shards
):
for
partition_id
in
range
(
num_shards
):
data_set
=
ds
.
MindDataset
(
CV_FILE_NAME
,
columns_list
,
num_readers
,
num_shards
=
num_shards
,
shard_id
=
partition_id
,
num_samples
=
0
)
num_iter
=
0
for
_
in
data_set
.
create_dict_iterator
():
num_iter
+=
1
with
pytest
.
raises
(
Exception
)
as
error_info
:
partitions
(
5
)
assert
'num_samples should be a positive integer value, but got num_samples=0'
in
str
(
error_info
)
os
.
remove
(
CV_FILE_NAME
)
os
.
remove
(
"{}.db"
.
format
(
CV_FILE_NAME
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录