send_ue_recv_kernel.cu 11.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/kernels/send_ue_recv_kernel.h"
16 17

#include <thrust/device_vector.h>
Z
zlsh80826 已提交
18
#include <thrust/execution_policy.h>
19 20 21 22 23 24 25 26
#include <thrust/fill.h>
#include <algorithm>
#include <vector>

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
27 28 29
#include "paddle/phi/kernels/gpu/graph_send_recv_funcs.h"
#include "paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h"
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84

namespace phi {

template <typename Context, typename T, typename IndexT>
void GraphSendUERecvOpCUDAKernelLaunchHelper(const Context& ctx,
                                             const DenseTensor& x,
                                             const DenseTensor& e,
                                             const DenseTensor& src_index,
                                             const DenseTensor& dst_index,
                                             const std::string& message_op,
                                             const std::string& reduce_op,
                                             int64_t out_size,
                                             DenseTensor* out,
                                             DenseTensor* dst_count = nullptr) {
  const int& index_size = src_index.dims()[0];
  auto out_dims = out->dims();
  int64_t memset_size = 1;
  std::vector<int64_t> dims_ = phi::vectorize(out_dims);
  if (out_size <= 0) {
    dims_[0] = x.dims()[0];
  } else {
    dims_[0] = out_size;
  }
  out->Resize(phi::make_ddim(dims_));
  for (size_t i = 0; i < dims_.size(); i++) {
    memset_size *= dims_[i];
  }

  ctx.template Alloc<T>(out);
  T* out_data = out->data<T>();
  const size_t& memset_bytes = memset_size * sizeof(T);
  if (reduce_op == "SUM" || reduce_op == "MEAN") {
#ifdef PADDLE_WITH_HIP
    hipMemset(out_data, 0, memset_bytes);
#else
    cudaMemset(out_data, 0, memset_bytes);
#endif
  } else if (reduce_op == "MAX") {
    thrust::device_ptr<T> out_data_ptr(out_data);
    thrust::fill(thrust::device,
                 out_data_ptr,
                 out_data_ptr + memset_size,
                 std::numeric_limits<T>::lowest());

  } else if (reduce_op == "MIN") {
    thrust::device_ptr<T> out_data_ptr(out_data);
    thrust::fill(thrust::device,
                 out_data_ptr,
                 out_data_ptr + memset_size,
                 std::numeric_limits<T>::max());
  }

  if (index_size == 0) return;

  const auto& bcast_info = phi::CalcBCastInfo(x.dims(), e.dims());
85

86 87 88 89 90 91 92
  const T* x_data = x.data<T>();
  const T* e_data = e.data<T>();
  const IndexT* s_index = src_index.data<IndexT>();
  const IndexT* d_index = dst_index.data<IndexT>();

  thrust::device_vector<int64_t> x_bcastoff, e_bcastoff;
  if (bcast_info.use_bcast) {
W
Wang Xin 已提交
93
    CopyBCastOff(bcast_info, &x_bcastoff, &e_bcastoff);
94 95 96 97 98 99
  }

  int64_t out_len = bcast_info.out_len;
  const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock());
  const int nty = ctx.GetMaxThreadsPerBlock() / ntx;
  const int nbx = (out_len + ntx - 1) / ntx;
100
  const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty);
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
  const dim3 grid(nbx, nby);
  const dim3 block(ntx, nty);
  int64_t input_size = x.dims()[0];
  int block_ = 1024;
  if (reduce_op == "SUM" || reduce_op == "MEAN") {
    GraphSendUERecvSumCUDAFunctor<T> sum_functor;
    if (message_op == "ADD") {
      funcs::AddFunctor<T> add_funtor;
      GraphSendUERecvCUDAKernel<T,
                                IndexT,
                                GraphSendUERecvSumCUDAFunctor<T>,
                                funcs::AddFunctor<T>>
          <<<grid, block, 0, ctx.stream()>>>(
              x_data,
              e_data,
              s_index,
              d_index,
              thrust::raw_pointer_cast(x_bcastoff.data()),
              thrust::raw_pointer_cast(e_bcastoff.data()),
              out_data,
              index_size,
              bcast_info.l_len,
              bcast_info.r_len,
              out_len,
              bcast_info.use_bcast,
              add_funtor,
              sum_functor);
    } else if (message_op == "MUL") {
      funcs::MultiplyFunctor<T> mul_functor;
      GraphSendUERecvCUDAKernel<T,
                                IndexT,
                                GraphSendUERecvSumCUDAFunctor<T>,
                                funcs::MultiplyFunctor<T>>
          <<<grid, block, 0, ctx.stream()>>>(
              x_data,
              e_data,
              s_index,
              d_index,
              thrust::raw_pointer_cast(x_bcastoff.data()),
              thrust::raw_pointer_cast(e_bcastoff.data()),
              out_data,
              index_size,
              bcast_info.l_len,
              bcast_info.r_len,
              out_len,
              bcast_info.use_bcast,
              mul_functor,
              sum_functor);
    }
    if (reduce_op == "MEAN") {
      input_size = out_size <= 0 ? x.dims()[0] : out_size;
      dst_count->Resize({input_size});
      ctx.template Alloc<int>(dst_count);
      int* dst_count_data = dst_count->data<int>();
#ifdef PADDLE_WITH_HIP
      hipMemset(dst_count_data, 0, input_size * sizeof(int));
#else
      cudaMemset(dst_count_data, 0, input_size * sizeof(int));
#endif
      int64_t grid_count = (index_size + block_ - 1) / block_;
      ComputeCountCUDAKernel<T, IndexT>
          <<<grid_count, block_, 0, ctx.stream()>>>(
              dst_count_data, d_index, index_size);

      int64_t grid_mean = (input_size * out_len + block_ - 1) / block_;
      int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize()[0];
      int64_t grid_mean_ =
          grid_mean < max_grid_dimx ? grid_mean : max_grid_dimx;
      ManipulateMeanCUDAKernel<T><<<grid_mean_, block_, 0, ctx.stream()>>>(
          out_data, dst_count_data, input_size, out_len);
    }
  } else if (reduce_op == "MAX") {
    GraphSendUERecvMaxCUDAFunctor<T> max_functor;
    if (message_op == "ADD") {
      funcs::AddFunctor<T> add_funtor;
      GraphSendUERecvCUDAKernel<T,
                                IndexT,
                                GraphSendUERecvMaxCUDAFunctor<T>,
                                funcs::AddFunctor<T>>
          <<<grid, block, 0, ctx.stream()>>>(
              x_data,
              e_data,
              s_index,
              d_index,
              thrust::raw_pointer_cast(x_bcastoff.data()),
              thrust::raw_pointer_cast(e_bcastoff.data()),
              out_data,
              index_size,
              bcast_info.l_len,
              bcast_info.r_len,
              out_len,
              bcast_info.use_bcast,
              add_funtor,
              max_functor);
    } else if (message_op == "MUL") {
      funcs::MultiplyFunctor<T> mul_functor;
      GraphSendUERecvCUDAKernel<T,
                                IndexT,
                                GraphSendUERecvMaxCUDAFunctor<T>,
                                funcs::MultiplyFunctor<T>>
          <<<grid, block, 0, ctx.stream()>>>(
              x_data,
              e_data,
              s_index,
              d_index,
              thrust::raw_pointer_cast(x_bcastoff.data()),
              thrust::raw_pointer_cast(e_bcastoff.data()),
              out_data,
              index_size,
              bcast_info.l_len,
              bcast_info.r_len,
              out_len,
              bcast_info.use_bcast,
              mul_functor,
              max_functor);
    }
    if (out_size > 0) {
      input_size = out_size;
    }
    int64_t grid_max = (input_size * out_len + block_ - 1) / block_;
    int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize()[0];
    int64_t grid_max_ = grid_max < max_grid_dimx ? grid_max : max_grid_dimx;
    InputResetMaxCUDAKernel<T>
        <<<grid_max_, block_, 0, ctx.stream()>>>(out_data, input_size, out_len);
  } else if (reduce_op == "MIN") {
    GraphSendUERecvMinCUDAFunctor<T> min_functor;
    if (message_op == "ADD") {
      funcs::AddFunctor<T> add_funtor;
      GraphSendUERecvCUDAKernel<T,
                                IndexT,
                                GraphSendUERecvMinCUDAFunctor<T>,
                                funcs::AddFunctor<T>>
          <<<grid, block, 0, ctx.stream()>>>(
              x_data,
              e_data,
              s_index,
              d_index,
              thrust::raw_pointer_cast(x_bcastoff.data()),
              thrust::raw_pointer_cast(e_bcastoff.data()),
              out_data,
              index_size,
              bcast_info.l_len,
              bcast_info.r_len,
              out_len,
              bcast_info.use_bcast,
              add_funtor,
              min_functor);
    } else if (message_op == "MUL") {
      funcs::MultiplyFunctor<T> mul_functor;
      GraphSendUERecvCUDAKernel<T,
                                IndexT,
                                GraphSendUERecvMinCUDAFunctor<T>,
                                funcs::MultiplyFunctor<T>>
          <<<grid, block, 0, ctx.stream()>>>(
              x_data,
              e_data,
              s_index,
              d_index,
              thrust::raw_pointer_cast(x_bcastoff.data()),
              thrust::raw_pointer_cast(e_bcastoff.data()),
              out_data,
              index_size,
              bcast_info.l_len,
              bcast_info.r_len,
              out_len,
              bcast_info.use_bcast,
              mul_functor,
              min_functor);
    }
    if (out_size > 0) {
      input_size = out_size;
    }
    int64_t grid_min = (input_size * out_len + block_ - 1) / block_;
    int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize()[0];
    int64_t grid_min_ = grid_min < max_grid_dimx ? grid_min : max_grid_dimx;
    InputResetMinCUDAKernel<T>
        <<<grid_min_, block_, 0, ctx.stream()>>>(out_data, input_size, out_len);
  }
}

