graph_send_recv_funcs.h 6.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <thrust/device_vector.h>
#include <thrust/fill.h>
18

19 20 21 22
#include <algorithm>
#include <vector>

#include "paddle/phi/backends/gpu/gpu_context.h"
W
Wang Xin 已提交
23
#include "paddle/phi/backends/gpu/gpu_primitives.h"
24
#include "paddle/phi/core/hostdevice.h"
25
#include "paddle/phi/kernels/send_u_recv_kernel.h"
26 27 28 29 30 31 32 33 34

namespace phi {

template <typename T, typename IndexT>
struct GraphSendRecvSumCUDAFunctor {
  DEVICE inline void operator()(const T* params,
                                T* output,
                                const IndexT& in_i,
                                const IndexT& out_i) {
W
Wang Xin 已提交
35
    phi::CudaAtomicAdd(output + out_i, *(params + in_i));
36 37 38 39 40 41 42 43 44
  }
};

template <typename T, typename IndexT>
struct GraphSendRecvMaxCUDAFunctor {
  DEVICE inline void operator()(const T* params,
                                T* output,
                                const IndexT& in_i,
                                const IndexT& out_i) {
W
Wang Xin 已提交
45
    phi::CudaAtomicMax(output + out_i, *(params + in_i));
46 47 48 49 50 51 52 53 54
  }
};

template <typename T, typename IndexT>
struct GraphSendRecvMinCUDAFunctor {
  DEVICE inline void operator()(const T* params,
                                T* output,
                                const IndexT& in_i,
                                const IndexT& out_i) {
W
Wang Xin 已提交
55
    phi::CudaAtomicMin(output + out_i, *(params + in_i));
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
  }
};

template <typename T, typename IndexT, typename Functor>
__global__ void GraphSendRecvCUDAKernel(const T* params,
                                        const IndexT* src_indices,
                                        const IndexT* dst_indices,
                                        T* output,
                                        size_t index_size,
                                        size_t slice_size,
                                        Functor functor) {
  CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
    int64_t indices_i = i / slice_size;
    int64_t slice_i = i - indices_i * slice_size;
    IndexT src_i = src_indices[indices_i];
    IndexT dst_i = dst_indices[indices_i];
    int64_t in_i = src_i * slice_size + slice_i;
    int64_t out_i = dst_i * slice_size + slice_i;
    functor(params, output, in_i, out_i);
  }
}

// For max
template <typename T>
__global__ void InputResetMaxCUDAKernel(T* output,
                                        size_t input_size,
                                        size_t slice_size) {
  CUDA_KERNEL_LOOP_TYPE(i, input_size * slice_size, int64_t) {
84
    if (*(output + i) == std::numeric_limits<T>::lowest()) {
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
      *(output + i) = 0;
    }
  }
}

// For min
template <typename T>
__global__ void InputResetMinCUDAKernel(T* output,
                                        size_t input_size,
                                        size_t slice_size) {
  CUDA_KERNEL_LOOP_TYPE(i, input_size * slice_size, int64_t) {
    if (*(output + i) == std::numeric_limits<T>::max()) {
      *(output + i) = 0;
    }
  }
}

// Get dst_count
template <typename T, typename IndexT>
__global__ void ComputeCountCUDAKernel(int32_t* count,
                                       const IndexT* dst_indices,
                                       size_t index_size) {
  CUDA_KERNEL_LOOP_TYPE(i, index_size, int64_t) {
    IndexT dst_i = dst_indices[i];
W
Wang Xin 已提交
109
    phi::CudaAtomicAdd(count + dst_i, 1);
110 111 112 113 114 115 116 117 118 119 120 121
  }
}

// For forward mean
template <typename T>
__global__ void ManipulateMeanCUDAKernel(T* output,
                                         int32_t* count,
                                         size_t input_size,
                                         size_t slice_size) {
  CUDA_KERNEL_LOOP_TYPE(i, input_size * slice_size, int64_t) {
    int64_t c_index = i / slice_size;
    if (*(count + c_index) > 1) {
122
      *(output + i) = *(output + i) / static_cast<T>(*(count + c_index));
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
    }
  }
}

// For backward mean
template <typename T, typename IndexT>
__global__ void ManipulateMeanGradCUDAKernel(const T* params,
                                             const IndexT* src_indices,
                                             const IndexT* dst_indices,
                                             T* output,
                                             size_t index_size,
                                             size_t slice_size,
                                             const int32_t* dst_count) {
  CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
    int64_t indices_i = i / slice_size;
    int64_t slice_i = i - indices_i * slice_size;
    IndexT src_i = src_indices[indices_i];
    IndexT dst_i = dst_indices[indices_i];
    int64_t in_i = src_i * slice_size + slice_i;
    int64_t out_i = dst_i * slice_size + slice_i;
W
Wang Xin 已提交
143 144
    phi::CudaAtomicAdd(output + out_i,
                       *(params + in_i) / static_cast<T>(dst_count[src_i]));
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
  }
}

// For backward min and max
template <typename T, typename IndexT>
__global__ void ManipulateMinMaxGradCUDAKernel(const T* params,
                                               const IndexT* src_indices,
                                               const IndexT* dst_indices,
                                               T* output,
                                               size_t index_size,
                                               size_t slice_size,
                                               const T* ptr_input,
                                               const T* ptr_output) {
  CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size, int64_t) {
    int64_t indices_i = i / slice_size;
    int64_t slice_i = i - indices_i * slice_size;
    IndexT src_i = src_indices[indices_i];
    IndexT dst_i = dst_indices[indices_i];
    int64_t in_i = src_i * slice_size + slice_i;
    int64_t out_i = dst_i * slice_size + slice_i;
W
Wang Xin 已提交
165 166 167
    phi::CudaAtomicAdd(output + out_i,
                       *(params + in_i) * static_cast<T>(*(ptr_input + out_i) ==
                                                         *(ptr_output + in_i)));
168 169 170 171
  }
}

}  // namespace phi