cudnn_norm_conv.cu.h 15.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/fluid/operators/fused/cudnn_fusion_helper.h"
18 19
#include "paddle/fluid/platform/cudnn_desc.h"
#include "paddle/fluid/platform/cudnn_helper.h"
20 21 22 23 24 25

namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
namespace dynload = platform::dynload;

26 27 28
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;

29
#if CUDNN_VERSION >= 8000
30 31 32

static size_t RoundUp(int64_t a, int64_t b) { return (a + b - 1) / b * b; }

33
template <typename T>
34 35 36 37 38 39 40
struct NormConvolutionArgs {
  NormConvolutionArgs() {
    dtype = platform::CudnnDataType<T>::type;
    format = CUDNN_TENSOR_NHWC;
    compute_type = platform::CudnnDataType<float>::type;
  }

41 42
  void Set(const platform::CUDADeviceContext &ctx,
           const std::vector<int> &input_shape,
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
           const std::vector<int> &filter_shape,
           const std::vector<int> &output_shape, int padding, int stride,
           int dilation, int group) {
    PADDLE_ENFORCE_EQ(
        input_shape.size(), 4U,
        platform::errors::InvalidArgument(
            "The size of input_shape is expected to 4. But recieved "
            "input_shape's size is %d, input_shape is [%s].",
            input_shape.size(), framework::make_ddim(input_shape)));
    PADDLE_ENFORCE_EQ(
        filter_shape.size(), 4U,
        platform::errors::InvalidArgument(
            "The size of filter_shape is expected to 4. But recieved "
            "filter_shape's size is %d, filter_shape is [%s].",
            filter_shape.size(), framework::make_ddim(filter_shape)));
    PADDLE_ENFORCE_EQ(filter_shape[1] == filter_shape[2] &&
                          (filter_shape[1] == 1 || filter_shape[1] == 3),
                      true,
                      platform::errors::InvalidArgument(
                          "The filter_shape is expected to store as nhwc, and "
                          "h = w = 1 or 3. But recieved filter_shape is [%s].",
                          framework::make_ddim(filter_shape)));
65 66 67 68 69 70 71 72
    PADDLE_ENFORCE_EQ((filter_shape[0] % 32 == 0 && filter_shape[3] % 8 == 0),
                      true,
                      platform::errors::InvalidArgument(
                          "The input channel is expected to be multiple of 8, "
                          "and the output channel is expected to be multiple "
                          "of 32. But recieved input channel is %d, output "
                          "channel is %d.",
                          filter_shape[3], filter_shape[0]));
73 74 75 76 77 78
    PADDLE_ENFORCE_EQ(
        output_shape.size(), 4U,
        platform::errors::InvalidArgument(
            "The size of output_shape is expected to 4. But recieved "
            "filter_shape's size is %d, filter_shape is [%s].",
            output_shape.size(), framework::make_ddim(output_shape)));
79 80 81 82 83 84 85 86 87 88 89 90 91
    is_support = IsSupport(ctx, filter_shape, stride, dilation, group);
    PADDLE_ENFORCE_EQ(
        is_support, true,
        platform::errors::InvalidArgument(
            "Current test is only supported in the platforms with "
            "compatiblity greater than or equal to 70 and the kernel size "
            "must be equal to 1 or 3. When the kernel size is 1, "
            "the stride must be 1 if the compatiblity is equal to 70. "
            "Besides, the dilation and group must be equal to 1. But recieved "
            "compatiblity is %d, kernel size is %d, stride is %d, "
            "dilation is %d, group is %d",
            ctx.GetComputeCapability(), filter_shape[1], stride, dilation,
            group));
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113

    for (size_t i = 0; i < input_shape.size(); ++i) {
      in_dims.push_back(input_shape[i]);
    }
    for (size_t i = 0; i < filter_shape.size(); ++i) {
      filter_dims.push_back(filter_shape[i]);
    }
    paddings = {padding, padding};
    strides = {stride, stride};
    dilations = {dilation, dilation};

    in_desc.set(input_shape, format, dtype);
    filter_desc.set(filter_shape, format, dtype, group);
    out_desc.set(output_shape, format, dtype);

    int output_channel = filter_shape[0];
    std::vector<int> stats_shape = {1, 1, 1, output_channel};
    out_stats_desc.set(stats_shape, format, compute_type);

    conv_desc.set(dtype, paddings, strides, dilations, false, group);
  }

114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
  bool IsSupport(const platform::CUDADeviceContext &ctx,
                 const std::vector<int> &filter_shape, int stride, int dilation,
                 int group) {
    int kernel_size = filter_shape[1];
    if (dilation != 1 || group != 1) {
      return false;
    }
    if (ctx.GetComputeCapability() == 70) {
      if ((kernel_size == 3) || ((kernel_size == 1) && (stride == 1))) {
        return true;
      }
    } else if (ctx.GetComputeCapability() > 70) {
      if ((kernel_size == 3) || (kernel_size == 1)) {
        return true;
      }
    }
    return false;
  }

133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
  cudnnDataType_t dtype;
  cudnnTensorFormat_t format;
  cudnnDataType_t compute_type;

