miopen_helper.h 17.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>
#include <vector>

#include "paddle/fluid/framework/operator.h"
21
#include "paddle/fluid/platform/device/gpu/gpu_types.h"
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
#include "paddle/fluid/platform/dynload/miopen.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/macros.h"

// MIOPEN do not have epslion definition
#define CUDNN_BN_MIN_EPSILON 1e-05

DECLARE_bool(cudnn_deterministic);

namespace paddle {
namespace platform {
inline const char* miopenGetErrorString(miopenStatus_t status) {
  switch (status) {
    case miopenStatusSuccess:
      return "miopenStatusSuccess";
    case miopenStatusNotInitialized:
      return "miopenStatusNotInitialized";
    case miopenStatusAllocFailed:
      return "miopenStatusAllocFailed";
    case miopenStatusBadParm:
      return "miopenStatusBadParm";
    case miopenStatusInternalError:
      return "miopenStatusInternalError";
    case miopenStatusInvalidValue:
      return "miopenStatusInvalidValue";
    case miopenStatusUnknownError:
      return "miopenStatusUnknownError";
    case miopenStatusNotImplemented:
      return "miopenStatusNotImplemented";
    default:
      return "Unknown miopen error number";
  }
}

// no use, but will have compiling error if not defined
#define CUDNN_VERSION_MIN(major, minor, patch) \
  (CUDNN_VERSION >= ((major)*1000 + (minor)*100 + (patch)))

enum class DataLayout {  // Not use
  kNHWC,
  kNCHW,
  kNCDHW,
  kNDHWC,  // add, liyamei
  kNCHW_VECT_C,
};

enum class PoolingMode {
  kMaximum,
  kMaximumDeterministic,
  kAverageExclusive,
  kAverageInclusive,
};

enum class ActivationMode {
  kNone,  // activation identity
  kSigmoid,
  kRelu,
  kRelu6,
  kReluX,
  kTanh,
  kBandPass,
};

inline miopenPoolingMode_t GetPoolingMode(const PoolingMode& mode) {
  switch (mode) {
    case PoolingMode::kMaximumDeterministic:
      return miopenPoolingMax;
    case PoolingMode::kAverageExclusive:
      return miopenPoolingAverage;
    case PoolingMode::kAverageInclusive:
      return miopenPoolingAverageInclusive;
    case PoolingMode::kMaximum:
      return miopenPoolingMax;
    default:
      PADDLE_THROW(
          platform::errors::Unimplemented("Unexpected MIOPEN pooling mode."));
  }
}

inline ActivationMode StringToActivationMode(const std::string& str) {
  if (str == "identity") {
    return ActivationMode::kNone;
  } else if (str == "sigmoid") {
    return ActivationMode::kSigmoid;
  } else if (str == "relu") {
    return ActivationMode::kRelu;
  } else if (str == "relu6") {
    return ActivationMode::kRelu6;
  } else if (str == "relux") {
    return ActivationMode::kReluX;
  } else if (str == "tanh") {
    return ActivationMode::kTanh;
  } else if (str == "bandpass") {
    return ActivationMode::kBandPass;
  } else {
    PADDLE_THROW(platform::errors::Unimplemented(
        "Unknown MIOPEN activation string: %s.", str));
  }
}

template <typename T>
class CudnnDataType;

template <>
class CudnnDataType<float16> {
 public:
  static const miopenDataType_t type = miopenHalf;
  // The scaling param type is float for HALF and FLOAT tensors
  using ScalingParamType = const float;
  using BatchNormParamType = float;
  static ScalingParamType* kOne() {
    static ScalingParamType v = 1.0;
    return &v;
  }
  static ScalingParamType* kZero() {
    static ScalingParamType v = 0.0;
    return &v;
  }
};

template <>
class CudnnDataType<float> {
 public:
  static const miopenDataType_t type = miopenFloat;
  using ScalingParamType = const float;
  using BatchNormParamType = float;
  static ScalingParamType* kOne() {
    static ScalingParamType v = 1.0;
    return &v;
  }
  static ScalingParamType* kZero() {
    static ScalingParamType v = 0.0;
    return &v;
  }
};

inline miopenTensorFormat_t GetCudnnTensorFormat(const DataLayout& order) {
  switch (order) {
    case DataLayout::kNHWC:
      return MIOPEN_TENSOR_NHWC;
    case DataLayout::kNCHW:
      return MIOPEN_TENSOR_NCHW;
    case DataLayout::kNCDHW:
      return MIOPEN_TENSOR_NCHW;
    case DataLayout::kNDHWC:
      return MIOPEN_TENSOR_NHWC;
    default:
      PADDLE_THROW(platform::errors::Unimplemented(
          "MIOPEN has no equivalent dataLayout for input order."));
  }
  return MIOPEN_TENSOR_NCHW;
}

class ScopedTensorDescriptor {
 public:
  ScopedTensorDescriptor() {
179
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenCreateTensorDescriptor(&desc_));
180 181
  }
  ~ScopedTensorDescriptor() PADDLE_MAY_THROW {
182
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenDestroyTensorDescriptor(desc_));
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
  }

