elementwise_op_broadcast.cu.h 14.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.1 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.1
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
18
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
19

20 21 22
namespace paddle {
namespace operators {

23 24
namespace kps = paddle::operators::kernel_primitives;

25 26 27 28 29 30 31 32 33
struct DimensionsTransform {
  using DimVector = std::vector<int64_t>;
  typedef void (*MergeFunctor)(bool &, std::vector<DimVector> &, DimVector &,
                               int, int);
  int64_t dim_size;
  DimVector out_dims;
  std::vector<DimVector> in_dims;

 private:
34 35
  // To compensate the lackage of input_tensors` dimension with input variable
  // 'axis'
36 37 38 39 40 41 42 43 44 45 46 47
  void InputDimensionsExtend(int N, int axis) {
    for (auto &in_dim : in_dims) {
      int64_t in_idx = 0;
      if (in_dim.size() < dim_size) {
        DimVector tmp_dim(dim_size, 1);
        do {
          if (in_dim[in_idx] == out_dims[axis] || in_dim[in_idx] == 1) {
            tmp_dim[axis] = in_dim[in_idx];
            in_idx++;
            axis++;
          } else {
            PADDLE_THROW(platform::errors::InvalidArgument(
48 49 50
                "The %d-th dimension of input tensor is expected to be equal "
                "with the %d-th dimension of output tensor %d or 1, but "
                "recieved %d.",
51 52 53 54 55 56 57 58 59 60 61
                in_idx + 1, axis + 1, out_dims[axis], in_dim[in_idx]));
          }
        } while (in_idx < in_dim.size());
        in_dim.resize(dim_size);
        std::copy(tmp_dim.begin(), tmp_dim.end(), in_dim.begin());
      } else {
        do {
          if (in_dim[in_idx] == out_dims[in_idx] || in_dim[in_idx] == 1) {
            in_idx++;
          } else {
            PADDLE_THROW(platform::errors::InvalidArgument(
62 63 64
                "The %d-th dimension of input tensor is expected to be equal "
                "with the %d-th dimension of output tensor %d or 1, but "
                "recieved %d.",
65 66 67 68 69 70 71 72 73 74
                in_idx + 1, in_idx + 1, out_dims[in_idx], in_dim[in_idx]));
          }
        } while (in_idx < dim_size);
      }
      std::reverse(in_dim.begin(), in_dim.end());
    }
    std::reverse(out_dims.begin(), out_dims.end());
  }

  template <typename MergeFunctor>
75
  __inline__ void MergeDimensions(MergeFunctor merge_func, int N) {
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
    auto VectorReorganise = [](DimVector *vec, int l_idx, int m_idx) {
      (*vec)[m_idx - 1] =
          std::accumulate(vec->begin() + l_idx, vec->begin() + m_idx, 1,
                          std::multiplies<int64_t>());
      vec->erase(vec->begin() + l_idx, vec->begin() + m_idx - 1);
    };

    int64_t i = 0;
    while (i < dim_size) {
      int cnt = 0;
      int low_idx = i;
      bool equal = true;
      do {
        merge_func(equal, in_dims, out_dims, i, N);
        if (equal) {
          i++;
          cnt++;
        } else {
          break;
        }
      } while (i < dim_size);

      if (cnt > 1) {
        for (auto &in_dim : in_dims) {
          VectorReorganise(&in_dim, low_idx, i);
        }
        VectorReorganise(&out_dims, low_idx, i);
        dim_size -= --cnt;
        i -= cnt;
      } else if (cnt < 1) {
        i++;
      }
    }
  }

