“c226a2ed72daaf6ee37378c6957d40d37d3cde1f”上不存在“documentation20/webdocs/markdowndocs/Connector.md”
sequence_padding.cu 6.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yiqun Liu 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <algorithm>
Y
Yi Wang 已提交
16
#include "paddle/fluid/operators/math/sequence_padding.h"
Y
Yiqun Liu 已提交
17 18 19 20 21

namespace paddle {
namespace operators {
namespace math {

22 23
template <typename T, bool Padding>
__global__ void SequencePaddingKernel(
Y
yangyaming 已提交
24 25 26 27 28 29
    T* pad_data, T* seq_data, const size_t* seq_offset, const size_t& seq_num,
    const size_t& max_seq_len, const size_t& seq_width, bool norm_by_times,
    const T& pad_value, const OutputLayout& output_layout) {
  size_t seq_idx = blockIdx.y;
  size_t seq_start = seq_offset[seq_idx];
  size_t seq_len = seq_offset[seq_idx + 1] - seq_start;
Y
Yiqun Liu 已提交
30

Y
yangyaming 已提交
31
  size_t seq_step_idx = blockIdx.x * blockDim.y + threadIdx.y;
Y
Yiqun Liu 已提交
32

Y
yangyaming 已提交
33
  size_t seq_data_offset = (seq_start + seq_step_idx) * seq_width;
34

Y
yangyaming 已提交
35
  size_t pad_data_offset = 0;
36

Y
yangyaming 已提交
37 38
  if (output_layout == kLengthBatchWidth) {
    pad_data_offset = (seq_step_idx * seq_num + seq_idx) * seq_width;
39
  } else {
Y
yangyaming 已提交
40
    pad_data_offset = (seq_idx * max_seq_len + seq_step_idx) * seq_width;
41 42
  }

Y
yangyaming 已提交
43
  if (seq_step_idx < seq_len) {
44
    T scale = norm_by_times ? (1.0f / static_cast<T>(seq_len)) : 1.0f;
Y
Yiqun Liu 已提交
45
    if (Padding) {
Y
yangyaming 已提交
46
      /* seq -> pad */
47
      for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) {
Y
yangyaming 已提交
48
        pad_data[pad_data_offset + i] = scale * seq_data[seq_data_offset + i];
Y
Yiqun Liu 已提交
49 50
      }
    } else {
Y
yangyaming 已提交
51
      /* pad -> seq */
52
      for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) {
Y
yangyaming 已提交
53
        seq_data[seq_data_offset + i] = scale * pad_data[pad_data_offset + i];
Y
Yiqun Liu 已提交
54 55
      }
    }
Y
yangyaming 已提交
56
  } else if (seq_step_idx < max_seq_len) {
Y
Yiqun Liu 已提交
57
    if (Padding) {
Y
yangyaming 已提交
58
      /* seq -> pad */
59
      for (size_t i = threadIdx.x; i < seq_width; i += blockDim.x) {
Y
yangyaming 已提交
60
        pad_data[pad_data_offset + i] = pad_value;
Y
Yiqun Liu 已提交
61 62 63 64 65
      }
    }
  }
}

Y
yangyaming 已提交
66 67
template <typename T>
class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
Y
Yiqun Liu 已提交
68 69
 public:
  void operator()(const platform::CUDADeviceContext& context,
70
                  const framework::LoDTensor& seq_tensor,
Y
yangyaming 已提交
71 72 73 74 75
                  framework::Tensor* pad_tensor,
                  T pad_value = static_cast<T>(0), bool norm_by_times = false,
                  size_t lod_level = 0,
                  OutputLayout output_layout = kBatchLengthWidth) {
    CheckLoD(seq_tensor, lod_level);
76 77

    auto& lod = seq_tensor.lod();
Y
yangyaming 已提交
78
    auto& seq_offset = framework::ToAbsOffset(lod)[lod_level];
79

Y
yangyaming 已提交
80 81 82 83 84
    auto seq_tensor_dims = seq_tensor.dims();
    auto pad_tensor_dims = pad_tensor->dims();
    int64_t max_seq_len = MaximumSequenceLength(seq_offset);
    int64_t seq_num = seq_offset.size() - 1;
    int64_t seq_width = seq_tensor.numel() / seq_tensor_dims[0];
85

Y
yangyaming 已提交
86 87
    CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len,
              seq_num, seq_width, output_layout);
88 89

    if (!norm_by_times && seq_num == 1UL) {
Y
yangyaming 已提交
90 91
      TensorCopy(seq_tensor, context.GetPlace(), context, pad_tensor);
      pad_tensor->Resize(pad_tensor_dims);
Y
Yiqun Liu 已提交
92 93 94
      return;
    }

Y
Yancey1989 已提交
95
    const int64_t kBlockSize = 512;
Y
Yiqun Liu 已提交
96 97 98 99 100

