p_norm_op.cu 7.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <algorithm>
16
#ifdef __NVCC__
17
#include "cub/cub.cuh"
18 19 20 21 22
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
G
Guoxia Wang 已提交
23
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
24
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
N
Noel 已提交
25
#include "paddle/fluid/operators/fc_op.h"
26
#include "paddle/fluid/operators/p_norm_op.h"
27 28
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
G
Guoxia Wang 已提交
29
#include "paddle/fluid/platform/float16.h"
30 31 32 33 34 35 36 37 38

namespace paddle {
namespace operators {

template <typename T>
__device__ __forceinline__ int sgn(T val) {
  return (T(0) < val) - (val < T(0));
}

G
Guoxia Wang 已提交
39 40 41
__device__ __forceinline__ platform::float16 inline_abs(platform::float16 x) {
  return static_cast<platform::float16>(abs(static_cast<float>(x)));
}
42 43 44
__device__ __forceinline__ float inline_abs(float x) { return abs(x); }
__device__ __forceinline__ double inline_abs(double x) { return abs(x); }

G
Guoxia Wang 已提交
45 46 47
__device__ __forceinline__ int inline_sign(platform::float16 x) {
  return sgn<platform::float16>(x);
}
48 49 50
__device__ __forceinline__ int inline_sign(float x) { return sgn<float>(x); }
__device__ __forceinline__ int inline_sign(double x) { return sgn<double>(x); }

G
Guoxia Wang 已提交
51 52 53 54 55
__device__ __forceinline__ platform::float16 inline_pow(
    platform::float16 base, platform::float16 exponent) {
  return static_cast<platform::float16>(
      pow(static_cast<float>(base), static_cast<float>(exponent)));
}
56 57 58 59 60 61 62
__device__ __forceinline__ float inline_pow(float base, float exponent) {
  return pow(base, exponent);
}
__device__ __forceinline__ double inline_pow(double base, double exponent) {
  return pow(base, exponent);
}

63
template <typename T>
64 65
struct NonzeroFunctor {
  HOSTDEVICE explicit inline NonzeroFunctor() {}
66
  HOSTDEVICE inline T operator()(const T x) const {
67
    return static_cast<T>(static_cast<double>(x) != 0);
68
  }
69
};
70

71
template <typename T>
72 73
struct AbsFunctor {
  HOSTDEVICE explicit inline AbsFunctor() {}
74
  HOSTDEVICE inline T operator()(const T x) const {
75
    return static_cast<T>(inline_abs(x));
76
  }
77
};
78

79
template <typename T>
80 81 82
struct UnsignedPowFunctor {
  HOSTDEVICE explicit inline UnsignedPowFunctor(float porder) {
    this->porder = porder;
83
  }
84 85
  HOSTDEVICE inline T operator()(const T x) const {
    return static_cast<T>(inline_pow(inline_abs(x), static_cast<T>(porder)));
86 87 88 89
  }
  float porder;
};

90 91 92 93 94 95 96 97 98 99
template <typename DeviceContext, typename T>
class PnormCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* in_x = ctx.Input<framework::Tensor>("X");
    auto* out_norm = ctx.Output<framework::Tensor>("Out");
    const T* x = in_x->data<T>();
    T* norm = out_norm->mutable_data<T>(ctx.GetPlace());
    auto xdim = in_x->dims();
    float porder = ctx.Attr<float>("porder");
N
Noel 已提交
100
    bool asvector = ctx.Attr<bool>("asvector");
101
    int axis = ctx.Attr<int>("axis");
102
    std::vector<int> reduce_axis = {axis};
N
Noel 已提交
103
    reduce_axis = GetReduceDim(reduce_axis, xdim.size(), asvector);
104
    auto stream = ctx.cuda_device_context().stream();
105

106
    using MT = typename details::MPTypeTrait<T>::Type;
107
    if (porder == 0) {
108
      TensorReduceImpl<T, T, kps::AddFunctor, NonzeroFunctor<T>>(
W
Wilber 已提交
109 110
          ctx.cuda_device_context(), *in_x, out_norm, NonzeroFunctor<T>(),
          reduce_axis, stream);
111
    } else if (porder == INFINITY) {
112
      TensorReduceImpl<T, T, kps::MaxFunctor, AbsFunctor<T>>(
W
Wilber 已提交
113 114
          ctx.cuda_device_context(), *in_x, out_norm, AbsFunctor<T>(),
          reduce_axis, stream);
115
    } else if (porder == -INFINITY) {
116
      TensorReduceImpl<T, T, kps::MinFunctor, AbsFunctor<T>>(
W
Wilber 已提交
117 118
          ctx.cuda_device_context(), *in_x, out_norm, AbsFunctor<T>(),
          reduce_axis, stream);
119
    } else {
120
      TensorReduceImpl<T, T, kps::AddFunctor, UnsignedPowFunctor<T>>(
W
Wilber 已提交
121 122
          ctx.cuda_device_context(), *in_x, out_norm,
          UnsignedPowFunctor<T>(porder), reduce_axis, stream);
123 124 125 126

      const framework::Tensor* tmp_norm = out_norm;
      std::vector<const framework::Tensor*> ins = {tmp_norm};
      std::vector<framework::Tensor*> outs = {out_norm};
127 128
      const auto& cuda_ctx =
          ctx.template device_context<platform::CUDADeviceContext>();
129
      paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
130
          cuda_ctx, ins, &outs, UnsignedPowFunctor<T>(1. / porder));
131
    }
132 133 134
  }
};

