row_conv_op.cu 14.4 KB
Newer Older
L
Luo Tao 已提交
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
S
Siddharth Goyal 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
S
Siddharth Goyal 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
S
Siddharth Goyal 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
S
Siddharth Goyal 已提交
14

Y
Yi Wang 已提交
15 16 17
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/row_conv_op.h"
#include "paddle/fluid/platform/cuda_helper.h"
S
Siddharth Goyal 已提交
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293

namespace paddle {
namespace operators {

using LoDTensor = framework::LoDTensor;
using framework::Tensor;

namespace {

inline int DivUp(int x, int y) { return (x + y - 1) / y; }

// Forward prop (shared memory version, for small future_context)
template <typename T>
__global__ void RowConvForwardSharedMemory(const T *in, const T *wt,
                                           int num_sequence, int input_dim,
                                           int future_context,
                                           const size_t *batch_indices,
                                           T *out) {
  int blx = blockDim.x;
  int bly = blockDim.y;
  int thx = threadIdx.x;
  int thy = threadIdx.y;
  int d = blockIdx.x * blx + thx;  // index along input dim

  extern __shared__ T mem[];
  T *sw = mem;

  if (thy < future_context) {
    sw[thy * blx + thx] =
        (d < input_dim) ? wt[thy * input_dim + d] : static_cast<T>(0);
  }
  __syncthreads();

  for (size_t i = 0; i < num_sequence; i++) {
    int start = static_cast<int>(batch_indices[i]);
    int end = static_cast<int>(batch_indices[i + 1]);
    int current_timesteps = end - start;
    for (int k = thy; k < current_timesteps; k += bly) {
      T sum = 0;
      for (int w = 0; (w < future_context) && ((k + w) < current_timesteps);
           w++) {
        sum += (d < input_dim)
                   ? sw[w * blx + thx] * in[(start + k + w) * input_dim + d]
                   : static_cast<T>(0);
      }
      if (d < input_dim) {
        out[(start + k) * input_dim + d] = sum;
      }
    }
  }
}

// Forward prop (naive version)
template <typename T>
__global__ void RowConvForward(const T *in, const T *wt, int num_sequence,
                               int input_dim, int future_context,
                               const size_t *batch_indices, T *out) {
  int d = blockIdx.x * blockDim.x + threadIdx.x;  // index along input_dim
  int bly = blockDim.y;
  int thy = threadIdx.y;

  if (d >= input_dim) return;

  for (size_t i = 0; i < num_sequence; i++) {
    int start = static_cast<int>(batch_indices[i]);
    int end = static_cast<int>(batch_indices[i + 1]);
    int current_timesteps = end - start;
    for (int k = thy; k < current_timesteps; k += bly) {
      T sum = 0;
      for (int w = 0; (w < future_context) && ((k + w) < current_timesteps);
           w++) {
        sum += (wt[w * input_dim + d] * in[(start + k + w) * input_dim + d]);
      }
      out[(start + k) * input_dim + d] = sum;
    }
  }
}

// Compute input gradient (shared memory version, for small future_context)
template <typename T>
__global__ void RowConvGradInputSharedMemory(const T *dout, const T *wt,
                                             int num_sequence, int input_dim,
                                             int future_context,
                                             const size_t *batch_indices,
                                             T *din) {
  int blx = blockDim.x;
  int bly = blockDim.y;
  int thx = threadIdx.x;
  int thy = threadIdx.y;
  int d = blockIdx.x * blx + thx;  // index along input dim

  extern __shared__ T mem[];
  T *sw = mem;
  if (thy < future_context) {
    sw[thy * blx + thx] =
        (d < input_dim) ? wt[thy * input_dim + d] : static_cast<T>(0);
  }
  __syncthreads();

  for (int i = 0; i < num_sequence; i++) {
    int start = static_cast<int>(batch_indices[i]);
    int end = static_cast<int>(batch_indices[i + 1]);
    int current_timesteps = end - start;
    for (int k = thy; k < current_timesteps; k += bly) {
      T sum = 0;
      for (int w = 0; (w < future_context) && ((k - w) >= 0); w++) {
        sum += (d < input_dim)
                   ? (sw[w * blx + thx] * dout[(k + start - w) * input_dim + d])
                   : static_cast<T>(0);
      }
      if (d < input_dim) {
        din[(k + start) * input_dim + d] = sum;
      }
    }
  }
}

// Compute input gradient (Naive version)
template <typename T>
__global__ void RowConvGradInput(const T *dout, const T *wt, int num_sequence,
                                 int input_dim, int future_context,
                                 const size_t *batch_indices, T *din) {
  int d = blockIdx.x * blockDim.x + threadIdx.x;  // index along input_dim
  int bly = blockDim.y;
  int thy = threadIdx.y;

  if (d >= input_dim) return;
  for (int i = 0; i < num_sequence; i++) {
    int start = static_cast<int>(batch_indices[i]);
    int end = static_cast<int>(batch_indices[i + 1]);
    int current_timesteps = end - start;
    for (int k = thy; k < current_timesteps; k += bly) {
      T sum = 0;
      for (int w = 0; (w < future_context) && ((k - w) >= 0); w++) {
        sum += (wt[w * input_dim + d] * dout[(k + start - w) * input_dim + d]);
      }
      din[(k + start) * input_dim + d] = sum;
    }
  }
}

// Compute W gradient (small future_context version)
template <typename T>
__global__ void RowConvGradFilterImproved(const T *in, const T *dout,
                                          int num_sequence, int input_dim,
                                          int future_context, int block_x,
                                          int block_y,
                                          const size_t *batch_indices,
                                          T *dfilter) {
  int blx = blockDim.x;
  int bly = blockDim.y;
  int thx = threadIdx.x;
  int thy = threadIdx.y;
  int gx = blockIdx.x * blx;
  int d = gx + thx;  // index along input dim

  extern __shared__ T mem[];

  int xdim_sh_in = block_y;
  int xdim_sh_dout = block_y;
  // int xdim_sh_dfilter = future_context;
  int ydim_sh_in = block_x;
  int ydim_sh_dout = block_x + future_context - 1;
  int ydim_sh_dfilter = block_y;

  T *sh_in = mem;
  T *sh_dout = &mem[xdim_sh_in * ydim_sh_in];
  T *sh_dfilter = &mem[xdim_sh_in * ydim_sh_in + xdim_sh_dout * ydim_sh_dout];

  if (thy < future_context) {
    sh_dfilter[thy * ydim_sh_dfilter + thx] = static_cast<T>(0);
  }
  __syncthreads();

  for (int i = 0; i < num_sequence; i++) {
    int start = static_cast<int>(batch_indices[i]);
    int end = static_cast<int>(batch_indices[i + 1]);
    int current_timesteps = end - start;
    int scaled_cur_steps =
        ((current_timesteps + block_x - 1) / block_x) * block_x;

    for (int k = thy; k < scaled_cur_steps; k += block_x) {
      int pos = start + k;
      sh_in[thx * ydim_sh_in + thy] =
          (d < input_dim && pos < end) ? in[pos * input_dim + d] : T(0);
      sh_dout[thx * ydim_sh_dout + thy + future_context - 1] =
          (d < input_dim && pos < end) ? dout[pos * input_dim + d] : T(0);
      __syncthreads();

      if (thy < future_context - 1) {
        int pos_offset = pos - future_context + 1;
        sh_dout[thx * ydim_sh_dout + thy] =
            (d < input_dim && pos_offset >= start)
                ? dout[pos_offset * input_dim + d]
                : T(0);
      }
      __syncthreads();

      for (int w = 0; w < future_context; w++) {
        T val = sh_in[thy * ydim_sh_in + thx] *
                sh_dout[thy * ydim_sh_dout + thx + future_context - 1 - w];
        __syncthreads();

        for (int offset = 16; offset > 0;
             offset = offset / 2) {  // blockDim.x is 32.
          val += __shfl_down(val, offset);
        }
        __syncthreads();

        if (thx == 0) {
          sh_dfilter[w * ydim_sh_dfilter + thy] += val;
        }
        __syncthreads();
      }
    }
  }
  for (int w = thy; (w < future_context) && (d < input_dim); w += bly) {
    dfilter[w * input_dim + d] += sh_dfilter[w * ydim_sh_dfilter + thx];
  }
}

// Compute weight(filter) gradient
template <typename T>
__global__ void RowConvGradFilter(const T *in, const T *dout, int num_sequence,
                                  int input_dim, int future_context,
                                  int block_x, int block_y,
                                  const size_t *batch_indices, T *dfilter) {
  int blx = blockDim.x;
  int thx = threadIdx.x;
  int thy = threadIdx.y;
  int gx = blockIdx.x * blx;
  int d = gx + thx;  // index along input dim
  extern __shared__ T mem[];
  T *sh_in = mem;
  T *sh_dout = &mem[block_x * block_y];

  for (int i = 0; i < num_sequence; i++) {
    int start = static_cast<int>(batch_indices[i]);
    int end = static_cast<int>(batch_indices[i + 1]);
    int current_timesteps = end - start;
    int scaled_cur_steps =
        ((current_timesteps + block_x - 1) / block_x) * block_x;

    for (int k = thy; k < scaled_cur_steps; k += block_x) {
      int pos = start + k;
      sh_in[thx * block_y + thy] =
          (d < input_dim && pos < end) ? in[pos * input_dim + d] : 0.0;
      __syncthreads();

      for (int w = 0; w < future_context; w++) {
        sh_dout[thx * block_y + thy] =
            (d < input_dim && (k - w) >= 0 && (k - w) < current_timesteps)
                ? dout[(pos - w) * input_dim + d]
                : 0.0;
        __syncthreads();

        T val = sh_in[thy * block_y + thx] * sh_dout[thy * block_y + thx];
        __syncthreads();

        for (int offset = 16; offset > 0;
             offset = offset / 2) {  // blockDim.x is 32.
          val += __shfl_down(val, offset);
        }
        __syncthreads();

        if (thx == 0 && (gx + thy) < input_dim) {
          dfilter[w * input_dim + gx + thy] += val;
        }
      }
    }
  }
}

}  // namespace

template <typename T>
Q
QI JUN 已提交
294 295
class RowConvKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
S
Siddharth Goyal 已提交
296 297 298 299 300 301 302 303 304 305 306 307 308 309
 public:
  void Compute(const framework::ExecutionContext &context) const override {
    auto *X = context.Input<LoDTensor>("X");
    auto *Filter = context.Input<Tensor>("Filter");
    auto *Out = context.Output<LoDTensor>("Out");

    const T *in = X->data<T>();
    const T *weight = Filter->data<T>();
    T *out = Out->mutable_data<T>(context.GetPlace());

    auto batch_indices = X->lod()[0];
    int input_dim = X->dims()[1];
    int num_sequence = batch_indices.size() - 1;
    int future_context = Filter->dims()[0];
Y
Yu Yang 已提交
310
    size_t *idx = batch_indices.CUDAMutableData(context.GetPlace());
S
Siddharth Goyal 已提交
311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329
    auto stream = context.cuda_device_context().stream();

    if (future_context <= 32) {
      dim3 block_dim = dim3(32, 32);
      dim3 grid_dim = dim3(DivUp(input_dim, block_dim.x), 1);
      int mem_per_block = (future_context * block_dim.x) * sizeof(T);
      RowConvForwardSharedMemory<
          T><<<grid_dim, block_dim, mem_per_block, stream>>>(
          in, weight, num_sequence, input_dim, future_context, idx, out);
    } else {
      dim3 block_dim = dim3(32, 32);
      dim3 grid_dim = dim3(DivUp(input_dim, block_dim.x), 1);
      RowConvForward<T><<<grid_dim, block_dim, 0, stream>>>(
          in, weight, num_sequence, input_dim, future_context, idx, out);
    }
  }
};

template <typename T>
Q
QI JUN 已提交
330 331
class RowConvGradKernel<platform::CUDADeviceContext, T>
    : public framework::OpKernel<T> {
S
Siddharth Goyal 已提交
332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347
 public:
  void Compute(const framework::ExecutionContext &context) const override {
    auto *X = context.Input<LoDTensor>("X");
    auto *Filter = context.Input<Tensor>("Filter");
    auto *dOut = context.Input<LoDTensor>(framework::GradVarName("Out"));
    const T *in = X->data<T>();
    const T *weights = Filter->data<T>();
    const T *dout = dOut->data<T>();

    Tensor *dX = context.Output<LoDTensor>(framework::GradVarName("X"));
    Tensor *dFilter = context.Output<Tensor>(framework::GradVarName("Filter"));

    auto batch_indices = X->lod()[0];
    int input_dim = X->dims()[1];
    int num_sequence = batch_indices.size() - 1;
    int future_context = Filter->dims()[0];
Y
Yu Yang 已提交
348
    size_t *idx = batch_indices.CUDAMutableData(context.GetPlace());
S
Siddharth Goyal 已提交
349 350

    auto &device_ctx = context.cuda_device_context();
Q
QI JUN 已提交
351
    math::SetConstant<platform::CUDADeviceContext, T> zero;
S
Siddharth Goyal 已提交
352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405

    if (dFilter) {
      T *dfilter = dFilter->mutable_data<T>(context.GetPlace());
      zero(device_ctx, dFilter, static_cast<T>(0.0));

      if (future_context <= 32) {
        dim3 block_dim = dim3(32, 32);
        dim3 grid_dim = dim3(DivUp(input_dim, block_dim.x), 1);
        int block_x = block_dim.x;
        int block_y = block_dim.y;
        int mem_per_block =
            (block_y * block_x + block_y * (block_x + future_context - 1) +
             future_context * block_y) *
            sizeof(T);
        RowConvGradFilterImproved<
            T><<<grid_dim, block_dim, mem_per_block, device_ctx.stream()>>>(
            in, dout, num_sequence, input_dim, future_context, block_x, block_y,
            idx, dfilter);
      } else {
        dim3 block_dim = dim3(32, 32);
        dim3 grid_dim = dim3(DivUp(input_dim, block_dim.x), 1);
        int block_x = block_dim.x;
        int block_y = block_dim.y;
        int mem_per_block =
            (block_x * block_y * 2) * sizeof(T);  // For 2 arrays of size 32x32
        RowConvGradFilter<
            T><<<grid_dim, block_dim, mem_per_block, device_ctx.stream()>>>(
            in, dout, num_sequence, input_dim, future_context, block_x, block_y,
            idx, dfilter);
      }
    }

    if (dX) {
      T *din = dX->mutable_data<T>(context.GetPlace());
      if (future_context <= 32) {
        dim3 block_dim = dim3(32, 32);
        dim3 grid_dim = dim3(DivUp(input_dim, block_dim.x), 1);
        int mem_per_block = (future_context * block_dim.x) * sizeof(T);
        RowConvGradInputSharedMemory<
            T><<<grid_dim, block_dim, mem_per_block, device_ctx.stream()>>>(
            dout, weights, num_sequence, input_dim, future_context, idx, din);
      } else {
        dim3 block_dim = dim3(32, 32);
        dim3 grid_dim = dim3(DivUp(input_dim, block_dim.x), 1);
        RowConvGradInput<T><<<grid_dim, block_dim, 0, device_ctx.stream()>>>(
            dout, weights, num_sequence, input_dim, future_context, idx, din);
      }
    }
  }
};
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Q
QI JUN 已提交
406 407 408 409 410
REGISTER_OP_CUDA_KERNEL(
    row_conv, ops::RowConvKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
    row_conv_grad,
    ops::RowConvGradKernel<paddle::platform::CUDADeviceContext, float>);