broadcast_impl.cu 20.0 KB
Newer Older
W
wilfChen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

17
#include <vector>
W
wilfChen 已提交
18
#include <iostream>
19

L
liubuyu 已提交
20 21
#include "backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh"
#include "runtime/device/gpu/cuda_common.h"
W
wilfChen 已提交
22

W
wilfChen 已提交
23 24
// Basic function
template <typename T>
W
wilfChen 已提交
25
struct GreaterFunc {
W
wilfChen 已提交
26
  __device__ __host__ __forceinline__ bool operator()(const T &lhs, const T &rhs) { return lhs > rhs ? true : false; }
W
wilfChen 已提交
27 28
};

W
wilfChen 已提交
29
template <typename T>
W
wilfChen 已提交
30
struct LessFunc {
W
wilfChen 已提交
31
  __device__ __host__ __forceinline__ bool operator()(const T &lhs, const T &rhs) { return lhs < rhs ? true : false; }
W
wilfChen 已提交
32 33
};

W
wilfChen 已提交
34
template <typename T>
W
wilfChen 已提交
35
struct MinimumFunc {
W
wilfChen 已提交
36
  __device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return lhs < rhs ? lhs : rhs; }
W
wilfChen 已提交
37 38
};

W
wilfChen 已提交
39
template <typename T>
W
wilfChen 已提交
40
struct MaximumFunc {
W
wilfChen 已提交
41
  __device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return lhs > rhs ? lhs : rhs; }
W
wilfChen 已提交
42 43
};

W
wilfChen 已提交
44
template <typename T>
W
wilfChen 已提交
45
struct PowerFunc {
W
wilfChen 已提交
46
  __device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return pow(lhs, rhs); }
W
wilfChen 已提交
47 48
};

49
template <>
W
wilfChen 已提交
50 51
struct PowerFunc<half> {
  __device__ __host__ __forceinline__ half operator()(const half &lhs, const half &rhs) {
52 53 54 55
    return __float2half(pow(__half2float(lhs), __half2float(rhs)));
  }
};

W
wilfChen 已提交
56 57 58 59 60 61 62 63 64 65 66 67
template <>
struct PowerFunc<half2> {
  __device__ __host__ __forceinline__ half2 operator()(const half2 &lhs, const half2 &rhs) {
    float2 base = __half22float2(lhs);
    float2 index = __half22float2(rhs);
    base.x = pow(base.x, index.x);
    base.y = pow(base.y, index.y);
    return __float22half2_rn(base);
  }
};

template <typename T>
W
wilfChen 已提交
68
struct RealDivFunc {
W
wilfChen 已提交
69
  __device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs / rhs); }
W
wilfChen 已提交
70 71
};

W
wilfChen 已提交
72
template <typename T>
Z
ZPaC 已提交
73
struct DivFunc {
W
wilfChen 已提交
74
  __device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs / rhs); }
Z
ZPaC 已提交
75 76
};

W
wilfChen 已提交
77
template <typename T>
W
wilfChen 已提交
78
struct MulFunc {
W
wilfChen 已提交
79
  __device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs * rhs); }
W
wilfChen 已提交
80 81
};

W
wilfChen 已提交
82
template <typename T>
W
wilfChen 已提交
83
struct SubFunc {
W
wilfChen 已提交
84
  __device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs - rhs); }
W
wilfChen 已提交
85 86
};

W
wilfChen 已提交
87
template <typename T>
V
VectorSL 已提交
88
struct AddFunc {
W
wilfChen 已提交
89
  __device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs + rhs); }
V
VectorSL 已提交
90 91
};

W
wilfChen 已提交
92 93
// convert to float to fix accuracy issue
template <typename T>
L
linqingke 已提交
94
struct FloorDivFunc {
W
wilfChen 已提交
95 96
  __device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) {
    return floorf(static_cast<float>(lhs) / static_cast<float>(rhs));
L
linqingke 已提交
97
  }
L
linqingke 已提交
98 99 100
};

