提交 0dc7d425 编写于 作者: X xiexionghang

feed improve dict/merge_patch

上级 01f2dca6
#!bash
build_mode=$1
function print_usage() {
echo "++++++++++++++++++++++++++++++++++++++++++++++++++++"
echo "sh build.sh all|make|clean"
echo "- all: will update all env && make it"
echo "- make: just do make, never update env"
echo "- clean: make clean"
echo "++++++++++++++++++++++++++++++++++++++++++++++++++++"
exit 0
}
if [ $# -lt 1 ];then
print_usage
fi
cd ~
user_dir=`pwd`
cd -
python_binary=${user_dir}/.jumbo/bin/python2.7
python_library=${user_dir}/.jumbo/lib/python2.7.so
python_include_dir=${user_dir}/.jumbo/include/python2.7
if [ ! -f ${python_binary} ];then
echo "Miss python ${python_binary}, please install with this cmd: jumbo install python"
exit -1
fi
#apply feed code
if [ -f "paddle/fluid/feed/apply_feed_code.sh" ];then
sh paddle/fluid/feed/apply_feed_code.sh
fi
function makeit() {
cd build
make -j8
cd ..
}
function cmake_all() {
mkdir build
cd build
#make clean
cmake -DCMAKE_INSTALL_PREFIX=./output/ -DCMAKE_BUILD_TYPE=Release -DWITH_PYTHON=ON -DWITH_MKL=OFF -DWITH_GPU=OFF -DWITH_PSLIB=ON -DPYTHON_INCLUDE_DIR=${python_include_dir} -DPYTHON_LIBRARY=${python_library} -DPYTHON_EXECUTABLE=${python_binary} ..
cd ..
}
if [ "${build_mode}" = "all" ];then
cmake_all
makeit
elif [ "${build_mode}" = "make" ];then
makeit
elif "${build_mode}" = "clean" ];then
cd build
make clean
fi
add_subdirectory(src)
add_subdirectory(pybind)
#!bash
#将FEED定制化代码生效到Paddle代码库(如FEED插件注册) 编译前执行
function fatal_log() {
echo "$1"
exit -1
}
#处理pybind 拓展
function apply_pybind() {
pybind_file='paddle/fluid/pybind/pybind.cc'
if [ ! -f ${pybind_file} ];then
fatal_log "Missing Requied File:${pybind_file}"
fi
find_inferece_api=`grep 'inference_api.h' ${pybind_file} |wc -l`
if [ ${find_inferece_api} -ne 1 ];then
fatal_log "Missing inference_api.h, Need Code Adjust"
fi
find_inferece_api=`grep 'BindInferenceApi' ${pybind_file} |wc -l`
if [ ${find_inferece_api} -ne 1 ];then
fatal_log "Missing BindInferenceApi, Need Code Adjust"
fi
makefile='paddle/fluid/pybind/CMakeLists.txt'
if [ ! -f ${makefile} ];then
fatal_log "Missing Requied File:${makefile}"
fi
sed -i '/expand_api/d' ${pybind_file}
sed -i '/BindExpandApi/d' ${pybind_file}
sed -i '/feed_data_set/d' ${makefile}
sed -i '/feed_paddle_pybind/d' ${makefile}
sed -i '/APPEND PYBIND_DEPS fs/d' ${makefile}
sed -i '/inference_api.h/a\#include "paddle/fluid/feed/pybind/expand_api.h"' ${pybind_file}
sed -i '/BindInferenceApi/a\ BindExpandApi(&m);' ${pybind_file}
sed -i '/set(PYBIND_SRCS/i\list(APPEND PYBIND_DEPS feed_data_set)' ${makefile}
sed -i '/set(PYBIND_SRCS/i\list(APPEND PYBIND_DEPS feed_paddle_pybind)' ${makefile}
sed -i '/set(PYBIND_SRCS/i\list(APPEND PYBIND_DEPS fs)' ${makefile}
}
function apply_feed_src() {
makefile='paddle/fluid/CMakeLists.txt'
if [ ! -f ${makefile} ];then
fatal_log "Missing Requied File:${makefile}"
fi
find_py=`grep 'pybind' ${makefile} |wc -l`
if [ ${find_py} -ne 1 ];then
fatal_log "Missing pybind, Need Code Adjust"
fi
sed -i '/feed/d' ${makefile}
sed -i '/pybind/i\add_subdirectory(feed)' ${makefile}
dataset_file='paddle/fluid/framework/dataset_factory.cc'
if [ ! -f ${dataset_file} ];then
fatal_log "Missing Requied File:${dataset_file}"
fi
sed -i '/FeedMultiSlotDataset/d' ${dataset_file}
sed -i '/data_reader/d' ${dataset_file}
sed -i '/REGISTER_DATASET_CLASS(MultiSlotDataset)/a\REGISTER_DATASET_CLASS(FeedMultiSlotDataset);' ${dataset_file}
sed -i '/data_set.h/a\#include "paddle/fluid/feed/src/data_reader/data_set.h"' ${dataset_file}
sed -i '/feed_data_set/d' paddle/fluid/framework/CMakeLists.txt
#sed -i '/target_link_libraries(executor/a\target_link_libraries(feed_data_set)' paddle/fluid/framework/CMakeLists.txt
#sed -i '/target_link_libraries(executor/a\add_dependencies(feed_data_set)' paddle/fluid/framework/CMakeLists.txt
}
apply_pybind
apply_feed_src
set(FEED_PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper
pass_builder parallel_executor profiler layer tracer engine scope_pool
dict_plugin fs shell)
if(WITH_PYTHON)
list(APPEND FEED_PYBIND_DEPS py_func_op)
endif()
set(FEED_PYBIND_SRCS
expand_api.cc
)
if(WITH_PYTHON)
if(WITH_AMD_GPU)
hip_library(feed_paddle_pybind SRCS ${FEED_PYBIND_SRCS} DEPS ARCHIVE_START ${FEED_PYBIND_DEPS} ARCHIVE_END)
else()
cc_library(feed_paddle_pybind SRCS ${FEED_PYBIND_SRCS} DEPS ${FEED_PYBIND_DEPS})
endif(WITH_AMD_GPU)
get_property (os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
target_link_libraries(feed_paddle_pybind ${os_dependency_modules})
endif(WITH_PYTHON)
#include "paddle/fluid/feed/pybind/expand_api.h"
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <cstring>
#include <iostream>
#include <map>
#include <memory>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/feed/src/common/dict_plugin.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
using paddle::framework::DictPluginManager;
using paddle::framework::FeasignCacheDict;
void BindExpandDictPlugin(py::module *m);
void BindExpandApi(py::module *m) {
BindExpandDictPlugin(m);
}
void BindExpandDictPlugin(py::module *m) {
py::class_<FeasignCacheDict>(*m, "FeasignCacheDict")
.def(py::init<>())
.def(py::init<const FeasignCacheDict &>())
.def("load", &FeasignCacheDict::Load);
py::class_<DictPluginManager>(*m, "DictPluginManager")
.def(py::init<>())
.def_static("instance", &DictPluginManager::Instance)
.def("load_dict", &DictPluginManager::LoadDict)
.def("create_dict", &DictPluginManager::CreateDict);
}
} // namespace pybind
} // namespace paddle
#pragma once
#include <pybind11/pybind11.h>
namespace paddle {
namespace pybind {
void BindExpandApi(pybind11::module *m);
} // namespace pybind
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <Python.h>
#include <algorithm>
#include <cstdlib>
#include <map>
#include <memory>
#include <mutex> // NOLINT // for call_once
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/ir/coalesce_grad_tensor_pass.h"
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/parallel_executor.h"
#include "paddle/fluid/framework/prune.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/scope_pool.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/version.h"
#include "paddle/fluid/memory/allocation/allocator_strategy.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/py_func_op.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/init.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/pybind/box_helper_py.h"
#include "paddle/fluid/pybind/const_value.h"
#include "paddle/fluid/pybind/data_set_py.h"
#include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/pybind/fleet_wrapper_py.h"
#include "paddle/fluid/pybind/imperative.h"
#include "paddle/fluid/pybind/inference_api.h"
#include "paddle/fluid/pybind/ir.h"
#include "paddle/fluid/pybind/expand_api.h"
#ifndef _WIN32
#include "paddle/fluid/pybind/nccl_wrapper_py.h"
#endif
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/pybind/protobuf.h"
#include "paddle/fluid/pybind/pybind.h" // NOLINT
#include "paddle/fluid/pybind/reader_py.h"
#include "paddle/fluid/pybind/tensor_py.h"
#include "paddle/fluid/string/to_string.h"
#ifdef PADDLE_WITH_CUDA
#ifndef _WIN32
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
#endif
#include "paddle/fluid/platform/cuda_profiler.h"
#include "paddle/fluid/platform/gpu_info.h"
#endif
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/pybind/communicator_py.h"
#endif
#include "pybind11/stl.h"
DEFINE_bool(reader_queue_speed_test_mode, false,
"If set true, the queue.pop will only get data from queue but not "
"remove the data from queue for speed testing");
DECLARE_bool(use_mkldnn);
#ifdef PADDLE_WITH_NGRAPH
DECLARE_bool(use_ngraph);
#endif
// disable auto conversion to list in Python
PYBIND11_MAKE_OPAQUE(paddle::framework::LoDTensorArray);
namespace paddle {
namespace pybind {
bool IsCompiledWithCUDA() {
#ifndef PADDLE_WITH_CUDA
return false;
#else
return true;
#endif
}
bool IsCompiledWithMKLDNN() {
#ifndef PADDLE_WITH_MKLDNN
return false;
#else
return true;
#endif
}
bool IsCompiledWithNGRAPH() {
#ifndef PADDLE_WITH_NGRAPH
return false;
#else
return true;
#endif
}
bool IsCompiledWithBrpc() {
#ifndef PADDLE_WITH_DISTRIBUTE
return false;
#endif
#ifdef PADDLE_WITH_GRPC
return false;
#endif
return true;
}
bool IsCompiledWithDIST() {
#ifdef PADDLE_WITH_DISTRIBUTE
return true;
#else
return false;
#endif
}
template <typename PlaceType1, typename PlaceType2>
static inline bool IsSamePlace(const PlaceType1 &p1, const PlaceType2 &p2) {
return paddle::platform::Place(p1) == paddle::platform::Place(p2);
}
template <typename PlaceType>
static inline int PlaceIndex(const PlaceType &p) {
return static_cast<int>(paddle::platform::Place(p).which());
}
#ifdef PADDLE_WITH_AVX
PYBIND11_MODULE(core_avx, m) {
#else
PYBIND11_MODULE(core_noavx, m) {
#endif
// Not used, just make sure cpu_info.cc is linked.
paddle::platform::CpuTotalPhysicalMemory();
paddle::memory::allocation::UseAllocatorStrategyGFlag();
m.doc() = "C++ core of PaddlePaddle";
// using framework in this function. Since it is inside a function, it will
// not cause namespace pollution.
using namespace paddle::framework; // NOLINT
BindException(&m);
m.def("set_num_threads", &platform::SetNumThreads);
m.def(
"_append_python_callable_object_and_return_id",
[](py::object py_obj) -> size_t {
return paddle::operators::AppendPythonCallableObjectAndReturnId(py_obj);
});
m.def("_get_use_default_grad_op_desc_maker_ops",
[] { return OpInfoMap::Instance().GetUseDefaultGradOpDescMakerOps(); });
// NOTE(zjl): ctest would load environment variables at the beginning even
// though we have not `import paddle.fluid as fluid`. So we add this API
// to enable eager deletion mode in unittest.
m.def("_set_eager_deletion_mode", &paddle::framework::SetEagerDeletionMode);
m.def("_set_fuse_parameter_group_size",
&paddle::framework::ir::SetFuseParameterGroupsSize);
m.def("_set_fuse_parameter_memory_size",
&paddle::framework::ir::SetFuseParameterMemorySize);
m.add_object("_cleanup",
py::capsule([]() { ScopePool::Instance().Clear(); }));
m.def("_set_paddle_lib_path", &paddle::platform::dynload::SetPaddleLibPath);
BindImperative(&m);
py::class_<Tensor>(m, "Tensor", py::buffer_protocol())
.def("__array__", [](Tensor &self) { return TensorToPyArray(self); })
.def("_is_initialized",
[](const Tensor &self) { return self.IsInitialized(); })
.def("_get_dims",
[](const Tensor &self) { return vectorize(self.dims()); })
.def("_set_dims",
[](Tensor &self, const std::vector<int64_t> &dim) {
self.Resize(make_ddim(dim));
})
.def("_set_layout",
[](Tensor &self, const std::string &layout) {
self.set_layout(StringToDataLayout(layout));
})
.def("_alloc_float",
[](Tensor &self, paddle::platform::CUDAPlace &place) {
self.mutable_data<float>(place);
})
.def("_alloc_float",
[](Tensor &self, paddle::platform::CPUPlace &place) {
self.mutable_data<float>(place);
})
.def("_alloc_double",
[](Tensor &self, paddle::platform::CPUPlace &place) {
self.mutable_data<double>(place);
})
.def("_alloc_int",
[](Tensor &self, paddle::platform::CPUPlace &place) {
self.mutable_data<int>(place);
})
.def("_alloc_int",
[](Tensor &self, paddle::platform::CUDAPlace &place) {
self.mutable_data<int>(place);
})
.def("_alloc_int",
[](Tensor &self, paddle::platform::CUDAPinnedPlace &place) {
self.mutable_data<int>(place);
})
.def("_alloc_float",
[](Tensor &self, paddle::platform::CUDAPinnedPlace &place) {
self.mutable_data<float>(place);
})
.def("_clear", &Tensor::clear)
.def("set", PyCPUTensorSetFromArray<float>)
.def("set", PyCPUTensorSetFromArray<int>)
.def("set", PyCPUTensorSetFromArray<double>)
.def("set", PyCPUTensorSetFromArray<int64_t>)
.def("set", PyCPUTensorSetFromArray<bool>)
.def("set", PyCPUTensorSetFromArray<uint16_t>)
.def("set", PyCPUTensorSetFromArray<uint8_t>)
.def("set", PyCPUTensorSetFromArray<int8_t>)
#ifdef PADDLE_WITH_CUDA
.def("set", PyCUDATensorSetFromArray<float>)
.def("set", PyCUDATensorSetFromArray<int>)
.def("set", PyCUDATensorSetFromArray<double>)
.def("set", PyCUDATensorSetFromArray<int64_t>)
.def("set", PyCUDATensorSetFromArray<bool>)
.def("set", PyCUDATensorSetFromArray<uint16_t>)
.def("set", PyCUDATensorSetFromArray<uint8_t>)
.def("set", PyCUDATensorSetFromArray<int8_t>)
.def("set", PyCUDAPinnedTensorSetFromArray<float>)
.def("set", PyCUDAPinnedTensorSetFromArray<int>)
.def("set", PyCUDAPinnedTensorSetFromArray<double>)
.def("set", PyCUDAPinnedTensorSetFromArray<int64_t>)
.def("set", PyCUDAPinnedTensorSetFromArray<bool>)
.def("set", PyCUDAPinnedTensorSetFromArray<uint16_t>)
.def("set", PyCUDAPinnedTensorSetFromArray<uint8_t>)
.def("set", PyCUDAPinnedTensorSetFromArray<int8_t>)
#endif
.def("shape", [](Tensor &self) { return vectorize(self.dims()); })
.def("_set_float_element", TensorSetElement<float>)
.def("_get_float_element", TensorGetElement<float>)
.def("_set_double_element", TensorSetElement<double>)
.def("_get_double_element", TensorGetElement<double>)
.def("_place", [](Tensor &self) { return self.place(); })
.def("_dtype", [](Tensor &self) { return self.type(); })
.def("__getitem__", PySliceTensor, py::return_value_policy::reference)
.def("__str__", [](const Tensor &self) {
std::stringstream ostr;
ostr << self;
return ostr.str();
});
py::class_<LoDTensor, Tensor>(m, "LoDTensor", R"DOC(
LoDTensor is a Tensor with optional LoD information.
np.array(lod_tensor) can convert LoDTensor to numpy array.
lod_tensor.lod() can retrieve the LoD information.
LoD is short for Level of Details and is usually used for varied sequence
length. You can skip the following comment if you don't need optional LoD.
For example, a LoDTensor X can look like the example below. It contains
2 sequences. The first has length 2 and the second has length 3, as
described by x.lod.
The first tensor dimension 5=2+3 is calculated from LoD if it's available.
It means the total number of sequence element. In X, each element has 2
columns, hence [5, 2].
x.lod = [[2, 3]]
x.data = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
x.shape = [5, 2]
LoD can have multiple levels (for example, a paragraph can have multiple
sentences and a sentence can have multiple words). In the following
LodTensor Y, the lod_level is 2. It means there are 2 sequence, the
first sequence length is 2 (has 2 sub-sequences), the second one's
length is 1. The first sequence's 2 sub-sequences have length 2 and 2,
respectively. And the second sequence's 1 sub-sequence has length 3.
y.lod = [[2 1], [2 2 3]]
y.shape = [2+2+3, ...]
Examples:
.. code-block:: python
import paddle.fluid as fluid
t = fluid.LoDTensor()
Note:
In above description, LoD is length-based. In Paddle internal
implementation, lod is offset-based. Hence, internally,
y.lod is represented as [[0, 2, 3], [0, 2, 4, 7]] (length-based
equivlent would be [[2-0, 3-2], [2-0, 4-2, 7-4]]).
Sometimes LoD is called recursive_sequence_length to be more
self-explanatory. In this case, it must be length-based. Due to history
reasons. when LoD is called lod in public API, it might be offset-based.
Users should be careful about it.
)DOC")
.def("__array__", [](Tensor &self) { return TensorToPyArray(self); })
.def("__init__",
[](LoDTensor &instance, const std::vector<std::vector<size_t>>
&recursive_sequence_lengths) {
LoD new_lod;
new_lod.reserve(recursive_sequence_lengths.size());
std::copy(recursive_sequence_lengths.begin(),
recursive_sequence_lengths.end(),
std::back_inserter(new_lod));
LoD new_offset_lod = ConvertToOffsetBasedLoD(new_lod);
PADDLE_ENFORCE_EQ(
CheckLoD(new_offset_lod, -1), true,
"the provided recursive_sequence_lengths info is invalid");
new (&instance) LoDTensor(new_offset_lod);
})
.def("__init__", [](LoDTensor &instance) { new (&instance) LoDTensor(); })
// We implement offset based LOD in C++ while we use length based with
// Python API. So we changed set_lod to set_recursive_sequence_lengths to
// avoid misuse.
// The discussion is here:
// https://github.com/PaddlePaddle/Paddle/issues/10855
.def("set_lod",
[](LoDTensor &self, const std::vector<std::vector<size_t>> &lod) {
// the input lod is offset-based level-of-detail info
LoD new_lod;
new_lod.reserve(lod.size());
std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod));
PADDLE_ENFORCE_EQ(
CheckLoD(new_lod, vectorize(self.dims()).front()), true,
"the provided lod info is invalid");
self.set_lod(new_lod);
},
py::arg("lod"), R"DOC(
Set LoD of the LoDTensor.
Args:
lod (List[List[int]]): the lod to be set.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
t = fluid.LoDTensor()
t.set(np.ndarray([5, 30]), fluid.CPUPlace())
t.set_lod([[0, 2, 5]])
)DOC")
.def("set_recursive_sequence_lengths",
[](LoDTensor &self, const std::vector<std::vector<size_t>>
&recursive_sequence_lengths) {
// the input recursive_sequence_lengths is length-based
// level-of-detail info
LoD new_lod;
new_lod.reserve(recursive_sequence_lengths.size());
std::copy(recursive_sequence_lengths.begin(),
recursive_sequence_lengths.end(),
std::back_inserter(new_lod));
LoD new_offset_lod = ConvertToOffsetBasedLoD(new_lod);
PADDLE_ENFORCE_EQ(
CheckLoD(new_offset_lod, vectorize(self.dims()).front()), true,
"the provided recursive_sequence_lengths info is invalid");
self.set_lod(new_offset_lod);
},
py::arg("recursive_sequence_lengths"), R"DOC(
Set LoD of the LoDTensor according to recursive sequence length.
For example, if recursive_sequence_lengths=[[2, 3]], meaning that
there are two sequences with length 2 and 3 respectively, the
corresponding lod would be [[0, 2, 2+3]], i.e, [[0, 2, 5]].
Args:
recursive_sequence_lengths (List[List[int]]): sequence lengths.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
t = fluid.LoDTensor()
t.set(np.ndarray([5, 30]), fluid.CPUPlace())
t.set_recursive_sequence_lengths([[2, 3]])
)DOC")
.def("lod",
[](LoDTensor &self) -> std::vector<std::vector<size_t>> {
// output the offset-based lod info
LoD lod = self.lod();
std::vector<std::vector<size_t>> new_lod;
new_lod.reserve(lod.size());
std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod));
return new_lod;
},
R"DOC(
Return the LoD of the LoDTensor.
Returns:
out (List[List[int]]): the lod of the LoDTensor.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
t = fluid.LoDTensor()
t.set(np.ndarray([5, 30]), fluid.CPUPlace())
t.set_lod([[0, 2, 5]])
print(t.lod()) # [[0, 2, 5]]
)DOC")
// Set above comments of set_lod.
.def("recursive_sequence_lengths",
[](LoDTensor &self) -> std::vector<std::vector<size_t>> {
// output the length-based lod info
LoD lod = ConvertToLengthBasedLoD(self.lod());
std::vector<std::vector<size_t>> new_lod;
new_lod.reserve(lod.size());
std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod));
return new_lod;
},
R"DOC(
Return the sequence length of the LoDTensor corresponding to LoD.
Returns:
out (List[List[int]): the sequence lengths.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
t = fluid.LoDTensor()
t.set(np.ndarray([5, 30]), fluid.CPUPlace())
t.set_recursive_sequence_lengths([[2, 3]])
print(t.recursive_sequence_lengths()) # [[2, 3]]
)DOC")
.def("has_valid_recursive_sequence_lengths",
[](LoDTensor &self) -> bool {
// Check that the lod info is valid and match the outermost
// dimension of the LoDTensor data
return CheckLoD(self.lod(), vectorize(self.dims()).front());
},
R"DOC(
Check whether the lod of the LoDTensor is valid.
Returns:
out (bool): whether the lod is valid.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
t = fluid.LoDTensor()
t.set(np.ndarray([5, 30]), fluid.CPUPlace())
t.set_recursive_sequence_lengths([[2, 3]])
print(t.has_valid_recursive_sequence_lengths()) # True
)DOC")
.def("__getitem__", PySliceTensor, py::return_value_policy::reference,
R"DOC(
Slice the original Tensor, and remove the LoD information.
Returns:
out (Tensor): new Tensor(NOT LoDTensor).
)DOC")
.def("__str__",
[](const LoDTensor &self) {
std::stringstream ostr;
ostr << self;
return ostr.str();
})
.def("_copy", [](const LoDTensor &self, const platform::Place &place) {
// follow fetch_op's inplementation
LoDTensor dst;
if (self.IsInitialized() && self.numel() > 0) {
TensorCopySync(self, place, &dst);
} else {
// Not copy, if the src tensor is empty.
dst.clear();
dst.Resize({0});
}
dst.set_lod(self.lod());
return dst;
});
py::class_<SelectedRows>(m, "SelectedRows")
.def("__init__",
[](SelectedRows &instance) { new (&instance) SelectedRows(); })
.def("__init__",
[](SelectedRows &instance, const std::vector<int64_t> rows,
const int64_t &height) {
new (&instance) SelectedRows(rows, height);
})
.def("get_tensor",
[](SelectedRows &self) { return self.mutable_value(); },
py::return_value_policy::reference)
.def("numel",
[](SelectedRows &self) -> int64_t { return self.value().numel(); })
.def("set_height", &SelectedRows::set_height)
.def("height", &SelectedRows::height)
.def("set_rows",
[](SelectedRows &self, std::vector<int64_t> rows) {
#ifndef PADDLE_WITH_CUDA
self.set_rows(rows);
#else
Vector<int64_t> new_rows(rows);
self.set_rows(new_rows);
#endif
})
.def("sync_index", [](SelectedRows &instance) { instance.SyncIndex(); })
.def("rows", [](SelectedRows &self) {
auto rows = self.rows();
std::vector<int64_t> new_rows;
new_rows.reserve(rows.size());
std::copy(rows.begin(), rows.end(), std::back_inserter(new_rows));
return new_rows;
});
py::class_<Variable>(m, "Variable", R"DOC(Variable Class.
All parameter, weight, gradient are variables in Paddle.
)DOC")
.def(py::init<>())
.def("is_int", [](const Variable &var) { return var.IsType<int>(); })
.def("set_int",
[](Variable &var, int val) -> void { *var.GetMutable<int>() = val; })
.def("get_int", [](const Variable &var) -> int { return var.Get<int>(); })
.def("is_float", [](const Variable &var) { return var.IsType<float>(); })
.def("set_float",
[](Variable &var, float val) -> void {
*var.GetMutable<float>() = val;
})
.def("get_float",
[](const Variable &var) -> float { return var.Get<float>(); })
.def("get_tensor",
[](Variable &self) -> LoDTensor * {
return self.GetMutable<LoDTensor>();
},
py::return_value_policy::reference)
.def("get_lod_rank_table",
[](Variable &self) { return self.GetMutable<LoDRankTable>(); },
py::return_value_policy::reference)
.def("get_selected_rows",
[](Variable &self) -> SelectedRows * {
return self.GetMutable<SelectedRows>();
},
py::return_value_policy::reference)
.def("get_lod_tensor_array",
[](Variable &self) { return self.GetMutable<LoDTensorArray>(); },
py::return_value_policy::reference)
#if (defined(PADDLE_WITH_CUDA) && !defined(_WIN32))
.def("get_communicator",
[](Variable &self) -> platform::Communicator * {
return self.GetMutable<platform::Communicator>();
},
py::return_value_policy::reference)
#endif
.def("get_reader",
[](Variable &self) -> framework::ReaderHolder * {
PADDLE_ENFORCE_EQ(self.IsType<framework::ReaderHolder>(), true);
return self.GetMutable<framework::ReaderHolder>();
},
py::return_value_policy::reference);
BindReader(&m);
using LoDTensorBlockingQueue =
::paddle::operators::reader::LoDTensorBlockingQueue;
using LoDTensorBlockingQueueHolder =
::paddle::operators::reader::LoDTensorBlockingQueueHolder;
py::class_<LoDTensorBlockingQueue, std::shared_ptr<LoDTensorBlockingQueue>>(
m, "LoDTensorBlockingQueue", "")
.def("push",
[](LoDTensorBlockingQueue &self,
const std::vector<framework::LoDTensor> &lod_tensor_vec) {
pybind11::gil_scoped_release release;
return self.Push(lod_tensor_vec);
})
.def("size", &LoDTensorBlockingQueue::Size)
.def("capacity", &LoDTensorBlockingQueue::Cap)
.def("close", &LoDTensorBlockingQueue::Close)
.def("is_closed", &LoDTensorBlockingQueue::IsClosed);
m.def("init_lod_tensor_blocking_queue",
[](Variable &var,
size_t capacity) -> std::shared_ptr<LoDTensorBlockingQueue> {
VLOG(1) << "init_lod_tensor_blocking_queue";
auto *holder = var.GetMutable<LoDTensorBlockingQueueHolder>();
holder->InitOnce(capacity, FLAGS_reader_queue_speed_test_mode);
return holder->GetQueue();
},
py::return_value_policy::copy);
py::class_<Scope>(m, "_Scope", R"DOC(
Scope is an association of a name to Variable. All variables belong to Scope.
Variables in a parent scope can be retrieved from local scope.
You need to specify a scope to run a Net, i.e., `exe.Run(&scope)`.
One net can run in different scopes and update different variable in the
scope.
You can create var in a scope and get it from the scope.
Examples:
.. code-block:: python
import paddle.fluid as fluid
# create tensor from a scope and set value to it.
param = scope.var('Param').get_tensor()
param_array = np.full((height, row_numel), 5.0).astype("float32")
param.set(param_array, place)
)DOC")
.def("_remove_from_pool",
[](Scope &self) { ScopePool::Instance().Remove(&self); })
.def("var",
[](Scope &self, const std::string &name) -> Variable * {
return self.Var(name);
},
py::arg("name"),
R"DOC(
Find or create variable named :code:`name` in the current scope.
If the variable named :code:`name` does not exist in the
current scope, the variable would be created. Otherwise,
return the existing variable.
Args:
name (str): the variable name.
Returns:
out (core.Variable): the found or created variable.
)DOC",
py::return_value_policy::reference)
.def("find_var", &Scope::FindVar, py::arg("name"),
R"DOC(
Find variable named :code:`name` in the current scope or
its parent scope. Return None if not found.
Args:
name (str): the variable name.
Returns:
out (core.Variable|None): the found variable or None.
)DOC",
py::return_value_policy::reference)
.def("new_scope", [](Scope &self) -> Scope * { return &self.NewScope(); },
R"DOC(
Create a new sub-scope of the current scope.
Returns:
out (core._Scope): the created sub-scope.
)DOC",
py::return_value_policy::reference)
.def("drop_kids", &Scope::DropKids,
R"DOC(
Delete all sub-scopes of the current scope.
)DOC")
.def("_kids", &Scope::kids);
m.def("Scope",
[]() -> Scope * {
auto *s = new Scope();
ScopePool::Instance().Insert(std::unique_ptr<Scope>(s));
return s;
},
R"DOC(
Create a new scope.
Returns:
out (core._Scope): the created scope.
)DOC",
py::return_value_policy::reference);
//! @note: Be careful! PyBind will return std::string as an unicode, not
//! Python str. If you want a str object, you should cast them in Python.
m.def("get_all_op_protos", []() -> std::vector<py::bytes> {
std::vector<py::bytes> ret_values;
for (auto &iter : OpInfoMap::Instance().map()) {
auto &info = iter.second;
if (info.HasOpProtoAndChecker()) {
std::string str;
PADDLE_ENFORCE_EQ(
info.Proto().SerializeToString(&str), true,
"Serialize OpProto Error. This could be a bug of Paddle.");
ret_values.emplace_back(str);
}
}
return ret_values;
});
m.def(
"get_grad_op_desc", [](const OpDesc &op_desc,
const std::unordered_set<std::string> &no_grad_set,
const std::vector<BlockDesc *> &grad_sub_block) {
std::unordered_map<std::string, std::string> grad_to_var;
std::vector<std::unique_ptr<OpDesc>> grad_op_descs =
framework::OpInfoMap::Instance()
.Get(op_desc.Type())
.GradOpMaker()(op_desc, no_grad_set, &grad_to_var,
grad_sub_block);
std::vector<OpDesc *> grad_op_desc_ptrs(grad_op_descs.size());
std::transform(grad_op_descs.begin(), grad_op_descs.end(),
grad_op_desc_ptrs.begin(),
[](std::unique_ptr<OpDesc> &p) { return p.release(); });
return std::make_pair(grad_op_desc_ptrs, grad_to_var);
});
m.def("has_grad_op_maker", [](const std::string op_type) {
return framework::OpInfoMap::Instance().Get(op_type).HasGradOpMaker();
});
m.def("has_infer_inplace", [](const std::string op_type) {
return framework::OpInfoMap::Instance().Get(op_type).HasInferInplace();
});
m.def("get_flags_use_mkldnn", []() { return FLAGS_use_mkldnn; });
#ifdef PADDLE_WITH_NGRAPH
m.def("get_flags_use_ngraph", []() { return FLAGS_use_ngraph; });
#endif
m.def("prune", [](const ProgramDesc &origin,
const std::set<std::string> &feeded_var_names,
const std::vector<std::array<size_t, 2>> &targets) {
ProgramDesc prog_with_targets(origin);
for (const auto &t : targets) {
prog_with_targets.MutableBlock(t[0])->Op(t[1])->SetIsTarget(true);
}
proto::ProgramDesc pruned_desc;
Prune(*prog_with_targets.Proto(), feeded_var_names, &pruned_desc);
return new ProgramDesc(pruned_desc);
});
m.def("prune_backward", [](const framework::ProgramDesc &program) {
return PruneBackward(program);
});
m.def("empty_var_name",
[]() { return std::string(framework::kEmptyVarName); });
m.def("grad_var_suffix",
[]() { return std::string(framework::kGradVarSuffix); });
m.def_submodule(
"var_names",
"The module will return special predefined variable name in Paddle")
.def("empty", []() { return kEmptyVarName; })
.def("temp", []() { return kTempVarName; });
// clang-format off
py::class_<paddle::platform::DeviceContext>(m, "DeviceContext")
.def_static("create",
[](paddle::platform::CPUPlace& place)
-> paddle::platform::DeviceContext* {
return new paddle::platform::CPUDeviceContext();
})
.def_static("create",
[](paddle::platform::CUDAPlace& place)
-> paddle::platform::DeviceContext* {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW("CUDAPlace is not supported in CPU device.");
#else
return new paddle::platform::CUDADeviceContext(place);
#endif
})
.def_static("create",
[](paddle::platform::CUDAPinnedPlace& place)
-> paddle::platform::DeviceContext* {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW(
"CUDAPinnedPlace is not supported in CPU device.");
#else
return new paddle::platform::CUDAPinnedDeviceContext(place);
#endif
});;
// clang-format on
#if (defined(PADDLE_WITH_CUDA) && !defined(_WIN32))
py::class_<platform::Communicator>(m, "Communicator").def(py::init<>());
#endif
py::class_<platform::CUDAPlace>(m, "CUDAPlace", R"DOC(
CUDAPlace is a descriptor of a device. It represents a GPU, and each CUDAPlace
has a dev_id to indicate the number of cards represented by the current CUDAPlace.
The memory of CUDAPlace with different dev_id is not accessible.
Examples:
.. code-block:: python
import paddle.fluid as fluid
gpu_place = fluid.CUDAPlace(0)
)DOC")
.def("__init__",
[](platform::CUDAPlace &self, int dev_id) {
#ifdef PADDLE_WITH_CUDA
if (UNLIKELY(dev_id < 0)) {
LOG(ERROR) << string::Sprintf(
"Invalid CUDAPlace(%d), device id must be 0 or "
"positive integer",
dev_id);
std::exit(-1);
}
if (UNLIKELY(dev_id >= platform::GetCUDADeviceCount())) {
if (platform::GetCUDADeviceCount() == 0) {
LOG(ERROR) << "Cannot use GPU because there is no GPU "
"detected on your "
"machine.";
std::exit(-1);
} else {
LOG(ERROR) << string::Sprintf(
"Invalid CUDAPlace(%d), must inside [0, %d), because GPU "
"number on your machine is %d",
dev_id, platform::GetCUDADeviceCount(),
platform::GetCUDADeviceCount());
std::exit(-1);
}
}
new (&self) platform::CUDAPlace(dev_id);
#else
LOG(ERROR) << string::Sprintf(
"Cannot use GPU because you have installed CPU version "
"PaddlePaddle.\n"
"If you want to use GPU, please try to install GPU version "
"PaddlePaddle by: pip install paddlepaddle-gpu\n"
"If you only have CPU, please change CUDAPlace(%d) to be "
"CPUPlace().\n",
dev_id);
std::exit(-1);
#endif
})
.def("_type", &PlaceIndex<platform::CUDAPlace>)
.def("_equals", &IsSamePlace<platform::CUDAPlace, platform::Place>)
.def("_equals", &IsSamePlace<platform::CUDAPlace, platform::CUDAPlace>)
.def("_equals", &IsSamePlace<platform::CUDAPlace, platform::CPUPlace>)
.def("_equals",
&IsSamePlace<platform::CUDAPlace, platform::CUDAPinnedPlace>)
.def("__str__", string::to_string<const platform::CUDAPlace &>);
py::class_<paddle::platform::CPUPlace>(m, "CPUPlace", R"DOC(
CPUPlace is a descriptor of a device. It represents a CPU, and the memory
CPUPlace can be accessed by CPU.
Examples:
.. code-block:: python
import paddle.fluid as fluid
cpu_place = fluid.CPUPlace()
)DOC")
.def(py::init<>())
.def("_type", &PlaceIndex<platform::CPUPlace>)
.def("_equals", &IsSamePlace<platform::CPUPlace, platform::Place>)
.def("_equals", &IsSamePlace<platform::CPUPlace, platform::CUDAPlace>)
.def("_equals", &IsSamePlace<platform::CPUPlace, platform::CPUPlace>)
.def("_equals",
&IsSamePlace<platform::CPUPlace, platform::CUDAPinnedPlace>)
.def("__str__", string::to_string<const platform::CPUPlace &>);
py::class_<paddle::platform::CUDAPinnedPlace>(m, "CUDAPinnedPlace", R"DOC(
CUDAPinnedPlace is a descriptor of a device. The memory of CUDAPinnedPlace
can be accessed by GPU and CPU.
Examples:
.. code-block:: python
import paddle.fluid as fluid
place = fluid.CUDAPinnedPlace()
)DOC")
.def("__init__",
[](platform::CUDAPinnedPlace &self) {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW("Cannot use CUDAPinnedPlace in CPU only version");
#endif
new (&self) platform::CUDAPinnedPlace();
})
.def("_type", &PlaceIndex<platform::CUDAPinnedPlace>)
.def("_equals", &IsSamePlace<platform::CUDAPinnedPlace, platform::Place>)
.def("_equals",
&IsSamePlace<platform::CUDAPinnedPlace, platform::CUDAPlace>)
.def("_equals",
&IsSamePlace<platform::CUDAPinnedPlace, platform::CPUPlace>)
.def("_equals",
&IsSamePlace<platform::CUDAPinnedPlace, platform::CUDAPinnedPlace>)
.def("__str__", string::to_string<const platform::CUDAPinnedPlace &>);
py::class_<platform::Place>(m, "Place")
.def(py::init<>())
.def("_type", &PlaceIndex<platform::Place>)
.def("_equals", &IsSamePlace<platform::Place, platform::Place>)
.def("_equals", &IsSamePlace<platform::Place, platform::CUDAPlace>)
.def("_equals", &IsSamePlace<platform::Place, platform::CPUPlace>)
.def("_equals", &IsSamePlace<platform::Place, platform::CUDAPinnedPlace>)
.def("is_gpu_place",
[](platform::Place &self) { return platform::is_gpu_place(self); })
.def("is_cpu_place",
[](platform::Place &self) { return platform::is_cpu_place(self); })
.def("is_cuda_pinned_place",
[](platform::Place &self) {
return platform::is_cuda_pinned_place(self);
})
.def("gpu_device_id",
[](platform::Place &self) {
return boost::get<platform::CUDAPlace>(self).device;
})
.def("set_place", [](platform::Place &self,
const platform::Place &other) { self = other; })
.def("set_place",
[](platform::Place &self, const platform::CPUPlace &cpu_place) {
self = cpu_place;
})
.def("set_place",
[](platform::Place &self, const platform::CUDAPlace &gpu_place) {
self = gpu_place;
})
.def("set_place", [](platform::Place &self,
const platform::CUDAPinnedPlace &cuda_pinned_place) {
self = cuda_pinned_place;
});
py::class_<OperatorBase>(m, "Operator")
.def_static(
"create",
[](py::bytes protobin) {
proto::OpDesc desc;
PADDLE_ENFORCE_EQ(desc.ParsePartialFromString(protobin), true,
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE_EQ(desc.IsInitialized(), true,
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
return OpRegistry::CreateOp(desc);
})
.def("run",
[](OperatorBase &self, const Scope &scope,
const platform::CPUPlace &place) { self.Run(scope, place); })
.def("run",
[](OperatorBase &self, const Scope &scope,
const platform::CUDAPlace &place) { self.Run(scope, place); })
.def("run",
[](OperatorBase &self, const Scope &scope,
const platform::CUDAPinnedPlace &place) {
self.Run(scope, place);
})
.def("type",
[](const OperatorBase &op) -> std::string { return op.Type(); })
.def("outputs",
[](const OperatorBase &op)
-> std::map<std::string, std::vector<std::string>> {
return op.Outputs();
})
.def("output_vars",
[](const OperatorBase &op) { return op.OutputVars(true); })
.def("inputs", [](const OperatorBase &op) { return op.Inputs(); })
.def("input_vars", [](const OperatorBase &op) { return op.InputVars(); })
.def("__str__", &OperatorBase::DebugString)
.def("no_intermediate_outputs",
[](const OperatorBase &op) { return op.OutputVars(false); })
.def("support_gpu", &OperatorBase::SupportGPU);
py::class_<framework::ExecutorPrepareContext>(m, "ExecutorPrepareContext")
.def(py::init<const ProgramDesc &, size_t>());
py::class_<framework::Executor>(m, "Executor")
.def(py::init<const platform::Place &>())
.def("close", &Executor::Close)
.def("run_from_dataset", &Executor::RunFromDataset,
py::call_guard<py::gil_scoped_release>())
.def("run_prepared_ctx",
[](Executor &self, ExecutorPrepareContext *ctx, Scope *scope,
std::map<std::string, const LoDTensor *> *feed_targets,
std::map<std::string, LoDTensor *> *fetch_targets,
bool create_local_scope = true, bool create_vars = true,
const std::string &feed_holder_name = "feed",
const std::string &fetch_holder_name = "fetch") {
pybind11::gil_scoped_release release;
self.RunPreparedContext(ctx, scope, feed_targets, fetch_targets,
create_local_scope, create_vars,
feed_holder_name, fetch_holder_name);
})
.def("run_cached_prepared_ctx",
[](Executor &self, ExecutorPrepareContext *ctx, Scope *scope,
bool create_local_scope = true, bool create_vars = true,
bool keep_kids = false) {
pybind11::gil_scoped_release release;
self.RunPreparedContext(ctx, scope, create_local_scope,
create_vars, keep_kids);
})
.def("prepare_ctx_cache", &Executor::PrepareCtxCache,
py::call_guard<py::gil_scoped_release>())
.def("create_variables", &Executor::CreateVariables,
py::call_guard<py::gil_scoped_release>())
.def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope,
int block_id, bool create_local_scope, bool create_vars,
const std::vector<std::string> &fetch_vars) {
pybind11::gil_scoped_release release;
self.Run(prog, scope, block_id, create_local_scope, create_vars,
fetch_vars);
});
m.def("init_gflags", framework::InitGflags);
m.def("init_glog", framework::InitGLOG);
m.def("init_dgc", framework::InitDGC);
m.def("init_devices",
[](bool init_p2p) { framework::InitDevices(init_p2p); });
m.def("is_compiled_with_ngraph", IsCompiledWithNGRAPH);
m.def("is_compiled_with_cuda", IsCompiledWithCUDA);
m.def("is_compiled_with_mkldnn", IsCompiledWithMKLDNN);
m.def("is_compiled_with_brpc", IsCompiledWithBrpc);
m.def("is_compiled_with_dist", IsCompiledWithDIST);
#ifdef PADDLE_WITH_CUDA
m.def("is_float16_supported", [](const platform::CUDAPlace &place) -> bool {
// Only GPUs with Compute Capability >= 53 support float16
return platform::GetCUDAComputeCapability(place.device) >= 53;
});
#endif
m.def("set_feed_variable", framework::SetFeedVariable);
m.def("get_fetch_variable", framework::GetFetchVariable);
m.def("get_variable_tensor", framework::GetVariableTensor);
m.def("_is_program_version_supported", IsProgramVersionSupported);
BindProgramDesc(&m);
BindBlockDesc(&m);
BindVarDsec(&m);
BindOpDesc(&m);
BindConstValue(&m);
py::class_<framework::LoDRankTable>(m, "LodRankTable")
.def("items", [](framework::LoDRankTable &table) {
std::vector<std::pair<size_t, size_t>> res;
for (auto &item : table.items()) {
res.push_back({item.index, item.length});
}
return res;
});
py::class_<LoDTensorArray>(m, "LoDTensorArray", R"DOC(
Array of LoDTensor.
Examples:
.. code-block:: python
import paddle.fluid as fluid
arr = fluid.LoDTensorArray()
)DOC")
.def("__init__",
[](LoDTensorArray &instance) { new (&instance) LoDTensorArray(); })
.def("__getitem__",
[](LoDTensorArray &self, size_t i) { return &self.at(i); },
py::return_value_policy::reference)
.def("__len__", [](LoDTensorArray &self) { return self.size(); })
.def("__setitem__",
[](LoDTensorArray &self, size_t i, const LoDTensor &t) {
PADDLE_ENFORCE_LT(i, self.size());
self[i].ShareDataWith(t);
self[i].set_lod(t.lod());
})
.def("append",
[](LoDTensorArray &self, const LoDTensor &t) {
self.emplace_back();
self.back().ShareDataWith(t);
self.back().set_lod(t.lod());
},
py::arg("tensor"), R"DOC(
Append a LoDensor to LoDTensorArray.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
arr = fluid.LoDTensorArray()
t = fluid.LoDTensor()
t.set(np.ndarray([5, 30]), fluid.CPUPlace())
arr.append(t)
)DOC")
.def("_move_to_list",
[](LoDTensorArray &self) -> py::list {
py::list res(self.size());
for (size_t i = 0; i < self.size(); ++i) {
res[i] = py::cast(std::move(self[i]));
}
self.clear();
return res;
},
py::return_value_policy::take_ownership);
m.def("op_support_gpu", OpSupportGPU);
#ifdef PADDLE_WITH_CUDA
m.def("get_cuda_device_count", platform::GetCUDADeviceCount);
#ifndef _WIN32
m.def("nvprof_init", platform::CudaProfilerInit);
m.def("nvprof_start", platform::CudaProfilerStart);
m.def("nvprof_stop", platform::CudaProfilerStop);
#endif
#endif
py::enum_<platform::ProfilerState>(m, "ProfilerState", py::arithmetic())
.value("kDisabled", platform::ProfilerState::kDisabled)
.value("kCPU", platform::ProfilerState::kCPU)
.value("kCUDA", platform::ProfilerState::kCUDA)
.value("kAll", platform::ProfilerState::kAll)
.export_values();
py::enum_<platform::EventSortingKey>(m, "EventSortingKey", py::arithmetic())
.value("kDefault", platform::EventSortingKey::kDefault)
.value("kCalls", platform::EventSortingKey::kCalls)
.value("kTotal", platform::EventSortingKey::kTotal)
.value("kMin", platform::EventSortingKey::kMin)
.value("kMax", platform::EventSortingKey::kMax)
.value("kAve", platform::EventSortingKey::kAve)
.export_values();
m.def("enable_profiler", platform::EnableProfiler);
m.def("disable_profiler", platform::DisableProfiler);
m.def("is_profiler_enabled", platform::IsProfileEnabled);
m.def("reset_profiler", platform::ResetProfiler);
m.def("get_pass", [](const std::string &pass_type) {
auto pass = framework::ir::PassRegistry::Instance().Get(pass_type);
return std::shared_ptr<framework::ir::Pass>(std::move(pass));
});
m.def("size_of_dtype", framework::SizeOfType);
using VarQuantScale =
std::unordered_map<std::string, std::pair<bool, LoDTensor>>;
py::class_<ir::Pass, std::shared_ptr<ir::Pass>> pass(m, "Pass");
pass.def(py::init())
.def("has", &ir::Pass::Has)
.def("set_not_owned",
[](ir::Pass &self, const std::string &attr_name, ProgramDesc &attr) {
self.SetNotOwned<ProgramDesc>(attr_name, &attr);
})
.def(
"set",
[](ir::Pass &self, const std::string &name, const std::string &attr) {
self.Set<std::string>(name, new std::string(attr));
})
.def("set", [](ir::Pass &self, const std::string &name,
int val) { self.Set<const int>(name, new int(val)); })
.def("set",
[](ir::Pass &self, const std::string &name,
std::unordered_set<std::string> set) {
self.Set(name, new std::unordered_set<std::string>(set));
})
.def("set",
[](ir::Pass &self, const std::string &name,
std::unordered_set<int> set) {
self.Set(name, new std::unordered_set<int>(set));
})
.def("set",
[](ir::Pass &self, const std::string &name, VarQuantScale scales) {
self.Set(name, new VarQuantScale(scales));
})
.def("type", &ir::Pass::Type)
.def("apply", [](ir::Pass &self, std::shared_ptr<ir::Graph> graph) {
self.Apply(graph.get());
});
py::class_<ir::PassBuilder, std::shared_ptr<ir::PassBuilder>> pb(
m, "PassBuilder");
pb.def(py::init())
.def("append_pass",
[](ir::PassBuilder &self,
const std::string &pass_type) -> std::shared_ptr<ir::Pass> {
return self.AppendPass(pass_type);
})
.def("all_passes", [](ir::PassBuilder &self) { return self.AllPasses(); })
.def("insert_pass",
[](ir::PassBuilder &self, size_t idx, const std::string &pass_type) {
return self.InsertPass(idx, pass_type);
})
.def("remove_pass",
[](ir::PassBuilder &self, size_t idx) { self.RemovePass(idx); });
// -- python binds for parallel executor.
py::class_<ParallelExecutor> pe(m, "ParallelExecutor");
py::class_<ExecutionStrategy> exec_strategy(pe, "ExecutionStrategy", R"DOC(
ExecutionStrategy allows the user to more preciously control how to run
the program in ParallelExecutor by setting the property.
Examples:
.. code-block:: python
import paddle.fluid as fluid
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_loss = fluid.layers.mean(cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(avg_loss)
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_threads = 4
train_exe = fluid.ParallelExecutor(use_cuda=False,
loss_name=avg_loss.name,
exec_strategy=exec_strategy)
)DOC");
exec_strategy.def(py::init())
.def_property(
"num_threads",
[](const ExecutionStrategy &self) { return self.num_threads_; },
[](ExecutionStrategy &self, size_t num_threads) {
self.num_threads_ = num_threads;
},
R"DOC(The type is INT, num_threads represents the size of thread pool that
used to run the operators of the current program in ParallelExecutor.
If :math:`num\_threads=1`, all the operators will execute one by one,
but the order maybe difference between iterations.
If it is not set, it will be set in ParallelExecutor according to the
device type and device count, for GPU, :math:`num\_threads=device\_count*4`, for CPU,
:math:`num\_threads=CPU\_NUM*4`, the explanation of:math:`CPU\_NUM` is in ParallelExecutor.
if it is not set, ParallelExecutor will get the cpu count by calling
`multiprocessing.cpu_count()`. Default 0.)DOC")
.def_property(
"use_cuda",
[](const ExecutionStrategy &self) { return self.use_cuda_; },
[](ExecutionStrategy &self, bool use_cuda) {
self.use_cuda_ = use_cuda;
}) // FIXME(chengduo): Doesn't add doc for 'use_cuda', use_cuda may
// make user confuse, because ParallelExecutor has a parameter named
// 'use_cuda' too, in current implementation, ParallelExecutor's
// 'use_cuda' will rewrite ExecutionStrategy's 'use_cuda'.
.def_property(
"allow_op_delay",
[](const ExecutionStrategy &self) { return self.allow_op_delay_; },
[](ExecutionStrategy &self, bool allow_op_delay) {
self.allow_op_delay_ = allow_op_delay;
},
R"DOC(The type is BOOL, allow_op_delay represents whether to delay the
communication operators to run, it may make the execution faster.
Note that this option is invalid now, and it will be removed in
next version. Default False.)DOC")
.def_property(
"num_iteration_per_drop_scope",
[](const ExecutionStrategy &self) {
return self.num_iteration_per_drop_scope_;
},
[](ExecutionStrategy &self, size_t num_iteration_per_drop_scope) {
self.num_iteration_per_drop_scope_ = num_iteration_per_drop_scope;
},
R"DOC(The type is INT, num_iteration_per_drop_scope indicates how
many iterations to clean up the temp variables which
is generated during execution. It may make the execution faster,
because the temp variable's shape maybe the same between two iterations.
Default 1.
NOTES:
1. If you fetch data when calling the 'run', the ParallelExecutor
will clean up the temp variables at the end of the current iteration.
2. In some NLP model, it may cause the GPU memory is insufficient,
in this case, you should reduce `num_iteration_per_drop_scope`.
)DOC")
.def_property(
"num_iteration_per_run",
[](const ExecutionStrategy &self) {
return self.num_iteration_per_run_;
},
[](ExecutionStrategy &self, size_t num_iteration_per_run) {
self.num_iteration_per_run_ = num_iteration_per_run;
},
R"DOC(This config that how many iteration the executor will run when
user call pe.run() in python
)DOC")
.def_property("_dry_run",
[](const ExecutionStrategy &self) { return self.dry_run_; },
[](ExecutionStrategy &self, bool dry_run) {
self.dry_run_ = dry_run;
});
exec_strategy.def_property(
"use_experimental_executor",
[](const ExecutionStrategy &self) {
return self.type_ == ExecutionStrategy::kExperimental;
},
[](ExecutionStrategy &self, bool experimental) {
self.type_ = experimental ? ExecutionStrategy::kExperimental
: ExecutionStrategy::kDefault;
});
py::class_<BuildStrategy> build_strategy(pe, "BuildStrategy", R"DOC(
BuildStrategy allows the user to more preciously control how to
build the SSA Graph in ParallelExecutor by setting the property.
Examples:
.. code-block:: python
import paddle.fluid as fluid
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
)DOC");
py::enum_<BuildStrategy::ReduceStrategy>(build_strategy, "ReduceStrategy")
.value("Reduce", BuildStrategy::ReduceStrategy::kReduce)
.value("AllReduce", BuildStrategy::ReduceStrategy::kAllReduce);
py::enum_<BuildStrategy::GradientScaleStrategy>(build_strategy,
"GradientScaleStrategy")
.value("CoeffNumDevice",
BuildStrategy::GradientScaleStrategy::kCoeffNumDevice)
.value("One", BuildStrategy::GradientScaleStrategy::kOne)
.value("Customized", BuildStrategy::GradientScaleStrategy::kCustomized);
build_strategy.def(py::init())
.def_property(
"reduce_strategy",
[](const BuildStrategy &self) { return self.reduce_; },
[](BuildStrategy &self, BuildStrategy::ReduceStrategy strategy) {
PADDLE_ENFORCE_EQ(!self.IsFinalized(), true,
"BuildStrategy is finlaized.");
self.reduce_ = strategy;
},
R"DOC(The type is fluid.BuildStrategy.ReduceStrategy, there are two reduce
strategies in ParallelExecutor, AllReduce and Reduce. If you want
that all the parameters' optimization are done on all devices independently,
you should choose AllReduce; if you choose Reduce, all the parameters'
optimization will be evenly distributed to different devices, and then
broadcast the optimized parameter to other devices.
Default 'AllReduce'.
Examples:
.. code-block:: python
import paddle.fluid as fluid
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
)DOC")
.def_property(
"gradient_scale_strategy",
[](const BuildStrategy &self) { return self.gradient_scale_; },
[](BuildStrategy &self,
BuildStrategy::GradientScaleStrategy strategy) {
PADDLE_ENFORCE_EQ(!self.IsFinalized(), true,
"BuildStrategy is finalized.");
self.gradient_scale_ = strategy;
},
R"DOC(The type is fluid.BuildStrategy.GradientScaleStrategy, there are three
ways of defining :math:`loss@grad` in ParallelExecutor, CoeffNumDevice,
One and Customized. By default, ParallelExecutor sets the :math:`loss@grad`
according to the number of devices. If you want to customize :math:`loss@grad`,
you can choose Customized. Default 'CoeffNumDevice'.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle.fluid.compiler as compiler
import numpy
import os
use_cuda = True
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
# NOTE: If you use CPU to run the program, you need
# to specify the CPU_NUM, otherwise, fluid will use
# all the number of the logic core as the CPU_NUM,
# in that case, the batch size of the input should be
# greater than CPU_NUM, if not, the process will be
# failed by an exception.
if not use_cuda:
os.environ['CPU_NUM'] = str(2)
places = fluid.cpu_places()
else:
places = places = fluid.cuda_places()
data = fluid.layers.data(name='X', shape=[1], dtype='float32')
hidden = fluid.layers.fc(input=data, size=10)
loss = fluid.layers.mean(hidden)
fluid.optimizer.SGD(learning_rate=0.01).minimize(loss)
fluid.default_startup_program().random_seed=1
exe.run(fluid.default_startup_program())
build_strategy = fluid.BuildStrategy()
build_strategy.gradient_scale_strategy = \
fluid.BuildStrategy.GradientScaleStrategy.Customized
compiled_prog = compiler.CompiledProgram(
fluid.default_main_program()).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy,
places = places)
dev_count = len(places)
x = numpy.random.random(size=(10, 1)).astype('float32')
loss_grad = numpy.ones((dev_count)).astype("float32") * 0.01
loss_grad_name = loss.name+"@GRAD"
loss_data = exe.run(compiled_prog,
feed={"X": x, loss_grad_name : loss_grad},
fetch_list=[loss.name, loss_grad_name])
)DOC")
.def_property(
"debug_graphviz_path",
[](const BuildStrategy &self) { return self.debug_graphviz_path_; },
[](BuildStrategy &self, const std::string &path) {
PADDLE_ENFORCE_EQ(!self.IsFinalized(), true,
"BuildStrategy is finlaized.");
self.debug_graphviz_path_ = path;
},
R"DOC(The type is STR, debug_graphviz_path indicates the path that
writing the SSA Graph to file in the form of graphviz.
It is useful for debugging. Default ""
Examples:
.. code-block:: python
import paddle.fluid as fluid
build_strategy = fluid.BuildStrategy()
build_strategy.debug_graphviz_path = "./graph"
)DOC")
.def_property(
"enable_sequential_execution",
[](const BuildStrategy &self) {
return self.enable_sequential_execution_;
},
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE_EQ(!self.IsFinalized(), true,
"BuildStrategy is finlaized.");
self.enable_sequential_execution_ = b;
},
R"DOC(The type is BOOL. If set True, the execution order of ops would
be the same as what is in the program. Default False.
Examples:
.. code-block:: python
import paddle.fluid as fluid
build_strategy = fluid.BuildStrategy()
build_strategy.enable_sequential_execution = True
)DOC")
.def_property(
"remove_unnecessary_lock",
[](const BuildStrategy &self) {
return self.remove_unnecessary_lock_;
},
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE_EQ(!self.IsFinalized(), true,
"BuildStrategy is finlaized.");
self.remove_unnecessary_lock_ = b;
},
R"DOC(The type is BOOL. If set True, some locks in GPU ops would be
released and ParallelExecutor would run faster. Default True.
Examples:
.. code-block:: python
import paddle.fluid as fluid
build_strategy = fluid.BuildStrategy()
build_strategy.remove_unnecessary_lock = True
)DOC")
.def_property(
"num_trainers",
[](const BuildStrategy &self) { return self.num_trainers_; },
[](BuildStrategy &self, int num_trainers) {
#ifdef WIN32
PADDLE_THROW("Windows has NO support to distribute mode.");
#endif
self.num_trainers_ = num_trainers;
})
.def_property(
"trainers_endpoints",
[](const BuildStrategy &self) { return self.trainers_endpoints_; },
[](BuildStrategy &self,
const std::vector<std::string> &trainers_endpoints) {
self.trainers_endpoints_ = trainers_endpoints;
})
.def_property("trainer_id",
[](const BuildStrategy &self) { return self.trainer_id_; },
[](BuildStrategy &self, int trainer_id) {
self.trainer_id_ = trainer_id;
})
.def_property(
"nccl_comm_num",
[](const BuildStrategy &self) { return self.nccl_comm_num_; },
[](BuildStrategy &self, int nccl_comm_num) {
self.nccl_comm_num_ = nccl_comm_num;
})
.def_property("use_hierarchical_allreduce",
[](const BuildStrategy &self) {
return self.use_hierarchical_allreduce_;
},
[](BuildStrategy &self, bool use) {
self.use_hierarchical_allreduce_ = use;
})
.def_property("hierarchical_allreduce_inter_nranks",
[](const BuildStrategy &self) {
return self.hierarchical_allreduce_inter_nranks_;
},
[](BuildStrategy &self, int nranks) {
self.hierarchical_allreduce_inter_nranks_ = nranks;
})
.def_property(
"fuse_elewise_add_act_ops",
[](const BuildStrategy &self) {
return self.fuse_elewise_add_act_ops_;
},
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE_EQ(!self.IsFinalized(), true,
"BuildStrategy is finlaized.");
self.fuse_elewise_add_act_ops_ = b;
},
R"DOC(The type is BOOL, fuse_elewise_add_act_ops indicate whether
to fuse elementwise_add_op and activation_op,
it may make the execution faster. Default False
Examples:
.. code-block:: python
import paddle.fluid as fluid
build_strategy = fluid.BuildStrategy()
build_strategy.fuse_elewise_add_act_ops = True
)DOC")
.def_property(
"fuse_relu_depthwise_conv",
[](const BuildStrategy &self) {
return self.fuse_relu_depthwise_conv_;
},
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE_EQ(!self.IsFinalized(), true,
"BuildStrategy is finlaized.");
self.fuse_relu_depthwise_conv_ = b;
},
R"DOC(The type is BOOL, fuse_relu_depthwise_conv indicate whether
to fuse relu and depthwise_conv2d,
it will save GPU memory and may make the execution faster.
This options is only available in GPU devices.
Default False.
Examples:
.. code-block:: python
import paddle.fluid as fluid
build_strategy = fluid.BuildStrategy()
build_strategy.fuse_relu_depthwise_conv = True
)DOC")
.def_property("fuse_broadcast_ops",
[](const BuildStrategy &self) {
return self.fuse_broadcast_ops_ == true ||
self.fuse_broadcast_ops_ == boost::none;
},
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE_EQ(!self.IsFinalized(), true,
"BuildStrategy is finlaized.");
self.fuse_broadcast_ops_ = b;
},
R"DOC(The type is BOOL, fuse_broadcast_op indicates whether
to fuse the broadcast ops. Note that, in Reduce mode,
fusing broadcast ops may make the program faster. Because
fusing broadcast OP equals delaying the execution of all
broadcast Ops, in this case, all nccl streams are used only
for NCCLReduce operations for a period of time. Default False.)DOC")
.def_property("fuse_all_optimizer_ops",
[](const BuildStrategy &self) {
return self.fuse_all_optimizer_ops_ == true ||
self.fuse_all_optimizer_ops_ == boost::none;
},
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE_EQ(!self.IsFinalized(), true,
"BuildStrategy is finlaized.");
self.fuse_all_optimizer_ops_ = b;
})
.def_property(
"sync_batch_norm",
[](const BuildStrategy &self) { return self.sync_batch_norm_; },
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE_EQ(!self.IsFinalized(), true,
"BuildStrategy is finlaized.");
self.sync_batch_norm_ = b;
},
R"DOC(The type is BOOL, sync_batch_norm indicates whether to use
synchronous batch normalization which synchronizes the mean
and variance through multi-devices in training phase.
Current implementation doesn't support FP16 training and CPU.
And only synchronous on one machine, not all machines.
Default False
Examples:
.. code-block:: python
import paddle.fluid as fluid
build_strategy = fluid.BuildStrategy()
build_strategy.sync_batch_norm = True
)DOC")
.def_property(
"memory_optimize",
[](const BuildStrategy &self) -> py::object {
if (self.memory_optimize_) {
return py::cast(self.memory_optimize_.get());
} else {
return py::cast(nullptr);
}
},
[](BuildStrategy &self, const py::handle &value) {
auto *py_obj = value.ptr();
if (py_obj == nullptr || py_obj == Py_None) {
self.memory_optimize_ = boost::none;
} else if (PyBool_Check(py_obj)) {
self.memory_optimize_ = (py_obj == Py_True);
} else {
PADDLE_THROW(
"BuildStrategy.memory_optimize must be None, False or True");
}
},
R"DOC(The type is BOOL or None, memory opitimize aims to save total memory
consumption, set to True to enable it.
Default None. None means framework would choose to use or not use
this strategy automatically. Currently, None means that it is
enabled when GC is disabled, and disabled when GC is enabled.
True means enabling and False means disabling. Default None.)DOC")
.def_property(
"is_distribution",
[](const BuildStrategy &self) { return self.is_distribution_; },
[](BuildStrategy &self, bool b) {
#ifdef WIN32
if (b) {
PADDLE_THROW("Windows has NO support to distribute mode.");
}
#else
self.is_distribution_ = b;
#endif
})
.def_property("async_mode",
[](const BuildStrategy &self) { return self.async_mode_; },
[](BuildStrategy &self, bool b) { self.async_mode_ = b; })
.def_property(
"enable_inplace",
[](const BuildStrategy &self) { return self.enable_inplace_; },
[](BuildStrategy &self, bool b) { self.enable_inplace_ = b; })
.def_property(
"fuse_all_reduce_ops",
[](const BuildStrategy &self) {
return self.fuse_all_reduce_ops_ == true ||
self.fuse_all_reduce_ops_ == boost::none;
},
[](BuildStrategy &self, bool b) { self.fuse_all_reduce_ops_ = b; })
.def_property("enable_backward_optimizer_op_deps",
[](const BuildStrategy &self) {
return self.enable_backward_optimizer_op_deps_;
},
[](BuildStrategy &self, bool b) {
self.enable_backward_optimizer_op_deps_ = b;
})
.def_property(
"cache_runtime_context",
[](const BuildStrategy &self) { return self.cache_runtime_context_; },
[](BuildStrategy &self, bool b) { self.cache_runtime_context_ = b; })
.def_property(
"mkldnn_enabled_op_types",
[](const BuildStrategy &self) {
return self.mkldnn_enabled_op_types_;
},
[](BuildStrategy &self,
const std::unordered_set<std::string> &mkldnn_enabled_op_types) {
self.mkldnn_enabled_op_types_ = mkldnn_enabled_op_types;
})
.def("_finalize_strategy_and_create_passes",
[](BuildStrategy &self) -> std::shared_ptr<ir::PassBuilder> {
return self.CreatePassesFromStrategy(true);
},
R"DOC(Allow user to customized passes. Normally model-specific
optimization passes should be defined in this way. BuildStrategy
cannot be updated after being finalized.)DOC");
pe.def(py::init<const std::vector<platform::Place> &,
const std::vector<std::string> &, const std::string &,
Scope *, std::vector<Scope *> &, const ExecutionStrategy &,
const BuildStrategy &, ir::Graph *>())
// NOTE: even we return a vec<Scope*>* to Python use reference policy.
// We still cannot get local_scope from this vector, since the element
// of vec<Scope*> will be freed by Python GC. We can only return Scope*
// one by one and mark them as reference.
.def("local_scopes",
[](ParallelExecutor &self) -> std::vector<Scope *> * {
return &self.GetLocalScopes();
},
py::return_value_policy::reference)
.def("drop_local_exe_scopes", &ParallelExecutor::DropLocalExeScopes)
.def("_need_create_local_exe_scopes",
&ParallelExecutor::NeedCreateLocalExeScope)
.def("feed_tensors_into_local_scopes",
&ParallelExecutor::FeedTensorsIntoLocalScopes)
.def("feed_and_split_tensor_into_local_scopes",
&ParallelExecutor::FeedAndSplitTensorIntoLocalScopes)
.def("run", [](ParallelExecutor &self,
const std::vector<std::string> &fetch_tensors) {
pybind11::gil_scoped_release release;
return self.Run(fetch_tensors);
});
BindFleetWrapper(&m);
BindBoxHelper(&m);
#ifndef _WIN32
BindNCCLWrapper(&m);
#endif
BindGraph(&m);
BindNode(&m);
BindInferenceApi(&m);
BindExpandApi(&m);
BindDataset(&m);
#ifdef PADDLE_WITH_DISTRIBUTE
BindCommunicator(&m);
#endif
}
} // namespace pybind
} // namespace paddle
add_subdirectory(common)
add_subdirectory(data_reader)
cc_library(dict_plugin SRCS dict_plugin.cc DEPS glog boost fs)
/**
* MIT License
*
* Copyright (c) 2017 Tessil
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_BHOPSCOTCH_MAP_H
#define TSL_BHOPSCOTCH_MAP_H
#include <algorithm>
#include <cstddef>
#include <functional>
#include <initializer_list>
#include <map>
#include <memory>
#include <type_traits>
#include <utility>
#include "paddle/fluid/feed/src/common/hopscotch_hash.h"
namespace tsl {
/**
* Similar to tsl::hopscotch_map but instead of using a list for overflowing elements it uses
* a binary search tree. It thus needs an additional template parameter Compare. Compare should
* be arithmetically coherent with KeyEqual.
*
* The binary search tree allows the map to have a worst-case scenario of O(log n) for search
* and delete, even if the hash function maps all the elements to the same bucket.
* For insert, the amortized worst case is O(log n), but the worst case is O(n) in case of rehash.
*
* This makes the map resistant to DoS attacks (but doesn't preclude you to have a good hash function,
* as an element in the bucket array is faster to retrieve than in the tree).
*
* @copydoc hopscotch_map
*/
template<class Key,
class T,
class Hash = std::hash<Key>,
class KeyEqual = std::equal_to<Key>,
class Compare = std::less<Key>,
class Allocator = std::allocator<std::pair<const Key, T>>,
unsigned int NeighborhoodSize = 62,
bool StoreHash = false,
class GrowthPolicy = tsl::hh::power_of_two_growth_policy<2>>
class bhopscotch_map {
private:
template<typename U>
using has_is_transparent = tsl::detail_hopscotch_hash::has_is_transparent<U>;
class KeySelect {
public:
using key_type = Key;
const key_type& operator()(const std::pair<const Key, T>& key_value) const {
return key_value.first;
}
const key_type& operator()(std::pair<const Key, T>& key_value) {
return key_value.first;
}
};
class ValueSelect {
public:
using value_type = T;
const value_type& operator()(const std::pair<const Key, T>& key_value) const {
return key_value.second;
}
value_type& operator()(std::pair<Key, T>& key_value) {
return key_value.second;
}
};
// TODO Not optimal as we have to use std::pair<const Key, T> as ValueType which forbid
// us to move the key in the bucket array, we have to use copy. Optimize.
using overflow_container_type = std::map<Key, T, Compare, Allocator>;
using ht = detail_hopscotch_hash::hopscotch_hash<std::pair<const Key, T>, KeySelect, ValueSelect,
Hash, KeyEqual,
Allocator, NeighborhoodSize,
StoreHash, GrowthPolicy,
overflow_container_type>;
public:
using key_type = typename ht::key_type;
using mapped_type = T;
using value_type = typename ht::value_type;
using size_type = typename ht::size_type;
using difference_type = typename ht::difference_type;
using hasher = typename ht::hasher;
using key_equal = typename ht::key_equal;
using key_compare = Compare;
using allocator_type = typename ht::allocator_type;
using reference = typename ht::reference;
using const_reference = typename ht::const_reference;
using pointer = typename ht::pointer;
using const_pointer = typename ht::const_pointer;
using iterator = typename ht::iterator;
using const_iterator = typename ht::const_iterator;
/*
* Constructors
*/
bhopscotch_map() : bhopscotch_map(ht::DEFAULT_INIT_BUCKETS_SIZE) {
}
explicit bhopscotch_map(size_type bucket_count,
const Hash& hash = Hash(),
const KeyEqual& equal = KeyEqual(),
const Allocator& alloc = Allocator(),
const Compare& comp = Compare()) :
m_ht(bucket_count, hash, equal, alloc, ht::DEFAULT_MAX_LOAD_FACTOR, comp)
{
}
bhopscotch_map(size_type bucket_count,
const Allocator& alloc) : bhopscotch_map(bucket_count, Hash(), KeyEqual(), alloc)
{
}
bhopscotch_map(size_type bucket_count,
const Hash& hash,
const Allocator& alloc) : bhopscotch_map(bucket_count, hash, KeyEqual(), alloc)
{
}
explicit bhopscotch_map(const Allocator& alloc) : bhopscotch_map(ht::DEFAULT_INIT_BUCKETS_SIZE, alloc) {
}
template<class InputIt>
bhopscotch_map(InputIt first, InputIt last,
size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE,
const Hash& hash = Hash(),
const KeyEqual& equal = KeyEqual(),
const Allocator& alloc = Allocator()) : bhopscotch_map(bucket_count, hash, equal, alloc)
{
insert(first, last);
}
template<class InputIt>
bhopscotch_map(InputIt first, InputIt last,
size_type bucket_count,
const Allocator& alloc) : bhopscotch_map(first, last, bucket_count, Hash(), KeyEqual(), alloc)
{
}
template<class InputIt>
bhopscotch_map(InputIt first, InputIt last,
size_type bucket_count,
const Hash& hash,
const Allocator& alloc) : bhopscotch_map(first, last, bucket_count, hash, KeyEqual(), alloc)
{
}
bhopscotch_map(std::initializer_list<value_type> init,
size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE,
const Hash& hash = Hash(),
const KeyEqual& equal = KeyEqual(),
const Allocator& alloc = Allocator()) :
bhopscotch_map(init.begin(), init.end(), bucket_count, hash, equal, alloc)
{
}
bhopscotch_map(std::initializer_list<value_type> init,
size_type bucket_count,
const Allocator& alloc) :
bhopscotch_map(init.begin(), init.end(), bucket_count, Hash(), KeyEqual(), alloc)
{
}
bhopscotch_map(std::initializer_list<value_type> init,
size_type bucket_count,
const Hash& hash,
const Allocator& alloc) :
bhopscotch_map(init.begin(), init.end(), bucket_count, hash, KeyEqual(), alloc)
{
}
bhopscotch_map& operator=(std::initializer_list<value_type> ilist) {
m_ht.clear();
m_ht.reserve(ilist.size());
m_ht.insert(ilist.begin(), ilist.end());
return *this;
}
allocator_type get_allocator() const { return m_ht.get_allocator(); }
/*
* Iterators
*/
iterator begin() noexcept { return m_ht.begin(); }
const_iterator begin() const noexcept { return m_ht.begin(); }
const_iterator cbegin() const noexcept { return m_ht.cbegin(); }
iterator end() noexcept { return m_ht.end(); }
const_iterator end() const noexcept { return m_ht.end(); }
const_iterator cend() const noexcept { return m_ht.cend(); }
/*
* Capacity
*/
bool empty() const noexcept { return m_ht.empty(); }
size_type size() const noexcept { return m_ht.size(); }
size_type max_size() const noexcept { return m_ht.max_size(); }
/*
* Modifiers
*/
void clear() noexcept { m_ht.clear(); }
std::pair<iterator, bool> insert(const value_type& value) {
return m_ht.insert(value);
}
template<class P, typename std::enable_if<std::is_constructible<value_type, P&&>::value>::type* = nullptr>
std::pair<iterator, bool> insert(P&& value) {
return m_ht.insert(std::forward<P>(value));
}
std::pair<iterator, bool> insert(value_type&& value) {
return m_ht.insert(std::move(value));
}
iterator insert(const_iterator hint, const value_type& value) {
return m_ht.insert(hint, value);
}
template<class P, typename std::enable_if<std::is_constructible<value_type, P&&>::value>::type* = nullptr>
iterator insert(const_iterator hint, P&& value) {
return m_ht.insert(hint, std::forward<P>(value));
}
iterator insert(const_iterator hint, value_type&& value) {
return m_ht.insert(hint, std::move(value));
}
template<class InputIt>
void insert(InputIt first, InputIt last) {
m_ht.insert(first, last);
}
void insert(std::initializer_list<value_type> ilist) {
m_ht.insert(ilist.begin(), ilist.end());
}
template<class M>
std::pair<iterator, bool> insert_or_assign(const key_type& k, M&& obj) {
return m_ht.insert_or_assign(k, std::forward<M>(obj));
}
template<class M>
std::pair<iterator, bool> insert_or_assign(key_type&& k, M&& obj) {
return m_ht.insert_or_assign(std::move(k), std::forward<M>(obj));
}
template<class M>
iterator insert_or_assign(const_iterator hint, const key_type& k, M&& obj) {
return m_ht.insert_or_assign(hint, k, std::forward<M>(obj));
}
template<class M>
iterator insert_or_assign(const_iterator hint, key_type&& k, M&& obj) {
return m_ht.insert_or_assign(hint, std::move(k), std::forward<M>(obj));
}
/**
* Due to the way elements are stored, emplace will need to move or copy the key-value once.
* The method is equivalent to insert(value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template<class... Args>
std::pair<iterator, bool> emplace(Args&&... args) {
return m_ht.emplace(std::forward<Args>(args)...);
}
/**
* Due to the way elements are stored, emplace_hint will need to move or copy the key-value once.
* The method is equivalent to insert(hint, value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template<class... Args>
iterator emplace_hint(const_iterator hint, Args&&... args) {
return m_ht.emplace_hint(hint, std::forward<Args>(args)...);
}
template<class... Args>
std::pair<iterator, bool> try_emplace(const key_type& k, Args&&... args) {
return m_ht.try_emplace(k, std::forward<Args>(args)...);
}
template<class... Args>
std::pair<iterator, bool> try_emplace(key_type&& k, Args&&... args) {
return m_ht.try_emplace(std::move(k), std::forward<Args>(args)...);
}
template<class... Args>
iterator try_emplace(const_iterator hint, const key_type& k, Args&&... args) {
return m_ht.try_emplace(hint, k, std::forward<Args>(args)...);
}
template<class... Args>
iterator try_emplace(const_iterator hint, key_type&& k, Args&&... args) {
return m_ht.try_emplace(hint, std::move(k), std::forward<Args>(args)...);
}
iterator erase(iterator pos) { return m_ht.erase(pos); }
iterator erase(const_iterator pos) { return m_ht.erase(pos); }
iterator erase(const_iterator first, const_iterator last) { return m_ht.erase(first, last); }
size_type erase(const key_type& key) { return m_ht.erase(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
*/
size_type erase(const key_type& key, std::size_t precalculated_hash) {
return m_ht.erase(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent
* and Compare::is_transparent exist.
* If so, K must be hashable and comparable to Key.
*/
template<class K, class KE = KeyEqual, class CP = Compare,
typename std::enable_if<has_is_transparent<KE>::value && has_is_transparent<CP>::value>::type* = nullptr>
size_type erase(const K& key) { return m_ht.erase(key); }
/**
* @copydoc erase(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
*/
template<class K, class KE = KeyEqual, class CP = Compare,
typename std::enable_if<has_is_transparent<KE>::value && has_is_transparent<CP>::value>::type* = nullptr>
size_type erase(const K& key, std::size_t precalculated_hash) { return m_ht.erase(key, precalculated_hash); }
void swap(bhopscotch_map& other) { other.m_ht.swap(m_ht); }
/*
* Lookup
*/
T& at(const Key& key) { return m_ht.at(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
T& at(const Key& key, std::size_t precalculated_hash) { return m_ht.at(key, precalculated_hash); }
const T& at(const Key& key) const { return m_ht.at(key); }
/**
* @copydoc at(const Key& key, std::size_t precalculated_hash)
*/
const T& at(const Key& key, std::size_t precalculated_hash) const { return m_ht.at(key, precalculated_hash); }
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent
* and Compare::is_transparent exist.
* If so, K must be hashable and comparable to Key.
*/
template<class K, class KE = KeyEqual, class CP = Compare,
typename std::enable_if<has_is_transparent<KE>::value && has_is_transparent<CP>::value>::type* = nullptr>
T& at(const K& key) { return m_ht.at(key); }
/**
* @copydoc at(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template<class K, class KE = KeyEqual, class CP = Compare,
typename std::enable_if<has_is_transparent<KE>::value && has_is_transparent<CP>::value>::type* = nullptr>
T& at(const K& key, std::size_t precalculated_hash) { return m_ht.at(key, precalculated_hash); }
/**
* @copydoc at(const K& key)
*/
template<class K, class KE = KeyEqual, class CP = Compare,
typename std::enable_if<has_is_transparent<KE>::value && has_is_transparent<CP>::value>::type* = nullptr>
const T& at(const K& key) const { return m_ht.at(key); }
/**
* @copydoc at(const K& key, std::size_t precalculated_hash)
*/
template<class K, class KE = KeyEqual, class CP = Compare,
typename std::enable_if<has_is_transparent<KE>::value && has_is_transparent<CP>::value>::type* = nullptr>
const T& at(const K& key, std::size_t precalculated_hash) const { return m_ht.at(key, precalculated_hash); }
T& operator[](const Key& key) { return m_ht[key]; }
T& operator[](Key&& key) { return m_ht[std::move(key)]; }
size_type count(const Key& key) const { return m_ht.count(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
size_type count(const Key& key, std::size_t precalculated_hash) const { return m_ht.count(key, precalculated_hash); }
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent
* and Compare::is_transparent exist.
* If so, K must be hashable and comparable to Key.
*/
template<class K, class KE = KeyEqual, class CP = Compare,
typename std::enable_if<has_is_transparent<KE>::value && has_is_transparent<CP>::value>::type* = nullptr>
size_type count(const K& key) const { return m_ht.count(key); }
/**
* @copydoc count(const K& key) const
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template<class K, class KE = KeyEqual, class CP = Compare,
typename std::enable_if<has_is_transparent<KE>::value && has_is_transparent<CP>::value>::type* = nullptr>
size_type count(const K& key, std::size_t precalculated_hash) const { return m_ht.count(key, precalculated_hash); }
iterator find(const Key& key) { return m_ht.find(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
iterator find(const Key& key, std::size_t precalculated_hash) { return m_ht.find(key, precalculated_hash); }
const_iterator find(const Key& key) const { return m_ht.find(key); }
/**
* @copydoc find(const Key& key, std::size_t precalculated_hash)
*/
const_iterator find(const Key& key, std::size_t precalculated_hash) const { return m_ht.find(key, precalculated_hash); }
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent
* and Compare::is_transparent exist.
* If so, K must be hashable and comparable to Key.
*/
template<class K, class KE = KeyEqual, class CP = Compare,
typename std::enable_if<has_is_transparent<KE>::value && has_is_transparent<CP>::value>::type* = nullptr>
iterator find(const K& key) { return m_ht.find(key); }
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template<class K, class KE = KeyEqual, class CP = Compare,
typename std::enable_if<has_is_transparent<KE>::value && has_is_transparent<CP>::value>::type* = nullptr>
iterator find(const K& key, std::size_t precalculated_hash) { return m_ht.find(key, precalculated_hash); }
/**
* @copydoc find(const K& key)
*/
template<class K, class KE = KeyEqual, class CP = Compare,
typename std::enable_if<has_is_transparent<KE>::value && has_is_transparent<CP>::value>::type* = nullptr>
const_iterator find(const K& key) const { return m_ht.find(key); }
/**
* @copydoc find(const K& key, std::size_t precalculated_hash)
*/
template<class K, class KE = KeyEqual, class CP = Compare,
typename std::enable_if<has_is_transparent<KE>::value && has_is_transparent<CP>::value>::type* = nullptr>
const_iterator find(const K& key, std::size_t precalculated_hash) const { return m_ht.find(key, precalculated_hash); }
std::pair<iterator, iterator> equal_range(const Key& key) { return m_ht.equal_range(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
std::pair<iterator, iterator> equal_range(const Key& key, std::size_t precalculated_hash) {
return m_ht.equal_range(key, precalculated_hash);
}
std::pair<const_iterator, const_iterator> equal_range(const Key& key) const { return m_ht.equal_range(key); }
/**
* @copydoc equal_range(const Key& key, std::size_t precalculated_hash)
*/
std::pair<const_iterator, const_iterator> equal_range(const Key& key, std::size_t precalculated_hash) const {
return m_ht.equal_range(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent
* and Compare::is_transparent exist.
* If so, K must be hashable and comparable to Key.
*/
template<class K, class KE = KeyEqual, class CP = Compare,
typename std::enable_if<has_is_transparent<KE>::value && has_is_transparent<CP>::value>::type* = nullptr>
std::pair<iterator, iterator> equal_range(const K& key) { return m_ht.equal_range(key); }
/**
* @copydoc equal_range(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template<class K, class KE = KeyEqual, class CP = Compare,
typename std::enable_if<has_is_transparent<KE>::value && has_is_transparent<CP>::value>::type* = nullptr>
std::pair<iterator, iterator> equal_range(const K& key, std::size_t precalculated_hash) {
return m_ht.equal_range(key, precalculated_hash);
}
/**
* @copydoc equal_range(const K& key)
*/
template<class K, class KE = KeyEqual, class CP = Compare,
typename std::enable_if<has_is_transparent<KE>::value && has_is_transparent<CP>::value>::type* = nullptr>
std::pair<const_iterator, const_iterator> equal_range(const K& key) const { return m_ht.equal_range(key); }
/**
* @copydoc equal_range(const K& key, std::size_t precalculated_hash)
*/
template<class K, class KE = KeyEqual, class CP = Compare,
typename std::enable_if<has_is_transparent<KE>::value && has_is_transparent<CP>::value>::type* = nullptr>
std::pair<const_iterator, const_iterator> equal_range(const K& key, std::size_t precalculated_hash) const {
return m_ht.equal_range(key, precalculated_hash);
}
/*
* Bucket interface
*/
size_type bucket_count() const { return m_ht.bucket_count(); }
size_type max_bucket_count() const { return m_ht.max_bucket_count(); }
/*
* Hash policy
*/
float load_factor() const { return m_ht.load_factor(); }
float max_load_factor() const { return m_ht.max_load_factor(); }
void max_load_factor(float ml) { m_ht.max_load_factor(ml); }
void rehash(size_type count_) { m_ht.rehash(count_); }
void reserve(size_type count_) { m_ht.reserve(count_); }
/*
* Observers
*/
hasher hash_function() const { return m_ht.hash_function(); }
key_equal key_eq() const { return m_ht.key_eq(); }
key_compare key_comp() const { return m_ht.key_comp(); }
/*
* Other
*/
/**
* Convert a const_iterator to an iterator.
*/
iterator mutable_iterator(const_iterator pos) {
return m_ht.mutable_iterator(pos);
}
size_type overflow_size() const noexcept { return m_ht.overflow_size(); }
friend bool operator==(const bhopscotch_map& lhs, const bhopscotch_map& rhs) {
if(lhs.size() != rhs.size()) {
return false;
}
for(const auto& element_lhs : lhs) {
const auto it_element_rhs = rhs.find(element_lhs.first);
if(it_element_rhs == rhs.cend() || element_lhs.second != it_element_rhs->second) {
return false;
}
}
return true;
}
friend bool operator!=(const bhopscotch_map& lhs, const bhopscotch_map& rhs) {
return !operator==(lhs, rhs);
}
friend void swap(bhopscotch_map& lhs, bhopscotch_map& rhs) {
lhs.swap(rhs);
}
private:
ht m_ht;
};
/**
* Same as `tsl::bhopscotch_map<Key, T, Hash, KeyEqual, Compare, Allocator, NeighborhoodSize, StoreHash, tsl::hh::prime_growth_policy>`.
*/
template<class Key,
class T,
class Hash = std::hash<Key>,
class KeyEqual = std::equal_to<Key>,
class Compare = std::less<Key>,
class Allocator = std::allocator<std::pair<const Key, T>>,
unsigned int NeighborhoodSize = 62,
bool StoreHash = false>
using bhopscotch_pg_map = bhopscotch_map<Key, T, Hash, KeyEqual, Compare, Allocator, NeighborhoodSize, StoreHash, tsl::hh::prime_growth_policy>;
} // end namespace tsl
#endif
/**
* MIT License
*
* Copyright (c) 2017 Tessil
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_BHOPSCOTCH_SET_H
#define TSL_BHOPSCOTCH_SET_H
#include <algorithm>
#include <cstddef>
#include <functional>
#include <initializer_list>
#include <memory>
#include <set>
#include <type_traits>
#include <utility>
#include "paddle/fluid/feed/src/common/hopscotch_hash.h"
namespace tsl {
/**
* Similar to tsl::hopscotch_set but instead of using a list for overflowing elements it uses
* a binary search tree. It thus needs an additional template parameter Compare. Compare should
* be arithmetically coherent with KeyEqual.
*
* The binary search tree allows the set to have a worst-case scenario of O(log n) for search
* and delete, even if the hash function maps all the elements to the same bucket.
* For insert, the amortized worst case is O(log n), but the worst case is O(n) in case of rehash.
*
* This makes the set resistant to DoS attacks (but doesn't preclude you to have a good hash function,
* as an element in the bucket array is faster to retrieve than in the tree).
*
* @copydoc hopscotch_set
*/
template<class Key,
class Hash = std::hash<Key>,
class KeyEqual = std::equal_to<Key>,
class Compare = std::less<Key>,
class Allocator = std::allocator<Key>,
unsigned int NeighborhoodSize = 62,
bool StoreHash = false,
class GrowthPolicy = tsl::hh::power_of_two_growth_policy<2>>
class bhopscotch_set {
private:
template<typename U>
using has_is_transparent = tsl::detail_hopscotch_hash::has_is_transparent<U>;
class KeySelect {
public:
using key_type = Key;
const key_type& operator()(const Key& key) const {
return key;
}
key_type& operator()(Key& key) {
return key;
}
};
using overflow_container_type = std::set<Key, Compare, Allocator>;
using ht = tsl::detail_hopscotch_hash::hopscotch_hash<Key, KeySelect, void,
Hash, KeyEqual,
Allocator, NeighborhoodSize,
StoreHash, GrowthPolicy,
overflow_container_type>;
public:
using key_type = typename ht::key_type;
using value_type = typename ht::value_type;
using size_type = typename ht::size_type;
using difference_type = typename ht::difference_type;
using hasher = typename ht::hasher;
using key_equal = typename ht::key_equal;
using key_compare = Compare;
using allocator_type = typename ht::allocator_type;
using reference = typename ht::reference;
using const_reference = typename ht::const_reference;
using pointer = typename ht::pointer;
using const_pointer = typename ht::const_pointer;
using iterator = typename ht::iterator;
using const_iterator = typename ht::const_iterator;
/*
* Constructors
*/
bhopscotch_set() : bhopscotch_set(ht::DEFAULT_INIT_BUCKETS_SIZE) {
}
explicit bhopscotch_set(size_type bucket_count,
const Hash& hash = Hash(),
const KeyEqual& equal = KeyEqual(),
const Allocator& alloc = Allocator(),
const Compare& comp = Compare()) :
m_ht(bucket_count, hash, equal, alloc, ht::DEFAULT_MAX_LOAD_FACTOR, comp)
{
}
bhopscotch_set(size_type bucket_count,
const Allocator& alloc) : bhopscotch_set(bucket_count, Hash(), KeyEqual(), alloc)
{
}
bhopscotch_set(size_type bucket_count,
const Hash& hash,
const Allocator& alloc) : bhopscotch_set(bucket_count, hash, KeyEqual(), alloc)
{
}
explicit bhopscotch_set(const Allocator& alloc) : bhopscotch_set(ht::DEFAULT_INIT_BUCKETS_SIZE, alloc) {
}
template<class InputIt>
bhopscotch_set(InputIt first, InputIt last,
size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE,
const Hash& hash = Hash(),
const KeyEqual& equal = KeyEqual(),
const Allocator& alloc = Allocator()) : bhopscotch_set(bucket_count, hash, equal, alloc)
{
insert(first, last);
}
template<class InputIt>
bhopscotch_set(InputIt first, InputIt last,
size_type bucket_count,
const Allocator& alloc) : bhopscotch_set(first, last, bucket_count, Hash(), KeyEqual(), alloc)
{
}
template<class InputIt>
bhopscotch_set(InputIt first, InputIt last,
size_type bucket_count,
const Hash& hash,
const Allocator& alloc) : bhopscotch_set(first, last, bucket_count, hash, KeyEqual(), alloc)
{
}
bhopscotch_set(std::initializer_list<value_type> init,
size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE,
const Hash& hash = Hash(),
const KeyEqual& equal = KeyEqual(),
const Allocator& alloc = Allocator()) :
bhopscotch_set(init.begin(), init.end(), bucket_count, hash, equal, alloc)
{
}
bhopscotch_set(std::initializer_list<value_type> init,
size_type bucket_count,
const Allocator& alloc) :
bhopscotch_set(init.begin(), init.end(), bucket_count, Hash(), KeyEqual(), alloc)
{
}
bhopscotch_set(std::initializer_list<value_type> init,
size_type bucket_count,
const Hash& hash,
const Allocator& alloc) :
bhopscotch_set(init.begin(), init.end(), bucket_count, hash, KeyEqual(), alloc)
{
}
bhopscotch_set& operator=(std::initializer_list<value_type> ilist) {
m_ht.clear();
m_ht.reserve(ilist.size());
m_ht.insert(ilist.begin(), ilist.end());
return *this;
}
allocator_type get_allocator() const { return m_ht.get_allocator(); }
/*
* Iterators
*/
iterator begin() noexcept { return m_ht.begin(); }
const_iterator begin() const noexcept { return m_ht.begin(); }
const_iterator cbegin() const noexcept { return m_ht.cbegin(); }
iterator end() noexcept { return m_ht.end(); }
const_iterator end() const noexcept { return m_ht.end(); }
const_iterator cend() const noexcept { return m_ht.cend(); }
/*
* Capacity
*/
bool empty() const noexcept { return m_ht.empty(); }
size_type size() const noexcept { return m_ht.size(); }
size_type max_size() const noexcept { return m_ht.max_size(); }
/*
* Modifiers
*/
void clear() noexcept { m_ht.clear(); }
std::pair<iterator, bool> insert(const value_type& value) { return m_ht.insert(value); }
std::pair<iterator, bool> insert(value_type&& value) { return m_ht.insert(std::move(value)); }
iterator insert(const_iterator hint, const value_type& value) { return m_ht.insert(hint, value); }
iterator insert(const_iterator hint, value_type&& value) { return m_ht.insert(hint, std::move(value)); }
template<class InputIt>
void insert(InputIt first, InputIt last) { m_ht.insert(first, last); }
void insert(std::initializer_list<value_type> ilist) { m_ht.insert(ilist.begin(), ilist.end()); }
/**
* Due to the way elements are stored, emplace will need to move or copy the key-value once.
* The method is equivalent to insert(value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template<class... Args>
std::pair<iterator, bool> emplace(Args&&... args) { return m_ht.emplace(std::forward<Args>(args)...); }
/**
* Due to the way elements are stored, emplace_hint will need to move or copy the key-value once.
* The method is equivalent to insert(hint, value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template<class... Args>
iterator emplace_hint(const_iterator hint, Args&&... args) {
return m_ht.emplace_hint(hint, std::forward<Args>(args)...);
}
iterator erase(iterator pos) { return m_ht.erase(pos); }
iterator erase(const_iterator pos) { return m_ht.erase(pos); }
iterator erase(const_iterator first, const_iterator last) { return m_ht.erase(first, last); }
size_type erase(const key_type& key) { return m_ht.erase(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
*/
size_type erase(const key_type& key, std::size_t precalculated_hash) {
return m_ht.erase(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent
* and Compare::is_transparent exist.
* If so, K must be hashable and comparable to Key.
*/
template<class K, class KE = KeyEqual, class CP = Compare,
typename std::enable_if<has_is_transparent<KE>::value && has_is_transparent<CP>::value>::type* = nullptr>
size_type erase(const K& key) { return m_ht.erase(key); }
/**
* @copydoc erase(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
*/
template<class K, class KE = KeyEqual, class CP = Compare,
typename std::enable_if<has_is_transparent<KE>::value && has_is_transparent<CP>::value>::type* = nullptr>
size_type erase(const K& key, std::size_t precalculated_hash) { return m_ht.erase(key, precalculated_hash); }
void swap(bhopscotch_set& other) { other.m_ht.swap(m_ht); }
/*
* Lookup
*/
size_type count(const Key& key) const { return m_ht.count(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
size_type count(const Key& key, std::size_t precalculated_hash) const { return m_ht.count(key, precalculated_hash); }
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent
* and Compare::is_transparent exist.
* If so, K must be hashable and comparable to Key.
*/
template<class K, class KE = KeyEqual, class CP = Compare,
typename std::enable_if<has_is_transparent<KE>::value && has_is_transparent<CP>::value>::type* = nullptr>
size_type count(const K& key) const { return m_ht.count(key); }
/**
* @copydoc count(const K& key) const
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template<class K, class KE = KeyEqual, class CP = Compare,
typename std::enable_if<has_is_transparent<KE>::value && has_is_transparent<CP>::value>::type* = nullptr>
size_type count(const K& key, std::size_t precalculated_hash) const { return m_ht.count(key, precalculated_hash); }
iterator find(const Key& key) { return m_ht.find(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
iterator find(const Key& key, std::size_t precalculated_hash) { return m_ht.find(key, precalculated_hash); }
const_iterator find(const Key& key) const { return m_ht.find(key); }
/**
* @copydoc find(const Key& key, std::size_t precalculated_hash)
*/
const_iterator find(const Key& key, std::size_t precalculated_hash) const { return m_ht.find(key, precalculated_hash); }
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent
* and Compare::is_transparent exist.
* If so, K must be hashable and comparable to Key.
*/
template<class K, class KE = KeyEqual, class CP = Compare,
typename std::enable_if<has_is_transparent<KE>::value && has_is_transparent<CP>::value>::type* = nullptr>
iterator find(const K& key) { return m_ht.find(key); }
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template<class K, class KE = KeyEqual, class CP = Compare,
typename std::enable_if<has_is_transparent<KE>::value && has_is_transparent<CP>::value>::type* = nullptr>
iterator find(const K& key, std::size_t precalculated_hash) { return m_ht.find(key, precalculated_hash); }
/**
* @copydoc find(const K& key)
*/
template<class K, class KE = KeyEqual, class CP = Compare,
typename std::enable_if<has_is_transparent<KE>::value && has_is_transparent<CP>::value>::type* = nullptr>
const_iterator find(const K& key) const { return m_ht.find(key); }
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template<class K, class KE = KeyEqual, class CP = Compare,
typename std::enable_if<has_is_transparent<KE>::value && has_is_transparent<CP>::value>::type* = nullptr>
const_iterator find(const K& key, std::size_t precalculated_hash) const { return m_ht.find(key, precalculated_hash); }
std::pair<iterator, iterator> equal_range(const Key& key) { return m_ht.equal_range(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
std::pair<iterator, iterator> equal_range(const Key& key, std::size_t precalculated_hash) {
return m_ht.equal_range(key, precalculated_hash);
}
std::pair<const_iterator, const_iterator> equal_range(const Key& key) const { return m_ht.equal_range(key); }
/**
* @copydoc equal_range(const Key& key, std::size_t precalculated_hash)
*/
std::pair<const_iterator, const_iterator> equal_range(const Key& key, std::size_t precalculated_hash) const {
return m_ht.equal_range(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent
* and Compare::is_transparent exist.
* If so, K must be hashable and comparable to Key.
*/
template<class K, class KE = KeyEqual, class CP = Compare,
typename std::enable_if<has_is_transparent<KE>::value && has_is_transparent<CP>::value>::type* = nullptr>
std::pair<iterator, iterator> equal_range(const K& key) { return m_ht.equal_range(key); }
/**
* @copydoc equal_range(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template<class K, class KE = KeyEqual, class CP = Compare,
typename std::enable_if<has_is_transparent<KE>::value && has_is_transparent<CP>::value>::type* = nullptr>
std::pair<iterator, iterator> equal_range(const K& key, std::size_t precalculated_hash) {
return m_ht.equal_range(key, precalculated_hash);
}
/**
* @copydoc equal_range(const K& key)
*/
template<class K, class KE = KeyEqual, class CP = Compare,
typename std::enable_if<has_is_transparent<KE>::value && has_is_transparent<CP>::value>::type* = nullptr>
std::pair<const_iterator, const_iterator> equal_range(const K& key) const { return m_ht.equal_range(key); }
/**
* @copydoc equal_range(const K& key, std::size_t precalculated_hash)
*/
template<class K, class KE = KeyEqual, class CP = Compare,
typename std::enable_if<has_is_transparent<KE>::value && has_is_transparent<CP>::value>::type* = nullptr>
std::pair<const_iterator, const_iterator> equal_range(const K& key, std::size_t precalculated_hash) const {
return m_ht.equal_range(key, precalculated_hash);
}
/*
* Bucket interface
*/
size_type bucket_count() const { return m_ht.bucket_count(); }
size_type max_bucket_count() const { return m_ht.max_bucket_count(); }
/*
* Hash policy
*/
float load_factor() const { return m_ht.load_factor(); }
float max_load_factor() const { return m_ht.max_load_factor(); }
void max_load_factor(float ml) { m_ht.max_load_factor(ml); }
void rehash(size_type count_) { m_ht.rehash(count_); }
void reserve(size_type count_) { m_ht.reserve(count_); }
/*
* Observers
*/
hasher hash_function() const { return m_ht.hash_function(); }
key_equal key_eq() const { return m_ht.key_eq(); }
key_compare key_comp() const { return m_ht.key_comp(); }
/*
* Other
*/
/**
* Convert a const_iterator to an iterator.
*/
iterator mutable_iterator(const_iterator pos) {
return m_ht.mutable_iterator(pos);
}
size_type overflow_size() const noexcept { return m_ht.overflow_size(); }
friend bool operator==(const bhopscotch_set& lhs, const bhopscotch_set& rhs) {
if(lhs.size() != rhs.size()) {
return false;
}
for(const auto& element_lhs : lhs) {
const auto it_element_rhs = rhs.find(element_lhs);
if(it_element_rhs == rhs.cend()) {
return false;
}
}
return true;
}
friend bool operator!=(const bhopscotch_set& lhs, const bhopscotch_set& rhs) {
return !operator==(lhs, rhs);
}
friend void swap(bhopscotch_set& lhs, bhopscotch_set& rhs) {
lhs.swap(rhs);
}
private:
ht m_ht;
};
/**
* Same as `tsl::bhopscotch_set<Key, Hash, KeyEqual, Compare, Allocator, NeighborhoodSize, StoreHash, tsl::hh::prime_growth_policy>`.
*/
template<class Key,
class Hash = std::hash<Key>,
class KeyEqual = std::equal_to<Key>,
class Compare = std::less<Key>,
class Allocator = std::allocator<Key>,
unsigned int NeighborhoodSize = 62,
bool StoreHash = false>
using bhopscotch_pg_set = bhopscotch_set<Key, Hash, KeyEqual, Compare, Allocator, NeighborhoodSize, StoreHash, tsl::hh::prime_growth_policy>;
} // end namespace tsl
#endif
#include <iostream>
#include "paddle/fluid/feed/src/common/dict_plugin.h"
#include "paddle/fluid/framework/io/fs.h"
namespace paddle {
namespace framework {
int FeasignCacheDict::Load(
const std::string& path, const std::string& converter) {
auto version = version_ + 1;
if (version >= versioned_entity_.size()) {
version = 0;
}
auto& entity = versioned_entity_[version];
uint64_t data_count = 0;
auto file_list = fs_list(path);
for (auto& file_path : file_list) {
int err_no = 0;
int line_len = 0;
size_t buffer_size = 0;
char *buffer = nullptr;
char* data_ptr = NULL;
auto file = fs_open_read(file_path, &err_no, converter);
CHECK(err_no == 0);
while ((line_len = getline(&buffer, &buffer_size, file.get())) > 0) {
if (line_len <= 1) {
continue;
}
++data_count;
entity.Append(strtoul(buffer, &data_ptr, 10), entity.Size());
}
if (buffer != nullptr) {
free(buffer);
}
}
version_ = version;
std::cerr << "Load success data_count" << data_count << " to version:" << version_ << std::endl;
return 0;
}
} // namespace framework
} // namespace paddle
#pragma once
#include <list>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <glog/logging.h>
#include "paddle/fluid/feed/src/common/bhopscotch_map.h"
namespace paddle {
namespace framework {
class DictPlugin {
public:
DictPlugin() {}
virtual ~DictPlugin() {}
virtual int Load(const std::string& path, const std::string& converter) = 0;
};
template <class K, class V>
class KvEntity {
public:
KvEntity() {}
~KvEntity() {}
uint32_t Size() {
return _key_list.size();
}
void Append(const K& k, const V& v) {
if (_dict_data.find(k) != _dict_data.end()) {
return;
}
_key_list.push_back(k);
_dict_data.emplace(k, v);
}
std::vector<K> _key_list;
tsl::bhopscotch_pg_map<K, V> _dict_data;
};
template <class K, class V>
class KvDictPlugin : public DictPlugin {
public:
KvDictPlugin() {
versioned_entity_.resize(2);
}
virtual ~KvDictPlugin() {}
// GetValue with version, Return: value
virtual int GetValueWithVersion(uint32_t version, const K& key, V& v) {
CHECK(version < versioned_entity_.size());
auto& entity = versioned_entity_[version];
auto itr = entity._dict_data.find(key);
if (itr == entity._dict_data.end()) {
return -1; // miss
}
v = itr->second;
return 0;
}
// GetValue without version, Return: value version
virtual int GetValue(const K& key, V& v, uint32_t& version) {
version = version_;
auto& entity = versioned_entity_[version];
auto itr = entity._dict_data.find(key);
if (itr == entity._dict_data.end()) {
return -1; // miss
}
v = itr->second;
return 0;
}
virtual int GetVersion() {
return version_;
}
protected:
uint32_t version_ = 0;
// double-buffer support version:0 1
std::vector<KvEntity<K, V>> versioned_entity_;
};
class FeasignCacheDict : public KvDictPlugin<uint64_t, uint32_t> {
public:
FeasignCacheDict(){}
virtual ~FeasignCacheDict(){}
virtual int Load(const std::string& path, const std::string& converter);
};
class DictPluginManager {
public:
DictPluginManager() {}
virtual ~DictPluginManager(){}
static DictPluginManager& Instance() {
static DictPluginManager manager;
return manager;
}
inline int CreateDict(const std::string& dict_name) {
#define PADDLE_DICT_PLUGIN_REGIST(dict) \
if (dict_name == #dict) { \
dicts_map_[dict_name].reset(new dict()); \
return 0; \
}
PADDLE_DICT_PLUGIN_REGIST(FeasignCacheDict)
#undef PADDLE_DICT_PLUGIN_REGIST
return -1;
}
inline DictPlugin* GetDict(const std::string& dict_name) {
if (dicts_map_.count(dict_name)) {
return dicts_map_[dict_name].get();
}
return nullptr;
}
inline int LoadDict(const std::string& dict_name,
const std::string& path, const std::string converter) {
auto dict = GetDict(dict_name);
if (!dict) {
return -1;
}
return dict->Load(path, converter);
}
private:
std::unordered_map<std::string, std::shared_ptr<DictPlugin>> dicts_map_;
};
} // namespace framework
} // namespace paddle
/**
* MIT License
*
* Copyright (c) 2018 Tessil
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_HOPSCOTCH_GROWTH_POLICY_H
#define TSL_HOPSCOTCH_GROWTH_POLICY_H
#include <algorithm>
#include <array>
#include <climits>
#include <cmath>
#include <cstddef>
#include <iterator>
#include <limits>
#include <ratio>
#include <stdexcept>
namespace tsl {
namespace hh {
/**
* Grow the hash table by a factor of GrowthFactor keeping the bucket count to a power of two. It allows
* the table to use a mask operation instead of a modulo operation to map a hash to a bucket.
*
* GrowthFactor must be a power of two >= 2.
*/
template<std::size_t GrowthFactor>
class power_of_two_growth_policy {
public:
/**
* Called on the hash table creation and on rehash. The number of buckets for the table is passed in parameter.
* This number is a minimum, the policy may update this value with a higher value if needed (but not lower).
*
* If 0 is given, min_bucket_count_in_out must still be 0 after the policy creation and
* bucket_for_hash must always return 0 in this case.
*/
explicit power_of_two_growth_policy(std::size_t& min_bucket_count_in_out) {
if(min_bucket_count_in_out > max_bucket_count()) {
throw std::length_error("The hash table exceeds its maxmimum size.");
}
if(min_bucket_count_in_out > 0) {
min_bucket_count_in_out = round_up_to_power_of_two(min_bucket_count_in_out);
m_mask = min_bucket_count_in_out - 1;
}
else {
m_mask = 0;
}
}
/**
* Return the bucket [0, bucket_count()) to which the hash belongs.
* If bucket_count() is 0, it must always return 0.
*/
std::size_t bucket_for_hash(std::size_t hash) const noexcept {
return hash & m_mask;
}
/**
* Return the bucket count to use when the bucket array grows on rehash.
*/
std::size_t next_bucket_count() const {
if((m_mask + 1) > max_bucket_count() / GrowthFactor) {
throw std::length_error("The hash table exceeds its maxmimum size.");
}
return (m_mask + 1) * GrowthFactor;
}
/**
* Return the maximum number of buckets supported by the policy.
*/
std::size_t max_bucket_count() const {
// Largest power of two.
return (std::numeric_limits<std::size_t>::max() / 2) + 1;
}
/**
* Reset the growth policy as if it was created with a bucket count of 0.
* After a clear, the policy must always return 0 when bucket_for_hash is called.
*/
void clear() noexcept {
m_mask = 0;
}
private:
static std::size_t round_up_to_power_of_two(std::size_t value) {
if(is_power_of_two(value)) {
return value;
}
if(value == 0) {
return 1;
}
--value;
for(std::size_t i = 1; i < sizeof(std::size_t) * CHAR_BIT; i *= 2) {
value |= value >> i;
}
return value + 1;
}
static constexpr bool is_power_of_two(std::size_t value) {
return value != 0 && (value & (value - 1)) == 0;
}
private:
static_assert(is_power_of_two(GrowthFactor) && GrowthFactor >= 2, "GrowthFactor must be a power of two >= 2.");
std::size_t m_mask;
};
/**
* Grow the hash table by GrowthFactor::num / GrowthFactor::den and use a modulo to map a hash
* to a bucket. Slower but it can be useful if you want a slower growth.
*/
template<class GrowthFactor = std::ratio<3, 2>>
class mod_growth_policy {
public:
explicit mod_growth_policy(std::size_t& min_bucket_count_in_out) {
if(min_bucket_count_in_out > max_bucket_count()) {
throw std::length_error("The hash table exceeds its maxmimum size.");
}
if(min_bucket_count_in_out > 0) {
m_mod = min_bucket_count_in_out;
}
else {
m_mod = 1;
}
}
std::size_t bucket_for_hash(std::size_t hash) const noexcept {
return hash % m_mod;
}
std::size_t next_bucket_count() const {
if(m_mod == max_bucket_count()) {
throw std::length_error("The hash table exceeds its maxmimum size.");
}
const double next_bucket_count = std::ceil(double(m_mod) * REHASH_SIZE_MULTIPLICATION_FACTOR);
if(!std::isnormal(next_bucket_count)) {
throw std::length_error("The hash table exceeds its maxmimum size.");
}
if(next_bucket_count > double(max_bucket_count())) {
return max_bucket_count();
}
else {
return std::size_t(next_bucket_count);
}
}
std::size_t max_bucket_count() const {
return MAX_BUCKET_COUNT;
}
void clear() noexcept {
m_mod = 1;
}
private:
static constexpr double REHASH_SIZE_MULTIPLICATION_FACTOR = 1.0 * GrowthFactor::num / GrowthFactor::den;
static const std::size_t MAX_BUCKET_COUNT =
std::size_t(double(
std::numeric_limits<std::size_t>::max() / REHASH_SIZE_MULTIPLICATION_FACTOR
));
static_assert(REHASH_SIZE_MULTIPLICATION_FACTOR >= 1.1, "Growth factor should be >= 1.1.");
std::size_t m_mod;
};
namespace detail {
static constexpr const std::array<std::size_t, 186> PRIMES = {{
1ull, 3ull, 5ull, 7ull, 11ull, 13ull, 17ull, 23ull, 29ull, 37ull, 47ull,
59ull, 73ull, 97ull, 127ull, 151ull, 197ull, 251ull, 313ull, 397ull,
499ull, 631ull, 797ull, 1009ull, 1259ull, 1597ull, 2011ull, 2539ull,
3203ull, 4027ull, 5087ull, 6421ull, 8089ull, 10193ull, 12853ull, 16193ull,
20399ull, 25717ull, 32401ull, 40823ull, 51437ull, 64811ull, 81649ull,
102877ull, 129607ull, 163307ull, 205759ull, 259229ull, 326617ull,
411527ull, 518509ull, 653267ull, 823117ull, 1037059ull, 1306601ull,
1646237ull, 2074129ull, 2613229ull, 3292489ull, 4148279ull, 5226491ull,
6584983ull, 8296553ull, 10453007ull, 13169977ull, 16593127ull, 20906033ull,
26339969ull, 33186281ull, 41812097ull, 52679969ull, 66372617ull,
83624237ull, 105359939ull, 132745199ull, 167248483ull, 210719881ull,
265490441ull, 334496971ull, 421439783ull, 530980861ull, 668993977ull,
842879579ull, 1061961721ull, 1337987929ull, 1685759167ull, 2123923447ull,
2675975881ull, 3371518343ull, 4247846927ull, 5351951779ull, 6743036717ull,
8495693897ull, 10703903591ull, 13486073473ull, 16991387857ull,
21407807219ull, 26972146961ull, 33982775741ull, 42815614441ull,
53944293929ull, 67965551447ull, 85631228929ull, 107888587883ull,
135931102921ull, 171262457903ull, 215777175787ull, 271862205833ull,
342524915839ull, 431554351609ull, 543724411781ull, 685049831731ull,
863108703229ull, 1087448823553ull, 1370099663459ull, 1726217406467ull,
2174897647073ull, 2740199326961ull, 3452434812973ull, 4349795294267ull,
5480398654009ull, 6904869625999ull, 8699590588571ull, 10960797308051ull,
13809739252051ull, 17399181177241ull, 21921594616111ull, 27619478504183ull,
34798362354533ull, 43843189232363ull, 55238957008387ull, 69596724709081ull,
87686378464759ull, 110477914016779ull, 139193449418173ull,
175372756929481ull, 220955828033581ull, 278386898836457ull,
350745513859007ull, 441911656067171ull, 556773797672909ull,
701491027718027ull, 883823312134381ull, 1113547595345903ull,
1402982055436147ull, 1767646624268779ull, 2227095190691797ull,
2805964110872297ull, 3535293248537579ull, 4454190381383713ull,
5611928221744609ull, 7070586497075177ull, 8908380762767489ull,
11223856443489329ull, 14141172994150357ull, 17816761525534927ull,
22447712886978529ull, 28282345988300791ull, 35633523051069991ull,
44895425773957261ull, 56564691976601587ull, 71267046102139967ull,
89790851547914507ull, 113129383953203213ull, 142534092204280003ull,
179581703095829107ull, 226258767906406483ull, 285068184408560057ull,
359163406191658253ull, 452517535812813007ull, 570136368817120201ull,
718326812383316683ull, 905035071625626043ull, 1140272737634240411ull,
1436653624766633509ull, 1810070143251252131ull, 2280545475268481167ull,
2873307249533267101ull, 3620140286502504283ull, 4561090950536962147ull,
5746614499066534157ull, 7240280573005008577ull, 9122181901073924329ull,
11493228998133068689ull, 14480561146010017169ull, 18446744073709551557ull
}};
template<unsigned int IPrime>
static constexpr std::size_t mod(std::size_t hash) { return hash % PRIMES[IPrime]; }
// MOD_PRIME[iprime](hash) returns hash % PRIMES[iprime]. This table allows for faster modulo as the
// compiler can optimize the modulo code better with a constant known at the compilation.
static constexpr const std::array<std::size_t(*)(std::size_t), 186> MOD_PRIME = {{
&mod<0>, &mod<1>, &mod<2>, &mod<3>, &mod<4>, &mod<5>, &mod<6>, &mod<7>, &mod<8>, &mod<9>, &mod<10>,
&mod<11>, &mod<12>, &mod<13>, &mod<14>, &mod<15>, &mod<16>, &mod<17>, &mod<18>, &mod<19>, &mod<20>,
&mod<21>, &mod<22>, &mod<23>, &mod<24>, &mod<25>, &mod<26>, &mod<27>, &mod<28>, &mod<29>, &mod<30>,
&mod<31>, &mod<32>, &mod<33>, &mod<34>, &mod<35>, &mod<36>, &mod<37>, &mod<38>, &mod<39>, &mod<40>,
&mod<41>, &mod<42>, &mod<43>, &mod<44>, &mod<45>, &mod<46>, &mod<47>, &mod<48>, &mod<49>, &mod<50>,
&mod<51>, &mod<52>, &mod<53>, &mod<54>, &mod<55>, &mod<56>, &mod<57>, &mod<58>, &mod<59>, &mod<60>,
&mod<61>, &mod<62>, &mod<63>, &mod<64>, &mod<65>, &mod<66>, &mod<67>, &mod<68>, &mod<69>, &mod<70>,
&mod<71>, &mod<72>, &mod<73>, &mod<74>, &mod<75>, &mod<76>, &mod<77>, &mod<78>, &mod<79>, &mod<80>,
&mod<81>, &mod<82>, &mod<83>, &mod<84>, &mod<85>, &mod<86>, &mod<87>, &mod<88>, &mod<89>, &mod<90>,
&mod<91>, &mod<92>, &mod<93>, &mod<94>, &mod<95>, &mod<96>, &mod<97>, &mod<98>, &mod<99>, &mod<100>,
&mod<101>, &mod<102>, &mod<103>, &mod<104>, &mod<105>, &mod<106>, &mod<107>, &mod<108>, &mod<109>, &mod<110>,
&mod<111>, &mod<112>, &mod<113>, &mod<114>, &mod<115>, &mod<116>, &mod<117>, &mod<118>, &mod<119>, &mod<120>,
&mod<121>, &mod<122>, &mod<123>, &mod<124>, &mod<125>, &mod<126>, &mod<127>, &mod<128>, &mod<129>, &mod<130>,
&mod<131>, &mod<132>, &mod<133>, &mod<134>, &mod<135>, &mod<136>, &mod<137>, &mod<138>, &mod<139>, &mod<140>,
&mod<141>, &mod<142>, &mod<143>, &mod<144>, &mod<145>, &mod<146>, &mod<147>, &mod<148>, &mod<149>, &mod<150>,
&mod<151>, &mod<152>, &mod<153>, &mod<154>, &mod<155>, &mod<156>, &mod<157>, &mod<158>, &mod<159>, &mod<160>,
&mod<161>, &mod<162>, &mod<163>, &mod<164>, &mod<165>, &mod<166>, &mod<167>, &mod<168>, &mod<169>, &mod<170>,
&mod<171>, &mod<172>, &mod<173>, &mod<174>, &mod<175>, &mod<176>, &mod<177>, &mod<178>, &mod<179>, &mod<180>,
&mod<181>, &mod<182>, &mod<183>, &mod<184>, &mod<185>
}};
}
/**
* Grow the hash table by using prime numbers as bucket count. Slower than tsl::hh::power_of_two_growth_policy in
* general but will probably distribute the values around better in the buckets with a poor hash function.
*
* To allow the compiler to optimize the modulo operation, a lookup table is used with constant primes numbers.
*
* With a switch the code would look like:
* \code
* switch(iprime) { // iprime is the current prime of the hash table
* case 0: hash % 5ul;
* break;
* case 1: hash % 17ul;
* break;
* case 2: hash % 29ul;
* break;
* ...
* }
* \endcode
*
* Due to the constant variable in the modulo the compiler is able to optimize the operation
* by a series of multiplications, substractions and shifts.
*
* The 'hash % 5' could become something like 'hash - (hash * 0xCCCCCCCD) >> 34) * 5' in a 64 bits environement.
*/
class prime_growth_policy {
public:
explicit prime_growth_policy(std::size_t& min_bucket_count_in_out) {
auto it_prime = std::lower_bound(detail::PRIMES.begin(),
detail::PRIMES.end(), min_bucket_count_in_out);
if(it_prime == detail::PRIMES.end()) {
throw std::length_error("The hash table exceeds its maxmimum size.");
}
m_iprime = static_cast<unsigned int>(std::distance(detail::PRIMES.begin(), it_prime));
if(min_bucket_count_in_out > 0) {
min_bucket_count_in_out = *it_prime;
}
else {
min_bucket_count_in_out = 0;
}
}
std::size_t bucket_for_hash(std::size_t hash) const noexcept {
return detail::MOD_PRIME[m_iprime](hash);
}
std::size_t next_bucket_count() const {
if(m_iprime + 1 >= detail::PRIMES.size()) {
throw std::length_error("The hash table exceeds its maxmimum size.");
}
return detail::PRIMES[m_iprime + 1];
}
std::size_t max_bucket_count() const {
return detail::PRIMES.back();
}
void clear() noexcept {
m_iprime = 0;
}
private:
unsigned int m_iprime;
static_assert(std::numeric_limits<decltype(m_iprime)>::max() >= detail::PRIMES.size(),
"The type of m_iprime is not big enough.");
};
}
}
#endif
/**
* MIT License
*
* Copyright (c) 2017 Tessil
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_HOPSCOTCH_HASH_H
#define TSL_HOPSCOTCH_HASH_H
#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <exception>
#include <functional>
#include <initializer_list>
#include <iterator>
#include <limits>
#include <memory>
#include <stdexcept>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>
#include "paddle/fluid/feed/src/common/hopscotch_growth_policy.h"
#if (defined(__GNUC__) && (__GNUC__ == 4) && (__GNUC_MINOR__ < 9))
# define TSL_HH_NO_RANGE_ERASE_WITH_CONST_ITERATOR
#endif
/*
* Only activate tsl_hh_assert if TSL_DEBUG is defined.
* This way we avoid the performance hit when NDEBUG is not defined with assert as tsl_hh_assert is used a lot
* (people usually compile with "-O3" and not "-O3 -DNDEBUG").
*/
#ifdef TSL_DEBUG
# define tsl_hh_assert(expr) assert(expr)
#else
# define tsl_hh_assert(expr) (static_cast<void>(0))
#endif
namespace tsl {
namespace detail_hopscotch_hash {
template<typename T>
struct make_void {
using type = void;
};
template<typename T, typename = void>
struct has_is_transparent : std::false_type {
};
template<typename T>
struct has_is_transparent<T, typename make_void<typename T::is_transparent>::type> : std::true_type {
};
template<typename T, typename = void>
struct has_key_compare : std::false_type {
};
template<typename T>
struct has_key_compare<T, typename make_void<typename T::key_compare>::type> : std::true_type {
};
template<typename U>
struct is_power_of_two_policy: std::false_type {
};
template<std::size_t GrowthFactor>
struct is_power_of_two_policy<tsl::hh::power_of_two_growth_policy<GrowthFactor>>: std::true_type {
};
/*
* smallest_type_for_min_bits::type returns the smallest type that can fit MinBits.
*/
static const std::size_t SMALLEST_TYPE_MAX_BITS_SUPPORTED = 64;
template<unsigned int MinBits, typename Enable = void>
class smallest_type_for_min_bits {
};
template<unsigned int MinBits>
class smallest_type_for_min_bits<MinBits, typename std::enable_if<(MinBits > 0) && (MinBits <= 8)>::type> {
public:
using type = std::uint_least8_t;
};
template<unsigned int MinBits>
class smallest_type_for_min_bits<MinBits, typename std::enable_if<(MinBits > 8) && (MinBits <= 16)>::type> {
public:
using type = std::uint_least16_t;
};
template<unsigned int MinBits>
class smallest_type_for_min_bits<MinBits, typename std::enable_if<(MinBits > 16) && (MinBits <= 32)>::type> {
public:
using type = std::uint_least32_t;
};
template<unsigned int MinBits>
class smallest_type_for_min_bits<MinBits, typename std::enable_if<(MinBits > 32) && (MinBits <= 64)>::type> {
public:
using type = std::uint_least64_t;
};
/*
* Each bucket may store up to three elements:
* - An aligned storage to store a value_type object with placement-new.
* - An (optional) hash of the value in the bucket.
* - An unsigned integer of type neighborhood_bitmap used to tell us which buckets in the neighborhood of the
* current bucket contain a value with a hash belonging to the current bucket.
*
* For a bucket 'bct', a bit 'i' (counting from 0 and from the least significant bit to the most significant)
* set to 1 means that the bucket 'bct + i' contains a value with a hash belonging to bucket 'bct'.
* The bits used for that, start from the third least significant bit.
* The two least significant bits are reserved:
* - The least significant bit is set to 1 if there is a value in the bucket storage.
* - The second least significant bit is set to 1 if there is an overflow. More than NeighborhoodSize values
* give the same hash, all overflow values are stored in the m_overflow_elements list of the map.
*
* Details regarding hopscotch hashing an its implementation can be found here:
* https://tessil.github.io/2016/08/29/hopscotch-hashing.html
*/
static const std::size_t NB_RESERVED_BITS_IN_NEIGHBORHOOD = 2;
using truncated_hash_type = std::uint_least32_t;
/**
* Helper class that stores a truncated hash if StoreHash is true and nothing otherwise.
*/
template<bool StoreHash>
class hopscotch_bucket_hash {
public:
bool bucket_hash_equal(std::size_t /*hash*/) const noexcept {
return true;
}
truncated_hash_type truncated_bucket_hash() const noexcept {
return 0;
}
protected:
void copy_hash(const hopscotch_bucket_hash& ) noexcept {
}
void set_hash(truncated_hash_type /*hash*/) noexcept {
}
};
template<>
class hopscotch_bucket_hash<true> {
public:
bool bucket_hash_equal(std::size_t hash) const noexcept {
return m_hash == truncated_hash_type(hash);
}
truncated_hash_type truncated_bucket_hash() const noexcept {
return m_hash;
}
protected:
void copy_hash(const hopscotch_bucket_hash& bucket) noexcept {
m_hash = bucket.m_hash;
}
void set_hash(truncated_hash_type hash) noexcept {
m_hash = hash;
}
private:
truncated_hash_type m_hash;
};
template<typename ValueType, unsigned int NeighborhoodSize, bool StoreHash>
class hopscotch_bucket: public hopscotch_bucket_hash<StoreHash> {
private:
static const std::size_t MIN_NEIGHBORHOOD_SIZE = 4;
static const std::size_t MAX_NEIGHBORHOOD_SIZE = SMALLEST_TYPE_MAX_BITS_SUPPORTED - NB_RESERVED_BITS_IN_NEIGHBORHOOD;
static_assert(NeighborhoodSize >= 4, "NeighborhoodSize should be >= 4.");
// We can't put a variable in the message, ensure coherence
static_assert(MIN_NEIGHBORHOOD_SIZE == 4, "");
static_assert(NeighborhoodSize <= 62, "NeighborhoodSize should be <= 62.");
// We can't put a variable in the message, ensure coherence
static_assert(MAX_NEIGHBORHOOD_SIZE == 62, "");
static_assert(!StoreHash || NeighborhoodSize <= 30,
"NeighborhoodSize should be <= 30 if StoreHash is true.");
// We can't put a variable in the message, ensure coherence
static_assert(MAX_NEIGHBORHOOD_SIZE - 32 == 30, "");
using bucket_hash = hopscotch_bucket_hash<StoreHash>;
public:
using value_type = ValueType;
using neighborhood_bitmap =
typename smallest_type_for_min_bits<NeighborhoodSize + NB_RESERVED_BITS_IN_NEIGHBORHOOD>::type;
hopscotch_bucket() noexcept: bucket_hash(), m_neighborhood_infos(0) {
tsl_hh_assert(empty());
}
hopscotch_bucket(const hopscotch_bucket& bucket)
noexcept(std::is_nothrow_copy_constructible<value_type>::value): bucket_hash(bucket),
m_neighborhood_infos(0)
{
if(!bucket.empty()) {
::new (static_cast<void*>(std::addressof(m_value))) value_type(bucket.value());
}
m_neighborhood_infos = bucket.m_neighborhood_infos;
}
hopscotch_bucket(hopscotch_bucket&& bucket)
noexcept(std::is_nothrow_move_constructible<value_type>::value) : bucket_hash(std::move(bucket)),
m_neighborhood_infos(0)
{
if(!bucket.empty()) {
::new (static_cast<void*>(std::addressof(m_value))) value_type(std::move(bucket.value()));
}
m_neighborhood_infos = bucket.m_neighborhood_infos;
}
hopscotch_bucket& operator=(const hopscotch_bucket& bucket)
noexcept(std::is_nothrow_copy_constructible<value_type>::value)
{
if(this != &bucket) {
remove_value();
bucket_hash::operator=(bucket);
if(!bucket.empty()) {
::new (static_cast<void*>(std::addressof(m_value))) value_type(bucket.value());
}
m_neighborhood_infos = bucket.m_neighborhood_infos;
}
return *this;
}
hopscotch_bucket& operator=(hopscotch_bucket&& ) = delete;
~hopscotch_bucket() noexcept {
if(!empty()) {
destroy_value();
}
}
neighborhood_bitmap neighborhood_infos() const noexcept {
return neighborhood_bitmap(m_neighborhood_infos >> NB_RESERVED_BITS_IN_NEIGHBORHOOD);
}
void set_overflow(bool has_overflow) noexcept {
if(has_overflow) {
m_neighborhood_infos = neighborhood_bitmap(m_neighborhood_infos | 2);
}
else {
m_neighborhood_infos = neighborhood_bitmap(m_neighborhood_infos & ~2);
}
}
bool has_overflow() const noexcept {
return (m_neighborhood_infos & 2) != 0;
}
bool empty() const noexcept {
return (m_neighborhood_infos & 1) == 0;
}
void toggle_neighbor_presence(std::size_t ineighbor) noexcept {
tsl_hh_assert(ineighbor <= NeighborhoodSize);
m_neighborhood_infos = neighborhood_bitmap(
m_neighborhood_infos ^ (1ull << (ineighbor + NB_RESERVED_BITS_IN_NEIGHBORHOOD)));
}
bool check_neighbor_presence(std::size_t ineighbor) const noexcept {
tsl_hh_assert(ineighbor <= NeighborhoodSize);
if(((m_neighborhood_infos >> (ineighbor + NB_RESERVED_BITS_IN_NEIGHBORHOOD)) & 1) == 1) {
return true;
}
return false;
}
value_type& value() noexcept {
tsl_hh_assert(!empty());
return *reinterpret_cast<value_type*>(std::addressof(m_value));
}
const value_type& value() const noexcept {
tsl_hh_assert(!empty());
return *reinterpret_cast<const value_type*>(std::addressof(m_value));
}
template<typename... Args>
void set_value_of_empty_bucket(truncated_hash_type hash, Args&&... value_type_args) {
tsl_hh_assert(empty());
::new (static_cast<void*>(std::addressof(m_value))) value_type(std::forward<Args>(value_type_args)...);
set_empty(false);
this->set_hash(hash);
}
void swap_value_into_empty_bucket(hopscotch_bucket& empty_bucket) {
tsl_hh_assert(empty_bucket.empty());
if(!empty()) {
::new (static_cast<void*>(std::addressof(empty_bucket.m_value))) value_type(std::move(value()));
empty_bucket.copy_hash(*this);
empty_bucket.set_empty(false);
destroy_value();
set_empty(true);
}
}
void remove_value() noexcept {
if(!empty()) {
destroy_value();
set_empty(true);
}
}
void clear() noexcept {
if(!empty()) {
destroy_value();
}
m_neighborhood_infos = 0;
tsl_hh_assert(empty());
}
static truncated_hash_type truncate_hash(std::size_t hash) noexcept {
return truncated_hash_type(hash);
}
private:
void set_empty(bool is_empty) noexcept {
if(is_empty) {
m_neighborhood_infos = neighborhood_bitmap(m_neighborhood_infos & ~1);
}
else {
m_neighborhood_infos = neighborhood_bitmap(m_neighborhood_infos | 1);
}
}
void destroy_value() noexcept {
tsl_hh_assert(!empty());
value().~value_type();
}
private:
using storage = typename std::aligned_storage<sizeof(value_type), alignof(value_type)>::type;
neighborhood_bitmap m_neighborhood_infos;
storage m_value;
};
/**
* Internal common class used by (b)hopscotch_map and (b)hopscotch_set.
*
* ValueType is what will be stored by hopscotch_hash (usually std::pair<Key, T> for a map and Key for a set).
*
* KeySelect should be a FunctionObject which takes a ValueType in parameter and returns a reference to the key.
*
* ValueSelect should be a FunctionObject which takes a ValueType in parameter and returns a reference to the value.
* ValueSelect should be void if there is no value (in a set for example).
*
* OverflowContainer will be used as containers for overflown elements. Usually it should be a list<ValueType>
* or a set<Key>/map<Key, T>.
*/
template<class ValueType,
class KeySelect,
class ValueSelect,
class Hash,
class KeyEqual,
class Allocator,
unsigned int NeighborhoodSize,
bool StoreHash,
class GrowthPolicy,
class OverflowContainer>
class hopscotch_hash: private Hash, private KeyEqual, private GrowthPolicy {
private:
template<typename U>
using has_mapped_type = typename std::integral_constant<bool, !std::is_same<U, void>::value>;
static_assert(noexcept(std::declval<GrowthPolicy>().bucket_for_hash(std::size_t(0))), "GrowthPolicy::bucket_for_hash must be noexcept.");
static_assert(noexcept(std::declval<GrowthPolicy>().clear()), "GrowthPolicy::clear must be noexcept.");
public:
template<bool IsConst>
class hopscotch_iterator;
using key_type = typename KeySelect::key_type;
using value_type = ValueType;
using size_type = std::size_t;
using difference_type = std::ptrdiff_t;
using hasher = Hash;
using key_equal = KeyEqual;
using allocator_type = Allocator;
using reference = value_type&;
using const_reference = const value_type&;
using pointer = value_type*;
using const_pointer = const value_type*;
using iterator = hopscotch_iterator<false>;
using const_iterator = hopscotch_iterator<true>;
private:
using hopscotch_bucket = tsl::detail_hopscotch_hash::hopscotch_bucket<ValueType, NeighborhoodSize, StoreHash>;
using neighborhood_bitmap = typename hopscotch_bucket::neighborhood_bitmap;
using buckets_allocator = typename std::allocator_traits<allocator_type>::template rebind_alloc<hopscotch_bucket>;
using buckets_container_type = std::vector<hopscotch_bucket, buckets_allocator>;
using overflow_container_type = OverflowContainer;
static_assert(std::is_same<typename overflow_container_type::value_type, ValueType>::value,
"OverflowContainer should have ValueType as type.");
static_assert(std::is_same<typename overflow_container_type::allocator_type, Allocator>::value,
"Invalid allocator, not the same type as the value_type.");
using iterator_buckets = typename buckets_container_type::iterator;
using const_iterator_buckets = typename buckets_container_type::const_iterator;
using iterator_overflow = typename overflow_container_type::iterator;
using const_iterator_overflow = typename overflow_container_type::const_iterator;
public:
/**
* The `operator*()` and `operator->()` methods return a const reference and const pointer respectively to the
* stored value type.
*
* In case of a map, to get a modifiable reference to the value associated to a key (the `.second` in the
* stored pair), you have to call `value()`.
*/
template<bool IsConst>
class hopscotch_iterator {
friend class hopscotch_hash;
private:
using iterator_bucket = typename std::conditional<IsConst,
typename hopscotch_hash::const_iterator_buckets,
typename hopscotch_hash::iterator_buckets>::type;
using iterator_overflow = typename std::conditional<IsConst,
typename hopscotch_hash::const_iterator_overflow,
typename hopscotch_hash::iterator_overflow>::type;
hopscotch_iterator(iterator_bucket buckets_iterator, iterator_bucket buckets_end_iterator,
iterator_overflow overflow_iterator) noexcept :
m_buckets_iterator(buckets_iterator), m_buckets_end_iterator(buckets_end_iterator),
m_overflow_iterator(overflow_iterator)
{
}
public:
using iterator_category = std::forward_iterator_tag;
using value_type = const typename hopscotch_hash::value_type;
using difference_type = std::ptrdiff_t;
using reference = value_type&;
using pointer = value_type*;
hopscotch_iterator() noexcept {
}
// Copy constructor from iterator to const_iterator.
template<bool TIsConst = IsConst, typename std::enable_if<TIsConst>::type* = nullptr>
hopscotch_iterator(const hopscotch_iterator<!TIsConst>& other) noexcept :
m_buckets_iterator(other.m_buckets_iterator), m_buckets_end_iterator(other.m_buckets_end_iterator),
m_overflow_iterator(other.m_overflow_iterator)
{
}
hopscotch_iterator(const hopscotch_iterator& other) = default;
hopscotch_iterator(hopscotch_iterator&& other) = default;
hopscotch_iterator& operator=(const hopscotch_iterator& other) = default;
hopscotch_iterator& operator=(hopscotch_iterator&& other) = default;
const typename hopscotch_hash::key_type& key() const {
if(m_buckets_iterator != m_buckets_end_iterator) {
return KeySelect()(m_buckets_iterator->value());
}
return KeySelect()(*m_overflow_iterator);
}
template<class U = ValueSelect, typename std::enable_if<has_mapped_type<U>::value>::type* = nullptr>
typename std::conditional<
IsConst,
const typename U::value_type&,
typename U::value_type&>::type value() const
{
if(m_buckets_iterator != m_buckets_end_iterator) {
return U()(m_buckets_iterator->value());
}
return U()(*m_overflow_iterator);
}
reference operator*() const {
if(m_buckets_iterator != m_buckets_end_iterator) {
return m_buckets_iterator->value();
}
return *m_overflow_iterator;
}
pointer operator->() const {
if(m_buckets_iterator != m_buckets_end_iterator) {
return std::addressof(m_buckets_iterator->value());
}
return std::addressof(*m_overflow_iterator);
}
hopscotch_iterator& operator++() {
if(m_buckets_iterator == m_buckets_end_iterator) {
++m_overflow_iterator;
return *this;
}
do {
++m_buckets_iterator;
} while(m_buckets_iterator != m_buckets_end_iterator && m_buckets_iterator->empty());
return *this;
}
hopscotch_iterator operator++(int) {
hopscotch_iterator tmp(*this);
++*this;
return tmp;
}
friend bool operator==(const hopscotch_iterator& lhs, const hopscotch_iterator& rhs) {
return lhs.m_buckets_iterator == rhs.m_buckets_iterator &&
lhs.m_overflow_iterator == rhs.m_overflow_iterator;
}
friend bool operator!=(const hopscotch_iterator& lhs, const hopscotch_iterator& rhs) {
return !(lhs == rhs);
}
private:
iterator_bucket m_buckets_iterator;
iterator_bucket m_buckets_end_iterator;
iterator_overflow m_overflow_iterator;
};
public:
template<class OC = OverflowContainer, typename std::enable_if<!has_key_compare<OC>::value>::type* = nullptr>
hopscotch_hash(size_type bucket_count,
const Hash& hash,
const KeyEqual& equal,
const Allocator& alloc,
float max_load_factor) : Hash(hash),
KeyEqual(equal),
GrowthPolicy(bucket_count),
m_buckets_data(alloc),
m_overflow_elements(alloc),
m_buckets(static_empty_bucket_ptr()),
m_nb_elements(0)
{
if(bucket_count > max_bucket_count()) {
throw std::length_error("The map exceeds its maxmimum size.");
}
if(bucket_count > 0) {
static_assert(NeighborhoodSize - 1 > 0, "");
// Can't directly construct with the appropriate size in the initializer
// as m_buckets_data(bucket_count, alloc) is not supported by GCC 4.8
m_buckets_data.resize(bucket_count + NeighborhoodSize - 1);
m_buckets = m_buckets_data.data();
}
this->max_load_factor(max_load_factor);
// Check in the constructor instead of outside of a function to avoi compilation issues
// when value_type is not complete.
static_assert(std::is_nothrow_move_constructible<value_type>::value ||
std::is_copy_constructible<value_type>::value,
"value_type must be either copy constructible or nothrow move constructible.");
}
template<class OC = OverflowContainer, typename std::enable_if<has_key_compare<OC>::value>::type* = nullptr>
hopscotch_hash(size_type bucket_count,
const Hash& hash,
const KeyEqual& equal,
const Allocator& alloc,
float max_load_factor,
const typename OC::key_compare& comp) : Hash(hash),
KeyEqual(equal),
GrowthPolicy(bucket_count),
m_buckets_data(alloc),
m_overflow_elements(comp, alloc),
m_buckets(static_empty_bucket_ptr()),
m_nb_elements(0)
{
if(bucket_count > max_bucket_count()) {
throw std::length_error("The map exceeds its maxmimum size.");
}
if(bucket_count > 0) {
static_assert(NeighborhoodSize - 1 > 0, "");
// Can't directly construct with the appropriate size in the initializer
// as m_buckets_data(bucket_count, alloc) is not supported by GCC 4.8
m_buckets_data.resize(bucket_count + NeighborhoodSize - 1);
m_buckets = m_buckets_data.data();
}
this->max_load_factor(max_load_factor);
// Check in the constructor instead of outside of a function to avoi compilation issues
// when value_type is not complete.
static_assert(std::is_nothrow_move_constructible<value_type>::value ||
std::is_copy_constructible<value_type>::value,
"value_type must be either copy constructible or nothrow move constructible.");
}
hopscotch_hash(const hopscotch_hash& other):
Hash(other),
KeyEqual(other),
GrowthPolicy(other),
m_buckets_data(other.m_buckets_data),
m_overflow_elements(other.m_overflow_elements),
m_buckets(m_buckets_data.empty()?static_empty_bucket_ptr():
m_buckets_data.data()),
m_nb_elements(other.m_nb_elements),
m_max_load_factor(other.m_max_load_factor),
m_max_load_threshold_rehash(other.m_max_load_threshold_rehash),
m_min_load_threshold_rehash(other.m_min_load_threshold_rehash)
{
}
hopscotch_hash(hopscotch_hash&& other)
noexcept(
std::is_nothrow_move_constructible<Hash>::value &&
std::is_nothrow_move_constructible<KeyEqual>::value &&
std::is_nothrow_move_constructible<GrowthPolicy>::value &&
std::is_nothrow_move_constructible<buckets_container_type>::value &&
std::is_nothrow_move_constructible<overflow_container_type>::value
):
Hash(std::move(static_cast<Hash&>(other))),
KeyEqual(std::move(static_cast<KeyEqual&>(other))),
GrowthPolicy(std::move(static_cast<GrowthPolicy&>(other))),
m_buckets_data(std::move(other.m_buckets_data)),
m_overflow_elements(std::move(other.m_overflow_elements)),
m_buckets(m_buckets_data.empty()?static_empty_bucket_ptr():
m_buckets_data.data()),
m_nb_elements(other.m_nb_elements),
m_max_load_factor(other.m_max_load_factor),
m_max_load_threshold_rehash(other.m_max_load_threshold_rehash),
m_min_load_threshold_rehash(other.m_min_load_threshold_rehash)
{
other.GrowthPolicy::clear();
other.m_buckets_data.clear();
other.m_overflow_elements.clear();
other.m_buckets = static_empty_bucket_ptr();
other.m_nb_elements = 0;
other.m_max_load_threshold_rehash = 0;
other.m_min_load_threshold_rehash = 0;
}
hopscotch_hash& operator=(const hopscotch_hash& other) {
if(&other != this) {
Hash::operator=(other);
KeyEqual::operator=(other);
GrowthPolicy::operator=(other);
m_buckets_data = other.m_buckets_data;
m_overflow_elements = other.m_overflow_elements;
m_buckets = m_buckets_data.empty()?static_empty_bucket_ptr():
m_buckets_data.data();
m_nb_elements = other.m_nb_elements;
m_max_load_factor = other.m_max_load_factor;
m_max_load_threshold_rehash = other.m_max_load_threshold_rehash;
m_min_load_threshold_rehash = other.m_min_load_threshold_rehash;
}
return *this;
}
hopscotch_hash& operator=(hopscotch_hash&& other) {
other.swap(*this);
other.clear();
return *this;
}
allocator_type get_allocator() const {
return m_buckets_data.get_allocator();
}
/*
* Iterators
*/
iterator begin() noexcept {
auto begin = m_buckets_data.begin();
while(begin != m_buckets_data.end() && begin->empty()) {
++begin;
}
return iterator(begin, m_buckets_data.end(), m_overflow_elements.begin());
}
const_iterator begin() const noexcept {
return cbegin();
}
const_iterator cbegin() const noexcept {
auto begin = m_buckets_data.cbegin();
while(begin != m_buckets_data.cend() && begin->empty()) {
++begin;
}
return const_iterator(begin, m_buckets_data.cend(), m_overflow_elements.cbegin());
}
iterator end() noexcept {
return iterator(m_buckets_data.end(), m_buckets_data.end(), m_overflow_elements.end());
}
const_iterator end() const noexcept {
return cend();
}
const_iterator cend() const noexcept {
return const_iterator(m_buckets_data.cend(), m_buckets_data.cend(), m_overflow_elements.cend());
}
/*
* Capacity
*/
bool empty() const noexcept {
return m_nb_elements == 0;
}
size_type size() const noexcept {
return m_nb_elements;
}
size_type max_size() const noexcept {
return m_buckets_data.max_size();
}
/*
* Modifiers
*/
void clear() noexcept {
for(auto& bucket: m_buckets_data) {
bucket.clear();
}
m_overflow_elements.clear();
m_nb_elements = 0;
}
std::pair<iterator, bool> insert(const value_type& value) {
return insert_impl(value);
}
template<class P, typename std::enable_if<std::is_constructible<value_type, P&&>::value>::type* = nullptr>
std::pair<iterator, bool> insert(P&& value) {
return insert_impl(value_type(std::forward<P>(value)));
}
std::pair<iterator, bool> insert(value_type&& value) {
return insert_impl(std::move(value));
}
iterator insert(const_iterator hint, const value_type& value) {
if(hint != cend() && compare_keys(KeySelect()(*hint), KeySelect()(value))) {
return mutable_iterator(hint);
}
return insert(value).first;
}
template<class P, typename std::enable_if<std::is_constructible<value_type, P&&>::value>::type* = nullptr>
iterator insert(const_iterator hint, P&& value) {
return emplace_hint(hint, std::forward<P>(value));
}
iterator insert(const_iterator hint, value_type&& value) {
if(hint != cend() && compare_keys(KeySelect()(*hint), KeySelect()(value))) {
return mutable_iterator(hint);
}
return insert(std::move(value)).first;
}
template<class InputIt>
void insert(InputIt first, InputIt last) {
if(std::is_base_of<std::forward_iterator_tag,
typename std::iterator_traits<InputIt>::iterator_category>::value)
{
const auto nb_elements_insert = std::distance(first, last);
const std::size_t nb_elements_in_buckets = m_nb_elements - m_overflow_elements.size();
const std::size_t nb_free_buckets = m_max_load_threshold_rehash - nb_elements_in_buckets;
tsl_hh_assert(m_nb_elements >= m_overflow_elements.size());
tsl_hh_assert(m_max_load_threshold_rehash >= nb_elements_in_buckets);
if(nb_elements_insert > 0 && nb_free_buckets < std::size_t(nb_elements_insert)) {
reserve(nb_elements_in_buckets + std::size_t(nb_elements_insert));
}
}
for(; first != last; ++first) {
insert(*first);
}
}
template<class M>
std::pair<iterator, bool> insert_or_assign(const key_type& k, M&& obj) {
return insert_or_assign_impl(k, std::forward<M>(obj));
}
template<class M>
std::pair<iterator, bool> insert_or_assign(key_type&& k, M&& obj) {
return insert_or_assign_impl(std::move(k), std::forward<M>(obj));
}
template<class M>
iterator insert_or_assign(const_iterator hint, const key_type& k, M&& obj) {
if(hint != cend() && compare_keys(KeySelect()(*hint), k)) {
auto it = mutable_iterator(hint);
it.value() = std::forward<M>(obj);
return it;
}
return insert_or_assign(k, std::forward<M>(obj)).first;
}
template<class M>
iterator insert_or_assign(const_iterator hint, key_type&& k, M&& obj) {
if(hint != cend() && compare_keys(KeySelect()(*hint), k)) {
auto it = mutable_iterator(hint);
it.value() = std::forward<M>(obj);
return it;
}
return insert_or_assign(std::move(k), std::forward<M>(obj)).first;
}
template<class... Args>
std::pair<iterator, bool> emplace(Args&&... args) {
return insert(value_type(std::forward<Args>(args)...));
}
template<class... Args>
iterator emplace_hint(const_iterator hint, Args&&... args) {
return insert(hint, value_type(std::forward<Args>(args)...));
}
template<class... Args>
std::pair<iterator, bool> try_emplace(const key_type& k, Args&&... args) {
return try_emplace_impl(k, std::forward<Args>(args)...);
}
template<class... Args>
std::pair<iterator, bool> try_emplace(key_type&& k, Args&&... args) {
return try_emplace_impl(std::move(k), std::forward<Args>(args)...);
}
template<class... Args>
iterator try_emplace(const_iterator hint, const key_type& k, Args&&... args) {
if(hint != cend() && compare_keys(KeySelect()(*hint), k)) {
return mutable_iterator(hint);
}
return try_emplace(k, std::forward<Args>(args)...).first;
}
template<class... Args>
iterator try_emplace(const_iterator hint, key_type&& k, Args&&... args) {
if(hint != cend() && compare_keys(KeySelect()(*hint), k)) {
return mutable_iterator(hint);
}
return try_emplace(std::move(k), std::forward<Args>(args)...).first;
}
/**
* Here to avoid `template<class K> size_type erase(const K& key)` being used when
* we use an iterator instead of a const_iterator.
*/
iterator erase(iterator pos) {
return erase(const_iterator(pos));
}
iterator erase(const_iterator pos) {
const std::size_t ibucket_for_hash = bucket_for_hash(hash_key(pos.key()));
if(pos.m_buckets_iterator != pos.m_buckets_end_iterator) {
auto it_bucket = m_buckets_data.begin() + std::distance(m_buckets_data.cbegin(), pos.m_buckets_iterator);
erase_from_bucket(*it_bucket, ibucket_for_hash);
return ++iterator(it_bucket, m_buckets_data.end(), m_overflow_elements.begin());
}
else {
auto it_next_overflow = erase_from_overflow(pos.m_overflow_iterator, ibucket_for_hash);
return iterator(m_buckets_data.end(), m_buckets_data.end(), it_next_overflow);
}
}
iterator erase(const_iterator first, const_iterator last) {
if(first == last) {
return mutable_iterator(first);
}
auto to_delete = erase(first);
while(to_delete != last) {
to_delete = erase(to_delete);
}
return to_delete;
}
template<class K>
size_type erase(const K& key) {
return erase(key, hash_key(key));
}
template<class K>
size_type erase(const K& key, std::size_t hash) {
const std::size_t ibucket_for_hash = bucket_for_hash(hash);
hopscotch_bucket* bucket_found = find_in_buckets(key, hash, m_buckets + ibucket_for_hash);
if(bucket_found != nullptr) {
erase_from_bucket(*bucket_found, ibucket_for_hash);
return 1;
}
if(m_buckets[ibucket_for_hash].has_overflow()) {
auto it_overflow = find_in_overflow(key);
if(it_overflow != m_overflow_elements.end()) {
erase_from_overflow(it_overflow, ibucket_for_hash);
return 1;
}
}
return 0;
}
void swap(hopscotch_hash& other) {
using std::swap;
swap(static_cast<Hash&>(*this), static_cast<Hash&>(other));
swap(static_cast<KeyEqual&>(*this), static_cast<KeyEqual&>(other));
swap(static_cast<GrowthPolicy&>(*this), static_cast<GrowthPolicy&>(other));
swap(m_buckets_data, other.m_buckets_data);
swap(m_overflow_elements, other.m_overflow_elements);
swap(m_buckets, other.m_buckets);
swap(m_nb_elements, other.m_nb_elements);
swap(m_max_load_factor, other.m_max_load_factor);
swap(m_max_load_threshold_rehash, other.m_max_load_threshold_rehash);
swap(m_min_load_threshold_rehash, other.m_min_load_threshold_rehash);
}
/*
* Lookup
*/
template<class K, class U = ValueSelect, typename std::enable_if<has_mapped_type<U>::value>::type* = nullptr>
typename U::value_type& at(const K& key) {
return at(key, hash_key(key));
}
template<class K, class U = ValueSelect, typename std::enable_if<has_mapped_type<U>::value>::type* = nullptr>
typename U::value_type& at(const K& key, std::size_t hash) {
return const_cast<typename U::value_type&>(static_cast<const hopscotch_hash*>(this)->at(key, hash));
}
template<class K, class U = ValueSelect, typename std::enable_if<has_mapped_type<U>::value>::type* = nullptr>
const typename U::value_type& at(const K& key) const {
return at(key, hash_key(key));
}
template<class K, class U = ValueSelect, typename std::enable_if<has_mapped_type<U>::value>::type* = nullptr>
const typename U::value_type& at(const K& key, std::size_t hash) const {
using T = typename U::value_type;
const T* value = find_value_impl(key, hash, m_buckets + bucket_for_hash(hash));
if(value == nullptr) {
throw std::out_of_range("Couldn't find key.");
}
else {
return *value;
}
}
template<class K, class U = ValueSelect, typename std::enable_if<has_mapped_type<U>::value>::type* = nullptr>
typename U::value_type& operator[](K&& key) {
using T = typename U::value_type;
const std::size_t hash = hash_key(key);
const std::size_t ibucket_for_hash = bucket_for_hash(hash);
T* value = find_value_impl(key, hash, m_buckets + ibucket_for_hash);
if(value != nullptr) {
return *value;
}
else {
return insert_value(ibucket_for_hash, hash, std::piecewise_construct,
std::forward_as_tuple(std::forward<K>(key)),
std::forward_as_tuple()).first.value();
}
}
template<class K>
size_type count(const K& key) const {
return count(key, hash_key(key));
}
template<class K>
size_type count(const K& key, std::size_t hash) const {
return count_impl(key, hash, m_buckets + bucket_for_hash(hash));
}
template<class K>
iterator find(const K& key) {
return find(key, hash_key(key));
}
template<class K>
iterator find(const K& key, std::size_t hash) {
return find_impl(key, hash, m_buckets + bucket_for_hash(hash));
}
template<class K>
const_iterator find(const K& key) const {
return find(key, hash_key(key));
}
template<class K>
const_iterator find(const K& key, std::size_t hash) const {
return find_impl(key, hash, m_buckets + bucket_for_hash(hash));
}
template<class K>
std::pair<iterator, iterator> equal_range(const K& key) {
return equal_range(key, hash_key(key));
}
template<class K>
std::pair<iterator, iterator> equal_range(const K& key, std::size_t hash) {
iterator it = find(key, hash);
return std::make_pair(it, (it == end())?it:std::next(it));
}
template<class K>
std::pair<const_iterator, const_iterator> equal_range(const K& key) const {
return equal_range(key, hash_key(key));
}
template<class K>
std::pair<const_iterator, const_iterator> equal_range(const K& key, std::size_t hash) const {
const_iterator it = find(key, hash);
return std::make_pair(it, (it == cend())?it:std::next(it));
}
/*
* Bucket interface
*/
size_type bucket_count() const {
/*
* So that the last bucket can have NeighborhoodSize neighbors, the size of the bucket array is a little
* bigger than the real number of buckets when not empty.
* We could use some of the buckets at the beginning, but it is faster this way as we avoid extra checks.
*/
if(m_buckets_data.empty()) {
return 0;
}
return m_buckets_data.size() - NeighborhoodSize + 1;
}
size_type max_bucket_count() const {
const std::size_t max_bucket_count = std::min(GrowthPolicy::max_bucket_count(), m_buckets_data.max_size());
return max_bucket_count - NeighborhoodSize + 1;
}
/*
* Hash policy
*/
float load_factor() const {
if(bucket_count() == 0) {
return 0;
}
return float(m_nb_elements)/float(bucket_count());
}
float max_load_factor() const {
return m_max_load_factor;
}
void max_load_factor(float ml) {
m_max_load_factor = std::max(0.1f, std::min(ml, 0.95f));
m_max_load_threshold_rehash = size_type(float(bucket_count())*m_max_load_factor);
m_min_load_threshold_rehash = size_type(float(bucket_count())*MIN_LOAD_FACTOR_FOR_REHASH);
}
void rehash(size_type count_) {
count_ = std::max(count_, size_type(std::ceil(float(size())/max_load_factor())));
rehash_impl(count_);
}
void reserve(size_type count_) {
rehash(size_type(std::ceil(float(count_)/max_load_factor())));
}
/*
* Observers
*/
hasher hash_function() const {
return static_cast<const Hash&>(*this);
}
key_equal key_eq() const {
return static_cast<const KeyEqual&>(*this);
}
/*
* Other
*/
iterator mutable_iterator(const_iterator pos) {
if(pos.m_buckets_iterator != pos.m_buckets_end_iterator) {
// Get a non-const iterator
auto it = m_buckets_data.begin() + std::distance(m_buckets_data.cbegin(), pos.m_buckets_iterator);
return iterator(it, m_buckets_data.end(), m_overflow_elements.begin());
}
else {
// Get a non-const iterator
auto it = mutable_overflow_iterator(pos.m_overflow_iterator);
return iterator(m_buckets_data.end(), m_buckets_data.end(), it);
}
}
size_type overflow_size() const noexcept {
return m_overflow_elements.size();
}
template<class U = OverflowContainer, typename std::enable_if<has_key_compare<U>::value>::type* = nullptr>
typename U::key_compare key_comp() const {
return m_overflow_elements.key_comp();
}
private:
template<class K>
std::size_t hash_key(const K& key) const {
return Hash::operator()(key);
}
template<class K1, class K2>
bool compare_keys(const K1& key1, const K2& key2) const {
return KeyEqual::operator()(key1, key2);
}
std::size_t bucket_for_hash(std::size_t hash) const {
const std::size_t bucket = GrowthPolicy::bucket_for_hash(hash);
tsl_hh_assert(bucket < m_buckets_data.size() || (bucket == 0 && m_buckets_data.empty()));
return bucket;
}
template<typename U = value_type,
typename std::enable_if<std::is_nothrow_move_constructible<U>::value>::type* = nullptr>
void rehash_impl(size_type count_) {
hopscotch_hash new_map = new_hopscotch_hash(count_);
if(!m_overflow_elements.empty()) {
new_map.m_overflow_elements.swap(m_overflow_elements);
new_map.m_nb_elements += new_map.m_overflow_elements.size();
for(const value_type& value : new_map.m_overflow_elements) {
const std::size_t ibucket_for_hash = new_map.bucket_for_hash(new_map.hash_key(KeySelect()(value)));
new_map.m_buckets[ibucket_for_hash].set_overflow(true);
}
}
try {
const bool use_stored_hash = USE_STORED_HASH_ON_REHASH(new_map.bucket_count());
for(auto it_bucket = m_buckets_data.begin(); it_bucket != m_buckets_data.end(); ++it_bucket) {
if(it_bucket->empty()) {
continue;
}
const std::size_t hash = use_stored_hash?
it_bucket->truncated_bucket_hash():
new_map.hash_key(KeySelect()(it_bucket->value()));
const std::size_t ibucket_for_hash = new_map.bucket_for_hash(hash);
new_map.insert_value(ibucket_for_hash, hash, std::move(it_bucket->value()));
erase_from_bucket(*it_bucket, bucket_for_hash(hash));
}
}
/*
* The call to insert_value may throw an exception if an element is added to the overflow
* list. Rollback the elements in this case.
*/
catch(...) {
m_overflow_elements.swap(new_map.m_overflow_elements);
const bool use_stored_hash = USE_STORED_HASH_ON_REHASH(new_map.bucket_count());
for(auto it_bucket = new_map.m_buckets_data.begin(); it_bucket != new_map.m_buckets_data.end(); ++it_bucket) {
if(it_bucket->empty()) {
continue;
}
const std::size_t hash = use_stored_hash?
it_bucket->truncated_bucket_hash():
hash_key(KeySelect()(it_bucket->value()));
const std::size_t ibucket_for_hash = bucket_for_hash(hash);
// The elements we insert were not in the overflow list before the switch.
// They will not be go in the overflow list if we rollback the switch.
insert_value(ibucket_for_hash, hash, std::move(it_bucket->value()));
}
throw;
}
new_map.swap(*this);
}
template<typename U = value_type,
typename std::enable_if<std::is_copy_constructible<U>::value &&
!std::is_nothrow_move_constructible<U>::value>::type* = nullptr>
void rehash_impl(size_type count_) {
hopscotch_hash new_map = new_hopscotch_hash(count_);
const bool use_stored_hash = USE_STORED_HASH_ON_REHASH(new_map.bucket_count());
for(const hopscotch_bucket& bucket: m_buckets_data) {
if(bucket.empty()) {
continue;
}
const std::size_t hash = use_stored_hash?
bucket.truncated_bucket_hash():
new_map.hash_key(KeySelect()(bucket.value()));
const std::size_t ibucket_for_hash = new_map.bucket_for_hash(hash);
new_map.insert_value(ibucket_for_hash, hash, bucket.value());
}
for(const value_type& value: m_overflow_elements) {
const std::size_t hash = new_map.hash_key(KeySelect()(value));
const std::size_t ibucket_for_hash = new_map.bucket_for_hash(hash);
new_map.insert_value(ibucket_for_hash, hash, value);
}
new_map.swap(*this);
}
#ifdef TSL_HH_NO_RANGE_ERASE_WITH_CONST_ITERATOR
iterator_overflow mutable_overflow_iterator(const_iterator_overflow it) {
return std::next(m_overflow_elements.begin(), std::distance(m_overflow_elements.cbegin(), it));
}
#else
iterator_overflow mutable_overflow_iterator(const_iterator_overflow it) {
return m_overflow_elements.erase(it, it);
}
#endif
// iterator is in overflow list
iterator_overflow erase_from_overflow(const_iterator_overflow pos, std::size_t ibucket_for_hash) {
#ifdef TSL_HH_NO_RANGE_ERASE_WITH_CONST_ITERATOR
auto it_next = m_overflow_elements.erase(mutable_overflow_iterator(pos));
#else
auto it_next = m_overflow_elements.erase(pos);
#endif
m_nb_elements--;
// Check if we can remove the overflow flag
tsl_hh_assert(m_buckets[ibucket_for_hash].has_overflow());
for(const value_type& value: m_overflow_elements) {
const std::size_t bucket_for_value = bucket_for_hash(hash_key(KeySelect()(value)));
if(bucket_for_value == ibucket_for_hash) {
return it_next;
}
}
m_buckets[ibucket_for_hash].set_overflow(false);
return it_next;
}
/**
* bucket_for_value is the bucket in which the value is.
* ibucket_for_hash is the bucket where the value belongs.
*/
void erase_from_bucket(hopscotch_bucket& bucket_for_value, std::size_t ibucket_for_hash) noexcept {
const std::size_t ibucket_for_value = std::distance(m_buckets_data.data(), &bucket_for_value);
tsl_hh_assert(ibucket_for_value >= ibucket_for_hash);
bucket_for_value.remove_value();
m_buckets[ibucket_for_hash].toggle_neighbor_presence(ibucket_for_value - ibucket_for_hash);
m_nb_elements--;
}
template<class K, class M>
std::pair<iterator, bool> insert_or_assign_impl(K&& key, M&& obj) {
auto it = try_emplace_impl(std::forward<K>(key), std::forward<M>(obj));
if(!it.second) {
it.first.value() = std::forward<M>(obj);
}
return it;
}
template<typename P, class... Args>
std::pair<iterator, bool> try_emplace_impl(P&& key, Args&&... args_value) {
const std::size_t hash = hash_key(key);
const std::size_t ibucket_for_hash = bucket_for_hash(hash);
// Check if already presents
auto it_find = find_impl(key, hash, m_buckets + ibucket_for_hash);
if(it_find != end()) {
return std::make_pair(it_find, false);
}
return insert_value(ibucket_for_hash, hash, std::piecewise_construct,
std::forward_as_tuple(std::forward<P>(key)),
std::forward_as_tuple(std::forward<Args>(args_value)...));
}
template<typename P>
std::pair<iterator, bool> insert_impl(P&& value) {
const std::size_t hash = hash_key(KeySelect()(value));
const std::size_t ibucket_for_hash = bucket_for_hash(hash);
// Check if already presents
auto it_find = find_impl(KeySelect()(value), hash, m_buckets + ibucket_for_hash);
if(it_find != end()) {
return std::make_pair(it_find, false);
}
return insert_value(ibucket_for_hash, hash, std::forward<P>(value));
}
template<typename... Args>
std::pair<iterator, bool> insert_value(std::size_t ibucket_for_hash, std::size_t hash, Args&&... value_type_args) {
if((m_nb_elements - m_overflow_elements.size()) >= m_max_load_threshold_rehash) {
rehash(GrowthPolicy::next_bucket_count());
ibucket_for_hash = bucket_for_hash(hash);
}
std::size_t ibucket_empty = find_empty_bucket(ibucket_for_hash);
if(ibucket_empty < m_buckets_data.size()) {
do {
tsl_hh_assert(ibucket_empty >= ibucket_for_hash);
// Empty bucket is in range of NeighborhoodSize, use it
if(ibucket_empty - ibucket_for_hash < NeighborhoodSize) {
auto it = insert_in_bucket(ibucket_empty, ibucket_for_hash,
hash, std::forward<Args>(value_type_args)...);
return std::make_pair(iterator(it, m_buckets_data.end(), m_overflow_elements.begin()), true);
}
}
// else, try to swap values to get a closer empty bucket
while(swap_empty_bucket_closer(ibucket_empty));
}
auto it = insert_in_overflow(ibucket_for_hash, std::forward<Args>(value_type_args)...);
return std::make_pair(iterator(m_buckets_data.end(), m_buckets_data.end(), it), true);
// Never rehash here for memory safety
//////////////////////////////////////////////////////////////////////////////////////////////////////
// Load factor is too low or a rehash will not change the neighborhood, put the value in overflow list
// if(size() < m_min_load_threshold_rehash || !will_neighborhood_change_on_rehash(ibucket_for_hash)) {
// auto it = insert_in_overflow(ibucket_for_hash, std::forward<Args>(value_type_args)...);
// return std::make_pair(iterator(m_buckets_data.end(), m_buckets_data.end(), it), true);
// }
// rehash(GrowthPolicy::next_bucket_count());
// ibucket_for_hash = bucket_for_hash(hash);
// return insert_value(ibucket_for_hash, hash, std::forward<Args>(value_type_args)...);
//////////////////////////////////////////////////////////////////////////////////////////////////////
}
/*
* Return true if a rehash will change the position of a key-value in the neighborhood of
* ibucket_neighborhood_check. In this case a rehash is needed instead of puting the value in overflow list.
*/
bool will_neighborhood_change_on_rehash(size_t ibucket_neighborhood_check) const {
std::size_t expand_bucket_count = GrowthPolicy::next_bucket_count();
GrowthPolicy expand_growth_policy(expand_bucket_count);
const bool use_stored_hash = USE_STORED_HASH_ON_REHASH(expand_bucket_count);
for(size_t ibucket = ibucket_neighborhood_check;
ibucket < m_buckets_data.size() && (ibucket - ibucket_neighborhood_check) < NeighborhoodSize;
++ibucket)
{
tsl_hh_assert(!m_buckets[ibucket].empty());
const size_t hash = use_stored_hash?
m_buckets[ibucket].truncated_bucket_hash():
hash_key(KeySelect()(m_buckets[ibucket].value()));
if(bucket_for_hash(hash) != expand_growth_policy.bucket_for_hash(hash)) {
return true;
}
}
return false;
}
/*
* Return the index of an empty bucket in m_buckets_data.
* If none, the returned index equals m_buckets_data.size()
*/
std::size_t find_empty_bucket(std::size_t ibucket_start) const {
const std::size_t limit = std::min(ibucket_start + MAX_PROBES_FOR_EMPTY_BUCKET, m_buckets_data.size());
for(; ibucket_start < limit; ibucket_start++) {
if(m_buckets[ibucket_start].empty()) {
return ibucket_start;
}
}
return m_buckets_data.size();
}
/*
* Insert value in ibucket_empty where value originally belongs to ibucket_for_hash
*
* Return bucket iterator to ibucket_empty
*/
template<typename... Args>
iterator_buckets insert_in_bucket(std::size_t ibucket_empty, std::size_t ibucket_for_hash,
std::size_t hash, Args&&... value_type_args)
{
tsl_hh_assert(ibucket_empty >= ibucket_for_hash );
tsl_hh_assert(m_buckets[ibucket_empty].empty());
m_buckets[ibucket_empty].set_value_of_empty_bucket(hopscotch_bucket::truncate_hash(hash), std::forward<Args>(value_type_args)...);
tsl_hh_assert(!m_buckets[ibucket_for_hash].empty());
m_buckets[ibucket_for_hash].toggle_neighbor_presence(ibucket_empty - ibucket_for_hash);
m_nb_elements++;
return m_buckets_data.begin() + ibucket_empty;
}
template<class... Args, class U = OverflowContainer, typename std::enable_if<!has_key_compare<U>::value>::type* = nullptr>
iterator_overflow insert_in_overflow(std::size_t ibucket_for_hash, Args&&... value_type_args) {
auto it = m_overflow_elements.emplace(m_overflow_elements.end(), std::forward<Args>(value_type_args)...);
m_buckets[ibucket_for_hash].set_overflow(true);
m_nb_elements++;
return it;
}
template<class... Args, class U = OverflowContainer, typename std::enable_if<has_key_compare<U>::value>::type* = nullptr>
iterator_overflow insert_in_overflow(std::size_t ibucket_for_hash, Args&&... value_type_args) {
auto it = m_overflow_elements.emplace(std::forward<Args>(value_type_args)...).first;
m_buckets[ibucket_for_hash].set_overflow(true);
m_nb_elements++;
return it;
}
/*
* Try to swap the bucket ibucket_empty_in_out with a bucket preceding it while keeping the neighborhood
* conditions correct.
*
* If a swap was possible, the position of ibucket_empty_in_out will be closer to 0 and true will re returned.
*/
bool swap_empty_bucket_closer(std::size_t& ibucket_empty_in_out) {
tsl_hh_assert(ibucket_empty_in_out >= NeighborhoodSize);
const std::size_t neighborhood_start = ibucket_empty_in_out - NeighborhoodSize + 1;
for(std::size_t to_check = neighborhood_start; to_check < ibucket_empty_in_out; to_check++) {
neighborhood_bitmap neighborhood_infos = m_buckets[to_check].neighborhood_infos();
std::size_t to_swap = to_check;
while(neighborhood_infos != 0 && to_swap < ibucket_empty_in_out) {
if((neighborhood_infos & 1) == 1) {
tsl_hh_assert(m_buckets[ibucket_empty_in_out].empty());
tsl_hh_assert(!m_buckets[to_swap].empty());
m_buckets[to_swap].swap_value_into_empty_bucket(m_buckets[ibucket_empty_in_out]);
tsl_hh_assert(!m_buckets[to_check].check_neighbor_presence(ibucket_empty_in_out - to_check));
tsl_hh_assert(m_buckets[to_check].check_neighbor_presence(to_swap - to_check));
m_buckets[to_check].toggle_neighbor_presence(ibucket_empty_in_out - to_check);
m_buckets[to_check].toggle_neighbor_presence(to_swap - to_check);
ibucket_empty_in_out = to_swap;
return true;
}
to_swap++;
neighborhood_infos = neighborhood_bitmap(neighborhood_infos >> 1);
}
}
return false;
}
template<class K, class U = ValueSelect, typename std::enable_if<has_mapped_type<U>::value>::type* = nullptr>
typename U::value_type* find_value_impl(const K& key, std::size_t hash, hopscotch_bucket* bucket_for_hash) {
return const_cast<typename U::value_type*>(
static_cast<const hopscotch_hash*>(this)->find_value_impl(key, hash, bucket_for_hash));
}
/*
* Avoid the creation of an iterator to just get the value for operator[] and at() in maps. Faster this way.
*
* Return null if no value for the key (TODO use std::optional when available).
*/
template<class K, class U = ValueSelect, typename std::enable_if<has_mapped_type<U>::value>::type* = nullptr>
const typename U::value_type* find_value_impl(const K& key, std::size_t hash,
const hopscotch_bucket* bucket_for_hash) const
{
const hopscotch_bucket* bucket_found = find_in_buckets(key, hash, bucket_for_hash);
if(bucket_found != nullptr) {
return std::addressof(ValueSelect()(bucket_found->value()));
}
if(bucket_for_hash->has_overflow()) {
auto it_overflow = find_in_overflow(key);
if(it_overflow != m_overflow_elements.end()) {
return std::addressof(ValueSelect()(*it_overflow));
}
}
return nullptr;
}
template<class K>
size_type count_impl(const K& key, std::size_t hash, const hopscotch_bucket* bucket_for_hash) const {
if(find_in_buckets(key, hash, bucket_for_hash) != nullptr) {
return 1;
}
else if(bucket_for_hash->has_overflow() && find_in_overflow(key) != m_overflow_elements.cend()) {
return 1;
}
else {
return 0;
}
}
template<class K>
iterator find_impl(const K& key, std::size_t hash, hopscotch_bucket* bucket_for_hash) {
hopscotch_bucket* bucket_found = find_in_buckets(key, hash, bucket_for_hash);
if(bucket_found != nullptr) {
return iterator(m_buckets_data.begin() + std::distance(m_buckets_data.data(), bucket_found),
m_buckets_data.end(), m_overflow_elements.begin());
}
if(!bucket_for_hash->has_overflow()) {
return end();
}
return iterator(m_buckets_data.end(), m_buckets_data.end(), find_in_overflow(key));
}
template<class K>
const_iterator find_impl(const K& key, std::size_t hash, const hopscotch_bucket* bucket_for_hash) const {
const hopscotch_bucket* bucket_found = find_in_buckets(key, hash, bucket_for_hash);
if(bucket_found != nullptr) {
return const_iterator(m_buckets_data.cbegin() + std::distance(m_buckets_data.data(), bucket_found),
m_buckets_data.cend(), m_overflow_elements.cbegin());
}
if(!bucket_for_hash->has_overflow()) {
return cend();
}
return const_iterator(m_buckets_data.cend(), m_buckets_data.cend(), find_in_overflow(key));
}
template<class K>
hopscotch_bucket* find_in_buckets(const K& key, std::size_t hash, hopscotch_bucket* bucket_for_hash) {
const hopscotch_bucket* bucket_found =
static_cast<const hopscotch_hash*>(this)->find_in_buckets(key, hash, bucket_for_hash);
return const_cast<hopscotch_bucket*>(bucket_found);
}
/**
* Return a pointer to the bucket which has the value, nullptr otherwise.
*/
template<class K>
const hopscotch_bucket* find_in_buckets(const K& key, std::size_t hash, const hopscotch_bucket* bucket_for_hash) const {
(void) hash; // Avoid warning of unused variable when StoreHash is false;
// TODO Try to optimize the function.
// I tried to use ffs and __builtin_ffs functions but I could not reduce the time the function
// takes with -march=native
neighborhood_bitmap neighborhood_infos = bucket_for_hash->neighborhood_infos();
while(neighborhood_infos != 0) {
if((neighborhood_infos & 1) == 1) {
// Check StoreHash before calling bucket_hash_equal. Functionally it doesn't change anythin.
// If StoreHash is false, bucket_hash_equal is a no-op. Avoiding the call is there to help
// GCC optimizes `hash` parameter away, it seems to not be able to do without this hint.
if((!StoreHash || bucket_for_hash->bucket_hash_equal(hash)) &&
compare_keys(KeySelect()(bucket_for_hash->value()), key))
{
return bucket_for_hash;
}
}
++bucket_for_hash;
neighborhood_infos = neighborhood_bitmap(neighborhood_infos >> 1);
}
return nullptr;
}
template<class K, class U = OverflowContainer, typename std::enable_if<!has_key_compare<U>::value>::type* = nullptr>
iterator_overflow find_in_overflow(const K& key) {
return std::find_if(m_overflow_elements.begin(), m_overflow_elements.end(),
[&](const value_type& value) {
return compare_keys(key, KeySelect()(value));
});
}
template<class K, class U = OverflowContainer, typename std::enable_if<!has_key_compare<U>::value>::type* = nullptr>
const_iterator_overflow find_in_overflow(const K& key) const {
return std::find_if(m_overflow_elements.cbegin(), m_overflow_elements.cend(),
[&](const value_type& value) {
return compare_keys(key, KeySelect()(value));
});
}
template<class K, class U = OverflowContainer, typename std::enable_if<has_key_compare<U>::value>::type* = nullptr>
iterator_overflow find_in_overflow(const K& key) {
return m_overflow_elements.find(key);
}
template<class K, class U = OverflowContainer, typename std::enable_if<has_key_compare<U>::value>::type* = nullptr>
const_iterator_overflow find_in_overflow(const K& key) const {
return m_overflow_elements.find(key);
}
template<class U = OverflowContainer, typename std::enable_if<!has_key_compare<U>::value>::type* = nullptr>
hopscotch_hash new_hopscotch_hash(size_type bucket_count) {
return hopscotch_hash(bucket_count, static_cast<Hash&>(*this), static_cast<KeyEqual&>(*this),
get_allocator(), m_max_load_factor);
}
template<class U = OverflowContainer, typename std::enable_if<has_key_compare<U>::value>::type* = nullptr>
hopscotch_hash new_hopscotch_hash(size_type bucket_count) {
return hopscotch_hash(bucket_count, static_cast<Hash&>(*this), static_cast<KeyEqual&>(*this),
get_allocator(), m_max_load_factor, m_overflow_elements.key_comp());
}
public:
static const size_type DEFAULT_INIT_BUCKETS_SIZE = 0;
static constexpr float DEFAULT_MAX_LOAD_FACTOR = (NeighborhoodSize <= 30)?0.8f:0.9f;
private:
static const std::size_t MAX_PROBES_FOR_EMPTY_BUCKET = 12*NeighborhoodSize;
static constexpr float MIN_LOAD_FACTOR_FOR_REHASH = 0.2f;
/**
* We can only use the hash on rehash if the size of the hash type is the same as the stored one or
* if we use a power of two modulo. In the case of the power of two modulo, we just mask
* the least significant bytes, we just have to check that the truncated_hash_type didn't truncated
* too much bytes.
*/
template<class T = size_type, typename std::enable_if<std::is_same<T, truncated_hash_type>::value>::type* = nullptr>
static bool USE_STORED_HASH_ON_REHASH(size_type /*bucket_count*/) {
return StoreHash;
}
template<class T = size_type, typename std::enable_if<!std::is_same<T, truncated_hash_type>::value>::type* = nullptr>
static bool USE_STORED_HASH_ON_REHASH(size_type bucket_count) {
(void) bucket_count;
if(StoreHash && is_power_of_two_policy<GrowthPolicy>::value) {
tsl_hh_assert(bucket_count > 0);
return (bucket_count - 1) <= std::numeric_limits<truncated_hash_type>::max();
}
else {
return false;
}
}
/**
* Return an always valid pointer to an static empty hopscotch_bucket.
*/
hopscotch_bucket* static_empty_bucket_ptr() {
static hopscotch_bucket empty_bucket;
return &empty_bucket;
}
private:
buckets_container_type m_buckets_data;
overflow_container_type m_overflow_elements;
/**
* Points to m_buckets_data.data() if !m_buckets_data.empty() otherwise points to static_empty_bucket_ptr.
* This variable is useful to avoid the cost of checking if m_buckets_data is empty when trying
* to find an element.
*
* TODO Remove m_buckets_data and only use a pointer+size instead of a pointer+vector to save some space in the hopscotch_hash object.
*/
hopscotch_bucket* m_buckets;
size_type m_nb_elements;
float m_max_load_factor;
/**
* Max size of the hash table before a rehash occurs automatically to grow the table.
*/
size_type m_max_load_threshold_rehash;
/**
* Min size of the hash table before a rehash can occurs automatically (except if m_max_load_threshold_rehash os reached).
* If the neighborhood of a bucket is full before the min is reacher, the elements are put into m_overflow_elements.
*/
size_type m_min_load_threshold_rehash;
};
} // end namespace detail_hopscotch_hash
} // end namespace tsl
#endif
/**
* MIT License
*
* Copyright (c) 2017 Tessil
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_HOPSCOTCH_MAP_H
#define TSL_HOPSCOTCH_MAP_H
#include <algorithm>
#include <cstddef>
#include <functional>
#include <initializer_list>
#include <list>
#include <memory>
#include <type_traits>
#include <utility>
#include "paddle/fluid/feed/src/common/hopscotch_hash.h"
namespace tsl {
/**
* Implementation of a hash map using the hopscotch hashing algorithm.
*
* The Key and the value T must be either nothrow move-constructible, copy-constuctible or both.
*
* The size of the neighborhood (NeighborhoodSize) must be > 0 and <= 62 if StoreHash is false.
* When StoreHash is true, 32-bits of the hash will be stored alongside the neighborhood limiting
* the NeighborhoodSize to <= 30. There is no memory usage difference between
* 'NeighborhoodSize 62; StoreHash false' and 'NeighborhoodSize 30; StoreHash true'.
*
* Storing the hash may improve performance on insert during the rehash process if the hash takes time
* to compute. It may also improve read performance if the KeyEqual function takes time (or incurs a cache-miss).
* If used with simple Hash and KeyEqual it may slow things down.
*
* StoreHash can only be set if the GrowthPolicy is set to tsl::power_of_two_growth_policy.
*
* GrowthPolicy defines how the map grows and consequently how a hash value is mapped to a bucket.
* By default the map uses tsl::power_of_two_growth_policy. This policy keeps the number of buckets
* to a power of two and uses a mask to map the hash to a bucket instead of the slow modulo.
* You may define your own growth policy, check tsl::power_of_two_growth_policy for the interface.
*
* If the destructors of Key or T throw an exception, behaviour of the class is undefined.
*
* Iterators invalidation:
* - clear, operator=, reserve, rehash: always invalidate the iterators.
* - insert, emplace, emplace_hint, operator[]: if there is an effective insert, invalidate the iterators
* if a displacement is needed to resolve a collision (which mean that most of the time,
* insert will invalidate the iterators). Or if there is a rehash.
* - erase: iterator on the erased element is the only one which become invalid.
*/
template<class Key,
class T,
class Hash = std::hash<Key>,
class KeyEqual = std::equal_to<Key>,
class Allocator = std::allocator<std::pair<Key, T>>,
unsigned int NeighborhoodSize = 62,
bool StoreHash = false,
class GrowthPolicy = tsl::hh::power_of_two_growth_policy<2>>
class hopscotch_map {
private:
template<typename U>
using has_is_transparent = tsl::detail_hopscotch_hash::has_is_transparent<U>;
class KeySelect {
public:
using key_type = Key;
const key_type& operator()(const std::pair<Key, T>& key_value) const {
return key_value.first;
}
key_type& operator()(std::pair<Key, T>& key_value) {
return key_value.first;
}
};
class ValueSelect {
public:
using value_type = T;
const value_type& operator()(const std::pair<Key, T>& key_value) const {
return key_value.second;
}
value_type& operator()(std::pair<Key, T>& key_value) {
return key_value.second;
}
};
using overflow_container_type = std::list<std::pair<Key, T>, Allocator>;
using ht = detail_hopscotch_hash::hopscotch_hash<std::pair<Key, T>, KeySelect, ValueSelect,
Hash, KeyEqual,
Allocator, NeighborhoodSize,
StoreHash, GrowthPolicy,
overflow_container_type>;
public:
using key_type = typename ht::key_type;
using mapped_type = T;
using value_type = typename ht::value_type;
using size_type = typename ht::size_type;
using difference_type = typename ht::difference_type;
using hasher = typename ht::hasher;
using key_equal = typename ht::key_equal;
using allocator_type = typename ht::allocator_type;
using reference = typename ht::reference;
using const_reference = typename ht::const_reference;
using pointer = typename ht::pointer;
using const_pointer = typename ht::const_pointer;
using iterator = typename ht::iterator;
using const_iterator = typename ht::const_iterator;
/*
* Constructors
*/
hopscotch_map() : hopscotch_map(ht::DEFAULT_INIT_BUCKETS_SIZE) {
}
explicit hopscotch_map(size_type bucket_count,
const Hash& hash = Hash(),
const KeyEqual& equal = KeyEqual(),
const Allocator& alloc = Allocator()) :
m_ht(bucket_count, hash, equal, alloc, ht::DEFAULT_MAX_LOAD_FACTOR)
{
}
hopscotch_map(size_type bucket_count,
const Allocator& alloc) : hopscotch_map(bucket_count, Hash(), KeyEqual(), alloc)
{
}
hopscotch_map(size_type bucket_count,
const Hash& hash,
const Allocator& alloc) : hopscotch_map(bucket_count, hash, KeyEqual(), alloc)
{
}
explicit hopscotch_map(const Allocator& alloc) : hopscotch_map(ht::DEFAULT_INIT_BUCKETS_SIZE, alloc) {
}
template<class InputIt>
hopscotch_map(InputIt first, InputIt last,
size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE,
const Hash& hash = Hash(),
const KeyEqual& equal = KeyEqual(),
const Allocator& alloc = Allocator()) : hopscotch_map(bucket_count, hash, equal, alloc)
{
insert(first, last);
}
template<class InputIt>
hopscotch_map(InputIt first, InputIt last,
size_type bucket_count,
const Allocator& alloc) : hopscotch_map(first, last, bucket_count, Hash(), KeyEqual(), alloc)
{
}
template<class InputIt>
hopscotch_map(InputIt first, InputIt last,
size_type bucket_count,
const Hash& hash,
const Allocator& alloc) : hopscotch_map(first, last, bucket_count, hash, KeyEqual(), alloc)
{
}
hopscotch_map(std::initializer_list<value_type> init,
size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE,
const Hash& hash = Hash(),
const KeyEqual& equal = KeyEqual(),
const Allocator& alloc = Allocator()) :
hopscotch_map(init.begin(), init.end(), bucket_count, hash, equal, alloc)
{
}
hopscotch_map(std::initializer_list<value_type> init,
size_type bucket_count,
const Allocator& alloc) :
hopscotch_map(init.begin(), init.end(), bucket_count, Hash(), KeyEqual(), alloc)
{
}
hopscotch_map(std::initializer_list<value_type> init,
size_type bucket_count,
const Hash& hash,
const Allocator& alloc) :
hopscotch_map(init.begin(), init.end(), bucket_count, hash, KeyEqual(), alloc)
{
}
hopscotch_map& operator=(std::initializer_list<value_type> ilist) {
m_ht.clear();
m_ht.reserve(ilist.size());
m_ht.insert(ilist.begin(), ilist.end());
return *this;
}
allocator_type get_allocator() const { return m_ht.get_allocator(); }
/*
* Iterators
*/
iterator begin() noexcept { return m_ht.begin(); }
const_iterator begin() const noexcept { return m_ht.begin(); }
const_iterator cbegin() const noexcept { return m_ht.cbegin(); }
iterator end() noexcept { return m_ht.end(); }
const_iterator end() const noexcept { return m_ht.end(); }
const_iterator cend() const noexcept { return m_ht.cend(); }
/*
* Capacity
*/
bool empty() const noexcept { return m_ht.empty(); }
size_type size() const noexcept { return m_ht.size(); }
size_type max_size() const noexcept { return m_ht.max_size(); }
/*
* Modifiers
*/
void clear() noexcept { m_ht.clear(); }
std::pair<iterator, bool> insert(const value_type& value) {
return m_ht.insert(value);
}
template<class P, typename std::enable_if<std::is_constructible<value_type, P&&>::value>::type* = nullptr>
std::pair<iterator, bool> insert(P&& value) {
return m_ht.insert(std::forward<P>(value));
}
std::pair<iterator, bool> insert(value_type&& value) {
return m_ht.insert(std::move(value));
}
iterator insert(const_iterator hint, const value_type& value) {
return m_ht.insert(hint, value);
}
template<class P, typename std::enable_if<std::is_constructible<value_type, P&&>::value>::type* = nullptr>
iterator insert(const_iterator hint, P&& value) {
return m_ht.insert(hint, std::forward<P>(value));
}
iterator insert(const_iterator hint, value_type&& value) {
return m_ht.insert(hint, std::move(value));
}
template<class InputIt>
void insert(InputIt first, InputIt last) {
m_ht.insert(first, last);
}
void insert(std::initializer_list<value_type> ilist) {
m_ht.insert(ilist.begin(), ilist.end());
}
template<class M>
std::pair<iterator, bool> insert_or_assign(const key_type& k, M&& obj) {
return m_ht.insert_or_assign(k, std::forward<M>(obj));
}
template<class M>
std::pair<iterator, bool> insert_or_assign(key_type&& k, M&& obj) {
return m_ht.insert_or_assign(std::move(k), std::forward<M>(obj));
}
template<class M>
iterator insert_or_assign(const_iterator hint, const key_type& k, M&& obj) {
return m_ht.insert_or_assign(hint, k, std::forward<M>(obj));
}
template<class M>
iterator insert_or_assign(const_iterator hint, key_type&& k, M&& obj) {
return m_ht.insert_or_assign(hint, std::move(k), std::forward<M>(obj));
}
/**
* Due to the way elements are stored, emplace will need to move or copy the key-value once.
* The method is equivalent to insert(value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template<class... Args>
std::pair<iterator, bool> emplace(Args&&... args) {
return m_ht.emplace(std::forward<Args>(args)...);
}
/**
* Due to the way elements are stored, emplace_hint will need to move or copy the key-value once.
* The method is equivalent to insert(hint, value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template<class... Args>
iterator emplace_hint(const_iterator hint, Args&&... args) {
return m_ht.emplace_hint(hint, std::forward<Args>(args)...);
}
template<class... Args>
std::pair<iterator, bool> try_emplace(const key_type& k, Args&&... args) {
return m_ht.try_emplace(k, std::forward<Args>(args)...);
}
template<class... Args>
std::pair<iterator, bool> try_emplace(key_type&& k, Args&&... args) {
return m_ht.try_emplace(std::move(k), std::forward<Args>(args)...);
}
template<class... Args>
iterator try_emplace(const_iterator hint, const key_type& k, Args&&... args) {
return m_ht.try_emplace(hint, k, std::forward<Args>(args)...);
}
template<class... Args>
iterator try_emplace(const_iterator hint, key_type&& k, Args&&... args) {
return m_ht.try_emplace(hint, std::move(k), std::forward<Args>(args)...);
}
iterator erase(iterator pos) { return m_ht.erase(pos); }
iterator erase(const_iterator pos) { return m_ht.erase(pos); }
iterator erase(const_iterator first, const_iterator last) { return m_ht.erase(first, last); }
size_type erase(const key_type& key) { return m_ht.erase(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
*/
size_type erase(const key_type& key, std::size_t precalculated_hash) {
return m_ht.erase(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
size_type erase(const K& key) { return m_ht.erase(key); }
/**
* @copydoc erase(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
size_type erase(const K& key, std::size_t precalculated_hash) {
return m_ht.erase(key, precalculated_hash);
}
void swap(hopscotch_map& other) { other.m_ht.swap(m_ht); }
/*
* Lookup
*/
T& at(const Key& key) { return m_ht.at(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
T& at(const Key& key, std::size_t precalculated_hash) { return m_ht.at(key, precalculated_hash); }
const T& at(const Key& key) const { return m_ht.at(key); }
/**
* @copydoc at(const Key& key, std::size_t precalculated_hash)
*/
const T& at(const Key& key, std::size_t precalculated_hash) const { return m_ht.at(key, precalculated_hash); }
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
T& at(const K& key) { return m_ht.at(key); }
/**
* @copydoc at(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
T& at(const K& key, std::size_t precalculated_hash) { return m_ht.at(key, precalculated_hash); }
/**
* @copydoc at(const K& key)
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
const T& at(const K& key) const { return m_ht.at(key); }
/**
* @copydoc at(const K& key, std::size_t precalculated_hash)
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
const T& at(const K& key, std::size_t precalculated_hash) const { return m_ht.at(key, precalculated_hash); }
T& operator[](const Key& key) { return m_ht[key]; }
T& operator[](Key&& key) { return m_ht[std::move(key)]; }
size_type count(const Key& key) const { return m_ht.count(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
size_type count(const Key& key, std::size_t precalculated_hash) const {
return m_ht.count(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
size_type count(const K& key) const { return m_ht.count(key); }
/**
* @copydoc count(const K& key) const
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
size_type count(const K& key, std::size_t precalculated_hash) const { return m_ht.count(key, precalculated_hash); }
iterator find(const Key& key) { return m_ht.find(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
iterator find(const Key& key, std::size_t precalculated_hash) { return m_ht.find(key, precalculated_hash); }
const_iterator find(const Key& key) const { return m_ht.find(key); }
/**
* @copydoc find(const Key& key, std::size_t precalculated_hash)
*/
const_iterator find(const Key& key, std::size_t precalculated_hash) const {
return m_ht.find(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
iterator find(const K& key) { return m_ht.find(key); }
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
iterator find(const K& key, std::size_t precalculated_hash) { return m_ht.find(key, precalculated_hash); }
/**
* @copydoc find(const K& key)
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
const_iterator find(const K& key) const { return m_ht.find(key); }
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
const_iterator find(const K& key, std::size_t precalculated_hash) const {
return m_ht.find(key, precalculated_hash);
}
std::pair<iterator, iterator> equal_range(const Key& key) { return m_ht.equal_range(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
std::pair<iterator, iterator> equal_range(const Key& key, std::size_t precalculated_hash) {
return m_ht.equal_range(key, precalculated_hash);
}
std::pair<const_iterator, const_iterator> equal_range(const Key& key) const { return m_ht.equal_range(key); }
/**
* @copydoc equal_range(const Key& key, std::size_t precalculated_hash)
*/
std::pair<const_iterator, const_iterator> equal_range(const Key& key, std::size_t precalculated_hash) const {
return m_ht.equal_range(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
std::pair<iterator, iterator> equal_range(const K& key) { return m_ht.equal_range(key); }
/**
* @copydoc equal_range(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
std::pair<iterator, iterator> equal_range(const K& key, std::size_t precalculated_hash) {
return m_ht.equal_range(key, precalculated_hash);
}
/**
* @copydoc equal_range(const K& key)
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
std::pair<const_iterator, const_iterator> equal_range(const K& key) const { return m_ht.equal_range(key); }
/**
* @copydoc equal_range(const K& key, std::size_t precalculated_hash)
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
std::pair<const_iterator, const_iterator> equal_range(const K& key, std::size_t precalculated_hash) const {
return m_ht.equal_range(key, precalculated_hash);
}
/*
* Bucket interface
*/
size_type bucket_count() const { return m_ht.bucket_count(); }
size_type max_bucket_count() const { return m_ht.max_bucket_count(); }
/*
* Hash policy
*/
float load_factor() const { return m_ht.load_factor(); }
float max_load_factor() const { return m_ht.max_load_factor(); }
void max_load_factor(float ml) { m_ht.max_load_factor(ml); }
void rehash(size_type count_) { m_ht.rehash(count_); }
void reserve(size_type count_) { m_ht.reserve(count_); }
/*
* Observers
*/
hasher hash_function() const { return m_ht.hash_function(); }
key_equal key_eq() const { return m_ht.key_eq(); }
/*
* Other
*/
/**
* Convert a const_iterator to an iterator.
*/
iterator mutable_iterator(const_iterator pos) {
return m_ht.mutable_iterator(pos);
}
size_type overflow_size() const noexcept { return m_ht.overflow_size(); }
friend bool operator==(const hopscotch_map& lhs, const hopscotch_map& rhs) {
if(lhs.size() != rhs.size()) {
return false;
}
for(const auto& element_lhs : lhs) {
const auto it_element_rhs = rhs.find(element_lhs.first);
if(it_element_rhs == rhs.cend() || element_lhs.second != it_element_rhs->second) {
return false;
}
}
return true;
}
friend bool operator!=(const hopscotch_map& lhs, const hopscotch_map& rhs) {
return !operator==(lhs, rhs);
}
friend void swap(hopscotch_map& lhs, hopscotch_map& rhs) {
lhs.swap(rhs);
}
private:
ht m_ht;
};
/**
* Same as `tsl::hopscotch_map<Key, T, Hash, KeyEqual, Allocator, NeighborhoodSize, StoreHash, tsl::hh::prime_growth_policy>`.
*/
template<class Key,
class T,
class Hash = std::hash<Key>,
class KeyEqual = std::equal_to<Key>,
class Allocator = std::allocator<std::pair<Key, T>>,
unsigned int NeighborhoodSize = 62,
bool StoreHash = false>
using hopscotch_pg_map = hopscotch_map<Key, T, Hash, KeyEqual, Allocator, NeighborhoodSize, StoreHash, tsl::hh::prime_growth_policy>;
} // end namespace tsl
#endif
/**
* MIT License
*
* Copyright (c) 2017 Tessil
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_HOPSCOTCH_SET_H
#define TSL_HOPSCOTCH_SET_H
#include <algorithm>
#include <cstddef>
#include <functional>
#include <initializer_list>
#include <list>
#include <memory>
#include <type_traits>
#include <utility>
#include "paddle/fluid/feed/src/common/hopscotch_hash.h"
namespace tsl {
/**
* Implementation of a hash set using the hopscotch hashing algorithm.
*
* The Key must be either nothrow move-constructible, copy-constuctible or both.
*
* The size of the neighborhood (NeighborhoodSize) must be > 0 and <= 62 if StoreHash is false.
* When StoreHash is true, 32-bits of the hash will be stored alongside the neighborhood limiting
* the NeighborhoodSize to <= 30. There is no memory usage difference between
* 'NeighborhoodSize 62; StoreHash false' and 'NeighborhoodSize 30; StoreHash true'.
*
* Storing the hash may improve performance on insert during the rehash process if the hash takes time
* to compute. It may also improve read performance if the KeyEqual function takes time (or incurs a cache-miss).
* If used with simple Hash and KeyEqual it may slow things down.
*
* StoreHash can only be set if the GrowthPolicy is set to tsl::power_of_two_growth_policy.
*
* GrowthPolicy defines how the set grows and consequently how a hash value is mapped to a bucket.
* By default the set uses tsl::power_of_two_growth_policy. This policy keeps the number of buckets
* to a power of two and uses a mask to set the hash to a bucket instead of the slow modulo.
* You may define your own growth policy, check tsl::power_of_two_growth_policy for the interface.
*
* If the destructor of Key throws an exception, behaviour of the class is undefined.
*
* Iterators invalidation:
* - clear, operator=, reserve, rehash: always invalidate the iterators.
* - insert, emplace, emplace_hint, operator[]: if there is an effective insert, invalidate the iterators
* if a displacement is needed to resolve a collision (which mean that most of the time,
* insert will invalidate the iterators). Or if there is a rehash.
* - erase: iterator on the erased element is the only one which become invalid.
*/
template<class Key,
class Hash = std::hash<Key>,
class KeyEqual = std::equal_to<Key>,
class Allocator = std::allocator<Key>,
unsigned int NeighborhoodSize = 62,
bool StoreHash = false,
class GrowthPolicy = tsl::hh::power_of_two_growth_policy<2>>
class hopscotch_set {
private:
template<typename U>
using has_is_transparent = tsl::detail_hopscotch_hash::has_is_transparent<U>;
class KeySelect {
public:
using key_type = Key;
const key_type& operator()(const Key& key) const {
return key;
}
key_type& operator()(Key& key) {
return key;
}
};
using overflow_container_type = std::list<Key, Allocator>;
using ht = detail_hopscotch_hash::hopscotch_hash<Key, KeySelect, void,
Hash, KeyEqual,
Allocator, NeighborhoodSize,
StoreHash, GrowthPolicy,
overflow_container_type>;
public:
using key_type = typename ht::key_type;
using value_type = typename ht::value_type;
using size_type = typename ht::size_type;
using difference_type = typename ht::difference_type;
using hasher = typename ht::hasher;
using key_equal = typename ht::key_equal;
using allocator_type = typename ht::allocator_type;
using reference = typename ht::reference;
using const_reference = typename ht::const_reference;
using pointer = typename ht::pointer;
using const_pointer = typename ht::const_pointer;
using iterator = typename ht::iterator;
using const_iterator = typename ht::const_iterator;
/*
* Constructors
*/
hopscotch_set() : hopscotch_set(ht::DEFAULT_INIT_BUCKETS_SIZE) {
}
explicit hopscotch_set(size_type bucket_count,
const Hash& hash = Hash(),
const KeyEqual& equal = KeyEqual(),
const Allocator& alloc = Allocator()) :
m_ht(bucket_count, hash, equal, alloc, ht::DEFAULT_MAX_LOAD_FACTOR)
{
}
hopscotch_set(size_type bucket_count,
const Allocator& alloc) : hopscotch_set(bucket_count, Hash(), KeyEqual(), alloc)
{
}
hopscotch_set(size_type bucket_count,
const Hash& hash,
const Allocator& alloc) : hopscotch_set(bucket_count, hash, KeyEqual(), alloc)
{
}
explicit hopscotch_set(const Allocator& alloc) : hopscotch_set(ht::DEFAULT_INIT_BUCKETS_SIZE, alloc) {
}
template<class InputIt>
hopscotch_set(InputIt first, InputIt last,
size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE,
const Hash& hash = Hash(),
const KeyEqual& equal = KeyEqual(),
const Allocator& alloc = Allocator()) : hopscotch_set(bucket_count, hash, equal, alloc)
{
insert(first, last);
}
template<class InputIt>
hopscotch_set(InputIt first, InputIt last,
size_type bucket_count,
const Allocator& alloc) : hopscotch_set(first, last, bucket_count, Hash(), KeyEqual(), alloc)
{
}
template<class InputIt>
hopscotch_set(InputIt first, InputIt last,
size_type bucket_count,
const Hash& hash,
const Allocator& alloc) : hopscotch_set(first, last, bucket_count, hash, KeyEqual(), alloc)
{
}
hopscotch_set(std::initializer_list<value_type> init,
size_type bucket_count = ht::DEFAULT_INIT_BUCKETS_SIZE,
const Hash& hash = Hash(),
const KeyEqual& equal = KeyEqual(),
const Allocator& alloc = Allocator()) :
hopscotch_set(init.begin(), init.end(), bucket_count, hash, equal, alloc)
{
}
hopscotch_set(std::initializer_list<value_type> init,
size_type bucket_count,
const Allocator& alloc) :
hopscotch_set(init.begin(), init.end(), bucket_count, Hash(), KeyEqual(), alloc)
{
}
hopscotch_set(std::initializer_list<value_type> init,
size_type bucket_count,
const Hash& hash,
const Allocator& alloc) :
hopscotch_set(init.begin(), init.end(), bucket_count, hash, KeyEqual(), alloc)
{
}
hopscotch_set& operator=(std::initializer_list<value_type> ilist) {
m_ht.clear();
m_ht.reserve(ilist.size());
m_ht.insert(ilist.begin(), ilist.end());
return *this;
}
allocator_type get_allocator() const { return m_ht.get_allocator(); }
/*
* Iterators
*/
iterator begin() noexcept { return m_ht.begin(); }
const_iterator begin() const noexcept { return m_ht.begin(); }
const_iterator cbegin() const noexcept { return m_ht.cbegin(); }
iterator end() noexcept { return m_ht.end(); }
const_iterator end() const noexcept { return m_ht.end(); }
const_iterator cend() const noexcept { return m_ht.cend(); }
/*
* Capacity
*/
bool empty() const noexcept { return m_ht.empty(); }
size_type size() const noexcept { return m_ht.size(); }
size_type max_size() const noexcept { return m_ht.max_size(); }
/*
* Modifiers
*/
void clear() noexcept { m_ht.clear(); }
std::pair<iterator, bool> insert(const value_type& value) { return m_ht.insert(value); }
std::pair<iterator, bool> insert(value_type&& value) { return m_ht.insert(std::move(value)); }
iterator insert(const_iterator hint, const value_type& value) { return m_ht.insert(hint, value); }
iterator insert(const_iterator hint, value_type&& value) { return m_ht.insert(hint, std::move(value)); }
template<class InputIt>
void insert(InputIt first, InputIt last) { m_ht.insert(first, last); }
void insert(std::initializer_list<value_type> ilist) { m_ht.insert(ilist.begin(), ilist.end()); }
/**
* Due to the way elements are stored, emplace will need to move or copy the key-value once.
* The method is equivalent to insert(value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template<class... Args>
std::pair<iterator, bool> emplace(Args&&... args) { return m_ht.emplace(std::forward<Args>(args)...); }
/**
* Due to the way elements are stored, emplace_hint will need to move or copy the key-value once.
* The method is equivalent to insert(hint, value_type(std::forward<Args>(args)...));
*
* Mainly here for compatibility with the std::unordered_map interface.
*/
template<class... Args>
iterator emplace_hint(const_iterator hint, Args&&... args) {
return m_ht.emplace_hint(hint, std::forward<Args>(args)...);
}
iterator erase(iterator pos) { return m_ht.erase(pos); }
iterator erase(const_iterator pos) { return m_ht.erase(pos); }
iterator erase(const_iterator first, const_iterator last) { return m_ht.erase(first, last); }
size_type erase(const key_type& key) { return m_ht.erase(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
*/
size_type erase(const key_type& key, std::size_t precalculated_hash) {
return m_ht.erase(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
size_type erase(const K& key) { return m_ht.erase(key); }
/**
* @copydoc erase(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup to the value if you already have the hash.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
size_type erase(const K& key, std::size_t precalculated_hash) {
return m_ht.erase(key, precalculated_hash);
}
void swap(hopscotch_set& other) { other.m_ht.swap(m_ht); }
/*
* Lookup
*/
size_type count(const Key& key) const { return m_ht.count(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
size_type count(const Key& key, std::size_t precalculated_hash) const { return m_ht.count(key, precalculated_hash); }
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
size_type count(const K& key) const { return m_ht.count(key); }
/**
* @copydoc count(const K& key) const
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
size_type count(const K& key, std::size_t precalculated_hash) const { return m_ht.count(key, precalculated_hash); }
iterator find(const Key& key) { return m_ht.find(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
iterator find(const Key& key, std::size_t precalculated_hash) { return m_ht.find(key, precalculated_hash); }
const_iterator find(const Key& key) const { return m_ht.find(key); }
/**
* @copydoc find(const Key& key, std::size_t precalculated_hash)
*/
const_iterator find(const Key& key, std::size_t precalculated_hash) const { return m_ht.find(key, precalculated_hash); }
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
iterator find(const K& key) { return m_ht.find(key); }
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
iterator find(const K& key, std::size_t precalculated_hash) { return m_ht.find(key, precalculated_hash); }
/**
* @copydoc find(const K& key)
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
const_iterator find(const K& key) const { return m_ht.find(key); }
/**
* @copydoc find(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
const_iterator find(const K& key, std::size_t precalculated_hash) const { return m_ht.find(key, precalculated_hash); }
std::pair<iterator, iterator> equal_range(const Key& key) { return m_ht.equal_range(key); }
/**
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
std::pair<iterator, iterator> equal_range(const Key& key, std::size_t precalculated_hash) {
return m_ht.equal_range(key, precalculated_hash);
}
std::pair<const_iterator, const_iterator> equal_range(const Key& key) const { return m_ht.equal_range(key); }
/**
* @copydoc equal_range(const Key& key, std::size_t precalculated_hash)
*/
std::pair<const_iterator, const_iterator> equal_range(const Key& key, std::size_t precalculated_hash) const {
return m_ht.equal_range(key, precalculated_hash);
}
/**
* This overload only participates in the overload resolution if the typedef KeyEqual::is_transparent exists.
* If so, K must be hashable and comparable to Key.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
std::pair<iterator, iterator> equal_range(const K& key) { return m_ht.equal_range(key); }
/**
* @copydoc equal_range(const K& key)
*
* Use the hash value 'precalculated_hash' instead of hashing the key. The hash value should be the same
* as hash_function()(key). Usefull to speed-up the lookup if you already have the hash.
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
std::pair<iterator, iterator> equal_range(const K& key, std::size_t precalculated_hash) {
return m_ht.equal_range(key, precalculated_hash);
}
/**
* @copydoc equal_range(const K& key)
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
std::pair<const_iterator, const_iterator> equal_range(const K& key) const { return m_ht.equal_range(key); }
/**
* @copydoc equal_range(const K& key, std::size_t precalculated_hash)
*/
template<class K, class KE = KeyEqual, typename std::enable_if<has_is_transparent<KE>::value>::type* = nullptr>
std::pair<const_iterator, const_iterator> equal_range(const K& key, std::size_t precalculated_hash) const {
return m_ht.equal_range(key, precalculated_hash);
}
/*
* Bucket interface
*/
size_type bucket_count() const { return m_ht.bucket_count(); }
size_type max_bucket_count() const { return m_ht.max_bucket_count(); }
/*
* Hash policy
*/
float load_factor() const { return m_ht.load_factor(); }
float max_load_factor() const { return m_ht.max_load_factor(); }
void max_load_factor(float ml) { m_ht.max_load_factor(ml); }
void rehash(size_type count_) { m_ht.rehash(count_); }
void reserve(size_type count_) { m_ht.reserve(count_); }
/*
* Observers
*/
hasher hash_function() const { return m_ht.hash_function(); }
key_equal key_eq() const { return m_ht.key_eq(); }
/*
* Other
*/
/**
* Convert a const_iterator to an iterator.
*/
iterator mutable_iterator(const_iterator pos) {
return m_ht.mutable_iterator(pos);
}
size_type overflow_size() const noexcept { return m_ht.overflow_size(); }
friend bool operator==(const hopscotch_set& lhs, const hopscotch_set& rhs) {
if(lhs.size() != rhs.size()) {
return false;
}
for(const auto& element_lhs : lhs) {
const auto it_element_rhs = rhs.find(element_lhs);
if(it_element_rhs == rhs.cend()) {
return false;
}
}
return true;
}
friend bool operator!=(const hopscotch_set& lhs, const hopscotch_set& rhs) {
return !operator==(lhs, rhs);
}
friend void swap(hopscotch_set& lhs, hopscotch_set& rhs) {
lhs.swap(rhs);
}
private:
ht m_ht;
};
/**
* Same as `tsl::hopscotch_set<Key, Hash, KeyEqual, Allocator, NeighborhoodSize, StoreHash, tsl::hh::prime_growth_policy>`.
*/
template<class Key,
class Hash = std::hash<Key>,
class KeyEqual = std::equal_to<Key>,
class Allocator = std::allocator<Key>,
unsigned int NeighborhoodSize = 62,
bool StoreHash = false>
using hopscotch_pg_set = hopscotch_set<Key, Hash, KeyEqual, Allocator, NeighborhoodSize, StoreHash, tsl::hh::prime_growth_policy>;
} // end namespace tsl
#endif
cc_library(feed_data_set SRCS data_set.cc DEPS operator)
#include "paddle/fluid/feed/src/data_reader/data_set.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/io/fs.h"
#include "paddle/fluid/platform/timer.h"
namespace paddle {
namespace framework {
void FeedMultiSlotDataset::CreatePreLoadReaders() {
VLOG(3) << "Begin CreatePreLoadReaders";
if (preload_thread_num_ == 0) {
preload_thread_num_ = thread_num_;
}
CHECK(preload_thread_num_ > 0) << "thread num should > 0";
CHECK(input_channel_ != nullptr);
preload_readers_.clear();
for (int i = 0; i < preload_thread_num_; ++i) {
preload_readers_.push_back(
DataFeedFactory::CreateDataFeed(data_feed_desc_.name()));
preload_readers_[i]->Init(data_feed_desc_);
preload_readers_[i]->SetThreadId(i);
preload_readers_[i]->SetThreadNum(preload_thread_num_);
preload_readers_[i]->SetFileListMutex(&mutex_for_pick_file_);
preload_readers_[i]->SetFileListIndex(&file_idx_);
preload_readers_[i]->SetFileList(filelist_);
preload_readers_[i]->SetParseInsId(parse_ins_id_);
preload_readers_[i]->SetParseContent(parse_content_);
preload_readers_[i]->SetInputChannel(input_channel_.get());
preload_readers_[i]->SetOutputChannel(nullptr);
preload_readers_[i]->SetConsumeChannel(nullptr);
}
VLOG(3) << "End CreatePreLoadReaders";
}
void FeedMultiSlotDataset::MergeByInsId() {
VLOG(3) << "MultiSlotDataset::MergeByInsId begin";
if (!merge_by_insid_) {
VLOG(3) << "merge_by_insid=false, will not MergeByInsId";
return;
}
auto multi_slot_desc = data_feed_desc_.multi_slot_desc();
std::vector<std::string> use_slots;
for (size_t i = 0; i < multi_slot_desc.slots_size(); ++i) {
const auto& slot = multi_slot_desc.slots(i);
if (slot.is_used()) {
use_slots.push_back(slot.name());
}
}
CHECK(multi_output_channel_.size() != 0); // NOLINT
auto channel_data = paddle::framework::MakeChannel<Record>();
VLOG(3) << "multi_output_channel_.size() " << multi_output_channel_.size();
for (size_t i = 0; i < multi_output_channel_.size(); ++i) {
std::vector<Record> vec_data;
multi_output_channel_[i]->Close();
multi_output_channel_[i]->ReadAll(vec_data);
channel_data->Write(std::move(vec_data));
vec_data.clear();
vec_data.shrink_to_fit();
multi_output_channel_[i]->Clear();
}
channel_data->Close();
std::vector<Record> recs;
recs.reserve(channel_data->Size());
channel_data->ReadAll(recs);
channel_data->Clear();
std::sort(recs.begin(), recs.end(), [](const Record& a, const Record& b) {
return a.ins_id_ < b.ins_id_;
});
std::vector<Record> results;
uint64_t drop_ins_num = 0;
std::unordered_set<uint16_t> all_int64;
std::unordered_set<uint16_t> all_float;
std::unordered_set<uint16_t> local_uint64;
std::unordered_set<uint16_t> local_float;
VLOG(3) << "recs.size() " << recs.size();
for (size_t i = 0; i < recs.size();) {
size_t j = i + 1;
while (j < recs.size() && recs[j].ins_id_ == recs[i].ins_id_) {
j++;
}
if (min_merge_size_ > 0 && j - i != min_merge_size_) {
drop_ins_num += j - i;
LOG(WARNING) << "drop ins " << recs[i].ins_id_ << " size=" << j - i
<< ", because merge_size=" << min_merge_size_;
i = j;
continue;
}
all_int64.clear();
all_float.clear();
bool has_conflict_slot = false;
uint16_t conflict_slot = 0;
Record rec;
rec.ins_id_ = recs[i].ins_id_;
rec.content_ = recs[i].content_;
for (size_t k = i; k < j; k++) {
local_uint64.clear();
local_float.clear();
for (auto& feature : recs[k].uint64_feasigns_) {
uint16_t slot = feature.slot();
if (all_int64.find(slot) != all_int64.end()) {
has_conflict_slot = true;
conflict_slot = slot;
break;
}
local_uint64.insert(slot);
rec.uint64_feasigns_.push_back(std::move(feature));
}
if (has_conflict_slot) {
break;
}
all_int64.insert(local_uint64.begin(), local_uint64.end());
for (auto& feature : recs[k].float_feasigns_) {
uint16_t slot = feature.slot();
if (all_float.find(slot) != all_float.end()) {
has_conflict_slot = true;
conflict_slot = slot;
break;
}
local_float.insert(slot);
rec.float_feasigns_.push_back(std::move(feature));
}
if (has_conflict_slot) {
break;
}
all_float.insert(local_float.begin(), local_float.end());
}
if (has_conflict_slot) {
LOG(WARNING) << "drop ins " << recs[i].ins_id_ << " size=" << j - i
<< ", because conflict_slot=" << use_slots[conflict_slot];
drop_ins_num += j - i;
} else {
results.push_back(std::move(rec));
}
i = j;
}
std::vector<Record>().swap(recs);
VLOG(3) << "results size " << results.size();
LOG(WARNING) << "total drop ins num: " << drop_ins_num;
results.shrink_to_fit();
auto fleet_ptr = FleetWrapper::GetInstance();
std::shuffle(results.begin(), results.end(), fleet_ptr->LocalRandomEngine());
channel_data->Open();
channel_data->Write(std::move(results));
channel_data->Close();
results.clear();
results.shrink_to_fit();
VLOG(3) << "channel data size " << channel_data->Size();
channel_data->SetBlockSize(channel_data->Size() / channel_num_ + 1);
VLOG(3) << "channel data block size " << channel_data->BlockSize();
for (size_t i = 0; i < multi_output_channel_.size(); ++i) {
std::vector<Record> vec_data;
channel_data->Read(vec_data);
multi_output_channel_[i]->Open();
multi_output_channel_[i]->Write(std::move(vec_data));
vec_data.clear();
vec_data.shrink_to_fit();
}
CHECK(channel_data->Size() == 0); // NOLINT
channel_data->Clear();
VLOG(3) << "MultiSlotDataset::MergeByInsId end";
}
} // end namespace framework
} // end namespace paddle
#pragma once
#include "paddle/fluid/framework/data_set.h"
namespace paddle {
namespace framework {
class FeedMultiSlotDataset : public MultiSlotDataset {
public:
FeedMultiSlotDataset() {}
virtual void MergeByInsId();
virtual void CreatePreLoadReaders();
virtual ~FeedMultiSlotDataset() {}
};
} // end namespace framework
} // end namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册