cudnn_norm_conv.cu.h 12.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/fluid/operators/fused/cudnn_fusion_helper.h"
18 19
#include "paddle/fluid/platform/cudnn_desc.h"
#include "paddle/fluid/platform/cudnn_helper.h"
20 21 22 23 24 25

namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
namespace dynload = platform::dynload;

26 27 28
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;

29
#if CUDNN_VERSION >= 8000
30 31 32

static size_t RoundUp(int64_t a, int64_t b) { return (a + b - 1) / b * b; }

33
template <typename T>
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
struct NormConvolutionArgs {
  NormConvolutionArgs() {
    dtype = platform::CudnnDataType<T>::type;
    format = CUDNN_TENSOR_NHWC;
    compute_type = platform::CudnnDataType<float>::type;
  }

  void Set(const std::vector<int> &input_shape,
           const std::vector<int> &filter_shape,
           const std::vector<int> &output_shape, int padding, int stride,
           int dilation, int group) {
    PADDLE_ENFORCE_EQ(
        input_shape.size(), 4U,
        platform::errors::InvalidArgument(
            "The size of input_shape is expected to 4. But recieved "
            "input_shape's size is %d, input_shape is [%s].",
            input_shape.size(), framework::make_ddim(input_shape)));
    PADDLE_ENFORCE_EQ(
        filter_shape.size(), 4U,
        platform::errors::InvalidArgument(
            "The size of filter_shape is expected to 4. But recieved "
            "filter_shape's size is %d, filter_shape is [%s].",
            filter_shape.size(), framework::make_ddim(filter_shape)));
    PADDLE_ENFORCE_EQ(filter_shape[1] == filter_shape[2] &&
                          (filter_shape[1] == 1 || filter_shape[1] == 3),
                      true,
                      platform::errors::InvalidArgument(
                          "The filter_shape is expected to store as nhwc, and "
                          "h = w = 1 or 3. But recieved filter_shape is [%s].",
                          framework::make_ddim(filter_shape)));
    PADDLE_ENFORCE_EQ(
        output_shape.size(), 4U,
        platform::errors::InvalidArgument(
            "The size of output_shape is expected to 4. But recieved "
            "filter_shape's size is %d, filter_shape is [%s].",
            output_shape.size(), framework::make_ddim(output_shape)));

    for (size_t i = 0; i < input_shape.size(); ++i) {
      in_dims.push_back(input_shape[i]);
    }
    for (size_t i = 0; i < filter_shape.size(); ++i) {
      filter_dims.push_back(filter_shape[i]);
    }
    paddings = {padding, padding};
    strides = {stride, stride};
    dilations = {dilation, dilation};

    in_desc.set(input_shape, format, dtype);
    filter_desc.set(filter_shape, format, dtype, group);
    out_desc.set(output_shape, format, dtype);

    int output_channel = filter_shape[0];
    std::vector<int> stats_shape = {1, 1, 1, output_channel};
    out_stats_desc.set(stats_shape, format, compute_type);

    conv_desc.set(dtype, paddings, strides, dilations, false, group);
  }

  cudnnDataType_t dtype;
  cudnnTensorFormat_t format;
  cudnnDataType_t compute_type;

  std::vector<int64_t> in_dims;
  std::vector<int64_t> filter_dims;
  std::vector<int> strides;
  std::vector<int> paddings;
  std::vector<int> dilations;

  platform::TensorDescriptor in_desc;
  platform::FilterDescriptor filter_desc;
  platform::TensorDescriptor out_desc;
  platform::TensorDescriptor out_stats_desc;
  platform::ConvolutionDescriptor conv_desc;
};

template <typename T>
class CudnnNormConvolution {
111
 public:
112 113 114 115 116 117 118 119
  CudnnNormConvolution(const platform::CUDADeviceContext &ctx,
                       const std::vector<int> &input_shape,
                       const std::vector<int> &filter_shape,
                       const std::vector<int> &output_shape, const int &padding,
                       const int &stride, const int &dilation,
                       const int &group) {
    args_.Set(input_shape, filter_shape, output_shape, padding, stride,
              dilation, group);
120
  }
121
  ~CudnnNormConvolution() {}
122 123 124 125

  void Forward(const platform::CUDADeviceContext &ctx, T *input_ptr,
               T *filter_ptr, T *output_ptr, float *sum_ptr,
               float *sum_of_squares_ptr) {
126 127 128 129 130 131 132
    auto cudnn_handle = ctx.cudnn_handle();

    CudnnFusionOp *fwd_op = GetForwardOp(ctx);
    size_t workspace_size = RoundUp(
        static_cast<int64_t>(fwd_op->GetWorkspaceSizeInBytes(cudnn_handle)),
        512);

133 134
    // Set variant_param
    // input ptr
135 136 137 138 139
    fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, input_ptr);
    fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_WDATA, filter_ptr);
    fwd_op->SetOpVariantParamAttrPtr(
        CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &workspace_size);

