segment_pooling.cu 14.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
#include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/segment_pooling.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_launch_param_config.h"
#include "paddle/fluid/platform/macros.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename T, typename Index, int DimTileSize>
__global__ void SegmentMeanCustomKernel(
    const Index* segment_ids, const T* input, T* output, T* summed_ids,
    const Index input_length_size, const Index inner_dim_size,
    const Index output_length_size, const Index total_stripe_count) {
  CUDA_KERNEL_LOOP(stripe_index, total_stripe_count) {
    const Index segment_offset = stripe_index % inner_dim_size;
    const Index dim_index_base =
        stripe_index / inner_dim_size * Index(DimTileSize);
    const Index actual_height =
        min(Index(DimTileSize), input_length_size - dim_index_base);

    Index first_segment_id = segment_ids[dim_index_base];
    Index last_segment_id = -1;
    if (dim_index_base > 0) {
      last_segment_id = segment_ids[dim_index_base - 1];
    }
    if (segment_offset == 0) {
      T sum = T(0);
      for (Index j = 0; j < actual_height; j++) {
        Index current_segment_id = segment_ids[dim_index_base + j];
        // Note(ZHUI): following check may cause
        // cudaErrorLaunchOutOfResources.
        // PADDLE_ENFORCE(current_segment_id >= last_segment_id,
        //               "the segment ids should be sorted, but got "
        //               "segment_ids[%d]:%d > segment_ids[%d]:%d.",
        //               dim_index_base + j - 1, dim_index_base + j,
        //               last_segment_id, current_segment_id);

        if (j > 0 && current_segment_id > last_segment_id) {
          if (last_segment_id == first_segment_id) {
            platform::CudaAtomicAdd(summed_ids + last_segment_id, sum);
          } else {
            *(summed_ids + last_segment_id) = sum;
          }
          sum = T(0);
        }
        sum += T(1);
        last_segment_id = current_segment_id;
      }
      platform::CudaAtomicAdd(summed_ids + last_segment_id, sum);
    }
    // ensure last_segment_id is the largest
    last_segment_id = output_length_size;
    __syncthreads();
    T sum = T(0);
    for (Index j = 0; j < actual_height; j++) {
      Index current_segment_id = segment_ids[dim_index_base + j];
      if (current_segment_id > last_segment_id) {
        const Index output_index =
            last_segment_id * inner_dim_size + segment_offset;
        if (last_segment_id == first_segment_id) {
          platform::CudaAtomicAdd(output + output_index,
                                  sum / *(summed_ids + last_segment_id));
        } else {
          *(output + output_index) = sum / *(summed_ids + last_segment_id);
        }
        sum = T(0);
      }
      sum += input[(dim_index_base + j) * inner_dim_size + segment_offset];
      last_segment_id = current_segment_id;
    }
    const Index output_index =
        last_segment_id * inner_dim_size + segment_offset;
    platform::CudaAtomicAdd(output + output_index,
                            sum / *(summed_ids + last_segment_id));
  }
}

template <typename T, typename Index, typename Helper, typename Pool>
__global__ void SegmentOpsKernel(const Index* segment_ids, const T* input,
                                 T* output, Helper h, Pool pool) {
  CUDA_KERNEL_LOOP(stripe_index, h.total_stripe_count) {
    Index segment_offset, dim_index_base, actual_height;
    Index inner_dim_size = h.inner_dim_size;
    h.calculate(stripe_index, segment_offset, dim_index_base, actual_height);

    T minmax = pool.initial();
    Index first_segment_id = segment_ids[dim_index_base];
    // -1 is for the start value when interval_id = 0
    Index last_segment_id = -1;
    if (dim_index_base > 0) {
      last_segment_id = segment_ids[dim_index_base - 1];
    }

    for (Index j = 0; j < actual_height; j++) {
      Index current_segment_id = segment_ids[dim_index_base + j];
      // ensure the segment_ids is sorted.
      PADDLE_ENFORCE(current_segment_id >= last_segment_id,
                     "The segment ids should be sorted, but got "
                     "segment_ids[%d]:%d > segment_ids[%d]:%d.",
                     dim_index_base + j - 1, dim_index_base + j,
                     last_segment_id, current_segment_id);

      if (current_segment_id > last_segment_id) {
        // reset the interval value which do not have corresponding ids.
        for (Index interval_id = last_segment_id + 1;
             interval_id < current_segment_id; ++interval_id) {
          *(output + interval_id * inner_dim_size + segment_offset) = 0;
        }
        // don't update result when j=0
        if (j > 0) {
          const Index output_index =
              last_segment_id * inner_dim_size + segment_offset;
          if (last_segment_id == first_segment_id) {
            pool.atomic(output + output_index, minmax);
          } else {
            *(output + output_index) = minmax;
          }
          minmax = pool.initial();
        }
      }
      pool.compute(
          input[(dim_index_base + j) * inner_dim_size + segment_offset],
          &minmax);
      last_segment_id = current_segment_id;
    }
    const Index output_index =
        last_segment_id * inner_dim_size + segment_offset;
    pool.atomic(output + output_index, minmax);
  }
}

