From ef24bd78cc9f65782b708134e4d2d399a9d4e1e7 Mon Sep 17 00:00:00 2001 From: malin10 Date: Tue, 8 Sep 2020 18:23:27 +0800 Subject: [PATCH] add tdm_tree --- paddle/fluid/framework/fleet/tree_wrapper.cc | 195 +++++++++++++++++++ paddle/fluid/framework/fleet/tree_wrapper.h | 140 +++++++++++++ 2 files changed, 335 insertions(+) create mode 100644 paddle/fluid/framework/fleet/tree_wrapper.cc create mode 100644 paddle/fluid/framework/fleet/tree_wrapper.h diff --git a/paddle/fluid/framework/fleet/tree_wrapper.cc b/paddle/fluid/framework/fleet/tree_wrapper.cc new file mode 100644 index 00000000000..e02ae46d401 --- /dev/null +++ b/paddle/fluid/framework/fleet/tree_wrapper.cc @@ -0,0 +1,195 @@ +#pragma once +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/data_feed.h" + +namespace paddle { +namespace framework { + +int Tree::load(std::string path, std::string tree_pipe_command_) { + uint64_t linenum = 0; + size_t idx = 0; + std::vector lines; + std::vector strs; + std::vector items; + + int err_no; + std::shared_ptr fp_ = fs_open_read(path, &err_no, tree_pipe_command_); + string::LineFileReader reader; + while (reader.getline(&*(fp_.get()))) { + line = std::string(reader.get()); + strs.clear(); + boost::split(strs, line, boost::is_any_of("\t")); + if (0 == linenum) { + _total_node_num = boost::lexical_cast(strs[0]); + _nodes = new Node[_total_node_num]; + if (strs.size() > 1) { + _tree_height = boost::lexical_cast(strs[1]); + } + ++linenum; + continue; + } + if (strs.size() < 4) { + LOG(WARNING) << "each line must has more than field"; + return -1; + } + Node& node = _nodes[idx]; + // id + node.id = boost::lexical_cast(strs[0]); + // embedding + items.clear(); + if (!strs[1].empty()) { + boost::split(items, strs[1], boost::is_any_of(" ")); + for (size_t i = 0; i != items.size(); ++i) { + node.embedding.emplace_back(boost::lexical_cast(items[i])); + } + } + // parent + items.clear(); + if (!strs[2].empty()) { + node.parent_node = _nodes + boost::lexical_cast(strs[2]); + } + // child + items.clear(); + if (!strs[3].empty()) { + boost::split(items, strs[3], boost::is_any_of(" ")); + // node.sub_nodes = new Node*[items.size()]; + for (size_t i = 0; i != items.size(); ++i) { + node.sub_nodes.push_back(_nodes + boost::lexical_cast(items[i])); + // node.sub_nodes[i] = _nodes + boost::lexical_cast(items[i]); + } + // node.sub_node_num = items.size(); + } else { + //没有孩子节点,当前节点是叶节点 + _leaf_node_map[node.id] = &node; + // node.sub_node_num = 0; + } + if (strs.size() > 4) { + node.height = boost::lexical_cast(strs[4]); + } + ++idx; + ++linenum; + } + _head = _nodes + _total_node_num - 1; + LOG(INFO) << "all lines:" << linenum << ", all tree nodes:" << idx; + return 0; +} +void Tree::print_tree() { + /* + std::queue q; + if (_head) { + q.push(_head); + } + while (!q.empty()) { + const Node* node = q.front(); + q.pop(); + std::cout << "node_id: " << node->id << std::endl; + std::cout << "node_embedding: "; + for (int i = 0; i != node->embedding.size(); ++i) { + std::cout << node->embedding[i] << " "; + } + std::cout << std::endl; + if (node->parent_node) { + std::cout << "parent_idx: " << node->parent_node - _nodes << + std::endl; + } + if (node->sub_node_num > 0) { + for (int i = 0; i != node->sub_node_num; ++i) { + std::cout << "child_idx" << i << ": " << node->sub_nodes[i] - _nodes + << std::endl; + } + } + std::cout << "-------------------------------------" << std::endl; + for (int i = 0; i != node->sub_node_num; ++i) { + Node* tmp_node = node->sub_nodes[i]; + q.push(tmp_node); + } + } + */ +} +int Tree::dump_tree(const uint64_t table_id, int fea_value_dim, + const std::string tree_path) { + int ret; + std::shared_ptr fp = + paddle::framework::fs_open(tree_path, "w", &ret, ""); + + std::vector fea_keys, std::vector pull_result_ptr; + + fea_keys.reserve(_total_node_num); + pull_result_ptr.reserve(_total_node_num); + for (size_t i = 0; i != _total_node_num; ++i) { + _nodes[i].embedding.resize(fea_value_dim); + fea_key.push_back(_nodes[i].id); + pull_result_ptr.push_back(_nodes[i].embedding.data()); + } + + std::string first_line = boost::lexical_cast(_total_node_num) + + "\t" + + boost::lexical_cast(_tree_height); + fwrite(first_line.c_str(), first_line.length(), 1, &*fp); + std::string line_break_str("\n"); + std::string line(""); + for (size_t i = 0; i != _total_node_num; ++i) { + line = line_break_str; + const Node& node = _nodes[i]; + line += boost::lexical_cast(node.id) + "\t"; + if (!node.embedding.empty()) { + for (size_t j = 0; j != node.embedding.size() - 1; ++j) { + line += boost::lexical_cast(node.embedding[j]) + " "; + } + line += boost::lexical_cast( + node.embedding[node.embedding.size() - 1]); + } else { + LOG(WARNING) << "node_idx[" << i << "], id[" << node.id << "] " + << "has no embeddings"; + } + line += "\t"; + if (node.parent_node) { + line += boost::lexical_cast(node.parent_node - _nodes); + } + line += "\t"; + if (node.sub_nodes.size() > 0) { + for (uint32_t j = 0; j < node.sub_nodes.size() - 1; ++j) { + line += + boost::lexical_cast(node.sub_nodes[j] - _nodes) + " "; + } + line += boost::lexical_cast( + node.sub_nodes[node.sub_nodes.size() - 1] - _nodes); + } + line += "\t" + boost::lexical_cast(node.height); + fwrite(line.c_str(), line.length(), 1, &*fp); + } + return 0; +} + +bool Tree::trace_back(uint64_t id, + std::vector>& ids) { + ids.clear(); + std::unordered_map::iterator find_it = + _leaf_node_map.find(id); + if (find_it == _leaf_node_map.end()) { + return false; + } else { + uint32_t height = 0; + Node* node = find_it->second; + while (node != NULL) { + height++; + ids.emplace_back(node->id, 0); + node = node->parent_node; + } + for (auto& id : ids) { + id.second = height--; + } + } + return true; +} + +Node* Tree::get_node() { return _nodes; } +size_t Tree::get_total_node_num() { return _total_node_num; } + +} // end namespace framework +} // end namespace paddle diff --git a/paddle/fluid/framework/fleet/tree_wrapper.h b/paddle/fluid/framework/fleet/tree_wrapper.h new file mode 100644 index 00000000000..c48ee69de6c --- /dev/null +++ b/paddle/fluid/framework/fleet/tree_wrapper.h @@ -0,0 +1,140 @@ +#pragma once +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/data_feed.h" + +namespace paddle { +namespace framework { + +struct Node { + Node::Node() : parent_node(NULL), id(0), height(0) {} + ~Node(){}; + std::vector sub_nodes; + // uint32_t sub_node_num; + Node* parent_node; + uint64_t id; + std::vector embedding; + int16_t height; //层级 +}; + +class Tree { + public: + Tree() : _nodes(NULL), _head(NULL) {} + ~Tree() { + if (_nodes) { + delete[] _nodes; + _nodes = NULL; + } + } + + void print_tree(); + int dump_tree(const uint64_t table_id, int fea_value_dim, + const std::string tree_path); + //采样:从叶节点回溯到根节点 + void trace_back(uint64_t id, std::vector>& ids); + int load(std::string path); + Node* get_node(); + size_t get_total_node_num(); + + private: + // tree data info + Node* _nodes{nullptr}; + // head pointer + Node* _head{nullptr}; + // total number of nodes + size_t _total_node_num{0}; + // leaf node map + std::unordered_map _leaf_node_map; + // version + std::string _version{""}; + //树的高度 + int16_t _tree_height{0}; +}; + +using TreePtr = std::shared_ptr; +class TreeWrapper { + public: + virtual ~TreeWrapper() {} + TreeWrapper() {} + + // TreeWrapper singleton + static std::shared_ptr GetInstance() { + if (NULL == s_instance_) { + s_instance_.reset(new paddle::framework::TreeWrapper()); + } + return s_instance_; + } + + void clear() { tree_map.clear(); } + + void insert(std::string name, std::string tree_path) { + if (tree_map.find(name) != tree_map.end()) { + return; + } + TreePtr tree = new Tree(); + tree.load(tree_path); + tree_map.insert(std::pair{name, tree}); + } + + void dump(std::string name, const uint64_t table_id, int fea_value_dim, + const std::string tree_path) { + if (tree_map.find(name) == tree_map.end()) { + return; + } + tree_map.at(name)->dump_tree(table_id, fea_value_dim, tree_path); + } + + void sample(const uint16_t sample_slot, const uint64_t type_slot, + std::vector& src_datas, + std::vector& sample_results) { + sample_results.clear(); + for (auto& data : src_datas) { + uint64_t sample_feasign_idx = -1, type_feasign_idx = -1; + for (auto i = 0; i < data.uint64_feasigns_.size(); i++) { + if (data.uint64_feasigns_[i].slot() == sample_slot) { + sample_feasign_idx = i; + } + if (data.uint64_feasigns_.slot() == type_slot) { + type_feasign_idx = i; + } + } + if (sample_feasign_idx > 0) { + std::vector> trace_ids; + for (auto name : tree_map) { + bool in_tree = tree_map.at(name)->trace_back( + data.uint64_feasigns_[sample_feasign_idx].sign().uint64_feasign_, + trace_ids); + if (in_tree) { + break; + } else { + PADDLE_ENFORCE_EQ(trace_ids.size(), 0, ""); + } + } + for (auto i = 0; i < trace_ids.size(); i++) { + Record instance(data); + instance.uint64_feasigns_[sample_feasign_idx].sign().uint64_feasign_ = + trace_ids[i].first; + if (type_feasign_idx > 0) + instance.uint64_feasigns_[type_feasign_idx] + .sign() + .uint64_feasign_ += trace_ids[i].second * 100; + sample_results.push_back(instance); + } + } + } + return; + } + + public: + std::unordered_map tree_map; + + private: + static std::shared_ptr s_instance_; +}; + +} // end namespace framework +} // end namespace paddle -- GitLab