提交 eadcb341 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2891 CacheOp phase 1

Merge pull request !2891 from Jamie/CacheOp_dev
......@@ -47,6 +47,8 @@ include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/dataset/include)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN:$ORIGIN/lib")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default")
ms_build_flatbuffers("engine/cache/de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${CMAKE_BINARY_DIR})
################## Include sub-modules ###############################
add_subdirectory(util)
add_subdirectory(core)
......@@ -55,7 +57,7 @@ add_subdirectory(engine)
add_subdirectory(api)
add_subdirectory(text)
######################################################################
add_dependencies(core utils)
add_dependencies(utils core)
add_dependencies(kernels-image core)
add_dependencies(kernels-data core)
add_dependencies(kernels core)
......@@ -89,6 +91,8 @@ set(submodules
$<TARGET_OBJECTS:engine-perf>
$<TARGET_OBJECTS:engine-datasetops>
$<TARGET_OBJECTS:engine-opt>
$<TARGET_OBJECTS:engine-cache-client>
$<TARGET_OBJECTS:engine-cache-server>
$<TARGET_OBJECTS:engine>
$<TARGET_OBJECTS:text>
$<TARGET_OBJECTS:text-kernels>
......@@ -106,6 +110,8 @@ else ()
add_library(_c_dataengine SHARED ${submodules})
endif ()
add_dependencies(_c_dataengine generated_engine_files)
set_target_properties(_c_dataengine PROPERTIES
PREFIX "${PYTHON_MODULE_PREFIX}"
SUFFIX "${PYTHON_MODULE_EXTENSION}"
......
......@@ -21,8 +21,10 @@
#include "common/utils.h"
#include "dataset/core/tensor.h"
#include "dataset/engine/cache/cache_client.h"
#include "dataset/engine/dataset_iterator.h"
#include "dataset/engine/datasetops/bucket_batch_by_length_op.h"
#include "dataset/engine/datasetops/cache_op.h"
#include "dataset/engine/datasetops/filter_op.h"
#include "dataset/engine/datasetops/source/celeba_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
......@@ -34,6 +36,7 @@
#include "dataset/engine/datasetops/source/random_data_op.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/engine/datasetops/source/voc_op.h"
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/kernels/py_func_op.h"
#include "dataset/util/random.h"
#include "dataset/util/status.h"
......@@ -441,6 +444,8 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *
MapOp::Builder map_builder;
std::vector<std::shared_ptr<TensorOp>> tensor_op_list;
std::vector<std::string> project_columns;
std::shared_ptr<CacheClient> cache_client = nullptr;
int num_workers = 0;
if (args["operations"].is_none()) RETURN_STATUS_UNEXPECTED("Error: 'operations' is not set. \n");
......@@ -456,7 +461,8 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *
} else if (key == "columns_order") {
project_columns = ToStringVector(value);
} else if (key == "num_parallel_workers") {
(void)map_builder.SetNumWorkers(ToInt(value));
num_workers = ToInt(value);
(void)map_builder.SetNumWorkers(num_workers);
} else if (key == "prefetch_size") {
(void)map_builder.SetOpConnectorSize(ToInt(value));
} else if (key == "operations") {
......@@ -477,6 +483,8 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *
}
if (tensor_op_list.empty()) RETURN_STATUS_UNEXPECTED("Error: tensor_op is invalid or not set.");
(void)map_builder.SetTensorFuncs(std::move(tensor_op_list));
} else if (key == "cache") {
cache_client = value.cast<std::shared_ptr<CacheClient>>();
} else {
RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key);
}
......@@ -499,6 +507,15 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *
*bottom = map_op;
}
// Additionally, add a cache if required. This will go over top of the project op if one
// was created, otherwise it goes over top of the map op
if (cache_client) {
std::shared_ptr<DatasetOp> cache_op = nullptr;
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, *top, &cache_op));
*top = cache_op;
*bottom = map_op;
}
return Status::OK();
}
......@@ -809,6 +826,9 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
std::shared_ptr<DatasetOp> *bottom) {
// Required arguments
std::vector<std::string> files_list;
std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<Sampler> sampler = nullptr;
int num_workers = 0;
std::shared_ptr<TFReaderOp::Builder> builder = std::make_shared<TFReaderOp::Builder>();
if (!args["dataset_files"].is_none()) {
files_list = ToStringVector(args["dataset_files"]);
......@@ -828,7 +848,8 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
num_workers = ToInt(value);
(void)builder->SetNumWorkers(num_workers);
} else if (key == "columns_list") {
columns_to_load = ToStringVector(value);
(void)builder->SetColumnsToLoad(columns_to_load);
......@@ -848,6 +869,11 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
(void)builder->SetDeviceId(ToInt(value));
} else if (key == "shard_equal_rows") {
(void)builder->SetShardEqualRows(ToBool(value));
} else if (key == "cache") {
cache_client = value.cast<std::shared_ptr<CacheClient>>();
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
sampler = create().cast<std::shared_ptr<Sampler>>();
}
}
}
......@@ -860,12 +886,27 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
}
(void)builder->SetDataSchema(std::move(schema));
}
// If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed
// because TFReaderOp is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can
// inherit this sampler from the leaf, providing sampling support from the caching layer.
// That is why we save the sampler here in a leaf node that does not use sampling.
if (sampler) {
(void)builder->SetSampler(std::move(sampler));
} else if (cache_client) {
int64_t num_samples = 0;
int64_t start_index = 0;
sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
(void)builder->SetSampler(std::move(sampler));
}
std::shared_ptr<TFReaderOp> tf_op;
RETURN_IF_NOT_OK(builder->Build(&tf_op));
RETURN_IF_NOT_OK(tree_->AssociateNode(tf_op));
*top = tf_op;
if (shuffle_required) {
if (!cache_client && shuffle_required) {
const boolean estimate = true;
const int64_t workers = 8;
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
......@@ -882,6 +923,15 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
*bottom = tf_op;
}
// Add a cache op over this op if required and update the output subtree (top/bottom)
if (cache_client) {
// Note, it is not allowed to have both shuffle and cache
std::shared_ptr<DatasetOp> cache_op = nullptr;
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, tf_op, &cache_op));
*top = cache_op;
*bottom = tf_op;
}
return Status::OK();
}
......@@ -906,6 +956,8 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr<Data
std::string err_msg = "Error: No dataset path specified";
RETURN_STATUS_UNEXPECTED(err_msg);
}
int num_workers = 0;
std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<ImageFolderOp::Builder> builder = std::make_shared<ImageFolderOp::Builder>();
(void)builder->SetImageFolderDir(ToString(args["dataset_dir"]));
......@@ -915,7 +967,8 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr<Data
py::handle value = arg.second;
if (!value.is_none()) {
if (key == "num_parallel_workers") {
(void)builder->SetNumWorkers(ToInt(value));
num_workers = ToInt(value);
(void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
......@@ -926,12 +979,27 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr<Data
(void)builder->SetClassIndex(ToStringMap(value));
} else if (key == "decode") {
(void)builder->SetDecode(ToBool(value));
} else if (key == "cache") {
cache_client = value.cast<std::shared_ptr<CacheClient>>();
}
}
}
std::shared_ptr<ImageFolderOp> op;
RETURN_IF_NOT_OK(builder->Build(&op));
*top = op;
std::shared_ptr<ImageFolderOp> if_op;
RETURN_IF_NOT_OK(builder->Build(&if_op));
RETURN_IF_NOT_OK(tree_->AssociateNode(if_op));
*top = if_op;
// Additionally, add a cache if required.
// Note that this cache op is only acting as a place holder for the caching position
// within the tree. Later, a pre-pass will execute a tree transform to set up the actual
// caching logic in the tree.
if (cache_client) {
std::shared_ptr<DatasetOp> cache_op = nullptr;
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, if_op, &cache_op));
*top = cache_op;
*bottom = if_op;
}
return Status::OK();
}
......@@ -1130,9 +1198,12 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas
std::shared_ptr<DatasetOp> *bottom) {
// Required arguments
RandomDataOp::Builder builder;
std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<Sampler> sampler = nullptr;
int num_workers = 0;
if (args["num_samples"].is_none()) {
std::string err_msg = "Error: num_samples is a required argument";
if (args["total_rows"].is_none()) {
std::string err_msg = "Error: total_rows is a required argument";
RETURN_STATUS_UNEXPECTED(err_msg);
}
std::vector<std::string> columns_to_load;
......@@ -1141,16 +1212,23 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas
for (auto arg : args) {
std::string key = py::str(arg.first);
py::handle value = arg.second;
if (key == "num_parallel_workers") {
(void)builder.SetNumWorkers(ToInt(value));
} else if (key == "schema_file_path" || key == "schema_json_string") {
schema_exists = true;
} else if (key == "columns_list") {
columns_to_load = ToStringVector(value);
} else if (key == "num_samples") {
// This is not sampling here. The random data op needs to know how much data to
// generate. It does not currently support sampling.
(void)builder.SetTotalRows(ToInt(value));
if (!value.is_none()) {
if (key == "num_parallel_workers") {
num_workers = ToInt(value);
(void)builder.SetNumWorkers(num_workers);
} else if (key == "schema_file_path" || key == "schema_json_string") {
schema_exists = true;
} else if (key == "columns_list") {
columns_to_load = ToStringVector(value);
} else if (key == "total_rows") {
// This is not sampling here. The random data op needs to know how much data to generate.
(void)builder.SetTotalRows(ToInt(value));
} else if (key == "cache") {
cache_client = value.cast<std::shared_ptr<CacheClient>>();
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
sampler = create().cast<std::shared_ptr<Sampler>>();
}
}
}
if (schema_exists) {
......@@ -1162,9 +1240,34 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas
}
(void)builder.SetDataSchema(std::move(schema));
}
std::shared_ptr<RandomDataOp> op;
RETURN_IF_NOT_OK(builder.Build(&op));
*top = op;
// If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed
// because RandomDataOp is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can
// inherit this sampler from the leaf, providing sampling support from the caching layer.
// That is why we save the sampler here in a leaf node that does not use sampling.
if (sampler) {
(void)builder.SetSampler(std::move(sampler));
} else if (cache_client) {
int64_t num_samples = 0;
int64_t start_index = 0;
sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
(void)builder.SetSampler(std::move(sampler));
}
std::shared_ptr<RandomDataOp> random_op = nullptr;
RETURN_IF_NOT_OK(builder.Build(&random_op));
RETURN_IF_NOT_OK(tree_->AssociateNode(random_op));
*top = random_op;
// Add a cache op over this op if required and update the output subtree (top/bottom)
if (cache_client) {
std::shared_ptr<DatasetOp> cache_op = nullptr;
RETURN_IF_NOT_OK(AddCacheOp(cache_client, num_workers, random_op, &cache_op));
*top = cache_op;
*bottom = random_op;
}
return Status::OK();
}
......@@ -1425,6 +1528,31 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp>
return Status::OK();
}
// Helper function to inject the cache operator over top of the current operation being built.
Status DEPipeline::AddCacheOp(std::shared_ptr<CacheClient> cache_client, int num_workers,
std::shared_ptr<DatasetOp> input_op, std::shared_ptr<DatasetOp> *cache_op) {
std::shared_ptr<CacheOp> new_cache_op = nullptr;
CacheOp::Builder cache_builder;
// use the same number of workers as the leaf. We need some optimization here, the user does not
// give the cache op number of workers directly.
if (num_workers != 0) {
(void)cache_builder.SetNumWorkers(num_workers);
}
(void)cache_builder.SetClient(cache_client);
RETURN_IF_NOT_OK(cache_builder.Build(&new_cache_op));
RETURN_IF_NOT_OK(tree_->AssociateNode(new_cache_op));
RETURN_IF_NOT_OK(new_cache_op->AddChild(input_op));
// We have now created:
//
// CacheOp
// |
// input_op
//
*cache_op = new_cache_op;
return Status::OK();
}
// Helper function to inject a shuffle operator over top of the current operation being built.
Status DEPipeline::AddShuffleOp(int64_t shuffle_size, std::shared_ptr<DatasetOp> input_op,
std::shared_ptr<DatasetOp> *shuffle_op) {
......
......@@ -35,6 +35,8 @@ namespace mindspore {
namespace dataset {
using DsOpPtr = std::shared_ptr<DatasetOp>;
class CacheClient;
// enum for the dataset operator names
enum OpName {
kShuffle,
......@@ -181,6 +183,16 @@ class DEPipeline {
static Status ParsePadInfo(py::handle value, PadInfo *pad_info);
/// \brief Helper function to inject a cache operator over top of the current operation being built.
/// \param[in] cache_client The client to use for caching
/// \param[in] num_workers The number of workers to use in the cache op
/// \param[in] input_op The operator to build the cache on top of
/// \param[out] cache_op The top node of the created subtree (subtree contains two nodes). In this case it will be
/// the cache operator
/// \return Status return code
Status AddCacheOp(std::shared_ptr<CacheClient> cache_client, int num_workers, std::shared_ptr<DatasetOp> input_op,
std::shared_ptr<DatasetOp> *cache_op);
/// \brief Helper function to inject a shuffle operator over top of the current operation being built.
/// \param[in] shuffle_size The size to use in the shuffle buffer
/// \param[in] input_op The operator to build shuffle on top of
......
......@@ -35,6 +35,7 @@
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/datasetops/source/voc_op.h"
#include "dataset/engine/cache/cache_client.h"
#include "dataset/engine/gnn/graph.h"
#include "dataset/engine/jagged_connector.h"
#include "dataset/kernels/data/concatenate_op.h"
......@@ -768,6 +769,11 @@ void bindInfoObjects(py::module *m) {
.def("get_batch_num", &BatchOp::CBatchInfo::get_batch_num);
}
void bindCacheClient(py::module *m) {
(void)py::class_<CacheClient, std::shared_ptr<CacheClient>>(*m, "CacheClient")
.def(py::init<uint32_t, uint64_t, bool>());
}
void bindVocabObjects(py::module *m) {
(void)py::class_<Vocab, std::shared_ptr<Vocab>>(*m, "Vocab")
.def(py::init<>())
......@@ -939,6 +945,7 @@ PYBIND11_MODULE(_c_dataengine, m) {
bindSamplerOps(&m);
bindDatasetOps(&m);
bindInfoObjects(&m);
bindCacheClient(&m);
bindVocabObjects(&m);
bindGraphData(&m);
bindDependIcuTokenizerOps(&m);
......
......@@ -2,6 +2,7 @@ add_subdirectory(datasetops)
add_subdirectory(opt)
add_subdirectory(gnn)
add_subdirectory(perf)
add_subdirectory(cache)
if (ENABLE_TDTQUE)
add_subdirectory(tdt)
endif ()
......@@ -17,7 +18,9 @@ add_library(engine OBJECT
target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS})
if (ENABLE_TDTQUE)
add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn engine-perf)
else()
add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf)
add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn engine-perf
engine-cache-client engine-cache-server)
else ()
add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf
engine-cache-client engine-cache-server)
endif ()
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(engine-cache-client OBJECT
cache_client.cc
cache_request.cc)
add_library(engine-cache-server OBJECT
cache_service.cc
cache_server.cc)
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iomanip>
#include "dataset/engine/cache/cache_client.h"
#include "dataset/engine/cache/cache_request.h"
#include "dataset/util/bit.h"
namespace mindspore {
namespace dataset {
// Constructor
CacheClient::CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill)
: server_connection_id_(0), session_id_(session_id), cache_crc_(0), cache_mem_sz_(cache_mem_sz), spill_(spill) {}
// print method for display cache details
void CacheClient::Print(std::ostream &out) const {
out << " Session id: " << session_id_ << "\n Cache crc: " << cache_crc_
<< "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << cache_mem_sz_
<< "\n Spilling: " << std::boolalpha << spill_;
}
Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_server) const {
CacheRowRequest rq(server_connection_id_, cookie());
RETURN_IF_NOT_OK(rq.SerializeCacheRowRequest(row));
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
if (row_id_from_server != nullptr) {
*row_id_from_server = rq.GetRowIdAfterCache();
}
return Status::OK();
}
Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const {
std::unique_ptr<DataBuffer> db_ptr = std::move(in);
auto num_rows = db_ptr->NumRows();
std::vector<TensorRow> all_rows;
if (num_rows > 0) {
all_rows.reserve(num_rows);
// Break down the DataBuffer into TensorRow. We will send the requests async
// and then do a final wait.
MemGuard<CacheRowRequest> rq_arr;
RETURN_IF_NOT_OK(rq_arr.allocate(num_rows, server_connection_id_, cookie()));
CacheServer &cs = CacheServer::GetInstance();
for (auto i = 0; i < num_rows; ++i) {
TensorRow row;
auto rq = rq_arr[i];
RETURN_IF_NOT_OK(db_ptr->PopRow(&row));
RETURN_IF_NOT_OK(rq->SerializeCacheRowRequest(row));
RETURN_IF_NOT_OK(cs.PushRequest(rq));
// We can't let row go out of scope. Otherwise it will free all the tensor memory.
// So park it in the vector. When this function go out of scope, its memory
// will be freed.
all_rows.push_back(std::move(row));
}
// Now we wait for the requests to be done.
for (auto i = 0; i < num_rows; ++i) {
auto rq = rq_arr[i];
RETURN_IF_NOT_OK(rq->Wait());
}
}
return Status::OK();
}
Status CacheClient::GetRows(const std::vector<row_id_type> &row_id, TensorTable *out) const {
RETURN_UNEXPECTED_IF_NULL(out);
BatchFetchRequest rq(server_connection_id_, row_id);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
RETURN_IF_NOT_OK(rq.RestoreRows(out));
return Status::OK();
}
Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
UniqueLock lck(&mux_);
// To create a cache, we identify ourself at the client by:
// - the shared session id
// - a crc for the tree nodes from the cache downward
// Pack these 2 into a single 64 bit request id
//
// Consider this example:
// tree1: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> batch
// tree2: cifar10 --> map(rotate) --> cache (session id = 1, crc = 456) --> batch
// These are different trees in a single session, but the user wants to share the cache.
// This is not allowed because the data of these caches are different.
//
// Consider this example:
// tree1: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> batch
// tree2: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> map(rotate) --> batch
// These are different trees in the same session, but the cached data is the same, so it is okay
// to allow the sharing of this cache between these pipelines.
// The CRC is computed by the tree prepare phase and passed to this function when creating the cache.
// If we already have a server_connection_id_, then it means this same cache client has already been used
// to create a cache and some other tree is trying to use the same cache.
// That is allowed, however the crc better match!
if (server_connection_id_) {
if (cache_crc_ != tree_crc) {
RETURN_STATUS_UNEXPECTED("Attempt to re-use a cache for a different tree!");
}
// Check the state of the server. For non-mappable case where there is a build phase and a fetch phase, we should
// skip the build phase.
lck.Unlock(); // GetStat will grab the mutex again. So unlock it to prevent deadlock.
CacheClient::ServiceStat stat{};
RETURN_IF_NOT_OK(GetStat(&stat));
if (stat.cache_service_state == static_cast<uint8_t>(CacheService::State::kFetchPhase)) {
return Status(StatusCode::kDuplicateKey, __LINE__, __FILE__, "Not an error and we should bypass the build phase");
}
} else {
cache_crc_ = tree_crc; // It's really a new cache we're creating so save our crc in the client
// Combine the session and crc. This will form our client cache identifier.
connection_id_type connection_identification = (static_cast<uint64_t>(session_id_) << 32) | cache_crc_;
// Now execute the cache create request using this identifier and other configs
BaseRequest::CreateCacheFlag createFlag = BaseRequest::CreateCacheFlag::kNone;
if (spill_) {
createFlag |= BaseRequest::CreateCacheFlag::kSpillToDisk;
}
if (generate_id) {
createFlag |= BaseRequest::CreateCacheFlag::kGenerateRowId;
}
CreationCacheRequest rq(connection_identification, cache_mem_sz_, createFlag);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
Status rc = rq.Wait();
if (rc.IsOk() || rc.get_code() == StatusCode::kDuplicateKey) {
server_connection_id_ = rq.GetServerConnectionId();
if (rc.IsOk()) {
// The 1st guy creating the cache will get a cookie back.
// But this object may be shared among pipelines and we don't want
// overwrite it.
cookie_ = rq.cookie();
}
}
// We are not resetting the Duplicate key return code. We are passing it back to the CacheOp. This will tell the
// CacheOp to bypass the build phase.
return rc;
}
return Status::OK();
}
Status CacheClient::PurgeCache() {
UniqueLock lck(&mux_);
PurgeCacheRequest rq(server_connection_id_);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
return rq.Wait();
}
Status CacheClient::DestroyCache() {
UniqueLock lck(&mux_);
DestroyCacheRequest rq(server_connection_id_);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
return rq.Wait();
}
Status CacheClient::GetStat(ServiceStat *stat) {
SharedLock lck(&mux_);
RETURN_UNEXPECTED_IF_NULL(stat);
GetStatRequest rq(server_connection_id_);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
stat->num_disk_cached = rq.GetNumDiskCached();
stat->num_mem_cached = rq.GetNumMemCached();
stat->min_row_id = rq.GetMinRowId();
stat->max_row_id = rq.GetMaxRowId();
stat->cache_service_state = rq.GetState();
return Status::OK();
}
Status CacheClient::CacheSchema(const std::unordered_map<std::string, int32_t> &map) {
SharedLock lck(&mux_);
CacheSchemaRequest rq(server_connection_id_);
RETURN_IF_NOT_OK(rq.SerializeCacheSchemaRequest(map));
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
return Status::OK();
}
Status CacheClient::FetchSchema(std::unordered_map<std::string, int32_t> *map) {
SharedLock lck(&mux_);
RETURN_UNEXPECTED_IF_NULL(map);
FetchSchemaRequest rq(server_connection_id_);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
*map = rq.GetColumnMap();
return Status::OK();
}
Status CacheClient::BuildPhaseDone() const {
SharedLock lck(&mux_);
BuildPhaseDoneRequest rq(server_connection_id_, cookie());
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_CACHE_CLIENT_H_
#define DATASET_ENGINE_CACHE_CLIENT_H_
#include <iostream>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "./de_tensor_generated.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/cache/cache_server.h"
#include "dataset/util/lock.h"
namespace mindspore {
namespace dataset {
/// \brief A CacheClient is a bridge between a DatasetOp and a CacheServer. All communications are through
/// a CacheClient. Typical tasks including like creating a cache service, cache a data buffer, restore a previously
/// rows, etc.
class CacheClient {
public:
/// \brief Constructor
/// \param session_id A user assigned session id for the current pipeline
/// \param cache_mem_sz Size of the memory set aside for the row caching. 0 for unlimited
/// \param spill Spill to disk if out of memory
CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill);
/// \brief Destructor
~CacheClient() = default;
/// \brief Getter function for returning the current session id
/// \return session id
uint64_t session_id() const { return session_id_; }
/// \brief Send a TensorRow to the cache server
/// \param[in] row
/// \param[out] row_id_from_server Optional. The row id assigned by the server for non-mappable dataset
/// \return return code
Status WriteRow(const TensorRow &row, row_id_type *row_id_from_server = nullptr) const;
/// \brief Send a DataBuffer to the cache server
/// \param in Unique pointer of the DataBuffer to be cached
/// \return return code
Status WriteBuffer(std::unique_ptr<DataBuffer> &&in) const;
/// \brief Fetch a list of rows from the cache server. An empty TensorRow will be returned if there is
/// any cache miss
/// \param row_id A vector of row id's
/// \param out A TensorTable of TensorRows.
/// \return return code
Status GetRows(const std::vector<row_id_type> &row_id, TensorTable *out) const;
/// \brief Create a cache.
/// \param tree_crc A crc that was generated during tree prepare phase
/// \param generate_id Let the cache service generate row id
/// \return Status object
Status CreateCache(uint32_t tree_crc, bool generate_id);
/// \brief Purge a cache. Cache can be reused after reset.
/// \return Status object
Status PurgeCache();
/// \brief Destroy a cache. Like Purge but the cache is deleted and can't be reused.
/// \return Status object
Status DestroyCache();
/// \brief Get the statistics from a cache.
/// \param[in/out] Pointer to a pre-allocated ServiceStat object
/// \return Status object
struct ServiceStat {
int64_t num_mem_cached;
int64_t num_disk_cached;
row_id_type min_row_id;
row_id_type max_row_id;
int8_t cache_service_state;
};
Status GetStat(ServiceStat *);
/// \brief Cache the schema at the cache server
/// \param map The unordered map of the schema
/// \return Status object
Status CacheSchema(const std::unordered_map<std::string, int32_t> &map);
/// \brief Fetch the schema from the cache server
/// \param map Pointer to pre-allocated map object
/// \return Status object.
Status FetchSchema(std::unordered_map<std::string, int32_t> *map);
/// \brief Change the state from build phase to read phase. Applicable to non-mappable dataset only. Only the cache
/// client that holds cookie can be allowed to make this request
/// \return Status object
Status BuildPhaseDone() const;
/// \brief A print method typically used for debugging
/// \param out The output stream to write output to
void Print(std::ostream &out) const;
/// \brief Stream output operator overload
/// \return the output stream must be returned
friend std::ostream &operator<<(std::ostream &out, const CacheClient &cc) {
cc.Print(out);
return out;
}
/// \brief Every cache server has a cookie which uniquely identifies the CacheClient that creates it.
/// \return Cookie
std::string cookie() const { return cookie_; }
private:
mutable RWLock mux_;
uint64_t cache_mem_sz_;
bool spill_;
// The session_id_ and cache_crc_ work together to uniquely identify this particular cache and allow
// sharing of the cache.
uint32_t session_id_;
uint32_t cache_crc_;
// The server_connection_id_ is the actual id we use for operations after the cache is built
connection_id_type server_connection_id_;
// Some magic cookie returned from the cache server.
std::string cookie_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_CACHE_CLIENT_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/cache/cache_request.h"
namespace mindspore {
namespace dataset {
Status CacheRowRequest::SerializeCacheRowRequest(const TensorRow &row) {
buffers_.reserve(row.size() + 1);
RETURN_IF_NOT_OK(SerializeTensorRowHeader(row));
buffers_.push_back(fbb_->GetBufferPointer());
for (const auto &ts : row) {
buffers_.push_back(ts->GetBuffer());
}
return Status::OK();
}
Status CacheRowRequest::SerializeTensorRowHeader(const TensorRow &row) {
try {
fbb_ = std::make_shared<flatbuffers::FlatBufferBuilder>();
std::vector<flatbuffers::Offset<TensorMetaMsg>> v;
std::vector<int64_t> tensor_sz;
v.reserve(row.size());
tensor_sz.reserve(row.size());
// We will go through each column in the row.
for (const std::shared_ptr<Tensor> &ts_ptr : row) {
flatbuffers::Offset<TensorMetaMsg> ts_off;
RETURN_IF_NOT_OK(SerializeOneTensorMeta(ts_ptr, &ts_off));
v.push_back(ts_off);
tensor_sz.push_back(ts_ptr->SizeInBytes());
}
auto column_off = fbb_->CreateVector(v);
auto data_sz_off = fbb_->CreateVector(tensor_sz);
TensorRowHeaderMsgBuilder row_builder(*fbb_);
row_builder.add_column(column_off);
row_builder.add_data_sz(data_sz_off);
// Pass the row_id even if it may not be known.
row_builder.add_row_id(row.getId());
row_builder.add_size_of_this(-1); // fill in later after we call Finish.
auto out = row_builder.Finish();
fbb_->Finish(out);
// Now go back to fill in size_of_this in the flat buffer.
auto msg = GetMutableTensorRowHeaderMsg(fbb_->GetBufferPointer());
auto success = msg->mutate_size_of_this(fbb_->GetSize());
if (!success) {
RETURN_STATUS_UNEXPECTED("Unable to set size_of_this");
}
return Status::OK();
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
}
Status CacheRowRequest::SerializeOneTensorMeta(const std::shared_ptr<Tensor> &ts_ptr,
flatbuffers::Offset<TensorMetaMsg> *out_off) {
RETURN_UNEXPECTED_IF_NULL(out_off);
const Tensor *ts = ts_ptr.get();
auto shape_off = fbb_->CreateVector(ts->shape().AsVector());
const auto ptr = ts->GetBuffer();
if (ptr == nullptr) {
RETURN_STATUS_UNEXPECTED("Tensor buffer is null");
}
auto src = ts->type().value();
TensorType dest;
#define CASE(t) \
case DataType::t: \
dest = TensorType::TensorType_##t; \
break
// Map the type to fill in the flat buffer.
switch (src) {
CASE(DE_BOOL);
CASE(DE_INT8);
CASE(DE_UINT8);
CASE(DE_INT16);
CASE(DE_UINT16);
CASE(DE_INT32);
CASE(DE_UINT32);
CASE(DE_INT64);
CASE(DE_UINT64);
CASE(DE_FLOAT16);
CASE(DE_FLOAT32);
CASE(DE_FLOAT64);
CASE(DE_STRING);
default:
MS_LOG(ERROR) << "Unknown tensor. Dumping content:\n" << *ts;
RETURN_STATUS_UNEXPECTED("Unknown type");
}
#undef CASE
TensorMetaMsgBuilder ts_builder(*fbb_);
ts_builder.add_dims(shape_off);
ts_builder.add_type(dest);
auto ts_off = ts_builder.Finish();
*out_off = ts_off;
return Status::OK();
}
Status BatchFetchRequest::RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data,
std::shared_ptr<Tensor> *out) {
RETURN_UNEXPECTED_IF_NULL(col_ts);
auto shape_in = col_ts->dims();
auto type_in = col_ts->type();
std::vector<dsize_t> v;
v.reserve(shape_in->size());
v.assign(shape_in->begin(), shape_in->end());
TensorShape shape(v);
DataType::Type dest = DataType::DE_UNKNOWN;
#define CASE(t) \
case TensorType_##t: \
dest = DataType::Type::t; \
break
switch (type_in) {
CASE(DE_BOOL);
CASE(DE_INT8);
CASE(DE_UINT8);
CASE(DE_INT16);
CASE(DE_UINT16);
CASE(DE_INT32);
CASE(DE_UINT32);
CASE(DE_INT64);
CASE(DE_UINT64);
CASE(DE_FLOAT16);
CASE(DE_FLOAT32);
CASE(DE_FLOAT64);
CASE(DE_STRING);
}
#undef CASE
DataType type(dest);
std::shared_ptr<Tensor> ts =
std::make_shared<Tensor>(shape, type, static_cast<const unsigned char *>(data.GetPointer()), data.GetSize());
// Next we restore the real data which can be embedded or stored separately.
if (ts->SizeInBytes() != data.GetSize()) {
MS_LOG(ERROR) << "Unexpected length. Read " << data.GetSize() << ". Expected " << ts->SizeInBytes() << ".\n"
<< "Dumping tensor\n"
<< *ts << "\n";
RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details.");
}
*out = std::move(ts);
return Status::OK();
}
Status BatchFetchRequest::RestoreRows(TensorTable *out) {
RETURN_UNEXPECTED_IF_NULL(out);
auto num_elements = row_id_.size();
auto *offset_array = reinterpret_cast<const int64_t *>(mem_.GetPointer());
TensorTable tbl;
tbl.reserve(num_elements);
ReadableSlice all(mem_.GetPointer(), mem_.GetSizeInBytes());
for (auto i = 0; i < num_elements; ++i) {
auto len = offset_array[i + 1] - offset_array[i];
TensorRow row;
row.setId(row_id_.at(i));
if (len > 0) {
ReadableSlice row_data(all, offset_array[i], len);
// Next we de-serialize flat buffer to get back each column
auto msg = GetTensorRowHeaderMsg(row_data.GetPointer());
auto msg_sz = msg->size_of_this();
// Start of the tensor data
auto ts_offset = msg_sz;
row.reserve(msg->column()->size());
for (auto k = 0; k < msg->column()->size(); ++k) {
auto col_ts = msg->column()->Get(k);
std::shared_ptr<Tensor> ts;
ReadableSlice data(row_data, ts_offset, msg->data_sz()->Get(k));
RETURN_IF_NOT_OK(RestoreOneTensor(col_ts, data, &ts));
row.push_back(ts);
ts_offset += data.GetSize();
}
}
tbl.push_back(std::move(row));
}
*out = std::move(tbl);
return Status::OK();
}
Status CacheSchemaRequest::SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map) {
try {
fbb_ = std::make_shared<flatbuffers::FlatBufferBuilder>();
std::vector<flatbuffers::Offset<ColumnNameMsg>> v;
v.reserve(map.size());
for (auto &column : map) {
auto c = CreateColumnNameMsg(*fbb_, fbb_->CreateString(column.first), column.second);
v.push_back(c);
}
auto v_off = fbb_->CreateVector(v);
auto final_off = CreateSchemaMsg(*fbb_, v_off);
fbb_->Finish(final_off);
buf_ = fbb_->GetBufferPointer();
len_of_buf_ = fbb_->GetSize();
return Status::OK();
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
}
std::unordered_map<std::string, int32_t> FetchSchemaRequest::GetColumnMap() {
if (column_name_id_map_.empty()) {
auto *map_msg = flatbuffers::GetRoot<SchemaMsg>(mem_.GetPointer());
auto v = map_msg->column();
for (auto i = 0; i < v->size(); ++i) {
auto col = map_msg->column()->Get(i);
column_name_id_map_.emplace(col->name()->str(), col->id());
}
}
return column_name_id_map_;
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_CACHE_REQ_H_
#define DATASET_ENGINE_CACHE_REQ_H_
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "./de_tensor_generated.h"
#include "dataset/core/tensor_row.h"
#include "dataset/util/slice.h"
#include "dataset/util/wait_post.h"
namespace mindspore {
namespace dataset {
/// \brief CacheClient communicates with CacheServer using Requests.
class BaseRequest {
public:
// Request types
enum class RequestType : int16_t {
kCacheRow = 0,
kBatchFetchRows = 1,
kCreateCache = 2,
kPurgeCache = 3,
kDestroyCache = 4,
kGetStat = 5,
kCacheSchema = 6,
kFetchSchema = 7,
kBuildPhaseDone = 8,
// Add new request before it.
kRequestUnknown = 32767
};
// For kCreateCache
enum class CreateCacheFlag : uint32_t { kNone = 0, kSpillToDisk = 1, kGenerateRowId = 1u << 1L };
friend class CacheServer;
/// \brief Base class of a cache server request
/// \param connection_id A combination of session id and crc that uniquely identifies a connection.
/// \param type Type of the request
explicit BaseRequest(connection_id_type connection_id, RequestType type)
: type_(type), connection_id_(connection_id) {}
virtual ~BaseRequest() = default;
/// \brief Wait for the completion of a request
/// \return Status returned from the cache server
Status Wait() {
RETURN_IF_NOT_OK(wp_.Wait());
return rc_;
}
/// \brief Getter function of the current connection id
/// \return Connection id
connection_id_type GetServerConnectionId() const { return connection_id_; }
private:
RequestType type_;
connection_id_type connection_id_;
Status rc_;
WaitPost wp_;
};
/// \brief Request to cache a single TensorRow
class CacheRowRequest : public BaseRequest {
public:
friend class CacheServer;
explicit CacheRowRequest(connection_id_type connection_id, const std::string &cookie)
: BaseRequest(connection_id, RequestType::kCacheRow), row_id_from_server_(-1), cookie_(cookie) {}
~CacheRowRequest() = default;
/// \brief Serialize a TensorRow for streaming to the cache server
/// \param row TensorRow
/// \return Status object
Status SerializeCacheRowRequest(const TensorRow &row);
/// \brief Return the row id assigned to this row for non-mappable dataset
/// \return row id of the cached row
row_id_type GetRowIdAfterCache() { return row_id_from_server_; }
private:
std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb_;
row_id_type row_id_from_server_;
std::vector<const void *> buffers_;
std::string cookie_;
/// \brief Private function to serialize one TensorRow
/// \param row TensorRow
/// \return Status object
Status SerializeTensorRowHeader(const TensorRow &row);
/// \brief Private function to serialize one Tensor
/// \param ts_ptr Tensor
/// \return Status object
Status SerializeOneTensorMeta(const std::shared_ptr<Tensor> &ts_ptr, flatbuffers::Offset<TensorMetaMsg> *out_off);
};
/// \brief Request to fetch rows in batch
class BatchFetchRequest : public BaseRequest {
public:
friend class CacheServer;
friend class CacheService;
BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id)
: BaseRequest(connection_id, RequestType::kBatchFetchRows), row_id_(row_id) {}
Status RestoreRows(TensorTable *out);
private:
std::vector<row_id_type> row_id_;
MemGuard<uint8_t> mem_;
Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr<Tensor> *out);
};
/// \brief Request to create a cache for the current connection
class CreationCacheRequest : public BaseRequest {
public:
friend class CacheServer;
/// \brief Constructor
/// \param connection_id
/// \param cache_mem_sz Maximum memory assigned for this connection. 0 means unlimited
/// \param flag Attributes of the cache.
explicit CreationCacheRequest(connection_id_type connection_id, uint64_t cache_mem_sz,
CreateCacheFlag flag = CreateCacheFlag::kNone)
: BaseRequest(connection_id, RequestType::kCreateCache), cache_mem_sz(cache_mem_sz), flag_(flag) {}
std::string cookie() const { return cookie_; }
private:
uint64_t cache_mem_sz;
CreateCacheFlag flag_;
std::string cookie_;
};
/// \brief Request to purge a cache.
class PurgeCacheRequest : public BaseRequest {
public:
friend class CacheServer;
explicit PurgeCacheRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kPurgeCache) {}
};
/// \brief Request to destroy a cache
class DestroyCacheRequest : public BaseRequest {
public:
friend class CacheServer;
explicit DestroyCacheRequest(connection_id_type connection_id)
: BaseRequest(connection_id, RequestType::kDestroyCache) {}
};
/// \brief Obtain the statistics of the current connection
class GetStatRequest : public BaseRequest {
public:
friend class CacheServer;
friend class CacheService;
explicit GetStatRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kGetStat) {}
row_id_type GetMinRowId() const {
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
return msg->min_row_id();
}
row_id_type GetMaxRowId() const {
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
return msg->max_row_id();
}
int64_t GetNumMemCached() const {
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
return msg->num_mem_cached();
}
int64_t GetNumDiskCached() const {
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
return msg->num_disk_cached();
}
uint8_t GetState() const {
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
return msg->state();
}
private:
MemGuard<uint8_t> mem_;
};
/// \brief Request to cache a schema
class CacheSchemaRequest : public BaseRequest {
public:
friend class CacheServer;
explicit CacheSchemaRequest(connection_id_type connection_id)
: BaseRequest(connection_id, RequestType::kCacheSchema), buf_(nullptr), len_of_buf_(0) {}
~CacheSchemaRequest() = default;
Status SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map);
const void *GetBuffer() const { return buf_; }
private:
std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb_;
const void *buf_;
int64_t len_of_buf_;
};
/// \brief Request to fetch a schema
class FetchSchemaRequest : public BaseRequest {
public:
friend class CacheServer;
explicit FetchSchemaRequest(connection_id_type connection_id)
: BaseRequest(connection_id, RequestType::kFetchSchema) {}
~FetchSchemaRequest() = default;
std::unordered_map<std::string, int32_t> GetColumnMap();
private:
MemGuard<uint8_t> mem_;
std::unordered_map<std::string, int32_t> column_name_id_map_;
};
/// \brief Request to change a cache from build phase to read phase. Applies to non-mappable cache only.
class BuildPhaseDoneRequest : public BaseRequest {
public:
friend class CacheServer;
BuildPhaseDoneRequest(connection_id_type connection_id, const std::string &cookie)
: BaseRequest(connection_id, RequestType::kBuildPhaseDone), cookie_(cookie) {}
private:
std::string cookie_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_CACHE_SERVICE_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/cache/cache_server.h"
#include "dataset/engine/cache/cache_service.h"
#include "dataset/engine/cache/cache_request.h"
#include "dataset/util/bit.h"
namespace mindspore {
namespace dataset {
Status CacheServer::DoServiceStart() {
if (!top_.empty()) {
Path spill(top_);
RETURN_IF_NOT_OK(spill.CreateDirectories());
MS_LOG(INFO) << "CacheServer will use disk folder: " << top_;
}
RETURN_IF_NOT_OK(vg_.ServiceStart());
cache_q_ = std::make_shared<Queue<BaseRequest *>>(1024);
RETURN_IF_NOT_OK(cache_q_->Register(&vg_));
auto f = std::bind(&CacheServer::ServerRequest, this);
// Spawn a a few threads to serve the request.
for (auto i = 0; i < num_workers_; ++i) {
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Cache server", f));
}
return Status::OK();
}
Status CacheServer::DoServiceStop() {
Status rc;
Status rc2;
// First stop all the threads.
RETURN_IF_NOT_OK(vg_.ServiceStop());
// Clean up all the caches if any.
UniqueLock lck(&rwLock_);
auto it = all_caches_.begin();
while (it != all_caches_.end()) {
auto cs = std::move(it->second);
rc2 = cs->ServiceStop();
if (rc2.IsError()) {
rc = rc2;
}
++it;
}
return rc;
}
CacheService *CacheServer::GetService(connection_id_type id) const {
SharedLock lck(&rwLock_);
auto it = all_caches_.find(id);
if (it != all_caches_.end()) {
return it->second.get();
}
return nullptr;
}
Status CacheServer::CreateService(connection_id_type connection_id, uint64_t cache_mem_sz,
BaseRequest::CreateCacheFlag flag, std::string *out_cookie) {
// We can't do spilling unless this server is setup with a spill path in the first place
bool spill = (flag & BaseRequest::CreateCacheFlag::kSpillToDisk) == BaseRequest::CreateCacheFlag::kSpillToDisk;
bool generate_id =
(flag & BaseRequest::CreateCacheFlag::kGenerateRowId) == BaseRequest::CreateCacheFlag::kGenerateRowId;
if (spill && top_.empty()) {
RETURN_STATUS_UNEXPECTED("Server is not set up with spill support.");
}
RETURN_UNEXPECTED_IF_NULL(out_cookie);
*out_cookie = "";
// Before creating the cache, first check if this is a request for a shared usage of an existing cache
// If two CreateService come in with identical connection_id, we need to serialize the create.
// The first create will be successful and be given a special cookie.
UniqueLock lck(&rwLock_);
auto end = all_caches_.end();
auto it = all_caches_.find(connection_id);
if (it == end) {
std::unique_ptr<CacheService> cs;
try {
cs = std::make_unique<CacheService>(cache_mem_sz, spill ? top_ : "", generate_id);
RETURN_IF_NOT_OK(cs->ServiceStart());
*out_cookie = cs->cookie();
all_caches_.emplace(connection_id, std::move(cs));
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory);
}
} else {
MS_LOG(INFO) << "Duplicate request for " + std::to_string(connection_id) + " to create cache service";
// We can return OK but we will return a duplicate key so user can act accordingly to either ignore it
// treat it as OK.
return Status(StatusCode::kDuplicateKey);
}
return Status::OK();
}
/// This is the main loop the cache server thread(s) are running.
/// Each thread will pop a request and save the result in the same request.
/// The sender will wait on the wait post in the request. Once the request
/// is fulfilled, the server thread will do a post signalling the request is
/// is processed.
/// \return
Status CacheServer::ServerRequest() {
TaskManager::FindMe()->Post();
// Loop forever until we are interrupted.
while (true) {
BaseRequest *base_rq = nullptr;
RETURN_IF_NOT_OK(cache_q_->PopFront(&base_rq));
auto cs = GetService(base_rq->connection_id_);
// Except for creating a new session, we expect cs is not null.
switch (base_rq->type_) {
case BaseRequest::RequestType::kCacheRow: {
if (cs == nullptr) {
std::string errMsg = "Cache id " + std::to_string(base_rq->connection_id_) + " not found";
base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
auto *rq = reinterpret_cast<CacheRowRequest *>(base_rq);
// Only if the cookie matches, we can accept insert into this cache that has a build phase
if (!cs->HasBuildPhase() || rq->cookie_ == cs->cookie()) {
rq->rc_ = cs->CacheRow(rq->buffers_, &rq->row_id_from_server_);
} else {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch");
}
}
break;
}
case BaseRequest::RequestType::kBatchFetchRows: {
if (cs == nullptr) {
std::string errMsg = "Cache id " + std::to_string(base_rq->connection_id_) + " not found";
base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
auto *rq = reinterpret_cast<BatchFetchRequest *>(base_rq);
rq->rc_ = cs->BatchFetch(rq->row_id_, &rq->mem_);
}
break;
}
case BaseRequest::RequestType::kCreateCache: {
// If the cache is already created we still need to run the creation so that we do sanity checks on the
// client id and return the cache id back to the user.
auto *rq = reinterpret_cast<CreationCacheRequest *>(base_rq);
rq->rc_ = CreateService(rq->connection_id_, rq->cache_mem_sz, rq->flag_, &rq->cookie_);
break;
}
case BaseRequest::RequestType::kPurgeCache: {
if (cs != nullptr) {
base_rq->rc_ = cs->Purge();
} else {
// it is already purged. Ignore it.
base_rq->rc_ = Status::OK();
}
break;
}
case BaseRequest::RequestType::kDestroyCache: {
if (cs != nullptr) {
// We need a strong lock to protect the map.
connection_id_type id = base_rq->connection_id_;
UniqueLock lck(&rwLock_);
// std::map will invoke the constructor of CacheService. So we don't need to do anything here.
auto n = all_caches_.erase(id);
if (n == 0) {
// It has been destroyed by another duplicate request.
MS_LOG(INFO) << "Duplicate request for " + std::to_string(id) + " to create cache service";
}
base_rq->rc_ = Status::OK();
} else {
// it is already destroyed. Ignore it.
base_rq->rc_ = Status::OK();
}
break;
}
case BaseRequest::RequestType::kGetStat: {
if (cs == nullptr) {
std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found";
base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
auto *rq = reinterpret_cast<GetStatRequest *>(base_rq);
CacheService::ServiceStat svc_stat;
rq->rc_ = cs->GetStat(&svc_stat);
if (rq->rc_.IsOk()) {
flatbuffers::FlatBufferBuilder fbb;
ServiceStatMsgBuilder bld(fbb);
bld.add_num_disk_cached(svc_stat.stat_.num_disk_cached);
bld.add_num_mem_cached(svc_stat.stat_.num_mem_cached);
bld.add_max_row_id(svc_stat.max_);
bld.add_min_row_id(svc_stat.min_);
bld.add_state(svc_stat.state_);
auto offset = bld.Finish();
fbb.Finish(offset);
rq->rc_ = rq->mem_.allocate(fbb.GetSize());
if (rq->rc_.IsOk()) {
WritableSlice dest(rq->mem_.GetMutablePointer(), fbb.GetSize());
ReadableSlice src(fbb.GetBufferPointer(), fbb.GetSize());
RETURN_IF_NOT_OK(WritableSlice::Copy(&dest, src));
}
}
}
break;
}
case BaseRequest::RequestType::kCacheSchema: {
if (cs == nullptr) {
std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found";
base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
auto *rq = reinterpret_cast<CacheSchemaRequest *>(base_rq);
rq->rc_ = cs->CacheSchema(rq->buf_, rq->len_of_buf_);
}
break;
}
case BaseRequest::RequestType::kFetchSchema: {
if (cs == nullptr) {
std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found";
base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
auto *rq = reinterpret_cast<FetchSchemaRequest *>(base_rq);
rq->rc_ = cs->FetchSchema(&rq->mem_);
}
break;
}
case BaseRequest::RequestType::kBuildPhaseDone: {
if (cs == nullptr) {
std::string errMsg = "Session " + std::to_string(base_rq->connection_id_) + " not found";
base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
} else {
auto *rq = reinterpret_cast<BuildPhaseDoneRequest *>(base_rq);
// We can only allow to switch phase is the cookie match.
if (rq->cookie_ == cs->cookie()) {
rq->rc_ = cs->BuildPhaseDone();
} else {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cookie mismatch");
}
}
break;
}
default:
base_rq->rc_ = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Unknown request type");
}
// Notify it is done, and move on to the next request.
base_rq->wp_.Set();
}
return Status::OK();
}
CacheServer::CacheServer(const std::string &spill_path, int32_t num_workers)
: top_(spill_path), num_workers_(num_workers) {}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_CACHE_SERVER_H_
#define DATASET_ENGINE_CACHE_SERVER_H_
#include <algorithm>
#include <atomic>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <map>
#include "dataset/engine/cache/cache_service.h"
#include "dataset/core/tensor.h"
#include "dataset/util/arena.h"
#include "dataset/util/cache_pool.h"
#include "dataset/util/lock.h"
#include "dataset/util/service.h"
#include "dataset/util/services.h"
#include "dataset/util/system_pool.h"
#include "dataset/util/queue.h"
#include "dataset/util/task_manager.h"
namespace mindspore {
namespace dataset {
class BaseRequest;
/// \brief A server which provides CacheService services.
class CacheServer : public Service {
public:
friend class Services;
using cache_index = std::map<connection_id_type, std::unique_ptr<CacheService>>;
CacheServer(const CacheServer &) = delete;
CacheServer &operator=(const CacheServer &) = delete;
CacheServer(CacheServer &&) = delete;
CacheServer &operator=(CacheServer &) = delete;
static CacheServer &GetInstance() noexcept { return Services::getCacheServer(); }
Status DoServiceStart() override;
Status DoServiceStop() override;
~CacheServer() { (void)ServiceStop(); }
/// \brief For the current demonstration, a cache client contacts cache server using a Queue.
/// \param rq
/// \return Status object
Status PushRequest(BaseRequest *rq) {
RETURN_UNEXPECTED_IF_NULL(rq);
RETURN_IF_NOT_OK(cache_q_->Add(rq));
return Status::OK();
}
private:
mutable RWLock rwLock_;
std::string top_;
cache_index all_caches_;
std::shared_ptr<Queue<BaseRequest *>> cache_q_;
TaskGroup vg_;
int32_t num_workers_;
/// \brief Constructor
/// \param spill_path Top directory for spilling buffers to.
/// \param num_workers Number of threads for handling requests.
explicit CacheServer(const std::string &spill_path, int32_t num_workers = 3);
/// \brief Locate a cache service from connection id.
/// \return Pointer to cache service. Null if not found
CacheService *GetService(connection_id_type id) const;
/// \brief Create a cache service. We allow multiple clients to create the same cache service.
/// Subsequent duplicate requests are ignored. The first cache client to create the service will be given
/// a special unique cookie.
/// \param[in] connection_id This is from a Cache client.
/// \param[in] cache_mem_sz
/// \param[in] flag
/// \param[out] out_cookie Only the first cache client will be given a special cookie to identify the creator
/// \return Status object
Status CreateService(connection_id_type connection_id, uint64_t cache_mem_sz, BaseRequest::CreateCacheFlag flag,
std::string *out_cookie);
/// \brief Entry point for all server threads.
Status ServerRequest();
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_CORE_CACHE_TENSOR_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/cache/cache_service.h"
#include "dataset/util/slice.h"
namespace mindspore {
namespace dataset {
CacheService::CacheService(uint64_t mem_sz, const std::string &root, bool generate_id)
: root_(root),
cache_mem_sz_(mem_sz),
cp_(nullptr),
map_(nullptr),
next_id_(0),
generate_id_(generate_id),
schema_key_(-1),
st_(generate_id ? State::kBuildPhase : State::kNone) {}
CacheService::~CacheService() { (void)ServiceStop(); }
bool CacheService::UseArena() {
// If fixed size, use Arena instead of the pool from global context.
return (cache_mem_sz_ > 0);
}
Status CacheService::DoServiceStart() {
std::shared_ptr<MemoryPool> mp_;
if (UseArena()) {
// Create a fixed size arena based on the parameter.
std::shared_ptr<Arena> arena;
RETURN_IF_NOT_OK(Arena::CreateArena(&arena, cache_mem_sz_));
mp_ = std::move(arena);
} else {
// Unlimited size. Simply use a system pool. Another choice is CircularPool.
mp_ = std::make_shared<SystemPool>();
}
// Put together a CachePool for backing up the Tensor
cp_ = std::make_shared<CachePool>(CachePool::value_allocator(mp_), root_);
RETURN_IF_NOT_OK(cp_->ServiceStart());
// Set up the B+ tree as well. But use the system pool instead.
map_ = std::make_shared<row_map>();
// Assign a name to this cache. Used for exclusive connection. But we can just use CachePool's name.
cookie_ = cp_->MyName();
return Status::OK();
}
Status CacheService::DoServiceStop() {
if (cp_ != nullptr) {
RETURN_IF_NOT_OK(cp_->ServiceStop());
}
return Status::OK();
}
Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated) {
SharedLock rw(&rw_lock_);
RETURN_UNEXPECTED_IF_NULL(row_id_generated);
if (st_ == State::kFetchPhase) {
// For this kind of cache service, once we are done with the build phase into fetch phase, we can't
// allow other to cache more rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
}
try {
// The first buffer is a flatbuffer which describes the rest of the buffers follow
auto fb = buf.front();
RETURN_UNEXPECTED_IF_NULL(fb);
auto msg = GetTensorRowHeaderMsg(fb);
// If the server side is designed to ignore incoming row id, we generate row id.
if (generate_id_) {
*row_id_generated = GetNextRowId();
// Some debug information on how many rows we have generated so far.
if ((*row_id_generated) % 1000 == 0) {
MS_LOG(DEBUG) << "Number of rows cached: " << *row_id_generated;
}
} else {
if (msg->row_id() < 0) {
std::string errMsg = "Expect positive row id: " + std::to_string(msg->row_id());
RETURN_STATUS_UNEXPECTED(errMsg);
}
*row_id_generated = msg->row_id();
}
auto size_of_this = msg->size_of_this();
auto column_hdr = msg->column();
// Number of tensor buffer should match the number of columns plus one.
if (buf.size() != column_hdr->size() + 1) {
std::string errMsg = "Column count does not match. Expect " + std::to_string(column_hdr->size() + 1) +
" but get " + std::to_string(buf.size());
RETURN_STATUS_UNEXPECTED(errMsg);
}
// Next we store in either memory or on disk. Low level code will consolidate everything in one piece.
std::vector<ReadableSlice> all_data;
all_data.reserve(column_hdr->size() + 1);
all_data.emplace_back(fb, size_of_this);
for (auto i = 0; i < column_hdr->size(); ++i) {
all_data.emplace_back(buf.at(i + 1), msg->data_sz()->Get(i));
}
// Now we cache the flat buffer.
CachePool::key_type key;
RETURN_IF_NOT_OK(cp_->Insert(all_data, &key));
Status rc = map_->DoInsert(*row_id_generated, key);
if (rc == Status(StatusCode::kDuplicateKey)) {
MS_LOG(DEBUG) << "Ignoring duplicate key";
} else {
RETURN_IF_NOT_OK(rc);
}
return Status::OK();
} catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
}
std::ostream &operator<<(std::ostream &out, const CacheService &cs) {
// Then show any custom derived-internal stuff
out << "\nCache memory size: " << cs.cache_mem_sz_;
out << "\nSpill path: ";
if (cs.root_.empty()) {
out << "None";
} else {
out << cs.GetSpillPath();
}
return out;
}
Path CacheService::GetSpillPath() const { return cp_->GetSpillPath(); }
Status CacheService::Purge() {
// First we must lock exclusively. No one else can cache/restore anything.
UniqueLock rw(&rw_lock_);
RETURN_IF_NOT_OK(cp_->ServiceStop());
auto new_map = std::make_shared<row_map>();
map_.reset();
map_ = std::move(new_map);
next_id_ = 0;
RETURN_IF_NOT_OK(cp_->ServiceStart());
return Status::OK();
}
Status CacheService::GetStat(CacheService::ServiceStat *out) {
SharedLock rw(&rw_lock_);
RETURN_UNEXPECTED_IF_NULL(out);
if (st_ == State::kNone || st_ == State::kFetchPhase) {
out->stat_ = cp_->GetStat();
out->state_ = static_cast<ServiceStat::state_type>(st_);
auto it = map_->begin();
if (it != map_->end()) {
out->min_ = it.key();
auto end_it = map_->end();
--end_it;
out->max_ = end_it.key();
}
} else {
out->state_ = static_cast<ServiceStat::state_type>(st_);
}
return Status::OK();
}
Status CacheService::BatchFetch(const std::vector<row_id_type> &v, MemGuard<uint8_t> *out) const {
RETURN_UNEXPECTED_IF_NULL(out);
SharedLock rw(&rw_lock_);
if (st_ == State::kBuildPhase) {
// For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
}
const auto num_elements = v.size();
int64_t mem_sz = (num_elements + 1) * sizeof(int64_t);
int64_t data_offset = mem_sz;
std::vector<int64_t> sz_v;
std::vector<CachePool::key_type> keys;
sz_v.reserve(num_elements);
keys.reserve(num_elements);
for (auto row_id : v) {
auto r = map_->Search(row_id);
if (r.second) {
auto &it = r.first;
CachePool::key_type key = it.value();
auto sz = cp_->GetSize(key);
if (sz == 0) {
std::string errMsg = "Key not found: ";
errMsg += std::to_string(key);
RETURN_STATUS_UNEXPECTED(errMsg);
}
keys.push_back(key);
sz_v.push_back(sz);
mem_sz += sz;
} else {
keys.push_back(-1);
sz_v.push_back(0);
}
}
MemGuard<uint8_t> mem;
RETURN_IF_NOT_OK(mem.allocate(mem_sz));
auto *offset_array = reinterpret_cast<int64_t *>(mem.GetMutablePointer());
offset_array[0] = data_offset;
WritableSlice all(mem.GetMutablePointer(), mem.GetSizeInBytes());
for (auto i = 0; i < num_elements; ++i) {
auto sz = sz_v.at(i);
offset_array[i + 1] = offset_array[i] + sz;
if (sz > 0) {
WritableSlice row_data(all, offset_array[i], sz);
auto key = keys.at(i);
size_t bytesRead = 0;
RETURN_IF_NOT_OK(cp_->Read(key, &row_data, &bytesRead));
if (bytesRead != sz) {
MS_LOG(ERROR) << "Unexpected length. Read " << bytesRead << ". Expected " << sz << "."
<< " Internal key: " << key << "\n";
RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details.");
}
}
}
*out = std::move(mem);
return Status::OK();
}
Status CacheService::CacheSchema(const void *buf, int64_t len) {
SharedLock rw(&rw_lock_);
if (st_ == State::kFetchPhase) {
// For this kind of cache service, once we are done with the build phase into fetch phase, we can't
// allow other to cache more rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
}
// This is a special request and we need to remember where we store it.
// In case we are calling the same function from multiple threads, only
// the first one is considered. Rest is ignored.
CachePool::key_type cur_key = schema_key_;
CachePool::key_type key;
if (cur_key < 0) {
RETURN_IF_NOT_OK(cp_->Insert({ReadableSlice(buf, len)}, &key));
auto result = std::atomic_compare_exchange_strong(&schema_key_, &cur_key, key);
MS_LOG(DEBUG) << "Caching Schema. Result = " << result;
} else {
MS_LOG(DEBUG) << "Caching Schema already done";
}
return Status::OK();
}
Status CacheService::FetchSchema(MemGuard<uint8_t> *out) const {
SharedLock rw(&rw_lock_);
if (st_ == State::kBuildPhase) {
// For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
}
RETURN_UNEXPECTED_IF_NULL(out);
MemGuard<uint8_t> mem;
if (schema_key_ >= 0) {
auto len = cp_->GetSize(schema_key_);
RETURN_IF_NOT_OK(mem.allocate(len));
auto slice = WritableSlice(mem.GetMutablePointer(), len);
RETURN_IF_NOT_OK(cp_->Read(schema_key_, &slice));
*out = std::move(mem);
} else {
return Status(StatusCode::kFileNotExist, __LINE__, __FILE__, "No schema has been cached");
}
return Status::OK();
}
Status CacheService::BuildPhaseDone() {
if (HasBuildPhase()) {
// Exclusive lock to switch phase
UniqueLock rw(&rw_lock_);
st_ = State::kFetchPhase;
return Status::OK();
} else {
RETURN_STATUS_UNEXPECTED("Not a cache that has a build phase");
}
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_CACHE_SERVICE_H_
#define DATASET_ENGINE_CACHE_SERVICE_H_
#include <algorithm>
#include <atomic>
#include <memory>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
#include "./de_tensor_generated.h"
#include "dataset/core/global_context.h"
#include "dataset/core/tensor.h"
#include "dataset/engine/cache/cache_request.h"
#include "dataset/util/arena.h"
#include "dataset/util/btree.h"
#include "dataset/util/cache_pool.h"
#include "dataset/util/service.h"
#include "dataset/util/services.h"
#include "dataset/util/system_pool.h"
namespace mindspore {
namespace dataset {
struct CacheStat;
/// \brief A cache service for storing/fetching buffers to in memory cache and may spill to disk the cache service is
/// created to support spilling
class CacheService : public Service {
public:
friend class CacheServer;
using row_map = BPlusTree<row_id_type, CachePool::key_type>;
enum class State : uint8_t { kNone = 0, kBuildPhase, kFetchPhase };
/// \brief Constructor
/// \param mem_sz Memory size to be set aside for the in memory cache. 0 means unlimited
/// \param root Spill path. Empty string means no spilling
/// \param generate_id If the cache service should generate row id for buffer that is cached.
/// For non-mappable dataset, this should be set to true.
CacheService(uint64_t mem_sz, const std::string &root, bool generate_id);
~CacheService();
/// \brief For fixed size memory, we will create an Arena.
/// \return false if unlimited memory.
bool UseArena();
Status DoServiceStart() override;
Status DoServiceStop() override;
/// \brief Main function to cache a row which is in form a series of buffers.
/// The first buffer is a Google flatbuffer which describes the rest of the buffers followed.
/// \param[in] buf Vector of buffer
/// \param[out] row_id_generated The row id assigned to this row if any
/// \return Status object
Status CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated);
/// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded
/// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row.
/// \param[in] v A vector of row id.
/// \param[out] out A contiguous memory buffer that holds the requested rows.
/// \return Status object
Status BatchFetch(const std::vector<row_id_type> &v, MemGuard<uint8_t> *out) const;
/// \brief Getter function
/// \return Spilling path
Path GetSpillPath() const;
/// \brief A structure returned from the cache server for statistics request.
class ServiceStat {
public:
using state_type = std::underlying_type<State>::type;
ServiceStat() : min_(0), max_(0), state_(0) {}
CachePool::CacheStat stat_{};
row_id_type min_;
row_id_type max_;
state_type state_;
};
/// \brief Statistics for the current service
/// \param[in/out] A pointer to a pre-allocated ServiceStat structure
/// \return Status Object
Status GetStat(ServiceStat *);
/// \brief Cache schema
/// \param buf A Google Flatbuffer that contains the schema
/// \param len size of the buffer
/// \return Status object
Status CacheSchema(const void *buf, int64_t len);
/// \brief Fetch schema
/// \param out A contiguous memory that contains the serialized form of schema.
/// \return Status object
Status FetchSchema(MemGuard<uint8_t> *out) const;
/// \brief Purge the content of a cache
/// \return Status object
Status Purge();
/// \brief Overload the << operator to print a cache service
/// \param out std::ostream
/// \param cs A cache service
/// \return std::ostream
friend std::ostream &operator<<(std::ostream &out, const CacheService &cs);
/// \brief Every cache service has a cookie. If the cookie of a CacheClient matches this cookie, this CacheClient
/// is the creator
/// \return Cookie
std::string cookie() const { return cookie_; }
/// \brief If this cache service generates row id for buffer cached, it is divided into two phases, a build phase and
/// a read phase.
/// \return True if has two phases.
bool HasBuildPhase() const { return generate_id_; }
/// \brief Change from write phase to read phase. Only the creator of this service is allowed to make this call.
/// \return Status object
Status BuildPhaseDone();
private:
mutable RWLock rw_lock_;
std::string root_;
uint64_t cache_mem_sz_;
std::shared_ptr<CachePool> cp_;
std::shared_ptr<row_map> map_;
std::atomic<row_id_type> next_id_;
bool generate_id_;
std::atomic<CachePool::key_type> schema_key_;
std::string cookie_;
State st_;
/// \brief Private function to generate a row id
/// \return Row id assigned.
row_id_type GetNextRowId() { return next_id_.fetch_add(1); }
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_CACHE_SERVICE_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
namespace mindspore.dataset;
/// Type of a Tensor
enum TensorType : byte {
DE_UNKNOWN = 0,
DE_BOOL = 1,
DE_INT8 = 2,
DE_UINT8 = 3,
DE_INT16 = 4,
DE_UINT16 = 5,
DE_INT32 = 6,
DE_UINT32 = 7,
DE_INT64 = 8,
DE_UINT64 = 9,
DE_FLOAT16 = 10,
DE_FLOAT32 = 11,
DE_FLOAT64 = 12,
DE_STRING = 13
}
/// The meta information of a Tensor
/// \note Only the type and shape are considered meta information. Tensor data is excluded.
table TensorMetaMsg {
dims:[int64] (required);
type:TensorType;
}
/// This is the first buffer that is sent to a Cache server when a TensorRow is serialized.
/// \param row_id is the row id of the TensorRow.
/// \param column The meta information of each Tensor in the row
/// \param size of this serialized buffer
/// \param size of each tensor data buffer that follows
table TensorRowHeaderMsg {
row_id:int64;
column:[TensorMetaMsg] (required);
size_of_this:int64;
data_sz:[int64] (required);
}
root_type TensorRowHeaderMsg;
/// A row of row id's
table TensorRowIds {
row_id:[int64] (required);
}
/// Statistics returned from each cache service
/// \note It must match CacheService::ServiceStat
table ServiceStatMsg {
num_mem_cached:int64;
num_disk_cached:int64;
min_row_id:int64;
max_row_id:int64;
state:int8;
}
/// Column description of each column in a schema
table ColumnNameMsg {
name:string;
id:int32;
}
/// Serialized form of a schema
table SchemaMsg {
column:[ColumnNameMsg];
}
......@@ -24,10 +24,8 @@ namespace dataset {
// Description: This is the main constructor that is used for making a buffer
DataBuffer::DataBuffer(int32_t id, BufferFlags flags) : buffer_id_(id), tensor_table_(nullptr), buffer_flags_(flags) {}
// Name: print()
// Description: A function that prints info about the DataBuffer (base class version)
void DataBuffer::Print(std::ostream &out, // In: The output stream to print to
bool show_all) const { // In: T/F if it should show everything
// A method for debug printing of the buffer
void DataBuffer::Print(std::ostream &out, bool show_all) const {
out << "bufferId: " << buffer_id_ << "\nflags: " << std::hex << buffer_flags_ << std::dec << "\n";
// If the column counts are set then it means that data has been set into
......@@ -46,11 +44,6 @@ void DataBuffer::Print(std::ostream &out, // In: The output stream to print
}
}
Status DataBuffer::Load() {
std::string err_msg = "Base class load called, but it does not have an implementation!";
RETURN_STATUS_UNEXPECTED(err_msg);
}
// Remove me!! Callers should fetch rows via pop
Status DataBuffer::GetTensor(std::shared_ptr<Tensor> *ptr, int32_t row_id, int32_t col_id) const {
if (row_id < tensor_table_->size() && col_id < tensor_table_->at(row_id).size()) {
......@@ -92,8 +85,5 @@ Status DataBuffer::SliceOff(int64_t number_of_rows) {
return Status::OK();
}
// Destructor
DataBuffer::~DataBuffer() {}
} // namespace dataset
} // namespace mindspore
......@@ -29,11 +29,9 @@
namespace mindspore {
namespace dataset {
// The DataBuffer class is a base class that will represent the data for n values based
// on a unique row id for each row of data.
// There can be different types of DataBuffers to abstract over how the data is stored
// in memory and acquired from storage.
// Each buffer holds a range of consecutive row id's.
/// \brief The DataBuffer class is a container of tensor data and is the unit of transmission between
/// connectors of dataset operators. Inside the buffer, tensors are organized into a table-like format
/// where n TensorRows may consist of m tensors (columns).
class DataBuffer {
public:
// Buffer flags
......@@ -47,13 +45,13 @@ class DataBuffer {
// Description: This is the main constructor that is used for making a buffer
DataBuffer(int32_t id, BufferFlags flags);
// Destructor
virtual ~DataBuffer();
/// \brief default destructor
~DataBuffer() = default;
// Name: print()
// Description: A function that prints info about the DataBuffer (base class version)
virtual void Print(std::ostream &out, // In: The output stream to print to
bool show_all) const; // In: T/F if it should show everything
/// \brief A method for debug printing of the buffer
/// \param[inout] out The stream to write to
/// \param[in] show_all A boolean to toggle between details and summary printing
void Print(std::ostream &out, bool show_all) const;
// Provide stream operator for displaying it
friend std::ostream &operator<<(std::ostream &out, const DataBuffer &cb) {
......@@ -61,10 +59,6 @@ class DataBuffer {
return out;
}
// Name: load()
// Description: populates the DataBuffer with data based on it's id
virtual Status Load();
// Convenience getter functions for flag checking
bool eof() const { return (static_cast<uint32_t>(buffer_flags_) & static_cast<uint32_t>(kDeBFlagEOF)); }
......
......@@ -17,7 +17,11 @@ set(DATASET_ENGINE_DATASETOPS_SRC_FILES
take_op.cc
shuffle_op.cc
zip_op.cc
concat_op.cc
concat_op.cc
cache_base_op.cc
cache_lookup_op.cc
cache_op.cc
cache_merge_op.cc
)
if (ENABLE_PYTHON)
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/cache_base_op.h"
#include <iomanip>
#include <iostream>
#include "dataset/engine/execution_tree.h"
namespace mindspore {
namespace dataset {
// A print method typically used for debugging
void CacheBase::Print(std::ostream &out, bool show_all) const {
// Always show the id and name as first line regardless if this summary or detailed print
out << "(" << std::setw(2) << operator_id_ << ") <" << Name() << ">:";
if (!show_all) {
// Call the super class for displaying any common 1-liner info
ParallelOp::Print(out, show_all);
out << "\n";
} else {
// Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\nCache client:\n" << *cache_client_ << "\n\n";
}
}
// Overrides base class reset method. When an operator does a reset, it cleans up any state
// info from it's previous execution and then initializes itself so that it can be executed
// again.
Status CacheBase::Reset() {
if (sampler_ != nullptr) {
RETURN_IF_NOT_OK(sampler_->ResetSampler());
}
// Wake up the workers to get them going again in a new epoch
MS_LOG(DEBUG) << Name() << " resetting.";
epoch_sync_.Set();
return Status::OK();
}
CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler)
: ParallelOp(num_workers, op_connector_size, sampler),
cache_client_(cache_client),
rows_per_buffer_(rows_per_buf),
// We can cause deadlock if this internal Connector size is too small.
keys_miss_(num_workers_, 1, 1024) {
io_block_queues_.Init(num_workers, op_connector_size);
}
// Common function to fetch samples from the sampler and send them using the io_block_queues to
// the parallel workers
Status CacheBase::FetchSamplesToWorkers() {
int64_t buf_cnt = 0;
int64_t wait_cnt = 0;
do {
epoch_sync_.Clear();
std::vector<row_id_type> keys;
int64_t row_cnt = 0;
keys.reserve(rows_per_buffer_);
std::unique_ptr<DataBuffer> sampler_buffer;
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
while (!sampler_buffer->eoe()) {
TensorRow sample_row;
RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row));
std::shared_ptr<Tensor> sample_ids = sample_row[0];
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) {
keys.push_back(*itr);
++row_cnt;
if (row_cnt % rows_per_buffer_ == 0) {
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk)));
keys.clear();
}
}
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
}
if (!keys.empty()) {
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk)));
}
// send the eoe
RETURN_IF_NOT_OK(
io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
// If repeat but the not last repeat, wait for reset.
if (BitTest(op_ctrl_flags_, kDeOpRepeated) && !BitTest(op_ctrl_flags_, kDeOpLastRepeat)) {
MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << ++wait_cnt << " Buffer sent " << buf_cnt;
RETURN_IF_NOT_OK(epoch_sync_.Wait());
} else {
// We can break out from the loop.
break;
}
} while (true);
// Flow the eof before exit
RETURN_IF_NOT_OK(
io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof)));
// Ask all the workers to quit.
for (int32_t i = 0; i < num_workers_; i++) {
RETURN_IF_NOT_OK(
io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone)));
}
return Status::OK();
}
Status CacheBase::FetchFromCache(int32_t worker_id) {
int64_t buffer_id = worker_id;
std::unique_ptr<IOBlock> blk;
do {
RETURN_IF_NOT_OK(io_block_queues_[worker_id]->PopFront(&blk));
if (blk->eof()) {
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF)));
} else if (blk->eoe()) {
if (AllowCacheMiss()) {
// This code path is for CacheLookupOp acting as a sampler. If we get a eoe from
// a sampler, send a eoe to physical leaf op as well.
std::vector<row_id_type> eoe;
eoe.push_back(eoe_row_id);
RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, eoe));
}
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE)));
} else {
std::vector<int64_t> keys;
RETURN_IF_NOT_OK(blk->GetKeys(&keys));
if (keys.empty()) {
// empty key is a quit signal for workers
break;
}
std::unique_ptr<DataBuffer> db = std::make_unique<DataBuffer>(buffer_id, DataBuffer::kDeBFlagNone);
std::unique_ptr<TensorQTable> que = std::make_unique<TensorQTable>();
TensorTable ttbl;
RETURN_IF_NOT_OK(cache_client_->GetRows(keys, &ttbl));
auto row_it = ttbl.begin();
std::vector<row_id_type> cache_miss;
cache_miss.reserve(keys.size());
for (auto row_id : keys) {
auto &row = *row_it;
if (row.empty()) {
if (AllowCacheMiss()) {
cache_miss.push_back(row_id);
} else {
std::string errMsg = "Row id " + std::to_string(row_id) + " not found.";
RETURN_STATUS_UNEXPECTED(errMsg);
}
}
que->push_back(std::move(row));
++row_it;
}
db->set_tensor_table(std::move(que));
if (AllowCacheMiss()) {
// Because of the way connector works, we push unconditionally even cache_miss can be empty.
RETURN_IF_NOT_OK(keys_miss_.Push(worker_id, cache_miss));
}
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db)));
buffer_id += num_workers_;
}
} while (true);
return Status::OK();
}
Status CacheBase::RegisterResources() {
RETURN_IF_NOT_OK(epoch_sync_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
return Status::OK();
}
CacheBase::~CacheBase() {}
Status CacheBase::UpdateColumnMapFromCache() {
Status rc;
// Get the schema from the server. It may not be there yet. So tolerate the error.
if (column_name_id_map_.empty()) {
rc = cache_client_->FetchSchema(&column_name_id_map_);
if (rc == Status(StatusCode::kFileNotExist)) {
MS_LOG(DEBUG) << "Schema not in the server yet.";
rc = Status::OK();
}
}
return rc;
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_
#define DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "dataset/engine/cache/cache_client.h"
#include "dataset/engine/cache/cache_service.h"
#include "dataset/engine/datasetops/parallel_op.h"
#include "dataset/engine/datasetops/repeat_op.h"
#include "dataset/engine/datasetops/source/io_block.h"
#include "dataset/engine/datasetops/source/sampler/sampler.h"
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/util/queue.h"
#include "dataset/util/wait_post.h"
#include "dataset/engine/datasetops/cache_base_op.h"
namespace mindspore {
namespace dataset {
/// \brief This is the base class for CacheOp and CacheLookupOp which share many similarities.
/// \see CacheOp
/// \see CacheLookupOp
class CacheBase : public ParallelOp {
public:
/// \brief Base class constructor
/// \param num_workers Number of parallel workers
/// \param op_connector_size Connector size
/// \param rows_per_buf Number of rows per buffer
/// \param cache_client CacheClient for communication to the CacheServer
/// \param sampler Sampler which is mandatory
CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler);
/// \brief Destructor
~CacheBase();
constexpr static int eoe_row_id = -1;
/// \brief Overrides base class reset method. When an operator does a reset, it cleans up any state
/// info from it's previous execution and then initializes itself so that it can be executed
/// again.
/// \return Status - The error code return
Status Reset() override;
/// \brief A print method typically used for debugging
/// \param out The output stream to write output to
/// \param show_all A bool to control if you want to show all info or just a summary
void Print(std::ostream &out, bool show_all) const override;
/// \brief << Stream output operator overload
/// \notes This allows you to write the debug print info using stream operators
/// \param out reference to the output stream being overloaded
/// \param mo reference to the CacheOp to display
/// \return the output stream must be returned
friend std::ostream &operator<<(std::ostream &out, const CacheBase &mo) {
mo.Print(out, false);
return out;
}
/// \brief Getter for the cache client
/// \return shared ptr to the cache client
std::shared_ptr<CacheClient> cache_client() { return cache_client_; }
/// \brief Setter for the cache client
void SetCacheClient(std::shared_ptr<CacheClient> cache_client) { cache_client_ = std::move(cache_client); }
/// \brief Derived class must implement this method if a cache miss is treated as error
virtual bool AllowCacheMiss() = 0;
protected:
std::shared_ptr<CacheClient> cache_client_;
WaitPost epoch_sync_;
int32_t rows_per_buffer_;
Connector<std::vector<row_id_type>> keys_miss_;
/// \brief Common function to register resources for interrupt
/// \note Derived should override this function for extra resources to be registered
virtual Status RegisterResources();
/// \brief This function is called by main thread to send samples to the worker thread.
/// \note It is a non-virtual function
/// \return Status object
Status FetchSamplesToWorkers();
/// \brief This function is called by each worker to fetch rows from the cache server for a given set of
/// sample row id's
/// \return Status object
Status FetchFromCache(int32_t worker_id);
/// \brief Get the column map from cache server
Status UpdateColumnMapFromCache();
private:
QueueList<std::unique_ptr<IOBlock>> io_block_queues_;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/cache_lookup_op.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/core/config_manager.h"
#include "dataset/core/constants.h"
#include "dataset/core/global_context.h"
#include "dataset/engine/execution_tree.h"
#include "utils/log_adapter.h"
#include "utils/system/crc32c.h"
namespace mindspore {
namespace dataset {
// Builder constructor. Creates the builder object.
CacheLookupOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
build_num_workers_ = cfg->num_parallel_workers();
rows_per_buffer_ = cfg->rows_per_buffer();
build_op_connector_size_ = cfg->op_connector_size();
}
// Check if the required parameters are set by the builder.
Status CacheLookupOp::Builder::SanityCheck() const {
if (build_cache_client_ == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CacheLookupOp requires a CacheClient");
}
// Make sure the cache client has a valid session
if (!build_cache_client_->session_id()) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"Cache client for CacheLookupOp is missing session id");
}
return Status::OK();
}
// The builder "build" method creates the final object and does some init on it
Status CacheLookupOp::Builder::Build(std::shared_ptr<CacheLookupOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<CacheLookupOp>(build_num_workers_, build_op_connector_size_, rows_per_buffer_,
build_cache_client_, build_sampler_);
return Status::OK();
}
Status CacheLookupOp::operator()() {
if (!sampler_) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"CacheLookupOp requires a sampler before it can be executed!");
}
RETURN_IF_NOT_OK(RegisterResources());
// Kick off the workers
RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_workers_, std::bind(&CacheLookupOp::WorkerEntry, this, std::placeholders::_1)));
// required task group sync after launching workers
TaskManager::FindMe()->Post();
// We have to wait until the leaf op has handshake with us.
RETURN_IF_NOT_OK(leaf_op_wp_.Wait());
RETURN_IF_NOT_OK(FetchSamplesToWorkers());
return Status::OK();
}
Status CacheLookupOp::WorkerEntry(int32_t worker_id) {
TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(FetchFromCache(worker_id));
return Status::OK();
}
Status CacheLookupOp::ResetSampler() { return Status::OK(); }
Status CacheLookupOp::HandshakeRandomAccessOp(const RandomAccessOp *op) {
// We act like a sampler and as a dataset op. During handshake with leaf op,
// We must wait until the leaf op has indexed everything.
RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(op));
// Now we notify the main thread handshake has finished.
leaf_op_wp_.Set();
return Status::OK();
}
Status CacheLookupOp::InitSampler() { return Sampler::InitSampler(); }
void CacheLookupOp::Print(std::ostream &out, bool show_all) const { CacheBase::Print(out, show_all); }
Status CacheLookupOp::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
std::vector<row_id_type> cache_miss;
RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss));
// Ignore the case we have no cache miss, we can't return empty samples.
while (cache_miss.empty()) {
RETURN_IF_NOT_OK(keys_miss_.Pop(0, &cache_miss));
}
// Special code for eoe
if (cache_miss.at(0) == eoe_row_id) {
*out_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
} else {
std::shared_ptr<Tensor> sample_ts;
RETURN_IF_NOT_OK(CreateSamplerTensor(&sample_ts, cache_miss.size()));
(*out_buffer) = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagNone);
auto idPtr = sample_ts->begin<int64_t>();
for (auto i = 0; i < cache_miss.size(); ++i) {
*idPtr = cache_miss.at(i);
++idPtr;
}
TensorRow row;
row.push_back(sample_ts);
(*out_buffer)->set_tensor_table(std::make_unique<TensorQTable>(1, row));
}
return Status::OK();
}
Status CacheLookupOp::RegisterResources() {
RETURN_IF_NOT_OK(CacheBase::RegisterResources());
RETURN_IF_NOT_OK(leaf_op_wp_.Register(tree_->AllTasks()));
return Status::OK();
}
Status CacheLookupOp::ComputeColMap() {
// We don't know the column map at this point unless we contact the cache server
// to fetch the schema but the cache server may not have it at this point either.
// So we will just return OK and let MergeOp (our parent) to handle it.
return Status::OK();
}
// Visitor accept method for NodePass
Status CacheLookupOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<CacheLookupOp>(), modified);
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_
#define DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_
#include <atomic>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "dataset/engine/datasetops/cache_base_op.h"
namespace mindspore {
namespace dataset {
/// \brief provides a memory/disk cache that acts as a save-point within a mappable dataset.
/// \note For non-mappable dataset, please see CacheOp
/// \see CacheOp
class CacheLookupOp : public CacheBase, public Sampler {
public:
class Builder {
public:
/// \brief Builder constructor. Creates the builder object.
/// \note No default args
Builder();
/// Default destructor
~Builder() = default;
/// Setter method.
/// \treturn Builder setter method returns reference to the builder.
Builder &SetNumWorkers(int32_t num_workers) {
build_num_workers_ = num_workers;
return *this;
}
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetOpConnectorSize(int32_t connector_size) {
build_op_connector_size_ = connector_size;
return *this;
}
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetClient(std::shared_ptr<CacheClient> cache_client) {
build_cache_client_ = cache_client;
return *this;
}
/// \brief Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
build_sampler_ = std::move(sampler);
return *this;
}
/// \brief The builder "build" method creates the final object and does some init on it.
/// \param ptr The shared_ptr to the new CacheLookupOp object
/// \return Status
Status Build(std::shared_ptr<CacheLookupOp> *ptr);
private:
int32_t build_num_workers_;
int32_t rows_per_buffer_;
int32_t build_op_connector_size_;
std::shared_ptr<CacheClient> build_cache_client_;
std::shared_ptr<Sampler> build_sampler_;
// Check if the required parameters are set by the builder.
// \return Status The error code return
Status SanityCheck() const;
};
/// \brief Constructor
/// \note It takes the same argument as the base class.
/// \see CacheBase
CacheLookupOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler)
: CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), Sampler(*(sampler.get())) {}
~CacheLookupOp() = default;
// As a parallel op, we override these two functions
Status operator()() override;
Status WorkerEntry(int32_t worker_id) override;
// As a sampler, we override the following functions
Status ResetSampler() override;
Status HandshakeRandomAccessOp(const RandomAccessOp *op) override;
Status InitSampler() override;
Status GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) override;
void Print(std::ostream &out, bool show_all) const override;
bool AllowCacheMiss() override { return true; }
std::string Name() const override { return "CacheLookupOp"; }
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
protected:
Status ComputeColMap() override;
private:
WaitPost leaf_op_wp_;
Status RegisterResources() override;
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include <functional>
#include <iomanip>
#include "dataset/core/config_manager.h"
#include "dataset/core/constants.h"
#include "dataset/core/global_context.h"
#include "dataset/engine/datasetops/cache_merge_op.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/util/task_manager.h"
namespace mindspore {
namespace dataset {
CacheMergeOp::~CacheMergeOp() = default;
void CacheMergeOp::Print(std::ostream &out, bool show_all)
const { // Always show the id and name as first line regardless if this is summary or detailed print
out << "(" << std::setw(2) << operator_id_ << ") <CacheMergeOp>:";
if (!show_all) {
// Call the super class for displaying any common 1-liner info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal 1-liner info for this op
out << "\n";
} else {
// Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff
out << "\n\n";
}
}
CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners,
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler)
: ParallelOp(numWorkers, opConnectorSize, sampler), num_cleaners_(numCleaners), cache_client_(cache_client) {}
Status CacheMergeOp::operator()() {
// A queue of row id to let cleaner send cache miss rows to the cache server
// We don't want a small queue as this will block the parallel op workers.
// A row id is 8 byte integer. So bigger size doesn't consume a lot of memory.
io_que_ = std::make_unique<Queue<row_id_type>>(512);
RETURN_IF_NOT_OK(io_que_->Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_workers_, std::bind(&CacheMergeOp::WorkerEntry, this, std::placeholders::_1)));
RETURN_IF_NOT_OK(
tree_->LaunchWorkers(num_workers_, std::bind(&CacheMergeOp::CacheMissWorkerEntry, this, std::placeholders::_1)));
// One dedicated thread to move TensorRow from the pool to the cache server
for (auto i = 0; i < num_cleaners_; ++i) {
RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("Cleaner", std::bind(&CacheMergeOp::Cleaner, this)));
}
TaskManager::FindMe()->Post();
return Status::OK();
}
// Each parallel worker will pop from the CacheHit stream. If there is a missing TensorRow, we will wait
// until it shows up in the pool.
Status CacheMergeOp::WorkerEntry(int32_t worker_id) {
TaskManager::FindMe()->Post();
std::shared_ptr<DatasetOp> cache_hit_stream = child_[kCacheHitChildIdx];
std::unique_ptr<DataBuffer> db_ptr;
RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id));
while (!db_ptr->eof()) {
if (db_ptr->eoe()) {
RETURN_IF_NOT_OK(EoeReceived(worker_id));
db_ptr.reset();
RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id));
} else {
// See if there is any missing row
auto tbl = std::make_unique<TensorQTable>();
while (db_ptr->NumRows() > 0) {
TensorRow row;
RETURN_IF_NOT_OK(db_ptr->PopRow(&row));
if (row.empty()) {
auto row_id = row.getId();
TensorRowRequest *rq = nullptr;
RETURN_IF_NOT_OK(GetRq(row_id, &rq));
// Block until the row shows up in the pool.
RETURN_IF_NOT_OK(rq->Wait(&row));
}
tbl->push_back(std::move(row));
}
db_ptr->set_tensor_table(std::move(tbl));
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db_ptr)));
RETURN_IF_NOT_OK(cache_hit_stream->GetNextBuffer(&db_ptr, worker_id));
}
}
RETURN_IF_NOT_OK(out_connector_->Add(worker_id, std::move(db_ptr)));
return Status::OK();
}
Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
TaskManager::FindMe()->Post();
// We will simply pop TensorRow from the stream and insert them into the pool and
// wake up any worker that is awaiting on the missing TensorRow.
// If we see an eoe, ignore it. For eof, we exit.
std::shared_ptr<DatasetOp> cache_missing_stream = child_[kCacheMissChildIdx];
// Before we start, cache the schema at the server. Pick one of the workers
// do it. The schema should have been done at prepare time.
if (workerId == 0) {
RETURN_IF_NOT_OK(cache_client_->CacheSchema(column_name_id_map()));
}
std::unique_ptr<DataBuffer> db_ptr;
RETURN_IF_NOT_OK(cache_missing_stream->GetNextBuffer(&db_ptr, workerId));
while (!db_ptr->eof()) {
if (db_ptr->eoe()) {
// Ignore it.
MS_LOG(DEBUG) << "Ignore eoe";
} else {
while (db_ptr->NumRows() > 0) {
TensorRow row;
RETURN_IF_NOT_OK(db_ptr->PopRow(&row));
row_id_type row_id = row.getId();
if (row_id < 0) {
std::string errMsg = "Expect positive row id: " + std::to_string(row_id);
RETURN_STATUS_UNEXPECTED(errMsg);
}
TensorRowRequest *rq = nullptr;
RETURN_IF_NOT_OK(GetRq(row_id, &rq));
rq->WakeUpAny(std::move(row));
// Let the cleaner to flush out this row (async) to the cache server.
RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id));
}
}
RETURN_IF_NOT_OK(cache_missing_stream->GetNextBuffer(&db_ptr, workerId));
}
return Status::OK();
}
Status CacheMergeOp::Cleaner() {
TaskManager::FindMe()->Post();
while (true) {
row_id_type row_id;
RETURN_IF_NOT_OK(io_que_->PopFront(&row_id));
if (row_id < 0) {
break;
}
TensorRowRequest *rq = nullptr;
RETURN_IF_NOT_OK(GetRq(row_id, &rq));
if (rq->GetState() == TensorRowRequest::State::kClean) {
// If already flushed, move on to the next one.
continue;
}
TensorRow row;
RETURN_IF_NOT_OK(rq->Release(&row));
CHECK_FAIL_RETURN_UNEXPECTED(!row.empty(), "Programming error");
Status rc = cache_client_->WriteRow(row);
// Bad rc should not bring down the pipeline
if (rc.IsError()) {
MS_LOG(WARNING) << "Cache not successful." << rc.ToString();
}
rq->SetState(TensorRowRequest::State::kClean);
}
return Status::OK();
}
Status CacheMergeOp::GetRq(row_id_type row_id, CacheMergeOp::TensorRowRequest **out) {
RETURN_UNEXPECTED_IF_NULL(out);
std::unique_lock<std::mutex> lck(mux_);
auto it = cache_miss_map_.find(row_id);
if (it != cache_miss_map_.end()) {
*out = it->second.GetMutablePointer();
} else {
// We will create a new one.
auto alloc = Services::GetAllocator<TensorRowRequest>();
auto r = cache_miss_map_.emplace(row_id, MemGuard<TensorRowRequest, Allocator<TensorRowRequest>>(alloc));
if (r.second) {
auto &mem = r.first->second;
RETURN_IF_NOT_OK(mem.allocate(1, row_id));
*out = mem.GetMutablePointer();
} else {
RETURN_STATUS_UNEXPECTED("Map insert fail.");
}
}
return Status::OK();
}
Status CacheMergeOp::PrepareNodePostAction() { // Run any common code from super class first before adding our own
// specific logic
CHECK_FAIL_RETURN_UNEXPECTED(child_.size() == 2, "Incorrect number of children");
RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction());
// Get the computed check sum from all ops in the cache miss class
uint32_t cache_crc = DatasetOp::GenerateCRC(child_[kCacheMissChildIdx]);
// This is a mappable cache op so the id's need to be generated.
// Construct the cache
const bool generate_ids = false;
Status rc = cache_client_->CreateCache(cache_crc, generate_ids);
if (rc.get_code() == StatusCode::kDuplicateKey) {
// We are told the cache has been created already.
MS_LOG(INFO) << "Cache created already";
rc = Status::OK();
}
RETURN_IF_NOT_OK(rc);
return Status::OK();
}
Status CacheMergeOp::ComputeColMap() {
CHECK_FAIL_RETURN_UNEXPECTED(child_[kCacheMissChildIdx] != nullptr, "Cache miss stream empty");
if (column_name_id_map().empty()) {
column_name_id_map_ = child_[kCacheMissChildIdx]->column_name_id_map();
}
CHECK_FAIL_RETURN_UNEXPECTED(!column_name_id_map().empty(), "No column map detected");
return Status::OK();
}
Status CacheMergeOp::TensorRowRequest::Wait(TensorRow *out) {
RETURN_UNEXPECTED_IF_NULL(out);
// Block until the missing row is in the pool.
RETURN_IF_NOT_OK(use_count_.P());
std::unique_lock<std::mutex> lck(dq_mux_);
CHECK_FAIL_RETURN_UNEXPECTED(!row_.empty(), "Programming error");
*out = std::move(row_.front());
row_.pop_front();
return Status::OK();
}
void CacheMergeOp::TensorRowRequest::WakeUpAny(TensorRow &&row) {
std::unique_lock<std::mutex> lck(dq_mux_);
// Technically number of this row shows up in the cache miss stream is equal to the number
// of P() call. However the cleaner wants it too. So we need an extra copy.
if (GetState() == State::kEmpty) {
// We will do a deep copy
for (auto &ts : row) {
auto out_ts = std::make_shared<Tensor>(ts->shape(), ts->type(), ts->GetBuffer(), ts->SizeInBytes());
cleaner_copy_.push_back(out_ts);
}
cleaner_copy_.setId(row.getId());
// Change the state to dirty
SetState(State::kDirty);
}
row_.push_back(std::move(row));
// Bump up the use count by 1. This wake up any parallel worker which is waiting
// for this row.
use_count_.V();
}
Status CacheMergeOp::TensorRowRequest::Release(TensorRow *out) {
RETURN_UNEXPECTED_IF_NULL(out);
// We are not holding any mutex here because the cleaner isn't really touching the deque row_.
// In case we have multiple cleaners and they all see the copy, only one of them will
// get it.
auto expected = State::kDirty;
if (st_.compare_exchange_strong(expected, State::kClean)) {
*out = std::move(cleaner_copy_);
}
return Status::OK();
}
// Builder constructor. Creates the builder object.
CacheMergeOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
build_num_workers_ = cfg->num_parallel_workers();
build_op_connector_size_ = cfg->op_connector_size();
build_num_cleaners_ = 1;
}
// Check if the required parameters are set by the builder.
Status CacheMergeOp::Builder::SanityCheck() const {
if (build_cache_client_ == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CacheMergeOp requires a CacheClient");
}
// Make sure the cache client has a valid session
if (!build_cache_client_->session_id()) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"Cache client for CacheMergeOp is missing session id");
}
return Status::OK();
}
// The builder "build" method creates the final object and does some init on it
Status CacheMergeOp::Builder::Build(std::shared_ptr<CacheMergeOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<CacheMergeOp>(build_num_workers_, build_op_connector_size_, build_num_cleaners_,
build_cache_client_, build_sampler_);
return Status::OK();
}
// Pre-Visitor accept method for NodePass
Status CacheMergeOp::PreAccept(NodePass *p, bool *modified) {
// Downcast shared pointer then call the pre-visitation
return p->PreRunOnNode(shared_from_base<CacheMergeOp>(), modified);
}
// Visitor accept method for NodePass
Status CacheMergeOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<CacheMergeOp>(), modified);
}
Status CacheMergeOp::EoeReceived(int32_t worker_id) {
// If we are in a repeat path, send the eoe up.
// Otherwise ignore it.
if (BitTest(op_ctrl_flags_, kDeOpRepeated)) {
return DatasetOp::EoeReceived(worker_id);
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_
#define DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_
#include <atomic>
#include <deque>
#include <map>
#include <memory>
#include <mutex>
#include <string>
#include <utility>
#include "dataset/core/tensor_row.h"
#include "dataset/engine/cache/cache_client.h"
#include "dataset/engine/datasetops/parallel_op.h"
#include "dataset/engine/dataset_iterator.h"
#include "dataset/util/queue.h"
#include "dataset/util/semaphore.h"
namespace mindspore {
namespace dataset {
/// \brief Provides method to merge two streams (one from CacheLookup and one from cache miss stream) into one single
/// stream
class CacheMergeOp : public ParallelOp {
public:
// Some handshake structures among the main thread, cleaner threads and parallel op threads.
class TensorRowRequest {
public:
enum class State : uint8_t {
kEmpty = 0, // No row in the deque
kDirty = 1, // Cleaner hasn't flushed it to the cache server yet.
kClean = 2 // The row has been flushed already.
};
explicit TensorRowRequest(row_id_type id) : st_(State::kEmpty), use_count_(0) {}
~TensorRowRequest() = default;
State GetState() const { return st_; }
void SetState(State newState) { st_ = newState; }
Status Wait(TensorRow *out);
void WakeUpAny(TensorRow &&row);
Status Release(TensorRow *out);
private:
std::mutex dq_mux_;
std::atomic<State> st_;
Semaphore use_count_;
std::deque<TensorRow> row_;
TensorRow cleaner_copy_;
};
constexpr static int kCacheHitChildIdx = 0; // Cache hit stream
constexpr static int kCacheMissChildIdx = 1; // Cache miss stream
/// \brief The nested builder class inside of the CacheMergeOp is used to help manage all of
/// the arguments for constructing it. Use the builder by setting each argument
/// with the provided set methods, and then finally call the build method to execute
/// the actual construction.
class Builder {
public:
/// Builder constructor. Creates the builder object.
/// \note No default args
Builder();
/// Default destructor
~Builder() = default;
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetNumWorkers(int32_t num_workers) {
build_num_workers_ = num_workers;
return *this;
}
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetOpConnectorSize(int32_t connector_size) {
build_op_connector_size_ = connector_size;
return *this;
}
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetClient(std::shared_ptr<CacheClient> cache_client) {
build_cache_client_ = cache_client;
return *this;
}
/// \brief Setter method
/// \param sampler
/// \return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
build_sampler_ = std::move(sampler);
return *this;
}
/// \brief Setter method
/// \param num_cleaners
/// \return Builder setter method returns reference to the builder.
Builder &SetNumCleaner(int32_t num_cleaners) {
build_num_cleaners_ = num_cleaners;
return *this;
}
/// The builder "build" method creates the final object and does some init on it.
/// \param ptr The shared_ptr to the new CacheMergeOp object
/// \return Status
Status Build(std::shared_ptr<CacheMergeOp> *ptr);
private:
int32_t build_num_workers_;
int32_t build_op_connector_size_;
int32_t build_num_cleaners_;
std::shared_ptr<CacheClient> build_cache_client_;
std::shared_ptr<Sampler> build_sampler_;
/// Check if the required parameters are set by the builder.
/// \return Status The error code return
Status SanityCheck() const;
};
/// \brief Constructor
/// \param numWorkers Number of parallel workers as a derived class of ParallelOp
/// \param opConnector Size Connector size as a derived class of ParallelOp
/// \param numCleaners Number of cleaners to move cache miss rows into the cache server
/// \param cache_client CacheClient to commmunicate with the Cache server
/// \param sampler as a derived class of ParallelOp
CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners,
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler);
~CacheMergeOp();
void Print(std::ostream &out, bool show_all) const override;
friend std::ostream &operator<<(std::ostream &out, const CacheMergeOp &mo) {
mo.Print(out, false);
return out;
}
/// \brief Master thread responsible to spawn all the necessary worker threads for the two streams and
/// the threads for the cleaners.
/// \return
Status operator()() override;
/// \brief Entry function for worker thread that fetch rows from CacheLookupOp
/// \param workerId
/// \return Status object
Status WorkerEntry(int32_t workerId) override;
Status PrepareNodePostAction() override;
/// \brief Entry function for worker thread that fetch rows from the cache miss stream
/// \param workerId
/// \return Status object
Status CacheMissWorkerEntry(int32_t workerId);
Status GetRq(row_id_type row_id, TensorRowRequest **);
/// \brief Base-class override for NodePass pre-visit acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status PreAccept(NodePass *p, bool *modified) override;
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for eoe handling
/// \param worker_id
/// \return Status object
Status EoeReceived(int32_t worker_id) override;
protected:
Status ComputeColMap() override;
private:
std::mutex mux_;
std::map<row_id_type, MemGuard<TensorRowRequest, Allocator<TensorRowRequest>>> cache_miss_map_;
std::unique_ptr<Queue<row_id_type>> io_que_;
std::shared_ptr<CacheClient> cache_client_;
int32_t num_cleaners_;
/// \brief These are the entry functions for the cleaner threads. Each cleaner is responsible for
/// moving cache miss TensorRow into the CacheServer.
/// \return Status object
Status Cleaner();
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/cache_op.h"
#include <memory>
#include <vector>
#include "dataset/core/config_manager.h"
#include "dataset/core/constants.h"
#include "dataset/core/global_context.h"
#include "dataset/engine/datasetops/repeat_op.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/util/task_manager.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
// Builder constructor. Creates the builder object.
CacheOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
build_num_workers_ = cfg->num_parallel_workers();
rows_per_buffer_ = cfg->rows_per_buffer();
build_op_connector_size_ = cfg->op_connector_size();
}
// Check if the required parameters are set by the builder.
Status CacheOp::Builder::SanityCheck() const {
if (build_cache_client_ == nullptr) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "CacheOp requires a CacheClient");
}
// Make sure the cache client has a valid session
if (!build_cache_client_->session_id()) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Cache client for CacheOp is missing session id");
}
return Status::OK();
}
// The builder "build" method creates the final object and does some init on it
Status CacheOp::Builder::Build(std::shared_ptr<CacheOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<CacheOp>(build_num_workers_, build_op_connector_size_, rows_per_buffer_, build_cache_client_,
build_sampler_);
RETURN_IF_NOT_OK((*ptr)->InitCache());
return Status::OK();
}
// Constructor of CacheOp
CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler)
: CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler),
num_guys_in_(0),
phase_(Phase::kBuildPhase) {}
// Destructor
CacheOp::~CacheOp() = default;
// Private function for cache setup/init work just after construction
Status CacheOp::InitCache() { return Status::OK(); }
// This class functor will provide the master loop that drives the logic for performing the work
Status CacheOp::operator()() {
if (!sampler_) {
return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__,
"CacheOp requires a sampler before it can be executed!");
}
RETURN_IF_NOT_OK(RegisterResources());
// Kick off the workers
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CacheOp::WorkerEntry, this, std::placeholders::_1)));
// required task group sync after launching workers
TaskManager::FindMe()->Post();
// Wait for the workers to finish caching the rows.
RETURN_IF_NOT_OK(WaitForCachingAllRows());
RETURN_IF_NOT_OK(FetchSamplesToWorkers());
return Status::OK();
}
Status CacheOp::CacheAllRows(int32_t worker_id) {
// If the current phase is to fill the cache, do it then.
if (phase_ == Phase::kBuildPhase) {
// We will take the chance to cache the schema at the server.
// Just do it once and pick one worker to do it.
if (worker_id == 0) {
RETURN_IF_NOT_OK(cache_client_->CacheSchema(column_name_id_map()));
}
MS_LOG(INFO) << "CacheOp first epoch SAVE mode started. Worker: " << worker_id;
// SAVE mode loop
std::unique_ptr<DataBuffer> db_ptr;
RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0));
while (!db_ptr->eof()) {
if (!db_ptr->eoe()) {
RETURN_IF_NOT_OK(cache_client_->WriteBuffer(std::move(db_ptr)));
} else {
// In a repeat-over-cache scenario, any of the "real" leaf operators below us have been set up
// as non-repeating leaf ops. As such, they only do one epoch and then quit. Since we got the
// the eoe to indicate the end of the epoch, we should next expect to get the eof.
// Drain this eof so that we don't leave it sitting there on a connector that we'll never fetch
// from again.
RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0));
if (!db_ptr->eof()) {
RETURN_STATUS_UNEXPECTED("Cache op expects to get an eof after eoe from child.");
}
}
RETURN_IF_NOT_OK(this->GetNextInput(&db_ptr, worker_id, 0));
}
}
// Let the main guy know we are done.
auto last_guy_in = num_guys_in_.fetch_add(1);
if ((last_guy_in + 1) == num_workers_) {
rows_cache_done_.Set();
} else {
// Let's do a sync up here.
RETURN_IF_NOT_OK(rows_cache_done_.Wait());
}
return Status::OK();
}
Status CacheOp::WaitForCachingAllRows() {
// Wait for the workers to finish caching the rows.
RETURN_IF_NOT_OK(rows_cache_done_.Wait());
// Move from build phase to fetch phase if we are the one to fill the cache
if (phase_ == Phase::kBuildPhase) {
RETURN_IF_NOT_OK(cache_client_->BuildPhaseDone());
// Move to the next phase
phase_ = Phase::kFetchPhase;
}
// Get statistics from the server, and if we are not the one to create the cache,
// wait until the state changed from build phase to fetch base.
CacheClient::ServiceStat stat{};
bool BuildPhaseDone = true;
do {
RETURN_IF_NOT_OK(cache_client_->GetStat(&stat));
BuildPhaseDone = stat.cache_service_state == static_cast<uint8_t>(CacheService::State::kFetchPhase);
if (!BuildPhaseDone) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
} while (!BuildPhaseDone);
const row_id_type min_key = stat.min_row_id;
const row_id_type max_key = stat.max_row_id;
num_rows_ = max_key - min_key + 1;
MS_LOG(INFO) << "Number of rows cached: " << num_rows_;
MS_LOG(INFO) << "Number of rows cached in memory : " << stat.num_mem_cached;
MS_LOG(INFO) << "Number of rows spilled to disk : " << stat.num_disk_cached;
// Now all rows are cached and we have done a sync point check up. Next phase is
// is pick up fetch input from sampler and pass up to the caller.
RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this));
return Status::OK();
}
Status CacheOp::WorkerEntry(int32_t worker_id) {
TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(CacheAllRows(worker_id));
RETURN_IF_NOT_OK(FetchFromCache(worker_id));
return Status::OK();
}
Status CacheOp::RegisterResources() {
RETURN_IF_NOT_OK(CacheBase::RegisterResources());
RETURN_IF_NOT_OK(rows_cache_done_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(keys_miss_.Register(tree_->AllTasks()));
return Status::OK();
}
// Base-class override for setting specific CacheOp configurations. This code will be called
// during the execution tree prepare phase BEFORE traversing down to child operators.
uint32_t CacheOp::PrepareFlags() const { return ExecutionTree::kDePrepCache; }
// Base-class override for special eoe handler.
// CacheOp must override this because it shall not perform default handling of eoe. Instead
// the CacheOp manages actions related to the end of the epoch.
Status CacheOp::EoeReceived(int32_t worker_id) {
state_ = OpState::kDeOpIdle;
return Status::OK();
}
// Base-class override for handling cases when an eof is received.
Status CacheOp::EofReceived(int32_t worker_id) {
// eofReceived is overloaded because we want to manually handle this eof.
// Specifically, the default behaviour is to pack it and flow it up to the next connection.
// In this case, we want a no-op behaviour so that we can perform correct action.
return Status::OK();
}
// Pre-Visitor accept method for NodePass
Status CacheOp::PreAccept(NodePass *p, bool *modified) {
// Downcast shared pointer then call the pre-visitation
return p->PreRunOnNode(shared_from_base<CacheOp>(), modified);
}
// Visitor accept method for NodePass
Status CacheOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<CacheOp>(), modified);
}
// A public wrapper for creating the cache through the client
Status CacheOp::CreateCache(uint32_t cache_crc) {
// This is a non-mappable cache op so the id's need to be generated.
// Construct the cache
const bool generate_ids = true;
Status rc = cache_client_->CreateCache(cache_crc, generate_ids);
if (rc.get_code() == StatusCode::kDuplicateKey) {
// We are told the cache has been created already. So we skip the build phase.
phase_ = Phase::kFetchPhase;
rc = Status::OK();
}
RETURN_IF_NOT_OK(rc);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_CACHE_OP_H_
#define DATASET_ENGINE_DATASETOPS_CACHE_OP_H_
#include <atomic>
#include <string>
#include <utility>
#include <memory>
#include "dataset/engine/datasetops/cache_base_op.h"
namespace mindspore {
namespace dataset {
/// \brief CacheOp provides a memory/disk cache that acts as a save-point within a non-mappable dataset.
/// \note For mappable dataset, please see CacheLookupOp.
/// \see CacheLookupOp
class CacheOp : public CacheBase, public RandomAccessOp {
public:
// This CacheOp is for non-mappable case where it is divided into two phases.
// The first phase is we cache all the rows from the child (and let the cache server
// assigns row id). No read access in the first phase. Once the cache is fully built,
// we switch to second phase and fetch requests from the sampler.
enum class Phase : uint8_t { kBuildPhase = 0, kFetchPhase = 1 };
/// \brief The nested builder class inside of the CacheOp is used to help manage all of
/// the arguments for constructing it. Use the builder by setting each argument
/// with the provided set methods, and then finally call the build method to execute
/// the actual construction.
class Builder {
public:
// Builder constructor. Creates the builder object.
// @note No default args
// @return This is a constructor.
Builder();
// Default destructor
~Builder() = default;
/// \brief Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetNumWorkers(int32_t num_workers) {
build_num_workers_ = num_workers;
return *this;
}
/// \brief Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetOpConnectorSize(int32_t connector_size) {
build_op_connector_size_ = connector_size;
return *this;
}
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetClient(std::shared_ptr<CacheClient> cache_client) {
build_cache_client_ = cache_client;
return *this;
}
/// \brief Setter method
/// \param rows_per_buffer
/// \return Builder setter method returns reference to the builder.
Builder &SetRowsPerBuffer(int32_t rows_per_buffer) {
rows_per_buffer_ = rows_per_buffer;
return *this;
}
/// \brief Setter method
/// \param sampler
/// \return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
build_sampler_ = std::move(sampler);
return *this;
}
/// \brief The builder "build" method creates the final object and does some init on it.
/// \param ptr The shared_ptr to the new CacheOp object
/// \return Status
Status Build(std::shared_ptr<CacheOp> *ptr);
private:
int32_t build_num_workers_;
int32_t rows_per_buffer_;
int32_t build_op_connector_size_;
std::shared_ptr<CacheClient> build_cache_client_;
std::shared_ptr<Sampler> build_sampler_;
/// \brief Check if the required parameters are set by the builder.
/// \return Status The error code return
Status SanityCheck() const;
};
/// \brief Constructor of CacheOp
/// \note The builder class should be used to call it.
/// \param num_workers The number of worker threads.
/// \param op_connector_size The size of each queue in the connector.
CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler);
// Destructor
~CacheOp();
/// \brief Base-class override for setting specific CacheOp configurations. This code will be called
/// during the execution tree prepare phase BEFORE traversing down to child operators.
uint32_t PrepareFlags() const override;
/// \brief Base-class override for special eoe handler.
/// CacheOp must override this because it shall not perform default handling of eoe. Instead
/// the CacheOp manages actions related to the end of the epoch.
/// \return Status - The error code return
Status EoeReceived(int32_t worker_id) override;
/// \brief Base-class override for NodePass pre-visit acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status PreAccept(NodePass *p, bool *modified) override;
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for handling cases when an eof is received.
/// \param worker_id - The worker id
/// \return Status - The error code return
Status EofReceived(int32_t worker_id) override;
Status operator()() override;
Status WorkerEntry(int32_t worker_id) override;
/// \brief Base-class override for handling cases if we allow cache miss
bool AllowCacheMiss() override { return false; }
/// \brief Base-class override for the name of this operator
std::string Name() const override { return "CacheOp"; }
/// \brief A public wrapper for creating the cache through the client
/// \param[in] cache_crc The crc that identifies the cache
/// \see cache_pass.cc
/// \return Status return code
Status CreateCache(uint32_t cache_crc);
private:
WaitPost rows_cache_done_;
std::atomic<int64_t> num_guys_in_;
Phase phase_;
/// \brief The main thread will wait until all the rows are cached and will start the handshake with the sampler.
/// \return Status object
Status WaitForCachingAllRows();
/// \brief For non-mappable dataset, there is a build phase where we cache all the rows.
/// \return Status object
Status CacheAllRows(int32_t worker_id);
Status RegisterResources() override;
/// \brief Private function for cache setup/init work just after construction
/// \return Status The error code return
Status InitCache();
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_CACHE_OP_H_
......@@ -61,46 +61,39 @@ void ConcatOp::Print(std::ostream &out, bool show_all) const {
Status ConcatOp::operator()() {
// The children_num_ parameter needs to be put here
children_num_ = static_cast<int32_t>(child_.size());
TaskManager::FindMe()->Post();
std::unique_ptr<DataBuffer> buf;
RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf));
int eof_count = 0;
while (eof_count != children_num_) {
while (eof_count == 0) {
for (int i = 0; i < children_num_; i++) {
// 1. Throw the eof buffer when meet it
if (buf->eof() || buf->eoe()) {
RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf));
// 1. Read the first buffer
RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf));
if (buf->eof()) {
eof_count++;
continue;
}
// 2. Do verification as for column name, column data type and rank of column data
RETURN_IF_NOT_OK(Verify(i, buf));
if (!buf->eoe()) {
RETURN_IF_NOT_OK(Verify(i, buf));
}
// 3. Put the data into output_connector
while (!buf->eoe() && !buf->eof()) {
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buf)));
RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf));
}
// 4. Throw the eoe buffer when meet it
if (buf->eoe() && (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat))) {
RETURN_IF_NOT_OK(child_[i]->GetNextBuffer(&buf));
}
// 5. Add eoe buffer after get buffer from all child
if (i == (children_num_ - 1)) {
auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
}
if (buf->eof()) {
eof_count++;
}
}
// 4. Add eoe buffer after get buffer from all child
if (eof_count == 0) {
auto eoe_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOE);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer)));
}
}
// 6. Add eof buffer in the end manually
CHECK_FAIL_RETURN_UNEXPECTED(eof_count == children_num_,
"Something went wrong, eof count does not match the number of children.");
// 5. Add eof buffer in the end manually
MS_LOG(DEBUG) << "Add the eof buffer manualy in the end.";
auto eof_buffer = std::make_unique<DataBuffer>(0, DataBuffer::kDeBFlagEOF);
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer)));
return Status::OK();
}
......@@ -126,12 +119,6 @@ Status ConcatOp::Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf) {
return Status::OK();
}
Status ConcatOp::PrepareNodePostAction() {
RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction());
tree_->AddToEOEOpStack(shared_from_this());
return Status::OK();
}
// We need to overwrite the super class ComputeColMap here because the number of children is more than 1.
Status ConcatOp::ComputeColMap() {
if (column_name_id_map_.empty()) {
......
......@@ -75,12 +75,6 @@ class ConcatOp : public PipelineOp {
// @return Status - The error code return
Status operator()() override;
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
Status PrepareNodePostAction() override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "ConcatOp"; }
......
......@@ -153,16 +153,38 @@ Status DatasetOp::Remove() {
}
}
// Finally, clear "this" op's parent and child pointers since we have just
// disconnected it from the tree and invalidate it's fields.
child_.clear();
parent_.clear();
operator_id_ = kInvalidOperatorId;
tree_ = nullptr;
return Status::OK();
}
// Getter function to get a shared pointer to our childAdds a operator to become our child.
// Getter function to get a shared pointer to our child
std::shared_ptr<DatasetOp> DatasetOp::child(int32_t child_index) const {
std::shared_ptr<DatasetOp> return_op = nullptr;
if (child_.empty()) {
return return_op;
}
MS_ASSERT(child_index < static_cast<int>(child_.size()));
// Return a shared pointer
return child_[child_index];
}
// Getter function to get the parent pointer
void DatasetOp::Parent(DatasetOp **parent, int32_t parent_index) const {
if (parent_.empty()) {
// common case if this is a root node
*parent = nullptr;
} else {
MS_ASSERT(parent_index < static_cast<int>(parent_.size()));
*parent = parent_[parent_index];
}
}
// Creates the connector within this operator
void DatasetOp::CreateConnector(int32_t num_producers, int32_t num_consumers) {
MS_LOG(DEBUG) << "Creating connector in tree operator: " << operator_id_ << ". Producer: " << num_producers
......@@ -264,19 +286,11 @@ Status DatasetOp::EofReceived(int32_t worker_id) {
// During tree prepare phase, operators may have specific pre-operations to perform depending on
// their role.
Status DatasetOp::PrepareNodePreAction() {
if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepRepeat)) set_control_flag(kDeOpRepeated);
return Status::OK();
}
Status DatasetOp::PrepareNodePreAction() { return Status::OK(); }
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
Status DatasetOp::PrepareNodePostAction() {
// If this op does not have any children and it is in a repeat path of the tree...
if (child_.empty() && BitTest(op_ctrl_flags_, kDeOpRepeated)) {
// push ourselves onto the eoe operator stack. Later, a repeat/epoch ctrl operator
// above us will consume them.
tree_->AddToEOEOpStack(shared_from_this());
}
// Creating Connector object for each op.
// The consumer of the root node is assumed to be one thread.
// If multiple threads are consuming from the root node, they will get the ordered data in round robin fashion.
......@@ -346,34 +360,13 @@ Status DatasetOp::Accept(NodePass *p, bool *modified) {
return p->RunOnNode(shared_from_this(), modified);
}
// A helper function with some common code that leaf nodes can use during
// prepare phase for checking if they need to assign a sampler to the cache.
Status DatasetOp::SaveSamplerForCache(bool random_access_op) {
// If we are a descendant under a cache op and we have a sampler, then save this sampler
// to a stack so that the cache can pick it up during it's processing above us.
if (sampler_) {
if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepCache)) {
// use move semantic to set our sampler_ to null after the move. This is okay because a sampler is
// useless to a random data op. It was only being used as a temporary holding until the cache can
// be created
tree_->AddToSamplerStack(sampler_);
MS_LOG(INFO) << "Preparing a leaf op: passing sampler up the tree for Cache handling.";
} else if (!random_access_op) {
// A sampler exists, but we are not in a caching tree and we are not a random access mappable leaf.
// This is an error because that type of leaf does not use sampling unless there's a cache to hook it into.
RETURN_STATUS_UNEXPECTED(
"Non-mappable leaf op has a sampler, but it only supports sampling if there is a cache after it in the tree");
}
}
if (!random_access_op) {
// Since we don't truly need the sampler for this non-mappable dataset and it's been saved for the cache
// we can remove it now from the base.
sampler_.reset();
}
// Getter for the sampler, and it also removes the sampler from the op
Status DatasetOp::FetchRemoveSampler(std::shared_ptr<Sampler> *sampler) {
*sampler = sampler_; // It's okay if it sampler_ points to nullptr
sampler_.reset(); // clear our member-copy of this pointer. We no longer have this sampler
return Status::OK();
}
uint32_t DatasetOp::GenerateCRC(const std::shared_ptr<DatasetOp> &op) {
std::stringstream ss;
op->tree_->Print(ss, op);
......
......@@ -45,10 +45,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
public:
static constexpr int32_t kInvalidOperatorId = -1;
// Flags that control operator runtime behaviours
// Operator control flags
enum OpControlFlags {
kDeOpNone = 0,
kDeOpRepeated = 1, // Operator is a leaf node in a repeat path
kDeOpRepeated = 1, // Operator is a node in a repeat path
kDeOpLastRepeat = 1 << 1 // We are in the last repeat loop
};
......@@ -71,17 +71,23 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \param child - shared pointer to the child to remove.
Status RemoveChild(std::shared_ptr<DatasetOp> child);
/// \brief Removes this node from the tree and connects it's parent/child together.
/// \brief Removes this node from the tree and connects it's parent/child together
/// \return Status eerror code returned
Status Remove();
/// \brief Getter function to get a shared pointer to our child
/// \param child_index - An operator can have n children. Indicates choose which child to return.
/// \param[in] child_index An operator can have n children. Indicates which child to return.
/// \return The shared pointer to the child. If there are no children, it returns null regardless of the given index
std::shared_ptr<DatasetOp> child(int32_t child_index) const;
/// \brief Inserts a operator as the parent current op.
/// Inserted op will become the sole parent of the current op.
/// The existing parent of the current op will be transferred to the inserted op.
/// \brief Getter function to get the pointer to our parent
/// If there are no parents, it returns null regardless of the given index
/// \param[in] parent_index An operator can have n parents. Indicates which parent to return.
void Parent(DatasetOp **parent, int32_t parent_index) const;
// Inserts a operator as the parent current op.
// Inserted op will become the sole parent of the current op.
// The existing parent of the current op will be transferred to the inserted op.
Status InsertAsParent(std::shared_ptr<DatasetOp> to_add);
/// \brief Creates the connector within this operator
......@@ -161,16 +167,6 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return Status - The error code return
virtual Status Reset();
/// \brief This calls the reset function on this subtree in pre-order
/// \return Status - The error code return
virtual Status ResetSubtree() {
RETURN_IF_NOT_OK(Reset());
for (const auto &c : child_) {
RETURN_IF_NOT_OK(c->ResetSubtree());
}
return Status::OK();
}
/// \brief During tree prepare phase, operators may have specific pre-operations to perform depending on
/// their role.
/// \notes Derived versions of this function should always call it's superclass version first
......@@ -296,7 +292,12 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return Shared pointer to the sampler (may return nullptr)
std::shared_ptr<Sampler> sampler() { return sampler_; }
/// Computes a CRC value for the operator
/// \brief Getter for the sampler, and it also removes the sampler from the op
/// \param[out] sampler A pointer to the output sampler that was removed
/// \return Status error code
Status FetchRemoveSampler(std::shared_ptr<Sampler> *sampler);
// Computes a CRC value for the operator
static uint32_t GenerateCRC(const std::shared_ptr<DatasetOp> &op);
/// \brief A helper templated function for casting "this" pointer to shared_ptr<derived>
......@@ -307,17 +308,24 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
return std::static_pointer_cast<Derived>(shared_from_this());
}
protected:
/// Adds a parent operator to this operator
/// \notes External callers do not have access to this function.
/// \param parent - The parent node to add
void AddParent(DatasetOp *parent);
/// \brief Setter for the sampler. Allows you to overwrite a previous sampler with a new one.
void SetSampler(std::shared_ptr<Sampler> sampler) { sampler_ = sampler; }
/// \brief Checks if this is a leaf node (0 children)
/// \return boolean returns true if it's a leaf
bool IsLeaf() { return (child_.empty()); }
/// Removes a parent operator from this operator
/// \notes External callers do not have access to this function.
/// \param parent - The parent node to remove
protected:
/// \brief Removes a parent operator from this operator
/// \notes External callers do not have access to this function
/// \param[in] parent The parent node to remove
void RemoveParent(const DatasetOp *parent);
/// \brief Adds a parent operator to this operator
/// \notes External callers do not have access to this function
/// \param[in] parent The parent node to add
void AddParent(DatasetOp *parent);
/// Compute the current op's column map using its child's column map.
/// Get called during the tree post-prepare phase in PrepareNodePostAction.
/// This base implementation just inherits the map from child 0, and can only be used if the number of children is 1.
......@@ -325,12 +333,6 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return - Status
virtual Status ComputeColMap();
/// A helper function with some common code that leaf nodes can use during
/// pre/pare phase for checking if they need to assign a sampler to the cache.
/// \param random_access_op - indicate if this is a mappable random access leaf or not
/// \return - Status
Status SaveSamplerForCache(bool random_access_op);
std::vector<std::shared_ptr<DatasetOp>> child_; // Child nodes
std::vector<DatasetOp *> parent_; // Parent nodes. No ownership
std::shared_ptr<Sampler> sampler_; // Some leaf ops might have a sampler
......
......@@ -77,26 +77,6 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const {
}
}
// Base-class override for executing specific RepeatOp configurations. This code will be called
// during the execution tree prepare phase when it is visiting this operator.
Status RepeatOp::PrepareNodePostAction() {
// Run any common code from super class first before adding our own specific logic
RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction());
std::shared_ptr<DatasetOp> leaf_op = tree_->PopFromEOEOpStack();
while (leaf_op != nullptr) {
// Track the leaf operators that are under this repeat op.
eoe_ops_.push_back(leaf_op);
leaf_op = tree_->PopFromEOEOpStack();
}
// Push ourselves to the stack in case one of our ascendants is repeat too.
tree_->AddToEOEOpStack(shared_from_this());
return Status::OK();
}
// Base-class override for setting specific RepeatOp configurations. This code will be called
// during the execution tree prepare phase BEFORE traversing down to child operators.
uint32_t RepeatOp::PrepareFlags() const { return ExecutionTree::kDePrepRepeat; }
// This function returns the buffer that is at the top of our output connector. The caller is
// typically our parent node, when the parent is asking us to provide the next buffer of data.
// Since RepeatOp is an inlined op, getting a buffer from us will simply bounce you to get
......@@ -130,7 +110,8 @@ Status RepeatOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t wo
// Base-class override for handling cases when an eoe is received.
Status RepeatOp::EoeReceived(int32_t worker_id) {
repeat_count_++;
MS_LOG(DEBUG) << "Repeat operator end of epoch message received. Repeat count is now: " << repeat_count_ << ".";
MS_LOG(DEBUG) << "Repeat operator (" << operator_id_
<< ") end of epoch message received. Repeat count is now: " << repeat_count_ << ".";
bool repeated = BitTest(op_ctrl_flags_, kDeOpRepeated);
bool last_repeat = BitTest(op_ctrl_flags_, kDeOpLastRepeat);
// If we've reached the requested repeat count, then flag the eoe nodes
......@@ -149,8 +130,12 @@ Status RepeatOp::EoeReceived(int32_t worker_id) {
return Status::OK();
}
// base-class ResetSubtree
return (DatasetOp::ResetSubtree());
// Invoke a reset against the eoe nodes only.
for (auto &eoe_op : eoe_ops_) {
RETURN_IF_NOT_OK(eoe_op->Reset());
}
return Status::OK();
}
// Class functor operator () override.
......@@ -178,6 +163,18 @@ int32_t RepeatOp::num_consumers() const {
}
}
// Drive reset actions if needed
Status RepeatOp::Reset() {
// If there's nested repeats, an ascendant repeat may have ourself listed as an eoe op.
// In that case, we now have to bounce the reset down to our own eoe ops.
MS_LOG(DEBUG) << "Repeat operator (" << operator_id_ << ") reset.";
for (auto &eoe_op : eoe_ops_) {
RETURN_IF_NOT_OK(eoe_op->Reset());
}
state_ = OpState::kDeOpRunning;
return Status::OK();
}
int32_t RepeatOp::num_producers() const {
if (child_.empty() || child_[0] == nullptr) {
MS_LOG(DEBUG) << "Repeat operator, pointer to child node is null. Returning 0.";
......@@ -187,6 +184,12 @@ int32_t RepeatOp::num_producers() const {
}
}
// Pre-Visitor accept method for NodePass
Status RepeatOp::PreAccept(NodePass *p, bool *modified) {
// Downcast shared pointer then call the pre-visitation
return p->PreRunOnNode(shared_from_base<RepeatOp>(), modified);
}
// Visitor accept method for NodePass
Status RepeatOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
......
......@@ -18,6 +18,7 @@
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "dataset/engine/datasetops/pipeline_op.h"
......@@ -82,14 +83,6 @@ class RepeatOp : public PipelineOp {
// @return Status - The error code return
Status operator()() override;
// Base-class override for setting specific RepeatOp configurations. This code will be called
// during the execution tree prepare phase BEFORE traversing down to child operators.
uint32_t PrepareFlags() const override;
// Base-class override for executing specific RepeatOp configurations. This code will be called
// during the execution tree post-prepare phase when it is visiting this operator.
Status PrepareNodePostAction() override;
// This function returns the buffer that is at the top of our output connector. The caller is
// typically our parent node, when the parent is asking us to provide the next buffer of data.
// Since RepeatOp is an inlined op, getting a buffer from us will simply bounce you to get
......@@ -110,6 +103,10 @@ class RepeatOp : public PipelineOp {
// @param worker_id - The worker id
Status EofReceived(int32_t worker_id) override;
/// \brief reset Op
/// \@return Status - The error code return
Status Reset() override;
// Base-class override. Return the number of workers in the first parent.
// @param workerId - The worker id
int32_t num_consumers() const override;
......@@ -118,16 +115,26 @@ class RepeatOp : public PipelineOp {
// @param workerId - The worker id
int32_t num_producers() const override;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
/// \brief Base-class override for NodePass pre-visit acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status PreAccept(NodePass *p, bool *modified) override;
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "RepeatOp"; }
/// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes
/// \param[in] eoe_op The input leaf/eoe operator to add to the list
void AddToEoeList(std::shared_ptr<DatasetOp> eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); }
private:
int32_t max_repeats_; // The number of repeats that the user requested
int32_t repeat_count_; // A counter for the current number of executed repeats
......
......@@ -22,6 +22,7 @@
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/data_schema.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/kernels/image/image_utils.h"
namespace mindspore {
......@@ -408,6 +409,12 @@ Status CelebAOp::Reset() {
return Status::OK();
}
// Visitor accept method for NodePass
Status CelebAOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<CelebAOp>(), modified);
}
Status CelebAOp::ComputeColMap() {
// Set the column name map (base class field)
if (column_name_id_map_.empty()) {
......
......@@ -169,6 +169,12 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
// @return Status - The error code return
Status AddIOBlock(std::unique_ptr<DataBuffer> *data_buffer);
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p Pointer to the NodePass to be accepted
/// \param[out] modified Indicator if the node was changed at all
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
// Op name getter
// @return Name of the current Op
std::string Name() const { return "CelebAOp"; }
......
......@@ -26,6 +26,7 @@
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
......@@ -450,6 +451,12 @@ Status CifarOp::CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *
}
}
// Visitor accept method for NodePass
Status CifarOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<CifarOp>(), modified);
}
Status CifarOp::ComputeColMap() {
// set the column name map (base class field)
if (column_name_id_map_.empty()) {
......
......@@ -155,6 +155,12 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
// @return
static Status CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *count);
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p Pointer to the NodePass to be accepted
/// \param[out] modified Indicator if the node was changed at all
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "CifarOp"; }
......
......@@ -24,6 +24,7 @@
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
......@@ -624,6 +625,12 @@ Status CocoOp::GetClassIndexing(const std::string &dir, const std::string &file,
return Status::OK();
}
// Visitor accept method for NodePass
Status CocoOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<CocoOp>(), modified);
}
Status CocoOp::ComputeColMap() {
// Set the column name map (base class field)
if (column_name_id_map_.empty()) {
......
......@@ -200,6 +200,12 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
static Status GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode,
std::vector<std::pair<std::string, std::vector<int32_t>>> *output_class_indexing);
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p Pointer to the NodePass to be accepted
/// \param[out] modified Indicator if the node was changed at all
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
private:
// Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return
......
......@@ -26,6 +26,7 @@
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
......@@ -416,6 +417,12 @@ Status ManifestOp::GetClassIndexing(const std::string &file, const py::dict &dic
return Status::OK();
}
// Visitor accept method for NodePass
Status ManifestOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<ManifestOp>(), modified);
}
Status ManifestOp::ComputeColMap() {
// Set the column name map (base class field)
if (column_name_id_map_.empty()) {
......
......@@ -172,6 +172,12 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
static Status GetClassIndexing(const std::string &file, const py::dict &dict, const std::string &usage,
std::map<std::string, int32_t> *output_class_indexing);
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p Pointer to the NodePass to be accepted
/// \param[out] modified Indicator if the node was changed at all
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "ManifestOp"; }
......
......@@ -23,6 +23,7 @@
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
......@@ -428,6 +429,12 @@ Status MnistOp::CountTotalRows(const std::string &dir, int64_t *count) {
return Status::OK();
}
// Visitor accept method for NodePass
Status MnistOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<MnistOp>(), modified);
}
Status MnistOp::ComputeColMap() {
// set the column name map (base class field)
if (column_name_id_map_.empty()) {
......
......@@ -152,6 +152,12 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
// @return
static Status CountTotalRows(const std::string &dir, int64_t *count);
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p Pointer to the NodePass to be accepted
/// \param[out] modified Indicator if the node was changed at all
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "MnistOp"; }
......
......@@ -22,6 +22,7 @@
#include "dataset/util/random.h"
#include "dataset/util/wait_post.h"
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
......@@ -406,6 +407,12 @@ Status RandomDataOp::Reset() {
return Status::OK();
}
// Visitor accept method for NodePass
Status RandomDataOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<RandomDataOp>(), modified);
}
Status RandomDataOp::ComputeColMap() {
// Extract the column name mapping from the schema and save it in the class.
if (column_name_id_map_.empty()) {
......@@ -415,15 +422,5 @@ Status RandomDataOp::ComputeColMap() {
}
return Status::OK();
}
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
Status RandomDataOp::PrepareNodePostAction() {
// Run common code from super class before adding RandomDataOp specific handling
RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction());
// Specific handling for this op, we need to do cache op work to assign the sampler to the cache.
RETURN_IF_NOT_OK(DatasetOp::SaveSamplerForCache(false));
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
......@@ -203,12 +203,6 @@ class RandomDataOp : public ParallelOp {
// @return Name of the current Op
std::string Name() const override { return "RandomDataOp"; }
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
Status PrepareNodePostAction() override;
private:
/**
* The entry point code for when workers are launched
......@@ -266,6 +260,12 @@ class RandomDataOp : public ParallelOp {
return ++buffer_id_;
}
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
// Private function for computing the assignment of the column name map.
// @return - Status
Status ComputeColMap() override;
......
......@@ -1019,31 +1019,28 @@ Status TFReaderOp::ComputeColMap() {
return Status::OK();
}
// Brief If a cache has been added into the ascendant tree over this tf reader, then the cache will be executing
// a sampler for fetching the data. As such, any options in the tf reader need to be reset to its defaults so
// that this tf reader will produce the full set of data into the cache.
void TFReaderOp::MakeSimpleProducer() {
device_id_ = 0;
num_devices_ = 1;
total_rows_ = 0;
shuffle_files_ = false;
equal_rows_per_shard_ = false;
}
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
Status TFReaderOp::PrepareNodePostAction() {
// Run common code from super class before adding TFReaderOp specific handling
RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction());
// Specific handling for this op, we need to do cache op work so assign the sampler to the cache
// TF is a special case because it can support file-based sharding/shuffling, or, if there
// is a cache, then it can also do row-based sampler using the sampler on the cache.
// Thus, pass true for random access op flag when saving the sampler. This is a special case,
// since usually a non-mappable dataset would pass false here.
RETURN_IF_NOT_OK(DatasetOp::SaveSamplerForCache(true));
// Now that the sampler has been saved for the cache, we need to adjust the TFReaderOp to turn it into
// a simpler producer of all data (no shuffling or sharding or anything)
if (BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepCache)) {
device_id_ = 0;
num_devices_ = 1;
total_rows_ = 0;
shuffle_files_ = false;
equal_rows_per_shard_ = false;
sampler_.reset(); // Normally SaveSampler code did this for us, but we passed in true above (See comment)
} else {
if (!BitTest(tree_->PrepareFlags(), ExecutionTree::kDePrepCache)) {
// This sanity check had been delayed until now in the prepare loop.
// If we are not in a cache path, then we can validate the the file-based sharding config.
// If we are not in a cache path, then we can validate the file-based sharding config.
// If we are in a cache path, there is no file-based sharding so the check is not correct in that
// situation.
if (!equal_rows_per_shard_ && dataset_files_list_.size() < static_cast<uint32_t>(num_devices_)) {
......
......@@ -246,6 +246,11 @@ class TFReaderOp : public ParallelOp {
// @return Vector of the input file names
std::vector<std::string> FileNames() { return dataset_files_list_; }
/// \Brief If a cache has been added into the ascendant tree over this tf reader, then the cache will be executing
/// a sampler for fetching the data. As such, any options in the tf reader need to be reset to its defaults so
/// that this tf reader will produce the full set of data into the cache.
void MakeSimpleProducer();
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
......
......@@ -25,6 +25,7 @@
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
using tinyxml2::XMLDocument;
using tinyxml2::XMLElement;
......@@ -449,6 +450,11 @@ Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_t
return Status::OK();
}
// Visitor accept method for NodePass
Status VOCOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<VOCOp>(), modified);
}
Status VOCOp::ComputeColMap() {
// Set the column name map (base class field)
......
......@@ -205,6 +205,12 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
static Status GetClassIndexing(const std::string &dir, const std::string &task_type, const std::string &task_mode,
const py::dict &dict, std::map<std::string, int32_t> *output_class_indexing);
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p Pointer to the NodePass to be accepted
/// \param[out] modified Indicator if the node was changed at all
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "VOCOp"; }
......
......@@ -127,12 +127,6 @@ Status TakeOp::FillBuffer(std::unique_ptr<DataBuffer> *buffer, std::unique_ptr<D
return Status::OK();
}
Status TakeOp::PrepareNodePostAction() {
RETURN_IF_NOT_OK(PipelineOp::PrepareNodePostAction());
tree_->AddToEOEOpStack(shared_from_this());
return Status::OK();
}
// Visitor accept method for NodePass
Status TakeOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
......
......@@ -78,12 +78,6 @@ class TakeOp : public PipelineOp {
// @return Status - The error code return
Status operator()() override;
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
Status PrepareNodePostAction() override;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
......
......@@ -21,6 +21,8 @@
#include "dataset/util/task_manager.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/engine/opt/pre/removal_pass.h"
#include "dataset/engine/opt/pre/cache_transform_pass.h"
#include "dataset/engine/opt/post/repeat_pass.h"
#include "dataset/engine/perf/profiling.h"
#include "dataset/engine/perf/monitor.h"
......@@ -215,18 +217,33 @@ Status ExecutionTree::PrepareTreePreAction() {
bool modified = false;
std::vector<std::unique_ptr<Pass>> pre_actions;
// Construct pre actions
MS_LOG(INFO) << "Running pre pass";
pre_actions.push_back(std::make_unique<RemovalPass>(RemovalPass()));
MS_LOG(INFO) << "Running pre pass loops.";
pre_actions.push_back(std::make_unique<RemovalPass>());
pre_actions.push_back(std::make_unique<CacheTransformPass>());
// Apply pre action passes
for (auto &pass : pre_actions) {
RETURN_IF_NOT_OK(pass->Run(this, &modified));
}
MS_LOG(INFO) << "Pre passes complete.";
return Status::OK();
}
Status ExecutionTree::PrepareTreePostAction() {
// The tree is ready to be prepared.
tree_state_ = kDeTStatePrepare;
bool modified = false;
std::vector<std::unique_ptr<Pass>> post_actions;
// Construct pre actions
MS_LOG(INFO) << "Running post pass loops.";
post_actions.push_back(std::make_unique<RepeatPass>());
// Apply post action passes
for (auto &pass : post_actions) {
RETURN_IF_NOT_OK(pass->Run(this, &modified));
}
MS_LOG(INFO) << "Post passes complete.";
return Status::OK();
}
......@@ -280,31 +297,5 @@ Status ExecutionTree::PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op)
return Status::OK();
}
// Adds an operator to the eoe operator stack during prepare phase.
void ExecutionTree::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) { eoe_stack_.push(dataset_op); }
// Pops an operator from the eoe operator stack during prepare phase.
std::shared_ptr<DatasetOp> ExecutionTree::PopFromEOEOpStack() {
std::shared_ptr<DatasetOp> top_op = nullptr;
if (!eoe_stack_.empty()) {
top_op = eoe_stack_.top();
eoe_stack_.pop();
}
return top_op;
}
// Adds a sampler to the sampler stack during prepare phase.
void ExecutionTree::AddToSamplerStack(std::shared_ptr<Sampler> sampler) { sampler_stack_.push(sampler); }
// Pops an operator from the sampler stack during prepare phase.
std::shared_ptr<Sampler> ExecutionTree::PopFromSamplerStack() {
std::shared_ptr<Sampler> top_sampler = nullptr;
if (!sampler_stack_.empty()) {
top_sampler = sampler_stack_.top();
sampler_stack_.pop();
}
return top_sampler;
}
} // namespace dataset
} // namespace mindspore
......@@ -200,24 +200,6 @@ class ExecutionTree {
// @return Status - The error code return
Status PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op);
/// Adds an operator to the eoe operator stack during prepare phase.
/// \param op - The dataset op to work add to eoe stack
/// \return Status - The error code return
void AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op);
/// Pops an operator from the eoe operator stack during prepare phase.
/// \return shared_ptr to the popped operator
std::shared_ptr<DatasetOp> PopFromEOEOpStack();
/// Adds a sampler to the sampler stack during prepare phase.
/// \param samplerop - The dataset op to work add to eoe stack
/// \return Status - The error code return
void AddToSamplerStack(std::shared_ptr<Sampler> sampler);
/// Pops an operator from the sampler stack during prepare phase.
/// \return shared_ptr to the popped operator
std::shared_ptr<Sampler> PopFromSamplerStack();
// Return the pointer to the TaskGroup
// @return raw pointer to the TaskGroup
TaskGroup *AllTasks() const { return tg_.get(); }
......@@ -248,8 +230,6 @@ class ExecutionTree {
TreeState tree_state_; // Tracking the current tree state
std::unique_ptr<Monitor> perf_monitor_; // Performance Monitor
std::unique_ptr<ProfilingManager> profiling_manager_; // Profiling manager
std::stack<std::shared_ptr<DatasetOp>> eoe_stack_; // A stack used during prepare phase
std::stack<std::shared_ptr<Sampler>> sampler_stack_; // A stack used during prepare phase
};
} // namespace dataset
} // namespace mindspore
......
......@@ -2,6 +2,9 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(engine-opt OBJECT
pass.cc
post/repeat_pass.cc
pre/cache_pass.cc
pre/cache_transform_pass.cc
pre/removal_nodes.cc
pre/removal_pass.cc
util/printer_pass.cc
......
......@@ -16,6 +16,9 @@
#include "dataset/engine/opt/pass.h"
#include "dataset/engine/datasetops/batch_op.h"
#include "dataset/engine/datasetops/cache_op.h"
#include "dataset/engine/datasetops/cache_merge_op.h"
#include "dataset/engine/datasetops/cache_lookup_op.h"
#include "dataset/engine/datasetops/dataset_op.h"
#include "dataset/engine/datasetops/device_queue_op.h"
#include "dataset/engine/datasetops/map_op.h"
......@@ -24,8 +27,15 @@
#include "dataset/engine/datasetops/repeat_op.h"
#include "dataset/engine/datasetops/skip_op.h"
#include "dataset/engine/datasetops/shuffle_op.h"
#include "dataset/engine/datasetops/source/celeba_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/coco_op.h"
#include "dataset/engine/datasetops/source/manifest_op.h"
#include "dataset/engine/datasetops/source/mindrecord_op.h"
#include "dataset/engine/datasetops/source/mnist_op.h"
#include "dataset/engine/datasetops/source/random_data_op.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/datasetops/source/voc_op.h"
#ifdef ENABLE_PYTHON
#include "dataset/engine/datasetops/filter_op.h"
#include "dataset/engine/datasetops/source/generator_op.h"
......@@ -145,6 +155,11 @@ Status NodePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) {
}
#endif
Status NodePass::RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<TakeOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
......@@ -164,5 +179,70 @@ Status NodePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified)
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
} // namespace dataset
} // namespace mindspore
......@@ -47,6 +47,10 @@ class FilterOp;
class GeneratorOp;
#endif
class RandomDataOp;
class RepeatOp;
class TakeOp;
class ZipOp;
......@@ -55,6 +59,24 @@ class DeviceQueueOp;
class ImageFolderOp;
class CacheOp;
class MnistOp;
class ManifestOp;
class CifarOp;
class VOCOp;
class CocoOp;
class CelebAOp;
class CacheMergeOp;
class CacheLookupOp;
// The base class Pass is the basic unit of tree transformation.
// The actual implementation of the passes will be derived from here.
class Pass : public std::enable_shared_from_this<Pass> {
......@@ -138,14 +160,42 @@ class NodePass : public Pass {
virtual Status RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified);
#endif
virtual Status RunOnNode(std::shared_ptr<RandomDataOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<TakeOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<ZipOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<DeviceQueueOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<MnistOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<CifarOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<VOCOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<CocoOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified);
private:
// Helper function to perform DFS visit
Status DFSNodeVisit(std::shared_ptr<DatasetOp> node, bool *modified);
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "dataset/engine/opt/post/repeat_pass.h"
#include "dataset/engine/datasetops/repeat_op.h"
#include "dataset/engine/datasetops/cache_op.h"
#include "dataset/engine/datasetops/cache_lookup_op.h"
#include "dataset/engine/datasetops/cache_merge_op.h"
namespace mindspore {
namespace dataset {
RepeatPass::RepeatPass() : is_repeated_(false), nested_repeats_(0), is_merge_(false), cache_lookup_(nullptr) {}
// Identifies the subtree below this node as being in a repeated path of the tree.
Status RepeatPass::PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
// If we are already repeated, then this is a nested repeat.
if (is_repeated_) {
nested_repeats_++;
}
is_repeated_ = true;
return Status::OK();
}
// Identifies the subtree below this node as being in a cache merge path
Status RepeatPass::PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
// Turn on the flag that we're under a merge op
is_merge_ = true;
return Status::OK();
}
// Hooks up any identified eoe nodes under this repeat.
Status RepeatPass::RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) {
// Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking
std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack();
while (leaf_op != nullptr) {
node->AddToEoeList(leaf_op);
leaf_op = PopFromEOEOpStack();
}
// We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up
// and add it to the list of eoe/leaf ops for the repeat, removing it from the save area.
if (is_merge_ && cache_lookup_) {
cache_lookup_->set_control_flag(DatasetOp::kDeOpRepeated);
node->AddToEoeList(std::move(cache_lookup_));
}
// If we are a nested repeat, then we add ourself to the repeat stack for the next one above us.
// A nested repeat acts like an eoe/leaf for the repeat in the ascendant tree.
if (nested_repeats_ > 0) {
node->set_control_flag(DatasetOp::kDeOpRepeated);
AddToEOEOpStack(node);
nested_repeats_--;
}
// If we are not nested, or we were the top-most repeat, now we clear the flag
if (nested_repeats_ == 0) {
is_repeated_ = false;
}
return Status::OK();
}
// CacheOp removes previous leaf ops and replaces them with itself
Status RepeatPass::RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
if (is_repeated_) {
node->set_control_flag(DatasetOp::kDeOpRepeated);
// if we are a cache within a repeat path of the tree, then there will be
// eoe-generating ops in the eoe op stack in the tree. They are flagged as such so that the
// repeat or epoch ctrl operators can work with them for repeat activity during runtime.
// However, since a cache is present:
// - unflag those ops as being repeated ops
// - remove them from the eoe op stack so that repeat op above in the tree won't know about them
// - add ourself (the cache op), as an eoe op
// We do this so that those old leafs become 1-time use (up to eoe), never repeated. Instead
// the repeating behaviours shall be invoked against the cache op.
std::shared_ptr<DatasetOp> leaf_op = PopFromEOEOpStack();
while (leaf_op != nullptr) {
leaf_op->ClearControlFlag(DatasetOp::kDeOpLastRepeat);
leaf_op->ClearControlFlag(DatasetOp::kDeOpRepeated);
leaf_op = PopFromEOEOpStack();
}
AddToEOEOpStack(std::static_pointer_cast<DatasetOp>(node));
}
return Status::OK();
}
// All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up
// for use with a controlling repeat above it.
Status RepeatPass::RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) {
// If we are in a repeat path, then set our repeated flag
if (is_repeated_) {
node->set_control_flag(DatasetOp::kDeOpRepeated);
// if we are a leaf node then save ourself in a stack for the repeat operator above us
if (node->IsLeaf()) {
AddToEOEOpStack(node);
}
}
return Status::OK();
}
// Turns off the tracking for operations under merge op
Status RepeatPass::RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) {
// Setting the flag is needed since we didn't call the base class DatasetOp version
if (is_repeated_) node->set_control_flag(DatasetOp::kDeOpRepeated);
is_merge_ = false;
cache_lookup_.reset(); // If a repeat op did not consume this then it's no longer needed
return Status::OK();
}
// Saves the lookup up in case it needs to be referenced by a repeat
Status RepeatPass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) {
if (!node->IsLeaf()) {
// By definition, the CacheLookup must be a leaf op. Make that clear here.
RETURN_STATUS_UNEXPECTED("CacheLookupOp must be a leaf node!");
}
// If we are in a repeat path already, then there must be a repeat above the merge op
// In this case, we naturally are a repeating leaf op so add the required setup for leafs under repeat here.
if (is_repeated_) {
node->set_control_flag(DatasetOp::kDeOpRepeated);
AddToEOEOpStack(node);
} else {
// save the lookup op. There could be a repeat in the cache miss leg of the merge op, in which case we
// may still need to be flagged as a repeating leaf. We can't decide that here though, so save ourself
// into the pass so that the decision can be made during the processing of the cache miss leg of the merge.
cache_lookup_ = std::static_pointer_cast<DatasetOp>(node);
}
return Status::OK();
}
// Adds an operator to the eoe operator stack save area
void RepeatPass::AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op) { eoe_stack_.push(dataset_op); }
// Pops an operator from the eoe operator stack save area
std::shared_ptr<DatasetOp> RepeatPass::PopFromEOEOpStack() {
std::shared_ptr<DatasetOp> top_op = nullptr;
if (!eoe_stack_.empty()) {
top_op = eoe_stack_.top();
eoe_stack_.pop();
}
return top_op;
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_
#define DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_
#include <memory>
#include <stack>
#include <utility>
#include "dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
/// \class RepeatPass repeat_pass.h
/// \brief This is a NodePass who's job is to perform setup actions for RepeatOps. A RepeatOp needs to have references
/// to the eoe-producing (typically leaf) nodes underneath it.
class RepeatPass : public NodePass {
public:
/// \brief Constructor
RepeatPass();
/// \brief Identifies the subtree below this node as being in a repeated path of the tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override;
/// \brief Identifies the subtree below this node as being in a cache merge path
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) override;
/// \brief Hooks up any identified eoe nodes under this repeat.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<RepeatOp> node, bool *modified) override;
/// \brief CacheOp removes previous leaf ops and replaces them with itself
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
/// \brief Turns of the tracking for operations under merge op
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified) override;
/// \brief Saves the lookup up in case it needs to be referenced by a repeat
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified) override;
/// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up
/// for use with a controlling repeat above it.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<DatasetOp> node, bool *modified) override;
private:
/// \brief Adds an operator to the eoe operator stack save area
/// \param op - The dataset op to work add to eoe stack
/// \return Status - The error code return
void AddToEOEOpStack(std::shared_ptr<DatasetOp> dataset_op);
/// \brief Pops an operator from the eoe operator stack save area
/// \return shared_ptr to the popped operator
std::shared_ptr<DatasetOp> PopFromEOEOpStack();
bool is_repeated_; // T/F if we are processing under a repeat
bool is_merge_; // T/F if we are processing under a cache merge op
int32_t nested_repeats_; // A counter for nested repeats
std::stack<std::shared_ptr<DatasetOp>> eoe_stack_; // A save area for leaf/eoe ops
std::shared_ptr<DatasetOp> cache_lookup_; // A save area for a cache lookup op
};
} // namespace dataset
} // namespace mindspore
#endif // DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_
此差异已折叠。
此差异已折叠。
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "dataset/engine/opt/pre/cache_pass.h"
#include "dataset/engine/opt/pre/cache_transform_pass.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/cache/cache_client.h"
#include "dataset/engine/datasetops/cache_lookup_op.h"
#include "dataset/engine/datasetops/cache_merge_op.h"
#include "dataset/engine/datasetops/cache_op.h"
namespace mindspore {
namespace dataset {
// constructor
CacheTransformPass::CacheTransformPass() {}
// Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations
Status CacheTransformPass::RunOnTree(ExecutionTree *tree, bool *modified) {
MS_LOG(INFO) << "Pre pass: Cache transform pass started.";
// Create the cache pass and run it. The cache pass identifies and creates the leaf/cache pairs that we will
// use to execute a transform.
std::unique_ptr<Pass> cache_pass = std::make_unique<CachePass>(this);
RETURN_IF_NOT_OK(cache_pass->Run(tree, modified));
// Then, execute the transform for each pair
for (auto cache_pair : cache_pairs_) {
MS_LOG(DEBUG) << "Cache transform pass: Executing a cache op mappable transform.";
ExecuteCacheTransform(tree, cache_pair.first, cache_pair.second, cache_pair.second->cache_client());
}
MS_LOG(INFO) << "Pre pass: Cache transform pass complete.";
return Status::OK();
}
// Helper function to execute the cache transformation.
Status CacheTransformPass::ExecuteCacheTransform(ExecutionTree *tree, std::shared_ptr<DatasetOp> leaf_op,
std::shared_ptr<DatasetOp> cache_op,
std::shared_ptr<CacheClient> cache_client) {
// Get local pointers the child/parent of the cache op. It's possible that the parent is null if the cache was
// the root node. It is also possible that cache_child == leaf_op
std::shared_ptr<DatasetOp> cache_child = cache_op->child(0);
DatasetOp *cache_parent = nullptr;
cache_op->Parent(&cache_parent, 0); // fetch the cache op's parent
// Extract the sampler from the leaf. We will overwrite this sampler with the lookup op later.
std::shared_ptr<Sampler> leaf_sampler = leaf_op->sampler();
// Construct the merge op with defaults
std::shared_ptr<CacheMergeOp> merge_op;
CacheMergeOp::Builder merge_builder;
RETURN_IF_NOT_OK(merge_builder.SetClient(cache_client).Build(&merge_op));
RETURN_IF_NOT_OK(tree->AssociateNode(merge_op));
// Construct the cache lookup op with defaults
std::shared_ptr<CacheLookupOp> cache_lookup_op;
CacheLookupOp::Builder lookup_builder;
RETURN_IF_NOT_OK(lookup_builder.SetClient(cache_client).SetSampler(std::move(leaf_sampler)).Build(&cache_lookup_op));
RETURN_IF_NOT_OK(tree->AssociateNode(cache_lookup_op));
// Overwrite the old sampler in this leaf op to become the lookup op
leaf_op->SetSampler(cache_lookup_op);
// If the cache had a parent, then go into that parent to remove the cache from it's child list and then
// replace it with the merge op.
if (cache_parent != nullptr) {
RETURN_IF_NOT_OK(cache_parent->RemoveChild(cache_op));
RETURN_IF_NOT_OK(cache_parent->AddChild(merge_op));
} else {
// If we didn't have a parent, then the merge op is the root node
RETURN_IF_NOT_OK(tree->AssignRoot(merge_op));
}
// Set the cache op to no longer be a parent over it's child. This will fully disconnect the old cache op.
// We maintain a local pointer to the old child though.
RETURN_IF_NOT_OK(cache_op->RemoveChild(cache_child));
// Connect the merge op
RETURN_IF_NOT_OK(merge_op->AddChild(std::move(cache_lookup_op)));
RETURN_IF_NOT_OK(merge_op->AddChild(std::move(cache_child)));
// At this point, the cache op has already had it's children and parents taken away. Calling remove
// on it at this point will not do any node hookups, and instead set internal fields to invalid.
RETURN_IF_NOT_OK(cache_op->Remove());
return Status::OK();
}
// Assigns the leaf and cache operators that are involved in a cache transformation
void CacheTransformPass::AddMappableCacheOperators(std::shared_ptr<DatasetOp> leaf_op,
std::shared_ptr<CacheOp> cache_op) {
cache_pairs_.push_back(std::make_pair(leaf_op, cache_op));
}
} // namespace dataset
} // namespace mindspore
......@@ -34,6 +34,18 @@ class RemovalNodes : public NodePass {
/// \param[in] removal_pass Raw pointer back to controlling tree pass
explicit RemovalNodes(RemovalPass *removal_pass);
/// \brief Identifies the subtree below this node as a cached descendant tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
/// \brief Resets the tracking of the cache within the tree
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified) override;
/// \brief Perform ShuffleOp removal check
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
......
......@@ -24,6 +24,7 @@ from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset
TextFileDataset, CLUEDataset, Schema, Shuffle, zip, RandomDataset
from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \
WeightedRandomSampler, Sampler
from .engine.cache_client import DatasetCache
from .engine.serializer_deserializer import serialize, deserialize, show
from .engine.graphdata import GraphData
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册