send_ue_recv_grad_kernel.cu 23.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/phi/kernels/send_ue_recv_grad_kernel.h"
16 17 18 19 20 21 22 23
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/gpu/graph_send_recv_funcs.h"
#include "paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h"
24
#include "paddle/phi/kernels/impl/graph_message_passing_impl.h"
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
#include "paddle/phi/kernels/reduce_sum_kernel.h"

namespace phi {

template <typename Context, typename T, typename IndexT>
void CalculateXEGradForMinMax(const Context& ctx,
                              const T* out_grad,
                              const T* x_data,
                              const T* e_data,
                              const phi::DDim& x_dims,
                              const phi::DDim& e_dims,
                              const IndexT* s_index,
                              const IndexT* d_index,
                              const std::string& message_op,
                              const std::string& reduce_op,
                              int64_t index_size,
                              T* x_grad,
                              T* e_grad,
                              const DenseTensor* out = nullptr) {
  const T* out_data = out->data<T>();
  const auto& bcast_info = phi::CalcBCastInfo(x_dims, e_dims);
  thrust::device_vector<int64_t> l_bcastoff, r_bcastoff;
  if (bcast_info.use_bcast) {
W
Wang Xin 已提交
48
    CopyBCastOff(bcast_info, &l_bcastoff, &r_bcastoff);
49 50 51 52 53 54
  }

  int64_t out_len = bcast_info.out_len;
  const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock());
  const int nty = ctx.GetMaxThreadsPerBlock() / ntx;
  const int nbx = (out_len + ntx - 1) / ntx;
55
  const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty);
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
  const dim3 grid(nbx, nby);
  const dim3 block(ntx, nty);

  if (message_op == "ADD") {
    ManipulateMinMaxGradCUDAKernelForAdd<T, IndexT>
        <<<grid, block, 0, ctx.stream()>>>(
            x_data,
            e_data,
            out_data,
            out_grad,
            d_index,
            s_index,
            thrust::raw_pointer_cast(l_bcastoff.data()),
            thrust::raw_pointer_cast(r_bcastoff.data()),
            x_grad,
            e_grad,
            index_size,
            bcast_info.l_len,
            bcast_info.r_len,
            out_len,
            bcast_info.use_bcast);
  } else if (message_op == "MUL") {
    ManipulateMinMaxGradCUDAKernelForMul<T, IndexT>
        <<<grid, block, 0, ctx.stream()>>>(
            x_data,
            e_data,
            out_data,
            out_grad,
            d_index,
            s_index,
            thrust::raw_pointer_cast(l_bcastoff.data()),
            thrust::raw_pointer_cast(r_bcastoff.data()),
            x_grad,
            e_grad,
            index_size,
            bcast_info.l_len,
            bcast_info.r_len,
            out_len,
            bcast_info.use_bcast);
  }
}

