activation_cudnn_op.cu.cc 9.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
17
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
18

W
wanghuancoder 已提交
19 20 21 22 23 24
namespace paddle {
namespace platform {
struct CUDAPlace;
}  // namespace platform
}  // namespace paddle

25 26 27 28 29 30 31
namespace paddle {
namespace operators {
using framework::Tensor;
using platform::ActivationDescriptor;
using platform::TensorDescriptor;
using platform::CUDADeviceContext;

32 33 34 35 36 37 38 39 40 41 42 43
#ifdef PADDLE_WITH_HIP
#define GPUDNN_ACTIVATION_RELU miopenActivationRELU
#define GPUDNN_ACTIVATION_CLIPPED_RELU miopenActivationCLIPPEDRELU
#define GPUDNN_ACTIVATION_SIGMOID miopenActivationLOGISTIC
#define GPUDNN_ACTIVATION_TANH miopenActivationTANH
#else
#define GPUDNN_ACTIVATION_RELU CUDNN_ACTIVATION_RELU
#define GPUDNN_ACTIVATION_CLIPPED_RELU CUDNN_ACTIVATION_CLIPPED_RELU
#define GPUDNN_ACTIVATION_SIGMOID CUDNN_ACTIVATION_SIGMOID
#define GPUDNN_ACTIVATION_TANH CUDNN_ACTIVATION_TANH
#endif

44 45 46
template <typename T>
struct CudnnActivationFunctor {
  using ELEMENT_TYPE = T;
47 48 49 50 51
#ifdef PADDLE_WITH_HIP
  CudnnActivationFunctor(const CUDADeviceContext& ctx, const T& c,
                         const miopenActivationMode_t& m)
      : ctx_(ctx), coef_(c), mode_(m) {}
#else
52 53 54
  CudnnActivationFunctor(const CUDADeviceContext& ctx, const T& c,
                         const cudnnActivationMode_t& m)
      : ctx_(ctx), coef_(c), mode_(m) {}
55
#endif
56 57 58 59 60
  void operator()(const Tensor& x, Tensor* out) {
    ActivationDescriptor act_desc;
    act_desc.set(mode_, coef_);
    TensorDescriptor x_desc, out_desc;
    x_desc.set(x);
61
    out_desc.set(GET_DATA_SAFELY(out, "Output", "Out", "CudnnActivation"));
62
#ifdef PADDLE_WITH_HIP
63
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenActivationForward(
64 65 66 67 68
        ctx_.cudnn_handle(), act_desc.desc(),
        platform::CudnnDataType<T>::kOne(), x_desc.desc(), x.data<T>(),
        platform::CudnnDataType<T>::kZero(), out_desc.desc(),
        out->mutable_data<T>(ctx_.GetPlace())));
#else
69
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnActivationForward(
70 71 72 73
        ctx_.cudnn_handle(), act_desc.desc(),
        platform::CudnnDataType<T>::kOne(), x_desc.desc(), x.data<T>(),
        platform::CudnnDataType<T>::kZero(), out_desc.desc(),
        out->mutable_data<T>(ctx_.GetPlace())));
74
#endif
75 76 77
  }
  const CUDADeviceContext& ctx_;
  const T coef_;
78 79 80
#ifdef PADDLE_WITH_HIP
  const miopenActivationMode_t mode_;
#else
81
  const cudnnActivationMode_t mode_;
82
#endif
83 84 85 86 87
};

