norm_op.cu 5.8 KB
Newer Older
1
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <algorithm>
16
#ifdef __NVCC__
17
#include "cub/cub.cuh"
18 19 20 21 22
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
Y
Yi Wang 已提交
23
#include "paddle/fluid/operators/norm_op.h"
24

25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
namespace paddle {
namespace operators {

__device__ __forceinline__ float square_root(float x) { return sqrtf(x); }

__device__ __forceinline__ double square_root(double x) { return sqrt(x); }

template <typename T, int BlockDim>
__global__ void Normalize(const T* x, const int pre,
                          const int axis_n,  // dim in axis
                          const int post, const T eps, T* y, T* out_norm) {
  typedef cub::BlockReduce<T, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage temp_storage;
  int num = pre * post;
  for (int i = blockIdx.x; i < num; i += gridDim.x) {
    int base = (i / post) * post * axis_n + (i % post);

    T sum = 0.0;
    __shared__ T norm;
    for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
      const T x_ij = x[base + j * post];
      sum += x_ij * x_ij;
    }
    T reduce_result = BlockReduce(temp_storage).Sum(sum);

    if (threadIdx.x == 0) {
      norm = square_root(reduce_result + eps);
      out_norm[i] = norm;
    }
    __syncthreads();
    for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
      const int index = base + j * post;
      y[index] = x[index] / norm;
    }
  }
}

template <typename DeviceContext, typename T>
class NormCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* in_x = ctx.Input<framework::Tensor>("X");
    auto* out_y = ctx.Output<framework::Tensor>("Out");
    auto* out_norm = ctx.Output<framework::Tensor>("Norm");
    const T* x = in_x->data<T>();
    T* y = out_y->mutable_data<T>(ctx.GetPlace());
    T* norm = out_norm->mutable_data<T>(ctx.GetPlace());

    auto xdim = in_x->dims();
    auto ndim = out_norm->dims();
    int axis = ctx.Attr<int>("axis");
    T eps = static_cast<T>(ctx.Attr<float>("epsilon"));
    if (axis < 0) axis = xdim.size() + axis;
    int pre, n, post;
    GetDims(xdim, axis, &pre, &n, &post);

    auto& dev_ctx = ctx.cuda_device_context();

    const int block = 512;
    int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
    const int max_blocks = std::max(max_threads / block, 1);
    int grid = std::min(max_blocks, pre * post);
    Normalize<T, block><<<grid, block, 0, dev_ctx.stream()>>>(x, pre, n, post,
                                                              eps, y, norm);
  }
};

template <typename T, int BlockDim>
__global__ void NormalizeGradient(const T* x, const T* x_norm, const T* y_grad,
                                  const int pre, const int axis_n,
                                  const int post, T* x_grad) {
  typedef cub::BlockReduce<T, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage temp_storage_sum;
  int num = pre * post;
  for (int i = blockIdx.x; i < num; i += gridDim.x) {
    T sum = 0.0;
    __shared__ T row_sum;
    __shared__ T row_sqrt_norm;
    __shared__ T row_norm;

    auto base = (i / post) * post * axis_n + (i % post);

    for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
      int index = base + j * post;
      sum += x[index] * y_grad[index];
    }
    T reduce_result = BlockReduce(temp_storage_sum).Sum(sum);

    if (threadIdx.x == 0) {
      row_sum = reduce_result;
      row_sqrt_norm = x_norm[i];
      row_norm = row_sqrt_norm * row_sqrt_norm;
    }
    __syncthreads();
    for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
      int index = base + j * post;
      const T x_ij = x[index];
      const T dy_ij = y_grad[index];
      x_grad[index] = (dy_ij - x_ij * row_sum / row_norm) / row_sqrt_norm;
    }
  }
}

template <typename DeviceContext, typename T, typename AttrType = T>
class NormGradCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* in_x = ctx.Input<framework::Tensor>("X");
    auto* in_norm = ctx.Input<framework::Tensor>("Norm");
    auto* in_dy = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
    auto* out_dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
    T* dx = out_dx->mutable_data<T>(ctx.GetPlace());
    const T* x = in_x->data<T>();
    const T* x_norm = in_norm->data<T>();
    const T* dy = in_dy->data<T>();

    auto xdim = in_x->dims();
    int axis = ctx.Attr<int>("axis");
    if (axis < 0) axis = xdim.size() + axis;
    int pre, n, post;
    GetDims(xdim, axis, &pre, &n, &post);

    auto& dev_ctx = ctx.cuda_device_context();

    const int block = 512;
    int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
    const int max_blocks = std::max(max_threads / block, 1);
    int grid = std::min(max_blocks, pre * post);
    NormalizeGradient<T, block><<<grid, block, 0, dev_ctx.stream()>>>(
        x, x_norm, dy, pre, n, post, dx);
  }
};

}  // namespace operators
}  // namespace paddle

161
namespace ops = paddle::operators;
162 163
using CUDA = paddle::platform::CUDADeviceContext;

164 165 166 167
REGISTER_OP_CUDA_KERNEL(norm, ops::NormCUDAKernel<CUDA, float>,
                        ops::NormCUDAKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(norm_grad, ops::NormGradCUDAKernel<CUDA, float>,
                        ops::NormGradCUDAKernel<CUDA, double>);