segment_pooling.cu 17.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <algorithm>
16

17
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
18 19
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
20
#include "paddle/phi/kernels/funcs/gather.cu.h"
21
#include "paddle/phi/kernels/funcs/math_function.h"
22
#include "paddle/phi/kernels/funcs/segment_pooling.h"
23

24 25
namespace phi {
namespace funcs {
26

27
using Tensor = DenseTensor;
28 29

template <typename T, typename Index, int DimTileSize>
30 31
__global__ void SegmentSumIdsKernel(const Index* segment_ids,
                                    T* summed_ids,
32 33
                                    const Index input_length_size,
                                    const Index total_stripe_count) {
34
  CUDA_KERNEL_LOOP(stripe_index, total_stripe_count) {
35 36
    const Index segment_offset = stripe_index;
    const Index dim_index_base = stripe_index * Index(DimTileSize);
37 38 39 40 41 42 43 44
    const Index actual_height =
        min(Index(DimTileSize), input_length_size - dim_index_base);

    Index first_segment_id = segment_ids[dim_index_base];
    Index last_segment_id = -1;
    if (dim_index_base > 0) {
      last_segment_id = segment_ids[dim_index_base - 1];
    }
45 46 47 48 49 50
    T sum = T(0);
    for (Index j = 0; j < actual_height; j++) {
      Index current_segment_id = segment_ids[dim_index_base + j];
      PADDLE_ENFORCE(current_segment_id >= last_segment_id,
                     "the segment ids should be sorted, but got "
                     "segment_ids[%d]:%d > segment_ids[%d]:%d.",
51 52 53 54
                     dim_index_base + j - 1,
                     dim_index_base + j,
                     last_segment_id,
                     current_segment_id);
55 56
      if (current_segment_id > last_segment_id) {
        for (Index interval_id = last_segment_id + 1;
57 58
             interval_id < current_segment_id;
             ++interval_id) {
59 60 61
          *(summed_ids + interval_id) = 0;
        }
        if (j > 0) {
62
          if (last_segment_id == first_segment_id) {
63
            paddle::platform::CudaAtomicAdd(summed_ids + last_segment_id, sum);
64 65 66 67 68 69
          } else {
            *(summed_ids + last_segment_id) = sum;
          }
          sum = T(0);
        }
      }
70 71 72
      sum += T(1);
      last_segment_id = current_segment_id;
    }
73
    paddle::platform::CudaAtomicAdd(summed_ids + last_segment_id, sum);
74 75 76 77
  }
}

template <typename T, typename Index, int DimTileSize>
78 79 80 81
__global__ void SegmentMeanKernel(const Index* segment_ids,
                                  const T* input,
                                  T* output,
                                  T* summed_ids,
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
                                  const Index input_length_size,
                                  const Index inner_dim_size,
                                  const Index output_length_size,
                                  const Index total_stripe_count) {
  CUDA_KERNEL_LOOP(stripe_index, total_stripe_count) {
    const Index segment_offset = stripe_index % inner_dim_size;
    const Index dim_index_base =
        stripe_index / inner_dim_size * Index(DimTileSize);
    const Index actual_height =
        min(Index(DimTileSize), input_length_size - dim_index_base);

    Index first_segment_id = segment_ids[dim_index_base];
    Index last_segment_id = -1;
    if (dim_index_base > 0) {
      last_segment_id = segment_ids[dim_index_base - 1];
97 98 99 100 101
    }
    T sum = T(0);
    for (Index j = 0; j < actual_height; j++) {
      Index current_segment_id = segment_ids[dim_index_base + j];
      if (current_segment_id > last_segment_id) {
102 103
        // reset the interval value which do not have corresponding ids.
        for (Index interval_id = last_segment_id + 1;
104 105
             interval_id < current_segment_id;
             ++interval_id) {
106 107 108 109 110 111 112 113
          *(output + interval_id * inner_dim_size + segment_offset) = T(0);
        }

        if (j > 0) {
          Index output_index =
              last_segment_id * inner_dim_size + segment_offset;

          if (last_segment_id == first_segment_id) {
114 115
            paddle::platform::CudaAtomicAdd(
                output + output_index, sum / *(summed_ids + last_segment_id));
116 117 118 119
          } else {
            *(output + output_index) = sum / *(summed_ids + last_segment_id);
          }
          sum = T(0);
120 121 122 123 124
        }
      }
      sum += input[(dim_index_base + j) * inner_dim_size + segment_offset];
      last_segment_id = current_segment_id;
    }
125
    Index output_index = last_segment_id * inner_dim_size + segment_offset;
126 127
    paddle::platform::CudaAtomicAdd(output + output_index,
                                    sum / *(summed_ids + last_segment_id));
128 129 130 131
  }
}

template <typename T, typename Index, typename Helper, typename Pool>
132 133
__global__ void __launch_bounds__(1024, 1) SegmentOpsKernel(
    const Index* segment_ids, const T* input, T* output, Helper h, Pool pool) {
134 135 136
  CUDA_KERNEL_LOOP(stripe_index, h.total_stripe_count) {
    Index segment_offset, dim_index_base, actual_height;
    Index inner_dim_size = h.inner_dim_size;
137
    h.calculate(stripe_index, &segment_offset, &dim_index_base, &actual_height);
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152

    T minmax = pool.initial();
    Index first_segment_id = segment_ids[dim_index_base];
    // -1 is for the start value when interval_id = 0
    Index last_segment_id = -1;
    if (dim_index_base > 0) {
      last_segment_id = segment_ids[dim_index_base - 1];
    }

    for (Index j = 0; j < actual_height; j++) {
      Index current_segment_id = segment_ids[dim_index_base + j];
      // ensure the segment_ids is sorted.
      PADDLE_ENFORCE(current_segment_id >= last_segment_id,
                     "The segment ids should be sorted, but got "
                     "segment_ids[%d]:%d > segment_ids[%d]:%d.",
153 154 155 156
                     dim_index_base + j - 1,
                     dim_index_base + j,
                     last_segment_id,
                     current_segment_id);
157 158 159 160

      if (current_segment_id > last_segment_id) {
        // reset the interval value which do not have corresponding ids.
        for (Index interval_id = last_segment_id + 1;
161 162
             interval_id < current_segment_id;
             ++interval_id) {
163
          *(output + interval_id * inner_dim_size + segment_offset) = T(0);
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
        }
        // don't update result when j=0
        if (j > 0) {
          const Index output_index =
              last_segment_id * inner_dim_size + segment_offset;
          if (last_segment_id == first_segment_id) {
            pool.atomic(output + output_index, minmax);
          } else {
            *(output + output_index) = minmax;
          }
          minmax = pool.initial();
        }
      }
      pool.compute(
          input[(dim_index_base + j) * inner_dim_size + segment_offset],
          &minmax);
      last_segment_id = current_segment_id;
    }
    const Index output_index =
        last_segment_id * inner_dim_size + segment_offset;
    pool.atomic(output + output_index, minmax);
  }
}

template <typename T, typename Index, typename Helper>
189 190 191 192 193 194
__global__ void SegmentIndexGradKernel(const Index* segment_ids,
                                       const T* input,
                                       const T* output,
                                       const T* out_grad,
                                       T* in_grad,
                                       Helper h) {
195 196
  CUDA_KERNEL_LOOP(stripe_index, h.total_stripe_count) {
    Index segment_offset, dim_index_base, actual_height;
197
    h.calculate(stripe_index, &segment_offset, &dim_index_base, &actual_height);
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217

    for (Index j = 0; j < actual_height; j++) {
      Index current_segment_id = segment_ids[dim_index_base + j];
      Index input_index =
          (dim_index_base + j) * h.inner_dim_size + segment_offset;
      Index output_index =
          current_segment_id * h.inner_dim_size + segment_offset;
      if (input[input_index] == output[output_index]) {
        in_grad[input_index] = out_grad[output_index];
      }
    }
  }
}

template <class T>
class MaxPool {
 public:
  DEVICE inline T initial() { return static_cast<T>(-FLT_MAX); }
  DEVICE inline void compute(const T& x, T* y) { *y = *y > x ? *y : x; }
  DEVICE inline T atomic(T* address, const T val) {
218
    return paddle::platform::CudaAtomicMax(address, val);
219 220 221 222 223 224 225 226 227
  }
};

template <class T>
class MinPool {
 public:
  DEVICE inline T initial() { return static_cast<T>(FLT_MAX); }
  DEVICE inline void compute(const T& x, T* y) { *y = *y < x ? *y : x; }
  DEVICE inline T atomic(T* address, const T val) {
228
    return paddle::platform::CudaAtomicMin(address, val);
229 230 231 232 233 234 235 236 237
  }
};

template <class T>
class SumPool {
 public:
  DEVICE inline T initial() { return static_cast<T>(0); }
  DEVICE inline void compute(const T& x, T* y) { *y = *y + x; }
  DEVICE inline T atomic(T* address, const T val) {
238
    return paddle::platform::CudaAtomicAdd(address, val);
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
  }
};

template <class T>
class ArrangeHelper {
 public:
  const T input_total_size;
  const T input_length_size;
  const T output_length_size;
  T inner_dim_size;
  T total_stripe_count;
  const T DimTileSize = 8;

