cudnn_norm_conv.cu.h 16.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/fluid/operators/fused/cudnn_fusion_helper.h"
18
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
19 20 21 22 23 24

namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
namespace dynload = platform::dynload;

25 26 27
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;

28
#if CUDNN_VERSION >= 8000
29 30 31

static size_t RoundUp(int64_t a, int64_t b) { return (a + b - 1) / b * b; }

32
template <typename T>
33 34 35 36 37 38 39
struct NormConvolutionArgs {
  NormConvolutionArgs() {
    dtype = platform::CudnnDataType<T>::type;
    format = CUDNN_TENSOR_NHWC;
    compute_type = platform::CudnnDataType<float>::type;
  }

L
Leo Chen 已提交
40
  void Set(const phi::GPUContext &ctx,
41
           const std::vector<int> &input_shape,
42
           const std::vector<int> &filter_shape,
43 44 45 46 47
           const std::vector<int> &output_shape,
           int padding,
           int stride,
           int dilation,
           int group) {
48 49 50 51 52 53 54 55
    PADDLE_ENFORCE_LT(
        ctx.GetComputeCapability(),
        90,
        phi::errors::PreconditionNotMet(
            "Expect compute compatiblity to be less than 90, but got %d. "
            "CUDNN FusedOps is no longer available on H100 and later "
            "devices.",
            ctx.GetComputeCapability()));
56
    PADDLE_ENFORCE_EQ(
57 58
        input_shape.size(),
        4U,
59
        platform::errors::InvalidArgument(
60
            "The size of input_shape is expected to 4. But received "
61
            "input_shape's size is %d, input_shape is [%s].",
62 63
            input_shape.size(),
            phi::make_ddim(input_shape)));
64
    PADDLE_ENFORCE_EQ(
65 66
        filter_shape.size(),
        4U,
67
        platform::errors::InvalidArgument(
68
            "The size of filter_shape is expected to 4. But received "
69
            "filter_shape's size is %d, filter_shape is [%s].",
70 71
            filter_shape.size(),
            phi::make_ddim(filter_shape)));
72 73 74 75 76
    PADDLE_ENFORCE_EQ(filter_shape[1] == filter_shape[2] &&
                          (filter_shape[1] == 1 || filter_shape[1] == 3),
                      true,
                      platform::errors::InvalidArgument(
                          "The filter_shape is expected to store as nhwc, and "
77
                          "h = w = 1 or 3. But received filter_shape is [%s].",
78
                          phi::make_ddim(filter_shape)));
79 80 81 82 83
    PADDLE_ENFORCE_EQ((filter_shape[0] % 32 == 0 && filter_shape[3] % 8 == 0),
                      true,
                      platform::errors::InvalidArgument(
                          "The input channel is expected to be multiple of 8, "
                          "and the output channel is expected to be multiple "
84
                          "of 32. But received input channel is %d, output "
85
                          "channel is %d.",
86 87
                          filter_shape[3],
                          filter_shape[0]));
88
    PADDLE_ENFORCE_EQ(
89 90
        output_shape.size(),
        4U,
91
        platform::errors::InvalidArgument(
92
            "The size of output_shape is expected to 4. But received "
93
            "filter_shape's size is %d, filter_shape is [%s].",
94 95
            output_shape.size(),
            phi::make_ddim(output_shape)));
96 97
    is_support = IsSupport(ctx, filter_shape, stride, dilation, group);
    PADDLE_ENFORCE_EQ(
98 99
        is_support,
        true,
100 101 102 103 104
        platform::errors::InvalidArgument(
            "Current test is only supported in the platforms with "
            "compatiblity greater than or equal to 70 and the kernel size "
            "must be equal to 1 or 3. When the kernel size is 1, "
            "the stride must be 1 if the compatiblity is equal to 70. "
105
            "Besides, the dilation and group must be equal to 1. But received "
106 107
            "compatiblity is %d, kernel size is %d, stride is %d, "
            "dilation is %d, group is %d",
108 109 110 111
            ctx.GetComputeCapability(),
            filter_shape[1],
            stride,
            dilation,
112
            group));
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134

    for (size_t i = 0; i < input_shape.size(); ++i) {
      in_dims.push_back(input_shape[i]);
    }
    for (size_t i = 0; i < filter_shape.size(); ++i) {
      filter_dims.push_back(filter_shape[i]);
    }
    paddings = {padding, padding};
    strides = {stride, stride};
    dilations = {dilation, dilation};

    in_desc.set(input_shape, format, dtype);
    filter_desc.set(filter_shape, format, dtype, group);
    out_desc.set(output_shape, format, dtype);

    int output_channel = filter_shape[0];
    std::vector<int> stats_shape = {1, 1, 1, output_channel};
    out_stats_desc.set(stats_shape, format, compute_type);

    conv_desc.set(dtype, paddings, strides, dilations, false, group);
  }

L
Leo Chen 已提交
135
  bool IsSupport(const phi::GPUContext &ctx,
136 137 138
                 const std::vector<int> &filter_shape,
                 int stride,
                 int dilation,
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
                 int group) {
    int kernel_size = filter_shape[1];
    if (dilation != 1 || group != 1) {
      return false;
    }
    if (ctx.GetComputeCapability() == 70) {
      if ((kernel_size == 3) || ((kernel_size == 1) && (stride == 1))) {
        return true;
      }
    } else if (ctx.GetComputeCapability() > 70) {
      if ((kernel_size == 3) || (kernel_size == 1)) {
        return true;
      }
    }
    return false;
  }

156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
  cudnnDataType_t dtype;
  cudnnTensorFormat_t format;
  cudnnDataType_t compute_type;