template <>
W
wilfChen 已提交
101 102 103
struct FloorDivFunc<half> {
  __device__ __host__ __forceinline__ half operator()(const half &lhs, const half &rhs) {
    return floorf(__half2float(lhs) / __half2float(rhs));
L
linqingke 已提交
104 105 106 107
  }
};

template <>
W
wilfChen 已提交
108 109 110 111 112 113 114 115
struct FloorDivFunc<half2> {
  __device__ __host__ __forceinline__ half2 operator()(const half2 &lhs, const half2 &rhs) {
    float2 l = __half22float2(lhs);
    float2 r = __half22float2(rhs);
    l.x = floorf(l.x / r.x);
    l.y = floorf(l.y / r.y);
    return __float22half2_rn(l);
  }
L
linqingke 已提交
116 117
};

W
wilfChen 已提交
118
template <typename T>
C
CaoJian 已提交
119
struct AbsGradFunc {
W
wilfChen 已提交
120
  __device__ __forceinline__ T operator()(const T &lhs, const T &rhs) {
C
CaoJian 已提交
121 122 123 124 125
    T zero = 0.0;
    return lhs < zero ? -rhs : rhs;
  }
};

126
template <>
W
wilfChen 已提交
127 128 129 130 131
struct AbsGradFunc<half2> {
  __device__ __forceinline__ half2 operator()(const half2 &lhs, const half2 &rhs) {
    half2 zero(0.0, 0.0);
    return lhs < zero ? -rhs : rhs;
  }
132 133
};

W
wilfChen 已提交
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
// Element-wise Comparation
template <typename T, typename Func>
__global__ void ElewiseCmpKernel(const int nums, const T *x0, const T *x1, bool *y) {
  for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) {
    y[pos] = Func()(x0[pos], x1[pos]);
  }
}

template <typename T>
void ElewiseCmp(const int &nums, enum BroadcastOpType op, const T *x0, const T *x1, bool *y, cudaStream_t stream) {
  switch (op) {
    case BROADCAST_TYPE_GREATER:
      return ElewiseCmpKernel<T, GreaterFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
    case BROADCAST_TYPE_LESS:
      return ElewiseCmpKernel<T, LessFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
    default:
      break;
  }
}

template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const float *x0, const float *x1, bool *y,
                         cudaStream_t stream);
template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const half *x0, const half *x1, bool *y,
                         cudaStream_t stream);
template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const int *x0, const int *x1, bool *y,
                         cudaStream_t stream);

// Element-wise ArithMetic
template <typename T, typename Func>
__global__ void ElewiseArithKernel(const int nums, const T *x0, const T *x1, T *y) {
  for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) {
    y[pos] = Func()(x0[pos], x1[pos]);
  }
}

template <typename T>
void ElewiseArithKernel(const int &nums, enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream) {
  switch (op) {
    case BROADCAST_TYPE_MINIMUM:
      return ElewiseArithKernel<T, MinimumFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
    case BROADCAST_TYPE_MAXIMUM:
      return ElewiseArithKernel<T, MaximumFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
    case BROADCAST_TYPE_POWER:
      return ElewiseArithKernel<T, PowerFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
    case BROADCAST_TYPE_REALDIV:
      return ElewiseArithKernel<T, RealDivFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
    case BROADCAST_TYPE_MUL:
      return ElewiseArithKernel<T, MulFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
    case BROADCAST_TYPE_SUB:
      return ElewiseArithKernel<T, SubFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
    case BROADCAST_TYPE_ADD:
      return ElewiseArithKernel<T, AddFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
    case BROADCAST_TYPE_FLOORDIV:
      return ElewiseArithKernel<T, FloorDivFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
    case BROADCAST_TYPE_ABSGRAD:
      return ElewiseArithKernel<T, AbsGradFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
    case BROADCAST_TYPE_DIV:
      return ElewiseArithKernel<T, DivFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
    default:
      break;
  }
}

template <typename T>
void ElewiseArith(const int &nums, enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream) {
  return ElewiseArithKernel(nums, op, x0, x1, y, stream);
}