135 136 137 138 139 140
template <typename T>
struct AbsMaxAndMinGradFunctor {
  template <typename DeviceContext, typename X, typename Y, typename DX,
            typename DY, typename Dim>
  void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
                  const Dim& dim, int size) {
141 142
    dx->device(place) = dy->broadcast(dim) * (*x).sign() *
                        ((*x).abs() == y->broadcast(dim)).template cast<T>();
143
  }
144
};
145

146
template <typename T>
147 148 149 150
struct PNormGradFunctor {
  HOSTDEVICE explicit inline PNormGradFunctor(float porder) {
    this->porder = static_cast<T>(porder - 1.);
  }
151 152 153 154
  template <typename DeviceContext, typename X, typename Y, typename DX,
            typename DY, typename Dim>
  void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy,
                  const Dim& dim, int size) {
155 156 157
    dx->device(place) = (*x).abs().pow(this->porder) * (*x).sign() *
                        dy->broadcast(dim) *
                        (*y).pow(-this->porder).broadcast(dim);
158
  }
159
  T porder;
160
};
161

162 163 164 165 166 167 168 169 170 171 172 173 174 175
template <typename DeviceContext, typename T, typename AttrType = T>
class PnormGradCUDAKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* in_x = ctx.Input<framework::Tensor>("X");
    auto* in_norm = ctx.Input<framework::Tensor>("Out");
    auto* in_norm_dy =
        ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
    auto* out_dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
    T* dx = out_dx->mutable_data<T>(ctx.GetPlace());

    auto xdim = in_x->dims();
    float porder = ctx.Attr<float>("porder");
    int axis = ctx.Attr<int>("axis");
N
Noel 已提交
176
    bool reduce_all = (in_norm->numel() == 1);
177
    if (axis < 0) axis = xdim.size() + axis;
178
    const std::vector<int> dims = {axis};
179

180
    auto& cuda_ctx = ctx.template device_context<DeviceContext>();
181

182
    if (porder == 0) {
183
      pten::funcs::SetConstant<DeviceContext, T> set_zero;
184
      set_zero(cuda_ctx, out_dx, static_cast<T>(0));
185
    } else if (porder == INFINITY || porder == -INFINITY) {
186
      AbsMaxAndMinGradFunctor<T> functor;
187
      LaunchReduceGradKernel<DeviceContext, T, AbsMaxAndMinGradFunctor<T>>(
188
          ctx, in_x, in_norm, in_norm_dy, out_dx, functor, dims, reduce_all);
189
    } else {
190 191 192
      auto functor = PNormGradFunctor<T>(porder);
      LaunchReduceGradKernel<DeviceContext, T, PNormGradFunctor<T>>(
          ctx, in_x, in_norm, in_norm_dy, out_dx, functor, dims, reduce_all);
193
    }
194 195 196 197 198 199 200 201 202
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;

G
Guoxia Wang 已提交
203 204 205
REGISTER_OP_CUDA_KERNEL(p_norm,
                        ops::PnormCUDAKernel<CUDA, paddle::platform::float16>,
                        ops::PnormCUDAKernel<CUDA, float>,
206
                        ops::PnormCUDAKernel<CUDA, double>);
G
Guoxia Wang 已提交
207 208 209 210
REGISTER_OP_CUDA_KERNEL(
    p_norm_grad, ops::PnormGradCUDAKernel<CUDA, paddle::platform::float16>,
    ops::PnormGradCUDAKernel<CUDA, float>,
    ops::PnormGradCUDAKernel<CUDA, double>);