  std::vector<int64_t> in_dims;
  std::vector<int64_t> filter_dims;
  std::vector<int> strides;
  std::vector<int> paddings;
  std::vector<int> dilations;

  platform::TensorDescriptor in_desc;
  platform::FilterDescriptor filter_desc;
  platform::TensorDescriptor out_desc;
  platform::TensorDescriptor out_stats_desc;
  platform::ConvolutionDescriptor conv_desc;
148 149

  bool is_support;
150 151 152 153
};

template <typename T>
class CudnnNormConvolution {
154
 public:
155 156 157 158 159 160
  CudnnNormConvolution(const platform::CUDADeviceContext &ctx,
                       const std::vector<int> &input_shape,
                       const std::vector<int> &filter_shape,
                       const std::vector<int> &output_shape, const int &padding,
                       const int &stride, const int &dilation,
                       const int &group) {
161
    args_.Set(ctx, input_shape, filter_shape, output_shape, padding, stride,
162
              dilation, group);
163
  }
164
  ~CudnnNormConvolution() {}
165

166 167 168
  void Forward(const platform::CUDADeviceContext &ctx, const Tensor &input,
               const Tensor &filter, Tensor *output, Tensor *sum,
               Tensor *sum_of_squares) {
169
    auto cudnn_handle = ctx.cudnn_handle();
170
    auto place = ctx.GetPlace();
171 172 173 174 175 176

    CudnnFusionOp *fwd_op = GetForwardOp(ctx);
    size_t workspace_size = RoundUp(
        static_cast<int64_t>(fwd_op->GetWorkspaceSizeInBytes(cudnn_handle)),
        512);

177 178
    // Set variant_param
    // input ptr
179 180
    T *input_ptr = const_cast<T *>(input.data<T>());
    T *filter_ptr = const_cast<T *>(filter.data<T>());
181 182 183 184 185
    fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, input_ptr);
    fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_WDATA, filter_ptr);
    fwd_op->SetOpVariantParamAttrPtr(
        CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &workspace_size);

186
    // output ptr
187 188 189
    T *output_ptr = output->mutable_data<T>(place);
    float *sum_ptr = sum->mutable_data<float>(place);
    float *sum_of_squares_ptr = sum_of_squares->mutable_data<float>(place);
190 191 192 193 194
    fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YDATA, output_ptr);
    fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YSUM, sum_ptr);
    fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YSQSUM, sum_of_squares_ptr);

    ctx.cudnn_workspace_handle().RunFunc(
195 196
        [&](void *workspace_ptr) {
          // workspace ptr
197
          fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE, workspace_ptr);
198
          // fused op execute
199
          fwd_op->Execute(cudnn_handle);
200
        },
201
        workspace_size);
202 203
  }