template <>
void ElewiseArith(const int &nums, enum BroadcastOpType op, const half *x0, const half *x1, half *y,
                  cudaStream_t stream) {
205 206
  // `>` return true iff both half result are true. fallback to half
  if (nums % 2 == 0 && op != BROADCAST_TYPE_MINIMUM && op != BROADCAST_TYPE_MAXIMUM && op != BROADCAST_TYPE_ABSGRAD) {
W
wilfChen 已提交
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
    ElewiseArithKernel<half2>(nums / 2, op, reinterpret_cast<const half2 *>(x0), reinterpret_cast<const half2 *>(x1),
                              reinterpret_cast<half2 *>(y), stream);
  } else {
    return ElewiseArithKernel(nums, op, x0, x1, y, stream);
  }
}

template void ElewiseArith(const int &nums, enum BroadcastOpType op, const float *x0, const float *x1, float *y,
                           cudaStream_t stream);
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const half *x0, const half *x1, half *y,
                           cudaStream_t stream);
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const int *x0, const int *x1, int *y,
                           cudaStream_t stream);

// Broadcast comparation
W
wilfChen 已提交
222 223
__device__ __forceinline__ int Index(const int &index, const int &dim) { return dim == 1 ? 0 : index; }

W
wilfChen 已提交
224 225 226 227 228
template <typename T, typename Func>
__global__ void BroadcastCmpKernel(const int l0, const int l1, const int l2, const int l3, const int l4, const int l5,
                                   const int l6, const int r0, const int r1, const int r2, const int r3, const int r4,
                                   const int r5, const int r6, const int d0, const int d1, const int d2, const int d3,
                                   const int d4, const int d5, const int d6, const T *x0, const T *x1, bool *y) {
229 230 231 232 233 234 235 236 237
  for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3 * d4 * d5 * d6;
       pos += blockDim.x * gridDim.x) {
    int i = pos / (d1 * d2 * d3 * d4 * d5 * d6) % d0;
    int j = pos / (d2 * d3 * d4 * d5 * d6) % d1;
    int k = pos / (d3 * d4 * d5 * d6) % d2;
    int l = pos / (d4 * d5 * d6) % d3;
    int m = pos / (d5 * d6) % d4;
    int n = pos / d6 % d5;
    int o = pos % d6;
W
wilfChen 已提交
238

239 240 241 242 243 244 245 246 247 248 249 250 251 252
    int l_index = Index(i, l0) * l1 * l2 * l3 * l4 * l5 * l6;
    l_index += Index(j, l1) * l2 * l3 * l4 * l5 * l6;
    l_index += Index(k, l2) * l3 * l4 * l5 * l6;
    l_index += Index(l, l3) * l4 * l5 * l6;
    l_index += Index(m, l4) * l5 * l6;
    l_index += Index(n, l5) * l6;
    l_index += Index(o, l6);
    int r_index = Index(i, r0) * r1 * r2 * r3 * r4 * r5 * r6;
    r_index += Index(j, r1) * r2 * r3 * r4 * r5 * r6;
    r_index += Index(k, r2) * r3 * r4 * r5 * r6;
    r_index += Index(l, r3) * r4 * r5 * r6;
    r_index += Index(m, r4) * r5 * r6;
    r_index += Index(n, r5) * r6;
    r_index += Index(o, r6);
W
wilfChen 已提交
253
    y[pos] = Func()(x0[l_index], x1[r_index]);
W
wilfChen 已提交
254 255 256
  }
}

W
wilfChen 已提交
257 258 259 260 261 262 263 264
template <typename T>
void BroadcastCmp(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims, const std::vector<int> &y_dims,
                  enum BroadcastOpType op, const T *x0, const T *x1, bool *y, cudaStream_t stream) {
  int size = 1;
  for (auto d : y_dims) {
    size *= d;
  }

W
wilfChen 已提交
265 266
  switch (op) {
    case BROADCAST_TYPE_GREATER:
W
wilfChen 已提交
267 268 269 270
      return BroadcastCmpKernel<T, GreaterFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
        x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
        x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
        y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
W
wilfChen 已提交
271
    case BROADCAST_TYPE_LESS:
W
wilfChen 已提交
272 273 274 275 276 277
      return BroadcastCmpKernel<T, LessFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
        x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
        x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
        y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
    default:
      break;
W
wilfChen 已提交
278 279 280
  }
}

