cos_sim_op.cu 2.5 KB
Newer Older
X
Xinghai Sun 已提交
1 2
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
X
Xinghai Sun 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
X
Xinghai Sun 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
X
Xinghai Sun 已提交
14 15 16

#define EIGEN_USE_GPU
#include "paddle/operators/cos_sim_op.h"
17
#include "paddle/platform/cuda_helper.h"
X
Xinghai Sun 已提交
18

C
refine  
chengduoZH 已提交
19 20 21 22
namespace paddle {
namespace operators {

template <typename T>
C
chengduoZH 已提交
23 24 25 26 27
__global__ void CosSimDyKernel(const T* x_norm, const T* y_norm, const T* x,
                               const T* y, const T* z, const T* dz,
                               const size_t rows, const size_t cols, T* dy) {
  int grid_size = blockDim.x * gridDim.x;
  T y_norm_data = y_norm[0];
C
chengduoZH 已提交
28 29 30 31 32 33
  for (int row_id = blockIdx.x * blockDim.x + threadIdx.x; row_id < rows;
       row_id += grid_size) {
    T xy_norm_prod = x_norm[row_id] * y_norm_data;
    T dz_data = dz[row_id];
    T z_data = z[row_id];
    const T* x_data = x + cols * row_id;
C
chengduoZH 已提交
34
    T reciprocal_xy_norm_prod = 1 / xy_norm_prod;
C
refine  
chengduoZH 已提交
35

C
chengduoZH 已提交
36 37 38 39 40 41
    T y_norm_square = y_norm_data * y_norm_data;
    T reciprocal_y_norm_square = 1 / y_norm_square;
    for (size_t i = 0; i < cols; ++i) {
      T dy_data = dz_data * (x_data[i] * reciprocal_xy_norm_prod -
                             z_data * y[i] * reciprocal_y_norm_square);
      platform::CudaAtomicAdd(dy + i, dy_data);
C
refine  
chengduoZH 已提交
42 43
    }
  }
C
chengduoZH 已提交
44
}
C
refine  
chengduoZH 已提交
45

C
chengduoZH 已提交
46 47 48 49 50 51 52 53 54 55 56 57
template <typename T>
struct CosSimDyFunctor<platform::CUDADeviceContext, T> {
  inline void operator()(const platform::CUDADeviceContext& ctx,
                         const T* x_norm, const T* y_norm, const T* x,
                         const T* y, const T* z, const T* dz, const size_t rows,
                         const size_t cols, T* dy) const {
    const int block_size = 512;
    dim3 threads(block_size, 1);
    dim3 grid(1, (rows + block_size - 1) / block_size);
    CosSimDyKernel<T><<<grid, threads, 0, ctx.stream()>>>(
        x_norm, y_norm, x, y, z, dz, rows, cols, dy);
  }
C
refine  
chengduoZH 已提交
58 59 60 61 62
};

}  // namespace operators
}  // namespace paddle

X
Xinghai Sun 已提交
63
namespace ops = paddle::operators;
Q
QI JUN 已提交
64 65 66 67 68
REGISTER_OP_CUDA_KERNEL(
    cos_sim, ops::CosSimKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
    cos_sim_grad,
    ops::CosSimGradKernel<paddle::platform::CUDADeviceContext, float>);