Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
3f7054dc
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3f7054dc
编写于
4月 09, 2020
作者:
J
jzw
提交者:
jiangzhiwen
4月 17, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add skip dataset op
上级
f69a668d
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
535 addition
and
1 deletion
+535
-1
mindspore/ccsrc/dataset/api/de_pipeline.cc
mindspore/ccsrc/dataset/api/de_pipeline.cc
+12
-0
mindspore/ccsrc/dataset/api/de_pipeline.h
mindspore/ccsrc/dataset/api/de_pipeline.h
+3
-0
mindspore/ccsrc/dataset/api/python_bindings.cc
mindspore/ccsrc/dataset/api/python_bindings.cc
+1
-0
mindspore/ccsrc/dataset/core/client.h
mindspore/ccsrc/dataset/core/client.h
+1
-0
mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt
mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt
+1
-0
mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc
mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc
+128
-0
mindspore/ccsrc/dataset/engine/datasetops/skip_op.h
mindspore/ccsrc/dataset/engine/datasetops/skip_op.h
+95
-0
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+53
-1
mindspore/dataset/engine/iterators.py
mindspore/dataset/engine/iterators.py
+2
-0
mindspore/dataset/engine/serializer_deserializer.py
mindspore/dataset/engine/serializer_deserializer.py
+3
-0
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+14
-0
tests/ut/cpp/dataset/CMakeLists.txt
tests/ut/cpp/dataset/CMakeLists.txt
+1
-0
tests/ut/cpp/dataset/skip_op_test.cc
tests/ut/cpp/dataset/skip_op_test.cc
+91
-0
tests/ut/python/dataset/test_skip.py
tests/ut/python/dataset/test_skip.py
+130
-0
未找到文件。
mindspore/ccsrc/dataset/api/de_pipeline.cc
浏览文件 @
3f7054dc
...
...
@@ -47,6 +47,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D
{
kMap
,
&
DEPipeline
::
ParseMapOp
},
{
kBatch
,
&
DEPipeline
::
ParseBatchOp
},
{
kRepeat
,
&
DEPipeline
::
ParseRepeatOp
},
{
kSkip
,
&
DEPipeline
::
ParseSkipOp
},
{
kZip
,
&
DEPipeline
::
ParseZipOp
},
{
kRename
,
&
DEPipeline
::
ParseRenameOp
},
{
kDeviceQueue
,
&
DEPipeline
::
ParseDeviceQueueOp
},
...
...
@@ -511,6 +512,17 @@ Status DEPipeline::ParseRepeatOp(const py::dict &args, std::shared_ptr<DatasetOp
return
Status
::
OK
();
}
Status
DEPipeline
::
ParseSkipOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
)
{
if
(
args
[
"count"
].
is_none
())
{
std
::
string
err_msg
=
"Error: count is invalid or not set."
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
std
::
shared_ptr
<
SkipOp
>
op
;
RETURN_IF_NOT_OK
(
SkipOp
::
Builder
(
ToInt
(
args
[
"count"
])).
Build
(
&
op
));
*
ptr
=
op
;
return
Status
::
OK
();
}
Status
DEPipeline
::
ParseGeneratorOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
)
{
std
::
shared_ptr
<
GeneratorOp
::
Builder
>
builder
=
std
::
make_shared
<
GeneratorOp
::
Builder
>
();
for
(
auto
arg
:
args
)
{
...
...
mindspore/ccsrc/dataset/api/de_pipeline.h
浏览文件 @
3f7054dc
...
...
@@ -42,6 +42,7 @@ enum OpName {
kBatch
,
kCache
,
kRepeat
,
kSkip
,
kTake
,
kZip
,
kMap
,
...
...
@@ -107,6 +108,8 @@ class DEPipeline {
Status
ParseRepeatOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseSkipOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseBatchOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseGeneratorOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
...
...
mindspore/ccsrc/dataset/api/python_bindings.cc
浏览文件 @
3f7054dc
...
...
@@ -446,6 +446,7 @@ PYBIND11_MODULE(_c_dataengine, m) {
.
value
(
"MINDRECORD"
,
OpName
::
kMindrecord
)
.
value
(
"CACHE"
,
OpName
::
kCache
)
.
value
(
"REPEAT"
,
OpName
::
kRepeat
)
.
value
(
"SKIP"
,
OpName
::
kSkip
)
.
value
(
"TAKE"
,
OpName
::
kTake
)
.
value
(
"ZIP"
,
OpName
::
kZip
)
.
value
(
"MAP"
,
OpName
::
kMap
)
...
...
mindspore/ccsrc/dataset/core/client.h
浏览文件 @
3f7054dc
...
...
@@ -32,6 +32,7 @@
#include "dataset/engine/datasetops/project_op.h"
#include "dataset/engine/datasetops/rename_op.h"
#include "dataset/engine/datasetops/repeat_op.h"
#include "dataset/engine/datasetops/skip_op.h"
#include "dataset/engine/datasetops/shuffle_op.h"
#include "dataset/engine/datasetops/source/generator_op.h"
#include "dataset/engine/datasetops/source/mindrecord_op.h"
...
...
mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt
浏览文件 @
3f7054dc
...
...
@@ -11,6 +11,7 @@ add_library(engine-datasetops OBJECT
project_op.cc
rename_op.cc
repeat_op.cc
skip_op.cc
shuffle_op.cc
zip_op.cc
)
...
...
mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc
0 → 100644
浏览文件 @
3f7054dc
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <utility>
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/datasetops/skip_op.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "utils/log_adapter.h"
namespace
mindspore
{
namespace
dataset
{
// Builder constructor. Creates the builder object.
SkipOp
::
Builder
::
Builder
(
int32_t
count
)
:
build_max_skips_
(
count
)
{}
Status
SkipOp
::
Builder
::
SanityCheck
()
const
{
if
(
build_max_skips_
<
0
)
{
std
::
string
err_msg
(
"Skip count must be positive integer or 0."
);
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
return
Status
::
OK
();
}
// The builder "build" method creates the final object.
Status
SkipOp
::
Builder
::
Build
(
std
::
shared_ptr
<
SkipOp
>
*
ptr
)
{
RETURN_IF_NOT_OK
(
SanityCheck
());
*
ptr
=
std
::
make_shared
<
SkipOp
>
(
build_max_skips_
);
return
Status
::
OK
();
}
// Constructor of the SkipOp.
SkipOp
::
SkipOp
(
int32_t
count
)
:
PipelineOp
(
0
),
max_skips_
(
count
),
skip_count_
(
0
)
{}
// Destructor
SkipOp
::~
SkipOp
()
{}
// A print method typically used for debugging
void
SkipOp
::
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
{
// Call base class printer first
PipelineOp
::
Print
(
out
,
show_all
);
// Then display our own stuff
out
<<
"SkipOp:"
<<
"
\n
Current skip count: "
<<
skip_count_
<<
"
\n
Max skip count: "
<<
max_skips_
;
}
// Since the buffer may contain multi rows, this function will drop the rows
// that need to skip in it, and then return the buffer.
Status
SkipOp
::
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
p_buffer
,
int32_t
worker_id
,
bool
retry_if_eoe
)
{
if
(
child_
.
empty
())
{
RETURN_STATUS_UNEXPECTED
(
"SkipOp can't be the leaf node."
);
}
std
::
unique_ptr
<
DataBuffer
>
buf
;
// Drop first max_skips_ rows
while
(
skip_count_
<
max_skips_
)
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
buf
,
worker_id
,
true
));
if
(
buf
->
eoe
()
||
buf
->
eof
())
{
break
;
}
// Consider the rows of buffer more than 1
TensorRow
drop_row
;
int
row_num
=
buf
->
NumRows
();
for
(
int
i
=
0
;
i
<
row_num
;
i
++
)
{
RETURN_IF_NOT_OK
(
buf
->
PopRow
(
&
drop_row
));
if
(
++
skip_count_
==
max_skips_
)
{
break
;
}
}
}
// If buffer is none or the rows of buffer is 0,
// then get a buffer from child.
if
(
!
buf
||
buf
->
NumRows
()
==
0
)
{
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
buf
,
worker_id
,
true
));
}
// Handling eoe and eof
if
(
buf
->
eoe
()
||
buf
->
eof
())
{
RETURN_IF_NOT_OK
(
EoeReceived
(
worker_id
));
if
(
state_
==
OpState
::
kDeOpIdle
)
{
*
p_buffer
=
std
::
move
(
buf
);
return
Status
::
OK
();
}
}
*
p_buffer
=
std
::
move
(
buf
);
return
Status
::
OK
();
}
// Base-class override for handling cases when an eoe is received.
Status
SkipOp
::
EoeReceived
(
int32_t
worker_id
)
{
skip_count_
=
0
;
state_
=
OpState
::
kDeOpIdle
;
return
Status
::
OK
();
}
// Class functor operator () override.
// Most dataset ops operate by launching a thread (see ExecutionTree).
// However, the SkipOp is defined as a inlined operator, so it is invalid to
// launch the functor since this op runs inlined inside another operator. The
// function is overloaded to ensure that it is not called by mistake (it will
// generate an error).
Status
SkipOp
::
operator
()()
{
RETURN_STATUS_UNEXPECTED
(
"Logic error. SkipOp is an inlined operator."
);
}
// Base-class override for handling cases when an eof is received.
Status
SkipOp
::
EofReceived
(
int32_t
worker_id
)
{
MS_LOG
(
INFO
)
<<
"Skip operator EOF received, do nothing now."
;
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/skip_op.h
0 → 100644
浏览文件 @
3f7054dc
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_SKIP_OP_H_
#define DATASET_ENGINE_DATASETOPS_SKIP_OP_H_
#include <memory>
#include <string>
#include <vector>
#include "dataset/engine/datasetops/pipeline_op.h"
namespace
mindspore
{
namespace
dataset
{
class
SkipOp
:
public
PipelineOp
{
public:
class
Builder
{
public:
// Builder constructor. Creates the builder object.
// @note No default args
// @param count - The number of skip to do
// @return This is a constructor.
explicit
Builder
(
int32_t
count
);
// Default destructor
~
Builder
()
=
default
;
// The builder "build" method creates the final object.
// @return shared_ptr to the new StorageOp object
Status
Build
(
std
::
shared_ptr
<
SkipOp
>
*
);
private:
int32_t
build_max_skips_
;
Status
SanityCheck
()
const
;
};
// Constructor of the SkipOp.
// @note The builder class should be used to call it
// @param count - The number of skips to do
explicit
SkipOp
(
int32_t
count
);
// Destructor
~
SkipOp
();
// A print method typically used for debugging
// @param out - The output stream to write output to
// @param show_all - A bool to control if you want to show all info or just a summary
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
// Class functor operator () override.
// Most dataset ops operate by launching a thread (see ExecutionTree).
// However, the SkipOp is defined as a inlined operator, so it is invalid to launch the
// functor since this op runs inlined inside another operator. The function is overloaded to
// ensure that it is not called by mistake (it will generate an error).
// @return Status - The error code return
Status
operator
()()
override
;
// This function returns the buffer that is at the top of our output connector. The caller is
// typically our parent node, when the parent is asking us to provide the next buffer of data.
// Since SkipOp is an inlined op, getting a buffer from us will simply bounce you to get
// a buffer from our child.
// @param p_buffer - output pointer to the buffer that it will fetch.
// @param worker_id - The worker id
// @param retry_if_eoe Set this flag to true to allow calling pop() again after the first pop() returns EOE.
// @return Status - The error code return
Status
GetNextBuffer
(
std
::
unique_ptr
<
DataBuffer
>
*
p_buffer
,
int32_t
worker_id
,
bool
retry_if_eoe
)
override
;
// Base-class override for handling cases when an eoe is received.
// @param worker_id - The worker id
Status
EoeReceived
(
int32_t
worker_id
)
override
;
// Base-class override for handling cases when an eof is received.
// @param worker_id - The worker id
Status
EofReceived
(
int32_t
worker_id
)
override
;
private:
int32_t
max_skips_
;
// The number of skips that the user requested
int32_t
skip_count_
;
// A counter for the current number of executed skips
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_SKIP_OP_H_
mindspore/dataset/engine/datasets.py
浏览文件 @
3f7054dc
...
...
@@ -35,7 +35,7 @@ from mindspore._c_expression import typing
from
mindspore
import
log
as
logger
from
.
import
samplers
from
.iterators
import
DictIterator
,
TupleIterator
from
.validators
import
check
,
check_batch
,
check_shuffle
,
check_map
,
check_repeat
,
check_zip
,
check_rename
,
\
from
.validators
import
check
,
check_batch
,
check_shuffle
,
check_map
,
check_repeat
,
check_
skip
,
check_
zip
,
check_rename
,
\
check_project
,
check_imagefolderdatasetv2
,
check_mnist_cifar_dataset
,
check_manifestdataset
,
\
check_tfrecorddataset
,
check_vocdataset
,
check_celebadataset
,
check_minddataset
,
check_generatordataset
,
\
check_zip_dataset
,
check_add_column
...
...
@@ -423,6 +423,25 @@ class Dataset:
return
self
return
RepeatDataset
(
self
,
count
)
@
check_skip
def
skip
(
self
,
count
):
"""
Skip the first N elements of this dataset.
Args:
count (int): Number of elements the dataset should be skipped.
Returns:
SkipDataset, dataset skipped.
Examples:
>>> import mindspore.dataset as ds
>>> # data is an instance of Dataset object.
>>> # creates a dataset which skips first 3 elements from data
>>> data = data.skip(3)
"""
return
SkipDataset
(
self
,
count
)
@
check_zip_dataset
def
zip
(
self
,
datasets
):
"""
...
...
@@ -1081,6 +1100,39 @@ class RepeatDataset(DatasetOp):
"""
return
self
.
count
class
SkipDataset
(
DatasetOp
):
"""
The result of applying Skip operator to the input Dataset.
Args:
datasets (tuple): A tuple of datasets to be skipped.
count (int): Number of rows the dataset should be skipped.
"""
def
__init__
(
self
,
input_dataset
,
count
):
super
().
__init__
()
self
.
count
=
count
self
.
input
.
append
(
input_dataset
)
input_dataset
.
output
.
append
(
self
)
self
.
_input_indexs
=
input_dataset
.
input_indexs
def
get_args
(
self
):
args
=
super
().
get_args
()
args
[
"count"
]
=
self
.
count
return
args
def
get_dataset_size
(
self
):
"""
Get the number of batches in an epoch.
Return:
Number, number of batches.
"""
child_size
=
self
.
input
[
0
].
get_dataset_size
()
output_size
=
0
if
self
.
count
>=
0
and
self
.
count
<
child_size
:
output_size
=
child_size
-
self
.
count
return
output_size
class
ZipDataset
(
DatasetOp
):
"""
...
...
mindspore/dataset/engine/iterators.py
浏览文件 @
3f7054dc
...
...
@@ -127,6 +127,8 @@ class Iterator:
op_type
=
OpName
.
MAP
elif
isinstance
(
dataset
,
de
.
RepeatDataset
):
op_type
=
OpName
.
REPEAT
elif
isinstance
(
dataset
,
de
.
SkipDataset
):
op_type
=
OpName
.
SKIP
elif
isinstance
(
dataset
,
de
.
StorageDataset
):
op_type
=
OpName
.
STORAGE
elif
isinstance
(
dataset
,
de
.
ImageFolderDatasetV2
):
...
...
mindspore/dataset/engine/serializer_deserializer.py
浏览文件 @
3f7054dc
...
...
@@ -297,6 +297,9 @@ def create_node(node):
elif
dataset_op
==
'RepeatDataset'
:
pyobj
=
de
.
Dataset
().
repeat
(
node
.
get
(
'count'
))
elif
dataset_op
==
'SkipDataset'
:
pyobj
=
de
.
Dataset
().
skip
(
node
.
get
(
'count'
))
elif
dataset_op
==
'MapDataset'
:
tensor_ops
=
construct_tensor_ops
(
node
.
get
(
'operations'
))
pyobj
=
de
.
Dataset
().
map
(
node
.
get
(
'input_columns'
),
tensor_ops
,
node
.
get
(
'output_columns'
),
...
...
mindspore/dataset/engine/validators.py
浏览文件 @
3f7054dc
...
...
@@ -709,6 +709,20 @@ def check_repeat(method):
return
new_method
def
check_skip
(
method
):
"""check the input arguments of skip."""
@
wraps
(
method
)
def
new_method
(
*
args
,
**
kwargs
):
param_dict
=
make_param_dict
(
method
,
args
,
kwargs
)
count
=
param_dict
.
get
(
'count'
)
check_type
(
count
,
'count'
,
int
)
if
count
<
0
:
raise
ValueError
(
"Skip count must be positive integer or 0."
)
return
method
(
*
args
,
**
kwargs
)
return
new_method
def
check_zip
(
method
):
"""check the input arguments of zip."""
...
...
tests/ut/cpp/dataset/CMakeLists.txt
浏览文件 @
3f7054dc
...
...
@@ -41,6 +41,7 @@ SET(DE_UT_SRCS
random_vertical_flip_op_test.cc
rename_op_test.cc
repeat_op_test.cc
skip_op_test.cc
rescale_op_test.cc
resize_bilinear_op_test.cc
resize_op_test.cc
...
...
tests/ut/cpp/dataset/skip_op_test.cc
0 → 100644
浏览文件 @
3f7054dc
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/util/circular_pool.h"
#include "dataset/core/client.h"
#include "common/common.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
using
namespace
mindspore
::
dataset
;
using
mindspore
::
MsLogLevel
::
INFO
;
using
mindspore
::
ExceptionType
::
NoExceptionType
;
using
mindspore
::
LogStream
;
class
MindDataTestSkipOp
:
public
UT
::
DatasetOpTesting
{};
TEST_F
(
MindDataTestSkipOp
,
TestSkipOpFuntions
)
{
// Start with an empty execution tree
auto
my_tree
=
std
::
make_shared
<
ExecutionTree
>
();
std
::
string
dataset_path
;
dataset_path
=
datasets_root_path_
+
"/testTFTestAllTypes/test.data"
;
std
::
shared_ptr
<
TFReaderOp
>
my_tfreader_op
;
TFReaderOp
::
Builder
builder
;
builder
.
SetDatasetFilesList
({
dataset_path
})
.
SetRowsPerBuffer
(
16
)
.
SetWorkerConnectorSize
(
16
)
.
SetNumWorkers
(
16
);
std
::
unique_ptr
<
DataSchema
>
schema
=
std
::
make_unique
<
DataSchema
>
();
schema
->
LoadSchemaFile
(
datasets_root_path_
+
"/testTFTestAllTypes/datasetSchema.json"
,
{});
builder
.
SetDataSchema
(
std
::
move
(
schema
));
Status
rc
=
builder
.
Build
(
&
my_tfreader_op
);
ASSERT_TRUE
(
rc
.
IsOk
());
rc
=
my_tree
->
AssociateNode
(
my_tfreader_op
);
ASSERT_TRUE
(
rc
.
IsOk
());
// SkipOp
std
::
shared_ptr
<
SkipOp
>
skip_op
=
std
::
make_shared
<
SkipOp
>
(
5
);
rc
=
my_tree
->
AssociateNode
(
skip_op
);
ASSERT_TRUE
(
rc
.
IsOk
());
// Set children/root layout.
rc
=
skip_op
->
AddChild
(
my_tfreader_op
);
ASSERT_TRUE
(
rc
.
IsOk
());
rc
=
my_tree
->
AssignRoot
(
skip_op
);
ASSERT_TRUE
(
rc
.
IsOk
());
MS_LOG
(
INFO
)
<<
"Launching tree and begin iteration."
;
rc
=
my_tree
->
Prepare
();
ASSERT_TRUE
(
rc
.
IsOk
());
rc
=
my_tree
->
Launch
();
ASSERT_TRUE
(
rc
.
IsOk
());
// Start the loop of reading tensors from our pipeline
DatasetIterator
di
(
my_tree
);
TensorRow
tensor_list
;
rc
=
di
.
FetchNextTensorRow
(
&
tensor_list
);
ASSERT_TRUE
(
rc
.
IsOk
());
int
row_count
=
0
;
while
(
!
tensor_list
.
empty
())
{
MS_LOG
(
INFO
)
<<
"Row display for row #: "
<<
row_count
<<
"."
;
// Display the tensor by calling the printer on it
for
(
int
i
=
0
;
i
<
tensor_list
.
size
();
i
++
)
{
std
::
ostringstream
ss
;
ss
<<
"("
<<
tensor_list
[
i
]
<<
"): "
<<
*
tensor_list
[
i
]
<<
std
::
endl
;
MS_LOG
(
INFO
)
<<
"Tensor print: "
<<
ss
.
str
()
<<
"."
;
}
rc
=
di
.
FetchNextTensorRow
(
&
tensor_list
);
ASSERT_TRUE
(
rc
.
IsOk
());
row_count
++
;
}
ASSERT_EQ
(
row_count
,
7
);
}
\ No newline at end of file
tests/ut/python/dataset/test_skip.py
0 → 100644
浏览文件 @
3f7054dc
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
numpy
as
np
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
import
mindspore.dataset
as
ds
from
mindspore
import
log
as
logger
DATA_DIR_TF2
=
[
"../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
]
SCHEMA_DIR_TF2
=
"../data/dataset/test_tf_file_3_images/datasetSchema.json"
def
test_tf_skip
():
data1
=
ds
.
TFRecordDataset
(
DATA_DIR_TF2
,
SCHEMA_DIR_TF2
,
shuffle
=
False
)
resize_height
,
resize_width
=
32
,
32
decode_op
=
vision
.
Decode
()
resize_op
=
vision
.
Resize
((
resize_height
,
resize_width
),
interpolation
=
ds
.
transforms
.
vision
.
Inter
.
LINEAR
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
decode_op
)
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
resize_op
)
data1
=
data1
.
skip
(
2
)
num_iter
=
0
for
item
in
data1
.
create_dict_iterator
():
num_iter
+=
1
assert
num_iter
==
1
def
generator_md
():
# Create a dataset with [0, 1, 2, 3, 4]
for
i
in
range
(
5
):
yield
(
np
.
array
([
i
]),
)
def
test_generator_skip
():
ds1
=
ds
.
GeneratorDataset
(
generator_md
,
[
"data"
])
# Here ds1 should be [3, 4]
ds1
=
ds1
.
skip
(
3
)
buf
=
[]
for
data
in
ds1
:
buf
.
append
(
data
[
0
][
0
])
assert
len
(
buf
)
==
2
def
test_skip_1
():
ds1
=
ds
.
GeneratorDataset
(
generator_md
,
[
"data"
])
# Here ds1 should be []
ds1
=
ds1
.
skip
(
7
)
buf
=
[]
for
data
in
ds1
:
buf
.
append
(
data
[
0
][
0
])
assert
len
(
buf
)
==
0
def
test_skip_2
():
ds1
=
ds
.
GeneratorDataset
(
generator_md
,
[
"data"
])
# Here ds1 should be [0, 1, 2, 3, 4]
ds1
=
ds1
.
skip
(
0
)
buf
=
[]
for
data
in
ds1
:
buf
.
append
(
data
[
0
][
0
])
assert
len
(
buf
)
==
5
def
test_skip_repeat_1
():
ds1
=
ds
.
GeneratorDataset
(
generator_md
,
[
"data"
])
# Here ds1 should be [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
ds1
=
ds1
.
repeat
(
2
)
# Here ds1 should be [3, 4, 0, 1, 2, 3, 4]
ds1
=
ds1
.
skip
(
3
)
buf
=
[]
for
data
in
ds1
:
buf
.
append
(
data
[
0
][
0
])
assert
len
(
buf
)
==
7
def
test_skip_repeat_2
():
ds1
=
ds
.
GeneratorDataset
(
generator_md
,
[
"data"
])
# Here ds1 should be [3, 4]
ds1
=
ds1
.
skip
(
3
)
# Here ds1 should be [3, 4, 3, 4]
ds1
=
ds1
.
repeat
(
2
)
buf
=
[]
for
data
in
ds1
:
buf
.
append
(
data
[
0
][
0
])
assert
len
(
buf
)
==
4
def
test_skip_repeat_3
():
ds1
=
ds
.
GeneratorDataset
(
generator_md
,
[
"data"
])
# Here ds1 should be [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
ds1
=
ds1
.
repeat
(
2
)
# Here ds1 should be [3, 4]
ds1
=
ds1
.
skip
(
8
)
# Here ds1 should be [3, 4, 3, 4, 3, 4]
ds1
=
ds1
.
repeat
(
3
)
buf
=
[]
for
data
in
ds1
:
buf
.
append
(
data
[
0
][
0
])
assert
len
(
buf
)
==
6
if
__name__
==
"__main__"
:
test_tf_skip
()
test_generator_skip
()
test_skip_1
()
test_skip_2
()
test_skip_repeat_1
()
test_skip_repeat_2
()
test_skip_repeat_3
()
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录