提交 dc0491ca 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!508 [Dataset] Adding sync_wait operator for dataset

Merge pull request !508 from EricZ/master
......@@ -48,6 +48,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D
{kMap, &DEPipeline::ParseMapOp},
{kFilter, &DEPipeline::ParseFilterOp},
{kBatch, &DEPipeline::ParseBatchOp},
{kBarrier, &DEPipeline::ParseBarrierOp},
{kRepeat, &DEPipeline::ParseRepeatOp},
{kSkip, &DEPipeline::ParseSkipOp},
{kZip, &DEPipeline::ParseZipOp},
......@@ -627,6 +628,30 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp>
return Status::OK();
}
Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
std::shared_ptr<BarrierOp::Builder> builder = std::make_shared<BarrierOp::Builder>();
// Right now barrier should only take num_rows_per_buffer = 1
// The reason for this is because having it otherwise can lead to blocking issues
// See barrier_op.h for more details
(void)builder->SetRowsPerBuffer(1);
for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "condition_name") {
(void)builder->SetConditionName(ToString(value));
} else if (key == "condition_func") {
(void)builder->SetConditionFunc(value.cast<py::function>());
}
}
}
std::shared_ptr<BarrierOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
*ptr = op;
return Status::OK();
}
Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
int32_t prefetch_size = 0;
if (args.contains("prefetch_size")) {
......
......@@ -40,6 +40,7 @@ enum OpName {
kShuffle,
kMindrecord,
kBatch,
kBarrier,
kCache,
kRepeat,
kSkip,
......@@ -115,6 +116,8 @@ class DEPipeline {
Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseRenameOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
......
......@@ -481,6 +481,7 @@ PYBIND11_MODULE(_c_dataengine, m) {
.value("STORAGE", OpName::kStorage)
.value("SHUFFLE", OpName::kShuffle)
.value("BATCH", OpName::kBatch)
.value("BARRIER", OpName::kBarrier)
.value("MINDRECORD", OpName::kMindrecord)
.value("CACHE", OpName::kCache)
.value("REPEAT", OpName::kRepeat)
......
......@@ -25,6 +25,7 @@
#include "dataset/core/tensor_shape.h"
#include "dataset/engine/data_schema.h"
#include "dataset/engine/dataset_iterator.h"
#include "dataset/engine/datasetops/barrier_op.h"
#include "dataset/engine/datasetops/batch_op.h"
#include "dataset/engine/datasetops/dataset_op.h"
#include "dataset/engine/datasetops/device_queue_op.h"
......
......@@ -4,6 +4,7 @@ add_library(engine-datasetops OBJECT
dataset_op.cc
parallel_op.cc
pipeline_op.cc
barrier_op.cc
batch_op.cc
device_queue_op.cc
map_op.cc
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/barrier_op.h"
#include <utility>
#include "dataset/core/constants.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h"
#include "dataset/core/config_manager.h"
#include "dataset/core/global_context.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
BarrierOp::Builder::Builder() {
// Some arguments to the BarrierOp constructor have a default argument that is taken
// from the client config.
// The user may choose to change these values for the construction of the BarrierOp by
// using the various builder set methods.
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_rows_per_buffer_ = cfg->rows_per_buffer();
builder_op_connector_size_ = cfg->op_connector_size();
}
Status BarrierOp::Builder::SanityCheck() const { return Status::OK(); }
Status BarrierOp::Builder::Build(std::shared_ptr<BarrierOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<BarrierOp>(builder_rows_per_buffer_, builder_op_connector_size_, builder_condition_name_,
builder_condition_func_);
return Status::OK();
}
// Construct BarrierOp here, local variables initialized in operator due to tree construction restrictions
BarrierOp::BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name,
py::function condition_func)
: PipelineOp(op_connector_size),
rows_per_buffer_(rows_per_buffer),
buffer_id_(0),
clean_up_(false),
eof_(false),
condition_name_(condition_name),
condition_function_(condition_func) {}
// destructor
BarrierOp::~BarrierOp() {}
// Entry point for Barrier, called by launch()
Status BarrierOp::operator()() {
// The children_num_ parameter needs to be put here
// Synchronize with TaskManager once the thread is created.
TaskManager::FindMe()->Post();
// create child iterator, right now this barrier is a pipeline operator
int32_t worker_id = 0;
int32_t child_idx = 0;
child_iterator_ = std::make_unique<ChildIterator>(this, worker_id, child_idx);
// Loop until eof is true
while (!eof_) {
// Create new table to put the new tensor rows
std::unique_ptr<TensorQTable> curr_table = std::make_unique<TensorQTable>();
RETURN_IF_NOT_OK(prepare(curr_table.get()));
// If an eof got picked up during the above prepare, then we're done
if (eof_) {
break;
}
// we have to output new buffer with possibly different buffer size, possibly one row
while (!clean_up_) {
// 1. If a previous loop iteration sent the current table out, then create a new one.
if (curr_table == nullptr) {
curr_table = std::make_unique<TensorQTable>();
}
// 2 fill the table. Note: clean_up mode might get turned on if epoch is finished
RETURN_IF_NOT_OK(fillBuffer(curr_table.get()));
// 3 create and update buffer and send it to the out connector
if (!curr_table->empty()) {
std::unique_ptr<DataBuffer> curr_buffer = std::make_unique<DataBuffer>(buffer_id_, DataBuffer::kDeBFlagNone);
curr_buffer->set_tensor_table(std::move(curr_table));
curr_buffer->set_column_name_map(col_name_id_map_);
MS_LOG(DEBUG) << "Barrier operator finished one buffer, pushing, rows " << curr_buffer->NumRows() << ", cols "
<< curr_buffer->NumCols() << ", map " << col_name_id_map_.size() << ".";
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer)));
buffer_id_++;
}
}
// 4 handle drain state.
if (clean_up_) {
MS_LOG(DEBUG) << "Barrier operator sending epoch ending signal.";
// Send the eoe up.
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))));
}
}
// 5 handle eof
// propagate eof here.
MS_LOG(INFO) << "Barrier operator got EOF, propagating.";
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF))));
return Status::OK();
}
// Handles preprocessing of the main loop, used when starting new epoch
Status BarrierOp::prepare(TensorQTable *const table) {
MS_LOG(DEBUG) << "Barrier operator prepares for new epoch.";
clean_up_ = false;
buffer_id_ = 0;
if (table == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp prepare phase requires a tensor table.");
}
// fill initial row
TensorRow new_row = {};
// use iterator to get next row and invoke pyfunc wait
RETURN_IF_NOT_OK(getNextTensorRow(&new_row));
// If the first row fetching resulted in eof, then we are done.
if (eof_) {
return Status::OK();
}
if (new_row.empty()) {
// This epoch is empty
return Status::OK();
}
// Pack this first row into our tensor table
// first row we also have to check if we should block
RETURN_IF_NOT_OK(blockCond());
table->push_back(std::move(new_row));
// At this point we have 1 row produced, we take the old column map id and use it in the new table
// Initializing col_name_id_map_ from the first data buffer.
col_name_id_map_ = child_iterator_->col_name_id_map();
// the update code below shouldn't do anything bad if the column name already exists.
return Status::OK();
}
// fillBuffer always expects a new table to fill
Status BarrierOp::fillBuffer(TensorQTable *const table) {
if (table == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp fillBuffer null table pointer.");
}
TensorRow new_row = {};
while (table->size() < static_cast<size_t>(rows_per_buffer_)) {
RETURN_IF_NOT_OK(getNextTensorRow(&new_row));
// Early exit the loop if we got empty row from any of our child iterations
if (new_row.empty()) {
return Status::OK();
}
// else we got a row so pack it into the tensor table.
RETURN_IF_NOT_OK(blockCond());
table->push_back(std::move(new_row));
}
return Status::OK();
}
// function executes a py_func and blocks until condition becomes true.
Status BarrierOp::blockCond() {
{
py::gil_scoped_acquire gil_acquire;
if (Py_IsInitialized() == 0) {
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
}
// we have condition name, however the flexibility is in python today
try {
// Invoke python function
py::object ret_py_obj = condition_function_();
// Process the return value
if (!py::isinstance<py::bool_>(ret_py_obj)) {
return Status(StatusCode::kPyFuncException, "Condition wait function should return true/false");
}
} catch (const py::error_already_set &e) {
return Status(StatusCode::kPyFuncException, e.what());
}
}
return Status::OK();
}
// fetches next Barrier buffer row
Status BarrierOp::getNextTensorRow(TensorRow *new_row) {
// iterate over all iterators and generate a row
RETURN_IF_NOT_OK((child_iterator_)->FetchNextTensorRow(new_row));
// add each new row to iterator, check if row is empty, if row from iterator is empty return empty row
if (new_row->empty()) {
// If we did not get a row from any of the children, then it's the end of an epoch and we can move
// to drain state.
MS_LOG(INFO) << "Barrier operator child iterator produced empty row.";
clean_up_ = true;
// If we picked up an eof here, then we are completely done.
if ((child_iterator_)->eof_handled()) {
MS_LOG(INFO) << "Barrier operator iterator got EOF.";
eof_ = true;
}
return Status::OK();
}
return Status::OK();
}
// A function that prints info about the Operator
void BarrierOp::Print(std::ostream &out, bool show_all) const {
// Call base class printer first
PipelineOp::Print(out, show_all);
out << "\nBarrierOp:\n"
<< "\nCondition " << condition_name_ << "\n\n";
}
// overwrite function and handle eof
Status BarrierOp::EofReceived(int32_t) {
MS_LOG(DEBUG) << "Barrier operator EOF received, do nothing now.";
return Status::OK();
}
// overwrite function and handle eoe
Status BarrierOp::EoeReceived(int32_t) {
state_ = OpState::kDeOpIdle;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_
#define DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "dataset/core/tensor.h"
#include "dataset/engine/dataset_iterator.h"
#include "dataset/engine/datasetops/pipeline_op.h"
#include "dataset/kernels/tensor_op.h"
namespace mindspore {
namespace dataset {
// Forward declare
class DataBuffer;
class ExecutionTree;
// BarrierOp class implements the Barrier operator. It will block sending of rows until a signal has
// been received. This signal is given from python layer. The current barrier design respects the
// rows per buffer design and will only output a buffer with rows once it has received rows per buffer
// signals from python.
class BarrierOp : public PipelineOp {
public:
// The nested builder class inside of the BarrierOp is used to help manage all of
// the arguments for constructing it. Use the builder by setting each argument
// with the provided set methods, and then finally call the build method to execute
// the actual construction.
class Builder {
public:
// Builder constructor. Creates the builder object.
// @note No default args
// @return This is a constructor.
Builder();
// Default destructor
~Builder() = default;
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder &SetRowsPerBuffer(int32_t rows_per_buffer) {
builder_rows_per_buffer_ = rows_per_buffer;
return *this;
}
// Setter method.
// @param int32_t op_connector_size
// @return Builder setter method returns reference to the builder.
Builder &SetOpConnectorSize(int32_t op_connector_size) {
builder_op_connector_size_ = op_connector_size;
return *this;
}
// Setter method.
// @param const std::string & condition_name
// @return Builder setter method returns reference to the builder.
Builder &SetConditionName(const std::string &condition_name) {
builder_condition_name_ = condition_name;
return *this;
}
// Setter method.
// @param py::function condition_func - blocking condition function
// @return Builder setter method returns reference to the builder.
Builder &SetConditionFunc(py::function condition_func) {
builder_condition_func_ = condition_func;
return *this;
}
// The builder "build" method creates the BarrierOp dataset Operator.
// @return shared_ptr to the new BarrierOp object
Status Build(std::shared_ptr<BarrierOp> *);
private:
int32_t builder_rows_per_buffer_;
int32_t builder_op_connector_size_;
std::string builder_condition_name_;
py::function builder_condition_func_;
Status SanityCheck() const;
};
// Constructor for BarrierOp
// @param rows_per_buffer - number of rows in output buffer
// @param op_connector_size - connector size
// @param condition_name - the condition name associated with this operator
// @param condition_func - the blocking condition check per row
// @note - currently rows_per_buffer should = 1 for barrier.
// The reason for this is having other values would complicate how the pipeline behaves with other operators
// One example of such case is having batch after barrier. Batch would be waiting for data and having
// rows per buffer in this case can result in hanging
BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name,
py::function condition_func);
// Destructor
~BarrierOp();
Status EofReceived(int32_t) override;
Status EoeReceived(int32_t) override;
// Print function for Barrier
// @param out - output stream to print to
// @param show_all - if it should print everything
void Print(std::ostream &out, bool show_all) const override;
// Provide stream operator for displaying it
friend std::ostream &operator<<(std::ostream &out, const BarrierOp &bo) {
bo.Print(out, false);
return out;
}
// Class functor operator () override.
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - The error code return
Status operator()() override;
// Handles preprocessing of the main loop, used when starting new epoch
// @param table - a table of tensors to be moved into a buffer
Status prepare(TensorQTable *const table);
// This function calls takes a table repeatedly adds rows to it.
// @param table - a table of tensors to be moved into a buffer
Status fillBuffer(TensorQTable *const table);
// Gets next tensor row and sets control signals
Status getNextTensorRow(TensorRow *new_row);
// This function runs the wait function on condition
Status blockCond();
private:
// clean up variable to return imcomplete buffer
bool clean_up_;
// end of file state, we stop reading data and shut down
bool eof_;
// rows per buffer
int32_t rows_per_buffer_;
// buffer_id
int32_t buffer_id_;
// local variable to keep track of the buffer information
std::unordered_map<std::string, int32_t> col_name_id_map_;
// iterator to pull new rows, we only have one child
std::unique_ptr<ChildIterator> child_iterator_;
// condition name, to support multiple barriers
std::string condition_name_;
// Function pointer of blocking function
py::function condition_function_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_
......@@ -34,7 +34,7 @@ class DataBuffer;
class ZipOp : public PipelineOp {
public:
// The nested builder class inside of the BatchOp is used to help manage all of
// The nested builder class inside of the ZipOp is used to help manage all of
// the arguments for constructing it. Use the builder by setting each argument
// with the provided set methods, and then finally call the build method to execute
// the actual construction.
......@@ -76,8 +76,8 @@ class ZipOp : public PipelineOp {
};
// Constructor for ZipOp
// @param rows_per_buffer number of rows in output buffer
// @param op_connector_size connector
// @param rows_per_buffer - number of rows in output buffer
// @param op_connector_size - connector size
ZipOp(int32_t rows_per_buffer, int32_t op_connector_size);
// Destructor
......@@ -88,8 +88,8 @@ class ZipOp : public PipelineOp {
Status EoeReceived(int32_t) override;
// Print function for Zip
// @param out output stream to print to
// @param show_all if it should print everything
// @param out - output stream to print to
// @param show_all - if it should print everything
void Print(std::ostream &out, bool show_all) const override;
// Provide stream operator for displaying it
......@@ -113,14 +113,14 @@ class ZipOp : public PipelineOp {
Status fillBuffer(TensorQTable *const table);
// Special handle case where an empty row has been received from child iterator
// @note we need to drain eoe signals from all children connectors.
// @details when this function is called, then we encountered eoe at child iterator
// @note - we need to drain eoe signals from all children connectors.
// @details - when this function is called, then we encountered eoe at child iterator
// we have to drain rows from other child iterators until we hit eoe from all other child iterators
Status drainPipeline();
// Merges 1 row from each childIterator together
// @param new_zip_row input and output, will return a non-empty row if all rows from childConnectors are non-empty
// @param updateColumnMapping generates a new column name to index mapping (mColNameIdMap) if set to true
// @param new_zip_row - input and output, will be a non-empty row if all rows from childConnectors are non-empty
// @param updateColumnMapping - generates a new column name to index mapping (mColNameIdMap) if set to true
// @details merge rows from iterator together. This is the main functionality for ZipOp
// this function takes one row and fills it with tensors from rows fetched
// from childIterators.
......
......@@ -28,6 +28,7 @@ import multiprocessing
import queue
from enum import Enum
from importlib import import_module
import threading
import numpy as np
from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \
......@@ -40,7 +41,7 @@ from .iterators import DictIterator, TupleIterator
from .validators import check, check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, check_rename, \
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \
check_zip_dataset, check_add_column, check_textfiledataset
check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
try:
......@@ -141,6 +142,7 @@ class Dataset:
self._batch_size = None
self._num_classes = None
self._repeat_count = None
self._sync = False
def get_args(self):
"""
......@@ -198,6 +200,30 @@ class Dataset:
"""
return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns)
@check_sync_wait
def sync_wait(self, condition_name, num_batch=1, callback=None):
'''
Add a blocking condition to the input Dataset
Args:
input_dataset (Dataset): Input dataset to apply flow control
num_batch (int): the number of batches without blocking at the start of each epoch
condition_name (str): The condition name that is used to toggle sending next row
callback (function): The callback funciton that will be invoked when sync_update is called
Raises:
RuntimeError: If condition name already exists.
Examples:
>>> import mindspore.dataset as ds
>>> # data is an instance of Dataset object.
>>> data = data.sync_wait("callback1")
>>> data = data.batch(batch_size)
>>> for batch_data in data.create_dict_iterator():
>>> data = data.sync_update("callback1")
'''
return SyncWaitDataset(self, condition_name, num_batch, callback)
@check_shuffle
def shuffle(self, buffer_size):
"""
......@@ -220,6 +246,9 @@ class Dataset:
Returns:
ShuffleDataset, dataset shuffled.
Raises:
RuntimeError: If exist sync operators before shuffle.
Examples:
>>> import mindspore.dataset as ds
>>> # data is an instance of Dataset object
......@@ -821,6 +850,9 @@ class Dataset:
self._input_indexs = value
def _get_pipeline_info(self):
"""
Gets pipeline information.
"""
device_iter = TupleIterator(self)
self._output_shapes = device_iter.get_output_shapes()
self._output_types = device_iter.get_output_types()
......@@ -875,6 +907,30 @@ class Dataset:
return self.input[0].num_classes()
return None
def get_sync_notifiers(self):
if self.input:
return self.input[0].get_sync_notifiers()
return {}
def is_sync(self):
if self.input:
return self.input[0].is_sync()
return False
def sync_update(self, condition_name, num_batch=None, data=None):
"""
condition_name (str): The condition name that is used to toggle sending next row
step_size (int or None): The number of steps(rows) that are released
when pass_rows is None, will update the same number as sync_wait specified
data (dict or None): The data passed to the callback
"""
notifiers_dict = self.get_sync_notifiers()
if condition_name not in notifiers_dict:
raise RuntimeError("Condition name not found")
if num_batch is not None:
num_batch *= self.get_batch_size()
notifiers_dict[condition_name](num_batch, data)
def get_batch_size(self):
"""
Get the size of a batch.
......@@ -978,6 +1034,8 @@ class BatchDataset(DatasetOp):
if BatchDataset._is_ancestor_of_repeat(input_dataset):
logger.warning("Repeat is located before batch, data from two epochs can be batched together.")
BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
self.batch_size = batch_size
self.drop_remainder = drop_remainder
self.per_batch_map = per_batch_map
......@@ -1034,6 +1092,20 @@ class BatchDataset(DatasetOp):
flag = flag | BatchDataset._is_ancestor_of_repeat(input_dataset)
return flag
@staticmethod
def _update_batch_size_for_syncwait(dataset, batch_size):
"""
Utility function to notify batch size to sync_wait.
Args:
dataset (Dataset): dataset to be checked
batchsize (int): batch size to notify
"""
if isinstance(dataset, SyncWaitDataset):
dataset.update_sync_batch_size(batch_size)
for input_dataset in dataset.input:
BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
class BatchInfo(CBatchInfo):
"""
......@@ -1058,6 +1130,108 @@ class BatchInfo(CBatchInfo):
"""
return
class BlockReleasePair:
"""
The blocking condition class used by SyncWaitDataset
Args:
init_release_rows (int): Number of lines to allow through the pipeline
callback (function): The callback funciton that will be called when release is called
"""
def __init__(self, init_release_rows, callback=None):
self.row_count = -init_release_rows
self.cv = threading.Condition()
self.callback = callback
self.default_rows = init_release_rows
def __deepcopy__(self, memodict):
if id(self) in memodict:
return memodict[id(self)]
memodict[id(self)] = self
# condition variable and callback are the same, but reset the counter
self.reset()
return self
def reset(self):
with self.cv:
self.row_count = -self.default_rows
self.cv.notify_all()
def update_batched_size(self, batch_size):
# should only use before the pipeline creates
self.row_count *= batch_size
self.default_rows *= batch_size
def block_func(self):
with self.cv:
self.cv.wait_for(lambda: self.row_count < 0)
self.row_count += 1
return True
def release_func(self, pass_rows=None, data=None):
with self.cv:
if pass_rows is None:
pass_rows = self.default_rows
self.row_count -= pass_rows
if self.callback is not None:
self.callback(data)
self.cv.notify_all()
class SyncWaitDataset(DatasetOp):
"""
The result of adding a blocking condition to the input Dataset
Args:
input_dataset (Dataset): Input dataset to apply flow control
num_batch (int): the number of batches without blocking at the start of each epoch
condition_name (str): The condition name that is used to toggle sending next row
callback (function): The callback funciton that will be invoked when sync_update is called
Raises:
RuntimeError: If condition name already exists.
"""
def __init__(self, input_dataset, condition_name, num_batch, callback=None):
super().__init__()
self.input.append(input_dataset)
input_dataset.output.append(self)
# set to the default value, waiting for the batch to update it
self._condition_name = condition_name
self._pair = BlockReleasePair(num_batch, callback)
if self._condition_name in self.input[0].get_sync_notifiers():
raise RuntimeError("Condition name is already in use")
def get_sync_notifiers(self):
return {**self.input[0].get_sync_notifiers(), **{self._condition_name: self._pair.release_func}}
def is_sync(self):
return True
def get_args(self):
args = super().get_args()
args["condition_name"] = self._condition_name
args["condition_func"] = self._pair.block_func
return args
def update_sync_batch_size(self, batch_size):
self._pair.update_batched_size(batch_size)
@staticmethod
def _is_ancestor_of_batch(dataset):
"""
Utility function to find the case where sync_wait is used before batch.
Args:
dataset (Dataset): dataset to be checked
Return:
True or False
"""
if isinstance(dataset, BatchDataset):
return True
flag = False
for input_dataset in dataset.input:
flag = flag | SyncWaitDataset._is_ancestor_of_batch(input_dataset)
return flag
class ShuffleDataset(DatasetOp):
"""
......@@ -1066,6 +1240,9 @@ class ShuffleDataset(DatasetOp):
Args:
input_dataset (Dataset): Input Dataset to be shuffled.
buffer_size (int): The size of the buffer.
Raises:
RuntimeError: If exist sync operators before shuffle.
"""
def __init__(self, input_dataset, buffer_size):
......@@ -1074,6 +1251,8 @@ class ShuffleDataset(DatasetOp):
self.input.append(input_dataset)
input_dataset.output.append(self)
self._input_indexs = input_dataset.input_indexs
if self.is_sync():
raise RuntimeError("No shuffle after sync operators")
def get_args(self):
args = super().get_args()
......@@ -1427,6 +1606,9 @@ class ZipDataset(DatasetOp):
"""
return None
def is_sync(self):
return any([c.is_sync() for c in self.input])
def get_args(self):
args = super().get_args()
return args
......
......@@ -129,6 +129,8 @@ class Iterator:
op_type = OpName.MINDRECORD
elif isinstance(dataset, de.BatchDataset):
op_type = OpName.BATCH
elif isinstance(dataset, de.SyncWaitDataset):
op_type = OpName.BARRIER
elif isinstance(dataset, de.ZipDataset):
op_type = OpName.ZIP
elif isinstance(dataset, de.MapDataset):
......
......@@ -652,6 +652,22 @@ def check_batch(method):
return new_method
def check_sync_wait(method):
"""check the input arguments of sync_wait."""
@wraps(method)
def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs)
nreq_param_str = ['condition_name']
nreq_param_int = ['step_size']
check_param_type(nreq_param_int, param_dict, int)
check_param_type(nreq_param_str, param_dict, str)
return method(*args, **kwargs)
return new_method
def check_shuffle(method):
"""check the input arguments of shuffle."""
......
......@@ -12,8 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing configuration manager
"""
import filecmp
import glob
import os
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
def test_basic():
ds.config.load('../data/dataset/declient.cfg')
......@@ -36,6 +46,34 @@ def test_basic():
assert ds.config.get_prefetch_size() == 4
assert ds.config.get_seed() == 5
def test_pipeline():
"""
Test that our configuration pipeline works when we set parameters at dataset interval
"""
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
ds.config.set_num_parallel_workers(2)
data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True)])
ds.serialize(data1, "testpipeline.json")
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
ds.config.set_num_parallel_workers(4)
data2 = data2.map(input_columns=["image"], operations=[vision.Decode(True)])
ds.serialize(data2, "testpipeline2.json")
# check that the generated output is different
assert (filecmp.cmp('testpipeline.json', 'testpipeline2.json'))
# this test passes currently because our num_parallel_workers don't get updated.
# remove generated jason files
file_list = glob.glob('*.json')
for f in file_list:
try:
os.remove(f)
except IOError:
logger.info("Error while deleting: {}".format(f))
if __name__ == '__main__':
test_basic()
test_pipeline()
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.dataset as ds
from mindspore import log as logger
import time
import numpy as np
def gen():
for i in range(100):
yield np.array(i),
class Augment:
def __init__(self, loss):
self.loss = loss
def preprocess(self, input):
return input
def update(self, data):
self.loss = data["loss"]
def test_simple_sync_wait():
"""
Test simple sync wait: test sync in dataset pipeline
"""
logger.info("test_simple_sync_wait")
batch_size = 4
dataset = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
dataset = dataset.batch(batch_size)
count = 0
for data in dataset.create_dict_iterator():
assert (data["input"][0] == count)
count += batch_size
data = {"loss": count}
dataset.sync_update(condition_name="policy", data=data)
def test_simple_shuffle_sync():
"""
Test simple shuffle sync: test shuffle before sync
"""
logger.info("test_simple_shuffle_sync")
shuffle_size = 4
batch_size = 10
dataset = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
dataset = dataset.shuffle(shuffle_size)
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
dataset = dataset.batch(batch_size)
count = 0
for data in dataset.create_dict_iterator():
count += 1
#time.sleep(0.5)
data = {"loss": count}
dataset.sync_update(condition_name="policy", data=data)
def test_two_sync():
"""
Test two sync: dataset pipeline with with two sync_operators
"""
logger.info("test_two_sync")
batch_size = 6
dataset = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
# notice that with our design, we need to have step_size = shuffle size
dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
dataset = dataset.sync_wait(num_batch=2, condition_name="every 2 batches")
dataset = dataset.batch(batch_size)
count = 0
for data in dataset.create_dict_iterator():
count += 1
data = {"loss": count}
dataset.sync_update(condition_name="every batch", data=data)
if count % 2 == 0:
dataset.sync_update(condition_name="every 2 batches")
def test_sync_epoch():
"""
Test sync wait with epochs: test sync with epochs in dataset pipeline
"""
logger.info("test_sync_epoch")
batch_size = 30
dataset = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
dataset = dataset.batch(batch_size, drop_remainder=True)
for epochs in range(3):
aug.update({"loss": 0})
count = 0
for data in dataset.create_dict_iterator():
assert (data["input"][0] == count)
count += batch_size
data = {"loss": count}
dataset.sync_update(condition_name="policy", data=data)
def test_sync_exception_01():
"""
Test sync: with shuffle in sync mode
"""
logger.info("test_sync_exception_01")
shuffle_size = 4
batch_size = 10
dataset = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
try:
dataset = dataset.shuffle(shuffle_size)
except BaseException as e:
assert "shuffle" in str(e)
dataset = dataset.batch(batch_size)
def test_sync_exception_02():
"""
Test sync: with duplicated condition name
"""
logger.info("test_sync_exception_02")
batch_size = 6
dataset = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
# notice that with our design, we need to have step_size = shuffle size
dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
try:
dataset = dataset.sync_wait(num_batch=2, condition_name="every batch")
except BaseException as e:
assert "name" in str(e)
dataset = dataset.batch(batch_size)
if __name__ == "__main__":
test_simple_sync_wait()
test_simple_shuffle_sync()
test_two_sync()
test_sync_exception_01()
test_sync_exception_02()
test_sync_epoch()
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册