  inline miopenTensorDescriptor_t descriptor(const miopenTensorFormat_t format,
                                             const miopenDataType_t type,
                                             const std::vector<int>& dims,
                                             const int groups = 1) {
    // the format is not used now, will add later
    std::vector<int> strides(dims.size());
    strides[dims.size() - 1] = 1;
    for (int i = dims.size() - 2; i >= 0; i--) {
      strides[i] = dims[i + 1] * strides[i + 1];
    }
    // Update tensor descriptor dims setting if groups > 1
    // NOTE: Here, Assume using NCHW or NCDHW order
    std::vector<int> dims_with_group(dims.begin(), dims.end());
    if (groups > 1) {
      dims_with_group[1] = dims_with_group[1] / groups;
    }

    // MIOPEN ONLY support data layout of NCHW
    PADDLE_ENFORCE_EQ(format, MIOPEN_TENSOR_NCHW,
                      platform::errors::InvalidArgument(
                          "format should ONLY be NCHW in MIOPEN."));
    if (dims.size() == 4) {
207
      PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenSetTensorDescriptor(
208 209 210 211
          desc_, type, dims_with_group.size(),
          const_cast<int*>(dims_with_group.data()),
          const_cast<int*>(strides.data())));
    } else if (dims.size() == 5) {
212
      PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenSetTensorDescriptor(
213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
          desc_, type, dims_with_group.size(),
          const_cast<int*>(dims_with_group.data()),
          const_cast<int*>(strides.data())));
    }
    return desc_;
  }

  template <typename T>
  inline miopenTensorDescriptor_t descriptor(const DataLayout& order,
                                             const std::vector<int>& dims,
                                             const int groups = 1) {
    return descriptor(GetCudnnTensorFormat(order), CudnnDataType<T>::type, dims,
                      groups);
  }

  inline miopenTensorDescriptor_t descriptor(const miopenDataType_t miopen_type,
                                             const std::vector<int>& dim,
                                             const std::vector<int>& stride) {
231
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenSetTensorDescriptor(
232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
        desc_, miopen_type, dim.size(), const_cast<int*>(dim.data()),
        const_cast<int*>(stride.data())));
    return desc_;
  }

  template <typename T>
  inline miopenTensorDescriptor_t descriptor(const std::vector<int>& dim,
                                             const std::vector<int>& stride) {
    return descriptor(CudnnDataType<T>::type, dim, stride);
  }

  inline miopenTensorDescriptor_t desc() { return desc_; }

 private:
  miopenTensorDescriptor_t desc_;
  DISABLE_COPY_AND_ASSIGN(ScopedTensorDescriptor);
};

class ScopedDropoutDescriptor {
 public:
  ScopedDropoutDescriptor() {
253
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenCreateDropoutDescriptor(&desc_));
254 255
  }
  ~ScopedDropoutDescriptor() PADDLE_MAY_THROW {
256
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenDestroyDropoutDescriptor(desc_));
257 258 259 260 261 262 263 264 265
  }

  inline miopenDropoutDescriptor_t descriptor(const miopenHandle_t& handle,
                                              const platform::Place& place,
                                              bool initialized,
                                              float dropout_prob_,
                                              framework::Tensor* dropout_state_,
                                              int seed, size_t state_size) {
    if (dropout_state_ == nullptr) {  // for no dropout or test
266
      PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenSetDropoutDescriptor(
267 268 269 270 271 272
          desc_, handle, 0 /* dropout */, nullptr, 0 /* state_size */,
          0 /* seed */, false, false, MIOPEN_RNG_PSEUDO_XORWOW));
      return desc_;
    }
    auto* dropout_state_data = dropout_state_->data<uint8_t>();
    if (!initialized) {
273
      PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenSetDropoutDescriptor(
274 275 276 277 278
          desc_, handle, dropout_prob_, dropout_state_data, state_size, seed,
          false, false, MIOPEN_RNG_PSEUDO_XORWOW));
    } else {
      auto dropout_state_dims = dropout_state_->dims();
      state_size = dropout_state_dims[0];
279
      PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenRestoreDropoutDescriptor(
280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
          desc_, handle, dropout_prob_, dropout_state_data, state_size, 0,
          false, false, MIOPEN_RNG_PSEUDO_XORWOW));
    }
    return desc_;
  }
  inline miopenDropoutDescriptor_t desc() { return desc_; }

 private:
  miopenDropoutDescriptor_t desc_;
  DISABLE_COPY_AND_ASSIGN(ScopedDropoutDescriptor);
};

