From 6bd2d2b1cb5fa2e350adb4c9b291b48054257be5 Mon Sep 17 00:00:00 2001 From: wawltor Date: Tue, 8 Mar 2022 10:29:59 +0800 Subject: [PATCH] [Phi] move the graph_send_recv op to the phi (#40092) * [Phi] transfer old kernel to pten kernel for the graph_send_recv op * update the code for the define of graph_send_recv * fix the gradient problem for graph_send_recv * fix the compile problem * update the enfore message for the windows * update the code for the compiler * update compiler problem for the windows * udpate the code for windows * fix some format problem --- paddle/fluid/operators/graph_send_recv_op.cc | 12 +- paddle/fluid/operators/graph_send_recv_op.cu | 419 ------------------ paddle/fluid/operators/graph_send_recv_op.h | 291 ------------ .../phi/kernels/cpu/graph_send_recv_funcs.h | 80 ++++ .../cpu/graph_send_recv_grad_kernel.cc | 172 +++++++ .../phi/kernels/cpu/graph_send_recv_kernel.cc | 153 +++++++ .../phi/kernels/gpu/graph_send_recv_funcs.h | 171 +++++++ .../gpu/graph_send_recv_grad_kernel.cu | 148 +++++++ .../phi/kernels/gpu/graph_send_recv_kernel.cu | 179 ++++++++ .../phi/kernels/graph_send_recv_grad_kernel.h | 33 ++ paddle/phi/kernels/graph_send_recv_kernel.h | 31 ++ paddle/phi/ops/compat/graph_send_recv_sig.cc | 31 ++ 12 files changed, 999 insertions(+), 721 deletions(-) delete mode 100644 paddle/fluid/operators/graph_send_recv_op.cu delete mode 100644 paddle/fluid/operators/graph_send_recv_op.h create mode 100644 paddle/phi/kernels/cpu/graph_send_recv_funcs.h create mode 100644 paddle/phi/kernels/cpu/graph_send_recv_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/graph_send_recv_kernel.cc create mode 100644 paddle/phi/kernels/gpu/graph_send_recv_funcs.h create mode 100644 paddle/phi/kernels/gpu/graph_send_recv_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/graph_send_recv_kernel.cu create mode 100644 paddle/phi/kernels/graph_send_recv_grad_kernel.h create mode 100644 paddle/phi/kernels/graph_send_recv_kernel.h create mode 100644 paddle/phi/ops/compat/graph_send_recv_sig.cc diff --git a/paddle/fluid/operators/graph_send_recv_op.cc b/paddle/fluid/operators/graph_send_recv_op.cc index 6af8388d9e..b759345eda 100644 --- a/paddle/fluid/operators/graph_send_recv_op.cc +++ b/paddle/fluid/operators/graph_send_recv_op.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/graph_send_recv_op.h" +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { @@ -171,13 +171,3 @@ REGISTER_OPERATOR(graph_send_recv, ops::GraphSendRecvOP, ops::GraphSendRecvGradOpMaker, ops::GraphSendRecvGradOpMaker); REGISTER_OPERATOR(graph_send_recv_grad, ops::GraphSendRecvGradOp); -REGISTER_OP_CPU_KERNEL(graph_send_recv, ops::GraphSendRecvOpKernel, - ops::GraphSendRecvOpKernel, - ops::GraphSendRecvOpKernel, - ops::GraphSendRecvOpKernel); - -REGISTER_OP_CPU_KERNEL(graph_send_recv_grad, - ops::GraphSendRecvGradOpKernel, - ops::GraphSendRecvGradOpKernel, - ops::GraphSendRecvGradOpKernel, - ops::GraphSendRecvGradOpKernel); diff --git a/paddle/fluid/operators/graph_send_recv_op.cu b/paddle/fluid/operators/graph_send_recv_op.cu deleted file mode 100644 index f43d31814a..0000000000 --- a/paddle/fluid/operators/graph_send_recv_op.cu +++ /dev/null @@ -1,419 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/operators/graph_send_recv_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" -#include "paddle/fluid/platform/place.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -struct GraphSendRecvSumCUDAFunctor { - DEVICE inline void operator()(const T* params, T* output, const IndexT& in_i, - const IndexT& out_i) { - paddle::platform::CudaAtomicAdd(output + out_i, *(params + in_i)); - } -}; - -template -struct GraphSendRecvMaxCUDAFunctor { - DEVICE inline void operator()(const T* params, T* output, const IndexT& in_i, - const IndexT& out_i) { - paddle::platform::CudaAtomicMax(output + out_i, *(params + in_i)); - } -}; - -template -struct GraphSendRecvMinCUDAFunctor { - DEVICE inline void operator()(const T* params, T* output, const IndexT& in_i, - const IndexT& out_i) { - paddle::platform::CudaAtomicMin(output + out_i, *(params + in_i)); - } -}; - -template -__global__ void GraphSendRecvCUDAKernel(const T* params, - const IndexT* src_indices, - const IndexT* dst_indices, T* output, - size_t index_size, size_t slice_size, - Functor functor) { - CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) { - int64_t indices_i = i / slice_size; - int64_t slice_i = i - indices_i * slice_size; - IndexT src_i = src_indices[indices_i]; - IndexT dst_i = dst_indices[indices_i]; - int64_t in_i = src_i * slice_size + slice_i; - int64_t out_i = dst_i * slice_size + slice_i; - functor(params, output, in_i, out_i); - } -} - -// For max -template -__global__ void InputResetMaxCUDAKernel(T* output, size_t input_size, - size_t slice_size) { - CUDA_KERNEL_LOOP_TYPE(i, input_size * slice_size, int64_t) { - if (*(output + i) == std::numeric_limits::min()) { - *(output + i) = 0; - } - } -} - -// For min -template -__global__ void InputResetMinCUDAKernel(T* output, size_t input_size, - size_t slice_size) { - CUDA_KERNEL_LOOP_TYPE(i, input_size * slice_size, int64_t) { - if (*(output + i) == std::numeric_limits::max()) { - *(output + i) = 0; - } - } -} - -// Get dst_count -template -__global__ void ComputeCountCUDAKernel(int* count, const IndexT* dst_indices, - size_t index_size) { - CUDA_KERNEL_LOOP_TYPE(i, index_size, int64_t) { - IndexT dst_i = dst_indices[i]; - paddle::platform::CudaAtomicAdd(count + dst_i, 1); - } -} - -// For forward mean -template -__global__ void ManipulateMeanCUDAKernel(T* output, int* count, - size_t input_size, size_t slice_size) { - CUDA_KERNEL_LOOP_TYPE(i, input_size * slice_size, int64_t) { - int64_t c_index = i / slice_size; - if (*(count + c_index) > 1) { - *(output + i) = *(output + i) / *(count + c_index); - } - } -} - -// For backward mean -template -__global__ void ManipulateMeanGradCUDAKernel( - const T* params, const IndexT* src_indices, const IndexT* dst_indices, - T* output, size_t index_size, size_t slice_size, const int* dst_count) { - CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) { - int64_t indices_i = i / slice_size; - int64_t slice_i = i - indices_i * slice_size; - IndexT src_i = src_indices[indices_i]; - IndexT dst_i = dst_indices[indices_i]; - int64_t in_i = src_i * slice_size + slice_i; - int64_t out_i = dst_i * slice_size + slice_i; - paddle::platform::CudaAtomicAdd(output + out_i, - *(params + in_i) / dst_count[src_i]); - } -} - -// For backward min and max -template -__global__ void ManipulateMinMaxGradCUDAKernel( - const T* params, const IndexT* src_indices, const IndexT* dst_indices, - T* output, size_t index_size, size_t slice_size, const T* ptr_input, - const T* ptr_output) { - CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) { - int64_t indices_i = i / slice_size; - int64_t slice_i = i - indices_i * slice_size; - IndexT src_i = src_indices[indices_i]; - IndexT dst_i = dst_indices[indices_i]; - int64_t in_i = src_i * slice_size + slice_i; - int64_t out_i = dst_i * slice_size + slice_i; - paddle::platform::CudaAtomicAdd( - output + out_i, - *(params + in_i) * (*(ptr_input + out_i) == *(ptr_output + in_i))); - } -} - -template -void GraphSendRecvOpCUDAKernelLaunchHelper( - const framework::ExecutionContext& ctx, const Tensor& src_index, - const Tensor& dst_index) { - auto* X = ctx.Input("X"); - auto* Y = ctx.Output("Out"); - std::string pool_type = ctx.Attr("pool_type"); - - const int& index_size = src_index.dims()[0]; - - T* p_output = Y->mutable_data(ctx.GetPlace()); - const auto& src_dims = X->dims(); - int64_t memset_size = 1; - for (int i = 0; i < src_dims.size(); ++i) { - memset_size *= src_dims[i]; - } - const size_t& memset_bytes = memset_size * sizeof(T); - if (pool_type == "SUM" || pool_type == "MEAN") { -#ifdef PADDLE_WITH_HIP - hipMemset(p_output, 0, memset_bytes); -#else - cudaMemset(p_output, 0, memset_bytes); -#endif - } else if (pool_type == "MAX") { - thrust::device_ptr p_output_ptr(p_output); - thrust::fill(thrust::device, p_output_ptr, p_output_ptr + memset_size, - std::numeric_limits::min()); - } else if (pool_type == "MIN") { - thrust::device_ptr p_output_ptr(p_output); - thrust::fill(thrust::device, p_output_ptr, p_output_ptr + memset_size, - std::numeric_limits::max()); - } - - if (index_size == 0) return; - - int64_t slice_size = 1; - for (int i = 1; i < src_dims.size(); ++i) { - slice_size *= src_dims[i]; - } - const T* p_src = X->data(); - const IndexT* s_index = src_index.data(); - const IndexT* d_index = dst_index.data(); - -#ifdef PADDLE_WITH_HIP - int block = 256; -#else - int block = 1024; -#endif - int64_t n = slice_size * index_size; - const auto& dev_ctx = ctx.cuda_device_context(); - int64_t max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize()[0]; - int64_t grid_tmp = (n + block - 1) / block; - int64_t grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx; - int64_t input_size = src_dims[0]; - if (pool_type == "SUM") { - GraphSendRecvSumCUDAFunctor functor; - GraphSendRecvCUDAKernel><<< - grid, block, 0, reinterpret_cast( - ctx.device_context()) - .stream()>>>(p_src, s_index, d_index, p_output, - index_size, slice_size, functor); - } else if (pool_type == "MAX") { - GraphSendRecvMaxCUDAFunctor functor; - GraphSendRecvCUDAKernel><<< - grid, block, 0, reinterpret_cast( - ctx.device_context()) - .stream()>>>(p_src, s_index, d_index, p_output, - index_size, slice_size, functor); - - int64_t grid_max_tmp = (input_size * slice_size + block - 1) / block; - int64_t grid_max = - grid_max_tmp < max_grid_dimx ? grid_max_tmp : max_grid_dimx; - InputResetMaxCUDAKernel< - T><<( - ctx.device_context()) - .stream()>>>(p_output, input_size, slice_size); - } else if (pool_type == "MIN") { - GraphSendRecvMinCUDAFunctor functor; - GraphSendRecvCUDAKernel><<< - grid, block, 0, reinterpret_cast( - ctx.device_context()) - .stream()>>>(p_src, s_index, d_index, p_output, - index_size, slice_size, functor); - - int64_t grid_min_tmp = (input_size * slice_size + block - 1) / block; - int64_t grid_min = - grid_min_tmp < max_grid_dimx ? grid_min_tmp : max_grid_dimx; - InputResetMinCUDAKernel< - T><<( - ctx.device_context()) - .stream()>>>(p_output, input_size, slice_size); - } else if (pool_type == "MEAN") { - GraphSendRecvSumCUDAFunctor functor; - GraphSendRecvCUDAKernel><<< - grid, block, 0, reinterpret_cast( - ctx.device_context()) - .stream()>>>(p_src, s_index, d_index, p_output, - index_size, slice_size, functor); - - auto* dst_count = ctx.Output("Dst_count"); - int* p_dst_count = dst_count->mutable_data(ctx.GetPlace()); - -#ifdef PADDLE_WITH_HIP - hipMemset(p_dst_count, 0, input_size * sizeof(int)); -#else - cudaMemset(p_dst_count, 0, input_size * sizeof(int)); -#endif - - int64_t grid_count = (index_size + block - 1) / block; - ComputeCountCUDAKernel< - T, IndexT><<( - ctx.device_context()) - .stream()>>>(p_dst_count, d_index, index_size); - - int64_t grid_mean_tmp = (input_size * slice_size + block - 1) / block; - int64_t grid_mean = - grid_mean_tmp < max_grid_dimx ? grid_mean_tmp : max_grid_dimx; - ManipulateMeanCUDAKernel< - T><<( - ctx.device_context()) - .stream()>>>(p_output, p_dst_count, input_size, slice_size); - } -} - -template -void GraphSendRecvGradOpCUDAKernelLaunchHelper( - const framework::ExecutionContext& ctx, const Tensor& src_index, - const Tensor& dst_index) { - auto* X = ctx.Input(framework::GradVarName("Out")); - auto* Y = ctx.Output(framework::GradVarName("X")); - std::string pool_type = ctx.Attr("pool_type"); - - const int& index_size = src_index.dims()[0]; - - T* p_output = Y->mutable_data(ctx.GetPlace()); - const auto& src_dims = X->dims(); - int64_t memset_size = 1; - for (int i = 0; i < src_dims.size(); ++i) { - memset_size *= src_dims[i]; - } - const size_t& memset_bytes = memset_size * sizeof(T); - -#ifdef PADDLE_WITH_HIP - hipMemset(p_output, 0, memset_bytes); -#else - cudaMemset(p_output, 0, memset_bytes); -#endif - - if (index_size == 0) return; - - int64_t slice_size = 1; - for (int i = 1; i < src_dims.size(); ++i) { - slice_size *= src_dims[i]; - } - const T* p_src = X->data(); - const IndexT* s_index = src_index.data(); - const IndexT* d_index = dst_index.data(); - -#ifdef PADDLE_WITH_HIP - int block = 256; -#else - int block = 1024; -#endif - int64_t n = slice_size * index_size; - const auto& dev_ctx = ctx.cuda_device_context(); - int64_t max_grid_dimx = dev_ctx.GetCUDAMaxGridDimSize()[0]; - int64_t grid_tmp = (n + block - 1) / block; - int64_t grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx; - int64_t input_size = src_dims[0]; - if (pool_type == "SUM") { - GraphSendRecvSumCUDAFunctor functor; - GraphSendRecvCUDAKernel><<< - grid, block, 0, reinterpret_cast( - ctx.device_context()) - .stream()>>>(p_src, s_index, d_index, p_output, - index_size, slice_size, functor); - } else if (pool_type == "MEAN") { - auto* dst_count = ctx.Input("Dst_count"); - const int* s_count = dst_count->data(); - ManipulateMeanGradCUDAKernel<<< - grid, block, 0, reinterpret_cast( - ctx.device_context()) - .stream()>>>(p_src, s_index, d_index, p_output, - index_size, slice_size, s_count); - } else if (pool_type == "MAX" || pool_type == "MIN") { - auto* input = ctx.Input("X"); - auto* output = ctx.Input("Out"); - const T* ptr_input = input->data(); - const T* ptr_output = output->data(); - ManipulateMinMaxGradCUDAKernel<<< - grid, block, 0, reinterpret_cast( - ctx.device_context()) - .stream()>>>(p_src, s_index, d_index, p_output, - index_size, slice_size, ptr_input, - ptr_output); - } -} - -template -class GraphSendRecvOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* src_index = ctx.Input("Src_index"); - auto* dst_index = ctx.Input("Dst_index"); - auto index_type = framework::TransToProtoVarType(src_index->dtype()); - - if (index_type == framework::proto::VarType::INT32) { - GraphSendRecvOpCUDAKernelLaunchHelper( - ctx, *src_index, *dst_index); - } else if (index_type == framework::proto::VarType::INT64) { - GraphSendRecvOpCUDAKernelLaunchHelper( - ctx, *src_index, *dst_index); - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "Unsupported Src_index or Dst_index dtype, expected int, int64, but " - "got %s.", - index_type)); - } - } -}; - -template -class GraphSendRecvGradOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* src_index = ctx.Input("Dst_index"); - auto* dst_index = ctx.Input("Src_index"); - auto index_type = framework::TransToProtoVarType(src_index->dtype()); - - if (index_type == framework::proto::VarType::INT32) { - GraphSendRecvGradOpCUDAKernelLaunchHelper( - ctx, *src_index, *dst_index); - } else if (index_type == framework::proto::VarType::INT64) { - GraphSendRecvGradOpCUDAKernelLaunchHelper( - ctx, *src_index, *dst_index); - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "Unsupported Src_index or Dst_index dtype, expected int, int64, but " - "got %s.", - index_type)); - } - } -}; - -} // namespace operators -} // namespace paddle - -using CUDA = paddle::platform::CUDADeviceContext; -namespace ops = paddle::operators; - -REGISTER_OP_CUDA_KERNEL(graph_send_recv, - ops::GraphSendRecvOpCUDAKernel, - ops::GraphSendRecvOpCUDAKernel, - ops::GraphSendRecvOpCUDAKernel, - ops::GraphSendRecvOpCUDAKernel); - -REGISTER_OP_CUDA_KERNEL(graph_send_recv_grad, - ops::GraphSendRecvGradOpCUDAKernel, - ops::GraphSendRecvGradOpCUDAKernel, - ops::GraphSendRecvGradOpCUDAKernel, - ops::GraphSendRecvGradOpCUDAKernel); diff --git a/paddle/fluid/operators/graph_send_recv_op.h b/paddle/fluid/operators/graph_send_recv_op.h deleted file mode 100644 index 8d8111e0ee..0000000000 --- a/paddle/fluid/operators/graph_send_recv_op.h +++ /dev/null @@ -1,291 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/platform/place.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -struct GraphSendRecvSumFunctor { - void operator()(const bool& first_flag, const Tensor& src_slice, - Tensor* dst_slice) { - auto eigen_src = framework::EigenVector::Flatten(src_slice); - auto eigen_dst = framework::EigenVector::Flatten(*dst_slice); - eigen_dst += eigen_src; - } -}; - -template -struct GraphSendRecvMinFunctor { - void operator()(const bool& first_flag, const Tensor& src_slice, - Tensor* dst_slice) { - auto eigen_src = framework::EigenVector::Flatten(src_slice); - auto eigen_dst = framework::EigenVector::Flatten(*dst_slice); - if (first_flag) { - eigen_dst += eigen_src; - } else { - eigen_dst = eigen_dst.cwiseMin(eigen_src); - } - } -}; - -template -struct GraphSendRecvMaxFunctor { - void operator()(const int& first_flag, const Tensor& src_slice, - Tensor* dst_slice) { - auto eigen_src = framework::EigenVector::Flatten(src_slice); - auto eigen_dst = framework::EigenVector::Flatten(*dst_slice); - if (first_flag) { - eigen_dst += eigen_src; - } else { - eigen_dst = eigen_dst.cwiseMax(eigen_src); - } - } -}; - -template -void elementwise_inner_operation(const Tensor& src, Tensor* dst, - const IndexT& src_index, - const IndexT& dst_index, - const bool& first_flag, Functor functor) { - auto src_slice = src.Slice(src_index, src_index + 1); - auto dst_slice = dst->Slice(dst_index, dst_index + 1); - - functor(first_flag, src_slice, &dst_slice); -} - -template -void graph_send_recv_cpu_for_loop(const int& input_size, const int& index_size, - const IndexT* s_index, const IndexT* d_index, - const Tensor& src, Tensor* dst, - const std::string& pool_type, - int* dst_count = nullptr) { - Functor functor; - if (pool_type == "SUM") { - for (int i = 0; i < index_size; ++i) { - const IndexT& src_idx = s_index[i]; - const IndexT& dst_idx = d_index[i]; - elementwise_inner_operation(src, dst, src_idx, - dst_idx, false, functor); - } - } else if (pool_type == "MEAN") { - for (int i = 0; i < index_size; ++i) { - const IndexT& src_idx = s_index[i]; - const IndexT& dst_idx = d_index[i]; - elementwise_inner_operation(src, dst, src_idx, - dst_idx, false, functor); - } - for (int i = 0; i < index_size; ++i) { - IndexT dst_idx = d_index[i]; - *(dst_count + dst_idx) += 1; - } - for (int i = 0; i < input_size; ++i) { - if (*(dst_count + i) == 0) continue; - auto dst_slice = dst->Slice(i, i + 1); - auto eigen_dst = framework::EigenVector::Flatten(dst_slice); - eigen_dst = eigen_dst / static_cast(*(dst_count + i)); - } - } else if (pool_type == "MIN" || pool_type == "MAX") { - std::set existed_dst; - for (int i = 0; i < index_size; ++i) { - const IndexT& src_idx = s_index[i]; - const IndexT& dst_idx = d_index[i]; - bool in_set = existed_dst.find(dst_idx) != existed_dst.end(); - if (!in_set) { - elementwise_inner_operation(src, dst, src_idx, - dst_idx, true, functor); - existed_dst.emplace(dst_idx); - } else { - elementwise_inner_operation( - src, dst, src_idx, dst_idx, false, functor); - } - } - } -} - -template -void graph_send_recv_cpu_for_loop_grad( - const int& input_size, const int& index_size, const IndexT* s_index, - const IndexT* d_index, const Tensor& src, Tensor* dst, - const std::string& pool_type, const int* dst_count = nullptr, - const Tensor* input = nullptr, const Tensor* output = nullptr) { - if (pool_type == "SUM") { - Functor functor; - for (int i = 0; i < index_size; ++i) { - const IndexT& src_idx = s_index[i]; - const IndexT& dst_idx = d_index[i]; - elementwise_inner_operation(src, dst, src_idx, - dst_idx, false, functor); - } - } else if (pool_type == "MEAN") { - for (int i = 0; i < index_size; ++i) { - const IndexT& src_idx = s_index[i]; - const IndexT& dst_idx = d_index[i]; - auto src_slice = src.Slice(src_idx, src_idx + 1); - auto dst_slice = dst->Slice(dst_idx, dst_idx + 1); - auto eigen_src = framework::EigenVector::Flatten(src_slice); - auto eigen_dst = framework::EigenVector::Flatten(dst_slice); - eigen_dst += (eigen_src / static_cast(dst_count[src_idx])); - } - } else if (pool_type == "MIN" || pool_type == "MAX") { - for (int i = 0; i < index_size; ++i) { - const IndexT& forward_src_idx = d_index[i]; - const IndexT& forward_dst_idx = s_index[i]; - auto input_slice = input->Slice(forward_src_idx, forward_src_idx + 1); - auto output_slice = output->Slice(forward_dst_idx, forward_dst_idx + 1); - auto eigen_input = framework::EigenVector::Flatten(input_slice); - auto eigen_output = framework::EigenVector::Flatten(output_slice); - - auto src_slice = src.Slice(forward_dst_idx, forward_dst_idx + 1); - auto dst_slice = dst->Slice(forward_src_idx, forward_src_idx + 1); - auto eigen_src = framework::EigenVector::Flatten(src_slice); - auto eigen_dst = framework::EigenVector::Flatten(dst_slice); - eigen_dst += eigen_src * (eigen_output == eigen_input); - } - } -} - -template -void GraphSendRecvOpKernelLaunchHelper(const framework::ExecutionContext& ctx, - const Tensor& src_index) { - auto* X = ctx.Input("X"); - auto* dst_index = ctx.Input("Dst_index"); - auto* Y = ctx.Output("Out"); - - const int& index_size = src_index.dims()[0]; - - T* p_output = Y->mutable_data(ctx.GetPlace()); - const auto& src_dims = X->dims(); - int64_t memset_size = 1; - for (int i = 0; i < src_dims.size(); ++i) memset_size *= src_dims[i]; - const size_t& memset_bytes = memset_size * sizeof(T); - memset(p_output, 0, memset_bytes); - - if (index_size == 0) return; - - const IndexT* s_index = src_index.data(); - const IndexT* d_index = dst_index->data(); - const std::string& pool_type = ctx.Attr("pool_type"); - if (pool_type == "SUM") { - graph_send_recv_cpu_for_loop>( - src_dims[0], index_size, s_index, d_index, *X, Y, pool_type); - } else if (pool_type == "MIN") { - graph_send_recv_cpu_for_loop>( - src_dims[0], index_size, s_index, d_index, *X, Y, pool_type); - } else if (pool_type == "MAX") { - graph_send_recv_cpu_for_loop>( - src_dims[0], index_size, s_index, d_index, *X, Y, pool_type); - } else if (pool_type == "MEAN") { - auto* dst_count = ctx.Output("Dst_count"); - int* p_dst_count = dst_count->mutable_data(ctx.GetPlace()); - memset(p_dst_count, 0, src_dims[0] * sizeof(int)); - graph_send_recv_cpu_for_loop>( - src_dims[0], index_size, s_index, d_index, *X, Y, pool_type, - p_dst_count); - } -} - -template -void GraphSendRecvGradOpKernelLaunchHelper( - const framework::ExecutionContext& ctx, const Tensor& src_index) { - auto* X = ctx.Input(framework::GradVarName("Out")); - auto* dst_index = ctx.Input("Src_index"); - auto* Y = ctx.Output(framework::GradVarName("X")); - - const int& index_size = src_index.dims()[0]; - - T* p_output = Y->mutable_data(ctx.GetPlace()); - const auto& src_dims = X->dims(); - int64_t memset_size = 1; - for (int i = 0; i < src_dims.size(); ++i) memset_size *= src_dims[i]; - const size_t& memset_bytes = memset_size * sizeof(T); - memset(p_output, 0, memset_bytes); - - if (index_size == 0) return; - - const IndexT* s_index = src_index.data(); - const IndexT* d_index = dst_index->data(); - - const std::string& pool_type = ctx.Attr("pool_type"); - if (pool_type == "SUM") { - graph_send_recv_cpu_for_loop_grad>( - src_dims[0], index_size, s_index, d_index, *X, Y, pool_type); - } else if (pool_type == "MEAN") { - auto* dst_count = ctx.Input("Dst_count"); - const int* s_count = dst_count->data(); - // Functor not used here. - graph_send_recv_cpu_for_loop_grad>( - src_dims[0], index_size, s_index, d_index, *X, Y, pool_type, s_count); - } else if (pool_type == "MIN" || pool_type == "MAX") { - const auto* input = ctx.Input("X"); - const auto* output = ctx.Input("Out"); - // Functor not used here. - graph_send_recv_cpu_for_loop_grad>( - src_dims[0], index_size, s_index, d_index, *X, Y, pool_type, nullptr, - input, output); - } -} - -template -class GraphSendRecvOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* src_index = ctx.Input("Src_index"); - auto index_type = framework::TransToProtoVarType(src_index->dtype()); - - if (index_type == framework::proto::VarType::INT32) { - GraphSendRecvOpKernelLaunchHelper(ctx, *src_index); - } else if (index_type == framework::proto::VarType::INT64) { - GraphSendRecvOpKernelLaunchHelper(ctx, - *src_index); - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "Unsupported Src_index or Dst_index type, Expected int, int64, but " - "got %s.", - index_type)); - } - } -}; - -template -class GraphSendRecvGradOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* src_index = ctx.Input("Dst_index"); - auto index_type = framework::TransToProtoVarType(src_index->dtype()); - - if (index_type == framework::proto::VarType::INT32) { - GraphSendRecvGradOpKernelLaunchHelper(ctx, - *src_index); - } else if (index_type == framework::proto::VarType::INT64) { - GraphSendRecvGradOpKernelLaunchHelper( - ctx, *src_index); - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "Unsupported Src_index or Dst_index type, Expected int, int64, but " - "got %s.", - index_type)); - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/kernels/cpu/graph_send_recv_funcs.h b/paddle/phi/kernels/cpu/graph_send_recv_funcs.h new file mode 100644 index 0000000000..df6d9c87be --- /dev/null +++ b/paddle/phi/kernels/cpu/graph_send_recv_funcs.h @@ -0,0 +1,80 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { + +template +struct GraphSendRecvSumFunctor { + void operator()(const bool& first_flag, + const DenseTensor& src_slice, + DenseTensor* dst_slice) { + auto eigen_src = phi::EigenVector::Flatten(src_slice); + auto eigen_dst = phi::EigenVector::Flatten(*dst_slice); + eigen_dst += eigen_src; + } +}; + +template +struct GraphSendRecvMinFunctor { + void operator()(const bool& first_flag, + const DenseTensor& src_slice, + DenseTensor* dst_slice) { + auto eigen_src = phi::EigenVector::Flatten(src_slice); + auto eigen_dst = phi::EigenVector::Flatten(*dst_slice); + if (first_flag) { + eigen_dst += eigen_src; + } else { + eigen_dst = eigen_dst.cwiseMin(eigen_src); + } + } +}; + +template +struct GraphSendRecvMaxFunctor { + void operator()(const int& first_flag, + const DenseTensor& src_slice, + DenseTensor* dst_slice) { + auto eigen_src = phi::EigenVector::Flatten(src_slice); + auto eigen_dst = phi::EigenVector::Flatten(*dst_slice); + if (first_flag) { + eigen_dst += eigen_src; + } else { + eigen_dst = eigen_dst.cwiseMax(eigen_src); + } + } +}; + +template +void ElementwiseInnerOperation(const DenseTensor& src, + DenseTensor* dst, + const IndexT& src_index, + const IndexT& dst_index, + const bool& first_flag, + Functor functor) { + auto src_slice = src.Slice(src_index, src_index + 1); + auto dst_slice = dst->Slice(dst_index, dst_index + 1); + + functor(first_flag, src_slice, &dst_slice); +} + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/graph_send_recv_grad_kernel.cc b/paddle/phi/kernels/cpu/graph_send_recv_grad_kernel.cc new file mode 100644 index 0000000000..8538461b1b --- /dev/null +++ b/paddle/phi/kernels/cpu/graph_send_recv_grad_kernel.cc @@ -0,0 +1,172 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/graph_send_recv_grad_kernel.h" +#include "paddle/phi/kernels/cpu/graph_send_recv_funcs.h" + +#include +#include + +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void GraphSendRecvCpuGradLoop(const int& input_size, + const int& index_size, + const IndexT* s_index, + const IndexT* d_index, + const DenseTensor& src, + DenseTensor* dst, + const std::string& pool_type, + const int* dst_count = nullptr, + const DenseTensor* input = nullptr, + const DenseTensor* output = nullptr) { + if (pool_type == "SUM") { + Functor functor; + for (int i = 0; i < index_size; ++i) { + const IndexT& src_idx = s_index[i]; + const IndexT& dst_idx = d_index[i]; + ElementwiseInnerOperation( + src, dst, src_idx, dst_idx, false, functor); + } + } else if (pool_type == "MEAN") { + for (int i = 0; i < index_size; ++i) { + const IndexT& src_idx = s_index[i]; + const IndexT& dst_idx = d_index[i]; + auto src_slice = src.Slice(src_idx, src_idx + 1); + auto dst_slice = dst->Slice(dst_idx, dst_idx + 1); + auto eigen_src = phi::EigenVector::Flatten(src_slice); + auto eigen_dst = phi::EigenVector::Flatten(dst_slice); + eigen_dst += (eigen_src / static_cast(dst_count[src_idx])); + } + } else if (pool_type == "MIN" || pool_type == "MAX") { + for (int i = 0; i < index_size; ++i) { + const IndexT& forward_src_idx = d_index[i]; + const IndexT& forward_dst_idx = s_index[i]; + auto input_slice = input->Slice(forward_src_idx, forward_src_idx + 1); + auto output_slice = output->Slice(forward_dst_idx, forward_dst_idx + 1); + auto eigen_input = phi::EigenVector::Flatten(input_slice); + auto eigen_output = phi::EigenVector::Flatten(output_slice); + + auto src_slice = src.Slice(forward_dst_idx, forward_dst_idx + 1); + auto dst_slice = dst->Slice(forward_src_idx, forward_src_idx + 1); + auto eigen_src = phi::EigenVector::Flatten(src_slice); + auto eigen_dst = phi::EigenVector::Flatten(dst_slice); + eigen_dst += eigen_src * (eigen_output == eigen_input); + } + } +} + +template +void GraphSendRecvGradOpKernelLaunchHelper( + const Context& ctx, + const DenseTensor& out_grad, + const DenseTensor& src_index, + const DenseTensor& dst_index, + const std::string& pool_type, + DenseTensor* x_grad, + const DenseTensor* dst_count = nullptr, + const DenseTensor* x = nullptr, + const DenseTensor* out = nullptr) { + const int& index_size = dst_index.dims()[0]; + + ctx.template Alloc(x_grad); + T* p_output = x_grad->data(); + const auto& src_dims = out_grad.dims(); + int64_t memset_size = 1; + for (int i = 0; i < src_dims.size(); ++i) memset_size *= src_dims[i]; + const size_t& memset_bytes = memset_size * sizeof(T); + memset(p_output, 0, memset_bytes); + + if (index_size == 0) return; + + const IndexT* s_index = src_index.data(); + const IndexT* d_index = dst_index.data(); + + if (pool_type == "SUM") { + GraphSendRecvCpuGradLoop>( + src_dims[0], index_size, d_index, s_index, out_grad, x_grad, pool_type); + } else if (pool_type == "MEAN") { + const int* s_count = dst_count->data(); + // Functor not used here. + GraphSendRecvCpuGradLoop>(src_dims[0], + index_size, + d_index, + s_index, + out_grad, + x_grad, + pool_type, + s_count); + } else if (pool_type == "MIN" || pool_type == "MAX") { + // Functor not used here. + GraphSendRecvCpuGradLoop>(src_dims[0], + index_size, + d_index, + s_index, + out_grad, + x_grad, + pool_type, + nullptr, + x, + out); + } +} + +template +void GraphSendRecvGradKernel(const Context& ctx, + const DenseTensor& out_grad, + paddle::optional x, + paddle::optional out, + const DenseTensor& src_index, + const DenseTensor& dst_index, + paddle::optional dst_count, + const std::string& pool_type, + DenseTensor* x_grad) { + auto index_type = src_index.dtype(); + if (index_type == phi::DataType::INT32) { + GraphSendRecvGradOpKernelLaunchHelper( + ctx, + out_grad, + src_index, + dst_index, + pool_type, + x_grad, + dst_count.get_ptr(), + x.get_ptr(), + out.get_ptr()); + } else if (index_type == phi::DataType::INT64) { + GraphSendRecvGradOpKernelLaunchHelper( + ctx, + out_grad, + src_index, + dst_index, + pool_type, + x_grad, + dst_count.get_ptr(), + x.get_ptr(), + out.get_ptr()); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(graph_send_recv_grad, + CPU, + ALL_LAYOUT, + phi::GraphSendRecvGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc b/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc new file mode 100644 index 0000000000..fecbd4b1d7 --- /dev/null +++ b/paddle/phi/kernels/cpu/graph_send_recv_kernel.cc @@ -0,0 +1,153 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/graph_send_recv_kernel.h" +#include "paddle/phi/kernels/cpu/graph_send_recv_funcs.h" + +#include +#include +#include + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void GraphSendRecvCpuLoop(const int& input_size, + const int& index_size, + const IndexT* s_index, + const IndexT* d_index, + const DenseTensor& src, + DenseTensor* dst, + const std::string& pool_type, + int* dst_count = nullptr) { + Functor functor; + if (pool_type == "SUM") { + for (int i = 0; i < index_size; ++i) { + const IndexT& src_idx = s_index[i]; + const IndexT& dst_idx = d_index[i]; + ElementwiseInnerOperation( + src, dst, src_idx, dst_idx, false, functor); + } + } else if (pool_type == "MEAN") { + for (int i = 0; i < index_size; ++i) { + const IndexT& src_idx = s_index[i]; + const IndexT& dst_idx = d_index[i]; + ElementwiseInnerOperation( + src, dst, src_idx, dst_idx, false, functor); + } + for (int i = 0; i < index_size; ++i) { + IndexT dst_idx = d_index[i]; + *(dst_count + dst_idx) += 1; + } + for (int i = 0; i < input_size; ++i) { + if (*(dst_count + i) == 0) continue; + auto dst_slice = dst->Slice(i, i + 1); + auto eigen_dst = phi::EigenVector::Flatten(dst_slice); + eigen_dst = eigen_dst / static_cast(*(dst_count + i)); + } + } else if (pool_type == "MIN" || pool_type == "MAX") { + std::set existed_dst; + for (int i = 0; i < index_size; ++i) { + const IndexT& src_idx = s_index[i]; + const IndexT& dst_idx = d_index[i]; + bool in_set = existed_dst.find(dst_idx) != existed_dst.end(); + if (!in_set) { + ElementwiseInnerOperation( + src, dst, src_idx, dst_idx, true, functor); + existed_dst.emplace(dst_idx); + } else { + ElementwiseInnerOperation( + src, dst, src_idx, dst_idx, false, functor); + } + } + } +} + +template +void GraphSendRecvOpKernelLaunchHelper(const Context& ctx, + const DenseTensor& x, + const DenseTensor& src_index, + const DenseTensor& dst_index, + const std::string& pool_type, + DenseTensor* out, + DenseTensor* dst_count = nullptr) { + const int& index_size = src_index.dims()[0]; + + ctx.template Alloc(out); + T* p_output = out->data(); + const auto& src_dims = x.dims(); + int64_t memset_size = 1; + for (int i = 0; i < src_dims.size(); ++i) memset_size *= src_dims[i]; + const size_t& memset_bytes = memset_size * sizeof(T); + memset(p_output, 0, memset_bytes); + + if (index_size == 0) return; + + const IndexT* s_index = src_index.data(); + const IndexT* d_index = dst_index.data(); + if (pool_type == "SUM") { + GraphSendRecvCpuLoop>( + src_dims[0], index_size, s_index, d_index, x, out, pool_type); + } else if (pool_type == "MIN") { + GraphSendRecvCpuLoop>( + src_dims[0], index_size, s_index, d_index, x, out, pool_type); + } else if (pool_type == "MAX") { + GraphSendRecvCpuLoop>( + src_dims[0], index_size, s_index, d_index, x, out, pool_type); + } else if (pool_type == "MEAN") { + ctx.template Alloc(dst_count); + int* p_dst_count = dst_count->data(); + memset(p_dst_count, 0, src_dims[0] * sizeof(int)); + GraphSendRecvCpuLoop>(src_dims[0], + index_size, + s_index, + d_index, + x, + out, + pool_type, + p_dst_count); + } +} + +template +void GraphSendRecvKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& src_index, + const DenseTensor& dst_index, + const std::string& pool_type, + DenseTensor* out, + DenseTensor* dst_count) { + auto index_type = src_index.dtype(); + if (index_type == phi::DataType::INT32) { + GraphSendRecvOpKernelLaunchHelper( + ctx, x, src_index, dst_index, pool_type, out, dst_count); + } else if (index_type == phi::DataType::INT64) { + GraphSendRecvOpKernelLaunchHelper( + ctx, x, src_index, dst_index, pool_type, out, dst_count); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(graph_send_recv, + CPU, + ALL_LAYOUT, + phi::GraphSendRecvKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/graph_send_recv_funcs.h b/paddle/phi/kernels/gpu/graph_send_recv_funcs.h new file mode 100644 index 0000000000..1eab521170 --- /dev/null +++ b/paddle/phi/kernels/gpu/graph_send_recv_funcs.h @@ -0,0 +1,171 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/kernels/graph_send_recv_kernel.h" + +#include +#include +#include +#include + +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/hostdevice.h" + +namespace phi { + +template +struct GraphSendRecvSumCUDAFunctor { + DEVICE inline void operator()(const T* params, + T* output, + const IndexT& in_i, + const IndexT& out_i) { + paddle::platform::CudaAtomicAdd(output + out_i, *(params + in_i)); + } +}; + +template +struct GraphSendRecvMaxCUDAFunctor { + DEVICE inline void operator()(const T* params, + T* output, + const IndexT& in_i, + const IndexT& out_i) { + paddle::platform::CudaAtomicMax(output + out_i, *(params + in_i)); + } +}; + +template +struct GraphSendRecvMinCUDAFunctor { + DEVICE inline void operator()(const T* params, + T* output, + const IndexT& in_i, + const IndexT& out_i) { + paddle::platform::CudaAtomicMin(output + out_i, *(params + in_i)); + } +}; + +template +__global__ void GraphSendRecvCUDAKernel(const T* params, + const IndexT* src_indices, + const IndexT* dst_indices, + T* output, + size_t index_size, + size_t slice_size, + Functor functor) { + CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) { + int64_t indices_i = i / slice_size; + int64_t slice_i = i - indices_i * slice_size; + IndexT src_i = src_indices[indices_i]; + IndexT dst_i = dst_indices[indices_i]; + int64_t in_i = src_i * slice_size + slice_i; + int64_t out_i = dst_i * slice_size + slice_i; + functor(params, output, in_i, out_i); + } +} + +// For max +template +__global__ void InputResetMaxCUDAKernel(T* output, + size_t input_size, + size_t slice_size) { + CUDA_KERNEL_LOOP_TYPE(i, input_size * slice_size, int64_t) { + if (*(output + i) == std::numeric_limits::min()) { + *(output + i) = 0; + } + } +} + +// For min +template +__global__ void InputResetMinCUDAKernel(T* output, + size_t input_size, + size_t slice_size) { + CUDA_KERNEL_LOOP_TYPE(i, input_size * slice_size, int64_t) { + if (*(output + i) == std::numeric_limits::max()) { + *(output + i) = 0; + } + } +} + +// Get dst_count +template +__global__ void ComputeCountCUDAKernel(int32_t* count, + const IndexT* dst_indices, + size_t index_size) { + CUDA_KERNEL_LOOP_TYPE(i, index_size, int64_t) { + IndexT dst_i = dst_indices[i]; + paddle::platform::CudaAtomicAdd(count + dst_i, 1); + } +} + +// For forward mean +template +__global__ void ManipulateMeanCUDAKernel(T* output, + int32_t* count, + size_t input_size, + size_t slice_size) { + CUDA_KERNEL_LOOP_TYPE(i, input_size * slice_size, int64_t) { + int64_t c_index = i / slice_size; + if (*(count + c_index) > 1) { + *(output + i) = *(output + i) / *(count + c_index); + } + } +} + +// For backward mean +template +__global__ void ManipulateMeanGradCUDAKernel(const T* params, + const IndexT* src_indices, + const IndexT* dst_indices, + T* output, + size_t index_size, + size_t slice_size, + const int32_t* dst_count) { + CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) { + int64_t indices_i = i / slice_size; + int64_t slice_i = i - indices_i * slice_size; + IndexT src_i = src_indices[indices_i]; + IndexT dst_i = dst_indices[indices_i]; + int64_t in_i = src_i * slice_size + slice_i; + int64_t out_i = dst_i * slice_size + slice_i; + paddle::platform::CudaAtomicAdd(output + out_i, + *(params + in_i) / dst_count[src_i]); + } +} + +// For backward min and max +template +__global__ void ManipulateMinMaxGradCUDAKernel(const T* params, + const IndexT* src_indices, + const IndexT* dst_indices, + T* output, + size_t index_size, + size_t slice_size, + const T* ptr_input, + const T* ptr_output) { + CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) { + int64_t indices_i = i / slice_size; + int64_t slice_i = i - indices_i * slice_size; + IndexT src_i = src_indices[indices_i]; + IndexT dst_i = dst_indices[indices_i]; + int64_t in_i = src_i * slice_size + slice_i; + int64_t out_i = dst_i * slice_size + slice_i; + paddle::platform::CudaAtomicAdd( + output + out_i, + *(params + in_i) * (*(ptr_input + out_i) == *(ptr_output + in_i))); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/graph_send_recv_grad_kernel.cu b/paddle/phi/kernels/gpu/graph_send_recv_grad_kernel.cu new file mode 100644 index 0000000000..75692966b4 --- /dev/null +++ b/paddle/phi/kernels/gpu/graph_send_recv_grad_kernel.cu @@ -0,0 +1,148 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gpu/graph_send_recv_funcs.h" +#include "paddle/phi/kernels/graph_send_recv_grad_kernel.h" + +#include +#include + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void GraphSendRecvGradOpCUDAKernelLaunchHelper( + const Context& ctx, + const DenseTensor& out_grad, + const DenseTensor& src_index, + const DenseTensor& dst_index, + const std::string& pool_type, + DenseTensor* x_grad, + const DenseTensor* dst_count = nullptr, + const DenseTensor* x = nullptr, + const DenseTensor* out = nullptr) { + const int& index_size = dst_index.dims()[0]; + + ctx.template Alloc(x_grad); + T* p_output = x_grad->data(); + + const auto& src_dims = out_grad.dims(); + int64_t memset_size = 1; + for (int i = 0; i < src_dims.size(); ++i) { + memset_size *= src_dims[i]; + } + const size_t& memset_bytes = memset_size * sizeof(T); + +#ifdef PADDLE_WITH_HIP + hipMemset(p_output, 0, memset_bytes); +#else + cudaMemset(p_output, 0, memset_bytes); +#endif + + if (index_size == 0) return; + + int64_t slice_size = 1; + for (int i = 1; i < src_dims.size(); ++i) { + slice_size *= src_dims[i]; + } + const T* p_src = out_grad.data(); + const IndexT* s_index = src_index.data(); + const IndexT* d_index = dst_index.data(); + +#ifdef PADDLE_WITH_HIP + int block = 256; +#else + int block = 1024; +#endif + int64_t n = slice_size * index_size; + int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize()[0]; + int64_t grid_tmp = (n + block - 1) / block; + int64_t grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx; + int64_t input_size = src_dims[0]; + if (pool_type == "SUM") { + GraphSendRecvSumCUDAFunctor functor; + GraphSendRecvCUDAKernel< + T, + IndexT, + GraphSendRecvSumCUDAFunctor><<>>( + p_src, d_index, s_index, p_output, index_size, slice_size, functor); + } else if (pool_type == "MEAN") { + const int32_t* s_count = dst_count->data(); + ManipulateMeanGradCUDAKernel<<>>( + p_src, d_index, s_index, p_output, index_size, slice_size, s_count); + } else if (pool_type == "MAX" || pool_type == "MIN") { + const T* ptr_input = x->data(); + const T* ptr_output = out->data(); + ManipulateMinMaxGradCUDAKernel<<>>( + p_src, + d_index, + s_index, + p_output, + index_size, + slice_size, + ptr_input, + ptr_output); + } +} + +template +void GraphSendRecvGradKernel(const Context& ctx, + const DenseTensor& out_grad, + paddle::optional x, + paddle::optional out, + const DenseTensor& src_index, + const DenseTensor& dst_index, + paddle::optional dst_count, + const std::string& pool_type, + DenseTensor* x_grad) { + auto index_type = src_index.dtype(); + if (index_type == phi::DataType::INT32) { + GraphSendRecvGradOpCUDAKernelLaunchHelper( + ctx, + out_grad, + src_index, + dst_index, + pool_type, + x_grad, + dst_count.get_ptr(), + x.get_ptr(), + out.get_ptr()); + } else if (index_type == phi::DataType::INT64) { + GraphSendRecvGradOpCUDAKernelLaunchHelper( + ctx, + out_grad, + src_index, + dst_index, + pool_type, + x_grad, + dst_count.get_ptr(), + x.get_ptr(), + out.get_ptr()); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(graph_send_recv_grad, + GPU, + ALL_LAYOUT, + phi::GraphSendRecvGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu b/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu new file mode 100644 index 0000000000..fab306f831 --- /dev/null +++ b/paddle/phi/kernels/gpu/graph_send_recv_kernel.cu @@ -0,0 +1,179 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gpu/graph_send_recv_funcs.h" +#include "paddle/phi/kernels/graph_send_recv_kernel.h" + +#include +#include +#include +#include + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx, + const DenseTensor& x, + const DenseTensor& src_index, + const DenseTensor& dst_index, + const std::string& pool_type, + DenseTensor* out, + DenseTensor* dst_count = nullptr) { + const int& index_size = src_index.dims()[0]; + ctx.template Alloc(out); + T* p_output = out->data(); + const auto& src_dims = x.dims(); + int64_t memset_size = 1; + for (int i = 0; i < src_dims.size(); ++i) { + memset_size *= src_dims[i]; + } + const size_t& memset_bytes = memset_size * sizeof(T); + if (pool_type == "SUM" || pool_type == "MEAN") { +#ifdef PADDLE_WITH_HIP + hipMemset(p_output, 0, memset_bytes); +#else + cudaMemset(p_output, 0, memset_bytes); +#endif + } else if (pool_type == "MAX") { + thrust::device_ptr p_output_ptr(p_output); + thrust::fill(thrust::device, + p_output_ptr, + p_output_ptr + memset_size, + std::numeric_limits::min()); + } else if (pool_type == "MIN") { + thrust::device_ptr p_output_ptr(p_output); + thrust::fill(thrust::device, + p_output_ptr, + p_output_ptr + memset_size, + std::numeric_limits::max()); + } + + if (index_size == 0) return; + + int64_t slice_size = 1; + for (int i = 1; i < src_dims.size(); ++i) { + slice_size *= src_dims[i]; + } + const T* p_src = x.data(); + const IndexT* s_index = src_index.data(); + const IndexT* d_index = dst_index.data(); + +#ifdef PADDLE_WITH_HIP + int block = 256; +#else + int block = 1024; +#endif + int64_t n = slice_size * index_size; + int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize()[0]; + int64_t grid_tmp = (n + block - 1) / block; + int64_t grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx; + int64_t input_size = src_dims[0]; + if (pool_type == "SUM") { + GraphSendRecvSumCUDAFunctor functor; + GraphSendRecvCUDAKernel< + T, + IndexT, + GraphSendRecvSumCUDAFunctor><<>>( + p_src, s_index, d_index, p_output, index_size, slice_size, functor); + } else if (pool_type == "MAX") { + GraphSendRecvMaxCUDAFunctor functor; + GraphSendRecvCUDAKernel< + T, + IndexT, + GraphSendRecvMaxCUDAFunctor><<>>( + p_src, s_index, d_index, p_output, index_size, slice_size, functor); + + int64_t grid_max_tmp = (input_size * slice_size + block - 1) / block; + int64_t grid_max = + grid_max_tmp < max_grid_dimx ? grid_max_tmp : max_grid_dimx; + InputResetMaxCUDAKernel<<>>( + p_output, input_size, slice_size); + } else if (pool_type == "MIN") { + GraphSendRecvMinCUDAFunctor functor; + GraphSendRecvCUDAKernel< + T, + IndexT, + GraphSendRecvMinCUDAFunctor><<>>( + p_src, s_index, d_index, p_output, index_size, slice_size, functor); + + int64_t grid_min_tmp = (input_size * slice_size + block - 1) / block; + int64_t grid_min = + grid_min_tmp < max_grid_dimx ? grid_min_tmp : max_grid_dimx; + InputResetMinCUDAKernel<<>>( + p_output, input_size, slice_size); + } else if (pool_type == "MEAN") { + GraphSendRecvSumCUDAFunctor functor; + GraphSendRecvCUDAKernel< + T, + IndexT, + GraphSendRecvSumCUDAFunctor><<>>( + p_src, s_index, d_index, p_output, index_size, slice_size, functor); + + ctx.template Alloc(dst_count); + int32_t* p_dst_count = dst_count->data(); + +#ifdef PADDLE_WITH_HIP + hipMemset(p_dst_count, 0, input_size * sizeof(int)); +#else + cudaMemset(p_dst_count, 0, input_size * sizeof(int)); +#endif + + int64_t grid_count = (index_size + block - 1) / block; + ComputeCountCUDAKernel<<>>( + p_dst_count, d_index, index_size); + + int64_t grid_mean_tmp = (input_size * slice_size + block - 1) / block; + int64_t grid_mean = + grid_mean_tmp < max_grid_dimx ? grid_mean_tmp : max_grid_dimx; + ManipulateMeanCUDAKernel<<>>( + p_output, p_dst_count, input_size, slice_size); + } +} + +template +void GraphSendRecvKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& src_index, + const DenseTensor& dst_index, + const std::string& pool_type, + DenseTensor* out, + DenseTensor* dst_count) { + auto index_type = src_index.dtype(); + if (index_type == phi::DataType::INT32) { + GraphSendRecvOpCUDAKernelLaunchHelper( + ctx, x, src_index, dst_index, pool_type, out, dst_count); + } else if (index_type == phi::DataType::INT64) { + GraphSendRecvOpCUDAKernelLaunchHelper( + ctx, x, src_index, dst_index, pool_type, out, dst_count); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(graph_send_recv, + GPU, + ALL_LAYOUT, + phi::GraphSendRecvKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/graph_send_recv_grad_kernel.h b/paddle/phi/kernels/graph_send_recv_grad_kernel.h new file mode 100644 index 0000000000..d163e6e278 --- /dev/null +++ b/paddle/phi/kernels/graph_send_recv_grad_kernel.h @@ -0,0 +1,33 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/utils/optional.h" + +namespace phi { + +template +void GraphSendRecvGradKernel(const Context& ctx, + const DenseTensor& out_grad, + paddle::optional x, + paddle::optional out, + const DenseTensor& src_index, + const DenseTensor& dst_index, + paddle::optional dst_count, + const std::string& pool_type, + DenseTensor* x_grad); +} // namespace phi diff --git a/paddle/phi/kernels/graph_send_recv_kernel.h b/paddle/phi/kernels/graph_send_recv_kernel.h new file mode 100644 index 0000000000..95dbdc4443 --- /dev/null +++ b/paddle/phi/kernels/graph_send_recv_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void GraphSendRecvKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& src_index, + const DenseTensor& dst_index, + const std::string& pool_type, + DenseTensor* out, + DenseTensor* dst_count); + +} // namespace phi diff --git a/paddle/phi/ops/compat/graph_send_recv_sig.cc b/paddle/phi/ops/compat/graph_send_recv_sig.cc new file mode 100644 index 0000000000..dacb8b25a8 --- /dev/null +++ b/paddle/phi/ops/compat/graph_send_recv_sig.cc @@ -0,0 +1,31 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature GraphSendRecvGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "graph_send_recv_grad", + {GradVarName("Out"), "X", "Out", "Src_index", "Dst_index", "Dst_count"}, + {"pool_type"}, + {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(graph_send_recv_grad, + phi::GraphSendRecvGradOpArgumentMapping); -- GitLab