sequence_padding.cu 7.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yiqun Liu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <algorithm>
Y
Yi Wang 已提交
16
#include "paddle/fluid/operators/math/sequence_padding.h"
Y
Yiqun Liu 已提交
17 18 19 20 21

namespace paddle {
namespace operators {
namespace math {

F
fengjiayi 已提交
22
template <typename T, CopyType Type>
23
__global__ void SequencePaddingKernel(
F
fengjiayi 已提交
24
    T* dst, const T* src, const T* pad_value, bool is_constant_pad,
F
bug fix  
fengjiayi 已提交
25
    const size_t* seq_offsets, const size_t seq_num, const size_t pad_seq_len,
26 27 28
    const size_t step_width, bool norm_by_len, bool norm_by_batchsize,
    bool norm_by_total_logits_len, int total_logits_len,
    const PadLayout layout) {
Y
yangyaming 已提交
29
  size_t seq_idx = blockIdx.y;
F
fengjiayi 已提交
30 31 32 33 34 35 36 37 38 39 40 41 42
  size_t seq_len = seq_offsets[seq_idx + 1] - seq_offsets[seq_idx];

  size_t step_idx = blockIdx.x * blockDim.y + threadIdx.y;
  size_t seq_data_offset = (seq_offsets[seq_idx] + step_idx) * step_width;
  size_t pad_data_offset = layout == kBatchLengthWidth
                               ? (seq_idx * pad_seq_len + step_idx) * step_width
                               : (step_idx * seq_num + seq_idx) * step_width;

  T* dst_data = dst + (Type == kSeqToPad ? pad_data_offset : seq_data_offset);
  const T* src_data =
      src + (Type == kSeqToPad ? seq_data_offset : pad_data_offset);

  if (step_idx < seq_len) {
43 44 45 46 47 48 49 50 51
    float scale = 1.0f;
    if (norm_by_total_logits_len) {
      scale = 1.0f / static_cast<float>(total_logits_len);
    } else if (norm_by_batchsize) {
      scale = 1.0f / static_cast<float>(seq_num);
    } else if (norm_by_len) {
      scale = norm_by_len ? (1.0f / static_cast<float>(seq_len)) : 1.0f;
    }

F
fengjiayi 已提交
52 53
    for (size_t i = threadIdx.x; i < step_width; i += blockDim.x) {
      dst_data[i] = scale * src_data[i];
Y
Yiqun Liu 已提交
54
    }
F
fengjiayi 已提交
55
  } else if (step_idx < pad_seq_len && Type == kSeqToPad) {
F
bug fix  
fengjiayi 已提交
56
    for (size_t i = threadIdx.x; i < step_width; i += blockDim.x) {
F
fengjiayi 已提交
57
      dst_data[i] = is_constant_pad ? pad_value[0] : pad_value[i];
Y
Yiqun Liu 已提交
58 59 60 61
    }
  }
}

Y
yangyaming 已提交
62 63
template <typename T>
class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
Y
Yiqun Liu 已提交
64 65
 public:
  void operator()(const platform::CUDADeviceContext& context,
66
                  const framework::LoDTensor& seq_tensor,
F
bug fix  
fengjiayi 已提交
67
                  framework::LoDTensor* pad_tensor,
F
fengjiayi 已提交
68 69
                  const framework::LoDTensor& pad_value, int pad_seq_len = -1,
                  int lod_level = 0, bool norm_by_times = false,
70 71
                  bool norm_by_batchsize = false,
                  bool norm_by_total_logits_len = false,
F
fengjiayi 已提交
72 73 74 75 76
                  const PadLayout layout = kBatchLengthWidth) {
    auto seq_lod = seq_tensor.lod();
    const auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level];
    const auto& seq_tensor_dims = seq_tensor.dims();
    const auto& pad_tensor_dims = pad_tensor->dims();
F
bug fix  
fengjiayi 已提交
77
    int max_seq_len = MaximumSequenceLength(seq_offsets);
F
fengjiayi 已提交
78
    if (pad_seq_len == -1) {
F
bug fix  
fengjiayi 已提交
79
      pad_seq_len = max_seq_len;
F
fengjiayi 已提交
80
    }
81 82 83 84 85 86 87
    PADDLE_ENFORCE_GE(
        pad_seq_len, max_seq_len,
        platform::errors::InvalidArgument(
            "The pad_seq_len must be equal to or greater than the "
            "original max sequence length. Expected %ld >= %ld, but got %ld < "
            "%ld. Please check the input value.",
            pad_seq_len, max_seq_len, pad_seq_len, max_seq_len));
F
fengjiayi 已提交
88
    int step_width = seq_tensor.numel() / seq_tensor_dims[0];
F
bug fix  
fengjiayi 已提交
89
    int seq_num = seq_offsets.size() - 1;
90

F
fengjiayi 已提交
91 92
    CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
              step_width, layout);
93 94 95 96 97 98 99
    PADDLE_ENFORCE_EQ(
        pad_value.numel() == 1 || pad_value.numel() == step_width, true,
        platform::errors::InvalidArgument(
            "The numel of 'pad_value' can only be 1 or be equal to "
            "the 'step_width', but got %ld != 1 and %ld. Please check the "
            "input value.",
            pad_value.numel(), step_width));
100

F
bug fix  
fengjiayi 已提交
101
    const int kBlockSize = 512;
Y
Yiqun Liu 已提交
102 103 104 105 106

