Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
8a08d0c3
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8a08d0c3
编写于
7月 18, 2020
作者:
J
Jesse Lee
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Phase 2 of CacheOp
上级
b11ef57b
变更
52
展开全部
隐藏空白更改
内联
并排
Showing
52 changed file
with
3929 addition
and
808 deletion
+3929
-808
mindspore/ccsrc/minddata/dataset/CMakeLists.txt
mindspore/ccsrc/minddata/dataset/CMakeLists.txt
+12
-12
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/cache/bindings.cc
...aset/api/python/bindings/dataset/engine/cache/bindings.cc
+19
-1
mindspore/ccsrc/minddata/dataset/core/constants.h
mindspore/ccsrc/minddata/dataset/core/constants.h
+2
-1
mindspore/ccsrc/minddata/dataset/engine/CMakeLists.txt
mindspore/ccsrc/minddata/dataset/engine/CMakeLists.txt
+3
-5
mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt
mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt
+42
-3
mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin.cc
mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin.cc
+70
-0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc
...re/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc
+396
-0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h
...ore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h
+105
-0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.cc
mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.cc
+73
-0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.h
mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.h
+52
-0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc
...spore/ccsrc/minddata/dataset/engine/cache/cache_client.cc
+91
-71
mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h
mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h
+144
-18
mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h
mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h
+90
-0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_fbb.cc
mindspore/ccsrc/minddata/dataset/engine/cache/cache_fbb.cc
+151
-0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_fbb.h
mindspore/ccsrc/minddata/dataset/engine/cache/cache_fbb.h
+46
-0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc.proto
...pore/ccsrc/minddata/dataset/engine/cache/cache_grpc.proto
+54
-0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.cc
.../ccsrc/minddata/dataset/engine/cache/cache_grpc_client.cc
+161
-0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.h
...e/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.h
+102
-0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc
.../ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc
+203
-0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h
...e/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h
+103
-0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc
mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc
+121
-0
mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc
...pore/ccsrc/minddata/dataset/engine/cache/cache_request.cc
+175
-145
mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h
...spore/ccsrc/minddata/dataset/engine/cache/cache_request.h
+207
-85
mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc
...spore/ccsrc/minddata/dataset/engine/cache/cache_server.cc
+550
-123
mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h
mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h
+154
-14
mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc
...pore/ccsrc/minddata/dataset/engine/cache/cache_service.cc
+79
-30
mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h
...spore/ccsrc/minddata/dataset/engine/cache/cache_service.h
+18
-4
mindspore/ccsrc/minddata/dataset/engine/cache/de_tensor.fbs
mindspore/ccsrc/minddata/dataset/engine/cache/de_tensor.fbs
+14
-1
mindspore/ccsrc/minddata/dataset/engine/cache/stub/cache_grpc_client.h
...rc/minddata/dataset/engine/cache/stub/cache_grpc_client.h
+45
-0
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc
...ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc
+119
-20
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h
.../ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h
+15
-1
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc
...csrc/minddata/dataset/engine/datasetops/cache_merge_op.cc
+95
-80
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h
...ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h
+30
-14
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc
...pore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc
+2
-1
mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc
...re/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc
+7
-0
mindspore/ccsrc/minddata/dataset/include/status.h
mindspore/ccsrc/minddata/dataset/include/status.h
+1
-0
mindspore/ccsrc/minddata/dataset/util/allocator.h
mindspore/ccsrc/minddata/dataset/util/allocator.h
+1
-1
mindspore/ccsrc/minddata/dataset/util/arena.h
mindspore/ccsrc/minddata/dataset/util/arena.h
+15
-15
mindspore/ccsrc/minddata/dataset/util/cache_pool.cc
mindspore/ccsrc/minddata/dataset/util/cache_pool.cc
+9
-0
mindspore/ccsrc/minddata/dataset/util/cache_pool.h
mindspore/ccsrc/minddata/dataset/util/cache_pool.h
+1
-0
mindspore/ccsrc/minddata/dataset/util/queue_map.h
mindspore/ccsrc/minddata/dataset/util/queue_map.h
+127
-0
mindspore/ccsrc/minddata/dataset/util/services.cc
mindspore/ccsrc/minddata/dataset/util/services.cc
+11
-43
mindspore/ccsrc/minddata/dataset/util/services.h
mindspore/ccsrc/minddata/dataset/util/services.h
+20
-12
mindspore/ccsrc/minddata/dataset/util/slice.h
mindspore/ccsrc/minddata/dataset/util/slice.h
+1
-0
mindspore/ccsrc/minddata/dataset/util/status.cc
mindspore/ccsrc/minddata/dataset/util/status.cc
+3
-0
mindspore/ccsrc/minddata/dataset/util/status.h
mindspore/ccsrc/minddata/dataset/util/status.h
+1
-0
mindspore/ccsrc/minddata/dataset/util/task_manager.cc
mindspore/ccsrc/minddata/dataset/util/task_manager.cc
+2
-0
mindspore/ccsrc/minddata/dataset/util/task_manager.h
mindspore/ccsrc/minddata/dataset/util/task_manager.h
+12
-1
mindspore/dataset/engine/cache_client.py
mindspore/dataset/engine/cache_client.py
+11
-2
tests/ut/cpp/dataset/cache_op_test.cc
tests/ut/cpp/dataset/cache_op_test.cc
+131
-100
tests/ut/python/dataset/test_cache_map.py
tests/ut/python/dataset/test_cache_map.py
+7
-4
tests/ut/python/dataset/test_cache_nomap.py
tests/ut/python/dataset/test_cache_nomap.py
+26
-1
未找到文件。
mindspore/ccsrc/minddata/dataset/CMakeLists.txt
浏览文件 @
8a08d0c3
...
...
@@ -24,6 +24,11 @@ if (ENABLE_TDTQUE)
add_definitions
(
-D ENABLE_TDTQUE
)
message
(
STATUS
"TDT queue is enabled"
)
endif
()
if
(
MS_BUILD_GRPC
)
set
(
ENABLE_CACHE true
)
add_definitions
(
-D ENABLE_CACHE
)
message
(
STATUS
"Cache is enabled"
)
endif
()
# conde coverage
# option(ENABLE_COVERAGE "Enable code coverage report" OFF)
...
...
@@ -47,10 +52,6 @@ include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/minddata/dataset/include
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-Wl,-rpath,$ORIGIN:$ORIGIN/lib"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-fvisibility=default"
)
include_directories
(
"
${
CMAKE_BINARY_DIR
}
/minddata/dataset/engine/cache"
)
set
(
MD_FLATBUFFER_OU
"
${
CMAKE_BINARY_DIR
}
/minddata/dataset/engine/cache"
)
ms_build_flatbuffers
(
"engine/cache/de_tensor.fbs"
${
CMAKE_CURRENT_SOURCE_DIR
}
generated_engine_files
${
MD_FLATBUFFER_OU
}
)
################## Include sub-modules ###############################
add_subdirectory
(
util
)
add_subdirectory
(
core
)
...
...
@@ -70,8 +71,6 @@ add_dependencies(engine-datasetops-source-sampler core)
add_dependencies
(
engine-datasetops core
)
add_dependencies
(
engine-datasetops-mapop core
)
add_dependencies
(
engine-opt core
)
add_dependencies
(
engine-cache-client core
)
add_dependencies
(
engine-cache-server core
)
add_dependencies
(
engine-perf core
)
add_dependencies
(
engine-gnn core
)
add_dependencies
(
engine core
)
...
...
@@ -85,7 +84,11 @@ endif()
if
(
ENABLE_TDTQUE
)
add_dependencies
(
engine-tdt core
)
endif
()
if
(
ENABLE_CACHE
)
add_dependencies
(
engine-datasetops engine-cache-client
)
add_dependencies
(
engine-cache-client core
)
add_dependencies
(
engine-cache-server core
)
endif
()
################### Create _c_dataengine Library ######################
set
(
submodules
$<TARGET_OBJECTS:core>
...
...
@@ -105,7 +108,6 @@ set(submodules
$<TARGET_OBJECTS:engine-datasetops>
$<TARGET_OBJECTS:engine-opt>
$<TARGET_OBJECTS:engine-cache-client>
$<TARGET_OBJECTS:engine-cache-server>
$<TARGET_OBJECTS:engine>
$<TARGET_OBJECTS:text>
$<TARGET_OBJECTS:text-kernels>
...
...
@@ -123,8 +125,6 @@ else ()
add_library
(
_c_dataengine SHARED
${
submodules
}
)
endif
()
add_dependencies
(
_c_dataengine generated_engine_files
)
if
(
ENABLE_PYTHON
)
set_target_properties
(
_c_dataengine PROPERTIES
PREFIX
"
${
PYTHON_MODULE_PREFIX
}
"
...
...
@@ -187,6 +187,6 @@ else()
endif
()
endif
()
if
(
NOT CMAKE_SYSTEM_NAME MATCHES
"Windows"
)
if
(
MS_BUILD_GRPC
)
target_link_libraries
(
_c_dataengine PRIVATE mindspore::grpc++
)
endif
()
\ No newline at end of file
endif
()
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/cache/bindings.cc
浏览文件 @
8a08d0c3
...
...
@@ -22,7 +22,25 @@ namespace dataset {
PYBIND_REGISTER
(
CacheClient
,
0
,
([](
const
py
::
module
*
m
)
{
(
void
)
py
::
class_
<
CacheClient
,
std
::
shared_ptr
<
CacheClient
>>
(
*
m
,
"CacheClient"
)
.
def
(
py
::
init
<
uint32_t
,
uint64_t
,
bool
>
());
.
def
(
py
::
init
([](
session_id_type
id
,
uint64_t
mem_sz
,
bool
spill
,
int32_t
port
,
int32_t
prefetch_sz
)
{
std
::
shared_ptr
<
CacheClient
>
cc
;
CacheClient
::
Builder
builder
;
builder
.
SetSessionId
(
id
).
SetCacheMemSz
(
mem_sz
).
SetSpill
(
spill
).
SetPort
(
port
).
SetPrefetchSize
(
prefetch_sz
);
THROW_IF_ERROR
(
builder
.
Build
(
&
cc
));
return
cc
;
}))
.
def
(
"GetStat"
,
[](
CacheClient
&
cc
)
{
CacheServiceStat
stat
{};
THROW_IF_ERROR
(
cc
.
GetStat
(
&
stat
));
return
stat
;
});
(
void
)
py
::
class_
<
CacheServiceStat
>
(
*
m
,
"CacheServiceStat"
)
.
def
(
py
::
init
<>
())
.
def_readwrite
(
"avg_cache_sz"
,
&
CacheServiceStat
::
avg_cache_sz
)
.
def_readwrite
(
"num_mem_cached"
,
&
CacheServiceStat
::
num_mem_cached
)
.
def_readwrite
(
"num_disk_cached"
,
&
CacheServiceStat
::
num_disk_cached
);
}));
}
// namespace dataset
...
...
mindspore/ccsrc/minddata/dataset/core/constants.h
浏览文件 @
8a08d0c3
...
...
@@ -72,7 +72,8 @@ constexpr uint32_t kCfgMonitorSamplingInterval = 10;
// Invalid OpenCV type should not be from 0 to 7 (opencv4/opencv2/core/hal/interface.h)
constexpr
uint8_t
kCVInvalidType
=
255
;
using
connection_id_type
=
int64_t
;
using
connection_id_type
=
uint64_t
;
using
session_id_type
=
uint32_t
;
using
row_id_type
=
int64_t
;
}
// namespace dataset
}
// namespace mindspore
...
...
mindspore/ccsrc/minddata/dataset/engine/CMakeLists.txt
浏览文件 @
8a08d0c3
...
...
@@ -20,10 +20,8 @@ if (ENABLE_PYTHON)
target_include_directories
(
engine PRIVATE
${
pybind11_INCLUDE_DIRS
}
)
endif
()
add_dependencies
(
engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf engine-cache-client engine-datasetops-mapop
)
if
(
ENABLE_TDTQUE
)
add_dependencies
(
engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn engine-perf
engine-cache-client engine-cache-server engine-datasetops-mapop
)
else
()
add_dependencies
(
engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf
engine-cache-client engine-cache-server engine-datasetops-mapop
)
add_dependencies
(
engine engine-tdt
)
endif
()
mindspore/ccsrc/minddata/dataset/engine/cache/CMakeLists.txt
浏览文件 @
8a08d0c3
include_directories
(
"
${
CMAKE_BINARY_DIR
}
/minddata/dataset/engine/cache"
)
set
(
MD_FLATBUFFER_OU
"
${
CMAKE_BINARY_DIR
}
/minddata/dataset/engine/cache"
)
ms_build_flatbuffers
(
"de_tensor.fbs"
${
CMAKE_CURRENT_SOURCE_DIR
}
generated_engine_files
${
MD_FLATBUFFER_OU
}
)
file
(
GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
"*.cc"
)
set_property
(
SOURCE
${
_CURRENT_SRC_FILES
}
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD
)
add_library
(
engine-cache-client OBJECT
cache_client.cc
cache_fbb.cc
cache_request.cc
)
add_library
(
engine-cache-server OBJECT
cache_service.cc
cache_server.cc
)
if
(
ENABLE_CACHE
)
ms_grpc_generate
(
CACHE_GRPC_SRCS CACHE_GRPC_HDRS cache_grpc.proto
)
target_sources
(
engine-cache-client PUBLIC
${
CACHE_GRPC_SRCS
}
cache_grpc_client.cc
)
add_library
(
engine-cache-server OBJECT
${
CACHE_GRPC_SRCS
}
cache_grpc_server.cc
cache_arena.cc
cache_service.cc
cache_server.cc
)
add_executable
(
cache_server cache_main.cc
)
target_link_libraries
(
cache_server
engine-cache-server
$<TARGET_OBJECTS:utils>
mindspore
mindspore::glog
mindspore::protobuf
mindspore::grpc++
mindspore_gvar
${
PYTHON_LIBRARIES
}
${
SECUREC_LIBRARY
}
pthread
)
add_executable
(
cache_admin cache_admin.cc cache_admin_arg.cc
)
target_link_libraries
(
cache_admin _c_dataengine _c_mindrecord
${
PYTHON_LIBRARIES
}
mindspore::glog
)
add_dependencies
(
engine-cache-server generated_engine_files
)
else
()
ms_protobuf_generate
(
CACHE_PROTO_SRCS CACHE_PRTO_HDRS cache_grpc.proto
)
target_sources
(
engine-cache-client PUBLIC
${
CACHE_PROTO_SRCS
}
)
endif
()
add_dependencies
(
engine-cache-client generated_engine_files
)
mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin.cc
0 → 100644
浏览文件 @
8a08d0c3
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <unistd.h>
#include <iostream>
#ifdef USE_GLOG
#include <glog/logging.h>
#endif
#include "minddata/dataset/engine/cache/cache_admin_arg.h"
namespace
ds
=
mindspore
::
dataset
;
int
main
(
int
argc
,
char
**
argv
)
{
ds
::
Status
rc
;
ds
::
CacheAdminArgHandler
args
;
std
::
stringstream
arg_stream
;
#ifdef USE_GLOG
FLAGS_log_dir
=
"/tmp"
;
google
::
InitGoogleLogging
(
argv
[
0
]);
#endif
std
::
string
warningMsg
;
warningMsg
.
reserve
(
512
);
warningMsg
+=
"WARNING:
\n
"
;
warningMsg
+=
"cache_admin and the cache server that it controls are currently only used for experimental research"
;
warningMsg
+=
" purposes at this time.
\n
"
;
warningMsg
+=
"It is not intended for general availability yet as it may not be stable. Use it at your own risk.
\n
"
;
// A warning message until the code is mature enough.
std
::
cerr
<<
warningMsg
<<
std
::
endl
;
if
(
argc
==
1
)
{
args
.
Help
();
return
0
;
}
// ingest all the args into a string stream for parsing
for
(
int
i
=
1
;
i
<
argc
;
++
i
)
{
arg_stream
<<
" "
<<
std
::
string
(
argv
[
i
]);
}
// Parse the args
rc
=
args
.
ParseArgStream
(
&
arg_stream
);
if
(
!
rc
.
IsOk
())
{
std
::
cerr
<<
rc
.
ToString
()
<<
std
::
endl
;
return
1
;
}
// Execute the command
rc
=
args
.
RunCommand
();
if
(
!
rc
.
IsOk
())
{
std
::
cerr
<<
rc
.
ToString
()
<<
std
::
endl
;
return
1
;
}
return
0
;
}
mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.cc
0 → 100644
浏览文件 @
8a08d0c3
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_admin_arg.h"
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/wait.h>
#include <unistd.h>
#include <cerrno>
#include <iostream>
#include <string>
#include <cstdlib>
#include "minddata/dataset/engine/cache/cache_request.h"
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/util/path.h"
namespace
mindspore
{
namespace
dataset
{
const
char
CacheAdminArgHandler
::
kDefaultHost
[]
=
"127.0.0.1"
;
const
char
CacheAdminArgHandler
::
kServerBinary
[]
=
"cache_server"
;
const
char
CacheAdminArgHandler
::
kDefaultSpillDir
[]
=
"/tmp"
;
CacheAdminArgHandler
::
CacheAdminArgHandler
()
:
port_
(
kDefaultPort
),
session_id_
(
0
),
num_workers_
(
kDefaultNumWorkers
),
shm_mem_sz_
(
kDefaultSharedMemorySizeInGB
),
log_level_
(
kDefaultLogLevel
),
hostname_
(
kDefaultHost
),
spill_dir_
(
kDefaultSpillDir
),
command_id_
(
CommandId
::
kCmdUnknown
)
{
// Initialize the command mappings
arg_map_
[
"-h"
]
=
ArgValue
::
kArgHost
;
arg_map_
[
"--hostname"
]
=
ArgValue
::
kArgHost
;
arg_map_
[
"-p"
]
=
ArgValue
::
kArgPort
;
arg_map_
[
"--port"
]
=
ArgValue
::
kArgPort
;
arg_map_
[
"--start"
]
=
ArgValue
::
kArgStart
;
arg_map_
[
"--stop"
]
=
ArgValue
::
kArgStop
;
arg_map_
[
"--help"
]
=
ArgValue
::
kArgHelp
;
arg_map_
[
"--generate_session"
]
=
ArgValue
::
kArgGenerateSession
;
arg_map_
[
"-g"
]
=
ArgValue
::
kArgGenerateSession
;
arg_map_
[
"--destroy_session"
]
=
ArgValue
::
kArgDestroySession
;
arg_map_
[
"-d"
]
=
ArgValue
::
kArgDestroySession
;
arg_map_
[
"--spilldir"
]
=
ArgValue
::
kArgSpillDir
;
arg_map_
[
"-s"
]
=
ArgValue
::
kArgSpillDir
;
arg_map_
[
"-w"
]
=
ArgValue
::
kArgNumWorkers
;
arg_map_
[
"--workers"
]
=
ArgValue
::
kArgNumWorkers
;
arg_map_
[
"-m"
]
=
ArgValue
::
kArgSharedMemorySize
;
arg_map_
[
"--shared_memory_size"
]
=
ArgValue
::
kArgSharedMemorySize
;
arg_map_
[
"-l"
]
=
ArgValue
::
kArgLogLevel
;
arg_map_
[
"--minloglevel"
]
=
ArgValue
::
kArgLogLevel
;
// Initialize argument tracker with false values
for
(
int16_t
i
=
0
;
i
<
static_cast
<
int16_t
>
(
ArgValue
::
kArgNumArgs
);
++
i
)
{
ArgValue
currAV
=
static_cast
<
ArgValue
>
(
i
);
used_args_
[
currAV
]
=
false
;
}
}
Status
CacheAdminArgHandler
::
AssignArg
(
std
::
string
option
,
int32_t
*
out_arg
,
std
::
stringstream
*
arg_stream
,
CommandId
command_id
)
{
// Detect if the user tried to provide this argument more than once
ArgValue
selected_arg
=
arg_map_
[
option
];
if
(
used_args_
[
selected_arg
])
{
std
::
string
err_msg
=
"The "
+
option
+
" argument was given more than once."
;
return
Status
(
StatusCode
::
kSyntaxError
,
err_msg
);
}
// Flag that this arg is used now
used_args_
[
selected_arg
]
=
true
;
// Some options are just arguments, for example "--port 50052" is not a command, it's just a argument.
// Other options are actual commands, for example "--destroy_session 1234". This executes the destroy session.
// If this option is also a command, make sure there has not been multiple commands given before assigning it.
if
(
command_id
!=
CommandId
::
kCmdUnknown
)
{
if
(
command_id_
!=
CommandId
::
kCmdUnknown
)
{
std
::
string
err_msg
=
"Only one command at a time is allowed. Invalid command: "
+
option
;
return
Status
(
StatusCode
::
kSyntaxError
,
err_msg
);
}
else
{
command_id_
=
command_id
;
}
}
std
::
string
value_as_string
;
// Fetch the argument from the arg stream into a string
*
arg_stream
>>
value_as_string
;
if
(
value_as_string
.
empty
())
{
std
::
string
err_msg
=
option
+
" option requires an argument field. Syntax: "
+
option
+
" <field>"
;
return
Status
(
StatusCode
::
kSyntaxError
,
err_msg
);
}
// Now, attempt to convert the value into it's string format for output
try
{
*
out_arg
=
std
::
stoul
(
value_as_string
);
}
catch
(
const
std
::
exception
&
e
)
{
std
::
string
err_msg
=
"Invalid numeric value: "
+
value_as_string
;
return
Status
(
StatusCode
::
kSyntaxError
,
err_msg
);
}
return
Status
::
OK
();
}
Status
CacheAdminArgHandler
::
AssignArg
(
std
::
string
option
,
std
::
string
*
out_arg
,
std
::
stringstream
*
arg_stream
,
CommandId
command_id
)
{
// Detect if the user tried to provide this argument more than once
ArgValue
selected_arg
=
arg_map_
[
option
];
if
(
used_args_
[
selected_arg
])
{
std
::
string
err_msg
=
"The "
+
option
+
" argument was given more than once."
;
return
Status
(
StatusCode
::
kSyntaxError
,
err_msg
);
}
// Flag that this arg is used now
used_args_
[
selected_arg
]
=
true
;
// Some options are just arguments, for example "--hostname "127.0.0.1" is not a command, it's just an argument.
// Other options are actual commands, for example "--start".
// If this option is also a command, make sure there has not been multiple commands given before assigning it.
if
(
command_id
!=
CommandId
::
kCmdUnknown
)
{
if
(
command_id_
!=
CommandId
::
kCmdUnknown
)
{
std
::
string
err_msg
=
"Only one command at a time is allowed. Invalid command: "
+
option
;
return
Status
(
StatusCode
::
kSyntaxError
,
err_msg
);
}
else
{
command_id_
=
command_id
;
}
}
// If there is no argument to get, such as the --start command, then out_arg will be a nullptr.
if
(
out_arg
!=
nullptr
)
{
// Fetch the argument from the arg stream into a string
*
arg_stream
>>
*
out_arg
;
if
(
out_arg
->
empty
())
{
std
::
string
err_msg
=
option
+
" option requires an argument field. Syntax: "
+
option
+
" <field>"
;
return
Status
(
StatusCode
::
kSyntaxError
,
err_msg
);
}
}
return
Status
::
OK
();
}
Status
CacheAdminArgHandler
::
ParseArgStream
(
std
::
stringstream
*
arg_stream
)
{
std
::
string
tok
;
while
(
*
arg_stream
>>
tok
)
{
switch
(
arg_map_
[
tok
])
{
case
ArgValue
::
kArgHost
:
{
RETURN_IF_NOT_OK
(
AssignArg
(
tok
,
&
hostname_
,
arg_stream
));
break
;
}
case
ArgValue
::
kArgPort
:
{
RETURN_IF_NOT_OK
(
AssignArg
(
tok
,
&
port_
,
arg_stream
));
break
;
}
case
ArgValue
::
kArgStart
:
{
RETURN_IF_NOT_OK
(
AssignArg
(
tok
,
static_cast
<
std
::
string
*>
(
nullptr
),
arg_stream
,
CommandId
::
kCmdStart
));
break
;
}
case
ArgValue
::
kArgStop
:
{
RETURN_IF_NOT_OK
(
AssignArg
(
tok
,
static_cast
<
std
::
string
*>
(
nullptr
),
arg_stream
,
CommandId
::
kCmdStop
));
break
;
}
case
ArgValue
::
kArgGenerateSession
:
{
RETURN_IF_NOT_OK
(
AssignArg
(
tok
,
static_cast
<
std
::
string
*>
(
nullptr
),
arg_stream
,
CommandId
::
kCmdGenerateSession
));
break
;
}
case
ArgValue
::
kArgHelp
:
{
command_id_
=
CommandId
::
kCmdHelp
;
break
;
}
case
ArgValue
::
kArgDestroySession
:
{
// session_id is an unsigned type. We may need to template the AssignArg function so that
// it can handle different flavours of integers instead of just int32_t.
int32_t
session_int
;
RETURN_IF_NOT_OK
(
AssignArg
(
tok
,
&
session_int
,
arg_stream
,
CommandId
::
kCmdDestroySession
));
session_id_
=
session_int
;
break
;
}
case
ArgValue
::
kArgNumWorkers
:
{
RETURN_IF_NOT_OK
(
AssignArg
(
tok
,
&
num_workers_
,
arg_stream
));
break
;
}
case
ArgValue
::
kArgSpillDir
:
{
RETURN_IF_NOT_OK
(
AssignArg
(
tok
,
&
spill_dir_
,
arg_stream
));
break
;
}
case
ArgValue
::
kArgSharedMemorySize
:
{
RETURN_IF_NOT_OK
(
AssignArg
(
tok
,
&
shm_mem_sz_
,
arg_stream
));
break
;
}
case
ArgValue
::
kArgLogLevel
:
{
RETURN_IF_NOT_OK
(
AssignArg
(
tok
,
&
log_level_
,
arg_stream
));
break
;
}
default:
{
// Save space delimited trailing arguments
trailing_args_
+=
(
" "
+
tok
);
break
;
}
}
}
RETURN_IF_NOT_OK
(
Validate
());
return
Status
::
OK
();
}
Status
CacheAdminArgHandler
::
Validate
()
{
// This sanity check is delayed until now in case there may be valid use-cases of trailing args.
// Any unhandled arguments at this point is an error.
if
(
!
trailing_args_
.
empty
())
{
std
::
string
err_msg
=
"Invalid arguments provided: "
+
trailing_args_
;
return
Status
(
StatusCode
::
kSyntaxError
,
err_msg
);
}
// The user must pick at least one command. i.e. it's meaningless to just give a hostname or port but no command to
// run.
if
(
command_id_
==
CommandId
::
kCmdUnknown
)
{
std
::
string
err_msg
=
"No command provided"
;
return
Status
(
StatusCode
::
kSyntaxError
,
err_msg
);
}
// Additional checks here
if
(
num_workers_
<
1
)
return
Status
(
StatusCode
::
kSyntaxError
,
"Number of workers must be positive value."
);
if
(
log_level_
<
0
||
log_level_
>
3
)
return
Status
(
StatusCode
::
kSyntaxError
,
"Log level must be in range (0..3)."
);
// port range check?
return
Status
::
OK
();
}
Status
CacheAdminArgHandler
::
RunCommand
()
{
switch
(
command_id_
)
{
case
CommandId
::
kCmdHelp
:
{
Help
();
break
;
}
case
CommandId
::
kCmdStart
:
{
RETURN_IF_NOT_OK
(
StartServer
());
break
;
}
case
CommandId
::
kCmdStop
:
{
RETURN_IF_NOT_OK
(
StopServer
());
break
;
}
case
CommandId
::
kCmdGenerateSession
:
{
CacheClientGreeter
comm
(
hostname_
,
port_
,
1
);
RETURN_IF_NOT_OK
(
comm
.
ServiceStart
());
auto
rq
=
std
::
make_shared
<
GenerateSessionIdRequest
>
();
RETURN_IF_NOT_OK
(
comm
.
HandleRequest
(
rq
));
RETURN_IF_NOT_OK
(
rq
->
Wait
());
std
::
cout
<<
rq
->
GetSessionId
()
<<
std
::
endl
;
break
;
}
case
CommandId
::
kCmdDestroySession
:
{
CacheClientGreeter
comm
(
hostname_
,
port_
,
1
);
RETURN_IF_NOT_OK
(
comm
.
ServiceStart
());
CacheClientInfo
cinfo
;
cinfo
.
set_session_id
(
session_id_
);
auto
rq
=
std
::
make_shared
<
DropSessionRequest
>
(
cinfo
);
RETURN_IF_NOT_OK
(
comm
.
HandleRequest
(
rq
));
RETURN_IF_NOT_OK
(
rq
->
Wait
());
std
::
cout
<<
"Drop session successful"
<<
std
::
endl
;
break
;
}
default:
{
RETURN_STATUS_UNEXPECTED
(
"Invalid cache admin command id."
);
break
;
}
}
return
Status
::
OK
();
}
Status
CacheAdminArgHandler
::
StartServer
()
{
// There currently does not exist any "install path" or method to identify which path the installed binaries will
// exist in. As a temporary approach, we will assume that the server binary shall exist in the same path as the
// cache_admin binary (this process).
const
std
::
string
self_proc
=
"/proc/self/exe"
;
std
::
string
canonical_path
;
canonical_path
.
resize
(
400
);
// PATH_MAX is large. This value should be big enough for our use.
// Some lower level OS library calls are needed here to determine the binary path.
// Fetch the path of this binary for admin_cache into C character array and then truncate off the binary name so that
// we are left with only the absolute path
if
(
realpath
(
self_proc
.
data
(),
canonical_path
.
data
())
==
nullptr
)
{
std
::
string
err_msg
=
"Failed to identify cache admin binary path: "
+
std
::
to_string
(
errno
);
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
canonical_path
.
resize
(
strlen
(
canonical_path
.
data
()));
int
last_seperator
=
canonical_path
.
find_last_of
(
'/'
);
CHECK_FAIL_RETURN_UNEXPECTED
(
last_seperator
!=
std
::
string
::
npos
,
"No / found"
);
// truncate the binary name so we are left with the absolute path of cache_admin binary
canonical_path
.
resize
(
last_seperator
+
1
);
std
::
string
cache_server_binary
=
canonical_path
+
std
::
string
(
kServerBinary
);
// Create a pipe before we fork. If all goes well, the child will run as a daemon in the background
// and never returns until shutdown. If there is any error, the child will notify us through the pipe.
int
fd
[
2
];
if
(
pipe
(
fd
)
==
-
1
)
{
std
::
string
err_msg
=
"Failed to create a pipe for communication "
+
std
::
to_string
(
errno
);
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
// fork the child process to become the daemon
pid_t
pid
;
pid
=
fork
();
// failed to fork
if
(
pid
<
0
)
{
std
::
string
err_msg
=
"Failed to fork process for cache server: "
+
std
::
to_string
(
errno
);
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
else
if
(
pid
>
0
)
{
// As a parent, we close the write end. We only listen.
close
(
fd
[
1
]);
dup2
(
fd
[
0
],
0
);
close
(
fd
[
0
]);
wait
(
nullptr
);
std
::
string
msg
;
const
int32_t
buf_sz
=
1024
;
msg
.
resize
(
buf_sz
);
auto
n
=
read
(
0
,
msg
.
data
(),
buf_sz
);
if
(
n
<
0
)
{
std
::
string
err_msg
=
"Failed to read from pipeline "
+
std
::
to_string
(
errno
);
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
msg
.
resize
(
n
);
std
::
cout
<<
msg
<<
std
::
endl
;
return
Status
::
OK
();
}
else
{
// Child here ...
// Close all stdin, redirect stdout and stderr to the write end of the pipe.
close
(
fd
[
0
]);
dup2
(
fd
[
1
],
1
);
dup2
(
fd
[
1
],
2
);
close
(
0
);
close
(
fd
[
1
]);
// exec the cache server binary in this process
std
::
string
port_string
=
std
::
to_string
(
port_
);
std
::
string
workers_string
=
std
::
to_string
(
num_workers_
);
std
::
string
shared_memory_string
=
std
::
to_string
(
shm_mem_sz_
);
std
::
string
minloglevel_string
=
std
::
to_string
(
log_level_
);
std
::
string
daemonize_string
=
"true"
;
char
*
argv
[
8
];
argv
[
0
]
=
cache_server_binary
.
data
();
// First arg is usually the binary name
argv
[
1
]
=
spill_dir_
.
data
();
argv
[
2
]
=
workers_string
.
data
();
argv
[
3
]
=
port_string
.
data
();
argv
[
4
]
=
shared_memory_string
.
data
();
argv
[
5
]
=
minloglevel_string
.
data
();
argv
[
6
]
=
daemonize_string
.
data
();
argv
[
7
]
=
nullptr
;
// Now exec the binary
execv
(
argv
[
0
],
argv
);
// If the exec was successful, this line will never be reached due to process image being replaced.
// ..unless exec failed.
std
::
string
err_msg
=
"Failed to exec cache server: "
+
cache_server_binary
;
std
::
cerr
<<
err_msg
<<
std
::
endl
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
}
Status
CacheAdminArgHandler
::
StopServer
()
{
CacheClientGreeter
comm
(
hostname_
,
port_
,
1
);
RETURN_IF_NOT_OK
(
comm
.
ServiceStart
());
auto
rq
=
std
::
make_shared
<
ShutdownRequest
>
();
RETURN_IF_NOT_OK
(
comm
.
HandleRequest
(
rq
));
return
Status
::
OK
();
}
void
CacheAdminArgHandler
::
Help
()
{
std
::
cerr
<<
"Syntax:
\n
"
;
std
::
cerr
<<
" cache_admin [--start | --stop]
\n
"
;
std
::
cerr
<<
" [ [-h | --hostname] <hostname> ]
\n
"
;
std
::
cerr
<<
" [ [-p | --port] <port number> ]
\n
"
;
std
::
cerr
<<
" [ [-g | --generate_session] ]
\n
"
;
std
::
cerr
<<
" [ [-d | --destroy_session] <session id> ]
\n
"
;
std
::
cerr
<<
" [ [-w | --workers] <number of workers> ]
\n
"
;
std
::
cerr
<<
" [ [-s | --spilldir] <spilling directory> ]
\n
"
;
std
::
cerr
<<
" [ [-m | --shared_memory_size] <shared memory size> ]
\n
"
;
std
::
cerr
<<
" [ [-l | --minloglevel] <log level> ]
\n
"
;
std
::
cerr
<<
" [--help]"
<<
std
::
endl
;
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/engine/cache/cache_admin_arg.h
0 → 100644
浏览文件 @
8a08d0c3
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ADMIN_ARG_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ADMIN_ARG_H_
#include <iostream>
#include <map>
#include <memory>
#include <string>
#include <sstream>
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/engine/cache/cache_client.h"
namespace
mindspore
{
namespace
dataset
{
class
CacheAdminArgHandler
{
public:
static
constexpr
int32_t
kDefaultPort
=
50052
;
static
constexpr
int32_t
kDefaultNumWorkers
=
32
;
static
constexpr
int32_t
kDefaultSharedMemorySizeInGB
=
4
;
static
constexpr
int32_t
kDefaultLogLevel
=
1
;
static
const
char
kDefaultHost
[];
static
const
char
kServerBinary
[];
static
const
char
kDefaultSpillDir
[];
// These are the actual command types to execute
enum
class
CommandId
:
int16_t
{
kCmdHelp
=
0
,
kCmdStart
=
1
,
kCmdStop
=
2
,
kCmdGenerateSession
=
3
,
kCmdDestroySession
=
4
,
kCmdUnknown
=
32767
};
CacheAdminArgHandler
();
~
CacheAdminArgHandler
()
=
default
;
Status
ParseArgStream
(
std
::
stringstream
*
arg_stream
);
Status
RunCommand
();
void
Help
();
private:
// These are the supported argument string integer mappings
enum
class
ArgValue
:
int16_t
{
kArgUnknown
=
0
,
// Must be at position 0. invalid map lookups in arg_map_ default to value 0
kArgStart
=
1
,
kArgStop
=
2
,
kArgHost
=
3
,
kArgPort
=
4
,
kArgHelp
=
5
,
kArgGenerateSession
=
6
,
kArgDestroySession
=
7
,
kArgSpillDir
=
8
,
kArgNumWorkers
=
9
,
kArgSharedMemorySize
=
10
,
kArgLogLevel
=
11
,
kArgNumArgs
=
12
// Must be the last position to provide a count
};
Status
StartServer
();
Status
StopServer
();
Status
AssignArg
(
std
::
string
option
,
int32_t
*
out_arg
,
std
::
stringstream
*
arg_stream
,
CommandId
command_id
=
CommandId
::
kCmdUnknown
);
Status
AssignArg
(
std
::
string
option
,
std
::
string
*
out_arg
,
std
::
stringstream
*
arg_stream
,
CommandId
command_id
=
CommandId
::
kCmdUnknown
);
Status
Validate
();
CommandId
command_id_
;
int32_t
port_
;
int32_t
num_workers_
;
int32_t
shm_mem_sz_
;
int32_t
log_level_
;
session_id_type
session_id_
;
std
::
string
hostname_
;
std
::
string
spill_dir_
;
std
::
string
trailing_args_
;
std
::
map
<
std
::
string
,
ArgValue
>
arg_map_
;
std
::
map
<
ArgValue
,
bool
>
used_args_
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ADMIN_ARG_H_
mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.cc
0 → 100644
浏览文件 @
8a08d0c3
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_arena.h"
#include "minddata/dataset/util/path.h"
namespace
mindspore
{
namespace
dataset
{
CachedSharedMemoryArena
::
CachedSharedMemoryArena
(
int32_t
port
,
size_t
val_in_GB
)
:
Arena
::
Arena
(
val_in_GB
*
1024
),
port_
(
port
),
shmid_
(
-
1
)
{}
CachedSharedMemoryArena
::~
CachedSharedMemoryArena
()
{
#if CACHE_LOCAL_CLIENT
if
(
this
->
ptr_
!=
nullptr
&&
this
->
ptr_
!=
reinterpret_cast
<
void
*>
(
-
1
))
{
shmdt
(
this
->
ptr_
);
}
this
->
ptr_
=
nullptr
;
if
(
shmid_
!=
-
1
)
{
shmctl
(
shmid_
,
IPC_RMID
,
nullptr
);
// Also remove the path we use to generate ftok.
Path
p
(
PortToUnixSocketPath
(
port_
));
(
void
)
p
.
Remove
();
}
#endif
}
Status
CachedSharedMemoryArena
::
CreateArena
(
std
::
unique_ptr
<
CachedSharedMemoryArena
>
*
out
,
int32_t
port
,
size_t
val_in_GB
)
{
RETURN_UNEXPECTED_IF_NULL
(
out
);
#if CACHE_LOCAL_CLIENT
auto
ba
=
new
(
std
::
nothrow
)
CachedSharedMemoryArena
(
port
,
val_in_GB
);
if
(
ba
==
nullptr
)
{
return
Status
(
StatusCode
::
kOutOfMemory
);
}
// Transfer the ownership of this pointer. Any future error in the processing we will have
// the destructor of *out to deal.
(
*
out
).
reset
(
ba
);
// Generate the ftok using a combination of port.
int
err
;
auto
shm_key
=
PortToFtok
(
port
,
&
err
);
if
(
shm_key
==
(
key_t
)
-
1
)
{
std
::
string
errMsg
=
"Ftok failed with errno "
+
std
::
to_string
(
err
);
RETURN_STATUS_UNEXPECTED
(
errMsg
);
}
auto
access_mode
=
S_IRUSR
|
S_IWUSR
|
S_IROTH
|
S_IWOTH
|
S_IRGRP
|
S_IWGRP
;
ba
->
shmid_
=
shmget
(
shm_key
,
ba
->
size_in_bytes_
,
IPC_CREAT
|
IPC_EXCL
|
access_mode
);
if
(
ba
->
shmid_
)
{
ba
->
ptr_
=
shmat
(
ba
->
shmid_
,
nullptr
,
0
);
if
(
ba
->
ptr_
==
reinterpret_cast
<
void
*>
(
-
1
))
{
RETURN_STATUS_UNEXPECTED
(
"Shared memory attach failed. Errno "
+
std
::
to_string
(
errno
));
}
}
else
{
RETURN_STATUS_UNEXPECTED
(
"Shared memory creation failed. Errno "
+
std
::
to_string
(
errno
));
}
uint64_t
num_blks
=
ba
->
size_in_bytes_
/
ARENA_BLK_SZ
;
MS_LOG
(
DEBUG
)
<<
"Size of memory pool is "
<<
num_blks
<<
", number of blocks of size is "
<<
ARENA_BLK_SZ
<<
"."
;
ba
->
tr_
.
Insert
(
0
,
num_blks
);
#endif
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/engine/cache/cache_arena.h
0 → 100644
浏览文件 @
8a08d0c3
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ARENA_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ARENA_H_
#include <memory>
#include <string>
#include "minddata/dataset/util/arena.h"
#include "minddata/dataset/engine/cache/cache_common.h"
namespace
mindspore
{
namespace
dataset
{
/// This is a derived class of Arena but resides in shared memory
class
CachedSharedMemoryArena
:
public
Arena
{
public:
~
CachedSharedMemoryArena
()
override
;
/// \brief Create an Arena in shared memory
/// \param[out] p_ba Pointer to a unique_ptr
/// \param shmkey Shared memory key
/// \param val_in_GB size of shared memory in gigabyte
/// \return Status object
static
Status
CreateArena
(
std
::
unique_ptr
<
CachedSharedMemoryArena
>
*
out
,
int32_t
port
,
size_t
val_in_GB
);
/// \brief This returns where we attach to the shared memory.
/// Some gRPC requests will ask for a shared memory block, and
/// we can't return the absolute address as this makes no sense
/// in the client. So instead we will return an address relative
/// to the base address of the shared memory where we attach to.
/// \return Base address of the shared memory.
const
void
*
SharedMemoryBaseAddr
()
const
{
return
this
->
ptr_
;
}
private:
int32_t
port_
;
int
shmid_
;
/// Private constructor. Not to be called directly.
CachedSharedMemoryArena
(
int32_t
port
,
size_t
val_in_GB
);
};
}
// namespace dataset
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_ARENA_H_
mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.cc
浏览文件 @
8a08d0c3
...
...
@@ -17,29 +17,45 @@
#include <iomanip>
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/cache/cache_request.h"
#include "minddata/dataset/engine/cache/cache_service.h"
#include "minddata/dataset/engine/cache/cache_fbb.h"
#include "minddata/dataset/util/bit.h"
namespace
mindspore
{
namespace
dataset
{
// Constructor
CacheClient
::
CacheClient
(
uint32_t
session_id
,
uint64_t
cache_mem_sz
,
bool
spill
)
:
server_connection_id_
(
0
),
session_id_
(
session_id
),
cache_crc_
(
0
),
cache_mem_sz_
(
cache_mem_sz
),
spill_
(
spill
)
{}
CacheClient
::
CacheClient
(
session_id_type
session_id
,
uint64_t
cache_mem_sz
,
bool
spill
,
std
::
string
hostname
,
int32_t
port
,
int32_t
num_workers
,
int32_t
prefetch_size
)
:
server_connection_id_
(
0
),
cache_mem_sz_
(
cache_mem_sz
),
spill_
(
spill
),
local_bypass_
(
false
),
hostname_
(
std
::
move
(
hostname
)),
port_
(
port
),
num_workers_
(
num_workers
),
prefetch_size_
(
prefetch_size
)
{
cinfo_
.
set_session_id
(
session_id
);
comm_
=
std
::
make_shared
<
CacheClientGreeter
>
(
hostname_
,
port_
,
num_workers_
);
}
// print method for display cache details
void
CacheClient
::
Print
(
std
::
ostream
&
out
)
const
{
out
<<
" Session id: "
<<
session_id_
<<
"
\n
Cache crc: "
<<
cache_crc_
<<
"
\n
Server cache id: "
<<
server_connection_id_
<<
"
\n
Cache mem size: "
<<
cache_mem_sz_
<<
"
\n
Spilling: "
<<
std
::
boolalpha
<<
spill_
;
out
<<
" Session id: "
<<
session_id
()
<<
"
\n
Cache crc: "
<<
cinfo_
.
crc
()
<<
"
\n
Server cache id: "
<<
server_connection_id_
<<
"
\n
Cache mem size: "
<<
getCacheMemSz
()
<<
"
\n
Spilling: "
<<
std
::
boolalpha
<<
isSpill
()
<<
"
\n
Hostname: "
<<
getHostname
()
<<
"
\n
Port: "
<<
getPort
()
<<
"
\n
Number of rpc workers: "
<<
getNumWorkers
()
<<
"
\n
Prefetch size: "
<<
getPrefetchSize
()
<<
"
\n
Local client support: "
<<
std
::
boolalpha
<<
SupportLocalClient
();
}
Status
CacheClient
::
WriteRow
(
const
TensorRow
&
row
,
row_id_type
*
row_id_from_server
)
const
{
CacheRowRequest
rq
(
server_connection_id_
,
cookie
());
RETURN_IF_NOT_OK
(
rq
.
SerializeCacheRowRequest
(
row
));
RETURN_IF_NOT_OK
(
CacheServer
::
GetInstance
().
PushRequest
(
&
rq
));
RETURN_IF_NOT_OK
(
rq
.
Wait
());
auto
rq
=
std
::
make_shared
<
CacheRowRequest
>
(
server_connection_id_
,
cookie
(),
SupportLocalClient
());
RETURN_IF_NOT_OK
(
rq
->
SerializeCacheRowRequest
(
this
,
row
));
RETURN_IF_NOT_OK
(
PushRequest
(
rq
));
RETURN_IF_NOT_OK
(
rq
->
Wait
());
if
(
row_id_from_server
!=
nullptr
)
{
*
row_id_from_server
=
rq
.
GetRowIdAfterCache
();
*
row_id_from_server
=
rq
->
GetRowIdAfterCache
();
}
return
Status
::
OK
();
}
...
...
@@ -47,29 +63,19 @@ Status CacheClient::WriteRow(const TensorRow &row, row_id_type *row_id_from_serv
Status
CacheClient
::
WriteBuffer
(
std
::
unique_ptr
<
DataBuffer
>
&&
in
)
const
{
std
::
unique_ptr
<
DataBuffer
>
db_ptr
=
std
::
move
(
in
);
auto
num_rows
=
db_ptr
->
NumRows
();
std
::
vector
<
TensorRow
>
all_rows
;
// We will send the requests async first on all rows and do a final wait.
if
(
num_rows
>
0
)
{
all_rows
.
reserve
(
num_rows
);
// Break down the DataBuffer into TensorRow. We will send the requests async
// and then do a final wait.
MemGuard
<
CacheRowRequest
>
rq_arr
;
RETURN_IF_NOT_OK
(
rq_arr
.
allocate
(
num_rows
,
server_connection_id_
,
cookie
()));
CacheServer
&
cs
=
CacheServer
::
GetInstance
();
auto
arr
=
std
::
make_unique
<
std
::
shared_ptr
<
CacheRowRequest
>
[]
>
(
num_rows
);
for
(
auto
i
=
0
;
i
<
num_rows
;
++
i
)
{
TensorRow
row
;
auto
rq
=
rq_arr
[
i
];
RETURN_IF_NOT_OK
(
db_ptr
->
PopRow
(
&
row
));
RETURN_IF_NOT_OK
(
rq
->
SerializeCacheRowRequest
(
row
));
RETURN_IF_NOT_OK
(
cs
.
PushRequest
(
rq
));
// We can't let row go out of scope. Otherwise it will free all the tensor memory.
// So park it in the vector. When this function go out of scope, its memory
// will be freed.
all_rows
.
push_back
(
std
::
move
(
row
));
arr
[
i
]
=
std
::
make_shared
<
CacheRowRequest
>
(
server_connection_id_
,
cookie
(),
SupportLocalClient
());
RETURN_IF_NOT_OK
(
arr
[
i
]
->
SerializeCacheRowRequest
(
this
,
row
));
RETURN_IF_NOT_OK
(
PushRequest
(
arr
[
i
]));
}
// Now we wait for the
requests to be done.
// Now we wait for the
m to come back
for
(
auto
i
=
0
;
i
<
num_rows
;
++
i
)
{
auto
rq
=
rq_arr
[
i
];
RETURN_IF_NOT_OK
(
rq
->
Wait
());
RETURN_IF_NOT_OK
(
arr
[
i
]
->
Wait
());
}
}
return
Status
::
OK
();
...
...
@@ -77,11 +83,21 @@ Status CacheClient::WriteBuffer(std::unique_ptr<DataBuffer> &&in) const {
Status
CacheClient
::
GetRows
(
const
std
::
vector
<
row_id_type
>
&
row_id
,
TensorTable
*
out
)
const
{
RETURN_UNEXPECTED_IF_NULL
(
out
);
BatchFetchRequest
rq
(
server_connection_id_
,
row_id
);
RETURN_IF_NOT_OK
(
CacheServer
::
GetInstance
().
PushRequest
(
&
rq
));
RETURN_IF_NOT_OK
(
rq
.
Wait
());
RETURN_IF_NOT_OK
(
rq
.
RestoreRows
(
out
));
return
Status
::
OK
();
auto
rq
=
std
::
make_shared
<
BatchFetchRequest
>
(
server_connection_id_
,
row_id
,
SupportLocalClient
());
RETURN_IF_NOT_OK
(
PushRequest
(
rq
));
RETURN_IF_NOT_OK
(
rq
->
Wait
());
int64_t
mem_addr
;
Status
rc
=
rq
->
RestoreRows
(
out
,
comm_
->
SharedMemoryBaseAddr
(),
&
mem_addr
);
// Free the memory by sending a request back to the server.
if
(
mem_addr
!=
-
1
)
{
auto
mfree_req
=
std
::
make_shared
<
FreeSharedBlockRequest
>
(
server_connection_id_
,
mem_addr
);
Status
rc2
=
PushRequest
(
mfree_req
);
// But we won't wait for the result for the sake of performance.
if
(
rc
.
IsOk
()
&&
rc2
.
IsError
())
{
rc
=
rc2
;
}
}
return
rc
;
}
Status
CacheClient
::
CreateCache
(
uint32_t
tree_crc
,
bool
generate_id
)
{
...
...
@@ -108,40 +124,44 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
// to create a cache and some other tree is trying to use the same cache.
// That is allowed, however the crc better match!
if
(
server_connection_id_
)
{
if
(
c
ache_crc_
!=
tree_crc
)
{
if
(
c
info_
.
crc
()
!=
tree_crc
)
{
RETURN_STATUS_UNEXPECTED
(
"Attempt to re-use a cache for a different tree!"
);
}
// Check the state of the server. For non-mappable case where there is a build phase and a fetch phase, we should
// skip the build phase.
lck
.
Unlock
();
// GetStat will grab the mutex again. So unlock it to prevent deadlock.
Cache
Client
::
ServiceStat
stat
{};
CacheServiceStat
stat
{};
RETURN_IF_NOT_OK
(
GetStat
(
&
stat
));
if
(
stat
.
cache_service_state
==
static_cast
<
uint8_t
>
(
CacheService
::
State
::
kFetchPhase
))
{
return
Status
(
StatusCode
::
kDuplicateKey
,
__LINE__
,
__FILE__
,
"Not an error and we should bypass the build phase"
);
}
}
else
{
cache_crc_
=
tree_crc
;
// It's really a new cache we're creating so save our crc in the client
// Combine the session and crc. This will form our client cache identifier.
connection_id_type
connection_identification
=
(
static_cast
<
uint64_t
>
(
session_id_
)
<<
32
)
|
cache_crc_
;
cinfo_
.
set_crc
(
tree_crc
);
// It's really a new cache we're creating so save our crc in the client
// Now execute the cache create request using this identifier and other configs
BaseRequest
::
CreateCacheFlag
createFlag
=
Bas
eRequest
::
CreateCacheFlag
::
kNone
;
CreateCacheRequest
::
CreateCacheFlag
createFlag
=
CreateCach
eRequest
::
CreateCacheFlag
::
kNone
;
if
(
spill_
)
{
createFlag
|=
Bas
eRequest
::
CreateCacheFlag
::
kSpillToDisk
;
createFlag
|=
CreateCach
eRequest
::
CreateCacheFlag
::
kSpillToDisk
;
}
if
(
generate_id
)
{
createFlag
|=
Bas
eRequest
::
CreateCacheFlag
::
kGenerateRowId
;
createFlag
|=
CreateCach
eRequest
::
CreateCacheFlag
::
kGenerateRowId
;
}
CreationCacheRequest
rq
(
connection_identification
,
cache_mem_sz_
,
createFlag
);
RETURN_IF_NOT_OK
(
CacheServer
::
GetInstance
().
PushRequest
(
&
rq
));
Status
rc
=
rq
.
Wait
();
// Start the comm layer to receive reply
RETURN_IF_NOT_OK
(
comm_
->
ServiceStart
());
// Initiate connection
auto
rq
=
std
::
make_shared
<
CreateCacheRequest
>
(
cinfo_
,
cache_mem_sz_
,
createFlag
);
RETURN_IF_NOT_OK
(
PushRequest
(
rq
));
Status
rc
=
rq
->
Wait
();
if
(
rc
.
IsOk
()
||
rc
.
get_code
()
==
StatusCode
::
kDuplicateKey
)
{
server_connection_id_
=
rq
.
GetServerConnectionId
();
std
::
string
cookie
;
rq
->
ParseResult
(
&
server_connection_id_
,
&
cookie
);
if
(
rc
.
IsOk
())
{
// The 1st guy creating the cache will get a cookie back.
// But this object may be shared among pipelines and we don't want
// overwrite it.
cookie_
=
rq
.
cookie
()
;
cookie_
=
cookie
;
}
// Attach to shared memory for local client
RETURN_IF_NOT_OK
(
comm_
->
AttachToSharedMemory
(
port_
,
&
local_bypass_
));
}
// We are not resetting the Duplicate key return code. We are passing it back to the CacheOp. This will tell the
// CacheOp to bypass the build phase.
...
...
@@ -152,57 +172,57 @@ Status CacheClient::CreateCache(uint32_t tree_crc, bool generate_id) {
Status
CacheClient
::
PurgeCache
()
{
UniqueLock
lck
(
&
mux_
);
PurgeCacheRequest
rq
(
server_connection_id_
);
RETURN_IF_NOT_OK
(
CacheServer
::
GetInstance
().
PushRequest
(
&
rq
));
return
rq
.
Wait
();
auto
rq
=
std
::
make_shared
<
PurgeCacheRequest
>
(
server_connection_id_
);
RETURN_IF_NOT_OK
(
PushRequest
(
rq
));
RETURN_IF_NOT_OK
(
rq
->
Wait
());
return
Status
::
OK
();
}
Status
CacheClient
::
DestroyCache
()
{
UniqueLock
lck
(
&
mux_
);
DestroyCacheRequest
rq
(
server_connection_id_
);
RETURN_IF_NOT_OK
(
CacheServer
::
GetInstance
().
PushRequest
(
&
rq
));
return
rq
.
Wait
();
auto
rq
=
std
::
make_shared
<
DestroyCacheRequest
>
(
server_connection_id_
);
RETURN_IF_NOT_OK
(
PushRequest
(
rq
));
RETURN_IF_NOT_OK
(
rq
->
Wait
());
return
Status
::
OK
();
}
Status
CacheClient
::
GetStat
(
ServiceStat
*
stat
)
{
Status
CacheClient
::
GetStat
(
Cache
ServiceStat
*
stat
)
{
SharedLock
lck
(
&
mux_
);
RETURN_UNEXPECTED_IF_NULL
(
stat
);
GetStatRequest
rq
(
server_connection_id_
);
RETURN_IF_NOT_OK
(
CacheServer
::
GetInstance
().
PushRequest
(
&
rq
));
RETURN_IF_NOT_OK
(
rq
.
Wait
());
stat
->
num_disk_cached
=
rq
.
GetNumDiskCached
();
stat
->
num_mem_cached
=
rq
.
GetNumMemCached
();
stat
->
min_row_id
=
rq
.
GetMinRowId
();
stat
->
max_row_id
=
rq
.
GetMaxRowId
();
stat
->
cache_service_state
=
rq
.
GetState
();
auto
rq
=
std
::
make_shared
<
GetStatRequest
>
(
server_connection_id_
);
RETURN_IF_NOT_OK
(
PushRequest
(
rq
));
RETURN_IF_NOT_OK
(
rq
->
Wait
());
rq
->
GetStat
(
stat
);
return
Status
::
OK
();
}
Status
CacheClient
::
CacheSchema
(
const
std
::
unordered_map
<
std
::
string
,
int32_t
>
&
map
)
{
SharedLock
lck
(
&
mux_
);
CacheSchemaRequest
rq
(
server_connection_id_
);
RETURN_IF_NOT_OK
(
rq
.
SerializeCacheSchemaRequest
(
map
));
RETURN_IF_NOT_OK
(
CacheServer
::
GetInstance
().
PushRequest
(
&
rq
));
RETURN_IF_NOT_OK
(
rq
.
Wait
());
auto
rq
=
std
::
make_shared
<
CacheSchemaRequest
>
(
server_connection_id_
);
RETURN_IF_NOT_OK
(
rq
->
SerializeCacheSchemaRequest
(
map
));
RETURN_IF_NOT_OK
(
PushRequest
(
rq
));
RETURN_IF_NOT_OK
(
rq
->
Wait
());
return
Status
::
OK
();
}
Status
CacheClient
::
FetchSchema
(
std
::
unordered_map
<
std
::
string
,
int32_t
>
*
map
)
{
SharedLock
lck
(
&
mux_
);
RETURN_UNEXPECTED_IF_NULL
(
map
);
FetchSchemaRequest
rq
(
server_connection_id_
);
RETURN_IF_NOT_OK
(
CacheServer
::
GetInstance
().
PushRequest
(
&
rq
));
RETURN_IF_NOT_OK
(
rq
.
Wait
());
*
map
=
rq
.
GetColumnMap
();
auto
rq
=
std
::
make_shared
<
FetchSchemaRequest
>
(
server_connection_id_
);
RETURN_IF_NOT_OK
(
PushRequest
(
rq
));
RETURN_IF_NOT_OK
(
rq
->
Wait
());
*
map
=
rq
->
GetColumnMap
();
return
Status
::
OK
();
}
Status
CacheClient
::
BuildPhaseDone
()
const
{
SharedLock
lck
(
&
mux_
);
BuildPhaseDoneRequest
rq
(
server_connection_id_
,
cookie
());
RETURN_IF_NOT_OK
(
CacheServer
::
GetInstance
().
PushRequest
(
&
rq
));
RETURN_IF_NOT_OK
(
rq
.
Wait
());
auto
rq
=
std
::
make_shared
<
BuildPhaseDoneRequest
>
(
server_connection_id_
,
cookie
());
RETURN_IF_NOT_OK
(
PushRequest
(
rq
));
RETURN_IF_NOT_OK
(
rq
->
Wait
());
return
Status
::
OK
();
}
Status
CacheClient
::
PushRequest
(
std
::
shared_ptr
<
BaseRequest
>
rq
)
const
{
return
comm_
->
HandleRequest
(
std
::
move
(
rq
));
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/engine/cache/cache_client.h
浏览文件 @
8a08d0c3
...
...
@@ -23,9 +23,13 @@
#include <utility>
#include <vector>
#include "minddata/dataset/core/config_manager.h"
#ifdef ENABLE_CACHE
#include "minddata/dataset/engine/cache/cache_grpc_client.h"
#else
#include "minddata/dataset/engine/cache/stub/cache_grpc_client.h"
#endif
#include "minddata/dataset/engine/data_buffer.h"
#include "minddata/dataset/engine/cache/cache_server.h"
#include "minddata/dataset/engine/cache/de_tensor_generated.h"
#include "minddata/dataset/util/lock.h"
namespace
mindspore
{
...
...
@@ -35,18 +39,120 @@ namespace dataset {
/// rows, etc.
class
CacheClient
{
public:
friend
class
CacheMergeOp
;
/// \brief A builder to help creating a CacheClient object
class
Builder
{
public:
Builder
()
:
session_id_
(
0
),
cache_mem_sz_
(
0
),
spill_
(
false
),
port_
(
0
),
num_workers_
(
0
),
prefetch_size_
(
0
)
{
std
::
shared_ptr
<
ConfigManager
>
cfg
=
GlobalContext
::
config_manager
();
hostname_
=
"127.0.0.1"
;
port_
=
50052
;
num_workers_
=
cfg
->
num_parallel_workers
();
prefetch_size_
=
20
;
// rows_per_buf is too small (1 by default).
}
/// Setter function to set the session id
/// \param session_id
/// \return Builder object itself.
Builder
&
SetSessionId
(
session_id_type
session_id
)
{
session_id_
=
session_id
;
return
*
this
;
}
/// Setter function to set the cache memory size
/// \param cache_mem_sz
/// \return Builder object itself
Builder
&
SetCacheMemSz
(
uint64_t
cache_mem_sz
)
{
cache_mem_sz_
=
cache_mem_sz
;
return
*
this
;
}
/// Setter function to spill attribute
/// \param spill
/// Builder object itself
Builder
&
SetSpill
(
bool
spill
)
{
spill_
=
spill
;
return
*
this
;
}
/// Setter function to set rpc hostname
/// \param host
/// \return Builder object itself
Builder
&
SetHostname
(
std
::
string
host
)
{
hostname_
=
std
::
move
(
host
);
return
*
this
;
}
/// Setter function to set tcpip port
/// \param port
/// \return Builder object itself.
Builder
&
SetPort
(
int32_t
port
)
{
port_
=
port
;
return
*
this
;
}
/// Setter function to set number of async rpc workers
/// \param num_workers
/// \return Builder object itself
Builder
&
SetNumWorkers
(
int32_t
num_workers
)
{
num_workers_
=
num_workers
;
return
*
this
;
}
/// Setter function to set prefetch amount for fetching rows from cache server
/// \param prefetch_sz
/// \return Builder object itself
Builder
&
SetPrefetchSize
(
int32_t
prefetch_sz
)
{
prefetch_size_
=
prefetch_sz
;
return
*
this
;
}
/// Getter functions
session_id_type
getSessionId
()
const
{
return
session_id_
;
}
uint64_t
getCacheMemSz
()
const
{
return
cache_mem_sz_
;
}
bool
isSpill
()
const
{
return
spill_
;
}
const
std
::
string
&
getHostname
()
const
{
return
hostname_
;
}
int32_t
getPort
()
const
{
return
port_
;
}
int32_t
getNumWorkers
()
const
{
return
num_workers_
;
}
int32_t
getPrefetchSize
()
const
{
return
prefetch_size_
;
}
Status
SanityCheck
()
{
CHECK_FAIL_RETURN_UNEXPECTED
(
session_id_
>
0
,
"session id must be positive"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
cache_mem_sz_
>=
0
,
"cache memory size must not be negative. (0 implies unlimited"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
num_workers_
>
0
,
"rpc workers must be positive"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
prefetch_size_
>
0
,
"prefetch size must be positive"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
!
hostname_
.
empty
(),
"hostname must not be empty"
);
return
Status
::
OK
();
}
Status
Build
(
std
::
shared_ptr
<
CacheClient
>
*
out
)
{
RETURN_UNEXPECTED_IF_NULL
(
out
);
RETURN_IF_NOT_OK
(
SanityCheck
());
*
out
=
std
::
make_shared
<
CacheClient
>
(
session_id_
,
cache_mem_sz_
,
spill_
,
hostname_
,
port_
,
num_workers_
,
prefetch_size_
);
return
Status
::
OK
();
}
private:
session_id_type
session_id_
;
uint64_t
cache_mem_sz_
;
bool
spill_
;
std
::
string
hostname_
;
int32_t
port_
;
int32_t
num_workers_
;
int32_t
prefetch_size_
;
};
/// \brief Constructor
/// \param session_id A user assigned session id for the current pipeline
/// \param cache_mem_sz Size of the memory set aside for the row caching. 0 for unlimited
/// \param spill Spill to disk if out of memory
CacheClient
(
uint32_t
session_id
,
uint64_t
cache_mem_sz
,
bool
spill
);
CacheClient
(
session_id_type
session_id
,
uint64_t
cache_mem_sz
,
bool
spill
,
std
::
string
hostname
,
int32_t
port
,
int32_t
num_workers
,
int32_t
prefetch_size
);
/// \brief Destructor
~
CacheClient
()
=
default
;
/// \brief Getter function for returning the current session id
/// \return session id
uint64_t
session_id
()
const
{
return
session_id_
;
}
~
CacheClient
()
{
(
void
)
comm_
->
ServiceStop
();
}
/// \brief Send a TensorRow to the cache server
/// \param[in] row
...
...
@@ -83,14 +189,7 @@ class CacheClient {
/// \brief Get the statistics from a cache.
/// \param[in/out] Pointer to a pre-allocated ServiceStat object
/// \return Status object
struct
ServiceStat
{
int64_t
num_mem_cached
;
int64_t
num_disk_cached
;
row_id_type
min_row_id
;
row_id_type
max_row_id
;
int8_t
cache_service_state
;
};
Status
GetStat
(
ServiceStat
*
);
Status
GetStat
(
CacheServiceStat
*
);
/// \brief Cache the schema at the cache server
/// \param map The unordered map of the schema
...
...
@@ -122,18 +221,45 @@ class CacheClient {
/// \return Cookie
std
::
string
cookie
()
const
{
return
cookie_
;
}
/// \brief Send a request async to the server
/// \param rq BaseRequest
/// \return Status object
Status
PushRequest
(
std
::
shared_ptr
<
BaseRequest
>
rq
)
const
;
/// \brief If the remote server supports local bypass using shared memory
/// \return boolean value
bool
SupportLocalClient
()
const
{
return
local_bypass_
;
}
/// \brief Return the base memory address if we attach to any shared memory.
auto
SharedMemoryBaseAddr
()
const
{
return
comm_
->
SharedMemoryBaseAddr
();
}
/// Getter functions
session_id_type
session_id
()
const
{
return
cinfo_
.
session_id
();
}
uint64_t
getCacheMemSz
()
const
{
return
cache_mem_sz_
;
}
bool
isSpill
()
const
{
return
spill_
;
}
const
std
::
string
&
getHostname
()
const
{
return
hostname_
;
}
int32_t
getPort
()
const
{
return
port_
;
}
int32_t
getNumWorkers
()
const
{
return
num_workers_
;
}
int32_t
getPrefetchSize
()
const
{
return
prefetch_size_
;
}
private:
mutable
RWLock
mux_
;
uint64_t
cache_mem_sz_
;
bool
spill_
;
// The session_id_ and cache_crc_ work together to uniquely identify this particular cache and allow
// sharing of the cache.
uint32_t
session_id_
;
uint32_t
cache_crc_
;
CacheClientInfo
cinfo_
;
// The server_connection_id_ is the actual id we use for operations after the cache is built
connection_id_type
server_connection_id_
;
// Some magic cookie returned from the cache server.
std
::
string
cookie_
;
// Comm layer
bool
local_bypass_
;
std
::
string
hostname_
;
int32_t
port_
;
int32_t
num_workers_
;
int32_t
prefetch_size_
;
mutable
std
::
shared_ptr
<
CacheClientGreeter
>
comm_
;
};
}
// namespace dataset
}
// namespace mindspore
...
...
mindspore/ccsrc/minddata/dataset/engine/cache/cache_common.h
0 → 100644
浏览文件 @
8a08d0c3
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_
/// \note This header file contains common header files and some inlines used by
/// both client and server side codes. Do not put code that is not common here.
/// There are client and server specific header files.
// On platform like Windows, we may support only tcp/ip clients
#if !defined(_WIN32) && !defined(_WIN64)
#define CACHE_LOCAL_CLIENT 1
#endif
#ifdef CACHE_LOCAL_CLIENT
#include <sys/types.h>
#include <sys/ipc.h>
#include <sys/shm.h>
#else
typedef
int
key_t
;
#endif
#ifdef ENABLE_CACHE
#include <grpcpp/grpcpp.h>
#endif
#include <string>
#ifdef ENABLE_CACHE
#include "proto/cache_grpc.grpc.pb.h"
#endif
#include "proto/cache_grpc.pb.h"
#include "minddata/dataset/engine/cache/cache_request.h"
#include "minddata/dataset/engine/cache/de_tensor_generated.h"
namespace
mindspore
{
namespace
dataset
{
/// \brief CacheRow and BatchFetch requests will switch to use shared memory method (if supported
/// on the platform) when the amount of bytes sent is greater than the following number.
/// For too small amount, we won't get any benefit using shared memory method because we need
/// two rpc requests to use shared memory method.
constexpr
static
int32_t
kLocalByPassThreshold
=
64
*
1024
;
/// \brief A flag used by the BatchFetch request (client side) if it can support local bypass
constexpr
static
uint32_t
kLocalClientSupport
=
1
;
/// \brief A flag used by CacheRow request (client side) and BatchFetch (server side) reply to indicate if the data is
/// inline in the protobuf. This also implies kLocalClientSupport is also true.
constexpr
static
uint32_t
kDataIsInSharedMemory
=
2
;
/// \brief Convert a Status object into a protobuf
/// \param rc[in] Status object
/// \param reply[in/out] pointer to pre-allocated protobuf object
inline
void
Status2CacheReply
(
const
Status
&
rc
,
CacheReply
*
reply
)
{
reply
->
set_rc
(
static_cast
<
google
::
int32
>
(
rc
.
get_code
()));
reply
->
set_msg
(
rc
.
ToString
());
}
/// \brief Generate the unix socket file we use on both client/server side given a tcp/ip port number
/// \param port
/// \return unix socket url
inline
std
::
string
PortToUnixSocketPath
(
int
port
)
{
return
"/tmp/cache_server_p"
+
std
::
to_string
(
port
);
}
/// \brief Generate a shared memory key using the tcp/ip port.
/// \note It must be called after the cache server generates the unix socket or ftok will fail.
/// \note Caller must check the return value. -1 means ftok failed.
/// \param[in] port
/// \param[out] err. If not null and ftok fails, this will contain the value of errno
/// \return key
inline
key_t
PortToFtok
(
int
port
,
int
*
err
)
{
key_t
shmkey
=
-
1
;
#ifdef CACHE_LOCAL_CLIENT
const
std
::
string
unix_path
=
PortToUnixSocketPath
(
port
);
shmkey
=
ftok
(
unix_path
.
data
(),
'a'
);
if
(
err
!=
nullptr
&&
shmkey
==
(
key_t
)
-
1
)
{
*
err
=
errno
;
}
#endif
return
shmkey
;
}
}
// namespace dataset
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_COMMON_H_
mindspore/ccsrc/minddata/dataset/engine/cache/cache_fbb.cc
0 → 100644
浏览文件 @
8a08d0c3
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_fbb.h"
namespace
mindspore
{
namespace
dataset
{
/// A private function used by SerializeTensorRowHeader to serialize each column in a tensor
/// \note Not to be called by outside world
/// \return Status object
Status
SerializeOneTensorMeta
(
const
std
::
shared_ptr
<
flatbuffers
::
FlatBufferBuilder
>
&
fbb
,
const
std
::
shared_ptr
<
Tensor
>
&
ts_ptr
,
flatbuffers
::
Offset
<
TensorMetaMsg
>
*
out_off
)
{
RETURN_UNEXPECTED_IF_NULL
(
out_off
);
const
Tensor
*
ts
=
ts_ptr
.
get
();
auto
shape_off
=
fbb
->
CreateVector
(
ts
->
shape
().
AsVector
());
const
auto
ptr
=
ts
->
GetBuffer
();
if
(
ptr
==
nullptr
)
{
RETURN_STATUS_UNEXPECTED
(
"Tensor buffer is null"
);
}
auto
src
=
ts
->
type
().
value
();
TensorType
dest
;
#define CASE(t) \
case DataType::t: \
dest = TensorType::TensorType_##t; \
break
// Map the type to fill in the flat buffer.
switch
(
src
)
{
CASE
(
DE_BOOL
);
CASE
(
DE_INT8
);
CASE
(
DE_UINT8
);
CASE
(
DE_INT16
);
CASE
(
DE_UINT16
);
CASE
(
DE_INT32
);
CASE
(
DE_UINT32
);
CASE
(
DE_INT64
);
CASE
(
DE_UINT64
);
CASE
(
DE_FLOAT16
);
CASE
(
DE_FLOAT32
);
CASE
(
DE_FLOAT64
);
CASE
(
DE_STRING
);
default:
MS_LOG
(
ERROR
)
<<
"Unknown tensor. Dumping content:
\n
"
<<
*
ts
;
RETURN_STATUS_UNEXPECTED
(
"Unknown type"
);
}
#undef CASE
TensorMetaMsgBuilder
ts_builder
(
*
fbb
);
ts_builder
.
add_dims
(
shape_off
);
ts_builder
.
add_type
(
dest
);
auto
ts_off
=
ts_builder
.
Finish
();
*
out_off
=
ts_off
;
return
Status
::
OK
();
}
Status
SerializeTensorRowHeader
(
const
TensorRow
&
row
,
std
::
shared_ptr
<
flatbuffers
::
FlatBufferBuilder
>
*
out_fbb
)
{
RETURN_UNEXPECTED_IF_NULL
(
out_fbb
);
auto
fbb
=
std
::
make_shared
<
flatbuffers
::
FlatBufferBuilder
>
();
try
{
fbb
=
std
::
make_shared
<
flatbuffers
::
FlatBufferBuilder
>
();
std
::
vector
<
flatbuffers
::
Offset
<
TensorMetaMsg
>>
v
;
std
::
vector
<
int64_t
>
tensor_sz
;
v
.
reserve
(
row
.
size
());
tensor_sz
.
reserve
(
row
.
size
());
// We will go through each column in the row.
for
(
const
std
::
shared_ptr
<
Tensor
>
&
ts_ptr
:
row
)
{
flatbuffers
::
Offset
<
TensorMetaMsg
>
ts_off
;
RETURN_IF_NOT_OK
(
SerializeOneTensorMeta
(
fbb
,
ts_ptr
,
&
ts_off
));
v
.
push_back
(
ts_off
);
tensor_sz
.
push_back
(
ts_ptr
->
SizeInBytes
());
}
auto
column_off
=
fbb
->
CreateVector
(
v
);
auto
data_sz_off
=
fbb
->
CreateVector
(
tensor_sz
);
TensorRowHeaderMsgBuilder
row_builder
(
*
fbb
);
row_builder
.
add_column
(
column_off
);
row_builder
.
add_data_sz
(
data_sz_off
);
// Pass the row_id even if it may not be known.
row_builder
.
add_row_id
(
row
.
getId
());
row_builder
.
add_size_of_this
(
-
1
);
// fill in later after we call Finish.
auto
out
=
row_builder
.
Finish
();
fbb
->
Finish
(
out
);
// Now go back to fill in size_of_this in the flat buffer.
auto
msg
=
GetMutableTensorRowHeaderMsg
(
fbb
->
GetBufferPointer
());
auto
success
=
msg
->
mutate_size_of_this
(
fbb
->
GetSize
());
if
(
!
success
)
{
RETURN_STATUS_UNEXPECTED
(
"Unable to set size_of_this"
);
}
(
*
out_fbb
)
=
std
::
move
(
fbb
);
return
Status
::
OK
();
}
catch
(
const
std
::
bad_alloc
&
e
)
{
return
Status
(
StatusCode
::
kOutOfMemory
,
__LINE__
,
__FILE__
);
}
}
Status
RestoreOneTensor
(
const
TensorMetaMsg
*
col_ts
,
const
ReadableSlice
&
data
,
std
::
shared_ptr
<
Tensor
>
*
out
)
{
RETURN_UNEXPECTED_IF_NULL
(
col_ts
);
auto
shape_in
=
col_ts
->
dims
();
auto
type_in
=
col_ts
->
type
();
std
::
vector
<
dsize_t
>
v
;
v
.
reserve
(
shape_in
->
size
());
v
.
assign
(
shape_in
->
begin
(),
shape_in
->
end
());
TensorShape
shape
(
v
);
DataType
::
Type
dest
=
DataType
::
DE_UNKNOWN
;
#define CASE(t) \
case TensorType_##t: \
dest = DataType::Type::t; \
break
switch
(
type_in
)
{
CASE
(
DE_BOOL
);
CASE
(
DE_INT8
);
CASE
(
DE_UINT8
);
CASE
(
DE_INT16
);
CASE
(
DE_UINT16
);
CASE
(
DE_INT32
);
CASE
(
DE_UINT32
);
CASE
(
DE_INT64
);
CASE
(
DE_UINT64
);
CASE
(
DE_FLOAT16
);
CASE
(
DE_FLOAT32
);
CASE
(
DE_FLOAT64
);
CASE
(
DE_STRING
);
}
#undef CASE
DataType
type
(
dest
);
std
::
shared_ptr
<
Tensor
>
ts
;
RETURN_IF_NOT_OK
(
Tensor
::
CreateFromMemory
(
shape
,
type
,
static_cast
<
const
unsigned
char
*>
(
data
.
GetPointer
()),
data
.
GetSize
(),
&
ts
));
// Next we restore the real data which can be embedded or stored separately.
if
(
ts
->
SizeInBytes
()
!=
data
.
GetSize
())
{
MS_LOG
(
ERROR
)
<<
"Unexpected length. Read "
<<
data
.
GetSize
()
<<
". Expected "
<<
ts
->
SizeInBytes
()
<<
".
\n
"
<<
"Dumping tensor
\n
"
<<
*
ts
<<
"
\n
"
;
RETURN_STATUS_UNEXPECTED
(
"Length mismatch. See log file for details."
);
}
*
out
=
std
::
move
(
ts
);
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/engine/cache/cache_fbb.h
0 → 100644
浏览文件 @
8a08d0c3
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_FBB_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_FBB_H_
/// This header contains some serialize and deserialize functions for tensor row using
/// Google Flatbuffer
#include <memory>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/cache/de_tensor_generated.h"
#include "minddata/dataset/core/tensor_row.h"
#include "minddata/dataset/util/slice.h"
#include "minddata/dataset/util/status.h"
namespace
mindspore
{
namespace
dataset
{
/// \brief Function to serialize TensorRow header used by CacheRowRequest
/// \param row TensorRow
/// \param fbb [in/out] fbb that contains the serialized data
/// \return Status object
Status
SerializeTensorRowHeader
(
const
TensorRow
&
row
,
std
::
shared_ptr
<
flatbuffers
::
FlatBufferBuilder
>
*
fbb
);
/// \brief A function used by BatchFetchRequest to deserialize a flat buffer back to a tensor row.
/// \param col_ts A serialized version of Tensor meta data
/// \param data Tensor data wrapped in a slice
/// \param out Tensor
/// \return Status object
Status
RestoreOneTensor
(
const
TensorMetaMsg
*
col_ts
,
const
ReadableSlice
&
data
,
std
::
shared_ptr
<
Tensor
>
*
out
);
}
// namespace dataset
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_FBB_H_
mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc.proto
0 → 100644
浏览文件 @
8a08d0c3
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
syntax
=
"proto3"
;
package
mindspore
.
dataset
;
option
cc_enable_arenas
=
true
;
// The session_id and crc work together to uniquely identify this particular cache and allow
// sharing of the cache.
message
CacheClientInfo
{
uint32
session_id
=
1
;
uint32
crc
=
2
;
}
message
CacheRequest
{
// Type of rpc request
int32
type
=
1
;
// Extra optional flag used by individual request if needed
uint32
flag
=
2
;
oneof
connect_info
{
// The server_connection_id is the actual id we use for operations after the cache is built
int64
connection_id
=
3
;
// But some request like CreateCache we have to use the session id and crc to connect to the server.
CacheClientInfo
connection_info
=
4
;
}
// Everything else is just vector of buffers
repeated
bytes
buf_data
=
5
;
}
message
CacheReply
{
int32
rc
=
1
;
string
msg
=
2
;
// Extra optional flag used by individual request if needed
uint32
flag
=
3
;
// What the server send back is a plain buffer
bytes
result
=
4
;
}
service
CacheServerGreeter
{
rpc
CacheServerRequest
(
CacheRequest
)
returns
(
CacheReply
)
{}
}
mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.cc
0 → 100644
浏览文件 @
8a08d0c3
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_grpc_client.h"
#include <chrono>
namespace
mindspore
{
namespace
dataset
{
Status
CacheClientRequestTag
::
MakeCall
(
CacheServerGreeter
::
Stub
*
stub
,
grpc
::
CompletionQueue
*
cq
,
std
::
unique_ptr
<
CacheClientRequestTag
>
&&
tag
)
{
// If there is anything extra we need to do before we send.
RETURN_IF_NOT_OK
(
tag
->
base_rq_
->
Prepare
());
// One minute timeout
auto
deadline
=
std
::
chrono
::
system_clock
::
now
()
+
std
::
chrono
::
seconds
(
60
);
tag
->
ctx_
.
set_deadline
(
deadline
);
tag
->
rpc_
=
stub
->
PrepareAsyncCacheServerRequest
(
&
tag
->
ctx_
,
tag
->
base_rq_
->
rq_
,
cq
);
tag
->
rpc_
->
StartCall
();
// Last step is we release the ownership and transfer it to the completion queue.
// The memory will be released by WorkerEntry or by the destructor when we drain the queue
auto
ccReqTag
=
tag
.
release
();
ccReqTag
->
rpc_
->
Finish
(
&
ccReqTag
->
base_rq_
->
reply_
,
&
ccReqTag
->
rc_
,
ccReqTag
);
// inject this object into the completion queue
return
Status
::
OK
();
}
CacheClientGreeter
::~
CacheClientGreeter
()
{
(
void
)
ServiceStop
();
// Detach from shared memory if any
if
(
shmat_addr_
!=
nullptr
)
{
shmdt
(
shmat_addr_
);
shmat_addr_
=
nullptr
;
}
}
CacheClientGreeter
::
CacheClientGreeter
(
const
std
::
string
&
hostname
,
int32_t
port
,
int32_t
num_workers
)
:
num_workers_
(
num_workers
),
shm_key_
(
-
1
),
shm_id_
(
-
1
),
shmat_addr_
(
nullptr
)
{
grpc
::
ChannelArguments
args
;
// We need to bump up the message size to unlimited. The default receiving
// message limit is 4MB which is not big enough.
args
.
SetMaxReceiveMessageSize
(
-
1
);
#if CACHE_LOCAL_CLIENT
// Try connect locally to the unix_socket first as the first preference
// Need to resolve hostname to ip address rather than to do a string compare
if
(
hostname
==
"127.0.0.1"
)
{
std
::
string
target
=
"unix://"
+
PortToUnixSocketPath
(
port
);
channel_
=
grpc
::
CreateCustomChannel
(
target
,
grpc
::
InsecureChannelCredentials
(),
args
);
}
else
{
#endif
std
::
string
target
=
hostname
+
":"
+
std
::
to_string
(
port
);
channel_
=
grpc
::
CreateCustomChannel
(
target
,
grpc
::
InsecureChannelCredentials
(),
args
);
#if CACHE_LOCAL_CLIENT
}
#endif
stub_
=
CacheServerGreeter
::
NewStub
(
channel_
);
}
Status
CacheClientGreeter
::
AttachToSharedMemory
(
int32_t
port
,
bool
*
local_bypass
)
{
*
local_bypass
=
false
;
#if CACHE_LOCAL_CLIENT
int
err
;
shm_key_
=
PortToFtok
(
port
,
&
err
);
if
(
shm_key_
==
(
key_t
)
-
1
)
{
std
::
string
errMsg
=
"Ftok failed with errno "
+
std
::
to_string
(
err
);
RETURN_STATUS_UNEXPECTED
(
errMsg
);
}
// Attach to the shared memory
shm_id_
=
shmget
(
shm_key_
,
0
,
0
);
if
(
shm_id_
==
-
1
)
{
RETURN_STATUS_UNEXPECTED
(
"Shmget failed. Errno "
+
std
::
to_string
(
errno
));
}
shmat_addr_
=
shmat
(
shm_id_
,
nullptr
,
0
);
if
(
shmat_addr_
==
reinterpret_cast
<
void
*>
(
-
1
))
{
RETURN_STATUS_UNEXPECTED
(
"Shared memory attach failed. Errno "
+
std
::
to_string
(
errno
));
}
*
local_bypass
=
true
;
#endif
return
Status
::
OK
();
}
Status
CacheClientGreeter
::
DoServiceStart
()
{
RETURN_IF_NOT_OK
(
vg_
.
ServiceStart
());
RETURN_IF_NOT_OK
(
DispatchWorkers
(
num_workers_
));
return
Status
::
OK
();
}
Status
CacheClientGreeter
::
DoServiceStop
()
{
// Shutdown the queue. We don't accept any more new incomers.
cq_
.
Shutdown
();
// Shutdown the TaskGroup.
vg_
.
interrupt_all
();
vg_
.
join_all
(
Task
::
WaitFlag
::
kNonBlocking
);
// Drain the queue
bool
success
;
void
*
tag
;
while
(
cq_
.
Next
(
&
tag
,
&
success
))
{
auto
r
=
reinterpret_cast
<
CacheClientRequestTag
*>
(
tag
);
delete
r
;
}
return
Status
::
OK
();
}
Status
CacheClientGreeter
::
HandleRequest
(
std
::
shared_ptr
<
BaseRequest
>
rq
)
{
auto
tag
=
std
::
make_unique
<
CacheClientRequestTag
>
(
std
::
move
(
rq
));
return
tag
->
MakeCall
(
stub_
.
get
(),
&
cq_
,
std
::
move
(
tag
));
}
Status
CacheClientGreeter
::
WorkerEntry
()
{
TaskManager
::
FindMe
()
->
Post
();
do
{
bool
success
;
void
*
tag
;
auto
deadline
=
std
::
chrono
::
system_clock
::
now
()
+
std
::
chrono
::
seconds
(
1
);
// Set a timeout for one second. Check for interrupt if we need to do early exit.
auto
r
=
cq_
.
AsyncNext
(
&
tag
,
&
success
,
deadline
);
if
(
r
==
grpc_impl
::
CompletionQueue
::
NextStatus
::
GOT_EVENT
)
{
auto
rq
=
reinterpret_cast
<
CacheClientRequestTag
*>
(
tag
);
if
(
success
)
{
auto
&
rc
=
rq
->
rc_
;
if
(
!
rc
.
ok
())
{
auto
error_code
=
rq
->
rc_
.
error_code
();
std
::
string
errMsg
=
rq
->
rc_
.
error_message
()
+
". GRPC Code "
+
std
::
to_string
(
error_code
);
Status
remote_rc
=
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
errMsg
);
Status2CacheReply
(
remote_rc
,
&
rq
->
base_rq_
->
reply_
);
}
// Notify the waiting thread.
rq
->
Notify
();
}
// We can now free the memory
delete
rq
;
}
else
if
(
r
==
grpc_impl
::
CompletionQueue
::
NextStatus
::
TIMEOUT
)
{
// If we are interrupted, exit. Otherwise wait again.
RETURN_IF_INTERRUPTED
();
}
else
{
// Queue is drained.
break
;
}
}
while
(
true
);
return
Status
::
OK
();
}
Status
CacheClientGreeter
::
DispatchWorkers
(
int32_t
num_workers
)
{
auto
f
=
std
::
bind
(
&
CacheClientGreeter
::
WorkerEntry
,
this
);
for
(
auto
i
=
0
;
i
<
num_workers
;
++
i
)
{
RETURN_IF_NOT_OK
(
vg_
.
CreateAsyncTask
(
"Async reply"
,
f
));
}
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_client.h
0 → 100644
浏览文件 @
8a08d0c3
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_
#include <memory>
#include <string>
#include <utility>
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/util/service.h"
#include "minddata/dataset/util/task_manager.h"
namespace
mindspore
{
namespace
dataset
{
/// \brief A client view of gRPC request
/// Like the class CacheServerRequest, this is used as a tag to inject into the gRPC
/// completion queue. The thread that makes the rpc request will wait on a wait post
/// area for the reply to come back. Since this tag will be deleted from memory and
/// we thus we need to work on a shared pointer of the BaseRequest such that its
/// use count is at least two. Otherwise either thread will be referencing stale memory.
/// \see CacheServerRequest
class
CacheClientRequestTag
{
public:
friend
class
CacheClientGreeter
;
explicit
CacheClientRequestTag
(
std
::
shared_ptr
<
BaseRequest
>
rq
)
:
base_rq_
(
std
::
move
(
rq
))
{}
~
CacheClientRequestTag
()
=
default
;
/// \brief Make a RPC call
/// \param stub from CacheClientGreeter
/// \param cq from CacheClientGreeter
/// \return Status object
static
Status
MakeCall
(
CacheServerGreeter
::
Stub
*
stub
,
grpc
::
CompletionQueue
*
cq
,
std
::
unique_ptr
<
CacheClientRequestTag
>
&&
tag
);
/// \brief Notify the client that a result has come back from the server
void
Notify
()
{
base_rq_
->
wp_
.
Set
();
}
private:
std
::
shared_ptr
<
BaseRequest
>
base_rq_
;
grpc
::
Status
rc_
;
grpc
::
ClientContext
ctx_
;
std
::
unique_ptr
<
grpc
::
ClientAsyncResponseReader
<
CacheReply
>>
rpc_
;
};
/// \brief A GRPC layer to convert BaseRequest into protobuf and send to the cache server using gRPC
/// \see BaseRequest
class
CacheClientGreeter
:
public
Service
{
friend
class
CacheClient
;
public:
explicit
CacheClientGreeter
(
const
std
::
string
&
hostname
,
int32_t
port
,
int32_t
num_workers
);
~
CacheClientGreeter
();
/// Override base Service class
Status
DoServiceStart
()
override
;
Status
DoServiceStop
()
override
;
/// \brief Send the request to the server
/// \return Status object
Status
HandleRequest
(
std
::
shared_ptr
<
BaseRequest
>
rq
);
/// \brief A handful of threads will be handling async reply from the server
/// \return
Status
WorkerEntry
();
/// \brief Kick off threads to receive reply from the server
Status
DispatchWorkers
(
int32_t
num_workers
);
/// \brief Attach to shared memory for local client
/// \note Called after we have established a connection.
/// \return Status object.
Status
AttachToSharedMemory
(
int32_t
port
,
bool
*
local_bypass
);
/// \brief This returns where we attach to the shared memory.
/// \return Base address of the shared memory.
const
void
*
SharedMemoryBaseAddr
()
const
{
return
shmat_addr_
;
}
private:
std
::
shared_ptr
<
grpc
::
Channel
>
channel_
;
std
::
unique_ptr
<
CacheServerGreeter
::
Stub
>
stub_
;
grpc
::
CompletionQueue
cq_
;
TaskGroup
vg_
;
int32_t
num_workers_
;
key_t
shm_key_
;
int32_t
shm_id_
;
void
*
shmat_addr_
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_CLIENT_H_
mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.cc
0 → 100644
浏览文件 @
8a08d0c3
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_grpc_server.h"
#include <limits>
#include "minddata/dataset/engine/cache/cache_server.h"
#include "minddata/dataset/util/path.h"
#include "utils/log_adapter.h"
namespace
mindspore
{
namespace
dataset
{
CacheServerGreeterImpl
::
CacheServerGreeterImpl
(
int32_t
port
,
int32_t
shared_memory_sz_in_gb
)
:
port_
(
port
),
shm_pool_sz_in_gb_
(
shared_memory_sz_in_gb
)
{
// Setup a path for unix socket.
unix_socket_
=
PortToUnixSocketPath
(
port
);
// We can't generate the ftok key yet until the unix_socket_ is created
}
void
CacheServerGreeterImpl
::
Shutdown
()
{
if
(
server_
)
{
auto
deadline
=
std
::
chrono
::
system_clock
::
now
()
+
std
::
chrono
::
seconds
(
1
);
server_
->
Shutdown
(
deadline
);
}
// Always shutdown the completion queue after the server.
if
(
cq_
)
{
cq_
->
Shutdown
();
// We need to drain the queue. All the tag is coming from
// the Services pool which will be shutdown as well. So we
// ignore the tag.
void
*
tag
;
bool
success
;
while
(
cq_
->
Next
(
&
tag
,
&
success
))
{
}
}
}
CacheServerGreeterImpl
::~
CacheServerGreeterImpl
()
{
Shutdown
();
}
Status
CacheServerGreeterImpl
::
IpcResourceCleanup
()
{
#if CACHE_LOCAL_CLIENT
int
err
;
auto
shm_key
=
PortToFtok
(
port_
,
&
err
);
// We are expecting the unix path doesn't exist.
if
(
shm_key
==
(
key_t
)
-
1
)
{
return
Status
::
OK
();
}
// Attach to the shared memory
auto
shm_id
=
shmget
(
shm_key
,
0
,
0
);
if
(
shm_id
==
-
1
)
{
return
Status
::
OK
();
}
struct
shmid_ds
ds
{};
auto
inx
=
shmctl
(
shm_id
,
IPC_STAT
,
&
ds
);
if
(
inx
==
-
1
)
{
std
::
string
errMsg
=
"Unable to query shared memory with id "
+
std
::
to_string
(
shm_id
);
errMsg
+=
"
\n
Plesae remove it manually using ipcrm -m command"
;
RETURN_STATUS_UNEXPECTED
(
errMsg
);
}
if
(
ds
.
shm_nattch
==
0
)
{
// Stale shared memory from last time.
// Remove both the memory and the socket path
inx
=
shmctl
(
shm_id
,
IPC_RMID
,
nullptr
);
if
(
inx
==
-
1
)
{
std
::
string
errMsg
=
"Unable to remove shared memory with id "
+
std
::
to_string
(
shm_id
);
errMsg
+=
". Errno :"
+
std
::
to_string
(
errno
);
errMsg
+=
"
\n
Plesae remove it manually using ipcrm -m command"
;
RETURN_STATUS_UNEXPECTED
(
errMsg
);
}
Path
p
(
unix_socket_
);
(
void
)
p
.
Remove
();
}
else
{
// Server is already up.
MS_LOG
(
ERROR
)
<<
"Cache server is already up and running"
;
// We return a duplicate error. The main() will intercept
// and output a proper message
return
Status
(
StatusCode
::
kDuplicateKey
);
}
#endif
return
Status
::
OK
();
}
Status
CacheServerGreeterImpl
::
Run
()
{
// To listen on all interfaces, use 0.0.0.0
// Use 127.0.0.1 if just locally on the same machine.
std
::
string
host
(
"0.0.0.0"
);
// listen on all interfaces.
std
::
string
server_address
=
host
+
":"
+
std
::
to_string
(
port_
);
grpc
::
ServerBuilder
builder
;
// Default message size for gRPC is 4MB. Increase it to 2g-1
builder
.
SetMaxReceiveMessageSize
(
std
::
numeric_limits
<
int32_t
>::
max
());
int
port_tcpip
=
0
;
#if CACHE_LOCAL_CLIENT
int
port_local
=
0
;
// Check if we need to do clean up on the shared memory if the server
// came down unexpectedly like SEGV
RETURN_IF_NOT_OK
(
IpcResourceCleanup
());
// We also optimize on local clients on the same machine using unix socket
builder
.
AddListeningPort
(
"unix://"
+
unix_socket_
,
grpc
::
InsecureServerCredentials
(),
&
port_local
);
#endif
builder
.
AddListeningPort
(
server_address
,
grpc
::
InsecureServerCredentials
(),
&
port_tcpip
);
builder
.
RegisterService
(
&
svc_
);
cq_
=
builder
.
AddCompletionQueue
();
server_
=
builder
.
BuildAndStart
();
if
(
server_
)
{
MS_LOG
(
INFO
)
<<
"Server listening on "
<<
server_address
;
#if CACHE_LOCAL_CLIENT
RETURN_IF_NOT_OK
(
CachedSharedMemoryArena
::
CreateArena
(
&
shm_pool_
,
port_
,
shm_pool_sz_in_gb_
));
MS_LOG
(
INFO
)
<<
"Creation of local socket and shared memory successful"
;
#endif
}
else
{
std
::
string
errMsg
=
"Fail to start server. "
;
if
(
port_tcpip
!=
port_
)
{
errMsg
+=
"Unable to bind to tcpip port "
+
std
::
to_string
(
port_
)
+
"."
;
}
#if CACHE_LOCAL_CLIENT
if
(
port_local
==
0
)
{
errMsg
+=
" Unable to create unix socket "
+
unix_socket_
+
"."
;
}
#endif
RETURN_STATUS_UNEXPECTED
(
errMsg
);
}
return
Status
::
OK
();
}
Status
CacheServerGreeterImpl
::
HandleRequest
(
int32_t
worker_id
)
{
bool
success
;
void
*
tag
;
// We loop through the grpc queue. Each connection if successful
// will come back with our own tag which is an instance of CacheServerRequest
// and we simply call its functor. But first we need to create these instances
// and inject them into the grpc queue.
CacheServerRequest
*
p
;
// Get a free tag from my free list.
RETURN_IF_NOT_OK
(
CacheServer
::
GetFreeRequestTag
(
worker_id
,
&
p
));
RETURN_IF_NOT_OK
((
*
p
)(
&
svc_
,
cq_
.
get
()));
do
{
auto
deadline
=
std
::
chrono
::
system_clock
::
now
()
+
std
::
chrono
::
seconds
(
1
);
// Set a timeout for one second. Check for interrupt if we need to do early exit.
auto
r
=
cq_
->
AsyncNext
(
&
tag
,
&
success
,
deadline
);
if
(
r
==
grpc_impl
::
CompletionQueue
::
NextStatus
::
GOT_EVENT
)
{
if
(
success
)
{
auto
rq
=
static_cast
<
CacheServerRequest
*>
(
tag
);
RETURN_IF_NOT_OK
((
*
rq
)(
&
svc_
,
cq_
.
get
()));
}
}
else
if
(
r
==
grpc_impl
::
CompletionQueue
::
NextStatus
::
TIMEOUT
)
{
// If we are interrupted, exit. Otherwise wait again.
RETURN_IF_INTERRUPTED
();
}
else
{
// Queue is drained.
break
;
}
}
while
(
true
);
return
Status
::
OK
();
}
Status
CacheServerRequest
::
operator
()(
CacheServerGreeter
::
AsyncService
*
svc
,
grpc
::
ServerCompletionQueue
*
cq
)
{
auto
myQID
=
getQid
();
if
(
st_
==
STATE
::
CREATE
)
{
st_
=
STATE
::
PROCESS
;
svc
->
RequestCacheServerRequest
(
&
ctx_
,
&
rq_
,
&
responder_
,
cq
,
cq
,
this
);
}
else
if
(
st_
==
STATE
::
PROCESS
)
{
// Get a new tag and handle the next request before we serve the current request.
// The tag will be recycled when its state is changed to FINISH
CacheServerRequest
*
next_rq
;
RETURN_IF_NOT_OK
(
CacheServer
::
GetFreeRequestTag
(
myQID
,
&
next_rq
));
RETURN_IF_NOT_OK
((
*
next_rq
)(
svc
,
cq
));
// Now we continue with the current request.
// First thing we need to extract the type from the incoming request.
// When this object was first created (i.e. STATE::CREATE), we set the type to UNKNOWN.
type_
=
static_cast
<
RequestType
>
(
rq_
.
type
());
// Now we pass the address of this instance to CacheServer's main loop.
MS_LOG
(
DEBUG
)
<<
"Handle request "
<<
*
this
;
auto
&
cs
=
CacheServer
::
GetInstance
();
RETURN_IF_NOT_OK
(
cs
.
PushRequest
(
myQID
,
this
));
}
else
if
(
st_
==
STATE
::
FINISH
)
{
MS_LOG
(
DEBUG
)
<<
*
this
<<
" Finished."
;
// Return back to the free list.
RETURN_IF_NOT_OK
(
CacheServer
::
ReturnRequestTag
(
this
));
}
return
Status
::
OK
();
}
void
CacheServerRequest
::
Print
(
std
::
ostream
&
out
)
const
{
if
(
rq_
.
has_connection_info
())
{
out
<<
"Session Id: "
<<
rq_
.
connection_info
().
session_id
()
<<
" CRC: "
<<
rq_
.
connection_info
().
crc
();
}
else
{
out
<<
"Connection Id: "
<<
rq_
.
connection_id
();
}
out
<<
" "
;
BaseRequest
::
Print
(
out
);
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/engine/cache/cache_grpc_server.h
0 → 100644
浏览文件 @
8a08d0c3
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/engine/cache/cache_arena.h"
#include "minddata/dataset/util/allocator.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/util/task_manager.h"
namespace
mindspore
{
namespace
dataset
{
/// \brief Server side view of BaseRequest. Incoming request are in the form of protobuf objects
/// and this class is used to translate from protobuf to structures understood by CacheService class.
/// \see CacheService
class
CacheServerRequest
:
public
BaseRequest
{
public:
friend
class
CacheServer
;
enum
class
STATE
:
int8_t
{
CREATE
=
1
,
PROCESS
=
2
,
FINISH
=
3
};
explicit
CacheServerRequest
(
int32_t
queue_id
)
:
BaseRequest
::
BaseRequest
(
BaseRequest
::
RequestType
::
kRequestUnknown
),
qid_
(
queue_id
),
st_
(
STATE
::
CREATE
),
responder_
(
&
ctx_
)
{}
~
CacheServerRequest
()
=
default
;
/// \brief Functor. Used mainly by CacheServerGreeterImpl class to tag each incoming request and this
/// functor will translate each protobuf into some form understood by by CacheService class.
/// \param svc Async service
/// \param cq Completion queue
/// \return Status object
Status
operator
()(
CacheServerGreeter
::
AsyncService
*
svc
,
grpc
::
ServerCompletionQueue
*
cq
);
/// \brief Override the base class Print method
/// \param out
void
Print
(
std
::
ostream
&
out
)
const
override
;
/// \brief Getter of the queue id
/// \return The queue where the request should go to
int32_t
getQid
()
const
{
return
qid_
;
}
private:
int32_t
qid_
;
Status
rc_
;
STATE
st_
;
grpc
::
ServerContext
ctx_
;
grpc
::
ServerAsyncResponseWriter
<
CacheReply
>
responder_
;
};
/// \brief Implementation of CacheServerGreeter
/// \note It is an async server
/// \see cache_grpc.proto
class
CacheServerGreeterImpl
final
{
friend
class
CacheServer
;
public:
explicit
CacheServerGreeterImpl
(
int32_t
port
,
int32_t
shared_memory_sz_in_gb
);
virtual
~
CacheServerGreeterImpl
();
/// \brief Brings up gRPC server
/// \return none
Status
Run
();
/// \brief Entry function to handle cache server request
Status
HandleRequest
(
int32_t
worker_id
);
/// Return the shared memory pool.
/// \return Return the shared memory pool
CachedSharedMemoryArena
*
GetSharedMemoryPool
()
{
return
shm_pool_
.
get
();
}
void
Shutdown
();
Status
IpcResourceCleanup
();
private:
int32_t
port_
;
size_t
shm_pool_sz_in_gb_
;
std
::
string
unix_socket_
;
CacheServerGreeter
::
AsyncService
svc_
;
std
::
unique_ptr
<
grpc
::
ServerCompletionQueue
>
cq_
;
std
::
unique_ptr
<
grpc
::
Server
>
server_
;
std
::
unique_ptr
<
CachedSharedMemoryArena
>
shm_pool_
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_GRPC_SERVER_H_
mindspore/ccsrc/minddata/dataset/engine/cache/cache_main.cc
0 → 100644
浏览文件 @
8a08d0c3
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_server.h"
#include <sys/types.h>
#include <unistd.h>
#ifdef USE_GLOG
#include <glog/logging.h>
#endif
#include <cstdlib>
namespace
ds
=
mindspore
::
dataset
;
int
main
(
int
argc
,
char
**
argv
)
{
ds
::
Status
rc
;
ds
::
CacheServer
::
Builder
builder
;
// This executable is not to be called directly, and should be invoked by cache_admin executable.
if
(
argc
!=
7
)
{
rc
=
ds
::
Status
(
ds
::
StatusCode
::
kSyntaxError
);
std
::
cerr
<<
rc
.
ToString
()
<<
std
::
endl
;
return
static_cast
<
int
>
(
rc
.
get_code
());
}
builder
.
SetRootDirectory
(
argv
[
1
])
.
SetNumWorkers
(
strtol
(
argv
[
2
],
nullptr
,
10
))
.
SetPort
(
strtol
(
argv
[
3
],
nullptr
,
10
))
.
SetSharedMemorySizeInGB
(
strtol
(
argv
[
4
],
nullptr
,
10
));
#ifdef USE_GLOG
FLAGS_minloglevel
=
strtol
(
argv
[
5
],
nullptr
,
10
);
#endif
auto
daemonize_string
=
argv
[
6
];
bool
daemonize
=
strcmp
(
daemonize_string
,
"true"
)
==
0
||
strcmp
(
daemonize_string
,
"TRUE"
)
==
0
||
strcmp
(
daemonize_string
,
"t"
)
==
0
||
strcmp
(
daemonize_string
,
"T"
)
==
0
;
// We always change directory to / on unix rather than using the directory where the cache_server
// is called. This is a standard procedure for daemonize a process on unix.
if
(
chdir
(
"/"
)
==
-
1
)
{
std
::
string
errMsg
=
"Unable to change directory to /. Errno = "
+
std
::
to_string
(
errno
);
std
::
cerr
<<
errMsg
<<
std
::
endl
;
return
-
1
;
}
// Simple check of the parameters before we move on.
rc
=
builder
.
SanityCheck
();
if
(
rc
.
IsError
())
{
std
::
cerr
<<
rc
.
ToString
()
<<
std
::
endl
;
return
static_cast
<
int
>
(
rc
.
get_code
());
}
#ifdef USE_GLOG
FLAGS_log_dir
=
"/tmp"
;
google
::
InitGoogleLogging
(
argv
[
0
]);
#endif
if
(
daemonize
)
{
// fork the child process to become the daemon
pid_t
pid
=
fork
();
// failed to fork
if
(
pid
<
0
)
{
std
::
string
err_msg
=
"Failed to fork process for cache server: "
+
std
::
to_string
(
errno
);
std
::
cerr
<<
err_msg
<<
std
::
endl
;
return
errno
;
}
else
if
(
pid
>
0
)
{
// Parent
std
::
cerr
<<
"cache server daemon process has been created as process id: "
<<
pid
<<
"
\n
Check log file for any start up error"
<<
std
::
endl
;
signal
(
SIGCHLD
,
SIG_IGN
);
// ignore sig child signal.
return
0
;
}
else
{
// Child process will continue from here if daemonize and parent has already exited.
// If we are running in the foreground, none of the code in block below will be run.
pid_t
sid
;
umask
(
0
);
sid
=
setsid
();
if
(
sid
<
0
)
{
MS_LOG
(
ERROR
)
<<
"Failed to setsid(). Errno = "
<<
std
::
to_string
(
errno
);
return
errno
;
}
close
(
0
);
close
(
1
);
close
(
2
);
}
}
// Dump the summary
MS_LOG
(
INFO
)
<<
builder
<<
std
::
endl
;
rc
=
builder
.
Build
();
if
(
rc
.
IsOk
())
{
ds
::
CacheServer
&
cs
=
ds
::
CacheServer
::
GetInstance
();
// Kick off the threads. Loop forever and never return unless error.
rc
=
cs
.
Run
();
if
(
rc
.
get_code
()
==
ds
::
StatusCode
::
kDuplicateKey
)
{
std
::
string
errMsg
=
"Server is already started"
;
MS_LOG
(
ERROR
)
<<
errMsg
;
std
::
cerr
<<
errMsg
<<
std
::
endl
;
return
0
;
}
}
if
(
rc
.
IsError
())
{
MS_LOG
(
ERROR
)
<<
rc
.
ToString
();
std
::
cerr
<<
rc
.
ToString
()
<<
std
::
endl
;
return
static_cast
<
int
>
(
rc
.
get_code
());
}
return
0
;
}
mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.cc
浏览文件 @
8a08d0c3
...
...
@@ -14,154 +14,149 @@
* limitations under the License.
*/
#include "minddata/dataset/engine/cache/cache_request.h"
#include <cstdlib>
#include <thread>
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/engine/cache/cache_client.h"
#include "minddata/dataset/engine/cache/cache_fbb.h"
namespace
mindspore
{
namespace
dataset
{
Status
CacheRowRequest
::
SerializeCacheRowRequest
(
const
TensorRow
&
row
)
{
buffers_
.
reserve
(
row
.
size
()
+
1
);
RETURN_IF_NOT_OK
(
SerializeTensorRowHeader
(
row
));
buffers_
.
push_back
(
fbb_
->
GetBufferPointer
());
for
(
const
auto
&
ts
:
row
)
{
buffers_
.
push_back
(
ts
->
GetBuffer
());
}
Status
BaseRequest
::
Wait
()
{
RETURN_IF_NOT_OK
(
wp_
.
Wait
());
Status
remote_rc
(
static_cast
<
StatusCode
>
(
reply_
.
rc
()),
reply_
.
msg
());
RETURN_IF_NOT_OK
(
remote_rc
);
// Any extra work to do before we return back to the client.
RETURN_IF_NOT_OK
(
PostReply
());
return
Status
::
OK
();
}
Status
CacheRowRequest
::
SerializeTensorRowHeader
(
const
TensorRow
&
row
)
{
try
{
fbb_
=
std
::
make_shared
<
flatbuffers
::
FlatBufferBuilder
>
();
std
::
vector
<
flatbuffers
::
Offset
<
TensorMetaMsg
>>
v
;
std
::
vector
<
int64_t
>
tensor_sz
;
v
.
reserve
(
row
.
size
());
tensor_sz
.
reserve
(
row
.
size
());
// We will go through each column in the row.
for
(
const
std
::
shared_ptr
<
Tensor
>
&
ts_ptr
:
row
)
{
flatbuffers
::
Offset
<
TensorMetaMsg
>
ts_off
;
RETURN_IF_NOT_OK
(
SerializeOneTensorMeta
(
ts_ptr
,
&
ts_off
));
v
.
push_back
(
ts_off
);
tensor_sz
.
push_back
(
ts_ptr
->
SizeInBytes
());
Status
CacheRowRequest
::
SerializeCacheRowRequest
(
const
CacheClient
*
cc
,
const
TensorRow
&
row
)
{
CHECK_FAIL_RETURN_UNEXPECTED
(
row
.
size
()
>
0
,
"Empty tensor row"
);
CHECK_FAIL_RETURN_UNEXPECTED
(
cc
->
SupportLocalClient
()
==
support_local_bypass_
,
"Local bypass mismatch"
);
// Calculate how many bytes (not counting the cookie) we are sending to the server. We only
// use shared memory (if supported) if we exceed certain amount
std
::
shared_ptr
<
flatbuffers
::
FlatBufferBuilder
>
fbb
;
RETURN_IF_NOT_OK
(
::
mindspore
::
dataset
::
SerializeTensorRowHeader
(
row
,
&
fbb
));
sz_
+=
fbb
->
GetSize
();
for
(
const
auto
&
ts
:
row
)
{
sz_
+=
ts
->
SizeInBytes
();
}
bool
sent_using_local_bypass
=
support_local_bypass_
?
(
sz_
>=
kLocalByPassThreshold
)
:
false
;
uint32_t
flag
=
0
;
if
(
support_local_bypass_
)
{
BitSet
(
&
flag
,
kLocalClientSupport
);
}
if
(
sent_using_local_bypass
)
{
BitSet
(
&
flag
,
kDataIsInSharedMemory
);
}
rq_
.
set_flag
(
flag
);
if
(
sent_using_local_bypass
)
{
MS_LOG
(
DEBUG
)
<<
"Requesting "
<<
sz_
<<
" bytes of shared memory data"
;
// Allocate shared memory from the server
auto
mem_rq
=
std
::
make_shared
<
AllocateSharedBlockRequest
>
(
rq_
.
connection_id
(),
sz_
);
RETURN_IF_NOT_OK
(
cc
->
PushRequest
(
mem_rq
));
RETURN_IF_NOT_OK
(
mem_rq
->
Wait
());
addr_
=
mem_rq
->
GetAddr
();
// Now we need to add that to the base address of where we attach.
auto
base
=
cc
->
SharedMemoryBaseAddr
();
auto
p
=
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
int64_t
>
(
base
)
+
addr_
);
// Now we copy the data onto shared memory.
WritableSlice
all
(
p
,
sz_
);
auto
offset
=
fbb
->
GetSize
();
ReadableSlice
header
(
fbb
->
GetBufferPointer
(),
fbb
->
GetSize
());
Status
copy_rc
;
copy_rc
=
WritableSlice
::
Copy
(
&
all
,
header
);
if
(
copy_rc
.
IsOk
())
{
for
(
const
auto
&
ts
:
row
)
{
WritableSlice
row_data
(
all
,
offset
,
ts
->
SizeInBytes
());
ReadableSlice
src
(
ts
->
GetBuffer
(),
ts
->
SizeInBytes
());
copy_rc
=
WritableSlice
::
Copy
(
&
row_data
,
src
);
if
(
copy_rc
.
IsError
())
{
break
;
}
offset
+=
ts
->
SizeInBytes
();
}
// Fill in where to find the data
AddDataLocation
();
}
auto
column_off
=
fbb_
->
CreateVector
(
v
);
auto
data_sz_off
=
fbb_
->
CreateVector
(
tensor_sz
);
TensorRowHeaderMsgBuilder
row_builder
(
*
fbb_
);
row_builder
.
add_column
(
column_off
);
row_builder
.
add_data_sz
(
data_sz_off
);
// Pass the row_id even if it may not be known.
row_builder
.
add_row_id
(
row
.
getId
());
row_builder
.
add_size_of_this
(
-
1
);
// fill in later after we call Finish.
auto
out
=
row_builder
.
Finish
();
fbb_
->
Finish
(
out
);
// Now go back to fill in size_of_this in the flat buffer.
auto
msg
=
GetMutableTensorRowHeaderMsg
(
fbb_
->
GetBufferPointer
());
auto
success
=
msg
->
mutate_size_of_this
(
fbb_
->
GetSize
());
if
(
!
success
)
{
RETURN_STATUS_UNEXPECTED
(
"Unable to set size_of_this"
);
if
(
copy_rc
.
IsError
())
{
// We need to return the memory back to the server
auto
mfree_req
=
GenerateFreeBlockRequest
();
Status
rc
=
cc
->
PushRequest
(
mfree_req
);
// But we won't wait for the result for the sake of performance.
if
(
rc
.
IsError
())
{
MS_LOG
(
ERROR
)
<<
"Push request for free memory failed."
;
}
return
copy_rc
;
}
return
Status
::
OK
();
}
catch
(
const
std
::
bad_alloc
&
e
)
{
return
Status
(
StatusCode
::
kOutOfMemory
,
__LINE__
,
__FILE__
);
}
else
{
// We have already filled the first buffer which is the cookie.
sz_
+=
rq_
.
buf_data
(
0
).
size
();
rq_
.
add_buf_data
(
fbb
->
GetBufferPointer
(),
fbb
->
GetSize
());
for
(
const
auto
&
ts
:
row
)
{
rq_
.
add_buf_data
(
ts
->
GetBuffer
(),
ts
->
SizeInBytes
());
}
MS_LOG
(
DEBUG
)
<<
"Sending "
<<
sz_
<<
" bytes of tensor data in "
<<
rq_
.
buf_data_size
()
<<
" segments"
;
}
return
Status
::
OK
();
}
Status
CacheRowRequest
::
SerializeOneTensorMeta
(
const
std
::
shared_ptr
<
Tensor
>
&
ts_ptr
,
flatbuffers
::
Offset
<
TensorMetaMsg
>
*
out_off
)
{
RETURN_UNEXPECTED_IF_NULL
(
out_off
);
const
Tensor
*
ts
=
ts_ptr
.
get
();
auto
shape_off
=
fbb_
->
CreateVector
(
ts
->
shape
().
AsVector
());
const
auto
ptr
=
ts
->
GetBuffer
();
if
(
ptr
==
nullptr
)
{
RETURN_STATUS_UNEXPECTED
(
"Tensor buffer is null"
);
}
auto
src
=
ts
->
type
().
value
();
TensorType
dest
;
#define CASE(t) \
case DataType::t: \
dest = TensorType::TensorType_##t; \
break
// Map the type to fill in the flat buffer.
switch
(
src
)
{
CASE
(
DE_BOOL
);
CASE
(
DE_INT8
);
CASE
(
DE_UINT8
);
CASE
(
DE_INT16
);
CASE
(
DE_UINT16
);
CASE
(
DE_INT32
);
CASE
(
DE_UINT32
);
CASE
(
DE_INT64
);
CASE
(
DE_UINT64
);
CASE
(
DE_FLOAT16
);
CASE
(
DE_FLOAT32
);
CASE
(
DE_FLOAT64
);
CASE
(
DE_STRING
);
default:
MS_LOG
(
ERROR
)
<<
"Unknown tensor. Dumping content:
\n
"
<<
*
ts
;
RETURN_STATUS_UNEXPECTED
(
"Unknown type"
);
Status
CacheRowRequest
::
PostReply
()
{
if
(
!
reply_
.
result
().
empty
())
{
row_id_from_server_
=
strtoll
(
reply_
.
result
().
data
(),
nullptr
,
10
);
}
#undef CASE
TensorMetaMsgBuilder
ts_builder
(
*
fbb_
);
ts_builder
.
add_dims
(
shape_off
);
ts_builder
.
add_type
(
dest
);
auto
ts_off
=
ts_builder
.
Finish
();
*
out_off
=
ts_off
;
return
Status
::
OK
();
}
Status
BatchFetchRequest
::
RestoreOneTensor
(
const
TensorMetaMsg
*
col_ts
,
const
ReadableSlice
&
data
,
std
::
shared_ptr
<
Tensor
>
*
out
)
{
RETURN_UNEXPECTED_IF_NULL
(
col_ts
);
auto
shape_in
=
col_ts
->
dims
();
auto
type_in
=
col_ts
->
type
();
std
::
vector
<
dsize_t
>
v
;
v
.
reserve
(
shape_in
->
size
());
v
.
assign
(
shape_in
->
begin
(),
shape_in
->
end
());
TensorShape
shape
(
v
);
DataType
::
Type
dest
=
DataType
::
DE_UNKNOWN
;
#define CASE(t) \
case TensorType_##t: \
dest = DataType::Type::t; \
break
switch
(
type_in
)
{
CASE
(
DE_BOOL
);
CASE
(
DE_INT8
);
CASE
(
DE_UINT8
);
CASE
(
DE_INT16
);
CASE
(
DE_UINT16
);
CASE
(
DE_INT32
);
CASE
(
DE_UINT32
);
CASE
(
DE_INT64
);
CASE
(
DE_UINT64
);
CASE
(
DE_FLOAT16
);
CASE
(
DE_FLOAT32
);
CASE
(
DE_FLOAT64
);
CASE
(
DE_STRING
);
Status
CacheRowRequest
::
Prepare
()
{
if
(
BitTest
(
rq_
.
flag
(),
kDataIsInSharedMemory
))
{
// First one is cookie, followed by address and then size.
CHECK_FAIL_RETURN_UNEXPECTED
(
rq_
.
buf_data_size
()
==
3
,
"Incomplete rpc data"
);
}
else
{
// First one is cookie. 2nd one is the google flat buffers followed by a number of buffers.
// But we are not going to decode them to verify.
CHECK_FAIL_RETURN_UNEXPECTED
(
rq_
.
buf_data_size
()
>=
3
,
"Incomplete rpc data"
);
}
#undef CASE
DataType
type
(
dest
);
std
::
shared_ptr
<
Tensor
>
ts
;
RETURN_IF_NOT_OK
(
Tensor
::
CreateFromMemory
(
shape
,
type
,
static_cast
<
const
unsigned
char
*>
(
data
.
GetPointer
()),
data
.
GetSize
(),
&
ts
));
// Next we restore the real data which can be embedded or stored separately.
if
(
ts
->
SizeInBytes
()
!=
data
.
GetSize
())
{
MS_LOG
(
ERROR
)
<<
"Unexpected length. Read "
<<
data
.
GetSize
()
<<
". Expected "
<<
ts
->
SizeInBytes
()
<<
".
\n
"
<<
"Dumping tensor
\n
"
<<
*
ts
<<
"
\n
"
;
RETURN_STATUS_UNEXPECTED
(
"Length mismatch. See log file for details."
);
}
*
out
=
std
::
move
(
ts
);
return
Status
::
OK
();
}
Status
BatchFetchRequest
::
RestoreRows
(
TensorTable
*
out
)
{
BatchFetchRequest
::
BatchFetchRequest
(
connection_id_type
connection_id
,
const
std
::
vector
<
row_id_type
>
&
row_id
,
bool
local_bypass
)
:
BaseRequest
(
RequestType
::
kBatchFetchRows
),
support_local_bypass_
(
local_bypass
),
row_id_
(
row_id
)
{
rq_
.
set_connection_id
(
connection_id
);
rq_
.
set_flag
(
support_local_bypass_
?
kLocalClientSupport
:
0
);
// Convert the row id into a flatbuffer
flatbuffers
::
FlatBufferBuilder
fbb
;
auto
off_t
=
fbb
.
CreateVector
(
row_id
);
TensorRowIdsBuilder
bld
(
fbb
);
bld
.
add_row_id
(
off_t
);
auto
off
=
bld
.
Finish
();
fbb
.
Finish
(
off
);
rq_
.
add_buf_data
(
fbb
.
GetBufferPointer
(),
fbb
.
GetSize
());
}
Status
BatchFetchRequest
::
RestoreRows
(
TensorTable
*
out
,
const
void
*
baseAddr
,
int64_t
*
out_addr
)
{
RETURN_UNEXPECTED_IF_NULL
(
out
);
auto
num_elements
=
row_id_
.
size
();
auto
*
offset_array
=
reinterpret_cast
<
const
int64_t
*>
(
mem_
.
GetPointer
());
const
char
*
ptr
=
nullptr
;
int64_t
sz
=
0
;
// Tap into the reply flag to see where we can find the data. Server may decide the amount is
// so small that it doesn't use shared memory method.
auto
flag
=
reply_
.
flag
();
bool
dataOnSharedMemory
=
support_local_bypass_
?
(
BitTest
(
flag
,
kDataIsInSharedMemory
))
:
false
;
if
(
dataOnSharedMemory
)
{
auto
addr
=
strtoll
(
reply_
.
result
().
data
(),
nullptr
,
10
);
ptr
=
reinterpret_cast
<
const
char
*>
(
reinterpret_cast
<
int64_t
>
(
baseAddr
)
+
addr
);
RETURN_UNEXPECTED_IF_NULL
(
out
);
*
out_addr
=
addr
;
}
else
{
ptr
=
reply_
.
result
().
data
();
*
out_addr
=
-
1
;
}
auto
*
offset_array
=
reinterpret_cast
<
const
int64_t
*>
(
ptr
);
sz
=
offset_array
[
num_elements
];
CHECK_FAIL_RETURN_UNEXPECTED
(
support_local_bypass_
||
sz
==
reply_
.
result
().
length
(),
"Length mismatch"
);
TensorTable
tbl
;
tbl
.
reserve
(
num_elements
);
ReadableSlice
all
(
mem_
.
GetPointer
(),
mem_
.
GetSizeInBytes
()
);
ReadableSlice
all
(
ptr
,
sz
);
for
(
auto
i
=
0
;
i
<
num_elements
;
++
i
)
{
auto
len
=
offset_array
[
i
+
1
]
-
offset_array
[
i
];
TensorRow
row
;
...
...
@@ -178,10 +173,12 @@ Status BatchFetchRequest::RestoreRows(TensorTable *out) {
auto
col_ts
=
msg
->
column
()
->
Get
(
k
);
std
::
shared_ptr
<
Tensor
>
ts
;
ReadableSlice
data
(
row_data
,
ts_offset
,
msg
->
data_sz
()
->
Get
(
k
));
RETURN_IF_NOT_OK
(
RestoreOneTensor
(
col_ts
,
data
,
&
ts
));
RETURN_IF_NOT_OK
(
mindspore
::
dataset
::
RestoreOneTensor
(
col_ts
,
data
,
&
ts
));
row
.
push_back
(
ts
);
ts_offset
+=
data
.
GetSize
();
}
}
else
{
CHECK_FAIL_RETURN_UNEXPECTED
(
len
==
0
,
"Data corruption detected."
);
}
tbl
.
push_back
(
std
::
move
(
row
));
}
...
...
@@ -189,36 +186,69 @@ Status BatchFetchRequest::RestoreRows(TensorTable *out) {
return
Status
::
OK
();
}
CreateCacheRequest
::
CreateCacheRequest
(
const
CacheClientInfo
&
cinfo
,
uint64_t
cache_mem_sz
,
CreateCacheRequest
::
CreateCacheFlag
flag
)
:
BaseRequest
(
RequestType
::
kCreateCache
),
cache_mem_sz_
(
cache_mem_sz
),
flag_
(
flag
)
{
// Type has been set already in the base constructor. So we need to fill in the connection info.
// On successful return, we will get the connection id
rq_
.
mutable_connection_info
()
->
operator
=
(
cinfo
);
}
Status
CreateCacheRequest
::
Prepare
()
{
try
{
flatbuffers
::
FlatBufferBuilder
fbb
;
CreateCacheRequestMsgBuilder
bld
(
fbb
);
bld
.
add_cache_mem_sz
(
cache_mem_sz_
);
bld
.
add_flag
(
static_cast
<
uint32_t
>
(
flag_
));
auto
off
=
bld
.
Finish
();
fbb
.
Finish
(
off
);
rq_
.
add_buf_data
(
fbb
.
GetBufferPointer
(),
fbb
.
GetSize
());
return
Status
::
OK
();
}
catch
(
const
std
::
bad_alloc
&
e
)
{
return
Status
(
StatusCode
::
kOutOfMemory
,
__LINE__
,
__FILE__
);
}
}
Status
CacheSchemaRequest
::
SerializeCacheSchemaRequest
(
const
std
::
unordered_map
<
std
::
string
,
int32_t
>
&
map
)
{
try
{
f
bb_
=
std
::
make_shared
<
flatbuffers
::
FlatBufferBuilder
>
()
;
f
latbuffers
::
FlatBufferBuilder
fbb
;
std
::
vector
<
flatbuffers
::
Offset
<
ColumnNameMsg
>>
v
;
v
.
reserve
(
map
.
size
());
for
(
auto
&
column
:
map
)
{
auto
c
=
CreateColumnNameMsg
(
*
fbb_
,
fbb_
->
CreateString
(
column
.
first
),
column
.
second
);
auto
c
=
CreateColumnNameMsg
(
fbb
,
fbb
.
CreateString
(
column
.
first
),
column
.
second
);
v
.
push_back
(
c
);
}
auto
v_off
=
fbb_
->
CreateVector
(
v
);
auto
final_off
=
CreateSchemaMsg
(
*
fbb_
,
v_off
);
fbb_
->
Finish
(
final_off
);
buf_
=
fbb_
->
GetBufferPointer
();
len_of_buf_
=
fbb_
->
GetSize
();
auto
v_off
=
fbb
.
CreateVector
(
v
);
auto
final_off
=
CreateSchemaMsg
(
fbb
,
v_off
);
fbb
.
Finish
(
final_off
);
rq_
.
add_buf_data
(
fbb
.
GetBufferPointer
(),
fbb
.
GetSize
());
return
Status
::
OK
();
}
catch
(
const
std
::
bad_alloc
&
e
)
{
return
Status
(
StatusCode
::
kOutOfMemory
,
__LINE__
,
__FILE__
);
}
}
std
::
unordered_map
<
std
::
string
,
int32_t
>
FetchSchemaRequest
::
GetColumnMap
()
{
if
(
column_name_id_map_
.
empty
())
{
auto
*
map_msg
=
flatbuffers
::
GetRoot
<
SchemaMsg
>
(
mem_
.
GetPointer
());
auto
v
=
map_msg
->
column
();
for
(
auto
i
=
0
;
i
<
v
->
size
();
++
i
)
{
auto
col
=
map_msg
->
column
()
->
Get
(
i
);
column_name_id_map_
.
emplace
(
col
->
name
()
->
str
(),
col
->
id
());
}
Status
FetchSchemaRequest
::
PostReply
()
{
auto
*
map_msg
=
flatbuffers
::
GetRoot
<
SchemaMsg
>
(
reply_
.
result
().
data
());
auto
v
=
map_msg
->
column
();
for
(
auto
i
=
0
;
i
<
v
->
size
();
++
i
)
{
auto
col
=
map_msg
->
column
()
->
Get
(
i
);
column_name_id_map_
.
emplace
(
col
->
name
()
->
str
(),
col
->
id
());
}
return
column_name_id_map_
;
return
Status
::
OK
();
}
std
::
unordered_map
<
std
::
string
,
int32_t
>
FetchSchemaRequest
::
GetColumnMap
()
{
return
column_name_id_map_
;
}
Status
GetStatRequest
::
PostReply
()
{
auto
*
msg
=
flatbuffers
::
GetRoot
<
ServiceStatMsg
>
(
reply_
.
result
().
data
());
stat_
.
num_disk_cached
=
msg
->
num_disk_cached
();
stat_
.
num_mem_cached
=
msg
->
num_mem_cached
();
stat_
.
avg_cache_sz
=
msg
->
avg_cache_sz
();
stat_
.
max_row_id
=
msg
->
max_row_id
();
stat_
.
min_row_id
=
msg
->
min_row_id
();
stat_
.
cache_service_state
=
msg
->
state
();
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/engine/cache/cache_request.h
浏览文件 @
8a08d0c3
...
...
@@ -18,11 +18,16 @@
#include <algorithm>
#include <memory>
#include <iostream>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#ifdef ENABLE_CACHE
#include "proto/cache_grpc.grpc.pb.h"
#endif
#include "proto/cache_grpc.pb.h"
#include "minddata/dataset/core/tensor_row.h"
#include "minddata/dataset/engine/cache/de_tensor_generated.h"
#include "minddata/dataset/util/slice.h"
...
...
@@ -30,6 +35,17 @@
namespace
mindspore
{
namespace
dataset
{
class
CacheClient
;
/// \brief Statistic structure for GetStat request
struct
CacheServiceStat
{
int64_t
num_mem_cached
;
int64_t
num_disk_cached
;
int64_t
avg_cache_sz
;
row_id_type
min_row_id
;
row_id_type
max_row_id
;
int8_t
cache_service_state
;
};
/// \brief CacheClient communicates with CacheServer using Requests.
class
BaseRequest
{
public:
...
...
@@ -44,195 +60,301 @@ class BaseRequest {
kCacheSchema
=
6
,
kFetchSchema
=
7
,
kBuildPhaseDone
=
8
,
kDropSession
=
9
,
kGenerateSessionId
=
10
,
kAllocateSharedBlock
=
11
,
kFreeSharedBlock
=
12
,
kStopService
=
13
,
// Add new request before it.
kRequestUnknown
=
32767
};
// For kCreateCache
enum
class
CreateCacheFlag
:
uint32_t
{
kNone
=
0
,
kSpillToDisk
=
1
,
kGenerateRowId
=
1u
<<
1L
};
friend
class
CacheServer
;
friend
class
CacheServerRequest
;
friend
class
CacheClientGreeter
;
friend
class
CacheClientRequestTag
;
/// \brief Base class of a cache server request
/// \param connection_id A combination of session id and crc that uniquely identifies a connection.
/// \param type Type of the request
explicit
BaseRequest
(
connection_id_type
connection_id
,
RequestType
type
)
:
type_
(
type
),
connection_id_
(
connection_id
)
{}
explicit
BaseRequest
(
RequestType
type
)
:
type_
(
type
)
{
rq_
.
set_type
(
static_cast
<
google
::
int32
>
(
type_
));
}
virtual
~
BaseRequest
()
=
default
;
/// \brief Wait for the completion of a request
/// \return Status returned from the cache server
Status
Wait
()
{
RETURN_IF_NOT_OK
(
wp_
.
Wait
());
return
rc_
;
/// \brief A print method for debugging
/// \param out The output stream to write output to
virtual
void
Print
(
std
::
ostream
&
out
)
const
{
out
<<
"Request type: "
<<
static_cast
<
int16_t
>
(
type_
);
}
/// \brief << Stream output operator overload
/// \param out reference to the output stream
/// \param rq reference to the BaseRequest
/// \return the output stream
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
BaseRequest
&
rq
)
{
rq
.
Print
(
out
);
return
out
;
}
/// \brief Getter function of the current connection id
/// \return Connection id
connection_id_type
GetServerConnectionId
()
const
{
return
connection_id_
;
}
/// \brief Derived class can implement extra work to be done before the request is sent to the server
virtual
Status
Prepare
()
{
return
Status
::
OK
();
}
/// \brief Derived class can implement extra work to be done after the server sends the request
virtual
Status
PostReply
()
{
return
Status
::
OK
();
}
/// \brief A method for the client to wait for the availability of the result back from the server.
/// \return Status object
Status
Wait
();
protected:
CacheRequest
rq_
;
// This is what we send to the server
CacheReply
reply_
;
// This is what the server send back
private:
RequestType
type_
;
connection_id_type
connection_id_
;
Status
rc_
;
WaitPost
wp_
;
WaitPost
wp_
;
// A sync area used by the client side.
};
class
FreeSharedBlockRequest
:
public
BaseRequest
{
public:
friend
class
CacheServer
;
explicit
FreeSharedBlockRequest
(
connection_id_type
connection_id
,
int64_t
addr
)
:
BaseRequest
(
RequestType
::
kFreeSharedBlock
)
{
rq_
.
set_connection_id
(
connection_id
);
rq_
.
add_buf_data
(
std
::
to_string
(
addr
));
}
~
FreeSharedBlockRequest
()
=
default
;
};
/// \brief Request to cache a single TensorRow
class
CacheRowRequest
:
public
BaseRequest
{
public:
friend
class
CacheServer
;
explicit
CacheRowRequest
(
connection_id_type
connection_id
,
const
std
::
string
&
cookie
)
:
BaseRequest
(
connection_id
,
RequestType
::
kCacheRow
),
row_id_from_server_
(
-
1
),
cookie_
(
cookie
)
{}
friend
class
CacheClient
;
explicit
CacheRowRequest
(
connection_id_type
connection_id
,
const
std
::
string
&
cookie
,
bool
local_bypass
)
:
BaseRequest
(
RequestType
::
kCacheRow
),
support_local_bypass_
(
local_bypass
),
addr_
(
-
1
),
sz_
(
0
),
row_id_from_server_
(
-
1
)
{
rq_
.
set_connection_id
(
connection_id
);
rq_
.
add_buf_data
(
cookie
);
}
~
CacheRowRequest
()
=
default
;
/// \brief Serialize a TensorRow for streaming to the cache server
/// \param row TensorRow
/// \return Status object
Status
SerializeCacheRowRequest
(
const
TensorRow
&
row
);
Status
SerializeCacheRowRequest
(
const
CacheClient
*
cc
,
const
TensorRow
&
row
);
/// \brief Sanity check before we send the row.
/// \return Status object
Status
Prepare
()
override
;
/// \brief Override the base function get the row id returned from the server
/// \return Status object
Status
PostReply
()
override
;
/// \brief Return the row id assigned to this row for non-mappable dataset
/// \return row id of the cached row
row_id_type
GetRowIdAfterCache
()
{
return
row_id_from_server_
;
}
/// \brief If we are doing local bypass, fill in extra request information of where the data is located.
void
AddDataLocation
()
{
if
(
support_local_bypass_
)
{
rq_
.
add_buf_data
(
std
::
to_string
(
addr_
));
rq_
.
add_buf_data
(
std
::
to_string
(
sz_
));
}
}
/// \brief If we fail to send the data to the server using shared memory method, we should release
/// the shared memory by sending another request. The following function will generate a suitable
/// request for the CacheClient to send.
std
::
shared_ptr
<
FreeSharedBlockRequest
>
GenerateFreeBlockRequest
()
{
return
std
::
make_shared
<
FreeSharedBlockRequest
>
(
rq_
.
connection_id
(),
addr_
);
}
private:
std
::
shared_ptr
<
flatbuffers
::
FlatBufferBuilder
>
fbb_
;
bool
support_local_bypass_
;
int64_t
addr_
;
int64_t
sz_
;
row_id_type
row_id_from_server_
;
std
::
vector
<
const
void
*>
buffers_
;
std
::
string
cookie_
;
/// \brief Private function to serialize one TensorRow
/// \param row TensorRow
/// \return Status object
Status
SerializeTensorRowHeader
(
const
TensorRow
&
row
);
/// \brief Private function to serialize one Tensor
/// \param ts_ptr Tensor
/// \return Status object
Status
SerializeOneTensorMeta
(
const
std
::
shared_ptr
<
Tensor
>
&
ts_ptr
,
flatbuffers
::
Offset
<
TensorMetaMsg
>
*
out_off
);
};
/// \brief Request to fetch rows in batch
class
BatchFetchRequest
:
public
BaseRequest
{
public:
friend
class
CacheServer
;
friend
class
CacheService
;
BatchFetchRequest
(
connection_id_type
connection_id
,
const
std
::
vector
<
row_id_type
>
&
row_id
)
:
BaseRequest
(
connection_id
,
RequestType
::
kBatchFetchRows
),
row_id_
(
row_id
)
{}
BatchFetchRequest
(
connection_id_type
connection_id
,
const
std
::
vector
<
row_id_type
>
&
row_id
,
bool
local_bypass
);
~
BatchFetchRequest
()
=
default
;
Status
RestoreRows
(
TensorTable
*
out
);
Status
RestoreRows
(
TensorTable
*
out
,
const
void
*
baseAddr
,
int64_t
*
out_addr
);
private:
bool
support_local_bypass_
;
std
::
vector
<
row_id_type
>
row_id_
;
MemGuard
<
uint8_t
>
mem_
;
Status
RestoreOneTensor
(
const
TensorMetaMsg
*
col_ts
,
const
ReadableSlice
&
data
,
std
::
shared_ptr
<
Tensor
>
*
out
);
};
/// \brief Request to create a cache for the current connection
class
Creat
ion
CacheRequest
:
public
BaseRequest
{
class
Creat
e
CacheRequest
:
public
BaseRequest
{
public:
friend
class
CacheServer
;
enum
class
CreateCacheFlag
:
uint32_t
{
kNone
=
0
,
kSpillToDisk
=
1
,
kGenerateRowId
=
1u
<<
1L
};
/// \brief Constructor
/// \param connection_id
/// \param cache_mem_sz Maximum memory assigned for this connection. 0 means unlimited
/// \param flag Attributes of the cache.
explicit
CreationCacheRequest
(
connection_id_type
connection_id
,
uint64_t
cache_mem_sz
,
CreateCacheFlag
flag
=
CreateCacheFlag
::
kNone
)
:
BaseRequest
(
connection_id
,
RequestType
::
kCreateCache
),
cache_mem_sz
(
cache_mem_sz
),
flag_
(
flag
)
{}
~
CreationCacheRequest
()
=
default
;
explicit
CreateCacheRequest
(
const
CacheClientInfo
&
cinfo
,
uint64_t
cache_mem_sz
,
CreateCacheFlag
flag
=
CreateCacheFlag
::
kNone
);
~
CreateCacheRequest
()
=
default
;
void
ParseResult
(
connection_id_type
*
id
,
std
::
string
*
out
)
{
auto
p
=
flatbuffers
::
GetRoot
<
CreateCacheReplyMsg
>
(
reply_
.
result
().
data
());
*
id
=
p
->
connection_id
();
*
out
=
p
->
cookie
()
->
str
();
}
std
::
string
cookie
()
const
{
return
cookie_
;
}
/// Overload the base class Prepare
Status
Prepare
()
override
;
private:
uint64_t
cache_mem_sz
;
uint64_t
cache_mem_sz
_
;
CreateCacheFlag
flag_
;
std
::
string
cookie_
;
};
/// \brief Request to purge a cache.
class
PurgeCacheRequest
:
public
BaseRequest
{
public:
friend
class
CacheServer
;
explicit
PurgeCacheRequest
(
connection_id_type
connection_id
)
:
BaseRequest
(
connection_id
,
RequestType
::
kPurgeCache
)
{}
explicit
PurgeCacheRequest
(
connection_id_type
connection_id
)
:
BaseRequest
(
RequestType
::
kPurgeCache
)
{
rq_
.
set_connection_id
(
connection_id
);
}
~
PurgeCacheRequest
()
=
default
;
};
/// \brief Request to destroy a cache
class
DestroyCacheRequest
:
public
BaseRequest
{
public:
friend
class
CacheServer
;
explicit
DestroyCacheRequest
(
connection_id_type
connection_id
)
:
BaseRequest
(
connection_id
,
RequestType
::
kDestroyCache
)
{}
/// \brief Destructor
explicit
DestroyCacheRequest
(
connection_id_type
connection_id
)
:
BaseRequest
(
RequestType
::
kDestroyCache
)
{
rq_
.
set_connection_id
(
connection_id
);
}
~
DestroyCacheRequest
()
=
default
;
};
/// \brief Obtain the statistics of the current connection
class
GetStatRequest
:
public
BaseRequest
{
public:
friend
class
CacheServer
;
friend
class
CacheService
;
explicit
GetStatRequest
(
connection_id_type
connection_id
)
:
BaseRequest
(
connection_id
,
RequestType
::
kGetStat
)
{}
explicit
GetStatRequest
(
connection_id_type
connection_id
)
:
BaseRequest
(
RequestType
::
kGetStat
)
{
rq_
.
set_connection_id
(
connection_id
);
}
~
GetStatRequest
()
=
default
;
row_id_type
GetMinRowId
()
const
{
auto
*
msg
=
flatbuffers
::
GetRoot
<
ServiceStatMsg
>
(
mem_
.
GetPointer
());
return
msg
->
min_row_id
();
}
row_id_type
GetMaxRowId
()
const
{
auto
*
msg
=
flatbuffers
::
GetRoot
<
ServiceStatMsg
>
(
mem_
.
GetPointer
());
return
msg
->
max_row_id
();
}
int64_t
GetNumMemCached
()
const
{
auto
*
msg
=
flatbuffers
::
GetRoot
<
ServiceStatMsg
>
(
mem_
.
GetPointer
());
return
msg
->
num_mem_cached
();
}
int64_t
GetNumDiskCached
()
const
{
auto
*
msg
=
flatbuffers
::
GetRoot
<
ServiceStatMsg
>
(
mem_
.
GetPointer
());
return
msg
->
num_disk_cached
();
}
uint8_t
GetState
()
const
{
auto
*
msg
=
flatbuffers
::
GetRoot
<
ServiceStatMsg
>
(
mem_
.
GetPointer
());
return
msg
->
state
();
/// \brief Override base function to process the result.
Status
PostReply
()
override
;
void
GetStat
(
CacheServiceStat
*
stat
)
{
if
(
stat
!=
nullptr
)
{
(
*
stat
)
=
stat_
;
}
}
private:
MemGuard
<
uint8_t
>
mem_
;
CacheServiceStat
stat_
{}
;
};
/// \brief Request to cache a schema
class
CacheSchemaRequest
:
public
BaseRequest
{
public:
friend
class
CacheServer
;
explicit
CacheSchemaRequest
(
connection_id_type
connection_id
)
:
BaseRequest
(
connection_id
,
RequestType
::
kCacheSchema
),
buf_
(
nullptr
),
len_of_buf_
(
0
)
{}
explicit
CacheSchemaRequest
(
connection_id_type
connection_id
)
:
BaseRequest
(
RequestType
::
kCacheSchema
)
{
rq_
.
set_connection_id
(
connection_id
);
}
~
CacheSchemaRequest
()
=
default
;
Status
SerializeCacheSchemaRequest
(
const
std
::
unordered_map
<
std
::
string
,
int32_t
>
&
map
);
const
void
*
GetBuffer
()
const
{
return
buf_
;
}
private:
std
::
shared_ptr
<
flatbuffers
::
FlatBufferBuilder
>
fbb_
;
const
void
*
buf_
;
int64_t
len_of_buf_
;
};
/// \brief Request to fetch a schema
class
FetchSchemaRequest
:
public
BaseRequest
{
public:
friend
class
CacheServer
;
explicit
FetchSchemaRequest
(
connection_id_type
connection_id
)
:
BaseRequest
(
connection_id
,
RequestType
::
kFetchSchema
)
{}
explicit
FetchSchemaRequest
(
connection_id_type
connection_id
)
:
BaseRequest
(
RequestType
::
kFetchSchema
)
{
rq_
.
set_connection_id
(
connection_id
);
}
~
FetchSchemaRequest
()
=
default
;
Status
PostReply
()
override
;
std
::
unordered_map
<
std
::
string
,
int32_t
>
GetColumnMap
();
private:
MemGuard
<
uint8_t
>
mem_
;
std
::
unordered_map
<
std
::
string
,
int32_t
>
column_name_id_map_
;
};
/// \brief Request to change a cache from build phase to read phase. Applies to non-mappable cache only.
class
BuildPhaseDoneRequest
:
public
BaseRequest
{
public:
friend
class
CacheServer
;
BuildPhaseDoneRequest
(
connection_id_type
connection_id
,
const
std
::
string
&
cookie
)
:
BaseRequest
(
connection_id
,
RequestType
::
kBuildPhaseDone
),
cookie_
(
cookie
)
{}
:
BaseRequest
(
RequestType
::
kBuildPhaseDone
),
cookie_
(
cookie
)
{
rq_
.
set_connection_id
(
connection_id
);
rq_
.
add_buf_data
(
cookie_
);
}
~
BuildPhaseDoneRequest
()
=
default
;
private:
std
::
string
cookie_
;
};
/// \brief Request to drop all the caches in the current session
class
DropSessionRequest
:
public
BaseRequest
{
public:
friend
class
CacheServer
;
explicit
DropSessionRequest
(
const
CacheClientInfo
&
cinfo
)
:
BaseRequest
(
RequestType
::
kDropSession
)
{
rq_
.
mutable_connection_info
()
->
operator
=
(
cinfo
);
}
~
DropSessionRequest
()
=
default
;
};
class
GenerateSessionIdRequest
:
public
BaseRequest
{
public:
friend
class
CacheServer
;
GenerateSessionIdRequest
()
:
BaseRequest
(
RequestType
::
kGenerateSessionId
)
{
// We don't have anything client info nor connection id to send. But we will manually
// set the connection id to 0.
rq_
.
set_connection_id
(
0
);
}
~
GenerateSessionIdRequest
()
=
default
;
session_id_type
GetSessionId
()
{
return
atoi
(
reply_
.
result
().
data
());
}
};
class
AllocateSharedBlockRequest
:
public
BaseRequest
{
public:
friend
class
CacheServer
;
explicit
AllocateSharedBlockRequest
(
connection_id_type
connection_id
,
size_t
requestedSz
)
:
BaseRequest
(
RequestType
::
kAllocateSharedBlock
)
{
rq_
.
set_connection_id
(
connection_id
);
rq_
.
add_buf_data
(
std
::
to_string
(
requestedSz
));
}
~
AllocateSharedBlockRequest
()
=
default
;
/// \brief On return from the server, we get the (relative) address where
/// the free block is located.
/// \return
int64_t
GetAddr
()
{
auto
addr
=
strtoll
(
reply_
.
result
().
data
(),
nullptr
,
10
);
return
addr
;
}
};
class
ShutdownRequest
:
public
BaseRequest
{
public:
friend
class
CacheServer
;
ShutdownRequest
()
:
BaseRequest
(
RequestType
::
kStopService
)
{}
~
ShutdownRequest
()
=
default
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_SERVICE_H_
mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.cc
浏览文件 @
8a08d0c3
此差异已折叠。
点击以展开。
mindspore/ccsrc/minddata/dataset/engine/cache/cache_server.h
浏览文件 @
8a08d0c3
...
...
@@ -24,8 +24,11 @@
#include <utility>
#include <vector>
#include <map>
#include <set>
#include "minddata/dataset/engine/cache/cache_service.h"
#include "minddata/dataset/engine/cache/cache_grpc_server.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/util/allocator.h"
#include "minddata/dataset/util/arena.h"
#include "minddata/dataset/util/cache_pool.h"
#include "minddata/dataset/util/lock.h"
...
...
@@ -37,43 +40,131 @@
namespace
mindspore
{
namespace
dataset
{
class
BaseRequest
;
/// \brief A server which provides CacheService services.
class
CacheServer
:
public
Service
{
public:
friend
class
Services
;
using
cache_index
=
std
::
map
<
connection_id_type
,
std
::
unique_ptr
<
CacheService
>>
;
class
Builder
{
public:
Builder
()
:
top_
(
"/tmp"
),
num_workers_
(
32
),
port_
(
50052
),
shared_memory_sz_in_gb_
(
4
)
{}
/// \brief Getter functions
const
std
::
string
&
getTop
()
const
{
return
top_
;
}
int32_t
getNumWorkers
()
const
{
return
num_workers_
;
}
int32_t
getPort
()
const
{
return
port_
;
}
int32_t
getSharedMemorySzInGb
()
const
{
return
shared_memory_sz_in_gb_
;
}
Builder
&
SetRootDirectory
(
std
::
string
root
)
{
top_
=
std
::
move
(
root
);
return
*
this
;
}
Builder
&
SetNumWorkers
(
int32_t
n
)
{
num_workers_
=
n
;
return
*
this
;
}
Builder
&
SetPort
(
int32_t
p
)
{
port_
=
p
;
return
*
this
;
}
Builder
&
SetSharedMemorySizeInGB
(
int32_t
sz
)
{
shared_memory_sz_in_gb_
=
sz
;
return
*
this
;
}
Status
SanityCheck
();
void
Print
(
std
::
ostream
&
out
)
const
{
out
<<
"Summary of the cache server configuration
\n
"
<<
"Spill directory: "
<<
getTop
()
<<
"
\n
"
<<
"Number of parallel workers: "
<<
getNumWorkers
()
<<
"
\n
"
<<
"Tcp/ip port: "
<<
getPort
()
<<
"
\n
"
<<
"Shared memory size (in GB): "
<<
getSharedMemorySzInGb
();
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
Builder
&
bld
)
{
bld
.
Print
(
out
);
return
out
;
}
Status
Build
()
{
RETURN_IF_NOT_OK
(
SanityCheck
());
// We need to bring up the Task Manager by bringing up the Services singleton.
RETURN_IF_NOT_OK
(
Services
::
CreateInstance
());
RETURN_IF_NOT_OK
(
CacheServer
::
CreateInstance
(
top_
,
num_workers_
,
port_
,
shared_memory_sz_in_gb_
));
return
Status
::
OK
();
}
private:
std
::
string
top_
;
int32_t
num_workers_
;
int32_t
port_
;
int32_t
shared_memory_sz_in_gb_
;
};
CacheServer
(
const
CacheServer
&
)
=
delete
;
CacheServer
&
operator
=
(
const
CacheServer
&
)
=
delete
;
CacheServer
(
CacheServer
&&
)
=
delete
;
CacheServer
&
operator
=
(
CacheServer
&
)
=
delete
;
static
CacheServer
&
GetInstance
()
noexcept
{
return
Services
::
getCacheServer
();
}
Status
DoServiceStart
()
override
;
Status
DoServiceStop
()
override
;
~
CacheServer
()
{
(
void
)
ServiceStop
();
}
static
Status
CreateInstance
(
const
std
::
string
&
spill_path
,
int32_t
num_workers
,
int32_t
port
,
int32_t
shared_memory_sz
)
{
std
::
call_once
(
init_instance_flag_
,
[
&
]()
->
Status
{
auto
&
svcManager
=
Services
::
GetInstance
();
RETURN_IF_NOT_OK
(
svcManager
.
AddHook
(
&
instance_
,
spill_path
,
num_workers
,
port
,
shared_memory_sz
));
return
Status
::
OK
();
});
return
Status
::
OK
();
}
static
CacheServer
&
GetInstance
()
{
return
*
instance_
;
}
/// \brief For the current demonstration, a cache client contacts cache server using a Queue.
/// \param rq
/// \return Status object
Status
PushRequest
(
Base
Request
*
rq
)
{
Status
PushRequest
(
int32_t
queue_id
,
CacheServer
Request
*
rq
)
{
RETURN_UNEXPECTED_IF_NULL
(
rq
);
RETURN_IF_NOT_OK
(
cache_q_
->
Add
(
rq
));
RETURN_IF_NOT_OK
(
cache_q_
->
operator
[](
queue_id
)
->
Add
(
rq
));
return
Status
::
OK
();
}
/// \\brief Kick off server threads. Never return unless error out.
Status
Run
();
/// \brief Get a free tag
/// \param q[in] pointer to a pointer to a CacheServerRequest
/// \return Status object
static
Status
GetFreeRequestTag
(
int32_t
queue_id
,
CacheServerRequest
**
q
);
/// \brief Return a tag to the free list
/// \param p[in] pointer to already finished CacheServerRequest tag
/// \return Status object
static
Status
ReturnRequestTag
(
CacheServerRequest
*
p
);
private:
static
std
::
once_flag
init_instance_flag_
;
static
CacheServer
*
instance_
;
mutable
RWLock
rwLock_
;
std
::
string
top_
;
cache_index
all_caches_
;
std
::
shared_ptr
<
Queue
<
BaseRequest
*>>
cache_q_
;
std
::
set
<
session_id_type
>
history_sessions_
;
std
::
shared_ptr
<
QueueList
<
CacheServerRequest
*>>
cache_q_
;
std
::
shared_ptr
<
QueueList
<
CacheServerRequest
*>>
free_list_
;
std
::
vector
<
std
::
unique_ptr
<
MemGuard
<
CacheServerRequest
,
Allocator
<
CacheServerRequest
>>>>
tag_
;
std
::
shared_ptr
<
CacheServerGreeterImpl
>
comm_layer_
;
std
::
shared_ptr
<
MemoryPool
>
mp_
;
TaskGroup
vg_
;
int32_t
num_workers_
;
int32_t
port_
;
int32_t
shared_memory_sz_in_gb_
;
std
::
atomic
<
bool
>
global_shutdown_
;
/// \brief Constructor
/// \param spill_path Top directory for spilling buffers to.
/// \param num_workers Number of threads for handling requests.
explicit
CacheServer
(
const
std
::
string
&
spill_path
,
int32_t
num_workers
=
3
);
explicit
CacheServer
(
const
std
::
string
&
spill_path
,
int32_t
num_workers
,
int32_t
port
,
int32_t
share_memory_sz_in_gb
);
/// \brief Locate a cache service from connection id.
/// \return Pointer to cache service. Null if not found
...
...
@@ -82,16 +173,65 @@ class CacheServer : public Service {
/// \brief Create a cache service. We allow multiple clients to create the same cache service.
/// Subsequent duplicate requests are ignored. The first cache client to create the service will be given
/// a special unique cookie.
/// \param[in] connection_id This is from a Cache client.
/// \param[in] cache_mem_sz
/// \param[in] flag
/// \param[out] out_cookie Only the first cache client will be given a special cookie to identify the creator
/// \return Status object
Status
CreateService
(
connection_id_type
connection_id
,
uint64_t
cache_mem_sz
,
BaseRequest
::
CreateCacheFlag
flag
,
std
::
string
*
out_cookie
);
Status
CreateService
(
CacheRequest
*
rq
,
CacheReply
*
reply
);
/// \brief Destroy a cache service
/// \param cs
/// \param rq
/// \return
Status
DestroyCache
(
CacheService
*
cs
,
CacheRequest
*
rq
);
Status
PurgeCache
(
CacheService
*
cs
);
/// \brief Entry point for all internal server threads.
Status
ServerRequest
(
int32_t
worker_id
);
/// \brief Entry point for all grpc threads.
/// \return
Status
RpcRequest
(
int32_t
worker_id
);
Status
DestroySession
(
CacheRequest
*
rq
);
/// \brief Create a connection id from a session id and a crc
/// \param session_id
/// \param crc
/// \return connection id
connection_id_type
GetConnectionID
(
session_id_type
session_id
,
uint32_t
crc
)
const
;
/// \brief Extract the session id from a connection id
/// \param connection_id
/// \return session id
session_id_type
GetSessionID
(
connection_id_type
connection_id
)
const
;
/// \brief Generate a session ID for the client
/// \return Session ID
session_id_type
GenerateSessionID
()
const
;
/// \brief Handle kAllocateSharedBlock request
/// \param rq CacheRequest
/// \param reply CacheReply
/// \return Status object
Status
AllocateSharedMemory
(
CacheRequest
*
rq
,
CacheReply
*
reply
);
/// \brief Handle kFreeSharedBlock request
/// \param rq
/// \return Status object
Status
FreeSharedMemory
(
CacheRequest
*
rq
);
/// \brief Entry point for all server threads.
Status
ServerRequest
();
/// \brief Handle kFastCacheRow request
/// \return Status object
Status
FastCacheRow
(
CacheService
*
cs
,
CacheRequest
*
rq
,
CacheReply
*
reply
);
/// \brief Internal function to do row batch fetch
/// \param cs CacheService
/// \param rq Request
/// \param reply Reply
/// \return
Status
BatchFetchRows
(
CacheService
*
cs
,
CacheRequest
*
rq
,
CacheReply
*
reply
);
/// \brief A proper shutdown of the server
/// \return Status object
Status
GlobalShutdown
();
};
}
// namespace dataset
}
// namespace mindspore
...
...
mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.cc
浏览文件 @
8a08d0c3
...
...
@@ -76,7 +76,7 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type
*
row_id_generated
=
GetNextRowId
();
// Some debug information on how many rows we have generated so far.
if
((
*
row_id_generated
)
%
1000
==
0
)
{
MS_LOG
(
DEBUG
)
<<
"Number of rows cached: "
<<
*
row_id_generated
;
MS_LOG
(
DEBUG
)
<<
"Number of rows cached: "
<<
(
*
row_id_generated
)
+
1
;
}
}
else
{
if
(
msg
->
row_id
()
<
0
)
{
...
...
@@ -114,6 +114,45 @@ Status CacheService::CacheRow(const std::vector<const void *> &buf, row_id_type
RETURN_STATUS_UNEXPECTED
(
e
.
what
());
}
}
Status
CacheService
::
FastCacheRow
(
const
ReadableSlice
&
src
,
row_id_type
*
row_id_generated
)
{
SharedLock
rw
(
&
rw_lock_
);
RETURN_UNEXPECTED_IF_NULL
(
row_id_generated
);
if
(
st_
==
State
::
kFetchPhase
)
{
// For this kind of cache service, once we are done with the build phase into fetch phase, we can't
// allow other to cache more rows.
RETURN_STATUS_UNEXPECTED
(
"Can't accept cache request in fetch phase"
);
}
try
{
// If we don't need to generate id, we need to find it from the buffer.
if
(
generate_id_
)
{
*
row_id_generated
=
GetNextRowId
();
// Some debug information on how many rows we have generated so far.
if
((
*
row_id_generated
)
%
1000
==
0
)
{
MS_LOG
(
DEBUG
)
<<
"Number of rows cached: "
<<
(
*
row_id_generated
)
+
1
;
}
}
else
{
auto
msg
=
GetTensorRowHeaderMsg
(
src
.
GetPointer
());
if
(
msg
->
row_id
()
<
0
)
{
std
::
string
errMsg
=
"Expect positive row id: "
+
std
::
to_string
(
msg
->
row_id
());
RETURN_STATUS_UNEXPECTED
(
errMsg
);
}
*
row_id_generated
=
msg
->
row_id
();
}
// Now we cache the flat buffer.
CachePool
::
key_type
key
;
RETURN_IF_NOT_OK
(
cp_
->
Insert
({
src
},
&
key
));
Status
rc
=
map_
->
DoInsert
(
*
row_id_generated
,
key
);
if
(
rc
==
Status
(
StatusCode
::
kDuplicateKey
))
{
MS_LOG
(
DEBUG
)
<<
"Ignoring duplicate key."
;
}
else
{
RETURN_IF_NOT_OK
(
rc
);
}
return
Status
::
OK
();
}
catch
(
const
std
::
exception
&
e
)
{
RETURN_STATUS_UNEXPECTED
(
e
.
what
());
}
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
CacheService
&
cs
)
{
// Then show any custom derived-internal stuff
out
<<
"
\n
Cache memory size: "
<<
cs
.
cache_mem_sz_
;
...
...
@@ -155,20 +194,15 @@ Status CacheService::GetStat(CacheService::ServiceStat *out) {
}
return
Status
::
OK
();
}
Status
CacheService
::
BatchFetch
(
const
std
::
vector
<
row_id_type
>
&
v
,
MemGuard
<
uint8_t
>
*
out
)
const
{
RETURN_UNEXPECTED_IF_NULL
(
out
);
Status
CacheService
::
PreBatchFetch
(
const
std
::
vector
<
row_id_type
>
&
v
,
std
::
vector
<
key_size_pair
>
*
out
,
int64_t
*
mem_sz
)
{
SharedLock
rw
(
&
rw_lock_
);
if
(
st_
==
State
::
kBuildPhase
)
{
// For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
RETURN_STATUS_UNEXPECTED
(
"Can't accept cache request in fetch phase"
);
}
RETURN_UNEXPECTED_IF_NULL
(
out
);
RETURN_UNEXPECTED_IF_NULL
(
mem_sz
);
const
auto
num_elements
=
v
.
size
();
int64_t
mem_sz
=
(
num_elements
+
1
)
*
sizeof
(
int64_t
);
int64_t
data_offset
=
mem_sz
;
std
::
vector
<
int64_t
>
sz_v
;
std
::
vector
<
CachePool
::
key_type
>
keys
;
sz_v
.
reserve
(
num_elements
);
keys
.
reserve
(
num_elements
);
*
mem_sz
=
(
num_elements
+
1
)
*
sizeof
(
int64_t
);
(
*
out
).
reserve
(
num_elements
);
for
(
auto
row_id
:
v
)
{
auto
r
=
map_
->
Search
(
row_id
);
if
(
r
.
second
)
{
...
...
@@ -180,25 +214,33 @@ Status CacheService::BatchFetch(const std::vector<row_id_type> &v, MemGuard<uint
errMsg
+=
std
::
to_string
(
key
);
RETURN_STATUS_UNEXPECTED
(
errMsg
);
}
keys
.
push_back
(
key
);
sz_v
.
push_back
(
sz
);
mem_sz
+=
sz
;
(
*
out
).
emplace_back
(
key
,
sz
);
(
*
mem_sz
)
+=
sz
;
}
else
{
keys
.
push_back
(
-
1
);
sz_v
.
push_back
(
0
);
(
*
out
).
emplace_back
(
-
1
,
0
);
}
}
MemGuard
<
uint8_t
>
mem
;
RETURN_IF_NOT_OK
(
mem
.
allocate
(
mem_sz
));
auto
*
offset_array
=
reinterpret_cast
<
int64_t
*>
(
mem
.
GetMutablePointer
());
return
Status
::
OK
();
}
Status
CacheService
::
BatchFetch
(
const
std
::
vector
<
row_id_type
>
&
v
,
const
std
::
vector
<
key_size_pair
>
&
info
,
WritableSlice
*
out
)
const
{
RETURN_UNEXPECTED_IF_NULL
(
out
);
SharedLock
rw
(
&
rw_lock_
);
if
(
st_
==
State
::
kBuildPhase
)
{
// For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
RETURN_STATUS_UNEXPECTED
(
"Can't accept cache request in fetch phase"
);
}
const
auto
num_elements
=
v
.
size
();
int64_t
data_offset
=
(
num_elements
+
1
)
*
sizeof
(
int64_t
);
auto
*
offset_array
=
reinterpret_cast
<
int64_t
*>
(
out
->
GetMutablePointer
());
offset_array
[
0
]
=
data_offset
;
WritableSlice
all
(
mem
.
GetMutablePointer
(),
mem
.
GetSizeInBytes
());
for
(
auto
i
=
0
;
i
<
num_elements
;
++
i
)
{
auto
sz
=
sz_v
.
at
(
i
)
;
auto
sz
=
info
.
at
(
i
).
second
;
offset_array
[
i
+
1
]
=
offset_array
[
i
]
+
sz
;
if
(
sz
>
0
)
{
WritableSlice
row_data
(
all
,
offset_array
[
i
],
sz
);
auto
key
=
keys
.
at
(
i
)
;
WritableSlice
row_data
(
*
out
,
offset_array
[
i
],
sz
);
auto
key
=
info
.
at
(
i
).
first
;
size_t
bytesRead
=
0
;
RETURN_IF_NOT_OK
(
cp_
->
Read
(
key
,
&
row_data
,
&
bytesRead
));
if
(
bytesRead
!=
sz
)
{
...
...
@@ -208,7 +250,6 @@ Status CacheService::BatchFetch(const std::vector<row_id_type> &v, MemGuard<uint
}
}
}
*
out
=
std
::
move
(
mem
);
return
Status
::
OK
();
}
Status
CacheService
::
CacheSchema
(
const
void
*
buf
,
int64_t
len
)
{
...
...
@@ -232,18 +273,26 @@ Status CacheService::CacheSchema(const void *buf, int64_t len) {
}
return
Status
::
OK
();
}
Status
CacheService
::
FetchSchema
(
MemGuard
<
uint8_t
>
*
out
)
const
{
Status
CacheService
::
FetchSchema
(
std
::
string
*
out
)
const
{
SharedLock
rw
(
&
rw_lock_
);
if
(
st_
==
State
::
kBuildPhase
)
{
// For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
RETURN_STATUS_UNEXPECTED
(
"Can't accept cache request in fetch phase"
);
}
RETURN_UNEXPECTED_IF_NULL
(
out
);
MemGuard
<
uint8_t
>
mem
;
// We are going to use std::string to allocate and hold the result which will be eventually
// 'moved' to the protobuf message (which underneath is also a std::string) for the purpose
// to minimize memory copy.
std
::
string
mem
;
if
(
schema_key_
>=
0
)
{
auto
len
=
cp_
->
GetSize
(
schema_key_
);
RETURN_IF_NOT_OK
(
mem
.
allocate
(
len
));
auto
slice
=
WritableSlice
(
mem
.
GetMutablePointer
(),
len
);
try
{
mem
.
resize
(
len
);
CHECK_FAIL_RETURN_UNEXPECTED
(
mem
.
capacity
()
>=
len
,
"Programming error"
);
}
catch
(
const
std
::
bad_alloc
&
e
)
{
return
Status
(
StatusCode
::
kOutOfMemory
);
}
auto
slice
=
WritableSlice
(
mem
.
data
(),
len
);
RETURN_IF_NOT_OK
(
cp_
->
Read
(
schema_key_
,
&
slice
));
*
out
=
std
::
move
(
mem
);
}
else
{
...
...
mindspore/ccsrc/minddata/dataset/engine/cache/cache_service.h
浏览文件 @
8a08d0c3
...
...
@@ -28,7 +28,6 @@
#include "minddata/dataset/core/global_context.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/cache/cache_request.h"
#include "minddata/dataset/engine/cache/de_tensor_generated.h"
#include "minddata/dataset/util/arena.h"
#include "minddata/dataset/util/btree.h"
#include "minddata/dataset/util/cache_pool.h"
...
...
@@ -38,7 +37,8 @@
namespace
mindspore
{
namespace
dataset
{
struct
CacheStat
;
/// Some typedef used for BatchFetch
using
key_size_pair
=
std
::
pair
<
CachePool
::
key_type
,
size_t
>
;
/// \brief A cache service for storing/fetching buffers to in memory cache and may spill to disk the cache service is
/// created to support spilling
class
CacheService
:
public
Service
{
...
...
@@ -69,12 +69,26 @@ class CacheService : public Service {
/// \param[out] row_id_generated The row id assigned to this row if any
/// \return Status object
Status
CacheRow
(
const
std
::
vector
<
const
void
*>
&
buf
,
row_id_type
*
row_id_generated
);
/// \brief A fast version of CacheRow where all the data is already in one contiguous piece.
/// \param src Slice of the data
/// \param row_id_generated
/// \return Status object
Status
FastCacheRow
(
const
ReadableSlice
&
src
,
row_id_type
*
row_id_generated
);
/// \brief This function is used in preparation for batch fetching.
/// It calculates how much memory we should allocate and which row id are present.
/// \param[in/out] Pointer to vector of <CachePool::key_type, size_t>
/// \param[in/out] mem_sz how much memory is required to batch fetch
/// \return Status object
Status
PreBatchFetch
(
const
std
::
vector
<
row_id_type
>
&
v
,
std
::
vector
<
key_size_pair
>
*
,
int64_t
*
mem_sz
);
/// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded
/// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row.
/// \param[in] v A vector of row id.
/// \param[out] out A contiguous memory buffer that holds the requested rows.
/// \return Status object
Status
BatchFetch
(
const
std
::
vector
<
row_id_type
>
&
v
,
MemGuard
<
uint8_t
>
*
out
)
const
;
Status
BatchFetch
(
const
std
::
vector
<
row_id_type
>
&
v
,
const
std
::
vector
<
key_size_pair
>
&
,
WritableSlice
*
out
)
const
;
/// \brief Getter function
/// \return Spilling path
...
...
@@ -102,7 +116,7 @@ class CacheService : public Service {
/// \brief Fetch schema
/// \param out A contiguous memory that contains the serialized form of schema.
/// \return Status object
Status
FetchSchema
(
MemGuard
<
uint8_t
>
*
out
)
const
;
Status
FetchSchema
(
std
::
string
*
out
)
const
;
/// \brief Purge the content of a cache
/// \return Status object
Status
Purge
();
...
...
mindspore/ccsrc/minddata/dataset/engine/cache/de_tensor.fbs
浏览文件 @
8a08d0c3
...
...
@@ -60,10 +60,11 @@ table TensorRowIds {
}
/// Statistics returned from each cache service
/// \note It must match CacheService
::Service
Stat
/// \note It must match CacheServiceStat
table ServiceStatMsg {
num_mem_cached:int64;
num_disk_cached:int64;
avg_cache_sz:int64;
min_row_id:int64;
max_row_id:int64;
state:int8;
...
...
@@ -79,3 +80,15 @@ table ColumnNameMsg {
table SchemaMsg {
column:[ColumnNameMsg];
}
/// Part of the CreateCacheRequest
table CreateCacheRequestMsg {
cache_mem_sz:int64;
flag:uint32;
}
/// Return result of CreateCacheRequest
table CreateCacheReplyMsg {
connection_id:int64;
cookie:string;
}
mindspore/ccsrc/minddata/dataset/engine/cache/stub/cache_grpc_client.h
0 → 100644
浏览文件 @
8a08d0c3
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_STUB_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_STUB_H_
#include <memory>
#include <string>
#include "proto/cache_grpc.pb.h"
#include "minddata/dataset/engine/cache/cache_common.h"
#include "minddata/dataset/engine/cache/cache_request.h"
#include "minddata/dataset/util/service.h"
namespace
mindspore
{
namespace
dataset
{
class
CacheClientGreeter
:
public
Service
{
public:
explicit
CacheClientGreeter
(
const
std
::
string
&
hostname
,
int32_t
port
,
int32_t
num_workers
)
{}
~
CacheClientGreeter
()
override
{}
Status
DoServiceStart
()
override
{
RETURN_STATUS_UNEXPECTED
(
"Not supported"
);
}
Status
DoServiceStop
()
override
{
RETURN_STATUS_UNEXPECTED
(
"Not supported"
);
}
void
*
SharedMemoryBaseAddr
()
{
return
nullptr
;
}
Status
HandleRequest
(
std
::
shared_ptr
<
BaseRequest
>
rq
)
{
RETURN_STATUS_UNEXPECTED
(
"Not supported"
);
}
Status
AttachToSharedMemory
(
int32_t
port
,
bool
*
local_bypass
)
{
RETURN_STATUS_UNEXPECTED
(
"Not supported"
);
}
protected:
private:
};
}
// namespace dataset
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CACHE_STUB_H_
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.cc
浏览文件 @
8a08d0c3
...
...
@@ -16,6 +16,7 @@
#include "minddata/dataset/engine/datasetops/cache_base_op.h"
#include <iomanip>
#include <iostream>
#include <utility>
#include "minddata/dataset/engine/execution_tree.h"
namespace
mindspore
{
...
...
@@ -47,22 +48,39 @@ Status CacheBase::Reset() {
}
CacheBase
::
CacheBase
(
int32_t
num_workers
,
int32_t
op_connector_size
,
int32_t
rows_per_buf
,
std
::
shared_ptr
<
CacheClient
>
cache_client
,
std
::
shared_ptr
<
Sampler
>
sampler
)
:
ParallelOp
(
num_workers
,
op_connector_size
,
sampler
),
cache_client_
(
cache_client
),
:
ParallelOp
(
num_workers
,
op_connector_size
,
std
::
move
(
sampler
)),
row_cnt_
(
0
),
num_cache_miss_
(
0
),
cache_client_
(
std
::
move
(
cache_client
)),
rows_per_buffer_
(
rows_per_buf
),
// We can cause deadlock if this internal Connector size is too small.
keys_miss_
(
num_workers_
,
1
,
connector_capacity_
)
{
keys_miss_
(
num_workers_
,
1
,
connector_capacity_
),
prefetch_size_
(
cache_client_
->
getPrefetchSize
())
{
io_block_queues_
.
Init
(
num_workers
,
op_connector_size
);
prefetch_queues_
.
Init
(
num_workers
,
op_connector_size
);
sampler_queue_
=
std
::
make_unique
<
Queue
<
std
::
shared_ptr
<
Tensor
>>>
(
op_connector_size
);
}
// Common function to fetch samples from the sampler and send them using the io_block_queues to
// the parallel workers
Status
CacheBase
::
FetchSamplesToWorkers
()
{
int64_t
buf_cnt
=
0
;
int64_t
wait_cnt
=
0
;
// Kick off several threads which will prefetch prefetch_size_ rows in advance. The rows_per_buffers_
// is too small (1 by default) and won't help performance.
RETURN_IF_NOT_OK
(
tree_
->
AllTasks
()
->
CreateAsyncTask
(
"Dispatcher"
,
std
::
bind
(
&
CacheBase
::
Dispatcher
,
this
)));
RETURN_IF_NOT_OK
(
tree_
->
LaunchWorkers
(
num_workers_
,
std
::
bind
(
&
CacheBase
::
Prefetcher
,
this
,
std
::
placeholders
::
_1
)));
// Instead of sending sampler id to WorkerEntry, we send them to the Prefetcher which will redirect them
// to the WorkerEntry.
do
{
epoch_sync_
.
Clear
();
if
(
AllowCacheMiss
()
&&
wait_cnt
>
0
)
{
MS_LOG
(
WARNING
)
<<
"Epoch: "
<<
wait_cnt
<<
" Cache Miss : "
<<
num_cache_miss_
<<
" Total number of rows : "
<<
row_cnt_
;
}
num_cache_miss_
=
0
;
row_cnt_
=
0
;
++
wait_cnt
;
std
::
vector
<
row_id_type
>
keys
;
int64_t
row_cnt
=
0
;
keys
.
reserve
(
rows_per_buffer_
);
std
::
unique_ptr
<
DataBuffer
>
sampler_buffer
;
RETURN_IF_NOT_OK
(
sampler_
->
GetNextSample
(
&
sampler_buffer
));
...
...
@@ -70,10 +88,13 @@ Status CacheBase::FetchSamplesToWorkers() {
TensorRow
sample_row
;
RETURN_IF_NOT_OK
(
sampler_buffer
->
PopRow
(
&
sample_row
));
std
::
shared_ptr
<
Tensor
>
sample_ids
=
sample_row
[
0
];
// Send the sampler tensor to other thread for prefetching. We are using shared pointer so it
// won't go out scope until it is really not in use.
RETURN_IF_NOT_OK
(
sampler_queue_
->
Add
(
sample_ids
));
for
(
auto
itr
=
sample_ids
->
begin
<
int64_t
>
();
itr
!=
sample_ids
->
end
<
int64_t
>
();
itr
++
)
{
keys
.
push_back
(
*
itr
);
++
row_cnt
;
if
(
row_cnt
%
rows_per_buffer_
==
0
)
{
++
row_cnt
_
;
if
(
row_cnt
_
%
rows_per_buffer_
==
0
)
{
auto
blk
=
std
::
make_unique
<
IOBlock
>
(
IOBlock
(
keys
,
IOBlock
::
kDeIoBlockNone
));
RETURN_IF_NOT_OK
(
io_block_queues_
[
buf_cnt
++
%
num_workers_
]
->
Add
(
std
::
move
(
blk
)));
keys
.
clear
();
...
...
@@ -90,7 +111,7 @@ Status CacheBase::FetchSamplesToWorkers() {
io_block_queues_
[(
buf_cnt
++
)
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
::
kDeIoBlockFlagEoe
)));
// If repeat but the not last repeat, wait for reset.
if
(
!
IsLastIteration
())
{
MS_LOG
(
DEBUG
)
<<
Name
()
<<
" Waiting for reset. Count "
<<
++
wait_cnt
<<
" Buffer sent "
<<
buf_cnt
;
MS_LOG
(
DEBUG
)
<<
Name
()
<<
" Waiting for reset. Count "
<<
wait_cnt
<<
" Buffer sent "
<<
buf_cnt
;
RETURN_IF_NOT_OK
(
epoch_sync_
.
Wait
());
}
else
{
// We can break out from the loop.
...
...
@@ -101,13 +122,21 @@ Status CacheBase::FetchSamplesToWorkers() {
// Flow the eof before exit
RETURN_IF_NOT_OK
(
io_block_queues_
[(
buf_cnt
++
)
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
::
kDeIoBlockFlagEof
)));
// Ask all the workers to quit.
// Shutdown threads
std
::
shared_ptr
<
Tensor
>
empty
;
RETURN_IF_NOT_OK
(
sampler_queue_
->
Add
(
std
::
move
(
empty
)));
for
(
int32_t
i
=
0
;
i
<
num_workers_
;
i
++
)
{
RETURN_IF_NOT_OK
(
io_block_queues_
[
i
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
std
::
vector
<
int64_t
>
(),
IOBlock
::
kDeIoBlockNone
)));
}
// Dump the last epoch result (approximately) without waiting for the worker threads to come back.
if
(
AllowCacheMiss
())
{
MS_LOG
(
WARNING
)
<<
"Epoch: "
<<
wait_cnt
<<
" Cache Miss : "
<<
num_cache_miss_
<<
" Total number of rows : "
<<
row_cnt_
;
}
return
Status
::
OK
();
}
Status
CacheBase
::
FetchFromCache
(
int32_t
worker_id
)
{
int64_t
buffer_id
=
worker_id
;
std
::
unique_ptr
<
IOBlock
>
blk
;
...
...
@@ -133,23 +162,16 @@ Status CacheBase::FetchFromCache(int32_t worker_id) {
}
std
::
unique_ptr
<
DataBuffer
>
db
=
std
::
make_unique
<
DataBuffer
>
(
buffer_id
,
DataBuffer
::
kDeBFlagNone
);
std
::
unique_ptr
<
TensorQTable
>
que
=
std
::
make_unique
<
TensorQTable
>
();
TensorTable
ttbl
;
RETURN_IF_NOT_OK
(
cache_client_
->
GetRows
(
keys
,
&
ttbl
));
auto
row_it
=
ttbl
.
begin
();
std
::
vector
<
row_id_type
>
cache_miss
;
cache_miss
.
reserve
(
keys
.
size
());
for
(
auto
row_id
:
keys
)
{
auto
&
row
=
*
row_it
;
TensorRow
row
;
// Block until the row shows up in the pool.
RETURN_IF_NOT_OK
(
prefetch_
.
PopFront
(
row_id
,
&
row
));
if
(
row
.
empty
())
{
if
(
AllowCacheMiss
())
{
cache_miss
.
push_back
(
row_id
);
}
else
{
std
::
string
errMsg
=
"Row id "
+
std
::
to_string
(
row_id
)
+
" not found."
;
RETURN_STATUS_UNEXPECTED
(
errMsg
);
}
cache_miss
.
push_back
(
row_id
);
}
que
->
push_back
(
std
::
move
(
row
));
++
row_it
;
}
db
->
set_tensor_table
(
std
::
move
(
que
));
if
(
AllowCacheMiss
())
{
...
...
@@ -162,12 +184,17 @@ Status CacheBase::FetchFromCache(int32_t worker_id) {
}
while
(
true
);
return
Status
::
OK
();
}
Status
CacheBase
::
RegisterResources
()
{
RETURN_IF_NOT_OK
(
epoch_sync_
.
Register
(
tree_
->
AllTasks
()));
RETURN_IF_NOT_OK
(
io_block_queues_
.
Register
(
tree_
->
AllTasks
()));
RETURN_IF_NOT_OK
(
prefetch_queues_
.
Register
(
tree_
->
AllTasks
()));
RETURN_IF_NOT_OK
(
sampler_queue_
->
Register
(
tree_
->
AllTasks
()));
return
Status
::
OK
();
}
CacheBase
::~
CacheBase
()
{}
CacheBase
::~
CacheBase
()
=
default
;
Status
CacheBase
::
UpdateColumnMapFromCache
()
{
Status
rc
;
// Get the schema from the server. It may not be there yet. So tolerate the error.
...
...
@@ -180,5 +207,77 @@ Status CacheBase::UpdateColumnMapFromCache() {
}
return
rc
;
}
Status
CacheBase
::
Dispatcher
()
{
TaskManager
::
FindMe
()
->
Post
();
int64_t
buf_cnt
=
0
;
int64_t
num_row
=
0
;
std
::
vector
<
row_id_type
>
keys
;
keys
.
reserve
(
prefetch_size_
);
do
{
keys
.
clear
();
std
::
shared_ptr
<
Tensor
>
sample_ids
;
RETURN_IF_NOT_OK
(
sampler_queue_
->
PopFront
(
&
sample_ids
));
if
(
sample_ids
==
nullptr
)
{
// A null shared pointer signal times to quit.
// Also signal all prefetchers to quit.
for
(
int32_t
i
=
0
;
i
<
num_workers_
;
i
++
)
{
RETURN_IF_NOT_OK
(
prefetch_queues_
[
i
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
std
::
vector
<
int64_t
>
(),
IOBlock
::
kDeIoBlockNone
)));
}
break
;
}
// Now we distribute the sampler ids to each prefetcher according to the prefetch size.
for
(
auto
itr
=
sample_ids
->
begin
<
int64_t
>
();
itr
!=
sample_ids
->
end
<
int64_t
>
();
itr
++
)
{
keys
.
push_back
(
*
itr
);
++
num_row
;
if
(
num_row
%
prefetch_size_
==
0
)
{
auto
blk
=
std
::
make_unique
<
IOBlock
>
(
IOBlock
(
keys
,
IOBlock
::
kDeIoBlockNone
));
RETURN_IF_NOT_OK
(
prefetch_queues_
[
buf_cnt
++
%
num_workers_
]
->
Add
(
std
::
move
(
blk
)));
keys
.
clear
();
}
}
// Send the remaining sample id
if
(
!
keys
.
empty
())
{
auto
blk
=
std
::
make_unique
<
IOBlock
>
(
IOBlock
(
keys
,
IOBlock
::
kDeIoBlockNone
));
RETURN_IF_NOT_OK
(
prefetch_queues_
[
buf_cnt
++
%
num_workers_
]
->
Add
(
std
::
move
(
blk
)));
}
}
while
(
true
);
return
Status
::
OK
();
}
Status
CacheBase
::
Prefetcher
(
int32_t
worker_id
)
{
TaskManager
::
FindMe
()
->
Post
();
std
::
vector
<
row_id_type
>
prefetch_keys
;
prefetch_keys
.
reserve
(
prefetch_size_
);
do
{
prefetch_keys
.
clear
();
std
::
unique_ptr
<
IOBlock
>
blk
;
RETURN_IF_NOT_OK
(
prefetch_queues_
[
worker_id
]
->
PopFront
(
&
blk
));
RETURN_IF_NOT_OK
(
blk
->
GetKeys
(
&
prefetch_keys
));
if
(
prefetch_keys
.
empty
())
{
// Empty keys mean time to quit.
break
;
}
TensorTable
ttbl
;
RETURN_IF_NOT_OK
(
cache_client_
->
GetRows
(
prefetch_keys
,
&
ttbl
));
auto
row_it
=
ttbl
.
begin
();
for
(
auto
row_id
:
prefetch_keys
)
{
auto
&
row
=
*
row_it
;
if
(
row
.
empty
())
{
if
(
AllowCacheMiss
())
{
++
num_cache_miss_
;
}
else
{
std
::
string
errMsg
=
"Row id "
+
std
::
to_string
(
row_id
)
+
" not found."
;
RETURN_STATUS_UNEXPECTED
(
errMsg
);
}
}
// Put the prefetch row into the pool and wake up any WorkerEntry to wait for the row
RETURN_IF_NOT_OK
(
prefetch_
.
Add
(
row_id
,
std
::
move
(
row
)));
++
row_it
;
}
}
while
(
true
);
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_base_op.h
浏览文件 @
8a08d0c3
...
...
@@ -16,6 +16,8 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_
#include <atomic>
#include <deque>
#include <memory>
#include <string>
#include <utility>
...
...
@@ -28,8 +30,9 @@
#include "minddata/dataset/engine/datasetops/source/sampler/sampler.h"
#include "minddata/dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "minddata/dataset/util/queue.h"
#include "minddata/dataset/util/queue_map.h"
#include "minddata/dataset/util/semaphore.h"
#include "minddata/dataset/util/wait_post.h"
#include "minddata/dataset/engine/datasetops/cache_base_op.h"
namespace
mindspore
{
namespace
dataset
{
/// \brief This is the base class for CacheOp and CacheLookupOp which share many similarities.
...
...
@@ -82,10 +85,13 @@ class CacheBase : public ParallelOp {
protected:
constexpr
static
int32_t
eoe_row_id
=
-
1
;
int64_t
row_cnt_
;
std
::
atomic
<
int64_t
>
num_cache_miss_
;
std
::
shared_ptr
<
CacheClient
>
cache_client_
;
WaitPost
epoch_sync_
;
int32_t
rows_per_buffer_
;
Connector
<
std
::
vector
<
row_id_type
>>
keys_miss_
;
QueueMap
<
row_id_type
,
TensorRow
>
prefetch_
;
/// \brief Common function to register resources for interrupt
/// \note Derived should override this function for extra resources to be registered
...
...
@@ -103,7 +109,15 @@ class CacheBase : public ParallelOp {
private:
constexpr
static
int32_t
connector_capacity_
=
1024
;
int32_t
prefetch_size_
;
QueueList
<
std
::
unique_ptr
<
IOBlock
>>
io_block_queues_
;
QueueList
<
std
::
unique_ptr
<
IOBlock
>>
prefetch_queues_
;
std
::
unique_ptr
<
Queue
<
std
::
shared_ptr
<
Tensor
>>>
sampler_queue_
;
Status
Dispatcher
();
/// \brief Prefetcher. It prefetch the rows from cache server
/// \return Status object.
Status
Prefetcher
(
int32_t
worker_id
);
};
}
// namespace dataset
}
// namespace mindspore
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.cc
浏览文件 @
8a08d0c3
...
...
@@ -16,8 +16,10 @@
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
#include <algorithm>
#include <chrono>
#include <functional>
#include <iomanip>
#include <utility>
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/core/constants.h"
#include "minddata/dataset/core/global_context.h"
...
...
@@ -41,9 +43,13 @@ void CacheMergeOp::Print(std::ostream &out, bool show_all) const {
out
<<
"
\n\n
"
;
}
}
CacheMergeOp
::
CacheMergeOp
(
int32_t
numWorkers
,
int32_t
opConnectorSize
,
int32_t
numCleaners
,
std
::
shared_ptr
<
CacheClient
>
cache_client
,
const
std
::
shared_ptr
<
Sampler
>
&
sampler
)
:
ParallelOp
(
numWorkers
,
opConnectorSize
,
sampler
),
num_cleaners_
(
numCleaners
),
cache_client_
(
cache_client
)
{}
:
ParallelOp
(
numWorkers
,
opConnectorSize
,
sampler
),
num_cleaners_
(
numCleaners
),
cache_client_
(
std
::
move
(
cache_client
))
{}
Status
CacheMergeOp
::
operator
()()
{
// A queue of row id to let cleaner send cache miss rows to the cache server
// We don't want a small queue as this will block the parallel op workers.
...
...
@@ -62,6 +68,7 @@ Status CacheMergeOp::operator()() {
TaskManager
::
FindMe
()
->
Post
();
return
Status
::
OK
();
}
// Each parallel worker will pop from the CacheHit stream. If there is a missing TensorRow, we will wait
// until it shows up in the pool.
Status
CacheMergeOp
::
WorkerEntry
(
int32_t
worker_id
)
{
...
...
@@ -82,10 +89,8 @@ Status CacheMergeOp::WorkerEntry(int32_t worker_id) {
RETURN_IF_NOT_OK
(
db_ptr
->
PopRow
(
&
row
));
if
(
row
.
empty
())
{
auto
row_id
=
row
.
getId
();
TensorRowRequest
*
rq
=
nullptr
;
RETURN_IF_NOT_OK
(
GetRq
(
row_id
,
&
rq
));
// Block until the row shows up in the pool.
RETURN_IF_NOT_OK
(
rq
->
Wait
(
&
row
));
RETURN_IF_NOT_OK
(
cache_miss_
.
PopFront
(
row_id
,
&
row
));
}
tbl
->
push_back
(
std
::
move
(
row
));
}
...
...
@@ -97,6 +102,7 @@ Status CacheMergeOp::WorkerEntry(int32_t worker_id) {
RETURN_IF_NOT_OK
(
EofReceived
(
worker_id
));
return
Status
::
OK
();
}
Status
CacheMergeOp
::
CacheMissWorkerEntry
(
int32_t
workerId
)
{
TaskManager
::
FindMe
()
->
Post
();
// We will simply pop TensorRow from the stream and insert them into the pool and
...
...
@@ -123,17 +129,27 @@ Status CacheMergeOp::CacheMissWorkerEntry(int32_t workerId) {
std
::
string
errMsg
=
"Expect positive row id: "
+
std
::
to_string
(
row_id
);
RETURN_STATUS_UNEXPECTED
(
errMsg
);
}
TensorRowRequest
*
rq
=
nullptr
;
// Technically number of this row shows up in the cache miss stream is equal to the number
// of P() call. However the cleaner wants it too. So we need an extra copy.
TensorRowCacheRequest
*
rq
;
RETURN_IF_NOT_OK
(
GetRq
(
row_id
,
&
rq
));
rq
->
WakeUpAny
(
std
::
move
(
row
));
// Let the cleaner to flush out this row (async) to the cache server.
RETURN_IF_NOT_OK
(
io_que_
->
EmplaceBack
(
row_id
));
if
(
rq
->
GetState
()
==
TensorRowCacheRequest
::
State
::
kEmpty
)
{
// We will send the request async. But any error we most
// likely ignore and continue.
Status
rc
;
rc
=
rq
->
AsyncSendCacheRequest
(
cache_client_
,
row
);
if
(
rc
.
IsOk
())
{
RETURN_IF_NOT_OK
(
io_que_
->
EmplaceBack
(
row_id
));
}
}
RETURN_IF_NOT_OK
(
cache_miss_
.
Add
(
row_id
,
std
::
move
(
row
)));
}
}
RETURN_IF_NOT_OK
(
cache_missing_stream
->
GetNextBuffer
(
&
db_ptr
,
workerId
));
}
return
Status
::
OK
();
}
Status
CacheMergeOp
::
Cleaner
()
{
TaskManager
::
FindMe
()
->
Post
();
while
(
true
)
{
...
...
@@ -142,45 +158,28 @@ Status CacheMergeOp::Cleaner() {
if
(
row_id
<
0
)
{
break
;
}
TensorRowRequest
*
rq
=
nullptr
;
// Locate the cache request
TensorRowCacheRequest
*
rq
;
RETURN_IF_NOT_OK
(
GetRq
(
row_id
,
&
rq
));
if
(
rq
->
GetState
()
==
TensorRowRequest
::
State
::
kClean
)
{
// If already flushed, move on to the next one.
// If already flushed, move on to the next one.
if
(
rq
->
GetState
()
==
TensorRowCacheRequest
::
State
::
kClean
)
{
continue
;
}
TensorRow
row
;
RETURN_IF_NOT_OK
(
rq
->
Release
(
&
row
));
CHECK_FAIL_RETURN_UNEXPECTED
(
!
row
.
empty
(),
"Programming error."
);
Status
rc
=
cache_client_
->
WriteRow
(
row
);
// Bad rc should not bring down the pipeline
Status
rc
=
rq
->
CheckCacheResult
();
if
(
rc
.
IsError
())
{
MS_LOG
(
WARNING
)
<<
"Cache not successful."
<<
rc
.
ToString
();
// If interrupt, time to quit.
if
(
rc
.
get_code
()
==
StatusCode
::
kInterrupted
)
{
return
Status
::
OK
();
}
MS_LOG
(
INFO
)
<<
"Cache row not successful: "
<<
rc
.
ToString
();
// Bad rc should not bring down the pipeline. We will simply continue and
// change the state back to empty. We don't need a CAS from CLEAN back to EMPTY.
rq
->
SetState
(
TensorRowCacheRequest
::
State
::
kEmpty
);
}
rq
->
SetState
(
TensorRowRequest
::
State
::
kClean
);
}
return
Status
::
OK
();
}
Status
CacheMergeOp
::
GetRq
(
row_id_type
row_id
,
CacheMergeOp
::
TensorRowRequest
**
out
)
{
RETURN_UNEXPECTED_IF_NULL
(
out
);
std
::
unique_lock
<
std
::
mutex
>
lck
(
mux_
);
auto
it
=
cache_miss_map_
.
find
(
row_id
);
if
(
it
!=
cache_miss_map_
.
end
())
{
*
out
=
it
->
second
.
GetMutablePointer
();
}
else
{
// We will create a new one.
auto
alloc
=
Services
::
GetAllocator
<
TensorRowRequest
>
();
auto
r
=
cache_miss_map_
.
emplace
(
row_id
,
MemGuard
<
TensorRowRequest
,
Allocator
<
TensorRowRequest
>>
(
alloc
));
if
(
r
.
second
)
{
auto
&
mem
=
r
.
first
->
second
;
RETURN_IF_NOT_OK
(
mem
.
allocate
(
1
,
row_id
));
*
out
=
mem
.
GetMutablePointer
();
}
else
{
RETURN_STATUS_UNEXPECTED
(
"Map insert fail."
);
}
}
return
Status
::
OK
();
}
Status
CacheMergeOp
::
PrepareNodePostAction
()
{
// Run any common code from super class first before adding our own
// specific logic
CHECK_FAIL_RETURN_UNEXPECTED
(
child_
.
size
()
==
2
,
"Incorrect number of children"
);
...
...
@@ -199,6 +198,7 @@ Status CacheMergeOp::PrepareNodePostAction() { // Run any common code from supe
RETURN_IF_NOT_OK
(
rc
);
return
Status
::
OK
();
}
Status
CacheMergeOp
::
ComputeColMap
()
{
CHECK_FAIL_RETURN_UNEXPECTED
(
child_
[
kCacheMissChildIdx
]
!=
nullptr
,
"Cache miss stream empty"
);
if
(
column_name_id_map
().
empty
())
{
...
...
@@ -207,53 +207,13 @@ Status CacheMergeOp::ComputeColMap() {
CHECK_FAIL_RETURN_UNEXPECTED
(
!
column_name_id_map
().
empty
(),
"No column map detected"
);
return
Status
::
OK
();
}
Status
CacheMergeOp
::
TensorRowRequest
::
Wait
(
TensorRow
*
out
)
{
RETURN_UNEXPECTED_IF_NULL
(
out
);
// Block until the missing row is in the pool.
RETURN_IF_NOT_OK
(
use_count_
.
P
());
std
::
unique_lock
<
std
::
mutex
>
lck
(
dq_mux_
);
CHECK_FAIL_RETURN_UNEXPECTED
(
!
row_
.
empty
(),
"Programming error"
);
*
out
=
std
::
move
(
row_
.
front
());
row_
.
pop_front
();
return
Status
::
OK
();
}
void
CacheMergeOp
::
TensorRowRequest
::
WakeUpAny
(
TensorRow
&&
row
)
{
std
::
unique_lock
<
std
::
mutex
>
lck
(
dq_mux_
);
// Technically number of this row shows up in the cache miss stream is equal to the number
// of P() call. However the cleaner wants it too. So we need an extra copy.
if
(
GetState
()
==
State
::
kEmpty
)
{
// We will do a deep copy
for
(
auto
&
ts
:
row
)
{
std
::
shared_ptr
<
Tensor
>
out_ts
;
Tensor
::
CreateFromTensor
(
ts
,
&
out_ts
);
cleaner_copy_
.
push_back
(
out_ts
);
}
cleaner_copy_
.
setId
(
row
.
getId
());
// Change the state to dirty
SetState
(
State
::
kDirty
);
}
row_
.
push_back
(
std
::
move
(
row
));
// Bump up the use count by 1. This wake up any parallel worker which is waiting
// for this row.
use_count_
.
V
();
}
Status
CacheMergeOp
::
TensorRowRequest
::
Release
(
TensorRow
*
out
)
{
RETURN_UNEXPECTED_IF_NULL
(
out
);
// We are not holding any mutex here because the cleaner isn't really touching the deque row_.
// In case we have multiple cleaners and they all see the copy, only one of them will
// get it.
auto
expected
=
State
::
kDirty
;
if
(
st_
.
compare_exchange_strong
(
expected
,
State
::
kClean
))
{
*
out
=
std
::
move
(
cleaner_copy_
);
}
return
Status
::
OK
();
}
// Builder constructor. Creates the builder object.
CacheMergeOp
::
Builder
::
Builder
()
:
build_cache_client_
(
nullptr
),
build_sampler_
(
nullptr
)
{
std
::
shared_ptr
<
ConfigManager
>
cfg
=
GlobalContext
::
config_manager
();
build_num_workers_
=
cfg
->
num_parallel_workers
();
build_op_connector_size_
=
cfg
->
op_connector_size
();
build_num_cleaners_
=
1
;
build_num_cleaners_
=
cfg
->
num_parallel_workers
()
;
}
// Check if the required parameters are set by the builder.
...
...
@@ -311,5 +271,60 @@ Status CacheMergeOp::EofReceived(int32_t worker_id) {
MS_LOG
(
DEBUG
)
<<
"Cache merge sending eof"
;
return
DatasetOp
::
EofReceived
(
worker_id
);
}
Status
CacheMergeOp
::
GetRq
(
row_id_type
row_id
,
CacheMergeOp
::
TensorRowCacheRequest
**
out
)
{
RETURN_UNEXPECTED_IF_NULL
(
out
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mux_
);
auto
it
=
io_request_
.
find
(
row_id
);
if
(
it
!=
io_request_
.
end
())
{
*
out
=
it
->
second
.
GetMutablePointer
();
}
else
{
// We will create a new one.
auto
alloc
=
Services
::
GetAllocator
<
TensorRowCacheRequest
>
();
auto
r
=
io_request_
.
emplace
(
row_id
,
MemGuard
<
TensorRowCacheRequest
,
Allocator
<
TensorRowCacheRequest
>>
(
alloc
));
if
(
r
.
second
)
{
auto
&
mem
=
r
.
first
->
second
;
RETURN_IF_NOT_OK
(
mem
.
allocate
(
1
));
*
out
=
mem
.
GetMutablePointer
();
}
else
{
RETURN_STATUS_UNEXPECTED
(
"Map insert fail."
);
}
}
return
Status
::
OK
();
}
Status
CacheMergeOp
::
TensorRowCacheRequest
::
AsyncSendCacheRequest
(
const
std
::
shared_ptr
<
CacheClient
>
&
cc
,
const
TensorRow
&
row
)
{
auto
expected
=
State
::
kEmpty
;
if
(
st_
.
compare_exchange_strong
(
expected
,
State
::
kDirty
))
{
// We will do a deep copy but write directly into CacheRequest protobuf or shared memory
Status
rc
;
cleaner_copy_
=
std
::
make_shared
<
CacheRowRequest
>
(
cc
->
server_connection_id_
,
cc
->
cookie
(),
cc
->
SupportLocalClient
());
rc
=
cleaner_copy_
->
SerializeCacheRowRequest
(
cc
.
get
(),
row
);
if
(
rc
.
IsOk
())
{
// Send the request async. The cleaner will check the return code.
rc
=
cc
->
PushRequest
(
cleaner_copy_
);
}
if
(
rc
.
IsError
())
{
// Clean up the shared pointer and reset the state back to empty
cleaner_copy_
.
reset
();
st_
=
State
::
kEmpty
;
}
}
return
Status
::
OK
();
}
Status
CacheMergeOp
::
TensorRowCacheRequest
::
CheckCacheResult
()
{
auto
expected
=
State
::
kDirty
;
if
(
st_
.
compare_exchange_strong
(
expected
,
State
::
kClean
))
{
// Success or not, we will release the memory.
// We simply move it out of the structure and let it go out of scope.
auto
cache_request
=
std
::
move
(
cleaner_copy_
);
RETURN_IF_NOT_OK
(
cache_request
->
Wait
());
return
Status
::
OK
();
}
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_merge_op.h
浏览文件 @
8a08d0c3
此差异已折叠。
点击以展开。
mindspore/ccsrc/minddata/dataset/engine/datasetops/cache_op.cc
浏览文件 @
8a08d0c3
...
...
@@ -142,7 +142,7 @@ Status CacheOp::WaitForCachingAllRows() {
}
// Get statistics from the server, and if we are not the one to create the cache,
// wait until the state changed from build phase to fetch base.
Cache
Client
::
ServiceStat
stat
{};
CacheServiceStat
stat
{};
bool
BuildPhaseDone
=
true
;
do
{
RETURN_IF_NOT_OK
(
cache_client_
->
GetStat
(
&
stat
));
...
...
@@ -157,6 +157,7 @@ Status CacheOp::WaitForCachingAllRows() {
MS_LOG
(
INFO
)
<<
"Number of rows cached: "
<<
num_rows_
;
MS_LOG
(
INFO
)
<<
"Number of rows cached in memory : "
<<
stat
.
num_mem_cached
;
MS_LOG
(
INFO
)
<<
"Number of rows spilled to disk : "
<<
stat
.
num_disk_cached
;
MS_LOG
(
INFO
)
<<
"Average cache size : "
<<
stat
.
avg_cache_sz
;
// Now all rows are cached and we have done a sync point check up. Next phase is
// is pick up fetch input from sampler and pass up to the caller.
RETURN_IF_NOT_OK
(
sampler_
->
HandshakeRandomAccessOp
(
this
));
...
...
mindspore/ccsrc/minddata/dataset/engine/datasetops/dataset_op.cc
浏览文件 @
8a08d0c3
...
...
@@ -392,6 +392,13 @@ uint32_t DatasetOp::GenerateCRC(const std::shared_ptr<DatasetOp> &op) {
ss_str
=
std
::
regex_replace
(
ss_str
,
std
::
regex
(
"Num workers.*
\n
"
),
""
);
ss_str
=
std
::
regex_replace
(
ss_str
,
std
::
regex
(
"
\\
[workers.*
\\
]"
),
""
);
// Filter out tcp/ip information
ss_str
=
std
::
regex_replace
(
ss_str
,
std
::
regex
(
"Hostname.*
\n
"
),
""
);
ss_str
=
std
::
regex_replace
(
ss_str
,
std
::
regex
(
"Port.*
\n
"
),
""
);
ss_str
=
std
::
regex_replace
(
ss_str
,
std
::
regex
(
"Number of rpc workers.*
\n
"
),
""
);
ss_str
=
std
::
regex_replace
(
ss_str
,
std
::
regex
(
"Prefetch size.*
\n
"
),
""
);
ss_str
=
std
::
regex_replace
(
ss_str
,
std
::
regex
(
"Local client support.*
\n
"
),
""
);
// Filter out Number of rows when generating the check sum
ss_str
=
std
::
regex_replace
(
ss_str
,
std
::
regex
(
"Number of rows.*
\n
"
),
""
);
...
...
mindspore/ccsrc/minddata/dataset/include/status.h
浏览文件 @
8a08d0c3
...
...
@@ -73,6 +73,7 @@ enum class StatusCode : char {
kProfilingError
=
10
,
kBoundingBoxOutOfBounds
=
11
,
kBoundingBoxInvalidShape
=
12
,
kSyntaxError
=
13
,
// Make this error code the last one. Add new error code above it.
kUnexpectedError
=
127
};
...
...
mindspore/ccsrc/minddata/dataset/util/allocator.h
浏览文件 @
8a08d0c3
...
...
@@ -168,9 +168,9 @@ class MemGuard {
size_t
GetSizeInBytes
()
const
{
return
n_
*
sizeof
(
T
);
}
private:
size_t
n_
;
allocator
alloc_
;
std
::
unique_ptr
<
T
[]
>
ptr_
;
size_t
n_
;
};
}
// namespace dataset
}
// namespace mindspore
...
...
mindspore/ccsrc/minddata/dataset/util/arena.h
浏览文件 @
8a08d0c3
此差异已折叠。
点击以展开。
mindspore/ccsrc/minddata/dataset/util/cache_pool.cc
浏览文件 @
8a08d0c3
此差异已折叠。
点击以展开。
mindspore/ccsrc/minddata/dataset/util/cache_pool.h
浏览文件 @
8a08d0c3
...
...
@@ -82,6 +82,7 @@ class CachePool : public Service {
struct
CacheStat
{
int64_t
num_mem_cached
;
int64_t
num_disk_cached
;
int64_t
average_cache_sz
;
};
/// \brief Constructor
...
...
mindspore/ccsrc/minddata/dataset/util/queue_map.h
0 → 100644
浏览文件 @
8a08d0c3
此差异已折叠。
点击以展开。
mindspore/ccsrc/minddata/dataset/util/services.cc
浏览文件 @
8a08d0c3
此差异已折叠。
点击以展开。
mindspore/ccsrc/minddata/dataset/util/services.h
浏览文件 @
8a08d0c3
此差异已折叠。
点击以展开。
mindspore/ccsrc/minddata/dataset/util/slice.h
浏览文件 @
8a08d0c3
...
...
@@ -86,6 +86,7 @@ class ReadableSlice {
class
WritableSlice
:
public
ReadableSlice
{
public:
friend
class
StorageContainer
;
friend
class
CacheService
;
/// \brief Default constructor
WritableSlice
()
:
ReadableSlice
(),
mutable_data_
(
nullptr
)
{}
/// \brief This form of a constructor takes a pointer and its size.
...
...
mindspore/ccsrc/minddata/dataset/util/status.cc
浏览文件 @
8a08d0c3
此差异已折叠。
点击以展开。
mindspore/ccsrc/minddata/dataset/util/status.h
浏览文件 @
8a08d0c3
此差异已折叠。
点击以展开。
mindspore/ccsrc/minddata/dataset/util/task_manager.cc
浏览文件 @
8a08d0c3
此差异已折叠。
点击以展开。
mindspore/ccsrc/minddata/dataset/util/task_manager.h
浏览文件 @
8a08d0c3
此差异已折叠。
点击以展开。
mindspore/dataset/engine/cache_client.py
浏览文件 @
8a08d0c3
此差异已折叠。
点击以展开。
tests/ut/cpp/dataset/cache_op_test.cc
浏览文件 @
8a08d0c3
此差异已折叠。
点击以展开。
tests/ut/python/dataset/test_cache_map.py
浏览文件 @
8a08d0c3
此差异已折叠。
点击以展开。
tests/ut/python/dataset/test_cache_nomap.py
浏览文件 @
8a08d0c3
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录