sequence_pooling.cu 16.9 KB
Newer Older
1
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <algorithm>
16
#include <string>
17

18
#include "paddle/phi/backends/gpu/gpu_primitives.h"
19 20
#include "paddle/phi/core/macros.h"
#include "paddle/phi/core/mixed_vector.h"
21
#include "paddle/phi/kernels/funcs/math_function.h"
22
#include "paddle/phi/kernels/funcs/sequence_pooling.h"
23

24 25
namespace phi {
namespace funcs {
26 27

template <typename T>
D
dzhwinter 已提交
28
struct MaxPoolFunctor {
29 30 31 32 33 34 35
  HOSTDEVICE void operator()(const T* input,
                             const T pad_value,
                             const size_t start,
                             const size_t end,
                             const size_t item_dim,
                             T* output,
                             int* index) {
D
dzhwinter 已提交
36 37 38
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
      T max_val = static_cast<T>(-FLT_MAX);
      int max_index = -1;
39 40 41 42 43 44 45 46 47
      if (start == end) {
        output[tid] = pad_value;
        index[tid] = -1;
      } else {
        for (int i = start; i < end; ++i) {
          if (max_val < input[item_dim * i + tid]) {
            max_val = input[item_dim * i + tid];
            max_index = i;
          }
D
dzhwinter 已提交
48
        }
49 50
        output[tid] = max_val;
        index[tid] = max_index;
51 52 53
      }
    }
  }
D
dzhwinter 已提交
54
};
55 56

template <typename T>
D
dzhwinter 已提交
57
struct AvgPoolFunctor {
58 59 60 61 62 63 64
  HOSTDEVICE void operator()(const T* input,
                             const T pad_value,
                             const size_t start,
                             const size_t end,
                             const size_t item_dim,
                             T* output,
                             int* index) {
D
dzhwinter 已提交
65
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
66 67 68 69 70 71 72 73 74
      if (start == end) {
        output[tid] = pad_value;
      } else {
        T val = static_cast<T>(0);
        for (int i = start; i < end; ++i) {
          val += input[item_dim * i + tid];
        }
        // end, start is lod, so end - start != 0
        output[tid] = val / static_cast<T>(end - start);
D
dzhwinter 已提交
75
      }
76
    }
D
dzhwinter 已提交
77 78
  }
};
79

D
dzhwinter 已提交
80 81
template <typename T>
struct SumPoolFunctor {
82 83 84 85 86 87 88
  HOSTDEVICE void operator()(const T* input,
                             const T pad_value,
                             const size_t start,
                             const size_t end,
                             const size_t item_dim,
                             T* output,
                             int* index) {
D
dzhwinter 已提交
89
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
90 91 92 93 94 95 96 97
      if (start == end) {
        output[tid] = pad_value;
      } else {
        T val = static_cast<T>(0);
        for (int i = start; i < end; ++i) {
          val += input[item_dim * i + tid];
        }
        output[tid] = val;
D
dzhwinter 已提交
98 99 100 101
      }
    }
  }
};
102

D
dzhwinter 已提交
103 104
template <typename T>
struct SqrtPoolFunctor {
105 106 107 108 109 110 111
  HOSTDEVICE void operator()(const T* input,
                             const T pad_value,
                             const size_t start,
                             const size_t end,
                             const size_t item_dim,
                             T* output,
                             int* index) {
D
dzhwinter 已提交
112
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
113 114 115 116 117 118 119 120 121
      if (start == end) {
        output[tid] = pad_value;
      } else {
        T val = static_cast<T>(0);
        for (int i = start; i < end; ++i) {
          val += input[item_dim * i + tid];
        }
        // end, start is lod, so end - start != 0
        output[tid] = val / sqrt(end - start);
D
dzhwinter 已提交
122 123 124 125
      }
    }
  }
};
126