  ArrangeHelper(T a, T b, T c)
      : input_total_size(a), input_length_size(b), output_length_size(c) {
    T input_outer_dim_num_stripe =
        (input_length_size + DimTileSize - 1) / DimTileSize;
    inner_dim_size = input_total_size / input_length_size;
    total_stripe_count = inner_dim_size * input_outer_dim_num_stripe;
  }

260 261 262 263
  DEVICE inline void calculate(T stripe_index,
                               T* segment_offset,
                               T* dim_index_base,
                               T* actual_height) {
264 265 266
    *segment_offset = stripe_index % inner_dim_size;
    *dim_index_base = stripe_index / inner_dim_size * DimTileSize;
    *actual_height = min(DimTileSize, input_length_size - *dim_index_base);
267 268 269 270
  }
};

template <typename T, typename Index>
271 272 273 274 275 276
void SegmentPoolCUDAGradFunctor(const phi::GPUContext& ctx,
                                const DenseTensor& input,
                                const DenseTensor& segment_ids,
                                const DenseTensor& output,
                                const DenseTensor& out_grad,
                                DenseTensor* in_grad,
277
                                const std::string pooltype = "SUM") {
278 279 280 281
  auto h = ArrangeHelper<Index>(
      input.numel(), segment_ids.dims()[0], output.dims()[0]);
  auto config =
      phi::backends::gpu::GetGpuLaunchConfig1D(ctx, h.total_stripe_count);
282
  if (pooltype == "MAX" || pooltype == "MIN") {
283 284 285 286 287 288 289 290 291 292
    SegmentIndexGradKernel<T, Index, ArrangeHelper<Index>>
        <<<config.block_per_grid.x,
           config.thread_per_block.x,
           0,
           ctx.stream()>>>(segment_ids.data<Index>(),
                           input.data<T>(),
                           output.data<T>(),
                           out_grad.data<T>(),
                           in_grad->data<T>(),
                           h);
293
  } else {
294
    PADDLE_THROW(phi::errors::InvalidArgument(
295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316
        "Unsupported segment pooling grad operation, Only MAX, MIN "
        "available, but got %s.",
        pooltype));
  }
}

template <typename T>
__global__ void SimpleDiv(T* x, const T* y, const int len, const int dim) {
  for (int i = blockIdx.x; i < len; i += gridDim.x) {
    __shared__ T y_i;
    auto base = i * dim;
    if (threadIdx.x == 0) {
      y_i = y[i];
    }
    __syncthreads();
    for (int j = threadIdx.x; j < dim; j += blockDim.x) {
      x[base + j] /= y_i;
    }
  }
}

template <typename T, typename IndexT>
317
class SegmentPoolFunctor<phi::GPUContext, T, IndexT> {
318
 public:
319 320 321 322 323
  void operator()(const phi::GPUContext& ctx,
                  const DenseTensor& input,
                  const DenseTensor& segment_ids,
                  DenseTensor* output,
                  DenseTensor* summed_ids = nullptr,
324
                  const std::string pooltype = "SUM") {
325 326 327 328 329 330
    if (pooltype == "MEAN") {
      // Sum the segment id num first
      T DimTileSize = 8;
      auto input_length_size = segment_ids.numel();
      auto total_stripe_count =
          (input_length_size + DimTileSize - 1) / DimTileSize;
331 332
      auto config =
          phi::backends::gpu::GetGpuLaunchConfig1D(ctx, total_stripe_count);
333 334 335 336 337 338 339 340
      SegmentSumIdsKernel<T, IndexT, IndexT(8)>
          <<<config.block_per_grid.x,
             config.thread_per_block.x,
             0,
             ctx.stream()>>>(segment_ids.data<IndexT>(),
                             summed_ids->data<T>(),
                             input_length_size,
                             total_stripe_count);
341 342
    }

343 344 345 346
    auto h = ArrangeHelper<IndexT>(
        input.numel(), segment_ids.dims()[0], output->dims()[0]);
    auto config =
        phi::backends::gpu::GetGpuLaunchConfig1D(ctx, h.total_stripe_count);
347
    if (pooltype == "MEAN") {
348 349 350 351 352 353 354 355 356 357 358 359
      SegmentMeanKernel<T, IndexT, IndexT(8)>
          <<<config.block_per_grid.x,
             config.thread_per_block.x,
             0,
             ctx.stream()>>>(segment_ids.data<IndexT>(),
                             input.data<T>(),
                             output->data<T>(),
                             summed_ids->data<T>(),
                             h.input_length_size,
                             h.inner_dim_size,
                             h.output_length_size,
                             h.total_stripe_count);
360 361
    } else if (pooltype == "SUM") {
      SumPool<T> pool;
362 363 364 365 366 367 368 369 370
      SegmentOpsKernel<T, IndexT, ArrangeHelper<IndexT>, SumPool<T>>
          <<<config.block_per_grid.x,
             config.thread_per_block.x,
             0,
             ctx.stream()>>>(segment_ids.data<IndexT>(),
                             input.data<T>(),
                             output->data<T>(),
                             h,
                             pool);
371 372
    } else if (pooltype == "MAX") {
      MaxPool<T> pool;
373 374 375 376 377 378 379 380 381
      SegmentOpsKernel<T, IndexT, ArrangeHelper<IndexT>, MaxPool<T>>
          <<<config.block_per_grid.x,
             config.thread_per_block.x,
             0,
             ctx.stream()>>>(segment_ids.data<IndexT>(),
                             input.data<T>(),
                             output->data<T>(),
                             h,
                             pool);
382 383
    } else if (pooltype == "MIN") {
      MinPool<T> pool;
384 385 386 387 388 389 390 391 392
      SegmentOpsKernel<T, IndexT, ArrangeHelper<IndexT>, MinPool<T>>
          <<<config.block_per_grid.x,
             config.thread_per_block.x,
             0,
             ctx.stream()>>>(segment_ids.data<IndexT>(),
                             input.data<T>(),
                             output->data<T>(),
                             h,
                             pool);
393
    } else {
394
      PADDLE_THROW(phi::errors::InvalidArgument(
395 396 397 398 399 400 401 402
          "Unsupported segment pooling operation, Only MEAN, SUM, MAX, MIN "
          "available, but got %s.",
          pooltype));
    }
  }
};

template <typename T, typename IndexT>
403
class SegmentPoolGradFunctor<phi::GPUContext, T, IndexT> {
404
 public:
405 406 407 408 409 410
  void operator()(const phi::GPUContext& dev_ctx,
                  const DenseTensor& input,
                  const DenseTensor& output,
                  const DenseTensor& out_grad,
                  const DenseTensor& segments,
                  DenseTensor* in_grad,
411
                  const paddle::optional<DenseTensor>& summed_ids,
412 413
                  const std::string pooltype = "SUM") {
    if (pooltype == "MAX" || pooltype == "MIN") {
414 415
      SegmentPoolCUDAGradFunctor<T, IndexT>(
          dev_ctx, input, segments, output, out_grad, in_grad, pooltype);
416
    } else if (pooltype == "MEAN") {
417 418 419 420 421
      DenseTensor mean_grad;
      mean_grad.Resize(input.dims());
      dev_ctx.template Alloc<T>(&mean_grad);
      paddle::framework::TensorCopy(
          out_grad, dev_ctx.GetPlace(), dev_ctx, &mean_grad);
422 423
      int len = output.dims()[0];
      int dim = output.numel() / len;
424 425 426 427 428 429 430
      auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, len);
      SimpleDiv<T><<<config.block_per_grid.x,
                     config.thread_per_block.x,
                     0,
                     dev_ctx.stream()>>>(
          mean_grad.data<T>(), summed_ids->data<T>(), len, dim);
      phi::funcs::GPUGather<T, IndexT>(dev_ctx, mean_grad, segments, in_grad);
431
    } else if (pooltype == "SUM") {
432
      phi::funcs::GPUGather<T, IndexT>(dev_ctx, out_grad, segments, in_grad);
433
    } else {
434
      PADDLE_THROW(phi::errors::InvalidArgument(
435 436 437 438 439 440 441
          "Unsupported segment pooling operation, Only MEAN, SUM, MAX, MIN "
          "available, but got %s.",
          pooltype));
    }
  }
};

442 443 444 445 446
using GPU = phi::GPUContext;
template class SegmentPoolFunctor<GPU, float, int>;
template class SegmentPoolFunctor<GPU, float, int64_t>;
template class SegmentPoolFunctor<GPU, double, int>;
template class SegmentPoolFunctor<GPU, double, int64_t>;
447 448 449 450 451
template class SegmentPoolFunctor<GPU, int, int>;
template class SegmentPoolFunctor<GPU, int, int64_t>;
template class SegmentPoolFunctor<GPU, int64_t, int>;
template class SegmentPoolFunctor<GPU, int64_t, int64_t>;

452 453 454 455
template class SegmentPoolGradFunctor<GPU, float, int>;
template class SegmentPoolGradFunctor<GPU, float, int64_t>;
template class SegmentPoolGradFunctor<GPU, double, int>;
template class SegmentPoolGradFunctor<GPU, double, int64_t>;
456 457 458 459
template class SegmentPoolGradFunctor<GPU, int, int>;
template class SegmentPoolGradFunctor<GPU, int, int64_t>;
template class SegmentPoolGradFunctor<GPU, int64_t, int>;
template class SegmentPoolGradFunctor<GPU, int64_t, int64_t>;
460 461 462

}  // namespace funcs
}  // namespace phi