  std::vector<int64_t> in_dims;
  std::vector<int64_t> filter_dims;
  std::vector<int> strides;
  std::vector<int> paddings;
  std::vector<int> dilations;

  platform::TensorDescriptor in_desc;
  platform::FilterDescriptor filter_desc;
  platform::TensorDescriptor out_desc;
  platform::TensorDescriptor out_stats_desc;
  platform::ConvolutionDescriptor conv_desc;
171 172

  bool is_support;
173 174 175 176
};

template <typename T>
class CudnnNormConvolution {
177
 public:
L
Leo Chen 已提交
178
  CudnnNormConvolution(const phi::GPUContext &ctx,
179 180
                       const std::vector<int> &input_shape,
                       const std::vector<int> &filter_shape,
181 182 183 184
                       const std::vector<int> &output_shape,
                       const int &padding,
                       const int &stride,
                       const int &dilation,
185
                       const int &group) {
186 187 188 189 190 191 192 193
    args_.Set(ctx,
              input_shape,
              filter_shape,
              output_shape,
              padding,
              stride,
              dilation,
              group);
194
  }
195
  ~CudnnNormConvolution() {}
196

L
Leo Chen 已提交
197
  void Forward(const phi::GPUContext &ctx,
198 199 200 201
               const Tensor &input,
               const Tensor &filter,
               Tensor *output,
               Tensor *sum,
202
               Tensor *sum_of_squares) {
203 204 205 206 207 208 209
    auto cudnn_handle = ctx.cudnn_handle();

    CudnnFusionOp *fwd_op = GetForwardOp(ctx);
    size_t workspace_size = RoundUp(
        static_cast<int64_t>(fwd_op->GetWorkspaceSizeInBytes(cudnn_handle)),
        512);

210 211
    // Set variant_param
    // input ptr
212 213
    T *input_ptr = const_cast<T *>(input.data<T>());
    T *filter_ptr = const_cast<T *>(filter.data<T>());
214 215 216 217 218
    fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, input_ptr);
    fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_WDATA, filter_ptr);
    fwd_op->SetOpVariantParamAttrPtr(
        CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &workspace_size);

219
    // output ptr
220 221 222 223 224
    T *output_ptr = ctx.template Alloc<T>(output, output->numel() * sizeof(T));
    float *sum_ptr =
        ctx.template Alloc<float>(sum, sum->numel() * sizeof(float));
    float *sum_of_squares_ptr = ctx.template Alloc<float>(
        sum_of_squares, sum_of_squares->numel() * sizeof(float));
225 226 227 228 229
    fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YDATA, output_ptr);
    fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YSUM, sum_ptr);
    fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_YSQSUM, sum_of_squares_ptr);

    ctx.cudnn_workspace_handle().RunFunc(
230 231
        [&](void *workspace_ptr) {
          // workspace ptr
232
          fwd_op->SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE, workspace_ptr);
233
          // fused op execute
234
          fwd_op->Execute(cudnn_handle);
235
        },
236
        workspace_size);
237 238
  }

239
 private:
L
Leo Chen 已提交
240
  CudnnFusionOp *GetForwardOp(const phi::GPUContext &ctx) {
241 242 243 244
    framework::AlgorithmsCache<CudnnFusionOp *> &cache =
        *(CudnnFusionOpCache::Instance().GetForward());

    CudnnFusionOp *fwd_op = cache.GetAlgorithm(
245 246 247 248 249 250 251 252
        args_.in_dims,
        args_.filter_dims,
        args_.strides,
        args_.paddings,
        args_.dilations,
        0,
        static_cast<int64_t>(args_.dtype),
        [&]() {
253 254 255 256
          CudnnFusionOp *fwd_op =
              new CudnnFusionOp(CUDNN_FUSED_SCALE_BIAS_ACTIVATION_CONV_BNSTATS);

          // Set constant_param
257 258 259 260
          fwd_op->SetOpConstParamAttr({CUDNN_PARAM_XDATA_PLACEHOLDER,
                                       CUDNN_PARAM_WDATA_PLACEHOLDER,
                                       CUDNN_PARAM_YDATA_PLACEHOLDER},
                                      CUDNN_PTR_16B_ALIGNED);
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287
          fwd_op->SetOpConstParamAttr(
              {CUDNN_PARAM_YSUM_PLACEHOLDER, CUDNN_PARAM_YSQSUM_PLACEHOLDER},
              CUDNN_PTR_16B_ALIGNED);

          // conv desc
          fwd_op->SetOpConstParamDesc(CUDNN_PARAM_CONV_DESC,
                                      args_.conv_desc.desc());
          // input desc
          fwd_op->SetOpConstParamDesc(CUDNN_PARAM_XDESC, args_.in_desc.desc());
          // filter desc
          fwd_op->SetOpConstParamDesc(CUDNN_PARAM_WDESC,
                                      args_.filter_desc.desc());
          // output desc
          fwd_op->SetOpConstParamDesc(CUDNN_PARAM_YDESC, args_.out_desc.desc());
          // output_stats desc
          fwd_op->SetOpConstParamDesc(CUDNN_PARAM_YSTATS_DESC,
                                      args_.out_stats_desc.desc());
          // batch_norm mode
          fwd_op->SetOpConstParamAttr(CUDNN_PARAM_BN_MODE,
                                      CUDNN_BATCHNORM_SPATIAL_PERSISTENT);

          // Make cudnn fused ops plan
          fwd_op->GetWorkspaceSizeInBytes(ctx.cudnn_handle());
          return fwd_op;
        });
    return fwd_op;
  }
288 289

 private:
290 291
  NormConvolutionArgs<T> args_;
};
292