template <typename T, typename Index, typename Helper>
__global__ void SegmentIndexGradKernel(const Index* segment_ids, const T* input,
                                       const T* output, const T* out_grad,
                                       T* in_grad, Helper h) {
  CUDA_KERNEL_LOOP(stripe_index, h.total_stripe_count) {
    Index segment_offset, dim_index_base, actual_height;
    h.calculate(stripe_index, segment_offset, dim_index_base, actual_height);

    for (Index j = 0; j < actual_height; j++) {
      Index current_segment_id = segment_ids[dim_index_base + j];
      Index input_index =
          (dim_index_base + j) * h.inner_dim_size + segment_offset;
      Index output_index =
          current_segment_id * h.inner_dim_size + segment_offset;
      if (input[input_index] == output[output_index]) {
        in_grad[input_index] = out_grad[output_index];
      }
    }
  }
}

template <class T>
class MaxPool {
 public:
  DEVICE inline T initial() { return static_cast<T>(-FLT_MAX); }
  DEVICE inline void compute(const T& x, T* y) { *y = *y > x ? *y : x; }
  DEVICE inline T atomic(T* address, const T val) {
    return platform::CudaAtomicMax(address, val);
  }
};

template <class T>
class MinPool {
 public:
  DEVICE inline T initial() { return static_cast<T>(FLT_MAX); }
  DEVICE inline void compute(const T& x, T* y) { *y = *y < x ? *y : x; }
  DEVICE inline T atomic(T* address, const T val) {
    return platform::CudaAtomicMin(address, val);
  }
};

template <class T>
class SumPool {
 public:
  DEVICE inline T initial() { return static_cast<T>(0); }
  DEVICE inline void compute(const T& x, T* y) { *y = *y + x; }
  DEVICE inline T atomic(T* address, const T val) {
    return platform::CudaAtomicAdd(address, val);
  }
};

template <class T>
class ArrangeHelper {
 public:
  const T input_total_size;
  const T input_length_size;
  const T output_length_size;
  T inner_dim_size;
  T total_stripe_count;
  const T DimTileSize = 8;

  ArrangeHelper(T a, T b, T c)
      : input_total_size(a), input_length_size(b), output_length_size(c) {
    T input_outer_dim_num_stripe =
        (input_length_size + DimTileSize - 1) / DimTileSize;
    inner_dim_size = input_total_size / input_length_size;
    total_stripe_count = inner_dim_size * input_outer_dim_num_stripe;
  }

  DEVICE inline void calculate(T stripe_index, T& segment_offset,
                               T& dim_index_base, T& actual_height) {
    segment_offset = stripe_index % inner_dim_size;
    dim_index_base = stripe_index / inner_dim_size * DimTileSize;
    actual_height = min(DimTileSize, input_length_size - dim_index_base);
  }
};

template <typename T, typename Index>
void SegmentPoolCUDAGradFunctor(const platform::CUDADeviceContext& ctx,
                                const framework::Tensor& input,
                                const framework::Tensor& segment_ids,
                                const framework::Tensor& output,
                                const framework::Tensor& out_grad,
                                framework::Tensor* in_grad,
                                const std::string pooltype = "SUM") {
  auto h = ArrangeHelper<Index>(input.numel(), segment_ids.dims()[0],
                                output.dims()[0]);
  auto config = platform::GetGpuLaunchConfig1D(ctx, h.total_stripe_count);
  if (pooltype == "MAX" || pooltype == "MIN") {
    SegmentIndexGradKernel<T, Index, ArrangeHelper<Index>><<<
        config.block_per_grid.x, config.thread_per_block.x, 0, ctx.stream()>>>(
        segment_ids.data<Index>(), input.data<T>(), output.data<T>(),
        out_grad.data<T>(), in_grad->data<T>(), h);
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "Unsupported segment pooling grad operation, Only MAX, MIN "
        "available, but got %s.",
        pooltype));
  }
}

template <typename T>
__global__ void SimpleDiv(T* x, const T* y, const int len, const int dim) {
  for (int i = blockIdx.x; i < len; i += gridDim.x) {
    __shared__ T y_i;
    auto base = i * dim;
    if (threadIdx.x == 0) {
      y_i = y[i];
    }
    __syncthreads();
    for (int j = threadIdx.x; j < dim; j += blockDim.x) {
      x[base + j] /= y_i;
    }
  }
}

