p_norm_op.cu 11.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <algorithm>
16
#ifdef __NVCC__
17
#include "cub/cub.cuh"
18 19 20 21 22
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
G
Guoxia Wang 已提交
23
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
24
#include "paddle/fluid/operators/p_norm_op.h"
G
Guoxia Wang 已提交
25
#include "paddle/fluid/platform/float16.h"
26 27 28 29 30 31 32 33 34

namespace paddle {
namespace operators {

template <typename T>
__device__ __forceinline__ int sgn(T val) {
  return (T(0) < val) - (val < T(0));
}

G
Guoxia Wang 已提交
35 36 37
__device__ __forceinline__ platform::float16 inline_abs(platform::float16 x) {
  return static_cast<platform::float16>(abs(static_cast<float>(x)));
}
38 39 40
__device__ __forceinline__ float inline_abs(float x) { return abs(x); }
__device__ __forceinline__ double inline_abs(double x) { return abs(x); }

G
Guoxia Wang 已提交
41 42 43
__device__ __forceinline__ int inline_sign(platform::float16 x) {
  return sgn<platform::float16>(x);
}
44 45 46
__device__ __forceinline__ int inline_sign(float x) { return sgn<float>(x); }
__device__ __forceinline__ int inline_sign(double x) { return sgn<double>(x); }

G
Guoxia Wang 已提交
47 48 49 50 51
__device__ __forceinline__ platform::float16 inline_pow(
    platform::float16 base, platform::float16 exponent) {
  return static_cast<platform::float16>(
      pow(static_cast<float>(base), static_cast<float>(exponent)));
}
52 53 54 55 56 57 58 59 60 61 62
__device__ __forceinline__ float inline_pow(float base, float exponent) {
  return pow(base, exponent);
}
__device__ __forceinline__ double inline_pow(double base, double exponent) {
  return pow(base, exponent);
}

template <typename T, int BlockDim>
__global__ void Pnorm(const T* x, const int pre,
                      const int axis_n,  // dim in axis
                      const int post, float porder, T* out_norm) {
G
Guoxia Wang 已提交
63 64
  using MT = typename details::MPTypeTrait<T>::Type;
  typedef cub::BlockReduce<MT, BlockDim> BlockReduce;
65 66
  __shared__ typename BlockReduce::TempStorage temp_storage;
  int num = pre * post;
G
Guoxia Wang 已提交
67 68
  auto porder_t = static_cast<MT>(porder);
  auto porder_inv = static_cast<MT>(1.0 / porder);
69

70 71
  for (int i = blockIdx.x; i < num; i += gridDim.x) {
    int base = (i / post) * post * axis_n + (i % post);
G
Guoxia Wang 已提交
72
    MT sum = static_cast<MT>(0.0);
73
    for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
G
Guoxia Wang 已提交
74
      const MT x_ij = static_cast<MT>(x[base + j * post]);
75
      sum += inline_pow(inline_abs(x_ij), porder_t);
76
    }
G
Guoxia Wang 已提交
77 78 79
    MT reduce_result = BlockReduce(temp_storage).Sum(sum);
    if (threadIdx.x == 0)
      out_norm[i] = static_cast<T>(inline_pow(reduce_result, porder_inv));
80 81
  }
}
82

83 84 85 86
template <typename T, int BlockDim>
__global__ void ZeorNorm(const T* x, const int pre,
                         const int axis_n,  // dim in axis
                         const int post, T* out_norm) {
G
Guoxia Wang 已提交
87 88
  using MT = typename details::MPTypeTrait<T>::Type;
  typedef cub::BlockReduce<MT, BlockDim> BlockReduce;
89 90 91 92
  __shared__ typename BlockReduce::TempStorage temp_storage;
  int num = pre * post;
  for (int i = blockIdx.x; i < num; i += gridDim.x) {
    int base = (i / post) * post * axis_n + (i % post);
G
Guoxia Wang 已提交
93
    MT sum = static_cast<MT>(0.0);
94
    for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
G
Guoxia Wang 已提交
95 96
      const MT x_ij = static_cast<MT>(x[base + j * post]);
      sum += static_cast<MT>(static_cast<double>(x_ij) != 0);
97
    }
G
Guoxia Wang 已提交
98 99
    MT reduce_result = BlockReduce(temp_storage).Sum(sum);
    if (threadIdx.x == 0) out_norm[i] = static_cast<T>(reduce_result);
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
  }
}