204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
 private:
  CudnnFusionOp *GetForwardOp(const platform::CUDADeviceContext &ctx) {
    framework::AlgorithmsCache<CudnnFusionOp *> &cache =
        *(CudnnFusionOpCache::Instance().GetForward());

    CudnnFusionOp *fwd_op = cache.GetAlgorithm(
        args_.in_dims, args_.filter_dims, args_.strides, args_.paddings,
        args_.dilations, 0, static_cast<int64_t>(args_.dtype), [&]() {
          CudnnFusionOp *fwd_op =
              new CudnnFusionOp(CUDNN_FUSED_SCALE_BIAS_ACTIVATION_CONV_BNSTATS);

          // Set constant_param
          fwd_op->SetOpConstParamAttr(
              {CUDNN_PARAM_XDATA_PLACEHOLDER, CUDNN_PARAM_WDATA_PLACEHOLDER,
               CUDNN_PARAM_YDATA_PLACEHOLDER},
              CUDNN_PTR_16B_ALIGNED);
          fwd_op->SetOpConstParamAttr(
              {CUDNN_PARAM_YSUM_PLACEHOLDER, CUDNN_PARAM_YSQSUM_PLACEHOLDER},
              CUDNN_PTR_16B_ALIGNED);

          // conv desc
          fwd_op->SetOpConstParamDesc(CUDNN_PARAM_CONV_DESC,
                                      args_.conv_desc.desc());
          // input desc
          fwd_op->SetOpConstParamDesc(CUDNN_PARAM_XDESC, args_.in_desc.desc());
          // filter desc
          fwd_op->SetOpConstParamDesc(CUDNN_PARAM_WDESC,
                                      args_.filter_desc.desc());
          // output desc
          fwd_op->SetOpConstParamDesc(CUDNN_PARAM_YDESC, args_.out_desc.desc());
          // output_stats desc
          fwd_op->SetOpConstParamDesc(CUDNN_PARAM_YSTATS_DESC,
                                      args_.out_stats_desc.desc());
          // batch_norm mode
          fwd_op->SetOpConstParamAttr(CUDNN_PARAM_BN_MODE,
                                      CUDNN_BATCHNORM_SPATIAL_PERSISTENT);

          // Make cudnn fused ops plan
          fwd_op->GetWorkspaceSizeInBytes(ctx.cudnn_handle());
          return fwd_op;
        });
    return fwd_op;
  }
247 248

 private:
249 250
  NormConvolutionArgs<T> args_;
};
251

252 253 254 255 256 257 258 259 260
template <typename T>
class CudnnNormConvolutionGrad {
 public:
  CudnnNormConvolutionGrad(const platform::CUDADeviceContext &ctx,
                           const std::vector<int> &input_shape,
                           const std::vector<int> &filter_shape,
                           const std::vector<int> &output_shape,
                           const int &padding, const int &stride,
                           const int &dilation, const int &group) {
261
    args_.Set(ctx, input_shape, filter_shape, output_shape, padding, stride,
262 263 264 265
              dilation, group);
    dgrad_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
  }
  ~CudnnNormConvolutionGrad() {}
266

267 268 269 270 271 272 273 274 275 276 277 278
  void Backward(const platform::CUDADeviceContext &ctx, const Tensor &input,
                const Tensor &filter, const Tensor &output_grad,
                Tensor *input_grad, Tensor *filter_grad,
                bool use_addto = false) {
    auto place = ctx.GetPlace();
    T *input_ptr = const_cast<T *>(input.data<T>());
    T *filter_ptr = const_cast<T *>(filter.data<T>());
    T *output_grad_ptr = const_cast<T *>(output_grad.data<T>());

    if (filter_grad) {
      T *filter_grad_ptr = filter_grad->mutable_data<T>(place);
      BackwardFilter(ctx, output_grad_ptr, input_ptr, filter_grad_ptr);
279
    }
280 281 282
    if (input_grad) {
      T *input_grad_ptr = input_grad->mutable_data<T>(place);
      BackwardData(ctx, output_grad_ptr, filter_ptr, input_grad_ptr, use_addto);
283 284
    }
  }
285

286
 private:
287 288
  void BackwardFilter(const platform::CUDADeviceContext &ctx,
                      T *output_grad_ptr, T *input_ptr, T *filter_grad_ptr) {
289
    auto cudnn_handle = ctx.cudnn_handle();
290

291 292 293 294
    CudnnFusionOp *wgrad_op = GetBackwardFilterOp(ctx);
    size_t workspace_size = RoundUp(
        static_cast<int64_t>(wgrad_op->GetWorkspaceSizeInBytes(cudnn_handle)),
        512);
295

296 297 298 299 300
    wgrad_op->SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, input_ptr);
    wgrad_op->SetOpVariantParamAttrPtr(CUDNN_PTR_DYDATA, output_grad_ptr);
    wgrad_op->SetOpVariantParamAttrPtr(CUDNN_PTR_DWDATA, filter_grad_ptr);
    wgrad_op->SetOpVariantParamAttrPtr(
        CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &workspace_size);
301

302 303 304 305 306 307 308 309 310
    ctx.cudnn_workspace_handle().RunFunc(
        [&](void *workspace_ptr) {
          // workspace ptr
          wgrad_op->SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE,
                                             workspace_ptr);
          // fused op execute
          wgrad_op->Execute(cudnn_handle);
        },
        workspace_size);
