未验证 提交 87cba48b 编写于 作者: L limingshu 提交者: GitHub

Performance fix for broadcast kernel [Part2] (#40051)

* first commit

* merged with develop

* merged with develop

* fix merge sequential one dims bugs
上级 429b5b5b
...@@ -54,20 +54,20 @@ HOSTDEVICE inline void Store(const AlignedVector<T, Size>& vec, T* addr) { ...@@ -54,20 +54,20 @@ HOSTDEVICE inline void Store(const AlignedVector<T, Size>& vec, T* addr) {
template <typename T> template <typename T>
int GetVectorizedSize(const T* pointer) { int GetVectorizedSize(const T* pointer) {
constexpr int max_load_bits = 128; constexpr int max_load_bits = 128;
int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T); constexpr int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T);
uint64_t address = reinterpret_cast<uint64_t>(pointer); uint64_t address = reinterpret_cast<uint64_t>(pointer);
constexpr int vec8 = std::alignment_of<AlignedVector<T, 8>>::value; // NOLINT constexpr int vec8 = std::alignment_of<AlignedVector<T, 8>>::value; // NOLINT
constexpr int vec4 = std::alignment_of<AlignedVector<T, 4>>::value; // NOLINT constexpr int vec4 = std::alignment_of<AlignedVector<T, 4>>::value; // NOLINT
constexpr int vec2 = std::alignment_of<AlignedVector<T, 2>>::value; // NOLINT constexpr int vec2 = std::alignment_of<AlignedVector<T, 2>>::value; // NOLINT
if (address % vec8 == 0) {
/* /*
* Currently, decide to deal with no more than 4 data once while adopting * Currently, decide to deal with no more than 4 data once while adopting
* vectorization load/store, if performance test shows that dealing with * vectorization load/store, if performance test shows that dealing with
* 8 data once in vectorization load/store does get optimized, return code * 8 data once in vectorization load/store does get optimized, code below
* below can be changed into " return std::min(8, valid_vec_size); " . * can begin with :
*/ if (address % vec8 == 0) {
return std::min(4, valid_vec_size); return std::min(4, valid_vec_size);
} else if (address % vec4 == 0) { */
if (address % vec4 == 0) {
return std::min(4, valid_vec_size); return std::min(4, valid_vec_size);
} else if (address % vec2 == 0) { } else if (address % vec2 == 0) {
return std::min(2, valid_vec_size); return std::min(2, valid_vec_size);
......
...@@ -125,7 +125,7 @@ struct DimensionsTransform { ...@@ -125,7 +125,7 @@ struct DimensionsTransform {
// To judge whether shape of any input tensors is sequential // To judge whether shape of any input tensors is sequential
// 1-value-dimensions, and metric the length of it. // 1-value-dimensions, and metric the length of it.
int GetSequentialOneDimLength(int *swap_index) { bool FindSequentialOneDim(int *swap_index) {
int index = 0; int index = 0;
int max_one_length = 0; int max_one_length = 0;
for (int j = 0; j < N; ++j) { for (int j = 0; j < N; ++j) {
...@@ -144,16 +144,16 @@ struct DimensionsTransform { ...@@ -144,16 +144,16 @@ struct DimensionsTransform {
} }
} }
} }
max_one_length =
seq_one_length > max_one_length ? seq_one_length : max_one_length;
index = seq_one_length > max_one_length ? j : index; index = seq_one_length > max_one_length ? j : index;
max_one_length = std::max(seq_one_length, max_one_length);
} }
if (max_one_length > 1) { bool has_seq_one = max_one_length > 1;
if (has_seq_one) {
std::swap(in_dims[0], in_dims[index]); std::swap(in_dims[0], in_dims[index]);
*swap_index = index; *swap_index = index;
} }
return max_one_length; return has_seq_one;
} }
public: public:
...@@ -214,8 +214,8 @@ struct DimensionsTransform { ...@@ -214,8 +214,8 @@ struct DimensionsTransform {
} }
}; };
int swap_idx = 0; int swap_idx = 0;
int max_one_length = GetSequentialOneDimLength(&swap_idx); bool has_seq_one = FindSequentialOneDim(&swap_idx);
if (max_one_length > 1) { if (has_seq_one) {
merge_ptr = merge_sequential_one_dims; merge_ptr = merge_sequential_one_dims;
MergeDimensions<MergeFunctor>(merge_ptr, N); MergeDimensions<MergeFunctor>(merge_ptr, N);
std::swap(in_dims[swap_idx], in_dims[0]); std::swap(in_dims[swap_idx], in_dims[0]);
...@@ -223,13 +223,13 @@ struct DimensionsTransform { ...@@ -223,13 +223,13 @@ struct DimensionsTransform {
} }
}; };
template <typename InT, typename OutT, int NumOuts = 1> template <typename InT, typename OutT>
int GetVecsize(const std::vector<const DenseTensor *> &ins, int GetVecsize(const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs) { std::vector<DenseTensor *> *outs) {
int in_vec_size = 4; int in_vec_size = 4;
int out_vec_size = 4; int out_vec_size = 4;
if (NumOuts > 1) { if (outs->size() > 1) {
for (int i = 0; i < NumOuts; ++i) { for (auto i = 1; i < outs->size(); ++i) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
(*outs)[i]->dims(), (*outs)[i]->dims(),
(*outs)[0]->dims(), (*outs)[0]->dims(),
...@@ -295,7 +295,7 @@ __device__ void VectorizedBroadcastKernelImpl( ...@@ -295,7 +295,7 @@ __device__ void VectorizedBroadcastKernelImpl(
__simd__ ConditionalT<OutT, NumOuts> result[VecSize]; __simd__ ConditionalT<OutT, NumOuts> result[VecSize];
#pragma unroll #pragma unroll
for (int i = 0; i < Arity; i++) { for (int i = 0; i < Arity; ++i) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f), read_lens); kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f), read_lens);
LoadData<InT, VecSize, IsBoundary>(args[i], LoadData<InT, VecSize, IsBoundary>(args[i],
ins[i], ins[i],
...@@ -433,7 +433,7 @@ void LaunchBroadcastKernel( ...@@ -433,7 +433,7 @@ void LaunchBroadcastKernel(
outs_data[i] = (_ptr_ OutT *)(ctx.Alloc<OutT>((*outs)[i])); outs_data[i] = (_ptr_ OutT *)(ctx.Alloc<OutT>((*outs)[i]));
} }
for (int i = 0; i < Arity; i++) { for (int i = 0; i < Arity; ++i) {
use_broadcast[i] = (ins[i]->numel() != numel); use_broadcast[i] = (ins[i]->numel() != numel);
ins_data[i] = (const _ptr_ InT *)(ins[i]->data<InT>()); ins_data[i] = (const _ptr_ InT *)(ins[i]->data<InT>());
} }
...@@ -532,7 +532,7 @@ void BroadcastKernelForDifferentVecSize( ...@@ -532,7 +532,7 @@ void BroadcastKernelForDifferentVecSize(
bool is_optimize = configs[0].cmp_type != type; bool is_optimize = configs[0].cmp_type != type;
int vec_size = is_optimize ? VecSizeL : VecSizeM; int vec_size = is_optimize ? VecSizeL : VecSizeM;
#else #else
for (int i = 0; i < kArity; i++) { for (int i = 0; i < kArity; ++i) {
// get the broadcast config, // get the broadcast config,
// if data shape is[m, n], then you should set data_dim = {n, m} // if data shape is[m, n], then you should set data_dim = {n, m}
// eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3} // eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
...@@ -541,7 +541,7 @@ void BroadcastKernelForDifferentVecSize( ...@@ -541,7 +541,7 @@ void BroadcastKernelForDifferentVecSize(
merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size); merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
} }
} }
int vec_size = GetVecsize<InT, OutT, NumOuts>(ins, outs); int vec_size = GetVecsize<InT, OutT>(ins, outs);
#endif #endif
switch (vec_size) { switch (vec_size) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册