Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
e6a12946
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e6a12946
编写于
8月 01, 2020
作者:
Y
YangLuo
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Updata voc build() & add C++ Op CreateTupleIterator
上级
b7ebe2be
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
305 addition
and
24 deletion
+305
-24
mindspore/ccsrc/minddata/dataset/api/datasets.cc
mindspore/ccsrc/minddata/dataset/api/datasets.cc
+38
-15
mindspore/ccsrc/minddata/dataset/api/iterator.cc
mindspore/ccsrc/minddata/dataset/api/iterator.cc
+13
-0
mindspore/ccsrc/minddata/dataset/api/transforms.cc
mindspore/ccsrc/minddata/dataset/api/transforms.cc
+10
-5
mindspore/ccsrc/minddata/dataset/include/datasets.h
mindspore/ccsrc/minddata/dataset/include/datasets.h
+8
-1
mindspore/ccsrc/minddata/dataset/include/iterator.h
mindspore/ccsrc/minddata/dataset/include/iterator.h
+7
-0
mindspore/ccsrc/minddata/dataset/include/transforms.h
mindspore/ccsrc/minddata/dataset/include/transforms.h
+5
-3
tests/ut/cpp/dataset/CMakeLists.txt
tests/ut/cpp/dataset/CMakeLists.txt
+1
-0
tests/ut/cpp/dataset/c_api_dataset_iterator_test.cc
tests/ut/cpp/dataset/c_api_dataset_iterator_test.cc
+223
-0
未找到文件。
mindspore/ccsrc/minddata/dataset/api/datasets.cc
浏览文件 @
e6a12946
...
...
@@ -61,11 +61,19 @@ namespace api {
} while (false)
// Function to create the iterator, which will build and launch the execution tree.
std
::
shared_ptr
<
Iterator
>
Dataset
::
CreateIterator
()
{
std
::
shared_ptr
<
Iterator
>
Dataset
::
CreateIterator
(
std
::
vector
<
std
::
string
>
columns
)
{
std
::
shared_ptr
<
Iterator
>
iter
;
try
{
auto
ds
=
shared_from_this
();
// The specified columns will be selected from the dataset and passed down the pipeline
// in the order specified, other columns will be discarded.
if
(
!
columns
.
empty
())
{
ds
=
ds
->
Project
(
columns
);
}
iter
=
std
::
make_shared
<
Iterator
>
();
Status
rc
=
iter
->
BuildAndLaunchTree
(
shared_from_this
()
);
Status
rc
=
iter
->
BuildAndLaunchTree
(
ds
);
if
(
rc
.
IsError
())
{
MS_LOG
(
ERROR
)
<<
"CreateIterator failed."
<<
rc
;
return
nullptr
;
...
...
@@ -629,13 +637,13 @@ bool VOCDataset::ValidateParams() {
}
Path
imagesets_file
=
dir
/
"ImageSets"
/
"Segmentation"
/
mode_
+
".txt"
;
if
(
!
imagesets_file
.
Exists
())
{
MS_LOG
(
ERROR
)
<<
"
[Segmentation] imagesets_file is invalid or not exist
"
;
MS_LOG
(
ERROR
)
<<
"
Invalid mode: "
<<
mode_
<<
", file
\"
"
<<
imagesets_file
<<
"
\"
is not exists!
"
;
return
false
;
}
}
else
if
(
task_
==
"Detection"
)
{
Path
imagesets_file
=
dir
/
"ImageSets"
/
"Main"
/
mode_
+
".txt"
;
if
(
!
imagesets_file
.
Exists
())
{
MS_LOG
(
ERROR
)
<<
"
[Detection] imagesets_file is invalid or not exist.
"
;
MS_LOG
(
ERROR
)
<<
"
Invalid mode: "
<<
mode_
<<
", file
\"
"
<<
imagesets_file
<<
"
\"
is not exists!
"
;
return
false
;
}
}
else
{
...
...
@@ -655,18 +663,33 @@ std::vector<std::shared_ptr<DatasetOp>> VOCDataset::Build() {
sampler_
=
CreateDefaultSampler
();
}
std
::
shared_ptr
<
VOCOp
::
Builder
>
builder
=
std
::
make_shared
<
VOCOp
::
Builder
>
();
(
void
)
builder
->
SetDir
(
dataset_dir_
);
(
void
)
builder
->
SetTask
(
task_
);
(
void
)
builder
->
SetMode
(
mode_
);
(
void
)
builder
->
SetNumWorkers
(
num_workers_
);
(
void
)
builder
->
SetSampler
(
std
::
move
(
sampler_
->
Build
()));
(
void
)
builder
->
SetDecode
(
decode_
);
(
void
)
builder
->
SetClassIndex
(
class_index_
);
auto
schema
=
std
::
make_unique
<
DataSchema
>
();
VOCOp
::
TaskType
task_type_
;
std
::
shared_ptr
<
VOCOp
>
op
;
RETURN_EMPTY_IF_ERROR
(
builder
->
Build
(
&
op
));
node_ops
.
push_back
(
op
);
if
(
task_
==
"Segmentation"
)
{
task_type_
=
VOCOp
::
TaskType
::
Segmentation
;
RETURN_EMPTY_IF_ERROR
(
schema
->
AddColumn
(
ColDescriptor
(
std
::
string
(
kColumnImage
),
DataType
(
DataType
::
DE_UINT8
),
TensorImpl
::
kFlexible
,
1
)));
RETURN_EMPTY_IF_ERROR
(
schema
->
AddColumn
(
ColDescriptor
(
std
::
string
(
kColumnTarget
),
DataType
(
DataType
::
DE_UINT8
),
TensorImpl
::
kFlexible
,
1
)));
}
else
if
(
task_
==
"Detection"
)
{
task_type_
=
VOCOp
::
TaskType
::
Detection
;
RETURN_EMPTY_IF_ERROR
(
schema
->
AddColumn
(
ColDescriptor
(
std
::
string
(
kColumnImage
),
DataType
(
DataType
::
DE_UINT8
),
TensorImpl
::
kFlexible
,
1
)));
RETURN_EMPTY_IF_ERROR
(
schema
->
AddColumn
(
ColDescriptor
(
std
::
string
(
kColumnBbox
),
DataType
(
DataType
::
DE_FLOAT32
),
TensorImpl
::
kFlexible
,
1
)));
RETURN_EMPTY_IF_ERROR
(
schema
->
AddColumn
(
ColDescriptor
(
std
::
string
(
kColumnLabel
),
DataType
(
DataType
::
DE_UINT32
),
TensorImpl
::
kFlexible
,
1
)));
RETURN_EMPTY_IF_ERROR
(
schema
->
AddColumn
(
ColDescriptor
(
std
::
string
(
kColumnDifficult
),
DataType
(
DataType
::
DE_UINT32
),
TensorImpl
::
kFlexible
,
1
)));
RETURN_EMPTY_IF_ERROR
(
schema
->
AddColumn
(
ColDescriptor
(
std
::
string
(
kColumnTruncate
),
DataType
(
DataType
::
DE_UINT32
),
TensorImpl
::
kFlexible
,
1
)));
}
std
::
shared_ptr
<
VOCOp
>
voc_op
;
voc_op
=
std
::
make_shared
<
VOCOp
>
(
task_type_
,
mode_
,
dataset_dir_
,
class_index_
,
num_workers_
,
rows_per_buffer_
,
connector_que_size_
,
decode_
,
std
::
move
(
schema
),
std
::
move
(
sampler_
->
Build
()));
node_ops
.
push_back
(
voc_op
);
return
node_ops
;
}
...
...
mindspore/ccsrc/minddata/dataset/api/iterator.cc
浏览文件 @
e6a12946
...
...
@@ -30,6 +30,19 @@ void Iterator::GetNextRow(TensorMap *row) {
}
}
// Get the next row from the data pipeline.
void
Iterator
::
GetNextRow
(
TensorVec
*
row
)
{
TensorRow
tensor_row
;
Status
rc
=
iterator_
->
FetchNextTensorRow
(
&
tensor_row
);
if
(
rc
.
IsError
())
{
MS_LOG
(
ERROR
)
<<
"GetNextRow: Failed to get next row."
;
row
->
clear
();
}
// Generate a vector as return
row
->
clear
();
std
::
copy
(
tensor_row
.
begin
(),
tensor_row
.
end
(),
std
::
back_inserter
(
*
row
));
}
// Shut down the data pipeline.
void
Iterator
::
Stop
()
{
// Releasing the iterator_ unique_ptre. This should trigger the destructor of iterator_.
...
...
mindspore/ccsrc/minddata/dataset/api/transforms.cc
浏览文件 @
e6a12946
...
...
@@ -116,8 +116,9 @@ std::shared_ptr<RandomColorAdjustOperation> RandomColorAdjust(std::vector<float>
// Function to create RandomCropOperation.
std
::
shared_ptr
<
RandomCropOperation
>
RandomCrop
(
std
::
vector
<
int32_t
>
size
,
std
::
vector
<
int32_t
>
padding
,
bool
pad_if_needed
,
std
::
vector
<
uint8_t
>
fill_value
)
{
auto
op
=
std
::
make_shared
<
RandomCropOperation
>
(
size
,
padding
,
pad_if_needed
,
fill_value
);
bool
pad_if_needed
,
std
::
vector
<
uint8_t
>
fill_value
,
BorderType
padding_mode
)
{
auto
op
=
std
::
make_shared
<
RandomCropOperation
>
(
size
,
padding
,
pad_if_needed
,
fill_value
,
padding_mode
);
// Input validation
if
(
!
op
->
ValidateParams
())
{
return
nullptr
;
...
...
@@ -403,8 +404,12 @@ std::shared_ptr<TensorOp> RandomColorAdjustOperation::Build() {
// RandomCropOperation
RandomCropOperation
::
RandomCropOperation
(
std
::
vector
<
int32_t
>
size
,
std
::
vector
<
int32_t
>
padding
,
bool
pad_if_needed
,
std
::
vector
<
uint8_t
>
fill_value
)
:
size_
(
size
),
padding_
(
padding
),
pad_if_needed_
(
pad_if_needed
),
fill_value_
(
fill_value
)
{}
std
::
vector
<
uint8_t
>
fill_value
,
BorderType
padding_mode
)
:
size_
(
size
),
padding_
(
padding
),
pad_if_needed_
(
pad_if_needed
),
fill_value_
(
fill_value
),
padding_mode_
(
padding_mode
)
{}
bool
RandomCropOperation
::
ValidateParams
()
{
if
(
size_
.
empty
()
||
size_
.
size
()
>
2
)
{
...
...
@@ -443,7 +448,7 @@ std::shared_ptr<TensorOp> RandomCropOperation::Build() {
}
auto
tensor_op
=
std
::
make_shared
<
RandomCropOp
>
(
crop_height
,
crop_width
,
pad_top
,
pad_bottom
,
pad_left
,
pad_right
,
BorderType
::
kConstant
,
pad_if_needed_
,
fill_r
,
fill_g
,
fill_b
);
padding_mode_
,
pad_if_needed_
,
fill_r
,
fill_g
,
fill_b
);
return
tensor_op
;
}
...
...
mindspore/ccsrc/minddata/dataset/include/datasets.h
浏览文件 @
e6a12946
...
...
@@ -196,8 +196,9 @@ class Dataset : public std::enable_shared_from_this<Dataset> {
}
/// \brief Function to create an Iterator over the Dataset pipeline
/// \param[in] columns List of columns to be used to specify the order of columns
/// \return Shared pointer to the Iterator
std
::
shared_ptr
<
Iterator
>
CreateIterator
();
std
::
shared_ptr
<
Iterator
>
CreateIterator
(
std
::
vector
<
std
::
string
>
columns
=
{}
);
/// \brief Function to create a BatchDataset
/// \notes Combines batch_size number of consecutive rows into batches
...
...
@@ -452,6 +453,12 @@ class VOCDataset : public Dataset {
bool
ValidateParams
()
override
;
private:
const
std
::
string
kColumnImage
=
"image"
;
const
std
::
string
kColumnTarget
=
"target"
;
const
std
::
string
kColumnBbox
=
"bbox"
;
const
std
::
string
kColumnLabel
=
"label"
;
const
std
::
string
kColumnDifficult
=
"difficult"
;
const
std
::
string
kColumnTruncate
=
"truncate"
;
std
::
string
dataset_dir_
;
std
::
string
task_
;
std
::
string
mode_
;
...
...
mindspore/ccsrc/minddata/dataset/include/iterator.h
浏览文件 @
e6a12946
...
...
@@ -37,6 +37,7 @@ namespace api {
class
Dataset
;
using
TensorMap
=
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Tensor
>>
;
using
TensorVec
=
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
;
// Abstract class for iterating over the dataset.
class
Iterator
{
...
...
@@ -53,9 +54,15 @@ class Iterator {
Status
BuildAndLaunchTree
(
std
::
shared_ptr
<
Dataset
>
ds
);
/// \brief Function to get the next row from the data pipeline.
/// \note Type of return data is a map(with column name).
/// \param[out] row - the output tensor row.
void
GetNextRow
(
TensorMap
*
row
);
/// \brief Function to get the next row from the data pipeline.
/// \note Type of return data is a vector(without column name).
/// \param[out] row - the output tensor row.
void
GetNextRow
(
TensorVec
*
row
);
/// \brief Function to shut down the data pipeline.
void
Stop
();
...
...
mindspore/ccsrc/minddata/dataset/include/transforms.h
浏览文件 @
e6a12946
...
...
@@ -148,8 +148,8 @@ std::shared_ptr<RandomColorAdjustOperation> RandomColorAdjust(std::vector<float>
/// fill R, G, B channels respectively.
/// \return Shared pointer to the current TensorOperation.
std
::
shared_ptr
<
RandomCropOperation
>
RandomCrop
(
std
::
vector
<
int32_t
>
size
,
std
::
vector
<
int32_t
>
padding
=
{
0
,
0
,
0
,
0
},
bool
pad_if_needed
=
false
,
std
::
vector
<
uint8_t
>
fill_value
=
{
0
,
0
,
0
}
);
bool
pad_if_needed
=
false
,
std
::
vector
<
uint8_t
>
fill_value
=
{
0
,
0
,
0
},
BorderType
padding_mode
=
BorderType
::
kConstant
);
/// \brief Function to create a RandomHorizontalFlip TensorOperation.
/// \notes Tensor operation to perform random horizontal flip.
...
...
@@ -311,7 +311,8 @@ class RandomColorAdjustOperation : public TensorOperation {
class
RandomCropOperation
:
public
TensorOperation
{
public:
RandomCropOperation
(
std
::
vector
<
int32_t
>
size
,
std
::
vector
<
int32_t
>
padding
=
{
0
,
0
,
0
,
0
},
bool
pad_if_needed
=
false
,
std
::
vector
<
uint8_t
>
fill_value
=
{
0
,
0
,
0
});
bool
pad_if_needed
=
false
,
std
::
vector
<
uint8_t
>
fill_value
=
{
0
,
0
,
0
},
BorderType
padding_mode
=
BorderType
::
kConstant
);
~
RandomCropOperation
()
=
default
;
...
...
@@ -324,6 +325,7 @@ class RandomCropOperation : public TensorOperation {
std
::
vector
<
int32_t
>
padding_
;
bool
pad_if_needed_
;
std
::
vector
<
uint8_t
>
fill_value_
;
BorderType
padding_mode_
;
};
class
RandomHorizontalFlipOperation
:
public
TensorOperation
{
...
...
tests/ut/cpp/dataset/CMakeLists.txt
浏览文件 @
e6a12946
...
...
@@ -95,6 +95,7 @@ SET(DE_UT_SRCS
c_api_dataset_coco_test.cc
c_api_dataset_voc_test.cc
c_api_datasets_test.cc
c_api_dataset_iterator_test.cc
tensor_op_fusion_pass_test.cc
sliding_window_op_test.cc
epoch_ctrl_op_test.cc
...
...
tests/ut/cpp/dataset/c_api_dataset_iterator_test.cc
0 → 100644
浏览文件 @
e6a12946
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <fstream>
#include <iostream>
#include <memory>
#include <vector>
#include <string>
#include "utils/log_adapter.h"
#include "utils/ms_utils.h"
#include "common/common.h"
#include "gtest/gtest.h"
#include "securec.h"
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/include/status.h"
#include "minddata/dataset/include/transforms.h"
#include "minddata/dataset/include/iterator.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/core/tensor_shape.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/include/samplers.h"
using
namespace
mindspore
::
dataset
::
api
;
using
mindspore
::
MsLogLevel
::
ERROR
;
using
mindspore
::
ExceptionType
::
NoExceptionType
;
using
mindspore
::
LogStream
;
using
mindspore
::
dataset
::
Tensor
;
using
mindspore
::
dataset
::
TensorShape
;
using
mindspore
::
dataset
::
TensorImpl
;
using
mindspore
::
dataset
::
DataType
;
using
mindspore
::
dataset
::
Status
;
using
mindspore
::
dataset
::
BorderType
;
using
mindspore
::
dataset
::
dsize_t
;
class
MindDataTestPipeline
:
public
UT
::
DatasetOpTesting
{
protected:
};
TEST_F
(
MindDataTestPipeline
,
TestIteratorOneColumn
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestIteratorOneColumn."
;
// Create a Mnist Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testMnistData/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
RandomSampler
(
false
,
4
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Batch operation on ds
int32_t
batch_size
=
2
;
ds
=
ds
->
Batch
(
batch_size
);
EXPECT_NE
(
ds
,
nullptr
);
// Create an iterator over the result of the above dataset
// Only select "image" column and drop others
std
::
vector
<
std
::
string
>
columns
=
{
"image"
};
std
::
shared_ptr
<
Iterator
>
iter
=
ds
->
CreateIterator
(
columns
);
EXPECT_NE
(
iter
,
nullptr
);
// Iterate the dataset and get each row
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
row
;
iter
->
GetNextRow
(
&
row
);
TensorShape
expect
({
2
,
28
,
28
,
1
});
uint64_t
i
=
0
;
while
(
row
.
size
()
!=
0
)
{
for
(
auto
&
v
:
row
)
{
MS_LOG
(
INFO
)
<<
"image shape:"
<<
v
->
shape
();
EXPECT_EQ
(
expect
,
v
->
shape
());
}
iter
->
GetNextRow
(
&
row
);
i
++
;
}
EXPECT_EQ
(
i
,
2
);
// Manually terminate the pipeline
iter
->
Stop
();
}
TEST_F
(
MindDataTestPipeline
,
TestIteratorTwoColumns
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestIteratorTwoColumns."
;
// Create a VOC Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testVOC2012_2"
;
std
::
shared_ptr
<
Dataset
>
ds
=
VOC
(
folder_path
,
"Detection"
,
"train"
,
{},
false
,
SequentialSampler
(
0
,
4
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Repeat operation on ds
int32_t
repeat_num
=
2
;
ds
=
ds
->
Repeat
(
repeat_num
);
EXPECT_NE
(
ds
,
nullptr
);
// Create an iterator over the result of the above dataset
// Only select "image" and "bbox" column
std
::
vector
<
std
::
string
>
columns
=
{
"image"
,
"bbox"
};
std
::
shared_ptr
<
Iterator
>
iter
=
ds
->
CreateIterator
(
columns
);
EXPECT_NE
(
iter
,
nullptr
);
// Iterate the dataset and get each row
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
row
;
iter
->
GetNextRow
(
&
row
);
std
::
vector
<
TensorShape
>
expect
=
{
TensorShape
({
173673
}),
TensorShape
({
1
,
4
}),
TensorShape
({
173673
}),
TensorShape
({
1
,
4
}),
TensorShape
({
147025
}),
TensorShape
({
1
,
4
}),
TensorShape
({
211653
}),
TensorShape
({
1
,
4
})};
uint64_t
i
=
0
;
uint64_t
j
=
0
;
while
(
row
.
size
()
!=
0
)
{
MS_LOG
(
INFO
)
<<
"row[0]:"
<<
row
[
0
]
->
shape
()
<<
", row[1]:"
<<
row
[
1
]
->
shape
();
EXPECT_EQ
(
2
,
row
.
size
());
EXPECT_EQ
(
expect
[
j
++
],
row
[
0
]
->
shape
());
EXPECT_EQ
(
expect
[
j
++
],
row
[
1
]
->
shape
());
iter
->
GetNextRow
(
&
row
);
i
++
;
j
=
(
j
==
expect
.
size
())
?
0
:
j
;
}
EXPECT_EQ
(
i
,
8
);
// Manually terminate the pipeline
iter
->
Stop
();
}
TEST_F
(
MindDataTestPipeline
,
TestIteratorEmptyColumn
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestIteratorEmptyColumn."
;
// Create a Cifar10 Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
RandomSampler
(
false
,
5
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Rename operation on ds
ds
=
ds
->
Rename
({
"image"
,
"label"
},
{
"col1"
,
"col2"
});
EXPECT_NE
(
ds
,
nullptr
);
// No columns are specified, use all columns
std
::
shared_ptr
<
Iterator
>
iter
=
ds
->
CreateIterator
();
EXPECT_NE
(
iter
,
nullptr
);
// Iterate the dataset and get each row
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
row
;
iter
->
GetNextRow
(
&
row
);
TensorShape
expect0
({
32
,
32
,
3
});
TensorShape
expect1
({});
uint64_t
i
=
0
;
while
(
row
.
size
()
!=
0
)
{
MS_LOG
(
INFO
)
<<
"row[0]:"
<<
row
[
0
]
->
shape
()
<<
", row[1]:"
<<
row
[
1
]
->
shape
();
EXPECT_EQ
(
expect0
,
row
[
0
]
->
shape
());
EXPECT_EQ
(
expect1
,
row
[
1
]
->
shape
());
iter
->
GetNextRow
(
&
row
);
i
++
;
}
EXPECT_EQ
(
i
,
5
);
// Manually terminate the pipeline
iter
->
Stop
();
}
TEST_F
(
MindDataTestPipeline
,
TestIteratorReOrder
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestIteratorReOrder."
;
// Create a Cifar10 Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testCifar10Data/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Cifar10
(
folder_path
,
SequentialSampler
(
false
,
4
));
EXPECT_NE
(
ds
,
nullptr
);
// Create a Take operation on ds
ds
=
ds
->
Take
(
2
);
EXPECT_NE
(
ds
,
nullptr
);
// Create an iterator over the result of the above dataset
// Reorder "image" and "label" column
std
::
vector
<
std
::
string
>
columns
=
{
"label"
,
"image"
};
std
::
shared_ptr
<
Iterator
>
iter
=
ds
->
CreateIterator
(
columns
);
EXPECT_NE
(
iter
,
nullptr
);
// Iterate the dataset and get each row
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
row
;
iter
->
GetNextRow
(
&
row
);
TensorShape
expect0
({
32
,
32
,
3
});
TensorShape
expect1
({});
// Check if we will catch "label" before "image" in row
std
::
vector
<
std
::
string
>
expect
=
{
"label"
,
"image"
};
uint64_t
i
=
0
;
while
(
row
.
size
()
!=
0
)
{
MS_LOG
(
INFO
)
<<
"row[0]:"
<<
row
[
0
]
->
shape
()
<<
", row[1]:"
<<
row
[
1
]
->
shape
();
EXPECT_EQ
(
expect1
,
row
[
0
]
->
shape
());
EXPECT_EQ
(
expect0
,
row
[
1
]
->
shape
());
iter
->
GetNextRow
(
&
row
);
i
++
;
}
EXPECT_EQ
(
i
,
2
);
// Manually terminate the pipeline
iter
->
Stop
();
}
TEST_F
(
MindDataTestPipeline
,
TestIteratorWrongColumn
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTestPipeline-TestIteratorOneColumn."
;
// Create a Mnist Dataset
std
::
string
folder_path
=
datasets_root_path_
+
"/testMnistData/"
;
std
::
shared_ptr
<
Dataset
>
ds
=
Mnist
(
folder_path
,
RandomSampler
(
false
,
4
));
EXPECT_NE
(
ds
,
nullptr
);
// Pass wrong column name
std
::
vector
<
std
::
string
>
columns
=
{
"digital"
};
std
::
shared_ptr
<
Iterator
>
iter
=
ds
->
CreateIterator
(
columns
);
EXPECT_EQ
(
iter
,
nullptr
);
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录