sequence_padding.cu 12.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yiqun Liu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <algorithm>
Y
Yi Wang 已提交
16
#include "paddle/fluid/operators/math/sequence_padding.h"
0
0x45f 已提交
17
#include "paddle/phi/backends/gpu/gpu_context.h"
Y
Yiqun Liu 已提交
18 19 20 21 22

namespace paddle {
namespace operators {
namespace math {

F
fengjiayi 已提交
23
template <typename T, CopyType Type>
24
__global__ void SequencePaddingKernel(
F
fengjiayi 已提交
25
    T* dst, const T* src, const T* pad_value, bool is_constant_pad,
F
bug fix  
fengjiayi 已提交
26
    const size_t* seq_offsets, const size_t seq_num, const size_t pad_seq_len,
H
Hui Zhang 已提交
27
    const size_t step_width, bool norm_by_len, const PadLayout layout) {
Y
yangyaming 已提交
28
  size_t seq_idx = blockIdx.y;
F
fengjiayi 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41
  size_t seq_len = seq_offsets[seq_idx + 1] - seq_offsets[seq_idx];

  size_t step_idx = blockIdx.x * blockDim.y + threadIdx.y;
  size_t seq_data_offset = (seq_offsets[seq_idx] + step_idx) * step_width;
  size_t pad_data_offset = layout == kBatchLengthWidth
                               ? (seq_idx * pad_seq_len + step_idx) * step_width
                               : (step_idx * seq_num + seq_idx) * step_width;

  T* dst_data = dst + (Type == kSeqToPad ? pad_data_offset : seq_data_offset);
  const T* src_data =
      src + (Type == kSeqToPad ? seq_data_offset : pad_data_offset);

  if (step_idx < seq_len) {
H
Hui Zhang 已提交
42
    float scale = norm_by_len ? (1.0f / static_cast<float>(seq_len)) : 1.0f;
F
fengjiayi 已提交
43 44
    for (size_t i = threadIdx.x; i < step_width; i += blockDim.x) {
      dst_data[i] = scale * src_data[i];
Y
Yiqun Liu 已提交
45
    }
F
fengjiayi 已提交
46
  } else if (step_idx < pad_seq_len && Type == kSeqToPad) {
F
bug fix  
fengjiayi 已提交
47
    for (size_t i = threadIdx.x; i < step_width; i += blockDim.x) {
F
fengjiayi 已提交
48
      dst_data[i] = is_constant_pad ? pad_value[0] : pad_value[i];
Y
Yiqun Liu 已提交
49 50 51 52
    }
  }
}

Y
yangyaming 已提交
53 54
template <typename T>
class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
Y
Yiqun Liu 已提交
55 56
 public:
  void operator()(const platform::CUDADeviceContext& context,
57
                  const framework::LoDTensor& seq_tensor,
F
bug fix  
fengjiayi 已提交
58
                  framework::LoDTensor* pad_tensor,
F
fengjiayi 已提交
59 60 61 62
                  const framework::LoDTensor& pad_value, int pad_seq_len = -1,
                  int lod_level = 0, bool norm_by_times = false,
                  const PadLayout layout = kBatchLengthWidth) {
    auto seq_lod = seq_tensor.lod();
63
    auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level];
F
fengjiayi 已提交
64 65
    const auto& seq_tensor_dims = seq_tensor.dims();
    const auto& pad_tensor_dims = pad_tensor->dims();
F
bug fix  
fengjiayi 已提交
66
    int max_seq_len = MaximumSequenceLength(seq_offsets);
F
fengjiayi 已提交
67
    if (pad_seq_len == -1) {
F
bug fix  
fengjiayi 已提交
68
      pad_seq_len = max_seq_len;
F
fengjiayi 已提交
69
    }
70 71 72 73 74 75 76
    PADDLE_ENFORCE_GE(
        pad_seq_len, max_seq_len,
        platform::errors::InvalidArgument(
            "The pad_seq_len must be equal to or greater than the "
            "original max sequence length. Expected %ld >= %ld, but got %ld < "
            "%ld. Please check the input value.",
            pad_seq_len, max_seq_len, pad_seq_len, max_seq_len));
F
fengjiayi 已提交
77
    int step_width = seq_tensor.numel() / seq_tensor_dims[0];
F
bug fix  
fengjiayi 已提交
78
    int seq_num = seq_offsets.size() - 1;
79

F
fengjiayi 已提交
80 81
    CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
              step_width, layout);
82 83 84 85 86 87 88
    PADDLE_ENFORCE_EQ(
        pad_value.numel() == 1 || pad_value.numel() == step_width, true,
        platform::errors::InvalidArgument(
            "The numel of 'pad_value' can only be 1 or be equal to "
            "the 'step_width', but got %ld != 1 and %ld. Please check the "
            "input value.",
            pad_value.numel(), step_width));
89

F
bug fix  
fengjiayi 已提交
90
    const int kBlockSize = 512;
Y
Yiqun Liu 已提交
91 92 93 94 95

