提交 cd945187 编写于 作者: E eric

X# This is a combination of 2 commits.

Initial commit for dataset op python

Added signature to barrier

Adde compiling barrier code

Rebasing, fixed new compile errors

Final fix for make_unique

Added pybind API for barrier

Fixed pyfunc invocation

python interface - sync_wait

!1 sync_wait python interface
* python interface - sync_wait

fix test

update test

update test

Added new test case

add test case

test for shuffle + batch

Added two-sync test case

Restrited that no shuffle after sync

Added sync to pipeline info

block first databuffer as well

Intelligently get batch size

Fix default case

Lock Pair shares among all iterators

Added fix for empty character

Fixed up test case formatting

Fix end of epoch in sync_wait

Fixing CI
上级 4a0b2b4a
......@@ -48,6 +48,7 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {{kStorage, &D
{kMap, &DEPipeline::ParseMapOp},
{kFilter, &DEPipeline::ParseFilterOp},
{kBatch, &DEPipeline::ParseBatchOp},
{kBarrier, &DEPipeline::ParseBarrierOp},
{kRepeat, &DEPipeline::ParseRepeatOp},
{kSkip, &DEPipeline::ParseSkipOp},
{kZip, &DEPipeline::ParseZipOp},
......@@ -627,6 +628,30 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp>
return Status::OK();
}
Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
std::shared_ptr<BarrierOp::Builder> builder = std::make_shared<BarrierOp::Builder>();
// Right now barrier should only take num_rows_per_buffer = 1
// The reason for this is because having it otherwise can lead to blocking issues
// See barrier_op.h for more details
(void)builder->SetRowsPerBuffer(1);
for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "condition_name") {
(void)builder->SetConditionName(ToString(value));
} else if (key == "condition_func") {
(void)builder->SetConditionFunc(value.cast<py::function>());
}
}
}
std::shared_ptr<BarrierOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
*ptr = op;
return Status::OK();
}
Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr) {
int32_t prefetch_size = 0;
if (args.contains("prefetch_size")) {
......
......@@ -40,6 +40,7 @@ enum OpName {
kShuffle,
kMindrecord,
kBatch,
kBarrier,
kCache,
kRepeat,
kSkip,
......@@ -115,6 +116,8 @@ class DEPipeline {
Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseRenameOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
......
......@@ -476,6 +476,7 @@ PYBIND11_MODULE(_c_dataengine, m) {
.value("STORAGE", OpName::kStorage)
.value("SHUFFLE", OpName::kShuffle)
.value("BATCH", OpName::kBatch)
.value("BARRIER", OpName::kBarrier)
.value("MINDRECORD", OpName::kMindrecord)
.value("CACHE", OpName::kCache)
.value("REPEAT", OpName::kRepeat)
......
......@@ -25,6 +25,7 @@
#include "dataset/core/tensor_shape.h"
#include "dataset/engine/data_schema.h"
#include "dataset/engine/dataset_iterator.h"
#include "dataset/engine/datasetops/barrier_op.h"
#include "dataset/engine/datasetops/batch_op.h"
#include "dataset/engine/datasetops/dataset_op.h"
#include "dataset/engine/datasetops/device_queue_op.h"
......
......@@ -4,6 +4,7 @@ add_library(engine-datasetops OBJECT
dataset_op.cc
parallel_op.cc
pipeline_op.cc
barrier_op.cc
batch_op.cc
device_queue_op.cc
map_op.cc
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/barrier_op.h"
#include <utility>
#include "dataset/core/constants.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h"
#include "dataset/core/config_manager.h"
#include "dataset/core/global_context.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
BarrierOp::Builder::Builder() {
// Some arguments to the BarrierOp constructor have a default argument that is taken
// from the client config.
// The user may choose to change these values for the construction of the BarrierOp by
// using the various builder set methods.
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_rows_per_buffer_ = cfg->rows_per_buffer();
builder_op_connector_size_ = cfg->op_connector_size();
}
Status BarrierOp::Builder::SanityCheck() const { return Status::OK(); }
Status BarrierOp::Builder::Build(std::shared_ptr<BarrierOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<BarrierOp>(builder_rows_per_buffer_, builder_op_connector_size_, builder_condition_name_,
builder_condition_func_);
return Status::OK();
}
// Construct BarrierOp here, local variables initialized in operator due to tree construction restrictions
BarrierOp::BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name,
py::function condition_func)
: PipelineOp(op_connector_size),
rows_per_buffer_(rows_per_buffer),
buffer_id_(0),
clean_up_(false),
eof_(false),
condition_name_(condition_name),
condition_function_(condition_func) {}
// destructor
BarrierOp::~BarrierOp() {}
// Entry point for Barrier, called by launch()
Status BarrierOp::operator()() {
// The children_num_ parameter needs to be put here
// Synchronize with TaskManager once the thread is created.
TaskManager::FindMe()->Post();
// create child iterator, right now this barrier is a pipeline operator
int32_t worker_id = 0;
int32_t child_idx = 0;
child_iterator_ = std::make_unique<ChildIterator>(this, worker_id, child_idx);
// Loop until eof is true
while (!eof_) {
// Create new table to put the new tensor rows
std::unique_ptr<TensorQTable> curr_table = std::make_unique<TensorQTable>();
RETURN_IF_NOT_OK(prepare(curr_table.get()));
// If an eof got picked up during the above prepare, then we're done
if (eof_) {
break;
}
// we have to output new buffer with possibly different buffer size, possibly one row
while (!clean_up_) {
// 1. If a previous loop iteration sent the current table out, then create a new one.
if (curr_table == nullptr) {
curr_table = std::make_unique<TensorQTable>();
}
// 2 fill the table. Note: clean_up mode might get turned on if epoch is finished
RETURN_IF_NOT_OK(fillBuffer(curr_table.get()));
// 3 create and update buffer and send it to the out connector
if (!curr_table->empty()) {
std::unique_ptr<DataBuffer> curr_buffer = std::make_unique<DataBuffer>(buffer_id_, DataBuffer::kDeBFlagNone);
curr_buffer->set_tensor_table(std::move(curr_table));
curr_buffer->set_column_name_map(col_name_id_map_);
MS_LOG(DEBUG) << "Barrier operator finished one buffer, pushing, rows " << curr_buffer->NumRows() << ", cols "
<< curr_buffer->NumCols() << ", map " << col_name_id_map_.size() << ".";
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer)));
buffer_id_++;
}
}
// 4 handle drain state.
if (clean_up_) {
MS_LOG(DEBUG) << "Barrier operator sending epoch ending signal.";
// Send the eoe up.
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE))));
}
}
// 5 handle eof
// propagate eof here.
MS_LOG(INFO) << "Barrier operator got EOF, propagating.";
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF))));
return Status::OK();
}
// Handles preprocessing of the main loop, used when starting new epoch
Status BarrierOp::prepare(TensorQTable *const table) {
MS_LOG(DEBUG) << "Barrier operator prepares for new epoch.";
clean_up_ = false;
buffer_id_ = 0;
if (table == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp prepare phase requires a tensor table.");
}
// fill initial row
TensorRow new_row = {};
// use iterator to get next row and invoke pyfunc wait
RETURN_IF_NOT_OK(getNextTensorRow(&new_row));
// If the first row fetching resulted in eof, then we are done.
if (eof_) {
return Status::OK();
}
if (new_row.empty()) {
// This epoch is empty
return Status::OK();
}
// Pack this first row into our tensor table
// first row we also have to check if we should block
RETURN_IF_NOT_OK(blockCond());
table->push_back(std::move(new_row));
// At this point we have 1 row produced, we take the old column map id and use it in the new table
// Initializing col_name_id_map_ from the first data buffer.
col_name_id_map_ = child_iterator_->col_name_id_map();
// the update code below shouldn't do anything bad if the column name already exists.
return Status::OK();
}
// fillBuffer always expects a new table to fill
Status BarrierOp::fillBuffer(TensorQTable *const table) {
if (table == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp fillBuffer null table pointer.");
}
TensorRow new_row = {};
while (table->size() < static_cast<size_t>(rows_per_buffer_)) {
RETURN_IF_NOT_OK(getNextTensorRow(&new_row));
// Early exit the loop if we got empty row from any of our child iterations
if (new_row.empty()) {
return Status::OK();
}
// else we got a row so pack it into the tensor table.
RETURN_IF_NOT_OK(blockCond());
table->push_back(std::move(new_row));
}
return Status::OK();
}
// function executes a py_func and blocks until condition becomes true.
Status BarrierOp::blockCond() {
{
py::gil_scoped_acquire gil_acquire;
if (Py_IsInitialized() == 0) {
return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized");
}
// we have condition name, however the flexibility is in python today
try {
// Invoke python function
py::object ret_py_obj = condition_function_();
// Process the return value
if (!py::isinstance<py::bool_>(ret_py_obj)) {
return Status(StatusCode::kPyFuncException, "Condition wait function should return true/false");
}
} catch (const py::error_already_set &e) {
return Status(StatusCode::kPyFuncException, e.what());
}
}
return Status::OK();
}
// fetches next Barrier buffer row
Status BarrierOp::getNextTensorRow(TensorRow *new_row) {
// iterate over all iterators and generate a row
RETURN_IF_NOT_OK((child_iterator_)->FetchNextTensorRow(new_row));
// add each new row to iterator, check if row is empty, if row from iterator is empty return empty row
if (new_row->empty()) {
// If we did not get a row from any of the children, then it's the end of an epoch and we can move
// to drain state.
MS_LOG(INFO) << "Barrier operator child iterator produced empty row.";
clean_up_ = true;
// If we picked up an eof here, then we are completely done.
if ((child_iterator_)->eof_handled()) {
MS_LOG(INFO) << "Barrier operator iterator got EOF.";
eof_ = true;
}
return Status::OK();
}
return Status::OK();
}
// A function that prints info about the Operator
void BarrierOp::Print(std::ostream &out, bool show_all) const {
// Call base class printer first
PipelineOp::Print(out, show_all);
out << "\nBarrierOp:\n"
<< "\nCondition " << condition_name_ << "\n\n";
}
// overwrite function and handle eof
Status BarrierOp::EofReceived(int32_t) {
MS_LOG(DEBUG) << "Barrier operator EOF received, do nothing now.";
return Status::OK();
}
// overwrite function and handle eoe
Status BarrierOp::EoeReceived(int32_t) {
state_ = OpState::kDeOpIdle;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_
#define DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "dataset/core/tensor.h"
#include "dataset/engine/dataset_iterator.h"
#include "dataset/engine/datasetops/pipeline_op.h"
#include "dataset/kernels/tensor_op.h"
namespace mindspore {
namespace dataset {
// Forward declare
class DataBuffer;
class ExecutionTree;
// BarrierOp class implements the Barrier operator. It will block sending of rows until a signal has
// been received. This signal is given from python layer. The current barrier design respects the
// rows per buffer design and will only output a buffer with rows once it has received rows per buffer
// signals from python.
class BarrierOp : public PipelineOp {
public:
// The nested builder class inside of the BarrierOp is used to help manage all of
// the arguments for constructing it. Use the builder by setting each argument
// with the provided set methods, and then finally call the build method to execute
// the actual construction.
class Builder {
public:
// Builder constructor. Creates the builder object.
// @note No default args
// @return This is a constructor.
Builder();
// Default destructor
~Builder() = default;
// Setter method.
// @return Builder setter method returns reference to the builder.
Builder &SetRowsPerBuffer(int32_t rows_per_buffer) {
builder_rows_per_buffer_ = rows_per_buffer;
return *this;
}
// Setter method.
// @param int32_t op_connector_size
// @return Builder setter method returns reference to the builder.
Builder &SetOpConnectorSize(int32_t op_connector_size) {
builder_op_connector_size_ = op_connector_size;
return *this;
}
// Setter method.
// @param const std::string & condition_name
// @return Builder setter method returns reference to the builder.
Builder &SetConditionName(const std::string &condition_name) {
builder_condition_name_ = condition_name;
return *this;
}
// Setter method.
// @param py::function condition_func - blocking condition function
// @return Builder setter method returns reference to the builder.
Builder &SetConditionFunc(py::function condition_func) {
builder_condition_func_ = condition_func;
return *this;
}
// The builder "build" method creates the BarrierOp dataset Operator.
// @return shared_ptr to the new BarrierOp object
Status Build(std::shared_ptr<BarrierOp> *);
private:
int32_t builder_rows_per_buffer_;
int32_t builder_op_connector_size_;
std::string builder_condition_name_;
py::function builder_condition_func_;
Status SanityCheck() const;
};
// Constructor for BarrierOp
// @param rows_per_buffer - number of rows in output buffer
// @param op_connector_size - connector size
// @param condition_name - the condition name associated with this operator
// @param condition_func - the blocking condition check per row
// @note - currently rows_per_buffer should = 1 for barrier.
// The reason for this is having other values would complicate how the pipeline behaves with other operators
// One example of such case is having batch after barrier. Batch would be waiting for data and having
// rows per buffer in this case can result in hanging
BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name,
py::function condition_func);
// Destructor
~BarrierOp();
Status EofReceived(int32_t) override;
Status EoeReceived(int32_t) override;
// Print function for Barrier
// @param out - output stream to print to
// @param show_all - if it should print everything
void Print(std::ostream &out, bool show_all) const override;
// Provide stream operator for displaying it
friend std::ostream &operator<<(std::ostream &out, const BarrierOp &bo) {
bo.Print(out, false);
return out;
}
// Class functor operator () override.
// All dataset ops operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - The error code return
Status operator()() override;
// Handles preprocessing of the main loop, used when starting new epoch
// @param table - a table of tensors to be moved into a buffer
Status prepare(TensorQTable *const table);
// This function calls takes a table repeatedly adds rows to it.
// @param table - a table of tensors to be moved into a buffer
Status fillBuffer(TensorQTable *const table);
// Gets next tensor row and sets control signals
Status getNextTensorRow(TensorRow *new_row);
// This function runs the wait function on condition
Status blockCond();
private:
// clean up variable to return imcomplete buffer
bool clean_up_;
// end of file state, we stop reading data and shut down
bool eof_;
// rows per buffer
int32_t rows_per_buffer_;
// buffer_id
int32_t buffer_id_;
// local variable to keep track of the buffer information
std::unordered_map<std::string, int32_t> col_name_id_map_;
// iterator to pull new rows, we only have one child
std::unique_ptr<ChildIterator> child_iterator_;
// condition name, to support multiple barriers
std::string condition_name_;
// Function pointer of blocking function
py::function condition_function_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_
......@@ -34,7 +34,7 @@ class DataBuffer;
class ZipOp : public PipelineOp {
public:
// The nested builder class inside of the BatchOp is used to help manage all of
// The nested builder class inside of the ZipOp is used to help manage all of
// the arguments for constructing it. Use the builder by setting each argument
// with the provided set methods, and then finally call the build method to execute
// the actual construction.
......@@ -76,8 +76,8 @@ class ZipOp : public PipelineOp {
};
// Constructor for ZipOp
// @param rows_per_buffer number of rows in output buffer
// @param op_connector_size connector
// @param rows_per_buffer - number of rows in output buffer
// @param op_connector_size - connector size
ZipOp(int32_t rows_per_buffer, int32_t op_connector_size);
// Destructor
......@@ -88,8 +88,8 @@ class ZipOp : public PipelineOp {
Status EoeReceived(int32_t) override;
// Print function for Zip
// @param out output stream to print to
// @param show_all if it should print everything
// @param out - output stream to print to
// @param show_all - if it should print everything
void Print(std::ostream &out, bool show_all) const override;
// Provide stream operator for displaying it
......@@ -113,14 +113,14 @@ class ZipOp : public PipelineOp {
Status fillBuffer(TensorQTable *const table);
// Special handle case where an empty row has been received from child iterator
// @note we need to drain eoe signals from all children connectors.
// @details when this function is called, then we encountered eoe at child iterator
// @note - we need to drain eoe signals from all children connectors.
// @details - when this function is called, then we encountered eoe at child iterator
// we have to drain rows from other child iterators until we hit eoe from all other child iterators
Status drainPipeline();
// Merges 1 row from each childIterator together
// @param new_zip_row input and output, will return a non-empty row if all rows from childConnectors are non-empty
// @param updateColumnMapping generates a new column name to index mapping (mColNameIdMap) if set to true
// @param new_zip_row - input and output, will be a non-empty row if all rows from childConnectors are non-empty
// @param updateColumnMapping - generates a new column name to index mapping (mColNameIdMap) if set to true
// @details merge rows from iterator together. This is the main functionality for ZipOp
// this function takes one row and fills it with tensors from rows fetched
// from childIterators.
......
......@@ -26,6 +26,7 @@ import random
import uuid
from enum import Enum
from importlib import import_module
import threading
import numpy as np
from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \
......@@ -38,7 +39,7 @@ from .iterators import DictIterator, TupleIterator
from .validators import check, check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, check_rename, \
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \
check_zip_dataset, check_add_column, check_textfiledataset
check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
try:
......@@ -139,6 +140,7 @@ class Dataset:
self._batch_size = None
self._num_classes = None
self._repeat_count = None
self._sync = False
def get_args(self):
"""
......@@ -196,6 +198,30 @@ class Dataset:
"""
return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns)
@check_sync_wait
def sync_wait(self, condition_name, num_batch=1, callback=None):
'''
Add a blocking condition to the input Dataset
Args:
input_dataset (Dataset): Input dataset to apply flow control
num_batch (int): the number of batches without blocking at the start of each epoch
condition_name (str): The condition name that is used to toggle sending next row
callback (function): The callback funciton that will be invoked when sync_update is called
Raises:
RuntimeError: If condition name already exists.
Examples:
>>> import mindspore.dataset as ds
>>> # data is an instance of Dataset object.
>>> data = data.sync_wait("callback1")
>>> data = data.batch(batch_size)
>>> for batch_data in data.create_dict_iterator():
>>> data = data.sync_update("callback1")
'''
return SyncWaitDataset(self, condition_name, num_batch, callback)
@check_shuffle
def shuffle(self, buffer_size):
"""
......@@ -218,6 +244,9 @@ class Dataset:
Returns:
ShuffleDataset, dataset shuffled.
Raises:
RuntimeError: If exist sync operators before shuffle.
Examples:
>>> import mindspore.dataset as ds
>>> # data is an instance of Dataset object
......@@ -816,6 +845,9 @@ class Dataset:
self._input_indexs = value
def _get_pipeline_info(self):
"""
Gets pipeline information.
"""
device_iter = TupleIterator(self)
self._output_shapes = device_iter.get_output_shapes()
self._output_types = device_iter.get_output_types()
......@@ -870,6 +902,30 @@ class Dataset:
return self.input[0].num_classes()
return None
def get_sync_notifiers(self):
if self.input:
return self.input[0].get_sync_notifiers()
return {}
def is_sync(self):
if self.input:
return self.input[0].is_sync()
return False
def sync_update(self, condition_name, num_batch=None, data=None):
"""
condition_name (str): The condition name that is used to toggle sending next row
step_size (int or None): The number of steps(rows) that are released
when pass_rows is None, will update the same number as sync_wait specified
data (dict or None): The data passed to the callback
"""
notifiers_dict = self.get_sync_notifiers()
if condition_name not in notifiers_dict:
raise RuntimeError("Condition name not found")
if num_batch is not None:
num_batch *= self.get_batch_size()
notifiers_dict[condition_name](num_batch, data)
def get_batch_size(self):
"""
Get the size of a batch.
......@@ -973,6 +1029,8 @@ class BatchDataset(DatasetOp):
if BatchDataset._is_ancestor_of_repeat(input_dataset):
logger.warning("Repeat is located before batch, data from two epochs can be batched together.")
BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
self.batch_size = batch_size
self.drop_remainder = drop_remainder
self.per_batch_map = per_batch_map
......@@ -1029,6 +1087,20 @@ class BatchDataset(DatasetOp):
flag = flag | BatchDataset._is_ancestor_of_repeat(input_dataset)
return flag
@staticmethod
def _update_batch_size_for_syncwait(dataset, batch_size):
"""
Utility function to notify batch size to sync_wait.
Args:
dataset (Dataset): dataset to be checked
batchsize (int): batch size to notify
"""
if isinstance(dataset, SyncWaitDataset):
dataset.update_sync_batch_size(batch_size)
for input_dataset in dataset.input:
BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size)
class BatchInfo(CBatchInfo):
"""
......@@ -1053,6 +1125,108 @@ class BatchInfo(CBatchInfo):
"""
return
class BlockReleasePair:
"""
The blocking condition class used by SyncWaitDataset
Args:
init_release_rows (int): Number of lines to allow through the pipeline
callback (function): The callback funciton that will be called when release is called
"""
def __init__(self, init_release_rows, callback=None):
self.row_count = -init_release_rows
self.cv = threading.Condition()
self.callback = callback
self.default_rows = init_release_rows
def __deepcopy__(self, memodict):
if id(self) in memodict:
return memodict[id(self)]
memodict[id(self)] = self
# condition variable and callback are the same, but reset the counter
self.reset()
return self
def reset(self):
with self.cv:
self.row_count = -self.default_rows
self.cv.notify_all()
def update_batched_size(self, batch_size):
# should only use before the pipeline creates
self.row_count *= batch_size
self.default_rows *= batch_size
def block_func(self):
with self.cv:
self.cv.wait_for(lambda: self.row_count < 0)
self.row_count += 1
return True
def release_func(self, pass_rows=None, data=None):
with self.cv:
if pass_rows is None:
pass_rows = self.default_rows
self.row_count -= pass_rows
if self.callback is not None:
self.callback(data)
self.cv.notify_all()
class SyncWaitDataset(DatasetOp):
"""
The result of adding a blocking condition to the input Dataset
Args:
input_dataset (Dataset): Input dataset to apply flow control
num_batch (int): the number of batches without blocking at the start of each epoch
condition_name (str): The condition name that is used to toggle sending next row
callback (function): The callback funciton that will be invoked when sync_update is called
Raises:
RuntimeError: If condition name already exists.
"""
def __init__(self, input_dataset, condition_name, num_batch, callback=None):
super().__init__()
self.input.append(input_dataset)
input_dataset.output.append(self)
# set to the default value, waiting for the batch to update it
self._condition_name = condition_name
self._pair = BlockReleasePair(num_batch, callback)
if self._condition_name in self.input[0].get_sync_notifiers():
raise RuntimeError("Condition name is already in use")
def get_sync_notifiers(self):
return {**self.input[0].get_sync_notifiers(), **{self._condition_name: self._pair.release_func}}
def is_sync(self):
return True
def get_args(self):
args = super().get_args()
args["condition_name"] = self._condition_name
args["condition_func"] = self._pair.block_func
return args
def update_sync_batch_size(self, batch_size):
self._pair.update_batched_size(batch_size)
@staticmethod
def _is_ancestor_of_batch(dataset):
"""
Utility function to find the case where sync_wait is used before batch.
Args:
dataset (Dataset): dataset to be checked
Return:
True or False
"""
if isinstance(dataset, BatchDataset):
return True
flag = False
for input_dataset in dataset.input:
flag = flag | SyncWaitDataset._is_ancestor_of_batch(input_dataset)
return flag
class ShuffleDataset(DatasetOp):
"""
......@@ -1061,6 +1235,9 @@ class ShuffleDataset(DatasetOp):
Args:
input_dataset (Dataset): Input Dataset to be shuffled.
buffer_size (int): The size of the buffer.
Raises:
RuntimeError: If exist sync operators before shuffle.
"""
def __init__(self, input_dataset, buffer_size):
......@@ -1069,6 +1246,8 @@ class ShuffleDataset(DatasetOp):
self.input.append(input_dataset)
input_dataset.output.append(self)
self._input_indexs = input_dataset.input_indexs
if self.is_sync():
raise RuntimeError("No shuffle after sync operators")
def get_args(self):
args = super().get_args()
......@@ -1335,6 +1514,9 @@ class ZipDataset(DatasetOp):
"""
return None
def is_sync(self):
return any([c.is_sync() for c in self.input])
def get_args(self):
args = super().get_args()
return args
......
......@@ -125,6 +125,8 @@ class Iterator:
op_type = OpName.MINDRECORD
elif isinstance(dataset, de.BatchDataset):
op_type = OpName.BATCH
elif isinstance(dataset, de.SyncWaitDataset):
op_type = OpName.BARRIER
elif isinstance(dataset, de.ZipDataset):
op_type = OpName.ZIP
elif isinstance(dataset, de.MapDataset):
......
......@@ -652,6 +652,22 @@ def check_batch(method):
return new_method
def check_sync_wait(method):
"""check the input arguments of sync_wait."""
@wraps(method)
def new_method(*args, **kwargs):
param_dict = make_param_dict(method, args, kwargs)
nreq_param_str = ['condition_name']
nreq_param_int = ['step_size']
check_param_type(nreq_param_int, param_dict, int)
check_param_type(nreq_param_str, param_dict, str)
return method(*args, **kwargs)
return new_method
def check_shuffle(method):
"""check the input arguments of shuffle."""
......
......@@ -12,8 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing configuration manager
"""
import filecmp
import glob
import os
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
def test_basic():
ds.config.load('../data/dataset/declient.cfg')
......@@ -36,6 +46,34 @@ def test_basic():
assert ds.config.get_prefetch_size() == 4
assert ds.config.get_seed() == 5
def test_pipeline():
"""
Test that our configuration pipeline works when we set parameters at dataset interval
"""
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
ds.config.set_num_parallel_workers(2)
data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True)])
ds.serialize(data1, "testpipeline.json")
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False)
ds.config.set_num_parallel_workers(4)
data2 = data2.map(input_columns=["image"], operations=[vision.Decode(True)])
ds.serialize(data2, "testpipeline2.json")
# check that the generated output is different
assert (filecmp.cmp('testpipeline.json', 'testpipeline2.json'))
# this test passes currently because our num_parallel_workers don't get updated.
# remove generated jason files
file_list = glob.glob('*.json')
for f in file_list:
try:
os.remove(f)
except IOError:
logger.info("Error while deleting: {}".format(f))
if __name__ == '__main__':
test_basic()
test_pipeline()
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import mindspore.dataset as ds
from mindspore import log as logger
import time
import numpy as np
def gen():
for i in range(100):
yield np.array(i),
class Augment:
def __init__(self, loss):
self.loss = loss
def preprocess(self, input):
return input
def update(self, data):
self.loss = data["loss"]
def test_simple_sync_wait():
"""
Test simple sync wait: test sync in dataset pipeline
"""
logger.info("test_simple_sync_wait")
batch_size = 4
dataset = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
dataset = dataset.batch(batch_size)
count = 0
for data in dataset.create_dict_iterator():
assert (data["input"][0] == count)
count += batch_size
data = {"loss": count}
dataset.sync_update(condition_name="policy", data=data)
def test_simple_shuffle_sync():
"""
Test simple shuffle sync: test shuffle before sync
"""
logger.info("test_simple_shuffle_sync")
shuffle_size = 4
batch_size = 10
dataset = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
dataset = dataset.shuffle(shuffle_size)
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
dataset = dataset.batch(batch_size)
count = 0
for data in dataset.create_dict_iterator():
count += 1
#time.sleep(0.5)
data = {"loss": count}
dataset.sync_update(condition_name="policy", data=data)
def test_two_sync():
"""
Test two sync: dataset pipeline with with two sync_operators
"""
logger.info("test_two_sync")
batch_size = 6
dataset = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
# notice that with our design, we need to have step_size = shuffle size
dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
dataset = dataset.sync_wait(num_batch=2, condition_name="every 2 batches")
dataset = dataset.batch(batch_size)
count = 0
for data in dataset.create_dict_iterator():
count += 1
data = {"loss": count}
dataset.sync_update(condition_name="every batch", data=data)
if count % 2 == 0:
dataset.sync_update(condition_name="every 2 batches")
def test_sync_epoch():
"""
Test sync wait with epochs: test sync with epochs in dataset pipeline
"""
logger.info("test_sync_epoch")
batch_size = 30
dataset = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
dataset = dataset.batch(batch_size, drop_remainder=True)
for epochs in range(3):
aug.update({"loss": 0})
count = 0
for data in dataset.create_dict_iterator():
assert (data["input"][0] == count)
count += batch_size
data = {"loss": count}
dataset.sync_update(condition_name="policy", data=data)
def test_sync_exception_01():
"""
Test sync: with shuffle in sync mode
"""
logger.info("test_sync_exception_01")
shuffle_size = 4
batch_size = 10
dataset = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
try:
dataset = dataset.shuffle(shuffle_size)
except BaseException as e:
assert "shuffle" in str(e)
dataset = dataset.batch(batch_size)
def test_sync_exception_02():
"""
Test sync: with duplicated condition name
"""
logger.info("test_sync_exception_02")
batch_size = 6
dataset = ds.GeneratorDataset(gen, column_names=["input"])
aug = Augment(0)
# notice that with our design, we need to have step_size = shuffle size
dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
try:
dataset = dataset.sync_wait(num_batch=2, condition_name="every batch")
except BaseException as e:
assert "name" in str(e)
dataset = dataset.batch(batch_size)
if __name__ == "__main__":
test_simple_sync_wait()
test_simple_shuffle_sync()
test_two_sync()
test_sync_exception_01()
test_sync_exception_02()
test_sync_epoch()
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册