提交 ef24bd78 编写于 作者: M malin10

add tdm_tree

上级 ecc59dcd
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
namespace paddle {
namespace framework {
int Tree::load(std::string path, std::string tree_pipe_command_) {
uint64_t linenum = 0;
size_t idx = 0;
std::vector<std::string> lines;
std::vector<std::string> strs;
std::vector<std::string> items;
int err_no;
std::shared_ptr<FILE> fp_ = fs_open_read(path, &err_no, tree_pipe_command_);
string::LineFileReader reader;
while (reader.getline(&*(fp_.get()))) {
line = std::string(reader.get());
strs.clear();
boost::split(strs, line, boost::is_any_of("\t"));
if (0 == linenum) {
_total_node_num = boost::lexical_cast<size_t>(strs[0]);
_nodes = new Node[_total_node_num];
if (strs.size() > 1) {
_tree_height = boost::lexical_cast<int16_t>(strs[1]);
}
++linenum;
continue;
}
if (strs.size() < 4) {
LOG(WARNING) << "each line must has more than field";
return -1;
}
Node& node = _nodes[idx];
// id
node.id = boost::lexical_cast<uint64_t>(strs[0]);
// embedding
items.clear();
if (!strs[1].empty()) {
boost::split(items, strs[1], boost::is_any_of(" "));
for (size_t i = 0; i != items.size(); ++i) {
node.embedding.emplace_back(boost::lexical_cast<float>(items[i]));
}
}
// parent
items.clear();
if (!strs[2].empty()) {
node.parent_node = _nodes + boost::lexical_cast<int>(strs[2]);
}
// child
items.clear();
if (!strs[3].empty()) {
boost::split(items, strs[3], boost::is_any_of(" "));
// node.sub_nodes = new Node*[items.size()];
for (size_t i = 0; i != items.size(); ++i) {
node.sub_nodes.push_back(_nodes + boost::lexical_cast<int>(items[i]));
// node.sub_nodes[i] = _nodes + boost::lexical_cast<int>(items[i]);
}
// node.sub_node_num = items.size();
} else {
//没有孩子节点,当前节点是叶节点
_leaf_node_map[node.id] = &node;
// node.sub_node_num = 0;
}
if (strs.size() > 4) {
node.height = boost::lexical_cast<int16_t>(strs[4]);
}
++idx;
++linenum;
}
_head = _nodes + _total_node_num - 1;
LOG(INFO) << "all lines:" << linenum << ", all tree nodes:" << idx;
return 0;
}
void Tree::print_tree() {
/*
std::queue<Node*> q;
if (_head) {
q.push(_head);
}
while (!q.empty()) {
const Node* node = q.front();
q.pop();
std::cout << "node_id: " << node->id << std::endl;
std::cout << "node_embedding: ";
for (int i = 0; i != node->embedding.size(); ++i) {
std::cout << node->embedding[i] << " ";
}
std::cout << std::endl;
if (node->parent_node) {
std::cout << "parent_idx: " << node->parent_node - _nodes <<
std::endl;
}
if (node->sub_node_num > 0) {
for (int i = 0; i != node->sub_node_num; ++i) {
std::cout << "child_idx" << i << ": " << node->sub_nodes[i] - _nodes
<< std::endl;
}
}
std::cout << "-------------------------------------" << std::endl;
for (int i = 0; i != node->sub_node_num; ++i) {
Node* tmp_node = node->sub_nodes[i];
q.push(tmp_node);
}
}
*/
}
int Tree::dump_tree(const uint64_t table_id, int fea_value_dim,
const std::string tree_path) {
int ret;
std::shared_ptr<FILE> fp =
paddle::framework::fs_open(tree_path, "w", &ret, "");
std::vector<uint64_t> fea_keys, std::vector<float*> pull_result_ptr;
fea_keys.reserve(_total_node_num);
pull_result_ptr.reserve(_total_node_num);
for (size_t i = 0; i != _total_node_num; ++i) {
_nodes[i].embedding.resize(fea_value_dim);
fea_key.push_back(_nodes[i].id);
pull_result_ptr.push_back(_nodes[i].embedding.data());
}
std::string first_line = boost::lexical_cast<std::string>(_total_node_num) +
"\t" +
boost::lexical_cast<std::string>(_tree_height);
fwrite(first_line.c_str(), first_line.length(), 1, &*fp);
std::string line_break_str("\n");
std::string line("");
for (size_t i = 0; i != _total_node_num; ++i) {
line = line_break_str;
const Node& node = _nodes[i];
line += boost::lexical_cast<std::string>(node.id) + "\t";
if (!node.embedding.empty()) {
for (size_t j = 0; j != node.embedding.size() - 1; ++j) {
line += boost::lexical_cast<std::string>(node.embedding[j]) + " ";
}
line += boost::lexical_cast<std::string>(
node.embedding[node.embedding.size() - 1]);
} else {
LOG(WARNING) << "node_idx[" << i << "], id[" << node.id << "] "
<< "has no embeddings";
}
line += "\t";
if (node.parent_node) {
line += boost::lexical_cast<std::string>(node.parent_node - _nodes);
}
line += "\t";
if (node.sub_nodes.size() > 0) {
for (uint32_t j = 0; j < node.sub_nodes.size() - 1; ++j) {
line +=
boost::lexical_cast<std::string>(node.sub_nodes[j] - _nodes) + " ";
}
line += boost::lexical_cast<std::string>(
node.sub_nodes[node.sub_nodes.size() - 1] - _nodes);
}
line += "\t" + boost::lexical_cast<std::string>(node.height);
fwrite(line.c_str(), line.length(), 1, &*fp);
}
return 0;
}
bool Tree::trace_back(uint64_t id,
std::vector<std::pair<uint64_t, uint32_t>>& ids) {
ids.clear();
std::unordered_map<uint64_t, Node*>::iterator find_it =
_leaf_node_map.find(id);
if (find_it == _leaf_node_map.end()) {
return false;
} else {
uint32_t height = 0;
Node* node = find_it->second;
while (node != NULL) {
height++;
ids.emplace_back(node->id, 0);
node = node->parent_node;
}
for (auto& id : ids) {
id.second = height--;
}
}
return true;
}
Node* Tree::get_node() { return _nodes; }
size_t Tree::get_total_node_num() { return _total_node_num; }
} // end namespace framework
} // end namespace paddle
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
namespace paddle {
namespace framework {
struct Node {
Node::Node() : parent_node(NULL), id(0), height(0) {}
~Node(){};
std::vector<Node*> sub_nodes;
// uint32_t sub_node_num;
Node* parent_node;
uint64_t id;
std::vector<float> embedding;
int16_t height; //层级
};
class Tree {
public:
Tree() : _nodes(NULL), _head(NULL) {}
~Tree() {
if (_nodes) {
delete[] _nodes;
_nodes = NULL;
}
}
void print_tree();
int dump_tree(const uint64_t table_id, int fea_value_dim,
const std::string tree_path);
//采样:从叶节点回溯到根节点
void trace_back(uint64_t id, std::vector<std::pair<uint64_t, uint32_t>>& ids);
int load(std::string path);
Node* get_node();
size_t get_total_node_num();
private:
// tree data info
Node* _nodes{nullptr};
// head pointer
Node* _head{nullptr};
// total number of nodes
size_t _total_node_num{0};
// leaf node map
std::unordered_map<uint64_t, Node*> _leaf_node_map;
// version
std::string _version{""};
//树的高度
int16_t _tree_height{0};
};
using TreePtr = std::shared_ptr<Tree>;
class TreeWrapper {
public:
virtual ~TreeWrapper() {}
TreeWrapper() {}
// TreeWrapper singleton
static std::shared_ptr<TreeWrapper> GetInstance() {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::framework::TreeWrapper());
}
return s_instance_;
}
void clear() { tree_map.clear(); }
void insert(std::string name, std::string tree_path) {
if (tree_map.find(name) != tree_map.end()) {
return;
}
TreePtr tree = new Tree();
tree.load(tree_path);
tree_map.insert(std::pair<std::string, TreePtr>{name, tree});
}
void dump(std::string name, const uint64_t table_id, int fea_value_dim,
const std::string tree_path) {
if (tree_map.find(name) == tree_map.end()) {
return;
}
tree_map.at(name)->dump_tree(table_id, fea_value_dim, tree_path);
}
void sample(const uint16_t sample_slot, const uint64_t type_slot,
std::vector<Record>& src_datas,
std::vector<Record>& sample_results) {
sample_results.clear();
for (auto& data : src_datas) {
uint64_t sample_feasign_idx = -1, type_feasign_idx = -1;
for (auto i = 0; i < data.uint64_feasigns_.size(); i++) {
if (data.uint64_feasigns_[i].slot() == sample_slot) {
sample_feasign_idx = i;
}
if (data.uint64_feasigns_.slot() == type_slot) {
type_feasign_idx = i;
}
}
if (sample_feasign_idx > 0) {
std::vector<std::pair<uint64_t, uint32_t>> trace_ids;
for (auto name : tree_map) {
bool in_tree = tree_map.at(name)->trace_back(
data.uint64_feasigns_[sample_feasign_idx].sign().uint64_feasign_,
trace_ids);
if (in_tree) {
break;
} else {
PADDLE_ENFORCE_EQ(trace_ids.size(), 0, "");
}
}
for (auto i = 0; i < trace_ids.size(); i++) {
Record instance(data);
instance.uint64_feasigns_[sample_feasign_idx].sign().uint64_feasign_ =
trace_ids[i].first;
if (type_feasign_idx > 0)
instance.uint64_feasigns_[type_feasign_idx]
.sign()
.uint64_feasign_ += trace_ids[i].second * 100;
sample_results.push_back(instance);
}
}
}
return;
}
public:
std::unordered_map<std::string, TreePtr> tree_map;
private:
static std::shared_ptr<TreeWrapper> s_instance_;
};
} // end namespace framework
} // end namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册