activation_cudnn_op.cu.cc 10.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
17
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
18 19 20 21 22 23

namespace paddle {
namespace operators {
using framework::Tensor;
using platform::ActivationDescriptor;
using platform::CUDADeviceContext;
24
using platform::TensorDescriptor;
25

26 27 28 29 30 31 32 33 34 35 36 37
#ifdef PADDLE_WITH_HIP
#define GPUDNN_ACTIVATION_RELU miopenActivationRELU
#define GPUDNN_ACTIVATION_CLIPPED_RELU miopenActivationCLIPPEDRELU
#define GPUDNN_ACTIVATION_SIGMOID miopenActivationLOGISTIC
#define GPUDNN_ACTIVATION_TANH miopenActivationTANH
#else
#define GPUDNN_ACTIVATION_RELU CUDNN_ACTIVATION_RELU
#define GPUDNN_ACTIVATION_CLIPPED_RELU CUDNN_ACTIVATION_CLIPPED_RELU
#define GPUDNN_ACTIVATION_SIGMOID CUDNN_ACTIVATION_SIGMOID
#define GPUDNN_ACTIVATION_TANH CUDNN_ACTIVATION_TANH
#endif

38 39 40
template <typename T>
struct CudnnActivationFunctor {
  using ELEMENT_TYPE = T;
41
#ifdef PADDLE_WITH_HIP
42 43
  CudnnActivationFunctor(const CUDADeviceContext& ctx,
                         const T& c,
44 45 46
                         const miopenActivationMode_t& m)
      : ctx_(ctx), coef_(c), mode_(m) {}
#else
47 48
  CudnnActivationFunctor(const CUDADeviceContext& ctx,
                         const T& c,
49 50
                         const cudnnActivationMode_t& m)
      : ctx_(ctx), coef_(c), mode_(m) {}
51
#endif
52 53 54 55 56
  void operator()(const Tensor& x, Tensor* out) {
    ActivationDescriptor act_desc;
    act_desc.set(mode_, coef_);
    TensorDescriptor x_desc, out_desc;
    x_desc.set(x);
57
    out_desc.set(GET_DATA_SAFELY(out, "Output", "Out", "CudnnActivation"));
58
#ifdef PADDLE_WITH_HIP
59
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenActivationForward(
60 61 62 63 64 65 66
        ctx_.cudnn_handle(),
        act_desc.desc(),
        platform::CudnnDataType<T>::kOne(),
        x_desc.desc(),
        x.data<T>(),
        platform::CudnnDataType<T>::kZero(),
        out_desc.desc(),
67 68
        out->mutable_data<T>(ctx_.GetPlace())));
#else
69
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnActivationForward(
70 71 72 73 74 75 76
        ctx_.cudnn_handle(),
        act_desc.desc(),
        platform::CudnnDataType<T>::kOne(),
        x_desc.desc(),
        x.data<T>(),
        platform::CudnnDataType<T>::kZero(),
        out_desc.desc(),
77
        out->mutable_data<T>(ctx_.GetPlace())));
78
#endif
79 80 81
  }
  const CUDADeviceContext& ctx_;
  const T coef_;
82 83 84
#ifdef PADDLE_WITH_HIP
  const miopenActivationMode_t mode_;
#else
85
  const cudnnActivationMode_t mode_;
86
#endif
87 88 89 90 91
};

