Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
98fbd30a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
98fbd30a
编写于
4月 20, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 20, 2020
浏览文件
操作
浏览文件
下载
差异文件
!460 [Data]Add filter operation
Merge pull request !460 from xulei/filter_master
上级
822a3160
c705ea5e
变更
16
显示空白变更内容
内联
并排
Showing
16 changed file
with
1156 addition
and
8 deletion
+1156
-8
mindspore/ccsrc/dataset/api/de_pipeline.cc
mindspore/ccsrc/dataset/api/de_pipeline.cc
+37
-2
mindspore/ccsrc/dataset/api/de_pipeline.h
mindspore/ccsrc/dataset/api/de_pipeline.h
+2
-2
mindspore/ccsrc/dataset/core/client.h
mindspore/ccsrc/dataset/core/client.h
+1
-0
mindspore/ccsrc/dataset/core/tensor.cc
mindspore/ccsrc/dataset/core/tensor.cc
+1
-1
mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt
mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt
+1
-0
mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc
mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc
+273
-0
mindspore/ccsrc/dataset/engine/datasetops/filter_op.h
mindspore/ccsrc/dataset/engine/datasetops/filter_op.h
+180
-0
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+65
-1
mindspore/dataset/engine/iterators.py
mindspore/dataset/engine/iterators.py
+2
-0
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+20
-0
tests/ut/cpp/dataset/CMakeLists.txt
tests/ut/cpp/dataset/CMakeLists.txt
+2
-0
tests/ut/cpp/dataset/filter_op_test.cc
tests/ut/cpp/dataset/filter_op_test.cc
+53
-0
tests/ut/cpp/dataset/tensor_test.cc
tests/ut/cpp/dataset/tensor_test.cc
+10
-0
tests/ut/data/dataset/declient_filter.cfg
tests/ut/data/dataset/declient_filter.cfg
+3
-0
tests/ut/python/dataset/test_filterop.py
tests/ut/python/dataset/test_filterop.py
+504
-0
tests/ut/python/dataset/test_iterator.py
tests/ut/python/dataset/test_iterator.py
+2
-2
未找到文件。
mindspore/ccsrc/dataset/api/de_pipeline.cc
浏览文件 @
98fbd30a
...
...
@@ -29,6 +29,7 @@
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/celeba_op.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/engine/datasetops/filter_op.h"
#include "mindrecord/include/shard_category.h"
#include "mindrecord/include/shard_sample.h"
#include "mindrecord/include/shard_shuffle.h"
...
...
@@ -45,6 +46,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D
{
kShuffle
,
&
DEPipeline
::
ParseShuffleOp
},
{
kMindrecord
,
&
DEPipeline
::
ParseMindRecordOp
},
{
kMap
,
&
DEPipeline
::
ParseMapOp
},
{
kFilter
,
&
DEPipeline
::
ParseFilterOp
},
{
kBatch
,
&
DEPipeline
::
ParseBatchOp
},
{
kRepeat
,
&
DEPipeline
::
ParseRepeatOp
},
{
kSkip
,
&
DEPipeline
::
ParseSkipOp
},
...
...
@@ -502,6 +504,41 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *
return
Status
::
OK
();
}
Status
DEPipeline
::
ParseFilterOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
)
{
std
::
shared_ptr
<
FilterOp
::
Builder
>
builder
=
std
::
make_shared
<
FilterOp
::
Builder
>
();
if
(
args
[
"predicate"
].
is_none
())
{
RETURN_STATUS_UNEXPECTED
(
"Error: 'predicate' is not set.
\n
"
);
}
for
(
auto
arg
:
args
)
{
std
::
string
key
=
py
::
str
(
arg
.
first
);
py
::
handle
value
=
arg
.
second
;
if
(
!
value
.
is_none
())
{
if
(
key
==
"num_parallel_workers"
)
{
(
void
)
builder
->
SetNumWorkers
(
ToInt
(
value
));
}
else
if
(
key
==
"predicate"
)
{
py
::
handle
op
=
args
[
"predicate"
];
if
(
!
py
::
isinstance
<
py
::
function
>
(
op
))
{
RETURN_STATUS_UNEXPECTED
(
"Error: predicate is not recognised (not pyfunc)."
);
}
py
::
function
predicate_func
=
op
.
cast
<
py
::
function
>
();
(
void
)
builder
->
SetPredicateFunc
(
std
::
move
(
predicate_func
));
}
else
if
(
key
==
"input_columns"
)
{
std
::
vector
<
std
::
string
>
in_col_names
=
ToStringVector
(
args
[
"input_columns"
]);
(
void
)
builder
->
SetInColNames
(
in_col_names
);
}
else
{
RETURN_STATUS_UNEXPECTED
(
"Error: Unhandled key: "
+
key
);
}
}
}
std
::
shared_ptr
<
FilterOp
>
op
;
RETURN_IF_NOT_OK
(
builder
->
Build
(
&
op
));
*
ptr
=
op
;
return
Status
::
OK
();
}
Status
DEPipeline
::
ParseRepeatOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
)
{
if
(
args
[
"count"
].
is_none
())
{
std
::
string
err_msg
=
"Error: count is invalid or not set."
;
...
...
@@ -671,8 +708,6 @@ Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *
return
Status
::
OK
();
}
DsOpPtr
DEPipeline
::
ParseFilterOp
(
const
py
::
dict
&
args
)
const
{
return
DsOpPtr
();
}
Status
DEPipeline
::
ParseTFReaderOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
)
{
// Required arguments
std
::
shared_ptr
<
TFReaderOp
::
Builder
>
builder
=
std
::
make_shared
<
TFReaderOp
::
Builder
>
();
...
...
mindspore/ccsrc/dataset/api/de_pipeline.h
浏览文件 @
98fbd30a
...
...
@@ -107,6 +107,8 @@ class DEPipeline {
Status
ParseMapOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseFilterOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseRepeatOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseSkipOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
...
...
@@ -121,8 +123,6 @@ class DEPipeline {
Status
ParseZipOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
DsOpPtr
ParseFilterOp
(
const
py
::
dict
&
args
)
const
;
Status
ParseDeviceQueueOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
Status
ParseTFReaderOp
(
const
py
::
dict
&
args
,
std
::
shared_ptr
<
DatasetOp
>
*
ptr
);
...
...
mindspore/ccsrc/dataset/core/client.h
浏览文件 @
98fbd30a
...
...
@@ -31,6 +31,7 @@
#include "dataset/engine/datasetops/map_op.h"
#include "dataset/engine/datasetops/project_op.h"
#include "dataset/engine/datasetops/rename_op.h"
#include "dataset/engine/datasetops/filter_op.h"
#include "dataset/engine/datasetops/repeat_op.h"
#include "dataset/engine/datasetops/skip_op.h"
#include "dataset/engine/datasetops/shuffle_op.h"
...
...
mindspore/ccsrc/dataset/core/tensor.cc
浏览文件 @
98fbd30a
...
...
@@ -240,7 +240,7 @@ void Tensor::PrintItemAt(const std::vector<dsize_t> &index, std::ostream &out) c
DS_ASSERT
(
data_
);
switch
(
type_
.
value
())
{
CASE_PRINT_HEX
(
DataType
::
DE_BOOL
,
uint8_t
);
CASE_PRINT_HEX
(
DataType
::
DE_BOOL
,
bool
);
CASE_PRINT_HEX
(
DataType
::
DE_INT8
,
int8_t
);
...
...
mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt
浏览文件 @
98fbd30a
...
...
@@ -14,5 +14,6 @@ add_library(engine-datasetops OBJECT
take_op.cc
shuffle_op.cc
zip_op.cc
filter_op.cc
)
mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc
0 → 100644
浏览文件 @
98fbd30a
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/filter_op.h"
#include <algorithm>
#include <cstring>
#include <iostream>
#include <memory>
#include <vector>
#include "dataset/core/config_manager.h"
#include "dataset/core/constants.h"
#include "dataset/core/global_context.h"
#include "dataset/core/tensor.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/kernels/tensor_op.h"
#include "utils/log_adapter.h"
#include "dataset/util/task_manager.h"
namespace
mindspore
{
namespace
dataset
{
Status
FilterOp
::
Builder
::
SanityCheck
()
{
std
::
string
err
;
err
+=
builder_op_connector_size_
<=
0
?
"connector size <= 0
\n
"
:
""
;
err
+=
builder_num_workers_
<=
0
?
"filter num_parallel_workers <= 0
\n
"
:
""
;
return
err
.
empty
()
?
Status
::
OK
()
:
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
common
::
SafeCStr
(
err
));
}
FilterOp
::
Builder
::
Builder
()
{
std
::
shared_ptr
<
ConfigManager
>
cfg
=
GlobalContext
::
config_manager
();
builder_num_workers_
=
cfg
->
num_parallel_workers
();
builder_op_connector_size_
=
cfg
->
op_connector_size
();
}
Status
FilterOp
::
Builder
::
Build
(
std
::
shared_ptr
<
FilterOp
>
*
ptr
)
{
RETURN_IF_NOT_OK
(
SanityCheck
());
*
ptr
=
std
::
make_shared
<
FilterOp
>
(
std
::
move
(
build_in_col_names_
),
builder_num_workers_
,
builder_op_connector_size_
,
builder_predicate_func_
);
return
Status
::
OK
();
}
FilterOp
::
FilterOp
(
const
std
::
vector
<
std
::
string
>
&
in_col_names
,
int32_t
num_workers
,
int32_t
op_queue_size
,
py
::
function
predicate_func
)
:
ParallelOp
(
num_workers
,
op_queue_size
),
predicate_func_
(
std
::
move
(
predicate_func
)),
in_columns_
(
in_col_names
)
{}
Status
FilterOp
::
operator
()()
{
// The operator class just starts off threads by calling the tree_ function.
RETURN_UNEXPECTED_IF_NULL
(
tree_
);
// Synchronize with TaskManager.
TaskManager
::
FindMe
()
->
Post
();
filter_queues_
.
Init
(
num_workers_
,
oc_queue_size_
);
RETURN_IF_NOT_OK
(
filter_queues_
.
Register
(
tree_
->
AllTasks
()));
RETURN_IF_NOT_OK
(
tree_
->
LaunchWorkers
(
num_workers_
,
std
::
bind
(
&
FilterOp
::
WorkerEntry
,
this
,
std
::
placeholders
::
_1
)));
RETURN_IF_NOT_OK
(
Collector
());
return
Status
::
OK
();
}
Status
FilterOp
::
EofReceived
(
int32_t
)
{
return
Status
::
OK
();
}
Status
FilterOp
::
EoeReceived
(
int32_t
)
{
return
Status
::
OK
();
}
// Validating if each of the input_columns exists in the DataBuffer.
Status
FilterOp
::
ValidateInColumns
(
const
std
::
unordered_map
<
std
::
string
,
int32_t
>
&
col_name_id_map
,
std
::
vector
<
std
::
string
>
*
input_columns
)
{
for
(
const
auto
&
inCol
:
*
input_columns
)
{
bool
found
=
col_name_id_map
.
find
(
inCol
)
!=
col_name_id_map
.
end
()
?
true
:
false
;
if
(
!
found
)
{
std
::
string
err_msg
=
"input column name: "
+
inCol
+
" doesn't exist in the dataset columns."
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
}
return
Status
::
OK
();
}
// A print method typically used for debugging.
void
FilterOp
::
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
{
// Call base class printer first.
ParallelOp
::
Print
(
out
,
show_all
);
// Then display our own stuff.
out
<<
"
\n
FilterOp:"
;
out
<<
"
\n
Input column names:"
;
for
(
size_t
i
=
0
;
i
<
in_columns_
.
size
();
i
++
)
{
out
<<
" "
<<
in_columns_
[
i
];
}
}
Status
FilterOp
::
WorkerEntry
(
int32_t
worker_id
)
{
// Handshake with TaskManager that thread creation is successful.
TaskManager
::
FindMe
()
->
Post
();
std
::
unique_ptr
<
DataBuffer
>
in_buffer
;
bool
worker_stop
=
false
;
while
(
worker_stop
==
false
)
{
// Getting a databuffer to work on.
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
in_buffer
,
worker_id
));
if
(
in_buffer
->
eoe
())
{
filter_queues_
[
worker_id
]
->
EmplaceBack
(
std
::
make_pair
(
std
::
move
(
in_buffer
),
filterCtrl
::
kFilterEoe
));
continue
;
}
else
if
(
in_buffer
->
eof
())
{
filter_queues_
[
worker_id
]
->
EmplaceBack
(
std
::
make_pair
(
std
::
move
(
in_buffer
),
filterCtrl
::
kFilterEof
));
worker_stop
=
true
;
continue
;
}
// Thread local variables to avoid lock. When in_columns_ is empty and workers will write
// the name of the first column into input_columns (thread local) instead of in_columns_ (thread global).
std
::
vector
<
std
::
string
>
input_columns
=
in_columns_
;
// Indices of the columns to process.
std
::
vector
<
size_t
>
to_process_indices
;
RETURN_IF_NOT_OK
(
WorkerEntryInit
(
in_buffer
.
get
(),
&
to_process_indices
,
&
input_columns
));
// if the databuffer was all filtered, it is marked as kFilterEmpty.
// if the databuffer was partially filtered, it is marked as kFilterPartial.
// if the databuffer was not filtered, it is marked as kFilterFull.
int32_t
num_rows
=
in_buffer
->
NumRows
();
std
::
unique_ptr
<
TensorQTable
>
new_tensor_table
;
RETURN_IF_NOT_OK
(
WorkerCompute
(
in_buffer
.
get
(),
to_process_indices
,
&
new_tensor_table
));
if
(
new_tensor_table
->
empty
())
{
RETURN_IF_NOT_OK
(
filter_queues_
[
worker_id
]
->
EmplaceBack
(
std
::
make_pair
(
std
::
move
(
in_buffer
),
filterCtrl
::
kFilterEmpty
)));
}
else
if
(
new_tensor_table
->
size
()
==
num_rows
)
{
in_buffer
->
set_tensor_table
(
std
::
move
(
new_tensor_table
));
RETURN_IF_NOT_OK
(
filter_queues_
[
worker_id
]
->
EmplaceBack
(
std
::
make_pair
(
std
::
move
(
in_buffer
),
filterCtrl
::
kFilterFull
)));
}
else
{
// kFilterPartial
in_buffer
->
set_tensor_table
(
std
::
move
(
new_tensor_table
));
RETURN_IF_NOT_OK
(
filter_queues_
[
worker_id
]
->
EmplaceBack
(
std
::
make_pair
(
std
::
move
(
in_buffer
),
filterCtrl
::
kFilterPartial
)));
}
}
return
Status
::
OK
();
}
Status
FilterOp
::
WorkerCompute
(
DataBuffer
*
in_buffer
,
const
std
::
vector
<
size_t
>
&
to_proess_indices
,
std
::
unique_ptr
<
TensorQTable
>
*
out
)
{
*
out
=
std
::
make_unique
<
TensorQTable
>
();
int32_t
num_rows
=
in_buffer
->
NumRows
();
for
(
int32_t
i
=
0
;
i
<
num_rows
;
i
++
)
{
TensorRow
to_process
;
TensorRow
cur_row
;
RETURN_IF_NOT_OK
(
in_buffer
->
PopRow
(
&
cur_row
));
(
void
)
std
::
transform
(
to_proess_indices
.
begin
(),
to_proess_indices
.
end
(),
std
::
back_inserter
(
to_process
),
[
&
cur_row
](
const
size_t
&
it
)
->
std
::
shared_ptr
<
Tensor
>
{
return
cur_row
[
it
];
});
bool
predicate
=
true
;
RETURN_IF_NOT_OK
(
InvokePredicateFunc
(
to_process
,
&
predicate
));
if
(
predicate
)
{
(
*
out
)
->
push_back
(
std
::
move
(
cur_row
));
}
}
return
Status
::
OK
();
}
// if the filtered DataBuffer is written directly to out_connector_,
// the thread fetching data will block in a queue.
// Collector function will reorder the DataBuffer in order.
// for example in two work queues:
// int filter_queues_:
// queue1: DB(data1 kFilterEmpty) DB(eoe) DB(data4) DB(eof)
// queue2: DB(data2) DB(data3 kFilterEmpty) DB(eoe)
// after reorder in out_connector_:
// queue1: DB(data2) DB(data4) DB(eof)
// queue2: DB(eoe) DB(eoe)
Status
FilterOp
::
Collector
()
{
bool
collector_stop
=
false
;
uint64_t
task_id_cnt
=
0
;
uint64_t
out_id_cnt
=
0
;
std
::
pair
<
std
::
unique_ptr
<
DataBuffer
>
,
filterCtrl
>
in_pair
;
while
(
collector_stop
==
false
)
{
uint32_t
w_id
=
task_id_cnt
%
num_workers_
;
RETURN_IF_NOT_OK
(
filter_queues_
[
w_id
]
->
PopFront
(
&
in_pair
));
if
(
in_pair
.
second
==
filterCtrl
::
kFilterFull
||
in_pair
.
second
==
filterCtrl
::
kFilterPartial
||
in_pair
.
second
==
filterCtrl
::
kFilterEoe
)
{
uint32_t
out_task_id
=
out_id_cnt
%
num_workers_
;
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
static_cast
<
int
>
(
out_task_id
),
std
::
move
(
in_pair
.
first
)));
out_id_cnt
++
;
task_id_cnt
++
;
}
else
if
(
in_pair
.
second
==
filterCtrl
::
kFilterEof
)
{
uint32_t
out_task_id
=
out_id_cnt
%
num_workers_
;
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
static_cast
<
int
>
(
out_task_id
),
std
::
move
(
in_pair
.
first
)));
collector_stop
=
true
;
}
else
{
// kFilterEmpty
task_id_cnt
++
;
}
}
return
Status
::
OK
();
}
// initialize some internal data structure used by WorkerEntry().
Status
FilterOp
::
WorkerEntryInit
(
const
DataBuffer
*
in_buf
,
std
::
vector
<
size_t
>
*
to_process_indices
,
std
::
vector
<
std
::
string
>
*
input_columns
)
{
int32_t
num_rows
=
in_buf
->
NumRows
();
int32_t
num_cols
=
in_buf
->
NumCols
();
if
(
num_rows
==
0
||
num_cols
==
0
)
{
RETURN_STATUS_UNEXPECTED
(
"FilterOp is getting an empty DataBuffer."
);
}
std
::
unordered_map
<
std
::
string
,
int32_t
>
col_name_id_map
=
in_buf
->
column_name_map
();
// Check if there is invalid column name in the inColumns.
RETURN_IF_NOT_OK
(
ValidateInColumns
(
col_name_id_map
,
input_columns
));
if
(
input_columns
->
empty
())
{
MS_LOG
(
INFO
)
<<
"Input columns in filter operator is empty, will apply to the all column in the current table."
;
// sort the input colunms by column index.
std
::
vector
<
std
::
pair
<
std
::
string
,
int32_t
>>
sort_vec
(
col_name_id_map
.
begin
(),
col_name_id_map
.
end
());
std
::
sort
(
sort_vec
.
begin
(),
sort_vec
.
end
(),
[](
const
std
::
pair
<
std
::
string
,
int32_t
>
&
a
,
const
std
::
pair
<
std
::
string
,
int32_t
>
&
b
)
{
return
a
.
second
<
b
.
second
;
});
(
void
)
std
::
transform
(
sort_vec
.
begin
(),
sort_vec
.
end
(),
std
::
back_inserter
(
*
input_columns
),
[](
const
auto
&
it
)
->
std
::
string
{
return
it
.
first
;
});
}
// initialize to_process_indices.
(
void
)
std
::
transform
(
input_columns
->
begin
(),
input_columns
->
end
(),
std
::
back_inserter
(
*
to_process_indices
),
[
&
col_name_id_map
](
const
auto
&
it
)
->
size_t
{
return
col_name_id_map
[
it
];
});
return
Status
::
OK
();
}
Status
FilterOp
::
CheckInput
(
const
TensorRow
&
input
)
const
{
for
(
auto
&
item
:
input
)
{
if
(
item
==
nullptr
)
{
RETURN_STATUS_UNEXPECTED
(
"input is null."
);
}
}
return
Status
::
OK
();
}
Status
FilterOp
::
InvokePredicateFunc
(
const
TensorRow
&
input
,
bool
*
out_predicate
)
{
RETURN_IF_NOT_OK
(
CheckInput
(
input
));
// Acquire Python GIL.
py
::
gil_scoped_acquire
gil_acquire
;
if
(
Py_IsInitialized
()
==
0
)
{
return
Status
(
StatusCode
::
kPythonInterpreterFailure
,
"Python Interpreter is finalized"
);
}
try
{
// Transform input tensor vector into numpy array vector.
py
::
tuple
input_args
(
input
.
size
());
for
(
size_t
i
=
0
;
i
<
input
.
size
();
i
++
)
{
py
::
array
new_data
;
RETURN_IF_NOT_OK
(
input
.
at
(
i
)
->
GetDataAsNumpy
(
&
new_data
));
input_args
[
i
]
=
new_data
;
}
// Invoke python function.
py
::
object
ret_py_obj
=
predicate_func_
(
*
input_args
);
*
out_predicate
=
ret_py_obj
.
cast
<
py
::
bool_
>
();
}
catch
(
const
py
::
error_already_set
&
e
)
{
std
::
stringstream
ss
;
ss
<<
e
.
what
()
<<
std
::
endl
;
ss
<<
"The type of the return value of python predicate function is not bool, or can not be convert to bool."
;
return
Status
(
StatusCode
::
kPyFuncException
,
ss
.
str
());
}
return
Status
(
StatusCode
::
kOK
,
"FilterOp predicate func call succeed"
);
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/filter_op.h
0 → 100644
浏览文件 @
98fbd30a
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_FILTER_OP_H_
#define DATASET_ENGINE_DATASETOPS_FILTER_OP_H_
#include <memory>
#include <queue>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "dataset/engine/datasetops/parallel_op.h"
#include "dataset/kernels/tensor_op.h"
#include "dataset/util/queue.h"
namespace
mindspore
{
namespace
dataset
{
class
FilterOp
:
public
ParallelOp
{
public:
// The nested builder class inside of the FilterOp is used to help manage all of
// the arguments for constructing it. Use the builder by setting each argument
// with the provided set methods, and then finally call the build method to execute
// the actual construction.
class
Builder
{
public:
// Builder constructor. Creates the builder object.
// @note No default args.
// @return This is a constructor.
Builder
();
// Default destructor
~
Builder
()
=
default
;
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder
&
SetPredicateFunc
(
py
::
function
func
)
{
builder_predicate_func_
=
std
::
move
(
func
);
return
*
this
;
}
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder
&
SetInColNames
(
const
std
::
vector
<
std
::
string
>
&
in_col_names
)
{
build_in_col_names_
=
in_col_names
;
return
*
this
;
}
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder
&
SetNumWorkers
(
int32_t
num_workers
)
{
builder_num_workers_
=
num_workers
;
return
*
this
;
}
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder
&
SetOpConnectorSize
(
int32_t
connector_size
)
{
builder_op_connector_size_
=
connector_size
;
return
*
this
;
}
// The builder "build" method creates the final object.
// @param ptr The shared_ptr to the new FilterOp object.
// @return Status.
Status
Build
(
std
::
shared_ptr
<
FilterOp
>
*
ptr
);
private:
// Sanity check for builder class args.
// @return Status - The error code return.
Status
SanityCheck
();
std
::
vector
<
std
::
string
>
build_in_col_names_
;
py
::
function
builder_predicate_func_
;
int32_t
builder_num_workers_
;
int32_t
builder_op_connector_size_
;
};
enum
filterCtrl
:
int8_t
{
kFilterEmpty
=
0
,
kFilterPartial
=
1
,
kFilterFull
=
2
,
kFilterEoe
=
3
,
kFilterEof
=
4
};
// Constructor of FilterOp
// @note The builder class should be used to call it.
// @param in_col_names A list of input column names,when it is empty the predicate will be
// applied all columns in the dataset.
// @param num_workers The number of worker threads.
// @param op_connector_size The size of each queue in the connector.
// @param predicate_func python callable which returns a boolean value.
FilterOp
(
const
std
::
vector
<
std
::
string
>
&
in_col_names
,
int32_t
num_workers
,
int32_t
op_queue_size
,
py
::
function
predicate_func
);
// Class functor operator () override.
// All dataset ops operate by launching a thread (see ExecutionTree),This class functor will
// provide the master loop that drives the logic for performing the work.
// @return Status The error code return
Status
operator
()()
override
;
// @param int32_t workerId.
// @return Status - The error code return.
Status
EofReceived
(
int32_t
)
override
;
// @param int32_t workerId.
// @return Status - The error code return.
Status
EoeReceived
(
int32_t
)
override
;
// A print method typically used for debugging.
// @param out The output stream to write output to.
// @param show_all A bool to control if you want to show all info or just a summary.
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
private:
// predicate_func python callable which returns a boolean value.
py
::
function
predicate_func_
;
// Variable to store the column name that will feed to predicate function.
std
::
vector
<
std
::
string
>
in_columns_
;
// Internal queue for filter.
QueueList
<
std
::
pair
<
std
::
unique_ptr
<
DataBuffer
>
,
filterCtrl
>>
filter_queues_
;
// Private function for worker/thread to loop continuously. It comprises the main
// logic of FilterOp, getting the data from previous Op, validating user specified column names,
// applying predicate to each of the data, filter the data when predicate result is false.
// @param worker_id The id assigned to this thread/worker upon creation.
// @return Status The error code return.
Status
WorkerEntry
(
int32_t
worker_id
)
override
;
// In: workerId assigned by tree_
// Filter the data by predicate function .
// @param in_buffer input data buffer.
// @param to_proess_indices Indices of columns to be processed.
// @param out data buffer that are filtered by predicate.
// @return Status The error code return.
Status
WorkerCompute
(
DataBuffer
*
in_buffer
,
const
std
::
vector
<
size_t
>
&
to_proess_indices
,
std
::
unique_ptr
<
TensorQTable
>
*
out
);
// Collector databuffer.
// @return Status The error code return.
Status
Collector
();
// @param input tensor vector.
// @return Status - The error code return.
Status
CheckInput
(
const
TensorRow
&
input
)
const
;
// Invoke python func.
// @param input tensor vector.
// @param the result of predicate.
// @return Status - The error code return.
Status
InvokePredicateFunc
(
const
TensorRow
&
input
,
bool
*
out_predicate
);
// Private function for validating if each of the user specified input column names
// exist in the DataBuffer.
// @param col_name_id_map The column name to index mapping obtained from DataBuffer.
// @param input_columns The vector of input column names used in the current thread.
// @return Status The error code return.
Status
ValidateInColumns
(
const
std
::
unordered_map
<
std
::
string
,
int32_t
>
&
col_name_id_map
,
std
::
vector
<
std
::
string
>
*
input_columns
);
// Private function that initialize some internal data structure used by WorkerEntry().
// @param in_buf A raw pointer to the DataBuffer. A raw pointer is fine because this function does not manage memory
// and is not shared with other threads.
// @param[out] to_process_indices Indices of columns that will feed to predicate.
// @param input_columns The vector of input column names used in the current thread.
Status
WorkerEntryInit
(
const
DataBuffer
*
in_buf
,
std
::
vector
<
size_t
>
*
to_process_indices
,
std
::
vector
<
std
::
string
>
*
input_columns
);
};
}
// namespace dataset
}
// namespace mindspore
#endif
mindspore/dataset/engine/datasets.py
浏览文件 @
98fbd30a
...
...
@@ -35,7 +35,7 @@ from mindspore._c_expression import typing
from
mindspore
import
log
as
logger
from
.
import
samplers
from
.iterators
import
DictIterator
,
TupleIterator
from
.validators
import
check
,
check_batch
,
check_shuffle
,
check_map
,
check_repeat
,
check_skip
,
check_zip
,
check_rename
,
\
from
.validators
import
check
,
check_batch
,
check_shuffle
,
check_map
,
check_
filter
,
check_
repeat
,
check_skip
,
check_zip
,
check_rename
,
\
check_take
,
check_project
,
check_imagefolderdatasetv2
,
check_mnist_cifar_dataset
,
check_manifestdataset
,
\
check_tfrecorddataset
,
check_vocdataset
,
check_celebadataset
,
check_minddataset
,
check_generatordataset
,
\
check_zip_dataset
,
check_add_column
,
check_textfiledataset
...
...
@@ -385,6 +385,32 @@ class Dataset:
"""
return
MapDataset
(
self
,
input_columns
,
operations
,
output_columns
,
columns_order
,
num_parallel_workers
)
@
check_filter
def
filter
(
self
,
predicate
,
input_columns
=
None
,
num_parallel_workers
=
1
):
"""
Filter dataset by predicate.
Note:
If input_columns not provided or empty, all columns will be used.
Args:
predicate: python callable which returns a boolean value.
input_columns: (list[str]): List of names of the input columns, when
default=None, the predicate will be applied on all columns in the dataset.
num_parallel_workers (int, optional): Number of workers to process the Dataset
in parallel (default=None).
Returns:
FilterDataset, dataset filter.
Examples:
>>> import mindspore.dataset as ds
>>> # generator data(0 ~ 63)
>>> # filter the data that greater than or equal to 11
>>> dataset_f = dataset.filter(predicate=lambda data: data < 11, input_columns = ["data"])
"""
return
FilterDataset
(
self
,
predicate
,
input_columns
,
num_parallel_workers
)
@
check_repeat
def
repeat
(
self
,
count
=
None
):
"""
...
...
@@ -1105,6 +1131,44 @@ class MapDataset(DatasetOp):
return
self
.
input
[
0
].
get_dataset_size
()
class
FilterDataset
(
DatasetOp
):
"""
The result of applying filter predicate to the input Dataset.
Args:
input_dataset: Input Dataset to be mapped.
predicate: python callable which returns a boolean value.
input_columns: (list[str]): List of names of the input columns, when
default=None, the predicate will be applied all columns in the dataset.
num_parallel_workers (int, optional): Number of workers to process the Dataset
in parallel (default=None).
"""
def
__init__
(
self
,
input_dataset
,
predicate
,
input_columns
=
None
,
num_parallel_workers
=
None
):
super
().
__init__
(
num_parallel_workers
)
self
.
predicate
=
lambda
*
args
:
bool
(
predicate
(
*
args
))
self
.
input
.
append
(
input_dataset
)
input_dataset
.
output
.
append
(
self
)
if
input_columns
is
not
None
and
not
isinstance
(
input_columns
,
list
):
input_columns
=
[
input_columns
]
self
.
input_columns
=
input_columns
def
get_args
(
self
):
args
=
super
().
get_args
()
args
[
"predicate"
]
=
self
.
predicate
args
[
"input_columns"
]
=
self
.
input_columns
return
args
def
get_dataset_size
(
self
):
"""
Get the number of batches in an epoch.
the size cannot be determined before we run the pipeline
Return:
0
"""
return
0
class
RepeatDataset
(
DatasetOp
):
"""
The result of applying Repeat operator to the input Dataset.
...
...
mindspore/dataset/engine/iterators.py
浏览文件 @
98fbd30a
...
...
@@ -129,6 +129,8 @@ class Iterator:
op_type
=
OpName
.
ZIP
elif
isinstance
(
dataset
,
de
.
MapDataset
):
op_type
=
OpName
.
MAP
elif
isinstance
(
dataset
,
de
.
FilterDataset
):
op_type
=
OpName
.
FILTER
elif
isinstance
(
dataset
,
de
.
RepeatDataset
):
op_type
=
OpName
.
REPEAT
elif
isinstance
(
dataset
,
de
.
SkipDataset
):
...
...
mindspore/dataset/engine/validators.py
浏览文件 @
98fbd30a
...
...
@@ -693,6 +693,26 @@ def check_map(method):
return
new_method
def
check_filter
(
method
):
""""check the input arguments of filter."""
@
wraps
(
method
)
def
new_method
(
*
args
,
**
kwargs
):
param_dict
=
make_param_dict
(
method
,
args
,
kwargs
)
predicate
=
param_dict
.
get
(
"predicate"
)
if
not
callable
(
predicate
):
raise
ValueError
(
"Predicate should be a python function or a callable python object."
)
nreq_param_int
=
[
'num_parallel_workers'
]
check_param_type
(
nreq_param_int
,
param_dict
,
int
)
param_name
=
"input_columns"
param
=
param_dict
.
get
(
param_name
)
if
param
is
not
None
:
check_columns
(
param
,
param_name
)
return
method
(
*
args
,
**
kwargs
)
return
new_method
def
check_repeat
(
method
):
"""check the input arguments of repeat."""
@
wraps
(
method
)
...
...
tests/ut/cpp/dataset/CMakeLists.txt
浏览文件 @
98fbd30a
...
...
@@ -66,6 +66,8 @@ SET(DE_UT_SRCS
celeba_op_test.cc
take_op_test.cc
text_file_op_test.cc
)
filter_op_test.cc
)
add_executable
(
de_ut_tests
${
DE_UT_SRCS
}
)
...
...
tests/ut/cpp/dataset/filter_op_test.cc
0 → 100644
浏览文件 @
98fbd30a
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/util/circular_pool.h"
#include "dataset/core/client.h"
#include "common/common.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
using
namespace
mindspore
::
dataset
;
namespace
de
=
mindspore
::
dataset
;
using
mindspore
::
MsLogLevel
::
INFO
;
using
mindspore
::
ExceptionType
::
NoExceptionType
;
using
mindspore
::
LogStream
;
class
MindDataTestfilter_op
:
public
UT
::
DatasetOpTesting
{
};
std
::
shared_ptr
<
de
::
FilterOp
>
Filter
()
{
Status
rc
;
std
::
shared_ptr
<
de
::
FilterOp
>
op
;
rc
=
de
::
FilterOp
::
Builder
().
Build
(
&
op
);
EXPECT_TRUE
(
rc
.
IsOk
());
return
op
;
}
TEST_F
(
MindDataTestfilter_op
,
Testfilter_opFuntions
)
{
MS_LOG
(
INFO
)
<<
"Doing MindDataTest filter_op."
;
auto
my_tree
=
std
::
make_shared
<
ExecutionTree
>
();
std
::
shared_ptr
<
DatasetOp
>
parent_op
=
Filter
();
std
::
shared_ptr
<
DatasetOp
>
leaf_op
=
Filter
();
my_tree
->
AssociateNode
(
parent_op
);
my_tree
->
AssociateNode
(
leaf_op
);
ASSERT_NE
(
parent_op
,
nullptr
);
ASSERT_NE
(
leaf_op
,
nullptr
);
}
tests/ut/cpp/dataset/tensor_test.cc
浏览文件 @
98fbd30a
...
...
@@ -158,6 +158,16 @@ TEST_F(MindDataTestTensorDE, InsertTensor) {
ASSERT_EQ
(
*
t
==
*
t6
,
true
);
}
// Test the bug of Tensor::ToString will exec failed for Tensor which store bool values
TEST_F
(
MindDataTestTensorDE
,
BoolTensor
)
{
std
::
shared_ptr
<
Tensor
>
t
=
std
::
make_shared
<
Tensor
>
(
TensorShape
({
2
}),
DataType
(
DataType
::
DE_BOOL
));
t
->
SetItemAt
<
bool
>
({
0
},
true
);
t
->
SetItemAt
<
bool
>
({
1
},
true
);
std
::
string
out
=
t
->
ToString
();
ASSERT_TRUE
(
out
.
find
(
"Template type and Tensor type are not compatible"
)
==
std
::
string
::
npos
);
}
TEST_F
(
MindDataTestTensorDE
,
GetItemAt
)
{
std
::
shared_ptr
<
Tensor
>
t
=
std
::
make_shared
<
Tensor
>
(
TensorShape
({
2
,
2
}),
DataType
(
DataType
::
DE_UINT8
));
t
->
Fill
<
uint8_t
>
(
254
);
...
...
tests/ut/data/dataset/declient_filter.cfg
0 → 100644
浏览文件 @
98fbd30a
{
"rowsPerBuffer": 10,
}
tests/ut/python/dataset/test_filterop.py
0 → 100644
浏览文件 @
98fbd30a
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
numpy
as
np
import
mindspore.dataset
as
ds
import
mindspore.dataset.transforms.vision.c_transforms
as
cde
import
mindspore.dataset.transforms.c_transforms
as
C
import
mindspore.common.dtype
as
mstype
from
mindspore
import
log
as
logger
DATA_DIR
=
[
"../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"
]
SCHEMA_DIR
=
"../data/dataset/test_tf_file_3_images/datasetSchema.json"
# test for predicate
def
test_diff_predicate_func
():
def
test_filter
(
predicate_func
):
transforms
=
[
cde
.
Decode
(),
cde
.
Resize
([
64
,
64
])
]
type_cast_op
=
C
.
TypeCast
(
mstype
.
int32
)
dataset
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
[
"image"
,
"label"
],
shuffle
=
False
)
dataset
=
dataset
.
map
(
input_columns
=
[
"image"
],
operations
=
transforms
,
num_parallel_workers
=
1
)
dataset
=
dataset
.
filter
(
input_columns
=
[
"image"
,
"label"
],
predicate
=
predicate_func
,
num_parallel_workers
=
4
)
num_iter
=
0
label_list
=
[]
for
data
in
dataset
.
create_dict_iterator
():
num_iter
+=
1
ori_img
=
data
[
"image"
]
label
=
data
[
"label"
]
label_list
.
append
(
label
)
assert
num_iter
==
1
assert
label_list
[
0
]
==
3
test_filter
(
lambda
image
,
label
:
label
==
3
)
test_filter
(
lambda
image
,
label
:
label
[
0
]
==
3
)
test_filter
(
lambda
image
,
label
:
label
==
[
3
])
test_filter
(
lambda
image
,
label
:
label
==
np
.
array
([
3
]))
test_filter
(
lambda
image
,
label
:
label
==
np
.
array
(
3
))
def
filter_func_ge
(
data
):
if
data
>
10
:
return
False
return
True
def
generator_1d
():
for
i
in
range
(
64
):
yield
(
np
.
array
(
i
),)
# test with GeneratorDataset
def
test_filter_by_generator_with_no
():
dataset
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
dataset_f
=
dataset
.
filter
(
predicate
=
lambda
data
:
data
<
11
,
num_parallel_workers
=
4
)
num_iter
=
0
expected_rs
=
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
]
for
item
in
dataset_f
.
create_dict_iterator
():
assert
item
[
"data"
]
==
expected_rs
[
num_iter
]
num_iter
+=
1
# test with repeatOp before
def
test_filter_by_generator_with_repeat
():
dataset
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
dataset_r
=
dataset
.
repeat
(
4
)
dataset_f
=
dataset_r
.
filter
(
predicate
=
filter_func_ge
,
num_parallel_workers
=
4
)
num_iter
=
0
ret_data
=
[]
expected_rs
=
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
]
for
item
in
dataset_f
.
create_dict_iterator
():
num_iter
+=
1
ret_data
.
append
(
item
[
"data"
])
assert
num_iter
==
44
for
i
in
range
(
4
):
for
ii
in
range
(
len
(
expected_rs
)):
index
=
i
*
len
(
expected_rs
)
+
ii
assert
ret_data
[
index
]
==
expected_rs
[
ii
]
# test with repeatOp after
def
test_filter_by_generator_with_repeat_after
():
dataset
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
dataset_f
=
dataset
.
filter
(
predicate
=
filter_func_ge
,
num_parallel_workers
=
4
)
dataset_r
=
dataset_f
.
repeat
(
4
)
num_iter
=
0
ret_data
=
[]
expected_rs
=
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
]
for
item
in
dataset_r
.
create_dict_iterator
():
num_iter
+=
1
ret_data
.
append
(
item
[
"data"
])
assert
num_iter
==
44
for
i
in
range
(
4
):
for
ii
in
range
(
len
(
expected_rs
)):
index
=
i
*
len
(
expected_rs
)
+
ii
assert
ret_data
[
index
]
==
expected_rs
[
ii
]
def
filter_func_batch
(
data
):
if
data
[
0
]
>
8
:
return
False
return
True
def
filter_func_batch_after
(
data
):
if
data
>
20
:
return
False
return
True
# test with batchOp before
def
test_filter_by_generator_with_batch
():
dataset
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
dataset_b
=
dataset
.
batch
(
4
)
dataset_f
=
dataset_b
.
filter
(
predicate
=
filter_func_batch
,
num_parallel_workers
=
4
)
num_iter
=
0
ret_data
=
[]
for
item
in
dataset_f
.
create_dict_iterator
():
num_iter
+=
1
ret_data
.
append
(
item
[
"data"
])
assert
num_iter
==
3
assert
ret_data
[
0
][
0
]
==
0
assert
ret_data
[
1
][
0
]
==
4
assert
ret_data
[
2
][
0
]
==
8
# test with batchOp after
def
test_filter_by_generator_with_batch_after
():
dataset
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
dataset_f
=
dataset
.
filter
(
predicate
=
filter_func_batch_after
,
num_parallel_workers
=
4
)
dataset_b
=
dataset_f
.
batch
(
4
)
num_iter
=
0
ret_data
=
[]
for
item
in
dataset_b
.
create_dict_iterator
():
num_iter
+=
1
ret_data
.
append
(
item
[
"data"
])
assert
num_iter
==
6
assert
ret_data
[
0
][
0
]
==
0
assert
ret_data
[
1
][
0
]
==
4
assert
ret_data
[
5
][
0
]
==
20
def
filter_func_shuffle
(
data
):
if
data
>
20
:
return
False
return
True
# test with batchOp before
def
test_filter_by_generator_with_shuffle
():
dataset
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
dataset_s
=
dataset
.
shuffle
(
4
)
dataset_f
=
dataset_s
.
filter
(
predicate
=
filter_func_shuffle
,
num_parallel_workers
=
4
)
num_iter
=
0
for
item
in
dataset_f
.
create_dict_iterator
():
num_iter
+=
1
assert
num_iter
==
21
def
filter_func_shuffle_after
(
data
):
if
data
>
20
:
return
False
return
True
# test with batchOp after
def
test_filter_by_generator_with_shuffle_after
():
dataset
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
dataset_f
=
dataset
.
filter
(
predicate
=
filter_func_shuffle_after
,
num_parallel_workers
=
4
)
dataset_s
=
dataset_f
.
shuffle
(
4
)
num_iter
=
0
for
item
in
dataset_s
.
create_dict_iterator
():
num_iter
+=
1
assert
num_iter
==
21
def
generator_1d_zip1
():
for
i
in
range
(
64
):
yield
(
np
.
array
(
i
),)
def
generator_1d_zip2
():
for
i
in
range
(
64
):
yield
(
np
.
array
(
i
+
100
),)
def
filter_func_zip
(
data1
,
data2
):
if
data1
>
20
:
return
False
return
True
def
filter_func_zip_after
(
data1
):
if
data1
>
20
:
return
False
return
True
# test with zipOp before
def
test_filter_by_generator_with_zip
():
dataset1
=
ds
.
GeneratorDataset
(
generator_1d_zip1
,
[
"data1"
])
dataset2
=
ds
.
GeneratorDataset
(
generator_1d_zip2
,
[
"data2"
])
dataz
=
ds
.
zip
((
dataset1
,
dataset2
))
dataset_f
=
dataz
.
filter
(
predicate
=
filter_func_zip
,
num_parallel_workers
=
1
)
num_iter
=
0
ret_data
=
[]
for
item
in
dataset_f
.
create_dict_iterator
():
num_iter
+=
1
ret_data
.
append
({
"data1"
:
item
[
"data1"
],
"data2"
:
item
[
"data2"
]})
assert
num_iter
==
21
assert
ret_data
[
0
][
"data1"
]
==
0
assert
ret_data
[
0
][
"data2"
]
==
100
assert
ret_data
[
5
][
"data1"
]
==
5
assert
ret_data
[
5
][
"data2"
]
==
105
# test with zipOp after
def
test_filter_by_generator_with_zip_after
():
dataset1
=
ds
.
GeneratorDataset
(
generator_1d_zip1
,
[
"data1"
])
dataset2
=
ds
.
GeneratorDataset
(
generator_1d_zip1
,
[
"data2"
])
dt1
=
dataset1
.
filter
(
predicate
=
filter_func_zip_after
,
num_parallel_workers
=
4
)
dt2
=
dataset2
.
filter
(
predicate
=
filter_func_zip_after
,
num_parallel_workers
=
4
)
dataz
=
ds
.
zip
((
dt1
,
dt2
))
num_iter
=
0
ret_data
=
[]
for
item
in
dataz
.
create_dict_iterator
():
num_iter
+=
1
ret_data
.
append
({
"data1"
:
item
[
"data1"
],
"data2"
:
item
[
"data2"
]})
assert
num_iter
==
21
assert
ret_data
[
0
][
"data1"
]
==
0
assert
ret_data
[
0
][
"data2"
]
==
0
assert
ret_data
[
5
][
"data1"
]
==
5
assert
ret_data
[
5
][
"data2"
]
==
5
def
filter_func_map
(
col1
,
col2
):
if
col1
[
0
]
>
8
:
return
True
return
False
def
filter_func_map_part
(
col1
):
if
col1
<
3
:
return
True
else
:
return
False
def
filter_func_map_all
(
col1
,
col2
):
return
True
def
generator_mc
(
maxid
=
20
):
for
i
in
range
(
maxid
):
yield
(
np
.
array
([
i
]),
np
.
array
([[
i
,
i
+
1
],
[
i
+
2
,
i
+
3
]]))
def
func_map
(
data_col1
,
data_col2
):
return
(
data_col1
,
data_col2
)
def
func_map_part
(
data_col1
):
return
(
data_col1
)
# test with map
def
test_filter_by_generator_with_map_all_col
():
dataset
=
ds
.
GeneratorDataset
(
generator_mc
(
12
),
[
"col1"
,
"col2"
])
dataset_map
=
dataset
.
map
(
input_columns
=
[
"col1"
],
output_columns
=
[
"col1"
]
,
operations
=
func_map_part
)
# dataset_map = dataset.map( operations=func_map_part)
dataset_f
=
dataset_map
.
filter
(
input_columns
=
[
"col1"
],
predicate
=
filter_func_map_part
,
num_parallel_workers
=
1
)
num_iter
=
0
ret_data
=
[]
for
item
in
dataset_f
.
create_dict_iterator
():
num_iter
+=
1
ret_data
.
append
(
item
[
"col1"
])
assert
num_iter
==
3
assert
ret_data
[
0
]
==
0
assert
ret_data
[
1
]
==
1
# test with map
def
test_filter_by_generator_with_map_part_col
():
dataset
=
ds
.
GeneratorDataset
(
generator_mc
(
12
),
[
"col1"
,
"col2"
])
dataset_map
=
dataset
.
map
(
input_columns
=
[
"col1"
],
output_columns
=
[
"out1"
]
,
operations
=
func_map_part
)
dataset_f
=
dataset_map
.
filter
(
input_columns
=
[
"out1"
,
"col2"
],
predicate
=
filter_func_map
,
num_parallel_workers
=
4
)
num_iter
=
0
ret_data
=
[]
for
item
in
dataset_f
.
create_dict_iterator
():
num_iter
+=
1
print
(
item
)
ret_data
.
append
(
item
[
"out1"
])
assert
num_iter
==
3
assert
ret_data
[
0
]
==
9
assert
ret_data
[
2
]
==
11
def
filter_func_rename
(
data
):
if
data
>
8
:
return
True
return
False
# test with rename before
def
test_filter_by_generator_with_rename
():
dataset
=
ds
.
GeneratorDataset
(
generator_1d
,
[
"data"
])
dataset_b
=
dataset
.
rename
(
input_columns
=
[
"data"
],
output_columns
=
[
"col1"
])
dataset_f
=
dataset_b
.
filter
(
predicate
=
filter_func_rename
,
num_parallel_workers
=
4
)
num_iter
=
0
ret_data
=
[]
for
item
in
dataset_f
.
create_dict_iterator
():
num_iter
+=
1
ret_data
.
append
(
item
[
"col1"
])
assert
num_iter
==
55
assert
ret_data
[
0
]
==
9
assert
ret_data
[
54
]
==
63
#test input_column
def
filter_func_input_column1
(
col1
,
col2
):
if
col1
[
0
]
<
8
:
return
True
return
False
def
filter_func_input_column2
(
col1
):
if
col1
[
0
]
<
8
:
return
True
return
False
def
filter_func_input_column3
(
col1
):
return
True
# test with input_columns
def
test_filter_by_generator_with_input_column
():
dataset
=
ds
.
GeneratorDataset
(
generator_mc
(
64
),
[
"col1"
,
"col2"
])
dataset_map
=
dataset
.
map
(
input_columns
=
[
"col1"
],
output_columns
=
[
"out1"
]
,
operations
=
func_map_part
)
dataset_f1
=
dataset_map
.
filter
(
input_columns
=
[
"out1"
,
"col2"
],
predicate
=
filter_func_input_column1
,
num_parallel_workers
=
4
)
dataset_f2
=
dataset_f1
.
filter
(
input_columns
=
[
"out1"
],
predicate
=
filter_func_input_column2
,
num_parallel_workers
=
4
)
dataset_f3
=
dataset_f2
.
filter
(
input_columns
=
[
"col2"
],
predicate
=
filter_func_input_column3
,
num_parallel_workers
=
4
)
dataset_f4
=
dataset_f3
.
filter
(
predicate
=
filter_func_input_column1
,
num_parallel_workers
=
4
)
num_iter
=
0
ret_data
=
[]
for
item
in
dataset_f4
.
create_dict_iterator
():
num_iter
+=
1
ret_data
.
append
(
item
[
"out1"
])
assert
num_iter
==
8
assert
ret_data
[
0
]
==
0
assert
ret_data
[
7
]
==
7
#test kFilterPartial
def
generator_mc_p0
(
maxid
=
20
):
for
i
in
range
(
maxid
):
yield
(
np
.
array
([
i
]),
np
.
array
([
i
+
100
]))
def
generator_mc_p1
(
maxid
=
20
):
for
i
in
range
(
maxid
):
yield
(
np
.
array
([
i
+
200
]),
np
.
array
([
i
+
300
]))
def
filter_func_Partial_0
(
col1
,
col2
,
col3
,
col4
):
filter_data
=
[
0
,
1
,
2
,
3
,
4
,
11
]
if
col1
[
0
]
in
filter_data
:
return
False
return
True
# test with row_data_buffer > 1
def
test_filter_by_generator_Partial0
():
ds
.
config
.
load
(
'../data/dataset/declient_filter.cfg'
)
dataset1
=
ds
.
GeneratorDataset
(
source
=
generator_mc_p0
(),
column_names
=
[
"col1"
,
"col2"
])
dataset2
=
ds
.
GeneratorDataset
(
source
=
generator_mc_p1
(),
column_names
=
[
"col3"
,
"col4"
])
dataset_zip
=
ds
.
zip
((
dataset1
,
dataset2
))
dataset_f1
=
dataset_zip
.
filter
(
predicate
=
filter_func_Partial_0
,
num_parallel_workers
=
2
)
ret
=
[]
for
item
in
dataset_f1
.
create_dict_iterator
():
ret
.
append
(
item
[
"col1"
])
assert
ret
[
0
]
==
5
assert
ret
[
6
]
==
12
# test with row_data_buffer > 1
def
test_filter_by_generator_Partial1
():
ds
.
config
.
load
(
'../data/dataset/declient_filter.cfg'
)
dataset1
=
ds
.
GeneratorDataset
(
source
=
generator_mc_p0
(),
column_names
=
[
"col1"
,
"col2"
])
dataset2
=
ds
.
GeneratorDataset
(
source
=
generator_mc_p1
(),
column_names
=
[
"col3"
,
"col4"
])
dataset_zip
=
ds
.
zip
((
dataset1
,
dataset2
))
dataset_f1
=
dataset_zip
.
filter
(
predicate
=
filter_func_Partial_0
,
num_parallel_workers
=
2
)
dataset_map
=
dataset_f1
.
map
(
input_columns
=
[
"col1"
],
output_columns
=
[
"out1"
]
,
operations
=
lambda
x1
:
x1
+
400
)
ret
=
[]
for
item
in
dataset_map
.
create_dict_iterator
():
ret
.
append
(
item
[
"out1"
])
assert
ret
[
0
]
==
405
assert
ret
[
6
]
==
412
# test with row_data_buffer > 1
def
test_filter_by_generator_Partial2
():
ds
.
config
.
load
(
'../data/dataset/declient_filter.cfg'
)
dataset1
=
ds
.
GeneratorDataset
(
source
=
generator_mc_p0
(),
column_names
=
[
"col1"
,
"col2"
])
dataset2
=
ds
.
GeneratorDataset
(
source
=
generator_mc_p1
(),
column_names
=
[
"col3"
,
"col4"
])
dataset1f
=
dataset1
.
filter
(
input_columns
=
[
"col1"
],
predicate
=
lambda
x
:
x
not
in
[
3
,
7
,
9
],
num_parallel_workers
=
2
)
dataset2f
=
dataset2
.
filter
(
input_columns
=
[
"col3"
],
predicate
=
lambda
x
:
x
not
in
[
203
,
207
,
209
],
num_parallel_workers
=
2
)
dataset_zip
=
ds
.
zip
((
dataset1f
,
dataset2f
))
dataset_map
=
dataset_zip
.
map
(
input_columns
=
[
"col1"
,
"col3"
],
output_columns
=
[
"out1"
,
"out3"
]
,
operations
=
lambda
x1
,
x3
:
(
x1
+
400
,
x3
+
500
))
ret1
=
[]
ret3
=
[]
for
item
in
dataset_map
.
create_dict_iterator
():
ret1
.
append
(
item
[
"out1"
])
ret3
.
append
(
item
[
"out3"
])
assert
ret1
[
0
]
==
400
assert
ret1
[
6
]
==
408
assert
ret3
[
0
]
==
700
assert
ret3
[
6
]
==
708
def
filter_func_Partial
(
col1
,
col2
):
if
col1
[
0
]
%
3
==
0
:
return
True
return
False
def
generator_big
(
maxid
=
20
):
for
i
in
range
(
maxid
):
yield
(
np
.
array
([
i
]),
np
.
array
([[
i
,
i
+
1
],
[
i
+
2
,
i
+
3
]]))
# test with row_data_buffer > 1
def
test_filter_by_generator_Partial
():
ds
.
config
.
load
(
'../data/dataset/declient_filter.cfg'
)
dataset
=
ds
.
GeneratorDataset
(
source
=
generator_mc
(
99
),
column_names
=
[
"col1"
,
"col2"
])
dataset_s
=
dataset
.
shuffle
(
4
)
dataset_f1
=
dataset_s
.
filter
(
input_columns
=
[
"col1"
,
"col2"
],
predicate
=
filter_func_Partial
,
num_parallel_workers
=
1
)
for
item
in
dataset_f1
.
create_dict_iterator
():
assert
item
[
"col1"
]
%
3
==
0
def
filter_func_cifar
(
col1
,
col2
):
if
col2
%
3
==
0
:
return
True
return
False
# test with cifar10
def
test_filte_case_dataset_cifar10
():
DATA_DIR_10
=
"../data/dataset/testCifar10Data"
ds
.
config
.
load
(
'../data/dataset/declient_filter.cfg'
)
dataset_c
=
ds
.
Cifar10Dataset
(
dataset_dir
=
DATA_DIR_10
,
num_samples
=
100000
,
shuffle
=
False
)
dataset_f1
=
dataset_c
.
filter
(
input_columns
=
[
"image"
,
"label"
],
predicate
=
filter_func_cifar
,
num_parallel_workers
=
1
)
num_iter
=
0
for
item
in
dataset_f1
.
create_dict_iterator
():
# in this example, each dictionary has keys "image" and "label"
assert
item
[
"label"
]
%
3
==
0
# column id sort
def
generator_sort1
(
maxid
=
20
):
for
i
in
range
(
maxid
):
yield
(
np
.
array
([
i
]),
np
.
array
([
i
+
100
]),
np
.
array
([
i
+
200
]))
def
generator_sort2
(
maxid
=
20
):
for
i
in
range
(
maxid
):
yield
(
np
.
array
([
i
+
300
]),
np
.
array
([
i
+
400
]),
np
.
array
([
i
+
500
]))
def
filter_func_part_sort
(
col1
,
col2
,
col3
,
col4
,
col5
,
col6
):
return
True
def
filter_func_map_sort
(
col1
,
col2
,
col3
):
return
(
col1
,
col2
,
col3
)
def
test_filter_by_generator_with_map_all_sort
():
dataset1
=
ds
.
GeneratorDataset
(
generator_sort1
(
10
),
[
"col1"
,
"col2"
,
"col3"
])
dataset2
=
ds
.
GeneratorDataset
(
generator_sort2
(
10
),
[
"col4 "
,
"col5"
,
"col6"
])
dataz
=
ds
.
zip
((
dataset1
,
dataset2
))
dataset_f
=
dataz
.
filter
(
predicate
=
filter_func_part_sort
,
num_parallel_workers
=
1
)
num_iter
=
0
ret_data
=
[]
for
item
in
dataset_f
.
create_dict_iterator
():
num_iter
+=
1
ret_data
.
append
(
item
)
assert
num_iter
==
10
assert
ret_data
[
0
][
"col1"
]
==
0
assert
ret_data
[
9
][
"col6"
]
==
509
if
__name__
==
'__main__'
:
test_diff_predicate_func
()
test_filte_case_dataset_cifar10
()
test_filter_by_generator_Partial0
()
test_filter_by_generator_Partial1
()
test_filter_by_generator_Partial2
()
test_filter_by_generator_with_batch
()
test_filter_by_generator_with_batch_after
()
test_filter_by_generator_with_input_column
()
test_filter_by_generator_with_map_all_col
()
test_filter_by_generator_with_map_all_sort
()
test_filter_by_generator_with_map_part_col
()
test_filter_by_generator_with_no
()
test_filter_by_generator_with_rename
()
test_filter_by_generator_with_repeat
()
test_filter_by_generator_with_repeat_after
()
test_filter_by_generator_with_shuffle
()
test_filter_by_generator_with_shuffle_after
()
test_filter_by_generator_with_zip
()
test_filter_by_generator_with_zip_after
()
test_filter_by_generator_Partial
()
tests/ut/python/dataset/test_iterator.py
浏览文件 @
98fbd30a
...
...
@@ -25,8 +25,8 @@ COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float",
def
check
(
project_columns
):
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
COLUMNS
)
data2
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
project_columns
)
data1
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
COLUMNS
,
shuffle
=
False
)
data2
=
ds
.
TFRecordDataset
(
DATA_DIR
,
SCHEMA_DIR
,
columns_list
=
project_columns
,
shuffle
=
False
)
for
data_actual
,
data_expected
in
zip
(
data1
.
create_tuple_iterator
(
project_columns
),
data2
.
create_tuple_iterator
()):
assert
len
(
data_actual
)
==
len
(
data_expected
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录