graph_send_recv_kernel.cu 7.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/gpu/graph_send_recv_funcs.h"
#include "paddle/phi/kernels/graph_send_recv_kernel.h"

#include <thrust/device_vector.h>
#include <thrust/fill.h>
#include <algorithm>
#include <vector>

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename Context, typename T, typename IndexT>
void GraphSendRecvOpCUDAKernelLaunchHelper(const Context& ctx,
                                           const DenseTensor& x,
                                           const DenseTensor& src_index,
                                           const DenseTensor& dst_index,
                                           const std::string& pool_type,
35
                                           int64_t out_size,
36 37 38 39 40 41 42
                                           DenseTensor* out,
                                           DenseTensor* dst_count = nullptr) {
  const int& index_size = src_index.dims()[0];
  ctx.template Alloc<T>(out);
  T* p_output = out->data<T>();
  const auto& src_dims = x.dims();
  int64_t memset_size = 1;
43 44 45 46 47 48 49 50 51
  if (out_size <= 0) {
    for (int i = 0; i < src_dims.size(); ++i) {
      memset_size *= src_dims[i];
    }
  } else {
    memset_size = out_size;
    for (int i = 1; i < src_dims.size(); ++i) {
      memset_size *= src_dims[i];
    }
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
  }
  const size_t& memset_bytes = memset_size * sizeof(T);
  if (pool_type == "SUM" || pool_type == "MEAN") {
#ifdef PADDLE_WITH_HIP
    hipMemset(p_output, 0, memset_bytes);
#else
    cudaMemset(p_output, 0, memset_bytes);
#endif
  } else if (pool_type == "MAX") {
    thrust::device_ptr<T> p_output_ptr(p_output);
    thrust::fill(thrust::device,
                 p_output_ptr,
                 p_output_ptr + memset_size,
                 std::numeric_limits<T>::min());
  } else if (pool_type == "MIN") {
    thrust::device_ptr<T> p_output_ptr(p_output);
    thrust::fill(thrust::device,
                 p_output_ptr,
                 p_output_ptr + memset_size,
                 std::numeric_limits<T>::max());
  }

  if (index_size == 0) return;

  int64_t slice_size = 1;
  for (int i = 1; i < src_dims.size(); ++i) {
    slice_size *= src_dims[i];
  }
  const T* p_src = x.data<T>();
  const IndexT* s_index = src_index.data<IndexT>();
  const IndexT* d_index = dst_index.data<IndexT>();

#ifdef PADDLE_WITH_HIP
  int block = 256;
#else
  int block = 1024;
#endif
  int64_t n = slice_size * index_size;
  int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize()[0];
  int64_t grid_tmp = (n + block - 1) / block;
  int64_t grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx;
  int64_t input_size = src_dims[0];
  if (pool_type == "SUM") {
    GraphSendRecvSumCUDAFunctor<T, IndexT> functor;
    GraphSendRecvCUDAKernel<
        T,
        IndexT,
        GraphSendRecvSumCUDAFunctor<T,
                                    IndexT>><<<grid, block, 0, ctx.stream()>>>(
        p_src, s_index, d_index, p_output, index_size, slice_size, functor);
  } else if (pool_type == "MAX") {
    GraphSendRecvMaxCUDAFunctor<T, IndexT> functor;
    GraphSendRecvCUDAKernel<
        T,
        IndexT,
        GraphSendRecvMaxCUDAFunctor<T,
                                    IndexT>><<<grid, block, 0, ctx.stream()>>>(
        p_src, s_index, d_index, p_output, index_size, slice_size, functor);

111 112 113
    if (out_size > 0) {
      input_size = out_size;
    }
114 115 116 117 118 119 120 121 122 123 124 125 126 127
    int64_t grid_max_tmp = (input_size * slice_size + block - 1) / block;
    int64_t grid_max =
        grid_max_tmp < max_grid_dimx ? grid_max_tmp : max_grid_dimx;
    InputResetMaxCUDAKernel<T><<<grid_max, block, 0, ctx.stream()>>>(
        p_output, input_size, slice_size);
  } else if (pool_type == "MIN") {
    GraphSendRecvMinCUDAFunctor<T, IndexT> functor;
    GraphSendRecvCUDAKernel<
        T,
        IndexT,
        GraphSendRecvMinCUDAFunctor<T,
                                    IndexT>><<<grid, block, 0, ctx.stream()>>>(
        p_src, s_index, d_index, p_output, index_size, slice_size, functor);

128 129 130
    if (out_size > 0) {
      input_size = out_size;
    }
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
    int64_t grid_min_tmp = (input_size * slice_size + block - 1) / block;
    int64_t grid_min =
        grid_min_tmp < max_grid_dimx ? grid_min_tmp : max_grid_dimx;
    InputResetMinCUDAKernel<T><<<grid_min, block, 0, ctx.stream()>>>(
        p_output, input_size, slice_size);
  } else if (pool_type == "MEAN") {
    GraphSendRecvSumCUDAFunctor<T, IndexT> functor;
    GraphSendRecvCUDAKernel<
        T,
        IndexT,
        GraphSendRecvSumCUDAFunctor<T,
                                    IndexT>><<<grid, block, 0, ctx.stream()>>>(
        p_src, s_index, d_index, p_output, index_size, slice_size, functor);

    ctx.template Alloc<int32_t>(dst_count);
    int32_t* p_dst_count = dst_count->data<int32_t>();
147 148 149
    if (out_size > 0) {
      input_size = out_size;
    }
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174

#ifdef PADDLE_WITH_HIP
    hipMemset(p_dst_count, 0, input_size * sizeof(int));
#else
    cudaMemset(p_dst_count, 0, input_size * sizeof(int));
#endif

    int64_t grid_count = (index_size + block - 1) / block;
    ComputeCountCUDAKernel<T, IndexT><<<grid_count, block, 0, ctx.stream()>>>(
        p_dst_count, d_index, index_size);

    int64_t grid_mean_tmp = (input_size * slice_size + block - 1) / block;
    int64_t grid_mean =
        grid_mean_tmp < max_grid_dimx ? grid_mean_tmp : max_grid_dimx;
    ManipulateMeanCUDAKernel<T><<<grid_mean, block, 0, ctx.stream()>>>(
        p_output, p_dst_count, input_size, slice_size);
  }
}

template <typename T, typename Context>
void GraphSendRecvKernel(const Context& ctx,
                         const DenseTensor& x,
                         const DenseTensor& src_index,
                         const DenseTensor& dst_index,
                         const std::string& pool_type,
175
                         int64_t out_size,
176 177 178 179 180
                         DenseTensor* out,
                         DenseTensor* dst_count) {
  auto index_type = src_index.dtype();
  if (index_type == phi::DataType::INT32) {
    GraphSendRecvOpCUDAKernelLaunchHelper<Context, T, int32_t>(
181
        ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count);
182 183
  } else if (index_type == phi::DataType::INT64) {
    GraphSendRecvOpCUDAKernelLaunchHelper<Context, T, int64_t>(
184
        ctx, x, src_index, dst_index, pool_type, out_size, out, dst_count);
185 186 187 188 189 190 191 192 193 194 195 196 197
  }
}

}  // namespace phi

PD_REGISTER_KERNEL(graph_send_recv,
                   GPU,
                   ALL_LAYOUT,
                   phi::GraphSendRecvKernel,
                   float,
                   double,
                   int,
                   int64_t) {}