    /* At least use 32 threads to copy sequence_width elements,
     * and at least 8 elements for each thread.
     */
    size_t block_dim_x =
101
        std::min(((((seq_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
Y
Yiqun Liu 已提交
102 103 104
    size_t block_dim_y = kBlockSize / block_dim_x;
    dim3 threads(block_dim_x, block_dim_y);

105 106
    size_t grid_dim_x = (max_seq_len + block_dim_y - 1) / block_dim_y;
    size_t grid_dim_y = seq_num;
Y
Yiqun Liu 已提交
107 108
    dim3 grid(grid_dim_x, grid_dim_y);

109
    const T* seq_data = seq_tensor.data<T>();
Y
yangyaming 已提交
110
    T* pad_data = pad_tensor->data<T>();
111 112

    SequencePaddingKernel<T, 1><<<grid, threads, 0, context.stream()>>>(
Y
yangyaming 已提交
113 114 115
        pad_data, const_cast<T*>(seq_data),
        seq_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len,
        seq_width, norm_by_times, pad_value, output_layout);
Y
Yiqun Liu 已提交
116 117 118
  }
};

Y
yangyaming 已提交
119 120
template <typename T>
class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
Y
Yiqun Liu 已提交
121 122
 public:
  void operator()(const platform::CUDADeviceContext& context,
123
                  framework::LoDTensor* seq_tensor,
Y
yangyaming 已提交
124 125 126 127
                  const framework::Tensor& pad_tensor,
                  bool norm_by_times = false, size_t lod_level = 0,
                  OutputLayout output_layout = kBatchLengthWidth) {
    CheckLoD(*seq_tensor, lod_level);
128 129

    auto& lod = seq_tensor->lod();
Y
yangyaming 已提交
130
    auto& seq_offset = framework::ToAbsOffset(lod)[lod_level];
131

Y
yangyaming 已提交
132 133 134 135 136 137 138 139
    auto seq_tensor_dims = seq_tensor->dims();
    auto pad_tensor_dims = pad_tensor.dims();
    int64_t max_seq_len = MaximumSequenceLength(seq_offset);
    int64_t seq_num = seq_offset.size() - 1;
    int64_t seq_width = seq_tensor->numel() / seq_tensor_dims[0];

    CheckDims(seq_tensor_dims, seq_offset.back(), pad_tensor_dims, max_seq_len,
              seq_num, seq_width, output_layout);
140 141

    if (!norm_by_times && seq_num == 1UL) {
Y
yangyaming 已提交
142 143
      TensorCopy(pad_tensor, context.GetPlace(), context, seq_tensor);
      seq_tensor->Resize(seq_tensor_dims);
Y
Yiqun Liu 已提交
144 145 146
      return;
    }

Y
Yancey1989 已提交
147
    const int64_t kBlockSize = 512;
Y
Yiqun Liu 已提交
148 149 150 151 152

    /* At least use 32 threads to copy sequence_width elements,
     * and at least 8 elements for each thread.
     */
    size_t block_dim_x =
153
        std::min(((((seq_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
Y
Yiqun Liu 已提交
154 155 156
    size_t block_dim_y = kBlockSize / block_dim_x;
    dim3 threads(block_dim_x, block_dim_y);

157 158
    size_t grid_dim_x = (max_seq_len + block_dim_y - 1) / block_dim_y;
    size_t grid_dim_y = seq_num;
Y
Yiqun Liu 已提交
159 160
    dim3 grid(grid_dim_x, grid_dim_y);

Y
yangyaming 已提交
161
    const T* pad_data = pad_tensor.data<T>();
162 163
    T* seq_data = seq_tensor->data<T>();

Y
yangyaming 已提交
164 165 166 167
    SequencePaddingKernel<T, 0><<<grid, threads, 0, context.stream()>>>(
        const_cast<T*>(pad_data), seq_data,
        seq_offset.CUDAData(context.GetPlace()), seq_num, max_seq_len,
        seq_width, norm_by_times, static_cast<T>(0), output_layout);
Y
Yiqun Liu 已提交
168 169 170
  }
};

Y
yangyaming 已提交
171 172 173 174 175 176 177 178 179
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, int>;
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, int64_t>;
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, float>;
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, double>;

template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, int>;
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, int64_t>;
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, float>;
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, double>;
Y
Yiqun Liu 已提交
180 181 182 183

}  // namespace math
}  // namespace operators
}  // namespace paddle