提交 0dc7d425 编写于 作者: X xiexionghang

feed improve dict/merge_patch

上级 01f2dca6
#!bash
build_mode=$1
function print_usage() {
echo "++++++++++++++++++++++++++++++++++++++++++++++++++++"
echo "sh build.sh all|make|clean"
echo "- all: will update all env && make it"
echo "- make: just do make, never update env"
echo "- clean: make clean"
echo "++++++++++++++++++++++++++++++++++++++++++++++++++++"
exit 0
}
if [ $# -lt 1 ];then
print_usage
fi
cd ~
user_dir=`pwd`
cd -
python_binary=${user_dir}/.jumbo/bin/python2.7
python_library=${user_dir}/.jumbo/lib/python2.7.so
python_include_dir=${user_dir}/.jumbo/include/python2.7
if [ ! -f ${python_binary} ];then
echo "Miss python ${python_binary}, please install with this cmd: jumbo install python"
exit -1
fi
#apply feed code
if [ -f "paddle/fluid/feed/apply_feed_code.sh" ];then
sh paddle/fluid/feed/apply_feed_code.sh
fi
function makeit() {
cd build
make -j8
cd ..
}
function cmake_all() {
mkdir build
cd build
#make clean
cmake -DCMAKE_INSTALL_PREFIX=./output/ -DCMAKE_BUILD_TYPE=Release -DWITH_PYTHON=ON -DWITH_MKL=OFF -DWITH_GPU=OFF -DWITH_PSLIB=ON -DPYTHON_INCLUDE_DIR=${python_include_dir} -DPYTHON_LIBRARY=${python_library} -DPYTHON_EXECUTABLE=${python_binary} ..
cd ..
}
if [ "${build_mode}" = "all" ];then
cmake_all
makeit
elif [ "${build_mode}" = "make" ];then
makeit
elif "${build_mode}" = "clean" ];then
cd build
make clean
fi
add_subdirectory(src)
add_subdirectory(pybind)
#!bash
#将FEED定制化代码生效到Paddle代码库(如FEED插件注册) 编译前执行
function fatal_log() {
echo "$1"
exit -1
}
#处理pybind 拓展
function apply_pybind() {
pybind_file='paddle/fluid/pybind/pybind.cc'
if [ ! -f ${pybind_file} ];then
fatal_log "Missing Requied File:${pybind_file}"
fi
find_inferece_api=`grep 'inference_api.h' ${pybind_file} |wc -l`
if [ ${find_inferece_api} -ne 1 ];then
fatal_log "Missing inference_api.h, Need Code Adjust"
fi
find_inferece_api=`grep 'BindInferenceApi' ${pybind_file} |wc -l`
if [ ${find_inferece_api} -ne 1 ];then
fatal_log "Missing BindInferenceApi, Need Code Adjust"
fi
makefile='paddle/fluid/pybind/CMakeLists.txt'
if [ ! -f ${makefile} ];then
fatal_log "Missing Requied File:${makefile}"
fi
sed -i '/expand_api/d' ${pybind_file}
sed -i '/BindExpandApi/d' ${pybind_file}
sed -i '/feed_data_set/d' ${makefile}
sed -i '/feed_paddle_pybind/d' ${makefile}
sed -i '/APPEND PYBIND_DEPS fs/d' ${makefile}
sed -i '/inference_api.h/a\#include "paddle/fluid/feed/pybind/expand_api.h"' ${pybind_file}
sed -i '/BindInferenceApi/a\ BindExpandApi(&m);' ${pybind_file}
sed -i '/set(PYBIND_SRCS/i\list(APPEND PYBIND_DEPS feed_data_set)' ${makefile}
sed -i '/set(PYBIND_SRCS/i\list(APPEND PYBIND_DEPS feed_paddle_pybind)' ${makefile}
sed -i '/set(PYBIND_SRCS/i\list(APPEND PYBIND_DEPS fs)' ${makefile}
}
function apply_feed_src() {
makefile='paddle/fluid/CMakeLists.txt'
if [ ! -f ${makefile} ];then
fatal_log "Missing Requied File:${makefile}"
fi
find_py=`grep 'pybind' ${makefile} |wc -l`
if [ ${find_py} -ne 1 ];then
fatal_log "Missing pybind, Need Code Adjust"
fi
sed -i '/feed/d' ${makefile}
sed -i '/pybind/i\add_subdirectory(feed)' ${makefile}
dataset_file='paddle/fluid/framework/dataset_factory.cc'
if [ ! -f ${dataset_file} ];then
fatal_log "Missing Requied File:${dataset_file}"
fi
sed -i '/FeedMultiSlotDataset/d' ${dataset_file}
sed -i '/data_reader/d' ${dataset_file}
sed -i '/REGISTER_DATASET_CLASS(MultiSlotDataset)/a\REGISTER_DATASET_CLASS(FeedMultiSlotDataset);' ${dataset_file}
sed -i '/data_set.h/a\#include "paddle/fluid/feed/src/data_reader/data_set.h"' ${dataset_file}
sed -i '/feed_data_set/d' paddle/fluid/framework/CMakeLists.txt
#sed -i '/target_link_libraries(executor/a\target_link_libraries(feed_data_set)' paddle/fluid/framework/CMakeLists.txt
#sed -i '/target_link_libraries(executor/a\add_dependencies(feed_data_set)' paddle/fluid/framework/CMakeLists.txt
}
apply_pybind
apply_feed_src
set(FEED_PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapper
pass_builder parallel_executor profiler layer tracer engine scope_pool
dict_plugin fs shell)
if(WITH_PYTHON)
list(APPEND FEED_PYBIND_DEPS py_func_op)
endif()
set(FEED_PYBIND_SRCS
expand_api.cc
)
if(WITH_PYTHON)
if(WITH_AMD_GPU)
hip_library(feed_paddle_pybind SRCS ${FEED_PYBIND_SRCS} DEPS ARCHIVE_START ${FEED_PYBIND_DEPS} ARCHIVE_END)
else()
cc_library(feed_paddle_pybind SRCS ${FEED_PYBIND_SRCS} DEPS ${FEED_PYBIND_DEPS})
endif(WITH_AMD_GPU)
get_property (os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
target_link_libraries(feed_paddle_pybind ${os_dependency_modules})
endif(WITH_PYTHON)
#include "paddle/fluid/feed/pybind/expand_api.h"
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <cstring>
#include <iostream>
#include <map>
#include <memory>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/feed/src/common/dict_plugin.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
using paddle::framework::DictPluginManager;
using paddle::framework::FeasignCacheDict;
void BindExpandDictPlugin(py::module *m);
void BindExpandApi(py::module *m) {
BindExpandDictPlugin(m);
}
void BindExpandDictPlugin(py::module *m) {
py::class_<FeasignCacheDict>(*m, "FeasignCacheDict")
.def(py::init<>())
.def(py::init<const FeasignCacheDict &>())
.def("load", &FeasignCacheDict::Load);
py::class_<DictPluginManager>(*m, "DictPluginManager")
.def(py::init<>())
.def_static("instance", &DictPluginManager::Instance)
.def("load_dict", &DictPluginManager::LoadDict)
.def("create_dict", &DictPluginManager::CreateDict);
}
} // namespace pybind
} // namespace paddle
#pragma once
#include <pybind11/pybind11.h>
namespace paddle {
namespace pybind {
void BindExpandApi(pybind11::module *m);
} // namespace pybind
} // namespace paddle
此差异已折叠。
add_subdirectory(common)
add_subdirectory(data_reader)
cc_library(dict_plugin SRCS dict_plugin.cc DEPS glog boost fs)
此差异已折叠。
此差异已折叠。
#include <iostream>
#include "paddle/fluid/feed/src/common/dict_plugin.h"
#include "paddle/fluid/framework/io/fs.h"
namespace paddle {
namespace framework {
int FeasignCacheDict::Load(
const std::string& path, const std::string& converter) {
auto version = version_ + 1;
if (version >= versioned_entity_.size()) {
version = 0;
}
auto& entity = versioned_entity_[version];
uint64_t data_count = 0;
auto file_list = fs_list(path);
for (auto& file_path : file_list) {
int err_no = 0;
int line_len = 0;
size_t buffer_size = 0;
char *buffer = nullptr;
char* data_ptr = NULL;
auto file = fs_open_read(file_path, &err_no, converter);
CHECK(err_no == 0);
while ((line_len = getline(&buffer, &buffer_size, file.get())) > 0) {
if (line_len <= 1) {
continue;
}
++data_count;
entity.Append(strtoul(buffer, &data_ptr, 10), entity.Size());
}
if (buffer != nullptr) {
free(buffer);
}
}
version_ = version;
std::cerr << "Load success data_count" << data_count << " to version:" << version_ << std::endl;
return 0;
}
} // namespace framework
} // namespace paddle
#pragma once
#include <list>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <glog/logging.h>
#include "paddle/fluid/feed/src/common/bhopscotch_map.h"
namespace paddle {
namespace framework {
class DictPlugin {
public:
DictPlugin() {}
virtual ~DictPlugin() {}
virtual int Load(const std::string& path, const std::string& converter) = 0;
};
template <class K, class V>
class KvEntity {
public:
KvEntity() {}
~KvEntity() {}
uint32_t Size() {
return _key_list.size();
}
void Append(const K& k, const V& v) {
if (_dict_data.find(k) != _dict_data.end()) {
return;
}
_key_list.push_back(k);
_dict_data.emplace(k, v);
}
std::vector<K> _key_list;
tsl::bhopscotch_pg_map<K, V> _dict_data;
};
template <class K, class V>
class KvDictPlugin : public DictPlugin {
public:
KvDictPlugin() {
versioned_entity_.resize(2);
}
virtual ~KvDictPlugin() {}
// GetValue with version, Return: value
virtual int GetValueWithVersion(uint32_t version, const K& key, V& v) {
CHECK(version < versioned_entity_.size());
auto& entity = versioned_entity_[version];
auto itr = entity._dict_data.find(key);
if (itr == entity._dict_data.end()) {
return -1; // miss
}
v = itr->second;
return 0;
}
// GetValue without version, Return: value version
virtual int GetValue(const K& key, V& v, uint32_t& version) {
version = version_;
auto& entity = versioned_entity_[version];
auto itr = entity._dict_data.find(key);
if (itr == entity._dict_data.end()) {
return -1; // miss
}
v = itr->second;
return 0;
}
virtual int GetVersion() {
return version_;
}
protected:
uint32_t version_ = 0;
// double-buffer support version:0 1
std::vector<KvEntity<K, V>> versioned_entity_;
};
class FeasignCacheDict : public KvDictPlugin<uint64_t, uint32_t> {
public:
FeasignCacheDict(){}
virtual ~FeasignCacheDict(){}
virtual int Load(const std::string& path, const std::string& converter);
};
class DictPluginManager {
public:
DictPluginManager() {}
virtual ~DictPluginManager(){}
static DictPluginManager& Instance() {
static DictPluginManager manager;
return manager;
}
inline int CreateDict(const std::string& dict_name) {
#define PADDLE_DICT_PLUGIN_REGIST(dict) \
if (dict_name == #dict) { \
dicts_map_[dict_name].reset(new dict()); \
return 0; \
}
PADDLE_DICT_PLUGIN_REGIST(FeasignCacheDict)
#undef PADDLE_DICT_PLUGIN_REGIST
return -1;
}
inline DictPlugin* GetDict(const std::string& dict_name) {
if (dicts_map_.count(dict_name)) {
return dicts_map_[dict_name].get();
}
return nullptr;
}
inline int LoadDict(const std::string& dict_name,
const std::string& path, const std::string converter) {
auto dict = GetDict(dict_name);
if (!dict) {
return -1;
}
return dict->Load(path, converter);
}
private:
std::unordered_map<std::string, std::shared_ptr<DictPlugin>> dicts_map_;
};
} // namespace framework
} // namespace paddle
/**
* MIT License
*
* Copyright (c) 2018 Tessil
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
#ifndef TSL_HOPSCOTCH_GROWTH_POLICY_H
#define TSL_HOPSCOTCH_GROWTH_POLICY_H
#include <algorithm>
#include <array>
#include <climits>
#include <cmath>
#include <cstddef>
#include <iterator>
#include <limits>
#include <ratio>
#include <stdexcept>
namespace tsl {
namespace hh {
/**
* Grow the hash table by a factor of GrowthFactor keeping the bucket count to a power of two. It allows
* the table to use a mask operation instead of a modulo operation to map a hash to a bucket.
*
* GrowthFactor must be a power of two >= 2.
*/
template<std::size_t GrowthFactor>
class power_of_two_growth_policy {
public:
/**
* Called on the hash table creation and on rehash. The number of buckets for the table is passed in parameter.
* This number is a minimum, the policy may update this value with a higher value if needed (but not lower).
*
* If 0 is given, min_bucket_count_in_out must still be 0 after the policy creation and
* bucket_for_hash must always return 0 in this case.
*/
explicit power_of_two_growth_policy(std::size_t& min_bucket_count_in_out) {
if(min_bucket_count_in_out > max_bucket_count()) {
throw std::length_error("The hash table exceeds its maxmimum size.");
}
if(min_bucket_count_in_out > 0) {
min_bucket_count_in_out = round_up_to_power_of_two(min_bucket_count_in_out);
m_mask = min_bucket_count_in_out - 1;
}
else {
m_mask = 0;
}
}
/**
* Return the bucket [0, bucket_count()) to which the hash belongs.
* If bucket_count() is 0, it must always return 0.
*/
std::size_t bucket_for_hash(std::size_t hash) const noexcept {
return hash & m_mask;
}
/**
* Return the bucket count to use when the bucket array grows on rehash.
*/
std::size_t next_bucket_count() const {
if((m_mask + 1) > max_bucket_count() / GrowthFactor) {
throw std::length_error("The hash table exceeds its maxmimum size.");
}
return (m_mask + 1) * GrowthFactor;
}
/**
* Return the maximum number of buckets supported by the policy.
*/
std::size_t max_bucket_count() const {
// Largest power of two.
return (std::numeric_limits<std::size_t>::max() / 2) + 1;
}
/**
* Reset the growth policy as if it was created with a bucket count of 0.
* After a clear, the policy must always return 0 when bucket_for_hash is called.
*/
void clear() noexcept {
m_mask = 0;
}
private:
static std::size_t round_up_to_power_of_two(std::size_t value) {
if(is_power_of_two(value)) {
return value;
}
if(value == 0) {
return 1;
}
--value;
for(std::size_t i = 1; i < sizeof(std::size_t) * CHAR_BIT; i *= 2) {
value |= value >> i;
}
return value + 1;
}
static constexpr bool is_power_of_two(std::size_t value) {
return value != 0 && (value & (value - 1)) == 0;
}
private:
static_assert(is_power_of_two(GrowthFactor) && GrowthFactor >= 2, "GrowthFactor must be a power of two >= 2.");
std::size_t m_mask;
};
/**
* Grow the hash table by GrowthFactor::num / GrowthFactor::den and use a modulo to map a hash
* to a bucket. Slower but it can be useful if you want a slower growth.
*/
template<class GrowthFactor = std::ratio<3, 2>>
class mod_growth_policy {
public:
explicit mod_growth_policy(std::size_t& min_bucket_count_in_out) {
if(min_bucket_count_in_out > max_bucket_count()) {
throw std::length_error("The hash table exceeds its maxmimum size.");
}
if(min_bucket_count_in_out > 0) {
m_mod = min_bucket_count_in_out;
}
else {
m_mod = 1;
}
}
std::size_t bucket_for_hash(std::size_t hash) const noexcept {
return hash % m_mod;
}
std::size_t next_bucket_count() const {
if(m_mod == max_bucket_count()) {
throw std::length_error("The hash table exceeds its maxmimum size.");
}
const double next_bucket_count = std::ceil(double(m_mod) * REHASH_SIZE_MULTIPLICATION_FACTOR);
if(!std::isnormal(next_bucket_count)) {
throw std::length_error("The hash table exceeds its maxmimum size.");
}
if(next_bucket_count > double(max_bucket_count())) {
return max_bucket_count();
}
else {
return std::size_t(next_bucket_count);
}
}
std::size_t max_bucket_count() const {
return MAX_BUCKET_COUNT;
}
void clear() noexcept {
m_mod = 1;
}
private:
static constexpr double REHASH_SIZE_MULTIPLICATION_FACTOR = 1.0 * GrowthFactor::num / GrowthFactor::den;
static const std::size_t MAX_BUCKET_COUNT =
std::size_t(double(
std::numeric_limits<std::size_t>::max() / REHASH_SIZE_MULTIPLICATION_FACTOR
));
static_assert(REHASH_SIZE_MULTIPLICATION_FACTOR >= 1.1, "Growth factor should be >= 1.1.");
std::size_t m_mod;
};
namespace detail {
static constexpr const std::array<std::size_t, 186> PRIMES = {{
1ull, 3ull, 5ull, 7ull, 11ull, 13ull, 17ull, 23ull, 29ull, 37ull, 47ull,
59ull, 73ull, 97ull, 127ull, 151ull, 197ull, 251ull, 313ull, 397ull,
499ull, 631ull, 797ull, 1009ull, 1259ull, 1597ull, 2011ull, 2539ull,
3203ull, 4027ull, 5087ull, 6421ull, 8089ull, 10193ull, 12853ull, 16193ull,
20399ull, 25717ull, 32401ull, 40823ull, 51437ull, 64811ull, 81649ull,
102877ull, 129607ull, 163307ull, 205759ull, 259229ull, 326617ull,
411527ull, 518509ull, 653267ull, 823117ull, 1037059ull, 1306601ull,
1646237ull, 2074129ull, 2613229ull, 3292489ull, 4148279ull, 5226491ull,
6584983ull, 8296553ull, 10453007ull, 13169977ull, 16593127ull, 20906033ull,
26339969ull, 33186281ull, 41812097ull, 52679969ull, 66372617ull,
83624237ull, 105359939ull, 132745199ull, 167248483ull, 210719881ull,
265490441ull, 334496971ull, 421439783ull, 530980861ull, 668993977ull,
842879579ull, 1061961721ull, 1337987929ull, 1685759167ull, 2123923447ull,
2675975881ull, 3371518343ull, 4247846927ull, 5351951779ull, 6743036717ull,
8495693897ull, 10703903591ull, 13486073473ull, 16991387857ull,
21407807219ull, 26972146961ull, 33982775741ull, 42815614441ull,
53944293929ull, 67965551447ull, 85631228929ull, 107888587883ull,
135931102921ull, 171262457903ull, 215777175787ull, 271862205833ull,
342524915839ull, 431554351609ull, 543724411781ull, 685049831731ull,
863108703229ull, 1087448823553ull, 1370099663459ull, 1726217406467ull,
2174897647073ull, 2740199326961ull, 3452434812973ull, 4349795294267ull,
5480398654009ull, 6904869625999ull, 8699590588571ull, 10960797308051ull,
13809739252051ull, 17399181177241ull, 21921594616111ull, 27619478504183ull,
34798362354533ull, 43843189232363ull, 55238957008387ull, 69596724709081ull,
87686378464759ull, 110477914016779ull, 139193449418173ull,
175372756929481ull, 220955828033581ull, 278386898836457ull,
350745513859007ull, 441911656067171ull, 556773797672909ull,
701491027718027ull, 883823312134381ull, 1113547595345903ull,
1402982055436147ull, 1767646624268779ull, 2227095190691797ull,
2805964110872297ull, 3535293248537579ull, 4454190381383713ull,
5611928221744609ull, 7070586497075177ull, 8908380762767489ull,
11223856443489329ull, 14141172994150357ull, 17816761525534927ull,
22447712886978529ull, 28282345988300791ull, 35633523051069991ull,
44895425773957261ull, 56564691976601587ull, 71267046102139967ull,
89790851547914507ull, 113129383953203213ull, 142534092204280003ull,
179581703095829107ull, 226258767906406483ull, 285068184408560057ull,
359163406191658253ull, 452517535812813007ull, 570136368817120201ull,
718326812383316683ull, 905035071625626043ull, 1140272737634240411ull,
1436653624766633509ull, 1810070143251252131ull, 2280545475268481167ull,
2873307249533267101ull, 3620140286502504283ull, 4561090950536962147ull,
5746614499066534157ull, 7240280573005008577ull, 9122181901073924329ull,
11493228998133068689ull, 14480561146010017169ull, 18446744073709551557ull
}};
template<unsigned int IPrime>
static constexpr std::size_t mod(std::size_t hash) { return hash % PRIMES[IPrime]; }
// MOD_PRIME[iprime](hash) returns hash % PRIMES[iprime]. This table allows for faster modulo as the
// compiler can optimize the modulo code better with a constant known at the compilation.
static constexpr const std::array<std::size_t(*)(std::size_t), 186> MOD_PRIME = {{
&mod<0>, &mod<1>, &mod<2>, &mod<3>, &mod<4>, &mod<5>, &mod<6>, &mod<7>, &mod<8>, &mod<9>, &mod<10>,
&mod<11>, &mod<12>, &mod<13>, &mod<14>, &mod<15>, &mod<16>, &mod<17>, &mod<18>, &mod<19>, &mod<20>,
&mod<21>, &mod<22>, &mod<23>, &mod<24>, &mod<25>, &mod<26>, &mod<27>, &mod<28>, &mod<29>, &mod<30>,
&mod<31>, &mod<32>, &mod<33>, &mod<34>, &mod<35>, &mod<36>, &mod<37>, &mod<38>, &mod<39>, &mod<40>,
&mod<41>, &mod<42>, &mod<43>, &mod<44>, &mod<45>, &mod<46>, &mod<47>, &mod<48>, &mod<49>, &mod<50>,
&mod<51>, &mod<52>, &mod<53>, &mod<54>, &mod<55>, &mod<56>, &mod<57>, &mod<58>, &mod<59>, &mod<60>,
&mod<61>, &mod<62>, &mod<63>, &mod<64>, &mod<65>, &mod<66>, &mod<67>, &mod<68>, &mod<69>, &mod<70>,
&mod<71>, &mod<72>, &mod<73>, &mod<74>, &mod<75>, &mod<76>, &mod<77>, &mod<78>, &mod<79>, &mod<80>,
&mod<81>, &mod<82>, &mod<83>, &mod<84>, &mod<85>, &mod<86>, &mod<87>, &mod<88>, &mod<89>, &mod<90>,
&mod<91>, &mod<92>, &mod<93>, &mod<94>, &mod<95>, &mod<96>, &mod<97>, &mod<98>, &mod<99>, &mod<100>,
&mod<101>, &mod<102>, &mod<103>, &mod<104>, &mod<105>, &mod<106>, &mod<107>, &mod<108>, &mod<109>, &mod<110>,
&mod<111>, &mod<112>, &mod<113>, &mod<114>, &mod<115>, &mod<116>, &mod<117>, &mod<118>, &mod<119>, &mod<120>,
&mod<121>, &mod<122>, &mod<123>, &mod<124>, &mod<125>, &mod<126>, &mod<127>, &mod<128>, &mod<129>, &mod<130>,
&mod<131>, &mod<132>, &mod<133>, &mod<134>, &mod<135>, &mod<136>, &mod<137>, &mod<138>, &mod<139>, &mod<140>,
&mod<141>, &mod<142>, &mod<143>, &mod<144>, &mod<145>, &mod<146>, &mod<147>, &mod<148>, &mod<149>, &mod<150>,
&mod<151>, &mod<152>, &mod<153>, &mod<154>, &mod<155>, &mod<156>, &mod<157>, &mod<158>, &mod<159>, &mod<160>,
&mod<161>, &mod<162>, &mod<163>, &mod<164>, &mod<165>, &mod<166>, &mod<167>, &mod<168>, &mod<169>, &mod<170>,
&mod<171>, &mod<172>, &mod<173>, &mod<174>, &mod<175>, &mod<176>, &mod<177>, &mod<178>, &mod<179>, &mod<180>,
&mod<181>, &mod<182>, &mod<183>, &mod<184>, &mod<185>
}};
}
/**
* Grow the hash table by using prime numbers as bucket count. Slower than tsl::hh::power_of_two_growth_policy in
* general but will probably distribute the values around better in the buckets with a poor hash function.
*
* To allow the compiler to optimize the modulo operation, a lookup table is used with constant primes numbers.
*
* With a switch the code would look like:
* \code
* switch(iprime) { // iprime is the current prime of the hash table
* case 0: hash % 5ul;
* break;
* case 1: hash % 17ul;
* break;
* case 2: hash % 29ul;
* break;
* ...
* }
* \endcode
*
* Due to the constant variable in the modulo the compiler is able to optimize the operation
* by a series of multiplications, substractions and shifts.
*
* The 'hash % 5' could become something like 'hash - (hash * 0xCCCCCCCD) >> 34) * 5' in a 64 bits environement.
*/
class prime_growth_policy {
public:
explicit prime_growth_policy(std::size_t& min_bucket_count_in_out) {
auto it_prime = std::lower_bound(detail::PRIMES.begin(),
detail::PRIMES.end(), min_bucket_count_in_out);
if(it_prime == detail::PRIMES.end()) {
throw std::length_error("The hash table exceeds its maxmimum size.");
}
m_iprime = static_cast<unsigned int>(std::distance(detail::PRIMES.begin(), it_prime));
if(min_bucket_count_in_out > 0) {
min_bucket_count_in_out = *it_prime;
}
else {
min_bucket_count_in_out = 0;
}
}
std::size_t bucket_for_hash(std::size_t hash) const noexcept {
return detail::MOD_PRIME[m_iprime](hash);
}
std::size_t next_bucket_count() const {
if(m_iprime + 1 >= detail::PRIMES.size()) {
throw std::length_error("The hash table exceeds its maxmimum size.");
}
return detail::PRIMES[m_iprime + 1];
}
std::size_t max_bucket_count() const {
return detail::PRIMES.back();
}
void clear() noexcept {
m_iprime = 0;
}
private:
unsigned int m_iprime;
static_assert(std::numeric_limits<decltype(m_iprime)>::max() >= detail::PRIMES.size(),
"The type of m_iprime is not big enough.");
};
}
}
#endif
此差异已折叠。
此差异已折叠。
此差异已折叠。
cc_library(feed_data_set SRCS data_set.cc DEPS operator)
此差异已折叠。
#pragma once
#include "paddle/fluid/framework/data_set.h"
namespace paddle {
namespace framework {
class FeedMultiSlotDataset : public MultiSlotDataset {
public:
FeedMultiSlotDataset() {}
virtual void MergeByInsId();
virtual void CreatePreLoadReaders();
virtual ~FeedMultiSlotDataset() {}
};
} // end namespace framework
} // end namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册