未验证 提交 e5ad3859 编写于 作者: Z ZhangDY-6483 提交者: GitHub

Memory Efficient Attention (#51867)

* first version, notest

* return final rst, notest

* use infinity() instead of max

* ut structure

* start up of ut

* generate lse

* update

* add depense

* reconstruct cmake

* move file

* add memory efficient attention and fix blasimpl

* update

* update cmake

* add namespace

* update cmake

* use .cu

* update for pad3d

* bug fix

* bug fix

* update

* bug fix

* update enforce

* add test case

* merge the lse pad

* fix kernel_fn of backward

* fix PADDLE_ENFORCE_EQ and phi_api

* fix PADDLE_ENFORCE

* fix PADDLE_ENFORCE

* rerun coverage

* fix memory efficient attention test

* rerun ci

* add cuda version condition

* add cuda version condition

* delete WIP test

* replace PADDLE_ENFORCE

* edit the namespace of datatype in multiple.cc

* rerun

* rerun

---------
Co-authored-by: Nliuyuang <liuyuang@baidu.com>
上级 40fea722
......@@ -96,7 +96,7 @@ endfunction()
# Function for selecting GPU arch flags for nvcc based on CUDA_ARCH_NAME
# Usage:
# select_nvcc_arch_flags(out_variable)
function(select_nvcc_arch_flags out_variable)
function(select_nvcc_arch_flags out_variable out_arch_bin)
# List of arch names
set(archs_names
"Kepler"
......@@ -244,6 +244,9 @@ function(select_nvcc_arch_flags out_variable)
set(${out_variable}_real_archs
${nvcc_real_archs}
PARENT_SCOPE)
set(${out_arch_bin}
${cuda_arch_bin}
PARENT_SCOPE)
endfunction()
message(STATUS "CUDA detected: " ${CMAKE_CUDA_COMPILER_VERSION})
......@@ -273,7 +276,7 @@ add_definitions("-DCUDA_VERSION_MINOR=\"${CUDA_VERSION_MINOR}\"")
add_definitions("-DCUDA_TOOLKIT_ROOT_DIR=\"${CUDA_TOOLKIT_ROOT_DIR}\"")
# setting nvcc arch flags
select_nvcc_arch_flags(NVCC_FLAGS_EXTRA)
select_nvcc_arch_flags(NVCC_FLAGS_EXTRA NVCC_ARCH_BIN)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${NVCC_FLAGS_EXTRA}")
message(STATUS "NVCC_FLAGS_EXTRA: ${NVCC_FLAGS_EXTRA}")
......
......@@ -963,6 +963,17 @@
kernel :
func : maxout_grad
- backward_op : memory_efficient_attention_grad
forward : memory_efficient_attention (Tensor query, Tensor key, Tensor value, Tensor bias, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor causal_diagonal, Tensor seqlen_k, Scalar max_seqlen_q, Scalar max_seqlen_k, bool causal, double dropout_p, float scale, bool is_test) -> Tensor(output), Tensor(logsumexp), Tensor(seed_and_offset)
args : (Tensor query, Tensor key, Tensor value, Tensor bias, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor output, Tensor logsumexp, Tensor seed_and_offset, Tensor output_grad, Scalar max_seqlen_q, Scalar max_seqlen_k, bool causal, double dropout_p, float scale)
output : Tensor(query_grad), Tensor(key_grad), Tensor(value_grad), Tensor(bias_grad)
infer_meta :
func : MemoryEfficientAttentionGradInferMeta
kernel :
func : memory_efficient_attention_grad
data_type : output_grad
optional : bias, cu_seqlens_q, cu_seqlens_k
- backward_op : meshgrid_grad
forward : meshgrid (Tensor[] inputs) -> Tensor[](outputs)
args : (Tensor[] inputs, Tensor[] outputs_grad)
......
......@@ -981,6 +981,17 @@
func : maxout
backward : maxout_grad
- op : memory_efficient_attention
args : (Tensor query, Tensor key, Tensor value, Tensor bias, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor causal_diagonal, Tensor seqlen_k, Scalar max_seqlen_q, Scalar max_seqlen_k, bool causal, double dropout_p, float scale, bool is_test)
output : Tensor(output), Tensor(logsumexp), Tensor(seed_and_offset)
infer_meta :
func : MemoryEfficientAttentionInferMeta
kernel :
func : memory_efficient_attention
data_type : query
optional : bias, cu_seqlens_q, cu_seqlens_k, causal_diagonal, seqlen_k
backward : memory_efficient_attention_grad
- op : meshgrid
args : (Tensor[] inputs)
output : Tensor[]{inputs.size()}
......
......@@ -1052,4 +1052,89 @@ void IndexAddGradInferMeta(const MetaTensor& index,
}
}
void MemoryEfficientAttentionGradInferMeta(const MetaTensor& query,
const MetaTensor& key,
const MetaTensor& value,
const MetaTensor& bias,
const MetaTensor& cu_seqlens_q,
const MetaTensor& cu_seqlens_k,
const MetaTensor& output,
const MetaTensor& logsumexp,
const MetaTensor& seed_and_offset,
const MetaTensor& output_grad,
const Scalar& max_seqlen_q,
const Scalar& max_seqlen_k,
const bool causal,
const double dropout_p,
const float scale,
MetaTensor* query_grad,
MetaTensor* key_grad,
MetaTensor* value_grad,
MetaTensor* bias_grad) {
PADDLE_ENFORCE_EQ(
output_grad.dims().size(),
4,
phi::errors::InvalidArgument("Key should be a 4-D tensor"
"But received Key dimension(%s)",
output_grad.dims().size()));
PADDLE_ENFORCE_EQ(
output.dims().size(),
4,
phi::errors::InvalidArgument("Key should be a 4-D tensor"
"But received Key dimension(%s)",
output_grad.dims().size()));
const int64_t query_batch_size = query.dims()[0];
const int64_t query_seq_length = query.dims()[1];
const int64_t query_num_head = query.dims()[2];
const int64_t query_head_size = query.dims()[3];
const int64_t key_batch_size = key.dims()[0];
const int64_t key_seq_length = key.dims()[1];
const int64_t key_num_head = key.dims()[2];
const int64_t key_head_size = key.dims()[3];
const int64_t value_batch_size = value.dims()[0];
const int64_t value_seq_length = value.dims()[1];
const int64_t value_num_head = value.dims()[2];
const int64_t value_head_size = value.dims()[3];
std::vector<int64_t> query_grad_dims(
{query_batch_size, query_seq_length, query_num_head, query_head_size});
std::vector<int64_t> key_grad_dims(
{key_batch_size, key_seq_length, key_num_head, key_head_size});
std::vector<int64_t> value_grad_dims(
{value_batch_size, value_seq_length, value_num_head, value_head_size});
query_grad->set_dims(phi::make_ddim(query_grad_dims));
query_grad->share_lod(query);
query_grad->set_dtype(query.dtype());
query_grad->set_layout(query.layout());
key_grad->set_dims(phi::make_ddim(key_grad_dims));
key_grad->share_lod(key);
key_grad->set_dtype(key.dtype());
key_grad->set_layout(key.layout());
value_grad->set_dims(phi::make_ddim(value_grad_dims));
value_grad->share_lod(value);
value_grad->set_dtype(value.dtype());
value_grad->set_layout(value.layout());
if (bias) {
const int64_t bias_batch_size = bias.dims()[0];
const int64_t bias_seq_length = bias.dims()[1];
const int64_t bias_num_head = bias.dims()[2];
const int64_t bias_head_size = bias.dims()[3];
std::vector<int64_t> bias_grad_dims(
{bias_batch_size, bias_seq_length, bias_num_head, bias_head_size});
bias_grad->set_dims(phi::make_ddim(bias_grad_dims));
bias_grad->share_lod(bias);
bias_grad->set_dtype(bias.dtype());
bias_grad->set_layout(bias.layout());
}
}
} // namespace phi
......@@ -418,4 +418,24 @@ void IndexAddGradInferMeta(const MetaTensor& index,
MetaTensor* x_grad,
MetaTensor* add_tensor_grad);
void MemoryEfficientAttentionGradInferMeta(const MetaTensor& query,
const MetaTensor& key,
const MetaTensor& value,
const MetaTensor& bias,
const MetaTensor& cu_seqlens_q,
const MetaTensor& cu_seqlens_k,
const MetaTensor& output,
const MetaTensor& logsumexp,
const MetaTensor& seed_and_offset,
const MetaTensor& output_grad,
const Scalar& max_seqlen_q,
const Scalar& max_seqlen_k,
const bool causal,
const double dropout_p,
const float scale,
MetaTensor* query_grad,
MetaTensor* key_grad,
MetaTensor* value_grad,
MetaTensor* bias_grad);
} // namespace phi
......@@ -3124,6 +3124,94 @@ void MoeInferMeta(const MetaTensor& x,
out->set_layout(x.layout());
}
} // namespace phi
void MemoryEfficientAttentionInferMeta(const MetaTensor& query,
const MetaTensor& key,
const MetaTensor& value,
const MetaTensor& bias,
const MetaTensor& cu_seqlens_q,
const MetaTensor& cu_seqlens_k,
const MetaTensor& causal_diagonal,
const MetaTensor& seqlen_k,
const Scalar& max_seqlen_q,
const Scalar& max_seqlen_k,
const bool causal,
const double dropout_p,
const float scale,
const bool is_test,
MetaTensor* output,
MetaTensor* logsumexp,
MetaTensor* seed_and_offset) {
PADDLE_ENFORCE_EQ(
query.dims().size(),
4,
phi::errors::InvalidArgument("Query should be a 4-D tensor"
"But received Query dimension(%s)",
query.dims().size()));
PADDLE_ENFORCE_EQ(
key.dims().size(),
4,
phi::errors::InvalidArgument("Key should be a 4-D tensor"
"But received Key dimension(%s)",
key.dims().size()));
PADDLE_ENFORCE_EQ(
value.dims().size(),
4,
phi::errors::InvalidArgument("Value should be a 4-D tensor"
"But received Value dimension(%s)",
value.dims().size()));
const int64_t query_batch_size = query.dims()[0];
const int64_t query_seq_length = query.dims()[1];
const int64_t query_num_head = query.dims()[2];
const int64_t query_head_size = query.dims()[3];
const int64_t key_batch_size = key.dims()[0];
const int64_t key_seq_length = key.dims()[1];
const int64_t key_num_head = key.dims()[2];
const int64_t key_head_size = key.dims()[3];
const int64_t value_batch_size = value.dims()[0];
const int64_t value_seq_length = value.dims()[1];
const int64_t value_num_head = value.dims()[2];
const int64_t value_head_size = value.dims()[3];
PADDLE_ENFORCE_EQ(((query_batch_size == key_batch_size) &&
(key_batch_size == value_batch_size)),
true,
phi::errors::InvalidArgument(
"The batchsize of Query, Key, Value should be equal."));
PADDLE_ENFORCE_EQ(
((query_num_head == key_num_head) && (key_num_head == value_num_head)),
true,
phi::errors::InvalidArgument(
"The head number of Query, Key, Value should be equal."));
PADDLE_ENFORCE_EQ(query_head_size == key_head_size,
true,
phi::errors::InvalidArgument(
"The head size of Query, Key should be equal."));
PADDLE_ENFORCE_EQ(key_seq_length == value_seq_length,
true,
phi::errors::InvalidArgument(
"The seq length of Key, Value should be equal."));
std::vector<int64_t> out_dims(
{query_batch_size, query_seq_length, query_num_head, value_head_size});
std::vector<int64_t> logsumexp_dims({query_num_head, query_batch_size});
std::vector<int64_t> seed_and_offset_dims({2});
output->set_dims(phi::make_ddim(out_dims));
output->share_lod(query);
output->set_dtype(query.dtype());
output->set_layout(query.layout());
logsumexp->set_dims(phi::make_ddim(logsumexp_dims));
logsumexp->set_dtype(phi::DataType::FLOAT32);
seed_and_offset->set_dims(phi::make_ddim(seed_and_offset_dims));
seed_and_offset->set_dtype(phi::DataType::INT64);
}
} // namespace phi
PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta);
......@@ -587,4 +587,22 @@ void MoeInferMeta(const MetaTensor& x,
const std::string& act_type,
MetaTensor* out);
void MemoryEfficientAttentionInferMeta(const MetaTensor& query,
const MetaTensor& key,
const MetaTensor& value,
const MetaTensor& bias,
const MetaTensor& cu_seqlens_q,
const MetaTensor& cu_seqlens_k,
const MetaTensor& causal_diagonal,
const MetaTensor& seqlen_k,
const Scalar& max_seqlen_q,
const Scalar& max_seqlen_k,
const bool causal,
const double dropout_p,
const float scale,
const bool is_test,
MetaTensor* output,
MetaTensor* logsumexp,
MetaTensor* seed_and_offset);
} // namespace phi
......@@ -125,8 +125,15 @@ if(WITH_CUTLASS)
COMMAND ${PYTHON_EXECUTABLE} "conv2d_bias_residual.py"
WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/fusion/cutlass/conv2d")
execute_process(
COMMAND
${PYTHON_EXECUTABLE}
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_kernels.py
--cuda_arch "${NVCC_ARCH_BIN}")
file(GLOB cutlass_cu "fusion/cutlass/conv2d/generated/*.cu"
"fusion/cutlass/conv2d/*.cu" "fusion/cutlass/*.cu")
"fusion/cutlass/conv2d/*.cu" "fusion/cutlass/*.cu"
"fusion/cutlass/memory_efficient_attention/autogen/impl/*.cu")
add_definitions("-DPADDLE_WITH_MEMORY_EFFICIENT_ATTENTION")
list(APPEND kernel_cu ${cutlass_cu})
endif()
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
#include <cuda_runtime_api.h> // NOLINT
#include "cuda.h" // NOLINT
#include "paddle/phi/backends/dynload/cublasLt.h"
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ref:
// https://github.com/facebookresearch/xformers/blob/b6be33aecb5297f3f994568cf29e194a75e47667/xformers/ops/fmha/common.py#L102
#pragma once
#include "paddle/phi/backends/gpu/cuda/cuda_helper.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/slice.h"
#include "paddle/phi/kernels/pad3d_kernel.h"
namespace phi {
namespace funcs {
using phi::PADDLE_CUDA_NUM_THREADS;
template <typename T>
__global__ void ViewSliceHelper(T* data,
int stride,
int in_last_dim,
int out_second_dim) {
CUDA_KERNEL_LOOP_TYPE(i, stride * in_last_dim, int64_t) {
if (i % in_last_dim >= out_second_dim) {
*(data + i) = std::numeric_limits<T>::infinity();
}
}
}
template <typename T>
phi::DenseTensor get_pad_lse(const phi::GPUContext& dev_ctx,
phi::DenseTensor* lse,
int out_second_dim,
int pad_to,
const std::string& data_format = "NCHW",
bool force_pad_inf = false) {
int pad_amount = (pad_to - (lse->dims()[2] % pad_to)) % pad_to;
PADDLE_ENFORCE_EQ(
lse->dims().size(),
3,
phi::errors::InvalidArgument("The lse should be a 3d tensor"));
PADDLE_ENFORCE_EQ(
(data_format == "NCHW" || data_format == "NHWC"),
true,
phi::errors::InvalidArgument("The data_format should be NCHW or NHWC"));
std::string pad3d_data_format = data_format == "NCHW" ? "NCDHW" : "NDHWC";
if (pad_amount > 0) {
phi::DenseTensor tmp = *lse;
if (force_pad_inf) {
tmp = phi::funcs::Slice<T, phi::GPUContext>(
dev_ctx, *lse, {2}, {0}, {out_second_dim});
pad_amount = (pad_to - (tmp.dims()[2] % pad_to)) % pad_to;
}
tmp.Resize({tmp.dims()[0], tmp.dims()[1], tmp.dims()[2], 1, 1});
phi::DenseTensor out;
out.Resize({1, 1, 1, 1, 1});
phi::Pad3dKernel<T, phi::GPUContext>(dev_ctx,
tmp,
{0, 0, 0, 0, 0, pad_amount},
"constant",
std::numeric_limits<T>::infinity(),
pad3d_data_format,
&out);
out.Resize({out.dims()[0], out.dims()[1], out.dims()[2]});
return out;
} else if (force_pad_inf && out_second_dim != lse->dims()[2]) {
auto in_dim = lse->dims();
auto in_data = lse->template data<T>();
int stride = in_dim[0] * in_dim[1];
int block = PADDLE_CUDA_NUM_THREADS;
int64_t n = lse->numel();
dim3 grid = dim3((n + block - 1) / block);
phi::backends::gpu::LimitGridDim(dev_ctx, &grid);
ViewSliceHelper<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, stride, in_dim[2], out_second_dim);
return *lse;
}
}
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/autogen/memory_efficient_attention.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
template <typename T, typename Context>
void MemoryEfficientAttentionForwardKernel(
const Context& ctx,
const DenseTensor& query,
const DenseTensor& key,
const DenseTensor& value,
const paddle::optional<DenseTensor>& bias,
const paddle::optional<DenseTensor>& cu_seqlens_q,
const paddle::optional<DenseTensor>& cu_seqlens_k,
const paddle::optional<DenseTensor>& causal_diagonal,
const paddle::optional<DenseTensor>& seqlen_k,
const Scalar& max_seqlen_q,
const Scalar& max_seqlen_k,
const bool causal,
const double dropout_p,
const float scale,
const bool is_test,
DenseTensor* output,
DenseTensor* logsumexp,
DenseTensor* seed_and_offset) {
int compute_capacity = ctx.GetComputeCapability();
const auto max_shmem =
getMaximumSharedMemoryPerBlockKb(compute_capacity) * 1024;
bool kernel_launched = false;
auto max_seqlen_q_num = max_seqlen_q.to<uint64_t>();
auto max_seqlen_k_num = max_seqlen_k.to<uint64_t>();
auto launchKernel = [&](auto k_, auto kernel_fn) {
using KernelType = decltype(k_);
bool is_launched = kernel_launched;
if (is_launched) {
return;
}
using scalar_t = typename KernelType::scalar_t;
bool use_dropout = (dropout_p != 0);
if (!KernelType::kSupportsDropout && use_dropout) {
VLOG(3) << "run in to use dropout" << use_dropout;
return;
}
if (!KernelType::kSupportsBias && bias) {
VLOG(3) << "run in to bias";
return;
}
const auto& v_dims = value.dims();
if (KernelType::kSingleValueIteration &&
KernelType::kKeysPerBlock < v_dims[3]) {
VLOG(3) << "run in to value dim" << v_dims;
return;
}
const auto& k_dims = key.dims();
const auto& q_dims = query.dims();
int64_t max_seqlen_q_tmp, max_seqlen_k_tmp;
if (cu_seqlens_q) {
max_seqlen_q_tmp = max_seqlen_q_num;
max_seqlen_k_tmp = 0; // Will be set inside the kernel
} else {
max_seqlen_q_tmp = q_dims[1];
max_seqlen_k_tmp = k_dims[1];
}
VLOG(3) << "max_seqlen_q_tmp " << max_seqlen_q_tmp;
if ((q_dims[3] % KernelType::kAlignmentQ) ||
(k_dims[3] % KernelType::kAlignmentK) ||
(v_dims[3] % KernelType::kAlignmentV)) {
VLOG(3) << "run in to query dim" << q_dims;
VLOG(3) << "run in to key dim" << k_dims;
return;
}
size_t smem_bytes = sizeof(typename KernelType::SharedStorage);
if (smem_bytes > max_shmem) {
VLOG(3) << "run in to shmem" << smem_bytes << " " << max_shmem;
return;
}
kernel_launched = true;
VLOG(3) << "launching";
output->Resize({q_dims[0], q_dims[1], q_dims[2], v_dims[3]});
constexpr int64_t kAlignLSE = KernelType::kAlignLSE;
phi::Dim<3> logsumexp_dims;
logsumexp_dims[0] =
cu_seqlens_q ? cu_seqlens_q.get().dims()[0] - 1 : q_dims[0];
logsumexp_dims[1] = q_dims[2];
logsumexp_dims[2] =
is_test ? 0 : (max_seqlen_q_tmp + kAlignLSE - 1) / kAlignLSE;
logsumexp_dims[2] *= kAlignLSE;
logsumexp->Resize(logsumexp_dims);
ctx.template Alloc<float>(logsumexp);
VLOG(3) << "logsumexp dims" << logsumexp_dims;
VLOG(3) << "logsumexp" << logsumexp;
VLOG(3) << "kAlignLSE" << kAlignLSE;
typename KernelType::Params p;
p.query_ptr = SafeGetTensorPtr<scalar_t>(query);
p.key_ptr = SafeGetTensorPtr<scalar_t>(key);
p.value_ptr = SafeGetTensorPtr<scalar_t>(value);
p.logsumexp_ptr = is_test ? nullptr : logsumexp->data<float>();
VLOG(3) << "logsumexp_ptr" << p.logsumexp_ptr;
DenseTensor out_accum;
if (KernelType::kNeedsOutputAccumulatorBuffer) {
out_accum.Resize(output->dims());
p.output_accum_ptr =
SafeAllocTensor<typename KernelType::output_accum_t, Context>(
ctx, &out_accum);
VLOG(3) << "output_accum_ptr " << p.output_accum_ptr;
} else {
p.output_accum_ptr = nullptr;
}
p.output_ptr =
SafeAllocTensor<typename KernelType::output_t, Context>(ctx, output);
VLOG(3) << "output_ptr " << p.output_ptr;
if (cu_seqlens_q) {
p.seqstart_q_ptr = SafeGetTensorPtr<int32_t>(cu_seqlens_q);
p.seqstart_k_ptr = SafeGetTensorPtr<int32_t>(cu_seqlens_k);
VLOG(3) << "seqstart_q_ptr " << p.seqstart_q_ptr;
} else {
p.seqstart_q_ptr = nullptr;
p.seqstart_k_ptr = nullptr;
}
p.num_heads = q_dims[2];
p.head_dim = q_dims[3];
p.head_dim_value = v_dims[3];
p.num_queries = max_seqlen_q_tmp;
p.num_keys = max_seqlen_k_tmp;
p.num_batches = cu_seqlens_q ? cu_seqlens_q.get().dims()[0] - 1 : q_dims[0];
p.causal = causal;
if (causal_diagonal) {
p.causal_diagonal_ptr = SafeGetTensorPtr<int32_t>(causal_diagonal);
} else {
p.causal_diagonal_ptr = nullptr;
}
VLOG(3) << "causal_diagonal_ptr " << p.causal_diagonal_ptr;
p.seqlen_k_ptr = nullptr;
if (seqlen_k) {
p.seqlen_k_ptr = SafeGetTensorPtr<int32_t>(seqlen_k);
} else {
p.seqlen_k_ptr = nullptr;
}
VLOG(3) << "seqlen_k_ptr " << p.seqlen_k_ptr;
if (scale < 0) {
p.scale = static_cast<float>(1.0 / std::sqrt(p.head_dim));
} else {
p.scale = scale;
}
VLOG(3) << "scale " << p.scale;
p.q_strideB = DimStride(query.dims(), 0);
p.k_strideB = DimStride(key.dims(), 0);
p.v_strideB = DimStride(value.dims(), 0);
p.q_strideM = DimStride(query.dims(), 1);
p.k_strideM = DimStride(key.dims(), 1);
p.v_strideM = DimStride(value.dims(), 1);
p.q_strideH = DimStride(query.dims(), 2);
p.k_strideH = DimStride(key.dims(), 2);
p.v_strideH = DimStride(value.dims(), 2);
p.o_strideM = DimStride(output->dims(), 1);
if (bias) {
p.attn_bias_ptr = SafeGetTensorPtr<scalar_t>(bias);
p.bias_strideB = q_dims[2] * q_dims[1] * k_dims[1];
p.bias_strideH = q_dims[1] * k_dims[1];
p.bias_strideM = k_dims[1];
} else {
p.attn_bias_ptr = nullptr;
}
VLOG(3) << "attn_bias_ptr " << p.attn_bias_ptr;
VLOG(3) << "bias_strideB " << p.bias_strideB;
VLOG(3) << "bias_strideH " << p.bias_strideH;
VLOG(3) << "bias_strideM " << p.bias_strideM;
phi::Dim<1> seed_dims;
seed_dims[0] = 2;
seed_and_offset->Resize(seed_dims);
ctx.template HostAlloc<int64_t>(seed_and_offset);
int64_t* seed_and_offset_ptr = SafeGetTensorPtr<int64_t>(seed_and_offset);
auto gen = ctx.GetGenerator();
uint64_t inc = query.dims()[0] * query.dims()[2] * 32;
auto seed_offset_pair = gen->IncrementOffset(inc);
auto seed = (seed_offset_pair.first);
auto offset = (seed_offset_pair.second);
seed_and_offset_ptr[0] = (int64_t)seed;
seed_and_offset_ptr[1] = (int64_t)offset;
VLOG(3) << "seed and offset: " << seed << " " << offset << " "
<< seed_and_offset_ptr;
p.use_dropout = use_dropout;
if (use_dropout) {
p.seed = seed;
p.offset = offset;
p.dropout_prob = dropout_p;
} else {
p.dropout_prob = 0.0;
}
if (smem_bytes > 0xc000) {
const void* kernel_fn_void_ptr =
reinterpret_cast<const void*>(reinterpret_cast<uintptr_t>(kernel_fn));
PADDLE_ENFORCE_GPU_SUCCESS(
cudaFuncSetAttribute(kernel_fn_void_ptr,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_bytes));
}
KernelType::check_supported(p);
VLOG(3) << "Kernel launched with func : " << typeid(kernel_fn).name()
<< " block dim " << p.getBlocksGrid() << " thread dim "
<< p.getThreadsGrid();
kernel_fn<<<p.getBlocksGrid(),
p.getThreadsGrid(),
smem_bytes,
ctx.stream()>>>(p);
};
dispatch_cutlass_forward<T>(ctx, launchKernel);
PADDLE_ENFORCE_EQ(kernel_launched,
true,
paddle::platform::errors::InvalidArgument(
"the kernel should not be launched"));
}
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(
memory_efficient_attention,
GPU,
ALL_LAYOUT,
phi::fusion::cutlass_internal::MemoryEfficientAttentionForwardKernel,
float,
phi::dtype::bfloat16,
phi::dtype::float16) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <float.h>
#include <stdio.h>
#include <cmath>
////////////////////////////////////////////////////////////////////////////////
// Debugging functions
////////////////////////////////////////////////////////////////////////////////
// Nans & inf detection
#define NANCHECK(frag) \
{ \
for (int _i = 0; _i < frag.size(); ++_i) { \
assert(std::isfinite(static_cast<float>(frag[_i]))); \
assert(!std::isnan(static_cast<float>(frag[_i]))); \
} \
}
// Print on the first thread of the first block
#if 1
#define PRINT_WARP_ID 0
#define PRINT_LANE_ID 0
#define PRINT_B0_T0(msg, ...) \
if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && \
threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \
threadIdx.z == 0) { \
printf(msg "\n", ##__VA_ARGS__); \
}
#define PRINT_T0(msg, ...) \
if (threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \
threadIdx.z == 0) { \
printf(msg "\n", ##__VA_ARGS__); \
}
#define PRINT_TX_LX(msg, ...) \
for (int bx = 0; bx < gridDim.x; ++bx) { \
for (int by = 0; by < gridDim.y; ++by) { \
for (int bz = 0; bz < gridDim.z; ++bz) { \
for (int tx = 0; tx < blockDim.x; ++tx) { \
for (int ty = 0; ty < blockDim.y; ++ty) { \
for (int tz = 0; tz < blockDim.z; ++tz) { \
__syncthreads(); \
if (blockIdx.x == bx && blockIdx.y == by && blockIdx.z == bz && \
threadIdx.x == tx && threadIdx.y == ty && \
threadIdx.z == tz) { \
printf("[%d,%d,%d][%d,%d,%d]" msg "\n", \
bx, \
by, \
bz, \
tx, \
ty, \
tz, \
##__VA_ARGS__); \
} \
} \
} \
} \
} \
} \
}
#else
#define PRINT_B0_T0
#define PRINT_TX_LX
#endif
struct __string_view {
char const* data;
std::size_t size;
};
#if __cplusplus >= 201402L
template <class T>
constexpr __string_view __get_type_name() {
char const* p = __PRETTY_FUNCTION__;
while (*p++ != '=')
; // NOLINT
for (; *p == ' '; ++p)
; // NOLINT
char const* p2 = p;
int count = 1;
for (;; ++p2) {
switch (*p2) {
case '[':
++count;
break;
case ']':
--count;
if (!count) return {p, std::size_t(p2 - p)};
}
}
return {};
}
#else
template <class T>
constexpr __string_view __get_type_name() {
return {"unsupported", 11};
}
#endif
// Print a given array
#define PRINT_ACCUM8_T0_L0_START(name, accum, start) \
PRINT_T0_L0("%s[%d:%d] - {%f, %f, %f, %f, %f, %f, %f, %f}", \
name, \
static_cast<int>(start), \
static_cast<int>(start + 8), \
static_cast<float>(accum[start + 0]), \
static_cast<float>(accum[start + 1]), \
static_cast<float>(accum[start + 2]), \
static_cast<float>(accum[start + 3]), \
static_cast<float>(accum[start + 4]), \
static_cast<float>(accum[start + 5]), \
static_cast<float>(accum[start + 6]), \
static_cast<float>(accum[start + 7]));
#define PRINT_ACCUM8_T0_L0(name, accum) PRINT_ACCUM8_T0_L0_START(name, accum, 0)
#define PRINT_FRAG_T0_L0(name, frag) \
{ \
auto typeStr = __get_type_name<decltype(frag)>(); \
PRINT_T0_L0("printing %s (%s)", name, typeStr.data); \
for (int _start = 0; _start < frag.size(); _start += 8) { \
PRINT_ACCUM8_T0_L0_START(" ", frag, _start); \
} \
/*__syncthreads(); NANCHECK(frag); */ \
}
#define PRINT_ARRAY_T0_L0_INCR(name, array, length, incr) \
{ \
PRINT_T0_L0("printing %s (len=%d)", name, static_cast<int>(length)); \
for (int _start = 0; _start < length; _start += incr) { \
PRINT_ACCUM8_T0_L0_START(" ", array, _start); \
} \
}
#define PRINT_ARRAY_T0_L0(name, array, length) \
PRINT_ARRAY_T0_L0_INCR(name, array, length, 8)
// Print a 4x4 matrix
#define PRINT_TENSOR4x4_T0_L0_START(name, ref, start_x, start_y) \
PRINT_T0_L0( \
"%s[%d:%d, %d:%d]:\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, " \
"%f, %f, %f\n %f, %f, %f, %f", \
name, \
static_cast<int>(start_x), \
static_cast<int>(start_x + 4), \
static_cast<int>(start_y), \
static_cast<int>(start_y + 4), \
static_cast<float>(ref.at({start_x + 0, start_y + 0})), \
static_cast<float>(ref.at({start_x + 0, start_y + 1})), \
static_cast<float>(ref.at({start_x + 0, start_y + 2})), \
static_cast<float>(ref.at({start_x + 0, start_y + 3})), \
static_cast<float>(ref.at({start_x + 1, start_y + 0})), \
static_cast<float>(ref.at({start_x + 1, start_y + 1})), \
static_cast<float>(ref.at({start_x + 1, start_y + 2})), \
static_cast<float>(ref.at({start_x + 1, start_y + 3})), \
static_cast<float>(ref.at({start_x + 2, start_y + 0})), \
static_cast<float>(ref.at({start_x + 2, start_y + 1})), \
static_cast<float>(ref.at({start_x + 2, start_y + 2})), \
static_cast<float>(ref.at({start_x + 2, start_y + 3})), \
static_cast<float>(ref.at({start_x + 3, start_y + 0})), \
static_cast<float>(ref.at({start_x + 3, start_y + 1})), \
static_cast<float>(ref.at({start_x + 3, start_y + 2})), \
static_cast<float>(ref.at({start_x + 3, start_y + 3})));
#define PRINT_TENSOR4x4_T0_L0(name, ref) \
PRINT_TENSOR4x4_T0_L0_START(name, ref, 0, 0)
#define PRINT_PROBLEM_SIZE(name, ps) \
PRINT_T0_L0("%s.problem_size: {.m=%d, .n=%d, .k=%d}", \
name, \
static_cast<int>(ps.m()), \
static_cast<int>(ps.n()), \
static_cast<int>(ps.k()))
template <typename LambdaIterator, typename LaneOffsetT, typename AccumT>
CUTLASS_DEVICE void print_warp_accum(AccumT accum,
LaneOffsetT lane_offset,
int32_t num_rows,
int32_t num_cols) {
bool is_main = blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 &&
threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0;
for (int row = 0; row < num_rows; ++row) {
for (int col = 0; col < num_cols; ++col) {
if (col % 32 == 0) {
if (is_main) {
printf("\nmat[%3d, %3d:%3d]", row, col, col + 32);
}
__syncthreads();
}
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {},
[&](int accum_m, int accum_n, int idx) {
if (row == accum_m && col == accum_n &&
(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)) {
printf(" %6.1f", static_cast<float>(accum[idx]));
}
},
[&](int accum_m) {});
__syncthreads();
}
if (is_main) {
printf("\n");
}
}
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
The epilogue rearranges the result of a matrix product through shared memory
to match canonical tensor layouts in global memory. Epilogues support
conversion and reduction operations.
This is a copy of cutlass/epilogue/threadblock/epilogue.h that can
handle "row_id" as a first argument, as uses it to get the corresponding
`m_prime` / `s_prime` to rescale the output.
*/
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include <assert.h>
#endif
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/functional.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/layout/vector.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_coord.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
#include "cutlass/epilogue/threadblock/epilogue_base.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
#include "./epilogue_pipelined.h"
#include "cutlass/epilogue/thread/scale_type.h"
#include "cutlass/numeric_conversion.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace thread {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Applies a linear combination operator to an array of elements.
// output <- alpha * accumulator + beta * source
// with:
// alpha = 1 / s_prime (to normalize when isLast=True, 1 otherwise)
// beta = alpha / m_prime (renormalize the output when the max changes)
// source is the current output
template <typename ElementOutput_, /// < Data type used to store tensors
typename ElementSource_, // < Data type for source (usually matches
// `ElementOutput`)
int Count, ///< Number of elements computed per operation.
///< Usually it is 128/sizeof_bits<ElementOutput_>,
///< but we use 64 or 32 sometimes when there are not
///< enough data to store
typename ElementAccumulator_, ///< Accumulator data type
typename ElementCompute_, ///< Data type used to compute linear
///< combination
bool isFirst,
bool isLast,
typename FragmentAlphaBeta_,
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest>
class MemoryEfficientAttentionNormalize {
public:
using ElementOutput = ElementOutput_;
using ElementSource = ElementSource_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
static int const kCount = Count;
using FragmentOutput = Array<ElementOutput, kCount>;
using FragmentSource = Array<ElementSource, kCount>;
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
using ComputeFragment = Array<ElementCompute, kCount>;
using FragmentAlphaBeta = FragmentAlphaBeta_;
static FloatRoundStyle const kRound = Round;
private:
//
// Data members
//
FragmentAlphaBeta const& s_prime_;
FragmentAlphaBeta const& m_prime_;
public:
/// Constructs the function object, possibly loading from pointers in host
/// memory
CUTLASS_HOST_DEVICE
MemoryEfficientAttentionNormalize(FragmentAlphaBeta const& s_prime,
FragmentAlphaBeta const& m_prime)
: s_prime_(s_prime), m_prime_(m_prime) {}
/// Returns true if source is needed
CUTLASS_HOST_DEVICE
bool is_source_needed() const { return !isFirst; }
/// Functionally required for serial reduction in the epilogue
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count) {}
/// Computes linear scaling: D = alpha * accumulator + beta * source
CUTLASS_HOST_DEVICE
FragmentOutput operator()(int row,
FragmentAccumulator const& accumulator,
FragmentSource const& source) const {
assert(!isFirst);
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementSource, kCount, Round>
source_converter;
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
accumulator_converter;
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
destination_converter;
ComputeFragment converted_source = source_converter(source);
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_add_source;
multiply_add<ComputeFragment> mul_add_accumulator;
ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1;
ElementCompute beta = alpha * m_prime_[row];
intermediate = mul_add_source(beta, converted_source); // X = beta * C
intermediate = mul_add_accumulator(
alpha, converted_accumulator, intermediate); // D = alpha * Accum + X
return destination_converter(intermediate);
}
/// Computes linear scaling: D = alpha * accumulator
CUTLASS_HOST_DEVICE
FragmentOutput operator()(int row,
FragmentAccumulator const& accumulator) const {
assert(isFirst);
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
accumulator_converter;
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
destination_converter;
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_accumulator;
ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1;
intermediate = mul_accumulator(
alpha, converted_accumulator); // X = alpha * C + uniform
return destination_converter(intermediate);
}
};
} // namespace thread
namespace threadblock {
template <typename EO,
typename ES,
int Count,
typename EA,
typename EC,
bool F,
bool L,
typename FAB,
FloatRoundStyle R>
struct ApplyEpilogueOp<thread::MemoryEfficientAttentionNormalize<EO,
ES,
Count,
EA,
EC,
F,
L,
FAB,
R>> {
using Op = thread::
MemoryEfficientAttentionNormalize<EO, ES, Count, EA, EC, F, L, FAB, R>;
static CUTLASS_DEVICE typename Op::FragmentOutput apply(
Op const& output_op,
int row_id,
typename Op::FragmentAccumulator const& accum,
typename Op::FragmentSource const& source) {
return output_op(row_id, accum, source);
}
static CUTLASS_DEVICE typename Op::FragmentOutput apply(
Op const& output_op,
int row_id,
typename Op::FragmentAccumulator const& accum) {
return output_op(row_id, accum);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Functor performing linear combination operations used by epilogues.
*/
#pragma once
#include <cuda_fp16.h>
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/functional.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace thread {
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
template <typename Element, int ElementsPerAccess>
struct ArrayExponential {
CUTLASS_HOST_DEVICE
Array<Element, ElementsPerAccess> operator()(
Array<Element, ElementsPerAccess> const& input) const {
Array<Element, ElementsPerAccess> result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ElementsPerAccess; ++i) {
result[i] = expf(input[i]);
}
return result;
}
};
template <int ElementsPerAccess>
struct ArrayExponential<half_t, ElementsPerAccess> {
CUTLASS_DEVICE
Array<half_t, ElementsPerAccess> operator()(
Array<half_t, ElementsPerAccess> const& input) const {
Array<half_t, ElementsPerAccess> result;
int const kVectorCount = ElementsPerAccess / 2;
__half2 const* input_ptr =
reinterpret_cast<__half2 const*>(input.raw_data());
__half2* res_ptr = reinterpret_cast<__half2*>(result.raw_data());
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kVectorCount; ++i) {
res_ptr[i] = h2exp(input_ptr[i]);
}
return result;
}
};
} // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Applies:
/// output <- (input - lse).exp()
template <typename ElementOutput_, // output
typename ElementLSE_, // accumulator from LSE
typename ElementAccumulator_, // accumulator from matmul
typename ElementCompute_, // intermediate compute (and exp
// calculation)
int ElementsPerAccess>
class ApplyLogSumExp {
public:
using ElementOutput = ElementOutput_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
using ElementLSE = ElementLSE_;
static int const kElementsPerAccess = ElementsPerAccess;
static int const kCount = kElementsPerAccess;
static const ScaleType::Kind kScale =
cutlass::epilogue::thread::ScaleType::NoBetaScaling;
using FragmentOutput = Array<ElementOutput, kCount>;
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
using FragmentLSE = Array<ElementLSE, kElementsPerAccess>;
using FragmentScaleBias = FragmentLSE; // Used by epilogue_smem_accumulator.h
public:
//
// Methods
//
CUTLASS_HOST_DEVICE
ApplyLogSumExp() {}
/// Returns true if source is needed
CUTLASS_HOST_DEVICE
bool is_source_needed() const { return true; }
/// Functionally required for serial reduction in the epilogue
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count) {}
CUTLASS_HOST_DEVICE
FragmentOutput operator()(FragmentAccumulator const& AB,
FragmentLSE const& scale_unused,
// bias used as LSE
FragmentLSE const& bias) const {
FragmentCompute frag_AB = NumericArrayConverter<ElementCompute,
ElementAccumulator,
kElementsPerAccess>()(AB);
FragmentCompute frag_lse_compute =
NumericArrayConverter<ElementCompute, ElementLSE, kElementsPerAccess>()(
bias);
FragmentCompute frag_compute;
minus<FragmentCompute> minus_lse;
detail::ArrayExponential<ElementCompute, kElementsPerAccess> apply_exp;
frag_compute = minus_lse(frag_AB, frag_lse_compute);
frag_compute = apply_exp(frag_compute);
return NumericArrayConverter<ElementOutput,
ElementCompute,
kElementsPerAccess>()(frag_compute);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace thread
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include "./custom_mma_multistage.h"
#include "./custom_mma_pipelined.h"
#include "cutlass/gemm/threadblock/mma_multistage.h"
#include "cutlass/gemm/threadblock/mma_pipelined.h"
template <typename Mma, int kMaxK>
struct MakeCustomMma;
template <typename Shape,
typename IteratorA,
typename SmemIteratorA,
cutlass::arch::CacheOperation::Kind CacheOpA,
typename IteratorB,
typename SmemIteratorB,
cutlass::arch::CacheOperation::Kind CacheOpB,
typename ElementC,
typename LayoutC,
typename Policy,
int Stages,
cutlass::gemm::SharedMemoryClearOption SharedMemoryClear,
int kMaxK>
struct MakeCustomMma<
cutlass::gemm::threadblock::MmaMultistage<Shape,
IteratorA,
SmemIteratorA,
CacheOpA,
IteratorB,
SmemIteratorB,
CacheOpB,
ElementC,
LayoutC,
Policy,
Stages,
SharedMemoryClear>,
kMaxK> {
// Reduce the number of stages if we don't need that many
static int constexpr kStages =
kMaxK == cutlass::platform::numeric_limits<int>::max()
? Stages
: cutlass::const_min(Stages,
(kMaxK + static_cast<int>(Shape::kK) - 1) /
static_cast<int>(Shape::kK));
using Mma = cutlass::gemm::threadblock::CustomMmaMultistage<Shape,
IteratorA,
SmemIteratorA,
CacheOpA,
IteratorB,
SmemIteratorB,
CacheOpB,
ElementC,
LayoutC,
Policy,
kStages,
SharedMemoryClear,
kMaxK>;
};
template <typename Shape,
typename IteratorA,
typename SmemIteratorA,
typename IteratorB,
typename SmemIteratorB,
typename ElementC,
typename LayoutC,
typename Policy,
int kMaxK>
struct MakeCustomMma<cutlass::gemm::threadblock::MmaPipelined<Shape,
IteratorA,
SmemIteratorA,
IteratorB,
SmemIteratorB,
ElementC,
LayoutC,
Policy>,
kMaxK> {
using Mma = cutlass::gemm::threadblock::CustomMmaPipelined<Shape,
IteratorA,
SmemIteratorA,
IteratorB,
SmemIteratorB,
ElementC,
LayoutC,
Policy>;
};
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/threadblock/mma_base.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Used for partial specialization
typename Enable = bool>
class CustomMmaBase {
public:
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Policy describing tuning details
using Policy = Policy_;
//
// Dependent types
//
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Shape describing the overall GEMM computed from shared memory
/// by each warp.
using WarpGemm = typename Policy::Operator::Shape;
/// Shape describing the number of warps filling the CTA
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM,
Shape::kN / WarpGemm::kN,
Shape::kK / WarpGemm::kK>;
/// Number of warp-level GEMM oeprations
static int const kWarpGemmIterations =
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
/// Number of stages
static int const kStages = Stages;
//
// Nested structs
//
/// Shared storage object needed by threadblock-scoped GEMM
template <typename Element, typename OperandShape, typename OperandLayout>
struct OperandSharedStorage {
AlignedBuffer<Element, OperandShape::kCount> buffer;
using TensorRef = TensorRef<Element, OperandLayout>;
CUTLASS_DEVICE
static OperandLayout Layout() {
return OperandLayout::packed({OperandShape::kRow, OperandShape::kColumn});
}
/// Returns a TensorRef to the operand
CUTLASS_HOST_DEVICE
TensorRef ref() { return TensorRef{buffer.data(), Layout()}; }
};
/// Shape of the A matrix operand in shared memory
using ShapeA =
MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,
Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
/// Shape of the B matrix operand in shared memory
using ShapeB = MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
Shape::kN + Policy::SmemPaddingB::kColumn>;
using SharedStorageA = OperandSharedStorage<typename Operator::ElementA,
ShapeA,
typename Operator::LayoutA>;
using SharedStorageB = OperandSharedStorage<typename Operator::ElementB,
ShapeB,
typename Operator::LayoutB>;
using TensorRefA = typename SharedStorageA::TensorRef;
using TensorRefB = typename SharedStorageB::TensorRef;
struct SharedStorage {
/// Buffer for A operand
SharedStorageA operand_A;
/// Buffer for B operand
SharedStorageB operand_B;
};
protected:
//
// Data members
//
/// Iterator to load a warp-scoped tile of A operand from shared memory
typename Operator::IteratorA warp_tile_iterator_A_;
/// Iterator to load a warp-scoped tile of B operand from shared memory
typename Operator::IteratorB warp_tile_iterator_B_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
CustomMmaBase(
///< Shared storage needed for internal use by threadblock-scoped GEMM
SharedStorageA& shared_storageA, // NOLINT
SharedStorageB& shared_storageB, // NOLINT
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: warp_tile_iterator_A_(shared_storageA.ref(), lane_idx),
warp_tile_iterator_B_(shared_storageB.ref(), lane_idx) {}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "./custom_mma_base.h"
#include "cutlass/gemm/gemm.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Transformation applied to A operand
typename TransformA_ =
NumericArrayConverter<typename SmemIteratorA_::Element,
typename IteratorA_::Element,
IteratorA_::Fragment::kElements>,
///
/// Transformation applied to B operand
typename TransformB_ =
NumericArrayConverter<typename SmemIteratorB_::Element,
typename IteratorB_::Element,
IteratorB_::Fragment::kElements>,
/// Used for partial specialization
typename Enable = bool>
class CustomMmaPipelined : public CustomMmaBase<Shape_, Policy_, 2> {
public:
///< Base class
using Base = CustomMmaBase<Shape_, Policy_, 2>;
using Shape =
Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using IteratorA =
IteratorA_; ///< Iterates over tiles of A operand in global memory
using IteratorB =
IteratorB_; ///< Iterates over tiles of B operand in global memory
using ElementC = ElementC_; ///< Data type of accumulator matrix
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
using Policy = Policy_; ///< Policy describing tuning details
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
using TransformA = TransformA_;
using TransformB = TransformB_;
//
// Dependent types
//
/// Fragment of operand A loaded from global memory
using FragmentA = typename IteratorA::Fragment;
/// Fragment of operand B loaded from global memory
using FragmentB = typename IteratorB::Fragment;
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Obtain the arch tag from the warp-level operator
using ArchTag = typename Policy::Operator::ArchTag;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
static_assert((Base::kStages == 2),
"MmaPipelined requires kStages set to value 2");
static bool const kSmemContainsEntireMat = false;
private:
using WarpFragmentA = typename Operator::FragmentA;
using WarpFragmentB = typename Operator::FragmentB;
protected:
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
CustomMmaPipelined(typename Base::SharedStorageA& shared_storageA, // NOLINT
typename Base::SharedStorageB& shared_storageB, // NOLINT
int thread_idx, ///< ID within the threadblock
int warp_idx, ///< ID of warp
int lane_idx ///< ID of each thread within a warp
)
: Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx),
smem_iterator_A_(shared_storageA.ref(), thread_idx),
smem_iterator_B_(shared_storageB.ref(), thread_idx) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset(
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset(
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
}
CUTLASS_DEVICE
CustomMmaPipelined(
///< Shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorage& st, // NOLINT
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: CustomMmaPipelined(
st.operand_A, st.operand_B, thread_idx, warp_idx, lane_idx) {}
CUTLASS_DEVICE
bool set_prologue_done(bool value) {
// NOT IMPLEMENTED FOR PIPELINED
}
CUTLASS_DEVICE
bool set_zero_outside_bounds(bool value) {
// NOT NEEDED FOR PIPELINED
// shared memory will always be zero-filled
}
template <bool kLoadA = true, bool kLoadB = true>
CUTLASS_DEVICE static void prologue(
typename Base::SharedStorage& shared_storage, // NOLINT
///< iterator over A operand in global memory
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B,
int thread_idx,
int problem_size_k) {
prologue<kLoadA, kLoadB>(shared_storage.operand_A,
shared_storage.operand_B,
iterator_A,
iterator_B,
thread_idx,
problem_size_k);
}
template <bool kLoadA = true, bool kLoadB = true>
CUTLASS_DEVICE static void prologue(
typename Base::SharedStorageA& shared_storageA, // NOLINT
typename Base::SharedStorageB& shared_storageB, // NOLINT
///< iterator over A operand in global memory
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B,
int thread_idx,
int problem_size_k) {
// NOT IMPLEMENTED FOR PIPELINED
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(
int gemm_k_iterations, ///< number of iterations of the mainloop
FragmentC& accum, ///< destination accumulator tile //NOLINT
IteratorA iterator_A, ///< iterator over A operand in global memory
IteratorB iterator_B, ///< iterator over B operand in global memory
FragmentC const& src_accum, ///< source accumulator tile
TransformA transform_A =
TransformA(), ///< transformation applied to A fragment
TransformB transform_B =
TransformB()) { ///< transformation applied to B fragment
//
// Prologue
//
// Perform accumulation in the 'd' output operand
accum = src_accum;
FragmentA tb_frag_A;
FragmentB tb_frag_B;
tb_frag_A.clear();
tb_frag_B.clear();
// The last kblock is loaded in the prolog
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
++iterator_A;
++iterator_B;
this->smem_iterator_A_.store(transform_A(tb_frag_A));
this->smem_iterator_B_.store(transform_B(tb_frag_B));
++this->smem_iterator_A_;
++this->smem_iterator_B_;
__syncthreads();
// Pair of fragments used to overlap shared memory loads and math
// instructions
WarpFragmentA warp_frag_A[2];
WarpFragmentB warp_frag_B[2];
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
Operator warp_mma;
int smem_write_stage_idx = 1;
// Avoid reading out of bounds
iterator_A.clear_mask(gemm_k_iterations <= 1);
iterator_B.clear_mask(gemm_k_iterations <= 1);
// Issue loads during the first warp-level matrix multiply-add *AFTER*
// issuing shared memory loads (which have the tighest latency requirement).
//
// Mainloop
//
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > 0; --gemm_k_iterations) {
//
// Loop over GEMM K dimension
//
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
++warp_mma_k) {
// Load warp-level tiles from shared memory, wrapping to k offset if
// this is the last group as the case may be.
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
// Write fragments to shared memory
this->smem_iterator_A_.store(transform_A(tb_frag_A));
this->smem_iterator_B_.store(transform_B(tb_frag_B));
__syncthreads();
++this->smem_iterator_A_;
++this->smem_iterator_B_;
// Add negative offsets to return iterators to the 'start' of the
// circular buffer in shared memory
if (smem_write_stage_idx == 1) {
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
} else {
this->warp_tile_iterator_A_.add_tile_offset(
{0,
-Base::kStages * Policy::kPartitionsK *
Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK *
Base::kWarpGemmIterations,
0});
}
smem_write_stage_idx ^= 1;
}
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) %
Base::kWarpGemmIterations);
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) %
Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
if (warp_mma_k == 0) {
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
++iterator_A;
++iterator_B;
// Avoid reading out of bounds if this was the last loop iteration
iterator_A.clear_mask(gemm_k_iterations <= 2);
iterator_B.clear_mask(gemm_k_iterations <= 2);
}
warp_mma(accum,
warp_frag_A[warp_mma_k % 2],
warp_frag_B[warp_mma_k % 2],
accum);
}
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
/*! \file
\brief Cutlass provides helper template functions to figure out the right
datastructures to instanciate to run a GEMM with various parameters (see
`cutlass/gemm/threadblock/default_mma.h`). However, due to template
instantiation priority rules, it will only create an MmaMultiStage with
kStages=3 (otherwise creates an MmePipelined - which is not compatible with
FastF32). kStages=3 uses too much shared memory and we want to use kStages=2,
so we just copy-pasted some code from `default_mma.h` and
`default_mma_core.h` files and wrapped this template to allow our usecase.
This is really only for the FastF32 case - aka using TensorCores with fp32.
*/
#pragma once
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Layout type for C and D matrix operand
typename LayoutC,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation perfomed by GEMM
typename Operator,
typename Enable_ = void>
struct FindDefaultMma {
static constexpr bool AccumulatorsInRowMajor = false;
static constexpr SharedMemoryClearOption SharedMemoryClear =
SharedMemoryClearOption::kNone;
using DefaultMma =
cutlass::gemm::threadblock::DefaultMma<ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementAccumulator,
LayoutC,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
Stages,
Operator,
AccumulatorsInRowMajor,
SharedMemoryClear>;
};
/// Specialization for sm80 / FastF32 / multistage with kStages=2
template <typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
int kStages,
typename Operator>
struct FindDefaultMma<
ElementA_,
LayoutA_,
kAlignmentA,
ElementB_,
LayoutB_,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
arch::Sm80,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
typename cutlass::platform::enable_if<(kAlignmentA > 1)>::type> {
using LayoutC = layout::RowMajor;
using OperatorClass = arch::OpClassTensorOp;
using ArchTag = arch::Sm80;
using DefaultMma_ = cutlass::gemm::threadblock::DefaultMma<ElementA_,
LayoutA_,
kAlignmentA,
ElementB_,
LayoutB_,
kAlignmentB,
ElementAccumulator,
LayoutC,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
3,
Operator>;
struct DefaultMma : DefaultMma_ {
using MmaCore_ = typename DefaultMma_::MmaCore;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<
typename MmaCore_::Shape,
typename DefaultMma_::IteratorA,
typename MmaCore_::SmemIteratorA,
MmaCore_::kCacheOpA,
typename DefaultMma_::IteratorB,
typename MmaCore_::SmemIteratorB,
MmaCore_::kCacheOpB,
ElementAccumulator,
LayoutC,
typename MmaCore_::MmaPolicy,
kStages>;
};
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include "cutlass/functional.h"
#include "cutlass/gemm/warp/mma_simt_tile_iterator.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
#include "cutlass/matrix_shape.h"
/*
TensorCores have different accumulator layouts.
This file provides a class to easily map the accumulator
i-th element with the corresponding matrix row/col.
*/
template <typename T, typename accum_t, int kWarpSize>
struct AccumLambdaIteratorSm80 {
static_assert(cutlass::platform::is_same<typename T::Layout,
cutlass::layout::RowMajor>::value,
"only RowMajor is supported");
using Policy = typename T::Policy;
using InstructionShape = typename T::InstructionShape;
using OpDelta = typename T::OpDelta;
using Shape = typename T::Shape;
static int const kElementsPerAccess = InstructionShape::kN / 4;
static int const kRowsPerTile = 8;
static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile;
static cutlass::MatrixCoord CUTLASS_DEVICE
get_lane_offset(int8_t lane_id,
int8_t warp_id,
typename T::TensorCoord const& tile_offset) {
int quad = (lane_id >> 2);
int lane_in_quad = (lane_id & 3);
return cutlass::MatrixCoord(quad + tile_offset.row() * Shape::kRow,
lane_in_quad * kElementsPerAccess +
tile_offset.column() * Shape::kColumn);
}
template <typename FA, typename FB, typename FC>
CUTLASS_DEVICE static void iterateRows(
cutlass::MatrixCoord& lane_offset, // NOLINT
FA beginRow,
FB op,
FC endRow) {
// See cutlass/gemm/warp/mma_tensor_op_tile_iterator.h
CUTLASS_PRAGMA_UNROLL
for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < kAccumulatorRows; ++row) {
int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow +
row * kRowsPerTile + lane_offset.row();
beginRow(accum_m);
CUTLASS_PRAGMA_UNROLL
for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) {
int mma_accum_start = kAccumulatorRows * kElementsPerAccess *
(mma_n * Policy::MmaIterations::kRow + mma_m);
CUTLASS_PRAGMA_UNROLL
for (int col = 0; col < kElementsPerAccess; ++col) {
int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn +
col + lane_offset.column();
int idx = mma_accum_start + row * kElementsPerAccess + col;
op(accum_m, accum_n, idx);
}
}
endRow(accum_m);
}
}
}
template <typename DT, typename F>
CUTLASS_DEVICE static bool reduceSameRow(int lane_id,
DT& myValue, // NOLINT
F fn) { // NOLINT
// In each warp, 4 threads will work on the same row
// - the ones with the same `quad`
auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1);
myValue = fn(myValue, otherV);
otherV = __shfl_xor_sync(0xffffffff, myValue, 2);
myValue = fn(myValue, otherV);
int lane_in_quad = (lane_id & 3);
return lane_in_quad == 0;
}
};
template <typename T, typename accum_t, int kWarpSize>
struct AccumLambdaIteratorSm70 {
static_assert(cutlass::platform::is_same<typename T::Layout,
cutlass::layout::RowMajor>::value,
"only RowMajor is supported");
using Policy = typename T::Policy;
using InstructionShape = typename T::InstructionShape;
using OpDelta = typename T::OpDelta;
using Shape = typename T::Shape;
using Element = accum_t;
static int const kElementsPerPartial = 4;
using EleShapePerPatial = typename cutlass::platform::conditional<
cutlass::platform::is_same<Element, float>::value,
cutlass::MatrixShape<2, 2>,
cutlass::MatrixShape<1, 4>>::type;
static int const kElementsPerMma = 8;
static int const kAccumulatorPatials = 2;
using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>;
static cutlass::MatrixCoord CUTLASS_DEVICE
get_lane_offset(int8_t lane_id,
int8_t warp_id,
typename T::TensorCoord const& tile_offset) {
int quad = (lane_id >> 2);
int lane_in_quad = (lane_id & 3);
int accum_m, accum_n;
if (cutlass::platform::is_same<Element, float>::value) {
// (quad[2],quad[0])+lane_in_quad[0]
accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1);
// (quad[1])+lane_in_quad[1]
accum_n =
((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials +
(lane_in_quad & 2);
} else {
accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 +
lane_in_quad; // (quad[2],quad[0])
accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials;
}
return cutlass::MatrixCoord(
accum_m + tile_offset.row() * Shape::kRow,
accum_n + tile_offset.column() * Shape::kColumn);
}
template <typename DT, typename F>
CUTLASS_DEVICE static bool reduceSameRow(int lane_id,
DT& myValue, // NOLINT
F fn) { // NOLINT
static_assert(cutlass::platform::is_same<Element, float>::value,
"update to support non-float accum");
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
// T0 & T2 share same line within a quad
auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 1);
myValue = fn(myValue, otherV);
// quad 0 and quad 2 are on the same lines
otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 3);
myValue = fn(myValue, otherV);
return (lane_id & ((1 << 1) | (1 << 3))) == 0;
}
template <typename FA, typename FB, typename FC>
CUTLASS_DEVICE static void iterateRows(
cutlass::MatrixCoord& lane_offset, // NOLINT
FA beginRow,
FB op,
FC endRow) {
CUTLASS_PRAGMA_UNROLL
for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) {
CUTLASS_PRAGMA_UNROLL
for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < EleShapePerPatial::kRow; ++m) {
int accum_m = tile_m * Policy::InterleavedTile::kRow +
mma_m * QuadShapePerPatialMma::kRow + m * 2 +
lane_offset.row();
beginRow(accum_m);
CUTLASS_PRAGMA_UNROLL
for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn;
++tile_n) {
CUTLASS_PRAGMA_UNROLL
for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn;
++mma_n) {
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < kAccumulatorPatials; ++p) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < EleShapePerPatial::kColumn; ++n) {
int mma_accum_start =
(((tile_n * Policy::TileIterations::kRow + tile_m) *
Policy::MmaIterations::kColumn +
mma_n) *
Policy::MmaIterations::kRow +
mma_m) *
kElementsPerMma;
int accum_n = tile_n * Policy::InterleavedTile::kColumn +
mma_n * QuadShapePerPatialMma::kColumn +
p * Policy::InterleavedTile::kColumn / 2 + n +
lane_offset.column();
int idx = mma_accum_start + p * kElementsPerPartial +
m * EleShapePerPatial::kColumn + n;
op(accum_m, accum_n, idx);
}
}
}
}
endRow(accum_m);
}
}
}
}
};
template <typename T, typename accum_t, int kWarpSize>
struct AccumLambdaIteratorSimt {
using Policy = typename T::Policy;
using Iterations = typename T::Iterations;
using Element = typename T::Element;
using Delta = typename T::Delta;
using Shape = typename T::Shape;
static_assert(cutlass::platform::is_same<typename T::Layout,
cutlass::layout::RowMajor>::value,
"only RowMajor is supported");
template <typename DT, typename F>
CUTLASS_DEVICE static bool reduceSameRow(int lane_id,
DT& myValue, // NOLINT
F fn) { // NOLINT
CUTLASS_PRAGMA_UNROLL
for (int bit = 1; bit < Policy::WarpShape::kColumn; bit *= 2) {
auto otherV = __shfl_xor_sync(0xffffffff, myValue, bit);
myValue = fn(myValue, otherV);
}
return (lane_id & (Policy::WarpShape::kColumn - 1)) == 0;
}
template <typename FA, typename FB, typename FC>
CUTLASS_DEVICE static void iterateRows(
cutlass::MatrixCoord& lane_offset, // NOLINT
FA beginRow,
FB op,
FC endRow) {
CUTLASS_PRAGMA_UNROLL
for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) {
int accum_m = mma_m * Delta::kRow + m + lane_offset.row();
beginRow(accum_m);
CUTLASS_PRAGMA_UNROLL
for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {
int accum_n =
mma_n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN +
lane_offset.column();
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) {
int idx =
n + Policy::LaneMmaShape::kN *
(mma_n + Iterations::kColumn *
(m + mma_m * Policy::LaneMmaShape::kM));
op(accum_m, accum_n + n, idx);
}
}
endRow(accum_m);
}
}
}
static cutlass::MatrixCoord CUTLASS_DEVICE
get_lane_offset(int8_t lane_id,
int8_t warp_id,
typename T::TensorCoord const& tile_offset) {
static_assert(cutlass::platform::is_same<
typename Policy::LaneLayout,
cutlass::layout::RowMajorInterleaved<1>>::value,
"");
typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
cutlass::MatrixCoord lane_offset =
lane_layout.inverse(lane_id) *
cutlass::MatrixCoord(Policy::LaneMmaShape::kM,
Policy::LaneMmaShape::kN);
return lane_offset +
tile_offset * cutlass::MatrixCoord(Shape::kRow, Shape::kColumn);
}
};
template <typename T, typename accum_t, int kWarpSize>
struct DefaultMmaAccumLambdaIterator;
// Simt
template <typename S, typename P, typename accum_t, int kWarpSize>
struct DefaultMmaAccumLambdaIterator<
cutlass::gemm::warp::MmaSimtTileIterator<S,
cutlass::gemm::Operand::kC,
accum_t,
cutlass::layout::RowMajor,
P,
1,
1>,
accum_t,
kWarpSize> {
using WarpIterator = typename cutlass::gemm::warp::MmaSimtTileIterator<
S,
cutlass::gemm::Operand::kC,
accum_t,
cutlass::layout::RowMajor,
P,
1,
1>;
using Iterator = AccumLambdaIteratorSimt<WarpIterator, accum_t, kWarpSize>;
};
// TensorOp - Volta
template <typename S1, typename S2, typename accum_t, int kWarpSize>
struct DefaultMmaAccumLambdaIterator<
cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator<
S1,
accum_t,
cutlass::layout::RowMajor,
S2,
cutlass::MatrixShape<1, 1>>,
accum_t,
kWarpSize> {
using WarpIterator =
typename cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator<
S1,
accum_t,
cutlass::layout::RowMajor,
S2,
cutlass::MatrixShape<1, 1>>;
using Iterator = AccumLambdaIteratorSm70<WarpIterator, accum_t, kWarpSize>;
};
// TensorOp - Sm75+
template <typename S1,
typename S2,
typename S3,
typename accum_t,
int kWarpSize>
struct DefaultMmaAccumLambdaIterator<
cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator<
S1,
accum_t,
cutlass::layout::RowMajor,
S2,
S3>,
accum_t,
kWarpSize> {
using WarpIterator =
typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator<
S1,
accum_t,
cutlass::layout::RowMajor,
S2,
S3>;
using Iterator = AccumLambdaIteratorSm80<WarpIterator, accum_t, kWarpSize>;
};
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include "cutlass/arch/mma.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/phi/core/enforce.h"
////////////////////////////////////////////////////////////////////////////////
// Some helper functions
////////////////////////////////////////////////////////////////////////////////
#define DISPATCH_TYPES(tensor, func) \
{ \
if (query.scalar_type() == at::ScalarType::Float) { \
using scalar_t = float; \
func(); \
} else if (query.scalar_type() == at::ScalarType::Half) { \
using scalar_t = cutlass::half_t; \
func(); \
} else if (query.scalar_type() == at::ScalarType::BFloat16) { \
using scalar_t = cutlass::bfloat16_t; \
func(); \
} else { \
PADDLE_CHECK(false, "Only fp32, half & bf16 supported at the moment"); \
} \
}
#define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \
{ \
if (BOOL_V) { \
constexpr bool BOOL_NAME = true; \
F(); \
} else { \
constexpr bool BOOL_NAME = false; \
F(); \
} \
}
#define DISPATCH_ARCHTAG(CC, func) \
{ \
if (CC >= 80) { \
using ArchTag = cutlass::arch::Sm80; \
func(); \
} else if (CC >= 75) { \
using ArchTag = cutlass::arch::Sm75; \
func(); \
} else if (CC >= 70) { \
using ArchTag = cutlass::arch::Sm70; \
func(); \
} else if (CC >= 50) { \
using ArchTag = cutlass::arch::Sm50; \
func(); \
} else { \
PADDLE_CHECK( \
false, \
"Your device is too old. We require compute capability >= 50"); \
} \
}
#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \
PADDLE_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
PADDLE_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
PADDLE_CHECK(TENSOR.is_contiguous());
#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \
PADDLE_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
PADDLE_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
PADDLE_CHECK(TENSOR.stride(-1) == 1, \
#TENSOR ": last dimension must be contiguous");
#ifdef defined(__CUDACC_RTC__)
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \
return false; \
}
#define PADDLE_CHECK(COND, ERR) \
if (!(COND)) { \
return false; \
}
#else
#include <iostream>
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \
std::cerr << #PTR " is not correctly aligned\n"; \
return false; \
}
#define PADDLE_CHECK(COND, ERR) \
if (!(COND)) { \
std::cerr << #COND " failed\n"; \
return false; \
}
#endif
#define ASSIGN_CHECK_OVERFLOW(A, B) \
{ \
A = B; \
PADDLE_CHECK(B < std::numeric_limits<decltype(A)>::max(), \
#B " overflows"); \
}
namespace gemm_kernel_utils {
template <typename integer>
constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) {
return (n + m - 1) / m;
}
template <typename integer>
constexpr CUTLASS_HOST_DEVICE integer align_up(integer n, integer m) {
return ((n + m - 1) / m) * m;
}
inline int32_t getMaximumSharedMemoryPerBlockKb(int cc) {
// from:
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications-technical-specifications-per-compute-capability
switch (cc) {
case 50:
case 52:
case 53:
case 60:
case 61:
case 62:
return 64;
case 70:
case 72:
return 96;
case 75:
return 64;
case 80:
return 163;
case 86:
return 99;
case 87:
return 163;
case 89:
return 99;
case 90:
return 227;
default:
return 0;
}
}
////////////////////////////////////////////////////////////////////////////////
// Determine the type of GEMM we do (TensorCores or not, Shapes ...)
// TODO(xformers): Maybe we could rely on Cutlass's DefaultGemm templates
////////////////////////////////////////////////////////////////////////////////
// Fallback to Simt (FMA on cuda cores) if not in a special case below
template <typename ArchTag, typename scalar_t_, typename Enable = void>
struct DefaultGemmType {
static constexpr int ThreadK = 8;
static constexpr int WarpK = 8;
static constexpr int kMinimumAlignment = 1;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
using OpClass = cutlass::arch::OpClassSimt;
using Operator = cutlass::arch::OpMultiplyAdd;
};
// Specialization for tensorcores with f32
template <typename ArchTag>
struct DefaultGemmType<ArchTag,
float,
typename cutlass::platform::enable_if<
ArchTag::kMinComputeCapability >= 80>::type> {
static constexpr int ThreadK = 32;
static constexpr int WarpK = 32;
static constexpr int kMinimumAlignment = 4;
using OpClass = cutlass::arch::OpClassTensorOp;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using Operator = cutlass::arch::OpMultiplyAddFastF32;
};
// Specialization for tensorcores with f16/bf16 - Sm75+
template <typename ArchTag, typename scalar_t>
struct DefaultGemmType<ArchTag,
scalar_t,
typename cutlass::platform::enable_if<
ArchTag::kMinComputeCapability >= 75 &&
cutlass::sizeof_bits<scalar_t>::value == 16>::type> {
static constexpr int ThreadK = 32;
static constexpr int WarpK = 32;
static constexpr int kMinimumAlignment = 4;
using OpClass = cutlass::arch::OpClassTensorOp;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using Operator = cutlass::arch::OpMultiplyAdd;
};
// Specialization for tensorcores with f16 - Volta
template <>
struct DefaultGemmType<cutlass::arch::Sm70, cutlass::half_t, void> {
static constexpr int ThreadK = 32;
static constexpr int WarpK = 32;
static constexpr int kMinimumAlignment = 2;
using OpClass = cutlass::arch::OpClassTensorOp;
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
using Operator = cutlass::arch::OpMultiplyAdd;
};
// Enables to do
// `auto x = kCondition ? fa(arg) : fb(arg)`
// when `fa` and `fb` have different types
template <bool kVal, typename TA, typename TB>
struct call_conditional;
template <typename TA, typename TB>
struct call_conditional<true, TA, TB> {
template <typename Arg>
static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg)
-> decltype(ta(arg)) {
return ta(arg);
}
};
template <typename TA, typename TB>
struct call_conditional<false, TA, TB> {
template <typename Arg>
static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg)
-> decltype(tb(arg)) {
return tb(arg);
}
};
////////////////////////////////////////////////////////////////////////////////
// Mark a variable as warp-uniform - enables some compiler optimizations
// The cheapest way to do it is just to broadcast it from lane 0
////////////////////////////////////////////////////////////////////////////////
CUTLASS_DEVICE int32_t warp_uniform(int32_t value) {
return (int32_t)__shfl_sync(0xffffffff, (unsigned)value, 0);
}
template <typename T>
CUTLASS_DEVICE T* warp_uniform(T* ptr) {
struct {
union {
T* ptr;
uint32_t asInt[2];
};
} p;
p.ptr = ptr;
p.asInt[0] = warp_uniform(p.asInt[0]);
p.asInt[1] = warp_uniform(p.asInt[1]);
return p.ptr;
}
} // namespace gemm_kernel_utils
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Generates combination of kernels - implementations and registry
# Kernels are ordered (see `sort_index`), and when dispatching,
# we select the first kernel in the list that supports the inputs
import argparse
import collections
import itertools
import os
import shutil
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Tuple, TypeVar
MAX_ARCH = 90
ENABLE_MACRO = "PADDLE_WITH_MEMORY_EFFICIENT_ATTENTION"
def convert_to_arch_list(arch):
arch = arch.lower().strip()
if arch == "all":
return [50, 70, 75, 80]
arch = [int(s.strip()) for s in arch.split(' ') if s.strip()]
arch = list(set(arch))
arch.sort()
for each_arch in arch:
assert each_arch < MAX_ARCH
return arch
def parse_args():
parser = argparse.ArgumentParser(
description="The argument for generating the memory efficient kernels."
)
parser.add_argument(
"--dst_path",
type=str,
default=str(Path(__file__).parent),
help="The destination path to save the generated files.",
)
parser.add_argument(
"--cuda_arch",
type=convert_to_arch_list,
default=convert_to_arch_list("All"),
help="The CUDA architecture to be generated.",
)
return parser.parse_args()
args = parse_args()
DTYPES = {
"f32": "float",
"f16": "cutlass::half_t",
"bf16": "cutlass::bfloat16_t",
}
SM = args.cuda_arch
KERNEL_IMPL_TEMPLATE = """__global__ void __launch_bounds__(
{CPP_CLASS}::kNumThreads,
{CPP_CLASS}::kMinBlocksPerSm)
{NAME}(typename {CPP_CLASS}::Params p) {{
#ifdef __CUDA_ARCH__
#if __CUDA_ARCH__ >= {SM}0
#if __CUDA_ARCH__ < {SM_MAX}0
if (!p.advance_to_block()) {{
return;
}}
{CPP_CLASS}::attention_kernel(p);
return;
#endif
#endif
printf(
"FATAL: kernel `{NAME}` is for sm{SM}-sm{SM_MAX}, but was built for sm%d\\n",
int(__CUDA_ARCH__ + 0) / 10);
#endif
}}
"""
@dataclass(order=True)
class FwdKernel:
sort_index: Tuple[int, ...] = field(init=False, repr=False)
aligned: bool
dtype: str
sm_range: Tuple[int, int]
q: int
k: int
single_value_iter: bool
supports_dropout: bool = True
supports_bias: bool = True
dispatch_cond: Optional[str] = None
def __post_init__(self) -> None:
# Set kernel selection priority
# The lowest value that matches inputs
# will be selected
self.sort_index = (
# First select aligned kernel
0 if self.aligned else 1,
# Then keep output in RF
0 if self.single_value_iter else 1,
self.k,
# Prefer kernels without dropout/bias if available
1 if self.supports_dropout else 0,
1 if self.supports_bias else 0,
)
@property
def _aligned_suffix(self) -> str:
return "aligned" if self.aligned else "notaligned"
@property
def name(self) -> str:
acc = "rf" if self.single_value_iter else "gmem"
return f"fmha_cutlassF_{self.dtype}_{self._aligned_suffix}_{self.q}x{self.k}_{acc}_sm{self.sm_range[0]}"
@property
def cpp_class(self) -> str:
template_args = ", ".join(
[
DTYPES[self.dtype],
f"cutlass::arch::Sm{self.sm_range[0]}",
"true" if self.aligned else "false",
str(self.q),
str(self.k),
"true" if self.single_value_iter else "false",
"true" if self.supports_dropout else "false",
"true" if self.supports_bias else "false",
]
)
return f"AttentionKernel<{template_args}>"
@property
def impl_group(self) -> str:
# Maps to file which will contain the implementation
return f"{self.dtype}_{self._aligned_suffix}"
@property
def cpp_impl(self) -> str:
return KERNEL_IMPL_TEMPLATE.format(
CPP_CLASS=self.cpp_class,
NAME=self.name,
SM=self.sm_range[0],
SM_MAX=self.sm_range[1],
)
@classmethod
def get_all(cls) -> List["FwdKernel"]:
kernels: List[FwdKernel] = []
for aligned, dtype, (sm, sm_max) in itertools.product(
[True, False], DTYPES.keys(), zip(SM, SM[1:] + [MAX_ARCH])
):
# Remove some kernels we don't use
if dtype == "bf16" and sm < 80:
continue
if not aligned and sm >= 80:
continue
for q, k, single_value_iter in [
(32, 128, True),
(32, 128, False),
(64, 64, True),
]:
kernels.append(
cls(
aligned=aligned,
dtype=dtype,
sm_range=(sm, sm_max),
q=q,
k=k,
single_value_iter=single_value_iter,
)
)
return kernels
@dataclass(order=True)
class BwdKernel:
sort_index: Tuple[int, ...] = field(init=False, repr=False)
sm_range: Tuple[int, int]
dtype: str
aligned: bool
apply_dropout: bool
preload_mmas: bool
block_i: int
block_j: int
max_k: int
dispatch_cond: Optional[str] = None
def __post_init__(self) -> None:
# Set kernel selection priority
# The lowest value that matches inputs
# will be selected
self.sort_index = (
# First select aligned kernel
0 if self.aligned else 1,
# Take a kernel without dropout if possible
1 if self.apply_dropout else 0,
# Then take the smallest maxK
self.max_k,
# .. and the highest block_i
-self.block_i,
)
@property
def _aligned_suffix(self) -> str:
return "aligned" if self.aligned else "notaligned"
@property
def name(self) -> str:
dropout_suffix = "_dropout" if self.apply_dropout else ""
return (
f"fmha_cutlassB_{self.dtype}_{self._aligned_suffix}"
f"_{self.block_i}x{self.block_j}_k{self.max_k}{dropout_suffix}_sm{self.sm_range[0]}"
)
@property
def cpp_class(self) -> str:
template_args = ", ".join(
[
f"cutlass::arch::Sm{self.sm_range[0]}",
DTYPES[self.dtype],
"true" if self.aligned else "false",
"true" if self.apply_dropout else "false",
"true" if self.preload_mmas else "false",
str(self.block_i),
str(self.block_j),
str(self.max_k),
]
)
return f"AttentionBackwardKernel<{template_args}>"
@property
def impl_group(self) -> str:
# Maps to file which will contain the implementation
dropout_suffix = "_dropout" if self.apply_dropout else ""
return (
f"{self.dtype}_{self._aligned_suffix}_k{self.max_k}{dropout_suffix}"
)
@property
def cpp_impl(self) -> str:
return KERNEL_IMPL_TEMPLATE.format(
CPP_CLASS=self.cpp_class,
NAME=self.name,
SM=self.sm_range[0],
SM_MAX=self.sm_range[1],
)
@classmethod
def get_all(cls) -> List["BwdKernel"]:
kernels: List[BwdKernel] = []
for (
aligned,
dtype,
(sm, sm_max),
apply_dropout,
max_k,
) in itertools.product(
[True, False],
DTYPES.keys(),
zip(SM, SM[1:] + [MAX_ARCH]),
[True, False],
[32, 64, 128, 2**16],
):
if dtype == "bf16" and sm < 80:
continue
if not aligned and sm >= 80:
continue
is_half = dtype in ["bf16", "f16"]
bi_values = [64]
# Some architectures have more shmem and can use 128
# We still need fallback to 64 for GPUs with less shmem
# (Sm75, Sm86 ...)
if sm >= 80 or (sm >= 70 and is_half):
if max_k > 64:
bi_values.append(128)
for bi in bi_values:
output_in_rf = is_half and max_k <= bi
preload_mmas = is_half and sm >= 80 and output_in_rf
bj = 128 if (preload_mmas and max_k > 64) else 64
kernels.append(
cls(
aligned=aligned,
dtype=dtype,
sm_range=(sm, sm_max),
apply_dropout=apply_dropout,
preload_mmas=preload_mmas,
block_i=bi,
block_j=bj,
max_k=max_k,
)
)
# Add some specialized kernels for stable diffusion BW (K=80)
# This is the only kernel that can keep the outputs on RF on
# Sm86/Sm89, so it's much faster than the 64x64 one
for dtype in ["f16", "bf16"]:
if max(args.cuda_arch) < 80:
continue
kernels.append(
cls(
aligned=True,
dtype=dtype,
sm_range=(80, MAX_ARCH),
apply_dropout=False,
preload_mmas=True,
block_i=128,
block_j=64,
max_k=96,
# Sm80 has a faster kernel for this case
dispatch_cond="cc == 86 || cc == 89",
)
)
return kernels
T = TypeVar("T", FwdKernel, BwdKernel)
def write_decl_impl(
kernels: List[T], family_name: str, impl_file: str, enable_def: str
) -> None:
cpp_file_header = """// This file is auto-generated. See "generate_kernels.py"
"""
kernels.sort()
implfile_to_kernels: Dict[str, List[T]] = collections.defaultdict(list)
cat_to_kernels: Dict[
Tuple[str, int, int], List[T]
] = collections.defaultdict(list)
dispatch_all = ""
declarations = cpp_file_header + "#pragma once\n"
declarations += f"#ifdef {enable_def}\n"
declarations += f"""#include "{impl_file}"\n"""
declarations += "namespace phi {\n"
# Declaration of kernel functions
for k in kernels:
implfile_to_kernels[k.impl_group].append(k)
cat_to_kernels[(k.dtype, k.sm_range[0], k.sm_range[1])].append(k)
for (cat_dt, cat_sm, cat_sm_max), kernels in cat_to_kernels.items():
declarations += f"// ======== {cat_dt} / sm{cat_sm} ========\n"
declarations += "\n".join(
k.cpp_impl.split("{")[0].rstrip() + ";" for k in kernels
)
dispatch_category_fn = f"dispatch_{family_name}_{cat_dt}_sm{cat_sm}"
declarations += f"\n\ntemplate <typename T> void {dispatch_category_fn}(T cb, int cc) {{\n"
for k in kernels:
_call = f"cb({k.cpp_class}(), {k.name});\n"
if k.dispatch_cond is not None:
_call = f"if ({k.dispatch_cond}) {_call}"
declarations += f" {_call}"
declarations += "}\n\n"
dispatch_all += f"""
if (std::is_same<DT, {DTYPES[cat_dt]}>::value && {cat_sm} <= cc && cc < {cat_sm_max}) {{
{dispatch_category_fn}(cb, cc);
}}"""
declarations += f"""
template <typename PaddleT, typename T>
void dispatch_{family_name}(const ::phi::GPUContext &ctx, T cb) {{
auto cc = ctx.GetComputeCapability();
using DT = typename ::phi::CutlassTrait<PaddleT>::Type;
{dispatch_all}
}}
"""
declarations += "} // namespace phi\n"
declarations += f"#endif // {enable_def}\n"
autogen_dir = Path(args.dst_path) / "autogen"
os.makedirs(autogen_dir, exist_ok=True)
declaration_path = autogen_dir / f"{family_name}.h"
declaration_path.write_text(declarations)
for f, f_kernels in implfile_to_kernels.items():
impl_cu = cpp_file_header
impl_cu += f"#ifdef {enable_def}\n"
impl_cu += f"""#include "{impl_file}"\n"""
impl_cu += "namespace phi {\n"
for k in f_kernels:
impl_cu += k.cpp_impl
impl_cu += "} // namespace phi\n"
impl_cu += f"#endif // {enable_def}\n"
impl_path = autogen_dir / "impl"
os.makedirs(impl_path, exist_ok=True)
(impl_path / f"{family_name}_{f}.cu").write_text(impl_cu)
def write_main_header(forward_impl, backward_impl):
main_header_content = '''
#pragma once
#ifdef %s
#include "%s"
#include "%s"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace phi {
template <typename T>
struct CutlassTrait {
using Type = T;
};
template <>
struct CutlassTrait<dtype::float16> {
using Type = cutlass::half_t;
};
template <>
struct CutlassTrait<dtype::bfloat16> {
using Type = cutlass::bfloat16_t;
};
template <typename T>
struct ToPhiDTypeTrait {
private:
using NonConstT = typename std::remove_const<T>::type;
static constexpr bool kIsFP16 = std::is_same<NonConstT, cutlass::half_t>::value;
static constexpr bool kIsBF16 = std::is_same<NonConstT, cutlass::bfloat16_t>::value;
public:
using Type = typename std::conditional<kIsFP16, dtype::float16,
typename std::conditional<kIsBF16, dtype::bfloat16, NonConstT>::type>::type;
};
template <typename T>
T *SafeGetTensorPtr(const DenseTensor &t) {
using PDT = typename ToPhiDTypeTrait<T>::Type;
return reinterpret_cast<T *>(reinterpret_cast<uintptr_t>(t.template data<PDT>()));
}
template <typename T>
T *SafeGetTensorPtr(const DenseTensor *t) {
return t ? SafeGetTensorPtr<T>(*t) : nullptr;
}
template <typename T>
T *SafeGetTensorPtr(const paddle::optional<DenseTensor> &t) {
return t ? SafeGetTensorPtr<T>(t.get()) : nullptr;
}
template <typename T, typename Context>
T *SafeAllocTensor(const Context &ctx, DenseTensor *t) {
using PDT = typename ToPhiDTypeTrait<T>::Type;
void *ptr = ctx.template Alloc<PDT>(t);
return reinterpret_cast<T *>(reinterpret_cast<uintptr_t>(ptr));
}
inline int64_t DimStride(const phi::DDim &dims, int n) {
int rank = dims.size();
if (n < 0) {
n += rank;
}
int64_t stride = 1;
for (int i = n+1; i < rank; ++i) {
stride *= dims[i];
}
return stride;
}
} // namespace phi
#include "./cutlass_forward.h"
#include "./cutlass_backward.h"
#endif
''' % (
ENABLE_MACRO,
forward_impl,
backward_impl,
)
path = Path(args.dst_path) / "autogen"
os.makedirs(path, exist_ok=True)
path = Path(path) / "memory_efficient_attention.h"
path.write_text(main_header_content)
if os.path.exists(Path(args.dst_path) / "autogen"):
shutil.rmtree(Path(args.dst_path) / "autogen")
forward_impl = "paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/kernel_forward.h"
backward_impl = "paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/kernel_backward.h"
write_main_header(forward_impl, backward_impl)
write_decl_impl(
FwdKernel.get_all(),
"cutlass_forward",
impl_file=forward_impl,
enable_def=ENABLE_MACRO,
)
write_decl_impl(
BwdKernel.get_all(),
"cutlass_backward",
impl_file=backward_impl,
enable_def=ENABLE_MACRO,
)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include "./predicated_tile_access_iterator_residual_last.h"
#include "./predicated_tile_iterator_residual_last.h"
namespace cutlass {
namespace transform {
namespace threadblock {
template <typename BaseIterator>
struct MakeIteratorResidualLast;
template <typename Shape,
typename Element,
typename Layout,
int AdvanceRank,
typename ThreadMap,
int AccessSize,
bool Gather>
struct MakeIteratorResidualLast<PredicatedTileIterator<Shape,
Element,
Layout,
AdvanceRank,
ThreadMap,
AccessSize,
Gather>> {
using Iterator = PredicatedTileIteratorResidualLast<Shape,
Element,
Layout,
AdvanceRank,
ThreadMap,
AccessSize,
Gather>;
};
template <typename Shape,
typename Element,
typename Layout,
int AdvanceRank,
typename ThreadMap,
typename AccessType,
bool Gather>
struct MakeIteratorResidualLast<PredicatedTileAccessIterator<Shape,
Element,
Layout,
AdvanceRank,
ThreadMap,
AccessType,
Gather>> {
using Iterator = PredicatedTileAccessIteratorResidualLast<Shape,
Element,
Layout,
AdvanceRank,
ThreadMap,
AccessType,
Gather>;
};
} // namespace threadblock
} // namespace transform
} // namespace cutlass
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include "./warp_iterator_from_smem.h"
template <typename WarpIterator>
struct TransposeWarpIterator {
using Iterator = char;
static bool constexpr kSupportsTranspose = false;
};
template <
/// Operand identity
cutlass::gemm::Operand Operand,
/// Data type of A elements
typename Element,
bool kTranspose>
struct TransposeWarpIterator<
cutlass::gemm::warp::WarpIteratorFromSmem<Operand, Element, kTranspose>> {
using Iterator =
cutlass::gemm::warp::WarpIteratorFromSmem<Operand, Element, !kTranspose>;
static bool constexpr kSupportsTranspose = true;
};
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <cutlass/cutlass.h>
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/numeric_types.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
template <typename scalar_t, // scalar type
typename ThreadblockTileShape, // size of tile to load
int Threads, // number of participating threads
int ElementsPerAccess> // thread access width in elements
class TileSmemLoader {
public:
using SmemTile =
cutlass::AlignedBuffer<scalar_t, ThreadblockTileShape::kCount>;
using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap<
cutlass::layout::PitchLinearShape<
ThreadblockTileShape::kColumn, // contiguous
ThreadblockTileShape::kRow>, // strided
Threads, // Threads
ElementsPerAccess>; // ElementsPerAccess
using GmemTileIterator =
cutlass::transform::threadblock::PredicatedTileIterator<
ThreadblockTileShape, // Shape
scalar_t, // Element
cutlass::layout::RowMajor, // Layout
0, // AdvanceRank
ThreadMap>; // ThreadMap
using SmemTileIterator = cutlass::transform::threadblock::RegularTileIterator<
ThreadblockTileShape, // Shape
scalar_t, // Element
cutlass::layout::RowMajor, // Layout
0, // AdvanceRank
ThreadMap>; // ThreadMap
using Fragment = typename GmemTileIterator::Fragment;
/// load a tile from global memory into shared memory
CUTLASS_DEVICE
static void load(GmemTileIterator tile_load_iter,
SmemTileIterator tile_store_iter) {
Fragment tb_frag;
tb_frag.clear();
tile_load_iter.load(tb_frag);
tile_store_iter.store(tb_frag);
__syncthreads();
}
};
......@@ -1117,6 +1117,7 @@ set_tests_properties(test_cumprod_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_split_program PROPERTIES TIMEOUT 120)
set_tests_properties(test_graph_send_ue_recv_op PROPERTIES TIMEOUT 60)
set_tests_properties(test_graph_send_uv_op PROPERTIES TIMEOUT 60)
if(WITH_DISTRIBUTE
AND WITH_GPU
AND WITH_NCCL)
......
......@@ -20,6 +20,7 @@ from .fused_transformer import fused_bias_dropout_residual_layer_norm
from .fused_ec_moe import fused_ec_moe
from .fused_dropout_add import fused_dropout_add
__all__ = [
'fused_multi_head_attention',
'fused_feedforward',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册