template <typename T, typename IndexT>
class SegmentPoolFunctor<platform::CUDADeviceContext, T, IndexT> {
 public:
  void operator()(const platform::CUDADeviceContext& ctx,
                  const framework::Tensor& input,
                  const framework::Tensor& segment_ids,
                  framework::Tensor* output,
                  framework::Tensor* summed_ids = nullptr,
                  const std::string pooltype = "SUM") {
    auto h = ArrangeHelper<IndexT>(input.numel(), segment_ids.dims()[0],
                                   output->dims()[0]);
    auto config = platform::GetGpuLaunchConfig1D(ctx, h.total_stripe_count);
    if (pooltype == "MEAN") {
      SegmentMeanCustomKernel<
          T, IndexT, IndexT(8)><<<config.block_per_grid.x,
                                  config.thread_per_block.x, 0, ctx.stream()>>>(
          segment_ids.data<IndexT>(), input.data<T>(), output->data<T>(),
          summed_ids->data<T>(), h.input_length_size, h.inner_dim_size,
          h.output_length_size, h.total_stripe_count);
    } else if (pooltype == "SUM") {
      SumPool<T> pool;
      SegmentOpsKernel<
          T, IndexT, ArrangeHelper<IndexT>,
          SumPool<T>><<<config.block_per_grid.x, config.thread_per_block.x, 0,
                        ctx.stream()>>>(segment_ids.data<IndexT>(),
                                        input.data<T>(), output->data<T>(), h,
                                        pool);
    } else if (pooltype == "MAX") {
      MaxPool<T> pool;
      SegmentOpsKernel<
          T, IndexT, ArrangeHelper<IndexT>,
          MaxPool<T>><<<config.block_per_grid.x, config.thread_per_block.x, 0,
                        ctx.stream()>>>(segment_ids.data<IndexT>(),
                                        input.data<T>(), output->data<T>(), h,
                                        pool);
    } else if (pooltype == "MIN") {
      MinPool<T> pool;
      SegmentOpsKernel<
          T, IndexT, ArrangeHelper<IndexT>,
          MinPool<T>><<<config.block_per_grid.x, config.thread_per_block.x, 0,
                        ctx.stream()>>>(segment_ids.data<IndexT>(),
                                        input.data<T>(), output->data<T>(), h,
                                        pool);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unsupported segment pooling operation, Only MEAN, SUM, MAX, MIN "
          "available, but got %s.",
          pooltype));
    }
  }
};

template <typename T, typename IndexT>
class SegmentPoolGradFunctor<platform::CUDADeviceContext, T, IndexT> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
                  const framework::Tensor& input,
                  const framework::Tensor& output,
                  const framework::Tensor& out_grad,
                  const framework::Tensor& segments, framework::Tensor* in_grad,
                  const framework::Tensor* summed_ids = nullptr,
                  const std::string pooltype = "SUM") {
    if (pooltype == "MAX" || pooltype == "MIN") {
      SegmentPoolCUDAGradFunctor<T, IndexT>(context, input, segments, output,
                                            out_grad, in_grad, pooltype);
    } else if (pooltype == "MEAN") {
      framework::Tensor mean_grad;
      mean_grad.mutable_data<T>(input.dims(), context.GetPlace());
      framework::TensorCopy(out_grad, context.GetPlace(), context, &mean_grad);
      int len = output.dims()[0];
      int dim = output.numel() / len;
      auto config = platform::GetGpuLaunchConfig1D(context, len);
      SimpleDiv<T><<<config.block_per_grid.x, config.thread_per_block.x, 0,
                     context.stream()>>>(mean_grad.data<T>(),
                                         summed_ids->data<T>(), len, dim);
      GPUGather<T, IndexT>(context, mean_grad, segments, in_grad);
    } else if (pooltype == "SUM") {
      GPUGather<T, IndexT>(context, out_grad, segments, in_grad);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Unsupported segment pooling operation, Only MEAN, SUM, MAX, MIN "
          "available, but got %s.",
          pooltype));
    }
  }
};

using CUDA = paddle::platform::CUDADeviceContext;
template class SegmentPoolFunctor<CUDA, float, int>;
template class SegmentPoolFunctor<CUDA, float, int64_t>;
template class SegmentPoolFunctor<CUDA, double, int>;
template class SegmentPoolFunctor<CUDA, double, int64_t>;
template class SegmentPoolGradFunctor<CUDA, float, int>;
template class SegmentPoolGradFunctor<CUDA, float, int64_t>;
template class SegmentPoolGradFunctor<CUDA, double, int>;
template class SegmentPoolGradFunctor<CUDA, double, int64_t>;

}  // namespace operators
}  // namespace paddle