Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
0dc7d425
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0dc7d425
编写于
11月 15, 2019
作者:
X
xiexionghang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feed improve dict/merge_patch
上级
01f2dca6
变更
20
隐藏空白更改
内联
并排
Showing
20 changed file
with
6892 addition
and
0 deletion
+6892
-0
build.sh
build.sh
+55
-0
paddle/fluid/feed/CMakeLists.txt
paddle/fluid/feed/CMakeLists.txt
+2
-0
paddle/fluid/feed/apply_feed_code.sh
paddle/fluid/feed/apply_feed_code.sh
+72
-0
paddle/fluid/feed/pybind/CMakeLists.txt
paddle/fluid/feed/pybind/CMakeLists.txt
+23
-0
paddle/fluid/feed/pybind/expand_api.cc
paddle/fluid/feed/pybind/expand_api.cc
+42
-0
paddle/fluid/feed/pybind/expand_api.h
paddle/fluid/feed/pybind/expand_api.h
+9
-0
paddle/fluid/feed/pybind/pybind.cc
paddle/fluid/feed/pybind/pybind.cc
+1753
-0
paddle/fluid/feed/src/CMakeLists.txt
paddle/fluid/feed/src/CMakeLists.txt
+2
-0
paddle/fluid/feed/src/common/CMakeLists.txt
paddle/fluid/feed/src/common/CMakeLists.txt
+1
-0
paddle/fluid/feed/src/common/bhopscotch_map.h
paddle/fluid/feed/src/common/bhopscotch_map.h
+675
-0
paddle/fluid/feed/src/common/bhopscotch_set.h
paddle/fluid/feed/src/common/bhopscotch_set.h
+529
-0
paddle/fluid/feed/src/common/dict_plugin.cc
paddle/fluid/feed/src/common/dict_plugin.cc
+42
-0
paddle/fluid/feed/src/common/dict_plugin.h
paddle/fluid/feed/src/common/dict_plugin.h
+128
-0
paddle/fluid/feed/src/common/hopscotch_growth_policy.h
paddle/fluid/feed/src/common/hopscotch_growth_policy.h
+348
-0
paddle/fluid/feed/src/common/hopscotch_hash.h
paddle/fluid/feed/src/common/hopscotch_hash.h
+1817
-0
paddle/fluid/feed/src/common/hopscotch_map.h
paddle/fluid/feed/src/common/hopscotch_map.h
+679
-0
paddle/fluid/feed/src/common/hopscotch_set.h
paddle/fluid/feed/src/common/hopscotch_set.h
+525
-0
paddle/fluid/feed/src/data_reader/CMakeLists.txt
paddle/fluid/feed/src/data_reader/CMakeLists.txt
+1
-0
paddle/fluid/feed/src/data_reader/data_set.cc
paddle/fluid/feed/src/data_reader/data_set.cc
+173
-0
paddle/fluid/feed/src/data_reader/data_set.h
paddle/fluid/feed/src/data_reader/data_set.h
+16
-0
未找到文件。
build.sh
0 → 100755
浏览文件 @
0dc7d425
#!bash
build_mode
=
$1
function
print_usage
()
{
echo
"++++++++++++++++++++++++++++++++++++++++++++++++++++"
echo
"sh build.sh all|make|clean"
echo
"- all: will update all env && make it"
echo
"- make: just do make, never update env"
echo
"- clean: make clean"
echo
"++++++++++++++++++++++++++++++++++++++++++++++++++++"
exit
0
}
if
[
$#
-lt
1
]
;
then
print_usage
fi
cd
~
user_dir
=
`
pwd
`
cd
-
python_binary
=
${
user_dir
}
/.jumbo/bin/python2.7
python_library
=
${
user_dir
}
/.jumbo/lib/python2.7.so
python_include_dir
=
${
user_dir
}
/.jumbo/include/python2.7
if
[
!
-f
${
python_binary
}
]
;
then
echo
"Miss python
${
python_binary
}
, please install with this cmd: jumbo install python"
exit
-1
fi
#apply feed code
if
[
-f
"paddle/fluid/feed/apply_feed_code.sh"
]
;
then
sh paddle/fluid/feed/apply_feed_code.sh
fi
function
makeit
()
{
cd
build
make
-j8
cd
..
}
function
cmake_all
()
{
mkdir
build
cd
build
#make clean
cmake
-DCMAKE_INSTALL_PREFIX
=
./output/
-DCMAKE_BUILD_TYPE
=
Release
-DWITH_PYTHON
=
ON
-DWITH_MKL
=
OFF
-DWITH_GPU
=
OFF
-DWITH_PSLIB
=
ON
-DPYTHON_INCLUDE_DIR
=
${
python_include_dir
}
-DPYTHON_LIBRARY
=
${
python_library
}
-DPYTHON_EXECUTABLE
=
${
python_binary
}
..
cd
..
}
if
[
"
${
build_mode
}
"
=
"all"
]
;
then
cmake_all
makeit
elif
[
"
${
build_mode
}
"
=
"make"
]
;
then
makeit
elif
"
${
build_mode
}
"
=
"clean"
]
;
then
cd
build
make clean
fi
paddle/fluid/feed/CMakeLists.txt
0 → 100755
浏览文件 @
0dc7d425
add_subdirectory
(
src
)
add_subdirectory
(
pybind
)
paddle/fluid/feed/apply_feed_code.sh
0 → 100755
浏览文件 @
0dc7d425
#!bash
#将FEED定制化代码生效到Paddle代码库(如FEED插件注册) 编译前执行
function
fatal_log
()
{
echo
"
$1
"
exit
-1
}
#处理pybind 拓展
function
apply_pybind
()
{
pybind_file
=
'paddle/fluid/pybind/pybind.cc'
if
[
!
-f
${
pybind_file
}
]
;
then
fatal_log
"Missing Requied File:
${
pybind_file
}
"
fi
find_inferece_api
=
`
grep
'inference_api.h'
${
pybind_file
}
|wc
-l
`
if
[
${
find_inferece_api
}
-ne
1
]
;
then
fatal_log
"Missing inference_api.h, Need Code Adjust"
fi
find_inferece_api
=
`
grep
'BindInferenceApi'
${
pybind_file
}
|wc
-l
`
if
[
${
find_inferece_api
}
-ne
1
]
;
then
fatal_log
"Missing BindInferenceApi, Need Code Adjust"
fi
makefile
=
'paddle/fluid/pybind/CMakeLists.txt'
if
[
!
-f
${
makefile
}
]
;
then
fatal_log
"Missing Requied File:
${
makefile
}
"
fi
sed
-i
'/expand_api/d'
${
pybind_file
}
sed
-i
'/BindExpandApi/d'
${
pybind_file
}
sed
-i
'/feed_data_set/d'
${
makefile
}
sed
-i
'/feed_paddle_pybind/d'
${
makefile
}
sed
-i
'/APPEND PYBIND_DEPS fs/d'
${
makefile
}
sed
-i
'/inference_api.h/a\#include "paddle/fluid/feed/pybind/expand_api.h"'
${
pybind_file
}
sed
-i
'/BindInferenceApi/a\ BindExpandApi(&m);'
${
pybind_file
}
sed
-i
'/set(PYBIND_SRCS/i\list(APPEND PYBIND_DEPS feed_data_set)'
${
makefile
}
sed
-i
'/set(PYBIND_SRCS/i\list(APPEND PYBIND_DEPS feed_paddle_pybind)'
${
makefile
}
sed
-i
'/set(PYBIND_SRCS/i\list(APPEND PYBIND_DEPS fs)'
${
makefile
}
}
function
apply_feed_src
()
{
makefile
=
'paddle/fluid/CMakeLists.txt'
if
[
!
-f
${
makefile
}
]
;
then
fatal_log
"Missing Requied File:
${
makefile
}
"
fi
find_py
=
`
grep
'pybind'
${
makefile
}
|wc
-l
`
if
[
${
find_py
}
-ne
1
]
;
then
fatal_log
"Missing pybind, Need Code Adjust"
fi
sed
-i
'/feed/d'
${
makefile
}
sed
-i
'/pybind/i\add_subdirectory(feed)'
${
makefile
}
dataset_file
=
'paddle/fluid/framework/dataset_factory.cc'
if
[
!
-f
${
dataset_file
}
]
;
then
fatal_log
"Missing Requied File:
${
dataset_file
}
"
fi
sed
-i
'/FeedMultiSlotDataset/d'
${
dataset_file
}
sed
-i
'/data_reader/d'
${
dataset_file
}
sed
-i
'/REGISTER_DATASET_CLASS(MultiSlotDataset)/a\REGISTER_DATASET_CLASS(FeedMultiSlotDataset);'
${
dataset_file
}
sed
-i
'/data_set.h/a\#include "paddle/fluid/feed/src/data_reader/data_set.h"'
${
dataset_file
}
sed
-i
'/feed_data_set/d'
paddle/fluid/framework/CMakeLists.txt
#sed -i '/target_link_libraries(executor/a\target_link_libraries(feed_data_set)' paddle/fluid/framework/CMakeLists.txt
#sed -i '/target_link_libraries(executor/a\add_dependencies(feed_data_set)' paddle/fluid/framework/CMakeLists.txt
}
apply_pybind
apply_feed_src
paddle/fluid/feed/pybind/CMakeLists.txt
0 → 100755
浏览文件 @
0dc7d425
set
(
FEED_PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper
pass_builder parallel_executor profiler layer tracer engine scope_pool
dict_plugin fs shell
)
if
(
WITH_PYTHON
)
list
(
APPEND FEED_PYBIND_DEPS py_func_op
)
endif
()
set
(
FEED_PYBIND_SRCS
expand_api.cc
)
if
(
WITH_PYTHON
)
if
(
WITH_AMD_GPU
)
hip_library
(
feed_paddle_pybind SRCS
${
FEED_PYBIND_SRCS
}
DEPS ARCHIVE_START
${
FEED_PYBIND_DEPS
}
ARCHIVE_END
)
else
()
cc_library
(
feed_paddle_pybind SRCS
${
FEED_PYBIND_SRCS
}
DEPS
${
FEED_PYBIND_DEPS
}
)
endif
(
WITH_AMD_GPU
)
get_property
(
os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES
)
target_link_libraries
(
feed_paddle_pybind
${
os_dependency_modules
}
)
endif
(
WITH_PYTHON
)
paddle/fluid/feed/pybind/expand_api.cc
0 → 100755
浏览文件 @
0dc7d425
#include "paddle/fluid/feed/pybind/expand_api.h"
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <cstring>
#include <iostream>
#include <map>
#include <memory>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/feed/src/common/dict_plugin.h"
namespace
py
=
pybind11
;
namespace
paddle
{
namespace
pybind
{
using
paddle
::
framework
::
DictPluginManager
;
using
paddle
::
framework
::
FeasignCacheDict
;
void
BindExpandDictPlugin
(
py
::
module
*
m
);
void
BindExpandApi
(
py
::
module
*
m
)
{
BindExpandDictPlugin
(
m
);
}
void
BindExpandDictPlugin
(
py
::
module
*
m
)
{
py
::
class_
<
FeasignCacheDict
>
(
*
m
,
"FeasignCacheDict"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
<
const
FeasignCacheDict
&>
())
.
def
(
"load"
,
&
FeasignCacheDict
::
Load
);
py
::
class_
<
DictPluginManager
>
(
*
m
,
"DictPluginManager"
)
.
def
(
py
::
init
<>
())
.
def_static
(
"instance"
,
&
DictPluginManager
::
Instance
)
.
def
(
"load_dict"
,
&
DictPluginManager
::
LoadDict
)
.
def
(
"create_dict"
,
&
DictPluginManager
::
CreateDict
);
}
}
// namespace pybind
}
// namespace paddle
paddle/fluid/feed/pybind/expand_api.h
0 → 100755
浏览文件 @
0dc7d425
#pragma once
#include <pybind11/pybind11.h>
namespace
paddle
{
namespace
pybind
{
void
BindExpandApi
(
pybind11
::
module
*
m
);
}
// namespace pybind
}
// namespace paddle
paddle/fluid/feed/pybind/pybind.cc
0 → 100755
浏览文件 @
0dc7d425
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <Python.h>
#include <algorithm>
#include <cstdlib>
#include <map>
#include <memory>
#include <mutex> // NOLINT // for call_once
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/ir/coalesce_grad_tensor_pass.h"
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/parallel_executor.h"
#include "paddle/fluid/framework/prune.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/scope_pool.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/version.h"
#include "paddle/fluid/memory/allocation/allocator_strategy.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/py_func_op.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/init.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/pybind/box_helper_py.h"
#include "paddle/fluid/pybind/const_value.h"
#include "paddle/fluid/pybind/data_set_py.h"
#include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/pybind/fleet_wrapper_py.h"
#include "paddle/fluid/pybind/imperative.h"
#include "paddle/fluid/pybind/inference_api.h"
#include "paddle/fluid/pybind/ir.h"
#include "paddle/fluid/pybind/expand_api.h"
#ifndef _WIN32
#include "paddle/fluid/pybind/nccl_wrapper_py.h"
#endif
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/pybind/protobuf.h"
#include "paddle/fluid/pybind/pybind.h" // NOLINT
#include "paddle/fluid/pybind/reader_py.h"
#include "paddle/fluid/pybind/tensor_py.h"
#include "paddle/fluid/string/to_string.h"
#ifdef PADDLE_WITH_CUDA
#ifndef _WIN32
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
#endif
#include "paddle/fluid/platform/cuda_profiler.h"
#include "paddle/fluid/platform/gpu_info.h"
#endif
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/pybind/communicator_py.h"
#endif
#include "pybind11/stl.h"
DEFINE_bool
(
reader_queue_speed_test_mode
,
false
,
"If set true, the queue.pop will only get data from queue but not "
"remove the data from queue for speed testing"
);
DECLARE_bool
(
use_mkldnn
);
#ifdef PADDLE_WITH_NGRAPH
DECLARE_bool
(
use_ngraph
);
#endif
// disable auto conversion to list in Python
PYBIND11_MAKE_OPAQUE
(
paddle
::
framework
::
LoDTensorArray
);
namespace
paddle
{
namespace
pybind
{
bool
IsCompiledWithCUDA
()
{
#ifndef PADDLE_WITH_CUDA
return
false
;
#else
return
true
;
#endif
}
bool
IsCompiledWithMKLDNN
()
{
#ifndef PADDLE_WITH_MKLDNN
return
false
;
#else
return
true
;
#endif
}
bool
IsCompiledWithNGRAPH
()
{
#ifndef PADDLE_WITH_NGRAPH
return
false
;
#else
return
true
;
#endif
}
bool
IsCompiledWithBrpc
()
{
#ifndef PADDLE_WITH_DISTRIBUTE
return
false
;
#endif
#ifdef PADDLE_WITH_GRPC
return
false
;
#endif
return
true
;
}
bool
IsCompiledWithDIST
()
{
#ifdef PADDLE_WITH_DISTRIBUTE
return
true
;
#else
return
false
;
#endif
}
template
<
typename
PlaceType1
,
typename
PlaceType2
>
static
inline
bool
IsSamePlace
(
const
PlaceType1
&
p1
,
const
PlaceType2
&
p2
)
{
return
paddle
::
platform
::
Place
(
p1
)
==
paddle
::
platform
::
Place
(
p2
);
}
template
<
typename
PlaceType
>
static
inline
int
PlaceIndex
(
const
PlaceType
&
p
)
{
return
static_cast
<
int
>
(
paddle
::
platform
::
Place
(
p
).
which
());
}
#ifdef PADDLE_WITH_AVX
PYBIND11_MODULE
(
core_avx
,
m
)
{
#else
PYBIND11_MODULE
(
core_noavx
,
m
)
{
#endif
// Not used, just make sure cpu_info.cc is linked.
paddle
::
platform
::
CpuTotalPhysicalMemory
();
paddle
::
memory
::
allocation
::
UseAllocatorStrategyGFlag
();
m
.
doc
()
=
"C++ core of PaddlePaddle"
;
// using framework in this function. Since it is inside a function, it will
// not cause namespace pollution.
using
namespace
paddle
::
framework
;
// NOLINT
BindException
(
&
m
);
m
.
def
(
"set_num_threads"
,
&
platform
::
SetNumThreads
);
m
.
def
(
"_append_python_callable_object_and_return_id"
,
[](
py
::
object
py_obj
)
->
size_t
{
return
paddle
::
operators
::
AppendPythonCallableObjectAndReturnId
(
py_obj
);
});
m
.
def
(
"_get_use_default_grad_op_desc_maker_ops"
,
[]
{
return
OpInfoMap
::
Instance
().
GetUseDefaultGradOpDescMakerOps
();
});
// NOTE(zjl): ctest would load environment variables at the beginning even
// though we have not `import paddle.fluid as fluid`. So we add this API
// to enable eager deletion mode in unittest.
m
.
def
(
"_set_eager_deletion_mode"
,
&
paddle
::
framework
::
SetEagerDeletionMode
);
m
.
def
(
"_set_fuse_parameter_group_size"
,
&
paddle
::
framework
::
ir
::
SetFuseParameterGroupsSize
);
m
.
def
(
"_set_fuse_parameter_memory_size"
,
&
paddle
::
framework
::
ir
::
SetFuseParameterMemorySize
);
m
.
add_object
(
"_cleanup"
,
py
::
capsule
([]()
{
ScopePool
::
Instance
().
Clear
();
}));
m
.
def
(
"_set_paddle_lib_path"
,
&
paddle
::
platform
::
dynload
::
SetPaddleLibPath
);
BindImperative
(
&
m
);
py
::
class_
<
Tensor
>
(
m
,
"Tensor"
,
py
::
buffer_protocol
())
.
def
(
"__array__"
,
[](
Tensor
&
self
)
{
return
TensorToPyArray
(
self
);
})
.
def
(
"_is_initialized"
,
[](
const
Tensor
&
self
)
{
return
self
.
IsInitialized
();
})
.
def
(
"_get_dims"
,
[](
const
Tensor
&
self
)
{
return
vectorize
(
self
.
dims
());
})
.
def
(
"_set_dims"
,
[](
Tensor
&
self
,
const
std
::
vector
<
int64_t
>
&
dim
)
{
self
.
Resize
(
make_ddim
(
dim
));
})
.
def
(
"_set_layout"
,
[](
Tensor
&
self
,
const
std
::
string
&
layout
)
{
self
.
set_layout
(
StringToDataLayout
(
layout
));
})
.
def
(
"_alloc_float"
,
[](
Tensor
&
self
,
paddle
::
platform
::
CUDAPlace
&
place
)
{
self
.
mutable_data
<
float
>
(
place
);
})
.
def
(
"_alloc_float"
,
[](
Tensor
&
self
,
paddle
::
platform
::
CPUPlace
&
place
)
{
self
.
mutable_data
<
float
>
(
place
);
})
.
def
(
"_alloc_double"
,
[](
Tensor
&
self
,
paddle
::
platform
::
CPUPlace
&
place
)
{
self
.
mutable_data
<
double
>
(
place
);
})
.
def
(
"_alloc_int"
,
[](
Tensor
&
self
,
paddle
::
platform
::
CPUPlace
&
place
)
{
self
.
mutable_data
<
int
>
(
place
);
})
.
def
(
"_alloc_int"
,
[](
Tensor
&
self
,
paddle
::
platform
::
CUDAPlace
&
place
)
{
self
.
mutable_data
<
int
>
(
place
);
})
.
def
(
"_alloc_int"
,
[](
Tensor
&
self
,
paddle
::
platform
::
CUDAPinnedPlace
&
place
)
{
self
.
mutable_data
<
int
>
(
place
);
})
.
def
(
"_alloc_float"
,
[](
Tensor
&
self
,
paddle
::
platform
::
CUDAPinnedPlace
&
place
)
{
self
.
mutable_data
<
float
>
(
place
);
})
.
def
(
"_clear"
,
&
Tensor
::
clear
)
.
def
(
"set"
,
PyCPUTensorSetFromArray
<
float
>
)
.
def
(
"set"
,
PyCPUTensorSetFromArray
<
int
>
)
.
def
(
"set"
,
PyCPUTensorSetFromArray
<
double
>
)
.
def
(
"set"
,
PyCPUTensorSetFromArray
<
int64_t
>
)
.
def
(
"set"
,
PyCPUTensorSetFromArray
<
bool
>
)
.
def
(
"set"
,
PyCPUTensorSetFromArray
<
uint16_t
>
)
.
def
(
"set"
,
PyCPUTensorSetFromArray
<
uint8_t
>
)
.
def
(
"set"
,
PyCPUTensorSetFromArray
<
int8_t
>
)
#ifdef PADDLE_WITH_CUDA
.
def
(
"set"
,
PyCUDATensorSetFromArray
<
float
>
)
.
def
(
"set"
,
PyCUDATensorSetFromArray
<
int
>
)
.
def
(
"set"
,
PyCUDATensorSetFromArray
<
double
>
)
.
def
(
"set"
,
PyCUDATensorSetFromArray
<
int64_t
>
)
.
def
(
"set"
,
PyCUDATensorSetFromArray
<
bool
>
)
.
def
(
"set"
,
PyCUDATensorSetFromArray
<
uint16_t
>
)
.
def
(
"set"
,
PyCUDATensorSetFromArray
<
uint8_t
>
)
.
def
(
"set"
,
PyCUDATensorSetFromArray
<
int8_t
>
)
.
def
(
"set"
,
PyCUDAPinnedTensorSetFromArray
<
float
>
)
.
def
(
"set"
,
PyCUDAPinnedTensorSetFromArray
<
int
>
)
.
def
(
"set"
,
PyCUDAPinnedTensorSetFromArray
<
double
>
)
.
def
(
"set"
,
PyCUDAPinnedTensorSetFromArray
<
int64_t
>
)
.
def
(
"set"
,
PyCUDAPinnedTensorSetFromArray
<
bool
>
)
.
def
(
"set"
,
PyCUDAPinnedTensorSetFromArray
<
uint16_t
>
)
.
def
(
"set"
,
PyCUDAPinnedTensorSetFromArray
<
uint8_t
>
)
.
def
(
"set"
,
PyCUDAPinnedTensorSetFromArray
<
int8_t
>
)
#endif
.
def
(
"shape"
,
[](
Tensor
&
self
)
{
return
vectorize
(
self
.
dims
());
})
.
def
(
"_set_float_element"
,
TensorSetElement
<
float
>
)
.
def
(
"_get_float_element"
,
TensorGetElement
<
float
>
)
.
def
(
"_set_double_element"
,
TensorSetElement
<
double
>
)
.
def
(
"_get_double_element"
,
TensorGetElement
<
double
>
)
.
def
(
"_place"
,
[](
Tensor
&
self
)
{
return
self
.
place
();
})
.
def
(
"_dtype"
,
[](
Tensor
&
self
)
{
return
self
.
type
();
})
.
def
(
"__getitem__"
,
PySliceTensor
,
py
::
return_value_policy
::
reference
)
.
def
(
"__str__"
,
[](
const
Tensor
&
self
)
{
std
::
stringstream
ostr
;
ostr
<<
self
;
return
ostr
.
str
();
});
py
::
class_
<
LoDTensor
,
Tensor
>
(
m
,
"LoDTensor"
,
R"DOC(
LoDTensor is a Tensor with optional LoD information.
np.array(lod_tensor) can convert LoDTensor to numpy array.
lod_tensor.lod() can retrieve the LoD information.
LoD is short for Level of Details and is usually used for varied sequence
length. You can skip the following comment if you don't need optional LoD.
For example, a LoDTensor X can look like the example below. It contains
2 sequences. The first has length 2 and the second has length 3, as
described by x.lod.
The first tensor dimension 5=2+3 is calculated from LoD if it's available.
It means the total number of sequence element. In X, each element has 2
columns, hence [5, 2].
x.lod = [[2, 3]]
x.data = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
x.shape = [5, 2]
LoD can have multiple levels (for example, a paragraph can have multiple
sentences and a sentence can have multiple words). In the following
LodTensor Y, the lod_level is 2. It means there are 2 sequence, the
first sequence length is 2 (has 2 sub-sequences), the second one's
length is 1. The first sequence's 2 sub-sequences have length 2 and 2,
respectively. And the second sequence's 1 sub-sequence has length 3.
y.lod = [[2 1], [2 2 3]]
y.shape = [2+2+3, ...]
Examples:
.. code-block:: python
import paddle.fluid as fluid
t = fluid.LoDTensor()
Note:
In above description, LoD is length-based. In Paddle internal
implementation, lod is offset-based. Hence, internally,
y.lod is represented as [[0, 2, 3], [0, 2, 4, 7]] (length-based
equivlent would be [[2-0, 3-2], [2-0, 4-2, 7-4]]).
Sometimes LoD is called recursive_sequence_length to be more
self-explanatory. In this case, it must be length-based. Due to history
reasons. when LoD is called lod in public API, it might be offset-based.
Users should be careful about it.
)DOC"
)
.
def
(
"__array__"
,
[](
Tensor
&
self
)
{
return
TensorToPyArray
(
self
);
})
.
def
(
"__init__"
,
[](
LoDTensor
&
instance
,
const
std
::
vector
<
std
::
vector
<
size_t
>>
&
recursive_sequence_lengths
)
{
LoD
new_lod
;
new_lod
.
reserve
(
recursive_sequence_lengths
.
size
());
std
::
copy
(
recursive_sequence_lengths
.
begin
(),
recursive_sequence_lengths
.
end
(),
std
::
back_inserter
(
new_lod
));
LoD
new_offset_lod
=
ConvertToOffsetBasedLoD
(
new_lod
);
PADDLE_ENFORCE_EQ
(
CheckLoD
(
new_offset_lod
,
-
1
),
true
,
"the provided recursive_sequence_lengths info is invalid"
);
new
(
&
instance
)
LoDTensor
(
new_offset_lod
);
})
.
def
(
"__init__"
,
[](
LoDTensor
&
instance
)
{
new
(
&
instance
)
LoDTensor
();
})
// We implement offset based LOD in C++ while we use length based with
// Python API. So we changed set_lod to set_recursive_sequence_lengths to
// avoid misuse.
// The discussion is here:
// https://github.com/PaddlePaddle/Paddle/issues/10855
.
def
(
"set_lod"
,
[](
LoDTensor
&
self
,
const
std
::
vector
<
std
::
vector
<
size_t
>>
&
lod
)
{
// the input lod is offset-based level-of-detail info
LoD
new_lod
;
new_lod
.
reserve
(
lod
.
size
());
std
::
copy
(
lod
.
begin
(),
lod
.
end
(),
std
::
back_inserter
(
new_lod
));
PADDLE_ENFORCE_EQ
(
CheckLoD
(
new_lod
,
vectorize
(
self
.
dims
()).
front
()),
true
,
"the provided lod info is invalid"
);
self
.
set_lod
(
new_lod
);
},
py
::
arg
(
"lod"
),
R"DOC(
Set LoD of the LoDTensor.
Args:
lod (List[List[int]]): the lod to be set.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
t = fluid.LoDTensor()
t.set(np.ndarray([5, 30]), fluid.CPUPlace())
t.set_lod([[0, 2, 5]])
)DOC"
)
.
def
(
"set_recursive_sequence_lengths"
,
[](
LoDTensor
&
self
,
const
std
::
vector
<
std
::
vector
<
size_t
>>
&
recursive_sequence_lengths
)
{
// the input recursive_sequence_lengths is length-based
// level-of-detail info
LoD
new_lod
;
new_lod
.
reserve
(
recursive_sequence_lengths
.
size
());
std
::
copy
(
recursive_sequence_lengths
.
begin
(),
recursive_sequence_lengths
.
end
(),
std
::
back_inserter
(
new_lod
));
LoD
new_offset_lod
=
ConvertToOffsetBasedLoD
(
new_lod
);
PADDLE_ENFORCE_EQ
(
CheckLoD
(
new_offset_lod
,
vectorize
(
self
.
dims
()).
front
()),
true
,
"the provided recursive_sequence_lengths info is invalid"
);
self
.
set_lod
(
new_offset_lod
);
},
py
::
arg
(
"recursive_sequence_lengths"
),
R"DOC(
Set LoD of the LoDTensor according to recursive sequence length.
For example, if recursive_sequence_lengths=[[2, 3]], meaning that
there are two sequences with length 2 and 3 respectively, the
corresponding lod would be [[0, 2, 2+3]], i.e, [[0, 2, 5]].
Args:
recursive_sequence_lengths (List[List[int]]): sequence lengths.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
t = fluid.LoDTensor()
t.set(np.ndarray([5, 30]), fluid.CPUPlace())
t.set_recursive_sequence_lengths([[2, 3]])
)DOC"
)
.
def
(
"lod"
,
[](
LoDTensor
&
self
)
->
std
::
vector
<
std
::
vector
<
size_t
>>
{
// output the offset-based lod info
LoD
lod
=
self
.
lod
();
std
::
vector
<
std
::
vector
<
size_t
>>
new_lod
;
new_lod
.
reserve
(
lod
.
size
());
std
::
copy
(
lod
.
begin
(),
lod
.
end
(),
std
::
back_inserter
(
new_lod
));
return
new_lod
;
},
R"DOC(
Return the LoD of the LoDTensor.
Returns:
out (List[List[int]]): the lod of the LoDTensor.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
t = fluid.LoDTensor()
t.set(np.ndarray([5, 30]), fluid.CPUPlace())
t.set_lod([[0, 2, 5]])
print(t.lod()) # [[0, 2, 5]]
)DOC"
)
// Set above comments of set_lod.
.
def
(
"recursive_sequence_lengths"
,
[](
LoDTensor
&
self
)
->
std
::
vector
<
std
::
vector
<
size_t
>>
{
// output the length-based lod info
LoD
lod
=
ConvertToLengthBasedLoD
(
self
.
lod
());
std
::
vector
<
std
::
vector
<
size_t
>>
new_lod
;
new_lod
.
reserve
(
lod
.
size
());
std
::
copy
(
lod
.
begin
(),
lod
.
end
(),
std
::
back_inserter
(
new_lod
));
return
new_lod
;
},
R"DOC(
Return the sequence length of the LoDTensor corresponding to LoD.
Returns:
out (List[List[int]): the sequence lengths.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
t = fluid.LoDTensor()
t.set(np.ndarray([5, 30]), fluid.CPUPlace())
t.set_recursive_sequence_lengths([[2, 3]])
print(t.recursive_sequence_lengths()) # [[2, 3]]
)DOC"
)
.
def
(
"has_valid_recursive_sequence_lengths"
,
[](
LoDTensor
&
self
)
->
bool
{
// Check that the lod info is valid and match the outermost
// dimension of the LoDTensor data
return
CheckLoD
(
self
.
lod
(),
vectorize
(
self
.
dims
()).
front
());
},
R"DOC(
Check whether the lod of the LoDTensor is valid.
Returns:
out (bool): whether the lod is valid.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
t = fluid.LoDTensor()
t.set(np.ndarray([5, 30]), fluid.CPUPlace())
t.set_recursive_sequence_lengths([[2, 3]])
print(t.has_valid_recursive_sequence_lengths()) # True
)DOC"
)
.
def
(
"__getitem__"
,
PySliceTensor
,
py
::
return_value_policy
::
reference
,
R"DOC(
Slice the original Tensor, and remove the LoD information.
Returns:
out (Tensor): new Tensor(NOT LoDTensor).
)DOC"
)
.
def
(
"__str__"
,
[](
const
LoDTensor
&
self
)
{
std
::
stringstream
ostr
;
ostr
<<
self
;
return
ostr
.
str
();
})
.
def
(
"_copy"
,
[](
const
LoDTensor
&
self
,
const
platform
::
Place
&
place
)
{
// follow fetch_op's inplementation
LoDTensor
dst
;
if
(
self
.
IsInitialized
()
&&
self
.
numel
()
>
0
)
{
TensorCopySync
(
self
,
place
,
&
dst
);
}
else
{
// Not copy, if the src tensor is empty.
dst
.
clear
();
dst
.
Resize
({
0
});
}
dst
.
set_lod
(
self
.
lod
());
return
dst
;
});
py
::
class_
<
SelectedRows
>
(
m
,
"SelectedRows"
)
.
def
(
"__init__"
,
[](
SelectedRows
&
instance
)
{
new
(
&
instance
)
SelectedRows
();
})
.
def
(
"__init__"
,
[](
SelectedRows
&
instance
,
const
std
::
vector
<
int64_t
>
rows
,
const
int64_t
&
height
)
{
new
(
&
instance
)
SelectedRows
(
rows
,
height
);
})
.
def
(
"get_tensor"
,
[](
SelectedRows
&
self
)
{
return
self
.
mutable_value
();
},
py
::
return_value_policy
::
reference
)
.
def
(
"numel"
,
[](
SelectedRows
&
self
)
->
int64_t
{
return
self
.
value
().
numel
();
})
.
def
(
"set_height"
,
&
SelectedRows
::
set_height
)
.
def
(
"height"
,
&
SelectedRows
::
height
)
.
def
(
"set_rows"
,
[](
SelectedRows
&
self
,
std
::
vector
<
int64_t
>
rows
)
{
#ifndef PADDLE_WITH_CUDA
self
.
set_rows
(
rows
);
#else
Vector
<
int64_t
>
new_rows
(
rows
);
self
.
set_rows
(
new_rows
);
#endif
})
.
def
(
"sync_index"
,
[](
SelectedRows
&
instance
)
{
instance
.
SyncIndex
();
})
.
def
(
"rows"
,
[](
SelectedRows
&
self
)
{
auto
rows
=
self
.
rows
();
std
::
vector
<
int64_t
>
new_rows
;
new_rows
.
reserve
(
rows
.
size
());
std
::
copy
(
rows
.
begin
(),
rows
.
end
(),
std
::
back_inserter
(
new_rows
));
return
new_rows
;
});
py
::
class_
<
Variable
>
(
m
,
"Variable"
,
R"DOC(Variable Class.
All parameter, weight, gradient are variables in Paddle.
)DOC"
)
.
def
(
py
::
init
<>
())
.
def
(
"is_int"
,
[](
const
Variable
&
var
)
{
return
var
.
IsType
<
int
>
();
})
.
def
(
"set_int"
,
[](
Variable
&
var
,
int
val
)
->
void
{
*
var
.
GetMutable
<
int
>
()
=
val
;
})
.
def
(
"get_int"
,
[](
const
Variable
&
var
)
->
int
{
return
var
.
Get
<
int
>
();
})
.
def
(
"is_float"
,
[](
const
Variable
&
var
)
{
return
var
.
IsType
<
float
>
();
})
.
def
(
"set_float"
,
[](
Variable
&
var
,
float
val
)
->
void
{
*
var
.
GetMutable
<
float
>
()
=
val
;
})
.
def
(
"get_float"
,
[](
const
Variable
&
var
)
->
float
{
return
var
.
Get
<
float
>
();
})
.
def
(
"get_tensor"
,
[](
Variable
&
self
)
->
LoDTensor
*
{
return
self
.
GetMutable
<
LoDTensor
>
();
},
py
::
return_value_policy
::
reference
)
.
def
(
"get_lod_rank_table"
,
[](
Variable
&
self
)
{
return
self
.
GetMutable
<
LoDRankTable
>
();
},
py
::
return_value_policy
::
reference
)
.
def
(
"get_selected_rows"
,
[](
Variable
&
self
)
->
SelectedRows
*
{
return
self
.
GetMutable
<
SelectedRows
>
();
},
py
::
return_value_policy
::
reference
)
.
def
(
"get_lod_tensor_array"
,
[](
Variable
&
self
)
{
return
self
.
GetMutable
<
LoDTensorArray
>
();
},
py
::
return_value_policy
::
reference
)
#if (defined(PADDLE_WITH_CUDA) && !defined(_WIN32))
.
def
(
"get_communicator"
,
[](
Variable
&
self
)
->
platform
::
Communicator
*
{
return
self
.
GetMutable
<
platform
::
Communicator
>
();
},
py
::
return_value_policy
::
reference
)
#endif
.
def
(
"get_reader"
,
[](
Variable
&
self
)
->
framework
::
ReaderHolder
*
{
PADDLE_ENFORCE_EQ
(
self
.
IsType
<
framework
::
ReaderHolder
>
(),
true
);
return
self
.
GetMutable
<
framework
::
ReaderHolder
>
();
},
py
::
return_value_policy
::
reference
);
BindReader
(
&
m
);
using
LoDTensorBlockingQueue
=
::
paddle
::
operators
::
reader
::
LoDTensorBlockingQueue
;
using
LoDTensorBlockingQueueHolder
=
::
paddle
::
operators
::
reader
::
LoDTensorBlockingQueueHolder
;
py
::
class_
<
LoDTensorBlockingQueue
,
std
::
shared_ptr
<
LoDTensorBlockingQueue
>>
(
m
,
"LoDTensorBlockingQueue"
,
""
)
.
def
(
"push"
,
[](
LoDTensorBlockingQueue
&
self
,
const
std
::
vector
<
framework
::
LoDTensor
>
&
lod_tensor_vec
)
{
pybind11
::
gil_scoped_release
release
;
return
self
.
Push
(
lod_tensor_vec
);
})
.
def
(
"size"
,
&
LoDTensorBlockingQueue
::
Size
)
.
def
(
"capacity"
,
&
LoDTensorBlockingQueue
::
Cap
)
.
def
(
"close"
,
&
LoDTensorBlockingQueue
::
Close
)
.
def
(
"is_closed"
,
&
LoDTensorBlockingQueue
::
IsClosed
);
m
.
def
(
"init_lod_tensor_blocking_queue"
,
[](
Variable
&
var
,
size_t
capacity
)
->
std
::
shared_ptr
<
LoDTensorBlockingQueue
>
{
VLOG
(
1
)
<<
"init_lod_tensor_blocking_queue"
;
auto
*
holder
=
var
.
GetMutable
<
LoDTensorBlockingQueueHolder
>
();
holder
->
InitOnce
(
capacity
,
FLAGS_reader_queue_speed_test_mode
);
return
holder
->
GetQueue
();
},
py
::
return_value_policy
::
copy
);
py
::
class_
<
Scope
>
(
m
,
"_Scope"
,
R"DOC(
Scope is an association of a name to Variable. All variables belong to Scope.
Variables in a parent scope can be retrieved from local scope.
You need to specify a scope to run a Net, i.e., `exe.Run(&scope)`.
One net can run in different scopes and update different variable in the
scope.
You can create var in a scope and get it from the scope.
Examples:
.. code-block:: python
import paddle.fluid as fluid
# create tensor from a scope and set value to it.
param = scope.var('Param').get_tensor()
param_array = np.full((height, row_numel), 5.0).astype("float32")
param.set(param_array, place)
)DOC"
)
.
def
(
"_remove_from_pool"
,
[](
Scope
&
self
)
{
ScopePool
::
Instance
().
Remove
(
&
self
);
})
.
def
(
"var"
,
[](
Scope
&
self
,
const
std
::
string
&
name
)
->
Variable
*
{
return
self
.
Var
(
name
);
},
py
::
arg
(
"name"
),
R"DOC(
Find or create variable named :code:`name` in the current scope.
If the variable named :code:`name` does not exist in the
current scope, the variable would be created. Otherwise,
return the existing variable.
Args:
name (str): the variable name.
Returns:
out (core.Variable): the found or created variable.
)DOC"
,
py
::
return_value_policy
::
reference
)
.
def
(
"find_var"
,
&
Scope
::
FindVar
,
py
::
arg
(
"name"
),
R"DOC(
Find variable named :code:`name` in the current scope or
its parent scope. Return None if not found.
Args:
name (str): the variable name.
Returns:
out (core.Variable|None): the found variable or None.
)DOC"
,
py
::
return_value_policy
::
reference
)
.
def
(
"new_scope"
,
[](
Scope
&
self
)
->
Scope
*
{
return
&
self
.
NewScope
();
},
R"DOC(
Create a new sub-scope of the current scope.
Returns:
out (core._Scope): the created sub-scope.
)DOC"
,
py
::
return_value_policy
::
reference
)
.
def
(
"drop_kids"
,
&
Scope
::
DropKids
,
R"DOC(
Delete all sub-scopes of the current scope.
)DOC"
)
.
def
(
"_kids"
,
&
Scope
::
kids
);
m
.
def
(
"Scope"
,
[]()
->
Scope
*
{
auto
*
s
=
new
Scope
();
ScopePool
::
Instance
().
Insert
(
std
::
unique_ptr
<
Scope
>
(
s
));
return
s
;
},
R"DOC(
Create a new scope.
Returns:
out (core._Scope): the created scope.
)DOC"
,
py
::
return_value_policy
::
reference
);
//! @note: Be careful! PyBind will return std::string as an unicode, not
//! Python str. If you want a str object, you should cast them in Python.
m
.
def
(
"get_all_op_protos"
,
[]()
->
std
::
vector
<
py
::
bytes
>
{
std
::
vector
<
py
::
bytes
>
ret_values
;
for
(
auto
&
iter
:
OpInfoMap
::
Instance
().
map
())
{
auto
&
info
=
iter
.
second
;
if
(
info
.
HasOpProtoAndChecker
())
{
std
::
string
str
;
PADDLE_ENFORCE_EQ
(
info
.
Proto
().
SerializeToString
(
&
str
),
true
,
"Serialize OpProto Error. This could be a bug of Paddle."
);
ret_values
.
emplace_back
(
str
);
}
}
return
ret_values
;
});
m
.
def
(
"get_grad_op_desc"
,
[](
const
OpDesc
&
op_desc
,
const
std
::
unordered_set
<
std
::
string
>
&
no_grad_set
,
const
std
::
vector
<
BlockDesc
*>
&
grad_sub_block
)
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
grad_to_var
;
std
::
vector
<
std
::
unique_ptr
<
OpDesc
>>
grad_op_descs
=
framework
::
OpInfoMap
::
Instance
()
.
Get
(
op_desc
.
Type
())
.
GradOpMaker
()(
op_desc
,
no_grad_set
,
&
grad_to_var
,
grad_sub_block
);
std
::
vector
<
OpDesc
*>
grad_op_desc_ptrs
(
grad_op_descs
.
size
());
std
::
transform
(
grad_op_descs
.
begin
(),
grad_op_descs
.
end
(),
grad_op_desc_ptrs
.
begin
(),
[](
std
::
unique_ptr
<
OpDesc
>
&
p
)
{
return
p
.
release
();
});
return
std
::
make_pair
(
grad_op_desc_ptrs
,
grad_to_var
);
});
m
.
def
(
"has_grad_op_maker"
,
[](
const
std
::
string
op_type
)
{
return
framework
::
OpInfoMap
::
Instance
().
Get
(
op_type
).
HasGradOpMaker
();
});
m
.
def
(
"has_infer_inplace"
,
[](
const
std
::
string
op_type
)
{
return
framework
::
OpInfoMap
::
Instance
().
Get
(
op_type
).
HasInferInplace
();
});
m
.
def
(
"get_flags_use_mkldnn"
,
[]()
{
return
FLAGS_use_mkldnn
;
});
#ifdef PADDLE_WITH_NGRAPH
m
.
def
(
"get_flags_use_ngraph"
,
[]()
{
return
FLAGS_use_ngraph
;
});
#endif
m
.
def
(
"prune"
,
[](
const
ProgramDesc
&
origin
,
const
std
::
set
<
std
::
string
>
&
feeded_var_names
,
const
std
::
vector
<
std
::
array
<
size_t
,
2
>>
&
targets
)
{
ProgramDesc
prog_with_targets
(
origin
);
for
(
const
auto
&
t
:
targets
)
{
prog_with_targets
.
MutableBlock
(
t
[
0
])
->
Op
(
t
[
1
])
->
SetIsTarget
(
true
);
}
proto
::
ProgramDesc
pruned_desc
;
Prune
(
*
prog_with_targets
.
Proto
(),
feeded_var_names
,
&
pruned_desc
);
return
new
ProgramDesc
(
pruned_desc
);
});
m
.
def
(
"prune_backward"
,
[](
const
framework
::
ProgramDesc
&
program
)
{
return
PruneBackward
(
program
);
});
m
.
def
(
"empty_var_name"
,
[]()
{
return
std
::
string
(
framework
::
kEmptyVarName
);
});
m
.
def
(
"grad_var_suffix"
,
[]()
{
return
std
::
string
(
framework
::
kGradVarSuffix
);
});
m
.
def_submodule
(
"var_names"
,
"The module will return special predefined variable name in Paddle"
)
.
def
(
"empty"
,
[]()
{
return
kEmptyVarName
;
})
.
def
(
"temp"
,
[]()
{
return
kTempVarName
;
});
// clang-format off
py
::
class_
<
paddle
::
platform
::
DeviceContext
>
(
m
,
"DeviceContext"
)
.
def_static
(
"create"
,
[](
paddle
::
platform
::
CPUPlace
&
place
)
->
paddle
::
platform
::
DeviceContext
*
{
return
new
paddle
::
platform
::
CPUDeviceContext
();
})
.
def_static
(
"create"
,
[](
paddle
::
platform
::
CUDAPlace
&
place
)
->
paddle
::
platform
::
DeviceContext
*
{
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW
(
"CUDAPlace is not supported in CPU device."
);
#else
return
new
paddle
::
platform
::
CUDADeviceContext
(
place
);
#endif
})
.
def_static
(
"create"
,
[](
paddle
::
platform
::
CUDAPinnedPlace
&
place
)
->
paddle
::
platform
::
DeviceContext
*
{
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW
(
"CUDAPinnedPlace is not supported in CPU device."
);
#else
return
new
paddle
::
platform
::
CUDAPinnedDeviceContext
(
place
);
#endif
});;
// clang-format on
#if (defined(PADDLE_WITH_CUDA) && !defined(_WIN32))
py
::
class_
<
platform
::
Communicator
>
(
m
,
"Communicator"
).
def
(
py
::
init
<>
());
#endif
py
::
class_
<
platform
::
CUDAPlace
>
(
m
,
"CUDAPlace"
,
R"DOC(
CUDAPlace is a descriptor of a device. It represents a GPU, and each CUDAPlace
has a dev_id to indicate the number of cards represented by the current CUDAPlace.
The memory of CUDAPlace with different dev_id is not accessible.
Examples:
.. code-block:: python
import paddle.fluid as fluid
gpu_place = fluid.CUDAPlace(0)
)DOC"
)
.
def
(
"__init__"
,
[](
platform
::
CUDAPlace
&
self
,
int
dev_id
)
{
#ifdef PADDLE_WITH_CUDA
if
(
UNLIKELY
(
dev_id
<
0
))
{
LOG
(
ERROR
)
<<
string
::
Sprintf
(
"Invalid CUDAPlace(%d), device id must be 0 or "
"positive integer"
,
dev_id
);
std
::
exit
(
-
1
);
}
if
(
UNLIKELY
(
dev_id
>=
platform
::
GetCUDADeviceCount
()))
{
if
(
platform
::
GetCUDADeviceCount
()
==
0
)
{
LOG
(
ERROR
)
<<
"Cannot use GPU because there is no GPU "
"detected on your "
"machine."
;
std
::
exit
(
-
1
);
}
else
{
LOG
(
ERROR
)
<<
string
::
Sprintf
(
"Invalid CUDAPlace(%d), must inside [0, %d), because GPU "
"number on your machine is %d"
,
dev_id
,
platform
::
GetCUDADeviceCount
(),
platform
::
GetCUDADeviceCount
());
std
::
exit
(
-
1
);
}
}
new
(
&
self
)
platform
::
CUDAPlace
(
dev_id
);
#else
LOG
(
ERROR
)
<<
string
::
Sprintf
(
"Cannot use GPU because you have installed CPU version "
"PaddlePaddle.
\n
"
"If you want to use GPU, please try to install GPU version "
"PaddlePaddle by: pip install paddlepaddle-gpu
\n
"
"If you only have CPU, please change CUDAPlace(%d) to be "
"CPUPlace().
\n
"
,
dev_id
);
std
::
exit
(
-
1
);
#endif
})
.
def
(
"_type"
,
&
PlaceIndex
<
platform
::
CUDAPlace
>
)
.
def
(
"_equals"
,
&
IsSamePlace
<
platform
::
CUDAPlace
,
platform
::
Place
>
)
.
def
(
"_equals"
,
&
IsSamePlace
<
platform
::
CUDAPlace
,
platform
::
CUDAPlace
>
)
.
def
(
"_equals"
,
&
IsSamePlace
<
platform
::
CUDAPlace
,
platform
::
CPUPlace
>
)
.
def
(
"_equals"
,
&
IsSamePlace
<
platform
::
CUDAPlace
,
platform
::
CUDAPinnedPlace
>
)
.
def
(
"__str__"
,
string
::
to_string
<
const
platform
::
CUDAPlace
&>
);
py
::
class_
<
paddle
::
platform
::
CPUPlace
>
(
m
,
"CPUPlace"
,
R"DOC(
CPUPlace is a descriptor of a device. It represents a CPU, and the memory
CPUPlace can be accessed by CPU.
Examples:
.. code-block:: python
import paddle.fluid as fluid
cpu_place = fluid.CPUPlace()
)DOC"
)
.
def
(
py
::
init
<>
())
.
def
(
"_type"
,
&
PlaceIndex
<
platform
::
CPUPlace
>
)
.
def
(
"_equals"
,
&
IsSamePlace
<
platform
::
CPUPlace
,
platform
::
Place
>
)
.
def
(
"_equals"
,
&
IsSamePlace
<
platform
::
CPUPlace
,
platform
::
CUDAPlace
>
)
.
def
(
"_equals"
,
&
IsSamePlace
<
platform
::
CPUPlace
,
platform
::
CPUPlace
>
)
.
def
(
"_equals"
,
&
IsSamePlace
<
platform
::
CPUPlace
,
platform
::
CUDAPinnedPlace
>
)
.
def
(
"__str__"
,
string
::
to_string
<
const
platform
::
CPUPlace
&>
);
py
::
class_
<
paddle
::
platform
::
CUDAPinnedPlace
>
(
m
,
"CUDAPinnedPlace"
,
R"DOC(
CUDAPinnedPlace is a descriptor of a device. The memory of CUDAPinnedPlace
can be accessed by GPU and CPU.
Examples:
.. code-block:: python
import paddle.fluid as fluid
place = fluid.CUDAPinnedPlace()
)DOC"
)
.
def
(
"__init__"
,
[](
platform
::
CUDAPinnedPlace
&
self
)
{
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW
(
"Cannot use CUDAPinnedPlace in CPU only version"
);
#endif
new
(
&
self
)
platform
::
CUDAPinnedPlace
();
})
.
def
(
"_type"
,
&
PlaceIndex
<
platform
::
CUDAPinnedPlace
>
)
.
def
(
"_equals"
,
&
IsSamePlace
<
platform
::
CUDAPinnedPlace
,
platform
::
Place
>
)
.
def
(
"_equals"
,
&
IsSamePlace
<
platform
::
CUDAPinnedPlace
,
platform
::
CUDAPlace
>
)
.
def
(
"_equals"
,
&
IsSamePlace
<
platform
::
CUDAPinnedPlace
,
platform
::
CPUPlace
>
)
.
def
(
"_equals"
,
&
IsSamePlace
<
platform
::
CUDAPinnedPlace
,
platform
::
CUDAPinnedPlace
>
)
.
def
(
"__str__"
,
string
::
to_string
<
const
platform
::
CUDAPinnedPlace
&>
);
py
::
class_
<
platform
::
Place
>
(
m
,
"Place"
)
.
def
(
py
::
init
<>
())
.
def
(
"_type"
,
&
PlaceIndex
<
platform
::
Place
>
)
.
def
(
"_equals"
,
&
IsSamePlace
<
platform
::
Place
,
platform
::
Place
>
)
.
def
(
"_equals"
,
&
IsSamePlace
<
platform
::
Place
,
platform
::
CUDAPlace
>
)
.
def
(
"_equals"
,
&
IsSamePlace
<
platform
::
Place
,
platform
::
CPUPlace
>
)
.
def
(
"_equals"
,
&
IsSamePlace
<
platform
::
Place
,
platform
::
CUDAPinnedPlace
>
)
.
def
(
"is_gpu_place"
,
[](
platform
::
Place
&
self
)
{
return
platform
::
is_gpu_place
(
self
);
})
.
def
(
"is_cpu_place"
,
[](
platform
::
Place
&
self
)
{
return
platform
::
is_cpu_place
(
self
);
})
.
def
(
"is_cuda_pinned_place"
,
[](
platform
::
Place
&
self
)
{
return
platform
::
is_cuda_pinned_place
(
self
);
})
.
def
(
"gpu_device_id"
,
[](
platform
::
Place
&
self
)
{
return
boost
::
get
<
platform
::
CUDAPlace
>
(
self
).
device
;
})
.
def
(
"set_place"
,
[](
platform
::
Place
&
self
,
const
platform
::
Place
&
other
)
{
self
=
other
;
})
.
def
(
"set_place"
,
[](
platform
::
Place
&
self
,
const
platform
::
CPUPlace
&
cpu_place
)
{
self
=
cpu_place
;
})
.
def
(
"set_place"
,
[](
platform
::
Place
&
self
,
const
platform
::
CUDAPlace
&
gpu_place
)
{
self
=
gpu_place
;
})
.
def
(
"set_place"
,
[](
platform
::
Place
&
self
,
const
platform
::
CUDAPinnedPlace
&
cuda_pinned_place
)
{
self
=
cuda_pinned_place
;
});
py
::
class_
<
OperatorBase
>
(
m
,
"Operator"
)
.
def_static
(
"create"
,
[](
py
::
bytes
protobin
)
{
proto
::
OpDesc
desc
;
PADDLE_ENFORCE_EQ
(
desc
.
ParsePartialFromString
(
protobin
),
true
,
"Cannot parse user input to OpDesc"
);
PADDLE_ENFORCE_EQ
(
desc
.
IsInitialized
(),
true
,
"User OpDesc is not initialized, reason %s"
,
desc
.
InitializationErrorString
());
return
OpRegistry
::
CreateOp
(
desc
);
})
.
def
(
"run"
,
[](
OperatorBase
&
self
,
const
Scope
&
scope
,
const
platform
::
CPUPlace
&
place
)
{
self
.
Run
(
scope
,
place
);
})
.
def
(
"run"
,
[](
OperatorBase
&
self
,
const
Scope
&
scope
,
const
platform
::
CUDAPlace
&
place
)
{
self
.
Run
(
scope
,
place
);
})
.
def
(
"run"
,
[](
OperatorBase
&
self
,
const
Scope
&
scope
,
const
platform
::
CUDAPinnedPlace
&
place
)
{
self
.
Run
(
scope
,
place
);
})
.
def
(
"type"
,
[](
const
OperatorBase
&
op
)
->
std
::
string
{
return
op
.
Type
();
})
.
def
(
"outputs"
,
[](
const
OperatorBase
&
op
)
->
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
{
return
op
.
Outputs
();
})
.
def
(
"output_vars"
,
[](
const
OperatorBase
&
op
)
{
return
op
.
OutputVars
(
true
);
})
.
def
(
"inputs"
,
[](
const
OperatorBase
&
op
)
{
return
op
.
Inputs
();
})
.
def
(
"input_vars"
,
[](
const
OperatorBase
&
op
)
{
return
op
.
InputVars
();
})
.
def
(
"__str__"
,
&
OperatorBase
::
DebugString
)
.
def
(
"no_intermediate_outputs"
,
[](
const
OperatorBase
&
op
)
{
return
op
.
OutputVars
(
false
);
})
.
def
(
"support_gpu"
,
&
OperatorBase
::
SupportGPU
);
py
::
class_
<
framework
::
ExecutorPrepareContext
>
(
m
,
"ExecutorPrepareContext"
)
.
def
(
py
::
init
<
const
ProgramDesc
&
,
size_t
>
());
py
::
class_
<
framework
::
Executor
>
(
m
,
"Executor"
)
.
def
(
py
::
init
<
const
platform
::
Place
&>
())
.
def
(
"close"
,
&
Executor
::
Close
)
.
def
(
"run_from_dataset"
,
&
Executor
::
RunFromDataset
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"run_prepared_ctx"
,
[](
Executor
&
self
,
ExecutorPrepareContext
*
ctx
,
Scope
*
scope
,
std
::
map
<
std
::
string
,
const
LoDTensor
*>
*
feed_targets
,
std
::
map
<
std
::
string
,
LoDTensor
*>
*
fetch_targets
,
bool
create_local_scope
=
true
,
bool
create_vars
=
true
,
const
std
::
string
&
feed_holder_name
=
"feed"
,
const
std
::
string
&
fetch_holder_name
=
"fetch"
)
{
pybind11
::
gil_scoped_release
release
;
self
.
RunPreparedContext
(
ctx
,
scope
,
feed_targets
,
fetch_targets
,
create_local_scope
,
create_vars
,
feed_holder_name
,
fetch_holder_name
);
})
.
def
(
"run_cached_prepared_ctx"
,
[](
Executor
&
self
,
ExecutorPrepareContext
*
ctx
,
Scope
*
scope
,
bool
create_local_scope
=
true
,
bool
create_vars
=
true
,
bool
keep_kids
=
false
)
{
pybind11
::
gil_scoped_release
release
;
self
.
RunPreparedContext
(
ctx
,
scope
,
create_local_scope
,
create_vars
,
keep_kids
);
})
.
def
(
"prepare_ctx_cache"
,
&
Executor
::
PrepareCtxCache
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"create_variables"
,
&
Executor
::
CreateVariables
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"run"
,
[](
Executor
&
self
,
const
ProgramDesc
&
prog
,
Scope
*
scope
,
int
block_id
,
bool
create_local_scope
,
bool
create_vars
,
const
std
::
vector
<
std
::
string
>
&
fetch_vars
)
{
pybind11
::
gil_scoped_release
release
;
self
.
Run
(
prog
,
scope
,
block_id
,
create_local_scope
,
create_vars
,
fetch_vars
);
});
m
.
def
(
"init_gflags"
,
framework
::
InitGflags
);
m
.
def
(
"init_glog"
,
framework
::
InitGLOG
);
m
.
def
(
"init_dgc"
,
framework
::
InitDGC
);
m
.
def
(
"init_devices"
,
[](
bool
init_p2p
)
{
framework
::
InitDevices
(
init_p2p
);
});
m
.
def
(
"is_compiled_with_ngraph"
,
IsCompiledWithNGRAPH
);
m
.
def
(
"is_compiled_with_cuda"
,
IsCompiledWithCUDA
);
m
.
def
(
"is_compiled_with_mkldnn"
,
IsCompiledWithMKLDNN
);
m
.
def
(
"is_compiled_with_brpc"
,
IsCompiledWithBrpc
);
m
.
def
(
"is_compiled_with_dist"
,
IsCompiledWithDIST
);
#ifdef PADDLE_WITH_CUDA
m
.
def
(
"is_float16_supported"
,
[](
const
platform
::
CUDAPlace
&
place
)
->
bool
{
// Only GPUs with Compute Capability >= 53 support float16
return
platform
::
GetCUDAComputeCapability
(
place
.
device
)
>=
53
;
});
#endif
m
.
def
(
"set_feed_variable"
,
framework
::
SetFeedVariable
);
m
.
def
(
"get_fetch_variable"
,
framework
::
GetFetchVariable
);
m
.
def
(
"get_variable_tensor"
,
framework
::
GetVariableTensor
);
m
.
def
(
"_is_program_version_supported"
,
IsProgramVersionSupported
);
BindProgramDesc
(
&
m
);
BindBlockDesc
(
&
m
);
BindVarDsec
(
&
m
);
BindOpDesc
(
&
m
);
BindConstValue
(
&
m
);
py
::
class_
<
framework
::
LoDRankTable
>
(
m
,
"LodRankTable"
)
.
def
(
"items"
,
[](
framework
::
LoDRankTable
&
table
)
{
std
::
vector
<
std
::
pair
<
size_t
,
size_t
>>
res
;
for
(
auto
&
item
:
table
.
items
())
{
res
.
push_back
({
item
.
index
,
item
.
length
});
}
return
res
;
});
py
::
class_
<
LoDTensorArray
>
(
m
,
"LoDTensorArray"
,
R"DOC(
Array of LoDTensor.
Examples:
.. code-block:: python
import paddle.fluid as fluid
arr = fluid.LoDTensorArray()
)DOC"
)
.
def
(
"__init__"
,
[](
LoDTensorArray
&
instance
)
{
new
(
&
instance
)
LoDTensorArray
();
})
.
def
(
"__getitem__"
,
[](
LoDTensorArray
&
self
,
size_t
i
)
{
return
&
self
.
at
(
i
);
},
py
::
return_value_policy
::
reference
)
.
def
(
"__len__"
,
[](
LoDTensorArray
&
self
)
{
return
self
.
size
();
})
.
def
(
"__setitem__"
,
[](
LoDTensorArray
&
self
,
size_t
i
,
const
LoDTensor
&
t
)
{
PADDLE_ENFORCE_LT
(
i
,
self
.
size
());
self
[
i
].
ShareDataWith
(
t
);
self
[
i
].
set_lod
(
t
.
lod
());
})
.
def
(
"append"
,
[](
LoDTensorArray
&
self
,
const
LoDTensor
&
t
)
{
self
.
emplace_back
();
self
.
back
().
ShareDataWith
(
t
);
self
.
back
().
set_lod
(
t
.
lod
());
},
py
::
arg
(
"tensor"
),
R"DOC(
Append a LoDensor to LoDTensorArray.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
arr = fluid.LoDTensorArray()
t = fluid.LoDTensor()
t.set(np.ndarray([5, 30]), fluid.CPUPlace())
arr.append(t)
)DOC"
)
.
def
(
"_move_to_list"
,
[](
LoDTensorArray
&
self
)
->
py
::
list
{
py
::
list
res
(
self
.
size
());
for
(
size_t
i
=
0
;
i
<
self
.
size
();
++
i
)
{
res
[
i
]
=
py
::
cast
(
std
::
move
(
self
[
i
]));
}
self
.
clear
();
return
res
;
},
py
::
return_value_policy
::
take_ownership
);
m
.
def
(
"op_support_gpu"
,
OpSupportGPU
);
#ifdef PADDLE_WITH_CUDA
m
.
def
(
"get_cuda_device_count"
,
platform
::
GetCUDADeviceCount
);
#ifndef _WIN32
m
.
def
(
"nvprof_init"
,
platform
::
CudaProfilerInit
);
m
.
def
(
"nvprof_start"
,
platform
::
CudaProfilerStart
);
m
.
def
(
"nvprof_stop"
,
platform
::
CudaProfilerStop
);
#endif
#endif
py
::
enum_
<
platform
::
ProfilerState
>
(
m
,
"ProfilerState"
,
py
::
arithmetic
())
.
value
(
"kDisabled"
,
platform
::
ProfilerState
::
kDisabled
)
.
value
(
"kCPU"
,
platform
::
ProfilerState
::
kCPU
)
.
value
(
"kCUDA"
,
platform
::
ProfilerState
::
kCUDA
)
.
value
(
"kAll"
,
platform
::
ProfilerState
::
kAll
)
.
export_values
();
py
::
enum_
<
platform
::
EventSortingKey
>
(
m
,
"EventSortingKey"
,
py
::
arithmetic
())
.
value
(
"kDefault"
,
platform
::
EventSortingKey
::
kDefault
)
.
value
(
"kCalls"
,
platform
::
EventSortingKey
::
kCalls
)
.
value
(
"kTotal"
,
platform
::
EventSortingKey
::
kTotal
)
.
value
(
"kMin"
,
platform
::
EventSortingKey
::
kMin
)
.
value
(
"kMax"
,
platform
::
EventSortingKey
::
kMax
)
.
value
(
"kAve"
,
platform
::
EventSortingKey
::
kAve
)
.
export_values
();
m
.
def
(
"enable_profiler"
,
platform
::
EnableProfiler
);
m
.
def
(
"disable_profiler"
,
platform
::
DisableProfiler
);
m
.
def
(
"is_profiler_enabled"
,
platform
::
IsProfileEnabled
);
m
.
def
(
"reset_profiler"
,
platform
::
ResetProfiler
);
m
.
def
(
"get_pass"
,
[](
const
std
::
string
&
pass_type
)
{
auto
pass
=
framework
::
ir
::
PassRegistry
::
Instance
().
Get
(
pass_type
);
return
std
::
shared_ptr
<
framework
::
ir
::
Pass
>
(
std
::
move
(
pass
));
});
m
.
def
(
"size_of_dtype"
,
framework
::
SizeOfType
);
using
VarQuantScale
=
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
bool
,
LoDTensor
>>
;
py
::
class_
<
ir
::
Pass
,
std
::
shared_ptr
<
ir
::
Pass
>>
pass
(
m
,
"Pass"
);
pass
.
def
(
py
::
init
())
.
def
(
"has"
,
&
ir
::
Pass
::
Has
)
.
def
(
"set_not_owned"
,
[](
ir
::
Pass
&
self
,
const
std
::
string
&
attr_name
,
ProgramDesc
&
attr
)
{
self
.
SetNotOwned
<
ProgramDesc
>
(
attr_name
,
&
attr
);
})
.
def
(
"set"
,
[](
ir
::
Pass
&
self
,
const
std
::
string
&
name
,
const
std
::
string
&
attr
)
{
self
.
Set
<
std
::
string
>
(
name
,
new
std
::
string
(
attr
));
})
.
def
(
"set"
,
[](
ir
::
Pass
&
self
,
const
std
::
string
&
name
,
int
val
)
{
self
.
Set
<
const
int
>
(
name
,
new
int
(
val
));
})
.
def
(
"set"
,
[](
ir
::
Pass
&
self
,
const
std
::
string
&
name
,
std
::
unordered_set
<
std
::
string
>
set
)
{
self
.
Set
(
name
,
new
std
::
unordered_set
<
std
::
string
>
(
set
));
})
.
def
(
"set"
,
[](
ir
::
Pass
&
self
,
const
std
::
string
&
name
,
std
::
unordered_set
<
int
>
set
)
{
self
.
Set
(
name
,
new
std
::
unordered_set
<
int
>
(
set
));
})
.
def
(
"set"
,
[](
ir
::
Pass
&
self
,
const
std
::
string
&
name
,
VarQuantScale
scales
)
{
self
.
Set
(
name
,
new
VarQuantScale
(
scales
));
})
.
def
(
"type"
,
&
ir
::
Pass
::
Type
)
.
def
(
"apply"
,
[](
ir
::
Pass
&
self
,
std
::
shared_ptr
<
ir
::
Graph
>
graph
)
{
self
.
Apply
(
graph
.
get
());
});
py
::
class_
<
ir
::
PassBuilder
,
std
::
shared_ptr
<
ir
::
PassBuilder
>>
pb
(
m
,
"PassBuilder"
);
pb
.
def
(
py
::
init
())
.
def
(
"append_pass"
,
[](
ir
::
PassBuilder
&
self
,
const
std
::
string
&
pass_type
)
->
std
::
shared_ptr
<
ir
::
Pass
>
{
return
self
.
AppendPass
(
pass_type
);
})
.
def
(
"all_passes"
,
[](
ir
::
PassBuilder
&
self
)
{
return
self
.
AllPasses
();
})
.
def
(
"insert_pass"
,
[](
ir
::
PassBuilder
&
self
,
size_t
idx
,
const
std
::
string
&
pass_type
)
{
return
self
.
InsertPass
(
idx
,
pass_type
);
})
.
def
(
"remove_pass"
,
[](
ir
::
PassBuilder
&
self
,
size_t
idx
)
{
self
.
RemovePass
(
idx
);
});
// -- python binds for parallel executor.
py
::
class_
<
ParallelExecutor
>
pe
(
m
,
"ParallelExecutor"
);
py
::
class_
<
ExecutionStrategy
>
exec_strategy
(
pe
,
"ExecutionStrategy"
,
R"DOC(
ExecutionStrategy allows the user to more preciously control how to run
the program in ParallelExecutor by setting the property.
Examples:
.. code-block:: python
import paddle.fluid as fluid
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_loss = fluid.layers.mean(cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(avg_loss)
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_threads = 4
train_exe = fluid.ParallelExecutor(use_cuda=False,
loss_name=avg_loss.name,
exec_strategy=exec_strategy)
)DOC"
);
exec_strategy
.
def
(
py
::
init
())
.
def_property
(
"num_threads"
,
[](
const
ExecutionStrategy
&
self
)
{
return
self
.
num_threads_
;
},
[](
ExecutionStrategy
&
self
,
size_t
num_threads
)
{
self
.
num_threads_
=
num_threads
;
},
R"DOC(The type is INT, num_threads represents the size of thread pool that
used to run the operators of the current program in ParallelExecutor.
If :math:`num\_threads=1`, all the operators will execute one by one,
but the order maybe difference between iterations.
If it is not set, it will be set in ParallelExecutor according to the
device type and device count, for GPU, :math:`num\_threads=device\_count*4`, for CPU,
:math:`num\_threads=CPU\_NUM*4`, the explanation of:math:`CPU\_NUM` is in ParallelExecutor.
if it is not set, ParallelExecutor will get the cpu count by calling
`multiprocessing.cpu_count()`. Default 0.)DOC"
)
.
def_property
(
"use_cuda"
,
[](
const
ExecutionStrategy
&
self
)
{
return
self
.
use_cuda_
;
},
[](
ExecutionStrategy
&
self
,
bool
use_cuda
)
{
self
.
use_cuda_
=
use_cuda
;
})
// FIXME(chengduo): Doesn't add doc for 'use_cuda', use_cuda may
// make user confuse, because ParallelExecutor has a parameter named
// 'use_cuda' too, in current implementation, ParallelExecutor's
// 'use_cuda' will rewrite ExecutionStrategy's 'use_cuda'.
.
def_property
(
"allow_op_delay"
,
[](
const
ExecutionStrategy
&
self
)
{
return
self
.
allow_op_delay_
;
},
[](
ExecutionStrategy
&
self
,
bool
allow_op_delay
)
{
self
.
allow_op_delay_
=
allow_op_delay
;
},
R"DOC(The type is BOOL, allow_op_delay represents whether to delay the
communication operators to run, it may make the execution faster.
Note that this option is invalid now, and it will be removed in
next version. Default False.)DOC"
)
.
def_property
(
"num_iteration_per_drop_scope"
,
[](
const
ExecutionStrategy
&
self
)
{
return
self
.
num_iteration_per_drop_scope_
;
},
[](
ExecutionStrategy
&
self
,
size_t
num_iteration_per_drop_scope
)
{
self
.
num_iteration_per_drop_scope_
=
num_iteration_per_drop_scope
;
},
R"DOC(The type is INT, num_iteration_per_drop_scope indicates how
many iterations to clean up the temp variables which
is generated during execution. It may make the execution faster,
because the temp variable's shape maybe the same between two iterations.
Default 1.
NOTES:
1. If you fetch data when calling the 'run', the ParallelExecutor
will clean up the temp variables at the end of the current iteration.
2. In some NLP model, it may cause the GPU memory is insufficient,
in this case, you should reduce `num_iteration_per_drop_scope`.
)DOC"
)
.
def_property
(
"num_iteration_per_run"
,
[](
const
ExecutionStrategy
&
self
)
{
return
self
.
num_iteration_per_run_
;
},
[](
ExecutionStrategy
&
self
,
size_t
num_iteration_per_run
)
{
self
.
num_iteration_per_run_
=
num_iteration_per_run
;
},
R"DOC(This config that how many iteration the executor will run when
user call pe.run() in python
)DOC"
)
.
def_property
(
"_dry_run"
,
[](
const
ExecutionStrategy
&
self
)
{
return
self
.
dry_run_
;
},
[](
ExecutionStrategy
&
self
,
bool
dry_run
)
{
self
.
dry_run_
=
dry_run
;
});
exec_strategy
.
def_property
(
"use_experimental_executor"
,
[](
const
ExecutionStrategy
&
self
)
{
return
self
.
type_
==
ExecutionStrategy
::
kExperimental
;
},
[](
ExecutionStrategy
&
self
,
bool
experimental
)
{
self
.
type_
=
experimental
?
ExecutionStrategy
::
kExperimental
:
ExecutionStrategy
::
kDefault
;
});
py
::
class_
<
BuildStrategy
>
build_strategy
(
pe
,
"BuildStrategy"
,
R"DOC(
BuildStrategy allows the user to more preciously control how to
build the SSA Graph in ParallelExecutor by setting the property.
Examples:
.. code-block:: python
import paddle.fluid as fluid
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
)DOC"
);
py
::
enum_
<
BuildStrategy
::
ReduceStrategy
>
(
build_strategy
,
"ReduceStrategy"
)
.
value
(
"Reduce"
,
BuildStrategy
::
ReduceStrategy
::
kReduce
)
.
value
(
"AllReduce"
,
BuildStrategy
::
ReduceStrategy
::
kAllReduce
);
py
::
enum_
<
BuildStrategy
::
GradientScaleStrategy
>
(
build_strategy
,
"GradientScaleStrategy"
)
.
value
(
"CoeffNumDevice"
,
BuildStrategy
::
GradientScaleStrategy
::
kCoeffNumDevice
)
.
value
(
"One"
,
BuildStrategy
::
GradientScaleStrategy
::
kOne
)
.
value
(
"Customized"
,
BuildStrategy
::
GradientScaleStrategy
::
kCustomized
);
build_strategy
.
def
(
py
::
init
())
.
def_property
(
"reduce_strategy"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
reduce_
;
},
[](
BuildStrategy
&
self
,
BuildStrategy
::
ReduceStrategy
strategy
)
{
PADDLE_ENFORCE_EQ
(
!
self
.
IsFinalized
(),
true
,
"BuildStrategy is finlaized."
);
self
.
reduce_
=
strategy
;
},
R"DOC(The type is fluid.BuildStrategy.ReduceStrategy, there are two reduce
strategies in ParallelExecutor, AllReduce and Reduce. If you want
that all the parameters' optimization are done on all devices independently,
you should choose AllReduce; if you choose Reduce, all the parameters'
optimization will be evenly distributed to different devices, and then
broadcast the optimized parameter to other devices.
Default 'AllReduce'.
Examples:
.. code-block:: python
import paddle.fluid as fluid
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
)DOC"
)
.
def_property
(
"gradient_scale_strategy"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
gradient_scale_
;
},
[](
BuildStrategy
&
self
,
BuildStrategy
::
GradientScaleStrategy
strategy
)
{
PADDLE_ENFORCE_EQ
(
!
self
.
IsFinalized
(),
true
,
"BuildStrategy is finalized."
);
self
.
gradient_scale_
=
strategy
;
},
R"DOC(The type is fluid.BuildStrategy.GradientScaleStrategy, there are three
ways of defining :math:`loss@grad` in ParallelExecutor, CoeffNumDevice,
One and Customized. By default, ParallelExecutor sets the :math:`loss@grad`
according to the number of devices. If you want to customize :math:`loss@grad`,
you can choose Customized. Default 'CoeffNumDevice'.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle.fluid.compiler as compiler
import numpy
import os
use_cuda = True
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
# NOTE: If you use CPU to run the program, you need
# to specify the CPU_NUM, otherwise, fluid will use
# all the number of the logic core as the CPU_NUM,
# in that case, the batch size of the input should be
# greater than CPU_NUM, if not, the process will be
# failed by an exception.
if not use_cuda:
os.environ['CPU_NUM'] = str(2)
places = fluid.cpu_places()
else:
places = places = fluid.cuda_places()
data = fluid.layers.data(name='X', shape=[1], dtype='float32')
hidden = fluid.layers.fc(input=data, size=10)
loss = fluid.layers.mean(hidden)
fluid.optimizer.SGD(learning_rate=0.01).minimize(loss)
fluid.default_startup_program().random_seed=1
exe.run(fluid.default_startup_program())
build_strategy = fluid.BuildStrategy()
build_strategy.gradient_scale_strategy = \
fluid.BuildStrategy.GradientScaleStrategy.Customized
compiled_prog = compiler.CompiledProgram(
fluid.default_main_program()).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy,
places = places)
dev_count = len(places)
x = numpy.random.random(size=(10, 1)).astype('float32')
loss_grad = numpy.ones((dev_count)).astype("float32") * 0.01
loss_grad_name = loss.name+"@GRAD"
loss_data = exe.run(compiled_prog,
feed={"X": x, loss_grad_name : loss_grad},
fetch_list=[loss.name, loss_grad_name])
)DOC"
)
.
def_property
(
"debug_graphviz_path"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
debug_graphviz_path_
;
},
[](
BuildStrategy
&
self
,
const
std
::
string
&
path
)
{
PADDLE_ENFORCE_EQ
(
!
self
.
IsFinalized
(),
true
,
"BuildStrategy is finlaized."
);
self
.
debug_graphviz_path_
=
path
;
},
R"DOC(The type is STR, debug_graphviz_path indicates the path that
writing the SSA Graph to file in the form of graphviz.
It is useful for debugging. Default ""
Examples:
.. code-block:: python
import paddle.fluid as fluid
build_strategy = fluid.BuildStrategy()
build_strategy.debug_graphviz_path = "./graph"
)DOC"
)
.
def_property
(
"enable_sequential_execution"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
enable_sequential_execution_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
PADDLE_ENFORCE_EQ
(
!
self
.
IsFinalized
(),
true
,
"BuildStrategy is finlaized."
);
self
.
enable_sequential_execution_
=
b
;
},
R"DOC(The type is BOOL. If set True, the execution order of ops would
be the same as what is in the program. Default False.
Examples:
.. code-block:: python
import paddle.fluid as fluid
build_strategy = fluid.BuildStrategy()
build_strategy.enable_sequential_execution = True
)DOC"
)
.
def_property
(
"remove_unnecessary_lock"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
remove_unnecessary_lock_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
PADDLE_ENFORCE_EQ
(
!
self
.
IsFinalized
(),
true
,
"BuildStrategy is finlaized."
);
self
.
remove_unnecessary_lock_
=
b
;
},
R"DOC(The type is BOOL. If set True, some locks in GPU ops would be
released and ParallelExecutor would run faster. Default True.
Examples:
.. code-block:: python
import paddle.fluid as fluid
build_strategy = fluid.BuildStrategy()
build_strategy.remove_unnecessary_lock = True
)DOC"
)
.
def_property
(
"num_trainers"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
num_trainers_
;
},
[](
BuildStrategy
&
self
,
int
num_trainers
)
{
#ifdef WIN32
PADDLE_THROW
(
"Windows has NO support to distribute mode."
);
#endif
self
.
num_trainers_
=
num_trainers
;
})
.
def_property
(
"trainers_endpoints"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
trainers_endpoints_
;
},
[](
BuildStrategy
&
self
,
const
std
::
vector
<
std
::
string
>
&
trainers_endpoints
)
{
self
.
trainers_endpoints_
=
trainers_endpoints
;
})
.
def_property
(
"trainer_id"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
trainer_id_
;
},
[](
BuildStrategy
&
self
,
int
trainer_id
)
{
self
.
trainer_id_
=
trainer_id
;
})
.
def_property
(
"nccl_comm_num"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
nccl_comm_num_
;
},
[](
BuildStrategy
&
self
,
int
nccl_comm_num
)
{
self
.
nccl_comm_num_
=
nccl_comm_num
;
})
.
def_property
(
"use_hierarchical_allreduce"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
use_hierarchical_allreduce_
;
},
[](
BuildStrategy
&
self
,
bool
use
)
{
self
.
use_hierarchical_allreduce_
=
use
;
})
.
def_property
(
"hierarchical_allreduce_inter_nranks"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
hierarchical_allreduce_inter_nranks_
;
},
[](
BuildStrategy
&
self
,
int
nranks
)
{
self
.
hierarchical_allreduce_inter_nranks_
=
nranks
;
})
.
def_property
(
"fuse_elewise_add_act_ops"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
fuse_elewise_add_act_ops_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
PADDLE_ENFORCE_EQ
(
!
self
.
IsFinalized
(),
true
,
"BuildStrategy is finlaized."
);
self
.
fuse_elewise_add_act_ops_
=
b
;
},
R"DOC(The type is BOOL, fuse_elewise_add_act_ops indicate whether
to fuse elementwise_add_op and activation_op,
it may make the execution faster. Default False
Examples:
.. code-block:: python
import paddle.fluid as fluid
build_strategy = fluid.BuildStrategy()
build_strategy.fuse_elewise_add_act_ops = True
)DOC"
)
.
def_property
(
"fuse_relu_depthwise_conv"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
fuse_relu_depthwise_conv_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
PADDLE_ENFORCE_EQ
(
!
self
.
IsFinalized
(),
true
,
"BuildStrategy is finlaized."
);
self
.
fuse_relu_depthwise_conv_
=
b
;
},
R"DOC(The type is BOOL, fuse_relu_depthwise_conv indicate whether
to fuse relu and depthwise_conv2d,
it will save GPU memory and may make the execution faster.
This options is only available in GPU devices.
Default False.
Examples:
.. code-block:: python
import paddle.fluid as fluid
build_strategy = fluid.BuildStrategy()
build_strategy.fuse_relu_depthwise_conv = True
)DOC"
)
.
def_property
(
"fuse_broadcast_ops"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
fuse_broadcast_ops_
==
true
||
self
.
fuse_broadcast_ops_
==
boost
::
none
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
PADDLE_ENFORCE_EQ
(
!
self
.
IsFinalized
(),
true
,
"BuildStrategy is finlaized."
);
self
.
fuse_broadcast_ops_
=
b
;
},
R"DOC(The type is BOOL, fuse_broadcast_op indicates whether
to fuse the broadcast ops. Note that, in Reduce mode,
fusing broadcast ops may make the program faster. Because
fusing broadcast OP equals delaying the execution of all
broadcast Ops, in this case, all nccl streams are used only
for NCCLReduce operations for a period of time. Default False.)DOC"
)
.
def_property
(
"fuse_all_optimizer_ops"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
fuse_all_optimizer_ops_
==
true
||
self
.
fuse_all_optimizer_ops_
==
boost
::
none
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
PADDLE_ENFORCE_EQ
(
!
self
.
IsFinalized
(),
true
,
"BuildStrategy is finlaized."
);
self
.
fuse_all_optimizer_ops_
=
b
;
})
.
def_property
(
"sync_batch_norm"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
sync_batch_norm_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
PADDLE_ENFORCE_EQ
(
!
self
.
IsFinalized
(),
true
,
"BuildStrategy is finlaized."
);
self
.
sync_batch_norm_
=
b
;
},
R"DOC(The type is BOOL, sync_batch_norm indicates whether to use
synchronous batch normalization which synchronizes the mean
and variance through multi-devices in training phase.
Current implementation doesn't support FP16 training and CPU.
And only synchronous on one machine, not all machines.
Default False
Examples:
.. code-block:: python
import paddle.fluid as fluid
build_strategy = fluid.BuildStrategy()
build_strategy.sync_batch_norm = True
)DOC"
)
.
def_property
(
"memory_optimize"
,
[](
const
BuildStrategy
&
self
)
->
py
::
object
{
if
(
self
.
memory_optimize_
)
{
return
py
::
cast
(
self
.
memory_optimize_
.
get
());
}
else
{
return
py
::
cast
(
nullptr
);
}
},
[](
BuildStrategy
&
self
,
const
py
::
handle
&
value
)
{
auto
*
py_obj
=
value
.
ptr
();
if
(
py_obj
==
nullptr
||
py_obj
==
Py_None
)
{
self
.
memory_optimize_
=
boost
::
none
;
}
else
if
(
PyBool_Check
(
py_obj
))
{
self
.
memory_optimize_
=
(
py_obj
==
Py_True
);
}
else
{
PADDLE_THROW
(
"BuildStrategy.memory_optimize must be None, False or True"
);
}
},
R"DOC(The type is BOOL or None, memory opitimize aims to save total memory
consumption, set to True to enable it.
Default None. None means framework would choose to use or not use
this strategy automatically. Currently, None means that it is
enabled when GC is disabled, and disabled when GC is enabled.
True means enabling and False means disabling. Default None.)DOC"
)
.
def_property
(
"is_distribution"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
is_distribution_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
#ifdef WIN32
if
(
b
)
{
PADDLE_THROW
(
"Windows has NO support to distribute mode."
);
}
#else
self
.
is_distribution_
=
b
;
#endif
})
.
def_property
(
"async_mode"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
async_mode_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
async_mode_
=
b
;
})
.
def_property
(
"enable_inplace"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
enable_inplace_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
enable_inplace_
=
b
;
})
.
def_property
(
"fuse_all_reduce_ops"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
fuse_all_reduce_ops_
==
true
||
self
.
fuse_all_reduce_ops_
==
boost
::
none
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
fuse_all_reduce_ops_
=
b
;
})
.
def_property
(
"enable_backward_optimizer_op_deps"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
enable_backward_optimizer_op_deps_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
enable_backward_optimizer_op_deps_
=
b
;
})
.
def_property
(
"cache_runtime_context"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
cache_runtime_context_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
cache_runtime_context_
=
b
;
})
.
def_property
(
"mkldnn_enabled_op_types"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
mkldnn_enabled_op_types_
;
},
[](
BuildStrategy
&
self
,
const
std
::
unordered_set
<
std
::
string
>
&
mkldnn_enabled_op_types
)
{
self
.
mkldnn_enabled_op_types_
=
mkldnn_enabled_op_types
;
})
.
def
(
"_finalize_strategy_and_create_passes"
,
[](
BuildStrategy
&
self
)
->
std
::
shared_ptr
<
ir
::
PassBuilder
>
{
return
self
.
CreatePassesFromStrategy
(
true
);
},
R"DOC(Allow user to customized passes. Normally model-specific
optimization passes should be defined in this way. BuildStrategy
cannot be updated after being finalized.)DOC"
);
pe
.
def
(
py
::
init
<
const
std
::
vector
<
platform
::
Place
>
&
,
const
std
::
vector
<
std
::
string
>
&
,
const
std
::
string
&
,
Scope
*
,
std
::
vector
<
Scope
*>
&
,
const
ExecutionStrategy
&
,
const
BuildStrategy
&
,
ir
::
Graph
*>
())
// NOTE: even we return a vec<Scope*>* to Python use reference policy.
// We still cannot get local_scope from this vector, since the element
// of vec<Scope*> will be freed by Python GC. We can only return Scope*
// one by one and mark them as reference.
.
def
(
"local_scopes"
,
[](
ParallelExecutor
&
self
)
->
std
::
vector
<
Scope
*>
*
{
return
&
self
.
GetLocalScopes
();
},
py
::
return_value_policy
::
reference
)
.
def
(
"drop_local_exe_scopes"
,
&
ParallelExecutor
::
DropLocalExeScopes
)
.
def
(
"_need_create_local_exe_scopes"
,
&
ParallelExecutor
::
NeedCreateLocalExeScope
)
.
def
(
"feed_tensors_into_local_scopes"
,
&
ParallelExecutor
::
FeedTensorsIntoLocalScopes
)
.
def
(
"feed_and_split_tensor_into_local_scopes"
,
&
ParallelExecutor
::
FeedAndSplitTensorIntoLocalScopes
)
.
def
(
"run"
,
[](
ParallelExecutor
&
self
,
const
std
::
vector
<
std
::
string
>
&
fetch_tensors
)
{
pybind11
::
gil_scoped_release
release
;
return
self
.
Run
(
fetch_tensors
);
});
BindFleetWrapper
(
&
m
);
BindBoxHelper
(
&
m
);
#ifndef _WIN32
BindNCCLWrapper
(
&
m
);
#endif
BindGraph
(
&
m
);
BindNode
(
&
m
);
BindInferenceApi
(
&
m
);
BindExpandApi
(
&
m
);
BindDataset
(
&
m
);
#ifdef PADDLE_WITH_DISTRIBUTE
BindCommunicator
(
&
m
);
#endif
}
}
// namespace pybind
}
// namespace paddle
paddle/fluid/feed/src/CMakeLists.txt
0 → 100755
浏览文件 @
0dc7d425
add_subdirectory
(
common
)
add_subdirectory
(
data_reader
)
paddle/fluid/feed/src/common/CMakeLists.txt
0 → 100755
浏览文件 @
0dc7d425
cc_library
(
dict_plugin SRCS dict_plugin.cc DEPS glog boost fs
)
paddle/fluid/feed/src/common/bhopscotch_map.h
0 → 100755
浏览文件 @
0dc7d425
/**
* MIT License
*
* Copyright (c) 2017 Tessil
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_BHOPSCOTCH_MAP_H
#define TSL_BHOPSCOTCH_MAP_H
#include <algorithm>
#include <cstddef>
#include <functional>
#include <initializer_list>
#include <map>
#include <memory>
#include <type_traits>
#include <utility>
#include "paddle/fluid/feed/src/common/hopscotch_hash.h"
namespace
tsl
{
/**
* Similar to tsl::hopscotch_map but instead of using a list for overflowing elements it uses
* a binary search tree. It thus needs an additional template parameter Compare. Compare should
* be arithmetically coherent with KeyEqual.
*
* The binary search tree allows the map to have a worst-case scenario of O(log n) for search
* and delete, even if the hash function maps all the elements to the same bucket.
* For insert, the amortized worst case is O(log n), but the worst case is O(n) in case of rehash.
*
* This makes the map resistant to DoS attacks (but doesn't preclude you to have a good hash function,
* as an element in the bucket array is faster to retrieve than in the tree).
*
* @copydoc hopscotch_map
*/
template
<
class
Key
,
class
T
,
class
Hash
=
std
::
hash
<
Key
>,
class
KeyEqual
=
std
::
equal_to
<
Key
>
,
class
Compare
=
std
::
less
<
Key
>
,
class
Allocator
=
std
::
allocator
<
std
::
pair
<
const
Key
,
T
>>
,
unsigned
int
NeighborhoodSize
=
62
,
bool
StoreHash
=
false
,
class
GrowthPolicy
=
tsl
::
hh
::
power_of_two_growth_policy
<
2
>>
class
bhopscotch_map
{
private:
template
<
typename
U
>
using
has_is_transparent
=
tsl
::
detail_hopscotch_hash
::
has_is_transparent
<
U
>
;
class
KeySelect
{
public:
using
key_type
=
Key
;
const
key_type
&
operator
()(
const
std
::
pair
<
const
Key
,
T
>&
key_value
)
const
{
return
key_value
.
first
;
}
const
key_type
&
operator
()(
std
::
pair
<
const
Key
,
T
>&
key_value
)
{
return
key_value
.
first
;
}
};
class
ValueSelect
{
public:
using
value_type
=
T
;
const
value_type
&
operator
()(
const
std
::
pair
<
const
Key
,
T
>&
key_value
)
const
{
return
key_value
.
second
;
}
value_type
&
operator
()(
std
::
pair
<
Key
,
T
>&
key_value
)
{
return
key_value
.
second
;
}
};
// TODO Not optimal as we have to use std::pair<const Key, T> as ValueType which forbid
// us to move the key in the bucket array, we have to use copy. Optimize.
using
overflow_container_type
=
std
::
map
<
Key
,
T
,
Compare
,
Allocator
>
;
using
ht
=
detail_hopscotch_hash
::
hopscotch_hash
<
std
::
pair
<
const
Key
,
T
>
,
KeySelect
,
ValueSelect
,
Hash
,
KeyEqual
,
Allocator
,
NeighborhoodSize
,
StoreHash
,
GrowthPolicy
,
overflow_container_type
>
;
public:
using
key_type
=
typename
ht
::
key_type
;
using
mapped_type
=
T
;
using
value_type
=
typename
ht
::
value_type
;
using
size_type
=
typename
ht
::
size_type
;
using
difference_type
=
typename
ht
::
difference_type
;
using
hasher
=
typename
ht
::
hasher
;
using
key_equal
=
typename
ht
::
key_equal
;
using
key_compare
=
Compare
;
using
allocator_type
=
typename
ht
::
allocator_type
;
using
reference
=
typename
ht
::
reference
;
using
const_reference
=
typename
ht
::
const_reference
;
using
pointer
=
typename
ht
::
pointer
;
using
const_pointer
=
typename
ht
::
const_pointer
;
using
iterator
=
typename
ht
::
iterator
;
using
const_iterator
=
typename
ht
::
const_iterator
;
/*
* Constructors
*/
bhopscotch_map
()
:
bhopscotch_map
(
ht
::
DEFAULT_INIT_BUCKETS_SIZE
)
{
}
explicit
bhopscotch_map
(
size_type
bucket_count
,
const
Hash
&
hash
=
Hash
(),
const
KeyEqual
&
equal
=
KeyEqual
(),
const
Allocator
&
alloc
=
Allocator
(),
const
Compare
&
comp
=
Compare
())
:
m_ht
(
bucket_count
,
hash
,
equal
,
alloc
,
ht
::
DEFAULT_MAX_LOAD_FACTOR
,
comp
)
{
}
bhopscotch_map
(
size_type
bucket_count
,
const
Allocator
&
alloc
)
:
bhopscotch_map
(
bucket_count
,
Hash
(),
KeyEqual
(),
alloc
)
{
}
bhopscotch_map
(
size_type
bucket_count
,
const
Hash
&
hash
,
const
Allocator
&
alloc
)
:
bhopscotch_map
(
bucket_count
,
hash
,
KeyEqual
(),
alloc
)
{
}
explicit
bhopscotch_map
(
const
Allocator
&
alloc
)
:
bhopscotch_map
(
ht
::
DEFAULT_INIT_BUCKETS_SIZE
,
alloc
)
{
}
template
<
class
InputIt
>
bhopscotch_map
(
InputIt
first
,
InputIt
last
,
size_type
bucket_count
=
ht
::
DEFAULT_INIT_BUCKETS_SIZE
,
const
Hash
&
hash
=
Hash
(),
const
KeyEqual
&
equal
=
KeyEqual
(),
const
Allocator
&
alloc
=
Allocator
())
:
bhopscotch_map
(
bucket_count
,
hash
,
equal
,
alloc
)
{
insert
(
first
,
last
);
}
template
<
class
InputIt
>
bhopscotch_map
(
InputIt
first
,
InputIt
last
,
size_type
bucket_count
,
const
Allocator
&
alloc
)
:
bhopscotch_map
(
first
,
last
,
bucket_count
,
Hash
(),
KeyEqual
(),
alloc
)
{
}
template
<
class
InputIt
>
bhopscotch_map
(
InputIt
first
,
InputIt
last
,
size_type
bucket_count
,
const
Hash
&
hash
,
const
Allocator
&
alloc
)
:
bhopscotch_map
(
first
,
last
,
bucket_count
,
hash
,
KeyEqual
(),
alloc
)
{
}
bhopscotch_map
(
std
::
initializer_list
<
value_type
>
init
,
size_type
bucket_count
=
ht
::
DEFAULT_INIT_BUCKETS_SIZE
,
const
Hash
&
hash
=
Hash
(),
const
KeyEqual
&
equal
=
KeyEqual
(),
const
Allocator
&
alloc
=
Allocator
())
:
bhopscotch_map
(
init
.
begin
(),
init
.
end
(),
bucket_count
,
hash
,
equal
,
alloc
)
{
}
bhopscotch_map
(
std
::
initializer_list
<
value_type
>
init
,
size_type
bucket_count
,
const
Allocator
&
alloc
)
:
bhopscotch_map
(
init
.
begin
(),
init
.
end
(),
bucket_count
,
Hash
(),
KeyEqual
(),
alloc
)
{
}
bhopscotch_map
(
std
::
initializer_list
<
value_type
>
init
,
size_type
bucket_count
,
const
Hash
&
hash
,
const
Allocator
&
alloc
)
:
bhopscotch_map
(
init
.
begin
(),
init
.
end
(),
bucket_count
,
hash
,
KeyEqual
(),
alloc
)
{
}
bhopscotch_map
&
operator
=
(
std
::
initializer_list
<
value_type
>
ilist
)
{
m_ht
.
clear
();
m_ht
.
reserve
(
ilist
.
size
());
m_ht
.
insert
(
ilist
.
begin
(),
ilist
.
end
());
return
*
this
;
}
allocator_type
get_allocator
()
const
{
return
m_ht
.
get_allocator
();
}
/*
* Iterators
*/
iterator
begin
()
noexcept
{
return
m_ht
.
begin
();
}
const_iterator
begin
()
const
noexcept
{
return
m_ht
.
begin
();
}
const_iterator
cbegin
()
const
noexcept
{
return
m_ht
.
cbegin
();
}
iterator
end
()
noexcept
{
return
m_ht
.
end
();
}
const_iterator
end
()
const
noexcept
{
return
m_ht
.
end
();
}
const_iterator
cend
()
const
noexcept
{
return
m_ht
.
cend
();
}
/*
* Capacity
*/
bool
empty
()
const
noexcept
{
return
m_ht
.
empty
();
}
size_type
size
()
const
noexcept
{
return
m_ht
.
size
();
}
size_type
max_size
()
const
noexcept
{
return
m_ht
.
max_size
();
}
/*
* Modifiers
*/
void
clear
()
noexcept
{
m_ht
.
clear
();
}
std
::
pair
<
iterator
,
bool
>
insert
(
const
value_type
&
value
)
{
return
m_ht
.
insert
(
value
);
}
template
<
class
P
,
typename
std
::
enable_if
<
std
::
is_constructible
<
value_type
,
P
&&
>
::
value
>::
type
*
=
nullptr
>
std
::
pair
<
iterator
,
bool
>
insert
(
P
&&
value
)
{
return
m_ht
.
insert
(
std
::
forward
<
P
>
(
value
));
}
std
::
pair
<
iterator
,
bool
>
insert
(
value_type
&&
value
)
{
return
m_ht
.
insert
(
std
::
move
(
value
));
}
iterator
insert
(
const_iterator
hint
,
const
value_type
&
value
)
{
return
m_ht
.
insert
(
hint
,
value
);
}
template
<
class
P
,
typename
std
::
enable_if
<
std
::
is_constructible
<
value_type
,
P
&&
>
::
value
>::
type
*
=
nullptr
>
iterator
insert
(
const_iterator
hint
,
P
&&
value
)
{
return
m_ht
.
insert
(
hint
,
std
::
forward
<
P
>
(
value
));
}
iterator
insert
(
const_iterator
hint
,
value_type
&&
value
)
{
return
m_ht
.
insert
(
hint
,
std
::
move
(
value
));
}
template
<
class
InputIt
>
void
insert
(
InputIt
first
,
InputIt
last
)
{
m_ht
.
insert
(
first
,
last
);
}
void
insert
(
std
::
initializer_list
<
value_type
>
ilist
)
{
m_ht
.
insert
(
ilist
.
begin
(),
ilist
.
end
());
}
template
<
class
M
>
std
::
pair
<
iterator
,
bool
>
insert_or_assign
(
const
key_type
&
k
,
M
&&
obj
)
{
return
m_ht
.
insert_or_assign
(
k
,
std
::
forward
<
M
>
(
obj
));
}
template
<
class
M
>
std
::
pair
<
iterator
,
bool
>
insert_or_assign
(
key_type
&&
k
,
M
&&
obj
)
{
return
m_ht
.
insert_or_assign
(
std
::
move
(
k
),
std
::
forward
<
M
>
(
obj
));
}
template
<
class
M
>
iterator
insert_or_assign
(
const_iterator
hint
,
const
key_type
&
k
,
M
&&
obj
)
{
return
m_ht
.
insert_or_assign
(
hint
,
k
,
std
::
forward
<
M
>
(
obj
));
}
template
<
class
M
>
iterator
insert_or_assign
(
const_iterator
hint
,
key_type
&&
k
,
M
&&
obj
)
{
return
m_ht
.
insert_or_assign
(
hint
,
std
::
move
(
k
),
std
::
forward
<
M
>
(
obj
));
}
/**
* Due to the way elements are stored, emplace will need to move or copy the key-value once.
* The method is equivalent to insert(value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template
<
class
...
Args
>
std
::
pair
<
iterator
,
bool
>
emplace
(
Args
&&
...
args
)
{
return
m_ht
.
emplace
(
std
::
forward
<
Args
>
(
args
)...);
}
/**
* Due to the way elements are stored, emplace_hint will need to move or copy the key-value once.
* The method is equivalent to insert(hint, value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template
<
class
...
Args
>
iterator
emplace_hint
(
const_iterator
hint
,
Args
&&
...
args
)
{
return
m_ht
.
emplace_hint
(
hint
,
std
::
forward
<
Args
>
(
args
)...);
}
template
<
class
...
Args
>
std
::
pair
<
iterator
,
bool
>
try_emplace
(
const
key_type
&
k
,
Args
&&
...
args
)
{
return
m_ht
.
try_emplace
(
k
,
std
::
forward
<
Args
>
(
args
)...);
}
template
<
class
...
Args
>
std
::
pair
<
iterator
,
bool
>
try_emplace
(
key_type
&&
k
,
Args
&&
...
args
)
{
return
m_ht
.
try_emplace
(
std
::
move
(
k
),
std
::
forward
<
Args
>
(
args
)...);
}
template
<
class
...
Args
>
iterator
try_emplace
(
const_iterator
hint
,
const
key_type
&
k
,
Args
&&
...
args
)
{
return
m_ht
.
try_emplace
(
hint
,
k
,
std
::
forward
<
Args
>
(
args
)...);
}
template
<
class
...
Args
>
iterator
try_emplace
(
const_iterator
hint
,
key_type
&&
k
,
Args
&&
...
args
)
{
return
m_ht
.
try_emplace
(
hint
,
std
::
move
(
k
),
std
::
forward
<
Args
>
(
args
)...);
}
iterator
erase
(
iterator
pos
)
{
return
m_ht
.
erase
(
pos
);
}
iterator
erase
(
const_iterator
pos
)
{
return
m_ht
.
erase
(
pos
);
}
iterator
erase
(
const_iterator
first
,
const_iterator
last
)
{
return
m_ht
.
erase
(
first
,
last
);
}
size_type
erase
(
const
key_type
&
key
)
{
return
m_ht
.
erase
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
*/
size_type
erase
(
const
key_type
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
erase
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent
* and Compare::is_transparent exist.
* If so, K must be hashable and comparable to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
class
CP
=
Compare
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
&&
has_is_transparent
<
CP
>::
value
>::
type
*
=
nullptr
>
size_type
erase
(
const
K
&
key
)
{
return
m_ht
.
erase
(
key
);
}
/**
* @copydoc erase(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
class
CP
=
Compare
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
&&
has_is_transparent
<
CP
>::
value
>::
type
*
=
nullptr
>
size_type
erase
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
erase
(
key
,
precalculated_hash
);
}
void
swap
(
bhopscotch_map
&
other
)
{
other
.
m_ht
.
swap
(
m_ht
);
}
/*
* Lookup
*/
T
&
at
(
const
Key
&
key
)
{
return
m_ht
.
at
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
T
&
at
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
at
(
key
,
precalculated_hash
);
}
const
T
&
at
(
const
Key
&
key
)
const
{
return
m_ht
.
at
(
key
);
}
/**
* @copydoc at(const Key& key, std::size_t precalculated_hash)
*/
const
T
&
at
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
at
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent
* and Compare::is_transparent exist.
* If so, K must be hashable and comparable to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
class
CP
=
Compare
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
&&
has_is_transparent
<
CP
>::
value
>::
type
*
=
nullptr
>
T
&
at
(
const
K
&
key
)
{
return
m_ht
.
at
(
key
);
}
/**
* @copydoc at(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
class
CP
=
Compare
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
&&
has_is_transparent
<
CP
>::
value
>::
type
*
=
nullptr
>
T
&
at
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
at
(
key
,
precalculated_hash
);
}
/**
* @copydoc at(const K& key)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
class
CP
=
Compare
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
&&
has_is_transparent
<
CP
>::
value
>::
type
*
=
nullptr
>
const
T
&
at
(
const
K
&
key
)
const
{
return
m_ht
.
at
(
key
);
}
/**
* @copydoc at(const K& key, std::size_t precalculated_hash)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
class
CP
=
Compare
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
&&
has_is_transparent
<
CP
>::
value
>::
type
*
=
nullptr
>
const
T
&
at
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
at
(
key
,
precalculated_hash
);
}
T
&
operator
[](
const
Key
&
key
)
{
return
m_ht
[
key
];
}
T
&
operator
[](
Key
&&
key
)
{
return
m_ht
[
std
::
move
(
key
)];
}
size_type
count
(
const
Key
&
key
)
const
{
return
m_ht
.
count
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
size_type
count
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
count
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent
* and Compare::is_transparent exist.
* If so, K must be hashable and comparable to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
class
CP
=
Compare
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
&&
has_is_transparent
<
CP
>::
value
>::
type
*
=
nullptr
>
size_type
count
(
const
K
&
key
)
const
{
return
m_ht
.
count
(
key
);
}
/**
* @copydoc count(const K& key) const
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
class
CP
=
Compare
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
&&
has_is_transparent
<
CP
>::
value
>::
type
*
=
nullptr
>
size_type
count
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
count
(
key
,
precalculated_hash
);
}
iterator
find
(
const
Key
&
key
)
{
return
m_ht
.
find
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
iterator
find
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
find
(
key
,
precalculated_hash
);
}
const_iterator
find
(
const
Key
&
key
)
const
{
return
m_ht
.
find
(
key
);
}
/**
* @copydoc find(const Key& key, std::size_t precalculated_hash)
*/
const_iterator
find
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
find
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent
* and Compare::is_transparent exist.
* If so, K must be hashable and comparable to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
class
CP
=
Compare
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
&&
has_is_transparent
<
CP
>::
value
>::
type
*
=
nullptr
>
iterator
find
(
const
K
&
key
)
{
return
m_ht
.
find
(
key
);
}
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
class
CP
=
Compare
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
&&
has_is_transparent
<
CP
>::
value
>::
type
*
=
nullptr
>
iterator
find
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
find
(
key
,
precalculated_hash
);
}
/**
* @copydoc find(const K& key)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
class
CP
=
Compare
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
&&
has_is_transparent
<
CP
>::
value
>::
type
*
=
nullptr
>
const_iterator
find
(
const
K
&
key
)
const
{
return
m_ht
.
find
(
key
);
}
/**
* @copydoc find(const K& key, std::size_t precalculated_hash)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
class
CP
=
Compare
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
&&
has_is_transparent
<
CP
>::
value
>::
type
*
=
nullptr
>
const_iterator
find
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
find
(
key
,
precalculated_hash
);
}
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
Key
&
key
)
{
return
m_ht
.
equal_range
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
equal_range
(
key
,
precalculated_hash
);
}
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
Key
&
key
)
const
{
return
m_ht
.
equal_range
(
key
);
}
/**
* @copydoc equal_range(const Key& key, std::size_t precalculated_hash)
*/
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
equal_range
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent
* and Compare::is_transparent exist.
* If so, K must be hashable and comparable to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
class
CP
=
Compare
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
&&
has_is_transparent
<
CP
>::
value
>::
type
*
=
nullptr
>
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
K
&
key
)
{
return
m_ht
.
equal_range
(
key
);
}
/**
* @copydoc equal_range(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
class
CP
=
Compare
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
&&
has_is_transparent
<
CP
>::
value
>::
type
*
=
nullptr
>
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
equal_range
(
key
,
precalculated_hash
);
}
/**
* @copydoc equal_range(const K& key)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
class
CP
=
Compare
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
&&
has_is_transparent
<
CP
>::
value
>::
type
*
=
nullptr
>
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
K
&
key
)
const
{
return
m_ht
.
equal_range
(
key
);
}
/**
* @copydoc equal_range(const K& key, std::size_t precalculated_hash)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
class
CP
=
Compare
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
&&
has_is_transparent
<
CP
>::
value
>::
type
*
=
nullptr
>
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
equal_range
(
key
,
precalculated_hash
);
}
/*
* Bucket interface
*/
size_type
bucket_count
()
const
{
return
m_ht
.
bucket_count
();
}
size_type
max_bucket_count
()
const
{
return
m_ht
.
max_bucket_count
();
}
/*
* Hash policy
*/
float
load_factor
()
const
{
return
m_ht
.
load_factor
();
}
float
max_load_factor
()
const
{
return
m_ht
.
max_load_factor
();
}
void
max_load_factor
(
float
ml
)
{
m_ht
.
max_load_factor
(
ml
);
}
void
rehash
(
size_type
count_
)
{
m_ht
.
rehash
(
count_
);
}
void
reserve
(
size_type
count_
)
{
m_ht
.
reserve
(
count_
);
}
/*
* Observers
*/
hasher
hash_function
()
const
{
return
m_ht
.
hash_function
();
}
key_equal
key_eq
()
const
{
return
m_ht
.
key_eq
();
}
key_compare
key_comp
()
const
{
return
m_ht
.
key_comp
();
}
/*
* Other
*/
/**
* Convert a const_iterator to an iterator.
*/
iterator
mutable_iterator
(
const_iterator
pos
)
{
return
m_ht
.
mutable_iterator
(
pos
);
}
size_type
overflow_size
()
const
noexcept
{
return
m_ht
.
overflow_size
();
}
friend
bool
operator
==
(
const
bhopscotch_map
&
lhs
,
const
bhopscotch_map
&
rhs
)
{
if
(
lhs
.
size
()
!=
rhs
.
size
())
{
return
false
;
}
for
(
const
auto
&
element_lhs
:
lhs
)
{
const
auto
it_element_rhs
=
rhs
.
find
(
element_lhs
.
first
);
if
(
it_element_rhs
==
rhs
.
cend
()
||
element_lhs
.
second
!=
it_element_rhs
->
second
)
{
return
false
;
}
}
return
true
;
}
friend
bool
operator
!=
(
const
bhopscotch_map
&
lhs
,
const
bhopscotch_map
&
rhs
)
{
return
!
operator
==
(
lhs
,
rhs
);
}
friend
void
swap
(
bhopscotch_map
&
lhs
,
bhopscotch_map
&
rhs
)
{
lhs
.
swap
(
rhs
);
}
private:
ht
m_ht
;
};
/**
* Same as `tsl::bhopscotch_map<Key, T, Hash, KeyEqual, Compare, Allocator, NeighborhoodSize, StoreHash, tsl::hh::prime_growth_policy>`.
*/
template
<
class
Key
,
class
T
,
class
Hash
=
std
::
hash
<
Key
>,
class
KeyEqual
=
std
::
equal_to
<
Key
>
,
class
Compare
=
std
::
less
<
Key
>
,
class
Allocator
=
std
::
allocator
<
std
::
pair
<
const
Key
,
T
>>
,
unsigned
int
NeighborhoodSize
=
62
,
bool
StoreHash
=
false
>
using
bhopscotch_pg_map
=
bhopscotch_map
<
Key
,
T
,
Hash
,
KeyEqual
,
Compare
,
Allocator
,
NeighborhoodSize
,
StoreHash
,
tsl
::
hh
::
prime_growth_policy
>
;
}
// end namespace tsl
#endif
paddle/fluid/feed/src/common/bhopscotch_set.h
0 → 100755
浏览文件 @
0dc7d425
/**
* MIT License
*
* Copyright (c) 2017 Tessil
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_BHOPSCOTCH_SET_H
#define TSL_BHOPSCOTCH_SET_H
#include <algorithm>
#include <cstddef>
#include <functional>
#include <initializer_list>
#include <memory>
#include <set>
#include <type_traits>
#include <utility>
#include "paddle/fluid/feed/src/common/hopscotch_hash.h"
namespace
tsl
{
/**
* Similar to tsl::hopscotch_set but instead of using a list for overflowing elements it uses
* a binary search tree. It thus needs an additional template parameter Compare. Compare should
* be arithmetically coherent with KeyEqual.
*
* The binary search tree allows the set to have a worst-case scenario of O(log n) for search
* and delete, even if the hash function maps all the elements to the same bucket.
* For insert, the amortized worst case is O(log n), but the worst case is O(n) in case of rehash.
*
* This makes the set resistant to DoS attacks (but doesn't preclude you to have a good hash function,
* as an element in the bucket array is faster to retrieve than in the tree).
*
* @copydoc hopscotch_set
*/
template
<
class
Key
,
class
Hash
=
std
::
hash
<
Key
>,
class
KeyEqual
=
std
::
equal_to
<
Key
>
,
class
Compare
=
std
::
less
<
Key
>
,
class
Allocator
=
std
::
allocator
<
Key
>
,
unsigned
int
NeighborhoodSize
=
62
,
bool
StoreHash
=
false
,
class
GrowthPolicy
=
tsl
::
hh
::
power_of_two_growth_policy
<
2
>>
class
bhopscotch_set
{
private:
template
<
typename
U
>
using
has_is_transparent
=
tsl
::
detail_hopscotch_hash
::
has_is_transparent
<
U
>
;
class
KeySelect
{
public:
using
key_type
=
Key
;
const
key_type
&
operator
()(
const
Key
&
key
)
const
{
return
key
;
}
key_type
&
operator
()(
Key
&
key
)
{
return
key
;
}
};
using
overflow_container_type
=
std
::
set
<
Key
,
Compare
,
Allocator
>
;
using
ht
=
tsl
::
detail_hopscotch_hash
::
hopscotch_hash
<
Key
,
KeySelect
,
void
,
Hash
,
KeyEqual
,
Allocator
,
NeighborhoodSize
,
StoreHash
,
GrowthPolicy
,
overflow_container_type
>
;
public:
using
key_type
=
typename
ht
::
key_type
;
using
value_type
=
typename
ht
::
value_type
;
using
size_type
=
typename
ht
::
size_type
;
using
difference_type
=
typename
ht
::
difference_type
;
using
hasher
=
typename
ht
::
hasher
;
using
key_equal
=
typename
ht
::
key_equal
;
using
key_compare
=
Compare
;
using
allocator_type
=
typename
ht
::
allocator_type
;
using
reference
=
typename
ht
::
reference
;
using
const_reference
=
typename
ht
::
const_reference
;
using
pointer
=
typename
ht
::
pointer
;
using
const_pointer
=
typename
ht
::
const_pointer
;
using
iterator
=
typename
ht
::
iterator
;
using
const_iterator
=
typename
ht
::
const_iterator
;
/*
* Constructors
*/
bhopscotch_set
()
:
bhopscotch_set
(
ht
::
DEFAULT_INIT_BUCKETS_SIZE
)
{
}
explicit
bhopscotch_set
(
size_type
bucket_count
,
const
Hash
&
hash
=
Hash
(),
const
KeyEqual
&
equal
=
KeyEqual
(),
const
Allocator
&
alloc
=
Allocator
(),
const
Compare
&
comp
=
Compare
())
:
m_ht
(
bucket_count
,
hash
,
equal
,
alloc
,
ht
::
DEFAULT_MAX_LOAD_FACTOR
,
comp
)
{
}
bhopscotch_set
(
size_type
bucket_count
,
const
Allocator
&
alloc
)
:
bhopscotch_set
(
bucket_count
,
Hash
(),
KeyEqual
(),
alloc
)
{
}
bhopscotch_set
(
size_type
bucket_count
,
const
Hash
&
hash
,
const
Allocator
&
alloc
)
:
bhopscotch_set
(
bucket_count
,
hash
,
KeyEqual
(),
alloc
)
{
}
explicit
bhopscotch_set
(
const
Allocator
&
alloc
)
:
bhopscotch_set
(
ht
::
DEFAULT_INIT_BUCKETS_SIZE
,
alloc
)
{
}
template
<
class
InputIt
>
bhopscotch_set
(
InputIt
first
,
InputIt
last
,
size_type
bucket_count
=
ht
::
DEFAULT_INIT_BUCKETS_SIZE
,
const
Hash
&
hash
=
Hash
(),
const
KeyEqual
&
equal
=
KeyEqual
(),
const
Allocator
&
alloc
=
Allocator
())
:
bhopscotch_set
(
bucket_count
,
hash
,
equal
,
alloc
)
{
insert
(
first
,
last
);
}
template
<
class
InputIt
>
bhopscotch_set
(
InputIt
first
,
InputIt
last
,
size_type
bucket_count
,
const
Allocator
&
alloc
)
:
bhopscotch_set
(
first
,
last
,
bucket_count
,
Hash
(),
KeyEqual
(),
alloc
)
{
}
template
<
class
InputIt
>
bhopscotch_set
(
InputIt
first
,
InputIt
last
,
size_type
bucket_count
,
const
Hash
&
hash
,
const
Allocator
&
alloc
)
:
bhopscotch_set
(
first
,
last
,
bucket_count
,
hash
,
KeyEqual
(),
alloc
)
{
}
bhopscotch_set
(
std
::
initializer_list
<
value_type
>
init
,
size_type
bucket_count
=
ht
::
DEFAULT_INIT_BUCKETS_SIZE
,
const
Hash
&
hash
=
Hash
(),
const
KeyEqual
&
equal
=
KeyEqual
(),
const
Allocator
&
alloc
=
Allocator
())
:
bhopscotch_set
(
init
.
begin
(),
init
.
end
(),
bucket_count
,
hash
,
equal
,
alloc
)
{
}
bhopscotch_set
(
std
::
initializer_list
<
value_type
>
init
,
size_type
bucket_count
,
const
Allocator
&
alloc
)
:
bhopscotch_set
(
init
.
begin
(),
init
.
end
(),
bucket_count
,
Hash
(),
KeyEqual
(),
alloc
)
{
}
bhopscotch_set
(
std
::
initializer_list
<
value_type
>
init
,
size_type
bucket_count
,
const
Hash
&
hash
,
const
Allocator
&
alloc
)
:
bhopscotch_set
(
init
.
begin
(),
init
.
end
(),
bucket_count
,
hash
,
KeyEqual
(),
alloc
)
{
}
bhopscotch_set
&
operator
=
(
std
::
initializer_list
<
value_type
>
ilist
)
{
m_ht
.
clear
();
m_ht
.
reserve
(
ilist
.
size
());
m_ht
.
insert
(
ilist
.
begin
(),
ilist
.
end
());
return
*
this
;
}
allocator_type
get_allocator
()
const
{
return
m_ht
.
get_allocator
();
}
/*
* Iterators
*/
iterator
begin
()
noexcept
{
return
m_ht
.
begin
();
}
const_iterator
begin
()
const
noexcept
{
return
m_ht
.
begin
();
}
const_iterator
cbegin
()
const
noexcept
{
return
m_ht
.
cbegin
();
}
iterator
end
()
noexcept
{
return
m_ht
.
end
();
}
const_iterator
end
()
const
noexcept
{
return
m_ht
.
end
();
}
const_iterator
cend
()
const
noexcept
{
return
m_ht
.
cend
();
}
/*
* Capacity
*/
bool
empty
()
const
noexcept
{
return
m_ht
.
empty
();
}
size_type
size
()
const
noexcept
{
return
m_ht
.
size
();
}
size_type
max_size
()
const
noexcept
{
return
m_ht
.
max_size
();
}
/*
* Modifiers
*/
void
clear
()
noexcept
{
m_ht
.
clear
();
}
std
::
pair
<
iterator
,
bool
>
insert
(
const
value_type
&
value
)
{
return
m_ht
.
insert
(
value
);
}
std
::
pair
<
iterator
,
bool
>
insert
(
value_type
&&
value
)
{
return
m_ht
.
insert
(
std
::
move
(
value
));
}
iterator
insert
(
const_iterator
hint
,
const
value_type
&
value
)
{
return
m_ht
.
insert
(
hint
,
value
);
}
iterator
insert
(
const_iterator
hint
,
value_type
&&
value
)
{
return
m_ht
.
insert
(
hint
,
std
::
move
(
value
));
}
template
<
class
InputIt
>
void
insert
(
InputIt
first
,
InputIt
last
)
{
m_ht
.
insert
(
first
,
last
);
}
void
insert
(
std
::
initializer_list
<
value_type
>
ilist
)
{
m_ht
.
insert
(
ilist
.
begin
(),
ilist
.
end
());
}
/**
* Due to the way elements are stored, emplace will need to move or copy the key-value once.
* The method is equivalent to insert(value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template
<
class
...
Args
>
std
::
pair
<
iterator
,
bool
>
emplace
(
Args
&&
...
args
)
{
return
m_ht
.
emplace
(
std
::
forward
<
Args
>
(
args
)...);
}
/**
* Due to the way elements are stored, emplace_hint will need to move or copy the key-value once.
* The method is equivalent to insert(hint, value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template
<
class
...
Args
>
iterator
emplace_hint
(
const_iterator
hint
,
Args
&&
...
args
)
{
return
m_ht
.
emplace_hint
(
hint
,
std
::
forward
<
Args
>
(
args
)...);
}
iterator
erase
(
iterator
pos
)
{
return
m_ht
.
erase
(
pos
);
}
iterator
erase
(
const_iterator
pos
)
{
return
m_ht
.
erase
(
pos
);
}
iterator
erase
(
const_iterator
first
,
const_iterator
last
)
{
return
m_ht
.
erase
(
first
,
last
);
}
size_type
erase
(
const
key_type
&
key
)
{
return
m_ht
.
erase
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
*/
size_type
erase
(
const
key_type
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
erase
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent
* and Compare::is_transparent exist.
* If so, K must be hashable and comparable to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
class
CP
=
Compare
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
&&
has_is_transparent
<
CP
>::
value
>::
type
*
=
nullptr
>
size_type
erase
(
const
K
&
key
)
{
return
m_ht
.
erase
(
key
);
}
/**
* @copydoc erase(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
class
CP
=
Compare
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
&&
has_is_transparent
<
CP
>::
value
>::
type
*
=
nullptr
>
size_type
erase
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
erase
(
key
,
precalculated_hash
);
}
void
swap
(
bhopscotch_set
&
other
)
{
other
.
m_ht
.
swap
(
m_ht
);
}
/*
* Lookup
*/
size_type
count
(
const
Key
&
key
)
const
{
return
m_ht
.
count
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
size_type
count
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
count
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent
* and Compare::is_transparent exist.
* If so, K must be hashable and comparable to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
class
CP
=
Compare
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
&&
has_is_transparent
<
CP
>::
value
>::
type
*
=
nullptr
>
size_type
count
(
const
K
&
key
)
const
{
return
m_ht
.
count
(
key
);
}
/**
* @copydoc count(const K& key) const
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
class
CP
=
Compare
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
&&
has_is_transparent
<
CP
>::
value
>::
type
*
=
nullptr
>
size_type
count
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
count
(
key
,
precalculated_hash
);
}
iterator
find
(
const
Key
&
key
)
{
return
m_ht
.
find
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
iterator
find
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
find
(
key
,
precalculated_hash
);
}
const_iterator
find
(
const
Key
&
key
)
const
{
return
m_ht
.
find
(
key
);
}
/**
* @copydoc find(const Key& key, std::size_t precalculated_hash)
*/
const_iterator
find
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
find
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent
* and Compare::is_transparent exist.
* If so, K must be hashable and comparable to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
class
CP
=
Compare
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
&&
has_is_transparent
<
CP
>::
value
>::
type
*
=
nullptr
>
iterator
find
(
const
K
&
key
)
{
return
m_ht
.
find
(
key
);
}
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
class
CP
=
Compare
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
&&
has_is_transparent
<
CP
>::
value
>::
type
*
=
nullptr
>
iterator
find
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
find
(
key
,
precalculated_hash
);
}
/**
* @copydoc find(const K& key)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
class
CP
=
Compare
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
&&
has_is_transparent
<
CP
>::
value
>::
type
*
=
nullptr
>
const_iterator
find
(
const
K
&
key
)
const
{
return
m_ht
.
find
(
key
);
}
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
class
CP
=
Compare
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
&&
has_is_transparent
<
CP
>::
value
>::
type
*
=
nullptr
>
const_iterator
find
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
find
(
key
,
precalculated_hash
);
}
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
Key
&
key
)
{
return
m_ht
.
equal_range
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
equal_range
(
key
,
precalculated_hash
);
}
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
Key
&
key
)
const
{
return
m_ht
.
equal_range
(
key
);
}
/**
* @copydoc equal_range(const Key& key, std::size_t precalculated_hash)
*/
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
equal_range
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent
* and Compare::is_transparent exist.
* If so, K must be hashable and comparable to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
class
CP
=
Compare
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
&&
has_is_transparent
<
CP
>::
value
>::
type
*
=
nullptr
>
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
K
&
key
)
{
return
m_ht
.
equal_range
(
key
);
}
/**
* @copydoc equal_range(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
class
CP
=
Compare
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
&&
has_is_transparent
<
CP
>::
value
>::
type
*
=
nullptr
>
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
equal_range
(
key
,
precalculated_hash
);
}
/**
* @copydoc equal_range(const K& key)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
class
CP
=
Compare
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
&&
has_is_transparent
<
CP
>::
value
>::
type
*
=
nullptr
>
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
K
&
key
)
const
{
return
m_ht
.
equal_range
(
key
);
}
/**
* @copydoc equal_range(const K& key, std::size_t precalculated_hash)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
class
CP
=
Compare
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
&&
has_is_transparent
<
CP
>::
value
>::
type
*
=
nullptr
>
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
equal_range
(
key
,
precalculated_hash
);
}
/*
* Bucket interface
*/
size_type
bucket_count
()
const
{
return
m_ht
.
bucket_count
();
}
size_type
max_bucket_count
()
const
{
return
m_ht
.
max_bucket_count
();
}
/*
* Hash policy
*/
float
load_factor
()
const
{
return
m_ht
.
load_factor
();
}
float
max_load_factor
()
const
{
return
m_ht
.
max_load_factor
();
}
void
max_load_factor
(
float
ml
)
{
m_ht
.
max_load_factor
(
ml
);
}
void
rehash
(
size_type
count_
)
{
m_ht
.
rehash
(
count_
);
}
void
reserve
(
size_type
count_
)
{
m_ht
.
reserve
(
count_
);
}
/*
* Observers
*/
hasher
hash_function
()
const
{
return
m_ht
.
hash_function
();
}
key_equal
key_eq
()
const
{
return
m_ht
.
key_eq
();
}
key_compare
key_comp
()
const
{
return
m_ht
.
key_comp
();
}
/*
* Other
*/
/**
* Convert a const_iterator to an iterator.
*/
iterator
mutable_iterator
(
const_iterator
pos
)
{
return
m_ht
.
mutable_iterator
(
pos
);
}
size_type
overflow_size
()
const
noexcept
{
return
m_ht
.
overflow_size
();
}
friend
bool
operator
==
(
const
bhopscotch_set
&
lhs
,
const
bhopscotch_set
&
rhs
)
{
if
(
lhs
.
size
()
!=
rhs
.
size
())
{
return
false
;
}
for
(
const
auto
&
element_lhs
:
lhs
)
{
const
auto
it_element_rhs
=
rhs
.
find
(
element_lhs
);
if
(
it_element_rhs
==
rhs
.
cend
())
{
return
false
;
}
}
return
true
;
}
friend
bool
operator
!=
(
const
bhopscotch_set
&
lhs
,
const
bhopscotch_set
&
rhs
)
{
return
!
operator
==
(
lhs
,
rhs
);
}
friend
void
swap
(
bhopscotch_set
&
lhs
,
bhopscotch_set
&
rhs
)
{
lhs
.
swap
(
rhs
);
}
private:
ht
m_ht
;
};
/**
* Same as `tsl::bhopscotch_set<Key, Hash, KeyEqual, Compare, Allocator, NeighborhoodSize, StoreHash, tsl::hh::prime_growth_policy>`.
*/
template
<
class
Key
,
class
Hash
=
std
::
hash
<
Key
>,
class
KeyEqual
=
std
::
equal_to
<
Key
>
,
class
Compare
=
std
::
less
<
Key
>
,
class
Allocator
=
std
::
allocator
<
Key
>
,
unsigned
int
NeighborhoodSize
=
62
,
bool
StoreHash
=
false
>
using
bhopscotch_pg_set
=
bhopscotch_set
<
Key
,
Hash
,
KeyEqual
,
Compare
,
Allocator
,
NeighborhoodSize
,
StoreHash
,
tsl
::
hh
::
prime_growth_policy
>
;
}
// end namespace tsl
#endif
paddle/fluid/feed/src/common/dict_plugin.cc
0 → 100755
浏览文件 @
0dc7d425
#include <iostream>
#include "paddle/fluid/feed/src/common/dict_plugin.h"
#include "paddle/fluid/framework/io/fs.h"
namespace
paddle
{
namespace
framework
{
int
FeasignCacheDict
::
Load
(
const
std
::
string
&
path
,
const
std
::
string
&
converter
)
{
auto
version
=
version_
+
1
;
if
(
version
>=
versioned_entity_
.
size
())
{
version
=
0
;
}
auto
&
entity
=
versioned_entity_
[
version
];
uint64_t
data_count
=
0
;
auto
file_list
=
fs_list
(
path
);
for
(
auto
&
file_path
:
file_list
)
{
int
err_no
=
0
;
int
line_len
=
0
;
size_t
buffer_size
=
0
;
char
*
buffer
=
nullptr
;
char
*
data_ptr
=
NULL
;
auto
file
=
fs_open_read
(
file_path
,
&
err_no
,
converter
);
CHECK
(
err_no
==
0
);
while
((
line_len
=
getline
(
&
buffer
,
&
buffer_size
,
file
.
get
()))
>
0
)
{
if
(
line_len
<=
1
)
{
continue
;
}
++
data_count
;
entity
.
Append
(
strtoul
(
buffer
,
&
data_ptr
,
10
),
entity
.
Size
());
}
if
(
buffer
!=
nullptr
)
{
free
(
buffer
);
}
}
version_
=
version
;
std
::
cerr
<<
"Load success data_count"
<<
data_count
<<
" to version:"
<<
version_
<<
std
::
endl
;
return
0
;
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/feed/src/common/dict_plugin.h
0 → 100755
浏览文件 @
0dc7d425
#pragma once
#include <list>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <glog/logging.h>
#include "paddle/fluid/feed/src/common/bhopscotch_map.h"
namespace
paddle
{
namespace
framework
{
class
DictPlugin
{
public:
DictPlugin
()
{}
virtual
~
DictPlugin
()
{}
virtual
int
Load
(
const
std
::
string
&
path
,
const
std
::
string
&
converter
)
=
0
;
};
template
<
class
K
,
class
V
>
class
KvEntity
{
public:
KvEntity
()
{}
~
KvEntity
()
{}
uint32_t
Size
()
{
return
_key_list
.
size
();
}
void
Append
(
const
K
&
k
,
const
V
&
v
)
{
if
(
_dict_data
.
find
(
k
)
!=
_dict_data
.
end
())
{
return
;
}
_key_list
.
push_back
(
k
);
_dict_data
.
emplace
(
k
,
v
);
}
std
::
vector
<
K
>
_key_list
;
tsl
::
bhopscotch_pg_map
<
K
,
V
>
_dict_data
;
};
template
<
class
K
,
class
V
>
class
KvDictPlugin
:
public
DictPlugin
{
public:
KvDictPlugin
()
{
versioned_entity_
.
resize
(
2
);
}
virtual
~
KvDictPlugin
()
{}
// GetValue with version, Return: value
virtual
int
GetValueWithVersion
(
uint32_t
version
,
const
K
&
key
,
V
&
v
)
{
CHECK
(
version
<
versioned_entity_
.
size
());
auto
&
entity
=
versioned_entity_
[
version
];
auto
itr
=
entity
.
_dict_data
.
find
(
key
);
if
(
itr
==
entity
.
_dict_data
.
end
())
{
return
-
1
;
// miss
}
v
=
itr
->
second
;
return
0
;
}
// GetValue without version, Return: value version
virtual
int
GetValue
(
const
K
&
key
,
V
&
v
,
uint32_t
&
version
)
{
version
=
version_
;
auto
&
entity
=
versioned_entity_
[
version
];
auto
itr
=
entity
.
_dict_data
.
find
(
key
);
if
(
itr
==
entity
.
_dict_data
.
end
())
{
return
-
1
;
// miss
}
v
=
itr
->
second
;
return
0
;
}
virtual
int
GetVersion
()
{
return
version_
;
}
protected:
uint32_t
version_
=
0
;
// double-buffer support version:0 1
std
::
vector
<
KvEntity
<
K
,
V
>>
versioned_entity_
;
};
class
FeasignCacheDict
:
public
KvDictPlugin
<
uint64_t
,
uint32_t
>
{
public:
FeasignCacheDict
(){}
virtual
~
FeasignCacheDict
(){}
virtual
int
Load
(
const
std
::
string
&
path
,
const
std
::
string
&
converter
);
};
class
DictPluginManager
{
public:
DictPluginManager
()
{}
virtual
~
DictPluginManager
(){}
static
DictPluginManager
&
Instance
()
{
static
DictPluginManager
manager
;
return
manager
;
}
inline
int
CreateDict
(
const
std
::
string
&
dict_name
)
{
#define PADDLE_DICT_PLUGIN_REGIST(dict) \
if (dict_name == #dict) { \
dicts_map_[dict_name].reset(new dict()); \
return 0; \
}
PADDLE_DICT_PLUGIN_REGIST
(
FeasignCacheDict
)
#undef PADDLE_DICT_PLUGIN_REGIST
return
-
1
;
}
inline
DictPlugin
*
GetDict
(
const
std
::
string
&
dict_name
)
{
if
(
dicts_map_
.
count
(
dict_name
))
{
return
dicts_map_
[
dict_name
].
get
();
}
return
nullptr
;
}
inline
int
LoadDict
(
const
std
::
string
&
dict_name
,
const
std
::
string
&
path
,
const
std
::
string
converter
)
{
auto
dict
=
GetDict
(
dict_name
);
if
(
!
dict
)
{
return
-
1
;
}
return
dict
->
Load
(
path
,
converter
);
}
private:
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
DictPlugin
>>
dicts_map_
;
};
}
// namespace framework
}
// namespace paddle
paddle/fluid/feed/src/common/hopscotch_growth_policy.h
0 → 100755
浏览文件 @
0dc7d425
/**
* MIT License
*
* Copyright (c) 2018 Tessil
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_HOPSCOTCH_GROWTH_POLICY_H
#define TSL_HOPSCOTCH_GROWTH_POLICY_H
#include <algorithm>
#include <array>
#include <climits>
#include <cmath>
#include <cstddef>
#include <iterator>
#include <limits>
#include <ratio>
#include <stdexcept>
namespace
tsl
{
namespace
hh
{
/**
* Grow the hash table by a factor of GrowthFactor keeping the bucket count to a power of two. It allows
* the table to use a mask operation instead of a modulo operation to map a hash to a bucket.
*
* GrowthFactor must be a power of two >= 2.
*/
template
<
std
::
size_t
GrowthFactor
>
class
power_of_two_growth_policy
{
public:
/**
* Called on the hash table creation and on rehash. The number of buckets for the table is passed in parameter.
* This number is a minimum, the policy may update this value with a higher value if needed (but not lower).
*
* If 0 is given, min_bucket_count_in_out must still be 0 after the policy creation and
* bucket_for_hash must always return 0 in this case.
*/
explicit
power_of_two_growth_policy
(
std
::
size_t
&
min_bucket_count_in_out
)
{
if
(
min_bucket_count_in_out
>
max_bucket_count
())
{
throw
std
::
length_error
(
"The hash table exceeds its maxmimum size."
);
}
if
(
min_bucket_count_in_out
>
0
)
{
min_bucket_count_in_out
=
round_up_to_power_of_two
(
min_bucket_count_in_out
);
m_mask
=
min_bucket_count_in_out
-
1
;
}
else
{
m_mask
=
0
;
}
}
/**
* Return the bucket [0, bucket_count()) to which the hash belongs.
* If bucket_count() is 0, it must always return 0.
*/
std
::
size_t
bucket_for_hash
(
std
::
size_t
hash
)
const
noexcept
{
return
hash
&
m_mask
;
}
/**
* Return the bucket count to use when the bucket array grows on rehash.
*/
std
::
size_t
next_bucket_count
()
const
{
if
((
m_mask
+
1
)
>
max_bucket_count
()
/
GrowthFactor
)
{
throw
std
::
length_error
(
"The hash table exceeds its maxmimum size."
);
}
return
(
m_mask
+
1
)
*
GrowthFactor
;
}
/**
* Return the maximum number of buckets supported by the policy.
*/
std
::
size_t
max_bucket_count
()
const
{
// Largest power of two.
return
(
std
::
numeric_limits
<
std
::
size_t
>::
max
()
/
2
)
+
1
;
}
/**
* Reset the growth policy as if it was created with a bucket count of 0.
* After a clear, the policy must always return 0 when bucket_for_hash is called.
*/
void
clear
()
noexcept
{
m_mask
=
0
;
}
private:
static
std
::
size_t
round_up_to_power_of_two
(
std
::
size_t
value
)
{
if
(
is_power_of_two
(
value
))
{
return
value
;
}
if
(
value
==
0
)
{
return
1
;
}
--
value
;
for
(
std
::
size_t
i
=
1
;
i
<
sizeof
(
std
::
size_t
)
*
CHAR_BIT
;
i
*=
2
)
{
value
|=
value
>>
i
;
}
return
value
+
1
;
}
static
constexpr
bool
is_power_of_two
(
std
::
size_t
value
)
{
return
value
!=
0
&&
(
value
&
(
value
-
1
))
==
0
;
}
private:
static_assert
(
is_power_of_two
(
GrowthFactor
)
&&
GrowthFactor
>=
2
,
"GrowthFactor must be a power of two >= 2."
);
std
::
size_t
m_mask
;
};
/**
* Grow the hash table by GrowthFactor::num / GrowthFactor::den and use a modulo to map a hash
* to a bucket. Slower but it can be useful if you want a slower growth.
*/
template
<
class
GrowthFactor
=
std
::
ratio
<
3
,
2
>
>
class
mod_growth_policy
{
public:
explicit
mod_growth_policy
(
std
::
size_t
&
min_bucket_count_in_out
)
{
if
(
min_bucket_count_in_out
>
max_bucket_count
())
{
throw
std
::
length_error
(
"The hash table exceeds its maxmimum size."
);
}
if
(
min_bucket_count_in_out
>
0
)
{
m_mod
=
min_bucket_count_in_out
;
}
else
{
m_mod
=
1
;
}
}
std
::
size_t
bucket_for_hash
(
std
::
size_t
hash
)
const
noexcept
{
return
hash
%
m_mod
;
}
std
::
size_t
next_bucket_count
()
const
{
if
(
m_mod
==
max_bucket_count
())
{
throw
std
::
length_error
(
"The hash table exceeds its maxmimum size."
);
}
const
double
next_bucket_count
=
std
::
ceil
(
double
(
m_mod
)
*
REHASH_SIZE_MULTIPLICATION_FACTOR
);
if
(
!
std
::
isnormal
(
next_bucket_count
))
{
throw
std
::
length_error
(
"The hash table exceeds its maxmimum size."
);
}
if
(
next_bucket_count
>
double
(
max_bucket_count
()))
{
return
max_bucket_count
();
}
else
{
return
std
::
size_t
(
next_bucket_count
);
}
}
std
::
size_t
max_bucket_count
()
const
{
return
MAX_BUCKET_COUNT
;
}
void
clear
()
noexcept
{
m_mod
=
1
;
}
private:
static
constexpr
double
REHASH_SIZE_MULTIPLICATION_FACTOR
=
1.0
*
GrowthFactor
::
num
/
GrowthFactor
::
den
;
static
const
std
::
size_t
MAX_BUCKET_COUNT
=
std
::
size_t
(
double
(
std
::
numeric_limits
<
std
::
size_t
>::
max
()
/
REHASH_SIZE_MULTIPLICATION_FACTOR
));
static_assert
(
REHASH_SIZE_MULTIPLICATION_FACTOR
>=
1.1
,
"Growth factor should be >= 1.1."
);
std
::
size_t
m_mod
;
};
namespace
detail
{
static
constexpr
const
std
::
array
<
std
::
size_t
,
186
>
PRIMES
=
{{
1ull
,
3ull
,
5ull
,
7ull
,
11ull
,
13ull
,
17ull
,
23ull
,
29ull
,
37ull
,
47ull
,
59ull
,
73ull
,
97ull
,
127ull
,
151ull
,
197ull
,
251ull
,
313ull
,
397ull
,
499ull
,
631ull
,
797ull
,
1009ull
,
1259ull
,
1597ull
,
2011ull
,
2539ull
,
3203ull
,
4027ull
,
5087ull
,
6421ull
,
8089ull
,
10193ull
,
12853ull
,
16193ull
,
20399ull
,
25717ull
,
32401ull
,
40823ull
,
51437ull
,
64811ull
,
81649ull
,
102877ull
,
129607ull
,
163307ull
,
205759ull
,
259229ull
,
326617ull
,
411527ull
,
518509ull
,
653267ull
,
823117ull
,
1037059ull
,
1306601ull
,
1646237ull
,
2074129ull
,
2613229ull
,
3292489ull
,
4148279ull
,
5226491ull
,
6584983ull
,
8296553ull
,
10453007ull
,
13169977ull
,
16593127ull
,
20906033ull
,
26339969ull
,
33186281ull
,
41812097ull
,
52679969ull
,
66372617ull
,
83624237ull
,
105359939ull
,
132745199ull
,
167248483ull
,
210719881ull
,
265490441ull
,
334496971ull
,
421439783ull
,
530980861ull
,
668993977ull
,
842879579ull
,
1061961721ull
,
1337987929ull
,
1685759167ull
,
2123923447ull
,
2675975881ull
,
3371518343ull
,
4247846927ull
,
5351951779ull
,
6743036717ull
,
8495693897ull
,
10703903591ull
,
13486073473ull
,
16991387857ull
,
21407807219ull
,
26972146961ull
,
33982775741ull
,
42815614441ull
,
53944293929ull
,
67965551447ull
,
85631228929ull
,
107888587883ull
,
135931102921ull
,
171262457903ull
,
215777175787ull
,
271862205833ull
,
342524915839ull
,
431554351609ull
,
543724411781ull
,
685049831731ull
,
863108703229ull
,
1087448823553ull
,
1370099663459ull
,
1726217406467ull
,
2174897647073ull
,
2740199326961ull
,
3452434812973ull
,
4349795294267ull
,
5480398654009ull
,
6904869625999ull
,
8699590588571ull
,
10960797308051ull
,
13809739252051ull
,
17399181177241ull
,
21921594616111ull
,
27619478504183ull
,
34798362354533ull
,
43843189232363ull
,
55238957008387ull
,
69596724709081ull
,
87686378464759ull
,
110477914016779ull
,
139193449418173ull
,
175372756929481ull
,
220955828033581ull
,
278386898836457ull
,
350745513859007ull
,
441911656067171ull
,
556773797672909ull
,
701491027718027ull
,
883823312134381ull
,
1113547595345903ull
,
1402982055436147ull
,
1767646624268779ull
,
2227095190691797ull
,
2805964110872297ull
,
3535293248537579ull
,
4454190381383713ull
,
5611928221744609ull
,
7070586497075177ull
,
8908380762767489ull
,
11223856443489329ull
,
14141172994150357ull
,
17816761525534927ull
,
22447712886978529ull
,
28282345988300791ull
,
35633523051069991ull
,
44895425773957261ull
,
56564691976601587ull
,
71267046102139967ull
,
89790851547914507ull
,
113129383953203213ull
,
142534092204280003ull
,
179581703095829107ull
,
226258767906406483ull
,
285068184408560057ull
,
359163406191658253ull
,
452517535812813007ull
,
570136368817120201ull
,
718326812383316683ull
,
905035071625626043ull
,
1140272737634240411ull
,
1436653624766633509ull
,
1810070143251252131ull
,
2280545475268481167ull
,
2873307249533267101ull
,
3620140286502504283ull
,
4561090950536962147ull
,
5746614499066534157ull
,
7240280573005008577ull
,
9122181901073924329ull
,
11493228998133068689ull
,
14480561146010017169ull
,
18446744073709551557ull
}};
template
<
unsigned
int
IPrime
>
static
constexpr
std
::
size_t
mod
(
std
::
size_t
hash
)
{
return
hash
%
PRIMES
[
IPrime
];
}
// MOD_PRIME[iprime](hash) returns hash % PRIMES[iprime]. This table allows for faster modulo as the
// compiler can optimize the modulo code better with a constant known at the compilation.
static
constexpr
const
std
::
array
<
std
::
size_t
(
*
)(
std
::
size_t
),
186
>
MOD_PRIME
=
{{
&
mod
<
0
>
,
&
mod
<
1
>
,
&
mod
<
2
>
,
&
mod
<
3
>
,
&
mod
<
4
>
,
&
mod
<
5
>
,
&
mod
<
6
>
,
&
mod
<
7
>
,
&
mod
<
8
>
,
&
mod
<
9
>
,
&
mod
<
10
>
,
&
mod
<
11
>
,
&
mod
<
12
>
,
&
mod
<
13
>
,
&
mod
<
14
>
,
&
mod
<
15
>
,
&
mod
<
16
>
,
&
mod
<
17
>
,
&
mod
<
18
>
,
&
mod
<
19
>
,
&
mod
<
20
>
,
&
mod
<
21
>
,
&
mod
<
22
>
,
&
mod
<
23
>
,
&
mod
<
24
>
,
&
mod
<
25
>
,
&
mod
<
26
>
,
&
mod
<
27
>
,
&
mod
<
28
>
,
&
mod
<
29
>
,
&
mod
<
30
>
,
&
mod
<
31
>
,
&
mod
<
32
>
,
&
mod
<
33
>
,
&
mod
<
34
>
,
&
mod
<
35
>
,
&
mod
<
36
>
,
&
mod
<
37
>
,
&
mod
<
38
>
,
&
mod
<
39
>
,
&
mod
<
40
>
,
&
mod
<
41
>
,
&
mod
<
42
>
,
&
mod
<
43
>
,
&
mod
<
44
>
,
&
mod
<
45
>
,
&
mod
<
46
>
,
&
mod
<
47
>
,
&
mod
<
48
>
,
&
mod
<
49
>
,
&
mod
<
50
>
,
&
mod
<
51
>
,
&
mod
<
52
>
,
&
mod
<
53
>
,
&
mod
<
54
>
,
&
mod
<
55
>
,
&
mod
<
56
>
,
&
mod
<
57
>
,
&
mod
<
58
>
,
&
mod
<
59
>
,
&
mod
<
60
>
,
&
mod
<
61
>
,
&
mod
<
62
>
,
&
mod
<
63
>
,
&
mod
<
64
>
,
&
mod
<
65
>
,
&
mod
<
66
>
,
&
mod
<
67
>
,
&
mod
<
68
>
,
&
mod
<
69
>
,
&
mod
<
70
>
,
&
mod
<
71
>
,
&
mod
<
72
>
,
&
mod
<
73
>
,
&
mod
<
74
>
,
&
mod
<
75
>
,
&
mod
<
76
>
,
&
mod
<
77
>
,
&
mod
<
78
>
,
&
mod
<
79
>
,
&
mod
<
80
>
,
&
mod
<
81
>
,
&
mod
<
82
>
,
&
mod
<
83
>
,
&
mod
<
84
>
,
&
mod
<
85
>
,
&
mod
<
86
>
,
&
mod
<
87
>
,
&
mod
<
88
>
,
&
mod
<
89
>
,
&
mod
<
90
>
,
&
mod
<
91
>
,
&
mod
<
92
>
,
&
mod
<
93
>
,
&
mod
<
94
>
,
&
mod
<
95
>
,
&
mod
<
96
>
,
&
mod
<
97
>
,
&
mod
<
98
>
,
&
mod
<
99
>
,
&
mod
<
100
>
,
&
mod
<
101
>
,
&
mod
<
102
>
,
&
mod
<
103
>
,
&
mod
<
104
>
,
&
mod
<
105
>
,
&
mod
<
106
>
,
&
mod
<
107
>
,
&
mod
<
108
>
,
&
mod
<
109
>
,
&
mod
<
110
>
,
&
mod
<
111
>
,
&
mod
<
112
>
,
&
mod
<
113
>
,
&
mod
<
114
>
,
&
mod
<
115
>
,
&
mod
<
116
>
,
&
mod
<
117
>
,
&
mod
<
118
>
,
&
mod
<
119
>
,
&
mod
<
120
>
,
&
mod
<
121
>
,
&
mod
<
122
>
,
&
mod
<
123
>
,
&
mod
<
124
>
,
&
mod
<
125
>
,
&
mod
<
126
>
,
&
mod
<
127
>
,
&
mod
<
128
>
,
&
mod
<
129
>
,
&
mod
<
130
>
,
&
mod
<
131
>
,
&
mod
<
132
>
,
&
mod
<
133
>
,
&
mod
<
134
>
,
&
mod
<
135
>
,
&
mod
<
136
>
,
&
mod
<
137
>
,
&
mod
<
138
>
,
&
mod
<
139
>
,
&
mod
<
140
>
,
&
mod
<
141
>
,
&
mod
<
142
>
,
&
mod
<
143
>
,
&
mod
<
144
>
,
&
mod
<
145
>
,
&
mod
<
146
>
,
&
mod
<
147
>
,
&
mod
<
148
>
,
&
mod
<
149
>
,
&
mod
<
150
>
,
&
mod
<
151
>
,
&
mod
<
152
>
,
&
mod
<
153
>
,
&
mod
<
154
>
,
&
mod
<
155
>
,
&
mod
<
156
>
,
&
mod
<
157
>
,
&
mod
<
158
>
,
&
mod
<
159
>
,
&
mod
<
160
>
,
&
mod
<
161
>
,
&
mod
<
162
>
,
&
mod
<
163
>
,
&
mod
<
164
>
,
&
mod
<
165
>
,
&
mod
<
166
>
,
&
mod
<
167
>
,
&
mod
<
168
>
,
&
mod
<
169
>
,
&
mod
<
170
>
,
&
mod
<
171
>
,
&
mod
<
172
>
,
&
mod
<
173
>
,
&
mod
<
174
>
,
&
mod
<
175
>
,
&
mod
<
176
>
,
&
mod
<
177
>
,
&
mod
<
178
>
,
&
mod
<
179
>
,
&
mod
<
180
>
,
&
mod
<
181
>
,
&
mod
<
182
>
,
&
mod
<
183
>
,
&
mod
<
184
>
,
&
mod
<
185
>
}};
}
/**
* Grow the hash table by using prime numbers as bucket count. Slower than tsl::hh::power_of_two_growth_policy in
* general but will probably distribute the values around better in the buckets with a poor hash function.
*
* To allow the compiler to optimize the modulo operation, a lookup table is used with constant primes numbers.
*
* With a switch the code would look like:
* \code
* switch(iprime) { // iprime is the current prime of the hash table
* case 0: hash % 5ul;
* break;
* case 1: hash % 17ul;
* break;
* case 2: hash % 29ul;
* break;
* ...
* }
* \endcode
*
* Due to the constant variable in the modulo the compiler is able to optimize the operation
* by a series of multiplications, substractions and shifts.
*
* The 'hash % 5' could become something like 'hash - (hash * 0xCCCCCCCD) >> 34) * 5' in a 64 bits environement.
*/
class
prime_growth_policy
{
public:
explicit
prime_growth_policy
(
std
::
size_t
&
min_bucket_count_in_out
)
{
auto
it_prime
=
std
::
lower_bound
(
detail
::
PRIMES
.
begin
(),
detail
::
PRIMES
.
end
(),
min_bucket_count_in_out
);
if
(
it_prime
==
detail
::
PRIMES
.
end
())
{
throw
std
::
length_error
(
"The hash table exceeds its maxmimum size."
);
}
m_iprime
=
static_cast
<
unsigned
int
>
(
std
::
distance
(
detail
::
PRIMES
.
begin
(),
it_prime
));
if
(
min_bucket_count_in_out
>
0
)
{
min_bucket_count_in_out
=
*
it_prime
;
}
else
{
min_bucket_count_in_out
=
0
;
}
}
std
::
size_t
bucket_for_hash
(
std
::
size_t
hash
)
const
noexcept
{
return
detail
::
MOD_PRIME
[
m_iprime
](
hash
);
}
std
::
size_t
next_bucket_count
()
const
{
if
(
m_iprime
+
1
>=
detail
::
PRIMES
.
size
())
{
throw
std
::
length_error
(
"The hash table exceeds its maxmimum size."
);
}
return
detail
::
PRIMES
[
m_iprime
+
1
];
}
std
::
size_t
max_bucket_count
()
const
{
return
detail
::
PRIMES
.
back
();
}
void
clear
()
noexcept
{
m_iprime
=
0
;
}
private:
unsigned
int
m_iprime
;
static_assert
(
std
::
numeric_limits
<
decltype
(
m_iprime
)
>::
max
()
>=
detail
::
PRIMES
.
size
(),
"The type of m_iprime is not big enough."
);
};
}
}
#endif
paddle/fluid/feed/src/common/hopscotch_hash.h
0 → 100755
浏览文件 @
0dc7d425
/**
* MIT License
*
* Copyright (c) 2017 Tessil
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_HOPSCOTCH_HASH_H
#define TSL_HOPSCOTCH_HASH_H
#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <exception>
#include <functional>
#include <initializer_list>
#include <iterator>
#include <limits>
#include <memory>
#include <stdexcept>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>
#include "paddle/fluid/feed/src/common/hopscotch_growth_policy.h"
#if (defined(__GNUC__) && (__GNUC__ == 4) && (__GNUC_MINOR__ < 9))
# define TSL_HH_NO_RANGE_ERASE_WITH_CONST_ITERATOR
#endif
/*
* Only activate tsl_hh_assert if TSL_DEBUG is defined.
* This way we avoid the performance hit when NDEBUG is not defined with assert as tsl_hh_assert is used a lot
* (people usually compile with "-O3" and not "-O3 -DNDEBUG").
*/
#ifdef TSL_DEBUG
# define tsl_hh_assert(expr) assert(expr)
#else
# define tsl_hh_assert(expr) (static_cast<void>(0))
#endif
namespace
tsl
{
namespace
detail_hopscotch_hash
{
template
<
typename
T
>
struct
make_void
{
using
type
=
void
;
};
template
<
typename
T
,
typename
=
void
>
struct
has_is_transparent
:
std
::
false_type
{
};
template
<
typename
T
>
struct
has_is_transparent
<
T
,
typename
make_void
<
typename
T
::
is_transparent
>::
type
>
:
std
::
true_type
{
};
template
<
typename
T
,
typename
=
void
>
struct
has_key_compare
:
std
::
false_type
{
};
template
<
typename
T
>
struct
has_key_compare
<
T
,
typename
make_void
<
typename
T
::
key_compare
>::
type
>
:
std
::
true_type
{
};
template
<
typename
U
>
struct
is_power_of_two_policy
:
std
::
false_type
{
};
template
<
std
::
size_t
GrowthFactor
>
struct
is_power_of_two_policy
<
tsl
::
hh
::
power_of_two_growth_policy
<
GrowthFactor
>>:
std
::
true_type
{
};
/*
* smallest_type_for_min_bits::type returns the smallest type that can fit MinBits.
*/
static
const
std
::
size_t
SMALLEST_TYPE_MAX_BITS_SUPPORTED
=
64
;
template
<
unsigned
int
MinBits
,
typename
Enable
=
void
>
class
smallest_type_for_min_bits
{
};
template
<
unsigned
int
MinBits
>
class
smallest_type_for_min_bits
<
MinBits
,
typename
std
::
enable_if
<
(
MinBits
>
0
)
&&
(
MinBits
<=
8
)
>::
type
>
{
public:
using
type
=
std
::
uint_least8_t
;
};
template
<
unsigned
int
MinBits
>
class
smallest_type_for_min_bits
<
MinBits
,
typename
std
::
enable_if
<
(
MinBits
>
8
)
&&
(
MinBits
<=
16
)
>::
type
>
{
public:
using
type
=
std
::
uint_least16_t
;
};
template
<
unsigned
int
MinBits
>
class
smallest_type_for_min_bits
<
MinBits
,
typename
std
::
enable_if
<
(
MinBits
>
16
)
&&
(
MinBits
<=
32
)
>::
type
>
{
public:
using
type
=
std
::
uint_least32_t
;
};
template
<
unsigned
int
MinBits
>
class
smallest_type_for_min_bits
<
MinBits
,
typename
std
::
enable_if
<
(
MinBits
>
32
)
&&
(
MinBits
<=
64
)
>::
type
>
{
public:
using
type
=
std
::
uint_least64_t
;
};
/*
* Each bucket may store up to three elements:
* - An aligned storage to store a value_type object with placement-new.
* - An (optional) hash of the value in the bucket.
* - An unsigned integer of type neighborhood_bitmap used to tell us which buckets in the neighborhood of the
* current bucket contain a value with a hash belonging to the current bucket.
*
* For a bucket 'bct', a bit 'i' (counting from 0 and from the least significant bit to the most significant)
* set to 1 means that the bucket 'bct + i' contains a value with a hash belonging to bucket 'bct'.
* The bits used for that, start from the third least significant bit.
* The two least significant bits are reserved:
* - The least significant bit is set to 1 if there is a value in the bucket storage.
* - The second least significant bit is set to 1 if there is an overflow. More than NeighborhoodSize values
* give the same hash, all overflow values are stored in the m_overflow_elements list of the map.
*
* Details regarding hopscotch hashing an its implementation can be found here:
* https://tessil.github.io/2016/08/29/hopscotch-hashing.html
*/
static
const
std
::
size_t
NB_RESERVED_BITS_IN_NEIGHBORHOOD
=
2
;
using
truncated_hash_type
=
std
::
uint_least32_t
;
/**
* Helper class that stores a truncated hash if StoreHash is true and nothing otherwise.
*/
template
<
bool
StoreHash
>
class
hopscotch_bucket_hash
{
public:
bool
bucket_hash_equal
(
std
::
size_t
/*hash*/
)
const
noexcept
{
return
true
;
}
truncated_hash_type
truncated_bucket_hash
()
const
noexcept
{
return
0
;
}
protected:
void
copy_hash
(
const
hopscotch_bucket_hash
&
)
noexcept
{
}
void
set_hash
(
truncated_hash_type
/*hash*/
)
noexcept
{
}
};
template
<
>
class
hopscotch_bucket_hash
<
true
>
{
public:
bool
bucket_hash_equal
(
std
::
size_t
hash
)
const
noexcept
{
return
m_hash
==
truncated_hash_type
(
hash
);
}
truncated_hash_type
truncated_bucket_hash
()
const
noexcept
{
return
m_hash
;
}
protected:
void
copy_hash
(
const
hopscotch_bucket_hash
&
bucket
)
noexcept
{
m_hash
=
bucket
.
m_hash
;
}
void
set_hash
(
truncated_hash_type
hash
)
noexcept
{
m_hash
=
hash
;
}
private:
truncated_hash_type
m_hash
;
};
template
<
typename
ValueType
,
unsigned
int
NeighborhoodSize
,
bool
StoreHash
>
class
hopscotch_bucket
:
public
hopscotch_bucket_hash
<
StoreHash
>
{
private:
static
const
std
::
size_t
MIN_NEIGHBORHOOD_SIZE
=
4
;
static
const
std
::
size_t
MAX_NEIGHBORHOOD_SIZE
=
SMALLEST_TYPE_MAX_BITS_SUPPORTED
-
NB_RESERVED_BITS_IN_NEIGHBORHOOD
;
static_assert
(
NeighborhoodSize
>=
4
,
"NeighborhoodSize should be >= 4."
);
// We can't put a variable in the message, ensure coherence
static_assert
(
MIN_NEIGHBORHOOD_SIZE
==
4
,
""
);
static_assert
(
NeighborhoodSize
<=
62
,
"NeighborhoodSize should be <= 62."
);
// We can't put a variable in the message, ensure coherence
static_assert
(
MAX_NEIGHBORHOOD_SIZE
==
62
,
""
);
static_assert
(
!
StoreHash
||
NeighborhoodSize
<=
30
,
"NeighborhoodSize should be <= 30 if StoreHash is true."
);
// We can't put a variable in the message, ensure coherence
static_assert
(
MAX_NEIGHBORHOOD_SIZE
-
32
==
30
,
""
);
using
bucket_hash
=
hopscotch_bucket_hash
<
StoreHash
>
;
public:
using
value_type
=
ValueType
;
using
neighborhood_bitmap
=
typename
smallest_type_for_min_bits
<
NeighborhoodSize
+
NB_RESERVED_BITS_IN_NEIGHBORHOOD
>::
type
;
hopscotch_bucket
()
noexcept
:
bucket_hash
(),
m_neighborhood_infos
(
0
)
{
tsl_hh_assert
(
empty
());
}
hopscotch_bucket
(
const
hopscotch_bucket
&
bucket
)
noexcept
(
std
::
is_nothrow_copy_constructible
<
value_type
>::
value
)
:
bucket_hash
(
bucket
),
m_neighborhood_infos
(
0
)
{
if
(
!
bucket
.
empty
())
{
::
new
(
static_cast
<
void
*>
(
std
::
addressof
(
m_value
)))
value_type
(
bucket
.
value
());
}
m_neighborhood_infos
=
bucket
.
m_neighborhood_infos
;
}
hopscotch_bucket
(
hopscotch_bucket
&&
bucket
)
noexcept
(
std
::
is_nothrow_move_constructible
<
value_type
>::
value
)
:
bucket_hash
(
std
::
move
(
bucket
)),
m_neighborhood_infos
(
0
)
{
if
(
!
bucket
.
empty
())
{
::
new
(
static_cast
<
void
*>
(
std
::
addressof
(
m_value
)))
value_type
(
std
::
move
(
bucket
.
value
()));
}
m_neighborhood_infos
=
bucket
.
m_neighborhood_infos
;
}
hopscotch_bucket
&
operator
=
(
const
hopscotch_bucket
&
bucket
)
noexcept
(
std
::
is_nothrow_copy_constructible
<
value_type
>::
value
)
{
if
(
this
!=
&
bucket
)
{
remove_value
();
bucket_hash
::
operator
=
(
bucket
);
if
(
!
bucket
.
empty
())
{
::
new
(
static_cast
<
void
*>
(
std
::
addressof
(
m_value
)))
value_type
(
bucket
.
value
());
}
m_neighborhood_infos
=
bucket
.
m_neighborhood_infos
;
}
return
*
this
;
}
hopscotch_bucket
&
operator
=
(
hopscotch_bucket
&&
)
=
delete
;
~
hopscotch_bucket
()
noexcept
{
if
(
!
empty
())
{
destroy_value
();
}
}
neighborhood_bitmap
neighborhood_infos
()
const
noexcept
{
return
neighborhood_bitmap
(
m_neighborhood_infos
>>
NB_RESERVED_BITS_IN_NEIGHBORHOOD
);
}
void
set_overflow
(
bool
has_overflow
)
noexcept
{
if
(
has_overflow
)
{
m_neighborhood_infos
=
neighborhood_bitmap
(
m_neighborhood_infos
|
2
);
}
else
{
m_neighborhood_infos
=
neighborhood_bitmap
(
m_neighborhood_infos
&
~
2
);
}
}
bool
has_overflow
()
const
noexcept
{
return
(
m_neighborhood_infos
&
2
)
!=
0
;
}
bool
empty
()
const
noexcept
{
return
(
m_neighborhood_infos
&
1
)
==
0
;
}
void
toggle_neighbor_presence
(
std
::
size_t
ineighbor
)
noexcept
{
tsl_hh_assert
(
ineighbor
<=
NeighborhoodSize
);
m_neighborhood_infos
=
neighborhood_bitmap
(
m_neighborhood_infos
^
(
1ull
<<
(
ineighbor
+
NB_RESERVED_BITS_IN_NEIGHBORHOOD
)));
}
bool
check_neighbor_presence
(
std
::
size_t
ineighbor
)
const
noexcept
{
tsl_hh_assert
(
ineighbor
<=
NeighborhoodSize
);
if
(((
m_neighborhood_infos
>>
(
ineighbor
+
NB_RESERVED_BITS_IN_NEIGHBORHOOD
))
&
1
)
==
1
)
{
return
true
;
}
return
false
;
}
value_type
&
value
()
noexcept
{
tsl_hh_assert
(
!
empty
());
return
*
reinterpret_cast
<
value_type
*>
(
std
::
addressof
(
m_value
));
}
const
value_type
&
value
()
const
noexcept
{
tsl_hh_assert
(
!
empty
());
return
*
reinterpret_cast
<
const
value_type
*>
(
std
::
addressof
(
m_value
));
}
template
<
typename
...
Args
>
void
set_value_of_empty_bucket
(
truncated_hash_type
hash
,
Args
&&
...
value_type_args
)
{
tsl_hh_assert
(
empty
());
::
new
(
static_cast
<
void
*>
(
std
::
addressof
(
m_value
)))
value_type
(
std
::
forward
<
Args
>
(
value_type_args
)...);
set_empty
(
false
);
this
->
set_hash
(
hash
);
}
void
swap_value_into_empty_bucket
(
hopscotch_bucket
&
empty_bucket
)
{
tsl_hh_assert
(
empty_bucket
.
empty
());
if
(
!
empty
())
{
::
new
(
static_cast
<
void
*>
(
std
::
addressof
(
empty_bucket
.
m_value
)))
value_type
(
std
::
move
(
value
()));
empty_bucket
.
copy_hash
(
*
this
);
empty_bucket
.
set_empty
(
false
);
destroy_value
();
set_empty
(
true
);
}
}
void
remove_value
()
noexcept
{
if
(
!
empty
())
{
destroy_value
();
set_empty
(
true
);
}
}
void
clear
()
noexcept
{
if
(
!
empty
())
{
destroy_value
();
}
m_neighborhood_infos
=
0
;
tsl_hh_assert
(
empty
());
}
static
truncated_hash_type
truncate_hash
(
std
::
size_t
hash
)
noexcept
{
return
truncated_hash_type
(
hash
);
}
private:
void
set_empty
(
bool
is_empty
)
noexcept
{
if
(
is_empty
)
{
m_neighborhood_infos
=
neighborhood_bitmap
(
m_neighborhood_infos
&
~
1
);
}
else
{
m_neighborhood_infos
=
neighborhood_bitmap
(
m_neighborhood_infos
|
1
);
}
}
void
destroy_value
()
noexcept
{
tsl_hh_assert
(
!
empty
());
value
().
~
value_type
();
}
private:
using
storage
=
typename
std
::
aligned_storage
<
sizeof
(
value_type
),
alignof
(
value_type
)
>::
type
;
neighborhood_bitmap
m_neighborhood_infos
;
storage
m_value
;
};
/**
* Internal common class used by (b)hopscotch_map and (b)hopscotch_set.
*
* ValueType is what will be stored by hopscotch_hash (usually std::pair<Key, T> for a map and Key for a set).
*
* KeySelect should be a FunctionObject which takes a ValueType in parameter and returns a reference to the key.
*
* ValueSelect should be a FunctionObject which takes a ValueType in parameter and returns a reference to the value.
* ValueSelect should be void if there is no value (in a set for example).
*
* OverflowContainer will be used as containers for overflown elements. Usually it should be a list<ValueType>
* or a set<Key>/map<Key, T>.
*/
template
<
class
ValueType
,
class
KeySelect
,
class
ValueSelect
,
class
Hash
,
class
KeyEqual
,
class
Allocator
,
unsigned
int
NeighborhoodSize
,
bool
StoreHash
,
class
GrowthPolicy
,
class
OverflowContainer
>
class
hopscotch_hash
:
private
Hash
,
private
KeyEqual
,
private
GrowthPolicy
{
private:
template
<
typename
U
>
using
has_mapped_type
=
typename
std
::
integral_constant
<
bool
,
!
std
::
is_same
<
U
,
void
>::
value
>
;
static_assert
(
noexcept
(
std
::
declval
<
GrowthPolicy
>
().
bucket_for_hash
(
std
::
size_t
(
0
))),
"GrowthPolicy::bucket_for_hash must be noexcept."
);
static_assert
(
noexcept
(
std
::
declval
<
GrowthPolicy
>
().
clear
()),
"GrowthPolicy::clear must be noexcept."
);
public:
template
<
bool
IsConst
>
class
hopscotch_iterator
;
using
key_type
=
typename
KeySelect
::
key_type
;
using
value_type
=
ValueType
;
using
size_type
=
std
::
size_t
;
using
difference_type
=
std
::
ptrdiff_t
;
using
hasher
=
Hash
;
using
key_equal
=
KeyEqual
;
using
allocator_type
=
Allocator
;
using
reference
=
value_type
&
;
using
const_reference
=
const
value_type
&
;
using
pointer
=
value_type
*
;
using
const_pointer
=
const
value_type
*
;
using
iterator
=
hopscotch_iterator
<
false
>
;
using
const_iterator
=
hopscotch_iterator
<
true
>
;
private:
using
hopscotch_bucket
=
tsl
::
detail_hopscotch_hash
::
hopscotch_bucket
<
ValueType
,
NeighborhoodSize
,
StoreHash
>
;
using
neighborhood_bitmap
=
typename
hopscotch_bucket
::
neighborhood_bitmap
;
using
buckets_allocator
=
typename
std
::
allocator_traits
<
allocator_type
>::
template
rebind_alloc
<
hopscotch_bucket
>;
using
buckets_container_type
=
std
::
vector
<
hopscotch_bucket
,
buckets_allocator
>
;
using
overflow_container_type
=
OverflowContainer
;
static_assert
(
std
::
is_same
<
typename
overflow_container_type
::
value_type
,
ValueType
>::
value
,
"OverflowContainer should have ValueType as type."
);
static_assert
(
std
::
is_same
<
typename
overflow_container_type
::
allocator_type
,
Allocator
>::
value
,
"Invalid allocator, not the same type as the value_type."
);
using
iterator_buckets
=
typename
buckets_container_type
::
iterator
;
using
const_iterator_buckets
=
typename
buckets_container_type
::
const_iterator
;
using
iterator_overflow
=
typename
overflow_container_type
::
iterator
;
using
const_iterator_overflow
=
typename
overflow_container_type
::
const_iterator
;
public:
/**
* The `operator*()` and `operator->()` methods return a const reference and const pointer respectively to the
* stored value type.
*
* In case of a map, to get a modifiable reference to the value associated to a key (the `.second` in the
* stored pair), you have to call `value()`.
*/
template
<
bool
IsConst
>
class
hopscotch_iterator
{
friend
class
hopscotch_hash
;
private:
using
iterator_bucket
=
typename
std
::
conditional
<
IsConst
,
typename
hopscotch_hash
::
const_iterator_buckets
,
typename
hopscotch_hash
::
iterator_buckets
>::
type
;
using
iterator_overflow
=
typename
std
::
conditional
<
IsConst
,
typename
hopscotch_hash
::
const_iterator_overflow
,
typename
hopscotch_hash
::
iterator_overflow
>::
type
;
hopscotch_iterator
(
iterator_bucket
buckets_iterator
,
iterator_bucket
buckets_end_iterator
,
iterator_overflow
overflow_iterator
)
noexcept
:
m_buckets_iterator
(
buckets_iterator
),
m_buckets_end_iterator
(
buckets_end_iterator
),
m_overflow_iterator
(
overflow_iterator
)
{
}
public:
using
iterator_category
=
std
::
forward_iterator_tag
;
using
value_type
=
const
typename
hopscotch_hash
::
value_type
;
using
difference_type
=
std
::
ptrdiff_t
;
using
reference
=
value_type
&
;
using
pointer
=
value_type
*
;
hopscotch_iterator
()
noexcept
{
}
// Copy constructor from iterator to const_iterator.
template
<
bool
TIsConst
=
IsConst
,
typename
std
::
enable_if
<
TIsConst
>
::
type
*
=
nullptr
>
hopscotch_iterator
(
const
hopscotch_iterator
<!
TIsConst
>&
other
)
noexcept
:
m_buckets_iterator
(
other
.
m_buckets_iterator
),
m_buckets_end_iterator
(
other
.
m_buckets_end_iterator
),
m_overflow_iterator
(
other
.
m_overflow_iterator
)
{
}
hopscotch_iterator
(
const
hopscotch_iterator
&
other
)
=
default
;
hopscotch_iterator
(
hopscotch_iterator
&&
other
)
=
default
;
hopscotch_iterator
&
operator
=
(
const
hopscotch_iterator
&
other
)
=
default
;
hopscotch_iterator
&
operator
=
(
hopscotch_iterator
&&
other
)
=
default
;
const
typename
hopscotch_hash
::
key_type
&
key
()
const
{
if
(
m_buckets_iterator
!=
m_buckets_end_iterator
)
{
return
KeySelect
()(
m_buckets_iterator
->
value
());
}
return
KeySelect
()(
*
m_overflow_iterator
);
}
template
<
class
U
=
ValueSelect
,
typename
std
::
enable_if
<
has_mapped_type
<
U
>
::
value
>::
type
*
=
nullptr
>
typename
std
::
conditional
<
IsConst
,
const
typename
U
::
value_type
&
,
typename
U
::
value_type
&>::
type
value
()
const
{
if
(
m_buckets_iterator
!=
m_buckets_end_iterator
)
{
return
U
()(
m_buckets_iterator
->
value
());
}
return
U
()(
*
m_overflow_iterator
);
}
reference
operator
*
()
const
{
if
(
m_buckets_iterator
!=
m_buckets_end_iterator
)
{
return
m_buckets_iterator
->
value
();
}
return
*
m_overflow_iterator
;
}
pointer
operator
->
()
const
{
if
(
m_buckets_iterator
!=
m_buckets_end_iterator
)
{
return
std
::
addressof
(
m_buckets_iterator
->
value
());
}
return
std
::
addressof
(
*
m_overflow_iterator
);
}
hopscotch_iterator
&
operator
++
()
{
if
(
m_buckets_iterator
==
m_buckets_end_iterator
)
{
++
m_overflow_iterator
;
return
*
this
;
}
do
{
++
m_buckets_iterator
;
}
while
(
m_buckets_iterator
!=
m_buckets_end_iterator
&&
m_buckets_iterator
->
empty
());
return
*
this
;
}
hopscotch_iterator
operator
++
(
int
)
{
hopscotch_iterator
tmp
(
*
this
);
++*
this
;
return
tmp
;
}
friend
bool
operator
==
(
const
hopscotch_iterator
&
lhs
,
const
hopscotch_iterator
&
rhs
)
{
return
lhs
.
m_buckets_iterator
==
rhs
.
m_buckets_iterator
&&
lhs
.
m_overflow_iterator
==
rhs
.
m_overflow_iterator
;
}
friend
bool
operator
!=
(
const
hopscotch_iterator
&
lhs
,
const
hopscotch_iterator
&
rhs
)
{
return
!
(
lhs
==
rhs
);
}
private:
iterator_bucket
m_buckets_iterator
;
iterator_bucket
m_buckets_end_iterator
;
iterator_overflow
m_overflow_iterator
;
};
public:
template
<
class
OC
=
OverflowContainer
,
typename
std
::
enable_if
<!
has_key_compare
<
OC
>
::
value
>::
type
*
=
nullptr
>
hopscotch_hash
(
size_type
bucket_count
,
const
Hash
&
hash
,
const
KeyEqual
&
equal
,
const
Allocator
&
alloc
,
float
max_load_factor
)
:
Hash
(
hash
),
KeyEqual
(
equal
),
GrowthPolicy
(
bucket_count
),
m_buckets_data
(
alloc
),
m_overflow_elements
(
alloc
),
m_buckets
(
static_empty_bucket_ptr
()),
m_nb_elements
(
0
)
{
if
(
bucket_count
>
max_bucket_count
())
{
throw
std
::
length_error
(
"The map exceeds its maxmimum size."
);
}
if
(
bucket_count
>
0
)
{
static_assert
(
NeighborhoodSize
-
1
>
0
,
""
);
// Can't directly construct with the appropriate size in the initializer
// as m_buckets_data(bucket_count, alloc) is not supported by GCC 4.8
m_buckets_data
.
resize
(
bucket_count
+
NeighborhoodSize
-
1
);
m_buckets
=
m_buckets_data
.
data
();
}
this
->
max_load_factor
(
max_load_factor
);
// Check in the constructor instead of outside of a function to avoi compilation issues
// when value_type is not complete.
static_assert
(
std
::
is_nothrow_move_constructible
<
value_type
>::
value
||
std
::
is_copy_constructible
<
value_type
>::
value
,
"value_type must be either copy constructible or nothrow move constructible."
);
}
template
<
class
OC
=
OverflowContainer
,
typename
std
::
enable_if
<
has_key_compare
<
OC
>
::
value
>::
type
*
=
nullptr
>
hopscotch_hash
(
size_type
bucket_count
,
const
Hash
&
hash
,
const
KeyEqual
&
equal
,
const
Allocator
&
alloc
,
float
max_load_factor
,
const
typename
OC
::
key_compare
&
comp
)
:
Hash
(
hash
),
KeyEqual
(
equal
),
GrowthPolicy
(
bucket_count
),
m_buckets_data
(
alloc
),
m_overflow_elements
(
comp
,
alloc
),
m_buckets
(
static_empty_bucket_ptr
()),
m_nb_elements
(
0
)
{
if
(
bucket_count
>
max_bucket_count
())
{
throw
std
::
length_error
(
"The map exceeds its maxmimum size."
);
}
if
(
bucket_count
>
0
)
{
static_assert
(
NeighborhoodSize
-
1
>
0
,
""
);
// Can't directly construct with the appropriate size in the initializer
// as m_buckets_data(bucket_count, alloc) is not supported by GCC 4.8
m_buckets_data
.
resize
(
bucket_count
+
NeighborhoodSize
-
1
);
m_buckets
=
m_buckets_data
.
data
();
}
this
->
max_load_factor
(
max_load_factor
);
// Check in the constructor instead of outside of a function to avoi compilation issues
// when value_type is not complete.
static_assert
(
std
::
is_nothrow_move_constructible
<
value_type
>::
value
||
std
::
is_copy_constructible
<
value_type
>::
value
,
"value_type must be either copy constructible or nothrow move constructible."
);
}
hopscotch_hash
(
const
hopscotch_hash
&
other
)
:
Hash
(
other
),
KeyEqual
(
other
),
GrowthPolicy
(
other
),
m_buckets_data
(
other
.
m_buckets_data
),
m_overflow_elements
(
other
.
m_overflow_elements
),
m_buckets
(
m_buckets_data
.
empty
()
?
static_empty_bucket_ptr
()
:
m_buckets_data
.
data
()),
m_nb_elements
(
other
.
m_nb_elements
),
m_max_load_factor
(
other
.
m_max_load_factor
),
m_max_load_threshold_rehash
(
other
.
m_max_load_threshold_rehash
),
m_min_load_threshold_rehash
(
other
.
m_min_load_threshold_rehash
)
{
}
hopscotch_hash
(
hopscotch_hash
&&
other
)
noexcept
(
std
::
is_nothrow_move_constructible
<
Hash
>::
value
&&
std
::
is_nothrow_move_constructible
<
KeyEqual
>::
value
&&
std
::
is_nothrow_move_constructible
<
GrowthPolicy
>::
value
&&
std
::
is_nothrow_move_constructible
<
buckets_container_type
>::
value
&&
std
::
is_nothrow_move_constructible
<
overflow_container_type
>::
value
)
:
Hash
(
std
::
move
(
static_cast
<
Hash
&>
(
other
))),
KeyEqual
(
std
::
move
(
static_cast
<
KeyEqual
&>
(
other
))),
GrowthPolicy
(
std
::
move
(
static_cast
<
GrowthPolicy
&>
(
other
))),
m_buckets_data
(
std
::
move
(
other
.
m_buckets_data
)),
m_overflow_elements
(
std
::
move
(
other
.
m_overflow_elements
)),
m_buckets
(
m_buckets_data
.
empty
()
?
static_empty_bucket_ptr
()
:
m_buckets_data
.
data
()),
m_nb_elements
(
other
.
m_nb_elements
),
m_max_load_factor
(
other
.
m_max_load_factor
),
m_max_load_threshold_rehash
(
other
.
m_max_load_threshold_rehash
),
m_min_load_threshold_rehash
(
other
.
m_min_load_threshold_rehash
)
{
other
.
GrowthPolicy
::
clear
();
other
.
m_buckets_data
.
clear
();
other
.
m_overflow_elements
.
clear
();
other
.
m_buckets
=
static_empty_bucket_ptr
();
other
.
m_nb_elements
=
0
;
other
.
m_max_load_threshold_rehash
=
0
;
other
.
m_min_load_threshold_rehash
=
0
;
}
hopscotch_hash
&
operator
=
(
const
hopscotch_hash
&
other
)
{
if
(
&
other
!=
this
)
{
Hash
::
operator
=
(
other
);
KeyEqual
::
operator
=
(
other
);
GrowthPolicy
::
operator
=
(
other
);
m_buckets_data
=
other
.
m_buckets_data
;
m_overflow_elements
=
other
.
m_overflow_elements
;
m_buckets
=
m_buckets_data
.
empty
()
?
static_empty_bucket_ptr
()
:
m_buckets_data
.
data
();
m_nb_elements
=
other
.
m_nb_elements
;
m_max_load_factor
=
other
.
m_max_load_factor
;
m_max_load_threshold_rehash
=
other
.
m_max_load_threshold_rehash
;
m_min_load_threshold_rehash
=
other
.
m_min_load_threshold_rehash
;
}
return
*
this
;
}
hopscotch_hash
&
operator
=
(
hopscotch_hash
&&
other
)
{
other
.
swap
(
*
this
);
other
.
clear
();
return
*
this
;
}
allocator_type
get_allocator
()
const
{
return
m_buckets_data
.
get_allocator
();
}
/*
* Iterators
*/
iterator
begin
()
noexcept
{
auto
begin
=
m_buckets_data
.
begin
();
while
(
begin
!=
m_buckets_data
.
end
()
&&
begin
->
empty
())
{
++
begin
;
}
return
iterator
(
begin
,
m_buckets_data
.
end
(),
m_overflow_elements
.
begin
());
}
const_iterator
begin
()
const
noexcept
{
return
cbegin
();
}
const_iterator
cbegin
()
const
noexcept
{
auto
begin
=
m_buckets_data
.
cbegin
();
while
(
begin
!=
m_buckets_data
.
cend
()
&&
begin
->
empty
())
{
++
begin
;
}
return
const_iterator
(
begin
,
m_buckets_data
.
cend
(),
m_overflow_elements
.
cbegin
());
}
iterator
end
()
noexcept
{
return
iterator
(
m_buckets_data
.
end
(),
m_buckets_data
.
end
(),
m_overflow_elements
.
end
());
}
const_iterator
end
()
const
noexcept
{
return
cend
();
}
const_iterator
cend
()
const
noexcept
{
return
const_iterator
(
m_buckets_data
.
cend
(),
m_buckets_data
.
cend
(),
m_overflow_elements
.
cend
());
}
/*
* Capacity
*/
bool
empty
()
const
noexcept
{
return
m_nb_elements
==
0
;
}
size_type
size
()
const
noexcept
{
return
m_nb_elements
;
}
size_type
max_size
()
const
noexcept
{
return
m_buckets_data
.
max_size
();
}
/*
* Modifiers
*/
void
clear
()
noexcept
{
for
(
auto
&
bucket
:
m_buckets_data
)
{
bucket
.
clear
();
}
m_overflow_elements
.
clear
();
m_nb_elements
=
0
;
}
std
::
pair
<
iterator
,
bool
>
insert
(
const
value_type
&
value
)
{
return
insert_impl
(
value
);
}
template
<
class
P
,
typename
std
::
enable_if
<
std
::
is_constructible
<
value_type
,
P
&&
>
::
value
>::
type
*
=
nullptr
>
std
::
pair
<
iterator
,
bool
>
insert
(
P
&&
value
)
{
return
insert_impl
(
value_type
(
std
::
forward
<
P
>
(
value
)));
}
std
::
pair
<
iterator
,
bool
>
insert
(
value_type
&&
value
)
{
return
insert_impl
(
std
::
move
(
value
));
}
iterator
insert
(
const_iterator
hint
,
const
value_type
&
value
)
{
if
(
hint
!=
cend
()
&&
compare_keys
(
KeySelect
()(
*
hint
),
KeySelect
()(
value
)))
{
return
mutable_iterator
(
hint
);
}
return
insert
(
value
).
first
;
}
template
<
class
P
,
typename
std
::
enable_if
<
std
::
is_constructible
<
value_type
,
P
&&
>
::
value
>::
type
*
=
nullptr
>
iterator
insert
(
const_iterator
hint
,
P
&&
value
)
{
return
emplace_hint
(
hint
,
std
::
forward
<
P
>
(
value
));
}
iterator
insert
(
const_iterator
hint
,
value_type
&&
value
)
{
if
(
hint
!=
cend
()
&&
compare_keys
(
KeySelect
()(
*
hint
),
KeySelect
()(
value
)))
{
return
mutable_iterator
(
hint
);
}
return
insert
(
std
::
move
(
value
)).
first
;
}
template
<
class
InputIt
>
void
insert
(
InputIt
first
,
InputIt
last
)
{
if
(
std
::
is_base_of
<
std
::
forward_iterator_tag
,
typename
std
::
iterator_traits
<
InputIt
>::
iterator_category
>::
value
)
{
const
auto
nb_elements_insert
=
std
::
distance
(
first
,
last
);
const
std
::
size_t
nb_elements_in_buckets
=
m_nb_elements
-
m_overflow_elements
.
size
();
const
std
::
size_t
nb_free_buckets
=
m_max_load_threshold_rehash
-
nb_elements_in_buckets
;
tsl_hh_assert
(
m_nb_elements
>=
m_overflow_elements
.
size
());
tsl_hh_assert
(
m_max_load_threshold_rehash
>=
nb_elements_in_buckets
);
if
(
nb_elements_insert
>
0
&&
nb_free_buckets
<
std
::
size_t
(
nb_elements_insert
))
{
reserve
(
nb_elements_in_buckets
+
std
::
size_t
(
nb_elements_insert
));
}
}
for
(;
first
!=
last
;
++
first
)
{
insert
(
*
first
);
}
}
template
<
class
M
>
std
::
pair
<
iterator
,
bool
>
insert_or_assign
(
const
key_type
&
k
,
M
&&
obj
)
{
return
insert_or_assign_impl
(
k
,
std
::
forward
<
M
>
(
obj
));
}
template
<
class
M
>
std
::
pair
<
iterator
,
bool
>
insert_or_assign
(
key_type
&&
k
,
M
&&
obj
)
{
return
insert_or_assign_impl
(
std
::
move
(
k
),
std
::
forward
<
M
>
(
obj
));
}
template
<
class
M
>
iterator
insert_or_assign
(
const_iterator
hint
,
const
key_type
&
k
,
M
&&
obj
)
{
if
(
hint
!=
cend
()
&&
compare_keys
(
KeySelect
()(
*
hint
),
k
))
{
auto
it
=
mutable_iterator
(
hint
);
it
.
value
()
=
std
::
forward
<
M
>
(
obj
);
return
it
;
}
return
insert_or_assign
(
k
,
std
::
forward
<
M
>
(
obj
)).
first
;
}
template
<
class
M
>
iterator
insert_or_assign
(
const_iterator
hint
,
key_type
&&
k
,
M
&&
obj
)
{
if
(
hint
!=
cend
()
&&
compare_keys
(
KeySelect
()(
*
hint
),
k
))
{
auto
it
=
mutable_iterator
(
hint
);
it
.
value
()
=
std
::
forward
<
M
>
(
obj
);
return
it
;
}
return
insert_or_assign
(
std
::
move
(
k
),
std
::
forward
<
M
>
(
obj
)).
first
;
}
template
<
class
...
Args
>
std
::
pair
<
iterator
,
bool
>
emplace
(
Args
&&
...
args
)
{
return
insert
(
value_type
(
std
::
forward
<
Args
>
(
args
)...));
}
template
<
class
...
Args
>
iterator
emplace_hint
(
const_iterator
hint
,
Args
&&
...
args
)
{
return
insert
(
hint
,
value_type
(
std
::
forward
<
Args
>
(
args
)...));
}
template
<
class
...
Args
>
std
::
pair
<
iterator
,
bool
>
try_emplace
(
const
key_type
&
k
,
Args
&&
...
args
)
{
return
try_emplace_impl
(
k
,
std
::
forward
<
Args
>
(
args
)...);
}
template
<
class
...
Args
>
std
::
pair
<
iterator
,
bool
>
try_emplace
(
key_type
&&
k
,
Args
&&
...
args
)
{
return
try_emplace_impl
(
std
::
move
(
k
),
std
::
forward
<
Args
>
(
args
)...);
}
template
<
class
...
Args
>
iterator
try_emplace
(
const_iterator
hint
,
const
key_type
&
k
,
Args
&&
...
args
)
{
if
(
hint
!=
cend
()
&&
compare_keys
(
KeySelect
()(
*
hint
),
k
))
{
return
mutable_iterator
(
hint
);
}
return
try_emplace
(
k
,
std
::
forward
<
Args
>
(
args
)...).
first
;
}
template
<
class
...
Args
>
iterator
try_emplace
(
const_iterator
hint
,
key_type
&&
k
,
Args
&&
...
args
)
{
if
(
hint
!=
cend
()
&&
compare_keys
(
KeySelect
()(
*
hint
),
k
))
{
return
mutable_iterator
(
hint
);
}
return
try_emplace
(
std
::
move
(
k
),
std
::
forward
<
Args
>
(
args
)...).
first
;
}
/**
* Here to avoid `template<class K> size_type erase(const K& key)` being used when
* we use an iterator instead of a const_iterator.
*/
iterator
erase
(
iterator
pos
)
{
return
erase
(
const_iterator
(
pos
));
}
iterator
erase
(
const_iterator
pos
)
{
const
std
::
size_t
ibucket_for_hash
=
bucket_for_hash
(
hash_key
(
pos
.
key
()));
if
(
pos
.
m_buckets_iterator
!=
pos
.
m_buckets_end_iterator
)
{
auto
it_bucket
=
m_buckets_data
.
begin
()
+
std
::
distance
(
m_buckets_data
.
cbegin
(),
pos
.
m_buckets_iterator
);
erase_from_bucket
(
*
it_bucket
,
ibucket_for_hash
);
return
++
iterator
(
it_bucket
,
m_buckets_data
.
end
(),
m_overflow_elements
.
begin
());
}
else
{
auto
it_next_overflow
=
erase_from_overflow
(
pos
.
m_overflow_iterator
,
ibucket_for_hash
);
return
iterator
(
m_buckets_data
.
end
(),
m_buckets_data
.
end
(),
it_next_overflow
);
}
}
iterator
erase
(
const_iterator
first
,
const_iterator
last
)
{
if
(
first
==
last
)
{
return
mutable_iterator
(
first
);
}
auto
to_delete
=
erase
(
first
);
while
(
to_delete
!=
last
)
{
to_delete
=
erase
(
to_delete
);
}
return
to_delete
;
}
template
<
class
K
>
size_type
erase
(
const
K
&
key
)
{
return
erase
(
key
,
hash_key
(
key
));
}
template
<
class
K
>
size_type
erase
(
const
K
&
key
,
std
::
size_t
hash
)
{
const
std
::
size_t
ibucket_for_hash
=
bucket_for_hash
(
hash
);
hopscotch_bucket
*
bucket_found
=
find_in_buckets
(
key
,
hash
,
m_buckets
+
ibucket_for_hash
);
if
(
bucket_found
!=
nullptr
)
{
erase_from_bucket
(
*
bucket_found
,
ibucket_for_hash
);
return
1
;
}
if
(
m_buckets
[
ibucket_for_hash
].
has_overflow
())
{
auto
it_overflow
=
find_in_overflow
(
key
);
if
(
it_overflow
!=
m_overflow_elements
.
end
())
{
erase_from_overflow
(
it_overflow
,
ibucket_for_hash
);
return
1
;
}
}
return
0
;
}
void
swap
(
hopscotch_hash
&
other
)
{
using
std
::
swap
;
swap
(
static_cast
<
Hash
&>
(
*
this
),
static_cast
<
Hash
&>
(
other
));
swap
(
static_cast
<
KeyEqual
&>
(
*
this
),
static_cast
<
KeyEqual
&>
(
other
));
swap
(
static_cast
<
GrowthPolicy
&>
(
*
this
),
static_cast
<
GrowthPolicy
&>
(
other
));
swap
(
m_buckets_data
,
other
.
m_buckets_data
);
swap
(
m_overflow_elements
,
other
.
m_overflow_elements
);
swap
(
m_buckets
,
other
.
m_buckets
);
swap
(
m_nb_elements
,
other
.
m_nb_elements
);
swap
(
m_max_load_factor
,
other
.
m_max_load_factor
);
swap
(
m_max_load_threshold_rehash
,
other
.
m_max_load_threshold_rehash
);
swap
(
m_min_load_threshold_rehash
,
other
.
m_min_load_threshold_rehash
);
}
/*
* Lookup
*/
template
<
class
K
,
class
U
=
ValueSelect
,
typename
std
::
enable_if
<
has_mapped_type
<
U
>
::
value
>::
type
*
=
nullptr
>
typename
U
::
value_type
&
at
(
const
K
&
key
)
{
return
at
(
key
,
hash_key
(
key
));
}
template
<
class
K
,
class
U
=
ValueSelect
,
typename
std
::
enable_if
<
has_mapped_type
<
U
>
::
value
>::
type
*
=
nullptr
>
typename
U
::
value_type
&
at
(
const
K
&
key
,
std
::
size_t
hash
)
{
return
const_cast
<
typename
U
::
value_type
&>
(
static_cast
<
const
hopscotch_hash
*>
(
this
)
->
at
(
key
,
hash
));
}
template
<
class
K
,
class
U
=
ValueSelect
,
typename
std
::
enable_if
<
has_mapped_type
<
U
>
::
value
>::
type
*
=
nullptr
>
const
typename
U
::
value_type
&
at
(
const
K
&
key
)
const
{
return
at
(
key
,
hash_key
(
key
));
}
template
<
class
K
,
class
U
=
ValueSelect
,
typename
std
::
enable_if
<
has_mapped_type
<
U
>
::
value
>::
type
*
=
nullptr
>
const
typename
U
::
value_type
&
at
(
const
K
&
key
,
std
::
size_t
hash
)
const
{
using
T
=
typename
U
::
value_type
;
const
T
*
value
=
find_value_impl
(
key
,
hash
,
m_buckets
+
bucket_for_hash
(
hash
));
if
(
value
==
nullptr
)
{
throw
std
::
out_of_range
(
"Couldn't find key."
);
}
else
{
return
*
value
;
}
}
template
<
class
K
,
class
U
=
ValueSelect
,
typename
std
::
enable_if
<
has_mapped_type
<
U
>
::
value
>::
type
*
=
nullptr
>
typename
U
::
value_type
&
operator
[](
K
&&
key
)
{
using
T
=
typename
U
::
value_type
;
const
std
::
size_t
hash
=
hash_key
(
key
);
const
std
::
size_t
ibucket_for_hash
=
bucket_for_hash
(
hash
);
T
*
value
=
find_value_impl
(
key
,
hash
,
m_buckets
+
ibucket_for_hash
);
if
(
value
!=
nullptr
)
{
return
*
value
;
}
else
{
return
insert_value
(
ibucket_for_hash
,
hash
,
std
::
piecewise_construct
,
std
::
forward_as_tuple
(
std
::
forward
<
K
>
(
key
)),
std
::
forward_as_tuple
()).
first
.
value
();
}
}
template
<
class
K
>
size_type
count
(
const
K
&
key
)
const
{
return
count
(
key
,
hash_key
(
key
));
}
template
<
class
K
>
size_type
count
(
const
K
&
key
,
std
::
size_t
hash
)
const
{
return
count_impl
(
key
,
hash
,
m_buckets
+
bucket_for_hash
(
hash
));
}
template
<
class
K
>
iterator
find
(
const
K
&
key
)
{
return
find
(
key
,
hash_key
(
key
));
}
template
<
class
K
>
iterator
find
(
const
K
&
key
,
std
::
size_t
hash
)
{
return
find_impl
(
key
,
hash
,
m_buckets
+
bucket_for_hash
(
hash
));
}
template
<
class
K
>
const_iterator
find
(
const
K
&
key
)
const
{
return
find
(
key
,
hash_key
(
key
));
}
template
<
class
K
>
const_iterator
find
(
const
K
&
key
,
std
::
size_t
hash
)
const
{
return
find_impl
(
key
,
hash
,
m_buckets
+
bucket_for_hash
(
hash
));
}
template
<
class
K
>
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
K
&
key
)
{
return
equal_range
(
key
,
hash_key
(
key
));
}
template
<
class
K
>
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
K
&
key
,
std
::
size_t
hash
)
{
iterator
it
=
find
(
key
,
hash
);
return
std
::
make_pair
(
it
,
(
it
==
end
())
?
it
:
std
::
next
(
it
));
}
template
<
class
K
>
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
K
&
key
)
const
{
return
equal_range
(
key
,
hash_key
(
key
));
}
template
<
class
K
>
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
K
&
key
,
std
::
size_t
hash
)
const
{
const_iterator
it
=
find
(
key
,
hash
);
return
std
::
make_pair
(
it
,
(
it
==
cend
())
?
it
:
std
::
next
(
it
));
}
/*
* Bucket interface
*/
size_type
bucket_count
()
const
{
/*
* So that the last bucket can have NeighborhoodSize neighbors, the size of the bucket array is a little
* bigger than the real number of buckets when not empty.
* We could use some of the buckets at the beginning, but it is faster this way as we avoid extra checks.
*/
if
(
m_buckets_data
.
empty
())
{
return
0
;
}
return
m_buckets_data
.
size
()
-
NeighborhoodSize
+
1
;
}
size_type
max_bucket_count
()
const
{
const
std
::
size_t
max_bucket_count
=
std
::
min
(
GrowthPolicy
::
max_bucket_count
(),
m_buckets_data
.
max_size
());
return
max_bucket_count
-
NeighborhoodSize
+
1
;
}
/*
* Hash policy
*/
float
load_factor
()
const
{
if
(
bucket_count
()
==
0
)
{
return
0
;
}
return
float
(
m_nb_elements
)
/
float
(
bucket_count
());
}
float
max_load_factor
()
const
{
return
m_max_load_factor
;
}
void
max_load_factor
(
float
ml
)
{
m_max_load_factor
=
std
::
max
(
0.1
f
,
std
::
min
(
ml
,
0.95
f
));
m_max_load_threshold_rehash
=
size_type
(
float
(
bucket_count
())
*
m_max_load_factor
);
m_min_load_threshold_rehash
=
size_type
(
float
(
bucket_count
())
*
MIN_LOAD_FACTOR_FOR_REHASH
);
}
void
rehash
(
size_type
count_
)
{
count_
=
std
::
max
(
count_
,
size_type
(
std
::
ceil
(
float
(
size
())
/
max_load_factor
())));
rehash_impl
(
count_
);
}
void
reserve
(
size_type
count_
)
{
rehash
(
size_type
(
std
::
ceil
(
float
(
count_
)
/
max_load_factor
())));
}
/*
* Observers
*/
hasher
hash_function
()
const
{
return
static_cast
<
const
Hash
&>
(
*
this
);
}
key_equal
key_eq
()
const
{
return
static_cast
<
const
KeyEqual
&>
(
*
this
);
}
/*
* Other
*/
iterator
mutable_iterator
(
const_iterator
pos
)
{
if
(
pos
.
m_buckets_iterator
!=
pos
.
m_buckets_end_iterator
)
{
// Get a non-const iterator
auto
it
=
m_buckets_data
.
begin
()
+
std
::
distance
(
m_buckets_data
.
cbegin
(),
pos
.
m_buckets_iterator
);
return
iterator
(
it
,
m_buckets_data
.
end
(),
m_overflow_elements
.
begin
());
}
else
{
// Get a non-const iterator
auto
it
=
mutable_overflow_iterator
(
pos
.
m_overflow_iterator
);
return
iterator
(
m_buckets_data
.
end
(),
m_buckets_data
.
end
(),
it
);
}
}
size_type
overflow_size
()
const
noexcept
{
return
m_overflow_elements
.
size
();
}
template
<
class
U
=
OverflowContainer
,
typename
std
::
enable_if
<
has_key_compare
<
U
>
::
value
>::
type
*
=
nullptr
>
typename
U
::
key_compare
key_comp
()
const
{
return
m_overflow_elements
.
key_comp
();
}
private:
template
<
class
K
>
std
::
size_t
hash_key
(
const
K
&
key
)
const
{
return
Hash
::
operator
()(
key
);
}
template
<
class
K1
,
class
K2
>
bool
compare_keys
(
const
K1
&
key1
,
const
K2
&
key2
)
const
{
return
KeyEqual
::
operator
()(
key1
,
key2
);
}
std
::
size_t
bucket_for_hash
(
std
::
size_t
hash
)
const
{
const
std
::
size_t
bucket
=
GrowthPolicy
::
bucket_for_hash
(
hash
);
tsl_hh_assert
(
bucket
<
m_buckets_data
.
size
()
||
(
bucket
==
0
&&
m_buckets_data
.
empty
()));
return
bucket
;
}
template
<
typename
U
=
value_type
,
typename
std
::
enable_if
<
std
::
is_nothrow_move_constructible
<
U
>
::
value
>::
type
*
=
nullptr
>
void
rehash_impl
(
size_type
count_
)
{
hopscotch_hash
new_map
=
new_hopscotch_hash
(
count_
);
if
(
!
m_overflow_elements
.
empty
())
{
new_map
.
m_overflow_elements
.
swap
(
m_overflow_elements
);
new_map
.
m_nb_elements
+=
new_map
.
m_overflow_elements
.
size
();
for
(
const
value_type
&
value
:
new_map
.
m_overflow_elements
)
{
const
std
::
size_t
ibucket_for_hash
=
new_map
.
bucket_for_hash
(
new_map
.
hash_key
(
KeySelect
()(
value
)));
new_map
.
m_buckets
[
ibucket_for_hash
].
set_overflow
(
true
);
}
}
try
{
const
bool
use_stored_hash
=
USE_STORED_HASH_ON_REHASH
(
new_map
.
bucket_count
());
for
(
auto
it_bucket
=
m_buckets_data
.
begin
();
it_bucket
!=
m_buckets_data
.
end
();
++
it_bucket
)
{
if
(
it_bucket
->
empty
())
{
continue
;
}
const
std
::
size_t
hash
=
use_stored_hash
?
it_bucket
->
truncated_bucket_hash
()
:
new_map
.
hash_key
(
KeySelect
()(
it_bucket
->
value
()));
const
std
::
size_t
ibucket_for_hash
=
new_map
.
bucket_for_hash
(
hash
);
new_map
.
insert_value
(
ibucket_for_hash
,
hash
,
std
::
move
(
it_bucket
->
value
()));
erase_from_bucket
(
*
it_bucket
,
bucket_for_hash
(
hash
));
}
}
/*
* The call to insert_value may throw an exception if an element is added to the overflow
* list. Rollback the elements in this case.
*/
catch
(...)
{
m_overflow_elements
.
swap
(
new_map
.
m_overflow_elements
);
const
bool
use_stored_hash
=
USE_STORED_HASH_ON_REHASH
(
new_map
.
bucket_count
());
for
(
auto
it_bucket
=
new_map
.
m_buckets_data
.
begin
();
it_bucket
!=
new_map
.
m_buckets_data
.
end
();
++
it_bucket
)
{
if
(
it_bucket
->
empty
())
{
continue
;
}
const
std
::
size_t
hash
=
use_stored_hash
?
it_bucket
->
truncated_bucket_hash
()
:
hash_key
(
KeySelect
()(
it_bucket
->
value
()));
const
std
::
size_t
ibucket_for_hash
=
bucket_for_hash
(
hash
);
// The elements we insert were not in the overflow list before the switch.
// They will not be go in the overflow list if we rollback the switch.
insert_value
(
ibucket_for_hash
,
hash
,
std
::
move
(
it_bucket
->
value
()));
}
throw
;
}
new_map
.
swap
(
*
this
);
}
template
<
typename
U
=
value_type
,
typename
std
::
enable_if
<
std
::
is_copy_constructible
<
U
>
::
value
&&
!
std
::
is_nothrow_move_constructible
<
U
>::
value
>::
type
*
=
nullptr
>
void
rehash_impl
(
size_type
count_
)
{
hopscotch_hash
new_map
=
new_hopscotch_hash
(
count_
);
const
bool
use_stored_hash
=
USE_STORED_HASH_ON_REHASH
(
new_map
.
bucket_count
());
for
(
const
hopscotch_bucket
&
bucket
:
m_buckets_data
)
{
if
(
bucket
.
empty
())
{
continue
;
}
const
std
::
size_t
hash
=
use_stored_hash
?
bucket
.
truncated_bucket_hash
()
:
new_map
.
hash_key
(
KeySelect
()(
bucket
.
value
()));
const
std
::
size_t
ibucket_for_hash
=
new_map
.
bucket_for_hash
(
hash
);
new_map
.
insert_value
(
ibucket_for_hash
,
hash
,
bucket
.
value
());
}
for
(
const
value_type
&
value
:
m_overflow_elements
)
{
const
std
::
size_t
hash
=
new_map
.
hash_key
(
KeySelect
()(
value
));
const
std
::
size_t
ibucket_for_hash
=
new_map
.
bucket_for_hash
(
hash
);
new_map
.
insert_value
(
ibucket_for_hash
,
hash
,
value
);
}
new_map
.
swap
(
*
this
);
}
#ifdef TSL_HH_NO_RANGE_ERASE_WITH_CONST_ITERATOR
iterator_overflow
mutable_overflow_iterator
(
const_iterator_overflow
it
)
{
return
std
::
next
(
m_overflow_elements
.
begin
(),
std
::
distance
(
m_overflow_elements
.
cbegin
(),
it
));
}
#else
iterator_overflow
mutable_overflow_iterator
(
const_iterator_overflow
it
)
{
return
m_overflow_elements
.
erase
(
it
,
it
);
}
#endif
// iterator is in overflow list
iterator_overflow
erase_from_overflow
(
const_iterator_overflow
pos
,
std
::
size_t
ibucket_for_hash
)
{
#ifdef TSL_HH_NO_RANGE_ERASE_WITH_CONST_ITERATOR
auto
it_next
=
m_overflow_elements
.
erase
(
mutable_overflow_iterator
(
pos
));
#else
auto
it_next
=
m_overflow_elements
.
erase
(
pos
);
#endif
m_nb_elements
--
;
// Check if we can remove the overflow flag
tsl_hh_assert
(
m_buckets
[
ibucket_for_hash
].
has_overflow
());
for
(
const
value_type
&
value
:
m_overflow_elements
)
{
const
std
::
size_t
bucket_for_value
=
bucket_for_hash
(
hash_key
(
KeySelect
()(
value
)));
if
(
bucket_for_value
==
ibucket_for_hash
)
{
return
it_next
;
}
}
m_buckets
[
ibucket_for_hash
].
set_overflow
(
false
);
return
it_next
;
}
/**
* bucket_for_value is the bucket in which the value is.
* ibucket_for_hash is the bucket where the value belongs.
*/
void
erase_from_bucket
(
hopscotch_bucket
&
bucket_for_value
,
std
::
size_t
ibucket_for_hash
)
noexcept
{
const
std
::
size_t
ibucket_for_value
=
std
::
distance
(
m_buckets_data
.
data
(),
&
bucket_for_value
);
tsl_hh_assert
(
ibucket_for_value
>=
ibucket_for_hash
);
bucket_for_value
.
remove_value
();
m_buckets
[
ibucket_for_hash
].
toggle_neighbor_presence
(
ibucket_for_value
-
ibucket_for_hash
);
m_nb_elements
--
;
}
template
<
class
K
,
class
M
>
std
::
pair
<
iterator
,
bool
>
insert_or_assign_impl
(
K
&&
key
,
M
&&
obj
)
{
auto
it
=
try_emplace_impl
(
std
::
forward
<
K
>
(
key
),
std
::
forward
<
M
>
(
obj
));
if
(
!
it
.
second
)
{
it
.
first
.
value
()
=
std
::
forward
<
M
>
(
obj
);
}
return
it
;
}
template
<
typename
P
,
class
...
Args
>
std
::
pair
<
iterator
,
bool
>
try_emplace_impl
(
P
&&
key
,
Args
&&
...
args_value
)
{
const
std
::
size_t
hash
=
hash_key
(
key
);
const
std
::
size_t
ibucket_for_hash
=
bucket_for_hash
(
hash
);
// Check if already presents
auto
it_find
=
find_impl
(
key
,
hash
,
m_buckets
+
ibucket_for_hash
);
if
(
it_find
!=
end
())
{
return
std
::
make_pair
(
it_find
,
false
);
}
return
insert_value
(
ibucket_for_hash
,
hash
,
std
::
piecewise_construct
,
std
::
forward_as_tuple
(
std
::
forward
<
P
>
(
key
)),
std
::
forward_as_tuple
(
std
::
forward
<
Args
>
(
args_value
)...));
}
template
<
typename
P
>
std
::
pair
<
iterator
,
bool
>
insert_impl
(
P
&&
value
)
{
const
std
::
size_t
hash
=
hash_key
(
KeySelect
()(
value
));
const
std
::
size_t
ibucket_for_hash
=
bucket_for_hash
(
hash
);
// Check if already presents
auto
it_find
=
find_impl
(
KeySelect
()(
value
),
hash
,
m_buckets
+
ibucket_for_hash
);
if
(
it_find
!=
end
())
{
return
std
::
make_pair
(
it_find
,
false
);
}
return
insert_value
(
ibucket_for_hash
,
hash
,
std
::
forward
<
P
>
(
value
));
}
template
<
typename
...
Args
>
std
::
pair
<
iterator
,
bool
>
insert_value
(
std
::
size_t
ibucket_for_hash
,
std
::
size_t
hash
,
Args
&&
...
value_type_args
)
{
if
((
m_nb_elements
-
m_overflow_elements
.
size
())
>=
m_max_load_threshold_rehash
)
{
rehash
(
GrowthPolicy
::
next_bucket_count
());
ibucket_for_hash
=
bucket_for_hash
(
hash
);
}
std
::
size_t
ibucket_empty
=
find_empty_bucket
(
ibucket_for_hash
);
if
(
ibucket_empty
<
m_buckets_data
.
size
())
{
do
{
tsl_hh_assert
(
ibucket_empty
>=
ibucket_for_hash
);
// Empty bucket is in range of NeighborhoodSize, use it
if
(
ibucket_empty
-
ibucket_for_hash
<
NeighborhoodSize
)
{
auto
it
=
insert_in_bucket
(
ibucket_empty
,
ibucket_for_hash
,
hash
,
std
::
forward
<
Args
>
(
value_type_args
)...);
return
std
::
make_pair
(
iterator
(
it
,
m_buckets_data
.
end
(),
m_overflow_elements
.
begin
()),
true
);
}
}
// else, try to swap values to get a closer empty bucket
while
(
swap_empty_bucket_closer
(
ibucket_empty
));
}
auto
it
=
insert_in_overflow
(
ibucket_for_hash
,
std
::
forward
<
Args
>
(
value_type_args
)...);
return
std
::
make_pair
(
iterator
(
m_buckets_data
.
end
(),
m_buckets_data
.
end
(),
it
),
true
);
// Never rehash here for memory safety
//////////////////////////////////////////////////////////////////////////////////////////////////////
// Load factor is too low or a rehash will not change the neighborhood, put the value in overflow list
// if(size() < m_min_load_threshold_rehash || !will_neighborhood_change_on_rehash(ibucket_for_hash)) {
// auto it = insert_in_overflow(ibucket_for_hash, std::forward<Args>(value_type_args)...);
// return std::make_pair(iterator(m_buckets_data.end(), m_buckets_data.end(), it), true);
// }
// rehash(GrowthPolicy::next_bucket_count());
// ibucket_for_hash = bucket_for_hash(hash);
// return insert_value(ibucket_for_hash, hash, std::forward<Args>(value_type_args)...);
//////////////////////////////////////////////////////////////////////////////////////////////////////
}
/*
* Return true if a rehash will change the position of a key-value in the neighborhood of
* ibucket_neighborhood_check. In this case a rehash is needed instead of puting the value in overflow list.
*/
bool
will_neighborhood_change_on_rehash
(
size_t
ibucket_neighborhood_check
)
const
{
std
::
size_t
expand_bucket_count
=
GrowthPolicy
::
next_bucket_count
();
GrowthPolicy
expand_growth_policy
(
expand_bucket_count
);
const
bool
use_stored_hash
=
USE_STORED_HASH_ON_REHASH
(
expand_bucket_count
);
for
(
size_t
ibucket
=
ibucket_neighborhood_check
;
ibucket
<
m_buckets_data
.
size
()
&&
(
ibucket
-
ibucket_neighborhood_check
)
<
NeighborhoodSize
;
++
ibucket
)
{
tsl_hh_assert
(
!
m_buckets
[
ibucket
].
empty
());
const
size_t
hash
=
use_stored_hash
?
m_buckets
[
ibucket
].
truncated_bucket_hash
()
:
hash_key
(
KeySelect
()(
m_buckets
[
ibucket
].
value
()));
if
(
bucket_for_hash
(
hash
)
!=
expand_growth_policy
.
bucket_for_hash
(
hash
))
{
return
true
;
}
}
return
false
;
}
/*
* Return the index of an empty bucket in m_buckets_data.
* If none, the returned index equals m_buckets_data.size()
*/
std
::
size_t
find_empty_bucket
(
std
::
size_t
ibucket_start
)
const
{
const
std
::
size_t
limit
=
std
::
min
(
ibucket_start
+
MAX_PROBES_FOR_EMPTY_BUCKET
,
m_buckets_data
.
size
());
for
(;
ibucket_start
<
limit
;
ibucket_start
++
)
{
if
(
m_buckets
[
ibucket_start
].
empty
())
{
return
ibucket_start
;
}
}
return
m_buckets_data
.
size
();
}
/*
* Insert value in ibucket_empty where value originally belongs to ibucket_for_hash
*
* Return bucket iterator to ibucket_empty
*/
template
<
typename
...
Args
>
iterator_buckets
insert_in_bucket
(
std
::
size_t
ibucket_empty
,
std
::
size_t
ibucket_for_hash
,
std
::
size_t
hash
,
Args
&&
...
value_type_args
)
{
tsl_hh_assert
(
ibucket_empty
>=
ibucket_for_hash
);
tsl_hh_assert
(
m_buckets
[
ibucket_empty
].
empty
());
m_buckets
[
ibucket_empty
].
set_value_of_empty_bucket
(
hopscotch_bucket
::
truncate_hash
(
hash
),
std
::
forward
<
Args
>
(
value_type_args
)...);
tsl_hh_assert
(
!
m_buckets
[
ibucket_for_hash
].
empty
());
m_buckets
[
ibucket_for_hash
].
toggle_neighbor_presence
(
ibucket_empty
-
ibucket_for_hash
);
m_nb_elements
++
;
return
m_buckets_data
.
begin
()
+
ibucket_empty
;
}
template
<
class
...
Args
,
class
U
=
OverflowContainer
,
typename
std
::
enable_if
<!
has_key_compare
<
U
>
::
value
>::
type
*
=
nullptr
>
iterator_overflow
insert_in_overflow
(
std
::
size_t
ibucket_for_hash
,
Args
&&
...
value_type_args
)
{
auto
it
=
m_overflow_elements
.
emplace
(
m_overflow_elements
.
end
(),
std
::
forward
<
Args
>
(
value_type_args
)...);
m_buckets
[
ibucket_for_hash
].
set_overflow
(
true
);
m_nb_elements
++
;
return
it
;
}
template
<
class
...
Args
,
class
U
=
OverflowContainer
,
typename
std
::
enable_if
<
has_key_compare
<
U
>
::
value
>::
type
*
=
nullptr
>
iterator_overflow
insert_in_overflow
(
std
::
size_t
ibucket_for_hash
,
Args
&&
...
value_type_args
)
{
auto
it
=
m_overflow_elements
.
emplace
(
std
::
forward
<
Args
>
(
value_type_args
)...).
first
;
m_buckets
[
ibucket_for_hash
].
set_overflow
(
true
);
m_nb_elements
++
;
return
it
;
}
/*
* Try to swap the bucket ibucket_empty_in_out with a bucket preceding it while keeping the neighborhood
* conditions correct.
*
* If a swap was possible, the position of ibucket_empty_in_out will be closer to 0 and true will re returned.
*/
bool
swap_empty_bucket_closer
(
std
::
size_t
&
ibucket_empty_in_out
)
{
tsl_hh_assert
(
ibucket_empty_in_out
>=
NeighborhoodSize
);
const
std
::
size_t
neighborhood_start
=
ibucket_empty_in_out
-
NeighborhoodSize
+
1
;
for
(
std
::
size_t
to_check
=
neighborhood_start
;
to_check
<
ibucket_empty_in_out
;
to_check
++
)
{
neighborhood_bitmap
neighborhood_infos
=
m_buckets
[
to_check
].
neighborhood_infos
();
std
::
size_t
to_swap
=
to_check
;
while
(
neighborhood_infos
!=
0
&&
to_swap
<
ibucket_empty_in_out
)
{
if
((
neighborhood_infos
&
1
)
==
1
)
{
tsl_hh_assert
(
m_buckets
[
ibucket_empty_in_out
].
empty
());
tsl_hh_assert
(
!
m_buckets
[
to_swap
].
empty
());
m_buckets
[
to_swap
].
swap_value_into_empty_bucket
(
m_buckets
[
ibucket_empty_in_out
]);
tsl_hh_assert
(
!
m_buckets
[
to_check
].
check_neighbor_presence
(
ibucket_empty_in_out
-
to_check
));
tsl_hh_assert
(
m_buckets
[
to_check
].
check_neighbor_presence
(
to_swap
-
to_check
));
m_buckets
[
to_check
].
toggle_neighbor_presence
(
ibucket_empty_in_out
-
to_check
);
m_buckets
[
to_check
].
toggle_neighbor_presence
(
to_swap
-
to_check
);
ibucket_empty_in_out
=
to_swap
;
return
true
;
}
to_swap
++
;
neighborhood_infos
=
neighborhood_bitmap
(
neighborhood_infos
>>
1
);
}
}
return
false
;
}
template
<
class
K
,
class
U
=
ValueSelect
,
typename
std
::
enable_if
<
has_mapped_type
<
U
>
::
value
>::
type
*
=
nullptr
>
typename
U
::
value_type
*
find_value_impl
(
const
K
&
key
,
std
::
size_t
hash
,
hopscotch_bucket
*
bucket_for_hash
)
{
return
const_cast
<
typename
U
::
value_type
*>
(
static_cast
<
const
hopscotch_hash
*>
(
this
)
->
find_value_impl
(
key
,
hash
,
bucket_for_hash
));
}
/*
* Avoid the creation of an iterator to just get the value for operator[] and at() in maps. Faster this way.
*
* Return null if no value for the key (TODO use std::optional when available).
*/
template
<
class
K
,
class
U
=
ValueSelect
,
typename
std
::
enable_if
<
has_mapped_type
<
U
>
::
value
>::
type
*
=
nullptr
>
const
typename
U
::
value_type
*
find_value_impl
(
const
K
&
key
,
std
::
size_t
hash
,
const
hopscotch_bucket
*
bucket_for_hash
)
const
{
const
hopscotch_bucket
*
bucket_found
=
find_in_buckets
(
key
,
hash
,
bucket_for_hash
);
if
(
bucket_found
!=
nullptr
)
{
return
std
::
addressof
(
ValueSelect
()(
bucket_found
->
value
()));
}
if
(
bucket_for_hash
->
has_overflow
())
{
auto
it_overflow
=
find_in_overflow
(
key
);
if
(
it_overflow
!=
m_overflow_elements
.
end
())
{
return
std
::
addressof
(
ValueSelect
()(
*
it_overflow
));
}
}
return
nullptr
;
}
template
<
class
K
>
size_type
count_impl
(
const
K
&
key
,
std
::
size_t
hash
,
const
hopscotch_bucket
*
bucket_for_hash
)
const
{
if
(
find_in_buckets
(
key
,
hash
,
bucket_for_hash
)
!=
nullptr
)
{
return
1
;
}
else
if
(
bucket_for_hash
->
has_overflow
()
&&
find_in_overflow
(
key
)
!=
m_overflow_elements
.
cend
())
{
return
1
;
}
else
{
return
0
;
}
}
template
<
class
K
>
iterator
find_impl
(
const
K
&
key
,
std
::
size_t
hash
,
hopscotch_bucket
*
bucket_for_hash
)
{
hopscotch_bucket
*
bucket_found
=
find_in_buckets
(
key
,
hash
,
bucket_for_hash
);
if
(
bucket_found
!=
nullptr
)
{
return
iterator
(
m_buckets_data
.
begin
()
+
std
::
distance
(
m_buckets_data
.
data
(),
bucket_found
),
m_buckets_data
.
end
(),
m_overflow_elements
.
begin
());
}
if
(
!
bucket_for_hash
->
has_overflow
())
{
return
end
();
}
return
iterator
(
m_buckets_data
.
end
(),
m_buckets_data
.
end
(),
find_in_overflow
(
key
));
}
template
<
class
K
>
const_iterator
find_impl
(
const
K
&
key
,
std
::
size_t
hash
,
const
hopscotch_bucket
*
bucket_for_hash
)
const
{
const
hopscotch_bucket
*
bucket_found
=
find_in_buckets
(
key
,
hash
,
bucket_for_hash
);
if
(
bucket_found
!=
nullptr
)
{
return
const_iterator
(
m_buckets_data
.
cbegin
()
+
std
::
distance
(
m_buckets_data
.
data
(),
bucket_found
),
m_buckets_data
.
cend
(),
m_overflow_elements
.
cbegin
());
}
if
(
!
bucket_for_hash
->
has_overflow
())
{
return
cend
();
}
return
const_iterator
(
m_buckets_data
.
cend
(),
m_buckets_data
.
cend
(),
find_in_overflow
(
key
));
}
template
<
class
K
>
hopscotch_bucket
*
find_in_buckets
(
const
K
&
key
,
std
::
size_t
hash
,
hopscotch_bucket
*
bucket_for_hash
)
{
const
hopscotch_bucket
*
bucket_found
=
static_cast
<
const
hopscotch_hash
*>
(
this
)
->
find_in_buckets
(
key
,
hash
,
bucket_for_hash
);
return
const_cast
<
hopscotch_bucket
*>
(
bucket_found
);
}
/**
* Return a pointer to the bucket which has the value, nullptr otherwise.
*/
template
<
class
K
>
const
hopscotch_bucket
*
find_in_buckets
(
const
K
&
key
,
std
::
size_t
hash
,
const
hopscotch_bucket
*
bucket_for_hash
)
const
{
(
void
)
hash
;
// Avoid warning of unused variable when StoreHash is false;
// TODO Try to optimize the function.
// I tried to use ffs and __builtin_ffs functions but I could not reduce the time the function
// takes with -march=native
neighborhood_bitmap
neighborhood_infos
=
bucket_for_hash
->
neighborhood_infos
();
while
(
neighborhood_infos
!=
0
)
{
if
((
neighborhood_infos
&
1
)
==
1
)
{
// Check StoreHash before calling bucket_hash_equal. Functionally it doesn't change anythin.
// If StoreHash is false, bucket_hash_equal is a no-op. Avoiding the call is there to help
// GCC optimizes `hash` parameter away, it seems to not be able to do without this hint.
if
((
!
StoreHash
||
bucket_for_hash
->
bucket_hash_equal
(
hash
))
&&
compare_keys
(
KeySelect
()(
bucket_for_hash
->
value
()),
key
))
{
return
bucket_for_hash
;
}
}
++
bucket_for_hash
;
neighborhood_infos
=
neighborhood_bitmap
(
neighborhood_infos
>>
1
);
}
return
nullptr
;
}
template
<
class
K
,
class
U
=
OverflowContainer
,
typename
std
::
enable_if
<!
has_key_compare
<
U
>
::
value
>::
type
*
=
nullptr
>
iterator_overflow
find_in_overflow
(
const
K
&
key
)
{
return
std
::
find_if
(
m_overflow_elements
.
begin
(),
m_overflow_elements
.
end
(),
[
&
](
const
value_type
&
value
)
{
return
compare_keys
(
key
,
KeySelect
()(
value
));
});
}
template
<
class
K
,
class
U
=
OverflowContainer
,
typename
std
::
enable_if
<!
has_key_compare
<
U
>
::
value
>::
type
*
=
nullptr
>
const_iterator_overflow
find_in_overflow
(
const
K
&
key
)
const
{
return
std
::
find_if
(
m_overflow_elements
.
cbegin
(),
m_overflow_elements
.
cend
(),
[
&
](
const
value_type
&
value
)
{
return
compare_keys
(
key
,
KeySelect
()(
value
));
});
}
template
<
class
K
,
class
U
=
OverflowContainer
,
typename
std
::
enable_if
<
has_key_compare
<
U
>
::
value
>::
type
*
=
nullptr
>
iterator_overflow
find_in_overflow
(
const
K
&
key
)
{
return
m_overflow_elements
.
find
(
key
);
}
template
<
class
K
,
class
U
=
OverflowContainer
,
typename
std
::
enable_if
<
has_key_compare
<
U
>
::
value
>::
type
*
=
nullptr
>
const_iterator_overflow
find_in_overflow
(
const
K
&
key
)
const
{
return
m_overflow_elements
.
find
(
key
);
}
template
<
class
U
=
OverflowContainer
,
typename
std
::
enable_if
<!
has_key_compare
<
U
>
::
value
>::
type
*
=
nullptr
>
hopscotch_hash
new_hopscotch_hash
(
size_type
bucket_count
)
{
return
hopscotch_hash
(
bucket_count
,
static_cast
<
Hash
&>
(
*
this
),
static_cast
<
KeyEqual
&>
(
*
this
),
get_allocator
(),
m_max_load_factor
);
}
template
<
class
U
=
OverflowContainer
,
typename
std
::
enable_if
<
has_key_compare
<
U
>
::
value
>::
type
*
=
nullptr
>
hopscotch_hash
new_hopscotch_hash
(
size_type
bucket_count
)
{
return
hopscotch_hash
(
bucket_count
,
static_cast
<
Hash
&>
(
*
this
),
static_cast
<
KeyEqual
&>
(
*
this
),
get_allocator
(),
m_max_load_factor
,
m_overflow_elements
.
key_comp
());
}
public:
static
const
size_type
DEFAULT_INIT_BUCKETS_SIZE
=
0
;
static
constexpr
float
DEFAULT_MAX_LOAD_FACTOR
=
(
NeighborhoodSize
<=
30
)
?
0.8
f
:
0.9
f
;
private:
static
const
std
::
size_t
MAX_PROBES_FOR_EMPTY_BUCKET
=
12
*
NeighborhoodSize
;
static
constexpr
float
MIN_LOAD_FACTOR_FOR_REHASH
=
0.2
f
;
/**
* We can only use the hash on rehash if the size of the hash type is the same as the stored one or
* if we use a power of two modulo. In the case of the power of two modulo, we just mask
* the least significant bytes, we just have to check that the truncated_hash_type didn't truncated
* too much bytes.
*/
template
<
class
T
=
size_type
,
typename
std
::
enable_if
<
std
::
is_same
<
T
,
truncated_hash_type
>
::
value
>::
type
*
=
nullptr
>
static
bool
USE_STORED_HASH_ON_REHASH
(
size_type
/*bucket_count*/
)
{
return
StoreHash
;
}
template
<
class
T
=
size_type
,
typename
std
::
enable_if
<!
std
::
is_same
<
T
,
truncated_hash_type
>
::
value
>::
type
*
=
nullptr
>
static
bool
USE_STORED_HASH_ON_REHASH
(
size_type
bucket_count
)
{
(
void
)
bucket_count
;
if
(
StoreHash
&&
is_power_of_two_policy
<
GrowthPolicy
>::
value
)
{
tsl_hh_assert
(
bucket_count
>
0
);
return
(
bucket_count
-
1
)
<=
std
::
numeric_limits
<
truncated_hash_type
>::
max
();
}
else
{
return
false
;
}
}
/**
* Return an always valid pointer to an static empty hopscotch_bucket.
*/
hopscotch_bucket
*
static_empty_bucket_ptr
()
{
static
hopscotch_bucket
empty_bucket
;
return
&
empty_bucket
;
}
private:
buckets_container_type
m_buckets_data
;
overflow_container_type
m_overflow_elements
;
/**
* Points to m_buckets_data.data() if !m_buckets_data.empty() otherwise points to static_empty_bucket_ptr.
* This variable is useful to avoid the cost of checking if m_buckets_data is empty when trying
* to find an element.
*
* TODO Remove m_buckets_data and only use a pointer+size instead of a pointer+vector to save some space in the hopscotch_hash object.
*/
hopscotch_bucket
*
m_buckets
;
size_type
m_nb_elements
;
float
m_max_load_factor
;
/**
* Max size of the hash table before a rehash occurs automatically to grow the table.
*/
size_type
m_max_load_threshold_rehash
;
/**
* Min size of the hash table before a rehash can occurs automatically (except if m_max_load_threshold_rehash os reached).
* If the neighborhood of a bucket is full before the min is reacher, the elements are put into m_overflow_elements.
*/
size_type
m_min_load_threshold_rehash
;
};
}
// end namespace detail_hopscotch_hash
}
// end namespace tsl
#endif
paddle/fluid/feed/src/common/hopscotch_map.h
0 → 100755
浏览文件 @
0dc7d425
/**
* MIT License
*
* Copyright (c) 2017 Tessil
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_HOPSCOTCH_MAP_H
#define TSL_HOPSCOTCH_MAP_H
#include <algorithm>
#include <cstddef>
#include <functional>
#include <initializer_list>
#include <list>
#include <memory>
#include <type_traits>
#include <utility>
#include "paddle/fluid/feed/src/common/hopscotch_hash.h"
namespace
tsl
{
/**
* Implementation of a hash map using the hopscotch hashing algorithm.
*
* The Key and the value T must be either nothrow move-constructible, copy-constuctible or both.
*
* The size of the neighborhood (NeighborhoodSize) must be > 0 and <= 62 if StoreHash is false.
* When StoreHash is true, 32-bits of the hash will be stored alongside the neighborhood limiting
* the NeighborhoodSize to <= 30. There is no memory usage difference between
* 'NeighborhoodSize 62; StoreHash false' and 'NeighborhoodSize 30; StoreHash true'.
*
* Storing the hash may improve performance on insert during the rehash process if the hash takes time
* to compute. It may also improve read performance if the KeyEqual function takes time (or incurs a cache-miss).
* If used with simple Hash and KeyEqual it may slow things down.
*
* StoreHash can only be set if the GrowthPolicy is set to tsl::power_of_two_growth_policy.
*
* GrowthPolicy defines how the map grows and consequently how a hash value is mapped to a bucket.
* By default the map uses tsl::power_of_two_growth_policy. This policy keeps the number of buckets
* to a power of two and uses a mask to map the hash to a bucket instead of the slow modulo.
* You may define your own growth policy, check tsl::power_of_two_growth_policy for the interface.
*
* If the destructors of Key or T throw an exception, behaviour of the class is undefined.
*
* Iterators invalidation:
* - clear, operator=, reserve, rehash: always invalidate the iterators.
* - insert, emplace, emplace_hint, operator[]: if there is an effective insert, invalidate the iterators
* if a displacement is needed to resolve a collision (which mean that most of the time,
* insert will invalidate the iterators). Or if there is a rehash.
* - erase: iterator on the erased element is the only one which become invalid.
*/
template
<
class
Key
,
class
T
,
class
Hash
=
std
::
hash
<
Key
>,
class
KeyEqual
=
std
::
equal_to
<
Key
>
,
class
Allocator
=
std
::
allocator
<
std
::
pair
<
Key
,
T
>>
,
unsigned
int
NeighborhoodSize
=
62
,
bool
StoreHash
=
false
,
class
GrowthPolicy
=
tsl
::
hh
::
power_of_two_growth_policy
<
2
>>
class
hopscotch_map
{
private:
template
<
typename
U
>
using
has_is_transparent
=
tsl
::
detail_hopscotch_hash
::
has_is_transparent
<
U
>
;
class
KeySelect
{
public:
using
key_type
=
Key
;
const
key_type
&
operator
()(
const
std
::
pair
<
Key
,
T
>&
key_value
)
const
{
return
key_value
.
first
;
}
key_type
&
operator
()(
std
::
pair
<
Key
,
T
>&
key_value
)
{
return
key_value
.
first
;
}
};
class
ValueSelect
{
public:
using
value_type
=
T
;
const
value_type
&
operator
()(
const
std
::
pair
<
Key
,
T
>&
key_value
)
const
{
return
key_value
.
second
;
}
value_type
&
operator
()(
std
::
pair
<
Key
,
T
>&
key_value
)
{
return
key_value
.
second
;
}
};
using
overflow_container_type
=
std
::
list
<
std
::
pair
<
Key
,
T
>
,
Allocator
>
;
using
ht
=
detail_hopscotch_hash
::
hopscotch_hash
<
std
::
pair
<
Key
,
T
>
,
KeySelect
,
ValueSelect
,
Hash
,
KeyEqual
,
Allocator
,
NeighborhoodSize
,
StoreHash
,
GrowthPolicy
,
overflow_container_type
>
;
public:
using
key_type
=
typename
ht
::
key_type
;
using
mapped_type
=
T
;
using
value_type
=
typename
ht
::
value_type
;
using
size_type
=
typename
ht
::
size_type
;
using
difference_type
=
typename
ht
::
difference_type
;
using
hasher
=
typename
ht
::
hasher
;
using
key_equal
=
typename
ht
::
key_equal
;
using
allocator_type
=
typename
ht
::
allocator_type
;
using
reference
=
typename
ht
::
reference
;
using
const_reference
=
typename
ht
::
const_reference
;
using
pointer
=
typename
ht
::
pointer
;
using
const_pointer
=
typename
ht
::
const_pointer
;
using
iterator
=
typename
ht
::
iterator
;
using
const_iterator
=
typename
ht
::
const_iterator
;
/*
* Constructors
*/
hopscotch_map
()
:
hopscotch_map
(
ht
::
DEFAULT_INIT_BUCKETS_SIZE
)
{
}
explicit
hopscotch_map
(
size_type
bucket_count
,
const
Hash
&
hash
=
Hash
(),
const
KeyEqual
&
equal
=
KeyEqual
(),
const
Allocator
&
alloc
=
Allocator
())
:
m_ht
(
bucket_count
,
hash
,
equal
,
alloc
,
ht
::
DEFAULT_MAX_LOAD_FACTOR
)
{
}
hopscotch_map
(
size_type
bucket_count
,
const
Allocator
&
alloc
)
:
hopscotch_map
(
bucket_count
,
Hash
(),
KeyEqual
(),
alloc
)
{
}
hopscotch_map
(
size_type
bucket_count
,
const
Hash
&
hash
,
const
Allocator
&
alloc
)
:
hopscotch_map
(
bucket_count
,
hash
,
KeyEqual
(),
alloc
)
{
}
explicit
hopscotch_map
(
const
Allocator
&
alloc
)
:
hopscotch_map
(
ht
::
DEFAULT_INIT_BUCKETS_SIZE
,
alloc
)
{
}
template
<
class
InputIt
>
hopscotch_map
(
InputIt
first
,
InputIt
last
,
size_type
bucket_count
=
ht
::
DEFAULT_INIT_BUCKETS_SIZE
,
const
Hash
&
hash
=
Hash
(),
const
KeyEqual
&
equal
=
KeyEqual
(),
const
Allocator
&
alloc
=
Allocator
())
:
hopscotch_map
(
bucket_count
,
hash
,
equal
,
alloc
)
{
insert
(
first
,
last
);
}
template
<
class
InputIt
>
hopscotch_map
(
InputIt
first
,
InputIt
last
,
size_type
bucket_count
,
const
Allocator
&
alloc
)
:
hopscotch_map
(
first
,
last
,
bucket_count
,
Hash
(),
KeyEqual
(),
alloc
)
{
}
template
<
class
InputIt
>
hopscotch_map
(
InputIt
first
,
InputIt
last
,
size_type
bucket_count
,
const
Hash
&
hash
,
const
Allocator
&
alloc
)
:
hopscotch_map
(
first
,
last
,
bucket_count
,
hash
,
KeyEqual
(),
alloc
)
{
}
hopscotch_map
(
std
::
initializer_list
<
value_type
>
init
,
size_type
bucket_count
=
ht
::
DEFAULT_INIT_BUCKETS_SIZE
,
const
Hash
&
hash
=
Hash
(),
const
KeyEqual
&
equal
=
KeyEqual
(),
const
Allocator
&
alloc
=
Allocator
())
:
hopscotch_map
(
init
.
begin
(),
init
.
end
(),
bucket_count
,
hash
,
equal
,
alloc
)
{
}
hopscotch_map
(
std
::
initializer_list
<
value_type
>
init
,
size_type
bucket_count
,
const
Allocator
&
alloc
)
:
hopscotch_map
(
init
.
begin
(),
init
.
end
(),
bucket_count
,
Hash
(),
KeyEqual
(),
alloc
)
{
}
hopscotch_map
(
std
::
initializer_list
<
value_type
>
init
,
size_type
bucket_count
,
const
Hash
&
hash
,
const
Allocator
&
alloc
)
:
hopscotch_map
(
init
.
begin
(),
init
.
end
(),
bucket_count
,
hash
,
KeyEqual
(),
alloc
)
{
}
hopscotch_map
&
operator
=
(
std
::
initializer_list
<
value_type
>
ilist
)
{
m_ht
.
clear
();
m_ht
.
reserve
(
ilist
.
size
());
m_ht
.
insert
(
ilist
.
begin
(),
ilist
.
end
());
return
*
this
;
}
allocator_type
get_allocator
()
const
{
return
m_ht
.
get_allocator
();
}
/*
* Iterators
*/
iterator
begin
()
noexcept
{
return
m_ht
.
begin
();
}
const_iterator
begin
()
const
noexcept
{
return
m_ht
.
begin
();
}
const_iterator
cbegin
()
const
noexcept
{
return
m_ht
.
cbegin
();
}
iterator
end
()
noexcept
{
return
m_ht
.
end
();
}
const_iterator
end
()
const
noexcept
{
return
m_ht
.
end
();
}
const_iterator
cend
()
const
noexcept
{
return
m_ht
.
cend
();
}
/*
* Capacity
*/
bool
empty
()
const
noexcept
{
return
m_ht
.
empty
();
}
size_type
size
()
const
noexcept
{
return
m_ht
.
size
();
}
size_type
max_size
()
const
noexcept
{
return
m_ht
.
max_size
();
}
/*
* Modifiers
*/
void
clear
()
noexcept
{
m_ht
.
clear
();
}
std
::
pair
<
iterator
,
bool
>
insert
(
const
value_type
&
value
)
{
return
m_ht
.
insert
(
value
);
}
template
<
class
P
,
typename
std
::
enable_if
<
std
::
is_constructible
<
value_type
,
P
&&
>
::
value
>::
type
*
=
nullptr
>
std
::
pair
<
iterator
,
bool
>
insert
(
P
&&
value
)
{
return
m_ht
.
insert
(
std
::
forward
<
P
>
(
value
));
}
std
::
pair
<
iterator
,
bool
>
insert
(
value_type
&&
value
)
{
return
m_ht
.
insert
(
std
::
move
(
value
));
}
iterator
insert
(
const_iterator
hint
,
const
value_type
&
value
)
{
return
m_ht
.
insert
(
hint
,
value
);
}
template
<
class
P
,
typename
std
::
enable_if
<
std
::
is_constructible
<
value_type
,
P
&&
>
::
value
>::
type
*
=
nullptr
>
iterator
insert
(
const_iterator
hint
,
P
&&
value
)
{
return
m_ht
.
insert
(
hint
,
std
::
forward
<
P
>
(
value
));
}
iterator
insert
(
const_iterator
hint
,
value_type
&&
value
)
{
return
m_ht
.
insert
(
hint
,
std
::
move
(
value
));
}
template
<
class
InputIt
>
void
insert
(
InputIt
first
,
InputIt
last
)
{
m_ht
.
insert
(
first
,
last
);
}
void
insert
(
std
::
initializer_list
<
value_type
>
ilist
)
{
m_ht
.
insert
(
ilist
.
begin
(),
ilist
.
end
());
}
template
<
class
M
>
std
::
pair
<
iterator
,
bool
>
insert_or_assign
(
const
key_type
&
k
,
M
&&
obj
)
{
return
m_ht
.
insert_or_assign
(
k
,
std
::
forward
<
M
>
(
obj
));
}
template
<
class
M
>
std
::
pair
<
iterator
,
bool
>
insert_or_assign
(
key_type
&&
k
,
M
&&
obj
)
{
return
m_ht
.
insert_or_assign
(
std
::
move
(
k
),
std
::
forward
<
M
>
(
obj
));
}
template
<
class
M
>
iterator
insert_or_assign
(
const_iterator
hint
,
const
key_type
&
k
,
M
&&
obj
)
{
return
m_ht
.
insert_or_assign
(
hint
,
k
,
std
::
forward
<
M
>
(
obj
));
}
template
<
class
M
>
iterator
insert_or_assign
(
const_iterator
hint
,
key_type
&&
k
,
M
&&
obj
)
{
return
m_ht
.
insert_or_assign
(
hint
,
std
::
move
(
k
),
std
::
forward
<
M
>
(
obj
));
}
/**
* Due to the way elements are stored, emplace will need to move or copy the key-value once.
* The method is equivalent to insert(value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template
<
class
...
Args
>
std
::
pair
<
iterator
,
bool
>
emplace
(
Args
&&
...
args
)
{
return
m_ht
.
emplace
(
std
::
forward
<
Args
>
(
args
)...);
}
/**
* Due to the way elements are stored, emplace_hint will need to move or copy the key-value once.
* The method is equivalent to insert(hint, value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template
<
class
...
Args
>
iterator
emplace_hint
(
const_iterator
hint
,
Args
&&
...
args
)
{
return
m_ht
.
emplace_hint
(
hint
,
std
::
forward
<
Args
>
(
args
)...);
}
template
<
class
...
Args
>
std
::
pair
<
iterator
,
bool
>
try_emplace
(
const
key_type
&
k
,
Args
&&
...
args
)
{
return
m_ht
.
try_emplace
(
k
,
std
::
forward
<
Args
>
(
args
)...);
}
template
<
class
...
Args
>
std
::
pair
<
iterator
,
bool
>
try_emplace
(
key_type
&&
k
,
Args
&&
...
args
)
{
return
m_ht
.
try_emplace
(
std
::
move
(
k
),
std
::
forward
<
Args
>
(
args
)...);
}
template
<
class
...
Args
>
iterator
try_emplace
(
const_iterator
hint
,
const
key_type
&
k
,
Args
&&
...
args
)
{
return
m_ht
.
try_emplace
(
hint
,
k
,
std
::
forward
<
Args
>
(
args
)...);
}
template
<
class
...
Args
>
iterator
try_emplace
(
const_iterator
hint
,
key_type
&&
k
,
Args
&&
...
args
)
{
return
m_ht
.
try_emplace
(
hint
,
std
::
move
(
k
),
std
::
forward
<
Args
>
(
args
)...);
}
iterator
erase
(
iterator
pos
)
{
return
m_ht
.
erase
(
pos
);
}
iterator
erase
(
const_iterator
pos
)
{
return
m_ht
.
erase
(
pos
);
}
iterator
erase
(
const_iterator
first
,
const_iterator
last
)
{
return
m_ht
.
erase
(
first
,
last
);
}
size_type
erase
(
const
key_type
&
key
)
{
return
m_ht
.
erase
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
*/
size_type
erase
(
const
key_type
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
erase
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
size_type
erase
(
const
K
&
key
)
{
return
m_ht
.
erase
(
key
);
}
/**
* @copydoc erase(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
size_type
erase
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
erase
(
key
,
precalculated_hash
);
}
void
swap
(
hopscotch_map
&
other
)
{
other
.
m_ht
.
swap
(
m_ht
);
}
/*
* Lookup
*/
T
&
at
(
const
Key
&
key
)
{
return
m_ht
.
at
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
T
&
at
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
at
(
key
,
precalculated_hash
);
}
const
T
&
at
(
const
Key
&
key
)
const
{
return
m_ht
.
at
(
key
);
}
/**
* @copydoc at(const Key& key, std::size_t precalculated_hash)
*/
const
T
&
at
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
at
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
T
&
at
(
const
K
&
key
)
{
return
m_ht
.
at
(
key
);
}
/**
* @copydoc at(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
T
&
at
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
at
(
key
,
precalculated_hash
);
}
/**
* @copydoc at(const K& key)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
const
T
&
at
(
const
K
&
key
)
const
{
return
m_ht
.
at
(
key
);
}
/**
* @copydoc at(const K& key, std::size_t precalculated_hash)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
const
T
&
at
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
at
(
key
,
precalculated_hash
);
}
T
&
operator
[](
const
Key
&
key
)
{
return
m_ht
[
key
];
}
T
&
operator
[](
Key
&&
key
)
{
return
m_ht
[
std
::
move
(
key
)];
}
size_type
count
(
const
Key
&
key
)
const
{
return
m_ht
.
count
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
size_type
count
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
count
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
size_type
count
(
const
K
&
key
)
const
{
return
m_ht
.
count
(
key
);
}
/**
* @copydoc count(const K& key) const
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
size_type
count
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
count
(
key
,
precalculated_hash
);
}
iterator
find
(
const
Key
&
key
)
{
return
m_ht
.
find
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
iterator
find
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
find
(
key
,
precalculated_hash
);
}
const_iterator
find
(
const
Key
&
key
)
const
{
return
m_ht
.
find
(
key
);
}
/**
* @copydoc find(const Key& key, std::size_t precalculated_hash)
*/
const_iterator
find
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
find
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
iterator
find
(
const
K
&
key
)
{
return
m_ht
.
find
(
key
);
}
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
iterator
find
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
find
(
key
,
precalculated_hash
);
}
/**
* @copydoc find(const K& key)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
const_iterator
find
(
const
K
&
key
)
const
{
return
m_ht
.
find
(
key
);
}
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
const_iterator
find
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
find
(
key
,
precalculated_hash
);
}
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
Key
&
key
)
{
return
m_ht
.
equal_range
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
equal_range
(
key
,
precalculated_hash
);
}
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
Key
&
key
)
const
{
return
m_ht
.
equal_range
(
key
);
}
/**
* @copydoc equal_range(const Key& key, std::size_t precalculated_hash)
*/
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
equal_range
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
K
&
key
)
{
return
m_ht
.
equal_range
(
key
);
}
/**
* @copydoc equal_range(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
equal_range
(
key
,
precalculated_hash
);
}
/**
* @copydoc equal_range(const K& key)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
K
&
key
)
const
{
return
m_ht
.
equal_range
(
key
);
}
/**
* @copydoc equal_range(const K& key, std::size_t precalculated_hash)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
equal_range
(
key
,
precalculated_hash
);
}
/*
* Bucket interface
*/
size_type
bucket_count
()
const
{
return
m_ht
.
bucket_count
();
}
size_type
max_bucket_count
()
const
{
return
m_ht
.
max_bucket_count
();
}
/*
* Hash policy
*/
float
load_factor
()
const
{
return
m_ht
.
load_factor
();
}
float
max_load_factor
()
const
{
return
m_ht
.
max_load_factor
();
}
void
max_load_factor
(
float
ml
)
{
m_ht
.
max_load_factor
(
ml
);
}
void
rehash
(
size_type
count_
)
{
m_ht
.
rehash
(
count_
);
}
void
reserve
(
size_type
count_
)
{
m_ht
.
reserve
(
count_
);
}
/*
* Observers
*/
hasher
hash_function
()
const
{
return
m_ht
.
hash_function
();
}
key_equal
key_eq
()
const
{
return
m_ht
.
key_eq
();
}
/*
* Other
*/
/**
* Convert a const_iterator to an iterator.
*/
iterator
mutable_iterator
(
const_iterator
pos
)
{
return
m_ht
.
mutable_iterator
(
pos
);
}
size_type
overflow_size
()
const
noexcept
{
return
m_ht
.
overflow_size
();
}
friend
bool
operator
==
(
const
hopscotch_map
&
lhs
,
const
hopscotch_map
&
rhs
)
{
if
(
lhs
.
size
()
!=
rhs
.
size
())
{
return
false
;
}
for
(
const
auto
&
element_lhs
:
lhs
)
{
const
auto
it_element_rhs
=
rhs
.
find
(
element_lhs
.
first
);
if
(
it_element_rhs
==
rhs
.
cend
()
||
element_lhs
.
second
!=
it_element_rhs
->
second
)
{
return
false
;
}
}
return
true
;
}
friend
bool
operator
!=
(
const
hopscotch_map
&
lhs
,
const
hopscotch_map
&
rhs
)
{
return
!
operator
==
(
lhs
,
rhs
);
}
friend
void
swap
(
hopscotch_map
&
lhs
,
hopscotch_map
&
rhs
)
{
lhs
.
swap
(
rhs
);
}
private:
ht
m_ht
;
};
/**
* Same as `tsl::hopscotch_map<Key, T, Hash, KeyEqual, Allocator, NeighborhoodSize, StoreHash, tsl::hh::prime_growth_policy>`.
*/
template
<
class
Key
,
class
T
,
class
Hash
=
std
::
hash
<
Key
>,
class
KeyEqual
=
std
::
equal_to
<
Key
>
,
class
Allocator
=
std
::
allocator
<
std
::
pair
<
Key
,
T
>>
,
unsigned
int
NeighborhoodSize
=
62
,
bool
StoreHash
=
false
>
using
hopscotch_pg_map
=
hopscotch_map
<
Key
,
T
,
Hash
,
KeyEqual
,
Allocator
,
NeighborhoodSize
,
StoreHash
,
tsl
::
hh
::
prime_growth_policy
>
;
}
// end namespace tsl
#endif
paddle/fluid/feed/src/common/hopscotch_set.h
0 → 100755
浏览文件 @
0dc7d425
/**
* MIT License
*
* Copyright (c) 2017 Tessil
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_HOPSCOTCH_SET_H
#define TSL_HOPSCOTCH_SET_H
#include <algorithm>
#include <cstddef>
#include <functional>
#include <initializer_list>
#include <list>
#include <memory>
#include <type_traits>
#include <utility>
#include "paddle/fluid/feed/src/common/hopscotch_hash.h"
namespace
tsl
{
/**
* Implementation of a hash set using the hopscotch hashing algorithm.
*
* The Key must be either nothrow move-constructible, copy-constuctible or both.
*
* The size of the neighborhood (NeighborhoodSize) must be > 0 and <= 62 if StoreHash is false.
* When StoreHash is true, 32-bits of the hash will be stored alongside the neighborhood limiting
* the NeighborhoodSize to <= 30. There is no memory usage difference between
* 'NeighborhoodSize 62; StoreHash false' and 'NeighborhoodSize 30; StoreHash true'.
*
* Storing the hash may improve performance on insert during the rehash process if the hash takes time
* to compute. It may also improve read performance if the KeyEqual function takes time (or incurs a cache-miss).
* If used with simple Hash and KeyEqual it may slow things down.
*
* StoreHash can only be set if the GrowthPolicy is set to tsl::power_of_two_growth_policy.
*
* GrowthPolicy defines how the set grows and consequently how a hash value is mapped to a bucket.
* By default the set uses tsl::power_of_two_growth_policy. This policy keeps the number of buckets
* to a power of two and uses a mask to set the hash to a bucket instead of the slow modulo.
* You may define your own growth policy, check tsl::power_of_two_growth_policy for the interface.
*
* If the destructor of Key throws an exception, behaviour of the class is undefined.
*
* Iterators invalidation:
* - clear, operator=, reserve, rehash: always invalidate the iterators.
* - insert, emplace, emplace_hint, operator[]: if there is an effective insert, invalidate the iterators
* if a displacement is needed to resolve a collision (which mean that most of the time,
* insert will invalidate the iterators). Or if there is a rehash.
* - erase: iterator on the erased element is the only one which become invalid.
*/
template
<
class
Key
,
class
Hash
=
std
::
hash
<
Key
>,
class
KeyEqual
=
std
::
equal_to
<
Key
>
,
class
Allocator
=
std
::
allocator
<
Key
>
,
unsigned
int
NeighborhoodSize
=
62
,
bool
StoreHash
=
false
,
class
GrowthPolicy
=
tsl
::
hh
::
power_of_two_growth_policy
<
2
>>
class
hopscotch_set
{
private:
template
<
typename
U
>
using
has_is_transparent
=
tsl
::
detail_hopscotch_hash
::
has_is_transparent
<
U
>
;
class
KeySelect
{
public:
using
key_type
=
Key
;
const
key_type
&
operator
()(
const
Key
&
key
)
const
{
return
key
;
}
key_type
&
operator
()(
Key
&
key
)
{
return
key
;
}
};
using
overflow_container_type
=
std
::
list
<
Key
,
Allocator
>
;
using
ht
=
detail_hopscotch_hash
::
hopscotch_hash
<
Key
,
KeySelect
,
void
,
Hash
,
KeyEqual
,
Allocator
,
NeighborhoodSize
,
StoreHash
,
GrowthPolicy
,
overflow_container_type
>
;
public:
using
key_type
=
typename
ht
::
key_type
;
using
value_type
=
typename
ht
::
value_type
;
using
size_type
=
typename
ht
::
size_type
;
using
difference_type
=
typename
ht
::
difference_type
;
using
hasher
=
typename
ht
::
hasher
;
using
key_equal
=
typename
ht
::
key_equal
;
using
allocator_type
=
typename
ht
::
allocator_type
;
using
reference
=
typename
ht
::
reference
;
using
const_reference
=
typename
ht
::
const_reference
;
using
pointer
=
typename
ht
::
pointer
;
using
const_pointer
=
typename
ht
::
const_pointer
;
using
iterator
=
typename
ht
::
iterator
;
using
const_iterator
=
typename
ht
::
const_iterator
;
/*
* Constructors
*/
hopscotch_set
()
:
hopscotch_set
(
ht
::
DEFAULT_INIT_BUCKETS_SIZE
)
{
}
explicit
hopscotch_set
(
size_type
bucket_count
,
const
Hash
&
hash
=
Hash
(),
const
KeyEqual
&
equal
=
KeyEqual
(),
const
Allocator
&
alloc
=
Allocator
())
:
m_ht
(
bucket_count
,
hash
,
equal
,
alloc
,
ht
::
DEFAULT_MAX_LOAD_FACTOR
)
{
}
hopscotch_set
(
size_type
bucket_count
,
const
Allocator
&
alloc
)
:
hopscotch_set
(
bucket_count
,
Hash
(),
KeyEqual
(),
alloc
)
{
}
hopscotch_set
(
size_type
bucket_count
,
const
Hash
&
hash
,
const
Allocator
&
alloc
)
:
hopscotch_set
(
bucket_count
,
hash
,
KeyEqual
(),
alloc
)
{
}
explicit
hopscotch_set
(
const
Allocator
&
alloc
)
:
hopscotch_set
(
ht
::
DEFAULT_INIT_BUCKETS_SIZE
,
alloc
)
{
}
template
<
class
InputIt
>
hopscotch_set
(
InputIt
first
,
InputIt
last
,
size_type
bucket_count
=
ht
::
DEFAULT_INIT_BUCKETS_SIZE
,
const
Hash
&
hash
=
Hash
(),
const
KeyEqual
&
equal
=
KeyEqual
(),
const
Allocator
&
alloc
=
Allocator
())
:
hopscotch_set
(
bucket_count
,
hash
,
equal
,
alloc
)
{
insert
(
first
,
last
);
}
template
<
class
InputIt
>
hopscotch_set
(
InputIt
first
,
InputIt
last
,
size_type
bucket_count
,
const
Allocator
&
alloc
)
:
hopscotch_set
(
first
,
last
,
bucket_count
,
Hash
(),
KeyEqual
(),
alloc
)
{
}
template
<
class
InputIt
>
hopscotch_set
(
InputIt
first
,
InputIt
last
,
size_type
bucket_count
,
const
Hash
&
hash
,
const
Allocator
&
alloc
)
:
hopscotch_set
(
first
,
last
,
bucket_count
,
hash
,
KeyEqual
(),
alloc
)
{
}
hopscotch_set
(
std
::
initializer_list
<
value_type
>
init
,
size_type
bucket_count
=
ht
::
DEFAULT_INIT_BUCKETS_SIZE
,
const
Hash
&
hash
=
Hash
(),
const
KeyEqual
&
equal
=
KeyEqual
(),
const
Allocator
&
alloc
=
Allocator
())
:
hopscotch_set
(
init
.
begin
(),
init
.
end
(),
bucket_count
,
hash
,
equal
,
alloc
)
{
}
hopscotch_set
(
std
::
initializer_list
<
value_type
>
init
,
size_type
bucket_count
,
const
Allocator
&
alloc
)
:
hopscotch_set
(
init
.
begin
(),
init
.
end
(),
bucket_count
,
Hash
(),
KeyEqual
(),
alloc
)
{
}
hopscotch_set
(
std
::
initializer_list
<
value_type
>
init
,
size_type
bucket_count
,
const
Hash
&
hash
,
const
Allocator
&
alloc
)
:
hopscotch_set
(
init
.
begin
(),
init
.
end
(),
bucket_count
,
hash
,
KeyEqual
(),
alloc
)
{
}
hopscotch_set
&
operator
=
(
std
::
initializer_list
<
value_type
>
ilist
)
{
m_ht
.
clear
();
m_ht
.
reserve
(
ilist
.
size
());
m_ht
.
insert
(
ilist
.
begin
(),
ilist
.
end
());
return
*
this
;
}
allocator_type
get_allocator
()
const
{
return
m_ht
.
get_allocator
();
}
/*
* Iterators
*/
iterator
begin
()
noexcept
{
return
m_ht
.
begin
();
}
const_iterator
begin
()
const
noexcept
{
return
m_ht
.
begin
();
}
const_iterator
cbegin
()
const
noexcept
{
return
m_ht
.
cbegin
();
}
iterator
end
()
noexcept
{
return
m_ht
.
end
();
}
const_iterator
end
()
const
noexcept
{
return
m_ht
.
end
();
}
const_iterator
cend
()
const
noexcept
{
return
m_ht
.
cend
();
}
/*
* Capacity
*/
bool
empty
()
const
noexcept
{
return
m_ht
.
empty
();
}
size_type
size
()
const
noexcept
{
return
m_ht
.
size
();
}
size_type
max_size
()
const
noexcept
{
return
m_ht
.
max_size
();
}
/*
* Modifiers
*/
void
clear
()
noexcept
{
m_ht
.
clear
();
}
std
::
pair
<
iterator
,
bool
>
insert
(
const
value_type
&
value
)
{
return
m_ht
.
insert
(
value
);
}
std
::
pair
<
iterator
,
bool
>
insert
(
value_type
&&
value
)
{
return
m_ht
.
insert
(
std
::
move
(
value
));
}
iterator
insert
(
const_iterator
hint
,
const
value_type
&
value
)
{
return
m_ht
.
insert
(
hint
,
value
);
}
iterator
insert
(
const_iterator
hint
,
value_type
&&
value
)
{
return
m_ht
.
insert
(
hint
,
std
::
move
(
value
));
}
template
<
class
InputIt
>
void
insert
(
InputIt
first
,
InputIt
last
)
{
m_ht
.
insert
(
first
,
last
);
}
void
insert
(
std
::
initializer_list
<
value_type
>
ilist
)
{
m_ht
.
insert
(
ilist
.
begin
(),
ilist
.
end
());
}
/**
* Due to the way elements are stored, emplace will need to move or copy the key-value once.
* The method is equivalent to insert(value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template
<
class
...
Args
>
std
::
pair
<
iterator
,
bool
>
emplace
(
Args
&&
...
args
)
{
return
m_ht
.
emplace
(
std
::
forward
<
Args
>
(
args
)...);
}
/**
* Due to the way elements are stored, emplace_hint will need to move or copy the key-value once.
* The method is equivalent to insert(hint, value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template
<
class
...
Args
>
iterator
emplace_hint
(
const_iterator
hint
,
Args
&&
...
args
)
{
return
m_ht
.
emplace_hint
(
hint
,
std
::
forward
<
Args
>
(
args
)...);
}
iterator
erase
(
iterator
pos
)
{
return
m_ht
.
erase
(
pos
);
}
iterator
erase
(
const_iterator
pos
)
{
return
m_ht
.
erase
(
pos
);
}
iterator
erase
(
const_iterator
first
,
const_iterator
last
)
{
return
m_ht
.
erase
(
first
,
last
);
}
size_type
erase
(
const
key_type
&
key
)
{
return
m_ht
.
erase
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
*/
size_type
erase
(
const
key_type
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
erase
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
size_type
erase
(
const
K
&
key
)
{
return
m_ht
.
erase
(
key
);
}
/**
* @copydoc erase(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
size_type
erase
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
erase
(
key
,
precalculated_hash
);
}
void
swap
(
hopscotch_set
&
other
)
{
other
.
m_ht
.
swap
(
m_ht
);
}
/*
* Lookup
*/
size_type
count
(
const
Key
&
key
)
const
{
return
m_ht
.
count
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
size_type
count
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
count
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
size_type
count
(
const
K
&
key
)
const
{
return
m_ht
.
count
(
key
);
}
/**
* @copydoc count(const K& key) const
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
size_type
count
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
count
(
key
,
precalculated_hash
);
}
iterator
find
(
const
Key
&
key
)
{
return
m_ht
.
find
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
iterator
find
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
find
(
key
,
precalculated_hash
);
}
const_iterator
find
(
const
Key
&
key
)
const
{
return
m_ht
.
find
(
key
);
}
/**
* @copydoc find(const Key& key, std::size_t precalculated_hash)
*/
const_iterator
find
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
find
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
iterator
find
(
const
K
&
key
)
{
return
m_ht
.
find
(
key
);
}
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
iterator
find
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
find
(
key
,
precalculated_hash
);
}
/**
* @copydoc find(const K& key)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
const_iterator
find
(
const
K
&
key
)
const
{
return
m_ht
.
find
(
key
);
}
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
const_iterator
find
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
find
(
key
,
precalculated_hash
);
}
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
Key
&
key
)
{
return
m_ht
.
equal_range
(
key
);
}
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
equal_range
(
key
,
precalculated_hash
);
}
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
Key
&
key
)
const
{
return
m_ht
.
equal_range
(
key
);
}
/**
* @copydoc equal_range(const Key& key, std::size_t precalculated_hash)
*/
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
Key
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
equal_range
(
key
,
precalculated_hash
);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
K
&
key
)
{
return
m_ht
.
equal_range
(
key
);
}
/**
* @copydoc equal_range(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
std
::
pair
<
iterator
,
iterator
>
equal_range
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
{
return
m_ht
.
equal_range
(
key
,
precalculated_hash
);
}
/**
* @copydoc equal_range(const K& key)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
K
&
key
)
const
{
return
m_ht
.
equal_range
(
key
);
}
/**
* @copydoc equal_range(const K& key, std::size_t precalculated_hash)
*/
template
<
class
K
,
class
KE
=
KeyEqual
,
typename
std
::
enable_if
<
has_is_transparent
<
KE
>
::
value
>::
type
*
=
nullptr
>
std
::
pair
<
const_iterator
,
const_iterator
>
equal_range
(
const
K
&
key
,
std
::
size_t
precalculated_hash
)
const
{
return
m_ht
.
equal_range
(
key
,
precalculated_hash
);
}
/*
* Bucket interface
*/
size_type
bucket_count
()
const
{
return
m_ht
.
bucket_count
();
}
size_type
max_bucket_count
()
const
{
return
m_ht
.
max_bucket_count
();
}
/*
* Hash policy
*/
float
load_factor
()
const
{
return
m_ht
.
load_factor
();
}
float
max_load_factor
()
const
{
return
m_ht
.
max_load_factor
();
}
void
max_load_factor
(
float
ml
)
{
m_ht
.
max_load_factor
(
ml
);
}
void
rehash
(
size_type
count_
)
{
m_ht
.
rehash
(
count_
);
}
void
reserve
(
size_type
count_
)
{
m_ht
.
reserve
(
count_
);
}
/*
* Observers
*/
hasher
hash_function
()
const
{
return
m_ht
.
hash_function
();
}
key_equal
key_eq
()
const
{
return
m_ht
.
key_eq
();
}
/*
* Other
*/
/**
* Convert a const_iterator to an iterator.
*/
iterator
mutable_iterator
(
const_iterator
pos
)
{
return
m_ht
.
mutable_iterator
(
pos
);
}
size_type
overflow_size
()
const
noexcept
{
return
m_ht
.
overflow_size
();
}
friend
bool
operator
==
(
const
hopscotch_set
&
lhs
,
const
hopscotch_set
&
rhs
)
{
if
(
lhs
.
size
()
!=
rhs
.
size
())
{
return
false
;
}
for
(
const
auto
&
element_lhs
:
lhs
)
{
const
auto
it_element_rhs
=
rhs
.
find
(
element_lhs
);
if
(
it_element_rhs
==
rhs
.
cend
())
{
return
false
;
}
}
return
true
;
}
friend
bool
operator
!=
(
const
hopscotch_set
&
lhs
,
const
hopscotch_set
&
rhs
)
{
return
!
operator
==
(
lhs
,
rhs
);
}
friend
void
swap
(
hopscotch_set
&
lhs
,
hopscotch_set
&
rhs
)
{
lhs
.
swap
(
rhs
);
}
private:
ht
m_ht
;
};
/**
* Same as `tsl::hopscotch_set<Key, Hash, KeyEqual, Allocator, NeighborhoodSize, StoreHash, tsl::hh::prime_growth_policy>`.
*/
template
<
class
Key
,
class
Hash
=
std
::
hash
<
Key
>,
class
KeyEqual
=
std
::
equal_to
<
Key
>
,
class
Allocator
=
std
::
allocator
<
Key
>
,
unsigned
int
NeighborhoodSize
=
62
,
bool
StoreHash
=
false
>
using
hopscotch_pg_set
=
hopscotch_set
<
Key
,
Hash
,
KeyEqual
,
Allocator
,
NeighborhoodSize
,
StoreHash
,
tsl
::
hh
::
prime_growth_policy
>
;
}
// end namespace tsl
#endif
paddle/fluid/feed/src/data_reader/CMakeLists.txt
0 → 100644
浏览文件 @
0dc7d425
cc_library
(
feed_data_set SRCS data_set.cc DEPS operator
)
paddle/fluid/feed/src/data_reader/data_set.cc
0 → 100644
浏览文件 @
0dc7d425
#include "paddle/fluid/feed/src/data_reader/data_set.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/platform/timer.h"
namespace
paddle
{
namespace
framework
{
void
FeedMultiSlotDataset
::
CreatePreLoadReaders
()
{
VLOG
(
3
)
<<
"Begin CreatePreLoadReaders"
;
if
(
preload_thread_num_
==
0
)
{
preload_thread_num_
=
thread_num_
;
}
CHECK
(
preload_thread_num_
>
0
)
<<
"thread num should > 0"
;
CHECK
(
input_channel_
!=
nullptr
);
preload_readers_
.
clear
();
for
(
int
i
=
0
;
i
<
preload_thread_num_
;
++
i
)
{
preload_readers_
.
push_back
(
DataFeedFactory
::
CreateDataFeed
(
data_feed_desc_
.
name
()));
preload_readers_
[
i
]
->
Init
(
data_feed_desc_
);
preload_readers_
[
i
]
->
SetThreadId
(
i
);
preload_readers_
[
i
]
->
SetThreadNum
(
preload_thread_num_
);
preload_readers_
[
i
]
->
SetFileListMutex
(
&
mutex_for_pick_file_
);
preload_readers_
[
i
]
->
SetFileListIndex
(
&
file_idx_
);
preload_readers_
[
i
]
->
SetFileList
(
filelist_
);
preload_readers_
[
i
]
->
SetParseInsId
(
parse_ins_id_
);
preload_readers_
[
i
]
->
SetParseContent
(
parse_content_
);
preload_readers_
[
i
]
->
SetInputChannel
(
input_channel_
.
get
());
preload_readers_
[
i
]
->
SetOutputChannel
(
nullptr
);
preload_readers_
[
i
]
->
SetConsumeChannel
(
nullptr
);
}
VLOG
(
3
)
<<
"End CreatePreLoadReaders"
;
}
void
FeedMultiSlotDataset
::
MergeByInsId
()
{
VLOG
(
3
)
<<
"MultiSlotDataset::MergeByInsId begin"
;
if
(
!
merge_by_insid_
)
{
VLOG
(
3
)
<<
"merge_by_insid=false, will not MergeByInsId"
;
return
;
}
auto
multi_slot_desc
=
data_feed_desc_
.
multi_slot_desc
();
std
::
vector
<
std
::
string
>
use_slots
;
for
(
size_t
i
=
0
;
i
<
multi_slot_desc
.
slots_size
();
++
i
)
{
const
auto
&
slot
=
multi_slot_desc
.
slots
(
i
);
if
(
slot
.
is_used
())
{
use_slots
.
push_back
(
slot
.
name
());
}
}
CHECK
(
multi_output_channel_
.
size
()
!=
0
);
// NOLINT
auto
channel_data
=
paddle
::
framework
::
MakeChannel
<
Record
>
();
VLOG
(
3
)
<<
"multi_output_channel_.size() "
<<
multi_output_channel_
.
size
();
for
(
size_t
i
=
0
;
i
<
multi_output_channel_
.
size
();
++
i
)
{
std
::
vector
<
Record
>
vec_data
;
multi_output_channel_
[
i
]
->
Close
();
multi_output_channel_
[
i
]
->
ReadAll
(
vec_data
);
channel_data
->
Write
(
std
::
move
(
vec_data
));
vec_data
.
clear
();
vec_data
.
shrink_to_fit
();
multi_output_channel_
[
i
]
->
Clear
();
}
channel_data
->
Close
();
std
::
vector
<
Record
>
recs
;
recs
.
reserve
(
channel_data
->
Size
());
channel_data
->
ReadAll
(
recs
);
channel_data
->
Clear
();
std
::
sort
(
recs
.
begin
(),
recs
.
end
(),
[](
const
Record
&
a
,
const
Record
&
b
)
{
return
a
.
ins_id_
<
b
.
ins_id_
;
});
std
::
vector
<
Record
>
results
;
uint64_t
drop_ins_num
=
0
;
std
::
unordered_set
<
uint16_t
>
all_int64
;
std
::
unordered_set
<
uint16_t
>
all_float
;
std
::
unordered_set
<
uint16_t
>
local_uint64
;
std
::
unordered_set
<
uint16_t
>
local_float
;
VLOG
(
3
)
<<
"recs.size() "
<<
recs
.
size
();
for
(
size_t
i
=
0
;
i
<
recs
.
size
();)
{
size_t
j
=
i
+
1
;
while
(
j
<
recs
.
size
()
&&
recs
[
j
].
ins_id_
==
recs
[
i
].
ins_id_
)
{
j
++
;
}
if
(
min_merge_size_
>
0
&&
j
-
i
!=
min_merge_size_
)
{
drop_ins_num
+=
j
-
i
;
LOG
(
WARNING
)
<<
"drop ins "
<<
recs
[
i
].
ins_id_
<<
" size="
<<
j
-
i
<<
", because merge_size="
<<
min_merge_size_
;
i
=
j
;
continue
;
}
all_int64
.
clear
();
all_float
.
clear
();
bool
has_conflict_slot
=
false
;
uint16_t
conflict_slot
=
0
;
Record
rec
;
rec
.
ins_id_
=
recs
[
i
].
ins_id_
;
rec
.
content_
=
recs
[
i
].
content_
;
for
(
size_t
k
=
i
;
k
<
j
;
k
++
)
{
local_uint64
.
clear
();
local_float
.
clear
();
for
(
auto
&
feature
:
recs
[
k
].
uint64_feasigns_
)
{
uint16_t
slot
=
feature
.
slot
();
if
(
all_int64
.
find
(
slot
)
!=
all_int64
.
end
())
{
has_conflict_slot
=
true
;
conflict_slot
=
slot
;
break
;
}
local_uint64
.
insert
(
slot
);
rec
.
uint64_feasigns_
.
push_back
(
std
::
move
(
feature
));
}
if
(
has_conflict_slot
)
{
break
;
}
all_int64
.
insert
(
local_uint64
.
begin
(),
local_uint64
.
end
());
for
(
auto
&
feature
:
recs
[
k
].
float_feasigns_
)
{
uint16_t
slot
=
feature
.
slot
();
if
(
all_float
.
find
(
slot
)
!=
all_float
.
end
())
{
has_conflict_slot
=
true
;
conflict_slot
=
slot
;
break
;
}
local_float
.
insert
(
slot
);
rec
.
float_feasigns_
.
push_back
(
std
::
move
(
feature
));
}
if
(
has_conflict_slot
)
{
break
;
}
all_float
.
insert
(
local_float
.
begin
(),
local_float
.
end
());
}
if
(
has_conflict_slot
)
{
LOG
(
WARNING
)
<<
"drop ins "
<<
recs
[
i
].
ins_id_
<<
" size="
<<
j
-
i
<<
", because conflict_slot="
<<
use_slots
[
conflict_slot
];
drop_ins_num
+=
j
-
i
;
}
else
{
results
.
push_back
(
std
::
move
(
rec
));
}
i
=
j
;
}
std
::
vector
<
Record
>
().
swap
(
recs
);
VLOG
(
3
)
<<
"results size "
<<
results
.
size
();
LOG
(
WARNING
)
<<
"total drop ins num: "
<<
drop_ins_num
;
results
.
shrink_to_fit
();
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
std
::
shuffle
(
results
.
begin
(),
results
.
end
(),
fleet_ptr
->
LocalRandomEngine
());
channel_data
->
Open
();
channel_data
->
Write
(
std
::
move
(
results
));
channel_data
->
Close
();
results
.
clear
();
results
.
shrink_to_fit
();
VLOG
(
3
)
<<
"channel data size "
<<
channel_data
->
Size
();
channel_data
->
SetBlockSize
(
channel_data
->
Size
()
/
channel_num_
+
1
);
VLOG
(
3
)
<<
"channel data block size "
<<
channel_data
->
BlockSize
();
for
(
size_t
i
=
0
;
i
<
multi_output_channel_
.
size
();
++
i
)
{
std
::
vector
<
Record
>
vec_data
;
channel_data
->
Read
(
vec_data
);
multi_output_channel_
[
i
]
->
Open
();
multi_output_channel_
[
i
]
->
Write
(
std
::
move
(
vec_data
));
vec_data
.
clear
();
vec_data
.
shrink_to_fit
();
}
CHECK
(
channel_data
->
Size
()
==
0
);
// NOLINT
channel_data
->
Clear
();
VLOG
(
3
)
<<
"MultiSlotDataset::MergeByInsId end"
;
}
}
// end namespace framework
}
// end namespace paddle
paddle/fluid/feed/src/data_reader/data_set.h
0 → 100644
浏览文件 @
0dc7d425
#pragma once
#include "paddle/fluid/framework/data_set.h"
namespace
paddle
{
namespace
framework
{
class
FeedMultiSlotDataset
:
public
MultiSlotDataset
{
public:
FeedMultiSlotDataset
()
{}
virtual
void
MergeByInsId
();
virtual
void
CreatePreLoadReaders
();
virtual
~
FeedMultiSlotDataset
()
{}
};
}
// end namespace framework
}
// end namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录