// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/tree2col.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; using DDim = framework::DDim; template class TreeConvKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { math::Tree2ColFunctor tree2col; math::SetConstant constant; auto *Edges = ctx.Input("EdgeSet"); auto *Embeddings = ctx.Input("NodesVector"); auto *Filter = ctx.Input("Filter"); auto *output_emb = ctx.Output("Out"); int max_depth = ctx.Attr("max_depth"); auto &dev_ctx = ctx.template device_context(); auto blas = math::GetBlas(dev_ctx); Tensor W; W.ShareDataWith(*Filter); W.Resize(framework::flatten_to_2d(Filter->dims(), 2)); int batch_size = static_cast(Edges->dims()[0]); int n = static_cast(Embeddings->dims()[1]); int out_size = static_cast(Filter->dims()[2]); int num_filters = static_cast(Filter->dims()[3]); output_emb->mutable_data({batch_size, n, out_size, num_filters}, ctx.GetPlace()); auto edge_set_slicedim = framework::slice_ddim( Edges->dims(), 1, static_cast(Edges->dims().size())); auto embedding_slicedim = framework::slice_ddim( Embeddings->dims(), 1, static_cast(Embeddings->dims().size())); auto output_slicedim = framework::slice_ddim( output_emb->dims(), 1, static_cast(output_emb->dims().size())); output_slicedim = framework::flatten_to_2d(output_slicedim, 1); for (int idx = 0; idx < batch_size; idx++) { auto edge_set = Edges->Slice(idx, idx + 1).Resize(edge_set_slicedim); auto embeddings = Embeddings->Slice(idx, idx + 1).Resize(embedding_slicedim); auto out_vec = output_emb->Slice(idx, idx + 1).Resize(output_slicedim); Tensor patch; tree2col(dev_ctx, edge_set, embeddings, &patch, max_depth); constant(dev_ctx, &out_vec, 0); blas.MatMul(patch, W, &out_vec); } } }; template class TreeConvGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { auto *out_g = ctx.Input(framework::GradVarName("Out")); auto *in_g = ctx.Output(framework::GradVarName("NodesVector")); auto *filter_g = ctx.Output(framework::GradVarName("Filter")); int max_depth = ctx.Attr("max_depth"); auto *Embeddings = ctx.Input("NodesVector"); auto *edges = ctx.Input("EdgeSet"); auto *Filter = ctx.Input("Filter"); math::Tree2ColFunctor tree2col; math::Col2TreeFunctor col2tree; math::SetConstant constant; auto &dev_ctx = ctx.template device_context(); auto blas = math::GetBlas(dev_ctx); Tensor W; W.ShareDataWith(*Filter); W.Resize(framework::flatten_to_2d(Filter->dims(), 1)); int batch_size = static_cast(Embeddings->dims()[0]); auto edge_set_slicedim = framework::slice_ddim( edges->dims(), 1, static_cast(edges->dims().size())); auto embedding_slicedim = framework::slice_ddim( Embeddings->dims(), 1, static_cast(Embeddings->dims().size())); auto out_grad_dims = framework::slice_ddim( out_g->dims(), 1, static_cast(out_g->dims().size())); out_grad_dims = framework::flatten_to_2d(out_grad_dims, 1); if (filter_g) { filter_g->mutable_data(Filter->dims(), ctx.GetPlace()); Tensor f_g; f_g.ShareDataWith(*filter_g); f_g.Resize(framework::flatten_to_2d(Filter->dims(), 2)); constant(dev_ctx, filter_g, 0); for (int batch_id = 0; batch_id < batch_size; batch_id++) { auto edge_set = edges->Slice(batch_id, batch_id + 1).Resize(edge_set_slicedim); auto embeddings = Embeddings->Slice(batch_id, batch_id + 1) .Resize(embedding_slicedim); auto out_grad = out_g->Slice(batch_id, batch_id + 1).Resize(out_grad_dims); Tensor patch; tree2col(dev_ctx, edge_set, embeddings, &patch, max_depth); blas.MatMul(patch, true, out_grad, false, T(1.0), &f_g, T(1.0)); } } if (in_g) { auto input_grad_dims = framework::slice_ddim( in_g->dims(), 1, static_cast(in_g->dims().size())); in_g->mutable_data(Embeddings->dims(), ctx.GetPlace()); constant(dev_ctx, in_g, 0); for (int batch_id = 0; batch_id < batch_size; batch_id++) { auto edge_set = edges->Slice(batch_id, batch_id + 1).Resize(edge_set_slicedim); auto out_grad = out_g->Slice(batch_id, batch_id + 1).Resize(out_grad_dims); auto in_grad = in_g->Slice(batch_id, batch_id + 1).Resize(input_grad_dims); Tensor in_grad_temp; col2tree(dev_ctx, edge_set, out_grad, &in_grad_temp, max_depth); blas.MatMul(in_grad_temp, false, W, true, &in_grad); } } } }; } // namespace operators } // namespace paddle