Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
8921a609
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8921a609
编写于
7月 20, 2020
作者:
C
Cathy Wong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
C++ API Support for Skip Dataset Op and UTs
上级
4bbbf2dc
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
147 addition
and
2 deletion
+147
-2
mindspore/ccsrc/minddata/dataset/api/datasets.cc
mindspore/ccsrc/minddata/dataset/api/datasets.cc
+37
-0
mindspore/ccsrc/minddata/dataset/include/datasets.h
mindspore/ccsrc/minddata/dataset/include/datasets.h
+27
-0
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+2
-2
tests/ut/cpp/dataset/c_api_test.cc
tests/ut/cpp/dataset/c_api_test.cc
+53
-0
tests/ut/python/dataset/test_skip.py
tests/ut/python/dataset/test_skip.py
+28
-0
未找到文件。
mindspore/ccsrc/minddata/dataset/api/datasets.cc
浏览文件 @
8921a609
...
...
@@ -27,6 +27,7 @@
#include "minddata/dataset/engine/datasetops/map_op.h"
#include "minddata/dataset/engine/datasetops/repeat_op.h"
#include "minddata/dataset/engine/datasetops/shuffle_op.h"
#include "minddata/dataset/engine/datasetops/skip_op.h"
#include "minddata/dataset/engine/datasetops/project_op.h"
#include "minddata/dataset/engine/datasetops/zip_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
...
...
@@ -173,6 +174,20 @@ std::shared_ptr<ShuffleDataset> Dataset::Shuffle(int32_t shuffle_size) {
return
ds
;
}
// Function to create a SkipDataset.
std
::
shared_ptr
<
SkipDataset
>
Dataset
::
Skip
(
int32_t
count
)
{
auto
ds
=
std
::
make_shared
<
SkipDataset
>
(
count
);
// Call derived class validation method.
if
(
!
ds
->
ValidateParams
())
{
return
nullptr
;
}
ds
->
children
.
push_back
(
shared_from_this
());
return
ds
;
}
// Function to create a ProjectDataset.
std
::
shared_ptr
<
ProjectDataset
>
Dataset
::
Project
(
const
std
::
vector
<
std
::
string
>
&
columns
)
{
auto
ds
=
std
::
make_shared
<
ProjectDataset
>
(
columns
);
...
...
@@ -400,6 +415,28 @@ bool ShuffleDataset::ValidateParams() {
return
true
;
}
// Constructor for SkipDataset
SkipDataset
::
SkipDataset
(
int32_t
count
)
:
skip_count_
(
count
)
{}
// Function to build the SkipOp
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>>
SkipDataset
::
Build
()
{
// A vector containing shared pointer to the Dataset Ops that this object will create
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
node_ops
;
node_ops
.
push_back
(
std
::
make_shared
<
SkipOp
>
(
skip_count_
,
connector_que_size_
));
return
std
::
make_shared
<
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>>
(
node_ops
);
}
// Function to validate the parameters for SkipDataset
bool
SkipDataset
::
ValidateParams
()
{
if
(
skip_count_
<=
-
1
)
{
MS_LOG
(
ERROR
)
<<
"Skip: Invalid input, skip_count: "
<<
skip_count_
;
return
false
;
}
return
true
;
}
// Constructor for Cifar10Dataset
Cifar10Dataset
::
Cifar10Dataset
(
const
std
::
string
&
dataset_dir
,
int32_t
num_samples
,
std
::
shared_ptr
<
SamplerObj
>
sampler
)
:
dataset_dir_
(
dataset_dir
),
num_samples_
(
num_samples
),
sampler_
(
sampler
)
{}
...
...
mindspore/ccsrc/minddata/dataset/include/datasets.h
浏览文件 @
8921a609
...
...
@@ -46,6 +46,7 @@ class BatchDataset;
class
RepeatDataset
;
class
MapDataset
;
class
ShuffleDataset
;
class
SkipDataset
;
class
Cifar10Dataset
;
class
ProjectDataset
;
class
ZipDataset
;
...
...
@@ -160,6 +161,12 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \return Shared pointer to the current ShuffleDataset
std
::
shared_ptr
<
ShuffleDataset
>
Shuffle
(
int32_t
shuffle_size
);
/// \brief Function to create a SkipDataset
/// \notes Skips count elements in this dataset.
/// \param[in] count Number of elements the dataset to be skipped.
/// \return Shared pointer to the current SkipDataset
std
::
shared_ptr
<
SkipDataset
>
Skip
(
int32_t
count
);
/// \brief Function to create a Project Dataset
/// \notes Applies project to the dataset
/// \param[in] columns The name of columns to project
...
...
@@ -293,6 +300,26 @@ class ShuffleDataset : public Dataset {
bool
reset_every_epoch_
;
};
class
SkipDataset
:
public
Dataset
{
public:
/// \brief Constructor
explicit
SkipDataset
(
int32_t
count
);
/// \brief Destructor
~
SkipDataset
()
=
default
;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>>
Build
()
override
;
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool
ValidateParams
()
override
;
private:
int32_t
skip_count_
;
};
class
MapDataset
:
public
Dataset
{
public:
/// \brief Constructor
...
...
mindspore/dataset/engine/datasets.py
浏览文件 @
8921a609
...
...
@@ -2094,8 +2094,8 @@ class SkipDataset(DatasetOp):
The result of applying Skip operator to the input Dataset.
Args:
input_dataset (
tuple): A tuple of datasets to be
skipped.
count (int): Number of rows
the dataset should
be skipped.
input_dataset (
Dataset): Input dataset to have rows
skipped.
count (int): Number of rows
in the dataset to
be skipped.
"""
def
__init__
(
self
,
input_dataset
,
count
):
...
...
tests/ut/cpp/dataset/c_api_test.cc
浏览文件 @
8921a609
...
...
@@ -573,6 +573,59 @@ TEST_F(MindDataTestPipeline, TestShuffleDataset) {
iter
->
Stop
();
}
TEST_F
(
MindDataTestPipeline
,
TestSkipDataset
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestSkipDataset."
;
// Create an ImageFolder Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testPK/data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
ImageFolder
(
folder_path
,
true
,
RandomSampler
(
false
,
10
));
EXPECT_TRUE
(
ds
!=
nullptr
);
// Create a Skip operation on ds
int32_t
count
=
3
;
ds
=
ds
->
Skip
(
count
);
EXPECT_TRUE
(
ds
!=
nullptr
);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std
::
shared_ptr
<
Iterator
>
iter
=
ds
->
CreateIterator
();
EXPECT_TRUE
(
iter
!=
nullptr
);
// Iterate the dataset and get each row
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Tensor
>>
row
;
iter
->
GetNextRow
(
&
row
);
uint64_t
i
=
0
;
while
(
row
.
size
()
!=
0
)
{
i
++
;
auto
image
=
row
[
"image"
];
MS_LOG
(
INFO
)
<<
"Tensor image shape: "
<<
image
->
shape
();
iter
->
GetNextRow
(
&
row
);
}
MS_LOG
(
INFO
)
<<
"Number of rows: "
<<
i
;
// Expect 10-3=7 rows
EXPECT_TRUE
(
i
==
7
);
// Manually terminate the pipeline
iter
->
Stop
();
}
TEST_F
(
MindDataTestPipeline
,
TestSkipDatasetError1
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestSkipDatasetError1."
;
// Create an ImageFolder Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testPK/data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
ImageFolder
(
folder_path
,
true
,
RandomSampler
(
false
,
10
));
EXPECT_TRUE
(
ds
!=
nullptr
);
// Create a Skip operation on ds with invalid count input
int32_t
count
=
-
1
;
ds
=
ds
->
Skip
(
count
);
// Expect nullptr for invalid input skip_count
EXPECT_TRUE
(
ds
==
nullptr
);
}
TEST_F
(
MindDataTestPipeline
,
TestCifar10Dataset
)
{
// Create a Cifar10 Dataset
...
...
tests/ut/python/dataset/test_skip.py
浏览文件 @
8921a609
...
...
@@ -13,9 +13,12 @@
# limitations under the License.
# ==============================================================================
import
numpy
as
np
import
pytest
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
from
mindspore
import
log
as
logger
DATA_DIR_TF2
=
[
"../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
]
SCHEMA_DIR_TF2
=
"../data/dataset/test_tf_file_3_images/datasetSchema.json"
...
...
@@ -196,6 +199,29 @@ def test_skip_filter_2():
assert
buf
==
[
5
,
6
,
7
,
8
,
9
,
10
]
def
test_skip_exception_1
():
data1
=
ds
.
GeneratorDataset
(
generator_md
,
[
"data"
])
try
:
data1
=
data1
.
skip
(
count
=-
1
)
num_iter
=
0
for
_
in
data1
.
create_dict_iterator
():
num_iter
+=
1
except
RuntimeError
as
e
:
logger
.
info
(
"Got an exception in DE: {}"
.
format
(
str
(
e
)))
assert
"Skip count must be positive integer or 0."
in
str
(
e
)
def
test_skip_exception_2
():
ds1
=
ds
.
GeneratorDataset
(
generator_md
,
[
"data"
])
with
pytest
.
raises
(
ValueError
)
as
e
:
ds1
=
ds1
.
skip
(
-
2
)
assert
"Input count is not within the required interval"
in
str
(
e
.
value
)
if
__name__
==
"__main__"
:
test_tf_skip
()
test_generator_skip
()
...
...
@@ -208,3 +234,5 @@ if __name__ == "__main__":
test_skip_take_2
()
test_skip_filter_1
()
test_skip_filter_2
()
test_skip_exception_1
()
test_skip_exception_2
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录