未验证 提交 c4d5ec66 编写于 作者: FormlessUnit's avatar FormlessUnit 提交者: GitHub

add linear_compress API (#54140)

* add linear_compress API
上级 989f3dde
...@@ -112,11 +112,13 @@ class AssignValueKernel : public framework::OpKernel<T> { ...@@ -112,11 +112,13 @@ class AssignValueKernel : public framework::OpKernel<T> {
break; break;
case framework::proto::VarType::INT64: case framework::proto::VarType::INT64:
value_name = "int64_values"; value_name = "int64_values";
case framework::proto::VarType::INT8:
value_name = "int8_values";
break; break;
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported data type(code %d) for AssignValue operator, only " "Unsupported data type(code %d) for AssignValue operator, only "
"supports bool, int32, float32 and int64.", "supports bool, int32, float32, int8 and int64.",
dtype)); dtype));
break; break;
} }
......
...@@ -1347,6 +1347,16 @@ ...@@ -1347,6 +1347,16 @@
data_transform : data_transform :
skip_transform : out_size, size_tensor, scale_tensor skip_transform : out_size, size_tensor, scale_tensor
- op : llm_int8_mat_mul
args : (Tensor x, Tensor weight, Tensor weight_scale, float threshold=6.0)
output : Tensor(out)
infer_meta :
func : LLMInt8MatMulInferMeta
param : [x, weight]
kernel :
func : llm_int8_mat_mul
data_type : x
- op : log - op : log
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor
...@@ -1896,6 +1906,15 @@ ...@@ -1896,6 +1906,15 @@
func : qr func : qr
backward : qr_grad backward : qr_grad
- op : quant_for_compress
args : (Tensor x, int bits = 8, str layout = "weight_only")
output : Tensor(out), Tensor(scale)
infer_meta :
func : QuantForCompressInferMeta
kernel :
func : quant_for_compress
data_type: x
- op : real - op : real
args : (Tensor x) args : (Tensor x)
output : Tensor (out) output : Tensor (out)
...@@ -2563,6 +2582,16 @@ ...@@ -2563,6 +2582,16 @@
intermediate: warprnntgrad intermediate: warprnntgrad
backward : warprnnt_grad backward : warprnnt_grad
- op : weight_only_mat_mul
args : (Tensor x, Tensor weight, Tensor weight_scale)
output : Tensor(out)
infer_meta :
func : WeightOnlyMatMulInferMeta
param : [x, weight]
kernel :
func : weight_only_mat_mul
data_type : x
- op : weighted_sample_neighbors - op : weighted_sample_neighbors
args : (Tensor row, Tensor colptr, Tensor edge_weight, Tensor input_nodes, Tensor eids, int sample_size, bool return_eids) args : (Tensor row, Tensor colptr, Tensor edge_weight, Tensor input_nodes, Tensor eids, int sample_size, bool return_eids)
output : Tensor(out_neighbors), Tensor(out_count), Tensor(out_eids) output : Tensor(out_neighbors), Tensor(out_count), Tensor(out_eids)
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#pragma once
namespace phi {
template <typename T>
struct PDDataTypeTraits {
using DataType = T;
};
template <>
struct PDDataTypeTraits<phi::dtype::float16> {
// Since LayerNormDirectCUDAFunctor register half type, we need to convert
// phi::float16 to half.
using DataType = half;
};
#ifdef PADDLE_CUDA_BF16
template <>
class PDDataTypeTraits<phi::dtype::bfloat16> {
public:
using DataType = __nv_bfloat16;
};
#endif
} // namespace phi
...@@ -3572,5 +3572,51 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row, ...@@ -3572,5 +3572,51 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row,
out_count->set_dtype(DataType::INT32); out_count->set_dtype(DataType::INT32);
} }
void LLMInt8MatMulInferMeta(const MetaTensor& x,
const MetaTensor& weight,
MetaTensor* out) {
auto x_dims = x.dims();
auto w_dims = weight.dims();
PADDLE_ENFORCE_EQ(
w_dims.size(),
2UL,
errors::InvalidArgument("The input(weight) must be a 2D Tensor."));
PADDLE_ENFORCE_EQ(
x_dims[x_dims.size() - 1],
w_dims[1],
errors::InvalidArgument(
"Input(X) dim[-1] and Input(Weight) dim[1] should be euqal."
"But received Input(X) dim[-1](%s) != Input(Weight) dim[1](%s)",
x_dims[x_dims.size() - 1],
w_dims[1]));
auto out_dims = x_dims;
out_dims[out_dims.size() - 1] = w_dims[0];
out->set_dims(out_dims);
out->set_dtype(x.dtype());
}
void WeightOnlyMatMulInferMeta(const MetaTensor& x,
const MetaTensor& weight,
MetaTensor* out) {
auto x_dims = x.dims();
auto w_dims = weight.dims();
PADDLE_ENFORCE_EQ(
w_dims.size(),
2UL,
errors::InvalidArgument("The input(weight) must be a 2D Tensor."));
PADDLE_ENFORCE_EQ(
x_dims[x_dims.size() - 1],
w_dims[0],
errors::InvalidArgument(
"Input(X) dim[-1] and Input(Weight) dim[0] should be euqal."
"But received Input(X) dim[-1](%s) != Input(Weight) dim[0](%s)",
x_dims[x_dims.size() - 1],
w_dims[0]));
auto out_dims = x_dims;
out_dims[out_dims.size() - 1] = w_dims[1];
out->set_dims(out_dims);
out->set_dtype(x.dtype());
}
} // namespace phi } // namespace phi
PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta); PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta);
...@@ -673,6 +673,31 @@ void MoeInferMeta(const MetaTensor& x, ...@@ -673,6 +673,31 @@ void MoeInferMeta(const MetaTensor& x,
const std::string& act_type, const std::string& act_type,
MetaTensor* out); MetaTensor* out);
void FusedMultiHeadAttentionInferMeta(const MetaTensor& query,
const MetaTensor& key,
const MetaTensor& value,
const MetaTensor& mask,
float scale,
bool causal,
MetaTensor* out);
void FusedMultiHeadAttentionVariableInferMeta(const MetaTensor& query,
const MetaTensor& key,
const MetaTensor& value,
const MetaTensor& seq_lens,
const MetaTensor& mask,
float scale,
bool causal,
MetaTensor* out);
void LLMInt8MatMulInferMeta(const MetaTensor& x,
const MetaTensor& weight,
MetaTensor* out);
void WeightOnlyMatMulInferMeta(const MetaTensor& x,
const MetaTensor& weight,
MetaTensor* out);
void FusedRopeInferMeta(const MetaTensor& q, void FusedRopeInferMeta(const MetaTensor& q,
const MetaTensor& k, const MetaTensor& k,
const MetaTensor& v, const MetaTensor& v,
......
...@@ -5062,6 +5062,51 @@ void CheckNumericsInferMeta(const MetaTensor& tensor, ...@@ -5062,6 +5062,51 @@ void CheckNumericsInferMeta(const MetaTensor& tensor,
values->set_dims(phi::make_ddim({3})); values->set_dims(phi::make_ddim({3}));
} }
void QuantForCompressInferMeta(const MetaTensor& x,
int bits,
const std::string& layout,
MetaTensor* out,
MetaTensor* scale) {
auto x_dims = x.dims();
PADDLE_ENFORCE_EQ(
x_dims.size(),
2UL,
phi::errors::InvalidArgument(
"The x tensor of quant op must be 2D, but got[%d]", x_dims.size()));
PADDLE_ENFORCE_GE(
x_dims[0],
64,
phi::errors::OutOfRange("The first dimension of input is out of range "
"(expected at least 64, but got %ld).",
x_dims[0]));
PADDLE_ENFORCE_EQ(
x_dims[0] % 64,
0,
phi::errors::InvalidArgument(
"The first dimension of input must be divisible by 64, but got[%d]",
x_dims[0]));
std::vector<int64_t> dim_scale({x_dims[1]});
std::vector<int64_t> dim_out;
if (layout == "weight_only") {
dim_out = std::vector<int64_t>({x_dims[0], x_dims[1]});
} else if (layout == "llm.int8") {
dim_out = std::vector<int64_t>({x_dims[1], x_dims[0]});
} else {
phi::errors::InvalidArgument(
"The layout must be weight_only or llm.int8, but got %s", layout);
}
out->set_dims(phi::make_ddim(dim_out));
// TODO(lizhenyun) support weight_only int4
if (bits == 8) {
out->set_dtype(DataType::INT8);
} else {
phi::errors::Fatal("The bits only support 8, but got[%d]", bits);
}
scale->set_dims(phi::make_ddim(dim_scale));
scale->set_dtype(DataType::FLOAT32);
}
} // namespace phi } // namespace phi
PD_REGISTER_INFER_META_FN(flatten, phi::FlattenInferMeta); PD_REGISTER_INFER_META_FN(flatten, phi::FlattenInferMeta);
...@@ -724,4 +724,10 @@ void UnStackInferMeta(const MetaTensor& x, ...@@ -724,4 +724,10 @@ void UnStackInferMeta(const MetaTensor& x,
int num, int num,
std::vector<MetaTensor*> outs); std::vector<MetaTensor*> outs);
void QuantForCompressInferMeta(const MetaTensor& x,
int bits,
const std::string& layout,
MetaTensor* out,
MetaTensor* scale);
} // namespace phi } // namespace phi
...@@ -132,6 +132,7 @@ PD_REGISTER_KERNEL(assign_value, ...@@ -132,6 +132,7 @@ PD_REGISTER_KERNEL(assign_value,
bool, bool,
int, int,
float, float,
int8_t,
int64_t) {} int64_t) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
...@@ -158,6 +159,7 @@ PD_REGISTER_KERNEL(assign_value, ...@@ -158,6 +159,7 @@ PD_REGISTER_KERNEL(assign_value,
bool, bool,
int, int,
float, float,
int8_t,
int64_t) {} int64_t) {}
#endif #endif
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/quant_for_compress_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/impl/quant_for_compress_kernel_impl.h"
namespace phi {
template <typename DeviceContext, typename T, typename D>
void quant_compute(const DeviceContext& dev_ctx,
const DenseTensor& x,
DenseTensor* out,
DenseTensor* scale,
const std::string& layout) {
const auto x_dims = x.dims();
PADDLE_ENFORCE_EQ(
x_dims.size(),
2,
phi::errors::InvalidArgument(
"the x tensor of quant op must be 2D, but got[%d]", x_dims.size()));
size_t m = x_dims[0];
size_t n = x_dims[1];
int64_t num = x.numel();
DDim dims = {num};
const T* x_data = x.data<T>();
D* out_data = out->data<D>();
float* scale_data = scale->data<float>();
DenseTensor x_int(out->type());
x_int.Resize({static_cast<int64_t>(m), static_cast<int64_t>(n)});
dev_ctx.template Alloc<D>(&x_int);
D* x_int_data = x_int.data<D>();
DenseTensor int_processed(out->type());
int_processed.Resize(dims);
dev_ctx.template Alloc<D>(&int_processed);
D* int_processed_data = int_processed.data<D>();
DenseTensor int_processed_2(out->type());
int_processed_2.Resize(out->dims());
dev_ctx.template Alloc<D>(&int_processed_2);
D* int_processed_2_data = int_processed_2.data<D>();
per_channel_scale(scale_data, x_data, m, n);
per_channel_quant(x_int_data, x_data, scale_data, m, n);
if (layout == "weight_only") {
permute_B_rows_for_mixed_gemm(
int_processed_data, x_int_data, std::vector<size_t>{m, n}, (int64_t)80);
row_major_to_column_major(
int_processed_2_data, int_processed_data, std::vector<size_t>{m, n});
interleave_column_major_tensor(
out_data, int_processed_2_data, std::vector<size_t>{m, n});
add_bias_and_interleave_int8s_inplace(out_data, num);
} else if (layout == "llm.int8") {
std::vector<int> axis = {1, 0};
funcs::Transpose<DeviceContext, int8_t, 2> trans;
trans(dev_ctx, x_int, out, axis);
} else {
phi::errors::InvalidArgument(
"The layout must be weight_only or llm.int8, but got %s", layout);
}
}
template <typename T, typename Context>
void QuantForCompressKernel(const Context& dev_ctx,
const DenseTensor& x,
int bits,
const std::string& layout,
DenseTensor* out,
DenseTensor* scale) {
if (bits == 8) {
dev_ctx.template Alloc<int8_t>(out);
dev_ctx.template Alloc<float>(scale);
quant_compute<Context, T, int8_t>(dev_ctx, x, out, scale, layout);
} else {
phi::errors::Unimplemented("The bits only support 8, but got[%d]", bits);
}
// VLOG(0) << "x: " << x.dtype() << x;
// VLOG(0) << "out: " << out->dtype() << *out;
}
} // namespace phi
PD_REGISTER_KERNEL(quant_for_compress,
CPU,
ALL_LAYOUT,
phi::QuantForCompressKernel,
float,
phi::dtype::float16) {}
...@@ -87,6 +87,7 @@ PD_REGISTER_KERNEL(transpose, ...@@ -87,6 +87,7 @@ PD_REGISTER_KERNEL(transpose,
double, double,
int32_t, int32_t,
int64_t, int64_t,
int8_t,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16, phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <sstream>
#include <string>
#include <unordered_map>
#include "paddle/phi/backends/dynload/cublasLt.h"
#include "paddle/phi/core/dense_tensor.h"
namespace dyl = phi::dynload;
namespace phi {
struct CublasLtAlgoParam {
int algoId;
int swizzle;
int customOption;
int tile;
int splitK_val;
int reductionScheme;
int stages;
size_t workspace_size;
};
const std::map<std::tuple<int, int, int>, CublasLtAlgoParam> AlgoParamCache{};
class CublasLtHelper {
public:
CublasLtHelper(int m, int k, int n)
: alpha_(1), beta_(0), m_(m), k_(k), n_(n) {
cublasStatus_t status;
// handle and matmul desc
status = dyl::cublasLtCreate(&handle_);
#if CUBLAS_VER_MAJOR < 11
cudaDataType_t cudaComputeType = CUDA_R_32I;
#else
cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I;
#endif
PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
phi::errors::External(
"cublasLtMatrixLayoutCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
#if CUBLAS_VER_MAJOR < 11
status = dyl::cublasLtMatmulDescCreate(&matmul_desc_, cudaComputeType);
#else
status = dyl::cublasLtMatmulDescCreate(
&matmul_desc_, cudaComputeType, CUDA_R_32I);
#endif
PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
phi::errors::External(
"cublasLtMatmulDescCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
cublasOperation_t op_transpose = CUBLAS_OP_T;
status = dyl::cublasLtMatmulDescSetAttribute(matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSA,
&op_transpose,
sizeof(op_transpose));
PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
phi::errors::External(
"cublasLtMatmulDescSetAttribute execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
// matrix desc
status = dyl::cublasLtMatrixLayoutCreate(&B_desc_, CUDA_R_8I, k, n, k);
PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
phi::errors::External(
"cublasLtMatrixLayoutCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
status = dyl::cublasLtMatrixLayoutCreate(&A_desc_, CUDA_R_8I, k, m, k);
PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
phi::errors::External(
"cublasLtMatrixLayoutCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
status = dyl::cublasLtMatrixLayoutCreate(&C_desc_, CUDA_R_32I, n, m, n);
PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
phi::errors::External(
"cublasLtMatrixLayoutCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
#if CUDA_VERSION >= 11020
int algoId = 21;
int swizzle = 0;
int customOption = 0;
int tile = 15;
int splitK_val = 0;
int reductionScheme = 0;
int stages = 23;
workspace_size_ = 0;
if (m >= 128) {
tile = 20;
stages = 17;
}
std::tuple<int, int, int> key(m_, k_, n_);
if (AlgoParamCache.count(key) != 0) {
auto value = AlgoParamCache.at(key);
algoId = value.algoId;
swizzle = value.swizzle;
customOption = value.customOption;
tile = value.tile;
splitK_val = value.splitK_val;
reductionScheme = value.reductionScheme;
stages = value.stages;
workspace_size_ = value.workspace_size;
}
dyl::cublasLtMatmulAlgoInit(handle_,
cudaComputeType,
CUDA_R_32I,
CUDA_R_8I,
CUDA_R_8I,
CUDA_R_32I,
CUDA_R_32I,
algoId,
&algo_);
dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo_,
CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION,
&(customOption),
sizeof(customOption));
dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo_, CUBLASLT_ALGO_CONFIG_TILE_ID, &(tile), sizeof(tile));
dyl::cublasLtMatmulAlgoConfigSetAttribute(&algo_,
CUBLASLT_ALGO_CONFIG_SPLITK_NUM,
&(splitK_val),
sizeof(splitK_val));
dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo_,
CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING,
&(swizzle),
sizeof(swizzle));
dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo_,
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME,
&(reductionScheme),
sizeof(int));
#if CUDA_VERSION >= 11000
dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo_, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages));
#endif
#endif
}
~CublasLtHelper() {}
void GEMM(int8_t* A_dev,
const int8_t* B_dev,
int32_t* C_dev,
cudaStream_t stream,
void* workspace = nullptr) {
cublasStatus_t status;
status = dyl::cublasLtMatmul(handle_,
matmul_desc_,
&alpha_,
B_dev,
B_desc_,
A_dev,
A_desc_,
&beta_,
C_dev,
C_desc_,
C_dev,
C_desc_,
#if CUDA_VERSION >= 11020
&algo_,
workspace,
workspace_size_,
#else
nullptr,
nullptr,
0,
#endif
stream);
PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
phi::errors::External(
"cublasLtMatmul execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
}
private:
cublasLtHandle_t handle_;
cublasLtMatmulDesc_t matmul_desc_;
cublasLtMatrixLayout_t A_desc_;
cublasLtMatrixLayout_t B_desc_;
cublasLtMatrixLayout_t C_desc_;
cublasLtMatmulAlgo_t algo_;
int32_t alpha_;
int32_t beta_;
int m_;
int k_;
int n_;
size_t workspace_size_;
};
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/transform.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace phi {
using backends::gpu::GpuLaunchConfig;
constexpr int DequantKernelVecSize = 4;
template <typename T>
inline HOSTDEVICE T roundWithTiesToEven(T x) {
T xLower = floor(x);
T xUpper = ceil(x);
// x is in interval [xl,xu]. Choose closest of two bounds, breaking ties to
// even.
T dLower = x - xLower;
T dUpper = xUpper - x;
return static_cast<T>(
(dLower == dUpper ? fmod(xLower, 2.0F) == 0.0F : dLower < dUpper)
? xLower
: xUpper);
}
template <typename T>
__forceinline__ __device__ int8_t quant_helper(const T input,
const float scale,
const int round_type,
const float max_bound,
const float min_bound) {
float quant_value = max_bound * scale * static_cast<float>(input);
if (round_type == 0) {
quant_value = static_cast<float>(roundWithTiesToEven(quant_value));
} else {
quant_value = static_cast<float>(round(quant_value));
}
quant_value = quant_value > max_bound ? max_bound : quant_value;
quant_value = quant_value < min_bound ? min_bound : quant_value;
return static_cast<int8_t>(quant_value);
}
template <typename T>
__global__ void quantize_kernel(const T* input,
char4* output,
const float scale,
const int m,
const int n,
const int round_type,
const float max_bound,
const float min_bound) {
int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2;
int m_id = blockIdx.y * blockDim.y + threadIdx.y;
bool check = ((m_id < m) && (n_id < n));
if (check) {
char4 tmp;
tmp.x = quant_helper(
input[m_id * n + n_id], scale, round_type, max_bound, min_bound);
tmp.y = quant_helper(
input[m_id * n + n_id + 1], scale, round_type, max_bound, min_bound);
tmp.z = quant_helper(
input[m_id * n + n_id + 2], scale, round_type, max_bound, min_bound);
tmp.w = quant_helper(
input[m_id * n + n_id + 3], scale, round_type, max_bound, min_bound);
output[(m_id * n + n_id) >> 2] = tmp;
}
}
template <typename T>
void quantize_kernel_launcher(const T* input,
int8_t* output,
const float scale,
const int m,
const int n,
const int round_type,
const float max_bound,
const float min_bound,
gpuStream_t stream) {
// TODO(minghaoBD): optimize the kennel launch times when m==1 or n==1
dim3 grid((n >> 2 + 31) / 32, (m + 31) / 32);
dim3 block(32, 32);
quantize_kernel<<<grid, block, 0, stream>>>(input,
(char4*)output, // NOLINT
scale,
m,
n,
round_type,
max_bound,
min_bound);
}
template <typename T, int VecSize>
__global__ void dequantize_kernel(T* output,
const int32_t* input,
const int m, // batch size
const int n, // hidden
const float quant_in_scale,
const float* dequant_out_scale_data) {
int numel = m * n;
int stride = blockDim.x * gridDim.x * VecSize;
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
int col_id = idx % n;
phi::AlignedVector<int32_t, VecSize> in_vec;
phi::AlignedVector<float, VecSize> out_scale_vec;
phi::AlignedVector<T, VecSize> out_vec;
for (; idx < numel; idx += stride) {
phi::Load<int32_t, VecSize>(input + idx, &in_vec);
phi::Load<float, VecSize>(dequant_out_scale_data + col_id, &out_scale_vec);
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
out_vec[i] =
static_cast<T>(static_cast<float>(in_vec[i]) * out_scale_vec[i]);
}
phi::Store<T, VecSize>(out_vec, output + idx);
}
}
template <typename T>
void dequantize_kernel_launcher(const int32_t* input,
T* output,
const int m, // m
const int n, // n
gpuStream_t stream,
GpuLaunchConfig* gpu_config,
const float quant_in_scale,
const float* dequant_out_scale_data) {
dequantize_kernel<T, DequantKernelVecSize>
<<<gpu_config->block_per_grid, gpu_config->thread_per_block, 0, stream>>>(
output, input, m, n, quant_in_scale, dequant_out_scale_data);
}
} // namespace phi
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*! \file
\brief Templates exposing architecture support for multiply-add operations
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace arch {
// Tag which triggers MMA which will trigger
struct OpMultiplyAddDequantizeInterleavedBToA;
} // namespace arch
} // namespace cutlass
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cuda_runtime_api.h>
#include "cutlass/device_kernel.h"
#include "paddle/phi/kernels/fusion/cutlass/utils/cuda_utils.h"
namespace phi {
template <typename GemmKernel>
inline int compute_occupancy_for_kernel() {
int smem_size = static_cast<int>(sizeof(typename GemmKernel::SharedStorage));
if (smem_size > (48 << 10)) {
cudaError_t status =
cudaFuncSetAttribute(cutlass::Kernel<GemmKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (status == cudaError::cudaErrorInvalidValue) {
// Clear the error bit since we can ignore this.
// This should mean that smem_size >
// cudaDevAttrMaxSharedMemoryPerBlockOptin. In that case, we return an
// occupancy of 0. This will cause the heuristic to ignore this
// configuration.
status = cudaGetLastError();
return 0;
}
check_cuda_error(status);
}
int max_active_blocks = -1;
check_cuda_error(
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks,
cutlass::Kernel<GemmKernel>,
GemmKernel::kThreadCount,
smem_size));
return max_active_blocks;
}
} // namespace phi
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
// define scaling mode
enum class QuantMode {
PerTensorQuant,
PerTokenQuant,
PerChannelQuant,
PerTokenChannelQuant
};
} // namespace epilogue
} // namespace cutlass
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*! \file
\brief Functor performing linear combination with a maximum operation used by
epilogues.
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/epilogue/thread/scale_type.h"
#include "cutlass/functional.h"
#include "cutlass/half.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace thread {
/////////////////////////////////////////////////////////////////////////////////////////////////
__forceinline__ __device__ float copysignf_pos(float a, float b) {
float r;
r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000));
return r;
}
__forceinline__ __device__ float tanh_opt(float x) {
#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750)
const float exp_val = -1.f * fabs(2 * x);
return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);
#else
return fast_tanh(x);
#endif
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template <>
struct GELU_taylor<float> {
static const bool kIsHeavy = true;
CUTLASS_DEVICE
float operator()(float const& z) const {
float k0 = static_cast<float>(0.7978845608028654);
float k1 = static_cast<float>(0.044715);
return static_cast<float>(
cutlass::constants::half<float>() * z *
(cutlass::constants::one<float>() +
tanh_opt(k0 * z * (cutlass::constants::one<float>() + k1 * z * z))));
}
using Params = LinearCombinationGenericParams<float>;
CUTLASS_DEVICE
float operator()(float const& scalar, Params const& params_) const {
return this->operator()(scalar);
}
};
} // namespace thread
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*! \file
\brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one
scaling factor per row, and one per column.
original file:
3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "../epilogue_quant_helper.h"
#include "cutlass/arch/memory.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/numeric_conversion.h"
namespace cutlass {
namespace epilogue {
namespace threadblock {
template <typename ThreadblockShape_,
int ThreadCount,
typename ScaleTileIterator_,
typename OutputTileIterator_,
typename ElementAccumulator_,
typename ElementCompute_,
typename ElementwiseFunctor_,
bool UseMasking_ = false>
class EpilogueVisitorPerRowPerCol {
public:
using ThreadblockShape = ThreadblockShape_;
static int const kThreadCount = ThreadCount;
using ScaleTileIterator = ScaleTileIterator_;
using OutputTileIterator = OutputTileIterator_;
using ElementwiseFunctor = ElementwiseFunctor_;
static int const kIterations = OutputTileIterator::kIterations;
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
using ElementOutput = typename OutputTileIterator::Element;
using LayoutOutput = cutlass::layout::RowMajor;
using ElementAccumulator = ElementAccumulator_;
using AlphaScaleElementType = typename ScaleTileIterator::Element;
using ElementCompute = ElementCompute_;
using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
using ComputeFragment = Array<ElementCompute_, kElementsPerAccess>;
using OutputVector = Array<ElementOutput, kElementsPerAccess>;
static int const kThreadsPerRow =
OutputTileIterator::ThreadMap::Detail::kAccessWidth;
static bool const kHasMultiStepsInRow =
(OutputTileIterator::ThreadMap::Iterations::kColumn > 1);
/// Argument structure
struct Arguments {
typename ElementwiseFunctor::Params elementwise;
int64_t batch_stride_alpha;
int64_t batch_stride_C;
int64_t batch_stride_D;
//
// Methods
//
Arguments() : batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {}
explicit Arguments(typename ElementwiseFunctor::Params elementwise_)
: elementwise(elementwise_),
batch_stride_alpha(0),
batch_stride_C(0),
batch_stride_D(0) {}
Arguments(typename ElementwiseFunctor::Params elementwise_,
int64_t batch_stride_alpha_,
int64_t batch_stride_C_,
int64_t batch_stride_D_)
: elementwise(elementwise_),
batch_stride_alpha(batch_stride_alpha_),
batch_stride_C(batch_stride_C_),
batch_stride_D(batch_stride_D_) {}
};
struct Params {
typename ElementwiseFunctor::Params elementwise;
int64_t batch_stride_alpha;
int64_t batch_stride_C;
int64_t batch_stride_D;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params() {}
CUTLASS_HOST_DEVICE
explicitParams(Arguments const& args)
: elementwise(args.elementwise),
batch_stride_alpha(args.batch_stride_alpha),
batch_stride_C(args.batch_stride_C),
batch_stride_D(args.batch_stride_D) {}
};
/// Shared storage
struct SharedStorage {};
private:
Params const& params_;
SharedStorage& shared_storage_;
MatrixCoord extent_;
MatrixCoord extent_real_;
ElementwiseFunctor elementwise_;
const bool per_token_quant_;
const bool per_channel_quant_;
AlphaScaleElementType* ptr_alpha_row_;
AlphaScaleElementType* ptr_alpha_col_;
ScaleTileIterator iterator_alpha_col_;
OutputTileIterator iterator_C_;
OutputTileIterator iterator_D_;
AlphaScaleElementType element_alpha_row_ = 1.0f;
AlphaScaleElementType element_alpha_col_ = 1.0f;
typename ScaleTileIterator::Fragment fragment_alpha_col_;
typename OutputTileIterator::Fragment fragment_C_;
typename OutputTileIterator::Fragment fragment_D_;
ElementAccumulator beta_;
int column_offset_;
MatrixCoord thread_offset_;
public:
CUTLASS_DEVICE
EpilogueVisitorPerRowPerCol(
Params const& params,
SharedStorage& shared_storage, // NOLINT
cutlass::MatrixCoord const& problem_size,
int thread_idx,
int warp_idx,
int lane_idx,
typename ScaleTileIterator::Params params_alpha_col,
typename OutputTileIterator::Params params_C,
typename OutputTileIterator::Params params_D,
QuantMode quant_mode,
AlphaScaleElementType* ptr_alpha_row,
AlphaScaleElementType* ptr_alpha_col,
typename OutputTileIterator::Element* ptr_C,
typename OutputTileIterator::Element* ptr_D,
cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0,
0),
int column_offset = 0,
cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0,
0))
: params_(params),
shared_storage_(shared_storage),
extent_(problem_size),
elementwise_(params.elementwise),
per_token_quant_(quant_mode == QuantMode::PerTokenQuant ||
quant_mode == QuantMode::PerTokenChannelQuant),
per_channel_quant_(quant_mode == QuantMode::PerChannelQuant ||
quant_mode == QuantMode::PerTokenChannelQuant),
ptr_alpha_row_(ptr_alpha_row),
ptr_alpha_col_(ptr_alpha_col),
iterator_alpha_col_(params_alpha_col,
ptr_alpha_col,
problem_size,
thread_idx,
threadblock_offset),
iterator_C_(
params_C, ptr_C, problem_size, thread_idx, threadblock_offset),
iterator_D_(
params_D, ptr_D, problem_size, thread_idx, threadblock_offset),
extent_real_(problem_size_real) {
beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr
: params.elementwise.beta);
if (beta_ == ElementAccumulator()) {
iterator_C_.clear_mask();
}
}
/// Helper to indicate split-K behavior
CUTLASS_DEVICE
void set_k_partition(
int split_k_index, ///< Index of this threadblock within split-K
///< partitioned scheme
int split_k_slices) { ///< Total number of split-K slices
}
/// Called to set the batch index
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
iterator_alpha_col_.add_pointer_offset(batch_idx *
params_.batch_stride_alpha);
iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C);
iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D);
}
/// Called at the start of the epilogue just before iterating over accumulator
/// slices
CUTLASS_DEVICE
void begin_epilogue() {
if (per_channel_quant_) {
iterator_alpha_col_.load(fragment_alpha_col_);
} else if (ptr_alpha_col_ != nullptr) {
arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
element_alpha_col_, ptr_alpha_col_, true);
}
if (!per_token_quant_ && ptr_alpha_row_ != nullptr) {
arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
element_alpha_row_, ptr_alpha_row_, true);
}
}
/// Called at the start of one step before starting accumulator exchange
CUTLASS_DEVICE
void begin_step(int step_idx) {
fragment_D_.clear();
fragment_C_.clear();
if (elementwise_.kScale !=
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
iterator_C_.load(fragment_C_);
++iterator_C_;
}
// load alpha_row in begin_step only when per token(row) scaling is used
if (per_token_quant_) {
int thread_offset_row =
iterator_D_.thread_start_row() +
OutputTileIterator::ThreadMap::iteration_offset(0).row();
// element_alpha_row_ = ptr_alpha_row_[thread_offset_row];
arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
element_alpha_row_,
ptr_alpha_row_ + thread_offset_row,
thread_offset_row < extent_.row());
}
}
/// Called at the start of a row
CUTLASS_DEVICE
void begin_row(int row_idx) {
// Clear accumulators for max and sum when starting a whole row
}
/// Called after accumulators have been exchanged for each accumulator vector
CUTLASS_DEVICE
void visit(int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorFragment const& accum) {
NumericArrayConverter<ElementCompute,
ElementAccumulator,
kElementsPerAccess>
source_converter;
ComputeFragment result = source_converter(accum);
if (per_channel_quant_) {
ComputeFragment alpha_col =
reinterpret_cast<ComputeFragment*>(&fragment_alpha_col_)[frag_idx];
result = per_token_channel_scale_accumulator_(
result, alpha_col, element_alpha_row_);
} else {
result = per_token_scale_accumulator_(
result, element_alpha_col_, element_alpha_row_);
}
// Convert to the output
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess>
output_converter;
OutputVector& output =
reinterpret_cast<OutputVector*>(&fragment_D_)[frag_idx];
output = output_converter(result);
}
/// Called after all accumulator elements have been visited
CUTLASS_DEVICE
void end_step(int step_idx) {
iterator_D_.store(fragment_D_);
++iterator_D_;
}
/// Called after all steps have been completed
CUTLASS_DEVICE
void end_epilogue() {}
private:
CUTLASS_DEVICE
ComputeFragment per_token_channel_scale_accumulator_(
ComputeFragment const& accum,
ComputeFragment const& scale_col,
AlphaScaleElementType const& scale_row) {
ComputeFragment result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ComputeFragment::kElements; ++i) {
result[i] = accum[i] * (scale_col[i] * scale_row);
}
return result;
}
CUTLASS_DEVICE
ComputeFragment per_token_scale_accumulator_(
ComputeFragment const& accum,
AlphaScaleElementType const& scale_col,
AlphaScaleElementType const& scale_row) {
ComputeFragment result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ComputeFragment::kElements; ++i) {
result[i] = accum[i] * (scale_col * scale_row);
}
return result;
}
};
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
The epilogue rearranges the result of a matrix product through shared memory
to match canonical tensor layouts in global memory. Epilogues support
conversion and reduction operations.
original file:
3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/platform/platform.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
#include "cutlass/epilogue/thread/linear_combination_gelu.h"
#include "cutlass/epilogue/thread/linear_combination_hardswish.h"
#include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
#include "cutlass/epilogue/thread/linear_combination_relu.h"
#include "cutlass/epilogue/thread/linear_combination_relu0.h"
#include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
#include "cutlass/epilogue/thread/conversion_op.h"
#include "cutlass/epilogue/thread/reduction_op.h"
#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h"
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h"
#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h"
#include "cutlass/epilogue/threadblock/epilogue.h"
#include "cutlass/epilogue/threadblock/interleaved_epilogue.h"
#include "cutlass/layout/permute.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
namespace detail {
/// Partial specialization for half <= int32_t x 8 epilogues avoids shared
/// memory bank conflicts.
template <typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename ThreadMap>
struct DefaultIteratorsTensorOp<cutlass::half_t,
int32_t,
8,
ThreadblockShape,
WarpShape,
InstructionShape,
ThreadMap> {
using WarpTileIterator =
cutlass::epilogue::warp::TileIteratorTensorOp<WarpShape,
InstructionShape,
int32_t,
layout::RowMajor>;
using SharedLoadIterator =
cutlass::epilogue::threadblock::SharedLoadIterator<ThreadMap, int32_t>;
static int const kFragmentsPerIteration = 1;
};
/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared
/// memory bank conflicts.
#ifdef PADDLE_CUDA_BF16
template <typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename ThreadMap>
struct DefaultIteratorsTensorOp<cutlass::bfloat16_t,
int32_t,
8,
ThreadblockShape,
WarpShape,
InstructionShape,
ThreadMap> {
using WarpTileIterator =
cutlass::epilogue::warp::TileIteratorTensorOp<WarpShape,
InstructionShape,
int32_t,
layout::RowMajor>;
using SharedLoadIterator =
cutlass::epilogue::threadblock::SharedLoadIterator<ThreadMap, int32_t>;
static int const kFragmentsPerIteration = 1;
};
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Tile iterator used to load output tile from shared memory in epilogue.
///
/// Satisfies: ReadableTileIterator
///
template <typename ThreadMap_ ///< Thread map (conept: OutputTileThreadMap)
>
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8> {
public:
using ThreadMap = ThreadMap_;
using Shape = typename ThreadMap::Shape;
using Element = int32_t;
using Layout = layout::RowMajor;
using TensorRef = TensorRef<Element, Layout>;
using ConstTensorRef = typename TensorRef::ConstTensorRef;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorCoord = MatrixCoord;
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
static int const kAlignment =
ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value / 8;
static int const kThreads = ThreadMap::kThreads;
/// Fragment object
using Fragment =
Array<Element,
ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow *
ThreadMap::Iterations::kGroup *
ThreadMap::Iterations::kCluster *
ThreadMap::kElementsPerAccess>;
/// Memory access size
using AccessType =
AlignedArray<Element, ThreadMap::kElementsPerAccess, kAlignment>;
/// Vector type used for SMEM loads
using LoadType = AlignedArray<Element,
const_min(128 / sizeof_bits<Element>::value,
ThreadMap::kElementsPerAccess),
const_min(16, kAlignment)>;
static int const kLoadsPerAccess =
AccessType::kElements / LoadType::kElements;
private:
//
// Data members
//
/// Byte-level pointer
LoadType const* pointers_[kLoadsPerAccess];
/// Stride along adjacent rows in units of LoadType
int stride_;
public:
//
// Methods
//
/// Constructor
CUTLASS_DEVICE
SharedLoadIteratorMixed(TensorRef ref, int thread_idx)
: stride_((ref.stride(0) / LoadType::kElements)) {
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx);
// Initialize pointers
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kLoadsPerAccess; ++i) {
pointers_[i] = reinterpret_cast<LoadType const*>(ref.data());
int col_idx =
(thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess;
int bank_offset = (col_idx * static_cast<int>(sizeof(LoadType)) / 128) %
kLoadsPerAccess;
col_idx += (bank_offset + i) % kLoadsPerAccess;
pointers_[i] += thread_offset.row() * stride_ + col_idx;
}
}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kLoadsPerAccess; ++i) {
pointers_[i] += pointer_offset / LoadType::kElements;
}
}
CUTLASS_DEVICE
void add_tile_offset(TensorCoord const& offset) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kLoadsPerAccess; ++i) {
pointers_[i] += offset.row() * Shape::kRow * stride_ +
offset.column() * Shape::kColumn / LoadType::kElements;
}
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load_with_pointer_offset(Fragment& frag, // NOLINT
Index pointer_offset) const {
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster;
++cluster) {
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int row_ptr_offset = row * ThreadMap::Delta::kRow * stride_ +
group * ThreadMap::Delta::kGroup * stride_ +
cluster * ThreadMap::Delta::kCluster * stride_ +
pointer_offset / LoadType::kElements;
int frag_row_idx =
(row + ThreadMap::Iterations::kRow *
(group + ThreadMap::Iterations::kGroup * cluster));
LoadType* frag_ptr = reinterpret_cast<LoadType*>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn;
++column) {
int frag_idx =
frag_row_idx * ThreadMap::Iterations::kColumn + column;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < kLoadsPerAccess; ++v) {
int vector_idx = (column * ThreadMap::Delta::kColumn /
kElementsPerAccess * kLoadsPerAccess);
LoadType const* memory_pointer = pointers_[v] + row_ptr_offset;
frag_ptr[frag_idx * kLoadsPerAccess + v] =
memory_pointer[vector_idx];
}
}
}
}
}
}
/// Loads a fragment
CUTLASS_DEVICE
void load(Fragment& frag) const { // NOLINT
load_with_pointer_offset(frag, 0);
}
};
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/**
* @file epilogue_helpers.h
*
* This file includes types for the epilogues. The empty structs exist so we can
* signal to template code the type of epilogue we want to run, and let the
* underlying code specify the details such as element types, accumulator type
* and elements per vector access.
*
*/
#pragma once
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/linear_combination_generic.h"
#include "cutlass/epilogue/thread/linear_combination_relu.h"
#include "cutlass/epilogue/thread/linear_combination_silu.h"
#include "epilogue/thread/ft_fused_activations.h"
namespace phi {
struct EpilogueOpBiasSilu {};
struct EpilogueOpBiasReLU {};
struct EpilogueOpBiasFtGelu {};
struct EpilogueOpBias {};
struct EpilogueOpNoBias {};
template <typename ElementType,
int ElementsPerVectorAccess,
typename ElementAccumulator,
typename Op>
struct Epilogue {};
template <typename ElementType,
int ElementsPerVectorAccess,
typename ElementAccumulator>
struct Epilogue<ElementType,
ElementsPerVectorAccess,
ElementAccumulator,
EpilogueOpBiasSilu> {
using Op = cutlass::epilogue::thread::LinearCombinationSilu<
ElementType,
ElementsPerVectorAccess,
ElementAccumulator,
ElementAccumulator,
cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
};
template <typename ElementType,
int ElementsPerVectorAccess,
typename ElementAccumulator>
struct Epilogue<ElementType,
ElementsPerVectorAccess,
ElementAccumulator,
EpilogueOpBiasReLU> {
using Op = cutlass::epilogue::thread::LinearCombinationRelu<
ElementType,
ElementsPerVectorAccess,
ElementAccumulator,
ElementAccumulator,
cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
};
template <typename ElementType,
int ElementsPerVectorAccess,
typename ElementAccumulator>
struct Epilogue<ElementType,
ElementsPerVectorAccess,
ElementAccumulator,
EpilogueOpBiasFtGelu> {
using Op = cutlass::epilogue::thread::LinearCombinationGeneric<
cutlass::epilogue::thread::GELU_taylor,
ElementType,
ElementsPerVectorAccess,
ElementAccumulator,
ElementAccumulator,
cutlass::epilogue::thread::ScaleType::NoBetaScaling,
cutlass::FloatRoundStyle::round_to_nearest,
true>;
};
template <typename ElementType,
int ElementsPerVectorAccess,
typename ElementAccumulator>
struct Epilogue<ElementType,
ElementsPerVectorAccess,
ElementAccumulator,
EpilogueOpBias> {
using Op = cutlass::epilogue::thread::LinearCombination<
ElementType,
ElementsPerVectorAccess,
ElementAccumulator,
ElementAccumulator,
cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
};
template <typename ElementType,
int ElementsPerVectorAccess,
typename ElementAccumulator>
struct Epilogue<ElementType,
ElementsPerVectorAccess,
ElementAccumulator,
EpilogueOpNoBias> {
using Op = cutlass::epilogue::thread::LinearCombination<
ElementType,
ElementsPerVectorAccess,
ElementAccumulator,
ElementAccumulator,
cutlass::epilogue::thread::ScaleType::Default>;
};
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
namespace phi {
// Note: The shapes are in the format MxNxK. The K shape of the runtime config
// MUST match the K shape
// in the kernel layout details when doing weight only quantization.
enum class CutlassTileConfig {
// Signals that we should run heuristics do choose a config
Undefined,
// Signals that we should run heuristics do choose a config
ChooseWithHeuristic,
// SiMT config
CtaShape128x128x8_WarpShape64x64x8,
// TensorCore configs CTA_N = 128, CTA_K = 64
// Warp configs for M=32
CtaShape32x128x64_WarpShape32x32x64,
// Warp configs for M=64
CtaShape64x128x64_WarpShape32x64x64,
CtaShape64x128x64_WarpShape64x32x64,
// Warp configs for M=128
CtaShape128x128x64_WarpShape64x32x64,
CtaShape128x128x64_WarpShape128x32x64,
// configs for large M in encoder
CtaShape128x256x64_WarpShape64x64x64,
CtaShape256x128x64_WarpShape64x64x64
};
enum class SplitKStyle {
NO_SPLIT_K,
SPLIT_K_SERIAL,
// SPLIT_K_PARALLEL // Not supported yet
};
struct CutlassGemmConfig {
CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic;
SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K;
int split_k_factor = -1;
int stages = -1;
};
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/bfloat16.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
namespace cutlass {
namespace gemm {
namespace kernel {
template <typename TypeA, typename TypeB, typename arch, typename Enable = void>
struct MixedGemmArchTraits {};
template <typename arch>
struct MixedGemmArchTraits<float, float, arch> {
static constexpr int Stages = 2;
using OperatorClass = cutlass::arch::OpClassSimt;
using AccType = float;
using LayoutB = cutlass::layout::RowMajor;
static constexpr int ElementsPerAccessA = 1;
static constexpr int ElementsPerAccessB = 1;
static constexpr int ElementsPerAccessC = 1;
static constexpr int ThreadblockK = 8;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
using Operator = cutlass::arch::OpMultiplyAdd;
};
// // ========================= Volta Traits ===========================
// // Volta will always dequantize after the global memory load.
// // This will instantiate any HMMA tensorcore kernels for Volta.
// // Note that volta does not have native bfloat support so weights and
// activations will be casted to fp16
// // and compute will happen in fp16 then will be converted for bf16 output.
template <typename TypeA, typename TypeB>
struct MixedGemmArchTraits<
TypeA,
TypeB,
cutlass::arch::Sm70,
typename cutlass::platform::enable_if<
cutlass::platform::is_same<TypeA, cutlass::half_t>::value ||
cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type> {
private:
using LayoutDetails = LayoutDetailsB<TypeB, cutlass::arch::Sm70>;
public:
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using AccType = float;
using LayoutB = typename LayoutDetails::Layout;
static constexpr int ElementsPerAccessA =
128 / cutlass::sizeof_bits<TypeA>::value;
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
static constexpr int ElementsPerAccessC =
128 / cutlass::sizeof_bits<TypeA>::value;
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
using Operator = typename LayoutDetails::Operator;
};
// ======================= Turing Traits ==============================
// Note that turing does not have native bfloat support so weights and
// activations will be casted to fp16 and compute will happen in fp16 then will
// be converted for bf16 output.
template <typename TypeA, typename TypeB>
struct MixedGemmArchTraits<
TypeA,
TypeB,
cutlass::arch::Sm75,
typename cutlass::platform::enable_if<
cutlass::platform::is_same<TypeA, cutlass::half_t>::value ||
cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type> {
private:
using LayoutDetails = LayoutDetailsB<TypeB, cutlass::arch::Sm75>;
public:
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using AccType = float;
using LayoutB = typename LayoutDetails::Layout;
static constexpr int ElementsPerAccessA =
128 / cutlass::sizeof_bits<TypeA>::value;
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
static constexpr int ElementsPerAccessC =
128 / cutlass::sizeof_bits<TypeA>::value;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using Operator = typename LayoutDetails::Operator;
};
// ======================= Ampere Traits ==============================
template <typename TypeA, typename TypeB>
struct MixedGemmArchTraits<
TypeA,
TypeB,
cutlass::arch::Sm80,
typename cutlass::platform::enable_if<
cutlass::platform::is_same<TypeA, cutlass::half_t>::value ||
cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type> {
private:
using LayoutDetails = LayoutDetailsB<TypeB, cutlass::arch::Sm80>;
public:
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using AccType = float;
using LayoutB = typename LayoutDetails::Layout;
static constexpr int ElementsPerAccessA =
128 / cutlass::sizeof_bits<TypeA>::value;
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
static constexpr int ElementsPerAccessC =
128 / cutlass::sizeof_bits<TypeA>::value;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using Operator = typename LayoutDetails::Operator;
};
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*! \file
\brief Template for a pipelined GEMM kernel. Does not compute batching or
support split-K.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/arch.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
typename KernelArch, ///! The Architecture this kernel is compiled
/// for. Used since SIMT kernels lose top-level
/// arch.
bool SplitKSerial ///! If true, code supporting split-K via serial
/// reduction is enabled.
>
struct GemmFpAIntB {
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueOutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static bool const kSplitKSerial = SplitKSerial;
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Element;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Mma::LayoutC;
using ElementScale = float;
static ComplexTransform const kTransformA = Mma::kTransformA;
static ComplexTransform const kTransformB = Mma::kTransformA;
// Type definitions about the mainloop.
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
static int const kStages = Mma::kStages;
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC =
Epilogue::OutputTileIterator::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
static int const kSplitKAlignment = const_max(
128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value);
static constexpr int kInterleave =
Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
/// Parameters structure
struct Arguments : UniversalArgumentsBase {
GemmUniversalMode mode = GemmUniversalMode::kGemm;
cutlass::gemm::GemmCoord problem_size;
typename Mma::IteratorA::TensorRef ref_A;
typename Mma::IteratorB::TensorRef ref_B;
typename Mma::IteratorScale::TensorRef ref_scale;
typename Epilogue::OutputTileIterator::TensorRef ref_C;
typename Epilogue::OutputTileIterator::TensorRef ref_D;
// Control serial split-k
int batch_count;
typename EpilogueOutputOp::Params output_op;
// For gather+scatter operations
int const* gather_A_indices;
int const* gather_B_indices;
int const* scatter_D_indices;
// Included so we can use Gemm Universal
int batch_stride_D = 0;
//
// Methods
//
CUTLASS_HOST_DEVICE
Arguments() {}
CUTLASS_HOST_DEVICE
Arguments(cutlass::gemm::GemmCoord const& problem_size,
typename Mma::IteratorA::TensorRef ref_A,
typename Mma::IteratorB::TensorRef ref_B,
typename Mma::IteratorScale::TensorRef ref_scale,
typename Epilogue::OutputTileIterator::TensorRef ref_C,
typename Epilogue::OutputTileIterator::TensorRef ref_D,
int serial_split_k_factor,
typename EpilogueOutputOp::Params output_op =
typename EpilogueOutputOp::Params(),
int const* gather_A_indices = nullptr,
int const* gather_B_indices = nullptr,
int const* scatter_D_indices = nullptr)
: UniversalArgumentsBase(mode,
problem_size,
/*serial_split_k_factor=*/1,
/*batch_stride_D=*/0),
ref_A(ref_A),
ref_B(ref_B),
ref_scale(ref_scale),
ref_C(ref_C),
ref_D(ref_D),
output_op(output_op),
gather_A_indices(gather_A_indices),
gather_B_indices(gather_B_indices),
scatter_D_indices(scatter_D_indices) {}
};
/// Parameters structure
struct Params : UniversalParamsBase<ThreadblockSwizzle,
ThreadblockShape,
ElementA,
ElementB,
ElementC> {
using ParamsBase = UniversalParamsBase<ThreadblockSwizzle,
ThreadblockShape,
ElementA,
ElementB,
ElementC>;
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorB::Params params_B;
typename Mma::IteratorScale::Params params_scale;
typename Epilogue::OutputTileIterator::Params params_C;
typename Epilogue::OutputTileIterator::Params params_D;
typename Mma::IteratorA::TensorRef ref_A;
typename Mma::IteratorB::TensorRef ref_B;
typename Mma::IteratorScale::TensorRef ref_scale;
typename Epilogue::OutputTileIterator::TensorRef ref_C;
typename Epilogue::OutputTileIterator::TensorRef ref_D;
typename EpilogueOutputOp::Params output_op;
// For gather+scatter operations
int const* gather_A_indices;
int const* gather_B_indices;
int const* scatter_D_indices;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params() = default;
CUTLASS_HOST_DEVICE
Params(Arguments const& args, int device_sms, int sm_occupancy)
: ParamsBase(args, device_sms, sm_occupancy),
params_A(args.ref_A.layout()),
ref_A(args.ref_A),
params_B(args.ref_B.layout()),
ref_B(args.ref_B),
params_scale(args.ref_scale.layout()),
ref_scale(args.ref_scale),
params_C(args.ref_C.layout()),
ref_C(args.ref_C),
params_D(args.ref_D.layout()),
ref_D(args.ref_D),
output_op(args.output_op),
gather_A_indices(args.gather_A_indices),
gather_B_indices(args.gather_B_indices),
scatter_D_indices(args.scatter_D_indices) {}
};
/// Shared memory storage structure
union SharedStorage {
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
};
//
// Methods
//
CUTLASS_HOST_DEVICE
GemmFpAIntB() {}
/// Determines whether kernel satisfies alignment
CUTLASS_HOST_DEVICE
static Status can_implement(Arguments const& args) {
static int const kAlignmentA =
(platform::is_same<typename Mma::IteratorA::Layout,
layout::ColumnMajorInterleaved<32>>::value)
? 32
: (platform::is_same<typename Mma::IteratorA::Layout,
layout::ColumnMajorInterleaved<64>>::value)
? 64
: Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB =
(platform::is_same<typename Mma::IteratorB::Layout,
layout::RowMajorInterleaved<32>>::value)
? 32
: (platform::is_same<typename Mma::IteratorB::Layout,
layout::RowMajorInterleaved<64>>::value)
? 64
: Mma::IteratorB::AccessType::kElements;
static int const kAlignmentScale =
Mma::IteratorScale::AccessType::kElements;
static int const kAlignmentC =
(platform::is_same<typename Epilogue::OutputTileIterator::Layout,
layout::ColumnMajorInterleaved<32>>::value)
? 32
: (platform::is_same<typename Epilogue::OutputTileIterator::Layout,
layout::ColumnMajorInterleaved<64>>::value)
? 64
: Epilogue::OutputTileIterator::kElementsPerAccess;
if (!TensorRef_aligned(args.ref_A, kAlignmentA)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(args.ref_B, kAlignmentB)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(args.ref_scale, kAlignmentScale)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(args.ref_C, kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(args.ref_D, kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
return Status::kSuccess;
}
static size_t get_extra_workspace_size(
Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) {
return 0;
}
CUTLASS_DEVICE
static void invoke(Params const& params,
SharedStorage& shared_storage) { // NOLINT
GemmFpAIntB op;
op(params, shared_storage);
}
// The dummy template parameter is not used and exists so that we can compile
// this code using a standard earlier than C++17. Prior to C++17, fully
// specialized templates HAD to exists in a namespace
template <bool B, typename dummy = void>
struct KernelRunner {
CUTLASS_DEVICE
static void run_kernel(Params const& params,
SharedStorage& shared_storage) { // NOLINT
CUTLASS_NOT_IMPLEMENTED();
}
};
template <typename dummy>
struct KernelRunner<true, dummy> {
CUTLASS_DEVICE
static void run_kernel(Params const& params,
SharedStorage& shared_storage) { // NOLINT
using LayoutB = typename Mma::IteratorB::Layout;
static_assert(
platform::is_same<LayoutB, layout::RowMajor>::value &&
kInterleave == 1 ||
platform::is_same<LayoutB, layout::ColumnMajor>::value &&
kInterleave >= 1,
"B must be row major/col major OR col major interleaved.");
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
return;
}
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.k() * params.gemm_k_size,
};
cutlass::MatrixCoord tb_offset_B{
threadblock_tile_offset.k() * params.gemm_k_size * kInterleave,
threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave};
cutlass::MatrixCoord tb_offset_scale{
0, threadblock_tile_offset.n() * Mma::Shape::kN};
// Problem size is a function of threadblock index in the K dimension
int problem_size_k =
min(params.problem_size.k(),
(threadblock_tile_offset.k() + 1) * params.gemm_k_size);
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations =
(problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) /
Mma::Shape::kK;
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
params.params_A,
params.ref_A.data(),
{params.problem_size.m(), problem_size_k},
thread_idx,
tb_offset_A,
params.gather_A_indices);
typename Mma::IteratorB iterator_B(
params.params_B,
params.ref_B.data(),
{problem_size_k * kInterleave, params.problem_size.n() / kInterleave},
thread_idx,
tb_offset_B,
params.gather_B_indices);
typename Mma::IteratorScale iterator_scale(params.params_scale,
params.ref_scale.data(),
{1, params.problem_size.n()},
thread_idx,
tb_offset_scale);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
if (!kSplitKSerial || gemm_k_iterations > 0) {
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations,
accumulators,
iterator_A,
iterator_B,
iterator_scale,
accumulators);
}
//
// Epilogue
//
EpilogueOutputOp output_op(params.output_op);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// assume identity swizzle
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.n() * Mma::Shape::kN);
int block_idx = threadblock_tile_offset.m() +
threadblock_tile_offset.n() * params.grid_tiled_shape.m();
// Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
// If performing a reduction via split-K, fetch the initial
// synchronization
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
// Fetch the synchronization lock initially but do not block.
semaphore.fetch();
// Indicate which position in a serial reduction the output operator is
// currently updating
output_op.set_k_partition(threadblock_tile_offset.k(),
params.grid_tiled_shape.k());
}
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(
params.params_C,
params.ref_C.data(),
params.problem_size.mn(),
thread_idx,
threadblock_offset,
params.scatter_D_indices);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
params.params_D,
params.ref_D.data(),
params.problem_size.mn(),
thread_idx,
threadblock_offset,
params.scatter_D_indices);
Epilogue epilogue(
shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
// Wait on the semaphore - this latency may have been covered by iterator
// construction
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
// For subsequent threadblocks, the source matrix is held in the 'D'
// tensor.
if (threadblock_tile_offset.k()) {
iterator_C = iterator_D;
}
semaphore.wait(threadblock_tile_offset.k());
}
// Execute the epilogue operator to update the destination tensor.
epilogue(output_op, iterator_D, accumulators, iterator_C);
//
// Release the semaphore
//
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
// The final threadblock resets the semaphore for subsequent grids.
lock = 0;
} else {
// Otherwise, the semaphore is incremented
lock = threadblock_tile_offset.k() + 1;
}
semaphore.release(lock);
}
}
};
/*
To improve compilation speed, we do not compile the device operator if the
CUDA_ARCH does not correspond to the ArchTag of the cutlass kernel
operator.
*/
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const& params,
SharedStorage& shared_storage) { // NOLINT
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750)
// static constexpr bool compile_needed = platform::is_same<KernelArch,
// arch::Sm70>::value; KernelRunner<compile_needed>::run_kernel(params,
// shared_storage);
CUTLASS_NOT_IMPLEMENTED();
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
// static constexpr bool compile_needed = platform::is_same<KernelArch,
// arch::Sm75>::value; KernelRunner<compile_needed>::run_kernel(params,
// shared_storage);
CUTLASS_NOT_IMPLEMENTED();
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
static constexpr bool compile_needed =
platform::is_same<KernelArch, arch::Sm80>::value;
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
#else
CUTLASS_NOT_IMPLEMENTED();
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
This file exists so that we use the same weight layout for MoE grouped gemm
and regular gemm when the weight is quantized. The preprocessing code reads
this template to know how to organize the quantized weight matrices to be
consumed by CUTLASS.
Note that for int4, ThreadBlockK MUST be 64.
*/
#pragma once
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/platform/platform.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/tile_interleaved_layout.h"
namespace cutlass {
namespace gemm {
namespace kernel {
template <typename TypeB, typename Arch, typename Enable = void>
struct LayoutDetailsB {};
// // Volta specialiations. Volta will dequantize before STS, so we need a
// different operator
template <typename TypeB>
struct LayoutDetailsB<TypeB, arch::Sm70> {
static constexpr int ThreadblockK = 64;
using Layout = layout::RowMajor;
static constexpr int ElementsPerAccess = 8;
using Operator = cutlass::arch::OpMultiplyAdd;
};
// Specializations for Turing+ when B is FP16. These are currently only used for
// MoE networks.
template <typename Arch>
struct LayoutDetailsB<
half_t,
Arch,
typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
static constexpr int ThreadblockK = 64;
using Layout = layout::RowMajor;
static constexpr int ElementsPerAccess =
128 / cutlass::sizeof_bits<half_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
template <typename Arch>
struct LayoutDetailsB<
bfloat16_t,
Arch,
typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
static constexpr int ThreadblockK = 64;
using Layout = layout::RowMajor;
static constexpr int ElementsPerAccess =
128 / cutlass::sizeof_bits<bfloat16_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
// Specializations for Turing+ when B is quantized. These can use the operator
// OpMultiplyAddDequantizeInterleavedBToA, which signals that we want to
// dequantize after loading from smem.
template <typename Arch>
struct LayoutDetailsB<
uint8_t,
Arch,
typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
static constexpr int ThreadblockK = 64;
private:
static constexpr int ElementsPerCacheLine =
128 * 8 / sizeof_bits<uint8_t>::value;
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
public:
using Layout =
layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
static constexpr int ElementsPerAccess =
128 / cutlass::sizeof_bits<uint8_t>::value;
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
};
template <typename Arch>
struct LayoutDetailsB<
uint4b_t,
Arch,
typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
static constexpr int ThreadblockK = 64;
private:
static constexpr int ElementsPerCacheLine =
128 * 8 / sizeof_bits<uint4b_t>::value;
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
public:
using Layout =
layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
static constexpr int ElementsPerAccess =
128 / cutlass::sizeof_bits<uint4b_t>::value;
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
};
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/interleaved_numeric_conversion.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
// We need to distinguish here, since we want volta support. It is too much
// effort to write shared memory iterators that are probably needed for volta to
// function properly. As a result, we allow converters both after the LDG (for
// volta) and after the LDS for Turing+.
template <
/// Iterator for B matrix in global memory
typename IteratorB,
/// Warp level Mma
typename MmaOperator,
/// Math operation perform by warp level operator
typename MathOperator>
struct SetConverters {};
// Dequantize after LDG, so set transforms accordingly
template <
/// Iterator for B matrix in global memory
typename IteratorB,
/// Mma Policy
typename MmaOperator>
struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAdd> {
using TransformAfterLDG = FastInterleavedAndBiasedNumericArrayConverter<
typename MmaOperator::ArchMmaOperator::ElementB,
typename IteratorB::Element,
IteratorB::Fragment::kElements>;
using TransformAfterLDS =
NumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
typename MmaOperator::ArchMmaOperator::ElementB,
MmaOperator::FragmentB::kElements>;
};
// Dequantize after LDS, so set transforms accordingly
template <
/// Iterator for B matrix in global memory
typename IteratorB,
/// Mma Policy
typename MmaOperator>
struct SetConverters<IteratorB,
MmaOperator,
arch::OpMultiplyAddDequantizeInterleavedBToA> {
using TransformAfterLDG =
NumericArrayConverter<typename IteratorB::Element,
typename IteratorB::Element,
IteratorB::Fragment::kElements>;
using TransformAfterLDS = FastInterleavedAndBiasedNumericArrayConverter<
typename MmaOperator::ArchMmaOperator::ElementB,
typename TransformAfterLDG::result_type::Element,
MmaOperator::FragmentB::kElements>;
};
////////////////////////////////////////////////////////////////////////////////
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for the input scale
typename ElementScale_,
/// Layout for the scale operand
typename LayoutScale_,
/// Access granularity of Scales in unit of elements
int kAlignmentScale,
/// Element type for internal accumulation
typename ElementAccumulator_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Operator class tag
typename OperatorClass_,
/// Tag indicating architecture to tune for
typename ArchTag_,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
typename Operator_,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
///
typename Enable = void>
struct DqMma;
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "cutlass/gemm/threadblock/default_mma.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/default_mma_tensor_op.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/tile_interleaved_layout.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
template <
/// Type for elementA
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Type for element B
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for the input scale
typename ElementScale,
/// Layout for the scale operand
typename LayoutScale,
/// Access granularity of Scales in unit of elements
int kAlignmentScale,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Stages in GEMM
int kStages,
///
typename Operator,
///
SharedMemoryClearOption SharedMemoryClear>
struct DqMma<ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementScale,
LayoutScale,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
SharedMemoryClear,
typename platform::enable_if<(ArchTag::kMinComputeCapability >=
80)>::type> {
static_assert(platform::is_same<ElementA, half_t>::value ||
platform::is_same<ElementA, bfloat16_t>::value,
"Element A must be fp16 or bf16");
static_assert(
platform::is_same<Operator,
arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
"Mma multistage must dequantize after ldsm");
static_assert(platform::is_same<ElementB, uint8_t>::value ||
platform::is_same<ElementB, uint4b_t>::value,
"Element B must be int8 or uint4");
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the MmaCore components
// Mma core does not depend on stages, so pass in at least 3 here to mma
// multistage pieces are created
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
WarpShape,
InstructionShape,
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementAccumulator,
layout::RowMajor,
OperatorClass,
std::max(kStages, 3),
Operator,
false,
CacheOpA,
CacheOpB>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
using IteratorA =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA,
LayoutA,
1,
ThreadMapA,
AccessTypeA>;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
using IteratorB =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB,
LayoutB,
0,
ThreadMapB,
AccessTypeB>;
// ThreadMap for scale iterator
static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
MmaCore::Shape::kN / kAlignmentScale,
kAlignmentScale>;
// Define iterators over tiles from the scale operand
using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<1, MmaCore::Shape::kN>,
ElementScale,
LayoutScale,
0,
IteratorScaleThreadMap,
kAlignmentScale>;
using SmemIteratorScale = IteratorScale;
using Converter = FastInterleavedAndBiasedNumericArrayConverter<
ElementA,
ElementB,
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<
typename MmaCore::Shape,
IteratorA,
typename MmaCore::SmemIteratorA,
MmaCore::kCacheOpA,
IteratorB,
typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB,
IteratorScale,
SmemIteratorScale,
ElementAccumulator,
layout::RowMajor,
typename MmaCore::MmaPolicy,
kStages,
Converter,
SharedMemoryClear>;
};
template <
/// Type for element A
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Type for element B
typename ElementB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for the input scale
typename ElementScale,
/// Layout for the scale operand
typename LayoutScale,
/// Access granularity of Scales in unit of elements
int kAlignmentScale,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Stages in GEMM
int kStages,
///
typename Operator,
///
SharedMemoryClearOption SharedMemoryClear,
///
int RowsPerTile,
///
int ColumnsInterleaved>
struct DqMma<ElementA,
LayoutA,
kAlignmentA,
ElementB,
layout::ColumnMajorTileInterleave<RowsPerTile, ColumnsInterleaved>,
kAlignmentB,
ElementScale,
LayoutScale,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
SharedMemoryClear,
typename platform::enable_if<(ArchTag::kMinComputeCapability >=
80)>::type> {
static_assert(platform::is_same<ElementA, half_t>::value ||
platform::is_same<ElementA, bfloat16_t>::value,
"Element A must be fp16 or bf16");
static_assert(
platform::is_same<Operator,
arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
"Mma multistage must dequantize after ldsm");
static_assert(platform::is_same<ElementB, uint8_t>::value ||
platform::is_same<ElementB, uint4b_t>::value,
"Element B must be uint8 or uint4");
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the MmaCore components
// Mma core does not depend on stages, so pass in at least 3 here to mma
// multistage pieces are created
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
WarpShape,
InstructionShape,
ElementA,
LayoutA,
ElementB,
layout::ColumnMajor,
ElementAccumulator,
layout::RowMajor,
OperatorClass,
std::max(kStages, 3),
Operator,
false,
CacheOpA,
CacheOpB>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
using IteratorA =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA,
LayoutA,
1,
ThreadMapA,
AccessTypeA>;
private:
static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
static_assert(RowsPerTile == MmaCore::Shape::kK, "");
using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
using OriginalWarpArrangement =
typename OriginalThreadMap::Detail::WarpThreadArrangement;
static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
using GmemIteratorShape =
MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved,
MmaCore::Shape::kN / ColumnsInterleaved>;
using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<GmemIteratorShape::kRow,
GmemIteratorShape::kColumn>,
OriginalThreadMap::kThreads,
layout::PitchLinearShape<
OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
public:
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
using IteratorB =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
GmemIteratorShape,
ElementB,
layout::ColumnMajor,
0,
GmemThreadMapB,
AccessTypeB>;
// ThreadMap for scale iterator
static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
MmaCore::Shape::kN / kAlignmentScale,
kAlignmentScale>;
// Define iterators over tiles from the scale operand
using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<1, MmaCore::Shape::kN>,
ElementScale,
LayoutScale,
0,
IteratorScaleThreadMap,
kAlignmentScale>;
using SmemIteratorScale = IteratorScale;
using Converter = FastInterleavedAndBiasedNumericArrayConverter<
ElementA,
ElementB,
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<
typename MmaCore::Shape,
IteratorA,
typename MmaCore::SmemIteratorA,
MmaCore::kCacheOpA,
IteratorB,
typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB,
IteratorScale,
SmemIteratorScale,
ElementAccumulator,
layout::RowMajor,
typename MmaCore::MmaPolicy,
kStages,
Converter,
SharedMemoryClear>;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "cutlass/gemm/threadblock/default_mma.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/default_mma_tensor_op.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/tile_interleaved_layout.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
template <
/// Type for element A
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Type for element B
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for the input scale
typename ElementScale,
/// Layout for the scale operand
typename LayoutScale,
/// Access granularity of Scales in unit of elements
int kAlignmentScale,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator>
struct DqMma<
ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementScale,
LayoutScale,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator,
SharedMemoryClearOption::kNone,
typename platform::enable_if<(ArchTag::kMinComputeCapability < 80)>::type> {
static_assert(platform::is_same<ElementA, half_t>::value ||
platform::is_same<ElementA, bfloat16_t>::value,
"Element A must be fp16 or bf16");
static_assert(platform::is_same<ElementB, uint8_t>::value ||
platform::is_same<ElementB, uint4b_t>::value,
"Element B must be int8 or uint4");
static constexpr bool DqAfterLDG =
platform::is_same<arch::OpMultiplyAdd, Operator>::value;
static constexpr bool arch_has_bf16_mma =
ArchTag::kMinComputeCapability >= 80;
using MmaCoreElementA =
typename platform::conditional<arch_has_bf16_mma, ElementA, half_t>::type;
using MmaCoreElementB = typename platform::
conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
// Define the MmaCore components
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
WarpShape,
InstructionShape,
MmaCoreElementA,
LayoutA,
MmaCoreElementB,
LayoutB,
ElementAccumulator,
layout::RowMajor,
OperatorClass,
2,
Operator>;
// Define iterators over tiles from the A operand
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
ElementA,
LayoutA,
1,
typename MmaCore::IteratorThreadMapA,
kAlignmentA>;
// Define iterators over tiles from the B operand
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>,
ElementB,
LayoutB,
0,
typename MmaCore::IteratorThreadMapB,
kAlignmentB>;
// ThreadMap for scale iterator
static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
MmaCore::Shape::kN / kAlignmentScale,
kAlignmentScale>;
// Define iterators over tiles from the scale operand
using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<1, MmaCore::Shape::kN>,
ElementScale,
LayoutScale,
0,
IteratorScaleThreadMap,
kAlignmentScale>;
using SmemScaleType = typename platform::
conditional<arch_has_bf16_mma, ElementScale, half_t>::type;
using SmemIteratorScale =
cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<1, MmaCore::Shape::kN>,
SmemScaleType,
LayoutScale,
0,
IteratorScaleThreadMap,
kAlignmentScale>;
using Converters =
SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<
typename MmaCore::Shape,
IteratorA,
typename MmaCore::SmemIteratorA,
IteratorB,
typename MmaCore::SmemIteratorB,
IteratorScale,
SmemIteratorScale,
ElementAccumulator,
layout::RowMajor,
typename MmaCore::MmaPolicy,
typename Converters::TransformAfterLDG,
typename Converters::TransformAfterLDS>;
};
// Specialization to handle column major interleave B
template <
/// Type for element A
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Type for element B
typename ElementB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for the input scale
typename ElementScale,
/// Layout for the scale operand
typename LayoutScale,
/// Access granularity of Scales in unit of elements
int kAlignmentScale,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
///
int RowsPerTile,
///
int ColumnsInterleaved>
struct DqMma<
ElementA,
LayoutA,
kAlignmentA,
ElementB,
layout::ColumnMajorTileInterleave<RowsPerTile, ColumnsInterleaved>,
kAlignmentB,
ElementScale,
LayoutScale,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator,
SharedMemoryClearOption::kNone,
typename platform::enable_if<(ArchTag::kMinComputeCapability < 80)>::type> {
static_assert(platform::is_same<ElementA, half_t>::value ||
platform::is_same<ElementA, bfloat16_t>::value,
"Element A must be fp16 or bf16");
static_assert(platform::is_same<ElementB, uint8_t>::value ||
platform::is_same<ElementB, uint4b_t>::value,
"Element B must be int8 or uint4");
static constexpr bool DqAfterLDG =
platform::is_same<arch::OpMultiplyAdd, Operator>::value;
static constexpr bool arch_has_bf16_mma =
ArchTag::kMinComputeCapability >= 80;
using MmaCoreElementA =
typename platform::conditional<arch_has_bf16_mma, ElementA, half_t>::type;
using MmaCoreElementB = typename platform::
conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
// Define the MmaCore components
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
WarpShape,
InstructionShape,
MmaCoreElementA,
LayoutA,
MmaCoreElementB,
layout::ColumnMajor,
ElementAccumulator,
layout::RowMajor,
OperatorClass,
2,
Operator>;
// Define iterators over tiles from the A operand
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
ElementA,
LayoutA,
1,
typename MmaCore::IteratorThreadMapA,
kAlignmentA>;
private:
static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
static_assert(RowsPerTile == MmaCore::Shape::kK, "");
using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
using OriginalWarpArrangement =
typename OriginalThreadMap::Detail::WarpThreadArrangement;
static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
using GmemIteratorShape =
MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved,
MmaCore::Shape::kN / ColumnsInterleaved>;
using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<GmemIteratorShape::kRow,
GmemIteratorShape::kColumn>,
OriginalThreadMap::kThreads,
layout::PitchLinearShape<
OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
public:
// Define iterators over tiles from the B operand
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
GmemIteratorShape,
ElementB,
layout::ColumnMajor,
0,
GmemThreadMapB,
kAlignmentB>;
// ThreadMap for scale iterator
static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
MmaCore::Shape::kN / kAlignmentScale,
kAlignmentScale>;
// Define iterators over tiles from the scale operand
using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<1, MmaCore::Shape::kN>,
ElementScale,
LayoutScale,
0,
IteratorScaleThreadMap,
kAlignmentScale>;
using SmemScaleType = typename platform::
conditional<arch_has_bf16_mma, ElementScale, half_t>::type;
using SmemIteratorScale =
cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<1, MmaCore::Shape::kN>,
SmemScaleType,
LayoutScale,
0,
IteratorScaleThreadMap,
kAlignmentScale>;
using Converters =
SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<
typename MmaCore::Shape,
IteratorA,
typename MmaCore::SmemIteratorA,
IteratorB,
typename MmaCore::SmemIteratorB,
IteratorScale,
SmemIteratorScale,
ElementAccumulator,
layout::RowMajor,
typename MmaCore::MmaPolicy,
typename Converters::TransformAfterLDG,
typename Converters::TransformAfterLDS>;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_mma_bf16.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16
/// activation & int8 weight
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator>
struct DefaultMma<cutlass::half_t,
LayoutA,
kAlignmentA,
uint8_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator> {
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<float>::value;
using Mma = DqMma<half_t,
LayoutA,
kAlignmentA,
uint8_t,
LayoutB,
kAlignmentB,
float,
layout::RowMajor,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16
/// activation & int4 weight
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator>
struct DefaultMma<cutlass::half_t,
LayoutA,
kAlignmentA,
uint4b_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator> {
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<float>::value;
using Mma = DqMma<half_t,
LayoutA,
kAlignmentA,
uint4b_t,
LayoutB,
kAlignmentB,
float,
layout::RowMajor,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
///
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<cutlass::half_t,
LayoutA,
kAlignmentA,
uint8_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
false,
SharedMemoryClear> {
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<float>::value;
using Mma = DqMma<half_t,
LayoutA,
kAlignmentA,
uint8_t,
LayoutB,
kAlignmentB,
float,
layout::RowMajor,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
SharedMemoryClear>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16
/// activation & int4 weight
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
///
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<cutlass::half_t,
LayoutA,
kAlignmentA,
uint4b_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
false,
SharedMemoryClear> {
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<float>::value;
using Mma = DqMma<half_t,
LayoutA,
kAlignmentA,
uint4b_t,
LayoutB,
kAlignmentB,
float,
layout::RowMajor,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
SharedMemoryClear>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps
// avoid reg spills on large tile when not enough shared mem is present to do 3+
// stage
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear,
/// Gather operand A by using an index array
bool GatherA,
/// Gather operand B by using an index array
bool GatherB>
struct DefaultMma<half_t,
LayoutA,
kAlignmentA,
half_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
arch::Sm80,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator,
false,
SharedMemoryClear,
GatherA,
GatherB> {
// Define the MmaCore components
// 3 is used on purpose here to trigger components for mma multistage
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
WarpShape,
InstructionShape,
half_t,
LayoutA,
half_t,
LayoutB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
3,
Operator>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
using IteratorA =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
half_t,
LayoutA,
1,
ThreadMapA,
AccessTypeA,
GatherA>;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
using IteratorB =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
half_t,
LayoutB,
0,
ThreadMapB,
AccessTypeB,
GatherB>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma =
cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape,
IteratorA,
typename MmaCore::SmemIteratorA,
MmaCore::kCacheOpA,
IteratorB,
typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB,
ElementAccumulator,
layout::RowMajor,
typename MmaCore::MmaPolicy,
2>;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "cutlass/gemm/threadblock/default_mma.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), bf16
/// activation & bf16 weight
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear,
/// Gather operand A by using an index array
bool GatherA,
/// Gather operand B by using an index array
bool GatherB>
struct DefaultMma<bfloat16_t,
LayoutA,
kAlignmentA,
bfloat16_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator,
false,
SharedMemoryClear,
GatherA,
GatherB> {
private:
// Conversions only needed pre-ampere. This will trigger mma pipeline, so we
// convert before STS.
static constexpr bool arch_has_bf16_mma =
ArchTag::kMinComputeCapability >= 80;
using MmaElementA = typename platform::
conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
using MmaElementB = typename platform::
conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
public:
// Define the MmaCore components
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
WarpShape,
InstructionShape,
MmaElementA,
LayoutA,
MmaElementB,
LayoutB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
2,
Operator>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
bfloat16_t,
LayoutA,
1,
typename MmaCore::IteratorThreadMapA,
kAlignmentA,
GatherA>;
// Define iterators over tiles from the B operand
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>,
bfloat16_t,
LayoutB,
0,
typename MmaCore::IteratorThreadMapB,
kAlignmentB,
GatherB>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma =
cutlass::gemm::threadblock::MmaPipelined<typename MmaCore::Shape,
IteratorA,
typename MmaCore::SmemIteratorA,
IteratorB,
typename MmaCore::SmemIteratorB,
ElementAccumulator,
layout::RowMajor,
typename MmaCore::MmaPolicy>;
};
// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps
// avoid reg spills on large tile when not enough shared mem is present to do 3+
// stage
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear,
/// Gather operand A by using an index array
bool GatherA,
/// Gather operand B by using an index array
bool GatherB>
struct DefaultMma<bfloat16_t,
LayoutA,
kAlignmentA,
bfloat16_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
arch::Sm80,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator,
false,
SharedMemoryClear,
GatherA,
GatherB> {
// Define the MmaCore components
// 3 is used on purpose here to trigger components for mma multistage
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
WarpShape,
InstructionShape,
bfloat16_t,
LayoutA,
bfloat16_t,
LayoutB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
3,
Operator>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
using IteratorA =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
bfloat16_t,
LayoutA,
1,
ThreadMapA,
AccessTypeA,
GatherA>;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
using IteratorB =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
bfloat16_t,
LayoutB,
0,
ThreadMapB,
AccessTypeB,
GatherB>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma =
cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape,
IteratorA,
typename MmaCore::SmemIteratorA,
MmaCore::kCacheOpA,
IteratorB,
typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB,
ElementAccumulator,
layout::RowMajor,
typename MmaCore::MmaPolicy,
2>;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), bf16
/// activation & int8 weight
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator>
struct DefaultMma<cutlass::bfloat16_t,
LayoutA,
kAlignmentA,
uint8_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator> {
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<float>::value;
using Mma = DqMma<bfloat16_t,
LayoutA,
kAlignmentA,
uint8_t,
LayoutB,
kAlignmentB,
float,
layout::RowMajor,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), bf16
/// activation & int4 weight
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator>
struct DefaultMma<cutlass::bfloat16_t,
LayoutA,
kAlignmentA,
uint4b_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator> {
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<float>::value;
using Mma = DqMma<bfloat16_t,
LayoutA,
kAlignmentA,
uint4b_t,
LayoutB,
kAlignmentB,
float,
layout::RowMajor,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
///
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<cutlass::bfloat16_t,
LayoutA,
kAlignmentA,
uint8_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
false,
SharedMemoryClear> {
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<float>::value;
using Mma = DqMma<bfloat16_t,
LayoutA,
kAlignmentA,
uint8_t,
LayoutB,
kAlignmentB,
float,
layout::RowMajor,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
SharedMemoryClear>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16
/// activation & int4 weight
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
///
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<cutlass::bfloat16_t,
LayoutA,
kAlignmentA,
uint4b_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
false,
SharedMemoryClear> {
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<float>::value;
using Mma = DqMma<bfloat16_t,
LayoutA,
kAlignmentA,
uint4b_t,
LayoutB,
kAlignmentB,
float,
layout::RowMajor,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
SharedMemoryClear>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/threadblock/mma_base.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
// SFINAE trick so I can keep the same loop code for Volta and dispatch to the
// correct warp level mma. On volta, all data is stored to shared memory as
// FP16.
template <typename WarpMma, int kExpansionFactor = 1>
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, // NOLINT
typename WarpMma::FragmentC& D, // NOLINT
typename WarpMma::FragmentA const& A,
typename WarpMma::FragmentB const& B,
typename WarpMma::FragmentC const& C,
const int warp_tileB_k_offset) {
warp_mma(D, A, B, C);
}
template <typename WarpMma, int kExpansionFactor = WarpMma::kExpansionFactor>
CUTLASS_DEVICE void run_warp_mma(
WarpMma& warp_mma, // NOLINT
typename WarpMma::FragmentC& D, // NOLINT
typename WarpMma::TransformedFragmentA const& A,
typename WarpMma::TransformedFragmentB const& B,
typename WarpMma::FragmentC const& C,
const int warp_tileB_k_offset) {
warp_mma(D, A, B, C, warp_tileB_k_offset);
}
////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// The type of the scales
typename ElementScale_,
/// Number of stages,
int Stages,
/// Used for partial specialization
typename Enable = bool>
class DqMmaBase {
public:
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Policy describing tuning details
using Policy = Policy_;
///< Type of the scale to be loaded
using ElementScale = ElementScale_;
//
// Dependent types
//
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Shape describing the overall GEMM computed from shared memory
/// by each warp.
using WarpGemm = typename Policy::Operator::Shape;
/// Shape describing the number of warps filling the CTA
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM,
Shape::kN / WarpGemm::kN,
Shape::kK / WarpGemm::kK>;
/// Number of warp-level GEMM oeprations
static int const kWarpGemmIterations =
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
static constexpr int kNumKIterationsPerWarpBLoad =
Operator::IteratorB::InstructionShape::kRow /
Operator::InstructionShape::kK;
static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), "");
static constexpr int kWarpGemmIterationsForB =
kWarpGemmIterations / kNumKIterationsPerWarpBLoad;
/// Number of stages
static int const kStages = Stages;
/// Tensor reference to the A operand
using TensorRefA =
TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
/// Tensor reference to the B operand
using TensorRefB =
TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
//
// Nested structs
//
/// Shared storage object needed by threadblock-scoped GEMM
class SharedStorage {
public:
//
// Type definitions
//
/// Shape of the A matrix operand in shared memory
using ShapeA =
MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,
Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
/// Shape of the B matrix operand in shared memory
using ShapeB = MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
Shape::kN + Policy::SmemPaddingB::kColumn>;
public:
//
// Data members
//
/// Buffer for A operand
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
/// Buffer for B operand
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
/// Buffer to hold scales for threadblock
AlignedBuffer<ElementScale, Shape::kN> operand_scale;
public:
//
// Methods
//
/// Returns a layout object for the A matrix
CUTLASS_DEVICE
static typename Operator::LayoutA LayoutA() {
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
}
/// Returns a layout object for the B matrix
CUTLASS_HOST_DEVICE
static typename Operator::LayoutB LayoutB() {
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
}
/// Returns a TensorRef to the A operand
CUTLASS_HOST_DEVICE
TensorRefA operand_A_ref() {
return TensorRefA{operand_A.data(), LayoutA()};
}
/// Returns a TensorRef to the B operand
CUTLASS_HOST_DEVICE
TensorRefB operand_B_ref() {
return TensorRefB{operand_B.data(), LayoutB()};
}
};
protected:
//
// Data members
//
/// Iterator to load a warp-scoped tile of A operand from shared memory
typename Operator::IteratorA warp_tile_iterator_A_;
/// Iterator to load a warp-scoped tile of B operand from shared memory
typename Operator::IteratorB warp_tile_iterator_B_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DqMmaBase(
///< Shared storage needed for internal use by threadblock-scoped GEMM
SharedStorage& shared_storage, // NOLINT
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_base.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/interleaved_numeric_conversion.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Cache operation for operand A
cutlass::arch::CacheOperation::Kind CacheOpA,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Cache operation for operand B
cutlass::arch::CacheOperation::Kind CacheOpB,
/// Data type for the scales
typename IteratorScale_,
/// Iterators over scales in shared memory
typename SmemIteratorScale_,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Number of stages,
int Stages,
/// Converter for B matrix applited immediately after the LDS
typename TransformBAfterLDS_,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
/// Used for partial specialization
typename Enable = bool>
class DqMmaMultistage : public DqMmaBase<Shape_,
Policy_,
typename IteratorScale_::Element,
Stages> {
public:
///< Base class
using Base =
DqMmaBase<Shape_, Policy_, typename IteratorScale_::Element, Stages>;
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Iterates over tiles of A operand in global memory
using IteratorA = IteratorA_;
///< Iterates over tiles of B operand in global memory
using IteratorB = IteratorB_;
///< Data type of accumulator matrix
using ElementC = ElementC_;
///< Layout of accumulator matrix
using LayoutC = LayoutC_;
///< Policy describing tuning details
using Policy = Policy_;
using IteratorScale = IteratorScale_;
using ElementScale = typename IteratorScale::Element;
using LayoutScale = typename IteratorScale::Layout;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
using SmemIteratorScale = SmemIteratorScale_;
static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;
using TransformBAfterLDS = TransformBAfterLDS_;
//
// Dependent types
//
/// Fragment of operand Scale loaded from global memory;
using FragmentScale = typename IteratorScale::Fragment;
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Minimum architecture is Sm80 to support cp.async
using ArchTag = arch::Sm80;
using Dequantizer =
warp::MmaTensorOpDequantizer<Operator,
typename Base::WarpGemm,
Operand::kB,
ElementScale,
LayoutScale,
32,
typename Operator::FragmentA::Element>;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
/// Internal structure exposed for introspection.
struct Detail {
static_assert(Base::kWarpGemmIterations > 1,
"The pipelined structure requires at least two warp-level "
"GEMM operations.");
/// Number of cp.async instructions to load one stage of operand A
static int const AsyncCopyIterationsPerStageA =
IteratorA::ThreadMap::Iterations::kCount;
/// Number of cp.async instructions to load one stage of operand B
static int const AsyncCopyIterationsPerStageB =
IteratorB::ThreadMap::Iterations::kCount;
/// Number of stages
static int const kStages = Stages;
/// Number of cp.async instructions to load on group of operand A
static int const kAccessesPerGroupA =
(AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) /
Base::kWarpGemmIterations;
/// Number of cp.async instructions to load on group of operand B
static int const kAccessesPerGroupB =
(AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) /
Base::kWarpGemmIterations;
};
private:
using WarpFragmentA = typename Operator::FragmentA;
using WarpFragmentB = typename Operator::FragmentB;
Dequantizer warp_dequantizer_;
using ElementB = typename IteratorB::Element;
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementB, ArchTag>;
static constexpr bool RequiresTileInterleave =
layout::IsColumnMajorTileInterleave<
typename LayoutDetailsForB::Layout>::value;
static_assert(!RequiresTileInterleave ||
(RequiresTileInterleave &&
(Shape::kK == LayoutDetailsForB::ThreadblockK)),
"Layout K must match threadblockK");
private:
//
// Data members
//
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
/// Iterator to write threadblock-scoped tile of scale operand to shared
/// memory
SmemIteratorScale smem_iterator_scale_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DqMmaMultistage(
///< Shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorage& shared_storage, // NOLINT
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
warp_dequantizer_(
{shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) /
Base::WarpCount::kM,
lane_idx),
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
smem_iterator_scale_(LayoutScale(Shape::kN),
shared_storage.operand_scale.data(),
{1, Shape::kN},
thread_idx) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset(
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset(
{Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
}
CUTLASS_DEVICE
void copy_tiles_and_advance(IteratorA& iterator_A, // NOLINT
IteratorB& iterator_B, // NOLINT
int group_start_A = 0,
int group_start_B = 0) {
iterator_A.set_iteration_index(group_start_A *
IteratorA::kAccessesPerVector);
this->smem_iterator_A_.set_iteration_index(group_start_A);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) {
if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) {
typename IteratorA::AccessType* dst_ptr =
reinterpret_cast<typename IteratorA::AccessType*>(
this->smem_iterator_A_.get());
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_A.get();
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, gmem_ptr, iterator_A.valid());
} else {
cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(
dst_ptr + v, gmem_ptr, iterator_A.valid());
}
++iterator_A;
}
++this->smem_iterator_A_;
}
}
iterator_B.set_iteration_index(group_start_B *
IteratorB::kAccessesPerVector);
this->smem_iterator_B_.set_iteration_index(group_start_B);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) {
typename IteratorB::AccessType* dst_ptr =
reinterpret_cast<typename IteratorB::AccessType*>(
this->smem_iterator_B_.get());
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
IteratorB::ThreadMap::kElementsPerAccess /
IteratorB::kAccessesPerVector / 8;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_B.get();
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
} else {
cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(
dst_ptr + v, gmem_ptr, iterator_B.valid());
}
++iterator_B;
}
++this->smem_iterator_B_;
}
}
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(
///< problem size of GEMM
int gemm_k_iterations,
///< destination accumulator tile
FragmentC& accum, // NOLINT
///< iterator over A operand in global memory
IteratorA iterator_A,
///< iterator over B operand in global memory
IteratorB iterator_B,
///< iterator over scale operand in global memory
IteratorScale iterator_scale,
///< initial value of accumulator
FragmentC const& src_accum) {
//
// Prologue
//
TransformBAfterLDS lds_converter;
// NOTE - switch to ldg.sts
// Issue this first, so cp.async.commit_group will commit this load as well.
// Note: we do not commit here and this load will commit in the same group
// as
// the first load of A.
FragmentScale tb_frag_scales;
tb_frag_scales.clear();
iterator_scale.load(tb_frag_scales);
this->smem_iterator_scale_.store(tb_frag_scales);
// Issue several complete stages
CUTLASS_PRAGMA_UNROLL
for (int stage = 0; stage < Base::kStages - 1;
++stage, --gemm_k_iterations) {
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
iterator_A.set_iteration_index(0);
this->smem_iterator_A_.set_iteration_index(0);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
typename IteratorA::AccessType* dst_ptr =
reinterpret_cast<typename IteratorA::AccessType*>(
this->smem_iterator_A_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
int const kSrcBytes =
sizeof_bits<typename IteratorA::Element>::value *
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
int src_bytes = (iterator_A.valid() ? kSrcBytes : 0);
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, iterator_A.get(), iterator_A.valid());
++iterator_A;
}
++this->smem_iterator_A_;
}
iterator_B.set_iteration_index(0);
this->smem_iterator_B_.set_iteration_index(0);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
typename IteratorB::AccessType* dst_ptr =
reinterpret_cast<typename IteratorB::AccessType*>(
this->smem_iterator_B_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
int const kSrcBytes =
sizeof_bits<typename IteratorB::Element>::value *
IteratorB::ThreadMap::kElementsPerAccess /
IteratorB::kAccessesPerVector / 8;
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, iterator_B.get(), iterator_B.valid());
++iterator_B;
}
++this->smem_iterator_B_;
}
// Move to the next stage
iterator_A.add_tile_offset({0, 1});
iterator_B.add_tile_offset({1, 0});
this->smem_iterator_A_.add_tile_offset({0, 1});
this->smem_iterator_B_.add_tile_offset({1, 0});
// Defines the boundary of a stage of cp.async.
cutlass::arch::cp_async_fence();
}
// Perform accumulation in the 'd' output operand
accum = src_accum;
//
// Clear the remaining tiles of SMEM. This is a functional requirement for
// some kernels so that all accumulator elements outside the GEMM footprint
// are zero.
//
if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) {
/// Iterator to write threadblock-scoped tile of A operand to shared
/// memory
SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
typename IteratorA::AccessType zero_A;
zero_A.clear();
last_smem_iterator_A.set_iteration_index(0);
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
typename IteratorA::AccessType* dst_ptr =
reinterpret_cast<typename IteratorA::AccessType*>(
last_smem_iterator_A.get());
*dst_ptr = zero_A;
++last_smem_iterator_A;
}
/// Iterator to write threadblock-scoped tile of B operand to shared
/// memory
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
typename IteratorB::AccessType zero_B;
zero_B.clear();
last_smem_iterator_B.set_iteration_index(0);
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
typename IteratorB::AccessType* dst_ptr =
reinterpret_cast<typename IteratorB::AccessType*>(
last_smem_iterator_B.get());
*dst_ptr = zero_B;
++last_smem_iterator_B;
}
}
// Waits until kStages-2 stages have committed.
cutlass::arch::cp_async_wait<Base::kStages - 2>();
__syncthreads();
// Pair of fragments used to overlap shared memory loads and math
// instructions
WarpFragmentA warp_frag_A[2];
WarpFragmentB warp_frag_B[2];
typename Dequantizer::FragmentScale warp_frag_scales;
Operator warp_mma;
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
warp_dequantizer_.load(warp_frag_scales);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
int smem_write_stage_idx = Base::kStages - 1;
int smem_read_stage_idx = 0;
//
// Mainloop
//
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > (-Base::kStages + 1);) {
//
// Loop over GEMM K dimension
//
// Computes a warp-level GEMM on data held in shared memory
// Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
++warp_mma_k) {
// Load warp-level tiles from shared memory, wrapping to k offset if
// this is the last group as the case may be.
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) %
Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
const int warp_tileB_k_compute_offset =
warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
const int warp_tileB_k_load_offset =
warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
if (warp_tileB_k_compute_offset ==
Base::kNumKIterationsPerWarpBLoad - 1) {
this->warp_tile_iterator_B_.set_kgroup_index(
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
this->warp_tile_iterator_B_.load(
warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
++this->warp_tile_iterator_B_;
}
typename TransformBAfterLDS::result_type converted_frag_B =
lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
run_warp_mma(warp_mma,
accum,
warp_frag_A[warp_mma_k % 2],
converted_frag_B,
accum,
warp_tileB_k_compute_offset);
// Issue global->shared copies for the this stage
if (warp_mma_k < Base::kWarpGemmIterations - 1) {
int group_start_iteration_A, group_start_iteration_B;
group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA;
group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB;
copy_tiles_and_advance(iterator_A,
iterator_B,
group_start_iteration_A,
group_start_iteration_B);
}
if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
int group_start_iteration_A, group_start_iteration_B;
group_start_iteration_A =
(warp_mma_k + 1) * Detail::kAccessesPerGroupA;
group_start_iteration_B =
(warp_mma_k + 1) * Detail::kAccessesPerGroupB;
copy_tiles_and_advance(iterator_A,
iterator_B,
group_start_iteration_A,
group_start_iteration_B);
// Inserts a memory fence between stages of cp.async instructions.
cutlass::arch::cp_async_fence();
// Waits until kStages-2 stages have committed.
arch::cp_async_wait<Base::kStages - 2>();
__syncthreads();
// Move to the next stage
iterator_A.add_tile_offset({0, 1});
iterator_B.add_tile_offset({1, 0});
this->smem_iterator_A_.add_tile_offset({0, 1});
this->smem_iterator_B_.add_tile_offset({1, 0});
// Add negative offsets to return iterators to the 'start' of the
// circular buffer in shared memory
if (smem_write_stage_idx == (Base::kStages - 1)) {
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
smem_write_stage_idx = 0;
} else {
++smem_write_stage_idx;
}
if (smem_read_stage_idx == (Base::kStages - 1)) {
this->warp_tile_iterator_A_.add_tile_offset(
{0,
-Base::kStages * Policy::kPartitionsK *
Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK *
Base::kWarpGemmIterationsForB,
0});
smem_read_stage_idx = 0;
} else {
++smem_read_stage_idx;
}
--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
}
}
}
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
// commit and drain all pending and predicated LDGSTS pnz from the GEMM
// mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_base.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/interleaved_numeric_conversion.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorB_,
/// Data type for the scales
typename IteratorScale_,
/// Iterators over scales in shared memory
typename SmemIteratorScale_,
/// Data type of accumulator matrix
typename ElementC_,
/// Data type of accumulator matrix
typename LayoutC_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// Converter for B matrix applied immediately after the LDG (before STS)
typename TransformBAfterLDG_,
/// Converter for B matrix applited immediately after the LDS
typename TransformBAfterLDS_,
/// Used for partial specialization
typename Enable = bool>
class DqMmaPipelined : public DqMmaBase<Shape_,
Policy_,
typename SmemIteratorScale_::Element,
2> {
public:
///< Base class
using Base =
DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2>;
using Shape =
Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using IteratorA =
IteratorA_; ///< Iterates over tiles of A operand in global memory
using IteratorB =
IteratorB_; ///< Iterates over tiles of B operand in global memory
using ElementC = ElementC_; ///< Data type of accumulator matrix
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
using Policy = Policy_; ///< Policy describing tuning details
using IteratorScale = IteratorScale_;
using ElementScale = typename IteratorScale::Element;
using LayoutScale = typename IteratorScale::Layout;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
using SmemIteratorScale = SmemIteratorScale_;
using TransformBAfterLDG = TransformBAfterLDG_;
using TransformBAfterLDS = TransformBAfterLDS_;
//
// Dependent types
//
/// Fragment of operand A loaded from global memory
using FragmentA = typename IteratorA::Fragment;
/// Fragment of operand B loaded from global memory
using FragmentB = typename IteratorB::Fragment;
/// Fragment of operand Scale loaded from global memory;
using FragmentScale = typename IteratorScale::Fragment;
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Obtain the arch tag from the warp-level operator
using ArchTag = typename Policy::Operator::ArchTag;
using Dequantizer = warp::MmaTensorOpDequantizer<
Operator,
typename Base::WarpGemm,
Operand::kB,
typename SmemIteratorScale::Fragment::Element,
LayoutScale,
32,
typename Operator::FragmentA::Element>;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
// staticaly assert kStages for DqMmaPipelined is two (Double-buffered
// pipeline)
static_assert((Base::kStages == 2),
"DqMmaPipelined requires kStages set to value 2");
private:
using WarpFragmentA = typename Operator::FragmentA;
using WarpFragmentB = typename Operator::FragmentB;
Dequantizer warp_dequantizer_;
using ElementB = typename IteratorB::Element;
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementB, ArchTag>;
static constexpr bool RequiresTileInterleave =
layout::IsColumnMajorTileInterleave<
typename LayoutDetailsForB::Layout>::value;
static_assert(!RequiresTileInterleave ||
(RequiresTileInterleave &&
(Shape::kK == LayoutDetailsForB::ThreadblockK)),
"Layout K must match threadblockK");
protected:
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
/// Iterator to write threadblock-scoped tile of scale operand to shared
/// memory
SmemIteratorScale smem_iterator_scale_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DqMmaPipelined(typename Base::SharedStorage&
shared_storage, ///< Shared storage needed for internal
///< use by threadblock-scoped GEMM
int thread_idx, ///< ID within the threadblock
int warp_idx, ///< ID of warp
int lane_idx ///< ID of each thread within a warp
)
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
warp_dequantizer_(
{shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) /
Base::WarpCount::kM,
lane_idx),
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
smem_iterator_scale_(LayoutScale(Shape::kN),
shared_storage.operand_scale.data(),
{1, Shape::kN},
thread_idx) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset(
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset(
{Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(
int gemm_k_iterations, ///< number of iterations of the mainloop
FragmentC& accum, ///< destination accumulator tile // NOLINT
IteratorA iterator_A, ///< iterator over A operand in global memory
IteratorB iterator_B, ///< iterator over B operand in global memory
IteratorScale
iterator_scale, ///< iterator over scale operand in global memory
FragmentC const& src_accum) { ///< source accumulator tile
//
// Prologue
//
TransformBAfterLDG ldg_converter;
TransformBAfterLDS lds_converter;
using TransformA = NumericArrayConverter<typename WarpFragmentA::Element,
typename FragmentA::Element,
FragmentA::kElements>;
using TransformScale =
NumericArrayConverter<typename SmemIteratorScale::Fragment::Element,
typename FragmentScale::Element,
FragmentScale::kElements>;
// These transforms are mainly to handle when we have bfloat activations and
// weights in GMEM and want to issue HMMA on architectures older than
// Ampere. We will convert to FP16 before STS.
TransformA transformA;
TransformScale transformScale;
// Perform accumulation in the 'd' output operand
accum = src_accum;
FragmentA tb_frag_A;
FragmentB tb_frag_B;
FragmentScale tb_frag_scales;
using WarpFragmentScale = typename Dequantizer::FragmentScale;
WarpFragmentScale warp_frag_scales;
tb_frag_A.clear();
tb_frag_B.clear();
tb_frag_scales.clear();
// The last kblock is loaded in the prolog
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
iterator_scale.load(tb_frag_scales);
++iterator_A;
++iterator_B;
this->smem_iterator_A_.store(transformA(tb_frag_A));
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
this->smem_iterator_scale_.store(transformScale(tb_frag_scales));
++this->smem_iterator_A_;
++this->smem_iterator_B_;
__syncthreads();
warp_dequantizer_.load(warp_frag_scales);
// Pair of fragments used to overlap shared memory loads and math
// instructions
WarpFragmentA warp_frag_A[2];
WarpFragmentB warp_frag_B[2];
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
Operator warp_mma;
int smem_write_stage_idx = 1;
// Avoid reading out of bounds
iterator_A.clear_mask(gemm_k_iterations <= 1);
iterator_B.clear_mask(gemm_k_iterations <= 1);
// Issue loads during the first warp-level matrix multiply-add *AFTER*
// issuing shared memory loads (which have the tighest latency requirement).
//
// Mainloop
//
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > 0; --gemm_k_iterations) {
//
// Loop over GEMM K dimension
//
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
++warp_mma_k) {
// Load warp-level tiles from shared memory, wrapping to k offset if
// this is the last group as the case may be.
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
// Write fragments to shared memory
this->smem_iterator_A_.store(transformA(tb_frag_A));
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
__syncthreads();
++this->smem_iterator_A_;
++this->smem_iterator_B_;
// Add negative offsets to return iterators to the 'start' of the
// circular buffer in shared memory
if (smem_write_stage_idx == 1) {
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
} else {
this->warp_tile_iterator_A_.add_tile_offset(
{0,
-Base::kStages * Policy::kPartitionsK *
Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK *
Base::kWarpGemmIterationsForB,
0});
}
smem_write_stage_idx ^= 1;
}
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) %
Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
const int warp_tileB_k_compute_offset =
warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
const int warp_tileB_k_load_offset =
warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
// We are just about to finish computing on a fragment of B, so initiate
// the load for the next fragment.
if (warp_tileB_k_compute_offset ==
Base::kNumKIterationsPerWarpBLoad - 1) {
this->warp_tile_iterator_B_.set_kgroup_index(
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
this->warp_tile_iterator_B_.load(
warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
++this->warp_tile_iterator_B_;
}
if (warp_mma_k == 0) {
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
++iterator_A;
++iterator_B;
// Avoid reading out of bounds if this was the last loop iteration
iterator_A.clear_mask(gemm_k_iterations <= 2);
iterator_B.clear_mask(gemm_k_iterations <= 2);
}
typename TransformBAfterLDS::result_type converted_frag_B =
lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
run_warp_mma(warp_mma,
accum,
warp_frag_A[warp_mma_k % 2],
converted_frag_B,
accum,
warp_tileB_k_compute_offset);
}
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*! \file
\brief Default warp-level GEMM operators selected by data type, size, and
layouts of operands.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/warp/default_mma_tensor_op.h"
#include "cutlass/gemm/warp/mma_tensor_op.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
namespace cutlass {
namespace gemm {
namespace warp {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for m-by-n-by-kgroup
template <
/// Shape of one matrix production operation (concept: GemmShape)
typename WarpShape_,
/// Shape of one matrix production operation (concept: GemmShape)
typename InstructionShape_,
/// Data type of A elements,
typename ElementA,
/// Layout of A matrix (concept: MatrixLayout)
typename LayoutA,
/// Data type of B elements
typename ElementB,
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB,
/// Element type of C matrix
typename ElementC,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC,
/// Number of partitions along K dimension
int PartitionsK,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool AccumulatorsInRowMajor>
struct DefaultMmaTensorOp<WarpShape_,
InstructionShape_,
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
arch::OpMultiplyAddDequantizeInterleavedBToA,
PartitionsK,
AccumulatorsInRowMajor> {
private:
// Shape for computing the FP16s
using ComputeInstructionShape = InstructionShape_;
// Chosen so we get K=16 for int8 and K=32 for int4.
static constexpr int LoadInstructionK =
8 * sizeof_bits<ElementA>::value / sizeof_bits<ElementB>::value;
// Shape for loading the narrow data type from shared memory
using LoadInstructionShape =
GemmShape<InstructionShape_::kM, InstructionShape_::kN, LoadInstructionK>;
public:
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
cutlass::arch::Mma<InstructionShape_,
32,
ElementA,
cutlass::layout::RowMajor,
ElementA,
cutlass::layout::ColumnMajor,
ElementC,
cutlass::layout::RowMajor,
arch::OpMultiplyAdd>,
cutlass::MatrixShape<1, 1>>;
// Define the warp-level tensor op
using Type =
cutlass::gemm::warp::MmaTensorOpComputeBWithF16<WarpShape_,
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
Policy,
LoadInstructionShape,
PartitionsK,
AccumulatorsInRowMajor>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace warp
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*! \file
\brief Templates implementing warp-level matrix multiply-accumulate
operations targeting Tensor Cores.
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/platform/platform.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/mma_sm80.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/warp/mma.h"
#include "cutlass/gemm/warp/mma_tensor_op_policy.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace warp {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Data type of A elements
typename ElementA_,
/// Layout of A matrix (concept: MatrixLayout)
typename LayoutA_,
/// Data type of B elements
typename ElementB_,
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB_,
/// Element type of C matrix
typename ElementC_,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC_,
/// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
typename Policy_,
/// Instruction shape to override shared memory iterators with
typename SharedMemoryInstructionShape_,
/// Number of partitions along K dimension
int PartitionsK_ = 1,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool AccumulatorsInRowMajor = false,
/// Used for partial specialization
typename Enable = bool>
class MmaTensorOpComputeBWithF16 {
public:
/// Shape of warp-level matrix operation (concept: GemmShape)
using Shape = Shape_;
/// Data type of multiplicand A
using ElementA = ElementA_;
/// Layout of multiplicand A
using LayoutA = LayoutA_;
/// Data type of multiplicand B
using ElementB = ElementB_;
/// Layout of multiplicand B
using LayoutB = LayoutB_;
/// Data type of accumulator matrix C
using ElementC = ElementC_;
/// Layout of accumulator matrix C
using LayoutC = LayoutC_;
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
using Policy = Policy_;
/// Underlying matrix multiply operator (concept: arch::Mma)
using ArchMmaOperator = typename Policy::Operator;
/// Indicates math operator
using MathOperator = typename ArchMmaOperator::Operator;
/// Architecture tag from underlying instruction
using ArchTag = typename ArchMmaOperator::ArchTag;
static_assert(
(platform::is_same<typename ArchMmaOperator::ElementA, half_t>::value &&
platform::is_same<typename ArchMmaOperator::ElementB, half_t>::value) ||
(platform::is_same<typename ArchMmaOperator::ElementA,
bfloat16_t>::value &&
platform::is_same<typename ArchMmaOperator::ElementB,
bfloat16_t>::value &&
ArchTag::kMinComputeCapability >= 80),
"MmaTensorOpCvtBToA only supports underlying HMMA");
static_assert(platform::is_same<ElementA, half_t>::value ||
(platform::is_same<ElementA, bfloat16_t>::value &&
ArchTag::kMinComputeCapability >= 80),
"MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+");
/// Indicates class of matrix operator
using OperatorClass = arch::OpClassTensorOp;
/// Shape of underlying instruction
using InstructionShape = typename ArchMmaOperator::Shape;
/// Instruction shape to override shared memory iterators with
using SharedMemoryInstructionShape = SharedMemoryInstructionShape_;
static_assert(SharedMemoryInstructionShape::kM == InstructionShape::kM,
"M dimension of compute instruction must match load");
static_assert(SharedMemoryInstructionShape::kN == InstructionShape::kN,
"N dimension of compute instruction must match load");
static constexpr int kExpansionFactor =
SharedMemoryInstructionShape::kK / InstructionShape::kK;
static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), "");
/// Complex transform on A operand
static ComplexTransform const kTransformA = ComplexTransform::kNone;
/// Complex transform on B operand
static ComplexTransform const kTransformB = ComplexTransform::kNone;
/// Number of threads participating in warp-level matrix product
static int const kThreadCount = 32;
/// Number of partitions along K dimension
static int const kPartitionsK = PartitionsK_;
public:
/// Iterates over the A operand in memory
using IteratorA = MmaTensorOpMultiplicandTileIterator<
MatrixShape<Shape::kM, Shape::kK>,
Operand::kA,
ElementA,
LayoutA,
MatrixShape<InstructionShape::kM, InstructionShape::kK>,
Policy::OpDelta::kRow,
kThreadCount,
kPartitionsK>;
/// Storage for A tile
using FragmentA = typename IteratorA::Fragment;
/// Storage for transformed A tile
using TransformedFragmentA =
Array<typename ArchMmaOperator::ElementA, FragmentA::kElements>;
/// Iterates over the B operand in memory
using IteratorB = MmaTensorOpMultiplicandTileIterator<
MatrixShape<Shape::kK, Shape::kN>,
Operand::kB,
ElementB,
LayoutB,
MatrixShape<SharedMemoryInstructionShape::kK, InstructionShape::kN>,
Policy::OpDelta::kRow,
kThreadCount,
kPartitionsK>;
/// Storage for B tile
using FragmentB = typename IteratorB::Fragment;
/// Storage for transformed B tile
using TransformedFragmentB =
Array<typename ArchMmaOperator::ElementB, FragmentB::kElements>;
/// Iterates over the C operand in memory
using IteratorC =
MmaTensorOpAccumulatorTileIterator<MatrixShape<Shape::kM, Shape::kN>,
ElementC,
LayoutC,
typename ArchMmaOperator::Shape,
typename Policy::OpDelta>;
/// Storage for C tile
using FragmentC = typename IteratorC::Fragment;
/// Number of mma operations performed
using MmaIterations =
MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) /
ArchMmaOperator::Shape::kM,
(Shape::kN + ArchMmaOperator::Shape::kN - 1) /
ArchMmaOperator::Shape::kN>;
public:
/// Underlying matrix multiply operator (concept: arch::Mma)
ArchMmaOperator mma;
public:
//
// Methods
//
/// Ctor
CUTLASS_DEVICE
MmaTensorOpComputeBWithF16() {}
/// Performs a warp-level matrix multiply-accumulate operation
CUTLASS_DEVICE
void operator()(FragmentC& D, // NOLINT
TransformedFragmentA const& A,
TransformedFragmentB const& B,
FragmentC const& C,
const int warp_tileB_k_offset) const {
using MmaOperandA = typename ArchMmaOperator::FragmentA;
using MmaOperandB = typename ArchMmaOperator::FragmentB;
using MmaOperandC = typename ArchMmaOperator::FragmentC;
static_assert(
TransformedFragmentB::kElements ==
MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn,
"Each thread should have a pack of mma registers for each column "
"iteration AND for the expanded K dim of B");
D = C;
MmaOperandA const* ptr_A = reinterpret_cast<MmaOperandA const*>(&A);
MmaOperandB const* ptr_B = reinterpret_cast<MmaOperandB const*>(&B);
MmaOperandC* ptr_D = reinterpret_cast<MmaOperandC*>(&D);
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
// Serpentine visitation order maximizing reuse of Rb
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < MmaIterations::kColumn; ++n) {
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < MmaIterations::kRow; ++m) {
int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m);
int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n;
if (AccumulatorsInRowMajor) { // matrix B is reordered
mma(ptr_D[n + m_serpentine * MmaIterations::kColumn],
ptr_A[m_serpentine],
ptr_B[n_offsetB],
ptr_D[n + m_serpentine * MmaIterations::kColumn]);
} else {
mma(ptr_D[m_serpentine + n * MmaIterations::kRow],
ptr_A[m_serpentine],
ptr_B[n_offsetB],
ptr_D[m_serpentine + n * MmaIterations::kRow]);
}
}
}
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
// Serpentine visitation order maximizing reuse of Ra
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < MmaIterations::kRow; ++m) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < MmaIterations::kColumn; ++n) {
int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n);
int n_serpentine_offsetB =
warp_tileB_k_offset + kExpansionFactor * n_serpentine;
if (AccumulatorsInRowMajor) { // matrix B is reordered
mma(ptr_D[n_serpentine + m * MmaIterations::kColumn],
ptr_A[m],
ptr_B[n_serpentine_offsetB],
ptr_D[n_serpentine + m * MmaIterations::kColumn]);
} else {
mma(ptr_D[m + n_serpentine * MmaIterations::kRow],
ptr_A[m],
ptr_B[n_serpentine_offsetB],
ptr_D[m + n_serpentine * MmaIterations::kRow]);
}
}
}
#else
assert(0);
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace warp
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*! \file
\brief Defines iterators used by warp-level matrix multiply operations
targeting Tensor Cores.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
#include "cutlass/tensor_ref.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/layout/tensor.h"
#include "cutlass/functional.h"
#include "cutlass/platform/platform.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace warp {
////////////////////////////////////////////////////////////////////////////////
template <
/// Matrix multiply operator
typename MmaOperator_,
/// Size of the matrix to load (concept: MatrixShape)
typename Shape_,
/// Operand identity
Operand Operand,
/// Data type of Scale elements
typename Element_,
/// Layout of operand
typename Layout_,
/// Number of threads participating in one matrix operation
int Threads,
///
/// Data type of out elements
typename Element_out_,
typename Enable = void>
class MmaTensorOpDequantizer;
////////////////////////////////////////////////////////////////////////////////
// Bfloat specialization for Ampere
#ifdef PADDLE_CUDA_BF16
template <
/// Underlying matrix multiply operator (concept: MmaTensorOp)
typename MmaOperator_,
/// Shape of the warp level matrix multiply (concept: GemmShape)
typename Shape_>
class MmaTensorOpDequantizer<
MmaOperator_,
Shape_,
Operand::kB,
bfloat16_t,
layout::RowMajor,
32,
bfloat16_t,
typename platform::enable_if<
MmaOperator_::ArchTag::kMinComputeCapability >= 80 &&
platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB,
layout::ColumnMajor>::value>::type> {
public:
/// Mma Operator
using MmaOperator = MmaOperator_;
// The architecture specific mma ooperator being used
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
// Mma Instruction Shape
using InstructionShape = typename ArchMmaOperator::Shape;
// This is the ratio of the load instruction vs the compute instruction.
static constexpr int kExpansionFactor =
MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
/// Type of the scales
using ElementScale = bfloat16_t;
/// Fragment to hold B data before Mma
using FragmentDequantizedOperand =
Array<ElementScale, MmaOperator::FragmentB::kElements>;
// Fragment to hold scale data to apply to B before mma
// We need 1 fp16 per matrix iteration in the N dimension
static constexpr int kColsPerMmaPerThread = 1;
using FragmentScale =
Array<ElementScale,
kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
/// Warp mma shape
using Shape = Shape_;
/// Layout of the scales in shared memory
using Layout = layout::RowMajor;
/// TensorRef type for loading element from a tensor
using TensorRef = TensorRef<ElementScale, Layout>;
CUTLASS_DEVICE
MmaTensorOpDequantizer(TensorRef smem_scales,
const int warp_idx_n,
const int lane_idx) {
const int warp_offset = warp_idx_n * Shape::kN;
const int quad = lane_idx / 4;
const int thread_offset = warp_offset + quad;
pointer_ = smem_scales.data() + thread_offset;
}
CUTLASS_DEVICE
void load(FragmentScale& scale_frag) {
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn;
++mma_n_iter) {
scale_frag[mma_n_iter] = pointer_[mma_n_iter * InstructionShape::kN];
}
}
CUTLASS_DEVICE
void dequantize(FragmentDequantizedOperand& operand_frag,
const FragmentScale& scale_frag) {
// Slow path not implemented here on purpose. If we need to do HMMA on older
// arch, scale conversion should happen before scales are stored to shared
// memory and we should use the fp16 dequantizer. This will avoid numerous
// conversion instructions in GEMM main loop.
arch::device_breakpoint();
}
private:
ElementScale const* pointer_;
};
#endif
// Specialization for Turing & Ampere
template <
/// Underlying matrix multiply operator (concept: MmaTensorOp)
typename MmaOperator_,
/// Shape of the warp level matrix multiply (concept: GemmShape)
typename Shape_>
class MmaTensorOpDequantizer<
MmaOperator_,
Shape_,
Operand::kB,
half_t,
layout::RowMajor,
32,
half_t,
typename platform::enable_if<
MmaOperator_::ArchTag::kMinComputeCapability >= 75 &&
platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB,
layout::ColumnMajor>::value>::type> {
public:
/// Mma Operator
using MmaOperator = MmaOperator_;
// The architecture specific mma ooperator being used
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
// Mma Instruction Shape
using InstructionShape = typename ArchMmaOperator::Shape;
// This is the ratio of the load instruction vs the compute instruction.
static constexpr int kExpansionFactor =
MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
/// Type of the scales
using ElementScale = half_t;
/// Fragment to hold B data before Mma
using FragmentDequantizedOperand =
Array<ElementScale, MmaOperator::FragmentB::kElements>;
// Fragment to hold scale data to apply to B before mma
// We need 1 fp16 per matrix iteration in the N dimension
static constexpr int kColsPerMmaPerThread = 1;
using FragmentScale =
Array<ElementScale,
kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
/// Warp mma shape
using Shape = Shape_;
/// Layout of the scales in shared memory
using Layout = layout::RowMajor;
/// TensorRef type for loading element from a tensor
using TensorRef = TensorRef<ElementScale, Layout>;
CUTLASS_DEVICE
MmaTensorOpDequantizer(TensorRef smem_scales,
const int warp_idx_n,
const int lane_idx) {
const int warp_offset = warp_idx_n * Shape::kN;
const int quad = lane_idx / 4;
const int thread_offset = warp_offset + quad;
pointer_ = smem_scales.data() + thread_offset;
}
CUTLASS_DEVICE
void load(FragmentScale& scale_frag) {
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn;
++mma_n_iter) {
scale_frag[mma_n_iter] = pointer_[mma_n_iter * InstructionShape::kN];
}
}
CUTLASS_DEVICE
void dequantize(FragmentDequantizedOperand& operand_frag,
const FragmentScale& scale_frag) {
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
using ExpandedMmaOperandB =
Array<typename _MmaOperandB::Element,
kExpansionFactor * _MmaOperandB::kElements>;
static_assert(
ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn ==
FragmentDequantizedOperand::kElements,
"");
multiplies<ExpandedMmaOperandB> mul_op;
ExpandedMmaOperandB* operand_frag_ptr =
reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn;
++mma_n_iter) {
operand_frag_ptr[mma_n_iter] =
mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]);
}
}
private:
ElementScale const* pointer_;
};
////////////////////////////////////////////////////////////////////////////////
// Specialization for Volta A x RowMajor B tensorOp, for 32x32x4 interleaved
// gemm
template <
/// Underlying matrix multiply operator (concept: MmaTensorOp)
typename MmaOperator_,
/// Shape of the warp level matrix multiply (concept: GemmShape)
typename Shape_>
class MmaTensorOpDequantizer<
MmaOperator_,
Shape_,
Operand::kB,
half_t,
layout::RowMajor,
32,
half_t,
typename platform::enable_if<
platform::is_same<typename MmaOperator_::ArchTag, arch::Sm70>::value &&
platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB,
layout::RowMajor>::value>::type> {
public:
static_assert(platform::is_same<typename MmaOperator_::InterleavedTileShape,
GemmShape<32, 32, 4>>::value,
"");
/// Mma Operator
using MmaOperator = MmaOperator_;
// The architecture specific mma ooperator being used
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
// Mma Instruction Shape
using InstructionShape = typename ArchMmaOperator::Shape;
/// Type of the scales
using ElementScale = half_t;
/// Fragment to hold B data before Mma
using FragmentDequantizedOperand =
Array<ElementScale, MmaOperator::FragmentB::kElements>;
/// Warp mma shape
using Shape = Shape_;
// Fragment to hold scale data to apply to B before mma
// Each 32x32x4 matmul uses 8 elements from B.
static constexpr int ColsPerMmaTile = 32;
static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile;
using FragmentScale = Array<ElementScale, TileNIterations * 8>;
using AccessType = Array<ElementScale, 8>;
/// Layout of the scales in shared memory
using Layout = layout::RowMajor;
/// TensorRef type for loading element from a tensor
using TensorRef = TensorRef<ElementScale, Layout>;
CUTLASS_DEVICE
MmaTensorOpDequantizer(TensorRef smem_scales,
const int warp_idx_n,
const int lane_idx) {
const int warp_offset = warp_idx_n * Shape::kN;
const int base_col = lane_idx & 0xF8;
const int thread_offset = warp_offset + base_col;
pointer_ = smem_scales.data() + thread_offset;
}
CUTLASS_DEVICE
void load(FragmentScale& scale_frag) { // NOLINT
AccessType* scale_frag_ptr = reinterpret_cast<AccessType*>(&scale_frag);
CUTLASS_PRAGMA_UNROLL
for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) {
// We jump by 32 here since volta does <32x32x4> super mmas inside a warp.
scale_frag_ptr[tile_iter] = *reinterpret_cast<AccessType const*>(
pointer_ + ColsPerMmaTile * tile_iter);
}
}
CUTLASS_DEVICE
void dequantize(FragmentDequantizedOperand& operand_frag, // NOLINT
const FragmentScale& scale_frag) {
static_assert(
FragmentScale::kElements == FragmentDequantizedOperand::kElements, "");
multiplies<FragmentDequantizedOperand> mul_op;
operand_frag = mul_op(operand_frag, scale_frag);
}
private:
ElementScale const* pointer_;
};
////////////////////////////////////////////////////////////////////////////////
// Specialization for Volta A x ColumnMajor B tensorOp, for 32x32x4 interleaved
// gemm
template <
/// Underlying matrix multiply operator (concept: MmaTensorOp)
typename MmaOperator_,
/// Shape of the warp level matrix multiply (concept: GemmShape)
typename Shape_>
class MmaTensorOpDequantizer<
MmaOperator_,
Shape_,
Operand::kB,
half_t,
layout::RowMajor,
32,
typename platform::enable_if<
platform::is_same<typename MmaOperator_::ArchTag, arch::Sm70>::value &&
platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB,
layout::ColumnMajor>::value>::type> {
public:
static_assert(platform::is_same<typename MmaOperator_::InterleavedTileShape,
GemmShape<32, 32, 4>>::value,
"");
/// Mma Operator
using MmaOperator = MmaOperator_;
// The architecture specific mma ooperator being used
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
// Mma Instruction Shape
using InstructionShape = typename ArchMmaOperator::Shape;
/// Type of the scales
using ElementScale = half_t;
/// Fragment to hold B data before Mma
using FragmentDequantizedOperand =
Array<ElementScale, MmaOperator::FragmentB::kElements>;
/// Warp mma shape
using Shape = Shape_;
// Fragment to hold scale data to apply to B before mma
// Each 32x32x4 matmul uses 8 elements from B.
static constexpr int ColsPerMmaTile = 32;
static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile;
using FragmentScale = Array<ElementScale, TileNIterations * 2>;
/// Layout of the scales in shared memory
using Layout = layout::RowMajor;
/// TensorRef type for loading element from a tensor
using TensorRef = TensorRef<ElementScale, Layout>;
CUTLASS_DEVICE
MmaTensorOpDequantizer(TensorRef smem_scales,
const int warp_idx_n,
const int lane_idx) {
const int warp_offset = warp_idx_n * Shape::kN;
const int base_col = lane_idx & 0xF8 + lane_idx % 4;
const int thread_offset = warp_offset + base_col;
pointer_ = smem_scales.data() + thread_offset;
}
CUTLASS_DEVICE
void load(FragmentScale& scale_frag) { // NOLINT
CUTLASS_PRAGMA_UNROLL
for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) {
// We jump by 32 here since volta does <32x32x4> super mmas inside a warp.
// For col major B, each thread will jump 4 cols to get its next value
// inside of the super mma.
CUTLASS_PRAGMA_UNROLL
for (int mma_iter = 0; mma_iter < 2; ++mma_iter) {
scale_frag[tile_iter * 2 + mma_iter] =
pointer_[ColsPerMmaTile * tile_iter + 4 * mma_iter];
}
}
}
CUTLASS_DEVICE
void dequantize(FragmentDequantizedOperand& operand_frag, // NOLINT
const FragmentScale& scale_frag) {
using MmaOperandB = typename ArchMmaOperator::FragmentB;
static constexpr int total_n_mmas = 2 * TileNIterations;
static_assert(MmaOperandB::kElements * total_n_mmas ==
FragmentDequantizedOperand::kElements,
"");
multiplies<MmaOperandB> mul_op;
MmaOperandB* operand_frag_ptr =
reinterpret_cast<MmaOperandB*>(&operand_frag);
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < total_n_mmas; ++mma_n_iter) {
operand_frag_ptr[mma_n_iter] =
mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]);
}
}
private:
ElementScale const* pointer_;
};
// Specialization for Turing & Ampere when Scale type is float and output type
// is half_t.
template <
/// Underlying matrix multiply operator (concept: MmaTensorOp)
typename MmaOperator_,
/// Shape of the warp level matrix multiply (concept: GemmShape)
typename Shape_>
class MmaTensorOpDequantizer<
MmaOperator_,
Shape_,
Operand::kB,
float,
layout::RowMajor,
32,
half_t,
typename platform::enable_if<
MmaOperator_::ArchTag::kMinComputeCapability >= 75 &&
platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB,
layout::ColumnMajor>::value>::type> {
public:
/// Mma Operator
using MmaOperator = MmaOperator_;
// The architecture specific mma ooperator being used
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
// Mma Instruction Shape
using InstructionShape = typename ArchMmaOperator::Shape;
// This is the ratio of the load instruction vs the compute instruction.
static constexpr int kExpansionFactor =
MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
/// Type of the output
using ElementType = half_t;
// using ElementScale = float;
using ElementScale = float;
/// Fragment to hold B data before Mma
using FragmentDequantizedOperand =
Array<ElementType, MmaOperator::FragmentB::kElements>;
// Fragment to hold scale data to apply to B before mma
// We need 1 fp16 per matrix iteration in the N dimension
static constexpr int kColsPerMmaPerThread = 1;
using FragmentScale =
Array<ElementScale,
kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
/// Warp mma shape
using Shape = Shape_;
/// Layout of the scales in shared memory
using Layout = layout::RowMajor;
/// TensorRef type for loading element from a tensor
using TensorRef = TensorRef<ElementScale, Layout>;
CUTLASS_DEVICE
MmaTensorOpDequantizer(TensorRef smem_scales,
const int warp_idx_n,
const int lane_idx) {
const int warp_offset = warp_idx_n * Shape::kN;
const int quad = lane_idx / 4;
const int thread_offset = warp_offset + quad;
pointer_ = smem_scales.data() + thread_offset;
}
CUTLASS_DEVICE
void load(FragmentScale& scale_frag) {
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn;
++mma_n_iter) {
scale_frag[mma_n_iter] = pointer_[mma_n_iter * InstructionShape::kN];
}
}
CUTLASS_DEVICE
void dequantize(FragmentDequantizedOperand& operand_frag,
const FragmentScale& scale_frag) {
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
using ExpandedMmaOperandB =
Array<typename _MmaOperandB::Element,
kExpansionFactor * _MmaOperandB::kElements>;
using ComputeFrag =
Array<ElementScale, kExpansionFactor * _MmaOperandB::kElements>;
static_assert(
ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn ==
FragmentDequantizedOperand::kElements,
"");
multiplies<ComputeFrag> mul_op;
ExpandedMmaOperandB* operand_frag_ptr =
reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
NumericArrayConverter<ElementScale,
ElementType,
kExpansionFactor * _MmaOperandB::kElements>
source_converter;
NumericArrayConverter<ElementType,
ElementScale,
kExpansionFactor * _MmaOperandB::kElements>
output_converter;
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn;
++mma_n_iter) {
ComputeFrag convert_frag = source_converter(operand_frag_ptr[mma_n_iter]);
convert_frag = mul_op(convert_frag, scale_frag[mma_n_iter]);
operand_frag_ptr[mma_n_iter] = output_converter(convert_frag);
}
}
private:
ElementScale const* pointer_;
};
// Specialization for Turing & Ampere when Scale type is float and output type
// is bfloat16.
#ifdef PADDLE_CUDA_BF16
template <
/// Underlying matrix multiply operator (concept: MmaTensorOp)
typename MmaOperator_,
/// Shape of the warp level matrix multiply (concept: GemmShape)
typename Shape_>
class MmaTensorOpDequantizer<
MmaOperator_,
Shape_,
Operand::kB,
float,
layout::RowMajor,
32,
bfloat16_t,
typename platform::enable_if<
MmaOperator_::ArchTag::kMinComputeCapability >= 75 &&
platform::is_same<typename MmaOperator_::ArchMmaOperator::LayoutB,
layout::ColumnMajor>::value>::type> {
public:
/// Mma Operator
using MmaOperator = MmaOperator_;
// The architecture specific mma ooperator being used
using ArchMmaOperator = typename MmaOperator::ArchMmaOperator;
// Mma Instruction Shape
using InstructionShape = typename ArchMmaOperator::Shape;
// This is the ratio of the load instruction vs the compute instruction.
static constexpr int kExpansionFactor =
MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK;
/// Type of the output
using ElementType = bfloat16_t;
// using ElementScale = float;
using ElementScale = float;
/// Fragment to hold B data before Mma
using FragmentDequantizedOperand =
Array<ElementType, MmaOperator::FragmentB::kElements>;
// Fragment to hold scale data to apply to B before mma
// We need 1 fp16 per matrix iteration in the N dimension
static constexpr int kColsPerMmaPerThread = 1;
using FragmentScale =
Array<ElementScale,
kColsPerMmaPerThread * MmaOperator::MmaIterations::kColumn>;
/// Warp mma shape
using Shape = Shape_;
/// Layout of the scales in shared memory
using Layout = layout::RowMajor;
/// TensorRef type for loading element from a tensor
using TensorRef = TensorRef<ElementScale, Layout>;
CUTLASS_DEVICE
MmaTensorOpDequantizer(TensorRef smem_scales,
const int warp_idx_n,
const int lane_idx) {
const int warp_offset = warp_idx_n * Shape::kN;
const int quad = lane_idx / 4;
const int thread_offset = warp_offset + quad;
pointer_ = smem_scales.data() + thread_offset;
}
CUTLASS_DEVICE
void load(FragmentScale& scale_frag) {
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn;
++mma_n_iter) {
scale_frag[mma_n_iter] = pointer_[mma_n_iter * InstructionShape::kN];
}
}
CUTLASS_DEVICE
void dequantize(FragmentDequantizedOperand& operand_frag,
const FragmentScale& scale_frag) {
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
using ExpandedMmaOperandB =
Array<typename _MmaOperandB::Element,
kExpansionFactor * _MmaOperandB::kElements>;
using ComputeFrag =
Array<ElementScale, kExpansionFactor * _MmaOperandB::kElements>;
static_assert(
ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn ==
FragmentDequantizedOperand::kElements,
"");
multiplies<ComputeFrag> mul_op;
ExpandedMmaOperandB* operand_frag_ptr =
reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
NumericArrayConverter<ElementScale,
ElementType,
kExpansionFactor * _MmaOperandB::kElements>
source_converter;
NumericArrayConverter<ElementType,
ElementScale,
kExpansionFactor * _MmaOperandB::kElements>
output_converter;
CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn;
++mma_n_iter) {
ComputeFrag convert_frag = source_converter(operand_frag_ptr[mma_n_iter]);
convert_frag = mul_op(convert_frag, scale_frag[mma_n_iter]);
operand_frag_ptr[mma_n_iter] = output_converter(convert_frag);
}
}
private:
ElementScale const* pointer_;
};
#endif
////////////////////////////////////////////////////////////////////////////////
} // namespace warp
} // namespace gemm
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*!
\file
\brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t
interleaved in a register
*/
#pragma once
#include "cutlass/arch/arch.h"
#include "cutlass/array.h"
#include "cutlass/half.h"
#include "cutlass/numeric_types.h"
namespace cutlass {
// This converter is meant to be used with data interleaved in a 32-bit register
// where the even elements are in the low bits and the odd elemeents are in the
// high bits of the register. In addition, it assumes elements were originally
// signed and had a bias of 2**(b-1) added (where b is the number of bits in the
// type) to make all numbers unsigned. This converter will uninterleave the data
// and subtract the bias while converting to the result type.
template <typename T, typename S, int N>
struct FastInterleavedAndBiasedNumericArrayConverter {};
template <>
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, 4> {
using result_type = Array<half_t, 4>;
using source_type = Array<uint8_t, 4>;
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
result_type result;
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
asm volatile("prmt.b32 %0,%1,%2,%3;\n"
: "=r"(h[0])
: "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01));
asm volatile("prmt.b32 %0,%1,%2,%3;\n"
: "=r"(h[1])
: "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23));
// Lastly, we subtract 1152 from our constructed number using fp16 math to
// get our signed integer as fp16.
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(h[0])
: "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM));
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(h[1])
: "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM));
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) { return convert(s); }
};
template <int N>
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, N> {
static constexpr int VEC_WIDTH = 4;
static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
using result_type = Array<half_t, N>;
using source_type = Array<uint8_t, N>;
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
using scalar_result_type = typename result_type::Element;
using scalar_source_type = typename source_type::Element;
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type,
scalar_source_type,
VEC_WIDTH>
convert_vector_;
result_type result;
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / VEC_WIDTH; ++i) {
result_ptr[i] = convert_vector_(source_ptr[i]);
}
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) { return convert(s); }
};
template <>
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint8_t, 4> {
using result_type = Array<bfloat16_t, 4>;
using source_type = Array<uint8_t, 4>;
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
result_type result;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&result);
uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);
static constexpr uint32_t fp32_base = 0x4B000000;
float fp32_intermediates[4];
// Construct FP32s, bfloat does not have enough mantissa for IADD trick
uint32_t* fp32_intermediates_casted =
reinterpret_cast<uint32_t*>(fp32_intermediates);
fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652);
fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651);
fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653);
// Subtract out fp32_base + 128 to make the unsigned integer signed.
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < 4; ++ii) {
fp32_intermediates[ii] -= 8388736.f;
}
// Truncate the fp32 representation and pack up as bfloat16s.
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < 2; ++ii) {
bf16_result_ptr[ii] = __byte_perm(fp32_intermediates_casted[2 * ii + 0],
fp32_intermediates_casted[2 * ii + 1],
0x7632);
}
#else
// Disable this on architectures older than Ampere since they lack hardware
// for bf16 mma. If one wishes to use HMMA on older hardware, they should
// Convert directly to FP16 using FP16 converters.
result.clear(); // Suppress compiler warning
arch::device_breakpoint();
#endif
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) { return convert(s); }
};
template <int N>
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint8_t, N> {
static constexpr int VEC_WIDTH = 4;
static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
using result_type = Array<bfloat16_t, N>;
using source_type = Array<uint8_t, N>;
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
using scalar_result_type = typename result_type::Element;
using scalar_source_type = typename source_type::Element;
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type,
scalar_source_type,
VEC_WIDTH>
convert_vector_;
result_type result;
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / VEC_WIDTH; ++i) {
result_ptr[i] = convert_vector_(source_ptr[i]);
}
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) { return convert(s); }
};
template <>
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint4b_t, 8> {
using result_type = Array<half_t, 8>;
using source_type = Array<uint4b_t, 8>;
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
result_type result;
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
// Note that the entire sequence only requires 1 shift instruction. This is
// thanks to the register packing format and the fact that we force our
// integers to be unsigned, and account for this in the fp16 subtractions.
// In addition, I exploit the fact that sub and fma have the same throughput
// in order to convert elt_23 and elt_67 to fp16 without having to shift
// them to the bottom bits before hand.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide
// RAW dependency if we issue immediately before required.
const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile(
"lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm volatile(
"lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(top_i4s),
"n"(BOTTOM_MASK),
"n"(I4s_TO_F16s_MAGIC_NUM),
"n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm volatile(
"lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// I use inline PTX below because I am not sure if the compiler will emit
// float2half instructions if I use the half2 ctor. In this case, I chose
// performance reliability over code readability.
// This is the half2 {1032, 1032} represented as an integer.
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
// This is the half2 {-72, -72} represented as an integer.
static constexpr uint32_t NEG_72 = 0xd480d480;
// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(h[0])
: "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(h[1])
: "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72));
// Convert elt_45
asm volatile("sub.f16x2 %0, %1, %2;\n"
: "=r"(h[2])
: "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(h[3])
: "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72));
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) { return convert(s); }
};
template <int N>
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint4b_t, N> {
static constexpr int VEC_WIDTH = 8;
static_assert(!(N % VEC_WIDTH), "N must be multiple of 8.");
using result_type = Array<half_t, N>;
using source_type = Array<uint4b_t, N>;
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
using scalar_result_type = typename result_type::Element;
using scalar_source_type = typename source_type::Element;
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type,
scalar_source_type,
VEC_WIDTH>
convert_vector_;
result_type result;
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / VEC_WIDTH; ++i) {
result_ptr[i] = convert_vector_(source_ptr[i]);
}
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) { return convert(s); }
};
template <>
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, 8> {
using result_type = Array<bfloat16_t, 8>;
using source_type = Array<uint4b_t, 8>;
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
result_type result;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const source_i4s = reinterpret_cast<uint32_t const&>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
// We don't have enough mantissa to remove as much shift overhead as FP16,
// so we must loop. No shift needed for first item.
uint32_t i4s = source_i4s;
asm volatile(
"lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
CUTLASS_PRAGMA_UNROLL
for (int ii = 1; ii < result_type::kElements / 2; ++ii) {
i4s >>= sizeof_bits<typename source_type::Element>::value;
// (i4s & 0x000f000f) | 0x43004300
asm volatile(
"lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[ii])
: "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
}
// This is the BF16 {-136, -136} represented as an integer.
static constexpr uint32_t BF16_BIAS = 0xC308C308;
static constexpr uint32_t BF16_ONE = 0x3F803F80;
// Finally, we construct the output numbers.
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < result_type::kElements / 2; ++ii) {
// Since this section is for Ampere+, we use bf16 fma to do the bias
// subtraction
asm("fma.rn.bf16x2 %0, %1, %2, %3;\n"
: "=r"(h[ii])
: "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS));
}
#else
// Disable this on architectures older than Ampere since they lack hardware
// for bf16 mma. If one wishes to use HMMA on older hardware, they should
// Convert directly to FP16 using FP16 converters.
arch::device_breakpoint();
result.clear(); // Suppress compiler warning.
#endif
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) { return convert(s); }
};
template <int N>
struct FastInterleavedAndBiasedNumericArrayConverter<bfloat16_t, uint4b_t, N> {
static constexpr int VEC_WIDTH = 8;
static_assert(!(N % VEC_WIDTH), "N must be multiple of 8.");
using result_type = Array<bfloat16_t, N>;
using source_type = Array<uint4b_t, N>;
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
using scalar_result_type = typename result_type::Element;
using scalar_source_type = typename source_type::Element;
FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type,
scalar_source_type,
VEC_WIDTH>
convert_vector_;
result_type result;
using vec_result = Array<scalar_result_type, VEC_WIDTH>;
using vec_source = Array<scalar_source_type, VEC_WIDTH>;
vec_result* result_ptr = reinterpret_cast<vec_result*>(&result);
vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / VEC_WIDTH; ++i) {
result_ptr[i] = convert_vector_(source_ptr[i]);
}
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) { return convert(s); }
};
} // namespace cutlass
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/pitch_linear_coord.h"
namespace cutlass {
namespace layout {
template <int RowsPerTile, int ColumnsInterleaved>
class ColumnMajorTileInterleave {
static constexpr int kRowsPerTile = RowsPerTile;
static constexpr int kColumnsInterleaved = ColumnsInterleaved;
};
template <class T>
struct IsColumnMajorTileInterleave {
static constexpr bool value = false;
};
template <int U, int V>
struct IsColumnMajorTileInterleave<ColumnMajorTileInterleave<U, V>> {
static constexpr bool value = true;
};
} // namespace layout
} // namespace cutlass
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h"
namespace phi {
struct TileShape {
int m;
int n;
};
static TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) {
switch (tile_config) {
case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
return TileShape{32, 128};
case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64:
case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
return TileShape{64, 128};
case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8:
case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64:
case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
return TileShape{128, 128};
// TODO(wangbojun) check the Tile Shape here,
// {256, 128} have better performance than 128, 128
case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64:
return TileShape{256, 128};
// TODO(wangbojun) CtaShape256x128x64_WarpShape64x64x64 is not a
case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64:
return TileShape{256, 128};
default:
throw std::runtime_error(
"[FT Error][get_grid_shape_for_config] Invalid config");
}
}
static bool is_valid_split_k_factor(const int64_t m,
const int64_t n,
const int64_t k,
const TileShape tile_shape,
const int split_k_factor,
const size_t workspace_bytes,
const bool is_weight_only) {
// All tile sizes have a k_tile of 64.
static constexpr int k_tile = 64;
// For weight-only quant, we need k and k_elements_per_split to be a multiple
// of cta_k
if (is_weight_only) {
if ((k % k_tile) != 0) {
return false;
}
if ((k % split_k_factor) != 0) {
return false;
}
const int k_elements_per_split = k / split_k_factor;
if ((k_elements_per_split % k_tile) != 0) {
return false;
}
}
// Check that the workspace has sufficient space for this split-k factor
const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m;
const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n;
const int required_ws_bytes =
split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim;
if (required_ws_bytes > workspace_bytes) {
return false;
}
return true;
}
static std::vector<CutlassTileConfig> get_candidate_tiles(
const bool is_weight_only,
const bool is_weight_only_encoder,
const bool simt_configs_only) {
std::vector<CutlassTileConfig> simt_configs{
CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8};
std::vector<CutlassTileConfig> square_configs{
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64,
CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64,
};
std::vector<CutlassTileConfig> quant_B_configs{
CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64,
CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64,
CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64,
};
std::vector<CutlassTileConfig> encoder_quant_B_configs{
CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64
// CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64
};
const std::vector<CutlassTileConfig> allowed_quant_B_configs =
is_weight_only_encoder ? encoder_quant_B_configs : quant_B_configs;
const std::vector<CutlassTileConfig> allowed_configs =
is_weight_only ? allowed_quant_B_configs : square_configs;
return simt_configs_only ? simt_configs : allowed_configs;
}
static std::vector<CutlassGemmConfig> get_candidate_configs(
int sm,
const bool is_weight_only,
const bool is_weight_only_encoder,
const bool simt_configs_only) {
std::vector<CutlassTileConfig> tiles = get_candidate_tiles(
is_weight_only, is_weight_only_encoder, simt_configs_only);
std::vector<CutlassGemmConfig> candidate_configs;
const int min_stages = 2;
const int max_stages = sm >= 80 ? 4 : 2;
for (const auto& tile_config : tiles) {
for (int stages = min_stages; stages <= max_stages; ++stages) {
CutlassGemmConfig config{tile_config, SplitKStyle::NO_SPLIT_K, 1, stages};
candidate_configs.push_back(config);
}
}
return candidate_configs;
}
static CutlassGemmConfig estimate_best_config_from_occupancies(
const std::vector<CutlassGemmConfig>& candidate_configs,
const std::vector<int>& occupancies,
const int64_t m,
const int64_t n,
const int64_t k,
const int64_t num_experts,
const int split_k_limit,
const size_t workspace_bytes,
const int multi_processor_count,
const int is_weight_only) {
if (occupancies.size() != candidate_configs.size()) {
throw std::runtime_error(
"[FT Error][estimate_best_config_from_occupancies] occpancies and "
"candidate configs vectors must have equal length.");
}
CutlassGemmConfig best_config;
// Score will be [0, 1]. The objective is to minimize this score.
// It represents the fraction of SM resources unused in the last wave.
float config_score = 1.0f;
int config_waves = INT_MAX;
int current_m_tile = 0;
const int max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit;
for (int ii = 0; ii < candidate_configs.size(); ++ii) {
CutlassGemmConfig candidate_config = candidate_configs[ii];
TileShape tile_shape =
get_cta_shape_for_config(candidate_config.tile_config);
int occupancy = occupancies[ii];
if (occupancy == 0) {
continue;
}
// Keep small tile sizes when possible.
if (best_config.tile_config != CutlassTileConfig::ChooseWithHeuristic &&
m < current_m_tile && current_m_tile < tile_shape.m) {
continue;
}
const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m;
const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n;
for (int split_k_factor = 1; split_k_factor <= max_split_k;
++split_k_factor) {
if (is_valid_split_k_factor(m,
n,
k,
tile_shape,
split_k_factor,
workspace_bytes,
is_weight_only)) {
const int ctas_per_wave = occupancy * multi_processor_count;
const int ctas_for_problem =
ctas_in_m_dim * ctas_in_n_dim * split_k_factor;
const int num_waves_total =
(ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave;
const float num_waves_fractional =
ctas_for_problem / static_cast<float>(ctas_per_wave);
const float current_score =
static_cast<float>(num_waves_total) - num_waves_fractional;
const float score_slack = 0.1f;
if (current_score < config_score ||
((config_waves > num_waves_total) &&
(current_score < config_score + score_slack))) {
config_score = current_score;
config_waves = num_waves_total;
SplitKStyle split_style = split_k_factor > 1
? SplitKStyle::SPLIT_K_SERIAL
: SplitKStyle::NO_SPLIT_K;
best_config = CutlassGemmConfig{candidate_config.tile_config,
split_style,
split_k_factor,
candidate_config.stages};
current_m_tile = tile_shape.m;
} else if (current_score == config_score &&
(best_config.stages < candidate_config.stages ||
split_k_factor < best_config.split_k_factor ||
current_m_tile < tile_shape.m)) {
// Prefer deeper pipeline or smaller split-k
SplitKStyle split_style = split_k_factor > 1
? SplitKStyle::SPLIT_K_SERIAL
: SplitKStyle::NO_SPLIT_K;
best_config = CutlassGemmConfig{candidate_config.tile_config,
split_style,
split_k_factor,
candidate_config.stages};
current_m_tile = tile_shape.m;
config_waves = num_waves_total;
}
}
}
}
if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic) {
throw std::runtime_error(
"[FT Error] Heurisitc failed to find a valid config.");
}
return best_config;
}
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h"
// #include "src/fastertransformer/utils/allocator.h"
#include "cuda_runtime_api.h" // NOLINT
namespace phi {
/*
This runner only supports:
T in {half, __nv_bfloat} WeightType in {uint8_t, cutlass::uint4b_t}
Activations, biases, scales and outputs are all assumed to be row-major.
However, it is assumed that B is in a special format governed by
cutlass_extensions/gemm/kernel/mixed_gemm_B_layout. In this case, B must be
preprocessed using the cutlass weight only quant preprocessors. The weight
preprocessor will instantiate the layout and preprocess based on the
instantiation, so layout changes should only require modifications to
mix_gemm_B_layout.h.
*/
template <typename T, typename WeightType>
class CutlassFpAIntBGemmRunner {
public:
CutlassFpAIntBGemmRunner();
~CutlassFpAIntBGemmRunner();
void gemm(const T* A,
const WeightType* B,
const float* weight_scales,
T* C,
int m,
int n,
int k,
char* workspace_ptr,
const size_t workspace_bytes,
cudaStream_t stream);
void gemm_bias_act(const T* A,
const WeightType* B,
const float* weight_scales,
const T* biases,
T* C,
int m,
int n,
int k,
std::string activation_type,
char* workspace_ptr,
const size_t workspace_bytes,
cudaStream_t stream);
// Returns desired workspace size in bytes.
int getWorkspaceSize(const int m, const int n, const int k);
private:
template <typename EpilogueTag>
void dispatch_to_arch(const T* A,
const WeightType* B,
const float* weight_scales,
const T* biases,
T* C,
int m,
int n,
int k,
CutlassGemmConfig gemm_config,
char* workspace_ptr,
const size_t workspace_bytes,
cudaStream_t stream,
int* occupancy = nullptr);
template <typename EpilogueTag>
void run_gemm(const T* A,
const WeightType* B,
const float* weight_scales,
const T* biases,
T* C,
int m,
int n,
int k,
char* workspace_ptr,
const size_t workspace_bytes,
cudaStream_t stream);
private:
static constexpr int split_k_limit = 7;
int sm_;
int multi_processor_count_;
};
// This allocation is present to help with compiling with other structures in
// FT. It will throw an error in all functions because this runner assumes the
// weight type and the activation type are different. We allow empty classes to
// be created, but any calls to gemm or gemm_bias_act will throw an error.
template <typename WeightType>
class CutlassFpAIntBGemmRunner<float, WeightType> {
public:
CutlassFpAIntBGemmRunner() = default;
~CutlassFpAIntBGemmRunner() = default;
void gemm(const float* A,
const WeightType* B,
const float* weight_scales,
float* C,
int m,
int n,
int k,
char* workspace_ptr,
const size_t workspace_bytes,
cudaStream_t stream);
void gemm_bias_act(const float* A,
const WeightType* B,
const float* weight_scales,
const float* biases,
float* C,
int m,
int n,
int k,
std::string activation_type,
char* workspace_ptr,
const size_t workspace_bytes,
cudaStream_t stream);
int getWorkspaceSize(const int m, const int n, const int k);
};
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma once
#include "cutlass/gemm/device/gemm_universal_base.h"
#include "cutlass/gemm/kernel/default_gemm.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/compute_occupancy.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue_helpers.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_mma.h"
#pragma GCC diagnostic pop
#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h"
#include "paddle/phi/kernels/fusion/cutlass/utils/cuda_utils.h"
namespace phi {
template <typename T,
typename WeightType,
typename arch,
typename EpilogueTag,
typename ThreadblockShape,
typename WarpShape,
int Stages>
void generic_mixed_gemm_kernelLauncher(const T* A,
const WeightType* B,
const float* weight_scales,
const T* biases,
T* C,
int m,
int n,
int k,
CutlassGemmConfig gemm_config,
char* workspace,
size_t workspace_bytes,
cudaStream_t stream,
int* occupancy = nullptr) {
static_assert(cutlass::platform::is_same<T, half>::value ||
#ifdef PADDLE_CUDA_BF16
cutlass::platform::is_same<T, __nv_bfloat16>::value ||
#endif
cutlass::platform::is_same<T, float>::value,
"Specialized for bfloat16, half, float");
static_assert(
cutlass::platform::is_same<T, WeightType>::value ||
cutlass::platform::is_same<WeightType, uint8_t>::value ||
cutlass::platform::is_same<WeightType, cutlass::uint4b_t>::value,
"");
// The cutlass type for the input elements. This is needed to convert to
// cutlass::half_t if necessary.
using ElementType_ = typename cutlass::platform::conditional<
cutlass::platform::is_same<T, half>::value,
cutlass::half_t,
T>::type;
using ElementType = ElementType_;
using CutlassWeightType_ = typename cutlass::platform::conditional<
cutlass::platform::is_same<WeightType, half>::value,
cutlass::half_t,
WeightType>::type;
using CutlassWeightType = CutlassWeightType_;
// We need separate config for each architecture since we will target
// different tensorcore instructions. For float, we do not target TCs.
using MixedGemmArchTraits = cutlass::gemm::kernel::
MixedGemmArchTraits<ElementType, CutlassWeightType, arch>;
using ElementAccumulator = typename MixedGemmArchTraits::AccType;
using EpilogueOp = typename Epilogue<ElementType,
MixedGemmArchTraits::ElementsPerAccessC,
ElementAccumulator,
EpilogueTag>::Op;
using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm<
ElementType,
cutlass::layout::RowMajor,
MixedGemmArchTraits::ElementsPerAccessA,
CutlassWeightType,
typename MixedGemmArchTraits::LayoutB,
MixedGemmArchTraits::ElementsPerAccessB,
ElementType,
cutlass::layout::RowMajor,
ElementAccumulator,
cutlass::arch::OpClassTensorOp,
arch,
ThreadblockShape,
WarpShape,
typename MixedGemmArchTraits::InstructionShape,
EpilogueOp,
typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
Stages,
true,
typename MixedGemmArchTraits::Operator>::GemmKernel;
using GemmKernel = cutlass::gemm::kernel::GemmFpAIntB<
typename GemmKernel_::Mma,
typename GemmKernel_::Epilogue,
typename GemmKernel_::ThreadblockSwizzle,
arch, // Ensure top level arch is used for dispatch
GemmKernel_::kSplitKSerial>;
if (occupancy != nullptr) {
*occupancy = compute_occupancy_for_kernel<GemmKernel>();
return;
}
using Gemm = cutlass::gemm::device::GemmUniversalBase<GemmKernel>;
const int ldb =
cutlass::platform::is_same<cutlass::layout::RowMajor,
typename MixedGemmArchTraits::LayoutB>::value
? n
: k * GemmKernel::kInterleave;
typename Gemm::Arguments args(
{m, n, k},
{reinterpret_cast<ElementType*>(const_cast<T*>(A)), k},
{reinterpret_cast<CutlassWeightType*>(const_cast<WeightType*>(B)), ldb},
{reinterpret_cast<float*>(const_cast<float*>(weight_scales)), 0},
{reinterpret_cast<ElementType*>(const_cast<T*>(biases)), 0},
{reinterpret_cast<ElementType*>(C), n},
gemm_config.split_k_factor,
{ElementAccumulator(1.f), ElementAccumulator(0.f)});
// This assertion is enabled because because for the column interleaved
// layout, K MUST be a multiple of threadblockK. The reason for this is that
// the default pitchlinear iterators are used to handle walking over the
// interleaved matrix. The way masking in handled in these do not map to the
// interleaved layout. We need to write our own predicated iterator in order
// to relax this limitation.
if (GemmKernel::kInterleave > 1 && ((k % MixedGemmArchTraits::ThreadblockK) ||
((k / gemm_config.split_k_factor) %
MixedGemmArchTraits::ThreadblockK))) {
throw std::runtime_error(
"Temp assertion: k must be multiple of threadblockK");
}
Gemm gemm;
if (gemm.get_workspace_size(args) > workspace_bytes) {
VLOG(1) << "Requested split-k but workspace size insufficient. Falling "
"back to non-split-k implementation.";
// If requested split-k factor will require more workspace bytes, revert to
// standard gemm.
args.batch_count = 1;
}
auto can_implement = gemm.can_implement(args);
if (can_implement != cutlass::Status::kSuccess) {
std::string err_msg =
"fpA_intB cutlass kernel will fail for params. Error: " +
std::string(cutlassGetStatusString(can_implement));
throw std::runtime_error("[fpA_intB Runner] " + err_msg);
}
auto init_status = gemm.initialize(args, workspace, stream);
if (init_status != cutlass::Status::kSuccess) {
std::string err_msg =
"Failed to initialize cutlass fpA_intB gemm. Error: " +
std::string(cutlassGetStatusString(init_status));
throw std::runtime_error("[fpA_intB Runner] " + err_msg);
}
auto run_status = gemm.run(stream);
if (run_status != cutlass::Status::kSuccess) {
std::string err_msg = "Failed to run cutlass fpA_intB gemm. Error: " +
std::string(cutlassGetStatusString(run_status));
throw std::runtime_error("[fpA_intB Runner] " + err_msg);
}
}
template <typename T,
typename WeightType,
typename arch,
typename EpilogueTag,
typename ThreadblockShape,
typename WarpShape,
int Stages,
typename Enable = void>
struct dispatch_stages {
static void dispatch(const T* A,
const WeightType* B,
const float* weight_scales,
const T* biases,
T* C,
int m,
int n,
int k,
CutlassGemmConfig gemm_config,
char* workspace,
size_t workspace_bytes,
cudaStream_t stream,
int* occupancy = nullptr) {
std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " +
std::to_string(arch::kMinComputeCapability) +
" with stages set to " + std::to_string(Stages);
throw std::runtime_error("[dispatch_stages::dispatch] " + err_msg);
}
};
template <typename T,
typename WeightType,
typename arch,
typename EpilogueTag,
typename ThreadblockShape,
typename WarpShape>
struct dispatch_stages<T,
WeightType,
arch,
EpilogueTag,
ThreadblockShape,
WarpShape,
2> {
static void dispatch(const T* A,
const WeightType* B,
const float* weight_scales,
const T* biases,
T* C,
int m,
int n,
int k,
CutlassGemmConfig gemm_config,
char* workspace,
size_t workspace_bytes,
cudaStream_t stream,
int* occupancy = nullptr) {
generic_mixed_gemm_kernelLauncher<T,
WeightType,
arch,
EpilogueTag,
ThreadblockShape,
WarpShape,
2>(A,
B,
weight_scales,
biases,
C,
m,
n,
k,
gemm_config,
workspace,
workspace_bytes,
stream,
occupancy);
}
};
template <typename T,
typename WeightType,
typename EpilogueTag,
typename ThreadblockShape,
typename WarpShape,
int Stages>
struct dispatch_stages<T,
WeightType,
cutlass::arch::Sm80,
EpilogueTag,
ThreadblockShape,
WarpShape,
Stages,
typename std::enable_if<(Stages > 2)>::type> {
static void dispatch(const T* A,
const WeightType* B,
const float* weight_scales,
const T* biases,
T* C,
int m,
int n,
int k,
CutlassGemmConfig gemm_config,
char* workspace,
size_t workspace_bytes,
cudaStream_t stream,
int* occupancy = nullptr) {
generic_mixed_gemm_kernelLauncher<T,
WeightType,
cutlass::arch::Sm80,
EpilogueTag,
ThreadblockShape,
WarpShape,
Stages>(A,
B,
weight_scales,
biases,
C,
m,
n,
k,
gemm_config,
workspace,
workspace_bytes,
stream,
occupancy);
}
};
template <typename T,
typename WeightType,
typename arch,
typename EpilogueTag,
typename ThreadblockShape,
typename WarpShape>
void dispatch_gemm_config(const T* A,
const WeightType* B,
const float* weight_scales,
const T* biases,
T* C,
int m,
int n,
int k,
CutlassGemmConfig gemm_config,
char* workspace,
size_t workspace_bytes,
cudaStream_t stream,
int* occupancy = nullptr) {
switch (gemm_config.stages) {
case 2:
using DispatcherStages2 = dispatch_stages<T,
WeightType,
arch,
EpilogueTag,
ThreadblockShape,
WarpShape,
2>;
DispatcherStages2::dispatch(A,
B,
weight_scales,
biases,
C,
m,
n,
k,
gemm_config,
workspace,
workspace_bytes,
stream,
occupancy);
break;
case 3:
using DispatcherStages3 = dispatch_stages<T,
WeightType,
arch,
EpilogueTag,
ThreadblockShape,
WarpShape,
3>;
DispatcherStages3::dispatch(A,
B,
weight_scales,
biases,
C,
m,
n,
k,
gemm_config,
workspace,
workspace_bytes,
stream,
occupancy);
break;
case 4:
using DispatcherStages4 = dispatch_stages<T,
WeightType,
arch,
EpilogueTag,
ThreadblockShape,
WarpShape,
4>;
DispatcherStages4::dispatch(A,
B,
weight_scales,
biases,
C,
m,
n,
k,
gemm_config,
workspace,
workspace_bytes,
stream,
occupancy);
break;
default:
std::string err_msg = "dispatch_gemm_config does not support stages " +
std::to_string(gemm_config.stages);
throw std::runtime_error("[dispatch_gemm_config] " + err_msg);
break;
}
}
template <typename T, typename WeightType, typename arch, typename EpilogueTag>
void dispatch_gemm_to_cutlass(const T* A,
const WeightType* B,
const float* weight_scales,
const T* biases,
T* C,
int m,
int n,
int k,
char* workspace,
size_t workspace_bytes,
CutlassGemmConfig gemm_config,
cudaStream_t stream,
int* occupancy = nullptr) {
// Note that SIMT configs are omitted here since they are not supported for
// fpA_intB. We also only instantiate configs here where threadblockShapeM ==
// warpShapeM since those usually perform the best for mixed type gemms.
switch (gemm_config.tile_config) {
case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
dispatch_gemm_config<T,
WeightType,
arch,
EpilogueTag,
cutlass::gemm::GemmShape<32, 128, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(
A,
B,
weight_scales,
biases,
C,
m,
n,
k,
gemm_config,
workspace,
workspace_bytes,
stream,
occupancy);
break;
case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
dispatch_gemm_config<T,
WeightType,
arch,
EpilogueTag,
cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 32, 64>>(
A,
B,
weight_scales,
biases,
C,
m,
n,
k,
gemm_config,
workspace,
workspace_bytes,
stream,
occupancy);
break;
case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
dispatch_gemm_config<T,
WeightType,
arch,
EpilogueTag,
cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<128, 32, 64>>(
A,
B,
weight_scales,
biases,
C,
m,
n,
k,
gemm_config,
workspace,
workspace_bytes,
stream,
occupancy);
break;
// config for M_16000_N_12288_K_6144 in encoder
case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64:
dispatch_gemm_config<T,
WeightType,
arch,
EpilogueTag,
cutlass::gemm::GemmShape<256, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 64>>(
A,
B,
weight_scales,
biases,
C,
m,
n,
k,
gemm_config,
workspace,
workspace_bytes,
stream,
occupancy);
break;
case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64:
dispatch_gemm_config<T,
WeightType,
arch,
EpilogueTag,
cutlass::gemm::GemmShape<128, 256, 64>,
cutlass::gemm::GemmShape<64, 64, 64>>(
A,
B,
weight_scales,
biases,
C,
m,
n,
k,
gemm_config,
workspace,
workspace_bytes,
stream,
occupancy);
break;
case CutlassTileConfig::Undefined:
throw std::runtime_error(
"[fpA_intB][dispatch_gemm_to_cutlass] gemm config "
"undefined.");
break;
case CutlassTileConfig::ChooseWithHeuristic:
throw std::runtime_error(
"[fpA_intB][dispatch_gemm_to_cutlass] gemm config should "
"have already been set by heuristic.");
break;
default:
throw std::runtime_error(
"[fpA_intB][dispatch_gemm_to_cutlass] Config is invalid "
"for mixed type GEMM.");
break;
}
}
template <typename T, typename WeightType>
CutlassFpAIntBGemmRunner<T, WeightType>::CutlassFpAIntBGemmRunner() {
int device{-1};
check_cuda_error(cudaGetDevice(&device));
sm_ = getSMVersion();
check_cuda_error(cudaDeviceGetAttribute(
&multi_processor_count_, cudaDevAttrMultiProcessorCount, device));
}
template <typename T, typename WeightType>
CutlassFpAIntBGemmRunner<T, WeightType>::~CutlassFpAIntBGemmRunner() {}
template <typename T, typename WeightType>
template <typename EpilogueTag>
void CutlassFpAIntBGemmRunner<T, WeightType>::dispatch_to_arch<EpilogueTag>(
const T* A,
const WeightType* B,
const float* weight_scales,
const T* biases,
T* C,
int m,
int n,
int k,
CutlassGemmConfig gemm_config,
char* workspace_ptr,
const size_t workspace_bytes,
cudaStream_t stream,
int* occupancy) {
// if (sm_ >= 70 && sm_ < 75) {
// dispatch_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm70,
// EpilogueTag>(
// A, B, weight_scales, biases, C, m, n, k, workspace_ptr,
// workspace_bytes, gemm_config, stream, occupancy);
// }
// else if (sm_ >= 75 && sm_ < 80) {
// dispatch_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm75,
// EpilogueTag>(
// A, B, weight_scales, biases, C, m, n, k, workspace_ptr,
// workspace_bytes, gemm_config, stream, occupancy);
// }
// else
if (sm_ >= 80 && sm_ < 90) {
dispatch_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm80, EpilogueTag>(
A,
B,
weight_scales,
biases,
C,
m,
n,
k,
workspace_ptr,
workspace_bytes,
gemm_config,
stream,
occupancy);
} else {
throw std::runtime_error(
"[CutlassFpAIntBGemmRunner][GEMM Dispatch] Arch unsupported "
"for CUTLASS mixed type GEMM");
}
}
template <typename T, typename WeightType>
template <typename EpilogueTag>
void CutlassFpAIntBGemmRunner<T, WeightType>::run_gemm<EpilogueTag>(
const T* A,
const WeightType* B,
const float* weight_scales,
const T* biases,
T* C,
int m,
int n,
int k,
char* workspace_ptr,
const size_t workspace_bytes,
cudaStream_t stream) {
static constexpr bool is_weight_only = !std::is_same<T, WeightType>::value;
const bool is_weight_only_encoder = m >= 512 ? true : false;
std::vector<CutlassGemmConfig> candidate_configs =
get_candidate_configs(sm_, is_weight_only, is_weight_only_encoder, false);
std::vector<int> occupancies(candidate_configs.size());
for (size_t ii = 0; ii < candidate_configs.size(); ++ii) {
dispatch_to_arch<EpilogueTag>(A,
B,
weight_scales,
biases,
C,
m,
n,
k,
candidate_configs[ii],
workspace_ptr,
workspace_bytes,
stream,
&occupancies[ii]);
}
// Standard GEMM, so 1 "expert". We use the same function for MoE and regular
// FFN.
static constexpr int num_experts = 1;
CutlassGemmConfig chosen_config =
estimate_best_config_from_occupancies(candidate_configs,
occupancies,
m,
n,
k,
num_experts,
split_k_limit,
workspace_bytes,
multi_processor_count_,
is_weight_only);
dispatch_to_arch<EpilogueTag>(A,
B,
weight_scales,
biases,
C,
m,
n,
k,
chosen_config,
workspace_ptr,
workspace_bytes,
stream);
}
template <typename T, typename WeightType>
void CutlassFpAIntBGemmRunner<T, WeightType>::gemm_bias_act(
const T* A,
const WeightType* B,
const float* weight_scales,
const T* biases,
T* C,
int m,
int n,
int k,
std::string activation_type,
char* workspace_ptr,
const size_t workspace_bytes,
cudaStream_t stream) {
if (activation_type == "gelu") {
run_gemm<EpilogueOpBiasFtGelu>(A,
B,
weight_scales,
biases,
C,
m,
n,
k,
workspace_ptr,
workspace_bytes,
stream);
} else if (activation_type == "relu") {
run_gemm<EpilogueOpBiasReLU>(A,
B,
weight_scales,
biases,
C,
m,
n,
k,
workspace_ptr,
workspace_bytes,
stream);
} else if (activation_type == "none") {
run_gemm<EpilogueOpBias>(A,
B,
weight_scales,
biases,
C,
m,
n,
k,
workspace_ptr,
workspace_bytes,
stream);
} else {
throw std::runtime_error(("Invalid activation type."));
}
}
template <typename T, typename WeightType>
void CutlassFpAIntBGemmRunner<T, WeightType>::gemm(const T* A,
const WeightType* B,
const float* weight_scales,
T* C,
int m,
int n,
int k,
char* workspace_ptr,
const size_t workspace_bytes,
cudaStream_t stream) {
run_gemm<EpilogueOpNoBias>(A,
B,
weight_scales,
nullptr,
C,
m,
n,
k,
workspace_ptr,
workspace_bytes,
stream);
}
template <typename T, typename WeightType>
int CutlassFpAIntBGemmRunner<T, WeightType>::getWorkspaceSize(const int m,
const int n,
const int k) {
// sizes for each config, which would launch the maximum number of blocks
const int max_grid_m = (m + 31) / 32;
const int max_grid_n = (n + 127) / 128;
// We need 4 bytes per block in the worst case. We launch split_k_limit in z
// dim.
return max_grid_m * max_grid_n * split_k_limit * 4;
}
// =============================== Specialization T == WeightType
// =======================================
template <typename WeightType>
void CutlassFpAIntBGemmRunner<float, WeightType>::gemm_bias_act(
const float* A,
const WeightType* B,
const float* weight_scales,
const float* biases,
float* C,
int m,
int n,
int k,
std::string activation_type,
char* workspace_ptr,
const size_t workspace_bytes,
cudaStream_t stream) {
throw std::runtime_error(
("Attempting to run mixed gemm bias act when the types are the same is "
"an error."));
}
template <typename WeightType>
void CutlassFpAIntBGemmRunner<float, WeightType>::gemm(
const float* A,
const WeightType* B,
const float* weight_scales,
float* C,
int m,
int n,
int k,
char* workspace_ptr,
const size_t workspace_bytes,
cudaStream_t stream) {
throw std::runtime_error((
"Attempting to run mixed gemm when the types are the same is an error."));
}
template <typename WeightType>
int CutlassFpAIntBGemmRunner<float, WeightType>::getWorkspaceSize(const int m,
const int n,
const int k) {
return 0;
}
} // namespace phi
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
#ifdef SPARSITY_ENABLED
#include <cusparseLt.h>
#endif
namespace phi {
#define MAX_CONFIG_NUM 20
#define COL32_ 32
// workspace for cublas gemm : 32MB
#define CUBLAS_WORKSPACE_SIZE 33554432
typedef struct __align__(4) {
half x, y, z, w;
}
half4;
/* **************************** type definition ***************************** */
enum CublasDataType {
FLOAT_DATATYPE = 0,
HALF_DATATYPE = 1,
BFLOAT16_DATATYPE = 2,
INT8_DATATYPE = 3,
FP8_DATATYPE = 4
};
// enum FtCudaDataType { FP32 = 0, FP16 = 1, BF16 = 2, INT8 = 3, FP8 = 4 };
// enum class OperationType { FP32, FP16, BF16, INT8, FP8 };
/* **************************** debug tools ********************************* */
static const char* _cudaGetErrorEnum(cudaError_t error) {
return cudaGetErrorString(error);
}
static const char* _cudaGetErrorEnum(cublasStatus_t error) {
switch (error) {
case CUBLAS_STATUS_SUCCESS:
return "CUBLAS_STATUS_SUCCESS";
case CUBLAS_STATUS_NOT_INITIALIZED:
return "CUBLAS_STATUS_NOT_INITIALIZED";
case CUBLAS_STATUS_ALLOC_FAILED:
return "CUBLAS_STATUS_ALLOC_FAILED";
case CUBLAS_STATUS_INVALID_VALUE:
return "CUBLAS_STATUS_INVALID_VALUE";
case CUBLAS_STATUS_ARCH_MISMATCH:
return "CUBLAS_STATUS_ARCH_MISMATCH";
case CUBLAS_STATUS_MAPPING_ERROR:
return "CUBLAS_STATUS_MAPPING_ERROR";
case CUBLAS_STATUS_EXECUTION_FAILED:
return "CUBLAS_STATUS_EXECUTION_FAILED";
case CUBLAS_STATUS_INTERNAL_ERROR:
return "CUBLAS_STATUS_INTERNAL_ERROR";
case CUBLAS_STATUS_NOT_SUPPORTED:
return "CUBLAS_STATUS_NOT_SUPPORTED";
case CUBLAS_STATUS_LICENSE_ERROR:
return "CUBLAS_STATUS_LICENSE_ERROR";
}
return "<unknown>";
}
template <typename T>
void check(T result,
char const* const func,
const char* const file,
int const line) {
if (result) {
throw std::runtime_error(std::string("[ERROR] CUDA runtime error: ") +
(_cudaGetErrorEnum(result)) + " " + file + ":" +
std::to_string(line) + " \n");
}
}
#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__)
#define check_cuda_error_2(val, file, line) check((val), #val, file, line)
inline void syncAndCheck(const char* const file, int const line) {
// When FT_DEBUG_LEVEL=DEBUG, must check error
static char* level_name = std::getenv("FT_DEBUG_LEVEL");
if (level_name != nullptr) {
static std::string level = std::string(level_name);
if (level == "DEBUG") {
cudaDeviceSynchronize();
cudaError_t result = cudaGetLastError();
if (result) {
throw std::runtime_error(std::string("[ERROR] CUDA runtime error: ") +
(_cudaGetErrorEnum(result)) + " " + file +
":" + std::to_string(line) + " \n");
}
VLOG(2) << "run syncAndCheck at " << file << ":" << line;
}
}
#ifndef NDEBUG
cudaDeviceSynchronize();
cudaError_t result = cudaGetLastError();
if (result) {
throw std::runtime_error(std::string("[ERROR] CUDA runtime error: ") +
(_cudaGetErrorEnum(result)) + " " + file + ":" +
std::to_string(line) + " \n");
}
#endif
}
#define sync_check_cuda_error() syncAndCheck(__FILE__, __LINE__)
#define checkCUDNN(expression) \
{ \
cudnnStatus_t status = (expression); \
if (status != CUDNN_STATUS_SUCCESS) { \
std::cerr << "Error on file " << __FILE__ << " line " << __LINE__ \
<< ": " << cudnnGetErrorString(status); \
std::exit(EXIT_FAILURE); \
} \
}
template <typename T>
void print_to_file(const T* result,
const int size,
const char* file,
cudaStream_t stream = 0,
std::ios::openmode open_mode = std::ios::out);
template <typename T>
void print_abs_mean(const T* buf,
uint size,
cudaStream_t stream,
std::string name = "");
template <typename T>
void print_to_screen(const T* result, const int size);
template <typename T>
void check_max_val(const T* result, const int size);
template <typename T>
void check_abs_mean_val(const T* result, const int size);
#define PRINT_FUNC_NAME_() \
do { \
VLOG(2) << "[CALL] " << __FUNCTION__ << " "; \
} while (0)
[[noreturn]] inline void throwRuntimeError(const char* const file,
int const line,
std::string const& info = "") {
throw std::runtime_error(std::string("[ERROR] ") + info +
" Assertion fail: " + file + ":" +
std::to_string(line) + " \n");
}
inline void myAssert(bool result,
const char* const file,
int const line,
std::string const& info = "") {
if (!result) {
throwRuntimeError(file, line, info);
}
}
#define FT_CHECK(val) myAssert(val, __FILE__, __LINE__)
#define FT_CHECK_WITH_INFO(val, info) \
do { \
bool is_valid_val = (val); \
if (!is_valid_val) { \
paddle::operators::myAssert(is_valid_val, __FILE__, __LINE__, (info)); \
} \
} while (0)
#define FT_THROW(info) throwRuntimeError(__FILE__, __LINE__, info)
#ifdef SPARSITY_ENABLED
#define CHECK_CUSPARSE(func) \
{ \
cusparseStatus_t status = (func); \
if (status != CUSPARSE_STATUS_SUCCESS) { \
throw std::runtime_error( \
std::string("[ERROR] CUSPARSE API failed at line ") + \
std::to_string(__LINE__) + " in file " + __FILE__ + ": " + \
cusparseGetErrorString(status) + " " + std::to_string(status)); \
} \
}
#endif
/*************Time Handling**************/
class CudaTimer {
private:
cudaEvent_t event_start_;
cudaEvent_t event_stop_;
cudaStream_t stream_;
public:
explicit CudaTimer(cudaStream_t stream = 0) { stream_ = stream; }
void start() {
check_cuda_error(cudaEventCreate(&event_start_));
check_cuda_error(cudaEventCreate(&event_stop_));
check_cuda_error(cudaEventRecord(event_start_, stream_));
}
float stop() {
float time;
check_cuda_error(cudaEventRecord(event_stop_, stream_));
check_cuda_error(cudaEventSynchronize(event_stop_));
check_cuda_error(cudaEventElapsedTime(&time, event_start_, event_stop_));
check_cuda_error(cudaEventDestroy(event_start_));
check_cuda_error(cudaEventDestroy(event_stop_));
return time;
}
~CudaTimer() {}
};
static double diffTime(timeval start, timeval end) {
return (end.tv_sec - start.tv_sec) * 1000 +
(end.tv_usec - start.tv_usec) * 0.001;
}
/* ***************************** common utils ****************************** */
inline void print_mem_usage(std::string time = "after allocation") {
size_t free_bytes, total_bytes;
check_cuda_error(cudaMemGetInfo(&free_bytes, &total_bytes));
float free = static_cast<float>(free_bytes) / 1024.0 / 1024.0 / 1024.0;
float total = static_cast<float>(total_bytes) / 1024.0 / 1024.0 / 1024.0;
float used = total - free;
printf("%-20s: free: %5.2f GB, total: %5.2f GB, used: %5.2f GB\n",
time.c_str(),
free,
total,
used);
}
inline int getSMVersion() {
int device{-1};
check_cuda_error(cudaGetDevice(&device));
int sm_major = 0;
int sm_minor = 0;
check_cuda_error(cudaDeviceGetAttribute(
&sm_major, cudaDevAttrComputeCapabilityMajor, device));
check_cuda_error(cudaDeviceGetAttribute(
&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
return sm_major * 10 + sm_minor;
}
inline int getMaxSharedMemoryPerBlock() {
int device{-1};
check_cuda_error(cudaGetDevice(&device));
int max_shared_memory_size = 0;
check_cuda_error(cudaDeviceGetAttribute(
&max_shared_memory_size, cudaDevAttrMaxSharedMemoryPerBlock, device));
return max_shared_memory_size;
}
inline std::string getDeviceName() {
int device{-1};
check_cuda_error(cudaGetDevice(&device));
cudaDeviceProp props;
check_cuda_error(cudaGetDeviceProperties(&props, device));
return std::string(props.name);
}
inline int div_up(int a, int n) { return (a + n - 1) / n; }
cudaError_t getSetDevice(int i_device, int* o_device = NULL);
inline int getDevice() {
int current_dev_id = 0;
check_cuda_error(cudaGetDevice(&current_dev_id));
return current_dev_id;
}
inline int getDeviceCount() {
int count = 0;
check_cuda_error(cudaGetDeviceCount(&count));
return count;
}
template <typename T>
CublasDataType getCublasDataType() {
if (std::is_same<T, half>::value) {
return HALF_DATATYPE;
} else if (std::is_same<T, float>::value) {
return FLOAT_DATATYPE;
} else {
FT_CHECK(false);
return FLOAT_DATATYPE;
}
}
template <typename T>
cudaDataType_t getCudaDataType() {
if (std::is_same<T, half>::value) {
return CUDA_R_16F;
} else if (std::is_same<T, float>::value) {
return CUDA_R_32F;
} else {
FT_CHECK(false);
return CUDA_R_32F;
}
}
template <CublasDataType T>
struct getTypeFromCudaDataType {
using Type = float;
};
template <>
struct getTypeFromCudaDataType<HALF_DATATYPE> {
using Type = half;
};
template <typename T>
struct packed_type;
template <>
struct packed_type<float> {
using type = float;
};
template <>
struct packed_type<half> {
using type = half2;
};
template <typename T>
struct num_elems;
template <>
struct num_elems<float> {
static constexpr int value = 1;
};
template <>
struct num_elems<float2> {
static constexpr int value = 2;
};
template <>
struct num_elems<float4> {
static constexpr int value = 4;
};
template <>
struct num_elems<half> {
static constexpr int value = 1;
};
template <>
struct num_elems<half2> {
static constexpr int value = 2;
};
template <typename T, int num>
struct packed_as;
template <typename T>
struct packed_as<T, 1> {
using type = T;
};
template <>
struct packed_as<half, 2> {
using type = half2;
};
template <>
struct packed_as<float, 2> {
using type = float2;
};
template <>
struct packed_as<int8_t, 2> {
using type = int16_t;
};
template <>
struct packed_as<int32_t, 2> {
using type = int2;
};
template <>
struct packed_as<half2, 1> {
using type = half;
};
inline __device__ float2 operator*(float2 a, float2 b) {
return make_float2(a.x * b.x, a.y * b.y);
}
inline __device__ float2 operator*(float2 a, float b) {
return make_float2(a.x * b, a.y * b);
}
template <typename T1, typename T2>
void compareTwoTensor(const T1* pred,
const T2* ref,
const int size,
const int print_size = 0,
const std::string filename = "") {
T1* h_pred = new T1[size];
T2* h_ref = new T2[size];
check_cuda_error(
cudaMemcpy(h_pred, pred, size * sizeof(T1), cudaMemcpyDeviceToHost));
check_cuda_error(
cudaMemcpy(h_ref, ref, size * sizeof(T2), cudaMemcpyDeviceToHost));
FILE* fd = nullptr;
if (filename != "") {
fd = fopen(filename.c_str(), "w");
fprintf(fd,
"| %10s | %10s | %10s | %10s | \n",
"pred",
"ref",
"abs_diff",
"rel_diff(%)");
}
if (print_size > 0) {
VLOG(2) << " id | pred | ref |abs diff | rel diff (%) |";
}
float mean_abs_diff = 0.0f;
float mean_rel_diff = 0.0f;
int count = 0;
for (int i = 0; i < size; i++) {
if (i < print_size) {
VLOG(2) << i << " | " << static_cast<float>(h_pred[i]) << " | "
<< static_cast<float>(h_ref[i]) << " | "
<< (abs(static_cast<float>(h_pred[i]) -
static_cast<float>(h_ref[i])))
<< " | "
<< (abs(static_cast<float>(h_pred[i]) -
static_cast<float>(h_ref[i])) /
(abs(static_cast<float>(h_ref[i])) + 1e-6f) * 100.f)
<< " | ";
}
if (static_cast<float>(h_pred[i]) == 0) {
continue;
}
count += 1;
mean_abs_diff +=
abs(static_cast<float>(h_pred[i]) - static_cast<float>(h_ref[i]));
mean_rel_diff +=
abs(static_cast<float>(h_pred[i]) - static_cast<float>(h_ref[i])) /
(abs(static_cast<float>(h_ref[i])) + 1e-6f) * 100.f;
if (fd != nullptr) {
fprintf(
fd,
"| %10.5f | %10.5f | %10.5f | %11.5f |\n",
static_cast<float>(h_pred[i]),
static_cast<float>(h_ref[i]),
abs(static_cast<float>(h_pred[i]) - static_cast<float>(h_ref[i])),
abs(static_cast<float>(h_pred[i]) - static_cast<float>(h_ref[i])) /
(abs(static_cast<float>(h_ref[i])) + 1e-6f) * 100.f);
}
}
mean_abs_diff = mean_abs_diff / static_cast<float>(count);
mean_rel_diff = mean_rel_diff / static_cast<float>(count);
VLOG(2) << "mean_abs_diff: " << mean_abs_diff
<< ", mean_rel_diff: " << mean_rel_diff;
if (fd != nullptr) {
fprintf(fd,
"mean_abs_diff: % 6.4f, mean_rel_diff: % 6.4f (%%)",
mean_abs_diff,
mean_rel_diff);
fclose(fd);
}
delete[] h_pred;
delete[] h_ref;
}
/* ************************** end of common utils ************************** */
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/llm_int8_mat_mul_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#ifndef PADDLE_WITH_HIP
#include "paddle/phi/kernels/impl/llm_int8_mat_mul_kernel_impl.h"
#endif
namespace phi {
template <typename T, typename Context>
void llm_int8_compute(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& weight,
const DenseTensor& weight_scale,
const float threshold,
DenseTensor* out) {
#if defined(PADDLE_WITH_HIP)
LOG(ERROR) << "Please compile with cublaslt, ROCM platform isn't support it";
#else
DenseTensor cublaslt_workspace;
cublaslt_workspace.Resize({{3000000}});
dev_ctx.template Alloc<int8_t>(&cublaslt_workspace);
const auto x_dims = x.dims();
const auto w_dims = weight.dims();
int k = w_dims[1];
int n = w_dims[0];
int m = x.numel() / k;
// mk * transpose(nk) = mn
llm_int8::LLMGemm<T>(dev_ctx,
&weight,
&x,
&weight_scale,
threshold,
out,
&cublaslt_workspace,
"llm_int8_mat_mul",
m,
k,
n);
#endif
}
template <typename T, typename Context>
void LLMInt8MatMulKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& weight,
const DenseTensor& weight_scale,
const float threshold,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
llm_int8_compute<T, Context>(
dev_ctx, x, weight, weight_scale, threshold, out);
}
} // namespace phi
PD_REGISTER_KERNEL(llm_int8_mat_mul,
GPU,
ALL_LAYOUT,
phi::LLMInt8MatMulKernel,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/weight_only_mat_mul_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/datatype_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#if defined(PADDLE_WITH_CUTLASS)
#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h"
#endif
namespace phi {
template <typename T, typename Context>
void WeightOnlyMatMulKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& weight,
const DenseTensor& weight_scale,
DenseTensor* out) {
#if defined(PADDLE_WITH_CUTLASS)
dev_ctx.template Alloc<T>(out);
const auto x_dims = x.dims();
const auto w_dims = weight.dims();
int k = w_dims[0];
int n = w_dims[1];
int m = x.numel() / k;
auto mixed_gemm_runner =
CutlassFpAIntBGemmRunner<typename PDDataTypeTraits<T>::DataType,
uint8_t>();
int mixgemm_max_size = std::max(n, k);
DenseTensor mixgemm_workspace;
int64_t mixgemm_workspace_size_bytes =
mixed_gemm_runner.getWorkspaceSize(m, mixgemm_max_size, mixgemm_max_size);
mixgemm_workspace.Resize({mixgemm_workspace_size_bytes});
dev_ctx.template Alloc<uint8_t>(&mixgemm_workspace);
char* mixgemm_workspace_data =
reinterpret_cast<char*>(mixgemm_workspace.data<uint8_t>());
mixed_gemm_runner.gemm(
reinterpret_cast<const typename PDDataTypeTraits<T>::DataType*>(
x.data<T>()),
reinterpret_cast<const uint8_t*>(weight.data<int8_t>()),
reinterpret_cast<const float*>(weight_scale.data<float>()),
reinterpret_cast<typename PDDataTypeTraits<T>::DataType*>(out->data<T>()),
m,
n,
k,
mixgemm_workspace_data,
mixgemm_workspace_size_bytes,
dev_ctx.stream());
#else
LOG(ERROR) << "Please compile with cutlass to EnableUseCutlass()";
#endif
}
} // namespace phi
PD_REGISTER_KERNEL(weight_only_mat_mul,
GPU,
ALL_LAYOUT,
phi::WeightOnlyMatMulKernel,
phi::dtype::float16) {}
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <vector>
#include "paddle/phi/common/datatype_traits.h"
#include "paddle/phi/kernels/funcs/cublaslt.h"
#include "paddle/phi/kernels/funcs/quant_dequant.h"
#pragma once
namespace phi {
namespace llm_int8 {
constexpr int32_t WARP_SIZE = 32;
constexpr int32_t HALF_WARP = 16;
constexpr float QUANT_MAX_BOUND = 127.0;
constexpr float QUANT_MIN_BOUND = -127.0;
constexpr int32_t kBlockSize = 256;
constexpr int32_t kNumWaves = 16;
inline cudaError_t GetGridSize(int64_t n, int* num_blocks) {
int dev;
{
cudaError_t err = cudaGetDevice(&dev);
if (err != cudaSuccess) {
return err;
}
}
int sm_count;
{
cudaError_t err =
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);
if (err != cudaSuccess) {
return err;
}
}
int tpm;
{
cudaError_t err = cudaDeviceGetAttribute(
&tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev);
if (err != cudaSuccess) {
return err;
}
}
*num_blocks =
std::max<int>(1,
std::min<int64_t>((n + kBlockSize - 1) / kBlockSize,
sm_count * tpm / kBlockSize * kNumWaves));
return cudaSuccess;
}
template <class Func>
inline cudaError_t GetMaxOccupancyBlocks(Func func,
int64_t block_size,
size_t dynamic_smem_size,
int64_t max_blocks,
int* num_blocks) {
int dev;
{
cudaError_t err = cudaGetDevice(&dev);
if (err != cudaSuccess) {
return err;
}
}
int sm_count;
{
cudaError_t err =
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev);
if (err != cudaSuccess) {
return err;
}
}
int max_active_blocks;
{
cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, func, block_size, dynamic_smem_size);
}
*num_blocks = std::max<int>(
1,
std::min<int64_t>(max_blocks, sm_count * max_active_blocks * kNumWaves));
return cudaSuccess;
}
template <typename T>
struct MaxFunc {
__device__ T operator()(T a, T b) { return max(a, b); }
};
template <>
struct MaxFunc<half> {
__device__ half operator()(half a, half b) {
#if __CUDA_ARCH__ >= 800
return __hmax(a, b);
#else
return max(static_cast<float>(a), static_cast<float>(b));
#endif
}
};
#ifdef PADDLE_CUDA_BF16
template <>
struct MaxFunc<__nv_bfloat16> {
__device__ __nv_bfloat16 operator()(__nv_bfloat16 a, __nv_bfloat16 b) {
#if __CUDA_ARCH__ >= 800
return __hmax(a, b);
#else
return max(static_cast<float>(a), static_cast<float>(b));
#endif
}
};
#endif
template <typename T>
struct AbsFunc {
__device__ T operator()(T x) { return abs(x); }
};
template <>
struct AbsFunc<half> {
__device__ half operator()(half x) {
#if __CUDA_ARCH__ >= 800
return __habs(x);
#else
return abs(static_cast<float>(x));
#endif
}
};
#ifdef PADDLE_CUDA_BF16
template <>
struct AbsFunc<__nv_bfloat16> {
__device__ __nv_bfloat16 operator()(__nv_bfloat16 x) {
#if __CUDA_ARCH__ >= 800
return __habs(x);
#else
return abs(static_cast<float>(x));
#endif
}
};
#endif
template <typename T>
struct QuantFunc {
HOSTDEVICE int8_t operator()(T x, float inverse_range) {
float tmp = static_cast<float>(x) * QUANT_MAX_BOUND * inverse_range;
tmp = round(tmp);
if (tmp > QUANT_MAX_BOUND)
tmp = QUANT_MAX_BOUND;
else if (tmp < QUANT_MIN_BOUND)
tmp = QUANT_MIN_BOUND;
return static_cast<int8_t>(tmp);
}
};
template <typename T>
struct DequantFunc {
HOSTDEVICE T operator()(int8_t x, T scale) {
return static_cast<T>(static_cast<float>(x) * static_cast<float>(scale));
}
HOSTDEVICE T operator()(int32_t x, T input_range, T weight_scale) {
return static_cast<T>(static_cast<float>(x) *
static_cast<float>(input_range) *
static_cast<float>(weight_scale) / (127.0f));
}
HOSTDEVICE T operator()(int8_t x, float scale) {
return static_cast<T>(static_cast<float>(x) * static_cast<float>(scale));
}
HOSTDEVICE T operator()(int32_t x, float input_range, float weight_scale) {
return static_cast<T>(static_cast<float>(x) *
static_cast<float>(input_range) *
static_cast<float>(weight_scale) / (127.0f));
}
};
template <typename T, typename Vec, int VecSize>
__inline__ __device__ T LocalReduceMax(Vec& vec) { // NOLINT
T local_max = static_cast<T>(0.0);
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
local_max = vec[i] > local_max ? vec[i] : local_max;
}
return local_max;
}
template <typename T>
__inline__ __device__ T WarpReduceAbsMax(T val, unsigned lane_mask) {
#pragma unroll
for (int mask = HALF_WARP; mask > 0; mask >>= 1) {
val = MaxFunc<T>()(val, __shfl_xor_sync(lane_mask, val, mask, WARP_SIZE));
}
return val;
}
template <typename T>
__inline__ __device__ T BlockReduceAbsMax(T val, unsigned mask) {
static __shared__ T smem[WARP_SIZE];
int32_t lane_id = threadIdx.x & 0x1f;
int32_t warp_id = threadIdx.x >> 5;
val = WarpReduceAbsMax(val, mask);
if (lane_id == 0) {
smem[warp_id] = val;
}
__syncthreads();
T abs_max_val = (threadIdx.x < blockDim.x / WARP_SIZE) ? smem[threadIdx.x]
: static_cast<T>(0.0f);
abs_max_val = WarpReduceAbsMax(abs_max_val, mask);
return abs_max_val;
}
template <typename T, typename ComputeType, int VecSize>
__global__ void ReduceAbsMaxKernel(const T* x,
const float threshold,
const int32_t rows,
const int32_t cols,
float* row_ranges,
int32_t* outlier_idx) {
using InVec = phi::AlignedVector<T, VecSize>;
using ComputeVec = phi::AlignedVector<ComputeType, VecSize>;
InVec in_vec;
ComputeVec abs_max_vec;
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
abs_max_vec[i] = 0.0f;
}
ComputeType local_max_val = static_cast<ComputeType>(0.0f);
for (int row_idx = blockIdx.x; row_idx < rows; row_idx += gridDim.x) {
for (int col_idx = threadIdx.x * VecSize; col_idx < cols;
col_idx += blockDim.x * VecSize) {
int32_t linear_index = row_idx * cols + col_idx;
phi::Load<T, VecSize>(x + linear_index, &in_vec);
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
in_vec[i] = AbsFunc<T>()(in_vec[i]);
if (in_vec[i] > static_cast<T>(threshold)) {
int32_t index = col_idx + i;
int32_t int_index = index / 32;
int32_t inner_index = index % 32;
atomicOr(outlier_idx + int_index, (1 << inner_index));
in_vec[i] = 0.0;
}
abs_max_vec[i] = MaxFunc<ComputeType>()(
abs_max_vec[i], static_cast<ComputeType>(in_vec[i]));
}
}
local_max_val =
LocalReduceMax<ComputeType, ComputeVec, VecSize>(abs_max_vec);
ComputeType tmp_max_val =
BlockReduceAbsMax<ComputeType>(local_max_val, 0xffffffff);
if (threadIdx.x == 0) {
row_ranges[row_idx] = tmp_max_val;
}
}
}
template <typename T, int VecSize>
__global__ void QuantActKernel(const T* x,
const int32_t elem_cnt,
const int32_t cols,
const float* row_ranges,
const int32_t* outlier_idx,
int8_t* quant_x) {
using InVec = phi::AlignedVector<T, VecSize>;
using OutVec = phi::AlignedVector<int8_t, VecSize>;
InVec in_vec;
OutVec out_vec;
for (int linear_index = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
linear_index < elem_cnt;
linear_index += gridDim.x * blockDim.x * VecSize) {
int row_idx = linear_index / cols;
int col_idx =
linear_index - row_idx * cols; // equal to linear_index % cols
phi::Load<T, VecSize>(x + linear_index, &in_vec);
int32_t local_outlier_idx = outlier_idx[col_idx / 32];
float scale = 1.0f / row_ranges[row_idx];
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
int32_t index = linear_index + i;
if (local_outlier_idx & (1 << (index % 32))) {
out_vec[i] = 0;
} else {
out_vec[i] = QuantFunc<T>()(in_vec[i], scale);
}
}
phi::Store(out_vec, quant_x + linear_index);
}
}
template <typename T, int VecSize>
__global__ void Fill(T* input, T value, int64_t num) {
phi::AlignedVector<T, VecSize> in_vec;
int stride = blockDim.x * gridDim.x * VecSize;
int base_idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
for (int idx = base_idx; idx < num; idx += stride) {
#pragma unroll
for (int j = 0; j < VecSize; ++j) {
in_vec[j] = value;
}
phi::Store(in_vec, input + idx);
}
}
template <typename T>
__global__ void SplitKernel(const T* x,
const int8_t* weight,
const float* weight_scale,
const int32_t* outlier_idx,
T* sub_x,
T* sub_weight,
int m,
int k,
int n,
int num_outlier_idx,
int kfp_num,
int sub_x_elem_cnt,
int sub_w_elem_cnt,
int elem_cnt) {
extern __shared__ int32_t k_ids_shm[];
int32_t cnt = 0;
if (threadIdx.x == 0) {
#pragma unroll
for (int i = 0; i < kfp_num; ++i) {
k_ids_shm[i] = -1;
}
for (int i = 0; i < num_outlier_idx; ++i) {
int32_t outlier_id = outlier_idx[i];
if (outlier_id == 0) continue;
for (int j = 0; j < 32; ++j) {
if (outlier_id & (1 << j)) {
k_ids_shm[cnt++] = i * 32 + j;
}
}
}
}
__syncthreads();
for (int linear_idx = blockIdx.x * blockDim.x + threadIdx.x;
linear_idx < elem_cnt;
linear_idx += blockDim.x * gridDim.x) {
int32_t row_idx = linear_idx / kfp_num; // n
int32_t col_idx = linear_idx % kfp_num; // k
int32_t k_id = k_ids_shm[col_idx];
if (k_id == -1) continue;
if (linear_idx < sub_x_elem_cnt) {
sub_x[row_idx * kfp_num + col_idx] = x[row_idx * k + k_id];
}
if (linear_idx < sub_w_elem_cnt) {
constexpr int32_t k_permute_const = 8;
int32_t k_mod_16 = k_id % 16;
int32_t temp_k_expr_1 = k_mod_16 - k_mod_16 / 8 * 8;
int32_t temp_k_expr_2 = k_mod_16 / 8;
int32_t permute_kk = temp_k_expr_1 + temp_k_expr_2 +
(temp_k_expr_2 + 1) % 2 * k_mod_16 * 2 / 2 +
temp_k_expr_1 * temp_k_expr_2 + k_id / 16 * 16;
int32_t permute_index = permute_kk % 64 + permute_kk / 64 * 128 +
64 * (row_idx % 2) + k * 2 * (row_idx / 2);
int8_t shifted_weight = static_cast<int8_t>(
static_cast<int32_t>(weight[permute_index]) - 128);
sub_weight[row_idx * kfp_num + col_idx] =
DequantFunc<T>()(shifted_weight, weight_scale[row_idx]);
}
}
}
__global__ static void UpdateOutlier(int32_t* outlier_idx, int32_t* total_num) {
constexpr int IntSize = 32;
int32_t outlier_val = outlier_idx[threadIdx.x];
#pragma unroll
for (int i = 0; i < IntSize; i++) {
while (outlier_val) {
outlier_val = outlier_val & (outlier_val - 1);
// ++kfp_num;
atomicAdd(total_num, 1);
}
}
}
// Input: x:dequantized_fp16:[m, n], x_fp16:T:[m, n], input_range:T:[m],
// weight_scale:T:[n] Outpuy: y:T:[m, n]
template <typename T, int VecSize>
__global__ void DequantActivationMergeKernel(const T* x,
const T* x_fp,
T* y,
const int32_t elem_cnt) {
using FpVec = phi::AlignedVector<T, VecSize>;
FpVec x_fp_vec;
FpVec out_vec;
FpVec x_vec;
for (int linear_idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
linear_idx < elem_cnt;
linear_idx += gridDim.x * blockDim.x * VecSize) {
phi::Load(x_fp + linear_idx, &x_fp_vec);
phi::Load(x + linear_idx, &x_vec);
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
out_vec[i] = x_fp_vec[i] + (x_vec[i] / static_cast<T>(127.0f));
}
phi::Store(out_vec, y + linear_idx);
}
}
// Input: x:int32:[m, n], x_fp16:T:[m, n], input_range:T:[m], weight_scale:T:[n]
// Outpuy: y:T:[m, n]
template <typename T, int VecSize>
__global__ void DequantMergeKernel(const int32_t* x,
const T* x_fp,
const float* input_range,
const float* weight_scale,
T* y,
int m,
int n) {
using FpVec = phi::AlignedVector<T, VecSize>;
using IntVec = phi::AlignedVector<int32_t, VecSize>;
FpVec x_fp_vec;
FpVec out_vec;
IntVec x_vec;
for (int row_idx = blockIdx.x; row_idx < m; row_idx += gridDim.x) {
for (int col_idx = threadIdx.x * VecSize; col_idx < n;
col_idx += blockDim.x * VecSize) {
int linear_idx = row_idx * n + col_idx;
phi::Load(x_fp + linear_idx, &x_fp_vec);
phi::Load(x + linear_idx, &x_vec);
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
T dequant_x_fp = DequantFunc<T>()(
x_vec[i], input_range[row_idx], weight_scale[col_idx + i]);
out_vec[i] = x_fp_vec[i] + dequant_x_fp;
}
phi::Store(out_vec, y + linear_idx);
}
}
}
template <typename T>
void LaunchFillKernel(T* input,
T value,
int64_t num,
backends::gpu::GpuLaunchConfig* gpu_config,
gpuStream_t stream) {
constexpr int VecSize = 16 / sizeof(T);
Fill<T, VecSize>
<<<gpu_config->block_per_grid, gpu_config->thread_per_block, 0, stream>>>(
input, value, num);
}
template <typename T>
void LaunchReduceAbsMaxQuantKernel(const T* x,
const float threshold,
const int32_t rows,
const int32_t cols,
float* row_ranges,
int32_t* outlier_idx,
int8_t* quant_x,
gpuStream_t stream) {
constexpr int VecSize = 16 / sizeof(T);
using DataT = typename PDDataTypeTraits<T>::DataType;
using ComputeType = float;
int32_t reduce_kernel_num_blocks;
PADDLE_ENFORCE_GPU_SUCCESS(
GetMaxOccupancyBlocks(ReduceAbsMaxKernel<DataT, ComputeType, VecSize>,
kBlockSize,
0,
rows,
&reduce_kernel_num_blocks));
assert((cols % VecSize == 0));
ReduceAbsMaxKernel<DataT, ComputeType, VecSize>
<<<reduce_kernel_num_blocks, kBlockSize, 0, stream>>>(
reinterpret_cast<const DataT*>(x),
threshold,
rows,
cols,
row_ranges,
outlier_idx);
const int32_t elem_cnt = rows * cols;
const int32_t vectorized_elem_cnt = elem_cnt / VecSize;
int32_t quant_kernel_num_blocks;
PADDLE_ENFORCE_GPU_SUCCESS(
GetGridSize(vectorized_elem_cnt, &quant_kernel_num_blocks));
QuantActKernel<DataT, VecSize>
<<<quant_kernel_num_blocks, kBlockSize, 0, stream>>>(
reinterpret_cast<const DataT*>(x),
elem_cnt,
cols,
row_ranges,
outlier_idx,
quant_x);
}
template <typename T>
void LaunchSplitKernel(const T* x,
const int8_t* weight,
const float* weight_scale,
const int32_t* outlier_idx,
T* sub_x,
T* sub_weight,
int m,
int k,
int n,
int kfp_num,
gpuStream_t stream) {
int max_row = m > n ? m : n;
const int elem_cnt = max_row * kfp_num;
int num_blocks = 1;
PADDLE_ENFORCE_GPU_SUCCESS(GetGridSize(elem_cnt, &num_blocks));
int64_t num_outlier_idx = (k + 31) / 32;
const int32_t sub_x_elem_cnt = m * kfp_num;
const int32_t sub_w_elem_cnt = n * kfp_num;
using DataT = typename PDDataTypeTraits<T>::DataType;
SplitKernel<DataT>
<<<num_blocks, kBlockSize, kfp_num * sizeof(int32_t), stream>>>(
reinterpret_cast<const DataT*>(x),
weight,
weight_scale,
outlier_idx,
reinterpret_cast<DataT*>(sub_x),
reinterpret_cast<DataT*>(sub_weight),
m,
k,
n,
num_outlier_idx,
kfp_num,
sub_x_elem_cnt,
sub_w_elem_cnt,
elem_cnt);
}
template <typename T>
void LaunchDequantMergeKernel(const int32_t* x,
const T* x_fp,
const float* input_range,
const float* weight_scale,
T* y,
int m,
int n,
gpuStream_t stream) {
constexpr int NumThreads = 256;
constexpr int VecSize = 16 / sizeof(T);
using DataT = typename PDDataTypeTraits<T>::DataType;
DequantMergeKernel<DataT, VecSize><<<m, NumThreads, 0, stream>>>(
x,
reinterpret_cast<const DataT*>(x_fp),
reinterpret_cast<const float*>(input_range),
reinterpret_cast<const float*>(weight_scale),
reinterpret_cast<DataT*>(y),
m,
n);
}
template <typename T>
void LLMGemm(const phi::GPUContext& dev_ctx,
const phi::DenseTensor* weight,
const phi::DenseTensor* input,
const phi::DenseTensor* weight_scale,
const float threshold,
phi::DenseTensor* output,
phi::DenseTensor* workspace,
std::string name,
int m,
int k,
int n) {
// absmax, quant, outlier
int64_t num_outlier_idx = (k + 31) / 32;
phi::DenseTensor row_ranges, outlier_idx, quant_input;
row_ranges.Resize({m});
outlier_idx.Resize({num_outlier_idx});
quant_input.Resize({m, k});
dev_ctx.Alloc<float>(&row_ranges);
dev_ctx.Alloc<int32_t>(&outlier_idx);
dev_ctx.Alloc<int8_t>(&quant_input);
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(outlier_idx.data<int32_t>(),
0,
num_outlier_idx * sizeof(int32_t),
dev_ctx.stream()));
LaunchReduceAbsMaxQuantKernel(input->data<T>(),
threshold,
m,
k,
row_ranges.data<float>(),
outlier_idx.data<int32_t>(),
quant_input.data<int8_t>(),
dev_ctx.stream());
int32_t kfp_num = 0;
phi::DenseTensor kfp_num_tensor;
kfp_num_tensor.Resize({1});
dev_ctx.Alloc<int32_t>(&kfp_num_tensor);
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
kfp_num_tensor.data<int32_t>(), 0, sizeof(int32_t), dev_ctx.stream()));
UpdateOutlier<<<1, num_outlier_idx, 0, dev_ctx.stream()>>>(
outlier_idx.data<int32_t>(), kfp_num_tensor.data<int32_t>());
cudaMemcpy(&kfp_num,
kfp_num_tensor.data<int32_t>(),
sizeof(int32_t),
cudaMemcpyDeviceToHost);
phi::DenseTensor sub_out;
sub_out.Resize({m, n});
dev_ctx.Alloc<T>(&sub_out);
if (kfp_num != 0) {
phi::DenseTensor sub_input, sub_weight;
sub_input.Resize({m, kfp_num});
sub_weight.Resize({n, kfp_num});
dev_ctx.Alloc<T>(&sub_input);
dev_ctx.Alloc<T>(&sub_weight);
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(sub_input.data<T>(),
0,
sub_input.numel() * sizeof(T),
dev_ctx.stream()));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(sub_weight.data<T>(),
0,
sub_weight.numel() * sizeof(T),
dev_ctx.stream()));
LaunchSplitKernel(input->data<T>(),
weight->data<int8_t>(),
weight_scale->data<float>(),
outlier_idx.data<int32_t>(),
sub_input.data<T>(),
sub_weight.data<T>(),
m,
k,
n,
kfp_num,
dev_ctx.stream());
CBLAS_TRANSPOSE transA = CblasNoTrans;
CBLAS_TRANSPOSE transB = CblasTrans;
T alpha = static_cast<T>(1.0);
T beta = static_cast<T>(0.0);
// (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out)
auto blas = phi::funcs::GetBlas<phi::GPUContext, T>(dev_ctx);
blas.GEMM(transA,
transB,
m,
n,
kfp_num,
alpha,
sub_input.data<T>(),
sub_weight.data<T>(),
beta,
sub_out.data<T>());
// PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
} else {
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
sub_out.data<T>(), 0, sub_out.numel() * sizeof(T), dev_ctx.stream()));
}
phi::DenseTensor int_out;
int_out.Resize({m, n});
dev_ctx.Alloc<int32_t>(&int_out);
{
auto helper = std::make_unique<CublasLtHelper>(m, k, n);
helper->GEMM(quant_input.data<int8_t>(),
weight->data<int8_t>(),
int_out.data<int32_t>(),
dev_ctx.stream(),
(void*)workspace->data<int8_t>());
}
// PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
LaunchDequantMergeKernel<T>(int_out.data<int32_t>(),
sub_out.data<T>(),
row_ranges.data<float>(),
weight_scale->data<float>(),
output->data<T>(),
m,
n,
dev_ctx.stream());
// PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
}
} // namespace llm_int8
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_PHI_KERNELS_IMPL_QUANT_FOR_COMPRESS_KERNEL_IMPL_H_
#define PADDLE_PHI_KERNELS_IMPL_QUANT_FOR_COMPRESS_KERNEL_IMPL_H_
#include <iostream>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
namespace phi {
template <typename T>
inline T xabs(const T x) {
return x < static_cast<T>(0.0) ? -x : x;
}
template <typename T>
void per_channel_scale(float* scale, const T* input, size_t m, size_t n) {
for (size_t i = 0; i < n; ++i) {
T max = input[i];
for (size_t j = 0; j < m; ++j) {
max = xabs(input[j * n + i]) > max ? xabs(input[j * n + i]) : max;
}
scale[i] = static_cast<float>(max) / 127.0;
}
}
template <typename T, typename D>
void per_channel_quant(
D* output, const T* input, const float* scale, size_t m, size_t n) {
for (size_t i = 0; i < m; i++) {
for (size_t j = 0; j < n; j++) {
output[i * n + j] = static_cast<D>(
round(static_cast<float>(input[i * n + j]) / scale[j]));
}
}
}
void row_major_to_column_major(int8_t* col_major_tensor,
const int8_t* row_major_tensor,
const std::vector<size_t>& shape) {
size_t m = shape[0];
size_t n = shape[1];
for (size_t i = 0; i < m * n; i++) {
size_t im = i / n;
size_t in = i % n;
col_major_tensor[in * m + im] = row_major_tensor[im * n + in];
}
}
void add_bias_and_interleave_int8s_inplace(int8_t* int8_tensor_ptr,
size_t num_elts) {
int8_t* int8_tensor = reinterpret_cast<int8_t*>(int8_tensor_ptr);
for (size_t ii = 0; ii < num_elts; ++ii) {
int8_tensor[ii] =
static_cast<int8_t>(static_cast<int>(int8_tensor[ii]) + 128);
}
// Step 2 will transform the layout of a 32-bit register in CUDA in order to
// match the int4 layout. This has no performance benefit and is purely so
// that int4 and int8 have the same layout. Pictorially, this does the
// following: bit 32 0
// [elt_3 elt_2 elt_1 elt_0] (each elt occupies 8 bits)
//
// And it will rearrange the output 32 bit register to be the following:
// bit 32 0
// [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits)
for (size_t base = 0; base < num_elts; base += 4) {
std::swap(int8_tensor[base + 1], int8_tensor[base + 2]);
}
}
void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor,
const int8_t* quantized_tensor,
const std::vector<size_t>& shape,
const int64_t arch_version) {
// We only want to run this step for weight only quant.
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
const int BITS_PER_ELT = 8;
const int K = 16 / BITS_PER_ELT;
// const int ELTS_PER_BYTE = 8 / BITS_PER_ELT;
const int ELTS_PER_REG = 32 / BITS_PER_ELT;
const uint32_t* input_byte_ptr =
reinterpret_cast<const uint32_t*>(quantized_tensor);
uint32_t* output_byte_ptr =
reinterpret_cast<uint32_t*>(permuted_quantized_tensor);
// int MMA_SHAPE_N = 8;
int B_ROWS_PER_MMA = 8 * K;
const int elts_in_int32 = 32 / BITS_PER_ELT;
const int num_vec_cols = num_cols / elts_in_int32;
// The code is written as below so it works for both int8 and packed int4.
for (size_t base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) {
for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) {
for (int write_col = 0; write_col < num_vec_cols; ++write_col) {
const int write_row = base_row + tile_row;
const int tile_read_row = 8 * (((tile_row % ELTS_PER_REG) / 2)) +
tile_row % 2 + 2 * (tile_row / ELTS_PER_REG);
const int read_row = base_row + tile_read_row;
const int read_col = write_col;
const int64_t read_offset = int64_t(read_row) * num_vec_cols + read_col;
const int64_t write_offset =
int64_t(write_row) * num_vec_cols + write_col;
output_byte_ptr[write_offset] = input_byte_ptr[read_offset];
}
}
}
}
void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor,
const int8_t* quantized_tensor,
const std::vector<size_t>& shape) {
// We only want to run this step for weight only quant.
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
const size_t BITS_PER_ELT = 8;
const size_t elts_in_int32 = 32 / BITS_PER_ELT;
const size_t rows_per_tile = 64;
const uint32_t* input_byte_ptr =
reinterpret_cast<const uint32_t*>(quantized_tensor);
uint32_t* output_byte_ptr =
reinterpret_cast<uint32_t*>(interleaved_quantized_tensor);
const size_t num_vec_rows = num_rows / elts_in_int32;
const size_t vec_rows_per_tile = rows_per_tile / elts_in_int32;
const size_t interleave = 2;
for (size_t read_col = 0; read_col < num_cols; ++read_col) {
const size_t write_col = read_col / interleave;
for (size_t base_vec_row = 0; base_vec_row < num_vec_rows;
base_vec_row += vec_rows_per_tile) {
for (size_t vec_read_row = base_vec_row;
vec_read_row <
std::min(num_vec_rows, base_vec_row + vec_rows_per_tile);
++vec_read_row) {
const size_t vec_write_row =
interleave * base_vec_row +
vec_rows_per_tile * (read_col % interleave) +
vec_read_row % vec_rows_per_tile;
const size_t read_offset =
size_t(read_col) * num_vec_rows + vec_read_row;
const size_t write_offset =
size_t(write_col) * num_vec_rows * interleave + vec_write_row;
output_byte_ptr[write_offset] = input_byte_ptr[read_offset];
}
}
}
}
} // namespace phi
#endif // PADDLE_PHI_KERNELS_IMPL_QUANT_FOR_COMPRESS_KERNEL_IMPL_H_
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void LLMInt8MatMulKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& weight,
const DenseTensor& weight_scale,
const float threshold,
DenseTensor* out);
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void QuantForCompressKernel(const Context& dev_ctx,
const DenseTensor& x,
int bits,
const std::string& layout,
DenseTensor* out,
DenseTensor* scale);
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void WeightOnlyMatMulKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& weight,
const DenseTensor& weight_scale,
DenseTensor* out);
} // namespace phi
...@@ -381,6 +381,7 @@ class LayerHelperBase: ...@@ -381,6 +381,7 @@ class LayerHelperBase:
and dtype != core.VarDesc.VarType.FP64 and dtype != core.VarDesc.VarType.FP64
and dtype != core.VarDesc.VarType.FP16 and dtype != core.VarDesc.VarType.FP16
and dtype != core.VarDesc.VarType.BF16 and dtype != core.VarDesc.VarType.BF16
and dtype != core.VarDesc.VarType.INT8
): ):
raise TypeError( raise TypeError(
"Can not create parameter with default initializer when dtype is not ['float16', 'float32', 'float64', 'bfloat16'] type. Set default_initializer to fit the parameter dtype!" "Can not create parameter with default initializer when dtype is not ['float16', 'float32', 'float64', 'bfloat16'] type. Set default_initializer to fit the parameter dtype!"
...@@ -392,6 +393,7 @@ class LayerHelperBase: ...@@ -392,6 +393,7 @@ class LayerHelperBase:
'float64', 'float64',
'bfloat16', 'bfloat16',
'float', 'float',
'int8',
]: ]:
raise TypeError( raise TypeError(
"Can not create parameter with default initializer when dtype is not ['float16', 'float32', 'float64', 'bfloat16', 'float'] type. Set default_initializer to fit the parameter dtype!" "Can not create parameter with default initializer when dtype is not ['float16', 'float32', 'float64', 'bfloat16', 'float'] type. Set default_initializer to fit the parameter dtype!"
......
...@@ -71,7 +71,7 @@ from .layer.common import AlphaDropout # noqa: F401 ...@@ -71,7 +71,7 @@ from .layer.common import AlphaDropout # noqa: F401
from .layer.common import Unfold # noqa: F401 from .layer.common import Unfold # noqa: F401
from .layer.common import Fold # noqa: F401 from .layer.common import Fold # noqa: F401
from .layer.common import Unflatten # noqa: F401 from .layer.common import Unflatten # noqa: F401
from .layer.common import LinearCompress # noqa: F401
from .layer.pooling import AvgPool1D # noqa: F401 from .layer.pooling import AvgPool1D # noqa: F401
from .layer.pooling import AvgPool2D # noqa: F401 from .layer.pooling import AvgPool2D # noqa: F401
from .layer.pooling import AvgPool3D # noqa: F401 from .layer.pooling import AvgPool3D # noqa: F401
......
...@@ -62,6 +62,8 @@ from .common import interpolate # noqa: F401 ...@@ -62,6 +62,8 @@ from .common import interpolate # noqa: F401
from .common import upsample # noqa: F401 from .common import upsample # noqa: F401
from .common import bilinear # noqa: F401 from .common import bilinear # noqa: F401
from .common import class_center_sample # noqa: F401 from .common import class_center_sample # noqa: F401
from .common import quant_for_compress # noqa: F401
from .common import linear_compress # noqa: F401
from .conv import conv1d # noqa: F401 from .conv import conv1d # noqa: F401
from .conv import conv1d_transpose # noqa: F401 from .conv import conv1d_transpose # noqa: F401
from .common import linear # noqa: F401 from .common import linear # noqa: F401
......
...@@ -1876,6 +1876,78 @@ def linear(x, weight, bias=None, name=None): ...@@ -1876,6 +1876,78 @@ def linear(x, weight, bias=None, name=None):
return res return res
def quant_for_compress(x, bits=8, layout="weight_only"):
return _C_ops.quant_for_compress(x, bits, layout)
def linear_compress(
x,
weight,
weight_scale,
bias=None,
bits=8,
algo="llm.int8",
name=None,
config=None,
):
if in_dynamic_mode():
if algo == "llm.int8":
y = _C_ops.llm_int8_mat_mul(
x, weight, weight_scale, config['threshold']
)
elif algo == "weight_only":
y = _C_ops.weight_only_mat_mul(x, weight, weight_scale)
else:
raise ValueError(
"Unknown algo: '{}'. It can only be 'llm.int8' or 'weight_only'.".format(
algo
)
)
if bias is not None:
y = paddle.add(y, bias)
return y
else:
helper = LayerHelper('linear_compress', **locals())
dtype = x.dtype
check_variable_and_dtype(x, 'x', ['float16'], 'linear_compress')
check_dtype(dtype, 'dtype', ['float16'], 'linear_compress')
if algo == "llm.int8":
type = "llm_int8_matmul"
inputs = {'X': [x], 'Y': [weight], 'weight_scale': [weight_scale]}
attrs = {'algo': algo, 'threshold': config['threshold']}
elif algo == "weight_only":
type = "weight_only_matmul"
inputs = {'X': [x], 'Y': [weight], 'weight_scale': [weight_scale]}
attrs = {}
else:
raise ValueError(
"Unknown algo: '{}'. It can only be 'llm.int8' or 'weight_only'.".format(
algo
)
)
tmp = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type=type,
inputs=inputs,
outputs={'Out': tmp},
attrs=attrs,
)
if bias is not None:
res = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='elementwise_add',
inputs={'X': [tmp], 'Y': [bias]},
outputs={'Out': [res]},
attrs={'axis': -1},
)
else:
res = tmp
return res
def label_smooth(label, prior_dist=None, epsilon=0.1, name=None): def label_smooth(label, prior_dist=None, epsilon=0.1, name=None):
r""" r"""
Label smoothing is a mechanism to regularize the classifier layer and is called Label smoothing is a mechanism to regularize the classifier layer and is called
......
...@@ -81,6 +81,12 @@ class NumpyArrayInitializer(Initializer): ...@@ -81,6 +81,12 @@ class NumpyArrayInitializer(Initializer):
elif out_dtype == core.VarDesc.VarType.INT32: elif out_dtype == core.VarDesc.VarType.INT32:
value_name = "int32_values" value_name = "int32_values"
values = [int(v) for v in np_value.flat] values = [int(v) for v in np_value.flat]
elif (
out_dtype == core.VarDesc.VarType.INT8
or out_dtype == core.VarDesc.VarType.UINT8
):
value_name = "int8_values"
values = [int(v) for v in np_value.flat]
else: else:
raise ValueError("Unsupported dtype %s", self._value.dtype) raise ValueError("Unsupported dtype %s", self._value.dtype)
if self._value.size > 1024 * 1024 * 1024: if self._value.size > 1024 * 1024 * 1024:
......
...@@ -183,6 +183,167 @@ class Linear(Layer): ...@@ -183,6 +183,167 @@ class Linear(Layer):
) )
class LinearCompress(Layer):
r"""
Fully-connected linear transformation layer. For each input :math:`X` ,
the equation is:
.. math::
Out = XW + b
where :math:`W` is the weight and :math:`b` is the bias.
Linear layer takes only one multi-dimensional tensor as input with the
shape :math:`[batch\_size, *, in\_features]` , where :math:`*` means any
number of additional dimensions. It multiplies input tensor with the weight
(a 2-D tensor of shape :math:`[in\_features, out\_features]` ) and produces
an output tensor of shape :math:`[batch\_size, *, out\_features]` .
If :math:`bias\_attr` is not False, the bias (a 1-D tensor of
shape :math:`[out\_features]` ) will be created and added to the output.
Parameters:
in_features (int): The number of input units.
out_features (int): The number of output units.
weight_attr (ParamAttr, optional): The attribute for the weight of this layer.
The default value is None. If the Initializer of the
param_attr is not set, the parameter is initialized with Xavier.
For detailed information, please refer to paddle.ParamAttr.
bias_attr (ParamAttr|bool, optional): The attribute for the bias of this layer.
If it is set to False, no bias will be added to the output.
If it is set to None or one kind of ParamAttr, a bias parameter will
be created according to ParamAttr. For detailed information, please refer
to paddle.ParamAttr. The default value is None and the bias will be
initialized to zero.
name (str, optional): Normally there is no need for user to set this parameter.
For detailed information, please refer to :ref:`api_guide_Name` .
bits (int, optional): The attribute to set num of bits in quant during weight_only,
it must be set as 8, default: 8.
algo (str, optional): The attribute to set algorithm of cpmoress, it must be set as 'weight_only'
or 'llm.int8', default: weight_only.
config (dict, optional): The parameter config for algorithm of cpmoress.
For llm.int8, it should be set as {'threshold': 6.0}, default: {'threshold': 6.0}.
Attribute:
**weight** (Parameter): the learnable weight of this layer.
**bias** (Parameter): the learnable bias of this layer.
Shape:
- input: Multi-dimentional tensor with shape :math:`[batch\_size, *, in\_features]` . Its data types are float16.
- output: Multi-dimentional tensor with shape :math:`[batch\_size, *, out\_features]` . The data type is the same as the input .
Examples:
.. code-block:: python
import paddle
# Define the linear layer.
paddle.set_default_dtype('float16')
weight_attr = paddle.ParamAttr(
name="weight",
initializer=paddle.nn.initializer.Constant(value=0.5))
bias_attr = paddle.ParamAttr(
name="bias",
initializer=paddle.nn.initializer.Constant(value=1.0))
linear = paddle.nn.LinearCompress(128, 64, weight_attr=weight_attr, bias_attr=bias_attr, bits=8, algo='weight_only')
x = paddle.randn((3, 128), dtype="float16")
y = linear(x)
"""
def __init__(
self,
in_features,
out_features,
weight_attr=None,
bias_attr=None,
name=None,
bits=8,
algo="weight_only",
config={'threshold': 6.0},
):
super().__init__()
self._dtype = self._helper.get_default_dtype()
self._weight_attr = weight_attr
self._bias_attr = bias_attr
self.weight = self.create_parameter(
shape=[in_features, out_features],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False,
)
self.bias = self.create_parameter(
shape=[out_features],
attr=self._bias_attr,
dtype=self._dtype,
is_bias=True,
)
self.weight_scale = self.create_parameter(
shape=[out_features],
attr=None,
dtype=self._dtype,
is_bias=False,
)
self.is_weight_quanted = False
self.name = (name,)
self.bits = bits
self.layout = algo
self.algo = algo
self.config = config
def forward(self, input):
if in_dynamic_mode():
if not self.is_weight_quanted:
weight_tensor, weight_scale_tensor = F.quant_for_compress(
self.weight, self.bits, self.layout
)
weight_attr = paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(weight_tensor)
)
self.weight = self.create_parameter(
shape=self.weight.shape
if self.layout == 0
else [self.weight.shape[1], self.weight.shape[0]],
attr=weight_attr,
dtype="int8",
is_bias=False,
)
weight_scale_attr = paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(
weight_scale_tensor
)
)
self.weight_scale = self.create_parameter(
shape=self.weight_scale.shape,
attr=weight_scale_attr,
dtype="float32",
is_bias=False,
)
self.is_weight_quanted = True
out = F.linear_compress(
x=input,
weight=self.weight,
weight_scale=self.weight_scale,
bias=self.bias,
bits=self.bits,
algo=self.algo,
name=self.name,
config=self.config,
)
return out
def extra_repr(self):
name_str = f', name={self.name}' if self.name else ''
return 'in_features={}, out_features={}, dtype={}{}, algo={}'.format(
self.weight.shape[0],
self.weight.shape[1],
self._dtype,
name_str,
self.algo,
)
class Upsample(Layer): class Upsample(Layer):
""" """
This op resizes a batch of images. This op resizes a batch of images.
......
...@@ -86,6 +86,7 @@ endif() ...@@ -86,6 +86,7 @@ endif()
list(REMOVE_ITEM TEST_OPS test_audio_logmel_feature test_audio_mel_feature) list(REMOVE_ITEM TEST_OPS test_audio_logmel_feature test_audio_mel_feature)
list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op) list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op)
list(REMOVE_ITEM TEST_OPS test_linear_compress)
list(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_op) list(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_op)
list(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_grad_op) list(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_grad_op)
list(REMOVE_ITEM TEST_OPS test_fuse_gemm_epilogue_pass) list(REMOVE_ITEM TEST_OPS test_fuse_gemm_epilogue_pass)
...@@ -153,6 +154,7 @@ if(WIN32) ...@@ -153,6 +154,7 @@ if(WIN32)
list(REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_bias) list(REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_bias)
list(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_int8_op) list(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_int8_op)
list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op) list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op)
list(REMOVE_ITEM TEST_OPS test_linear_compress)
endif() endif()
list(REMOVE_ITEM TEST_OPS test_checkpoint_saver) list(REMOVE_ITEM TEST_OPS test_checkpoint_saver)
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle import fluid
from paddle.fluid.framework import default_main_program
from paddle.framework import set_default_dtype
np.random.seed(123)
paddle.seed(123)
default_main_program().random_seed = 42
paddle.disable_static()
class LinearTestCase(unittest.TestCase):
def config(self):
self.dtype = 'float16'
self.rtol = 1e-5
self.atol = 1e-2
self.bias = True
self.in_features = 64
self.out_features = 64
self.algo = "weight_only"
def setUp(self):
self.config()
input = np.random.random((2, 4, self.in_features))
self.input = paddle.to_tensor(input, dtype=self.dtype)
if self.bias:
bias_attr = fluid.ParamAttr(
learning_rate=0.8,
trainable=False,
regularizer=None,
initializer=paddle.nn.initializer.Constant(value=1.0),
)
else:
bias_attr = None
set_default_dtype(self.dtype)
self.linear = paddle.nn.Linear(
self.in_features, self.out_features, bias_attr=bias_attr
)
if self.algo == "llm.int8":
self.config = {"threshold": 6.0}
else:
self.config = None
self.linear_compress = paddle.nn.LinearCompress(
self.in_features,
self.out_features,
bias_attr=bias_attr,
algo=self.algo,
config=self.config,
)
self.linear_compress(self.input)
def get_linear_out(self):
out = self.linear(self.input)
return out.numpy()
def get_linear_compress_out(self):
out = self.linear_compress(self.input)
return out.numpy()
def test_linear_compress(self):
out_real = self.get_linear_compress_out()
out_expect = self.get_linear_out()
np.testing.assert_allclose(
out_real, out_expect, rtol=self.rtol, atol=self.atol
)
class LinearTestCase1(LinearTestCase):
def config(self):
super().config()
self.dtype = 'float16'
self.bias = True
self.in_features = 128
self.out_features = 64
class LinearTestCase2(LinearTestCase):
def config(self):
super().config()
self.dtype = 'float16'
self.bias = False
self.in_features = 64
self.out_features = 64
class LinearTestCase3(LinearTestCase):
def config(self):
super().config()
self.dtype = 'float16'
self.bias = False
self.in_features = 64
self.out_features = 64
self.algo = "llm.int8"
self.atol = 1e-1
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册