D
dzhwinter 已提交
127 128
template <typename T>
struct LastPoolFunctor {
129 130 131 132 133 134 135
  HOSTDEVICE void operator()(const T* input,
                             const T pad_value,
                             const size_t start,
                             const size_t end,
                             const size_t item_dim,
                             T* output,
                             int* index) {
D
dzhwinter 已提交
136
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
137 138 139 140 141
      if (start == end) {
        output[tid] = pad_value;
      } else {
        output[tid] = input[item_dim * (end - 1) + tid];
      }
D
dzhwinter 已提交
142
    }
143 144 145 146
  }
};

template <typename T>
D
dzhwinter 已提交
147
struct FirstPoolFunctor {
148 149 150 151 152 153 154
  HOSTDEVICE void operator()(const T* input,
                             const T pad_value,
                             const size_t start,
                             const size_t end,
                             const size_t item_dim,
                             T* output,
                             int* index) {
D
dzhwinter 已提交
155
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
156 157 158 159 160
      if (start == end) {
        output[tid] = pad_value;
      } else {
        output[tid] = input[item_dim * start + tid];
      }
D
dzhwinter 已提交
161
    }
162
  }
D
dzhwinter 已提交
163 164 165
};

template <typename T, typename Range_OP>
166 167 168 169
__global__ void sequence_pool_kernel(Range_OP op,
                                     const T* input,
                                     const T pad_value,
                                     const size_t* lod,
170
                                     const size_t lod_size,
171 172
                                     const size_t item_dim,
                                     T* output,
D
dzhwinter 已提交
173 174 175 176 177 178 179 180 181
                                     int* index) {
  int bid = blockIdx.x;
  if (bid >= lod_size - 1) return;
  size_t start = lod[bid];
  size_t end = lod[bid + 1];
  int* index_offset = nullptr;
  if (index != nullptr) {
    index_offset = &index[bid * item_dim];
  }
182 183 184 185 186 187
  op(input,
     pad_value,
     start,
     end,
     item_dim,
     &output[bid * item_dim],
188
     index_offset);
189 190 191
}