template <typename T>
struct CudnnActivationGradFunctor {
  using ELEMENT_TYPE = T;
88 89 90 91 92
#ifdef PADDLE_WITH_HIP
  CudnnActivationGradFunctor(const CUDADeviceContext& ctx, const T& c,
                             const miopenActivationMode_t& m)
      : ctx_(ctx), coef_(c), mode_(m) {}
#else
93 94 95
  CudnnActivationGradFunctor(const CUDADeviceContext& ctx, const T& c,
                             const cudnnActivationMode_t& m)
      : ctx_(ctx), coef_(c), mode_(m) {}
96
#endif
97 98 99 100 101 102 103 104
  void operator()(const Tensor& x, const Tensor& out, const Tensor dout,
                  Tensor* dx) {
    ActivationDescriptor act_desc;
    act_desc.set(mode_, coef_);
    TensorDescriptor x_desc, out_desc, dout_desc, dx_desc;
    x_desc.set(x);
    out_desc.set(out);
    dout_desc.set(dout);
105
    dx_desc.set(GET_DATA_SAFELY(dx, "Output", "X@GRAD", "CudnnActivationGrad"));
106
#ifdef PADDLE_WITH_HIP
107
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenActivationBackward(
108 109 110 111 112 113
        ctx_.cudnn_handle(), act_desc.desc(),
        platform::CudnnDataType<T>::kOne(), out_desc.desc(), out.data<T>(),
        dout_desc.desc(), dout.data<T>(), x_desc.desc(), x.data<T>(),
        platform::CudnnDataType<T>::kZero(), dx_desc.desc(),
        dx->mutable_data<T>(ctx_.GetPlace())));
#else
114
    PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnActivationBackward(
115 116 117 118 119
        ctx_.cudnn_handle(), act_desc.desc(),
        platform::CudnnDataType<T>::kOne(), out_desc.desc(), out.data<T>(),
        dout_desc.desc(), dout.data<T>(), x_desc.desc(), x.data<T>(),
        platform::CudnnDataType<T>::kZero(), dx_desc.desc(),
        dx->mutable_data<T>(ctx_.GetPlace())));
120
#endif
121 122 123
  }
  const CUDADeviceContext& ctx_;
  const T coef_;
124 125 126
#ifdef PADDLE_WITH_HIP
  const miopenActivationMode_t mode_;
#else
127
  const cudnnActivationMode_t mode_;
128
#endif
129 130 131 132 133
};

template <typename T>
struct CudnnReluFunctor : public CudnnActivationFunctor<T> {
  explicit CudnnReluFunctor(const CUDADeviceContext& ctx)
134
      : CudnnActivationFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_RELU) {}
135 136 137 138
};
template <typename T>
struct CudnnReluGradFunctor : public CudnnActivationGradFunctor<T> {
  explicit CudnnReluGradFunctor(const CUDADeviceContext& ctx)
139
      : CudnnActivationGradFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_RELU) {}
140 141

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
142 143 144 145 146
};

template <typename T>
struct CudnnRelu6Functor : public CudnnActivationFunctor<T> {
  explicit CudnnRelu6Functor(const CUDADeviceContext& ctx)
147
      : CudnnActivationFunctor<T>(ctx, 6.0, GPUDNN_ACTIVATION_CLIPPED_RELU) {}
148 149 150 151
};
template <typename T>
struct CudnnRelu6GradFunctor : public CudnnActivationGradFunctor<T> {
  explicit CudnnRelu6GradFunctor(const CUDADeviceContext& ctx)
152 153
      : CudnnActivationGradFunctor<T>(ctx, 6.0,
                                      GPUDNN_ACTIVATION_CLIPPED_RELU) {}
154 155

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
156 157 158 159 160
};

template <typename T>
struct CudnnSigmoidFunctor : public CudnnActivationFunctor<T> {
  explicit CudnnSigmoidFunctor(const CUDADeviceContext& ctx)
161
      : CudnnActivationFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_SIGMOID) {}
162 163 164 165
};
template <typename T>
struct CudnnSigmoidGradFunctor : public CudnnActivationGradFunctor<T> {
  explicit CudnnSigmoidGradFunctor(const CUDADeviceContext& ctx)
166
      : CudnnActivationGradFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_SIGMOID) {}
167 168

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
169 170 171 172 173
};

template <typename T>
struct CudnnTanhFunctor : public CudnnActivationFunctor<T> {
  explicit CudnnTanhFunctor(const CUDADeviceContext& ctx)
174
      : CudnnActivationFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_TANH) {}
175 176 177 178
};
template <typename T>
struct CudnnTanhGradFunctor : public CudnnActivationGradFunctor<T> {
  explicit CudnnTanhGradFunctor(const CUDADeviceContext& ctx)
179
      : CudnnActivationGradFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_TANH) {}
