ContextProjectionOpGpu.cu 15.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "ContextProjectionOp.h"
L
liaogang 已提交
16
#include "hl_base.h"
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32

namespace paddle {

template <bool padding>
__global__ void KeContextProjectionForward(const real* input,
                                           const int* sequence,
                                           const real* weight,
                                           real* output,
                                           int input_dim,
                                           int context_length,
                                           int context_start,
                                           int begin_pad) {
  int idx = threadIdx.x;
  int block_size = blockDim.x;
  int sequenceId = blockIdx.x;
  int seq_start = sequence[sequenceId];
L
liaogang 已提交
33
  int seq_end = sequence[sequenceId + 1];
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
  real value = 0;

  int instances = seq_end - seq_start + context_length - 1;
  output += seq_start * input_dim * context_length;
  input += seq_start * input_dim;
  for (int k = 0; k <= input_dim / block_size; k++) {
    if (idx < input_dim) {
      for (int i = 0; i < instances; i++) {
        // i + context_start;
        if ((i + context_start) < 0) {
          if (padding) {
            value = weight[i * input_dim + idx];
          } else {
            continue;
          }
        } else if ((i + context_start) >= (seq_end - seq_start)) {
          if (padding) {
            value =
L
liaogang 已提交
52 53 54
                weight[(begin_pad + i + context_start - (seq_end - seq_start)) *
                           input_dim +
                       idx];
55 56 57 58 59 60 61 62 63 64
          } else {
            continue;
          }
        } else {
          value = input[(i + context_start) * input_dim + idx];
        }

        int outx = (i - context_length) < 0 ? i : (context_length - 1);
        int outy = (i - context_length) < 0 ? 0 : (i - (context_length - 1));
        real* output_r =
L
liaogang 已提交
65
            output + outy * input_dim * context_length + outx * input_dim;
66 67 68 69 70 71 72 73 74 75 76
        for (int j = outy; j < seq_end - seq_start; j++) {
          output_r[idx] += value;
          if (j - outy == outx) break;
          output_r += (context_length - 1) * input_dim;
        }
      }
    }
    idx += block_size;
  }
}

X
xutianbing 已提交
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
/**
 * @brief   Context projection forward.
 *
 * @param[in]   input           input sequence.
 * @param[in]   sequence        sequence index.
 * @param[in]   weight          padding data.
 * @param[out]  output          output sequence.
 * @param[in]   num_sequences    number of sequences.
 * @param[in]   input_dim        input sequence dimension.
 * @param[in]   context_length   context length.
 * @param[in]   context_start    context start.
 * @param[in]   begin_pad        number of extra timesteps added at the
 * beginning.
 *
 */
92 93
void hl_context_projection_forward(const real* input,
                                   const int* sequence,
94
                                   const real* weight,
95
                                   real* output,
X
xutianbing 已提交
96 97 98
                                   size_t num_sequences,
                                   size_t input_dim,
                                   size_t context_length,
99
                                   int context_start,
X
xutianbing 已提交
100
                                   size_t begin_pad) {
101 102 103 104 105 106 107 108 109 110
  CHECK_NOTNULL(input);
  CHECK_NOTNULL(sequence);
  CHECK_NOTNULL(output);

  int block_size = 128;
  int blocks_x = num_sequences;
  int blocks_y = 1;
  dim3 threads(block_size, 1);
  dim3 grid(blocks_x, blocks_y);

111
  if (weight) {
L
liaogang 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
    KeContextProjectionForward<true><<<grid, threads, 0, STREAM_DEFAULT>>>(
        input,
        sequence,
        weight,
        output,
        input_dim,
        context_length,
        context_start,
        begin_pad);
  } else {
    KeContextProjectionForward<false><<<grid, threads, 0, STREAM_DEFAULT>>>(
        input,
        sequence,
        weight,
        output,
        input_dim,
        context_length,
        context_start,
        begin_pad);
131 132 133 134 135
  }
  CHECK_SYNC("hl_context_projection_forward failed");
}

template <>
136 137 138
void ContextProjectionForward<DEVICE_TYPE_GPU>(GpuMatrix& output,
                                               const GpuMatrix& input,
                                               const GpuMatrix& weight,
139
                                               const GpuIVector& sequence,
140 141
                                               size_t context_length,
                                               int context_start,
142
                                               size_t begin_pad) {
143
  hl_context_projection_forward(input.getData(),
144
                                sequence.getData(),
145 146
                                weight ? weight.getData() : nullptr,
                                output.getData(),
147
                                sequence.getSize() - 1,
148
                                input.getWidth(),
149 150
                                context_length,
                                context_start,
151
                                begin_pad);
152 153
}

154
__global__ void KeContextProjectionBackwardData(const real* out_grad,
155 156
                                                const int* sequence,
                                                real* in_grad,
157
                                                size_t input_dim,
158 159 160 161 162 163
                                                int context_length,
                                                int context_start) {
  int idx = threadIdx.x;
  int block_size = blockDim.x;
  int sequenceId = blockIdx.x;
  int seq_start = sequence[sequenceId];
L
liaogang 已提交
164
  int seq_end = sequence[sequenceId + 1];
165 166 167
  real value = 0;

  int instances = seq_end - seq_start + context_length - 1;
168 169
  auto out = const_cast<real*>(out_grad);
  out += seq_start * input_dim * context_length;
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
  in_grad += seq_start * input_dim;
  for (int k = 0; k <= input_dim / block_size; k++) {
    if (idx < input_dim) {
      for (int i = 0; i < instances; i++) {
        if ((i + context_start) < 0) {
          continue;
        } else if ((i + context_start) >= (seq_end - seq_start)) {
          continue;
        } else {
          // value = 0;
          value = in_grad[(i + context_start) * input_dim + idx];
        }

        int outx = (i - context_length) < 0 ? i : (context_length - 1);
        int outy = (i - context_length) < 0 ? 0 : (i - (context_length - 1));
        real* output_r =
L
liaogang 已提交
186
            out + outy * input_dim * context_length + outx * input_dim;
187 188 189 190 191 192 193 194 195 196 197 198
        for (int j = outy; j < seq_end - seq_start; j++) {
          value += output_r[idx];
          if (j - outy == outx) break;
          output_r += (context_length - 1) * input_dim;
        }
        in_grad[(i + context_start) * input_dim + idx] = value;
      }
    }
    idx += block_size;
  }
}

X
xutianbing 已提交
199 200 201 202 203 204 205 206 207 208 209 210
/**
 * @brief   Context projection backward data.
 *
 * @param[in]   out_grad         output gradient.
 * @param[in]   sequence         sequence index.
 * @param[out]  input_grad       input gradient.
 * @param[in]   num_sequences    number of sequences.
 * @param[in]   input_dim        input sequence dimension.
 * @param[in]   context_length   context length.
 * @param[in]   context_start    context start.
 *
 */
211
void hl_context_projection_backward_data(const real* out_grad,
212 213
                                         const int* sequence,
                                         real* input_grad,
X
xutianbing 已提交
214 215 216
                                         size_t num_sequences,
                                         size_t input_dim,
                                         size_t context_length,
217 218 219 220 221 222 223 224 225 226
                                         int context_start) {
  CHECK_NOTNULL(out_grad);
  CHECK_NOTNULL(sequence);
  CHECK_NOTNULL(input_grad);

  int block_size = 128;
  int blocks_x = num_sequences;
  int blocks_y = 1;
  dim3 threads(block_size, 1);
  dim3 grid(blocks_x, blocks_y);
L
liaogang 已提交
227 228
  KeContextProjectionBackwardData<<<grid, threads, 0, STREAM_DEFAULT>>>(
      out_grad, sequence, input_grad, input_dim, context_length, context_start);
229 230 231 232
  CHECK_SYNC("hl_context_projection_backward_data failed");
}

template <>
233
void ContextProjectionBackwardData<DEVICE_TYPE_GPU>(const GpuMatrix& out_grad,
234
                                                    GpuMatrix& in_grad,
235 236 237
                                                    const GpuIVector& sequence,
                                                    size_t context_length,
                                                    int context_start) {
238
  hl_context_projection_backward_data(out_grad.getData(),
239
                                      sequence.getData(),
240
                                      in_grad.getData(),
241
                                      sequence.getSize() - 1,
242
                                      in_grad.getWidth(),
243 244 245 246
                                      context_length,
                                      context_start);
}

L
liaogang 已提交
247
template <int THREADS_X, int THREADS_Y>
248
__global__ void KeContextProjectionBackwardWeight(const real* out_grad,
249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269
                                                  const int* sequence,
                                                  real* w_grad,
                                                  int num_sequences,
                                                  int w_dim,
                                                  int context_length,
                                                  int context_start,
                                                  int begin_pad) {
  __shared__ real sum_s[THREADS_Y][THREADS_X];
  int pad_of_block = (w_dim + THREADS_X - 1) / THREADS_X;
  const int idx = threadIdx.x;
  const int idy = threadIdx.y;
  int padId = blockIdx.x / pad_of_block;
  int weight_idx = idx + THREADS_X * (blockIdx.x % pad_of_block);
  int instanceId;
  real value = 0;
  real* output_r;

  sum_s[idy][idx] = 0.0f;
  if (weight_idx < w_dim) {
    for (int seqId = idy; seqId < num_sequences; seqId += THREADS_Y) {
      int seq_start = sequence[seqId];
L
liaogang 已提交
270 271 272
      int seq_end = sequence[seqId + 1];
      output_r =
          const_cast<real*>(out_grad) + seq_start * w_dim * context_length;
273 274 275 276 277 278

      if (context_start < 0) {
        if (padId + context_start < 0) {
          instanceId = padId;
        } else {
          // begin_pad > 0;
L
liaogang 已提交
279 280
          instanceId =
              (padId - begin_pad) + (seq_end - seq_start) - context_start;
281 282 283 284 285 286 287 288 289 290
        }
      } else {
        if (padId + (seq_end - seq_start) < context_start) {
          continue;
        } else {
          // begin_pad == 0;
          instanceId = padId + (seq_end - seq_start) - context_start;
        }
      }

L
liaogang 已提交
291 292 293 294 295
      int outx =
          (instanceId - context_length) < 0 ? instanceId : (context_length - 1);
      int outy = (instanceId - context_length) < 0
                     ? 0
                     : (instanceId - (context_length - 1));
296 297 298 299 300 301 302 303 304 305 306
      output_r += outy * w_dim * context_length + outx * w_dim;
      for (int j = outy; j < seq_end - seq_start; j++) {
        value += output_r[weight_idx];
        if (j - outy == outx) break;
        output_r += (context_length - 1) * w_dim;
      }
    }
    sum_s[idy][idx] = value;
  }
  __syncthreads();

L
liaogang 已提交
307
  for (int stride = THREADS_Y / 2; stride > 0; stride = stride / 2) {
308 309 310 311 312 313 314 315 316 317 318 319 320 321
    if (idy < stride) {
      sum_s[idy][idx] += sum_s[idy + stride][idx];
    }
    __syncthreads();
  }
  __syncthreads();

  if (weight_idx < w_dim) {
    if (idy == 0) {
      w_grad[padId * w_dim + weight_idx] += sum_s[0][idx];
    }
  }
}

X
xutianbing 已提交
322 323 324 325 326 327 328 329 330 331 332 333 334 335 336
/**
 * @brief   Context projection backward weight.
 *
 * @param[in]   out_grad         output gradient.
 * @param[in]   sequence         sequence index.
 * @param[out]  w_grad           weight gradient.
 * @param[in]   num_sequences    number of sequences.
 * @param[in]   w_dim            input sequence dimension.
 * @param[in]   total_pad        number of extra timesteps.
 * @param[in]   context_length   context length.
 * @param[in]   context_start    context start.
 * @param[in]   begin_pad        number of extra timesteps added at the
 * beginning.
 *
 */
337
void hl_context_projection_backward_weight(const real* out_grad,
338 339
                                           const int* sequence,
                                           real* w_grad,
X
xutianbing 已提交
340 341
                                           size_t num_sequences,
                                           size_t w_dim,
342
                                           size_t total_pad,
X
xutianbing 已提交
343
                                           size_t context_length,
344
                                           int context_start,
X
xutianbing 已提交
345
                                           size_t begin_pad) {
346 347 348 349 350 351 352 353 354 355
  CHECK_NOTNULL(out_grad);
  CHECK_NOTNULL(sequence);
  CHECK_NOTNULL(w_grad);

  int threads_x = 32;
  int threads_y = 32;
  int blocks_x = total_pad * ((w_dim + threads_x - 1) / threads_x);
  dim3 threads(threads_x, threads_y);
  dim3 grid(blocks_x, 1);

L
liaogang 已提交
356 357 358 359 360 361 362 363 364 365
  KeContextProjectionBackwardWeight<32,
                                    32><<<grid, threads, 0, STREAM_DEFAULT>>>(
      out_grad,
      sequence,
      w_grad,
      num_sequences,
      w_dim,
      context_length,
      context_start,
      begin_pad);
366 367 368 369
  CHECK_SYNC("hl_context_projection_backward_weight failed");
}

template <>
L
liaogang 已提交
370 371 372 373 374 375 376
void ContextProjectionBackwardWeight<DEVICE_TYPE_GPU>(const GpuMatrix& out_grad,
                                                      GpuMatrix& w_grad,
                                                      const GpuIVector& seq_vec,
                                                      size_t context_length,
                                                      int context_start,
                                                      size_t total_pad,
                                                      size_t begin_pad) {
377
  hl_context_projection_backward_weight(out_grad.getData(),
378
                                        seq_vec.getData(),
379
                                        w_grad.getData(),
380
                                        seq_vec.getSize() - 1,
381
                                        w_grad.getWidth(),
382 383 384 385 386 387
                                        total_pad,
                                        context_length,
                                        context_start,
                                        begin_pad);
}

388
template <>
389
void ContextProjectionBackward<DEVICE_TYPE_GPU>(const GpuMatrix& out_grad,
390 391
                                                GpuMatrix& in_grad,
                                                GpuMatrix& w_grad,
392 393 394 395 396 397
                                                const GpuIVector& sequence,
                                                size_t context_length,
                                                int context_start,
                                                size_t begin_pad,
                                                bool is_padding,
                                                size_t total_pad) {
L
liaogang 已提交
398 399 400 401 402 403 404 405 406 407 408 409
  if (in_grad) {
    ContextProjectionBackwardData<DEVICE_TYPE_GPU>(
        out_grad, in_grad, sequence, context_length, context_start);
  }
  if (is_padding && w_grad) {
    ContextProjectionBackwardWeight<DEVICE_TYPE_GPU>(out_grad,
                                                     w_grad,
                                                     sequence,
                                                     context_length,
                                                     context_start,
                                                     total_pad,
                                                     begin_pad);
410 411 412
  }
}

413
}  // namespace paddle