util.cu.h 8.4 KB
Newer Older
W
wangguanzhong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

W
wangguanzhong 已提交
14
#include <vector>
W
wangguanzhong 已提交
15 16
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memory.h"
W
wangguanzhong 已提交
17
#include "paddle/fluid/platform/cuda_primitives.h"
W
wangguanzhong 已提交
18 19 20 21 22 23 24 25 26 27 28 29

namespace paddle {
namespace operators {

using framework::Tensor;

#define CUDA_1D_KERNEL_LOOP(i, n)                              \
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
       i += blockDim.x * gridDim.x)

template <typename T>
__global__ void FillConstant(T* x, int num, int fill_num) {
W
wangguanzhong 已提交
30
  CUDA_1D_KERNEL_LOOP(i, fill_num) { x[i] = static_cast<T>(num); }
W
wangguanzhong 已提交
31 32 33
}

template <typename T>
W
wangguanzhong 已提交
34 35 36 37 38 39 40 41
__global__ void SliceOnAxis(const T* x,
                            const int NC_num,
                            const int H,
                            const int W,
                            const int axis,
                            const int start,
                            const int end,
                            T* output) {
W
wangguanzhong 已提交
42 43 44 45 46 47 48 49
  int HW_num = H * W;
  int length = axis == 2 ? W : H;
  int sliced_len = end - start;
  int cur_HW_num = length * sliced_len;
  // slice input on H or W (axis is 2 or 3)
  CUDA_1D_KERNEL_LOOP(i, NC_num * cur_HW_num) {
    int NC_id = i / cur_HW_num;
    int HW_id = i % cur_HW_num;
W
wangguanzhong 已提交
50
    if (axis == 2) {
W
wangguanzhong 已提交
51 52 53 54 55 56
      output[i] = x[NC_id * HW_num + start * W + HW_id];
    } else if (axis == 3) {
      int col = HW_id % sliced_len;
      int row = HW_id / sliced_len;
      output[i] = x[NC_id * HW_num + row * W + start + col];
    }
W
wangguanzhong 已提交
57
  }
W
wangguanzhong 已提交
58 59 60
}

template <typename T>
W
wangguanzhong 已提交
61 62 63 64 65 66 67 68 69
__global__ void MaxOut(const T* input,
                       const int next_ind,
                       const int NC_num,
                       const int H,
                       const int W,
                       const int axis,
                       const int start,
                       const int end,
                       T* output) {
W
wangguanzhong 已提交
70
  int HW_num = H * W;
W
wangguanzhong 已提交
71
  int length = axis == 2 ? W : H;
W
wangguanzhong 已提交
72 73 74 75 76 77 78 79 80
  T cur = static_cast<T>(0.);
  T next = static_cast<T>(0.);
  T max_v = static_cast<T>(0.);
  int sliced_len = end - start;
  int cur_HW_num = length * sliced_len;
  // compare cur and next and assign max values to output
  CUDA_1D_KERNEL_LOOP(i, NC_num * cur_HW_num) {
    int NC_id = i / cur_HW_num;
    int HW_id = i % cur_HW_num;
W
wangguanzhong 已提交
81 82

    if (axis == 2) {
W
wangguanzhong 已提交
83 84
      cur = input[NC_id * HW_num + start * W + HW_id];
      next = input[NC_id * HW_num + next_ind * W + HW_id];
W
wangguanzhong 已提交
85
      max_v = cur > next ? cur : next;
W
wangguanzhong 已提交
86 87 88 89 90 91 92 93 94 95 96 97 98 99
      output[NC_id * HW_num + start * W + HW_id] = max_v;
    } else if (axis == 3) {
      int col = HW_id % sliced_len;
      int row = HW_id / sliced_len;
      cur = input[NC_id * HW_num + row * W + start + col];
      next = input[NC_id * HW_num + row * W + next_ind + col];
      max_v = cur > next ? cur : next;
      output[NC_id * HW_num + row * W + start + col] = max_v;
    }
    __syncthreads();
  }
}

template <typename T>
W
wangguanzhong 已提交
100 101 102 103 104 105 106 107
__global__ void UpdateMaxInfo(const T* input,
                              const int NC_num,
                              const int H,
                              const int W,
                              const int axis,
                              const int index,
                              T* max_val,
                              int* max_ind) {
W
wangguanzhong 已提交
108
  int length = axis == 2 ? W : H;
W
wangguanzhong 已提交
109
  int HW_num = H * W;
W
wangguanzhong 已提交
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
  T val = static_cast<T>(0.);
  CUDA_1D_KERNEL_LOOP(i, NC_num * length) {
    int NC_id = i / length;
    int length_id = i % length;
    if (axis == 2) {
      val = input[NC_id * HW_num + index * W + length_id];
    } else if (axis == 3) {
      val = input[NC_id * HW_num + length_id * W + index];
    }
    if (val > max_val[i]) {
      max_val[i] = val;
      max_ind[i] = index;
    }
    __syncthreads();
  }
}

template <typename T>
W
wangguanzhong 已提交
128 129 130 131 132 133 134 135
__global__ void ScatterAddOnAxis(const T* input,
                                 const int start,
                                 const int* max_ind,
                                 const int NC_num,
                                 const int H,
                                 const int W,
                                 const int axis,
                                 T* output) {
W
wangguanzhong 已提交
136 137
  int length = axis == 2 ? W : H;
  int HW_num = H * W;
W
wangguanzhong 已提交
138
  CUDA_1D_KERNEL_LOOP(i, NC_num * length) {
W
wangguanzhong 已提交
139 140 141 142
    int NC_id = i / length;
    int length_id = i % length;
    int id_ = max_ind[i];
    if (axis == 2) {
W
wangguanzhong 已提交
143 144 145 146
      platform::CudaAtomicAdd(output + NC_id * HW_num + id_ * W + length_id,
                              input[NC_id * HW_num + start * W + length_id]);
      // output[NC_id * HW_num + id_ * W + length_id] += input[NC_id * HW_num +
      // start * W + length_id];
W
wangguanzhong 已提交
147
    } else if (axis == 3) {
W
wangguanzhong 已提交
148 149 150 151
      platform::CudaAtomicAdd(output + NC_id * HW_num + length_id * W + id_,
                              input[NC_id * HW_num + length_id * W + start]);
      // output[NC_id * HW_num + length_id * W + id_] += input[NC_id * HW_num +
      // length_id * W + start];
W
wangguanzhong 已提交
152 153 154 155 156 157
    }
    __syncthreads();
  }
}

template <typename T>
W
wangguanzhong 已提交
158 159 160 161 162 163 164 165
__global__ void GetMaxInfo(const T* input,
                           const int NC_num,
                           const int H,
                           const int W,
                           const int axis,
                           const bool reverse,
                           T* max_val,
                           int* max_ind,
W
wangguanzhong 已提交
166
                           int* max_map) {
W
wangguanzhong 已提交
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
  int start = 0;
  int end = axis == 2 ? H : W;
  int s = reverse ? end - 1 : start;
  int e = reverse ? start - 1 : end;
  int step = reverse ? -1 : 1;
  int len = axis == 2 ? W : H;
  int loc = 0;
  T val = static_cast<T>(0.);
  for (int i = s;;) {
    if (i == s) {
      CUDA_1D_KERNEL_LOOP(j, NC_num * len) {
        int NC_id = j / len;
        int len_id = j % len;
        if (axis == 2) {
          loc = NC_id * H * W + i * W + len_id;
        } else if (axis == 3) {
          loc = NC_id * H * W + len_id * W + i;
        }
        max_ind[j] = i;
        max_map[loc] = max_ind[j];
        max_val[j] = input[loc];
        __syncthreads();
      }
    } else {
      CUDA_1D_KERNEL_LOOP(j, NC_num * len) {
        int NC_id = j / len;
        int len_id = j % len;

        if (axis == 2) {
          loc = NC_id * H * W + i * W + len_id;
        } else if (axis == 3) {
          loc = NC_id * H * W + len_id * W + i;
        }
        val = input[loc];
        T max_v = max_val[j];
        if (val > max_v) {
          max_val[j] = val;
          max_map[loc] = i;
          max_ind[j] = i;
        } else {
          max_map[loc] = max_ind[j];
        }
        __syncthreads();
      }
    }
    i += step;
    if (s < e && i >= e) break;
    if (s > e && i <= e) break;
  }
W
wangguanzhong 已提交
216 217 218
}

template <typename T>
W
wangguanzhong 已提交
219 220 221 222 223 224 225
__global__ void ScatterAddFw(const T* input,
                             const int* max_map,
                             const int NC_num,
                             const int H,
                             const int W,
                             const int axis,
                             T* output) {
W
wangguanzhong 已提交
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
  CUDA_1D_KERNEL_LOOP(i, NC_num * H * W) {
    int loc = max_map[i];
    int NC_id = i / (H * W);
    int len_id = 0;
    if (axis == 2) {
      len_id = i % W;
      output[i] = input[NC_id * H * W + loc * W + len_id];
    } else {
      len_id = i % (H * W) / W;
      output[i] = input[NC_id * H * W + len_id * W + loc];
    }
  }
}

template <typename T>
W
wangguanzhong 已提交
241 242 243 244 245 246 247
__global__ void ScatterAddBw(const T* input,
                             const int* max_map,
                             const int NC_num,
                             const int H,
                             const int W,
                             const int axis,
                             T* output) {
W
wangguanzhong 已提交
248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
  CUDA_1D_KERNEL_LOOP(i, NC_num * H * W) {
    int loc = max_map[i];
    int NC_id = i / (H * W);
    int len_id = 0;
    int offset = 0;
    if (axis == 2) {
      len_id = i % W;
      offset = NC_id * H * W + loc * W + len_id;
    } else {
      len_id = i % (H * W) / W;
      offset = NC_id * H * W + len_id * W + loc;
    }
    platform::CudaAtomicAdd(output + offset, input[i]);
  }
}

}  // namespace operators
}  // namespace paddle