template <typename T, int BlockDim>
__global__ void InfNorm(const T* x, const int pre,
                        const int axis_n,  // dim in axis
                        const int post, T* out_norm) {
  typedef cub::BlockReduce<T, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage temp_storage;
  int num = pre * post;
  for (int i = blockIdx.x; i < num; i += gridDim.x) {
    int base = (i / post) * post * axis_n + (i % post);
    T cur_max = inline_abs(x[base]);
    for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
      T x_ij_abs = inline_abs(x[base + j * post]);
      if (cur_max < x_ij_abs) cur_max = x_ij_abs;
    }
    T reduce_result = BlockReduce(temp_storage).Reduce(cur_max, cub::Max());
    if (threadIdx.x == 0) out_norm[i] = reduce_result;
  }
}

template <typename T, int BlockDim>
__global__ void NegInfNorm(const T* x, const int pre,
                           const int axis_n,  // dim in axis
                           const int post, T* out_norm) {
  typedef cub::BlockReduce<T, BlockDim> BlockReduce;
  __shared__ typename BlockReduce::TempStorage temp_storage;
  int num = pre * post;
  for (int i = blockIdx.x; i < num; i += gridDim.x) {
    int base = (i / post) * post * axis_n + (i % post);
    T cur_min = inline_abs(x[base]);
    for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
      T x_ij_abs = inline_abs(x[base + j * post]);
      if (cur_min > x_ij_abs) cur_min = x_ij_abs;
    }
    T reduce_result = BlockReduce(temp_storage).Reduce(cur_min, cub::Min());
    if (threadIdx.x == 0) out_norm[i] = reduce_result;
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
  }
}

template <typename DeviceContext, typename T>
class PnormCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* in_x = ctx.Input<framework::Tensor>("X");
    auto* out_norm = ctx.Output<framework::Tensor>("Out");
    const T* x = in_x->data<T>();
    T* norm = out_norm->mutable_data<T>(ctx.GetPlace());

    auto xdim = in_x->dims();
    auto ndim = out_norm->dims();
    float porder = ctx.Attr<float>("porder");
    int axis = ctx.Attr<int>("axis");
myq406450149's avatar
myq406450149 已提交
154
    bool asvector = ctx.Attr<bool>("asvector");
155 156
    if (axis < 0) axis = xdim.size() + axis;
    int pre, n, post;
myq406450149's avatar
myq406450149 已提交
157
    GetDims(xdim, axis, &pre, &n, &post, asvector);
158 159 160

    auto& dev_ctx = ctx.cuda_device_context();

161 162 163
#ifdef __HIPCC__
    const int block = 256;
#else
164
    const int block = 512;
165 166
#endif

167 168 169
    int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
    const int max_blocks = std::max(max_threads / block, 1);
    int grid = std::min(max_blocks, pre * post);
170 171 172 173 174 175 176 177 178 179 180 181 182
    if (porder == 0) {
      ZeorNorm<T, block><<<grid, block, 0, dev_ctx.stream()>>>(x, pre, n, post,
                                                               norm);
    } else if (porder == INFINITY) {
      InfNorm<T, block><<<grid, block, 0, dev_ctx.stream()>>>(x, pre, n, post,
                                                              norm);
    } else if (porder == -INFINITY) {
      NegInfNorm<T, block><<<grid, block, 0, dev_ctx.stream()>>>(x, pre, n,
                                                                 post, norm);
    } else {
      Pnorm<T, block><<<grid, block, 0, dev_ctx.stream()>>>(x, pre, n, post,
                                                            porder, norm);
    }
183 184 185 186 187 188 189 190
  }
};

template <typename T, int BlockDim>
__global__ void PnormGradient(const T* x, const T* x_norm, const T* y_grad,
                              const float porder, const int pre,
                              const int axis_n, const int post, const T eps,
                              T* x_grad) {
G
Guoxia Wang 已提交
191
  using MT = typename details::MPTypeTrait<T>::Type;
192 193
  // dx = (x/pnorm_broadcast).pow(p-1) * norm_dy.broadcast * sign(x)
  int num = pre * post;
G
Guoxia Wang 已提交
194
  auto porder_grad = static_cast<MT>(porder - 1.0f);
195
  for (int i = blockIdx.x; i < num; i += gridDim.x) {
G
Guoxia Wang 已提交
196 197
    __shared__ MT pnorm_i;
    __shared__ MT yout_i;
198 199 200 201

    auto base = (i / post) * post * axis_n + (i % post);

    if (threadIdx.x == 0) {
G
Guoxia Wang 已提交
202 203
      pnorm_i = static_cast<MT>(x_norm[i]);
      yout_i = static_cast<MT>(y_grad[i]);
204
    }
205
    __syncthreads();
206 207 208

    for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
      int index = base + j * post;
G
Guoxia Wang 已提交
209 210 211 212 213
      const MT x_ij = static_cast<MT>(inline_abs(x[index]));
      x_grad[index] = static_cast<T>(
          inline_pow(x_ij, porder_grad) /
          (inline_pow(pnorm_i, porder_grad) + static_cast<MT>(eps)) * yout_i *
          static_cast<MT>(inline_sign(x[index])));
214 215 216 217
    }
  }
}

