未验证 提交 806b252c 编写于 作者: L limingshu 提交者: GitHub

first commit (#46525)

上级 45df9be8
...@@ -1196,8 +1196,7 @@ void TransposeGPUKernelDriver(const phi::GPUContext& ctx, ...@@ -1196,8 +1196,7 @@ void TransposeGPUKernelDriver(const phi::GPUContext& ctx,
if (!ret) { if (!ret) {
auto* tuner = auto* tuner =
phi::autotune::MakeTransposeTuner<T>(TransCompute<phi::GPUContext, T>); phi::autotune::MakeTransposeTuner<T>(TransCompute<phi::GPUContext, T>);
tuner->AddCallBack( tuner->AddCallBack(SimplifyThenLaunch<phi::GPUContext, T>);
phi::autotune::MakeCallback<T>(SimplifyThenLaunch<phi::GPUContext, T>));
size_t key = phi::autotune::TransposeKey( size_t key = phi::autotune::TransposeKey(
phi::vectorize(in.dims()), phi::vectorize(in.dims()),
......
...@@ -22,26 +22,26 @@ ...@@ -22,26 +22,26 @@
namespace phi { namespace phi {
namespace autotune { namespace autotune {
template <typename T, typename RetureType, typename... Args> template <typename T, typename ReturnType, typename... Args>
class KernelCallback { class KernelCallback {
public: public:
using ReturnT = RetureType; using ReturnT = ReturnType;
using FuncType = RetureType (*)(Args...); using FuncType = ReturnType (*)(Args...);
KernelCallback() {} KernelCallback() {}
explicit KernelCallback(FuncType func_) : func(func_) {} explicit KernelCallback(FuncType func_) : func(func_) {}
virtual ~KernelCallback() {} virtual ~KernelCallback() {}
RetureType Run(Args... args) { return func(args...); } ReturnType Run(Args... args) { return func(args...); }
private: private:
FuncType func; FuncType func;
}; };
template <typename T, typename RetureType, typename... Args> template <typename T, typename ReturnType, typename... Args>
static KernelCallback<T, RetureType, Args...> MakeCallback( static KernelCallback<T, ReturnType, Args...> MakeCallback(
RetureType (*cb)(Args...)) { ReturnType (*cb)(Args...)) {
return KernelCallback<T, RetureType, Args...>(cb); return KernelCallback<T, ReturnType, Args...>(cb);
} }
template <typename T, typename KernelType> template <typename T, typename KernelType>
...@@ -54,10 +54,11 @@ class AutoTuneBase { ...@@ -54,10 +54,11 @@ class AutoTuneBase {
kernels_.push_back(/*default=*/kernel); kernels_.push_back(/*default=*/kernel);
} }
void AddCallBack(KernelType kernel) { template <typename ReturnType, typename... Args>
void AddCallBack(ReturnType (*func)(Args...)) {
if (!is_init_) { if (!is_init_) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
kernels_.push_back(kernel); kernels_.push_back(MakeCallback<T>(func));
} }
} }
...@@ -142,31 +143,35 @@ class AutoTuneBase { ...@@ -142,31 +143,35 @@ class AutoTuneBase {
} }
}; };
template <typename T, typename RetureType, typename... Args> template <typename T, typename ReturnType, typename... Args>
static AutoTuneBase<T, KernelCallback<T, RetureType, Args...>> MakeAutoTuner( static AutoTuneBase<T, KernelCallback<T, ReturnType, Args...>> MakeAutoTuner(
RetureType (*func)(Args...)) { ReturnType (*func)(Args...)) {
auto obj = MakeCallback<T>(func); auto obj = MakeCallback<T>(func);
return AutoTuneBase<T, decltype(obj)>(obj); return AutoTuneBase<T, decltype(obj)>(obj);
} }
template <typename T, typename KernelType> template <typename T, typename ReturnType, typename... Args>
class TransposeAutoTuner : public AutoTuneBase<T, KernelType> { class TransposeAutoTuner
: public AutoTuneBase<T, KernelCallback<T, ReturnType, Args...>> {
public: public:
static AutoTuneBase<T, KernelType>* Instance(KernelType kernel) { static AutoTuneBase<T, KernelCallback<T, ReturnType, Args...>>* Instance(
ReturnType (*func)(Args...)) {
static std::once_flag transpose_init_flag_; static std::once_flag transpose_init_flag_;
static std::unique_ptr<AutoTuneBase<T, KernelType>> instance_; static std::unique_ptr<
AutoTuneBase<T, KernelCallback<T, ReturnType, Args...>>>
instance_;
std::call_once(transpose_init_flag_, [&] { std::call_once(transpose_init_flag_, [&] {
instance_.reset(new AutoTuneBase<T, KernelType>(kernel)); auto obj = MakeCallback<T>(func);
instance_.reset(new AutoTuneBase<T, decltype(obj)>(obj));
}); });
return instance_.get(); return instance_.get();
} }
}; };
template <typename T, typename RetureType, typename... Args> template <typename T, typename ReturnType, typename... Args>
static AutoTuneBase<T, KernelCallback<T, RetureType, Args...>>* static AutoTuneBase<T, KernelCallback<T, ReturnType, Args...>>*
MakeTransposeTuner(RetureType (*func)(Args...)) { MakeTransposeTuner(ReturnType (*func)(Args...)) {
auto obj = MakeCallback<T>(func); return TransposeAutoTuner<T, ReturnType, Args...>::Instance(func);
return TransposeAutoTuner<T, decltype(obj)>::Instance(obj);
} }
} // namespace autotune } // namespace autotune
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册