template <typename T, typename Context>
282 283 284 285 286 287 288 289 290 291
void SendUERecvKernel(const Context& ctx,
                      const DenseTensor& x,
                      const DenseTensor& y,
                      const DenseTensor& src_index,
                      const DenseTensor& dst_index,
                      const std::string& message_op,
                      const std::string& reduce_op,
                      const IntArray& out_size,
                      DenseTensor* out,
                      DenseTensor* dst_count) {
292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322
  auto index_type = src_index.dtype();
  auto& out_size_data = out_size.GetData();
  if (index_type == phi::DataType::INT32) {
    GraphSendUERecvOpCUDAKernelLaunchHelper<Context, T, int32_t>(
        ctx,
        x,
        y,
        src_index,
        dst_index,
        message_op,
        reduce_op,
        out_size_data[0],
        out,
        dst_count);
  } else if (index_type == phi::DataType::INT64) {
    GraphSendUERecvOpCUDAKernelLaunchHelper<Context, T, int64_t>(
        ctx,
        x,
        y,
        src_index,
        dst_index,
        message_op,
        reduce_op,
        out_size_data[0],
        out,
        dst_count);
  }
}

}  // namespace phi

323
PD_REGISTER_KERNEL(send_ue_recv,
324 325
                   GPU,
                   ALL_LAYOUT,
326
                   phi::SendUERecvKernel,
327 328 329 330
                   float,
                   double,
                   int,
                   int64_t,
331 332 333
                   phi::dtype::float16) {
  kernel->OutputAt(1).SetDataType(phi::DataType::INT32);
}