    /* At least use 32 threads to copy sequence_width elements,
     * and at least 8 elements for each thread.
     */
    size_t block_dim_x =
F
fengjiayi 已提交
107
        std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
Y
Yiqun Liu 已提交
108 109 110
    size_t block_dim_y = kBlockSize / block_dim_x;
    dim3 threads(block_dim_x, block_dim_y);

F
fengjiayi 已提交
111
    size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y;
112
    size_t grid_dim_y = seq_num;
Y
Yiqun Liu 已提交
113 114
    dim3 grid(grid_dim_x, grid_dim_y);

115
    const T* seq_data = seq_tensor.data<T>();
Y
yangyaming 已提交
116
    T* pad_data = pad_tensor->data<T>();
F
fengjiayi 已提交
117
    const T* pad_value_data = pad_value.data<T>();
118

F
fengjiayi 已提交
119 120
    SequencePaddingKernel<T, kSeqToPad><<<grid, threads, 0, context.stream()>>>(
        pad_data, seq_data, pad_value_data, pad_value.numel() == 1,
F
bug fix  
fengjiayi 已提交
121
        seq_offsets.CUDAData(context.GetPlace()), seq_num, pad_seq_len,
122
        step_width, norm_by_times, false, false, 0, layout);
Y
Yiqun Liu 已提交
123 124 125
  }
};

Y
yangyaming 已提交
126 127
template <typename T>
class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
Y
Yiqun Liu 已提交
128 129
 public:
  void operator()(const platform::CUDADeviceContext& context,
F
fengjiayi 已提交
130 131 132
                  const framework::LoDTensor& pad_tensor,
                  framework::LoDTensor* seq_tensor, int pad_seq_len = -1,
                  int lod_level = 0, bool norm_by_times = false,
133 134
                  bool norm_by_batchsize = false,
                  bool norm_by_total_logits_len = false,
F
fengjiayi 已提交
135 136 137 138
                  const PadLayout layout = kBatchLengthWidth) {
    auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level];
    const auto& seq_tensor_dims = seq_tensor->dims();
    const auto& pad_tensor_dims = pad_tensor.dims();
F
bug fix  
fengjiayi 已提交
139
    int max_seq_len = MaximumSequenceLength(seq_offsets);
F
fengjiayi 已提交
140
    if (pad_seq_len == -1) {
F
bug fix  
fengjiayi 已提交
141
      pad_seq_len = max_seq_len;
F
fengjiayi 已提交
142
    }
143
    int total_logits_len = TotalSequenceLength(seq_offsets);
F
fengjiayi 已提交
144
    int step_width = seq_tensor->numel() / seq_tensor_dims[0];
F
bug fix  
fengjiayi 已提交
145
    int seq_num = seq_offsets.size() - 1;
Y
yangyaming 已提交
146

F
fengjiayi 已提交
147 148
    CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
              step_width, layout);
P
phlrain 已提交
149
    /*
F
bug fix  
fengjiayi 已提交
150
    if (!norm_by_times && seq_num == 1UL && pad_seq_len == max_seq_len) {
Y
yangyaming 已提交
151 152
      TensorCopy(pad_tensor, context.GetPlace(), context, seq_tensor);
      seq_tensor->Resize(seq_tensor_dims);
Y
Yiqun Liu 已提交
153 154
      return;
    }
P
phlrain 已提交
155
    */
Y
Yiqun Liu 已提交
156

F
bug fix  
fengjiayi 已提交
157
    const int kBlockSize = 512;
Y
Yiqun Liu 已提交
158 159 160 161 162

    /* At least use 32 threads to copy sequence_width elements,
     * and at least 8 elements for each thread.
     */
    size_t block_dim_x =
F
fengjiayi 已提交
163
        std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
Y
Yiqun Liu 已提交
164 165 166
    size_t block_dim_y = kBlockSize / block_dim_x;
    dim3 threads(block_dim_x, block_dim_y);

F
fengjiayi 已提交
167
    size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y;
168
    size_t grid_dim_y = seq_num;
Y
Yiqun Liu 已提交
169 170
    dim3 grid(grid_dim_x, grid_dim_y);

Y
yangyaming 已提交
171
    const T* pad_data = pad_tensor.data<T>();
172 173
    T* seq_data = seq_tensor->data<T>();

F
fengjiayi 已提交
174 175
    SequencePaddingKernel<T, kPadToSeq><<<grid, threads, 0, context.stream()>>>(
        seq_data, pad_data, nullptr, false,
F
bug fix  
fengjiayi 已提交
176
        seq_offsets.CUDAData(context.GetPlace()), seq_num, pad_seq_len,
177 178
        step_width, norm_by_times, norm_by_batchsize, norm_by_total_logits_len,
        total_logits_len, layout);
Y
Yiqun Liu 已提交
179 180 181
  }
};

Y
yangyaming 已提交
182 183 184 185 186 187 188 189 190
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, int>;
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, int64_t>;
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, float>;
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, double>;

template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, int>;
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, int64_t>;
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, float>;
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, double>;
Y
Yiqun Liu 已提交
191 192 193 194

}  // namespace math
}  // namespace operators
}  // namespace paddle