 public:
  explicit DimensionsTransform(
      const std::vector<const framework::Tensor *> &ins,
      const framework::DDim &dims, int axis) {
    const int N = ins.size();
    dim_size = dims.size();
    out_dims = framework::vectorize<int64_t>(dims);
    in_dims.resize(N);
    for (int j = 0; j < N; ++j) {
      in_dims[j] = framework::vectorize<int64_t>(ins[j]->dims());
    }
    InputDimensionsExtend(N, axis);

    auto merge_sequential_dims = [](bool &equal,
                                    std::vector<DimVector> &in_dims,
                                    DimVector &out, int i, int num) {
      for (int j = 1; j < num; ++j) {
        equal = (in_dims[0][i] == in_dims[j][i]) ? true : false;
      }
    };
    auto merge_sequential_one_dims = [](bool &equal,
                                        std::vector<DimVector> &in_dims,
                                        DimVector &out, int i, int num) {
      equal = in_dims[0][i] == 1;
      if (equal) {
        for (int j = 1; j < num; ++j) {
          equal = in_dims[j][i] == out[i];
        }
      }
    };
    // To Merge the dimensions of input_tensors while the consequtive
    // equal-dimensions appears.
    MergeFunctor merge_ptr = merge_sequential_dims;
144
    MergeDimensions<MergeFunctor>(merge_ptr, N);
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159

    int min_idx = 0;
    int min_val = std::accumulate(in_dims[0].begin(), in_dims[0].end(), 1,
                                  std::multiplies<int64_t>());
    for (int j = 1; j < N; ++j) {
      int temp = std::accumulate(in_dims[j].begin(), in_dims[j].end(), 1,
                                 std::multiplies<int64_t>());
      min_val = min_val > temp ? temp : min_val;
      min_idx = min_val == temp ? j : min_idx;
    }
    std::swap(in_dims[0], in_dims[min_idx]);

    // To Merge the dimension of input_tensors while the consequtive
    // 1-value-dimensions appears.
    merge_ptr = merge_sequential_one_dims;
160
    MergeDimensions<MergeFunctor>(merge_ptr, N);
161 162 163 164
    std::swap(in_dims[min_idx], in_dims[0]);
  }
};

165
template <typename T, int VecSize, int Rank, bool IsBoundary = false>
166 167
__device__ __forceinline__ void LoadData(
    T *dst, const T *__restrict__ src, uint32_t block_offset,
168
    const kps::details::BroadcastConfig<Rank> &config, int numel, int num,
169 170 171 172
    bool need_broadcast) {
  // numel : whole num of output
  // num: how many data will be deal with in this time
  if (need_broadcast) {
173
    kps::ReadDataBc<T, VecSize, 1, 1, Rank, IsBoundary>(dst, src, block_offset,
174
                                                        config, numel);
175 176
  } else {
    kps::ReadData<T, VecSize, 1, 1, IsBoundary>(dst, src + block_offset, num);
177
  }
178
}
179

180 181
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize,
          int Rank, bool IsBoundary = false>
182
__device__ void DealSegment(
183 184 185
    const framework::Array<const InT *__restrict__, Arity> &ins, OutT *out,
    const framework::Array<bool, Arity> &use_broadcast, uint32_t numel,
    const framework::Array<kps::details::BroadcastConfig<Rank>, Arity> &configs,
186
    int num, Functor func) {
187
  InT args[Arity][VecSize];
188
  OutT result[VecSize];
189

190
  int block_offset = blockIdx.x * blockDim.x * VecSize;
191

192
#pragma unroll
193
  for (int i = 0; i < Arity; i++) {
194
    kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
195 196 197
    LoadData<InT, VecSize, Rank, IsBoundary>(args[i], ins[i], block_offset,
                                             configs[i], numel, num,
                                             use_broadcast[i]);
198
  }
199 200 201 202 203

  const bool kCallElementwiseAny =
      platform::FunctionTraits<Functor>::has_pointer_args;
  ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, Arity,
                             kCallElementwiseAny>()(func, args, result);
204 205 206
  kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(out + block_offset, result,
                                                  num);
}
207

208 209
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize,
          int Rank>
210
__global__ void BroadcastKernel(
211 212 213
    framework::Array<const InT *__restrict__, Arity> ins, OutT *out,
    framework::Array<bool, Arity> use_broadcast, uint32_t numel,
    framework::Array<kps::details::BroadcastConfig<Rank>, Arity> configs,
214 215 216 217 218
    int main_tid, int tail_tid, Functor func) {
  int block_offset = blockIdx.x * blockDim.x * VecSize;
  // data offset of this block
  if (blockIdx.x < main_tid) {
    int num = blockDim.x * VecSize;  // blockIdx.x < main_tid
219 220
    DealSegment<InT, OutT, Functor, Arity, VecSize, Rank, false>(
        ins, out, use_broadcast, numel, configs, num, func);
221 222
  } else {  // reminder
    int num = tail_tid;
223 224
    DealSegment<InT, OutT, Functor, Arity, VecSize, Rank, true>(
        ins, out, use_broadcast, numel, configs, num, func);
225 226 227
  }
}

228 229
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize,
          int Rank>
230 231 232 233 234 235 236
void LaunchKernel(const platform::CUDADeviceContext &ctx,
                  const std::vector<const framework::Tensor *> &ins,
                  framework::Tensor *out, Functor func,
                  DimensionsTransform merge_dims) {
  int numel = out->numel();
  const int threads = 256;
  int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
237

238 239 240 241 242
  int main_tid = numel / (VecSize * threads);
  int tail_tid = numel % (VecSize * threads);
  auto stream = ctx.stream();
  OutT *out_data = out->data<OutT>();

243 244 245
  framework::Array<kps::details::BroadcastConfig<Rank>, Arity> configs;
  framework::Array<bool, Arity> use_broadcast;
  framework::Array<const InT *__restrict__, Arity> ins_data;
246

247
  for (int i = 0; i < Arity; i++) {
248 249 250 251 252 253
    use_broadcast[i] = (ins[i]->numel() != numel);
    ins_data[i] = ins[i]->data<InT>();
    if (use_broadcast[i]) {
      // get the broadcast config,
      // if data shape is[m, n], then you should set data_dim = {n, m}
      // eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
254
      configs[i] = kps::details::BroadcastConfig<Rank>(
255
          merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
256 257 258
    }
  }

259 260 261
  BroadcastKernel<InT, OutT, Functor, Arity, VecSize,
                  Rank><<<blocks, threads, 0, stream>>>(
      ins_data, out_data, use_broadcast, numel, configs, main_tid, tail_tid,
262
      func);
263 264
}

265 266
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize>
void LaunchBroadcastKernelForDifferentVecSize(
267 268
    const platform::CUDADeviceContext &ctx,
    const std::vector<const framework::Tensor *> &ins, framework::Tensor *out,
269
    int axis, Functor func) {
270
  const auto merge_dims = DimensionsTransform(ins, out->dims(), axis);
271 272 273 274 275

#define CALL_BROADCAST_FOR_DIM_SIZE(rank)                                     \
  case rank: {                                                                \
    LaunchKernel<InT, OutT, Functor, Arity, VecSize, rank>(ctx, ins, out,     \
                                                           func, merge_dims); \
276
  } break;
277 278

  switch (merge_dims.dim_size) {
279 280 281 282 283 284 285 286 287 288 289 290 291 292
    CALL_BROADCAST_FOR_DIM_SIZE(1);
    CALL_BROADCAST_FOR_DIM_SIZE(2);
    CALL_BROADCAST_FOR_DIM_SIZE(3);
    CALL_BROADCAST_FOR_DIM_SIZE(4);
    CALL_BROADCAST_FOR_DIM_SIZE(5);
    CALL_BROADCAST_FOR_DIM_SIZE(6);
    CALL_BROADCAST_FOR_DIM_SIZE(7);
    CALL_BROADCAST_FOR_DIM_SIZE(8);
    default: {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "The maximum dimension of input tensor is expected to be less than "
          "%d, but recieved %d.\n",
          merge_dims.dim_size, framework::DDim::kMaxRank));
    }
293
  }
294
#undef CALL_BROADCAST_FOR_DIM_SIZE
295 296
}

297
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
298 299
void LaunchBroadcastElementwiseCudaKernel(
    const platform::CUDADeviceContext &ctx,
300 301
    const std::vector<const framework::Tensor *> &ins,
    std::vector<framework::Tensor *> *outs, int axis, Functor func) {
302 303 304 305
  using Traits = platform::FunctionTraits<Functor>;
  const int kArity =
      Traits::has_pointer_args ? static_cast<int>(ET) : Traits::arity;
  PADDLE_ENFORCE_EQ(ins.size(), kArity,
306
                    platform::errors::InvalidArgument(
307 308 309 310 311 312 313 314 315 316
                        "The number of inputs is expected to be equal to the "
                        "arity of functor. But recieved: the number of inputs "
                        "is %d, the arity of functor is %d.",
                        ins.size(), kArity));
  PADDLE_ENFORCE_EQ(kArity, 2,
                    platform::errors::InvalidArgument(
                        "Currently only broadcast of binary is supported and "
                        "verified, but received %d.",
                        kArity));

317
  int in_vec_size = 4;
318
  framework::Tensor *out = (*outs)[0];
319
  for (auto *in : ins) {
320
    auto temp_size = platform::GetVectorizedSize<InT>(in->data<InT>());
321 322 323
    in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size)
                                            : in_vec_size;
  }
324
  int out_vec_size = platform::GetVectorizedSize<OutT>(out->data<OutT>());
325 326 327 328
  int vec_size = std::min(out_vec_size, in_vec_size);