W
wilfChen 已提交
281 282 283 284 285 286 287 288 289
template void BroadcastCmp(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims,
                           const std::vector<int> &y_dims, enum BroadcastOpType op, const float *x0, const float *x1,
                           bool *y, cudaStream_t stream);
template void BroadcastCmp(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims,
                           const std::vector<int> &y_dims, enum BroadcastOpType op, const half *x0, const half *x1,
                           bool *y, cudaStream_t stream);
template void BroadcastCmp(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims,
                           const std::vector<int> &y_dims, enum BroadcastOpType op, const int *x0, const int *x1,
                           bool *y, cudaStream_t stream);
W
wilfChen 已提交
290

W
wilfChen 已提交
291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321
// Broadcast Arithmetic
template <typename T, typename Func>
__global__ void BroadcastArithKernel(const int l0, const int l1, const int l2, const int l3, const int l4, const int l5,
                                     const int l6, const int r0, const int r1, const int r2, const int r3, const int r4,
                                     const int r5, const int r6, const int d0, const int d1, const int d2, const int d3,
                                     const int d4, const int d5, const int d6, const T *x0, const T *x1, T *y) {
  for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3 * d4 * d5 * d6;
       pos += blockDim.x * gridDim.x) {
    int i = pos / (d1 * d2 * d3 * d4 * d5 * d6) % d0;
    int j = pos / (d2 * d3 * d4 * d5 * d6) % d1;
    int k = pos / (d3 * d4 * d5 * d6) % d2;
    int l = pos / (d4 * d5 * d6) % d3;
    int m = pos / (d5 * d6) % d4;
    int n = pos / d6 % d5;
    int o = pos % d6;

    int l_index = Index(i, l0) * l1 * l2 * l3 * l4 * l5 * l6;
    l_index += Index(j, l1) * l2 * l3 * l4 * l5 * l6;
    l_index += Index(k, l2) * l3 * l4 * l5 * l6;
    l_index += Index(l, l3) * l4 * l5 * l6;
    l_index += Index(m, l4) * l5 * l6;
    l_index += Index(n, l5) * l6;
    l_index += Index(o, l6);
    int r_index = Index(i, r0) * r1 * r2 * r3 * r4 * r5 * r6;
    r_index += Index(j, r1) * r2 * r3 * r4 * r5 * r6;
    r_index += Index(k, r2) * r3 * r4 * r5 * r6;
    r_index += Index(l, r3) * r4 * r5 * r6;
    r_index += Index(m, r4) * r5 * r6;
    r_index += Index(n, r5) * r6;
    r_index += Index(o, r6);
    y[pos] = Func()(x0[l_index], x1[r_index]);
W
wilfChen 已提交
322 323 324
  }
}

