#pragma once #include #include #include #include #include #include "paddle/fluid/framework/data_feed.h" namespace paddle { namespace framework { int Tree::load(std::string path, std::string tree_pipe_command_) { uint64_t linenum = 0; size_t idx = 0; std::vector lines; std::vector strs; std::vector items; int err_no; std::shared_ptr fp_ = fs_open_read(path, &err_no, tree_pipe_command_); string::LineFileReader reader; while (reader.getline(&*(fp_.get()))) { line = std::string(reader.get()); strs.clear(); boost::split(strs, line, boost::is_any_of("\t")); if (0 == linenum) { _total_node_num = boost::lexical_cast(strs[0]); _nodes = new Node[_total_node_num]; if (strs.size() > 1) { _tree_height = boost::lexical_cast(strs[1]); } ++linenum; continue; } if (strs.size() < 4) { LOG(WARNING) << "each line must has more than field"; return -1; } Node& node = _nodes[idx]; // id node.id = boost::lexical_cast(strs[0]); // embedding items.clear(); if (!strs[1].empty()) { boost::split(items, strs[1], boost::is_any_of(" ")); for (size_t i = 0; i != items.size(); ++i) { node.embedding.emplace_back(boost::lexical_cast(items[i])); } } // parent items.clear(); if (!strs[2].empty()) { node.parent_node = _nodes + boost::lexical_cast(strs[2]); } // child items.clear(); if (!strs[3].empty()) { boost::split(items, strs[3], boost::is_any_of(" ")); // node.sub_nodes = new Node*[items.size()]; for (size_t i = 0; i != items.size(); ++i) { node.sub_nodes.push_back(_nodes + boost::lexical_cast(items[i])); // node.sub_nodes[i] = _nodes + boost::lexical_cast(items[i]); } // node.sub_node_num = items.size(); } else { //没有孩子节点,当前节点是叶节点 _leaf_node_map[node.id] = &node; // node.sub_node_num = 0; } if (strs.size() > 4) { node.height = boost::lexical_cast(strs[4]); } ++idx; ++linenum; } _head = _nodes + _total_node_num - 1; LOG(INFO) << "all lines:" << linenum << ", all tree nodes:" << idx; return 0; } void Tree::print_tree() { /* std::queue q; if (_head) { q.push(_head); } while (!q.empty()) { const Node* node = q.front(); q.pop(); std::cout << "node_id: " << node->id << std::endl; std::cout << "node_embedding: "; for (int i = 0; i != node->embedding.size(); ++i) { std::cout << node->embedding[i] << " "; } std::cout << std::endl; if (node->parent_node) { std::cout << "parent_idx: " << node->parent_node - _nodes << std::endl; } if (node->sub_node_num > 0) { for (int i = 0; i != node->sub_node_num; ++i) { std::cout << "child_idx" << i << ": " << node->sub_nodes[i] - _nodes << std::endl; } } std::cout << "-------------------------------------" << std::endl; for (int i = 0; i != node->sub_node_num; ++i) { Node* tmp_node = node->sub_nodes[i]; q.push(tmp_node); } } */ } int Tree::dump_tree(const uint64_t table_id, int fea_value_dim, const std::string tree_path) { int ret; std::shared_ptr fp = paddle::framework::fs_open(tree_path, "w", &ret, ""); std::vector fea_keys, std::vector pull_result_ptr; fea_keys.reserve(_total_node_num); pull_result_ptr.reserve(_total_node_num); for (size_t i = 0; i != _total_node_num; ++i) { _nodes[i].embedding.resize(fea_value_dim); fea_key.push_back(_nodes[i].id); pull_result_ptr.push_back(_nodes[i].embedding.data()); } std::string first_line = boost::lexical_cast(_total_node_num) + "\t" + boost::lexical_cast(_tree_height); fwrite(first_line.c_str(), first_line.length(), 1, &*fp); std::string line_break_str("\n"); std::string line(""); for (size_t i = 0; i != _total_node_num; ++i) { line = line_break_str; const Node& node = _nodes[i]; line += boost::lexical_cast(node.id) + "\t"; if (!node.embedding.empty()) { for (size_t j = 0; j != node.embedding.size() - 1; ++j) { line += boost::lexical_cast(node.embedding[j]) + " "; } line += boost::lexical_cast( node.embedding[node.embedding.size() - 1]); } else { LOG(WARNING) << "node_idx[" << i << "], id[" << node.id << "] " << "has no embeddings"; } line += "\t"; if (node.parent_node) { line += boost::lexical_cast(node.parent_node - _nodes); } line += "\t"; if (node.sub_nodes.size() > 0) { for (uint32_t j = 0; j < node.sub_nodes.size() - 1; ++j) { line += boost::lexical_cast(node.sub_nodes[j] - _nodes) + " "; } line += boost::lexical_cast( node.sub_nodes[node.sub_nodes.size() - 1] - _nodes); } line += "\t" + boost::lexical_cast(node.height); fwrite(line.c_str(), line.length(), 1, &*fp); } return 0; } bool Tree::trace_back(uint64_t id, std::vector>& ids) { ids.clear(); std::unordered_map::iterator find_it = _leaf_node_map.find(id); if (find_it == _leaf_node_map.end()) { return false; } else { uint32_t height = 0; Node* node = find_it->second; while (node != NULL) { height++; ids.emplace_back(node->id, 0); node = node->parent_node; } for (auto& id : ids) { id.second = height--; } } return true; } Node* Tree::get_node() { return _nodes; } size_t Tree::get_total_node_num() { return _total_node_num; } } // end namespace framework } // end namespace paddle