CosSimOpGpu.cu 7.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "hl_base.h"
16
#include "hl_device_functions.cuh"
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
#include "CosSimOp.h"

namespace paddle {

template<int block_size>
__global__ void KeCosSim(real* output,
                         const real* input1,
                         const real* input2,
                         int width,
                         int input1_height,
                         int input2_height,
                         real scale) {
  const int ty = blockIdx.y;
  int tid = threadIdx.x;

  __shared__ real xx[block_size];
  __shared__ real yy[block_size];
  __shared__ real xy[block_size];

  xx[tid] = 0.0;
  yy[tid] = 0.0;
  xy[tid] = 0.0;
  __syncthreads();

  input1 += ty * width;
  if (input2_height > 1) {
    input2 += ty * width;
  }
  for (int index = tid; index < width; index += block_size) {
    real x = input1[index];
    real y = input2[index];
    xx[tid] += x * x;
    yy[tid] += y * y;
    xy[tid] += x * y;
  }
  __syncthreads();

  for (int s = block_size / 2; s > 0; s >>= 1) {
    if (tid < s) {
      xx[tid] += xx[tid + s];
      yy[tid] += yy[tid + s];
      xy[tid] += xy[tid + s];
    }
    __syncthreads();
  }
  if (tid == 0) {
    output[ty] = scale * xy[0] / (sqrt(xx[0]) * sqrt(yy[0]));
  }
}

void hlCossim(real* output,
68 69 70 71 72 73
              const real* input1,
              const real* input2,
              size_t width,
              size_t input1_height,
              size_t input2_height,
              real scale) {
74 75 76 77 78 79 80 81 82
  CHECK_NOTNULL(output);
  CHECK_NOTNULL(input1);
  CHECK_NOTNULL(input2);
  const int block_size = 256;
  dim3 threads(block_size, 1);
  dim3 grid(1, input1_height);

  KeCosSim<block_size><<<grid, threads, 0, STREAM_DEFAULT>>>
    (output, input1, input2, width, input1_height, input2_height, scale);
83
  CHECK_SYNC("hlCossim failed");
84 85 86
}

template <>
87 88 89
void CosSimForward<DEVICE_TYPE_GPU>(GpuMatrix& out_mat,
                                    const GpuMatrix& in1_mat,
                                    const GpuMatrix& in2_mat,
90
                                    real scale) {
91 92
  CHECK(out_mat.getData() && in1_mat.getData() && in2_mat.getData());
  CHECK(in1_mat.useGpu_ == true && in2_mat.useGpu_ == true)
93 94
      << "Matrix type are not GPU";

95 96 97 98 99 100
  size_t num_samples = out_mat.getHeight();
  size_t dim = in1_mat.getWidth();
  real* out = out_mat.getData();
  const real* x = in1_mat.getData();
  const real* y = in2_mat.getData();
  hlCossim(out, x, y, dim, in1_mat.getHeight(), in2_mat.getHeight(), scale);
101 102
}

103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
template<int block_size>
__global__ void KeCosSimDerivative(const real* grad,
                                   const real* output,
                                   const real* prev_out_x,
                                   const real* prev_out_y,
                                   real* prev_grad_x,
                                   real* prev_grad_y,
                                   size_t width,
                                   size_t input1_height,
                                   size_t input2_height,
                                   real scale) {
  const int ty = blockIdx.y;
  int tid = threadIdx.x;

  __shared__ real xx[block_size];
  __shared__ real yy[block_size];
  __shared__ real xy[block_size];

  xx[tid] = 0.0;
  yy[tid] = 0.0;
  xy[tid] = 0.0;
  __syncthreads();

  prev_out_x += ty * width;
  prev_grad_x += ty * width;
  if (input2_height > 1) {
    prev_out_y += ty * width;
    prev_grad_y += ty * width;
  }
  for (int index = tid; index < width; index += block_size) {
    real x = prev_out_x[index];
    real y = prev_out_y[index];
    xx[tid] += x * x;
    yy[tid] += y * y;
    xy[tid] += x * y;
  }
  __syncthreads();

  for (int s = block_size / 2; s > 0; s >>= 1) {
    if (tid < s) {
      xx[tid] += xx[tid + s];
      yy[tid] += yy[tid + s];
      xy[tid] += xy[tid + s];
    }
    __syncthreads();
  }
  if (xy[0] == 0) {
    real reciprocal = 1.0 / (sqrt(xx[0]) * sqrt(yy[0]));
    for (int index = tid; index < width; index += block_size) {
      prev_grad_x[index] +=
        scale * grad[ty] * prev_out_y[index] * reciprocal;
      if (input2_height > 1) {
        prev_grad_y[index] +=
          scale * grad[ty] * prev_out_x[index] * reciprocal;
      } else {
        paddle::paddleAtomicAdd(prev_grad_y + index,
          scale * grad[ty] * prev_out_x[index] * reciprocal);
      }
    }
  } else {
    real reciprocalXY = 1.0 / xy[0];
    real reciprocalSquareSumX = 1.0 / xx[0];
    real reciprocalSquareSumY = 1.0 / yy[0];
    for (int index = tid; index < width; index += block_size) {
      prev_grad_x[index] += output[ty] * grad[ty] *
        (prev_out_y[index] * reciprocalXY -
         prev_out_x[index] * reciprocalSquareSumX);
      if (input2_height > 1) {
        prev_grad_y[index] += output[ty] * grad[ty] *
          (prev_out_x[index] * reciprocalXY -
           prev_out_y[index] * reciprocalSquareSumY);
      } else {
        paddle::paddleAtomicAdd(prev_grad_y + index, output[ty] * grad[ty] *
          (prev_out_x[index] * reciprocalXY -
           prev_out_y[index] * reciprocalSquareSumY));
      }
    }
  }
}

void hlCossimDerivative(const real* grad,
                        const real* output,
                        const real* prev_out_x,
                        const real* prev_out_y,
                        real* prev_grad_x,
                        real* prev_grad_y,
                        size_t width,
                        size_t input1_height,
                        size_t input2_height,
                        real scale) {
  CHECK_NOTNULL(grad);
  CHECK_NOTNULL(output);
  CHECK_NOTNULL(prev_out_x);
  CHECK_NOTNULL(prev_out_y);
  CHECK_NOTNULL(prev_grad_x);
  CHECK_NOTNULL(prev_grad_y);
  const int block_size = 256;
  dim3 threads(block_size, 1);
  dim3 grid(1, input1_height);
  KeCosSimDerivative<block_size><<<grid, threads, 0, STREAM_DEFAULT>>>
    (grad, output, prev_out_x, prev_out_y, prev_grad_x, prev_grad_y, width,
        input1_height, input2_height, scale);
  CHECK_SYNC("hlCossimDerivate failed");
}

template <>
209 210 211 212 213 214
void CosSimBackward<DEVICE_TYPE_GPU>(const GpuMatrix& out_grad,
                                     const GpuMatrix& out_val,
                                     const GpuMatrix& in1_val,
                                     const GpuMatrix& in2_val,
                                     GpuMatrix& in1_grad,
                                     GpuMatrix& in2_grad,
215
                                     real scale) {
216 217 218 219
  CHECK(out_grad.getData() && out_val.getData() && in1_val.getData() &&
        in2_val.getData() && in1_grad.getData() && in2_grad.getData());
  CHECK(out_grad.useGpu_ && out_val.useGpu_ && in1_val.useGpu_
        && in2_val.useGpu_ && in1_grad.useGpu_ && in2_grad.useGpu_)
220 221
        << "Matrix types are not equally GPU";

222 223 224 225 226 227 228
  size_t dim = in1_val.getWidth();
  const real* grad = out_grad.getData();
  const real* out = out_val.getData();
  const real* prev_out_x = in1_val.getData();
  const real* prev_out_y = in2_val.getData();
  real* prev_grad_x = in1_grad.getData();
  real* prev_grad_y = in2_grad.getData();
229 230 231 232 233 234 235
  hlCossimDerivative(grad,
                     out,
                     prev_out_x,
                     prev_out_y,
                     prev_grad_x,
                     prev_grad_y,
                     dim,
236 237
                     in1_val.getHeight(),
                     in2_val.getHeight(),
238 239 240
                     scale);
}

241
}  // namespace paddle
反馈
建议
客服 返回
顶部