sequence2batch.cu 2.9 KB
Newer Older
D
dangqingqing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/math/sequence2batch.h"

namespace paddle {
namespace operators {
namespace math {

D
dangqingqing 已提交
21
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
22 23
__global__ void CopyMatrixRowsKernel(const T* src, T* dst, const size_t* index,
                                     int64_t height, int64_t width,
Y
Yu Yang 已提交
24
                                     bool is_src_index) {
D
dangqingqing 已提交
25 26 27 28 29 30
  int idx = threadIdx.x;
  int idy = threadIdx.y;
  int id = blockIdx.x + idy * GridDimX;
  while (id < height) {
    int src_idx = is_src_index ? index[id] : id;
    int dst_idx = is_src_index ? id : index[id];
31
    const T* src_data = src + src_idx * width;
D
dangqingqing 已提交
32 33 34 35 36 37 38 39 40 41 42 43
    T* dst_data = dst + dst_idx * width;
    for (int i = idx; i < width; i += BlockDimX) {
      dst_data[i] = src_data[i];
    }
    id += BlockDimY * GridDimX;
  }
}

template <typename T>
class CopyMatrixRowsFunctor<platform::GPUPlace, T> {
 public:
  void operator()(const platform::DeviceContext& context,
44 45
                  const framework::LoDTensor& src, const size_t* index,
                  framework::LoDTensor& dst, bool is_src_index) {
D
dangqingqing 已提交
46 47
    auto src_dims = src.dims();
    auto dst_dims = dst.dims();
48 49 50 51
    PADDLE_ENFORCE_EQ(src_dims.size(), 2,
                      "The src must be matrix with rank 2.");
    PADDLE_ENFORCE_EQ(dst_dims.size(), 2,
                      "The dst must be matrix with rank 2.");
D
dangqingqing 已提交
52 53 54 55 56 57 58 59 60
    PADDLE_ENFORCE_EQ(src_dims[1], dst_dims[1],
                      "The width of src and dst must be same.");
    auto height = dst_dims[0];
    auto width = dst_dims[1];
    auto* src_data = src.data<T>();
    auto* dst_data = dst.data<T>();

    dim3 threads(128, 8);
    dim3 grid(8, 1);
61 62
    auto stream =
        reinterpret_cast<const platform::CUDADeviceContext&>(context).stream();
D
dangqingqing 已提交
63
    CopyMatrixRowsKernel<T, 128, 8, 8><<<grid, threads, 0, stream>>>(
64
        src_data, dst_data, index, height, width, is_src_index);
D
dangqingqing 已提交
65 66 67 68 69 70
  }
};

template class CopyMatrixRowsFunctor<platform::GPUPlace, float>;
template class CopyMatrixRowsFunctor<platform::GPUPlace, double>;

D
dangqingqing 已提交
71
template class LoDTensor2BatchFunctor<platform::GPUPlace, float>;
72 73 74
template class LoDTensor2BatchFunctor<platform::GPUPlace, double>;
template class Batch2LoDTensorFunctor<platform::GPUPlace, float>;
template class Batch2LoDTensorFunctor<platform::GPUPlace, double>;
D
dangqingqing 已提交
75 76 77 78

}  // namespace math
}  // namespace operators
}  // namespace paddle