293 294 295
template <typename T>
class CudnnNormConvolutionGrad {
 public:
L
Leo Chen 已提交
296
  CudnnNormConvolutionGrad(const phi::GPUContext &ctx,
297 298 299
                           const std::vector<int> &input_shape,
                           const std::vector<int> &filter_shape,
                           const std::vector<int> &output_shape,
300 301 302 303 304 305 306 307 308 309 310 311
                           const int &padding,
                           const int &stride,
                           const int &dilation,
                           const int &group) {
    args_.Set(ctx,
              input_shape,
              filter_shape,
              output_shape,
              padding,
              stride,
              dilation,
              group);
312 313 314
    dgrad_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
  }
  ~CudnnNormConvolutionGrad() {}
315

L
Leo Chen 已提交
316
  void Backward(const phi::GPUContext &ctx,
317 318 319 320 321
                const Tensor &input,
                const Tensor &filter,
                const Tensor &output_grad,
                Tensor *input_grad,
                Tensor *filter_grad,
322 323 324 325 326 327
                bool use_addto = false) {
    T *input_ptr = const_cast<T *>(input.data<T>());
    T *filter_ptr = const_cast<T *>(filter.data<T>());
    T *output_grad_ptr = const_cast<T *>(output_grad.data<T>());

    if (filter_grad) {
328 329
      T *filter_grad_ptr =
          ctx.template Alloc<T>(filter_grad, filter_grad->numel() * sizeof(T));
330
      BackwardFilter(ctx, output_grad_ptr, input_ptr, filter_grad_ptr);
331
    }
332
    if (input_grad) {
333 334
      T *input_grad_ptr =
          ctx.template Alloc<T>(input_grad, input_grad->numel() * sizeof(T));
335
      BackwardData(ctx, output_grad_ptr, filter_ptr, input_grad_ptr, use_addto);
336 337
    }
  }
338

339
 private:
L
Leo Chen 已提交
340
  void BackwardFilter(const phi::GPUContext &ctx,
341 342 343
                      T *output_grad_ptr,
                      T *input_ptr,
                      T *filter_grad_ptr) {
344
    auto cudnn_handle = ctx.cudnn_handle();
345

346 347 348 349
    CudnnFusionOp *wgrad_op = GetBackwardFilterOp(ctx);
    size_t workspace_size = RoundUp(
        static_cast<int64_t>(wgrad_op->GetWorkspaceSizeInBytes(cudnn_handle)),
        512);
350

351 352 353 354 355
    wgrad_op->SetOpVariantParamAttrPtr(CUDNN_PTR_XDATA, input_ptr);
    wgrad_op->SetOpVariantParamAttrPtr(CUDNN_PTR_DYDATA, output_grad_ptr);
    wgrad_op->SetOpVariantParamAttrPtr(CUDNN_PTR_DWDATA, filter_grad_ptr);
    wgrad_op->SetOpVariantParamAttrPtr(
        CUDNN_SCALAR_SIZE_T_WORKSPACE_SIZE_IN_BYTES, &workspace_size);
356

357 358 359 360 361 362 363 364 365
    ctx.cudnn_workspace_handle().RunFunc(
        [&](void *workspace_ptr) {
          // workspace ptr
          wgrad_op->SetOpVariantParamAttrPtr(CUDNN_PTR_WORKSPACE,
                                             workspace_ptr);
          // fused op execute
          wgrad_op->Execute(cudnn_handle);
        },
        workspace_size);
366 367
  }

L
Leo Chen 已提交
368
  void BackwardData(const phi::GPUContext &ctx,
369 370 371 372
                    T *output_grad_ptr,
                    T *filter_ptr,
                    T *input_grad_ptr,
                    bool use_addto = false) {
373 374 375 376 377 378 379 380
    auto cudnn_handle = ctx.cudnn_handle();
    size_t workspace_size = GetWorkspaceSizeBwdData(ctx);

    // Convolution dgrad followed optionally by batchnorm dgrad
    ScalingParamType<T> alpha = 1.0f;
    ScalingParamType<T> beta = use_addto ? 1.0f : 0.0f;
    ctx.cudnn_workspace_handle().RunFunc(
        [&](void *cudnn_workspace_ptr) {
381
          PADDLE_ENFORCE_GPU_SUCCESS(
382
              platform::dynload::cudnnConvolutionBackwardData(
383 384 385 386 387 388 389 390 391 392 393 394 395
                  cudnn_handle,
                  &alpha,
                  args_.filter_desc.desc(),
                  filter_ptr,
                  args_.out_desc.desc(),
                  output_grad_ptr,
                  args_.conv_desc.desc(),
                  dgrad_algo_,
                  cudnn_workspace_ptr,
                  workspace_size,
                  &beta,
                  args_.in_desc.desc(),
                  input_grad_ptr));
396 397
        },
        workspace_size);
398 399
  }

L
Leo Chen 已提交
400
  CudnnFusionOp *GetBackwardFilterOp(const phi::GPUContext &ctx) {
401 402 403 404
    framework::AlgorithmsCache<CudnnFusionOp *> &cache =
        *(CudnnFusionOpCache::Instance().GetBackward());

    CudnnFusionOp *wgrad_op = cache.GetAlgorithm(
405 406 407 408 409 410 411 412
        args_.in_dims,
        args_.filter_dims,
        args_.strides,
        args_.paddings,
        args_.dilations,
        0,
        static_cast<int64_t>(args_.dtype),
        [&]() {
413 414 415
          CudnnFusionOp *wgrad_op =
              new CudnnFusionOp(CUDNN_FUSED_SCALE_BIAS_ACTIVATION_WGRAD);

416 417 418 419
          wgrad_op->SetOpConstParamAttr({CUDNN_PARAM_DYDATA_PLACEHOLDER,
                                         CUDNN_PARAM_XDATA_PLACEHOLDER,
                                         CUDNN_PARAM_DWDATA_PLACEHOLDER},
                                        CUDNN_PTR_16B_ALIGNED);
420

421 422 423 424 425 426 427 428 429 430 431 432 433 434
          // conv desc
          wgrad_op->SetOpConstParamDesc(CUDNN_PARAM_CONV_DESC,
                                        args_.conv_desc.desc());
          // input desc
          wgrad_op->SetOpConstParamDesc(CUDNN_PARAM_XDESC,
                                        args_.in_desc.desc());
          // filter desc
          wgrad_op->SetOpConstParamDesc(CUDNN_PARAM_DWDESC,
                                        args_.filter_desc.desc());
          // output desc
          wgrad_op->SetOpConstParamDesc(CUDNN_PARAM_DYDESC,
                                        args_.out_desc.desc());
          wgrad_op->SetOpConstParamAttr(CUDNN_PARAM_BN_MODE,
                                        CUDNN_BATCHNORM_SPATIAL_PERSISTENT);
435

436 437 438 439 440 441 442
          // Make cudnn fused ops plan
          wgrad_op->GetWorkspaceSizeInBytes(ctx.cudnn_handle());
          return wgrad_op;
        });
    return wgrad_op;
  }

L
Leo Chen 已提交
443
  size_t GetWorkspaceSizeBwdData(const phi::GPUContext &ctx) {
444 445
    size_t workspace_size = 0U;
    auto handle = ctx.cudnn_handle();
446
    PADDLE_ENFORCE_GPU_SUCCESS(
447
        platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
448 449 450 451 452 453
            handle,
            args_.filter_desc.desc(),
            args_.out_desc.desc(),
            args_.conv_desc.desc(),
            args_.in_desc.desc(),
            dgrad_algo_,
454 455 456 457 458 459 460
            &workspace_size));
    return RoundUp(workspace_size, 512);
  }

 private:
  NormConvolutionArgs<T> args_;
  cudnnConvolutionBwdDataAlgo_t dgrad_algo_;
461
};
462

463 464 465
#endif
}  // namespace operators
}  // namespace paddle