未验证 提交 e5ad3859 编写于 作者: Z ZhangDY-6483 提交者: GitHub

Memory Efficient Attention (#51867)

* first version, notest

* return final rst, notest

* use infinity() instead of max

* ut structure

* start up of ut

* generate lse

* update

* add depense

* reconstruct cmake

* move file

* add memory efficient attention and fix blasimpl

* update

* update cmake

* add namespace

* update cmake

* use .cu

* update for pad3d

* bug fix

* bug fix

* update

* bug fix

* update enforce

* add test case

* merge the lse pad

* fix kernel_fn of backward

* fix PADDLE_ENFORCE_EQ and phi_api

* fix PADDLE_ENFORCE

* fix PADDLE_ENFORCE

* rerun coverage

* fix memory efficient attention test

* rerun ci

* add cuda version condition

* add cuda version condition

* delete WIP test

* replace PADDLE_ENFORCE

* edit the namespace of datatype in multiple.cc

* rerun

* rerun

---------
Co-authored-by: Nliuyuang <liuyuang@baidu.com>
上级 40fea722
...@@ -96,7 +96,7 @@ endfunction() ...@@ -96,7 +96,7 @@ endfunction()
# Function for selecting GPU arch flags for nvcc based on CUDA_ARCH_NAME # Function for selecting GPU arch flags for nvcc based on CUDA_ARCH_NAME
# Usage: # Usage:
# select_nvcc_arch_flags(out_variable) # select_nvcc_arch_flags(out_variable)
function(select_nvcc_arch_flags out_variable) function(select_nvcc_arch_flags out_variable out_arch_bin)
# List of arch names # List of arch names
set(archs_names set(archs_names
"Kepler" "Kepler"
...@@ -244,6 +244,9 @@ function(select_nvcc_arch_flags out_variable) ...@@ -244,6 +244,9 @@ function(select_nvcc_arch_flags out_variable)
set(${out_variable}_real_archs set(${out_variable}_real_archs
${nvcc_real_archs} ${nvcc_real_archs}
PARENT_SCOPE) PARENT_SCOPE)
set(${out_arch_bin}
${cuda_arch_bin}
PARENT_SCOPE)
endfunction() endfunction()
message(STATUS "CUDA detected: " ${CMAKE_CUDA_COMPILER_VERSION}) message(STATUS "CUDA detected: " ${CMAKE_CUDA_COMPILER_VERSION})
...@@ -273,7 +276,7 @@ add_definitions("-DCUDA_VERSION_MINOR=\"${CUDA_VERSION_MINOR}\"") ...@@ -273,7 +276,7 @@ add_definitions("-DCUDA_VERSION_MINOR=\"${CUDA_VERSION_MINOR}\"")
add_definitions("-DCUDA_TOOLKIT_ROOT_DIR=\"${CUDA_TOOLKIT_ROOT_DIR}\"") add_definitions("-DCUDA_TOOLKIT_ROOT_DIR=\"${CUDA_TOOLKIT_ROOT_DIR}\"")
# setting nvcc arch flags # setting nvcc arch flags
select_nvcc_arch_flags(NVCC_FLAGS_EXTRA) select_nvcc_arch_flags(NVCC_FLAGS_EXTRA NVCC_ARCH_BIN)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${NVCC_FLAGS_EXTRA}") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${NVCC_FLAGS_EXTRA}")
message(STATUS "NVCC_FLAGS_EXTRA: ${NVCC_FLAGS_EXTRA}") message(STATUS "NVCC_FLAGS_EXTRA: ${NVCC_FLAGS_EXTRA}")
......
...@@ -963,6 +963,17 @@ ...@@ -963,6 +963,17 @@
kernel : kernel :
func : maxout_grad func : maxout_grad
- backward_op : memory_efficient_attention_grad
forward : memory_efficient_attention (Tensor query, Tensor key, Tensor value, Tensor bias, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor causal_diagonal, Tensor seqlen_k, Scalar max_seqlen_q, Scalar max_seqlen_k, bool causal, double dropout_p, float scale, bool is_test) -> Tensor(output), Tensor(logsumexp), Tensor(seed_and_offset)
args : (Tensor query, Tensor key, Tensor value, Tensor bias, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor output, Tensor logsumexp, Tensor seed_and_offset, Tensor output_grad, Scalar max_seqlen_q, Scalar max_seqlen_k, bool causal, double dropout_p, float scale)
output : Tensor(query_grad), Tensor(key_grad), Tensor(value_grad), Tensor(bias_grad)
infer_meta :
func : MemoryEfficientAttentionGradInferMeta
kernel :
func : memory_efficient_attention_grad
data_type : output_grad
optional : bias, cu_seqlens_q, cu_seqlens_k
- backward_op : meshgrid_grad - backward_op : meshgrid_grad
forward : meshgrid (Tensor[] inputs) -> Tensor[](outputs) forward : meshgrid (Tensor[] inputs) -> Tensor[](outputs)
args : (Tensor[] inputs, Tensor[] outputs_grad) args : (Tensor[] inputs, Tensor[] outputs_grad)
......
...@@ -981,6 +981,17 @@ ...@@ -981,6 +981,17 @@
func : maxout func : maxout
backward : maxout_grad backward : maxout_grad
- op : memory_efficient_attention
args : (Tensor query, Tensor key, Tensor value, Tensor bias, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor causal_diagonal, Tensor seqlen_k, Scalar max_seqlen_q, Scalar max_seqlen_k, bool causal, double dropout_p, float scale, bool is_test)
output : Tensor(output), Tensor(logsumexp), Tensor(seed_and_offset)
infer_meta :
func : MemoryEfficientAttentionInferMeta
kernel :
func : memory_efficient_attention
data_type : query
optional : bias, cu_seqlens_q, cu_seqlens_k, causal_diagonal, seqlen_k
backward : memory_efficient_attention_grad
- op : meshgrid - op : meshgrid
args : (Tensor[] inputs) args : (Tensor[] inputs)
output : Tensor[]{inputs.size()} output : Tensor[]{inputs.size()}
......
...@@ -1052,4 +1052,89 @@ void IndexAddGradInferMeta(const MetaTensor& index, ...@@ -1052,4 +1052,89 @@ void IndexAddGradInferMeta(const MetaTensor& index,
} }
} }
void MemoryEfficientAttentionGradInferMeta(const MetaTensor& query,
const MetaTensor& key,
const MetaTensor& value,
const MetaTensor& bias,
const MetaTensor& cu_seqlens_q,
const MetaTensor& cu_seqlens_k,
const MetaTensor& output,
const MetaTensor& logsumexp,
const MetaTensor& seed_and_offset,
const MetaTensor& output_grad,
const Scalar& max_seqlen_q,
const Scalar& max_seqlen_k,
const bool causal,
const double dropout_p,
const float scale,
MetaTensor* query_grad,
MetaTensor* key_grad,
MetaTensor* value_grad,
MetaTensor* bias_grad) {
PADDLE_ENFORCE_EQ(
output_grad.dims().size(),
4,
phi::errors::InvalidArgument("Key should be a 4-D tensor"
"But received Key dimension(%s)",
output_grad.dims().size()));
PADDLE_ENFORCE_EQ(
output.dims().size(),
4,
phi::errors::InvalidArgument("Key should be a 4-D tensor"
"But received Key dimension(%s)",
output_grad.dims().size()));
const int64_t query_batch_size = query.dims()[0];
const int64_t query_seq_length = query.dims()[1];
const int64_t query_num_head = query.dims()[2];
const int64_t query_head_size = query.dims()[3];
const int64_t key_batch_size = key.dims()[0];
const int64_t key_seq_length = key.dims()[1];
const int64_t key_num_head = key.dims()[2];
const int64_t key_head_size = key.dims()[3];
const int64_t value_batch_size = value.dims()[0];
const int64_t value_seq_length = value.dims()[1];
const int64_t value_num_head = value.dims()[2];
const int64_t value_head_size = value.dims()[3];
std::vector<int64_t> query_grad_dims(
{query_batch_size, query_seq_length, query_num_head, query_head_size});
std::vector<int64_t> key_grad_dims(
{key_batch_size, key_seq_length, key_num_head, key_head_size});
std::vector<int64_t> value_grad_dims(
{value_batch_size, value_seq_length, value_num_head, value_head_size});
query_grad->set_dims(phi::make_ddim(query_grad_dims));
query_grad->share_lod(query);
query_grad->set_dtype(query.dtype());
query_grad->set_layout(query.layout());
key_grad->set_dims(phi::make_ddim(key_grad_dims));
key_grad->share_lod(key);
key_grad->set_dtype(key.dtype());
key_grad->set_layout(key.layout());
value_grad->set_dims(phi::make_ddim(value_grad_dims));
value_grad->share_lod(value);
value_grad->set_dtype(value.dtype());
value_grad->set_layout(value.layout());
if (bias) {
const int64_t bias_batch_size = bias.dims()[0];
const int64_t bias_seq_length = bias.dims()[1];
const int64_t bias_num_head = bias.dims()[2];
const int64_t bias_head_size = bias.dims()[3];
std::vector<int64_t> bias_grad_dims(
{bias_batch_size, bias_seq_length, bias_num_head, bias_head_size});
bias_grad->set_dims(phi::make_ddim(bias_grad_dims));
bias_grad->share_lod(bias);
bias_grad->set_dtype(bias.dtype());
bias_grad->set_layout(bias.layout());
}
}
} // namespace phi } // namespace phi
...@@ -418,4 +418,24 @@ void IndexAddGradInferMeta(const MetaTensor& index, ...@@ -418,4 +418,24 @@ void IndexAddGradInferMeta(const MetaTensor& index,
MetaTensor* x_grad, MetaTensor* x_grad,
MetaTensor* add_tensor_grad); MetaTensor* add_tensor_grad);
void MemoryEfficientAttentionGradInferMeta(const MetaTensor& query,
const MetaTensor& key,
const MetaTensor& value,
const MetaTensor& bias,
const MetaTensor& cu_seqlens_q,
const MetaTensor& cu_seqlens_k,
const MetaTensor& output,
const MetaTensor& logsumexp,
const MetaTensor& seed_and_offset,
const MetaTensor& output_grad,
const Scalar& max_seqlen_q,
const Scalar& max_seqlen_k,
const bool causal,
const double dropout_p,
const float scale,
MetaTensor* query_grad,
MetaTensor* key_grad,
MetaTensor* value_grad,
MetaTensor* bias_grad);
} // namespace phi } // namespace phi
...@@ -3124,6 +3124,94 @@ void MoeInferMeta(const MetaTensor& x, ...@@ -3124,6 +3124,94 @@ void MoeInferMeta(const MetaTensor& x,
out->set_layout(x.layout()); out->set_layout(x.layout());
} }
} // namespace phi void MemoryEfficientAttentionInferMeta(const MetaTensor& query,
const MetaTensor& key,
const MetaTensor& value,
const MetaTensor& bias,
const MetaTensor& cu_seqlens_q,
const MetaTensor& cu_seqlens_k,
const MetaTensor& causal_diagonal,
const MetaTensor& seqlen_k,
const Scalar& max_seqlen_q,
const Scalar& max_seqlen_k,
const bool causal,
const double dropout_p,
const float scale,
const bool is_test,
MetaTensor* output,
MetaTensor* logsumexp,
MetaTensor* seed_and_offset) {
PADDLE_ENFORCE_EQ(
query.dims().size(),
4,
phi::errors::InvalidArgument("Query should be a 4-D tensor"
"But received Query dimension(%s)",
query.dims().size()));
PADDLE_ENFORCE_EQ(
key.dims().size(),
4,
phi::errors::InvalidArgument("Key should be a 4-D tensor"
"But received Key dimension(%s)",
key.dims().size()));
PADDLE_ENFORCE_EQ(
value.dims().size(),
4,
phi::errors::InvalidArgument("Value should be a 4-D tensor"
"But received Value dimension(%s)",
value.dims().size()));
const int64_t query_batch_size = query.dims()[0];
const int64_t query_seq_length = query.dims()[1];
const int64_t query_num_head = query.dims()[2];
const int64_t query_head_size = query.dims()[3];
const int64_t key_batch_size = key.dims()[0];
const int64_t key_seq_length = key.dims()[1];
const int64_t key_num_head = key.dims()[2];
const int64_t key_head_size = key.dims()[3];
const int64_t value_batch_size = value.dims()[0];
const int64_t value_seq_length = value.dims()[1];
const int64_t value_num_head = value.dims()[2];
const int64_t value_head_size = value.dims()[3];
PADDLE_ENFORCE_EQ(((query_batch_size == key_batch_size) &&
(key_batch_size == value_batch_size)),
true,
phi::errors::InvalidArgument(
"The batchsize of Query, Key, Value should be equal."));
PADDLE_ENFORCE_EQ(
((query_num_head == key_num_head) && (key_num_head == value_num_head)),
true,
phi::errors::InvalidArgument(
"The head number of Query, Key, Value should be equal."));
PADDLE_ENFORCE_EQ(query_head_size == key_head_size,
true,
phi::errors::InvalidArgument(
"The head size of Query, Key should be equal."));
PADDLE_ENFORCE_EQ(key_seq_length == value_seq_length,
true,
phi::errors::InvalidArgument(
"The seq length of Key, Value should be equal."));
std::vector<int64_t> out_dims(
{query_batch_size, query_seq_length, query_num_head, value_head_size});
std::vector<int64_t> logsumexp_dims({query_num_head, query_batch_size});
std::vector<int64_t> seed_and_offset_dims({2});
output->set_dims(phi::make_ddim(out_dims));
output->share_lod(query);
output->set_dtype(query.dtype());
output->set_layout(query.layout());
logsumexp->set_dims(phi::make_ddim(logsumexp_dims));
logsumexp->set_dtype(phi::DataType::FLOAT32);
seed_and_offset->set_dims(phi::make_ddim(seed_and_offset_dims));
seed_and_offset->set_dtype(phi::DataType::INT64);
}
} // namespace phi
PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta); PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta);
...@@ -587,4 +587,22 @@ void MoeInferMeta(const MetaTensor& x, ...@@ -587,4 +587,22 @@ void MoeInferMeta(const MetaTensor& x,
const std::string& act_type, const std::string& act_type,
MetaTensor* out); MetaTensor* out);
void MemoryEfficientAttentionInferMeta(const MetaTensor& query,
const MetaTensor& key,
const MetaTensor& value,
const MetaTensor& bias,
const MetaTensor& cu_seqlens_q,
const MetaTensor& cu_seqlens_k,
const MetaTensor& causal_diagonal,
const MetaTensor& seqlen_k,
const Scalar& max_seqlen_q,
const Scalar& max_seqlen_k,
const bool causal,
const double dropout_p,
const float scale,
const bool is_test,
MetaTensor* output,
MetaTensor* logsumexp,
MetaTensor* seed_and_offset);
} // namespace phi } // namespace phi
...@@ -125,8 +125,15 @@ if(WITH_CUTLASS) ...@@ -125,8 +125,15 @@ if(WITH_CUTLASS)
COMMAND ${PYTHON_EXECUTABLE} "conv2d_bias_residual.py" COMMAND ${PYTHON_EXECUTABLE} "conv2d_bias_residual.py"
WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/fusion/cutlass/conv2d") WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/fusion/cutlass/conv2d")
execute_process(
COMMAND
${PYTHON_EXECUTABLE}
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/generate_kernels.py
--cuda_arch "${NVCC_ARCH_BIN}")
file(GLOB cutlass_cu "fusion/cutlass/conv2d/generated/*.cu" file(GLOB cutlass_cu "fusion/cutlass/conv2d/generated/*.cu"
"fusion/cutlass/conv2d/*.cu" "fusion/cutlass/*.cu") "fusion/cutlass/conv2d/*.cu" "fusion/cutlass/*.cu"
"fusion/cutlass/memory_efficient_attention/autogen/impl/*.cu")
add_definitions("-DPADDLE_WITH_MEMORY_EFFICIENT_ATTENTION")
list(APPEND kernel_cu ${cutlass_cu}) list(APPEND kernel_cu ${cutlass_cu})
endif() endif()
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 #if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
#include <cuda_runtime_api.h> // NOLINT #include <cuda_runtime_api.h> // NOLINT
#include "cuda.h" // NOLINT #include "cuda.h" // NOLINT
#include "paddle/phi/backends/dynload/cublasLt.h" #include "paddle/phi/backends/dynload/cublasLt.h"
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ref:
// https://github.com/facebookresearch/xformers/blob/b6be33aecb5297f3f994568cf29e194a75e47667/xformers/ops/fmha/common.py#L102
#pragma once
#include "paddle/phi/backends/gpu/cuda/cuda_helper.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/slice.h"
#include "paddle/phi/kernels/pad3d_kernel.h"
namespace phi {
namespace funcs {
using phi::PADDLE_CUDA_NUM_THREADS;
template <typename T>
__global__ void ViewSliceHelper(T* data,
int stride,
int in_last_dim,
int out_second_dim) {
CUDA_KERNEL_LOOP_TYPE(i, stride * in_last_dim, int64_t) {
if (i % in_last_dim >= out_second_dim) {
*(data + i) = std::numeric_limits<T>::infinity();
}
}
}
template <typename T>
phi::DenseTensor get_pad_lse(const phi::GPUContext& dev_ctx,
phi::DenseTensor* lse,
int out_second_dim,
int pad_to,
const std::string& data_format = "NCHW",
bool force_pad_inf = false) {
int pad_amount = (pad_to - (lse->dims()[2] % pad_to)) % pad_to;
PADDLE_ENFORCE_EQ(
lse->dims().size(),
3,
phi::errors::InvalidArgument("The lse should be a 3d tensor"));
PADDLE_ENFORCE_EQ(
(data_format == "NCHW" || data_format == "NHWC"),
true,
phi::errors::InvalidArgument("The data_format should be NCHW or NHWC"));
std::string pad3d_data_format = data_format == "NCHW" ? "NCDHW" : "NDHWC";
if (pad_amount > 0) {
phi::DenseTensor tmp = *lse;
if (force_pad_inf) {
tmp = phi::funcs::Slice<T, phi::GPUContext>(
dev_ctx, *lse, {2}, {0}, {out_second_dim});
pad_amount = (pad_to - (tmp.dims()[2] % pad_to)) % pad_to;
}
tmp.Resize({tmp.dims()[0], tmp.dims()[1], tmp.dims()[2], 1, 1});
phi::DenseTensor out;
out.Resize({1, 1, 1, 1, 1});
phi::Pad3dKernel<T, phi::GPUContext>(dev_ctx,
tmp,
{0, 0, 0, 0, 0, pad_amount},
"constant",
std::numeric_limits<T>::infinity(),
pad3d_data_format,
&out);
out.Resize({out.dims()[0], out.dims()[1], out.dims()[2]});
return out;
} else if (force_pad_inf && out_second_dim != lse->dims()[2]) {
auto in_dim = lse->dims();
auto in_data = lse->template data<T>();
int stride = in_dim[0] * in_dim[1];
int block = PADDLE_CUDA_NUM_THREADS;
int64_t n = lse->numel();
dim3 grid = dim3((n + block - 1) / block);
phi::backends::gpu::LimitGridDim(dev_ctx, &grid);
ViewSliceHelper<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, stride, in_dim[2], out_second_dim);
return *lse;
}
}
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/autogen/memory_efficient_attention.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
template <typename T, typename Context>
void MemoryEfficientAttentionForwardKernel(
const Context& ctx,
const DenseTensor& query,
const DenseTensor& key,
const DenseTensor& value,
const paddle::optional<DenseTensor>& bias,
const paddle::optional<DenseTensor>& cu_seqlens_q,
const paddle::optional<DenseTensor>& cu_seqlens_k,
const paddle::optional<DenseTensor>& causal_diagonal,
const paddle::optional<DenseTensor>& seqlen_k,
const Scalar& max_seqlen_q,
const Scalar& max_seqlen_k,
const bool causal,
const double dropout_p,
const float scale,
const bool is_test,
DenseTensor* output,
DenseTensor* logsumexp,
DenseTensor* seed_and_offset) {
int compute_capacity = ctx.GetComputeCapability();
const auto max_shmem =
getMaximumSharedMemoryPerBlockKb(compute_capacity) * 1024;
bool kernel_launched = false;
auto max_seqlen_q_num = max_seqlen_q.to<uint64_t>();
auto max_seqlen_k_num = max_seqlen_k.to<uint64_t>();
auto launchKernel = [&](auto k_, auto kernel_fn) {
using KernelType = decltype(k_);
bool is_launched = kernel_launched;
if (is_launched) {
return;
}
using scalar_t = typename KernelType::scalar_t;
bool use_dropout = (dropout_p != 0);
if (!KernelType::kSupportsDropout && use_dropout) {
VLOG(3) << "run in to use dropout" << use_dropout;
return;
}
if (!KernelType::kSupportsBias && bias) {
VLOG(3) << "run in to bias";
return;
}
const auto& v_dims = value.dims();
if (KernelType::kSingleValueIteration &&
KernelType::kKeysPerBlock < v_dims[3]) {
VLOG(3) << "run in to value dim" << v_dims;
return;
}
const auto& k_dims = key.dims();
const auto& q_dims = query.dims();
int64_t max_seqlen_q_tmp, max_seqlen_k_tmp;
if (cu_seqlens_q) {
max_seqlen_q_tmp = max_seqlen_q_num;
max_seqlen_k_tmp = 0; // Will be set inside the kernel
} else {
max_seqlen_q_tmp = q_dims[1];
max_seqlen_k_tmp = k_dims[1];
}
VLOG(3) << "max_seqlen_q_tmp " << max_seqlen_q_tmp;
if ((q_dims[3] % KernelType::kAlignmentQ) ||
(k_dims[3] % KernelType::kAlignmentK) ||
(v_dims[3] % KernelType::kAlignmentV)) {
VLOG(3) << "run in to query dim" << q_dims;
VLOG(3) << "run in to key dim" << k_dims;
return;
}
size_t smem_bytes = sizeof(typename KernelType::SharedStorage);
if (smem_bytes > max_shmem) {
VLOG(3) << "run in to shmem" << smem_bytes << " " << max_shmem;
return;
}
kernel_launched = true;
VLOG(3) << "launching";
output->Resize({q_dims[0], q_dims[1], q_dims[2], v_dims[3]});
constexpr int64_t kAlignLSE = KernelType::kAlignLSE;
phi::Dim<3> logsumexp_dims;
logsumexp_dims[0] =
cu_seqlens_q ? cu_seqlens_q.get().dims()[0] - 1 : q_dims[0];
logsumexp_dims[1] = q_dims[2];
logsumexp_dims[2] =
is_test ? 0 : (max_seqlen_q_tmp + kAlignLSE - 1) / kAlignLSE;
logsumexp_dims[2] *= kAlignLSE;
logsumexp->Resize(logsumexp_dims);
ctx.template Alloc<float>(logsumexp);
VLOG(3) << "logsumexp dims" << logsumexp_dims;
VLOG(3) << "logsumexp" << logsumexp;
VLOG(3) << "kAlignLSE" << kAlignLSE;
typename KernelType::Params p;
p.query_ptr = SafeGetTensorPtr<scalar_t>(query);
p.key_ptr = SafeGetTensorPtr<scalar_t>(key);
p.value_ptr = SafeGetTensorPtr<scalar_t>(value);
p.logsumexp_ptr = is_test ? nullptr : logsumexp->data<float>();
VLOG(3) << "logsumexp_ptr" << p.logsumexp_ptr;
DenseTensor out_accum;
if (KernelType::kNeedsOutputAccumulatorBuffer) {
out_accum.Resize(output->dims());
p.output_accum_ptr =
SafeAllocTensor<typename KernelType::output_accum_t, Context>(
ctx, &out_accum);
VLOG(3) << "output_accum_ptr " << p.output_accum_ptr;
} else {
p.output_accum_ptr = nullptr;
}
p.output_ptr =
SafeAllocTensor<typename KernelType::output_t, Context>(ctx, output);
VLOG(3) << "output_ptr " << p.output_ptr;
if (cu_seqlens_q) {
p.seqstart_q_ptr = SafeGetTensorPtr<int32_t>(cu_seqlens_q);
p.seqstart_k_ptr = SafeGetTensorPtr<int32_t>(cu_seqlens_k);
VLOG(3) << "seqstart_q_ptr " << p.seqstart_q_ptr;
} else {
p.seqstart_q_ptr = nullptr;
p.seqstart_k_ptr = nullptr;
}
p.num_heads = q_dims[2];
p.head_dim = q_dims[3];
p.head_dim_value = v_dims[3];
p.num_queries = max_seqlen_q_tmp;
p.num_keys = max_seqlen_k_tmp;
p.num_batches = cu_seqlens_q ? cu_seqlens_q.get().dims()[0] - 1 : q_dims[0];
p.causal = causal;
if (causal_diagonal) {
p.causal_diagonal_ptr = SafeGetTensorPtr<int32_t>(causal_diagonal);
} else {
p.causal_diagonal_ptr = nullptr;
}
VLOG(3) << "causal_diagonal_ptr " << p.causal_diagonal_ptr;
p.seqlen_k_ptr = nullptr;
if (seqlen_k) {
p.seqlen_k_ptr = SafeGetTensorPtr<int32_t>(seqlen_k);
} else {
p.seqlen_k_ptr = nullptr;
}
VLOG(3) << "seqlen_k_ptr " << p.seqlen_k_ptr;
if (scale < 0) {
p.scale = static_cast<float>(1.0 / std::sqrt(p.head_dim));
} else {
p.scale = scale;
}
VLOG(3) << "scale " << p.scale;
p.q_strideB = DimStride(query.dims(), 0);
p.k_strideB = DimStride(key.dims(), 0);
p.v_strideB = DimStride(value.dims(), 0);
p.q_strideM = DimStride(query.dims(), 1);
p.k_strideM = DimStride(key.dims(), 1);
p.v_strideM = DimStride(value.dims(), 1);
p.q_strideH = DimStride(query.dims(), 2);
p.k_strideH = DimStride(key.dims(), 2);
p.v_strideH = DimStride(value.dims(), 2);
p.o_strideM = DimStride(output->dims(), 1);
if (bias) {
p.attn_bias_ptr = SafeGetTensorPtr<scalar_t>(bias);
p.bias_strideB = q_dims[2] * q_dims[1] * k_dims[1];
p.bias_strideH = q_dims[1] * k_dims[1];
p.bias_strideM = k_dims[1];
} else {
p.attn_bias_ptr = nullptr;
}
VLOG(3) << "attn_bias_ptr " << p.attn_bias_ptr;
VLOG(3) << "bias_strideB " << p.bias_strideB;
VLOG(3) << "bias_strideH " << p.bias_strideH;
VLOG(3) << "bias_strideM " << p.bias_strideM;
phi::Dim<1> seed_dims;
seed_dims[0] = 2;
seed_and_offset->Resize(seed_dims);
ctx.template HostAlloc<int64_t>(seed_and_offset);
int64_t* seed_and_offset_ptr = SafeGetTensorPtr<int64_t>(seed_and_offset);
auto gen = ctx.GetGenerator();
uint64_t inc = query.dims()[0] * query.dims()[2] * 32;
auto seed_offset_pair = gen->IncrementOffset(inc);
auto seed = (seed_offset_pair.first);
auto offset = (seed_offset_pair.second);
seed_and_offset_ptr[0] = (int64_t)seed;
seed_and_offset_ptr[1] = (int64_t)offset;
VLOG(3) << "seed and offset: " << seed << " " << offset << " "
<< seed_and_offset_ptr;
p.use_dropout = use_dropout;
if (use_dropout) {
p.seed = seed;
p.offset = offset;
p.dropout_prob = dropout_p;
} else {
p.dropout_prob = 0.0;
}
if (smem_bytes > 0xc000) {
const void* kernel_fn_void_ptr =
reinterpret_cast<const void*>(reinterpret_cast<uintptr_t>(kernel_fn));
PADDLE_ENFORCE_GPU_SUCCESS(
cudaFuncSetAttribute(kernel_fn_void_ptr,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_bytes));
}
KernelType::check_supported(p);
VLOG(3) << "Kernel launched with func : " << typeid(kernel_fn).name()
<< " block dim " << p.getBlocksGrid() << " thread dim "
<< p.getThreadsGrid();
kernel_fn<<<p.getBlocksGrid(),
p.getThreadsGrid(),
smem_bytes,
ctx.stream()>>>(p);
};
dispatch_cutlass_forward<T>(ctx, launchKernel);
PADDLE_ENFORCE_EQ(kernel_launched,
true,
paddle::platform::errors::InvalidArgument(
"the kernel should not be launched"));
}
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(
memory_efficient_attention,
GPU,
ALL_LAYOUT,
phi::fusion::cutlass_internal::MemoryEfficientAttentionForwardKernel,
float,
phi::dtype::bfloat16,
phi::dtype::float16) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <float.h>
#include <stdio.h>
#include <cmath>
////////////////////////////////////////////////////////////////////////////////
// Debugging functions
////////////////////////////////////////////////////////////////////////////////
// Nans & inf detection
#define NANCHECK(frag) \
{ \
for (int _i = 0; _i < frag.size(); ++_i) { \
assert(std::isfinite(static_cast<float>(frag[_i]))); \
assert(!std::isnan(static_cast<float>(frag[_i]))); \
} \
}
// Print on the first thread of the first block
#if 1
#define PRINT_WARP_ID 0
#define PRINT_LANE_ID 0
#define PRINT_B0_T0(msg, ...) \
if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && \
threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \
threadIdx.z == 0) { \
printf(msg "\n", ##__VA_ARGS__); \
}
#define PRINT_T0(msg, ...) \
if (threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \
threadIdx.z == 0) { \
printf(msg "\n", ##__VA_ARGS__); \
}
#define PRINT_TX_LX(msg, ...) \
for (int bx = 0; bx < gridDim.x; ++bx) { \
for (int by = 0; by < gridDim.y; ++by) { \
for (int bz = 0; bz < gridDim.z; ++bz) { \
for (int tx = 0; tx < blockDim.x; ++tx) { \
for (int ty = 0; ty < blockDim.y; ++ty) { \
for (int tz = 0; tz < blockDim.z; ++tz) { \
__syncthreads(); \
if (blockIdx.x == bx && blockIdx.y == by && blockIdx.z == bz && \
threadIdx.x == tx && threadIdx.y == ty && \
threadIdx.z == tz) { \
printf("[%d,%d,%d][%d,%d,%d]" msg "\n", \
bx, \
by, \
bz, \
tx, \
ty, \
tz, \
##__VA_ARGS__); \
} \
} \
} \
} \
} \
} \
}
#else
#define PRINT_B0_T0
#define PRINT_TX_LX
#endif
struct __string_view {
char const* data;
std::size_t size;
};
#if __cplusplus >= 201402L
template <class T>
constexpr __string_view __get_type_name() {
char const* p = __PRETTY_FUNCTION__;
while (*p++ != '=')
; // NOLINT
for (; *p == ' '; ++p)
; // NOLINT
char const* p2 = p;
int count = 1;
for (;; ++p2) {
switch (*p2) {
case '[':
++count;
break;
case ']':
--count;
if (!count) return {p, std::size_t(p2 - p)};
}
}
return {};
}
#else
template <class T>
constexpr __string_view __get_type_name() {
return {"unsupported", 11};
}
#endif
// Print a given array
#define PRINT_ACCUM8_T0_L0_START(name, accum, start) \
PRINT_T0_L0("%s[%d:%d] - {%f, %f, %f, %f, %f, %f, %f, %f}", \
name, \
static_cast<int>(start), \
static_cast<int>(start + 8), \
static_cast<float>(accum[start + 0]), \
static_cast<float>(accum[start + 1]), \
static_cast<float>(accum[start + 2]), \
static_cast<float>(accum[start + 3]), \
static_cast<float>(accum[start + 4]), \
static_cast<float>(accum[start + 5]), \
static_cast<float>(accum[start + 6]), \
static_cast<float>(accum[start + 7]));
#define PRINT_ACCUM8_T0_L0(name, accum) PRINT_ACCUM8_T0_L0_START(name, accum, 0)
#define PRINT_FRAG_T0_L0(name, frag) \
{ \
auto typeStr = __get_type_name<decltype(frag)>(); \
PRINT_T0_L0("printing %s (%s)", name, typeStr.data); \
for (int _start = 0; _start < frag.size(); _start += 8) { \
PRINT_ACCUM8_T0_L0_START(" ", frag, _start); \
} \
/*__syncthreads(); NANCHECK(frag); */ \
}
#define PRINT_ARRAY_T0_L0_INCR(name, array, length, incr) \
{ \
PRINT_T0_L0("printing %s (len=%d)", name, static_cast<int>(length)); \
for (int _start = 0; _start < length; _start += incr) { \
PRINT_ACCUM8_T0_L0_START(" ", array, _start); \
} \
}
#define PRINT_ARRAY_T0_L0(name, array, length) \
PRINT_ARRAY_T0_L0_INCR(name, array, length, 8)
// Print a 4x4 matrix
#define PRINT_TENSOR4x4_T0_L0_START(name, ref, start_x, start_y) \
PRINT_T0_L0( \
"%s[%d:%d, %d:%d]:\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, " \
"%f, %f, %f\n %f, %f, %f, %f", \
name, \
static_cast<int>(start_x), \
static_cast<int>(start_x + 4), \
static_cast<int>(start_y), \
static_cast<int>(start_y + 4), \
static_cast<float>(ref.at({start_x + 0, start_y + 0})), \
static_cast<float>(ref.at({start_x + 0, start_y + 1})), \
static_cast<float>(ref.at({start_x + 0, start_y + 2})), \
static_cast<float>(ref.at({start_x + 0, start_y + 3})), \
static_cast<float>(ref.at({start_x + 1, start_y + 0})), \
static_cast<float>(ref.at({start_x + 1, start_y + 1})), \
static_cast<float>(ref.at({start_x + 1, start_y + 2})), \
static_cast<float>(ref.at({start_x + 1, start_y + 3})), \
static_cast<float>(ref.at({start_x + 2, start_y + 0})), \
static_cast<float>(ref.at({start_x + 2, start_y + 1})), \
static_cast<float>(ref.at({start_x + 2, start_y + 2})), \
static_cast<float>(ref.at({start_x + 2, start_y + 3})), \
static_cast<float>(ref.at({start_x + 3, start_y + 0})), \
static_cast<float>(ref.at({start_x + 3, start_y + 1})), \
static_cast<float>(ref.at({start_x + 3, start_y + 2})), \
static_cast<float>(ref.at({start_x + 3, start_y + 3})));
#define PRINT_TENSOR4x4_T0_L0(name, ref) \
PRINT_TENSOR4x4_T0_L0_START(name, ref, 0, 0)
#define PRINT_PROBLEM_SIZE(name, ps) \
PRINT_T0_L0("%s.problem_size: {.m=%d, .n=%d, .k=%d}", \
name, \
static_cast<int>(ps.m()), \
static_cast<int>(ps.n()), \
static_cast<int>(ps.k()))
template <typename LambdaIterator, typename LaneOffsetT, typename AccumT>
CUTLASS_DEVICE void print_warp_accum(AccumT accum,
LaneOffsetT lane_offset,
int32_t num_rows,
int32_t num_cols) {
bool is_main = blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 &&
threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0;
for (int row = 0; row < num_rows; ++row) {
for (int col = 0; col < num_cols; ++col) {
if (col % 32 == 0) {
if (is_main) {
printf("\nmat[%3d, %3d:%3d]", row, col, col + 32);
}
__syncthreads();
}
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {},
[&](int accum_m, int accum_n, int idx) {
if (row == accum_m && col == accum_n &&
(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)) {
printf(" %6.1f", static_cast<float>(accum[idx]));
}
},
[&](int accum_m) {});
__syncthreads();
}
if (is_main) {
printf("\n");
}
}
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
File copied from "cutlass/epilogue/threadblock/epilogue.h"
then modified to:
(1) load 2 source fragments at the same time (pipelining)
(2) support reading from a different dtype
(3) pass the row id to the OutputOp if it takes it
(see MemoryEfficientAttentionNormalize)
Note that in general the fragment passed to the OutputOp could
span multiple rows but it does not happen with the configurations we have
*/
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include <assert.h>
#endif
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/functional.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/layout/vector.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_coord.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
#include "cutlass/epilogue/threadblock/epilogue_base.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
template <typename Op>
struct ApplyEpilogueOp {
static CUTLASS_DEVICE typename Op::FragmentOutput apply(
Op const& output_op,
int row_id,
typename Op::FragmentAccumulator const& accum,
typename Op::FragmentOutput const& source) {
return output_op(accum, source);
}
static CUTLASS_DEVICE typename Op::FragmentOutput apply(
Op const& output_op,
int row_id,
typename Op::FragmentAccumulator const& accum) {
return output_op(accum);
}
};
////////////////////////////////////////////////////////////////////////////////
/// Epilogue operator
template <typename Shape_, ///< Shape of threadblock tile (concept: GemmShape)
typename WarpMmaOperator_, ///< Warp-level MMA operator (concept:
///< gemm::warp::MmaTensorOp)
int PartitionsK, ///< Number of partitions of the K dimension
typename OutputTileIterator_, ///< Tile iterator writing output
///< tensors
typename AccumulatorFragmentIterator_, ///< Fragment iterator
///< selecting accumulators
typename WarpTileIterator_, ///< Warp-scoped tile iterator writing
///< accumulators to SMEM
typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator
///< loading from SMEM
typename OutputOp_, ///< Output operator
typename Padding_, ///< Padding added to SMEM allocation to avoid
///< bank conflicts (concept: MatrixShape)
int FragmentsPerPartition =
1, ///< Used to coarsten the epilogue granularity
int IterationsUnroll = ///< Used to reduce binary size when epilogue
///< op is large
(!IsEpilogueFunctorHeavy<OutputOp_>::value),
typename OutputTileSourceIterator_ =
OutputTileIterator_ ///< Tile iterator reading tensors
>
class EpiloguePipelined : public EpilogueBase<Shape_,
typename WarpMmaOperator_::Shape,
PartitionsK,
AccumulatorFragmentIterator_,
WarpTileIterator_,
Padding_,
FragmentsPerPartition> {
public:
using Base = EpilogueBase<Shape_,
typename WarpMmaOperator_::Shape,
PartitionsK,
AccumulatorFragmentIterator_,
WarpTileIterator_,
Padding_,
FragmentsPerPartition>;
using Shape = Shape_;
using WarpMmaOperator = WarpMmaOperator_;
static int const kPartitionsK = PartitionsK;
using OutputTileIterator = OutputTileIterator_;
using OutputTileSourceIterator = OutputTileSourceIterator_;
using AccumulatorFragmentIterator = AccumulatorFragmentIterator_;
using WarpTileIterator = WarpTileIterator_;
using SharedLoadIterator = SharedLoadIterator_;
using OutputOp = OutputOp_;
using Padding = Padding_;
using Layout = layout::RowMajor;
using LongIndex = typename Layout::LongIndex;
/// The complete warp-level accumulator tile
using AccumulatorTile = typename Base::AccumulatorTile;
/// Accumulator element
using ElementAccumulator = typename WarpTileIterator::Element;
/// Output element
using ElementOutput = typename OutputTileIterator::Element;
using ElementSource = typename OutputTileSourceIterator::Element;
/// Output access size
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
/// Tensor reference to destination tensor
using TensorRef = typename OutputTileIterator::TensorRef;
/// Tensor reference to sync tensor
using SyncTensorRef =
typename cutlass::TensorRef<int, cutlass::layout::PackedVectorLayout>;
/// Const tensor reference to source tensor
using ConstTensorRef = typename OutputTileIterator::ConstTensorRef;
/// Array type used to output
using OutputAccessType = Array<typename OutputTileIterator::Element,
OutputTileIterator::kElementsPerAccess>;
using SourceAccessType = Array<typename OutputTileSourceIterator::Element,
OutputTileSourceIterator::kElementsPerAccess>;
/// Array type used by output functor
using AccumulatorAccessType = Array<typename WarpTileIterator::Element,
OutputTileIterator::kElementsPerAccess>;
/// Number of warps
using WarpCount = typename Base::WarpCount;
static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1
? Base::kFragmentsPerIteration
: kPartitionsK;
static int constexpr kSmemPointerOffset =
Base::SharedStorage::StorageShape::kCount / kSmemTiles;
public:
static_assert(
OutputTileSourceIterator::Fragment::kElements ==
OutputTileIterator::Fragment::kElements,
"Mismatch between input tile and output tile iterator (kElements)");
static_assert(
OutputTileSourceIterator::kIterations == OutputTileIterator::kIterations,
"Mismatch between input tile and output tile iterator (kIterations)");
static_assert(
SharedLoadIterator::Fragment::kElements ==
OutputTileIterator::Fragment::kElements,
"Mismatch between shared load iterator and output tile iterator.");
static_assert(OutputTileIterator::kElementsPerAccess,
"OutputTileIterator::kElementsPerAccess must not be zero.");
static_assert(!(OutputTileIterator::Fragment::kElements %
OutputTileIterator::kElementsPerAccess),
"Divisibility");
private:
/// Loads fragment from shared memory aligned with output tensor
SharedLoadIterator shared_load_iterator_;
public:
/// Constructor
CUTLASS_DEVICE
EpiloguePipelined(typename Base::SharedStorage&
shared_storage, ///< Shared storage object //NOLINT
int thread_idx, ///< ID of a thread within the threadblock
int warp_idx, ///< ID of warp within threadblock
int lane_idx ///< Id of thread within warp
)
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
shared_load_iterator_(shared_storage.reference(), thread_idx) {}
/// Streams the result to global memory
CUTLASS_DEVICE
void operator()(
OutputOp const& output_op, ///< Output operator
OutputTileIterator
destination_iterator, ///< Tile iterator for destination
AccumulatorTile const&
accumulators, ///< Complete warp-level accumulator tile
OutputTileSourceIterator
source_iterator) { ///< Threadblock tile coordinate in GEMM (in units
///< of threadblock tiles)
if (!output_op.is_source_needed()) {
compute_source_not_needed_(output_op, destination_iterator, accumulators);
} else {
compute_source_needed_(
output_op, destination_iterator, accumulators, source_iterator);
}
}
CUTLASS_DEVICE
void operator()(OutputOp const& output_op, ///< Output operator
OutputTileIterator
destination_iterator, ///< Tile iterator for destination
AccumulatorTile const&
accumulators) { ///< Complete warp-level accumulator tile
compute_source_not_needed_(output_op, destination_iterator, accumulators);
}
private:
template <class Seq>
struct acc2smem_source_not_needed;
template <size_t... Seq>
struct acc2smem_source_not_needed<cutlass::index_sequence<Seq...>> {
template <int Advance>
CUTLASS_DEVICE static void helper(
AccumulatorFragmentIterator accum_fragment_iterator,
WarpTileIterator& warp_tile_iterator) { // NOLINT
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Advance; i++) {
++accum_fragment_iterator;
}
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
typename AccumulatorFragmentIterator::Fragment accum_fragment;
accum_fragment_iterator.load(accum_fragment);
++accum_fragment_iterator;
warp_tile_iterator.store(accum_fragment);
if (p < Base::kFragmentsPerIteration - 1) {
warp_tile_iterator.add_pointer_offset(kSmemPointerOffset);
}
}
if (Base::kFragmentsPerIteration > 1) {
warp_tile_iterator.add_pointer_offset(
kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
}
}
CUTLASS_DEVICE
static void push(
size_t pos,
AccumulatorFragmentIterator const& iterator_begin, // NOLINT
WarpTileIterator& warp_tile_iterator) { // NOLINT
int dummy[] = {(pos == (Seq * Base::kFragmentsPerIteration)) &&
(helper<Seq * Base::kFragmentsPerIteration>(
iterator_begin, warp_tile_iterator),
0)...};
CUTLASS_UNUSED(dummy[0]);
}
};
static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1,
"One of these must be exactly 1.");
/// Streams the result to global memory
CUTLASS_DEVICE
void compute_source_not_needed_(
OutputOp const& output_op, ///< Output operator
OutputTileIterator
destination_iterator, ///< Tile iterator for destination
AccumulatorTile const&
accumulators ///< Complete warp-level accumulator tile
) {
//
// Iterator over warp-level accumulator fragment
//
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
//
// Iterate over accumulator tile
//
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / \
Base::kFragmentsPerIteration \
: 1)
for (int iter = 0; iter < OutputTileIterator::kIterations;
iter += Base::kFragmentsPerIteration) {
//
// Convert and store fragment
//
__syncthreads();
acc2smem_source_not_needed<cutlass::make_index_sequence<
OutputTileIterator::kIterations / Base::kFragmentsPerIteration>>::
push(iter, accum_fragment_iterator, this->warp_tile_iterator_);
__syncthreads();
//
// Load fragments from shared memory
//
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < Base::kFragmentsPerIteration; ++p) {
typename SharedLoadIterator::Fragment
aligned_accum_fragment[kPartitionsK];
shared_load_iterator_.load(aligned_accum_fragment[0]);
if (p < Base::kFragmentsPerIteration - 1) {
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
} else if (kPartitionsK > 1) {
plus<typename SharedLoadIterator::Fragment> add_fragments;
CUTLASS_PRAGMA_UNROLL
for (int i = 1; i < kPartitionsK; ++i) {
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
shared_load_iterator_.load(aligned_accum_fragment[i]);
aligned_accum_fragment[0] = add_fragments(
aligned_accum_fragment[0], aligned_accum_fragment[i]);
}
shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) *
kSmemPointerOffset);
}
//
// Compute the output result
//
typename OutputTileIterator::Fragment output_fragment;
apply_output_operator_source_not_needed_(
destination_iterator.thread_start_row(),
output_fragment,
output_op,
aligned_accum_fragment[0]);
//
// Store the final result
//
destination_iterator.store(output_fragment);
++destination_iterator;
}
if (Base::kFragmentsPerIteration > 1) {
shared_load_iterator_.add_pointer_offset(
kSmemPointerOffset * (1 - Base::kFragmentsPerIteration));
}
}
}
template <class Seq>
struct acc2smem_source_needed;
template <size_t... Seq>
struct acc2smem_source_needed<cutlass::index_sequence<Seq...>> {
template <int Advance>
CUTLASS_DEVICE static void helper(
AccumulatorFragmentIterator accum_fragment_iterator,
WarpTileIterator& warp_tile_iterator) { // NOLINT
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Advance; i++) {
++accum_fragment_iterator;
}
typename AccumulatorFragmentIterator::Fragment accum_fragment;
accum_fragment_iterator.load(accum_fragment);
warp_tile_iterator.store(accum_fragment);
}
CUTLASS_DEVICE
static void push(
size_t pos,
AccumulatorFragmentIterator const& iterator_begin, // NOLINT
WarpTileIterator& warp_tile_iterator) { // NOLINT
int dummy[] = {(pos == Seq) &&
(helper<Seq>(iterator_begin, warp_tile_iterator), 0)...};
}
};
/// Streams the result to global memory
CUTLASS_DEVICE
void compute_source_needed_(
OutputOp const& output_op, ///< Output operator
OutputTileIterator
destination_iterator, ///< Tile iterator for destination
AccumulatorTile const&
accumulators, ///< Complete warp-level accumulator tile
OutputTileSourceIterator
source_iterator ///< Threadblock tile coordinate in GEMM (in units of
///< threadblock tiles)
) {
typename OutputTileSourceIterator::Fragment source_fragment[2];
source_fragment[0].clear();
source_iterator.load(source_fragment[0]);
++source_iterator;
source_fragment[1].clear();
//
// Iterator over warp-level accumulator fragment
//
AccumulatorFragmentIterator accum_fragment_iterator(accumulators);
//
// Iterate over accumulator tile
//
#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1)
for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) {
if (iter > 0) {
__syncthreads();
}
//
// Load the source for next iteration (pipelining)
//
if (iter + 1 < OutputTileIterator::kIterations) {
source_iterator.load(source_fragment[(iter + 1) % 2]);
}
++source_iterator;
acc2smem_source_needed<cutlass::make_index_sequence<
OutputTileIterator::kIterations>>::push(iter,
accum_fragment_iterator,
this->warp_tile_iterator_);
__syncthreads();
//
// Load fragments from shared memory
//
typename SharedLoadIterator::Fragment
aligned_accum_fragment[kPartitionsK];
shared_load_iterator_.load(aligned_accum_fragment[0]);
// If the number of k-slices is > 1 - perform a reduction amongst the
// k-slices
if (kPartitionsK > 1) {
plus<typename SharedLoadIterator::Fragment> add_fragments;
CUTLASS_PRAGMA_UNROLL
for (int i = 1; i < kPartitionsK; ++i) {
shared_load_iterator_.add_pointer_offset(kSmemPointerOffset);
shared_load_iterator_.load(aligned_accum_fragment[i]);
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0],
aligned_accum_fragment[i]);
}
shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) *
kSmemPointerOffset);
}
//
// Compute the output result
//
typename OutputTileIterator::Fragment output_fragment;
apply_output_operator_(destination_iterator.thread_start_row(),
output_fragment,
output_op,
aligned_accum_fragment[0],
source_fragment[iter % 2]);
//
// Store the final result
//
destination_iterator.store(output_fragment);
++destination_iterator;
}
}
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator_(
int begin_row,
typename OutputTileIterator::Fragment& output_fragment, // NOLINT
OutputOp const& output_op, ///< Output operator
typename SharedLoadIterator::Fragment const& aligned_accum_fragment,
typename OutputTileSourceIterator::Fragment const& source_fragment) {
OutputAccessType* output_frag_ptr =
reinterpret_cast<OutputAccessType*>(&output_fragment);
AccumulatorAccessType const* compute_frag_ptr =
reinterpret_cast<AccumulatorAccessType const*>(&aligned_accum_fragment);
SourceAccessType const* source_frag_ptr =
reinterpret_cast<SourceAccessType const*>(&source_fragment);
int const kOutputOpIterations = OutputTileIterator::Fragment::kElements /
OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i) {
// Call the output operator
output_frag_ptr[i] = ApplyEpilogueOp<OutputOp>::apply(
output_op,
begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess),
compute_frag_ptr[i],
source_frag_ptr[i]);
}
}
/// Helper to invoke the output functor over each vector of output
CUTLASS_DEVICE
void apply_output_operator_source_not_needed_(
int begin_row,
typename OutputTileIterator::Fragment& output_fragment, // NOLINT
OutputOp const& output_op, ///< Output operator
typename SharedLoadIterator::Fragment const& aligned_accum_fragment) {
OutputAccessType* output_frag_ptr =
reinterpret_cast<OutputAccessType*>(&output_fragment);
AccumulatorAccessType const* compute_frag_ptr =
reinterpret_cast<AccumulatorAccessType const*>(&aligned_accum_fragment);
int const kOutputOpIterations = OutputTileIterator::Fragment::kElements /
OutputTileIterator::kElementsPerAccess;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kOutputOpIterations; ++i) {
// Call the output operator
output_frag_ptr[i] = ApplyEpilogueOp<OutputOp>::apply(
output_op,
begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess),
compute_frag_ptr[i]);
}
}
// This should be constexpr, but it's only supported on c++14
static int CUTLASS_HOST_DEVICE getRowOffset(int i) {
using ThreadMap = typename OutputTileIterator::ThreadMap;
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster;
++cluster) {
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int row_offset = row * ThreadMap::Delta::kRow +
group * ThreadMap::Delta::kGroup +
cluster * ThreadMap::Delta::kCluster;
int frag_row_idx =
(row + ThreadMap::Iterations::kRow *
(group + ThreadMap::Iterations::kGroup * cluster));
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn;
++column) {
int frag_idx =
ThreadMap::kElementsPerAccess *
(frag_row_idx * ThreadMap::Iterations::kColumn + column);
if (i < frag_idx + ThreadMap::kElementsPerAccess) {
return row_offset;
}
}
}
}
}
return -1;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
The epilogue rearranges the result of a matrix product through shared memory
to match canonical tensor layouts in global memory. Epilogues support
conversion and reduction operations.
This is a copy of cutlass/epilogue/threadblock/epilogue.h that can
handle "row_id" as a first argument, as uses it to get the corresponding
`m_prime` / `s_prime` to rescale the output.
*/
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#else
#include <assert.h>
#endif
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/functional.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/layout/vector.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_coord.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
#include "cutlass/epilogue/threadblock/epilogue_base.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
#include "./epilogue_pipelined.h"
#include "cutlass/epilogue/thread/scale_type.h"
#include "cutlass/numeric_conversion.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace thread {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Applies a linear combination operator to an array of elements.
// output <- alpha * accumulator + beta * source
// with:
// alpha = 1 / s_prime (to normalize when isLast=True, 1 otherwise)
// beta = alpha / m_prime (renormalize the output when the max changes)
// source is the current output
template <typename ElementOutput_, /// < Data type used to store tensors
typename ElementSource_, // < Data type for source (usually matches
// `ElementOutput`)
int Count, ///< Number of elements computed per operation.
///< Usually it is 128/sizeof_bits<ElementOutput_>,
///< but we use 64 or 32 sometimes when there are not
///< enough data to store
typename ElementAccumulator_, ///< Accumulator data type
typename ElementCompute_, ///< Data type used to compute linear
///< combination
bool isFirst,
bool isLast,
typename FragmentAlphaBeta_,
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest>
class MemoryEfficientAttentionNormalize {
public:
using ElementOutput = ElementOutput_;
using ElementSource = ElementSource_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
static int const kCount = Count;
using FragmentOutput = Array<ElementOutput, kCount>;
using FragmentSource = Array<ElementSource, kCount>;
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
using ComputeFragment = Array<ElementCompute, kCount>;
using FragmentAlphaBeta = FragmentAlphaBeta_;
static FloatRoundStyle const kRound = Round;
private:
//
// Data members
//
FragmentAlphaBeta const& s_prime_;
FragmentAlphaBeta const& m_prime_;
public:
/// Constructs the function object, possibly loading from pointers in host
/// memory
CUTLASS_HOST_DEVICE
MemoryEfficientAttentionNormalize(FragmentAlphaBeta const& s_prime,
FragmentAlphaBeta const& m_prime)
: s_prime_(s_prime), m_prime_(m_prime) {}
/// Returns true if source is needed
CUTLASS_HOST_DEVICE
bool is_source_needed() const { return !isFirst; }
/// Functionally required for serial reduction in the epilogue
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count) {}
/// Computes linear scaling: D = alpha * accumulator + beta * source
CUTLASS_HOST_DEVICE
FragmentOutput operator()(int row,
FragmentAccumulator const& accumulator,
FragmentSource const& source) const {
assert(!isFirst);
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementSource, kCount, Round>
source_converter;
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
accumulator_converter;
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
destination_converter;
ComputeFragment converted_source = source_converter(source);
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
// Perform binary operations
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_add_source;
multiply_add<ComputeFragment> mul_add_accumulator;
ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1;
ElementCompute beta = alpha * m_prime_[row];
intermediate = mul_add_source(beta, converted_source); // X = beta * C
intermediate = mul_add_accumulator(
alpha, converted_accumulator, intermediate); // D = alpha * Accum + X
return destination_converter(intermediate);
}
/// Computes linear scaling: D = alpha * accumulator
CUTLASS_HOST_DEVICE
FragmentOutput operator()(int row,
FragmentAccumulator const& accumulator) const {
assert(isFirst);
// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round>
accumulator_converter;
// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round>
destination_converter;
ComputeFragment converted_accumulator = accumulator_converter(accumulator);
ComputeFragment intermediate;
multiplies<ComputeFragment> mul_accumulator;
ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1;
intermediate = mul_accumulator(
alpha, converted_accumulator); // X = alpha * C + uniform
return destination_converter(intermediate);
}
};
} // namespace thread
namespace threadblock {
template <typename EO,
typename ES,
int Count,
typename EA,
typename EC,
bool F,
bool L,
typename FAB,
FloatRoundStyle R>
struct ApplyEpilogueOp<thread::MemoryEfficientAttentionNormalize<EO,
ES,
Count,
EA,
EC,
F,
L,
FAB,
R>> {
using Op = thread::
MemoryEfficientAttentionNormalize<EO, ES, Count, EA, EC, F, L, FAB, R>;
static CUTLASS_DEVICE typename Op::FragmentOutput apply(
Op const& output_op,
int row_id,
typename Op::FragmentAccumulator const& accum,
typename Op::FragmentSource const& source) {
return output_op(row_id, accum, source);
}
static CUTLASS_DEVICE typename Op::FragmentOutput apply(
Op const& output_op,
int row_id,
typename Op::FragmentAccumulator const& accum) {
return output_op(row_id, accum);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Functor performing linear combination operations used by epilogues.
*/
#pragma once
#include <cuda_fp16.h>
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/functional.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace thread {
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
template <typename Element, int ElementsPerAccess>
struct ArrayExponential {
CUTLASS_HOST_DEVICE
Array<Element, ElementsPerAccess> operator()(
Array<Element, ElementsPerAccess> const& input) const {
Array<Element, ElementsPerAccess> result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ElementsPerAccess; ++i) {
result[i] = expf(input[i]);
}
return result;
}
};
template <int ElementsPerAccess>
struct ArrayExponential<half_t, ElementsPerAccess> {
CUTLASS_DEVICE
Array<half_t, ElementsPerAccess> operator()(
Array<half_t, ElementsPerAccess> const& input) const {
Array<half_t, ElementsPerAccess> result;
int const kVectorCount = ElementsPerAccess / 2;
__half2 const* input_ptr =
reinterpret_cast<__half2 const*>(input.raw_data());
__half2* res_ptr = reinterpret_cast<__half2*>(result.raw_data());
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kVectorCount; ++i) {
res_ptr[i] = h2exp(input_ptr[i]);
}
return result;
}
};
} // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Applies:
/// output <- (input - lse).exp()
template <typename ElementOutput_, // output
typename ElementLSE_, // accumulator from LSE
typename ElementAccumulator_, // accumulator from matmul
typename ElementCompute_, // intermediate compute (and exp
// calculation)
int ElementsPerAccess>
class ApplyLogSumExp {
public:
using ElementOutput = ElementOutput_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;
using ElementLSE = ElementLSE_;
static int const kElementsPerAccess = ElementsPerAccess;
static int const kCount = kElementsPerAccess;
static const ScaleType::Kind kScale =
cutlass::epilogue::thread::ScaleType::NoBetaScaling;
using FragmentOutput = Array<ElementOutput, kCount>;
using FragmentAccumulator = Array<ElementAccumulator, kElementsPerAccess>;
using FragmentCompute = Array<ElementCompute, kElementsPerAccess>;
using FragmentLSE = Array<ElementLSE, kElementsPerAccess>;
using FragmentScaleBias = FragmentLSE; // Used by epilogue_smem_accumulator.h
public:
//
// Methods
//
CUTLASS_HOST_DEVICE
ApplyLogSumExp() {}
/// Returns true if source is needed
CUTLASS_HOST_DEVICE
bool is_source_needed() const { return true; }
/// Functionally required for serial reduction in the epilogue
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count) {}
CUTLASS_HOST_DEVICE
FragmentOutput operator()(FragmentAccumulator const& AB,
FragmentLSE const& scale_unused,
// bias used as LSE
FragmentLSE const& bias) const {
FragmentCompute frag_AB = NumericArrayConverter<ElementCompute,
ElementAccumulator,
kElementsPerAccess>()(AB);
FragmentCompute frag_lse_compute =
NumericArrayConverter<ElementCompute, ElementLSE, kElementsPerAccess>()(
bias);
FragmentCompute frag_compute;
minus<FragmentCompute> minus_lse;
detail::ArrayExponential<ElementCompute, kElementsPerAccess> apply_exp;
frag_compute = minus_lse(frag_AB, frag_lse_compute);
frag_compute = apply_exp(frag_compute);
return NumericArrayConverter<ElementOutput,
ElementCompute,
kElementsPerAccess>()(frag_compute);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace thread
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include "./custom_mma_multistage.h"
#include "./custom_mma_pipelined.h"
#include "cutlass/gemm/threadblock/mma_multistage.h"
#include "cutlass/gemm/threadblock/mma_pipelined.h"
template <typename Mma, int kMaxK>
struct MakeCustomMma;
template <typename Shape,
typename IteratorA,
typename SmemIteratorA,
cutlass::arch::CacheOperation::Kind CacheOpA,
typename IteratorB,
typename SmemIteratorB,
cutlass::arch::CacheOperation::Kind CacheOpB,
typename ElementC,
typename LayoutC,
typename Policy,
int Stages,
cutlass::gemm::SharedMemoryClearOption SharedMemoryClear,
int kMaxK>
struct MakeCustomMma<
cutlass::gemm::threadblock::MmaMultistage<Shape,
IteratorA,
SmemIteratorA,
CacheOpA,
IteratorB,
SmemIteratorB,
CacheOpB,
ElementC,
LayoutC,
Policy,
Stages,
SharedMemoryClear>,
kMaxK> {
// Reduce the number of stages if we don't need that many
static int constexpr kStages =
kMaxK == cutlass::platform::numeric_limits<int>::max()
? Stages
: cutlass::const_min(Stages,
(kMaxK + static_cast<int>(Shape::kK) - 1) /
static_cast<int>(Shape::kK));
using Mma = cutlass::gemm::threadblock::CustomMmaMultistage<Shape,
IteratorA,
SmemIteratorA,
CacheOpA,
IteratorB,
SmemIteratorB,
CacheOpB,
ElementC,
LayoutC,
Policy,
kStages,
SharedMemoryClear,
kMaxK>;
};
template <typename Shape,
typename IteratorA,
typename SmemIteratorA,
typename IteratorB,
typename SmemIteratorB,
typename ElementC,
typename LayoutC,
typename Policy,
int kMaxK>
struct MakeCustomMma<cutlass::gemm::threadblock::MmaPipelined<Shape,
IteratorA,
SmemIteratorA,
IteratorB,
SmemIteratorB,
ElementC,
LayoutC,
Policy>,
kMaxK> {
using Mma = cutlass::gemm::threadblock::CustomMmaPipelined<Shape,
IteratorA,
SmemIteratorA,
IteratorB,
SmemIteratorB,
ElementC,
LayoutC,
Policy>;
};
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/threadblock/mma_base.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Used for partial specialization
typename Enable = bool>
class CustomMmaBase {
public:
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Policy describing tuning details
using Policy = Policy_;
//
// Dependent types
//
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Shape describing the overall GEMM computed from shared memory
/// by each warp.
using WarpGemm = typename Policy::Operator::Shape;
/// Shape describing the number of warps filling the CTA
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM,
Shape::kN / WarpGemm::kN,
Shape::kK / WarpGemm::kK>;
/// Number of warp-level GEMM oeprations
static int const kWarpGemmIterations =
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
/// Number of stages
static int const kStages = Stages;
//
// Nested structs
//
/// Shared storage object needed by threadblock-scoped GEMM
template <typename Element, typename OperandShape, typename OperandLayout>
struct OperandSharedStorage {
AlignedBuffer<Element, OperandShape::kCount> buffer;
using TensorRef = TensorRef<Element, OperandLayout>;
CUTLASS_DEVICE
static OperandLayout Layout() {
return OperandLayout::packed({OperandShape::kRow, OperandShape::kColumn});
}
/// Returns a TensorRef to the operand
CUTLASS_HOST_DEVICE
TensorRef ref() { return TensorRef{buffer.data(), Layout()}; }
};
/// Shape of the A matrix operand in shared memory
using ShapeA =
MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,
Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
/// Shape of the B matrix operand in shared memory
using ShapeB = MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
Shape::kN + Policy::SmemPaddingB::kColumn>;
using SharedStorageA = OperandSharedStorage<typename Operator::ElementA,
ShapeA,
typename Operator::LayoutA>;
using SharedStorageB = OperandSharedStorage<typename Operator::ElementB,
ShapeB,
typename Operator::LayoutB>;
using TensorRefA = typename SharedStorageA::TensorRef;
using TensorRefB = typename SharedStorageB::TensorRef;
struct SharedStorage {
/// Buffer for A operand
SharedStorageA operand_A;
/// Buffer for B operand
SharedStorageB operand_B;
};
protected:
//
// Data members
//
/// Iterator to load a warp-scoped tile of A operand from shared memory
typename Operator::IteratorA warp_tile_iterator_A_;
/// Iterator to load a warp-scoped tile of B operand from shared memory
typename Operator::IteratorB warp_tile_iterator_B_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
CustomMmaBase(
///< Shared storage needed for internal use by threadblock-scoped GEMM
SharedStorageA& shared_storageA, // NOLINT
SharedStorageB& shared_storageB, // NOLINT
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: warp_tile_iterator_A_(shared_storageA.ref(), lane_idx),
warp_tile_iterator_B_(shared_storageB.ref(), lane_idx) {}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/cache_operation.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "./custom_mma_base.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Cache operation for operand A
cutlass::arch::CacheOperation::Kind CacheOpA,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Cache operation for operand B
cutlass::arch::CacheOperation::Kind CacheOpB,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
/// Upper boundon the K dimension
int kMaxK = cutlass::platform::numeric_limits<int>::max(),
/// Used for partial specialization
typename Enable = bool>
class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages> {
public:
///< Base class
using Base = CustomMmaBase<Shape_, Policy_, Stages>;
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Iterates over tiles of A operand in global memory
using IteratorA = IteratorA_;
///< Iterates over tiles of B operand in global memory
using IteratorB = IteratorB_;
///< Data type of accumulator matrix
using ElementC = ElementC_;
///< Layout of accumulator matrix
using LayoutC = LayoutC_;
///< Policy describing tuning details
using Policy = Policy_;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
//
// Dependent types
//
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Minimum architecture is Sm80 to support cp.async
using ArchTag = arch::Sm80;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
/// Internal structure exposed for introspection.
struct Detail {
static_assert(Base::kWarpGemmIterations > 1,
"The pipelined structure requires at least two warp-level "
"GEMM operations.");
/// Number of cp.async instructions to load one stage of operand A
static int const AsyncCopyIterationsPerStageA =
IteratorA::ThreadMap::Iterations::kCount;
/// Number of cp.async instructions to load one stage of operand B
static int const AsyncCopyIterationsPerStageB =
IteratorB::ThreadMap::Iterations::kCount;
/// Number of stages
static int const kStages = Stages;
/// Number of cp.async instructions to load on group of operand A
static int const kAccessesPerGroupA =
(AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) /
Base::kWarpGemmIterations;
/// Number of cp.async instructions to load on group of operand B
static int const kAccessesPerGroupB =
(AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) /
Base::kWarpGemmIterations;
};
static bool const kSmemContainsEntireMat = kMaxK <= Shape::kK * Stages;
static constexpr int kNumStagesConcurrentLoad =
kSmemContainsEntireMat ? Stages : Stages - 1;
private:
using WarpLoadedFragmentA = typename Operator::FragmentA;
using WarpLoadedFragmentB = typename Operator::FragmentB;
using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
private:
//
// Data members
//
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
bool prologue_done_;
// Set to `True` to ensure the accumulator will be zero outside the GEMM
// footprint
bool zero_outside_bounds_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
CustomMmaMultistage(
///< Shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorageA& shared_storageA, // NOLINT
typename Base::SharedStorageB& shared_storageB, // NOLINT
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx),
smem_iterator_A_(shared_storageA.ref(), thread_idx),
smem_iterator_B_(shared_storageB.ref(), thread_idx),
prologue_done_(false),
zero_outside_bounds_(false) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset(
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset(
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
}
CUTLASS_DEVICE
CustomMmaMultistage(
///< Shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorage& st, // NOLINT
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: CustomMmaMultistage(
st.operand_A, st.operand_B, thread_idx, warp_idx, lane_idx) {}
CUTLASS_DEVICE
bool set_prologue_done(bool value) { prologue_done_ = value; }
CUTLASS_DEVICE
bool set_zero_outside_bounds(bool value) { zero_outside_bounds_ = value; }
template <bool kLoadA = true, bool kLoadB = true>
CUTLASS_DEVICE static void prologue(
typename Base::SharedStorage& shared_storage, // NOLINT
///< iterator over A operand in global memory
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B,
int thread_idx,
int problem_size_k) {
prologue<kLoadA, kLoadB>(shared_storage.operand_A,
shared_storage.operand_B,
iterator_A,
iterator_B,
thread_idx,
problem_size_k);
}
template <bool kLoadA = true, bool kLoadB = true>
CUTLASS_DEVICE static void prologue(
typename Base::SharedStorageA& shared_storageA, // NOLINT
typename Base::SharedStorageB& shared_storageB, // NOLINT
///< iterator over A operand in global memory
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B,
int thread_idx,
int problem_size_k) {
SmemIteratorA smem_iterator_A(shared_storageA.ref(), thread_idx);
SmemIteratorB smem_iterator_B(shared_storageB.ref(), thread_idx);
int32_t iter = (problem_size_k + Base::Shape::kK - 1) / Base::Shape::kK;
_prologue<kLoadA, kLoadB>(
iterator_A, iterator_B, iter, smem_iterator_A, smem_iterator_B);
}
CUTLASS_DEVICE
void copy_tiles_and_advance(IteratorA& iterator_A, // NOLINT
IteratorB& iterator_B, // NOLINT
int group_start_A = 0,
int group_start_B = 0) {
iterator_A.set_iteration_index(group_start_A *
IteratorA::kAccessesPerVector);
this->smem_iterator_A_.set_iteration_index(group_start_A);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) {
if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) {
typename IteratorA::AccessType* dst_ptr =
reinterpret_cast<typename IteratorA::AccessType*>(
this->smem_iterator_A_.get());
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_A.get();
if (zero_outside_bounds_ ||
SharedMemoryClear == SharedMemoryClearOption::kZfill) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, gmem_ptr, iterator_A.valid());
} else {
cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(
dst_ptr + v, gmem_ptr, iterator_A.valid());
}
++iterator_A;
}
++this->smem_iterator_A_;
}
}
iterator_B.set_iteration_index(group_start_B *
IteratorB::kAccessesPerVector);
this->smem_iterator_B_.set_iteration_index(group_start_B);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
typename IteratorB::AccessType* dst_ptr =
reinterpret_cast<typename IteratorB::AccessType*>(
this->smem_iterator_B_.get());
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
IteratorB::ThreadMap::kElementsPerAccess /
IteratorB::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_B.get();
if (zero_outside_bounds_ ||
SharedMemoryClear == SharedMemoryClearOption::kZfill) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
} else {
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
}
++iterator_B;
}
++this->smem_iterator_B_;
}
}
}
template <bool kLoadA = true, bool kLoadB = true>
CUTLASS_DEVICE static void _prologue(
IteratorA& iterator_A, // NOLINT
IteratorB& iterator_B, // NOLINT
int32_t& gemm_k_iterations, // NOLINT
SmemIteratorA& smem_iterator_A_, // NOLINT
SmemIteratorB& smem_iterator_B_) { // NOLINT
// Issue several complete stages
CUTLASS_PRAGMA_UNROLL
for (int stage = 0; stage < kNumStagesConcurrentLoad;
++stage, --gemm_k_iterations) {
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.set_iteration_index(0);
smem_iterator_A_.set_iteration_index(0);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
typename IteratorA::AccessType* dst_ptr =
reinterpret_cast<typename IteratorA::AccessType*>(
smem_iterator_A_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
int const kSrcBytes =
sizeof_bits<typename IteratorA::Element>::value *
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
if (kLoadA) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, iterator_A.get(), iterator_A.valid());
}
++iterator_A;
}
++smem_iterator_A_;
}
iterator_B.set_iteration_index(0);
smem_iterator_B_.set_iteration_index(0);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
typename IteratorB::AccessType* dst_ptr =
reinterpret_cast<typename IteratorB::AccessType*>(
smem_iterator_B_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
int const kSrcBytes =
sizeof_bits<typename IteratorB::Element>::value *
IteratorB::ThreadMap::kElementsPerAccess /
IteratorB::kAccessesPerVector / 8;
if (kLoadB) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, iterator_B.get(), iterator_B.valid());
}
++iterator_B;
}
++smem_iterator_B_;
}
// Move to the next stage
iterator_A.add_tile_offset({0, 1});
iterator_B.add_tile_offset({1, 0});
smem_iterator_A_.add_tile_offset({0, 1});
smem_iterator_B_.add_tile_offset({1, 0});
// Defines the boundary of a stage of cp.async.
cutlass::arch::cp_async_fence();
}
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(
///< problem size of GEMM
int gemm_k_iterations,
///< destination accumulator tile
FragmentC& accum, // NOLINT
///< iterator over A operand in global memory
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B,
///< initial value of accumulator
FragmentC const& src_accum) {
//
// Prologue
//
if (!prologue_done_) {
_prologue<true, true>(iterator_A,
iterator_B,
gemm_k_iterations,
smem_iterator_A_,
smem_iterator_B_);
} else if (!kSmemContainsEntireMat) {
_prologue<false, false>(iterator_A,
iterator_B,
gemm_k_iterations,
smem_iterator_A_,
smem_iterator_B_);
} else {
gemm_k_iterations -= kNumStagesConcurrentLoad;
}
// Perform accumulation in the 'd' output operand
accum = src_accum;
//
// Clear the remaining tiles of SMEM. This is a functional requirement for
// some kernels so that all accumulator elements outside the GEMM footprint
// are zero.
//
if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) {
/// Iterator to write threadblock-scoped tile of A operand to shared
/// memory
SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
typename IteratorA::AccessType zero_A;
zero_A.clear();
last_smem_iterator_A.set_iteration_index(0);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
typename IteratorA::AccessType* dst_ptr =
reinterpret_cast<typename IteratorA::AccessType*>(
last_smem_iterator_A.get());
*dst_ptr = zero_A;
++last_smem_iterator_A;
}
/// Iterator to write threadblock-scoped tile of B operand to shared
/// memory
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
typename IteratorB::AccessType zero_B;
zero_B.clear();
last_smem_iterator_B.set_iteration_index(0);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
typename IteratorB::AccessType* dst_ptr =
reinterpret_cast<typename IteratorB::AccessType*>(
last_smem_iterator_B.get());
*dst_ptr = zero_B;
++last_smem_iterator_B;
}
}
// Waits until kStages-2 stages have committed.
cutlass::arch::cp_async_wait<kNumStagesConcurrentLoad - 1>();
__syncthreads();
// Pair of fragments used to overlap shared memory loads and math
// instructions
WarpLoadedFragmentA warp_loaded_frag_A[2];
WarpLoadedFragmentB warp_loaded_frag_B[2];
WarpTransformedFragmentA warp_transformed_frag_A[2];
WarpTransformedFragmentB warp_transformed_frag_B[2];
Operator warp_mma;
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]);
this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
warp_mma.transform(warp_transformed_frag_A[0],
warp_transformed_frag_B[0],
warp_loaded_frag_A[0],
warp_loaded_frag_B[0]);
// tf32x3 kernels use staging accumulation. warp_mma uses a temporary
// accumulator and this temporary accumulator is added to the final
// accumulator once in every mainloop iteration.
plus<FragmentC> plus_accum;
FragmentC tmp_accum;
if (platform::is_same<typename Operator::MathOperator,
arch::OpMultiplyAddFastF32>::value ||
platform::is_same<typename Operator::MathOperator,
arch::OpMultiplyAddComplexFastF32>::value) {
tmp_accum.clear();
}
//
// Mainloop
//
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > (-kNumStagesConcurrentLoad);) {
//
// Loop over GEMM K dimension
//
// Computes a warp-level GEMM on data held in shared memory
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
++warp_mma_k) {
// Load warp-level tiles from shared memory, wrapping to k offset if
// this is the last group as the case may be.
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) %
Base::kWarpGemmIterations);
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) %
Base::kWarpGemmIterations);
// In case of a non-circular buffer ("kSmemContainsEntireMat")
// make sure we don't load out of bounds data.
if (!kSmemContainsEntireMat ||
gemm_k_iterations > (-kNumStagesConcurrentLoad) ||
warp_mma_k < Base::kWarpGemmIterations - 1) {
this->warp_tile_iterator_A_.load(
warp_loaded_frag_A[(warp_mma_k + 1) % 2]);
this->warp_tile_iterator_B_.load(
warp_loaded_frag_B[(warp_mma_k + 1) % 2]);
}
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
if (warp_mma_k > 0)
warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2],
warp_transformed_frag_B[warp_mma_k % 2],
warp_loaded_frag_A[warp_mma_k % 2],
warp_loaded_frag_B[warp_mma_k % 2]);
if (platform::is_same<typename Operator::MathOperator,
arch::OpMultiplyAddFastF32>::value ||
platform::is_same<typename Operator::MathOperator,
arch::OpMultiplyAddComplexFastF32>::value) {
warp_mma(tmp_accum,
warp_transformed_frag_A[warp_mma_k % 2],
warp_transformed_frag_B[warp_mma_k % 2],
tmp_accum);
if (warp_mma_k == 0) {
accum = plus_accum(accum, tmp_accum);
tmp_accum.clear();
}
} else {
warp_mma(accum,
warp_transformed_frag_A[warp_mma_k % 2],
warp_transformed_frag_B[warp_mma_k % 2],
accum);
}
// Issue global->shared copies for the this stage
if (!kSmemContainsEntireMat &&
warp_mma_k < Base::kWarpGemmIterations - 1) {
int group_start_iteration_A, group_start_iteration_B;
group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
copy_tiles_and_advance(iterator_A,
iterator_B,
group_start_iteration_A,
group_start_iteration_B);
}
if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
if (!kSmemContainsEntireMat) {
int group_start_iteration_A, group_start_iteration_B;
group_start_iteration_A =
(warp_mma_k + 1) * Detail::kAccessesPerGroupA;
group_start_iteration_B =
(warp_mma_k + 1) * Detail::kAccessesPerGroupB;
copy_tiles_and_advance(iterator_A,
iterator_B,
group_start_iteration_A,
group_start_iteration_B);
}
// Inserts a memory fence between stages of cp.async instructions.
cutlass::arch::cp_async_fence();
// Waits until kStages-2 stages have committed.
cutlass::arch::cp_async_wait<kNumStagesConcurrentLoad - 1>();
__syncthreads();
// Move to the next stage
iterator_A.add_tile_offset({0, 1});
iterator_B.add_tile_offset({1, 0});
this->smem_iterator_A_.add_tile_offset({0, 1});
this->smem_iterator_B_.add_tile_offset({1, 0});
// Add negative offsets to return iterators to the 'start' of the
// circular buffer in shared memory
if (smem_write_stage_idx == (Base::kStages - 1)) {
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
smem_write_stage_idx = 0;
} else {
++smem_write_stage_idx;
}
if (!kSmemContainsEntireMat &&
smem_read_stage_idx == (Base::kStages - 1)) {
this->warp_tile_iterator_A_.add_tile_offset(
{0,
-Base::kStages * Policy::kPartitionsK *
Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK *
Base::kWarpGemmIterations,
0});
smem_read_stage_idx = 0;
} else {
++smem_read_stage_idx;
}
--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
}
// Do any conversions feeding the first stage at the end of the loop so
// we can start right away on mma instructions
if (warp_mma_k + 1 == Base::kWarpGemmIterations)
warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2],
warp_transformed_frag_B[(warp_mma_k + 1) % 2],
warp_loaded_frag_A[(warp_mma_k + 1) % 2],
warp_loaded_frag_B[(warp_mma_k + 1) % 2]);
}
}
if (platform::is_same<typename Operator::MathOperator,
arch::OpMultiplyAddFastF32>::value ||
platform::is_same<typename Operator::MathOperator,
arch::OpMultiplyAddComplexFastF32>::value) {
accum = plus_accum(accum, tmp_accum);
}
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
// commit and drain all pending and predicated LDGSTS pnz from the GEMM
// mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "./custom_mma_base.h"
#include "cutlass/gemm/gemm.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Transformation applied to A operand
typename TransformA_ =
NumericArrayConverter<typename SmemIteratorA_::Element,
typename IteratorA_::Element,
IteratorA_::Fragment::kElements>,
///
/// Transformation applied to B operand
typename TransformB_ =
NumericArrayConverter<typename SmemIteratorB_::Element,
typename IteratorB_::Element,
IteratorB_::Fragment::kElements>,
/// Used for partial specialization
typename Enable = bool>
class CustomMmaPipelined : public CustomMmaBase<Shape_, Policy_, 2> {
public:
///< Base class
using Base = CustomMmaBase<Shape_, Policy_, 2>;
using Shape =
Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using IteratorA =
IteratorA_; ///< Iterates over tiles of A operand in global memory
using IteratorB =
IteratorB_; ///< Iterates over tiles of B operand in global memory
using ElementC = ElementC_; ///< Data type of accumulator matrix
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
using Policy = Policy_; ///< Policy describing tuning details
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
using TransformA = TransformA_;
using TransformB = TransformB_;
//
// Dependent types
//
/// Fragment of operand A loaded from global memory
using FragmentA = typename IteratorA::Fragment;
/// Fragment of operand B loaded from global memory
using FragmentB = typename IteratorB::Fragment;
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Obtain the arch tag from the warp-level operator
using ArchTag = typename Policy::Operator::ArchTag;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
static_assert((Base::kStages == 2),
"MmaPipelined requires kStages set to value 2");
static bool const kSmemContainsEntireMat = false;
private:
using WarpFragmentA = typename Operator::FragmentA;
using WarpFragmentB = typename Operator::FragmentB;
protected:
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
CustomMmaPipelined(typename Base::SharedStorageA& shared_storageA, // NOLINT
typename Base::SharedStorageB& shared_storageB, // NOLINT
int thread_idx, ///< ID within the threadblock
int warp_idx, ///< ID of warp
int lane_idx ///< ID of each thread within a warp
)
: Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx),
smem_iterator_A_(shared_storageA.ref(), thread_idx),
smem_iterator_B_(shared_storageB.ref(), thread_idx) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset(
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset(
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
}
CUTLASS_DEVICE
CustomMmaPipelined(
///< Shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorage& st, // NOLINT
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: CustomMmaPipelined(
st.operand_A, st.operand_B, thread_idx, warp_idx, lane_idx) {}
CUTLASS_DEVICE
bool set_prologue_done(bool value) {
// NOT IMPLEMENTED FOR PIPELINED
}
CUTLASS_DEVICE
bool set_zero_outside_bounds(bool value) {
// NOT NEEDED FOR PIPELINED
// shared memory will always be zero-filled
}
template <bool kLoadA = true, bool kLoadB = true>
CUTLASS_DEVICE static void prologue(
typename Base::SharedStorage& shared_storage, // NOLINT
///< iterator over A operand in global memory
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B,
int thread_idx,
int problem_size_k) {
prologue<kLoadA, kLoadB>(shared_storage.operand_A,
shared_storage.operand_B,
iterator_A,
iterator_B,
thread_idx,
problem_size_k);
}
template <bool kLoadA = true, bool kLoadB = true>
CUTLASS_DEVICE static void prologue(
typename Base::SharedStorageA& shared_storageA, // NOLINT
typename Base::SharedStorageB& shared_storageB, // NOLINT
///< iterator over A operand in global memory
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B,
int thread_idx,
int problem_size_k) {
// NOT IMPLEMENTED FOR PIPELINED
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(
int gemm_k_iterations, ///< number of iterations of the mainloop
FragmentC& accum, ///< destination accumulator tile //NOLINT
IteratorA iterator_A, ///< iterator over A operand in global memory
IteratorB iterator_B, ///< iterator over B operand in global memory
FragmentC const& src_accum, ///< source accumulator tile
TransformA transform_A =
TransformA(), ///< transformation applied to A fragment
TransformB transform_B =
TransformB()) { ///< transformation applied to B fragment
//
// Prologue
//
// Perform accumulation in the 'd' output operand
accum = src_accum;
FragmentA tb_frag_A;
FragmentB tb_frag_B;
tb_frag_A.clear();
tb_frag_B.clear();
// The last kblock is loaded in the prolog
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
++iterator_A;
++iterator_B;
this->smem_iterator_A_.store(transform_A(tb_frag_A));
this->smem_iterator_B_.store(transform_B(tb_frag_B));
++this->smem_iterator_A_;
++this->smem_iterator_B_;
__syncthreads();
// Pair of fragments used to overlap shared memory loads and math
// instructions
WarpFragmentA warp_frag_A[2];
WarpFragmentB warp_frag_B[2];
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
Operator warp_mma;
int smem_write_stage_idx = 1;
// Avoid reading out of bounds
iterator_A.clear_mask(gemm_k_iterations <= 1);
iterator_B.clear_mask(gemm_k_iterations <= 1);
// Issue loads during the first warp-level matrix multiply-add *AFTER*
// issuing shared memory loads (which have the tighest latency requirement).
//
// Mainloop
//
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > 0; --gemm_k_iterations) {
//
// Loop over GEMM K dimension
//
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
++warp_mma_k) {
// Load warp-level tiles from shared memory, wrapping to k offset if
// this is the last group as the case may be.
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
// Write fragments to shared memory
this->smem_iterator_A_.store(transform_A(tb_frag_A));
this->smem_iterator_B_.store(transform_B(tb_frag_B));
__syncthreads();
++this->smem_iterator_A_;
++this->smem_iterator_B_;
// Add negative offsets to return iterators to the 'start' of the
// circular buffer in shared memory
if (smem_write_stage_idx == 1) {
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
} else {
this->warp_tile_iterator_A_.add_tile_offset(
{0,
-Base::kStages * Policy::kPartitionsK *
Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK *
Base::kWarpGemmIterations,
0});
}
smem_write_stage_idx ^= 1;
}
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) %
Base::kWarpGemmIterations);
this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) %
Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
if (warp_mma_k == 0) {
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
++iterator_A;
++iterator_B;
// Avoid reading out of bounds if this was the last loop iteration
iterator_A.clear_mask(gemm_k_iterations <= 2);
iterator_B.clear_mask(gemm_k_iterations <= 2);
}
warp_mma(accum,
warp_frag_A[warp_mma_k % 2],
warp_frag_B[warp_mma_k % 2],
accum);
}
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
/*! \file
\brief Cutlass provides helper template functions to figure out the right
datastructures to instanciate to run a GEMM with various parameters (see
`cutlass/gemm/threadblock/default_mma.h`). However, due to template
instantiation priority rules, it will only create an MmaMultiStage with
kStages=3 (otherwise creates an MmePipelined - which is not compatible with
FastF32). kStages=3 uses too much shared memory and we want to use kStages=2,
so we just copy-pasted some code from `default_mma.h` and
`default_mma_core.h` files and wrapped this template to allow our usecase.
This is really only for the FastF32 case - aka using TensorCores with fp32.
*/
#pragma once
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
template <
/// Element type for A matrix operand
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Layout type for C and D matrix operand
typename LayoutC,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation perfomed by GEMM
typename Operator,
typename Enable_ = void>
struct FindDefaultMma {
static constexpr bool AccumulatorsInRowMajor = false;
static constexpr SharedMemoryClearOption SharedMemoryClear =
SharedMemoryClearOption::kNone;
using DefaultMma =
cutlass::gemm::threadblock::DefaultMma<ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementAccumulator,
LayoutC,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
Stages,
Operator,
AccumulatorsInRowMajor,
SharedMemoryClear>;
};
/// Specialization for sm80 / FastF32 / multistage with kStages=2
template <typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
int kStages,
typename Operator>
struct FindDefaultMma<
ElementA_,
LayoutA_,
kAlignmentA,
ElementB_,
LayoutB_,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
arch::Sm80,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
typename cutlass::platform::enable_if<(kAlignmentA > 1)>::type> {
using LayoutC = layout::RowMajor;
using OperatorClass = arch::OpClassTensorOp;
using ArchTag = arch::Sm80;
using DefaultMma_ = cutlass::gemm::threadblock::DefaultMma<ElementA_,
LayoutA_,
kAlignmentA,
ElementB_,
LayoutB_,
kAlignmentB,
ElementAccumulator,
LayoutC,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
3,
Operator>;
struct DefaultMma : DefaultMma_ {
using MmaCore_ = typename DefaultMma_::MmaCore;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<
typename MmaCore_::Shape,
typename DefaultMma_::IteratorA,
typename MmaCore_::SmemIteratorA,
MmaCore_::kCacheOpA,
typename DefaultMma_::IteratorB,
typename MmaCore_::SmemIteratorB,
MmaCore_::kCacheOpB,
ElementAccumulator,
LayoutC,
typename MmaCore_::MmaPolicy,
kStages>;
};
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include "cutlass/functional.h"
#include "cutlass/gemm/warp/mma_simt_tile_iterator.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
#include "cutlass/matrix_shape.h"
/*
TensorCores have different accumulator layouts.
This file provides a class to easily map the accumulator
i-th element with the corresponding matrix row/col.
*/
template <typename T, typename accum_t, int kWarpSize>
struct AccumLambdaIteratorSm80 {
static_assert(cutlass::platform::is_same<typename T::Layout,
cutlass::layout::RowMajor>::value,
"only RowMajor is supported");
using Policy = typename T::Policy;
using InstructionShape = typename T::InstructionShape;
using OpDelta = typename T::OpDelta;
using Shape = typename T::Shape;
static int const kElementsPerAccess = InstructionShape::kN / 4;
static int const kRowsPerTile = 8;
static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile;
static cutlass::MatrixCoord CUTLASS_DEVICE
get_lane_offset(int8_t lane_id,
int8_t warp_id,
typename T::TensorCoord const& tile_offset) {
int quad = (lane_id >> 2);
int lane_in_quad = (lane_id & 3);
return cutlass::MatrixCoord(quad + tile_offset.row() * Shape::kRow,
lane_in_quad * kElementsPerAccess +
tile_offset.column() * Shape::kColumn);
}
template <typename FA, typename FB, typename FC>
CUTLASS_DEVICE static void iterateRows(
cutlass::MatrixCoord& lane_offset, // NOLINT
FA beginRow,
FB op,
FC endRow) {
// See cutlass/gemm/warp/mma_tensor_op_tile_iterator.h
CUTLASS_PRAGMA_UNROLL
for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < kAccumulatorRows; ++row) {
int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow +
row * kRowsPerTile + lane_offset.row();
beginRow(accum_m);
CUTLASS_PRAGMA_UNROLL
for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) {
int mma_accum_start = kAccumulatorRows * kElementsPerAccess *
(mma_n * Policy::MmaIterations::kRow + mma_m);
CUTLASS_PRAGMA_UNROLL
for (int col = 0; col < kElementsPerAccess; ++col) {
int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn +
col + lane_offset.column();
int idx = mma_accum_start + row * kElementsPerAccess + col;
op(accum_m, accum_n, idx);
}
}
endRow(accum_m);
}
}
}
template <typename DT, typename F>
CUTLASS_DEVICE static bool reduceSameRow(int lane_id,
DT& myValue, // NOLINT
F fn) { // NOLINT
// In each warp, 4 threads will work on the same row
// - the ones with the same `quad`
auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1);
myValue = fn(myValue, otherV);
otherV = __shfl_xor_sync(0xffffffff, myValue, 2);
myValue = fn(myValue, otherV);
int lane_in_quad = (lane_id & 3);
return lane_in_quad == 0;
}
};
template <typename T, typename accum_t, int kWarpSize>
struct AccumLambdaIteratorSm70 {
static_assert(cutlass::platform::is_same<typename T::Layout,
cutlass::layout::RowMajor>::value,
"only RowMajor is supported");
using Policy = typename T::Policy;
using InstructionShape = typename T::InstructionShape;
using OpDelta = typename T::OpDelta;
using Shape = typename T::Shape;
using Element = accum_t;
static int const kElementsPerPartial = 4;
using EleShapePerPatial = typename cutlass::platform::conditional<
cutlass::platform::is_same<Element, float>::value,
cutlass::MatrixShape<2, 2>,
cutlass::MatrixShape<1, 4>>::type;
static int const kElementsPerMma = 8;
static int const kAccumulatorPatials = 2;
using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>;
static cutlass::MatrixCoord CUTLASS_DEVICE
get_lane_offset(int8_t lane_id,
int8_t warp_id,
typename T::TensorCoord const& tile_offset) {
int quad = (lane_id >> 2);
int lane_in_quad = (lane_id & 3);
int accum_m, accum_n;
if (cutlass::platform::is_same<Element, float>::value) {
// (quad[2],quad[0])+lane_in_quad[0]
accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1);
// (quad[1])+lane_in_quad[1]
accum_n =
((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials +
(lane_in_quad & 2);
} else {
accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 +
lane_in_quad; // (quad[2],quad[0])
accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials;
}
return cutlass::MatrixCoord(
accum_m + tile_offset.row() * Shape::kRow,
accum_n + tile_offset.column() * Shape::kColumn);
}
template <typename DT, typename F>
CUTLASS_DEVICE static bool reduceSameRow(int lane_id,
DT& myValue, // NOLINT
F fn) { // NOLINT
static_assert(cutlass::platform::is_same<Element, float>::value,
"update to support non-float accum");
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
// T0 & T2 share same line within a quad
auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 1);
myValue = fn(myValue, otherV);
// quad 0 and quad 2 are on the same lines
otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 3);
myValue = fn(myValue, otherV);
return (lane_id & ((1 << 1) | (1 << 3))) == 0;
}
template <typename FA, typename FB, typename FC>
CUTLASS_DEVICE static void iterateRows(
cutlass::MatrixCoord& lane_offset, // NOLINT
FA beginRow,
FB op,
FC endRow) {
CUTLASS_PRAGMA_UNROLL
for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) {
CUTLASS_PRAGMA_UNROLL
for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < EleShapePerPatial::kRow; ++m) {
int accum_m = tile_m * Policy::InterleavedTile::kRow +
mma_m * QuadShapePerPatialMma::kRow + m * 2 +
lane_offset.row();
beginRow(accum_m);
CUTLASS_PRAGMA_UNROLL
for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn;
++tile_n) {
CUTLASS_PRAGMA_UNROLL
for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn;
++mma_n) {
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < kAccumulatorPatials; ++p) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < EleShapePerPatial::kColumn; ++n) {
int mma_accum_start =
(((tile_n * Policy::TileIterations::kRow + tile_m) *
Policy::MmaIterations::kColumn +
mma_n) *
Policy::MmaIterations::kRow +
mma_m) *
kElementsPerMma;
int accum_n = tile_n * Policy::InterleavedTile::kColumn +
mma_n * QuadShapePerPatialMma::kColumn +
p * Policy::InterleavedTile::kColumn / 2 + n +
lane_offset.column();
int idx = mma_accum_start + p * kElementsPerPartial +
m * EleShapePerPatial::kColumn + n;
op(accum_m, accum_n, idx);
}
}
}
}
endRow(accum_m);
}
}
}
}
};
template <typename T, typename accum_t, int kWarpSize>
struct AccumLambdaIteratorSimt {
using Policy = typename T::Policy;
using Iterations = typename T::Iterations;
using Element = typename T::Element;
using Delta = typename T::Delta;
using Shape = typename T::Shape;
static_assert(cutlass::platform::is_same<typename T::Layout,
cutlass::layout::RowMajor>::value,
"only RowMajor is supported");
template <typename DT, typename F>
CUTLASS_DEVICE static bool reduceSameRow(int lane_id,
DT& myValue, // NOLINT
F fn) { // NOLINT
CUTLASS_PRAGMA_UNROLL
for (int bit = 1; bit < Policy::WarpShape::kColumn; bit *= 2) {
auto otherV = __shfl_xor_sync(0xffffffff, myValue, bit);
myValue = fn(myValue, otherV);
}
return (lane_id & (Policy::WarpShape::kColumn - 1)) == 0;
}
template <typename FA, typename FB, typename FC>
CUTLASS_DEVICE static void iterateRows(
cutlass::MatrixCoord& lane_offset, // NOLINT
FA beginRow,
FB op,
FC endRow) {
CUTLASS_PRAGMA_UNROLL
for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) {
int accum_m = mma_m * Delta::kRow + m + lane_offset.row();
beginRow(accum_m);
CUTLASS_PRAGMA_UNROLL
for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {
int accum_n =
mma_n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN +
lane_offset.column();
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) {
int idx =
n + Policy::LaneMmaShape::kN *
(mma_n + Iterations::kColumn *
(m + mma_m * Policy::LaneMmaShape::kM));
op(accum_m, accum_n + n, idx);
}
}
endRow(accum_m);
}
}
}
static cutlass::MatrixCoord CUTLASS_DEVICE
get_lane_offset(int8_t lane_id,
int8_t warp_id,
typename T::TensorCoord const& tile_offset) {
static_assert(cutlass::platform::is_same<
typename Policy::LaneLayout,
cutlass::layout::RowMajorInterleaved<1>>::value,
"");
typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
cutlass::MatrixCoord lane_offset =
lane_layout.inverse(lane_id) *
cutlass::MatrixCoord(Policy::LaneMmaShape::kM,
Policy::LaneMmaShape::kN);
return lane_offset +
tile_offset * cutlass::MatrixCoord(Shape::kRow, Shape::kColumn);
}
};
template <typename T, typename accum_t, int kWarpSize>
struct DefaultMmaAccumLambdaIterator;
// Simt
template <typename S, typename P, typename accum_t, int kWarpSize>
struct DefaultMmaAccumLambdaIterator<
cutlass::gemm::warp::MmaSimtTileIterator<S,
cutlass::gemm::Operand::kC,
accum_t,
cutlass::layout::RowMajor,
P,
1,
1>,
accum_t,
kWarpSize> {
using WarpIterator = typename cutlass::gemm::warp::MmaSimtTileIterator<
S,
cutlass::gemm::Operand::kC,
accum_t,
cutlass::layout::RowMajor,
P,
1,
1>;
using Iterator = AccumLambdaIteratorSimt<WarpIterator, accum_t, kWarpSize>;
};
// TensorOp - Volta
template <typename S1, typename S2, typename accum_t, int kWarpSize>
struct DefaultMmaAccumLambdaIterator<
cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator<
S1,
accum_t,
cutlass::layout::RowMajor,
S2,
cutlass::MatrixShape<1, 1>>,
accum_t,
kWarpSize> {
using WarpIterator =
typename cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator<
S1,
accum_t,
cutlass::layout::RowMajor,
S2,
cutlass::MatrixShape<1, 1>>;
using Iterator = AccumLambdaIteratorSm70<WarpIterator, accum_t, kWarpSize>;
};
// TensorOp - Sm75+
template <typename S1,
typename S2,
typename S3,
typename accum_t,
int kWarpSize>
struct DefaultMmaAccumLambdaIterator<
cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator<
S1,
accum_t,
cutlass::layout::RowMajor,
S2,
S3>,
accum_t,
kWarpSize> {
using WarpIterator =
typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator<
S1,
accum_t,
cutlass::layout::RowMajor,
S2,
S3>;
using Iterator = AccumLambdaIteratorSm80<WarpIterator, accum_t, kWarpSize>;
};
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
#include "cutlass/functional.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
#include "cutlass/platform/platform.h"
#include "cutlass/transform/threadblock/vector_iterator.h"
#include "../epilogue/epilogue_thread_apply_logsumexp.h"
#include "../gemm/mma_accum_lambda_iterator.h"
#include "../gemm_kernel_utils.h"
#include "../iterators/make_residual_last.h"
#include "../iterators/transpose_warp_iterator.h"
#include "../iterators/warp_iterator_from_smem.h"
#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h"
#include "cutlass/gemm/threadblock/mma_base.h"
#include "cutlass/gemm/threadblock/mma_multistage.h"
#include "cutlass/gemm/threadblock/mma_pipelined.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
/// Shared storage object needed by accumulator
/// From 13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h
template <typename Shape_,
typename Element_,
typename Layout_,
typename Padding_>
class AccumulatorSharedStorage {
public:
//
// Type definitions
//
using Shape = Shape_;
using Element = Element_;
using Layout = Layout_;
using Padding = Padding_;
/// Tensor reference to the accumulator
using TensorRefAccum = cutlass::TensorRef<Element, Layout>;
/// Shape of the accumulator matrix in shared memory
using ShapeAccum = cutlass::MatrixShape<Shape::kM + Padding::kRow,
Shape::kN + Padding::kColumn>;
public:
//
// Data members
//
/// Buffer for accumulator
cutlass::AlignedBuffer<Element, ShapeAccum::kCount> accum;
public:
//
// Methods
//
/// Returns a layout object for the Accum matrix
CUTLASS_DEVICE
static Layout LayoutAccum() {
return Layout::packed({ShapeAccum::kRow, ShapeAccum::kColumn});
}
/// Returns a TensorRef to the Accumulator
CUTLASS_HOST_DEVICE
TensorRefAccum accum_ref() {
return TensorRefAccum{accum.data(), LayoutAccum()};
}
};
////////////////////////////////////////////////////////////////////////////////
// Taken from
// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h
////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
// Maximum value for K
int kMaxK,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Used for partial specialization
typename Enable = bool>
class MmaBaseFromSharedMemory {
public:
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Policy describing tuning details
using Policy = Policy_;
//
// Dependent types
//
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Shape describing the overall GEMM computed from shared memory
/// by each warp.
using WarpGemm = typename Policy::Operator::Shape;
/// Shape describing the number of warps filling the CTA
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM,
Shape::kN / WarpGemm::kN,
Shape::kK / WarpGemm::kK>;
using WarpCount1 = WarpCount;
/// Number of warp-level GEMM oeprations
static int const kWarpGemmIterations =
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
static int const kWarpGemmIterations1 = kWarpGemmIterations;
/// Number of stages
static int const kStages = Stages;
/// If this is true, we fill the entire shmem buffer at start
/// and don't need to iterate through it in a circular fashion
static bool const kSmemContainsEntireB = kMaxK <= Shape::kK * kStages;
/// Tensor reference to the A operand
using TensorRefA =
TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
/// Tensor reference to the B operand
using TensorRefB =
TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
//
// Nested structs
//
/// Shared storage object needed by threadblock-scoped GEMM
class SharedStorage {
public:
//
// Type definitions
//
/// Shape of the B matrix operand in shared memory
using ShapeB = MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
Shape::kN + Policy::SmemPaddingB::kColumn>;
public:
//
// Data members
//
/// Buffer for B operand
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
public:
//
// Methods
//
/// Returns a layout object for the B matrix
CUTLASS_HOST_DEVICE
static typename Operator::LayoutB LayoutB() {
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
}
/// Returns a TensorRef to the B operand
CUTLASS_HOST_DEVICE
TensorRefB operand_B_ref() {
return TensorRefB{operand_B.data(), LayoutB()};
}
};
protected:
//
// Data members
//
// /// Iterator to load a warp-scoped tile of A operand from shared memory
// typename Operator::IteratorA warp_tile_iterator_A_;
/// Iterator to load a warp-scoped tile of B operand from shared memory
typename Operator::IteratorB warp_tile_iterator_B_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
MmaBaseFromSharedMemory(
///< Shared storage needed for internal use by threadblock-scoped GEMM
SharedStorage& shared_storage, // NOLINT
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {}
};
namespace { // NOLINT
// has necessary trait compliance with WarpIteratorFromSmem but doesn't do
// anything, can be default initialized, and uses fragment that takes up
// (almost) no space. this warp iterator is selected at compile time when
// elementwise on-the-fly scaling for operand A is disabled, in which case
// operations related to loading scale factors for operand A get wiped out by
// the compiler.
template <typename TensorRef>
class NoOpWarpIteratorScale {
public:
// in pipelined+multistage MMA implementations we keep an array of fragments.
// if we aren't using scaling we don't want to waste registers on fragments
// of scale elements, so ideally this would be sized 0.
// using size 1 is kind of a hack to get around arrays of zero-sized objects
// not being allowed. the compiler is probably smart enough to wipe it out
// anyways.
using Fragment = cutlass::Array<char, 1>;
CUTLASS_HOST_DEVICE
NoOpWarpIteratorScale() {}
CUTLASS_HOST_DEVICE
NoOpWarpIteratorScale(TensorRef const&, int) {}
CUTLASS_HOST_DEVICE
NoOpWarpIteratorScale& add_tile_offset(
typename TensorRef::TensorCoord const&) {
return *this;
}
CUTLASS_HOST_DEVICE
NoOpWarpIteratorScale& operator++() { return *this; }
CUTLASS_DEVICE
void load(Fragment&) const {}
};
// if scaling is enabled, performs fragment elementwise multiplication between
// fragment and its scaling factor.
template <typename Fragment, typename FragmentScale, bool ScalingEnabled>
class FragmentElementwiseScaler;
// specialization for scaling being enabled.
template <typename Fragment, typename FragmentScale>
class FragmentElementwiseScaler<Fragment, FragmentScale, true> {
public:
// cast scale_frag to correct type then apply elementwise to fragment
CUTLASS_DEVICE
static Fragment apply(Fragment frag, FragmentScale const& scale_frag) {
Fragment converted_scale_frag =
cutlass::NumericArrayConverter<typename Fragment::Element,
typename FragmentScale::Element,
FragmentScale::kElements>()(scale_frag);
return cutlass::multiplies<Fragment>()(frag, converted_scale_frag);
}
};
// specialization for scaling being disabled. doesn't do anything and should
// just get wiped out by the compiler.
template <typename Fragment, typename FragmentScale>
class FragmentElementwiseScaler<Fragment, FragmentScale, false> {
public:
CUTLASS_DEVICE
static Fragment apply(Fragment frag, FragmentScale const&) { return frag; }
};
} // namespace
////////////////////////////////////////////////////////////////////////////////
// Taken from
// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h
////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
// BEGIN smem
/// Iterates over the intermediate accumulator tile in shared memory
typename WarpIteratorA,
/// whether or not to perform elementwise multiplication of A
// by another matrix (A_scale) that is also kept in shared memory prior
// to matmul A @ B
bool ScaleOperandA_,
// Accumulator type
typename AccumulatorSharedStorage,
// END smem
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Transformation applied to B operand
typename TransformB_ =
NumericArrayConverter<typename SmemIteratorB_::Element,
typename IteratorB_::Element,
IteratorB_::Fragment::kElements>,
/// Used for partial specialization
typename Enable = bool>
class MmaPipelinedFromSharedMemory
: public MmaBaseFromSharedMemory<Shape_,
AccumulatorSharedStorage::Shape::kN,
Policy_,
2> {
public:
///< Base class
using Base = MmaBaseFromSharedMemory<Shape_,
AccumulatorSharedStorage::Shape::kN,
Policy_,
2>;
using Shape =
Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
static constexpr bool ScaleOperandA = ScaleOperandA_;
///< loads fragments of A_scale from shared memory if operand A scaling is
///< enabled. otherwise no-op.
using WarpIteratorAScale = typename cutlass::platform::conditional<
ScaleOperandA,
WarpIteratorA,
NoOpWarpIteratorScale<typename WarpIteratorA::TensorRef>>::type;
using IteratorB =
IteratorB_; ///< Iterates over tiles of B operand in global memory
using ElementC = ElementC_; ///< Data type of accumulator matrix
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
using Policy = Policy_; ///< Policy describing tuning details
using SmemIteratorB = SmemIteratorB_;
using TransformB = TransformB_;
//
// Dependent types
//
/// Fragment of operand B loaded from global memory
using FragmentB = typename IteratorB::Fragment;
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Obtain the arch tag from the warp-level operator
using ArchTag = typename Policy::Operator::ArchTag;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
static_assert((Base::kStages == 2),
"MmaPipelined requires kStages set to value 2");
private:
using WarpFragmentA = typename Operator::FragmentA;
/// fragment type of OperandA elementwise scaling matrix. (almost) empty
/// if operand A scaling is disabled.
using WarpFragmentAScale = typename WarpIteratorAScale::Fragment;
using WarpFragmentB = typename Operator::FragmentB;
/// applies scaling factor to operand A fragment if operand A scaling is
/// enabled. otherwise no-op.
using FragmentAScaler = FragmentElementwiseScaler<WarpFragmentA,
WarpFragmentAScale,
ScaleOperandA>;
protected:
// /// Iterator to write threadblock-scoped tile of A operand to shared memory
// SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
/// Iterator to load a warp-scoped tile of A operand from intermediate
/// accumulator tile
WarpIteratorA warp_tile_iterator_A_;
/// Iterator to load a warp-scoped tile of A_scale from intermediate
/// accumulator tile (only used if ScaleOperandA_ is true)
WarpIteratorAScale warp_tile_iterator_A_scale_;
public:
/// constructor for MMA with operand A scaling enabled.
CUTLASS_DEVICE
MmaPipelinedFromSharedMemory(
// shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorage& shared_storage, // NOLINT
// warp iterator over A tile held in shared memory
WarpIteratorA warp_iter_a,
// warp iterator over A_scale tile held in shared memory
WarpIteratorAScale warp_iter_a_scale,
int thread_idx,
int warp_idx,
int lane_idx)
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
warp_tile_iterator_A_(warp_iter_a),
warp_tile_iterator_A_scale_(warp_iter_a_scale),
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset(
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_A_scale_.add_tile_offset(
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset(
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
}
/// Construct from tensor references
CUTLASS_DEVICE
MmaPipelinedFromSharedMemory(
typename Base::SharedStorage&
shared_storage, ///< Shared storage needed for internal use by
///< threadblock-scoped GEMM
AccumulatorSharedStorage& accumulator_shared_storage, // NOLINT
int thread_idx, ///< ID within the threadblock
int warp_idx, ///< ID of warp
int lane_idx, ///< ID of each thread within a warp
int problem_size_0_n)
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
warp_tile_iterator_A_(accumulator_shared_storage.accum_ref(), lane_idx),
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset(
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset(
{Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
}
// For API compatibility with MmaMultistageFromSharedMemory
// but not supported as it worsens perf: older gpus < sm80 don't
// support async tranfers and have to waste registers
CUTLASS_DEVICE
void set_prologue_done(bool value) {}
CUTLASS_DEVICE
static void prologue(typename Base::SharedStorage& shared_storage, // NOLINT
IteratorB iterator_B1,
int thread_idx,
int problem_size_0_n) {}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(
int gemm_k_iterations, ///< number of iterations of the mainloop
FragmentC& accum, ///< destination accumulator tile //NOLINT
// IteratorA iterator_A, ///< iterator over A
// operand in global memory
IteratorB iterator_B, ///< iterator over B operand in global memory
FragmentC const& src_accum, ///< source accumulator tile
// TransformA transform_A = TransformA(), ///< transformation
// applied to A fragment
TransformB transform_B =
TransformB()) { ///< transformation applied to B fragment
//
// Prologue
//
// Perform accumulation in the 'd' output operand
accum = src_accum;
FragmentB tb_frag_B;
tb_frag_B.clear();
// The last kblock is loaded in the prolog
iterator_B.set_residual_tile(gemm_k_iterations == 1);
iterator_B.load(tb_frag_B);
++iterator_B;
this->smem_iterator_B_.store(transform_B(tb_frag_B));
++this->smem_iterator_B_;
__syncthreads();
// remember that WarpFragmentAScale and WarpIteratorAScale are empty/no-op
// if scaling is disabled.
// Pair of fragments used to overlap shared memory loads and math
// instructions
WarpFragmentA warp_frag_A[2];
WarpFragmentAScale warp_frag_A_scale[2];
WarpFragmentB warp_frag_B[2];
warp_frag_A[0].clear();
warp_frag_A_scale[0].clear();
warp_frag_B[0].clear();
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
this->warp_tile_iterator_A_scale_.load(warp_frag_A_scale[0]);
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_A_scale_;
++this->warp_tile_iterator_B_;
Operator warp_mma;
int smem_write_stage_idx = 1;
// Avoid reading out of bounds
iterator_B.set_residual_tile(gemm_k_iterations == 2);
iterator_B.clear_mask(gemm_k_iterations <= 1);
// Issue loads during the first warp-level matrix multiply-add *AFTER*
// issuing shared memory loads (which have the tighest latency requirement).
//
// Mainloop
//
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > 0; --gemm_k_iterations) {
//
// Loop over GEMM K dimension
//
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
++warp_mma_k) {
// Load warp-level tiles from shared memory, wrapping to k offset if
// this is the last group as the case may be.
bool hasNext = true;
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
// Write fragments to shared memory
this->smem_iterator_B_.store(transform_B(tb_frag_B));
__syncthreads();
++this->smem_iterator_B_;
// Add negative offsets to return iterators to the 'start' of the
// circular buffer in shared memory SMEM: Don't reset iterator A, as
// we are continuing our iteration at this point
if (smem_write_stage_idx == 1) {
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
} else {
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK *
Base::kWarpGemmIterations,
0});
}
smem_write_stage_idx ^= 1;
hasNext = gemm_k_iterations > 1;
}
// Only read the next if we need to
if (hasNext) {
this->warp_tile_iterator_B_.set_kgroup_index(
(warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
this->warp_tile_iterator_A_scale_.load(
warp_frag_A_scale[(warp_mma_k + 1) % 2]);
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_A_scale_;
++this->warp_tile_iterator_B_;
if (warp_mma_k == 0) {
iterator_B.load(tb_frag_B);
++iterator_B;
// Avoid reading out of bounds if this was the last loop iteration
iterator_B.set_residual_tile(gemm_k_iterations == 3);
iterator_B.clear_mask(gemm_k_iterations <= 2);
}
}
warp_mma(accum,
FragmentAScaler::apply(warp_frag_A[warp_mma_k % 2],
warp_frag_A_scale[warp_mma_k % 2]),
warp_frag_B[warp_mma_k % 2],
accum);
}
}
}
};
////////////////////////////////////////////////////////////////////////////////
// Taken from
// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h
////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape1_,
/// Iterates over the intermediate accumulator tile in shared memory
typename WarpIteratorA1_,
/// whether or not to perform elementwise multiplication of A
// by another matrix (A_scale) that is also kept in shared memory prior
// to matmul A @ B
bool ScaleOperandA_,
// Accumulator type
typename AccumulatorSharedStorage,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB1_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB1_,
/// Cache operation for operand B
cutlass::arch::CacheOperation::Kind CacheOpB1,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy1_,
/// Number of stages,
int Stages_,
int kMaxK_,
/// Used for partial specialization
typename Enable = bool>
class MmaMultistageFromSharedMemory
: public MmaBaseFromSharedMemory<Shape1_, kMaxK_, Policy1_, Stages_> {
public:
///< Base class
using Base = MmaBaseFromSharedMemory<Shape1_, kMaxK_, Policy1_, Stages_>;
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape1 = Shape1_;
///< Iterates over tiles of B operand in global memory
using IteratorB1 = IteratorB1_;
using IteratorB = IteratorB1;
///< Policy describing tuning details
using Policy1 = Policy1_;
using SmemIteratorB1 = SmemIteratorB1_;
using WarpIteratorA1 =
WarpIteratorA1_; ///< Iterates over the intermediate
///< accumulator tile in shared memory
static constexpr bool ScaleOperandA = ScaleOperandA_;
///< warp level iterator over A_scale matrix tile kept in shared memory.
///< if elementwise A scaling is disabled then everything this does is no-op.
using WarpIteratorAScale = typename cutlass::platform::conditional<
ScaleOperandA,
WarpIteratorA1,
NoOpWarpIteratorScale<typename WarpIteratorA1::TensorRef>>::type;
///< Data type of accumulator matrix
using ElementC = ElementC_;
///< Layout of accumulator matrix
using LayoutC = LayoutC_;
static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1;
static constexpr bool kSmemContainsEntireB = Base::kSmemContainsEntireB;
//
// Dependent types
//
/// Fragment of accumulator tile
using FragmentC1 = typename Policy1::Operator::FragmentC;
using FragmentC = FragmentC1;
/// Warp-level Mma
using Operator1 = typename Policy1::Operator;
/// Minimum architecture is Sm80 to support cp.async
using ArchTag = arch::Sm80;
/// Complex transform on B operand
static ComplexTransform const kTransformB1 = Operator1::kTransformB;
/// Internal structure exposed for introspection.
struct Detail {
static_assert(Base::kWarpGemmIterations1 > 1,
"The pipelined structure requires at least two warp-level "
"GEMM operations.");
/// Number of cp.async instructions to load one stage of operand B
static int const TBLDGSTSIterationsB1 =
IteratorB1::ThreadMap::Iterations::kCount;
/// Number of cp.async instructions to load on group of operand B
static int const kAccessesPerGroupB1 =
(TBLDGSTSIterationsB1 + Base::kWarpGemmIterations1 - 1) /
Base::kWarpGemmIterations1;
};
static constexpr int kNumStagesConcurrentLoad =
kSmemContainsEntireB ? Base::kStages : Base::kStages - 1;
private:
using WarpLoadedFragmentA1 = typename Operator1::FragmentA;
/// fragment of OperandA scale matrix. if operand A scaling is disabled this
/// is (almost) empty.
using WarpLoadedFragmentA1Scale = typename WarpIteratorAScale::Fragment;
using WarpLoadedFragmentB1 = typename Operator1::FragmentB;
using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA;
using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB;
/// applies elementwise scaling to fragment of A. if operand A scaling is
/// disabled this is a no-op.
using FragmentAScaler = FragmentElementwiseScaler<WarpLoadedFragmentA1,
WarpLoadedFragmentA1Scale,
ScaleOperandA>;
private:
//
// Data members
//
/// Iterator to load a warp-scoped tile of A1 operand from intermediate
/// accumulator tile
WarpIteratorA1 warp_tile_iterator_A1_;
/// Iterator to load a warp-scoped tile of A1_scale operand from shared memory
/// if operand A scaling is disabled everything this does is a no-op.
WarpIteratorAScale warp_tile_iterator_A1_scale_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB1 smem_iterator_B1_;
bool prologue_done_;
public:
/// constructor for MMA with operand A scaling enabled.
CUTLASS_DEVICE
MmaMultistageFromSharedMemory(
// shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorage& shared_storage, // NOLINT
// warp level iterator over operand A tile kept in shared memory
WarpIteratorA1 warp_tile_iterator_A1,
// warp level iterator over operand A elementwise scale tile kept in
// shared memory.
WarpIteratorAScale warp_tile_iterator_A1_scale,
int thread_idx,
int warp_idx,
int lane_idx)
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
warp_tile_iterator_A1_(warp_tile_iterator_A1),
warp_tile_iterator_A1_scale_(warp_tile_iterator_A1_scale),
smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx),
prologue_done_(false) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn_1 =
warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN);
int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN);
int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM;
int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM;
// Add per-warp offsets in units of warp-level tiles
warp_tile_iterator_A1_.add_tile_offset(
{warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1});
warp_tile_iterator_A1_scale_.add_tile_offset(
{warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1});
this->warp_tile_iterator_B_.add_tile_offset(
{Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1});
}
/// Construct from tensor references
CUTLASS_DEVICE
MmaMultistageFromSharedMemory(
typename Base::SharedStorage&
shared_storage, ///< Shared storage needed for internal use by
///< threadblock-scoped GEMM
AccumulatorSharedStorage& accumulator_shared_storage, // NOLINT
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx,
///< GEMM0 N is used for accumulator extent
int problem_size_0_n)
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
warp_tile_iterator_A1_(accumulator_shared_storage.accum_ref(),
lane_idx),
smem_iterator_B1_(shared_storage.operand_B_ref(), thread_idx),
prologue_done_(false) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn_1 =
warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN);
int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN);
int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM;
int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM;
// Add per-warp offsets in units of warp-level tiles
warp_tile_iterator_A1_.add_tile_offset(
{warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1});
this->warp_tile_iterator_B_.add_tile_offset(
{Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1});
}
CUTLASS_DEVICE
void set_prologue_done(bool value) { prologue_done_ = value; }
CUTLASS_DEVICE
static void prologue(typename Base::SharedStorage& shared_storage, // NOLINT
IteratorB iterator_B1,
int thread_idx,
int problem_size_0_n) {
SmemIteratorB1 smem_iterator_B1(shared_storage.operand_B_ref(), thread_idx);
_prologue(iterator_B1,
(problem_size_0_n + Base::Shape::kK - 1) / Base::Shape::kK,
smem_iterator_B1);
}
CUTLASS_DEVICE
void copy_tiles_and_advance_1(IteratorB1& iterator_B1, // NOLINT
int group_start_B1 = 0) {
iterator_B1.set_iteration_index(group_start_B1 *
IteratorB1::kAccessesPerVector);
this->smem_iterator_B1_.set_iteration_index(group_start_B1);
// LDGSTS for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) {
if (group_start_B1 + j < Detail::TBLDGSTSIterationsB1) {
typename IteratorB1::AccessType* dst_ptr =
reinterpret_cast<typename IteratorB1::AccessType*>(
this->smem_iterator_B1_.get());
int const kSrcBytes = sizeof_bits<typename IteratorB1::Element>::value *
IteratorB1::ThreadMap::kElementsPerAccess /
IteratorB1::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_B1.get();
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB1>(
dst_ptr + v, gmem_ptr, iterator_B1.valid());
++iterator_B1;
}
++this->smem_iterator_B1_;
}
}
}
CUTLASS_DEVICE
static void _prologue(IteratorB& iterator_B1, // NOLINT
int32_t gemm_k_iterations_1,
SmemIteratorB1& smem_iterator_B1_) { // NOLINT
// Issue several complete stages
CUTLASS_PRAGMA_UNROLL
for (int stage = 0; stage < kNumStagesConcurrentLoad;
++stage, --gemm_k_iterations_1) {
iterator_B1.set_residual_tile(gemm_k_iterations_1 == 1);
iterator_B1.clear_mask(gemm_k_iterations_1 == 0);
iterator_B1.set_iteration_index(0);
smem_iterator_B1_.set_iteration_index(0);
// LDGSTS for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::TBLDGSTSIterationsB1; ++j) {
typename IteratorB1::AccessType* dst_ptr =
reinterpret_cast<typename IteratorB1::AccessType*>(
smem_iterator_B1_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) {
int const kSrcBytes =
sizeof_bits<typename IteratorB1::Element>::value *
IteratorB1::ThreadMap::kElementsPerAccess /
IteratorB1::kAccessesPerVector / 8;
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB1>(
dst_ptr + v, iterator_B1.get(), iterator_B1.valid());
++iterator_B1;
}
++smem_iterator_B1_;
}
// Move to the next stage
iterator_B1.add_tile_offset({1, 0});
smem_iterator_B1_.add_tile_offset({1, 0});
// Defines the boundary of a stage of cp.async.
cutlass::arch::cp_async_fence();
}
iterator_B1.set_residual_tile(gemm_k_iterations_1 == 1);
iterator_B1.clear_mask(gemm_k_iterations_1 == 0);
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(
///< problem size of GEMM
int gemm_k_iterations_1_,
///< destination accumulator tile
FragmentC1& accum, // NOLINT
///< iterator over B1 operand in global memory
IteratorB1 iterator_B1,
///< initial value of accumulator
FragmentC1 const& src_accum) {
// 2nd Gemm
//
// Prologue
//
// Perform accumulation in the 'd' output operand
accum = src_accum;
if (!prologue_done_) {
_prologue(iterator_B1, gemm_k_iterations_1_, smem_iterator_B1_);
} else if (!kSmemContainsEntireB) {
// Restore the iterators increments
int gemm_k_iterations_1 = gemm_k_iterations_1_;
// Issue several complete stages
CUTLASS_PRAGMA_UNROLL
for (int stage = 0; stage < kNumStagesConcurrentLoad;
++stage, --gemm_k_iterations_1) {
iterator_B1.set_iteration_index(0);
this->smem_iterator_B1_.set_iteration_index(0);
// LDGSTS for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::TBLDGSTSIterationsB1; ++j) {
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) {
++iterator_B1;
}
++this->smem_iterator_B1_;
}
iterator_B1.add_tile_offset({1, 0});
this->smem_iterator_B1_.add_tile_offset({1, 0});
}
iterator_B1.set_residual_tile(gemm_k_iterations_1 <= 1);
iterator_B1.clear_mask(gemm_k_iterations_1 <= 0);
}
// DEPBAR+SYNC
cutlass::arch::cp_async_wait<kNumStagesConcurrentLoad - 1>();
__syncthreads();
// remember that WarpFragmentAScale and WarpIteratorAScale are no-op/empty
// if scaling is disabled.
// Pair of fragments used to overlap shared memory loads and math
// instructions
WarpLoadedFragmentA1 warp_loaded_frag_A1[2];
WarpLoadedFragmentA1Scale warp_loaded_frag_A1_scale[2];
WarpLoadedFragmentB1 warp_loaded_frag_B1[2];
WarpTransformedFragmentA1 warp_transformed_frag_A1[2];
WarpTransformedFragmentB1 warp_transformed_frag_B1[2];
Operator1 warp_mma1;
warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]);
++warp_tile_iterator_A1_;
warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]);
++warp_tile_iterator_A1_scale_;
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_B_.load(warp_loaded_frag_B1[0]);
++this->warp_tile_iterator_B_;
int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
warp_mma1.transform(warp_transformed_frag_A1[0],
warp_transformed_frag_B1[0],
FragmentAScaler::apply(warp_loaded_frag_A1[0],
warp_loaded_frag_A1_scale[0]),
warp_loaded_frag_B1[0]);
// tf32x3 kernels use staging accumulation. warp_mma uses a temporary
// accumulator and this temporary accumulator is added to the final
// accumulator once in every mainloop iteration.
plus<FragmentC1> plus_accum;
FragmentC1 tmp_accum;
if (platform::is_same<typename Operator1::MathOperator,
arch::OpMultiplyAddFastF32>::value ||
platform::is_same<typename Operator1::MathOperator,
arch::OpMultiplyAddComplexFastF32>::value) {
tmp_accum.clear();
}
//
// Mainloop
//
CUTLASS_PRAGMA_UNROLL
for (int gemm_k_iterations_1 = gemm_k_iterations_1_ - (Base::kStages - 1);
gemm_k_iterations_1 > (-Base::kStages + 1);
gemm_k_iterations_1--) {
//
// Loop over GEMM K dimension
//
// Computes a warp-level GEMM on data held in shared memory
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1;
++warp_mma_k) {
// Load warp-level tile from accumulator fragment (A)
// or shared memory (operand B)
this->warp_tile_iterator_B_.set_kgroup_index(
(warp_mma_k + 1) % Base::kWarpGemmIterations1);
// skip warp tile loading for the last kgroup (we are out of the buf)
if (gemm_k_iterations_1 > (-Base::kStages + 2) ||
warp_mma_k < Base::kWarpGemmIterations1 - 1) {
warp_tile_iterator_A1_.load(
warp_loaded_frag_A1[(warp_mma_k + 1) % 2]);
warp_tile_iterator_A1_scale_.load(
warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]);
this->warp_tile_iterator_B_.load(
warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
}
++warp_tile_iterator_A1_;
++warp_tile_iterator_A1_scale_;
++this->warp_tile_iterator_B_;
if (warp_mma_k > 0)
warp_mma1.transform(
warp_transformed_frag_A1[warp_mma_k % 2],
warp_transformed_frag_B1[warp_mma_k % 2],
FragmentAScaler::apply(warp_loaded_frag_A1[warp_mma_k % 2],
warp_loaded_frag_A1_scale[warp_mma_k % 2]),
warp_loaded_frag_B1[warp_mma_k % 2]);
if (platform::is_same<typename Operator1::MathOperator,
arch::OpMultiplyAddFastF32>::value ||
platform::is_same<typename Operator1::MathOperator,
arch::OpMultiplyAddComplexFastF32>::value) {
warp_mma1(tmp_accum,
warp_transformed_frag_A1[warp_mma_k % 2],
warp_transformed_frag_B1[warp_mma_k % 2],
tmp_accum);
if (warp_mma_k == 0) {
accum = plus_accum(accum, tmp_accum);
tmp_accum.clear();
}
} else {
warp_mma1(accum,
warp_transformed_frag_A1[warp_mma_k % 2],
warp_transformed_frag_B1[warp_mma_k % 2],
accum);
}
// Issue global->shared copies for the this stage
if (warp_mma_k < Base::kWarpGemmIterations1 - 1) {
int group_start_iteration_B1;
group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1;
if (!kSmemContainsEntireB) {
copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1);
}
}
if (warp_mma_k + 2 == Base::kWarpGemmIterations1) {
int group_start_iteration_B1;
group_start_iteration_B1 =
(warp_mma_k + 1) * Detail::kAccessesPerGroupB1;
if (!kSmemContainsEntireB) {
copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1);
}
// Inserts a memory fence between stages of cp.async instructions.
cutlass::arch::cp_async_fence();
// Waits until kStages-2 stages have committed.
arch::cp_async_wait<kNumStagesConcurrentLoad - 1>();
__syncthreads();
// Move to the next stage
iterator_B1.add_tile_offset({1, 0});
this->smem_iterator_B1_.add_tile_offset({1, 0});
// Add negative offsets to return iterators to the 'start' of the
// circular buffer in shared memory
if (!kSmemContainsEntireB) {
if (smem_write_stage_idx == (Base::kStages - 1)) {
this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0});
smem_write_stage_idx = 0;
} else {
++smem_write_stage_idx;
}
if (smem_read_stage_idx == (Base::kStages - 1)) {
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy1::kPartitionsK *
Base::kWarpGemmIterations1,
0});
smem_read_stage_idx = 0;
} else {
++smem_read_stage_idx;
}
}
iterator_B1.set_residual_tile(gemm_k_iterations_1 == 2);
iterator_B1.clear_mask(gemm_k_iterations_1 == 1);
}
// Do any conversions feeding the first stage at the end of the loop so
// we can start right away on mma instructions
if (warp_mma_k + 1 == Base::kWarpGemmIterations1)
warp_mma1.transform(
warp_transformed_frag_A1[(warp_mma_k + 1) % 2],
warp_transformed_frag_B1[(warp_mma_k + 1) % 2],
FragmentAScaler::apply(
warp_loaded_frag_A1[(warp_mma_k + 1) % 2],
warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]),
warp_loaded_frag_B1[(warp_mma_k + 1) % 2]);
}
}
if (platform::is_same<typename Operator1::MathOperator,
arch::OpMultiplyAddFastF32>::value ||
platform::is_same<typename Operator1::MathOperator,
arch::OpMultiplyAddComplexFastF32>::value) {
accum = plus_accum(accum, tmp_accum);
}
}
};
template <typename WarpShape,
typename InstructionShape,
typename RegularWarpIterator,
typename Policy,
typename Enable = void>
struct DefaultWarpIteratorAFromSharedMemory {};
// TensorOp - Ampere half
template <typename RegularWarpIterator, typename Policy>
struct DefaultWarpIteratorAFromSharedMemory<
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
RegularWarpIterator,
Policy,
typename platform::enable_if<(
sizeof_bits<typename RegularWarpIterator::Element>::value == 16 &&
Policy::Operator::Policy::OpDelta::kRow == 1)>::type> {
static constexpr auto kWarpSize = 32;
using OpDelta = typename Policy::Operator::Policy::OpDelta;
using WarpShape = cutlass::MatrixShape<32, 32>;
using WarpIterator = cutlass::gemm::warp::WarpIteratorFromSmem<
cutlass::gemm::Operand::kA,
typename RegularWarpIterator::Element>;
};
// TensorOp - Ampere f32
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
struct DefaultWarpIteratorAFromSharedMemory<
WarpShape,
cutlass::gemm::GemmShape<16, 8, 8>,
RegularWarpIterator,
Policy,
typename platform::enable_if<(
sizeof_bits<typename RegularWarpIterator::Element>::value != 16 ||
Policy::Operator::Policy::OpDelta::kRow != 1)>::type> {
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
static constexpr auto kWarpSize = 32;
using OpDelta = typename Policy::Operator::Policy::OpDelta;
using WarpIterator =
cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator<
cutlass::MatrixShape<WarpShape::kM, WarpShape::kK>,
cutlass::gemm::Operand::kA,
typename RegularWarpIterator::Element,
cutlass::layout::RowMajor,
cutlass::MatrixShape<InstructionShape::kM, InstructionShape::kK>,
OpDelta::kRow,
kWarpSize>;
};
// TensorOp - Volta
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
struct DefaultWarpIteratorAFromSharedMemory<WarpShape,
cutlass::gemm::GemmShape<16, 16, 4>,
RegularWarpIterator,
Policy> {
using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>;
static constexpr auto kWarpSize = 32;
using OpDelta = typename Policy::Operator::Policy::OpDelta;
using WarpIterator =
cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator<
cutlass::MatrixShape<32, 32>, // MatrixShape<WarpShape::kM,
// WarpShape::kK>,
cutlass::gemm::Operand::kA,
typename RegularWarpIterator::Element,
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>,
cutlass::MatrixShape<16, 4>,
OpDelta::kRow,
kWarpSize>;
};
// Simt
template <typename WarpShape, typename RegularWarpIterator, typename Policy>
struct DefaultWarpIteratorAFromSharedMemory<WarpShape,
cutlass::gemm::GemmShape<1, 1, 1>,
RegularWarpIterator,
Policy> {
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
static constexpr auto kWarpSize = 32;
// We just use the same iterator, as we reproduced the same shared-memory
// schema. Just modify it to handle non-complete tiles.
using WarpIterator = RegularWarpIterator;
};
// Converts a "regular" Mma into their counterpart from shared memory
template <typename Mma_,
typename AccumulatorSharedStorage,
/// whether or not to apply elementwise multiplication of operand A by
/// another matrix in shared memory before usage in A @ B
bool kScaleOperandA,
bool kTransposeA = false>
struct DefaultMmaFromSharedMemory;
// Mma pipelined
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Transformation applied to A operand
typename TransformA_,
/// Transformation applied to B operand
typename TransformB_,
typename AccumulatorSharedStorage_,
/// whether or not to apply elementwise multiplication of operand A by
/// another matrix in shared memory before usage in A @ B
bool kScaleOperandA,
bool kTransposeA>
struct DefaultMmaFromSharedMemory<MmaPipelined<Shape_,
IteratorA_,
SmemIteratorA_,
IteratorB_,
SmemIteratorB_,
ElementC_,
LayoutC_,
Policy_,
TransformA_,
TransformB_>,
AccumulatorSharedStorage_,
kScaleOperandA,
kTransposeA> {
static constexpr int kWarpSize = 32;
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
using RegularMma = MmaPipelined<Shape_,
IteratorA_,
SmemIteratorA_,
IteratorB_,
SmemIteratorB_,
ElementC_,
LayoutC_,
Policy_,
TransformA_,
TransformB_>;
using WarpShape = typename Policy_::Operator::Shape;
using InstructionShape = typename Policy_::Operator::InstructionShape;
using ArchMmaOperator = typename Policy_::Operator;
static constexpr bool kIsTransposedA = false;
using WarpIteratorA = typename DefaultWarpIteratorAFromSharedMemory<
WarpShape,
InstructionShape,
typename RegularMma::Operator::IteratorA,
Policy_>::WarpIterator;
using IteratorB =
typename cutlass::transform::threadblock::MakeIteratorResidualLast<
IteratorB_>::Iterator;
using Mma = typename cutlass::gemm::threadblock::MmaPipelinedFromSharedMemory<
Shape_,
WarpIteratorA,
kScaleOperandA,
AccumulatorSharedStorage_,
IteratorB,
SmemIteratorB_,
ElementC_,
LayoutC_,
Policy_>;
};
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Cache operation for operand A
cutlass::arch::CacheOperation::Kind CacheOpA,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Cache operation for operand B
cutlass::arch::CacheOperation::Kind CacheOpB,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear,
typename AccumulatorSharedStorage_,
/// whether or not to apply elementwise multiplication of operand A by
/// another matrix in shared memory before usage in A @ B
bool kScaleOperandA,
bool kTransposeA>
struct DefaultMmaFromSharedMemory<MmaMultistage<Shape_,
IteratorA_,
SmemIteratorA_,
CacheOpA,
IteratorB_,
SmemIteratorB_,
CacheOpB,
ElementC_,
LayoutC_,
Policy_,
Stages,
SharedMemoryClear>,
AccumulatorSharedStorage_,
kScaleOperandA,
kTransposeA> {
static constexpr int kWarpSize = 32;
using RegularMma = MmaMultistage<Shape_,
IteratorA_,
SmemIteratorA_,
CacheOpA,
IteratorB_,
SmemIteratorB_,
CacheOpB,
ElementC_,
LayoutC_,
Policy_,
Stages,
SharedMemoryClear>;
using WarpShape = typename Policy_::Operator::Shape;
using InstructionShape = typename Policy_::Operator::InstructionShape;
using WarpIteratorA_ = typename DefaultWarpIteratorAFromSharedMemory<
WarpShape,
InstructionShape,
typename RegularMma::Operator::IteratorA,
Policy_>::WarpIterator;
using WarpIteratorTranspose = TransposeWarpIterator<WarpIteratorA_>;
static constexpr bool kIsTransposedA =
WarpIteratorTranspose::kSupportsTranspose && kTransposeA;
using WarpIteratorA =
typename platform::conditional<kIsTransposedA,
typename WarpIteratorTranspose::Iterator,
WarpIteratorA_>::type;
static int constexpr kMaxK = kIsTransposedA
? AccumulatorSharedStorage_::Shape::kM
: AccumulatorSharedStorage_::Shape::kN;
// Reduce the number of stages if we don't need that many
static int constexpr kStagesMax =
(kMaxK + static_cast<int>(Shape_::kK) - 1) / static_cast<int>(Shape_::kK);
static int constexpr kStages = cutlass::const_min(Stages, kStagesMax);
using IteratorB =
typename cutlass::transform::threadblock::MakeIteratorResidualLast<
IteratorB_>::Iterator;
using Mma =
typename cutlass::gemm::threadblock::MmaMultistageFromSharedMemory<
Shape_,
WarpIteratorA,
kScaleOperandA,
AccumulatorSharedStorage_,
IteratorB,
SmemIteratorB_,
RegularMma::kCacheOpB,
ElementC_,
LayoutC_,
Policy_,
kStages,
kMaxK>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename IteratorC,
typename Operator,
typename scalar_t,
typename WarpShape_,
typename ThreadblockShape_>
struct B2bGemm;
// Tensor Cores >= Sm75 specialization (Ampere ...)
template < /// Size of the matrix to load (concept: MatrixShape)
typename Shape_,
/// Element type
typename Element_,
/// Layout of operand in memory
typename Layout_,
/// Shape of one matrix product operation (concept: MatrixShape)
typename InstructionShape_,
/// Interval between adjacent *MMA instructions (in units of MMA
/// instructions, concept: MatrixShape)
typename OpDelta_,
typename Operator,
typename scalar_t,
typename WarpShape_,
typename ThreadblockShape_>
struct B2bGemm<
cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator<Shape_,
Element_,
Layout_,
InstructionShape_,
OpDelta_>,
Operator,
scalar_t,
WarpShape_,
ThreadblockShape_> {
using IteratorC =
typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator<
Shape_,
Element_,
Layout_,
InstructionShape_,
OpDelta_>;
using FragmentC = typename IteratorC::Fragment;
using InstructionShape = InstructionShape_;
using WarpShape = WarpShape_;
using ThreadblockShape = ThreadblockShape_;
using accum_t = Element_;
using lse_scalar_t = float;
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
// Iterator to load accumulators (results of matmul in registers)
using FragmentIteratorAccumulator =
cutlass::epilogue::warp::FragmentIteratorTensorOp<
WarpShape,
InstructionShape,
accum_t,
typename Operator::Policy::Operator::FragmentC,
cutlass::layout::RowMajor>;
// Iterator to store to shared-memory
using SmemIteratorD0 = typename cutlass::epilogue::warp::TileIteratorTensorOp<
WarpShape,
InstructionShape,
scalar_t, // accum_t,
SmemAccumulatorLayout>;
using AccumulatorSharedStorage =
cutlass::gemm::threadblock::AccumulatorSharedStorage<
ThreadblockShape,
typename SmemIteratorD0::Element,
typename SmemIteratorD0::TensorLayout,
typename SmemIteratorD0::Padding>;
// We need to provide an operation for the epilogue. Let's create an
// operation that does nothing (ScaleType::Nothing), just converts
// from accum_t (float) -> scalar_t (can be half)
using OutputOpNoOp = cutlass::epilogue::thread::LinearCombination<
typename SmemIteratorD0::Element, // ElementOutput
FragmentIteratorAccumulator::Fragment::kElements,
accum_t, // ElementAccumulator
typename SmemIteratorD0::Element, // ElementCompute
cutlass::epilogue::thread::ScaleType::Nothing>;
using Epilogue = cutlass::epilogue::threadblock::EpilogueSmemAccumulator<
SmemIteratorD0,
FragmentIteratorAccumulator,
SmemIteratorD0, // ScaleBiasIterator - not used
OutputOpNoOp>;
// Epilogue 2: with LSE (for backwards pass)
static int const kElementsPerAccess = 2; // TODO(xformers): Why 2?
using IteratorAccumulatorLSE =
cutlass::transform::threadblock::VectorIterator<
cutlass::transform::threadblock::PredicatedVectorAccessIterator<
// Shape
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kN>,
// WarpShape
cutlass::MatrixShape<WarpShape::kM, WarpShape::kN>,
lse_scalar_t,
cutlass::layout::RowMajor,
kElementsPerAccess>>;
using EpilogueOpApplyLSE = cutlass::epilogue::thread::ApplyLogSumExp<
scalar_t, // ElementOutput_
lse_scalar_t, // ElementLSE_
accum_t, // ElementAccumulator_
accum_t, // ElementCompute_
128 / cutlass::sizeof_bits<scalar_t>::value
// FragmentIteratorAccumulator::Fragment::kElements
// InstructionShape::kM * InstructionShape::kN / 32
>;
using EpilogueWithLSE =
cutlass::epilogue::threadblock::EpilogueSmemAccumulator<
SmemIteratorD0,
FragmentIteratorAccumulator,
IteratorAccumulatorLSE,
EpilogueOpApplyLSE>;
static void CUTLASS_DEVICE
accumToSmem(AccumulatorSharedStorage& shared_storage, // NOLINT
FragmentC const& accum, // NOLINT
int lane_id,
cutlass::MatrixCoord const& tile_coords) {
SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id);
smem_iterator_attn.add_tile_offset(
tile_coords *
cutlass::MatrixCoord{SmemIteratorD0::TileIterations::kRow,
SmemIteratorD0::TileIterations::kColumn});
Epilogue epilogue;
epilogue(OutputOpNoOp({}), smem_iterator_attn, accum);
}
static void CUTLASS_DEVICE
accumApplyLSEToSmem(AccumulatorSharedStorage& shared_storage, // NOLINT
FragmentC& accum, // NOLINT
lse_scalar_t const* lse,
int32_t lse_extents,
int thread_id,
int warp_id,
int lane_id,
cutlass::MatrixCoord const& tile_coords) {
constexpr int32_t kAlignLSE = 32;
IteratorAccumulatorLSE iterator_lse(
lse,
{(int32_t)0, (int32_t)ceil_div(lse_extents, kAlignLSE) * kAlignLSE},
thread_id,
warp_id,
cutlass::MatrixCoord{0, 0} // offset
);
SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id);
smem_iterator_attn.add_tile_offset(
tile_coords *
cutlass::MatrixCoord{SmemIteratorD0::TileIterations::kRow,
SmemIteratorD0::TileIterations::kColumn});
EpilogueWithLSE epilogue;
EpilogueOpApplyLSE minus_lse_exp({});
epilogue(minus_lse_exp,
smem_iterator_attn,
accum,
// scale - unused
iterator_lse,
// bias
iterator_lse);
}
};
// Volta Specialization
// only supported for f16
template <typename Operator, typename WarpShape_, typename ThreadblockShape_>
struct B2bGemm<cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator<
cutlass::MatrixShape<32, 32>,
float,
cutlass::layout::RowMajor,
cutlass::gemm::GemmShape<16, 16, 4>,
cutlass::MatrixShape<1, 1>>,
Operator,
cutlass::half_t,
WarpShape_,
ThreadblockShape_> {
using IteratorC =
cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator<
cutlass::MatrixShape<32, 32>,
float,
cutlass::layout::RowMajor,
cutlass::gemm::GemmShape<16, 16, 4>,
cutlass::MatrixShape<1, 1>>;
using scalar_t = cutlass::half_t;
using accum_t = IteratorC::Element;
using WarpShape = WarpShape_;
using ThreadblockShape = ThreadblockShape_;
using FragmentC = IteratorC::Fragment;
using lse_scalar_t = float;
using SmemAccumulatorLayout = cutlass::layout::RowMajor;
using SmemIteratorD0 = cutlass::epilogue::warp::TileIteratorVoltaTensorOp<
WarpShape,
cutlass::gemm::GemmShape<32, 32, 4>,
scalar_t,
SmemAccumulatorLayout>;
// // Storage in shared-memory for Q.Kt
using AccumulatorSharedStorage =
cutlass::gemm::threadblock::AccumulatorSharedStorage<
ThreadblockShape,
scalar_t,
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<
16,
32>, // typename SmemIteratorD0::TensorLayout,
cutlass::MatrixShape<0, 0> // Padding
>;
using OutputLayout =
cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>;
using TensorRef = cutlass::TensorRef<scalar_t, OutputLayout>;
using Policy = typename IteratorC::Policy;
using Element = accum_t;
// Those are MmaVoltaTensorOpAccumulatorTileIterator private fields
// Let's copy their values
static int const kElementsPerPartial = 4;
using EleShapePerPatial = typename cutlass::platform::conditional<
cutlass::platform::is_same<Element, float>::value,
cutlass::MatrixShape<2, 2>,
cutlass::MatrixShape<1, 4>>::type;
static int const kElementsPerMma = 8;
static int const kAccumulatorPatials = 2;
using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>;
static void CUTLASS_DEVICE
accumToSmem(AccumulatorSharedStorage& shared_storage, // NOLINT
FragmentC const& accum,
int lane_id,
cutlass::MatrixCoord const& tile_coords) {
// ctor - from MmaVoltaTensorOpAccumulatorTileIterator
TensorRef ref_(shared_storage.accum_ref());
int quad = (lane_id >> 2);
int lane_in_quad = (lane_id & 3);
int accum_m, accum_n;
if (cutlass::platform::is_same<Element, float>::value) {
// (quad[2],quad[0])+lane_in_quad[0]
accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1);
// (quad[1])+lane_in_quad[1]
accum_n =
((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials +
(lane_in_quad & 2);
} else {
accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 +
lane_in_quad; // (quad[2],quad[0])
accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials;
}
cutlass::MatrixCoord lane_offset(accum_m, accum_n);
// Tile offset
ref_.add_coord_offset(tile_coords *
cutlass::MatrixCoord({IteratorC::Shape::kRow,
IteratorC::Shape::kColumn}));
using AccessType = cutlass::Array<scalar_t, EleShapePerPatial::kColumn>;
// store - from MmaVoltaTensorOpAccumulatorTileIterator
CUTLASS_PRAGMA_UNROLL
for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) {
CUTLASS_PRAGMA_UNROLL
for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) {
CUTLASS_PRAGMA_UNROLL
for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) {
CUTLASS_PRAGMA_UNROLL
for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
int mma_accum_start =
(((tile_n * Policy::TileIterations::kRow + tile_m) *
Policy::MmaIterations::kColumn +
mma_n) *
Policy::MmaIterations::kRow +
mma_m) *
kElementsPerMma;
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < kAccumulatorPatials; ++p) {
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < EleShapePerPatial::kRow; ++m) {
int accum_m = tile_m * Policy::InterleavedTile::kRow +
mma_m * QuadShapePerPatialMma::kRow + m * 2;
int accum_n = tile_n * Policy::InterleavedTile::kColumn +
mma_n * QuadShapePerPatialMma::kColumn +
p * Policy::InterleavedTile::kColumn / 2;
int r = (accum_m + lane_offset.row());
AccessType to_store;
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < EleShapePerPatial::kColumn; ++n) {
int idx = mma_accum_start + p * kElementsPerPartial +
m * EleShapePerPatial::kColumn + n;
int c = (accum_n + n + lane_offset.column());
to_store[n] = scalar_t(accum[idx]);
}
int c = (accum_n + lane_offset.column());
assert(r < 32);
assert(c < 32);
*reinterpret_cast<AccessType*>(ref_.data() +
ref_.offset({r, c})) = to_store;
}
}
}
}
}
}
}
static void CUTLASS_DEVICE
accumApplyLSEToSmem(AccumulatorSharedStorage& shared_storage, // NOLINT
typename IteratorC::Fragment& accum, // NOLINT
lse_scalar_t const* lse,
int lse_extent,
int thread_id,
int warp_id,
int lane_id,
cutlass::MatrixCoord const& tile_coords) {
// Non-optimized way to apply LSE to registers
// NOTE: accum is attn.T
// TODO(xformers): Optimize for each architecture
static constexpr int WarpSize = 32;
using AccumLambdaIterator =
typename DefaultMmaAccumLambdaIterator<IteratorC, accum_t, WarpSize>::
Iterator;
auto lane_offset =
AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords);
cutlass::Array<lse_scalar_t, IteratorC::Fragment::kElements> lse_prefetched;
lse_prefetched.clear();
int rowIdx = 0;
int colIdx = 0;
AccumLambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {
++rowIdx;
colIdx = 0;
},
[&](int accum_m, int accum_n, int idx) {
if (rowIdx == 1) {
lse_prefetched[colIdx] =
accum_n < lse_extent
? lse[accum_n]
: platform::numeric_limits<accum_t>::infinity();
}
accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]);
++colIdx;
},
[&](int accum_m) {});
accumToSmem(shared_storage, accum, lane_id, tile_coords);
}
};
// Simt Specialization
// for f32 on Sm70-Sm75 and f16/f32 below
template <typename Operator,
typename OperatorPolicy,
typename scalar_t,
typename WarpShape_,
typename ThreadblockShape_>
struct B2bGemm<
cutlass::gemm::warp::MmaSimtTileIterator<cutlass::MatrixShape<32, 32>,
cutlass::gemm::Operand::kC,
float,
cutlass::layout::RowMajor,
OperatorPolicy,
1,
1>,
Operator,
scalar_t,
WarpShape_,
ThreadblockShape_> {
using IteratorC =
cutlass::gemm::warp::MmaSimtTileIterator<cutlass::MatrixShape<32, 32>,
cutlass::gemm::Operand::kC,
float,
cutlass::layout::RowMajor,
OperatorPolicy,
1,
1>;
using accum_t = typename IteratorC::Element;
using WarpShape = WarpShape_;
using ThreadblockShape = ThreadblockShape_;
using FragmentC = typename IteratorC::Fragment;
using lse_scalar_t = float;
// Storage in shared-memory for Q.Kt
using AccumulatorSharedStorage =
cutlass::gemm::threadblock::AccumulatorSharedStorage<
ThreadblockShape,
scalar_t,
cutlass::layout::ColumnMajor,
cutlass::MatrixShape<0, 0> // Padding
>;
static void CUTLASS_DEVICE
accumToSmem(AccumulatorSharedStorage& shared_storage, // NOLINT
FragmentC const& accum, // NOLINT
int lane_id,
cutlass::MatrixCoord const& tile_coords) {
using Policy = typename IteratorC::Policy;
using Element = typename IteratorC::Element;
using Iterations = typename IteratorC::Iterations;
using Delta = typename IteratorC::Delta;
auto ref_ = shared_storage.accum_ref();
// ctor - MmaSimtTileIterator
// compute offset based on thread ID and lane layout
typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
MatrixCoord lane_offset =
lane_layout.inverse(lane_id) *
MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN);
ref_.add_coord_offset(lane_offset);
// Tile offset
ref_.add_coord_offset(tile_coords *
cutlass::MatrixCoord({IteratorC::Shape::kRow,
IteratorC::Shape::kColumn}));
// store - MmaSimtTileIterator
CUTLASS_PRAGMA_UNROLL
for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) {
CUTLASS_PRAGMA_UNROLL
for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) {
int r =
Policy::LaneMmaShape::kM * (mma_m * Policy::WarpShape::kRow) +
m;
int c = mma_n * Delta::kColumn + n;
int idx =
n + Policy::LaneMmaShape::kN *
(mma_n + Iterations::kColumn *
(m + mma_m * Policy::LaneMmaShape::kM));
ref_.at({r, c}) = scalar_t(accum[idx]);
}
}
}
}
}
static void CUTLASS_DEVICE
accumApplyLSEToSmem(AccumulatorSharedStorage& shared_storage, // NOLINT
typename IteratorC::Fragment& accum, // NOLINT
lse_scalar_t const* lse,
int lse_extent,
int thread_id,
int warp_id,
int lane_id,
cutlass::MatrixCoord const& tile_coords) {
// Non-optimized way to apply LSE to registers
// NOTE: accum is attn.T
// TODO(xformers): Optimize for each architecture
static constexpr int WarpSize = 32;
using AccumLambdaIterator =
typename DefaultMmaAccumLambdaIterator<IteratorC, accum_t, WarpSize>::
Iterator;
auto lane_offset =
AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords);
cutlass::Array<lse_scalar_t, IteratorC::Fragment::kElements> lse_prefetched;
lse_prefetched.clear();
int rowIdx = 0;
int colIdx = 0;
AccumLambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {
++rowIdx;
colIdx = 0;
},
[&](int accum_m, int accum_n, int idx) {
if (rowIdx == 1) {
lse_prefetched[colIdx] =
accum_n < lse_extent
? lse[accum_n]
: platform::numeric_limits<accum_t>::infinity();
}
accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]);
++colIdx;
},
[&](int accum_m) {});
accumToSmem(shared_storage, accum, lane_id, tile_coords);
}
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include "cutlass/arch/mma.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/phi/core/enforce.h"
////////////////////////////////////////////////////////////////////////////////
// Some helper functions
////////////////////////////////////////////////////////////////////////////////
#define DISPATCH_TYPES(tensor, func) \
{ \
if (query.scalar_type() == at::ScalarType::Float) { \
using scalar_t = float; \
func(); \
} else if (query.scalar_type() == at::ScalarType::Half) { \
using scalar_t = cutlass::half_t; \
func(); \
} else if (query.scalar_type() == at::ScalarType::BFloat16) { \
using scalar_t = cutlass::bfloat16_t; \
func(); \
} else { \
PADDLE_CHECK(false, "Only fp32, half & bf16 supported at the moment"); \
} \
}
#define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \
{ \
if (BOOL_V) { \
constexpr bool BOOL_NAME = true; \
F(); \
} else { \
constexpr bool BOOL_NAME = false; \
F(); \
} \
}
#define DISPATCH_ARCHTAG(CC, func) \
{ \
if (CC >= 80) { \
using ArchTag = cutlass::arch::Sm80; \
func(); \
} else if (CC >= 75) { \
using ArchTag = cutlass::arch::Sm75; \
func(); \
} else if (CC >= 70) { \
using ArchTag = cutlass::arch::Sm70; \
func(); \
} else if (CC >= 50) { \
using ArchTag = cutlass::arch::Sm50; \
func(); \
} else { \
PADDLE_CHECK( \
false, \
"Your device is too old. We require compute capability >= 50"); \
} \
}
#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \
PADDLE_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
PADDLE_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
PADDLE_CHECK(TENSOR.is_contiguous());
#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \
PADDLE_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
PADDLE_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
PADDLE_CHECK(TENSOR.stride(-1) == 1, \
#TENSOR ": last dimension must be contiguous");
#ifdef defined(__CUDACC_RTC__)
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \
return false; \
}
#define PADDLE_CHECK(COND, ERR) \
if (!(COND)) { \
return false; \
}
#else
#include <iostream>
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \
std::cerr << #PTR " is not correctly aligned\n"; \
return false; \
}
#define PADDLE_CHECK(COND, ERR) \
if (!(COND)) { \
std::cerr << #COND " failed\n"; \
return false; \
}
#endif
#define ASSIGN_CHECK_OVERFLOW(A, B) \
{ \
A = B; \
PADDLE_CHECK(B < std::numeric_limits<decltype(A)>::max(), \
#B " overflows"); \
}
namespace gemm_kernel_utils {
template <typename integer>
constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) {
return (n + m - 1) / m;
}
template <typename integer>
constexpr CUTLASS_HOST_DEVICE integer align_up(integer n, integer m) {
return ((n + m - 1) / m) * m;
}
inline int32_t getMaximumSharedMemoryPerBlockKb(int cc) {
// from:
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications-technical-specifications-per-compute-capability
switch (cc) {
case 50:
case 52:
case 53:
case 60:
case 61:
case 62:
return 64;
case 70:
case 72:
return 96;
case 75:
return 64;
case 80:
return 163;
case 86:
return 99;
case 87:
return 163;
case 89:
return 99;
case 90:
return 227;
default:
return 0;
}
}
////////////////////////////////////////////////////////////////////////////////
// Determine the type of GEMM we do (TensorCores or not, Shapes ...)
// TODO(xformers): Maybe we could rely on Cutlass's DefaultGemm templates
////////////////////////////////////////////////////////////////////////////////
// Fallback to Simt (FMA on cuda cores) if not in a special case below
template <typename ArchTag, typename scalar_t_, typename Enable = void>
struct DefaultGemmType {
static constexpr int ThreadK = 8;
static constexpr int WarpK = 8;
static constexpr int kMinimumAlignment = 1;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
using OpClass = cutlass::arch::OpClassSimt;
using Operator = cutlass::arch::OpMultiplyAdd;
};
// Specialization for tensorcores with f32
template <typename ArchTag>
struct DefaultGemmType<ArchTag,
float,
typename cutlass::platform::enable_if<
ArchTag::kMinComputeCapability >= 80>::type> {
static constexpr int ThreadK = 32;
static constexpr int WarpK = 32;
static constexpr int kMinimumAlignment = 4;
using OpClass = cutlass::arch::OpClassTensorOp;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using Operator = cutlass::arch::OpMultiplyAddFastF32;
};
// Specialization for tensorcores with f16/bf16 - Sm75+
template <typename ArchTag, typename scalar_t>
struct DefaultGemmType<ArchTag,
scalar_t,
typename cutlass::platform::enable_if<
ArchTag::kMinComputeCapability >= 75 &&
cutlass::sizeof_bits<scalar_t>::value == 16>::type> {
static constexpr int ThreadK = 32;
static constexpr int WarpK = 32;
static constexpr int kMinimumAlignment = 4;
using OpClass = cutlass::arch::OpClassTensorOp;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using Operator = cutlass::arch::OpMultiplyAdd;
};
// Specialization for tensorcores with f16 - Volta
template <>
struct DefaultGemmType<cutlass::arch::Sm70, cutlass::half_t, void> {
static constexpr int ThreadK = 32;
static constexpr int WarpK = 32;
static constexpr int kMinimumAlignment = 2;
using OpClass = cutlass::arch::OpClassTensorOp;
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
using Operator = cutlass::arch::OpMultiplyAdd;
};
// Enables to do
// `auto x = kCondition ? fa(arg) : fb(arg)`
// when `fa` and `fb` have different types
template <bool kVal, typename TA, typename TB>
struct call_conditional;
template <typename TA, typename TB>
struct call_conditional<true, TA, TB> {
template <typename Arg>
static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg)
-> decltype(ta(arg)) {
return ta(arg);
}
};
template <typename TA, typename TB>
struct call_conditional<false, TA, TB> {
template <typename Arg>
static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg)
-> decltype(tb(arg)) {
return tb(arg);
}
};
////////////////////////////////////////////////////////////////////////////////
// Mark a variable as warp-uniform - enables some compiler optimizations
// The cheapest way to do it is just to broadcast it from lane 0
////////////////////////////////////////////////////////////////////////////////
CUTLASS_DEVICE int32_t warp_uniform(int32_t value) {
return (int32_t)__shfl_sync(0xffffffff, (unsigned)value, 0);
}
template <typename T>
CUTLASS_DEVICE T* warp_uniform(T* ptr) {
struct {
union {
T* ptr;
uint32_t asInt[2];
};
} p;
p.ptr = ptr;
p.asInt[0] = warp_uniform(p.asInt[0]);
p.asInt[1] = warp_uniform(p.asInt[1]);
return p.ptr;
}
} // namespace gemm_kernel_utils
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# Generates combination of kernels - implementations and registry
# Kernels are ordered (see `sort_index`), and when dispatching,
# we select the first kernel in the list that supports the inputs
import argparse
import collections
import itertools
import os
import shutil
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Tuple, TypeVar
MAX_ARCH = 90
ENABLE_MACRO = "PADDLE_WITH_MEMORY_EFFICIENT_ATTENTION"
def convert_to_arch_list(arch):
arch = arch.lower().strip()
if arch == "all":
return [50, 70, 75, 80]
arch = [int(s.strip()) for s in arch.split(' ') if s.strip()]
arch = list(set(arch))
arch.sort()
for each_arch in arch:
assert each_arch < MAX_ARCH
return arch
def parse_args():
parser = argparse.ArgumentParser(
description="The argument for generating the memory efficient kernels."
)
parser.add_argument(
"--dst_path",
type=str,
default=str(Path(__file__).parent),
help="The destination path to save the generated files.",
)
parser.add_argument(
"--cuda_arch",
type=convert_to_arch_list,
default=convert_to_arch_list("All"),
help="The CUDA architecture to be generated.",
)
return parser.parse_args()
args = parse_args()
DTYPES = {
"f32": "float",
"f16": "cutlass::half_t",
"bf16": "cutlass::bfloat16_t",
}
SM = args.cuda_arch
KERNEL_IMPL_TEMPLATE = """__global__ void __launch_bounds__(
{CPP_CLASS}::kNumThreads,
{CPP_CLASS}::kMinBlocksPerSm)
{NAME}(typename {CPP_CLASS}::Params p) {{
#ifdef __CUDA_ARCH__
#if __CUDA_ARCH__ >= {SM}0
#if __CUDA_ARCH__ < {SM_MAX}0
if (!p.advance_to_block()) {{
return;
}}
{CPP_CLASS}::attention_kernel(p);
return;
#endif
#endif
printf(
"FATAL: kernel `{NAME}` is for sm{SM}-sm{SM_MAX}, but was built for sm%d\\n",
int(__CUDA_ARCH__ + 0) / 10);
#endif
}}
"""
@dataclass(order=True)
class FwdKernel:
sort_index: Tuple[int, ...] = field(init=False, repr=False)
aligned: bool
dtype: str
sm_range: Tuple[int, int]
q: int
k: int
single_value_iter: bool
supports_dropout: bool = True
supports_bias: bool = True
dispatch_cond: Optional[str] = None
def __post_init__(self) -> None:
# Set kernel selection priority
# The lowest value that matches inputs
# will be selected
self.sort_index = (
# First select aligned kernel
0 if self.aligned else 1,
# Then keep output in RF
0 if self.single_value_iter else 1,
self.k,
# Prefer kernels without dropout/bias if available
1 if self.supports_dropout else 0,
1 if self.supports_bias else 0,
)
@property
def _aligned_suffix(self) -> str:
return "aligned" if self.aligned else "notaligned"
@property
def name(self) -> str:
acc = "rf" if self.single_value_iter else "gmem"
return f"fmha_cutlassF_{self.dtype}_{self._aligned_suffix}_{self.q}x{self.k}_{acc}_sm{self.sm_range[0]}"
@property
def cpp_class(self) -> str:
template_args = ", ".join(
[
DTYPES[self.dtype],
f"cutlass::arch::Sm{self.sm_range[0]}",
"true" if self.aligned else "false",
str(self.q),
str(self.k),
"true" if self.single_value_iter else "false",
"true" if self.supports_dropout else "false",
"true" if self.supports_bias else "false",
]
)
return f"AttentionKernel<{template_args}>"
@property
def impl_group(self) -> str:
# Maps to file which will contain the implementation
return f"{self.dtype}_{self._aligned_suffix}"
@property
def cpp_impl(self) -> str:
return KERNEL_IMPL_TEMPLATE.format(
CPP_CLASS=self.cpp_class,
NAME=self.name,
SM=self.sm_range[0],
SM_MAX=self.sm_range[1],
)
@classmethod
def get_all(cls) -> List["FwdKernel"]:
kernels: List[FwdKernel] = []
for aligned, dtype, (sm, sm_max) in itertools.product(
[True, False], DTYPES.keys(), zip(SM, SM[1:] + [MAX_ARCH])
):
# Remove some kernels we don't use
if dtype == "bf16" and sm < 80:
continue
if not aligned and sm >= 80:
continue
for q, k, single_value_iter in [
(32, 128, True),
(32, 128, False),
(64, 64, True),
]:
kernels.append(
cls(
aligned=aligned,
dtype=dtype,
sm_range=(sm, sm_max),
q=q,
k=k,
single_value_iter=single_value_iter,
)
)
return kernels
@dataclass(order=True)
class BwdKernel:
sort_index: Tuple[int, ...] = field(init=False, repr=False)
sm_range: Tuple[int, int]
dtype: str
aligned: bool
apply_dropout: bool
preload_mmas: bool
block_i: int
block_j: int
max_k: int
dispatch_cond: Optional[str] = None
def __post_init__(self) -> None:
# Set kernel selection priority
# The lowest value that matches inputs
# will be selected
self.sort_index = (
# First select aligned kernel
0 if self.aligned else 1,
# Take a kernel without dropout if possible
1 if self.apply_dropout else 0,
# Then take the smallest maxK
self.max_k,
# .. and the highest block_i
-self.block_i,
)
@property
def _aligned_suffix(self) -> str:
return "aligned" if self.aligned else "notaligned"
@property
def name(self) -> str:
dropout_suffix = "_dropout" if self.apply_dropout else ""
return (
f"fmha_cutlassB_{self.dtype}_{self._aligned_suffix}"
f"_{self.block_i}x{self.block_j}_k{self.max_k}{dropout_suffix}_sm{self.sm_range[0]}"
)
@property
def cpp_class(self) -> str:
template_args = ", ".join(
[
f"cutlass::arch::Sm{self.sm_range[0]}",
DTYPES[self.dtype],
"true" if self.aligned else "false",
"true" if self.apply_dropout else "false",
"true" if self.preload_mmas else "false",
str(self.block_i),
str(self.block_j),
str(self.max_k),
]
)
return f"AttentionBackwardKernel<{template_args}>"
@property
def impl_group(self) -> str:
# Maps to file which will contain the implementation
dropout_suffix = "_dropout" if self.apply_dropout else ""
return (
f"{self.dtype}_{self._aligned_suffix}_k{self.max_k}{dropout_suffix}"
)
@property
def cpp_impl(self) -> str:
return KERNEL_IMPL_TEMPLATE.format(
CPP_CLASS=self.cpp_class,
NAME=self.name,
SM=self.sm_range[0],
SM_MAX=self.sm_range[1],
)
@classmethod
def get_all(cls) -> List["BwdKernel"]:
kernels: List[BwdKernel] = []
for (
aligned,
dtype,
(sm, sm_max),
apply_dropout,
max_k,
) in itertools.product(
[True, False],
DTYPES.keys(),
zip(SM, SM[1:] + [MAX_ARCH]),
[True, False],
[32, 64, 128, 2**16],
):
if dtype == "bf16" and sm < 80:
continue
if not aligned and sm >= 80:
continue
is_half = dtype in ["bf16", "f16"]
bi_values = [64]
# Some architectures have more shmem and can use 128
# We still need fallback to 64 for GPUs with less shmem
# (Sm75, Sm86 ...)
if sm >= 80 or (sm >= 70 and is_half):
if max_k > 64:
bi_values.append(128)
for bi in bi_values:
output_in_rf = is_half and max_k <= bi
preload_mmas = is_half and sm >= 80 and output_in_rf
bj = 128 if (preload_mmas and max_k > 64) else 64
kernels.append(
cls(
aligned=aligned,
dtype=dtype,
sm_range=(sm, sm_max),
apply_dropout=apply_dropout,
preload_mmas=preload_mmas,
block_i=bi,
block_j=bj,
max_k=max_k,
)
)
# Add some specialized kernels for stable diffusion BW (K=80)
# This is the only kernel that can keep the outputs on RF on
# Sm86/Sm89, so it's much faster than the 64x64 one
for dtype in ["f16", "bf16"]:
if max(args.cuda_arch) < 80:
continue
kernels.append(
cls(
aligned=True,
dtype=dtype,
sm_range=(80, MAX_ARCH),
apply_dropout=False,
preload_mmas=True,
block_i=128,
block_j=64,
max_k=96,
# Sm80 has a faster kernel for this case
dispatch_cond="cc == 86 || cc == 89",
)
)
return kernels
T = TypeVar("T", FwdKernel, BwdKernel)
def write_decl_impl(
kernels: List[T], family_name: str, impl_file: str, enable_def: str
) -> None:
cpp_file_header = """// This file is auto-generated. See "generate_kernels.py"
"""
kernels.sort()
implfile_to_kernels: Dict[str, List[T]] = collections.defaultdict(list)
cat_to_kernels: Dict[
Tuple[str, int, int], List[T]
] = collections.defaultdict(list)
dispatch_all = ""
declarations = cpp_file_header + "#pragma once\n"
declarations += f"#ifdef {enable_def}\n"
declarations += f"""#include "{impl_file}"\n"""
declarations += "namespace phi {\n"
# Declaration of kernel functions
for k in kernels:
implfile_to_kernels[k.impl_group].append(k)
cat_to_kernels[(k.dtype, k.sm_range[0], k.sm_range[1])].append(k)
for (cat_dt, cat_sm, cat_sm_max), kernels in cat_to_kernels.items():
declarations += f"// ======== {cat_dt} / sm{cat_sm} ========\n"
declarations += "\n".join(
k.cpp_impl.split("{")[0].rstrip() + ";" for k in kernels
)
dispatch_category_fn = f"dispatch_{family_name}_{cat_dt}_sm{cat_sm}"
declarations += f"\n\ntemplate <typename T> void {dispatch_category_fn}(T cb, int cc) {{\n"
for k in kernels:
_call = f"cb({k.cpp_class}(), {k.name});\n"
if k.dispatch_cond is not None:
_call = f"if ({k.dispatch_cond}) {_call}"
declarations += f" {_call}"
declarations += "}\n\n"
dispatch_all += f"""
if (std::is_same<DT, {DTYPES[cat_dt]}>::value && {cat_sm} <= cc && cc < {cat_sm_max}) {{
{dispatch_category_fn}(cb, cc);
}}"""
declarations += f"""
template <typename PaddleT, typename T>
void dispatch_{family_name}(const ::phi::GPUContext &ctx, T cb) {{
auto cc = ctx.GetComputeCapability();
using DT = typename ::phi::CutlassTrait<PaddleT>::Type;
{dispatch_all}
}}
"""
declarations += "} // namespace phi\n"
declarations += f"#endif // {enable_def}\n"
autogen_dir = Path(args.dst_path) / "autogen"
os.makedirs(autogen_dir, exist_ok=True)
declaration_path = autogen_dir / f"{family_name}.h"
declaration_path.write_text(declarations)
for f, f_kernels in implfile_to_kernels.items():
impl_cu = cpp_file_header
impl_cu += f"#ifdef {enable_def}\n"
impl_cu += f"""#include "{impl_file}"\n"""
impl_cu += "namespace phi {\n"
for k in f_kernels:
impl_cu += k.cpp_impl
impl_cu += "} // namespace phi\n"
impl_cu += f"#endif // {enable_def}\n"
impl_path = autogen_dir / "impl"
os.makedirs(impl_path, exist_ok=True)
(impl_path / f"{family_name}_{f}.cu").write_text(impl_cu)
def write_main_header(forward_impl, backward_impl):
main_header_content = '''
#pragma once
#ifdef %s
#include "%s"
#include "%s"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace phi {
template <typename T>
struct CutlassTrait {
using Type = T;
};
template <>
struct CutlassTrait<dtype::float16> {
using Type = cutlass::half_t;
};
template <>
struct CutlassTrait<dtype::bfloat16> {
using Type = cutlass::bfloat16_t;
};
template <typename T>
struct ToPhiDTypeTrait {
private:
using NonConstT = typename std::remove_const<T>::type;
static constexpr bool kIsFP16 = std::is_same<NonConstT, cutlass::half_t>::value;
static constexpr bool kIsBF16 = std::is_same<NonConstT, cutlass::bfloat16_t>::value;
public:
using Type = typename std::conditional<kIsFP16, dtype::float16,
typename std::conditional<kIsBF16, dtype::bfloat16, NonConstT>::type>::type;
};
template <typename T>
T *SafeGetTensorPtr(const DenseTensor &t) {
using PDT = typename ToPhiDTypeTrait<T>::Type;
return reinterpret_cast<T *>(reinterpret_cast<uintptr_t>(t.template data<PDT>()));
}
template <typename T>
T *SafeGetTensorPtr(const DenseTensor *t) {
return t ? SafeGetTensorPtr<T>(*t) : nullptr;
}
template <typename T>
T *SafeGetTensorPtr(const paddle::optional<DenseTensor> &t) {
return t ? SafeGetTensorPtr<T>(t.get()) : nullptr;
}
template <typename T, typename Context>
T *SafeAllocTensor(const Context &ctx, DenseTensor *t) {
using PDT = typename ToPhiDTypeTrait<T>::Type;
void *ptr = ctx.template Alloc<PDT>(t);
return reinterpret_cast<T *>(reinterpret_cast<uintptr_t>(ptr));
}
inline int64_t DimStride(const phi::DDim &dims, int n) {
int rank = dims.size();
if (n < 0) {
n += rank;
}
int64_t stride = 1;
for (int i = n+1; i < rank; ++i) {
stride *= dims[i];
}
return stride;
}
} // namespace phi
#include "./cutlass_forward.h"
#include "./cutlass_backward.h"
#endif
''' % (
ENABLE_MACRO,
forward_impl,
backward_impl,
)
path = Path(args.dst_path) / "autogen"
os.makedirs(path, exist_ok=True)
path = Path(path) / "memory_efficient_attention.h"
path.write_text(main_header_content)
if os.path.exists(Path(args.dst_path) / "autogen"):
shutil.rmtree(Path(args.dst_path) / "autogen")
forward_impl = "paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/kernel_forward.h"
backward_impl = "paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/kernel_backward.h"
write_main_header(forward_impl, backward_impl)
write_decl_impl(
FwdKernel.get_all(),
"cutlass_forward",
impl_file=forward_impl,
enable_def=ENABLE_MACRO,
)
write_decl_impl(
BwdKernel.get_all(),
"cutlass_backward",
impl_file=backward_impl,
enable_def=ENABLE_MACRO,
)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue iterator that supports prefetching
Mostly copied from "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
*/
#pragma once
#include "cutlass/arch/arch.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/threadblock/output_tile_thread_map.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
////////////////////////////////////////////////////////////////////////////////
namespace epilogue {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Tile iterator used to load and store output tile from global memory in
/// epilogue.
///
/// Satisfies: ReadableTileIterator | PredicatedTileIterator |
/// ForwardTileIterator
///
template <typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
typename Element_, ///< Element data type
bool ScatterD = false, ///< Scatter D operand or not
bool UseCUDAStore = false>
class PredicatedTileIteratorPrefetch {
public:
using ThreadMap = ThreadMap_;
using Shape = typename ThreadMap::Shape;
using Element = Element_;
using Layout = layout::RowMajor;
using TensorRef = TensorRef<Element, Layout>;
using ConstTensorRef = typename TensorRef::ConstTensorRef;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorCoord = MatrixCoord;
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
static int const kThreads = ThreadMap::kThreads;
static int const kIterations = ThreadMap::Count::kTile;
static_assert(ThreadMap::Iterations::kRow > 0,
"ThreadMap::Iterations::kRow must be > 0");
static_assert(ThreadMap::Iterations::kGroup > 0,
"ThreadMap::Iterations::kGroup must be > 0");
static_assert(ThreadMap::Iterations::kCluster > 0,
"ThreadMap::Iterations::kCluster must be > 0");
static_assert(ThreadMap::Iterations::kColumn > 0,
"ThreadMap::Iterations::kColumn must be > 0");
/// Fragment object
using Fragment =
Array<Element,
ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow *
ThreadMap::Iterations::kGroup *
ThreadMap::Iterations::kCluster *
ThreadMap::kElementsPerAccess>;
/// Memory access size
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
//
// Parameters struct
//
/// Uses a non-template class
struct Params : PredicatedTileIteratorParams {
using Base = PredicatedTileIteratorParams;
CUTLASS_HOST_DEVICE
Params() {}
CUTLASS_HOST_DEVICE
Params(Layout const& layout) // NOLINT
: PredicatedTileIteratorParams(
layout.stride(0) * static_cast<int>(sizeof(AccessType)) /
kElementsPerAccess,
make_OutputTileThreadMapDesc<ThreadMap>()) {}
CUTLASS_HOST_DEVICE
Params(Base const& base) : Base(base) {} // NOLINT
};
/// Mask object
struct Mask {
static int const kCount = ThreadMap::Iterations::kColumn;
/// Predicate state
bool predicates[kCount];
//
// Mask
//
CUTLASS_HOST_DEVICE
Mask() { enable(); }
///< Efficiently disables all accesses guarded by mask
CUTLASS_HOST_DEVICE void clear() {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kCount; ++i) {
predicates[i] = false;
}
}
///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask
CUTLASS_DEVICE void enable() {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kCount; ++i) {
predicates[i] = true;
}
}
};
private:
//
// Data members
//
/// Parameters structure containing reference and precomputed state.
PredicatedTileIteratorParams params_;
/// Byte-level pointer
uint8_t* byte_pointer_;
/// Array of boolean values to contain steady-state predicates
Mask mask_;
/// Extent of the matrix tile in rows
Index extent_row_;
/// Extent of the matrix tile in rows
Index extent_column_;
/// A thread's starting row position (assuming steady-state predicates have
/// been computed)
Index thread_start_row_;
/// A thread's starting column
Index thread_start_column_;
/// Internal state counter
int state_[3];
/// Scatter indices
int const* indices_;
//
// Static asserts about internal strides
//
static_assert(sizeof(extent_row_) == 4, "Expected 32b extents");
static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents");
static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8,
"Expected 64b strides");
private:
//
// Methods
//
public:
//
// Methods
//
/// Constructor
CUTLASS_DEVICE
PredicatedTileIteratorPrefetch(PredicatedTileIteratorParams const& params,
Element* pointer,
TensorCoord extent,
int thread_idx,
TensorCoord threadblock_offset = TensorCoord(),
int const* indices = nullptr)
: params_(params), indices_(indices) {
TensorCoord thread_offset =
ThreadMap::initial_offset(thread_idx) + threadblock_offset;
extent_row_ = extent.row();
extent_column_ = extent.column();
thread_start_row_ = thread_offset.row();
thread_start_column_ = thread_offset.column();
// Initialize predicates
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) {
mask_.predicates[c] = ((thread_offset.column() +
ThreadMap::Delta::kColumn * c) < extent.column());
}
// Null pointer performs no accesses
if (!pointer) {
mask_.clear();
}
if (ScatterD && !indices) {
mask_.clear();
}
// Initialize pointer
byte_pointer_ = reinterpret_cast<uint8_t*>(pointer) +
LongIndex(thread_offset.row()) * LongIndex(params_.stride) +
LongIndex(thread_offset.column()) * sizeof(AccessType) /
kElementsPerAccess;
if (ScatterD) {
byte_pointer_ = reinterpret_cast<uint8_t*>(pointer) +
LongIndex(thread_offset.column()) * sizeof(AccessType) /
kElementsPerAccess;
}
// Initialize internal state counter
state_[0] = state_[1] = state_[2] = 0;
}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
}
CUTLASS_DEVICE
void prefetch_all() {
CUTLASS_PRAGMA_UNROLL
for (int iter = 0; iter < kIterations; ++iter) {
prefetch();
++(*this);
}
}
CUTLASS_DEVICE
void prefetch() {
uint8_t* byte_pointer = byte_pointer_;
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster;
++cluster) {
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int row_offset = row * ThreadMap::Delta::kRow +
group * ThreadMap::Delta::kGroup +
cluster * ThreadMap::Delta::kCluster;
AccessType* memory_pointer =
reinterpret_cast<AccessType*>(byte_pointer);
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn;
++column) {
// on windows using unsigned long here gives the error
// error: asm operand type size(4) does not match
// type/size implied by constraint 'l'
uint64_t addr = (uint64_t)((
void*)&memory_pointer[column * ThreadMap::Delta::kColumn /
kElementsPerAccess]);
asm volatile("prefetch.global.L1 [ %1 ];" : "=l"(addr) : "l"(addr));
}
if (row + 1 < ThreadMap::Iterations::kRow) {
if (!ScatterD) {
byte_pointer += params_.increment_row;
}
}
}
if (group + 1 < ThreadMap::Iterations::kGroup) {
byte_pointer += params_.increment_group;
}
}
if (cluster + 1 < ThreadMap::Iterations::kCluster) {
byte_pointer += params_.increment_cluster;
}
}
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load_with_byte_offset(Fragment& frag, // NOLINT
int64_t byte_offset) const { // NOLINT
uint8_t* byte_pointer = byte_pointer_;
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster;
++cluster) {
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int frag_row_idx =
(row + ThreadMap::Iterations::kRow *
(group + ThreadMap::Iterations::kGroup * cluster));
int row_offset = row * ThreadMap::Delta::kRow +
group * ThreadMap::Delta::kGroup +
cluster * ThreadMap::Delta::kCluster;
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
AccessType* memory_pointer =
reinterpret_cast<AccessType*>(byte_pointer + byte_offset);
if (ScatterD && row_guard) {
assert(indices_);
memory_pointer = reinterpret_cast<AccessType*>(
byte_pointer + byte_offset +
LongIndex(indices_[row_offset + thread_start_row_]) *
LongIndex(params_.stride));
}
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn;
++column) {
bool guard = row_guard && mask_.predicates[column];
cutlass::arch::global_load<AccessType, sizeof(AccessType)>(
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn +
column],
(void*)&memory_pointer[column * // NOLINT
ThreadMap::Delta::kColumn / // NOLINT
kElementsPerAccess],
guard);
}
if (row + 1 < ThreadMap::Iterations::kRow) {
if (!ScatterD) {
byte_pointer += params_.increment_row;
}
}
}
if (group + 1 < ThreadMap::Iterations::kGroup) {
byte_pointer += params_.increment_group;
}
}
if (cluster + 1 < ThreadMap::Iterations::kCluster) {
byte_pointer += params_.increment_cluster;
}
}
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load(Fragment& frag) const { load_with_byte_offset(frag, 0); } // NOLINT
/// Stores a fragment to memory
CUTLASS_DEVICE
void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const {
uint8_t* byte_pointer = byte_pointer_;
AccessType const* frag_ptr = reinterpret_cast<AccessType const*>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster;
++cluster) {
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int frag_row_idx =
(row + ThreadMap::Iterations::kRow *
(group + ThreadMap::Iterations::kGroup * cluster));
int row_offset = row * ThreadMap::Delta::kRow +
group * ThreadMap::Delta::kGroup +
cluster * ThreadMap::Delta::kCluster;
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
AccessType* memory_pointer =
reinterpret_cast<AccessType*>(byte_pointer + byte_offset);
if (ScatterD && row_guard) {
assert(indices_);
memory_pointer = reinterpret_cast<AccessType*>(
byte_pointer + byte_offset +
LongIndex(indices_[row_offset + thread_start_row_]) *
LongIndex(params_.stride));
}
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn;
++column) {
bool guard = row_guard && mask_.predicates[column];
if (UseCUDAStore) {
if (guard) {
memory_pointer[column * ThreadMap::Delta::kColumn /
kElementsPerAccess] =
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn +
column];
}
} else {
cutlass::arch::global_store<AccessType, sizeof(AccessType)>(
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn +
column],
(void*)&memory_pointer[column * // NOLINT
ThreadMap::Delta::kColumn / // NOLINT
kElementsPerAccess],
guard);
}
}
if (row + 1 < ThreadMap::Iterations::kRow) {
if (!ScatterD) {
byte_pointer += params_.increment_row;
}
}
}
if (group + 1 < ThreadMap::Iterations::kGroup) {
byte_pointer += params_.increment_group;
}
}
if (cluster + 1 < ThreadMap::Iterations::kCluster) {
byte_pointer += params_.increment_cluster;
}
}
}
/// Stores a fragment to memory
CUTLASS_DEVICE
void store(Fragment const& frag) const { store_with_byte_offset(frag, 0); }
/// Loads a fragment from memory
CUTLASS_DEVICE
void downsample_load_with_byte_offset(Fragment& frag, // NOLINT
int64_t byte_offset,
int convolution_P,
int convolution_Q,
int add_P,
int add_Q,
int problem_N) const {
uint8_t* byte_pointer = byte_pointer_;
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster;
++cluster) {
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int frag_row_idx =
(row + ThreadMap::Iterations::kRow *
(group + ThreadMap::Iterations::kGroup * cluster));
int row_offset = row * ThreadMap::Delta::kRow +
group * ThreadMap::Delta::kGroup +
cluster * ThreadMap::Delta::kCluster;
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
int output_row = row_offset + thread_start_row_;
int output_N = output_row / (convolution_P * convolution_Q);
int output_PQ = output_row % (convolution_P * convolution_Q);
int output_P = output_PQ / convolution_Q;
int output_Q = output_PQ % convolution_Q;
int input_row = output_N * 2 * convolution_P * 2 * convolution_Q +
(2 * output_P + add_P) * 2 * convolution_Q +
2 * output_Q + add_Q;
int64_t byte_offset =
(input_row - output_row) * problem_N * sizeof(float);
AccessType* memory_pointer =
reinterpret_cast<AccessType*>(byte_pointer + byte_offset);
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn;
++column) {
bool guard = row_guard && mask_.predicates[column];
cutlass::arch::global_load<AccessType, sizeof(AccessType)>(
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn +
column],
(void*)&memory_pointer[column * // NOLINT
ThreadMap::Delta::kColumn / // NOLINT
kElementsPerAccess],
guard);
}
if (row + 1 < ThreadMap::Iterations::kRow) {
byte_pointer += params_.increment_row;
}
}
if (group + 1 < ThreadMap::Iterations::kGroup) {
byte_pointer += params_.increment_group;
}
}
if (cluster + 1 < ThreadMap::Iterations::kCluster) {
byte_pointer += params_.increment_cluster;
}
}
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void upsample_load_with_byte_offset(Fragment& frag, // NOLINT
int64_t byte_offset,
int convolution_P,
int convolution_Q,
int add_P,
int add_Q,
int problem_N) const {
uint8_t* byte_pointer = byte_pointer_;
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster;
++cluster) {
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int frag_row_idx =
(row + ThreadMap::Iterations::kRow *
(group + ThreadMap::Iterations::kGroup * cluster));
int row_offset = row * ThreadMap::Delta::kRow +
group * ThreadMap::Delta::kGroup +
cluster * ThreadMap::Delta::kCluster;
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
int output_row = row_offset + thread_start_row_;
int output_N = output_row / (convolution_P * convolution_Q);
int output_PQ = output_row % (convolution_P * convolution_Q);
int output_P = output_PQ / convolution_Q;
int output_Q = output_PQ % convolution_Q;
int row_add_P = add_P;
int row_add_Q = add_Q;
if (output_P > convolution_P - 2) row_add_P = 0;
if (output_Q > convolution_Q - 2) row_add_Q = 0;
int input_row = output_N * (convolution_P / 2) * (convolution_Q / 2) +
((output_P + row_add_P) / 2) * (convolution_Q / 2) +
(output_Q + row_add_Q) / 2;
int64_t byte_offset =
(input_row - output_row) * problem_N * sizeof(float);
AccessType* memory_pointer =
reinterpret_cast<AccessType*>(byte_pointer + byte_offset);
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn;
++column) {
bool guard = row_guard && mask_.predicates[column];
cutlass::arch::global_load<AccessType, sizeof(AccessType)>(
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn +
column],
(void*)&memory_pointer[column * // NOLINT
ThreadMap::Delta::kColumn / // NOLINT
kElementsPerAccess],
guard);
}
if (row + 1 < ThreadMap::Iterations::kRow) {
byte_pointer += params_.increment_row;
}
}
if (group + 1 < ThreadMap::Iterations::kGroup) {
byte_pointer += params_.increment_group;
}
}
if (cluster + 1 < ThreadMap::Iterations::kCluster) {
byte_pointer += params_.increment_cluster;
}
}
}
CUTLASS_DEVICE
MatrixCoord thread_start() const {
return MatrixCoord(thread_start_row_, thread_start_column_);
}
/// Need to get the thread start row from the tile iterator
CUTLASS_DEVICE
int32_t thread_start_row() const { return thread_start_row_; }
/// Need to get the thread start row from the tile iterator
CUTLASS_DEVICE
int32_t thread_start_column() const { return thread_start_column_; }
/// Extent of the matrix in rows
CUTLASS_DEVICE
Index extent_row() const { return extent_row_; }
/// Extent of the matrix in columns
CUTLASS_DEVICE
Index extent_column() const { return extent_column_; }
/// Advances to the next position to load or store
CUTLASS_HOST_DEVICE
PredicatedTileIteratorPrefetch& operator++() {
++state_[0];
if (!ScatterD) {
byte_pointer_ += params_.advance_row;
}
thread_start_row_ += ThreadMap::Shape::kRow;
if (state_[0] == ThreadMap::Count::kRow) {
state_[0] = 0;
++state_[1];
byte_pointer_ += params_.advance_group;
thread_start_row_ += (ThreadMap::Shape::kGroup - 1) *
ThreadMap::Shape::kRow * ThreadMap::Count::kRow;
if (state_[1] == ThreadMap::Count::kGroup) {
state_[1] = 0;
++state_[2];
byte_pointer_ += params_.advance_cluster;
thread_start_row_ += ThreadMap::Count::kGroup *
ThreadMap::Shape::kGroup * ThreadMap::Count::kRow *
ThreadMap::Shape::kRow;
if (state_[2] == ThreadMap::Count::kCluster) {
state_[2] = 0;
byte_pointer_ += params_.advance_tile;
}
}
}
return *this;
}
///< Efficiently disables all accesses guarded by mask
CUTLASS_DEVICE void clear_mask() { mask_.clear(); }
///< Efficiently enables all accesses guarded by mask
CUTLASS_DEVICE void enable_mask() { mask_.enable(); }
///< Sets the mask
CUTLASS_DEVICE void get_mask(Mask& mask) const { mask = mask_; } // NOLINT
///< Sets the mask
CUTLASS_DEVICE void set_mask(Mask const& mask) { mask_ = mask; }
};
template <typename IT>
struct MakePrefetchableIterator {
using Iterator = PredicatedTileIteratorPrefetch<typename IT::ThreadMap,
typename IT::Element>;
};
///////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include "./predicated_tile_access_iterator_residual_last.h"
#include "./predicated_tile_iterator_residual_last.h"
namespace cutlass {
namespace transform {
namespace threadblock {
template <typename BaseIterator>
struct MakeIteratorResidualLast;
template <typename Shape,
typename Element,
typename Layout,
int AdvanceRank,
typename ThreadMap,
int AccessSize,
bool Gather>
struct MakeIteratorResidualLast<PredicatedTileIterator<Shape,
Element,
Layout,
AdvanceRank,
ThreadMap,
AccessSize,
Gather>> {
using Iterator = PredicatedTileIteratorResidualLast<Shape,
Element,
Layout,
AdvanceRank,
ThreadMap,
AccessSize,
Gather>;
};
template <typename Shape,
typename Element,
typename Layout,
int AdvanceRank,
typename ThreadMap,
typename AccessType,
bool Gather>
struct MakeIteratorResidualLast<PredicatedTileAccessIterator<Shape,
Element,
Layout,
AdvanceRank,
ThreadMap,
AccessType,
Gather>> {
using Iterator = PredicatedTileAccessIteratorResidualLast<Shape,
Element,
Layout,
AdvanceRank,
ThreadMap,
AccessType,
Gather>;
};
} // namespace threadblock
} // namespace transform
} // namespace cutlass
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Templates calculating the address and predicates to the load of tiles
from pitch-linear rank=2 tensors.
This iterator uses masks to guard out-of-bounds accesses. The first tile
this iterator visits maybe partial, then the remaining tiles are complete.
So, we only need to compute the predicates twice, once before the first tile
and once for the remaining full tiles which can share the same predicates.
A precomputed "Params" object minimizes the amount of state that must be
stored in registers, and integer addition is used to advance the pointer
through memory.
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/coord.h"
#include "cutlass/cutlass.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/predicate_vector.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/tensor_view.h"
#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h"
////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace transform {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// PredicatedTileAccessIteratorResidualLast
///
template <typename Shape,
typename Element,
typename Layout,
int AdvanceRank,
typename ThreadMap,
typename AccessType,
bool Gather = false>
class PredicatedTileAccessIteratorResidualLast;
////////////////////////////////////////////////////////////////////////////////
/// Specialization of PredicatedTileAccessIteratorResidualLast for pitch-linear
/// data.
///
template <typename Shape_,
typename Element_,
int AdvanceRank,
typename ThreadMap_,
typename AccessType_,
bool Gather>
class PredicatedTileAccessIteratorResidualLast<Shape_,
Element_,
layout::PitchLinear,
AdvanceRank,
ThreadMap_,
AccessType_,
Gather> {
public:
static_assert(
AdvanceRank == 0 || AdvanceRank == 1,
"Specialization for pitch-linear iterator may along advance along the "
"contiguous(rank=0) or strided(rank=1) dimension.");
using Shape = Shape_;
using Element = Element_;
using Layout = layout::PitchLinear;
static int const kAdvanceRank = AdvanceRank;
using ThreadMap = ThreadMap_;
using AccessType = AccessType_;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorRef = TensorRef<Element, Layout>;
using TensorView = TensorView<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Pointer = Element*;
using NonConstPointer = typename platform::remove_const<Element>::type*;
using UnderlyingPredicates =
PredicatedTileAccessIteratorPredicates<Shape,
Element,
Layout,
AdvanceRank,
ThreadMap,
AccessType>;
static int const kAccessesPerVector =
ThreadMap::kElementsPerAccess / AccessType::kElements;
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
"Vectors implied by the thread map must be divisible by the "
"access type.");
using Mask = typename UnderlyingPredicates::Mask;
/// Uses a non-template class
struct Params : PredicatedTileAccessIteratorParams {
using Base = PredicatedTileAccessIteratorParams;
// Default ctor
CUTLASS_HOST_DEVICE
Params() {}
/// Construct the Params object given a pitch-linear tensor's layout
CUTLASS_HOST_DEVICE
Params(Layout const& layout) // NOLINT
: Base(layout.stride(0),
MakePredicatedTileAccessIteratorDesc<Shape,
Element,
Layout,
kAdvanceRank,
ThreadMap>()()) {}
CUTLASS_HOST_DEVICE
Params(Base const& base) : Base(base) {} // NOLINT
};
private:
/// Internal pointer type permits fast address arithmetic
using BytePointer = char*;
private:
//
// Data members
//
UnderlyingPredicates the_predicates;
Mask residual_tile_mask;
/// Parameters object with precomputed internal state
Params const& params_;
/// Internal pointer to first access of tile
BytePointer pointer_;
/// Below is used when Gather is turned on. We need to record strided_offset
/// and contiguous_offset seperated to compute the offset by using
///
/// offset = contiguous_offset + indices[strided_offset]
///
/// Gather indices
int const* indices_;
Index gather_offset_strided;
private:
/// Computes predicates based on internally tracked per-thread offset.
CUTLASS_DEVICE
void compute_predicates_(
/// Extent of the matrix window
TensorCoord extent,
/// optionally, simplify predicate calculation during 'steady state' phase
bool is_steady_state = false) {
the_predicates.compute_predicates_(extent, is_steady_state);
}
public:
/// Constructs a TileIterator from its precomputed state, threadblock offset,
/// and thread ID
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast(
/// Precomputed parameters object
Params const& params,
/// Pointer to start of tensor
Pointer pointer,
/// Extent of tensor
TensorCoord extent,
/// ID of each participating thread
int thread_id,
/// Initial offset of threadblock
TensorCoord const& threadblock_offset,
/// Gather indices
int const* indices = nullptr)
: params_(params),
pointer_(reinterpret_cast<BytePointer>(
const_cast<NonConstPointer>(pointer))),
the_predicates(extent),
indices_(indices) {
the_predicates.set_predicates(thread_id, threadblock_offset);
the_predicates.get_mask(residual_tile_mask);
// Working around a weird compiler bug happening on P100 for the backward.
// I've seen together: the_predicates.predicates_[0] = 14 (instead of 15)
// residual_tile_mask[0] = 15 (correct)
//
// Adding prints when the value is calculated (in `compute_predicates_`)
// sometimes removes the bug. The consequence is that we skip some
// element of a tensor, leading to wrong results
// Setting `compute_predicates_`'s second argument (`is_steady_state`) to
// true also seems to get rid of the bug - at the cost of twice as many
// comparisons.
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700)
constexpr bool kWorkAroundCompilerBug = false;
#else
constexpr bool kWorkAroundCompilerBug = true;
#endif
the_predicates.compute_predicates_(extent, true && !kWorkAroundCompilerBug);
// update internal pointers
Layout layout(params_.stride_);
if (!Gather) {
add_pointer_offset(layout(the_predicates.thread_offset_));
} else {
gather_offset_strided = the_predicates.thread_offset_.strided();
add_pointer_offset(
layout(make_Coord(the_predicates.thread_offset_.contiguous(), 0)));
}
}
/// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock
/// offset
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast(
/// Precomputed parameters object
Params const& params,
/// Pointer to start of tensor
Pointer pointer,
/// Extent of tensor
TensorCoord extent,
///< ID of each participating thread
int thread_id)
: PredicatedTileAccessIteratorResidualLast(
params, pointer, extent, thread_id, make_Coord(0, 0)) {}
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(int index) {
the_predicates.set_iteration_index(index);
}
CUTLASS_HOST_DEVICE
void set_residual_tile(bool is_residual_tile) {
if (is_residual_tile) {
the_predicates.set_mask(residual_tile_mask);
}
}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
pointer_ += sizeof_bits<Element>::value * pointer_offset / 8;
}
/// Advances an iterator along logical dimensions of matrix in units of whole
/// tiles
CUTLASS_DEVICE
void add_tile_offset(TensorCoord const& tile_offset) {
if (!Gather) {
if (kAdvanceRank) {
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided());
pointer_ += Shape::kContiguous * tile_offset.contiguous();
} else {
pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous());
pointer_ += Shape::kStrided * tile_offset.strided();
}
} else {
add_pointer_offset(Shape::kContiguous * tile_offset.contiguous());
gather_offset_strided += Shape::kStrided * tile_offset.strided();
}
}
/// Returns a pointer
CUTLASS_HOST_DEVICE
AccessType* get() const {
if (Gather) {
assert(indices_);
if (!valid()) {
return nullptr;
}
LongIndex contiguous_offset = the_predicates.iteration_contiguous_ *
(ThreadMap::Delta::kContiguous *
sizeof_bits<Element>::value / 8) +
the_predicates.iteration_vector_;
int strided_index =
gather_offset_strided +
the_predicates.iteration_strided_ * ThreadMap::Delta::kStrided;
LongIndex strided_offset = indices_[strided_index] *
LongIndex(params_.stride_) *
sizeof_bits<Element>::value / 8;
return reinterpret_cast<AccessType*>(pointer_ + contiguous_offset +
strided_offset);
}
return reinterpret_cast<AccessType*>(pointer_ +
the_predicates.iteration_contiguous_ *
(ThreadMap::Delta::kContiguous *
sizeof_bits<Element>::value) /
8) +
the_predicates.iteration_vector_;
}
/// Increment and return an instance to self.
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast& operator++() {
the_predicates.operator++();
++the_predicates.iteration_vector_;
if (the_predicates.iteration_vector_ < kAccessesPerVector) {
return *this; // NOLINT
}
the_predicates.iteration_vector_ = 0;
++the_predicates.iteration_contiguous_;
if (the_predicates.iteration_contiguous_ <
ThreadMap::Iterations::kContiguous) {
return *this;
}
// Enter here only if (iteration_contiguous_ ==
// ThreadMap::Iteration::kContiguous)
the_predicates.iteration_contiguous_ = 0;
++the_predicates.iteration_strided_;
if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) {
if (!Gather) {
pointer_ += params_.inc_strided_;
}
return *this;
}
// Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided)
// which means we enter the next tile.
the_predicates.iteration_strided_ = 0;
if (!Gather) {
// advance to next tile
pointer_ += params_.inc_next_;
// now return to start tile - if the iterator is subsequently advanced,
// this subtraction as well as the subsequent integer addition are both
// elided by the compiler.
pointer_ -= params_.inc_advance_;
}
return *this;
}
/// Increment and return an instance to self.
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast operator++(int) {
PredicatedTileAccessIteratorResidualLast self(*this);
operator++();
return self;
}
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask(bool enable = true) { the_predicates.clear_mask(enable); }
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void enable_mask() { the_predicates.enable_mask(); }
/// Sets the predicate mask, overriding value stored in predicate iterator
CUTLASS_HOST_DEVICE
void set_mask(Mask const& mask) { the_predicates.set_mask(mask); }
/// Gets the mask
CUTLASS_HOST_DEVICE
void get_mask(Mask& mask) { the_predicates.get_mask(mask); } // NOLINT
/// Returns whether access is valid or not
CUTLASS_HOST_DEVICE
bool valid() const { return the_predicates.valid(); }
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization of PredicatedTileAccessIteratorResidualLast for column-major
/// data.
///
/// Satisfies: ForwardTileIteratorConcept |
/// ReadableContiguousTileIteratorConcept |
/// WriteableContiguousTileIteratorConcept |
/// MaskedTileIteratorConcept
///
template <typename Shape_,
typename Element_,
int AdvanceRank,
typename ThreadMap_,
typename AccessType_,
bool Gather>
class PredicatedTileAccessIteratorResidualLast<Shape_,
Element_,
layout::ColumnMajor,
AdvanceRank,
ThreadMap_,
AccessType_,
Gather> {
public:
static_assert(
AdvanceRank == 0 || AdvanceRank == 1,
"Specialization for pitch-linear iterator may along advance along the "
"contiguous(rank=0) or strided(rank=1) dimension.");
using Shape = Shape_;
using Element = Element_;
using Layout = layout::ColumnMajor;
static int const kAdvanceRank = AdvanceRank;
using ThreadMap = ThreadMap_;
using AccessType = AccessType_;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorRef = TensorRef<Element, Layout>;
using TensorView = TensorView<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Pointer = Element*;
using NonConstPointer = typename platform::remove_const<Element>::type*;
using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast<
layout::PitchLinearShape<Shape::kRow, Shape::kColumn>,
Element,
layout::PitchLinear,
(kAdvanceRank == 0 ? 0 : 1),
ThreadMap,
AccessType,
Gather>;
/// Predicate vector stores mask to guard accesses
using Mask = typename UnderlyingIterator::Mask;
static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
/// Parameters object is precomputed state and is host-constructible
class Params {
private:
friend PredicatedTileAccessIteratorResidualLast;
/// Parameters object
typename UnderlyingIterator::Params params_;
public:
/// Default ctor
CUTLASS_HOST_DEVICE
Params() {}
/// Construct the Params object given a pitch-linear tensor's layout
CUTLASS_HOST_DEVICE
Params(Layout const& layout) // NOLINT
: params_(layout::PitchLinear(layout.stride(0))){}; // NOLINT
/// Construct the Params object given a pitch-linear tensor's layout
CUTLASS_HOST_DEVICE
Params(typename UnderlyingIterator::Params::Base const& base) // NOLINT
: params_(base) {}
};
private:
//
// Data members
//
/// Underlying pitch-linear tile iterator
UnderlyingIterator iterator_;
public:
/// Constructs a TileIterator from its precomputed state, threadblock offset,
/// and thread ID
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast(
///< Precomputed parameters object
Params const& params,
///< Pointer to start of tensor
Pointer pointer,
///< Extent of tensor
TensorCoord extent,
///< ID of each participating thread
int thread_id,
///< Initial offset of threadblock
TensorCoord const& threadblock_offset,
int const* indices =
nullptr ///< gather/scatter indices, note no support for
///< gather/scatter at this specialization
)
: iterator_(params.params_,
pointer,
layout::PitchLinearCoord(extent.row(), extent.column()),
thread_id,
layout::PitchLinearCoord(threadblock_offset.row(),
threadblock_offset.column()),
indices) {}
/// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock
/// offset
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast(
Params const& params, ///< Precomputed parameters object
Pointer pointer, ///< Pointer to start of tensor
TensorCoord extent, ///< Extent of tensor
int thread_id ///< ID of each participating thread
)
: PredicatedTileAccessIteratorResidualLast(
params, pointer, extent, thread_id, make_Coord(0, 0)) {}
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
CUTLASS_HOST_DEVICE
void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); }
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
iterator_.add_pointer_offset(pointer_offset);
}
/// Advances an iterator along logical dimensions of matrix in units of whole
/// tiles
CUTLASS_HOST_DEVICE
void add_tile_offset(TensorCoord const& tile_offset) {
iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});
}
/// Returns a pointer
CUTLASS_HOST_DEVICE
AccessType* get() const {
return reinterpret_cast<AccessType*>(iterator_.get());
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast& operator++() {
++iterator_;
return *this;
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast operator++(int) {
PredicatedTileAccessIteratorResidualLast self(*this);
operator++();
return self;
}
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void enable_mask() { iterator_.enable_mask(); }
/// Sets the predicate mask, overriding value stored in predicate iterator
CUTLASS_HOST_DEVICE
void set_mask(Mask const& mask) { iterator_.set_mask(mask); }
/// Gets the mask
CUTLASS_HOST_DEVICE
void get_mask(Mask& mask) { iterator_.get_mask(mask); } // NOLINT
/// Returns whether access is valid or not
CUTLASS_HOST_DEVICE
bool valid() { return iterator_.valid(); }
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization of PredicatedTileAccessIteratorResidualLast for row-major
/// data.
///
/// Satisfies: ForwardTileIteratorConcept |
/// ReadableContiguousTileIteratorConcept |
/// WriteableContiguousTileIteratorConcept |
/// MaskedTileIteratorConcept
///
template <typename Shape_,
typename Element_,
int AdvanceRank,
typename ThreadMap_,
typename AccessType_,
bool Gather>
class PredicatedTileAccessIteratorResidualLast<Shape_,
Element_,
layout::RowMajor,
AdvanceRank,
ThreadMap_,
AccessType_,
Gather> {
public:
static_assert(
AdvanceRank == 0 || AdvanceRank == 1,
"Specialization for pitch-linear iterator may along advance along the "
"contiguous(rank=0) or strided(rank=1) dimension.");
using Shape = Shape_;
using Element = Element_;
using Layout = layout::RowMajor;
static int const kAdvanceRank = AdvanceRank;
using ThreadMap = ThreadMap_;
using AccessType = AccessType_;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorRef = TensorRef<Element, Layout>;
using TensorView = TensorView<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Pointer = Element*;
using NonConstPointer = typename platform::remove_const<Element>::type*;
using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast<
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>,
Element,
layout::PitchLinear,
(kAdvanceRank == 0 ? 1 : 0),
ThreadMap,
AccessType,
Gather>;
static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
/// Predicate vector stores mask to guard accesses
using Mask = typename UnderlyingIterator::Mask;
/// Parameters object is precomputed state and is host-constructible
class Params {
private:
friend PredicatedTileAccessIteratorResidualLast;
/// Parameters object
typename UnderlyingIterator::Params params_;
public:
/// Default ctor
CUTLASS_HOST_DEVICE
Params() {}
/// Construct the Params object given a pitch-linear tensor's layout
CUTLASS_HOST_DEVICE
Params(Layout const& layout) // NOLINT
: params_(layout::PitchLinear(layout.stride(0))){}; // NOLINT
/// Construct the Params object given a pitch-linear tensor's layout
CUTLASS_HOST_DEVICE
Params(typename UnderlyingIterator::Params::Base const& base) // NOLINT
: params_(base) {}
};
private:
//
// Data members
//
/// Underlying pitch-linear tile iterator
UnderlyingIterator iterator_;
public:
/// Constructs a TileIterator from its precomputed state, threadblock offset,
/// and thread ID
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast(
///< Precomputed parameters object
Params const& params,
///< Pointer to start of tensor
Pointer pointer,
///< Extent of tensor
TensorCoord extent,
///< ID of each participating thread
int thread_id,
///< Initial offset of threadblock
TensorCoord const& threadblock_offset,
/// Gather indices
int const* indices = nullptr)
: iterator_(params.params_,
pointer,
layout::PitchLinearCoord(extent.column(), extent.row()),
thread_id,
layout::PitchLinearCoord(threadblock_offset.column(),
threadblock_offset.row()),
indices) {}
/// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock
/// offset
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast(
Params const& params, ///< Precomputed parameters object
Pointer pointer, ///< Pointer to start of tensor
TensorCoord extent, ///< Extent of tensor
int thread_id ///< ID of each participating thread
)
: PredicatedTileAccessIteratorResidualLast(
params, pointer, extent, thread_id, make_Coord(0, 0)) {}
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
CUTLASS_HOST_DEVICE
void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); }
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
iterator_.add_pointer_offset(pointer_offset);
}
/// Advances an iterator along logical dimensions of matrix in units of whole
/// tiles
CUTLASS_HOST_DEVICE
void add_tile_offset(TensorCoord const& tile_offset) {
iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
}
/// Returns a pointer
CUTLASS_HOST_DEVICE
AccessType* get() const {
return reinterpret_cast<AccessType*>(iterator_.get());
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast& operator++() {
++iterator_;
return *this;
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast operator++(int) {
PredicatedTileAccessIteratorResidualLast self(*this);
operator++();
return self;
}
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void enable_mask() { iterator_.enable_mask(); }
/// Sets the predicate mask, overriding value stored in predicate iterator
CUTLASS_HOST_DEVICE
void set_mask(Mask const& mask) { iterator_.set_mask(mask); }
/// Gets the mask
CUTLASS_HOST_DEVICE
void get_mask(Mask& mask) { iterator_.get_mask(mask); } // NOLINT
/// Returns whether access is valid or not
CUTLASS_HOST_DEVICE
bool valid() { return iterator_.valid(); }
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank 2
/// data.
///
/// Satisfies: ForwardTileIteratorConcept |
/// ReadableContiguousTileIteratorConcept |
/// WriteableContiguousTileIteratorConcept |
/// MaskedTileIteratorConcept
///
template <typename Shape_,
typename Element_,
int AdvanceRank,
typename ThreadMap_,
typename AccessType_>
class PredicatedTileAccessIteratorResidualLast<Shape_,
Element_,
layout::AffineRankN<2>,
AdvanceRank,
ThreadMap_,
AccessType_,
false> {
public:
static_assert(
AdvanceRank == 0 || AdvanceRank == 1,
"Specialization for pitch-linear iterator may along advance along the "
"contiguous(rank=0) or strided(rank=1) dimension.");
using Shape = Shape_;
using Element = Element_;
using Layout = layout::AffineRankN<2>;
static int const kAdvanceRank = AdvanceRank;
using ThreadMap = ThreadMap_;
using AccessType = AccessType_;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorRef = TensorRef<Element, Layout>;
using TensorView = TensorView<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Pointer = Element*;
using NonConstPointer = typename platform::remove_const<Element>::type*;
using UnderlyingPredicates =
PredicatedTileAccessIteratorPredicates<Shape,
Element,
layout::PitchLinear,
AdvanceRank,
ThreadMap,
AccessType>;
static int const kAccessesPerVector =
ThreadMap::kElementsPerAccess / AccessType::kElements;
static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements),
"Vectors implied by the thread map must be divisible by the "
"access type.");
/// Predicate vector stores mask to guard accesses
using Mask = typename UnderlyingPredicates::Mask;
/// Parameters object is precomputed state and is host-constructible
class Params {
public:
friend PredicatedTileAccessIteratorResidualLast;
private:
/// stride of pitch-linear layout (units of Element)
Coord<Layout::kStrideRank, Layout::LongIndex> stride_;
/// amount (in byte) to increment pointer to move to next access along
/// contiguous dimension
LongIndex inc_contiguous_;
/// amount (in byte) to increment pointer from first access of current
/// contiguous dimension to first access of next one.
LongIndex inc_strided_;
/// amount (in byte) to increment pointer from last access of current
/// contiguous dimension to first access of next one.
LongIndex inc_next_strided_;
/// amount (in byte) to increment pointer from last access to first access
/// of next tile
LongIndex inc_next_;
/// amount (in byte) to increment pointer from first access of current tile
/// to first access of next tile
LongIndex inc_advance_;
public:
// Default ctor
CUTLASS_HOST_DEVICE
Params()
: stride_(0),
inc_contiguous_(0),
inc_strided_(0),
inc_next_(0),
inc_advance_(0) {}
/// Construct the Params object given a pitch-linear tensor's layout
CUTLASS_HOST_DEVICE
Params(Layout const& layout) // NOLINT
: stride_({layout.stride(0), layout.stride(1)}) {
inc_contiguous_ =
(LongIndex(stride_[0]) * ThreadMap::Delta::kContiguous) *
sizeof_bits<Element>::value / 8;
inc_strided_ = (LongIndex(stride_[1]) * ThreadMap::Delta::kStrided) *
sizeof_bits<Element>::value / 8;
inc_next_strided_ =
inc_strided_ -
LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_;
if (kAdvanceRank) {
// advance along strided dimension
inc_advance_ = Shape::kStrided * LongIndex(stride_[1]) *
sizeof_bits<Element>::value / 8;
} else {
// advance along contiguous dimension
inc_advance_ =
Shape::kContiguous * stride_[0] * sizeof_bits<Element>::value / 8;
}
inc_next_ =
inc_advance_ -
LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_ -
LongIndex(ThreadMap::Iterations::kStrided - 1) * inc_strided_;
}; // NOLINT
};
private:
/// Internal pointer type permits fast address arithmetic
using BytePointer = char*;
//
// Data members
//
/// Parameters object with precomputed internal state
Params const& params_;
/// Internal pointer to first access of tile
BytePointer pointer_;
UnderlyingPredicates the_predicates;
Mask residual_tile_mask;
private:
/// Computes predicates based on internally tracked per-thread offset.
CUTLASS_DEVICE
void compute_predicates_(
/// Extent of the matrix window
TensorCoord extent,
/// optionally, simplify predicate calculation during 'steady state' phase
bool is_steady_state = false) {
the_predicates.compute_predicates_(extent, is_steady_state);
}
public:
/// Constructs a TileIterator from its precomputed state, threadblock offset,
/// and thread ID
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast(
///< Precomputed parameters object
Params const& params,
///< Pointer to start of tensor
Pointer pointer,
///< Extent of tensor
TensorCoord extent,
///< ID of each participating thread
int thread_id,
///< Initial offset of threadblock
TensorCoord const& threadblock_offset,
int const* indices =
nullptr ///< gather/scatter indices, note no support for
///< gather/scatter at this specialization
)
: params_(params),
pointer_(reinterpret_cast<BytePointer>(
const_cast<NonConstPointer>(pointer))),
the_predicates(extent) {
the_predicates.set_predicates(thread_id, threadblock_offset);
// update internal pointers
Layout layout(params_.stride_);
add_pointer_offset(layout(the_predicates.thread_offset_));
}
/// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock
/// offset
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast(
Params const& params, ///< Precomputed parameters object
Pointer pointer, ///< Pointer to start of tensor
TensorCoord extent, ///< Extent of tensor
int thread_id ///< ID of each participating thread
)
: PredicatedTileAccessIteratorResidualLast(
params, pointer, extent, thread_id, make_Coord(0, 0)) {}
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(int index) {
the_predicates.set_iteration_index(index);
}
CUTLASS_HOST_DEVICE
void set_residual_tile(bool is_residual_tile) {
if (is_residual_tile) {
the_predicates.set_mask(residual_tile_mask);
}
}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
pointer_ += sizeof_bits<Element>::value * pointer_offset / 8;
}
/// Advances an iterator along logical dimensions of matrix in units of whole
/// tiles
CUTLASS_HOST_DEVICE
void add_tile_offset(TensorCoord const& tile_offset) {
if (kAdvanceRank) {
pointer_ += params_.inc_advance_ * LongIndex(tile_offset[1]);
pointer_ += Shape::kContiguous * tile_offset[0];
} else {
pointer_ += params_.inc_advance_ * LongIndex(tile_offset[0]);
pointer_ += Shape::kStrided * tile_offset[1];
}
}
/// Returns a pointer
CUTLASS_HOST_DEVICE
AccessType* get() const {
return reinterpret_cast<AccessType*>(pointer_) +
the_predicates.iteration_vector_;
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast& operator++() {
the_predicates.operator++();
++the_predicates.iteration_vector_;
if (the_predicates.iteration_vector_ < kAccessesPerVector) {
return *this;
}
the_predicates.iteration_vector_ = 0;
++the_predicates.iteration_contiguous_;
if (the_predicates.iteration_contiguous_ <
ThreadMap::Iterations::kContiguous) {
pointer_ += params_.inc_contiguous_;
return *this;
}
// Enter here only if (iteration_contiguous_ ==
// ThreadMap::Iteration::kContiguous)
the_predicates.iteration_contiguous_ = 0;
++the_predicates.iteration_strided_;
if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) {
pointer_ += params_.inc_next_strided_;
return *this;
}
// Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided)
// which means we enter the next tile.
the_predicates.iteration_strided_ = 0;
// advance to next tile
pointer_ += params_.inc_next_;
// now return to start tile - if the iterator is subsequently advanced, this
// subtraction as well as the subsequent integer addition are both elided by
// the compiler.
pointer_ -= params_.inc_advance_;
return *this;
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast operator++(int) {
PredicatedTileAccessIteratorResidualLast self(*this);
operator++();
return self;
}
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask(bool enable = true) { the_predicates.clear_mask(enable); }
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void enable_mask() { the_predicates.enable_mask(); }
/// Sets the predicate mask, overriding value stored in predicate iterator
CUTLASS_HOST_DEVICE
void set_mask(Mask const& mask) { the_predicates.set_mask(mask); }
/// Gets the mask
CUTLASS_HOST_DEVICE
void get_mask(Mask& mask) { the_predicates.get_mask(mask); } // NOLINT
/// Returns whether access is valid or not
CUTLASS_HOST_DEVICE
bool valid() { return the_predicates.valid(); }
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank 2
/// column-major data.
///
/// Satisfies: ForwardTileIteratorConcept |
/// ReadableContiguousTileIteratorConcept |
/// WriteableContiguousTileIteratorConcept |
/// MaskedTileIteratorConcept
///
template <typename Shape_,
typename Element_,
int AdvanceRank,
typename ThreadMap_,
typename AccessType_>
class PredicatedTileAccessIteratorResidualLast<Shape_,
Element_,
layout::AffineRank2ColumnMajor,
AdvanceRank,
ThreadMap_,
AccessType_,
false> {
public:
static_assert(
AdvanceRank == 0 || AdvanceRank == 1,
"Specialization for pitch-linear iterator may along advance along the "
"contiguous(rank=0) or strided(rank=1) dimension.");
using Shape = Shape_;
using Element = Element_;
using Layout = layout::AffineRank2ColumnMajor;
static int const kAdvanceRank = AdvanceRank;
using ThreadMap = ThreadMap_;
using AccessType = AccessType_;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorRef = TensorRef<Element, Layout>;
using TensorView = TensorView<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Pointer = Element*;
using NonConstPointer = typename platform::remove_const<Element>::type*;
// Map to the underlying AffineRankN<2> layout
using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast<
layout::PitchLinearShape<Shape::kRow, Shape::kColumn>,
Element,
layout::AffineRankN<2>,
(kAdvanceRank == 0 ? 0 : 1),
ThreadMap,
AccessType>;
static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
/// Predicate vector stores mask to guard accesses
using Mask = typename UnderlyingIterator::Mask;
/// Parameters object is precomputed state and is host-constructible
class Params {
private:
friend PredicatedTileAccessIteratorResidualLast;
/// Parameters object
typename UnderlyingIterator::Params params_;
public:
/// Default ctor
CUTLASS_HOST_DEVICE
Params() {}
/// Construct the Params object given an AffineRankN<2> tensor's layout
CUTLASS_HOST_DEVICE
Params(Layout const& layout) // NOLINT
: params_(layout::AffineRankN<2>(layout.stride(0),
layout.stride(1))){}; // NOLINT
};
private:
//
// Data members
//
/// Underlying AffineRankN<2> tile iterator
UnderlyingIterator iterator_;
public:
/// Constructs a TileIterator from its precomputed state, threadblock offset,
/// and thread ID
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast(
///< Precomputed parameters object
Params const& params,
///< Pointer to start of tensor
Pointer pointer,
///< Extent of tensor
TensorCoord extent,
///< ID of each participating thread
int thread_id,
///< Initial offset of threadblock
TensorCoord const& threadblock_offset,
int const* indices =
nullptr ///< gather/scatter indices, note no support for
///< gather/scatter at this specialization
)
: iterator_(params.params_,
pointer,
layout::PitchLinearCoord(extent.row(), extent.column()),
thread_id,
layout::PitchLinearCoord(threadblock_offset.row(),
threadblock_offset.column())) {}
/// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock
/// offset
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast(
Params const& params, ///< Precomputed parameters object
Pointer pointer, ///< Pointer to start of tensor
TensorCoord extent, ///< Extent of tensor
int thread_id ///< ID of each participating thread
)
: PredicatedTileAccessIteratorResidualLast(
params, pointer, extent, thread_id, make_Coord(0, 0)) {}
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
CUTLASS_HOST_DEVICE
void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); }
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
iterator_.add_pointer_offset(pointer_offset);
}
/// Advances an iterator along logical dimensions of matrix in units of whole
/// tiles
CUTLASS_HOST_DEVICE
void add_tile_offset(TensorCoord const& tile_offset) {
iterator_.add_tile_offset(
make_Coord(tile_offset.row(), tile_offset.column()));
}
/// Returns a pointer
CUTLASS_HOST_DEVICE
AccessType* get() const {
return reinterpret_cast<AccessType*>(iterator_.get());
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast& operator++() {
++iterator_;
return *this;
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast operator++(int) {
PredicatedTileAccessIteratorResidualLast self(*this);
operator++();
return self;
}
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void enable_mask() { iterator_.enable_mask(); }
/// Sets the predicate mask, overriding value stored in predicate iterator
CUTLASS_HOST_DEVICE
void set_mask(Mask const& mask) { iterator_.set_mask(mask); }
/// Gets the mask
CUTLASS_HOST_DEVICE
void get_mask(Mask& mask) { iterator_.get_mask(mask); } // NOLINT
/// Returns whether access is valid or not
CUTLASS_HOST_DEVICE
bool valid() { return iterator_.valid(); }
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank-2
/// row-major data.
///
/// Satisfies: ForwardTileIteratorConcept |
/// ReadableContiguousTileIteratorConcept |
/// WriteableContiguousTileIteratorConcept |
/// MaskedTileIteratorConcept
///
template <typename Shape_,
typename Element_,
int AdvanceRank,
typename ThreadMap_,
typename AccessType_>
class PredicatedTileAccessIteratorResidualLast<Shape_,
Element_,
layout::AffineRank2RowMajor,
AdvanceRank,
ThreadMap_,
AccessType_,
false> {
public:
static_assert(
AdvanceRank == 0 || AdvanceRank == 1,
"Specialization for pitch-linear iterator may along advance along the "
"contiguous(rank=0) or strided(rank=1) dimension.");
using Shape = Shape_;
using Element = Element_;
using Layout = layout::AffineRank2RowMajor;
static int const kAdvanceRank = AdvanceRank;
using ThreadMap = ThreadMap_;
using AccessType = AccessType_;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorRef = TensorRef<Element, Layout>;
using TensorView = TensorView<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Pointer = Element*;
using NonConstPointer = typename platform::remove_const<Element>::type*;
// Map to the underlying AffineRankN<2> layout
using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast<
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>,
Element,
layout::AffineRankN<2>,
(kAdvanceRank == 0 ? 1 : 0),
ThreadMap,
AccessType>;
static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
/// Predicate vector stores mask to guard accesses
using Mask = typename UnderlyingIterator::Mask;
/// Parameters object is precomputed state and is host-constructible
class Params {
private:
friend PredicatedTileAccessIteratorResidualLast;
/// Parameters object
typename UnderlyingIterator::Params params_;
public:
/// Default ctor
CUTLASS_HOST_DEVICE
Params() {}
/// Construct the Params object given an AffineRankN<2> tensor's layout
CUTLASS_HOST_DEVICE
Params(Layout const& layout) // NOLINT
: params_(layout::AffineRankN<2>(layout.stride(1),
layout.stride(0))){}; // NOLINT
};
private:
//
// Data members
//
/// Underlying AffineRankN<2> tile iterator
UnderlyingIterator iterator_;
public:
/// Constructs a TileIterator from its precomputed state, threadblock offset,
/// and thread ID
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast(
///< Precomputed parameters object
Params const& params,
///< Pointer to start of tensor
Pointer pointer,
///< Extent of tensor
TensorCoord extent,
///< ID of each participating thread
int thread_id,
///< Initial offset of threadblock
TensorCoord const& threadblock_offset,
int const* indices =
nullptr ///< gather/scatter indices, note no support for
///< gather/scatter at this specialization
)
: iterator_(params.params_,
pointer,
layout::PitchLinearCoord(extent.column(), extent.row()),
thread_id,
layout::PitchLinearCoord(threadblock_offset.column(),
threadblock_offset.row())) {}
/// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock
/// offset
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast(
Params const& params, ///< Precomputed parameters object
Pointer pointer, ///< Pointer to start of tensor
TensorCoord extent, ///< Extent of tensor
int thread_id ///< ID of each participating thread
)
: PredicatedTileAccessIteratorResidualLast(
params, pointer, extent, thread_id, make_Coord(0, 0)) {}
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
CUTLASS_HOST_DEVICE
void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); }
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
iterator_.add_pointer_offset(pointer_offset);
}
/// Advances an iterator along logical dimensions of matrix in units of whole
/// tiles
CUTLASS_HOST_DEVICE
void add_tile_offset(TensorCoord const& tile_offset) {
iterator_.add_tile_offset(
make_Coord(tile_offset.column(), tile_offset.row()));
}
/// Returns a pointer
CUTLASS_HOST_DEVICE
AccessType* get() const {
return reinterpret_cast<AccessType*>(iterator_.get());
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast& operator++() {
++iterator_;
return *this;
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast operator++(int) {
PredicatedTileAccessIteratorResidualLast self(*this);
operator++();
return self;
}
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void enable_mask() { iterator_.enable_mask(); }
/// Sets the predicate mask, overriding value stored in predicate iterator
CUTLASS_HOST_DEVICE
void set_mask(Mask const& mask) { iterator_.set_mask(mask); }
/// Gets the mask
CUTLASS_HOST_DEVICE
void get_mask(Mask& mask) { iterator_.get_mask(mask); } // NOLINT
/// Returns whether access is valid or not
CUTLASS_HOST_DEVICE
bool valid() { return iterator_.valid(); }
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization of PredicatedTileAccessIteratorResidualLast for column-major
/// interleaved data. It is mapped to the congruous layout.
///
/// Satisfies: ForwardTileIteratorConcept |
/// ReadableContiguousTileIteratorConcept |
/// WriteableContiguousTileIteratorConcept |
/// MaskedTileIteratorConcept
///
template <typename Shape_,
typename Element_,
int AdvanceRank,
typename ThreadMap_,
typename AccessType_,
int InterleavedK>
class PredicatedTileAccessIteratorResidualLast<
Shape_,
Element_,
layout::ColumnMajorInterleaved<InterleavedK>,
AdvanceRank,
ThreadMap_,
AccessType_,
false> {
public:
static_assert(
AdvanceRank == 0 || AdvanceRank == 1,
"Specialization for pitch-linear iterator may along advance along the "
"contiguous(rank=0) or strided(rank=1) dimension.");
using Shape = Shape_;
using Element = Element_;
static int const kInterleavedK = InterleavedK;
using Layout = layout::ColumnMajorInterleaved<kInterleavedK>;
static int const kAdvanceRank = AdvanceRank;
using ThreadMap = ThreadMap_;
using AccessType = AccessType_;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorRef = TensorRef<Element, Layout>;
using TensorView = TensorView<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Pointer = Element*;
using NonConstPointer = typename platform::remove_const<Element>::type*;
using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast<
layout::PitchLinearShape<Shape::kRow * kInterleavedK,
Shape::kColumn / kInterleavedK>,
Element,
layout::PitchLinear,
(kAdvanceRank == 0 ? 0 : 1),
ThreadMap,
AccessType>;
static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
/// Predicate vector stores mask to guard accesses
using Mask = typename UnderlyingIterator::Mask;
/// Parameters object is precomputed state and is host-constructible
class Params {
private:
friend PredicatedTileAccessIteratorResidualLast;
/// Parameters object
typename UnderlyingIterator::Params params_;
public:
CUTLASS_HOST_DEVICE
Params() {}
/// Construct the Params object given a pitch-linear tensor's layout
CUTLASS_HOST_DEVICE
Params(Layout const& layout) // NOLINT
: params_(layout::PitchLinear(layout.stride(0))) {}
CUTLASS_HOST_DEVICE
Params(typename UnderlyingIterator::Params::Base const& base) // NOLINT
: params_(base) {}
};
private:
//
// Data members
//
/// Underlying pitch-linear tile iterator
UnderlyingIterator iterator_;
public:
/// Constructs a TileIterator from its precomputed state, threadblock offset,
/// and thread ID
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast(
/// Precomputed parameters object
Params const& params,
/// Pointer to start of tensor
Pointer pointer,
/// Extent of tensor
TensorCoord extent,
/// ID of each participating thread
int thread_id,
/// Initial offset of threadblock
TensorCoord const& threadblock_offset,
int const* indices =
nullptr ///< gather/scatter indices, note no support for
///< gather/scatter at this specialization
)
: iterator_(params.params_,
pointer,
layout::PitchLinearCoord(extent.row() * kInterleavedK,
extent.column() / kInterleavedK),
thread_id,
layout::PitchLinearCoord(
threadblock_offset.row() * kInterleavedK,
threadblock_offset.column() / kInterleavedK)) {}
/// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock
/// offset
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast(
Params const& params, ///< Precomputed parameters object
Pointer pointer, ///< Pointer to start of tensor
TensorCoord extent, ///< Extent of tensor
int thread_id ///< ID of each participating thread
)
: PredicatedTileAccessIteratorResidualLast(
params, pointer, extent, thread_id, make_Coord(0, 0)) {}
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
CUTLASS_HOST_DEVICE
void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); }
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
iterator_.add_pointer_offset(pointer_offset);
}
/// Advances an iterator along logical dimensions of matrix in units of whole
/// tiles
CUTLASS_HOST_DEVICE
void add_tile_offset(TensorCoord const& tile_offset) {
iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()});
}
/// Returns a pointer
CUTLASS_HOST_DEVICE
AccessType* get() const {
return reinterpret_cast<AccessType*>(iterator_.get());
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast& operator++() {
++iterator_;
return *this;
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast operator++(int) {
PredicatedTileAccessIteratorResidualLast self(*this);
operator++();
return self;
}
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void enable_mask() { iterator_.enable_mask(); }
/// Sets the predicate mask, overriding value stored in predicate iterator
CUTLASS_HOST_DEVICE
void set_mask(Mask const& mask) { iterator_.set_mask(mask); }
/// Gets the mask
CUTLASS_HOST_DEVICE
void get_mask(Mask& mask) { iterator_.get_mask(mask); } // NOLINT
/// Returns whether access is valid or not
CUTLASS_HOST_DEVICE
bool valid() { return iterator_.valid(); }
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization of PredicatedTileAccessIteratorResidualLast for row-major
/// interleaved data.
// It is mapped to the congruous layout.
///
/// Satisfies: ForwardTileIteratorConcept |
/// ReadableContiguousTileIteratorConcept |
/// WriteableContiguousTileIteratorConcept |
/// MaskedTileIteratorConcept
///
template <typename Shape_,
typename Element_,
int AdvanceRank,
typename ThreadMap_,
typename AccessType_,
int InterleavedK>
class PredicatedTileAccessIteratorResidualLast<
Shape_,
Element_,
layout::RowMajorInterleaved<InterleavedK>,
AdvanceRank,
ThreadMap_,
AccessType_,
false> {
public:
static_assert(
AdvanceRank == 0 || AdvanceRank == 1,
"Specialization for pitch-linear iterator may along advance along the "
"contiguous(rank=0) or strided(rank=1) dimension.");
using Shape = Shape_;
using Element = Element_;
static int const kInterleavedK = InterleavedK;
using Layout = layout::RowMajorInterleaved<kInterleavedK>;
static int const kAdvanceRank = AdvanceRank;
using ThreadMap = ThreadMap_;
using AccessType = AccessType_;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorRef = TensorRef<Element, Layout>;
using TensorView = TensorView<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Pointer = Element*;
using NonConstPointer = typename platform::remove_const<Element>::type*;
using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast<
layout::PitchLinearShape<Shape::kColumn * kInterleavedK,
Shape::kRow / kInterleavedK>,
Element,
layout::PitchLinear,
(kAdvanceRank == 0 ? 1 : 0),
ThreadMap,
AccessType>;
static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector;
/// Predicate vector stores mask to guard accesses
using Mask = typename UnderlyingIterator::Mask;
/// Parameters object is precomputed state and is host-constructible
class Params {
private:
friend PredicatedTileAccessIteratorResidualLast;
/// Parameters object
typename UnderlyingIterator::Params params_;
public:
CUTLASS_HOST_DEVICE
Params() {}
/// Construct the Params object given a pitch-linear tensor's layout
CUTLASS_HOST_DEVICE
Params(Layout const& layout) // NOLINT
: params_(layout::PitchLinear(layout.stride(0))) {}
CUTLASS_HOST_DEVICE
Params(typename UnderlyingIterator::Params::Base const& base) // NOLINT
: params_(base) {}
};
private:
//
// Data members
//
/// Underlying pitch-linear tile iterator
UnderlyingIterator iterator_;
public:
/// Constructs a TileIterator from its precomputed state, threadblock offset,
/// and thread ID
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast(
/// Precomputed parameters object
Params const& params,
/// Pointer to start of tensor
Pointer pointer,
/// Extent of tensor
TensorCoord extent,
/// ID of each participating thread
int thread_id,
/// Initial offset of threadblock
TensorCoord const& threadblock_offset,
int const* indices =
nullptr ///< gather/scatter indices, note no support for
///< gather/scatter at this specialization
)
: iterator_(params.params_,
pointer,
layout::PitchLinearCoord(extent.column() * kInterleavedK,
extent.row() / kInterleavedK),
thread_id,
layout::PitchLinearCoord(
threadblock_offset.column() * kInterleavedK,
threadblock_offset.row() / kInterleavedK)) {}
/// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock
/// offset
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast(
Params const& params, ///< Precomputed parameters object
Pointer pointer, ///< Pointer to start of tensor
TensorCoord extent, ///< Extent of tensor
int thread_id ///< ID of each participating thread
)
: PredicatedTileAccessIteratorResidualLast(
params, pointer, extent, thread_id, make_Coord(0, 0)) {}
/// Overrides the internal iteration index
CUTLASS_HOST_DEVICE
void set_iteration_index(int index) { iterator_.set_iteration_index(index); }
CUTLASS_HOST_DEVICE
void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); }
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
iterator_.add_pointer_offset(pointer_offset);
}
/// Advances an iterator along logical dimensions of matrix in units of whole
/// tiles
CUTLASS_HOST_DEVICE
void add_tile_offset(TensorCoord const& tile_offset) {
iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()});
}
/// Returns a pointer
CUTLASS_HOST_DEVICE
AccessType* get() const {
return reinterpret_cast<AccessType*>(iterator_.get());
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast& operator++() {
++iterator_;
return *this;
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileAccessIteratorResidualLast operator++(int) {
PredicatedTileAccessIteratorResidualLast self(*this);
operator++();
return self;
}
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void enable_mask() { iterator_.enable_mask(); }
/// Sets the predicate mask, overriding value stored in predicate iterator
CUTLASS_HOST_DEVICE
void set_mask(Mask const& mask) { iterator_.set_mask(mask); }
/// Gets the mask
CUTLASS_HOST_DEVICE
void get_mask(Mask& mask) { iterator_.get_mask(mask); } // NOLINT
/// Returns whether access is valid or not
CUTLASS_HOST_DEVICE
bool valid() { return iterator_.valid(); }
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace transform
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Templates implementing loading of tiles from pitch-linear rank=2
tensors.
This iterator uses masks to guard out-of-bounds accesses. The first tile
this iterator visits maybe partial, then the remaining tiles are complete.
So, we only need to compute the predicates twice, once before the first tile
and once for the remaining full tiles which can share the same predicates.
A precomputed "Params" object minimizes the amount of state that must be
stored in registers, and integer addition is used to advance the pointer
through memory.
*/
#pragma once
#include "cutlass/arch/memory.h"
#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace transform {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// PredicatedTileIteratorResidualLast
///
/// Satisfies: ForwardTileIteratorConcept |
/// ReadableContiguousTileIteratorConcept |
/// WriteableContiguousTileIteratorConcept |
/// MaskedTileIteratorConcept
///
/// Regular tile iterator using a precomputed control structure to minimize
/// register liveness and integer arithmetic.
///
/// Layout is assumed to be invariant at the time the precomputed "Params"
/// object is constructed.
///
/// Base pointer and tensor extents may be specified at the time the iterator is
/// constructed. Subsequently, they are assumed to be immutable.
///
/// Adding a logical coordinate offset may be performed at the time the iterator
/// is constructed. Subsequent additions to logical coordinate offset may be
/// performed but are relatively expensive.
///
/// Visitation order is intended to first visit a "residual" tile that may be
/// partially full in both the advance dimension and the steady-state dimension.
/// This is assumed to be the last tile in the iteration sequence. Advancing an
/// iterator that has just been constructed moves to the first tile that is full
/// in the advance dimension and recomputes predicates. Subsequent accesses may
/// be performed without updating internal predicates and are efficient in terms
/// of live register state and pointer arithmetic instructions.
///
/// To be efficient, this assumes the iterator will be dereferenced and advanced
/// at least once outside any looping structure to minimize integer arithmetic.
///
/// Acceses out of bounds are safe so long as `clear_mask()` is called prior to
/// dereferencing the iterator.
///
///
/// Example:
///
/// An efficient pipeline structure may be constructed as follows:
///
// template <typename Iterator>
// __global__ void kernel(
// typename Iterator::Params params,
// typename Iterator::Element *ptr,
// TensorCoord extent) {
//
// typename Iterator::Fragment fragment;
//
// TensorCoord threadblock_offset(0, 0);
//
// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets);
//
//
// fragment = *iter; // load "residue" tile first
// ++iter; // advance to first "steady state" tile and update
// internal masks
//
//
// #pragma unroll
// for (int i = Remaining - 1; i >= 0; --i) {
//
// f(fragment);
//
// if (!i) {
// iter.clear_mask(); // light-weight operation to clear masks -
// subsequent loads become NO-OPs.
// }
//
// fragment = *iter; // load tile during "steady state" phase
// ++iter; // advance to next tile - lightweight due to
// steady-state masks
// }
// }
//
// void host(TensorView<Element, 2, layout::PitchLinear> view) {
//
// using Iterator =
// transform::threadblock::PredicatedTileIteratorResidualLast;
//
// typename Iterator::Params params(view.layout());
//
// kernel<Iterator>(params, view.data());
// }
///
///
template <typename Shape,
typename Element,
typename Layout,
int AdvanceRank,
typename ThreadMap,
int AccessSize = ThreadMap::kElementsPerAccess,
bool Gather = false>
class PredicatedTileIteratorResidualLast;
////////////////////////////////////////////////////////////////////////////////
/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data.
///
/// Satisfies: ForwardTileIteratorConcept |
/// ReadableContiguousTileIteratorConcept |
/// WriteableContiguousTileIteratorConcept |
/// MaskedTileIteratorConcept
///
template <typename Shape_,
typename Element_,
int AdvanceRank,
typename ThreadMap_,
int AccessSize,
bool Gather>
class PredicatedTileIteratorResidualLast<Shape_,
Element_,
layout::PitchLinear,
AdvanceRank,
ThreadMap_,
AccessSize,
Gather> {
public:
static_assert(
AdvanceRank == 0 || AdvanceRank == 1,
"Specialization for pitch-linear iterator may advance along the "
"contiguous(rank=0) or strided(rank=1) dimension.");
using Shape = Shape_;
using Element = Element_;
using Layout = layout::PitchLinear;
static int const kAdvanceRank = AdvanceRank;
using ThreadMap = ThreadMap_;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorRef = TensorRef<Element, Layout>;
using TensorView = TensorView<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Pointer = Element*;
using NonConstPointer = typename platform::remove_const<Element>::type*;
/// Type used for internal memory accesses
using AccessType =
AlignedArray<Element,
AccessSize,
(AccessSize * sizeof_bits<Element>::value / 8)>;
/// Underlying iterator to compute the addresses
using TileAccessIterator =
PredicatedTileAccessIteratorResidualLast<Shape,
Element,
Layout,
kAdvanceRank,
ThreadMap,
AccessType,
Gather>;
static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector;
/// Fragment object to be loaded or stored
using Fragment = cutlass::Array<Element,
ThreadMap::Iterations::kCount *
ThreadMap::kElementsPerAccess>;
/// Predicate vector stores mask to guard accesses
using Mask = typename TileAccessIterator::Mask;
/// Parameters object is precomputed state and is host-constructible
class Params {
public:
using Base = typename TileAccessIterator::Params::Base;
friend PredicatedTileIteratorResidualLast;
private:
/// Parameters object
typename TileAccessIterator::Params params_;
public:
/// Construct the Params object given a pitch-linear tensor's layout
CUTLASS_HOST_DEVICE
Params(Layout const& layout) : params_(layout) {} // NOLINT
CUTLASS_HOST_DEVICE
Params() {}
CUTLASS_HOST_DEVICE
Params(Base const& base) : params_(base) {} // NOLINT
};
private:
/// Internal pointer type permits fast address arithmetic
using BytePointer = char*;
private:
//
// Data members
//
/// Data member to the tile access iterator
TileAccessIterator address_iterator_;
public:
/// Constructs a TileIterator from its precomputed state, threadblock offset,
/// and thread ID
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast(
/// Precomputed parameters object
Params const& params,
/// Pointer to start of tensor
Pointer pointer,
/// Extent of tensor
TensorCoord extent,
/// ID of each participating thread
int thread_id,
/// Initial offset of threadblock
TensorCoord const& threadblock_offset,
/// Gather indices
int const* indices = nullptr)
: address_iterator_(params.params_,
pointer,
extent,
thread_id,
threadblock_offset,
indices) {}
/// Construct a PredicatedTileIteratorResidualLast with zero threadblock
/// offset
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast(
Params const& params, ///< Precomputed parameters object
Pointer pointer, ///< Pointer to start of tensor
TensorCoord extent, ///< Extent of tensor
int thread_id ///< ID of each participating thread
)
: PredicatedTileIteratorResidualLast(
params, pointer, extent, thread_id, make_Coord(0, 0)) {}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
address_iterator_.add_pointer_offset(pointer_offset);
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast& operator++() {
if (kAdvanceRank)
address_iterator_.add_tile_offset({0, 1});
else
address_iterator_.add_tile_offset({1, 0});
return *this;
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast operator++(int) {
PredicatedTileIteratorResidualLast self(*this);
operator++();
return self;
}
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); }
CUTLASS_HOST_DEVICE
void set_residual_tile(bool enable) {
address_iterator_.set_residual_tile(enable);
}
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void enable_mask() { address_iterator_.enable_mask(); }
/// Sets the predicate mask, overriding value stored in predicate iterator
CUTLASS_HOST_DEVICE
void set_mask(Mask const& mask) { address_iterator_.set_mask(mask); }
/// Gets the mask
CUTLASS_HOST_DEVICE
void get_mask(Mask& mask) { address_iterator_.get_mask(mask); } // NOLINT
CUTLASS_DEVICE
void load_with_pointer_offset(Fragment& frag, // NOLINT
Index pointer_offset) { // NOLINT
load_with_byte_offset(frag,
pointer_offset * sizeof_bits<Element>::value / 8);
}
CUTLASS_DEVICE
void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { // NOLINT
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < kAccessesPerVector; ++v) {
int idx = v + kAccessesPerVector *
(c + s * ThreadMap::Iterations::kContiguous);
address_iterator_.set_iteration_index(idx);
char const* byte_ptr =
reinterpret_cast<char const*>(address_iterator_.get()) +
byte_offset;
AccessType const* access_ptr = // NOLINT
reinterpret_cast<AccessType const*>(byte_ptr);
cutlass::arch::global_load<AccessType, sizeof(AccessType)>(
frag_ptr[idx], access_ptr, address_iterator_.valid());
++address_iterator_;
}
}
}
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load(Fragment& frag) { load_with_byte_offset(frag, 0); } // NOLINT
/// Store a fragment to memory
CUTLASS_DEVICE
void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
store_with_byte_offset(frag,
pointer_offset * sizeof_bits<Element>::value / 8);
}
/// Store a fragment to memory
CUTLASS_DEVICE
void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) {
address_iterator_.set_iteration_index(0);
AccessType const* frag_ptr = reinterpret_cast<AccessType const*>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < kAccessesPerVector; ++v) {
int idx = v + kAccessesPerVector *
(c + s * ThreadMap::Iterations::kContiguous);
char* byte_ptr =
reinterpret_cast<char*>(address_iterator_.get()) + byte_offset;
AccessType* access_ptr = reinterpret_cast<AccessType*>(byte_ptr);
if (address_iterator_.valid()) {
*access_ptr = frag_ptr[idx];
}
++address_iterator_;
}
}
}
}
/// Store a fragment to memory
CUTLASS_DEVICE
void store(Fragment const& frag) { store_with_byte_offset(frag, 0); }
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data.
///
/// Satisfies: ForwardTileIteratorConcept |
/// ReadableContiguousTileIteratorConcept |
/// WriteableContiguousTileIteratorConcept |
/// MaskedTileIteratorConcept
///
template <typename Shape_,
typename Element_,
int AdvanceRank,
typename ThreadMap_,
int AccessSize,
bool Gather>
class PredicatedTileIteratorResidualLast<Shape_,
Element_,
layout::ColumnMajor,
AdvanceRank,
ThreadMap_,
AccessSize,
Gather> {
public:
static_assert(
AdvanceRank == 0 || AdvanceRank == 1,
"Specialization for pitch-linear iterator may along advance along the "
"contiguous(rank=0) or strided(rank=1) dimension.");
using Shape = Shape_;
using Element = Element_;
using Layout = layout::ColumnMajor;
static int const kAdvanceRank = AdvanceRank;
using ThreadMap = ThreadMap_;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorRef = TensorRef<Element, Layout>;
using TensorView = TensorView<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Pointer = Element*;
using NonConstPointer = typename platform::remove_const<Element>::type*;
using UnderlyingIterator = PredicatedTileIteratorResidualLast<
layout::PitchLinearShape<Shape::kRow, Shape::kColumn>,
Element,
layout::PitchLinear,
(kAdvanceRank == 0 ? 0 : 1),
ThreadMap,
AccessSize,
Gather>;
using AccessType = typename UnderlyingIterator::AccessType;
/// Fragment object to be loaded or stored
using Fragment = cutlass::Array<Element,
ThreadMap::Iterations::kCount *
ThreadMap::kElementsPerAccess>;
/// Predicate vector stores mask to guard accesses
using Mask = typename UnderlyingIterator::Mask;
/// Parameters object is precomputed state and is host-constructible
class Params {
private:
friend PredicatedTileIteratorResidualLast;
/// Parameters object
typename UnderlyingIterator::Params params_;
public:
CUTLASS_HOST_DEVICE
Params() {}
/// Construct the Params object given a pitch-linear tensor's layout
CUTLASS_HOST_DEVICE
Params(Layout const& layout) // NOLINT
: params_(layout::PitchLinear(layout.stride(0))) {}
CUTLASS_HOST_DEVICE
Params(typename UnderlyingIterator::Params::Base const& base) // NOLINT
: params_(base) {}
};
private:
//
// Data members
//
/// Underlying pitch-linear tile iterator
UnderlyingIterator iterator_;
public:
/// Constructs a TileIterator from its precomputed state, threadblock offset,
/// and thread ID
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast(
Params const& params, ///< Precomputed parameters object
Pointer pointer, ///< Pointer to start of tensor
TensorCoord extent, ///< Extent of tensor
int thread_id, ///< ID of each participating thread
TensorCoord const& threadblock_offset, ///< Initial offset of threadblock
int const* indices =
nullptr ///< gather/scatter indices, note no support for
///< gather/scatter at this specialization
)
: iterator_(params.params_,
pointer,
layout::PitchLinearCoord(extent.row(), extent.column()),
thread_id,
layout::PitchLinearCoord(threadblock_offset.row(),
threadblock_offset.column()),
indices) {}
/// Construct a PredicatedTileIteratorResidualLast with zero threadblock
/// offset
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast(
Params const& params, ///< Precomputed parameters object
Pointer pointer, ///< Pointer to start of tensor
TensorCoord extent, ///< Extent of tensor
int thread_id ///< ID of each participating thread
)
: PredicatedTileIteratorResidualLast(
params, pointer, extent, thread_id, make_Coord(0, 0)) {}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
iterator_.add_pointer_offset(pointer_offset);
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast& operator++() {
++iterator_;
return *this;
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast operator++(int) {
PredicatedTileIteratorResidualLast self(*this);
operator++();
return self;
}
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
CUTLASS_HOST_DEVICE
void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); }
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void enable_mask() { iterator_.enable_mask(); }
/// Sets the predicate mask, overriding value stored in predicate iterator
CUTLASS_HOST_DEVICE
void set_mask(Mask const& mask) { iterator_.set_mask(mask); }
/// Gets the mask
CUTLASS_HOST_DEVICE
void get_mask(Mask& mask) { iterator_.get_mask(mask); } // NOLINT
/// Loads a fragment from memory
CUTLASS_DEVICE
void load_with_pointer_offset(Fragment& frag, // NOLINT
Index pointer_offset) { // NOLINT
iterator_.load_with_pointer_offset(frag, pointer_offset);
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { // NOLINT
iterator_.load_with_byte_offset(frag, byte_offset);
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load(Fragment& frag) { load_with_pointer_offset(frag, 0); } // NOLINT
/// Store a fragment to memory
CUTLASS_DEVICE
void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
iterator_.store_with_pointer_offset(frag, pointer_offset);
}
/// Store a fragment to memory
CUTLASS_DEVICE
void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) {
iterator_.store_with_byte_offset(frag, byte_offset);
}
/// Store a fragment to memory
CUTLASS_DEVICE
void store(Fragment const& frag) { store_with_pointer_offset(frag, 0); }
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data.
///
/// Satisfies: ForwardTileIteratorConcept |
/// ReadableContiguousTileIteratorConcept |
/// WriteableContiguousTileIteratorConcept |
/// MaskedTileIteratorConcept
///
template <typename Shape_,
typename Element_,
int AdvanceRank,
typename ThreadMap_,
int AccessSize,
bool Gather>
class PredicatedTileIteratorResidualLast<Shape_,
Element_,
layout::RowMajor,
AdvanceRank,
ThreadMap_,
AccessSize,
Gather> {
public:
static_assert(
AdvanceRank == 0 || AdvanceRank == 1,
"Specialization for pitch-linear iterator may along advance along the "
"contiguous(rank=0) or strided(rank=1) dimension.");
using Shape = Shape_;
using Element = Element_;
using Layout = layout::RowMajor;
static int const kAdvanceRank = AdvanceRank;
using ThreadMap = ThreadMap_;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorRef = TensorRef<Element, Layout>;
using TensorView = TensorView<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Pointer = Element*;
using NonConstPointer = typename platform::remove_const<Element>::type*;
using UnderlyingIterator = PredicatedTileIteratorResidualLast<
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>,
Element,
layout::PitchLinear,
(kAdvanceRank == 0 ? 1 : 0),
ThreadMap,
AccessSize,
Gather>;
using AccessType = typename UnderlyingIterator::AccessType;
/// Fragment object to be loaded or stored
using Fragment = cutlass::Array<Element,
ThreadMap::Iterations::kCount *
ThreadMap::kElementsPerAccess>;
/// Predicate vector stores mask to guard accesses
using Mask = typename UnderlyingIterator::Mask;
/// Parameters object is precomputed state and is host-constructible
class Params {
private:
friend PredicatedTileIteratorResidualLast;
/// Parameters object
typename UnderlyingIterator::Params params_;
public:
CUTLASS_HOST_DEVICE
Params() {}
/// Construct the Params object given a pitch-linear tensor's layout
CUTLASS_HOST_DEVICE
Params(Layout const& layout) // NOLINT
: params_(layout::PitchLinear(layout.stride(0))) {}
CUTLASS_HOST_DEVICE
Params(typename UnderlyingIterator::Params::Base const& base) // NOLINT
: params_(base) {}
};
private:
//
// Data members
//
/// Underlying pitch-linear tile iterator
UnderlyingIterator iterator_;
public:
/// Constructs a TileIterator from its precomputed state, threadblock offset,
/// and thread ID
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast(
Params const& params, ///< Precomputed parameters object
Pointer pointer, ///< Pointer to start of tensor
TensorCoord extent, ///< Extent of tensor
int thread_id, ///< ID of each participating thread
TensorCoord const& threadblock_offset, ///< Initial offset of threadblock
int const* indices = nullptr ///< Gather indices
)
: iterator_(params.params_,
pointer,
layout::PitchLinearCoord(extent.column(), extent.row()),
thread_id,
layout::PitchLinearCoord(threadblock_offset.column(),
threadblock_offset.row()),
indices) {}
/// Construct a PredicatedTileIteratorResidualLast with zero threadblock
/// offset
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast(
Params const& params, ///< Precomputed parameters object
Pointer pointer, ///< Pointer to start of tensor
TensorCoord extent, ///< Extent of tensor
int thread_id ///< ID of each participating thread
)
: PredicatedTileIteratorResidualLast(
params, pointer, extent, thread_id, make_Coord(0, 0)) {}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
iterator_.add_pointer_offset(pointer_offset);
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast& operator++() {
++iterator_;
return *this;
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast operator++(int) {
PredicatedTileIteratorResidualLast self(*this);
operator++();
return self;
}
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
CUTLASS_HOST_DEVICE
void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); }
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void enable_mask() { iterator_.enable_mask(); }
/// Sets the predicate mask, overriding value stored in predicate iterator
CUTLASS_HOST_DEVICE
void set_mask(Mask const& mask) { iterator_.set_mask(mask); }
/// Gets the mask
CUTLASS_HOST_DEVICE
void get_mask(Mask& mask) { iterator_.get_mask(mask); } // NOLINT
/// Loads a fragment from memory
CUTLASS_DEVICE
void load_with_pointer_offset(Fragment& frag, // NOLINT
Index pointer_offset) { // NOLINT
iterator_.load_with_pointer_offset(frag, pointer_offset);
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { // NOLINT
iterator_.load_with_byte_offset(frag, byte_offset);
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load(Fragment& frag) { load_with_pointer_offset(frag, 0); } // NOLINT
/// Store a fragment to memory
CUTLASS_DEVICE
void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
iterator_.store_with_pointer_offset(frag, pointer_offset);
}
/// Store a fragment to memory
CUTLASS_DEVICE
void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) {
iterator_.store_with_byte_offset(frag, byte_offset);
}
/// Store a fragment to memory
CUTLASS_DEVICE
void store(Fragment const& frag) { store_with_pointer_offset(frag, 0); }
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization of PredicatedTileIteratorResidualLast for affine rank-2 data.
///
/// Satisfies: ForwardTileIteratorConcept |
/// ReadableContiguousTileIteratorConcept |
/// WriteableContiguousTileIteratorConcept |
/// MaskedTileIteratorConcept
///
template <typename Shape_,
typename Element_,
int AdvanceRank,
typename ThreadMap_,
int AccessSize>
class PredicatedTileIteratorResidualLast<Shape_,
Element_,
layout::AffineRankN<2>,
AdvanceRank,
ThreadMap_,
AccessSize,
false> {
public:
static_assert(
AdvanceRank == 0 || AdvanceRank == 1,
"Specialization for pitch-linear iterator may advance along the "
"contiguous(rank=0) or strided(rank=1) dimension.");
using Shape = Shape_;
using Element = Element_;
using Layout = layout::AffineRankN<2>;
static int const kAdvanceRank = AdvanceRank;
using ThreadMap = ThreadMap_;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorRef = TensorRef<Element, Layout>;
using TensorView = TensorView<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Pointer = Element*;
using NonConstPointer = typename platform::remove_const<Element>::type*;
/// Type used for internal memory accesses
using AccessType =
AlignedArray<Element,
AccessSize,
(AccessSize * sizeof_bits<Element>::value / 8)>;
/// Underlying iterator to compute the addresses
using TileAccessIterator =
PredicatedTileAccessIteratorResidualLast<Shape,
Element,
Layout,
kAdvanceRank,
ThreadMap,
AccessType>;
static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector;
/// Fragment object to be loaded or stored
using Fragment = cutlass::Array<Element,
ThreadMap::Iterations::kCount *
ThreadMap::kElementsPerAccess>;
/// Predicate vector stores mask to guard accesses
using Mask = typename TileAccessIterator::Mask;
/// Parameters object is precomputed state and is host-constructible
class Params {
public:
friend PredicatedTileIteratorResidualLast;
private:
/// Parameters object
typename TileAccessIterator::Params params_;
public:
/// Construct the Params object given a pitch-linear tensor's layout
CUTLASS_HOST_DEVICE
Params(Layout const& layout) : params_(layout) {} // NOLINT
CUTLASS_HOST_DEVICE
Params() {}
};
private:
/// Internal pointer type permits fast address arithmetic
using BytePointer = char*;
private:
//
// Data members
//
/// Data member to the tile access iterator
TileAccessIterator address_iterator_;
public:
/// Constructs a TileIterator from its precomputed state, threadblock offset,
/// and thread ID
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast(
/// Precomputed parameters object
Params const& params,
/// Pointer to start of tensor
Pointer pointer,
/// Extent of tensor
TensorCoord extent,
/// ID of each participating thread
int thread_id,
/// Initial offset of threadblock
TensorCoord const& threadblock_offset,
int const* indices =
nullptr ///< gather/scatter indices, note no support for
///< gather/scatter at this specialization
)
: address_iterator_(
params.params_, pointer, extent, thread_id, threadblock_offset) {}
/// Construct a PredicatedTileIteratorResidualLast with zero threadblock
/// offset
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast(
Params const& params, ///< Precomputed parameters object
Pointer pointer, ///< Pointer to start of tensor
TensorCoord extent, ///< Extent of tensor
int thread_id ///< ID of each participating thread
)
: PredicatedTileIteratorResidualLast(
params, pointer, extent, thread_id, make_Coord(0, 0)) {}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
address_iterator_.add_pointer_offset(pointer_offset);
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast& operator++() {
if (kAdvanceRank)
address_iterator_.add_tile_offset(make_Coord(0, 1));
else
address_iterator_.add_tile_offset(make_Coord(1, 0));
return *this;
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast operator++(int) {
PredicatedTileIteratorResidualLast self(*this);
operator++();
return self;
}
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask(bool enable = true) { address_iterator_.clear_mask(enable); }
CUTLASS_HOST_DEVICE
void set_residual_tile(bool enable) {
address_iterator_.set_residual_tile(enable);
}
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void enable_mask() { address_iterator_.enable_mask(); }
/// Sets the predicate mask, overriding value stored in predicate iterator
CUTLASS_HOST_DEVICE
void set_mask(Mask const& mask) { address_iterator_.set_mask(mask); }
/// Gets the mask
CUTLASS_HOST_DEVICE
void get_mask(Mask& mask) { address_iterator_.get_mask(mask); } // NOLINT
CUTLASS_DEVICE
void load_with_pointer_offset(Fragment& frag, // NOLINT
Index pointer_offset) { // NOLINT
load_with_byte_offset(frag,
pointer_offset * sizeof_bits<Element>::value / 8);
}
CUTLASS_DEVICE
void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { // NOLINT
AccessType* frag_ptr = reinterpret_cast<AccessType*>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < kAccessesPerVector; ++v) {
int idx = v + kAccessesPerVector *
(c + s * ThreadMap::Iterations::kContiguous);
address_iterator_.set_iteration_index(idx);
char const* byte_ptr =
reinterpret_cast<char const*>(address_iterator_.get()) +
byte_offset;
AccessType const* access_ptr =
reinterpret_cast<AccessType const*>(byte_ptr);
cutlass::arch::global_load<AccessType, sizeof(AccessType)>(
frag_ptr[idx], access_ptr, address_iterator_.valid());
++address_iterator_;
}
}
}
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load(Fragment& frag) { load_with_byte_offset(frag, 0); } // NOLINT
/// Store a fragment to memory
CUTLASS_DEVICE
void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
store_with_byte_offset(frag,
pointer_offset * sizeof_bits<Element>::value / 8);
}
/// Store a fragment to memory
CUTLASS_DEVICE
void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) {
address_iterator_.set_iteration_index(0);
AccessType const* frag_ptr = reinterpret_cast<AccessType const*>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) {
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) {
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < kAccessesPerVector; ++v) {
int idx = v + kAccessesPerVector *
(c + s * ThreadMap::Iterations::kContiguous);
char* byte_ptr =
reinterpret_cast<char*>(address_iterator_.get()) + byte_offset;
AccessType* access_ptr = reinterpret_cast<AccessType*>(byte_ptr);
if (address_iterator_.valid()) {
*access_ptr = frag_ptr[idx];
}
++address_iterator_;
}
}
}
}
/// Store a fragment to memory
CUTLASS_DEVICE
void store(Fragment const& frag) { store_with_byte_offset(frag, 0); }
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization of PredicatedTileIteratorResidualLast for affine rank 2
/// column-major data.
///
/// Satisfies: ForwardTileIteratorConcept |
/// ReadableContiguousTileIteratorConcept |
/// WriteableContiguousTileIteratorConcept |
/// MaskedTileIteratorConcept
///
template <typename Shape_,
typename Element_,
int AdvanceRank,
typename ThreadMap_,
int AccessSize>
class PredicatedTileIteratorResidualLast<Shape_,
Element_,
layout::AffineRank2ColumnMajor,
AdvanceRank,
ThreadMap_,
AccessSize,
false> {
public:
static_assert(
AdvanceRank == 0 || AdvanceRank == 1,
"Specialization for pitch-linear iterator may along advance along the "
"contiguous(rank=0) or strided(rank=1) dimension.");
using Shape = Shape_;
using Element = Element_;
using Layout = layout::AffineRank2ColumnMajor;
static int const kAdvanceRank = AdvanceRank;
using ThreadMap = ThreadMap_;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorRef = TensorRef<Element, Layout>;
using TensorView = TensorView<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Pointer = Element*;
using NonConstPointer = typename platform::remove_const<Element>::type*;
// Map to the underlying AffineRankN<2> layout
using UnderlyingIterator = PredicatedTileIteratorResidualLast<
layout::PitchLinearShape<Shape::kRow, Shape::kColumn>,
Element,
layout::AffineRankN<2>,
(kAdvanceRank == 0 ? 0 : 1),
ThreadMap,
AccessSize>;
using AccessType = typename UnderlyingIterator::AccessType;
/// Fragment object to be loaded or stored
using Fragment = cutlass::Array<Element,
ThreadMap::Iterations::kCount *
ThreadMap::kElementsPerAccess>;
/// Predicate vector stores mask to guard accesses
using Mask = typename UnderlyingIterator::Mask;
/// Parameters object is precomputed state and is host-constructible
class Params {
private:
friend PredicatedTileIteratorResidualLast;
/// Parameters object
typename UnderlyingIterator::Params params_;
public:
CUTLASS_HOST_DEVICE
Params() {}
/// Construct the Params object given an AffineRankN<2> tensor's layout
CUTLASS_HOST_DEVICE
Params(Layout const& layout) // NOLINT
: params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))) {}
};
private:
//
// Data members
//
/// Underlying AffineRankN<2> tile iterator
UnderlyingIterator iterator_;
public:
/// Constructs a TileIterator from its precomputed state, threadblock offset,
/// and thread ID
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast(
Params const& params, ///< Precomputed parameters object
Pointer pointer, ///< Pointer to start of tensor
TensorCoord extent, ///< Extent of tensor
int thread_id, ///< ID of each participating thread
TensorCoord const& threadblock_offset, ///< Initial offset of threadblock
int const* indices =
nullptr ///< gather/scatter indices, note no support for
///< gather/scatter at this specialization
)
: iterator_(params.params_,
pointer,
layout::PitchLinearCoord(extent.row(), extent.column()),
thread_id,
layout::PitchLinearCoord(threadblock_offset.row(),
threadblock_offset.column())) {}
/// Construct a PredicatedTileIteratorResidualLast with zero threadblock
/// offset
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast(
Params const& params, ///< Precomputed parameters object
Pointer pointer, ///< Pointer to start of tensor
TensorCoord extent, ///< Extent of tensor
int thread_id ///< ID of each participating thread
)
: PredicatedTileIteratorResidualLast(
params, pointer, extent, thread_id, make_Coord(0, 0)) {}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
iterator_.add_pointer_offset(pointer_offset);
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast& operator++() {
++iterator_;
return *this;
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast operator++(int) {
PredicatedTileIteratorResidualLast self(*this);
operator++();
return self;
}
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
CUTLASS_HOST_DEVICE
void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); }
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void enable_mask() { iterator_.enable_mask(); }
/// Sets the predicate mask, overriding value stored in predicate iterator
CUTLASS_HOST_DEVICE
void set_mask(Mask const& mask) { iterator_.set_mask(mask); }
/// Gets the mask
CUTLASS_HOST_DEVICE
void get_mask(Mask& mask) { iterator_.get_mask(mask); } // NOLINT
/// Loads a fragment from memory
CUTLASS_DEVICE
void load_with_pointer_offset(Fragment& frag, // NOLINT
Index pointer_offset) { // NOLINT
iterator_.load_with_pointer_offset(frag, pointer_offset);
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { // NOLINT
iterator_.load_with_byte_offset(frag, byte_offset);
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load(Fragment& frag) { load_with_pointer_offset(frag, 0); } // NOLINT
/// Store a fragment to memory
CUTLASS_DEVICE
void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
iterator_.store_with_pointer_offset(frag, pointer_offset);
}
/// Store a fragment to memory
CUTLASS_DEVICE
void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) {
iterator_.store_with_byte_offset(frag, byte_offset);
}
/// Store a fragment to memory
CUTLASS_DEVICE
void store(Fragment const& frag) { store_with_pointer_offset(frag, 0); }
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization of PredicatedTileIteratorResidualLast for affine rank 2
/// row-major data.
///
/// Satisfies: ForwardTileIteratorConcept |
/// ReadableContiguousTileIteratorConcept |
/// WriteableContiguousTileIteratorConcept |
/// MaskedTileIteratorConcept
///
template <typename Shape_,
typename Element_,
int AdvanceRank,
typename ThreadMap_,
int AccessSize>
class PredicatedTileIteratorResidualLast<Shape_,
Element_,
layout::AffineRank2RowMajor,
AdvanceRank,
ThreadMap_,
AccessSize,
false> {
public:
static_assert(
AdvanceRank == 0 || AdvanceRank == 1,
"Specialization for pitch-linear iterator may along advance along the "
"contiguous(rank=0) or strided(rank=1) dimension.");
using Shape = Shape_;
using Element = Element_;
using Layout = layout::AffineRank2RowMajor;
static int const kAdvanceRank = AdvanceRank;
using ThreadMap = ThreadMap_;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorRef = TensorRef<Element, Layout>;
using TensorView = TensorView<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Pointer = Element*;
using NonConstPointer = typename platform::remove_const<Element>::type*;
// Map to the underlying AffineRankN<2> layout
using UnderlyingIterator = PredicatedTileIteratorResidualLast<
layout::PitchLinearShape<Shape::kColumn, Shape::kRow>,
Element,
layout::AffineRankN<2>,
(kAdvanceRank == 0 ? 1 : 0),
ThreadMap,
AccessSize>;
using AccessType = typename UnderlyingIterator::AccessType;
/// Fragment object to be loaded or stored
using Fragment = cutlass::Array<Element,
ThreadMap::Iterations::kCount *
ThreadMap::kElementsPerAccess>;
/// Predicate vector stores mask to guard accesses
using Mask = typename UnderlyingIterator::Mask;
/// Parameters object is precomputed state and is host-constructible
class Params {
private:
friend PredicatedTileIteratorResidualLast;
/// Parameters object
typename UnderlyingIterator::Params params_;
public:
CUTLASS_HOST_DEVICE
Params() {}
/// Construct the Params object given an AffineRankN<2> tensor's layout
CUTLASS_HOST_DEVICE
Params(Layout const& layout) // NOLINT
: params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))) {}
};
private:
//
// Data members
//
/// Underlying AffineRankN<2> tile iterator
UnderlyingIterator iterator_;
public:
/// Constructs a TileIterator from its precomputed state, threadblock offset,
/// and thread ID
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast(
Params const& params, ///< Precomputed parameters object
Pointer pointer, ///< Pointer to start of tensor
TensorCoord extent, ///< Extent of tensor
int thread_id, ///< ID of each participating thread
TensorCoord const& threadblock_offset, ///< Initial offset of threadblock
int const* indices =
nullptr ///< gather/scatter indices, note no support for
///< gather/scatter at this specialization
)
: iterator_(params.params_,
pointer,
layout::PitchLinearCoord(extent.column(), extent.row()),
thread_id,
layout::PitchLinearCoord(threadblock_offset.column(),
threadblock_offset.row())) {}
/// Construct a PredicatedTileIteratorResidualLast with zero threadblock
/// offset
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast(
Params const& params, ///< Precomputed parameters object
Pointer pointer, ///< Pointer to start of tensor
TensorCoord extent, ///< Extent of tensor
int thread_id ///< ID of each participating thread
)
: PredicatedTileIteratorResidualLast(
params, pointer, extent, thread_id, make_Coord(0, 0)) {}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
iterator_.add_pointer_offset(pointer_offset);
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast& operator++() {
++iterator_;
return *this;
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast operator++(int) {
PredicatedTileIteratorResidualLast self(*this);
operator++();
return self;
}
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
CUTLASS_HOST_DEVICE
void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); }
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void enable_mask() { iterator_.enable_mask(); }
/// Sets the predicate mask, overriding value stored in predicate iterator
CUTLASS_HOST_DEVICE
void set_mask(Mask const& mask) { iterator_.set_mask(mask); }
/// Gets the mask
CUTLASS_HOST_DEVICE
void get_mask(Mask& mask) { iterator_.get_mask(mask); } // NOLINT
/// Loads a fragment from memory
CUTLASS_DEVICE
void load_with_pointer_offset(Fragment& frag, // NOLINT
Index pointer_offset) { // NOLINT
iterator_.load_with_pointer_offset(frag, pointer_offset);
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { // NOLINT
iterator_.load_with_byte_offset(frag, byte_offset);
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load(Fragment& frag) { load_with_pointer_offset(frag, 0); } // NOLINT
/// Store a fragment to memory
CUTLASS_DEVICE
void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
iterator_.store_with_pointer_offset(frag, pointer_offset);
}
/// Store a fragment to memory
CUTLASS_DEVICE
void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) {
iterator_.store_with_byte_offset(frag, byte_offset);
}
/// Store a fragment to memory
CUTLASS_DEVICE
void store(Fragment const& frag) { store_with_pointer_offset(frag, 0); }
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization of PredicatedTileIteratorResidualLast for interleaved data.
/// It is mapped to the congruous layout.
///
/// Satisfies: ForwardTileIteratorConcept |
/// ReadableContiguousTileIteratorConcept |
/// WriteableContiguousTileIteratorConcept |
/// MaskedTileIteratorConcept
///
template <typename Shape_,
typename Element_,
int AdvanceRank,
typename ThreadMap_,
int AccessSize,
int InterleavedK>
class PredicatedTileIteratorResidualLast<
Shape_,
Element_,
layout::ColumnMajorInterleaved<InterleavedK>,
AdvanceRank,
ThreadMap_,
AccessSize,
false> {
public:
static_assert(
AdvanceRank == 0 || AdvanceRank == 1,
"Specialization for pitch-linear iterator may along advance along the "
"contiguous(rank=0) or strided(rank=1) dimension.");
using Shape = Shape_;
using Element = Element_;
static int const kInterleavedK = InterleavedK;
using Layout = layout::ColumnMajorInterleaved<kInterleavedK>;
static int const kAdvanceRank = AdvanceRank;
using ThreadMap = ThreadMap_;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorRef = TensorRef<Element, Layout>;
using TensorView = TensorView<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Pointer = Element*;
using NonConstPointer = typename platform::remove_const<Element>::type*;
using UnderlyingIterator = PredicatedTileIteratorResidualLast<
layout::PitchLinearShape<Shape::kRow * kInterleavedK,
Shape::kColumn / kInterleavedK>,
Element,
layout::PitchLinear,
(kAdvanceRank == 0 ? 0 : 1),
ThreadMap,
AccessSize>;
using AccessType = typename UnderlyingIterator::AccessType;
/// Fragment object to be loaded or stored
using Fragment = cutlass::Array<Element,
ThreadMap::Iterations::kCount *
ThreadMap::kElementsPerAccess>;
/// Predicate vector stores mask to guard accesses
using Mask = typename UnderlyingIterator::Mask;
/// Parameters object is precomputed state and is host-constructible
class Params {
private:
friend PredicatedTileIteratorResidualLast;
/// Parameters object
typename UnderlyingIterator::Params params_;
public:
CUTLASS_HOST_DEVICE
Params() {}
/// Construct the Params object given a pitch-linear tensor's layout
CUTLASS_HOST_DEVICE
Params(Layout const& layout) // NOLINT
: params_(layout::PitchLinear(layout.stride(0))) {}
CUTLASS_HOST_DEVICE
Params(typename UnderlyingIterator::Params::Base const& base) // NOLINT
: params_(base) {}
};
private:
//
// Data members
//
/// Underlying pitch-linear tile iterator
UnderlyingIterator iterator_;
public:
/// Constructs a TileIterator from its precomputed state, threadblock offset,
/// and thread ID
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast(
/// Precomputed parameters object
Params const& params,
/// Pointer to start of tensor
Pointer pointer,
/// Extent of tensor
TensorCoord extent,
/// ID of each participating thread
int thread_id,
/// Initial offset of threadblock
TensorCoord const& threadblock_offset,
int const* indices =
nullptr ///< gather/scatter indices, note no support for
///< gather/scatter at this specialization
)
: iterator_(params.params_,
pointer,
layout::PitchLinearCoord(extent.row() * kInterleavedK,
extent.column() / kInterleavedK),
thread_id,
layout::PitchLinearCoord(
threadblock_offset.row() * kInterleavedK,
threadblock_offset.column() / kInterleavedK)) {}
/// Construct a PredicatedTileIteratorResidualLast with zero threadblock
/// offset
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast(
Params const& params, ///< Precomputed parameters object
Pointer pointer, ///< Pointer to start of tensor
TensorCoord extent, ///< Extent of tensor
int thread_id ///< ID of each participating thread
)
: PredicatedTileIteratorResidualLast(
params, pointer, extent, thread_id, make_Coord(0, 0)) {}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
iterator_.add_pointer_offset(pointer_offset);
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast& operator++() {
++iterator_;
return *this;
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast operator++(int) {
PredicatedTileIteratorResidualLast self(*this);
operator++();
return self;
}
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
CUTLASS_HOST_DEVICE
void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); }
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void enable_mask() { iterator_.enable_mask(); }
/// Sets the predicate mask, overriding value stored in predicate iterator
CUTLASS_HOST_DEVICE
void set_mask(Mask const& mask) { iterator_.set_mask(mask); }
/// Gets the mask
CUTLASS_HOST_DEVICE
void get_mask(Mask& mask) { iterator_.get_mask(mask); } // NOLINT
/// Loads a fragment from memory
CUTLASS_DEVICE
void load_with_pointer_offset(Fragment& frag, // NOLINT
Index pointer_offset) { // NOLINT
iterator_.load_with_pointer_offset(frag, pointer_offset);
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load(Fragment& frag) { load_with_pointer_offset(frag, 0); } // NOLINT
/// Store a fragment to memory
CUTLASS_DEVICE
void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
iterator_.store_with_pointer_offset(frag, pointer_offset);
}
/// Store a fragment to memory
CUTLASS_DEVICE
void store(Fragment const& frag) { store_with_pointer_offset(frag, 0); }
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization of PredicatedTileIteratorResidualLast for interleaved-32
/// data. It is mapped to the congruous layout.
///
/// Satisfies: ForwardTileIteratorConcept |
/// ReadableContiguousTileIteratorConcept |
/// WriteableContiguousTileIteratorConcept |
/// MaskedTileIteratorConcept
///
template <typename Shape_,
typename Element_,
int AdvanceRank,
typename ThreadMap_,
int AccessSize,
int InterleavedK>
class PredicatedTileIteratorResidualLast<
Shape_,
Element_,
layout::RowMajorInterleaved<InterleavedK>,
AdvanceRank,
ThreadMap_,
AccessSize,
false> {
public:
static_assert(
AdvanceRank == 0 || AdvanceRank == 1,
"Specialization for pitch-linear iterator may along advance along the "
"contiguous(rank=0) or strided(rank=1) dimension.");
using Shape = Shape_;
using Element = Element_;
static int const kInterleavedK = InterleavedK;
using Layout = layout::RowMajorInterleaved<kInterleavedK>;
static int const kAdvanceRank = AdvanceRank;
using ThreadMap = ThreadMap_;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorRef = TensorRef<Element, Layout>;
using TensorView = TensorView<Element, Layout>;
using TensorCoord = typename Layout::TensorCoord;
using Pointer = Element*;
using NonConstPointer = typename platform::remove_const<Element>::type*;
using UnderlyingIterator = PredicatedTileIteratorResidualLast<
layout::PitchLinearShape<Shape::kColumn * kInterleavedK,
Shape::kRow / kInterleavedK>,
Element,
layout::PitchLinear,
(kAdvanceRank == 0 ? 1 : 0),
ThreadMap,
AccessSize>;
using AccessType = typename UnderlyingIterator::AccessType;
/// Fragment object to be loaded or stored
using Fragment = cutlass::Array<Element,
ThreadMap::Iterations::kCount *
ThreadMap::kElementsPerAccess>;
/// Predicate vector stores mask to guard accesses
using Mask = typename UnderlyingIterator::Mask;
/// Parameters object is precomputed state and is host-constructible
class Params {
private:
friend PredicatedTileIteratorResidualLast;
/// Parameters object
typename UnderlyingIterator::Params params_;
public:
CUTLASS_HOST_DEVICE
Params() {}
/// Construct the Params object given a pitch-linear tensor's layout
CUTLASS_HOST_DEVICE
Params(Layout const& layout) // NOLINT
: params_(layout::PitchLinear(layout.stride(0))) {}
CUTLASS_HOST_DEVICE
Params(typename UnderlyingIterator::Params::Base const& base) // NOLINT
: params_(base) {}
};
private:
//
// Data members
//
/// Underlying pitch-linear tile iterator
UnderlyingIterator iterator_;
public:
/// Constructs a TileIterator from its precomputed state, threadblock offset,
/// and thread ID
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast(
/// Precomputed parameters object
Params const& params,
/// Pointer to start of tensor
Pointer pointer,
/// Extent of tensor
TensorCoord extent,
/// ID of each participating thread
int thread_id,
/// Initial offset of threadblock
TensorCoord const& threadblock_offset,
int const* indices =
nullptr ///< gather/scatter indices, note no support for
///< gather/scatter at this specialization
)
: iterator_(params.params_,
pointer,
layout::PitchLinearCoord(extent.column() * kInterleavedK,
extent.row() / kInterleavedK),
thread_id,
layout::PitchLinearCoord(
threadblock_offset.column() * kInterleavedK,
threadblock_offset.row() / kInterleavedK)) {}
/// Construct a PredicatedTileIteratorResidualLast with zero threadblock
/// offset
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast(
Params const& params, ///< Precomputed parameters object
Pointer pointer, ///< Pointer to start of tensor
TensorCoord extent, ///< Extent of tensor
int thread_id ///< ID of each participating thread
)
: PredicatedTileIteratorResidualLast(
params, pointer, extent, thread_id, make_Coord(0, 0)) {}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
iterator_.add_pointer_offset(pointer_offset);
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast& operator++() {
++iterator_;
return *this;
}
/// Advances to the next tile in memory.
///
/// The first time this method is called, predicates are updated, and the
/// iterator's internal pointer is reverted to the first "steady state" tile.
/// Subsequent calls are lightweight and must only update the internal
/// pointer.
CUTLASS_HOST_DEVICE
PredicatedTileIteratorResidualLast operator++(int) {
PredicatedTileIteratorResidualLast self(*this);
operator++();
return self;
}
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void clear_mask(bool enable = true) { iterator_.clear_mask(enable); }
CUTLASS_HOST_DEVICE
void set_residual_tile(bool enable) { iterator_.set_residual_tile(enable); }
/// Clears the predicate set efficiently
CUTLASS_HOST_DEVICE
void enable_mask() { iterator_.enable_mask(); }
/// Sets the predicate mask, overriding value stored in predicate iterator
CUTLASS_HOST_DEVICE
void set_mask(Mask const& mask) { iterator_.set_mask(mask); }
/// Gets the mask
CUTLASS_HOST_DEVICE
void get_mask(Mask& mask) { iterator_.get_mask(mask); } // NOLINT
/// Loads a fragment from memory
CUTLASS_DEVICE
void load_with_pointer_offset(Fragment& frag, // NOLINT
Index pointer_offset) { // NOLINT
iterator_.load_with_pointer_offset(frag, pointer_offset);
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load(Fragment& frag) { load_with_pointer_offset(frag, 0); } // NOLINT
/// Store a fragment to memory
CUTLASS_DEVICE
void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) {
iterator_.store_with_pointer_offset(frag, pointer_offset);
}
/// Store a fragment to memory
CUTLASS_DEVICE
void store(Fragment const& frag) { store_with_pointer_offset(frag, 0); }
};
////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace transform
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include "./warp_iterator_from_smem.h"
template <typename WarpIterator>
struct TransposeWarpIterator {
using Iterator = char;
static bool constexpr kSupportsTranspose = false;
};
template <
/// Operand identity
cutlass::gemm::Operand Operand,
/// Data type of A elements
typename Element,
bool kTranspose>
struct TransposeWarpIterator<
cutlass::gemm::warp::WarpIteratorFromSmem<Operand, Element, kTranspose>> {
using Iterator =
cutlass::gemm::warp::WarpIteratorFromSmem<Operand, Element, !kTranspose>;
static bool constexpr kSupportsTranspose = true;
};
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Inspired from
"cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" Loads tiles of GEMM
operands from a RowMajor shared-memory layout into registers to use by A100
TensorCores.
The difference with "mma_tensor_op_tile_access_iterator.h" is that:
(1) We use "ldmatrix" to load tiles, rather than manual loads (slightly
faster) (2) We support to transpose the operand (eg read `A.transpose()` when
the shared memory holds `A`)
This is only implemented for the specific shapes that are interesting for us
*/
#pragma once
#include <cutlass/gemm/gemm.h>
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace warp {
template <
/// Operand identity
Operand Operand_,
/// Data type of A elements
typename Element_,
bool kTranspose = false>
class WarpIteratorFromSmem {
public:
/// Shape of tile to load (concept: MatrixShape)
using Shape = cutlass::MatrixShape<32, 32>;
/// Operand tag
static Operand const kOperand = Operand_;
/// Basic check
static_assert(kOperand == Operand::kA || kOperand == Operand::kB,
"WarpIteratorFromSmem may only be instantiated for A or B "
"operands to warp-level Mma.");
/// Element type
using Element = Element_;
static_assert(sizeof_bits<Element>::value == 16, "Only supported for half");
/// Layout of source tile
using Layout = cutlass::layout::RowMajor;
/// Shape of one matrix product operation (concept: MatrixShape)
using InstructionShape = cutlass::MatrixShape<16, 8>;
/// Delta between *MMA operations (in units of *MMA operations, concept:
/// MatrixShape)
static int const kOpDelta = 1;
/// Number of participating threads
static int const kThreads = 32;
/// TensorRef type for loading element from a tensor
using TensorRef = TensorRef<Element, Layout>;
/// Index type
using Index = typename TensorRef::Index;
/// Long Index type
using LongIndex = typename TensorRef::LongIndex;
/// Coordinate for an element in the tensor
using TensorCoord = typename TensorRef::TensorCoord;
/// Number of elements accessed per Shared Memory load
static int const kElementsPerAccess =
(sizeof_bits<Element>::value >= 32 ? 1
: 32 / sizeof_bits<Element>::value);
using InstructionCount =
MatrixShape<Shape::kRow / InstructionShape::kRow,
Shape::kColumn / InstructionShape::kColumn>;
static int const kIterations = (kOperand == Operand::kA)
? InstructionCount::kColumn
: InstructionCount::kRow;
public:
//
// Derived quantities
//
/// Fragment object holding a thread's part of a tile
using Fragment =
Array<Element,
(kOperand == Operand::kA)
? (Shape::kRow* InstructionShape::kColumn / kThreads)
: (Shape::kColumn* InstructionShape::kRow / kThreads)>;
/// Memory access type
// using AccessType = AlignedArray<Element, kElementsPerAccess>;
using AccessType = Array<unsigned, 4>;
static int constexpr kWarpShapeDivisibleInner =
(kOperand == Operand::kA ? InstructionShape::kColumn
: InstructionShape::kRow);
static int constexpr kAccessesInner =
(kWarpShapeDivisibleInner / kElementsPerAccess) / 4;
static int const kTilesPerInstruction = InstructionShape::kRow / 8;
private:
/// Underlying tensor reference
TensorRef ref_;
/// Origin
MatrixCoord origin_;
/// Iterations in a tile
int iterations_;
public:
/// Constructor from TensorRef
CUTLASS_HOST_DEVICE
WarpIteratorFromSmem(TensorRef const& ref, int lane_id)
: WarpIteratorFromSmem(ref, {Shape::kRow, Shape::kColumn}, lane_id) {}
CUTLASS_HOST_DEVICE
WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id)
: ref_(ref), iterations_(0) {
int ldsm_vec_num = (lane_id >> 3);
if (kOperand == Operand::kA) {
origin_ = MatrixCoord(lane_id % 8, 0);
static_assert(
InstructionCount::kRow * kAccessesInner * kTilesPerInstruction == 4,
"");
CUTLASS_PRAGMA_UNROLL
for (int inst_m_idx = 0; inst_m_idx < InstructionCount::kRow;
++inst_m_idx) {
CUTLASS_PRAGMA_UNROLL
for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) {
CUTLASS_PRAGMA_UNROLL
for (int access_m_idx = 0; access_m_idx < kTilesPerInstruction;
++access_m_idx) {
int access_idx =
access_m_idx + kTilesPerInstruction *
(inner_idx + kAccessesInner * inst_m_idx);
MatrixCoord offset(
access_m_idx * 8 + inst_m_idx * InstructionShape::kRow,
inner_idx * 4 * kElementsPerAccess);
if (access_idx == ldsm_vec_num) {
if (kTranspose) {
offset = MatrixCoord(offset.column(), offset.row());
}
origin_ += offset;
}
}
}
}
} else {
origin_ = MatrixCoord(0, lane_id % 8);
static_assert(InstructionCount::kColumn * kAccessesInner == 4, "");
CUTLASS_PRAGMA_UNROLL
for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn;
++inst_n_idx) {
CUTLASS_PRAGMA_UNROLL
for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) {
int access_idx = inner_idx + kAccessesInner * inst_n_idx;
MatrixCoord offset(inner_idx * 4 * kElementsPerAccess,
inst_n_idx * 8);
if (access_idx == ldsm_vec_num) {
if (kTranspose) {
offset = MatrixCoord(offset.column(), offset.row());
}
origin_ += offset;
}
}
}
}
ref_.add_coord_offset(origin_);
}
/// Advances an iterator along logical dimensions of matrix in units of whole
/// tiles
CUTLASS_HOST_DEVICE
WarpIteratorFromSmem& add_tile_offset(TensorCoord const& tile_offset) {
TensorCoord coord_offset(tile_offset.row() * Shape::kRow,
tile_offset.column() * Shape::kColumn);
if (kTranspose) {
coord_offset = TensorCoord{coord_offset.column(), coord_offset.row()};
}
origin_ += coord_offset;
ref_.add_coord_offset(coord_offset);
return *this;
}
/// Advances the iterator along the advance dimension
CUTLASS_DEVICE
void advance() {
if (kOperand == Operand::kA) {
add_tile_offset({0, 1});
} else {
add_tile_offset({1, 0});
}
iterations_ = 0;
}
/// increase iterations in a tile
CUTLASS_HOST_DEVICE
WarpIteratorFromSmem& operator++() {
iterations_++;
if (iterations_ >= kIterations) advance();
return *this;
}
/// Loads a fragment from memory at the location pointed to by the iterator.
CUTLASS_DEVICE
void load(Fragment& frag) const { // NOLINT
AccessType* access_ptr = reinterpret_cast<AccessType*>(&frag);
using LoadLayout = typename platform::
conditional<kTranspose, layout::ColumnMajor, layout::RowMajor>::type;
MatrixCoord offset;
if (kOperand == Operand::kA) {
offset = MatrixCoord(0, iterations_ * InstructionShape::kColumn);
} else {
offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0);
}
if (kTranspose) {
offset = MatrixCoord(offset.column(), offset.row());
}
cutlass::arch::ldsm<LoadLayout, 4>(access_ptr[0],
ref_.data() + ref_.offset(offset));
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace warp
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <cmath>
#include <type_traits>
#include <vector>
#include <cuda_fp16.h> //NOLINT
#include <curand_kernel.h> //NOLINT
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/scale_type.h"
#include "cutlass/fast_math.h"
#include "cutlass/functional.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/vector.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/thread/linear_combination_relu.h"
#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h"
#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/kernel/default_gemm.h"
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/platform/platform.h"
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
#include "cutlass/transform/threadblock/vector_iterator.h"
#include <inttypes.h> //NOLINT
#include "./debug_utils.h"
#include "./gemm_kernel_utils.h"
#include "epilogue/epilogue_pipelined.h"
#include "gemm/custom_mma.h"
#include "gemm/find_default_mma.h"
#include "gemm/mma_accum_lambda_iterator.h"
#include "gemm/mma_from_smem.h"
#include "iterators/epilogue_predicated_tile_iterator.h"
#include "transform/tile_smem_loader.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/phi/core/enforce.h"
namespace phi {
using namespace gemm_kernel_utils; // NOLINT
namespace { // NOLINT
template <typename FragmentType, int32_t kNumThreads>
struct GmemTile {
/*
Helper functions to efficient store/load RF to gmem
GEMM accumulators have a particular format on A100, and
it takes some compute/shared-memory to rearrange them to
a RowMajor or ColumnMajor format in global memory through
an Epilogue. The same complexity goes for loading into RF.
This class loads/stores RF as they are, and can be used for
efficient accumulation across gemms for instance:
```
GmemTile tile;
for (int i = 0; i < N; ++i) {
// ...
Fragment accum;
if (i == 0) {
accum.clear();
} else {
tile.load(accum);
}
mma(accum, ...);
if (i < N-1) {
// Store for next GEMM
tile.store(accum);
} else {
// Store in tensor (eg RowMajor)
epilogue(accum);
}
// ...
}
```
*/
// 128bits per thread
using AccessType = cutlass::Array<float, 4>;
static constexpr int32_t kBytes = sizeof(AccessType);
static constexpr int32_t kStride = kNumThreads * AccessType::kElements;
static constexpr int32_t kNumIters =
FragmentType::kElements / AccessType::kElements;
static constexpr int32_t kElementsStored =
kNumThreads * FragmentType::kElements;
static_assert(FragmentType::kElements % AccessType::kElements == 0,
"fragment not aligned on 128 bits");
float* ptr;
CUTLASS_DEVICE void load(FragmentType& fragment, int thread_id) { // NOLINT
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNumIters; ++i) {
AccessType* __restrict__ gmem_ptr = reinterpret_cast<AccessType*>(
ptr + thread_id * AccessType::kElements + i * kStride);
AccessType sub_fragment;
cutlass::arch::global_load<AccessType, kBytes>(
sub_fragment, gmem_ptr, true);
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < AccessType::kElements; ++j) {
fragment[i * AccessType::kElements + j] = sub_fragment[j];
}
}
}
CUTLASS_DEVICE void store(FragmentType const& fragment, int thread_id) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNumIters; ++i) {
AccessType* __restrict__ gmem_ptr = reinterpret_cast<AccessType*>(
ptr + thread_id * AccessType::kElements + i * kStride);
AccessType sub_fragment;
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < AccessType::kElements; ++j) {
sub_fragment[j] = fragment[i * AccessType::kElements + j];
}
cutlass::arch::global_store<AccessType, kBytes>(
sub_fragment, gmem_ptr, true);
}
}
};
template <typename scalar_t, typename Arch>
constexpr int getWarpsPerSm() {
bool is_half = !std::is_same<scalar_t, float>::value;
if (Arch::kMinComputeCapability >= 80) {
return is_half ? 12 : 8;
}
return 8;
}
} // namespace
template <
// which arch we target (eg `cutlass::arch::Sm80`)
typename ArchTag_,
// input/output type
typename scalar_t_,
// run optimized kernel because memory accesses will be aligned
bool kIsAligned_,
// use dropout if enabled
bool kApplyDropout_,
// when doing a GEMM, preload the next one (uses more shmem)
bool kPreloadMmas_,
// block dimensions
int kBlockSizeI_,
int kBlockSizeJ_,
// upperbound on `max(value.shape[-1], query.shape[-1])`
int kMaxK_ = std::numeric_limits<int>::max()>
struct AttentionBackwardKernel {
using scalar_t = scalar_t_;
using output_t = scalar_t;
using output_accum_t = float;
using lse_scalar_t = float;
using accum_t = float;
using ArchTag = ArchTag_;
static constexpr bool kIsAligned = kIsAligned_;
static constexpr bool kApplyDropout = kApplyDropout_;
static constexpr bool kPreloadMmas = kPreloadMmas_;
static constexpr int kBlockSizeI = kBlockSizeI_;
static constexpr int kBlockSizeJ = kBlockSizeJ_;
static constexpr int kMaxK = kMaxK_;
struct Params {
// Input tensors
scalar_t* query_ptr; // [Mq, nH, K]
scalar_t* key_ptr; // [Mk, nH, K]
scalar_t* value_ptr; // [Mk, nH, Kv]
scalar_t* bias_ptr = nullptr;
lse_scalar_t* logsumexp_ptr; // [nH, Mq]
scalar_t* output_ptr; // [Mq, nH, Kv]
scalar_t* grad_output_ptr; // [Mq, nH, Kv]
accum_t* delta_ptr; // [nH, Mq]
int32_t* cu_seqlens_q_ptr = nullptr;
int32_t* cu_seqlens_k_ptr = nullptr;
// Output tensors
output_t* grad_query_ptr; // [Mq, nH, K]
output_t* grad_key_ptr; // [Mk, nH, K]
output_t* grad_value_ptr; // [Mk, nH, Kv]
output_t* grad_bias_ptr = nullptr;
// Accumulators
union {
output_accum_t* workspace = nullptr; // [Mq, Kq] + [Mkv, Kq] + [Mkv, Kv]
output_accum_t* workspace_gk;
};
output_accum_t* workspace_gv;
output_accum_t* workspace_gq;
// Scale
accum_t scale;
// Dimensions/strides
int32_t head_dim;
int32_t head_dim_value;
int32_t num_queries;
int32_t num_keys;
int32_t num_heads;
bool causal;
int32_t q_strideM;
int32_t k_strideM;
int32_t v_strideM;
int32_t bias_strideM = 0;
int32_t gO_strideM;
int32_t gB_strideM;
int8_t gQKV_strideM_multiplier; // 3 for packed, 1 otherwise
// dropout
uint64_t seed;
uint64_t offset;
// RNG sequence offset based on batch_id and head_id
unsigned long long dropout_batch_head_rng_offset; // NOLINT
float dropout_prob;
CUTLASS_HOST_DEVICE int32_t o_strideM() const {
return head_dim_value * num_heads;
}
CUTLASS_HOST_DEVICE int32_t gQ_strideM() const {
return gQKV_strideM_multiplier * num_heads * head_dim;
}
CUTLASS_HOST_DEVICE int32_t gK_strideM() const {
return gQKV_strideM_multiplier * num_heads * head_dim;
}
CUTLASS_HOST_DEVICE int32_t gV_strideM() const {
return gQKV_strideM_multiplier * num_heads * head_dim_value;
}
// Everything below is only used in `advance_to_block`
// and shouldn't use registers
int64_t o_strideH;
int32_t q_strideH;
int32_t k_strideH;
int32_t v_strideH;
int32_t bias_strideH = 0;
int64_t o_strideB;
int64_t q_strideB;
int64_t k_strideB;
int64_t v_strideB;
int64_t bias_strideB = 0;
int64_t lse_strideB;
int64_t lse_strideH;
int64_t delta_strideB;
int64_t delta_strideH;
int32_t num_batches;
int64_t gO_strideB;
int64_t gQ_strideB;
int64_t gK_strideB;
int64_t gV_strideB;
int64_t gB_strideB;
int64_t gO_strideH;
int64_t gQ_strideH;
int64_t gK_strideH;
int64_t gV_strideH;
int64_t gB_strideH;
CUTLASS_DEVICE bool advance_to_block() {
int64_t batch_id = blockIdx.z;
int32_t head_id = blockIdx.y;
if (kNeedsAccumGradQ || kNeedsAccumGradK || kNeedsAccumGradV) {
assert(workspace_size() == 0 || workspace != nullptr);
workspace += (batch_id * num_heads + head_id) * workspace_strideBH();
workspace = warp_uniform(workspace);
workspace_gv = workspace + workspace_elements_gk();
workspace_gq = workspace_gv + workspace_elements_gv();
} else {
workspace = nullptr;
}
// Advance pointers that depend on the total concatenated
// number of queries, as `num_queries` is modified in the block
// below
dropout_batch_head_rng_offset =
batch_id * (num_heads * num_queries * num_keys) +
head_id * (num_queries * num_keys);
logsumexp_ptr += batch_id * lse_strideB + head_id * lse_strideH;
if (cu_seqlens_q_ptr != nullptr) {
assert(cu_seqlens_k_ptr != nullptr);
cu_seqlens_q_ptr += batch_id;
cu_seqlens_k_ptr += batch_id;
int32_t q_start = cu_seqlens_q_ptr[0];
int32_t k_start = cu_seqlens_k_ptr[0];
int64_t q_next_start = cu_seqlens_q_ptr[1];
int64_t k_next_start = cu_seqlens_k_ptr[1];
assert(q_next_start - q_start <= num_queries);
assert(k_next_start - k_start <= num_keys);
num_queries = q_next_start - q_start;
num_keys = k_next_start - k_start;
// Jump manually
batch_id = 0;
query_ptr += q_start * q_strideM;
key_ptr += k_start * k_strideM;
value_ptr += k_start * v_strideM;
assert(bias_ptr == nullptr);
assert(grad_bias_ptr == nullptr);
output_ptr += q_start * o_strideM();
grad_output_ptr += q_start * gO_strideM;
delta_ptr += q_start;
grad_query_ptr += q_start * gQ_strideM();
grad_key_ptr += k_start * gK_strideM();
grad_value_ptr += k_start * gV_strideM();
}
query_ptr += batch_id * q_strideB + head_id * q_strideH;
key_ptr += batch_id * k_strideB + head_id * k_strideH;
value_ptr += batch_id * v_strideB + head_id * v_strideH;
if (bias_ptr != nullptr) {
bias_ptr += batch_id * bias_strideB + head_id * bias_strideH;
}
output_ptr += batch_id * o_strideB + head_id * o_strideH;
grad_output_ptr += batch_id * gO_strideB + head_id * gO_strideH;
delta_ptr += batch_id * delta_strideB + head_id * delta_strideH;
grad_query_ptr += batch_id * gQ_strideB + head_id * gQ_strideH;
grad_key_ptr += batch_id * gK_strideB + head_id * gK_strideH;
grad_value_ptr += batch_id * gV_strideB + head_id * gV_strideH;
if (grad_bias_ptr != nullptr) {
grad_bias_ptr += batch_id * gB_strideB + head_id * gB_strideH;
}
head_dim = warp_uniform(head_dim);
head_dim_value = warp_uniform(head_dim_value);
num_queries = warp_uniform(num_queries);
num_keys = warp_uniform(num_keys);
num_heads = warp_uniform(num_heads);
gO_strideM = warp_uniform(gO_strideM);
gQKV_strideM_multiplier = warp_uniform(gQKV_strideM_multiplier);
q_strideM = warp_uniform(q_strideM);
k_strideM = warp_uniform(k_strideM);
v_strideM = warp_uniform(v_strideM);
query_ptr = warp_uniform(query_ptr);
key_ptr = warp_uniform(key_ptr);
value_ptr = warp_uniform(value_ptr);
bias_ptr = warp_uniform(bias_ptr);
logsumexp_ptr = warp_uniform(logsumexp_ptr);
output_ptr = warp_uniform(output_ptr);
grad_output_ptr = warp_uniform(grad_output_ptr);
delta_ptr = warp_uniform(delta_ptr);
grad_query_ptr = warp_uniform(grad_query_ptr);
grad_key_ptr = warp_uniform(grad_key_ptr);
grad_value_ptr = warp_uniform(grad_value_ptr);
grad_bias_ptr = warp_uniform(grad_bias_ptr);
#if 0
PRINT_T0("[b:%d h:%d] dp[0]:%f Q:%f K:%f V:%f LSE:%f",
int(blockIdx.z), int(blockIdx.y), //NOLINT
float(delta_ptr[0]), //NOLINT
float(query_ptr[0]), float(key_ptr[0]), float(value_ptr[0]), //NOLINT
float(logsumexp_ptr[0]) //NOLINT
)
#endif
return true;
}
__host__ dim3 getBlocksGrid() const {
return dim3(1, num_heads, num_batches);
}
__host__ dim3 getThreadsGrid() const {
return dim3(kWarpSize, kNumWarpsPerBlock, 1);
}
CUTLASS_HOST_DEVICE int64_t workspace_elements_gk() const {
if (!kNeedsAccumGradK) {
return 0;
}
return align_up(num_keys, (int32_t)kBlockSizeJ) *
align_up(head_dim, (int32_t)kBlockSizeI);
}
CUTLASS_HOST_DEVICE int64_t workspace_elements_gv() const {
if (!kNeedsAccumGradV) {
return 0;
}
return align_up(num_keys, (int32_t)kBlockSizeJ) *
align_up(head_dim_value, (int32_t)kBlockSizeI);
}
CUTLASS_HOST_DEVICE int64_t workspace_elements_gq() const {
if (!kNeedsAccumGradQ) {
return 0;
}
if (num_keys <= kBlockSizeJ) {
return 0;
}
return align_up(num_queries, (int32_t)kBlockSizeI) *
align_up(head_dim, (int32_t)kBlockSizeJ);
}
CUTLASS_HOST_DEVICE int64_t workspace_strideBH() const {
// Aligned on 128bits
return align_up(workspace_elements_gk() + workspace_elements_gv() +
workspace_elements_gq(),
int64_t(4));
}
CUTLASS_HOST_DEVICE int64_t workspace_size() const {
// Returns size of buffer we need to run this kernel
return num_batches * num_heads * workspace_strideBH() * sizeof(float);
}
};
static constexpr int64_t kWarpSize = 32;
// If this is true, we store and accumulate dK/dV in RF
// rather than going back to gmem everytime
static constexpr bool kIsHalf = cutlass::sizeof_bits<scalar_t>::value <= 16;
static constexpr bool kOutputInRF = kIsHalf && kMaxK <= kBlockSizeI;
static_assert(!kPreloadMmas ||
(kIsHalf && ArchTag::kMinComputeCapability >= 80 &&
kOutputInRF),
"preload MMA not supported");
static constexpr bool kPrologueQK = kPreloadMmas;
static constexpr bool kPrologueGV = kPreloadMmas;
static constexpr bool kPrologueDOV = kPreloadMmas;
static constexpr bool kPrologueGQ = kPreloadMmas;
static constexpr bool kPrologueGK = kPreloadMmas;
static constexpr int64_t kNumWarpsPerBlock =
(kBlockSizeI * kBlockSizeJ) / (32 * 32);
// Compute delta for the f16 kernels
// TODO(xformers): Figure out why it's slower on the f32 kernels
// (something due to RF pressure?)
// TODO(xformers): Remove condition on `kOutputInRF` - this is needed to work
// around a compiler bug on V100, not exactly sure why but I spent
// too much time on this already. Reproducible with
// (B, Mq, Mkv, K) = (1, 1, 1, 136) for instance
static constexpr bool kKernelComputesDelta =
kIsHalf && (kOutputInRF || ArchTag::kMinComputeCapability != 70);
static constexpr bool kNeedsAccumGradQ =
!std::is_same<output_accum_t, output_t>::value;
static constexpr bool kNeedsAccumGradK =
!kOutputInRF && !std::is_same<output_accum_t, output_t>::value;
static constexpr bool kNeedsAccumGradV =
!kOutputInRF && !std::is_same<output_accum_t, output_t>::value;
// Launch bounds
static constexpr int64_t kNumThreads = kWarpSize * kNumWarpsPerBlock;
static constexpr int64_t kMinBlocksPerSm =
getWarpsPerSm<scalar_t, ArchTag>() / kNumWarpsPerBlock;
using GemmType = DefaultGemmType<ArchTag, scalar_t>;
using DefaultConfig =
typename cutlass::gemm::device::DefaultGemmConfiguration<
typename GemmType::OpClass,
ArchTag,
scalar_t,
scalar_t,
scalar_t, // ElementC
accum_t // ElementAccumulator
>;
static constexpr auto kOptimalAlignement =
std::max(DefaultConfig::kAlignmentA, DefaultConfig::kAlignmentB);
static constexpr auto kMinimumAlignment = GemmType::kMinimumAlignment;
struct MatmulQK {
/*
attn_T = k_j @ q_i.transpose(-2, -1) # matmul
attn_T = (attn_T - logsumexp[i_start:i_end].unsqueeze(1).transpose(-2,
-1)).exp() # epilogue
with attn_T.shape = (kBlockSizeJ, kBlockSizeI)
*/
using ThreadblockShape =
cutlass::gemm::GemmShape<kBlockSizeJ, kBlockSizeI, GemmType::ThreadK>;
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
using DefaultMma = typename cutlass::gemm::threadblock::DefaultMma<
scalar_t, // ElementA
cutlass::layout::RowMajor, // LayoutA
kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment,
scalar_t, // ElementB
cutlass::layout::ColumnMajor, // LayoutB
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment,
accum_t, // ElementC
cutlass::layout::RowMajor, // LayoutC
typename GemmType::OpClass,
ArchTag,
ThreadblockShape,
WarpShape,
typename GemmType::InstructionShape,
DefaultConfig::kStages,
typename GemmType::Operator,
false, // AccumulatorsInRowMajor = false,
cutlass::gemm::SharedMemoryClearOption::kNone>;
using MmaCore = typename DefaultMma::MmaCore;
using Mma =
typename MakeCustomMma<typename DefaultMma::ThreadblockMma, kMaxK>::Mma;
// used for efficient load of bias tile (Bij) from global memory to shared
// memory
using BiasLoader = TileSmemLoader<
scalar_t,
// Bij is applied to transposed attn matrix tile (Pij.T). Bij is loaded
// row-major but needs to have transposed shape so we get the same
// elements.
cutlass::MatrixShape<ThreadblockShape::kN, ThreadblockShape::kM>,
MmaCore::kThreads,
// input restriction: kv_len has to be a multiple of this value
128 / cutlass::sizeof_bits<scalar_t>::value>;
// Epilogue to store to shared-memory in a format that we can use later for
// the second matmul
using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm<
typename Mma::Operator::IteratorC,
typename Mma::Operator,
scalar_t,
WarpShape,
ThreadblockShape>;
using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
typename Mma::Operator::IteratorC,
accum_t,
kWarpSize>::Iterator;
using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage;
};
struct MatmulGradV {
/*
grad_v[j_start:j_end] += attn_T @ do_i # matmul
Dimensions: (kBlockSizeJ * kNumWarpsPerBlock, kBlockSizeI, K)
(we might need to iterate multiple times on K)
*/
using ThreadblockShape =
cutlass::gemm::GemmShape<kBlockSizeJ, kBlockSizeI, GemmType::ThreadK>;
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
using InstructionShape = typename GemmType::InstructionShape;
using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<
scalar_t, // ElementA,
cutlass::layout::RowMajor, // LayoutA,
DefaultConfig::kAlignmentA,
scalar_t, // ElementB,
cutlass::layout::RowMajor, // LayoutB,
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment,
output_t,
cutlass::layout::RowMajor, // LayoutC,
accum_t,
typename GemmType::OpClass,
ArchTag,
ThreadblockShape,
WarpShape,
typename GemmType::InstructionShape,
typename DefaultConfig::EpilogueOutputOp,
void, // ThreadblockSwizzle - not used
DefaultConfig::kStages,
false, // SplitKSerial
typename GemmType::Operator>;
// if dropout:
// for computing dVj += (Pij.T * Zij) @ dOi
// Pij_dropped.T = Pij.T * Zij is computed on the fly as fragments of
// Pij.T are loaded in. The reason we do it this way is because Pij.T and
// Zij are reused in later steps, while Pij_dropped.T is only needed in
// this step. computing Pij_dropped.T on the fly allows us to avoid
// keeping all 3 of Pij_dropped.T, Pij.T, and Zij in shared memory at the
// same time.
// if no dropout:
// for computing dVj += Pij.T @ dOi
using DefaultMmaFromSmem =
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
typename DefaultGemm::Mma,
typename MatmulQK::AccumulatorSharedStorage,
kApplyDropout>; // kScaleOperandA
using Mma = typename DefaultMmaFromSmem::Mma;
using WarpIteratorA = typename DefaultMmaFromSmem::WarpIteratorA;
using IteratorB = typename Mma::IteratorB;
using WarpCount = typename Mma::WarpCount;
// Epilogue
using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp;
using DefaultEpilogue = typename DefaultGemm::Epilogue;
using OutputTileIterator =
typename cutlass::epilogue::threadblock::MakePrefetchableIterator<
typename DefaultEpilogue::OutputTileIterator>::Iterator;
using AccumTileGmem = GmemTile<typename Mma::FragmentC, kNumThreads>;
};
struct MatmulDOIVJ {
/*
doi_t_vj = do_i @ v_j.transpose(-2, -1) # matmul
tmp = (doi_t_vj - Di.unsqueeze(1)) * attn # inplace / epilogue?
*/
using ThreadblockShape =
cutlass::gemm::GemmShape<kBlockSizeI, kBlockSizeJ, GemmType::ThreadK>;
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
using ElementC = output_t;
using ElementAccum = accum_t;
// no-op output op - epilogue just stores result to global memory
using BiasGradEpilogueOutputOp =
typename cutlass::epilogue::thread::LinearCombination<
ElementC,
DefaultConfig::EpilogueOutputOp::kCount,
typename DefaultConfig::EpilogueOutputOp::ElementAccumulator,
typename DefaultConfig::EpilogueOutputOp::ElementCompute,
cutlass::epilogue::thread::ScaleType::Nothing>;
using DefaultGemm = typename cutlass::gemm::kernel::DefaultGemm<
scalar_t, // ElementA
cutlass::layout::RowMajor, // LayoutA
kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment,
scalar_t, // ElementB
cutlass::layout::ColumnMajor, // LayoutB
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment,
ElementC, // ElementC
cutlass::layout::RowMajor, // LayoutC
ElementAccum, // ElementAccumulator
typename GemmType::OpClass,
ArchTag,
ThreadblockShape,
WarpShape,
typename GemmType::InstructionShape,
BiasGradEpilogueOutputOp, // EpilogueOutputOp
void, // ThreadblockSwizzle (not used)
// multiple preloads, dropout Zij tile, and 3 stages push us over shared
// memory capacity on A100. set a ceiling on number of stages to save
// shared memory if dropout is in use.
kPreloadMmas && kApplyDropout && (kBlockSizeI * kBlockSizeJ > 64 * 64)
? cutlass::const_min(2, DefaultConfig::kStages)
: DefaultConfig::kStages, // Stages
false, // SplitKSerial
typename GemmType::Operator,
cutlass::gemm::SharedMemoryClearOption::kNone>;
using Mma = typename MakeCustomMma<typename DefaultGemm::Mma, kMaxK>::Mma;
// epilogue used to write bias gradient, which is just the output of this
// matmul with some operations applied to the fragment
using BiasGradEpilogue = typename DefaultGemm::Epilogue;
// Epilogue to store to shared-memory in a format that we can use later for
// the second matmul
using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm<
typename Mma::Operator::IteratorC,
typename Mma::Operator,
scalar_t,
WarpShape,
ThreadblockShape>;
using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage;
};
struct MatmulGradQ {
// grad_q <- tmp @ k_j
using ThreadblockShape =
cutlass::gemm::GemmShape<kBlockSizeI, kBlockSizeJ, GemmType::ThreadK>;
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
using InstructionShape = typename GemmType::InstructionShape;
using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<
scalar_t, // ElementA,
cutlass::layout::RowMajor, // LayoutA,
DefaultConfig::kAlignmentA,
scalar_t, // ElementB,
cutlass::layout::RowMajor, // LayoutB,
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment,
output_t,
cutlass::layout::RowMajor, // LayoutC,
accum_t,
typename GemmType::OpClass,
ArchTag,
ThreadblockShape,
WarpShape,
typename GemmType::InstructionShape,
typename DefaultConfig::EpilogueOutputOp,
void, // ThreadblockSwizzle - not used
DefaultConfig::kStages,
false, // SplitKSerial
typename GemmType::Operator>;
using DefaultMmaFromSmem =
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
typename DefaultGemm::Mma,
typename MatmulDOIVJ::AccumulatorSharedStorage,
false>; // kScaleOperandA
using Mma = typename DefaultMmaFromSmem::Mma;
using IteratorB = typename Mma::IteratorB;
using WarpCount = typename Mma::WarpCount;
// Epilogue
using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp;
using DefaultEpilogue = typename DefaultGemm::Epilogue;
using OutputTileIterator =
typename cutlass::epilogue::threadblock::MakePrefetchableIterator<
typename DefaultEpilogue::OutputTileIterator>::Iterator;
using AccumTileGmem = GmemTile<typename Mma::FragmentC, kNumThreads>;
};
struct MatmulGradK {
// grad_k <- tmp.transpose(-2, -1) @ q_i
using ThreadblockShape =
cutlass::gemm::GemmShape<kBlockSizeJ, kBlockSizeI, GemmType::ThreadK>;
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
using InstructionShape = typename GemmType::InstructionShape;
using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<
scalar_t, // ElementA,
cutlass::layout::RowMajor, // LayoutA,
DefaultConfig::kAlignmentA,
scalar_t, // ElementB,
cutlass::layout::RowMajor, // LayoutB,
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment,
output_t,
cutlass::layout::RowMajor, // LayoutC,
accum_t,
typename GemmType::OpClass,
ArchTag,
ThreadblockShape,
WarpShape,
typename GemmType::InstructionShape,
typename DefaultConfig::EpilogueOutputOp,
void, // ThreadblockSwizzle - not used
DefaultConfig::kStages,
false, // SplitKSerial
typename GemmType::Operator>;
using DefaultMmaFromSmemN =
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
typename DefaultGemm::Mma,
typename MatmulQK::AccumulatorSharedStorage,
false>; // kScaleOperandA
using DefaultMmaFromSmemT =
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
typename DefaultGemm::Mma,
typename MatmulDOIVJ::AccumulatorSharedStorage,
false, // kScaleOperandA
kPreloadMmas>; // kTransposeA
using DefaultMmaFromSmem = typename cutlass::platform::conditional<
DefaultMmaFromSmemT::kIsTransposedA,
DefaultMmaFromSmemT,
DefaultMmaFromSmemN>::type;
using Mma = typename DefaultMmaFromSmem::Mma;
using IteratorB = typename Mma::IteratorB;
using WarpCount = typename Mma::WarpCount;
// Epilogue
using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp;
using DefaultEpilogue = typename DefaultGemm::Epilogue;
using OutputTileIterator =
typename cutlass::epilogue::threadblock::MakePrefetchableIterator<
typename DefaultEpilogue::OutputTileIterator>::Iterator;
using AccumTileGmem = GmemTile<typename Mma::FragmentC, kNumThreads>;
};
// shared storage for keeping Zij matrix. not needed if we aren't using
// dropout, in which case we use an empty array to save shared memory
using ZijSharedStorage = typename cutlass::platform::conditional<
kApplyDropout,
typename MatmulQK::AccumulatorSharedStorage,
// dummy shared storage object that takes up no space.
typename cutlass::gemm::threadblock::AccumulatorSharedStorage<
#ifdef _WIN32
// windows builds throw the error:
// "type containing an unknown-size array is not allowed"
// if we try to make Zij shared storage zero-sized.
// To get around this just make it sized 1 on windows.
typename cutlass::gemm::GemmShape<1, 1, 0>,
#else
typename cutlass::gemm::GemmShape<0, 0, 0>,
#endif
typename MatmulQK::AccumulatorSharedStorage::Element,
typename MatmulQK::AccumulatorSharedStorage::Layout,
typename cutlass::MatrixShape<0, 0>>>::type;
// See https://fburl.com/gsheet/l5bltspl
// for an illustration of how smem is used
struct SharedStoragePrologue {
struct {
cutlass::Array<accum_t, kBlockSizeI> di; // (do_i * o_i).sum(-1)
typename MatmulQK::Mma::SharedStorageA mm_qk_k;
} persistent;
union {
struct {
// p1 - after Q.K / dV / dO.V
union {
// 1. efficient load of bias tile Bij, which is then applied to Pij
typename MatmulQK::BiasLoader::SmemTile bias;
// 4. store Pij. it is needed:
// - in dVj += (Pij.T * Zij) @ dOi
// - in dSij = Pij * (dPij - Di)
// 6. dVj += (Pij.T * Zij) @ dOi
// 10. write to fragment
typename MatmulQK::AccumulatorSharedStorage attn_shared_storage;
};
// 5. store Zij. it is needed:
// - to compute Pij_dropped = Pij * Zij on the fly as fragments of Pij
// are loaded for the computation of dVj.
// - to compute dPij = (dOi @ Vj.T) * Zij
// 6. used in dVj += (Pij.T * Zij) @ dOi
// 9. used in dPij = dPij_dropped * Zij
ZijSharedStorage zij;
union {
// 2. prologue for dVj
// 6. workspace for dVj += (Pij.T * Zij) @ dOi
typename MatmulGradV::Mma::SharedStorage mm_gradV;
// 7. dVj epilogue
typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue;
};
// 3. prologue for dPij_dropped
// 8. used in dPij_dropped = dOi @ Vj.T
typename MatmulDOIVJ::Mma::SharedStorage mm_doivj;
} p1;
struct {
// p2 - dQ
union {
typename MatmulQK::AccumulatorSharedStorage
tmpT_shared_storage; // (from p1)
typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage;
};
typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload)
typename MatmulGradQ::Mma::SharedStorage mm_gradQ; // (preload)
union {
// store dB = dSij to global memory
typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue;
typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue;
};
} p2;
struct {
// p3 - after last iteration on dQ's epilogue / dK
union {
typename MatmulQK::AccumulatorSharedStorage
tmpT_shared_storage; // (from p1)
typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage;
};
typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload)
typename MatmulGradQ::DefaultEpilogue::SharedStorage
gradQ_epilogue_lastIter;
typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue;
} p3;
struct {
// p4 - after last iteration on dK's epilogue / preload next K.Q_t
typename MatmulQK::Mma::SharedStorageB mm_qk_q;
// If we reach end of current key, dump RF->gmem with "final" epilogues
typename MatmulGradK::DefaultEpilogue::SharedStorage
gradK_epilogue_final;
typename MatmulGradV::DefaultEpilogue::SharedStorage
gradV_epilogue_final;
} p4;
};
static void print_size() {
// Field size
#define FSZ(f) int((sizeof(((SharedStoragePrologue*)0)->f))) // NOLINT
printf("Total smem: %d bytes\n",
int(sizeof(SharedStoragePrologue))); // NOLINT
printf(" persistent: %db\n", FSZ(persistent));
printf(" mm_qk_k: %db\n", FSZ(persistent.mm_qk_k));
printf(" p1: %db\n", FSZ(p1));
printf(" bias: %db\n", FSZ(p1.bias));
printf(" attn_shared_storage: %db\n", FSZ(p1.attn_shared_storage));
printf(" zij: %db\n", FSZ(p1.zij));
printf(" mm_gradV: %db\n", FSZ(p1.mm_gradV));
printf(" gradV_epilogue: %db\n", FSZ(p1.gradV_epilogue));
printf(" mm_doivj: %db\n", FSZ(p1.mm_doivj));
printf(" p2: %db\n", FSZ(p2));
printf(" tmpT_shared_storage: %db\n", FSZ(p2.tmpT_shared_storage));
printf(" tmp_shared_storage: %db\n", FSZ(p2.tmp_shared_storage));
printf(" mm_gradK: %db\n", FSZ(p2.mm_gradK));
printf(" mm_gradQ: %db\n", FSZ(p2.mm_gradQ));
printf(" gradB_epilogue: %db\n", FSZ(p2.gradB_epilogue));
printf(" gradQ_epilogue: %db\n", FSZ(p2.gradQ_epilogue));
printf(" p3: %db\n", FSZ(p3));
printf(" tmpT_shared_storage: %db\n", FSZ(p3.tmpT_shared_storage));
printf(" p4: %db\n", FSZ(p4));
printf(" mm_qk_q: %db\n", FSZ(p4.mm_qk_q));
printf(" gradK_epilogue_final: %db\n", FSZ(p4.gradK_epilogue_final));
printf(" gradV_epilogue_final: %db\n", FSZ(p4.gradV_epilogue_final));
}
// ===========================================
#define FIELD(INSIDE_STRUCT, FIELDNAME) \
CUTLASS_DEVICE auto& FIELDNAME() { return INSIDE_STRUCT.FIELDNAME; }
FIELD(persistent, di)
FIELD(persistent, mm_qk_k)
FIELD(p1, bias)
FIELD(p1, attn_shared_storage)
FIELD(p1, zij)
FIELD(p1, mm_gradV)
FIELD(p1, gradV_epilogue)
FIELD(p1, mm_doivj)
FIELD(p2, mm_gradK)
FIELD(p2, mm_gradQ)
FIELD(p2, gradB_epilogue)
FIELD(p2, gradQ_epilogue)
FIELD(p2, tmp_shared_storage)
FIELD(p3, tmpT_shared_storage)
FIELD(p3, gradQ_epilogue_lastIter)
FIELD(p3, gradK_epilogue)
FIELD(p4, mm_qk_q)
FIELD(p4, gradK_epilogue_final)
FIELD(p4, gradV_epilogue_final)
};
struct SharedStorageNoPrologue {
struct {
cutlass::Array<accum_t, kBlockSizeI> di; // (do_i * o_i).sum(-1)
} persistent;
union {
struct {
// p1 - Q.K matmul
typename MatmulQK::Mma::SharedStorageA mm_qk_k;
typename MatmulQK::Mma::SharedStorageB mm_qk_q;
} p1;
struct {
// p2 - compute gradV
union {
// 1. efficient load of bias tile Bij, which is then applied to Pij
typename MatmulQK::BiasLoader::SmemTile bias;
// 2. store Pij to shared memory. it is needed:
// - in this step, where it is used in dVj += (Pij.T * Zij) @ dOi
// - in next step where it is used in dSij = Pij * (dPij - Di)
typename MatmulQK::AccumulatorSharedStorage attn_shared_storage;
};
// 3. store Zij. it is needed:
// - in this step, where it is used to compute Pij_dropped = Pij * Zij
// on the
// fly as fragments of Pij are loaded for the computation of dVj.
// - later to compute dPij = (dOi @ Vj.T) * Zij
ZijSharedStorage zij;
union {
typename MatmulGradV::Mma::SharedStorage mm_gradV;
typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue;
};
} p2;
struct {
// p3 - DO.V matmul
union {
// first compute dPij = (dOi @ Vj.T) * Zij
// and dSij = Pij * (dPij - Di)
struct {
// (from p2) - Pij for computing dSij = Pij * (dPij - Di)
typename MatmulQK::AccumulatorSharedStorage attn_shared_storage;
// (from p2) - Zij for computing dPij = dPij_dropped * Zij
ZijSharedStorage zij;
// matmul to compute dOiVj
typename MatmulDOIVJ::Mma::SharedStorage mm_doivj;
};
// then store dB = dSij to global memory
typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue;
};
} p3;
struct {
// p4 - compute gradQ
typename MatmulQK::AccumulatorSharedStorage
tmpT_shared_storage; // (from p2)
typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage;
union {
typename MatmulGradQ::Mma::SharedStorage mm_gradQ;
typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue;
typename MatmulGradQ::DefaultEpilogue::SharedStorage
gradQ_epilogue_lastIter;
};
} p4;
struct {
// p5 - compute gradK
typename MatmulQK::AccumulatorSharedStorage
tmpT_shared_storage; // (from p2)
typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage;
union {
typename MatmulGradK::Mma::SharedStorage mm_gradK;
typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue;
};
} p5;
struct {
// p6 - store RF accumulated into gmem
typename MatmulGradK::DefaultEpilogue::SharedStorage
gradK_epilogue_final;
typename MatmulGradV::DefaultEpilogue::SharedStorage
gradV_epilogue_final;
} p6;
};
static void print_size() {
#define FIELD_SIZEOF(f) \
int((sizeof(((SharedStorageNoPrologue*)0)->f))) // NOLINT
printf("Total smem: %d bytes\n",
int(sizeof(SharedStorageNoPrologue))); // NOLINT
printf(" persistent: %db\n", FIELD_SIZEOF(persistent));
printf(" p1: %db\n", FIELD_SIZEOF(p1));
printf(" p2: %db\n", FIELD_SIZEOF(p2));
printf(" p3: %db\n", FIELD_SIZEOF(p3));
printf(" p4: %db\n", FIELD_SIZEOF(p4));
printf(" p5: %db\n", FIELD_SIZEOF(p5));
printf(" p6: %db\n", FIELD_SIZEOF(p6));
}
// ===========================================
#define FIELD(INSIDE_STRUCT, FIELDNAME) \
CUTLASS_DEVICE auto& FIELDNAME() { return INSIDE_STRUCT.FIELDNAME; }
FIELD(persistent, di)
FIELD(p1, mm_qk_k)
FIELD(p1, mm_qk_q)
FIELD(p2, bias)
FIELD(p2, attn_shared_storage)
FIELD(p2, zij)
FIELD(p2, mm_gradV)
FIELD(p2, gradV_epilogue)
FIELD(p3, mm_doivj)
FIELD(p3, gradB_epilogue)
FIELD(p4, tmpT_shared_storage)
FIELD(p4, tmp_shared_storage)
FIELD(p4, mm_gradQ)
FIELD(p4, gradQ_epilogue)
FIELD(p4, gradQ_epilogue_lastIter)
FIELD(p5, mm_gradK)
FIELD(p5, gradK_epilogue)
FIELD(p6, gradK_epilogue_final)
FIELD(p6, gradV_epilogue_final)
};
using SharedStorage =
typename std::conditional<kPreloadMmas,
SharedStoragePrologue,
SharedStorageNoPrologue>::type;
struct OutputFragments {
typename MatmulGradV::Mma::FragmentC gradV;
typename MatmulGradK::Mma::FragmentC gradK;
CUTLASS_DEVICE void clear() {
gradV.clear();
gradK.clear();
}
};
static bool __host__ check_supported(Params const& p) {
CHECK_ALIGNED_PTR(p.query_ptr, kMinimumAlignment);
CHECK_ALIGNED_PTR(p.key_ptr, kMinimumAlignment);
CHECK_ALIGNED_PTR(p.value_ptr, kMinimumAlignment);
CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment);
CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment);
CHECK_ALIGNED_PTR(p.bias_ptr, kMinimumAlignment);
PADDLE_ENFORCE_EQ(p.lse_strideH % 8,
0,
paddle::platform::errors::InvalidArgument(
"LSE is not correctly aligned"));
PADDLE_ENFORCE_EQ(p.lse_strideB % 8,
0,
paddle::platform::errors::InvalidArgument(
"LSE is not correctly aligned"));
PADDLE_ENFORCE_EQ(p.q_strideH % kMinimumAlignment,
0,
paddle::platform::errors::InvalidArgument(
"query is not correctly aligned"));
PADDLE_ENFORCE_EQ(p.k_strideH % kMinimumAlignment,
0,
paddle::platform::errors::InvalidArgument(
"key is not correctly aligned"));
PADDLE_ENFORCE_EQ(p.v_strideH % kMinimumAlignment,
0,
paddle::platform::errors::InvalidArgument(
"value is not correctly aligned"));
PADDLE_ENFORCE_EQ(p.bias_strideB % kMinimumAlignment,
0,
paddle::platform::errors::InvalidArgument(
"attn_bias is not correctly aligned"));
PADDLE_ENFORCE_EQ(p.bias_strideH % kMinimumAlignment,
0,
paddle::platform::errors::InvalidArgument(
"attn_bias is not correctly aligned"));
PADDLE_ENFORCE_EQ(p.bias_strideM % kMinimumAlignment,
0,
paddle::platform::errors::InvalidArgument(
"attn_bias is not correctly aligned"));
PADDLE_ENFORCE_EQ(p.cu_seqlens_q_ptr && p.bias_ptr,
false,
paddle::platform::errors::InvalidArgument(
"CuSeqlen + bias not implemented yet"));
return true;
}
static CUTLASS_DEVICE void attention_kernel(Params const& p) {
extern __shared__ char smem_buffer[];
SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); // NOLINT
if (kPrologueQK) {
prologueQkNextIteration<true>(shared_storage, p, 0, 0);
}
// Computes (dO*out).sum(-1) and writes it to `p.delta_ptr`
if (kKernelComputesDelta) {
constexpr int kOptimalElements =
128 / cutlass::sizeof_bits<scalar_t>::value;
if (p.head_dim_value % kOptimalElements == 0) {
for (int query_start = 0; query_start < p.num_queries;
query_start += kBlockSizeI) {
computeDelta<kOptimalElements>(p, query_start);
}
} else {
for (int query_start = 0; query_start < p.num_queries;
query_start += kBlockSizeI) {
computeDelta<1>(p, query_start);
}
}
__syncthreads();
}
OutputFragments output_frags;
curandStatePhilox4_32_10_t rng_state_init;
if (kApplyDropout) {
// each element of the attention matrix P with shape
// (batch_sz, n_heads, n_queries, n_keys) is associated with a single
// offset in RNG sequence. we initialize the RNG state with offset that
// starts at the beginning of a (n_queries, n_keys) matrix for this
// block's batch_id and head_id
// initializing rng state is very expensive, so we run once per kernel,
// rather than once per iteration. each iteration takes a copy of the
// initialized RNG state and offsets it as needed.
curand_init(p.seed,
0,
p.offset + p.dropout_batch_head_rng_offset,
&rng_state_init);
}
int32_t key_start = 0;
int32_t key_end = p.num_keys / kBlockSizeJ * kBlockSizeJ;
for (; key_start < key_end; key_start += kBlockSizeJ) {
output_frags.clear();
int32_t query_start = getQueryStart(p, key_start);
int32_t query_end = query_start + (p.num_queries - query_start) /
kBlockSizeI * kBlockSizeI;
for (; query_start < query_end; query_start += kBlockSizeI) {
processBlockIJ<true>(shared_storage,
output_frags,
p,
query_start,
key_start,
rng_state_init);
}
// last (partial) query
if (query_start < p.num_queries) {
processBlockIJ<false>(shared_storage,
output_frags,
p,
query_start,
key_start,
rng_state_init);
}
if (kOutputInRF) {
writeFragsToGmem<true>(shared_storage, output_frags, p, key_start);
} else if (getQueryStart(p, key_start) >= p.num_queries) {
zfillGradKV<true>(p, key_start);
}
__syncthreads();
}
// Last (partial) key
if (key_start != p.num_keys) {
output_frags.clear();
int32_t query_start = getQueryStart(p, key_start);
for (; query_start < p.num_queries; query_start += kBlockSizeI) {
processBlockIJ<false>(shared_storage,
output_frags,
p,
query_start,
key_start,
rng_state_init);
}
if (kOutputInRF) {
writeFragsToGmem<false>(shared_storage, output_frags, p, key_start);
} else if (getQueryStart(p, key_start) >= p.num_queries) {
zfillGradKV<false>(p, key_start);
}
}
}
static CUTLASS_DEVICE void loadDi(
cutlass::Array<accum_t, kBlockSizeI>& di, // NOLINT
Params const& p, // NOLINT
int32_t query_start) {
int32_t thread_id = threadIdx.x + threadIdx.y * blockDim.x;
if (thread_id < kBlockSizeI) {
accum_t di_rf = accum_t(0);
if (query_start + thread_id < p.num_queries) {
di_rf = p.delta_ptr[query_start + thread_id];
}
di[thread_id] = di_rf;
}
}
template <bool skipBoundsChecks>
static CUTLASS_DEVICE void zfillGradKV(Params const& p, int32_t key_start) {
constexpr int kThreadsPerKey = 8;
constexpr int kParallelKeys = kNumThreads / kThreadsPerKey;
static_assert(kBlockSizeJ % kParallelKeys == 0, "");
// This function is not really optimized, but should rarely be used
// It's only used when some keys are "useless" and don't attend to
// any query, due to causal masking
int lane_id = get_lane_id();
int thread_id = get_thread_id();
int k_shift = lane_id % kThreadsPerKey;
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < kBlockSizeJ; j += kParallelKeys) {
int key = key_start + j + (thread_id / kThreadsPerKey);
if (!skipBoundsChecks && key >= p.num_keys) {
continue;
}
auto gv_ptr = p.grad_value_ptr + key * p.gV_strideM();
auto gk_ptr = p.grad_key_ptr + key * p.gK_strideM();
for (int k = k_shift; k < p.head_dim_value; k += kThreadsPerKey) {
gv_ptr[k] = scalar_t(0);
}
for (int k = k_shift; k < p.head_dim; k += kThreadsPerKey) {
gk_ptr[k] = scalar_t(0);
}
}
}
template <bool skipBoundsChecks>
static CUTLASS_DEVICE void processBlockIJ(
SharedStorage& shared_storage, // NOLINT
OutputFragments& output_frags, // NOLINT
Params const& p, // NOLINT
int32_t query_start,
int32_t key_start,
const curandStatePhilox4_32_10_t& curand_state_init) {
cutlass::MatrixCoord no_offset{0, 0};
accum_t scale = p.scale;
int16_t thread_id = threadIdx.x + threadIdx.y * blockDim.x;
int8_t warp_id = warp_uniform(threadIdx.y);
int8_t lane_id = threadIdx.x;
bool isFirstQuery =
query_start == 0 || (p.causal && query_start <= key_start);
int32_t next_query, next_key;
incrIteration(p, query_start, key_start, next_query, next_key);
bool isLastQuery = next_key != key_start;
__syncthreads();
loadDi(shared_storage.di(), p, query_start);
int32_t num_queries_in_block =
skipBoundsChecks ? MatmulQK::Mma::Shape::kN
: std::min((int32_t)MatmulQK::Mma::Shape::kN,
p.num_queries - query_start);
int32_t num_keys_in_block =
skipBoundsChecks ? MatmulQK::Mma::Shape::kM
: std::min((int32_t)MatmulQK::Mma::Shape::kM,
p.num_keys - key_start);
auto prologueGradV = [&](int col) {
typename MatmulGradV::Mma::IteratorB iterator_dO(
{int32_t(p.gO_strideM)},
p.grad_output_ptr + query_start * p.gO_strideM + col,
{num_queries_in_block, p.head_dim_value - col},
thread_id,
no_offset);
MatmulGradV::Mma::prologue(shared_storage.mm_gradV(),
iterator_dO,
thread_id,
num_queries_in_block);
};
auto prologueGradQ = [&](int col) {
typename MatmulGradQ::Mma::IteratorB iterator_K(
{int32_t(p.k_strideM)},
p.key_ptr + key_start * p.k_strideM + col,
{num_keys_in_block, p.head_dim - col},
thread_id,
no_offset);
MatmulGradQ::Mma::prologue(
shared_storage.mm_gradQ(), iterator_K, thread_id, num_keys_in_block);
};
auto prologueGradK = [&](int col) {
typename MatmulGradK::Mma::IteratorB iterator_Q(
{int32_t(p.q_strideM)},
p.query_ptr + query_start * p.q_strideM + col,
{num_queries_in_block, p.head_dim - col},
thread_id,
no_offset);
MatmulGradK::Mma::prologue(shared_storage.mm_gradK(),
iterator_Q,
thread_id,
num_queries_in_block);
};
auto prologueDOV = [&]() {
typename MatmulDOIVJ::Mma::IteratorA iterator_A(
{int32_t(p.gO_strideM)},
p.grad_output_ptr + query_start * p.gO_strideM,
{num_queries_in_block, p.head_dim_value},
thread_id,
no_offset);
typename MatmulDOIVJ::Mma::IteratorB iterator_B(
{int32_t(p.v_strideM)},
p.value_ptr + key_start * p.v_strideM,
{p.head_dim_value, num_keys_in_block},
thread_id,
no_offset);
MatmulDOIVJ::Mma::prologue(shared_storage.mm_doivj(),
iterator_A,
iterator_B,
thread_id,
p.head_dim_value);
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// MatmulQK
/////////////////////////////////////////////////////////////////////////////////////////////////
{
using Mma = typename MatmulQK::Mma;
cutlass::gemm::GemmCoord problem_size(num_keys_in_block,
num_queries_in_block,
p.head_dim // k
);
// k_j
typename Mma::IteratorA iterator_A({int32_t(p.k_strideM)},
p.key_ptr + key_start * p.k_strideM,
{problem_size.m(), problem_size.k()},
thread_id,
no_offset);
// q_i.transpose(-2, -1)
typename Mma::IteratorB iterator_B(
{int32_t(p.q_strideM)},
p.query_ptr + query_start * p.q_strideM,
{problem_size.k(), problem_size.n()},
thread_id,
no_offset);
Mma mma(shared_storage.mm_qk_k(),
shared_storage.mm_qk_q(),
thread_id,
warp_id,
lane_id);
typename Mma::FragmentC accum;
accum.clear();
auto gemm_k_iterations =
(problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
mma.set_prologue_done(kPrologueQK);
mma.set_zero_outside_bounds(!skipBoundsChecks);
mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
accum = cutlass::multiplies<typename Mma::FragmentC>()(scale, accum);
// Epilogue: add LSE + exp and store that to our shared memory buffer
// shmem <- (matmul_result -
// logsumexp[i_start:i_end].unsqueeze(1)).exp()
int warp_idx_mn_0 =
warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN);
auto output_tile_coords =
cutlass::MatrixCoord{warp_idx_mn_0 % Mma::Base::WarpCount::kM,
warp_idx_mn_0 / Mma::Base::WarpCount::kM};
// apply bias if applicable
if (p.bias_ptr != nullptr) {
// load bias tile Bij into shared memory
typename MatmulQK::BiasLoader::GmemTileIterator bias_iter(
{cutlass::layout::RowMajor(p.bias_strideM)},
p.bias_ptr + query_start * p.bias_strideM + key_start,
{num_queries_in_block, num_keys_in_block},
thread_id);
cutlass::TensorRef<scalar_t, cutlass::layout::RowMajor> bias_tensor_ref(
shared_storage.bias().data(),
cutlass::layout::RowMajor(MatmulQK::ThreadblockShape::kM));
typename MatmulQK::BiasLoader::SmemTileIterator smem_tile_iter(
bias_tensor_ref, thread_id);
MatmulQK::BiasLoader::load(bias_iter, smem_tile_iter);
// Pij += Bij, where Pij is in register fragment and Bij is in shmem
auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset(
lane_id, warp_id, output_tile_coords);
MatmulQK::AccumLambdaIterator::iterateRows(
lane_offset,
[&](int accum_n) {},
[&](int accum_m, int accum_n, int idx) {
// remember we are transposed
if (skipBoundsChecks || (accum_n < num_queries_in_block &&
accum_m < num_keys_in_block)) {
accum[idx] += bias_tensor_ref.at({accum_n, accum_m});
}
},
[&](int accum_n) {});
}
// Apply mask
if (p.causal) {
auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset(
lane_id, warp_id, output_tile_coords);
MatmulQK::AccumLambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {},
[&](int accum_m, int accum_n, int idx) {
// (don't forget we are transposed!)
if (accum_m > accum_n + query_start - key_start) {
accum[idx] = -std::numeric_limits<accum_t>::infinity();
}
},
[&](int accum_m) {});
}
__syncthreads();
if (kPrologueGV) {
prologueGradV(0);
}
if (kPrologueDOV) {
prologueDOV();
}
MatmulQK::B2bGemm::accumApplyLSEToSmem(
shared_storage.attn_shared_storage(),
accum,
p.logsumexp_ptr + query_start,
problem_size.n(),
thread_id,
warp_id,
lane_id,
output_tile_coords);
// if we are using dropout, compute Zij, writing it to shared memory.
// each element of Zij is:
// - 0 with probability dropout_p
// - 1 / (1 - dropout_p) with probability 1 - dropout_p
if (kApplyDropout) {
auto zij = shared_storage.zij().accum_ref();
// each thread generates a contiguous sequence of elements in Zij, all
// in the same row. the reason they have to come from the same row is
// that sampling random numbers from a contiguous random number sequence
// is much more efficient than jumping around, and the linear offset of
// each element of Z (the global matrix) maps to an offset in a random
// number sequence. for Z, the end of a row and the beginning of the
// next have adjacent offsets, but for Zij (tile of global matrix), this
// is not necessarily the case.
const int num_threads = blockDim.x * blockDim.y * blockDim.z;
const int threads_per_row = cutlass::fast_min(
num_threads / num_queries_in_block, num_keys_in_block);
const int elts_per_thread = cutlass::round_nearest(
cutlass::ceil_div(num_keys_in_block, threads_per_row), 4);
const int thread_i = thread_id / threads_per_row;
const int thread_start_j =
(thread_id % threads_per_row) * elts_per_thread;
if (thread_i < num_queries_in_block &&
thread_start_j < num_keys_in_block) {
curandStatePhilox4_32_10_t curand_state = curand_state_init;
skipahead((query_start + thread_i) * p.num_keys +
(key_start + thread_start_j),
&curand_state);
const float dropout_scale = 1.0 / (1.0 - p.dropout_prob);
// generate elements of Zij, 4 elements at a time
for (int zij_start_col_idx = thread_start_j;
zij_start_col_idx <
cutlass::fast_min(thread_start_j + elts_per_thread,
num_keys_in_block);
zij_start_col_idx += 4) {
const float4 rand_uniform_quad = curand_uniform4(&curand_state);
CUTLASS_PRAGMA_UNROLL
for (int quad_idx = 0; quad_idx < 4; ++quad_idx) {
// we'll write Zij transposed since attention is also transposed
// during the matmul to compute dV.
zij.at({zij_start_col_idx + quad_idx, thread_i}) =
static_cast<scalar_t>(
dropout_scale *
((&rand_uniform_quad.x)[quad_idx] > p.dropout_prob));
}
}
}
}
__syncthreads();
}
/////////////////////////////////////////////////////////////////////////////////////////////////
// GradV matmul
//
// grad_v[j_start:j_end] += attn_T @ do_i
/////////////////////////////////////////////////////////////////////////////////////////////////
for (int col = 0; col < (kOutputInRF ? 1 : p.head_dim_value);
col += MatmulGradV::ThreadblockShape::kN) {
using Mma = typename MatmulGradV::Mma;
using AccumTileGmem = typename MatmulGradQ::AccumTileGmem;
cutlass::gemm::GemmCoord problem_size(
num_keys_in_block, p.head_dim_value - col, num_queries_in_block);
auto createEpilogueIter = [&]() {
return typename MatmulGradV::OutputTileIterator(
typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()},
p.grad_value_ptr + key_start * p.gV_strideM() + col,
{num_keys_in_block, p.head_dim_value - col},
thread_id);
};
typename Mma::IteratorB iterator_B(
{int32_t(p.gO_strideM)},
p.grad_output_ptr + query_start * p.gO_strideM + col,
{num_queries_in_block, p.head_dim_value - col},
thread_id,
no_offset);
// if dropout: dVj += (Pij.T * Zij) @ dOi
// otherwise: dVj += Pij.T @ dOi
Mma mma(shared_storage.mm_gradV(),
// operand A: Pij
typename MatmulGradV::WarpIteratorA(
shared_storage.attn_shared_storage().accum_ref(), lane_id),
// if we're using dropout, operand A is Pij_dropped = Pij * Zij
// which is computed on the fly as fragments of Pij are loaded in
typename Mma::WarpIteratorAScale(shared_storage.zij().accum_ref(),
lane_id),
thread_id,
warp_id,
lane_id);
int storage_id = col / MatmulGradV::ThreadblockShape::kN;
AccumTileGmem gmem_tile{p.workspace_gv +
storage_id * AccumTileGmem::kElementsStored};
if (!kOutputInRF) {
if (isFirstQuery || !kNeedsAccumGradV) {
output_frags.gradV.clear();
} else {
gmem_tile.load(output_frags.gradV, thread_id);
}
}
mma.set_prologue_done(kPrologueGV);
auto gemm_k_iterations =
(problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
__syncthreads();
mma(gemm_k_iterations,
output_frags.gradV,
iterator_B,
output_frags.gradV);
__syncthreads();
if (kPrologueGV &&
col + MatmulGradV::ThreadblockShape::kN < p.head_dim_value) {
prologueGradV(col + MatmulGradV::ThreadblockShape::kN);
}
if (!kOutputInRF) {
if (kNeedsAccumGradV && !isLastQuery) {
gmem_tile.store(output_frags.gradV, thread_id);
} else {
accumulateInGmem<MatmulGradV>(shared_storage.gradV_epilogue(),
output_frags.gradV,
createEpilogueIter(),
isFirstQuery || kNeedsAccumGradV);
}
}
}
__syncthreads();
/////////////////////////////////////////////////////////////////////////////////////////////////
// MatmulDOIVJ
/////////////////////////////////////////////////////////////////////////////////////////////////
{
using Mma = typename MatmulDOIVJ::Mma;
// do_i
typename Mma::IteratorA iterator_A(
{int32_t(p.gO_strideM)},
p.grad_output_ptr + query_start * p.gO_strideM,
{num_queries_in_block, p.head_dim_value},
thread_id,
no_offset);
// v_j.transpose(-2, -1)
typename Mma::IteratorB iterator_B({int32_t(p.v_strideM)},
p.value_ptr + key_start * p.v_strideM,
{p.head_dim_value, num_keys_in_block},
thread_id,
no_offset);
Mma mma(shared_storage.mm_doivj(), thread_id, warp_id, lane_id);
mma.set_prologue_done(kPrologueDOV);
mma.set_zero_outside_bounds(!skipBoundsChecks);
typename Mma::FragmentC accum;
accum.clear();
auto gemm_k_iterations =
(p.head_dim_value + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
__syncthreads();
if (kPrologueGQ) {
prologueGradQ(0);
}
if (kPrologueGK) {
prologueGradK(0);
}
int warp_idx_mn_0 =
warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN);
auto output_tile_coords =
cutlass::MatrixCoord{warp_idx_mn_0 % Mma::Base::WarpCount::kM,
warp_idx_mn_0 / Mma::Base::WarpCount::kM};
// TODO(xformers): This must be terribly inefficient. There must be a
// better way tmp [RF] <- (accum [RF] - Di [smem] ) * attn_T.T [smem]
// attn_shared_storage [smem] <- tmp.T
// tmp_shared_storage [smem] <- tmp
{
using LambdaIterator = typename DefaultMmaAccumLambdaIterator<
typename Mma::Operator::IteratorC,
typename MatmulDOIVJ::ElementAccum,
kWarpSize>::Iterator;
auto lane_offset = LambdaIterator::get_lane_offset(
lane_id, warp_id, output_tile_coords);
// if dropout was used, compute dPij = dPij_dropped * Zij
// Zij was written to shared memory earlier, and the elementwise
// multiplication occurs on a fragment of dPij_dropped
if (kApplyDropout) {
const auto zij = shared_storage.zij().accum_ref();
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {},
[&](int accum_m, int accum_n, int idx) {
const int global_query_idx = query_start + accum_m;
const int global_key_idx = key_start + accum_n;
if (skipBoundsChecks || (global_query_idx < p.num_queries &&
global_key_idx < p.num_keys)) {
accum[idx] *= zij.at({accum_n, accum_m});
}
},
[&](int accum_m) {});
}
auto attn_T = shared_storage.attn_shared_storage().accum_ref();
accum_t current_di;
typename Mma::FragmentC fragment_attn, fragment_di;
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {
current_di = shared_storage.di()[accum_m];
}, // NOLINT
[&](int accum_m, int accum_n, int idx) { // NOLINT
// TODO(xformers): Otherwise we can get nans as we
// might have infs here (only seen on f16 tho)
if (skipBoundsChecks || (accum_m < num_queries_in_block &&
accum_n < num_keys_in_block)) {
fragment_attn[idx] = attn_T.at({accum_n, accum_m});
} else {
fragment_attn[idx] = 0;
}
fragment_di[idx] = current_di;
},
[&](int accum_m) {});
// dSij = (dPij - Di) * Pij
accum = (accum - fragment_di) * fragment_attn;
// store bias gradient tile dBij to global memory,
// where dBij = dSij = Pij * (dPij - Di)
if (p.grad_bias_ptr != nullptr) {
typename MatmulDOIVJ::BiasGradEpilogue::OutputTileIterator
output_iter(
typename MatmulDOIVJ::BiasGradEpilogue::OutputTileIterator::
Params{p.gB_strideM},
// grad_bias_ptr is offset to point at beginning of
// matrix of shape (queries, keys) for a given
// (batch_id, head_id) the pointer arithmetic here produces
// a pointer to the start of the current tile within that
// matrix
p.grad_bias_ptr + query_start * p.gB_strideM + key_start,
{num_queries_in_block, num_keys_in_block},
thread_id);
// no-op epilogue operator - just casting and storing contents of
// accum to global memory
typename MatmulDOIVJ::BiasGradEpilogue::OutputOp output_op({1, 1});
typename MatmulDOIVJ::BiasGradEpilogue epilogue(
shared_storage.gradB_epilogue(), thread_id, warp_id, lane_id);
epilogue(output_op, output_iter, accum, output_iter);
}
accum = accum * scale;
__syncthreads();
if (!MatmulGradK::DefaultMmaFromSmem::kIsTransposedA) {
auto tmpT = shared_storage.tmpT_shared_storage().accum_ref();
// attn <- attn_T.T
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {},
[&](int accum_m, int accum_n, int idx) {
tmpT.at({accum_n, accum_m}) = scalar_t(accum[idx]);
},
[&](int accum_m) {});
}
}
MatmulDOIVJ::B2bGemm::accumToSmem(shared_storage.tmp_shared_storage(),
accum,
lane_id,
output_tile_coords);
__syncthreads();
}
/////////////////////////////////////////////////////////////////////////////////////////////////
// GradQ matmul
//
// grad_q[i_start:i_end] += tmp @ k_j
/////////////////////////////////////////////////////////////////////////////////////////////////
for (int col = 0; col < p.head_dim;
col += MatmulGradQ::ThreadblockShape::kN) {
using Mma = typename MatmulGradQ::Mma;
using AccumTileGmem = typename MatmulGradQ::AccumTileGmem;
cutlass::gemm::GemmCoord problem_size(
num_queries_in_block,
false ? MatmulGradQ::ThreadblockShape::kN : p.head_dim - col,
num_keys_in_block);
// k_j
typename Mma::IteratorB iterator_B(
{int32_t(p.k_strideM)},
p.key_ptr + key_start * p.k_strideM + col,
{problem_size.k(), problem_size.n()},
thread_id,
no_offset);
auto a = shared_storage.tmp_shared_storage().accum_ref();
Mma mma(shared_storage.mm_gradQ(),
shared_storage.tmp_shared_storage(),
thread_id,
warp_id,
lane_id,
problem_size.k());
typename Mma::FragmentC accum;
bool isFirst = key_start == 0;
int col_id = col / MatmulGradQ::ThreadblockShape::kN;
int storage_id =
(col_id +
query_start / kBlockSizeI *
ceil_div(p.head_dim, MatmulGradQ::ThreadblockShape::kN));
AccumTileGmem gmem_tile{p.workspace_gq +
storage_id * AccumTileGmem::kElementsStored};
if (isFirst || !kNeedsAccumGradQ) {
accum.clear();
} else {
gmem_tile.load(accum, thread_id);
}
auto gemm_k_iterations =
(problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
__syncthreads();
mma.set_prologue_done(kPrologueGQ);
mma(gemm_k_iterations, accum, iterator_B, accum);
__syncthreads();
bool isLastColumn = col + MatmulGradQ::ThreadblockShape::kN >= p.head_dim;
if (kPrologueGQ && !isLastColumn) {
prologueGradQ(col + MatmulGradQ::ThreadblockShape::kN);
}
// Output results
int32_t next_query, next_key;
incrIteration(p, p.num_queries, key_start, next_query, next_key);
bool isLast =
(p.causal && next_query > query_start) || next_key >= p.num_keys;
if (kNeedsAccumGradQ && !isLast) {
gmem_tile.store(accum, thread_id);
} else {
typename MatmulGradQ::OutputTileIterator output_it(
typename MatmulGradQ::OutputTileIterator::Params{p.gQ_strideM()},
p.grad_query_ptr + query_start * p.gQ_strideM() + col,
{problem_size.m(), problem_size.n()},
thread_id);
accumulateInGmem<MatmulGradQ>(
isLastColumn ? shared_storage.gradQ_epilogue_lastIter()
: shared_storage.gradQ_epilogue(),
accum,
output_it,
isFirst || kNeedsAccumGradQ);
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////
// GradK matmul
//
// grad_k[i_start:i_end] += tmp.transpose(-2, -1) @ q_i
/////////////////////////////////////////////////////////////////////////////////////////////////
for (int col = 0; col < (kOutputInRF ? 1 : p.head_dim);
col += MatmulGradK::ThreadblockShape::kN) {
using Mma = typename MatmulGradK::Mma;
using AccumTileGmem = typename MatmulGradQ::AccumTileGmem;
cutlass::gemm::GemmCoord problem_size(
num_keys_in_block,
false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col,
num_queries_in_block);
auto createEpilogueIter = [&]() {
return typename MatmulGradK::OutputTileIterator(
typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()},
p.grad_key_ptr + key_start * p.gK_strideM() + col,
{num_keys_in_block,
false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col},
thread_id);
};
// q_i
typename Mma::IteratorB iterator_B(
{int32_t(p.q_strideM)},
p.query_ptr + query_start * p.q_strideM + col,
{problem_size.k(), problem_size.n()},
thread_id,
no_offset);
auto getTmp = [&](int) { return &shared_storage.tmp_shared_storage(); };
auto getTmpT = [&](int) { return &shared_storage.tmpT_shared_storage(); };
// this is basically:
// opA = kIsTransposedA ? getTmp() : getTmpT();
bool constexpr kIsTransposedA =
MatmulGradK::DefaultMmaFromSmem::kIsTransposedA;
auto& opA =
*call_conditional<kIsTransposedA,
decltype(getTmp),
decltype(getTmpT)>::apply(getTmp, getTmpT, 0);
Mma mma(shared_storage.mm_gradK(),
opA,
thread_id,
warp_id,
lane_id,
problem_size.k());
int storage_id = col / MatmulGradK::ThreadblockShape::kN;
AccumTileGmem gmem_tile{p.workspace_gk +
storage_id * AccumTileGmem::kElementsStored};
if (!kOutputInRF) {
if (isFirstQuery || !kNeedsAccumGradK) {
output_frags.gradK.clear();
} else {
gmem_tile.load(output_frags.gradK, thread_id);
}
}
mma.set_prologue_done(kPrologueGK);
auto gemm_k_iterations =
(problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
__syncthreads();
mma(gemm_k_iterations,
output_frags.gradK,
iterator_B,
output_frags.gradK);
__syncthreads();
bool isLastColumn = col + MatmulGradK::ThreadblockShape::kN >= p.head_dim;
if (kPrologueGK && !isLastColumn) {
prologueGradK(col + MatmulGradK::ThreadblockShape::kN);
}
if (kPrologueQK && isLastColumn) {
int32_t next_query, next_key;
incrIteration(p, query_start, key_start, next_query, next_key);
DISPATCH_BOOL(next_key != key_start, kForceReloadK, ([&]() {
prologueQkNextIteration<kForceReloadK>(
shared_storage, p, next_query, next_key);
}));
}
// Output results
if (!kOutputInRF) {
if (kNeedsAccumGradK && !isLastQuery) {
gmem_tile.store(output_frags.gradK, thread_id);
} else {
accumulateInGmem<MatmulGradK>(
isLastColumn ? shared_storage.gradK_epilogue_final()
: shared_storage.gradK_epilogue(),
output_frags.gradK,
createEpilogueIter(),
isFirstQuery || kNeedsAccumGradK);
}
}
}
}
static CUTLASS_DEVICE int32_t getQueryStart(Params const& p,
int32_t key_start) {
if (p.causal) {
return (key_start / kBlockSizeI) * kBlockSizeI;
}
return 0;
}
static CUTLASS_DEVICE void incrIteration(Params const& p, // NOLINT
int32_t query_start,
int32_t key_start,
int32_t& next_query, // NOLINT
int32_t& next_key) { // NOLINT
next_query = query_start + kBlockSizeI;
next_key = key_start;
if (next_query >= p.num_queries) {
next_key = key_start + kBlockSizeJ;
next_query = getQueryStart(p, next_key);
}
}
template <bool kForceReloadK>
static CUTLASS_DEVICE void prologueQkNextIteration(
SharedStorage& shared_storage, // NOLINT
Params const& p, // NOLINT
int32_t query_start,
int32_t key_start) {
if (query_start >= p.num_queries || key_start >= p.num_keys) {
return;
}
static constexpr bool kReloadK =
kForceReloadK || !MatmulQK::Mma::kSmemContainsEntireMat;
auto thread_id = get_thread_id();
typename MatmulQK::Mma::IteratorA iterator_A(
{int32_t(p.k_strideM)},
p.key_ptr + key_start * p.k_strideM,
{p.num_keys - key_start, p.head_dim},
thread_id,
cutlass::MatrixCoord{0, 0});
typename MatmulQK::Mma::IteratorB iterator_B(
{int32_t(p.q_strideM)},
p.query_ptr + query_start * p.q_strideM,
{p.head_dim, p.num_queries - query_start},
thread_id,
cutlass::MatrixCoord{0, 0});
MatmulQK::Mma::prologue<kReloadK, true>(shared_storage.mm_qk_k(),
shared_storage.mm_qk_q(),
iterator_A,
iterator_B,
thread_id,
p.head_dim);
}
template <bool skipBoundsChecks>
static CUTLASS_DEVICE void writeFragsToGmem(
SharedStorage& shared_storage, // NOLINT
OutputFragments& output_frags, // NOLINT
Params const& p, // NOLINT
int32_t key_start) {
int32_t num_keys_in_block =
skipBoundsChecks ? MatmulQK::Mma::Shape::kM
: std::min((int32_t)MatmulQK::Mma::Shape::kM,
p.num_keys - key_start);
typename MatmulGradV::OutputTileIterator outputV_it(
typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()},
p.grad_value_ptr + key_start * p.gV_strideM(),
{num_keys_in_block, p.head_dim_value},
get_thread_id());
accumulateInGmem<MatmulGradV>(shared_storage.gradV_epilogue_final(),
output_frags.gradV,
outputV_it,
true);
typename MatmulGradK::OutputTileIterator outputK_it(
typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()},
p.grad_key_ptr + key_start * p.gK_strideM(),
{num_keys_in_block,
false ? MatmulGradK::ThreadblockShape::kN : p.head_dim},
get_thread_id());
accumulateInGmem<MatmulGradK>(shared_storage.gradK_epilogue_final(),
output_frags.gradK,
outputK_it,
true);
}
template <typename MatmulT>
static CUTLASS_DEVICE void accumulateInGmem(
typename MatmulT::DefaultEpilogue::SharedStorage&
epilogue_smem, // NOLINT
typename MatmulT::Mma::FragmentC const& accum, // NOLINT
typename MatmulT::OutputTileIterator output_it,
bool first) {
using DefaultEpilogue = typename MatmulT::DefaultEpilogue;
using DefaultOutputOp = typename MatmulT::DefaultOutputOp;
using Mma = typename MatmulT::Mma;
DISPATCH_BOOL(
first, kIsFirst, ([&]() {
static constexpr auto ScaleType =
kIsFirst ? cutlass::epilogue::thread::ScaleType::Nothing
: cutlass::epilogue::thread::ScaleType::NoBetaScaling;
using EpilogueOutputOp =
typename cutlass::epilogue::thread::LinearCombination<
typename DefaultOutputOp::ElementOutput,
DefaultOutputOp::kCount,
typename DefaultOutputOp::ElementAccumulator,
typename DefaultOutputOp::ElementCompute,
ScaleType>;
using Epilogue =
typename cutlass::epilogue::threadblock::EpiloguePipelined<
typename DefaultEpilogue::Shape,
typename Mma::Operator,
DefaultEpilogue::kPartitionsK,
typename MatmulT::OutputTileIterator,
typename DefaultEpilogue::AccumulatorFragmentIterator,
typename DefaultEpilogue::WarpTileIterator,
typename DefaultEpilogue::SharedLoadIterator,
EpilogueOutputOp,
typename DefaultEpilogue::Padding,
DefaultEpilogue::kFragmentsPerIteration,
true // IterationsUnroll
>;
EpilogueOutputOp rescale({1, 1});
Epilogue epilogue(
epilogue_smem, get_thread_id(), get_warp_id(), get_lane_id());
epilogue(rescale, output_it, accum, output_it);
}));
}
template <int kElementsPerAccess>
static CUTLASS_DEVICE void computeDelta(Params const& p,
int32_t query_start) {
// Each thread computes one value for Delta
// Depending on warp configuration, we might have multiple
// threads of the same warp working on the same row
using AccessType = cutlass::Array<scalar_t, kElementsPerAccess>;
static_assert(kNumThreads >= kBlockSizeI, "");
static constexpr int kNumThreadsPerLine = kNumThreads / kBlockSizeI;
int16_t thread_id = get_thread_id();
int16_t laneFirstCol =
kElementsPerAccess * (get_lane_id() % kNumThreadsPerLine);
int16_t laneRow = thread_id / kNumThreadsPerLine;
bool rowPred = (query_start + laneRow) < p.num_queries;
bool pred = rowPred;
// on windows, previous syntax __restrict__ AccessType*
// resulted in error: "restrict" is not allowed
const AccessType* __restrict__ grad_output_ptr =
reinterpret_cast<const AccessType*>(
p.grad_output_ptr + (query_start + laneRow) * p.gO_strideM +
laneFirstCol);
const AccessType* __restrict__ output_ptr =
reinterpret_cast<const AccessType*>(
p.output_ptr + (query_start + laneRow) * p.o_strideM() +
laneFirstCol);
static constexpr int64_t kMaxIters =
kMaxK / (kElementsPerAccess * kNumThreadsPerLine);
constexpr int kPipelineStages = 2;
accum_t delta_value = accum_t(0);
using GlobalLoad =
cutlass::arch::global_load<AccessType, sizeof(AccessType)>;
AccessType frag_grad_output[kPipelineStages];
AccessType frag_output[kPipelineStages];
auto loadAndIncrement = [&](int ld_pos, bool is_valid) {
frag_grad_output[ld_pos].clear();
frag_output[ld_pos].clear();
GlobalLoad(frag_grad_output[ld_pos], grad_output_ptr, is_valid);
GlobalLoad(frag_output[ld_pos], output_ptr, is_valid);
grad_output_ptr += kNumThreadsPerLine;
output_ptr += kNumThreadsPerLine;
};
CUTLASS_PRAGMA_UNROLL
for (int iter = 0; iter < kPipelineStages - 1; ++iter) {
int ld_pos = iter % kPipelineStages;
pred = pred && (laneFirstCol + iter * kElementsPerAccess *
kNumThreadsPerLine) < p.head_dim_value;
loadAndIncrement(ld_pos, pred);
}
auto columnIteration = [&](int iter) {
// Load for next iter
int ld_pos = (iter + kPipelineStages - 1) % kPipelineStages;
pred = pred &&
(laneFirstCol + (iter + kPipelineStages - 1) * kElementsPerAccess *
kNumThreadsPerLine) < p.head_dim_value;
loadAndIncrement(ld_pos, pred);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < AccessType::kElements; ++i) {
delta_value += accum_t(frag_output[iter % kPipelineStages][i]) *
accum_t(frag_grad_output[iter % kPipelineStages][i]);
}
};
// If we have a small lower-bound for K, we can unroll the loop
if (kMaxK <= 256) {
CUTLASS_PRAGMA_UNROLL
for (int iter = 0; iter < kMaxIters; ++iter) {
columnIteration(iter);
}
} else {
int num_iters =
ceil_div(p.head_dim_value, kElementsPerAccess * kNumThreadsPerLine) *
(kElementsPerAccess * kNumThreadsPerLine);
for (int iter = 0; iter < num_iters; ++iter) {
columnIteration(iter);
}
}
// Reduce between workers
static_assert(kNumThreadsPerLine == 1 || kNumThreadsPerLine == 2 ||
kNumThreadsPerLine == 4,
"");
CUTLASS_PRAGMA_UNROLL
for (int i = 1; i < kNumThreadsPerLine; i *= 2) {
delta_value = delta_value + __shfl_xor_sync(0xffffffff, delta_value, i);
}
// Store in gmem
if (rowPred) {
p.delta_ptr[query_start + laneRow] = delta_value;
}
}
static CUTLASS_DEVICE int8_t get_lane_id() { return threadIdx.x; }
static CUTLASS_DEVICE int8_t get_warp_id() { return threadIdx.y; }
static CUTLASS_DEVICE int16_t get_thread_id() {
return threadIdx.x + threadIdx.y * blockDim.x;
}
};
template <typename AK>
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
attention_kernel_backward_batched_impl(typename AK::Params p) {
if (!p.advance_to_block()) {
return;
}
AK::attention_kernel(p);
}
template <typename AK>
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
attention_kernel_backward_batched(typename AK::Params params);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <curand_kernel.h>
#include <cmath>
#include <vector>
#include "cutlass/bfloat16.h"
#include "cutlass/fast_math.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/vector.h"
#include "cutlass/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/kernel/default_gemm.h"
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/platform/platform.h"
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
#include <inttypes.h> //NOLINT
#include "./debug_utils.h"
#include "./gemm_kernel_utils.h"
#include "epilogue/epilogue_pipelined.h"
#include "epilogue/epilogue_rescale_output.h"
#include "gemm/find_default_mma.h"
#include "gemm/mma_from_smem.h"
#include "transform/tile_smem_loader.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/phi/core/enforce.h"
// namespace phi {
using namespace gemm_kernel_utils; // NOLINT
namespace { // NOLINT
template <typename scalar_t, typename Arch>
constexpr int getWarpsPerSm() {
return (Arch::kMinComputeCapability >= 80 &&
!cutlass::platform::is_same<scalar_t, float>::value
? 16
: 12);
}
static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) {
// source: https://stackoverflow.com/a/51549250
return (value >= 0) ? __int_as_float(atomicMax(
(int*)addr, __float_as_int(value))) // NOLINT
: __uint_as_float(atomicMin((unsigned int*)addr,
__float_as_uint(value)));
}
} // namespace
template <
// The datatype of Q/K/V
typename scalar_t_,
// Architecture we are targeting (eg `cutlass::arch::Sm80`)
typename ArchTag,
// If Q/K/V are correctly aligned in memory and we can run a fast kernel
bool isAligned_,
int kQueriesPerBlock,
int kKeysPerBlock_,
bool kSingleValueIteration_, // = `value.shape[-1] <= kKeysPerBlock`
// This is quite slower on V100 for some reason
// Set to false if you know at compile-time you will never need dropout
bool kSupportsDropout_ = true,
bool kSupportsBias_ = true>
struct AttentionKernel {
using scalar_t = scalar_t_;
using accum_t = float;
using lse_scalar_t = float;
using output_t = scalar_t;
// Accumulator between 2 iterations
// Using `accum_t` improves perf on f16 at the cost of
// numerical errors
using output_accum_t = accum_t;
static constexpr bool kSupportsDropout = kSupportsDropout_;
static constexpr bool kSupportsBias = kSupportsBias_;
static constexpr int kKeysPerBlock = kKeysPerBlock_;
static constexpr bool kIsAligned = isAligned_;
static constexpr bool kSingleValueIteration = kSingleValueIteration_;
static constexpr int32_t kAlignLSE = 32; // block size of backward
static constexpr bool kPreloadV = ArchTag::kMinComputeCapability >= 80 &&
cutlass::sizeof_bits<scalar_t>::value == 16;
static constexpr bool kKeepOutputInRF = kSingleValueIteration;
static constexpr bool kNeedsOutputAccumulatorBuffer =
!kKeepOutputInRF &&
!cutlass::platform::is_same<output_accum_t, output_t>::value;
static_assert(kQueriesPerBlock % 32 == 0, "");
static_assert(kKeysPerBlock % 32 == 0, "");
static constexpr int kNumWarpsPerBlock =
kQueriesPerBlock * kKeysPerBlock / (32 * 32);
static constexpr int kWarpSize = 32;
// Launch bounds
static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock;
static constexpr int kMinBlocksPerSm =
getWarpsPerSm<scalar_t, ArchTag>() / kNumWarpsPerBlock;
struct Params {
// Input tensors
scalar_t* query_ptr; // [num_queries, num_heads, head_dim]
scalar_t* key_ptr; // [num_keys, num_heads, head_dim]
scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value]
scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys]
int32_t* seqstart_q_ptr = nullptr;
int32_t* seqstart_k_ptr = nullptr;
int32_t* causal_diagonal_ptr = nullptr;
int32_t* seqlen_k_ptr = nullptr;
uint32_t causal_diagonal_offset = 0;
// Output tensors
output_t* output_ptr; // [num_queries, num_heads, head_dim_value]
output_accum_t*
output_accum_ptr; // [num_queries, num_heads, head_dim_value]
lse_scalar_t* logsumexp_ptr; // [num_heads, num_queries] - can be null
// Scale
accum_t scale;
// Dimensions/strides
int32_t head_dim;
int32_t head_dim_value;
int32_t num_queries;
int32_t num_keys;
bool causal;
int32_t q_strideM;
int32_t k_strideM;
int32_t v_strideM;
int32_t bias_strideM = 0;
int32_t o_strideM = 0;
// Everything below is only used in `advance_to_block`
// and shouldn't use registers
int32_t q_strideH;
int32_t k_strideH;
int32_t v_strideH;
int32_t bias_strideH = 0;
int64_t q_strideB;
int64_t k_strideB;
int64_t v_strideB;
int32_t bias_strideB = 0;
int32_t num_batches;
int32_t num_heads;
// dropout
bool use_dropout;
unsigned long long dropout_batch_head_rng_offset; // NOLINT
float dropout_prob;
uint64_t seed;
uint64_t offset;
// Moves pointers to what we should process
// Returns "false" if there is no work to do
CUTLASS_DEVICE bool advance_to_block() {
auto batch_id = blockIdx.z;
auto head_id = blockIdx.y;
auto query_start = blockIdx.x * kQueriesPerBlock;
auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE;
if (kSupportsDropout) {
dropout_batch_head_rng_offset =
batch_id * num_heads * num_queries * num_keys +
head_id * num_queries * num_keys;
}
int64_t q_start, k_start;
// Advance to current batch - in case of different sequence lengths
if (seqstart_q_ptr != nullptr) {
assert(seqstart_k_ptr != nullptr);
seqstart_q_ptr += batch_id;
q_start = seqstart_q_ptr[0];
int64_t q_next_start = seqstart_q_ptr[1];
int64_t k_end;
seqstart_k_ptr += batch_id;
if (seqlen_k_ptr) {
k_start = seqstart_k_ptr[0];
k_end = k_start + seqlen_k_ptr[batch_id];
} else {
k_start = seqstart_k_ptr[0];
k_end = seqstart_k_ptr[1];
}
num_queries = q_next_start - q_start;
num_keys = k_end - k_start;
if (query_start >= num_queries) {
return false;
}
} else {
query_ptr += batch_id * q_strideB;
key_ptr += batch_id * k_strideB;
value_ptr += batch_id * v_strideB;
output_ptr += int64_t(batch_id * num_queries) * o_strideM;
if (output_accum_ptr != nullptr) {
output_accum_ptr +=
int64_t(batch_id * num_queries) * (head_dim_value * num_heads);
}
q_start = 0;
k_start = 0;
}
// Advance to the current batch / head / query_start
query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH;
key_ptr += k_start * k_strideM + head_id * k_strideH;
value_ptr += k_start * v_strideM + head_id * v_strideH;
output_ptr +=
int64_t(q_start + query_start) * o_strideM + head_id * head_dim_value;
if (kSupportsBias && attn_bias_ptr != nullptr) {
attn_bias_ptr += (batch_id * bias_strideB) + (head_id * bias_strideH);
}
if (output_accum_ptr != nullptr) {
output_accum_ptr +=
int64_t(q_start + query_start) * (head_dim_value * num_heads) +
head_id * head_dim_value;
} else {
// Accumulate directly in the destination buffer (eg for f32)
output_accum_ptr = (accum_t*)output_ptr; // NOLINT
}
if (logsumexp_ptr != nullptr) {
// lse[batch_id, head_id, query_start]
logsumexp_ptr +=
batch_id * lse_dim * num_heads + head_id * lse_dim + query_start;
}
if (causal_diagonal_ptr) {
causal_diagonal_offset = causal_diagonal_ptr[batch_id];
}
num_queries -= query_start;
if (causal) {
// the bottom row of the current block is query_start + kQueriesPerBlock
// the last active key is then query_start + causal_diagonal_offset +
// kQueriesPerBlock so num_keys is the min between actual num_keys and
// this to avoid extra computations
num_keys = cutlass::fast_min(
int32_t(query_start + causal_diagonal_offset + kQueriesPerBlock),
num_keys);
}
num_batches = 0; // no longer used after
// If num_queries == 1, and there is only one key head we're wasting
// 15/16th of tensor core compute In that case :
// - we only launch kernels for head_id % kQueriesPerBlock == 0
// - we iterate over heads instead of queries (strideM = strideH)
if (num_queries == 1 && k_strideH == 0) {
if (head_id % kQueriesPerBlock != 0) return false;
q_strideM = q_strideH;
num_queries = num_heads;
num_heads = 1; // unused but here for intent
// remove causal since n_query = 1
// otherwise, offset would change with head !
causal = false;
o_strideM = head_dim_value;
}
// Make sure the compiler knows these variables are the same on all
// the threads of the warp.
query_ptr = warp_uniform(query_ptr);
key_ptr = warp_uniform(key_ptr);
value_ptr = warp_uniform(value_ptr);
if (kSupportsBias) {
attn_bias_ptr = warp_uniform(attn_bias_ptr);
}
output_ptr = warp_uniform(output_ptr);
output_accum_ptr = warp_uniform(output_accum_ptr);
logsumexp_ptr = warp_uniform(logsumexp_ptr);
num_queries = warp_uniform(num_queries);
num_keys = warp_uniform(num_keys);
num_heads = warp_uniform(num_heads);
head_dim = warp_uniform(head_dim);
head_dim_value = warp_uniform(head_dim_value);
o_strideM = warp_uniform(o_strideM);
return true;
}
__host__ dim3 getBlocksGrid() const {
return dim3(ceil_div(num_queries, (int32_t)kQueriesPerBlock),
num_heads,
num_batches);
}
__host__ dim3 getThreadsGrid() const {
return dim3(kWarpSize, kNumWarpsPerBlock, 1);
}
};
struct MM0 {
/*
In this first matmul, we compute a block of `Q @ K.T`.
While the calculation result is still hot in registers, we update
`mi`, `m_prime`, `s_prime` in shared-memory, and then store this value
into a shared-memory ("AccumulatorSharedStorage") that is used later as
operand A for the second matmul (see MM1)
*/
using GemmType = DefaultGemmType<ArchTag, scalar_t>;
using OpClass = typename GemmType::OpClass;
using DefaultConfig =
typename cutlass::gemm::device::DefaultGemmConfiguration<
OpClass,
ArchTag,
scalar_t,
scalar_t,
scalar_t, // ElementC
accum_t // ElementAccumulator
>;
static constexpr int kAlignmentA =
kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment;
static constexpr int kAlignmentB =
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
using ThreadblockShape = cutlass::gemm::
GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma<
scalar_t, // ElementA,
cutlass::layout::RowMajor, // LayoutA,
kAlignmentA,
scalar_t, // ElementB,
cutlass::layout::ColumnMajor, // LayoutB,
kAlignmentB,
accum_t,
cutlass::layout::RowMajor, // LayoutC,
OpClass,
ArchTag, // ArchTag
ThreadblockShape, // ThreadblockShape
WarpShape, // WarpShape
typename GemmType::InstructionShape, // InstructionShape
DefaultConfig::kStages, // Should use `DefaultConfig::kStages`, but
// that uses too much smem
typename GemmType::Operator // Operator
>::DefaultMma;
using MmaCore = typename DefaultMma::MmaCore;
using IteratorA = typename DefaultMma::IteratorA;
using IteratorB = typename DefaultMma::IteratorB;
using Mma = typename DefaultMma::ThreadblockMma;
using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator<
typename Mma::Operator::IteratorC,
accum_t,
kWarpSize>::Iterator;
static_assert(MmaCore::WarpCount::kM * MmaCore::WarpCount::kN *
MmaCore::WarpCount::kK ==
kNumWarpsPerBlock,
"");
// used for efficient load of bias tile Bij from global to shared memory
using BiasLoader =
TileSmemLoader<scalar_t,
cutlass::MatrixShape<kQueriesPerBlock, kKeysPerBlock>,
MmaCore::kThreads,
// input restriction: kv_len has to be a multiple of this
// value
128 / cutlass::sizeof_bits<scalar_t>::value>;
// Epilogue to store to shared-memory in a format that we can use later for
// the second matmul
using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm<
typename Mma::Operator::IteratorC,
typename Mma::Operator,
scalar_t,
WarpShape,
ThreadblockShape>;
using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage;
};
struct MM1 {
/**
Second matmul: perform `attn @ V` where `attn` is the attention (not
normalized) and stored in shared memory
*/
using GemmType = DefaultGemmType<ArchTag, scalar_t>;
using OpClass = typename GemmType::OpClass;
using DefaultConfig =
typename cutlass::gemm::device::DefaultGemmConfiguration<
OpClass,
ArchTag,
scalar_t,
scalar_t,
output_accum_t, // ElementC
accum_t // ElementAccumulator
>;
static constexpr int kAlignmentA = DefaultConfig::kAlignmentA; // from smem
static constexpr int kAlignmentB =
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
using ThreadblockShape = cutlass::gemm::
GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
using InstructionShape = typename GemmType::InstructionShape;
using LayoutB = cutlass::layout::RowMajor;
using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<
scalar_t, // ElementA,
cutlass::layout::RowMajor, // LayoutA,
kAlignmentA,
scalar_t, // ElementB,
LayoutB, // LayoutB,
kAlignmentB,
output_accum_t,
cutlass::layout::RowMajor, // LayoutC,
accum_t,
OpClass,
ArchTag,
ThreadblockShape,
WarpShape,
typename GemmType::InstructionShape,
typename DefaultConfig::EpilogueOutputOp,
void, // ThreadblockSwizzle - not used
DefaultConfig::kStages,
false, // SplitKSerial
typename GemmType::Operator>;
using DefaultMmaFromSmem =
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
typename DefaultGemm::Mma,
typename MM0::AccumulatorSharedStorage,
false>; // kScaleOperandA
using Mma = typename DefaultMmaFromSmem::Mma;
using IteratorB = typename Mma::IteratorB;
using WarpCount = typename Mma::WarpCount;
static_assert(WarpCount::kM * WarpCount::kN * WarpCount::kK ==
kNumWarpsPerBlock,
"");
using DefaultEpilogue = typename DefaultGemm::Epilogue;
using OutputTileIterator =
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
output_t>;
using OutputTileIteratorAccum =
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
output_accum_t>;
struct SharedStorageMM1 {
typename Mma::SharedStorage mm;
};
};
static constexpr int64_t kAlignmentQ = MM0::kAlignmentA;
static constexpr int64_t kAlignmentK = MM0::kAlignmentB;
static constexpr int64_t kAlignmentV = 1;
// Shared storage - depends on kernel params
struct ScalingCoefs {
cutlass::Array<accum_t, kQueriesPerBlock> m_prime;
cutlass::Array<accum_t, kQueriesPerBlock> s_prime;
cutlass::Array<accum_t, kQueriesPerBlock> mi;
};
struct SharedStorageEpilogueAtEnd : ScalingCoefs {
struct SharedStorageAfterMM0 {
// Everything here might be overwritten during MM0
union {
typename MM0::BiasLoader::SmemTile bias;
typename MM0::AccumulatorSharedStorage si;
};
typename MM1::SharedStorageMM1 mm1;
};
union {
typename MM0::Mma::SharedStorage mm0;
SharedStorageAfterMM0 after_mm0;
typename MM1::DefaultEpilogue::SharedStorage epilogue;
};
CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
epilogue_shared_storage() {
return epilogue;
}
};
struct SharedStorageEpilogueInLoop : ScalingCoefs {
struct SharedStorageAfterMM0 {
// Everything here might be overwritten during MM0
union {
typename MM0::BiasLoader::SmemTile bias;
typename MM0::AccumulatorSharedStorage si;
};
typename MM1::SharedStorageMM1 mm1;
typename MM1::DefaultEpilogue::SharedStorage epilogue;
};
union {
typename MM0::Mma::SharedStorage mm0;
SharedStorageAfterMM0 after_mm0;
};
CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
epilogue_shared_storage() {
return after_mm0.epilogue;
}
};
using SharedStorage = typename cutlass::platform::conditional<
kSingleValueIteration || kKeepOutputInRF,
SharedStorageEpilogueAtEnd,
SharedStorageEpilogueInLoop>::type;
static bool __host__ check_supported(Params const& p) {
CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ);
CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK);
CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV);
if (kSupportsBias) {
CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ);
PADDLE_ENFORCE_EQ(p.bias_strideB % kAlignmentQ,
0,
paddle::platform::errors::InvalidArgument(
"attn_bias is not correctly aligned"));
PADDLE_ENFORCE_EQ(p.bias_strideH % kAlignmentQ,
0,
paddle::platform::errors::InvalidArgument(
"attn_bias is not correctly aligned"));
PADDLE_ENFORCE_EQ(p.bias_strideM % kAlignmentQ,
0,
paddle::platform::errors::InvalidArgument(
"attn_bias is not correctly aligned"));
}
PADDLE_ENFORCE_EQ(p.q_strideM % kAlignmentQ,
0,
paddle::platform::errors::InvalidArgument(
"query is not correctly aligned"));
PADDLE_ENFORCE_EQ(p.k_strideM % kAlignmentK,
0,
paddle::platform::errors::InvalidArgument(
"key is not correctly aligned"));
PADDLE_ENFORCE_EQ(p.v_strideM % kAlignmentV,
0,
paddle::platform::errors::InvalidArgument(
"value is not correctly aligned"));
PADDLE_ENFORCE_EQ(p.q_strideH % kAlignmentQ,
0,
paddle::platform::errors::InvalidArgument(
"query is not correctly aligned"));
PADDLE_ENFORCE_EQ(p.k_strideH % kAlignmentK,
0,
paddle::platform::errors::InvalidArgument(
"key is not correctly aligned"));
PADDLE_ENFORCE_EQ(p.v_strideH % kAlignmentV,
0,
paddle::platform::errors::InvalidArgument(
"value is not correctly aligned"));
return true;
}
static void CUTLASS_DEVICE attention_kernel(Params& p) { // NOLINT
// In this block, we will only ever:
// - read query[query_start:query_end, :]
// - write to output[query_start:query_end, :]
extern __shared__ char smem_buffer[];
SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); // NOLINT
auto& m_prime = shared_storage.m_prime;
auto& s_prime = shared_storage.s_prime;
auto& si = shared_storage.after_mm0.si;
auto& mi = shared_storage.mi;
const uint32_t query_start = blockIdx.x * kQueriesPerBlock;
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
if (thread_id() < kQueriesPerBlock) {
s_prime[thread_id()] = accum_t(0);
m_prime[thread_id()] =
-cutlass::platform::numeric_limits<accum_t>::infinity();
mi[thread_id()] = -cutlass::platform::numeric_limits<accum_t>::infinity();
}
typename MM1::Mma::FragmentC accum_o;
accum_o.clear();
auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator {
using OutputTileIterator = typename MM1::OutputTileIterator;
return OutputTileIterator(
typename OutputTileIterator::Params{(int32_t)p.o_strideM},
p.output_ptr,
typename OutputTileIterator::TensorCoord{p.num_queries,
p.head_dim_value},
thread_id(),
{0, col});
};
auto createOutputAccumIter = [&](int col) ->
typename MM1::OutputTileIteratorAccum {
using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum;
return OutputTileIteratorAccum(
typename OutputTileIteratorAccum::Params{
(int32_t)(p.head_dim_value * p.num_heads)},
p.output_accum_ptr,
typename OutputTileIteratorAccum::TensorCoord{p.num_queries,
p.head_dim_value},
thread_id(),
{0, col});
};
curandStatePhilox4_32_10_t curand_state_init;
if (kSupportsDropout && p.use_dropout) {
// each element of the attention matrix P with shape
// (batch_sz, n_heads, n_queries, n_keys) is associated with a single
// offset in RNG sequence. we initialize the RNG state with offset that
// starts at the beginning of a (n_queries, n_keys) matrix for this
// block's batch_id and head_id
// initializing rng state is very expensive, so we run once per kernel,
// rather than once per iteration. each iteration takes a copy of the
// initialized RNG state and offsets it as needed.
curand_init(p.seed,
0,
p.offset + p.dropout_batch_head_rng_offset,
&curand_state_init);
}
// Iterate through keys
for (int32_t iter_key_start = 0; iter_key_start < p.num_keys;
iter_key_start += kKeysPerBlock) {
int32_t problem_size_0_m =
cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries);
int32_t problem_size_0_n = cutlass::fast_min(int32_t(kKeysPerBlock),
p.num_keys - iter_key_start);
int32_t const& problem_size_0_k = p.head_dim;
int32_t const& problem_size_1_n = p.head_dim_value;
int32_t const& problem_size_1_k = problem_size_0_n;
auto prologueV = [&](int blockN) {
typename MM1::Mma::IteratorB iterator_V(
typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)},
p.value_ptr + iter_key_start * p.v_strideM,
{problem_size_1_k, problem_size_1_n},
thread_id(),
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
MM1::Mma::prologue(shared_storage.after_mm0.mm1.mm,
iterator_V,
thread_id(),
problem_size_1_k);
};
__syncthreads(); // Need to have shared memory initialized, and `m_prime`
// updated from end of prev iter
//
// MATMUL: Q.K_t
//
// Computes the block-matrix product of:
// (a) query[query_start:query_end, :]
// with
// (b) key[iter_key_start:iter_key_start + kKeysPerBlock]
// and stores that into `shared_storage.si`
//
// Compute threadblock location
cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0};
cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * MM0::Mma::Shape::kM,
tb_tile_offset.k()};
cutlass::MatrixCoord tb_offset_B{
tb_tile_offset.k(), tb_tile_offset.n() * MM0::Mma::Shape::kN};
// Construct iterators to A and B operands
typename MM0::IteratorA iterator_A(
typename MM0::IteratorA::Params(
typename MM0::MmaCore::LayoutA(p.q_strideM)),
p.query_ptr,
{problem_size_0_m, problem_size_0_k},
thread_id(),
tb_offset_A);
typename MM0::IteratorB iterator_B(
typename MM0::IteratorB::Params(
typename MM0::MmaCore::LayoutB(p.k_strideM)),
p.key_ptr + iter_key_start * p.k_strideM,
{problem_size_0_k, problem_size_0_n},
thread_id(),
tb_offset_B);
auto my_warp_id = warp_id();
auto my_lane_id = lane_id();
// Construct thread-scoped matrix multiply
typename MM0::Mma mma(
shared_storage.mm0, thread_id(), my_warp_id, my_lane_id);
typename MM0::Mma::FragmentC accum;
accum.clear();
auto gemm_k_iterations =
(problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
__syncthreads();
if (kPreloadV) {
prologueV(0);
}
typename MM0::Mma::Operator::IteratorC::TensorCoord
iteratorC_tile_offset = {
(tb_tile_offset.m() * MM0::Mma::WarpCount::kM) +
(my_warp_id % MM0::Mma::WarpCount::kM),
(tb_tile_offset.n() * MM0::Mma::WarpCount::kN) +
(my_warp_id / MM0::Mma::WarpCount::kM)};
// multiply by scaling factor
if (kSupportsBias) {
accum =
cutlass::multiplies<typename MM0::Mma::FragmentC>()(p.scale, accum);
}
// apply attention bias if applicable
if (kSupportsBias && p.attn_bias_ptr != nullptr) {
// load bias tile Bij into shared memory
typename MM0::BiasLoader::GmemTileIterator bias_iter(
{cutlass::layout::RowMajor(p.bias_strideM)},
// attn_bias_pointer points to matrix of size (n_queries, n_keys)
// for the relevant batch_id and head_id
p.attn_bias_ptr + query_start * p.bias_strideM + iter_key_start,
{problem_size_0_m, problem_size_0_n},
thread_id());
cutlass::TensorRef<scalar_t, cutlass::layout::RowMajor> bias_tensor_ref(
shared_storage.after_mm0.bias.data(),
cutlass::layout::RowMajor(MM0::ThreadblockShape::kN));
typename MM0::BiasLoader::SmemTileIterator smem_tile_iter(
bias_tensor_ref, thread_id());
MM0::BiasLoader::load(bias_iter, smem_tile_iter);
// Pij += Bij, Pij is in register fragment and Bij is in shared memory
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
lane_id(), warp_id(), iteratorC_tile_offset);
MM0::AccumLambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {},
[&](int accum_m, int accum_n, int idx) {
if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) {
accum[idx] += bias_tensor_ref.at({accum_m, accum_n});
}
},
[&](int accum_m) {});
}
// Mask out last if causal
// This is only needed if upper-right corner of current query / key block
// intersects the mask Coordinates of upper-right corner of current block
// is y=query_start x=min(iter_key_start + kKeysPerBlock, num_keys)) The
// first masked element is x = y + offset -> query_start + offset There is
// intersection (and we need to mask) if min(iter_key_start +
// kKeysPerBlock, num_keys)) >= query_start + offset
if (p.causal &&
cutlass::fast_min(iter_key_start + kKeysPerBlock, p.num_keys) >=
(query_start + p.causal_diagonal_offset)) {
auto query_start = blockIdx.x * kQueriesPerBlock;
auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset(
lane_id(), warp_id(), iteratorC_tile_offset);
int32_t last_col;
MM0::AccumLambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {
// last absolute col is (last absolute query + offset)
// last local col is (last absolute query + offset -
// iter_key_start)
last_col = query_start + accum_m + p.causal_diagonal_offset -
iter_key_start;
},
[&](int accum_m, int accum_n, int idx) {
if (accum_n > last_col) {
accum[idx] =
-cutlass::platform::numeric_limits<accum_t>::infinity();
}
},
[&](int accum_m) {});
}
DISPATCH_BOOL(
iter_key_start == 0, kIsFirst, ([&] {
DISPATCH_BOOL(
p.num_keys - iter_key_start >= kKeysPerBlock,
kFullColumns,
([&] {
// Update `mi` from accum stored in registers
// Also does accum[i] <- exp(accum[i] - mi)
iterative_softmax<typename MM0::Mma::Operator::IteratorC,
kFullColumns,
kIsFirst,
kKeepOutputInRF>(
accum_o,
accum,
mi,
m_prime,
s_prime,
lane_id(),
thread_id(),
warp_id(),
p.num_keys - iter_key_start,
iteratorC_tile_offset,
kSupportsBias ? 1.0f : p.scale);
}));
}));
// Output results to shared-memory
int warp_idx_mn_0 = my_warp_id % (MM0::Mma::Base::WarpCount::kM *
MM0::Mma::Base::WarpCount::kN);
auto output_tile_coords =
cutlass::MatrixCoord{warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM,
warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM};
MM0::B2bGemm::accumToSmem(
shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords);
__syncthreads();
// apply dropout (if applicable) after we've written Pij to smem.
// dropout is applied by multiplying each element of Pij by:
// - 0 with probability dropout_p
// - 1 / (1 - dropout_p) with probability 1 - dropout_p
//
// for backward purposes we want to be able to map each element of the
// attention matrix to the same random uniform number as the one we used
// in forward, without needing to use the same iteration order or having
// to store the dropout matrix. its possible to do this in registers but
// it ends up being very slow because each thread having noncontiguous
// strips of the Pij tile means we have to skip around a lot, and also
// have to generate a single random number at a time
if (kSupportsDropout && p.use_dropout) {
auto si = shared_storage.after_mm0.si.accum_ref();
// each thread handles a contiguous sequence of elements from Sij, all
// coming from the same row. the reason they have to come from the same
// row is that the sampling random numbers from a contiguous random
// number sequence is much more efficient than jumping around, and the
// linear offset of each element of S (the global matrix) maps to an
// offset in a random number sequence. for S, the end of a row and the
// beginning of the next have adjacent offsets, but for Sij, this is not
// necessarily the case.
const int num_threads = blockDim.x * blockDim.y * blockDim.z;
const int threads_per_row =
cutlass::fast_min(num_threads / problem_size_0_m, problem_size_0_n);
const int elts_per_thread = cutlass::round_nearest(
cutlass::ceil_div(problem_size_0_n, threads_per_row), 4);
const int thread_i = thread_id() / threads_per_row;
const int thread_start_j =
(thread_id() % threads_per_row) * elts_per_thread;
if (thread_i < problem_size_0_m && thread_start_j < problem_size_0_n) {
curandStatePhilox4_32_10_t curand_state = curand_state_init;
skipahead(static_cast<unsigned long long>( // NOLINT
(query_start + thread_i) * p.num_keys +
(iter_key_start + thread_start_j)),
&curand_state);
const float dropout_scale = 1.0 / (1.0 - p.dropout_prob);
// apply dropout scaling to elements this thread is responsible for,
// in chunks of 4
for (int sij_start_col_idx = thread_start_j;
sij_start_col_idx <
cutlass::fast_min(thread_start_j + elts_per_thread,
problem_size_0_n);
sij_start_col_idx += 4) {
const float4 rand_uniform_quad = curand_uniform4(&curand_state);
CUTLASS_PRAGMA_UNROLL
for (int quad_idx = 0; quad_idx < 4; ++quad_idx) {
si.at({thread_i, sij_start_col_idx + quad_idx}) *=
static_cast<scalar_t>(
dropout_scale *
((&rand_uniform_quad.x)[quad_idx] > p.dropout_prob));
}
}
}
__syncthreads(); // p.use_dropout should have same value kernel-wide
}
//
// MATMUL: Attn . V
// Run the matmul `attn @ V` for a block of attn and V.
// `attn` is read from shared memory (in `shared_storage_si`)
// `V` is read from global memory (with iterator_B)
//
const int64_t nBlockN =
kSingleValueIteration ? 1
: ceil_div((int64_t)problem_size_1_n,
int64_t(MM1::ThreadblockShape::kN));
for (int blockN = 0; blockN < nBlockN; ++blockN) {
int gemm_k_iterations =
(problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add and store it in accum
// (in registers)
if (!kPreloadV) {
__syncthreads(); // we share shmem between mma and epilogue
}
typename MM1::Mma::IteratorB iterator_V(
typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)},
p.value_ptr + iter_key_start * p.v_strideM,
{problem_size_1_k, problem_size_1_n},
thread_id(),
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
typename MM1::Mma mma_pv(shared_storage.after_mm0.mm1.mm,
shared_storage.after_mm0.si,
(int)thread_id(), // NOLINT
(int)warp_id(), // NOLINT
(int)lane_id(), // NOLINT
(int)problem_size_1_k); // NOLINT
mma_pv.set_prologue_done(kPreloadV);
if (!kKeepOutputInRF) {
accum_o.clear();
}
mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o);
__syncthreads();
if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) {
prologueV(blockN + 1);
}
if (!kKeepOutputInRF) {
DISPATCH_BOOL(
iter_key_start == 0, kIsFirst, ([&] {
DISPATCH_BOOL(
(iter_key_start + kKeysPerBlock) >= p.num_keys,
kIsLast,
([&] {
using DefaultEpilogue = typename MM1::DefaultEpilogue;
using DefaultOp =
typename MM1::DefaultConfig::EpilogueOutputOp;
using ElementCompute = typename DefaultOp::ElementCompute;
using EpilogueOutputOp = typename cutlass::epilogue::
thread::MemoryEfficientAttentionNormalize<
typename cutlass::platform::conditional<
kIsLast,
output_t,
output_accum_t>::type,
output_accum_t,
DefaultOp::kCount,
typename DefaultOp::ElementAccumulator,
ElementCompute,
kIsFirst,
kIsLast,
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
using Epilogue = typename cutlass::epilogue::threadblock::
EpiloguePipelined<
typename DefaultEpilogue::Shape,
typename MM1::Mma::Operator,
DefaultEpilogue::kPartitionsK,
typename cutlass::platform::conditional<
kIsLast,
typename MM1::OutputTileIterator,
typename MM1::OutputTileIteratorAccum>::type,
typename DefaultEpilogue::
AccumulatorFragmentIterator,
typename DefaultEpilogue::WarpTileIterator,
typename DefaultEpilogue::SharedLoadIterator,
EpilogueOutputOp,
typename DefaultEpilogue::Padding,
DefaultEpilogue::kFragmentsPerIteration,
true, // IterationsUnroll
typename MM1::OutputTileIteratorAccum // Read
// iterator
>;
int col = blockN * MM1::Mma::Shape::kN;
auto source_iter = createOutputAccumIter(col);
auto dest_iter =
call_conditional<kIsLast,
decltype(createOutputIter),
decltype(createOutputAccumIter)>::
apply(
createOutputIter, createOutputAccumIter, col);
EpilogueOutputOp rescale(s_prime, m_prime);
Epilogue epilogue(
shared_storage.epilogue_shared_storage(),
thread_id(),
warp_id(),
lane_id());
epilogue(rescale, dest_iter, accum_o, source_iter);
}));
}));
if (!kSingleValueIteration) {
__syncthreads();
}
}
}
__syncthreads(); // we modify `m_prime` after
}
if (kKeepOutputInRF) {
constexpr bool kIsFirst = true;
constexpr bool kIsLast = true;
using DefaultEpilogue = typename MM1::DefaultEpilogue;
using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
using ElementCompute = typename DefaultOp::ElementCompute;
using EpilogueOutputOp =
typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize<
output_t, // output
output_accum_t, // source
DefaultOp::kCount,
typename DefaultOp::ElementAccumulator, // accum
output_accum_t, // compute
kIsFirst,
kIsLast,
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
using Epilogue =
typename cutlass::epilogue::threadblock::EpiloguePipelined<
typename DefaultEpilogue::Shape,
typename MM1::Mma::Operator,
DefaultEpilogue::kPartitionsK,
typename MM1::OutputTileIterator, // destination
typename DefaultEpilogue::AccumulatorFragmentIterator,
typename DefaultEpilogue::WarpTileIterator,
typename DefaultEpilogue::SharedLoadIterator,
EpilogueOutputOp,
typename DefaultEpilogue::Padding,
DefaultEpilogue::kFragmentsPerIteration,
true, // IterationsUnroll
typename MM1::OutputTileIteratorAccum // source tile
>;
auto dest_iter = createOutputIter(0);
EpilogueOutputOp rescale(s_prime, m_prime);
Epilogue epilogue(shared_storage.epilogue_shared_storage(),
thread_id(),
warp_id(),
lane_id());
epilogue(rescale, dest_iter, accum_o);
}
// 7. Calculate logsumexp
// To make the backward easier, we pad logsumexp with `inf`
// this avoids a few bound checks, and is not more expensive during fwd
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) {
auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE;
if (thread_id() < p.num_queries) {
p.logsumexp_ptr[thread_id()] =
accum_t(mi[thread_id()]) +
cutlass::fast_log(accum_t(s_prime[thread_id()]));
} else if (thread_id() < lse_dim) {
p.logsumexp_ptr[thread_id()] =
cutlass::platform::numeric_limits<accum_t>::infinity();
}
}
}
template <typename WarpIteratorC,
bool kFullColumns,
bool kIsFirst,
bool kKeepOutputInRF>
CUTLASS_DEVICE static void iterative_softmax(
typename WarpIteratorC::Fragment& frag_o, // output so far //NOLINT
typename WarpIteratorC::Fragment& frag, // NOLINT
cutlass::Array<accum_t, kQueriesPerBlock>& mi, // NOLINT
cutlass::Array<accum_t, kQueriesPerBlock>& m_prime, // NOLINT
cutlass::Array<accum_t, kQueriesPerBlock>& s_prime, // NOLINT
int8_t lane_id,
int8_t thread_id,
int8_t warp_id,
int16_t max_col,
typename WarpIteratorC::TensorCoord const& tile_offset,
float scaling) {
/* Iterates on the accumulator and corresponding position on result matrix
(1) Update `mi[r]` to the max value of the row `r`
(2) In a second iteration do the following:
(a) accum <- exp(accum - mi)
(b) m_prime <- exp(m_prime - mi)
(c) s_prime <- s_prime * m_prime + sum(accum)
All of this is done on registers, before we store all of this
on shared memory for the next matmul with Value.
*/
using Fragment = typename WarpIteratorC::Fragment;
using LambdaIterator =
typename DefaultMmaAccumLambdaIterator<WarpIteratorC,
accum_t,
kWarpSize>::Iterator;
// Convert to `accum_t` (rather than double)
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
if (!kIsFirst) {
if (thread_id < kQueriesPerBlock) {
m_prime[thread_id] = mi[thread_id];
}
__syncthreads();
}
auto lane_offset =
LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset);
// First update `mi` to the max per-row
{
accum_t max;
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) {
max = -cutlass::platform::numeric_limits<accum_t>::infinity();
},
[&](int accum_m, int accum_n, int idx) {
if (kFullColumns || accum_n < max_col) {
max = cutlass::fast_max(max, frag[idx]);
}
},
[&](int accum_m) {
// Having 4x atomicMax seems faster than reduce within warp
// first...
atomicMaxFloat(&mi[accum_m], max * scaling);
});
}
frag = cutlass::multiplies<Fragment>()(scaling * kLog2e, frag);
// Make sure we all share the update values for `mi`
__syncthreads();
if (thread_id < kQueriesPerBlock) {
auto m_prime_exp = exp2f(kLog2e * (m_prime[thread_id] - mi[thread_id]));
m_prime[thread_id] = m_prime_exp;
s_prime[thread_id] *= m_prime_exp;
}
__syncthreads(); // Update output fragments
if (kKeepOutputInRF && !kIsFirst) {
accum_t mp;
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) { mp = m_prime[accum_m]; },
[&](int accum_m, int accum_n, int idx) { frag_o[idx] *= mp; },
[&](int accum_m) {});
__syncthreads();
}
// Update accum_m, accum_n, ...
{
accum_t mi_row, total_row;
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) { mi_row = kLog2e * mi[accum_m]; },
[&](int accum_m, int accum_n, int idx) {
frag[idx] = (kFullColumns || accum_n < max_col)
? exp2f(frag[idx] - mi_row)
: accum_t(0.0);
},
[&](int accum_m) {});
LambdaIterator::iterateRows(
lane_offset,
[&](int accum_m) { total_row = 0.0; },
[&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; },
[&](int accum_m) {
if (LambdaIterator::reduceSameRow(
lane_id, total_row, [](accum_t a, accum_t b) {
return a + b;
})) {
atomicAdd(&s_prime[accum_m], total_row);
}
});
}
}
static CUTLASS_DEVICE int8_t lane_id() { return threadIdx.x; }
static CUTLASS_DEVICE int8_t warp_id() { return threadIdx.y; }
static CUTLASS_DEVICE int16_t thread_id() {
return threadIdx.x + threadIdx.y * blockDim.x;
}
};
template <typename AK>
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
attention_kernel_batched_impl(typename AK::Params p) {
if (!p.advance_to_block()) {
return;
}
AK::attention_kernel(p);
}
template <typename AK>
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
attention_kernel_batched(typename AK::Params params);
// } // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This source code is licensed under the BSD license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <cutlass/cutlass.h>
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/numeric_types.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
template <typename scalar_t, // scalar type
typename ThreadblockTileShape, // size of tile to load
int Threads, // number of participating threads
int ElementsPerAccess> // thread access width in elements
class TileSmemLoader {
public:
using SmemTile =
cutlass::AlignedBuffer<scalar_t, ThreadblockTileShape::kCount>;
using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap<
cutlass::layout::PitchLinearShape<
ThreadblockTileShape::kColumn, // contiguous
ThreadblockTileShape::kRow>, // strided
Threads, // Threads
ElementsPerAccess>; // ElementsPerAccess
using GmemTileIterator =
cutlass::transform::threadblock::PredicatedTileIterator<
ThreadblockTileShape, // Shape
scalar_t, // Element
cutlass::layout::RowMajor, // Layout
0, // AdvanceRank
ThreadMap>; // ThreadMap
using SmemTileIterator = cutlass::transform::threadblock::RegularTileIterator<
ThreadblockTileShape, // Shape
scalar_t, // Element
cutlass::layout::RowMajor, // Layout
0, // AdvanceRank
ThreadMap>; // ThreadMap
using Fragment = typename GmemTileIterator::Fragment;
/// load a tile from global memory into shared memory
CUTLASS_DEVICE
static void load(GmemTileIterator tile_load_iter,
SmemTileIterator tile_store_iter) {
Fragment tb_frag;
tb_frag.clear();
tile_load_iter.load(tb_frag);
tile_store_iter.store(tb_frag);
__syncthreads();
}
};
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/phi/api/include/tensor_operants.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/autogen/memory_efficient_attention.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/cum_kernel.h"
#include "paddle/phi/kernels/elementwise_add_kernel.h"
#include "paddle/phi/kernels/elementwise_multiply_kernel.h"
#include "paddle/phi/kernels/funcs/get_pad_lse.cu.h"
#include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
template <typename T, typename Context>
void MemoryEfficientAttentionBackwardKernel(
const Context& ctx,
const DenseTensor& query,
const DenseTensor& key,
const DenseTensor& value,
const paddle::optional<DenseTensor>& bias,
const paddle::optional<DenseTensor>& cu_seqlens_q,
const paddle::optional<DenseTensor>& cu_seqlens_k,
const DenseTensor& output,
const DenseTensor& logsumexp,
const DenseTensor& seed_and_offset,
const DenseTensor& output_grad,
const Scalar& max_seqlen_q,
const Scalar& max_seqlen_k,
const bool causal,
const double dropout_p,
const float scale,
DenseTensor* query_grad,
DenseTensor* key_grad,
DenseTensor* value_grad,
DenseTensor* bias_grad) {
bool kernel_launched = false;
auto launchKernel = [&](auto k_, auto kernel_fn) {
// ndim
PADDLE_ENFORCE_EQ(
query.dims().size(),
output_grad.dims().size(),
paddle::platform::errors::InvalidArgument(
"The size of query's dimensions "
"should be euqal to output grad. But received query's "
"dimensions = %d, output grad's dimensions = %d.",
query.dims().size(),
output_grad.dims().size()));
PADDLE_ENFORCE_EQ(query.dims().size(),
key.dims().size(),
paddle::platform::errors::InvalidArgument(
"The size of query's dimensions "
"should be euqal to key. But received query's "
"dimensions = %d, key's dimensions = %d.",
query.dims().size(),
key.dims().size()));
PADDLE_ENFORCE_EQ(query.dims().size(),
value.dims().size(),
paddle::platform::errors::InvalidArgument(
"The size of query's dimensions "
"should be euqal to value. But received query's "
"dimensions = %d, value's dimensions = %d.",
query.dims().size(),
key.dims().size()));
PADDLE_ENFORCE_EQ(query.dims().size(),
4,
paddle::platform::errors::InvalidArgument(
"The size of query's dimensions "
"dim size of query is illegal. Expected dimension "
"size=4. Received %d.",
query.dims().size()));
// batch size
PADDLE_ENFORCE_EQ(
query.dims()[0],
output_grad.dims()[0],
paddle::platform::errors::InvalidArgument(
"The batch size of query's dimensions "
"should be euqal to output grad. But received query's "
"batch size = %d, output grad's batch size = %d.",
query.dims()[0],
output_grad.dims()[0]));
PADDLE_ENFORCE_EQ(query.dims()[0],
key.dims()[0],
paddle::platform::errors::InvalidArgument(
"The batch size of query's dimensions "
"should be euqal to key. But received query's "
"batch size = %d, key's batch size = %d.",
query.dims()[0],
key.dims()[0]));
PADDLE_ENFORCE_EQ(query.dims()[0],
value.dims()[0],
paddle::platform::errors::InvalidArgument(
"The batch size of query's dimensions "
"should be euqal to value. But received query's "
"batch size = %d, value's batch size = %d.",
query.dims()[0],
value.dims()[0]));
// seqlen
PADDLE_ENFORCE_EQ(
key.dims()[1],
value.dims()[1],
paddle::platform::errors::InvalidArgument(
"The sequence length of key"
"should be euqal to value. But received key's sequence length = "
"%d, value's sequence length = %d.",
key.dims()[1],
value.dims()[1]));
PADDLE_ENFORCE_EQ(query.dims()[1],
output_grad.dims()[1],
paddle::platform::errors::InvalidArgument(
"The sequence length of query"
"should be euqal to output grad. But received "
"query's sequence length = "
"%d, output grad's sequence length = %d.",
query.dims()[1],
output_grad.dims()[1]));
// Num heads
PADDLE_ENFORCE_EQ(
query.dims()[2],
key.dims()[2],
paddle::platform::errors::InvalidArgument(
"The head number of query"
"should be euqal to key. But received query's head number = "
"%d, key's head number = %d.",
query.dims()[2],
key.dims()[2]));
PADDLE_ENFORCE_EQ(
query.dims()[2],
value.dims()[2],
paddle::platform::errors::InvalidArgument(
"The head number of query"
"should be euqal to value. But received query's head number = "
"%d, value's head number = %d.",
query.dims()[2],
value.dims()[2]));
PADDLE_ENFORCE_EQ(query.dims()[2],
output_grad.dims()[2],
paddle::platform::errors::InvalidArgument(
"The head number of query"
"should be euqal to output grad. But received "
"query's head number = "
"%d, output grad's head number = %d.",
query.dims()[2],
output_grad.dims()[2]));
// Embedding per head
PADDLE_ENFORCE_EQ(
query.dims()[3],
key.dims()[3],
paddle::platform::errors::InvalidArgument(
"The head size of query"
"should be euqal to key. But received query's head size = "
"%d, key's head size = %d.",
query.dims()[3],
key.dims()[3]));
PADDLE_ENFORCE_EQ(
value.dims()[3],
output_grad.dims()[3],
paddle::platform::errors::InvalidArgument(
"The head size of value"
"should be euqal to output grad. But received value's head size = "
"%d, output grad's head size = %d.",
value.dims()[3],
output_grad.dims()[3]));
if (cu_seqlens_q) {
PADDLE_ENFORCE_EQ((cu_seqlens_q && bias),
false,
paddle::platform::errors::InvalidArgument(
"cu_seqlens_q or bias should be None"));
PADDLE_ENFORCE_EQ(
(cu_seqlens_k && cu_seqlens_q),
true,
paddle::platform::errors::InvalidArgument(
"cu_seqlens_q and cu_seqlens_k should be same condition"));
} else {
PADDLE_ENFORCE_EQ(
(cu_seqlens_k || cu_seqlens_q),
false,
paddle::platform::errors::InvalidArgument(
"cu_seqlens_q and cu_seqlens_k should be same condition"));
}
const auto& k_dims = key.dims();
const auto& q_dims = query.dims();
const auto& v_dims = value.dims();
int64_t max_seqlen_q_tmp, max_seqlen_k_tmp;
if (cu_seqlens_q) {
PADDLE_ENFORCE_EQ(cu_seqlens_q.get().dtype(),
DataType::INT32,
paddle::platform::errors::InvalidArgument(
"data type of cu_seqlens_q should be INT32"));
PADDLE_ENFORCE_EQ(cu_seqlens_k.get().dtype(),
DataType::INT32,
paddle::platform::errors::InvalidArgument(
"data type of cu_seqlens_k should be INT32"));
PADDLE_ENFORCE_EQ(cu_seqlens_q.get().dims().size(),
1,
paddle::platform::errors::InvalidArgument(
"dims of cu_seqlens_q should be one"));
PADDLE_ENFORCE_EQ(cu_seqlens_k.get().dims().size(),
1,
paddle::platform::errors::InvalidArgument(
"dims of cu_seqlens_k should be one"));
max_seqlen_q_tmp = max_seqlen_q.to<int64_t>();
max_seqlen_k_tmp = max_seqlen_k.to<int64_t>();
VLOG(3) << "max_seqlen_q_tmp" << max_seqlen_q_tmp;
VLOG(3) << "max_seqlen_k_tmp" << max_seqlen_k_tmp;
PADDLE_ENFORCE_EQ(cu_seqlens_q.get().dims()[0],
cu_seqlens_k.get().dims()[0],
paddle::platform::errors::InvalidArgument(
"The first dimension of cu_seqlens_q"
"should be euqal to cu_seqlens_q."));
PADDLE_ENFORCE_EQ(
q_dims[0],
1,
paddle::platform::errors::InvalidArgument(
"The batch number of query"
"should be one. But received batch number of query = %d.",
q_dims[0]));
PADDLE_ENFORCE_LT(0,
max_seqlen_q_tmp,
paddle::platform::errors::InvalidArgument(
"The max sequence length of query"
"should more than zero. But received the max "
"sequence length of query = %d.",
max_seqlen_q_tmp));
PADDLE_ENFORCE_LT(0,
max_seqlen_k_tmp,
paddle::platform::errors::InvalidArgument(
"The max sequence length of key"
"should more than zero. But received the max "
"sequence length of key = %d.",
max_seqlen_k_tmp));
PADDLE_ENFORCE_LE(max_seqlen_q_tmp,
q_dims[1],
paddle::platform::errors::InvalidArgument(
"The max sequence length of query"
"should larger than sequence length of query. But "
"received the max sequence length of query = %d,"
"the sequence length of query = %d",
max_seqlen_q_tmp,
q_dims[1]));
PADDLE_ENFORCE_LE(max_seqlen_k_tmp,
k_dims[1],
paddle::platform::errors::InvalidArgument(
"The max sequence length of key"
"should larger than sequence length of key. But "
"received the max sequence length of key = %d,"
"the sequence length of key = %d",
max_seqlen_k_tmp,
k_dims[1]));
} else {
max_seqlen_q_tmp = q_dims[1];
max_seqlen_k_tmp = k_dims[1];
}
VLOG(3) << "max_seqlen_q_tmp has been set " << max_seqlen_q_tmp
<< " max_seqlen_k_tmp " << max_seqlen_k_tmp;
auto use_dropout = dropout_p != 0.0;
const auto maxK = std::max(q_dims[3], v_dims[3]);
int compute_capacity = ctx.GetComputeCapability();
const auto max_shmem =
getMaximumSharedMemoryPerBlockKb(compute_capacity) * 1024;
using KernelType = decltype(k_);
using scalar_t = typename KernelType::scalar_t;
if (kernel_launched) {
return;
}
// Check if this kernel is compatible
if (KernelType::kMaxK < maxK) {
return;
}
// Dropout must be supported if we need it
if (use_dropout && !KernelType::kApplyDropout) {
return;
}
// Alignment
if ((q_dims[3] % KernelType::kMinimumAlignment) ||
(k_dims[3] % KernelType::kMinimumAlignment) ||
(v_dims[3] % KernelType::kMinimumAlignment)) {
return;
}
// Uses too much shmem
size_t smem_bytes = sizeof(typename KernelType::SharedStorage);
if (smem_bytes > max_shmem) {
return;
}
VLOG(3) << "smem has been set " << smem_bytes << " " << max_shmem;
kernel_launched = true;
DenseTensor delta;
if (KernelType::kKernelComputesDelta) {
phi::EmptyKernel<float, Context>(
ctx,
{output.dims()[0], output.dims()[2], output.dims()[1]},
output.dtype(),
&delta);
} else {
DenseTensor output_grad_tmp =
output_grad.dtype() == DataType::FLOAT32
? output_grad
: phi::Cast<T, Context>(ctx, output_grad, DataType::FLOAT32);
DenseTensor output_tmp =
output.dtype() == DataType::FLOAT32
? output
: phi::Cast<T, Context>(ctx, output, DataType::FLOAT32);
DenseTensor delta_mul =
phi::Multiply<float, Context>(ctx, output_grad_tmp, output_tmp);
DenseTensor delta_sum;
phi::EmptyKernel<float, Context>(
ctx,
{delta_mul.dims()[0], delta_mul.dims()[1], delta_mul.dims()[2]},
DataType::FLOAT32,
&delta_sum);
phi::SumKernel<float, Context>(
ctx, delta_mul, {-1}, delta_mul.dtype(), false, &delta_sum);
phi::EmptyKernel<float, Context>(
ctx,
{delta_mul.dims()[0], delta_mul.dims()[2], delta_mul.dims()[1]},
DataType::FLOAT32,
&delta);
phi::TransposeKernel<float, Context>(ctx, delta_sum, {0, 2, 1}, &delta);
}
VLOG(3) << "p.output" << output.dtype();
VLOG(3) << "p.output_grad" << output_grad.dtype();
PADDLE_ENFORCE_EQ(
delta.dims()[0],
query.dims()[0],
paddle::platform::errors::InvalidArgument(
"The first dimension of delta"
"should be euqal to query. But received delta's first dimension = "
"%d, query's first dimension = %d.",
delta.dims()[0],
query.dims()[0]));
PADDLE_ENFORCE_EQ(delta.dims()[1],
query.dims()[2],
paddle::platform::errors::InvalidArgument(
"The second dimension of delta"
"should be euqal to third dimension query. But "
"received delta's second dimension = "
"%d, query's third dimension = %d.",
delta.dims()[1],
query.dims()[2]));
PADDLE_ENFORCE_EQ(delta.dims()[2],
query.dims()[1],
paddle::platform::errors::InvalidArgument(
"The third dimension of delta"
"should be euqal to second dimension query. But "
"received delta's third dimension = "
"%d, query's second dimension = %d.",
delta.dims()[2],
query.dims()[1]));
VLOG(3) << "delta has been set" << delta.data();
typename KernelType::Params p;
p.query_ptr = SafeGetTensorPtr<scalar_t>(query);
p.key_ptr = SafeGetTensorPtr<scalar_t>(key);
p.value_ptr = SafeGetTensorPtr<scalar_t>(value);
bool force_pad_inf = (compute_capacity == 75);
const std::string data_format = "NCHW";
DenseTensor padded_lse =
phi::funcs::get_pad_lse<float>(ctx,
const_cast<DenseTensor*>(&logsumexp),
static_cast<int>(output.dims()[1]),
32,
data_format,
force_pad_inf);
p.logsumexp_ptr = SafeGetTensorPtr<float>(padded_lse);
VLOG(3) << "logsumexp_ptr" << p.logsumexp_ptr;
p.output_ptr = SafeGetTensorPtr<scalar_t>(output);
p.grad_output_ptr = SafeGetTensorPtr<scalar_t>(output_grad);
p.grad_query_ptr = SafeAllocTensor<scalar_t, Context>(ctx, query_grad);
p.grad_key_ptr = SafeAllocTensor<scalar_t, Context>(ctx, key_grad);
p.grad_value_ptr = SafeAllocTensor<scalar_t, Context>(ctx, value_grad);
p.delta_ptr = SafeGetTensorPtr<float>(delta);
p.head_dim = q_dims[3];
p.head_dim_value = v_dims[3];
p.num_queries = max_seqlen_q_tmp;
p.num_keys = max_seqlen_k_tmp;
p.num_batches = cu_seqlens_q ? cu_seqlens_q.get().dims()[0] - 1 : q_dims[0];
p.num_heads = q_dims[2];
p.causal = causal;
if (scale < 0) {
p.scale = static_cast<float>(1.0 / std::sqrt(p.head_dim));
} else {
p.scale = scale;
}
VLOG(3) << "p.scale" << p.scale;
if (cu_seqlens_q) {
p.cu_seqlens_q_ptr = SafeGetTensorPtr<int32_t>(cu_seqlens_q);
p.cu_seqlens_k_ptr = SafeGetTensorPtr<int32_t>(cu_seqlens_k);
VLOG(3) << "p.cu_seqlens_q_ptr" << p.cu_seqlens_q_ptr;
}
p.lse_strideH = DimStride(logsumexp.dims(), 1);
p.lse_strideB = DimStride(logsumexp.dims(), 0);
VLOG(3) << "p.lse_strideH " << p.lse_strideH;
p.gO_strideH = DimStride(output_grad.dims(), 2);
p.gO_strideM = DimStride(output_grad.dims(), 1);
p.gO_strideB = DimStride(output_grad.dims(), 0);
p.o_strideH = DimStride(output.dims(), 2);
p.o_strideB = DimStride(output.dims(), 0);
p.gQ_strideH = DimStride(query_grad->dims(), 2);
p.gK_strideH = DimStride(key_grad->dims(), 2);
p.gV_strideH = DimStride(value_grad->dims(), 2);
p.gQ_strideB = DimStride(query_grad->dims(), 0);
p.gK_strideB = DimStride(key_grad->dims(), 0);
p.gV_strideB = DimStride(value_grad->dims(), 0);
p.gQKV_strideM_multiplier = 1;
PADDLE_ENFORCE_EQ(q_dims[2] * q_dims[3],
DimStride(query_grad->dims(), 1),
paddle::platform::errors::InvalidArgument(
"The strideM of grad query"
"should be euqal to the first dimension size of "
"query grad's stride"));
PADDLE_ENFORCE_EQ(k_dims[2] * k_dims[3],
DimStride(key_grad->dims(), 1),
paddle::platform::errors::InvalidArgument(
"The strideM of grad key"
"should be euqal to the first dimension size of key "
"grad's stride"));
PADDLE_ENFORCE_EQ(v_dims[2] * v_dims[3],
DimStride(value_grad->dims(), 1),
paddle::platform::errors::InvalidArgument(
"The strideM of grad value"
"should be euqal to the first dimension size of "
"value grad's stride"));
p.q_strideB = DimStride(query.dims(), 0);
p.k_strideB = DimStride(key.dims(), 0);
p.v_strideB = DimStride(value.dims(), 0);
p.q_strideM = DimStride(query.dims(), 1);
p.k_strideM = DimStride(key.dims(), 1);
p.v_strideM = DimStride(value.dims(), 1);
p.q_strideH = DimStride(query.dims(), 2);
p.k_strideH = DimStride(key.dims(), 2);
p.v_strideH = DimStride(value.dims(), 2);
p.delta_strideH = DimStride(delta.dims(), 1);
p.delta_strideB = DimStride(delta.dims(), 0);
if (bias) {
p.bias_ptr = SafeGetTensorPtr<scalar_t>(bias);
p.bias_strideB = q_dims[2] * q_dims[1] * k_dims[1];
p.bias_strideH = q_dims[1] * k_dims[1];
p.bias_strideM = k_dims[1];
VLOG(3) << "p.bias_ptr" << p.bias_ptr;
if (bias_grad) {
p.grad_bias_ptr = SafeAllocTensor<scalar_t, Context>(ctx, bias_grad);
p.gB_strideB = q_dims[2] * q_dims[1] * k_dims[1];
p.gB_strideH = q_dims[1] * k_dims[1];
p.gB_strideM = k_dims[1];
VLOG(3) << "p.grad_bias_ptr" << p.grad_bias_ptr;
} else {
p.grad_bias_ptr = nullptr;
}
} else {
p.bias_ptr = nullptr;
p.grad_bias_ptr = nullptr;
}
if (dropout_p != 0) {
int64_t* seed_and_offset_ptr = SafeGetTensorPtr<int64_t>(seed_and_offset);
p.seed = (uint64_t)seed_and_offset_ptr[0];
p.offset = (uint64_t)seed_and_offset_ptr[1];
p.dropout_prob = dropout_p;
VLOG(3) << "seed_and_offset_ptr " << seed_and_offset_ptr;
VLOG(3) << "p.seed " << p.seed << " " << p.offset;
VLOG(3) << "p.dropout_prob " << p.dropout_prob;
}
int64_t size_bytes = p.workspace_size();
paddle::memory::AllocationPtr temp_workspace{nullptr};
VLOG(3) << "size_bytes " << size_bytes;
temp_workspace = paddle::memory::Alloc(
ctx.GetPlace(),
size_bytes,
phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
if (size_bytes) {
p.workspace = reinterpret_cast<typename KernelType::output_accum_t*>(
temp_workspace->ptr());
VLOG(3) << "p.workspace" << p.workspace;
}
VLOG(3) << "temp_workspace has been set";
if (smem_bytes > 0xc000) {
const void* kernel_fn_void_ptr =
reinterpret_cast<const void*>(reinterpret_cast<uintptr_t>(kernel_fn));
PADDLE_ENFORCE_GPU_SUCCESS(
cudaFuncSetAttribute(kernel_fn_void_ptr,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_bytes));
}
KernelType::check_supported(p);
VLOG(3) << "Kernel launched with func : " << typeid(kernel_fn).name()
<< " block dim " << p.getBlocksGrid() << " thread dim "
<< p.getThreadsGrid();
kernel_fn<<<p.getBlocksGrid(),
p.getThreadsGrid(),
smem_bytes,
ctx.stream()>>>(p);
};
dispatch_cutlass_backward<T>(ctx, launchKernel);
PADDLE_ENFORCE_EQ(kernel_launched,
true,
paddle::platform::errors::InvalidArgument(
"the kernel should not be launched"));
}
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(
memory_efficient_attention_grad,
GPU,
ALL_LAYOUT,
phi::fusion::cutlass_internal::MemoryEfficientAttentionBackwardKernel,
float,
phi::dtype::bfloat16,
phi::dtype::float16) {
kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND);
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void MemoryEfficientAttentionBackwardKernel(
const Context& ctx,
const DenseTensor& query,
const DenseTensor& key,
const DenseTensor& value,
const paddle::optional<DenseTensor>& bias,
const paddle::optional<DenseTensor>& cu_seqlens_q,
const paddle::optional<DenseTensor>& cu_seqlens_k,
const DenseTensor& output,
const DenseTensor& logsumexp,
const DenseTensor& seed_and_offset,
const DenseTensor& output_grad,
const Scalar& max_seqlen_q,
const Scalar& max_seqlen_k,
const bool causal,
const double dropout_p,
const float scale,
DenseTensor* query_grad,
DenseTensor* key_grad,
DenseTensor* value_grad,
DenseTensor* bias_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void MemoryEfficientAttentionForwardKernel(
const Context& ctx,
const DenseTensor& query,
const DenseTensor& key,
const DenseTensor& value,
const paddle::optional<DenseTensor>& bias,
const paddle::optional<DenseTensor>& cu_seqlens_q,
const paddle::optional<DenseTensor>& cu_seqlens_k,
const paddle::optional<DenseTensor>& causal_diagonal,
const paddle::optional<DenseTensor>& seqlen_k,
const Scalar& max_seqlen_q,
const Scalar& max_seqlen_k,
const bool causal,
const double dropout_p,
const float scale,
const bool is_test,
DenseTensor* output,
DenseTensor* logsumexp,
DenseTensor* seed_and_offset);
} // namespace phi
...@@ -1117,6 +1117,7 @@ set_tests_properties(test_cumprod_op PROPERTIES TIMEOUT 120) ...@@ -1117,6 +1117,7 @@ set_tests_properties(test_cumprod_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_split_program PROPERTIES TIMEOUT 120) set_tests_properties(test_split_program PROPERTIES TIMEOUT 120)
set_tests_properties(test_graph_send_ue_recv_op PROPERTIES TIMEOUT 60) set_tests_properties(test_graph_send_ue_recv_op PROPERTIES TIMEOUT 60)
set_tests_properties(test_graph_send_uv_op PROPERTIES TIMEOUT 60) set_tests_properties(test_graph_send_uv_op PROPERTIES TIMEOUT 60)
if(WITH_DISTRIBUTE if(WITH_DISTRIBUTE
AND WITH_GPU AND WITH_GPU
AND WITH_NCCL) AND WITH_NCCL)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import re
import unittest
from typing import List, Sequence, Tuple
import numpy as np
import paddle
import paddle.fluid.core as core
import paddle.incubate.nn.attn_bias as ab
import paddle.nn.functional as F
from paddle.incubate.nn.memory_efficient_attention import (
memory_efficient_attention,
)
paddle.seed(2023)
def get_cuda_version():
result = os.popen("nvcc --version").read()
regex = r'release (\S+),'
match = re.search(regex, result)
if match:
num = str(match.group(1))
integer, decimal = num.split('.')
return int(integer) * 1000 + int(float(decimal) * 10)
else:
return -1
def create_attn_bias(
bias_type,
batch_size: int,
num_heads: int,
q_len: int,
kv_len: int,
tdtype,
pdtype,
requires_grad: bool,
fmt: str,
):
if bias_type is None or isinstance(None, bias_type):
return None
r = random.Random(
"-".join(map(str, [batch_size, q_len, kv_len, tdtype, fmt]))
)
if bias_type is paddle.Tensor:
if fmt == "BMK":
batch_size *= num_heads
num_heads = 1
attn_bias = (
paddle.randn((batch_size, num_heads, 1, kv_len), dtype=pdtype) * 3
)
attn_bias = attn_bias.expand([batch_size, num_heads, q_len, kv_len])
if requires_grad:
attn_bias.stop_gradient = False
return attn_bias
if bias_type is ab.LowerTriangularMask:
return ab.LowerTriangularMask()
if bias_type in [
ab.BlockDiagonalMask,
ab.BlockDiagonalCausalMask,
]:
# This bias is not supported in BMK format
assert fmt == "BMHK"
block_diag = ab.BlockDiagonalMask.from_seqlens(
*_rand_seqlens(r, batch_size, q_len, kv_len)
)
if bias_type is ab.BlockDiagonalCausalMask:
block_diag = block_diag.make_causal()
return block_diag
raise AssertionError(f"Unsupported bias type: {bias_type}")
def _rand_seqlens(
r: random.Random, bs: int, q_len: int, kv_len: int
) -> Tuple[Sequence[int], Sequence[int]]:
q_len *= bs
kv_len *= bs
seqlens_q: List[int] = []
seqlens_k: List[int] = []
step_q = [max(1, q_len // 10), max(2, q_len // 2)]
step_k = [max(1, kv_len // 10), max(2, kv_len // 2)]
while sum(seqlens_q) < q_len and sum(seqlens_k) < kv_len:
seqlens_q.append(r.randrange(*step_q))
seqlens_k.append(r.randrange(*step_k))
seqlens_q[-1] = q_len - sum(seqlens_q[:-1])
seqlens_k[-1] = kv_len - sum(seqlens_k[:-1])
return seqlens_q, seqlens_k
def attention_naive(q, k, v, attn_bias, dropout_prob, scale, seed):
qt = paddle.transpose(q, [0, 2, 1, 3])
kt = paddle.transpose(k, [0, 2, 1, 3])
vt = paddle.transpose(v, [0, 2, 1, 3])
scale = 1.0 / np.sqrt(q.shape[-1])
s = paddle.matmul(qt, paddle.transpose(kt, [0, 1, 3, 2]))
s = paddle.scale(s, scale)
if attn_bias is None:
dropout_input = F.softmax(s)
elif isinstance(
attn_bias,
(
ab.LowerTriangularMask,
ab.BlockDiagonalMask,
ab.BlockDiagonalCausalMask,
),
):
bias = attn_bias.materialize(
(q.shape[0], q.shape[2], q.shape[1], k.shape[1]), q.dtype
)
dropout_input = F.softmax(s + bias)
elif isinstance(attn_bias, paddle.Tensor):
dropout_input = F.softmax(s + attn_bias)
paddle.seed(seed)
dropout_output = F.dropout(
x=dropout_input,
p=dropout_prob,
training=True,
mode="upscale_in_train",
)
o = paddle.matmul(dropout_output, vt)
return paddle.transpose(o, [0, 2, 1, 3])
@unittest.skipIf(
not core.is_compiled_with_cuda() or get_cuda_version() < 11030,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3",
)
class TestMemEffAttentionAPI(unittest.TestCase):
def setUp(self):
self.name = "MemEffAPI_fp32"
self.place = paddle.CUDAPlace(0)
self.shape = (1, 128, 8, 16)
self.dtype = 'float32'
self.dropout = 0.0
self.training = True
self.attention_bias = None
self.scale = 1.0 / np.sqrt(self.shape[-1])
self.seed = 2023
def test_all(self):
print(
f"Test All case shape {self.shape} dtype {self.dtype} name {self.name}"
)
paddle.disable_static()
query = np.random.random(self.shape)
q = paddle.to_tensor(
query, place=self.place, dtype=self.dtype, stop_gradient=False
)
q_ = paddle.to_tensor(
query, place=self.place, dtype=self.dtype, stop_gradient=False
)
key = np.random.random(self.shape)
k = paddle.to_tensor(
key, place=self.place, dtype=self.dtype, stop_gradient=False
)
k_ = paddle.to_tensor(
key, place=self.place, dtype=self.dtype, stop_gradient=False
)
value = np.random.random(self.shape)
v = paddle.to_tensor(
value, place=self.place, dtype=self.dtype, stop_gradient=False
)
v_ = paddle.to_tensor(
value, place=self.place, dtype=self.dtype, stop_gradient=False
)
q.stop_gradient = False
k.stop_gradient = False
v.stop_gradient = False
q_.stop_gradient = False
k_.stop_gradient = False
v_.stop_gradient = False
out_ = attention_naive(
q_, k_, v_, self.attention_bias, self.dropout, self.scale, self.seed
)
paddle.seed(self.seed)
out = memory_efficient_attention(
q,
k,
v,
self.attention_bias,
self.dropout,
self.scale,
self.training,
)
np.testing.assert_allclose(out.numpy(), out_, rtol=5e-03, atol=1e-03)
grad_out = paddle.ones_like(q)
out.backward(grad_out)
out_.backward(grad_out)
np.testing.assert_allclose(
q.grad.numpy(), q_.grad.numpy(), rtol=5e-03, atol=1e-03
)
class TestMemEffAPIDtypeFp16(TestMemEffAttentionAPI):
def setUp(self):
self.name = "MemEffAPI_fp16"
self.place = paddle.CUDAPlace(0)
self.shape = (1, 32, 128, 128)
self.dtype = paddle.float16
self.dropout = 0.0
self.attention_bias = None
self.training = True
self.scale = 1.0 / np.sqrt(self.shape[-1])
self.seed = 2023
class TestMemEffAPIShape0(TestMemEffAttentionAPI):
def setUp(self):
self.name = "MemEffAPI_fp32"
self.place = paddle.CUDAPlace(0)
self.shape = (1, 32, 128, 32)
self.dtype = paddle.float32
self.dropout = 0.0
self.attention_bias = None
self.training = True
self.scale = 1.0 / np.sqrt(self.shape[-1])
self.seed = 2023
class TestMemEffAPIShape1(TestMemEffAttentionAPI):
def setUp(self):
self.name = "MemEffAPI_fp32"
self.place = paddle.CUDAPlace(0)
self.shape = (1, 32, 16, 16)
self.dtype = paddle.float32
self.dropout = 0.0
self.attention_bias = None
self.training = True
self.scale = 1.0 / np.sqrt(self.shape[-1])
self.seed = 2023
class TestMemEffAPIShape2(TestMemEffAttentionAPI):
def setUp(self):
self.name = "MemEffAPI_fp32"
self.place = paddle.CUDAPlace(0)
self.shape = (1, 32, 8, 8)
self.dtype = paddle.float32
self.dropout = 0.0
self.attention_bias = None
self.training = True
self.scale = 1.0 / np.sqrt(self.shape[-1])
self.seed = 2023
class TestMemEffAPIShape3(TestMemEffAttentionAPI):
def setUp(self):
self.name = "MemEffAPI_fp32"
self.place = paddle.CUDAPlace(0)
self.shape = (16, 32, 128, 128)
self.dtype = paddle.float32
self.dropout = 0.0
self.attention_bias = None
self.training = True
self.scale = 1.0 / np.sqrt(self.shape[-1])
self.seed = 2023
class TestMemEffAPIMask0(TestMemEffAttentionAPI):
def setUp(self):
self.name = "MemEffAPI_fp32_BlockDiagonalMask"
self.place = paddle.CUDAPlace(0)
self.shape = (1, 32, 128, 128)
self.dtype = paddle.float32
self.dropout = 0.0
self.attention_bias = create_attn_bias(
ab.BlockDiagonalMask,
self.shape[0],
self.shape[2],
self.shape[1],
self.shape[1],
"float32",
self.dtype,
False,
"BMHK",
)
self.training = True
self.scale = 1.0 / np.sqrt(self.shape[-1])
self.seed = 2023
class TestMemEffAPIMask1(TestMemEffAttentionAPI):
def setUp(self):
self.name = "MemEffAPI_fp32_BlockDiagonalCausalMask"
self.place = paddle.CUDAPlace(0)
self.shape = (1, 32, 128, 128)
self.dtype = paddle.float32
self.dropout = 0.0
self.attention_bias = create_attn_bias(
ab.BlockDiagonalCausalMask,
self.shape[0],
self.shape[2],
self.shape[1],
self.shape[1],
"float32",
self.dtype,
False,
"BMHK",
)
self.training = True
self.scale = 1.0 / np.sqrt(self.shape[-1])
self.seed = 2023
class TestMemEffAPIMask2(TestMemEffAttentionAPI):
def setUp(self):
self.name = "MemEffAPI_fp32_LowerTriangularMask"
self.place = paddle.CUDAPlace(0)
self.shape = (1, 32, 128, 128)
self.dtype = paddle.float32
self.dropout = 0.0
self.attention_bias = create_attn_bias(
ab.LowerTriangularMask,
self.shape[0],
self.shape[2],
self.shape[1],
self.shape[1],
"float32",
self.dtype,
False,
"BMHK",
)
self.training = True
self.scale = 1.0 / np.sqrt(self.shape[-1])
self.seed = 2023
class TestMemEffAPIMask3(TestMemEffAttentionAPI):
def setUp(self):
self.name = "MemEffAPI_fp32_AnyTensor"
self.place = paddle.CUDAPlace(0)
self.shape = (1, 32, 128, 128)
self.dtype = paddle.float32
self.dropout = 0.0
self.attention_bias = (
paddle.randn(
(self.shape[0], self.shape[2], 1, self.shape[1]),
dtype=self.dtype,
)
* 3
)
self.attention_bias = self.attention_bias.expand(
[self.shape[0], self.shape[2], self.shape[1], self.shape[1]]
)
self.attention_bias.stop_gradient = False
self.training = True
self.scale = 1.0 / np.sqrt(self.shape[-1])
self.seed = 2023
if __name__ == '__main__':
unittest.main()
...@@ -20,6 +20,7 @@ from .fused_transformer import fused_bias_dropout_residual_layer_norm ...@@ -20,6 +20,7 @@ from .fused_transformer import fused_bias_dropout_residual_layer_norm
from .fused_ec_moe import fused_ec_moe from .fused_ec_moe import fused_ec_moe
from .fused_dropout_add import fused_dropout_add from .fused_dropout_add import fused_dropout_add
__all__ = [ __all__ = [
'fused_multi_head_attention', 'fused_multi_head_attention',
'fused_feedforward', 'fused_feedforward',
......
...@@ -20,6 +20,9 @@ ...@@ -20,6 +20,9 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import paddle import paddle
from paddle import _C_ops
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.layer_helper import LayerHelper
from .attn_bias import ( from .attn_bias import (
BlockDiagonalCausalMask, BlockDiagonalCausalMask,
...@@ -65,7 +68,7 @@ def _get_tensor_bias(attn_bias): ...@@ -65,7 +68,7 @@ def _get_tensor_bias(attn_bias):
def memory_efficient_attention( def memory_efficient_attention(
query, key, value, attn_bias, p=0.0, scale=None, training=True query, key, value, attn_bias=None, p=0.0, scale=None, training=True
): ):
assert type(attn_bias) in SUPPORTED_ATTN_BIAS_TYPES assert type(attn_bias) in SUPPORTED_ATTN_BIAS_TYPES
causal = isinstance( causal = isinstance(
...@@ -76,9 +79,10 @@ def memory_efficient_attention( ...@@ -76,9 +79,10 @@ def memory_efficient_attention(
BlockDiagonalCausalWithOffsetPaddedKeysMask, BlockDiagonalCausalWithOffsetPaddedKeysMask,
), ),
) )
seqstart_k, seqstart_q, max_seqlen_q, _ = _get_seqlen_info(attn_bias) seqstart_k, seqstart_q, max_seqlen_q, max_seqlen_k = _get_seqlen_info(
attn_bias
)
# NOTE: compute_logsumexp = training # NOTE: compute_logsumexp = training
is_test = not training
causal_diagonal = ( causal_diagonal = (
attn_bias.causal_diagonal attn_bias.causal_diagonal
if isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask) if isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
...@@ -89,5 +93,60 @@ def memory_efficient_attention( ...@@ -89,5 +93,60 @@ def memory_efficient_attention(
if isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask) if isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
else None else None
) )
attn_bias = _get_tensor_bias(attn_bias) if scale is None:
# TODO(zhangdanyang): add C++ codes here scale = -1.0
bias = _get_tensor_bias(attn_bias)
is_test = not training
if in_dygraph_mode():
output, logsumexp, seed_and_offset = _C_ops.memory_efficient_attention(
query,
key,
value,
bias,
seqstart_q,
seqstart_k,
causal_diagonal,
seqlen_k,
max_seqlen_q,
max_seqlen_k,
causal,
p,
scale,
is_test,
)
return output
helper = LayerHelper('memory_efficient_attention', **locals())
output = helper.create_variable_for_type_inference(dtype=query.dtype)
logsumexp = helper.create_variable_for_type_inference(dtype='float')
seed_and_offset = helper.create_variable_for_type_inference(dtype='int32')
helper.append_op(
type='memory_efficient_attention',
inputs={
'query': query,
'key': key,
'value': value,
'bias': bias,
"cu_seqlens_q": seqstart_q,
"cu_seqlens_k": seqstart_k,
"causal_diagonal": causal_diagonal,
"seqlen_k": seqlen_k,
},
args={
"max_seqlen_q": max_seqlen_q,
"max_seqlen_k": max_seqlen_k,
"causal": causal,
"dropout_p": p,
"scale": scale,
"is_test": is_test,
},
outputs={
'output': output,
'logsumexp': logsumexp,
"seed_and_offset": seed_and_offset,
},
)
return output
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册