diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.cc b/mindspore/ccsrc/dataset/api/de_pipeline.cc index a02d995147a655eb07a3b29d68c15659a1b43ba4..c3dfeafe48e1899cf27344a5231f24023ce22f63 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/dataset/api/de_pipeline.cc @@ -48,6 +48,7 @@ static std::unordered_map g_parse_op_func_ = {{kStorage, &D {kMap, &DEPipeline::ParseMapOp}, {kFilter, &DEPipeline::ParseFilterOp}, {kBatch, &DEPipeline::ParseBatchOp}, + {kBarrier, &DEPipeline::ParseBarrierOp}, {kRepeat, &DEPipeline::ParseRepeatOp}, {kSkip, &DEPipeline::ParseSkipOp}, {kZip, &DEPipeline::ParseZipOp}, @@ -627,6 +628,30 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr return Status::OK(); } +Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr *ptr) { + std::shared_ptr builder = std::make_shared(); + // Right now barrier should only take num_rows_per_buffer = 1 + // The reason for this is because having it otherwise can lead to blocking issues + // See barrier_op.h for more details + (void)builder->SetRowsPerBuffer(1); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "condition_name") { + (void)builder->SetConditionName(ToString(value)); + } else if (key == "condition_func") { + (void)builder->SetConditionFunc(value.cast()); + } + } + } + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *ptr = op; + return Status::OK(); +} + Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptr *ptr) { int32_t prefetch_size = 0; if (args.contains("prefetch_size")) { diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.h b/mindspore/ccsrc/dataset/api/de_pipeline.h index 25919afe588cbb7fe36707e14190bf3b20e8432e..7f9c6c459a58149b0c2ab51b7acf23e87156cb92 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.h +++ b/mindspore/ccsrc/dataset/api/de_pipeline.h @@ -40,6 +40,7 @@ enum OpName { kShuffle, kMindrecord, kBatch, + kBarrier, kCache, kRepeat, kSkip, @@ -115,6 +116,8 @@ class DEPipeline { Status ParseBatchOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseBarrierOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseGeneratorOp(const py::dict &args, std::shared_ptr *ptr); Status ParseRenameOp(const py::dict &args, std::shared_ptr *ptr); diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 9865396a7dc1d860ee028769468906d78d1d9bfc..2b8ce4e896bcd12b3d52dd04ec0c346a72cdfd2d 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -476,6 +476,7 @@ PYBIND11_MODULE(_c_dataengine, m) { .value("STORAGE", OpName::kStorage) .value("SHUFFLE", OpName::kShuffle) .value("BATCH", OpName::kBatch) + .value("BARRIER", OpName::kBarrier) .value("MINDRECORD", OpName::kMindrecord) .value("CACHE", OpName::kCache) .value("REPEAT", OpName::kRepeat) diff --git a/mindspore/ccsrc/dataset/core/client.h b/mindspore/ccsrc/dataset/core/client.h index 15064dee6b817175af45c2722737e1a4c38ac881..40de887aea95834d9c5173c3eaab9bf45d4c13b0 100644 --- a/mindspore/ccsrc/dataset/core/client.h +++ b/mindspore/ccsrc/dataset/core/client.h @@ -25,6 +25,7 @@ #include "dataset/core/tensor_shape.h" #include "dataset/engine/data_schema.h" #include "dataset/engine/dataset_iterator.h" +#include "dataset/engine/datasetops/barrier_op.h" #include "dataset/engine/datasetops/batch_op.h" #include "dataset/engine/datasetops/dataset_op.h" #include "dataset/engine/datasetops/device_queue_op.h" diff --git a/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt index 7de62d9d1109b81ae67477c4f65f627f64e4e8cd..9e8272d513337d557cb30aed85b029738df86078 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt @@ -4,6 +4,7 @@ add_library(engine-datasetops OBJECT dataset_op.cc parallel_op.cc pipeline_op.cc + barrier_op.cc batch_op.cc device_queue_op.cc map_op.cc diff --git a/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b0ea7dbd077405881ef97e98d466951c61b4953f --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.cc @@ -0,0 +1,235 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/engine/datasetops/barrier_op.h" +#include +#include "dataset/core/constants.h" +#include "dataset/engine/data_buffer.h" +#include "dataset/engine/db_connector.h" +#include "dataset/core/config_manager.h" +#include "dataset/core/global_context.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +BarrierOp::Builder::Builder() { + // Some arguments to the BarrierOp constructor have a default argument that is taken + // from the client config. + // The user may choose to change these values for the construction of the BarrierOp by + // using the various builder set methods. + + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_rows_per_buffer_ = cfg->rows_per_buffer(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status BarrierOp::Builder::SanityCheck() const { return Status::OK(); } + +Status BarrierOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(builder_rows_per_buffer_, builder_op_connector_size_, builder_condition_name_, + builder_condition_func_); + return Status::OK(); +} + +// Construct BarrierOp here, local variables initialized in operator due to tree construction restrictions +BarrierOp::BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name, + py::function condition_func) + : PipelineOp(op_connector_size), + rows_per_buffer_(rows_per_buffer), + buffer_id_(0), + clean_up_(false), + eof_(false), + condition_name_(condition_name), + condition_function_(condition_func) {} + +// destructor +BarrierOp::~BarrierOp() {} + +// Entry point for Barrier, called by launch() +Status BarrierOp::operator()() { + // The children_num_ parameter needs to be put here + // Synchronize with TaskManager once the thread is created. + TaskManager::FindMe()->Post(); + + // create child iterator, right now this barrier is a pipeline operator + int32_t worker_id = 0; + int32_t child_idx = 0; + child_iterator_ = std::make_unique(this, worker_id, child_idx); + + // Loop until eof is true + while (!eof_) { + // Create new table to put the new tensor rows + std::unique_ptr curr_table = std::make_unique(); + RETURN_IF_NOT_OK(prepare(curr_table.get())); + + // If an eof got picked up during the above prepare, then we're done + if (eof_) { + break; + } + + // we have to output new buffer with possibly different buffer size, possibly one row + while (!clean_up_) { + // 1. If a previous loop iteration sent the current table out, then create a new one. + + if (curr_table == nullptr) { + curr_table = std::make_unique(); + } + + // 2 fill the table. Note: clean_up mode might get turned on if epoch is finished + RETURN_IF_NOT_OK(fillBuffer(curr_table.get())); + + // 3 create and update buffer and send it to the out connector + if (!curr_table->empty()) { + std::unique_ptr curr_buffer = std::make_unique(buffer_id_, DataBuffer::kDeBFlagNone); + curr_buffer->set_tensor_table(std::move(curr_table)); + curr_buffer->set_column_name_map(col_name_id_map_); + MS_LOG(DEBUG) << "Barrier operator finished one buffer, pushing, rows " << curr_buffer->NumRows() << ", cols " + << curr_buffer->NumCols() << ", map " << col_name_id_map_.size() << "."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); + buffer_id_++; + } + } + + // 4 handle drain state. + if (clean_up_) { + MS_LOG(DEBUG) << "Barrier operator sending epoch ending signal."; + // Send the eoe up. + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOE)))); + } + } + // 5 handle eof + // propagate eof here. + MS_LOG(INFO) << "Barrier operator got EOF, propagating."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOF)))); + return Status::OK(); +} + +// Handles preprocessing of the main loop, used when starting new epoch +Status BarrierOp::prepare(TensorQTable *const table) { + MS_LOG(DEBUG) << "Barrier operator prepares for new epoch."; + clean_up_ = false; + buffer_id_ = 0; + if (table == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp prepare phase requires a tensor table."); + } + // fill initial row + TensorRow new_row = {}; + // use iterator to get next row and invoke pyfunc wait + RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); + + // If the first row fetching resulted in eof, then we are done. + if (eof_) { + return Status::OK(); + } + if (new_row.empty()) { + // This epoch is empty + return Status::OK(); + } + // Pack this first row into our tensor table + // first row we also have to check if we should block + RETURN_IF_NOT_OK(blockCond()); + + table->push_back(std::move(new_row)); + // At this point we have 1 row produced, we take the old column map id and use it in the new table + // Initializing col_name_id_map_ from the first data buffer. + col_name_id_map_ = child_iterator_->col_name_id_map(); + // the update code below shouldn't do anything bad if the column name already exists. + return Status::OK(); +} + +// fillBuffer always expects a new table to fill +Status BarrierOp::fillBuffer(TensorQTable *const table) { + if (table == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp fillBuffer null table pointer."); + } + TensorRow new_row = {}; + while (table->size() < static_cast(rows_per_buffer_)) { + RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); + // Early exit the loop if we got empty row from any of our child iterations + if (new_row.empty()) { + return Status::OK(); + } + // else we got a row so pack it into the tensor table. + RETURN_IF_NOT_OK(blockCond()); + + table->push_back(std::move(new_row)); + } + return Status::OK(); +} + +// function executes a py_func and blocks until condition becomes true. +Status BarrierOp::blockCond() { + { + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + // we have condition name, however the flexibility is in python today + try { + // Invoke python function + py::object ret_py_obj = condition_function_(); + // Process the return value + if (!py::isinstance(ret_py_obj)) { + return Status(StatusCode::kPyFuncException, "Condition wait function should return true/false"); + } + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } + } + return Status::OK(); +} + +// fetches next Barrier buffer row +Status BarrierOp::getNextTensorRow(TensorRow *new_row) { + // iterate over all iterators and generate a row + RETURN_IF_NOT_OK((child_iterator_)->FetchNextTensorRow(new_row)); + // add each new row to iterator, check if row is empty, if row from iterator is empty return empty row + if (new_row->empty()) { + // If we did not get a row from any of the children, then it's the end of an epoch and we can move + // to drain state. + MS_LOG(INFO) << "Barrier operator child iterator produced empty row."; + clean_up_ = true; + // If we picked up an eof here, then we are completely done. + if ((child_iterator_)->eof_handled()) { + MS_LOG(INFO) << "Barrier operator iterator got EOF."; + eof_ = true; + } + return Status::OK(); + } + return Status::OK(); +} + +// A function that prints info about the Operator +void BarrierOp::Print(std::ostream &out, bool show_all) const { + // Call base class printer first + PipelineOp::Print(out, show_all); + out << "\nBarrierOp:\n" + << "\nCondition " << condition_name_ << "\n\n"; +} + +// overwrite function and handle eof +Status BarrierOp::EofReceived(int32_t) { + MS_LOG(DEBUG) << "Barrier operator EOF received, do nothing now."; + return Status::OK(); +} + +// overwrite function and handle eoe +Status BarrierOp::EoeReceived(int32_t) { + state_ = OpState::kDeOpIdle; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.h b/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.h new file mode 100644 index 0000000000000000000000000000000000000000..8be55fba7ec28d8fc2b152e03065ac60dc71c76b --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.h @@ -0,0 +1,172 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ +#define DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ + +#include +#include +#include +#include +#include +#include "dataset/core/tensor.h" +#include "dataset/engine/dataset_iterator.h" +#include "dataset/engine/datasetops/pipeline_op.h" +#include "dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +// Forward declare +class DataBuffer; +class ExecutionTree; + +// BarrierOp class implements the Barrier operator. It will block sending of rows until a signal has +// been received. This signal is given from python layer. The current barrier design respects the +// rows per buffer design and will only output a buffer with rows once it has received rows per buffer +// signals from python. + +class BarrierOp : public PipelineOp { + public: + // The nested builder class inside of the BarrierOp is used to help manage all of + // the arguments for constructing it. Use the builder by setting each argument + // with the provided set methods, and then finally call the build method to execute + // the actual construction. + + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method. + // @param int32_t op_connector_size + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + // Setter method. + // @param const std::string & condition_name + // @return Builder setter method returns reference to the builder. + Builder &SetConditionName(const std::string &condition_name) { + builder_condition_name_ = condition_name; + return *this; + } + + // Setter method. + // @param py::function condition_func - blocking condition function + // @return Builder setter method returns reference to the builder. + Builder &SetConditionFunc(py::function condition_func) { + builder_condition_func_ = condition_func; + return *this; + } + + // The builder "build" method creates the BarrierOp dataset Operator. + // @return shared_ptr to the new BarrierOp object + Status Build(std::shared_ptr *); + + private: + int32_t builder_rows_per_buffer_; + int32_t builder_op_connector_size_; + std::string builder_condition_name_; + py::function builder_condition_func_; + + Status SanityCheck() const; + }; + + // Constructor for BarrierOp + // @param rows_per_buffer - number of rows in output buffer + // @param op_connector_size - connector size + // @param condition_name - the condition name associated with this operator + // @param condition_func - the blocking condition check per row + // @note - currently rows_per_buffer should = 1 for barrier. + // The reason for this is having other values would complicate how the pipeline behaves with other operators + // One example of such case is having batch after barrier. Batch would be waiting for data and having + // rows per buffer in this case can result in hanging + BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name, + py::function condition_func); + + // Destructor + ~BarrierOp(); + + Status EofReceived(int32_t) override; + + Status EoeReceived(int32_t) override; + + // Print function for Barrier + // @param out - output stream to print to + // @param show_all - if it should print everything + void Print(std::ostream &out, bool show_all) const override; + + // Provide stream operator for displaying it + friend std::ostream &operator<<(std::ostream &out, const BarrierOp &bo) { + bo.Print(out, false); + return out; + } + + // Class functor operator () override. + // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - The error code return + Status operator()() override; + + // Handles preprocessing of the main loop, used when starting new epoch + // @param table - a table of tensors to be moved into a buffer + Status prepare(TensorQTable *const table); + + // This function calls takes a table repeatedly adds rows to it. + // @param table - a table of tensors to be moved into a buffer + Status fillBuffer(TensorQTable *const table); + + // Gets next tensor row and sets control signals + Status getNextTensorRow(TensorRow *new_row); + + // This function runs the wait function on condition + Status blockCond(); + + private: + // clean up variable to return imcomplete buffer + bool clean_up_; + // end of file state, we stop reading data and shut down + bool eof_; + // rows per buffer + int32_t rows_per_buffer_; + // buffer_id + int32_t buffer_id_; + // local variable to keep track of the buffer information + std::unordered_map col_name_id_map_; + // iterator to pull new rows, we only have one child + std::unique_ptr child_iterator_; + // condition name, to support multiple barriers + std::string condition_name_; + // Function pointer of blocking function + py::function condition_function_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/zip_op.h b/mindspore/ccsrc/dataset/engine/datasetops/zip_op.h index f14ecba733e9fbad4266f634b87a422e935a866c..04d8ab012174222586c4093d037f1fec7ce06f42 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/zip_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/zip_op.h @@ -34,7 +34,7 @@ class DataBuffer; class ZipOp : public PipelineOp { public: - // The nested builder class inside of the BatchOp is used to help manage all of + // The nested builder class inside of the ZipOp is used to help manage all of // the arguments for constructing it. Use the builder by setting each argument // with the provided set methods, and then finally call the build method to execute // the actual construction. @@ -76,8 +76,8 @@ class ZipOp : public PipelineOp { }; // Constructor for ZipOp - // @param rows_per_buffer number of rows in output buffer - // @param op_connector_size connector + // @param rows_per_buffer - number of rows in output buffer + // @param op_connector_size - connector size ZipOp(int32_t rows_per_buffer, int32_t op_connector_size); // Destructor @@ -88,8 +88,8 @@ class ZipOp : public PipelineOp { Status EoeReceived(int32_t) override; // Print function for Zip - // @param out output stream to print to - // @param show_all if it should print everything + // @param out - output stream to print to + // @param show_all - if it should print everything void Print(std::ostream &out, bool show_all) const override; // Provide stream operator for displaying it @@ -113,14 +113,14 @@ class ZipOp : public PipelineOp { Status fillBuffer(TensorQTable *const table); // Special handle case where an empty row has been received from child iterator - // @note we need to drain eoe signals from all children connectors. - // @details when this function is called, then we encountered eoe at child iterator + // @note - we need to drain eoe signals from all children connectors. + // @details - when this function is called, then we encountered eoe at child iterator // we have to drain rows from other child iterators until we hit eoe from all other child iterators Status drainPipeline(); // Merges 1 row from each childIterator together - // @param new_zip_row input and output, will return a non-empty row if all rows from childConnectors are non-empty - // @param updateColumnMapping generates a new column name to index mapping (mColNameIdMap) if set to true + // @param new_zip_row - input and output, will be a non-empty row if all rows from childConnectors are non-empty + // @param updateColumnMapping - generates a new column name to index mapping (mColNameIdMap) if set to true // @details merge rows from iterator together. This is the main functionality for ZipOp // this function takes one row and fills it with tensors from rows fetched // from childIterators. diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 855e4609bbd310b23c826cbbed5ebf812ce294b8..f67461eee3b31fa9b4afab41bb3ae1774dd12c13 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -26,6 +26,7 @@ import random import uuid from enum import Enum from importlib import import_module +import threading import numpy as np from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \ @@ -38,7 +39,7 @@ from .iterators import DictIterator, TupleIterator from .validators import check, check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, check_rename, \ check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ - check_zip_dataset, check_add_column, check_textfiledataset + check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist try: @@ -139,6 +140,7 @@ class Dataset: self._batch_size = None self._num_classes = None self._repeat_count = None + self._sync = False def get_args(self): """ @@ -196,6 +198,30 @@ class Dataset: """ return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns) + @check_sync_wait + def sync_wait(self, condition_name, num_batch=1, callback=None): + ''' + Add a blocking condition to the input Dataset + + Args: + input_dataset (Dataset): Input dataset to apply flow control + num_batch (int): the number of batches without blocking at the start of each epoch + condition_name (str): The condition name that is used to toggle sending next row + callback (function): The callback funciton that will be invoked when sync_update is called + + Raises: + RuntimeError: If condition name already exists. + + Examples: + >>> import mindspore.dataset as ds + >>> # data is an instance of Dataset object. + >>> data = data.sync_wait("callback1") + >>> data = data.batch(batch_size) + >>> for batch_data in data.create_dict_iterator(): + >>> data = data.sync_update("callback1") + ''' + return SyncWaitDataset(self, condition_name, num_batch, callback) + @check_shuffle def shuffle(self, buffer_size): """ @@ -218,6 +244,9 @@ class Dataset: Returns: ShuffleDataset, dataset shuffled. + Raises: + RuntimeError: If exist sync operators before shuffle. + Examples: >>> import mindspore.dataset as ds >>> # data is an instance of Dataset object @@ -816,6 +845,9 @@ class Dataset: self._input_indexs = value def _get_pipeline_info(self): + """ + Gets pipeline information. + """ device_iter = TupleIterator(self) self._output_shapes = device_iter.get_output_shapes() self._output_types = device_iter.get_output_types() @@ -870,6 +902,30 @@ class Dataset: return self.input[0].num_classes() return None + def get_sync_notifiers(self): + if self.input: + return self.input[0].get_sync_notifiers() + return {} + + def is_sync(self): + if self.input: + return self.input[0].is_sync() + return False + + def sync_update(self, condition_name, num_batch=None, data=None): + """ + condition_name (str): The condition name that is used to toggle sending next row + step_size (int or None): The number of steps(rows) that are released + when pass_rows is None, will update the same number as sync_wait specified + data (dict or None): The data passed to the callback + """ + notifiers_dict = self.get_sync_notifiers() + if condition_name not in notifiers_dict: + raise RuntimeError("Condition name not found") + if num_batch is not None: + num_batch *= self.get_batch_size() + notifiers_dict[condition_name](num_batch, data) + def get_batch_size(self): """ Get the size of a batch. @@ -973,6 +1029,8 @@ class BatchDataset(DatasetOp): if BatchDataset._is_ancestor_of_repeat(input_dataset): logger.warning("Repeat is located before batch, data from two epochs can be batched together.") + BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) + self.batch_size = batch_size self.drop_remainder = drop_remainder self.per_batch_map = per_batch_map @@ -1029,6 +1087,20 @@ class BatchDataset(DatasetOp): flag = flag | BatchDataset._is_ancestor_of_repeat(input_dataset) return flag + @staticmethod + def _update_batch_size_for_syncwait(dataset, batch_size): + """ + Utility function to notify batch size to sync_wait. + + Args: + dataset (Dataset): dataset to be checked + batchsize (int): batch size to notify + """ + if isinstance(dataset, SyncWaitDataset): + dataset.update_sync_batch_size(batch_size) + for input_dataset in dataset.input: + BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) + class BatchInfo(CBatchInfo): """ @@ -1053,6 +1125,108 @@ class BatchInfo(CBatchInfo): """ return +class BlockReleasePair: + """ + The blocking condition class used by SyncWaitDataset + + Args: + init_release_rows (int): Number of lines to allow through the pipeline + callback (function): The callback funciton that will be called when release is called + """ + def __init__(self, init_release_rows, callback=None): + self.row_count = -init_release_rows + self.cv = threading.Condition() + self.callback = callback + self.default_rows = init_release_rows + + def __deepcopy__(self, memodict): + if id(self) in memodict: + return memodict[id(self)] + memodict[id(self)] = self + # condition variable and callback are the same, but reset the counter + self.reset() + return self + + def reset(self): + with self.cv: + self.row_count = -self.default_rows + self.cv.notify_all() + + def update_batched_size(self, batch_size): + # should only use before the pipeline creates + self.row_count *= batch_size + self.default_rows *= batch_size + + def block_func(self): + with self.cv: + self.cv.wait_for(lambda: self.row_count < 0) + self.row_count += 1 + return True + + def release_func(self, pass_rows=None, data=None): + with self.cv: + if pass_rows is None: + pass_rows = self.default_rows + self.row_count -= pass_rows + if self.callback is not None: + self.callback(data) + self.cv.notify_all() + +class SyncWaitDataset(DatasetOp): + """ + The result of adding a blocking condition to the input Dataset + + Args: + input_dataset (Dataset): Input dataset to apply flow control + num_batch (int): the number of batches without blocking at the start of each epoch + condition_name (str): The condition name that is used to toggle sending next row + callback (function): The callback funciton that will be invoked when sync_update is called + + Raises: + RuntimeError: If condition name already exists. + """ + + def __init__(self, input_dataset, condition_name, num_batch, callback=None): + super().__init__() + self.input.append(input_dataset) + input_dataset.output.append(self) + # set to the default value, waiting for the batch to update it + self._condition_name = condition_name + self._pair = BlockReleasePair(num_batch, callback) + if self._condition_name in self.input[0].get_sync_notifiers(): + raise RuntimeError("Condition name is already in use") + + def get_sync_notifiers(self): + return {**self.input[0].get_sync_notifiers(), **{self._condition_name: self._pair.release_func}} + + def is_sync(self): + return True + + def get_args(self): + args = super().get_args() + args["condition_name"] = self._condition_name + args["condition_func"] = self._pair.block_func + return args + + def update_sync_batch_size(self, batch_size): + self._pair.update_batched_size(batch_size) + + @staticmethod + def _is_ancestor_of_batch(dataset): + """ + Utility function to find the case where sync_wait is used before batch. + + Args: + dataset (Dataset): dataset to be checked + Return: + True or False + """ + if isinstance(dataset, BatchDataset): + return True + flag = False + for input_dataset in dataset.input: + flag = flag | SyncWaitDataset._is_ancestor_of_batch(input_dataset) + return flag class ShuffleDataset(DatasetOp): """ @@ -1061,6 +1235,9 @@ class ShuffleDataset(DatasetOp): Args: input_dataset (Dataset): Input Dataset to be shuffled. buffer_size (int): The size of the buffer. + + Raises: + RuntimeError: If exist sync operators before shuffle. """ def __init__(self, input_dataset, buffer_size): @@ -1069,6 +1246,8 @@ class ShuffleDataset(DatasetOp): self.input.append(input_dataset) input_dataset.output.append(self) self._input_indexs = input_dataset.input_indexs + if self.is_sync(): + raise RuntimeError("No shuffle after sync operators") def get_args(self): args = super().get_args() @@ -1335,6 +1514,9 @@ class ZipDataset(DatasetOp): """ return None + def is_sync(self): + return any([c.is_sync() for c in self.input]) + def get_args(self): args = super().get_args() return args diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index 6af6c7dba8e5080a216ff1993b05883080a7f235..a8d20df5f338332423247404240d492c1b15b695 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -125,6 +125,8 @@ class Iterator: op_type = OpName.MINDRECORD elif isinstance(dataset, de.BatchDataset): op_type = OpName.BATCH + elif isinstance(dataset, de.SyncWaitDataset): + op_type = OpName.BARRIER elif isinstance(dataset, de.ZipDataset): op_type = OpName.ZIP elif isinstance(dataset, de.MapDataset): diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index a68d723f1d359f4df69b85b2ea5e788739334132..a8d18ab2c106878c26da35ef5329400c1885aaa4 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -652,6 +652,22 @@ def check_batch(method): return new_method +def check_sync_wait(method): + """check the input arguments of sync_wait.""" + @wraps(method) + def new_method(*args, **kwargs): + param_dict = make_param_dict(method, args, kwargs) + + nreq_param_str = ['condition_name'] + nreq_param_int = ['step_size'] + + check_param_type(nreq_param_int, param_dict, int) + + check_param_type(nreq_param_str, param_dict, str) + + return method(*args, **kwargs) + + return new_method def check_shuffle(method): """check the input arguments of shuffle.""" diff --git a/tests/ut/python/dataset/test_config.py b/tests/ut/python/dataset/test_config.py index 8cabe81aaa4724fe039586b22335f665346fe03c..0c1e0073af75291be3d586068432371059049c92 100644 --- a/tests/ut/python/dataset/test_config.py +++ b/tests/ut/python/dataset/test_config.py @@ -12,8 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +""" +Testing configuration manager +""" +import filecmp +import glob +import os + import mindspore.dataset as ds +import mindspore.dataset.transforms.vision.c_transforms as vision +DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] +SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" def test_basic(): ds.config.load('../data/dataset/declient.cfg') @@ -36,6 +46,34 @@ def test_basic(): assert ds.config.get_prefetch_size() == 4 assert ds.config.get_seed() == 5 +def test_pipeline(): + """ + Test that our configuration pipeline works when we set parameters at dataset interval + """ + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) + ds.config.set_num_parallel_workers(2) + data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True)]) + ds.serialize(data1, "testpipeline.json") + + data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) + ds.config.set_num_parallel_workers(4) + data2 = data2.map(input_columns=["image"], operations=[vision.Decode(True)]) + ds.serialize(data2, "testpipeline2.json") + + # check that the generated output is different + assert (filecmp.cmp('testpipeline.json', 'testpipeline2.json')) + + # this test passes currently because our num_parallel_workers don't get updated. + + # remove generated jason files + file_list = glob.glob('*.json') + for f in file_list: + try: + os.remove(f) + except IOError: + logger.info("Error while deleting: {}".format(f)) + if __name__ == '__main__': test_basic() + test_pipeline() diff --git a/tests/ut/python/dataset/test_sync_wait.py b/tests/ut/python/dataset/test_sync_wait.py new file mode 100644 index 0000000000000000000000000000000000000000..277499d9ae0506b98faa64f1a3e1b1ee95be3df5 --- /dev/null +++ b/tests/ut/python/dataset/test_sync_wait.py @@ -0,0 +1,182 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import mindspore.dataset as ds +from mindspore import log as logger +import time +import numpy as np + + +def gen(): + for i in range(100): + yield np.array(i), + + +class Augment: + def __init__(self, loss): + self.loss = loss + + def preprocess(self, input): + return input + + def update(self, data): + self.loss = data["loss"] + + +def test_simple_sync_wait(): + """ + Test simple sync wait: test sync in dataset pipeline + """ + logger.info("test_simple_sync_wait") + batch_size = 4 + dataset = ds.GeneratorDataset(gen, column_names=["input"]) + + aug = Augment(0) + dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) + dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) + dataset = dataset.batch(batch_size) + + count = 0 + for data in dataset.create_dict_iterator(): + assert (data["input"][0] == count) + count += batch_size + data = {"loss": count} + dataset.sync_update(condition_name="policy", data=data) + + +def test_simple_shuffle_sync(): + """ + Test simple shuffle sync: test shuffle before sync + """ + logger.info("test_simple_shuffle_sync") + shuffle_size = 4 + batch_size = 10 + + dataset = ds.GeneratorDataset(gen, column_names=["input"]) + + aug = Augment(0) + dataset = dataset.shuffle(shuffle_size) + dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) + dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) + dataset = dataset.batch(batch_size) + + count = 0 + for data in dataset.create_dict_iterator(): + count += 1 + #time.sleep(0.5) + data = {"loss": count} + dataset.sync_update(condition_name="policy", data=data) + + +def test_two_sync(): + """ + Test two sync: dataset pipeline with with two sync_operators + """ + logger.info("test_two_sync") + batch_size = 6 + + dataset = ds.GeneratorDataset(gen, column_names=["input"]) + + aug = Augment(0) + # notice that with our design, we need to have step_size = shuffle size + dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update) + + dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) + + dataset = dataset.sync_wait(num_batch=2, condition_name="every 2 batches") + + dataset = dataset.batch(batch_size) + + count = 0 + for data in dataset.create_dict_iterator(): + count += 1 + data = {"loss": count} + dataset.sync_update(condition_name="every batch", data=data) + if count % 2 == 0: + dataset.sync_update(condition_name="every 2 batches") + +def test_sync_epoch(): + """ + Test sync wait with epochs: test sync with epochs in dataset pipeline + """ + logger.info("test_sync_epoch") + batch_size = 30 + dataset = ds.GeneratorDataset(gen, column_names=["input"]) + + aug = Augment(0) + dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) + dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) + dataset = dataset.batch(batch_size, drop_remainder=True) + + for epochs in range(3): + aug.update({"loss": 0}) + count = 0 + for data in dataset.create_dict_iterator(): + assert (data["input"][0] == count) + count += batch_size + data = {"loss": count} + dataset.sync_update(condition_name="policy", data=data) + + +def test_sync_exception_01(): + """ + Test sync: with shuffle in sync mode + """ + logger.info("test_sync_exception_01") + shuffle_size = 4 + batch_size = 10 + + dataset = ds.GeneratorDataset(gen, column_names=["input"]) + + aug = Augment(0) + dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) + dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) + + try: + dataset = dataset.shuffle(shuffle_size) + except BaseException as e: + assert "shuffle" in str(e) + dataset = dataset.batch(batch_size) + + +def test_sync_exception_02(): + """ + Test sync: with duplicated condition name + """ + logger.info("test_sync_exception_02") + batch_size = 6 + + dataset = ds.GeneratorDataset(gen, column_names=["input"]) + + aug = Augment(0) + # notice that with our design, we need to have step_size = shuffle size + dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update) + + dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) + + try: + dataset = dataset.sync_wait(num_batch=2, condition_name="every batch") + except BaseException as e: + assert "name" in str(e) + dataset = dataset.batch(batch_size) + + +if __name__ == "__main__": + test_simple_sync_wait() + test_simple_shuffle_sync() + test_two_sync() + test_sync_exception_01() + test_sync_exception_02() + test_sync_epoch() \ No newline at end of file