diff --git a/paddle/fluid/framework/fleet/tree_wrapper.cc b/paddle/fluid/framework/fleet/tree_wrapper.cc new file mode 100644 index 0000000000000000000000000000000000000000..e02ae46d401748dc6f9647e1575a4ef56ca60ddc --- /dev/null +++ b/paddle/fluid/framework/fleet/tree_wrapper.cc @@ -0,0 +1,195 @@ +#pragma once +#include <memory> +#include <string> +#include <unordered_map> +#include <unordered_set> +#include <vector> + +#include "paddle/fluid/framework/data_feed.h" + +namespace paddle { +namespace framework { + +int Tree::load(std::string path, std::string tree_pipe_command_) { + uint64_t linenum = 0; + size_t idx = 0; + std::vector<std::string> lines; + std::vector<std::string> strs; + std::vector<std::string> items; + + int err_no; + std::shared_ptr<FILE> fp_ = fs_open_read(path, &err_no, tree_pipe_command_); + string::LineFileReader reader; + while (reader.getline(&*(fp_.get()))) { + line = std::string(reader.get()); + strs.clear(); + boost::split(strs, line, boost::is_any_of("\t")); + if (0 == linenum) { + _total_node_num = boost::lexical_cast<size_t>(strs[0]); + _nodes = new Node[_total_node_num]; + if (strs.size() > 1) { + _tree_height = boost::lexical_cast<int16_t>(strs[1]); + } + ++linenum; + continue; + } + if (strs.size() < 4) { + LOG(WARNING) << "each line must has more than field"; + return -1; + } + Node& node = _nodes[idx]; + // id + node.id = boost::lexical_cast<uint64_t>(strs[0]); + // embedding + items.clear(); + if (!strs[1].empty()) { + boost::split(items, strs[1], boost::is_any_of(" ")); + for (size_t i = 0; i != items.size(); ++i) { + node.embedding.emplace_back(boost::lexical_cast<float>(items[i])); + } + } + // parent + items.clear(); + if (!strs[2].empty()) { + node.parent_node = _nodes + boost::lexical_cast<int>(strs[2]); + } + // child + items.clear(); + if (!strs[3].empty()) { + boost::split(items, strs[3], boost::is_any_of(" ")); + // node.sub_nodes = new Node*[items.size()]; + for (size_t i = 0; i != items.size(); ++i) { + node.sub_nodes.push_back(_nodes + boost::lexical_cast<int>(items[i])); + // node.sub_nodes[i] = _nodes + boost::lexical_cast<int>(items[i]); + } + // node.sub_node_num = items.size(); + } else { + //没有å©å节点,当å‰èŠ‚点是å¶èŠ‚点 + _leaf_node_map[node.id] = &node; + // node.sub_node_num = 0; + } + if (strs.size() > 4) { + node.height = boost::lexical_cast<int16_t>(strs[4]); + } + ++idx; + ++linenum; + } + _head = _nodes + _total_node_num - 1; + LOG(INFO) << "all lines:" << linenum << ", all tree nodes:" << idx; + return 0; +} +void Tree::print_tree() { + /* + std::queue<Node*> q; + if (_head) { + q.push(_head); + } + while (!q.empty()) { + const Node* node = q.front(); + q.pop(); + std::cout << "node_id: " << node->id << std::endl; + std::cout << "node_embedding: "; + for (int i = 0; i != node->embedding.size(); ++i) { + std::cout << node->embedding[i] << " "; + } + std::cout << std::endl; + if (node->parent_node) { + std::cout << "parent_idx: " << node->parent_node - _nodes << + std::endl; + } + if (node->sub_node_num > 0) { + for (int i = 0; i != node->sub_node_num; ++i) { + std::cout << "child_idx" << i << ": " << node->sub_nodes[i] - _nodes + << std::endl; + } + } + std::cout << "-------------------------------------" << std::endl; + for (int i = 0; i != node->sub_node_num; ++i) { + Node* tmp_node = node->sub_nodes[i]; + q.push(tmp_node); + } + } + */ +} +int Tree::dump_tree(const uint64_t table_id, int fea_value_dim, + const std::string tree_path) { + int ret; + std::shared_ptr<FILE> fp = + paddle::framework::fs_open(tree_path, "w", &ret, ""); + + std::vector<uint64_t> fea_keys, std::vector<float*> pull_result_ptr; + + fea_keys.reserve(_total_node_num); + pull_result_ptr.reserve(_total_node_num); + for (size_t i = 0; i != _total_node_num; ++i) { + _nodes[i].embedding.resize(fea_value_dim); + fea_key.push_back(_nodes[i].id); + pull_result_ptr.push_back(_nodes[i].embedding.data()); + } + + std::string first_line = boost::lexical_cast<std::string>(_total_node_num) + + "\t" + + boost::lexical_cast<std::string>(_tree_height); + fwrite(first_line.c_str(), first_line.length(), 1, &*fp); + std::string line_break_str("\n"); + std::string line(""); + for (size_t i = 0; i != _total_node_num; ++i) { + line = line_break_str; + const Node& node = _nodes[i]; + line += boost::lexical_cast<std::string>(node.id) + "\t"; + if (!node.embedding.empty()) { + for (size_t j = 0; j != node.embedding.size() - 1; ++j) { + line += boost::lexical_cast<std::string>(node.embedding[j]) + " "; + } + line += boost::lexical_cast<std::string>( + node.embedding[node.embedding.size() - 1]); + } else { + LOG(WARNING) << "node_idx[" << i << "], id[" << node.id << "] " + << "has no embeddings"; + } + line += "\t"; + if (node.parent_node) { + line += boost::lexical_cast<std::string>(node.parent_node - _nodes); + } + line += "\t"; + if (node.sub_nodes.size() > 0) { + for (uint32_t j = 0; j < node.sub_nodes.size() - 1; ++j) { + line += + boost::lexical_cast<std::string>(node.sub_nodes[j] - _nodes) + " "; + } + line += boost::lexical_cast<std::string>( + node.sub_nodes[node.sub_nodes.size() - 1] - _nodes); + } + line += "\t" + boost::lexical_cast<std::string>(node.height); + fwrite(line.c_str(), line.length(), 1, &*fp); + } + return 0; +} + +bool Tree::trace_back(uint64_t id, + std::vector<std::pair<uint64_t, uint32_t>>& ids) { + ids.clear(); + std::unordered_map<uint64_t, Node*>::iterator find_it = + _leaf_node_map.find(id); + if (find_it == _leaf_node_map.end()) { + return false; + } else { + uint32_t height = 0; + Node* node = find_it->second; + while (node != NULL) { + height++; + ids.emplace_back(node->id, 0); + node = node->parent_node; + } + for (auto& id : ids) { + id.second = height--; + } + } + return true; +} + +Node* Tree::get_node() { return _nodes; } +size_t Tree::get_total_node_num() { return _total_node_num; } + +} // end namespace framework +} // end namespace paddle diff --git a/paddle/fluid/framework/fleet/tree_wrapper.h b/paddle/fluid/framework/fleet/tree_wrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..c48ee69de6c9f594c3fa444ce547300a47596c6a --- /dev/null +++ b/paddle/fluid/framework/fleet/tree_wrapper.h @@ -0,0 +1,140 @@ +#pragma once +#include <memory> +#include <string> +#include <unordered_map> +#include <unordered_set> +#include <vector> + +#include "paddle/fluid/framework/data_feed.h" + +namespace paddle { +namespace framework { + +struct Node { + Node::Node() : parent_node(NULL), id(0), height(0) {} + ~Node(){}; + std::vector<Node*> sub_nodes; + // uint32_t sub_node_num; + Node* parent_node; + uint64_t id; + std::vector<float> embedding; + int16_t height; //层级 +}; + +class Tree { + public: + Tree() : _nodes(NULL), _head(NULL) {} + ~Tree() { + if (_nodes) { + delete[] _nodes; + _nodes = NULL; + } + } + + void print_tree(); + int dump_tree(const uint64_t table_id, int fea_value_dim, + const std::string tree_path); + //é‡‡æ ·ï¼šä»Žå¶èŠ‚ç‚¹å›žæº¯åˆ°æ ¹èŠ‚ç‚¹ + void trace_back(uint64_t id, std::vector<std::pair<uint64_t, uint32_t>>& ids); + int load(std::string path); + Node* get_node(); + size_t get_total_node_num(); + + private: + // tree data info + Node* _nodes{nullptr}; + // head pointer + Node* _head{nullptr}; + // total number of nodes + size_t _total_node_num{0}; + // leaf node map + std::unordered_map<uint64_t, Node*> _leaf_node_map; + // version + std::string _version{""}; + //æ ‘çš„é«˜åº¦ + int16_t _tree_height{0}; +}; + +using TreePtr = std::shared_ptr<Tree>; +class TreeWrapper { + public: + virtual ~TreeWrapper() {} + TreeWrapper() {} + + // TreeWrapper singleton + static std::shared_ptr<TreeWrapper> GetInstance() { + if (NULL == s_instance_) { + s_instance_.reset(new paddle::framework::TreeWrapper()); + } + return s_instance_; + } + + void clear() { tree_map.clear(); } + + void insert(std::string name, std::string tree_path) { + if (tree_map.find(name) != tree_map.end()) { + return; + } + TreePtr tree = new Tree(); + tree.load(tree_path); + tree_map.insert(std::pair<std::string, TreePtr>{name, tree}); + } + + void dump(std::string name, const uint64_t table_id, int fea_value_dim, + const std::string tree_path) { + if (tree_map.find(name) == tree_map.end()) { + return; + } + tree_map.at(name)->dump_tree(table_id, fea_value_dim, tree_path); + } + + void sample(const uint16_t sample_slot, const uint64_t type_slot, + std::vector<Record>& src_datas, + std::vector<Record>& sample_results) { + sample_results.clear(); + for (auto& data : src_datas) { + uint64_t sample_feasign_idx = -1, type_feasign_idx = -1; + for (auto i = 0; i < data.uint64_feasigns_.size(); i++) { + if (data.uint64_feasigns_[i].slot() == sample_slot) { + sample_feasign_idx = i; + } + if (data.uint64_feasigns_.slot() == type_slot) { + type_feasign_idx = i; + } + } + if (sample_feasign_idx > 0) { + std::vector<std::pair<uint64_t, uint32_t>> trace_ids; + for (auto name : tree_map) { + bool in_tree = tree_map.at(name)->trace_back( + data.uint64_feasigns_[sample_feasign_idx].sign().uint64_feasign_, + trace_ids); + if (in_tree) { + break; + } else { + PADDLE_ENFORCE_EQ(trace_ids.size(), 0, ""); + } + } + for (auto i = 0; i < trace_ids.size(); i++) { + Record instance(data); + instance.uint64_feasigns_[sample_feasign_idx].sign().uint64_feasign_ = + trace_ids[i].first; + if (type_feasign_idx > 0) + instance.uint64_feasigns_[type_feasign_idx] + .sign() + .uint64_feasign_ += trace_ids[i].second * 100; + sample_results.push_back(instance); + } + } + } + return; + } + + public: + std::unordered_map<std::string, TreePtr> tree_map; + + private: + static std::shared_ptr<TreeWrapper> s_instance_; +}; + +} // end namespace framework +} // end namespace paddle