activation_cudnn_op.cu.cc 10.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
17 18 19
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_desc.h"
#else
20
#include "paddle/fluid/platform/cudnn_desc.h"
21
#endif
22

W
wanghuancoder 已提交
23 24 25 26 27 28
namespace paddle {
namespace platform {
struct CUDAPlace;
}  // namespace platform
}  // namespace paddle

29 30 31 32 33 34 35
namespace paddle {
namespace operators {
using framework::Tensor;
using platform::ActivationDescriptor;
using platform::TensorDescriptor;
using platform::CUDADeviceContext;

36 37 38 39 40 41 42 43 44 45 46 47
#ifdef PADDLE_WITH_HIP
#define GPUDNN_ACTIVATION_RELU miopenActivationRELU
#define GPUDNN_ACTIVATION_CLIPPED_RELU miopenActivationCLIPPEDRELU
#define GPUDNN_ACTIVATION_SIGMOID miopenActivationLOGISTIC
#define GPUDNN_ACTIVATION_TANH miopenActivationTANH
#else
#define GPUDNN_ACTIVATION_RELU CUDNN_ACTIVATION_RELU
#define GPUDNN_ACTIVATION_CLIPPED_RELU CUDNN_ACTIVATION_CLIPPED_RELU
#define GPUDNN_ACTIVATION_SIGMOID CUDNN_ACTIVATION_SIGMOID
#define GPUDNN_ACTIVATION_TANH CUDNN_ACTIVATION_TANH
#endif

48 49 50
template <typename T>
struct CudnnActivationFunctor {
  using ELEMENT_TYPE = T;
51 52 53 54 55
#ifdef PADDLE_WITH_HIP
  CudnnActivationFunctor(const CUDADeviceContext& ctx, const T& c,
                         const miopenActivationMode_t& m)
      : ctx_(ctx), coef_(c), mode_(m) {}
#else
56 57 58
  CudnnActivationFunctor(const CUDADeviceContext& ctx, const T& c,
                         const cudnnActivationMode_t& m)
      : ctx_(ctx), coef_(c), mode_(m) {}
59
#endif
60 61 62 63 64
  void operator()(const Tensor& x, Tensor* out) {
    ActivationDescriptor act_desc;
    act_desc.set(mode_, coef_);
    TensorDescriptor x_desc, out_desc;
    x_desc.set(x);
65
    out_desc.set(GET_DATA_SAFELY(out, "Output", "Out", "CudnnActivation"));
66 67 68 69 70 71 72
#ifdef PADDLE_WITH_HIP
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenActivationForward(
        ctx_.cudnn_handle(), act_desc.desc(),
        platform::CudnnDataType<T>::kOne(), x_desc.desc(), x.data<T>(),
        platform::CudnnDataType<T>::kZero(), out_desc.desc(),
        out->mutable_data<T>(ctx_.GetPlace())));
#else
73
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnActivationForward(
74 75 76 77
        ctx_.cudnn_handle(), act_desc.desc(),
        platform::CudnnDataType<T>::kOne(), x_desc.desc(), x.data<T>(),
        platform::CudnnDataType<T>::kZero(), out_desc.desc(),
        out->mutable_data<T>(ctx_.GetPlace())));
78
#endif
79 80 81
  }
  const CUDADeviceContext& ctx_;
  const T coef_;
82 83 84
#ifdef PADDLE_WITH_HIP
  const miopenActivationMode_t mode_;
#else
85
  const cudnnActivationMode_t mode_;
86
#endif
87 88 89 90 91
};