template <typename T>
L
Leo Chen 已提交
192
class SequencePoolFunctor<phi::GPUContext, T> {
193
 public:
L
Leo Chen 已提交
194
  void operator()(const phi::GPUContext& context,
195 196
                  const std::string pooltype,
                  T pad_value,
197 198
                  const phi::DenseTensor& input,
                  phi::DenseTensor* output,
199
                  bool is_test,
200
                  phi::DenseTensor* index = nullptr) {
201 202
    auto lod_level = input.lod().size();
    auto& lod = input.lod()[lod_level - 1];
D
dzhwinter 已提交
203 204
    const size_t item_dim = output->numel() / output->dims()[0];
    dim3 threads(1024, 1);
205
    dim3 grid(std::max(static_cast<int>(lod.size()) - 1, 1), 1);
H
Huang Jiyi 已提交
206
    phi::MixVector<size_t> mix_vector(&lod);
D
dzhwinter 已提交
207
    if (pooltype == "MAX") {
208 209
      sequence_pool_kernel<T, MaxPoolFunctor<T>>
          <<<grid, threads, 0, context.stream()>>>(
210 211 212 213 214 215
              MaxPoolFunctor<T>(),
              input.data<T>(),
              pad_value,
              mix_vector.CUDAData(context.GetPlace()),
              lod.size(),
              item_dim,
216
              context.template Alloc<T>(output),
217
              index->data<int>());
D
dzhwinter 已提交
218
    } else if (pooltype == "AVERAGE") {
219 220
      sequence_pool_kernel<T, AvgPoolFunctor<T>>
          <<<grid, threads, 0, context.stream()>>>(
221 222 223 224 225 226
              AvgPoolFunctor<T>(),
              input.data<T>(),
              pad_value,
              mix_vector.CUDAData(context.GetPlace()),
              lod.size(),
              item_dim,
227
              context.template Alloc<T>(output),
228
              nullptr);
D
dzhwinter 已提交
229
    } else if (pooltype == "SUM") {
230 231
      sequence_pool_kernel<T, SumPoolFunctor<T>>
          <<<grid, threads, 0, context.stream()>>>(
232 233 234 235 236 237
              SumPoolFunctor<T>(),
              input.data<T>(),
              pad_value,
              mix_vector.CUDAData(context.GetPlace()),
              lod.size(),
              item_dim,
238
              context.template Alloc<T>(output),
239
              nullptr);
D
dzhwinter 已提交
240
    } else if (pooltype == "SQRT") {
241 242
      sequence_pool_kernel<T, SqrtPoolFunctor<T>>
          <<<grid, threads, 0, context.stream()>>>(
243 244 245 246 247 248
              SqrtPoolFunctor<T>(),
              input.data<T>(),
              pad_value,
              mix_vector.CUDAData(context.GetPlace()),
              lod.size(),
              item_dim,
249
              context.template Alloc<T>(output),
250
              nullptr);
D
dzhwinter 已提交
251
    } else if (pooltype == "LAST") {
252 253
      sequence_pool_kernel<T, LastPoolFunctor<T>>
          <<<grid, threads, 0, context.stream()>>>(
254 255 256 257 258 259
              LastPoolFunctor<T>(),
              input.data<T>(),
              pad_value,
              mix_vector.CUDAData(context.GetPlace()),
              lod.size(),
              item_dim,
260
              context.template Alloc<T>(output),
261
              nullptr);
D
dzhwinter 已提交
262
    } else if (pooltype == "FIRST") {
263 264
      sequence_pool_kernel<T, FirstPoolFunctor<T>>
          <<<grid, threads, 0, context.stream()>>>(
265 266 267 268 269 270
              FirstPoolFunctor<T>(),
              input.data<T>(),
              pad_value,
              mix_vector.CUDAData(context.GetPlace()),
              lod.size(),
              item_dim,
271
              context.template Alloc<T>(output),
272
              nullptr);
D
dzhwinter 已提交
273
    } else {
274
      PADDLE_THROW(errors::InvalidArgument(
275 276 277
          "unsupported pooling pooltype: %s. Only support \"MAX\", "
          "\"AVERAGE\", \"SUM\", \"SQRT\", \"LAST\" and \"FIRST\"",
          pooltype));
278
    }
D
dzhwinter 已提交
279 280
  }
};
281

D
dzhwinter 已提交
282 283
template <typename T>
struct MaxPoolGradFunctor {
284 285 286 287 288 289
  HOSTDEVICE void operator()(const T* out_grad,
                             const size_t start,
                             const size_t end,
                             const size_t item_dim,
                             T* in_grad,
                             const int* index) {
D
dzhwinter 已提交
290 291 292 293 294 295 296 297 298 299 300
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
      for (int i = start; i < end; ++i) {
        if (i == index[tid]) {
          in_grad[item_dim * i + tid] = out_grad[tid];
        } else {
          in_grad[item_dim * i + tid] = static_cast<T>(0);
        }
      }
    }
  }
};
301

D
dzhwinter 已提交
302 303
template <typename T>
struct AvgPoolGradFunctor {
304 305 306 307 308 309
  HOSTDEVICE void operator()(const T* out_grad,
                             const size_t start,
                             const size_t end,
                             const size_t item_dim,
                             T* in_grad,
                             const int* index) {
D
dzhwinter 已提交
310 311 312 313 314 315 316
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
      for (int i = start; i < end; ++i) {
        in_grad[item_dim * i + tid] = out_grad[tid] / (end - start);
      }
    }
  }
};
317

D
dzhwinter 已提交
318 319
template <typename T>
struct SumPoolGradFunctor {
320 321 322 323 324 325
  HOSTDEVICE void operator()(const T* out_grad,
                             const size_t start,
                             const size_t end,
                             const size_t item_dim,
                             T* in_grad,
                             const int* index) {
D
dzhwinter 已提交
326 327 328 329 330 331 332 333 334 335
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
      for (int i = start; i < end; ++i) {
        in_grad[item_dim * i + tid] = out_grad[tid];
      }
    }
  }
};