template <typename Context, typename T, typename IndexT>
void CalculateXGrad(const Context& ctx,
                    const T* out_grad,
                    const T* x_data,
                    const T* e_data,
                    const phi::DDim& out_grad_dims,
                    const phi::DDim& x_dims,
                    const phi::DDim& e_dims,
                    const IndexT* s_index,
                    const IndexT* d_index,
                    const std::string& message_op,
                    const std::string& reduce_op,
                    int64_t index_size,
                    int64_t slice_size,
                    T* x_grad,
                    const DenseTensor& out_grad_tensor,
                    const DenseTensor* dst_count = nullptr,
                    const DenseTensor* out = nullptr) {
  int block = 1024;
  int64_t n = slice_size * index_size;
  int max_grid_dimx = ctx.GetCUDAMaxGridDimSize()[0];
  int64_t grid_tmp = (n + block - 1) / block;
  int64_t grid = grid_tmp < max_grid_dimx ? grid_tmp : max_grid_dimx;
  std::vector<int64_t> reduce_idx;
  bool reduce = ReduceGrad(out_grad_dims, x_dims, reduce_idx);
  if (reduce_op == "SUM") {
    if (message_op == "ADD") {
      GraphSendRecvSumCUDAFunctor<T, IndexT> functor;
      if (!reduce) {
        GraphSendRecvCUDAKernel<T,
                                IndexT,
                                GraphSendRecvSumCUDAFunctor<T, IndexT>>
            <<<grid, block, 0, ctx.stream()>>>(out_grad,
                                               d_index,
                                               s_index,
                                               x_grad,
                                               index_size,
                                               slice_size,
                                               functor);
      } else {
        const auto& bcast_info = phi::CalcBCastInfo(out_grad_dims, e_dims);
        DenseTensor x_grad_v2 =
            phi::EmptyLike<T, Context>(ctx, out_grad_tensor);
        phi::funcs::SetConstant<Context, T>()(ctx, &x_grad_v2, T(0));
        T* x_grad_v2_data = x_grad_v2.data<T>();
        GraphSendRecvCUDAKernel<T,
                                IndexT,
                                GraphSendRecvSumCUDAFunctor<T, IndexT>>
            <<<grid, block, 0, ctx.stream()>>>(out_grad,
                                               d_index,
                                               s_index,
                                               x_grad_v2_data,
                                               index_size,
                                               bcast_info.out_len,
                                               functor);
        // Run reduce_sum
154 155 156 157 158 159
        DenseTensor x_grad_out =
            phi::Sum<T, Context>(ctx,
                                 x_grad_v2,
                                 phi::IntArray(reduce_idx),
                                 phi::CppTypeToDataType<T>::Type(),
                                 true);
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
#ifdef PADDLE_WITH_HIP
        hipMemcpy(x_grad,
                  x_grad_out.data<T>(),
                  x_grad_out.numel() * sizeof(T),
                  hipMemcpyDeviceToDevice);
#else
        cudaMemcpy(x_grad,
                   x_grad_out.data<T>(),
                   x_grad_out.numel() * sizeof(T),
                   cudaMemcpyDeviceToDevice);
#endif
      }
    } else if (message_op == "MUL") {
      const auto& bcast_info = phi::CalcBCastInfo(out_grad_dims, e_dims);
      thrust::device_vector<int64_t> l_bcastoff, r_bcastoff;
      if (bcast_info.use_bcast) {
W
Wang Xin 已提交
176
        CopyBCastOff(bcast_info, &l_bcastoff, &r_bcastoff);
177 178 179 180 181
      }
      int64_t out_len = bcast_info.out_len;
      const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock());
      const int nty = ctx.GetMaxThreadsPerBlock() / ntx;
      const int nbx = (out_len + ntx - 1) / ntx;
182
      const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty);
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
      const dim3 grid_(nbx, nby);
      const dim3 block_(ntx, nty);
      funcs::MultiplyFunctor<T> mul_functor;
      GraphSendUERecvSumCUDAFunctor<T> sum_functor;
      if (!reduce) {
        GraphSendUERecvCUDAKernel<T,
                                  IndexT,
                                  GraphSendUERecvSumCUDAFunctor<T>,
                                  funcs::MultiplyFunctor<T>>
            <<<grid_, block_, 0, ctx.stream()>>>(
                out_grad,
                e_data,
                d_index,
                s_index,
                thrust::raw_pointer_cast(l_bcastoff.data()),
                thrust::raw_pointer_cast(r_bcastoff.data()),
                x_grad,
                index_size,
                bcast_info.l_len,
                bcast_info.r_len,
                out_len,
                bcast_info.use_bcast,
                mul_functor,
                sum_functor);
      } else {
        DenseTensor x_grad_v2 =
            phi::EmptyLike<T, Context>(ctx, out_grad_tensor);
        phi::funcs::SetConstant<Context, T>()(ctx, &x_grad_v2, T(0));
        T* x_grad_v2_data = x_grad_v2.data<T>();
        GraphSendUERecvCUDAKernel<T,
                                  IndexT,
                                  GraphSendUERecvSumCUDAFunctor<T>,
                                  funcs::MultiplyFunctor<T>>
            <<<grid_, block_, 0, ctx.stream()>>>(
                out_grad,
                e_data,
                d_index,
                s_index,
                thrust::raw_pointer_cast(l_bcastoff.data()),
                thrust::raw_pointer_cast(r_bcastoff.data()),
                x_grad_v2_data,
                index_size,
                bcast_info.l_len,
                bcast_info.r_len,
                out_len,
                bcast_info.use_bcast,
                mul_functor,
                sum_functor);
231 232 233 234 235 236
        DenseTensor x_grad_out =
            phi::Sum<T, Context>(ctx,
                                 x_grad_v2,
                                 phi::IntArray(reduce_idx),
                                 phi::CppTypeToDataType<T>::Type(),
                                 true);
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
#ifdef PADDLE_WITH_HIP
        hipMemcpy(x_grad,
                  x_grad_out.data<T>(),
                  x_grad_out.numel() * sizeof(T),
                  hipMemcpyDeviceToDevice);
#else
        cudaMemcpy(x_grad,
                   x_grad_out.data<T>(),
                   x_grad_out.numel() * sizeof(T),
                   cudaMemcpyDeviceToDevice);
#endif
      }
    }
  } else if (reduce_op == "MEAN") {
    const int* s_count = dst_count->data<int>();
    if (message_op == "ADD") {
      if (!reduce) {
        ManipulateMeanGradCUDAKernel<T, IndexT>
            <<<grid, block, 0, ctx.stream()>>>(out_grad,
                                               d_index,
                                               s_index,
                                               x_grad,
                                               index_size,
                                               slice_size,
                                               s_count);
      } else {
        const auto& bcast_info = phi::CalcBCastInfo(out_grad_dims, e_dims);
        DenseTensor x_grad_v2 =
            phi::EmptyLike<T, Context>(ctx, out_grad_tensor);
        phi::funcs::SetConstant<Context, T>()(ctx, &x_grad_v2, T(0));
        T* x_grad_v2_data = x_grad_v2.data<T>();
        ManipulateMeanGradCUDAKernel<T, IndexT>
            <<<grid, block, 0, ctx.stream()>>>(out_grad,
                                               d_index,
                                               s_index,
                                               x_grad_v2_data,
                                               index_size,
                                               bcast_info.out_len,
                                               s_count);
        // Run reduce_sum
277 278 279 280 281 282
        DenseTensor x_grad_out =
            phi::Sum<T, Context>(ctx,
                                 x_grad_v2,
                                 phi::IntArray(reduce_idx),
                                 phi::CppTypeToDataType<T>::Type(),
                                 true);
283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298
#ifdef PADDLE_WITH_HIP
        hipMemcpy(x_grad,
                  x_grad_out.data<T>(),
                  x_grad_out.numel() * sizeof(T),
                  hipMemcpyDeviceToDevice);
#else
        cudaMemcpy(x_grad,
                   x_grad_out.data<T>(),
                   x_grad_out.numel() * sizeof(T),
                   cudaMemcpyDeviceToDevice);
#endif
      }
    } else if (message_op == "MUL") {
      const auto& bcast_info = phi::CalcBCastInfo(out_grad_dims, e_dims);
      thrust::device_vector<int64_t> l_bcastoff, r_bcastoff;
      if (bcast_info.use_bcast) {
W
Wang Xin 已提交
299
        CopyBCastOff(bcast_info, &l_bcastoff, &r_bcastoff);
300 301 302 303 304
      }
      int64_t out_len = bcast_info.out_len;
      const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock());
      const int nty = ctx.GetMaxThreadsPerBlock() / ntx;
      const int nbx = (out_len + ntx - 1) / ntx;
305
      const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty);