180 181

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
182 183 184 185 186 187 188 189 190 191 192 193 194 195
};

template <typename Functor>
class CudnnActivationKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  using T = typename Functor::ELEMENT_TYPE;
  void Compute(const framework::ExecutionContext& context) const override {
    const framework::Tensor* X = nullptr;
    framework::Tensor* Out = nullptr;
    ExtractActivationTensor(context, &X, &Out);
    Out->mutable_data<T>(context.GetPlace());
    auto& dev_ctx = context.template device_context<CUDADeviceContext>();
    Functor functor(dev_ctx);
196
    functor(GET_DATA_SAFELY(X, "Input", "X", "CudnnActivation"), Out);
197 198 199 200 201 202 203 204 205
  }
};

template <typename Functor>
class CudnnActivationGradKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  using T = typename Functor::ELEMENT_TYPE;
  void Compute(const framework::ExecutionContext& context) const override {
206 207
    static_assert(Functor::FwdDeps() == kDepOut, "Forward deps must be Out.");

208 209 210
    const framework::Tensor *X, *Out, *dOut;
    X = Out = dOut = nullptr;
    framework::Tensor* dX = nullptr;
211 212
    ExtractActivationGradTensor<Functor::FwdDeps()>(context, &X, &Out, &dOut,
                                                    &dX);
213 214 215
    dX->mutable_data<T>(context.GetPlace());
    auto& dev_ctx = context.template device_context<CUDADeviceContext>();
    Functor functor(dev_ctx);
216 217 218 219
    functor(GET_DATA_SAFELY(X, "Input", "X", "CudnnActivationGrad"),
            GET_DATA_SAFELY(Out, "Input", "Out", "CudnnActivationGrad"),
            GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "CudnnActivationGrad"),
            dX);
220 221 222 223 224 225 226 227 228
  }
};

}  // namespace operators
}  // namespace paddle

namespace plat = paddle::platform;
namespace ops = paddle::operators;

T
Tao Luo 已提交
229 230 231 232
#define FOR_EACH_CUDNN_OP_FUNCTOR(__macro)                        \
  __macro(relu, CudnnReluFunctor, CudnnReluGradFunctor);          \
  __macro(relu6, CudnnRelu6Functor, CudnnRelu6GradFunctor);       \
  __macro(sigmoid, CudnnSigmoidFunctor, CudnnSigmoidGradFunctor); \
233 234
  __macro(tanh, CudnnTanhFunctor, CudnnTanhGradFunctor)

235 236 237 238 239 240 241 242
#ifdef PADDLE_WITH_HIP
#define REGISTER_ACTIVATION_CUDNN_KERNEL(act_type, functor, grad_functor) \
  REGISTER_OP_KERNEL(act_type, CUDNN, plat::CUDAPlace,                    \
                     ops::CudnnActivationKernel<ops::functor<float>>);    \
  REGISTER_OP_KERNEL(                                                     \
      act_type##_grad, CUDNN, plat::CUDAPlace,                            \
      ops::CudnnActivationGradKernel<ops::grad_functor<float>>);
#else
243 244 245 246 247 248 249 250
#define REGISTER_ACTIVATION_CUDNN_KERNEL(act_type, functor, grad_functor) \
  REGISTER_OP_KERNEL(act_type, CUDNN, plat::CUDAPlace,                    \
                     ops::CudnnActivationKernel<ops::functor<float>>,     \
                     ops::CudnnActivationKernel<ops::functor<double>>);   \
  REGISTER_OP_KERNEL(                                                     \
      act_type##_grad, CUDNN, plat::CUDAPlace,                            \
      ops::CudnnActivationGradKernel<ops::grad_functor<float>>,           \
      ops::CudnnActivationGradKernel<ops::grad_functor<double>>);
251
#endif
252 253

FOR_EACH_CUDNN_OP_FUNCTOR(REGISTER_ACTIVATION_CUDNN_KERNEL);