提交 8a08d0c3 编写于 作者: J Jesse Lee

Phase 2 of CacheOp

上级 b11ef57b
......@@ -24,6 +24,11 @@ if (ENABLE_TDTQUE)
add_definitions(-D ENABLE_TDTQUE)
message(STATUS "TDT queue is enabled")
endif ()
if (MS_BUILD_GRPC)
set (ENABLE_CACHE true)
add_definitions(-D ENABLE_CACHE)
message(STATUS "Cache is enabled")
endif()
# conde coverage
# option(ENABLE_COVERAGE "Enable code coverage report" OFF)
......@@ -47,10 +52,6 @@ include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ORIGIN:$ORIGIN/lib")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default")
include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache")
set(MD_FLATBUFFER_OU "${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache")
ms_build_flatbuffers("engine/cache/de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU})
################## Include sub-modules ###############################
add_subdirectory(util)
add_subdirectory(core)
......@@ -70,8 +71,6 @@ add_dependencies(engine-datasetops-source-sampler core)
add_dependencies(engine-datasetops core)
add_dependencies(engine-datasetops-mapop core)
add_dependencies(engine-opt core)
add_dependencies(engine-cache-client core)
add_dependencies(engine-cache-server core)
add_dependencies(engine-perf core)
add_dependencies(engine-gnn core)
add_dependencies(engine core)
......@@ -85,7 +84,11 @@ endif()
if (ENABLE_TDTQUE)
add_dependencies(engine-tdt core)
endif ()
if (ENABLE_CACHE)
add_dependencies(engine-datasetops engine-cache-client)
add_dependencies(engine-cache-client core)
add_dependencies(engine-cache-server core)
endif ()
################### Create _c_dataengine Library ######################
set(submodules
$<TARGET_OBJECTS:core>
......@@ -105,7 +108,6 @@ set(submodules
$<TARGET_OBJECTS:engine-datasetops>
$<TARGET_OBJECTS:engine-opt>
$<TARGET_OBJECTS:engine-cache-client>
$<TARGET_OBJECTS:engine-cache-server>
$<TARGET_OBJECTS:engine>
$<TARGET_OBJECTS:text>
$<TARGET_OBJECTS:text-kernels>
......@@ -123,8 +125,6 @@ else ()
add_library(_c_dataengine SHARED ${submodules})
endif ()
add_dependencies(_c_dataengine generated_engine_files)
if (ENABLE_PYTHON)
set_target_properties(_c_dataengine PROPERTIES
PREFIX "${PYTHON_MODULE_PREFIX}"
......@@ -187,6 +187,6 @@ else()
endif ()
endif()
if (NOT CMAKE_SYSTEM_NAME MATCHES "Windows")
if (MS_BUILD_GRPC)
target_link_libraries(_c_dataengine PRIVATE mindspore::grpc++)
endif()
\ No newline at end of file
endif()
......@@ -22,7 +22,25 @@ namespace dataset {
PYBIND_REGISTER(CacheClient, 0, ([](const py::module *m) {
(void)py::class_<CacheClient, std::shared_ptr<CacheClient>>(*m, "CacheClient")
.def(py::init<uint32_t, uint64_t, bool>());
.def(
py::init([](session_id_type id, uint64_t mem_sz, bool spill, int32_t port, int32_t prefetch_sz) {
std::shared_ptr<CacheClient> cc;
CacheClient::Builder builder;
builder.SetSessionId(id).SetCacheMemSz(mem_sz).SetSpill(spill).SetPort(port).SetPrefetchSize(
prefetch_sz);
THROW_IF_ERROR(builder.Build(&cc));
return cc;
}))
.def("GetStat", [](CacheClient &cc) {
CacheServiceStat stat{};
THROW_IF_ERROR(cc.GetStat(&stat));
return stat;
});
(void)py::class_<CacheServiceStat>(*m, "CacheServiceStat")
.def(py::init<>())
.def_readwrite("avg_cache_sz", &CacheServiceStat::avg_cache_sz)
.def_readwrite("num_mem_cached", &CacheServiceStat::num_mem_cached)
.def_readwrite("num_disk_cached", &CacheServiceStat::num_disk_cached);
}));
} // namespace dataset
......
......@@ -72,7 +72,8 @@ constexpr uint32_t kCfgMonitorSamplingInterval = 10;
// Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h)
constexpr uint8_t kCVInvalidType = 255;
using connection_id_type = int64_t;
using connection_id_type = uint64_t;
using session_id_type = uint32_t;
using row_id_type = int64_t;
} // namespace dataset
} // namespace mindspore
......
......@@ -20,10 +20,8 @@ if (ENABLE_PYTHON)
target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS})
endif()
add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf engine-cache-client engine-datasetops-mapop)
if (ENABLE_TDTQUE)
add_dependencies(engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn engine-perf
engine-cache-client engine-cache-server engine-datasetops-mapop)
else ()
add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf
engine-cache-client engine-cache-server engine-datasetops-mapop)
add_dependencies(engine engine-tdt)
endif ()
include_directories("${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache")
set(MD_FLATBUFFER_OU "${CMAKE_BINARY_DIR}/minddata/dataset/engine/cache")
ms_build_flatbuffers("de_tensor.fbs" ${CMAKE_CURRENT_SOURCE_DIR} generated_engine_files ${MD_FLATBUFFER_OU})
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(engine-cache-client OBJECT
cache_client.cc
cache_fbb.cc
cache_request.cc)
add_library(engine-cache-server OBJECT
cache_service.cc
cache_server.cc)
if (ENABLE_CACHE)
ms_grpc_generate(CACHE_GRPC_SRCS CACHE_GRPC_HDRS cache_grpc.proto)
target_sources(engine-cache-client PUBLIC ${CACHE_GRPC_SRCS} cache_grpc_client.cc)
add_library(engine-cache-server OBJECT
${CACHE_GRPC_SRCS}
cache_grpc_server.cc
cache_arena.cc
cache_service.cc
cache_server.cc)
add_executable(cache_server cache_main.cc)
target_link_libraries(cache_server
engine-cache-server
$<TARGET_OBJECTS:utils>
mindspore
mindspore::glog
mindspore::protobuf
mindspore::grpc++
mindspore_gvar
${PYTHON_LIBRARIES}
${SECUREC_LIBRARY}
pthread)
add_executable(cache_admin cache_admin.cc cache_admin_arg.cc)
target_link_libraries(cache_admin _c_dataengine _c_mindrecord ${PYTHON_LIBRARIES} mindspore::glog)
add_dependencies(engine-cache-server generated_engine_files)
else ()
ms_protobuf_generate(CACHE_PROTO_SRCS CACHE_PRTO_HDRS cache_grpc.proto)
target_sources(engine-cache-client PUBLIC ${CACHE_PROTO_SRCS})
endif ()
add_dependencies(engine-cache-client generated_engine_files)
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <unistd.h>
#include <iostream>
#ifdef USE_GLOG
#include <glog/logging.h>
#endif
#include "minddata/dataset/engine/cache/cache_admin_arg.h"
namespace ds = mindspore::dataset;
int main(int argc, char **argv) {
ds::Status rc;
ds::CacheAdminArgHandler args;
std::stringstream arg_stream;
#ifdef USE_GLOG
FLAGS_log_dir = "/tmp";
google::InitGoogleLogging(argv[0]);
#endif
std::string warningMsg;
warningMsg.reserve(512);
warningMsg += "WARNING:\n";
warningMsg += "cache_admin and the cache server that it controls are currently only used for experimental research";
warningMsg += " purposes at this time.\n";
warningMsg += "It is not intended for general availability yet as it may not be stable. Use it at your own risk.\n";
// A warning message until the code is mature enough.
std::cerr << warningMsg << std::endl;
if (argc == 1) {
args.Help();
return 0;
}
// ingest all the args into a string stream for parsing
for (int i = 1; i < argc; ++i) {
arg_stream << " " << std::string(argv[i]);
}
// Parse the args
rc = args.ParseArgStream(&arg_stream);
if (!rc.IsOk()) {
std::cerr << rc.ToString() << std::endl;
return 1;
}
// Execute the command
rc = args.RunCommand();
if (!rc.IsOk()) {
std::cerr << rc.ToString() << std::endl;
return 1;
}
return 0;
}
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_admin_arg.h"
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/wait.h>
#include <unistd.h>
#include <cerrno>
#include <iostream>
#include <string>
#include <cstdlib>
#include "minddata/dataset/engine/cache/cache_request.h"
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/util/path.h"
namespace mindspore {
namespace dataset {
const char CacheAdminArgHandler::kDefaultHost[] = "127.0.0.1";
const char CacheAdminArgHandler::kServerBinary[] = "cache_server";
const char CacheAdminArgHandler::kDefaultSpillDir[] = "/tmp";
CacheAdminArgHandler::CacheAdminArgHandler()
: port_(kDefaultPort),
session_id_(0),
num_workers_(kDefaultNumWorkers),
shm_mem_sz_(kDefaultSharedMemorySizeInGB),
log_level_(kDefaultLogLevel),
hostname_(kDefaultHost),
spill_dir_(kDefaultSpillDir),
command_id_(CommandId::kCmdUnknown) {
// Initialize the command mappings
arg_map_["-h"] = ArgValue::kArgHost;
arg_map_["--hostname"] = ArgValue::kArgHost;
arg_map_["-p"] = ArgValue::kArgPort;
arg_map_["--port"] = ArgValue::kArgPort;
arg_map_["--start"] = ArgValue::kArgStart;
arg_map_["--stop"] = ArgValue::kArgStop;
arg_map_["--help"] = ArgValue::kArgHelp;
arg_map_["--generate_session"] = ArgValue::kArgGenerateSession;
arg_map_["-g"] = ArgValue::kArgGenerateSession;
arg_map_["--destroy_session"] = ArgValue::kArgDestroySession;
arg_map_["-d"] = ArgValue::kArgDestroySession;
arg_map_["--spilldir"] = ArgValue::kArgSpillDir;
arg_map_["-s"] = ArgValue::kArgSpillDir;
arg_map_["-w"] = ArgValue::kArgNumWorkers;
arg_map_["--workers"] = ArgValue::kArgNumWorkers;
arg_map_["-m"] = ArgValue::kArgSharedMemorySize;
arg_map_["--shared_memory_size"] = ArgValue::kArgSharedMemorySize;
arg_map_["-l"] = ArgValue::kArgLogLevel;
arg_map_["--minloglevel"] = ArgValue::kArgLogLevel;
// Initialize argument tracker with false values
for (int16_t i = 0; i < static_cast<int16_t>(ArgValue::kArgNumArgs); ++i) {
ArgValue currAV = static_cast<ArgValue>(i);
used_args_[currAV] = false;
}
}
Status CacheAdminArgHandler::AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream,
CommandId command_id) {
// Detect if the user tried to provide this argument more than once
ArgValue selected_arg = arg_map_[option];
if (used_args_[selected_arg]) {
std::string err_msg = "The " + option + " argument was given more than once.";
return Status(StatusCode::kSyntaxError, err_msg);
}
// Flag that this arg is used now
used_args_[selected_arg] = true;
// Some options are just arguments, for example "--port 50052" is not a command, it's just a argument.
// Other options are actual commands, for example "--destroy_session 1234". This executes the destroy session.
// If this option is also a command, make sure there has not been multiple commands given before assigning it.
if (command_id != CommandId::kCmdUnknown) {
if (command_id_ != CommandId::kCmdUnknown) {
std::string err_msg = "Only one command at a time is allowed. Invalid command: " + option;
return Status(StatusCode::kSyntaxError, err_msg);
} else {
command_id_ = command_id;
}
}
std::string value_as_string;
// Fetch the argument from the arg stream into a string
*arg_stream >> value_as_string;
if (value_as_string.empty()) {
std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>";
return Status(StatusCode::kSyntaxError, err_msg);
}
// Now, attempt to convert the value into it's string format for output
try {
*out_arg = std::stoul(value_as_string);
} catch (const std::exception &e) {
std::string err_msg = "Invalid numeric value: " + value_as_string;
return Status(StatusCode::kSyntaxError, err_msg);
}
return Status::OK();
}
Status CacheAdminArgHandler::AssignArg(std::string option, std::string *out_arg, std::stringstream *arg_stream,
CommandId command_id) {
// Detect if the user tried to provide this argument more than once
ArgValue selected_arg = arg_map_[option];
if (used_args_[selected_arg]) {
std::string err_msg = "The " + option + " argument was given more than once.";
return Status(StatusCode::kSyntaxError, err_msg);
}
// Flag that this arg is used now
used_args_[selected_arg] = true;
// Some options are just arguments, for example "--hostname "127.0.0.1" is not a command, it's just an argument.
// Other options are actual commands, for example "--start".
// If this option is also a command, make sure there has not been multiple commands given before assigning it.
if (command_id != CommandId::kCmdUnknown) {
if (command_id_ != CommandId::kCmdUnknown) {
std::string err_msg = "Only one command at a time is allowed. Invalid command: " + option;
return Status(StatusCode::kSyntaxError, err_msg);
} else {
command_id_ = command_id;
}
}
// If there is no argument to get, such as the --start command, then out_arg will be a nullptr.
if (out_arg != nullptr) {
// Fetch the argument from the arg stream into a string
*arg_stream >> *out_arg;
if (out_arg->empty()) {
std::string err_msg = option + " option requires an argument field. Syntax: " + option + " <field>";
return Status(StatusCode::kSyntaxError, err_msg);
}
}
return Status::OK();
}
Status CacheAdminArgHandler::ParseArgStream(std::stringstream *arg_stream) {
std::string tok;
while (*arg_stream >> tok) {
switch (arg_map_[tok]) {
case ArgValue::kArgHost: {
RETURN_IF_NOT_OK(AssignArg(tok, &hostname_, arg_stream));
break;
}
case ArgValue::kArgPort: {
RETURN_IF_NOT_OK(AssignArg(tok, &port_, arg_stream));
break;
}
case ArgValue::kArgStart: {
RETURN_IF_NOT_OK(AssignArg(tok, static_cast<std::string *>(nullptr), arg_stream, CommandId::kCmdStart));
break;
}
case ArgValue::kArgStop: {
RETURN_IF_NOT_OK(AssignArg(tok, static_cast<std::string *>(nullptr), arg_stream, CommandId::kCmdStop));
break;
}
case ArgValue::kArgGenerateSession: {
RETURN_IF_NOT_OK(
AssignArg(tok, static_cast<std::string *>(nullptr), arg_stream, CommandId::kCmdGenerateSession));
break;
}
case ArgValue::kArgHelp: {
command_id_ = CommandId::kCmdHelp;
break;
}
case ArgValue::kArgDestroySession: {
// session_id is an unsigned type. We may need to template the AssignArg function so that
// it can handle different flavours of integers instead of just int32_t.
int32_t session_int;
RETURN_IF_NOT_OK(AssignArg(tok, &session_int, arg_stream, CommandId::kCmdDestroySession));
session_id_ = session_int;
break;
}
case ArgValue::kArgNumWorkers: {
RETURN_IF_NOT_OK(AssignArg(tok, &num_workers_, arg_stream));
break;
}
case ArgValue::kArgSpillDir: {
RETURN_IF_NOT_OK(AssignArg(tok, &spill_dir_, arg_stream));
break;
}
case ArgValue::kArgSharedMemorySize: {
RETURN_IF_NOT_OK(AssignArg(tok, &shm_mem_sz_, arg_stream));
break;
}
case ArgValue::kArgLogLevel: {
RETURN_IF_NOT_OK(AssignArg(tok, &log_level_, arg_stream));
break;
}
default: {
// Save space delimited trailing arguments
trailing_args_ += (" " + tok);
break;
}
}
}
RETURN_IF_NOT_OK(Validate());
return Status::OK();
}
Status CacheAdminArgHandler::Validate() {
// This sanity check is delayed until now in case there may be valid use-cases of trailing args.
// Any unhandled arguments at this point is an error.
if (!trailing_args_.empty()) {
std::string err_msg = "Invalid arguments provided: " + trailing_args_;
return Status(StatusCode::kSyntaxError, err_msg);
}
// The user must pick at least one command. i.e. it's meaningless to just give a hostname or port but no command to
// run.
if (command_id_ == CommandId::kCmdUnknown) {
std::string err_msg = "No command provided";
return Status(StatusCode::kSyntaxError, err_msg);
}
// Additional checks here
if (num_workers_ < 1) return Status(StatusCode::kSyntaxError, "Number of workers must be positive value.");
if (log_level_ < 0 || log_level_ > 3) return Status(StatusCode::kSyntaxError, "Log level must be in range (0..3).");
// port range check?
return Status::OK();
}
Status CacheAdminArgHandler::RunCommand() {
switch (command_id_) {
case CommandId::kCmdHelp: {
Help();
break;
}
case CommandId::kCmdStart: {
RETURN_IF_NOT_OK(StartServer());
break;
}
case CommandId::kCmdStop: {
RETURN_IF_NOT_OK(StopServer());
break;
}
case CommandId::kCmdGenerateSession: {
CacheClientGreeter comm(hostname_, port_, 1);
RETURN_IF_NOT_OK(comm.ServiceStart());
auto rq = std::make_shared<GenerateSessionIdRequest>();
RETURN_IF_NOT_OK(comm.HandleRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
std::cout << rq->GetSessionId() << std::endl;
break;
}
case CommandId::kCmdDestroySession: {
CacheClientGreeter comm(hostname_, port_, 1);
RETURN_IF_NOT_OK(comm.ServiceStart());
CacheClientInfo cinfo;
cinfo.set_session_id(session_id_);
auto rq = std::make_shared<DropSessionRequest>(cinfo);
RETURN_IF_NOT_OK(comm.HandleRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
std::cout << "Drop session successful" << std::endl;
break;
}
default: {
RETURN_STATUS_UNEXPECTED("Invalid cache admin command id.");
break;
}
}
return Status::OK();
}
Status CacheAdminArgHandler::StartServer() {
// There currently does not exist any "install path" or method to identify which path the installed binaries will
// exist in. As a temporary approach, we will assume that the server binary shall exist in the same path as the
// cache_admin binary (this process).
const std::string self_proc = "/proc/self/exe";
std::string canonical_path;
canonical_path.resize(400); // PATH_MAX is large. This value should be big enough for our use.
// Some lower level OS library calls are needed here to determine the binary path.
// Fetch the path of this binary for admin_cache into C character array and then truncate off the binary name so that
// we are left with only the absolute path
if (realpath(self_proc.data(), canonical_path.data()) == nullptr) {
std::string err_msg = "Failed to identify cache admin binary path: " + std::to_string(errno);
RETURN_STATUS_UNEXPECTED(err_msg);
}
canonical_path.resize(strlen(canonical_path.data()));
int last_seperator = canonical_path.find_last_of('/');
CHECK_FAIL_RETURN_UNEXPECTED(last_seperator != std::string::npos, "No / found");
// truncate the binary name so we are left with the absolute path of cache_admin binary
canonical_path.resize(last_seperator + 1);
std::string cache_server_binary = canonical_path + std::string(kServerBinary);
// Create a pipe before we fork. If all goes well, the child will run as a daemon in the background
// and never returns until shutdown. If there is any error, the child will notify us through the pipe.
int fd[2];
if (pipe(fd) == -1) {
std::string err_msg = "Failed to create a pipe for communication " + std::to_string(errno);
RETURN_STATUS_UNEXPECTED(err_msg);
}
// fork the child process to become the daemon
pid_t pid;
pid = fork();
// failed to fork
if (pid < 0) {
std::string err_msg = "Failed to fork process for cache server: " + std::to_string(errno);
RETURN_STATUS_UNEXPECTED(err_msg);
} else if (pid > 0) {
// As a parent, we close the write end. We only listen.
close(fd[1]);
dup2(fd[0], 0);
close(fd[0]);
wait(nullptr);
std::string msg;
const int32_t buf_sz = 1024;
msg.resize(buf_sz);
auto n = read(0, msg.data(), buf_sz);
if (n < 0) {
std::string err_msg = "Failed to read from pipeline " + std::to_string(errno);
RETURN_STATUS_UNEXPECTED(err_msg);
}
msg.resize(n);
std::cout << msg << std::endl;
return Status::OK();
} else {
// Child here ...
// Close all stdin, redirect stdout and stderr to the write end of the pipe.
close(fd[0]);
dup2(fd[1], 1);
dup2(fd[1], 2);
close(0);
close(fd[1]);
// exec the cache server binary in this process
std::string port_string = std::to_string(port_);
std::string workers_string = std::to_string(num_workers_);
std::string shared_memory_string = std::to_string(shm_mem_sz_);
std::string minloglevel_string = std::to_string(log_level_);
std::string daemonize_string = "true";
char *argv[8];
argv[0] = cache_server_binary.data(); // First arg is usually the binary name
argv[1] = spill_dir_.data();
argv[2] = workers_string.data();
argv[3] = port_string.data();
argv[4] = shared_memory_string.data();
argv[5] = minloglevel_string.data();
argv[6] = daemonize_string.data();
argv[7] = nullptr;
// Now exec the binary
execv(argv[0], argv);
// If the exec was successful, this line will never be reached due to process image being replaced.
// ..unless exec failed.
std::string err_msg = "Failed to exec cache server: " + cache_server_binary;
std::cerr << err_msg << std::endl;
RETURN_STATUS_UNEXPECTED(err_msg);
}
}
Status CacheAdminArgHandler::StopServer() {
CacheClientGreeter comm(hostname_, port_, 1);
RETURN_IF_NOT_OK(comm.ServiceStart());
auto rq = std::make_shared<ShutdownRequest>();
RETURN_IF_NOT_OK(comm.HandleRequest(rq));
return Status::OK();
}
void CacheAdminArgHandler::Help() {
std::cerr << "Syntax:\n";
std::cerr << " cache_admin [--start | --stop]\n";
std::cerr << " [ [-h | --hostname] <hostname> ]\n";
std::cerr << " [ [-p | --port] <port number> ]\n";
std::cerr << " [ [-g | --generate_session] ]\n";
std::cerr << " [ [-d | --destroy_session] <session id> ]\n";
std::cerr << " [ [-w | --workers] <number of workers> ]\n";
std::cerr << " [ [-s | --spilldir] <spilling directory> ]\n";
std::cerr << " [ [-m | --shared_memory_size] <shared memory size> ]\n";
std::cerr << " [ [-l | --minloglevel] <log level> ]\n";
std::cerr << " [--help]" << std::endl;
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ADMIN_ARG_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ADMIN_ARG_H_
#include <iostream>
#include <map>
#include <memory>
#include <string>
#include <sstream>
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/engine/cache/cache_client.h"
namespace mindspore {
namespace dataset {
class CacheAdminArgHandler {
public:
static constexpr int32_t kDefaultPort = 50052;
static constexpr int32_t kDefaultNumWorkers = 32;
static constexpr int32_t kDefaultSharedMemorySizeInGB = 4;
static constexpr int32_t kDefaultLogLevel = 1;
static const char kDefaultHost[];
static const char kServerBinary[];
static const char kDefaultSpillDir[];
// These are the actual command types to execute
enum class CommandId : int16_t {
kCmdHelp = 0,
kCmdStart = 1,
kCmdStop = 2,
kCmdGenerateSession = 3,
kCmdDestroySession = 4,
kCmdUnknown = 32767
};
CacheAdminArgHandler();
~CacheAdminArgHandler() = default;
Status ParseArgStream(std::stringstream *arg_stream);
Status RunCommand();
void Help();
private:
// These are the supported argument string integer mappings
enum class ArgValue : int16_t {
kArgUnknown = 0, // Must be at position 0. invalid map lookups in arg_map_ default to value 0
kArgStart = 1,
kArgStop = 2,
kArgHost = 3,
kArgPort = 4,
kArgHelp = 5,
kArgGenerateSession = 6,
kArgDestroySession = 7,
kArgSpillDir = 8,
kArgNumWorkers = 9,
kArgSharedMemorySize = 10,
kArgLogLevel = 11,
kArgNumArgs = 12 // Must be the last position to provide a count
};
Status StartServer();
Status StopServer();
Status AssignArg(std::string option, int32_t *out_arg, std::stringstream *arg_stream,
CommandId command_id = CommandId::kCmdUnknown);
Status AssignArg(std::string option, std::string *out_arg, std::stringstream *arg_stream,
CommandId command_id = CommandId::kCmdUnknown);
Status Validate();
CommandId command_id_;
int32_t port_;
int32_t num_workers_;
int32_t shm_mem_sz_;
int32_t log_level_;
session_id_type session_id_;
std::string hostname_;
std::string spill_dir_;
std::string trailing_args_;
std::map<std::string, ArgValue> arg_map_;
std::map<ArgValue, bool> used_args_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ADMIN_ARG_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_arena.h"
#include "minddata/dataset/util/path.h"
namespace mindspore {
namespace dataset {
CachedSharedMemoryArena::CachedSharedMemoryArena(int32_t port, size_t val_in_GB)
: Arena::Arena(val_in_GB * 1024), port_(port), shmid_(-1) {}
CachedSharedMemoryArena::~CachedSharedMemoryArena() {
#if CACHE_LOCAL_CLIENT
if (this->ptr_ != nullptr && this->ptr_ != reinterpret_cast<void *>(-1)) {
shmdt(this->ptr_);
}
this->ptr_ = nullptr;
if (shmid_ != -1) {
shmctl(shmid_, IPC_RMID, nullptr);
// Also remove the path we use to generate ftok.
Path p(PortToUnixSocketPath(port_));
(void)p.Remove();
}
#endif
}
Status CachedSharedMemoryArena::CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port,
size_t val_in_GB) {
RETURN_UNEXPECTED_IF_NULL(out);
#if CACHE_LOCAL_CLIENT
auto ba = new (std::nothrow) CachedSharedMemoryArena(port, val_in_GB);
if (ba == nullptr) {
return Status(StatusCode::kOutOfMemory);
}
// Transfer the ownership of this pointer. Any future error in the processing we will have
// the destructor of *out to deal.
(*out).reset(ba);
// Generate the ftok using a combination of port.
int err;
auto shm_key = PortToFtok(port, &err);
if (shm_key == (key_t)-1) {
std::string errMsg = "Ftok failed with errno " + std::to_string(err);
RETURN_STATUS_UNEXPECTED(errMsg);
}
auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP;
ba->shmid_ = shmget(shm_key, ba->size_in_bytes_, IPC_CREAT | IPC_EXCL | access_mode);
if (ba->shmid_) {
ba->ptr_ = shmat(ba->shmid_, nullptr, 0);
if (ba->ptr_ == reinterpret_cast<void *>(-1)) {
RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno));
}
} else {
RETURN_STATUS_UNEXPECTED("Shared memory creation failed. Errno " + std::to_string(errno));
}
uint64_t num_blks = ba->size_in_bytes_ / ARENA_BLK_SZ;
MS_LOG(DEBUG) << "Size of memory pool is " << num_blks << ", number of blocks of size is " << ARENA_BLK_SZ << ".";
ba->tr_.Insert(0, num_blks);
#endif
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ARENA_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ARENA_H_
#include <memory>
#include <string>
#include "minddata/dataset/util/arena.h"
#include "minddata/dataset/engine/cache/cache_common.h"
namespace mindspore {
namespace dataset {
/// This is a derived class of Arena but resides in shared memory
class CachedSharedMemoryArena : public Arena {
public:
~CachedSharedMemoryArena() override;
/// \brief Create an Arena in shared memory
/// \param[out] p_ba Pointer to a unique_ptr
/// \param shmkey Shared memory key
/// \param val_in_GB size of shared memory in gigabyte
/// \return Status object
static Status CreateArena(std::unique_ptr<CachedSharedMemoryArena> *out, int32_t port, size_t val_in_GB);
/// \brief This returns where we attach to the shared memory.
/// Some gRPC requests will ask for a shared memory block, and
/// we can't return the absolute address as this makes no sense
/// in the client. So instead we will return an address relative
/// to the base address of the shared memory where we attach to.
/// \return Base address of the shared memory.
const void *SharedMemoryBaseAddr() const { return this->ptr_; }
private:
int32_t port_;
int shmid_;
/// Private constructor. Not to be called directly.
CachedSharedMemoryArena(int32_t port, size_t val_in_GB);
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ARENA_H_
......@@ -17,29 +17,45 @@
#include <iomanip>
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/cache/cache_request.h"
#include "minddata/dataset/engine/cache/cache_service.h"
#include "minddata/dataset/engine/cache/cache_fbb.h"
#include "minddata/dataset/util/bit.h"
namespace mindspore {
namespace dataset {
// Constructor
CacheClient::CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill)
: server_connection_id_(0), session_id_(session_id), cache_crc_(0), cache_mem_sz_(cache_mem_sz), spill_(spill) {}
CacheClient::CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname,
int32_t port, int32_t num_workers, int32_t prefetch_size)
: server_connection_id_(0),
cache_mem_sz_(cache_mem_sz),
spill_(spill),
local_bypass_(false),
hostname_(std::move(hostname)),
port_(port),
num_workers_(num_workers),
prefetch_size_(prefetch_size) {
cinfo_.set_session_id(session_id);
comm_ = std::make_shared<CacheClientGreeter>(hostname_, port_, num_workers_);
}
// print method for display cache details
void CacheClient::Print(std::ostream &out) const {
out << " Session id: " << session_id_ << "\n Cache crc: " << cache_crc_
<< "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << cache_mem_sz_
<< "\n Spilling: " << std::boolalpha << spill_;
out << " Session id: " << session_id() << "\n Cache crc: " << cinfo_.crc()
<< "\n Server cache id: " << server_connection_id_ << "\n Cache mem size: " << getCacheMemSz()
<< "\n Spilling: " << std::boolalpha << isSpill() << "\n Hostname: " << getHostname()
<< "\n Port: " << getPort() << "\n Number of rpc workers: " << getNumWorkers()
<< "\n Prefetch size: " << getPrefetchSize() << "\n Local client support: " << std::boolalpha
<< SupportLocalClient();
}
Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_server) const {
CacheRowRequest rq(server_connection_id_, cookie());
RETURN_IF_NOT_OK(rq.SerializeCacheRowRequest(row));
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
auto rq = std::make_shared<CacheRowRequest>(server_connection_id_, cookie(), SupportLocalClient());
RETURN_IF_NOT_OK(rq->SerializeCacheRowRequest(this, row));
RETURN_IF_NOT_OK(PushRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
if (row_id_from_server != nullptr) {
*row_id_from_server = rq.GetRowIdAfterCache();
*row_id_from_server = rq->GetRowIdAfterCache();
}
return Status::OK();
}
......@@ -47,29 +63,19 @@ Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_serv
Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const {
std::unique_ptr<DataBuffer> db_ptr = std::move(in);
auto num_rows = db_ptr->NumRows();
std::vector<TensorRow> all_rows;
// We will send the requests async first on all rows and do a final wait.
if (num_rows > 0) {
all_rows.reserve(num_rows);
// Break down the DataBuffer into TensorRow. We will send the requests async
// and then do a final wait.
MemGuard<CacheRowRequest> rq_arr;
RETURN_IF_NOT_OK(rq_arr.allocate(num_rows, server_connection_id_, cookie()));
CacheServer &cs = CacheServer::GetInstance();
auto arr = std::make_unique<std::shared_ptr<CacheRowRequest>[]>(num_rows);
for (auto i = 0; i < num_rows; ++i) {
TensorRow row;
auto rq = rq_arr[i];
RETURN_IF_NOT_OK(db_ptr->PopRow(&row));
RETURN_IF_NOT_OK(rq->SerializeCacheRowRequest(row));
RETURN_IF_NOT_OK(cs.PushRequest(rq));
// We can't let row go out of scope. Otherwise it will free all the tensor memory.
// So park it in the vector. When this function go out of scope, its memory
// will be freed.
all_rows.push_back(std::move(row));
arr[i] = std::make_shared<CacheRowRequest>(server_connection_id_, cookie(), SupportLocalClient());
RETURN_IF_NOT_OK(arr[i]->SerializeCacheRowRequest(this, row));
RETURN_IF_NOT_OK(PushRequest(arr[i]));
}
// Now we wait for the requests to be done.
// Now we wait for them to come back
for (auto i = 0; i < num_rows; ++i) {
auto rq = rq_arr[i];
RETURN_IF_NOT_OK(rq->Wait());
RETURN_IF_NOT_OK(arr[i]->Wait());
}
}
return Status::OK();
......@@ -77,11 +83,21 @@ Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const {
Status CacheClient::GetRows(const std::vector<row_id_type> &row_id, TensorTable *out) const {
RETURN_UNEXPECTED_IF_NULL(out);
BatchFetchRequest rq(server_connection_id_, row_id);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
RETURN_IF_NOT_OK(rq.RestoreRows(out));
return Status::OK();
auto rq = std::make_shared<BatchFetchRequest>(server_connection_id_, row_id, SupportLocalClient());
RETURN_IF_NOT_OK(PushRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
int64_t mem_addr;
Status rc = rq->RestoreRows(out, comm_->SharedMemoryBaseAddr(), &mem_addr);
// Free the memory by sending a request back to the server.
if (mem_addr != -1) {
auto mfree_req = std::make_shared<FreeSharedBlockRequest>(server_connection_id_, mem_addr);
Status rc2 = PushRequest(mfree_req);
// But we won't wait for the result for the sake of performance.
if (rc.IsOk() && rc2.IsError()) {
rc = rc2;
}
}
return rc;
}
Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
......@@ -108,40 +124,44 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
// to create a cache and some other tree is trying to use the same cache.
// That is allowed, however the crc better match!
if (server_connection_id_) {
if (cache_crc_ != tree_crc) {
if (cinfo_.crc() != tree_crc) {
RETURN_STATUS_UNEXPECTED("Attempt to re-use a cache for a different tree!");
}
// Check the state of the server. For non-mappable case where there is a build phase and a fetch phase, we should
// skip the build phase.
lck.Unlock(); // GetStat will grab the mutex again. So unlock it to prevent deadlock.
CacheClient::ServiceStat stat{};
CacheServiceStat stat{};
RETURN_IF_NOT_OK(GetStat(&stat));
if (stat.cache_service_state == static_cast<uint8_t>(CacheService::State::kFetchPhase)) {
return Status(StatusCode::kDuplicateKey, __LINE__, __FILE__, "Not an error and we should bypass the build phase");
}
} else {
cache_crc_ = tree_crc; // It's really a new cache we're creating so save our crc in the client
// Combine the session and crc. This will form our client cache identifier.
connection_id_type connection_identification = (static_cast<uint64_t>(session_id_) << 32) | cache_crc_;
cinfo_.set_crc(tree_crc); // It's really a new cache we're creating so save our crc in the client
// Now execute the cache create request using this identifier and other configs
BaseRequest::CreateCacheFlag createFlag = BaseRequest::CreateCacheFlag::kNone;
CreateCacheRequest::CreateCacheFlag createFlag = CreateCacheRequest::CreateCacheFlag::kNone;
if (spill_) {
createFlag |= BaseRequest::CreateCacheFlag::kSpillToDisk;
createFlag |= CreateCacheRequest::CreateCacheFlag::kSpillToDisk;
}
if (generate_id) {
createFlag |= BaseRequest::CreateCacheFlag::kGenerateRowId;
createFlag |= CreateCacheRequest::CreateCacheFlag::kGenerateRowId;
}
CreationCacheRequest rq(connection_identification, cache_mem_sz_, createFlag);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
Status rc = rq.Wait();
// Start the comm layer to receive reply
RETURN_IF_NOT_OK(comm_->ServiceStart());
// Initiate connection
auto rq = std::make_shared<CreateCacheRequest>(cinfo_, cache_mem_sz_, createFlag);
RETURN_IF_NOT_OK(PushRequest(rq));
Status rc = rq->Wait();
if (rc.IsOk() || rc.get_code() == StatusCode::kDuplicateKey) {
server_connection_id_ = rq.GetServerConnectionId();
std::string cookie;
rq->ParseResult(&server_connection_id_, &cookie);
if (rc.IsOk()) {
// The 1st guy creating the cache will get a cookie back.
// But this object may be shared among pipelines and we don't want
// overwrite it.
cookie_ = rq.cookie();
cookie_ = cookie;
}
// Attach to shared memory for local client
RETURN_IF_NOT_OK(comm_->AttachToSharedMemory(port_, &local_bypass_));
}
// We are not resetting the Duplicate key return code. We are passing it back to the CacheOp. This will tell the
// CacheOp to bypass the build phase.
......@@ -152,57 +172,57 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
Status CacheClient::PurgeCache() {
UniqueLock lck(&mux_);
PurgeCacheRequest rq(server_connection_id_);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
return rq.Wait();
auto rq = std::make_shared<PurgeCacheRequest>(server_connection_id_);
RETURN_IF_NOT_OK(PushRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
return Status::OK();
}
Status CacheClient::DestroyCache() {
UniqueLock lck(&mux_);
DestroyCacheRequest rq(server_connection_id_);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
return rq.Wait();
auto rq = std::make_shared<DestroyCacheRequest>(server_connection_id_);
RETURN_IF_NOT_OK(PushRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
return Status::OK();
}
Status CacheClient::GetStat(ServiceStat *stat) {
Status CacheClient::GetStat(CacheServiceStat *stat) {
SharedLock lck(&mux_);
RETURN_UNEXPECTED_IF_NULL(stat);
GetStatRequest rq(server_connection_id_);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
stat->num_disk_cached = rq.GetNumDiskCached();
stat->num_mem_cached = rq.GetNumMemCached();
stat->min_row_id = rq.GetMinRowId();
stat->max_row_id = rq.GetMaxRowId();
stat->cache_service_state = rq.GetState();
auto rq = std::make_shared<GetStatRequest>(server_connection_id_);
RETURN_IF_NOT_OK(PushRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
rq->GetStat(stat);
return Status::OK();
}
Status CacheClient::CacheSchema(const std::unordered_map<std::string, int32_t> &map) {
SharedLock lck(&mux_);
CacheSchemaRequest rq(server_connection_id_);
RETURN_IF_NOT_OK(rq.SerializeCacheSchemaRequest(map));
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
auto rq = std::make_shared<CacheSchemaRequest>(server_connection_id_);
RETURN_IF_NOT_OK(rq->SerializeCacheSchemaRequest(map));
RETURN_IF_NOT_OK(PushRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
return Status::OK();
}
Status CacheClient::FetchSchema(std::unordered_map<std::string, int32_t> *map) {
SharedLock lck(&mux_);
RETURN_UNEXPECTED_IF_NULL(map);
FetchSchemaRequest rq(server_connection_id_);
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
*map = rq.GetColumnMap();
auto rq = std::make_shared<FetchSchemaRequest>(server_connection_id_);
RETURN_IF_NOT_OK(PushRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
*map = rq->GetColumnMap();
return Status::OK();
}
Status CacheClient::BuildPhaseDone() const {
SharedLock lck(&mux_);
BuildPhaseDoneRequest rq(server_connection_id_, cookie());
RETURN_IF_NOT_OK(CacheServer::GetInstance().PushRequest(&rq));
RETURN_IF_NOT_OK(rq.Wait());
auto rq = std::make_shared<BuildPhaseDoneRequest>(server_connection_id_, cookie());
RETURN_IF_NOT_OK(PushRequest(rq));
RETURN_IF_NOT_OK(rq->Wait());
return Status::OK();
}
Status CacheClient::PushRequest(std::shared_ptr<BaseRequest> rq) const { return comm_->HandleRequest(std::move(rq)); }
} // namespace dataset
} // namespace mindspore
......@@ -23,9 +23,13 @@
#include <utility>
#include <vector>
#include "minddata/dataset/core/config_manager.h"
#ifdef ENABLE_CACHE
#include "minddata/dataset/engine/cache/cache_grpc_client.h"
#else
#include "minddata/dataset/engine/cache/stub/cache_grpc_client.h"
#endif
#include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/engine/cache/cache_server.h"
#include "minddata/dataset/engine/cache/de_tensor_generated.h"
#include "minddata/dataset/util/lock.h"
namespace mindspore {
......@@ -35,18 +39,120 @@ namespace dataset {
/// rows, etc.
class CacheClient {
public:
friend class CacheMergeOp;
/// \brief A builder to help creating a CacheClient object
class Builder {
public:
Builder() : session_id_(0), cache_mem_sz_(0), spill_(false), port_(0), num_workers_(0), prefetch_size_(0) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
hostname_ = "127.0.0.1";
port_ = 50052;
num_workers_ = cfg->num_parallel_workers();
prefetch_size_ = 20; // rows_per_buf is too small (1 by default).
}
/// Setter function to set the session id
/// \param session_id
/// \return Builder object itself.
Builder &SetSessionId(session_id_type session_id) {
session_id_ = session_id;
return *this;
}
/// Setter function to set the cache memory size
/// \param cache_mem_sz
/// \return Builder object itself
Builder &SetCacheMemSz(uint64_t cache_mem_sz) {
cache_mem_sz_ = cache_mem_sz;
return *this;
}
/// Setter function to spill attribute
/// \param spill
/// Builder object itself
Builder &SetSpill(bool spill) {
spill_ = spill;
return *this;
}
/// Setter function to set rpc hostname
/// \param host
/// \return Builder object itself
Builder &SetHostname(std::string host) {
hostname_ = std::move(host);
return *this;
}
/// Setter function to set tcpip port
/// \param port
/// \return Builder object itself.
Builder &SetPort(int32_t port) {
port_ = port;
return *this;
}
/// Setter function to set number of async rpc workers
/// \param num_workers
/// \return Builder object itself
Builder &SetNumWorkers(int32_t num_workers) {
num_workers_ = num_workers;
return *this;
}
/// Setter function to set prefetch amount for fetching rows from cache server
/// \param prefetch_sz
/// \return Builder object itself
Builder &SetPrefetchSize(int32_t prefetch_sz) {
prefetch_size_ = prefetch_sz;
return *this;
}
/// Getter functions
session_id_type getSessionId() const { return session_id_; }
uint64_t getCacheMemSz() const { return cache_mem_sz_; }
bool isSpill() const { return spill_; }
const std::string &getHostname() const { return hostname_; }
int32_t getPort() const { return port_; }
int32_t getNumWorkers() const { return num_workers_; }
int32_t getPrefetchSize() const { return prefetch_size_; }
Status SanityCheck() {
CHECK_FAIL_RETURN_UNEXPECTED(session_id_ > 0, "session id must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(cache_mem_sz_ >= 0, "cache memory size must not be negative. (0 implies unlimited");
CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "rpc workers must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(prefetch_size_ > 0, "prefetch size must be positive");
CHECK_FAIL_RETURN_UNEXPECTED(!hostname_.empty(), "hostname must not be empty");
return Status::OK();
}
Status Build(std::shared_ptr<CacheClient> *out) {
RETURN_UNEXPECTED_IF_NULL(out);
RETURN_IF_NOT_OK(SanityCheck());
*out = std::make_shared<CacheClient>(session_id_, cache_mem_sz_, spill_, hostname_, port_, num_workers_,
prefetch_size_);
return Status::OK();
}
private:
session_id_type session_id_;
uint64_t cache_mem_sz_;
bool spill_;
std::string hostname_;
int32_t port_;
int32_t num_workers_;
int32_t prefetch_size_;
};
/// \brief Constructor
/// \param session_id A user assigned session id for the current pipeline
/// \param cache_mem_sz Size of the memory set aside for the row caching. 0 for unlimited
/// \param spill Spill to disk if out of memory
CacheClient(uint32_t session_id, uint64_t cache_mem_sz, bool spill);
CacheClient(session_id_type session_id, uint64_t cache_mem_sz, bool spill, std::string hostname, int32_t port,
int32_t num_workers, int32_t prefetch_size);
/// \brief Destructor
~CacheClient() = default;
/// \brief Getter function for returning the current session id
/// \return session id
uint64_t session_id() const { return session_id_; }
~CacheClient() { (void)comm_->ServiceStop(); }
/// \brief Send a TensorRow to the cache server
/// \param[in] row
......@@ -83,14 +189,7 @@ class CacheClient {
/// \brief Get the statistics from a cache.
/// \param[in/out] Pointer to a pre-allocated ServiceStat object
/// \return Status object
struct ServiceStat {
int64_t num_mem_cached;
int64_t num_disk_cached;
row_id_type min_row_id;
row_id_type max_row_id;
int8_t cache_service_state;
};
Status GetStat(ServiceStat *);
Status GetStat(CacheServiceStat *);
/// \brief Cache the schema at the cache server
/// \param map The unordered map of the schema
......@@ -122,18 +221,45 @@ class CacheClient {
/// \return Cookie
std::string cookie() const { return cookie_; }
/// \brief Send a request async to the server
/// \param rq BaseRequest
/// \return Status object
Status PushRequest(std::shared_ptr<BaseRequest> rq) const;
/// \brief If the remote server supports local bypass using shared memory
/// \return boolean value
bool SupportLocalClient() const { return local_bypass_; }
/// \brief Return the base memory address if we attach to any shared memory.
auto SharedMemoryBaseAddr() const { return comm_->SharedMemoryBaseAddr(); }
/// Getter functions
session_id_type session_id() const { return cinfo_.session_id(); }
uint64_t getCacheMemSz() const { return cache_mem_sz_; }
bool isSpill() const { return spill_; }
const std::string &getHostname() const { return hostname_; }
int32_t getPort() const { return port_; }
int32_t getNumWorkers() const { return num_workers_; }
int32_t getPrefetchSize() const { return prefetch_size_; }
private:
mutable RWLock mux_;
uint64_t cache_mem_sz_;
bool spill_;
// The session_id_ and cache_crc_ work together to uniquely identify this particular cache and allow
// sharing of the cache.
uint32_t session_id_;
uint32_t cache_crc_;
CacheClientInfo cinfo_;
// The server_connection_id_ is the actual id we use for operations after the cache is built
connection_id_type server_connection_id_;
// Some magic cookie returned from the cache server.
std::string cookie_;
// Comm layer
bool local_bypass_;
std::string hostname_;
int32_t port_;
int32_t num_workers_;
int32_t prefetch_size_;
mutable std::shared_ptr<CacheClientGreeter> comm_;
};
} // namespace dataset
} // namespace mindspore
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_
/// \note This header file contains common header files and some inlines used by
/// both client and server side codes. Do not put code that is not common here.
/// There are client and server specific header files.
// On platform like Windows, we may support only tcp/ip clients
#if !defined(_WIN32) && !defined(_WIN64)
#define CACHE_LOCAL_CLIENT 1
#endif
#ifdef CACHE_LOCAL_CLIENT
#include <sys/types.h>
#include <sys/ipc.h>
#include <sys/shm.h>
#else
typedef int key_t;
#endif
#ifdef ENABLE_CACHE
#include <grpcpp/grpcpp.h>
#endif
#include <string>
#ifdef ENABLE_CACHE
#include "proto/cache_grpc.grpc.pb.h"
#endif
#include "proto/cache_grpc.pb.h"
#include "minddata/dataset/engine/cache/cache_request.h"
#include "minddata/dataset/engine/cache/de_tensor_generated.h"
namespace mindspore {
namespace dataset {
/// \brief CacheRow and BatchFetch requests will switch to use shared memory method (if supported
/// on the platform) when the amount of bytes sent is greater than the following number.
/// For too small amount, we won't get any benefit using shared memory method because we need
/// two rpc requests to use shared memory method.
constexpr static int32_t kLocalByPassThreshold = 64 * 1024;
/// \brief A flag used by the BatchFetch request (client side) if it can support local bypass
constexpr static uint32_t kLocalClientSupport = 1;
/// \brief A flag used by CacheRow request (client side) and BatchFetch (server side) reply to indicate if the data is
/// inline in the protobuf. This also implies kLocalClientSupport is also true.
constexpr static uint32_t kDataIsInSharedMemory = 2;
/// \brief Convert a Status object into a protobuf
/// \param rc[in] Status object
/// \param reply[in/out] pointer to pre-allocated protobuf object
inline void Status2CacheReply(const Status &rc, CacheReply *reply) {
reply->set_rc(static_cast<google::int32>(rc.get_code()));
reply->set_msg(rc.ToString());
}
/// \brief Generate the unix socket file we use on both client/server side given a tcp/ip port number
/// \param port
/// \return unix socket url
inline std::string PortToUnixSocketPath(int port) { return "/tmp/cache_server_p" + std::to_string(port); }
/// \brief Generate a shared memory key using the tcp/ip port.
/// \note It must be called after the cache server generates the unix socket or ftok will fail.
/// \note Caller must check the return value. -1 means ftok failed.
/// \param[in] port
/// \param[out] err. If not null and ftok fails, this will contain the value of errno
/// \return key
inline key_t PortToFtok(int port, int *err) {
key_t shmkey = -1;
#ifdef CACHE_LOCAL_CLIENT
const std::string unix_path = PortToUnixSocketPath(port);
shmkey = ftok(unix_path.data(), 'a');
if (err != nullptr && shmkey == (key_t)-1) {
*err = errno;
}
#endif
return shmkey;
}
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_fbb.h"
namespace mindspore {
namespace dataset {
/// A private function used by SerializeTensorRowHeader to serialize each column in a tensor
/// \note Not to be called by outside world
/// \return Status object
Status SerializeOneTensorMeta(const std::shared_ptr<flatbuffers::FlatBufferBuilder> &fbb,
const std::shared_ptr<Tensor> &ts_ptr, flatbuffers::Offset<TensorMetaMsg> *out_off) {
RETURN_UNEXPECTED_IF_NULL(out_off);
const Tensor *ts = ts_ptr.get();
auto shape_off = fbb->CreateVector(ts->shape().AsVector());
const auto ptr = ts->GetBuffer();
if (ptr == nullptr) {
RETURN_STATUS_UNEXPECTED("Tensor buffer is null");
}
auto src = ts->type().value();
TensorType dest;
#define CASE(t) \
case DataType::t: \
dest = TensorType::TensorType_##t; \
break
// Map the type to fill in the flat buffer.
switch (src) {
CASE(DE_BOOL);
CASE(DE_INT8);
CASE(DE_UINT8);
CASE(DE_INT16);
CASE(DE_UINT16);
CASE(DE_INT32);
CASE(DE_UINT32);
CASE(DE_INT64);
CASE(DE_UINT64);
CASE(DE_FLOAT16);
CASE(DE_FLOAT32);
CASE(DE_FLOAT64);
CASE(DE_STRING);
default:
MS_LOG(ERROR) << "Unknown tensor. Dumping content:\n" << *ts;
RETURN_STATUS_UNEXPECTED("Unknown type");
}
#undef CASE
TensorMetaMsgBuilder ts_builder(*fbb);
ts_builder.add_dims(shape_off);
ts_builder.add_type(dest);
auto ts_off = ts_builder.Finish();
*out_off = ts_off;
return Status::OK();
}
Status SerializeTensorRowHeader(const TensorRow &row, std::shared_ptr<flatbuffers::FlatBufferBuilder> *out_fbb) {
RETURN_UNEXPECTED_IF_NULL(out_fbb);
auto fbb = std::make_shared<flatbuffers::FlatBufferBuilder>();
try {
fbb = std::make_shared<flatbuffers::FlatBufferBuilder>();
std::vector<flatbuffers::Offset<TensorMetaMsg>> v;
std::vector<int64_t> tensor_sz;
v.reserve(row.size());
tensor_sz.reserve(row.size());
// We will go through each column in the row.
for (const std::shared_ptr<Tensor> &ts_ptr : row) {
flatbuffers::Offset<TensorMetaMsg> ts_off;
RETURN_IF_NOT_OK(SerializeOneTensorMeta(fbb, ts_ptr, &ts_off));
v.push_back(ts_off);
tensor_sz.push_back(ts_ptr->SizeInBytes());
}
auto column_off = fbb->CreateVector(v);
auto data_sz_off = fbb->CreateVector(tensor_sz);
TensorRowHeaderMsgBuilder row_builder(*fbb);
row_builder.add_column(column_off);
row_builder.add_data_sz(data_sz_off);
// Pass the row_id even if it may not be known.
row_builder.add_row_id(row.getId());
row_builder.add_size_of_this(-1); // fill in later after we call Finish.
auto out = row_builder.Finish();
fbb->Finish(out);
// Now go back to fill in size_of_this in the flat buffer.
auto msg = GetMutableTensorRowHeaderMsg(fbb->GetBufferPointer());
auto success = msg->mutate_size_of_this(fbb->GetSize());
if (!success) {
RETURN_STATUS_UNEXPECTED("Unable to set size_of_this");
}
(*out_fbb) = std::move(fbb);
return Status::OK();
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
}
Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr<Tensor> *out) {
RETURN_UNEXPECTED_IF_NULL(col_ts);
auto shape_in = col_ts->dims();
auto type_in = col_ts->type();
std::vector<dsize_t> v;
v.reserve(shape_in->size());
v.assign(shape_in->begin(), shape_in->end());
TensorShape shape(v);
DataType::Type dest = DataType::DE_UNKNOWN;
#define CASE(t) \
case TensorType_##t: \
dest = DataType::Type::t; \
break
switch (type_in) {
CASE(DE_BOOL);
CASE(DE_INT8);
CASE(DE_UINT8);
CASE(DE_INT16);
CASE(DE_UINT16);
CASE(DE_INT32);
CASE(DE_UINT32);
CASE(DE_INT64);
CASE(DE_UINT64);
CASE(DE_FLOAT16);
CASE(DE_FLOAT32);
CASE(DE_FLOAT64);
CASE(DE_STRING);
}
#undef CASE
DataType type(dest);
std::shared_ptr<Tensor> ts;
RETURN_IF_NOT_OK(
Tensor::CreateFromMemory(shape, type, static_cast<const unsigned char *>(data.GetPointer()), data.GetSize(), &ts));
// Next we restore the real data which can be embedded or stored separately.
if (ts->SizeInBytes() != data.GetSize()) {
MS_LOG(ERROR) << "Unexpected length. Read " << data.GetSize() << ". Expected " << ts->SizeInBytes() << ".\n"
<< "Dumping tensor\n"
<< *ts << "\n";
RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details.");
}
*out = std::move(ts);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_FBB_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_FBB_H_
/// This header contains some serialize and deserialize functions for tensor row using
/// Google Flatbuffer
#include <memory>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/cache/de_tensor_generated.h"
#include "minddata/dataset/core/tensor_row.h"
#include "minddata/dataset/util/slice.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
/// \brief Function to serialize TensorRow header used by CacheRowRequest
/// \param row TensorRow
/// \param fbb [in/out] fbb that contains the serialized data
/// \return Status object
Status SerializeTensorRowHeader(const TensorRow &row, std::shared_ptr<flatbuffers::FlatBufferBuilder> *fbb);
/// \brief A function used by BatchFetchRequest to deserialize a flat buffer back to a tensor row.
/// \param col_ts A serialized version of Tensor meta data
/// \param data Tensor data wrapped in a slice
/// \param out Tensor
/// \return Status object
Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr<Tensor> *out);
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_FBB_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
syntax = "proto3";
package mindspore.dataset;
option cc_enable_arenas = true;
// The session_id and crc work together to uniquely identify this particular cache and allow
// sharing of the cache.
message CacheClientInfo {
uint32 session_id = 1;
uint32 crc = 2;
}
message CacheRequest {
// Type of rpc request
int32 type = 1;
// Extra optional flag used by individual request if needed
uint32 flag = 2;
oneof connect_info {
// The server_connection_id is the actual id we use for operations after the cache is built
int64 connection_id = 3;
// But some request like CreateCache we have to use the session id and crc to connect to the server.
CacheClientInfo connection_info = 4;
}
// Everything else is just vector of buffers
repeated bytes buf_data = 5;
}
message CacheReply {
int32 rc = 1;
string msg = 2;
// Extra optional flag used by individual request if needed
uint32 flag = 3;
// What the server send back is a plain buffer
bytes result = 4;
}
service CacheServerGreeter {
rpc CacheServerRequest (CacheRequest) returns (CacheReply) {}
}
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_grpc_client.h"
#include <chrono>
namespace mindspore {
namespace dataset {
Status CacheClientRequestTag::MakeCall(CacheServerGreeter::Stub *stub, grpc::CompletionQueue *cq,
std::unique_ptr<CacheClientRequestTag> &&tag) {
// If there is anything extra we need to do before we send.
RETURN_IF_NOT_OK(tag->base_rq_->Prepare());
// One minute timeout
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60);
tag->ctx_.set_deadline(deadline);
tag->rpc_ = stub->PrepareAsyncCacheServerRequest(&tag->ctx_, tag->base_rq_->rq_, cq);
tag->rpc_->StartCall();
// Last step is we release the ownership and transfer it to the completion queue.
// The memory will be released by WorkerEntry or by the destructor when we drain the queue
auto ccReqTag = tag.release();
ccReqTag->rpc_->Finish(&ccReqTag->base_rq_->reply_, &ccReqTag->rc_,
ccReqTag); // inject this object into the completion queue
return Status::OK();
}
CacheClientGreeter::~CacheClientGreeter() {
(void)ServiceStop();
// Detach from shared memory if any
if (shmat_addr_ != nullptr) {
shmdt(shmat_addr_);
shmat_addr_ = nullptr;
}
}
CacheClientGreeter::CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers)
: num_workers_(num_workers), shm_key_(-1), shm_id_(-1), shmat_addr_(nullptr) {
grpc::ChannelArguments args;
// We need to bump up the message size to unlimited. The default receiving
// message limit is 4MB which is not big enough.
args.SetMaxReceiveMessageSize(-1);
#if CACHE_LOCAL_CLIENT
// Try connect locally to the unix_socket first as the first preference
// Need to resolve hostname to ip address rather than to do a string compare
if (hostname == "127.0.0.1") {
std::string target = "unix://" + PortToUnixSocketPath(port);
channel_ = grpc::CreateCustomChannel(target, grpc::InsecureChannelCredentials(), args);
} else {
#endif
std::string target = hostname + ":" + std::to_string(port);
channel_ = grpc::CreateCustomChannel(target, grpc::InsecureChannelCredentials(), args);
#if CACHE_LOCAL_CLIENT
}
#endif
stub_ = CacheServerGreeter::NewStub(channel_);
}
Status CacheClientGreeter::AttachToSharedMemory(int32_t port, bool *local_bypass) {
*local_bypass = false;
#if CACHE_LOCAL_CLIENT
int err;
shm_key_ = PortToFtok(port, &err);
if (shm_key_ == (key_t)-1) {
std::string errMsg = "Ftok failed with errno " + std::to_string(err);
RETURN_STATUS_UNEXPECTED(errMsg);
}
// Attach to the shared memory
shm_id_ = shmget(shm_key_, 0, 0);
if (shm_id_ == -1) {
RETURN_STATUS_UNEXPECTED("Shmget failed. Errno " + std::to_string(errno));
}
shmat_addr_ = shmat(shm_id_, nullptr, 0);
if (shmat_addr_ == reinterpret_cast<void *>(-1)) {
RETURN_STATUS_UNEXPECTED("Shared memory attach failed. Errno " + std::to_string(errno));
}
*local_bypass = true;
#endif
return Status::OK();
}
Status CacheClientGreeter::DoServiceStart() {
RETURN_IF_NOT_OK(vg_.ServiceStart());
RETURN_IF_NOT_OK(DispatchWorkers(num_workers_));
return Status::OK();
}
Status CacheClientGreeter::DoServiceStop() {
// Shutdown the queue. We don't accept any more new incomers.
cq_.Shutdown();
// Shutdown the TaskGroup.
vg_.interrupt_all();
vg_.join_all(Task::WaitFlag::kNonBlocking);
// Drain the queue
bool success;
void *tag;
while (cq_.Next(&tag, &success)) {
auto r = reinterpret_cast<CacheClientRequestTag *>(tag);
delete r;
}
return Status::OK();
}
Status CacheClientGreeter::HandleRequest(std::shared_ptr<BaseRequest> rq) {
auto tag = std::make_unique<CacheClientRequestTag>(std::move(rq));
return tag->MakeCall(stub_.get(), &cq_, std::move(tag));
}
Status CacheClientGreeter::WorkerEntry() {
TaskManager::FindMe()->Post();
do {
bool success;
void *tag;
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(1);
// Set a timeout for one second. Check for interrupt if we need to do early exit.
auto r = cq_.AsyncNext(&tag, &success, deadline);
if (r == grpc_impl::CompletionQueue::NextStatus::GOT_EVENT) {
auto rq = reinterpret_cast<CacheClientRequestTag *>(tag);
if (success) {
auto &rc = rq->rc_;
if (!rc.ok()) {
auto error_code = rq->rc_.error_code();
std::string errMsg = rq->rc_.error_message() + ". GRPC Code " + std::to_string(error_code);
Status remote_rc = Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, errMsg);
Status2CacheReply(remote_rc, &rq->base_rq_->reply_);
}
// Notify the waiting thread.
rq->Notify();
}
// We can now free the memory
delete rq;
} else if (r == grpc_impl::CompletionQueue::NextStatus::TIMEOUT) {
// If we are interrupted, exit. Otherwise wait again.
RETURN_IF_INTERRUPTED();
} else {
// Queue is drained.
break;
}
} while (true);
return Status::OK();
}
Status CacheClientGreeter::DispatchWorkers(int32_t num_workers) {
auto f = std::bind(&CacheClientGreeter::WorkerEntry, this);
for (auto i = 0; i < num_workers; ++i) {
RETURN_IF_NOT_OK(vg_.CreateAsyncTask("Async reply", f));
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_
#include <memory>
#include <string>
#include <utility>
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/util/service.h"
#include "minddata/dataset/util/task_manager.h"
namespace mindspore {
namespace dataset {
/// \brief A client view of gRPC request
/// Like the class CacheServerRequest, this is used as a tag to inject into the gRPC
/// completion queue. The thread that makes the rpc request will wait on a wait post
/// area for the reply to come back. Since this tag will be deleted from memory and
/// we thus we need to work on a shared pointer of the BaseRequest such that its
/// use count is at least two. Otherwise either thread will be referencing stale memory.
/// \see CacheServerRequest
class CacheClientRequestTag {
public:
friend class CacheClientGreeter;
explicit CacheClientRequestTag(std::shared_ptr<BaseRequest> rq) : base_rq_(std::move(rq)) {}
~CacheClientRequestTag() = default;
/// \brief Make a RPC call
/// \param stub from CacheClientGreeter
/// \param cq from CacheClientGreeter
/// \return Status object
static Status MakeCall(CacheServerGreeter::Stub *stub, grpc::CompletionQueue *cq,
std::unique_ptr<CacheClientRequestTag> &&tag);
/// \brief Notify the client that a result has come back from the server
void Notify() { base_rq_->wp_.Set(); }
private:
std::shared_ptr<BaseRequest> base_rq_;
grpc::Status rc_;
grpc::ClientContext ctx_;
std::unique_ptr<grpc::ClientAsyncResponseReader<CacheReply>> rpc_;
};
/// \brief A GRPC layer to convert BaseRequest into protobuf and send to the cache server using gRPC
/// \see BaseRequest
class CacheClientGreeter : public Service {
friend class CacheClient;
public:
explicit CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers);
~CacheClientGreeter();
/// Override base Service class
Status DoServiceStart() override;
Status DoServiceStop() override;
/// \brief Send the request to the server
/// \return Status object
Status HandleRequest(std::shared_ptr<BaseRequest> rq);
/// \brief A handful of threads will be handling async reply from the server
/// \return
Status WorkerEntry();
/// \brief Kick off threads to receive reply from the server
Status DispatchWorkers(int32_t num_workers);
/// \brief Attach to shared memory for local client
/// \note Called after we have established a connection.
/// \return Status object.
Status AttachToSharedMemory(int32_t port, bool *local_bypass);
/// \brief This returns where we attach to the shared memory.
/// \return Base address of the shared memory.
const void *SharedMemoryBaseAddr() const { return shmat_addr_; }
private:
std::shared_ptr<grpc::Channel> channel_;
std::unique_ptr<CacheServerGreeter::Stub> stub_;
grpc::CompletionQueue cq_;
TaskGroup vg_;
int32_t num_workers_;
key_t shm_key_;
int32_t shm_id_;
void *shmat_addr_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_grpc_server.h"
#include <limits>
#include "minddata/dataset/engine/cache/cache_server.h"
#include "minddata/dataset/util/path.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
CacheServerGreeterImpl::CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb)
: port_(port), shm_pool_sz_in_gb_(shared_memory_sz_in_gb) {
// Setup a path for unix socket.
unix_socket_ = PortToUnixSocketPath(port);
// We can't generate the ftok key yet until the unix_socket_ is created
}
void CacheServerGreeterImpl::Shutdown() {
if (server_) {
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(1);
server_->Shutdown(deadline);
}
// Always shutdown the completion queue after the server.
if (cq_) {
cq_->Shutdown();
// We need to drain the queue. All the tag is coming from
// the Services pool which will be shutdown as well. So we
// ignore the tag.
void *tag;
bool success;
while (cq_->Next(&tag, &success)) {
}
}
}
CacheServerGreeterImpl::~CacheServerGreeterImpl() { Shutdown(); }
Status CacheServerGreeterImpl::IpcResourceCleanup() {
#if CACHE_LOCAL_CLIENT
int err;
auto shm_key = PortToFtok(port_, &err);
// We are expecting the unix path doesn't exist.
if (shm_key == (key_t)-1) {
return Status::OK();
}
// Attach to the shared memory
auto shm_id = shmget(shm_key, 0, 0);
if (shm_id == -1) {
return Status::OK();
}
struct shmid_ds ds {};
auto inx = shmctl(shm_id, IPC_STAT, &ds);
if (inx == -1) {
std::string errMsg = "Unable to query shared memory with id " + std::to_string(shm_id);
errMsg += "\nPlesae remove it manually using ipcrm -m command";
RETURN_STATUS_UNEXPECTED(errMsg);
}
if (ds.shm_nattch == 0) {
// Stale shared memory from last time.
// Remove both the memory and the socket path
inx = shmctl(shm_id, IPC_RMID, nullptr);
if (inx == -1) {
std::string errMsg = "Unable to remove shared memory with id " + std::to_string(shm_id);
errMsg += ". Errno :" + std::to_string(errno);
errMsg += "\nPlesae remove it manually using ipcrm -m command";
RETURN_STATUS_UNEXPECTED(errMsg);
}
Path p(unix_socket_);
(void)p.Remove();
} else {
// Server is already up.
MS_LOG(ERROR) << "Cache server is already up and running";
// We return a duplicate error. The main() will intercept
// and output a proper message
return Status(StatusCode::kDuplicateKey);
}
#endif
return Status::OK();
}
Status CacheServerGreeterImpl::Run() {
// To listen on all interfaces, use 0.0.0.0
// Use 127.0.0.1 if just locally on the same machine.
std::string host("0.0.0.0"); // listen on all interfaces.
std::string server_address = host + ":" + std::to_string(port_);
grpc::ServerBuilder builder;
// Default message size for gRPC is 4MB. Increase it to 2g-1
builder.SetMaxReceiveMessageSize(std::numeric_limits<int32_t>::max());
int port_tcpip = 0;
#if CACHE_LOCAL_CLIENT
int port_local = 0;
// Check if we need to do clean up on the shared memory if the server
// came down unexpectedly like SEGV
RETURN_IF_NOT_OK(IpcResourceCleanup());
// We also optimize on local clients on the same machine using unix socket
builder.AddListeningPort("unix://" + unix_socket_, grpc::InsecureServerCredentials(), &port_local);
#endif
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_tcpip);
builder.RegisterService(&svc_);
cq_ = builder.AddCompletionQueue();
server_ = builder.BuildAndStart();
if (server_) {
MS_LOG(INFO) << "Server listening on " << server_address;
#if CACHE_LOCAL_CLIENT
RETURN_IF_NOT_OK(CachedSharedMemoryArena::CreateArena(&shm_pool_, port_, shm_pool_sz_in_gb_));
MS_LOG(INFO) << "Creation of local socket and shared memory successful";
#endif
} else {
std::string errMsg = "Fail to start server. ";
if (port_tcpip != port_) {
errMsg += "Unable to bind to tcpip port " + std::to_string(port_) + ".";
}
#if CACHE_LOCAL_CLIENT
if (port_local == 0) {
errMsg += " Unable to create unix socket " + unix_socket_ + ".";
}
#endif
RETURN_STATUS_UNEXPECTED(errMsg);
}
return Status::OK();
}
Status CacheServerGreeterImpl::HandleRequest(int32_t worker_id) {
bool success;
void *tag;
// We loop through the grpc queue. Each connection if successful
// will come back with our own tag which is an instance of CacheServerRequest
// and we simply call its functor. But first we need to create these instances
// and inject them into the grpc queue.
CacheServerRequest *p;
// Get a free tag from my free list.
RETURN_IF_NOT_OK(CacheServer::GetFreeRequestTag(worker_id, &p));
RETURN_IF_NOT_OK((*p)(&svc_, cq_.get()));
do {
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(1);
// Set a timeout for one second. Check for interrupt if we need to do early exit.
auto r = cq_->AsyncNext(&tag, &success, deadline);
if (r == grpc_impl::CompletionQueue::NextStatus::GOT_EVENT) {
if (success) {
auto rq = static_cast<CacheServerRequest *>(tag);
RETURN_IF_NOT_OK((*rq)(&svc_, cq_.get()));
}
} else if (r == grpc_impl::CompletionQueue::NextStatus::TIMEOUT) {
// If we are interrupted, exit. Otherwise wait again.
RETURN_IF_INTERRUPTED();
} else {
// Queue is drained.
break;
}
} while (true);
return Status::OK();
}
Status CacheServerRequest::operator()(CacheServerGreeter::AsyncService *svc, grpc::ServerCompletionQueue *cq) {
auto myQID = getQid();
if (st_ == STATE::CREATE) {
st_ = STATE::PROCESS;
svc->RequestCacheServerRequest(&ctx_, &rq_, &responder_, cq, cq, this);
} else if (st_ == STATE::PROCESS) {
// Get a new tag and handle the next request before we serve the current request.
// The tag will be recycled when its state is changed to FINISH
CacheServerRequest *next_rq;
RETURN_IF_NOT_OK(CacheServer::GetFreeRequestTag(myQID, &next_rq));
RETURN_IF_NOT_OK((*next_rq)(svc, cq));
// Now we continue with the current request.
// First thing we need to extract the type from the incoming request.
// When this object was first created (i.e. STATE::CREATE), we set the type to UNKNOWN.
type_ = static_cast<RequestType>(rq_.type());
// Now we pass the address of this instance to CacheServer's main loop.
MS_LOG(DEBUG) << "Handle request " << *this;
auto &cs = CacheServer::GetInstance();
RETURN_IF_NOT_OK(cs.PushRequest(myQID, this));
} else if (st_ == STATE::FINISH) {
MS_LOG(DEBUG) << *this << " Finished.";
// Return back to the free list.
RETURN_IF_NOT_OK(CacheServer::ReturnRequestTag(this));
}
return Status::OK();
}
void CacheServerRequest::Print(std::ostream &out) const {
if (rq_.has_connection_info()) {
out << "Session Id: " << rq_.connection_info().session_id() << " CRC: " << rq_.connection_info().crc();
} else {
out << "Connection Id: " << rq_.connection_id();
}
out << " ";
BaseRequest::Print(out);
}
} // namespace dataset
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/engine/cache/cache_arena.h"
#include "minddata/dataset/util/allocator.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/task_manager.h"
namespace mindspore {
namespace dataset {
/// \brief Server side view of BaseRequest. Incoming request are in the form of protobuf objects
/// and this class is used to translate from protobuf to structures understood by CacheService class.
/// \see CacheService
class CacheServerRequest : public BaseRequest {
public:
friend class CacheServer;
enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 };
explicit CacheServerRequest(int32_t queue_id)
: BaseRequest::BaseRequest(BaseRequest::RequestType::kRequestUnknown),
qid_(queue_id),
st_(STATE::CREATE),
responder_(&ctx_) {}
~CacheServerRequest() = default;
/// \brief Functor. Used mainly by CacheServerGreeterImpl class to tag each incoming request and this
/// functor will translate each protobuf into some form understood by by CacheService class.
/// \param svc Async service
/// \param cq Completion queue
/// \return Status object
Status operator()(CacheServerGreeter::AsyncService *svc, grpc::ServerCompletionQueue *cq);
/// \brief Override the base class Print method
/// \param out
void Print(std::ostream &out) const override;
/// \brief Getter of the queue id
/// \return The queue where the request should go to
int32_t getQid() const { return qid_; }
private:
int32_t qid_;
Status rc_;
STATE st_;
grpc::ServerContext ctx_;
grpc::ServerAsyncResponseWriter<CacheReply> responder_;
};
/// \brief Implementation of CacheServerGreeter
/// \note It is an async server
/// \see cache_grpc.proto
class CacheServerGreeterImpl final {
friend class CacheServer;
public:
explicit CacheServerGreeterImpl(int32_t port, int32_t shared_memory_sz_in_gb);
virtual ~CacheServerGreeterImpl();
/// \brief Brings up gRPC server
/// \return none
Status Run();
/// \brief Entry function to handle cache server request
Status HandleRequest(int32_t worker_id);
/// Return the shared memory pool.
/// \return Return the shared memory pool
CachedSharedMemoryArena *GetSharedMemoryPool() { return shm_pool_.get(); }
void Shutdown();
Status IpcResourceCleanup();
private:
int32_t port_;
size_t shm_pool_sz_in_gb_;
std::string unix_socket_;
CacheServerGreeter::AsyncService svc_;
std::unique_ptr<grpc::ServerCompletionQueue> cq_;
std::unique_ptr<grpc::Server> server_;
std::unique_ptr<CachedSharedMemoryArena> shm_pool_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_server.h"
#include <sys/types.h>
#include <unistd.h>
#ifdef USE_GLOG
#include <glog/logging.h>
#endif
#include <cstdlib>
namespace ds = mindspore::dataset;
int main(int argc, char **argv) {
ds::Status rc;
ds::CacheServer::Builder builder;
// This executable is not to be called directly, and should be invoked by cache_admin executable.
if (argc != 7) {
rc = ds::Status(ds::StatusCode::kSyntaxError);
std::cerr << rc.ToString() << std::endl;
return static_cast<int>(rc.get_code());
}
builder.SetRootDirectory(argv[1])
.SetNumWorkers(strtol(argv[2], nullptr, 10))
.SetPort(strtol(argv[3], nullptr, 10))
.SetSharedMemorySizeInGB(strtol(argv[4], nullptr, 10));
#ifdef USE_GLOG
FLAGS_minloglevel = strtol(argv[5], nullptr, 10);
#endif
auto daemonize_string = argv[6];
bool daemonize = strcmp(daemonize_string, "true") == 0 || strcmp(daemonize_string, "TRUE") == 0 ||
strcmp(daemonize_string, "t") == 0 || strcmp(daemonize_string, "T") == 0;
// We always change directory to / on unix rather than using the directory where the cache_server
// is called. This is a standard procedure for daemonize a process on unix.
if (chdir("/") == -1) {
std::string errMsg = "Unable to change directory to /. Errno = " + std::to_string(errno);
std::cerr << errMsg << std::endl;
return -1;
}
// Simple check of the parameters before we move on.
rc = builder.SanityCheck();
if (rc.IsError()) {
std::cerr << rc.ToString() << std::endl;
return static_cast<int>(rc.get_code());
}
#ifdef USE_GLOG
FLAGS_log_dir = "/tmp";
google::InitGoogleLogging(argv[0]);
#endif
if (daemonize) {
// fork the child process to become the daemon
pid_t pid = fork();
// failed to fork
if (pid < 0) {
std::string err_msg = "Failed to fork process for cache server: " + std::to_string(errno);
std::cerr << err_msg << std::endl;
return errno;
} else if (pid > 0) {
// Parent
std::cerr << "cache server daemon process has been created as process id: " << pid
<< "\nCheck log file for any start up error" << std::endl;
signal(SIGCHLD, SIG_IGN); // ignore sig child signal.
return 0;
} else {
// Child process will continue from here if daemonize and parent has already exited.
// If we are running in the foreground, none of the code in block below will be run.
pid_t sid;
umask(0);
sid = setsid();
if (sid < 0) {
MS_LOG(ERROR) << "Failed to setsid(). Errno = " << std::to_string(errno);
return errno;
}
close(0);
close(1);
close(2);
}
}
// Dump the summary
MS_LOG(INFO) << builder << std::endl;
rc = builder.Build();
if (rc.IsOk()) {
ds::CacheServer &cs = ds::CacheServer::GetInstance();
// Kick off the threads. Loop forever and never return unless error.
rc = cs.Run();
if (rc.get_code() == ds::StatusCode::kDuplicateKey) {
std::string errMsg = "Server is already started";
MS_LOG(ERROR) << errMsg;
std::cerr << errMsg << std::endl;
return 0;
}
}
if (rc.IsError()) {
MS_LOG(ERROR) << rc.ToString();
std::cerr << rc.ToString() << std::endl;
return static_cast<int>(rc.get_code());
}
return 0;
}
......@@ -14,154 +14,149 @@
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_request.h"
#include <cstdlib>
#include <thread>
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/cache/cache_fbb.h"
namespace mindspore {
namespace dataset {
Status CacheRowRequest::SerializeCacheRowRequest(const TensorRow &row) {
buffers_.reserve(row.size() + 1);
RETURN_IF_NOT_OK(SerializeTensorRowHeader(row));
buffers_.push_back(fbb_->GetBufferPointer());
for (const auto &ts : row) {
buffers_.push_back(ts->GetBuffer());
}
Status BaseRequest::Wait() {
RETURN_IF_NOT_OK(wp_.Wait());
Status remote_rc(static_cast<StatusCode>(reply_.rc()), reply_.msg());
RETURN_IF_NOT_OK(remote_rc);
// Any extra work to do before we return back to the client.
RETURN_IF_NOT_OK(PostReply());
return Status::OK();
}
Status CacheRowRequest::SerializeTensorRowHeader(const TensorRow &row) {
try {
fbb_ = std::make_shared<flatbuffers::FlatBufferBuilder>();
std::vector<flatbuffers::Offset<TensorMetaMsg>> v;
std::vector<int64_t> tensor_sz;
v.reserve(row.size());
tensor_sz.reserve(row.size());
// We will go through each column in the row.
for (const std::shared_ptr<Tensor> &ts_ptr : row) {
flatbuffers::Offset<TensorMetaMsg> ts_off;
RETURN_IF_NOT_OK(SerializeOneTensorMeta(ts_ptr, &ts_off));
v.push_back(ts_off);
tensor_sz.push_back(ts_ptr->SizeInBytes());
Status CacheRowRequest::SerializeCacheRowRequest(const CacheClient *cc, const TensorRow &row) {
CHECK_FAIL_RETURN_UNEXPECTED(row.size() > 0, "Empty tensor row");
CHECK_FAIL_RETURN_UNEXPECTED(cc->SupportLocalClient() == support_local_bypass_, "Local bypass mismatch");
// Calculate how many bytes (not counting the cookie) we are sending to the server. We only
// use shared memory (if supported) if we exceed certain amount
std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb;
RETURN_IF_NOT_OK(::mindspore::dataset::SerializeTensorRowHeader(row, &fbb));
sz_ += fbb->GetSize();
for (const auto &ts : row) {
sz_ += ts->SizeInBytes();
}
bool sent_using_local_bypass = support_local_bypass_ ? (sz_ >= kLocalByPassThreshold) : false;
uint32_t flag = 0;
if (support_local_bypass_) {
BitSet(&flag, kLocalClientSupport);
}
if (sent_using_local_bypass) {
BitSet(&flag, kDataIsInSharedMemory);
}
rq_.set_flag(flag);
if (sent_using_local_bypass) {
MS_LOG(DEBUG) << "Requesting " << sz_ << " bytes of shared memory data";
// Allocate shared memory from the server
auto mem_rq = std::make_shared<AllocateSharedBlockRequest>(rq_.connection_id(), sz_);
RETURN_IF_NOT_OK(cc->PushRequest(mem_rq));
RETURN_IF_NOT_OK(mem_rq->Wait());
addr_ = mem_rq->GetAddr();
// Now we need to add that to the base address of where we attach.
auto base = cc->SharedMemoryBaseAddr();
auto p = reinterpret_cast<void *>(reinterpret_cast<int64_t>(base) + addr_);
// Now we copy the data onto shared memory.
WritableSlice all(p, sz_);
auto offset = fbb->GetSize();
ReadableSlice header(fbb->GetBufferPointer(), fbb->GetSize());
Status copy_rc;
copy_rc = WritableSlice::Copy(&all, header);
if (copy_rc.IsOk()) {
for (const auto &ts : row) {
WritableSlice row_data(all, offset, ts->SizeInBytes());
ReadableSlice src(ts->GetBuffer(), ts->SizeInBytes());
copy_rc = WritableSlice::Copy(&row_data, src);
if (copy_rc.IsError()) {
break;
}
offset += ts->SizeInBytes();
}
// Fill in where to find the data
AddDataLocation();
}
auto column_off = fbb_->CreateVector(v);
auto data_sz_off = fbb_->CreateVector(tensor_sz);
TensorRowHeaderMsgBuilder row_builder(*fbb_);
row_builder.add_column(column_off);
row_builder.add_data_sz(data_sz_off);
// Pass the row_id even if it may not be known.
row_builder.add_row_id(row.getId());
row_builder.add_size_of_this(-1); // fill in later after we call Finish.
auto out = row_builder.Finish();
fbb_->Finish(out);
// Now go back to fill in size_of_this in the flat buffer.
auto msg = GetMutableTensorRowHeaderMsg(fbb_->GetBufferPointer());
auto success = msg->mutate_size_of_this(fbb_->GetSize());
if (!success) {
RETURN_STATUS_UNEXPECTED("Unable to set size_of_this");
if (copy_rc.IsError()) {
// We need to return the memory back to the server
auto mfree_req = GenerateFreeBlockRequest();
Status rc = cc->PushRequest(mfree_req);
// But we won't wait for the result for the sake of performance.
if (rc.IsError()) {
MS_LOG(ERROR) << "Push request for free memory failed.";
}
return copy_rc;
}
return Status::OK();
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
} else {
// We have already filled the first buffer which is the cookie.
sz_ += rq_.buf_data(0).size();
rq_.add_buf_data(fbb->GetBufferPointer(), fbb->GetSize());
for (const auto &ts : row) {
rq_.add_buf_data(ts->GetBuffer(), ts->SizeInBytes());
}
MS_LOG(DEBUG) << "Sending " << sz_ << " bytes of tensor data in " << rq_.buf_data_size() << " segments";
}
return Status::OK();
}
Status CacheRowRequest::SerializeOneTensorMeta(const std::shared_ptr<Tensor> &ts_ptr,
flatbuffers::Offset<TensorMetaMsg> *out_off) {
RETURN_UNEXPECTED_IF_NULL(out_off);
const Tensor *ts = ts_ptr.get();
auto shape_off = fbb_->CreateVector(ts->shape().AsVector());
const auto ptr = ts->GetBuffer();
if (ptr == nullptr) {
RETURN_STATUS_UNEXPECTED("Tensor buffer is null");
}
auto src = ts->type().value();
TensorType dest;
#define CASE(t) \
case DataType::t: \
dest = TensorType::TensorType_##t; \
break
// Map the type to fill in the flat buffer.
switch (src) {
CASE(DE_BOOL);
CASE(DE_INT8);
CASE(DE_UINT8);
CASE(DE_INT16);
CASE(DE_UINT16);
CASE(DE_INT32);
CASE(DE_UINT32);
CASE(DE_INT64);
CASE(DE_UINT64);
CASE(DE_FLOAT16);
CASE(DE_FLOAT32);
CASE(DE_FLOAT64);
CASE(DE_STRING);
default:
MS_LOG(ERROR) << "Unknown tensor. Dumping content:\n" << *ts;
RETURN_STATUS_UNEXPECTED("Unknown type");
Status CacheRowRequest::PostReply() {
if (!reply_.result().empty()) {
row_id_from_server_ = strtoll(reply_.result().data(), nullptr, 10);
}
#undef CASE
TensorMetaMsgBuilder ts_builder(*fbb_);
ts_builder.add_dims(shape_off);
ts_builder.add_type(dest);
auto ts_off = ts_builder.Finish();
*out_off = ts_off;
return Status::OK();
}
Status BatchFetchRequest::RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data,
std::shared_ptr<Tensor> *out) {
RETURN_UNEXPECTED_IF_NULL(col_ts);
auto shape_in = col_ts->dims();
auto type_in = col_ts->type();
std::vector<dsize_t> v;
v.reserve(shape_in->size());
v.assign(shape_in->begin(), shape_in->end());
TensorShape shape(v);
DataType::Type dest = DataType::DE_UNKNOWN;
#define CASE(t) \
case TensorType_##t: \
dest = DataType::Type::t; \
break
switch (type_in) {
CASE(DE_BOOL);
CASE(DE_INT8);
CASE(DE_UINT8);
CASE(DE_INT16);
CASE(DE_UINT16);
CASE(DE_INT32);
CASE(DE_UINT32);
CASE(DE_INT64);
CASE(DE_UINT64);
CASE(DE_FLOAT16);
CASE(DE_FLOAT32);
CASE(DE_FLOAT64);
CASE(DE_STRING);
Status CacheRowRequest::Prepare() {
if (BitTest(rq_.flag(), kDataIsInSharedMemory)) {
// First one is cookie, followed by address and then size.
CHECK_FAIL_RETURN_UNEXPECTED(rq_.buf_data_size() == 3, "Incomplete rpc data");
} else {
// First one is cookie. 2nd one is the google flat buffers followed by a number of buffers.
// But we are not going to decode them to verify.
CHECK_FAIL_RETURN_UNEXPECTED(rq_.buf_data_size() >= 3, "Incomplete rpc data");
}
#undef CASE
DataType type(dest);
std::shared_ptr<Tensor> ts;
RETURN_IF_NOT_OK(
Tensor::CreateFromMemory(shape, type, static_cast<const unsigned char *>(data.GetPointer()), data.GetSize(), &ts));
// Next we restore the real data which can be embedded or stored separately.
if (ts->SizeInBytes() != data.GetSize()) {
MS_LOG(ERROR) << "Unexpected length. Read " << data.GetSize() << ". Expected " << ts->SizeInBytes() << ".\n"
<< "Dumping tensor\n"
<< *ts << "\n";
RETURN_STATUS_UNEXPECTED("Length mismatch. See log file for details.");
}
*out = std::move(ts);
return Status::OK();
}
Status BatchFetchRequest::RestoreRows(TensorTable *out) {
BatchFetchRequest::BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id,
bool local_bypass)
: BaseRequest(RequestType::kBatchFetchRows), support_local_bypass_(local_bypass), row_id_(row_id) {
rq_.set_connection_id(connection_id);
rq_.set_flag(support_local_bypass_ ? kLocalClientSupport : 0);
// Convert the row id into a flatbuffer
flatbuffers::FlatBufferBuilder fbb;
auto off_t = fbb.CreateVector(row_id);
TensorRowIdsBuilder bld(fbb);
bld.add_row_id(off_t);
auto off = bld.Finish();
fbb.Finish(off);
rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize());
}
Status BatchFetchRequest::RestoreRows(TensorTable *out, const void *baseAddr, int64_t *out_addr) {
RETURN_UNEXPECTED_IF_NULL(out);
auto num_elements = row_id_.size();
auto *offset_array = reinterpret_cast<const int64_t *>(mem_.GetPointer());
const char *ptr = nullptr;
int64_t sz = 0;
// Tap into the reply flag to see where we can find the data. Server may decide the amount is
// so small that it doesn't use shared memory method.
auto flag = reply_.flag();
bool dataOnSharedMemory = support_local_bypass_ ? (BitTest(flag, kDataIsInSharedMemory)) : false;
if (dataOnSharedMemory) {
auto addr = strtoll(reply_.result().data(), nullptr, 10);
ptr = reinterpret_cast<const char *>(reinterpret_cast<int64_t>(baseAddr) + addr);
RETURN_UNEXPECTED_IF_NULL(out);
*out_addr = addr;
} else {
ptr = reply_.result().data();
*out_addr = -1;
}
auto *offset_array = reinterpret_cast<const int64_t *>(ptr);
sz = offset_array[num_elements];
CHECK_FAIL_RETURN_UNEXPECTED(support_local_bypass_ || sz == reply_.result().length(), "Length mismatch");
TensorTable tbl;
tbl.reserve(num_elements);
ReadableSlice all(mem_.GetPointer(), mem_.GetSizeInBytes());
ReadableSlice all(ptr, sz);
for (auto i = 0; i < num_elements; ++i) {
auto len = offset_array[i + 1] - offset_array[i];
TensorRow row;
......@@ -178,10 +173,12 @@ Status BatchFetchRequest::RestoreRows(TensorTable *out) {
auto col_ts = msg->column()->Get(k);
std::shared_ptr<Tensor> ts;
ReadableSlice data(row_data, ts_offset, msg->data_sz()->Get(k));
RETURN_IF_NOT_OK(RestoreOneTensor(col_ts, data, &ts));
RETURN_IF_NOT_OK(mindspore::dataset::RestoreOneTensor(col_ts, data, &ts));
row.push_back(ts);
ts_offset += data.GetSize();
}
} else {
CHECK_FAIL_RETURN_UNEXPECTED(len == 0, "Data corruption detected.");
}
tbl.push_back(std::move(row));
}
......@@ -189,36 +186,69 @@ Status BatchFetchRequest::RestoreRows(TensorTable *out) {
return Status::OK();
}
CreateCacheRequest::CreateCacheRequest(const CacheClientInfo &cinfo, uint64_t cache_mem_sz,
CreateCacheRequest::CreateCacheFlag flag)
: BaseRequest(RequestType::kCreateCache), cache_mem_sz_(cache_mem_sz), flag_(flag) {
// Type has been set already in the base constructor. So we need to fill in the connection info.
// On successful return, we will get the connection id
rq_.mutable_connection_info()->operator=(cinfo);
}
Status CreateCacheRequest::Prepare() {
try {
flatbuffers::FlatBufferBuilder fbb;
CreateCacheRequestMsgBuilder bld(fbb);
bld.add_cache_mem_sz(cache_mem_sz_);
bld.add_flag(static_cast<uint32_t>(flag_));
auto off = bld.Finish();
fbb.Finish(off);
rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize());
return Status::OK();
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
}
Status CacheSchemaRequest::SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map) {
try {
fbb_ = std::make_shared<flatbuffers::FlatBufferBuilder>();
flatbuffers::FlatBufferBuilder fbb;
std::vector<flatbuffers::Offset<ColumnNameMsg>> v;
v.reserve(map.size());
for (auto &column : map) {
auto c = CreateColumnNameMsg(*fbb_, fbb_->CreateString(column.first), column.second);
auto c = CreateColumnNameMsg(fbb, fbb.CreateString(column.first), column.second);
v.push_back(c);
}
auto v_off = fbb_->CreateVector(v);
auto final_off = CreateSchemaMsg(*fbb_, v_off);
fbb_->Finish(final_off);
buf_ = fbb_->GetBufferPointer();
len_of_buf_ = fbb_->GetSize();
auto v_off = fbb.CreateVector(v);
auto final_off = CreateSchemaMsg(fbb, v_off);
fbb.Finish(final_off);
rq_.add_buf_data(fbb.GetBufferPointer(), fbb.GetSize());
return Status::OK();
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory, __LINE__, __FILE__);
}
}
std::unordered_map<std::string, int32_t> FetchSchemaRequest::GetColumnMap() {
if (column_name_id_map_.empty()) {
auto *map_msg = flatbuffers::GetRoot<SchemaMsg>(mem_.GetPointer());
auto v = map_msg->column();
for (auto i = 0; i < v->size(); ++i) {
auto col = map_msg->column()->Get(i);
column_name_id_map_.emplace(col->name()->str(), col->id());
}
Status FetchSchemaRequest::PostReply() {
auto *map_msg = flatbuffers::GetRoot<SchemaMsg>(reply_.result().data());
auto v = map_msg->column();
for (auto i = 0; i < v->size(); ++i) {
auto col = map_msg->column()->Get(i);
column_name_id_map_.emplace(col->name()->str(), col->id());
}
return column_name_id_map_;
return Status::OK();
}
std::unordered_map<std::string, int32_t> FetchSchemaRequest::GetColumnMap() { return column_name_id_map_; }
Status GetStatRequest::PostReply() {
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(reply_.result().data());
stat_.num_disk_cached = msg->num_disk_cached();
stat_.num_mem_cached = msg->num_mem_cached();
stat_.avg_cache_sz = msg->avg_cache_sz();
stat_.max_row_id = msg->max_row_id();
stat_.min_row_id = msg->min_row_id();
stat_.cache_service_state = msg->state();
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
......@@ -18,11 +18,16 @@
#include <algorithm>
#include <memory>
#include <iostream>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#ifdef ENABLE_CACHE
#include "proto/cache_grpc.grpc.pb.h"
#endif
#include "proto/cache_grpc.pb.h"
#include "minddata/dataset/core/tensor_row.h"
#include "minddata/dataset/engine/cache/de_tensor_generated.h"
#include "minddata/dataset/util/slice.h"
......@@ -30,6 +35,17 @@
namespace mindspore {
namespace dataset {
class CacheClient;
/// \brief Statistic structure for GetStat request
struct CacheServiceStat {
int64_t num_mem_cached;
int64_t num_disk_cached;
int64_t avg_cache_sz;
row_id_type min_row_id;
row_id_type max_row_id;
int8_t cache_service_state;
};
/// \brief CacheClient communicates with CacheServer using Requests.
class BaseRequest {
public:
......@@ -44,195 +60,301 @@ class BaseRequest {
kCacheSchema = 6,
kFetchSchema = 7,
kBuildPhaseDone = 8,
kDropSession = 9,
kGenerateSessionId = 10,
kAllocateSharedBlock = 11,
kFreeSharedBlock = 12,
kStopService = 13,
// Add new request before it.
kRequestUnknown = 32767
};
// For kCreateCache
enum class CreateCacheFlag : uint32_t { kNone = 0, kSpillToDisk = 1, kGenerateRowId = 1u << 1L };
friend class CacheServer;
friend class CacheServerRequest;
friend class CacheClientGreeter;
friend class CacheClientRequestTag;
/// \brief Base class of a cache server request
/// \param connection_id A combination of session id and crc that uniquely identifies a connection.
/// \param type Type of the request
explicit BaseRequest(connection_id_type connection_id, RequestType type)
: type_(type), connection_id_(connection_id) {}
explicit BaseRequest(RequestType type) : type_(type) { rq_.set_type(static_cast<google::int32>(type_)); }
virtual ~BaseRequest() = default;
/// \brief Wait for the completion of a request
/// \return Status returned from the cache server
Status Wait() {
RETURN_IF_NOT_OK(wp_.Wait());
return rc_;
/// \brief A print method for debugging
/// \param out The output stream to write output to
virtual void Print(std::ostream &out) const { out << "Request type: " << static_cast<int16_t>(type_); }
/// \brief << Stream output operator overload
/// \param out reference to the output stream
/// \param rq reference to the BaseRequest
/// \return the output stream
friend std::ostream &operator<<(std::ostream &out, const BaseRequest &rq) {
rq.Print(out);
return out;
}
/// \brief Getter function of the current connection id
/// \return Connection id
connection_id_type GetServerConnectionId() const { return connection_id_; }
/// \brief Derived class can implement extra work to be done before the request is sent to the server
virtual Status Prepare() { return Status::OK(); }
/// \brief Derived class can implement extra work to be done after the server sends the request
virtual Status PostReply() { return Status::OK(); }
/// \brief A method for the client to wait for the availability of the result back from the server.
/// \return Status object
Status Wait();
protected:
CacheRequest rq_; // This is what we send to the server
CacheReply reply_; // This is what the server send back
private:
RequestType type_;
connection_id_type connection_id_;
Status rc_;
WaitPost wp_;
WaitPost wp_; // A sync area used by the client side.
};
class FreeSharedBlockRequest : public BaseRequest {
public:
friend class CacheServer;
explicit FreeSharedBlockRequest(connection_id_type connection_id, int64_t addr)
: BaseRequest(RequestType::kFreeSharedBlock) {
rq_.set_connection_id(connection_id);
rq_.add_buf_data(std::to_string(addr));
}
~FreeSharedBlockRequest() = default;
};
/// \brief Request to cache a single TensorRow
class CacheRowRequest : public BaseRequest {
public:
friend class CacheServer;
explicit CacheRowRequest(connection_id_type connection_id, const std::string &cookie)
: BaseRequest(connection_id, RequestType::kCacheRow), row_id_from_server_(-1), cookie_(cookie) {}
friend class CacheClient;
explicit CacheRowRequest(connection_id_type connection_id, const std::string &cookie, bool local_bypass)
: BaseRequest(RequestType::kCacheRow),
support_local_bypass_(local_bypass),
addr_(-1),
sz_(0),
row_id_from_server_(-1) {
rq_.set_connection_id(connection_id);
rq_.add_buf_data(cookie);
}
~CacheRowRequest() = default;
/// \brief Serialize a TensorRow for streaming to the cache server
/// \param row TensorRow
/// \return Status object
Status SerializeCacheRowRequest(const TensorRow &row);
Status SerializeCacheRowRequest(const CacheClient *cc, const TensorRow &row);
/// \brief Sanity check before we send the row.
/// \return Status object
Status Prepare() override;
/// \brief Override the base function get the row id returned from the server
/// \return Status object
Status PostReply() override;
/// \brief Return the row id assigned to this row for non-mappable dataset
/// \return row id of the cached row
row_id_type GetRowIdAfterCache() { return row_id_from_server_; }
/// \brief If we are doing local bypass, fill in extra request information of where the data is located.
void AddDataLocation() {
if (support_local_bypass_) {
rq_.add_buf_data(std::to_string(addr_));
rq_.add_buf_data(std::to_string(sz_));
}
}
/// \brief If we fail to send the data to the server using shared memory method, we should release
/// the shared memory by sending another request. The following function will generate a suitable
/// request for the CacheClient to send.
std::shared_ptr<FreeSharedBlockRequest> GenerateFreeBlockRequest() {
return std::make_shared<FreeSharedBlockRequest>(rq_.connection_id(), addr_);
}
private:
std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb_;
bool support_local_bypass_;
int64_t addr_;
int64_t sz_;
row_id_type row_id_from_server_;
std::vector<const void *> buffers_;
std::string cookie_;
/// \brief Private function to serialize one TensorRow
/// \param row TensorRow
/// \return Status object
Status SerializeTensorRowHeader(const TensorRow &row);
/// \brief Private function to serialize one Tensor
/// \param ts_ptr Tensor
/// \return Status object
Status SerializeOneTensorMeta(const std::shared_ptr<Tensor> &ts_ptr, flatbuffers::Offset<TensorMetaMsg> *out_off);
};
/// \brief Request to fetch rows in batch
class BatchFetchRequest : public BaseRequest {
public:
friend class CacheServer;
friend class CacheService;
BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id)
: BaseRequest(connection_id, RequestType::kBatchFetchRows), row_id_(row_id) {}
BatchFetchRequest(connection_id_type connection_id, const std::vector<row_id_type> &row_id, bool local_bypass);
~BatchFetchRequest() = default;
Status RestoreRows(TensorTable *out);
Status RestoreRows(TensorTable *out, const void *baseAddr, int64_t *out_addr);
private:
bool support_local_bypass_;
std::vector<row_id_type> row_id_;
MemGuard<uint8_t> mem_;
Status RestoreOneTensor(const TensorMetaMsg *col_ts, const ReadableSlice &data, std::shared_ptr<Tensor> *out);
};
/// \brief Request to create a cache for the current connection
class CreationCacheRequest : public BaseRequest {
class CreateCacheRequest : public BaseRequest {
public:
friend class CacheServer;
enum class CreateCacheFlag : uint32_t { kNone = 0, kSpillToDisk = 1, kGenerateRowId = 1u << 1L };
/// \brief Constructor
/// \param connection_id
/// \param cache_mem_sz Maximum memory assigned for this connection. 0 means unlimited
/// \param flag Attributes of the cache.
explicit CreationCacheRequest(connection_id_type connection_id, uint64_t cache_mem_sz,
CreateCacheFlag flag = CreateCacheFlag::kNone)
: BaseRequest(connection_id, RequestType::kCreateCache), cache_mem_sz(cache_mem_sz), flag_(flag) {}
~CreationCacheRequest() = default;
explicit CreateCacheRequest(const CacheClientInfo &cinfo, uint64_t cache_mem_sz,
CreateCacheFlag flag = CreateCacheFlag::kNone);
~CreateCacheRequest() = default;
void ParseResult(connection_id_type *id, std::string *out) {
auto p = flatbuffers::GetRoot<CreateCacheReplyMsg>(reply_.result().data());
*id = p->connection_id();
*out = p->cookie()->str();
}
std::string cookie() const { return cookie_; }
/// Overload the base class Prepare
Status Prepare() override;
private:
uint64_t cache_mem_sz;
uint64_t cache_mem_sz_;
CreateCacheFlag flag_;
std::string cookie_;
};
/// \brief Request to purge a cache.
class PurgeCacheRequest : public BaseRequest {
public:
friend class CacheServer;
explicit PurgeCacheRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kPurgeCache) {}
explicit PurgeCacheRequest(connection_id_type connection_id) : BaseRequest(RequestType::kPurgeCache) {
rq_.set_connection_id(connection_id);
}
~PurgeCacheRequest() = default;
};
/// \brief Request to destroy a cache
class DestroyCacheRequest : public BaseRequest {
public:
friend class CacheServer;
explicit DestroyCacheRequest(connection_id_type connection_id)
: BaseRequest(connection_id, RequestType::kDestroyCache) {}
/// \brief Destructor
explicit DestroyCacheRequest(connection_id_type connection_id) : BaseRequest(RequestType::kDestroyCache) {
rq_.set_connection_id(connection_id);
}
~DestroyCacheRequest() = default;
};
/// \brief Obtain the statistics of the current connection
class GetStatRequest : public BaseRequest {
public:
friend class CacheServer;
friend class CacheService;
explicit GetStatRequest(connection_id_type connection_id) : BaseRequest(connection_id, RequestType::kGetStat) {}
explicit GetStatRequest(connection_id_type connection_id) : BaseRequest(RequestType::kGetStat) {
rq_.set_connection_id(connection_id);
}
~GetStatRequest() = default;
row_id_type GetMinRowId() const {
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
return msg->min_row_id();
}
row_id_type GetMaxRowId() const {
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
return msg->max_row_id();
}
int64_t GetNumMemCached() const {
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
return msg->num_mem_cached();
}
int64_t GetNumDiskCached() const {
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
return msg->num_disk_cached();
}
uint8_t GetState() const {
auto *msg = flatbuffers::GetRoot<ServiceStatMsg>(mem_.GetPointer());
return msg->state();
/// \brief Override base function to process the result.
Status PostReply() override;
void GetStat(CacheServiceStat *stat) {
if (stat != nullptr) {
(*stat) = stat_;
}
}
private:
MemGuard<uint8_t> mem_;
CacheServiceStat stat_{};
};
/// \brief Request to cache a schema
class CacheSchemaRequest : public BaseRequest {
public:
friend class CacheServer;
explicit CacheSchemaRequest(connection_id_type connection_id)
: BaseRequest(connection_id, RequestType::kCacheSchema), buf_(nullptr), len_of_buf_(0) {}
explicit CacheSchemaRequest(connection_id_type connection_id) : BaseRequest(RequestType::kCacheSchema) {
rq_.set_connection_id(connection_id);
}
~CacheSchemaRequest() = default;
Status SerializeCacheSchemaRequest(const std::unordered_map<std::string, int32_t> &map);
const void *GetBuffer() const { return buf_; }
private:
std::shared_ptr<flatbuffers::FlatBufferBuilder> fbb_;
const void *buf_;
int64_t len_of_buf_;
};
/// \brief Request to fetch a schema
class FetchSchemaRequest : public BaseRequest {
public:
friend class CacheServer;
explicit FetchSchemaRequest(connection_id_type connection_id)
: BaseRequest(connection_id, RequestType::kFetchSchema) {}
explicit FetchSchemaRequest(connection_id_type connection_id) : BaseRequest(RequestType::kFetchSchema) {
rq_.set_connection_id(connection_id);
}
~FetchSchemaRequest() = default;
Status PostReply() override;
std::unordered_map<std::string, int32_t> GetColumnMap();
private:
MemGuard<uint8_t> mem_;
std::unordered_map<std::string, int32_t> column_name_id_map_;
};
/// \brief Request to change a cache from build phase to read phase. Applies to non-mappable cache only.
class BuildPhaseDoneRequest : public BaseRequest {
public:
friend class CacheServer;
BuildPhaseDoneRequest(connection_id_type connection_id, const std::string &cookie)
: BaseRequest(connection_id, RequestType::kBuildPhaseDone), cookie_(cookie) {}
: BaseRequest(RequestType::kBuildPhaseDone), cookie_(cookie) {
rq_.set_connection_id(connection_id);
rq_.add_buf_data(cookie_);
}
~BuildPhaseDoneRequest() = default;
private:
std::string cookie_;
};
/// \brief Request to drop all the caches in the current session
class DropSessionRequest : public BaseRequest {
public:
friend class CacheServer;
explicit DropSessionRequest(const CacheClientInfo &cinfo) : BaseRequest(RequestType::kDropSession) {
rq_.mutable_connection_info()->operator=(cinfo);
}
~DropSessionRequest() = default;
};
class GenerateSessionIdRequest : public BaseRequest {
public:
friend class CacheServer;
GenerateSessionIdRequest() : BaseRequest(RequestType::kGenerateSessionId) {
// We don't have anything client info nor connection id to send. But we will manually
// set the connection id to 0.
rq_.set_connection_id(0);
}
~GenerateSessionIdRequest() = default;
session_id_type GetSessionId() { return atoi(reply_.result().data()); }
};
class AllocateSharedBlockRequest : public BaseRequest {
public:
friend class CacheServer;
explicit AllocateSharedBlockRequest(connection_id_type connection_id, size_t requestedSz)
: BaseRequest(RequestType::kAllocateSharedBlock) {
rq_.set_connection_id(connection_id);
rq_.add_buf_data(std::to_string(requestedSz));
}
~AllocateSharedBlockRequest() = default;
/// \brief On return from the server, we get the (relative) address where
/// the free block is located.
/// \return
int64_t GetAddr() {
auto addr = strtoll(reply_.result().data(), nullptr, 10);
return addr;
}
};
class ShutdownRequest : public BaseRequest {
public:
friend class CacheServer;
ShutdownRequest() : BaseRequest(RequestType::kStopService) {}
~ShutdownRequest() = default;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVICE_H_
......@@ -24,8 +24,11 @@
#include <utility>
#include <vector>
#include <map>
#include <set>
#include "minddata/dataset/engine/cache/cache_service.h"
#include "minddata/dataset/engine/cache/cache_grpc_server.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/util/allocator.h"
#include "minddata/dataset/util/arena.h"
#include "minddata/dataset/util/cache_pool.h"
#include "minddata/dataset/util/lock.h"
......@@ -37,43 +40,131 @@
namespace mindspore {
namespace dataset {
class BaseRequest;
/// \brief A server which provides CacheService services.
class CacheServer : public Service {
public:
friend class Services;
using cache_index = std::map<connection_id_type, std::unique_ptr<CacheService>>;
class Builder {
public:
Builder() : top_("/tmp"), num_workers_(32), port_(50052), shared_memory_sz_in_gb_(4) {}
/// \brief Getter functions
const std::string &getTop() const { return top_; }
int32_t getNumWorkers() const { return num_workers_; }
int32_t getPort() const { return port_; }
int32_t getSharedMemorySzInGb() const { return shared_memory_sz_in_gb_; }
Builder &SetRootDirectory(std::string root) {
top_ = std::move(root);
return *this;
}
Builder &SetNumWorkers(int32_t n) {
num_workers_ = n;
return *this;
}
Builder &SetPort(int32_t p) {
port_ = p;
return *this;
}
Builder &SetSharedMemorySizeInGB(int32_t sz) {
shared_memory_sz_in_gb_ = sz;
return *this;
}
Status SanityCheck();
void Print(std::ostream &out) const {
out << "Summary of the cache server configuration\n"
<< "Spill directory: " << getTop() << "\n"
<< "Number of parallel workers: " << getNumWorkers() << "\n"
<< "Tcp/ip port: " << getPort() << "\n"
<< "Shared memory size (in GB): " << getSharedMemorySzInGb();
}
friend std::ostream &operator<<(std::ostream &out, const Builder &bld) {
bld.Print(out);
return out;
}
Status Build() {
RETURN_IF_NOT_OK(SanityCheck());
// We need to bring up the Task Manager by bringing up the Services singleton.
RETURN_IF_NOT_OK(Services::CreateInstance());
RETURN_IF_NOT_OK(CacheServer::CreateInstance(top_, num_workers_, port_, shared_memory_sz_in_gb_));
return Status::OK();
}
private:
std::string top_;
int32_t num_workers_;
int32_t port_;
int32_t shared_memory_sz_in_gb_;
};
CacheServer(const CacheServer &) = delete;
CacheServer &operator=(const CacheServer &) = delete;
CacheServer(CacheServer &&) = delete;
CacheServer &operator=(CacheServer &) = delete;
static CacheServer &GetInstance() noexcept { return Services::getCacheServer(); }
Status DoServiceStart() override;
Status DoServiceStop() override;
~CacheServer() { (void)ServiceStop(); }
static Status CreateInstance(const std::string &spill_path, int32_t num_workers, int32_t port,
int32_t shared_memory_sz) {
std::call_once(init_instance_flag_, [&]() -> Status {
auto &svcManager = Services::GetInstance();
RETURN_IF_NOT_OK(svcManager.AddHook(&instance_, spill_path, num_workers, port, shared_memory_sz));
return Status::OK();
});
return Status::OK();
}
static CacheServer &GetInstance() { return *instance_; }
/// \brief For the current demonstration, a cache client contacts cache server using a Queue.
/// \param rq
/// \return Status object
Status PushRequest(BaseRequest *rq) {
Status PushRequest(int32_t queue_id, CacheServerRequest *rq) {
RETURN_UNEXPECTED_IF_NULL(rq);
RETURN_IF_NOT_OK(cache_q_->Add(rq));
RETURN_IF_NOT_OK(cache_q_->operator[](queue_id)->Add(rq));
return Status::OK();
}
/// \\brief Kick off server threads. Never return unless error out.
Status Run();
/// \brief Get a free tag
/// \param q[in] pointer to a pointer to a CacheServerRequest
/// \return Status object
static Status GetFreeRequestTag(int32_t queue_id, CacheServerRequest **q);
/// \brief Return a tag to the free list
/// \param p[in] pointer to already finished CacheServerRequest tag
/// \return Status object
static Status ReturnRequestTag(CacheServerRequest *p);
private:
static std::once_flag init_instance_flag_;
static CacheServer *instance_;
mutable RWLock rwLock_;
std::string top_;
cache_index all_caches_;
std::shared_ptr<Queue<BaseRequest *>> cache_q_;
std::set<session_id_type> history_sessions_;
std::shared_ptr<QueueList<CacheServerRequest *>> cache_q_;
std::shared_ptr<QueueList<CacheServerRequest *>> free_list_;
std::vector<std::unique_ptr<MemGuard<CacheServerRequest, Allocator<CacheServerRequest>>>> tag_;
std::shared_ptr<CacheServerGreeterImpl> comm_layer_;
std::shared_ptr<MemoryPool> mp_;
TaskGroup vg_;
int32_t num_workers_;
int32_t port_;
int32_t shared_memory_sz_in_gb_;
std::atomic<bool> global_shutdown_;
/// \brief Constructor
/// \param spill_path Top directory for spilling buffers to.
/// \param num_workers Number of threads for handling requests.
explicit CacheServer(const std::string &spill_path, int32_t num_workers = 3);
explicit CacheServer(const std::string &spill_path, int32_t num_workers, int32_t port, int32_t share_memory_sz_in_gb);
/// \brief Locate a cache service from connection id.
/// \return Pointer to cache service. Null if not found
......@@ -82,16 +173,65 @@ class CacheServer : public Service {
/// \brief Create a cache service. We allow multiple clients to create the same cache service.
/// Subsequent duplicate requests are ignored. The first cache client to create the service will be given
/// a special unique cookie.
/// \param[in] connection_id This is from a Cache client.
/// \param[in] cache_mem_sz
/// \param[in] flag
/// \param[out] out_cookie Only the first cache client will be given a special cookie to identify the creator
/// \return Status object
Status CreateService(connection_id_type connection_id, uint64_t cache_mem_sz, BaseRequest::CreateCacheFlag flag,
std::string *out_cookie);
Status CreateService(CacheRequest *rq, CacheReply *reply);
/// \brief Destroy a cache service
/// \param cs
/// \param rq
/// \return
Status DestroyCache(CacheService *cs, CacheRequest *rq);
Status PurgeCache(CacheService *cs);
/// \brief Entry point for all internal server threads.
Status ServerRequest(int32_t worker_id);
/// \brief Entry point for all grpc threads.
/// \return
Status RpcRequest(int32_t worker_id);
Status DestroySession(CacheRequest *rq);
/// \brief Create a connection id from a session id and a crc
/// \param session_id
/// \param crc
/// \return connection id
connection_id_type GetConnectionID(session_id_type session_id, uint32_t crc) const;
/// \brief Extract the session id from a connection id
/// \param connection_id
/// \return session id
session_id_type GetSessionID(connection_id_type connection_id) const;
/// \brief Generate a session ID for the client
/// \return Session ID
session_id_type GenerateSessionID() const;
/// \brief Handle kAllocateSharedBlock request
/// \param rq CacheRequest
/// \param reply CacheReply
/// \return Status object
Status AllocateSharedMemory(CacheRequest *rq, CacheReply *reply);
/// \brief Handle kFreeSharedBlock request
/// \param rq
/// \return Status object
Status FreeSharedMemory(CacheRequest *rq);
/// \brief Entry point for all server threads.
Status ServerRequest();
/// \brief Handle kFastCacheRow request
/// \return Status object
Status FastCacheRow(CacheService *cs, CacheRequest *rq, CacheReply *reply);
/// \brief Internal function to do row batch fetch
/// \param cs CacheService
/// \param rq Request
/// \param reply Reply
/// \return
Status BatchFetchRows(CacheService *cs, CacheRequest *rq, CacheReply *reply);
/// \brief A proper shutdown of the server
/// \return Status object
Status GlobalShutdown();
};
} // namespace dataset
} // namespace mindspore
......
......@@ -76,7 +76,7 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type
*row_id_generated = GetNextRowId();
// Some debug information on how many rows we have generated so far.
if ((*row_id_generated) % 1000 == 0) {
MS_LOG(DEBUG) << "Number of rows cached: " << *row_id_generated;
MS_LOG(DEBUG) << "Number of rows cached: " << (*row_id_generated) + 1;
}
} else {
if (msg->row_id() < 0) {
......@@ -114,6 +114,45 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type
RETURN_STATUS_UNEXPECTED(e.what());
}
}
Status CacheService::FastCacheRow(const ReadableSlice &src, row_id_type *row_id_generated) {
SharedLock rw(&rw_lock_);
RETURN_UNEXPECTED_IF_NULL(row_id_generated);
if (st_ == State::kFetchPhase) {
// For this kind of cache service, once we are done with the build phase into fetch phase, we can't
// allow other to cache more rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
}
try {
// If we don't need to generate id, we need to find it from the buffer.
if (generate_id_) {
*row_id_generated = GetNextRowId();
// Some debug information on how many rows we have generated so far.
if ((*row_id_generated) % 1000 == 0) {
MS_LOG(DEBUG) << "Number of rows cached: " << (*row_id_generated) + 1;
}
} else {
auto msg = GetTensorRowHeaderMsg(src.GetPointer());
if (msg->row_id() < 0) {
std::string errMsg = "Expect positive row id: " + std::to_string(msg->row_id());
RETURN_STATUS_UNEXPECTED(errMsg);
}
*row_id_generated = msg->row_id();
}
// Now we cache the flat buffer.
CachePool::key_type key;
RETURN_IF_NOT_OK(cp_->Insert({src}, &key));
Status rc = map_->DoInsert(*row_id_generated, key);
if (rc == Status(StatusCode::kDuplicateKey)) {
MS_LOG(DEBUG) << "Ignoring duplicate key.";
} else {
RETURN_IF_NOT_OK(rc);
}
return Status::OK();
} catch (const std::exception &e) {
RETURN_STATUS_UNEXPECTED(e.what());
}
}
std::ostream &operator<<(std::ostream &out, const CacheService &cs) {
// Then show any custom derived-internal stuff
out << "\nCache memory size: " << cs.cache_mem_sz_;
......@@ -155,20 +194,15 @@ Status CacheService::GetStat(CacheService::ServiceStat *out) {
}
return Status::OK();
}
Status CacheService::BatchFetch(const std::vector<row_id_type> &v, MemGuard<uint8_t> *out) const {
RETURN_UNEXPECTED_IF_NULL(out);
Status CacheService::PreBatchFetch(const std::vector<row_id_type> &v, std::vector<key_size_pair> *out,
int64_t *mem_sz) {
SharedLock rw(&rw_lock_);
if (st_ == State::kBuildPhase) {
// For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
}
RETURN_UNEXPECTED_IF_NULL(out);
RETURN_UNEXPECTED_IF_NULL(mem_sz);
const auto num_elements = v.size();
int64_t mem_sz = (num_elements + 1) * sizeof(int64_t);
int64_t data_offset = mem_sz;
std::vector<int64_t> sz_v;
std::vector<CachePool::key_type> keys;
sz_v.reserve(num_elements);
keys.reserve(num_elements);
*mem_sz = (num_elements + 1) * sizeof(int64_t);
(*out).reserve(num_elements);
for (auto row_id : v) {
auto r = map_->Search(row_id);
if (r.second) {
......@@ -180,25 +214,33 @@ Status CacheService::BatchFetch(const std::vector<row_id_type> &v, MemGuard<uint
errMsg += std::to_string(key);
RETURN_STATUS_UNEXPECTED(errMsg);
}
keys.push_back(key);
sz_v.push_back(sz);
mem_sz += sz;
(*out).emplace_back(key, sz);
(*mem_sz) += sz;
} else {
keys.push_back(-1);
sz_v.push_back(0);
(*out).emplace_back(-1, 0);
}
}
MemGuard<uint8_t> mem;
RETURN_IF_NOT_OK(mem.allocate(mem_sz));
auto *offset_array = reinterpret_cast<int64_t *>(mem.GetMutablePointer());
return Status::OK();
}
Status CacheService::BatchFetch(const std::vector<row_id_type> &v, const std::vector<key_size_pair> &info,
WritableSlice *out) const {
RETURN_UNEXPECTED_IF_NULL(out);
SharedLock rw(&rw_lock_);
if (st_ == State::kBuildPhase) {
// For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
}
const auto num_elements = v.size();
int64_t data_offset = (num_elements + 1) * sizeof(int64_t);
auto *offset_array = reinterpret_cast<int64_t *>(out->GetMutablePointer());
offset_array[0] = data_offset;
WritableSlice all(mem.GetMutablePointer(), mem.GetSizeInBytes());
for (auto i = 0; i < num_elements; ++i) {
auto sz = sz_v.at(i);
auto sz = info.at(i).second;
offset_array[i + 1] = offset_array[i] + sz;
if (sz > 0) {
WritableSlice row_data(all, offset_array[i], sz);
auto key = keys.at(i);
WritableSlice row_data(*out, offset_array[i], sz);
auto key = info.at(i).first;
size_t bytesRead = 0;
RETURN_IF_NOT_OK(cp_->Read(key, &row_data, &bytesRead));
if (bytesRead != sz) {
......@@ -208,7 +250,6 @@ Status CacheService::BatchFetch(const std::vector<row_id_type> &v, MemGuard<uint
}
}
}
*out = std::move(mem);
return Status::OK();
}
Status CacheService::CacheSchema(const void *buf, int64_t len) {
......@@ -232,18 +273,26 @@ Status CacheService::CacheSchema(const void *buf, int64_t len) {
}
return Status::OK();
}
Status CacheService::FetchSchema(MemGuard<uint8_t> *out) const {
Status CacheService::FetchSchema(std::string *out) const {
SharedLock rw(&rw_lock_);
if (st_ == State::kBuildPhase) {
// For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
RETURN_STATUS_UNEXPECTED("Can't accept cache request in fetch phase");
}
RETURN_UNEXPECTED_IF_NULL(out);
MemGuard<uint8_t> mem;
// We are going to use std::string to allocate and hold the result which will be eventually
// 'moved' to the protobuf message (which underneath is also a std::string) for the purpose
// to minimize memory copy.
std::string mem;
if (schema_key_ >= 0) {
auto len = cp_->GetSize(schema_key_);
RETURN_IF_NOT_OK(mem.allocate(len));
auto slice = WritableSlice(mem.GetMutablePointer(), len);
try {
mem.resize(len);
CHECK_FAIL_RETURN_UNEXPECTED(mem.capacity() >= len, "Programming error");
} catch (const std::bad_alloc &e) {
return Status(StatusCode::kOutOfMemory);
}
auto slice = WritableSlice(mem.data(), len);
RETURN_IF_NOT_OK(cp_->Read(schema_key_, &slice));
*out = std::move(mem);
} else {
......
......@@ -28,7 +28,6 @@
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/cache/cache_request.h"
#include "minddata/dataset/engine/cache/de_tensor_generated.h"
#include "minddata/dataset/util/arena.h"
#include "minddata/dataset/util/btree.h"
#include "minddata/dataset/util/cache_pool.h"
......@@ -38,7 +37,8 @@
namespace mindspore {
namespace dataset {
struct CacheStat;
/// Some typedef used for BatchFetch
using key_size_pair = std::pair<CachePool::key_type, size_t>;
/// \brief A cache service for storing/fetching buffers to in memory cache and may spill to disk the cache service is
/// created to support spilling
class CacheService : public Service {
......@@ -69,12 +69,26 @@ class CacheService : public Service {
/// \param[out] row_id_generated The row id assigned to this row if any
/// \return Status object
Status CacheRow(const std::vector<const void *> &buf, row_id_type *row_id_generated);
/// \brief A fast version of CacheRow where all the data is already in one contiguous piece.
/// \param src Slice of the data
/// \param row_id_generated
/// \return Status object
Status FastCacheRow(const ReadableSlice &src, row_id_type *row_id_generated);
/// \brief This function is used in preparation for batch fetching.
/// It calculates how much memory we should allocate and which row id are present.
/// \param[in/out] Pointer to vector of <CachePool::key_type, size_t>
/// \param[in/out] mem_sz how much memory is required to batch fetch
/// \return Status object
Status PreBatchFetch(const std::vector<row_id_type> &v, std::vector<key_size_pair> *, int64_t *mem_sz);
/// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded
/// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row.
/// \param[in] v A vector of row id.
/// \param[out] out A contiguous memory buffer that holds the requested rows.
/// \return Status object
Status BatchFetch(const std::vector<row_id_type> &v, MemGuard<uint8_t> *out) const;
Status BatchFetch(const std::vector<row_id_type> &v, const std::vector<key_size_pair> &, WritableSlice *out) const;
/// \brief Getter function
/// \return Spilling path
......@@ -102,7 +116,7 @@ class CacheService : public Service {
/// \brief Fetch schema
/// \param out A contiguous memory that contains the serialized form of schema.
/// \return Status object
Status FetchSchema(MemGuard<uint8_t> *out) const;
Status FetchSchema(std::string *out) const;
/// \brief Purge the content of a cache
/// \return Status object
Status Purge();
......
......@@ -60,10 +60,11 @@ table TensorRowIds {
}
/// Statistics returned from each cache service
/// \note It must match CacheService::ServiceStat
/// \note It must match CacheServiceStat
table ServiceStatMsg {
num_mem_cached:int64;
num_disk_cached:int64;
avg_cache_sz:int64;
min_row_id:int64;
max_row_id:int64;
state:int8;
......@@ -79,3 +80,15 @@ table ColumnNameMsg {
table SchemaMsg {
column:[ColumnNameMsg];
}
/// Part of the CreateCacheRequest
table CreateCacheRequestMsg {
cache_mem_sz:int64;
flag:uint32;
}
/// Return result of CreateCacheRequest
table CreateCacheReplyMsg {
connection_id:int64;
cookie:string;
}
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_STUB_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_STUB_H_
#include <memory>
#include <string>
#include "proto/cache_grpc.pb.h"
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/engine/cache/cache_request.h"
#include "minddata/dataset/util/service.h"
namespace mindspore {
namespace dataset {
class CacheClientGreeter : public Service {
public:
explicit CacheClientGreeter(const std::string &hostname, int32_t port, int32_t num_workers) {}
~CacheClientGreeter() override {}
Status DoServiceStart() override { RETURN_STATUS_UNEXPECTED("Not supported"); }
Status DoServiceStop() override { RETURN_STATUS_UNEXPECTED("Not supported"); }
void *SharedMemoryBaseAddr() { return nullptr; }
Status HandleRequest(std::shared_ptr<BaseRequest> rq) { RETURN_STATUS_UNEXPECTED("Not supported"); }
Status AttachToSharedMemory(int32_t port, bool *local_bypass) { RETURN_STATUS_UNEXPECTED("Not supported"); }
protected:
private:
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_STUB_H_
......@@ -16,6 +16,7 @@
#include "minddata/dataset/engine/datasetops/cache_base_op.h"
#include <iomanip>
#include <iostream>
#include <utility>
#include "minddata/dataset/engine/execution_tree.h"
namespace mindspore {
......@@ -47,22 +48,39 @@ Status CacheBase::Reset() {
}
CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler)
: ParallelOp(num_workers, op_connector_size, sampler),
cache_client_(cache_client),
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
row_cnt_(0),
num_cache_miss_(0),
cache_client_(std::move(cache_client)),
rows_per_buffer_(rows_per_buf),
// We can cause deadlock if this internal Connector size is too small.
keys_miss_(num_workers_, 1, connector_capacity_) {
keys_miss_(num_workers_, 1, connector_capacity_),
prefetch_size_(cache_client_->getPrefetchSize()) {
io_block_queues_.Init(num_workers, op_connector_size);
prefetch_queues_.Init(num_workers, op_connector_size);
sampler_queue_ = std::make_unique<Queue<std::shared_ptr<Tensor>>>(op_connector_size);
}
// Common function to fetch samples from the sampler and send them using the io_block_queues to
// the parallel workers
Status CacheBase::FetchSamplesToWorkers() {
int64_t buf_cnt = 0;
int64_t wait_cnt = 0;
// Kick off several threads which will prefetch prefetch_size_ rows in advance. The rows_per_buffers_
// is too small (1 by default) and won't help performance.
RETURN_IF_NOT_OK(tree_->AllTasks()->CreateAsyncTask("Dispatcher", std::bind(&CacheBase::Dispatcher, this)));
RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&CacheBase::Prefetcher, this, std::placeholders::_1)));
// Instead of sending sampler id to WorkerEntry, we send them to the Prefetcher which will redirect them
// to the WorkerEntry.
do {
epoch_sync_.Clear();
if (AllowCacheMiss() && wait_cnt > 0) {
MS_LOG(WARNING) << "Epoch: " << wait_cnt << " Cache Miss : " << num_cache_miss_
<< " Total number of rows : " << row_cnt_;
}
num_cache_miss_ = 0;
row_cnt_ = 0;
++wait_cnt;
std::vector<row_id_type> keys;
int64_t row_cnt = 0;
keys.reserve(rows_per_buffer_);
std::unique_ptr<DataBuffer> sampler_buffer;
RETURN_IF_NOT_OK(sampler_->GetNextSample(&sampler_buffer));
......@@ -70,10 +88,13 @@ Status CacheBase::FetchSamplesToWorkers() {
TensorRow sample_row;
RETURN_IF_NOT_OK(sampler_buffer->PopRow(&sample_row));
std::shared_ptr<Tensor> sample_ids = sample_row[0];
// Send the sampler tensor to other thread for prefetching. We are using shared pointer so it
// won't go out scope until it is really not in use.
RETURN_IF_NOT_OK(sampler_queue_->Add(sample_ids));
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) {
keys.push_back(*itr);
++row_cnt;
if (row_cnt % rows_per_buffer_ == 0) {
++row_cnt_;
if (row_cnt_ % rows_per_buffer_ == 0) {
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
RETURN_IF_NOT_OK(io_block_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk)));
keys.clear();
......@@ -90,7 +111,7 @@ Status CacheBase::FetchSamplesToWorkers() {
io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEoe)));
// If repeat but the not last repeat, wait for reset.
if (!IsLastIteration()) {
MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << ++wait_cnt << " Buffer sent " << buf_cnt;
MS_LOG(DEBUG) << Name() << " Waiting for reset. Count " << wait_cnt << " Buffer sent " << buf_cnt;
RETURN_IF_NOT_OK(epoch_sync_.Wait());
} else {
// We can break out from the loop.
......@@ -101,13 +122,21 @@ Status CacheBase::FetchSamplesToWorkers() {
// Flow the eof before exit
RETURN_IF_NOT_OK(
io_block_queues_[(buf_cnt++) % num_workers_]->Add(std::make_unique<IOBlock>(IOBlock::kDeIoBlockFlagEof)));
// Ask all the workers to quit.
// Shutdown threads
std::shared_ptr<Tensor> empty;
RETURN_IF_NOT_OK(sampler_queue_->Add(std::move(empty)));
for (int32_t i = 0; i < num_workers_; i++) {
RETURN_IF_NOT_OK(
io_block_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone)));
}
// Dump the last epoch result (approximately) without waiting for the worker threads to come back.
if (AllowCacheMiss()) {
MS_LOG(WARNING) << "Epoch: " << wait_cnt << " Cache Miss : " << num_cache_miss_
<< " Total number of rows : " << row_cnt_;
}
return Status::OK();
}
Status CacheBase::FetchFromCache(int32_t worker_id) {
int64_t buffer_id = worker_id;
std::unique_ptr<IOBlock> blk;
......@@ -133,23 +162,16 @@ Status CacheBase::FetchFromCache(int32_t worker_id) {
}
std::unique_ptr<DataBuffer> db = std::make_unique<DataBuffer>(buffer_id, DataBuffer::kDeBFlagNone);
std::unique_ptr<TensorQTable> que = std::make_unique<TensorQTable>();
TensorTable ttbl;
RETURN_IF_NOT_OK(cache_client_->GetRows(keys, &ttbl));
auto row_it = ttbl.begin();
std::vector<row_id_type> cache_miss;
cache_miss.reserve(keys.size());
for (auto row_id : keys) {
auto &row = *row_it;
TensorRow row;
// Block until the row shows up in the pool.
RETURN_IF_NOT_OK(prefetch_.PopFront(row_id, &row));
if (row.empty()) {
if (AllowCacheMiss()) {
cache_miss.push_back(row_id);
} else {
std::string errMsg = "Row id " + std::to_string(row_id) + " not found.";
RETURN_STATUS_UNEXPECTED(errMsg);
}
cache_miss.push_back(row_id);
}
que->push_back(std::move(row));
++row_it;
}
db->set_tensor_table(std::move(que));
if (AllowCacheMiss()) {
......@@ -162,12 +184,17 @@ Status CacheBase::FetchFromCache(int32_t worker_id) {
} while (true);
return Status::OK();
}
Status CacheBase::RegisterResources() {
RETURN_IF_NOT_OK(epoch_sync_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(io_block_queues_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(prefetch_queues_.Register(tree_->AllTasks()));
RETURN_IF_NOT_OK(sampler_queue_->Register(tree_->AllTasks()));
return Status::OK();
}
CacheBase::~CacheBase() {}
CacheBase::~CacheBase() = default;
Status CacheBase::UpdateColumnMapFromCache() {
Status rc;
// Get the schema from the server. It may not be there yet. So tolerate the error.
......@@ -180,5 +207,77 @@ Status CacheBase::UpdateColumnMapFromCache() {
}
return rc;
}
Status CacheBase::Dispatcher() {
TaskManager::FindMe()->Post();
int64_t buf_cnt = 0;
int64_t num_row = 0;
std::vector<row_id_type> keys;
keys.reserve(prefetch_size_);
do {
keys.clear();
std::shared_ptr<Tensor> sample_ids;
RETURN_IF_NOT_OK(sampler_queue_->PopFront(&sample_ids));
if (sample_ids == nullptr) {
// A null shared pointer signal times to quit.
// Also signal all prefetchers to quit.
for (int32_t i = 0; i < num_workers_; i++) {
RETURN_IF_NOT_OK(
prefetch_queues_[i]->Add(std::make_unique<IOBlock>(std::vector<int64_t>(), IOBlock::kDeIoBlockNone)));
}
break;
}
// Now we distribute the sampler ids to each prefetcher according to the prefetch size.
for (auto itr = sample_ids->begin<int64_t>(); itr != sample_ids->end<int64_t>(); itr++) {
keys.push_back(*itr);
++num_row;
if (num_row % prefetch_size_ == 0) {
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
RETURN_IF_NOT_OK(prefetch_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk)));
keys.clear();
}
}
// Send the remaining sample id
if (!keys.empty()) {
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
RETURN_IF_NOT_OK(prefetch_queues_[buf_cnt++ % num_workers_]->Add(std::move(blk)));
}
} while (true);
return Status::OK();
}
Status CacheBase::Prefetcher(int32_t worker_id) {
TaskManager::FindMe()->Post();
std::vector<row_id_type> prefetch_keys;
prefetch_keys.reserve(prefetch_size_);
do {
prefetch_keys.clear();
std::unique_ptr<IOBlock> blk;
RETURN_IF_NOT_OK(prefetch_queues_[worker_id]->PopFront(&blk));
RETURN_IF_NOT_OK(blk->GetKeys(&prefetch_keys));
if (prefetch_keys.empty()) {
// Empty keys mean time to quit.
break;
}
TensorTable ttbl;
RETURN_IF_NOT_OK(cache_client_->GetRows(prefetch_keys, &ttbl));
auto row_it = ttbl.begin();
for (auto row_id : prefetch_keys) {
auto &row = *row_it;
if (row.empty()) {
if (AllowCacheMiss()) {
++num_cache_miss_;
} else {
std::string errMsg = "Row id " + std::to_string(row_id) + " not found.";
RETURN_STATUS_UNEXPECTED(errMsg);
}
}
// Put the prefetch row into the pool and wake up any WorkerEntry to wait for the row
RETURN_IF_NOT_OK(prefetch_.Add(row_id, std::move(row)));
++row_it;
}
} while (true);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
......@@ -16,6 +16,8 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_
#include <atomic>
#include <deque>
#include <memory>
#include <string>
#include <utility>
......@@ -28,8 +30,9 @@
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/util/queue.h"
#include "minddata/dataset/util/queue_map.h"
#include "minddata/dataset/util/semaphore.h"
#include "minddata/dataset/util/wait_post.h"
#include "minddata/dataset/engine/datasetops/cache_base_op.h"
namespace mindspore {
namespace dataset {
/// \brief This is the base class for CacheOp and CacheLookupOp which share many similarities.
......@@ -82,10 +85,13 @@ class CacheBase : public ParallelOp {
protected:
constexpr static int32_t eoe_row_id = -1;
int64_t row_cnt_;
std::atomic<int64_t> num_cache_miss_;
std::shared_ptr<CacheClient> cache_client_;
WaitPost epoch_sync_;
int32_t rows_per_buffer_;
Connector<std::vector<row_id_type>> keys_miss_;
QueueMap<row_id_type, TensorRow> prefetch_;
/// \brief Common function to register resources for interrupt
/// \note Derived should override this function for extra resources to be registered
......@@ -103,7 +109,15 @@ class CacheBase : public ParallelOp {
private:
constexpr static int32_t connector_capacity_ = 1024;
int32_t prefetch_size_;
QueueList<std::unique_ptr<IOBlock>> io_block_queues_;
QueueList<std::unique_ptr<IOBlock>> prefetch_queues_;
std::unique_ptr<Queue<std::shared_ptr<Tensor>>> sampler_queue_;
Status Dispatcher();
/// \brief Prefetcher. It prefetch the rows from cache server
/// \return Status object.
Status Prefetcher(int32_t worker_id);
};
} // namespace dataset
} // namespace mindspore
......
......@@ -16,8 +16,10 @@
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
#include <algorithm>
#include <chrono>
#include <functional>
#include <iomanip>
#include <utility>
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/core/global_context.h"
......@@ -41,9 +43,13 @@ void CacheMergeOp::Print(std::ostream &out, bool show_all) const {
out << "\n\n";
}
}
CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners,
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler)
: ParallelOp(numWorkers, opConnectorSize, sampler), num_cleaners_(numCleaners), cache_client_(cache_client) {}
: ParallelOp(numWorkers, opConnectorSize, sampler),
num_cleaners_(numCleaners),
cache_client_(std::move(cache_client)) {}
Status CacheMergeOp::operator()() {
// A queue of row id to let cleaner send cache miss rows to the cache server
// We don't want a small queue as this will block the parallel op workers.
......@@ -62,6 +68,7 @@ Status CacheMergeOp::operator()() {
TaskManager::FindMe()->Post();
return Status::OK();
}
// Each parallel worker will pop from the CacheHit stream. If there is a missing TensorRow, we will wait
// until it shows up in the pool.
Status CacheMergeOp::WorkerEntry(int32_t worker_id) {
......@@ -82,10 +89,8 @@ Status CacheMergeOp::WorkerEntry(int32_t worker_id) {
RETURN_IF_NOT_OK(db_ptr->PopRow(&row));
if (row.empty()) {
auto row_id = row.getId();
TensorRowRequest *rq = nullptr;
RETURN_IF_NOT_OK(GetRq(row_id, &rq));
// Block until the row shows up in the pool.
RETURN_IF_NOT_OK(rq->Wait(&row));
RETURN_IF_NOT_OK(cache_miss_.PopFront(row_id, &row));
}
tbl->push_back(std::move(row));
}
......@@ -97,6 +102,7 @@ Status CacheMergeOp::WorkerEntry(int32_t worker_id) {
RETURN_IF_NOT_OK(EofReceived(worker_id));
return Status::OK();
}
Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
TaskManager::FindMe()->Post();
// We will simply pop TensorRow from the stream and insert them into the pool and
......@@ -123,17 +129,27 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
std::string errMsg = "Expect positive row id: " + std::to_string(row_id);
RETURN_STATUS_UNEXPECTED(errMsg);
}
TensorRowRequest *rq = nullptr;
// Technically number of this row shows up in the cache miss stream is equal to the number
// of P() call. However the cleaner wants it too. So we need an extra copy.
TensorRowCacheRequest *rq;
RETURN_IF_NOT_OK(GetRq(row_id, &rq));
rq->WakeUpAny(std::move(row));
// Let the cleaner to flush out this row (async) to the cache server.
RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id));
if (rq->GetState() == TensorRowCacheRequest::State::kEmpty) {
// We will send the request async. But any error we most
// likely ignore and continue.
Status rc;
rc = rq->AsyncSendCacheRequest(cache_client_, row);
if (rc.IsOk()) {
RETURN_IF_NOT_OK(io_que_->EmplaceBack(row_id));
}
}
RETURN_IF_NOT_OK(cache_miss_.Add(row_id, std::move(row)));
}
}
RETURN_IF_NOT_OK(cache_missing_stream->GetNextBuffer(&db_ptr, workerId));
}
return Status::OK();
}
Status CacheMergeOp::Cleaner() {
TaskManager::FindMe()->Post();
while (true) {
......@@ -142,45 +158,28 @@ Status CacheMergeOp::Cleaner() {
if (row_id < 0) {
break;
}
TensorRowRequest *rq = nullptr;
// Locate the cache request
TensorRowCacheRequest *rq;
RETURN_IF_NOT_OK(GetRq(row_id, &rq));
if (rq->GetState() == TensorRowRequest::State::kClean) {
// If already flushed, move on to the next one.
// If already flushed, move on to the next one.
if (rq->GetState() == TensorRowCacheRequest::State::kClean) {
continue;
}
TensorRow row;
RETURN_IF_NOT_OK(rq->Release(&row));
CHECK_FAIL_RETURN_UNEXPECTED(!row.empty(), "Programming error.");
Status rc = cache_client_->WriteRow(row);
// Bad rc should not bring down the pipeline
Status rc = rq->CheckCacheResult();
if (rc.IsError()) {
MS_LOG(WARNING) << "Cache not successful." << rc.ToString();
// If interrupt, time to quit.
if (rc.get_code() == StatusCode::kInterrupted) {
return Status::OK();
}
MS_LOG(INFO) << "Cache row not successful: " << rc.ToString();
// Bad rc should not bring down the pipeline. We will simply continue and
// change the state back to empty. We don't need a CAS from CLEAN back to EMPTY.
rq->SetState(TensorRowCacheRequest::State::kEmpty);
}
rq->SetState(TensorRowRequest::State::kClean);
}
return Status::OK();
}
Status CacheMergeOp::GetRq(row_id_type row_id, CacheMergeOp::TensorRowRequest **out) {
RETURN_UNEXPECTED_IF_NULL(out);
std::unique_lock<std::mutex> lck(mux_);
auto it = cache_miss_map_.find(row_id);
if (it != cache_miss_map_.end()) {
*out = it->second.GetMutablePointer();
} else {
// We will create a new one.
auto alloc = Services::GetAllocator<TensorRowRequest>();
auto r = cache_miss_map_.emplace(row_id, MemGuard<TensorRowRequest, Allocator<TensorRowRequest>>(alloc));
if (r.second) {
auto &mem = r.first->second;
RETURN_IF_NOT_OK(mem.allocate(1, row_id));
*out = mem.GetMutablePointer();
} else {
RETURN_STATUS_UNEXPECTED("Map insert fail.");
}
}
return Status::OK();
}
Status CacheMergeOp::PrepareNodePostAction() { // Run any common code from super class first before adding our own
// specific logic
CHECK_FAIL_RETURN_UNEXPECTED(child_.size() == 2, "Incorrect number of children");
......@@ -199,6 +198,7 @@ Status CacheMergeOp::PrepareNodePostAction() { // Run any common code from supe
RETURN_IF_NOT_OK(rc);
return Status::OK();
}
Status CacheMergeOp::ComputeColMap() {
CHECK_FAIL_RETURN_UNEXPECTED(child_[kCacheMissChildIdx] != nullptr, "Cache miss stream empty");
if (column_name_id_map().empty()) {
......@@ -207,53 +207,13 @@ Status CacheMergeOp::ComputeColMap() {
CHECK_FAIL_RETURN_UNEXPECTED(!column_name_id_map().empty(), "No column map detected");
return Status::OK();
}
Status CacheMergeOp::TensorRowRequest::Wait(TensorRow *out) {
RETURN_UNEXPECTED_IF_NULL(out);
// Block until the missing row is in the pool.
RETURN_IF_NOT_OK(use_count_.P());
std::unique_lock<std::mutex> lck(dq_mux_);
CHECK_FAIL_RETURN_UNEXPECTED(!row_.empty(), "Programming error");
*out = std::move(row_.front());
row_.pop_front();
return Status::OK();
}
void CacheMergeOp::TensorRowRequest::WakeUpAny(TensorRow &&row) {
std::unique_lock<std::mutex> lck(dq_mux_);
// Technically number of this row shows up in the cache miss stream is equal to the number
// of P() call. However the cleaner wants it too. So we need an extra copy.
if (GetState() == State::kEmpty) {
// We will do a deep copy
for (auto &ts : row) {
std::shared_ptr<Tensor> out_ts;
Tensor::CreateFromTensor(ts, &out_ts);
cleaner_copy_.push_back(out_ts);
}
cleaner_copy_.setId(row.getId());
// Change the state to dirty
SetState(State::kDirty);
}
row_.push_back(std::move(row));
// Bump up the use count by 1. This wake up any parallel worker which is waiting
// for this row.
use_count_.V();
}
Status CacheMergeOp::TensorRowRequest::Release(TensorRow *out) {
RETURN_UNEXPECTED_IF_NULL(out);
// We are not holding any mutex here because the cleaner isn't really touching the deque row_.
// In case we have multiple cleaners and they all see the copy, only one of them will
// get it.
auto expected = State::kDirty;
if (st_.compare_exchange_strong(expected, State::kClean)) {
*out = std::move(cleaner_copy_);
}
return Status::OK();
}
// Builder constructor. Creates the builder object.
CacheMergeOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
build_num_workers_ = cfg->num_parallel_workers();
build_op_connector_size_ = cfg->op_connector_size();
build_num_cleaners_ = 1;
build_num_cleaners_ = cfg->num_parallel_workers();
}
// Check if the required parameters are set by the builder.
......@@ -311,5 +271,60 @@ Status CacheMergeOp::EofReceived(int32_t worker_id) {
MS_LOG(DEBUG) << "Cache merge sending eof";
return DatasetOp::EofReceived(worker_id);
}
Status CacheMergeOp::GetRq(row_id_type row_id, CacheMergeOp::TensorRowCacheRequest **out) {
RETURN_UNEXPECTED_IF_NULL(out);
std::unique_lock<std::mutex> lock(mux_);
auto it = io_request_.find(row_id);
if (it != io_request_.end()) {
*out = it->second.GetMutablePointer();
} else {
// We will create a new one.
auto alloc = Services::GetAllocator<TensorRowCacheRequest>();
auto r = io_request_.emplace(row_id, MemGuard<TensorRowCacheRequest, Allocator<TensorRowCacheRequest>>(alloc));
if (r.second) {
auto &mem = r.first->second;
RETURN_IF_NOT_OK(mem.allocate(1));
*out = mem.GetMutablePointer();
} else {
RETURN_STATUS_UNEXPECTED("Map insert fail.");
}
}
return Status::OK();
}
Status CacheMergeOp::TensorRowCacheRequest::AsyncSendCacheRequest(const std::shared_ptr<CacheClient> &cc,
const TensorRow &row) {
auto expected = State::kEmpty;
if (st_.compare_exchange_strong(expected, State::kDirty)) {
// We will do a deep copy but write directly into CacheRequest protobuf or shared memory
Status rc;
cleaner_copy_ =
std::make_shared<CacheRowRequest>(cc->server_connection_id_, cc->cookie(), cc->SupportLocalClient());
rc = cleaner_copy_->SerializeCacheRowRequest(cc.get(), row);
if (rc.IsOk()) {
// Send the request async. The cleaner will check the return code.
rc = cc->PushRequest(cleaner_copy_);
}
if (rc.IsError()) {
// Clean up the shared pointer and reset the state back to empty
cleaner_copy_.reset();
st_ = State::kEmpty;
}
}
return Status::OK();
}
Status CacheMergeOp::TensorRowCacheRequest::CheckCacheResult() {
auto expected = State::kDirty;
if (st_.compare_exchange_strong(expected, State::kClean)) {
// Success or not, we will release the memory.
// We simply move it out of the structure and let it go out of scope.
auto cache_request = std::move(cleaner_copy_);
RETURN_IF_NOT_OK(cache_request->Wait());
return Status::OK();
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
......@@ -142,7 +142,7 @@ Status CacheOp::WaitForCachingAllRows() {
}
// Get statistics from the server, and if we are not the one to create the cache,
// wait until the state changed from build phase to fetch base.
CacheClient::ServiceStat stat{};
CacheServiceStat stat{};
bool BuildPhaseDone = true;
do {
RETURN_IF_NOT_OK(cache_client_->GetStat(&stat));
......@@ -157,6 +157,7 @@ Status CacheOp::WaitForCachingAllRows() {
MS_LOG(INFO) << "Number of rows cached: " << num_rows_;
MS_LOG(INFO) << "Number of rows cached in memory : " << stat.num_mem_cached;
MS_LOG(INFO) << "Number of rows spilled to disk : " << stat.num_disk_cached;
MS_LOG(INFO) << "Average cache size : " << stat.avg_cache_sz;
// Now all rows are cached and we have done a sync point check up. Next phase is
// is pick up fetch input from sampler and pass up to the caller.
RETURN_IF_NOT_OK(sampler_->HandshakeRandomAccessOp(this));
......
......@@ -392,6 +392,13 @@ uint32_t DatasetOp::GenerateCRC(const std::shared_ptr<DatasetOp> &op) {
ss_str = std::regex_replace(ss_str, std::regex("Num workers.*\n"), "");
ss_str = std::regex_replace(ss_str, std::regex("\\[workers.*\\]"), "");
// Filter out tcp/ip information
ss_str = std::regex_replace(ss_str, std::regex("Hostname.*\n"), "");
ss_str = std::regex_replace(ss_str, std::regex("Port.*\n"), "");
ss_str = std::regex_replace(ss_str, std::regex("Number of rpc workers.*\n"), "");
ss_str = std::regex_replace(ss_str, std::regex("Prefetch size.*\n"), "");
ss_str = std::regex_replace(ss_str, std::regex("Local client support.*\n"), "");
// Filter out Number of rows when generating the check sum
ss_str = std::regex_replace(ss_str, std::regex("Number of rows.*\n"), "");
......
......@@ -73,6 +73,7 @@ enum class StatusCode : char {
kProfilingError = 10,
kBoundingBoxOutOfBounds = 11,
kBoundingBoxInvalidShape = 12,
kSyntaxError = 13,
// Make this error code the last one. Add new error code above it.
kUnexpectedError = 127
};
......
......@@ -168,9 +168,9 @@ class MemGuard {
size_t GetSizeInBytes() const { return n_ * sizeof(T); }
private:
size_t n_;
allocator alloc_;
std::unique_ptr<T[]> ptr_;
size_t n_;
};
} // namespace dataset
} // namespace mindspore
......
......@@ -82,6 +82,7 @@ class CachePool : public Service {
struct CacheStat {
int64_t num_mem_cached;
int64_t num_disk_cached;
int64_t average_cache_sz;
};
/// \brief Constructor
......
此差异已折叠。
......@@ -86,6 +86,7 @@ class ReadableSlice {
class WritableSlice : public ReadableSlice {
public:
friend class StorageContainer;
friend class CacheService;
/// \brief Default constructor
WritableSlice() : ReadableSlice(), mutable_data_(nullptr) {}
/// \brief This form of a constructor takes a pointer and its size.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册