306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344
      const dim3 grid_(nbx, nby);
      const dim3 block_(ntx, nty);
      if (!reduce) {
        ManipulateMeanGradCUDAKernelForMulX<T, IndexT>
            <<<grid_, block_, 0, ctx.stream()>>>(
                out_grad,
                e_data,
                d_index,
                s_index,
                s_count,
                thrust::raw_pointer_cast(l_bcastoff.data()),
                thrust::raw_pointer_cast(r_bcastoff.data()),
                x_grad,
                index_size,
                bcast_info.l_len,
                bcast_info.r_len,
                out_len,
                bcast_info.use_bcast);
      } else {
        DenseTensor x_grad_v2 =
            phi::EmptyLike<T, Context>(ctx, out_grad_tensor);
        phi::funcs::SetConstant<Context, T>()(ctx, &x_grad_v2, T(0));
        T* x_grad_v2_data = x_grad_v2.data<T>();
        ManipulateMeanGradCUDAKernelForMulX<T, IndexT>
            <<<grid_, block_, 0, ctx.stream()>>>(
                out_grad,
                e_data,
                d_index,
                s_index,
                s_count,
                thrust::raw_pointer_cast(l_bcastoff.data()),
                thrust::raw_pointer_cast(r_bcastoff.data()),
                x_grad_v2_data,
                index_size,
                bcast_info.l_len,
                bcast_info.r_len,
                out_len,
                bcast_info.use_bcast);
        // Run reduce_sum
345 346 347 348 349 350
        DenseTensor x_grad_out =
            phi::Sum<T, Context>(ctx,
                                 x_grad_v2,
                                 phi::IntArray(reduce_idx),
                                 phi::CppTypeToDataType<T>::Type(),
                                 true);
351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384
        // TODO(daisiming): Whether use x_grad instead.
#ifdef PADDLE_WITH_HIP
        hipMemcpy(x_grad,
                  x_grad_out.data<T>(),
                  x_grad_out.numel() * sizeof(T),
                  hipMemcpyDeviceToDevice);
#else
        cudaMemcpy(x_grad,
                   x_grad_out.data<T>(),
                   x_grad_out.numel() * sizeof(T),
                   cudaMemcpyDeviceToDevice);
#endif
      }
    }
  }
}

