sequence_expand_op.cu 6.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
W
wanghaoshuang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
W
wanghaoshuang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
W
wanghaoshuang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
W
wanghaoshuang 已提交
14 15

#define EIGEN_USE_GPU
D
dzhwinter 已提交
16 17
#include <stdio.h>
#include <algorithm>
Y
Yi Wang 已提交
18
#include "paddle/fluid/operators/sequence_expand_op.h"
D
dzhwinter 已提交
19
#include "paddle/fluid/platform/cuda_helper.h"
W
wanghaoshuang 已提交
20

D
dzhwinter 已提交
21 22 23 24 25 26
namespace paddle {
namespace operators {

using LoDTensor = framework::LoDTensor;

template <typename T>
D
dzhwinter 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39
__global__ void sequence_expand_kernel(const T* x_data, const size_t* x_lod,
                                       const size_t* ref_lod,
                                       const size_t lod_size,
                                       /* default=1,
                                          the instance length*/
                                       const int x_item_length, T* out_data) {
  constexpr int N = 1024;
  __shared__ int mem[N];
  int offset = 0;
  for (int i = 0; i < lod_size; ++i) {
    mem[i] = offset;
    if (i < lod_size - 1) {
      offset += (ref_lod[i + 1] - ref_lod[i]) * (x_lod[i + 1] - x_lod[i]);
D
dzhwinter 已提交
40
    }
D
dzhwinter 已提交
41
  }
D
dzhwinter 已提交
42
  __syncthreads();
D
dzhwinter 已提交
43

D
dzhwinter 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57
  int bid = blockIdx.x;
  if (bid >= lod_size - 1) return;

  int x_item_count = x_lod[bid + 1] - x_lod[bid];
  int repeats = ref_lod[bid + 1] - ref_lod[bid];
  int out_offset = mem[bid];
  int x_offset = x_lod[bid];
  for (int tid_z = threadIdx.z; tid_z < repeats; tid_z += blockDim.z) {
    for (int tid_y = threadIdx.y; tid_y < x_item_count; tid_y += blockDim.y) {
      for (int tid_x = threadIdx.x; tid_x < x_item_length;
           tid_x += blockDim.x) {
        out_data[(out_offset + tid_z * x_item_count + tid_y) * x_item_length +
                 tid_x] = x_data[(x_offset + tid_y) * x_item_length + tid_x];
      }
D
dzhwinter 已提交
58 59
    }
  }
D
dzhwinter 已提交
60
}
D
dzhwinter 已提交
61

D
dzhwinter 已提交
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
template <typename T>
__global__ void sequence_expand_grad_kernel(const T* dout_data,
                                            const size_t* ref_lod,
                                            const size_t* dx_lod,
                                            const size_t lod_size,
                                            /* default=1,
                                               the instance length*/
                                            const int x_item_length,
                                            T* dx_data) {
  // TODO(dzhwinter) : too many atomicAdd
  // use shared memory to reduce memory visits
  constexpr int N = 1024;
  __shared__ int mem[N];
  int offset = 0;
  for (int i = 0; i < lod_size; ++i) {
    mem[i] = offset;
    if (i < lod_size - 1) {
      offset += (ref_lod[i + 1] - ref_lod[i]) * (dx_lod[i + 1] - dx_lod[i]);
D
dzhwinter 已提交
80 81 82
    }
  }
  __syncthreads();
D
dzhwinter 已提交
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101

  int bid = blockIdx.x;
  if (bid >= lod_size - 1) return;
  int x_item_count = dx_lod[bid + 1] - dx_lod[bid];
  int repeats = ref_lod[bid + 1] - ref_lod[bid];
  int out_offset = mem[bid];
  int x_offset = dx_lod[bid];

  for (int tid_z = threadIdx.z; tid_z < repeats; tid_z += blockDim.z) {
    for (int tid_y = threadIdx.y; tid_y < x_item_count; tid_y += blockDim.y) {
      for (int tid_x = threadIdx.x; tid_x < x_item_length;
           tid_x += blockDim.x) {
        platform::CudaAtomicAdd(
            &dx_data[(x_offset + tid_y) * x_item_length + tid_x],
            dout_data[(out_offset + tid_z * x_item_count + tid_y) *
                          x_item_length +
                      tid_x]);
      }
    }
D
dzhwinter 已提交
102 103 104 105
  }
}

template <typename T>
D
dzhwinter 已提交
106
struct SequenceExpandFunctor<platform::CUDADeviceContext, T> {
D
dzhwinter 已提交
107 108 109 110 111 112 113 114 115 116 117 118 119
  void operator()(
      const platform::CUDADeviceContext& context, const LoDTensor& x,
      const framework::Vector<size_t>& x_lod,   /*expand source lod*/
      const framework::Vector<size_t>& ref_lod, /*expand referenced lod*/
      LoDTensor* out) {
    int x_item_length = 1;
    x_item_length = x.numel() / x.dims()[0];
    VLOG(0) << "x_item_length" << x_item_length;
    int thread_x = std::max(static_cast<int>(ref_lod.size()), 32);
    int thread_y = std::max(1024 / thread_x, 16);
    int thread_z = std::min(1024 / thread_x / thread_y, 16);
    int block_x = static_cast<int>(ref_lod.size());
    dim3 block_size(thread_x, thread_y, thread_z);
D
dzhwinter 已提交
120
    dim3 grid_size(block_x, 1);
D
dzhwinter 已提交
121

D
dzhwinter 已提交
122
    sequence_expand_kernel<<<grid_size, block_size, 0, context.stream()>>>(
D
dzhwinter 已提交
123 124 125
        x.data<T>(), x_lod.CUDAData(context.GetPlace()),
        ref_lod.CUDAData(context.GetPlace()), x_lod.size(), x_item_length,
        out->mutable_data<T>(context.GetPlace()));
D
dzhwinter 已提交
126
  }
D
dzhwinter 已提交
127
};
D
dzhwinter 已提交
128

D
dzhwinter 已提交
129 130
template <typename T>
struct SequenceExpandGradFunctor<platform::CUDADeviceContext, T> {
D
dzhwinter 已提交
131
  void operator()(const platform::CUDADeviceContext& context,
D
dzhwinter 已提交
132 133 134 135 136 137 138 139 140 141 142 143
                  const LoDTensor& dout,
                  const framework::Vector<size_t>& x_lod, /*expand source lod*/
                  const framework::Vector<size_t>& ref_lod, /*expand based lod*/
                  LoDTensor* dx) {
    int x_item_length = 1;
    x_item_length = framework::product(dx->dims()) / dx->dims()[0];

    int thread_x = std::max(static_cast<int>(ref_lod.size()), 32);
    int thread_y = std::max(1024 / thread_x, 16);
    int thread_z = std::min(1024 / thread_x / thread_y, 16);
    int block_x = static_cast<int>(ref_lod.size());
    dim3 block_size(thread_x, thread_y, thread_z);
D
dzhwinter 已提交
144
    dim3 grid_size(block_x, 1);
D
dzhwinter 已提交
145 146 147 148
    sequence_expand_grad_kernel<<<grid_size, block_size, 0, context.stream()>>>(
        dout.data<T>(), ref_lod.CUDAData(context.GetPlace()),
        x_lod.CUDAData(context.GetPlace()), ref_lod.size(), x_item_length,
        dx->mutable_data<T>(context.GetPlace()));
D
dzhwinter 已提交
149 150
  }
};
D
dzhwinter 已提交
151 152 153 154

}  // namespace operators
}  // namespace paddle

W
wanghaoshuang 已提交
155
namespace ops = paddle::operators;
Q
QI JUN 已提交
156
REGISTER_OP_CUDA_KERNEL(
W
wanghaoshuang 已提交
157
    sequence_expand,
Y
yangyaming 已提交
158 159 160 161
    ops::SequenceExpandKernel<paddle::platform::CUDADeviceContext, float>,
    ops::SequenceExpandKernel<paddle::platform::CUDADeviceContext, double>,
    ops::SequenceExpandKernel<paddle::platform::CUDADeviceContext, int>,
    ops::SequenceExpandKernel<paddle::platform::CUDADeviceContext, int64_t>);
Q
QI JUN 已提交
162
REGISTER_OP_CUDA_KERNEL(
W
wanghaoshuang 已提交
163
    sequence_expand_grad,
Y
yangyaming 已提交
164 165 166 167 168
    ops::SequenceExpandGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::SequenceExpandGradKernel<paddle::platform::CUDADeviceContext, double>,
    ops::SequenceExpandGradKernel<paddle::platform::CUDADeviceContext, int>,
    ops::SequenceExpandGradKernel<paddle::platform::CUDADeviceContext,
                                  int64_t>);