segment_pooling.cu 18.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include "paddle/phi/kernels/funcs/segment_pooling.h"

17
#include <algorithm>
18

19
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
20 21
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
22
#include "paddle/phi/kernels/funcs/gather.cu.h"
23
#include "paddle/phi/kernels/funcs/math_function.h"
24

25 26
namespace phi {
namespace funcs {
27

28
using Tensor = DenseTensor;
29 30

template <typename T, typename Index, int DimTileSize>
31 32
__global__ void SegmentSumIdsKernel(const Index* segment_ids,
                                    T* summed_ids,
33 34
                                    const Index input_length_size,
                                    const Index total_stripe_count) {
35
  CUDA_KERNEL_LOOP(stripe_index, total_stripe_count) {
36 37
    const Index segment_offset = stripe_index;
    const Index dim_index_base = stripe_index * Index(DimTileSize);
38 39 40 41 42 43 44 45
    const Index actual_height =
        min(Index(DimTileSize), input_length_size - dim_index_base);

    Index first_segment_id = segment_ids[dim_index_base];
    Index last_segment_id = -1;
    if (dim_index_base > 0) {
      last_segment_id = segment_ids[dim_index_base - 1];
    }
46 47 48 49 50 51
    T sum = T(0);
    for (Index j = 0; j < actual_height; j++) {
      Index current_segment_id = segment_ids[dim_index_base + j];
      PADDLE_ENFORCE(current_segment_id >= last_segment_id,
                     "the segment ids should be sorted, but got "
                     "segment_ids[%d]:%d > segment_ids[%d]:%d.",
52 53 54 55
                     dim_index_base + j - 1,
                     dim_index_base + j,
                     last_segment_id,
                     current_segment_id);
56 57
      if (current_segment_id > last_segment_id) {
        for (Index interval_id = last_segment_id + 1;
58 59
             interval_id < current_segment_id;
             ++interval_id) {
60 61 62
          *(summed_ids + interval_id) = 0;
        }
        if (j > 0) {
63
          if (last_segment_id == first_segment_id) {
64
            paddle::platform::CudaAtomicAdd(summed_ids + last_segment_id, sum);
65 66 67 68 69 70
          } else {
            *(summed_ids + last_segment_id) = sum;
          }
          sum = T(0);
        }
      }
71 72 73
      sum += T(1);
      last_segment_id = current_segment_id;
    }
74
    paddle::platform::CudaAtomicAdd(summed_ids + last_segment_id, sum);
75 76 77 78
  }
}

template <typename T, typename Index, int DimTileSize>
79 80 81 82
__global__ void SegmentMeanKernel(const Index* segment_ids,
                                  const T* input,
                                  T* output,
                                  T* summed_ids,
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
                                  const Index input_length_size,
                                  const Index inner_dim_size,
                                  const Index output_length_size,
                                  const Index total_stripe_count) {
  CUDA_KERNEL_LOOP(stripe_index, total_stripe_count) {
    const Index segment_offset = stripe_index % inner_dim_size;
    const Index dim_index_base =
        stripe_index / inner_dim_size * Index(DimTileSize);
    const Index actual_height =
        min(Index(DimTileSize), input_length_size - dim_index_base);

    Index first_segment_id = segment_ids[dim_index_base];
    Index last_segment_id = -1;
    if (dim_index_base > 0) {
      last_segment_id = segment_ids[dim_index_base - 1];
98 99 100 101 102
    }
    T sum = T(0);
    for (Index j = 0; j < actual_height; j++) {
      Index current_segment_id = segment_ids[dim_index_base + j];
      if (current_segment_id > last_segment_id) {
103 104
        // reset the interval value which do not have corresponding ids.
        for (Index interval_id = last_segment_id + 1;
105 106
             interval_id < current_segment_id;
             ++interval_id) {
107 108 109 110 111 112 113 114
          *(output + interval_id * inner_dim_size + segment_offset) = T(0);
        }

        if (j > 0) {
          Index output_index =
              last_segment_id * inner_dim_size + segment_offset;

          if (last_segment_id == first_segment_id) {
115 116
            paddle::platform::CudaAtomicAdd(
                output + output_index, sum / *(summed_ids + last_segment_id));
117 118 119 120
          } else {
            *(output + output_index) = sum / *(summed_ids + last_segment_id);
          }
          sum = T(0);
121 122 123 124 125
        }
      }
      sum += input[(dim_index_base + j) * inner_dim_size + segment_offset];
      last_segment_id = current_segment_id;
    }
126
    Index output_index = last_segment_id * inner_dim_size + segment_offset;
127 128
    paddle::platform::CudaAtomicAdd(output + output_index,
                                    sum / *(summed_ids + last_segment_id));
129 130 131 132
  }
}

template <typename T, typename Index, typename Helper, typename Pool>
133 134
__global__ void __launch_bounds__(1024, 1) SegmentOpsKernel(
    const Index* segment_ids, const T* input, T* output, Helper h, Pool pool) {
135 136 137
  CUDA_KERNEL_LOOP(stripe_index, h.total_stripe_count) {
    Index segment_offset, dim_index_base, actual_height;
    Index inner_dim_size = h.inner_dim_size;
138
    h.calculate(stripe_index, &segment_offset, &dim_index_base, &actual_height);
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153

    T minmax = pool.initial();
    Index first_segment_id = segment_ids[dim_index_base];
    // -1 is for the start value when interval_id = 0
    Index last_segment_id = -1;
    if (dim_index_base > 0) {
      last_segment_id = segment_ids[dim_index_base - 1];
    }

    for (Index j = 0; j < actual_height; j++) {
      Index current_segment_id = segment_ids[dim_index_base + j];
      // ensure the segment_ids is sorted.
      PADDLE_ENFORCE(current_segment_id >= last_segment_id,
                     "The segment ids should be sorted, but got "
                     "segment_ids[%d]:%d > segment_ids[%d]:%d.",
154 155 156 157
                     dim_index_base + j - 1,
                     dim_index_base + j,
                     last_segment_id,
                     current_segment_id);
158 159 160 161

      if (current_segment_id > last_segment_id) {
        // reset the interval value which do not have corresponding ids.
        for (Index interval_id = last_segment_id + 1;
162 163
             interval_id < current_segment_id;
             ++interval_id) {
164
          *(output + interval_id * inner_dim_size + segment_offset) = T(0);
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
        }
        // don't update result when j=0
        if (j > 0) {
          const Index output_index =
              last_segment_id * inner_dim_size + segment_offset;
          if (last_segment_id == first_segment_id) {
            pool.atomic(output + output_index, minmax);
          } else {
            *(output + output_index) = minmax;
          }
          minmax = pool.initial();
        }
      }
      pool.compute(
          input[(dim_index_base + j) * inner_dim_size + segment_offset],
          &minmax);
      last_segment_id = current_segment_id;
    }
    const Index output_index =
        last_segment_id * inner_dim_size + segment_offset;
    pool.atomic(output + output_index, minmax);
  }
}

template <typename T, typename Index, typename Helper>
190 191 192 193 194 195
__global__ void SegmentIndexGradKernel(const Index* segment_ids,
                                       const T* input,
                                       const T* output,
                                       const T* out_grad,
                                       T* in_grad,
                                       Helper h) {
196 197
  CUDA_KERNEL_LOOP(stripe_index, h.total_stripe_count) {
    Index segment_offset, dim_index_base, actual_height;
198
    h.calculate(stripe_index, &segment_offset, &dim_index_base, &actual_height);
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218

    for (Index j = 0; j < actual_height; j++) {
      Index current_segment_id = segment_ids[dim_index_base + j];
      Index input_index =
          (dim_index_base + j) * h.inner_dim_size + segment_offset;
      Index output_index =
          current_segment_id * h.inner_dim_size + segment_offset;
      if (input[input_index] == output[output_index]) {
        in_grad[input_index] = out_grad[output_index];
      }
    }
  }
}

template <class T>
class MaxPool {
 public:
  DEVICE inline T initial() { return static_cast<T>(-FLT_MAX); }
  DEVICE inline void compute(const T& x, T* y) { *y = *y > x ? *y : x; }
  DEVICE inline T atomic(T* address, const T val) {
219
    return paddle::platform::CudaAtomicMax(address, val);
220 221 222 223 224 225 226 227 228
  }
};

template <class T>
class MinPool {
 public:
  DEVICE inline T initial() { return static_cast<T>(FLT_MAX); }
  DEVICE inline void compute(const T& x, T* y) { *y = *y < x ? *y : x; }
  DEVICE inline T atomic(T* address, const T val) {
229
    return paddle::platform::CudaAtomicMin(address, val);
230 231 232 233 234 235 236 237 238
  }
};

template <class T>
class SumPool {
 public:
  DEVICE inline T initial() { return static_cast<T>(0); }
  DEVICE inline void compute(const T& x, T* y) { *y = *y + x; }
  DEVICE inline T atomic(T* address, const T val) {
239
    return paddle::platform::CudaAtomicAdd(address, val);
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260
  }
};

template <class T>
class ArrangeHelper {
 public:
  const T input_total_size;
  const T input_length_size;
  const T output_length_size;
  T inner_dim_size;
  T total_stripe_count;
  const T DimTileSize = 8;