class ScopedRNNDescriptor {
 public:
  ScopedRNNDescriptor() {
295
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenCreateRNNDescriptor(&desc_));
296 297
  }
  ~ScopedRNNDescriptor() PADDLE_MAY_THROW {
298
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenDestroyRNNDescriptor(desc_));
299 300 301 302 303 304 305 306 307 308 309 310
  }

  inline miopenRNNDescriptor_t desc() { return desc_; }

 private:
  miopenRNNDescriptor_t desc_;
  DISABLE_COPY_AND_ASSIGN(ScopedRNNDescriptor);
};

class ScopedFilterDescriptor {
 public:
  ScopedFilterDescriptor() {
311
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenCreateTensorDescriptor(&desc_));
312 313
  }
  ~ScopedFilterDescriptor() PADDLE_MAY_THROW {
314
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenDestroyTensorDescriptor(desc_));
315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334
  }

  inline miopenTensorDescriptor_t descriptor(const miopenTensorFormat_t format,
                                             const miopenDataType_t type,
                                             const std::vector<int>& kernel,
                                             const int groups = 1) {
    // filter layout: MCHW(MCDHW), where M is the number of
    // output image channels, C is the number of input image channels,
    // D is the depth of the filter, H is the height of the filter, and W is the
    // width of the filter.
    std::vector<int> kernel_with_group(kernel.begin(), kernel.end());
    if (groups > 1) {
      kernel_with_group[0] /= groups;
      // NOTE: input filter(C) of the filter is already asserted to be C/groups.
    }
    std::vector<int> stride_dim(kernel_with_group.size());
    stride_dim.push_back(1);
    for (int k = kernel_with_group.size() - 2; k >= 0; k--) {
      stride_dim[k] = stride_dim[k + 1] * kernel_with_group[k + 1];
    }
335
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenSetTensorDescriptor(
336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359
        desc_, type, kernel_with_group.size(),
        const_cast<int*>(kernel_with_group.data()),
        const_cast<int*>(stride_dim.data())));
    return desc_;
  }

  template <typename T>
  inline miopenTensorDescriptor_t descriptor(const DataLayout& order,
                                             const std::vector<int>& kernel,
                                             const int groups = 1) {
    return descriptor(GetCudnnTensorFormat(order), CudnnDataType<T>::type,
                      kernel, groups);
  }

  inline miopenTensorDescriptor_t desc() { return desc_; }

 private:
  miopenTensorDescriptor_t desc_;
  DISABLE_COPY_AND_ASSIGN(ScopedFilterDescriptor);
};

class ScopedConvolutionDescriptor {
 public:
  ScopedConvolutionDescriptor() {
360
    PADDLE_ENFORCE_GPU_SUCCESS(
361 362 363
        dynload::miopenCreateConvolutionDescriptor(&desc_));
  }
  ~ScopedConvolutionDescriptor() PADDLE_MAY_THROW {
364
    PADDLE_ENFORCE_GPU_SUCCESS(
365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381
        dynload::miopenDestroyConvolutionDescriptor(desc_));
  }

  inline miopenConvolutionDescriptor_t descriptor(
      miopenDataType_t type, const std::vector<int>& pads,
      const std::vector<int>& strides, const std::vector<int>& dilations) {
    PADDLE_ENFORCE_EQ(pads.size(), strides.size(),
                      platform::errors::InvalidArgument(
                          "The size of pads and strides should be equal. But "
                          "received size of pads is %d, size of strides is %d.",
                          pads.size(), strides.size()));
    PADDLE_ENFORCE_EQ(
        pads.size(), dilations.size(),
        platform::errors::InvalidArgument(
            "The size of pads and dilations should be equal. But received size "
            "of pads is %d, size of dilations is %d.",
            pads.size(), dilations.size()));
382
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenInitConvolutionNdDescriptor(
383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403
        desc_, pads.size(), const_cast<int*>(pads.data()),
        const_cast<int*>(strides.data()), const_cast<int*>(dilations.data()),
        miopenConvolution));
    return desc_;
  }

  template <typename T>
  inline miopenConvolutionDescriptor_t descriptor(
      const std::vector<int>& pads, const std::vector<int>& strides,
      const std::vector<int>& dilations) {
    return descriptor(CudnnDataType<T>::type, pads, strides, dilations);
  }

 private:
  miopenConvolutionDescriptor_t desc_;
  DISABLE_COPY_AND_ASSIGN(ScopedConvolutionDescriptor);
};

