提交 8829ceb1 编写于 作者: W wangyihong01

feasign parser

Change-Id: I9febb7270ff6248b6d4d7194f9d87501a156e3b3
上级 8e040ca3
...@@ -5,7 +5,6 @@ GLOBAL_CFLAGS_STR = '-g -O0 -pipe -fopenmp ' ...@@ -5,7 +5,6 @@ GLOBAL_CFLAGS_STR = '-g -O0 -pipe -fopenmp '
CFLAGS(GLOBAL_CFLAGS_STR) CFLAGS(GLOBAL_CFLAGS_STR)
GLOBAL_CXXFLAGS_STR = GLOBAL_CFLAGS_STR + ' -std=c++11 ' GLOBAL_CXXFLAGS_STR = GLOBAL_CFLAGS_STR + ' -std=c++11 '
CXXFLAGS(GLOBAL_CXXFLAGS_STR) CXXFLAGS(GLOBAL_CXXFLAGS_STR)
INCPATHS('./') INCPATHS('./')
INCPATHS('$OUT/../') INCPATHS('$OUT/../')
INCPATHS('../../third-party') INCPATHS('../../third-party')
...@@ -37,7 +36,6 @@ CONFIGS('baidu/third-party/python@gcc482output@git_branch') ...@@ -37,7 +36,6 @@ CONFIGS('baidu/third-party/python@gcc482output@git_branch')
CONFIGS('baidu/third-party/yaml-cpp@yaml-cpp_0-6-2-0_GEN_PD_BL@git_tag') CONFIGS('baidu/third-party/yaml-cpp@yaml-cpp_0-6-2-0_GEN_PD_BL@git_tag')
CONFIGS('baidu/third-party/openmpi@openmpi_1-4-5-0-feed_mlarch@git_branch') CONFIGS('baidu/third-party/openmpi@openmpi_1-4-5-0-feed_mlarch@git_branch')
CONFIGS('third-64/gtest@base') CONFIGS('third-64/gtest@base')
HEADERS('paddle/fluid/memory/*.h', '$INC/paddle/fluid/memory/') HEADERS('paddle/fluid/memory/*.h', '$INC/paddle/fluid/memory/')
HEADERS('paddle/fluid/memory/detail/*.h', '$INC/paddle/fluid/memory/detail/') HEADERS('paddle/fluid/memory/detail/*.h', '$INC/paddle/fluid/memory/detail/')
HEADERS('paddle/fluid/memory/allocation/*.h', '$INC/paddle/fluid/memory/allocation/') HEADERS('paddle/fluid/memory/allocation/*.h', '$INC/paddle/fluid/memory/allocation/')
...@@ -58,12 +56,9 @@ HEADERS('paddle/fluid/pybind/pybind.h', '$INC/paddle/fluid/pybind') ...@@ -58,12 +56,9 @@ HEADERS('paddle/fluid/pybind/pybind.h', '$INC/paddle/fluid/pybind')
HEADERS('paddle/fluid/inference/api/*.h', '$INC/paddle/fluid/inference/api/') HEADERS('paddle/fluid/inference/api/*.h', '$INC/paddle/fluid/inference/api/')
HEADERS(GLOB_GEN_SRCS('paddle/fluid/framework/*pb.h'), '$INC/paddle/fluid/framework') HEADERS(GLOB_GEN_SRCS('paddle/fluid/framework/*pb.h'), '$INC/paddle/fluid/framework')
HEADERS(GLOB_GEN_SRCS('paddle/fluid/platform/*pb.h'), '$INC/paddle/fluid/platform') HEADERS(GLOB_GEN_SRCS('paddle/fluid/platform/*pb.h'), '$INC/paddle/fluid/platform')
PROTOC('../../third-party/protobuf/bin/protoc') PROTOC('../../third-party/protobuf/bin/protoc')
#proto #proto
StaticLibrary("fake_paddle_proto", Sources(GLOB("paddle/fluid/framework/*.proto"), GLOB("paddle/fluid/platform/*.proto"))) StaticLibrary("fake_paddle_proto", Sources(GLOB("paddle/fluid/framework/*.proto"), GLOB("paddle/fluid/platform/*.proto")))
#feed #feed
HEADERS('paddle/fluid/train/custom_trainer/feed/*.h', '$INC/paddle/fluid/train/custom_trainer/feed/') HEADERS('paddle/fluid/train/custom_trainer/feed/*.h', '$INC/paddle/fluid/train/custom_trainer/feed/')
HEADERS('paddle/fluid/train/custom_trainer/feed/common/*.h', '$INC/paddle/fluid/train/custom_trainer/feed/common/') HEADERS('paddle/fluid/train/custom_trainer/feed/common/*.h', '$INC/paddle/fluid/train/custom_trainer/feed/common/')
...@@ -83,9 +78,7 @@ def UT_FILE(filename): ...@@ -83,9 +78,7 @@ def UT_FILE(filename):
UT_DIR = 'paddle/fluid/train/custom_trainer/feed/unit_test' UT_DIR = 'paddle/fluid/train/custom_trainer/feed/unit_test'
import os import os
return os.path.join(UT_DIR, filename) return os.path.join(UT_DIR, filename)
custom_trainer_src = GLOB('paddle/fluid/train/custom_trainer/feed/*/*.cc', Exclude(UT_FILE('*'))) custom_trainer_src = GLOB('paddle/fluid/train/custom_trainer/feed/*/*.cc', Exclude(UT_FILE('*')))
CPPFLAGS_STR = '-DHPPL_STUB_FUNC -DLAPACK_FOUND -DPADDLE_DISABLE_PROFILER -DPADDLE_NO_PYTHON -DCUSTOM_TRAINER -DPADDLE_ON_INFERENCE -DPADDLE_USE_DSO -DPADDLE_USE_PTHREAD_BARRIER -DPADDLE_USE_PTHREAD_SPINLOCK -DPADDLE_VERSION=0.0.0 -DPADDLE_WITH_AVX -DPADDLE_WITH_MKLML -DPADDLE_WITH_XBYAK -DXBYAK64 -DXBYAK_NO_OP_NAMES -D_GNU_SOURCE -D__STDC_LIMIT_MACROS -DPYBIND_AVX_MKLML' + r" -DPADDLE_REVISION=\"%s@%s@%s\"" % (REPO_URL(), REPO_BRANCH(), REPO_REVISION()) CPPFLAGS_STR = '-DHPPL_STUB_FUNC -DLAPACK_FOUND -DPADDLE_DISABLE_PROFILER -DPADDLE_NO_PYTHON -DCUSTOM_TRAINER -DPADDLE_ON_INFERENCE -DPADDLE_USE_DSO -DPADDLE_USE_PTHREAD_BARRIER -DPADDLE_USE_PTHREAD_SPINLOCK -DPADDLE_VERSION=0.0.0 -DPADDLE_WITH_AVX -DPADDLE_WITH_MKLML -DPADDLE_WITH_XBYAK -DXBYAK64 -DXBYAK_NO_OP_NAMES -D_GNU_SOURCE -D__STDC_LIMIT_MACROS -DPYBIND_AVX_MKLML' + r" -DPADDLE_REVISION=\"%s@%s@%s\"" % (REPO_URL(), REPO_BRANCH(), REPO_REVISION())
CFLAGS_STR = '-m64 -fPIC -fno-omit-frame-pointer -Werror -Wall -Wextra -Wnon-virtual-dtor -Wdelete-non-virtual-dtor -Wno-unused-parameter -Wno-unused-function -Wno-error=literal-suffix -Wno-error=sign-compare -Wno-error=unused-local-typedefs -Wno-error=maybe-uninitialized -fopenmp -mavx -O0 -DNDEBUG ' CFLAGS_STR = '-m64 -fPIC -fno-omit-frame-pointer -Werror -Wall -Wextra -Wnon-virtual-dtor -Wdelete-non-virtual-dtor -Wno-unused-parameter -Wno-unused-function -Wno-error=literal-suffix -Wno-error=sign-compare -Wno-error=unused-local-typedefs -Wno-error=maybe-uninitialized -fopenmp -mavx -O0 -DNDEBUG '
......
...@@ -518,11 +518,11 @@ Archive<AR>& operator>>(Archive<AR>& ar, std::tuple<T...>& x) { ...@@ -518,11 +518,11 @@ Archive<AR>& operator>>(Archive<AR>& ar, std::tuple<T...>& x) {
} \ } \
template <class AR, class KEY, class VALUE, class... ARGS> \ template <class AR, class KEY, class VALUE, class... ARGS> \
Archive<AR>& operator>>(Archive<AR>& ar, MAP_TYPE<KEY, VALUE, ARGS...>& p) { \ Archive<AR>& operator>>(Archive<AR>& ar, MAP_TYPE<KEY, VALUE, ARGS...>& p) { \
size_t size = ar.template get<size_t>(); \ size_t size = ar.template Get<size_t>(); \
p.clear(); \ p.clear(); \
RESERVE_STATEMENT; \ RESERVE_STATEMENT; \
for (size_t i = 0; i < size; i++) { \ for (size_t i = 0; i < size; i++) { \
p.insert(ar.template get<std::pair<KEY, VALUE>>()); \ p.insert(ar.template Get<std::pair<KEY, VALUE>>()); \
} \ } \
return ar; \ return ar; \
} }
...@@ -539,11 +539,11 @@ Archive<AR>& operator>>(Archive<AR>& ar, std::tuple<T...>& x) { ...@@ -539,11 +539,11 @@ Archive<AR>& operator>>(Archive<AR>& ar, std::tuple<T...>& x) {
} \ } \
template <class AR, class KEY, class VALUE, class... ARGS> \ template <class AR, class KEY, class VALUE, class... ARGS> \
Archive<AR>& operator>>(Archive<AR>& ar, MAP_TYPE<KEY, VALUE, ARGS...>& p) { \ Archive<AR>& operator>>(Archive<AR>& ar, MAP_TYPE<KEY, VALUE, ARGS...>& p) { \
size_t size = ar.template get<uint64_t>(); \ size_t size = ar.template Get<uint64_t>(); \
p.clear(); \ p.clear(); \
RESERVE_STATEMENT; \ RESERVE_STATEMENT; \
for (size_t i = 0; i < size; i++) { \ for (size_t i = 0; i < size; i++) { \
p.insert(ar.template get<std::pair<KEY, VALUE>>()); \ p.insert(ar.template Get<std::pair<KEY, VALUE>>()); \
} \ } \
return ar; \ return ar; \
} }
...@@ -568,11 +568,11 @@ ARCHIVE_REPEAT(std::unordered_multimap, p.reserve(size)) ...@@ -568,11 +568,11 @@ ARCHIVE_REPEAT(std::unordered_multimap, p.reserve(size))
} \ } \
template <class AR, class KEY, class... ARGS> \ template <class AR, class KEY, class... ARGS> \
Archive<AR>& operator>>(Archive<AR>& ar, SET_TYPE<KEY, ARGS...>& p) { \ Archive<AR>& operator>>(Archive<AR>& ar, SET_TYPE<KEY, ARGS...>& p) { \
size_t size = ar.template get<size_t>(); \ size_t size = ar.template Get<size_t>(); \
p.clear(); \ p.clear(); \
RESERVE_STATEMENT; \ RESERVE_STATEMENT; \
for (size_t i = 0; i < size; i++) { \ for (size_t i = 0; i < size; i++) { \
p.insert(ar.template get<KEY>()); \ p.insert(ar.template Get<KEY>()); \
} \ } \
return ar; \ return ar; \
} }
...@@ -588,11 +588,11 @@ ARCHIVE_REPEAT(std::unordered_multimap, p.reserve(size)) ...@@ -588,11 +588,11 @@ ARCHIVE_REPEAT(std::unordered_multimap, p.reserve(size))
} \ } \
template <class AR, class KEY, class... ARGS> \ template <class AR, class KEY, class... ARGS> \
Archive<AR>& operator>>(Archive<AR>& ar, SET_TYPE<KEY, ARGS...>& p) { \ Archive<AR>& operator>>(Archive<AR>& ar, SET_TYPE<KEY, ARGS...>& p) { \
size_t size = ar.template get<uint64_t>(); \ size_t size = ar.template Get<uint64_t>(); \
p.clear(); \ p.clear(); \
RESERVE_STATEMENT; \ RESERVE_STATEMENT; \
for (size_t i = 0; i < size; i++) { \ for (size_t i = 0; i < size; i++) { \
p.insert(ar.template get<KEY>()); \ p.insert(ar.template Get<KEY>()); \
} \ } \
return ar; \ return ar; \
} }
......
...@@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#if defined _WIN32 || defined __APPLE__ //#if defined _WIN32 || defined __APPLE__
#else //#else
#define _LINUX #define _LINUX
#endif //#endif
#include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/data_feed.h"
#ifdef _LINUX #ifdef _LINUX
......
文件模式从 100644 更改为 100755
...@@ -24,26 +24,6 @@ ...@@ -24,26 +24,6 @@
namespace paddle { namespace paddle {
namespace string { namespace string {
inline size_t count_spaces(const char* s) {
size_t count = 0;
while (*s != 0 && isspace(*s++)) {
count++;
}
return count;
}
inline size_t count_nonspaces(const char* s) {
size_t count = 0;
while (*s != 0 && !isspace(*s++)) {
count++;
}
return count;
}
// remove leading and tailing spaces // remove leading and tailing spaces
std::string trim_spaces(const std::string& str) { std::string trim_spaces(const std::string& str) {
const char* p = str.c_str(); const char* p = str.c_str();
......
...@@ -26,9 +26,25 @@ ...@@ -26,9 +26,25 @@
namespace paddle { namespace paddle {
namespace string { namespace string {
inline size_t count_spaces(const char* s); inline size_t count_spaces(const char* s) {
size_t count = 0;
inline size_t count_nonspaces(const char* s); while (*s != 0 && isspace(*s++)) {
count++;
}
return count;
}
inline size_t count_nonspaces(const char* s) {
size_t count = 0;
while (*s != 0 && !isspace(*s++)) {
count++;
}
return count;
}
template <class... ARGS> template <class... ARGS>
void format_string_append(std::string& str, const char* fmt, // NOLINT void format_string_append(std::string& str, const char* fmt, // NOLINT
......
...@@ -58,6 +58,378 @@ public: ...@@ -58,6 +58,378 @@ public:
}; };
REGISTER_CLASS(DataParser, LineDataParser); REGISTER_CLASS(DataParser, LineDataParser);
/********************************
* feasign压缩格式
* 情形1:slot:hot
* |4b|4b|4b|4b|4b| 28b |
* |slot |0 |sign |
* 情形2:slot:hot*n
* |4b|4b|4b|4b|4b|4b|4b|4b|32b*n|
* |slot |1 |0 |len |sign |
* 情形3:slot:cold
* |4b|4b|4b|4b|4b|4b| 64b |
* |slot |2 |0 |sign |
* 情形4:slot:cold*n
* |4b|4b|4b|4b|4b|4b|4b|4b|64b*n|
* |slot |3 |0 |len |sign |
********************************/
class ArchiveDataParse : public DataParser {
public:
static const uint8_t HOT_SIGN_SIZE = 4;
static const uint8_t COLD_SIGN_SIZE = 8;
public:
ArchiveDataParse() {}
virtual ~ArchiveDataParse() {}
struct Record {
int show, clk;
std::string tags;
std::map<std::string, std::vector<float>> vec_feas;
int sample_type;
std::map<std::string, std::vector<int>> auc_category_info_map; //为细维度计算auc准备的数据
std::vector<FeatureItem> hot_feas, cold_feas; //冷(int32_t)热(uint64_t)feasign
void clear() {
show = 0;
clk = 0;
tags.clear();
vec_feas.clear();
sample_type = 0;
auc_category_info_map.clear();
hot_feas.clear();
cold_feas.clear();
}
uint32_t calc_compress_feas_lens() const {
uint32_t hot_len = hot_feas.size();
uint32_t cold_len = cold_feas.size();
uint32_t cursor = 0;
int32_t pre_slot = -1;
uint32_t k = 0;
//热编码
if (hot_len > 0) {
pre_slot = hot_feas[0].slot();
for (uint32_t i = 0; i < hot_len + 1; ++i) {
if (i == hot_len || pre_slot != hot_feas[i].slot()) {
cursor += 2;
//情形2
if (i - k > 1) {
cursor += 2;
}
//情形1/2
cursor += (HOT_SIGN_SIZE * (i - k));
k = i;
}
pre_slot = hot_feas[i].slot();
}
}
//冷编码
if (cold_len > 0) {
pre_slot = cold_feas[0].slot();
k = 0;
for (uint32_t i = 0; i < cold_len + 1; ++i) {
if (i == cold_len || pre_slot != cold_feas[i].slot()) {
cursor += 2;
//情形4
if (i - k > 1) {
cursor += 2;
} else { //情形3
cursor++;
}
//情形3/4
cursor += (COLD_SIGN_SIZE * (i - k));
k = i;
}
pre_slot = cold_feas[i].slot();
}
}
return cursor;
}
void parse_feas(char* buffer) const {
if (buffer == nullptr) {
return ;
}
uint32_t cursor = 0;
uint32_t hot_len = hot_feas.size();
uint32_t cold_len = cold_feas.size();
int32_t pre_slot = -1;
int32_t hot_sign;
uint16_t slot;
uint8_t flag = 0, len = 0;
uint32_t k = 0;
//热编码
if (hot_len > 0) {
pre_slot = hot_feas[0].slot();
for (uint32_t i = 0; i < hot_len + 1; ++i) {
if (i == hot_len || pre_slot != hot_feas[i].slot()) {
memcpy(buffer + cursor, &pre_slot, 2);
cursor += 2;
//情形2
if (i - k > 1) {
flag = 0x10;
memcpy(buffer + cursor, &flag, 1);
cursor++;
len = i - k;
memcpy(buffer + cursor, &len, 1);
cursor++;
}
//情形1/2
for (uint32_t j = k; j < i; ++j) {
hot_sign = (int32_t) hot_feas[j].sign();
for (uint8_t b = 0; b < HOT_SIGN_SIZE; ++b) {
flag = (hot_sign >> ((HOT_SIGN_SIZE - b - 1) * 8)) & 0xFF;
memcpy(buffer + cursor, &flag, 1);
cursor++;
}
}
k = i;
}
pre_slot = hot_feas[i].slot();
}
}
//冷编码
if (cold_len > 0) {
pre_slot = cold_feas[0].slot();
k = 0;
for (uint32_t i = 0; i < cold_len + 1; ++i) {
if (i == cold_len || pre_slot != cold_feas[i].slot()) {
memcpy(buffer + cursor, &pre_slot, 2);
cursor += 2;
//情形4
if (i - k > 1) {
flag = 0x30;
memcpy(buffer + cursor, &flag, 1);
cursor++;
len = i - k;
memcpy(buffer + cursor, &len, 1);
cursor++;
}
//情形3/4
for (uint32_t j = k; j < i; ++j) {
if (i - k == 1) {
flag = 0x20;
memcpy(buffer + cursor, &flag, 1);
cursor++;
}
memcpy(buffer + cursor, &cold_feas[j].sign(), COLD_SIGN_SIZE);
cursor += COLD_SIGN_SIZE;
}
k = i;
}
pre_slot = cold_feas[i].slot();
}
}
}
};
virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
_index = context->cache_dict;
return 0;
}
virtual int parse(const char* str, size_t len, DataItem& data) const {
size_t pos = paddle::string::count_nonspaces(str);
if (pos >= len) {
VLOG(2) << "fail to parse line: " << std::string(str, len) << ", strlen: " << len;
return -1;
}
VLOG(5) << "getline: " << str << " , pos: " << pos << ", len: " << len;
data.id.assign(str, pos);
str += pos;
static thread_local std::vector<float> vec_feas;
static thread_local Record rec;
rec.clear();
const char* line_end = str + len;
char* cursor = NULL;
CHECK((rec.show = (int)strtol(str, &cursor, 10), cursor != str));
str = cursor;
CHECK((rec.clk = (int)strtol(str, &cursor, 10), cursor != str));
str = cursor;
CHECK(rec.show >= 1 && rec.clk >= 0 && rec.clk <= rec.show);
while (*(str += paddle::string::count_nonspaces(str)) != 0) {
if (*str == '*') {
str++;
size_t len = paddle::string::count_nonspaces(str);
std::string tag(str, str + len);
rec.tags = tag;
str += len;
} else if (*str == '$') {
str++;
CHECK((rec.sample_type = (int)strtol(str, &cursor, 10), cursor != str))<<" sample type parse err:" << str;
str = cursor;
} else if (*str == '#') {
str++;
size_t len = std::find_if_not(str, line_end,
[](char c) { return std::isalnum(c) != 0 || c == '_';}) - str;
CHECK(len > 0 && *(str + len) == ':');
std::string name(str, len);
str += len;
vec_feas.clear();
while (*str == ':') {
float val = 0;
CHECK((val = strtof(str + 1, &cursor), cursor > str));
vec_feas.push_back(val);
str = cursor;
}
CHECK(rec.vec_feas.insert({name, vec_feas}).second);
} else if (*str == '@') {
str++;
size_t len = paddle::string::count_nonspaces(str);
std::string all_str(str, str + len);
str += len;
//category_name1=value1,value2,value3|category_name2=value1,value2|....
std::vector<std::string> all_category_vec = paddle::string::split_string(all_str, "|");
for (size_t i = 0; i < all_category_vec.size(); ++i) {
std::string& single_category_str = all_category_vec[i];
std::vector<std::string> str_vec = paddle::string::split_string(single_category_str, "=");
CHECK(str_vec.size() == 2);
/*std::string category_name = str_vec[0];
std::vector<int> category_info_vec = paddle::string::split_string<int>(str_vec[1], ",");
CHECK(category_info_vec.size() > 0);
CHECK(rec.auc_category_info_map.insert({category_name, category_info_vec}).second);*/
}
} else {
uint64_t sign = 0;
int slot = -1;
CHECK((sign = (uint64_t) strtoull(str, &cursor, 10), cursor != str));
str = cursor;
CHECK(*str++ == ':');
CHECK(!isspace(*str));
CHECK((slot = (int) strtol(str, &cursor, 10), cursor != str)) << " format error: " << str;
CHECK((uint16_t) slot == slot);
str = cursor;
int32_t compress_sign = _index->sign2index(sign);
if (compress_sign < 0) {
rec.cold_feas.emplace_back(sign, (uint16_t)slot);
} else {
rec.hot_feas.emplace_back(compress_sign, (uint16_t)slot);
}
}
}
paddle::framework::BinaryArchive bar;
bar << rec.show;
bar << rec.clk;
bar << rec.tags;
bar << rec.vec_feas;
bar << rec.sample_type;
bar << rec.auc_category_info_map;
uint32_t feas_len = rec.calc_compress_feas_lens();
bar << feas_len;
bar.Resize(bar.Length() + feas_len);
rec.parse_feas(bar.Cursor());
data.data.assign(bar.Buffer(), bar.Length());
return 0;
}
virtual int parse(const char* str, DataItem& data) const {
}
virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const {
instance.id = data.id;
if (data.data.empty()) {
return -1;
}
//FIXME
int show, clk;
std::string tags;
std::map<std::string, std::vector<float>> vec_feas;
int sample_type;
std::map<std::string, std::vector<int>> auc_category_info_map;
uint32_t feas_len = 0;
paddle::framework::BinaryArchive bar;
bar.SetReadBuffer(const_cast<char*>(&data.data[0]), data.data.size(), nullptr);
bar >> show;
bar >> clk;
bar >> tags;
bar >> vec_feas;
bar >> sample_type;
bar >> auc_category_info_map;
bar >> feas_len;
parse_feas_to_ins(bar.Cursor(), feas_len, instance.features);
return 0;
}
private:
void parse_feas_to_ins(char* buffer, uint32_t len, std::vector<FeatureItem>& ins) const {
if (buffer == nullptr) {
return ;
}
uint32_t cursor = 0;
uint16_t slot;
uint8_t flag;
while (cursor < len) {
memcpy(&slot, buffer + cursor, 2);
cursor += 2;
memcpy(&flag, buffer + cursor, 1);
flag &= 0xF0;
CHECK(flag == 0x00 || flag == 0x10|| flag == 0x20 || flag == 0x30);
if (flag == 0x00 || flag == 0x10) {
uint8_t len = 1;
if (flag == 0x10) {
cursor++;
memcpy(&len, buffer + cursor, 1);
cursor++;
}
for (uint8_t i = 0; i < len; ++i) {
int32_t sign;
for (uint8_t j = 0; j < HOT_SIGN_SIZE; ++j) {
memcpy((char*)&sign + HOT_SIGN_SIZE-j-1, buffer + cursor, 1);
cursor++;
}
uint64_t sign64 = sign & 0x0FFFFFFF;
sign64 = _index->index2sign((int32_t)sign64);
ins.emplace_back(sign64, slot);
}
}
if (flag == 0x20 || flag == 0x30) {
uint8_t len = 1;
cursor++;
if (flag == 0x30) {
memcpy(&len, buffer + cursor, 1);
cursor++;
}
for (uint8_t i = 0; i < len; ++i) {
uint64_t sign64;
memcpy(&sign64, buffer + cursor, COLD_SIGN_SIZE);
cursor += COLD_SIGN_SIZE;
ins.emplace_back(sign64, slot);
}
}
}
}
private:
std::shared_ptr<SignCacheDict> _index;
};
REGISTER_CLASS(DataParser, ArchiveDataParse);
int DataReader::initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) { int DataReader::initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
_parser.reset(CREATE_CLASS(DataParser, config["parser"]["class"].as<std::string>())); _parser.reset(CREATE_CLASS(DataParser, config["parser"]["class"].as<std::string>()));
if (_parser == nullptr) { if (_parser == nullptr) {
......
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
#include <yaml-cpp/yaml.h> #include <yaml-cpp/yaml.h>
#include "paddle/fluid/framework/channel.h" #include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/train/custom_trainer/feed/common/pipeline.h" #include "paddle/fluid/train/custom_trainer/feed/common/pipeline.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h" #include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
namespace paddle { namespace paddle {
...@@ -18,13 +20,38 @@ namespace feed { ...@@ -18,13 +20,38 @@ namespace feed {
class TrainerContext; class TrainerContext;
struct FeatureItem { struct FeatureItem {
uint64_t feature_sign; public:
uint16_t slot_id; FeatureItem() {
}
FeatureItem(uint64_t sign_, uint16_t slot_) {
sign() = sign_;
slot() = slot_;
}
uint64_t& sign() {
return *(uint64_t*)sign_buffer();
}
const uint64_t& sign() const {
return *(const uint64_t*)sign_buffer();
}
uint16_t& slot() {
return _slot;
}
const uint16_t& slot() const {
return _slot;
}
private:
char _sign[sizeof(uint64_t)];
uint16_t _slot;
char* sign_buffer() const {
return (char*)_sign;
}
}; };
struct SampleInstance { struct SampleInstance {
std::string id; std::string id;
std::vector<float> lables; std::vector<float> labels;
std::vector<FeatureItem> features; std::vector<FeatureItem> features;
std::vector<float> embedx; std::vector<float> embedx;
}; };
......
...@@ -33,6 +33,17 @@ private: ...@@ -33,6 +33,17 @@ private:
int _id; int _id;
}; };
class SignCacheDict {
public:
int32_t sign2index(uint64_t sign) {
return -1;
}
uint64_t index2sign(int32_t index) {
return 0;
}
};
class TrainerContext { class TrainerContext {
public: public:
YAML::Node trainer_config; YAML::Node trainer_config;
...@@ -44,6 +55,7 @@ std::vector<TableMeta> params_table_list; //参数表 ...@@ -44,6 +55,7 @@ std::vector<TableMeta> params_table_list; //参数表
std::shared_ptr<EpochAccessor> epoch_accessor; //训练轮次控制 std::shared_ptr<EpochAccessor> epoch_accessor; //训练轮次控制
std::shared_ptr<RuntimeEnvironment> environment; //运行环境 std::shared_ptr<RuntimeEnvironment> environment; //运行环境
std::vector<std::shared_ptr<Process>> process_list; //训练流程 std::vector<std::shared_ptr<Process>> process_list; //训练流程
std::shared_ptr<SignCacheDict> cache_dict; //大模型cache词典
}; };
} // namespace feed } // namespace feed
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册