template <typename Context, typename T, typename IndexT>
void CalculateEGrad(const Context& ctx,
                    const T* out_grad,
                    const T* x_data,
                    const T* e_data,
                    const phi::DDim& x_dims,
                    const phi::DDim& e_dims,
                    const IndexT* s_index,
                    const IndexT* d_index,
                    const std::string& message_op,
                    const std::string& reduce_op,
                    int64_t index_size,
                    T* e_grad,
                    const DenseTensor* dst_count = nullptr) {
  const auto& bcast_info = phi::CalcBCastInfo(x_dims, e_dims);
  thrust::device_vector<int64_t> l_bcastoff, r_bcastoff;
  if (bcast_info.use_bcast) {
W
Wang Xin 已提交
385
    CopyBCastOff(bcast_info, &l_bcastoff, &r_bcastoff);
386 387 388 389 390
  }
  int64_t out_len = bcast_info.out_len;
  const int ntx = FindNumThreads(out_len, ctx.GetMaxThreadsPerBlock());
  const int nty = ctx.GetMaxThreadsPerBlock() / ntx;
  const int nbx = (out_len + ntx - 1) / ntx;
391
  const int nby = FindNumBlocks('y', (index_size + nty - 1) / nty);
392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554
  const dim3 grid(nbx, nby);
  const dim3 block(ntx, nty);
  if (reduce_op == "SUM") {
    if (message_op == "ADD") {
      ManipulateSumGradCUDAKernelForAddE<T, IndexT>
          <<<grid, block, 0, ctx.stream()>>>(
              out_grad,
              d_index,
              thrust::raw_pointer_cast(r_bcastoff.data()),
              e_grad,
              index_size,
              bcast_info.r_len,
              out_len,
              bcast_info.use_bcast);
    } else if (message_op == "MUL") {
      ManipulateSumGradCUDAKernelForMulE<T, IndexT>
          <<<grid, block, 0, ctx.stream()>>>(
              x_data,
              out_grad,
              s_index,
              d_index,
              thrust::raw_pointer_cast(l_bcastoff.data()),
              thrust::raw_pointer_cast(r_bcastoff.data()),
              e_grad,
              index_size,
              bcast_info.l_len,
              bcast_info.r_len,
              out_len,
              bcast_info.use_bcast);
    }
  } else if (reduce_op == "MEAN") {
    const int* s_count = dst_count->data<int>();
    if (message_op == "ADD") {
      ManipulateMeanGradCUDAKernelForAddE<T, IndexT>
          <<<grid, block, 0, ctx.stream()>>>(
              out_grad,
              d_index,
              s_count,
              thrust::raw_pointer_cast(r_bcastoff.data()),
              e_grad,
              index_size,
              bcast_info.r_len,
              out_len,
              bcast_info.use_bcast);
    } else if (message_op == "MUL") {
      ManipulateMeanGradCUDAKernelForMulE<T, IndexT>
          <<<grid, block, 0, ctx.stream()>>>(
              x_data,
              out_grad,
              s_index,
              d_index,
              s_count,
              thrust::raw_pointer_cast(l_bcastoff.data()),
              thrust::raw_pointer_cast(r_bcastoff.data()),
              e_grad,
              index_size,
              bcast_info.l_len,
              bcast_info.r_len,
              out_len,
              bcast_info.use_bcast);
    }
  }
}

