未验证 提交 7c5dca9f 编写于 作者: L limingshu 提交者: GitHub

add_autotune_kernel_tool (#40658)

* for 1st time interface combine.

* modification with kernel factory

* first auto_tune version.

* first version.

* basic version

* add warm up step.

* a debug version.

* optimize the functionality of class auto_tuner.

* add some quotes for optimized auto_tuner class.

* add some quotes for optimized auto_tuner class.

* add namespace.

* modification according to the advices

* replace fluid header with phi header.

* replace fluid header with phi header.
上级 dea24544
if (WITH_GPU)
nv_test(gpu_timer_test SRCS gpu_timer_test.cu DEPS gtest)
nv_test(auto_tune_test SRCS auto_tune_test.cu DEPS gtest)
elseif (WITH_ROCM)
hip_test(gpu_timer_test SRCS gpu_timer_test.cu DEPS gtest)
hip_test(auto_tune_test SRCS auto_tune_test.cu DEPS gtest)
endif()
cc_test(cache_test SRCS cache_test.cc DEPS gtest)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <type_traits>
#include "glog/logging.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/autotune/gpu_timer.h"
namespace phi {
namespace autotune {
template <typename RetureType, typename... Args>
class KernelCallback {
public:
using ReturnT = RetureType;
using FuncType = RetureType (*)(Args...);
KernelCallback() {}
explicit KernelCallback(FuncType func_) : func(func_) {}
virtual ~KernelCallback() {}
RetureType Call(Args... args) { return func(args...); }
private:
FuncType func;
};
template <typename RetureType, typename... Args>
static KernelCallback<RetureType, Args...> MakeCallback(
RetureType (*cb)(Args...)) {
return KernelCallback<RetureType, Args...>(cb);
}
template <typename KernelType>
class AutoTuneBase {
public:
AutoTuneBase() {}
virtual ~AutoTuneBase() {}
explicit AutoTuneBase(KernelType kernel) : default_kernel_(kernel) {
kernels_.push_back(kernel);
}
template <typename T>
void AddCallBack(T kernel) {
static_assert(std::is_same<T, KernelType>::value, "Type must be the same");
kernels_.push_back(kernel);
}
template <typename Context, typename... Args>
KernelType PickBestKernel(const Context& ctx, Args&&... args) {
PADDLE_ENFORCE_GT(
kernels_.size(),
0,
paddle::platform::errors::InvalidArgument(
"kernel num must be greater than 0, now is %d", kernels_.size()));
int idx = 0;
phi::GpuTimer timer;
float min_time = std::numeric_limits<float>::max();
for (int i = 0; i < kernels_.size(); ++i) {
ctx.Wait();
timer.Start(0);
kernels_[i].Call(args...);
timer.Stop(0);
auto time = timer.ElapsedTime();
VLOG(3) << "kernel[" << i << "]: time cost is " << time;
if (time < min_time) {
min_time = time;
idx = i;
}
}
VLOG(3) << "best kernel idx is " << idx;
return kernels_[idx];
}
private:
KernelType default_kernel_;
std::vector<KernelType> kernels_;
};
template <typename RetureType, typename... Args>
static AutoTuneBase<KernelCallback<RetureType, Args...>> MakeAutoTuner(
RetureType (*func)(Args...)) {
auto obj = MakeCallback(func);
return AutoTuneBase<decltype(obj)>(obj);
}
} // namespace autotune
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "glog/logging.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/kernels/autotune/auto_tune_base.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
namespace tune = phi::autotune;
template <typename T, int VecSize>
__global__ void VecSumTest(const T* x, T* y, int N) {
#ifdef __HIPCC__
int idx = hipBlockDim_x * hipBlockIdx_x + hipThreadIdx_x;
#else
int idx = blockDim.x * blockIdx.x + threadIdx.x;
#endif
using LoadT = phi::AlignedVector<T, VecSize>;
for (int i = idx * VecSize; i < N; i += blockDim.x * gridDim.x * VecSize) {
LoadT x_vec;
LoadT y_vec;
phi::Load<T, VecSize>(&x[i], &x_vec);
phi::Load<T, VecSize>(&y[i], &y_vec);
#pragma unroll
for (int j = 0; j < VecSize; j++) {
y_vec[j] = x_vec[j] + y_vec[j];
}
phi::Store<T, VecSize>(y_vec, &y[i]);
}
}
template <int Vecsize>
float Algo(const phi::GPUContext& ctx,
const phi::DenseTensor& d_in,
phi::DenseTensor* d_out,
size_t N,
size_t threads,
size_t blocks) {
const float* d_in_data = d_in.data<float>();
float* d_out_data = d_out->data<float>();
#ifdef __HIPCC__
hipLaunchKernelGGL(HIP_KERNEL_NAME(VecSumTest<float, Vecsize>),
dim3(blocks),
dim3(threads),
0,
0,
d_in_data,
d_out_data,
N);
#else
VLOG(3) << "Vecsize is " << Vecsize;
VecSumTest<float, Vecsize><<<blocks, threads, 0, ctx.stream()>>>(
d_in_data, d_out_data, N);
#endif
return Vecsize;
}
TEST(AutoTune, sum) {
int64_t N = 1 << 22;
size_t blocks = 512;
size_t threads = 256;
size_t size = sizeof(float) * N;
const auto alloc_cpu =
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
auto in1 = std::make_shared<phi::DenseTensor>(
alloc_cpu.get(),
phi::DenseTensorMeta(
phi::DataType::FLOAT32, phi::make_ddim({N}), phi::DataLayout::NCHW));
auto in2 = std::make_shared<phi::DenseTensor>(
alloc_cpu.get(),
phi::DenseTensorMeta(
phi::DataType::FLOAT32, phi::make_ddim({N}), phi::DataLayout::NCHW));
float* in1_data = in1->data<float>();
float* in2_data = in2->data<float>();
for (size_t i = 0; i < N; i++) {
in1_data[i] = 1.0f;
in2_data[i] = 2.0f;
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
const auto alloc_cuda =
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CUDAPlace());
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
auto place = paddle::platform::CUDAPlace();
auto* dev_ctx = static_cast<const phi::GPUContext*>(pool.GetByPlace(place));
auto stream = dev_ctx->stream();
auto d_in1 = std::make_shared<phi::DenseTensor>(
alloc_cuda.get(),
phi::DenseTensorMeta(
phi::DataType::FLOAT32, phi::make_ddim({N}), phi::DataLayout::NCHW));
auto d_in2 = std::make_shared<phi::DenseTensor>(
alloc_cuda.get(),
phi::DenseTensorMeta(
phi::DataType::FLOAT32, phi::make_ddim({N}), phi::DataLayout::NCHW));
phi::Copy(*dev_ctx, *in1.get(), phi::GPUPlace(), false, d_in1.get());
phi::Copy(*dev_ctx, *in2.get(), phi::GPUPlace(), false, d_in2.get());
// 1. Test call_back.
VLOG(3) << ">>> [CallBack]: Test case.";
auto callback1 = tune::MakeCallback(Algo<4>);
auto callback2 = tune::MakeCallback(Algo<2>);
auto callback3 = tune::MakeCallback(Algo<1>);
std::vector<decltype(callback1)> callbacks{callback1, callback2, callback3};
for (int i = 0; i < callbacks.size(); ++i) {
dev_ctx->Wait();
phi::GpuTimer timer;
timer.Start(0);
callbacks[i].Call(*dev_ctx, *d_in1.get(), d_in2.get(), N, threads, blocks);
timer.Stop(0);
VLOG(3) << "kernel[" << i << "]: time cost is " << timer.ElapsedTime();
}
// 2. Test call_back tune.
VLOG(3) << ">>> [AutoTune]: Test case.";
auto tuner = tune::MakeAutoTuner(Algo<4>);
tuner.AddCallBack(tune::MakeCallback(Algo<2>));
tuner.AddCallBack(tune::MakeCallback(Algo<1>));
/* The 1st ctx works for ctx.Wait(),
the 2nd is just the param of call_back. */
auto best_call_back = tuner.PickBestKernel(
*dev_ctx, *dev_ctx, *d_in1.get(), d_in2.get(), N, threads, blocks);
best_call_back.Call(*dev_ctx, *d_in1.get(), d_in2.get(), N, threads, blocks);
dev_ctx->Wait();
phi::GpuTimer timer;
timer.Start(0);
best_call_back.Call(*dev_ctx, *d_in1.get(), d_in2.get(), N, threads, blocks);
timer.Stop(0);
VLOG(3) << "Best CallBackKernel time cost is " << timer.ElapsedTime();
#endif
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册