From e2ba9668b4a0b9b8c820f8fe152b1f6fc65310e9 Mon Sep 17 00:00:00 2001 From: zhaozhehao Date: Fri, 18 Jan 2019 15:24:26 +0800 Subject: [PATCH] Tree conv op (#15217) * refactor tree2col operator with new memory mechanism test=develop * test=develop * test=develop * Modified API according to panyx0718 test=develop * fix API change according to heavengate test=develop * Modify API comment test=develop --- paddle/fluid/API.spec | 1 + paddle/fluid/operators/CMakeLists.txt | 2 +- paddle/fluid/operators/math/CMakeLists.txt | 1 + paddle/fluid/operators/math/tree2col.cc | 197 +++++++++++++++++ paddle/fluid/operators/math/tree2col.cu | 208 ++++++++++++++++++ paddle/fluid/operators/math/tree2col.h | 90 ++++++++ paddle/fluid/operators/tree_conv_op.cc | 129 +++++++++++ paddle/fluid/operators/tree_conv_op.cu | 24 ++ paddle/fluid/operators/tree_conv_op.h | 146 ++++++++++++ python/paddle/fluid/layers/nn.py | 71 ++++++ .../tests/unittests/test_tree_conv_op.py | 120 ++++++++++ 11 files changed, 988 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/math/tree2col.cc create mode 100644 paddle/fluid/operators/math/tree2col.cu create mode 100644 paddle/fluid/operators/math/tree2col.h create mode 100644 paddle/fluid/operators/tree_conv_op.cc create mode 100644 paddle/fluid/operators/tree_conv_op.cu create mode 100644 paddle/fluid/operators/tree_conv_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_tree_conv_op.py diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index f6bf54d339..0a4edea2c3 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -215,6 +215,7 @@ paddle.fluid.layers.py_func ArgSpec(args=['func', 'x', 'out', 'backward_func', ' paddle.fluid.layers.psroi_pool ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.teacher_student_sigmoid_loss ArgSpec(args=['input', 'label', 'soft_max_up_bound', 'soft_max_lower_bound'], varargs=None, keywords=None, defaults=(15.0, -15.0)) paddle.fluid.layers.huber_loss ArgSpec(args=['input', 'label', 'delta'], varargs=None, keywords=None, defaults=None) +paddle.fluid.layers.tree_conv ArgSpec(args=['nodes_vector', 'edge_set', 'output_size', 'num_filters', 'max_depth', 'act', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 2, 'tanh', None, None, None)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index e53a6a562a..992a2bdd5a 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -65,7 +65,7 @@ set(COMMON_OP_DEPS ${OP_HEADER_DEPS}) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) -set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler) +set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler tree2col) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions) if (WITH_GPU) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu) diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 600ab14d37..dc27e543f0 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -60,6 +60,7 @@ math_library(matrix_bit_code) math_library(unpooling) math_library(vol2col) math_library(prelu) +math_library(tree2col DEPS math_function) cc_test(math_function_test SRCS math_function_test.cc DEPS math_function) cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor) diff --git a/paddle/fluid/operators/math/tree2col.cc b/paddle/fluid/operators/math/tree2col.cc new file mode 100644 index 0000000000..05ce5bc7a2 --- /dev/null +++ b/paddle/fluid/operators/math/tree2col.cc @@ -0,0 +1,197 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/math/tree2col.h" +#include +#include + +namespace paddle { +namespace operators { +namespace math { +using Tensor = framework::Tensor; +std::vector Tree2ColUtil::construct_patch( + size_t root, int max_depth, const std::vector> &tr) { + std::stack> stack; + std::unordered_map visited; + std::vector patch; + + stack.push(TreeNode(root, 1, 1, 0)); + patch.emplace_back(TreeNode(root, 1, 1, 0)); + visited[root] = true; + + while (!stack.empty()) { + TreeNode &u = stack.top(); + bool end = true; + size_t node = u.get_node(), sz = tr[node].size(); + visited[node] = true; + for (size_t i = 0; i < sz; i++) { + size_t v = tr[node][i]; + if (!visited[v] && static_cast(u.get_depth()) + 1 < max_depth) { + visited[v] = true; + stack.push(TreeNode(v, i, sz, u.get_depth() + 1)); + patch.push_back(TreeNode(v, i + 1, sz, u.get_depth() + 1)); + end = false; + } + } + if (end) { + stack.pop(); + } + } + return patch; +} + +void Tree2ColUtil::construct_tree(const paddle::Tensor &EdgeSet, + std::vector> *tr, + size_t *node_count) { + auto edge_set_dims = EdgeSet.dims(); + PADDLE_ENFORCE_EQ(edge_set_dims[1], 2); + int64_t edge_count = EdgeSet.numel(); + + const int *edge_data = EdgeSet.data(); + + for (int64_t i = 0; i < edge_count; i += 2) { + int u = edge_data[i], v = edge_data[i + 1]; + if (u != 0 && v != 0) (*node_count)++; + } + (*node_count)++; + + tr->resize(static_cast(*node_count + 1)); + + for (int64_t i = 0; i < edge_count; i += 2) { + int u = edge_data[i], v = edge_data[i + 1]; + if (u != 0 && v != 0) { + tr->at(u).push_back(v); + } else { + break; + } + } +} + +template +class Tree2ColFunctor { + public: + void operator()(const platform::CPUDeviceContext &context, + const framework::Tensor &EdgeSet, + const framework::Tensor &node_features, + framework::Tensor *patch, int max_depth) { + std::vector> tr; + auto feature_dims = node_features.dims(); + auto cpu_place = boost::get(context.GetPlace()); + math::SetConstant constant; + int64_t feature_size = feature_dims[1]; + size_t patch_elem_size = 3 * static_cast(feature_size); + size_t node_count = 0, patch_count = 0, patch_size; + Tree2ColUtil::construct_tree(EdgeSet, &tr, &node_count); + std::vector> processing_list; + for (size_t u = 1; u <= node_count; u++) { + std::vector temp_patch = + Tree2ColUtil::construct_patch(u, max_depth, tr); + if (!temp_patch.empty()) { + processing_list.emplace_back(temp_patch); + } + } + patch_size = processing_list.size(); + + T *patch_data = + patch->mutable_data({static_cast(patch_size), + static_cast(patch_elem_size)}, + cpu_place); + constant(context, patch, 0); + const T *features = node_features.data(); + + for (auto &patch_item : processing_list) { + size_t pointer_base = patch_count * patch_elem_size; + for (auto &v : patch_item) { + T eta_l = v.eta_l(max_depth), eta_r = v.eta_r(max_depth), + eta_t = v.eta_t(max_depth); + size_t id = v.get_node() - 1; + for (int i = 0; i < feature_size; i++) { + patch_data[pointer_base + i * 3] += + eta_l * features[id * feature_size + i]; + patch_data[pointer_base + i * 3 + 1] += + eta_r * features[id * feature_size + i]; + patch_data[pointer_base + i * 3 + 2] += + eta_t * features[id * feature_size + i]; + } + } + patch_count++; + } + patch->Resize({static_cast(patch_count), + static_cast(patch_elem_size)}); + } +}; +template +class Col2TreeFunctor { + public: + void operator()(const platform::CPUDeviceContext &context, + const framework::Tensor &EdgeSet, + const framework::Tensor &out_grad, framework::Tensor *in_grad, + int max_depth) { + std::vector> tr; + auto output_dims = out_grad.dims(); + auto cpu_place = boost::get(context.GetPlace()); + math::SetConstant constant; + int64_t output_size = output_dims[1]; + size_t grad_elem_size = 3 * static_cast(output_size); + size_t node_count = 0, grad_count = 0; + Tree2ColUtil::construct_tree(EdgeSet, &tr, &node_count); + std::vector> processing_list; + std::vector> grad_list; + grad_list.resize(node_count); + for (size_t u = 1; u <= node_count; u++) { + std::vector tmp = + Tree2ColUtil::construct_patch(u, max_depth, tr); + if (!tmp.empty()) { + processing_list.push_back(tmp); + } + } + for (size_t patch_id = 0; patch_id < processing_list.size(); patch_id++) { + for (auto v : processing_list[patch_id]) { + grad_list[v.get_node() - 1].push_back(v.change_node(patch_id + 1)); + } + } + T *grad_data = + in_grad->mutable_data({static_cast(node_count), + static_cast(grad_elem_size)}, + cpu_place); + + constant(context, in_grad, 0); + const T *out_g = out_grad.data(); + for (auto &patch_item : grad_list) { + size_t pointer_base = grad_count * grad_elem_size; + for (auto &v : patch_item) { + T eta_l = v.eta_l(max_depth), eta_r = v.eta_r(max_depth), + eta_t = v.eta_t(max_depth); + size_t id = v.get_node() - 1; + for (int i = 0; i < output_size; i++) { + grad_data[pointer_base + i * 3] += + eta_l * out_g[id * output_size + i]; + grad_data[pointer_base + i * 3 + 1] += + eta_r * out_g[id * output_size + i]; + grad_data[pointer_base + i * 3 + 2] += + eta_t * out_g[id * output_size + i]; + } + } + grad_count++; + } + } +}; + +template class Tree2ColFunctor; +template class Tree2ColFunctor; +template class Col2TreeFunctor; +template class Col2TreeFunctor; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/tree2col.cu b/paddle/fluid/operators/math/tree2col.cu new file mode 100644 index 0000000000..3c50a525c2 --- /dev/null +++ b/paddle/fluid/operators/math/tree2col.cu @@ -0,0 +1,208 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/tree2col.h" + +namespace paddle { +namespace operators { +namespace math { +using Tensor = framework::Tensor; +using Node = paddle::operators::math::TreeNode; +template +__global__ void tree2col(const T* eta, const int* node, const int* index, + const T* vectors, T* result, int feature_size, int n) { + const int thread_id = + (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + const int patch_id = thread_id / feature_size; + const int j = thread_id % feature_size; + if (patch_id < n) { + const int begin_o = patch_id * 3 * feature_size; + const int begin = index[patch_id * 2], end = index[patch_id * 2 + 1]; + T res_l = 0, res_r = 0, res_t = 0; + for (int i = begin; i < end; i++) { + const int id = node[i]; + const T vec = vectors[id * feature_size + j]; + res_l += eta[i * 3] * vec; + res_r += eta[i * 3 + 1] * vec; + res_t += eta[i * 3 + 2] * vec; + } + result[begin_o + j * 3] = res_l; + result[begin_o + j * 3 + 1] = res_r; + result[begin_o + j * 3 + 2] = res_t; + } +} +template +class Tree2ColFunctor { + public: + void operator()(const paddle::platform::CUDADeviceContext& context, + const framework::Tensor& EdgeSet, + const framework::Tensor& node_features, + framework::Tensor* patch, int max_depth) { + std::vector> tr; + auto gpu_place = boost::get(context.GetPlace()); + auto cpu_place = platform::CPUPlace(); + auto stream = context.stream(); + auto feature_dims = node_features.dims(); + math::SetConstant constant; + + Tensor EdgeSet_cpu; + framework::TensorCopy(EdgeSet, cpu_place, &EdgeSet_cpu); + int64_t feature_size = feature_dims[1]; + size_t patch_elem_size = 3 * static_cast(feature_size); + size_t node_count = 0, patch_count = 0, total_size = 0; + size_t max_size = feature_dims[0]; + Tree2ColUtil::construct_tree(EdgeSet_cpu, &tr, &node_count); + + std::vector> processing_list; + for (size_t u = 1; u <= node_count; u++) { + std::vector tmp = Tree2ColUtil::construct_patch(u, max_depth, tr); + if (!tmp.empty()) { + processing_list.push_back(tmp); + total_size += tmp.size(); + } + } + + size_t patch_size = processing_list.size(); + Tensor node_cpu, node_gpu, eta_cpu, eta_gpu, index_cpu, index_gpu; + int* node = node_cpu.mutable_data({static_cast(total_size)}, + cpu_place); + T* eta = eta_cpu.mutable_data({static_cast(total_size * 3)}, + cpu_place); + int* index = index_cpu.mutable_data( + {static_cast(patch_size * 2)}, cpu_place); + + int idx = 0, index_idx = 0; + for (auto& tmp : processing_list) { + index[index_idx++] = idx; + for (auto& v : tmp) { + node[idx] = static_cast(v.node - 1); + eta[idx * 3] = v.eta_l(max_depth); + eta[idx * 3 + 1] = v.eta_r(max_depth); + eta[idx * 3 + 2] = v.eta_t(max_depth); + idx++; + } + index[index_idx++] = idx; + } + framework::TensorCopy(node_cpu, gpu_place, context, &node_gpu); + framework::TensorCopy(eta_cpu, gpu_place, context, &eta_gpu); + framework::TensorCopy(index_cpu, gpu_place, context, &index_gpu); + + int elem_size = patch_size * feature_size; + int blocks = (elem_size + 1024 - 1) / 1024; + int block_x = 512; + int block_y = (blocks + 512 - 1) / 512; + dim3 threads(1024, 1); + dim3 grid(block_x, block_y); + + patch->mutable_data( + {static_cast(max_size), static_cast(patch_elem_size)}, + gpu_place); + constant(context, patch, 0); + tree2col<<>>( + eta_gpu.data(), node_gpu.data(), index_gpu.data(), + node_features.data(), patch->data(), feature_size, patch_size); + } +}; +template +class Col2TreeFunctor { + public: + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& EdgeSet, + const framework::Tensor& patch_grad, + framework::Tensor* embedding_grad, int max_depth) { + std::vector> tr; + auto gpu_place = boost::get(context.GetPlace()); + auto cpu_place = platform::CPUPlace(); + auto stream = context.stream(); + auto output_dims = patch_grad.dims(); + math::SetConstant constant; + + Tensor EdgeSet_cpu; + framework::TensorCopy(EdgeSet, cpu_place, &EdgeSet_cpu); + int64_t output_size = output_dims[1]; + size_t patch_elem_size = 3 * static_cast(output_size); + size_t node_count = 0, patch_count = 0; + size_t max_size = output_dims[0]; + Tree2ColUtil::construct_tree(EdgeSet_cpu, &tr, &node_count); + std::vector> processing_list; + std::vector> grad_list; + grad_list.resize(node_count); + size_t total_size = 0, grad_size = node_count; + for (size_t u = 1; u <= node_count; u++) { + std::vector tmp = Tree2ColUtil::construct_patch(u, max_depth, tr); + if (!tmp.empty()) { + processing_list.push_back(tmp); + } + } + for (size_t patch_id = 0; patch_id < processing_list.size(); patch_id++) { + for (auto v : processing_list[patch_id]) { + grad_list[v.get_node() - 1].push_back(v.change_node(patch_id + 1)); + } + } + for (auto& tmp : grad_list) { + total_size += tmp.size(); + } + + Tensor node_cpu, node_gpu, eta_cpu, eta_gpu, index_cpu, index_gpu; + int* node = node_cpu.mutable_data({static_cast(total_size)}, + cpu_place); + T* eta = eta_cpu.mutable_data({static_cast(total_size * 3)}, + cpu_place); + int* index = index_cpu.mutable_data( + {static_cast(grad_size * 2)}, cpu_place); + + size_t idx = 0, index_idx = 0; + for (auto& tmp : grad_list) { + index[index_idx++] = idx; + for (auto& v : tmp) { + node[idx] = static_cast(v.node - 1); + eta[idx * 3] = v.eta_l(max_depth); + eta[idx * 3 + 1] = v.eta_r(max_depth); + eta[idx * 3 + 2] = v.eta_t(max_depth); + idx++; + } + index[index_idx++] = idx; + } + framework::TensorCopy(node_cpu, gpu_place, &node_gpu); + framework::TensorCopy(eta_cpu, gpu_place, &eta_gpu); + framework::TensorCopy(index_cpu, gpu_place, &index_gpu); + + int elem_size = output_size * grad_size; + int blocks = (elem_size + 1024 - 1) / 1024; + int block_x = 512; + int block_y = (blocks + 512 - 1) / 512; + dim3 threads(1024, 1); + dim3 grid(block_x, block_y); + + embedding_grad->mutable_data( + {static_cast(max_size), static_cast(patch_elem_size)}, + gpu_place); + + constant(context, embedding_grad, 0); + tree2col<<>>( + eta_gpu.data(), node_gpu.data(), index_gpu.data(), + patch_grad.data(), embedding_grad->data(), output_size, + grad_size); + } +}; + +template class Tree2ColFunctor; +template class Tree2ColFunctor; +template class Col2TreeFunctor; +template class Col2TreeFunctor; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/tree2col.h b/paddle/fluid/operators/math/tree2col.h new file mode 100644 index 0000000000..478ba78e25 --- /dev/null +++ b/paddle/fluid/operators/math/tree2col.h @@ -0,0 +1,90 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +using Tensor = framework::Tensor; +using DDim = framework::DDim; +namespace operators { +namespace math { +class TreeNode { + public: + size_t node; + explicit TreeNode(size_t node = 0, size_t index = 0, size_t pclen = 0, + size_t depth = 0) + : node(node), index(index), pclen(pclen), depth(depth) {} + template + T eta_t(T filter_depth) { + return ((filter_depth - this->depth) / filter_depth); + } + template + T eta_l(T filter_depth) { + T temp; + if (this->pclen == 1) { + temp = 0.5; + } else { + temp = (this->index - 1.0) / (this->pclen - 1.0); + } + return (1.0 - this->eta_t(filter_depth)) * temp; + } + template + T eta_r(T filter_depth) { + return (1.0 - this->eta_t(filter_depth)) * + (1.0 - this->eta_l(filter_depth)); + } + TreeNode change_node(size_t v) { + return TreeNode(v, this->index, this->pclen, this->depth); + } + size_t get_node() { return this->node; } + size_t get_depth() { return this->depth; } + + private: + size_t index, pclen, depth; +}; +class Tree2ColUtil { + public: + static std::vector construct_patch( + size_t root, int max_depth, const std::vector> &tr); + + static void construct_tree(const Tensor &EdgeSet, + std::vector> *tr, + size_t *node_count); +}; + +template +class Tree2ColFunctor { + public: + void operator()(const DeviceContext &context, + const framework::Tensor &EdgeSet, + const framework::Tensor &node_features, + framework::Tensor *patch, int max_depth); +}; +template +class Col2TreeFunctor { + public: + void operator()(const DeviceContext &context, + const framework::Tensor &EdgeSet, + const framework::Tensor &out_grad, framework::Tensor *in_grad, + int max_depth); +}; +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/tree_conv_op.cc b/paddle/fluid/operators/tree_conv_op.cc new file mode 100644 index 0000000000..615ea285e5 --- /dev/null +++ b/paddle/fluid/operators/tree_conv_op.cc @@ -0,0 +1,129 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/tree_conv_op.h" +#include + +namespace paddle { +namespace operators { +class TreeConvOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("NodesVector", + "(Tensor) The feature vector of every node on the tree. " + "The shape of the feature vector must be " + "[max_tree_node_size, feature_size]."); + AddInput("EdgeSet", + "(Tensor) The Edges of Tree. The edge must be directional. " + "The shape of the edge set must be [max_tree_node_size, 2]."); + AddInput("Filter", + "(Tensor) The feature detector. " + "The shape of the filter is " + "[feature_size, 3, output_size, num_filters]."); + AddOutput("Out", + "(Tensor) The feature vector of subtrees. " + "The shape of the output tensor is [max_tree_node_size, " + "output_size, num_filters]. " + "The output tensor could be a new feature " + "vector for next tree convolution layers."); + AddAttr("max_depth", + "(int, default: 2) The depth of feature detector.") + .SetDefault(2) + .GreaterThan(1); + AddComment(R"DOC( +**Tree-Based Convolution Operator** + +Tree-Based Convolution is a kind of convolution based on tree structure. +Tree-Based Convolution is a part of Tree-Based Convolution Neural Network(TBCNN), +which is used to classify tree structures, such as Abstract Syntax Tree. +Tree-Based Convolution proposed a kind of data structure called continuous binary tree, +which regards multiway tree as binary tree. +The paper of Tree-Based Convolution Operator is here: +https://arxiv.org/abs/1409.5718v1 +)DOC"); + } +}; +class TreeConvOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasOutput("Out")); + auto edge_dims = ctx->GetInputDim("EdgeSet"); + auto vector_dims = ctx->GetInputDim("NodesVector"); + auto filter_dims = ctx->GetInputDim("Filter"); + PADDLE_ENFORCE_EQ(edge_dims[2], 2, "Input(EdgeSet) dim[2] should be 2"); + PADDLE_ENFORCE_EQ(edge_dims.size(), 3, + "The dimension of EdgeSet Tensor should be 3"); + PADDLE_ENFORCE_EQ(vector_dims.size(), 3, + "The dimension of NodesVector Tensor should be 3"); + PADDLE_ENFORCE_EQ(filter_dims.size(), 4, + "The dimension of Filter Tensor should be 4"); + PADDLE_ENFORCE_EQ(filter_dims[1], 3, "Input(Filter) dim[1] should be 3"); + PADDLE_ENFORCE_EQ( + filter_dims[0], vector_dims[2], + "Input(Filter) dim[0] must equal to Input(NodesVector) dim[2]"); + auto output_dims = framework::make_ddim( + {vector_dims[0], vector_dims[1], filter_dims[2], filter_dims[3]}); + ctx->SetOutputDim("Out", output_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(ctx.Input("NodesVector")->type(), + ctx.device_context()); + } +}; + +class TreeConvGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + auto vectors_dims = ctx->GetInputDim("NodesVector"); + auto filter_dims = ctx->GetInputDim("Filter"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "the gradient of output(Out) must not be null"); + if (ctx->HasOutput(framework::GradVarName("Filter"))) { + ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims); + } + if (ctx->HasOutput(framework::GradVarName("NodesVector"))) { + ctx->SetOutputDim(framework::GradVarName("NodesVector"), vectors_dims); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(ctx.Input("NodesVector")->type(), + ctx.device_context()); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(tree_conv, ops::TreeConvOp, ops::TreeConvOpMaker, + paddle::framework::DefaultGradOpDescMaker); + +REGISTER_OPERATOR(tree_conv_grad, ops::TreeConvGradOp); + +REGISTER_OP_CPU_KERNEL( + tree_conv, ops::TreeConvKernel, + ops::TreeConvKernel); + +REGISTER_OP_CPU_KERNEL( + tree_conv_grad, + ops::TreeConvGradKernel, + ops::TreeConvGradKernel); diff --git a/paddle/fluid/operators/tree_conv_op.cu b/paddle/fluid/operators/tree_conv_op.cu new file mode 100644 index 0000000000..eebfe412bd --- /dev/null +++ b/paddle/fluid/operators/tree_conv_op.cu @@ -0,0 +1,24 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/tree_conv_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + tree_conv, ops::TreeConvKernel, + ops::TreeConvKernel); +REGISTER_OP_CUDA_KERNEL( + tree_conv_grad, + ops::TreeConvGradKernel, + ops::TreeConvGradKernel); diff --git a/paddle/fluid/operators/tree_conv_op.h b/paddle/fluid/operators/tree_conv_op.h new file mode 100644 index 0000000000..a84589b32f --- /dev/null +++ b/paddle/fluid/operators/tree_conv_op.h @@ -0,0 +1,146 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/tree2col.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; +using DDim = framework::DDim; +template +class TreeConvKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + math::Tree2ColFunctor tree2col; + math::SetConstant constant; + + auto *Edges = ctx.Input("EdgeSet"); + auto *Embeddings = ctx.Input("NodesVector"); + auto *Filter = ctx.Input("Filter"); + auto *output_emb = ctx.Output("Out"); + int max_depth = ctx.Attr("max_depth"); + + auto &dev_ctx = ctx.template device_context(); + auto blas = math::GetBlas(dev_ctx); + + Tensor W; + W.ShareDataWith(*Filter); + W.Resize(framework::flatten_to_2d(Filter->dims(), 2)); + + int batch_size = static_cast(Edges->dims()[0]); + int n = static_cast(Embeddings->dims()[1]); + int out_size = static_cast(Filter->dims()[2]); + int num_filters = static_cast(Filter->dims()[3]); + output_emb->mutable_data({batch_size, n, out_size, num_filters}, + ctx.GetPlace()); + + auto edge_set_slicedim = framework::slice_ddim( + Edges->dims(), 1, static_cast(Edges->dims().size())); + + auto embedding_slicedim = framework::slice_ddim( + Embeddings->dims(), 1, static_cast(Embeddings->dims().size())); + + auto output_slicedim = framework::slice_ddim( + output_emb->dims(), 1, static_cast(output_emb->dims().size())); + + output_slicedim = framework::flatten_to_2d(output_slicedim, 1); + + for (int idx = 0; idx < batch_size; idx++) { + auto edge_set = Edges->Slice(idx, idx + 1).Resize(edge_set_slicedim); + auto embeddings = + Embeddings->Slice(idx, idx + 1).Resize(embedding_slicedim); + auto out_vec = output_emb->Slice(idx, idx + 1).Resize(output_slicedim); + Tensor patch; + tree2col(dev_ctx, edge_set, embeddings, &patch, max_depth); + constant(dev_ctx, &out_vec, 0); + blas.MatMul(patch, W, &out_vec); + } + } +}; +template +class TreeConvGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *out_g = ctx.Input(framework::GradVarName("Out")); + auto *in_g = ctx.Output(framework::GradVarName("NodesVector")); + auto *filter_g = ctx.Output(framework::GradVarName("Filter")); + int max_depth = ctx.Attr("max_depth"); + auto *Embeddings = ctx.Input("NodesVector"); + auto *edges = ctx.Input("EdgeSet"); + auto *Filter = ctx.Input("Filter"); + math::Tree2ColFunctor tree2col; + math::Col2TreeFunctor col2tree; + math::SetConstant constant; + auto &dev_ctx = ctx.template device_context(); + auto blas = math::GetBlas(dev_ctx); + + Tensor W; + W.ShareDataWith(*Filter); + W.Resize(framework::flatten_to_2d(Filter->dims(), 1)); + + int batch_size = static_cast(Embeddings->dims()[0]); + + auto edge_set_slicedim = framework::slice_ddim( + edges->dims(), 1, static_cast(edges->dims().size())); + + auto embedding_slicedim = framework::slice_ddim( + Embeddings->dims(), 1, static_cast(Embeddings->dims().size())); + + auto out_grad_dims = framework::slice_ddim( + out_g->dims(), 1, static_cast(out_g->dims().size())); + out_grad_dims = framework::flatten_to_2d(out_grad_dims, 1); + if (filter_g) { + filter_g->mutable_data(Filter->dims(), ctx.GetPlace()); + Tensor f_g; + f_g.ShareDataWith(*filter_g); + f_g.Resize(framework::flatten_to_2d(Filter->dims(), 2)); + constant(dev_ctx, filter_g, 0); + for (int batch_id = 0; batch_id < batch_size; batch_id++) { + auto edge_set = + edges->Slice(batch_id, batch_id + 1).Resize(edge_set_slicedim); + auto embeddings = Embeddings->Slice(batch_id, batch_id + 1) + .Resize(embedding_slicedim); + auto out_grad = + out_g->Slice(batch_id, batch_id + 1).Resize(out_grad_dims); + Tensor patch; + tree2col(dev_ctx, edge_set, embeddings, &patch, max_depth); + blas.MatMul(patch, true, out_grad, false, T(1.0), &f_g, T(1.0)); + } + } + if (in_g) { + auto input_grad_dims = framework::slice_ddim( + in_g->dims(), 1, static_cast(in_g->dims().size())); + in_g->mutable_data(Embeddings->dims(), ctx.GetPlace()); + constant(dev_ctx, in_g, 0); + for (int batch_id = 0; batch_id < batch_size; batch_id++) { + auto edge_set = + edges->Slice(batch_id, batch_id + 1).Resize(edge_set_slicedim); + auto out_grad = + out_g->Slice(batch_id, batch_id + 1).Resize(out_grad_dims); + auto in_grad = + in_g->Slice(batch_id, batch_id + 1).Resize(input_grad_dims); + Tensor in_grad_temp; + col2tree(dev_ctx, edge_set, out_grad, &in_grad_temp, max_depth); + blas.MatMul(in_grad_temp, false, W, true, &in_grad); + } + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 093571a93b..ea88d8b4d0 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -183,6 +183,7 @@ __all__ = [ 'psroi_pool', 'teacher_student_sigmoid_loss', 'huber_loss', + 'tree_conv', ] kIgnoreIndex = -100 @@ -9930,3 +9931,73 @@ def huber_loss(input, label, delta): 'Residual': residual}, attrs={'delta': delta}) return out + + +@templatedoc() +def tree_conv(nodes_vector, + edge_set, + output_size, + num_filters=1, + max_depth=2, + act='tanh', + param_attr=None, + bias_attr=None, + name=None): + """ + ${comment} + + Args: + nodes_vector(${nodes_vector_type}): ${nodes_vector_comment} + edge_set(${edge_set_type}): ${edge_set_comment} + output_size(int): output feature width + num_filters(int): number of filters, Default 1 + max_depth(int): max depth of filters, Default 2 + act(str): activation function, Default tanh + param_attr(ParamAttr): the parameter attribute for the filters, Default None + bias_attr(ParamAttr): the parameter attribute for the bias of this layer, Default None + name(str): a name of this layer(optional). If set None, the layer will be named automatically, Default None + + Returns: + out(${out_type}): ${out_comment} + + Examples: + .. code-block:: python + + nodes_vector = layers.data(name='vectors', shape=[None, 10, 5], dtype='float32) + # None for batch size, 10 for max_node_size of dataset, 5 for vector width + edge_set = layers.data(name='edge_set', shape=[None, 10, 2], dtype='float32') + # None for batch size, 10 for max_node_size of dataset, 2 for every edge has two nodes + # edges must be directional + out_vector = layers.tree_conv(nodes_vector, edge_set, 6, 1, 2, 'tanh', + ParamAttr(initializer=Constant(1.0), ParamAttr(initializer=Constant(1.0)) + # the shape of output will be [None, 10, 6, 1], + # None for batch size, 10 for max_node_size of dataset, 6 for output size, 1 for 1 filter + out_vector = layers.reshape(out_vector, shape=[None, 10, 6]) + # After reshape, output tensor could be nodes_vector for next tree convolution + out_vector_2 = layers.tree_conv(out_vector, edge_set, 3, 4, 2, 'tanh', + ParamAttr(initializer=Constant(1.0), ParamAttr(initializer=Constant(1.0)) + # also output tensor could be pooling(the pooling in paper called global pooling) + pooled = layers.reduce_max(out_vector, dims=2) # global pooling + """ + helper = LayerHelper("tree_conv", **locals()) + dtype = helper.input_dtype('nodes_vector') + feature_size = nodes_vector.shape[2] + W_shape = [feature_size, 3, output_size, num_filters] + W = helper.create_parameter( + attr=param_attr, shape=W_shape, dtype=dtype, is_bias=False) + if name == None: + out = helper.create_variable_for_type_inference(dtype=dtype) + else: + out = helper.create_variable(name=name, dtype=dtype, persistable=False) + helper.append_op( + type='tree_conv', + inputs={'NodesVector': nodes_vector, + 'EdgeSet': edge_set, + 'Filter': W}, + outputs={'Out': out, }, + attrs={'max_depth': max_depth}) + if helper.bias_attr: + pre_activation = helper.append_bias_op(out) + else: + pre_activation = out + return helper.append_activation(pre_activation) diff --git a/python/paddle/fluid/tests/unittests/test_tree_conv_op.py b/python/paddle/fluid/tests/unittests/test_tree_conv_op.py new file mode 100644 index 0000000000..712453d291 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_tree_conv_op.py @@ -0,0 +1,120 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from op_test import OpTest + + +def collect_node_patch(og, max_depth): + """ + The naive method to construct patches + :param og: original graph + :param max_depth: the depth of convolution filters + :return: convolution patches + """ + + def gen(node, max_depth): + collected = [(node, 1, 1, 0, max_depth)] + + def recurse_helper(node, depth): + if depth > max_depth: + return + l = len(og[node]) + for idx, c in enumerate(og[node], 1): + if depth + 1 < max_depth: + collected.append((c, idx, l, depth + 1, max_depth)) + recurse_helper(c, depth + 1) + + recurse_helper(node, 0) + return collected + + res = [] + for u in range(1, len(og)): + lis = gen(u, max_depth) + if len(lis) > 0: + res.append(lis) + return res + + +class TestTreeConvOp(OpTest): + def setUp(self): + self.n = 17 + self.fea_size = 3 + self.output_size = 1 + self.max_depth = 2 + self.batch_size = 1 + self.num_filters = 1 + adj_array = [ + 1, 2, 1, 3, 1, 4, 1, 5, 2, 6, 2, 7, 2, 8, 4, 9, 4, 10, 5, 11, 6, 12, + 6, 13, 9, 14, 9, 15, 9, 16, 9, 17 + ] + adj = np.array(adj_array).reshape((1, self.n - 1, 2)).astype('int32') + adj = np.tile(adj, (self.batch_size, 1, 1)) + self.op_type = 'tree_conv' + vectors = np.random.random( + (self.batch_size, self.n, self.fea_size)).astype('float32') + self.inputs = { + 'EdgeSet': adj, + 'NodesVector': vectors, + 'Filter': np.random.random((self.fea_size, 3, self.output_size, + self.num_filters)).astype('float32') + } + self.attrs = {'max_depth': self.max_depth} + vectors = [] + for i in range(self.batch_size): + vector = self.get_output_naive(i) + vectors.append(vector) + self.outputs = {'Out': np.array(vectors).astype('float32'), } + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad( + ['NodesVector', 'Filter'], 'Out', max_relative_error=0.5) + + def get_output_naive(self, batch_id): + og = [[] for i in range(1, self.n + 2)] + st = np.array(self.inputs['EdgeSet'][batch_id]).tolist() + for e in st: + og[e[0]].append(e[1]) + patches = collect_node_patch(og, self.max_depth) + W = np.array(self.inputs['Filter']).astype('float32') + W = np.transpose(W, axes=[1, 0, 2, 3]) + vec = [] + for i, patch in enumerate(patches, 1): + result = np.zeros((1, W.shape[2], W.shape[3])) + for v in patch: + eta_t = float(v[4] - v[3]) / float(v[4]) + eta_l = (1.0 - eta_t) * (0.5 if v[2] == 1 else + float(v[1] - 1.0) / float(v[2] - 1.0)) + eta_r = (1.0 - eta_t) * (1.0 - eta_l) + x = self.inputs['NodesVector'][batch_id][v[0] - 1] + eta = np.array([eta_l, eta_r, eta_t]).reshape( + (3, 1)).astype('float32') + Wconvi = np.tensordot(eta, W, axes=([0], [0])) + x = np.array(x).reshape((1, 1, self.fea_size)) + res = np.tensordot(x, Wconvi, axes=2) + result = result + res + vec.append(result) + vec = np.concatenate(vec, axis=0) + vec = np.concatenate( + [ + vec, np.zeros( + (self.n - vec.shape[0], W.shape[2], W.shape[3]), + dtype='float32') + ], + axis=0) + return vec -- GitLab