template <typename T>
struct CudnnActivationGradFunctor {
  using ELEMENT_TYPE = T;
92
#ifdef PADDLE_WITH_HIP
93 94
  CudnnActivationGradFunctor(const CUDADeviceContext& ctx,
                             const T& c,
95 96 97
                             const miopenActivationMode_t& m)
      : ctx_(ctx), coef_(c), mode_(m) {}
#else
98 99
  CudnnActivationGradFunctor(const CUDADeviceContext& ctx,
                             const T& c,
100 101
                             const cudnnActivationMode_t& m)
      : ctx_(ctx), coef_(c), mode_(m) {}
102
#endif
103 104 105
  void operator()(const Tensor& x,
                  const Tensor& out,
                  const Tensor dout,
106 107 108 109 110 111 112
                  Tensor* dx) {
    ActivationDescriptor act_desc;
    act_desc.set(mode_, coef_);
    TensorDescriptor x_desc, out_desc, dout_desc, dx_desc;
    x_desc.set(x);
    out_desc.set(out);
    dout_desc.set(dout);
113
    dx_desc.set(GET_DATA_SAFELY(dx, "Output", "X@GRAD", "CudnnActivationGrad"));
114
#ifdef PADDLE_WITH_HIP
115
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenActivationBackward(
116 117 118 119 120 121 122 123 124 125 126
        ctx_.cudnn_handle(),
        act_desc.desc(),
        platform::CudnnDataType<T>::kOne(),
        out_desc.desc(),
        out.data<T>(),
        dout_desc.desc(),
        dout.data<T>(),
        x_desc.desc(),
        x.data<T>(),
        platform::CudnnDataType<T>::kZero(),
        dx_desc.desc(),
127 128
        dx->mutable_data<T>(ctx_.GetPlace())));
#else
129
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnActivationBackward(
130 131 132 133 134 135 136 137 138 139 140
        ctx_.cudnn_handle(),
        act_desc.desc(),
        platform::CudnnDataType<T>::kOne(),
        out_desc.desc(),
        out.data<T>(),
        dout_desc.desc(),
        dout.data<T>(),
        x_desc.desc(),
        x.data<T>(),
        platform::CudnnDataType<T>::kZero(),
        dx_desc.desc(),
141
        dx->mutable_data<T>(ctx_.GetPlace())));
142
#endif
143 144 145
  }
  const CUDADeviceContext& ctx_;
  const T coef_;
146 147 148
#ifdef PADDLE_WITH_HIP
  const miopenActivationMode_t mode_;
#else
149
  const cudnnActivationMode_t mode_;
150
#endif
151 152 153 154 155
};

template <typename T>
struct CudnnReluFunctor : public CudnnActivationFunctor<T> {
  explicit CudnnReluFunctor(const CUDADeviceContext& ctx)
156
      : CudnnActivationFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_RELU) {}
157 158 159 160
};
template <typename T>
struct CudnnReluGradFunctor : public CudnnActivationGradFunctor<T> {
  explicit CudnnReluGradFunctor(const CUDADeviceContext& ctx)
161
      : CudnnActivationGradFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_RELU) {}
162

163 164 165
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
166 167 168 169 170
};

template <typename T>
struct CudnnRelu6Functor : public CudnnActivationFunctor<T> {
  explicit CudnnRelu6Functor(const CUDADeviceContext& ctx)
171
      : CudnnActivationFunctor<T>(ctx, 6.0, GPUDNN_ACTIVATION_CLIPPED_RELU) {}
172 173 174 175
};
template <typename T>
struct CudnnRelu6GradFunctor : public CudnnActivationGradFunctor<T> {
  explicit CudnnRelu6GradFunctor(const CUDADeviceContext& ctx)
176 177
      : CudnnActivationGradFunctor<T>(
            ctx, 6.0, GPUDNN_ACTIVATION_CLIPPED_RELU) {}
178

179 180 181
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
182 183 184 185 186
};

template <typename T>
struct CudnnSigmoidFunctor : public CudnnActivationFunctor<T> {
  explicit CudnnSigmoidFunctor(const CUDADeviceContext& ctx)
187
      : CudnnActivationFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_SIGMOID) {}
188 189 190 191
};
template <typename T>
struct CudnnSigmoidGradFunctor : public CudnnActivationGradFunctor<T> {
  explicit CudnnSigmoidGradFunctor(const CUDADeviceContext& ctx)
192
      : CudnnActivationGradFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_SIGMOID) {}
193

194 195 196
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
197 198 199 200 201
};

template <typename T>
struct CudnnTanhFunctor : public CudnnActivationFunctor<T> {
  explicit CudnnTanhFunctor(const CUDADeviceContext& ctx)
202
      : CudnnActivationFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_TANH) {}
203 204 205 206
};
template <typename T>
struct CudnnTanhGradFunctor : public CudnnActivationGradFunctor<T> {
  explicit CudnnTanhGradFunctor(const CUDADeviceContext& ctx)
207
      : CudnnActivationGradFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_TANH) {}
208

