提交 8e4c0a9d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3212 GetDatasize feature

Merge pull request !3212 from anzhengqi/epochs-ready
......@@ -25,6 +25,8 @@
#include "minddata/dataset/engine/dataset_iterator.h"
#include "minddata/dataset/engine/datasetops/bucket_batch_by_length_op.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/engine/datasetops/filter_op.h"
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
......@@ -84,7 +86,8 @@ static std::unordered_map<uint32_t, pFunction> g_parse_op_func_ = {
{kRandomData, &DEPipeline::ParseRandomDataOp},
{kTextFile, &DEPipeline::ParseTextFileOp},
{kBuildVocab, &DEPipeline::ParseBuildVocabOp},
{kClue, &DEPipeline::ParseClueOp}};
{kClue, &DEPipeline::ParseClueOp},
{kEpochCtrl, &DEPipeline::ParseEpochCtrlOp}};
DEPipeline::DEPipeline() : iterator_(nullptr) {
try {
......@@ -166,8 +169,8 @@ Status DEPipeline::AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr &
Status DEPipeline::AssignRootNode(const DsOpPtr &dataset_op) { return (tree_->AssignRoot(dataset_op)); }
// Function to launch the tree execution.
Status DEPipeline::LaunchTreeExec() {
RETURN_IF_NOT_OK(tree_->Prepare());
Status DEPipeline::LaunchTreeExec(const int32_t num_epochs) {
RETURN_IF_NOT_OK(tree_->Prepare(num_epochs));
RETURN_IF_NOT_OK(tree_->Launch());
iterator_ = std::make_unique<DatasetIterator>(tree_);
if (iterator_ == nullptr) RETURN_STATUS_UNEXPECTED("Cannot create an Iterator.");
......@@ -252,6 +255,16 @@ int DEPipeline::GetRepeatCount() const { return repeat_num_; }
float ToFloat(const py::handle &handle) { return py::reinterpret_borrow<py::float_>(handle); }
Status DEPipeline::StopSend() {
// tree_.root() must be DeviceQueueOp
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(tree_->root().get());
if (op == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "StopSend only supported by DeviceQueueOp");
}
op->StopSend();
return Status::OK();
}
int ToInt(const py::handle &handle) { return py::reinterpret_borrow<py::int_>(handle); }
bool ToBool(const py::handle &handle) { return py::reinterpret_borrow<py::bool_>(handle); }
......@@ -804,6 +817,18 @@ Status DEPipeline::ParseSkipOp(const py::dict &args, std::shared_ptr<DatasetOp>
return Status::OK();
}
Status DEPipeline::ParseEpochCtrlOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
std::shared_ptr<DatasetOp> *bottom) {
if (args["count"].is_none()) {
std::string err_msg = "Error: count is invalid or not set.";
RETURN_STATUS_UNEXPECTED(err_msg);
}
std::shared_ptr<EpochCtrlOp> op;
RETURN_IF_NOT_OK(EpochCtrlOp::Builder(ToInt(args["count"])).Build(&op));
*top = op;
return Status::OK();
}
Status DEPipeline::ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
std::shared_ptr<DatasetOp> *bottom) {
std::shared_ptr<GeneratorOp::Builder> builder = std::make_shared<GeneratorOp::Builder>();
......@@ -973,8 +998,8 @@ Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptr<Data
(void)builder->SetDeviceType(ToString(value));
} else if (key == "device_id") {
(void)builder->SetDeviceId(ToInt(value));
} else if (key == "num_batch") {
(void)builder->SetNumBatch(ToInt(value));
} else if (key == "send_epoch_end") {
(void)builder->SetSendEpochEnd(ToBool(value));
}
}
}
......
......@@ -70,7 +70,8 @@ enum OpName {
kRandomData,
kTextFile,
kBuildVocab,
kClue
kClue,
kEpochCtrl
};
// The C++ binder class that we expose to the python script.
......@@ -90,7 +91,7 @@ class DEPipeline {
Status AssignRootNode(const DsOpPtr &dataset_op);
// Function to launch the tree execution.
Status LaunchTreeExec();
Status LaunchTreeExec(int32_t num_epochs);
// Get a row of data as dictionary of column name to the value.
Status GetNextAsMap(py::dict *output);
......@@ -143,6 +144,10 @@ class DEPipeline {
Status ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr<DatasetOp> *top,
std::shared_ptr<DatasetOp> *bottom);
Status ParseEpochCtrlOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseBatchOp(const py::dict &args, std::shared_ptr<DatasetOp> *ptr);
Status ParseBarrierOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status ParseGeneratorOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
......@@ -189,6 +194,8 @@ class DEPipeline {
Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
Status StopSend();
Status ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp> *top, std::shared_ptr<DatasetOp> *bottom);
private:
......
......@@ -159,7 +159,7 @@ void bindDEPipeline(py::module *m) {
[](DEPipeline &de, const DsOpPtr &dataset_op) { THROW_IF_ERROR(de.AssignRootNode(dataset_op)); })
.def("SetBatchParameters",
[](DEPipeline &de, const py::dict &args) { THROW_IF_ERROR(de.SetBatchParameters(args)); })
.def("LaunchTreeExec", [](DEPipeline &de) { THROW_IF_ERROR(de.LaunchTreeExec()); })
.def("LaunchTreeExec", [](DEPipeline &de, int32_t num_epochs) { THROW_IF_ERROR(de.LaunchTreeExec(num_epochs)); })
.def("GetNextAsMap",
[](DEPipeline &de) {
py::dict out;
......@@ -188,6 +188,7 @@ void bindDEPipeline(py::module *m) {
.def("GetBatchSize", &DEPipeline::GetBatchSize)
.def("GetNumClasses", &DEPipeline::GetNumClasses)
.def("GetRepeatCount", &DEPipeline::GetRepeatCount)
.def("StopSend", [](DEPipeline &de) { THROW_IF_ERROR(de.StopSend()); })
.def("SaveDataset", [](DEPipeline &de, const std::vector<std::string> &file_names, const std::string &file_type) {
THROW_IF_ERROR(de.SaveDataset(file_names, file_type));
return true;
......@@ -999,7 +1000,8 @@ PYBIND11_MODULE(_c_dataengine, m) {
.value("BUILDVOCAB", OpName::kBuildVocab)
.value("CELEBA", OpName::kCelebA)
.value("TEXTFILE", OpName::kTextFile)
.value("CLUE", OpName::kClue);
.value("CLUE", OpName::kClue)
.value("EPOCHCTRL", OpName::kEpochCtrl);
(void)py::enum_<JiebaMode>(m, "JiebaMode", py::arithmetic())
.value("DE_JIEBA_MIX", JiebaMode::kMix)
......
......@@ -40,7 +40,9 @@ Status IteratorBase::GetNextAsMap(TensorMap *out_map) {
out_map->clear();
TensorRow curr_row;
MS_LOG(INFO) << "get next as map start.";
RETURN_IF_NOT_OK(FetchNextTensorRow(&curr_row));
MS_LOG(INFO) << "fetchNextTensor success.";
// Return empty map if there's no data
if (curr_row.empty()) {
......@@ -105,7 +107,8 @@ Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) {
// Once eof is handled, always return empty row. Class must be destroyed and recreated if you
// want to iterate again.
if (eof_handled_) {
return Status::OK();
std::string err = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs.";
RETURN_STATUS_UNEXPECTED(err);
}
// Check if we need to get a new DataBuffer to iterate.
......@@ -119,36 +122,22 @@ Status DatasetIterator::FetchNextTensorRow(TensorRow *out_row) {
// Since GetNextBuffer was used rather than GetNextInput(), it means we need to manually
// handle eoe and eof messages here.
//
// An eoe buffer means we have iterated fully to the end of the tree.
// An eoe buffer will be immediately followed by an eof buffer, which signals the shutdown of
// all operators.
// An eoe buffer means we have iterated an epoch.
// The next buffer in the pipeline might be an EOF or a databuffer for next epoch
if (curr_buffer_->eoe()) {
MS_LOG(DEBUG) << "End of data iteration. Fetch eof and then return empty row.";
// Before returning the last empty vector, fetch the eof buffer which should be the last
// buffer, and then free it.
RETURN_IF_NOT_OK(root_->GetNextBuffer(&curr_buffer_));
if (!curr_buffer_->eof()) {
RETURN_STATUS_UNEXPECTED("Non-eof after getting eoe in iterator!");
}
eof_handled_ = true;
curr_buffer_.reset(); // explicitly free the eof buffer
// Set tree to Finished state
root_->Tree()->SetFinished();
MS_LOG(INFO) << "End of data iteration.";
curr_buffer_.reset(); // explicitly free the eoe buffer
return Status::OK();
}
// An eof buffer means it is the end of execution and all operators are shutting down.
// Because there is no more data to return to the caller, this will change `eof_handled_` state and
// returns status unexpected error.
if (curr_buffer_->eof()) {
// An eof by itself, without being preceded by an eoe, is possible if a repeat operator
// exists below us in the stack. Repeat operator eats eoe's but eventually allows the
// flow of an eof up the pipeline by itself.
eof_handled_ = true;
curr_buffer_.reset(); // explicitly free the eof buffer
// Set tree to Finished state
root_->Tree()->SetFinished();
return Status::OK();
std::string err = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs.";
RETURN_STATUS_UNEXPECTED(err);
}
}
......@@ -208,20 +197,24 @@ Status ChildIterator::FetchNextTensorRow(TensorRow *out_row) {
// Once eof is handled, always return empty row. Class must be destroyed and recreated if you
// want to iterate again.
if (eof_handled_) {
return Status::OK();
std::string err = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs.";
RETURN_STATUS_UNEXPECTED(err);
}
// Check if we need to get a new DataBuffer to iterate.
if (curr_buffer_ == nullptr || curr_buffer_->NumRows() == 0) {
// GetNextInput() depends on current_op's EoeReceived. So, EOE buffer might be already be handled and
// this child iterator might not see EOE buffer.
RETURN_IF_NOT_OK(current_op_->GetNextInput(&curr_buffer_, worker_id_, child_idx_));
// Unlike the DatasetIterator, this child iterator does not quit after eoe.
// Instead, if an eoe is picked up here, we simply return an empty vector and it's up to the
// If an eoe is picked up here, we simply return an empty vector and it's up to the
// caller to decide what it wants to do next.
if (curr_buffer_->eoe()) {
MS_LOG(DEBUG) << "Child iterator picked up EOE.";
end_epoch_ = true;
return Status::OK();
} else {
end_epoch_ = false;
}
if (curr_buffer_->eof()) {
......
......@@ -144,6 +144,9 @@ class ChildIterator : public IteratorBase {
// @return The string to column id mapping.
std::unordered_map<std::string, int32_t> GetColumnNameMap() const override;
// Return T/F if end of epoch
bool end_of_epoch() { return end_epoch_; }
private:
DatasetOp *current_op_; // The parent operator. We consume from it's children.
int32_t child_idx_; // The specific child this iterator will fetch from.
......
......@@ -18,6 +18,7 @@ set(DATASET_ENGINE_DATASETOPS_SRC_FILES
shuffle_op.cc
zip_op.cc
concat_op.cc
epoch_ctrl_op.cc
cache_base_op.cc
cache_lookup_op.cc
cache_op.cc
......
......@@ -17,11 +17,13 @@
#include "minddata/dataset/engine/datasetops/build_vocab_op.h"
#include <algorithm>
#include <iomanip>
#include <limits>
#include <string>
#include <unordered_map>
#include <utility>
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
......@@ -202,5 +204,29 @@ BuildVocabOp::Builder::Builder()
builder_num_workers_ = cfg->num_parallel_workers();
builder_connector_size_ = cfg->op_connector_size();
}
// A print method typically used for debugging
void BuildVocabOp::Print(std::ostream &out, bool show_all) const {
// Always show the id and name as first line regardless if this summary or detailed print
out << "(" << std::setw(2) << operator_id_ << ") <BuildVocabOp>:";
if (!show_all) {
// Call the super class for displaying any common 1-liner info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op
out << "\n";
} else {
// Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nCode is needed here to show more info about the op."
<< "\n\n";
}
}
// Pre-Visitor accept method for NodePass
Status BuildVocabOp::PreAccept(NodePass *p, bool *modified) {
// Downcast shared pointer then call the pre-visitation
return p->PreRunOnNode(shared_from_base<BuildVocabOp>(), modified);
}
} // namespace dataset
} // namespace mindspore
......@@ -131,6 +131,21 @@ class BuildVocabOp : public ParallelOp {
~BuildVocabOp() = default;
/// \brief A print method typically used for debugging
/// \param[out] out The output stream to write output to
/// \param[in] show_all A bool to control if you want to show all info or just a summary
void Print(std::ostream &out, bool show_all) const override;
/// \briefStream output operator overload
/// \notes This allows you to write the debug print info using stream operators
/// \param[out] out Reference to the output stream being overloaded
/// \param[in] vop - reference to the BuildVocabOp to display
/// \return - the output stream must be returned
friend std::ostream &operator<<(std::ostream &out, const BuildVocabOp &vop) {
vop.Print(out, false);
return out;
}
Status WorkerEntry(int32_t worker_id) override;
// collect the work product from each worker
......@@ -152,6 +167,12 @@ class BuildVocabOp : public ParallelOp {
Status Reset() override { RETURN_STATUS_UNEXPECTED("Reset shouldn't be called in BuildVocabOp"); }
/// \brief Base-class override for NodePass pre-visit acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status PreAccept(NodePass *p, bool *modified) override;
private:
const int32_t interval_;
bool special_first_;
......
......@@ -96,7 +96,7 @@ Status CacheMergeOp::WorkerEntry(int32_t worker_id) {
RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id));
}
}
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db_ptr)));
RETURN_IF_NOT_OK(EofReceived(worker_id));
return Status::OK();
}
Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
......@@ -298,5 +298,19 @@ Status CacheMergeOp::EoeReceived(int32_t worker_id) {
}
return Status::OK();
}
// Base-class override for handling cases when an eof is received.
Status CacheMergeOp::EofReceived(int32_t worker_id) {
// If we are not in a repeated path, then the merge op gets a eof by itself, without first
// getting an eoe. However, the logic demands that all epochs close with an eoe first before eof.
// Thus, generate an eoe first, before flowing up the eof in the non-repeated case. Base class
// provides that for us.
if (!BitTest(op_ctrl_flags_, kDeOpRepeated)) {
MS_LOG(DEBUG) << "Cache merge sending eoe";
RETURN_IF_NOT_OK(DatasetOp::EoeReceived(worker_id));
}
MS_LOG(DEBUG) << "Cache merge sending eof";
return DatasetOp::EofReceived(worker_id);
}
} // namespace dataset
} // namespace mindspore
......@@ -176,6 +176,11 @@ class CacheMergeOp : public ParallelOp {
/// \return Status object
Status EoeReceived(int32_t worker_id) override;
/// \brief Base-class override for handling cases when an eof is received.
/// \param worker_id - The worker id
/// \return Status - The error code return
Status EofReceived(int32_t worker_id) override;
protected:
Status ComputeColMap() override;
......
......@@ -26,6 +26,7 @@
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/engine/db_connector.h"
#include "minddata/dataset/engine/opt/pass.h"
......@@ -102,6 +103,15 @@ Status DatasetOp::InsertAsParent(std::shared_ptr<DatasetOp> to_add) {
}
return Status::OK();
}
// Removes child operator in this operator.
Status DatasetOp::RemoveChildren() {
for (const auto &child : child_) {
child->RemoveParent(this);
}
child_.clear();
return Status::OK();
}
// Adds a parent operator to this operator
void DatasetOp::AddParent(DatasetOp *parent) { parent_.push_back(parent); }
......@@ -185,6 +195,12 @@ void DatasetOp::Parent(DatasetOp **parent, int32_t parent_index) const {
}
}
// Getter function to get all of our children.
std::vector<std::shared_ptr<DatasetOp>> DatasetOp::children() const { return child_; }
// Getter function to get all of our parents.
std::vector<DatasetOp *> DatasetOp::parents() const { return parent_; }
// Creates the connector within this operator
void DatasetOp::CreateConnector(int32_t num_producers, int32_t num_consumers) {
MS_LOG(DEBUG) << "Creating connector in tree operator: " << operator_id_ << ". Producer: " << num_producers
......
......@@ -76,6 +76,9 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return Status eerror code returned
Status Remove();
// Removes child operator in this operator.
Status RemoveChildren();
/// \brief Getter function to get a shared pointer to our child
/// \param[in] child_index An operator can have n children. Indicates which child to return.
/// \return The shared pointer to the child. If there are no children, it returns null regardless of the given index
......@@ -86,6 +89,12 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \param[in] parent_index An operator can have n parents. Indicates which parent to return.
void Parent(DatasetOp **parent, int32_t parent_index) const;
// Getter function to get all of our children.
std::vector<std::shared_ptr<DatasetOp>> children() const;
// Getter function to get all of our parents.
std::vector<DatasetOp *> parents() const;
// Inserts a operator as the parent current op.
// Inserted op will become the sole parent of the current op.
// The existing parent of the current op will be transferred to the inserted op.
......
......@@ -25,19 +25,21 @@
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/perf/profiling.h"
#include "minddata/dataset/engine/perf/device_queue_tracing.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/task_manager.h"
namespace mindspore {
namespace dataset {
DeviceQueueOp::DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size,
int32_t op_connector_size, int64_t num_batch)
int32_t op_connector_size, bool send_epoch_end)
: PipelineOp(op_connector_size),
channel_name_(channel_name),
device_type_(device_type),
device_id_(device_id),
prefetch_size_(prefetch_size),
num_batch_(num_batch) {}
send_epoch_end_(send_epoch_end),
stop_send_(false) {}
DeviceQueueOp::~DeviceQueueOp() {}
......@@ -53,8 +55,7 @@ DeviceQueueOp::Builder::Builder(int32_t prefetch_size)
: builder_prefetch_size_(prefetch_size),
builder_device_id_(0),
builder_device_type_(DeviceType::CPU),
builder_channel_name_(""),
builder_num_batch_(0) {
builder_channel_name_("") {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_op_connector_size_ = cfg->op_connector_size();
}
......@@ -64,6 +65,18 @@ Status DeviceQueueOp::EoeReceived(int32_t worker_id) {
return Status::OK();
}
Status DeviceQueueOp::CheckExceptions(const std::unique_ptr<DataBuffer> &buffer) const {
// this method checks if the buffer meets the conditions to be sent to TDT
if (buffer->NumRows() != 0) {
TensorRow row;
buffer->GetRow(0, &row);
for (const auto &item : row) {
CHECK_FAIL_RETURN_UNEXPECTED(item->type().IsNumeric(), "Cannot send tensor of string type to device.");
}
}
return Status::OK();
}
Status DeviceQueueOp::operator()() {
TaskManager::FindMe()->Post();
......@@ -82,23 +95,10 @@ Status DeviceQueueOp::operator()() {
return Status::OK();
}
Status DeviceQueueOp::CheckExceptions(const std::unique_ptr<DataBuffer> &buffer) const {
// this method checks if the buffer meets the conditions to be sent to TDT
if (buffer->NumRows() != 0) {
TensorRow row;
buffer->GetRow(0, &row);
for (const auto &item : row) {
CHECK_FAIL_RETURN_UNEXPECTED(item->type().IsNumeric(), "Cannot send tensor of string type to device.");
}
}
return Status::OK();
}
#ifdef ENABLE_TDTQUE
Status DeviceQueueOp::SendDataToAscend() {
MS_LOG(INFO) << "Device queue, sending data to Ascend.";
int64_t total_batch = 0;
bool is_break_loop = false;
double batch_start_time, end_time;
int32_t batch_cost, tdt_cost;
int32_t connector_size = 0;
......@@ -115,15 +115,20 @@ Status DeviceQueueOp::SendDataToAscend() {
std::unique_ptr<DataBuffer> current_buffer;
RETURN_IF_NOT_OK(GetNextInput(&current_buffer));
while (!current_buffer->eof() && !is_break_loop) {
while (!current_buffer->eoe() && !is_break_loop) {
while (!current_buffer->eof()) {
while (!current_buffer->eoe()) {
RETURN_IF_NOT_OK(CheckExceptions(current_buffer));
TensorRow currRow;
for (int row_id = 0; row_id < current_buffer->NumRows() && !is_break_loop; row_id++) {
for (int row_id = 0; row_id < current_buffer->NumRows(); row_id++) {
RETURN_IF_NOT_OK(current_buffer->GetRow(row_id, &currRow));
auto status = tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost);
if (status == TdtStatus::FAILED) {
return Status(StatusCode::kTDTPushFailure, "TDT Push Failed");
if (stop_send_) {
MS_LOG(INFO) << "stop_send received";
return Status::OK();
} else {
return Status(StatusCode::kTDTPushFailure, "TDT Push Failed");
}
}
if (isProfilingEnable) {
......@@ -140,9 +145,6 @@ Status DeviceQueueOp::SendDataToAscend() {
profiling_node->Record(CONNECTOR_DEPTH, connector_capacity, total_batch + 1, connector_size);
}
total_batch++;
if (num_batch_ > 0 && total_batch == num_batch_) {
is_break_loop = true;
}
}
if (isProfilingEnable) {
connector_size = ChildOpConnectorSize();
......@@ -150,6 +152,19 @@ Status DeviceQueueOp::SendDataToAscend() {
}
RETURN_IF_NOT_OK(GetNextInput(&current_buffer));
}
if (current_buffer->eoe() && send_epoch_end_) {
TensorRow currRow;
auto status =
tdtInstancePtr->hostPush(currRow, true, channel_name_, isProfilingEnable, tdt_cost, tdt::TDT_END_OF_SEQUENCE);
if (status == TdtStatus::FAILED) {
if (stop_send_) {
MS_LOG(INFO) << "stop_send received";
return Status::OK();
} else {
return Status(StatusCode::kTDTPushFailure, "TDT Push Failed");
}
}
}
if (isProfilingEnable) {
connector_size = ChildOpConnectorSize();
connector_capacity = ChildOpConnectorCapacity();
......@@ -158,7 +173,7 @@ Status DeviceQueueOp::SendDataToAscend() {
}
tree_->SetFinished();
MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << ".";
MS_LOG(INFO) << "Device queue total batch is " << total_batch;
return Status::OK();
}
......@@ -196,9 +211,6 @@ Status DeviceQueueOp::SendDataToGPU() {
}
RETURN_IF_NOT_OK(RetryPushGPUData(data_size, curr_row, handle));
total_batch++;
if (num_batch_ > 0 && total_batch == num_batch_) {
is_break_loop = true;
}
}
if (!TaskManager::FindMe()->Interrupted())
RETURN_IF_NOT_OK(GetNextInput(&current_buffer));
......@@ -211,12 +223,10 @@ Status DeviceQueueOp::SendDataToGPU() {
is_break_loop = true;
}
MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << ".";
MS_LOG(INFO) << "Device queue total batch is " << total_batch << ".";
GpuBufferMgr::GetInstance().Close(handle);
GpuBufferMgr::GetInstance().CloseConfirm();
return Status::OK();
}
......@@ -240,8 +250,11 @@ Status DeviceQueueOp::RetryPushGPUData(const std::vector<size_t> &data_size, con
if (ret == BlockQueueStatus_T::ERROR_INPUT) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "invalid input Data, please check it.");
} else {
MS_LOG(WARNING) << "Retry pushing data...";
continue;
if (!stop_send_) {
MS_LOG(WARNING) << "Retry pushing data...";
continue;
}
break;
}
} else {
break;
......@@ -283,13 +296,11 @@ Status DeviceQueueOp::SendDataToCPU() {
MS_LOG(DEBUG) << "Feature size is " << curr_row[0]->SizeInBytes() << ".";
MS_LOG(DEBUG) << "Label size is " << curr_row[1]->SizeInBytes() << ".";
total_batch++;
if (num_batch_ > 0 && total_batch == num_batch_) {
break;
}
if (stop_send_) break;
}
}
MS_LOG(INFO) << "Device queue total batch is " << total_batch << ", number of batches is " << num_batch_ << ".";
MS_LOG(INFO) << "Device queue total batch is " << total_batch << ".";
return Status::OK();
}
......
......@@ -21,6 +21,7 @@
#include <vector>
#include "minddata/dataset/engine/datasetops/pipeline_op.h"
#include "minddata/dataset/engine/datasetops/repeat_op.h"
#include "minddata/dataset/util/status.h"
#ifdef ENABLE_TDTQUE
......@@ -84,8 +85,8 @@ class DeviceQueueOp : public PipelineOp {
return *this;
}
Builder &SetNumBatch(int64_t num_batch) {
builder_num_batch_ = num_batch;
Builder &SetSendEpochEnd(bool send_epoch_end) {
builder_send_epoch_end_ = send_epoch_end;
return *this;
}
......@@ -94,8 +95,9 @@ class DeviceQueueOp : public PipelineOp {
// to call this Build() method. It will instantiate the DeviceQueueOp
// and return it to caller as a shared pointer.
Status Build(std::shared_ptr<DeviceQueueOp> *ptr) {
*ptr = std::make_shared<DeviceQueueOp>(builder_channel_name_, builder_device_type_, builder_device_id_,
builder_prefetch_size_, builder_op_connector_size_, builder_num_batch_);
*ptr =
std::make_shared<DeviceQueueOp>(builder_channel_name_, builder_device_type_, builder_device_id_,
builder_prefetch_size_, builder_op_connector_size_, builder_send_epoch_end_);
return Status::OK();
}
......@@ -104,14 +106,14 @@ class DeviceQueueOp : public PipelineOp {
int32_t builder_device_id_;
DeviceType builder_device_type_;
std::string builder_channel_name_;
int64_t builder_num_batch_;
int32_t builder_op_connector_size_;
bool builder_send_epoch_end_;
};
// Name: constructor
// Description
DeviceQueueOp(std::string channel_name, DeviceType device_type, int32_t device_id, int32_t prefetch_size,
int32_t op_connector_size, int64_t num_batch);
int32_t op_connector_size, bool send_epoch_end);
// Name: destructor
// Description
......@@ -121,6 +123,8 @@ class DeviceQueueOp : public PipelineOp {
const int32_t get_prefetch_size() { return prefetch_size_; }
void StopSend() { stop_send_ = true; }
// Name: Print()
// Description: A function that prints info about the node
void Print(std::ostream &out, // In: The output stream to print to
......@@ -149,6 +153,7 @@ class DeviceQueueOp : public PipelineOp {
// Description: Check whether the dataBuffer meets the condition for performing DeviceQueueOp
Status CheckExceptions(const std::unique_ptr<DataBuffer> &buffer) const;
private:
#ifdef ENABLE_TDTQUE
Status SendDataToAscend();
#endif
......@@ -164,7 +169,8 @@ class DeviceQueueOp : public PipelineOp {
DeviceType device_type_;
const int32_t device_id_;
const int32_t prefetch_size_;
const int64_t num_batch_;
const bool send_epoch_end_;
bool stop_send_;
#ifdef ENABLE_TDTQUE
std::shared_ptr<TdtPlugin> tdtInstancePtr;
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iomanip>
#include <iostream>
#include <utility>
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/engine/db_connector.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
// The builder "build" method creates the final object.
Status EpochCtrlOp::Builder::Build(std::shared_ptr<EpochCtrlOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<EpochCtrlOp>(build_max_repeats_);
return Status::OK();
}
// Constructor
EpochCtrlOp::EpochCtrlOp(int32_t num_epoch) : RepeatOp(num_epoch) { MS_LOG(INFO) << "Welcome to Epoch Ctrl Op."; }
// Destructor
EpochCtrlOp::~EpochCtrlOp() {}
// A print method typically used for debugging
void EpochCtrlOp::Print(std::ostream &out, bool show_all) const {
// Always show the id and name as first line regardless if this summary or detailed print
out << "(" << std::setw(2) << operator_id_ << ") <EpochCtrlOp>:";
if (!show_all) {
// Call the super class for displaying any common 1-liner info
PipelineOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op
out << " [epochs: " << max_repeats_ << "]\n";
} else {
// Call the super class for displaying any common detailed info
PipelineOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nCurrent epoch count: " << repeat_count_ << "\nMax epoch count: " << max_repeats_
<< "\nLeaf Nodes in execution path:";
if (!eoe_ops_.empty()) {
for (size_t i = 0; i < eoe_ops_.size(); i++) {
out << "\n Operator: " << eoe_ops_[i]->id();
}
} else {
out << " None.";
}
out << "\n\n";
}
}
Status EpochCtrlOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) {
if (child_.empty()) {
RETURN_STATUS_UNEXPECTED("EpochCtrlOp can't be the leaf node.");
}
std::unique_ptr<DataBuffer> buf;
// `retry_if_eoe` is false because EpochCtrlOp does not eat EOE.
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, false));
// Only intercept EOE for EoeReceived processing, after that the EOE is forwarded to next op.
// Other databuffers containing data or EOF will simply be forwarded.
// EOF can simply be forwarded because this op does not spawn any thread, thus does not require clean up.
if (buf->eoe()) {
RETURN_IF_NOT_OK(EoeReceived(worker_id));
}
*p_buffer = std::move(buf);
return Status::OK();
}
Status EpochCtrlOp::EoeReceived(int32_t worker_id) {
repeat_count_++;
MS_LOG(DEBUG) << "Epoch Control operator received end of epoch. Epoch count is now: " << repeat_count_
<< ". Repeated: " << BitTest(op_ctrl_flags_, kDeOpRepeated) << ". Max epochs: " << max_repeats_;
// If we've reached the requested epoch count, then flag the leaf nodes
// to tell them they've got one more epoch to perform. When they reach the end
// of the last epoch, they quit rather than loop again.
if (max_repeats_ != kInfiniteRepeat && repeat_count_ == (max_repeats_ - 1)) {
for (auto &eoe_op : eoe_ops_) {
MS_LOG(DEBUG) << "EpochCtrl setting last repeat for eoe_op: " << eoe_op->id();
eoe_op->set_control_flag(kDeOpLastRepeat);
}
}
// This will allow GetNextInput in DatasetOp class to pass EOE buffer instead of eating it.
state_ = OpState::kDeOpIdle;
if (repeat_count_ != max_repeats_) {
for (auto &eoe_op : eoe_ops_) {
MS_LOG(DEBUG) << "Epoch Control driving reset to op: " << eoe_op->id();
RETURN_IF_NOT_OK(eoe_op->Reset());
}
}
return Status::OK();
}
// Pre-Visitor accept method for NodePass
Status EpochCtrlOp::PreAccept(NodePass *p, bool *modified) {
// Downcast shared pointer then call the pre-visitation
return p->PreRunOnNode(shared_from_base<EpochCtrlOp>(), modified);
}
// Visitor accept method for NodePass
Status EpochCtrlOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call the pre-visitation
return p->RunOnNode(shared_from_base<EpochCtrlOp>(), modified);
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_EPOCH_CTRL_OP_H_
#define DATASET_ENGINE_DATASETOPS_EPOCH_CTRL_OP_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/datasetops/repeat_op.h"
#include "minddata/dataset/engine/datasetops/pipeline_op.h"
namespace mindspore {
namespace dataset {
class EpochCtrlOp : public RepeatOp {
public:
class Builder : public RepeatOp::Builder {
public:
// Builder constructor. Creates the builder object.
// @note No default args
// @param count - The number of repeats to do
// @return This is a constructor.
explicit Builder(int32_t count) : RepeatOp::Builder(count) {}
// Default destructor
~Builder() = default;
// The builder "build" method creates the final object.
// @return shared_ptr to the new EpochCtrlOp object
Status Build(std::shared_ptr<EpochCtrlOp> *);
};
// Contructor
explicit EpochCtrlOp(int32_t num_epoch);
// Destructor
~EpochCtrlOp();
// A print method typically used for debugging
// @param out - The output stream to write output to
// @param show_all - A bool to control if you want to show all info or just a summary
void Print(std::ostream &out, bool show_all) const override;
// This function returns the buffer that is at the top of our output connector. The caller is
// typically our parent node, when the parent is asking us to provide the next buffer of data.
// Since EpochCtrlOp is derived from RepeatOp which is an inlined op, getting a buffer from us
// will simply bounce you to get a buffer from our child.
// Epoch Control Op does not eat the EOE, it will pass the EOE to the next op.
Status GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id, bool retry_if_eoe) override;
// Base-class override for handling cases when an eoe is received.
// @param worker_id - The worker id
Status EoeReceived(int32_t worker_id) override;
/// \brief Base-class override for NodePass pre-visit acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status PreAccept(NodePass *p, bool *modified) override;
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_EPOCH_CTRL_OP_H_
......@@ -132,6 +132,7 @@ Status RepeatOp::EoeReceived(int32_t worker_id) {
// Invoke a reset against the eoe nodes only.
for (auto &eoe_op : eoe_ops_) {
MS_LOG(DEBUG) << "Repeat operator sending reset to operator: " << eoe_op->id();
RETURN_IF_NOT_OK(eoe_op->Reset());
}
......@@ -167,8 +168,9 @@ int32_t RepeatOp::num_consumers() const {
Status RepeatOp::Reset() {
// If there's nested repeats, an ascendant repeat may have ourself listed as an eoe op.
// In that case, we now have to bounce the reset down to our own eoe ops.
MS_LOG(DEBUG) << "Repeat operator (" << operator_id_ << ") reset.";
MS_LOG(DEBUG) << "Repeat operator " << operator_id_ << " got reset.";
for (auto &eoe_op : eoe_ops_) {
MS_LOG(DEBUG) << "Nested repeat operator bouncing a reset to operator: " << eoe_op->id();
RETURN_IF_NOT_OK(eoe_op->Reset());
}
state_ = OpState::kDeOpRunning;
......
......@@ -46,7 +46,7 @@ class RepeatOp : public PipelineOp {
// @return shared_ptr to the new RepeatOp object
Status Build(std::shared_ptr<RepeatOp> *);
private:
protected:
int32_t build_max_repeats_;
Status SanityCheck() const;
......@@ -131,11 +131,11 @@ class RepeatOp : public PipelineOp {
// @return Name of the current Op
std::string Name() const override { return "RepeatOp"; }
/// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes
/// \param[in] eoe_op The input leaf/eoe operator to add to the list
// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes
// \param[in] eoe_op The input leaf/eoe operator to add to the list
void AddToEoeList(std::shared_ptr<DatasetOp> eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); }
private:
protected:
int32_t max_repeats_; // The number of repeats that the user requested
int32_t repeat_count_; // A counter for the current number of executed repeats
std::vector<std::shared_ptr<DatasetOp>> eoe_ops_; // List of operators that can generate EOE underneath this repeat.
......
......@@ -132,8 +132,9 @@ Status ZipOp::prepare(TensorQTable *const table) {
if (eof_) {
return Status::OK();
}
// One of our child iterators encounter EOE. Returns and proceed with draining phase.
if (new_row.empty()) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "ZipOp prepare phase got empty row!");
return Status::OK();
}
// Pack this first row into our tensor table
......
......@@ -23,6 +23,7 @@
#include "minddata/dataset/engine/opt/pre/removal_pass.h"
#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h"
#include "minddata/dataset/engine/opt/post/repeat_pass.h"
#include "minddata/dataset/engine/opt/pre/injection_pass.h"
#include "mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h"
#include "minddata/dataset/engine/perf/profiling.h"
#include "minddata/dataset/engine/perf/monitor.h"
......@@ -50,11 +51,11 @@ Status ExecutionTree::AssociateNode(const std::shared_ptr<DatasetOp> &op) {
if (op->tree_ == this) {
return Status::OK();
}
if (tree_state_ != kDeTStateInit && tree_state_ != kDeTStateBuilding) {
if (tree_state_ != kDeTStateInit && tree_state_ != kDeTStateBuilding && tree_state_ != kDeTStatePrepare) {
std::string err_msg =
"Invalid tree state for adding a node. Current state: " + std::to_string(static_cast<int>(tree_state_)) +
" Expected states: " + std::to_string(static_cast<int>(kDeTStateInit)) + " or " +
std::to_string(static_cast<int>(kDeTStateBuilding));
std::to_string(static_cast<int>(kDeTStateBuilding)) + " or " + std::to_string(static_cast<int>(kDeTStatePrepare));
RETURN_STATUS_UNEXPECTED(err_msg);
}
......@@ -200,7 +201,9 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui
// For example, repeatOp inlining
//
// @return Status - The error code return
Status ExecutionTree::Prepare() {
Status ExecutionTree::Prepare(int32_t num_epochs) {
num_epochs_ = num_epochs;
// Pre optimization compulsory transformation
RETURN_IF_NOT_OK(this->PrepareTreePreAction());
......@@ -222,6 +225,7 @@ Status ExecutionTree::PrepareTreePreAction() {
std::vector<std::unique_ptr<Pass>> pre_actions;
// Construct pre actions
MS_LOG(INFO) << "Running pre pass loops.";
pre_actions.push_back(std::make_unique<InjectionPass>());
pre_actions.push_back(std::make_unique<RemovalPass>());
pre_actions.push_back(std::make_unique<CacheTransformPass>());
// Apply pre action passes
......@@ -278,6 +282,11 @@ Status ExecutionTree::PrepareDeprecated() {
" Expected state: " + std::to_string(static_cast<int>(kDeTStatePrepare));
RETURN_STATUS_UNEXPECTED(err_msg);
}
if (root_ == nullptr) {
RETURN_STATUS_UNEXPECTED("Please assign one operator as the root of this tree.");
}
// Start the recursive prepare
RETURN_IF_NOT_OK(this->PrepareNode(root_));
tree_state_ = kDeTStateReady;
......
......@@ -176,7 +176,7 @@ class ExecutionTree {
// For example, repeatOp inlining
//
// @return Status - The error code return
Status Prepare();
Status Prepare(int num_epochs = -1);
// Compulsory transformation/action pre optimization.
// @return Status - The error code return
......@@ -193,6 +193,7 @@ class ExecutionTree {
// The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively
// walk the tree to perform modifications to the tree or specific nodes within the tree to get
// it ready for execution.
// @param Total number of epochs that will be run on this tree
// @return Status - The error code return
Status PrepareDeprecated();
......@@ -231,6 +232,10 @@ class ExecutionTree {
// Optional optimizations status
bool OptimizationEnabled() const { return optimize_; }
// Getter function to get the total number of epochs to be run on this tree.
// @return total number of epochs
int32_t num_epochs() { return num_epochs_; }
private:
// A helper functions for doing the recursive printing
// @param dataset_op - The dataset op to print
......@@ -245,6 +250,7 @@ class ExecutionTree {
int32_t id_count_; // Counter for generating operator id's
uint32_t prepare_flags_; // Flags used during tree prepare
TreeState tree_state_; // Tracking the current tree state
int32_t num_epochs_; // Total number of epochs to run for this tree
std::unique_ptr<Monitor> perf_monitor_; // Performance Monitor
std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager
bool optimize_; // Flag to enable optional optimizations
......
......@@ -5,6 +5,7 @@ add_library(engine-opt OBJECT
post/repeat_pass.cc
pre/cache_pass.cc
pre/cache_transform_pass.cc
pre/injection_pass.cc
pre/removal_nodes.cc
pre/removal_pass.cc
optional/tensor_op_fusion_pass.cc
......
......@@ -16,11 +16,13 @@
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/datasetops/batch_op.h"
#include "minddata/dataset/engine/datasetops/build_vocab_op.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
#include "minddata/dataset/engine/datasetops/dataset_op.h"
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/engine/datasetops/map_op.h"
#include "minddata/dataset/engine/datasetops/project_op.h"
#include "minddata/dataset/engine/datasetops/rename_op.h"
......@@ -230,6 +232,11 @@ Status NodePass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified)
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
......@@ -244,5 +251,15 @@ Status NodePass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) {
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) {
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
} // namespace dataset
} // namespace mindspore
......@@ -77,6 +77,10 @@ class CacheMergeOp;
class CacheLookupOp;
class EpochCtrlOp;
class BuildVocabOp;
// The base class Pass is the basic unit of tree transformation.
// The actual implementation of the passes will be derived from here.
class Pass : public std::enable_shared_from_this<Pass> {
......@@ -190,12 +194,18 @@ class NodePass : public Pass {
virtual Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified);
private:
// Helper function to perform DFS visit
Status DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified);
......
......@@ -20,6 +20,7 @@
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
namespace mindspore {
namespace dataset {
......@@ -28,6 +29,9 @@ RepeatPass::RepeatPass() : is_repeated_(false), nested_repeats_(0), is_merge_(fa
// Identifies the subtree below this node as being in a repeated path of the tree.
Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
// Create a new stack for eoe operators and push onto our stack of stacks.
std::unique_ptr<eoe_op_stack> new_stack = std::make_unique<eoe_op_stack>();
eoe_op_stacks_.push(std::move(new_stack));
// If we are already repeated, then this is a nested repeat.
if (is_repeated_) {
nested_repeats_++;
......@@ -36,6 +40,18 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified)
return Status::OK();
}
// Identifies the subtree below this node as being in a repeated path of the tree.
Status RepeatPass::PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) {
// EpochCtrl is derived from RepeatOp. Generally it should do the identical setup
// that RepeatOp does. However, epoch control is actually simpler because it can
// only exist as the root node so it doesn't need all the nested code.
// Create a new stack for eoe operators and push onto our stack of stacks.
std::unique_ptr<eoe_op_stack> new_stack = std::make_unique<eoe_op_stack>();
eoe_op_stacks_.push(std::move(new_stack));
is_repeated_ = true;
return Status::OK();
}
// Identifies the subtree below this node as being in a cache merge path
Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
// Turn on the flag that we're under a merge op
......@@ -47,13 +63,24 @@ Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modifi
Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
// Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking
std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack();
while (leaf_op != nullptr) {
node->AddToEoeList(leaf_op);
leaf_op = PopFromEOEOpStack();
}
// At this point, we are done with the save area stack. It's a unique pointer to an empty stack
// at this time, so we can pop it to get rid of it.
eoe_op_stack *current_stack = eoe_op_stacks_.top().get();
if (!current_stack->empty()) {
RETURN_STATUS_UNEXPECTED("The eoe op stack should be empty right now!");
}
eoe_op_stacks_.pop();
// We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up
// and add it to the list of eoe/leaf ops for the repeat, removing it from the save area.
// and add it to the list of eoe/leaf ops for the repeat. It is important that the op is removed
// from the save area, because the merge op above us may also take action on it later for a different
// case when there is no repeat in the merge leg.
if (is_merge_ && cache_lookup_) {
cache_lookup_->set_control_flag(DatasetOp::kDeOpRepeated);
node->AddToEoeList(std::move(cache_lookup_));
......@@ -65,16 +92,29 @@ Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
node->set_control_flag(DatasetOp::kDeOpRepeated);
AddToEOEOpStack(node);
nested_repeats_--;
}
// If we are not nested, or we were the top-most repeat, now we clear the flag
if (nested_repeats_ == 0) {
} else {
// If we are not nested, or we were the top-most repeat, now we clear the flag
if (nested_repeats_ != 0) {
RETURN_STATUS_UNEXPECTED("Nested repeat counter cannot be negative!");
}
is_repeated_ = false;
}
return Status::OK();
}
// Hooks up any identified eoe nodes under this repeat.
Status RepeatPass::RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) {
// Pop the leaf ops from the save-area stack and add them to the eoe node tracking
std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack();
while (leaf_op != nullptr) {
node->AddToEoeList(leaf_op);
leaf_op = PopFromEOEOpStack();
}
is_repeated_ = false;
return Status::OK();
}
// CacheOp removes previous leaf ops and replaces them with itself
Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
if (is_repeated_) {
......@@ -118,9 +158,16 @@ Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) {
// Turns off the tracking for operations under merge op
Status RepeatPass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
// Setting the flag is needed since we didn't call the base class DatasetOp version
if (is_repeated_) node->set_control_flag(DatasetOp::kDeOpRepeated);
if (is_repeated_) {
node->set_control_flag(DatasetOp::kDeOpRepeated);
// If there was not any repeat in the merge cache miss leg, then the cache_lookup
// would not have been consumed yet. In that case, we need to assign it to the upper repeat eoe stack
if (cache_lookup_) {
AddToEOEOpStack(std::move(cache_lookup_));
}
}
cache_lookup_.reset(); // If we are not repeated then the saved lookup is no longer needed or used
is_merge_ = false;
cache_lookup_.reset(); // If a repeat op did not consume this then it's no longer needed
return Status::OK();
}
......@@ -135,25 +182,32 @@ Status RepeatPass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified
// In this case, we naturally are a repeating leaf op so add the required setup for leafs under repeat here.
if (is_repeated_) {
node->set_control_flag(DatasetOp::kDeOpRepeated);
AddToEOEOpStack(node);
} else {
// save the lookup op. There could be a repeat in the cache miss leg of the merge op, in which case we
// may still need to be flagged as a repeating leaf. We can't decide that here though, so save ourself
// into the pass so that the decision can be made during the processing of the cache miss leg of the merge.
cache_lookup_ = std::static_pointer_cast<DatasetOp>(node);
// Delay the assigment of this leap to the eoe stack and allow the merge op processing to handle that.
}
// save the lookup op. There could be a repeat in the cache miss leg of the merge op, in which case we
// may still need to be flagged as a repeating leaf. We can't decide that here though, so save ourself
// into the pass so that the decision can be made during the processing of the cache miss leg of the merge.
// Further, if there's a repeat above the merge but no repeat in the cache miss leg, then the merge op will
// add the lookup to the eoe stack
cache_lookup_ = std::static_pointer_cast<DatasetOp>(node);
return Status::OK();
}
// Adds an operator to the eoe operator stack save area
void RepeatPass::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) { eoe_stack_.push(dataset_op); }
void RepeatPass::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) {
eoe_op_stack *current_stack = eoe_op_stacks_.top().get();
current_stack->push(dataset_op);
}
// Pops an operator from the eoe operator stack save area
std::shared_ptr<DatasetOp> RepeatPass::PopFromEOEOpStack() {
std::shared_ptr<DatasetOp> top_op = nullptr;
if (!eoe_stack_.empty()) {
top_op = eoe_stack_.top();
eoe_stack_.pop();
eoe_op_stack *current_stack = eoe_op_stacks_.top().get();
if (current_stack != nullptr && !current_stack->empty()) {
top_op = current_stack->top();
current_stack->pop();
}
return top_op;
}
......
......@@ -30,6 +30,8 @@ namespace dataset {
/// to the eoe-producing (typically leaf) nodes underneath it.
class RepeatPass : public NodePass {
public:
using eoe_op_stack = std::stack<std::shared_ptr<DatasetOp>>;
/// \brief Constructor
RepeatPass();
......@@ -39,6 +41,12 @@ class RepeatPass : public NodePass {
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override;
/// \brief Identifies the subtree below this node as being in a repeated path of the tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override;
/// \brief Identifies the subtree below this node as being in a cache merge path
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
......@@ -51,6 +59,12 @@ class RepeatPass : public NodePass {
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override;
/// \brief Hooks up any identified eoe nodes under this repeat.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<EpochCtrlOp> node, bool *modified) override;
/// \brief CacheOp removes previous leaf ops and replaces them with itself
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
......@@ -86,11 +100,11 @@ class RepeatPass : public NodePass {
/// \return shared_ptr to the popped operator
std::shared_ptr<DatasetOp> PopFromEOEOpStack();
bool is_repeated_; // T/F if we are processing under a repeat
bool is_merge_; // T/F if we are processing under a cache merge op
int32_t nested_repeats_; // A counter for nested repeats
std::stack<std::shared_ptr<DatasetOp>> eoe_stack_; // A save area for leaf/eoe ops
std::shared_ptr<DatasetOp> cache_lookup_; // A save area for a cache lookup op
bool is_repeated_; // T/F if we are processing under a repeat
bool is_merge_; // T/F if we are processing under a cache merge op
int32_t nested_repeats_; // A counter for nested repeats
std::stack<std::unique_ptr<eoe_op_stack>> eoe_op_stacks_; // A save area for leaf/eoe ops (with nesting)
std::shared_ptr<DatasetOp> cache_lookup_; // A save area for a cache lookup op
};
} // namespace dataset
} // namespace mindspore
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include <algorithm>
#include "minddata/dataset/engine/opt/pre/injection_pass.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
namespace mindspore {
namespace dataset {
// constructor
InjectionPass::InjectionFinder::InjectionFinder(InjectionPass *injection_pass) : injection_pass_(injection_pass) {}
// Performs finder work for BuildVocabOp that has special rules about epoch control injection
Status InjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) {
if (injection_pass_) {
injection_pass_->epoch_ctrl_bypass_ = true;
return Status::OK();
} else {
RETURN_STATUS_UNEXPECTED("Missing outer injection pass object from inside InjectionFinder!");
}
}
// Temporary code to prevent the injection of epoch control when cache op is present
// Remove this code in cache op phase 2
Status InjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
if (injection_pass_) {
injection_pass_->epoch_ctrl_bypass_ = true;
return Status::OK();
} else {
RETURN_STATUS_UNEXPECTED("Missing outer injection pass object from inside InjectionFinder!");
}
}
// constructor
InjectionPass::InjectionPass() : epoch_ctrl_bypass_(false) {}
// Runs an injection pass to inject in operators needed at the pre pass stage
Status InjectionPass::RunOnTree(ExecutionTree *tree, bool *modified) {
MS_LOG(INFO) << "Pre pass: Injection pass started.";
// First, run the finder to perform any injection info before we can go ahead to drive the op injection work.
// The finder can make updates to the InjectionPass object.
InjectionPass::InjectionFinder finder(this);
finder.Run(tree, modified);
// The first injection logic is to check if we should inject the epoch control op as the root node.
// Do not inject the op if the number of epochs is 1.
int32_t num_epochs = tree->num_epochs();
if (num_epochs != 1 && !epoch_ctrl_bypass_) {
std::shared_ptr<EpochCtrlOp> epoch_ctrl_op;
RETURN_IF_NOT_OK(EpochCtrlOp::Builder(num_epochs).Build(&epoch_ctrl_op));
RETURN_IF_NOT_OK(tree->AssociateNode(epoch_ctrl_op));
std::shared_ptr<DatasetOp> node = tree->root();
if (std::dynamic_pointer_cast<DeviceQueueOp>(node) == nullptr) {
tree->root()->InsertAsParent(epoch_ctrl_op);
} else {
tree->root()->child(0)->InsertAsParent(epoch_ctrl_op);
}
}
MS_LOG(INFO) << "Pre pass: Injection pass complete.";
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_
#define DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_
#include <memory>
#include <vector>
#include "minddata/dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
class DatasetOp;
/// \class InjectionPass injection_pass.h
/// \brief This is a pre pass that drives the injection of any nodes that could not be directly injected from the api
/// parsing.
class InjectionPass : public TreePass {
/// \class InjectionFinder
/// \brief This is a nested node pass class who's job is to parse the tree and perform any identification logic for
/// operators that need to be injected. It is run first by the main injection pass to find out what operators
/// it may need to inject.
class InjectionFinder : public NodePass {
public:
/// \brief Constructor
explicit InjectionFinder(InjectionPass *injection_pass);
/// \brief Performs finder work for BuildVocabOp that has special rules about epoch control injection.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) override;
/// \brief Temporary code to prevent the injection of epoch control when cache op is present.
/// Remove this code in cache op phase 2
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
private:
InjectionPass *injection_pass_;
};
public:
/// \brief Constructor
InjectionPass();
/// \brief Runs an injection pass to inject in operators needed at the pre pass stage
/// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
Status RunOnTree(ExecutionTree *tree, bool *modified) override;
private:
bool epoch_ctrl_bypass_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_
......@@ -29,20 +29,27 @@ std::shared_ptr<TdtPlugin> TdtPlugin::GetInstance() {
return instance_ptr_;
}
TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profiling, int32_t &time) {
TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profiling, int32_t &time,
tdt::TdtDataType tdt_type) {
MS_LOG(DEBUG) << "TDT channel name is " << channel_name << ".";
std::vector<DataItem> items;
double start_time;
auto ret = translate(ts_row, items);
if (ret != SUCCESS) {
MS_LOG(ERROR) << "TDT converting tensor failed!";
return FAILED;
if (tdt_type == tdt::TDT_TENSOR) {
auto ret = translate(ts_row, items);
if (ret != SUCCESS) {
MS_LOG(ERROR) << "TDT converting tensor failed!";
return FAILED;
}
} else if (tdt_type == tdt::TDT_END_OF_SEQUENCE) {
DataItem data_item;
data_item.dataType_ = tdt::TDT_END_OF_SEQUENCE;
items.emplace_back(data_item);
MS_LOG(INFO) << "TDT data type is TDT_END_OF_SEQUENCE";
}
if (profiling) {
start_time = ProfilingTime::GetCurMilliSecond();
}
if (tdt::TdtHostPushData(channel_name, items) != 0) {
MS_LOG(ERROR) << "TDT pushing data failed!";
return FAILED;
}
if (profiling) {
......@@ -122,8 +129,8 @@ TdtStatus TdtPlugin::translate(const TensorRow &ts_row, std::vector<DataItem> &i
data_item.dataPtr_ =
std::shared_ptr<void>(reinterpret_cast<uchar *>(&(*ts->begin<uint8_t>())), [](const void *elem) {});
items.emplace_back(data_item);
MS_LOG(DEBUG) << "TDT data type is " << datatype << ", data shape is " << dataShapes << ", data length is "
<< ts->Size() << ".";
MS_LOG(INFO) << "TDT data type is TDT_TENSOR, tensor type is " << datatype << ", tensor shape is " << dataShapes
<< ", data length is " << ts->Size() << ".";
}
return SUCCESS;
}
......
......@@ -38,7 +38,8 @@ class TdtPlugin {
public:
static std::shared_ptr<TdtPlugin> GetInstance();
TdtStatus hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profilig, int32_t &time);
TdtStatus hostPush(TensorRow ts_row, bool is_wait, std::string channel_name, bool profilig, int32_t &time,
tdt::TdtDataType tdt_type = tdt::TDT_TENSOR);
private:
TdtPlugin() {}
......
......@@ -797,6 +797,9 @@ bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t ba
(void)InitBackend();
}
#endif
if (iter_num == -1) {
iter_num = INT32_MAX;
}
if (name == kMsConvert || name == kMsVm) {
return InitExecDatasetVm(queue_name, iter_num, batch_size, types, shapes, input_indexes, need_run);
}
......
......@@ -44,7 +44,7 @@ from .validators import check_batch, check_shuffle, check_map, check_filter, che
check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \
check_tfrecorddataset, check_vocdataset, check_cocodataset, check_celebadataset, check_minddataset, \
check_generatordataset, check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset, check_concat, \
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_positive_int32, check_save
check_random_dataset, check_split, check_bucket_batch_by_length, check_cluedataset, check_save
from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist
try:
......@@ -946,14 +946,14 @@ class Dataset:
raise TypeError("apply_func must return a dataset.")
return dataset
@check_positive_int32
def device_que(self, prefetch_size=None):
def device_que(self, prefetch_size=None, send_epoch_end=True):
"""
Return a transferredDataset that transfer data through device.
Args:
prefetch_size (int, optional): prefetch number of records ahead of the
user's request (default=None).
send_epoch_end (bool, optional): whether send end of sequence to device or not.(default=True)
Note:
If device is Ascend, features of data will be transferred one by one. The limitation
......@@ -962,15 +962,14 @@ class Dataset:
Return:
TransferDataset, dataset for transferring.
"""
return self.to_device()
return self.to_device(send_epoch_end=send_epoch_end)
@check_positive_int32
def to_device(self, num_batch=None):
def to_device(self, send_epoch_end=True):
"""
Transfer data through CPU, GPU or Ascend devices.
Args:
num_batch (int, optional): limit the number of batch to be sent to device (default=None).
send_epoch_end (bool, optional): whether send end of sequence to device or not.(default=True)
Note:
If device is Ascend, features of data will be transferred one by one. The limitation
......@@ -982,19 +981,9 @@ class Dataset:
Raises:
TypeError: If device_type is empty.
ValueError: If device_type is not 'Ascend', 'GPU' or 'CPU'.
ValueError: If num_batch is not positive or larger than int_max.
ValueError: If dataset size is None or 0.
RuntimeError: If dataset is unknown.
RuntimeError: If distribution file path is given but failed to read.
"""
if self.get_dataset_size() is None or 0:
raise ValueError("dataset size is None or 0.")
if num_batch is None:
num_batch = self.get_dataset_size()
repeat_count = self.get_repeat_count()
num_batch = num_batch * repeat_count
queue_name = str(uuid.uuid1())
if context:
......@@ -1008,9 +997,6 @@ class Dataset:
if device_type not in ('Ascend', 'GPU', 'CPU'):
raise ValueError("Only support CPU, Ascend, GPU")
if num_batch == 0:
raise ValueError("num_batch is 0.")
def get_distribution(output_dataset):
dev_id = 0
if isinstance(output_dataset, (Cifar10Dataset, Cifar100Dataset, GeneratorDataset, ImageFolderDatasetV2,
......@@ -1032,7 +1018,7 @@ class Dataset:
distribution_path, device_id = get_distribution(self)
if distribution_path == "":
return TransferDataset(self, queue_name, device_id, device_type, num_batch)
return TransferDataset(self, queue_name, device_id, device_type, send_epoch_end)
try:
with open(distribution_path, 'r') as distribution_f:
dist = json.load(distribution_f)
......@@ -1042,7 +1028,7 @@ class Dataset:
except Exception:
raise RuntimeError("Distribution file failed to read")
return TransferDataset(self, queue_name, device_id, device_type, num_batch)
return TransferDataset(self, queue_name, device_id, device_type, send_epoch_end)
@check_save
def save(self, file_name, num_files=1, file_type='mindrecord'):
......@@ -1072,7 +1058,7 @@ class Dataset:
return SaveOp(self).save(file_names, file_type)
def create_tuple_iterator(self, columns=None):
def create_tuple_iterator(self, columns=None, num_epochs=-1):
"""
Create an Iterator over the dataset. The data retrieved will be a list of ndarray of data.
......@@ -1098,9 +1084,9 @@ class Dataset:
"""
if self._noop_mode():
return DummyIterator(self, 'tuple')
return TupleIterator(self, columns)
return TupleIterator(self, columns, num_epochs)
def create_dict_iterator(self):
def create_dict_iterator(self, num_epochs=-1):
"""
Create an Iterator over the dataset.
......@@ -1123,7 +1109,7 @@ class Dataset:
"""
if self._noop_mode():
return DummyIterator(self, 'dict')
return DictIterator(self)
return DictIterator(self, num_epochs)
def __iter__(self):
"""Create an Iterator over the dataset."""
......@@ -1149,7 +1135,7 @@ class Dataset:
self._batch_size = device_iter.get_batch_size()
self._num_classes = device_iter.num_classes()
self._repeat_count = device_iter.get_repeat_count()
device_iter.release()
device_iter.stop()
def output_shapes(self):
"""
......@@ -2085,7 +2071,7 @@ class RepeatDataset(DatasetOp):
"""
child_size = self.children[0].get_dataset_size()
if child_size is not None:
return child_size
return child_size * self.count
return None
def get_repeat_count(self):
......@@ -2097,7 +2083,6 @@ class RepeatDataset(DatasetOp):
"""
return self.count
class SkipDataset(DatasetOp):
"""
The result of applying Skip operator to the input Dataset.
......@@ -2317,10 +2302,10 @@ class TransferDataset(DatasetOp):
queue_name (str): Name of device queue.
device_id (int): Id of device.
device_type (str): Type of device, including "CPU", "GPU", and "Ascend".
num_batch (int): limit the number of batch to be sent to device (default=None).
send_epoch_end (bool, optional): Whether send end of sequence to device or not.(default=True)
"""
def __init__(self, input_dataset, queue_name, device_id, device_type, num_batch=None):
def __init__(self, input_dataset, queue_name, device_id, device_type, send_epoch_end=True):
super().__init__()
self.children.append(input_dataset)
input_dataset.parent.append(self)
......@@ -2328,7 +2313,7 @@ class TransferDataset(DatasetOp):
self._input_indexs = input_dataset.input_indexs
self._device_type = device_type
self._device_id = device_id
self.__num_batch = num_batch
self._send_epoch_end = send_epoch_end
self.iterator = None
def get_args(self):
......@@ -2336,13 +2321,13 @@ class TransferDataset(DatasetOp):
args["queue_name"] = self.queue_name
args["device_type"] = self._device_type
args["device_id"] = self._device_id
args["num_batch"] = self.__num_batch
args["send_epoch_end"] = self._send_epoch_end
return args
def create_dict_iterator(self):
def create_dict_iterator(self, num_epochs=-1):
raise RuntimeError("TransferDataset is not iterable")
def create_tuple_iterator(self, columns=None):
def create_tuple_iterator(self, columns=None, num_epochs=-1):
raise RuntimeError("TransferDataset is not iterable")
def __iter__(self):
......@@ -2354,12 +2339,14 @@ class TransferDataset(DatasetOp):
def output_types(self):
raise RuntimeError("TransferDataset does not support output_types")
def send(self):
def send(self, num_epochs=-1):
# need to keep iterator alive so the executionTree is not destroyed
if self._noop_mode():
return
self.iterator = TupleIterator(self)
self.iterator = TupleIterator(self, num_epochs=-1)
def stop_send(self):
self.iterator.depipeline.StopSend()
class RangeDataset(MappableDataset):
"""
......
......@@ -29,7 +29,6 @@ from . import datasets as de
ITERATORS_LIST = list()
def _cleanup():
"""Release all the Iterator."""
for itr_ref in ITERATORS_LIST:
......@@ -60,7 +59,6 @@ def _alter_node(node):
node.iterator_bootstrap()
return node
class Iterator:
"""
General Iterator over a dataset.
......@@ -69,10 +67,21 @@ class Iterator:
dataset: Dataset to be iterated over
"""
def __init__(self, dataset):
def __init__(self, dataset, num_epochs=-1):
self.num_epochs = num_epochs
ITERATORS_LIST.append(weakref.ref(self))
# create a copy of tree and work on it.
self.dataset = copy.deepcopy(dataset)
self.parent_subtree = []
# The dataset passed into the iterator is not the root of the tree.
# Trim the tree by saving the parent subtree into self.parent_subtree and
# restore it after launching our c++ pipeline.
if self.dataset.parent:
logger.warning("The dataset passed in is not the root of the pipeline. Ignoring parent subtree.")
self.parent_subtree = self.dataset.parent
self.dataset.parent = []
self.dataset = alter_tree(self.dataset)
if not self.__is_tree():
raise ValueError("The data pipeline is not a tree (i.e., one node has 2 consumers)")
......@@ -83,9 +92,17 @@ class Iterator:
root = self.__convert_node_postorder(self.dataset)
self.depipeline.AssignRootNode(root)
self.depipeline.LaunchTreeExec()
self.depipeline.LaunchTreeExec(self.num_epochs)
self._index = 0
def stop(self):
"""
Manually terminate python iterator instead of relying on out of scope destruction.
"""
logger.info("terminating python iterator. This will also terminate c++ pipeline.")
if hasattr(self, 'depipeline') and self.depipeline:
del self.depipeline
def __is_tree_node(self, node):
"""Check if a node is tree node."""
if not node.children:
......@@ -214,9 +231,14 @@ class Iterator:
@abstractmethod
def get_next(self):
pass
raise RuntimeError("Calling base class Iterator's get_next is invalid.")
def __next__(self):
if not self.depipeline:
logger.warning("Iterator does not have a running c++ pipeline." +
"It can be because Iterator stop() had been called, or c++ pipeline crashed silently.")
raise RuntimeError("Iterator does not have a running c++ pipeline.")
data = self.get_next()
if not data:
if self._index == 0:
......@@ -293,12 +315,12 @@ class TupleIterator(Iterator):
def check_node_type(self, node):
pass
def __init__(self, dataset, columns=None):
def __init__(self, dataset, columns=None, num_epochs=-1):
if columns is not None:
if not isinstance(columns, list):
columns = [columns]
dataset = dataset.project(columns)
super().__init__(dataset)
super().__init__(dataset, num_epochs)
def __iter__(self):
return self
......
......@@ -57,7 +57,8 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
# transform data format
dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
exec_dataset = exec_dataset.device_que()
send_epoch_end = bool(dataset_size == -1)
exec_dataset = exec_dataset.device_que(send_epoch_end=send_epoch_end)
_executor.init_dataset(exec_dataset.queue_name,
dataset_size,
......@@ -126,7 +127,7 @@ def _construct_tensor_list(types, shapes, batch_expand_num=1):
def _to_tensor(elem, scaling_sens=None):
"""Conver numpy to tensor, adapt to minddata feed solution."""
"""Convert numpy to tensor, adapt to feed the data from host solution."""
lst = []
if not isinstance(elem, (tuple, list)):
elem = [elem]
......@@ -145,7 +146,8 @@ def _to_tensor(elem, scaling_sens=None):
def _to_full_tensor(elem, device_num, global_rank, scaling_sens=None):
"""Conver numpy to tensor, expanding batch dimension according to device_num, adapt to minddata feed solution."""
"""Convert numpy to tensor, expanding batch dimension according to device_num, adapt to feed the data
from host solution."""
lst = []
if not isinstance(elem, (tuple, list)):
elem = [elem]
......
......@@ -16,7 +16,7 @@
import math
import os
from mindspore._checkparam import check_bool
from mindspore._checkparam import check_bool, check_int
from .. import context
from ._utils import _exec_datagraph, _get_types_and_shapes, _to_tensor, \
_construct_tensor_list, _to_full_shapes, _to_full_tensor
......@@ -42,17 +42,23 @@ class DatasetHelper:
The iter of DatasetHelper will give one epoch data.
Args:
dataset (DataSet): The dataset.
dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host.
Default: True.
dataset (DataSet): The training dataset iterator.
dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host. Default: True.
sink_size (int): Control the amount of data each sink.
If sink_size=-1, sink the complete dataset each epoch.
If sink_size>0, sink sink_size data each epoch. Default: -1.
Examples:
>>> dataset_helper = DatasetHelper(dataset)
>>> for inputs in dataset_helper:
>>> outputs = network(*inputs)
"""
def __init__(self, dataset, dataset_sink_mode=True):
def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1):
check_bool(dataset_sink_mode)
check_int(sink_size)
if sink_size < -1 or sink_size == 0:
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
if dataset_sink_mode:
if context.get_context("enable_ge"):
......@@ -68,9 +74,10 @@ class DatasetHelper:
iterclass = _DatasetIterMS
elif context.get_context("device_target") == "CPU":
raise RuntimeError("Currently dataset sink mode is not supported when the device target is CPU.")
self.iter = iterclass(dataset, sink_size)
else:
iterclass = _DatasetIterFeed
self.iter = iterclass(dataset)
iterclass = _DatasetIterNormal
self.iter = iterclass(dataset)
def __iter__(self):
return self.iter.__iter__()
......@@ -80,21 +87,26 @@ class DatasetHelper:
"""Get the types and shapes from dataset on current config."""
return self.iter.types_shapes()
def loop_size(self):
"""Get loop_size for every iteration."""
return self.iter.loop_size
def sink_size(self):
"""Get sink_size for every iteration."""
return self.iter.get_sink_size()
def stop_send(self):
"""Free up resources about data sink."""
self.iter.stop_send()
class _DatasetIter:
"""Base iter for dataset help"""
def __init__(self, dataset):
if not hasattr(dataset, '__loop_size__'):
self.loop_size = dataset.get_dataset_size()
else:
self.loop_size = dataset.__loop_size__
"""Base iter for dataset helper"""
def __init__(self, dataset, sink_size):
self.dataset = dataset
self.sink_size = sink_size
self.sink_count = 1
if not hasattr(dataset, '__ME_INITED__'):
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size)
if not hasattr(dataset, '__TRANSFER_DATASET__'):
if hasattr(dataset, '__loop_size__'):
self.sink_size = dataset.__loop_size__
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size)
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
if not hasattr(dataset, '__no_send__'):
......@@ -102,43 +114,70 @@ class _DatasetIter:
else:
_send_data(dataset)
self.ind = 0
self.dataset = dataset
dataset_types, dataset_shapes = _get_types_and_shapes(dataset)
self.dataset_types, self.dataset_shapes = dataset_types, dataset_shapes
self.stop_send = dataset.__TRANSFER_DATASET__.stop_send
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
def __iter__(self):
self.ind = 0
self.index = 0
return self
def __next__(self):
if self.ind >= self.loop_count:
if self.index >= self.sink_count:
raise StopIteration()
self.ind += 1
self.index += 1
return self.op()
def types_shapes(self):
return self.dataset_types, self.dataset_shapes
def get_loop_count(self, dataset):
loop_count = 1
def get_sink_count(self, dataset):
sink_count = 1
if hasattr(dataset, '__loop_size__'):
loop_size = dataset.__loop_size__
if loop_size <= dataset.get_dataset_size() and dataset.get_dataset_size() % loop_size != 0:
raise ValueError(f'Dataset size {dataset.get_dataset_size()} and '
f'loop_size {loop_size} are not matched.')
loop_count = math.ceil(dataset.get_dataset_size() / loop_size)
return loop_count
f'sink_size {loop_size} are not matched.')
sink_count = math.ceil(dataset.get_dataset_size() / loop_size)
return sink_count
def get_sink_size(self):
"""get sink_size to device"""
sink_size = 1
if hasattr(self.dataset, '__loop_size__'):
sink_size = self.dataset.__loop_size__
else:
if context.get_context("enable_ge") or context.get_context("device_target") == "Ascend":
if self.sink_size > 0:
sink_size = self.sink_size
else:
sink_size = self.dataset.get_dataset_size()
return sink_size
class _DatasetIterGE(_DatasetIter):
"""Iter for GE."""
def __init__(self, dataset, sink_size):
super().__init__(dataset, sink_size)
self.sink_count = self.get_sink_count(dataset)
batch_expand_num = 1
if _need_to_full():
batch_expand_num = _get_device_num()
tensor_list_run = _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num)
def op():
return tensor_list_run
self.op = op
class _DatasetIterMSLoopSink(_DatasetIter):
"""Iter for context (device_target=Ascend)"""
def __init__(self, dataset):
super(_DatasetIterMSLoopSink, self).__init__(dataset)
self.loop_count = self.get_loop_count(dataset)
def __init__(self, dataset, sink_size):
super().__init__(dataset, sink_size)
self.sink_count = self.get_sink_count(dataset)
ms_role = os.getenv("MS_ROLE")
if ms_role in ("MS_PSERVER", "MS_SCHED"):
self.loop_count = 1
self.sink_count = 1
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch,
# use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for
# compile is device_number times the batch dimension of tensors for run. Now only support LoopSink.
......@@ -153,66 +192,42 @@ class _DatasetIterMSLoopSink(_DatasetIter):
class _DatasetIterMS(_DatasetIter):
"""Iter for context (device_target=GPU)"""
def __init__(self, dataset):
super(_DatasetIterMS, self).__init__(dataset)
self.loop_count = dataset.get_dataset_size()
self.loop_size = 1
"""Iter for MS(enable_loop_sink=False)."""
def __init__(self, dataset, sink_size):
super().__init__(dataset, sink_size)
if sink_size > 0:
self.sink_count = sink_size
else:
self.sink_count = dataset.get_dataset_size()
queue_name = dataset.__ME_INITED__
self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name)
class _DatasetIterPSLite(_DatasetIter):
"""Iter for context (device_target=GPU) on MS_PSERVER or MS_SCHED"""
def __init__(self, dataset):
super(_DatasetIterPSLite, self).__init__(dataset)
self.loop_count = 1
self.loop_size = 1
def __init__(self, dataset, sink_size):
super().__init__(dataset, sink_size)
self.sink_count = 1
self.sink_size = 1
self.op = None
def op():
return _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num=1)
self.op = op
class _DatasetIterGE(_DatasetIter):
"""Iter for ge"""
def __init__(self, dataset):
super(_DatasetIterGE, self).__init__(dataset)
self.loop_count = self.get_loop_count(dataset)
batch_expand_num = 1
if _need_to_full():
batch_expand_num = _get_device_num()
tensor_list_run = _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num)
def op():
return tensor_list_run
self.op = op
class _DatasetIterFeed:
class _DatasetIterNormal:
"""Iter for normal(non sink) mode, feed the data from host."""
def __init__(self, dataset):
self.dataset = dataset
self.device_num = _get_device_num()
self.global_rank = _get_global_rank()
self.repeat_count = dataset.get_repeat_count()
self.repeat_ind = 0
self.loop_count = dataset.get_dataset_size()
self.ind = 0
def __iter__(self):
if self.repeat_ind % self.repeat_count == 0:
self.iter = self.dataset.__iter__()
self.repeat_ind += 1
self.ind = 0
self.iter = self.dataset.create_tuple_iterator()
return self
def __next__(self):
if self.ind >= self.loop_count:
raise StopIteration()
self.ind += 1
data = self.iter.__next__()
if _need_to_full():
return _to_full_tensor(data, self.device_num, self.global_rank)
......
......@@ -21,7 +21,7 @@ import numpy as np
from mindspore import log as logger
from ..common.tensor import Tensor
from ..nn.metrics import get_metrics
from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool
from .._checkparam import check_input_data, check_output_data, check_int_positive, check_bool, check_int
from .callback import _InternalCallbackParam, RunContext, _CallbackManager
from .. import context
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
......@@ -225,7 +225,7 @@ class Model:
scaling_sens /= self._device_number
return scaling_sens
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode):
def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1):
"""Initializes dataset."""
need_wrap = False
if dataset_sink_mode:
......@@ -237,7 +237,7 @@ class Model:
if not is_train:
dataset.__loop_size__ = 1
dataset_helper = DatasetHelper(dataset, dataset_sink_mode)
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size)
# remove later to deal with loop sink
if need_wrap:
......@@ -317,7 +317,7 @@ class Model:
self._eval_network.compile(*inputs)
break
def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True):
def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1):
"""
Training.
......@@ -332,6 +332,7 @@ class Model:
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
Configure pynative mode, the training process will be performed with
dataset not sink.
sink_size (int): Control the amount of data each sink. Default: -1.
"""
epoch = check_int_positive(epoch)
self._train_network.set_train()
......@@ -342,7 +343,10 @@ class Model:
cb_params = _InternalCallbackParam()
cb_params.train_network = self._train_network
cb_params.epoch_num = epoch
cb_params.batch_num = train_dataset.get_dataset_size()
if dataset_sink_mode and sink_size > 0:
cb_params.batch_num = sink_size
else:
cb_params.batch_num = train_dataset.get_dataset_size()
cb_params.mode = "train"
cb_params.loss_fn = self._loss_fn
cb_params.optimizer = self._optimizer
......@@ -364,7 +368,7 @@ class Model:
"So the training process will be performed with dataset not sink.")
self._train_process(epoch, train_dataset, list_callback, cb_params)
else:
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params)
self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params, sink_size)
@staticmethod
def _transform_callbacks(callbacks):
......@@ -377,7 +381,7 @@ class Model:
return [callbacks]
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, sink_size=-1):
"""
Training process. The data would be passed to network through dataset channel.
......@@ -390,17 +394,18 @@ class Model:
function respectively.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
sink_size (int): Control the amount of data each sink. Default: -1.
"""
dataset_helper, train_network = self._exec_preprocess(self._train_network,
is_train=True,
phase='train',
dataset=train_dataset,
dataset_sink_mode=True)
dataset_sink_mode=True,
sink_size=sink_size)
self._train_network = train_network
cb_params.train_network = self._train_network
cb_params.cur_step_num = 0
loop_size = dataset_helper.loop_size()
run_context = RunContext(cb_params)
list_callback.begin(run_context)
......@@ -412,9 +417,9 @@ class Model:
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
for inputs in dataset_helper:
cb_params.cur_step_num += loop_size
list_callback.step_begin(run_context)
outputs = self._train_network(*inputs)
cb_params.cur_step_num += dataset_helper.sink_size()
cb_params.net_outputs = outputs
list_callback.step_end(run_context)
......@@ -422,6 +427,7 @@ class Model:
should_stop = should_stop or run_context.get_stop_requested()
if should_stop:
break
dataset_helper.stop_send()
list_callback.end(run_context)
......@@ -490,7 +496,7 @@ class Model:
list_callback.end(run_context)
def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True):
def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1):
"""
Training API where the iteration is controlled by python front-end.
......@@ -515,7 +521,10 @@ class Model:
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
Configure pynative mode, the training process will be performed with
dataset not sink.
sink_size (int): Control the amount of data each sink.
If sink_size=-1, sink the complete dataset each epoch.
If sink_size>0, sink sink_size data each epoch.
If dataset_sink_mode is False, set sink_size invalid. Default: -1.
Examples:
>>> dataset = get_dataset()
......@@ -526,17 +535,19 @@ class Model:
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
>>> model.train(2, dataset)
"""
repeat_count = train_dataset.get_repeat_count()
if epoch != repeat_count and dataset_sink_mode is True:
logger.warning(f"The epoch_size {epoch} is not the same with dataset repeat_count {repeat_count}")
check_bool(dataset_sink_mode)
check_int(sink_size)
if sink_size < -1 or sink_size == 0:
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
_device_number_check(self._parallel_mode, self._device_number)
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
self._train(epoch,
train_dataset,
callbacks=callbacks,
dataset_sink_mode=dataset_sink_mode)
dataset_sink_mode=dataset_sink_mode,
sink_size=sink_size)
def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None):
"""
......
......@@ -43,7 +43,7 @@ if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
ds_train = create_dataset_cifar10(args.data_path, cfg.batch_size, cfg.epoch_size)
ds_train = create_dataset_cifar10(args.data_path, cfg.batch_size, 1)
network = AlexNet(cfg.num_classes)
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
lr = Tensor(get_lr(0, cfg.learning_rate, cfg.epoch_size, ds_train.get_dataset_size()))
......
......@@ -57,7 +57,7 @@ if __name__ == '__main__':
ds_train = create_dataset(args_opt.dataset_path,
train_mode=True,
epochs=train_config.train_epochs,
epochs=1,
batch_size=train_config.batch_size,
data_type=DataType(data_config.data_format),
rank_size=rank_size,
......@@ -82,7 +82,7 @@ if __name__ == '__main__':
if args_opt.do_eval:
ds_eval = create_dataset(args_opt.dataset_path, train_mode=False,
epochs=train_config.train_epochs,
epochs=1,
batch_size=train_config.batch_size,
data_type=DataType(data_config.data_format))
eval_callback = EvalCallBack(model, ds_eval, auc_metric,
......
......@@ -66,7 +66,7 @@ if __name__ == "__main__":
init()
args_opt.base_size = config.crop_size
args_opt.crop_size = config.crop_size
train_dataset = create_dataset(args_opt, args_opt.data_url, config.epoch_size, config.batch_size, usage="train")
train_dataset = create_dataset(args_opt, args_opt.data_url, 1, config.batch_size, usage="train")
dataset_size = train_dataset.get_dataset_size()
time_cb = TimeMonitor(data_size=dataset_size)
callback = [time_cb, LossCallBack()]
......
......@@ -94,7 +94,7 @@ if __name__ == '__main__':
loss_scale = float(config.loss_scale)
# When create MindDataset, using the fitst mindrecord file, such as FasterRcnn.mindrecord0.
dataset = create_fasterrcnn_dataset(mindrecord_file, repeat_num=config.epoch_size,
dataset = create_fasterrcnn_dataset(mindrecord_file, repeat_num=1,
batch_size=config.batch_size, device_num=device_num, rank_id=rank)
dataset_size = dataset.get_dataset_size()
......
......@@ -78,7 +78,7 @@ if __name__ == '__main__':
mirror_mean=True)
init()
dataset = create_dataset(cfg.data_path, cfg.epoch_size)
dataset = create_dataset(cfg.data_path, 1)
batch_num = dataset.get_dataset_size()
net = GoogleNet(num_classes=cfg.num_classes)
......
......@@ -45,8 +45,7 @@ if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
ds_train = create_dataset(os.path.join(args.data_path, "train"),
cfg.batch_size,
cfg.epoch_size)
cfg.batch_size)
network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
......
......@@ -44,7 +44,7 @@ args = parser.parse_args()
if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size)
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, 1)
step_size = ds_train.get_dataset_size()
# define fusion network
......
......@@ -77,7 +77,7 @@ if __name__ == '__main__':
model = Model(network, loss, opt, {'acc': Accuracy()})
print("============== Starting Training ==============")
ds_train = lstm_create_dataset(args.preprocess_path, cfg.batch_size, cfg.num_epochs)
ds_train = lstm_create_dataset(args.preprocess_path, cfg.batch_size, 1)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck)
......
......@@ -249,7 +249,7 @@ def train_parallel(config: TransformerConfig):
pre_train_dataset = load_dataset(
data_files=config.pre_train_dataset,
batch_size=config.batch_size, epoch_count=config.epochs,
batch_size=config.batch_size, epoch_count=1,
sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step,
rank_size=MultiAscend.get_group_size(),
......@@ -257,7 +257,7 @@ def train_parallel(config: TransformerConfig):
) if config.pre_train_dataset else None
fine_tune_dataset = load_dataset(
data_files=config.fine_tune_dataset,
batch_size=config.batch_size, epoch_count=config.epochs,
batch_size=config.batch_size, epoch_count=1,
sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step,
rank_size=MultiAscend.get_group_size(),
......@@ -265,7 +265,7 @@ def train_parallel(config: TransformerConfig):
) if config.fine_tune_dataset else None
test_dataset = load_dataset(
data_files=config.test_dataset,
batch_size=config.batch_size, epoch_count=config.epochs,
batch_size=config.batch_size, epoch_count=1,
sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step,
rank_size=MultiAscend.get_group_size(),
......@@ -288,17 +288,17 @@ def train_single(config: TransformerConfig):
print(" | Starting training on single device.")
pre_train_dataset = load_dataset(data_files=config.pre_train_dataset,
batch_size=config.batch_size,
epoch_count=config.epochs,
epoch_count=1,
sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step) if config.pre_train_dataset else None
fine_tune_dataset = load_dataset(data_files=config.fine_tune_dataset,
batch_size=config.batch_size,
epoch_count=config.epochs,
epoch_count=1,
sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step) if config.fine_tune_dataset else None
test_dataset = load_dataset(data_files=config.test_dataset,
batch_size=config.batch_size,
epoch_count=config.epochs,
epoch_count=1,
sink_mode=config.dataset_sink_mode,
sink_step=config.dataset_sink_step) if config.test_dataset else None
......
......@@ -180,7 +180,7 @@ if __name__ == '__main__':
do_train=True,
config=config_gpu,
platform=args_opt.platform,
repeat_num=epoch_size,
repeat_num=1,
batch_size=config_gpu.batch_size)
step_size = dataset.get_dataset_size()
# resume
......@@ -239,7 +239,7 @@ if __name__ == '__main__':
do_train=True,
config=config_ascend,
platform=args_opt.platform,
repeat_num=epoch_size,
repeat_num=1,
batch_size=config_ascend.batch_size)
step_size = dataset.get_dataset_size()
if args_opt.pre_trained:
......
......@@ -86,7 +86,7 @@ if __name__ == '__main__':
do_train=True,
config=config,
device_target=args_opt.device_target,
repeat_num=epoch_size,
repeat_num=1,
batch_size=config.batch_size)
step_size = dataset.get_dataset_size()
# load pre trained ckpt
......
......@@ -181,7 +181,7 @@ if __name__ == '__main__':
do_train=True,
config=config_gpu,
platform=args_opt.platform,
repeat_num=epoch_size,
repeat_num=1,
batch_size=config_gpu.batch_size)
step_size = dataset.get_dataset_size()
# resume
......@@ -240,7 +240,7 @@ if __name__ == '__main__':
do_train=True,
config=config_ascend,
platform=args_opt.platform,
repeat_num=epoch_size,
repeat_num=1,
batch_size=config_ascend.batch_size)
step_size = dataset.get_dataset_size()
if args_opt.pre_trained:
......
......@@ -36,12 +36,11 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
_cur_dir = os.getcwd()
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path=""):
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
""" do train """
if load_checkpoint_path == "":
raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
steps_per_epoch = dataset.get_dataset_size()
epoch_num = dataset.get_repeat_count()
# optimizer
if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR':
optimizer = AdamWeightDecayDynamicLR(network.trainable_params(),
......@@ -176,11 +175,11 @@ def run_classifier():
assessment_method=assessment_method)
if args_opt.do_train.lower() == "true":
ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num,
ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1,
assessment_method=assessment_method,
data_file_path=args_opt.train_data_file_path,
schema_file_path=args_opt.schema_file_path)
do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path)
do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num)
if args_opt.do_eval.lower() == "true":
if save_finetune_checkpoint_path == "":
......@@ -191,7 +190,7 @@ def run_classifier():
ds.get_dataset_size(), epoch_num, "classifier")
if args_opt.do_eval.lower() == "true":
ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num,
ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1,
assessment_method=assessment_method,
data_file_path=args_opt.eval_data_file_path,
schema_file_path=args_opt.schema_file_path)
......
......@@ -38,12 +38,11 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
_cur_dir = os.getcwd()
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path=""):
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
""" do train """
if load_checkpoint_path == "":
raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
steps_per_epoch = dataset.get_dataset_size()
epoch_num = dataset.get_repeat_count()
# optimizer
if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR':
optimizer = AdamWeightDecayDynamicLR(network.trainable_params(),
......@@ -204,10 +203,10 @@ def run_ner():
use_crf=(args_opt.use_crf.lower() == "true"),
tag_to_index=tag_to_index, dropout_prob=0.1)
if args_opt.do_train.lower() == "true":
ds = create_ner_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num,
ds = create_ner_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1,
assessment_method=assessment_method, data_file_path=args_opt.train_data_file_path,
schema_file_path=args_opt.schema_file_path)
do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path)
do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num)
if args_opt.do_eval.lower() == "true":
if save_finetune_checkpoint_path == "":
......@@ -218,7 +217,7 @@ def run_ner():
ds.get_dataset_size(), epoch_num, "ner")
if args_opt.do_eval.lower() == "true":
ds = create_ner_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num,
ds = create_ner_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1,
assessment_method=assessment_method, data_file_path=args_opt.eval_data_file_path,
schema_file_path=args_opt.schema_file_path)
do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method, args_opt.eval_data_file_path,
......
......@@ -100,11 +100,12 @@ def run_pretrain():
bert_net_cfg.compute_type = mstype.float32
ds, new_repeat_count = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle,
args_opt.enable_data_sink, args_opt.data_sink_steps,
args_opt.data_dir, args_opt.schema_dir)
ds = create_bert_dataset(1, device_num, rank, args_opt.do_shuffle,
args_opt.enable_data_sink, args_opt.data_sink_steps,
args_opt.data_dir, args_opt.schema_dir)
new_repeat_count = args_opt.epoch_size
if args_opt.train_steps > 0:
new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps)
new_repeat_count = min(args_opt.epoch_size, args_opt.train_steps // args_opt.data_sink_steps)
netwithloss = BertNetworkWithLoss(bert_net_cfg, True)
if cfg.optimizer == 'Lamb':
......
......@@ -38,12 +38,11 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
_cur_dir = os.getcwd()
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path=""):
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
""" do train """
if load_checkpoint_path == "":
raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
steps_per_epoch = dataset.get_dataset_size()
epoch_num = dataset.get_repeat_count()
# optimizer
if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR':
optimizer = AdamWeightDecayDynamicLR(network.trainable_params(),
......@@ -181,10 +180,10 @@ def run_squad():
netwithloss = BertSquad(bert_net_cfg, True, 2, dropout_prob=0.1)
if args_opt.do_train.lower() == "true":
ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num,
ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1,
data_file_path=args_opt.train_data_file_path,
schema_file_path=args_opt.schema_file_path)
do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path)
do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path, epoch_num)
if args_opt.do_eval.lower() == "true":
if save_finetune_checkpoint_path == "":
load_finetune_checkpoint_dir = _cur_dir
......@@ -194,7 +193,7 @@ def run_squad():
ds.get_dataset_size(), epoch_num, "squad")
if args_opt.do_eval.lower() == "true":
ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num,
ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=1,
data_file_path=args_opt.eval_data_file_path,
schema_file_path=args_opt.schema_file_path, is_training=False)
do_eval(ds, args_opt.vocab_file_path, args_opt.eval_json_path,
......
......@@ -54,7 +54,6 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
# apply batch operations
ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True)
ds = ds.repeat(max(new_repeat_count, repeat_count))
logger.info("data size: {}".format(ds.get_dataset_size()))
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
return ds, new_repeat_count
......
......@@ -17,7 +17,6 @@
import mindspore.common.dtype as mstype
import mindspore.dataset.engine.datasets as de
import mindspore.dataset.transforms.c_transforms as deC
from mindspore import log as logger
from .config import transformer_net_cfg
def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle="true", enable_data_sink="true",
......@@ -42,7 +41,4 @@ def create_transformer_dataset(epoch_count=1, rank_size=1, rank_id=0, do_shuffle
ds = ds.batch(transformer_net_cfg.batch_size, drop_remainder=True)
ds = ds.repeat(repeat_count)
ds.channel_name = 'transformer'
logger.info("data size: {}".format(ds.get_dataset_size()))
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
return ds, repeat_count
return ds
......@@ -125,10 +125,10 @@ def run_transformer_train():
else:
device_num = 1
rank_id = 0
dataset, repeat_count = create_transformer_dataset(epoch_count=args.epoch_size, rank_size=device_num,
rank_id=rank_id, do_shuffle=args.do_shuffle,
enable_data_sink=args.enable_data_sink,
dataset_path=args.data_path)
dataset = create_transformer_dataset(epoch_count=1, rank_size=device_num,
rank_id=rank_id, do_shuffle=args.do_shuffle,
enable_data_sink=args.enable_data_sink,
dataset_path=args.data_path)
netwithloss = TransformerNetworkWithLoss(transformer_net_cfg, True)
......@@ -165,7 +165,7 @@ def run_transformer_train():
netwithgrads.set_train(True)
model = Model(netwithgrads)
model.train(repeat_count, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true"))
model.train(args.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true"))
if __name__ == '__main__':
run_transformer_train()
......@@ -88,10 +88,10 @@ if __name__ == '__main__':
# create dataset
if args_opt.net == "resnet50":
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=config.epoch_size,
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1,
batch_size=config.batch_size, target=target)
else:
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=config.epoch_size,
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1,
batch_size=config.batch_size)
step_size = dataset.get_dataset_size()
......
......@@ -105,7 +105,7 @@ if __name__ == '__main__':
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
if args_opt.do_train:
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True,
repeat_num=epoch_size, batch_size=config.batch_size)
batch_size=config.batch_size)
step_size = dataset.get_dataset_size()
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
......
......@@ -91,7 +91,7 @@ def main():
loss_scale = float(args_opt.loss_scale)
# When create MindDataset, using the fitst mindrecord file, such as ssd.mindrecord0.
dataset = create_ssd_dataset(mindrecord_file, repeat_num=args_opt.epoch_size,
dataset = create_ssd_dataset(mindrecord_file, repeat_num=1,
batch_size=args_opt.batch_size, device_num=device_num, rank=rank)
dataset_size = dataset.get_dataset_size()
......
......@@ -83,7 +83,7 @@ if __name__ == '__main__':
mirror_mean=True)
init()
dataset = vgg_create_dataset(args_opt.data_path, cfg.epoch_size)
dataset = vgg_create_dataset(args_opt.data_path, 1)
batch_num = dataset.get_dataset_size()
net = vgg16(num_classes=cfg.num_classes)
......
......@@ -63,7 +63,7 @@ def test_train(configure):
data_path = configure.data_path
batch_size = configure.batch_size
epochs = configure.epochs
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, batch_size=batch_size)
ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size)
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
net_builder = ModelBuilder()
......
......@@ -67,8 +67,8 @@ def test_train_eval(config):
data_path = config.data_path
batch_size = config.batch_size
epochs = config.epochs
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, batch_size=batch_size)
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, batch_size=batch_size)
ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size)
ds_eval = create_dataset(data_path, train_mode=False, epochs=1, batch_size=batch_size)
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
......
......@@ -85,14 +85,14 @@ def train_and_eval(config):
if config.full_batch:
context.set_auto_parallel_context(full_batch=True)
de.config.set_seed(1)
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs,
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
batch_size=batch_size*get_group_size())
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1,
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
batch_size=batch_size*get_group_size())
else:
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs,
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1,
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
......
......@@ -74,9 +74,9 @@ def train_and_eval(config):
batch_size = config.batch_size
epochs = config.epochs
print("epochs is {}".format(epochs))
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs,
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1,
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
......
......@@ -121,7 +121,7 @@ def main():
loss_scale = float(args_opt.loss_scale)
# When create MindDataset, using the fitst mindrecord file, such as yolo.mindrecord0.
dataset = create_yolo_dataset(mindrecord_file, repeat_num=args_opt.epoch_size,
dataset = create_yolo_dataset(mindrecord_file,
batch_size=args_opt.batch_size, device_num=device_num, rank=rank)
dataset_size = dataset.get_dataset_size()
print("Create dataset done!")
......
......@@ -50,13 +50,20 @@ class MindData:
def input_indexs(self):
return self._input_indexs
def device_que(self):
def device_que(self, send_epoch_end=True):
self.queue_name = '6ba41974-209e-11ea-88b0-a24efeb2c736'
self.send_epoch_end = send_epoch_end
return self
def create_tuple_iterator(self):
return self.__iter__()
def send(self):
pass
def stop_send(self):
pass
def __len__(self):
return self._size
......
......@@ -73,7 +73,7 @@ if __name__ == "__main__":
epoch_size = 3
args_opt.base_size = config.crop_size
args_opt.crop_size = config.crop_size
train_dataset = create_dataset(args_opt, args_opt.data_url, epoch_size, config.batch_size,
train_dataset = create_dataset(args_opt, args_opt.data_url, 1, config.batch_size,
usage="train", shuffle=False)
dataset_size = train_dataset.get_dataset_size()
callback = LossCallBack(dataset_size)
......
......@@ -120,10 +120,10 @@ def test_transformer():
batch_size = 96
epoch_size = 3
config = get_config(version=version, batch_size=batch_size)
dataset, repeat_count = create_transformer_dataset(epoch_count=epoch_size,
do_shuffle="false",
enable_data_sink="false",
dataset_path=DATA_DIR)
dataset = create_transformer_dataset(epoch_count=1,
do_shuffle="false",
enable_data_sink="false",
dataset_path=DATA_DIR)
netwithloss = TransformerNetworkWithLoss(config, True)
......@@ -146,7 +146,7 @@ def test_transformer():
netwithgrads.set_train(True)
time_monitor_callback = TimeMonitor(dataset.get_dataset_size())
model = Model(netwithgrads)
model.train(repeat_count, dataset, callbacks=[time_monitor_callback, callback], dataset_sink_mode=False)
model.train(epoch_size, dataset, callbacks=[time_monitor_callback, callback], dataset_sink_mode=False)
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
loss_value = np.array(callback.loss_list)
......
......@@ -79,9 +79,9 @@ def test_train_eval():
batch_size = config.batch_size
epochs = config.epochs
print("epochs is {}".format(epochs))
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, batch_size=batch_size,
ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size,
data_type=DataType.MINDRECORD, rank_id=get_rank(), rank_size=get_group_size())
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, batch_size=batch_size,
ds_eval = create_dataset(data_path, train_mode=False, epochs=1, batch_size=batch_size,
data_type=DataType.MINDRECORD, rank_id=get_rank(), rank_size=get_group_size())
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
......
......@@ -76,9 +76,9 @@ def test_train_eval():
batch_size = config.batch_size
epochs = config.epochs
print("epochs is {}".format(epochs))
ds_train = create_dataset(data_path, train_mode=True, epochs=epochs,
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1,
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
batch_size=batch_size, rank_id=get_rank(), rank_size=get_group_size())
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
......
......@@ -113,7 +113,7 @@ def test_yolov3():
loss_scale = float(loss_scale)
# When create MindDataset, using the fitst mindrecord file, such as yolo.mindrecord0.
dataset = create_yolo_dataset(mindrecord_file, repeat_num=epoch_size,
dataset = create_yolo_dataset(mindrecord_file, repeat_num=1,
batch_size=batch_size, device_num=device_num, rank=rank)
dataset_size = dataset.get_dataset_size()
print("Create dataset done!")
......@@ -146,12 +146,12 @@ def test_yolov3():
assert loss_value[2] < expect_loss_value[2]
epoch_mseconds = np.array(time_monitor_callback.epoch_mseconds_list)[2]
expect_epoch_mseconds = 950
expect_epoch_mseconds = 2000
print("epoch mseconds: {}".format(epoch_mseconds))
assert epoch_mseconds <= expect_epoch_mseconds
per_step_mseconds = np.array(time_monitor_callback.per_step_mseconds_list)[2]
expect_per_step_mseconds = 110
expect_per_step_mseconds = 220
print("per step mseconds: {}".format(per_step_mseconds))
assert per_step_mseconds <= expect_per_step_mseconds
print("yolov3 test case passed.")
......@@ -91,6 +91,7 @@ def me_de_train_dataset(sink_mode=False):
"""test me de train dataset"""
# apply repeat operations
repeat_count = 1
sink_size = -1
batch_size = 16
ds = de.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["input_ids", "input_mask", "segment_ids",
"next_sentence_labels", "masked_lm_positions",
......@@ -99,9 +100,9 @@ def me_de_train_dataset(sink_mode=False):
new_repeat_count = repeat_count
if sink_mode:
repeat_count = 30
sink_steps = 100
sink_size = 100
ori_dataaet_size = ds.get_dataset_size()
new_size = sink_steps * batch_size
new_size = sink_size * batch_size
ds.set_dataset_size(new_size)
new_repeat_count = int(repeat_count * ori_dataaet_size // ds.get_dataset_size())
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op)
......@@ -112,10 +113,9 @@ def me_de_train_dataset(sink_mode=False):
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(repeat_count)
logger.info("data size: {}".format(ds.get_dataset_size()))
logger.info("repeat_count: {}".format(ds.get_repeat_count()))
return ds, new_repeat_count
return ds, new_repeat_count, sink_size
def weight_variable(shape):
......@@ -157,7 +157,7 @@ class TimeMonitor(Callback):
def test_bert_percision():
"""test bert percision"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False)
ds, new_repeat_count = me_de_train_dataset()
ds, new_repeat_count, _ = me_de_train_dataset()
version = os.getenv('VERSION', 'large')
batch_size = 16
config = get_config(version=version, batch_size=batch_size)
......@@ -215,7 +215,7 @@ def test_bert_percision():
def test_bert_performance():
"""test bert performance"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False)
ds, new_repeat_count = me_de_train_dataset(sink_mode=True)
ds, new_repeat_count, sink_size = me_de_train_dataset(sink_mode=True)
version = os.getenv('VERSION', 'large')
batch_size = 16
config = get_config(version=version, batch_size=batch_size)
......@@ -251,7 +251,7 @@ def test_bert_performance():
param.default_input = weight_variable(value.asnumpy().shape)
time_monitor_callback = TimeMonitor(ds.get_dataset_size())
model.train(new_repeat_count, ds, callbacks=[time_monitor_callback, callback],
dataset_sink_mode=True)
dataset_sink_mode=True, sink_size=sink_size)
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
loss_value = np.array(callback.loss_list)
......
......@@ -79,7 +79,7 @@ def test_deeplabv3_1p():
args_opt.base_size = config.crop_size
args_opt.crop_size = config.crop_size
args_opt.batch_size = config.batch_size
train_dataset = create_dataset(args_opt, data_url, epoch_size, config.batch_size,
train_dataset = create_dataset(args_opt, data_url, 1, config.batch_size,
usage="eval")
dataset_size = train_dataset.get_dataset_size()
callback = LossCallBack(dataset_size)
......
......@@ -155,7 +155,7 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl):
# train dataset
dataset = create_dataset(dataset_path=dataset_path, do_train=True,
repeat_num=epoch_size, batch_size=config.batch_size)
repeat_num=1, batch_size=config.batch_size)
step_size = dataset.get_dataset_size()
eval_interval = config.eval_interval
......@@ -163,7 +163,7 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl):
# evalutation dataset
eval_dataset = create_dataset(dataset_path=eval_path, do_train=False,
repeat_num=epoch_size, batch_size=config.eval_batch_size)
repeat_num=1, batch_size=config.eval_batch_size)
# loss scale
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
......@@ -260,14 +260,14 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl):
# train dataset
dataset = create_dataset(dataset_path=dataset_path, do_train=True,
repeat_num=epoch_size, batch_size=thor_config.batch_size)
repeat_num=1, batch_size=thor_config.batch_size)
step_size = dataset.get_dataset_size()
eval_interval = thor_config.eval_interval
# evalutation dataset
eval_dataset = create_dataset(dataset_path=eval_path, do_train=False,
repeat_num=epoch_size, batch_size=thor_config.eval_batch_size)
repeat_num=1, batch_size=thor_config.eval_batch_size)
# loss scale
loss_scale = FixedLossScaleManager(thor_config.loss_scale, drop_overflow_update=False)
......
......@@ -136,7 +136,7 @@ if __name__ == '__main__':
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
if args_opt.do_train:
dataset = create_dataset(epoch_size)
dataset = create_dataset(1)
batch_num = dataset.get_dataset_size()
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=10)
ckpoint_cb = ModelCheckpoint(prefix="train_resnet_cifar10", directory="./", config=config_ck)
......
......@@ -140,7 +140,7 @@ def train_process(epoch_size, num_classes, batch_size):
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
dataset = create_dataset(epoch_size, training=True, batch_size=batch_size)
dataset = create_dataset(1, training=True, batch_size=batch_size)
loss_cb = LossGet()
model.train(epoch_size, dataset, callbacks=[loss_cb])
......
......@@ -164,7 +164,7 @@ def train_process(q, device_id, epoch_size, num_classes, device_num, batch_size,
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
dataset = create_dataset(epoch_size, training=True,
dataset = create_dataset(1, training=True,
batch_size=batch_size, rank_id=device_id, rank_size=device_num,
enable_hccl=enable_hccl)
......
......@@ -91,8 +91,9 @@ SET(DE_UT_SRCS
cyclic_array_test.cc
perf_data_test.cc
c_api_test.cc
tensor_op_fusion_pass_test.cc
tensor_op_fusion_pass_test.cc
sliding_window_op_test.cc
epoch_ctrl_op_test.cc
)
add_executable(de_ut_tests ${DE_UT_SRCS})
......
......@@ -397,23 +397,21 @@ TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) {
std::shared_ptr<CacheClient> myClient = std::make_shared<CacheClient>(1, 0, true);
std::shared_ptr<CacheMergeOp> myMergeOp;
rc = CacheMergeOp::Builder().SetNumWorkers(3).SetOpConnectorSize(3).SetNumCleaner(2).SetClient(myClient).Build(
&myMergeOp);
EXPECT_TRUE(rc.IsOk());
// In a mappable dataset, it uses a complex interactions of cache lookup op and cache merge op.
// Rather than manually build this, the way to do it is to choose the position of the cache in the tree by
// adding a CacheOp. Then, the tree prepare code will drive a transform that will remove the CacheOp and
// replace it with the required tree structures for cache lookup op and cache merge op.
std::shared_ptr<CacheLookupOp> myLookupOp;
rc = CacheLookupOp::Builder()
.SetNumWorkers(3)
.SetOpConnectorSize(3)
std::shared_ptr<CacheOp> myCacheOp;
rc = CacheOp::Builder()
.SetNumWorkers(4)
.SetClient(myClient)
.SetSampler(seq_sampler)
.Build(&myLookupOp);
EXPECT_TRUE(rc.IsOk());
.SetRowsPerBuffer(3)
.Build(&myCacheOp);
std::shared_ptr<ImageFolderOp> so;
ImageFolderOp::Builder builder;
builder.SetSampler(myLookupOp)
builder.SetSampler(std::move(seq_sampler))
.SetOpConnectorSize(3)
.SetNumWorkers(3)
.SetRowsPerBuffer(2)
......@@ -432,20 +430,18 @@ TEST_F(MindDataTestCacheOp, TestImageFolderCacheMerge) {
auto myTree = std::make_shared<ExecutionTree>();
rc = myTree->AssociateNode(so);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myLookupOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myMergeOp);
rc = myTree->AssociateNode(myCacheOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssociateNode(myRepeatOp);
EXPECT_TRUE(rc.IsOk());
rc = myTree->AssignRoot(myRepeatOp);
EXPECT_TRUE(rc.IsOk());
rc = myRepeatOp->AddChild(myMergeOp);
EXPECT_TRUE(rc.IsOk());
rc = myMergeOp->AddChild(myLookupOp);
rc = myRepeatOp->AddChild(myCacheOp);
EXPECT_TRUE(rc.IsOk());
rc = myMergeOp->AddChild(so);
rc = myCacheOp->AddChild(so);
EXPECT_TRUE(rc.IsOk());
rc = myTree->Prepare();
......
此差异已折叠。
......@@ -46,7 +46,8 @@ TEST_F(MindDataTestrepeat_op, Testrepeat_opFuntions) {
ASSERT_TRUE(rc.IsOk());
rc = my_tree->AssociateNode(my_tfreader_op);
ASSERT_TRUE(rc.IsOk());
my_tree->AssociateNode(parent_op);
rc = my_tree->AssociateNode(parent_op);
ASSERT_TRUE(rc.IsOk());
ASSERT_NE(parent_op, nullptr);
ASSERT_NE(my_tfreader_op, nullptr);
parent_op->AddChild(std::move(my_tfreader_op));
......
......@@ -104,9 +104,11 @@ def test_cache_map_basic3():
decode_op = c_vision.Decode()
ds1 = ds1.repeat(4)
ds1 = ds1.map(input_columns=["image"], operations=decode_op, cache=some_cache)
print("ds1.dataset_size is ", ds1.get_dataset_size())
num_iter = 0
for _ in ds1.create_dict_iterator():
print("get data from dataset")
num_iter += 1
logger.info("Number of data in ds1: {} ".format(num_iter))
......@@ -152,6 +154,10 @@ def test_cache_map_failure1():
if __name__ == '__main__':
test_cache_map_basic1()
print("test_cache_map_basic1 success.")
test_cache_map_basic2()
print("test_cache_map_basic2 success.")
test_cache_map_basic3()
print("test_cache_map_basic3 success.")
test_cache_map_failure1()
print("test_cache_map_failure1 success.")
......@@ -238,7 +238,7 @@ def test_tfrecord_shard_equal_rows():
def test_tfrecord_no_schema_columns_list():
logger.info("test_tfrecord_no_schema_columns_list")
data = ds.TFRecordDataset(FILES, shuffle=False, columns_list=["col_sint16"])
row = data.create_dict_iterator().get_next()
row = data.create_dict_iterator().__next__()
assert row["col_sint16"] == [-32768]
with pytest.raises(KeyError) as info:
......@@ -258,7 +258,7 @@ def test_tfrecord_schema_columns_list():
schema.add_column('col_sint32', de_type=mstype.int64, shape=[1])
schema.add_column('col_sint64', de_type=mstype.int64, shape=[1])
data = ds.TFRecordDataset(FILES, schema=schema, shuffle=False, columns_list=["col_sint16"])
row = data.create_dict_iterator().get_next()
row = data.create_dict_iterator().__next__()
assert row["col_sint16"] == [-32768]
with pytest.raises(KeyError) as info:
......
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import time
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger
......@@ -35,6 +37,8 @@ def test_case_0():
data = data.device_que()
data.send()
time.sleep(0.1)
data.stop_send()
def test_case_1():
......@@ -58,6 +62,8 @@ def test_case_1():
data = data.device_que()
data.send()
time.sleep(0.1)
data.stop_send()
def test_case_2():
......@@ -84,6 +90,8 @@ def test_case_2():
data = data.device_que()
assert data.get_repeat_count() == 2
data.send()
time.sleep(0.1)
data.stop_send()
def test_case_3():
......@@ -109,13 +117,17 @@ def test_case_3():
data = data.device_que()
data.send()
time.sleep(0.1)
data.stop_send()
def test_case_tf_file():
data = ds.TFRecordDataset(TF_FILES, TF_SCHEMA_FILE, shuffle=ds.Shuffle.FILES)
data = data.to_device(num_batch=10)
data = data.to_device()
data.send()
time.sleep(0.1)
data.stop_send()
if __name__ == '__main__':
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing Epoch Control op in DE
"""
import itertools
import cv2
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
def diff_mse(in1, in2):
"""
diff_mse
"""
mse = (np.square(in1.astype(float) / 255 - in2.astype(float) / 255)).mean()
return mse * 100
def test_cifar10():
"""
dataset parameter
"""
logger.info("Test dataset parameter")
data_dir_10 = "../data/dataset/testCifar10Data"
num_repeat = 2
batch_size = 32
limit_dataset = 100
# apply dataset operations
data1 = ds.Cifar10Dataset(data_dir_10, limit_dataset)
data1 = data1.repeat(num_repeat)
data1 = data1.batch(batch_size, True)
num_epoch = 5
# iter1 will always assume there is a next epoch and never shutdown.
iter1 = data1.create_tuple_iterator()
epoch_count = 0
sample_count = 0
for _ in range(num_epoch):
row_count = 0
for _ in iter1:
# in this example, each dictionary has keys "image" and "label"
row_count += 1
assert row_count == int(limit_dataset * num_repeat / batch_size)
logger.debug("row_count: ", row_count)
epoch_count += 1
sample_count += row_count
assert epoch_count == num_epoch
logger.debug("total epochs: ", epoch_count)
assert sample_count == int(limit_dataset * num_repeat / batch_size) * num_epoch
logger.debug("total sample: ", sample_count)
def test_decode_op():
"""
Test Decode op
"""
logger.info("test_decode_op")
# Decode with rgb format set to True
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
# Serialize and Load dataset requires using vision.Decode instead of vision.Decode().
data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True)])
# Second dataset
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
num_epoch = 5
# iter1 will always assume there is a next epoch and never shutdown.
iter1 = data1.create_dict_iterator()
# iter 2 will stop and shutdown pipeline after num_epoch
iter2 = data2.create_dict_iterator(num_epoch)
for _ in range(num_epoch):
i = 0
for item1, item2 in itertools.zip_longest(iter1, iter2):
actual = item1["image"]
expected = cv2.imdecode(item2["image"], cv2.IMREAD_COLOR)
expected = cv2.cvtColor(expected, cv2.COLOR_BGR2RGB)
assert actual.shape == expected.shape
diff = actual - expected
mse = np.sum(np.power(diff, 2))
assert mse == 0
i = i + 1
assert i == 3
# Users have the option to manually stop the iterator, or rely on garbage collector.
iter1.stop()
# Expect a AttributeError since iter1 has been stopped.
with pytest.raises(AttributeError) as info:
iter1.__next__()
assert "object has no attribute 'depipeline'" in str(info.value)
with pytest.raises(RuntimeError) as info:
iter2.__next__()
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
# Generate 1d int numpy array from 0 - 63
def generator_1d():
"""
generator
"""
for i in range(64):
yield (np.array([i]),)
def test_generator_dict_0():
"""
test generator dict 0
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
i = 0
# create the iterator inside the loop declaration
for item in data1.create_dict_iterator(): # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item["data"], golden)
i = i + 1
def test_generator_dict_1():
"""
test generator dict 1
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
for _ in range(10):
i = 0
# BAD. Do not create iterator every time inside.
# Create iterator outside the epoch for loop.
for item in data1.create_dict_iterator(): # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item["data"], golden)
i = i + 1
assert i == 64
def test_generator_dict_2():
"""
test generator dict 2
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
iter1 = data1.create_dict_iterator()
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item["data"], golden)
i = i + 1
assert i == 64
# iter1 is still alive and running.
item1 = iter1.__next__()
assert item1
# rely on garbage collector to destroy iter1
def test_generator_dict_3():
"""
test generator dict 3
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
iter1 = data1.create_dict_iterator()
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item["data"], golden)
i = i + 1
assert i == 64
# optional
iter1.stop()
# Expect a AttributeError since iter1 has been stopped.
with pytest.raises(AttributeError) as info:
iter1.__next__()
assert "object has no attribute 'depipeline'" in str(info.value)
def test_generator_dict_4():
"""
test generator dict 4
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
iter1 = data1.create_dict_iterator(num_epochs=10)
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item["data"], golden)
i = i + 1
assert i == 64
with pytest.raises(RuntimeError) as info:
iter1.__next__()
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
def test_generator_dict_4_1():
"""
test generator dict 4_1
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
# epoch ctrl op will not be injected if num_epochs is 1.
iter1 = data1.create_dict_iterator(num_epochs=1)
for _ in range(1):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item["data"], golden)
i = i + 1
assert i == 64
with pytest.raises(RuntimeError) as info:
iter1.__next__()
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
def test_generator_dict_4_2():
"""
test generator dict 4_2
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
# repeat will not be injected when num repeat is 1.
data1 = data1.repeat(1)
# epoch ctrl op will not be injected if num_epochs is 1.
iter1 = data1.create_dict_iterator(num_epochs=1)
for _ in range(1):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item["data"], golden)
i = i + 1
assert i == 64
with pytest.raises(RuntimeError) as info:
iter1.__next__()
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
def test_generator_dict_5():
"""
test generator dict 5
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
iter1 = data1.create_dict_iterator(num_epochs=11)
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item["data"], golden)
i = i + 1
assert i == 64
# still one more epoch left in the iter1.
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item["data"], golden)
i = i + 1
assert i == 64
# now iter1 has been exhausted, c++ pipeline has been shut down.
with pytest.raises(RuntimeError) as info:
iter1.__next__()
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
# Test tuple iterator
def test_generator_tuple_0():
"""
test generator tuple 0
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
i = 0
# create the iterator inside the loop declaration
for item in data1.create_tuple_iterator(): # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item[0], golden)
i = i + 1
def test_generator_tuple_1():
"""
test generator tuple 1
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
for _ in range(10):
i = 0
# BAD. Do not create iterator every time inside.
# Create iterator outside the epoch for loop.
for item in data1.create_tuple_iterator(): # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64
def test_generator_tuple_2():
"""
test generator tuple 2
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
iter1 = data1.create_tuple_iterator()
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64
# iter1 is still alive and running.
item1 = iter1.__next__()
assert item1
# rely on garbage collector to destroy iter1
def test_generator_tuple_3():
"""
test generator tuple 3
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
iter1 = data1.create_tuple_iterator()
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64
# optional
iter1.stop()
# Expect a AttributeError since iter1 has been stopped.
with pytest.raises(AttributeError) as info:
iter1.__next__()
assert "object has no attribute 'depipeline'" in str(info.value)
def test_generator_tuple_4():
"""
test generator tuple 4
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
iter1 = data1.create_tuple_iterator(num_epochs=10)
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64
with pytest.raises(RuntimeError) as info:
iter1.__next__()
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
def test_generator_tuple_5():
"""
test generator tuple 5
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
iter1 = data1.create_tuple_iterator(num_epochs=11)
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64
# still one more epoch left in the iter1.
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64
# now iter1 has been exhausted, c++ pipeline has been shut down.
with pytest.raises(RuntimeError) as info:
iter1.__next__()
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
# Test with repeat
def test_generator_tuple_repeat_1():
"""
test generator tuple repeat 1
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
data1 = data1.repeat(2)
iter1 = data1.create_tuple_iterator(num_epochs=11)
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64 * 2
# still one more epoch left in the iter1.
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64 * 2
# now iter1 has been exhausted, c++ pipeline has been shut down.
with pytest.raises(RuntimeError) as info:
iter1.__next__()
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
# Test with repeat
def test_generator_tuple_repeat_repeat_1():
"""
test generator tuple repeat repeat 1
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
data1 = data1.repeat(2)
data1 = data1.repeat(3)
iter1 = data1.create_tuple_iterator(num_epochs=11)
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64 * 2 * 3
# still one more epoch left in the iter1.
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64 * 2 * 3
# now iter1 has been exhausted, c++ pipeline has been shut down.
with pytest.raises(RuntimeError) as info:
iter1.__next__()
err_msg = "EOF buffer encountered. Users try to fetch data beyond the specified number of epochs."
assert err_msg in str(info.value)
def test_generator_tuple_repeat_repeat_2():
"""
test generator tuple repeat repeat 2
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
data1 = data1.repeat(2)
data1 = data1.repeat(3)
iter1 = data1.create_tuple_iterator()
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64 * 2 * 3
# optional
iter1.stop()
# Expect a AttributeError since iter1 has been stopped.
with pytest.raises(AttributeError) as info:
iter1.__next__()
assert "object has no attribute 'depipeline'" in str(info.value)
def test_generator_tuple_repeat_repeat_3():
"""
test generator tuple repeat repeat 3
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
data1 = data1.repeat(2)
data1 = data1.repeat(3)
iter1 = data1.create_tuple_iterator()
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64 * 2 * 3
for _ in range(5):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64 * 2 * 3
# rely on garbage collector to destroy iter1
def test_generator_reusedataset():
"""
test generator reusedataset
"""
logger.info("Test 1D Generator : 0 - 63")
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
data1 = data1.repeat(2)
iter1 = data1.create_tuple_iterator()
for _ in range(10):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64 * 2
data1 = data1.repeat(3)
iter1 = data1.create_tuple_iterator()
for _ in range(5):
i = 0
for item in iter1: # each data is a dictionary
golden = np.array([i % 64])
assert np.array_equal(item[0], golden)
i = i + 1
assert i == 64 * 2 * 3
data1 = data1.batch(2)
iter1 = data1.create_dict_iterator()
for _ in range(5):
i = 0
sample = 0
for item in iter1: # each data is a dictionary
golden = np.array([[i % 64], [(i + 1) % 64]])
assert np.array_equal(item["data"], golden)
i = i + 2
sample = sample + 1
assert sample == 64 * 3
# rely on garbage collector to destroy iter1
......@@ -87,7 +87,7 @@ def test_five_crop_error_msg():
data = data.map(input_columns=["image"], operations=transform())
with pytest.raises(RuntimeError) as info:
data.create_tuple_iterator().get_next()
data.create_tuple_iterator().__next__()
error_msg = "TypeError: img should be PIL Image or Numpy array. Got <class 'tuple'>"
# error msg comes from ToTensor()
......
......@@ -41,18 +41,18 @@ def test_case1():
assert data.get_batch_size() == 2
assert data.get_repeat_count() == 1
data = data.repeat(10)
assert data.get_dataset_size() == 6
assert data.get_dataset_size() == 60
assert data.get_batch_size() == 2
assert data.get_repeat_count() == 10
data = data.project(["new_column"])
assert data.get_dataset_size() == 6
assert data.get_dataset_size() == 60
assert data.get_batch_size() == 2
assert data.get_repeat_count() == 10
data2 = ds.TFRecordDataset(FILES, SCHEMA_FILE).batch(2).repeat(10)
data1 = data.zip(data2)
assert data1.get_dataset_size() == 6
assert data1.get_dataset_size() == 60
def test_case2():
......@@ -65,14 +65,14 @@ def test_case2():
data = data.rename("col_sint64", "new_column")
assert data.get_dataset_size() == 3
data = data.repeat(10)
assert data.get_dataset_size() == 3
assert data.get_dataset_size() == 30
data = data.project(["new_column"])
assert data.get_dataset_size() == 3
assert data.get_dataset_size() == 30
data2 = ds.TFRecordDataset(FILES, num_samples=6).batch(2).repeat(10)
data1 = data.zip(data2)
assert data1.get_dataset_size() == 3
assert data1.get_dataset_size() == 30
def test_case3():
......@@ -94,11 +94,11 @@ def test_case4():
data2 = data2.shuffle(100)
assert data2.get_dataset_size() == 6
data2 = data2.repeat(3)
assert data2.get_dataset_size() == 6
assert data2.get_dataset_size() == 18
data3 = ds.zip((data1, data2))
assert data3.get_dataset_size() == 6
assert data3.get_dataset_size() == 18
def test_case5():
......
......@@ -73,7 +73,7 @@ def test_iterator_weak_ref():
_cleanup()
with pytest.raises(AttributeError) as info:
itr2.get_next()
itr2.__next__()
assert "object has no attribute 'depipeline'" in str(info.value)
del itr1
......
......@@ -251,6 +251,49 @@ def test_nested_repeat11():
assert sum([1 for _ in data]) == 2 * 3 * 4 * 5 * 3
def test_repeat_count1():
data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False)
data1_size = data1.get_dataset_size()
logger.info("dataset size is {}".format(data1_size))
batch_size = 2
repeat_count = 4
resize_height, resize_width = 32, 32
decode_op = vision.Decode()
resize_op = vision.Resize((resize_height, resize_width), interpolation=ds.transforms.vision.Inter.LINEAR)
data1 = data1.map(input_columns=["image"], operations=decode_op)
data1 = data1.map(input_columns=["image"], operations=resize_op)
data1 = data1.repeat(repeat_count)
data1 = data1.batch(batch_size, drop_remainder=False)
dataset_size = data1.get_dataset_size()
logger.info("dataset repeat then batch's size is {}".format(dataset_size))
num1_iter = 0
for _ in data1.create_dict_iterator():
num1_iter += 1
assert data1_size == 3
assert dataset_size == num1_iter == 6
def test_repeat_count2():
data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False)
data1_size = data1.get_dataset_size()
logger.info("dataset size is {}".format(data1_size))
batch_size = 2
repeat_count = 4
resize_height, resize_width = 32, 32
decode_op = vision.Decode()
resize_op = vision.Resize((resize_height, resize_width), interpolation=ds.transforms.vision.Inter.LINEAR)
data1 = data1.map(input_columns=["image"], operations=decode_op)
data1 = data1.map(input_columns=["image"], operations=resize_op)
data1 = data1.batch(batch_size, drop_remainder=False)
data1 = data1.repeat(repeat_count)
dataset_size = data1.get_dataset_size()
logger.info("dataset batch then repeat's size is {}".format(dataset_size))
num1_iter = 0
for _ in data1.create_dict_iterator():
num1_iter += 1
assert data1_size == 3
assert dataset_size == num1_iter == 8
if __name__ == "__main__":
test_tf_repeat_01()
......@@ -268,3 +311,5 @@ if __name__ == "__main__":
test_nested_repeat9()
test_nested_repeat10()
test_nested_repeat11()
test_repeat_count1()
test_repeat_count2()
......@@ -252,14 +252,14 @@ def test_zip_exception_06():
if __name__ == '__main__':
test_zip_01()
test_zip_02()
test_zip_03()
test_zip_04()
test_zip_05()
test_zip_06()
test_zip_exception_01()
test_zip_exception_02()
test_zip_exception_03()
test_zip_exception_04()
test_zip_exception_05()
test_zip_exception_06()
#test_zip_02()
#test_zip_03()
#test_zip_04()
#test_zip_05()
#test_zip_06()
#test_zip_exception_01()
#test_zip_exception_02()
#test_zip_exception_03()
#test_zip_exception_04()
#test_zip_exception_05()
#test_zip_exception_06()
此差异已折叠。
......@@ -274,6 +274,9 @@ class DatasetLenet():
def get_repeat_count(self):
return 1
def create_tuple_iterator(self):
return self
def test_train_32k_8p(batch_size=32, num_classes=32768):
dev_num = 8
......
......@@ -61,6 +61,9 @@ class DatasetLenet():
def get_repeat_count(self):
return 1
def create_tuple_iterator(self):
return self
class Net(nn.Cell):
def __init__(self):
......
......@@ -58,6 +58,9 @@ class Dataset():
def get_repeat_count(self):
return 1
def create_tuple_iterator(self):
return self
class GatherV2(_Loss):
def __init__(self, index_dim, strategy, index_size=16):
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""test dataset helper."""
import pytest
import numpy as np
import mindspore.context as context
from mindspore.communication.management import init
from mindspore.train.dataset_helper import DatasetHelper
from ....dataset_mock import MindData
def get_dataset(batch_size=1):
dataset_types = (np.int32, np.int32, np.int32, np.int32, np.int32, np.int32, np.int32)
dataset_shapes = ((batch_size, 128), (batch_size, 128), (batch_size, 128), (batch_size, 1),
(batch_size, 20), (batch_size, 20), (batch_size, 20))
dataset = MindData(size=2, batch_size=batch_size, np_types=dataset_types,
output_shapes=dataset_shapes, input_indexs=(0, 1))
return dataset
def test_dataset_helper_dataset_sink_mode_str():
dataset = get_dataset(32)
with pytest.raises(TypeError):
DatasetHelper(dataset, dataset_sink_mode="True")
def test_dataset_helper_dataset_sink_mode_int():
dataset = get_dataset(32)
with pytest.raises(TypeError):
DatasetHelper(dataset, dataset_sink_mode=1)
def test_dataset_helper_sink_size_bool():
dataset = get_dataset(32)
with pytest.raises(TypeError):
DatasetHelper(dataset, dataset_sink_mode=True, sink_size=True)
def test_dataset_helper_sink_size_float():
dataset = get_dataset(32)
with pytest.raises(TypeError):
DatasetHelper(dataset, dataset_sink_mode=True, sink_size=1.0)
def test_dataset_helper_sink_size_negative():
dataset = get_dataset(32)
with pytest.raises(ValueError):
DatasetHelper(dataset, dataset_sink_mode=True, sink_size=-2)
def test_dataset_iter_normal():
dataset = get_dataset(32)
dataset_helper = DatasetHelper(dataset, dataset_sink_mode=False)
count = 0
for _ in range(2):
for _ in dataset_helper:
count += 1
dataset.reset()
assert count == 6
@pytest.mark.skipif('not context.get_context("enable_ge")')
def test_dataset_iter_ge():
init()
dataset = get_dataset(32)
dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)
count = 0
for _ in range(2):
for _ in dataset_helper:
count += 1
assert count == 2
@pytest.mark.skipif('context.get_context("enable_ge")')
def test_dataset_iter_ms_loop_sink():
init()
context.set_context(enable_loop_sink=True)
dataset = get_dataset(32)
dataset_helper = DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)
count = 0
for _ in range(2):
for inputs in dataset_helper:
count += 1
assert inputs == tuple()
assert count == 2
@pytest.mark.skipif('context.get_context("enable_ge")')
def test_dataset_iter_ms():
init()
context.set_context(enable_loop_sink=False)
dataset = get_dataset(32)
DatasetHelper(dataset, dataset_sink_mode=True, sink_size=10)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册