218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236
template <typename T, int BlockDim>
__global__ void InfNormGradient(const T* x, const T* x_norm, const T* y_grad,
                                const int pre, const int axis_n, const int post,
                                T* x_grad) {
  int num = pre * post;
  for (int i = blockIdx.x; i < num; i += gridDim.x) {
    __shared__ T pnorm_i;
    __shared__ T yout_i;
    auto base = (i / post) * post * axis_n + (i % post);
    if (threadIdx.x == 0) {
      pnorm_i = x_norm[i];
      yout_i = y_grad[i];
    }
    __syncthreads();

    for (int j = threadIdx.x; j < axis_n; j += blockDim.x) {
      int index = base + j * post;
      const T x_ij = inline_abs(x[index]);
      if (x_ij == pnorm_i) {
G
Guoxia Wang 已提交
237
        x_grad[index] = static_cast<T>(inline_sign(x[index])) * yout_i;
238 239 240 241 242 243 244
      } else {
        x_grad[index] = static_cast<T>(0);
      }
    }
  }
}

245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
template <typename DeviceContext, typename T, typename AttrType = T>
class PnormGradCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* in_x = ctx.Input<framework::Tensor>("X");
    auto* in_norm = ctx.Input<framework::Tensor>("Out");
    auto* in_norm_dy =
        ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
    auto* out_dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
    T* dx = out_dx->mutable_data<T>(ctx.GetPlace());
    const T* x = in_x->data<T>();
    const T* x_norm = in_norm->data<T>();
    const T* norm_dy = in_norm_dy->data<T>();

    auto xdim = in_x->dims();
    float porder = ctx.Attr<float>("porder");
    T eps = static_cast<T>(ctx.Attr<float>("epsilon"));
    int axis = ctx.Attr<int>("axis");
myq406450149's avatar
myq406450149 已提交
263
    bool asvector = ctx.Attr<bool>("asvector");
264 265
    if (axis < 0) axis = xdim.size() + axis;
    int pre, n, post;
myq406450149's avatar
myq406450149 已提交
266
    GetDims(xdim, axis, &pre, &n, &post, asvector);
267 268 269

    auto& dev_ctx = ctx.cuda_device_context();

270 271 272
#ifdef __HIPCC__
    const int block = 256;
#else
273
    const int block = 512;
274 275
#endif

276 277 278
    int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
    const int max_blocks = std::max(max_threads / block, 1);
    int grid = std::min(max_blocks, pre * post);
279 280 281 282 283 284 285 286 287 288 289
    if (porder == 0) {
      math::SetConstant<DeviceContext, T> set_zero;
      auto& dev_ctx = ctx.template device_context<DeviceContext>();
      set_zero(dev_ctx, out_dx, static_cast<T>(0));
    } else if (porder == INFINITY || porder == -INFINITY) {
      InfNormGradient<T, block><<<grid, block, 0, dev_ctx.stream()>>>(
          x, x_norm, norm_dy, pre, n, post, dx);
    } else {
      PnormGradient<T, block><<<grid, block, 0, dev_ctx.stream()>>>(
          x, x_norm, norm_dy, porder, pre, n, post, eps, dx);
    }
290 291 292 293 294 295 296 297 298
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;

G
Guoxia Wang 已提交
299 300 301
REGISTER_OP_CUDA_KERNEL(p_norm,
                        ops::PnormCUDAKernel<CUDA, paddle::platform::float16>,
                        ops::PnormCUDAKernel<CUDA, float>,
302
                        ops::PnormCUDAKernel<CUDA, double>);
G
Guoxia Wang 已提交
303 304 305 306
REGISTER_OP_CUDA_KERNEL(
    p_norm_grad, ops::PnormGradCUDAKernel<CUDA, paddle::platform::float16>,
    ops::PnormGradCUDAKernel<CUDA, float>,
    ops::PnormGradCUDAKernel<CUDA, double>);