提交 f44d2135 编写于 作者: J Junhan Hu

MindData optimizer infrastructure.

上级 6cbde2b3
......@@ -65,6 +65,7 @@ set(submodules
$<TARGET_OBJECTS:engine-datasetops-source>
$<TARGET_OBJECTS:engine-datasetops-source-sampler>
$<TARGET_OBJECTS:engine-datasetops>
$<TARGET_OBJECTS:engine-opt>
$<TARGET_OBJECTS:engine>
)
......
add_subdirectory(datasetops)
add_subdirectory(opt)
if (ENABLE_TDTQUE)
add_subdirectory(tdt)
endif ()
......@@ -14,7 +15,7 @@ add_library(engine OBJECT
target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS})
if (ENABLE_TDTQUE)
add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt)
add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt)
else()
add_dependencies(engine engine-datasetops engine-datasetops-source)
add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt)
endif ()
......@@ -22,6 +22,7 @@
#include "dataset/core/pybind_support.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/opt/pass.h"
using float16 = Eigen::half;
......@@ -462,5 +463,11 @@ Status BatchOp::PadHelper(std::shared_ptr<Tensor> src, std::shared_ptr<Tensor> d
return Status::OK();
}
// Visitor accept method for NodePass
Status BatchOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<BatchOp>(shared_from_this()), modified);
}
} // namespace dataset
} // namespace mindspore
......@@ -192,6 +192,12 @@ class BatchOp : public ParallelOp {
Status PadTensor(std::shared_ptr<Tensor> src, std::shared_ptr<Tensor> *dst, const std::vector<dsize_t> &pad_shape,
float pad_val);
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
private:
// recursive helper function. This function could be very expensive if called on a multi-dimensional tensor
// it is only meant to be called by PadTensor.
......
......@@ -25,6 +25,7 @@
#include "dataset/engine/datasetops/device_queue_op.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/opt/pass.h"
#include "utils/log_adapter.h"
......@@ -249,5 +250,11 @@ Status DatasetOp::AssignColMapFromChild() {
}
return Status::OK();
}
Status DatasetOp::Accept(NodePass *p, bool *modified) {
// DatasetOp is the base class of visitor target.
// This method will only be called if its derived class does not implement one.
return p->RunOnNode(shared_from_this(), modified);
}
} // namespace dataset
} // namespace mindspore
......@@ -32,6 +32,8 @@ class ExecutionTree;
class DataBuffer;
class NodePass;
// The base class DatasetOp is the main tree node. It is an abstract class, so
// the actual implementation of the operators will be derived from here.
class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
......@@ -209,6 +211,16 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
// @return - the column name map as a string
std::string ColumnNameMapAsString() const;
// Children Getter
// @return Vector or Children
std::vector<std::shared_ptr<DatasetOp>> Children() const { return child_; }
// Base method for NodePass visit.
// Subclass needs to override this if it requires special node visit access.
// Check "dataset/engine/opt/pass.h" for more details.
// @return Statue of the node visit
virtual Status Accept(NodePass *p, bool *modified);
protected:
// Adds a parent operator to this operator
// @notes External callers do not have access to this function.
......
......@@ -24,6 +24,7 @@
#include "dataset/engine/dataset_iterator.h"
#include "dataset/util/status.h"
#include "dataset/util/task_manager.h"
#include "dataset/engine/opt/pass.h"
#ifdef ENABLE_TDTQUE
#include "tdt/tsd_client.h"
......@@ -265,5 +266,12 @@ void DeviceQueueOp::Print(std::ostream &out, bool show_all) const {
out << "\nChannel name: " << channel_name_ << "\nPrefetch size: " << prefetch_size_ << "\n\n";
}
}
// Visitor accept method for NodePass
Status DeviceQueueOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<DeviceQueueOp>(shared_from_this()), modified);
}
} // namespace dataset
} // namespace mindspore
......@@ -134,6 +134,12 @@ class DeviceQueueOp : public PipelineOp {
Status operator()() override;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
private:
// Name: checkExceptions(DataBuffer);
// Description: Check whether the dataBuffer meets the condition for performing DeviceQueueOp
......
......@@ -27,6 +27,7 @@
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/kernels/tensor_op.h"
#include "utils/log_adapter.h"
#include "dataset/util/task_manager.h"
......@@ -259,5 +260,11 @@ Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate
}
return Status(StatusCode::kOK, "FilterOp predicate func call succeed");
}
// Visitor accept method for NodePass
Status FilterOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<FilterOp>(shared_from_this()), modified);
}
} // namespace dataset
} // namespace mindspore
......@@ -121,6 +121,12 @@ class FilterOp : public ParallelOp {
// @param show_all A bool to control if you want to show all info or just a summary.
void Print(std::ostream &out, bool show_all) const override;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
private:
// predicate_func python callable which returns a boolean value.
py::function predicate_func_;
......
......@@ -27,6 +27,7 @@
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/kernels/tensor_op.h"
#include "utils/log_adapter.h"
#include "dataset/util/task_manager.h"
......@@ -370,5 +371,11 @@ void MapOp::CreateFinalColMap(std::unordered_map<std::string, int32_t> *col_name
column_name_id_map_ = final_col_name_id_map;
}
}
// Visitor accept method for NodePass
Status MapOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<MapOp>(shared_from_this()), modified);
}
} // namespace dataset
} // namespace mindspore
......@@ -171,6 +171,12 @@ class MapOp : public ParallelOp {
// @return the number of threads consuming data from previous op's output Connector.
int32_t num_consumers() const override;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
private:
// Local queues where worker threads can pop from.
// Popping directly from the Connector can block if the previous designated threads haven't pop.
......
......@@ -25,6 +25,7 @@
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
#include "utils/log_adapter.h"
namespace mindspore {
......@@ -144,5 +145,11 @@ Status ProjectOp::EoeReceived(int32_t worker_id) {
}
Status ProjectOp::EofReceived(int32_t worker_id) { return Status::OK(); }
// Visitor accept method for NodePass
Status ProjectOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<ProjectOp>(shared_from_this()), modified);
}
} // namespace dataset
} // namespace mindspore
......@@ -101,6 +101,12 @@ class ProjectOp : public PipelineOp {
// @return Status - The error code returned.
Status EofReceived(int32_t worker_id) override;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
private:
std::vector<std::string> columns_to_project_;
std::vector<int32_t> projected_column_indices_;
......
......@@ -24,6 +24,7 @@
#include "dataset/core/global_context.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/opt/pass.h"
#include "utils/log_adapter.h"
namespace mindspore {
......@@ -170,5 +171,11 @@ Status RenameOp::EoeReceived(int32_t) {
state_ = OpState::kDeOpIdle;
return Status::OK();
}
// Visitor accept method for NodePass
Status RenameOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<RenameOp>(shared_from_this()), modified);
}
} // namespace dataset
} // namespace mindspore
......@@ -110,6 +110,12 @@ class RenameOp : public PipelineOp {
// @return Status - The error code return
Status operator()() override;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
protected:
// Rename core functionality
Status RenameColumns();
......
......@@ -21,6 +21,7 @@
#include "dataset/engine/datasetops/repeat_op.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/opt/pass.h"
#include "utils/log_adapter.h"
......@@ -187,5 +188,11 @@ int32_t RepeatOp::num_producers() const {
return child_[0]->num_producers();
}
}
// Visitor accept method for NodePass
Status RepeatOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<RepeatOp>(shared_from_this()), modified);
}
} // namespace dataset
} // namespace mindspore
......@@ -118,6 +118,12 @@ class RepeatOp : public PipelineOp {
// @param workerId - The worker id
int32_t num_producers() const override;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
private:
int32_t max_repeats_; // The number of repeats that the user requested
int32_t repeat_count_; // A counter for the current number of executed repeats
......
......@@ -30,6 +30,7 @@
#include "dataset/engine/dataset_iterator.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/util/random.h"
#include "dataset/util/status.h"
......@@ -296,5 +297,11 @@ Status ShuffleOp::EoeReceived(int32_t worker_id) {
state_ = OpState::kDeOpIdle;
return Status::OK();
}
// Visitor accept method for NodePass
Status ShuffleOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<ShuffleOp>(shared_from_this()), modified);
}
} // namespace dataset
} // namespace mindspore
......@@ -155,6 +155,12 @@ class ShuffleOp : public PipelineOp {
// @return Status - The error code return
Status EoeReceived(int32_t worker_id) override;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
private:
// Private function to add a new row to the shuffle buffer.
// @return Status - The error code return
......
......@@ -22,6 +22,7 @@
#include "dataset/engine/datasetops/skip_op.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
#include "utils/log_adapter.h"
......@@ -128,5 +129,11 @@ Status SkipOp::EofReceived(int32_t worker_id) {
MS_LOG(DEBUG) << "Skip operator EOF received, do nothing now.";
return Status::OK();
}
// Visitor accept method for NodePass
Status SkipOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<SkipOp>(shared_from_this()), modified);
}
} // namespace dataset
} // namespace mindspore
......@@ -74,6 +74,12 @@ class SkipOp : public PipelineOp {
// @param worker_id - The worker id
Status EofReceived(int32_t worker_id) override;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
private:
int32_t max_skips_; // The number of skips that the user requested
int32_t skip_count_; // A counter for the current number of executed skips
......
......@@ -20,6 +20,7 @@
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/util/task_manager.h"
#include "dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
......@@ -250,5 +251,11 @@ Status GeneratorOp::Reset() {
wp_.Set();
return Status(StatusCode::kOK, "GeneratorOp Reset Succeed");
}
// Visitor accept method for NodePass
Status GeneratorOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<GeneratorOp>(shared_from_this()), modified);
}
} // namespace dataset
} // namespace mindspore
......@@ -121,6 +121,12 @@ class GeneratorOp : public PipelineOp {
// @return Status - The error code return
Status Reset() override;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
private:
py::function generator_function_;
std::vector<std::string> column_names_;
......
......@@ -22,6 +22,7 @@
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
......@@ -451,5 +452,11 @@ Status ImageFolderOp::CountRowsAndClasses(const std::string &path, const int64_t
(*num_rows) = (row_cnt / num_dev) + (row_cnt % num_dev == 0 ? 0 : 1);
return Status::OK();
}
// Visitor accept method for NodePass
Status ImageFolderOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<ImageFolderOp>(shared_from_this()), modified);
}
} // namespace dataset
} // namespace mindspore
......@@ -225,6 +225,12 @@ class ImageFolderOp : public ParallelOp, public RandomAccessOp {
const std::set<std::string> &exts, int64_t *num_rows, int64_t *num_classes,
int64_t dev_id = 0, int64_t num_dev = 1);
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
private:
// Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return
......
......@@ -29,6 +29,7 @@
#include "dataset/engine/datasetops/dataset_op.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
#include "utils/log_adapter.h"
namespace mindspore {
......@@ -684,5 +685,11 @@ Status MindRecordOp::CountTotalRows(const std::vector<std::string> dataset_path,
}
return Status::OK();
}
// Visitor accept method for NodePass
Status MindRecordOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<MindRecordOp>(shared_from_this()), modified);
}
} // namespace dataset
} // namespace mindspore
......@@ -195,6 +195,12 @@ class MindRecordOp : public ParallelOp {
Status SetColumnsBlob();
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
private:
Status GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_buffer, int64_t buffer_id, int32_t worker_id);
......
......@@ -37,6 +37,7 @@
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/jagged_connector.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/util/path.h"
#include "dataset/util/queue.h"
#include "dataset/util/random.h"
......@@ -1037,5 +1038,11 @@ int64_t TFReaderOp::CountTotalRowsSectioned(const std::vector<std::string> &file
return rows_read;
}
// Visitor accept method for NodePass
Status TFReaderOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<TFReaderOp>(shared_from_this()), modified);
}
} // namespace dataset
} // namespace mindspore
......@@ -222,6 +222,12 @@ class TFReaderOp : public ParallelOp {
static Status CountTotalRows(int64_t *out_total_rows, const std::vector<std::string> &filenames, int64_t threads = 1,
bool estimate = false);
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
......
......@@ -22,6 +22,7 @@
#include "dataset/engine/datasetops/take_op.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
......@@ -132,5 +133,11 @@ Status TakeOp::PrepareNodePostAction() {
tree_->AddToRepeatStack(shared_from_this());
return Status::OK();
}
// Visitor accept method for NodePass
Status TakeOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<TakeOp>(shared_from_this()), modified);
}
} // namespace dataset
} // namespace mindspore
......@@ -84,6 +84,12 @@ class TakeOp : public PipelineOp {
// before providing their own implementations.
Status PrepareNodePostAction() override;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
private:
int32_t max_takes_; // The number of takes that the user requested
int32_t take_count_; // A counter for the current number of executed takes
......
......@@ -19,6 +19,7 @@
#include "dataset/core/constants.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/core/config_manager.h"
#include "dataset/core/global_context.h"
#include "utils/log_adapter.h"
......@@ -250,5 +251,11 @@ Status ZipOp::EoeReceived(int32_t) {
state_ = OpState::kDeOpIdle;
return Status::OK();
}
// Visitor accept method for NodePass
Status ZipOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(std::static_pointer_cast<ZipOp>(shared_from_this()), modified);
}
} // namespace dataset
} // namespace mindspore
......@@ -104,6 +104,12 @@ class ZipOp : public PipelineOp {
// @return Status - The error code return
Status operator()() override;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
private:
// Handles preprocessing of the main loop, used when starting new epoch
Status prepare(TensorQTable *const table);
......
......@@ -20,6 +20,8 @@
#include "dataset/engine/datasetops/shuffle_op.h"
#include "dataset/util/task_manager.h"
#include "dataset/engine/opt/util/printer_pass.h"
namespace mindspore {
namespace dataset {
// Constructor
......@@ -161,10 +163,54 @@ Status ExecutionTree::LaunchWorkers(int32_t num_workers, std::function<Status(ui
return Status::OK();
}
// The driver of the prepare phase of the execution tree.
// Prepare phase consists of three sub phases
//
// 1. PrepareTreePreAction()
// Compulsory transformation/action pre optimization.
// For example, CacheOp Insertion
//
// 2. Optimize()
// Optimization transformation/action, optional
// For example, MapOp Fusion
//
// 3. PrepareTreePostAction()
// Compulsory transformation/action post optimization.
// For example, repeatOp inlining
//
// @return Status - The error code return
Status ExecutionTree::Prepare() {
// Pre optimization compulsory transformation
RETURN_IF_NOT_OK(this->PrepareTreePreAction());
// Optimization transformation
RETURN_IF_NOT_OK(this->Optimize());
// Post optimization compulsory transformation
RETURN_IF_NOT_OK(this->PrepareTreePostAction());
// Existing transformation implementation, will be removed later
RETURN_IF_NOT_OK(this->PrepareDeprecated());
return Status::OK();
}
Status ExecutionTree::PrepareTreePreAction() { return Status::OK(); }
Status ExecutionTree::PrepareTreePostAction() { return Status::OK(); }
Status ExecutionTree::Optimize() {
// auto pp = new PrinterPass();
// bool modified = false;
// pp->Run(this, &modified);
return Status::OK();
}
// The driver of the prepare phase of the execution tree. The prepare phase will recursively
// walk the tree to perform modifications to the tree or specific nodes within the tree to get
// it ready for execution.
Status ExecutionTree::Prepare() {
//
// This driver is deprecated.
Status ExecutionTree::PrepareDeprecated() {
// Tree must be in pending prepare state before we can assign root to it
if (tree_state_ != kDeTStatePrepare) {
std::string err_msg =
......
......@@ -152,11 +152,41 @@ class ExecutionTree {
// @return the prepare flags
uint32_t PrepareFlags() const { return prepare_flags_; }
// The driver of the prepare phase of the execution tree. The prepare phase will recursively
// The driver of the prepare phase of the execution tree.
// Prepare phase consists of three sub phases
//
// 1. PrepareTreePreAction()
// Compulsory transformation/action pre optimization.
// For example, CacheOp Insertion
//
// 2. Optimize()
// Optimization transformation/action, optional
// For example, MapOp Fusion
//
// 3. PrepareTreePostAction()
// Compulsory transformation/action post optimization.
// For example, repeatOp inlining
//
// @return Status - The error code return
Status Prepare();
// Compulsory transformation/action pre optimization.
// @return Status - The error code return
Status PrepareTreePreAction();
// Compulsory transformation/action post optimization.
// @return Status - The error code return
Status PrepareTreePostAction();
// Optimization transformation/action, optional.
// @return Status - The error code return
Status Optimize();
// The DEPRECATED driver of the prepare phase of the execution tree. The prepare phase will recursively
// walk the tree to perform modifications to the tree or specific nodes within the tree to get
// it ready for execution.
// @return Status - The error code return
Status Prepare();
Status PrepareDeprecated();
// Recursive function used during prepare phase to visit a node and drive any pre- and post-
// node actions during a tree walk.
......
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(engine-opt OBJECT
pass.cc
util/printer_pass.cc
)
\ No newline at end of file
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/opt/pass.h"
#include "dataset/engine/datasetops/dataset_op.h"
#include "dataset/engine/datasetops/batch_op.h"
#include "dataset/engine/datasetops/dataset_op.h"
#include "dataset/engine/datasetops/device_queue_op.h"
#include "dataset/engine/datasetops/map_op.h"
#include "dataset/engine/datasetops/project_op.h"
#include "dataset/engine/datasetops/rename_op.h"
#include "dataset/engine/datasetops/filter_op.h"
#include "dataset/engine/datasetops/repeat_op.h"
#include "dataset/engine/datasetops/skip_op.h"
#include "dataset/engine/datasetops/shuffle_op.h"
#include "dataset/engine/datasetops/source/generator_op.h"
#include "dataset/engine/datasetops/source/mindrecord_op.h"
#include "dataset/engine/datasetops/source/storage_op.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/datasetops/source/image_folder_op.h"
#include "dataset/engine/datasetops/take_op.h"
#include "dataset/engine/datasetops/zip_op.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/util/status.h"
namespace mindspore {
namespace dataset {
// Driver method for TreePass
Status TreePass::Run(ExecutionTree *tree, bool *modified) { return this->RunOnTree(tree, modified); }
// Driver method for NodePass
Status NodePass::Run(ExecutionTree *tree, bool *modified) {
std::shared_ptr<DatasetOp> root = tree->root();
if (traversalOrder_ == Order::DFS) {
// DFS
return DFSNodeVisit(root, modified);
} else if (traversalOrder_ == Order::BFS) {
// BFS
return BFSNodeVisit(root, modified);
}
return Status::OK();
}
// Helper function to perform DFS visit
Status NodePass::DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified) {
for (const auto &c : node->Children()) {
RETURN_IF_NOT_OK(this->DFSNodeVisit(c, modified));
}
return node->Accept(this, modified);
}
// Helper function to perform BFS visit
Status NodePass::BFSNodeVisit(std::shared_ptr<DatasetOp> root, bool *modified) {
// Initialize bfs queue with root
std::queue<std::shared_ptr<DatasetOp>> bfsQueue;
bfsQueue.push(root);
// BFS loop
while (!bfsQueue.empty()) {
// Pop the front of the bfs queue
auto curNode = bfsQueue.front();
bfsQueue.pop();
// Run node pass
RETURN_IF_NOT_OK(curNode->Accept(this, modified));
// Push children into bfs queue
for (const auto &c : curNode->Children()) {
bfsQueue.push(c);
}
}
return Status::OK();
}
Status NodePass::RunOnNode(std::shared_ptr<BatchOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<MapOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<ZipOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_OPT_PASS_H_
#define DATASET_ENGINE_OPT_PASS_H_
#include <memory>
#include <queue>
#include "dataset/engine/execution_tree.h"
#include "dataset/util/status.h"
namespace mindspore {
namespace dataset {
class BatchOp;
class MapOp;
class ProjectOp;
class RenameOp;
class FilterOp;
class SkipOp;
class ShuffleOp;
class GeneratorOp;
class MindRecordOp;
class TFReaderOp;
class TakeOp;
class ZipOp;
class DeviceQueueOp;
class ImageFolderOp;
// The base class Pass is the basic unit of tree transformation.
// The actual implementation of the passes will be derived from here.
class Pass : public std::enable_shared_from_this<Pass> {
public:
// Run the transformation pass again the execution tree.
// @param tree - Pointer to the execution tree to be transformed.
// @param modified - Pointer to the modified flag,
virtual Status Run(ExecutionTree *tree, bool *modified) { return Status::OK(); }
};
// TreePass is a basic Pass class which performs transformation on ExecutionTree directly.
class TreePass : public Pass {
public:
// Run the transformation pass against the execution tree.
// @param tree - Pointer to the execution tree to be transformed.
// @param modified - Pointer to the modified flag,
Status Run(ExecutionTree *tree, bool *modified) final;
// Derived classes may implement the runOnTree function to implement tree transformation.
// "modified" flag needs to be set to true if tree is modified during the pass execution.
// @return Status - The error code return
virtual Status RunOnTree(ExecutionTree *tree, bool *modified) { return Status::OK(); }
};
// NodePass is a basic Pass class which performs transformation on Node visiting.
// NodePass implements Visitor design pattern.
class NodePass : public Pass {
public:
// Tree traversal order
enum Order { DFS, BFS };
// Constructor
// Default DFS traversal
explicit NodePass(Order order = Order::DFS) { traversalOrder_ = order; }
// Run the transformation pass against the execution tree.
// @param tree - Pointer to the execution tree to be transformed.
// @param modified - Pointer to the modified flag,
Status Run(ExecutionTree *tree, bool *modified) final;
// Derived classes may implement the runOnNode function to implement node level tree transformation.
// "modified" flag needs to be set to true if tree is modified during the pass execution.
// @return Status - The error code return
virtual Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) { return Status::OK(); }
// Visit methods to be overridden.
// Note that member template can not be virtual, any op which wants to work with NodePass should declare RunOnNode
// of its own type and override "Accept" from DatasetOp.
virtual Status RunOnNode(std::shared_ptr<BatchOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<RenameOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<ZipOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified);
private:
// Helper function to perform DFS visit
Status DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified);
// Helper function to perform BFS visit
Status BFSNodeVisit(std::shared_ptr<DatasetOp> root, bool *modified);
// Tree traversal order of the NodePass
Order traversalOrder_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_OPT_PASS_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "dataset/engine/opt/util/printer_pass.h"
namespace mindspore {
namespace dataset {
Status PrinterPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting DatasetOp" << '\n';
return Status::OK();
}
Status PrinterPass::RunOnNode(std::shared_ptr<BatchOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting BatchOp" << '\n';
return Status::OK();
}
Status PrinterPass::RunOnNode(std::shared_ptr<MapOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting MapOp" << '\n';
return Status::OK();
}
Status PrinterPass::RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting ProjectOp" << '\n';
return Status::OK();
}
Status PrinterPass::RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting RenameOp" << '\n';
return Status::OK();
}
Status PrinterPass::RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting FilterOp" << '\n';
return Status::OK();
}
Status PrinterPass::RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting SkipOp" << '\n';
return Status::OK();
}
Status PrinterPass::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting ShuffleOp" << '\n';
return Status::OK();
}
Status PrinterPass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting GeneratorOp" << '\n';
return Status::OK();
}
Status PrinterPass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting MindRecordOp" << '\n';
return Status::OK();
}
Status PrinterPass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting TFReaderOp" << '\n';
return Status::OK();
}
Status PrinterPass::RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting TakeOp" << '\n';
return Status::OK();
}
Status PrinterPass::RunOnNode(std::shared_ptr<ZipOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting ZipOp" << '\n';
return Status::OK();
}
Status PrinterPass::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting DeviceQueueOp" << '\n';
return Status::OK();
}
Status PrinterPass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) {
*modified = false;
std::cout << "Visiting ImageFolderOp" << '\n';
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H
#define DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H
#include <memory>
#include "dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
class PrinterPass : public NodePass {
public:
Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<BatchOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<MapOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<ProjectOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<RenameOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<FilterOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<SkipOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<ZipOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override;
Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_OPT_PASS_UTIL_PRINTER_H
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import pytest
import mindspore.dataset as ds
# Generate 1d int numpy array from 0 - 63
def generator_1d():
for i in range(64):
yield (np.array([i]),)
def test_case_0():
"""
Test 1D Generator
"""
# apply dataset operations
data1 = ds.GeneratorDataset(generator_1d, ["data"])
data1 = data1.shuffle(2)
data1 = data1.map(["data"], operations=(lambda x : x))
data1 = data1.batch(2)
i = 0
for item in data1.create_dict_iterator(): # each data is a dictionary
pass
if __name__ == "__main__":
test_case_0()
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册