auto_tune_base.h 4.6 KB
Newer Older
L
limingshu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

17
#include <mutex>
L
limingshu 已提交
18
#include <type_traits>
19

L
limingshu 已提交
20 21 22 23 24 25 26
#include "glog/logging.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/autotune/gpu_timer.h"

namespace phi {
namespace autotune {

27
template <typename T, typename RetureType, typename... Args>
L
limingshu 已提交
28 29 30 31 32 33 34 35 36
class KernelCallback {
 public:
  using ReturnT = RetureType;
  using FuncType = RetureType (*)(Args...);

  KernelCallback() {}
  explicit KernelCallback(FuncType func_) : func(func_) {}
  virtual ~KernelCallback() {}

37
  RetureType Run(Args... args) { return func(args...); }
L
limingshu 已提交
38 39 40 41 42

 private:
  FuncType func;
};

43 44
template <typename T, typename RetureType, typename... Args>
static KernelCallback<T, RetureType, Args...> MakeCallback(
L
limingshu 已提交
45
    RetureType (*cb)(Args...)) {
46
  return KernelCallback<T, RetureType, Args...>(cb);
L
limingshu 已提交
47 48
}

49
template <typename T, typename KernelType>
L
limingshu 已提交
50 51 52 53
class AutoTuneBase {
 public:
  AutoTuneBase() {}
  virtual ~AutoTuneBase() {}
54 55 56 57 58 59
  explicit AutoTuneBase(KernelType kernel) { kernels_.push_back(kernel); }

  template <typename Type>
  void AddCallBack(Type kernel) {
    static_assert(std::is_same<Type, KernelType>::value,
                  "Type must be the same");
L
limingshu 已提交
60 61 62
    kernels_.push_back(kernel);
  }

63 64 65 66 67 68 69 70
  template <typename... Args>
  void RunBestKernel(const int idx, Args&&... args) {
    kernels_[idx].Run(args...);
  }

  template <typename... Args>
  void RunDefaultKernel(Args&&... args) {
    kernels_[0].Run(args...);
L
limingshu 已提交
71 72 73
  }

  template <typename Context, typename... Args>
74
  int PickBestKernel(const Context& ctx, Args&&... args) {
L
limingshu 已提交
75 76 77 78 79
    PADDLE_ENFORCE_GT(
        kernels_.size(),
        0,
        paddle::platform::errors::InvalidArgument(
            "kernel num must be greater than 0, now is %d", kernels_.size()));
80
    int best_idx = 0;
L
limingshu 已提交
81 82
    float min_time = std::numeric_limits<float>::max();

83
    // Time cost test estabulished in default stream.
L
limingshu 已提交
84
    for (int i = 0; i < kernels_.size(); ++i) {
85
      auto time = RunAndMeasureKernel<Context>(ctx, i, args...);
L
limingshu 已提交
86 87
      if (time < min_time) {
        min_time = time;
88
        best_idx = i;
L
limingshu 已提交
89 90
      }
    }
91 92
    VLOG(3) << "best kernel idx is " << best_idx;
    return best_idx;
L
limingshu 已提交
93 94
  }

95 96 97
  bool IsInit() { return is_init_; }
  void Finalize() { is_init_ = true; }

L
limingshu 已提交
98
 private:
99
  bool is_init_{false};
L
limingshu 已提交
100
  std::vector<KernelType> kernels_;
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124

  template <typename Context, typename... Args>
  float RunAndMeasureKernel(const Context& ctx, const int idx, Args&&... args) {
    phi::GpuTimer timer;
    float time_cost = 0;
    const auto& stream = ctx.stream();

    // Treat 1st run as warm up. Judge the result with
    // the sum of 2nd and 3rd run.
    constexpr int repeats = 3;

    ctx.Wait();
    for (int i = 0; i < repeats; ++i) {
      timer.Start(stream);
      kernels_[idx].Run(args...);
      timer.Stop(stream);
      auto time = timer.ElapsedTime();
      if (i > 0) {
        time_cost += time;
      }
      VLOG(3) << "kernel[" << idx << "][" << i << "th time cost is " << time;
    }
    return time_cost;
  }
L
limingshu 已提交
125 126
};

127 128
template <typename T, typename RetureType, typename... Args>
static AutoTuneBase<T, KernelCallback<T, RetureType, Args...>> MakeAutoTuner(
L
limingshu 已提交
129
    RetureType (*func)(Args...)) {
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
  auto obj = MakeCallback<T>(func);
  return AutoTuneBase<T, decltype(obj)>(obj);
}

template <typename T, typename KernelType>
class TransposeAutoTuner : public AutoTuneBase<T, KernelType> {
 public:
  static AutoTuneBase<T, KernelType>* Instance(KernelType kernel) {
    static std::unique_ptr<AutoTuneBase<T, KernelType>> instance_;
    std::call_once(init_flag_, [&] {
      instance_.reset(new AutoTuneBase<T, KernelType>(kernel));
    });
    return instance_.get();
  }

 private:
  static std::once_flag init_flag_;
};

template <typename T, typename KernelType>
std::once_flag TransposeAutoTuner<T, KernelType>::init_flag_;

template <typename T, typename RetureType, typename... Args>
static AutoTuneBase<T, KernelCallback<T, RetureType, Args...>>*
    MakeTransposeTuner(RetureType (*func)(Args...)) {
  auto obj = MakeCallback<T>(func);
  return TransposeAutoTuner<T, decltype(obj)>::Instance(obj);
L
limingshu 已提交
157 158 159 160
}

}  // namespace autotune
}  // namespace phi