template <typename Context, typename T, typename IndexT>
void GraphSendUERecvGradOpCUDAKernelLaunchHelper(
    const Context& ctx,
    const DenseTensor& out_grad,
    const DenseTensor& x,
    const DenseTensor& e,
    const DenseTensor& src_index,
    const DenseTensor& dst_index,
    const std::string& message_op,
    const std::string& reduce_op,
    DenseTensor* x_grad,
    DenseTensor* e_grad,
    const DenseTensor* dst_count = nullptr,
    const DenseTensor* out = nullptr) {
  const int& index_size = dst_index.dims()[0];

  ctx.template Alloc<T>(x_grad);
  T* x_grad_data = x_grad->data<T>();
  ctx.template Alloc<T>(e_grad);
  T* e_grad_data = e_grad->data<T>();
  const auto& x_dims = x.dims();
  const auto& e_dims = e.dims();
  int64_t memset_size_x = 1, memset_size_e = 1;
  int64_t slice_size = 1;
  for (int i = 0; i < x_dims.size(); i++) {
    memset_size_x *= x_dims[i];
    if (i > 0) slice_size *= x_dims[i];
  }
  for (int i = 0; i < e_dims.size(); i++) {
    memset_size_e *= e_dims[i];
  }
  const size_t& memset_bytes_x = memset_size_x * sizeof(T);
  const size_t& memset_bytes_e = memset_size_e * sizeof(T);
#ifdef PADDLE_WITH_HIP
  hipMemset(x_grad_data, 0, memset_bytes_x);
  hipMemset(e_grad_data, 0, memset_bytes_e);
#else
  cudaMemset(x_grad_data, 0, memset_bytes_x);
  cudaMemset(e_grad_data, 0, memset_bytes_e);
#endif

  if (index_size == 0) return;

  const T* out_grad_data = out_grad.data<T>();
  const T* x_data = x.data<T>();
  const T* e_data = e.data<T>();
  const IndexT* s_index = src_index.data<IndexT>();
  const IndexT* d_index = dst_index.data<IndexT>();

  if (reduce_op == "SUM" || reduce_op == "MEAN") {
    CalculateXGrad<Context, T, IndexT>(ctx,
                                       out_grad_data,
                                       x_data,
                                       e_data,
                                       out_grad.dims(),
                                       x_dims,
                                       e_dims,
                                       s_index,
                                       d_index,
                                       message_op,
                                       reduce_op,
                                       index_size,
                                       slice_size,
                                       x_grad_data,
                                       out_grad,
                                       dst_count,
                                       out);
    CalculateEGrad<Context, T, IndexT>(ctx,
                                       out_grad_data,
                                       x_data,
                                       e_data,
                                       x_dims,
                                       e_dims,
                                       s_index,
                                       d_index,
                                       message_op,
                                       reduce_op,
                                       index_size,
                                       e_grad_data,
                                       dst_count);
  } else if (reduce_op == "MIN" || reduce_op == "MAX") {
    CalculateXEGradForMinMax<Context, T, IndexT>(ctx,
                                                 out_grad_data,
                                                 x_data,
                                                 e_data,
                                                 x_dims,
                                                 e_dims,
                                                 s_index,
                                                 d_index,
                                                 message_op,
                                                 reduce_op,
                                                 index_size,
                                                 x_grad_data,
                                                 e_grad_data,
                                                 out);
  }
}

template <typename T, typename Context>
555 556 557 558 559 560 561 562 563 564 565 566
void SendUERecvGradKernel(const Context& ctx,
                          const DenseTensor& x,
                          const DenseTensor& y,
                          const DenseTensor& src_index,
                          const DenseTensor& dst_index,
                          const paddle::optional<DenseTensor>& out,
                          const paddle::optional<DenseTensor>& dst_count,
                          const DenseTensor& out_grad,
                          const std::string& message_op,
                          const std::string& reduce_op,
                          DenseTensor* x_grad,
                          DenseTensor* y_grad) {
567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600
  auto index_type = src_index.dtype();
  if (index_type == phi::DataType::INT32) {
    GraphSendUERecvGradOpCUDAKernelLaunchHelper<Context, T, int32_t>(
        ctx,
        out_grad,
        x,
        y,
        src_index,
        dst_index,
        message_op,
        reduce_op,
        x_grad,
        y_grad,
        dst_count.get_ptr(),
        out.get_ptr());
  } else if (index_type == phi::DataType::INT64) {
    GraphSendUERecvGradOpCUDAKernelLaunchHelper<Context, T, int64_t>(
        ctx,
        out_grad,
        x,
        y,
        src_index,
        dst_index,
        message_op,
        reduce_op,
        x_grad,
        y_grad,
        dst_count.get_ptr(),
        out.get_ptr());
  }
}

}  // namespace phi

601
PD_REGISTER_KERNEL(send_ue_recv_grad,
602 603
                   GPU,
                   ALL_LAYOUT,
604
                   phi::SendUERecvGradKernel,
605 606 607 608 609
                   float,
                   double,
                   int,
                   int64_t,
                   phi::dtype::float16) {}