未验证 提交 806b252c 编写于 作者: L limingshu 提交者: GitHub

first commit (#46525)

上级 45df9be8
......@@ -1196,8 +1196,7 @@ void TransposeGPUKernelDriver(const phi::GPUContext& ctx,
if (!ret) {
auto* tuner =
phi::autotune::MakeTransposeTuner<T>(TransCompute<phi::GPUContext, T>);
tuner->AddCallBack(
phi::autotune::MakeCallback<T>(SimplifyThenLaunch<phi::GPUContext, T>));
tuner->AddCallBack(SimplifyThenLaunch<phi::GPUContext, T>);
size_t key = phi::autotune::TransposeKey(
phi::vectorize(in.dims()),
......
......@@ -22,26 +22,26 @@
namespace phi {
namespace autotune {
template <typename T, typename RetureType, typename... Args>
template <typename T, typename ReturnType, typename... Args>
class KernelCallback {
public:
using ReturnT = RetureType;
using FuncType = RetureType (*)(Args...);
using ReturnT = ReturnType;
using FuncType = ReturnType (*)(Args...);
KernelCallback() {}
explicit KernelCallback(FuncType func_) : func(func_) {}
virtual ~KernelCallback() {}
RetureType Run(Args... args) { return func(args...); }
ReturnType Run(Args... args) { return func(args...); }
private:
FuncType func;
};
template <typename T, typename RetureType, typename... Args>
static KernelCallback<T, RetureType, Args...> MakeCallback(
RetureType (*cb)(Args...)) {
return KernelCallback<T, RetureType, Args...>(cb);
template <typename T, typename ReturnType, typename... Args>
static KernelCallback<T, ReturnType, Args...> MakeCallback(
ReturnType (*cb)(Args...)) {
return KernelCallback<T, ReturnType, Args...>(cb);
}
template <typename T, typename KernelType>
......@@ -54,10 +54,11 @@ class AutoTuneBase {
kernels_.push_back(/*default=*/kernel);
}
void AddCallBack(KernelType kernel) {
template <typename ReturnType, typename... Args>
void AddCallBack(ReturnType (*func)(Args...)) {
if (!is_init_) {
std::lock_guard<std::mutex> lock(mutex_);
kernels_.push_back(kernel);
kernels_.push_back(MakeCallback<T>(func));
}
}
......@@ -142,31 +143,35 @@ class AutoTuneBase {
}
};
template <typename T, typename RetureType, typename... Args>
static AutoTuneBase<T, KernelCallback<T, RetureType, Args...>> MakeAutoTuner(
RetureType (*func)(Args...)) {
template <typename T, typename ReturnType, typename... Args>
static AutoTuneBase<T, KernelCallback<T, ReturnType, Args...>> MakeAutoTuner(
ReturnType (*func)(Args...)) {
auto obj = MakeCallback<T>(func);
return AutoTuneBase<T, decltype(obj)>(obj);
}
template <typename T, typename KernelType>
class TransposeAutoTuner : public AutoTuneBase<T, KernelType> {
template <typename T, typename ReturnType, typename... Args>
class TransposeAutoTuner
: public AutoTuneBase<T, KernelCallback<T, ReturnType, Args...>> {
public:
static AutoTuneBase<T, KernelType>* Instance(KernelType kernel) {
static AutoTuneBase<T, KernelCallback<T, ReturnType, Args...>>* Instance(
ReturnType (*func)(Args...)) {
static std::once_flag transpose_init_flag_;
static std::unique_ptr<AutoTuneBase<T, KernelType>> instance_;
static std::unique_ptr<
AutoTuneBase<T, KernelCallback<T, ReturnType, Args...>>>
instance_;
std::call_once(transpose_init_flag_, [&] {
instance_.reset(new AutoTuneBase<T, KernelType>(kernel));
auto obj = MakeCallback<T>(func);
instance_.reset(new AutoTuneBase<T, decltype(obj)>(obj));
});
return instance_.get();
}
};
template <typename T, typename RetureType, typename... Args>
static AutoTuneBase<T, KernelCallback<T, RetureType, Args...>>*
MakeTransposeTuner(RetureType (*func)(Args...)) {
auto obj = MakeCallback<T>(func);
return TransposeAutoTuner<T, decltype(obj)>::Instance(obj);
template <typename T, typename ReturnType, typename... Args>
static AutoTuneBase<T, KernelCallback<T, ReturnType, Args...>>*
MakeTransposeTuner(ReturnType (*func)(Args...)) {
return TransposeAutoTuner<T, ReturnType, Args...>::Instance(func);
}
} // namespace autotune
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册