template <typename T>
struct CudnnActivationGradFunctor {
  using ELEMENT_TYPE = T;
92 93 94 95 96
#ifdef PADDLE_WITH_HIP
  CudnnActivationGradFunctor(const CUDADeviceContext& ctx, const T& c,
                             const miopenActivationMode_t& m)
      : ctx_(ctx), coef_(c), mode_(m) {}
#else
97 98 99
  CudnnActivationGradFunctor(const CUDADeviceContext& ctx, const T& c,
                             const cudnnActivationMode_t& m)
      : ctx_(ctx), coef_(c), mode_(m) {}
100
#endif
101 102 103 104 105 106 107 108
  void operator()(const Tensor& x, const Tensor& out, const Tensor dout,
                  Tensor* dx) {
    ActivationDescriptor act_desc;
    act_desc.set(mode_, coef_);
    TensorDescriptor x_desc, out_desc, dout_desc, dx_desc;
    x_desc.set(x);
    out_desc.set(out);
    dout_desc.set(dout);
109
    dx_desc.set(GET_DATA_SAFELY(dx, "Output", "X@GRAD", "CudnnActivationGrad"));
110 111 112 113 114 115 116 117
#ifdef PADDLE_WITH_HIP
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenActivationBackward(
        ctx_.cudnn_handle(), act_desc.desc(),
        platform::CudnnDataType<T>::kOne(), out_desc.desc(), out.data<T>(),
        dout_desc.desc(), dout.data<T>(), x_desc.desc(), x.data<T>(),
        platform::CudnnDataType<T>::kZero(), dx_desc.desc(),
        dx->mutable_data<T>(ctx_.GetPlace())));
#else
118
    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnActivationBackward(
119 120 121 122 123
        ctx_.cudnn_handle(), act_desc.desc(),
        platform::CudnnDataType<T>::kOne(), out_desc.desc(), out.data<T>(),
        dout_desc.desc(), dout.data<T>(), x_desc.desc(), x.data<T>(),
        platform::CudnnDataType<T>::kZero(), dx_desc.desc(),
        dx->mutable_data<T>(ctx_.GetPlace())));
124
#endif
125 126 127
  }
  const CUDADeviceContext& ctx_;
  const T coef_;
128 129 130
#ifdef PADDLE_WITH_HIP
  const miopenActivationMode_t mode_;
#else
131
  const cudnnActivationMode_t mode_;
132
#endif
133 134 135 136 137
};

template <typename T>
struct CudnnReluFunctor : public CudnnActivationFunctor<T> {
  explicit CudnnReluFunctor(const CUDADeviceContext& ctx)
138
      : CudnnActivationFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_RELU) {}
139 140 141 142
};
template <typename T>
struct CudnnReluGradFunctor : public CudnnActivationGradFunctor<T> {
  explicit CudnnReluGradFunctor(const CUDADeviceContext& ctx)
143
      : CudnnActivationGradFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_RELU) {}
144 145

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
146 147 148 149 150
};

template <typename T>
struct CudnnRelu6Functor : public CudnnActivationFunctor<T> {
  explicit CudnnRelu6Functor(const CUDADeviceContext& ctx)
151
      : CudnnActivationFunctor<T>(ctx, 6.0, GPUDNN_ACTIVATION_CLIPPED_RELU) {}
152 153 154 155
};
template <typename T>
struct CudnnRelu6GradFunctor : public CudnnActivationGradFunctor<T> {
  explicit CudnnRelu6GradFunctor(const CUDADeviceContext& ctx)
156 157
      : CudnnActivationGradFunctor<T>(ctx, 6.0,
                                      GPUDNN_ACTIVATION_CLIPPED_RELU) {}
158 159

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
160 161 162 163 164
};

template <typename T>
struct CudnnSigmoidFunctor : public CudnnActivationFunctor<T> {
  explicit CudnnSigmoidFunctor(const CUDADeviceContext& ctx)
165
      : CudnnActivationFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_SIGMOID) {}
166 167 168 169
};
template <typename T>
struct CudnnSigmoidGradFunctor : public CudnnActivationGradFunctor<T> {
  explicit CudnnSigmoidGradFunctor(const CUDADeviceContext& ctx)
170
      : CudnnActivationGradFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_SIGMOID) {}
171 172

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
173 174 175 176 177
};

template <typename T>
struct CudnnTanhFunctor : public CudnnActivationFunctor<T> {
  explicit CudnnTanhFunctor(const CUDADeviceContext& ctx)
178
      : CudnnActivationFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_TANH) {}
