提交 98fbd30a 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!460 [Data]Add filter operation

Merge pull request !460 from xulei/filter_master
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include "dataset/engine/datasetops/source/cifar_op.h" #include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/celeba_op.h" #include "dataset/engine/datasetops/source/celeba_op.h"
#include "dataset/engine/datasetops/source/text_file_op.h" #include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/engine/datasetops/filter_op.h"
#include "mindrecord/include/shard_category.h" #include "mindrecord/include/shard_category.h"
#include "mindrecord/include/shard_sample.h" #include "mindrecord/include/shard_sample.h"
#include "mindrecord/include/shard_shuffle.h" #include "mindrecord/include/shard_shuffle.h"
...@@ -45,6 +46,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D ...@@ -45,6 +46,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D
{kShuffle, &DEPipeline::ParseShuffleOp}, {kShuffle, &DEPipeline::ParseShuffleOp},
{kMindrecord, &DEPipeline::ParseMindRecordOp}, {kMindrecord, &DEPipeline::ParseMindRecordOp},
{kMap, &DEPipeline::ParseMapOp}, {kMap, &DEPipeline::ParseMapOp},
{kFilter, &DEPipeline::ParseFilterOp},
{kBatch, &DEPipeline::ParseBatchOp}, {kBatch, &DEPipeline::ParseBatchOp},
{kRepeat, &DEPipeline::ParseRepeatOp}, {kRepeat, &DEPipeline::ParseRepeatOp},
{kSkip, &DEPipeline::ParseSkipOp}, {kSkip, &DEPipeline::ParseSkipOp},
...@@ -502,6 +504,41 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> * ...@@ -502,6 +504,41 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *
return Status::OK(); return Status::OK();
} }
Status DEPipeline::ParseFilterOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
std::shared_ptr<FilterOp::Builder> builder = std::make_shared<FilterOp::Builder>();
if (args["predicate"].is_none()) {
RETURN_STATUS_UNEXPECTED("Error: 'predicate' is not set. \n");
}
for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
} else if (key == "predicate") {
py::handle op = args["predicate"];
if (!py::isinstance<py::function>(op)) {
RETURN_STATUS_UNEXPECTED("Error: predicate is not recognised (not pyfunc).");
}
py::function predicate_func = op.cast<py::function>();
(void)builder->SetPredicateFunc(std::move(predicate_func));
} else if (key == "input_columns") {
std::vector<std::string> in_col_names = ToStringVector(args["input_columns"]);
(void)builder->SetInColNames(in_col_names);
} else {
RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key);
}
}
}
std::shared_ptr<FilterOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
*ptr = op;
return Status::OK();
}
Status DEPipeline::ParseRepeatOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { Status DEPipeline::ParseRepeatOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
if (args["count"].is_none()) { if (args["count"].is_none()) {
std::string err_msg = "Error: count is invalid or not set."; std::string err_msg = "Error: count is invalid or not set.";
...@@ -671,8 +708,6 @@ Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> * ...@@ -671,8 +708,6 @@ Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *
return Status::OK(); return Status::OK();
} }
DsOpPtr DEPipeline::ParseFilterOp(const py::dict &args) const { return DsOpPtr(); }
Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) { Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
// Required arguments // Required arguments
std::shared_ptr<TFReaderOp::Builder> builder = std::make_shared<TFReaderOp::Builder>(); std::shared_ptr<TFReaderOp::Builder> builder = std::make_shared<TFReaderOp::Builder>();
......
...@@ -107,6 +107,8 @@ class DEPipeline { ...@@ -107,6 +107,8 @@ class DEPipeline {
Status ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); Status ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseFilterOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseRepeatOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); Status ParseRepeatOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseSkipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); Status ParseSkipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
...@@ -121,8 +123,6 @@ class DEPipeline { ...@@ -121,8 +123,6 @@ class DEPipeline {
Status ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); Status ParseZipOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
DsOpPtr ParseFilterOp(const py::dict &args) const;
Status ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); Status ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseTFReaderOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr); Status ParseTFReaderOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
......
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include "dataset/engine/datasetops/map_op.h" #include "dataset/engine/datasetops/map_op.h"
#include "dataset/engine/datasetops/project_op.h" #include "dataset/engine/datasetops/project_op.h"
#include "dataset/engine/datasetops/rename_op.h" #include "dataset/engine/datasetops/rename_op.h"
#include "dataset/engine/datasetops/filter_op.h"
#include "dataset/engine/datasetops/repeat_op.h" #include "dataset/engine/datasetops/repeat_op.h"
#include "dataset/engine/datasetops/skip_op.h" #include "dataset/engine/datasetops/skip_op.h"
#include "dataset/engine/datasetops/shuffle_op.h" #include "dataset/engine/datasetops/shuffle_op.h"
......
...@@ -240,7 +240,7 @@ void Tensor::PrintItemAt(const std::vector<dsize_t> &index, std::ostream &out) c ...@@ -240,7 +240,7 @@ void Tensor::PrintItemAt(const std::vector<dsize_t> &index, std::ostream &out) c
DS_ASSERT(data_); DS_ASSERT(data_);
switch (type_.value()) { switch (type_.value()) {
CASE_PRINT_HEX(DataType::DE_BOOL, uint8_t); CASE_PRINT_HEX(DataType::DE_BOOL, bool);
CASE_PRINT_HEX(DataType::DE_INT8, int8_t); CASE_PRINT_HEX(DataType::DE_INT8, int8_t);
......
...@@ -14,5 +14,6 @@ add_library(engine-datasetops OBJECT ...@@ -14,5 +14,6 @@ add_library(engine-datasetops OBJECT
take_op.cc take_op.cc
shuffle_op.cc shuffle_op.cc
zip_op.cc zip_op.cc
filter_op.cc
) )
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/filter_op.h"
#include <algorithm>
#include <cstring>
#include <iostream>
#include <memory>
#include <vector>
#include "dataset/core/config_manager.h"
#include "dataset/core/constants.h"
#include "dataset/core/global_context.h"
#include "dataset/core/tensor.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/kernels/tensor_op.h"
#include "utils/log_adapter.h"
#include "dataset/util/task_manager.h"
namespace mindspore {
namespace dataset {
Status FilterOp::Builder::SanityCheck() {
std::string err;
err += builder_op_connector_size_ <= 0 ? "connector size <= 0\n" : "";
err += builder_num_workers_ <= 0 ? "filter num_parallel_workers <= 0\n" : "";
return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err));
}
FilterOp::Builder::Builder() {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_num_workers_ = cfg->num_parallel_workers();
builder_op_connector_size_ = cfg->op_connector_size();
}
Status FilterOp::Builder::Build(std::shared_ptr<FilterOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<FilterOp>(std::move(build_in_col_names_), builder_num_workers_, builder_op_connector_size_,
builder_predicate_func_);
return Status::OK();
}
FilterOp::FilterOp(const std::vector<std::string> &in_col_names, int32_t num_workers, int32_t op_queue_size,
py::function predicate_func)
: ParallelOp(num_workers, op_queue_size), predicate_func_(std::move(predicate_func)), in_columns_(in_col_names) {}
Status FilterOp::operator()() {
// The operator class just starts off threads by calling the tree_ function.
RETURN_UNEXPECTED_IF_NULL(tree_);
// Synchronize with TaskManager.
TaskManager::FindMe()->Post();
filter_queues_.Init(num_workers_, oc_queue_size_);
RETURN_IF_NOT_OK(filter_queues_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&FilterOp::WorkerEntry, this, std::placeholders::_1)));
RETURN_IF_NOT_OK(Collector());
return Status::OK();
}
Status FilterOp::EofReceived(int32_t) { return Status::OK(); }
Status FilterOp::EoeReceived(int32_t) { return Status::OK(); }
// Validating if each of the input_columns exists in the DataBuffer.
Status FilterOp::ValidateInColumns(const std::unordered_map<std::string, int32_t> &col_name_id_map,
std::vector<std::string> *input_columns) {
for (const auto &inCol : *input_columns) {
bool found = col_name_id_map.find(inCol) != col_name_id_map.end() ? true : false;
if (!found) {
std::string err_msg = "input column name: " + inCol + " doesn't exist in the dataset columns.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
}
return Status::OK();
}
// A print method typically used for debugging.
void FilterOp::Print(std::ostream &out, bool show_all) const {
// Call base class printer first.
ParallelOp::Print(out, show_all);
// Then display our own stuff.
out << "\nFilterOp:";
out << "\n Input column names:";
for (size_t i = 0; i < in_columns_.size(); i++) {
out << " " << in_columns_[i];
}
}
Status FilterOp::WorkerEntry(int32_t worker_id) {
// Handshake with TaskManager that thread creation is successful.
TaskManager::FindMe()->Post();
std::unique_ptr<DataBuffer> in_buffer;
bool worker_stop = false;
while (worker_stop == false) {
// Getting a databuffer to work on.
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&in_buffer, worker_id));
if (in_buffer->eoe()) {
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEoe));
continue;
} else if (in_buffer->eof()) {
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEof));
worker_stop = true;
continue;
}
// Thread local variables to avoid lock. When in_columns_ is empty and workers will write
// the name of the first column into input_columns (thread local) instead of in_columns_ (thread global).
std::vector<std::string> input_columns = in_columns_;
// Indices of the columns to process.
std::vector<size_t> to_process_indices;
RETURN_IF_NOT_OK(WorkerEntryInit(in_buffer.get(), &to_process_indices, &input_columns));
// if the databuffer was all filtered, it is marked as kFilterEmpty.
// if the databuffer was partially filtered, it is marked as kFilterPartial.
// if the databuffer was not filtered, it is marked as kFilterFull.
int32_t num_rows = in_buffer->NumRows();
std::unique_ptr<TensorQTable> new_tensor_table;
RETURN_IF_NOT_OK(WorkerCompute(in_buffer.get(), to_process_indices, &new_tensor_table));
if (new_tensor_table->empty()) {
RETURN_IF_NOT_OK(
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEmpty)));
} else if (new_tensor_table->size() == num_rows) {
in_buffer->set_tensor_table(std::move(new_tensor_table));
RETURN_IF_NOT_OK(
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterFull)));
} else { // kFilterPartial
in_buffer->set_tensor_table(std::move(new_tensor_table));
RETURN_IF_NOT_OK(
filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterPartial)));
}
}
return Status::OK();
}
Status FilterOp::WorkerCompute(DataBuffer *in_buffer, const std::vector<size_t> &to_proess_indices,
std::unique_ptr<TensorQTable> *out) {
*out = std::make_unique<TensorQTable>();
int32_t num_rows = in_buffer->NumRows();
for (int32_t i = 0; i < num_rows; i++) {
TensorRow to_process;
TensorRow cur_row;
RETURN_IF_NOT_OK(in_buffer->PopRow(&cur_row));
(void)std::transform(to_proess_indices.begin(), to_proess_indices.end(), std::back_inserter(to_process),
[&cur_row](const size_t &it) -> std::shared_ptr<Tensor> { return cur_row[it]; });
bool predicate = true;
RETURN_IF_NOT_OK(InvokePredicateFunc(to_process, &predicate));
if (predicate) {
(*out)->push_back(std::move(cur_row));
}
}
return Status::OK();
}
// if the filtered DataBuffer is written directly to out_connector_,
// the thread fetching data will block in a queue.
// Collector function will reorder the DataBuffer in order.
// for example in two work queues:
// int filter_queues_:
// queue1: DB(data1 kFilterEmpty) DB(eoe) DB(data4) DB(eof)
// queue2: DB(data2) DB(data3 kFilterEmpty) DB(eoe)
// after reorder in out_connector_:
// queue1: DB(data2) DB(data4) DB(eof)
// queue2: DB(eoe) DB(eoe)
Status FilterOp::Collector() {
bool collector_stop = false;
uint64_t task_id_cnt = 0;
uint64_t out_id_cnt = 0;
std::pair<std::unique_ptr<DataBuffer>, filterCtrl> in_pair;
while (collector_stop == false) {
uint32_t w_id = task_id_cnt % num_workers_;
RETURN_IF_NOT_OK(filter_queues_[w_id]->PopFront(&in_pair));
if (in_pair.second == filterCtrl::kFilterFull || in_pair.second == filterCtrl::kFilterPartial ||
in_pair.second == filterCtrl::kFilterEoe) {
uint32_t out_task_id = out_id_cnt % num_workers_;
RETURN_IF_NOT_OK(out_connector_->Add(static_cast<int>(out_task_id), std::move(in_pair.first)));
out_id_cnt++;
task_id_cnt++;
} else if (in_pair.second == filterCtrl::kFilterEof) {
uint32_t out_task_id = out_id_cnt % num_workers_;
RETURN_IF_NOT_OK(out_connector_->Add(static_cast<int>(out_task_id), std::move(in_pair.first)));
collector_stop = true;
} else { // kFilterEmpty
task_id_cnt++;
}
}
return Status::OK();
}
// initialize some internal data structure used by WorkerEntry().
Status FilterOp::WorkerEntryInit(const DataBuffer *in_buf, std::vector<size_t> *to_process_indices,
std::vector<std::string> *input_columns) {
int32_t num_rows = in_buf->NumRows();
int32_t num_cols = in_buf->NumCols();
if (num_rows == 0 || num_cols == 0) {
RETURN_STATUS_UNEXPECTED("FilterOp is getting an empty DataBuffer.");
}
std::unordered_map<std::string, int32_t> col_name_id_map = in_buf->column_name_map();
// Check if there is invalid column name in the inColumns.
RETURN_IF_NOT_OK(ValidateInColumns(col_name_id_map, input_columns));
if (input_columns->empty()) {
MS_LOG(INFO) << "Input columns in filter operator is empty, will apply to the all column in the current table.";
// sort the input colunms by column index.
std::vector<std::pair<std::string, int32_t>> sort_vec(col_name_id_map.begin(), col_name_id_map.end());
std::sort(sort_vec.begin(), sort_vec.end(),
[](const std::pair<std::string, int32_t> &a, const std::pair<std::string, int32_t> &b) {
return a.second < b.second;
});
(void)std::transform(sort_vec.begin(), sort_vec.end(), std::back_inserter(*input_columns),
[](const auto &it) -> std::string { return it.first; });
}
// initialize to_process_indices.
(void)std::transform(input_columns->begin(), input_columns->end(), std::back_inserter(*to_process_indices),
[&col_name_id_map](const auto &it) -> size_t { return col_name_id_map[it]; });
return Status::OK();
}
Status FilterOp::CheckInput(const TensorRow &input) const {
for (auto &item : input) {
if (item == nullptr) {
RETURN_STATUS_UNEXPECTED("input is null.");
}
}
return Status::OK();
}
Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate) {
RETURN_IF_NOT_OK(CheckInput(input));
// Acquire Python GIL.
py::gil_scoped_acquire gil_acquire;
if (Py_IsInitialized() == 0) {
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
}
try {
// Transform input tensor vector into numpy array vector.
py::tuple input_args(input.size());
for (size_t i = 0; i < input.size(); i++) {
py::array new_data;
RETURN_IF_NOT_OK(input.at(i)->GetDataAsNumpy(&new_data));
input_args[i] = new_data;
}
// Invoke python function.
py::object ret_py_obj = predicate_func_(*input_args);
*out_predicate = ret_py_obj.cast<py::bool_>();
} catch (const py::error_already_set &e) {
std::stringstream ss;
ss << e.what() << std::endl;
ss << "The type of the return value of python predicate function is not bool, or can not be convert to bool.";
return Status(StatusCode::kPyFuncException, ss.str());
}
return Status(StatusCode::kOK, "FilterOp predicate func call succeed");
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_FILTER_OP_H_
#define DATASET_ENGINE_DATASETOPS_FILTER_OP_H_
#include <memory>
#include <queue>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "dataset/engine/datasetops/parallel_op.h"
#include "dataset/kernels/tensor_op.h"
#include "dataset/util/queue.h"
namespace mindspore {
namespace dataset {
class FilterOp : public ParallelOp {
public:
// The nested builder class inside of the FilterOp is used to help manage all of
// the arguments for constructing it. Use the builder by setting each argument
// with the provided set methods, and then finally call the build method to execute
// the actual construction.
class Builder {
public:
// Builder constructor. Creates the builder object.
// @note No default args.
// @return This is a constructor.
Builder();
// Default destructor
~Builder() = default;
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder &SetPredicateFunc(py::function func) {
builder_predicate_func_ = std::move(func);
return *this;
}
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder &SetInColNames(const std::vector<std::string> &in_col_names) {
build_in_col_names_ = in_col_names;
return *this;
}
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder &SetNumWorkers(int32_t num_workers) {
builder_num_workers_ = num_workers;
return *this;
}
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder &SetOpConnectorSize(int32_t connector_size) {
builder_op_connector_size_ = connector_size;
return *this;
}
// The builder "build" method creates the final object.
// @param ptr The shared_ptr to the new FilterOp object.
// @return Status.
Status Build(std::shared_ptr<FilterOp> *ptr);
private:
// Sanity check for builder class args.
// @return Status - The error code return.
Status SanityCheck();
std::vector<std::string> build_in_col_names_;
py::function builder_predicate_func_;
int32_t builder_num_workers_;
int32_t builder_op_connector_size_;
};
enum filterCtrl : int8_t { kFilterEmpty = 0, kFilterPartial = 1, kFilterFull = 2, kFilterEoe = 3, kFilterEof = 4 };
// Constructor of FilterOp
// @note The builder class should be used to call it.
// @param in_col_names A list of input column names,when it is empty the predicate will be
// applied all columns in the dataset.
// @param num_workers The number of worker threads.
// @param op_connector_size The size of each queue in the connector.
// @param predicate_func python callable which returns a boolean value.
FilterOp(const std::vector<std::string> &in_col_names, int32_t num_workers, int32_t op_queue_size,
py::function predicate_func);
// Class functor operator () override.
// All dataset ops operate by launching a thread (see ExecutionTree),This class functor will
// provide the master loop that drives the logic for performing the work.
// @return Status The error code return
Status operator()() override;
// @param int32_t workerId.
// @return Status - The error code return.
Status EofReceived(int32_t) override;
// @param int32_t workerId.
// @return Status - The error code return.
Status EoeReceived(int32_t) override;
// A print method typically used for debugging.
// @param out The output stream to write output to.
// @param show_all A bool to control if you want to show all info or just a summary.
void Print(std::ostream &out, bool show_all) const override;
private:
// predicate_func python callable which returns a boolean value.
py::function predicate_func_;
// Variable to store the column name that will feed to predicate function.
std::vector<std::string> in_columns_;
// Internal queue for filter.
QueueList<std::pair<std::unique_ptr<DataBuffer>, filterCtrl>> filter_queues_;
// Private function for worker/thread to loop continuously. It comprises the main
// logic of FilterOp, getting the data from previous Op, validating user specified column names,
// applying predicate to each of the data, filter the data when predicate result is false.
// @param worker_id The id assigned to this thread/worker upon creation.
// @return Status The error code return.
Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_
// Filter the data by predicate function .
// @param in_buffer input data buffer.
// @param to_proess_indices Indices of columns to be processed.
// @param out data buffer that are filtered by predicate.
// @return Status The error code return.
Status WorkerCompute(DataBuffer *in_buffer, const std::vector<size_t> &to_proess_indices,
std::unique_ptr<TensorQTable> *out);
// Collector databuffer.
// @return Status The error code return.
Status Collector();
// @param input tensor vector.
// @return Status - The error code return.
Status CheckInput(const TensorRow &input) const;
// Invoke python func.
// @param input tensor vector.
// @param the result of predicate.
// @return Status - The error code return.
Status InvokePredicateFunc(const TensorRow &input, bool *out_predicate);
// Private function for validating if each of the user specified input column names
// exist in the DataBuffer.
// @param col_name_id_map The column name to index mapping obtained from DataBuffer.
// @param input_columns The vector of input column names used in the current thread.
// @return Status The error code return.
Status ValidateInColumns(const std::unordered_map<std::string, int32_t> &col_name_id_map,
std::vector<std::string> *input_columns);
// Private function that initialize some internal data structure used by WorkerEntry().
// @param in_buf A raw pointer to the DataBuffer. A raw pointer is fine because this function does not manage memory
// and is not shared with other threads.
// @param[out] to_process_indices Indices of columns that will feed to predicate.
// @param input_columns The vector of input column names used in the current thread.
Status WorkerEntryInit(const DataBuffer *in_buf, std::vector<size_t> *to_process_indices,
std::vector<std::string> *input_columns);
};
} // namespace dataset
} // namespace mindspore
#endif
...@@ -35,7 +35,7 @@ from mindspore._c_expression import typing ...@@ -35,7 +35,7 @@ from mindspore._c_expression import typing
from mindspore import log as logger from mindspore import log as logger
from . import samplers from . import samplers
from .iterators import DictIterator, TupleIterator from .iterators import DictIterator, TupleIterator
from .validators import check, check_batch, check_shuffle, check_map, check_repeat, check_skip, check_zip, check_rename, \ from .validators import check, check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, check_rename, \
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \
check_zip_dataset, check_add_column, check_textfiledataset check_zip_dataset, check_add_column, check_textfiledataset
...@@ -385,6 +385,32 @@ class Dataset: ...@@ -385,6 +385,32 @@ class Dataset:
""" """
return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers) return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers)
@check_filter
def filter(self, predicate, input_columns=None, num_parallel_workers=1):
"""
Filter dataset by predicate.
Note:
If input_columns not provided or empty, all columns will be used.
Args:
predicate: python callable which returns a boolean value.
input_columns: (list[str]): List of names of the input columns, when
default=None, the predicate will be applied on all columns in the dataset.
num_parallel_workers (int, optional): Number of workers to process the Dataset
in parallel (default=None).
Returns:
FilterDataset, dataset filter.
Examples:
>>> import mindspore.dataset as ds
>>> # generator data(0 ~ 63)
>>> # filter the data that greater than or equal to 11
>>> dataset_f = dataset.filter(predicate=lambda data: data < 11, input_columns = ["data"])
"""
return FilterDataset(self, predicate, input_columns, num_parallel_workers)
@check_repeat @check_repeat
def repeat(self, count=None): def repeat(self, count=None):
""" """
...@@ -1105,6 +1131,44 @@ class MapDataset(DatasetOp): ...@@ -1105,6 +1131,44 @@ class MapDataset(DatasetOp):
return self.input[0].get_dataset_size() return self.input[0].get_dataset_size()
class FilterDataset(DatasetOp):
"""
The result of applying filter predicate to the input Dataset.
Args:
input_dataset: Input Dataset to be mapped.
predicate: python callable which returns a boolean value.
input_columns: (list[str]): List of names of the input columns, when
default=None, the predicate will be applied all columns in the dataset.
num_parallel_workers (int, optional): Number of workers to process the Dataset
in parallel (default=None).
"""
def __init__(self, input_dataset, predicate, input_columns=None, num_parallel_workers=None):
super().__init__(num_parallel_workers)
self.predicate = lambda *args: bool(predicate(*args))
self.input.append(input_dataset)
input_dataset.output.append(self)
if input_columns is not None and not isinstance(input_columns, list):
input_columns = [input_columns]
self.input_columns = input_columns
def get_args(self):
args = super().get_args()
args["predicate"] = self.predicate
args["input_columns"] = self.input_columns
return args
def get_dataset_size(self):
"""
Get the number of batches in an epoch.
the size cannot be determined before we run the pipeline
Return:
0
"""
return 0
class RepeatDataset(DatasetOp): class RepeatDataset(DatasetOp):
""" """
The result of applying Repeat operator to the input Dataset. The result of applying Repeat operator to the input Dataset.
......
...@@ -129,6 +129,8 @@ class Iterator: ...@@ -129,6 +129,8 @@ class Iterator:
op_type = OpName.ZIP op_type = OpName.ZIP
elif isinstance(dataset, de.MapDataset): elif isinstance(dataset, de.MapDataset):
op_type = OpName.MAP op_type = OpName.MAP
elif isinstance(dataset, de.FilterDataset):
op_type = OpName.FILTER
elif isinstance(dataset, de.RepeatDataset): elif isinstance(dataset, de.RepeatDataset):
op_type = OpName.REPEAT op_type = OpName.REPEAT
elif isinstance(dataset, de.SkipDataset): elif isinstance(dataset, de.SkipDataset):
......
...@@ -693,6 +693,26 @@ def check_map(method): ...@@ -693,6 +693,26 @@ def check_map(method):
return new_method return new_method
def check_filter(method):
""""check the input arguments of filter."""
@wraps(method)
def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs)
predicate = param_dict.get("predicate")
if not callable(predicate):
raise ValueError("Predicate should be a python function or a callable python object.")
nreq_param_int = ['num_parallel_workers']
check_param_type(nreq_param_int, param_dict, int)
param_name = "input_columns"
param = param_dict.get(param_name)
if param is not None:
check_columns(param, param_name)
return method(*args, **kwargs)
return new_method
def check_repeat(method): def check_repeat(method):
"""check the input arguments of repeat.""" """check the input arguments of repeat."""
@wraps(method) @wraps(method)
......
...@@ -66,6 +66,8 @@ SET(DE_UT_SRCS ...@@ -66,6 +66,8 @@ SET(DE_UT_SRCS
celeba_op_test.cc celeba_op_test.cc
take_op_test.cc take_op_test.cc
text_file_op_test.cc) text_file_op_test.cc)
filter_op_test.cc
)
add_executable(de_ut_tests ${DE_UT_SRCS}) add_executable(de_ut_tests ${DE_UT_SRCS})
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/util/circular_pool.h"
#include "dataset/core/client.h"
#include "common/common.h"
#include "gtest/gtest.h"
#include "utils/log_adapter.h"
using namespace mindspore::dataset;
namespace de = mindspore::dataset;
using mindspore::MsLogLevel::INFO;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;
class MindDataTestfilter_op : public UT::DatasetOpTesting {
};
std::shared_ptr<de::FilterOp> Filter() {
Status rc;
std::shared_ptr<de::FilterOp> op;
rc = de::FilterOp::Builder().Build(&op);
EXPECT_TRUE(rc.IsOk());
return op;
}
TEST_F(MindDataTestfilter_op, Testfilter_opFuntions) {
MS_LOG(INFO) << "Doing MindDataTest filter_op.";
auto my_tree = std::make_shared<ExecutionTree>();
std::shared_ptr<DatasetOp> parent_op = Filter();
std::shared_ptr<DatasetOp> leaf_op = Filter();
my_tree->AssociateNode(parent_op);
my_tree->AssociateNode(leaf_op);
ASSERT_NE(parent_op, nullptr);
ASSERT_NE(leaf_op, nullptr);
}
...@@ -158,6 +158,16 @@ TEST_F(MindDataTestTensorDE, InsertTensor) { ...@@ -158,6 +158,16 @@ TEST_F(MindDataTestTensorDE, InsertTensor) {
ASSERT_EQ(*t == *t6, true); ASSERT_EQ(*t == *t6, true);
} }
// Test the bug of Tensor::ToString will exec failed for Tensor which store bool values
TEST_F(MindDataTestTensorDE, BoolTensor) {
std::shared_ptr<Tensor> t = std::make_shared<Tensor>(TensorShape({2}),
DataType(DataType::DE_BOOL));
t->SetItemAt<bool>({0}, true);
t->SetItemAt<bool>({1}, true);
std::string out = t->ToString();
ASSERT_TRUE(out.find("Template type and Tensor type are not compatible") == std::string::npos);
}
TEST_F(MindDataTestTensorDE, GetItemAt) { TEST_F(MindDataTestTensorDE, GetItemAt) {
std::shared_ptr<Tensor> t = std::make_shared<Tensor>(TensorShape({2, 2}), DataType(DataType::DE_UINT8)); std::shared_ptr<Tensor> t = std::make_shared<Tensor>(TensorShape({2, 2}), DataType(DataType::DE_UINT8));
t->Fill<uint8_t>(254); t->Fill<uint8_t>(254);
......
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as cde
import mindspore.dataset.transforms.c_transforms as C
import mindspore.common.dtype as mstype
from mindspore import log as logger
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
# test for predicate
def test_diff_predicate_func():
def test_filter(predicate_func):
transforms = [
cde.Decode(),
cde.Resize([64, 64])
]
type_cast_op = C.TypeCast(mstype.int32)
dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image", "label"], shuffle=False)
dataset = dataset.map(input_columns=["image"], operations=transforms, num_parallel_workers=1)
dataset = dataset.filter(input_columns=["image", "label"], predicate=predicate_func, num_parallel_workers=4)
num_iter = 0
label_list = []
for data in dataset.create_dict_iterator():
num_iter += 1
ori_img = data["image"]
label = data["label"]
label_list.append(label)
assert num_iter == 1
assert label_list[0] == 3
test_filter(lambda image, label: label == 3)
test_filter(lambda image, label: label[0] == 3)
test_filter(lambda image, label: label == [3])
test_filter(lambda image, label: label == np.array([3]))
test_filter(lambda image, label: label == np.array(3))
def filter_func_ge(data):
if data > 10:
return False
return True
def generator_1d():
for i in range(64):
yield (np.array(i),)
# test with GeneratorDataset
def test_filter_by_generator_with_no():
dataset = ds.GeneratorDataset(generator_1d, ["data"])
dataset_f = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4)
num_iter = 0
expected_rs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
for item in dataset_f.create_dict_iterator():
assert item["data"] == expected_rs[num_iter]
num_iter += 1
# test with repeatOp before
def test_filter_by_generator_with_repeat():
dataset = ds.GeneratorDataset(generator_1d, ["data"])
dataset_r = dataset.repeat(4)
dataset_f = dataset_r.filter(predicate=filter_func_ge, num_parallel_workers=4)
num_iter = 0
ret_data = []
expected_rs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
for item in dataset_f.create_dict_iterator():
num_iter += 1
ret_data.append(item["data"])
assert num_iter == 44
for i in range(4):
for ii in range(len(expected_rs)):
index = i * len(expected_rs) + ii
assert ret_data[index] == expected_rs[ii]
# test with repeatOp after
def test_filter_by_generator_with_repeat_after():
dataset = ds.GeneratorDataset(generator_1d, ["data"])
dataset_f = dataset.filter(predicate=filter_func_ge, num_parallel_workers=4)
dataset_r = dataset_f.repeat(4)
num_iter = 0
ret_data = []
expected_rs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
for item in dataset_r.create_dict_iterator():
num_iter += 1
ret_data.append(item["data"])
assert num_iter == 44
for i in range(4):
for ii in range(len(expected_rs)):
index = i * len(expected_rs) + ii
assert ret_data[index] == expected_rs[ii]
def filter_func_batch(data):
if data[0] > 8:
return False
return True
def filter_func_batch_after(data):
if data > 20:
return False
return True
# test with batchOp before
def test_filter_by_generator_with_batch():
dataset = ds.GeneratorDataset(generator_1d, ["data"])
dataset_b = dataset.batch(4)
dataset_f = dataset_b.filter(predicate=filter_func_batch, num_parallel_workers=4)
num_iter = 0
ret_data = []
for item in dataset_f.create_dict_iterator():
num_iter += 1
ret_data.append(item["data"])
assert num_iter == 3
assert ret_data[0][0] == 0
assert ret_data[1][0] == 4
assert ret_data[2][0] == 8
# test with batchOp after
def test_filter_by_generator_with_batch_after():
dataset = ds.GeneratorDataset(generator_1d, ["data"])
dataset_f = dataset.filter(predicate=filter_func_batch_after, num_parallel_workers=4)
dataset_b = dataset_f.batch(4)
num_iter = 0
ret_data = []
for item in dataset_b.create_dict_iterator():
num_iter += 1
ret_data.append(item["data"])
assert num_iter == 6
assert ret_data[0][0] == 0
assert ret_data[1][0] == 4
assert ret_data[5][0] == 20
def filter_func_shuffle(data):
if data > 20:
return False
return True
# test with batchOp before
def test_filter_by_generator_with_shuffle():
dataset = ds.GeneratorDataset(generator_1d, ["data"])
dataset_s = dataset.shuffle(4)
dataset_f = dataset_s.filter(predicate=filter_func_shuffle, num_parallel_workers=4)
num_iter = 0
for item in dataset_f.create_dict_iterator():
num_iter += 1
assert num_iter == 21
def filter_func_shuffle_after(data):
if data > 20:
return False
return True
# test with batchOp after
def test_filter_by_generator_with_shuffle_after():
dataset = ds.GeneratorDataset(generator_1d, ["data"])
dataset_f = dataset.filter(predicate=filter_func_shuffle_after, num_parallel_workers=4)
dataset_s = dataset_f.shuffle(4)
num_iter = 0
for item in dataset_s.create_dict_iterator():
num_iter += 1
assert num_iter == 21
def generator_1d_zip1():
for i in range(64):
yield (np.array(i),)
def generator_1d_zip2():
for i in range(64):
yield (np.array(i+100),)
def filter_func_zip(data1, data2):
if data1 > 20:
return False
return True
def filter_func_zip_after(data1):
if data1 > 20:
return False
return True
# test with zipOp before
def test_filter_by_generator_with_zip():
dataset1 = ds.GeneratorDataset(generator_1d_zip1, ["data1"])
dataset2 = ds.GeneratorDataset(generator_1d_zip2, ["data2"])
dataz = ds.zip((dataset1, dataset2))
dataset_f = dataz.filter(predicate=filter_func_zip, num_parallel_workers=1)
num_iter = 0
ret_data = []
for item in dataset_f.create_dict_iterator():
num_iter += 1
ret_data.append({"data1": item["data1"], "data2":item["data2"]})
assert num_iter == 21
assert ret_data[0]["data1"] == 0
assert ret_data[0]["data2"] == 100
assert ret_data[5]["data1"] == 5
assert ret_data[5]["data2"] == 105
# test with zipOp after
def test_filter_by_generator_with_zip_after():
dataset1 = ds.GeneratorDataset(generator_1d_zip1, ["data1"])
dataset2 = ds.GeneratorDataset(generator_1d_zip1, ["data2"])
dt1 = dataset1.filter(predicate=filter_func_zip_after, num_parallel_workers=4)
dt2 = dataset2.filter(predicate=filter_func_zip_after, num_parallel_workers=4)
dataz = ds.zip((dt1, dt2))
num_iter = 0
ret_data = []
for item in dataz.create_dict_iterator():
num_iter += 1
ret_data.append({"data1": item["data1"], "data2":item["data2"]})
assert num_iter == 21
assert ret_data[0]["data1"] == 0
assert ret_data[0]["data2"] == 0
assert ret_data[5]["data1"] == 5
assert ret_data[5]["data2"] == 5
def filter_func_map(col1, col2):
if col1[0] > 8:
return True
return False
def filter_func_map_part(col1):
if col1 < 3:
return True
else:
return False
def filter_func_map_all(col1, col2):
return True
def generator_mc(maxid=20):
for i in range(maxid):
yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]]))
def func_map(data_col1, data_col2):
return (data_col1, data_col2)
def func_map_part(data_col1):
return (data_col1)
# test with map
def test_filter_by_generator_with_map_all_col():
dataset = ds.GeneratorDataset(generator_mc(12), ["col1", "col2"])
dataset_map = dataset.map( input_columns=["col1"], output_columns=["col1"] , operations=func_map_part)
# dataset_map = dataset.map( operations=func_map_part)
dataset_f = dataset_map.filter(input_columns=["col1"], predicate=filter_func_map_part, num_parallel_workers=1)
num_iter = 0
ret_data = []
for item in dataset_f.create_dict_iterator():
num_iter += 1
ret_data.append(item["col1"])
assert num_iter == 3
assert ret_data[0] == 0
assert ret_data[1] == 1
# test with map
def test_filter_by_generator_with_map_part_col():
dataset = ds.GeneratorDataset(generator_mc(12), ["col1", "col2"])
dataset_map = dataset.map( input_columns=["col1"], output_columns=["out1"] , operations=func_map_part)
dataset_f = dataset_map.filter(input_columns=["out1", "col2"], predicate=filter_func_map, num_parallel_workers=4)
num_iter = 0
ret_data = []
for item in dataset_f.create_dict_iterator():
num_iter += 1
print(item)
ret_data.append(item["out1"])
assert num_iter == 3
assert ret_data[0] == 9
assert ret_data[2] == 11
def filter_func_rename(data):
if data> 8:
return True
return False
# test with rename before
def test_filter_by_generator_with_rename():
dataset = ds.GeneratorDataset(generator_1d, ["data"])
dataset_b = dataset.rename(input_columns=["data"], output_columns=["col1"])
dataset_f = dataset_b.filter(predicate=filter_func_rename, num_parallel_workers=4)
num_iter = 0
ret_data = []
for item in dataset_f.create_dict_iterator():
num_iter += 1
ret_data.append(item["col1"])
assert num_iter == 55
assert ret_data[0] == 9
assert ret_data[54] == 63
#test input_column
def filter_func_input_column1(col1, col2):
if col1[0] < 8:
return True
return False
def filter_func_input_column2(col1):
if col1[0] < 8:
return True
return False
def filter_func_input_column3(col1):
return True
# test with input_columns
def test_filter_by_generator_with_input_column():
dataset = ds.GeneratorDataset(generator_mc(64), ["col1", "col2"])
dataset_map = dataset.map( input_columns=["col1"], output_columns=["out1"] , operations=func_map_part)
dataset_f1 = dataset_map.filter(input_columns=["out1", "col2"], predicate=filter_func_input_column1, num_parallel_workers=4)
dataset_f2 = dataset_f1.filter(input_columns=["out1"], predicate=filter_func_input_column2, num_parallel_workers=4)
dataset_f3 = dataset_f2.filter(input_columns=["col2"], predicate=filter_func_input_column3, num_parallel_workers=4)
dataset_f4 = dataset_f3.filter(predicate=filter_func_input_column1, num_parallel_workers=4)
num_iter = 0
ret_data = []
for item in dataset_f4.create_dict_iterator():
num_iter += 1
ret_data.append(item["out1"])
assert num_iter == 8
assert ret_data[0] == 0
assert ret_data[7] == 7
#test kFilterPartial
def generator_mc_p0(maxid=20):
for i in range(maxid):
yield (np.array([i ]), np.array([i + 100]))
def generator_mc_p1(maxid=20):
for i in range(maxid):
yield (np.array([i + 200 ]), np.array([i + 300]))
def filter_func_Partial_0(col1, col2, col3, col4):
filter_data = [0,1,2,3,4, 11]
if col1[0] in filter_data:
return False
return True
# test with row_data_buffer > 1
def test_filter_by_generator_Partial0():
ds.config.load('../data/dataset/declient_filter.cfg')
dataset1= ds.GeneratorDataset(source = generator_mc_p0(), column_names = ["col1", "col2"])
dataset2 = ds.GeneratorDataset(source = generator_mc_p1(), column_names = ["col3", "col4"])
dataset_zip = ds.zip((dataset1, dataset2))
dataset_f1 = dataset_zip.filter(predicate=filter_func_Partial_0, num_parallel_workers=2)
ret = []
for item in dataset_f1.create_dict_iterator():
ret.append(item["col1"])
assert ret[0] == 5
assert ret[6] == 12
# test with row_data_buffer > 1
def test_filter_by_generator_Partial1():
ds.config.load('../data/dataset/declient_filter.cfg')
dataset1= ds.GeneratorDataset(source = generator_mc_p0(), column_names = ["col1", "col2"])
dataset2 = ds.GeneratorDataset(source = generator_mc_p1(), column_names = ["col3", "col4"])
dataset_zip = ds.zip((dataset1, dataset2))
dataset_f1 = dataset_zip.filter(predicate=filter_func_Partial_0, num_parallel_workers=2)
dataset_map = dataset_f1.map( input_columns=["col1"], output_columns=["out1"] , operations=lambda x1: x1 + 400)
ret = []
for item in dataset_map.create_dict_iterator():
ret.append(item["out1"])
assert ret[0] == 405
assert ret[6] == 412
# test with row_data_buffer > 1
def test_filter_by_generator_Partial2():
ds.config.load('../data/dataset/declient_filter.cfg')
dataset1= ds.GeneratorDataset(source = generator_mc_p0(), column_names = ["col1", "col2"])
dataset2 = ds.GeneratorDataset(source = generator_mc_p1(), column_names = ["col3", "col4"])
dataset1f = dataset1.filter( input_columns= ["col1"], predicate=lambda x: x not in [3,7,9], num_parallel_workers=2)
dataset2f = dataset2.filter( input_columns= ["col3"], predicate=lambda x: x not in [203,207,209], num_parallel_workers=2)
dataset_zip = ds.zip((dataset1f, dataset2f))
dataset_map = dataset_zip.map( input_columns=["col1", "col3"], output_columns=["out1", "out3"] , operations=lambda x1,x3: (x1 + 400, x3+500))
ret1 = []
ret3 = []
for item in dataset_map.create_dict_iterator():
ret1.append(item["out1"])
ret3.append(item["out3"])
assert ret1[0] == 400
assert ret1[6] == 408
assert ret3[0] == 700
assert ret3[6] == 708
def filter_func_Partial(col1, col2):
if col1[0] % 3 == 0:
return True
return False
def generator_big(maxid=20):
for i in range(maxid):
yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]]))
# test with row_data_buffer > 1
def test_filter_by_generator_Partial():
ds.config.load('../data/dataset/declient_filter.cfg')
dataset = ds.GeneratorDataset(source = generator_mc(99), column_names = ["col1", "col2"])
dataset_s = dataset.shuffle(4)
dataset_f1 = dataset_s.filter(input_columns=["col1", "col2"], predicate=filter_func_Partial, num_parallel_workers=1)
for item in dataset_f1.create_dict_iterator():
assert item["col1"] % 3 == 0
def filter_func_cifar(col1, col2):
if col2 % 3 == 0:
return True
return False
# test with cifar10
def test_filte_case_dataset_cifar10():
DATA_DIR_10 = "../data/dataset/testCifar10Data"
ds.config.load('../data/dataset/declient_filter.cfg')
dataset_c = ds.Cifar10Dataset(dataset_dir = DATA_DIR_10, num_samples = 100000, shuffle=False)
dataset_f1 = dataset_c.filter(input_columns=["image", "label"], predicate=filter_func_cifar, num_parallel_workers=1)
num_iter = 0
for item in dataset_f1.create_dict_iterator():
# in this example, each dictionary has keys "image" and "label"
assert item["label"] % 3 == 0
# column id sort
def generator_sort1(maxid=20):
for i in range(maxid):
yield (np.array([i]), np.array([i + 100]), np.array([i + 200]))
def generator_sort2(maxid=20):
for i in range(maxid):
yield (np.array([i + 300]), np.array([i + 400]), np.array([i + 500]))
def filter_func_part_sort(col1, col2, col3, col4, col5, col6):
return True
def filter_func_map_sort(col1, col2, col3):
return (col1, col2, col3)
def test_filter_by_generator_with_map_all_sort():
dataset1 = ds.GeneratorDataset(generator_sort1(10), ["col1", "col2", "col3"])
dataset2 = ds.GeneratorDataset(generator_sort2(10), ["col4 ", "col5", "col6"])
dataz = ds.zip((dataset1, dataset2))
dataset_f = dataz.filter(predicate=filter_func_part_sort, num_parallel_workers=1)
num_iter = 0
ret_data = []
for item in dataset_f.create_dict_iterator():
num_iter += 1
ret_data.append(item)
assert num_iter == 10
assert ret_data[0]["col1"] == 0
assert ret_data[9]["col6"] == 509
if __name__ == '__main__':
test_diff_predicate_func()
test_filte_case_dataset_cifar10()
test_filter_by_generator_Partial0()
test_filter_by_generator_Partial1()
test_filter_by_generator_Partial2()
test_filter_by_generator_with_batch()
test_filter_by_generator_with_batch_after()
test_filter_by_generator_with_input_column()
test_filter_by_generator_with_map_all_col()
test_filter_by_generator_with_map_all_sort()
test_filter_by_generator_with_map_part_col()
test_filter_by_generator_with_no()
test_filter_by_generator_with_rename()
test_filter_by_generator_with_repeat()
test_filter_by_generator_with_repeat_after()
test_filter_by_generator_with_shuffle()
test_filter_by_generator_with_shuffle_after()
test_filter_by_generator_with_zip()
test_filter_by_generator_with_zip_after()
test_filter_by_generator_Partial()
...@@ -25,8 +25,8 @@ COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", ...@@ -25,8 +25,8 @@ COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float",
def check(project_columns): def check(project_columns):
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=COLUMNS) data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=COLUMNS, shuffle=False)
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=project_columns) data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=project_columns, shuffle=False)
for data_actual, data_expected in zip(data1.create_tuple_iterator(project_columns), data2.create_tuple_iterator()): for data_actual, data_expected in zip(data1.create_tuple_iterator(project_columns), data2.create_tuple_iterator()):
assert len(data_actual) == len(data_expected) assert len(data_actual) == len(data_expected)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册