未验证 提交 a0f43889 编写于 作者: L limingshu 提交者: GitHub

Transpose optimization for AlphaFold2 (#45230)

* first commit

* fix bugs according to ci

* add some changes

* change file name into function.cu.h

* remove const_cast
上级 30f4ef7f
......@@ -21,7 +21,7 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/functors.h"
#include "paddle/phi/kernels/funcs/transpose_functor.cu.h"
#include "paddle/phi/kernels/funcs/transpose_function.cu.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
namespace paddle {
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/funcs/transpose_functor.cu.h"
#include "paddle/phi/kernels/funcs/transpose_function.cu.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"
namespace paddle {
......
......@@ -243,5 +243,106 @@ struct BroadcastDimsSimplifier {
}
};
// Simplify the input dims and permute dims if possible.
struct DimsSimplifier {
public:
explicit DimsSimplifier(const int rank,
const int64_t numel,
const std::vector<int32_t> &perm,
const std::vector<int64_t> &dims)
: perm_(rank), src_dims_(rank), count_(numel) {
SimplifyPermAndDims(rank, dims, perm);
perm_.resize(rank_);
src_dims_.resize(rank_);
dst_dims_.resize(rank_);
if (!is_seq_perm_) {
for (auto i = 0; i < rank_; ++i) {
dst_dims_[i] = src_dims_[perm_[i]];
}
} else {
dst_dims_[0] = numel;
src_dims_[0] = numel;
}
}
~DimsSimplifier() = default;
const int &GetRank() const { return rank_; }
const int64_t &GetCount() const { return count_; }
const std::vector<int> &GetPerm() const { return perm_; }
const std::vector<int64_t> &GetSrcDims() const { return src_dims_; }
const std::vector<int64_t> &GetDstDims() const { return dst_dims_; }
private:
int rank_{1};
int64_t count_{0};
bool is_seq_perm_{true};
std::vector<int> perm_;
std::vector<int64_t> src_dims_;
std::vector<int64_t> dst_dims_;
void SimplifyPermAndDims(const int rank,
const std::vector<int64_t> &in_dims,
const std::vector<int32_t> &perm) {
int start_perm_idx = 0;
int valid_dim_idx = 0;
int valid_map[phi::DDim::kMaxRank];
int64_t combined_dims[phi::DDim::kMaxRank];
// Merge consecutive dims to the fist one dim and
// leave original dim to be 1. Example below :
// perm: [2, 3, 0, 1], origin_dims : [4, 8, 2, 5]
// new_dims: [4, 8, 2, 5] -> [32, 1, 10, 1]
while (start_perm_idx < rank) {
const int start_dim_idx = perm[start_perm_idx];
combined_dims[start_dim_idx] = in_dims[start_dim_idx];
int end_perm_idx = start_perm_idx + 1;
while (end_perm_idx < rank &&
perm[end_perm_idx] == perm[end_perm_idx - 1] + 1) {
const int end_dim_idx = perm[end_perm_idx];
combined_dims[start_dim_idx] *= in_dims[end_dim_idx];
combined_dims[end_dim_idx] = 1;
end_perm_idx += 1;
}
start_perm_idx = end_perm_idx;
}
// Reorder combined dims and marked useless dim as -1.
// for example, if combined dims is [32, 1, 10, 1],
// valid_map is [0, -1, 1, -1] and generate simplified
// dims as [32, 10]
for (auto i = 0; i < rank; ++i) {
const int dim_val = combined_dims[i];
if (dim_val == 1) {
valid_map[i] = -1;
} else {
valid_map[i] = valid_dim_idx;
src_dims_[valid_dim_idx] = dim_val;
valid_dim_idx += 1;
}
}
if (valid_dim_idx == 0) {
src_dims_[0] = 1;
perm_[0] = 0;
return;
}
// Acquire simplified perm with help of combined dims
// and original perm, finally simplified perm is [1, 0]
int perm_idx = 0;
for (auto i = 0; i < rank; ++i) {
const int mapped = valid_map[perm[i]];
if (mapped >= 0) {
perm_[perm_idx] = mapped;
is_seq_perm_ &= (mapped == perm_idx);
perm_idx += 1;
}
}
rank_ = is_seq_perm_ ? 1 : valid_dim_idx;
}
};
} // namespace funcs
} // namespace phi
......@@ -27,161 +27,115 @@ enum { kTransposeMKLDNNFP32 = 1, kTransposeMKLDNNINT8 = 2 };
enum PermuteType {
kCopy = 1,
kTranspose = 2,
kVecPermute = 3,
kGeneralPermute = 4
kSwapTranspose = 2,
kGeneralTranspose = 3,
kVecPermute = 4,
kGeneralPermute = 5
};
constexpr int kBlockRows = 16;
constexpr int kTileSize = 32;
constexpr int kShareCol = (kTileSize + 1);
#define GETTILESIZE(LEN_, ALIGN_) \
((LEN_ + (ALIGN_ - 1)) & ~(ALIGN_ - 1)) / ALIGN_
// Simplify the input dims and permute dims if possible.
template <typename T>
class TranposeTypeClassifier {
struct PermTypeClassifier {
public:
TranposeTypeClassifier(const int sm_count,
const size_t rank,
const int64_t numel,
const std::vector<int32_t>& perm,
const std::vector<int64_t>& dims,
const T* src,
T* dst)
: perm_(rank), src_dims(rank) {
SimplifyPermAndDims(rank, dims, perm);
if (rank_ > 1) {
vec_size_ = GetPermVecSize(sm_count, src, dst);
}
perm_.resize(rank_);
src_dims.resize(rank_);
dst_dims.resize(rank_);
for (auto i = 0; i < rank_; ++i) {
dst_dims[i] = src_dims[perm_[i]];
}
}
int GetRank() const { return rank_; }
int GetVecSize() const { return vec_size_; }
PermuteType GetPermType() const { return type_; }
std::vector<int> GetPerm() const { return perm_; }
std::vector<int64_t> GetSrcDims() const { return src_dims; }
std::vector<int64_t> GetDstDims() const { return dst_dims; }
private:
int rank_{1};
int vec_size_{1};
std::vector<int> perm_;
std::vector<int64_t> src_dims;
std::vector<int64_t> dst_dims;
PermuteType type_{kCopy};
void SimplifyPermAndDims(const size_t rank,
const std::vector<int64_t>& in_dims,
const std::vector<int32_t>& perm) {
int64_t combined_dims[phi::DDim::kMaxRank];
int valid_map[phi::DDim::kMaxRank];
// Merge consecutive dims to the fist one dim and
// leave original dim to be 1. Example below :
// perm: [2, 3, 0, 1], origin_dims : [4, 8, 2, 5]
// new_dims: [4, 8, 2, 5] -> [32, 1, 10, 1]
int start_perm_idx = 0;
while (start_perm_idx < rank) {
const int start_dim_idx = perm[start_perm_idx];
combined_dims[start_dim_idx] = in_dims[start_dim_idx];
int end_perm_idx = start_perm_idx + 1;
while (end_perm_idx < rank &&
perm[end_perm_idx] == perm[end_perm_idx - 1] + 1) {
const int end_dim_idx = perm[end_perm_idx];
combined_dims[start_dim_idx] *= in_dims[end_dim_idx];
combined_dims[end_dim_idx] = 1;
end_perm_idx += 1;
explicit PermTypeClassifier(const int sm_count,
const int rank,
const std::vector<int32_t>& perm,
const std::vector<int64_t>& dims,
const T* src,
T* dst) {
if (rank == 1) {
type_ = PermuteType::kCopy;
} else {
constexpr int64_t dim_limitation = 65536;
const int dst_vec_size = phi::GetVectorizedSize<T>(dst);
// While the last dim is fixed, there is chance for vectorized IO.
const int last_idx = rank - 1;
if (perm[last_idx] == last_idx) {
type_ = PermuteType::kVecPermute;
vec_size_ = GetDimVecSize(dst_vec_size, dims[last_idx], src, false);
return;
}
start_perm_idx = end_perm_idx;
}
// Reorder combined dims and marked useless dim as -1.
// for example, if combined dims is [32, 1, 10, 1],
// valid_map is [0, -1, 1, -1] and generate simplified
// dims as [32, 10]
int valid_dim_idx = 0;
bool sequential_flag = false;
for (auto i = 0; i < rank; ++i) {
const int src_dim = combined_dims[i];
if (src_dim == 1) {
valid_map[i] = -1;
} else {
sequential_flag = true;
valid_map[i] = valid_dim_idx;
src_dims[valid_dim_idx] = src_dim;
valid_dim_idx += 1;
// Permute at last 2 dims, namely transpose.
if ((rank == 2 && perm[1] == 0 && perm[0] == 1) ||
(rank == 3 && perm[2] == 1 && perm[1] == 2)) {
int64_t channel = rank == 2 ? 1 : dims[0];
// Currently, transpose kernel cannot cover the case that channel
// dimension is more than 65536 which is the limitation of dim3 setting.
// This special case will be covered by extended transpose kernel later.
if (channel < dim_limitation) {
type_ = PermuteType::kGeneralTranspose;
num_rows_tile_ = GETTILESIZE(dims[rank - 2], kTileSize);
int dim_vec_size = GetDimVecSize(dst_vec_size, dims[last_idx], src);
int tile_size =
channel * num_rows_tile_ * GETTILESIZE(dims[last_idx], kTileSize);
vec_size_ = tile_size < sm_count ? 1 : dim_vec_size;
} else {
type_ = PermuteType::kGeneralPermute;
}
return;
}
}
if (valid_dim_idx == 0) {
src_dims[0] = 1;
perm_[0] = 0;
return;
} else if (valid_dim_idx == 1) {
type_ = PermuteType::kCopy;
}
// Acquire simplified perm with help of combined dims
// and original perm, finally simplified perm is [1, 0]
int perm_idx = 0;
for (auto i = 0; i < rank; ++i) {
const int mapped = valid_map[perm[i]];
if (mapped >= 0) {
perm_[perm_idx] = mapped;
perm_idx += 1;
// Permute at first dim and third dim.
if (rank == 3 && perm[2] == 0 && perm[1] == 1) {
// Currently, transpose kernel cannot cover the case that channel
// dimension is more than 65536 which is the limitation of dim3 setting.
// This special case will be covered by extended transpose kernel later.
if (dims[1] < dim_limitation) {
type_ = PermuteType::kSwapTranspose;
num_rows_tile_ = GETTILESIZE(dims[0], kTileSize);
int dim_vec_size = GetDimVecSize(dst_vec_size, dims[last_idx], src);
int tile_size =
dims[1] * num_rows_tile_ * GETTILESIZE(dims[2], kTileSize);
vec_size_ = tile_size < sm_count ? 1 : dim_vec_size;
} else {
type_ = PermuteType::kGeneralPermute;
}
return;
}
vec_size_ = dst_vec_size;
}
rank_ = valid_dim_idx;
}
int GetPermVecSize(const int sm_count, const T* src, T* dst) {
// For gerneal_permute kernel, there is good chance for
// vectorized write.
type_ = PermuteType::kGeneralPermute;
int vec_size = phi::GetVectorizedSize<T>(dst);
// While the last dim is fixed, there is good chance for
// both vectorized read and write.
if (perm_[rank_ - 1] == rank_ - 1) {
int tmp_size = std::min(vec_size, phi::GetVectorizedSize<T>(src));
tmp_size = GetDimVesSize(tmp_size, src_dims[rank_ - 1]);
if (tmp_size > 1) {
type_ = kVecPermute;
vec_size = tmp_size;
}
}
~PermTypeClassifier() = default;
// Once only transpose at the last 2 dims, there is good
// chance for vectorized read.
if ((rank_ == 2 && perm_[1] == 0 && perm_[0] == 1) ||
(rank_ == 3 && perm_[2] == 1 && perm_[1] == 2)) {
type_ = PermuteType::kTranspose;
int tmp_vec = std::min(vec_size, phi::GetVectorizedSize<T>(src));
// With bytes limitation of shared_memory, the VecSize shall be
// restricted for the type whose byte-size is less than 8 (double).
vec_size =
sizeof(T) > 8 ? 1 : GetDimVesSize(tmp_vec, src_dims[rank_ - 1]);
}
return vec_size;
}
int GetVecSize() const { return vec_size_; }
int GetRowsTile() const { return num_rows_tile_; }
PermuteType GetPermType() const { return type_; }
private:
int vec_size_{1};
int64_t num_rows_tile_{0};
PermuteType type_{kGeneralPermute};
// To find if highest common divisor and make it as vec_size.
int GetDimVesSize(const int vec_size, const size_t target_dim) {
int GetDimVecSize(const int dst_vec_size,
const int64_t target_dim,
const T* src,
bool use_share_mem = true) {
const int vec_size = std::min(dst_vec_size, phi::GetVectorizedSize<T>(src));
int dim_vec_size = 1;
for (auto size = vec_size; size > 0; size /= 2) {
for (int size = vec_size; size > 0; size /= 2) {
if (target_dim % size == 0) {
dim_vec_size = size;
break;
}
}
return dim_vec_size;
if (use_share_mem) {
// By bytes limitation of shared_memory.
return (sizeof(T) > sizeof(float) ? 1 : dim_vec_size);
} else {
return dim_vec_size;
}
}
};
......
......@@ -21,7 +21,7 @@
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/transpose_functor.cu.h"
#include "paddle/phi/kernels/funcs/transpose_function.cu.h"
#include "paddle/phi/kernels/impl/transpose_grad_kernel_impl.h"
namespace phi {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册