提交 9fb1904e 编写于 作者: N Nat Sutyanyong

Refactoring opt/pre

上级 b13c7a3d
......@@ -23,7 +23,7 @@
#include "minddata/dataset/engine/opt/pre/removal_pass.h"
#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h"
#include "minddata/dataset/engine/opt/post/repeat_pass.h"
#include "minddata/dataset/engine/opt/pre/injection_pass.h"
#include "minddata/dataset/engine/opt/pre/epoch_injection_pass.h"
#include "mindspore/ccsrc/minddata/dataset/engine/opt/optional/tensor_op_fusion_pass.h"
#include "minddata/dataset/engine/perf/profiling.h"
#include "minddata/dataset/engine/perf/monitor.h"
......@@ -225,7 +225,7 @@ Status ExecutionTree::PrepareTreePreAction() {
std::vector<std::unique_ptr<Pass>> pre_actions;
// Construct pre actions
MS_LOG(INFO) << "Running pre pass loops.";
pre_actions.push_back(std::make_unique<InjectionPass>());
pre_actions.push_back(std::make_unique<EpochInjectionPass>());
pre_actions.push_back(std::make_unique<RemovalPass>());
pre_actions.push_back(std::make_unique<CacheTransformPass>());
// Apply pre action passes
......
......@@ -3,10 +3,8 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE
add_library(engine-opt OBJECT
pass.cc
post/repeat_pass.cc
pre/cache_pass.cc
pre/cache_transform_pass.cc
pre/injection_pass.cc
pre/removal_nodes.cc
pre/epoch_injection_pass.cc
pre/removal_pass.cc
optional/tensor_op_fusion_pass.cc
util/printer_pass.cc
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "minddata/dataset/engine/opt/pre/cache_pass.h"
#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
#include "minddata/dataset/engine/datasetops/source/generator_op.h"
#include "minddata/dataset/engine/datasetops/source/manifest_op.h"
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
namespace mindspore {
namespace dataset {
// Constructor
CachePass::CachePass(CacheTransformPass *transform_pass)
: transform_pass_(transform_pass), is_caching_(false), leaf_op_(nullptr) {}
// Identifies the subtree below this node as a cached descendant tree.
Status CachePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
*modified = false;
MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree.";
if (is_caching_) {
RETURN_STATUS_UNEXPECTED("Nested cache operations is not supported!");
}
is_caching_ = true;
return Status::OK();
}
// Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache
// transformation
Status CachePass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
*modified = false;
is_caching_ = false; // We a no longer in a cache subtree. clear the flag.
if (leaf_op_) {
MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache.";
// Assign the leaf op into the transform pass, using move to null our copy of it, and also assign the cache op,
// using base class pointers.
transform_pass_->AddMappableCacheOperators(std::move(leaf_op_), node);
} else {
// If there was no leaf_op set, then this is a non-mappable scenario.
if (sampler_) {
// Grab the sampler that was saved from the leaf and plug it into the cache op
node->SetSampler(std::move(sampler_));
MS_LOG(INFO) << "Cache transform pass: Set up cache sampler from non-mappable leaf.";
} else {
// We're a cache op but no sampler was saved from leaf, so create a default sampler
const int64_t num_samples = 0;
const int64_t start_index = 0;
sampler_ = std::make_shared<SequentialSampler>(num_samples, start_index);
node->SetSampler(std::move(sampler_));
MS_LOG(INFO) << "Cache transform pass: Creating default sequential sampler for cache op.";
}
// Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache
uint32_t cache_crc = DatasetOp::GenerateCRC(node);
RETURN_IF_NOT_OK(node->CreateCache(cache_crc));
}
return Status::OK();
}
// Common code for mappable leaf setup.
Status CachePass::MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) {
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
if (is_caching_ && leaf_op_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache.");
}
// If we are a leaf in the caching path, then save this leaf.
if (is_caching_) {
MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected";
leaf_op_ = std::move(leaf_op);
}
return Status::OK();
}
// Common code for non mappable leaf setup.
Status CachePass::NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) {
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
if (is_caching_ && leaf_op_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache.");
}
// Sampler for non mapable dataset only works if there is a downstream cache. Remove it from the leaf
// as save it for use by cache op in ascendant tree.
if (is_caching_) {
RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_));
MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected";
} else {
// If we are a non-mappable leaf and are not in a cache tree, then this sampler is not used so we can
// remove it here. The leaf itself will provide it's own methods of fetching the data (not sampler-based)
std::shared_ptr<Sampler> sampler_from_leaf;
RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_from_leaf));
}
return Status::OK();
}
// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) {
if (is_caching_) {
// If we are a TF Reader in a caching tree, then change our config so that it becomes a basic
// TF reader that parses all files. Selection of data will come from the sampler on the cache instead.
node->MakeSimpleProducer();
}
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) {
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CachePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_H_
#include <memory>
#include <string>
#include <utility>
#include "minddata/dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
class CacheTransformPass;
/// \class CachePass cache_pass.h
/// \brief This is a NodePass who's job is to identify and set up the nodes that will be involved in a cache
/// transformation. It works in conjunction with the CacheTransformPass
class CachePass : public NodePass {
public:
/// \brief Constructor
/// \param[in] transform_pass Raw pointer back to controlling tree pass
explicit CachePass(CacheTransformPass *transform_pass);
/// \brief Destructor
~CachePass() = default;
/// \brief Identifies the subtree below this node as a cached descendant tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
/// \brief Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache
/// transformation
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override;
private:
/// \brief Common code for mappable leaf setup.
/// \param[in] node The leaf node performing setup work.
/// \return Status The error code return
Status MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op);
/// \brief Common code for non-mappable leaf setup.
/// \param[in] node The leaf node performing setup work.
/// \return Status The error code return
Status NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op);
bool is_caching_;
std::shared_ptr<DatasetOp> leaf_op_;
std::shared_ptr<Sampler> sampler_;
CacheTransformPass *transform_pass_; // Back pointer to the owning transform pass
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_CACHE_PASS_
......@@ -15,17 +15,177 @@
*/
#include <vector>
#include "minddata/dataset/engine/opt/pre/cache_pass.h"
#include "minddata/dataset/engine/opt/pre/cache_transform_pass.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
#include "minddata/dataset/engine/datasetops/source/coco_op.h"
#include "minddata/dataset/engine/datasetops/source/generator_op.h"
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
#include "minddata/dataset/engine/datasetops/source/manifest_op.h"
#include "minddata/dataset/engine/datasetops/source/mindrecord_op.h"
#include "minddata/dataset/engine/datasetops/source/mnist_op.h"
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
namespace mindspore {
namespace dataset {
// Constructor
CacheTransformPass::CachePass::CachePass() : is_caching_(false), leaf_op_(nullptr) {}
// Identifies the subtree below this node as a cached descendant tree.
Status CacheTransformPass::CachePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
*modified = false;
MS_LOG(INFO) << "Cache transform pass: CacheOp found, identified descendant tree.";
if (is_caching_) {
RETURN_STATUS_UNEXPECTED("Nested cache operations is not supported!");
}
is_caching_ = true;
return Status::OK();
}
// Resets the tracking of the cache within the tree and assigns the operators that will be involved in a cache
// transformation
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
*modified = false;
is_caching_ = false; // We a no longer in a cache subtree. clear the flag.
if (leaf_op_) {
MS_LOG(INFO) << "Cache transform pass: Set up transformation nodes for mappable cache.";
// Assign the leaf op into the transform pass, using move to null our copy of it, and also assign the cache op,
// using base class pointers.
AddMappableCacheOperators(std::move(leaf_op_), node);
} else {
// If there was no leaf_op set, then this is a non-mappable scenario.
if (sampler_) {
// Grab the sampler that was saved from the leaf and plug it into the cache op
node->SetSampler(std::move(sampler_));
MS_LOG(INFO) << "Cache transform pass: Set up cache sampler from non-mappable leaf.";
} else {
// We're a cache op but no sampler was saved from leaf, so create a default sampler
int64_t num_samples = 0;
int64_t start_index = 0;
sampler_ = std::make_shared<SequentialSampler>(num_samples, start_index);
node->SetSampler(std::move(sampler_));
MS_LOG(INFO) << "Cache transform pass: Creating default sequential sampler for cache op.";
}
// Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache
uint32_t cache_crc = DatasetOp::GenerateCRC(node);
RETURN_IF_NOT_OK(node->CreateCache(cache_crc));
}
return Status::OK();
}
// Common code for mappable leaf setup.
Status CacheTransformPass::CachePass::MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) {
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
if (is_caching_ && leaf_op_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache.");
}
// If we are a leaf in the caching path, then save this leaf.
if (is_caching_) {
MS_LOG(DEBUG) << "Cache transform pass: Mappable leaf in a cache descendant tree detected";
leaf_op_ = std::move(leaf_op);
}
return Status::OK();
}
// Common code for non mappable leaf setup.
Status CacheTransformPass::CachePass::NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op) {
// If a leaf has already been assigned, then we have more than one leaf inside this cache descendant tree.
if (is_caching_ && leaf_op_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for multiple leaf nodes under cache.");
}
// Sampler for non mapable dataset only works if there is a downstream cache. Remove it from the leaf
// as save it for use by cache op in ascendant tree.
if (is_caching_) {
RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_));
MS_LOG(DEBUG) << "Cache transform pass: Non mappable leaf in a cache descendant tree detected";
} else {
// If we are a non-mappable leaf and are not in a cache tree, then this sampler is not used so we can
// remove it here. The leaf itself will provide it's own methods of fetching the data (not sampler-based)
std::shared_ptr<Sampler> sampler_from_leaf;
RETURN_IF_NOT_OK(leaf_op->FetchRemoveSampler(&sampler_from_leaf));
}
return Status::OK();
}
// Perform leaf node cache tranform identifications
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) {
if (is_caching_) {
// If we are a TF Reader in a caching tree, then change our config so that it becomes a basic
// TF reader that parses all files. Selection of data will come from the sampler on the cache instead.
node->MakeSimpleProducer();
}
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) {
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache tranform identifications
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) {
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Assigns the leaf and cache operators that are involved in a cache transformation
void CacheTransformPass::CachePass::AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op,
std::shared_ptr<CacheOp> cache_op) {
cache_pairs_.push_back(std::make_pair(leaf_op, cache_op));
}
// constructor
CacheTransformPass::CacheTransformPass() {}
......@@ -34,11 +194,11 @@ Status CacheTransformPass::RunOnTree(ExecutionTree *tree, bool *modified) {
MS_LOG(INFO) << "Pre pass: Cache transform pass started.";
// Create the cache pass and run it. The cache pass identifies and creates the leaf/cache pairs that we will
// use to execute a transform.
std::unique_ptr<Pass> cache_pass = std::make_unique<CachePass>(this);
RETURN_IF_NOT_OK(cache_pass->Run(tree, modified));
CachePass cache_pass = CachePass();
RETURN_IF_NOT_OK(cache_pass.Run(tree, modified));
// Then, execute the transform for each pair
for (auto cache_pair : cache_pairs_) {
for (auto cache_pair : cache_pass.cache_pairs()) {
MS_LOG(DEBUG) << "Cache transform pass: Executing a cache op mappable transform.";
ExecuteCacheTransform(tree, cache_pair.first, cache_pair.second, cache_pair.second->cache_client());
}
......@@ -98,11 +258,5 @@ Status CacheTransformPass::ExecuteCacheTransform(ExecutionTree *tree, std::share
return Status::OK();
}
// Assigns the leaf and cache operators that are involved in a cache transformation
void CacheTransformPass::AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op,
std::shared_ptr<CacheOp> cache_op) {
cache_pairs_.push_back(std::make_pair(leaf_op, cache_op));
}
} // namespace dataset
} // namespace mindspore
......@@ -33,6 +33,123 @@ class CacheClient;
/// \brief This is a tree pass that will invoke a tree transformation to inject the correct operators for caching
/// operations
class CacheTransformPass : public TreePass {
/// \class CachePass
/// \brief This is a NodePass who's job is to identify and set up the nodes that will be involved in a cache
/// transformation. It works in conjunction with the CacheTransformPass
class CachePass : public NodePass {
public:
/// \brief Constructor
/// \param[in] transform_pass Raw pointer back to controlling tree pass
CachePass();
/// \brief Destructor
~CachePass() = default;
/// \brief Identifies the subtree below this node as a cached descendant tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
/// \brief Resets the tracking of the cache within the tree and assigns the operators that
/// will be involved in a cache transformation
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<MindRecordOp> node, bool *modified) override;
/// \brief Getter
std::vector<std::pair<std::shared_ptr<DatasetOp>, std::shared_ptr<CacheOp>>> cache_pairs() { return cache_pairs_; }
private:
/// \brief Common code for mappable leaf setup.
/// \param[in] node The leaf node performing setup work.
/// \return Status The error code return
Status MappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op);
/// \brief Common code for non-mappable leaf setup.
/// \param[in] node The leaf node performing setup work.
/// \return Status The error code return
Status NonMappableCacheLeafSetup(std::shared_ptr<DatasetOp> leaf_op);
/// \brief Assigns the leaf and cache operators that are involved in a cache transformation
/// \param[in] leaf_op The leaf operator involved in the cache transform
/// \param[in] cache_op The cache operator involved in the cache transform
void AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op, std::shared_ptr<CacheOp> cache_op);
bool is_caching_;
std::shared_ptr<DatasetOp> leaf_op_;
std::shared_ptr<Sampler> sampler_;
// The two operators that work together to establish the cache transform
std::vector<std::pair<std::shared_ptr<DatasetOp>, std::shared_ptr<CacheOp>>> cache_pairs_;
};
public:
/// \brief Constructor
CacheTransformPass();
......@@ -46,11 +163,6 @@ class CacheTransformPass : public TreePass {
/// \return Status The error code return
Status RunOnTree(ExecutionTree *tree, bool *modified) override;
/// \brief Assigns the leaf and cache operators that are involved in a cache transformation
/// \param[in] leaf_op The leaf operator involved in the cache transform
/// \param[in] cache_op The cache operator involved in the cache transform
void AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op, std::shared_ptr<CacheOp> cache_op);
private:
/// \brief Helper function to execute the cache transformation.
///
......@@ -72,9 +184,6 @@ class CacheTransformPass : public TreePass {
/// \return Status The error code return
Status ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr<DatasetOp> leaf_op,
std::shared_ptr<DatasetOp> cache_op, std::shared_ptr<CacheClient> cache_client);
// The two operators that work together to establish the cache transform
std::vector<std::pair<std::shared_ptr<DatasetOp>, std::shared_ptr<CacheOp>>> cache_pairs_;
};
} // namespace dataset
} // namespace mindspore
......
......@@ -16,7 +16,7 @@
#include <vector>
#include <algorithm>
#include "minddata/dataset/engine/opt/pre/injection_pass.h"
#include "minddata/dataset/engine/opt/pre/epoch_injection_pass.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/datasetops/epoch_ctrl_op.h"
#include "minddata/dataset/engine/datasetops/device_queue_op.h"
......@@ -25,64 +25,55 @@ namespace mindspore {
namespace dataset {
// constructor
InjectionPass::InjectionFinder::InjectionFinder(InjectionPass *injection_pass) : injection_pass_(injection_pass) {}
EpochInjectionPass::InjectionFinder::InjectionFinder(std::shared_ptr<DatasetOp> node) : injection_point_(node) {}
// Performs finder work for BuildVocabOp that has special rules about epoch control injection
Status InjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) {
if (injection_pass_) {
injection_pass_->epoch_ctrl_bypass_ = true;
return Status::OK();
} else {
RETURN_STATUS_UNEXPECTED("Missing outer injection pass object from inside InjectionFinder!");
}
Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildVocabOp> node, bool *modified) {
injection_point_ = nullptr;
return Status::OK();
}
// Performs finder work for BuildSentencePieceVocabOp that has special rules about epoch control injection
Status InjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node, bool *modified) {
if (injection_pass_) {
injection_pass_->epoch_ctrl_bypass_ = true;
return Status::OK();
} else {
RETURN_STATUS_UNEXPECTED("Missing outer injection pass object from inside InjectionFinder!");
}
Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<BuildSentencePieceVocabOp> node,
bool *modified) {
injection_point_ = nullptr;
return Status::OK();
}
// Temporary code to prevent the injection of epoch control when cache op is present
// Remove this code in cache op phase 2
Status InjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
if (injection_pass_) {
injection_pass_->epoch_ctrl_bypass_ = true;
return Status::OK();
} else {
RETURN_STATUS_UNEXPECTED("Missing outer injection pass object from inside InjectionFinder!");
}
Status EpochInjectionPass::InjectionFinder::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
injection_point_ = nullptr;
return Status::OK();
}
Status EpochInjectionPass::InjectionFinder::RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) {
// Assumption: There is only one DeviceQueueOp in a pipeline. This assumption is not validated here.
injection_point_ = node->child(0);
return Status::OK();
}
// constructor
InjectionPass::InjectionPass() : epoch_ctrl_bypass_(false) {}
EpochInjectionPass::EpochInjectionPass() {}
// Runs an injection pass to inject in operators needed at the pre pass stage
Status InjectionPass::RunOnTree(ExecutionTree *tree, bool *modified) {
Status EpochInjectionPass::RunOnTree(ExecutionTree *tree, bool *modified) {
MS_LOG(INFO) << "Pre pass: Injection pass started.";
// First, run the finder to perform any injection info before we can go ahead to drive the op injection work.
// The finder can make updates to the InjectionPass object.
InjectionPass::InjectionFinder finder(this);
finder.Run(tree, modified);
// The finder can make updates to the EpochInjectionPass object.
EpochInjectionPass::InjectionFinder finder(tree->root());
RETURN_IF_NOT_OK(finder.Run(tree, modified));
// The first injection logic is to check if we should inject the epoch control op as the root node.
// Do not inject the op if the number of epochs is 1.
int32_t num_epochs = tree->num_epochs();
if (num_epochs != 1 && !epoch_ctrl_bypass_) {
std::shared_ptr<DatasetOp> epoch_inject_node = finder.injection_point();
if (num_epochs != 1 && epoch_inject_node != nullptr) {
std::shared_ptr<EpochCtrlOp> epoch_ctrl_op;
RETURN_IF_NOT_OK(EpochCtrlOp::Builder(num_epochs).Build(&epoch_ctrl_op));
RETURN_IF_NOT_OK(tree->AssociateNode(epoch_ctrl_op));
std::shared_ptr<DatasetOp> node = tree->root();
if (std::dynamic_pointer_cast<DeviceQueueOp>(node) == nullptr) {
tree->root()->InsertAsParent(epoch_ctrl_op);
} else {
tree->root()->child(0)->InsertAsParent(epoch_ctrl_op);
}
epoch_inject_node->InsertAsParent(epoch_ctrl_op);
}
MS_LOG(INFO) << "Pre pass: Injection pass complete.";
......
......@@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_
#define DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_
#ifndef DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_
#define DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_
#include <memory>
#include <vector>
......@@ -26,10 +26,10 @@ namespace dataset {
class DatasetOp;
/// \class InjectionPass injection_pass.h
/// \class EpochInjectionPass epoch_injection_pass.h
/// \brief This is a pre pass that drives the injection of any nodes that could not be directly injected from the api
/// parsing.
class InjectionPass : public TreePass {
class EpochInjectionPass : public TreePass {
/// \class InjectionFinder
/// \brief This is a nested node pass class who's job is to parse the tree and perform any identification logic for
/// operators that need to be injected. It is run first by the main injection pass to find out what operators
......@@ -37,7 +37,10 @@ class InjectionPass : public TreePass {
class InjectionFinder : public NodePass {
public:
/// \brief Constructor
explicit InjectionFinder(InjectionPass *injection_pass);
explicit InjectionFinder(std::shared_ptr<DatasetOp> node);
/// \brief Destructor
~InjectionFinder() = default;
/// \brief Performs finder work for BuildVocabOp that has special rules about epoch control injection.
/// \param[in] node The node being visited
......@@ -58,24 +61,30 @@ class InjectionPass : public TreePass {
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
/// \brief Register the DeviceQueueOp for further action.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified) override;
/// \brief Getter
std::shared_ptr<DatasetOp> injection_point() { return injection_point_; }
private:
InjectionPass *injection_pass_;
std::shared_ptr<DatasetOp> injection_point_;
};
public:
/// \brief Constructor
InjectionPass();
EpochInjectionPass();
/// \brief Runs an injection pass to inject in operators needed at the pre pass stage
/// \param[inout] tree The tree to operate on.
/// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
Status RunOnTree(ExecutionTree *tree, bool *modified) override;
private:
bool epoch_ctrl_bypass_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_OPT_PASS_PRE_INJECTION_PASS_H_
#endif // DATASET_ENGINE_OPT_PASS_PRE_EPOCH_INJECTION_PASS_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "minddata/dataset/engine/opt/pre/removal_nodes.h"
#include "minddata/dataset/engine/opt/pre/removal_pass.h"
#include "minddata/dataset/engine/datasetops/shuffle_op.h"
namespace mindspore {
namespace dataset {
RemovalNodes::RemovalNodes(RemovalPass *removal_pass) : removal_pass_(removal_pass), is_caching_(false) {}
// Identifies the subtree below this node as a cached descendant tree.
Status RemovalNodes::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
*modified = false;
MS_LOG(INFO) << "Removal pass: CacheOp found, identified descendant tree.";
is_caching_ = true;
return Status::OK();
}
// Resets the tracking of the cache within the tree
Status RemovalNodes::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
*modified = false;
MS_LOG(INFO) << "Removal pass: cache descendant tree complete.";
is_caching_ = false;
return Status::OK();
}
// Perform ShuffleOp removal check.
Status RemovalNodes::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) {
*modified = false;
// If we are in a cache descendant tree, then this shuffle op needs to be removed
if (is_caching_) {
MS_LOG(INFO) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)";
if (removal_pass_) {
removal_pass_->AddToRemovalList(std::static_pointer_cast<DatasetOp>(node));
} else {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Back reference to removal pass is missing!");
}
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_H_
#include <memory>
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/opt/pre/removal_pass.h"
namespace mindspore {
namespace dataset {
/// \class RemovalNodes removal_nodes.h
/// \brief This is a NodePass who's job is to identify which nodes should be removed.
/// It works in conjunction with the removal_pass.
class RemovalNodes : public NodePass {
public:
/// \brief Constructor
/// \param[in] removal_pass Raw pointer back to controlling tree pass
explicit RemovalNodes(RemovalPass *removal_pass);
/// \brief Identifies the subtree below this node as a cached descendant tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
/// \brief Resets the tracking of the cache within the tree
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
/// \brief Destructor
~RemovalNodes() = default;
/// \brief Perform ShuffleOp removal check
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override;
private:
bool is_caching_;
RemovalPass *removal_pass_; // Back pointer to the owning removal pass
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_PASS_PRE_REMOVAL_NODES_
......@@ -16,32 +16,58 @@
#include <vector>
#include <algorithm>
#include "minddata/dataset/engine/opt/pre/removal_nodes.h"
#include "minddata/dataset/engine/opt/pre/removal_pass.h"
#include "minddata/dataset/engine/datasetops/shuffle_op.h"
#include "minddata/dataset/engine/execution_tree.h"
namespace mindspore {
namespace dataset {
RemovalPass::RemovalNodes::RemovalNodes() : is_caching_(false) {}
// Identifies the subtree below this node as a cached descendant tree.
Status RemovalPass::RemovalNodes::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
*modified = false;
MS_LOG(INFO) << "Removal pass: CacheOp found, identified descendant tree.";
is_caching_ = true;
return Status::OK();
}
// Resets the tracking of the cache within the tree
Status RemovalPass::RemovalNodes::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
*modified = false;
MS_LOG(INFO) << "Removal pass: cache descendant tree complete.";
is_caching_ = false;
return Status::OK();
}
// Perform ShuffleOp removal check.
Status RemovalPass::RemovalNodes::RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) {
*modified = false;
// If we are in a cache descendant tree, then this shuffle op needs to be removed
if (is_caching_) {
MS_LOG(INFO) << "ShuffleOp identified for removal (CacheOp is in ascendant tree)";
nodes_to_remove_.push_back(std::static_pointer_cast<DatasetOp>(node));
}
return Status::OK();
}
// constructor
RemovalPass::RemovalPass() {}
// Runs a removal_nodes pass first to find out which nodes to remove, then removes them.
// Walk the tree to collect the nodes to remove, then removes them.
Status RemovalPass::RunOnTree(ExecutionTree *tree, bool *modified) {
MS_LOG(INFO) << "Pre pass: removal pass started.";
// Create the removal node pass which can identify which nodes need to be removed.
std::unique_ptr<Pass> removal_nodes = std::make_unique<RemovalNodes>(this);
std::unique_ptr<RemovalPass::RemovalNodes> removal_nodes = std::make_unique<RemovalPass::RemovalNodes>();
RETURN_IF_NOT_OK(removal_nodes->Run(tree, modified));
// Then, execute the removal of any nodes that were set up for removal
for (auto node : removal_nodes_) {
for (auto node : removal_nodes->nodes_to_remove()) {
node->Remove();
}
MS_LOG(INFO) << "Pre pass: removal pass complete.";
return Status::OK();
}
// Adds an operator to the list of operators to be removed
void RemovalPass::AddToRemovalList(std::shared_ptr<DatasetOp> dataset_op) { removal_nodes_.push_back(dataset_op); }
} // namespace dataset
} // namespace mindspore
......@@ -30,6 +30,45 @@ class DatasetOp;
/// \brief This is a tree pass that will remove nodes. It uses removal_nodes to first identify which
/// nodes should be removed, and then removes them.
class RemovalPass : public TreePass {
/// \class RemovalNodes
/// \brief This is a NodePass who's job is to identify which nodes should be removed.
/// It works in conjunction with the removal_pass.
class RemovalNodes : public NodePass {
public:
/// \brief Constructor
/// \param[in] removal_pass Raw pointer back to controlling tree pass
RemovalNodes();
/// \brief Destructor
~RemovalNodes() = default;
/// \brief Identifies the subtree below this node as a cached descendant tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
/// \brief Resets the tracking of the cache within the tree
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
/// \brief Perform ShuffleOp removal check
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<ShuffleOp> node, bool *modified) override;
/// \brief Getter
/// \return All the nodes to be removed
std::vector<std::shared_ptr<DatasetOp>> nodes_to_remove() { return nodes_to_remove_; }
private:
bool is_caching_;
std::vector<std::shared_ptr<DatasetOp>> nodes_to_remove_;
};
public:
/// \brief Constructor
RemovalPass();
......@@ -42,13 +81,6 @@ class RemovalPass : public TreePass {
/// \param[inout] Indicate of the tree was modified.
/// \return Status The error code return
Status RunOnTree(ExecutionTree *tree, bool *modified) override;
/// \brief Adds an operator to the list of operators to be removed
/// \param[in] dataset_op The operator to add to the removal list
void AddToRemovalList(std::shared_ptr<DatasetOp> dataset_op);
private:
std::vector<std::shared_ptr<DatasetOp>> removal_nodes_;
};
} // namespace dataset
} // namespace mindspore
......
......@@ -189,7 +189,7 @@ def test_minddataset_invalidate_num_shards():
num_iter = 0
for _ in data_set.create_dict_iterator():
num_iter += 1
assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info)
assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info.value)
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
......@@ -203,7 +203,7 @@ def test_minddataset_invalidate_shard_id():
num_iter = 0
for _ in data_set.create_dict_iterator():
num_iter += 1
assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info)
assert 'Input shard_id is not within the required interval of (0 to 0).' in str(error_info.value)
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
......@@ -217,14 +217,14 @@ def test_minddataset_shard_id_bigger_than_num_shard():
num_iter = 0
for _ in data_set.create_dict_iterator():
num_iter += 1
assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info)
assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info.value)
with pytest.raises(Exception) as error_info:
data_set = ds.MindDataset(CV_FILE_NAME, columns_list, num_readers, True, 2, 5)
num_iter = 0
for _ in data_set.create_dict_iterator():
num_iter += 1
assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info)
assert 'Input shard_id is not within the required interval of (0 to 1).' in str(error_info.value)
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
......@@ -245,7 +245,7 @@ def test_cv_minddataset_partition_num_samples_equals_0():
num_iter += 1
with pytest.raises(Exception) as error_info:
partitions(5)
assert 'num_samples should be a positive integer value, but got num_samples=0' in str(error_info)
assert 'num_samples should be a positive integer value, but got num_samples=0' in str(error_info.value)
os.remove(CV_FILE_NAME)
os.remove("{}.db".format(CV_FILE_NAME))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册