/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include #include #include #include "paddle/fluid/framework/data_feed.h" #include "paddle/fluid/framework/fleet/tree_wrapper.h" namespace paddle { namespace framework { int Tree::load(std::string path, std::string tree_pipe_command_) { uint64_t linenum = 0; size_t idx = 0; std::vector lines; std::vector strs; std::vector items; int err_no; std::shared_ptr fp_ = fs_open_read(path, &err_no, tree_pipe_command_); string::LineFileReader reader; while (reader.getline(&*(fp_.get()))) { line = std::string(reader.get()); strs.clear(); boost::split(strs, line, boost::is_any_of("\t")); if (0 == linenum) { _total_node_num = boost::lexical_cast(strs[0]); _nodes = new Node[_total_node_num]; if (strs.size() > 1) { _tree_height = boost::lexical_cast(strs[1]); } ++linenum; continue; } if (strs.size() < 4) { LOG(WARNING) << "each line must has more than field"; return -1; } Node& node = _nodes[idx]; // id node.id = boost::lexical_cast(strs[0]); // embedding items.clear(); if (!strs[1].empty()) { boost::split(items, strs[1], boost::is_any_of(" ")); for (size_t i = 0; i != items.size(); ++i) { node.embedding.emplace_back(boost::lexical_cast(items[i])); } } // parent items.clear(); if (!strs[2].empty()) { node.parent_node = _nodes + boost::lexical_cast(strs[2]); } // child items.clear(); if (!strs[3].empty()) { boost::split(items, strs[3], boost::is_any_of(" ")); // node.sub_nodes = new Node*[items.size()]; for (size_t i = 0; i != items.size(); ++i) { node.sub_nodes.push_back(_nodes + boost::lexical_cast(items[i])); // node.sub_nodes[i] = _nodes + boost::lexical_cast(items[i]); } // node.sub_node_num = items.size(); } else { //没有孩子节点,当前节点是叶节点 _leaf_node_map[node.id] = &node; // node.sub_node_num = 0; } if (strs.size() > 4) { node.height = boost::lexical_cast(strs[4]); } ++idx; ++linenum; } _head = _nodes + _total_node_num - 1; LOG(INFO) << "all lines:" << linenum << ", all tree nodes:" << idx; return 0; } void Tree::print_tree() { /* std::queue q; if (_head) { q.push(_head); } while (!q.empty()) { const Node* node = q.front(); q.pop(); std::cout << "node_id: " << node->id << std::endl; std::cout << "node_embedding: "; for (int i = 0; i != node->embedding.size(); ++i) { std::cout << node->embedding[i] << " "; } std::cout << std::endl; if (node->parent_node) { std::cout << "parent_idx: " << node->parent_node - _nodes << std::endl; } if (node->sub_node_num > 0) { for (int i = 0; i != node->sub_node_num; ++i) { std::cout << "child_idx" << i << ": " << node->sub_nodes[i] - _nodes << std::endl; } } std::cout << "-------------------------------------" << std::endl; for (int i = 0; i != node->sub_node_num; ++i) { Node* tmp_node = node->sub_nodes[i]; q.push(tmp_node); } } */ } int Tree::dump_tree(const uint64_t table_id, int fea_value_dim, const std::string tree_path) { int ret; std::shared_ptr fp = paddle::framework::fs_open(tree_path, "w", &ret, ""); std::vector fea_keys, std::vector pull_result_ptr; fea_keys.reserve(_total_node_num); pull_result_ptr.reserve(_total_node_num); for (size_t i = 0; i != _total_node_num; ++i) { _nodes[i].embedding.resize(fea_value_dim); fea_key.push_back(_nodes[i].id); pull_result_ptr.push_back(_nodes[i].embedding.data()); } std::string first_line = boost::lexical_cast(_total_node_num) + "\t" + boost::lexical_cast(_tree_height); fwrite(first_line.c_str(), first_line.length(), 1, &*fp); std::string line_break_str("\n"); std::string line(""); for (size_t i = 0; i != _total_node_num; ++i) { line = line_break_str; const Node& node = _nodes[i]; line += boost::lexical_cast(node.id) + "\t"; if (!node.embedding.empty()) { for (size_t j = 0; j != node.embedding.size() - 1; ++j) { line += boost::lexical_cast(node.embedding[j]) + " "; } line += boost::lexical_cast( node.embedding[node.embedding.size() - 1]); } else { LOG(WARNING) << "node_idx[" << i << "], id[" << node.id << "] " << "has no embeddings"; } line += "\t"; if (node.parent_node) { line += boost::lexical_cast(node.parent_node - _nodes); } line += "\t"; if (node.sub_nodes.size() > 0) { for (uint32_t j = 0; j < node.sub_nodes.size() - 1; ++j) { line += boost::lexical_cast(node.sub_nodes[j] - _nodes) + " "; } line += boost::lexical_cast( node.sub_nodes[node.sub_nodes.size() - 1] - _nodes); } line += "\t" + boost::lexical_cast(node.height); fwrite(line.c_str(), line.length(), 1, &*fp); } return 0; } bool Tree::trace_back(uint64_t id, std::vector>* ids) { ids.clear(); std::unordered_map::iterator find_it = _leaf_node_map.find(id); if (find_it == _leaf_node_map.end()) { return false; } else { uint32_t height = 0; Node* node = find_it->second; while (node != NULL) { height++; ids->emplace_back(node->id, 0); node = node->parent_node; } for (auto& pair_id : *ids) { pair_id.second = height--; } } return true; } Node* Tree::get_node() { return _nodes; } size_t Tree::get_total_node_num() { return _total_node_num; } } // end namespace framework } // end namespace paddle