ContextProjectionOpGpu.cu 14.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "hl_base.h"
16
#include "ContextProjectionOp.h"
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75

namespace paddle {

template <bool padding>
__global__ void KeContextProjectionForward(const real* input,
                                           const int* sequence,
                                           const real* weight,
                                           real* output,
                                           int input_dim,
                                           int context_length,
                                           int context_start,
                                           int begin_pad) {
  int idx = threadIdx.x;
  int block_size = blockDim.x;
  int sequenceId = blockIdx.x;
  int seq_start = sequence[sequenceId];
  int seq_end = sequence[sequenceId+1];
  real value = 0;

  int instances = seq_end - seq_start + context_length - 1;
  output += seq_start * input_dim * context_length;
  input += seq_start * input_dim;
  for (int k = 0; k <= input_dim / block_size; k++) {
    if (idx < input_dim) {
      for (int i = 0; i < instances; i++) {
        // i + context_start;
        if ((i + context_start) < 0) {
          if (padding) {
            value = weight[i * input_dim + idx];
          } else {
            continue;
          }
        } else if ((i + context_start) >= (seq_end - seq_start)) {
          if (padding) {
            value =
              weight[(begin_pad + i + context_start - (seq_end - seq_start)) *
                         input_dim + idx];
          } else {
            continue;
          }
        } else {
          value = input[(i + context_start) * input_dim + idx];
        }

        int outx = (i - context_length) < 0 ? i : (context_length - 1);
        int outy = (i - context_length) < 0 ? 0 : (i - (context_length - 1));
        real* output_r =
          output + outy * input_dim * context_length + outx * input_dim;
        for (int j = outy; j < seq_end - seq_start; j++) {
          output_r[idx] += value;
          if (j - outy == outx) break;
          output_r += (context_length - 1) * input_dim;
        }
      }
    }
    idx += block_size;
  }
}

X
xutianbing 已提交
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
/**
 * @brief   Context projection forward.
 *
 * @param[in]   input           input sequence.
 * @param[in]   sequence        sequence index.
 * @param[in]   weight          padding data.
 * @param[out]  output          output sequence.
 * @param[in]   num_sequences    number of sequences.
 * @param[in]   input_dim        input sequence dimension.
 * @param[in]   context_length   context length.
 * @param[in]   context_start    context start.
 * @param[in]   begin_pad        number of extra timesteps added at the
 * beginning.
 *
 */
91 92
void hl_context_projection_forward(const real* input,
                                   const int* sequence,
93
                                   const real* weight,
94
                                   real* output,
X
xutianbing 已提交
95 96 97
                                   size_t num_sequences,
                                   size_t input_dim,
                                   size_t context_length,
98
                                   int context_start,
X
xutianbing 已提交
99
                                   size_t begin_pad) {
100 101 102 103 104 105 106 107 108 109
  CHECK_NOTNULL(input);
  CHECK_NOTNULL(sequence);
  CHECK_NOTNULL(output);

  int block_size = 128;
  int blocks_x = num_sequences;
  int blocks_y = 1;
  dim3 threads(block_size, 1);
  dim3 grid(blocks_x, blocks_y);

110
  if (weight) {
111 112 113 114 115 116 117 118 119 120 121 122
    KeContextProjectionForward<true><<< grid, threads, 0, STREAM_DEFAULT >>>
      (input, sequence, weight, output, input_dim,
       context_length, context_start, begin_pad);
  } else  {
    KeContextProjectionForward<false><<< grid, threads, 0, STREAM_DEFAULT >>>
      (input, sequence, weight, output, input_dim,
       context_length, context_start, begin_pad);
  }
  CHECK_SYNC("hl_context_projection_forward failed");
}

template <>
123 124 125
void ContextProjectionForward<DEVICE_TYPE_GPU>(GpuMatrix& output,
                                               const GpuMatrix& input,
                                               const GpuMatrix& weight,
126
                                               const GpuIVector& sequence,
127 128
                                               size_t context_length,
                                               int context_start,
129
                                               size_t begin_pad) {
130
  hl_context_projection_forward(input.getData(),
131
                                sequence.getData(),
132 133
                                weight ? weight.getData() : nullptr,
                                output.getData(),
134
                                sequence.getSize() - 1,
135
                                input.getWidth(),
136 137
                                context_length,
                                context_start,
138
                                begin_pad);
139 140
}

141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
__global__ void KeContextProjectionBackwardData(real* out_grad,
                                                const int* sequence,
                                                real* in_grad,
                                                int input_dim,
                                                int context_length,
                                                int context_start) {
  int idx = threadIdx.x;
  int block_size = blockDim.x;
  int sequenceId = blockIdx.x;
  int seq_start = sequence[sequenceId];
  int seq_end = sequence[sequenceId+1];
  real value = 0;

  int instances = seq_end - seq_start + context_length - 1;
  out_grad += seq_start * input_dim * context_length;
  in_grad += seq_start * input_dim;
  for (int k = 0; k <= input_dim / block_size; k++) {
    if (idx < input_dim) {
      for (int i = 0; i < instances; i++) {
        if ((i + context_start) < 0) {
          continue;
        } else if ((i + context_start) >= (seq_end - seq_start)) {
          continue;
        } else {
          // value = 0;
          value = in_grad[(i + context_start) * input_dim + idx];
        }

        int outx = (i - context_length) < 0 ? i : (context_length - 1);
        int outy = (i - context_length) < 0 ? 0 : (i - (context_length - 1));
        real* output_r =
          out_grad + outy * input_dim * context_length + outx * input_dim;
        for (int j = outy; j < seq_end - seq_start; j++) {
          value += output_r[idx];
          if (j - outy == outx) break;
          output_r += (context_length - 1) * input_dim;
        }
        in_grad[(i + context_start) * input_dim + idx] = value;
      }
    }
    idx += block_size;
  }
}

X
xutianbing 已提交
185 186 187 188 189 190 191 192 193 194 195 196
/**
 * @brief   Context projection backward data.
 *
 * @param[in]   out_grad         output gradient.
 * @param[in]   sequence         sequence index.
 * @param[out]  input_grad       input gradient.
 * @param[in]   num_sequences    number of sequences.
 * @param[in]   input_dim        input sequence dimension.
 * @param[in]   context_length   context length.
 * @param[in]   context_start    context start.
 *
 */
197 198 199
void hl_context_projection_backward_data(real* out_grad,
                                         const int* sequence,
                                         real* input_grad,
X
xutianbing 已提交
200 201 202
                                         size_t num_sequences,
                                         size_t input_dim,
                                         size_t context_length,
203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
                                         int context_start) {
  CHECK_NOTNULL(out_grad);
  CHECK_NOTNULL(sequence);
  CHECK_NOTNULL(input_grad);

  int block_size = 128;
  int blocks_x = num_sequences;
  int blocks_y = 1;
  dim3 threads(block_size, 1);
  dim3 grid(blocks_x, blocks_y);
  KeContextProjectionBackwardData<<< grid, threads, 0, STREAM_DEFAULT >>>
    (out_grad, sequence, input_grad, input_dim, context_length, context_start);
  CHECK_SYNC("hl_context_projection_backward_data failed");
}

template <>
219 220
void ContextProjectionBackwardData<DEVICE_TYPE_GPU>(GpuMatrix& out_grad,
                                                    GpuMatrix& in_grad,
221 222 223
                                                    const GpuIVector& sequence,
                                                    size_t context_length,
                                                    int context_start) {
224
  hl_context_projection_backward_data(out_grad.getData(),
225
                                      sequence.getData(),
226
                                      in_grad.getData(),
227
                                      sequence.getSize() - 1,
228
                                      in_grad.getWidth(),
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305
                                      context_length,
                                      context_start);
}

template<int THREADS_X, int THREADS_Y>
__global__ void KeContextProjectionBackwardWeight(real* out_grad,
                                                  const int* sequence,
                                                  real* w_grad,
                                                  int num_sequences,
                                                  int w_dim,
                                                  int context_length,
                                                  int context_start,
                                                  int begin_pad) {
  __shared__ real sum_s[THREADS_Y][THREADS_X];
  int pad_of_block = (w_dim + THREADS_X - 1) / THREADS_X;
  const int idx = threadIdx.x;
  const int idy = threadIdx.y;
  int padId = blockIdx.x / pad_of_block;
  int weight_idx = idx + THREADS_X * (blockIdx.x % pad_of_block);
  int instanceId;
  real value = 0;
  real* output_r;

  sum_s[idy][idx] = 0.0f;
  if (weight_idx < w_dim) {
    for (int seqId = idy; seqId < num_sequences; seqId += THREADS_Y) {
      int seq_start = sequence[seqId];
      int seq_end = sequence[seqId+1];
      output_r = out_grad + seq_start * w_dim * context_length;

      if (context_start < 0) {
        if (padId + context_start < 0) {
          instanceId = padId;
        } else {
          // begin_pad > 0;
          instanceId = (padId - begin_pad) +
            (seq_end - seq_start) - context_start;
        }
      } else {
        if (padId + (seq_end - seq_start) < context_start) {
          continue;
        } else {
          // begin_pad == 0;
          instanceId = padId + (seq_end - seq_start) - context_start;
        }
      }

      int outx = (instanceId - context_length) < 0 ?
                 instanceId : (context_length - 1);
      int outy = (instanceId - context_length) < 0 ?
                 0 : (instanceId - (context_length - 1));
      output_r += outy * w_dim * context_length + outx * w_dim;
      for (int j = outy; j < seq_end - seq_start; j++) {
        value += output_r[weight_idx];
        if (j - outy == outx) break;
        output_r += (context_length - 1) * w_dim;
      }
    }
    sum_s[idy][idx] = value;
  }
  __syncthreads();

  for (int stride = THREADS_Y/2; stride > 0; stride = stride/2) {
    if (idy < stride) {
      sum_s[idy][idx] += sum_s[idy + stride][idx];
    }
    __syncthreads();
  }
  __syncthreads();

  if (weight_idx < w_dim) {
    if (idy == 0) {
      w_grad[padId * w_dim + weight_idx] += sum_s[0][idx];
    }
  }
}

X
xutianbing 已提交
306 307 308 309 310 311 312 313 314 315 316 317 318 319 320
/**
 * @brief   Context projection backward weight.
 *
 * @param[in]   out_grad         output gradient.
 * @param[in]   sequence         sequence index.
 * @param[out]  w_grad           weight gradient.
 * @param[in]   num_sequences    number of sequences.
 * @param[in]   w_dim            input sequence dimension.
 * @param[in]   total_pad        number of extra timesteps.
 * @param[in]   context_length   context length.
 * @param[in]   context_start    context start.
 * @param[in]   begin_pad        number of extra timesteps added at the
 * beginning.
 *
 */
321 322 323
void hl_context_projection_backward_weight(real* out_grad,
                                           const int* sequence,
                                           real* w_grad,
X
xutianbing 已提交
324 325
                                           size_t num_sequences,
                                           size_t w_dim,
326
                                           size_t total_pad,
X
xutianbing 已提交
327
                                           size_t context_length,
328
                                           int context_start,
X
xutianbing 已提交
329
                                           size_t begin_pad) {
330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347
  CHECK_NOTNULL(out_grad);
  CHECK_NOTNULL(sequence);
  CHECK_NOTNULL(w_grad);

  int threads_x = 32;
  int threads_y = 32;
  int blocks_x = total_pad * ((w_dim + threads_x - 1) / threads_x);
  dim3 threads(threads_x, threads_y);
  dim3 grid(blocks_x, 1);

  KeContextProjectionBackwardWeight<32, 32>
    <<< grid, threads, 0, STREAM_DEFAULT >>>
    (out_grad, sequence, w_grad, num_sequences, w_dim,
     context_length, context_start, begin_pad);
  CHECK_SYNC("hl_context_projection_backward_weight failed");
}

template <>
348
void ContextProjectionBackwardWeight<DEVICE_TYPE_GPU>(
349 350
        GpuMatrix& out_grad,
        GpuMatrix& w_grad,
351 352 353 354 355
        const GpuIVector& seq_vec,
        size_t context_length,
        int context_start,
        size_t total_pad,
        size_t begin_pad) {
356
  hl_context_projection_backward_weight(out_grad.getData(),
357
                                        seq_vec.getData(),
358
                                        w_grad.getData(),
359
                                        seq_vec.getSize() - 1,
360
                                        w_grad.getWidth(),
361 362 363 364 365 366
                                        total_pad,
                                        context_length,
                                        context_start,
                                        begin_pad);
}

367
template <>
368 369 370
void ContextProjectionBackward<DEVICE_TYPE_GPU>(GpuMatrix& out_grad,
                                                GpuMatrix& in_grad,
                                                GpuMatrix& w_grad,
371 372 373 374 375 376 377 378 379
                                                const GpuIVector& sequence,
                                                size_t context_length,
                                                int context_start,
                                                size_t begin_pad,
                                                bool is_padding,
                                                size_t total_pad) {
    if (in_grad) {
        ContextProjectionBackwardData<DEVICE_TYPE_GPU>(
                out_grad,
380 381 382 383 384
                in_grad,
                sequence,
                context_length,
                context_start);
    }
385 386 387
    if (is_padding && w_grad) {
        ContextProjectionBackwardWeight<DEVICE_TYPE_GPU>(
                out_grad,
388 389 390 391 392 393 394 395 396
                w_grad,
                sequence,
                context_length,
                context_start,
                total_pad,
                begin_pad);
  }
}

397
}  // namespace paddle