From 39012536762576eda7c72aa5413de8202dba8e7d Mon Sep 17 00:00:00 2001 From: Siming Dai <908660116@qq.com> Date: Fri, 19 Nov 2021 12:32:04 +0800 Subject: [PATCH] Add paddle.incubate.graph_send_recv API (#37205) * add cpu version, using set: sum, min, max * add cpu version: mean * improve cpu code and fix dynamic memory allcation problem * fix arg error, add index judge, delete fp16 * fix bug in CudaAtomicMax and CudaAtomicMin * add CUDA version * fix grad_op bug for index * add op test, add correct cpu grad op * Add correct CUDA Mean grad * [Add] Successful MEAN and SUM * [Add] Successful MIN and MAX in CPU * [Add] Successful MIN and MAX in CUDA * fix windows dtype ci * fix ROCM ci by adding HIP flag * rename fused_gather_scatter to send_recv * unify name as send and recv * change zero index return time * add send_recv incubate api * fix index data type, add unittest case for API * delete redundant input tensor * fix en example and docs, add default value in pool_type * add shape judge and max grid judge * fix comment * fix index type bug * add const & * fix en docs * delete numpy in examples * add unittest for int input * fix send_recv comment * change send_recv to graph_send_recv --- paddle/fluid/operators/graph_send_recv_op.cc | 183 ++++++++ paddle/fluid/operators/graph_send_recv_op.cu | 419 ++++++++++++++++++ paddle/fluid/operators/graph_send_recv_op.h | 291 ++++++++++++ paddle/fluid/platform/cuda_primitives.h | 28 +- .../unittests/test_graph_send_recv_op.py | 309 +++++++++++++ python/paddle/incubate/__init__.py | 2 + python/paddle/incubate/operators/__init__.py | 1 + .../incubate/operators/graph_send_recv.py | 108 +++++ 8 files changed, 1335 insertions(+), 6 deletions(-) create mode 100644 paddle/fluid/operators/graph_send_recv_op.cc create mode 100644 paddle/fluid/operators/graph_send_recv_op.cu create mode 100644 paddle/fluid/operators/graph_send_recv_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py create mode 100644 python/paddle/incubate/operators/graph_send_recv.py diff --git a/paddle/fluid/operators/graph_send_recv_op.cc b/paddle/fluid/operators/graph_send_recv_op.cc new file mode 100644 index 0000000000..6af8388d9e --- /dev/null +++ b/paddle/fluid/operators/graph_send_recv_op.cc @@ -0,0 +1,183 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/graph_send_recv_op.h" + +namespace paddle { +namespace operators { + +class GraphSendRecvOP : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "GraphSendRecv"); + OP_INOUT_CHECK(ctx->HasInput("Src_index"), "Input", "Src_index", + "GraphSendRecv"); + OP_INOUT_CHECK(ctx->HasInput("Dst_index"), "Input", "Dst_index", + "GraphSendRecv"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "GraphSendRecv"); + + auto src_index_dims = ctx->GetInputDim("Src_index"); + if (src_index_dims.size() == 2) { + PADDLE_ENFORCE_EQ(src_index_dims[1], 1, + platform::errors::InvalidArgument( + "The last dim of Src_index should be 1 when it " + "is 2D, but we get %d", + src_index_dims[1])); + } else { + PADDLE_ENFORCE_EQ( + src_index_dims.size(), 1, + platform::errors::InvalidArgument( + "The Src_index should be 1D, when it is not 2D, but we get %d", + src_index_dims.size())); + } + + auto dst_index_dims = ctx->GetInputDim("Dst_index"); + if (dst_index_dims.size() == 2) { + PADDLE_ENFORCE_EQ(dst_index_dims[1], 1, + platform::errors::InvalidArgument( + "The last dim of Dst_index should be 1 when it " + "is 2D, but we get %d", + dst_index_dims[1])); + } else { + PADDLE_ENFORCE_EQ( + dst_index_dims.size(), 1, + platform::errors::InvalidArgument("The Dst_index should be 1D, " + "when it is not 2D, but we get %d", + dst_index_dims.size())); + } + + PADDLE_ENFORCE_EQ( + src_index_dims[0], dst_index_dims[0], + platform::errors::InvalidArgument( + "Src_index and Dst_index should have the same shape.")); + + auto dims = ctx->GetInputDim("X"); + ctx->SetOutputDim("Out", dims); + + if (ctx->Attrs().Get("pool_type") == "MEAN") { + OP_INOUT_CHECK(ctx->HasOutput("Dst_count"), "Output", "Dst_count", + "GraphSendRecv"); + ctx->SetOutputDim("Dst_count", {dims[0]}); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); + } +}; + +class GraphSendRecvGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + auto in_dims = ctx->GetInputDim(framework::GradVarName("Out")); + ctx->SetOutputDim(framework::GradVarName("X"), in_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")), + ctx.device_context()); + } +}; + +class GraphSendRecvOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "The input tensor with data type float32, float64, int32, int64."); + AddInput("Src_index", "The source index tensor."); + AddInput("Dst_index", "The destination index tensor."); + AddOutput("Out", "Output tensor of graph_send_recv op."); + AddOutput("Dst_count", + "Count tensor of Dst_index, mainly for MEAN pool_type.") + .AsIntermediate(); + AddAttr("pool_type", + "(string, default 'SUM')" + "Define different pool types to receive the result " + "tensors of Dst_index.") + .SetDefault("SUM") + .InEnum({"SUM", "MEAN", "MIN", "MAX"}); + AddComment(R"DOC( +Graph Learning Send_Recv combine operator. + +$Out = Recv(Send(X, Src_index), Dst_index, pool_type)$ + +This operator is mainly used in Graph Learning domain, and the main purpose is to reduce +intermediate memory consumption in the process of message passing. +Take `x` as the input tensor, we first use `src_index` to gather corresponding data, +and then use `dst_index` to update the corresponding position of output tensor in different +pooling types, like sum, mean, max, or min. + +)DOC"); + } +}; + +template +class GraphSendRecvGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("graph_send_recv_grad"); + op->SetInput("Src_index", this->Input("Src_index")); + op->SetInput("Dst_index", this->Input("Dst_index")); + + if (BOOST_GET_CONST(std::string, this->GetAttr("pool_type")) == "MEAN") { + op->SetInput("Dst_count", this->Output("Dst_count")); + } + + if (BOOST_GET_CONST(std::string, this->GetAttr("pool_type")) == "MIN" || + BOOST_GET_CONST(std::string, this->GetAttr("pool_type")) == "MAX") { + op->SetInput("X", this->Input("X")); + op->SetInput("Out", this->Output("Out")); + } + + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using CPU = paddle::platform::CPUDeviceContext; + +REGISTER_OPERATOR(graph_send_recv, ops::GraphSendRecvOP, + ops::GraphSendRecvOpMaker, + ops::GraphSendRecvGradOpMaker, + ops::GraphSendRecvGradOpMaker); +REGISTER_OPERATOR(graph_send_recv_grad, ops::GraphSendRecvGradOp); +REGISTER_OP_CPU_KERNEL(graph_send_recv, ops::GraphSendRecvOpKernel, + ops::GraphSendRecvOpKernel, + ops::GraphSendRecvOpKernel, + ops::GraphSendRecvOpKernel); + +REGISTER_OP_CPU_KERNEL(graph_send_recv_grad, + ops::GraphSendRecvGradOpKernel, + ops::GraphSendRecvGradOpKernel, + ops::GraphSendRecvGradOpKernel, + ops::GraphSendRecvGradOpKernel); diff --git a/paddle/fluid/operators/graph_send_recv_op.cu b/paddle/fluid/operators/graph_send_recv_op.cu new file mode 100644 index 0000000000..d9f56ec4dc --- /dev/null +++ b/paddle/fluid/operators/graph_send_recv_op.cu @@ -0,0 +1,419 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/graph_send_recv_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +struct GraphSendRecvSumCUDAFunctor { + DEVICE inline void operator()(const T* params, T* output, const IndexT& in_i, + const IndexT& out_i) { + paddle::platform::CudaAtomicAdd(output + out_i, *(params + in_i)); + } +}; + +template +struct GraphSendRecvMaxCUDAFunctor { + DEVICE inline void operator()(const T* params, T* output, const IndexT& in_i, + const IndexT& out_i) { + paddle::platform::CudaAtomicMax(output + out_i, *(params + in_i)); + } +}; + +template +struct GraphSendRecvMinCUDAFunctor { + DEVICE inline void operator()(const T* params, T* output, const IndexT& in_i, + const IndexT& out_i) { + paddle::platform::CudaAtomicMin(output + out_i, *(params + in_i)); + } +}; + +template +__global__ void GraphSendRecvCUDAKernel(const T* params, + const IndexT* src_indices, + const IndexT* dst_indices, T* output, + size_t index_size, size_t slice_size, + Functor functor) { + CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) { + int64_t indices_i = i / slice_size; + int64_t slice_i = i - indices_i * slice_size; + IndexT src_i = src_indices[indices_i]; + IndexT dst_i = dst_indices[indices_i]; + int64_t in_i = src_i * slice_size + slice_i; + int64_t out_i = dst_i * slice_size + slice_i; + functor(params, output, in_i, out_i); + } +} + +// For max +template +__global__ void InputResetMaxCUDAKernel(T* output, size_t input_size, + size_t slice_size) { + CUDA_KERNEL_LOOP_TYPE(i, input_size * slice_size, int64_t) { + if (*(output + i) == std::numeric_limits::min()) { + *(output + i) = 0; + } + } +} + +// For min +template +__global__ void InputResetMinCUDAKernel(T* output, size_t input_size, + size_t slice_size) { + CUDA_KERNEL_LOOP_TYPE(i, input_size * slice_size, int64_t) { + if (*(output + i) == std::numeric_limits::max()) { + *(output + i) = 0; + } + } +} + +// Get dst_count +template +__global__ void ComputeCountCUDAKernel(int* count, const IndexT* dst_indices, + size_t index_size) { + CUDA_KERNEL_LOOP_TYPE(i, index_size, int64_t) { + IndexT dst_i = dst_indices[i]; + paddle::platform::CudaAtomicAdd(count + dst_i, 1); + } +} + +// For forward mean +template +__global__ void ManipulateMeanCUDAKernel(T* output, int* count, + size_t input_size, size_t slice_size) { + CUDA_KERNEL_LOOP_TYPE(i, input_size * slice_size, int64_t) { + int64_t c_index = i / slice_size; + if (*(count + c_index) > 1) { + *(output + i) = *(output + i) / *(count + c_index); + } + } +} + +// For backward mean +template +__global__ void ManipulateMeanGradCUDAKernel( + const T* params, const IndexT* src_indices, const IndexT* dst_indices, + T* output, size_t index_size, size_t slice_size, const int* dst_count) { + CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) { + int64_t indices_i = i / slice_size; + int64_t slice_i = i - indices_i * slice_size; + IndexT src_i = src_indices[indices_i]; + IndexT dst_i = dst_indices[indices_i]; + int64_t in_i = src_i * slice_size + slice_i; + int64_t out_i = dst_i * slice_size + slice_i; + paddle::platform::CudaAtomicAdd(output + out_i, + *(params + in_i) / dst_count[src_i]); + } +} + +// For backward min and max +template +__global__ void ManipulateMinMaxGradCUDAKernel( + const T* params, const IndexT* src_indices, const IndexT* dst_indices, + T* output, size_t index_size, size_t slice_size, const T* ptr_input, + const T* ptr_output) { + CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) { + int64_t indices_i = i / slice_size; + int64_t slice_i = i - indices_i * slice_size; + IndexT src_i = src_indices[indices_i]; + IndexT dst_i = dst_indices[indices_i]; + int64_t in_i = src_i * slice_size + slice_i; + int64_t out_i = dst_i * slice_size + slice_i; + paddle::platform::CudaAtomicAdd( + output + out_i, + *(params + in_i) * (*(ptr_input + out_i) == *(ptr_output + in_i))); + } +} + +template +void GraphSendRecvOpCUDAKernelLaunchHelper( + const framework::ExecutionContext& ctx, const Tensor& src_index, + const Tensor& dst_index) { + auto* X = ctx.Input("X"); + auto* Y = ctx.Output("Out"); + std::string pool_type = ctx.Attr("pool_type"); + + const int& index_size = src_index.dims()[0]; + + T* p_output = Y->mutable_data(ctx.GetPlace()); + const auto& src_dims = X->dims(); + int64_t memset_size = 1; + for (int i = 0; i < src_dims.size(); ++i) { + memset_size *= src_dims[i]; + } + const size_t& memset_bytes = memset_size * sizeof(T); + if (pool_type == "SUM" || pool_type == "MEAN") { +#ifdef PADDLE_WITH_HIP + hipMemset(p_output, 0, memset_bytes); +#else + cudaMemset(p_output, 0, memset_bytes); +#endif + } else if (pool_type == "MAX") { + thrust::device_ptr p_output_ptr(p_output); + thrust::fill(thrust::device, p_output_ptr, p_output_ptr + memset_size, + std::numeric_limits::min()); + } else if (pool_type == "MIN") { + thrust::device_ptr p_output_ptr(p_output); + thrust::fill(thrust::device, p_output_ptr, p_output_ptr + memset_size, + std::numeric_limits::max()); + } + + if (index_size == 0) return; + + int64_t slice_size = 1; + for (int i = 1; i < src_dims.size(); ++i) { + slice_size *= src_dims[i]; + } + const T* p_src = X->data(); + const IndexT* s_index = src_index.data(); + const IndexT* d_index = dst_index.data(); + +#ifdef PADDLE_WITH_HIP + int block = 256; +#else + int block = 1024; +#endif + int64_t n = slice_size * index_size; + const auto& dev_ctx = ctx.cuda_device_context(); + int64_t max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize().x; + int64_t grid_tmp = (n + block - 1) / block; + int64_t grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx; + int64_t input_size = src_dims[0]; + if (pool_type == "SUM") { + GraphSendRecvSumCUDAFunctor functor; + GraphSendRecvCUDAKernel><<< + grid, block, 0, reinterpret_cast( + ctx.device_context()) + .stream()>>>(p_src, s_index, d_index, p_output, + index_size, slice_size, functor); + } else if (pool_type == "MAX") { + GraphSendRecvMaxCUDAFunctor functor; + GraphSendRecvCUDAKernel><<< + grid, block, 0, reinterpret_cast( + ctx.device_context()) + .stream()>>>(p_src, s_index, d_index, p_output, + index_size, slice_size, functor); + + int64_t grid_max_tmp = (input_size * slice_size + block - 1) / block; + int64_t grid_max = + grid_max_tmp < max_grid_dimx ? grid_max_tmp : max_grid_dimx; + InputResetMaxCUDAKernel< + T><<( + ctx.device_context()) + .stream()>>>(p_output, input_size, slice_size); + } else if (pool_type == "MIN") { + GraphSendRecvMinCUDAFunctor functor; + GraphSendRecvCUDAKernel><<< + grid, block, 0, reinterpret_cast( + ctx.device_context()) + .stream()>>>(p_src, s_index, d_index, p_output, + index_size, slice_size, functor); + + int64_t grid_min_tmp = (input_size * slice_size + block - 1) / block; + int64_t grid_min = + grid_min_tmp < max_grid_dimx ? grid_min_tmp : max_grid_dimx; + InputResetMinCUDAKernel< + T><<( + ctx.device_context()) + .stream()>>>(p_output, input_size, slice_size); + } else if (pool_type == "MEAN") { + GraphSendRecvSumCUDAFunctor functor; + GraphSendRecvCUDAKernel><<< + grid, block, 0, reinterpret_cast( + ctx.device_context()) + .stream()>>>(p_src, s_index, d_index, p_output, + index_size, slice_size, functor); + + auto* dst_count = ctx.Output("Dst_count"); + int* p_dst_count = dst_count->mutable_data(ctx.GetPlace()); + +#ifdef PADDLE_WITH_HIP + hipMemset(p_dst_count, 0, input_size * sizeof(int)); +#else + cudaMemset(p_dst_count, 0, input_size * sizeof(int)); +#endif + + int64_t grid_count = (index_size + block - 1) / block; + ComputeCountCUDAKernel< + T, IndexT><<( + ctx.device_context()) + .stream()>>>(p_dst_count, d_index, index_size); + + int64_t grid_mean_tmp = (input_size * slice_size + block - 1) / block; + int64_t grid_mean = + grid_mean_tmp < max_grid_dimx ? grid_mean_tmp : max_grid_dimx; + ManipulateMeanCUDAKernel< + T><<( + ctx.device_context()) + .stream()>>>(p_output, p_dst_count, input_size, slice_size); + } +} + +template +void GraphSendRecvGradOpCUDAKernelLaunchHelper( + const framework::ExecutionContext& ctx, const Tensor& src_index, + const Tensor& dst_index) { + auto* X = ctx.Input(framework::GradVarName("Out")); + auto* Y = ctx.Output(framework::GradVarName("X")); + std::string pool_type = ctx.Attr("pool_type"); + + const int& index_size = src_index.dims()[0]; + + T* p_output = Y->mutable_data(ctx.GetPlace()); + const auto& src_dims = X->dims(); + int64_t memset_size = 1; + for (int i = 0; i < src_dims.size(); ++i) { + memset_size *= src_dims[i]; + } + const size_t& memset_bytes = memset_size * sizeof(T); + +#ifdef PADDLE_WITH_HIP + hipMemset(p_output, 0, memset_bytes); +#else + cudaMemset(p_output, 0, memset_bytes); +#endif + + if (index_size == 0) return; + + int64_t slice_size = 1; + for (int i = 1; i < src_dims.size(); ++i) { + slice_size *= src_dims[i]; + } + const T* p_src = X->data(); + const IndexT* s_index = src_index.data(); + const IndexT* d_index = dst_index.data(); + +#ifdef PADDLE_WITH_HIP + int block = 256; +#else + int block = 1024; +#endif + int64_t n = slice_size * index_size; + const auto& dev_ctx = ctx.cuda_device_context(); + int64_t max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize().x; + int64_t grid_tmp = (n + block - 1) / block; + int64_t grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx; + int64_t input_size = src_dims[0]; + if (pool_type == "SUM") { + GraphSendRecvSumCUDAFunctor functor; + GraphSendRecvCUDAKernel><<< + grid, block, 0, reinterpret_cast( + ctx.device_context()) + .stream()>>>(p_src, s_index, d_index, p_output, + index_size, slice_size, functor); + } else if (pool_type == "MEAN") { + auto* dst_count = ctx.Input("Dst_count"); + const int* s_count = dst_count->data(); + ManipulateMeanGradCUDAKernel<<< + grid, block, 0, reinterpret_cast( + ctx.device_context()) + .stream()>>>(p_src, s_index, d_index, p_output, + index_size, slice_size, s_count); + } else if (pool_type == "MAX" || pool_type == "MIN") { + auto* input = ctx.Input("X"); + auto* output = ctx.Input("Out"); + const T* ptr_input = input->data(); + const T* ptr_output = output->data(); + ManipulateMinMaxGradCUDAKernel<<< + grid, block, 0, reinterpret_cast( + ctx.device_context()) + .stream()>>>(p_src, s_index, d_index, p_output, + index_size, slice_size, ptr_input, + ptr_output); + } +} + +template +class GraphSendRecvOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* src_index = ctx.Input("Src_index"); + auto* dst_index = ctx.Input("Dst_index"); + auto index_type = src_index->type(); + + if (index_type == framework::proto::VarType::INT32) { + GraphSendRecvOpCUDAKernelLaunchHelper( + ctx, *src_index, *dst_index); + } else if (index_type == framework::proto::VarType::INT64) { + GraphSendRecvOpCUDAKernelLaunchHelper( + ctx, *src_index, *dst_index); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported Src_index or Dst_index dtype, expected int, int64, but " + "got %s.", + index_type)); + } + } +}; + +template +class GraphSendRecvGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* src_index = ctx.Input("Dst_index"); + auto* dst_index = ctx.Input("Src_index"); + auto index_type = src_index->type(); + + if (index_type == framework::proto::VarType::INT32) { + GraphSendRecvGradOpCUDAKernelLaunchHelper( + ctx, *src_index, *dst_index); + } else if (index_type == framework::proto::VarType::INT64) { + GraphSendRecvGradOpCUDAKernelLaunchHelper( + ctx, *src_index, *dst_index); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported Src_index or Dst_index dtype, expected int, int64, but " + "got %s.", + index_type)); + } + } +}; + +} // namespace operators +} // namespace paddle + +using CUDA = paddle::platform::CUDADeviceContext; +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL(graph_send_recv, + ops::GraphSendRecvOpCUDAKernel, + ops::GraphSendRecvOpCUDAKernel, + ops::GraphSendRecvOpCUDAKernel, + ops::GraphSendRecvOpCUDAKernel); + +REGISTER_OP_CUDA_KERNEL(graph_send_recv_grad, + ops::GraphSendRecvGradOpCUDAKernel, + ops::GraphSendRecvGradOpCUDAKernel, + ops::GraphSendRecvGradOpCUDAKernel, + ops::GraphSendRecvGradOpCUDAKernel); diff --git a/paddle/fluid/operators/graph_send_recv_op.h b/paddle/fluid/operators/graph_send_recv_op.h new file mode 100644 index 0000000000..1c7ea74be2 --- /dev/null +++ b/paddle/fluid/operators/graph_send_recv_op.h @@ -0,0 +1,291 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +struct GraphSendRecvSumFunctor { + void operator()(const bool& first_flag, const Tensor& src_slice, + Tensor* dst_slice) { + auto eigen_src = framework::EigenVector::Flatten(src_slice); + auto eigen_dst = framework::EigenVector::Flatten(*dst_slice); + eigen_dst += eigen_src; + } +}; + +template +struct GraphSendRecvMinFunctor { + void operator()(const bool& first_flag, const Tensor& src_slice, + Tensor* dst_slice) { + auto eigen_src = framework::EigenVector::Flatten(src_slice); + auto eigen_dst = framework::EigenVector::Flatten(*dst_slice); + if (first_flag) { + eigen_dst += eigen_src; + } else { + eigen_dst = eigen_dst.cwiseMin(eigen_src); + } + } +}; + +template +struct GraphSendRecvMaxFunctor { + void operator()(const int& first_flag, const Tensor& src_slice, + Tensor* dst_slice) { + auto eigen_src = framework::EigenVector::Flatten(src_slice); + auto eigen_dst = framework::EigenVector::Flatten(*dst_slice); + if (first_flag) { + eigen_dst += eigen_src; + } else { + eigen_dst = eigen_dst.cwiseMax(eigen_src); + } + } +}; + +template +void elementwise_inner_operation(const Tensor& src, Tensor* dst, + const IndexT& src_index, + const IndexT& dst_index, + const bool& first_flag, Functor functor) { + auto src_slice = src.Slice(src_index, src_index + 1); + auto dst_slice = dst->Slice(dst_index, dst_index + 1); + + functor(first_flag, src_slice, &dst_slice); +} + +template +void graph_send_recv_cpu_for_loop(const int& input_size, const int& index_size, + const IndexT* s_index, const IndexT* d_index, + const Tensor& src, Tensor* dst, + const std::string& pool_type, + int* dst_count = nullptr) { + Functor functor; + if (pool_type == "SUM") { + for (int i = 0; i < index_size; ++i) { + const IndexT& src_idx = s_index[i]; + const IndexT& dst_idx = d_index[i]; + elementwise_inner_operation(src, dst, src_idx, + dst_idx, false, functor); + } + } else if (pool_type == "MEAN") { + for (int i = 0; i < index_size; ++i) { + const IndexT& src_idx = s_index[i]; + const IndexT& dst_idx = d_index[i]; + elementwise_inner_operation(src, dst, src_idx, + dst_idx, false, functor); + } + for (int i = 0; i < index_size; ++i) { + IndexT dst_idx = d_index[i]; + *(dst_count + dst_idx) += 1; + } + for (int i = 0; i < input_size; ++i) { + if (*(dst_count + i) == 0) continue; + auto dst_slice = dst->Slice(i, i + 1); + auto eigen_dst = framework::EigenVector::Flatten(dst_slice); + eigen_dst = eigen_dst / static_cast(*(dst_count + i)); + } + } else if (pool_type == "MIN" || pool_type == "MAX") { + std::set existed_dst; + for (int i = 0; i < index_size; ++i) { + const IndexT& src_idx = s_index[i]; + const IndexT& dst_idx = d_index[i]; + bool in_set = existed_dst.find(dst_idx) != existed_dst.end(); + if (!in_set) { + elementwise_inner_operation(src, dst, src_idx, + dst_idx, true, functor); + existed_dst.emplace(dst_idx); + } else { + elementwise_inner_operation( + src, dst, src_idx, dst_idx, false, functor); + } + } + } +} + +template +void graph_send_recv_cpu_for_loop_grad( + const int& input_size, const int& index_size, const IndexT* s_index, + const IndexT* d_index, const Tensor& src, Tensor* dst, + const std::string& pool_type, const int* dst_count = nullptr, + const Tensor* input = nullptr, const Tensor* output = nullptr) { + if (pool_type == "SUM") { + Functor functor; + for (int i = 0; i < index_size; ++i) { + const IndexT& src_idx = s_index[i]; + const IndexT& dst_idx = d_index[i]; + elementwise_inner_operation(src, dst, src_idx, + dst_idx, false, functor); + } + } else if (pool_type == "MEAN") { + for (int i = 0; i < index_size; ++i) { + const IndexT& src_idx = s_index[i]; + const IndexT& dst_idx = d_index[i]; + auto src_slice = src.Slice(src_idx, src_idx + 1); + auto dst_slice = dst->Slice(dst_idx, dst_idx + 1); + auto eigen_src = framework::EigenVector::Flatten(src_slice); + auto eigen_dst = framework::EigenVector::Flatten(dst_slice); + eigen_dst += (eigen_src / static_cast(dst_count[src_idx])); + } + } else if (pool_type == "MIN" || pool_type == "MAX") { + for (int i = 0; i < index_size; ++i) { + const IndexT& forward_src_idx = d_index[i]; + const IndexT& forward_dst_idx = s_index[i]; + auto input_slice = input->Slice(forward_src_idx, forward_src_idx + 1); + auto output_slice = output->Slice(forward_dst_idx, forward_dst_idx + 1); + auto eigen_input = framework::EigenVector::Flatten(input_slice); + auto eigen_output = framework::EigenVector::Flatten(output_slice); + + auto src_slice = src.Slice(forward_dst_idx, forward_dst_idx + 1); + auto dst_slice = dst->Slice(forward_src_idx, forward_src_idx + 1); + auto eigen_src = framework::EigenVector::Flatten(src_slice); + auto eigen_dst = framework::EigenVector::Flatten(dst_slice); + eigen_dst += eigen_src * (eigen_output == eigen_input); + } + } +} + +template +void GraphSendRecvOpKernelLaunchHelper(const framework::ExecutionContext& ctx, + const Tensor& src_index) { + auto* X = ctx.Input("X"); + auto* dst_index = ctx.Input("Dst_index"); + auto* Y = ctx.Output("Out"); + + const int& index_size = src_index.dims()[0]; + + T* p_output = Y->mutable_data(ctx.GetPlace()); + const auto& src_dims = X->dims(); + int64_t memset_size = 1; + for (int i = 0; i < src_dims.size(); ++i) memset_size *= src_dims[i]; + const size_t& memset_bytes = memset_size * sizeof(T); + memset(p_output, 0, memset_bytes); + + if (index_size == 0) return; + + const IndexT* s_index = src_index.data(); + const IndexT* d_index = dst_index->data(); + const std::string& pool_type = ctx.Attr("pool_type"); + if (pool_type == "SUM") { + graph_send_recv_cpu_for_loop>( + src_dims[0], index_size, s_index, d_index, *X, Y, pool_type); + } else if (pool_type == "MIN") { + graph_send_recv_cpu_for_loop>( + src_dims[0], index_size, s_index, d_index, *X, Y, pool_type); + } else if (pool_type == "MAX") { + graph_send_recv_cpu_for_loop>( + src_dims[0], index_size, s_index, d_index, *X, Y, pool_type); + } else if (pool_type == "MEAN") { + auto* dst_count = ctx.Output("Dst_count"); + int* p_dst_count = dst_count->mutable_data(ctx.GetPlace()); + memset(p_dst_count, 0, src_dims[0] * sizeof(int)); + graph_send_recv_cpu_for_loop>( + src_dims[0], index_size, s_index, d_index, *X, Y, pool_type, + p_dst_count); + } +} + +template +void GraphSendRecvGradOpKernelLaunchHelper( + const framework::ExecutionContext& ctx, const Tensor& src_index) { + auto* X = ctx.Input(framework::GradVarName("Out")); + auto* dst_index = ctx.Input("Src_index"); + auto* Y = ctx.Output(framework::GradVarName("X")); + + const int& index_size = src_index.dims()[0]; + + T* p_output = Y->mutable_data(ctx.GetPlace()); + const auto& src_dims = X->dims(); + int64_t memset_size = 1; + for (int i = 0; i < src_dims.size(); ++i) memset_size *= src_dims[i]; + const size_t& memset_bytes = memset_size * sizeof(T); + memset(p_output, 0, memset_bytes); + + if (index_size == 0) return; + + const IndexT* s_index = src_index.data(); + const IndexT* d_index = dst_index->data(); + + const std::string& pool_type = ctx.Attr("pool_type"); + if (pool_type == "SUM") { + graph_send_recv_cpu_for_loop_grad>( + src_dims[0], index_size, s_index, d_index, *X, Y, pool_type); + } else if (pool_type == "MEAN") { + auto* dst_count = ctx.Input("Dst_count"); + const int* s_count = dst_count->data(); + // Functor not used here. + graph_send_recv_cpu_for_loop_grad>( + src_dims[0], index_size, s_index, d_index, *X, Y, pool_type, s_count); + } else if (pool_type == "MIN" || pool_type == "MAX") { + const auto* input = ctx.Input("X"); + const auto* output = ctx.Input("Out"); + // Functor not used here. + graph_send_recv_cpu_for_loop_grad>( + src_dims[0], index_size, s_index, d_index, *X, Y, pool_type, nullptr, + input, output); + } +} + +template +class GraphSendRecvOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* src_index = ctx.Input("Src_index"); + auto index_type = src_index->type(); + + if (index_type == framework::proto::VarType::INT32) { + GraphSendRecvOpKernelLaunchHelper(ctx, *src_index); + } else if (index_type == framework::proto::VarType::INT64) { + GraphSendRecvOpKernelLaunchHelper(ctx, + *src_index); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported Src_index or Dst_index type, Expected int, int64, but " + "got %s.", + index_type)); + } + } +}; + +template +class GraphSendRecvGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* src_index = ctx.Input("Dst_index"); + auto index_type = src_index->type(); + + if (index_type == framework::proto::VarType::INT32) { + GraphSendRecvGradOpKernelLaunchHelper(ctx, + *src_index); + } else if (index_type == framework::proto::VarType::INT64) { + GraphSendRecvGradOpKernelLaunchHelper( + ctx, *src_index); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported Src_index or Dst_index type, Expected int, int64, but " + "got %s.", + index_type)); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/platform/cuda_primitives.h b/paddle/fluid/platform/cuda_primitives.h index 4708a99e8f..d443e78ed8 100644 --- a/paddle/fluid/platform/cuda_primitives.h +++ b/paddle/fluid/platform/cuda_primitives.h @@ -178,9 +178,17 @@ CUDA_ATOMIC_WRAPPER(Max, int64_t) { // Here, we check long long int must be int64_t. static_assert(sizeof(int64_t) == sizeof(long long int), // NOLINT "long long should be int64"); - return CudaAtomicMax( - reinterpret_cast(address), // NOLINT - static_cast(val)); // NOLINT + long long int res = *address; // NOLINT + while (val > res) { + long long int old = res; // NOLINT + res = (long long int)atomicCAS((unsigned long long int *)address, // NOLINT + (unsigned long long int)old, // NOLINT + (unsigned long long int)val); // NOLINT + if (res == old) { + break; + } + } + return res; } CUDA_ATOMIC_WRAPPER(Max, float) { @@ -254,9 +262,17 @@ CUDA_ATOMIC_WRAPPER(Min, int64_t) { // Here, we check long long int must be int64_t. static_assert(sizeof(int64_t) == sizeof(long long int), // NOLINT "long long should be int64"); - return CudaAtomicMin( - reinterpret_cast(address), // NOLINT - static_cast(val)); // NOLINT + long long int res = *address; // NOLINT + while (val < res) { + long long int old = res; // NOLINT + res = (long long int)atomicCAS((unsigned long long int *)address, // NOLINT + (unsigned long long int)old, // NOLINT + (unsigned long long int)val); // NOLINT + if (res == old) { + break; + } + } + return res; } CUDA_ATOMIC_WRAPPER(Min, float) { diff --git a/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py b/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py new file mode 100644 index 0000000000..68b354775d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_graph_send_recv_op.py @@ -0,0 +1,309 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import paddle +import paddle.fluid as fluid + +from op_test import OpTest + + +class TestGraphSendRecvMaxOp(OpTest): + def setUp(self): + paddle.enable_static() + self.op_type = "graph_send_recv" + x = np.random.random((10, 20)).astype("float64") + index = np.random.randint(0, 10, (15, 2)).astype(np.int64) + src_index = index[:, 0] + dst_index = index[:, 1] + + self.inputs = {'X': x, 'Src_index': src_index, 'Dst_index': dst_index} + + self.attrs = {'pool_type': 'MAX'} + + out, self.gradient = compute_graph_send_recv_for_min_max(self.inputs, + self.attrs) + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out', user_defined_grads=[self.gradient]) + + +class TestGraphSendRecvMinOp(OpTest): + def setUp(self): + paddle.enable_static() + self.op_type = "graph_send_recv" + x = np.random.random((10, 20)).astype("float64") + index = np.random.randint(0, 10, (15, 2)).astype(np.int64) + src_index = index[:, 0] + dst_index = index[:, 1] + + self.inputs = {'X': x, 'Src_index': src_index, 'Dst_index': dst_index} + + self.attrs = {'pool_type': 'MIN'} + + out, self.gradient = compute_graph_send_recv_for_min_max(self.inputs, + self.attrs) + + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out', user_defined_grads=[self.gradient]) + + +class TestGraphSendRecvSumOp(OpTest): + def setUp(self): + paddle.enable_static() + self.op_type = "graph_send_recv" + x = np.random.random((10, 20)).astype("float64") + index = np.random.randint(0, 10, (15, 2)).astype(np.int64) + src_index = index[:, 0] + dst_index = index[:, 1] + + self.inputs = {'X': x, 'Src_index': src_index, 'Dst_index': dst_index} + + self.attrs = {'pool_type': 'SUM'} + + out, _ = compute_graph_send_recv_for_sum_mean(self.inputs, self.attrs) + + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestGraphSendRecvMeanOp(OpTest): + def setUp(self): + paddle.enable_static() + self.op_type = "graph_send_recv" + x = np.random.random((10, 20)).astype("float64") + index = np.random.randint(0, 10, (15, 2)).astype(np.int64) + src_index = index[:, 0] + dst_index = index[:, 1] + + self.inputs = {'X': x, 'Src_index': src_index, 'Dst_index': dst_index} + + self.attrs = {'pool_type': 'MEAN'} + + out, dst_count = compute_graph_send_recv_for_sum_mean(self.inputs, + self.attrs) + + self.outputs = {'Out': out, 'Dst_count': dst_count} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +def compute_graph_send_recv_for_sum_mean(inputs, attributes): + x = inputs['X'] + src_index = inputs['Src_index'] + dst_index = inputs['Dst_index'] + + pool_type = attributes['pool_type'] + + gather_x = x[src_index] + target_shape = list(x.shape) + results = np.zeros(target_shape, dtype=x.dtype) + if pool_type == 'SUM': + for index, s_id in enumerate(dst_index): + results[s_id, :] += gather_x[index, :] + elif pool_type == 'MEAN': + count = np.zeros(target_shape[0], dtype=np.int32) + for index, s_id in enumerate(dst_index): + results[s_id, :] += gather_x[index, :] + count[s_id] += 1 + results = results / count.reshape([-1, 1]) + results[np.isnan(results)] = 0 + else: + raise ValueError("Invalid pool_type, only SUM, MEAN supported!") + + count = np.zeros(target_shape[0], dtype=np.int32) + for index, s_id in enumerate(dst_index): + count[s_id] += 1 + + return results, count + + +def compute_graph_send_recv_for_min_max(inputs, attributes): + x = inputs['X'] + src_index = inputs['Src_index'] + dst_index = inputs['Dst_index'] + + pool_type = attributes['pool_type'] + + gather_x = x[src_index] + target_shape = list(x.shape) + results = np.zeros(target_shape, dtype=x.dtype) + gradient = np.zeros_like(x) + + # Calculate forward output + if pool_type == "MAX": + first_set = set() + for index, s_id in enumerate(dst_index): + if s_id not in first_set: + results[s_id, :] += gather_x[index, :] + first_set.add(s_id) + else: + results[s_id, :] = np.maximum(results[s_id, :], + gather_x[index, :]) + elif pool_type == "MIN": + first_set = set() + for index, s_id in enumerate(dst_index): + if s_id not in first_set: + results[s_id, :] += gather_x[index, :] + first_set.add(s_id) + else: + results[s_id, :] = np.minimum(results[s_id, :], + gather_x[index, :]) + else: + raise ValueError("Invalid pool_type, only MAX, MIN supported!") + + # Calculate backward gradient + index_size = len(src_index) + for i in range(index_size): + forward_src_idx = src_index[i] + forward_dst_idx = dst_index[i] + gradient[forward_src_idx] += 1 * ( + x[forward_src_idx] == results[forward_dst_idx]) + + return results, gradient / results.size + + +class API_GraphSendRecvOpTest(unittest.TestCase): + def test_static(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(name="x", shape=[3, 3], dtype="float32") + src_index = paddle.static.data(name="src", shape=[4], dtype="int32") + dst_index = paddle.static.data(name="dst", shape=[4], dtype="int32") + + res_sum = paddle.incubate.graph_send_recv(x, src_index, dst_index, + "sum") + res_mean = paddle.incubate.graph_send_recv(x, src_index, dst_index, + "mean") + res_max = paddle.incubate.graph_send_recv(x, src_index, dst_index, + "max") + res_min = paddle.incubate.graph_send_recv(x, src_index, dst_index, + "min") + + exe = paddle.static.Executor(paddle.CPUPlace()) + data1 = np.array([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype='float32') + data2 = np.array([0, 1, 2, 0], dtype="int32") + data3 = np.array([1, 2, 1, 0], dtype="int32") + + np_sum = np.array( + [[0, 2, 3], [2, 8, 10], [1, 4, 5]], dtype="float32") + np_mean = np.array( + [[0, 2, 3], [1, 4, 5], [1, 4, 5]], dtype="float32") + np_max = np.array( + [[0, 2, 3], [2, 6, 7], [1, 4, 5]], dtype="float32") + np_min = np.array( + [[0, 2, 3], [0, 2, 3], [1, 4, 5]], dtype="float32") + + ret = exe.run(feed={'x': data1, + 'src': data2, + 'dst': data3}, + fetch_list=[res_sum, res_mean, res_max, res_min]) + + for np_res, ret_res in zip([np_sum, np_mean, np_max, np_min], ret): + self.assertTrue( + np.allclose( + np_res, ret_res, atol=1e-6), + "two value is\ + {}\n{}, check diff!".format(np_res, ret_res)) + + def test_dygraph(self): + device = paddle.CPUPlace() + with paddle.fluid.dygraph.guard(device): + x = paddle.to_tensor( + np.array([[0, 2, 3], [1, 4, 5], [2, 6, 7]]), dtype="float32") + src_index = paddle.to_tensor(np.array([0, 1, 2, 0]), dtype="int32") + dst_index = paddle.to_tensor(np.array([1, 2, 1, 0]), dtype="int32") + res_sum = paddle.incubate.graph_send_recv(x, src_index, dst_index, + "sum") + res_mean = paddle.incubate.graph_send_recv(x, src_index, dst_index, + "mean") + res_max = paddle.incubate.graph_send_recv(x, src_index, dst_index, + "max") + res_min = paddle.incubate.graph_send_recv(x, src_index, dst_index, + "min") + + np_sum = np.array( + [[0, 2, 3], [2, 8, 10], [1, 4, 5]], dtype="float32") + np_mean = np.array( + [[0, 2, 3], [1, 4, 5], [1, 4, 5]], dtype="float32") + np_max = np.array( + [[0, 2, 3], [2, 6, 7], [1, 4, 5]], dtype="float32") + np_min = np.array( + [[0, 2, 3], [0, 2, 3], [1, 4, 5]], dtype="float32") + + ret = [res_sum, res_mean, res_max, res_min] + + for np_res, ret_res in zip([np_sum, np_mean, np_max, np_min], ret): + self.assertTrue( + np.allclose( + np_res, ret_res, atol=1e-6), + "two value is\ + {}\n{}, check diff!".format(np_res, ret_res)) + + def test_int32_input(self): + device = paddle.CPUPlace() + with paddle.fluid.dygraph.guard(device): + x = paddle.to_tensor( + np.array([[0, 2, 3], [1, 4, 5], [2, 6, 6]]), dtype="int32") + src_index = paddle.to_tensor( + np.array([0, 1, 2, 0, 1]), dtype="int32") + dst_index = paddle.to_tensor( + np.array([1, 2, 1, 0, 1]), dtype="int32") + res_sum = paddle.incubate.graph_send_recv(x, src_index, dst_index, + "sum") + res_mean = paddle.incubate.graph_send_recv(x, src_index, dst_index, + "mean") + res_max = paddle.incubate.graph_send_recv(x, src_index, dst_index, + "max") + res_min = paddle.incubate.graph_send_recv(x, src_index, dst_index, + "min") + + np_sum = np.array( + [[0, 2, 3], [3, 12, 14], [1, 4, 5]], dtype="int32") + np_mean = np.array([[0, 2, 3], [1, 4, 4], [1, 4, 5]], dtype="int32") + np_max = np.array([[0, 2, 3], [2, 6, 6], [1, 4, 5]], dtype="int32") + np_min = np.array([[0, 2, 3], [0, 2, 3], [1, 4, 5]], dtype="int32") + + ret = [res_sum, res_mean, res_max, res_min] + + for np_res, ret_res in zip([np_sum, np_mean, np_max, np_min], ret): + self.assertTrue( + np.allclose( + np_res, ret_res, atol=1e-6), + "two value is\ + {}\n{}, check diff!".format(np_res, ret_res)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/incubate/__init__.py b/python/paddle/incubate/__init__.py index f44e38347e..e5215cf506 100644 --- a/python/paddle/incubate/__init__.py +++ b/python/paddle/incubate/__init__.py @@ -18,6 +18,7 @@ from .checkpoint import auto_checkpoint # noqa: F401 from ..fluid.layer_helper import LayerHelper # noqa: F401 from .operators import softmax_mask_fuse_upper_triangle # noqa: F401 from .operators import softmax_mask_fuse # noqa: F401 +from .operators import graph_send_recv from .tensor import segment_sum from .tensor import segment_mean from .tensor import segment_max @@ -30,6 +31,7 @@ __all__ = [ 'ModelAverage', 'softmax_mask_fuse_upper_triangle', 'softmax_mask_fuse', + 'graph_send_recv', 'segment_sum', 'segment_mean', 'segment_max', diff --git a/python/paddle/incubate/operators/__init__.py b/python/paddle/incubate/operators/__init__.py index 9a6710d095..ecf73fb393 100644 --- a/python/paddle/incubate/operators/__init__.py +++ b/python/paddle/incubate/operators/__init__.py @@ -15,3 +15,4 @@ from .softmax_mask_fuse_upper_triangle import softmax_mask_fuse_upper_triangle # noqa: F401 from .softmax_mask_fuse import softmax_mask_fuse # noqa: F401 from .resnet_unit import ResNetUnit #noqa: F401 +from .graph_send_recv import graph_send_recv #noqa: F401 diff --git a/python/paddle/incubate/operators/graph_send_recv.py b/python/paddle/incubate/operators/graph_send_recv.py new file mode 100644 index 0000000000..9b8f542658 --- /dev/null +++ b/python/paddle/incubate/operators/graph_send_recv.py @@ -0,0 +1,108 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.framework import in_dygraph_mode +from paddle.fluid.data_feeder import check_variable_and_dtype +from paddle.fluid import core + + +def graph_send_recv(x, src_index, dst_index, pool_type="sum", name=None): + r""" + + Graph Learning Send_Recv combine operator. + + This operator is mainly used in Graph Learning domain, and the main purpose is to reduce intermediate memory + consumption in the process of message passing. Take `x` as the input tensor, we first use `src_index` + to gather the corresponding data, and then use `dst_index` to update the corresponding position of output tensor + in different pooling types, like sum, mean, max, or min. + + .. code-block:: text + + Given: + + X = [[0, 2, 3], + [1, 4, 5], + [2, 6, 7]] + + src_index = [0, 1, 2, 0] + + dst_index = [1, 2, 1, 0] + + pool_type = "sum" + + Then: + + Out = [[0, 2, 3], + [2, 8, 10], + [1, 4, 5]] + + Args: + x (Tensor): The input tensor, and the available data type is float32, float64, int32, int64. + src_index (Tensor): An 1-D tensor, and the available data type is int32, int64. + dst_index (Tensor): An 1-D tensor, and should have the same shape as `src_index`. + The available data type is int32, int64. + pool_type (str): The pooling type of graph_send_recv, including `sum`, `mean`, `max`, `min`. + Default value is `sum`. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + out (Tensor): The output tensor, should have the same shape and same dtype as input tensor `x`. + + Examples: + + .. code-block:: python + + import paddle + + x = paddle.to_tensor([[0, 2, 3], [1, 4, 5], [2, 6, 7]], dtype="float32") + indexes = paddle.to_tensor([[0, 1], [1, 2], [2, 1], [0, 0]], dtype="int32") + src_index = indexes[:, 0] + dst_index = indexes[:, 1] + out = paddle.incubate.graph_send_recv(x, src_index, dst_index, pool_type="sum") + # Outputs: [[0., 2., 3.], [2., 8., 10.], [1., 4., 5.]] + + """ + + if pool_type not in ["sum", "mean", "max", "min"]: + raise ValueError( + "pool_type should be `sum`, `mean`, `max` or `min`, but received %s" + % pool_type) + + if in_dygraph_mode(): + out, tmp = core.ops.graph_send_recv(x, src_index, dst_index, + 'pool_type', pool_type.upper()) + return out + + check_variable_and_dtype(x, "X", ("float32", "float64", "int32", "int64"), + "graph_send_recv") + check_variable_and_dtype(src_index, "Src_index", ("int32", "int64"), + "graph_send_recv") + check_variable_and_dtype(dst_index, "Dst_index", ("int32", "int64"), + "graph_send_recv") + + helper = LayerHelper("graph_send_recv", **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + dst_count = helper.create_variable_for_type_inference( + dtype="int32", stop_gradient=True) + helper.append_op( + type="graph_send_recv", + inputs={"X": x, + "Src_index": src_index, + "Dst_index": dst_index}, + outputs={"Out": out, + "Dst_count": dst_count}, + attrs={"pool_type": pool_type.upper()}) + return out -- GitLab