Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
c0fa7b4b
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c0fa7b4b
编写于
5月 11, 2020
作者:
M
ms_yan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
init commit of concat dataset
change to use __add__ operation instead ds.concat
上级
e2951707
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
852 addition
and
1 deletion
+852
-1
mindspore/ccsrc/dataset/api/de_pipeline.cc
mindspore/ccsrc/dataset/api/de_pipeline.cc
+9
-0
mindspore/ccsrc/dataset/api/de_pipeline.h
mindspore/ccsrc/dataset/api/de_pipeline.h
+3
-0
mindspore/ccsrc/dataset/api/python_bindings.cc
mindspore/ccsrc/dataset/api/python_bindings.cc
+1
-0
mindspore/ccsrc/dataset/core/client.h
mindspore/ccsrc/dataset/core/client.h
+1
-0
mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt
mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt
+1
-0
mindspore/ccsrc/dataset/engine/datasetops/concat_op.cc
mindspore/ccsrc/dataset/engine/datasetops/concat_op.cc
+145
-0
mindspore/ccsrc/dataset/engine/datasetops/concat_op.h
mindspore/ccsrc/dataset/engine/datasetops/concat_op.h
+95
-0
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+68
-1
mindspore/dataset/engine/iterators.py
mindspore/dataset/engine/iterators.py
+2
-0
mindspore/dataset/engine/serializer_deserializer.py
mindspore/dataset/engine/serializer_deserializer.py
+4
-0
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+20
-0
tests/ut/cpp/dataset/CMakeLists.txt
tests/ut/cpp/dataset/CMakeLists.txt
+1
-0
tests/ut/cpp/dataset/concat_op_test.cc
tests/ut/cpp/dataset/concat_op_test.cc
+125
-0
tests/ut/python/dataset/test_concat.py
tests/ut/python/dataset/test_concat.py
+377
-0
未找到文件。
mindspore/ccsrc/dataset/api/de_pipeline.cc
浏览文件 @
c0fa7b4b
...
...
@@ -53,6 +53,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D
{
kRepeat
,
&
DEPipeline
::
ParseRepeatOp
},
{
kSkip
,
&
DEPipeline
::
ParseSkipOp
},
{
kZip
,
&
DEPipeline
::
ParseZipOp
},
{
kConcat
,
&
DEPipeline
::
ParseConcatOp
},
{
kRename
,
&
DEPipeline
::
ParseRenameOp
},
{
kDeviceQueue
,
&
DEPipeline
::
ParseDeviceQueueOp
},
{
kGenerator
,
&
DEPipeline
::
ParseGeneratorOp
},
...
...
@@ -757,6 +758,14 @@ Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *
return
Status
::
OK
();
}
Status
DEPipeline
::
ParseConcatOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
)
{
std
::
shared_ptr
<
ConcatOp
::
Builder
>
builder
=
std
::
make_shared
<
ConcatOp
::
Builder
>
();
std
::
shared_ptr
<
ConcatOp
>
op
;
RETURN_IF_NOT_OK
(
builder
->
Build
(
&
op
));
*
ptr
=
op
;
return
Status
::
OK
();
}
Status
DEPipeline
::
ParseTFReaderOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
)
{
// Required arguments
std
::
shared_ptr
<
TFReaderOp
::
Builder
>
builder
=
std
::
make_shared
<
TFReaderOp
::
Builder
>
();
...
...
mindspore/ccsrc/dataset/api/de_pipeline.h
浏览文件 @
c0fa7b4b
...
...
@@ -46,6 +46,7 @@ enum OpName {
kSkip
,
kTake
,
kZip
,
kConcat
,
kMap
,
kFilter
,
kDeviceQueue
,
...
...
@@ -127,6 +128,8 @@ class DEPipeline {
Status
ParseZipOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseConcatOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseDeviceQueueOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseTFReaderOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
...
...
mindspore/ccsrc/dataset/api/python_bindings.cc
浏览文件 @
c0fa7b4b
...
...
@@ -468,6 +468,7 @@ PYBIND11_MODULE(_c_dataengine, m) {
.
value
(
"SKIP"
,
OpName
::
kSkip
)
.
value
(
"TAKE"
,
OpName
::
kTake
)
.
value
(
"ZIP"
,
OpName
::
kZip
)
.
value
(
"CONCAT"
,
OpName
::
kConcat
)
.
value
(
"MAP"
,
OpName
::
kMap
)
.
value
(
"FILTER"
,
OpName
::
kFilter
)
.
value
(
"DEVICEQUEUE"
,
OpName
::
kDeviceQueue
)
...
...
mindspore/ccsrc/dataset/core/client.h
浏览文件 @
c0fa7b4b
...
...
@@ -42,6 +42,7 @@
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/datasetops/take_op.h"
#include "dataset/engine/datasetops/zip_op.h"
#include "dataset/engine/datasetops/concat_op.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/util/status.h"
...
...
mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt
浏览文件 @
c0fa7b4b
...
...
@@ -17,6 +17,7 @@ add_library(engine-datasetops OBJECT
take_op.cc
shuffle_op.cc
zip_op.cc
concat_op.cc
filter_op.cc
)
mindspore/ccsrc/dataset/engine/datasetops/concat_op.cc
0 → 100644
浏览文件 @
c0fa7b4b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iomanip>
#include <utility>
#include "common/utils.h"
#include "dataset/core/config_manager.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/datasetops/concat_op.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
namespace
mindspore
{
namespace
dataset
{
// Builder constructor. Creates the builder object.
ConcatOp
::
Builder
::
Builder
()
{
std
::
shared_ptr
<
ConfigManager
>
cfg
=
GlobalContext
::
config_manager
();
builder_op_connector_size_
=
cfg
->
op_connector_size
();
}
// The builder "build" method creates the final object.
Status
ConcatOp
::
Builder
::
Build
(
std
::
shared_ptr
<
ConcatOp
>
*
ptr
)
{
*
ptr
=
std
::
make_shared
<
ConcatOp
>
(
builder_op_connector_size_
);
return
Status
::
OK
();
}
// Constructor of the ConcatOp.
ConcatOp
::
ConcatOp
(
int32_t
op_connector_size
)
:
PipelineOp
(
op_connector_size
),
children_num_
(
0
)
{}
// A function that prints info about the Operator
void
ConcatOp
::
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
{
// Always show the id and name as first line regardless if this is summary or detailed print
out
<<
"("
<<
std
::
setw
(
2
)
<<
operator_id_
<<
") <ConcatOp>:"
;
if
(
!
show_all
)
{
// Call the super class for displaying any common 1-liner info
PipelineOp
::
Print
(
out
,
show_all
);
// Then show any custom derived-internal 1-liner info for this op
out
<<
"
\n
"
;
}
else
{
// Call the super class for displaying any common detailed info
PipelineOp
::
Print
(
out
,
show_all
);
// Then show any custom derived-internal stuff
out
<<
"
\n
Datasets: "
<<
children_num_
<<
"
\n\n
"
;
}
}
// Main entry point for Concat
Status
ConcatOp
::
operator
()()
{
// The children_num_ parameter needs to be put here
children_num_
=
static_cast
<
int32_t
>
(
child_
.
size
());
TaskManager
::
FindMe
()
->
Post
();
std
::
unique_ptr
<
DataBuffer
>
buf
;
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
buf
));
// Obtain columns_name_id_map from child_[0]
column_name_id_map_
=
child_
[
0
]
->
column_name_id_map
();
if
(
column_name_id_map_
.
empty
())
{
RETURN_STATUS_UNEXPECTED
(
"Child column name map cannot be empty!"
);
}
int
eof_count
=
0
;
while
(
eof_count
!=
children_num_
)
{
for
(
int
i
=
0
;
i
<
children_num_
;
i
++
)
{
// 1. Throw the eof buffer when meet it
if
(
buf
->
eof
()
||
buf
->
eoe
())
{
RETURN_IF_NOT_OK
(
child_
[
i
]
->
GetNextBuffer
(
&
buf
));
}
// 2. Do varification as for column name, column data type and rank of column data
RETURN_IF_NOT_OK
(
Verify
(
i
,
buf
));
// 3. Put the data into output_connector
while
(
!
buf
->
eoe
()
&&
!
buf
->
eof
())
{
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
buf
)));
RETURN_IF_NOT_OK
(
child_
[
i
]
->
GetNextBuffer
(
&
buf
));
}
// 4. Throw the eoe buffer when meet it
if
(
buf
->
eoe
()
&&
(
!
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
)
||
BitTest
(
op_ctrl_flags_
,
kDeOpLastRepeat
)))
{
RETURN_IF_NOT_OK
(
child_
[
i
]
->
GetNextBuffer
(
&
buf
));
}
// 5. Add eoe buffer after get buffer from all child
if
(
i
==
(
children_num_
-
1
))
{
auto
eoe_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
eoe_buffer
)));
}
if
(
buf
->
eof
())
{
eof_count
++
;
}
}
}
// 6. Add eof buffer in the end manually
MS_LOG
(
DEBUG
)
<<
"Add the eof buffer manualy in the end."
;
auto
eof_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOF
);
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
eof_buffer
)));
return
Status
::
OK
();
}
Status
ConcatOp
::
Verify
(
int32_t
id
,
const
std
::
unique_ptr
<
DataBuffer
>
&
buf
)
{
TensorRow
new_row
;
buf
->
GetRow
(
0
,
&
new_row
);
if
(
id
==
0
)
{
// Obtain the column name, data type and data rank in child[0]
column_name_id_
=
child_
[
id
]
->
column_name_id_map
();
for
(
auto
item
:
new_row
)
{
data_type_
.
push_back
(
item
->
type
());
data_rank_
.
push_back
(
item
->
Rank
());
}
}
else
{
// Compare the column name, data type and data rank with these in child[0]
if
(
child_
[
id
]
->
column_name_id_map
()
!=
column_name_id_
)
{
RETURN_STATUS_UNEXPECTED
(
"The column name or column order is not the same with previous dataset."
);
}
int32_t
index
=
0
;
for
(
auto
item
:
new_row
)
{
if
((
item
->
type
()
!=
data_type_
[
index
])
||
item
->
Rank
()
!=
data_rank_
[
index
++
])
{
RETURN_STATUS_UNEXPECTED
(
"The data type or data rank is not the same with previous dataset."
);
}
}
}
return
Status
::
OK
();
}
Status
ConcatOp
::
PrepareNodePostAction
()
{
RETURN_IF_NOT_OK
(
PipelineOp
::
PrepareNodePostAction
());
tree_
->
AddToRepeatStack
(
shared_from_this
());
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/concat_op.h
0 → 100644
浏览文件 @
c0fa7b4b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_
#define DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "dataset/engine/datasetops/pipeline_op.h"
namespace
mindspore
{
namespace
dataset
{
class
ConcatOp
:
public
PipelineOp
{
public:
// The nested builder class inside of the ConcatOp is used to help manage all of the arguments
// for constructing it. This Concat op is very simple though, so this builder is really just
// provided for a consistent look and feel for creators of Dataset operators overall.
class
Builder
{
public:
// Builder constructor. Creates the builder object.
// @note No default args
// @return This is a constructor.
Builder
();
// Default destructor
~
Builder
()
=
default
;
// The builder "build" method creates the final object.
// @return shared_ptr to the new StorageOp object
Status
Build
(
std
::
shared_ptr
<
ConcatOp
>
*
);
private:
int32_t
builder_op_connector_size_
;
};
// Constructor of the ConcatOp.
// @note The builder class should be used to call it
// @param op_connector_size - connector size
explicit
ConcatOp
(
int32_t
op_connector_size
);
// Destructor
~
ConcatOp
()
=
default
;
// A print method typically used for debugging
// @param out - The output stream to write output to
// @param show_all - A bool to control if you want to show all info or just a summary
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
// << Stream output operator overload
// @notes This allows you to write the debug print info using stream operators
// @param out - reference to the output stream being overloaded
// @param ro - reference to the ConcatOp to display
// @return - the output stream must be returned
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
ConcatOp
&
ro
)
{
ro
.
Print
(
out
,
false
);
return
out
;
}
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - The error code return
Status
operator
()()
override
;
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
Status
PrepareNodePostAction
()
override
;
private:
Status
Verify
(
int32_t
id
,
const
std
::
unique_ptr
<
DataBuffer
>
&
buf
);
int32_t
children_num_
;
// The num of child of parent node.
std
::
unordered_map
<
std
::
string
,
int32_t
>
column_name_id_
;
// Mapping between col index and col name
std
::
vector
<
DataType
>
data_type_
;
std
::
vector
<
dsize_t
>
data_rank_
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_CONCAT_OP_H_
mindspore/dataset/engine/datasets.py
浏览文件 @
c0fa7b4b
...
...
@@ -44,7 +44,7 @@ from .validators import check, check_batch, check_shuffle, check_map, check_filt
check_rename
,
\
check_take
,
check_project
,
check_imagefolderdatasetv2
,
check_mnist_cifar_dataset
,
check_manifestdataset
,
\
check_tfrecorddataset
,
check_vocdataset
,
check_celebadataset
,
check_minddataset
,
check_generatordataset
,
\
check_sync_wait
,
check_zip_dataset
,
check_add_column
,
check_textfiledataset
check_sync_wait
,
check_zip_dataset
,
check_add_column
,
check_textfiledataset
,
check_concat
from
..core.datatypes
import
mstype_to_detype
,
mstypelist_to_detypelist
try
:
...
...
@@ -147,6 +147,9 @@ class Dataset:
self
.
_repeat_count
=
None
self
.
_sync
=
False
def
__add__
(
self
,
datasets
):
return
self
.
concat
(
datasets
)
def
get_args
(
self
):
"""
Returns attributes (member variables) related to the current class.
...
...
@@ -560,6 +563,37 @@ class Dataset:
raise
TypeError
(
"The zip function %s type error!"
%
(
datasets
))
return
ZipDataset
(
datasets
)
@
check_concat
def
concat
(
self
,
datasets
):
"""
Concat the datasets in the input list of datasets, supported using "+" to reload concat operation.
Note:
The column name,column data type and rank of column data should be the same in input datasets.
Args:
datasets (list or class Dataset): A list of datasets or a single class Dataset
to be concated together with this dataset.
Returns:
ConcatDataset, dataset concated.
Examples:
>>> import mindspore.dataset as ds
>>> # ds1 and ds2 are instances of Dataset object
>>> # creates a dataset by concating ds1 and ds2 with "+" operation
>>> data1 = ds1 + ds2
>>> # creates a dataset by concating ds1 and ds2 with concat operation
>>> data1 = ds1.concat(ds2)
"""
if
isinstance
(
datasets
,
Dataset
):
datasets
=
[
self
]
+
[
datasets
]
elif
isinstance
(
datasets
,
list
):
datasets
=
[
self
]
+
datasets
else
:
raise
TypeError
(
"The concat_dataset function %s type error!"
%
(
datasets
))
return
ConcatDataset
(
datasets
)
@
check_rename
def
rename
(
self
,
input_columns
,
output_columns
):
"""
...
...
@@ -1658,6 +1692,39 @@ class ZipDataset(DatasetOp):
return
args
class
ConcatDataset
(
DatasetOp
):
"""
The result of applying concat dataset operator to the input Dataset.
Args:
datasets (list): A list of datasets to be concated together.
Raises:
TypeError: If dataset is not an instance of Dataset.
"""
def
__init__
(
self
,
datasets
):
super
().
__init__
()
for
dataset
in
datasets
:
if
not
isinstance
(
dataset
,
Dataset
):
raise
TypeError
(
"The parameter %s of concat has type error!"
%
(
dataset
))
self
.
datasets
=
datasets
for
data
in
datasets
:
self
.
input
.
append
(
data
)
data
.
output
.
append
(
self
)
def
get_dataset_size
(
self
):
"""
Get the number of batches in an epoch.
Return:
Number, number of batches.
"""
children_sizes
=
[
c
.
get_dataset_size
()
for
c
in
self
.
input
]
dataset_size
=
np
.
sum
(
children_sizes
)
return
dataset_size
class
RenameDataset
(
DatasetOp
):
"""
The result of applying Rename operator to the input Dataset.
...
...
mindspore/dataset/engine/iterators.py
浏览文件 @
c0fa7b4b
...
...
@@ -156,6 +156,8 @@ class Iterator:
op_type
=
OpName
.
BARRIER
elif
isinstance
(
dataset
,
de
.
ZipDataset
):
op_type
=
OpName
.
ZIP
elif
isinstance
(
dataset
,
de
.
ConcatDataset
):
op_type
=
OpName
.
CONCAT
elif
isinstance
(
dataset
,
de
.
MapDataset
):
op_type
=
OpName
.
MAP
elif
isinstance
(
dataset
,
de
.
FilterDataset
):
...
...
mindspore/dataset/engine/serializer_deserializer.py
浏览文件 @
c0fa7b4b
...
...
@@ -335,6 +335,10 @@ def create_node(node):
# Create ZipDataset instance, giving dummy input dataset that will be overrided in the caller.
pyobj
=
de
.
ZipDataset
((
de
.
Dataset
(),
de
.
Dataset
()))
elif
dataset_op
==
'ConcatDataset'
:
# Create ConcatDataset instance, giving dummy input dataset that will be overrided in the caller.
pyobj
=
de
.
ConcatDataset
((
de
.
Dataset
(),
de
.
Dataset
()))
elif
dataset_op
==
'RenameDataset'
:
pyobj
=
de
.
Dataset
().
rename
(
node
[
'input_columns'
],
node
[
'output_columns'
])
...
...
mindspore/dataset/engine/validators.py
浏览文件 @
c0fa7b4b
...
...
@@ -875,6 +875,26 @@ def check_zip_dataset(method):
return
new_method
def
check_concat
(
method
):
"""check the input arguments of concat_dataset method in `Dataset`."""
@
wraps
(
method
)
def
new_method
(
*
args
,
**
kwargs
):
param_dict
=
make_param_dict
(
method
,
args
,
kwargs
)
# check datasets; required argument
ds
=
param_dict
.
get
(
"datasets"
)
if
ds
is
None
:
raise
ValueError
(
"datasets is not provided."
)
if
not
isinstance
(
ds
,
(
list
,
datasets
.
Dataset
)):
raise
ValueError
(
"datasets is not list or of type Dataset."
)
return
method
(
*
args
,
**
kwargs
)
return
new_method
def
check_rename
(
method
):
"""check the input arguments of rename."""
...
...
tests/ut/cpp/dataset/CMakeLists.txt
浏览文件 @
c0fa7b4b
...
...
@@ -66,6 +66,7 @@ SET(DE_UT_SRCS
take_op_test.cc
text_file_op_test.cc
filter_op_test.cc
concat_op_test.cc
)
add_executable
(
de_ut_tests
${
DE_UT_SRCS
}
)
...
...
tests/ut/cpp/dataset/concat_op_test.cc
0 → 100644
浏览文件 @
c0fa7b4b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include <vector>
#include "common/common.h"
#include "common/utils.h"
#include "dataset/core/client.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
namespace
common
=
mindspore
::
common
;
using
namespace
mindspore
::
dataset
;
using
mindspore
::
MsLogLevel
::
INFO
;
using
mindspore
::
ExceptionType
::
NoExceptionType
;
using
mindspore
::
LogStream
;
class
MindDataTestConcatOp
:
public
UT
::
DatasetOpTesting
{};
TEST_F
(
MindDataTestConcatOp
,
TestConcatProject
)
{
/* Tree:
*
* OpId(2) ConcatOp
* / \
* OpId(0) TFReaderOp OpId(1) TFReaderOp
*
* Start with an empty execution tree
*/
MS_LOG
(
INFO
)
<<
"UT test TestConcatProject."
;
auto
my_tree
=
std
::
make_shared
<
ExecutionTree
>
();
std
::
string
dataset_path
;
dataset_path
=
datasets_root_path_
+
"/testTFTestAllTypes/test.data"
;
// TFReaderOp1
std
::
shared_ptr
<
TFReaderOp
>
my_tfreader_op1
;
TFReaderOp
::
Builder
builder1
;
builder1
.
SetDatasetFilesList
({
dataset_path
})
.
SetRowsPerBuffer
(
16
)
.
SetWorkerConnectorSize
(
16
)
.
SetNumWorkers
(
16
);
std
::
unique_ptr
<
DataSchema
>
schema1
=
std
::
make_unique
<
DataSchema
>
();
schema1
->
LoadSchemaFile
(
datasets_root_path_
+
"/testTFTestAllTypes/datasetSchema.json"
,
{});
builder1
.
SetDataSchema
(
std
::
move
(
schema1
));
Status
rc
=
builder1
.
Build
(
&
my_tfreader_op1
);
ASSERT_TRUE
(
rc
.
IsOk
());
rc
=
my_tree
->
AssociateNode
(
my_tfreader_op1
);
ASSERT_TRUE
(
rc
.
IsOk
());
// TFReaderOp2
std
::
shared_ptr
<
TFReaderOp
>
my_tfreader_op2
;
TFReaderOp
::
Builder
builder2
;
builder2
.
SetDatasetFilesList
({
dataset_path
})
.
SetRowsPerBuffer
(
16
)
.
SetWorkerConnectorSize
(
16
)
.
SetNumWorkers
(
16
);
std
::
unique_ptr
<
DataSchema
>
schema2
=
std
::
make_unique
<
DataSchema
>
();
schema2
->
LoadSchemaFile
(
datasets_root_path_
+
"/testTFTestAllTypes/datasetSchema.json"
,
{});
builder2
.
SetDataSchema
(
std
::
move
(
schema2
));
rc
=
builder2
.
Build
(
&
my_tfreader_op2
);
ASSERT_TRUE
(
rc
.
IsOk
());
rc
=
my_tree
->
AssociateNode
(
my_tfreader_op2
);
ASSERT_TRUE
(
rc
.
IsOk
());
// Creating ConcatOp
std
::
shared_ptr
<
ConcatOp
>
concat_op
;
rc
=
ConcatOp
::
Builder
().
Build
(
&
concat_op
);
EXPECT_TRUE
(
rc
.
IsOk
());
rc
=
my_tree
->
AssociateNode
(
concat_op
);
EXPECT_TRUE
(
rc
.
IsOk
());
rc
=
concat_op
->
AddChild
(
std
::
move
(
my_tfreader_op1
));
EXPECT_TRUE
(
rc
.
IsOk
());
rc
=
concat_op
->
AddChild
(
std
::
move
(
my_tfreader_op2
));
EXPECT_TRUE
(
rc
.
IsOk
());
rc
=
my_tree
->
AssignRoot
(
concat_op
);
EXPECT_TRUE
(
rc
.
IsOk
());
rc
=
my_tree
->
Prepare
();
EXPECT_TRUE
(
rc
.
IsOk
());
// Launch the tree execution to kick off threads and start running the pipeline
MS_LOG
(
INFO
)
<<
"Launching my tree."
;
rc
=
my_tree
->
Launch
();
EXPECT_TRUE
(
rc
.
IsOk
());
// Simulate a parse of data from our pipeline.
std
::
shared_ptr
<
DatasetOp
>
rootNode
=
my_tree
->
root
();
DatasetIterator
di
(
my_tree
);
TensorRow
tensor_list
;
rc
=
di
.
FetchNextTensorRow
(
&
tensor_list
);
EXPECT_TRUE
(
rc
.
IsOk
());
int
row_count
=
0
;
while
(
!
tensor_list
.
empty
())
{
MS_LOG
(
INFO
)
<<
"Row display for row #: "
<<
row_count
<<
"."
;
// Display the tensor by calling the printer on it
for
(
int
i
=
0
;
i
<
tensor_list
.
size
();
i
++
)
{
std
::
ostringstream
ss
;
ss
<<
"("
<<
tensor_list
[
i
]
<<
"): "
<<
*
tensor_list
[
i
]
<<
std
::
endl
;
MS_LOG
(
INFO
)
<<
"Tensor print: "
<<
common
::
SafeCStr
(
ss
.
str
())
<<
"."
;
}
rc
=
di
.
FetchNextTensorRow
(
&
tensor_list
);
EXPECT_TRUE
(
rc
.
IsOk
());
row_count
++
;
}
ASSERT_EQ
(
row_count
,
24
);
// Should be 24 rows fetched
}
\ No newline at end of file
tests/ut/python/dataset/test_concat.py
0 → 100644
浏览文件 @
c0fa7b4b
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.vision.py_transforms
as
F
import
mindspore.dataset.transforms.c_transforms
as
C
import
mindspore.common.dtype
as
mstype
from
mindspore
import
log
as
logger
import
numpy
as
np
# In generator dataset: Number of rows is 3, its value is 0, 1, 2
def
generator
():
for
i
in
range
(
3
):
yield
np
.
array
([
i
]),
# In generator_10 dataset: Number of rows is 7, its value is 3, 4, 5 ... 10
def
generator_10
():
for
i
in
range
(
3
,
10
):
yield
np
.
array
([
i
]),
# In generator_20 dataset: Number of rows is 10, its value is 10, 11, 12 ... 20
def
generator_20
():
for
i
in
range
(
10
,
20
):
yield
np
.
array
([
i
]),
def
test_concat_01
():
"""
Test concat: test concat 2 datasets that have the same column name and data type
"""
logger
.
info
(
"test_concat_01"
)
data1
=
ds
.
GeneratorDataset
(
generator
,
[
"col1"
])
data2
=
ds
.
GeneratorDataset
(
generator_10
,
[
"col1"
])
data3
=
data1
+
data2
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data3
):
logger
.
info
(
"data: %i"
,
d
[
0
][
0
])
assert
i
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data3
])
==
10
def
test_concat_02
():
"""
Test concat: test concat 2 datasets using concat operation not "+" operation
"""
logger
.
info
(
"test_concat_02"
)
data1
=
ds
.
GeneratorDataset
(
generator
,
[
"col1"
])
data2
=
ds
.
GeneratorDataset
(
generator_10
,
[
"col1"
])
data3
=
data1
.
concat
(
data2
)
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data3
):
logger
.
info
(
"data: %i"
,
d
[
0
][
0
])
assert
i
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data3
])
==
10
def
test_concat_03
():
"""
Test concat: test concat dataset that has different column
"""
logger
.
info
(
"test_concat_03"
)
data1
=
ds
.
GeneratorDataset
(
generator
,
[
"col1"
])
data2
=
ds
.
GeneratorDataset
(
generator_10
,
[
"col2"
])
data3
=
data1
+
data2
try
:
for
i
,
d
in
enumerate
(
data3
):
pass
assert
False
except
RuntimeError
:
pass
def
test_concat_04
():
"""
Test concat: test concat dataset that has different rank
"""
logger
.
info
(
"test_concat_04"
)
data1
=
ds
.
GeneratorDataset
(
generator
,
[
"col1"
])
data2
=
ds
.
GeneratorDataset
(
generator_10
,
[
"col2"
])
data2
=
data2
.
batch
(
3
)
data3
=
data1
+
data2
try
:
for
i
,
d
in
enumerate
(
data3
):
pass
assert
False
except
RuntimeError
:
pass
def
test_concat_05
():
"""
Test concat: test concat dataset that has different data type
"""
logger
.
info
(
"test_concat_05"
)
data1
=
ds
.
GeneratorDataset
(
generator
,
[
"col1"
])
data2
=
ds
.
GeneratorDataset
(
generator_10
,
[
"col1"
])
type_cast_op
=
C
.
TypeCast
(
mstype
.
float32
)
data1
=
data1
.
map
(
input_columns
=
[
"col1"
],
operations
=
type_cast_op
)
data3
=
data1
+
data2
try
:
for
i
,
d
in
enumerate
(
data3
):
pass
assert
False
except
RuntimeError
:
pass
def
test_concat_06
():
"""
Test concat: test concat muti datasets in one time
"""
logger
.
info
(
"test_concat_06"
)
data1
=
ds
.
GeneratorDataset
(
generator
,
[
"col1"
])
data2
=
ds
.
GeneratorDataset
(
generator_10
,
[
"col1"
])
data3
=
ds
.
GeneratorDataset
(
generator_20
,
[
"col1"
])
dataset
=
data1
+
data2
+
data3
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
dataset
):
logger
.
info
(
"data: %i"
,
d
[
0
][
0
])
assert
i
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
dataset
])
==
20
def
test_concat_07
():
"""
Test concat: test concat one dataset with multi datasets (datasets list)
"""
logger
.
info
(
"test_concat_07"
)
data1
=
ds
.
GeneratorDataset
(
generator
,
[
"col1"
])
data2
=
ds
.
GeneratorDataset
(
generator_10
,
[
"col1"
])
data3
=
ds
.
GeneratorDataset
(
generator_20
,
[
"col1"
])
dataset
=
[
data2
]
+
[
data3
]
data4
=
data1
+
dataset
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data4
):
logger
.
info
(
"data: %i"
,
d
[
0
][
0
])
assert
i
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data4
])
==
20
def
test_concat_08
():
"""
Test concat: test concat 2 datasets, and then repeat
"""
logger
.
info
(
"test_concat_08"
)
data1
=
ds
.
GeneratorDataset
(
generator
,
[
"col1"
])
data2
=
ds
.
GeneratorDataset
(
generator_10
,
[
"col1"
])
data3
=
data1
+
data2
data3
=
data3
.
repeat
(
2
)
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data3
):
logger
.
info
(
"data: %i"
,
d
[
0
][
0
])
assert
i
%
10
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data3
])
==
20
def
test_concat_09
():
"""
Test concat: test concat 2 datasets, both of them have been repeat before
"""
logger
.
info
(
"test_concat_09"
)
data1
=
ds
.
GeneratorDataset
(
generator
,
[
"col1"
])
data2
=
ds
.
GeneratorDataset
(
generator_10
,
[
"col1"
])
data1
=
data1
.
repeat
(
2
)
data2
=
data2
.
repeat
(
2
)
data3
=
data1
+
data2
res
=
[
0
,
1
,
2
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
3
,
4
,
5
,
6
,
7
,
8
,
9
]
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data3
):
logger
.
info
(
"data: %i"
,
d
[
0
][
0
])
assert
res
[
i
]
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data3
])
==
20
def
test_concat_10
():
"""
Test concat: test concat 2 datasets, one of them have repeat before
"""
logger
.
info
(
"test_concat_10"
)
data1
=
ds
.
GeneratorDataset
(
generator
,
[
"col1"
])
data2
=
ds
.
GeneratorDataset
(
generator_10
,
[
"col1"
])
data1
=
data1
.
repeat
(
2
)
data3
=
data1
+
data2
res
=
[
0
,
1
,
2
,
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
]
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data3
):
logger
.
info
(
"data: %i"
,
d
[
0
][
0
])
assert
res
[
i
]
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data3
])
==
13
def
test_concat_11
():
"""
Test concat: test dataset batch then concat
"""
logger
.
info
(
"test_concat_11"
)
data1
=
ds
.
GeneratorDataset
(
generator
,
[
"col1"
])
data2
=
ds
.
GeneratorDataset
(
generator_20
,
[
"col1"
])
data1
=
data1
.
batch
(
3
)
data2
=
data2
.
batch
(
5
)
data3
=
data1
+
data2
res
=
[
0
,
10
,
15
,
20
]
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data3
):
logger
.
info
(
"data: %i"
,
d
[
0
][
0
])
assert
res
[
i
]
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data3
])
==
3
def
test_concat_12
():
"""
Test concat: test dataset concat then shuffle
"""
logger
.
info
(
"test_concat_12"
)
data1
=
ds
.
GeneratorDataset
(
generator
,
[
"col1"
])
data2
=
ds
.
GeneratorDataset
(
generator_10
,
[
"col1"
])
data1
.
set_dataset_size
(
3
)
data2
.
set_dataset_size
(
7
)
data3
=
data1
+
data2
res
=
[
8
,
6
,
2
,
5
,
0
,
4
,
9
,
3
,
7
,
1
]
ds
.
config
.
set_seed
(
1
)
assert
data3
.
get_dataset_size
()
==
10
data3
=
data3
.
shuffle
(
buffer_size
=
10
)
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data3
):
logger
.
info
(
"data: %i"
,
d
[
0
][
0
])
assert
res
[
i
]
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data3
])
==
10
def
test_concat_13
():
"""
Test concat: test dataset batch then shuffle and concat
"""
logger
.
info
(
"test_concat_13"
)
data1
=
ds
.
GeneratorDataset
(
generator
,
[
"col1"
])
data2
=
ds
.
GeneratorDataset
(
generator_20
,
[
"col1"
])
data1
.
set_dataset_size
(
3
)
data2
.
set_dataset_size
(
10
)
data1
=
data1
.
batch
(
3
)
data2
=
data2
.
batch
(
5
)
data3
=
data1
+
data2
res
=
[
15
,
0
,
10
]
ds
.
config
.
set_seed
(
1
)
assert
data3
.
get_dataset_size
()
==
3
data3
=
data3
.
shuffle
(
buffer_size
=
int
(
data3
.
get_dataset_size
()))
# Here i refers to index, d refers to data element
for
i
,
d
in
enumerate
(
data3
):
logger
.
info
(
"data: %i"
,
d
[
0
][
0
])
assert
res
[
i
]
==
d
[
0
][
0
]
assert
sum
([
1
for
_
in
data3
])
==
3
def
test_concat_14
():
"""
Test concat: create dataset with different dataset folder, and do diffrent operation then concat
"""
logger
.
info
(
"test_concat_14"
)
DATA_DIR
=
"../data/dataset/testPK/data"
DATA_DIR2
=
"../data/dataset/testImageNetData/train/"
data1
=
ds
.
ImageFolderDatasetV2
(
DATA_DIR
,
num_samples
=
3
)
data2
=
ds
.
ImageFolderDatasetV2
(
DATA_DIR2
,
num_samples
=
2
)
transforms1
=
F
.
ComposeOp
([
F
.
Decode
(),
F
.
Resize
((
224
,
224
)),
F
.
ToTensor
()])
data1
=
data1
.
map
(
input_columns
=
[
"image"
],
operations
=
transforms1
())
data2
=
data2
.
map
(
input_columns
=
[
"image"
],
operations
=
transforms1
())
data3
=
data1
+
data2
expected
,
output
=
[],
[]
for
d
in
data1
:
expected
.
append
(
d
[
0
])
for
d
in
data2
:
expected
.
append
(
d
[
0
])
for
d
in
data3
:
output
.
append
(
d
[
0
])
assert
len
(
expected
)
==
len
(
output
)
np
.
array_equal
(
np
.
array
(
output
),
np
.
array
(
expected
))
assert
sum
([
1
for
_
in
data3
])
==
5
assert
data3
.
get_dataset_size
()
==
5
def
test_concat_15
():
"""
Test concat: create dataset with different format of dataset file, and then concat
"""
logger
.
info
(
"test_concat_15"
)
DATA_DIR
=
"../data/dataset/testPK/data"
DATA_DIR2
=
[
"../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
]
data1
=
ds
.
ImageFolderDatasetV2
(
DATA_DIR
)
data2
=
ds
.
TFRecordDataset
(
DATA_DIR2
,
columns_list
=
[
"image"
])
data1
=
data1
.
project
([
"image"
])
data3
=
data1
+
data2
assert
sum
([
1
for
_
in
data3
])
==
47
if
__name__
==
"__main__"
:
test_concat_01
()
test_concat_02
()
test_concat_03
()
test_concat_04
()
test_concat_05
()
test_concat_06
()
test_concat_07
()
test_concat_08
()
test_concat_09
()
test_concat_10
()
test_concat_11
()
test_concat_12
()
test_concat_13
()
test_concat_14
()
test_concat_15
()
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录