    /* At least use 32 threads to copy sequence_width elements,
     * and at least 8 elements for each thread.
     */
    size_t block_dim_x =
F
fengjiayi 已提交
96
        std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
Y
Yiqun Liu 已提交
97 98 99
    size_t block_dim_y = kBlockSize / block_dim_x;
    dim3 threads(block_dim_x, block_dim_y);

F
fengjiayi 已提交
100
    size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y;
101
    size_t grid_dim_y = seq_num;
Y
Yiqun Liu 已提交
102 103
    dim3 grid(grid_dim_x, grid_dim_y);

104
    const T* seq_data = seq_tensor.data<T>();
Y
yangyaming 已提交
105
    T* pad_data = pad_tensor->data<T>();
F
fengjiayi 已提交
106
    const T* pad_value_data = pad_value.data<T>();
107

108
    paddle::framework::MixVector<size_t> mix_vector_seq_offsets(&seq_offsets);
F
fengjiayi 已提交
109 110
    SequencePaddingKernel<T, kSeqToPad><<<grid, threads, 0, context.stream()>>>(
        pad_data, seq_data, pad_value_data, pad_value.numel() == 1,
111 112
        mix_vector_seq_offsets.CUDAData(context.GetPlace()), seq_num,
        pad_seq_len, step_width, norm_by_times, layout);
Y
Yiqun Liu 已提交
113 114 115
  }
};

0
0x45f 已提交
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
template <typename T>
class PaddingLoDTensorFunctor<phi::GPUContext, T> {
 public:
  void operator()(const phi::GPUContext& context,
                  const framework::LoDTensor& seq_tensor,
                  framework::LoDTensor* pad_tensor,
                  const framework::LoDTensor& pad_value, int pad_seq_len = -1,
                  int lod_level = 0, bool norm_by_times = false,
                  const PadLayout layout = kBatchLengthWidth) {
    auto seq_lod = seq_tensor.lod();
    auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level];
    const auto& seq_tensor_dims = seq_tensor.dims();
    const auto& pad_tensor_dims = pad_tensor->dims();
    int max_seq_len = MaximumSequenceLength(seq_offsets);
    if (pad_seq_len == -1) {
      pad_seq_len = max_seq_len;
    }
    PADDLE_ENFORCE_GE(
        pad_seq_len, max_seq_len,
        platform::errors::InvalidArgument(
            "The pad_seq_len must be equal to or greater than the "
            "original max sequence length. Expected %ld >= %ld, but got %ld < "
            "%ld. Please check the input value.",
            pad_seq_len, max_seq_len, pad_seq_len, max_seq_len));
    int step_width = seq_tensor.numel() / seq_tensor_dims[0];
    int seq_num = seq_offsets.size() - 1;

    CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
              step_width, layout);
    PADDLE_ENFORCE_EQ(
        pad_value.numel() == 1 || pad_value.numel() == step_width, true,
        platform::errors::InvalidArgument(
            "The numel of 'pad_value' can only be 1 or be equal to "
            "the 'step_width', but got %ld != 1 and %ld. Please check the "
            "input value.",
            pad_value.numel(), step_width));

    const int kBlockSize = 512;

    /* At least use 32 threads to copy sequence_width elements,
     * and at least 8 elements for each thread.
     */
    size_t block_dim_x =
        std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
    size_t block_dim_y = kBlockSize / block_dim_x;
    dim3 threads(block_dim_x, block_dim_y);

    size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y;
    size_t grid_dim_y = seq_num;
    dim3 grid(grid_dim_x, grid_dim_y);

    const T* seq_data = seq_tensor.data<T>();
    T* pad_data = pad_tensor->data<T>();
    const T* pad_value_data = pad_value.data<T>();

    paddle::framework::MixVector<size_t> mix_vector_seq_offsets(&seq_offsets);
    SequencePaddingKernel<T, kSeqToPad><<<grid, threads, 0, context.stream()>>>(
        pad_data, seq_data, pad_value_data, pad_value.numel() == 1,
        mix_vector_seq_offsets.CUDAData(context.GetPlace()), seq_num,
        pad_seq_len, step_width, norm_by_times, layout);
  }
};

Y
yangyaming 已提交
179 180
template <typename T>
class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
Y
Yiqun Liu 已提交
181 182
 public:
  void operator()(const platform::CUDADeviceContext& context,
F
fengjiayi 已提交
183 184 185 186 187 188 189
                  const framework::LoDTensor& pad_tensor,
                  framework::LoDTensor* seq_tensor, int pad_seq_len = -1,
                  int lod_level = 0, bool norm_by_times = false,
                  const PadLayout layout = kBatchLengthWidth) {
    auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level];
    const auto& seq_tensor_dims = seq_tensor->dims();
    const auto& pad_tensor_dims = pad_tensor.dims();
F
bug fix  
fengjiayi 已提交
190
    int max_seq_len = MaximumSequenceLength(seq_offsets);
F
fengjiayi 已提交
191
    if (pad_seq_len == -1) {
F
bug fix  
fengjiayi 已提交
192
      pad_seq_len = max_seq_len;
F
fengjiayi 已提交
193 194
    }
    int step_width = seq_tensor->numel() / seq_tensor_dims[0];
F
bug fix  
fengjiayi 已提交
195
    int seq_num = seq_offsets.size() - 1;
Y
yangyaming 已提交
196

F
fengjiayi 已提交
197 198
    CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
              step_width, layout);
P
phlrain 已提交
199
    /*
F
bug fix  
fengjiayi 已提交
200
    if (!norm_by_times && seq_num == 1UL && pad_seq_len == max_seq_len) {
201 202
      paddle::framework::TensorCopy(pad_tensor, context.GetPlace(), context,
    seq_tensor);
Y
yangyaming 已提交
203
      seq_tensor->Resize(seq_tensor_dims);
Y
Yiqun Liu 已提交
204 205
      return;
    }
P
phlrain 已提交
206
    */
Y
Yiqun Liu 已提交
207

F
bug fix  
fengjiayi 已提交
208
    const int kBlockSize = 512;
Y
Yiqun Liu 已提交
209 210 211 212 213

    /* At least use 32 threads to copy sequence_width elements,
     * and at least 8 elements for each thread.
     */
    size_t block_dim_x =
F
fengjiayi 已提交
214
        std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
Y
Yiqun Liu 已提交
215 216 217
    size_t block_dim_y = kBlockSize / block_dim_x;
    dim3 threads(block_dim_x, block_dim_y);

F
fengjiayi 已提交
218
    size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y;
219
    size_t grid_dim_y = seq_num;
Y
Yiqun Liu 已提交
220 221
    dim3 grid(grid_dim_x, grid_dim_y);