209 210 211
  static constexpr ActBwdOpFwdDeps FwdDeps() {
    return ActBwdOpFwdDeps::kDepOut;
  }
212 213 214 215 216 217 218 219 220 221 222 223 224 225
};

template <typename Functor>
class CudnnActivationKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  using T = typename Functor::ELEMENT_TYPE;
  void Compute(const framework::ExecutionContext& context) const override {
    const framework::Tensor* X = nullptr;
    framework::Tensor* Out = nullptr;
    ExtractActivationTensor(context, &X, &Out);
    Out->mutable_data<T>(context.GetPlace());
    auto& dev_ctx = context.template device_context<CUDADeviceContext>();
    Functor functor(dev_ctx);
226
    functor(GET_DATA_SAFELY(X, "Input", "X", "CudnnActivation"), Out);
227 228 229 230 231 232 233 234 235
  }
};

template <typename Functor>
class CudnnActivationGradKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  using T = typename Functor::ELEMENT_TYPE;
  void Compute(const framework::ExecutionContext& context) const override {
236 237
    static_assert(Functor::FwdDeps() == ActBwdOpFwdDeps::kDepOut,
                  "Forward deps must be Out.");
238

239 240 241
    const framework::Tensor *X, *Out, *dOut;
    X = Out = dOut = nullptr;
    framework::Tensor* dX = nullptr;
242 243
    ExtractActivationGradTensor<Functor::FwdDeps()>(
        context, &X, &Out, &dOut, &dX);
244 245 246
    dX->mutable_data<T>(context.GetPlace());
    auto& dev_ctx = context.template device_context<CUDADeviceContext>();
    Functor functor(dev_ctx);
247 248 249 250
    functor(GET_DATA_SAFELY(X, "Input", "X", "CudnnActivationGrad"),
            GET_DATA_SAFELY(Out, "Input", "Out", "CudnnActivationGrad"),
            GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "CudnnActivationGrad"),
            dX);
251 252 253 254 255 256 257 258 259
  }
};

}  // namespace operators
}  // namespace paddle

namespace plat = paddle::platform;
namespace ops = paddle::operators;

T
Tao Luo 已提交
260 261 262 263
#define FOR_EACH_CUDNN_OP_FUNCTOR(__macro)                        \
  __macro(relu, CudnnReluFunctor, CudnnReluGradFunctor);          \
  __macro(relu6, CudnnRelu6Functor, CudnnRelu6GradFunctor);       \
  __macro(sigmoid, CudnnSigmoidFunctor, CudnnSigmoidGradFunctor); \
264 265
  __macro(tanh, CudnnTanhFunctor, CudnnTanhGradFunctor)

266 267
#ifdef PADDLE_WITH_HIP
#define REGISTER_ACTIVATION_CUDNN_KERNEL(act_type, functor, grad_functor) \
268 269 270
  REGISTER_OP_KERNEL(act_type,                                            \
                     CUDNN,                                               \
                     plat::CUDAPlace,                                     \
271 272
                     ops::CudnnActivationKernel<ops::functor<float>>);    \
  REGISTER_OP_KERNEL(                                                     \
273 274 275
      act_type##_grad,                                                    \
      CUDNN,                                                              \
      plat::CUDAPlace,                                                    \
276 277
      ops::CudnnActivationGradKernel<ops::grad_functor<float>>);
#else
278
#define REGISTER_ACTIVATION_CUDNN_KERNEL(act_type, functor, grad_functor) \
279 280 281
  REGISTER_OP_KERNEL(act_type,                                            \
                     CUDNN,                                               \
                     plat::CUDAPlace,                                     \
282 283 284
                     ops::CudnnActivationKernel<ops::functor<float>>,     \
                     ops::CudnnActivationKernel<ops::functor<double>>);   \
  REGISTER_OP_KERNEL(                                                     \
285 286 287
      act_type##_grad,                                                    \
      CUDNN,                                                              \
      plat::CUDAPlace,                                                    \
288 289
      ops::CudnnActivationGradKernel<ops::grad_functor<float>>,           \
      ops::CudnnActivationGradKernel<ops::grad_functor<double>>);
290
#endif
291 292

FOR_EACH_CUDNN_OP_FUNCTOR(REGISTER_ACTIVATION_CUDNN_KERNEL);