W
wilfChen 已提交
325 326 327 328 329 330 331
template <typename T>
void BroadcastArith(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims, const std::vector<int> &y_dims,
                    enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream) {
  int size = 1;
  for (auto d : y_dims) {
    size *= d;
  }
W
wilfChen 已提交
332 333
  switch (op) {
    case BROADCAST_TYPE_MAXIMUM:
W
wilfChen 已提交
334 335 336 337 338 339 340 341 342
      return BroadcastArithKernel<T, MaximumFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
        x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
        x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
        y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
    case BROADCAST_TYPE_MINIMUM:
      return BroadcastArithKernel<T, MinimumFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
        x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
        x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
        y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
W
wilfChen 已提交
343
    case BROADCAST_TYPE_POWER:
W
wilfChen 已提交
344 345 346 347
      return BroadcastArithKernel<T, PowerFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
        x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
        x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
        y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
W
wilfChen 已提交
348
    case BROADCAST_TYPE_REALDIV:
W
wilfChen 已提交
349 350 351 352
      return BroadcastArithKernel<T, RealDivFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
        x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
        x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
        y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
W
wilfChen 已提交
353
    case BROADCAST_TYPE_MUL:
W
wilfChen 已提交
354 355 356 357
      return BroadcastArithKernel<T, MulFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
        x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
        x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
        y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
W
wilfChen 已提交
358
    case BROADCAST_TYPE_SUB:
W
wilfChen 已提交
359 360 361 362
      return BroadcastArithKernel<T, SubFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
        x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
        x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
        y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
V
VectorSL 已提交
363
    case BROADCAST_TYPE_ADD:
W
wilfChen 已提交
364 365 366 367
      return BroadcastArithKernel<T, AddFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
        x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
        x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
        y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
L
linqingke 已提交
368
    case BROADCAST_TYPE_FLOORDIV:
W
wilfChen 已提交
369 370 371 372
      return BroadcastArithKernel<T, FloorDivFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
        x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
        x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
        y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
C
CaoJian 已提交
373
    case BROADCAST_TYPE_ABSGRAD:
W
wilfChen 已提交
374 375 376 377
      return BroadcastArithKernel<T, AbsGradFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
        x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
        x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
        y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
Z
ZPaC 已提交
378
    case BROADCAST_TYPE_DIV:
W
wilfChen 已提交
379 380 381 382 383 384
      return BroadcastArithKernel<T, DivFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
        x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
        x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
        y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
    default:
      break;
W
wilfChen 已提交
385 386 387
  }
}

W
wilfChen 已提交
388 389 390 391 392 393 394 395 396
template void BroadcastArith(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims,
                             const std::vector<int> &y_dims, enum BroadcastOpType op, const float *x0, const float *x1,
                             float *y, cudaStream_t stream);
template void BroadcastArith(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims,
                             const std::vector<int> &y_dims, enum BroadcastOpType op, const half *x0, const half *x1,
                             half *y, cudaStream_t stream);
template void BroadcastArith(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims,
                             const std::vector<int> &y_dims, enum BroadcastOpType op, const int *x0, const int *x1,
                             int *y, cudaStream_t stream);
W
wilfChen 已提交
397

W
wilfChen 已提交
398
// BroadcastTo
W
wilfChen 已提交
399
template <typename T>
Z
ZPaC 已提交
400 401
__global__ void BroadcastToKernel(const int i0, const int i1, const int i2, const int i3, const int o0, const int o1,
                                  const int o2, const int o3, const T *input_addr, T *output_addr) {
W
wilfChen 已提交
402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417
  for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < o0 * o1 * o2 * o3; pos += blockDim.x * gridDim.x) {
    int i = pos / (o1 * o2 * o3) % o0;
    int j = pos / (o2 * o3) % o1;
    int k = pos / o3 % o2;
    int l = pos % o3;

    int input_idx = Index(i, i0) * i1 * i2 * i3 + Index(j, i1) * i2 * i3 + Index(k, i2) * i3 + Index(l, i3);
    output_addr[pos] = input_addr[input_idx];
  }
}

template <typename T>
void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1,
                 const int &o2, const int &o3, const T *input_addr, T *output_addr, cudaStream_t stream) {
  int nums = o0 * o1 * o2 * o3;
  BroadcastToKernel<<<GET_BLOCKS(nums), GET_THREADS, 0, stream>>>(i0, i1, i2, i3, o0, o1, o2, o3, input_addr,
Z
ZPaC 已提交
418
                                                                  output_addr);
W
wilfChen 已提交
419 420 421 422 423 424 425
}

template void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1,
                          const int &o2, const int &o3, const float *input_addr, float *output_addr,
                          cudaStream_t stream);
template void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1,
                          const int &o2, const int &o3, const half *input_addr, half *output_addr, cudaStream_t stream);