sequence_padding.cu 7.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yiqun Liu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <algorithm>
Y
Yi Wang 已提交
16
#include "paddle/fluid/operators/math/sequence_padding.h"
Y
Yiqun Liu 已提交
17 18 19 20 21

namespace paddle {
namespace operators {
namespace math {

F
fengjiayi 已提交
22
template <typename T, CopyType Type>
23
__global__ void SequencePaddingKernel(
F
fengjiayi 已提交
24
    T* dst, const T* src, const T* pad_value, bool is_constant_pad,
F
bug fix  
fengjiayi 已提交
25
    const size_t* seq_offsets, const size_t seq_num, const size_t pad_seq_len,
H
Hui Zhang 已提交
26
    const size_t step_width, bool norm_by_len, const PadLayout layout) {
Y
yangyaming 已提交
27
  size_t seq_idx = blockIdx.y;
F
fengjiayi 已提交
28 29 30 31 32 33 34 35 36 37 38 39 40
  size_t seq_len = seq_offsets[seq_idx + 1] - seq_offsets[seq_idx];

  size_t step_idx = blockIdx.x * blockDim.y + threadIdx.y;
  size_t seq_data_offset = (seq_offsets[seq_idx] + step_idx) * step_width;
  size_t pad_data_offset = layout == kBatchLengthWidth
                               ? (seq_idx * pad_seq_len + step_idx) * step_width
                               : (step_idx * seq_num + seq_idx) * step_width;

  T* dst_data = dst + (Type == kSeqToPad ? pad_data_offset : seq_data_offset);
  const T* src_data =
      src + (Type == kSeqToPad ? seq_data_offset : pad_data_offset);

  if (step_idx < seq_len) {
H
Hui Zhang 已提交
41
    float scale = norm_by_len ? (1.0f / static_cast<float>(seq_len)) : 1.0f;
F
fengjiayi 已提交
42 43
    for (size_t i = threadIdx.x; i < step_width; i += blockDim.x) {
      dst_data[i] = scale * src_data[i];
Y
Yiqun Liu 已提交
44
    }
F
fengjiayi 已提交
45
  } else if (step_idx < pad_seq_len && Type == kSeqToPad) {
F
bug fix  
fengjiayi 已提交
46
    for (size_t i = threadIdx.x; i < step_width; i += blockDim.x) {
F
fengjiayi 已提交
47
      dst_data[i] = is_constant_pad ? pad_value[0] : pad_value[i];
Y
Yiqun Liu 已提交
48 49 50 51
    }
  }
}

Y
yangyaming 已提交
52 53
template <typename T>
class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
Y
Yiqun Liu 已提交
54 55
 public:
  void operator()(const platform::CUDADeviceContext& context,
56
                  const framework::LoDTensor& seq_tensor,
F
bug fix  
fengjiayi 已提交
57
                  framework::LoDTensor* pad_tensor,
F
fengjiayi 已提交
58 59 60 61 62 63 64
                  const framework::LoDTensor& pad_value, int pad_seq_len = -1,
                  int lod_level = 0, bool norm_by_times = false,
                  const PadLayout layout = kBatchLengthWidth) {
    auto seq_lod = seq_tensor.lod();
    const auto seq_offsets = framework::ToAbsOffset(seq_lod)[lod_level];
    const auto& seq_tensor_dims = seq_tensor.dims();
    const auto& pad_tensor_dims = pad_tensor->dims();
F
bug fix  
fengjiayi 已提交
65
    int max_seq_len = MaximumSequenceLength(seq_offsets);
F
fengjiayi 已提交
66
    if (pad_seq_len == -1) {
F
bug fix  
fengjiayi 已提交
67
      pad_seq_len = max_seq_len;
F
fengjiayi 已提交
68
    }
69 70 71 72 73 74 75
    PADDLE_ENFORCE_GE(
        pad_seq_len, max_seq_len,
        platform::errors::InvalidArgument(
            "The pad_seq_len must be equal to or greater than the "
            "original max sequence length. Expected %ld >= %ld, but got %ld < "
            "%ld. Please check the input value.",
            pad_seq_len, max_seq_len, pad_seq_len, max_seq_len));
F
fengjiayi 已提交
76
    int step_width = seq_tensor.numel() / seq_tensor_dims[0];
F
bug fix  
fengjiayi 已提交
77
    int seq_num = seq_offsets.size() - 1;
78

F
fengjiayi 已提交
79 80
    CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
              step_width, layout);
81 82 83 84 85 86 87
    PADDLE_ENFORCE_EQ(
        pad_value.numel() == 1 || pad_value.numel() == step_width, true,
        platform::errors::InvalidArgument(
            "The numel of 'pad_value' can only be 1 or be equal to "
            "the 'step_width', but got %ld != 1 and %ld. Please check the "
            "input value.",
            pad_value.numel(), step_width));
88

F
bug fix  
fengjiayi 已提交
89
    const int kBlockSize = 512;
Y
Yiqun Liu 已提交
90 91 92 93 94