template <typename T>
struct SqrtPoolGradFunctor {
336 337 338 339 340 341
  HOSTDEVICE void operator()(const T* out_grad,
                             const size_t start,
                             const size_t end,
                             const size_t item_dim,
                             T* in_grad,
                             const int* index) {
D
dzhwinter 已提交
342 343 344 345 346 347 348 349 350 351 352
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
      for (int i = start; i < end; ++i) {
        in_grad[item_dim * i + tid] =
            out_grad[tid] / (sqrt(static_cast<T>(end - start)));
      }
    }
  }
};

template <typename T>
struct LastPoolGradFunctor {
353 354 355 356 357 358
  HOSTDEVICE void operator()(const T* out_grad,
                             const size_t start,
                             const size_t end,
                             const size_t item_dim,
                             T* in_grad,
                             const int* index) {
D
dzhwinter 已提交
359 360 361 362 363 364 365 366 367 368 369 370 371 372
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
      for (int i = start; i < end; ++i) {
        if (i == end - 1) {
          in_grad[item_dim * i + tid] = out_grad[tid];
        } else {
          in_grad[item_dim * i + tid] = static_cast<T>(0);
        }
      }
    }
  }
};

template <typename T>
struct FirstPoolGradFunctor {
373 374 375 376 377 378
  HOSTDEVICE void operator()(const T* out_grad,
                             const size_t start,
                             const size_t end,
                             const size_t item_dim,
                             T* in_grad,
                             const int* index) {
D
dzhwinter 已提交
379 380 381 382 383 384 385 386 387 388 389 390 391
    for (int tid = threadIdx.x; tid < item_dim; tid += blockDim.x) {
      for (int i = start; i < end; ++i) {
        if (i == start) {
          in_grad[item_dim * i + tid] = out_grad[tid];
        } else {
          in_grad[item_dim * i + tid] = static_cast<T>(0);
        }
      }
    }
  }
};

template <typename T, typename Range_OP>
392 393
__global__ void sequence_pool_grad_kernel(Range_OP op,
                                          const T* out_grad,
D
dzhwinter 已提交
394 395
                                          const size_t* lod,
                                          const size_t lod_size,
396 397
                                          const size_t item_dim,
                                          T* in_grad,
D
dzhwinter 已提交
398 399 400 401 402 403 404 405 406 407 408 409 410
                                          const int* index) {
  int bid = blockIdx.x;
  if (bid >= lod_size - 1) return;
  size_t start = lod[bid];
  size_t end = lod[bid + 1];
  const int* index_offset = nullptr;
  if (index != nullptr) {
    index_offset = &index[bid * item_dim];
  }
  op(&out_grad[bid * item_dim], start, end, item_dim, in_grad, index_offset);
}

