Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
43a2e998
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
43a2e998
编写于
4月 16, 2020
作者:
J
Junhan Hu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add python sampler support for CPP dataset
上级
3ad73b7d
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
296 addition
and
16 deletion
+296
-16
mindspore/ccsrc/dataset/api/python_bindings.cc
mindspore/ccsrc/dataset/api/python_bindings.cc
+5
-0
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt
...c/dataset/engine/datasetops/source/sampler/CMakeLists.txt
+1
-0
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc
...ataset/engine/datasetops/source/sampler/python_sampler.cc
+83
-0
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h
...dataset/engine/datasetops/source/sampler/python_sampler.h
+58
-0
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc
...ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc
+0
-3
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc
...et/engine/datasetops/source/sampler/sequential_sampler.cc
+1
-0
mindspore/dataset/__init__.py
mindspore/dataset/__init__.py
+1
-1
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+1
-1
mindspore/dataset/engine/samplers.py
mindspore/dataset/engine/samplers.py
+86
-6
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+3
-5
tests/ut/python/dataset/test_sampler.py
tests/ut/python/dataset/test_sampler.py
+57
-0
未找到文件。
mindspore/ccsrc/dataset/api/python_bindings.cc
浏览文件 @
43a2e998
...
...
@@ -53,6 +53,7 @@
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
#include "dataset/engine/datasetops/source/sampler/python_sampler.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/jagged_connector.h"
#include "dataset/kernels/data/to_float16_op.h"
...
...
@@ -415,6 +416,7 @@ void bindSamplerOps(py::module *m) {
(
void
)
py
::
class_
<
SequentialSampler
,
Sampler
,
std
::
shared_ptr
<
SequentialSampler
>>
(
*
m
,
"SequentialSampler"
)
.
def
(
py
::
init
<>
());
(
void
)
py
::
class_
<
SubsetRandomSampler
,
Sampler
,
std
::
shared_ptr
<
SubsetRandomSampler
>>
(
*
m
,
"SubsetRandomSampler"
)
.
def
(
py
::
init
<
std
::
vector
<
int64_t
>>
(),
py
::
arg
(
"indices"
));
...
...
@@ -425,6 +427,9 @@ void bindSamplerOps(py::module *m) {
(
void
)
py
::
class_
<
WeightedRandomSampler
,
Sampler
,
std
::
shared_ptr
<
WeightedRandomSampler
>>
(
*
m
,
"WeightedRandomSampler"
)
.
def
(
py
::
init
<
std
::
vector
<
double
>
,
int64_t
,
bool
>
(),
py
::
arg
(
"weights"
),
py
::
arg
(
"numSamples"
),
py
::
arg
(
"replacement"
));
(
void
)
py
::
class_
<
PythonSampler
,
Sampler
,
std
::
shared_ptr
<
PythonSampler
>>
(
*
m
,
"PythonSampler"
)
.
def
(
py
::
init
<
py
::
object
>
(),
py
::
arg
(
"pySampler"
));
}
void
bindInfoObjects
(
py
::
module
*
m
)
{
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt
浏览文件 @
43a2e998
add_library
(
engine-datasetops-source-sampler OBJECT
distributed_sampler.cc
pk_sampler.cc
python_sampler.cc
random_sampler.cc
sampler.cc
sequential_sampler.cc
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc
0 → 100644
浏览文件 @
43a2e998
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/source/sampler/python_sampler.h"
#include <memory>
namespace
mindspore
{
namespace
dataset
{
PythonSampler
::
PythonSampler
(
py
::
object
py_sampler_instance
,
int64_t
samples_per_buffer
)
:
Sampler
(
samples_per_buffer
),
py_sampler_instance
(
py_sampler_instance
),
need_to_reset_
(
false
)
{}
Status
PythonSampler
::
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
{
if
(
need_to_reset_
)
{
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
}
else
{
std
::
shared_ptr
<
Tensor
>
sample_ids
;
{
py
::
gil_scoped_acquire
gil_acquire
;
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagNone
);
if
(
Py_IsInitialized
()
==
0
)
{
return
Status
(
StatusCode
::
kPythonInterpreterFailure
,
"Python Interpreter is finalized"
);
}
try
{
py
::
object
py_ret
=
py_sampler_instance
.
attr
(
"_get_indices"
)();
py
::
array
np_sample_ids
=
py_ret
.
cast
<
py
::
array
>
();
Tensor
::
CreateTensor
(
&
sample_ids
,
np_sample_ids
);
// copy numpy to tensor
}
catch
(
const
py
::
error_already_set
&
e
)
{
return
Status
(
StatusCode
::
kPyFuncException
,
e
.
what
());
}
}
TensorRow
row
(
1
,
sample_ids
);
(
*
out_buffer
)
->
set_tensor_table
(
std
::
make_unique
<
TensorQTable
>
(
1
,
row
));
need_to_reset_
=
true
;
}
return
Status
::
OK
();
}
Status
PythonSampler
::
InitSampler
()
{
CHECK_FAIL_RETURN_UNEXPECTED
(
num_rows_
>
0
,
"ERROR num_rows_ should be greater than 0"
);
{
py
::
gil_scoped_acquire
gil_acquire
;
if
(
Py_IsInitialized
()
==
0
)
{
return
Status
(
StatusCode
::
kPythonInterpreterFailure
,
"Python Interpreter is finalized"
);
}
try
{
py_sampler_instance
.
attr
(
"_handshake"
)(
num_rows_
,
num_samples_
);
}
catch
(
const
py
::
error_already_set
&
e
)
{
return
Status
(
StatusCode
::
kPyFuncException
,
e
.
what
());
}
}
return
Status
::
OK
();
}
Status
PythonSampler
::
Reset
()
{
CHECK_FAIL_RETURN_UNEXPECTED
(
need_to_reset_
,
"ERROR Reset() called not at end of an epoch"
);
need_to_reset_
=
false
;
py
::
gil_scoped_acquire
gil_acquire
;
if
(
Py_IsInitialized
()
==
0
)
{
return
Status
(
StatusCode
::
kPythonInterpreterFailure
,
"Python Interpreter is finalized"
);
}
try
{
py_sampler_instance
.
attr
(
"reset"
)();
}
catch
(
const
py
::
error_already_set
&
e
)
{
return
Status
(
StatusCode
::
kPyFuncException
,
e
.
what
());
}
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h
0 → 100644
浏览文件 @
43a2e998
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_
#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_
#include <limits>
#include <memory>
#include "dataset/engine/datasetops/source/sampler/sampler.h"
namespace
mindspore
{
namespace
dataset
{
class
PythonSampler
:
public
Sampler
{
public:
// Constructor
// @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call
explicit
PythonSampler
(
py
::
object
py_sampler_instance
,
int64_t
samples_per_buffer
=
std
::
numeric_limits
<
int64_t
>::
max
());
// Destructor.
~
PythonSampler
()
=
default
;
// Initialize the sampler.
// @return Status
Status
InitSampler
()
override
;
// for next epoch of sampleIds
// @return - The error code return
Status
Reset
()
override
;
// Op calls this to get next Buffer that contains all the sampleIds
// @param std::unique_ptr<DataBuffer> pBuffer - Buffer to be returned to StorageOp
// @param int32_t workerId - not meant to be used
// @return - The error code return
Status
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
private:
bool
need_to_reset_
;
// Whether Reset() should be called before calling GetNextBuffer()
py
::
object
py_sampler_instance
;
// The handle to the py_sampler python object
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc
浏览文件 @
43a2e998
...
...
@@ -48,9 +48,6 @@ Status Sampler::GetAllIdsThenReset(py::array *data) {
std
::
unique_ptr
<
DataBuffer
>
db
;
std
::
shared_ptr
<
Tensor
>
sample_ids
;
// check samples_per_buffer is properly set and doesn't overflow
CHECK_FAIL_RETURN_UNEXPECTED
(
samples_per_buffer_
+
1
>
1
,
"samples_per_buffer invalid"
);
// A call to derived class to get sample ids wrapped inside a buffer
RETURN_IF_NOT_OK
(
GetNextBuffer
(
&
db
));
// Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc
浏览文件 @
43a2e998
...
...
@@ -42,6 +42,7 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr<DataBuffer> *out_buffer)
}
Status
SequentialSampler
::
InitSampler
()
{
num_samples_
=
(
num_samples_
<=
0
)
?
num_rows_
:
num_samples_
;
// if num_samples < 0, try if num_rows is set
CHECK_FAIL_RETURN_UNEXPECTED
(
num_samples_
>
0
&&
samples_per_buffer_
>
0
,
"Fail to init Sequential Sampler"
);
samples_per_buffer_
=
samples_per_buffer_
>
num_samples_
?
num_samples_
:
samples_per_buffer_
;
return
Status
::
OK
();
...
...
mindspore/dataset/__init__.py
浏览文件 @
43a2e998
...
...
@@ -23,7 +23,7 @@ from .engine.datasets import StorageDataset, TFRecordDataset, ImageFolderDataset
GeneratorDataset
,
ManifestDataset
,
Cifar10Dataset
,
Cifar100Dataset
,
VOCDataset
,
CelebADataset
,
Schema
,
\
Shuffle
,
zip
from
.engine.samplers
import
DistributedSampler
,
PKSampler
,
RandomSampler
,
SequentialSampler
,
SubsetRandomSampler
,
\
WeightedRandomSampler
WeightedRandomSampler
,
Sampler
from
.engine.serializer_deserializer
import
serialize
,
deserialize
,
show
__all__
=
[
"config"
,
"ImageFolderDatasetV2"
,
"MnistDataset"
,
"StorageDataset"
,
...
...
mindspore/dataset/engine/datasets.py
浏览文件 @
43a2e998
...
...
@@ -2032,7 +2032,7 @@ class GeneratorDataset(SourceDataset):
if
self
.
sampler
is
not
None
and
hasattr
(
source
,
"__getitem__"
):
if
isinstance
(
self
.
sampler
,
(
samplers
.
SequentialSampler
,
samplers
.
DistributedSampler
,
samplers
.
RandomSampler
,
samplers
.
SubsetRandomSampler
,
samplers
.
WeightedRandomSampler
)):
samplers
.
WeightedRandomSampler
,
samplers
.
Sampler
)):
if
num_samples
is
None
:
num_samples
=
len
(
source
)
sampler_instance
=
self
.
sampler
.
create
()
...
...
mindspore/dataset/engine/samplers.py
浏览文件 @
43a2e998
...
...
@@ -16,11 +16,90 @@
Sampler module provides several samplers to generate sampling data from dataset.
There are following samplers: DistributedSampler, PKSampler, RandomSampler,
SequentialSampler, SubsetRandomSampler, WeightedRandomSampler.
User can also define custom sampler by extending from Sampler class.
"""
import
mindspore._c_dataengine
as
cde
import
numpy
as
np
class
DistributedSampler
():
class
Sampler
:
"""
Base class for user defined sampler.
User defined sampler can be used with any existing dataset with sampler support.
An required _iter_() method should by overridden by user for sample index generation.
An optional reset() method can be overridden for per repeat reset,
dataset_size and num_samples will be set by dataset once a dataset iterator is created.
Examples:
>>> import mindspore.dataset as ds
>>>
>>> class ReverseSampler(ds,Sampler):
>>> def __iter__(self):
>>> for i in range(self.dataset_size - 1, -1, -1):
>>> yield i
>>>
>>> ds = ds.ImageFolderDatasetV2(path, sampler=ReverseSampler())
"""
def
__init__
(
self
):
self
.
dataset_size
=
0
self
.
num_samples
=
0
def
__iter__
(
self
):
"""
User defined iterator, must be overridden.
_handshake is guaranteed to be called prior to iterator construction
"""
raise
NotImplementedError
def
reset
(
self
):
"""
Per repeat reset callback, override this method if necessary
"""
# Initialization handshake callback
# Do not override this method!
def
_handshake
(
self
,
ds_size
,
num_samples
):
self
.
dataset_size
=
ds_size
self
.
num_samples
=
num_samples
# Indices fetcher
# Do not override this method!
def
_get_indices
(
self
):
sampler_iter
=
iter
(
self
)
ret
=
[]
for
_
in
range
(
self
.
num_samples
):
try
:
idx
=
next
(
sampler_iter
)
ret
.
append
(
idx
)
except
StopIteration
:
break
return
np
.
array
(
ret
)
# Instance fetcher
# Do not override this method!
def
create
(
self
):
return
cde
.
PythonSampler
(
self
)
class
BuiltinSampler
:
"""
Base class for BuiltinSampler.
User should not extend this class.
"""
def
__init__
(
self
):
pass
def
create
(
self
):
pass
class
DistributedSampler
(
BuiltinSampler
):
"""
Sampler that access a shard of the dataset.
...
...
@@ -65,7 +144,7 @@ class DistributedSampler():
return
cde
.
DistributedSampler
(
self
.
num_shards
,
self
.
shard_id
,
self
.
shuffle
,
self
.
seed
)
class
PKSampler
():
class
PKSampler
(
BuiltinSampler
):
"""
Samples K elements for each P class in the dataset.
...
...
@@ -106,7 +185,7 @@ class PKSampler():
return
cde
.
PKSampler
(
self
.
num_val
,
self
.
shuffle
)
class
RandomSampler
():
class
RandomSampler
(
BuiltinSampler
):
"""
Samples the elements randomly.
...
...
@@ -147,7 +226,7 @@ class RandomSampler():
return
cde
.
RandomSampler
(
self
.
replacement
,
self
.
num_samples
)
class
SequentialSampler
():
class
SequentialSampler
(
BuiltinSampler
):
"""
Samples the dataset elements sequentially, same as not having a sampler.
...
...
@@ -165,7 +244,7 @@ class SequentialSampler():
return
cde
.
SequentialSampler
()
class
SubsetRandomSampler
():
class
SubsetRandomSampler
(
BuiltinSampler
):
"""
Samples the elements randomly from a sequence of indices.
...
...
@@ -196,7 +275,8 @@ class SubsetRandomSampler():
def
_create_for_minddataset
(
self
):
return
cde
.
MindrecordSubsetRandomSampler
(
self
.
indices
)
class
WeightedRandomSampler
():
class
WeightedRandomSampler
(
BuiltinSampler
):
"""
Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities).
...
...
mindspore/dataset/engine/validators.py
浏览文件 @
43a2e998
...
...
@@ -297,9 +297,7 @@ def check_sampler_shuffle_shard_options(param_dict):
shuffle
,
sampler
=
param_dict
.
get
(
'shuffle'
),
param_dict
.
get
(
'sampler'
)
num_shards
,
shard_id
=
param_dict
.
get
(
'num_shards'
),
param_dict
.
get
(
'shard_id'
)
if
sampler
is
not
None
and
not
isinstance
(
sampler
,
(
samplers
.
DistributedSampler
,
samplers
.
PKSampler
,
samplers
.
RandomSampler
,
samplers
.
SequentialSampler
,
samplers
.
SubsetRandomSampler
,
samplers
.
WeightedRandomSampler
)):
if
sampler
is
not
None
and
not
isinstance
(
sampler
,
(
samplers
.
BuiltinSampler
,
samplers
.
Sampler
)):
raise
ValueError
(
"sampler is not a valid Sampler type."
)
if
sampler
is
not
None
:
...
...
@@ -579,11 +577,11 @@ def check_generatordataset(method):
raise
ValueError
(
"PKSampler is not supported by GeneratorDataset"
)
if
not
isinstance
(
sampler
,
(
samplers
.
SequentialSampler
,
samplers
.
DistributedSampler
,
samplers
.
RandomSampler
,
samplers
.
SubsetRandomSampler
,
samplers
.
WeightedRandomSampler
)):
samplers
.
WeightedRandomSampler
,
samplers
.
Sampler
)):
try
:
iter
(
sampler
)
except
TypeError
:
raise
TypeError
(
"sampler should be either iterable or from
dataset.samplers.py
"
)
raise
TypeError
(
"sampler should be either iterable or from
mindspore.dataset.samplers
"
)
return
method
(
*
args
,
**
kwargs
)
...
...
tests/ut/python/dataset/test_sampler.py
浏览文件 @
43a2e998
...
...
@@ -14,6 +14,7 @@
# ==============================================================================
import
mindspore.dataset
as
ds
from
mindspore
import
log
as
logger
import
numpy
as
np
# test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631]
...
...
@@ -107,8 +108,64 @@ def test_sampler_py_api():
sampler
.
get_indices
()
def
test_python_sampler
():
manifest_file
=
"../data/dataset/testManifestData/test5trainimgs.json"
map
=
{(
172876
,
0
):
0
,
(
54214
,
0
):
1
,
(
54214
,
1
):
2
,
(
173673
,
0
):
3
,
(
64631
,
1
):
4
}
class
Sp1
(
ds
.
Sampler
):
def
__iter__
(
self
):
return
iter
([
i
for
i
in
range
(
self
.
dataset_size
)])
class
Sp2
(
ds
.
Sampler
):
def
__init__
(
self
):
super
(
Sp2
,
self
).
__init__
()
# at this stage, self.dataset_size and self.num_samples are not yet known
self
.
cnt
=
0
def
__iter__
(
self
):
# first epoch, all 0, second epoch all 1, third all 2 etc.. ...
return
iter
([
self
.
cnt
for
i
in
range
(
self
.
num_samples
)])
def
reset
(
self
):
self
.
cnt
=
(
self
.
cnt
+
1
)
%
self
.
dataset_size
def
test_config
(
num_samples
,
num_repeats
,
sampler
):
data1
=
ds
.
ManifestDataset
(
manifest_file
,
num_samples
=
num_samples
,
sampler
=
sampler
)
if
num_repeats
is
not
None
:
data1
=
data1
.
repeat
(
num_repeats
)
res
=
[]
for
item
in
data1
.
create_dict_iterator
():
logger
.
info
(
"item[image].shape[0]: {}, item[label].item(): {}"
.
format
(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
()))
res
.
append
(
map
[(
item
[
"image"
].
shape
[
0
],
item
[
"label"
].
item
())])
# print(res)
return
res
def
test_generator
():
class
MySampler
(
ds
.
Sampler
):
def
__iter__
(
self
):
for
i
in
range
(
99
,
-
1
,
-
1
):
yield
i
data1
=
ds
.
GeneratorDataset
([(
np
.
array
(
i
),)
for
i
in
range
(
100
)],
[
"data"
],
sampler
=
MySampler
())
i
=
99
for
data
in
data1
:
assert
data
[
0
]
==
(
np
.
array
(
i
),)
i
=
i
-
1
assert
test_config
(
5
,
2
,
Sp1
())
==
[
0
,
1
,
2
,
3
,
4
,
0
,
1
,
2
,
3
,
4
]
assert
test_config
(
2
,
6
,
Sp2
())
==
[
0
,
0
,
1
,
1
,
2
,
2
,
3
,
3
,
4
,
4
,
0
,
0
]
test_generator
()
sp1
=
Sp1
().
create
()
sp1
.
set_num_rows
(
5
)
sp1
.
set_num_samples
(
5
)
sp1
.
initialize
()
assert
list
(
sp1
.
get_indices
())
==
[
0
,
1
,
2
,
3
,
4
]
if
__name__
==
'__main__'
:
test_sequential_sampler
(
True
)
test_random_sampler
(
True
)
test_random_sampler_multi_iter
(
True
)
test_sampler_py_api
()
test_python_sampler
()
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录