  ArrangeHelper(T a, T b, T c)
      : input_total_size(a), input_length_size(b), output_length_size(c) {
    T input_outer_dim_num_stripe =
        (input_length_size + DimTileSize - 1) / DimTileSize;
    inner_dim_size = input_total_size / input_length_size;
    total_stripe_count = inner_dim_size * input_outer_dim_num_stripe;
  }

261 262 263 264
  DEVICE inline void calculate(T stripe_index,
                               T* segment_offset,
                               T* dim_index_base,
                               T* actual_height) {
265 266 267
    *segment_offset = stripe_index % inner_dim_size;
    *dim_index_base = stripe_index / inner_dim_size * DimTileSize;
    *actual_height = min(DimTileSize, input_length_size - *dim_index_base);
268 269 270 271
  }
};

template <typename T, typename Index>
272 273 274 275 276 277
void SegmentPoolCUDAGradFunctor(const phi::GPUContext& ctx,
                                const DenseTensor& input,
                                const DenseTensor& segment_ids,
                                const DenseTensor& output,
                                const DenseTensor& out_grad,
                                DenseTensor* in_grad,
278
                                const std::string pooltype = "SUM") {
279 280 281 282
  auto h = ArrangeHelper<Index>(
      input.numel(), segment_ids.dims()[0], output.dims()[0]);
  auto config =
      phi::backends::gpu::GetGpuLaunchConfig1D(ctx, h.total_stripe_count);
283
  if (pooltype == "MAX" || pooltype == "MIN") {
284 285 286 287 288 289 290 291 292 293 294 295
    SegmentIndexGradKernel<T,
                           Index,
                           ArrangeHelper<Index>><<<config.block_per_grid.x,
                                                   config.thread_per_block.x,
                                                   0,
                                                   ctx.stream()>>>(
        segment_ids.data<Index>(),
        input.data<T>(),
        output.data<T>(),
        out_grad.data<T>(),
        in_grad->data<T>(),
        h);
296
  } else {
297
    PADDLE_THROW(phi::errors::InvalidArgument(
298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319
        "Unsupported segment pooling grad operation, Only MAX, MIN "
        "available, but got %s.",
        pooltype));
  }
}

template <typename T>
__global__ void SimpleDiv(T* x, const T* y, const int len, const int dim) {
  for (int i = blockIdx.x; i < len; i += gridDim.x) {
    __shared__ T y_i;
    auto base = i * dim;
    if (threadIdx.x == 0) {
      y_i = y[i];
    }
    __syncthreads();
    for (int j = threadIdx.x; j < dim; j += blockDim.x) {
      x[base + j] /= y_i;
    }
  }
}

template <typename T, typename IndexT>
320
class SegmentPoolFunctor<phi::GPUContext, T, IndexT> {
321
 public:
322 323 324 325 326
  void operator()(const phi::GPUContext& ctx,
                  const DenseTensor& input,
                  const DenseTensor& segment_ids,
                  DenseTensor* output,
                  DenseTensor* summed_ids = nullptr,
327
                  const std::string pooltype = "SUM") {
328 329 330 331 332 333
    if (pooltype == "MEAN") {
      // Sum the segment id num first
      T DimTileSize = 8;
      auto input_length_size = segment_ids.numel();
      auto total_stripe_count =
          (input_length_size + DimTileSize - 1) / DimTileSize;
334 335 336 337 338 339 340 341 342
      auto config =
          phi::backends::gpu::GetGpuLaunchConfig1D(ctx, total_stripe_count);
      SegmentSumIdsKernel<T, IndexT, IndexT(8)><<<config.block_per_grid.x,
                                                  config.thread_per_block.x,
                                                  0,
                                                  ctx.stream()>>>(
          segment_ids.data<IndexT>(),
          summed_ids->data<T>(),
          input_length_size,
343 344 345
          total_stripe_count);
    }

346 347 348 349
    auto h = ArrangeHelper<IndexT>(
        input.numel(), segment_ids.dims()[0], output->dims()[0]);
    auto config =
        phi::backends::gpu::GetGpuLaunchConfig1D(ctx, h.total_stripe_count);
350
    if (pooltype == "MEAN") {
351 352 353 354 355 356 357 358 359 360 361 362
      SegmentMeanKernel<T, IndexT, IndexT(8)><<<config.block_per_grid.x,
                                                config.thread_per_block.x,
                                                0,
                                                ctx.stream()>>>(
          segment_ids.data<IndexT>(),
          input.data<T>(),
          output->data<T>(),
          summed_ids->data<T>(),
          h.input_length_size,
          h.inner_dim_size,
          h.output_length_size,
          h.total_stripe_count);
363 364
    } else if (pooltype == "SUM") {
      SumPool<T> pool;
365 366 367 368 369 370 371 372 373 374 375
      SegmentOpsKernel<T,
                       IndexT,
                       ArrangeHelper<IndexT>,
                       SumPool<T>><<<config.block_per_grid.x,
                                     config.thread_per_block.x,
                                     0,
                                     ctx.stream()>>>(segment_ids.data<IndexT>(),
                                                     input.data<T>(),
                                                     output->data<T>(),
                                                     h,
                                                     pool);
376 377
    } else if (pooltype == "MAX") {
      MaxPool<T> pool;
378 379 380 381 382 383 384 385 386 387 388
      SegmentOpsKernel<T,
                       IndexT,
                       ArrangeHelper<IndexT>,
                       MaxPool<T>><<<config.block_per_grid.x,
                                     config.thread_per_block.x,
                                     0,
                                     ctx.stream()>>>(segment_ids.data<IndexT>(),
                                                     input.data<T>(),
                                                     output->data<T>(),
                                                     h,
                                                     pool);
389 390
    } else if (pooltype == "MIN") {
      MinPool<T> pool;
391 392 393 394 395 396 397 398 399 400 401
      SegmentOpsKernel<T,
                       IndexT,
                       ArrangeHelper<IndexT>,
                       MinPool<T>><<<config.block_per_grid.x,
                                     config.thread_per_block.x,
                                     0,
                                     ctx.stream()>>>(segment_ids.data<IndexT>(),
                                                     input.data<T>(),
                                                     output->data<T>(),
                                                     h,
                                                     pool);
402
    } else {
403
      PADDLE_THROW(phi::errors::InvalidArgument(
404 405 406 407 408 409 410 411
          "Unsupported segment pooling operation, Only MEAN, SUM, MAX, MIN "
          "available, but got %s.",
          pooltype));
    }
  }
};

template <typename T, typename IndexT>
412
class SegmentPoolGradFunctor<phi::GPUContext, T, IndexT> {
413
 public:
414 415 416 417 418 419 420
  void operator()(const phi::GPUContext& dev_ctx,
                  const DenseTensor& input,
                  const DenseTensor& output,
                  const DenseTensor& out_grad,
                  const DenseTensor& segments,
                  DenseTensor* in_grad,
                  paddle::optional<const DenseTensor&> summed_ids,
421 422
                  const std::string pooltype = "SUM") {
    if (pooltype == "MAX" || pooltype == "MIN") {
423 424
      SegmentPoolCUDAGradFunctor<T, IndexT>(
          dev_ctx, input, segments, output, out_grad, in_grad, pooltype);
425
    } else if (pooltype == "MEAN") {
426 427 428 429 430
      DenseTensor mean_grad;
      mean_grad.Resize(input.dims());
      dev_ctx.template Alloc<T>(&mean_grad);
      paddle::framework::TensorCopy(
          out_grad, dev_ctx.GetPlace(), dev_ctx, &mean_grad);
431 432
      int len = output.dims()[0];
      int dim = output.numel() / len;
433 434 435 436 437 438 439
      auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, len);
      SimpleDiv<T><<<config.block_per_grid.x,
                     config.thread_per_block.x,
                     0,
                     dev_ctx.stream()>>>(
          mean_grad.data<T>(), summed_ids->data<T>(), len, dim);
      phi::funcs::GPUGather<T, IndexT>(dev_ctx, mean_grad, segments, in_grad);
440
    } else if (pooltype == "SUM") {
441
      phi::funcs::GPUGather<T, IndexT>(dev_ctx, out_grad, segments, in_grad);
442
    } else {
443
      PADDLE_THROW(phi::errors::InvalidArgument(
444 445 446 447 448 449 450
          "Unsupported segment pooling operation, Only MEAN, SUM, MAX, MIN "
          "available, but got %s.",
          pooltype));
    }
  }
};

451 452 453 454 455
using GPU = phi::GPUContext;
template class SegmentPoolFunctor<GPU, float, int>;
template class SegmentPoolFunctor<GPU, float, int64_t>;
template class SegmentPoolFunctor<GPU, double, int>;
template class SegmentPoolFunctor<GPU, double, int64_t>;
456 457 458 459 460
template class SegmentPoolFunctor<GPU, int, int>;
template class SegmentPoolFunctor<GPU, int, int64_t>;
template class SegmentPoolFunctor<GPU, int64_t, int>;
template class SegmentPoolFunctor<GPU, int64_t, int64_t>;

461 462 463 464
template class SegmentPoolGradFunctor<GPU, float, int>;
template class SegmentPoolGradFunctor<GPU, float, int64_t>;
template class SegmentPoolGradFunctor<GPU, double, int>;
template class SegmentPoolGradFunctor<GPU, double, int64_t>;
465 466 467 468
template class SegmentPoolGradFunctor<GPU, int, int>;
template class SegmentPoolGradFunctor<GPU, int, int64_t>;
template class SegmentPoolGradFunctor<GPU, int64_t, int>;
template class SegmentPoolGradFunctor<GPU, int64_t, int64_t>;
469 470 471

}  // namespace funcs
}  // namespace phi