140
    // output ptr
141 142 143 144 145
    fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YDATA, output_ptr);
    fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YSUM, sum_ptr);
    fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YSQSUM, sum_of_squares_ptr);

    ctx.cudnn_workspace_handle().RunFunc(
146 147
        [&](void *workspace_ptr) {
          // workspace ptr
148
          fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE, workspace_ptr);
149
          // fused op execute
150
          fwd_op->Execute(cudnn_handle);
151
        },
152
        workspace_size);
153 154
  }

155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
 private:
  CudnnFusionOp *GetForwardOp(const platform::CUDADeviceContext &ctx) {
    framework::AlgorithmsCache<CudnnFusionOp *> &cache =
        *(CudnnFusionOpCache::Instance().GetForward());

    CudnnFusionOp *fwd_op = cache.GetAlgorithm(
        args_.in_dims, args_.filter_dims, args_.strides, args_.paddings,
        args_.dilations, 0, static_cast<int64_t>(args_.dtype), [&]() {
          CudnnFusionOp *fwd_op =
              new CudnnFusionOp(CUDNN_FUSED_SCALE_BIAS_ACTIVATION_CONV_BNSTATS);

          // Set constant_param
          fwd_op->SetOpConstParamAttr(
              {CUDNN_PARAM_XDATA_PLACEHOLDER, CUDNN_PARAM_WDATA_PLACEHOLDER,
               CUDNN_PARAM_YDATA_PLACEHOLDER},
              CUDNN_PTR_16B_ALIGNED);
          fwd_op->SetOpConstParamAttr(
              {CUDNN_PARAM_YSUM_PLACEHOLDER, CUDNN_PARAM_YSQSUM_PLACEHOLDER},
              CUDNN_PTR_16B_ALIGNED);

          // conv desc
          fwd_op->SetOpConstParamDesc(CUDNN_PARAM_CONV_DESC,
                                      args_.conv_desc.desc());
          // input desc
          fwd_op->SetOpConstParamDesc(CUDNN_PARAM_XDESC, args_.in_desc.desc());
          // filter desc
          fwd_op->SetOpConstParamDesc(CUDNN_PARAM_WDESC,
                                      args_.filter_desc.desc());
          // output desc
          fwd_op->SetOpConstParamDesc(CUDNN_PARAM_YDESC, args_.out_desc.desc());
          // output_stats desc
          fwd_op->SetOpConstParamDesc(CUDNN_PARAM_YSTATS_DESC,
                                      args_.out_stats_desc.desc());
          // batch_norm mode
          fwd_op->SetOpConstParamAttr(CUDNN_PARAM_BN_MODE,
                                      CUDNN_BATCHNORM_SPATIAL_PERSISTENT);

          // Make cudnn fused ops plan
          fwd_op->GetWorkspaceSizeInBytes(ctx.cudnn_handle());
          return fwd_op;
        });
    return fwd_op;
  }
198 199

 private:
200 201
  NormConvolutionArgs<T> args_;
};
202

203 204 205 206 207 208 209 210 211 212 213 214 215 216
template <typename T>
class CudnnNormConvolutionGrad {
 public:
  CudnnNormConvolutionGrad(const platform::CUDADeviceContext &ctx,
                           const std::vector<int> &input_shape,
                           const std::vector<int> &filter_shape,
                           const std::vector<int> &output_shape,
                           const int &padding, const int &stride,
                           const int &dilation, const int &group) {
    args_.Set(input_shape, filter_shape, output_shape, padding, stride,
              dilation, group);
    dgrad_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
  }
  ~CudnnNormConvolutionGrad() {}
217

218 219 220 221 222 223 224 225 226 227 228 229
  void Backward(const platform::CUDADeviceContext &ctx, T *input_ptr,
                T *output_grad_ptr, T *filter_ptr, T *input_grad_ptr,
                T *filter_grad_ptr, bool use_addto = false) {
    if (filter_grad_ptr) {
      BackwardFilter(ctx, input_ptr, output_grad_ptr, filter_ptr,
                     filter_grad_ptr);
    }
    if (input_grad_ptr) {
      BackwardData(ctx, input_ptr, output_grad_ptr, filter_ptr, input_grad_ptr,
                   use_addto);
    }
  }
230

231 232 233 234
 private:
  void BackwardFilter(const platform::CUDADeviceContext &ctx, T *input_ptr,
                      T *output_grad_ptr, T *filter_ptr, T *filter_grad_ptr) {
    auto cudnn_handle = ctx.cudnn_handle();
235

236 237 238 239
    CudnnFusionOp *wgrad_op = GetBackwardFilterOp(ctx);
    size_t workspace_size = RoundUp(
        static_cast<int64_t>(wgrad_op->GetWorkspaceSizeInBytes(cudnn_handle)),
        512);
240

241 242 243 244 245
    wgrad_op->SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, input_ptr);
    wgrad_op->SetOpVariantParamAttrPtr(CUDNN_PTR_DYDATA, output_grad_ptr);
    wgrad_op->SetOpVariantParamAttrPtr(CUDNN_PTR_DWDATA, filter_grad_ptr);
    wgrad_op->SetOpVariantParamAttrPtr(
        CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &workspace_size);
246

247 248 249 250 251 252 253 254 255
    ctx.cudnn_workspace_handle().RunFunc(
        [&](void *workspace_ptr) {
          // workspace ptr
          wgrad_op->SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE,
                                             workspace_ptr);
          // fused op execute
          wgrad_op->Execute(cudnn_handle);
        },
        workspace_size);