  switch (vec_size) {
    case 4: {
329 330
      LaunchBroadcastKernelForDifferentVecSize<InT, OutT, Functor, kArity, 4>(
          ctx, ins, out, axis, func);
331 332 333
      break;
    }
    case 2: {
334 335
      LaunchBroadcastKernelForDifferentVecSize<InT, OutT, Functor, kArity, 2>(
          ctx, ins, out, axis, func);
336 337 338
      break;
    }
    case 1: {
339 340
      LaunchBroadcastKernelForDifferentVecSize<InT, OutT, Functor, kArity, 1>(
          ctx, ins, out, axis, func);
341 342 343
      break;
    }
    default: {
344 345
      PADDLE_THROW(platform::errors::Unimplemented(
          "Unsupported vectorized size: %d !", vec_size));
346 347 348 349 350
      break;
    }
  }
}

351
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
352
void LaunchElementwiseCudaKernel(
353
    const platform::CUDADeviceContext &cuda_ctx,
354
    const std::vector<const framework::Tensor *> &ins,
355
    std::vector<framework::Tensor *> *outs, int axis, Functor func) {
356
  std::vector<int> dims_size;
357 358 359
  bool no_broadcast_flag = true;
  for (auto *in : ins) {
    no_broadcast_flag = ins[0]->dims() == in->dims();
360
    dims_size.emplace_back(in->dims().size());
361
  }
362

363
  if (no_broadcast_flag) {
364 365
    LaunchSameDimsElementwiseCudaKernel<ET, InT, OutT>(cuda_ctx, ins, outs,
                                                       func);
366
  } else {
367 368 369 370
    axis = axis == -1
               ? *std::max_element(dims_size.begin(), dims_size.end()) -
                     *std::min_element(dims_size.begin(), dims_size.end())
               : axis;
371 372
    LaunchBroadcastElementwiseCudaKernel<ET, InT, OutT>(cuda_ctx, ins, outs,
                                                        axis, func);
373 374 375
  }
}

376 377
}  // namespace operators
}  // namespace paddle