auto_tune_base.h 5.3 KB
Newer Older
L
limingshu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <type_traits>
#include "glog/logging.h"
#include "paddle/phi/kernels/autotune/gpu_timer.h"
20
#include "paddle/phi/kernels/autotune/switch_autotune.h"
L
limingshu 已提交
21 22 23 24

namespace phi {
namespace autotune {

25
template <typename T, typename RetureType, typename... Args>
L
limingshu 已提交
26 27 28 29 30 31 32 33 34
class KernelCallback {
 public:
  using ReturnT = RetureType;
  using FuncType = RetureType (*)(Args...);

  KernelCallback() {}
  explicit KernelCallback(FuncType func_) : func(func_) {}
  virtual ~KernelCallback() {}

35
  RetureType Run(Args... args) { return func(args...); }
L
limingshu 已提交
36 37 38 39 40

 private:
  FuncType func;
};

41 42
template <typename T, typename RetureType, typename... Args>
static KernelCallback<T, RetureType, Args...> MakeCallback(
L
limingshu 已提交
43
    RetureType (*cb)(Args...)) {
44
  return KernelCallback<T, RetureType, Args...>(cb);
L
limingshu 已提交
45 46
}

47
template <typename T, typename KernelType>
L
limingshu 已提交
48 49 50 51
class AutoTuneBase {
 public:
  AutoTuneBase() {}
  virtual ~AutoTuneBase() {}
52

53 54
  explicit AutoTuneBase(KernelType kernel) {
    kernels_.push_back(/*default=*/kernel);
L
limingshu 已提交
55 56
  }

57 58 59 60 61
  void AddCallBack(KernelType kernel) {
    if (!is_init_) {
      std::lock_guard<std::mutex> lock(mutex_);
      kernels_.push_back(kernel);
    }
62 63
  }

64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
  template <typename Context, typename... Args>
  void Run(const Context& ctx,
           const AlgorithmType& algo,
           const size_t key,
           Args&&... args) {
    PADDLE_ENFORCE_GT(
        kernels_.size(),
        0,
        paddle::platform::errors::InvalidArgument(
            "kernel num must be greater than 0, now is %d", kernels_.size()));
    is_init_ = true;

    auto& cache = AutoTuneCache::Instance().Get(algo);
    if (cache.Find(key)) {
      auto best_idx = cache.Get(key);
      kernels_[best_idx].Run(args...);
    } else {
      bool use_autotune = AutoTuneStatus::Instance().UseAutoTune();
      if (use_autotune) {
        // All avaliable kernels have ran while picking the best kernel,
        // so there may be no need for another kernel run.
        auto best_idx = PickBestKernel(ctx, args...);
        cache.Set(key, best_idx);
      } else {
        kernels_[0].Run(args...);
      }
    }
L
limingshu 已提交
91 92
  }

93 94 95 96 97
 private:
  bool is_init_{false};
  std::vector<KernelType> kernels_;
  mutable std::mutex mutex_;

L
limingshu 已提交
98
  template <typename Context, typename... Args>
99 100
  size_t PickBestKernel(const Context& ctx, Args&&... args) {
    std::lock_guard<std::mutex> lock(mutex_);
L
limingshu 已提交
101 102 103 104 105
    PADDLE_ENFORCE_GT(
        kernels_.size(),
        0,
        paddle::platform::errors::InvalidArgument(
            "kernel num must be greater than 0, now is %d", kernels_.size()));
106
    size_t best_idx = 0;
L
limingshu 已提交
107 108
    float min_time = std::numeric_limits<float>::max();

109
    // Time cost test estabulished in default stream.
L
limingshu 已提交
110
    for (int i = 0; i < kernels_.size(); ++i) {
111
      auto time = RunAndMeasureKernel<Context>(ctx, i, args...);
L
limingshu 已提交
112 113
      if (time < min_time) {
        min_time = time;
114
        best_idx = i;
L
limingshu 已提交
115 116
      }
    }
117 118
    VLOG(3) << "best kernel idx is " << best_idx;
    return best_idx;
L
limingshu 已提交
119 120
  }

121 122
  template <typename Context, typename... Args>
  float RunAndMeasureKernel(const Context& ctx, const int idx, Args&&... args) {
123 124 125
    // Regard 1st run as warmup. Judge the result by the time cost of rest run
    // cycles.
    constexpr int repeats = 3;
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
    phi::GpuTimer timer;
    float time_cost = 0;
    const auto& stream = ctx.stream();

    ctx.Wait();
    for (int i = 0; i < repeats; ++i) {
      timer.Start(stream);
      kernels_[idx].Run(args...);
      timer.Stop(stream);
      auto time = timer.ElapsedTime();
      if (i > 0) {
        time_cost += time;
      }
      VLOG(3) << "kernel[" << idx << "][" << i << "th time cost is " << time;
    }
    return time_cost;
  }
L
limingshu 已提交
143 144
};

145 146
template <typename T, typename RetureType, typename... Args>
static AutoTuneBase<T, KernelCallback<T, RetureType, Args...>> MakeAutoTuner(
L
limingshu 已提交
147
    RetureType (*func)(Args...)) {
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
  auto obj = MakeCallback<T>(func);
  return AutoTuneBase<T, decltype(obj)>(obj);
}

template <typename T, typename KernelType>
class TransposeAutoTuner : public AutoTuneBase<T, KernelType> {
 public:
  static AutoTuneBase<T, KernelType>* Instance(KernelType kernel) {
    static std::unique_ptr<AutoTuneBase<T, KernelType>> instance_;
    std::call_once(init_flag_, [&] {
      instance_.reset(new AutoTuneBase<T, KernelType>(kernel));
    });
    return instance_.get();
  }

 private:
  static std::once_flag init_flag_;
};

template <typename T, typename KernelType>
std::once_flag TransposeAutoTuner<T, KernelType>::init_flag_;

template <typename T, typename RetureType, typename... Args>
static AutoTuneBase<T, KernelCallback<T, RetureType, Args...>>*
172
MakeTransposeTuner(RetureType (*func)(Args...)) {
173 174
  auto obj = MakeCallback<T>(func);
  return TransposeAutoTuner<T, decltype(obj)>::Instance(obj);
L
limingshu 已提交
175 176 177 178
}

}  // namespace autotune
}  // namespace phi