    /* At least use 32 threads to copy sequence_width elements,
     * and at least 8 elements for each thread.
     */
    size_t block_dim_x =
F
fengjiayi 已提交
95
        std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
Y
Yiqun Liu 已提交
96 97 98
    size_t block_dim_y = kBlockSize / block_dim_x;
    dim3 threads(block_dim_x, block_dim_y);

F
fengjiayi 已提交
99
    size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y;
100
    size_t grid_dim_y = seq_num;
Y
Yiqun Liu 已提交
101 102
    dim3 grid(grid_dim_x, grid_dim_y);

103
    const T* seq_data = seq_tensor.data<T>();
Y
yangyaming 已提交
104
    T* pad_data = pad_tensor->data<T>();
F
fengjiayi 已提交
105
    const T* pad_value_data = pad_value.data<T>();
106

F
fengjiayi 已提交
107 108
    SequencePaddingKernel<T, kSeqToPad><<<grid, threads, 0, context.stream()>>>(
        pad_data, seq_data, pad_value_data, pad_value.numel() == 1,
F
bug fix  
fengjiayi 已提交
109
        seq_offsets.CUDAData(context.GetPlace()), seq_num, pad_seq_len,
H
Hui Zhang 已提交
110
        step_width, norm_by_times, layout);
Y
Yiqun Liu 已提交
111 112 113
  }
};

Y
yangyaming 已提交
114 115
template <typename T>
class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
Y
Yiqun Liu 已提交
116 117
 public:
  void operator()(const platform::CUDADeviceContext& context,
F
fengjiayi 已提交
118 119 120 121 122 123 124
                  const framework::LoDTensor& pad_tensor,
                  framework::LoDTensor* seq_tensor, int pad_seq_len = -1,
                  int lod_level = 0, bool norm_by_times = false,
                  const PadLayout layout = kBatchLengthWidth) {
    auto seq_offsets = framework::ToAbsOffset(seq_tensor->lod())[lod_level];
    const auto& seq_tensor_dims = seq_tensor->dims();
    const auto& pad_tensor_dims = pad_tensor.dims();
F
bug fix  
fengjiayi 已提交
125
    int max_seq_len = MaximumSequenceLength(seq_offsets);
F
fengjiayi 已提交
126
    if (pad_seq_len == -1) {
F
bug fix  
fengjiayi 已提交
127
      pad_seq_len = max_seq_len;
F
fengjiayi 已提交
128 129
    }
    int step_width = seq_tensor->numel() / seq_tensor_dims[0];
F
bug fix  
fengjiayi 已提交
130
    int seq_num = seq_offsets.size() - 1;
Y
yangyaming 已提交
131

F
fengjiayi 已提交
132 133
    CheckDims(seq_tensor_dims, pad_tensor_dims, seq_offsets, pad_seq_len,
              step_width, layout);
P
phlrain 已提交
134
    /*
F
bug fix  
fengjiayi 已提交
135
    if (!norm_by_times && seq_num == 1UL && pad_seq_len == max_seq_len) {
136 137
      paddle::framework::TensorCopy(pad_tensor, context.GetPlace(), context,
    seq_tensor);
Y
yangyaming 已提交
138
      seq_tensor->Resize(seq_tensor_dims);
Y
Yiqun Liu 已提交
139 140
      return;
    }
P
phlrain 已提交
141
    */
Y
Yiqun Liu 已提交
142

F
bug fix  
fengjiayi 已提交
143
    const int kBlockSize = 512;
Y
Yiqun Liu 已提交
144 145 146 147 148

    /* At least use 32 threads to copy sequence_width elements,
     * and at least 8 elements for each thread.
     */
    size_t block_dim_x =
F
fengjiayi 已提交
149
        std::min(((((step_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
Y
Yiqun Liu 已提交
150 151 152
    size_t block_dim_y = kBlockSize / block_dim_x;
    dim3 threads(block_dim_x, block_dim_y);

F
fengjiayi 已提交
153
    size_t grid_dim_x = (pad_seq_len + block_dim_y - 1) / block_dim_y;
154
    size_t grid_dim_y = seq_num;
Y
Yiqun Liu 已提交
155 156
    dim3 grid(grid_dim_x, grid_dim_y);

Y
yangyaming 已提交
157
    const T* pad_data = pad_tensor.data<T>();
158 159
    T* seq_data = seq_tensor->data<T>();

F
fengjiayi 已提交
160 161
    SequencePaddingKernel<T, kPadToSeq><<<grid, threads, 0, context.stream()>>>(
        seq_data, pad_data, nullptr, false,
F
bug fix  
fengjiayi 已提交
162
        seq_offsets.CUDAData(context.GetPlace()), seq_num, pad_seq_len,
H
Hui Zhang 已提交
163
        step_width, norm_by_times, layout);
Y
Yiqun Liu 已提交
164 165 166
  }
};

Y
yangyaming 已提交
167 168 169 170 171 172 173 174 175
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, int>;
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, int64_t>;
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, float>;
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, double>;

template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, int>;
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, int64_t>;
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, float>;
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, double>;
Y
Yiqun Liu 已提交
176 177 178 179

}  // namespace math
}  // namespace operators
}  // namespace paddle