179 180 181 182
};
template <typename T>
struct CudnnTanhGradFunctor : public CudnnActivationGradFunctor<T> {
  explicit CudnnTanhGradFunctor(const CUDADeviceContext& ctx)
183
      : CudnnActivationGradFunctor<T>(ctx, 0.0, GPUDNN_ACTIVATION_TANH) {}
184 185

  static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
186 187 188 189 190 191 192 193 194 195 196 197 198 199
};

template <typename Functor>
class CudnnActivationKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  using T = typename Functor::ELEMENT_TYPE;
  void Compute(const framework::ExecutionContext& context) const override {
    const framework::Tensor* X = nullptr;
    framework::Tensor* Out = nullptr;
    ExtractActivationTensor(context, &X, &Out);
    Out->mutable_data<T>(context.GetPlace());
    auto& dev_ctx = context.template device_context<CUDADeviceContext>();
    Functor functor(dev_ctx);
200
    functor(GET_DATA_SAFELY(X, "Input", "X", "CudnnActivation"), Out);
201 202 203 204 205 206 207 208 209
  }
};

template <typename Functor>
class CudnnActivationGradKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  using T = typename Functor::ELEMENT_TYPE;
  void Compute(const framework::ExecutionContext& context) const override {
210 211
    static_assert(Functor::FwdDeps() == kDepOut, "Forward deps must be Out.");

212 213 214
    const framework::Tensor *X, *Out, *dOut;
    X = Out = dOut = nullptr;
    framework::Tensor* dX = nullptr;
215 216
    ExtractActivationGradTensor<Functor::FwdDeps()>(context, &X, &Out, &dOut,
                                                    &dX);
217 218 219
    dX->mutable_data<T>(context.GetPlace());
    auto& dev_ctx = context.template device_context<CUDADeviceContext>();
    Functor functor(dev_ctx);
220 221 222 223
    functor(GET_DATA_SAFELY(X, "Input", "X", "CudnnActivationGrad"),
            GET_DATA_SAFELY(Out, "Input", "Out", "CudnnActivationGrad"),
            GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "CudnnActivationGrad"),
            dX);
224 225 226 227 228 229 230 231 232
  }
};

}  // namespace operators
}  // namespace paddle

namespace plat = paddle::platform;
namespace ops = paddle::operators;

T
Tao Luo 已提交
233 234 235 236
#define FOR_EACH_CUDNN_OP_FUNCTOR(__macro)                        \
  __macro(relu, CudnnReluFunctor, CudnnReluGradFunctor);          \
  __macro(relu6, CudnnRelu6Functor, CudnnRelu6GradFunctor);       \
  __macro(sigmoid, CudnnSigmoidFunctor, CudnnSigmoidGradFunctor); \
237 238
  __macro(tanh, CudnnTanhFunctor, CudnnTanhGradFunctor)

239 240 241 242 243 244 245 246
#ifdef PADDLE_WITH_HIP
#define REGISTER_ACTIVATION_CUDNN_KERNEL(act_type, functor, grad_functor) \
  REGISTER_OP_KERNEL(act_type, CUDNN, plat::CUDAPlace,                    \
                     ops::CudnnActivationKernel<ops::functor<float>>);    \
  REGISTER_OP_KERNEL(                                                     \
      act_type##_grad, CUDNN, plat::CUDAPlace,                            \
      ops::CudnnActivationGradKernel<ops::grad_functor<float>>);
#else
247 248 249 250 251 252 253 254
#define REGISTER_ACTIVATION_CUDNN_KERNEL(act_type, functor, grad_functor) \
  REGISTER_OP_KERNEL(act_type, CUDNN, plat::CUDAPlace,                    \
                     ops::CudnnActivationKernel<ops::functor<float>>,     \
                     ops::CudnnActivationKernel<ops::functor<double>>);   \
  REGISTER_OP_KERNEL(                                                     \
      act_type##_grad, CUDNN, plat::CUDAPlace,                            \
      ops::CudnnActivationGradKernel<ops::grad_functor<float>>,           \
      ops::CudnnActivationGradKernel<ops::grad_functor<double>>);
255
#endif
256 257

FOR_EACH_CUDNN_OP_FUNCTOR(REGISTER_ACTIVATION_CUDNN_KERNEL);