sequence_padding.cu 8.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yiqun Liu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <algorithm>
Y
Yi Wang 已提交
16
#include "paddle/fluid/operators/math/sequence_padding.h"
Y
Yiqun Liu 已提交
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64

namespace paddle {
namespace operators {
namespace math {

template <typename T, bool NormByTimes, bool Padding>
__global__ void SequencePaddingKernel(T* padding, T* sequence,
                                      const size_t* sequence_start_positions,
                                      const size_t sequence_width,
                                      const size_t max_sequence_length,
                                      const size_t num_sequences) {
  size_t padding_idx = blockIdx.y;
  size_t start_pos = sequence_start_positions[padding_idx];
  size_t sequence_length =
      sequence_start_positions[padding_idx + 1] - start_pos;

  size_t sequence_idx = blockIdx.x * blockDim.y + threadIdx.y;
  size_t padding_base_idx =
      (sequence_idx * num_sequences + padding_idx) * sequence_width;
  size_t sequence_base_idx = (start_pos + sequence_idx) * sequence_width;

  if (sequence_idx < sequence_length) {
    T scale = NormByTimes ? (1.0f / static_cast<T>(sequence_length)) : 1.0f;
    if (Padding) {
      /* sequence -> padding */
      for (size_t i = threadIdx.x; i < sequence_width; i += blockDim.x) {
        padding[padding_base_idx + i] = scale * sequence[sequence_base_idx + i];
      }
    } else {
      /* padding -> sequence */
      for (size_t i = threadIdx.x; i < sequence_width; i += blockDim.x) {
        sequence[sequence_base_idx + i] = scale * padding[padding_base_idx + i];
      }
    }
  } else if (sequence_idx < max_sequence_length) {
    if (Padding) {
      /* sequence -> padding */
      for (size_t i = threadIdx.x; i < sequence_width; i += blockDim.x) {
        padding[padding_base_idx + i] = 0;
      }
    }
  }
}

template <typename T>
class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
65
                  const framework::LoDTensor& seq, framework::Tensor* padding,
Y
Yiqun Liu 已提交
66 67 68 69 70 71 72 73 74
                  bool norm_by_times) {
    auto lod = seq.lod();
    PADDLE_ENFORCE_GT(lod.size(), 0UL,
                      "The lod of LoDTensor seq should not be null.");

    const size_t level = 0;
    framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);

    auto seq_dims = seq.dims();
Y
Yancey1989 已提交
75 76
    PADDLE_ENFORCE_EQ(seq_dims[0],
                      static_cast<int64_t>(abs_offset_lod[level].back()),
Y
Yiqun Liu 已提交
77 78 79
                      "The first dimension of LoDTensor seq should be "
                      "equal to the sum of all sequences's length.");

80
    auto padding_dims = padding->dims();
Y
Yiqun Liu 已提交
81 82 83 84
    PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL,
                      "The input padding should be a 3-D Tensor of shape "
                      "[max_sequence_length, num_sequences, sequence_width].");

Y
Yancey1989 已提交
85
    int64_t max_sequence_length = MaximumSequenceLength(lod, level);
Y
Yiqun Liu 已提交
86 87 88 89
    PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length,
                      "The first dimension of Tensor padding should be the "
                      "maximum length of all sequences in LoDTensor seq.");

Y
Yancey1989 已提交
90
    const int64_t num_sequences = abs_offset_lod[level].size() - 1;
Y
Yiqun Liu 已提交
91 92 93 94
    PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences,
                      "The second dimension of Tensor padding should be the "
                      "number of sequences in LoDTensor seq.");

Y
Yancey1989 已提交
95
    const int64_t sequence_width = seq.numel() / seq_dims[0];
Y
Yiqun Liu 已提交
96 97 98 99 100
    PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width,
                      "The third dimension of Tensor padding should be the "
                      "width of sequence in LoDTensor seq.");

    if (!norm_by_times && num_sequences == 1UL) {
101 102
      TensorCopy(seq, context.GetPlace(), context, padding);
      padding->Resize(padding_dims);
Y
Yiqun Liu 已提交
103 104 105
      return;
    }

Y
Yancey1989 已提交
106
    const int64_t kBlockSize = 512;