256 257
  }

258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
  void BackwardData(const platform::CUDADeviceContext &ctx, T *input_ptr,
                    T *output_grad_ptr, T *filter_ptr, T *input_grad_ptr,
                    bool use_addto = false) {
    auto cudnn_handle = ctx.cudnn_handle();
    size_t workspace_size = GetWorkspaceSizeBwdData(ctx);

    // Convolution dgrad followed optionally by batchnorm dgrad
    ScalingParamType<T> alpha = 1.0f;
    ScalingParamType<T> beta = use_addto ? 1.0f : 0.0f;
    ctx.cudnn_workspace_handle().RunFunc(
        [&](void *cudnn_workspace_ptr) {
          PADDLE_ENFORCE_CUDA_SUCCESS(
              platform::dynload::cudnnConvolutionBackwardData(
                  cudnn_handle, &alpha, args_.filter_desc.desc(), filter_ptr,
                  args_.out_desc.desc(), output_grad_ptr,
                  args_.conv_desc.desc(), dgrad_algo_, cudnn_workspace_ptr,
                  workspace_size, &beta, args_.in_desc.desc(), input_grad_ptr));
        },
        workspace_size);
277 278
  }

279 280 281 282 283 284 285 286 287 288 289 290 291 292
  CudnnFusionOp *GetBackwardFilterOp(const platform::CUDADeviceContext &ctx) {
    framework::AlgorithmsCache<CudnnFusionOp *> &cache =
        *(CudnnFusionOpCache::Instance().GetBackward());

    CudnnFusionOp *wgrad_op = cache.GetAlgorithm(
        args_.in_dims, args_.filter_dims, args_.strides, args_.paddings,
        args_.dilations, 0, static_cast<int64_t>(args_.dtype), [&]() {
          CudnnFusionOp *wgrad_op =
              new CudnnFusionOp(CUDNN_FUSED_SCALE_BIAS_ACTIVATION_WGRAD);

          wgrad_op->SetOpConstParamAttr(
              {CUDNN_PARAM_DYDATA_PLACEHOLDER, CUDNN_PARAM_XDATA_PLACEHOLDER,
               CUDNN_PARAM_DWDATA_PLACEHOLDER},
              CUDNN_PTR_16B_ALIGNED);
293

294 295 296 297 298 299 300 301 302 303 304 305 306 307
          // conv desc
          wgrad_op->SetOpConstParamDesc(CUDNN_PARAM_CONV_DESC,
                                        args_.conv_desc.desc());
          // input desc
          wgrad_op->SetOpConstParamDesc(CUDNN_PARAM_XDESC,
                                        args_.in_desc.desc());
          // filter desc
          wgrad_op->SetOpConstParamDesc(CUDNN_PARAM_DWDESC,
                                        args_.filter_desc.desc());
          // output desc
          wgrad_op->SetOpConstParamDesc(CUDNN_PARAM_DYDESC,
                                        args_.out_desc.desc());
          wgrad_op->SetOpConstParamAttr(CUDNN_PARAM_BN_MODE,
                                        CUDNN_BATCHNORM_SPATIAL_PERSISTENT);
308

309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329
          // Make cudnn fused ops plan
          wgrad_op->GetWorkspaceSizeInBytes(ctx.cudnn_handle());
          return wgrad_op;
        });
    return wgrad_op;
  }

  size_t GetWorkspaceSizeBwdData(const platform::CUDADeviceContext &ctx) {
    size_t workspace_size = 0U;
    auto handle = ctx.cudnn_handle();
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
            handle, args_.filter_desc.desc(), args_.out_desc.desc(),
            args_.conv_desc.desc(), args_.in_desc.desc(), dgrad_algo_,
            &workspace_size));
    return RoundUp(workspace_size, 512);
  }

 private:
  NormConvolutionArgs<T> args_;
  cudnnConvolutionBwdDataAlgo_t dgrad_algo_;
330
};
331

332 333 334
#endif
}  // namespace operators
}  // namespace paddle