sequence_padding.cu 8.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yiqun Liu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <algorithm>
16

Y
Yi Wang 已提交
17
#include "paddle/fluid/operators/math/sequence_padding.h"
0
0x45f 已提交
18
#include "paddle/phi/backends/gpu/gpu_context.h"
Y
Yiqun Liu 已提交
19 20 21 22 23

namespace paddle {
namespace operators {
namespace math {

F
fengjiayi 已提交
24
template <typename T, CopyType Type>
25 26 27 28 29 30 31 32 33 34
__global__ void SequencePaddingKernel(T* dst,
                                      const T* src,
                                      const T* pad_value,
                                      bool is_constant_pad,
                                      const size_t* seq_offsets,
                                      const size_t seq_num,
                                      const size_t pad_seq_len,
                                      const size_t step_width,
                                      bool norm_by_len,
                                      const PadLayout layout) {
Y
yangyaming 已提交
35
  size_t seq_idx = blockIdx.y;
F
fengjiayi 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48
  size_t seq_len = seq_offsets[seq_idx + 1] - seq_offsets[seq_idx];

  size_t step_idx = blockIdx.x * blockDim.y + threadIdx.y;
  size_t seq_data_offset = (seq_offsets[seq_idx] + step_idx) * step_width;
  size_t pad_data_offset = layout == kBatchLengthWidth
                               ? (seq_idx * pad_seq_len + step_idx) * step_width
                               : (step_idx * seq_num + seq_idx) * step_width;

  T* dst_data = dst + (Type == kSeqToPad ? pad_data_offset : seq_data_offset);
  const T* src_data =
      src + (Type == kSeqToPad ? seq_data_offset : pad_data_offset);

  if (step_idx < seq_len) {
H
Hui Zhang 已提交
49
    float scale = norm_by_len ? (1.0f / static_cast<float>(seq_len)) : 1.0f;
F
fengjiayi 已提交
50 51
    for (size_t i = threadIdx.x; i < step_width; i += blockDim.x) {
      dst_data[i] = scale * src_data[i];
Y
Yiqun Liu 已提交
52
    }
F
fengjiayi 已提交
53
  } else if (step_idx < pad_seq_len && Type == kSeqToPad) {
F
bug fix  
fengjiayi 已提交
54
    for (size_t i = threadIdx.x; i < step_width; i += blockDim.x) {
F
fengjiayi 已提交
55
      dst_data[i] = is_constant_pad ? pad_value[0] : pad_value[i];
Y
Yiqun Liu 已提交
56 57 58 59
    }
  }
}

0
0x45f 已提交
60 61 62 63 64 65
template <typename T>
class PaddingLoDTensorFunctor<phi::GPUContext, T> {
 public:
  void operator()(const phi::GPUContext& context,
                  const framework::LoDTensor& seq_tensor,
                  framework::LoDTensor* pad_tensor,
66 67 68 69
                  const framework::LoDTensor& pad_value,
                  int pad_seq_len = -1,
                  int lod_level = 0,
                  bool norm_by_times = false,
0
0x45f 已提交
70 71 72 73 74 75 76 77 78 79
                  const PadLayout layout = kBatchLengthWidth) {
    auto seq_lod = seq_tensor.lod();
    auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level];
    const auto& seq_tensor_dims = seq_tensor.dims();
    const auto& pad_tensor_dims = pad_tensor->dims();
    int max_seq_len = MaximumSequenceLength(seq_offsets);
    if (pad_seq_len == -1) {
      pad_seq_len = max_seq_len;
    }
    PADDLE_ENFORCE_GE(
80 81
        pad_seq_len,
        max_seq_len,
0
0x45f 已提交
82 83 84 85
        platform::errors::InvalidArgument(
            "The pad_seq_len must be equal to or greater than the "
            "original max sequence length. Expected %ld >= %ld, but got %ld < "
            "%ld. Please check the input value.",
86 87 88 89
            pad_seq_len,
            max_seq_len,
            pad_seq_len,
            max_seq_len));
0
0x45f 已提交
90 91 92
    int step_width = seq_tensor.numel() / seq_tensor_dims[0];
    int seq_num = seq_offsets.size() - 1;

93 94 95 96 97 98
    CheckDims(seq_tensor_dims,
              pad_tensor_dims,
              seq_offsets,
              pad_seq_len,
              step_width,
              layout);
0
0x45f 已提交
99
    PADDLE_ENFORCE_EQ(
100 101
        pad_value.numel() == 1 || pad_value.numel() == step_width,
        true,
0
0x45f 已提交
102 103 104 105
        platform::errors::InvalidArgument(
            "The numel of 'pad_value' can only be 1 or be equal to "
            "the 'step_width', but got %ld != 1 and %ld. Please check the "
            "input value.",
106 107
            pad_value.numel(),
            step_width));
0
0x45f 已提交
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128