311 312
  }

313 314
  void BackwardData(const platform::CUDADeviceContext &ctx, T *output_grad_ptr,
                    T *filter_ptr, T *input_grad_ptr, bool use_addto = false) {
315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330
    auto cudnn_handle = ctx.cudnn_handle();
    size_t workspace_size = GetWorkspaceSizeBwdData(ctx);

    // Convolution dgrad followed optionally by batchnorm dgrad
    ScalingParamType<T> alpha = 1.0f;
    ScalingParamType<T> beta = use_addto ? 1.0f : 0.0f;
    ctx.cudnn_workspace_handle().RunFunc(
        [&](void *cudnn_workspace_ptr) {
          PADDLE_ENFORCE_CUDA_SUCCESS(
              platform::dynload::cudnnConvolutionBackwardData(
                  cudnn_handle, &alpha, args_.filter_desc.desc(), filter_ptr,
                  args_.out_desc.desc(), output_grad_ptr,
                  args_.conv_desc.desc(), dgrad_algo_, cudnn_workspace_ptr,
                  workspace_size, &beta, args_.in_desc.desc(), input_grad_ptr));
        },
        workspace_size);
331 332
  }

333 334 335 336 337 338 339 340 341 342 343 344 345 346
  CudnnFusionOp *GetBackwardFilterOp(const platform::CUDADeviceContext &ctx) {
    framework::AlgorithmsCache<CudnnFusionOp *> &cache =
        *(CudnnFusionOpCache::Instance().GetBackward());

    CudnnFusionOp *wgrad_op = cache.GetAlgorithm(
        args_.in_dims, args_.filter_dims, args_.strides, args_.paddings,
        args_.dilations, 0, static_cast<int64_t>(args_.dtype), [&]() {
          CudnnFusionOp *wgrad_op =
              new CudnnFusionOp(CUDNN_FUSED_SCALE_BIAS_ACTIVATION_WGRAD);

          wgrad_op->SetOpConstParamAttr(
              {CUDNN_PARAM_DYDATA_PLACEHOLDER, CUDNN_PARAM_XDATA_PLACEHOLDER,
               CUDNN_PARAM_DWDATA_PLACEHOLDER},
              CUDNN_PTR_16B_ALIGNED);
347

348 349 350 351 352 353 354 355 356 357 358 359 360 361
          // conv desc
          wgrad_op->SetOpConstParamDesc(CUDNN_PARAM_CONV_DESC,
                                        args_.conv_desc.desc());
          // input desc
          wgrad_op->SetOpConstParamDesc(CUDNN_PARAM_XDESC,
                                        args_.in_desc.desc());
          // filter desc
          wgrad_op->SetOpConstParamDesc(CUDNN_PARAM_DWDESC,
                                        args_.filter_desc.desc());
          // output desc
          wgrad_op->SetOpConstParamDesc(CUDNN_PARAM_DYDESC,
                                        args_.out_desc.desc());
          wgrad_op->SetOpConstParamAttr(CUDNN_PARAM_BN_MODE,
                                        CUDNN_BATCHNORM_SPATIAL_PERSISTENT);
362

363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383
          // Make cudnn fused ops plan
          wgrad_op->GetWorkspaceSizeInBytes(ctx.cudnn_handle());
          return wgrad_op;
        });
    return wgrad_op;
  }

  size_t GetWorkspaceSizeBwdData(const platform::CUDADeviceContext &ctx) {
    size_t workspace_size = 0U;
    auto handle = ctx.cudnn_handle();
    PADDLE_ENFORCE_CUDA_SUCCESS(
        platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
            handle, args_.filter_desc.desc(), args_.out_desc.desc(),
            args_.conv_desc.desc(), args_.in_desc.desc(), dgrad_algo_,
            &workspace_size));
    return RoundUp(workspace_size, 512);
  }

 private:
  NormConvolutionArgs<T> args_;
  cudnnConvolutionBwdDataAlgo_t dgrad_algo_;
384
};
385

386 387 388
#endif
}  // namespace operators
}  // namespace paddle