未验证 提交 71a63f0a 编写于 作者: L limingshu 提交者: GitHub

Transpose optimization with assitant of Chengdu Supercomputing Center and...

Transpose optimization with assitant of  Chengdu Supercomputing Center and auto_tune operation (#42704)
上级 e74f287b
......@@ -17,8 +17,12 @@ limitations under the License. */
#include "paddle/fluid/framework/gpu_utils.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/fast_divmod.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/autotune/auto_tune_base.h"
#include "paddle/phi/kernels/autotune/cache.h"
#include "paddle/phi/kernels/copy_kernel.h"
namespace paddle {
namespace operators {
......@@ -656,13 +660,437 @@ struct TransposeSimple {
}
};
template <int N, typename T>
class IdxHelper {
public:
IdxHelper() {}
explicit IdxHelper(const T* dims) {
for (int i = N - 1; i >= 0; --i) {
stride_[i] = i < (N - 1) ? dims[i + 1] * stride_[i + 1] : 1;
}
}
__device__ inline T GetStride(int idx) const { return stride_[idx]; }
__device__ inline void GetIndexFromOffset(T offset, T* index) const {
T remaining = offset;
#pragma unroll
for (int i = 0; i < N - 1; ++i) {
const T idx = remaining / stride_[i];
remaining -= idx * stride_[i];
index[i] = idx;
}
index[N - 1] = remaining;
}
private:
T stride_[N];
};
template <int N>
class IdxHelper<N, uint32_t> {
public:
IdxHelper() {}
explicit IdxHelper(const uint32_t* dims) {
for (int i = N - 1; i >= 0; --i) {
uint32_t value = i < (N - 1) ? dims[i + 1] * stride_[i + 1] : 1;
divmoder_[i] = paddle::platform::FastDivMod(value);
stride_[i] = value;
}
}
__device__ inline uint32_t GetStride(int idx) const { return stride_[idx]; }
__device__ inline void GetIndexFromOffset(uint32_t offset,
uint32_t* index) const {
uint32_t remaining = offset;
#pragma unroll
for (int i = 0; i < N - 1; ++i) {
uint32_t idx = divmoder_[i].Div(remaining);
index[i] = idx;
remaining -= idx * stride_[i];
}
index[N - 1] = remaining;
}
private:
uint32_t stride_[N];
paddle::platform::FastDivMod divmoder_[N];
};
// Transform index between memory offset and shape coodinate.
template <typename T, int N>
class IdxAndOffsetHelper {
public:
IdxAndOffsetHelper() {}
~IdxAndOffsetHelper() = default;
explicit IdxAndOffsetHelper(const T* dims) {
index_helper = IdxHelper<N, T>(dims);
}
template <typename U>
explicit IdxAndOffsetHelper(const U* dims) {
T temp_dims[N];
for (int i = 0; i < N; ++i) {
temp_dims[i] = static_cast<T>(dims[i]);
}
index_helper = IdxHelper<N, T>(temp_dims);
}
__device__ inline T IndexToOffset(const T* index) const {
T offset = 0;
#pragma unroll
for (int i = 0; i < N - 1; ++i) {
offset += index[i] * index_helper.GetStride(i);
}
offset += index[N - 1];
return offset;
}
__device__ inline void OffsetToIndex(T offset, T* index) const {
index_helper.GetIndexFromOffset(offset, index);
}
private:
IdxHelper<N, T> index_helper;
};
template <size_t Rank, typename IndexT>
struct PermuteParams {
public:
IdxAndOffsetHelper<IndexT, Rank> src_index_helper;
IdxAndOffsetHelper<IndexT, Rank> dst_index_helper;
int perm[Rank]{};
explicit PermuteParams(const std::vector<size_t>& dims,
const std::vector<int>& perm_) {
size_t dst_dims[Rank];
for (size_t i = 0; i < Rank; ++i) {
dst_dims[i] = dims[perm_[i]];
perm[i] = perm_[i];
}
dst_index_helper = IdxAndOffsetHelper<IndexT, Rank>(dst_dims);
src_index_helper = IdxAndOffsetHelper<IndexT, Rank>(dims.data());
}
};
// A special kernel for target case, both vectorized read and write supported.
template <typename T, typename IndexT, int VecSize, int Rank>
__global__ void VectorizedPermuteKernel(PermuteParams<Rank, IndexT> params,
const size_t count,
const T* __restrict__ src_data,
T* dst_data) {
using VecT = phi::AlignedVector<T, VecSize>;
IndexT src_index[Rank];
IndexT dst_index[Rank];
const VecT* __restrict__ src =
reinterpret_cast<const VecT* __restrict__>(src_data);
VecT* dst = reinterpret_cast<VecT*>(dst_data);
IndexT tid = blockIdx.x * blockDim.x + threadIdx.x;
for (IndexT i = tid; i < count; i += blockDim.x * gridDim.x) {
params.dst_index_helper.OffsetToIndex(i, dst_index);
#pragma unroll
for (int j = 0; j < Rank; ++j) {
src_index[params.perm[j]] = dst_index[j];
}
IndexT src_offset = params.src_index_helper.IndexToOffset(src_index);
dst[i] = src[src_offset];
}
}
// A general kernel for normal case, only support vectorized write.
template <typename T, typename IndexT, int VecSize, int Rank>
__global__ void GeneralPermuteKernel(PermuteParams<Rank, IndexT> params,
const T* __restrict__ src, T* dst,
const size_t main_cnt,
const size_t tail_cnt,
const size_t offset) {
using VecT = phi::AlignedVector<T, VecSize>;
VecT* vec_dst = reinterpret_cast<VecT*>(dst);
IndexT src_index[VecSize][Rank];
IndexT dst_index[VecSize][Rank];
// Avoid read perm data both in 2 load process.
__shared__ int perm[Rank];
if (threadIdx.x < Rank) {
perm[threadIdx.x] = params.perm[threadIdx.x];
}
__syncthreads();
// Vectorized load data.
IndexT tid = blockIdx.x * blockDim.x + threadIdx.x;
for (IndexT idx = tid; idx < main_cnt; idx += blockDim.x * gridDim.x) {
VecT vec_data;
IndexT vec_idx = idx * VecSize;
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
params.dst_index_helper.OffsetToIndex(vec_idx + i, dst_index[i]);
#pragma unroll
for (int j = 0; j < Rank; ++j) {
src_index[i][perm[j]] = dst_index[i][j];
}
IndexT src_offset = params.src_index_helper.IndexToOffset(src_index[i]);
vec_data[i] = src[src_offset];
}
vec_dst[idx] = vec_data;
}
// Singularized load data.
if (tid < tail_cnt) {
IndexT idx = tid + offset;
params.dst_index_helper.OffsetToIndex(idx, dst_index[0]);
#pragma unroll
for (int j = 0; j < Rank; ++j) {
src_index[0][perm[j]] = dst_index[0][j];
}
IndexT src_offset = params.src_index_helper.IndexToOffset(src_index[0]);
dst[idx] = src[src_offset];
}
}
// A Gerneral permute method that drectly find the dst data
// coordinate in the source data.
template <typename T, typename IndexT, int VecSize, int Rank>
inline void LaunchPermuteKernel(const phi::GPUContext& ctx, const IndexT count,
const PermuteType perm_type,
const std::vector<size_t>& dims,
const std::vector<int>& perm, const T* src,
T* dst) {
size_t main_count = count / VecSize;
auto params = PermuteParams<Rank, IndexT>(dims, perm);
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, main_count);
if (perm_type == PermuteType::kNormalPermute) {
size_t tail_count = count - main_count * VecSize;
size_t offset = count - tail_count;
GeneralPermuteKernel<
T, IndexT, VecSize,
Rank><<<config.GetGridSize(), config.GetBlockSize(), 0, ctx.stream()>>>(
params, src, dst, main_count, tail_count, offset);
} else {
VectorizedPermuteKernel<
T, IndexT, VecSize,
Rank><<<config.GetGridSize(), config.GetBlockSize(), 0, ctx.stream()>>>(
params, main_count, src, dst);
}
}
template <typename T, typename IndexT, int VecSize>
inline void LaunchPermuteRankDispatch(const phi::GPUContext& ctx,
const IndexT count,
const PermuteType perm_type,
const std::vector<size_t>& dims,
const std::vector<int>& perm,
const T* src, T* dst) {
#define CALL_DISPATCH_RANK(rank) \
case rank: { \
LaunchPermuteKernel<T, IndexT, VecSize, rank>(ctx, count, perm_type, dims, \
perm, src, dst); \
break; \
}
switch (dims.size()) {
CALL_DISPATCH_RANK(1);
CALL_DISPATCH_RANK(2);
CALL_DISPATCH_RANK(3);
CALL_DISPATCH_RANK(4);
CALL_DISPATCH_RANK(5);
CALL_DISPATCH_RANK(6);
CALL_DISPATCH_RANK(7);
CALL_DISPATCH_RANK(8);
CALL_DISPATCH_RANK(9);
}
#undef CALL_DISPATCH_RANK
}
// Aim at transposing the last 2 dimensions. Refer from
// https://developer.nvidia.com/blog/efficient-matrix-transpose-cuda-cc/
template <typename T, typename IndexT, int VecSize>
__global__ void BatchTransposeKernel(const T* __restrict__ src_data,
T* dst_data, IndexT rows, IndexT cols) {
using VecT = phi::AlignedVector<T, VecSize>;
__shared__ VecT tile[kTileSize][kShareCol];
T* single_tile = reinterpret_cast<T*>(tile);
IndexT col_in_matrix = blockIdx.x * kTileSize + threadIdx.x;
IndexT offset = blockIdx.z * rows * cols;
// Vectorized load data from src into shared memory. [rows, cols]
const VecT* __restrict__ src =
reinterpret_cast<const VecT* __restrict__>(src_data);
for (IndexT tile_y = threadIdx.y; tile_y < kTileSize; tile_y += kBlockRows) {
IndexT row_in_matrix = tile_y + blockIdx.y * kTileSize;
if (col_in_matrix < cols && row_in_matrix < rows) {
tile[tile_y][threadIdx.x] =
src[offset + row_in_matrix * cols + col_in_matrix];
}
}
// Singularized load data from shared memory into dst.
// and dst_cols = rows, dst_rows = cols, [cols * Vecsize, rows]
col_in_matrix = blockIdx.y * kTileSize + threadIdx.x;
offset = offset * VecSize + col_in_matrix;
IndexT tile_x_idx = threadIdx.x * (kShareCol * VecSize);
__syncthreads();
for (IndexT tile_y = threadIdx.y; tile_y < kTileSize; tile_y += kBlockRows) {
IndexT row_in_matrix = tile_y + blockIdx.x * kTileSize;
IndexT dst_idx = offset + row_in_matrix * VecSize * rows;
IndexT tile_idx = tile_x_idx + tile_y * VecSize;
if (col_in_matrix < /*dst_cols=*/rows &&
row_in_matrix < /*dst_rows=*/cols) {
#pragma unroll
for (auto i = 0; i < VecSize; ++i) {
dst_data[dst_idx + i * rows] = single_tile[tile_idx + i];
}
}
}
}
// With the byte limitation of shared_memory, the VecSize shall be restricted
// for the type whose byte-size is less than 8.
template <typename T, typename IndexT, int Size,
int VecSize = (sizeof(T) > 8 ? 1 : Size)>
inline void LaunchTransposeKernel(const phi::GPUContext& ctx,
const std::vector<size_t>& dims, const T* src,
T* dst) {
auto rank = dims.size();
IndexT num_batches = (rank == 2) ? 1 : dims[0];
IndexT rows = dims[rank - 2];
IndexT cols = dims[rank - 1];
IndexT num_tile_rows = (rows + kTileSize - 1) / kTileSize;
IndexT num_tile_cols = (cols + kTileSize - 1) / kTileSize;
dim3 blocks(num_tile_cols, num_tile_rows, num_batches);
dim3 threads(kTileSize, kBlockRows, 1);
BatchTransposeKernel<T, IndexT,
VecSize><<<blocks, threads, 0, ctx.stream()>>>(
src, dst, rows, cols);
}
template <typename T, typename IndexT>
inline void LaunchWithDispatchVecSize(const phi::GPUContext& ctx,
const int vec_size,
const PermuteType perm_type,
const std::vector<size_t>& dims,
const std::vector<int>& perm,
const T* src, T* dst, IndexT count) {
#define CALL_DISPATCH_VEC_SIZE(vec_size) \
case vec_size: { \
if (perm_type == PermuteType::kTranspose) { \
LaunchTransposeKernel<T, IndexT, vec_size>(ctx, dims, src, dst); \
} else { \
LaunchPermuteRankDispatch<T, IndexT, vec_size>(ctx, count, perm_type, \
dims, perm, src, dst); \
} \
break; \
}
switch (vec_size) {
CALL_DISPATCH_VEC_SIZE(1);
CALL_DISPATCH_VEC_SIZE(2);
CALL_DISPATCH_VEC_SIZE(4);
default: {
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported vectorized size: %d !", vec_size));
break;
}
}
#undef CALL_DISPATCH_VEC_SIZE
}
template <typename T>
void TransposeGPUKernelDriver(const phi::GPUContext& dev_ctx, const int ndims,
inline void LaunchWithDispatchIndex(const phi::GPUContext& ctx,
const size_t count, const int vec_size,
const PermuteType perm_type,
const std::vector<size_t>& dims,
const std::vector<int>& perm, const T* src,
T* dst) {
if (count < std::numeric_limits<uint32_t>::max()) {
LaunchWithDispatchVecSize<T, uint32_t>(ctx, vec_size, perm_type, dims, perm,
src, dst,
static_cast<uint32_t>(count));
} else {
int64_t cnt = static_cast<int64_t>(count);
LaunchWithDispatchVecSize<T, int64_t>(ctx, vec_size, perm_type, dims, perm,
src, dst,
static_cast<int64_t>(count));
}
}
template <typename DeviceContext, typename T>
inline void SimplifyThenLaunch(const int rank, const DeviceContext& ctx,
const Tensor& in, Tensor* out,
const std::vector<int32_t>& perm) {
int sm_count = ctx.GetSMCount();
auto src_dims = phi::vectorize<size_t>(in.dims());
auto simplifier = DimsSimplifier<T>(sm_count, rank, perm, src_dims,
in.data<T>(), out->data<T>());
if (simplifier.GetPermType() == PermuteType::kCopy) {
// If perm is [0,1,2,3], then just operate a DtoD copy.
phi::Copy(ctx, in, ctx.GetPlace(), false, out);
} else {
LaunchWithDispatchIndex<T>(
ctx, simplifier.GetCount(), simplifier.GetVecSize(),
simplifier.GetPermType(), simplifier.GetDims(), simplifier.GetPerm(),
in.data<T>(), out->data<T>());
}
}
template <typename T>
size_t GetTransposeKey(const int rank, const Tensor& in,
const std::vector<int32_t>& perm) {
auto in_shape = phi::vectorize(in.dims());
return phi::autotune::GetKey(
in_shape, perm, rank, paddle::experimental::CppTypeToDataType<T>::Type());
}
template <typename T>
void TransposeGPUKernelDriver(const phi::GPUContext& dev_ctx, const int rank,
const Tensor& in,
const std::vector<int32_t>& perm, Tensor* out) {
PADDLE_ENFORCE_LT(
rank, phi::DDim::kMaxRank,
platform::errors::OutOfRange(
"The maximum dimension rank of "
"tensor is expected to be less than %d, but here is %d.",
phi::DDim::kMaxRank, rank));
auto ret = TransposeSimple<T>::run(dev_ctx, in, perm, out);
if (!ret) {
TransCompute<phi::GPUContext, T>(ndims, dev_ctx, in, out, perm);
auto* tuner = phi::autotune::MakeTransposeTuner<T>(
SimplifyThenLaunch<phi::GPUContext, T>);
if (!tuner->IsInit()) {
tuner->AddCallBack(
phi::autotune::MakeCallback<T>(TransCompute<phi::GPUContext, T>));
tuner->Finalize();
}
auto key = GetTransposeKey<T>(rank, in, perm);
auto& cache = phi::autotune::AutoTuneCache::Instance().GetTranspose();
if (cache.Find(key)) {
auto index = cache.Get(key);
tuner->RunBestKernel(index, rank, dev_ctx, in, out, perm);
} else {
// All avaliable kernels have ran while picking the best kernel, so
// there may be no need for another RunBestKernel.
auto index = tuner->PickBestKernel(dev_ctx, rank, dev_ctx, in, out, perm);
cache.Set(key, index);
}
}
}
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
......@@ -60,5 +61,182 @@ inline void TransCompute(const int dim, const DeviceContext& dev_ctx,
}
}
enum PermuteType {
kCopy = 1,
kTranspose = 2,
kVecPermute = 3,
kNormalPermute = 4
};
constexpr int kBlockRows = 16;
constexpr int kTileSize = 32;
// To avoid bank conflict.
constexpr int kShareCol = kTileSize + 1;
// Simplify the input dims and permute dims if possible.
template <typename T>
class DimsSimplifier {
public:
explicit DimsSimplifier(const int sm_count, const int rank,
const std::vector<int32_t>& perm,
const std::vector<size_t>& dims, const T* src, T* dst)
: perm_(rank), dims_(rank) {
SimplifyPermAndDims(rank, dims, perm);
count_ = std::accumulate(dims.begin(), dims.end(), size_t{1},
std::multiplies<size_t>());
if (rank_ > 1) {
vec_size_ = GetPermVecSize(sm_count, src, dst);
perm_.resize(rank_);
dims_.resize(rank_);
}
}
size_t GetCount() const { return count_; }
int GetVecSize() const { return vec_size_; }
PermuteType GetPermType() const { return type_; }
std::vector<int> GetPerm() const { return perm_; }
std::vector<size_t> GetDims() const { return dims_; }
private:
size_t rank_{1};
size_t count_{0};
int vec_size_{1};
std::vector<int> perm_;
std::vector<size_t> dims_;
PermuteType type_{kCopy};
void SimplifyPermAndDims(const size_t rank,
const std::vector<size_t>& in_dims,
const std::vector<int32_t>& perm) {
size_t combined_dims[phi::DDim::kMaxRank];
int valid_map[phi::DDim::kMaxRank];
// Merge consecutive dims to the fist one of this these dims,
// and leave the origin dim value to be 1. Example below :
// perm: [2, 3, 0, 1], origin_dims : [4, 8, 2, 5]
// new_dims: [4, 8, 2, 5] -> [32, 1, 10, 1]
size_t start_perm_idx = 0;
while (start_perm_idx < rank) {
const size_t start_dim_idx = perm[start_perm_idx];
combined_dims[start_dim_idx] = in_dims[start_dim_idx];
size_t end_perm_idx = start_perm_idx + 1;
while (end_perm_idx < rank &&
perm[end_perm_idx] == perm[end_perm_idx - 1] + 1) {
const size_t end_dim_idx = perm[end_perm_idx];
combined_dims[start_dim_idx] *= in_dims[end_dim_idx];
combined_dims[end_dim_idx] = 1;
end_perm_idx += 1;
}
start_perm_idx = end_perm_idx;
}
// Reorder combined dims and marked useless dim as -1.
// for example, if combined dims is [32, 1, 10, 1],
// valid_map is [0, -1, 1, -1] and generate simplified
// dims as [32, 10]
size_t valid_dim_idx = 0;
bool sequential_flag = false;
for (size_t i = 0; i < rank; ++i) {
const int src_dim = combined_dims[i];
if (src_dim == 1) {
valid_map[i] = -1;
} else {
sequential_flag = true;
valid_map[i] = valid_dim_idx;
dims_[valid_dim_idx] = src_dim;
valid_dim_idx += 1;
}
}
if (valid_dim_idx == 0) {
dims_[0] = 1;
perm_[0] = 0;
return;
} else if (valid_dim_idx == 1) {
type_ = PermuteType::kCopy;
}
// Acquire simplified perm with help of combined dims
// and original perm, finally simplified perm is [1, 0]
size_t perm_idx = 0;
for (size_t i = 0; i < rank; ++i) {
const int mapped = valid_map[perm[i]];
if (mapped >= 0) {
perm_[perm_idx] = mapped;
perm_idx += 1;
}
}
rank_ = valid_dim_idx;
}
int GetPermVecSize(const int sm_count, const T* src, T* dst) {
// For gerneal_permute kernel, there is good chance for
// vectorized write.
int vec_size = phi::GetVectorizedSize<T>(dst);
type_ = PermuteType::kNormalPermute;
// While the last dim is fixed, there is good chance for
// both vectorized read and write.
if (perm_[rank_ - 1] == rank_ - 1) {
int tmp_size = std::min(vec_size, phi::GetVectorizedSize<T>(src));
tmp_size = GetDimVesSize(tmp_size, dims_[rank_ - 1]);
if (tmp_size > 1) {
type_ = kVecPermute;
vec_size = tmp_size;
// For stride calculation of src_data index.
dims_[rank_ - 1] /= vec_size;
}
}
// Once only transpose at the last 2 dims, there is good
// chance for vectorized read.
if ((rank_ == 2 && perm_[1] == 0 && perm_[0] == 1) ||
(rank_ == 3 && perm_[2] == 1 && perm_[1] == 2)) {
type_ = PermuteType::kTranspose;
// Compared with vectorized load or read, set config to let more
// sm work simultaneously affect more according to performance.
constexpr int threads = kTileSize * kTileSize;
int blocks = count_ / threads;
if (blocks < sm_count) {
vec_size = 1;
} else {
int tmp_vec = std::min(vec_size, phi::GetVectorizedSize<T>(src));
// With bytes limitation of shared_memory, the VecSize shall be
// restricted for the type whose byte-size is less than 8 (double).
int type_vec =
sizeof(T) > 8 ? 1 : GetDimVesSize(tmp_vec, dims_[rank_ - 1]);
for (int i = type_vec; i > 0; i /= 2) {
if (blocks / i >= sm_count) {
break;
}
// When blocks is smaller than sm_count, a test shown that decrease
// vec_size to make blocks close to sm_count would gain performance.
vec_size = i;
}
}
dims_[rank_ - 1] /= vec_size;
count_ /= vec_size;
}
return vec_size;
}
// To find if highest common divisor and make it as vec_size.
int GetDimVesSize(const int vec_size, const size_t target_dim) {
int dim_vec_size = 1;
for (auto size = vec_size; size > 0; size /= 2) {
if (target_dim % size == 0) {
dim_vec_size = size;
break;
}
}
return dim_vec_size;
}
};
} // namespace operators
} // namespace paddle
......@@ -59,8 +59,8 @@ struct FastDivMod {
return result;
}
int32_t divisor;
int32_t shift_val;
uint32_t divisor;
uint32_t multiplier;
};
......
......@@ -14,6 +14,7 @@
#pragma once
#include <mutex>
#include <type_traits>
#include "glog/logging.h"
......@@ -23,7 +24,7 @@
namespace phi {
namespace autotune {
template <typename RetureType, typename... Args>
template <typename T, typename RetureType, typename... Args>
class KernelCallback {
public:
using ReturnT = RetureType;
......@@ -33,71 +34,126 @@ class KernelCallback {
explicit KernelCallback(FuncType func_) : func(func_) {}
virtual ~KernelCallback() {}
RetureType Call(Args... args) { return func(args...); }
RetureType Run(Args... args) { return func(args...); }
private:
FuncType func;
};
template <typename RetureType, typename... Args>
static KernelCallback<RetureType, Args...> MakeCallback(
template <typename T, typename RetureType, typename... Args>
static KernelCallback<T, RetureType, Args...> MakeCallback(
RetureType (*cb)(Args...)) {
return KernelCallback<RetureType, Args...>(cb);
return KernelCallback<T, RetureType, Args...>(cb);
}
template <typename KernelType>
template <typename T, typename KernelType>
class AutoTuneBase {
public:
AutoTuneBase() {}
virtual ~AutoTuneBase() {}
explicit AutoTuneBase(KernelType kernel) : default_kernel_(kernel) {
explicit AutoTuneBase(KernelType kernel) { kernels_.push_back(kernel); }
template <typename Type>
void AddCallBack(Type kernel) {
static_assert(std::is_same<Type, KernelType>::value,
"Type must be the same");
kernels_.push_back(kernel);
}
template <typename T>
void AddCallBack(T kernel) {
static_assert(std::is_same<T, KernelType>::value, "Type must be the same");
kernels_.push_back(kernel);
template <typename... Args>
void RunBestKernel(const int idx, Args&&... args) {
kernels_[idx].Run(args...);
}
template <typename... Args>
void RunDefaultKernel(Args&&... args) {
kernels_[0].Run(args...);
}
template <typename Context, typename... Args>
KernelType PickBestKernel(const Context& ctx, Args&&... args) {
int PickBestKernel(const Context& ctx, Args&&... args) {
PADDLE_ENFORCE_GT(
kernels_.size(),
0,
paddle::platform::errors::InvalidArgument(
"kernel num must be greater than 0, now is %d", kernels_.size()));
int idx = 0;
phi::GpuTimer timer;
int best_idx = 0;
float min_time = std::numeric_limits<float>::max();
// Time cost test estabulished in default stream.
for (int i = 0; i < kernels_.size(); ++i) {
ctx.Wait();
timer.Start(0);
kernels_[i].Call(args...);
timer.Stop(0);
auto time = timer.ElapsedTime();
VLOG(3) << "kernel[" << i << "]: time cost is " << time;
auto time = RunAndMeasureKernel<Context>(ctx, i, args...);
if (time < min_time) {
min_time = time;
idx = i;
best_idx = i;
}
}
VLOG(3) << "best kernel idx is " << idx;
return kernels_[idx];
VLOG(3) << "best kernel idx is " << best_idx;
return best_idx;
}
bool IsInit() { return is_init_; }
void Finalize() { is_init_ = true; }
private:
KernelType default_kernel_;
bool is_init_{false};
std::vector<KernelType> kernels_;
template <typename Context, typename... Args>
float RunAndMeasureKernel(const Context& ctx, const int idx, Args&&... args) {
phi::GpuTimer timer;
float time_cost = 0;
const auto& stream = ctx.stream();
// Treat 1st run as warm up. Judge the result with
// the sum of 2nd and 3rd run.
constexpr int repeats = 3;
ctx.Wait();
for (int i = 0; i < repeats; ++i) {
timer.Start(stream);
kernels_[idx].Run(args...);
timer.Stop(stream);
auto time = timer.ElapsedTime();
if (i > 0) {
time_cost += time;
}
VLOG(3) << "kernel[" << idx << "][" << i << "th time cost is " << time;
}
return time_cost;
}
};
template <typename RetureType, typename... Args>
static AutoTuneBase<KernelCallback<RetureType, Args...>> MakeAutoTuner(
template <typename T, typename RetureType, typename... Args>
static AutoTuneBase<T, KernelCallback<T, RetureType, Args...>> MakeAutoTuner(
RetureType (*func)(Args...)) {
auto obj = MakeCallback(func);
return AutoTuneBase<decltype(obj)>(obj);
auto obj = MakeCallback<T>(func);
return AutoTuneBase<T, decltype(obj)>(obj);
}
template <typename T, typename KernelType>
class TransposeAutoTuner : public AutoTuneBase<T, KernelType> {
public:
static AutoTuneBase<T, KernelType>* Instance(KernelType kernel) {
static std::unique_ptr<AutoTuneBase<T, KernelType>> instance_;
std::call_once(init_flag_, [&] {
instance_.reset(new AutoTuneBase<T, KernelType>(kernel));
});
return instance_.get();
}
private:
static std::once_flag init_flag_;
};
template <typename T, typename KernelType>
std::once_flag TransposeAutoTuner<T, KernelType>::init_flag_;
template <typename T, typename RetureType, typename... Args>
static AutoTuneBase<T, KernelCallback<T, RetureType, Args...>>*
MakeTransposeTuner(RetureType (*func)(Args...)) {
auto obj = MakeCallback<T>(func);
return TransposeAutoTuner<T, decltype(obj)>::Instance(obj);
}
} // namespace autotune
......
......@@ -74,7 +74,7 @@ float Algo(const phi::GPUContext& ctx,
}
TEST(AutoTune, sum) {
int64_t N = 1 << 22;
int64_t N = 1 << 20;
size_t blocks = 512;
size_t threads = 256;
size_t size = sizeof(float) * N;
......@@ -119,35 +119,35 @@ TEST(AutoTune, sum) {
// 1. Test call_back.
VLOG(3) << ">>> [CallBack]: Test case.";
auto callback1 = tune::MakeCallback(Algo<4>);
auto callback2 = tune::MakeCallback(Algo<2>);
auto callback3 = tune::MakeCallback(Algo<1>);
auto callback1 = tune::MakeCallback<float>(Algo<4>);
auto callback2 = tune::MakeCallback<float>(Algo<2>);
auto callback3 = tune::MakeCallback<float>(Algo<1>);
std::vector<decltype(callback1)> callbacks{callback1, callback2, callback3};
for (int i = 0; i < callbacks.size(); ++i) {
dev_ctx->Wait();
phi::GpuTimer timer;
timer.Start(0);
callbacks[i].Call(*dev_ctx, *d_in1.get(), d_in2.get(), N, threads, blocks);
callbacks[i].Run(*dev_ctx, *d_in1.get(), d_in2.get(), N, threads, blocks);
timer.Stop(0);
VLOG(3) << "kernel[" << i << "]: time cost is " << timer.ElapsedTime();
}
// 2. Test call_back tune.
VLOG(3) << ">>> [AutoTune]: Test case.";
auto tuner = tune::MakeAutoTuner(Algo<4>);
tuner.AddCallBack(tune::MakeCallback(Algo<2>));
tuner.AddCallBack(tune::MakeCallback(Algo<1>));
auto tuner = tune::MakeAutoTuner<float>(Algo<4>);
tuner.AddCallBack(tune::MakeCallback<float>(Algo<2>));
tuner.AddCallBack(tune::MakeCallback<float>(Algo<1>));
/* The 1st ctx works for ctx.Wait(),
the 2nd is just the param of call_back. */
auto best_call_back = tuner.PickBestKernel(
auto best_index = tuner.PickBestKernel(
*dev_ctx, *dev_ctx, *d_in1.get(), d_in2.get(), N, threads, blocks);
best_call_back.Call(*dev_ctx, *d_in1.get(), d_in2.get(), N, threads, blocks);
dev_ctx->Wait();
phi::GpuTimer timer;
timer.Start(0);
best_call_back.Call(*dev_ctx, *d_in1.get(), d_in2.get(), N, threads, blocks);
tuner.RunBestKernel(
best_index, *dev_ctx, *d_in1.get(), d_in2.get(), N, threads, blocks);
timer.Stop(0);
VLOG(3) << "Best CallBackKernel time cost is " << timer.ElapsedTime();
#endif
......
......@@ -134,7 +134,8 @@ enum class AlgorithmType {
kConvForward = 1,
kConvBackwardData = 2,
kConvBackwardFilter = 3,
kAlgorithmCount = 4
kTranspose = 4,
kAlgorithmCount = 5
};
// AlgorithmsConfigKey -> AlgorithmsID
......@@ -165,6 +166,8 @@ class AutoTuneCache {
return Get(AlgorithmType::kConvBackwardFilter);
}
AlgorithmsCacheMap& GetTranspose() { return Get(AlgorithmType::kTranspose); }
void Clean() {
for (auto& v : auto_tune_map_) {
v.second.Clean();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册