class ScopedPoolingDescriptor {
 public:
  ScopedPoolingDescriptor() {
404
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenCreatePoolingDescriptor(&desc_));
405 406
  }
  ~ScopedPoolingDescriptor() PADDLE_MAY_THROW {
407
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenDestroyPoolingDescriptor(desc_));
408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424
  }

  inline miopenPoolingDescriptor_t descriptor(const PoolingMode& mode,
                                              const std::vector<int>& kernel,
                                              const std::vector<int>& pads,
                                              const std::vector<int>& strides) {
    PADDLE_ENFORCE_EQ(kernel.size(), pads.size(),
                      platform::errors::InvalidArgument(
                          "The size of kernel and pads should be equal. But "
                          "received size of kernel is %d, size of pads is %d.",
                          kernel.size(), pads.size()));
    PADDLE_ENFORCE_EQ(
        kernel.size(), strides.size(),
        platform::errors::InvalidArgument(
            "The size of kernel and strides should be equal. But "
            "received size of kernel is %d, size of strides is %d.",
            kernel.size(), strides.size()));
425
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenSetNdPoolingDescriptor(
R
ronnywang 已提交
426 427 428
        desc_, GetPoolingMode(mode), kernel.size(),
        const_cast<int*>(kernel.data()), const_cast<int*>(pads.data()),
        const_cast<int*>(strides.data())));
429 430 431 432 433 434 435 436 437 438 439
    return desc_;
  }

 private:
  miopenPoolingDescriptor_t desc_;
  DISABLE_COPY_AND_ASSIGN(ScopedPoolingDescriptor);
};

class ScopedActivationDescriptor {
 public:
  ScopedActivationDescriptor() {
440
    PADDLE_ENFORCE_GPU_SUCCESS(
441 442 443
        dynload::miopenCreateActivationDescriptor(&desc_));
  }
  ~ScopedActivationDescriptor() PADDLE_MAY_THROW {
444
    PADDLE_ENFORCE_GPU_SUCCESS(
445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479
        dynload::miopenDestroyActivationDescriptor(desc_));
  }

  template <typename T>
  inline miopenActivationDescriptor_t descriptor(
      const std::string& act, double value_max = static_cast<double>(0.)) {
    double relu_ceiling = 0.0;
    ActivationMode activation_mode = StringToActivationMode(act);
    miopenActivationMode_t mode;
    switch (activation_mode) {
      case ActivationMode::kNone:
        mode = miopenActivationPASTHRU;
        break;
      case ActivationMode::kRelu6:
        relu_ceiling = 6.0;
        mode = miopenActivationCLIPPEDRELU;
        break;
      case ActivationMode::kReluX:
        relu_ceiling = value_max;
        mode = miopenActivationCLIPPEDRELU;
        break;
      case ActivationMode::kRelu:
        mode = miopenActivationRELU;
        break;
      case ActivationMode::kSigmoid:
        mode = miopenActivationLOGISTIC;
        break;
      case ActivationMode::kTanh:
        mode = miopenActivationTANH;
        break;
      default:
        PADDLE_THROW(platform::errors::Unimplemented(
            "Unrecognized MIOPEN activation mode: %d.",
            static_cast<int>(activation_mode)));
    }
480
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenSetActivationDescriptor(
481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504
        desc_, mode, relu_ceiling, 0.0, 0.0));
    return desc_;
  }

 private:
  miopenActivationDescriptor_t desc_;
  DISABLE_COPY_AND_ASSIGN(ScopedActivationDescriptor);
};

inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) {
  bool use_cudnn = ctx.Attr<bool>("use_cudnn");
  use_cudnn &= paddle::platform::is_gpu_place(ctx.GetPlace());
#ifdef PADDLE_WITH_HIP
  if (use_cudnn) {
    auto& dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
    use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
  }
#endif
  return use_cudnn;
}

class ScopedCTCLossDescriptor {
 public:
  ScopedCTCLossDescriptor() {
505
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenCreateCTCLossDescriptor(&desc_));
506 507
  }
  ~ScopedCTCLossDescriptor() PADDLE_MAY_THROW {
508
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenDestroyCTCLossDescriptor(desc_));
509 510 511 512
  }

  template <typename T>
  inline miopenCTCLossDescriptor_t descriptor() {
513
    PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenSetCTCLossDescriptor(
514 515 516 517 518 519 520 521 522 523 524
        desc_, CudnnDataType<T>::type, 0, false));
    return desc_;
  }

 private:
  miopenCTCLossDescriptor_t desc_;
  DISABLE_COPY_AND_ASSIGN(ScopedCTCLossDescriptor);
};

}  // namespace platform
}  // namespace paddle