diff --git a/paddle/fluid/operators/transpose_op.cu.h b/paddle/fluid/operators/transpose_op.cu.h index f7c4597d437568913b5a68cb6a6ec1e5e63258b2..486c568b42b8a2823fbf3a1205990e35315b25e0 100644 --- a/paddle/fluid/operators/transpose_op.cu.h +++ b/paddle/fluid/operators/transpose_op.cu.h @@ -1196,8 +1196,7 @@ void TransposeGPUKernelDriver(const phi::GPUContext& ctx, if (!ret) { auto* tuner = phi::autotune::MakeTransposeTuner(TransCompute); - tuner->AddCallBack( - phi::autotune::MakeCallback(SimplifyThenLaunch)); + tuner->AddCallBack(SimplifyThenLaunch); size_t key = phi::autotune::TransposeKey( phi::vectorize(in.dims()), diff --git a/paddle/phi/kernels/autotune/auto_tune_base.h b/paddle/phi/kernels/autotune/auto_tune_base.h index 1f79855107e764593b446a510c6a54efeacc0a32..ff97b2a1f48f4bf046fe0e8b4728a321d1a62336 100644 --- a/paddle/phi/kernels/autotune/auto_tune_base.h +++ b/paddle/phi/kernels/autotune/auto_tune_base.h @@ -22,26 +22,26 @@ namespace phi { namespace autotune { -template +template class KernelCallback { public: - using ReturnT = RetureType; - using FuncType = RetureType (*)(Args...); + using ReturnT = ReturnType; + using FuncType = ReturnType (*)(Args...); KernelCallback() {} explicit KernelCallback(FuncType func_) : func(func_) {} virtual ~KernelCallback() {} - RetureType Run(Args... args) { return func(args...); } + ReturnType Run(Args... args) { return func(args...); } private: FuncType func; }; -template -static KernelCallback MakeCallback( - RetureType (*cb)(Args...)) { - return KernelCallback(cb); +template +static KernelCallback MakeCallback( + ReturnType (*cb)(Args...)) { + return KernelCallback(cb); } template @@ -54,10 +54,11 @@ class AutoTuneBase { kernels_.push_back(/*default=*/kernel); } - void AddCallBack(KernelType kernel) { + template + void AddCallBack(ReturnType (*func)(Args...)) { if (!is_init_) { std::lock_guard lock(mutex_); - kernels_.push_back(kernel); + kernels_.push_back(MakeCallback(func)); } } @@ -142,31 +143,35 @@ class AutoTuneBase { } }; -template -static AutoTuneBase> MakeAutoTuner( - RetureType (*func)(Args...)) { +template +static AutoTuneBase> MakeAutoTuner( + ReturnType (*func)(Args...)) { auto obj = MakeCallback(func); return AutoTuneBase(obj); } -template -class TransposeAutoTuner : public AutoTuneBase { +template +class TransposeAutoTuner + : public AutoTuneBase> { public: - static AutoTuneBase* Instance(KernelType kernel) { + static AutoTuneBase>* Instance( + ReturnType (*func)(Args...)) { static std::once_flag transpose_init_flag_; - static std::unique_ptr> instance_; + static std::unique_ptr< + AutoTuneBase>> + instance_; std::call_once(transpose_init_flag_, [&] { - instance_.reset(new AutoTuneBase(kernel)); + auto obj = MakeCallback(func); + instance_.reset(new AutoTuneBase(obj)); }); return instance_.get(); } }; -template -static AutoTuneBase>* -MakeTransposeTuner(RetureType (*func)(Args...)) { - auto obj = MakeCallback(func); - return TransposeAutoTuner::Instance(obj); +template +static AutoTuneBase>* +MakeTransposeTuner(ReturnType (*func)(Args...)) { + return TransposeAutoTuner::Instance(func); } } // namespace autotune