Y
Yiqun Liu 已提交
107 108 109 110 111 112 113 114 115 116 117 118 119 120

    /* At least use 32 threads to copy sequence_width elements,
     * and at least 8 elements for each thread.
     */
    size_t block_dim_x =
        std::min(((((sequence_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
    size_t block_dim_y = kBlockSize / block_dim_x;
    dim3 threads(block_dim_x, block_dim_y);

    size_t grid_dim_x = (max_sequence_length + block_dim_y - 1) / block_dim_y;
    size_t grid_dim_y = num_sequences;
    dim3 grid(grid_dim_x, grid_dim_y);

    const T* seq_data = seq.data<T>();
121
    T* padding_data = padding->data<T>();
Y
Yiqun Liu 已提交
122 123
    if (norm_by_times) {
      SequencePaddingKernel<T, 1, 1><<<grid, threads, 0, context.stream()>>>(
D
dzhwinter 已提交
124
          padding_data, const_cast<T*>(seq_data),
Y
Yu Yang 已提交
125
          abs_offset_lod[level].CUDAData(context.GetPlace()), sequence_width,
D
dzhwinter 已提交
126
          max_sequence_length, num_sequences);
Y
Yiqun Liu 已提交
127 128
    } else {
      SequencePaddingKernel<T, 0, 1><<<grid, threads, 0, context.stream()>>>(
D
dzhwinter 已提交
129
          padding_data, const_cast<T*>(seq_data),
Y
Yu Yang 已提交
130
          abs_offset_lod[level].CUDAData(context.GetPlace()), sequence_width,
D
dzhwinter 已提交
131
          max_sequence_length, num_sequences);
Y
Yiqun Liu 已提交
132 133 134 135 136 137 138 139
    }
  }
};

template <typename T>
class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
 public:
  void operator()(const platform::CUDADeviceContext& context,
140
                  framework::LoDTensor* seq, const framework::Tensor& padding,
Y
Yiqun Liu 已提交
141
                  bool norm_by_times) {
142
    auto lod = seq->lod();
Y
Yiqun Liu 已提交
143 144 145 146 147 148
    PADDLE_ENFORCE_GT(lod.size(), 0UL,
                      "The lod of LoDTensor seq should not be null.");

    const size_t level = 0;
    framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);

149
    auto seq_dims = seq->dims();
Y
Yancey1989 已提交
150 151
    PADDLE_ENFORCE_EQ(seq_dims[0],
                      static_cast<int64_t>(abs_offset_lod[level].back()),
Y
Yiqun Liu 已提交
152 153 154 155 156 157 158 159
                      "The first dimension of LoDTensor seq should be "
                      "equal to the sum of all sequences's length.");

    auto padding_dims = padding.dims();
    PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL,
                      "The input padding should be a 3-D Tensor of shape "
                      "[max_sequnece_length, num_sequences, sequence_width].");

Y
Yancey1989 已提交
160
    int64_t max_sequence_length = MaximumSequenceLength(lod, level);
Y
Yiqun Liu 已提交
161 162 163 164
    PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length,
                      "The first dimension of Tensor padding should be "
                      "the maximum length of all sequences in LoDTensor seq.");

Y
Yancey1989 已提交
165
    const int64_t num_sequences = abs_offset_lod[level].size() - 1;
Y
Yiqun Liu 已提交
166 167 168 169
    PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences,
                      "The second dimension of Tensor padding should be "
                      "the number of sequences in LoDTensor seq.");

170
    const int64_t sequence_width = seq->numel() / seq_dims[0];
Y
Yiqun Liu 已提交
171 172 173 174 175
    PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width,
                      "The third dimension of Tensor padding should be the "
                      "width of sequence in LoDTensor seq.");

    if (!norm_by_times && num_sequences == 1UL) {
176 177
      TensorCopy(padding, context.GetPlace(), context, seq);
      seq->Resize(seq_dims);
Y
Yiqun Liu 已提交
178 179 180
      return;
    }

Y
Yancey1989 已提交
181
    const int64_t kBlockSize = 512;
Y
Yiqun Liu 已提交
182 183 184 185 186 187 188 189 190 191 192 193 194 195

    /* At least use 32 threads to copy sequence_width elements,
     * and at least 8 elements for each thread.
     */
    size_t block_dim_x =
        std::min(((((sequence_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
    size_t block_dim_y = kBlockSize / block_dim_x;
    dim3 threads(block_dim_x, block_dim_y);

    size_t grid_dim_x = (max_sequence_length + block_dim_y - 1) / block_dim_y;
    size_t grid_dim_y = num_sequences;
    dim3 grid(grid_dim_x, grid_dim_y);

    const T* padding_data = padding.data<T>();
196
    T* seq_data = seq->data<T>();
Y
Yiqun Liu 已提交
197 198
    if (norm_by_times) {
      SequencePaddingKernel<T, 1, 0><<<grid, threads, 0, context.stream()>>>(
D
dzhwinter 已提交
199
          const_cast<T*>(padding_data), seq_data,
Y
Yu Yang 已提交
200
          abs_offset_lod[level].CUDAData(context.GetPlace()), sequence_width,
D
dzhwinter 已提交
201
          max_sequence_length, num_sequences);
Y
Yiqun Liu 已提交
202 203
    } else {
      SequencePaddingKernel<T, 0, 0><<<grid, threads, 0, context.stream()>>>(
D
dzhwinter 已提交
204
          const_cast<T*>(padding_data), seq_data,
Y
Yu Yang 已提交
205
          abs_offset_lod[level].CUDAData(context.GetPlace()), sequence_width,
D
dzhwinter 已提交
206
          max_sequence_length, num_sequences);
Y
Yiqun Liu 已提交
207 208 209 210 211 212 213 214 215 216
    }
  }
};

template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, float>;
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, float>;

}  // namespace math
}  // namespace operators
}  // namespace paddle