template <typename T>
L
Leo Chen 已提交
411
class SequencePoolGradFunctor<phi::GPUContext, T> {
D
dzhwinter 已提交
412
 public:
L
Leo Chen 已提交
413
  void operator()(const phi::GPUContext& context,
414
                  const std::string pooltype,
415 416
                  const phi::DenseTensor& out_grad,
                  phi::DenseTensor* in_grad,
D
dzhwinter 已提交
417
                  /* max pool has index */
418
                  const phi::DenseTensor* index = nullptr) {
419 420
    auto lod_level = in_grad->lod().size();
    auto& lod = in_grad->lod()[lod_level - 1];
D
dzhwinter 已提交
421 422
    const size_t item_dim = in_grad->numel() / in_grad->dims()[0];
    dim3 threads(1024, 1);
423
    dim3 grid(std::max(static_cast<int>(lod.size()) - 1, 1), 1);
H
Huang Jiyi 已提交
424
    phi::MixVector<size_t> mix_vector(&lod);
D
dzhwinter 已提交
425
    if (pooltype == "MAX") {
426 427
      sequence_pool_grad_kernel<T, MaxPoolGradFunctor<T>>
          <<<grid, threads, 0, context.stream()>>>(
428 429 430 431 432
              MaxPoolGradFunctor<T>(),
              out_grad.data<T>(),
              mix_vector.CUDAData(context.GetPlace()),
              lod.size(),
              item_dim,
433
              context.template Alloc<T>(in_grad),
434
              index->data<int>());
D
dzhwinter 已提交
435
    } else if (pooltype == "AVERAGE") {
436 437
      sequence_pool_grad_kernel<T, AvgPoolGradFunctor<T>>
          <<<grid, threads, 0, context.stream()>>>(
438 439 440 441 442
              AvgPoolGradFunctor<T>(),
              out_grad.data<T>(),
              mix_vector.CUDAData(context.GetPlace()),
              lod.size(),
              item_dim,
443
              context.template Alloc<T>(in_grad),
444
              nullptr);
D
dzhwinter 已提交
445
    } else if (pooltype == "SUM") {
446 447
      sequence_pool_grad_kernel<T, SumPoolGradFunctor<T>>
          <<<grid, threads, 0, context.stream()>>>(
448 449 450 451 452
              SumPoolGradFunctor<T>(),
              out_grad.data<T>(),
              mix_vector.CUDAData(context.GetPlace()),
              lod.size(),
              item_dim,
453
              context.template Alloc<T>(in_grad),
454
              nullptr);
D
dzhwinter 已提交
455
    } else if (pooltype == "SQRT") {
456 457
      sequence_pool_grad_kernel<T, SqrtPoolGradFunctor<T>>
          <<<grid, threads, 0, context.stream()>>>(
458 459 460 461 462
              SqrtPoolGradFunctor<T>(),
              out_grad.data<T>(),
              mix_vector.CUDAData(context.GetPlace()),
              lod.size(),
              item_dim,
463
              context.template Alloc<T>(in_grad),
464
              nullptr);
D
dzhwinter 已提交
465
    } else if (pooltype == "LAST") {
466 467
      sequence_pool_grad_kernel<T, LastPoolGradFunctor<T>>
          <<<grid, threads, 0, context.stream()>>>(
468 469 470 471 472
              LastPoolGradFunctor<T>(),
              out_grad.data<T>(),
              mix_vector.CUDAData(context.GetPlace()),
              lod.size(),
              item_dim,
473
              context.template Alloc<T>(in_grad),
474
              nullptr);
D
dzhwinter 已提交
475
    } else if (pooltype == "FIRST") {
476 477
      sequence_pool_grad_kernel<T, FirstPoolGradFunctor<T>>
          <<<grid, threads, 0, context.stream()>>>(
478 479 480 481 482
              FirstPoolGradFunctor<T>(),
              out_grad.data<T>(),
              mix_vector.CUDAData(context.GetPlace()),
              lod.size(),
              item_dim,
483
              context.template Alloc<T>(in_grad),
484
              nullptr);
D
dzhwinter 已提交
485 486

    } else {
487
      PADDLE_THROW(errors::InvalidArgument(
488 489 490
          "unsupported pooling pooltype: %s. Only support \"MAX\", "
          "\"AVERAGE\", \"SUM\", \"SQRT\", \"LAST\" and \"FIRST\"",
          pooltype));
D
dzhwinter 已提交
491
    }
492 493 494
  }
};

D
dzhwinter 已提交
495
// sequence pooling
L
Leo Chen 已提交
496 497 498 499
template class SequencePoolFunctor<phi::GPUContext, float>;
template class SequencePoolFunctor<phi::GPUContext, double>;
template class SequencePoolGradFunctor<phi::GPUContext, float>;
template class SequencePoolGradFunctor<phi::GPUContext, double>;
500

501 502
}  // namespace funcs
}  // namespace phi