未验证 提交 6c9df13d 编写于 作者: L limingshu 提交者: GitHub

Divide elementwise case from BroadcastKernel and refine transpose autotune (#33051)

* First Commit.

* add some codes

* add elementwise loader

* fix code styles

* merge with develop

* add some changes both in elementwise and transpose

* add init operation in broadcast kernel.

* change codes according to pr suggestions about transpose file

* fix error for op-benchmark ci

* fix according to ci
上级 f0dab193
......@@ -16,7 +16,6 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/phi/kernels/funcs/transpose_functor.h"
namespace paddle {
namespace operators {
......
......@@ -20,7 +20,6 @@ limitations under the License. */
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/funcs/transpose_functor.h"
namespace paddle {
namespace operators {
......
......@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/phi/kernels/funcs/transpose_functor.h"
namespace paddle {
namespace operators {
......
......@@ -24,7 +24,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/transpose_functor.h"
namespace paddle {
namespace operators {
......
......@@ -123,7 +123,7 @@ class AutoTuneBase {
float RunAndMeasureKernel(const Context& ctx, const int idx, Args&&... args) {
// Regard 1st run as warmup, judge the compare result by the time cost
// of rest cycles.
constexpr int repeats = 4;
constexpr int repeats = 6;
phi::GpuTimer timer;
float time_cost = 0;
const auto& stream = ctx.stream();
......
......@@ -29,12 +29,23 @@ namespace funcs {
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
template <typename InT, typename OutT>
int GetVecsize(const std::vector<const DenseTensor *> &ins,
enum BroadcastLoadType { kMixed = 1, kBroadcast = 2, kElementwise = 3 };
template <typename InT, typename OutT, int Arity>
struct LoaderTypeClassifier {
public:
int64_t numel{0};
int vec_size{1};
int broadcast_num{0};
bool all_elementwise{true};
phi::Array<int, Arity> use_broadcast;
phi::Array<const _ptr_ InT *__restrict__, Arity> ins_data;
LoaderTypeClassifier() {}
LoaderTypeClassifier(const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs) {
int in_vec_size = 4;
int out_vec_size = 4;
if (outs->size() > 1) {
int out_vec_size =
std::min(4, phi::GetVectorizedSize<OutT>((*outs)[0]->data<OutT>()));
for (auto i = 1; i < outs->size(); ++i) {
PADDLE_ENFORCE_EQ(
(*outs)[i]->dims(),
......@@ -46,25 +57,33 @@ int GetVecsize(const std::vector<const DenseTensor *> &ins,
out_vec_size = std::min(
phi::GetVectorizedSize<OutT>((*outs)[i]->data<OutT>()), out_vec_size);
}
numel = (*outs)[0]->numel();
for (int i = 0; i < Arity; ++i) {
auto in_data = ins[i]->data<InT>();
ins_data[i] = (const _ptr_ InT *)(in_data);
bool is_same_dim = ins[i]->numel() == numel;
if (is_same_dim) {
use_broadcast[i] = false;
auto temp_size = phi::GetVectorizedSize<InT>(in_data);
in_vec_size = std::min(temp_size, in_vec_size);
} else {
out_vec_size = phi::GetVectorizedSize<OutT>((*outs)[0]->data<OutT>());
use_broadcast[i] = true;
broadcast_num++;
}
for (auto *in : ins) {
auto temp_size = phi::GetVectorizedSize<InT>(in->data<InT>());
in_vec_size = in->dims() == (*outs)[0]->dims()
? std::min(temp_size, in_vec_size)
: in_vec_size;
all_elementwise &= is_same_dim;
}
vec_size = std::min(out_vec_size, in_vec_size);
}
return std::min(out_vec_size, in_vec_size);
}
private:
int in_vec_size{4};
};
#ifndef PADDLE_WITH_XPU_KP
template <typename T,
int VecSize,
int Arity,
bool IsBoundary,
bool is_all_broadcast>
// Common broadcast/elementwise Loader.
template <typename T, int VecSize, int Arity, bool IsBoundary, int LoadType>
struct BroadcastDataLoader {
__device__ __forceinline__ void operator()(
T args[Arity][VecSize],
......@@ -88,8 +107,63 @@ struct BroadcastDataLoader {
}
};
// Scalar elementwise Loader with consideration of IsBoundary.
template <typename T, int VecSize, int Arity>
struct BroadcastDataLoader<T, VecSize, Arity, true, kElementwise> {
__device__ __forceinline__ void operator()(
T args[Arity][VecSize],
const phi::Array<const _ptr_ T *__restrict__, Arity> &ins,
const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
const phi::Array<int, Arity> &use_broadcast,
const int block_offset,
const int num,
const uint32_t numel) {
int thread_offset = threadIdx.x * VecSize + block_offset;
#pragma unroll
for (int i = 0; i < Arity; ++i) {
#pragma unroll
for (int idx = 0; idx < VecSize; ++idx) {
args[i][idx] = static_cast<T>(1);
int index = thread_offset + idx;
if (index < numel) {
args[i][idx] = ins[i][index];
}
}
}
}
};
// Vectorized elementwise Loader without consideration of IsBoundary.
template <typename T, int VecSize, int Arity>
struct BroadcastDataLoader<T, VecSize, Arity, false, kElementwise> {
__device__ __forceinline__ void operator()(
T args[Arity][VecSize],
const phi::Array<const _ptr_ T *__restrict__, Arity> &ins,
const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
const phi::Array<int, Arity> &use_broadcast,
const int block_offset,
const int num,
const uint32_t numel) {
using VecType = phi::kps::details::VectorType<T, VecSize>;
VecType vec_temp[Arity];
int thread_offset = threadIdx.x + blockIdx.x * blockDim.x;
#pragma unroll
for (int i = 0; i < Arity; ++i) {
const VecType *__restrict__ vec_input =
reinterpret_cast<const VecType *__restrict__>(ins[i]);
vec_temp[i] = vec_input[thread_offset];
#pragma unroll
for (int idx = 0; idx < VecSize; ++idx) {
args[i][idx] = vec_temp[i].val[idx];
}
}
}
};
// Common broadcast data loader.
template <typename T, int VecSize, int Arity, bool IsBoundary>
struct BroadcastDataLoader<T, VecSize, Arity, IsBoundary, true> {
struct BroadcastDataLoader<T, VecSize, Arity, IsBoundary, kBroadcast> {
__device__ __forceinline__ void operator()(
T args[Arity][VecSize],
const phi::Array<const _ptr_ T *__restrict__, Arity> &ins,
......@@ -146,7 +220,7 @@ template <typename InT,
int NumOuts,
int VecSize,
bool IsBoundary,
bool IsAllBroadcast = false>
int LoadType>
__device__ void VectorizedBroadcastKernelImpl(
const phi::Array<const _ptr_ InT *__restrict__, Arity> &ins,
phi::Array<_ptr_ OutT *, NumOuts> outs,
......@@ -172,7 +246,7 @@ __device__ void VectorizedBroadcastKernelImpl(
}
}
#else
BroadcastDataLoader<InT, VecSize, Arity, IsBoundary, IsAllBroadcast>()(
BroadcastDataLoader<InT, VecSize, Arity, IsBoundary, LoadType>()(
args, ins, configs, use_broadcast, block_offset, num, numel);
#endif
......@@ -196,7 +270,7 @@ template <typename Functor,
int Arity,
int NumOuts,
int VecSize,
bool IsAllBroadcast>
int LoadType>
__global__ void VectorizedBroadcastKernel(
phi::Array<const _ptr_ InT *__restrict__, Arity> ins,
phi::Array<_ptr_ OutT *, NumOuts> outs,
......@@ -218,7 +292,7 @@ __global__ void VectorizedBroadcastKernel(
NumOuts,
VecSize,
false,
IsAllBroadcast>(ins,
LoadType>(ins,
outs,
use_broadcast,
numel,
......@@ -237,7 +311,7 @@ __global__ void VectorizedBroadcastKernel(
NumOuts,
VecSize,
true,
IsAllBroadcast>(ins,
LoadType>(ins,
outs,
use_broadcast,
numel,
......@@ -257,7 +331,7 @@ __global__ void VectorizedBroadcastKernel(
NumOuts,
VecSize,
false,
IsAllBroadcast>(ins,
LoadType>(ins,
outs,
use_broadcast,
numel,
......@@ -274,7 +348,7 @@ __global__ void VectorizedBroadcastKernel(
NumOuts,
VecSize,
true,
IsAllBroadcast>(ins,
LoadType>(ins,
outs,
use_broadcast,
numel,
......@@ -289,7 +363,7 @@ __global__ void VectorizedBroadcastKernel(
template <typename InT,
typename OutT,
typename Functor,
typename Func,
int Arity,
int NumOuts,
int VecSize>
......@@ -297,29 +371,16 @@ void LaunchBroadcastKernel(
const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
Functor func,
const phi::Array<kps::details::BroadcastConfig, Arity> &configs) {
int broadcast_num = 0;
int numel = (*outs)[0]->numel();
phi::Array<int, Arity> use_broadcast;
phi::Array<const _ptr_ InT *__restrict__, Arity> ins_data;
Func func,
const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
const LoaderTypeClassifier<InT, OutT, Arity> &loader_classifier) {
phi::Array<_ptr_ OutT *, NumOuts> outs_data;
for (int i = 0; i < NumOuts; ++i) {
outs_data[i] = (_ptr_ OutT *)(ctx.Alloc<OutT>((*outs)[i]));
}
for (int i = 0; i < Arity; ++i) {
if (ins[i]->numel() != numel) {
broadcast_num++;
use_broadcast[i] = true;
} else {
use_broadcast[i] = false;
}
ins_data[i] = (const _ptr_ InT *)(ins[i]->data<InT>());
}
#ifdef PADDLE_WITH_XPU_KP
int numel = (*outs)[0]->numel();
const int threads = 64;
const int blocks = 8;
int read_lens = configs[0].buf_len;
......@@ -327,10 +388,10 @@ void LaunchBroadcastKernel(
int main_offset = (numel / (read_lens * threads)) * read_lens * threads;
int tail_tid = numel % (read_lens * threads);
VectorizedBroadcastKernel<Functor, InT, OutT, Arity, NumOuts, VecSize, false>
<<<blocks, threads, 0, stream>>>(ins_data,
VectorizedBroadcastKernel<Func, InT, OutT, Arity, NumOuts, VecSize, false>
<<<blocks, threads, 0, stream>>>(loader_classifier.ins_data,
outs_data,
use_broadcast,
loader_classifier.use_broadcast,
numel,
configs,
main_offset,
......@@ -338,49 +399,54 @@ void LaunchBroadcastKernel(
read_lens,
func);
#else
const auto &numel = loader_classifier.numel;
auto gpu_config =
phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, VecSize);
int read_lens = VecSize;
auto stream = ctx.stream();
auto threads = gpu_config.thread_per_block;
auto threads = gpu_config.GetBlockSize();
auto blocks = gpu_config.block_per_grid;
int main_offset = (numel / (read_lens * gpu_config.GetBlockSize())) *
read_lens * gpu_config.GetBlockSize();
int tail_tid = numel % (read_lens * gpu_config.GetBlockSize());
int main_offset = (numel / (VecSize * threads)) * VecSize * threads;
int tail_tid = numel % (VecSize * threads);
if (broadcast_num > (Arity >> 1)) {
VectorizedBroadcastKernel<Functor,
if (loader_classifier.all_elementwise) {
VectorizedBroadcastKernel<Func,
InT,
OutT,
Arity,
NumOuts,
VecSize,
(Arity > 1)>
<<<blocks, threads, 0, stream>>>(ins_data,
kElementwise>
<<<blocks, threads, 0, stream>>>(loader_classifier.ins_data,
outs_data,
use_broadcast,
loader_classifier.use_broadcast,
numel,
configs,
main_offset,
tail_tid,
read_lens,
VecSize,
func);
} else {
VectorizedBroadcastKernel<Functor,
InT,
OutT,
Arity,
NumOuts,
} else if (loader_classifier.broadcast_num > (Arity >> 1)) {
constexpr BroadcastLoadType type_ = (Arity > 1) ? kBroadcast : kMixed;
VectorizedBroadcastKernel<Func, InT, OutT, Arity, NumOuts, VecSize, type_>
<<<blocks, threads, 0, stream>>>(loader_classifier.ins_data,
outs_data,
loader_classifier.use_broadcast,
numel,
configs,
main_offset,
tail_tid,
VecSize,
false>
<<<blocks, threads, 0, stream>>>(ins_data,
func);
} else {
VectorizedBroadcastKernel<Func, InT, OutT, Arity, NumOuts, VecSize, kMixed>
<<<blocks, threads, 0, stream>>>(loader_classifier.ins_data,
outs_data,
use_broadcast,
loader_classifier.use_broadcast,
numel,
configs,
main_offset,
tail_tid,
read_lens,
VecSize,
func);
}
#endif
......@@ -777,21 +843,6 @@ struct LaunchBroadcastKernelWithInt64IndexHelper<InT,
};
#endif
template <typename T>
static std::string ReversedVectorToString(const std::vector<T> &reversed_v) {
std::stringstream ss;
bool is_last = true;
for (int i = reversed_v.size() - 1; i >= 0; --i) {
if (is_last) {
ss << reversed_v[i];
is_last = false;
} else {
ss << ", " << reversed_v[i];
}
}
return ss.str();
}
template <ElementwiseType ET,
typename InT,
typename OutT,
......@@ -839,8 +890,8 @@ void BroadcastKernelForDifferentVecSize(
kEnabledInt64IndexKernel &&
(*outs)[0]->numel() >= std::numeric_limits<int32_t>::max();
if (use_int64_index_kernel) {
int vec_size = GetVecsize<InT, OutT>(ins, outs);
switch (vec_size) {
auto loader_classifier = LoaderTypeClassifier<InT, OutT, kArity>(ins, outs);
switch (loader_classifier.vec_size) {
case VecSizeL: {
LaunchBroadcastKernelWithInt64IndexHelper<InT,
OutT,
......@@ -882,7 +933,7 @@ void BroadcastKernelForDifferentVecSize(
}
default: {
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported vectorized size: %d!", vec_size));
"Unsupported vectorized size: %d!", loader_classifier.vec_size));
break;
}
}
......@@ -890,30 +941,21 @@ void BroadcastKernelForDifferentVecSize(
}
#endif
// mergedim and get vec_size
const auto dims_simplifier =
BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis);
if (VLOG_IS_ON(6)) {
for (size_t i = 0; i < ins.size(); ++i) {
VLOG(6) << "input i=" << i << ": origin_dims={" << ins[i]->dims()
<< "}, simplied_dims={"
<< ReversedVectorToString<int64_t>(dims_simplifier.in_dims[i])
<< "}";
}
VLOG(6) << "output: origin_dims={" << (*outs)[0]->dims()
<< "}, simplied_dims={"
<< ReversedVectorToString<int64_t>(dims_simplifier.out_dims) << "}";
}
phi::Array<kps::details::BroadcastConfig, kArity> configs;
// get vec_size
#ifdef PADDLE_WITH_XPU_KP
PADDLE_ENFORCE_EQ(
ins.size(),
2,
phi::errors::InvalidArgument(
"XPU only support inputs is 2, but received %d", ins.size()));
auto loader_classifier = LoaderTypeClassifier<InT, OutT, kArity>();
const auto dims_simplifier =
BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis);
if (VLOG_IS_ON(6)) {
DimsSimplifiedLogger<int64_t>::Log(
ins, outs, dims_simplifier, "XPU Broadcast");
}
configs[0] = kps::details::BroadcastConfig(dims_simplifier.out_dims,
dims_simplifier.in_dims[0],
dims_simplifier.in_dims[1],
......@@ -926,8 +968,16 @@ void BroadcastKernelForDifferentVecSize(
bool is_optimize = configs[0].cmp_type != type;
int vec_size = is_optimize ? VecSizeL : VecSizeM;
#else
auto loader_classifier = LoaderTypeClassifier<InT, OutT, kArity>(ins, outs);
if (!loader_classifier.all_elementwise) {
const auto dims_simplifier =
BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis);
if (VLOG_IS_ON(6)) {
DimsSimplifiedLogger<int64_t>::Log(
ins, outs, dims_simplifier, "GPU Broadcast");
}
for (int i = 0; i < kArity; ++i) {
// get the broadcast config,
// if data shape is[m, n], then you should set data_dim = {n, m}
// eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
// if (ins[i]->numel() != (*outs)[0]->numel()) {
......@@ -937,28 +987,27 @@ void BroadcastKernelForDifferentVecSize(
dims_simplifier.rank);
}
}
int vec_size = GetVecsize<InT, OutT>(ins, outs);
}
#endif
switch (vec_size) {
switch (loader_classifier.vec_size) {
case VecSizeL: {
LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, VecSizeL>(
ctx, ins, outs, func, configs);
ctx, ins, outs, func, configs, loader_classifier);
break;
}
case VecSizeM: {
LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, VecSizeM>(
ctx, ins, outs, func, configs);
ctx, ins, outs, func, configs, loader_classifier);
break;
}
case VecSizeS: {
LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, VecSizeS>(
ctx, ins, outs, func, configs);
ctx, ins, outs, func, configs, loader_classifier);
break;
}
default: {
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported vectorized size: %d!", vec_size));
"Unsupported vectorized size: %d!", loader_classifier.vec_size));
break;
}
}
......@@ -1037,7 +1086,6 @@ void DefaultElementwiseOperator(const DeviceContext &dev_ctx,
dev_ctx, x, y, axis, InverseFunctor(), z);
}
}
#endif
} // namespace funcs
......
......@@ -34,18 +34,6 @@ struct BroadcastDimsSimplifier {
BroadcastDimsSimplifier(const std::vector<const DenseTensor *> &ins,
const phi::DDim &dims,
int axis) {
if (!NeedBroadcast(ins, dims)) {
int64_t numel = phi::product(dims);
rank = 1;
N = ins.size();
out_dims = DimVector{numel};
in_dims.resize(N);
for (int64_t i = 0; i < N; ++i) {
in_dims[i] = DimVector{numel};
}
return;
}
N = std::max(static_cast<int>(ins.size()), 2);
in_dims.resize(N);
rank = dims.size();
......@@ -112,18 +100,6 @@ struct BroadcastDimsSimplifier {
}
private:
bool NeedBroadcast(const std::vector<const DenseTensor *> &ins,
const phi::DDim &dims) {
bool no_broadcast_flag = true;
for (auto *in : ins) {
no_broadcast_flag &= ins[0]->dims() == in->dims();
}
if (ins.size() > 0) {
no_broadcast_flag &= dims == ins[0]->dims();
}
return !no_broadcast_flag;
}
// To compensate the lackage of input_tensors' dimension with axis.
void ExtendInputDimensions(int N, int axis) {
for (auto &in_dim : in_dims) {
......@@ -244,9 +220,9 @@ struct BroadcastDimsSimplifier {
};
// Simplify the input dims and permute dims if possible.
struct DimsSimplifier {
struct PermuteDimsSimplifier {
public:
explicit DimsSimplifier(const int rank,
PermuteDimsSimplifier(const int rank,
const int64_t numel,
const std::vector<int32_t> &perm,
const std::vector<int64_t> &dims)
......@@ -255,7 +231,7 @@ struct DimsSimplifier {
perm_.resize(rank_);
src_dims_.resize(rank_);
dst_dims_.resize(rank_);
if (!is_seq_perm_) {
if (!is_sequential_perm_) {
for (auto i = 0; i < rank_; ++i) {
dst_dims_[i] = src_dims_[perm_[i]];
}
......@@ -265,7 +241,7 @@ struct DimsSimplifier {
}
}
~DimsSimplifier() = default;
~PermuteDimsSimplifier() = default;
const int &GetRank() const { return rank_; }
const int64_t &GetCount() const { return count_; }
......@@ -276,8 +252,8 @@ struct DimsSimplifier {
private:
int rank_{1};
int64_t count_{0};
bool is_seq_perm_{true};
std::vector<int> perm_;
bool is_sequential_perm_{true};
std::vector<int64_t> src_dims_;
std::vector<int64_t> dst_dims_;
......@@ -336,11 +312,44 @@ struct DimsSimplifier {
const int mapped = valid_map[perm[i]];
if (mapped >= 0) {
perm_[perm_idx] = mapped;
is_seq_perm_ &= (mapped == perm_idx);
is_sequential_perm_ &= (mapped == perm_idx);
perm_idx += 1;
}
}
rank_ = is_seq_perm_ ? 1 : valid_dim_idx;
rank_ = is_sequential_perm_ ? 1 : valid_dim_idx;
}
};
template <typename T>
struct DimsSimplifiedLogger {
public:
static void Log(const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
const BroadcastDimsSimplifier &dims_simplifier,
const std::string &op_name) {
VLOG(6) << op_name << "`s dims after simplification is below :";
for (size_t i = 0; i < ins.size(); ++i) {
VLOG(6) << "input i=" << i << ": origin_dims={" << ins[i]->dims()
<< "}, simplied_dims={"
<< ReversedVectorToString(dims_simplifier.in_dims[i]) << "}";
}
VLOG(6) << "output: origin_dims={" << (*outs)[0]->dims()
<< "}, simplied_dims={"
<< ReversedVectorToString(dims_simplifier.out_dims) << "}";
}
static std::string ReversedVectorToString(const std::vector<T> &reversed_v) {
std::stringstream ss;
bool is_last = true;
for (int i = reversed_v.size() - 1; i >= 0; --i) {
if (is_last) {
ss << reversed_v[i];
is_last = false;
} else {
ss << ", " << reversed_v[i];
}
}
return ss.str();
}
};
......
......@@ -19,8 +19,9 @@ limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_utils.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/autotune/auto_tune_base.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/dims_simplifier.h"
#include "paddle/phi/kernels/funcs/transpose_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/primitive/datamover_primitives.h"
namespace phi {
......@@ -705,23 +706,23 @@ inline void CombineTransposeDim3(const DDim& shape,
template <typename T>
struct TransposeSimple {
static bool Impl(const phi::GPUContext& ctx,
static bool Run(const phi::GPUContext& ctx,
const phi::DenseTensor& in,
const std::vector<int32_t> perm,
const std::vector<int32_t>& perm,
phi::DenseTensor* out,
const int64_t numel) {
if (numel >= std::numeric_limits<int32_t>::max()) {
return Run<int64_t>(ctx, in, perm, out);
return RunImpl<int64_t>(ctx, in, perm, out);
} else {
return Run<int32_t>(ctx, in, perm, out);
return RunImpl<int32_t>(ctx, in, perm, out);
}
}
private:
template <typename IndexType = int32_t>
static bool Run(const phi::GPUContext& ctx,
static bool RunImpl(const phi::GPUContext& ctx,
const phi::DenseTensor& in,
const std::vector<int32_t> perm,
const std::vector<int32_t>& perm,
phi::DenseTensor* out) {
// First reduce the dimensions of the input tensor if possible.
auto in_data = in.data<T>();
......@@ -752,13 +753,128 @@ struct TransposeSimple {
}
};
template <typename IndexT, int N>
enum PermuteType {
kCopy = 1,
kSwapTranspose = 2,
kGeneralTranspose = 3,
kVecPermute = 4,
kGeneralPermute = 5
};
constexpr int kBlockRows = 16;
constexpr int kTileSize = 32;
constexpr int kShareCol = (kTileSize + 1);
#define GET_TILE_SIZE(LEN_, ALIGN_) \
((LEN_ + (ALIGN_ - 1)) & ~(ALIGN_ - 1)) / ALIGN_
template <typename T>
struct PermTypeClassifier {
public:
PermTypeClassifier(const int sm_count,
const int rank,
const std::vector<int32_t>& perm,
const std::vector<int64_t>& dims,
const T* src,
T* dst) {
if (rank == 1) {
type_ = PermuteType::kCopy;
} else {
// Limitation of the setting in one dimension of cuda grid.
constexpr int64_t dim_limitation = 65536;
int dst_vec_size = phi::GetVectorizedSize<T>(dst);
// While the last dim is fixed, there is chance for vectorized IO.
const int last_idx = rank - 1;
if (perm[last_idx] == last_idx) {
type_ = PermuteType::kVecPermute;
vec_size_ = GetDimVecSize(dst_vec_size, dims[last_idx], src, false);
return;
}
// Permute at last 2 dims, namely transpose.
if ((rank == 2 && perm[1] == 0) ||
(rank == 3 && perm[2] == 1 && perm[1] == 2)) {
int64_t channel = rank == 2 ? 1 : dims[0];
// Currently, transpose kernel cannot cover the case that channel
// dimension is more than 65536 which is the limitation of dim3 setting.
// This special case will be covered by extended transpose kernel later.
if (channel < dim_limitation) {
type_ = PermuteType::kGeneralTranspose;
num_rows_tile_ = GET_TILE_SIZE(dims[rank - 2], kTileSize);
int dim_vec_size = GetDimVecSize(dst_vec_size, dims[last_idx], src);
int tile_size = channel * num_rows_tile_ *
GET_TILE_SIZE(dims[last_idx], kTileSize);
vec_size_ = tile_size < sm_count ? 1 : dim_vec_size;
} else {
type_ = PermuteType::kGeneralPermute;
}
return;
}
// Permute at first dim and third dim.
if (rank == 3 && perm[2] == 0 && perm[1] == 1) {
// Currently, transpose kernel cannot cover the case that channel
// dimension is more than 65536 which is the limitation of dim3 setting.
// This special case will be covered by extended transpose kernel later.
if (dims[1] < dim_limitation) {
type_ = PermuteType::kSwapTranspose;
num_rows_tile_ = GET_TILE_SIZE(dims[0], kTileSize);
int dim_vec_size = GetDimVecSize(dst_vec_size, dims[last_idx], src);
int tile_size =
dims[1] * num_rows_tile_ * GET_TILE_SIZE(dims[2], kTileSize);
vec_size_ = tile_size < sm_count ? 1 : dim_vec_size;
} else {
type_ = PermuteType::kGeneralPermute;
}
return;
}
vec_size_ = dst_vec_size;
}
}
~PermTypeClassifier() = default;
int GetVecSize() const { return vec_size_; }
int GetRowsTile() const { return num_rows_tile_; }
PermuteType GetPermType() const { return type_; }
private:
int vec_size_{1};
int64_t num_rows_tile_{0};
PermuteType type_{kGeneralPermute};
// To find if highest common divisor and make it as vec_size.
int GetDimVecSize(const int dst_vec_size,
const int64_t target_dim,
const T* src,
bool use_share_mem = true) {
int vec_size = std::min(dst_vec_size, phi::GetVectorizedSize<T>(src));
int dim_vec_size = 1;
for (int size = vec_size; size > 0; size /= 2) {
if (target_dim % size == 0) {
dim_vec_size = size;
break;
}
}
if (use_share_mem) {
// By bytes limitation of shared_memory.
return (sizeof(T) > sizeof(float) ? 1 : dim_vec_size);
} else {
return dim_vec_size;
}
}
};
template <typename IndexT, int Rank>
class IdxHelper {
public:
IdxHelper() {}
explicit IdxHelper(const IndexT* dims) {
for (int i = N - 1; i >= 0; --i) {
stride_[i] = i < (N - 1) ? dims[i + 1] * stride_[i + 1] : 1;
for (int i = Rank - 1; i >= 0; --i) {
stride_[i] = i < (Rank - 1) ? dims[i + 1] * stride_[i + 1] : 1;
}
}
......@@ -770,25 +886,25 @@ class IdxHelper {
IndexT* index) const {
IndexT remaining = offset;
#pragma unroll
for (int i = 0; i < N - 1; ++i) {
for (int i = 0; i < Rank - 1; ++i) {
const IndexT idx = remaining / stride_[i];
remaining -= idx * stride_[i];
index[i] = idx;
}
index[N - 1] = remaining;
index[Rank - 1] = remaining;
}
private:
IndexT stride_[N];
IndexT stride_[Rank];
};
template <int N>
class IdxHelper<uint32_t, N> {
template <int Rank>
class IdxHelper<uint32_t, Rank> {
public:
IdxHelper() {}
explicit IdxHelper(const uint32_t* dims) {
for (int i = N - 1; i >= 0; --i) {
uint32_t value = i < (N - 1) ? dims[i + 1] * stride_[i + 1] : 1;
for (int i = Rank - 1; i >= 0; --i) {
uint32_t value = i < (Rank - 1) ? dims[i + 1] * stride_[i + 1] : 1;
divmoder_[i] = phi::kps::details::FastDivMod(value);
stride_[i] = value;
}
......@@ -802,35 +918,35 @@ class IdxHelper<uint32_t, N> {
uint32_t* index) const {
uint32_t remaining = offset;
#pragma unroll
for (int i = 0; i < N - 1; ++i) {
for (int i = 0; i < Rank - 1; ++i) {
uint32_t idx = divmoder_[i].Div(remaining);
index[i] = idx;
remaining -= idx * stride_[i];
}
index[N - 1] = remaining;
index[Rank - 1] = remaining;
}
private:
uint32_t stride_[N];
phi::kps::details::FastDivMod divmoder_[N];
uint32_t stride_[Rank];
phi::kps::details::FastDivMod divmoder_[Rank];
};
// Transform index between memory offset and shape coodinate.
template <typename IndexT, int N>
template <typename IndexT, int Rank>
class IdxAndOffsetHelper {
public:
IdxAndOffsetHelper() {}
explicit IdxAndOffsetHelper(const IndexT* dims) {
index_helper = IdxHelper<IndexT, N>(dims);
index_helper = IdxHelper<IndexT, Rank>(dims);
}
__device__ __forceinline__ IndexT IndexToOffset(const IndexT* index) const {
IndexT offset = 0;
#pragma unroll
for (int i = 0; i < N - 1; ++i) {
for (int i = 0; i < Rank - 1; ++i) {
offset += index[i] * index_helper.GetStride(i);
}
offset += index[N - 1];
offset += index[Rank - 1];
return offset;
}
......@@ -840,7 +956,7 @@ class IdxAndOffsetHelper {
}
private:
IdxHelper<IndexT, N> index_helper;
IdxHelper<IndexT, Rank> index_helper;
};
template <typename IndexT, int Rank>
......@@ -1173,7 +1289,7 @@ struct TransposeLauncher {
T* dst) {
constexpr int ReadSize = sizeof(T) > sizeof(float) ? 1 : VecSize;
const IndexT cols = dims[rank - 1] / VecSize;
const IndexT n_cols_tile = GETTILESIZE(cols, kTileSize);
const IndexT n_cols_tile = GET_TILE_SIZE(cols, kTileSize);
if (perm_type == PermuteType::kGeneralTranspose) {
IndexT chs = (rank == 2) ? 1 : dims[0];
......@@ -1229,65 +1345,48 @@ struct TransposeLauncher {
vec_write = is_vec_write ? kVecRow : 1;
}
IndexT n_rows_tile = is_vec_write
? GETTILESIZE(rows, (kTileSize * vec_write))
? GET_TILE_SIZE(rows, (kTileSize * vec_write))
: num_rows_tile;
return n_rows_tile;
}
};
template <typename T, typename IndexT>
struct PermuteDispatch {
public:
PermuteDispatch(const phi::GPUContext& ctx,
inline void PermuteDispatch(const phi::GPUContext& ctx,
const IndexT& count,
PermTypeClassifier<T>* cls_ptr,
const std::vector<int64_t>& dims,
const std::vector<int32_t>& perm,
const IndexT count,
const T* src,
T* dst)
: dims_(dims), cls_(cls_ptr) {
rank_ = dims_.size();
type_ = cls_->GetPermType();
KernelTypeDispatch(ctx, count, perm, src, dst);
}
~PermuteDispatch() {}
private:
int rank_{0};
std::vector<int64_t> dims_;
PermTypeClassifier<T>* cls_;
PermuteType type_{kGeneralPermute};
void KernelTypeDispatch(const phi::GPUContext& ctx,
const IndexT& count,
const std::vector<int32_t>& perm,
const T* src,
T* dst) {
int rank = dims.size();
PermuteType type = cls_ptr->GetPermType();
#define TRANSPOSE_DISPATCH_VEC_SIZE(size) \
case size: { \
TransposeLauncher<T, IndexT, size>()( \
ctx, rank_, type_, dims_, cls_->GetRowsTile(), src, dst); \
ctx, rank, type, dims, cls_ptr->GetRowsTile(), src, dst); \
break; \
}
#define PERMUTE_DISPATCH_VEC_SIZE(size) \
case size: { \
PermuteLauncher<T, IndexT, size>()( \
ctx, rank_, count, type_, dims_, perm, src, dst); \
ctx, rank, count, type, dims, perm, src, dst); \
break; \
}
switch (type_) {
switch (type) {
case kSwapTranspose:
case kGeneralTranspose:
switch (cls_->GetVecSize()) {
switch (cls_ptr->GetVecSize()) {
TRANSPOSE_DISPATCH_VEC_SIZE(1);
TRANSPOSE_DISPATCH_VEC_SIZE(2);
TRANSPOSE_DISPATCH_VEC_SIZE(4);
}
break;
default:
switch (cls_->GetVecSize()) {
switch (cls_ptr->GetVecSize()) {
PERMUTE_DISPATCH_VEC_SIZE(1);
PERMUTE_DISPATCH_VEC_SIZE(2);
PERMUTE_DISPATCH_VEC_SIZE(4);
......@@ -1296,15 +1395,15 @@ struct PermuteDispatch {
}
#define TRANSPOSE_DISPATCH_VEC_SIZE
#define PERMUTE_DISPATCH_VEC_SIZE
}
};
}
template <typename T>
inline void PermuteAndTranspose(const phi::GPUContext& ctx,
inline void PermuteAndTranspose(
const phi::GPUContext& ctx,
const int& rank,
const phi::DenseTensor& in,
phi::DenseTensor* out,
const DimsSimplifier& simplifier) {
const phi::funcs::PermuteDimsSimplifier& simplifier) {
T* dst_data = out->data<T>();
const T* src_data = in.data<T>();
const auto count = simplifier.GetCount();
......@@ -1324,18 +1423,18 @@ inline void PermuteAndTranspose(const phi::GPUContext& ctx,
} else {
if (count < std::numeric_limits<uint32_t>::max()) {
PermuteDispatch<T, uint32_t>(ctx,
static_cast<uint32_t>(count),
&classifier,
simplifier.GetSrcDims(),
simplifier.GetPerm(),
static_cast<uint32_t>(count),
src_data,
dst_data);
} else {
PermuteDispatch<T, int64_t>(ctx,
static_cast<int64_t>(count),
&classifier,
simplifier.GetSrcDims(),
simplifier.GetPerm(),
static_cast<int64_t>(count),
src_data,
dst_data);
}
......@@ -1343,12 +1442,13 @@ inline void PermuteAndTranspose(const phi::GPUContext& ctx,
}
template <typename T>
inline void PermuteWithEigen(const phi::GPUContext& ctx,
inline void PermuteWithEigen(
const phi::GPUContext& ctx,
const int& rank,
const phi::DenseTensor& in,
phi::DenseTensor* out,
const DimsSimplifier& simplifier) {
const bool not_same_dims = simplifier.GetRank() != rank;
const phi::funcs::PermuteDimsSimplifier& simplifier) {
bool not_same_dims = simplifier.GetRank() != rank;
if (not_same_dims) {
phi::DDim dst_dims = out->dims();
phi::DenseTensor temp_in;
......@@ -1373,10 +1473,10 @@ void TransposeGPUKernelDriver(const phi::GPUContext& ctx,
phi::DenseTensor* out) {
const int rank = perm.size();
int64_t numel = in.numel();
bool ret = TransposeSimple<T>::Impl(ctx, in, perm, out, numel);
bool ret = TransposeSimple<T>::Run(ctx, in, perm, out, numel);
if (!ret) {
auto simplifier =
DimsSimplifier(rank, numel, perm, phi::vectorize<int64_t>(in.dims()));
auto simplifier = phi::funcs::PermuteDimsSimplifier(
rank, numel, perm, phi::vectorize<int64_t>(in.dims()));
auto* tuner = phi::autotune::MakeTransposeTuner<T>(PermuteWithEigen<T>);
tuner->AddCallBack(PermuteAndTranspose<T>);
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
namespace funcs {
enum { kTransposeMKLDNNFP32 = 1, kTransposeMKLDNNINT8 = 2 };
enum PermuteType {
kCopy = 1,
kSwapTranspose = 2,
kGeneralTranspose = 3,
kVecPermute = 4,
kGeneralPermute = 5
};
constexpr int kBlockRows = 16;
constexpr int kTileSize = 32;
constexpr int kShareCol = (kTileSize + 1);
#define GETTILESIZE(LEN_, ALIGN_) \
((LEN_ + (ALIGN_ - 1)) & ~(ALIGN_ - 1)) / ALIGN_
template <typename T>
struct PermTypeClassifier {
public:
explicit PermTypeClassifier(const int sm_count,
const int rank,
const std::vector<int32_t>& perm,
const std::vector<int64_t>& dims,
const T* src,
T* dst) {
if (rank == 1) {
type_ = PermuteType::kCopy;
} else {
constexpr int64_t dim_limitation = 65536;
const int dst_vec_size = phi::GetVectorizedSize<T>(dst);
// While the last dim is fixed, there is chance for vectorized IO.
const int last_idx = rank - 1;
if (perm[last_idx] == last_idx) {
type_ = PermuteType::kVecPermute;
vec_size_ = GetDimVecSize(dst_vec_size, dims[last_idx], src, false);
return;
}
// Permute at last 2 dims, namely transpose.
if ((rank == 2 && perm[1] == 0 && perm[0] == 1) ||
(rank == 3 && perm[2] == 1 && perm[1] == 2)) {
int64_t channel = rank == 2 ? 1 : dims[0];
// Currently, transpose kernel cannot cover the case that channel
// dimension is more than 65536 which is the limitation of dim3 setting.
// This special case will be covered by extended transpose kernel later.
if (channel < dim_limitation) {
type_ = PermuteType::kGeneralTranspose;
num_rows_tile_ = GETTILESIZE(dims[rank - 2], kTileSize);
int dim_vec_size = GetDimVecSize(dst_vec_size, dims[last_idx], src);
int tile_size =
channel * num_rows_tile_ * GETTILESIZE(dims[last_idx], kTileSize);
vec_size_ = tile_size < sm_count ? 1 : dim_vec_size;
} else {
type_ = PermuteType::kGeneralPermute;
}
return;
}
// Permute at first dim and third dim.
if (rank == 3 && perm[2] == 0 && perm[1] == 1) {
// Currently, transpose kernel cannot cover the case that channel
// dimension is more than 65536 which is the limitation of dim3 setting.
// This special case will be covered by extended transpose kernel later.
if (dims[1] < dim_limitation) {
type_ = PermuteType::kSwapTranspose;
num_rows_tile_ = GETTILESIZE(dims[0], kTileSize);
int dim_vec_size = GetDimVecSize(dst_vec_size, dims[last_idx], src);
int tile_size =
dims[1] * num_rows_tile_ * GETTILESIZE(dims[2], kTileSize);
vec_size_ = tile_size < sm_count ? 1 : dim_vec_size;
} else {
type_ = PermuteType::kGeneralPermute;
}
return;
}
vec_size_ = dst_vec_size;
}
}
~PermTypeClassifier() = default;
int GetVecSize() const { return vec_size_; }
int GetRowsTile() const { return num_rows_tile_; }
PermuteType GetPermType() const { return type_; }
private:
int vec_size_{1};
int64_t num_rows_tile_{0};
PermuteType type_{kGeneralPermute};
// To find if highest common divisor and make it as vec_size.
int GetDimVecSize(const int dst_vec_size,
const int64_t target_dim,
const T* src,
bool use_share_mem = true) {
const int vec_size = std::min(dst_vec_size, phi::GetVectorizedSize<T>(src));
int dim_vec_size = 1;
for (int size = vec_size; size > 0; size /= 2) {
if (target_dim % size == 0) {
dim_vec_size = size;
break;
}
}
if (use_share_mem) {
// By bytes limitation of shared_memory.
return (sizeof(T) > sizeof(float) ? 1 : dim_vec_size);
} else {
return dim_vec_size;
}
}
};
} // namespace funcs
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册