// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "lite/backends/x86/math/tree2col.h" #include #include namespace paddle { namespace lite { namespace x86 { namespace math { std::vector Tree2ColUtil::construct_patch( size_t root, int max_depth, const std::vector> &tr) { std::stack> stack; std::map visited; std::vector patch; stack.push(TreeNode(root, 1, 1, 0)); patch.emplace_back(TreeNode(root, 1, 1, 0)); visited[root] = true; while (!stack.empty()) { TreeNode &u = stack.top(); bool end = true; size_t node = u.get_node(), sz = tr[node].size(); visited[node] = true; for (size_t i = 0; i < sz; i++) { size_t v = tr[node][i]; if (!visited[v] && static_cast(u.get_depth()) + 1 < max_depth) { visited[v] = true; stack.push(TreeNode(v, i, sz, u.get_depth() + 1)); patch.push_back(TreeNode(v, i + 1, sz, u.get_depth() + 1)); end = false; } } if (end) { stack.pop(); } } return patch; } void Tree2ColUtil::construct_tree(const lite::Tensor &EdgeSet, std::vector> *tr, size_t *node_count) { auto edge_set_dims = EdgeSet.dims(); PADDLE_ENFORCE_EQ(edge_set_dims[1], 2); int64_t edge_count = EdgeSet.numel(); const int *edge_data = EdgeSet.data(); for (int64_t i = 0; i < edge_count; i += 2) { int u = edge_data[i], v = edge_data[i + 1]; if (u != 0 && v != 0) (*node_count)++; } (*node_count)++; tr->resize(static_cast(*node_count + 1)); for (int64_t i = 0; i < edge_count; i += 2) { int u = edge_data[i], v = edge_data[i + 1]; if (u != 0 && v != 0) { tr->at(u).push_back(v); } else { break; } } } template class Tree2ColFunctor { public: void operator()(const lite::X86Context &context, const lite::Tensor &EdgeSet, const lite::Tensor &node_features, lite::Tensor *patch, int max_depth) { std::vector> tr; auto feature_dims = node_features.dims(); math::SetConstant constant; int64_t feature_size = feature_dims[1]; size_t patch_elem_size = 3 * static_cast(feature_size); size_t node_count = 0, patch_count = 0, patch_size; Tree2ColUtil::construct_tree(EdgeSet, &tr, &node_count); std::vector> processing_list; for (size_t u = 1; u <= node_count; u++) { std::vector temp_patch = Tree2ColUtil::construct_patch(u, max_depth, tr); if (!temp_patch.empty()) { processing_list.emplace_back(temp_patch); } } patch_size = processing_list.size(); // T *patch_data = // patch->template mutable_data({static_cast(patch_size), // static_cast(patch_elem_size)}, // cpu_place); patch->Resize({static_cast(patch_size), static_cast(patch_elem_size)}); auto *patch_data = patch->template mutable_data(lite::TargetType::kX86); constant(context, patch, 0); const T *features = node_features.data(); for (auto &patch_item : processing_list) { size_t pointer_base = patch_count * patch_elem_size; for (auto &v : patch_item) { T eta_l = v.eta_l(max_depth), eta_r = v.eta_r(max_depth), eta_t = v.eta_t(max_depth); size_t id = v.get_node() - 1; for (int i = 0; i < feature_size; i++) { patch_data[pointer_base + i * 3] += eta_l * features[id * feature_size + i]; patch_data[pointer_base + i * 3 + 1] += eta_r * features[id * feature_size + i]; patch_data[pointer_base + i * 3 + 2] += eta_t * features[id * feature_size + i]; } } patch_count++; } patch->Resize({static_cast(patch_count), static_cast(patch_elem_size)}); } }; template class Col2TreeFunctor { public: void operator()(const lite::X86Context &context, const lite::Tensor &EdgeSet, const lite::Tensor &out_grad, lite::Tensor *in_grad, int max_depth) { std::vector> tr; auto output_dims = out_grad.dims(); // auto cpu_place = boost::get(context.GetPlace()); math::SetConstant constant; int64_t output_size = output_dims[1]; size_t grad_elem_size = 3 * static_cast(output_size); size_t node_count = 0, grad_count = 0; Tree2ColUtil::construct_tree(EdgeSet, &tr, &node_count); std::vector> processing_list; std::vector> grad_list; grad_list.resize(node_count); for (size_t u = 1; u <= node_count; u++) { std::vector tmp = Tree2ColUtil::construct_patch(u, max_depth, tr); if (!tmp.empty()) { processing_list.push_back(tmp); } } for (size_t patch_id = 0; patch_id < processing_list.size(); patch_id++) { for (auto v : processing_list[patch_id]) { grad_list[v.get_node() - 1].push_back(v.change_node(patch_id + 1)); } } // T *grad_data = // in_grad->template mutable_data({static_cast(node_count), // static_cast(grad_elem_size)}, // cpu_place); in_grad->Resize({static_cast(node_count), static_cast(grad_elem_size)}); auto *grad_data = in_grad->template mutable_data(lite::TargetType::kX86); constant(context, in_grad, 0); const T *out_g = out_grad.data(); for (auto &patch_item : grad_list) { size_t pointer_base = grad_count * grad_elem_size; for (auto &v : patch_item) { T eta_l = v.eta_l(max_depth), eta_r = v.eta_r(max_depth), eta_t = v.eta_t(max_depth); size_t id = v.get_node() - 1; for (int i = 0; i < output_size; i++) { grad_data[pointer_base + i * 3] += eta_l * out_g[id * output_size + i]; grad_data[pointer_base + i * 3 + 1] += eta_r * out_g[id * output_size + i]; grad_data[pointer_base + i * 3 + 2] += eta_t * out_g[id * output_size + i]; } } grad_count++; } } }; template class Tree2ColFunctor; template class Tree2ColFunctor; template class Col2TreeFunctor; template class Col2TreeFunctor; } // namespace math } // namespace x86 } // namespace lite } // namespace paddle