sequence_padding.cu 6.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yiqun Liu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <algorithm>
Y
Yi Wang 已提交
16
#include "paddle/fluid/operators/math/sequence_padding.h"
Y
Yiqun Liu 已提交
17 18 19 20 21

namespace paddle {
namespace operators {
namespace math {

22 23 24 25 26 27
template <typename T, bool Padding>
__global__ void SequencePaddingKernel(
    T* padding_data, T* seq_data, const size_t* abs_offset,
    const size_t& seq_num, const size_t& max_seq_len, const size_t& seq_width,
    const PaddingLayout& padding_layout, bool norm_by_times = false,
    const T& padding_value = 0) {
Y
Yiqun Liu 已提交
28
  size_t padding_idx = blockIdx.y;
29 30
  size_t seq_start = abs_offset[padding_idx];
  size_t seq_len = abs_offset[padding_idx + 1] - seq_start;
Y
Yiqun Liu 已提交
31

32
  size_t seq_idx = blockIdx.x * blockDim.y + threadIdx.y;
Y
Yiqun Liu 已提交
33

34 35 36 37 38 39 40 41 42 43 44 45
  size_t seq_offset = (seq_start + seq_idx) * seq_width;

  size_t padding_offset = 0;

  if (padding_layout == LENGTH_BATCH_WIDTH) {
    padding_offset = (seq_idx * seq_num + padding_idx) * seq_width;
  } else {
    padding_offset = (padding_idx * max_seq_len + seq_idx) * seq_width;
  }

  if (seq_idx < seq_len) {
    T scale = norm_by_times ? (1.0f / static_cast<T>(seq_len)) : 1.0f;
Y
Yiqun Liu 已提交
46 47
    if (Padding) {
      /* sequence -> padding */
48 49
      for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) {
        padding_data[padding_offset + i] = scale * seq_data[seq_offset + i];
Y
Yiqun Liu 已提交
50 51 52
      }
    } else {
      /* padding -> sequence */
53 54
      for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) {
        seq_data[seq_offset + i] = scale * padding_data[padding_offset + i];
Y
Yiqun Liu 已提交
55 56
      }
    }
57
  } else if (seq_idx < max_seq_len) {
Y
Yiqun Liu 已提交
58 59
    if (Padding) {
      /* sequence -> padding */
60 61
      for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) {
        padding_data[padding_offset + i] = padding_value;
Y
Yiqun Liu 已提交
62 63 64 65 66
      }
    }
  }
}

67 68
template <typename T, PaddingLayout padding_layout>
class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T, padding_layout> {
Y
Yiqun Liu 已提交
69 70
 public:
  void operator()(const platform::CUDADeviceContext& context,
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
                  const framework::LoDTensor& seq_tensor,
                  framework::Tensor* padding_tensor,
                  T padding_value = static_cast<T>(0),
                  bool norm_by_times = false, size_t lod_level = 0) {
    ValidateLoD(seq_tensor, lod_level);

    auto& lod = seq_tensor.lod();
    auto& abs_offset = framework::ToAbsOffset(lod)[lod_level];

    auto seq_dims = seq_tensor.dims();
    auto padding_dims = padding_tensor->dims();
    int64_t max_seq_len = MaximumSequenceLength(lod, lod_level);
    const int64_t seq_num = abs_offset.size() - 1;
    const int64_t seq_width = seq_tensor.numel() / seq_dims[0];

    ValidateShape(seq_dims, abs_offset.back(), padding_dims, max_seq_len,
                  seq_num, seq_width, padding_layout);

    if (!norm_by_times && seq_num == 1UL) {
      TensorCopy(seq_tensor, context.GetPlace(), context, padding_tensor);
      padding_tensor->Resize(padding_dims);
Y
Yiqun Liu 已提交
92 93 94
      return;
    }

Y
Yancey1989 已提交
95
    const int64_t kBlockSize = 512;
Y
Yiqun Liu 已提交
96 97 98 99 100

    /* At least use 32 threads to copy sequence_width elements,
     * and at least 8 elements for each thread.
     */
    size_t block_dim_x =
101
        std::min(((((seq_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
Y
Yiqun Liu 已提交
102 103 104
    size_t block_dim_y = kBlockSize / block_dim_x;
    dim3 threads(block_dim_x, block_dim_y);

105 106
    size_t grid_dim_x = (max_seq_len + block_dim_y - 1) / block_dim_y;
    size_t grid_dim_y = seq_num;
Y
Yiqun Liu 已提交
107 108
    dim3 grid(grid_dim_x, grid_dim_y);

109 110 111 112 113 114 115
    const T* seq_data = seq_tensor.data<T>();
    T* padding_data = padding_tensor->data<T>();

    SequencePaddingKernel<T, 1><<<grid, threads, 0, context.stream()>>>(
        padding_data, const_cast<T*>(seq_data),
        abs_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len,
        seq_width, padding_layout, norm_by_times, padding_value);
Y
Yiqun Liu 已提交
116 117 118
  }
};

119 120 121
template <typename T, PaddingLayout padding_layout>
class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T,
                                padding_layout> {
Y
Yiqun Liu 已提交
122 123
 public:
  void operator()(const platform::CUDADeviceContext& context,
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
                  framework::LoDTensor* seq_tensor,
                  const framework::Tensor& padding_tensor,
                  bool norm_by_times = false, size_t lod_level = 0) {
    ValidateLoD(*seq_tensor, lod_level);

    auto& lod = seq_tensor->lod();
    auto& abs_offset = framework::ToAbsOffset(lod)[lod_level];

    auto seq_dims = seq_tensor->dims();
    auto padding_dims = padding_tensor.dims();
    int64_t max_seq_len = MaximumSequenceLength(lod, lod_level);
    int64_t seq_num = abs_offset.size() - 1;
    int64_t seq_width = seq_tensor->numel() / seq_dims[0];

    if (!norm_by_times && seq_num == 1UL) {
      TensorCopy(padding_tensor, context.GetPlace(), context, seq_tensor);
      seq_tensor->Resize(seq_dims);
Y
Yiqun Liu 已提交
141 142 143
      return;
    }

Y
Yancey1989 已提交
144
    const int64_t kBlockSize = 512;
Y
Yiqun Liu 已提交
145 146 147 148 149

    /* At least use 32 threads to copy sequence_width elements,
     * and at least 8 elements for each thread.
     */
    size_t block_dim_x =
150
        std::min(((((seq_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
Y
Yiqun Liu 已提交
151 152 153
    size_t block_dim_y = kBlockSize / block_dim_x;
    dim3 threads(block_dim_x, block_dim_y);

154 155
    size_t grid_dim_x = (max_seq_len + block_dim_y - 1) / block_dim_y;
    size_t grid_dim_y = seq_num;
Y
Yiqun Liu 已提交
156 157
    dim3 grid(grid_dim_x, grid_dim_y);

158 159 160 161 162 163 164
    const T* padding_data = padding_tensor.data<T>();
    T* seq_data = seq_tensor->data<T>();

    SequencePaddingKernel<T, 1><<<grid, threads, 0, context.stream()>>>(
        const_cast<T*>(padding_data), seq_data,
        abs_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len,
        seq_width, padding_layout, norm_by_times);
Y
Yiqun Liu 已提交
165 166 167
  }
};

168 169 170 171
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, float,
                                       LENGTH_BATCH_WIDTH>;
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, float,
                                         LENGTH_BATCH_WIDTH>;
Y
Yiqun Liu 已提交
172 173 174 175

}  // namespace math
}  // namespace operators
}  // namespace paddle