sequence_expand_op.cu 8.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
W
wanghaoshuang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
W
wanghaoshuang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
W
wanghaoshuang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
W
wanghaoshuang 已提交
14

D
dzhwinter 已提交
15
#include <algorithm>
16
#include "paddle/fluid/memory/memcpy.h"
W
Wu Yi 已提交
17
#include "paddle/fluid/operators/sequence_ops/sequence_expand_op.h"
D
dzhwinter 已提交
18
#include "paddle/fluid/platform/cuda_primitives.h"
W
wanghaoshuang 已提交
19

D
dzhwinter 已提交
20 21 22 23 24 25
namespace paddle {
namespace operators {

using LoDTensor = framework::LoDTensor;

template <typename T>
D
dzhwinter 已提交
26 27
__global__ void sequence_expand_kernel(const T* x_data, const size_t* x_lod,
                                       const size_t* ref_lod,
D
dzhwinter 已提交
28
                                       const size_t* offset,
D
dzhwinter 已提交
29 30 31 32 33 34 35 36 37
                                       const size_t lod_size,
                                       /* default=1,
                                          the instance length*/
                                       const int x_item_length, T* out_data) {
  int bid = blockIdx.x;
  if (bid >= lod_size - 1) return;

  int x_item_count = x_lod[bid + 1] - x_lod[bid];
  int repeats = ref_lod[bid + 1] - ref_lod[bid];
D
dzhwinter 已提交
38
  int out_offset = static_cast<int>(offset[bid]);
D
dzhwinter 已提交
39 40 41 42 43 44 45 46
  int x_offset = x_lod[bid];
  for (int tid_z = threadIdx.z; tid_z < repeats; tid_z += blockDim.z) {
    for (int tid_y = threadIdx.y; tid_y < x_item_count; tid_y += blockDim.y) {
      for (int tid_x = threadIdx.x; tid_x < x_item_length;
           tid_x += blockDim.x) {
        out_data[(out_offset + tid_z * x_item_count + tid_y) * x_item_length +
                 tid_x] = x_data[(x_offset + tid_y) * x_item_length + tid_x];
      }
D
dzhwinter 已提交
47 48
    }
  }
D
dzhwinter 已提交
49
}
D
dzhwinter 已提交
50

D
dzhwinter 已提交
51
template <typename T>
D
dzhwinter 已提交
52 53 54 55 56 57
__global__ void sequence_expand_grad_kernel(
    const T* dout_data, const size_t* ref_lod, const size_t* dx_lod,
    const size_t* offset, const size_t lod_size,
    /* default=1,
       the instance length*/
    const int x_item_length, T* dx_data) {
D
dzhwinter 已提交
58 59 60 61
  int bid = blockIdx.x;
  if (bid >= lod_size - 1) return;
  int x_item_count = dx_lod[bid + 1] - dx_lod[bid];
  int repeats = ref_lod[bid + 1] - ref_lod[bid];
D
dzhwinter 已提交
62
  int out_offset = static_cast<int>(offset[bid]);
D
dzhwinter 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75
  int x_offset = dx_lod[bid];

  for (int tid_z = threadIdx.z; tid_z < repeats; tid_z += blockDim.z) {
    for (int tid_y = threadIdx.y; tid_y < x_item_count; tid_y += blockDim.y) {
      for (int tid_x = threadIdx.x; tid_x < x_item_length;
           tid_x += blockDim.x) {
        platform::CudaAtomicAdd(
            &dx_data[(x_offset + tid_y) * x_item_length + tid_x],
            dout_data[(out_offset + tid_z * x_item_count + tid_y) *
                          x_item_length +
                      tid_x]);
      }
    }
D
dzhwinter 已提交
76 77 78
  }
}

D
dzhwinter 已提交
79 80
void GetOutputOffset(const framework::Vector<size_t>& x_lod,
                     const framework::Vector<size_t>& ref_lod,
D
dzhwinter 已提交
81
                     framework::Vector<size_t>* out_offset) {
D
dzhwinter 已提交
82 83 84
  size_t offset = 0;
  int lod_size = static_cast<int>(x_lod.size());
  for (int i = 0; i < static_cast<int>(x_lod.size()); ++i) {
D
"done"  
dzhwinter 已提交
85
    (*out_offset)[i] = offset;
D
dzhwinter 已提交
86 87 88 89 90 91
    if (i < lod_size - 1) {
      offset += (ref_lod[i + 1] - ref_lod[i]) * (x_lod[i + 1] - x_lod[i]);
    }
  }
}

92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
template <typename T>
static int ExpandByMemoryCopy(const platform::CUDADeviceContext& context,
                              const LoDTensor& x, LoDTensor* out,
                              const framework::Vector<size_t>& x_lod,
                              const framework::Vector<size_t>& ref_lod,
                              bool do_copy) {
  auto out_data = out->data<T>();
  auto x_data = x.data<T>();

  auto& gpu_place = boost::get<platform::CUDAPlace>(context.GetPlace());

  int x_item_length = x.numel() / x.dims()[0];
  int out_offset = 0;
  int num_copys = 0;
  for (size_t i = 1; i < ref_lod.size(); ++i) {
    int repeat_num = ref_lod[i] - ref_lod[i - 1];
    int x_start = x_lod[i - 1];
    int x_end = x_lod[i];
    int x_seq_len = x_end - x_start;
    if (repeat_num > 0) {
      if (do_copy) {
        int out_start = out_offset;
        if (out->lod().size() == 1) {
          out_start = out->lod()[0][out_offset];
        }
        for (int j = 0; j < repeat_num; j++) {
          for (int k = 0; k < x_seq_len; k++) {
            memory::Copy(
                gpu_place,
                out_data + (out_start + j * x_seq_len + k) * x_item_length,
                gpu_place, x_data + (x_start + k) * x_item_length,
                sizeof(T) * x_item_length, context.stream());
          }
        }
      } else {
        num_copys += repeat_num * x_seq_len;
      }
    }
    out_offset += repeat_num;
  }
  return num_copys;
}

D
dzhwinter 已提交
135
template <typename T>
D
dzhwinter 已提交
136
struct SequenceExpandFunctor<platform::CUDADeviceContext, T> {
D
dzhwinter 已提交
137 138 139 140 141
  void operator()(
      const platform::CUDADeviceContext& context, const LoDTensor& x,
      const framework::Vector<size_t>& x_lod,   /*expand source lod*/
      const framework::Vector<size_t>& ref_lod, /*expand referenced lod*/
      LoDTensor* out) {
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
    int num_copys =
        ExpandByMemoryCopy<T>(context, x, out, x_lod, ref_lod, false);
    // Sometimes direct copies will be faster, this maybe need deeply analysis.
    if (num_copys < 5) {
      ExpandByMemoryCopy<T>(context, x, out, x_lod, ref_lod, true);
    } else {
      int x_item_length = x.numel() / x.dims()[0];
      size_t x_lod_size = x_lod.size();
      framework::Vector<size_t> out_offset(x_lod_size * 2 + ref_lod.size());
      GetOutputOffset(x_lod, ref_lod, &out_offset);

      for (size_t i = 0; i < x_lod_size; ++i) {
        out_offset[x_lod_size + i] = x_lod[i];
      }
      for (size_t i = 0; i < ref_lod.size(); ++i) {
        out_offset[2 * x_lod_size + i] = ref_lod[i];
      }
D
dzhwinter 已提交
159

160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
      const size_t* out_offset_data = out_offset.CUDAData(context.GetPlace());
      const size_t* x_lod_data = out_offset_data + x_lod_size;
      const size_t* ref_lod_data = out_offset_data + 2 * x_lod_size;

      int thread_x =
          std::min(32, std::max(static_cast<int>(ref_lod.size()), 16));
      int thread_y = 16;
      int thread_z = 1024 / thread_x / thread_y;
      int block_x = static_cast<int>(ref_lod.size());
      dim3 block_size(thread_x, thread_y, thread_z);
      dim3 grid_size(block_x, 1);

      sequence_expand_kernel<<<grid_size, block_size, 0, context.stream()>>>(
          x.data<T>(), x_lod_data, ref_lod_data, out_offset_data, x_lod_size,
          x_item_length, out->mutable_data<T>(context.GetPlace()));
    }
D
dzhwinter 已提交
176
  }
D
dzhwinter 已提交
177
};
D
dzhwinter 已提交
178

D
dzhwinter 已提交
179 180
template <typename T>
struct SequenceExpandGradFunctor<platform::CUDADeviceContext, T> {
D
dzhwinter 已提交
181
  void operator()(const platform::CUDADeviceContext& context,
D
dzhwinter 已提交
182 183 184 185
                  const LoDTensor& dout,
                  const framework::Vector<size_t>& x_lod, /*expand source lod*/
                  const framework::Vector<size_t>& ref_lod, /*expand based lod*/
                  LoDTensor* dx) {
D
dzhwinter 已提交
186
    int x_item_length = framework::product(dx->dims()) / dx->dims()[0];
D
dzhwinter 已提交
187
    framework::Vector<size_t> out_offset(x_lod.size());
D
dzhwinter 已提交
188
    GetOutputOffset(x_lod, ref_lod, &out_offset);
D
dzhwinter 已提交
189

D
dzhwinter 已提交
190 191 192
    int thread_x = std::min(32, std::max(static_cast<int>(ref_lod.size()), 16));
    int thread_y = 16;
    int thread_z = 1024 / thread_x / thread_y;
D
dzhwinter 已提交
193 194
    int block_x = static_cast<int>(ref_lod.size());
    dim3 block_size(thread_x, thread_y, thread_z);
D
dzhwinter 已提交
195
    dim3 grid_size(block_x, 1);
D
dzhwinter 已提交
196 197
    sequence_expand_grad_kernel<<<grid_size, block_size, 0, context.stream()>>>(
        dout.data<T>(), ref_lod.CUDAData(context.GetPlace()),
D
dzhwinter 已提交
198 199
        x_lod.CUDAData(context.GetPlace()),
        out_offset.CUDAData(context.GetPlace()), ref_lod.size(), x_item_length,
D
dzhwinter 已提交
200
        dx->mutable_data<T>(context.GetPlace()));
D
dzhwinter 已提交
201 202
  }
};
D
dzhwinter 已提交
203 204 205 206

}  // namespace operators
}  // namespace paddle

W
wanghaoshuang 已提交
207
namespace ops = paddle::operators;
Q
QI JUN 已提交
208
REGISTER_OP_CUDA_KERNEL(
W
wanghaoshuang 已提交
209
    sequence_expand,
Y
yangyaming 已提交
210 211 212 213
    ops::SequenceExpandKernel<paddle::platform::CUDADeviceContext, float>,
    ops::SequenceExpandKernel<paddle::platform::CUDADeviceContext, double>,
    ops::SequenceExpandKernel<paddle::platform::CUDADeviceContext, int>,
    ops::SequenceExpandKernel<paddle::platform::CUDADeviceContext, int64_t>);
Q
QI JUN 已提交
214
REGISTER_OP_CUDA_KERNEL(
W
wanghaoshuang 已提交
215
    sequence_expand_grad,
Y
yangyaming 已提交
216 217 218 219 220
    ops::SequenceExpandGradKernel<paddle::platform::CUDADeviceContext, float>,
    ops::SequenceExpandGradKernel<paddle::platform::CUDADeviceContext, double>,
    ops::SequenceExpandGradKernel<paddle::platform::CUDADeviceContext, int>,
    ops::SequenceExpandGradKernel<paddle::platform::CUDADeviceContext,
                                  int64_t>);