Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
7c1bc519
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7c1bc519
编写于
4月 29, 2020
作者:
J
Jesse Lee
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Initial Drop of CacheOp Phase I
上级
be9b3c53
变更
82
展开全部
隐藏空白更改
内联
并排
Showing
82 changed file
with
5868 addition
and
374 deletion
+5868
-374
mindspore/ccsrc/dataset/CMakeLists.txt
mindspore/ccsrc/dataset/CMakeLists.txt
+7
-1
mindspore/ccsrc/dataset/api/de_pipeline.cc
mindspore/ccsrc/dataset/api/de_pipeline.cc
+150
-22
mindspore/ccsrc/dataset/api/de_pipeline.h
mindspore/ccsrc/dataset/api/de_pipeline.h
+12
-0
mindspore/ccsrc/dataset/api/python_bindings.cc
mindspore/ccsrc/dataset/api/python_bindings.cc
+7
-0
mindspore/ccsrc/dataset/engine/CMakeLists.txt
mindspore/ccsrc/dataset/engine/CMakeLists.txt
+6
-3
mindspore/ccsrc/dataset/engine/cache/CMakeLists.txt
mindspore/ccsrc/dataset/engine/cache/CMakeLists.txt
+8
-0
mindspore/ccsrc/dataset/engine/cache/cache_client.cc
mindspore/ccsrc/dataset/engine/cache/cache_client.cc
+208
-0
mindspore/ccsrc/dataset/engine/cache/cache_client.h
mindspore/ccsrc/dataset/engine/cache/cache_client.h
+141
-0
mindspore/ccsrc/dataset/engine/cache/cache_request.cc
mindspore/ccsrc/dataset/engine/cache/cache_request.cc
+223
-0
mindspore/ccsrc/dataset/engine/cache/cache_request.h
mindspore/ccsrc/dataset/engine/cache/cache_request.h
+225
-0
mindspore/ccsrc/dataset/engine/cache/cache_server.cc
mindspore/ccsrc/dataset/engine/cache/cache_server.cc
+252
-0
mindspore/ccsrc/dataset/engine/cache/cache_server.h
mindspore/ccsrc/dataset/engine/cache/cache_server.h
+98
-0
mindspore/ccsrc/dataset/engine/cache/cache_service.cc
mindspore/ccsrc/dataset/engine/cache/cache_service.cc
+265
-0
mindspore/ccsrc/dataset/engine/cache/cache_service.h
mindspore/ccsrc/dataset/engine/cache/cache_service.h
+143
-0
mindspore/ccsrc/dataset/engine/cache/de_tensor.fbs
mindspore/ccsrc/dataset/engine/cache/de_tensor.fbs
+81
-0
mindspore/ccsrc/dataset/engine/data_buffer.cc
mindspore/ccsrc/dataset/engine/data_buffer.cc
+2
-12
mindspore/ccsrc/dataset/engine/data_buffer.h
mindspore/ccsrc/dataset/engine/data_buffer.h
+9
-15
mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt
mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt
+5
-1
mindspore/ccsrc/dataset/engine/datasetops/cache_base_op.cc
mindspore/ccsrc/dataset/engine/datasetops/cache_base_op.cc
+185
-0
mindspore/ccsrc/dataset/engine/datasetops/cache_base_op.h
mindspore/ccsrc/dataset/engine/datasetops/cache_base_op.h
+108
-0
mindspore/ccsrc/dataset/engine/datasetops/cache_lookup_op.cc
mindspore/ccsrc/dataset/engine/datasetops/cache_lookup_op.cc
+130
-0
mindspore/ccsrc/dataset/engine/datasetops/cache_lookup_op.h
mindspore/ccsrc/dataset/engine/datasetops/cache_lookup_op.h
+122
-0
mindspore/ccsrc/dataset/engine/datasetops/cache_merge_op.cc
mindspore/ccsrc/dataset/engine/datasetops/cache_merge_op.cc
+301
-0
mindspore/ccsrc/dataset/engine/datasetops/cache_merge_op.h
mindspore/ccsrc/dataset/engine/datasetops/cache_merge_op.h
+196
-0
mindspore/ccsrc/dataset/engine/datasetops/cache_op.cc
mindspore/ccsrc/dataset/engine/datasetops/cache_op.cc
+219
-0
mindspore/ccsrc/dataset/engine/datasetops/cache_op.h
mindspore/ccsrc/dataset/engine/datasetops/cache_op.h
+168
-0
mindspore/ccsrc/dataset/engine/datasetops/concat_op.cc
mindspore/ccsrc/dataset/engine/datasetops/concat_op.cc
+17
-30
mindspore/ccsrc/dataset/engine/datasetops/concat_op.h
mindspore/ccsrc/dataset/engine/datasetops/concat_op.h
+0
-6
mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc
mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc
+30
-37
mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h
mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h
+34
-32
mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc
mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc
+26
-23
mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h
mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h
+19
-12
mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc
...spore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc
+7
-0
mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h
mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h
+6
-0
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc
+7
-0
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.h
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.h
+6
-0
mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.cc
mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.cc
+7
-0
mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.h
mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.h
+6
-0
mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc
...ore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc
+7
-0
mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.h
...pore/ccsrc/dataset/engine/datasetops/source/manifest_op.h
+6
-0
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc
+7
-0
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h
+6
-0
mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.cc
.../ccsrc/dataset/engine/datasetops/source/random_data_op.cc
+7
-10
mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.h
...e/ccsrc/dataset/engine/datasetops/source/random_data_op.h
+6
-6
mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc
...re/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc
+13
-16
mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h
...ore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h
+5
-0
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc
+6
-0
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h
+6
-0
mindspore/ccsrc/dataset/engine/datasetops/take_op.cc
mindspore/ccsrc/dataset/engine/datasetops/take_op.cc
+0
-6
mindspore/ccsrc/dataset/engine/datasetops/take_op.h
mindspore/ccsrc/dataset/engine/datasetops/take_op.h
+0
-6
mindspore/ccsrc/dataset/engine/execution_tree.cc
mindspore/ccsrc/dataset/engine/execution_tree.cc
+19
-28
mindspore/ccsrc/dataset/engine/execution_tree.h
mindspore/ccsrc/dataset/engine/execution_tree.h
+0
-20
mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt
mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt
+3
-0
mindspore/ccsrc/dataset/engine/opt/pass.cc
mindspore/ccsrc/dataset/engine/opt/pass.cc
+80
-0
mindspore/ccsrc/dataset/engine/opt/pass.h
mindspore/ccsrc/dataset/engine/opt/pass.h
+50
-0
mindspore/ccsrc/dataset/engine/opt/post/repeat_pass.cc
mindspore/ccsrc/dataset/engine/opt/post/repeat_pass.cc
+161
-0
mindspore/ccsrc/dataset/engine/opt/post/repeat_pass.h
mindspore/ccsrc/dataset/engine/opt/post/repeat_pass.h
+98
-0
mindspore/ccsrc/dataset/engine/opt/pre/cache_pass.cc
mindspore/ccsrc/dataset/engine/opt/pre/cache_pass.cc
+181
-0
mindspore/ccsrc/dataset/engine/opt/pre/cache_pass.h
mindspore/ccsrc/dataset/engine/opt/pre/cache_pass.h
+138
-0
mindspore/ccsrc/dataset/engine/opt/pre/cache_transform_pass.cc
...pore/ccsrc/dataset/engine/opt/pre/cache_transform_pass.cc
+108
-0
mindspore/ccsrc/dataset/engine/opt/pre/cache_transform_pass.h
...spore/ccsrc/dataset/engine/opt/pre/cache_transform_pass.h
+79
-0
mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.cc
mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.cc
+17
-1
mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.h
mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.h
+12
-0
mindspore/ccsrc/dataset/engine/opt/pre/removal_pass.cc
mindspore/ccsrc/dataset/engine/opt/pre/removal_pass.cc
+2
-0
mindspore/ccsrc/dataset/util/allocator.h
mindspore/ccsrc/dataset/util/allocator.h
+4
-3
mindspore/ccsrc/dataset/util/cache_pool.cc
mindspore/ccsrc/dataset/util/cache_pool.cc
+0
-5
mindspore/ccsrc/dataset/util/services.cc
mindspore/ccsrc/dataset/util/services.cc
+22
-6
mindspore/ccsrc/dataset/util/services.h
mindspore/ccsrc/dataset/util/services.h
+6
-2
mindspore/dataset/__init__.py
mindspore/dataset/__init__.py
+1
-0
mindspore/dataset/engine/cache_client.py
mindspore/dataset/engine/cache_client.py
+49
-0
mindspore/dataset/engine/datasets.py
mindspore/dataset/engine/datasets.py
+78
-25
mindspore/dataset/engine/serializer_deserializer.py
mindspore/dataset/engine/serializer_deserializer.py
+3
-1
mindspore/dataset/engine/validators.py
mindspore/dataset/engine/validators.py
+45
-21
mindspore/dataset/text/validators.py
mindspore/dataset/text/validators.py
+6
-6
mindspore/dataset/transforms/vision/validators.py
mindspore/dataset/transforms/vision/validators.py
+7
-7
tests/ut/cpp/dataset/c_api_test.cc
tests/ut/cpp/dataset/c_api_test.cc
+1
-1
tests/ut/cpp/dataset/cache_op_test.cc
tests/ut/cpp/dataset/cache_op_test.cc
+579
-0
tests/ut/data/dataset/golden/cache_map_01_result.npz
tests/ut/data/dataset/golden/cache_map_01_result.npz
+0
-0
tests/ut/data/dataset/golden/cache_map_02_result.npz
tests/ut/data/dataset/golden/cache_map_02_result.npz
+0
-0
tests/ut/python/dataset/test_cache_map.py
tests/ut/python/dataset/test_cache_map.py
+157
-0
tests/ut/python/dataset/test_cache_nomap.py
tests/ut/python/dataset/test_cache_nomap.py
+429
-0
tests/ut/python/dataset/test_random_dataset.py
tests/ut/python/dataset/test_random_dataset.py
+28
-10
未找到文件。
mindspore/ccsrc/dataset/CMakeLists.txt
浏览文件 @
7c1bc519
...
...
@@ -47,6 +47,8 @@ include_directories(${CMAKE_SOURCE_DIR}/mindspore/ccsrc/dataset/include)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-Wl,-rpath,$ORIGIN:$ORIGIN/lib"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-fvisibility=default"
)
ms_build_flatbuffers
(
"engine/cache/de_tensor.fbs"
${
CMAKE_CURRENT_SOURCE_DIR
}
generated_engine_files
${
CMAKE_BINARY_DIR
}
)
################## Include sub-modules ###############################
add_subdirectory
(
util
)
add_subdirectory
(
core
)
...
...
@@ -55,7 +57,7 @@ add_subdirectory(engine)
add_subdirectory
(
api
)
add_subdirectory
(
text
)
######################################################################
add_dependencies
(
core utils
)
add_dependencies
(
utils core
)
add_dependencies
(
kernels-image core
)
add_dependencies
(
kernels-data core
)
add_dependencies
(
kernels core
)
...
...
@@ -89,6 +91,8 @@ set(submodules
$<TARGET_OBJECTS:engine-perf>
$<TARGET_OBJECTS:engine-datasetops>
$<TARGET_OBJECTS:engine-opt>
$<TARGET_OBJECTS:engine-cache-client>
$<TARGET_OBJECTS:engine-cache-server>
$<TARGET_OBJECTS:engine>
$<TARGET_OBJECTS:text>
$<TARGET_OBJECTS:text-kernels>
...
...
@@ -106,6 +110,8 @@ else ()
add_library
(
_c_dataengine SHARED
${
submodules
}
)
endif
()
add_dependencies
(
_c_dataengine generated_engine_files
)
set_target_properties
(
_c_dataengine PROPERTIES
PREFIX
"
${
PYTHON_MODULE_PREFIX
}
"
SUFFIX
"
${
PYTHON_MODULE_EXTENSION
}
"
...
...
mindspore/ccsrc/dataset/api/de_pipeline.cc
浏览文件 @
7c1bc519
...
...
@@ -21,8 +21,10 @@
#include "common/utils.h"
#include "dataset/core/tensor.h"
#include "dataset/engine/cache/cache_client.h"
#include "dataset/engine/dataset_iterator.h"
#include "dataset/engine/datasetops/bucket_batch_by_length_op.h"
#include "dataset/engine/datasetops/cache_op.h"
#include "dataset/engine/datasetops/filter_op.h"
#include "dataset/engine/datasetops/source/celeba_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
...
...
@@ -34,6 +36,7 @@
#include "dataset/engine/datasetops/source/random_data_op.h"
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/engine/datasetops/source/voc_op.h"
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/kernels/py_func_op.h"
#include "dataset/util/random.h"
#include "dataset/util/status.h"
...
...
@@ -441,6 +444,8 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *
MapOp
::
Builder
map_builder
;
std
::
vector
<
std
::
shared_ptr
<
TensorOp
>>
tensor_op_list
;
std
::
vector
<
std
::
string
>
project_columns
;
std
::
shared_ptr
<
CacheClient
>
cache_client
=
nullptr
;
int
num_workers
=
0
;
if
(
args
[
"operations"
].
is_none
())
RETURN_STATUS_UNEXPECTED
(
"Error: 'operations' is not set.
\n
"
);
...
...
@@ -456,7 +461,8 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *
}
else
if
(
key
==
"columns_order"
)
{
project_columns
=
ToStringVector
(
value
);
}
else
if
(
key
==
"num_parallel_workers"
)
{
(
void
)
map_builder
.
SetNumWorkers
(
ToInt
(
value
));
num_workers
=
ToInt
(
value
);
(
void
)
map_builder
.
SetNumWorkers
(
num_workers
);
}
else
if
(
key
==
"prefetch_size"
)
{
(
void
)
map_builder
.
SetOpConnectorSize
(
ToInt
(
value
));
}
else
if
(
key
==
"operations"
)
{
...
...
@@ -477,6 +483,8 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *
}
if
(
tensor_op_list
.
empty
())
RETURN_STATUS_UNEXPECTED
(
"Error: tensor_op is invalid or not set."
);
(
void
)
map_builder
.
SetTensorFuncs
(
std
::
move
(
tensor_op_list
));
}
else
if
(
key
==
"cache"
)
{
cache_client
=
value
.
cast
<
std
::
shared_ptr
<
CacheClient
>>
();
}
else
{
RETURN_STATUS_UNEXPECTED
(
"Error: Unhandled key: "
+
key
);
}
...
...
@@ -499,6 +507,15 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr<DatasetOp> *
*
bottom
=
map_op
;
}
// Additionally, add a cache if required. This will go over top of the project op if one
// was created, otherwise it goes over top of the map op
if
(
cache_client
)
{
std
::
shared_ptr
<
DatasetOp
>
cache_op
=
nullptr
;
RETURN_IF_NOT_OK
(
AddCacheOp
(
cache_client
,
num_workers
,
*
top
,
&
cache_op
));
*
top
=
cache_op
;
*
bottom
=
map_op
;
}
return
Status
::
OK
();
}
...
...
@@ -809,6 +826,9 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
std
::
shared_ptr
<
DatasetOp
>
*
bottom
)
{
// Required arguments
std
::
vector
<
std
::
string
>
files_list
;
std
::
shared_ptr
<
CacheClient
>
cache_client
=
nullptr
;
std
::
shared_ptr
<
Sampler
>
sampler
=
nullptr
;
int
num_workers
=
0
;
std
::
shared_ptr
<
TFReaderOp
::
Builder
>
builder
=
std
::
make_shared
<
TFReaderOp
::
Builder
>
();
if
(
!
args
[
"dataset_files"
].
is_none
())
{
files_list
=
ToStringVector
(
args
[
"dataset_files"
]);
...
...
@@ -828,7 +848,8 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
py
::
handle
value
=
arg
.
second
;
if
(
!
value
.
is_none
())
{
if
(
key
==
"num_parallel_workers"
)
{
(
void
)
builder
->
SetNumWorkers
(
ToInt
(
value
));
num_workers
=
ToInt
(
value
);
(
void
)
builder
->
SetNumWorkers
(
num_workers
);
}
else
if
(
key
==
"columns_list"
)
{
columns_to_load
=
ToStringVector
(
value
);
(
void
)
builder
->
SetColumnsToLoad
(
columns_to_load
);
...
...
@@ -848,6 +869,11 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
(
void
)
builder
->
SetDeviceId
(
ToInt
(
value
));
}
else
if
(
key
==
"shard_equal_rows"
)
{
(
void
)
builder
->
SetShardEqualRows
(
ToBool
(
value
));
}
else
if
(
key
==
"cache"
)
{
cache_client
=
value
.
cast
<
std
::
shared_ptr
<
CacheClient
>>
();
}
else
if
(
key
==
"sampler"
)
{
auto
create
=
py
::
reinterpret_borrow
<
py
::
object
>
(
value
).
attr
(
"create"
);
sampler
=
create
().
cast
<
std
::
shared_ptr
<
Sampler
>>
();
}
}
}
...
...
@@ -860,12 +886,27 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
}
(
void
)
builder
->
SetDataSchema
(
std
::
move
(
schema
));
}
// If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed
// because TFReaderOp is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can
// inherit this sampler from the leaf, providing sampling support from the caching layer.
// That is why we save the sampler here in a leaf node that does not use sampling.
if
(
sampler
)
{
(
void
)
builder
->
SetSampler
(
std
::
move
(
sampler
));
}
else
if
(
cache_client
)
{
int64_t
num_samples
=
0
;
int64_t
start_index
=
0
;
sampler
=
std
::
make_shared
<
SequentialSampler
>
(
num_samples
,
start_index
);
(
void
)
builder
->
SetSampler
(
std
::
move
(
sampler
));
}
std
::
shared_ptr
<
TFReaderOp
>
tf_op
;
RETURN_IF_NOT_OK
(
builder
->
Build
(
&
tf_op
));
RETURN_IF_NOT_OK
(
tree_
->
AssociateNode
(
tf_op
));
*
top
=
tf_op
;
if
(
shuffle_required
)
{
if
(
!
cache_client
&&
shuffle_required
)
{
const
boolean
estimate
=
true
;
const
int64_t
workers
=
8
;
std
::
shared_ptr
<
DatasetOp
>
shuffle_op
=
nullptr
;
...
...
@@ -882,6 +923,15 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
*
bottom
=
tf_op
;
}
// Add a cache op over this op if required and update the output subtree (top/bottom)
if
(
cache_client
)
{
// Note, it is not allowed to have both shuffle and cache
std
::
shared_ptr
<
DatasetOp
>
cache_op
=
nullptr
;
RETURN_IF_NOT_OK
(
AddCacheOp
(
cache_client
,
num_workers
,
tf_op
,
&
cache_op
));
*
top
=
cache_op
;
*
bottom
=
tf_op
;
}
return
Status
::
OK
();
}
...
...
@@ -906,6 +956,8 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr<Data
std
::
string
err_msg
=
"Error: No dataset path specified"
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
int
num_workers
=
0
;
std
::
shared_ptr
<
CacheClient
>
cache_client
=
nullptr
;
std
::
shared_ptr
<
ImageFolderOp
::
Builder
>
builder
=
std
::
make_shared
<
ImageFolderOp
::
Builder
>
();
(
void
)
builder
->
SetImageFolderDir
(
ToString
(
args
[
"dataset_dir"
]));
...
...
@@ -915,7 +967,8 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr<Data
py
::
handle
value
=
arg
.
second
;
if
(
!
value
.
is_none
())
{
if
(
key
==
"num_parallel_workers"
)
{
(
void
)
builder
->
SetNumWorkers
(
ToInt
(
value
));
num_workers
=
ToInt
(
value
);
(
void
)
builder
->
SetNumWorkers
(
num_workers
);
}
else
if
(
key
==
"sampler"
)
{
auto
create
=
py
::
reinterpret_borrow
<
py
::
object
>
(
value
).
attr
(
"create"
);
std
::
shared_ptr
<
Sampler
>
sampler
=
create
().
cast
<
std
::
shared_ptr
<
Sampler
>>
();
...
...
@@ -926,12 +979,27 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr<Data
(
void
)
builder
->
SetClassIndex
(
ToStringMap
(
value
));
}
else
if
(
key
==
"decode"
)
{
(
void
)
builder
->
SetDecode
(
ToBool
(
value
));
}
else
if
(
key
==
"cache"
)
{
cache_client
=
value
.
cast
<
std
::
shared_ptr
<
CacheClient
>>
();
}
}
}
std
::
shared_ptr
<
ImageFolderOp
>
op
;
RETURN_IF_NOT_OK
(
builder
->
Build
(
&
op
));
*
top
=
op
;
std
::
shared_ptr
<
ImageFolderOp
>
if_op
;
RETURN_IF_NOT_OK
(
builder
->
Build
(
&
if_op
));
RETURN_IF_NOT_OK
(
tree_
->
AssociateNode
(
if_op
));
*
top
=
if_op
;
// Additionally, add a cache if required.
// Note that this cache op is only acting as a place holder for the caching position
// within the tree. Later, a pre-pass will execute a tree transform to set up the actual
// caching logic in the tree.
if
(
cache_client
)
{
std
::
shared_ptr
<
DatasetOp
>
cache_op
=
nullptr
;
RETURN_IF_NOT_OK
(
AddCacheOp
(
cache_client
,
num_workers
,
if_op
,
&
cache_op
));
*
top
=
cache_op
;
*
bottom
=
if_op
;
}
return
Status
::
OK
();
}
...
...
@@ -1130,9 +1198,12 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas
std
::
shared_ptr
<
DatasetOp
>
*
bottom
)
{
// Required arguments
RandomDataOp
::
Builder
builder
;
std
::
shared_ptr
<
CacheClient
>
cache_client
=
nullptr
;
std
::
shared_ptr
<
Sampler
>
sampler
=
nullptr
;
int
num_workers
=
0
;
if
(
args
[
"
num_sample
s"
].
is_none
())
{
std
::
string
err_msg
=
"Error:
num_sample
s is a required argument"
;
if
(
args
[
"
total_row
s"
].
is_none
())
{
std
::
string
err_msg
=
"Error:
total_row
s is a required argument"
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
std
::
vector
<
std
::
string
>
columns_to_load
;
...
...
@@ -1141,16 +1212,23 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas
for
(
auto
arg
:
args
)
{
std
::
string
key
=
py
::
str
(
arg
.
first
);
py
::
handle
value
=
arg
.
second
;
if
(
key
==
"num_parallel_workers"
)
{
(
void
)
builder
.
SetNumWorkers
(
ToInt
(
value
));
}
else
if
(
key
==
"schema_file_path"
||
key
==
"schema_json_string"
)
{
schema_exists
=
true
;
}
else
if
(
key
==
"columns_list"
)
{
columns_to_load
=
ToStringVector
(
value
);
}
else
if
(
key
==
"num_samples"
)
{
// This is not sampling here. The random data op needs to know how much data to
// generate. It does not currently support sampling.
(
void
)
builder
.
SetTotalRows
(
ToInt
(
value
));
if
(
!
value
.
is_none
())
{
if
(
key
==
"num_parallel_workers"
)
{
num_workers
=
ToInt
(
value
);
(
void
)
builder
.
SetNumWorkers
(
num_workers
);
}
else
if
(
key
==
"schema_file_path"
||
key
==
"schema_json_string"
)
{
schema_exists
=
true
;
}
else
if
(
key
==
"columns_list"
)
{
columns_to_load
=
ToStringVector
(
value
);
}
else
if
(
key
==
"total_rows"
)
{
// This is not sampling here. The random data op needs to know how much data to generate.
(
void
)
builder
.
SetTotalRows
(
ToInt
(
value
));
}
else
if
(
key
==
"cache"
)
{
cache_client
=
value
.
cast
<
std
::
shared_ptr
<
CacheClient
>>
();
}
else
if
(
key
==
"sampler"
)
{
auto
create
=
py
::
reinterpret_borrow
<
py
::
object
>
(
value
).
attr
(
"create"
);
sampler
=
create
().
cast
<
std
::
shared_ptr
<
Sampler
>>
();
}
}
}
if
(
schema_exists
)
{
...
...
@@ -1162,9 +1240,34 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas
}
(
void
)
builder
.
SetDataSchema
(
std
::
move
(
schema
));
}
std
::
shared_ptr
<
RandomDataOp
>
op
;
RETURN_IF_NOT_OK
(
builder
.
Build
(
&
op
));
*
top
=
op
;
// If the user gave a sampler, but they did not ask for a cache, then by itself this is not allowed
// because RandomDataOp is a non-mappable dataset that does not support sampling.
// However, if a cache operator is injected at some other place higher in the tree, that cache can
// inherit this sampler from the leaf, providing sampling support from the caching layer.
// That is why we save the sampler here in a leaf node that does not use sampling.
if
(
sampler
)
{
(
void
)
builder
.
SetSampler
(
std
::
move
(
sampler
));
}
else
if
(
cache_client
)
{
int64_t
num_samples
=
0
;
int64_t
start_index
=
0
;
sampler
=
std
::
make_shared
<
SequentialSampler
>
(
num_samples
,
start_index
);
(
void
)
builder
.
SetSampler
(
std
::
move
(
sampler
));
}
std
::
shared_ptr
<
RandomDataOp
>
random_op
=
nullptr
;
RETURN_IF_NOT_OK
(
builder
.
Build
(
&
random_op
));
RETURN_IF_NOT_OK
(
tree_
->
AssociateNode
(
random_op
));
*
top
=
random_op
;
// Add a cache op over this op if required and update the output subtree (top/bottom)
if
(
cache_client
)
{
std
::
shared_ptr
<
DatasetOp
>
cache_op
=
nullptr
;
RETURN_IF_NOT_OK
(
AddCacheOp
(
cache_client
,
num_workers
,
random_op
,
&
cache_op
));
*
top
=
cache_op
;
*
bottom
=
random_op
;
}
return
Status
::
OK
();
}
...
...
@@ -1425,6 +1528,31 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp>
return
Status
::
OK
();
}
// Helper function to inject the cache operator over top of the current operation being built.
Status
DEPipeline
::
AddCacheOp
(
std
::
shared_ptr
<
CacheClient
>
cache_client
,
int
num_workers
,
std
::
shared_ptr
<
DatasetOp
>
input_op
,
std
::
shared_ptr
<
DatasetOp
>
*
cache_op
)
{
std
::
shared_ptr
<
CacheOp
>
new_cache_op
=
nullptr
;
CacheOp
::
Builder
cache_builder
;
// use the same number of workers as the leaf. We need some optimization here, the user does not
// give the cache op number of workers directly.
if
(
num_workers
!=
0
)
{
(
void
)
cache_builder
.
SetNumWorkers
(
num_workers
);
}
(
void
)
cache_builder
.
SetClient
(
cache_client
);
RETURN_IF_NOT_OK
(
cache_builder
.
Build
(
&
new_cache_op
));
RETURN_IF_NOT_OK
(
tree_
->
AssociateNode
(
new_cache_op
));
RETURN_IF_NOT_OK
(
new_cache_op
->
AddChild
(
input_op
));
// We have now created:
//
// CacheOp
// |
// input_op
//
*
cache_op
=
new_cache_op
;
return
Status
::
OK
();
}
// Helper function to inject a shuffle operator over top of the current operation being built.
Status
DEPipeline
::
AddShuffleOp
(
int64_t
shuffle_size
,
std
::
shared_ptr
<
DatasetOp
>
input_op
,
std
::
shared_ptr
<
DatasetOp
>
*
shuffle_op
)
{
...
...
mindspore/ccsrc/dataset/api/de_pipeline.h
浏览文件 @
7c1bc519
...
...
@@ -35,6 +35,8 @@ namespace mindspore {
namespace
dataset
{
using
DsOpPtr
=
std
::
shared_ptr
<
DatasetOp
>
;
class
CacheClient
;
// enum for the dataset operator names
enum
OpName
{
kShuffle
,
...
...
@@ -181,6 +183,16 @@ class DEPipeline {
static
Status
ParsePadInfo
(
py
::
handle
value
,
PadInfo
*
pad_info
);
/// \brief Helper function to inject a cache operator over top of the current operation being built.
/// \param[in] cache_client The client to use for caching
/// \param[in] num_workers The number of workers to use in the cache op
/// \param[in] input_op The operator to build the cache on top of
/// \param[out] cache_op The top node of the created subtree (subtree contains two nodes). In this case it will be
/// the cache operator
/// \return Status return code
Status
AddCacheOp
(
std
::
shared_ptr
<
CacheClient
>
cache_client
,
int
num_workers
,
std
::
shared_ptr
<
DatasetOp
>
input_op
,
std
::
shared_ptr
<
DatasetOp
>
*
cache_op
);
/// \brief Helper function to inject a shuffle operator over top of the current operation being built.
/// \param[in] shuffle_size The size to use in the shuffle buffer
/// \param[in] input_op The operator to build shuffle on top of
...
...
mindspore/ccsrc/dataset/api/python_bindings.cc
浏览文件 @
7c1bc519
...
...
@@ -35,6 +35,7 @@
#include "dataset/engine/datasetops/source/text_file_op.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/datasetops/source/voc_op.h"
#include "dataset/engine/cache/cache_client.h"
#include "dataset/engine/gnn/graph.h"
#include "dataset/engine/jagged_connector.h"
#include "dataset/kernels/data/concatenate_op.h"
...
...
@@ -768,6 +769,11 @@ void bindInfoObjects(py::module *m) {
.
def
(
"get_batch_num"
,
&
BatchOp
::
CBatchInfo
::
get_batch_num
);
}
void
bindCacheClient
(
py
::
module
*
m
)
{
(
void
)
py
::
class_
<
CacheClient
,
std
::
shared_ptr
<
CacheClient
>>
(
*
m
,
"CacheClient"
)
.
def
(
py
::
init
<
uint32_t
,
uint64_t
,
bool
>
());
}
void
bindVocabObjects
(
py
::
module
*
m
)
{
(
void
)
py
::
class_
<
Vocab
,
std
::
shared_ptr
<
Vocab
>>
(
*
m
,
"Vocab"
)
.
def
(
py
::
init
<>
())
...
...
@@ -939,6 +945,7 @@ PYBIND11_MODULE(_c_dataengine, m) {
bindSamplerOps
(
&
m
);
bindDatasetOps
(
&
m
);
bindInfoObjects
(
&
m
);
bindCacheClient
(
&
m
);
bindVocabObjects
(
&
m
);
bindGraphData
(
&
m
);
bindDependIcuTokenizerOps
(
&
m
);
...
...
mindspore/ccsrc/dataset/engine/CMakeLists.txt
浏览文件 @
7c1bc519
...
...
@@ -2,6 +2,7 @@ add_subdirectory(datasetops)
add_subdirectory
(
opt
)
add_subdirectory
(
gnn
)
add_subdirectory
(
perf
)
add_subdirectory
(
cache
)
if
(
ENABLE_TDTQUE
)
add_subdirectory
(
tdt
)
endif
()
...
...
@@ -17,7 +18,9 @@ add_library(engine OBJECT
target_include_directories
(
engine PRIVATE
${
pybind11_INCLUDE_DIRS
}
)
if
(
ENABLE_TDTQUE
)
add_dependencies
(
engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn engine-perf
)
else
()
add_dependencies
(
engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf
)
add_dependencies
(
engine engine-datasetops engine-datasetops-source engine-tdt engine-opt engine-gnn engine-perf
engine-cache-client engine-cache-server
)
else
()
add_dependencies
(
engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf
engine-cache-client engine-cache-server
)
endif
()
mindspore/ccsrc/dataset/engine/cache/CMakeLists.txt
0 → 100644
浏览文件 @
7c1bc519
file
(
GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
"*.cc"
)
set_property
(
SOURCE
${
_CURRENT_SRC_FILES
}
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD
)
add_library
(
engine-cache-client OBJECT
cache_client.cc
cache_request.cc
)
add_library
(
engine-cache-server OBJECT
cache_service.cc
cache_server.cc
)
mindspore/ccsrc/dataset/engine/cache/cache_client.cc
0 → 100644
浏览文件 @
7c1bc519
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iomanip>
#include "dataset/engine/cache/cache_client.h"
#include "dataset/engine/cache/cache_request.h"
#include "dataset/util/bit.h"
namespace
mindspore
{
namespace
dataset
{
// Constructor
CacheClient
::
CacheClient
(
uint32_t
session_id
,
uint64_t
cache_mem_sz
,
bool
spill
)
:
server_connection_id_
(
0
),
session_id_
(
session_id
),
cache_crc_
(
0
),
cache_mem_sz_
(
cache_mem_sz
),
spill_
(
spill
)
{}
// print method for display cache details
void
CacheClient
::
Print
(
std
::
ostream
&
out
)
const
{
out
<<
" Session id: "
<<
session_id_
<<
"
\n
Cache crc: "
<<
cache_crc_
<<
"
\n
Server cache id: "
<<
server_connection_id_
<<
"
\n
Cache mem size: "
<<
cache_mem_sz_
<<
"
\n
Spilling: "
<<
std
::
boolalpha
<<
spill_
;
}
Status
CacheClient
::
WriteRow
(
const
TensorRow
&
row
,
row_id_type
*
row_id_from_server
)
const
{
CacheRowRequest
rq
(
server_connection_id_
,
cookie
());
RETURN_IF_NOT_OK
(
rq
.
SerializeCacheRowRequest
(
row
));
RETURN_IF_NOT_OK
(
CacheServer
::
GetInstance
().
PushRequest
(
&
rq
));
RETURN_IF_NOT_OK
(
rq
.
Wait
());
if
(
row_id_from_server
!=
nullptr
)
{
*
row_id_from_server
=
rq
.
GetRowIdAfterCache
();
}
return
Status
::
OK
();
}
Status
CacheClient
::
WriteBuffer
(
std
::
unique_ptr
<
DataBuffer
>
&&
in
)
const
{
std
::
unique_ptr
<
DataBuffer
>
db_ptr
=
std
::
move
(
in
);
auto
num_rows
=
db_ptr
->
NumRows
();
std
::
vector
<
TensorRow
>
all_rows
;
if
(
num_rows
>
0
)
{
all_rows
.
reserve
(
num_rows
);
// Break down the DataBuffer into TensorRow. We will send the requests async
// and then do a final wait.
MemGuard
<
CacheRowRequest
>
rq_arr
;
RETURN_IF_NOT_OK
(
rq_arr
.
allocate
(
num_rows
,
server_connection_id_
,
cookie
()));
CacheServer
&
cs
=
CacheServer
::
GetInstance
();
for
(
auto
i
=
0
;
i
<
num_rows
;
++
i
)
{
TensorRow
row
;
auto
rq
=
rq_arr
[
i
];
RETURN_IF_NOT_OK
(
db_ptr
->
PopRow
(
&
row
));
RETURN_IF_NOT_OK
(
rq
->
SerializeCacheRowRequest
(
row
));
RETURN_IF_NOT_OK
(
cs
.
PushRequest
(
rq
));
// We can't let row go out of scope. Otherwise it will free all the tensor memory.
// So park it in the vector. When this function go out of scope, its memory
// will be freed.
all_rows
.
push_back
(
std
::
move
(
row
));
}
// Now we wait for the requests to be done.
for
(
auto
i
=
0
;
i
<
num_rows
;
++
i
)
{
auto
rq
=
rq_arr
[
i
];
RETURN_IF_NOT_OK
(
rq
->
Wait
());
}
}
return
Status
::
OK
();
}
Status
CacheClient
::
GetRows
(
const
std
::
vector
<
row_id_type
>
&
row_id
,
TensorTable
*
out
)
const
{
RETURN_UNEXPECTED_IF_NULL
(
out
);
BatchFetchRequest
rq
(
server_connection_id_
,
row_id
);
RETURN_IF_NOT_OK
(
CacheServer
::
GetInstance
().
PushRequest
(
&
rq
));
RETURN_IF_NOT_OK
(
rq
.
Wait
());
RETURN_IF_NOT_OK
(
rq
.
RestoreRows
(
out
));
return
Status
::
OK
();
}
Status
CacheClient
::
CreateCache
(
uint32_t
tree_crc
,
bool
generate_id
)
{
UniqueLock
lck
(
&
mux_
);
// To create a cache, we identify ourself at the client by:
// - the shared session id
// - a crc for the tree nodes from the cache downward
// Pack these 2 into a single 64 bit request id
//
// Consider this example:
// tree1: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> batch
// tree2: cifar10 --> map(rotate) --> cache (session id = 1, crc = 456) --> batch
// These are different trees in a single session, but the user wants to share the cache.
// This is not allowed because the data of these caches are different.
//
// Consider this example:
// tree1: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> batch
// tree2: tfreader --> map(decode) --> cache (session id = 1, crc = 123) --> map(rotate) --> batch
// These are different trees in the same session, but the cached data is the same, so it is okay
// to allow the sharing of this cache between these pipelines.
// The CRC is computed by the tree prepare phase and passed to this function when creating the cache.
// If we already have a server_connection_id_, then it means this same cache client has already been used
// to create a cache and some other tree is trying to use the same cache.
// That is allowed, however the crc better match!
if
(
server_connection_id_
)
{
if
(
cache_crc_
!=
tree_crc
)
{
RETURN_STATUS_UNEXPECTED
(
"Attempt to re-use a cache for a different tree!"
);
}
// Check the state of the server. For non-mappable case where there is a build phase and a fetch phase, we should
// skip the build phase.
lck
.
Unlock
();
// GetStat will grab the mutex again. So unlock it to prevent deadlock.
CacheClient
::
ServiceStat
stat
{};
RETURN_IF_NOT_OK
(
GetStat
(
&
stat
));
if
(
stat
.
cache_service_state
==
static_cast
<
uint8_t
>
(
CacheService
::
State
::
kFetchPhase
))
{
return
Status
(
StatusCode
::
kDuplicateKey
,
__LINE__
,
__FILE__
,
"Not an error and we should bypass the build phase"
);
}
}
else
{
cache_crc_
=
tree_crc
;
// It's really a new cache we're creating so save our crc in the client
// Combine the session and crc. This will form our client cache identifier.
connection_id_type
connection_identification
=
(
static_cast
<
uint64_t
>
(
session_id_
)
<<
32
)
|
cache_crc_
;
// Now execute the cache create request using this identifier and other configs
BaseRequest
::
CreateCacheFlag
createFlag
=
BaseRequest
::
CreateCacheFlag
::
kNone
;
if
(
spill_
)
{
createFlag
|=
BaseRequest
::
CreateCacheFlag
::
kSpillToDisk
;
}
if
(
generate_id
)
{
createFlag
|=
BaseRequest
::
CreateCacheFlag
::
kGenerateRowId
;
}
CreationCacheRequest
rq
(
connection_identification
,
cache_mem_sz_
,
createFlag
);
RETURN_IF_NOT_OK
(
CacheServer
::
GetInstance
().
PushRequest
(
&
rq
));
Status
rc
=
rq
.
Wait
();
if
(
rc
.
IsOk
()
||
rc
.
get_code
()
==
StatusCode
::
kDuplicateKey
)
{
server_connection_id_
=
rq
.
GetServerConnectionId
();
if
(
rc
.
IsOk
())
{
// The 1st guy creating the cache will get a cookie back.
// But this object may be shared among pipelines and we don't want
// overwrite it.
cookie_
=
rq
.
cookie
();
}
}
// We are not resetting the Duplicate key return code. We are passing it back to the CacheOp. This will tell the
// CacheOp to bypass the build phase.
return
rc
;
}
return
Status
::
OK
();
}
Status
CacheClient
::
PurgeCache
()
{
UniqueLock
lck
(
&
mux_
);
PurgeCacheRequest
rq
(
server_connection_id_
);
RETURN_IF_NOT_OK
(
CacheServer
::
GetInstance
().
PushRequest
(
&
rq
));
return
rq
.
Wait
();
}
Status
CacheClient
::
DestroyCache
()
{
UniqueLock
lck
(
&
mux_
);
DestroyCacheRequest
rq
(
server_connection_id_
);
RETURN_IF_NOT_OK
(
CacheServer
::
GetInstance
().
PushRequest
(
&
rq
));
return
rq
.
Wait
();
}
Status
CacheClient
::
GetStat
(
ServiceStat
*
stat
)
{
SharedLock
lck
(
&
mux_
);
RETURN_UNEXPECTED_IF_NULL
(
stat
);
GetStatRequest
rq
(
server_connection_id_
);
RETURN_IF_NOT_OK
(
CacheServer
::
GetInstance
().
PushRequest
(
&
rq
));
RETURN_IF_NOT_OK
(
rq
.
Wait
());
stat
->
num_disk_cached
=
rq
.
GetNumDiskCached
();
stat
->
num_mem_cached
=
rq
.
GetNumMemCached
();
stat
->
min_row_id
=
rq
.
GetMinRowId
();
stat
->
max_row_id
=
rq
.
GetMaxRowId
();
stat
->
cache_service_state
=
rq
.
GetState
();
return
Status
::
OK
();
}
Status
CacheClient
::
CacheSchema
(
const
std
::
unordered_map
<
std
::
string
,
int32_t
>
&
map
)
{
SharedLock
lck
(
&
mux_
);
CacheSchemaRequest
rq
(
server_connection_id_
);
RETURN_IF_NOT_OK
(
rq
.
SerializeCacheSchemaRequest
(
map
));
RETURN_IF_NOT_OK
(
CacheServer
::
GetInstance
().
PushRequest
(
&
rq
));
RETURN_IF_NOT_OK
(
rq
.
Wait
());
return
Status
::
OK
();
}
Status
CacheClient
::
FetchSchema
(
std
::
unordered_map
<
std
::
string
,
int32_t
>
*
map
)
{
SharedLock
lck
(
&
mux_
);
RETURN_UNEXPECTED_IF_NULL
(
map
);
FetchSchemaRequest
rq
(
server_connection_id_
);
RETURN_IF_NOT_OK
(
CacheServer
::
GetInstance
().
PushRequest
(
&
rq
));
RETURN_IF_NOT_OK
(
rq
.
Wait
());
*
map
=
rq
.
GetColumnMap
();
return
Status
::
OK
();
}
Status
CacheClient
::
BuildPhaseDone
()
const
{
SharedLock
lck
(
&
mux_
);
BuildPhaseDoneRequest
rq
(
server_connection_id_
,
cookie
());
RETURN_IF_NOT_OK
(
CacheServer
::
GetInstance
().
PushRequest
(
&
rq
));
RETURN_IF_NOT_OK
(
rq
.
Wait
());
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/cache/cache_client.h
0 → 100644
浏览文件 @
7c1bc519
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_CACHE_CLIENT_H_
#define DATASET_ENGINE_CACHE_CLIENT_H_
#include <iostream>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "./de_tensor_generated.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/cache/cache_server.h"
#include "dataset/util/lock.h"
namespace
mindspore
{
namespace
dataset
{
/// \brief A CacheClient is a bridge between a DatasetOp and a CacheServer. All communications are through
/// a CacheClient. Typical tasks including like creating a cache service, cache a data buffer, restore a previously
/// rows, etc.
class
CacheClient
{
public:
/// \brief Constructor
/// \param session_id A user assigned session id for the current pipeline
/// \param cache_mem_sz Size of the memory set aside for the row caching. 0 for unlimited
/// \param spill Spill to disk if out of memory
CacheClient
(
uint32_t
session_id
,
uint64_t
cache_mem_sz
,
bool
spill
);
/// \brief Destructor
~
CacheClient
()
=
default
;
/// \brief Getter function for returning the current session id
/// \return session id
uint64_t
session_id
()
const
{
return
session_id_
;
}
/// \brief Send a TensorRow to the cache server
/// \param[in] row
/// \param[out] row_id_from_server Optional. The row id assigned by the server for non-mappable dataset
/// \return return code
Status
WriteRow
(
const
TensorRow
&
row
,
row_id_type
*
row_id_from_server
=
nullptr
)
const
;
/// \brief Send a DataBuffer to the cache server
/// \param in Unique pointer of the DataBuffer to be cached
/// \return return code
Status
WriteBuffer
(
std
::
unique_ptr
<
DataBuffer
>
&&
in
)
const
;
/// \brief Fetch a list of rows from the cache server. An empty TensorRow will be returned if there is
/// any cache miss
/// \param row_id A vector of row id's
/// \param out A TensorTable of TensorRows.
/// \return return code
Status
GetRows
(
const
std
::
vector
<
row_id_type
>
&
row_id
,
TensorTable
*
out
)
const
;
/// \brief Create a cache.
/// \param tree_crc A crc that was generated during tree prepare phase
/// \param generate_id Let the cache service generate row id
/// \return Status object
Status
CreateCache
(
uint32_t
tree_crc
,
bool
generate_id
);
/// \brief Purge a cache. Cache can be reused after reset.
/// \return Status object
Status
PurgeCache
();
/// \brief Destroy a cache. Like Purge but the cache is deleted and can't be reused.
/// \return Status object
Status
DestroyCache
();
/// \brief Get the statistics from a cache.
/// \param[in/out] Pointer to a pre-allocated ServiceStat object
/// \return Status object
struct
ServiceStat
{
int64_t
num_mem_cached
;
int64_t
num_disk_cached
;
row_id_type
min_row_id
;
row_id_type
max_row_id
;
int8_t
cache_service_state
;
};
Status
GetStat
(
ServiceStat
*
);
/// \brief Cache the schema at the cache server
/// \param map The unordered map of the schema
/// \return Status object
Status
CacheSchema
(
const
std
::
unordered_map
<
std
::
string
,
int32_t
>
&
map
);
/// \brief Fetch the schema from the cache server
/// \param map Pointer to pre-allocated map object
/// \return Status object.
Status
FetchSchema
(
std
::
unordered_map
<
std
::
string
,
int32_t
>
*
map
);
/// \brief Change the state from build phase to read phase. Applicable to non-mappable dataset only. Only the cache
/// client that holds cookie can be allowed to make this request
/// \return Status object
Status
BuildPhaseDone
()
const
;
/// \brief A print method typically used for debugging
/// \param out The output stream to write output to
void
Print
(
std
::
ostream
&
out
)
const
;
/// \brief Stream output operator overload
/// \return the output stream must be returned
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
CacheClient
&
cc
)
{
cc
.
Print
(
out
);
return
out
;
}
/// \brief Every cache server has a cookie which uniquely identifies the CacheClient that creates it.
/// \return Cookie
std
::
string
cookie
()
const
{
return
cookie_
;
}
private:
mutable
RWLock
mux_
;
uint64_t
cache_mem_sz_
;
bool
spill_
;
// The session_id_ and cache_crc_ work together to uniquely identify this particular cache and allow
// sharing of the cache.
uint32_t
session_id_
;
uint32_t
cache_crc_
;
// The server_connection_id_ is the actual id we use for operations after the cache is built
connection_id_type
server_connection_id_
;
// Some magic cookie returned from the cache server.
std
::
string
cookie_
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_ENGINE_CACHE_CLIENT_H_
mindspore/ccsrc/dataset/engine/cache/cache_request.cc
0 → 100644
浏览文件 @
7c1bc519
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/cache/cache_request.h"
namespace
mindspore
{
namespace
dataset
{
Status
CacheRowRequest
::
SerializeCacheRowRequest
(
const
TensorRow
&
row
)
{
buffers_
.
reserve
(
row
.
size
()
+
1
);
RETURN_IF_NOT_OK
(
SerializeTensorRowHeader
(
row
));
buffers_
.
push_back
(
fbb_
->
GetBufferPointer
());
for
(
const
auto
&
ts
:
row
)
{
buffers_
.
push_back
(
ts
->
GetBuffer
());
}
return
Status
::
OK
();
}
Status
CacheRowRequest
::
SerializeTensorRowHeader
(
const
TensorRow
&
row
)
{
try
{
fbb_
=
std
::
make_shared
<
flatbuffers
::
FlatBufferBuilder
>
();
std
::
vector
<
flatbuffers
::
Offset
<
TensorMetaMsg
>>
v
;
std
::
vector
<
int64_t
>
tensor_sz
;
v
.
reserve
(
row
.
size
());
tensor_sz
.
reserve
(
row
.
size
());
// We will go through each column in the row.
for
(
const
std
::
shared_ptr
<
Tensor
>
&
ts_ptr
:
row
)
{
flatbuffers
::
Offset
<
TensorMetaMsg
>
ts_off
;
RETURN_IF_NOT_OK
(
SerializeOneTensorMeta
(
ts_ptr
,
&
ts_off
));
v
.
push_back
(
ts_off
);
tensor_sz
.
push_back
(
ts_ptr
->
SizeInBytes
());
}
auto
column_off
=
fbb_
->
CreateVector
(
v
);
auto
data_sz_off
=
fbb_
->
CreateVector
(
tensor_sz
);
TensorRowHeaderMsgBuilder
row_builder
(
*
fbb_
);
row_builder
.
add_column
(
column_off
);
row_builder
.
add_data_sz
(
data_sz_off
);
// Pass the row_id even if it may not be known.
row_builder
.
add_row_id
(
row
.
getId
());
row_builder
.
add_size_of_this
(
-
1
);
// fill in later after we call Finish.
auto
out
=
row_builder
.
Finish
();
fbb_
->
Finish
(
out
);
// Now go back to fill in size_of_this in the flat buffer.
auto
msg
=
GetMutableTensorRowHeaderMsg
(
fbb_
->
GetBufferPointer
());
auto
success
=
msg
->
mutate_size_of_this
(
fbb_
->
GetSize
());
if
(
!
success
)
{
RETURN_STATUS_UNEXPECTED
(
"Unable to set size_of_this"
);
}
return
Status
::
OK
();
}
catch
(
const
std
::
bad_alloc
&
e
)
{
return
Status
(
StatusCode
::
kOutOfMemory
,
__LINE__
,
__FILE__
);
}
}
Status
CacheRowRequest
::
SerializeOneTensorMeta
(
const
std
::
shared_ptr
<
Tensor
>
&
ts_ptr
,
flatbuffers
::
Offset
<
TensorMetaMsg
>
*
out_off
)
{
RETURN_UNEXPECTED_IF_NULL
(
out_off
);
const
Tensor
*
ts
=
ts_ptr
.
get
();
auto
shape_off
=
fbb_
->
CreateVector
(
ts
->
shape
().
AsVector
());
const
auto
ptr
=
ts
->
GetBuffer
();
if
(
ptr
==
nullptr
)
{
RETURN_STATUS_UNEXPECTED
(
"Tensor buffer is null"
);
}
auto
src
=
ts
->
type
().
value
();
TensorType
dest
;
#define CASE(t) \
case DataType::t: \
dest = TensorType::TensorType_##t; \
break
// Map the type to fill in the flat buffer.
switch
(
src
)
{
CASE
(
DE_BOOL
);
CASE
(
DE_INT8
);
CASE
(
DE_UINT8
);
CASE
(
DE_INT16
);
CASE
(
DE_UINT16
);
CASE
(
DE_INT32
);
CASE
(
DE_UINT32
);
CASE
(
DE_INT64
);
CASE
(
DE_UINT64
);
CASE
(
DE_FLOAT16
);
CASE
(
DE_FLOAT32
);
CASE
(
DE_FLOAT64
);
CASE
(
DE_STRING
);
default:
MS_LOG
(
ERROR
)
<<
"Unknown tensor. Dumping content:
\n
"
<<
*
ts
;
RETURN_STATUS_UNEXPECTED
(
"Unknown type"
);
}
#undef CASE
TensorMetaMsgBuilder
ts_builder
(
*
fbb_
);
ts_builder
.
add_dims
(
shape_off
);
ts_builder
.
add_type
(
dest
);
auto
ts_off
=
ts_builder
.
Finish
();
*
out_off
=
ts_off
;
return
Status
::
OK
();
}
Status
BatchFetchRequest
::
RestoreOneTensor
(
const
TensorMetaMsg
*
col_ts
,
const
ReadableSlice
&
data
,
std
::
shared_ptr
<
Tensor
>
*
out
)
{
RETURN_UNEXPECTED_IF_NULL
(
col_ts
);
auto
shape_in
=
col_ts
->
dims
();
auto
type_in
=
col_ts
->
type
();
std
::
vector
<
dsize_t
>
v
;
v
.
reserve
(
shape_in
->
size
());
v
.
assign
(
shape_in
->
begin
(),
shape_in
->
end
());
TensorShape
shape
(
v
);
DataType
::
Type
dest
=
DataType
::
DE_UNKNOWN
;
#define CASE(t) \
case TensorType_##t: \
dest = DataType::Type::t; \
break
switch
(
type_in
)
{
CASE
(
DE_BOOL
);
CASE
(
DE_INT8
);
CASE
(
DE_UINT8
);
CASE
(
DE_INT16
);
CASE
(
DE_UINT16
);
CASE
(
DE_INT32
);
CASE
(
DE_UINT32
);
CASE
(
DE_INT64
);
CASE
(
DE_UINT64
);
CASE
(
DE_FLOAT16
);
CASE
(
DE_FLOAT32
);
CASE
(
DE_FLOAT64
);
CASE
(
DE_STRING
);
}
#undef CASE
DataType
type
(
dest
);
std
::
shared_ptr
<
Tensor
>
ts
=
std
::
make_shared
<
Tensor
>
(
shape
,
type
,
static_cast
<
const
unsigned
char
*>
(
data
.
GetPointer
()),
data
.
GetSize
());
// Next we restore the real data which can be embedded or stored separately.
if
(
ts
->
SizeInBytes
()
!=
data
.
GetSize
())
{
MS_LOG
(
ERROR
)
<<
"Unexpected length. Read "
<<
data
.
GetSize
()
<<
". Expected "
<<
ts
->
SizeInBytes
()
<<
".
\n
"
<<
"Dumping tensor
\n
"
<<
*
ts
<<
"
\n
"
;
RETURN_STATUS_UNEXPECTED
(
"Length mismatch. See log file for details."
);
}
*
out
=
std
::
move
(
ts
);
return
Status
::
OK
();
}
Status
BatchFetchRequest
::
RestoreRows
(
TensorTable
*
out
)
{
RETURN_UNEXPECTED_IF_NULL
(
out
);
auto
num_elements
=
row_id_
.
size
();
auto
*
offset_array
=
reinterpret_cast
<
const
int64_t
*>
(
mem_
.
GetPointer
());
TensorTable
tbl
;
tbl
.
reserve
(
num_elements
);
ReadableSlice
all
(
mem_
.
GetPointer
(),
mem_
.
GetSizeInBytes
());
for
(
auto
i
=
0
;
i
<
num_elements
;
++
i
)
{
auto
len
=
offset_array
[
i
+
1
]
-
offset_array
[
i
];
TensorRow
row
;
row
.
setId
(
row_id_
.
at
(
i
));
if
(
len
>
0
)
{
ReadableSlice
row_data
(
all
,
offset_array
[
i
],
len
);
// Next we de-serialize flat buffer to get back each column
auto
msg
=
GetTensorRowHeaderMsg
(
row_data
.
GetPointer
());
auto
msg_sz
=
msg
->
size_of_this
();
// Start of the tensor data
auto
ts_offset
=
msg_sz
;
row
.
reserve
(
msg
->
column
()
->
size
());
for
(
auto
k
=
0
;
k
<
msg
->
column
()
->
size
();
++
k
)
{
auto
col_ts
=
msg
->
column
()
->
Get
(
k
);
std
::
shared_ptr
<
Tensor
>
ts
;
ReadableSlice
data
(
row_data
,
ts_offset
,
msg
->
data_sz
()
->
Get
(
k
));
RETURN_IF_NOT_OK
(
RestoreOneTensor
(
col_ts
,
data
,
&
ts
));
row
.
push_back
(
ts
);
ts_offset
+=
data
.
GetSize
();
}
}
tbl
.
push_back
(
std
::
move
(
row
));
}
*
out
=
std
::
move
(
tbl
);
return
Status
::
OK
();
}
Status
CacheSchemaRequest
::
SerializeCacheSchemaRequest
(
const
std
::
unordered_map
<
std
::
string
,
int32_t
>
&
map
)
{
try
{
fbb_
=
std
::
make_shared
<
flatbuffers
::
FlatBufferBuilder
>
();
std
::
vector
<
flatbuffers
::
Offset
<
ColumnNameMsg
>>
v
;
v
.
reserve
(
map
.
size
());
for
(
auto
&
column
:
map
)
{
auto
c
=
CreateColumnNameMsg
(
*
fbb_
,
fbb_
->
CreateString
(
column
.
first
),
column
.
second
);
v
.
push_back
(
c
);
}
auto
v_off
=
fbb_
->
CreateVector
(
v
);
auto
final_off
=
CreateSchemaMsg
(
*
fbb_
,
v_off
);
fbb_
->
Finish
(
final_off
);
buf_
=
fbb_
->
GetBufferPointer
();
len_of_buf_
=
fbb_
->
GetSize
();
return
Status
::
OK
();
}
catch
(
const
std
::
bad_alloc
&
e
)
{
return
Status
(
StatusCode
::
kOutOfMemory
,
__LINE__
,
__FILE__
);
}
}
std
::
unordered_map
<
std
::
string
,
int32_t
>
FetchSchemaRequest
::
GetColumnMap
()
{
if
(
column_name_id_map_
.
empty
())
{
auto
*
map_msg
=
flatbuffers
::
GetRoot
<
SchemaMsg
>
(
mem_
.
GetPointer
());
auto
v
=
map_msg
->
column
();
for
(
auto
i
=
0
;
i
<
v
->
size
();
++
i
)
{
auto
col
=
map_msg
->
column
()
->
Get
(
i
);
column_name_id_map_
.
emplace
(
col
->
name
()
->
str
(),
col
->
id
());
}
}
return
column_name_id_map_
;
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/cache/cache_request.h
0 → 100644
浏览文件 @
7c1bc519
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_CACHE_REQ_H_
#define DATASET_ENGINE_CACHE_REQ_H_
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "./de_tensor_generated.h"
#include "dataset/core/tensor_row.h"
#include "dataset/util/slice.h"
#include "dataset/util/wait_post.h"
namespace
mindspore
{
namespace
dataset
{
/// \brief CacheClient communicates with CacheServer using Requests.
class
BaseRequest
{
public:
// Request types
enum
class
RequestType
:
int16_t
{
kCacheRow
=
0
,
kBatchFetchRows
=
1
,
kCreateCache
=
2
,
kPurgeCache
=
3
,
kDestroyCache
=
4
,
kGetStat
=
5
,
kCacheSchema
=
6
,
kFetchSchema
=
7
,
kBuildPhaseDone
=
8
,
// Add new request before it.
kRequestUnknown
=
32767
};
// For kCreateCache
enum
class
CreateCacheFlag
:
uint32_t
{
kNone
=
0
,
kSpillToDisk
=
1
,
kGenerateRowId
=
1u
<<
1L
};
friend
class
CacheServer
;
/// \brief Base class of a cache server request
/// \param connection_id A combination of session id and crc that uniquely identifies a connection.
/// \param type Type of the request
explicit
BaseRequest
(
connection_id_type
connection_id
,
RequestType
type
)
:
type_
(
type
),
connection_id_
(
connection_id
)
{}
virtual
~
BaseRequest
()
=
default
;
/// \brief Wait for the completion of a request
/// \return Status returned from the cache server
Status
Wait
()
{
RETURN_IF_NOT_OK
(
wp_
.
Wait
());
return
rc_
;
}
/// \brief Getter function of the current connection id
/// \return Connection id
connection_id_type
GetServerConnectionId
()
const
{
return
connection_id_
;
}
private:
RequestType
type_
;
connection_id_type
connection_id_
;
Status
rc_
;
WaitPost
wp_
;
};
/// \brief Request to cache a single TensorRow
class
CacheRowRequest
:
public
BaseRequest
{
public:
friend
class
CacheServer
;
explicit
CacheRowRequest
(
connection_id_type
connection_id
,
const
std
::
string
&
cookie
)
:
BaseRequest
(
connection_id
,
RequestType
::
kCacheRow
),
row_id_from_server_
(
-
1
),
cookie_
(
cookie
)
{}
~
CacheRowRequest
()
=
default
;
/// \brief Serialize a TensorRow for streaming to the cache server
/// \param row TensorRow
/// \return Status object
Status
SerializeCacheRowRequest
(
const
TensorRow
&
row
);
/// \brief Return the row id assigned to this row for non-mappable dataset
/// \return row id of the cached row
row_id_type
GetRowIdAfterCache
()
{
return
row_id_from_server_
;
}
private:
std
::
shared_ptr
<
flatbuffers
::
FlatBufferBuilder
>
fbb_
;
row_id_type
row_id_from_server_
;
std
::
vector
<
const
void
*>
buffers_
;
std
::
string
cookie_
;
/// \brief Private function to serialize one TensorRow
/// \param row TensorRow
/// \return Status object
Status
SerializeTensorRowHeader
(
const
TensorRow
&
row
);
/// \brief Private function to serialize one Tensor
/// \param ts_ptr Tensor
/// \return Status object
Status
SerializeOneTensorMeta
(
const
std
::
shared_ptr
<
Tensor
>
&
ts_ptr
,
flatbuffers
::
Offset
<
TensorMetaMsg
>
*
out_off
);
};
/// \brief Request to fetch rows in batch
class
BatchFetchRequest
:
public
BaseRequest
{
public:
friend
class
CacheServer
;
friend
class
CacheService
;
BatchFetchRequest
(
connection_id_type
connection_id
,
const
std
::
vector
<
row_id_type
>
&
row_id
)
:
BaseRequest
(
connection_id
,
RequestType
::
kBatchFetchRows
),
row_id_
(
row_id
)
{}
Status
RestoreRows
(
TensorTable
*
out
);
private:
std
::
vector
<
row_id_type
>
row_id_
;
MemGuard
<
uint8_t
>
mem_
;
Status
RestoreOneTensor
(
const
TensorMetaMsg
*
col_ts
,
const
ReadableSlice
&
data
,
std
::
shared_ptr
<
Tensor
>
*
out
);
};
/// \brief Request to create a cache for the current connection
class
CreationCacheRequest
:
public
BaseRequest
{
public:
friend
class
CacheServer
;
/// \brief Constructor
/// \param connection_id
/// \param cache_mem_sz Maximum memory assigned for this connection. 0 means unlimited
/// \param flag Attributes of the cache.
explicit
CreationCacheRequest
(
connection_id_type
connection_id
,
uint64_t
cache_mem_sz
,
CreateCacheFlag
flag
=
CreateCacheFlag
::
kNone
)
:
BaseRequest
(
connection_id
,
RequestType
::
kCreateCache
),
cache_mem_sz
(
cache_mem_sz
),
flag_
(
flag
)
{}
std
::
string
cookie
()
const
{
return
cookie_
;
}
private:
uint64_t
cache_mem_sz
;
CreateCacheFlag
flag_
;
std
::
string
cookie_
;
};
/// \brief Request to purge a cache.
class
PurgeCacheRequest
:
public
BaseRequest
{
public:
friend
class
CacheServer
;
explicit
PurgeCacheRequest
(
connection_id_type
connection_id
)
:
BaseRequest
(
connection_id
,
RequestType
::
kPurgeCache
)
{}
};
/// \brief Request to destroy a cache
class
DestroyCacheRequest
:
public
BaseRequest
{
public:
friend
class
CacheServer
;
explicit
DestroyCacheRequest
(
connection_id_type
connection_id
)
:
BaseRequest
(
connection_id
,
RequestType
::
kDestroyCache
)
{}
};
/// \brief Obtain the statistics of the current connection
class
GetStatRequest
:
public
BaseRequest
{
public:
friend
class
CacheServer
;
friend
class
CacheService
;
explicit
GetStatRequest
(
connection_id_type
connection_id
)
:
BaseRequest
(
connection_id
,
RequestType
::
kGetStat
)
{}
row_id_type
GetMinRowId
()
const
{
auto
*
msg
=
flatbuffers
::
GetRoot
<
ServiceStatMsg
>
(
mem_
.
GetPointer
());
return
msg
->
min_row_id
();
}
row_id_type
GetMaxRowId
()
const
{
auto
*
msg
=
flatbuffers
::
GetRoot
<
ServiceStatMsg
>
(
mem_
.
GetPointer
());
return
msg
->
max_row_id
();
}
int64_t
GetNumMemCached
()
const
{
auto
*
msg
=
flatbuffers
::
GetRoot
<
ServiceStatMsg
>
(
mem_
.
GetPointer
());
return
msg
->
num_mem_cached
();
}
int64_t
GetNumDiskCached
()
const
{
auto
*
msg
=
flatbuffers
::
GetRoot
<
ServiceStatMsg
>
(
mem_
.
GetPointer
());
return
msg
->
num_disk_cached
();
}
uint8_t
GetState
()
const
{
auto
*
msg
=
flatbuffers
::
GetRoot
<
ServiceStatMsg
>
(
mem_
.
GetPointer
());
return
msg
->
state
();
}
private:
MemGuard
<
uint8_t
>
mem_
;
};
/// \brief Request to cache a schema
class
CacheSchemaRequest
:
public
BaseRequest
{
public:
friend
class
CacheServer
;
explicit
CacheSchemaRequest
(
connection_id_type
connection_id
)
:
BaseRequest
(
connection_id
,
RequestType
::
kCacheSchema
),
buf_
(
nullptr
),
len_of_buf_
(
0
)
{}
~
CacheSchemaRequest
()
=
default
;
Status
SerializeCacheSchemaRequest
(
const
std
::
unordered_map
<
std
::
string
,
int32_t
>
&
map
);
const
void
*
GetBuffer
()
const
{
return
buf_
;
}
private:
std
::
shared_ptr
<
flatbuffers
::
FlatBufferBuilder
>
fbb_
;
const
void
*
buf_
;
int64_t
len_of_buf_
;
};
/// \brief Request to fetch a schema
class
FetchSchemaRequest
:
public
BaseRequest
{
public:
friend
class
CacheServer
;
explicit
FetchSchemaRequest
(
connection_id_type
connection_id
)
:
BaseRequest
(
connection_id
,
RequestType
::
kFetchSchema
)
{}
~
FetchSchemaRequest
()
=
default
;
std
::
unordered_map
<
std
::
string
,
int32_t
>
GetColumnMap
();
private:
MemGuard
<
uint8_t
>
mem_
;
std
::
unordered_map
<
std
::
string
,
int32_t
>
column_name_id_map_
;
};
/// \brief Request to change a cache from build phase to read phase. Applies to non-mappable cache only.
class
BuildPhaseDoneRequest
:
public
BaseRequest
{
public:
friend
class
CacheServer
;
BuildPhaseDoneRequest
(
connection_id_type
connection_id
,
const
std
::
string
&
cookie
)
:
BaseRequest
(
connection_id
,
RequestType
::
kBuildPhaseDone
),
cookie_
(
cookie
)
{}
private:
std
::
string
cookie_
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_ENGINE_CACHE_SERVICE_H_
mindspore/ccsrc/dataset/engine/cache/cache_server.cc
0 → 100644
浏览文件 @
7c1bc519
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/cache/cache_server.h"
#include "dataset/engine/cache/cache_service.h"
#include "dataset/engine/cache/cache_request.h"
#include "dataset/util/bit.h"
namespace
mindspore
{
namespace
dataset
{
Status
CacheServer
::
DoServiceStart
()
{
if
(
!
top_
.
empty
())
{
Path
spill
(
top_
);
RETURN_IF_NOT_OK
(
spill
.
CreateDirectories
());
MS_LOG
(
INFO
)
<<
"CacheServer will use disk folder: "
<<
top_
;
}
RETURN_IF_NOT_OK
(
vg_
.
ServiceStart
());
cache_q_
=
std
::
make_shared
<
Queue
<
BaseRequest
*>>
(
1024
);
RETURN_IF_NOT_OK
(
cache_q_
->
Register
(
&
vg_
));
auto
f
=
std
::
bind
(
&
CacheServer
::
ServerRequest
,
this
);
// Spawn a a few threads to serve the request.
for
(
auto
i
=
0
;
i
<
num_workers_
;
++
i
)
{
RETURN_IF_NOT_OK
(
vg_
.
CreateAsyncTask
(
"Cache server"
,
f
));
}
return
Status
::
OK
();
}
Status
CacheServer
::
DoServiceStop
()
{
Status
rc
;
Status
rc2
;
// First stop all the threads.
RETURN_IF_NOT_OK
(
vg_
.
ServiceStop
());
// Clean up all the caches if any.
UniqueLock
lck
(
&
rwLock_
);
auto
it
=
all_caches_
.
begin
();
while
(
it
!=
all_caches_
.
end
())
{
auto
cs
=
std
::
move
(
it
->
second
);
rc2
=
cs
->
ServiceStop
();
if
(
rc2
.
IsError
())
{
rc
=
rc2
;
}
++
it
;
}
return
rc
;
}
CacheService
*
CacheServer
::
GetService
(
connection_id_type
id
)
const
{
SharedLock
lck
(
&
rwLock_
);
auto
it
=
all_caches_
.
find
(
id
);
if
(
it
!=
all_caches_
.
end
())
{
return
it
->
second
.
get
();
}
return
nullptr
;
}
Status
CacheServer
::
CreateService
(
connection_id_type
connection_id
,
uint64_t
cache_mem_sz
,
BaseRequest
::
CreateCacheFlag
flag
,
std
::
string
*
out_cookie
)
{
// We can't do spilling unless this server is setup with a spill path in the first place
bool
spill
=
(
flag
&
BaseRequest
::
CreateCacheFlag
::
kSpillToDisk
)
==
BaseRequest
::
CreateCacheFlag
::
kSpillToDisk
;
bool
generate_id
=
(
flag
&
BaseRequest
::
CreateCacheFlag
::
kGenerateRowId
)
==
BaseRequest
::
CreateCacheFlag
::
kGenerateRowId
;
if
(
spill
&&
top_
.
empty
())
{
RETURN_STATUS_UNEXPECTED
(
"Server is not set up with spill support."
);
}
RETURN_UNEXPECTED_IF_NULL
(
out_cookie
);
*
out_cookie
=
""
;
// Before creating the cache, first check if this is a request for a shared usage of an existing cache
// If two CreateService come in with identical connection_id, we need to serialize the create.
// The first create will be successful and be given a special cookie.
UniqueLock
lck
(
&
rwLock_
);
auto
end
=
all_caches_
.
end
();
auto
it
=
all_caches_
.
find
(
connection_id
);
if
(
it
==
end
)
{
std
::
unique_ptr
<
CacheService
>
cs
;
try
{
cs
=
std
::
make_unique
<
CacheService
>
(
cache_mem_sz
,
spill
?
top_
:
""
,
generate_id
);
RETURN_IF_NOT_OK
(
cs
->
ServiceStart
());
*
out_cookie
=
cs
->
cookie
();
all_caches_
.
emplace
(
connection_id
,
std
::
move
(
cs
));
}
catch
(
const
std
::
bad_alloc
&
e
)
{
return
Status
(
StatusCode
::
kOutOfMemory
);
}
}
else
{
MS_LOG
(
INFO
)
<<
"Duplicate request for "
+
std
::
to_string
(
connection_id
)
+
" to create cache service"
;
// We can return OK but we will return a duplicate key so user can act accordingly to either ignore it
// treat it as OK.
return
Status
(
StatusCode
::
kDuplicateKey
);
}
return
Status
::
OK
();
}
/// This is the main loop the cache server thread(s) are running.
/// Each thread will pop a request and save the result in the same request.
/// The sender will wait on the wait post in the request. Once the request
/// is fulfilled, the server thread will do a post signalling the request is
/// is processed.
/// \return
Status
CacheServer
::
ServerRequest
()
{
TaskManager
::
FindMe
()
->
Post
();
// Loop forever until we are interrupted.
while
(
true
)
{
BaseRequest
*
base_rq
=
nullptr
;
RETURN_IF_NOT_OK
(
cache_q_
->
PopFront
(
&
base_rq
));
auto
cs
=
GetService
(
base_rq
->
connection_id_
);
// Except for creating a new session, we expect cs is not null.
switch
(
base_rq
->
type_
)
{
case
BaseRequest
::
RequestType
::
kCacheRow
:
{
if
(
cs
==
nullptr
)
{
std
::
string
errMsg
=
"Cache id "
+
std
::
to_string
(
base_rq
->
connection_id_
)
+
" not found"
;
base_rq
->
rc_
=
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
errMsg
);
}
else
{
auto
*
rq
=
reinterpret_cast
<
CacheRowRequest
*>
(
base_rq
);
// Only if the cookie matches, we can accept insert into this cache that has a build phase
if
(
!
cs
->
HasBuildPhase
()
||
rq
->
cookie_
==
cs
->
cookie
())
{
rq
->
rc_
=
cs
->
CacheRow
(
rq
->
buffers_
,
&
rq
->
row_id_from_server_
);
}
else
{
return
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
"Cookie mismatch"
);
}
}
break
;
}
case
BaseRequest
::
RequestType
::
kBatchFetchRows
:
{
if
(
cs
==
nullptr
)
{
std
::
string
errMsg
=
"Cache id "
+
std
::
to_string
(
base_rq
->
connection_id_
)
+
" not found"
;
base_rq
->
rc_
=
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
errMsg
);
}
else
{
auto
*
rq
=
reinterpret_cast
<
BatchFetchRequest
*>
(
base_rq
);
rq
->
rc_
=
cs
->
BatchFetch
(
rq
->
row_id_
,
&
rq
->
mem_
);
}
break
;
}
case
BaseRequest
::
RequestType
::
kCreateCache
:
{
// If the cache is already created we still need to run the creation so that we do sanity checks on the
// client id and return the cache id back to the user.
auto
*
rq
=
reinterpret_cast
<
CreationCacheRequest
*>
(
base_rq
);
rq
->
rc_
=
CreateService
(
rq
->
connection_id_
,
rq
->
cache_mem_sz
,
rq
->
flag_
,
&
rq
->
cookie_
);
break
;
}
case
BaseRequest
::
RequestType
::
kPurgeCache
:
{
if
(
cs
!=
nullptr
)
{
base_rq
->
rc_
=
cs
->
Purge
();
}
else
{
// it is already purged. Ignore it.
base_rq
->
rc_
=
Status
::
OK
();
}
break
;
}
case
BaseRequest
::
RequestType
::
kDestroyCache
:
{
if
(
cs
!=
nullptr
)
{
// We need a strong lock to protect the map.
connection_id_type
id
=
base_rq
->
connection_id_
;
UniqueLock
lck
(
&
rwLock_
);
// std::map will invoke the constructor of CacheService. So we don't need to do anything here.
auto
n
=
all_caches_
.
erase
(
id
);
if
(
n
==
0
)
{
// It has been destroyed by another duplicate request.
MS_LOG
(
INFO
)
<<
"Duplicate request for "
+
std
::
to_string
(
id
)
+
" to create cache service"
;
}
base_rq
->
rc_
=
Status
::
OK
();
}
else
{
// it is already destroyed. Ignore it.
base_rq
->
rc_
=
Status
::
OK
();
}
break
;
}
case
BaseRequest
::
RequestType
::
kGetStat
:
{
if
(
cs
==
nullptr
)
{
std
::
string
errMsg
=
"Session "
+
std
::
to_string
(
base_rq
->
connection_id_
)
+
" not found"
;
base_rq
->
rc_
=
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
errMsg
);
}
else
{
auto
*
rq
=
reinterpret_cast
<
GetStatRequest
*>
(
base_rq
);
CacheService
::
ServiceStat
svc_stat
;
rq
->
rc_
=
cs
->
GetStat
(
&
svc_stat
);
if
(
rq
->
rc_
.
IsOk
())
{
flatbuffers
::
FlatBufferBuilder
fbb
;
ServiceStatMsgBuilder
bld
(
fbb
);
bld
.
add_num_disk_cached
(
svc_stat
.
stat_
.
num_disk_cached
);
bld
.
add_num_mem_cached
(
svc_stat
.
stat_
.
num_mem_cached
);
bld
.
add_max_row_id
(
svc_stat
.
max_
);
bld
.
add_min_row_id
(
svc_stat
.
min_
);
bld
.
add_state
(
svc_stat
.
state_
);
auto
offset
=
bld
.
Finish
();
fbb
.
Finish
(
offset
);
rq
->
rc_
=
rq
->
mem_
.
allocate
(
fbb
.
GetSize
());
if
(
rq
->
rc_
.
IsOk
())
{
WritableSlice
dest
(
rq
->
mem_
.
GetMutablePointer
(),
fbb
.
GetSize
());
ReadableSlice
src
(
fbb
.
GetBufferPointer
(),
fbb
.
GetSize
());
RETURN_IF_NOT_OK
(
WritableSlice
::
Copy
(
&
dest
,
src
));
}
}
}
break
;
}
case
BaseRequest
::
RequestType
::
kCacheSchema
:
{
if
(
cs
==
nullptr
)
{
std
::
string
errMsg
=
"Session "
+
std
::
to_string
(
base_rq
->
connection_id_
)
+
" not found"
;
base_rq
->
rc_
=
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
errMsg
);
}
else
{
auto
*
rq
=
reinterpret_cast
<
CacheSchemaRequest
*>
(
base_rq
);
rq
->
rc_
=
cs
->
CacheSchema
(
rq
->
buf_
,
rq
->
len_of_buf_
);
}
break
;
}
case
BaseRequest
::
RequestType
::
kFetchSchema
:
{
if
(
cs
==
nullptr
)
{
std
::
string
errMsg
=
"Session "
+
std
::
to_string
(
base_rq
->
connection_id_
)
+
" not found"
;
base_rq
->
rc_
=
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
errMsg
);
}
else
{
auto
*
rq
=
reinterpret_cast
<
FetchSchemaRequest
*>
(
base_rq
);
rq
->
rc_
=
cs
->
FetchSchema
(
&
rq
->
mem_
);
}
break
;
}
case
BaseRequest
::
RequestType
::
kBuildPhaseDone
:
{
if
(
cs
==
nullptr
)
{
std
::
string
errMsg
=
"Session "
+
std
::
to_string
(
base_rq
->
connection_id_
)
+
" not found"
;
base_rq
->
rc_
=
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
errMsg
);
}
else
{
auto
*
rq
=
reinterpret_cast
<
BuildPhaseDoneRequest
*>
(
base_rq
);
// We can only allow to switch phase is the cookie match.
if
(
rq
->
cookie_
==
cs
->
cookie
())
{
rq
->
rc_
=
cs
->
BuildPhaseDone
();
}
else
{
return
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
"Cookie mismatch"
);
}
}
break
;
}
default:
base_rq
->
rc_
=
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
"Unknown request type"
);
}
// Notify it is done, and move on to the next request.
base_rq
->
wp_
.
Set
();
}
return
Status
::
OK
();
}
CacheServer
::
CacheServer
(
const
std
::
string
&
spill_path
,
int32_t
num_workers
)
:
top_
(
spill_path
),
num_workers_
(
num_workers
)
{}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/cache/cache_server.h
0 → 100644
浏览文件 @
7c1bc519
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_CACHE_SERVER_H_
#define DATASET_ENGINE_CACHE_SERVER_H_
#include <algorithm>
#include <atomic>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <map>
#include "dataset/engine/cache/cache_service.h"
#include "dataset/core/tensor.h"
#include "dataset/util/arena.h"
#include "dataset/util/cache_pool.h"
#include "dataset/util/lock.h"
#include "dataset/util/service.h"
#include "dataset/util/services.h"
#include "dataset/util/system_pool.h"
#include "dataset/util/queue.h"
#include "dataset/util/task_manager.h"
namespace
mindspore
{
namespace
dataset
{
class
BaseRequest
;
/// \brief A server which provides CacheService services.
class
CacheServer
:
public
Service
{
public:
friend
class
Services
;
using
cache_index
=
std
::
map
<
connection_id_type
,
std
::
unique_ptr
<
CacheService
>>
;
CacheServer
(
const
CacheServer
&
)
=
delete
;
CacheServer
&
operator
=
(
const
CacheServer
&
)
=
delete
;
CacheServer
(
CacheServer
&&
)
=
delete
;
CacheServer
&
operator
=
(
CacheServer
&
)
=
delete
;
static
CacheServer
&
GetInstance
()
noexcept
{
return
Services
::
getCacheServer
();
}
Status
DoServiceStart
()
override
;
Status
DoServiceStop
()
override
;
~
CacheServer
()
{
(
void
)
ServiceStop
();
}
/// \brief For the current demonstration, a cache client contacts cache server using a Queue.
/// \param rq
/// \return Status object
Status
PushRequest
(
BaseRequest
*
rq
)
{
RETURN_UNEXPECTED_IF_NULL
(
rq
);
RETURN_IF_NOT_OK
(
cache_q_
->
Add
(
rq
));
return
Status
::
OK
();
}
private:
mutable
RWLock
rwLock_
;
std
::
string
top_
;
cache_index
all_caches_
;
std
::
shared_ptr
<
Queue
<
BaseRequest
*>>
cache_q_
;
TaskGroup
vg_
;
int32_t
num_workers_
;
/// \brief Constructor
/// \param spill_path Top directory for spilling buffers to.
/// \param num_workers Number of threads for handling requests.
explicit
CacheServer
(
const
std
::
string
&
spill_path
,
int32_t
num_workers
=
3
);
/// \brief Locate a cache service from connection id.
/// \return Pointer to cache service. Null if not found
CacheService
*
GetService
(
connection_id_type
id
)
const
;
/// \brief Create a cache service. We allow multiple clients to create the same cache service.
/// Subsequent duplicate requests are ignored. The first cache client to create the service will be given
/// a special unique cookie.
/// \param[in] connection_id This is from a Cache client.
/// \param[in] cache_mem_sz
/// \param[in] flag
/// \param[out] out_cookie Only the first cache client will be given a special cookie to identify the creator
/// \return Status object
Status
CreateService
(
connection_id_type
connection_id
,
uint64_t
cache_mem_sz
,
BaseRequest
::
CreateCacheFlag
flag
,
std
::
string
*
out_cookie
);
/// \brief Entry point for all server threads.
Status
ServerRequest
();
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_CORE_CACHE_TENSOR_H_
mindspore/ccsrc/dataset/engine/cache/cache_service.cc
0 → 100644
浏览文件 @
7c1bc519
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/cache/cache_service.h"
#include "dataset/util/slice.h"
namespace
mindspore
{
namespace
dataset
{
CacheService
::
CacheService
(
uint64_t
mem_sz
,
const
std
::
string
&
root
,
bool
generate_id
)
:
root_
(
root
),
cache_mem_sz_
(
mem_sz
),
cp_
(
nullptr
),
map_
(
nullptr
),
next_id_
(
0
),
generate_id_
(
generate_id
),
schema_key_
(
-
1
),
st_
(
generate_id
?
State
::
kBuildPhase
:
State
::
kNone
)
{}
CacheService
::~
CacheService
()
{
(
void
)
ServiceStop
();
}
bool
CacheService
::
UseArena
()
{
// If fixed size, use Arena instead of the pool from global context.
return
(
cache_mem_sz_
>
0
);
}
Status
CacheService
::
DoServiceStart
()
{
std
::
shared_ptr
<
MemoryPool
>
mp_
;
if
(
UseArena
())
{
// Create a fixed size arena based on the parameter.
std
::
shared_ptr
<
Arena
>
arena
;
RETURN_IF_NOT_OK
(
Arena
::
CreateArena
(
&
arena
,
cache_mem_sz_
));
mp_
=
std
::
move
(
arena
);
}
else
{
// Unlimited size. Simply use a system pool. Another choice is CircularPool.
mp_
=
std
::
make_shared
<
SystemPool
>
();
}
// Put together a CachePool for backing up the Tensor
cp_
=
std
::
make_shared
<
CachePool
>
(
CachePool
::
value_allocator
(
mp_
),
root_
);
RETURN_IF_NOT_OK
(
cp_
->
ServiceStart
());
// Set up the B+ tree as well. But use the system pool instead.
map_
=
std
::
make_shared
<
row_map
>
();
// Assign a name to this cache. Used for exclusive connection. But we can just use CachePool's name.
cookie_
=
cp_
->
MyName
();
return
Status
::
OK
();
}
Status
CacheService
::
DoServiceStop
()
{
if
(
cp_
!=
nullptr
)
{
RETURN_IF_NOT_OK
(
cp_
->
ServiceStop
());
}
return
Status
::
OK
();
}
Status
CacheService
::
CacheRow
(
const
std
::
vector
<
const
void
*>
&
buf
,
row_id_type
*
row_id_generated
)
{
SharedLock
rw
(
&
rw_lock_
);
RETURN_UNEXPECTED_IF_NULL
(
row_id_generated
);
if
(
st_
==
State
::
kFetchPhase
)
{
// For this kind of cache service, once we are done with the build phase into fetch phase, we can't
// allow other to cache more rows.
RETURN_STATUS_UNEXPECTED
(
"Can't accept cache request in fetch phase"
);
}
try
{
// The first buffer is a flatbuffer which describes the rest of the buffers follow
auto
fb
=
buf
.
front
();
RETURN_UNEXPECTED_IF_NULL
(
fb
);
auto
msg
=
GetTensorRowHeaderMsg
(
fb
);
// If the server side is designed to ignore incoming row id, we generate row id.
if
(
generate_id_
)
{
*
row_id_generated
=
GetNextRowId
();
// Some debug information on how many rows we have generated so far.
if
((
*
row_id_generated
)
%
1000
==
0
)
{
MS_LOG
(
DEBUG
)
<<
"Number of rows cached: "
<<
*
row_id_generated
;
}
}
else
{
if
(
msg
->
row_id
()
<
0
)
{
std
::
string
errMsg
=
"Expect positive row id: "
+
std
::
to_string
(
msg
->
row_id
());
RETURN_STATUS_UNEXPECTED
(
errMsg
);
}
*
row_id_generated
=
msg
->
row_id
();
}
auto
size_of_this
=
msg
->
size_of_this
();
auto
column_hdr
=
msg
->
column
();
// Number of tensor buffer should match the number of columns plus one.
if
(
buf
.
size
()
!=
column_hdr
->
size
()
+
1
)
{
std
::
string
errMsg
=
"Column count does not match. Expect "
+
std
::
to_string
(
column_hdr
->
size
()
+
1
)
+
" but get "
+
std
::
to_string
(
buf
.
size
());
RETURN_STATUS_UNEXPECTED
(
errMsg
);
}
// Next we store in either memory or on disk. Low level code will consolidate everything in one piece.
std
::
vector
<
ReadableSlice
>
all_data
;
all_data
.
reserve
(
column_hdr
->
size
()
+
1
);
all_data
.
emplace_back
(
fb
,
size_of_this
);
for
(
auto
i
=
0
;
i
<
column_hdr
->
size
();
++
i
)
{
all_data
.
emplace_back
(
buf
.
at
(
i
+
1
),
msg
->
data_sz
()
->
Get
(
i
));
}
// Now we cache the flat buffer.
CachePool
::
key_type
key
;
RETURN_IF_NOT_OK
(
cp_
->
Insert
(
all_data
,
&
key
));
Status
rc
=
map_
->
DoInsert
(
*
row_id_generated
,
key
);
if
(
rc
==
Status
(
StatusCode
::
kDuplicateKey
))
{
MS_LOG
(
DEBUG
)
<<
"Ignoring duplicate key"
;
}
else
{
RETURN_IF_NOT_OK
(
rc
);
}
return
Status
::
OK
();
}
catch
(
const
std
::
exception
&
e
)
{
RETURN_STATUS_UNEXPECTED
(
e
.
what
());
}
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
CacheService
&
cs
)
{
// Then show any custom derived-internal stuff
out
<<
"
\n
Cache memory size: "
<<
cs
.
cache_mem_sz_
;
out
<<
"
\n
Spill path: "
;
if
(
cs
.
root_
.
empty
())
{
out
<<
"None"
;
}
else
{
out
<<
cs
.
GetSpillPath
();
}
return
out
;
}
Path
CacheService
::
GetSpillPath
()
const
{
return
cp_
->
GetSpillPath
();
}
Status
CacheService
::
Purge
()
{
// First we must lock exclusively. No one else can cache/restore anything.
UniqueLock
rw
(
&
rw_lock_
);
RETURN_IF_NOT_OK
(
cp_
->
ServiceStop
());
auto
new_map
=
std
::
make_shared
<
row_map
>
();
map_
.
reset
();
map_
=
std
::
move
(
new_map
);
next_id_
=
0
;
RETURN_IF_NOT_OK
(
cp_
->
ServiceStart
());
return
Status
::
OK
();
}
Status
CacheService
::
GetStat
(
CacheService
::
ServiceStat
*
out
)
{
SharedLock
rw
(
&
rw_lock_
);
RETURN_UNEXPECTED_IF_NULL
(
out
);
if
(
st_
==
State
::
kNone
||
st_
==
State
::
kFetchPhase
)
{
out
->
stat_
=
cp_
->
GetStat
();
out
->
state_
=
static_cast
<
ServiceStat
::
state_type
>
(
st_
);
auto
it
=
map_
->
begin
();
if
(
it
!=
map_
->
end
())
{
out
->
min_
=
it
.
key
();
auto
end_it
=
map_
->
end
();
--
end_it
;
out
->
max_
=
end_it
.
key
();
}
}
else
{
out
->
state_
=
static_cast
<
ServiceStat
::
state_type
>
(
st_
);
}
return
Status
::
OK
();
}
Status
CacheService
::
BatchFetch
(
const
std
::
vector
<
row_id_type
>
&
v
,
MemGuard
<
uint8_t
>
*
out
)
const
{
RETURN_UNEXPECTED_IF_NULL
(
out
);
SharedLock
rw
(
&
rw_lock_
);
if
(
st_
==
State
::
kBuildPhase
)
{
// For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
RETURN_STATUS_UNEXPECTED
(
"Can't accept cache request in fetch phase"
);
}
const
auto
num_elements
=
v
.
size
();
int64_t
mem_sz
=
(
num_elements
+
1
)
*
sizeof
(
int64_t
);
int64_t
data_offset
=
mem_sz
;
std
::
vector
<
int64_t
>
sz_v
;
std
::
vector
<
CachePool
::
key_type
>
keys
;
sz_v
.
reserve
(
num_elements
);
keys
.
reserve
(
num_elements
);
for
(
auto
row_id
:
v
)
{
auto
r
=
map_
->
Search
(
row_id
);
if
(
r
.
second
)
{
auto
&
it
=
r
.
first
;
CachePool
::
key_type
key
=
it
.
value
();
auto
sz
=
cp_
->
GetSize
(
key
);
if
(
sz
==
0
)
{
std
::
string
errMsg
=
"Key not found: "
;
errMsg
+=
std
::
to_string
(
key
);
RETURN_STATUS_UNEXPECTED
(
errMsg
);
}
keys
.
push_back
(
key
);
sz_v
.
push_back
(
sz
);
mem_sz
+=
sz
;
}
else
{
keys
.
push_back
(
-
1
);
sz_v
.
push_back
(
0
);
}
}
MemGuard
<
uint8_t
>
mem
;
RETURN_IF_NOT_OK
(
mem
.
allocate
(
mem_sz
));
auto
*
offset_array
=
reinterpret_cast
<
int64_t
*>
(
mem
.
GetMutablePointer
());
offset_array
[
0
]
=
data_offset
;
WritableSlice
all
(
mem
.
GetMutablePointer
(),
mem
.
GetSizeInBytes
());
for
(
auto
i
=
0
;
i
<
num_elements
;
++
i
)
{
auto
sz
=
sz_v
.
at
(
i
);
offset_array
[
i
+
1
]
=
offset_array
[
i
]
+
sz
;
if
(
sz
>
0
)
{
WritableSlice
row_data
(
all
,
offset_array
[
i
],
sz
);
auto
key
=
keys
.
at
(
i
);
size_t
bytesRead
=
0
;
RETURN_IF_NOT_OK
(
cp_
->
Read
(
key
,
&
row_data
,
&
bytesRead
));
if
(
bytesRead
!=
sz
)
{
MS_LOG
(
ERROR
)
<<
"Unexpected length. Read "
<<
bytesRead
<<
". Expected "
<<
sz
<<
"."
<<
" Internal key: "
<<
key
<<
"
\n
"
;
RETURN_STATUS_UNEXPECTED
(
"Length mismatch. See log file for details."
);
}
}
}
*
out
=
std
::
move
(
mem
);
return
Status
::
OK
();
}
Status
CacheService
::
CacheSchema
(
const
void
*
buf
,
int64_t
len
)
{
SharedLock
rw
(
&
rw_lock_
);
if
(
st_
==
State
::
kFetchPhase
)
{
// For this kind of cache service, once we are done with the build phase into fetch phase, we can't
// allow other to cache more rows.
RETURN_STATUS_UNEXPECTED
(
"Can't accept cache request in fetch phase"
);
}
// This is a special request and we need to remember where we store it.
// In case we are calling the same function from multiple threads, only
// the first one is considered. Rest is ignored.
CachePool
::
key_type
cur_key
=
schema_key_
;
CachePool
::
key_type
key
;
if
(
cur_key
<
0
)
{
RETURN_IF_NOT_OK
(
cp_
->
Insert
({
ReadableSlice
(
buf
,
len
)},
&
key
));
auto
result
=
std
::
atomic_compare_exchange_strong
(
&
schema_key_
,
&
cur_key
,
key
);
MS_LOG
(
DEBUG
)
<<
"Caching Schema. Result = "
<<
result
;
}
else
{
MS_LOG
(
DEBUG
)
<<
"Caching Schema already done"
;
}
return
Status
::
OK
();
}
Status
CacheService
::
FetchSchema
(
MemGuard
<
uint8_t
>
*
out
)
const
{
SharedLock
rw
(
&
rw_lock_
);
if
(
st_
==
State
::
kBuildPhase
)
{
// For this kind of cache service, we can't fetch yet until we are done with caching all the rows.
RETURN_STATUS_UNEXPECTED
(
"Can't accept cache request in fetch phase"
);
}
RETURN_UNEXPECTED_IF_NULL
(
out
);
MemGuard
<
uint8_t
>
mem
;
if
(
schema_key_
>=
0
)
{
auto
len
=
cp_
->
GetSize
(
schema_key_
);
RETURN_IF_NOT_OK
(
mem
.
allocate
(
len
));
auto
slice
=
WritableSlice
(
mem
.
GetMutablePointer
(),
len
);
RETURN_IF_NOT_OK
(
cp_
->
Read
(
schema_key_
,
&
slice
));
*
out
=
std
::
move
(
mem
);
}
else
{
return
Status
(
StatusCode
::
kFileNotExist
,
__LINE__
,
__FILE__
,
"No schema has been cached"
);
}
return
Status
::
OK
();
}
Status
CacheService
::
BuildPhaseDone
()
{
if
(
HasBuildPhase
())
{
// Exclusive lock to switch phase
UniqueLock
rw
(
&
rw_lock_
);
st_
=
State
::
kFetchPhase
;
return
Status
::
OK
();
}
else
{
RETURN_STATUS_UNEXPECTED
(
"Not a cache that has a build phase"
);
}
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/cache/cache_service.h
0 → 100644
浏览文件 @
7c1bc519
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_CACHE_SERVICE_H_
#define DATASET_ENGINE_CACHE_SERVICE_H_
#include <algorithm>
#include <atomic>
#include <memory>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
#include "./de_tensor_generated.h"
#include "dataset/core/global_context.h"
#include "dataset/core/tensor.h"
#include "dataset/engine/cache/cache_request.h"
#include "dataset/util/arena.h"
#include "dataset/util/btree.h"
#include "dataset/util/cache_pool.h"
#include "dataset/util/service.h"
#include "dataset/util/services.h"
#include "dataset/util/system_pool.h"
namespace
mindspore
{
namespace
dataset
{
struct
CacheStat
;
/// \brief A cache service for storing/fetching buffers to in memory cache and may spill to disk the cache service is
/// created to support spilling
class
CacheService
:
public
Service
{
public:
friend
class
CacheServer
;
using
row_map
=
BPlusTree
<
row_id_type
,
CachePool
::
key_type
>
;
enum
class
State
:
uint8_t
{
kNone
=
0
,
kBuildPhase
,
kFetchPhase
};
/// \brief Constructor
/// \param mem_sz Memory size to be set aside for the in memory cache. 0 means unlimited
/// \param root Spill path. Empty string means no spilling
/// \param generate_id If the cache service should generate row id for buffer that is cached.
/// For non-mappable dataset, this should be set to true.
CacheService
(
uint64_t
mem_sz
,
const
std
::
string
&
root
,
bool
generate_id
);
~
CacheService
();
/// \brief For fixed size memory, we will create an Arena.
/// \return false if unlimited memory.
bool
UseArena
();
Status
DoServiceStart
()
override
;
Status
DoServiceStop
()
override
;
/// \brief Main function to cache a row which is in form a series of buffers.
/// The first buffer is a Google flatbuffer which describes the rest of the buffers followed.
/// \param[in] buf Vector of buffer
/// \param[out] row_id_generated The row id assigned to this row if any
/// \return Status object
Status
CacheRow
(
const
std
::
vector
<
const
void
*>
&
buf
,
row_id_type
*
row_id_generated
);
/// \brief Main function to fetch rows in batch. The output is a contiguous memory which will be decoded
/// by the CacheClient. Cache miss is not an error, and will be coded in the output to mark an empty row.
/// \param[in] v A vector of row id.
/// \param[out] out A contiguous memory buffer that holds the requested rows.
/// \return Status object
Status
BatchFetch
(
const
std
::
vector
<
row_id_type
>
&
v
,
MemGuard
<
uint8_t
>
*
out
)
const
;
/// \brief Getter function
/// \return Spilling path
Path
GetSpillPath
()
const
;
/// \brief A structure returned from the cache server for statistics request.
class
ServiceStat
{
public:
using
state_type
=
std
::
underlying_type
<
State
>::
type
;
ServiceStat
()
:
min_
(
0
),
max_
(
0
),
state_
(
0
)
{}
CachePool
::
CacheStat
stat_
{};
row_id_type
min_
;
row_id_type
max_
;
state_type
state_
;
};
/// \brief Statistics for the current service
/// \param[in/out] A pointer to a pre-allocated ServiceStat structure
/// \return Status Object
Status
GetStat
(
ServiceStat
*
);
/// \brief Cache schema
/// \param buf A Google Flatbuffer that contains the schema
/// \param len size of the buffer
/// \return Status object
Status
CacheSchema
(
const
void
*
buf
,
int64_t
len
);
/// \brief Fetch schema
/// \param out A contiguous memory that contains the serialized form of schema.
/// \return Status object
Status
FetchSchema
(
MemGuard
<
uint8_t
>
*
out
)
const
;
/// \brief Purge the content of a cache
/// \return Status object
Status
Purge
();
/// \brief Overload the << operator to print a cache service
/// \param out std::ostream
/// \param cs A cache service
/// \return std::ostream
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
CacheService
&
cs
);
/// \brief Every cache service has a cookie. If the cookie of a CacheClient matches this cookie, this CacheClient
/// is the creator
/// \return Cookie
std
::
string
cookie
()
const
{
return
cookie_
;
}
/// \brief If this cache service generates row id for buffer cached, it is divided into two phases, a build phase and
/// a read phase.
/// \return True if has two phases.
bool
HasBuildPhase
()
const
{
return
generate_id_
;
}
/// \brief Change from write phase to read phase. Only the creator of this service is allowed to make this call.
/// \return Status object
Status
BuildPhaseDone
();
private:
mutable
RWLock
rw_lock_
;
std
::
string
root_
;
uint64_t
cache_mem_sz_
;
std
::
shared_ptr
<
CachePool
>
cp_
;
std
::
shared_ptr
<
row_map
>
map_
;
std
::
atomic
<
row_id_type
>
next_id_
;
bool
generate_id_
;
std
::
atomic
<
CachePool
::
key_type
>
schema_key_
;
std
::
string
cookie_
;
State
st_
;
/// \brief Private function to generate a row id
/// \return Row id assigned.
row_id_type
GetNextRowId
()
{
return
next_id_
.
fetch_add
(
1
);
}
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_ENGINE_CACHE_SERVICE_H_
mindspore/ccsrc/dataset/engine/cache/de_tensor.fbs
0 → 100644
浏览文件 @
7c1bc519
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
namespace mindspore.dataset;
/// Type of a Tensor
enum TensorType : byte {
DE_UNKNOWN = 0,
DE_BOOL = 1,
DE_INT8 = 2,
DE_UINT8 = 3,
DE_INT16 = 4,
DE_UINT16 = 5,
DE_INT32 = 6,
DE_UINT32 = 7,
DE_INT64 = 8,
DE_UINT64 = 9,
DE_FLOAT16 = 10,
DE_FLOAT32 = 11,
DE_FLOAT64 = 12,
DE_STRING = 13
}
/// The meta information of a Tensor
/// \note Only the type and shape are considered meta information. Tensor data is excluded.
table TensorMetaMsg {
dims:[int64] (required);
type:TensorType;
}
/// This is the first buffer that is sent to a Cache server when a TensorRow is serialized.
/// \param row_id is the row id of the TensorRow.
/// \param column The meta information of each Tensor in the row
/// \param size of this serialized buffer
/// \param size of each tensor data buffer that follows
table TensorRowHeaderMsg {
row_id:int64;
column:[TensorMetaMsg] (required);
size_of_this:int64;
data_sz:[int64] (required);
}
root_type TensorRowHeaderMsg;
/// A row of row id's
table TensorRowIds {
row_id:[int64] (required);
}
/// Statistics returned from each cache service
/// \note It must match CacheService::ServiceStat
table ServiceStatMsg {
num_mem_cached:int64;
num_disk_cached:int64;
min_row_id:int64;
max_row_id:int64;
state:int8;
}
/// Column description of each column in a schema
table ColumnNameMsg {
name:string;
id:int32;
}
/// Serialized form of a schema
table SchemaMsg {
column:[ColumnNameMsg];
}
mindspore/ccsrc/dataset/engine/data_buffer.cc
浏览文件 @
7c1bc519
...
...
@@ -24,10 +24,8 @@ namespace dataset {
// Description: This is the main constructor that is used for making a buffer
DataBuffer
::
DataBuffer
(
int32_t
id
,
BufferFlags
flags
)
:
buffer_id_
(
id
),
tensor_table_
(
nullptr
),
buffer_flags_
(
flags
)
{}
// Name: print()
// Description: A function that prints info about the DataBuffer (base class version)
void
DataBuffer
::
Print
(
std
::
ostream
&
out
,
// In: The output stream to print to
bool
show_all
)
const
{
// In: T/F if it should show everything
// A method for debug printing of the buffer
void
DataBuffer
::
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
{
out
<<
"bufferId: "
<<
buffer_id_
<<
"
\n
flags: "
<<
std
::
hex
<<
buffer_flags_
<<
std
::
dec
<<
"
\n
"
;
// If the column counts are set then it means that data has been set into
...
...
@@ -46,11 +44,6 @@ void DataBuffer::Print(std::ostream &out, // In: The output stream to print
}
}
Status
DataBuffer
::
Load
()
{
std
::
string
err_msg
=
"Base class load called, but it does not have an implementation!"
;
RETURN_STATUS_UNEXPECTED
(
err_msg
);
}
// Remove me!! Callers should fetch rows via pop
Status
DataBuffer
::
GetTensor
(
std
::
shared_ptr
<
Tensor
>
*
ptr
,
int32_t
row_id
,
int32_t
col_id
)
const
{
if
(
row_id
<
tensor_table_
->
size
()
&&
col_id
<
tensor_table_
->
at
(
row_id
).
size
())
{
...
...
@@ -92,8 +85,5 @@ Status DataBuffer::SliceOff(int64_t number_of_rows) {
return
Status
::
OK
();
}
// Destructor
DataBuffer
::~
DataBuffer
()
{}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/data_buffer.h
浏览文件 @
7c1bc519
...
...
@@ -29,11 +29,9 @@
namespace
mindspore
{
namespace
dataset
{
// The DataBuffer class is a base class that will represent the data for n values based
// on a unique row id for each row of data.
// There can be different types of DataBuffers to abstract over how the data is stored
// in memory and acquired from storage.
// Each buffer holds a range of consecutive row id's.
/// \brief The DataBuffer class is a container of tensor data and is the unit of transmission between
/// connectors of dataset operators. Inside the buffer, tensors are organized into a table-like format
/// where n TensorRows may consist of m tensors (columns).
class
DataBuffer
{
public:
// Buffer flags
...
...
@@ -47,13 +45,13 @@ class DataBuffer {
// Description: This is the main constructor that is used for making a buffer
DataBuffer
(
int32_t
id
,
BufferFlags
flags
);
//
D
estructor
virtual
~
DataBuffer
()
;
//
/ \brief default d
estructor
~
DataBuffer
()
=
default
;
//
Name: print()
//
Description: A function that prints info about the DataBuffer (base class version)
virtual
void
Print
(
std
::
ostream
&
out
,
// In: The output stream to print to
bool
show_all
)
const
;
// In: T/F if it should show everything
//
/ \brief A method for debug printing of the buffer
//
/ \param[inout] out The stream to write to
/// \param[in] show_all A boolean to toggle between details and summary printing
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
;
// Provide stream operator for displaying it
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
DataBuffer
&
cb
)
{
...
...
@@ -61,10 +59,6 @@ class DataBuffer {
return
out
;
}
// Name: load()
// Description: populates the DataBuffer with data based on it's id
virtual
Status
Load
();
// Convenience getter functions for flag checking
bool
eof
()
const
{
return
(
static_cast
<
uint32_t
>
(
buffer_flags_
)
&
static_cast
<
uint32_t
>
(
kDeBFlagEOF
));
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt
浏览文件 @
7c1bc519
...
...
@@ -17,7 +17,11 @@ set(DATASET_ENGINE_DATASETOPS_SRC_FILES
take_op.cc
shuffle_op.cc
zip_op.cc
concat_op.cc
concat_op.cc
cache_base_op.cc
cache_lookup_op.cc
cache_op.cc
cache_merge_op.cc
)
if
(
ENABLE_PYTHON
)
...
...
mindspore/ccsrc/dataset/engine/datasetops/cache_base_op.cc
0 → 100644
浏览文件 @
7c1bc519
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/cache_base_op.h"
#include <iomanip>
#include <iostream>
#include "dataset/engine/execution_tree.h"
namespace
mindspore
{
namespace
dataset
{
// A print method typically used for debugging
void
CacheBase
::
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
{
// Always show the id and name as first line regardless if this summary or detailed print
out
<<
"("
<<
std
::
setw
(
2
)
<<
operator_id_
<<
") <"
<<
Name
()
<<
">:"
;
if
(
!
show_all
)
{
// Call the super class for displaying any common 1-liner info
ParallelOp
::
Print
(
out
,
show_all
);
out
<<
"
\n
"
;
}
else
{
// Call the super class for displaying any common detailed info
ParallelOp
::
Print
(
out
,
show_all
);
// Then show any custom derived-internal stuff
out
<<
"
\n
Cache client:
\n
"
<<
*
cache_client_
<<
"
\n\n
"
;
}
}
// Overrides base class reset method. When an operator does a reset, it cleans up any state
// info from it's previous execution and then initializes itself so that it can be executed
// again.
Status
CacheBase
::
Reset
()
{
if
(
sampler_
!=
nullptr
)
{
RETURN_IF_NOT_OK
(
sampler_
->
ResetSampler
());
}
// Wake up the workers to get them going again in a new epoch
MS_LOG
(
DEBUG
)
<<
Name
()
<<
" resetting."
;
epoch_sync_
.
Set
();
return
Status
::
OK
();
}
CacheBase
::
CacheBase
(
int32_t
num_workers
,
int32_t
op_connector_size
,
int32_t
rows_per_buf
,
std
::
shared_ptr
<
CacheClient
>
cache_client
,
std
::
shared_ptr
<
Sampler
>
sampler
)
:
ParallelOp
(
num_workers
,
op_connector_size
,
sampler
),
cache_client_
(
cache_client
),
rows_per_buffer_
(
rows_per_buf
),
// We can cause deadlock if this internal Connector size is too small.
keys_miss_
(
num_workers_
,
1
,
1024
)
{
io_block_queues_
.
Init
(
num_workers
,
op_connector_size
);
}
// Common function to fetch samples from the sampler and send them using the io_block_queues to
// the parallel workers
Status
CacheBase
::
FetchSamplesToWorkers
()
{
int64_t
buf_cnt
=
0
;
int64_t
wait_cnt
=
0
;
do
{
epoch_sync_
.
Clear
();
std
::
vector
<
row_id_type
>
keys
;
int64_t
row_cnt
=
0
;
keys
.
reserve
(
rows_per_buffer_
);
std
::
unique_ptr
<
DataBuffer
>
sampler_buffer
;
RETURN_IF_NOT_OK
(
sampler_
->
GetNextSample
(
&
sampler_buffer
));
while
(
!
sampler_buffer
->
eoe
())
{
TensorRow
sample_row
;
RETURN_IF_NOT_OK
(
sampler_buffer
->
PopRow
(
&
sample_row
));
std
::
shared_ptr
<
Tensor
>
sample_ids
=
sample_row
[
0
];
for
(
auto
itr
=
sample_ids
->
begin
<
int64_t
>
();
itr
!=
sample_ids
->
end
<
int64_t
>
();
itr
++
)
{
keys
.
push_back
(
*
itr
);
++
row_cnt
;
if
(
row_cnt
%
rows_per_buffer_
==
0
)
{
auto
blk
=
std
::
make_unique
<
IOBlock
>
(
IOBlock
(
keys
,
IOBlock
::
kDeIoBlockNone
));
RETURN_IF_NOT_OK
(
io_block_queues_
[
buf_cnt
++
%
num_workers_
]
->
Add
(
std
::
move
(
blk
)));
keys
.
clear
();
}
}
RETURN_IF_NOT_OK
(
sampler_
->
GetNextSample
(
&
sampler_buffer
));
}
if
(
!
keys
.
empty
())
{
auto
blk
=
std
::
make_unique
<
IOBlock
>
(
IOBlock
(
keys
,
IOBlock
::
kDeIoBlockNone
));
RETURN_IF_NOT_OK
(
io_block_queues_
[
buf_cnt
++
%
num_workers_
]
->
Add
(
std
::
move
(
blk
)));
}
// send the eoe
RETURN_IF_NOT_OK
(
io_block_queues_
[(
buf_cnt
++
)
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
::
kDeIoBlockFlagEoe
)));
// If repeat but the not last repeat, wait for reset.
if
(
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
)
&&
!
BitTest
(
op_ctrl_flags_
,
kDeOpLastRepeat
))
{
MS_LOG
(
DEBUG
)
<<
Name
()
<<
" Waiting for reset. Count "
<<
++
wait_cnt
<<
" Buffer sent "
<<
buf_cnt
;
RETURN_IF_NOT_OK
(
epoch_sync_
.
Wait
());
}
else
{
// We can break out from the loop.
break
;
}
}
while
(
true
);
// Flow the eof before exit
RETURN_IF_NOT_OK
(
io_block_queues_
[(
buf_cnt
++
)
%
num_workers_
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
IOBlock
::
kDeIoBlockFlagEof
)));
// Ask all the workers to quit.
for
(
int32_t
i
=
0
;
i
<
num_workers_
;
i
++
)
{
RETURN_IF_NOT_OK
(
io_block_queues_
[
i
]
->
Add
(
std
::
make_unique
<
IOBlock
>
(
std
::
vector
<
int64_t
>
(),
IOBlock
::
kDeIoBlockNone
)));
}
return
Status
::
OK
();
}
Status
CacheBase
::
FetchFromCache
(
int32_t
worker_id
)
{
int64_t
buffer_id
=
worker_id
;
std
::
unique_ptr
<
IOBlock
>
blk
;
do
{
RETURN_IF_NOT_OK
(
io_block_queues_
[
worker_id
]
->
PopFront
(
&
blk
));
if
(
blk
->
eof
())
{
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
worker_id
,
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOF
)));
}
else
if
(
blk
->
eoe
())
{
if
(
AllowCacheMiss
())
{
// This code path is for CacheLookupOp acting as a sampler. If we get a eoe from
// a sampler, send a eoe to physical leaf op as well.
std
::
vector
<
row_id_type
>
eoe
;
eoe
.
push_back
(
eoe_row_id
);
RETURN_IF_NOT_OK
(
keys_miss_
.
Push
(
worker_id
,
eoe
));
}
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
worker_id
,
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
)));
}
else
{
std
::
vector
<
int64_t
>
keys
;
RETURN_IF_NOT_OK
(
blk
->
GetKeys
(
&
keys
));
if
(
keys
.
empty
())
{
// empty key is a quit signal for workers
break
;
}
std
::
unique_ptr
<
DataBuffer
>
db
=
std
::
make_unique
<
DataBuffer
>
(
buffer_id
,
DataBuffer
::
kDeBFlagNone
);
std
::
unique_ptr
<
TensorQTable
>
que
=
std
::
make_unique
<
TensorQTable
>
();
TensorTable
ttbl
;
RETURN_IF_NOT_OK
(
cache_client_
->
GetRows
(
keys
,
&
ttbl
));
auto
row_it
=
ttbl
.
begin
();
std
::
vector
<
row_id_type
>
cache_miss
;
cache_miss
.
reserve
(
keys
.
size
());
for
(
auto
row_id
:
keys
)
{
auto
&
row
=
*
row_it
;
if
(
row
.
empty
())
{
if
(
AllowCacheMiss
())
{
cache_miss
.
push_back
(
row_id
);
}
else
{
std
::
string
errMsg
=
"Row id "
+
std
::
to_string
(
row_id
)
+
" not found."
;
RETURN_STATUS_UNEXPECTED
(
errMsg
);
}
}
que
->
push_back
(
std
::
move
(
row
));
++
row_it
;
}
db
->
set_tensor_table
(
std
::
move
(
que
));
if
(
AllowCacheMiss
())
{
// Because of the way connector works, we push unconditionally even cache_miss can be empty.
RETURN_IF_NOT_OK
(
keys_miss_
.
Push
(
worker_id
,
cache_miss
));
}
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
worker_id
,
std
::
move
(
db
)));
buffer_id
+=
num_workers_
;
}
}
while
(
true
);
return
Status
::
OK
();
}
Status
CacheBase
::
RegisterResources
()
{
RETURN_IF_NOT_OK
(
epoch_sync_
.
Register
(
tree_
->
AllTasks
()));
RETURN_IF_NOT_OK
(
io_block_queues_
.
Register
(
tree_
->
AllTasks
()));
return
Status
::
OK
();
}
CacheBase
::~
CacheBase
()
{}
Status
CacheBase
::
UpdateColumnMapFromCache
()
{
Status
rc
;
// Get the schema from the server. It may not be there yet. So tolerate the error.
if
(
column_name_id_map_
.
empty
())
{
rc
=
cache_client_
->
FetchSchema
(
&
column_name_id_map_
);
if
(
rc
==
Status
(
StatusCode
::
kFileNotExist
))
{
MS_LOG
(
DEBUG
)
<<
"Schema not in the server yet."
;
rc
=
Status
::
OK
();
}
}
return
rc
;
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/cache_base_op.h
0 → 100644
浏览文件 @
7c1bc519
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_
#define DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "dataset/engine/cache/cache_client.h"
#include "dataset/engine/cache/cache_service.h"
#include "dataset/engine/datasetops/parallel_op.h"
#include "dataset/engine/datasetops/repeat_op.h"
#include "dataset/engine/datasetops/source/io_block.h"
#include "dataset/engine/datasetops/source/sampler/sampler.h"
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/util/queue.h"
#include "dataset/util/wait_post.h"
#include "dataset/engine/datasetops/cache_base_op.h"
namespace
mindspore
{
namespace
dataset
{
/// \brief This is the base class for CacheOp and CacheLookupOp which share many similarities.
/// \see CacheOp
/// \see CacheLookupOp
class
CacheBase
:
public
ParallelOp
{
public:
/// \brief Base class constructor
/// \param num_workers Number of parallel workers
/// \param op_connector_size Connector size
/// \param rows_per_buf Number of rows per buffer
/// \param cache_client CacheClient for communication to the CacheServer
/// \param sampler Sampler which is mandatory
CacheBase
(
int32_t
num_workers
,
int32_t
op_connector_size
,
int32_t
rows_per_buf
,
std
::
shared_ptr
<
CacheClient
>
cache_client
,
std
::
shared_ptr
<
Sampler
>
sampler
);
/// \brief Destructor
~
CacheBase
();
constexpr
static
int
eoe_row_id
=
-
1
;
/// \brief Overrides base class reset method. When an operator does a reset, it cleans up any state
/// info from it's previous execution and then initializes itself so that it can be executed
/// again.
/// \return Status - The error code return
Status
Reset
()
override
;
/// \brief A print method typically used for debugging
/// \param out The output stream to write output to
/// \param show_all A bool to control if you want to show all info or just a summary
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
/// \brief << Stream output operator overload
/// \notes This allows you to write the debug print info using stream operators
/// \param out reference to the output stream being overloaded
/// \param mo reference to the CacheOp to display
/// \return the output stream must be returned
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
CacheBase
&
mo
)
{
mo
.
Print
(
out
,
false
);
return
out
;
}
/// \brief Getter for the cache client
/// \return shared ptr to the cache client
std
::
shared_ptr
<
CacheClient
>
cache_client
()
{
return
cache_client_
;
}
/// \brief Setter for the cache client
void
SetCacheClient
(
std
::
shared_ptr
<
CacheClient
>
cache_client
)
{
cache_client_
=
std
::
move
(
cache_client
);
}
/// \brief Derived class must implement this method if a cache miss is treated as error
virtual
bool
AllowCacheMiss
()
=
0
;
protected:
std
::
shared_ptr
<
CacheClient
>
cache_client_
;
WaitPost
epoch_sync_
;
int32_t
rows_per_buffer_
;
Connector
<
std
::
vector
<
row_id_type
>>
keys_miss_
;
/// \brief Common function to register resources for interrupt
/// \note Derived should override this function for extra resources to be registered
virtual
Status
RegisterResources
();
/// \brief This function is called by main thread to send samples to the worker thread.
/// \note It is a non-virtual function
/// \return Status object
Status
FetchSamplesToWorkers
();
/// \brief This function is called by each worker to fetch rows from the cache server for a given set of
/// sample row id's
/// \return Status object
Status
FetchFromCache
(
int32_t
worker_id
);
/// \brief Get the column map from cache server
Status
UpdateColumnMapFromCache
();
private:
QueueList
<
std
::
unique_ptr
<
IOBlock
>>
io_block_queues_
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_CACHE_BASE_OP_H_
mindspore/ccsrc/dataset/engine/datasetops/cache_lookup_op.cc
0 → 100644
浏览文件 @
7c1bc519
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/cache_lookup_op.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/core/config_manager.h"
#include "dataset/core/constants.h"
#include "dataset/core/global_context.h"
#include "dataset/engine/execution_tree.h"
#include "utils/log_adapter.h"
#include "utils/system/crc32c.h"
namespace
mindspore
{
namespace
dataset
{
// Builder constructor. Creates the builder object.
CacheLookupOp
::
Builder
::
Builder
()
:
build_cache_client_
(
nullptr
),
build_sampler_
(
nullptr
)
{
std
::
shared_ptr
<
ConfigManager
>
cfg
=
GlobalContext
::
config_manager
();
build_num_workers_
=
cfg
->
num_parallel_workers
();
rows_per_buffer_
=
cfg
->
rows_per_buffer
();
build_op_connector_size_
=
cfg
->
op_connector_size
();
}
// Check if the required parameters are set by the builder.
Status
CacheLookupOp
::
Builder
::
SanityCheck
()
const
{
if
(
build_cache_client_
==
nullptr
)
{
return
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
"CacheLookupOp requires a CacheClient"
);
}
// Make sure the cache client has a valid session
if
(
!
build_cache_client_
->
session_id
())
{
return
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
"Cache client for CacheLookupOp is missing session id"
);
}
return
Status
::
OK
();
}
// The builder "build" method creates the final object and does some init on it
Status
CacheLookupOp
::
Builder
::
Build
(
std
::
shared_ptr
<
CacheLookupOp
>
*
ptr
)
{
RETURN_IF_NOT_OK
(
SanityCheck
());
*
ptr
=
std
::
make_shared
<
CacheLookupOp
>
(
build_num_workers_
,
build_op_connector_size_
,
rows_per_buffer_
,
build_cache_client_
,
build_sampler_
);
return
Status
::
OK
();
}
Status
CacheLookupOp
::
operator
()()
{
if
(
!
sampler_
)
{
return
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
"CacheLookupOp requires a sampler before it can be executed!"
);
}
RETURN_IF_NOT_OK
(
RegisterResources
());
// Kick off the workers
RETURN_IF_NOT_OK
(
tree_
->
LaunchWorkers
(
num_workers_
,
std
::
bind
(
&
CacheLookupOp
::
WorkerEntry
,
this
,
std
::
placeholders
::
_1
)));
// required task group sync after launching workers
TaskManager
::
FindMe
()
->
Post
();
// We have to wait until the leaf op has handshake with us.
RETURN_IF_NOT_OK
(
leaf_op_wp_
.
Wait
());
RETURN_IF_NOT_OK
(
FetchSamplesToWorkers
());
return
Status
::
OK
();
}
Status
CacheLookupOp
::
WorkerEntry
(
int32_t
worker_id
)
{
TaskManager
::
FindMe
()
->
Post
();
RETURN_IF_NOT_OK
(
FetchFromCache
(
worker_id
));
return
Status
::
OK
();
}
Status
CacheLookupOp
::
ResetSampler
()
{
return
Status
::
OK
();
}
Status
CacheLookupOp
::
HandshakeRandomAccessOp
(
const
RandomAccessOp
*
op
)
{
// We act like a sampler and as a dataset op. During handshake with leaf op,
// We must wait until the leaf op has indexed everything.
RETURN_IF_NOT_OK
(
sampler_
->
HandshakeRandomAccessOp
(
op
));
// Now we notify the main thread handshake has finished.
leaf_op_wp_
.
Set
();
return
Status
::
OK
();
}
Status
CacheLookupOp
::
InitSampler
()
{
return
Sampler
::
InitSampler
();
}
void
CacheLookupOp
::
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
{
CacheBase
::
Print
(
out
,
show_all
);
}
Status
CacheLookupOp
::
GetNextSample
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
{
std
::
vector
<
row_id_type
>
cache_miss
;
RETURN_IF_NOT_OK
(
keys_miss_
.
Pop
(
0
,
&
cache_miss
));
// Ignore the case we have no cache miss, we can't return empty samples.
while
(
cache_miss
.
empty
())
{
RETURN_IF_NOT_OK
(
keys_miss_
.
Pop
(
0
,
&
cache_miss
));
}
// Special code for eoe
if
(
cache_miss
.
at
(
0
)
==
eoe_row_id
)
{
*
out_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
}
else
{
std
::
shared_ptr
<
Tensor
>
sample_ts
;
RETURN_IF_NOT_OK
(
CreateSamplerTensor
(
&
sample_ts
,
cache_miss
.
size
()));
(
*
out_buffer
)
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagNone
);
auto
idPtr
=
sample_ts
->
begin
<
int64_t
>
();
for
(
auto
i
=
0
;
i
<
cache_miss
.
size
();
++
i
)
{
*
idPtr
=
cache_miss
.
at
(
i
);
++
idPtr
;
}
TensorRow
row
;
row
.
push_back
(
sample_ts
);
(
*
out_buffer
)
->
set_tensor_table
(
std
::
make_unique
<
TensorQTable
>
(
1
,
row
));
}
return
Status
::
OK
();
}
Status
CacheLookupOp
::
RegisterResources
()
{
RETURN_IF_NOT_OK
(
CacheBase
::
RegisterResources
());
RETURN_IF_NOT_OK
(
leaf_op_wp_
.
Register
(
tree_
->
AllTasks
()));
return
Status
::
OK
();
}
Status
CacheLookupOp
::
ComputeColMap
()
{
// We don't know the column map at this point unless we contact the cache server
// to fetch the schema but the cache server may not have it at this point either.
// So we will just return OK and let MergeOp (our parent) to handle it.
return
Status
::
OK
();
}
// Visitor accept method for NodePass
Status
CacheLookupOp
::
Accept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call visitor
return
p
->
RunOnNode
(
shared_from_base
<
CacheLookupOp
>
(),
modified
);
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/cache_lookup_op.h
0 → 100644
浏览文件 @
7c1bc519
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_
#define DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_
#include <atomic>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "dataset/engine/datasetops/cache_base_op.h"
namespace
mindspore
{
namespace
dataset
{
/// \brief provides a memory/disk cache that acts as a save-point within a mappable dataset.
/// \note For non-mappable dataset, please see CacheOp
/// \see CacheOp
class
CacheLookupOp
:
public
CacheBase
,
public
Sampler
{
public:
class
Builder
{
public:
/// \brief Builder constructor. Creates the builder object.
/// \note No default args
Builder
();
/// Default destructor
~
Builder
()
=
default
;
/// Setter method.
/// \treturn Builder setter method returns reference to the builder.
Builder
&
SetNumWorkers
(
int32_t
num_workers
)
{
build_num_workers_
=
num_workers
;
return
*
this
;
}
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder
&
SetOpConnectorSize
(
int32_t
connector_size
)
{
build_op_connector_size_
=
connector_size
;
return
*
this
;
}
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder
&
SetClient
(
std
::
shared_ptr
<
CacheClient
>
cache_client
)
{
build_cache_client_
=
cache_client
;
return
*
this
;
}
/// \brief Setter method.
/// \return Builder setter method returns reference to the builder.
Builder
&
SetSampler
(
std
::
shared_ptr
<
Sampler
>
sampler
)
{
build_sampler_
=
std
::
move
(
sampler
);
return
*
this
;
}
/// \brief The builder "build" method creates the final object and does some init on it.
/// \param ptr The shared_ptr to the new CacheLookupOp object
/// \return Status
Status
Build
(
std
::
shared_ptr
<
CacheLookupOp
>
*
ptr
);
private:
int32_t
build_num_workers_
;
int32_t
rows_per_buffer_
;
int32_t
build_op_connector_size_
;
std
::
shared_ptr
<
CacheClient
>
build_cache_client_
;
std
::
shared_ptr
<
Sampler
>
build_sampler_
;
// Check if the required parameters are set by the builder.
// \return Status The error code return
Status
SanityCheck
()
const
;
};
/// \brief Constructor
/// \note It takes the same argument as the base class.
/// \see CacheBase
CacheLookupOp
(
int32_t
num_workers
,
int32_t
op_connector_size
,
int32_t
rows_per_buf
,
std
::
shared_ptr
<
CacheClient
>
cache_client
,
std
::
shared_ptr
<
Sampler
>
sampler
)
:
CacheBase
(
num_workers
,
op_connector_size
,
rows_per_buf
,
cache_client
,
sampler
),
Sampler
(
*
(
sampler
.
get
()))
{}
~
CacheLookupOp
()
=
default
;
// As a parallel op, we override these two functions
Status
operator
()()
override
;
Status
WorkerEntry
(
int32_t
worker_id
)
override
;
// As a sampler, we override the following functions
Status
ResetSampler
()
override
;
Status
HandshakeRandomAccessOp
(
const
RandomAccessOp
*
op
)
override
;
Status
InitSampler
()
override
;
Status
GetNextSample
(
std
::
unique_ptr
<
DataBuffer
>
*
out_buffer
)
override
;
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
bool
AllowCacheMiss
()
override
{
return
true
;
}
std
::
string
Name
()
const
override
{
return
"CacheLookupOp"
;
}
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status
Accept
(
NodePass
*
p
,
bool
*
modified
)
override
;
protected:
Status
ComputeColMap
()
override
;
private:
WaitPost
leaf_op_wp_
;
Status
RegisterResources
()
override
;
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_CACHE_LOOKUP_OP_H_
mindspore/ccsrc/dataset/engine/datasetops/cache_merge_op.cc
0 → 100644
浏览文件 @
7c1bc519
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include <functional>
#include <iomanip>
#include "dataset/core/config_manager.h"
#include "dataset/core/constants.h"
#include "dataset/core/global_context.h"
#include "dataset/engine/datasetops/cache_merge_op.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/util/task_manager.h"
namespace
mindspore
{
namespace
dataset
{
CacheMergeOp
::~
CacheMergeOp
()
=
default
;
void
CacheMergeOp
::
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
{
// Always show the id and name as first line regardless if this is summary or detailed print
out
<<
"("
<<
std
::
setw
(
2
)
<<
operator_id_
<<
") <CacheMergeOp>:"
;
if
(
!
show_all
)
{
// Call the super class for displaying any common 1-liner info
ParallelOp
::
Print
(
out
,
show_all
);
// Then show any custom derived-internal 1-liner info for this op
out
<<
"
\n
"
;
}
else
{
// Call the super class for displaying any common detailed info
ParallelOp
::
Print
(
out
,
show_all
);
// Then show any custom derived-internal stuff
out
<<
"
\n\n
"
;
}
}
CacheMergeOp
::
CacheMergeOp
(
int32_t
numWorkers
,
int32_t
opConnectorSize
,
int32_t
numCleaners
,
std
::
shared_ptr
<
CacheClient
>
cache_client
,
const
std
::
shared_ptr
<
Sampler
>
&
sampler
)
:
ParallelOp
(
numWorkers
,
opConnectorSize
,
sampler
),
num_cleaners_
(
numCleaners
),
cache_client_
(
cache_client
)
{}
Status
CacheMergeOp
::
operator
()()
{
// A queue of row id to let cleaner send cache miss rows to the cache server
// We don't want a small queue as this will block the parallel op workers.
// A row id is 8 byte integer. So bigger size doesn't consume a lot of memory.
io_que_
=
std
::
make_unique
<
Queue
<
row_id_type
>>
(
512
);
RETURN_IF_NOT_OK
(
io_que_
->
Register
(
tree_
->
AllTasks
()));
RETURN_IF_NOT_OK
(
tree_
->
LaunchWorkers
(
num_workers_
,
std
::
bind
(
&
CacheMergeOp
::
WorkerEntry
,
this
,
std
::
placeholders
::
_1
)));
RETURN_IF_NOT_OK
(
tree_
->
LaunchWorkers
(
num_workers_
,
std
::
bind
(
&
CacheMergeOp
::
CacheMissWorkerEntry
,
this
,
std
::
placeholders
::
_1
)));
// One dedicated thread to move TensorRow from the pool to the cache server
for
(
auto
i
=
0
;
i
<
num_cleaners_
;
++
i
)
{
RETURN_IF_NOT_OK
(
tree_
->
AllTasks
()
->
CreateAsyncTask
(
"Cleaner"
,
std
::
bind
(
&
CacheMergeOp
::
Cleaner
,
this
)));
}
TaskManager
::
FindMe
()
->
Post
();
return
Status
::
OK
();
}
// Each parallel worker will pop from the CacheHit stream. If there is a missing TensorRow, we will wait
// until it shows up in the pool.
Status
CacheMergeOp
::
WorkerEntry
(
int32_t
worker_id
)
{
TaskManager
::
FindMe
()
->
Post
();
std
::
shared_ptr
<
DatasetOp
>
cache_hit_stream
=
child_
[
kCacheHitChildIdx
];
std
::
unique_ptr
<
DataBuffer
>
db_ptr
;
RETURN_IF_NOT_OK
(
cache_hit_stream
->
GetNextBuffer
(
&
db_ptr
,
worker_id
));
while
(
!
db_ptr
->
eof
())
{
if
(
db_ptr
->
eoe
())
{
RETURN_IF_NOT_OK
(
EoeReceived
(
worker_id
));
db_ptr
.
reset
();
RETURN_IF_NOT_OK
(
cache_hit_stream
->
GetNextBuffer
(
&
db_ptr
,
worker_id
));
}
else
{
// See if there is any missing row
auto
tbl
=
std
::
make_unique
<
TensorQTable
>
();
while
(
db_ptr
->
NumRows
()
>
0
)
{
TensorRow
row
;
RETURN_IF_NOT_OK
(
db_ptr
->
PopRow
(
&
row
));
if
(
row
.
empty
())
{
auto
row_id
=
row
.
getId
();
TensorRowRequest
*
rq
=
nullptr
;
RETURN_IF_NOT_OK
(
GetRq
(
row_id
,
&
rq
));
// Block until the row shows up in the pool.
RETURN_IF_NOT_OK
(
rq
->
Wait
(
&
row
));
}
tbl
->
push_back
(
std
::
move
(
row
));
}
db_ptr
->
set_tensor_table
(
std
::
move
(
tbl
));
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
worker_id
,
std
::
move
(
db_ptr
)));
RETURN_IF_NOT_OK
(
cache_hit_stream
->
GetNextBuffer
(
&
db_ptr
,
worker_id
));
}
}
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
worker_id
,
std
::
move
(
db_ptr
)));
return
Status
::
OK
();
}
Status
CacheMergeOp
::
CacheMissWorkerEntry
(
int32_t
workerId
)
{
TaskManager
::
FindMe
()
->
Post
();
// We will simply pop TensorRow from the stream and insert them into the pool and
// wake up any worker that is awaiting on the missing TensorRow.
// If we see an eoe, ignore it. For eof, we exit.
std
::
shared_ptr
<
DatasetOp
>
cache_missing_stream
=
child_
[
kCacheMissChildIdx
];
// Before we start, cache the schema at the server. Pick one of the workers
// do it. The schema should have been done at prepare time.
if
(
workerId
==
0
)
{
RETURN_IF_NOT_OK
(
cache_client_
->
CacheSchema
(
column_name_id_map
()));
}
std
::
unique_ptr
<
DataBuffer
>
db_ptr
;
RETURN_IF_NOT_OK
(
cache_missing_stream
->
GetNextBuffer
(
&
db_ptr
,
workerId
));
while
(
!
db_ptr
->
eof
())
{
if
(
db_ptr
->
eoe
())
{
// Ignore it.
MS_LOG
(
DEBUG
)
<<
"Ignore eoe"
;
}
else
{
while
(
db_ptr
->
NumRows
()
>
0
)
{
TensorRow
row
;
RETURN_IF_NOT_OK
(
db_ptr
->
PopRow
(
&
row
));
row_id_type
row_id
=
row
.
getId
();
if
(
row_id
<
0
)
{
std
::
string
errMsg
=
"Expect positive row id: "
+
std
::
to_string
(
row_id
);
RETURN_STATUS_UNEXPECTED
(
errMsg
);
}
TensorRowRequest
*
rq
=
nullptr
;
RETURN_IF_NOT_OK
(
GetRq
(
row_id
,
&
rq
));
rq
->
WakeUpAny
(
std
::
move
(
row
));
// Let the cleaner to flush out this row (async) to the cache server.
RETURN_IF_NOT_OK
(
io_que_
->
EmplaceBack
(
row_id
));
}
}
RETURN_IF_NOT_OK
(
cache_missing_stream
->
GetNextBuffer
(
&
db_ptr
,
workerId
));
}
return
Status
::
OK
();
}
Status
CacheMergeOp
::
Cleaner
()
{
TaskManager
::
FindMe
()
->
Post
();
while
(
true
)
{
row_id_type
row_id
;
RETURN_IF_NOT_OK
(
io_que_
->
PopFront
(
&
row_id
));
if
(
row_id
<
0
)
{
break
;
}
TensorRowRequest
*
rq
=
nullptr
;
RETURN_IF_NOT_OK
(
GetRq
(
row_id
,
&
rq
));
if
(
rq
->
GetState
()
==
TensorRowRequest
::
State
::
kClean
)
{
// If already flushed, move on to the next one.
continue
;
}
TensorRow
row
;
RETURN_IF_NOT_OK
(
rq
->
Release
(
&
row
));
CHECK_FAIL_RETURN_UNEXPECTED
(
!
row
.
empty
(),
"Programming error"
);
Status
rc
=
cache_client_
->
WriteRow
(
row
);
// Bad rc should not bring down the pipeline
if
(
rc
.
IsError
())
{
MS_LOG
(
WARNING
)
<<
"Cache not successful."
<<
rc
.
ToString
();
}
rq
->
SetState
(
TensorRowRequest
::
State
::
kClean
);
}
return
Status
::
OK
();
}
Status
CacheMergeOp
::
GetRq
(
row_id_type
row_id
,
CacheMergeOp
::
TensorRowRequest
**
out
)
{
RETURN_UNEXPECTED_IF_NULL
(
out
);
std
::
unique_lock
<
std
::
mutex
>
lck
(
mux_
);
auto
it
=
cache_miss_map_
.
find
(
row_id
);
if
(
it
!=
cache_miss_map_
.
end
())
{
*
out
=
it
->
second
.
GetMutablePointer
();
}
else
{
// We will create a new one.
auto
alloc
=
Services
::
GetAllocator
<
TensorRowRequest
>
();
auto
r
=
cache_miss_map_
.
emplace
(
row_id
,
MemGuard
<
TensorRowRequest
,
Allocator
<
TensorRowRequest
>>
(
alloc
));
if
(
r
.
second
)
{
auto
&
mem
=
r
.
first
->
second
;
RETURN_IF_NOT_OK
(
mem
.
allocate
(
1
,
row_id
));
*
out
=
mem
.
GetMutablePointer
();
}
else
{
RETURN_STATUS_UNEXPECTED
(
"Map insert fail."
);
}
}
return
Status
::
OK
();
}
Status
CacheMergeOp
::
PrepareNodePostAction
()
{
// Run any common code from super class first before adding our own
// specific logic
CHECK_FAIL_RETURN_UNEXPECTED
(
child_
.
size
()
==
2
,
"Incorrect number of children"
);
RETURN_IF_NOT_OK
(
ParallelOp
::
PrepareNodePostAction
());
// Get the computed check sum from all ops in the cache miss class
uint32_t
cache_crc
=
DatasetOp
::
GenerateCRC
(
child_
[
kCacheMissChildIdx
]);
// This is a mappable cache op so the id's need to be generated.
// Construct the cache
const
bool
generate_ids
=
false
;
Status
rc
=
cache_client_
->
CreateCache
(
cache_crc
,
generate_ids
);
if
(
rc
.
get_code
()
==
StatusCode
::
kDuplicateKey
)
{
// We are told the cache has been created already.
MS_LOG
(
INFO
)
<<
"Cache created already"
;
rc
=
Status
::
OK
();
}
RETURN_IF_NOT_OK
(
rc
);
return
Status
::
OK
();
}
Status
CacheMergeOp
::
ComputeColMap
()
{
CHECK_FAIL_RETURN_UNEXPECTED
(
child_
[
kCacheMissChildIdx
]
!=
nullptr
,
"Cache miss stream empty"
);
if
(
column_name_id_map
().
empty
())
{
column_name_id_map_
=
child_
[
kCacheMissChildIdx
]
->
column_name_id_map
();
}
CHECK_FAIL_RETURN_UNEXPECTED
(
!
column_name_id_map
().
empty
(),
"No column map detected"
);
return
Status
::
OK
();
}
Status
CacheMergeOp
::
TensorRowRequest
::
Wait
(
TensorRow
*
out
)
{
RETURN_UNEXPECTED_IF_NULL
(
out
);
// Block until the missing row is in the pool.
RETURN_IF_NOT_OK
(
use_count_
.
P
());
std
::
unique_lock
<
std
::
mutex
>
lck
(
dq_mux_
);
CHECK_FAIL_RETURN_UNEXPECTED
(
!
row_
.
empty
(),
"Programming error"
);
*
out
=
std
::
move
(
row_
.
front
());
row_
.
pop_front
();
return
Status
::
OK
();
}
void
CacheMergeOp
::
TensorRowRequest
::
WakeUpAny
(
TensorRow
&&
row
)
{
std
::
unique_lock
<
std
::
mutex
>
lck
(
dq_mux_
);
// Technically number of this row shows up in the cache miss stream is equal to the number
// of P() call. However the cleaner wants it too. So we need an extra copy.
if
(
GetState
()
==
State
::
kEmpty
)
{
// We will do a deep copy
for
(
auto
&
ts
:
row
)
{
auto
out_ts
=
std
::
make_shared
<
Tensor
>
(
ts
->
shape
(),
ts
->
type
(),
ts
->
GetBuffer
(),
ts
->
SizeInBytes
());
cleaner_copy_
.
push_back
(
out_ts
);
}
cleaner_copy_
.
setId
(
row
.
getId
());
// Change the state to dirty
SetState
(
State
::
kDirty
);
}
row_
.
push_back
(
std
::
move
(
row
));
// Bump up the use count by 1. This wake up any parallel worker which is waiting
// for this row.
use_count_
.
V
();
}
Status
CacheMergeOp
::
TensorRowRequest
::
Release
(
TensorRow
*
out
)
{
RETURN_UNEXPECTED_IF_NULL
(
out
);
// We are not holding any mutex here because the cleaner isn't really touching the deque row_.
// In case we have multiple cleaners and they all see the copy, only one of them will
// get it.
auto
expected
=
State
::
kDirty
;
if
(
st_
.
compare_exchange_strong
(
expected
,
State
::
kClean
))
{
*
out
=
std
::
move
(
cleaner_copy_
);
}
return
Status
::
OK
();
}
// Builder constructor. Creates the builder object.
CacheMergeOp
::
Builder
::
Builder
()
:
build_cache_client_
(
nullptr
),
build_sampler_
(
nullptr
)
{
std
::
shared_ptr
<
ConfigManager
>
cfg
=
GlobalContext
::
config_manager
();
build_num_workers_
=
cfg
->
num_parallel_workers
();
build_op_connector_size_
=
cfg
->
op_connector_size
();
build_num_cleaners_
=
1
;
}
// Check if the required parameters are set by the builder.
Status
CacheMergeOp
::
Builder
::
SanityCheck
()
const
{
if
(
build_cache_client_
==
nullptr
)
{
return
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
"CacheMergeOp requires a CacheClient"
);
}
// Make sure the cache client has a valid session
if
(
!
build_cache_client_
->
session_id
())
{
return
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
"Cache client for CacheMergeOp is missing session id"
);
}
return
Status
::
OK
();
}
// The builder "build" method creates the final object and does some init on it
Status
CacheMergeOp
::
Builder
::
Build
(
std
::
shared_ptr
<
CacheMergeOp
>
*
ptr
)
{
RETURN_IF_NOT_OK
(
SanityCheck
());
*
ptr
=
std
::
make_shared
<
CacheMergeOp
>
(
build_num_workers_
,
build_op_connector_size_
,
build_num_cleaners_
,
build_cache_client_
,
build_sampler_
);
return
Status
::
OK
();
}
// Pre-Visitor accept method for NodePass
Status
CacheMergeOp
::
PreAccept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call the pre-visitation
return
p
->
PreRunOnNode
(
shared_from_base
<
CacheMergeOp
>
(),
modified
);
}
// Visitor accept method for NodePass
Status
CacheMergeOp
::
Accept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call visitor
return
p
->
RunOnNode
(
shared_from_base
<
CacheMergeOp
>
(),
modified
);
}
Status
CacheMergeOp
::
EoeReceived
(
int32_t
worker_id
)
{
// If we are in a repeat path, send the eoe up.
// Otherwise ignore it.
if
(
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
))
{
return
DatasetOp
::
EoeReceived
(
worker_id
);
}
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/cache_merge_op.h
0 → 100644
浏览文件 @
7c1bc519
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_
#define DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_
#include <atomic>
#include <deque>
#include <map>
#include <memory>
#include <mutex>
#include <string>
#include <utility>
#include "dataset/core/tensor_row.h"
#include "dataset/engine/cache/cache_client.h"
#include "dataset/engine/datasetops/parallel_op.h"
#include "dataset/engine/dataset_iterator.h"
#include "dataset/util/queue.h"
#include "dataset/util/semaphore.h"
namespace
mindspore
{
namespace
dataset
{
/// \brief Provides method to merge two streams (one from CacheLookup and one from cache miss stream) into one single
/// stream
class
CacheMergeOp
:
public
ParallelOp
{
public:
// Some handshake structures among the main thread, cleaner threads and parallel op threads.
class
TensorRowRequest
{
public:
enum
class
State
:
uint8_t
{
kEmpty
=
0
,
// No row in the deque
kDirty
=
1
,
// Cleaner hasn't flushed it to the cache server yet.
kClean
=
2
// The row has been flushed already.
};
explicit
TensorRowRequest
(
row_id_type
id
)
:
st_
(
State
::
kEmpty
),
use_count_
(
0
)
{}
~
TensorRowRequest
()
=
default
;
State
GetState
()
const
{
return
st_
;
}
void
SetState
(
State
newState
)
{
st_
=
newState
;
}
Status
Wait
(
TensorRow
*
out
);
void
WakeUpAny
(
TensorRow
&&
row
);
Status
Release
(
TensorRow
*
out
);
private:
std
::
mutex
dq_mux_
;
std
::
atomic
<
State
>
st_
;
Semaphore
use_count_
;
std
::
deque
<
TensorRow
>
row_
;
TensorRow
cleaner_copy_
;
};
constexpr
static
int
kCacheHitChildIdx
=
0
;
// Cache hit stream
constexpr
static
int
kCacheMissChildIdx
=
1
;
// Cache miss stream
/// \brief The nested builder class inside of the CacheMergeOp is used to help manage all of
/// the arguments for constructing it. Use the builder by setting each argument
/// with the provided set methods, and then finally call the build method to execute
/// the actual construction.
class
Builder
{
public:
/// Builder constructor. Creates the builder object.
/// \note No default args
Builder
();
/// Default destructor
~
Builder
()
=
default
;
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder
&
SetNumWorkers
(
int32_t
num_workers
)
{
build_num_workers_
=
num_workers
;
return
*
this
;
}
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder
&
SetOpConnectorSize
(
int32_t
connector_size
)
{
build_op_connector_size_
=
connector_size
;
return
*
this
;
}
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder
&
SetClient
(
std
::
shared_ptr
<
CacheClient
>
cache_client
)
{
build_cache_client_
=
cache_client
;
return
*
this
;
}
/// \brief Setter method
/// \param sampler
/// \return Builder setter method returns reference to the builder.
Builder
&
SetSampler
(
std
::
shared_ptr
<
Sampler
>
sampler
)
{
build_sampler_
=
std
::
move
(
sampler
);
return
*
this
;
}
/// \brief Setter method
/// \param num_cleaners
/// \return Builder setter method returns reference to the builder.
Builder
&
SetNumCleaner
(
int32_t
num_cleaners
)
{
build_num_cleaners_
=
num_cleaners
;
return
*
this
;
}
/// The builder "build" method creates the final object and does some init on it.
/// \param ptr The shared_ptr to the new CacheMergeOp object
/// \return Status
Status
Build
(
std
::
shared_ptr
<
CacheMergeOp
>
*
ptr
);
private:
int32_t
build_num_workers_
;
int32_t
build_op_connector_size_
;
int32_t
build_num_cleaners_
;
std
::
shared_ptr
<
CacheClient
>
build_cache_client_
;
std
::
shared_ptr
<
Sampler
>
build_sampler_
;
/// Check if the required parameters are set by the builder.
/// \return Status The error code return
Status
SanityCheck
()
const
;
};
/// \brief Constructor
/// \param numWorkers Number of parallel workers as a derived class of ParallelOp
/// \param opConnector Size Connector size as a derived class of ParallelOp
/// \param numCleaners Number of cleaners to move cache miss rows into the cache server
/// \param cache_client CacheClient to commmunicate with the Cache server
/// \param sampler as a derived class of ParallelOp
CacheMergeOp
(
int32_t
numWorkers
,
int32_t
opConnectorSize
,
int32_t
numCleaners
,
std
::
shared_ptr
<
CacheClient
>
cache_client
,
const
std
::
shared_ptr
<
Sampler
>
&
sampler
);
~
CacheMergeOp
();
void
Print
(
std
::
ostream
&
out
,
bool
show_all
)
const
override
;
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
CacheMergeOp
&
mo
)
{
mo
.
Print
(
out
,
false
);
return
out
;
}
/// \brief Master thread responsible to spawn all the necessary worker threads for the two streams and
/// the threads for the cleaners.
/// \return
Status
operator
()()
override
;
/// \brief Entry function for worker thread that fetch rows from CacheLookupOp
/// \param workerId
/// \return Status object
Status
WorkerEntry
(
int32_t
workerId
)
override
;
Status
PrepareNodePostAction
()
override
;
/// \brief Entry function for worker thread that fetch rows from the cache miss stream
/// \param workerId
/// \return Status object
Status
CacheMissWorkerEntry
(
int32_t
workerId
);
Status
GetRq
(
row_id_type
row_id
,
TensorRowRequest
**
);
/// \brief Base-class override for NodePass pre-visit acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status
PreAccept
(
NodePass
*
p
,
bool
*
modified
)
override
;
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status
Accept
(
NodePass
*
p
,
bool
*
modified
)
override
;
/// \brief Base-class override for eoe handling
/// \param worker_id
/// \return Status object
Status
EoeReceived
(
int32_t
worker_id
)
override
;
protected:
Status
ComputeColMap
()
override
;
private:
std
::
mutex
mux_
;
std
::
map
<
row_id_type
,
MemGuard
<
TensorRowRequest
,
Allocator
<
TensorRowRequest
>>>
cache_miss_map_
;
std
::
unique_ptr
<
Queue
<
row_id_type
>>
io_que_
;
std
::
shared_ptr
<
CacheClient
>
cache_client_
;
int32_t
num_cleaners_
;
/// \brief These are the entry functions for the cleaner threads. Each cleaner is responsible for
/// moving cache miss TensorRow into the CacheServer.
/// \return Status object
Status
Cleaner
();
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_CACHE_MERGE_OP_H_
mindspore/ccsrc/dataset/engine/datasetops/cache_op.cc
0 → 100644
浏览文件 @
7c1bc519
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dataset/engine/datasetops/cache_op.h"
#include <memory>
#include <vector>
#include "dataset/core/config_manager.h"
#include "dataset/core/constants.h"
#include "dataset/core/global_context.h"
#include "dataset/engine/datasetops/repeat_op.h"
#include "dataset/engine/data_buffer.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/util/task_manager.h"
#include "utils/log_adapter.h"
namespace
mindspore
{
namespace
dataset
{
// Builder constructor. Creates the builder object.
CacheOp
::
Builder
::
Builder
()
:
build_cache_client_
(
nullptr
),
build_sampler_
(
nullptr
)
{
std
::
shared_ptr
<
ConfigManager
>
cfg
=
GlobalContext
::
config_manager
();
build_num_workers_
=
cfg
->
num_parallel_workers
();
rows_per_buffer_
=
cfg
->
rows_per_buffer
();
build_op_connector_size_
=
cfg
->
op_connector_size
();
}
// Check if the required parameters are set by the builder.
Status
CacheOp
::
Builder
::
SanityCheck
()
const
{
if
(
build_cache_client_
==
nullptr
)
{
return
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
"CacheOp requires a CacheClient"
);
}
// Make sure the cache client has a valid session
if
(
!
build_cache_client_
->
session_id
())
{
return
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
"Cache client for CacheOp is missing session id"
);
}
return
Status
::
OK
();
}
// The builder "build" method creates the final object and does some init on it
Status
CacheOp
::
Builder
::
Build
(
std
::
shared_ptr
<
CacheOp
>
*
ptr
)
{
RETURN_IF_NOT_OK
(
SanityCheck
());
*
ptr
=
std
::
make_shared
<
CacheOp
>
(
build_num_workers_
,
build_op_connector_size_
,
rows_per_buffer_
,
build_cache_client_
,
build_sampler_
);
RETURN_IF_NOT_OK
((
*
ptr
)
->
InitCache
());
return
Status
::
OK
();
}
// Constructor of CacheOp
CacheOp
::
CacheOp
(
int32_t
num_workers
,
int32_t
op_connector_size
,
int32_t
rows_per_buf
,
std
::
shared_ptr
<
CacheClient
>
cache_client
,
std
::
shared_ptr
<
Sampler
>
sampler
)
:
CacheBase
(
num_workers
,
op_connector_size
,
rows_per_buf
,
cache_client
,
sampler
),
num_guys_in_
(
0
),
phase_
(
Phase
::
kBuildPhase
)
{}
// Destructor
CacheOp
::~
CacheOp
()
=
default
;
// Private function for cache setup/init work just after construction
Status
CacheOp
::
InitCache
()
{
return
Status
::
OK
();
}
// This class functor will provide the master loop that drives the logic for performing the work
Status
CacheOp
::
operator
()()
{
if
(
!
sampler_
)
{
return
Status
(
StatusCode
::
kUnexpectedError
,
__LINE__
,
__FILE__
,
"CacheOp requires a sampler before it can be executed!"
);
}
RETURN_IF_NOT_OK
(
RegisterResources
());
// Kick off the workers
RETURN_IF_NOT_OK
(
tree_
->
LaunchWorkers
(
num_workers_
,
std
::
bind
(
&
CacheOp
::
WorkerEntry
,
this
,
std
::
placeholders
::
_1
)));
// required task group sync after launching workers
TaskManager
::
FindMe
()
->
Post
();
// Wait for the workers to finish caching the rows.
RETURN_IF_NOT_OK
(
WaitForCachingAllRows
());
RETURN_IF_NOT_OK
(
FetchSamplesToWorkers
());
return
Status
::
OK
();
}
Status
CacheOp
::
CacheAllRows
(
int32_t
worker_id
)
{
// If the current phase is to fill the cache, do it then.
if
(
phase_
==
Phase
::
kBuildPhase
)
{
// We will take the chance to cache the schema at the server.
// Just do it once and pick one worker to do it.
if
(
worker_id
==
0
)
{
RETURN_IF_NOT_OK
(
cache_client_
->
CacheSchema
(
column_name_id_map
()));
}
MS_LOG
(
INFO
)
<<
"CacheOp first epoch SAVE mode started. Worker: "
<<
worker_id
;
// SAVE mode loop
std
::
unique_ptr
<
DataBuffer
>
db_ptr
;
RETURN_IF_NOT_OK
(
this
->
GetNextInput
(
&
db_ptr
,
worker_id
,
0
));
while
(
!
db_ptr
->
eof
())
{
if
(
!
db_ptr
->
eoe
())
{
RETURN_IF_NOT_OK
(
cache_client_
->
WriteBuffer
(
std
::
move
(
db_ptr
)));
}
else
{
// In a repeat-over-cache scenario, any of the "real" leaf operators below us have been set up
// as non-repeating leaf ops. As such, they only do one epoch and then quit. Since we got the
// the eoe to indicate the end of the epoch, we should next expect to get the eof.
// Drain this eof so that we don't leave it sitting there on a connector that we'll never fetch
// from again.
RETURN_IF_NOT_OK
(
this
->
GetNextInput
(
&
db_ptr
,
worker_id
,
0
));
if
(
!
db_ptr
->
eof
())
{
RETURN_STATUS_UNEXPECTED
(
"Cache op expects to get an eof after eoe from child."
);
}
}
RETURN_IF_NOT_OK
(
this
->
GetNextInput
(
&
db_ptr
,
worker_id
,
0
));
}
}
// Let the main guy know we are done.
auto
last_guy_in
=
num_guys_in_
.
fetch_add
(
1
);
if
((
last_guy_in
+
1
)
==
num_workers_
)
{
rows_cache_done_
.
Set
();
}
else
{
// Let's do a sync up here.
RETURN_IF_NOT_OK
(
rows_cache_done_
.
Wait
());
}
return
Status
::
OK
();
}
Status
CacheOp
::
WaitForCachingAllRows
()
{
// Wait for the workers to finish caching the rows.
RETURN_IF_NOT_OK
(
rows_cache_done_
.
Wait
());
// Move from build phase to fetch phase if we are the one to fill the cache
if
(
phase_
==
Phase
::
kBuildPhase
)
{
RETURN_IF_NOT_OK
(
cache_client_
->
BuildPhaseDone
());
// Move to the next phase
phase_
=
Phase
::
kFetchPhase
;
}
// Get statistics from the server, and if we are not the one to create the cache,
// wait until the state changed from build phase to fetch base.
CacheClient
::
ServiceStat
stat
{};
bool
BuildPhaseDone
=
true
;
do
{
RETURN_IF_NOT_OK
(
cache_client_
->
GetStat
(
&
stat
));
BuildPhaseDone
=
stat
.
cache_service_state
==
static_cast
<
uint8_t
>
(
CacheService
::
State
::
kFetchPhase
);
if
(
!
BuildPhaseDone
)
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
100
));
}
}
while
(
!
BuildPhaseDone
);
const
row_id_type
min_key
=
stat
.
min_row_id
;
const
row_id_type
max_key
=
stat
.
max_row_id
;
num_rows_
=
max_key
-
min_key
+
1
;
MS_LOG
(
INFO
)
<<
"Number of rows cached: "
<<
num_rows_
;
MS_LOG
(
INFO
)
<<
"Number of rows cached in memory : "
<<
stat
.
num_mem_cached
;
MS_LOG
(
INFO
)
<<
"Number of rows spilled to disk : "
<<
stat
.
num_disk_cached
;
// Now all rows are cached and we have done a sync point check up. Next phase is
// is pick up fetch input from sampler and pass up to the caller.
RETURN_IF_NOT_OK
(
sampler_
->
HandshakeRandomAccessOp
(
this
));
return
Status
::
OK
();
}
Status
CacheOp
::
WorkerEntry
(
int32_t
worker_id
)
{
TaskManager
::
FindMe
()
->
Post
();
RETURN_IF_NOT_OK
(
CacheAllRows
(
worker_id
));
RETURN_IF_NOT_OK
(
FetchFromCache
(
worker_id
));
return
Status
::
OK
();
}
Status
CacheOp
::
RegisterResources
()
{
RETURN_IF_NOT_OK
(
CacheBase
::
RegisterResources
());
RETURN_IF_NOT_OK
(
rows_cache_done_
.
Register
(
tree_
->
AllTasks
()));
RETURN_IF_NOT_OK
(
keys_miss_
.
Register
(
tree_
->
AllTasks
()));
return
Status
::
OK
();
}
// Base-class override for setting specific CacheOp configurations. This code will be called
// during the execution tree prepare phase BEFORE traversing down to child operators.
uint32_t
CacheOp
::
PrepareFlags
()
const
{
return
ExecutionTree
::
kDePrepCache
;
}
// Base-class override for special eoe handler.
// CacheOp must override this because it shall not perform default handling of eoe. Instead
// the CacheOp manages actions related to the end of the epoch.
Status
CacheOp
::
EoeReceived
(
int32_t
worker_id
)
{
state_
=
OpState
::
kDeOpIdle
;
return
Status
::
OK
();
}
// Base-class override for handling cases when an eof is received.
Status
CacheOp
::
EofReceived
(
int32_t
worker_id
)
{
// eofReceived is overloaded because we want to manually handle this eof.
// Specifically, the default behaviour is to pack it and flow it up to the next connection.
// In this case, we want a no-op behaviour so that we can perform correct action.
return
Status
::
OK
();
}
// Pre-Visitor accept method for NodePass
Status
CacheOp
::
PreAccept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call the pre-visitation
return
p
->
PreRunOnNode
(
shared_from_base
<
CacheOp
>
(),
modified
);
}
// Visitor accept method for NodePass
Status
CacheOp
::
Accept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call visitor
return
p
->
RunOnNode
(
shared_from_base
<
CacheOp
>
(),
modified
);
}
// A public wrapper for creating the cache through the client
Status
CacheOp
::
CreateCache
(
uint32_t
cache_crc
)
{
// This is a non-mappable cache op so the id's need to be generated.
// Construct the cache
const
bool
generate_ids
=
true
;
Status
rc
=
cache_client_
->
CreateCache
(
cache_crc
,
generate_ids
);
if
(
rc
.
get_code
()
==
StatusCode
::
kDuplicateKey
)
{
// We are told the cache has been created already. So we skip the build phase.
phase_
=
Phase
::
kFetchPhase
;
rc
=
Status
::
OK
();
}
RETURN_IF_NOT_OK
(
rc
);
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/cache_op.h
0 → 100644
浏览文件 @
7c1bc519
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_DATASETOPS_CACHE_OP_H_
#define DATASET_ENGINE_DATASETOPS_CACHE_OP_H_
#include <atomic>
#include <string>
#include <utility>
#include <memory>
#include "dataset/engine/datasetops/cache_base_op.h"
namespace
mindspore
{
namespace
dataset
{
/// \brief CacheOp provides a memory/disk cache that acts as a save-point within a non-mappable dataset.
/// \note For mappable dataset, please see CacheLookupOp.
/// \see CacheLookupOp
class
CacheOp
:
public
CacheBase
,
public
RandomAccessOp
{
public:
// This CacheOp is for non-mappable case where it is divided into two phases.
// The first phase is we cache all the rows from the child (and let the cache server
// assigns row id). No read access in the first phase. Once the cache is fully built,
// we switch to second phase and fetch requests from the sampler.
enum
class
Phase
:
uint8_t
{
kBuildPhase
=
0
,
kFetchPhase
=
1
};
/// \brief The nested builder class inside of the CacheOp is used to help manage all of
/// the arguments for constructing it. Use the builder by setting each argument
/// with the provided set methods, and then finally call the build method to execute
/// the actual construction.
class
Builder
{
public:
// Builder constructor. Creates the builder object.
// @note No default args
// @return This is a constructor.
Builder
();
// Default destructor
~
Builder
()
=
default
;
/// \brief Setter method.
/// \return Builder setter method returns reference to the builder.
Builder
&
SetNumWorkers
(
int32_t
num_workers
)
{
build_num_workers_
=
num_workers
;
return
*
this
;
}
/// \brief Setter method.
/// \return Builder setter method returns reference to the builder.
Builder
&
SetOpConnectorSize
(
int32_t
connector_size
)
{
build_op_connector_size_
=
connector_size
;
return
*
this
;
}
/// Setter method.
/// \return Builder setter method returns reference to the builder.
Builder
&
SetClient
(
std
::
shared_ptr
<
CacheClient
>
cache_client
)
{
build_cache_client_
=
cache_client
;
return
*
this
;
}
/// \brief Setter method
/// \param rows_per_buffer
/// \return Builder setter method returns reference to the builder.
Builder
&
SetRowsPerBuffer
(
int32_t
rows_per_buffer
)
{
rows_per_buffer_
=
rows_per_buffer
;
return
*
this
;
}
/// \brief Setter method
/// \param sampler
/// \return Builder setter method returns reference to the builder.
Builder
&
SetSampler
(
std
::
shared_ptr
<
Sampler
>
sampler
)
{
build_sampler_
=
std
::
move
(
sampler
);
return
*
this
;
}
/// \brief The builder "build" method creates the final object and does some init on it.
/// \param ptr The shared_ptr to the new CacheOp object
/// \return Status
Status
Build
(
std
::
shared_ptr
<
CacheOp
>
*
ptr
);
private:
int32_t
build_num_workers_
;
int32_t
rows_per_buffer_
;
int32_t
build_op_connector_size_
;
std
::
shared_ptr
<
CacheClient
>
build_cache_client_
;
std
::
shared_ptr
<
Sampler
>
build_sampler_
;
/// \brief Check if the required parameters are set by the builder.
/// \return Status The error code return
Status
SanityCheck
()
const
;
};
/// \brief Constructor of CacheOp
/// \note The builder class should be used to call it.
/// \param num_workers The number of worker threads.
/// \param op_connector_size The size of each queue in the connector.
CacheOp
(
int32_t
num_workers
,
int32_t
op_connector_size
,
int32_t
rows_per_buf
,
std
::
shared_ptr
<
CacheClient
>
cache_client
,
std
::
shared_ptr
<
Sampler
>
sampler
);
// Destructor
~
CacheOp
();
/// \brief Base-class override for setting specific CacheOp configurations. This code will be called
/// during the execution tree prepare phase BEFORE traversing down to child operators.
uint32_t
PrepareFlags
()
const
override
;
/// \brief Base-class override for special eoe handler.
/// CacheOp must override this because it shall not perform default handling of eoe. Instead
/// the CacheOp manages actions related to the end of the epoch.
/// \return Status - The error code return
Status
EoeReceived
(
int32_t
worker_id
)
override
;
/// \brief Base-class override for NodePass pre-visit acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status
PreAccept
(
NodePass
*
p
,
bool
*
modified
)
override
;
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status
Accept
(
NodePass
*
p
,
bool
*
modified
)
override
;
/// \brief Base-class override for handling cases when an eof is received.
/// \param worker_id - The worker id
/// \return Status - The error code return
Status
EofReceived
(
int32_t
worker_id
)
override
;
Status
operator
()()
override
;
Status
WorkerEntry
(
int32_t
worker_id
)
override
;
/// \brief Base-class override for handling cases if we allow cache miss
bool
AllowCacheMiss
()
override
{
return
false
;
}
/// \brief Base-class override for the name of this operator
std
::
string
Name
()
const
override
{
return
"CacheOp"
;
}
/// \brief A public wrapper for creating the cache through the client
/// \param[in] cache_crc The crc that identifies the cache
/// \see cache_pass.cc
/// \return Status return code
Status
CreateCache
(
uint32_t
cache_crc
);
private:
WaitPost
rows_cache_done_
;
std
::
atomic
<
int64_t
>
num_guys_in_
;
Phase
phase_
;
/// \brief The main thread will wait until all the rows are cached and will start the handshake with the sampler.
/// \return Status object
Status
WaitForCachingAllRows
();
/// \brief For non-mappable dataset, there is a build phase where we cache all the rows.
/// \return Status object
Status
CacheAllRows
(
int32_t
worker_id
);
Status
RegisterResources
()
override
;
/// \brief Private function for cache setup/init work just after construction
/// \return Status The error code return
Status
InitCache
();
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_ENGINE_DATASETOPS_CACHE_OP_H_
mindspore/ccsrc/dataset/engine/datasetops/concat_op.cc
浏览文件 @
7c1bc519
...
...
@@ -61,46 +61,39 @@ void ConcatOp::Print(std::ostream &out, bool show_all) const {
Status
ConcatOp
::
operator
()()
{
// The children_num_ parameter needs to be put here
children_num_
=
static_cast
<
int32_t
>
(
child_
.
size
());
TaskManager
::
FindMe
()
->
Post
();
std
::
unique_ptr
<
DataBuffer
>
buf
;
RETURN_IF_NOT_OK
(
child_
[
0
]
->
GetNextBuffer
(
&
buf
));
int
eof_count
=
0
;
while
(
eof_count
!=
children_num_
)
{
while
(
eof_count
==
0
)
{
for
(
int
i
=
0
;
i
<
children_num_
;
i
++
)
{
// 1. Throw the eof buffer when meet it
if
(
buf
->
eof
()
||
buf
->
eoe
())
{
RETURN_IF_NOT_OK
(
child_
[
i
]
->
GetNextBuffer
(
&
buf
));
// 1. Read the first buffer
RETURN_IF_NOT_OK
(
child_
[
i
]
->
GetNextBuffer
(
&
buf
));
if
(
buf
->
eof
())
{
eof_count
++
;
continue
;
}
// 2. Do verification as for column name, column data type and rank of column data
RETURN_IF_NOT_OK
(
Verify
(
i
,
buf
));
if
(
!
buf
->
eoe
())
{
RETURN_IF_NOT_OK
(
Verify
(
i
,
buf
));
}
// 3. Put the data into output_connector
while
(
!
buf
->
eoe
()
&&
!
buf
->
eof
())
{
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
buf
)));
RETURN_IF_NOT_OK
(
child_
[
i
]
->
GetNextBuffer
(
&
buf
));
}
// 4. Throw the eoe buffer when meet it
if
(
buf
->
eoe
()
&&
(
!
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
)
||
BitTest
(
op_ctrl_flags_
,
kDeOpLastRepeat
)))
{
RETURN_IF_NOT_OK
(
child_
[
i
]
->
GetNextBuffer
(
&
buf
));
}
// 5. Add eoe buffer after get buffer from all child
if
(
i
==
(
children_num_
-
1
))
{
auto
eoe_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
eoe_buffer
)));
}
if
(
buf
->
eof
())
{
eof_count
++
;
}
}
// 4. Add eoe buffer after get buffer from all child
if
(
eof_count
==
0
)
{
auto
eoe_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOE
);
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
eoe_buffer
)));
}
}
// 6. Add eof buffer in the end manually
CHECK_FAIL_RETURN_UNEXPECTED
(
eof_count
==
children_num_
,
"Something went wrong, eof count does not match the number of children."
);
// 5. Add eof buffer in the end manually
MS_LOG
(
DEBUG
)
<<
"Add the eof buffer manualy in the end."
;
auto
eof_buffer
=
std
::
make_unique
<
DataBuffer
>
(
0
,
DataBuffer
::
kDeBFlagEOF
);
RETURN_IF_NOT_OK
(
out_connector_
->
Add
(
0
,
std
::
move
(
eof_buffer
)));
return
Status
::
OK
();
}
...
...
@@ -126,12 +119,6 @@ Status ConcatOp::Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf) {
return
Status
::
OK
();
}
Status
ConcatOp
::
PrepareNodePostAction
()
{
RETURN_IF_NOT_OK
(
PipelineOp
::
PrepareNodePostAction
());
tree_
->
AddToEOEOpStack
(
shared_from_this
());
return
Status
::
OK
();
}
// We need to overwrite the super class ComputeColMap here because the number of children is more than 1.
Status
ConcatOp
::
ComputeColMap
()
{
if
(
column_name_id_map_
.
empty
())
{
...
...
mindspore/ccsrc/dataset/engine/datasetops/concat_op.h
浏览文件 @
7c1bc519
...
...
@@ -75,12 +75,6 @@ class ConcatOp : public PipelineOp {
// @return Status - The error code return
Status
operator
()()
override
;
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
Status
PrepareNodePostAction
()
override
;
// Op name getter
// @return Name of the current Op
std
::
string
Name
()
const
override
{
return
"ConcatOp"
;
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/dataset_op.cc
浏览文件 @
7c1bc519
...
...
@@ -153,16 +153,38 @@ Status DatasetOp::Remove() {
}
}
// Finally, clear "this" op's parent and child pointers since we have just
// disconnected it from the tree and invalidate it's fields.
child_
.
clear
();
parent_
.
clear
();
operator_id_
=
kInvalidOperatorId
;
tree_
=
nullptr
;
return
Status
::
OK
();
}
// Getter function to get a shared pointer to our child
Adds a operator to become our child.
// Getter function to get a shared pointer to our child
std
::
shared_ptr
<
DatasetOp
>
DatasetOp
::
child
(
int32_t
child_index
)
const
{
std
::
shared_ptr
<
DatasetOp
>
return_op
=
nullptr
;
if
(
child_
.
empty
())
{
return
return_op
;
}
MS_ASSERT
(
child_index
<
static_cast
<
int
>
(
child_
.
size
()));
// Return a shared pointer
return
child_
[
child_index
];
}
// Getter function to get the parent pointer
void
DatasetOp
::
Parent
(
DatasetOp
**
parent
,
int32_t
parent_index
)
const
{
if
(
parent_
.
empty
())
{
// common case if this is a root node
*
parent
=
nullptr
;
}
else
{
MS_ASSERT
(
parent_index
<
static_cast
<
int
>
(
parent_
.
size
()));
*
parent
=
parent_
[
parent_index
];
}
}
// Creates the connector within this operator
void
DatasetOp
::
CreateConnector
(
int32_t
num_producers
,
int32_t
num_consumers
)
{
MS_LOG
(
DEBUG
)
<<
"Creating connector in tree operator: "
<<
operator_id_
<<
". Producer: "
<<
num_producers
...
...
@@ -264,19 +286,11 @@ Status DatasetOp::EofReceived(int32_t worker_id) {
// During tree prepare phase, operators may have specific pre-operations to perform depending on
// their role.
Status
DatasetOp
::
PrepareNodePreAction
()
{
if
(
BitTest
(
tree_
->
PrepareFlags
(),
ExecutionTree
::
kDePrepRepeat
))
set_control_flag
(
kDeOpRepeated
);
return
Status
::
OK
();
}
Status
DatasetOp
::
PrepareNodePreAction
()
{
return
Status
::
OK
();
}
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
Status
DatasetOp
::
PrepareNodePostAction
()
{
// If this op does not have any children and it is in a repeat path of the tree...
if
(
child_
.
empty
()
&&
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
))
{
// push ourselves onto the eoe operator stack. Later, a repeat/epoch ctrl operator
// above us will consume them.
tree_
->
AddToEOEOpStack
(
shared_from_this
());
}
// Creating Connector object for each op.
// The consumer of the root node is assumed to be one thread.
// If multiple threads are consuming from the root node, they will get the ordered data in round robin fashion.
...
...
@@ -346,34 +360,13 @@ Status DatasetOp::Accept(NodePass *p, bool *modified) {
return
p
->
RunOnNode
(
shared_from_this
(),
modified
);
}
// A helper function with some common code that leaf nodes can use during
// prepare phase for checking if they need to assign a sampler to the cache.
Status
DatasetOp
::
SaveSamplerForCache
(
bool
random_access_op
)
{
// If we are a descendant under a cache op and we have a sampler, then save this sampler
// to a stack so that the cache can pick it up during it's processing above us.
if
(
sampler_
)
{
if
(
BitTest
(
tree_
->
PrepareFlags
(),
ExecutionTree
::
kDePrepCache
))
{
// use move semantic to set our sampler_ to null after the move. This is okay because a sampler is
// useless to a random data op. It was only being used as a temporary holding until the cache can
// be created
tree_
->
AddToSamplerStack
(
sampler_
);
MS_LOG
(
INFO
)
<<
"Preparing a leaf op: passing sampler up the tree for Cache handling."
;
}
else
if
(
!
random_access_op
)
{
// A sampler exists, but we are not in a caching tree and we are not a random access mappable leaf.
// This is an error because that type of leaf does not use sampling unless there's a cache to hook it into.
RETURN_STATUS_UNEXPECTED
(
"Non-mappable leaf op has a sampler, but it only supports sampling if there is a cache after it in the tree"
);
}
}
if
(
!
random_access_op
)
{
// Since we don't truly need the sampler for this non-mappable dataset and it's been saved for the cache
// we can remove it now from the base.
sampler_
.
reset
();
}
// Getter for the sampler, and it also removes the sampler from the op
Status
DatasetOp
::
FetchRemoveSampler
(
std
::
shared_ptr
<
Sampler
>
*
sampler
)
{
*
sampler
=
sampler_
;
// It's okay if it sampler_ points to nullptr
sampler_
.
reset
();
// clear our member-copy of this pointer. We no longer have this sampler
return
Status
::
OK
();
}
uint32_t
DatasetOp
::
GenerateCRC
(
const
std
::
shared_ptr
<
DatasetOp
>
&
op
)
{
std
::
stringstream
ss
;
op
->
tree_
->
Print
(
ss
,
op
);
...
...
mindspore/ccsrc/dataset/engine/datasetops/dataset_op.h
浏览文件 @
7c1bc519
...
...
@@ -45,10 +45,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
public:
static
constexpr
int32_t
kInvalidOperatorId
=
-
1
;
//
Flags that control operator runtime behaviour
s
//
Operator control flag
s
enum
OpControlFlags
{
kDeOpNone
=
0
,
kDeOpRepeated
=
1
,
// Operator is a
leaf
node in a repeat path
kDeOpRepeated
=
1
,
// Operator is a node in a repeat path
kDeOpLastRepeat
=
1
<<
1
// We are in the last repeat loop
};
...
...
@@ -71,17 +71,23 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \param child - shared pointer to the child to remove.
Status
RemoveChild
(
std
::
shared_ptr
<
DatasetOp
>
child
);
/// \brief Removes this node from the tree and connects it's parent/child together
.
/// \brief Removes this node from the tree and connects it's parent/child together
/// \return Status eerror code returned
Status
Remove
();
/// \brief Getter function to get a shared pointer to our child
/// \param child_index - An operator can have n children. Indicates choose which child to return.
/// \param[in] child_index An operator can have n children. Indicates which child to return.
/// \return The shared pointer to the child. If there are no children, it returns null regardless of the given index
std
::
shared_ptr
<
DatasetOp
>
child
(
int32_t
child_index
)
const
;
/// \brief Inserts a operator as the parent current op.
/// Inserted op will become the sole parent of the current op.
/// The existing parent of the current op will be transferred to the inserted op.
/// \brief Getter function to get the pointer to our parent
/// If there are no parents, it returns null regardless of the given index
/// \param[in] parent_index An operator can have n parents. Indicates which parent to return.
void
Parent
(
DatasetOp
**
parent
,
int32_t
parent_index
)
const
;
// Inserts a operator as the parent current op.
// Inserted op will become the sole parent of the current op.
// The existing parent of the current op will be transferred to the inserted op.
Status
InsertAsParent
(
std
::
shared_ptr
<
DatasetOp
>
to_add
);
/// \brief Creates the connector within this operator
...
...
@@ -161,16 +167,6 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return Status - The error code return
virtual
Status
Reset
();
/// \brief This calls the reset function on this subtree in pre-order
/// \return Status - The error code return
virtual
Status
ResetSubtree
()
{
RETURN_IF_NOT_OK
(
Reset
());
for
(
const
auto
&
c
:
child_
)
{
RETURN_IF_NOT_OK
(
c
->
ResetSubtree
());
}
return
Status
::
OK
();
}
/// \brief During tree prepare phase, operators may have specific pre-operations to perform depending on
/// their role.
/// \notes Derived versions of this function should always call it's superclass version first
...
...
@@ -296,7 +292,12 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return Shared pointer to the sampler (may return nullptr)
std
::
shared_ptr
<
Sampler
>
sampler
()
{
return
sampler_
;
}
/// Computes a CRC value for the operator
/// \brief Getter for the sampler, and it also removes the sampler from the op
/// \param[out] sampler A pointer to the output sampler that was removed
/// \return Status error code
Status
FetchRemoveSampler
(
std
::
shared_ptr
<
Sampler
>
*
sampler
);
// Computes a CRC value for the operator
static
uint32_t
GenerateCRC
(
const
std
::
shared_ptr
<
DatasetOp
>
&
op
);
/// \brief A helper templated function for casting "this" pointer to shared_ptr<derived>
...
...
@@ -307,17 +308,24 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
return
std
::
static_pointer_cast
<
Derived
>
(
shared_from_this
());
}
protected:
/// Adds a parent operator to this operator
/// \notes External callers do not have access to this function.
/// \param parent - The parent node to add
void
AddParent
(
DatasetOp
*
parent
);
/// \brief Setter for the sampler. Allows you to overwrite a previous sampler with a new one.
void
SetSampler
(
std
::
shared_ptr
<
Sampler
>
sampler
)
{
sampler_
=
sampler
;
}
/// \brief Checks if this is a leaf node (0 children)
/// \return boolean returns true if it's a leaf
bool
IsLeaf
()
{
return
(
child_
.
empty
());
}
/// Removes a parent operator from this operator
/// \notes External callers do not have access to this function.
/// \param parent - The parent node to remove
protected:
/// \brief Removes a parent operator from this operator
/// \notes External callers do not have access to this function
/// \param[in] parent The parent node to remove
void
RemoveParent
(
const
DatasetOp
*
parent
);
/// \brief Adds a parent operator to this operator
/// \notes External callers do not have access to this function
/// \param[in] parent The parent node to add
void
AddParent
(
DatasetOp
*
parent
);
/// Compute the current op's column map using its child's column map.
/// Get called during the tree post-prepare phase in PrepareNodePostAction.
/// This base implementation just inherits the map from child 0, and can only be used if the number of children is 1.
...
...
@@ -325,12 +333,6 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return - Status
virtual
Status
ComputeColMap
();
/// A helper function with some common code that leaf nodes can use during
/// pre/pare phase for checking if they need to assign a sampler to the cache.
/// \param random_access_op - indicate if this is a mappable random access leaf or not
/// \return - Status
Status
SaveSamplerForCache
(
bool
random_access_op
);
std
::
vector
<
std
::
shared_ptr
<
DatasetOp
>>
child_
;
// Child nodes
std
::
vector
<
DatasetOp
*>
parent_
;
// Parent nodes. No ownership
std
::
shared_ptr
<
Sampler
>
sampler_
;
// Some leaf ops might have a sampler
...
...
mindspore/ccsrc/dataset/engine/datasetops/repeat_op.cc
浏览文件 @
7c1bc519
...
...
@@ -77,26 +77,6 @@ void RepeatOp::Print(std::ostream &out, bool show_all) const {
}
}
// Base-class override for executing specific RepeatOp configurations. This code will be called
// during the execution tree prepare phase when it is visiting this operator.
Status
RepeatOp
::
PrepareNodePostAction
()
{
// Run any common code from super class first before adding our own specific logic
RETURN_IF_NOT_OK
(
PipelineOp
::
PrepareNodePostAction
());
std
::
shared_ptr
<
DatasetOp
>
leaf_op
=
tree_
->
PopFromEOEOpStack
();
while
(
leaf_op
!=
nullptr
)
{
// Track the leaf operators that are under this repeat op.
eoe_ops_
.
push_back
(
leaf_op
);
leaf_op
=
tree_
->
PopFromEOEOpStack
();
}
// Push ourselves to the stack in case one of our ascendants is repeat too.
tree_
->
AddToEOEOpStack
(
shared_from_this
());
return
Status
::
OK
();
}
// Base-class override for setting specific RepeatOp configurations. This code will be called
// during the execution tree prepare phase BEFORE traversing down to child operators.
uint32_t
RepeatOp
::
PrepareFlags
()
const
{
return
ExecutionTree
::
kDePrepRepeat
;
}
// This function returns the buffer that is at the top of our output connector. The caller is
// typically our parent node, when the parent is asking us to provide the next buffer of data.
// Since RepeatOp is an inlined op, getting a buffer from us will simply bounce you to get
...
...
@@ -130,7 +110,8 @@ Status RepeatOp::GetNextBuffer(std::unique_ptr<DataBuffer> *p_buffer, int32_t wo
// Base-class override for handling cases when an eoe is received.
Status
RepeatOp
::
EoeReceived
(
int32_t
worker_id
)
{
repeat_count_
++
;
MS_LOG
(
DEBUG
)
<<
"Repeat operator end of epoch message received. Repeat count is now: "
<<
repeat_count_
<<
"."
;
MS_LOG
(
DEBUG
)
<<
"Repeat operator ("
<<
operator_id_
<<
") end of epoch message received. Repeat count is now: "
<<
repeat_count_
<<
"."
;
bool
repeated
=
BitTest
(
op_ctrl_flags_
,
kDeOpRepeated
);
bool
last_repeat
=
BitTest
(
op_ctrl_flags_
,
kDeOpLastRepeat
);
// If we've reached the requested repeat count, then flag the eoe nodes
...
...
@@ -149,8 +130,12 @@ Status RepeatOp::EoeReceived(int32_t worker_id) {
return
Status
::
OK
();
}
// base-class ResetSubtree
return
(
DatasetOp
::
ResetSubtree
());
// Invoke a reset against the eoe nodes only.
for
(
auto
&
eoe_op
:
eoe_ops_
)
{
RETURN_IF_NOT_OK
(
eoe_op
->
Reset
());
}
return
Status
::
OK
();
}
// Class functor operator () override.
...
...
@@ -178,6 +163,18 @@ int32_t RepeatOp::num_consumers() const {
}
}
// Drive reset actions if needed
Status
RepeatOp
::
Reset
()
{
// If there's nested repeats, an ascendant repeat may have ourself listed as an eoe op.
// In that case, we now have to bounce the reset down to our own eoe ops.
MS_LOG
(
DEBUG
)
<<
"Repeat operator ("
<<
operator_id_
<<
") reset."
;
for
(
auto
&
eoe_op
:
eoe_ops_
)
{
RETURN_IF_NOT_OK
(
eoe_op
->
Reset
());
}
state_
=
OpState
::
kDeOpRunning
;
return
Status
::
OK
();
}
int32_t
RepeatOp
::
num_producers
()
const
{
if
(
child_
.
empty
()
||
child_
[
0
]
==
nullptr
)
{
MS_LOG
(
DEBUG
)
<<
"Repeat operator, pointer to child node is null. Returning 0."
;
...
...
@@ -187,6 +184,12 @@ int32_t RepeatOp::num_producers() const {
}
}
// Pre-Visitor accept method for NodePass
Status
RepeatOp
::
PreAccept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call the pre-visitation
return
p
->
PreRunOnNode
(
shared_from_base
<
RepeatOp
>
(),
modified
);
}
// Visitor accept method for NodePass
Status
RepeatOp
::
Accept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call visitor
...
...
mindspore/ccsrc/dataset/engine/datasetops/repeat_op.h
浏览文件 @
7c1bc519
...
...
@@ -18,6 +18,7 @@
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "dataset/engine/datasetops/pipeline_op.h"
...
...
@@ -82,14 +83,6 @@ class RepeatOp : public PipelineOp {
// @return Status - The error code return
Status
operator
()()
override
;
// Base-class override for setting specific RepeatOp configurations. This code will be called
// during the execution tree prepare phase BEFORE traversing down to child operators.
uint32_t
PrepareFlags
()
const
override
;
// Base-class override for executing specific RepeatOp configurations. This code will be called
// during the execution tree post-prepare phase when it is visiting this operator.
Status
PrepareNodePostAction
()
override
;
// This function returns the buffer that is at the top of our output connector. The caller is
// typically our parent node, when the parent is asking us to provide the next buffer of data.
// Since RepeatOp is an inlined op, getting a buffer from us will simply bounce you to get
...
...
@@ -110,6 +103,10 @@ class RepeatOp : public PipelineOp {
// @param worker_id - The worker id
Status
EofReceived
(
int32_t
worker_id
)
override
;
/// \brief reset Op
/// \@return Status - The error code return
Status
Reset
()
override
;
// Base-class override. Return the number of workers in the first parent.
// @param workerId - The worker id
int32_t
num_consumers
()
const
override
;
...
...
@@ -118,16 +115,26 @@ class RepeatOp : public PipelineOp {
// @param workerId - The worker id
int32_t
num_producers
()
const
override
;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
/// \brief Base-class override for NodePass pre-visit acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status
PreAccept
(
NodePass
*
p
,
bool
*
modified
)
override
;
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p The node to visit
/// \param[out] modified Indicator if the node was modified
/// \return Status of the node visit
Status
Accept
(
NodePass
*
p
,
bool
*
modified
)
override
;
// Op name getter
// @return Name of the current Op
std
::
string
Name
()
const
override
{
return
"RepeatOp"
;
}
/// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes
/// \param[in] eoe_op The input leaf/eoe operator to add to the list
void
AddToEoeList
(
std
::
shared_ptr
<
DatasetOp
>
eoe_op
)
{
eoe_ops_
.
push_back
(
std
::
move
(
eoe_op
));
}
private:
int32_t
max_repeats_
;
// The number of repeats that the user requested
int32_t
repeat_count_
;
// A counter for the current number of executed repeats
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.cc
浏览文件 @
7c1bc519
...
...
@@ -22,6 +22,7 @@
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/data_schema.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/kernels/image/image_utils.h"
namespace
mindspore
{
...
...
@@ -408,6 +409,12 @@ Status CelebAOp::Reset() {
return
Status
::
OK
();
}
// Visitor accept method for NodePass
Status
CelebAOp
::
Accept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call visitor
return
p
->
RunOnNode
(
shared_from_base
<
CelebAOp
>
(),
modified
);
}
Status
CelebAOp
::
ComputeColMap
()
{
// Set the column name map (base class field)
if
(
column_name_id_map_
.
empty
())
{
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/celeba_op.h
浏览文件 @
7c1bc519
...
...
@@ -169,6 +169,12 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
// @return Status - The error code return
Status
AddIOBlock
(
std
::
unique_ptr
<
DataBuffer
>
*
data_buffer
);
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p Pointer to the NodePass to be accepted
/// \param[out] modified Indicator if the node was changed at all
/// \return Status of the node visit
Status
Accept
(
NodePass
*
p
,
bool
*
modified
)
override
;
// Op name getter
// @return Name of the current Op
std
::
string
Name
()
const
{
return
"CelebAOp"
;
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc
浏览文件 @
7c1bc519
...
...
@@ -26,6 +26,7 @@
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
namespace
mindspore
{
namespace
dataset
{
...
...
@@ -450,6 +451,12 @@ Status CifarOp::CountTotalRows(const std::string &dir, bool isCIFAR10, int64_t *
}
}
// Visitor accept method for NodePass
Status
CifarOp
::
Accept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call visitor
return
p
->
RunOnNode
(
shared_from_base
<
CifarOp
>
(),
modified
);
}
Status
CifarOp
::
ComputeColMap
()
{
// set the column name map (base class field)
if
(
column_name_id_map_
.
empty
())
{
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.h
浏览文件 @
7c1bc519
...
...
@@ -155,6 +155,12 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
// @return
static
Status
CountTotalRows
(
const
std
::
string
&
dir
,
bool
isCIFAR10
,
int64_t
*
count
);
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p Pointer to the NodePass to be accepted
/// \param[out] modified Indicator if the node was changed at all
/// \return Status of the node visit
Status
Accept
(
NodePass
*
p
,
bool
*
modified
)
override
;
// Op name getter
// @return Name of the current Op
std
::
string
Name
()
const
override
{
return
"CifarOp"
;
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.cc
浏览文件 @
7c1bc519
...
...
@@ -24,6 +24,7 @@
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
namespace
mindspore
{
namespace
dataset
{
...
...
@@ -624,6 +625,12 @@ Status CocoOp::GetClassIndexing(const std::string &dir, const std::string &file,
return
Status
::
OK
();
}
// Visitor accept method for NodePass
Status
CocoOp
::
Accept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call visitor
return
p
->
RunOnNode
(
shared_from_base
<
CocoOp
>
(),
modified
);
}
Status
CocoOp
::
ComputeColMap
()
{
// Set the column name map (base class field)
if
(
column_name_id_map_
.
empty
())
{
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/coco_op.h
浏览文件 @
7c1bc519
...
...
@@ -200,6 +200,12 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
static
Status
GetClassIndexing
(
const
std
::
string
&
dir
,
const
std
::
string
&
task_type
,
const
std
::
string
&
task_mode
,
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
vector
<
int32_t
>>>
*
output_class_indexing
);
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p Pointer to the NodePass to be accepted
/// \param[out] modified Indicator if the node was changed at all
/// \return Status of the node visit
Status
Accept
(
NodePass
*
p
,
bool
*
modified
)
override
;
private:
// Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.cc
浏览文件 @
7c1bc519
...
...
@@ -26,6 +26,7 @@
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
namespace
mindspore
{
namespace
dataset
{
...
...
@@ -416,6 +417,12 @@ Status ManifestOp::GetClassIndexing(const std::string &file, const py::dict &dic
return
Status
::
OK
();
}
// Visitor accept method for NodePass
Status
ManifestOp
::
Accept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call visitor
return
p
->
RunOnNode
(
shared_from_base
<
ManifestOp
>
(),
modified
);
}
Status
ManifestOp
::
ComputeColMap
()
{
// Set the column name map (base class field)
if
(
column_name_id_map_
.
empty
())
{
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/manifest_op.h
浏览文件 @
7c1bc519
...
...
@@ -172,6 +172,12 @@ class ManifestOp : public ParallelOp, public RandomAccessOp {
static
Status
GetClassIndexing
(
const
std
::
string
&
file
,
const
py
::
dict
&
dict
,
const
std
::
string
&
usage
,
std
::
map
<
std
::
string
,
int32_t
>
*
output_class_indexing
);
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p Pointer to the NodePass to be accepted
/// \param[out] modified Indicator if the node was changed at all
/// \return Status of the node visit
Status
Accept
(
NodePass
*
p
,
bool
*
modified
)
override
;
// Op name getter
// @return Name of the current Op
std
::
string
Name
()
const
override
{
return
"ManifestOp"
;
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.cc
浏览文件 @
7c1bc519
...
...
@@ -23,6 +23,7 @@
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
namespace
mindspore
{
namespace
dataset
{
...
...
@@ -428,6 +429,12 @@ Status MnistOp::CountTotalRows(const std::string &dir, int64_t *count) {
return
Status
::
OK
();
}
// Visitor accept method for NodePass
Status
MnistOp
::
Accept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call visitor
return
p
->
RunOnNode
(
shared_from_base
<
MnistOp
>
(),
modified
);
}
Status
MnistOp
::
ComputeColMap
()
{
// set the column name map (base class field)
if
(
column_name_id_map_
.
empty
())
{
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/mnist_op.h
浏览文件 @
7c1bc519
...
...
@@ -152,6 +152,12 @@ class MnistOp : public ParallelOp, public RandomAccessOp {
// @return
static
Status
CountTotalRows
(
const
std
::
string
&
dir
,
int64_t
*
count
);
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p Pointer to the NodePass to be accepted
/// \param[out] modified Indicator if the node was changed at all
/// \return Status of the node visit
Status
Accept
(
NodePass
*
p
,
bool
*
modified
)
override
;
// Op name getter
// @return Name of the current Op
std
::
string
Name
()
const
override
{
return
"MnistOp"
;
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.cc
浏览文件 @
7c1bc519
...
...
@@ -22,6 +22,7 @@
#include "dataset/util/random.h"
#include "dataset/util/wait_post.h"
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/opt/pass.h"
namespace
mindspore
{
namespace
dataset
{
...
...
@@ -406,6 +407,12 @@ Status RandomDataOp::Reset() {
return
Status
::
OK
();
}
// Visitor accept method for NodePass
Status
RandomDataOp
::
Accept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call visitor
return
p
->
RunOnNode
(
shared_from_base
<
RandomDataOp
>
(),
modified
);
}
Status
RandomDataOp
::
ComputeColMap
()
{
// Extract the column name mapping from the schema and save it in the class.
if
(
column_name_id_map_
.
empty
())
{
...
...
@@ -415,15 +422,5 @@ Status RandomDataOp::ComputeColMap() {
}
return
Status
::
OK
();
}
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
Status
RandomDataOp
::
PrepareNodePostAction
()
{
// Run common code from super class before adding RandomDataOp specific handling
RETURN_IF_NOT_OK
(
ParallelOp
::
PrepareNodePostAction
());
// Specific handling for this op, we need to do cache op work to assign the sampler to the cache.
RETURN_IF_NOT_OK
(
DatasetOp
::
SaveSamplerForCache
(
false
));
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/datasetops/source/random_data_op.h
浏览文件 @
7c1bc519
...
...
@@ -203,12 +203,6 @@ class RandomDataOp : public ParallelOp {
// @return Name of the current Op
std
::
string
Name
()
const
override
{
return
"RandomDataOp"
;
}
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
Status
PrepareNodePostAction
()
override
;
private:
/**
* The entry point code for when workers are launched
...
...
@@ -266,6 +260,12 @@ class RandomDataOp : public ParallelOp {
return
++
buffer_id_
;
}
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status
Accept
(
NodePass
*
p
,
bool
*
modified
)
override
;
// Private function for computing the assignment of the column name map.
// @return - Status
Status
ComputeColMap
()
override
;
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc
浏览文件 @
7c1bc519
...
...
@@ -1019,31 +1019,28 @@ Status TFReaderOp::ComputeColMap() {
return
Status
::
OK
();
}
// Brief If a cache has been added into the ascendant tree over this tf reader, then the cache will be executing
// a sampler for fetching the data. As such, any options in the tf reader need to be reset to its defaults so
// that this tf reader will produce the full set of data into the cache.
void
TFReaderOp
::
MakeSimpleProducer
()
{
device_id_
=
0
;
num_devices_
=
1
;
total_rows_
=
0
;
shuffle_files_
=
false
;
equal_rows_per_shard_
=
false
;
}
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
Status
TFReaderOp
::
PrepareNodePostAction
()
{
// Run common code from super class before adding TFReaderOp specific handling
RETURN_IF_NOT_OK
(
ParallelOp
::
PrepareNodePostAction
());
// Specific handling for this op, we need to do cache op work so assign the sampler to the cache
// TF is a special case because it can support file-based sharding/shuffling, or, if there
// is a cache, then it can also do row-based sampler using the sampler on the cache.
// Thus, pass true for random access op flag when saving the sampler. This is a special case,
// since usually a non-mappable dataset would pass false here.
RETURN_IF_NOT_OK
(
DatasetOp
::
SaveSamplerForCache
(
true
));
// Now that the sampler has been saved for the cache, we need to adjust the TFReaderOp to turn it into
// a simpler producer of all data (no shuffling or sharding or anything)
if
(
BitTest
(
tree_
->
PrepareFlags
(),
ExecutionTree
::
kDePrepCache
))
{
device_id_
=
0
;
num_devices_
=
1
;
total_rows_
=
0
;
shuffle_files_
=
false
;
equal_rows_per_shard_
=
false
;
sampler_
.
reset
();
// Normally SaveSampler code did this for us, but we passed in true above (See comment)
}
else
{
if
(
!
BitTest
(
tree_
->
PrepareFlags
(),
ExecutionTree
::
kDePrepCache
))
{
// This sanity check had been delayed until now in the prepare loop.
// If we are not in a cache path, then we can validate the
the
file-based sharding config.
// If we are not in a cache path, then we can validate the file-based sharding config.
// If we are in a cache path, there is no file-based sharding so the check is not correct in that
// situation.
if
(
!
equal_rows_per_shard_
&&
dataset_files_list_
.
size
()
<
static_cast
<
uint32_t
>
(
num_devices_
))
{
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h
浏览文件 @
7c1bc519
...
...
@@ -246,6 +246,11 @@ class TFReaderOp : public ParallelOp {
// @return Vector of the input file names
std
::
vector
<
std
::
string
>
FileNames
()
{
return
dataset_files_list_
;
}
/// \Brief If a cache has been added into the ascendant tree over this tf reader, then the cache will be executing
/// a sampler for fetching the data. As such, any options in the tf reader need to be reset to its defaults so
/// that this tf reader will produce the full set of data into the cache.
void
MakeSimpleProducer
();
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.cc
浏览文件 @
7c1bc519
...
...
@@ -25,6 +25,7 @@
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
#include "dataset/engine/db_connector.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/opt/pass.h"
using
tinyxml2
::
XMLDocument
;
using
tinyxml2
::
XMLElement
;
...
...
@@ -449,6 +450,11 @@ Status VOCOp::GetClassIndexing(const std::string &dir, const std::string &task_t
return
Status
::
OK
();
}
// Visitor accept method for NodePass
Status
VOCOp
::
Accept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call visitor
return
p
->
RunOnNode
(
shared_from_base
<
VOCOp
>
(),
modified
);
}
Status
VOCOp
::
ComputeColMap
()
{
// Set the column name map (base class field)
...
...
mindspore/ccsrc/dataset/engine/datasetops/source/voc_op.h
浏览文件 @
7c1bc519
...
...
@@ -205,6 +205,12 @@ class VOCOp : public ParallelOp, public RandomAccessOp {
static
Status
GetClassIndexing
(
const
std
::
string
&
dir
,
const
std
::
string
&
task_type
,
const
std
::
string
&
task_mode
,
const
py
::
dict
&
dict
,
std
::
map
<
std
::
string
,
int32_t
>
*
output_class_indexing
);
/// \brief Base-class override for NodePass visitor acceptor
/// \param[in] p Pointer to the NodePass to be accepted
/// \param[out] modified Indicator if the node was changed at all
/// \return Status of the node visit
Status
Accept
(
NodePass
*
p
,
bool
*
modified
)
override
;
// Op name getter
// @return Name of the current Op
std
::
string
Name
()
const
override
{
return
"VOCOp"
;
}
...
...
mindspore/ccsrc/dataset/engine/datasetops/take_op.cc
浏览文件 @
7c1bc519
...
...
@@ -127,12 +127,6 @@ Status TakeOp::FillBuffer(std::unique_ptr<DataBuffer> *buffer, std::unique_ptr<D
return
Status
::
OK
();
}
Status
TakeOp
::
PrepareNodePostAction
()
{
RETURN_IF_NOT_OK
(
PipelineOp
::
PrepareNodePostAction
());
tree_
->
AddToEOEOpStack
(
shared_from_this
());
return
Status
::
OK
();
}
// Visitor accept method for NodePass
Status
TakeOp
::
Accept
(
NodePass
*
p
,
bool
*
modified
)
{
// Downcast shared pointer then call visitor
...
...
mindspore/ccsrc/dataset/engine/datasetops/take_op.h
浏览文件 @
7c1bc519
...
...
@@ -78,12 +78,6 @@ class TakeOp : public PipelineOp {
// @return Status - The error code return
Status
operator
()()
override
;
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
// before providing their own implementations.
Status
PrepareNodePostAction
()
override
;
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
...
...
mindspore/ccsrc/dataset/engine/execution_tree.cc
浏览文件 @
7c1bc519
...
...
@@ -21,6 +21,8 @@
#include "dataset/util/task_manager.h"
#include "dataset/engine/opt/pass.h"
#include "dataset/engine/opt/pre/removal_pass.h"
#include "dataset/engine/opt/pre/cache_transform_pass.h"
#include "dataset/engine/opt/post/repeat_pass.h"
#include "dataset/engine/perf/profiling.h"
#include "dataset/engine/perf/monitor.h"
...
...
@@ -215,18 +217,33 @@ Status ExecutionTree::PrepareTreePreAction() {
bool
modified
=
false
;
std
::
vector
<
std
::
unique_ptr
<
Pass
>>
pre_actions
;
// Construct pre actions
MS_LOG
(
INFO
)
<<
"Running pre pass"
;
pre_actions
.
push_back
(
std
::
make_unique
<
RemovalPass
>
(
RemovalPass
()));
MS_LOG
(
INFO
)
<<
"Running pre pass loops."
;
pre_actions
.
push_back
(
std
::
make_unique
<
RemovalPass
>
());
pre_actions
.
push_back
(
std
::
make_unique
<
CacheTransformPass
>
());
// Apply pre action passes
for
(
auto
&
pass
:
pre_actions
)
{
RETURN_IF_NOT_OK
(
pass
->
Run
(
this
,
&
modified
));
}
MS_LOG
(
INFO
)
<<
"Pre passes complete."
;
return
Status
::
OK
();
}
Status
ExecutionTree
::
PrepareTreePostAction
()
{
// The tree is ready to be prepared.
tree_state_
=
kDeTStatePrepare
;
bool
modified
=
false
;
std
::
vector
<
std
::
unique_ptr
<
Pass
>>
post_actions
;
// Construct pre actions
MS_LOG
(
INFO
)
<<
"Running post pass loops."
;
post_actions
.
push_back
(
std
::
make_unique
<
RepeatPass
>
());
// Apply post action passes
for
(
auto
&
pass
:
post_actions
)
{
RETURN_IF_NOT_OK
(
pass
->
Run
(
this
,
&
modified
));
}
MS_LOG
(
INFO
)
<<
"Post passes complete."
;
return
Status
::
OK
();
}
...
...
@@ -280,31 +297,5 @@ Status ExecutionTree::PrepareNode(const std::shared_ptr<DatasetOp> &dataset_op)
return
Status
::
OK
();
}
// Adds an operator to the eoe operator stack during prepare phase.
void
ExecutionTree
::
AddToEOEOpStack
(
std
::
shared_ptr
<
DatasetOp
>
dataset_op
)
{
eoe_stack_
.
push
(
dataset_op
);
}
// Pops an operator from the eoe operator stack during prepare phase.
std
::
shared_ptr
<
DatasetOp
>
ExecutionTree
::
PopFromEOEOpStack
()
{
std
::
shared_ptr
<
DatasetOp
>
top_op
=
nullptr
;
if
(
!
eoe_stack_
.
empty
())
{
top_op
=
eoe_stack_
.
top
();
eoe_stack_
.
pop
();
}
return
top_op
;
}
// Adds a sampler to the sampler stack during prepare phase.
void
ExecutionTree
::
AddToSamplerStack
(
std
::
shared_ptr
<
Sampler
>
sampler
)
{
sampler_stack_
.
push
(
sampler
);
}
// Pops an operator from the sampler stack during prepare phase.
std
::
shared_ptr
<
Sampler
>
ExecutionTree
::
PopFromSamplerStack
()
{
std
::
shared_ptr
<
Sampler
>
top_sampler
=
nullptr
;
if
(
!
sampler_stack_
.
empty
())
{
top_sampler
=
sampler_stack_
.
top
();
sampler_stack_
.
pop
();
}
return
top_sampler
;
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/execution_tree.h
浏览文件 @
7c1bc519
...
...
@@ -200,24 +200,6 @@ class ExecutionTree {
// @return Status - The error code return
Status
PrepareNode
(
const
std
::
shared_ptr
<
DatasetOp
>
&
dataset_op
);
/// Adds an operator to the eoe operator stack during prepare phase.
/// \param op - The dataset op to work add to eoe stack
/// \return Status - The error code return
void
AddToEOEOpStack
(
std
::
shared_ptr
<
DatasetOp
>
dataset_op
);
/// Pops an operator from the eoe operator stack during prepare phase.
/// \return shared_ptr to the popped operator
std
::
shared_ptr
<
DatasetOp
>
PopFromEOEOpStack
();
/// Adds a sampler to the sampler stack during prepare phase.
/// \param samplerop - The dataset op to work add to eoe stack
/// \return Status - The error code return
void
AddToSamplerStack
(
std
::
shared_ptr
<
Sampler
>
sampler
);
/// Pops an operator from the sampler stack during prepare phase.
/// \return shared_ptr to the popped operator
std
::
shared_ptr
<
Sampler
>
PopFromSamplerStack
();
// Return the pointer to the TaskGroup
// @return raw pointer to the TaskGroup
TaskGroup
*
AllTasks
()
const
{
return
tg_
.
get
();
}
...
...
@@ -248,8 +230,6 @@ class ExecutionTree {
TreeState
tree_state_
;
// Tracking the current tree state
std
::
unique_ptr
<
Monitor
>
perf_monitor_
;
// Performance Monitor
std
::
unique_ptr
<
ProfilingManager
>
profiling_manager_
;
// Profiling manager
std
::
stack
<
std
::
shared_ptr
<
DatasetOp
>>
eoe_stack_
;
// A stack used during prepare phase
std
::
stack
<
std
::
shared_ptr
<
Sampler
>>
sampler_stack_
;
// A stack used during prepare phase
};
}
// namespace dataset
}
// namespace mindspore
...
...
mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt
浏览文件 @
7c1bc519
...
...
@@ -2,6 +2,9 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
set_property
(
SOURCE
${
_CURRENT_SRC_FILES
}
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD
)
add_library
(
engine-opt OBJECT
pass.cc
post/repeat_pass.cc
pre/cache_pass.cc
pre/cache_transform_pass.cc
pre/removal_nodes.cc
pre/removal_pass.cc
util/printer_pass.cc
...
...
mindspore/ccsrc/dataset/engine/opt/pass.cc
浏览文件 @
7c1bc519
...
...
@@ -16,6 +16,9 @@
#include "dataset/engine/opt/pass.h"
#include "dataset/engine/datasetops/batch_op.h"
#include "dataset/engine/datasetops/cache_op.h"
#include "dataset/engine/datasetops/cache_merge_op.h"
#include "dataset/engine/datasetops/cache_lookup_op.h"
#include "dataset/engine/datasetops/dataset_op.h"
#include "dataset/engine/datasetops/device_queue_op.h"
#include "dataset/engine/datasetops/map_op.h"
...
...
@@ -24,8 +27,15 @@
#include "dataset/engine/datasetops/repeat_op.h"
#include "dataset/engine/datasetops/skip_op.h"
#include "dataset/engine/datasetops/shuffle_op.h"
#include "dataset/engine/datasetops/source/celeba_op.h"
#include "dataset/engine/datasetops/source/cifar_op.h"
#include "dataset/engine/datasetops/source/coco_op.h"
#include "dataset/engine/datasetops/source/manifest_op.h"
#include "dataset/engine/datasetops/source/mindrecord_op.h"
#include "dataset/engine/datasetops/source/mnist_op.h"
#include "dataset/engine/datasetops/source/random_data_op.h"
#include "dataset/engine/datasetops/source/tf_reader_op.h"
#include "dataset/engine/datasetops/source/voc_op.h"
#ifdef ENABLE_PYTHON
#include "dataset/engine/datasetops/filter_op.h"
#include "dataset/engine/datasetops/source/generator_op.h"
...
...
@@ -145,6 +155,11 @@ Status NodePass::RunOnNode(std::shared_ptr<GeneratorOp> node, bool *modified) {
}
#endif
Status
NodePass
::
RunOnNode
(
std
::
shared_ptr
<
RandomDataOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
RunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
RunOnNode
(
std
::
shared_ptr
<
TakeOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
RunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
...
...
@@ -164,5 +179,70 @@ Status NodePass::RunOnNode(std::shared_ptr<ImageFolderOp> node, bool *modified)
// Fallback to base class visitor by default
return
RunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
RunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
RunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
RunOnNode
(
std
::
shared_ptr
<
MnistOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
RunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
RunOnNode
(
std
::
shared_ptr
<
ManifestOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
RunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
RunOnNode
(
std
::
shared_ptr
<
CifarOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
RunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
RunOnNode
(
std
::
shared_ptr
<
VOCOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
RunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
RunOnNode
(
std
::
shared_ptr
<
CelebAOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
RunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
RunOnNode
(
std
::
shared_ptr
<
CocoOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
RunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
RunOnNode
(
std
::
shared_ptr
<
RepeatOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
RunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
RunOnNode
(
std
::
shared_ptr
<
CacheMergeOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
RunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
RunOnNode
(
std
::
shared_ptr
<
CacheLookupOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
RunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
PreRunOnNode
(
std
::
shared_ptr
<
RepeatOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
PreRunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
PreRunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
PreRunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
Status
NodePass
::
PreRunOnNode
(
std
::
shared_ptr
<
CacheMergeOp
>
node
,
bool
*
modified
)
{
// Fallback to base class visitor by default
return
PreRunOnNode
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
),
modified
);
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/opt/pass.h
浏览文件 @
7c1bc519
...
...
@@ -47,6 +47,10 @@ class FilterOp;
class
GeneratorOp
;
#endif
class
RandomDataOp
;
class
RepeatOp
;
class
TakeOp
;
class
ZipOp
;
...
...
@@ -55,6 +59,24 @@ class DeviceQueueOp;
class
ImageFolderOp
;
class
CacheOp
;
class
MnistOp
;
class
ManifestOp
;
class
CifarOp
;
class
VOCOp
;
class
CocoOp
;
class
CelebAOp
;
class
CacheMergeOp
;
class
CacheLookupOp
;
// The base class Pass is the basic unit of tree transformation.
// The actual implementation of the passes will be derived from here.
class
Pass
:
public
std
::
enable_shared_from_this
<
Pass
>
{
...
...
@@ -138,14 +160,42 @@ class NodePass : public Pass {
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
GeneratorOp
>
node
,
bool
*
modified
);
#endif
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
RandomDataOp
>
node
,
bool
*
modified
);
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
TakeOp
>
node
,
bool
*
modified
);
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
ZipOp
>
node
,
bool
*
modified
);
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
DeviceQueueOp
>
node
,
bool
*
modified
);
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
);
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
ImageFolderOp
>
node
,
bool
*
modified
);
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
MnistOp
>
node
,
bool
*
modified
);
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
ManifestOp
>
node
,
bool
*
modified
);
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
CifarOp
>
node
,
bool
*
modified
);
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
VOCOp
>
node
,
bool
*
modified
);
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
CocoOp
>
node
,
bool
*
modified
);
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
CelebAOp
>
node
,
bool
*
modified
);
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
RepeatOp
>
node
,
bool
*
modified
);
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
CacheMergeOp
>
node
,
bool
*
modified
);
virtual
Status
RunOnNode
(
std
::
shared_ptr
<
CacheLookupOp
>
node
,
bool
*
modified
);
virtual
Status
PreRunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
);
virtual
Status
PreRunOnNode
(
std
::
shared_ptr
<
RepeatOp
>
node
,
bool
*
modified
);
virtual
Status
PreRunOnNode
(
std
::
shared_ptr
<
CacheMergeOp
>
node
,
bool
*
modified
);
private:
// Helper function to perform DFS visit
Status
DFSNodeVisit
(
std
::
shared_ptr
<
DatasetOp
>
node
,
bool
*
modified
);
...
...
mindspore/ccsrc/dataset/engine/opt/post/repeat_pass.cc
0 → 100644
浏览文件 @
7c1bc519
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include "dataset/engine/opt/post/repeat_pass.h"
#include "dataset/engine/datasetops/repeat_op.h"
#include "dataset/engine/datasetops/cache_op.h"
#include "dataset/engine/datasetops/cache_lookup_op.h"
#include "dataset/engine/datasetops/cache_merge_op.h"
namespace
mindspore
{
namespace
dataset
{
RepeatPass
::
RepeatPass
()
:
is_repeated_
(
false
),
nested_repeats_
(
0
),
is_merge_
(
false
),
cache_lookup_
(
nullptr
)
{}
// Identifies the subtree below this node as being in a repeated path of the tree.
Status
RepeatPass
::
PreRunOnNode
(
std
::
shared_ptr
<
RepeatOp
>
node
,
bool
*
modified
)
{
// If we are already repeated, then this is a nested repeat.
if
(
is_repeated_
)
{
nested_repeats_
++
;
}
is_repeated_
=
true
;
return
Status
::
OK
();
}
// Identifies the subtree below this node as being in a cache merge path
Status
RepeatPass
::
PreRunOnNode
(
std
::
shared_ptr
<
CacheMergeOp
>
node
,
bool
*
modified
)
{
// Turn on the flag that we're under a merge op
is_merge_
=
true
;
return
Status
::
OK
();
}
// Hooks up any identified eoe nodes under this repeat.
Status
RepeatPass
::
RunOnNode
(
std
::
shared_ptr
<
RepeatOp
>
node
,
bool
*
modified
)
{
// Pop the leaf ops from the save-area stack and add them to the repeat op's eoe node tracking
std
::
shared_ptr
<
DatasetOp
>
leaf_op
=
PopFromEOEOpStack
();
while
(
leaf_op
!=
nullptr
)
{
node
->
AddToEoeList
(
leaf_op
);
leaf_op
=
PopFromEOEOpStack
();
}
// We are a repeat op in the descendant tree of a merge op, then we take the saved lookup up
// and add it to the list of eoe/leaf ops for the repeat, removing it from the save area.
if
(
is_merge_
&&
cache_lookup_
)
{
cache_lookup_
->
set_control_flag
(
DatasetOp
::
kDeOpRepeated
);
node
->
AddToEoeList
(
std
::
move
(
cache_lookup_
));
}
// If we are a nested repeat, then we add ourself to the repeat stack for the next one above us.
// A nested repeat acts like an eoe/leaf for the repeat in the ascendant tree.
if
(
nested_repeats_
>
0
)
{
node
->
set_control_flag
(
DatasetOp
::
kDeOpRepeated
);
AddToEOEOpStack
(
node
);
nested_repeats_
--
;
}
// If we are not nested, or we were the top-most repeat, now we clear the flag
if
(
nested_repeats_
==
0
)
{
is_repeated_
=
false
;
}
return
Status
::
OK
();
}
// CacheOp removes previous leaf ops and replaces them with itself
Status
RepeatPass
::
RunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
{
if
(
is_repeated_
)
{
node
->
set_control_flag
(
DatasetOp
::
kDeOpRepeated
);
// if we are a cache within a repeat path of the tree, then there will be
// eoe-generating ops in the eoe op stack in the tree. They are flagged as such so that the
// repeat or epoch ctrl operators can work with them for repeat activity during runtime.
// However, since a cache is present:
// - unflag those ops as being repeated ops
// - remove them from the eoe op stack so that repeat op above in the tree won't know about them
// - add ourself (the cache op), as an eoe op
// We do this so that those old leafs become 1-time use (up to eoe), never repeated. Instead
// the repeating behaviours shall be invoked against the cache op.
std
::
shared_ptr
<
DatasetOp
>
leaf_op
=
PopFromEOEOpStack
();
while
(
leaf_op
!=
nullptr
)
{
leaf_op
->
ClearControlFlag
(
DatasetOp
::
kDeOpLastRepeat
);
leaf_op
->
ClearControlFlag
(
DatasetOp
::
kDeOpRepeated
);
leaf_op
=
PopFromEOEOpStack
();
}
AddToEOEOpStack
(
std
::
static_pointer_cast
<
DatasetOp
>
(
node
));
}
return
Status
::
OK
();
}
// All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up
// for use with a controlling repeat above it.
Status
RepeatPass
::
RunOnNode
(
std
::
shared_ptr
<
DatasetOp
>
node
,
bool
*
modified
)
{
// If we are in a repeat path, then set our repeated flag
if
(
is_repeated_
)
{
node
->
set_control_flag
(
DatasetOp
::
kDeOpRepeated
);
// if we are a leaf node then save ourself in a stack for the repeat operator above us
if
(
node
->
IsLeaf
())
{
AddToEOEOpStack
(
node
);
}
}
return
Status
::
OK
();
}
// Turns off the tracking for operations under merge op
Status
RepeatPass
::
RunOnNode
(
std
::
shared_ptr
<
CacheMergeOp
>
node
,
bool
*
modified
)
{
// Setting the flag is needed since we didn't call the base class DatasetOp version
if
(
is_repeated_
)
node
->
set_control_flag
(
DatasetOp
::
kDeOpRepeated
);
is_merge_
=
false
;
cache_lookup_
.
reset
();
// If a repeat op did not consume this then it's no longer needed
return
Status
::
OK
();
}
// Saves the lookup up in case it needs to be referenced by a repeat
Status
RepeatPass
::
RunOnNode
(
std
::
shared_ptr
<
CacheLookupOp
>
node
,
bool
*
modified
)
{
if
(
!
node
->
IsLeaf
())
{
// By definition, the CacheLookup must be a leaf op. Make that clear here.
RETURN_STATUS_UNEXPECTED
(
"CacheLookupOp must be a leaf node!"
);
}
// If we are in a repeat path already, then there must be a repeat above the merge op
// In this case, we naturally are a repeating leaf op so add the required setup for leafs under repeat here.
if
(
is_repeated_
)
{
node
->
set_control_flag
(
DatasetOp
::
kDeOpRepeated
);
AddToEOEOpStack
(
node
);
}
else
{
// save the lookup op. There could be a repeat in the cache miss leg of the merge op, in which case we
// may still need to be flagged as a repeating leaf. We can't decide that here though, so save ourself
// into the pass so that the decision can be made during the processing of the cache miss leg of the merge.
cache_lookup_
=
std
::
static_pointer_cast
<
DatasetOp
>
(
node
);
}
return
Status
::
OK
();
}
// Adds an operator to the eoe operator stack save area
void
RepeatPass
::
AddToEOEOpStack
(
std
::
shared_ptr
<
DatasetOp
>
dataset_op
)
{
eoe_stack_
.
push
(
dataset_op
);
}
// Pops an operator from the eoe operator stack save area
std
::
shared_ptr
<
DatasetOp
>
RepeatPass
::
PopFromEOEOpStack
()
{
std
::
shared_ptr
<
DatasetOp
>
top_op
=
nullptr
;
if
(
!
eoe_stack_
.
empty
())
{
top_op
=
eoe_stack_
.
top
();
eoe_stack_
.
pop
();
}
return
top_op
;
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/opt/post/repeat_pass.h
0 → 100644
浏览文件 @
7c1bc519
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_
#define DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_
#include <memory>
#include <stack>
#include <utility>
#include "dataset/engine/opt/pass.h"
namespace
mindspore
{
namespace
dataset
{
/// \class RepeatPass repeat_pass.h
/// \brief This is a NodePass who's job is to perform setup actions for RepeatOps. A RepeatOp needs to have references
/// to the eoe-producing (typically leaf) nodes underneath it.
class
RepeatPass
:
public
NodePass
{
public:
/// \brief Constructor
RepeatPass
();
/// \brief Identifies the subtree below this node as being in a repeated path of the tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
PreRunOnNode
(
std
::
shared_ptr
<
RepeatOp
>
node
,
bool
*
modified
)
override
;
/// \brief Identifies the subtree below this node as being in a cache merge path
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
PreRunOnNode
(
std
::
shared_ptr
<
CacheMergeOp
>
node
,
bool
*
modified
)
override
;
/// \brief Hooks up any identified eoe nodes under this repeat.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
RepeatOp
>
node
,
bool
*
modified
)
override
;
/// \brief CacheOp removes previous leaf ops and replaces them with itself
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
override
;
/// \brief Turns of the tracking for operations under merge op
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
CacheMergeOp
>
node
,
bool
*
modified
)
override
;
/// \brief Saves the lookup up in case it needs to be referenced by a repeat
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
CacheLookupOp
>
node
,
bool
*
modified
)
override
;
/// \brief All operators have a flag that might be set related to the repeat and any leaf nodes need to be set up
/// for use with a controlling repeat above it.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
DatasetOp
>
node
,
bool
*
modified
)
override
;
private:
/// \brief Adds an operator to the eoe operator stack save area
/// \param op - The dataset op to work add to eoe stack
/// \return Status - The error code return
void
AddToEOEOpStack
(
std
::
shared_ptr
<
DatasetOp
>
dataset_op
);
/// \brief Pops an operator from the eoe operator stack save area
/// \return shared_ptr to the popped operator
std
::
shared_ptr
<
DatasetOp
>
PopFromEOEOpStack
();
bool
is_repeated_
;
// T/F if we are processing under a repeat
bool
is_merge_
;
// T/F if we are processing under a cache merge op
int32_t
nested_repeats_
;
// A counter for nested repeats
std
::
stack
<
std
::
shared_ptr
<
DatasetOp
>>
eoe_stack_
;
// A save area for leaf/eoe ops
std
::
shared_ptr
<
DatasetOp
>
cache_lookup_
;
// A save area for a cache lookup op
};
}
// namespace dataset
}
// namespace mindspore
#endif // DATASET_ENGINE_OPT_PASS_POST_REPEAT_PASS_
mindspore/ccsrc/dataset/engine/opt/pre/cache_pass.cc
0 → 100644
浏览文件 @
7c1bc519
此差异已折叠。
点击以展开。
mindspore/ccsrc/dataset/engine/opt/pre/cache_pass.h
0 → 100644
浏览文件 @
7c1bc519
此差异已折叠。
点击以展开。
mindspore/ccsrc/dataset/engine/opt/pre/cache_transform_pass.cc
0 → 100644
浏览文件 @
7c1bc519
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <vector>
#include "dataset/engine/opt/pre/cache_pass.h"
#include "dataset/engine/opt/pre/cache_transform_pass.h"
#include "dataset/engine/execution_tree.h"
#include "dataset/engine/cache/cache_client.h"
#include "dataset/engine/datasetops/cache_lookup_op.h"
#include "dataset/engine/datasetops/cache_merge_op.h"
#include "dataset/engine/datasetops/cache_op.h"
namespace
mindspore
{
namespace
dataset
{
// constructor
CacheTransformPass
::
CacheTransformPass
()
{}
// Runs a cache_pass first to set up the transformation nodes, and then drives any of these transformations
Status
CacheTransformPass
::
RunOnTree
(
ExecutionTree
*
tree
,
bool
*
modified
)
{
MS_LOG
(
INFO
)
<<
"Pre pass: Cache transform pass started."
;
// Create the cache pass and run it. The cache pass identifies and creates the leaf/cache pairs that we will
// use to execute a transform.
std
::
unique_ptr
<
Pass
>
cache_pass
=
std
::
make_unique
<
CachePass
>
(
this
);
RETURN_IF_NOT_OK
(
cache_pass
->
Run
(
tree
,
modified
));
// Then, execute the transform for each pair
for
(
auto
cache_pair
:
cache_pairs_
)
{
MS_LOG
(
DEBUG
)
<<
"Cache transform pass: Executing a cache op mappable transform."
;
ExecuteCacheTransform
(
tree
,
cache_pair
.
first
,
cache_pair
.
second
,
cache_pair
.
second
->
cache_client
());
}
MS_LOG
(
INFO
)
<<
"Pre pass: Cache transform pass complete."
;
return
Status
::
OK
();
}
// Helper function to execute the cache transformation.
Status
CacheTransformPass
::
ExecuteCacheTransform
(
ExecutionTree
*
tree
,
std
::
shared_ptr
<
DatasetOp
>
leaf_op
,
std
::
shared_ptr
<
DatasetOp
>
cache_op
,
std
::
shared_ptr
<
CacheClient
>
cache_client
)
{
// Get local pointers the child/parent of the cache op. It's possible that the parent is null if the cache was
// the root node. It is also possible that cache_child == leaf_op
std
::
shared_ptr
<
DatasetOp
>
cache_child
=
cache_op
->
child
(
0
);
DatasetOp
*
cache_parent
=
nullptr
;
cache_op
->
Parent
(
&
cache_parent
,
0
);
// fetch the cache op's parent
// Extract the sampler from the leaf. We will overwrite this sampler with the lookup op later.
std
::
shared_ptr
<
Sampler
>
leaf_sampler
=
leaf_op
->
sampler
();
// Construct the merge op with defaults
std
::
shared_ptr
<
CacheMergeOp
>
merge_op
;
CacheMergeOp
::
Builder
merge_builder
;
RETURN_IF_NOT_OK
(
merge_builder
.
SetClient
(
cache_client
).
Build
(
&
merge_op
));
RETURN_IF_NOT_OK
(
tree
->
AssociateNode
(
merge_op
));
// Construct the cache lookup op with defaults
std
::
shared_ptr
<
CacheLookupOp
>
cache_lookup_op
;
CacheLookupOp
::
Builder
lookup_builder
;
RETURN_IF_NOT_OK
(
lookup_builder
.
SetClient
(
cache_client
).
SetSampler
(
std
::
move
(
leaf_sampler
)).
Build
(
&
cache_lookup_op
));
RETURN_IF_NOT_OK
(
tree
->
AssociateNode
(
cache_lookup_op
));
// Overwrite the old sampler in this leaf op to become the lookup op
leaf_op
->
SetSampler
(
cache_lookup_op
);
// If the cache had a parent, then go into that parent to remove the cache from it's child list and then
// replace it with the merge op.
if
(
cache_parent
!=
nullptr
)
{
RETURN_IF_NOT_OK
(
cache_parent
->
RemoveChild
(
cache_op
));
RETURN_IF_NOT_OK
(
cache_parent
->
AddChild
(
merge_op
));
}
else
{
// If we didn't have a parent, then the merge op is the root node
RETURN_IF_NOT_OK
(
tree
->
AssignRoot
(
merge_op
));
}
// Set the cache op to no longer be a parent over it's child. This will fully disconnect the old cache op.
// We maintain a local pointer to the old child though.
RETURN_IF_NOT_OK
(
cache_op
->
RemoveChild
(
cache_child
));
// Connect the merge op
RETURN_IF_NOT_OK
(
merge_op
->
AddChild
(
std
::
move
(
cache_lookup_op
)));
RETURN_IF_NOT_OK
(
merge_op
->
AddChild
(
std
::
move
(
cache_child
)));
// At this point, the cache op has already had it's children and parents taken away. Calling remove
// on it at this point will not do any node hookups, and instead set internal fields to invalid.
RETURN_IF_NOT_OK
(
cache_op
->
Remove
());
return
Status
::
OK
();
}
// Assigns the leaf and cache operators that are involved in a cache transformation
void
CacheTransformPass
::
AddMappableCacheOperators
(
std
::
shared_ptr
<
DatasetOp
>
leaf_op
,
std
::
shared_ptr
<
CacheOp
>
cache_op
)
{
cache_pairs_
.
push_back
(
std
::
make_pair
(
leaf_op
,
cache_op
));
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/engine/opt/pre/cache_transform_pass.h
0 → 100644
浏览文件 @
7c1bc519
此差异已折叠。
点击以展开。
mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.cc
浏览文件 @
7c1bc519
此差异已折叠。
点击以展开。
mindspore/ccsrc/dataset/engine/opt/pre/removal_nodes.h
浏览文件 @
7c1bc519
...
...
@@ -34,6 +34,18 @@ class RemovalNodes : public NodePass {
/// \param[in] removal_pass Raw pointer back to controlling tree pass
explicit
RemovalNodes
(
RemovalPass
*
removal_pass
);
/// \brief Identifies the subtree below this node as a cached descendant tree.
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
PreRunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
override
;
/// \brief Resets the tracking of the cache within the tree
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status
RunOnNode
(
std
::
shared_ptr
<
CacheOp
>
node
,
bool
*
modified
)
override
;
/// \brief Perform ShuffleOp removal check
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
...
...
mindspore/ccsrc/dataset/engine/opt/pre/removal_pass.cc
浏览文件 @
7c1bc519
此差异已折叠。
点击以展开。
mindspore/ccsrc/dataset/util/allocator.h
浏览文件 @
7c1bc519
此差异已折叠。
点击以展开。
mindspore/ccsrc/dataset/util/cache_pool.cc
浏览文件 @
7c1bc519
此差异已折叠。
点击以展开。
mindspore/ccsrc/dataset/util/services.cc
浏览文件 @
7c1bc519
此差异已折叠。
点击以展开。
mindspore/ccsrc/dataset/util/services.h
浏览文件 @
7c1bc519
此差异已折叠。
点击以展开。
mindspore/dataset/__init__.py
浏览文件 @
7c1bc519
...
...
@@ -24,6 +24,7 @@ from .engine.datasets import TFRecordDataset, ImageFolderDatasetV2, MnistDataset
TextFileDataset
,
CLUEDataset
,
Schema
,
Shuffle
,
zip
,
RandomDataset
from
.engine.samplers
import
DistributedSampler
,
PKSampler
,
RandomSampler
,
SequentialSampler
,
SubsetRandomSampler
,
\
WeightedRandomSampler
,
Sampler
from
.engine.cache_client
import
DatasetCache
from
.engine.serializer_deserializer
import
serialize
,
deserialize
,
show
from
.engine.graphdata
import
GraphData
...
...
mindspore/dataset/engine/cache_client.py
0 → 100644
浏览文件 @
7c1bc519
此差异已折叠。
点击以展开。
mindspore/dataset/engine/datasets.py
浏览文件 @
7c1bc519
此差异已折叠。
点击以展开。
mindspore/dataset/engine/serializer_deserializer.py
浏览文件 @
7c1bc519
此差异已折叠。
点击以展开。
mindspore/dataset/engine/validators.py
浏览文件 @
7c1bc519
此差异已折叠。
点击以展开。
mindspore/dataset/text/validators.py
浏览文件 @
7c1bc519
此差异已折叠。
点击以展开。
mindspore/dataset/transforms/vision/validators.py
浏览文件 @
7c1bc519
此差异已折叠。
点击以展开。
tests/ut/cpp/dataset/c_api_test.cc
浏览文件 @
7c1bc519
此差异已折叠。
点击以展开。
tests/ut/cpp/dataset/cache_op_test.cc
0 → 100644
浏览文件 @
7c1bc519
此差异已折叠。
点击以展开。
tests/ut/data/dataset/golden/cache_map_01_result.npz
0 → 100644
浏览文件 @
7c1bc519
文件已添加
tests/ut/data/dataset/golden/cache_map_02_result.npz
0 → 100644
浏览文件 @
7c1bc519
此差异已折叠。
点击以展开。
tests/ut/python/dataset/test_cache_map.py
0 → 100644
浏览文件 @
7c1bc519
此差异已折叠。
点击以展开。
tests/ut/python/dataset/test_cache_nomap.py
0 → 100644
浏览文件 @
7c1bc519
此差异已折叠。
点击以展开。
tests/ut/python/dataset/test_random_dataset.py
浏览文件 @
7c1bc519
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录