Y
yangyaming 已提交
222
    const T* pad_data = pad_tensor.data<T>();
223 224
    T* seq_data = seq_tensor->data<T>();

225
    paddle::framework::MixVector<size_t> mixv_seq_offsets(&seq_offsets);
F
fengjiayi 已提交
226 227
    SequencePaddingKernel<T, kPadToSeq><<<grid, threads, 0, context.stream()>>>(
        seq_data, pad_data, nullptr, false,
228
        mixv_seq_offsets.CUDAData(context.GetPlace()), seq_num, pad_seq_len,
H
Hui Zhang 已提交
229
        step_width, norm_by_times, layout);
Y
Yiqun Liu 已提交
230 231 232
  }
};

0
0x45f 已提交
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
template <typename T>
class UnpaddingLoDTensorFunctor<phi::GPUContext, T> {
 public:
  void operator()(const phi::GPUContext& context,
                  const framework::LoDTensor& pad_tensor,
                  framework::LoDTensor* seq_tensor, int pad_seq_len = -1,
                  int lod_level = 0, bool norm_by_times = false,
                  const PadLayout layout = kBatchLengthWidth) {
    auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level];
    const auto& seq_tensor_dims = seq_tensor->dims();
    const auto& pad_tensor_dims = pad_tensor.dims();
    int max_seq_len = MaximumSequenceLength(seq_offsets);
    if (pad_seq_len == -1) {
      pad_seq_len = max_seq_len;
    }
    int step_width = seq_tensor->numel() / seq_tensor_dims[0];
    int seq_num = seq_offsets.size() - 1;

    CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
              step_width, layout);
    /*
    if (!norm_by_times && seq_num == 1UL && pad_seq_len == max_seq_len) {
      paddle::framework::TensorCopy(pad_tensor, context.GetPlace(), context,
    seq_tensor);
      seq_tensor->Resize(seq_tensor_dims);
      return;
    }
    */

    const int kBlockSize = 512;

    /* At least use 32 threads to copy sequence_width elements,
     * and at least 8 elements for each thread.
     */
    size_t block_dim_x =
        std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
    size_t block_dim_y = kBlockSize / block_dim_x;
    dim3 threads(block_dim_x, block_dim_y);

    size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y;
    size_t grid_dim_y = seq_num;
    dim3 grid(grid_dim_x, grid_dim_y);

    const T* pad_data = pad_tensor.data<T>();
    T* seq_data = seq_tensor->data<T>();

    paddle::framework::MixVector<size_t> mixv_seq_offsets(&seq_offsets);
    SequencePaddingKernel<T, kPadToSeq><<<grid, threads, 0, context.stream()>>>(
        seq_data, pad_data, nullptr, false,
        mixv_seq_offsets.CUDAData(context.GetPlace()), seq_num, pad_seq_len,
        step_width, norm_by_times, layout);
  }
};

Y
yangyaming 已提交
287 288 289 290 291 292 293 294 295
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, int>;
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, int64_t>;
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, float>;
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, double>;

template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, int>;
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, int64_t>;
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, float>;
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, double>;
Y
Yiqun Liu 已提交
296

0
0x45f 已提交
297 298 299 300 301 302 303 304 305 306
template class PaddingLoDTensorFunctor<phi::GPUContext, int>;
template class PaddingLoDTensorFunctor<phi::GPUContext, int64_t>;
template class PaddingLoDTensorFunctor<phi::GPUContext, float>;
template class PaddingLoDTensorFunctor<phi::GPUContext, double>;

template class UnpaddingLoDTensorFunctor<phi::GPUContext, int>;
template class UnpaddingLoDTensorFunctor<phi::GPUContext, int64_t>;
template class UnpaddingLoDTensorFunctor<phi::GPUContext, float>;
template class UnpaddingLoDTensorFunctor<phi::GPUContext, double>;

Y
Yiqun Liu 已提交
307 308 309
}  // namespace math
}  // namespace operators
}  // namespace paddle