    const int kBlockSize = 512;

    /* At least use 32 threads to copy sequence_width elements,
     * and at least 8 elements for each thread.
     */
    size_t block_dim_x =
        std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
    size_t block_dim_y = kBlockSize / block_dim_x;
    dim3 threads(block_dim_x, block_dim_y);

    size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y;
    size_t grid_dim_y = seq_num;
    dim3 grid(grid_dim_x, grid_dim_y);

    const T* seq_data = seq_tensor.data<T>();
    T* pad_data = pad_tensor->data<T>();
    const T* pad_value_data = pad_value.data<T>();

    paddle::framework::MixVector<size_t> mix_vector_seq_offsets(&seq_offsets);
    SequencePaddingKernel<T, kSeqToPad><<<grid, threads, 0, context.stream()>>>(
129 130 131 132 133 134 135 136 137 138
        pad_data,
        seq_data,
        pad_value_data,
        pad_value.numel() == 1,
        mix_vector_seq_offsets.CUDAData(context.GetPlace()),
        seq_num,
        pad_seq_len,
        step_width,
        norm_by_times,
        layout);
0
0x45f 已提交
139 140 141 142 143 144 145 146
  }
};

template <typename T>
class UnpaddingLoDTensorFunctor<phi::GPUContext, T> {
 public:
  void operator()(const phi::GPUContext& context,
                  const framework::LoDTensor& pad_tensor,
147 148 149 150
                  framework::LoDTensor* seq_tensor,
                  int pad_seq_len = -1,
                  int lod_level = 0,
                  bool norm_by_times = false,
0
0x45f 已提交
151 152 153 154 155 156 157 158 159 160 161
                  const PadLayout layout = kBatchLengthWidth) {
    auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level];
    const auto& seq_tensor_dims = seq_tensor->dims();
    const auto& pad_tensor_dims = pad_tensor.dims();
    int max_seq_len = MaximumSequenceLength(seq_offsets);
    if (pad_seq_len == -1) {
      pad_seq_len = max_seq_len;
    }
    int step_width = seq_tensor->numel() / seq_tensor_dims[0];
    int seq_num = seq_offsets.size() - 1;

162 163 164 165 166 167
    CheckDims(seq_tensor_dims,
              pad_tensor_dims,
              seq_offsets,
              pad_seq_len,
              step_width,
              layout);
0
0x45f 已提交
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
    /*
    if (!norm_by_times && seq_num == 1UL && pad_seq_len == max_seq_len) {
      paddle::framework::TensorCopy(pad_tensor, context.GetPlace(), context,
    seq_tensor);
      seq_tensor->Resize(seq_tensor_dims);
      return;
    }
    */

    const int kBlockSize = 512;

    /* At least use 32 threads to copy sequence_width elements,
     * and at least 8 elements for each thread.
     */
    size_t block_dim_x =
        std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
    size_t block_dim_y = kBlockSize / block_dim_x;
    dim3 threads(block_dim_x, block_dim_y);

    size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y;
    size_t grid_dim_y = seq_num;
    dim3 grid(grid_dim_x, grid_dim_y);

    const T* pad_data = pad_tensor.data<T>();
    T* seq_data = seq_tensor->data<T>();

    paddle::framework::MixVector<size_t> mixv_seq_offsets(&seq_offsets);
    SequencePaddingKernel<T, kPadToSeq><<<grid, threads, 0, context.stream()>>>(
196 197 198 199 200 201 202 203 204 205
        seq_data,
        pad_data,
        nullptr,
        false,
        mixv_seq_offsets.CUDAData(context.GetPlace()),
        seq_num,
        pad_seq_len,
        step_width,
        norm_by_times,
        layout);
0
0x45f 已提交
206 207 208 209 210 211 212 213 214 215 216 217 218
  }
};

template class PaddingLoDTensorFunctor<phi::GPUContext, int>;
template class PaddingLoDTensorFunctor<phi::GPUContext, int64_t>;
template class PaddingLoDTensorFunctor<phi::GPUContext, float>;
template class PaddingLoDTensorFunctor<phi::GPUContext, double>;

template class UnpaddingLoDTensorFunctor<phi::GPUContext, int>;
template class UnpaddingLoDTensorFunctor<phi::GPUContext, int64_t>;
template class UnpaddingLoDTensorFunctor<phi::GPUContext, float>;
template class UnpaddingLoDTensorFunctor<phi::GPUContext, double>;

Y
Yiqun Liu 已提交
219 220 221
}  // namespace math
}  // namespace operators
}  // namespace paddle