Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
91b4d907
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
91b4d907
编写于
7月 20, 2020
作者:
M
Mahdi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Added ZipOp
上级
eeba0461
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
112 addition
and
3 deletion
+112
-3
mindspore/ccsrc/minddata/dataset/api/datasets.cc
mindspore/ccsrc/minddata/dataset/api/datasets.cc
+30
-0
mindspore/ccsrc/minddata/dataset/api/iterator.cc
mindspore/ccsrc/minddata/dataset/api/iterator.cc
+3
-1
mindspore/ccsrc/minddata/dataset/include/datasets.h
mindspore/ccsrc/minddata/dataset/include/datasets.h
+25
-0
tests/ut/cpp/dataset/c_api_test.cc
tests/ut/cpp/dataset/c_api_test.cc
+54
-2
未找到文件。
mindspore/ccsrc/minddata/dataset/api/datasets.cc
浏览文件 @
91b4d907
...
...
@@ -28,6 +28,7 @@
#include "minddata/dataset/engine/datasetops/repeat_op.h"
#include "minddata/dataset/engine/datasetops/shuffle_op.h"
#include "minddata/dataset/engine/datasetops/project_op.h"
#include "minddata/dataset/engine/datasetops/zip_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/random_sampler.h"
...
...
@@ -53,6 +54,7 @@ std::shared_ptr<Iterator> Dataset::CreateIterator() {
iter
=
std
::
make_shared
<
Iterator
>
();
Status
rc
=
iter
->
BuildAndLaunchTree
(
shared_from_this
());
if
(
rc
.
IsError
())
{
MS_LOG
(
ERROR
)
<<
rc
;
MS_LOG
(
ERROR
)
<<
"CreateIterator failed."
;
return
nullptr
;
}
...
...
@@ -184,6 +186,21 @@ std::shared_ptr<ProjectDataset> Dataset::Project(const std::vector<std::string>
return
ds
;
}
// Function to create a Zip dataset
std
::
shared_ptr
<
ZipDataset
>
Dataset
::
Zip
(
const
std
::
vector
<
std
::
shared_ptr
<
Dataset
>>
&
datasets
)
{
// Default values
auto
ds
=
std
::
make_shared
<
ZipDataset
>
();
if
(
!
ds
->
ValidateParams
())
{
return
nullptr
;
}
for
(
auto
dataset
:
datasets
)
{
ds
->
children
.
push_back
(
dataset
);
}
return
ds
;
}
// Helper function to create default RandomSampler.
std
::
shared_ptr
<
SamplerObj
>
CreateDefaultSampler
()
{
int32_t
num_samples
=
0
;
// 0 means to sample all ids.
...
...
@@ -441,6 +458,19 @@ std::shared_ptr<std::vector<std::shared_ptr<DatasetOp>>> ProjectDataset::Build()
return
std
::
make_shared
<
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>>
(
node_ops
);
}
// Function to build ZipOp
ZipDataset
::
ZipDataset
()
{}
bool
ZipDataset
::
ValidateParams
()
{
return
true
;
}
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>>
ZipDataset
::
Build
()
{
// A vector containing shared pointer to the Dataset Ops that this object will create
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
node_ops
;
node_ops
.
push_back
(
std
::
make_shared
<
ZipOp
>
(
rows_per_buffer_
,
connector_que_size_
));
return
std
::
make_shared
<
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>>
(
node_ops
);
}
}
// namespace api
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/api/iterator.cc
浏览文件 @
91b4d907
...
...
@@ -52,7 +52,9 @@ Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) {
// Iterative BFS converting Dataset tree into runtime Execution tree.
std
::
queue
<
std
::
pair
<
std
::
shared_ptr
<
Dataset
>
,
std
::
shared_ptr
<
DatasetOp
>>>
q
;
if
(
ds
!=
nullptr
)
{
if
(
ds
==
nullptr
)
{
RETURN_STATUS_UNEXPECTED
(
"Input is null pointer"
);
}
else
{
// Convert the current root node.
auto
root_op
=
ds
->
Build
()
->
front
();
RETURN_UNEXPECTED_IF_NULL
(
root_op
);
...
...
mindspore/ccsrc/minddata/dataset/include/datasets.h
浏览文件 @
91b4d907
...
...
@@ -48,6 +48,7 @@ class MapDataset;
class
ShuffleDataset
;
class
Cifar10Dataset
;
class
ProjectDataset
;
class
ZipDataset
;
/// \brief Function to create an ImageFolderDataset
/// \notes A source dataset that reads images from a tree of directories
...
...
@@ -165,6 +166,12 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
/// \return Shared pointer to the current Dataset
std
::
shared_ptr
<
ProjectDataset
>
Project
(
const
std
::
vector
<
std
::
string
>
&
columns
);
/// \brief Function to create a Zip Dataset
/// \notes Applies zip to the dataset
/// \param[in] datasets A list of shared pointer to the datasets that we want to zip
/// \return Shared pointer to the current Dataset
std
::
shared_ptr
<
ZipDataset
>
Zip
(
const
std
::
vector
<
std
::
shared_ptr
<
Dataset
>>
&
datasets
);
protected:
std
::
vector
<
std
::
shared_ptr
<
Dataset
>>
children
;
std
::
shared_ptr
<
Dataset
>
parent
;
...
...
@@ -351,6 +358,24 @@ class ProjectDataset : public Dataset {
private:
std
::
vector
<
std
::
string
>
columns_
;
};
class
ZipDataset
:
public
Dataset
{
public:
/// \brief Constructor
ZipDataset
();
/// \brief Destructor
~
ZipDataset
()
=
default
;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \return shared pointer to the list of newly created DatasetOps
std
::
shared_ptr
<
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>>
Build
()
override
;
/// \brief Parameters validation
/// \return bool true if all the params are valid
bool
ValidateParams
()
override
;
};
}
// namespace api
}
// namespace dataset
}
// namespace mindspore
...
...
tests/ut/cpp/dataset/c_api_test.cc
浏览文件 @
91b4d907
...
...
@@ -764,8 +764,60 @@ TEST_F(MindDataTestPipeline, TestProjectMap) {
iter
->
GetNextRow
(
&
row
);
}
EXPECT_
TRUE
(
i
==
20
);
EXPECT_
EQ
(
i
,
20
);
// Manually terminate the pipeline
iter
->
Stop
();
}
\ No newline at end of file
}
TEST_F
(
MindDataTestPipeline
,
TestZip
)
{
// Create an ImageFolder Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testPK/data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
ImageFolder
(
folder_path
,
true
,
RandomSampler
(
false
,
10
));
EXPECT_TRUE
(
ds
!=
nullptr
);
// Create a Project operation on ds
std
::
vector
<
std
::
string
>
column_project
=
{
"image"
};
ds
=
ds
->
Project
(
column_project
);
EXPECT_TRUE
(
ds
!=
nullptr
);
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
std
::
shared_ptr
<
Dataset
>
ds1
=
Cifar10
(
folder_path
,
0
,
RandomSampler
(
false
,
10
));
EXPECT_TRUE
(
ds1
!=
nullptr
);
// Create a Project operation on ds
column_project
=
{
"label"
};
ds1
=
ds1
->
Project
(
column_project
);
EXPECT_TRUE
(
ds1
!=
nullptr
);
// Create a Zip operation on the datasets
ds
=
ds
->
Zip
({
ds
,
ds1
});
EXPECT_TRUE
(
ds
!=
nullptr
);
// Create a Batch operation on ds
int32_t
batch_size
=
1
;
ds
=
ds
->
Batch
(
batch_size
);
EXPECT_TRUE
(
ds
!=
nullptr
);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std
::
shared_ptr
<
Iterator
>
iter
=
ds
->
CreateIterator
();
EXPECT_TRUE
(
iter
!=
nullptr
);
// Iterate the dataset and get each row
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Tensor
>>
row
;
iter
->
GetNextRow
(
&
row
);
uint64_t
i
=
0
;
while
(
row
.
size
()
!=
0
)
{
i
++
;
auto
image
=
row
[
"image"
];
MS_LOG
(
INFO
)
<<
"Tensor image shape: "
<<
image
->
shape
();
iter
->
GetNextRow
(
&
row
);
}
EXPECT_EQ
(
i
,
10
);
// Manually terminate the pipeline
iter
->
Stop
();
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录