Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
3ad73b7d
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3ad73b7d
编写于
4月 19, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 19, 2020
浏览文件
操作
浏览文件
下载
差异文件
!463 !330 dataset: add take operation
Merge pull request !463 from ms_yan/take_op_merge
上级
663ae4d7
f0c07c3f
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
779 addition
and
5 deletion
+779
-5
mindspore/ccsrc/dataset/api/de_pipeline.cc
mindspore/ccsrc/dataset/api/de_pipeline.cc
+11
-1
mindspore/ccsrc/dataset/api/de_pipeline.h
mindspore/ccsrc/dataset/api/de_pipeline.h
+1
-1
mindspore/ccsrc/dataset/core/client.h
mindspore/ccsrc/dataset/core/client.h
+1
-0
mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt
mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt
+1
-1
mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc
mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc
+4
-0
mindspore/ccsrc/dataset/engine/datasetops/take_op.cc
mindspore/ccsrc/dataset/engine/datasetops/take_op.cc
+146
-0
mindspore/ccsrc/dataset/engine/datasetops/take_op.h
mindspore/ccsrc/dataset/engine/datasetops/take_op.h
+107
-0
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+64
-1
mindspore/dataset/engine/iterators.py
mindspore/dataset/engine/iterators.py
+2
-0
mindspore/dataset/engine/serializer_deserializer.py
mindspore/dataset/engine/serializer_deserializer.py
+3
-0
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+18
-1
tests/ut/cpp/dataset/CMakeLists.txt
tests/ut/cpp/dataset/CMakeLists.txt
+1
-0
tests/ut/cpp/dataset/take_op_test.cc
tests/ut/cpp/dataset/take_op_test.cc
+103
-0
tests/ut/python/dataset/test_take.py
tests/ut/python/dataset/test_take.py
+317
-0
未找到文件。
mindspore/ccsrc/dataset/api/de_pipeline.cc
浏览文件 @
3ad73b7d
...
...
@@ -54,6 +54,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D
{
kGenerator
,
&
DEPipeline
::
ParseGeneratorOp
},
{
kTfReader
,
&
DEPipeline
::
ParseTFReaderOp
},
{
kProject
,
&
DEPipeline
::
ParseProjectOp
},
{
kTake
,
&
DEPipeline
::
ParseTakeOp
},
{
kImageFolder
,
&
DEPipeline
::
ParseImageFolderOp
},
{
kMnist
,
&
DEPipeline
::
ParseMnistOp
},
{
kManifest
,
&
DEPipeline
::
ParseManifestOp
},
...
...
@@ -650,7 +651,16 @@ Status DEPipeline::ParseRenameOp(const py::dict &args, std::shared_ptr<DatasetOp
return
Status
::
OK
();
}
DsOpPtr
DEPipeline
::
ParseTakeOp
(
const
py
::
dict
&
args
)
const
{
return
DsOpPtr
();
}
Status
DEPipeline
::
ParseTakeOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
)
{
if
(
args
[
"count"
].
is_none
())
{
std
::
string
err_msg
=
"Error: count is invalid or not set."
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
std
::
shared_ptr
<
TakeOp
>
op
;
RETURN_IF_NOT_OK
(
TakeOp
::
Builder
(
ToInt
(
args
[
"count"
])).
Build
(
&
op
));
*
ptr
=
op
;
return
Status
::
OK
();
}
Status
DEPipeline
::
ParseZipOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
)
{
std
::
shared_ptr
<
ZipOp
::
Builder
>
builder
=
std
::
make_shared
<
ZipOp
::
Builder
>
();
...
...
mindspore/ccsrc/dataset/api/de_pipeline.h
浏览文件 @
3ad73b7d
...
...
@@ -116,7 +116,7 @@ class DEPipeline {
Status
ParseRenameOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
DsOpPtr
ParseTakeOp
(
const
py
::
dict
&
args
)
const
;
Status
ParseTakeOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
)
;
Status
ParseZipOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
...
...
mindspore/ccsrc/dataset/core/client.h
浏览文件 @
3ad73b7d
...
...
@@ -38,6 +38,7 @@
#include "dataset/engine/datasetops/source/mindrecord_op.h"
#include "dataset/engine/datasetops/source/storage_op.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/datasetops/take_op.h"
#include "dataset/engine/datasetops/zip_op.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/util/status.h"
...
...
mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt
浏览文件 @
3ad73b7d
...
...
@@ -5,13 +5,13 @@ add_library(engine-datasetops OBJECT
parallel_op.cc
pipeline_op.cc
batch_op.cc
batch_op.cc
device_queue_op.cc
map_op.cc
project_op.cc
rename_op.cc
repeat_op.cc
skip_op.cc
take_op.cc
shuffle_op.cc
zip_op.cc
)
...
...
mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc
浏览文件 @
3ad73b7d
...
...
@@ -88,6 +88,10 @@ Status SkipOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t work
// If buffer is none or the rows of buffer is 0,
// then get a buffer from child.
if
(
!
buf
||
buf
->
NumRows
()
==
0
)
{
if
(
buf
&&
buf
->
eof
())
{
*
p_buffer
=
std
::
move
(
buf
);
return
Status
::
OK
();
}
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
buf
,
worker_id
,
true
));
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/take_op.cc
0 → 100644
浏览文件 @
3ad73b7d
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <utility>
#include "common/utils.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/datasetops/take_op.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
namespace
mindspore
{
namespace
dataset
{
// Builder constructor. Creates the builder object.
TakeOp
::
Builder
::
Builder
(
int32_t
count
)
:
build_max_takes_
(
count
)
{}
Status
TakeOp
::
Builder
::
SanityCheck
()
const
{
if
(
build_max_takes_
<=
0
)
{
std
::
string
err_msg
(
"Take count must be greater than 0."
);
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
return
Status
::
OK
();
}
// The builder "build" method creates the final object.
Status
TakeOp
::
Builder
::
Build
(
std
::
shared_ptr
<
TakeOp
>
*
ptr
)
{
RETURN_IF_NOT_OK
(
SanityCheck
());
*
ptr
=
std
::
make_shared
<
TakeOp
>
(
build_max_takes_
);
return
Status
::
OK
();
}
// Constructor of the TakeOp.
TakeOp
::
TakeOp
(
int32_t
count
)
:
PipelineOp
(
0
),
max_takes_
(
count
),
take_count_
(
0
)
{}
// A print method typically used for debugging
void
TakeOp
::
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
{
// Call base class printer first
PipelineOp
::
Print
(
out
,
show_all
);
// Then display our own stuff
out
<<
"TakeOp:"
<<
"
\n
Current take count: "
<<
take_count_
<<
"
\n
Max take count: "
<<
max_takes_
;
}
// This function will be call muti times to returns the buffer, when meet required max take count or meet
// EOF buffer then this will stop.
Status
TakeOp
::
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
p_buffer
,
int32_t
worker_id
,
bool
retry_if_eoe
)
{
if
(
child_
.
empty
())
{
RETURN_STATUS_UNEXPECTED
(
"TakeOp can't be the leaf node."
);
}
std
::
unique_ptr
<
DataBuffer
>
buf
;
bool
last_repeat
=
!
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
)
||
BitTest
(
op_ctrl_flags_
,
kDeOpLastRepeat
);
if
(
take_count_
==
max_takes_
)
{
if
(
state_
==
OpState
::
kDeOpRunning
)
{
MS_LOG
(
INFO
)
<<
"meet max count and push-back eoe buffer."
;
auto
eoe_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
*
p_buffer
=
std
::
move
(
eoe_buffer
);
state_
=
OpState
::
kDeOpIdle
;
// Reset the count and drain
if
(
!
last_repeat
)
{
take_count_
=
0
;
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
buf
,
worker_id
,
true
));
while
(
!
buf
->
eoe
()
&&
!
buf
->
eof
())
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
buf
,
worker_id
,
true
));
}
}
}
else
{
MS_LOG
(
INFO
)
<<
"meet max count and push-back eof buffer."
;
auto
eof_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOF
);
*
p_buffer
=
std
::
move
(
eof_buffer
);
take_count_
=
0
;
}
return
Status
::
OK
();
}
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
buf
,
worker_id
,
true
));
// Loop until non EOE is received
if
(
buf
->
eoe
())
{
take_count_
=
0
;
*
p_buffer
=
std
::
move
(
buf
);
return
Status
::
OK
();
}
// Check if the last buf is next eof
if
(
buf
->
eof
())
{
*
p_buffer
=
std
::
move
(
buf
);
return
Status
::
OK
();
}
// Get buffer and push back when take_count is still small
if
(
take_count_
<
max_takes_
)
{
RETURN_IF_NOT_OK
(
FillBuffer
(
&
buf
,
p_buffer
));
}
return
Status
::
OK
();
}
// Function FillBuffer mainly prepare the buffer for returning
Status
TakeOp
::
FillBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
buffer
,
std
::
unique_ptr
<
DataBuffer
>
*
data_buffer
)
{
int32_t
buffer_size
=
(
*
buffer
)
->
NumRows
();
if
(
take_count_
+
buffer_size
<
max_takes_
)
{
*
data_buffer
=
std
::
move
(
*
buffer
);
take_count_
=
take_count_
+
buffer_size
;
}
else
{
MS_LOG
(
INFO
)
<<
"In last buffer: Push one buffer."
;
std
::
unique_ptr
<
TensorQTable
>
new_tensor_table
=
std
::
make_unique
<
TensorQTable
>
();
while
(
take_count_
<
max_takes_
)
{
TensorRow
new_row
;
RETURN_IF_NOT_OK
((
*
buffer
)
->
PopRow
(
&
new_row
));
take_count_
++
;
new_tensor_table
->
push_back
(
new_row
);
}
(
*
buffer
)
->
set_tensor_table
(
std
::
move
(
new_tensor_table
));
*
data_buffer
=
std
::
move
(
*
buffer
);
}
return
Status
::
OK
();
}
// Class functor operator () override.
// Most dataset ops operate by launching a thread (see ExecutionTree).
// However, the TakeOp is defined as a inlined operator, so it is invalid to launch the
// functor since this op runs inlined inside another operator. The function is overloaded to
// ensure that it is not called by mistake (it will generate an error).
Status
TakeOp
::
operator
()()
{
RETURN_STATUS_UNEXPECTED
(
"Logic error. TakeOp is an inlined operator."
);
}
Status
TakeOp
::
PrepareNodePostAction
()
{
RETURN_IF_NOT_OK
(
PipelineOp
::
PrepareNodePostAction
());
tree_
->
AddToRepeatStack
(
shared_from_this
());
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/take_op.h
0 → 100644
浏览文件 @
3ad73b7d
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_TAKE_OP_H_
#define DATASET_ENGINE_DATASETOPS_TAKE_OP_H_
#include <memory>
#include <string>
#include <vector>
#include "dataset/engine/datasetops/pipeline_op.h"
namespace
mindspore
{
namespace
dataset
{
class
TakeOp
:
public
PipelineOp
{
public:
// The nested builder class inside of the TakeOp is used to help manage all of the arguments
// for constructing it. This take op is very simple though, so this builder is really just
// provided for a consistent look and feel for creators of Dataset operators overall.
class
Builder
{
public:
// Builder constructor. Creates the builder object.
// @note No default args
// @param count - The number of takes to do
// @return This is a constructor.
explicit
Builder
(
int32_t
count
);
// Default destructor
~
Builder
()
=
default
;
// The builder "build" method creates the final object.
// @return shared_ptr to the new StorageOp object
Status
Build
(
std
::
shared_ptr
<
TakeOp
>
*
);
private:
int32_t
build_max_takes_
;
Status
SanityCheck
()
const
;
};
// Constructor of the TakeOp.
// @note The builder class should be used to call it
// @param count - The number of takes to do
explicit
TakeOp
(
int32_t
count
);
// Destructor
~
TakeOp
()
=
default
;
// A print method typically used for debugging
// @param out - The output stream to write output to
// @param show_all - A bool to control if you want to show all info or just a summary
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
// << Stream output operator overload
// @notes This allows you to write the debug print info using stream operators
// @param out - reference to the output stream being overloaded
// @param ro - reference to the TakeOp to display
// @return - the output stream must be returned
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
TakeOp
&
ro
)
{
ro
.
Print
(
out
,
false
);
return
out
;
}
// Class functor operator () override.
// Most dataset ops operate by launching a thread (see ExecutionTree).
// However, the TakeOp is defined as a inlined operator, so it is invalid to launch the
// functor since this op runs inlined inside another operator. The function is overloaded to
// ensure that it is not called by mistake (it will generate an error).
// @return Status - The error code return
Status
operator
()()
override
;
// Gets a buffer from the child node. The caller is typically our parent node.
// @note This function sets the `retryIfEoe` flag when popping from the child connector. This way,
// this function will retry to pop the connector again and will get the non-EOE buffer if any.
// @param p_buffer - output pointer to the buffer that it will fetch.
// @param worker_id - The worker id
// @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE.
// @return Status - The error code return
Status
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
p_buffer
,
int32_t
worker_id
,
bool
retry_if_eoe
)
override
;
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
Status
PrepareNodePostAction
()
override
;
private:
int32_t
max_takes_
;
// The number of takes that the user requested
int32_t
take_count_
;
// A counter for the current number of executed takes
Status
FillBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
buffer
,
std
::
unique_ptr
<
DataBuffer
>
*
data_buffer
);
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_TAKE_OP_H_
mindspore/dataset/engine/datasets.py
浏览文件 @
3ad73b7d
...
...
@@ -36,7 +36,7 @@ from mindspore import log as logger
from
.
import
samplers
from
.iterators
import
DictIterator
,
TupleIterator
from
.validators
import
check
,
check_batch
,
check_shuffle
,
check_map
,
check_repeat
,
check_skip
,
check_zip
,
check_rename
,
\
check_project
,
check_imagefolderdatasetv2
,
check_mnist_cifar_dataset
,
check_manifestdataset
,
\
check_
take
,
check_
project
,
check_imagefolderdatasetv2
,
check_mnist_cifar_dataset
,
check_manifestdataset
,
\
check_tfrecorddataset
,
check_vocdataset
,
check_celebadataset
,
check_minddataset
,
check_generatordataset
,
\
check_zip_dataset
,
check_add_column
from
..core.datatypes
import
mstype_to_detype
,
mstypelist_to_detypelist
...
...
@@ -442,6 +442,33 @@ class Dataset:
"""
return
SkipDataset
(
self
,
count
)
@
check_take
def
take
(
self
,
count
=-
1
):
"""
Takes at most given numbers of elements from the dataset.
Note:
1. If count is greater than the number of element in dataset or equal to -1,
all the element in dataset will be taken.
2. The order of using take and batch effects. If take before batch operation,
then taken given number of rows, otherwise take given number of batches.
Args:
count (int, optional): Number of elements to be taken from the dataset (default=-1).
Returns:
TakeDataset, dataset taken.
Examples:
>>> import mindspore.dataset as ds
>>> # data is an instance of Dataset object.
>>> # creates a dataset where the dataset including 50 elements.
>>> data = data.take(50)
"""
if
count
==
-
1
:
return
self
return
TakeDataset
(
self
,
count
)
@
check_zip_dataset
def
zip
(
self
,
datasets
):
"""
...
...
@@ -1100,6 +1127,7 @@ class RepeatDataset(DatasetOp):
"""
return
self
.
count
class
SkipDataset
(
DatasetOp
):
"""
The result of applying Skip operator to the input Dataset.
...
...
@@ -1134,6 +1162,41 @@ class SkipDataset(DatasetOp):
output_size
=
child_size
-
self
.
count
return
output_size
class
TakeDataset
(
DatasetOp
):
"""
The result of applying Take operator to the input Dataset.
Args:
input_dataset (Dataset): Input Dataset to be taken element from.
count (int): Number of elements to be taken from the dataset.
"""
def
__init__
(
self
,
input_dataset
,
count
):
super
().
__init__
()
self
.
count
=
count
self
.
input
.
append
(
input_dataset
)
input_dataset
.
output
.
append
(
self
)
self
.
_input_indexs
=
input_dataset
.
input_indexs
def
get_args
(
self
):
args
=
super
().
get_args
()
args
[
"count"
]
=
self
.
count
return
args
def
get_dataset_size
(
self
):
"""
Get the number of batches in an epoch.
Return:
Number, number of batches.
"""
child_size
=
self
.
input
[
0
].
get_dataset_size
()
if
child_size
<
self
.
count
:
return
child_size
return
self
.
count
class
ZipDataset
(
DatasetOp
):
"""
The result of applying Zip operator to the input Dataset.
...
...
mindspore/dataset/engine/iterators.py
浏览文件 @
3ad73b7d
...
...
@@ -129,6 +129,8 @@ class Iterator:
op_type
=
OpName
.
REPEAT
elif
isinstance
(
dataset
,
de
.
SkipDataset
):
op_type
=
OpName
.
SKIP
elif
isinstance
(
dataset
,
de
.
TakeDataset
):
op_type
=
OpName
.
TAKE
elif
isinstance
(
dataset
,
de
.
StorageDataset
):
op_type
=
OpName
.
STORAGE
elif
isinstance
(
dataset
,
de
.
ImageFolderDatasetV2
):
...
...
mindspore/dataset/engine/serializer_deserializer.py
浏览文件 @
3ad73b7d
...
...
@@ -304,6 +304,9 @@ def create_node(node):
elif
dataset_op
==
'SkipDataset'
:
pyobj
=
de
.
Dataset
().
skip
(
node
.
get
(
'count'
))
elif
dataset_op
==
'TakeDataset'
:
pyobj
=
de
.
Dataset
().
take
(
node
.
get
(
'count'
))
elif
dataset_op
==
'MapDataset'
:
tensor_ops
=
construct_tensor_ops
(
node
.
get
(
'operations'
))
pyobj
=
de
.
Dataset
().
map
(
node
.
get
(
'input_columns'
),
tensor_ops
,
node
.
get
(
'output_columns'
),
...
...
mindspore/dataset/engine/validators.py
浏览文件 @
3ad73b7d
...
...
@@ -602,7 +602,7 @@ def check_batch_size(batch_size):
def
check_count
(
count
):
check_type
(
count
,
'count'
,
int
)
if
(
count
<=
0
and
count
!=
-
1
)
or
count
>
INT32_MAX
:
raise
ValueError
(
"
repeat
count should be either -1 or positive integer."
)
raise
ValueError
(
"count should be either -1 or positive integer."
)
def
check_columns
(
columns
,
name
):
...
...
@@ -709,6 +709,7 @@ def check_repeat(method):
return
new_method
def
check_skip
(
method
):
"""check the input arguments of skip."""
@
wraps
(
method
)
...
...
@@ -724,6 +725,21 @@ def check_skip(method):
return
new_method
def
check_take
(
method
):
"""check the input arguments of take."""
@
wraps
(
method
)
def
new_method
(
*
args
,
**
kwargs
):
param_dict
=
make_param_dict
(
method
,
args
,
kwargs
)
count
=
param_dict
.
get
(
'count'
)
check_count
(
count
)
return
method
(
*
args
,
**
kwargs
)
return
new_method
def
check_zip
(
method
):
"""check the input arguments of zip."""
@
wraps
(
method
)
...
...
@@ -759,6 +775,7 @@ def check_zip_dataset(method):
return
new_method
def
check_rename
(
method
):
"""check the input arguments of rename."""
@
wraps
(
method
)
...
...
tests/ut/cpp/dataset/CMakeLists.txt
浏览文件 @
3ad73b7d
...
...
@@ -64,6 +64,7 @@ SET(DE_UT_SRCS
voc_op_test.cc
cifar_op_test.cc
celeba_op_test.cc
take_op_test.cc
)
add_executable
(
de_ut_tests
${
DE_UT_SRCS
}
)
...
...
tests/ut/cpp/dataset/take_op_test.cc
0 → 100644
浏览文件 @
3ad73b7d
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include <vector>
#include "common/common.h"
#include "common/utils.h"
#include "dataset/core/client.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
namespace
common
=
mindspore
::
common
;
using
namespace
mindspore
::
dataset
;
using
mindspore
::
MsLogLevel
::
INFO
;
using
mindspore
::
ExceptionType
::
NoExceptionType
;
using
mindspore
::
LogStream
;
class
MindDataTestTakeOp
:
public
UT
::
DatasetOpTesting
{};
TEST_F
(
MindDataTestTakeOp
,
TestTakeProject
)
{
// Start with an empty execution tree
auto
my_tree
=
std
::
make_shared
<
ExecutionTree
>
();
std
::
string
dataset_path
;
dataset_path
=
datasets_root_path_
+
"/testTFTestAllTypes/test.data"
;
// TFReaderOp
std
::
shared_ptr
<
TFReaderOp
>
my_tfreader_op
;
TFReaderOp
::
Builder
builder
;
builder
.
SetDatasetFilesList
({
dataset_path
})
.
SetRowsPerBuffer
(
16
)
.
SetWorkerConnectorSize
(
16
)
.
SetNumWorkers
(
16
);
std
::
unique_ptr
<
DataSchema
>
schema
=
std
::
make_unique
<
DataSchema
>
();
schema
->
LoadSchemaFile
(
datasets_root_path_
+
"/testTFTestAllTypes/datasetSchema.json"
,
{});
builder
.
SetDataSchema
(
std
::
move
(
schema
));
Status
rc
=
builder
.
Build
(
&
my_tfreader_op
);
ASSERT_TRUE
(
rc
.
IsOk
());
// TakeOp
std
::
shared_ptr
<
TakeOp
>
my_take_op
;
TakeOp
::
Builder
builder_take
(
5
);
rc
=
builder_take
.
Build
(
&
my_take_op
);
ASSERT_TRUE
(
rc
.
IsOk
());
rc
=
my_tree
->
AssociateNode
(
my_tfreader_op
);
ASSERT_TRUE
(
rc
.
IsOk
());
rc
=
my_tree
->
AssociateNode
(
my_take_op
);
ASSERT_TRUE
(
rc
.
IsOk
());
// Set children/root layout.
rc
=
my_take_op
->
AddChild
(
my_tfreader_op
);
ASSERT_TRUE
(
rc
.
IsOk
());
rc
=
my_tree
->
AssignRoot
(
my_take_op
);
ASSERT_TRUE
(
rc
.
IsOk
());
MS_LOG
(
INFO
)
<<
"Launching tree and begin iteration."
;
rc
=
my_tree
->
Prepare
();
ASSERT_TRUE
(
rc
.
IsOk
());
rc
=
my_tree
->
Launch
();
ASSERT_TRUE
(
rc
.
IsOk
());
// Start the loop of reading tensors from our pipeline
DatasetIterator
di
(
my_tree
);
TensorRow
tensor_list
;
rc
=
di
.
FetchNextTensorRow
(
&
tensor_list
);
ASSERT_TRUE
(
rc
.
IsOk
());
int
row_count
=
0
;
while
(
!
tensor_list
.
empty
())
{
MS_LOG
(
INFO
)
<<
"Row display for row #: "
<<
row_count
<<
"."
;
// Display the tensor by calling the printer on it
for
(
int
i
=
0
;
i
<
tensor_list
.
size
();
i
++
)
{
std
::
ostringstream
ss
;
ss
<<
"("
<<
tensor_list
[
i
]
<<
"): "
<<
*
tensor_list
[
i
]
<<
std
::
endl
;
MS_LOG
(
INFO
)
<<
"Tensor print: "
<<
ss
.
str
()
<<
"."
;
}
rc
=
di
.
FetchNextTensorRow
(
&
tensor_list
);
ASSERT_TRUE
(
rc
.
IsOk
());
row_count
++
;
}
ASSERT_EQ
(
row_count
,
5
);
}
tests/ut/python/dataset/test_take.py
0 → 100644
浏览文件 @
3ad73b7d
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
from
mindspore
import
log
as
logger
import
numpy
as
np
# In generator dataset: Number of rows is 3, its value is 0, 1, 2
def
generator
():
for
i
in
range
(
3
):
yield
np
.
array
([
i
]),
# In generator dataset: Number of rows is 10, its value is 0, 1, 2 ... 10
def
generator_10
():
for
i
in
range
(
10
):
yield
np
.
array
([
i
]),
def
test_take_01
():
"""
Test take: origin there are 3 row, and take 1 row, in this case: will not meet eoe and eof
"""
logger
.
info
(
"test_take_01"
)
data1
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data1
=
data1
.
take
(
1
)
data1
=
data1
.
repeat
(
2
)
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
assert
0
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data1
])
==
2
def
test_take_02
():
"""
Test take: origin there are 3 row, and take 2 row, in this case: will meet eoe
"""
logger
.
info
(
"test_take_02"
)
data1
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data1
=
data1
.
take
(
2
)
data1
=
data1
.
repeat
(
2
)
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
assert
i
%
2
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data1
])
==
4
def
test_take_03
():
"""
Test take: origin there are 3 row, and take 3 row, in this case: will meet eoe and eof
"""
logger
.
info
(
"test_take_03"
)
data1
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data1
=
data1
.
take
(
3
)
data1
=
data1
.
repeat
(
2
)
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
assert
i
%
3
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data1
])
==
6
def
test_take_04
():
"""
Test take: origin there are 3 row, and take 4 row, this is more than the total rows
"""
logger
.
info
(
"test_take_04"
)
data1
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data1
=
data1
.
take
(
4
)
data1
=
data1
.
repeat
(
2
)
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
assert
i
%
3
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data1
])
==
6
def
test_take_05
():
"""
Test take: there is no repeat op
"""
logger
.
info
(
"test_take_05"
)
data1
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data1
=
data1
.
take
(
2
)
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
assert
i
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data1
])
==
2
def
test_take_06
():
"""
Test take: repeat is before take
"""
logger
.
info
(
"test_take_06"
)
data1
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data1
=
data1
.
repeat
(
2
)
data1
=
data1
.
take
(
4
)
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
assert
i
%
3
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data1
])
==
4
def
test_take_07
():
"""
Test take: take is before batch, that mean take(N), N refer to rows num
"""
logger
.
info
(
"test_take_07"
)
data1
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data1
=
data1
.
take
(
2
)
data1
=
data1
.
batch
(
2
)
assert
sum
([
1
for
_
in
data1
])
==
1
def
test_take_08
():
"""
Test take: take is after batch, that mean take(N), N refer to batches num
"""
logger
.
info
(
"test_take_08"
)
data1
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data1
=
data1
.
batch
(
2
)
data1
=
data1
.
take
(
2
)
assert
sum
([
1
for
_
in
data1
])
==
2
def
test_take_09
():
"""
Test take: repeat count is -1, and read the whole dataset, take after repeat
"""
logger
.
info
(
"test_take_09"
)
data1
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data1
=
data1
.
repeat
(
2
)
data1
=
data1
.
take
(
-
1
)
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
assert
i
%
3
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data1
])
==
6
def
test_take_10
():
"""
Test take: repeat count is -1, and read the whole dataset, take before repeat
"""
logger
.
info
(
"test_take_10"
)
data1
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data1
=
data1
.
take
(
-
1
)
data1
=
data1
.
repeat
(
2
)
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
assert
i
%
3
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data1
])
==
6
def
test_take_11
():
"""
Test take: batch first, then do repeat and take operation
"""
logger
.
info
(
"test_take_11"
)
data1
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data1
=
data1
.
batch
(
2
)
data1
=
data1
.
repeat
(
2
)
data1
=
data1
.
take
(
-
1
)
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
assert
2
*
(
i
%
2
)
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data1
])
==
4
def
test_take_12
():
"""
Test take: take first, then do batch and repeat operation
"""
logger
.
info
(
"test_take_12"
)
data1
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data1
=
data1
.
take
(
2
)
data1
=
data1
.
batch
(
2
)
data1
=
data1
.
repeat
(
2
)
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
assert
0
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data1
])
==
2
def
test_take_13
():
"""
Test take: skip first, then do take, batch and repeat operation
"""
logger
.
info
(
"test_take_13"
)
data1
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data1
=
data1
.
skip
(
2
)
data1
=
data1
.
take
(
-
1
)
data1
=
data1
.
batch
(
2
)
data1
=
data1
.
repeat
(
2
)
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
assert
2
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data1
])
==
2
def
test_take_14
():
"""
Test take: take first, then do batch, skip and repeat operation
"""
logger
.
info
(
"test_take_14"
)
data1
=
ds
.
GeneratorDataset
(
generator
,
[
"data"
])
data1
=
data1
.
take
(
-
1
)
data1
=
data1
.
batch
(
2
)
data1
=
data1
.
skip
(
1
)
data1
=
data1
.
repeat
(
2
)
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
assert
2
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data1
])
==
2
def
test_take_15
():
"""
Test take: large amount data, take a part, then do skip operation
"""
logger
.
info
(
"test_take_15"
)
data1
=
ds
.
GeneratorDataset
(
generator_10
,
[
"data"
])
data1
=
data1
.
take
(
6
)
data1
=
data1
.
skip
(
2
)
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
assert
(
i
+
2
)
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data1
])
==
4
def
test_take_16
():
"""
Test take: large amount data, skip a part, then do take operation
"""
logger
.
info
(
"test_take_16"
)
data1
=
ds
.
GeneratorDataset
(
generator_10
,
[
"data"
])
data1
=
data1
.
skip
(
3
)
data1
=
data1
.
take
(
5
)
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data1
):
assert
(
i
+
3
)
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data1
])
==
5
if
__name__
==
'__main__'
:
test_take_01
()
test_take_02
()
test_take_03
()
test_take_04
()
test_take_05
()
test_take_06
()
test_take_07
()
test_take_08
()
test_take_09
()
test_take_10
()
test_take_11
()
test_take_12
()
test_take_13
()
test_take_14
()
test_take_15
()
test_take_16
()
logger
.
info
(
'== test take operation finished =='
)
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录