未验证 提交 c4d5ec66 编写于 作者: FormlessUnit's avatar FormlessUnit 提交者: GitHub

add linear_compress API (#54140)

* add linear_compress API
上级 989f3dde
...@@ -112,11 +112,13 @@ class AssignValueKernel : public framework::OpKernel<T> { ...@@ -112,11 +112,13 @@ class AssignValueKernel : public framework::OpKernel<T> {
break; break;
case framework::proto::VarType::INT64: case framework::proto::VarType::INT64:
value_name = "int64_values"; value_name = "int64_values";
case framework::proto::VarType::INT8:
value_name = "int8_values";
break; break;
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported data type(code %d) for AssignValue operator, only " "Unsupported data type(code %d) for AssignValue operator, only "
"supports bool, int32, float32 and int64.", "supports bool, int32, float32, int8 and int64.",
dtype)); dtype));
break; break;
} }
......
...@@ -1347,6 +1347,16 @@ ...@@ -1347,6 +1347,16 @@
data_transform : data_transform :
skip_transform : out_size, size_tensor, scale_tensor skip_transform : out_size, size_tensor, scale_tensor
- op : llm_int8_mat_mul
args : (Tensor x, Tensor weight, Tensor weight_scale, float threshold=6.0)
output : Tensor(out)
infer_meta :
func : LLMInt8MatMulInferMeta
param : [x, weight]
kernel :
func : llm_int8_mat_mul
data_type : x
- op : log - op : log
args : (Tensor x) args : (Tensor x)
output : Tensor output : Tensor
...@@ -1896,6 +1906,15 @@ ...@@ -1896,6 +1906,15 @@
func : qr func : qr
backward : qr_grad backward : qr_grad
- op : quant_for_compress
args : (Tensor x, int bits = 8, str layout = "weight_only")
output : Tensor(out), Tensor(scale)
infer_meta :
func : QuantForCompressInferMeta
kernel :
func : quant_for_compress
data_type: x
- op : real - op : real
args : (Tensor x) args : (Tensor x)
output : Tensor (out) output : Tensor (out)
...@@ -2563,6 +2582,16 @@ ...@@ -2563,6 +2582,16 @@
intermediate: warprnntgrad intermediate: warprnntgrad
backward : warprnnt_grad backward : warprnnt_grad
- op : weight_only_mat_mul
args : (Tensor x, Tensor weight, Tensor weight_scale)
output : Tensor(out)
infer_meta :
func : WeightOnlyMatMulInferMeta
param : [x, weight]
kernel :
func : weight_only_mat_mul
data_type : x
- op : weighted_sample_neighbors - op : weighted_sample_neighbors
args : (Tensor row, Tensor colptr, Tensor edge_weight, Tensor input_nodes, Tensor eids, int sample_size, bool return_eids) args : (Tensor row, Tensor colptr, Tensor edge_weight, Tensor input_nodes, Tensor eids, int sample_size, bool return_eids)
output : Tensor(out_neighbors), Tensor(out_count), Tensor(out_eids) output : Tensor(out_neighbors), Tensor(out_count), Tensor(out_eids)
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#pragma once
namespace phi {
template <typename T>
struct PDDataTypeTraits {
using DataType = T;
};
template <>
struct PDDataTypeTraits<phi::dtype::float16> {
// Since LayerNormDirectCUDAFunctor register half type, we need to convert
// phi::float16 to half.
using DataType = half;
};
#ifdef PADDLE_CUDA_BF16
template <>
class PDDataTypeTraits<phi::dtype::bfloat16> {
public:
using DataType = __nv_bfloat16;
};
#endif
} // namespace phi
...@@ -3572,5 +3572,51 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row, ...@@ -3572,5 +3572,51 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row,
out_count->set_dtype(DataType::INT32); out_count->set_dtype(DataType::INT32);
} }
void LLMInt8MatMulInferMeta(const MetaTensor& x,
const MetaTensor& weight,
MetaTensor* out) {
auto x_dims = x.dims();
auto w_dims = weight.dims();
PADDLE_ENFORCE_EQ(
w_dims.size(),
2UL,
errors::InvalidArgument("The input(weight) must be a 2D Tensor."));
PADDLE_ENFORCE_EQ(
x_dims[x_dims.size() - 1],
w_dims[1],
errors::InvalidArgument(
"Input(X) dim[-1] and Input(Weight) dim[1] should be euqal."
"But received Input(X) dim[-1](%s) != Input(Weight) dim[1](%s)",
x_dims[x_dims.size() - 1],
w_dims[1]));
auto out_dims = x_dims;
out_dims[out_dims.size() - 1] = w_dims[0];
out->set_dims(out_dims);
out->set_dtype(x.dtype());
}
void WeightOnlyMatMulInferMeta(const MetaTensor& x,
const MetaTensor& weight,
MetaTensor* out) {
auto x_dims = x.dims();
auto w_dims = weight.dims();
PADDLE_ENFORCE_EQ(
w_dims.size(),
2UL,
errors::InvalidArgument("The input(weight) must be a 2D Tensor."));
PADDLE_ENFORCE_EQ(
x_dims[x_dims.size() - 1],
w_dims[0],
errors::InvalidArgument(
"Input(X) dim[-1] and Input(Weight) dim[0] should be euqal."
"But received Input(X) dim[-1](%s) != Input(Weight) dim[0](%s)",
x_dims[x_dims.size() - 1],
w_dims[0]));
auto out_dims = x_dims;
out_dims[out_dims.size() - 1] = w_dims[1];
out->set_dims(out_dims);
out->set_dtype(x.dtype());
}
} // namespace phi } // namespace phi
PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta); PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta);
...@@ -673,6 +673,31 @@ void MoeInferMeta(const MetaTensor& x, ...@@ -673,6 +673,31 @@ void MoeInferMeta(const MetaTensor& x,
const std::string& act_type, const std::string& act_type,
MetaTensor* out); MetaTensor* out);
void FusedMultiHeadAttentionInferMeta(const MetaTensor& query,
const MetaTensor& key,
const MetaTensor& value,
const MetaTensor& mask,
float scale,
bool causal,
MetaTensor* out);
void FusedMultiHeadAttentionVariableInferMeta(const MetaTensor& query,
const MetaTensor& key,
const MetaTensor& value,
const MetaTensor& seq_lens,
const MetaTensor& mask,
float scale,
bool causal,
MetaTensor* out);
void LLMInt8MatMulInferMeta(const MetaTensor& x,
const MetaTensor& weight,
MetaTensor* out);
void WeightOnlyMatMulInferMeta(const MetaTensor& x,
const MetaTensor& weight,
MetaTensor* out);
void FusedRopeInferMeta(const MetaTensor& q, void FusedRopeInferMeta(const MetaTensor& q,
const MetaTensor& k, const MetaTensor& k,
const MetaTensor& v, const MetaTensor& v,
......
...@@ -5062,6 +5062,51 @@ void CheckNumericsInferMeta(const MetaTensor& tensor, ...@@ -5062,6 +5062,51 @@ void CheckNumericsInferMeta(const MetaTensor& tensor,
values->set_dims(phi::make_ddim({3})); values->set_dims(phi::make_ddim({3}));
} }
void QuantForCompressInferMeta(const MetaTensor& x,
int bits,
const std::string& layout,
MetaTensor* out,
MetaTensor* scale) {
auto x_dims = x.dims();
PADDLE_ENFORCE_EQ(
x_dims.size(),
2UL,
phi::errors::InvalidArgument(
"The x tensor of quant op must be 2D, but got[%d]", x_dims.size()));
PADDLE_ENFORCE_GE(
x_dims[0],
64,
phi::errors::OutOfRange("The first dimension of input is out of range "
"(expected at least 64, but got %ld).",
x_dims[0]));
PADDLE_ENFORCE_EQ(
x_dims[0] % 64,
0,
phi::errors::InvalidArgument(
"The first dimension of input must be divisible by 64, but got[%d]",
x_dims[0]));
std::vector<int64_t> dim_scale({x_dims[1]});
std::vector<int64_t> dim_out;
if (layout == "weight_only") {
dim_out = std::vector<int64_t>({x_dims[0], x_dims[1]});
} else if (layout == "llm.int8") {
dim_out = std::vector<int64_t>({x_dims[1], x_dims[0]});
} else {
phi::errors::InvalidArgument(
"The layout must be weight_only or llm.int8, but got %s", layout);
}
out->set_dims(phi::make_ddim(dim_out));
// TODO(lizhenyun) support weight_only int4
if (bits == 8) {
out->set_dtype(DataType::INT8);
} else {
phi::errors::Fatal("The bits only support 8, but got[%d]", bits);
}
scale->set_dims(phi::make_ddim(dim_scale));
scale->set_dtype(DataType::FLOAT32);
}
} // namespace phi } // namespace phi
PD_REGISTER_INFER_META_FN(flatten, phi::FlattenInferMeta); PD_REGISTER_INFER_META_FN(flatten, phi::FlattenInferMeta);
...@@ -724,4 +724,10 @@ void UnStackInferMeta(const MetaTensor& x, ...@@ -724,4 +724,10 @@ void UnStackInferMeta(const MetaTensor& x,
int num, int num,
std::vector<MetaTensor*> outs); std::vector<MetaTensor*> outs);
void QuantForCompressInferMeta(const MetaTensor& x,
int bits,
const std::string& layout,
MetaTensor* out,
MetaTensor* scale);
} // namespace phi } // namespace phi
...@@ -132,6 +132,7 @@ PD_REGISTER_KERNEL(assign_value, ...@@ -132,6 +132,7 @@ PD_REGISTER_KERNEL(assign_value,
bool, bool,
int, int,
float, float,
int8_t,
int64_t) {} int64_t) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
...@@ -158,6 +159,7 @@ PD_REGISTER_KERNEL(assign_value, ...@@ -158,6 +159,7 @@ PD_REGISTER_KERNEL(assign_value,
bool, bool,
int, int,
float, float,
int8_t,
int64_t) {} int64_t) {}
#endif #endif
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/quant_for_compress_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/impl/quant_for_compress_kernel_impl.h"
namespace phi {
template <typename DeviceContext, typename T, typename D>
void quant_compute(const DeviceContext& dev_ctx,
const DenseTensor& x,
DenseTensor* out,
DenseTensor* scale,
const std::string& layout) {
const auto x_dims = x.dims();
PADDLE_ENFORCE_EQ(
x_dims.size(),
2,
phi::errors::InvalidArgument(
"the x tensor of quant op must be 2D, but got[%d]", x_dims.size()));
size_t m = x_dims[0];
size_t n = x_dims[1];
int64_t num = x.numel();
DDim dims = {num};
const T* x_data = x.data<T>();
D* out_data = out->data<D>();
float* scale_data = scale->data<float>();
DenseTensor x_int(out->type());
x_int.Resize({static_cast<int64_t>(m), static_cast<int64_t>(n)});
dev_ctx.template Alloc<D>(&x_int);
D* x_int_data = x_int.data<D>();
DenseTensor int_processed(out->type());
int_processed.Resize(dims);
dev_ctx.template Alloc<D>(&int_processed);
D* int_processed_data = int_processed.data<D>();
DenseTensor int_processed_2(out->type());
int_processed_2.Resize(out->dims());
dev_ctx.template Alloc<D>(&int_processed_2);
D* int_processed_2_data = int_processed_2.data<D>();
per_channel_scale(scale_data, x_data, m, n);
per_channel_quant(x_int_data, x_data, scale_data, m, n);
if (layout == "weight_only") {
permute_B_rows_for_mixed_gemm(
int_processed_data, x_int_data, std::vector<size_t>{m, n}, (int64_t)80);
row_major_to_column_major(
int_processed_2_data, int_processed_data, std::vector<size_t>{m, n});
interleave_column_major_tensor(
out_data, int_processed_2_data, std::vector<size_t>{m, n});
add_bias_and_interleave_int8s_inplace(out_data, num);
} else if (layout == "llm.int8") {
std::vector<int> axis = {1, 0};
funcs::Transpose<DeviceContext, int8_t, 2> trans;
trans(dev_ctx, x_int, out, axis);
} else {
phi::errors::InvalidArgument(
"The layout must be weight_only or llm.int8, but got %s", layout);
}
}
template <typename T, typename Context>
void QuantForCompressKernel(const Context& dev_ctx,
const DenseTensor& x,
int bits,
const std::string& layout,
DenseTensor* out,
DenseTensor* scale) {
if (bits == 8) {
dev_ctx.template Alloc<int8_t>(out);
dev_ctx.template Alloc<float>(scale);
quant_compute<Context, T, int8_t>(dev_ctx, x, out, scale, layout);
} else {
phi::errors::Unimplemented("The bits only support 8, but got[%d]", bits);
}
// VLOG(0) << "x: " << x.dtype() << x;
// VLOG(0) << "out: " << out->dtype() << *out;
}
} // namespace phi
PD_REGISTER_KERNEL(quant_for_compress,
CPU,
ALL_LAYOUT,
phi::QuantForCompressKernel,
float,
phi::dtype::float16) {}
...@@ -87,6 +87,7 @@ PD_REGISTER_KERNEL(transpose, ...@@ -87,6 +87,7 @@ PD_REGISTER_KERNEL(transpose,
double, double,
int32_t, int32_t,
int64_t, int64_t,
int8_t,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16, phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <sstream>
#include <string>
#include <unordered_map>
#include "paddle/phi/backends/dynload/cublasLt.h"
#include "paddle/phi/core/dense_tensor.h"
namespace dyl = phi::dynload;
namespace phi {
struct CublasLtAlgoParam {
int algoId;
int swizzle;
int customOption;
int tile;
int splitK_val;
int reductionScheme;
int stages;
size_t workspace_size;
};
const std::map<std::tuple<int, int, int>, CublasLtAlgoParam> AlgoParamCache{};
class CublasLtHelper {
public:
CublasLtHelper(int m, int k, int n)
: alpha_(1), beta_(0), m_(m), k_(k), n_(n) {
cublasStatus_t status;
// handle and matmul desc
status = dyl::cublasLtCreate(&handle_);
#if CUBLAS_VER_MAJOR < 11
cudaDataType_t cudaComputeType = CUDA_R_32I;
#else
cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I;
#endif
PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
phi::errors::External(
"cublasLtMatrixLayoutCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
#if CUBLAS_VER_MAJOR < 11
status = dyl::cublasLtMatmulDescCreate(&matmul_desc_, cudaComputeType);
#else
status = dyl::cublasLtMatmulDescCreate(
&matmul_desc_, cudaComputeType, CUDA_R_32I);
#endif
PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
phi::errors::External(
"cublasLtMatmulDescCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
cublasOperation_t op_transpose = CUBLAS_OP_T;
status = dyl::cublasLtMatmulDescSetAttribute(matmul_desc_,
CUBLASLT_MATMUL_DESC_TRANSA,
&op_transpose,
sizeof(op_transpose));
PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
phi::errors::External(
"cublasLtMatmulDescSetAttribute execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
// matrix desc
status = dyl::cublasLtMatrixLayoutCreate(&B_desc_, CUDA_R_8I, k, n, k);
PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
phi::errors::External(
"cublasLtMatrixLayoutCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
status = dyl::cublasLtMatrixLayoutCreate(&A_desc_, CUDA_R_8I, k, m, k);
PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
phi::errors::External(
"cublasLtMatrixLayoutCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
status = dyl::cublasLtMatrixLayoutCreate(&C_desc_, CUDA_R_32I, n, m, n);
PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
phi::errors::External(
"cublasLtMatrixLayoutCreate execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
#if CUDA_VERSION >= 11020
int algoId = 21;
int swizzle = 0;
int customOption = 0;
int tile = 15;
int splitK_val = 0;
int reductionScheme = 0;
int stages = 23;
workspace_size_ = 0;
if (m >= 128) {
tile = 20;
stages = 17;
}
std::tuple<int, int, int> key(m_, k_, n_);
if (AlgoParamCache.count(key) != 0) {
auto value = AlgoParamCache.at(key);
algoId = value.algoId;
swizzle = value.swizzle;
customOption = value.customOption;
tile = value.tile;
splitK_val = value.splitK_val;
reductionScheme = value.reductionScheme;
stages = value.stages;
workspace_size_ = value.workspace_size;
}
dyl::cublasLtMatmulAlgoInit(handle_,
cudaComputeType,
CUDA_R_32I,
CUDA_R_8I,
CUDA_R_8I,
CUDA_R_32I,
CUDA_R_32I,
algoId,
&algo_);
dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo_,
CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION,
&(customOption),
sizeof(customOption));
dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo_, CUBLASLT_ALGO_CONFIG_TILE_ID, &(tile), sizeof(tile));
dyl::cublasLtMatmulAlgoConfigSetAttribute(&algo_,
CUBLASLT_ALGO_CONFIG_SPLITK_NUM,
&(splitK_val),
sizeof(splitK_val));
dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo_,
CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING,
&(swizzle),
sizeof(swizzle));
dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo_,
CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME,
&(reductionScheme),
sizeof(int));
#if CUDA_VERSION >= 11000
dyl::cublasLtMatmulAlgoConfigSetAttribute(
&algo_, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages));
#endif
#endif
}
~CublasLtHelper() {}
void GEMM(int8_t* A_dev,
const int8_t* B_dev,
int32_t* C_dev,
cudaStream_t stream,
void* workspace = nullptr) {
cublasStatus_t status;
status = dyl::cublasLtMatmul(handle_,
matmul_desc_,
&alpha_,
B_dev,
B_desc_,
A_dev,
A_desc_,
&beta_,
C_dev,
C_desc_,
C_dev,
C_desc_,
#if CUDA_VERSION >= 11020
&algo_,
workspace,
workspace_size_,
#else
nullptr,
nullptr,
0,
#endif
stream);
PADDLE_ENFORCE_EQ(
status,
CUBLAS_STATUS_SUCCESS,
phi::errors::External(
"cublasLtMatmul execution error"
"refer https://docs.nvidia.com/cuda/cublas/index.html to get more "
"information"));
}
private:
cublasLtHandle_t handle_;
cublasLtMatmulDesc_t matmul_desc_;
cublasLtMatrixLayout_t A_desc_;
cublasLtMatrixLayout_t B_desc_;
cublasLtMatrixLayout_t C_desc_;
cublasLtMatmulAlgo_t algo_;
int32_t alpha_;
int32_t beta_;
int m_;
int k_;
int n_;
size_t workspace_size_;
};
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/transform.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace phi {
using backends::gpu::GpuLaunchConfig;
constexpr int DequantKernelVecSize = 4;
template <typename T>
inline HOSTDEVICE T roundWithTiesToEven(T x) {
T xLower = floor(x);
T xUpper = ceil(x);
// x is in interval [xl,xu]. Choose closest of two bounds, breaking ties to
// even.
T dLower = x - xLower;
T dUpper = xUpper - x;
return static_cast<T>(
(dLower == dUpper ? fmod(xLower, 2.0F) == 0.0F : dLower < dUpper)
? xLower
: xUpper);
}
template <typename T>
__forceinline__ __device__ int8_t quant_helper(const T input,
const float scale,
const int round_type,
const float max_bound,
const float min_bound) {
float quant_value = max_bound * scale * static_cast<float>(input);
if (round_type == 0) {
quant_value = static_cast<float>(roundWithTiesToEven(quant_value));
} else {
quant_value = static_cast<float>(round(quant_value));
}
quant_value = quant_value > max_bound ? max_bound : quant_value;
quant_value = quant_value < min_bound ? min_bound : quant_value;
return static_cast<int8_t>(quant_value);
}
template <typename T>
__global__ void quantize_kernel(const T* input,
char4* output,
const float scale,
const int m,
const int n,
const int round_type,
const float max_bound,
const float min_bound) {
int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2;
int m_id = blockIdx.y * blockDim.y + threadIdx.y;
bool check = ((m_id < m) && (n_id < n));
if (check) {
char4 tmp;
tmp.x = quant_helper(
input[m_id * n + n_id], scale, round_type, max_bound, min_bound);
tmp.y = quant_helper(
input[m_id * n + n_id + 1], scale, round_type, max_bound, min_bound);
tmp.z = quant_helper(
input[m_id * n + n_id + 2], scale, round_type, max_bound, min_bound);
tmp.w = quant_helper(
input[m_id * n + n_id + 3], scale, round_type, max_bound, min_bound);
output[(m_id * n + n_id) >> 2] = tmp;
}
}
template <typename T>
void quantize_kernel_launcher(const T* input,
int8_t* output,
const float scale,
const int m,
const int n,
const int round_type,
const float max_bound,
const float min_bound,
gpuStream_t stream) {
// TODO(minghaoBD): optimize the kennel launch times when m==1 or n==1
dim3 grid((n >> 2 + 31) / 32, (m + 31) / 32);
dim3 block(32, 32);
quantize_kernel<<<grid, block, 0, stream>>>(input,
(char4*)output, // NOLINT
scale,
m,
n,
round_type,
max_bound,
min_bound);
}
template <typename T, int VecSize>
__global__ void dequantize_kernel(T* output,
const int32_t* input,
const int m, // batch size
const int n, // hidden
const float quant_in_scale,
const float* dequant_out_scale_data) {
int numel = m * n;
int stride = blockDim.x * gridDim.x * VecSize;
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize;
int col_id = idx % n;
phi::AlignedVector<int32_t, VecSize> in_vec;
phi::AlignedVector<float, VecSize> out_scale_vec;
phi::AlignedVector<T, VecSize> out_vec;
for (; idx < numel; idx += stride) {
phi::Load<int32_t, VecSize>(input + idx, &in_vec);
phi::Load<float, VecSize>(dequant_out_scale_data + col_id, &out_scale_vec);
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
out_vec[i] =
static_cast<T>(static_cast<float>(in_vec[i]) * out_scale_vec[i]);
}
phi::Store<T, VecSize>(out_vec, output + idx);
}
}
template <typename T>
void dequantize_kernel_launcher(const int32_t* input,
T* output,
const int m, // m
const int n, // n
gpuStream_t stream,
GpuLaunchConfig* gpu_config,
const float quant_in_scale,
const float* dequant_out_scale_data) {
dequantize_kernel<T, DequantKernelVecSize>
<<<gpu_config->block_per_grid, gpu_config->thread_per_block, 0, stream>>>(
output, input, m, n, quant_in_scale, dequant_out_scale_data);
}
} // namespace phi
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*! \file
\brief Templates exposing architecture support for multiply-add operations
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace arch {
// Tag which triggers MMA which will trigger
struct OpMultiplyAddDequantizeInterleavedBToA;
} // namespace arch
} // namespace cutlass
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cuda_runtime_api.h>
#include "cutlass/device_kernel.h"
#include "paddle/phi/kernels/fusion/cutlass/utils/cuda_utils.h"
namespace phi {
template <typename GemmKernel>
inline int compute_occupancy_for_kernel() {
int smem_size = static_cast<int>(sizeof(typename GemmKernel::SharedStorage));
if (smem_size > (48 << 10)) {
cudaError_t status =
cudaFuncSetAttribute(cutlass::Kernel<GemmKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (status == cudaError::cudaErrorInvalidValue) {
// Clear the error bit since we can ignore this.
// This should mean that smem_size >
// cudaDevAttrMaxSharedMemoryPerBlockOptin. In that case, we return an
// occupancy of 0. This will cause the heuristic to ignore this
// configuration.
status = cudaGetLastError();
return 0;
}
check_cuda_error(status);
}
int max_active_blocks = -1;
check_cuda_error(
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks,
cutlass::Kernel<GemmKernel>,
GemmKernel::kThreadCount,
smem_size));
return max_active_blocks;
}
} // namespace phi
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
// define scaling mode
enum class QuantMode {
PerTensorQuant,
PerTokenQuant,
PerChannelQuant,
PerTokenChannelQuant
};
} // namespace epilogue
} // namespace cutlass
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*! \file
\brief Functor performing linear combination with a maximum operation used by
epilogues.
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/epilogue/thread/scale_type.h"
#include "cutlass/functional.h"
#include "cutlass/half.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace thread {
/////////////////////////////////////////////////////////////////////////////////////////////////
__forceinline__ __device__ float copysignf_pos(float a, float b) {
float r;
r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000));
return r;
}
__forceinline__ __device__ float tanh_opt(float x) {
#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750)
const float exp_val = -1.f * fabs(2 * x);
return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);
#else
return fast_tanh(x);
#endif
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template <>
struct GELU_taylor<float> {
static const bool kIsHeavy = true;
CUTLASS_DEVICE
float operator()(float const& z) const {
float k0 = static_cast<float>(0.7978845608028654);
float k1 = static_cast<float>(0.044715);
return static_cast<float>(
cutlass::constants::half<float>() * z *
(cutlass::constants::one<float>() +
tanh_opt(k0 * z * (cutlass::constants::one<float>() + k1 * z * z))));
}
using Params = LinearCombinationGenericParams<float>;
CUTLASS_DEVICE
float operator()(float const& scalar, Params const& params_) const {
return this->operator()(scalar);
}
};
} // namespace thread
} // namespace epilogue
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*! \file
\brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one
scaling factor per row, and one per column.
original file:
3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "../epilogue_quant_helper.h"
#include "cutlass/arch/memory.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/numeric_conversion.h"
namespace cutlass {
namespace epilogue {
namespace threadblock {
template <typename ThreadblockShape_,
int ThreadCount,
typename ScaleTileIterator_,
typename OutputTileIterator_,
typename ElementAccumulator_,
typename ElementCompute_,
typename ElementwiseFunctor_,
bool UseMasking_ = false>
class EpilogueVisitorPerRowPerCol {
public:
using ThreadblockShape = ThreadblockShape_;
static int const kThreadCount = ThreadCount;
using ScaleTileIterator = ScaleTileIterator_;
using OutputTileIterator = OutputTileIterator_;
using ElementwiseFunctor = ElementwiseFunctor_;
static int const kIterations = OutputTileIterator::kIterations;
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
using ElementOutput = typename OutputTileIterator::Element;
using LayoutOutput = cutlass::layout::RowMajor;
using ElementAccumulator = ElementAccumulator_;
using AlphaScaleElementType = typename ScaleTileIterator::Element;
using ElementCompute = ElementCompute_;
using AccumulatorFragment = Array<ElementAccumulator, kElementsPerAccess>;
using ComputeFragment = Array<ElementCompute_, kElementsPerAccess>;
using OutputVector = Array<ElementOutput, kElementsPerAccess>;
static int const kThreadsPerRow =
OutputTileIterator::ThreadMap::Detail::kAccessWidth;
static bool const kHasMultiStepsInRow =
(OutputTileIterator::ThreadMap::Iterations::kColumn > 1);
/// Argument structure
struct Arguments {
typename ElementwiseFunctor::Params elementwise;
int64_t batch_stride_alpha;
int64_t batch_stride_C;
int64_t batch_stride_D;
//
// Methods
//
Arguments() : batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {}
explicit Arguments(typename ElementwiseFunctor::Params elementwise_)
: elementwise(elementwise_),
batch_stride_alpha(0),
batch_stride_C(0),
batch_stride_D(0) {}
Arguments(typename ElementwiseFunctor::Params elementwise_,
int64_t batch_stride_alpha_,
int64_t batch_stride_C_,
int64_t batch_stride_D_)
: elementwise(elementwise_),
batch_stride_alpha(batch_stride_alpha_),
batch_stride_C(batch_stride_C_),
batch_stride_D(batch_stride_D_) {}
};
struct Params {
typename ElementwiseFunctor::Params elementwise;
int64_t batch_stride_alpha;
int64_t batch_stride_C;
int64_t batch_stride_D;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params() {}
CUTLASS_HOST_DEVICE
explicitParams(Arguments const& args)
: elementwise(args.elementwise),
batch_stride_alpha(args.batch_stride_alpha),
batch_stride_C(args.batch_stride_C),
batch_stride_D(args.batch_stride_D) {}
};
/// Shared storage
struct SharedStorage {};
private:
Params const& params_;
SharedStorage& shared_storage_;
MatrixCoord extent_;
MatrixCoord extent_real_;
ElementwiseFunctor elementwise_;
const bool per_token_quant_;
const bool per_channel_quant_;
AlphaScaleElementType* ptr_alpha_row_;
AlphaScaleElementType* ptr_alpha_col_;
ScaleTileIterator iterator_alpha_col_;
OutputTileIterator iterator_C_;
OutputTileIterator iterator_D_;
AlphaScaleElementType element_alpha_row_ = 1.0f;
AlphaScaleElementType element_alpha_col_ = 1.0f;
typename ScaleTileIterator::Fragment fragment_alpha_col_;
typename OutputTileIterator::Fragment fragment_C_;
typename OutputTileIterator::Fragment fragment_D_;
ElementAccumulator beta_;
int column_offset_;
MatrixCoord thread_offset_;
public:
CUTLASS_DEVICE
EpilogueVisitorPerRowPerCol(
Params const& params,
SharedStorage& shared_storage, // NOLINT
cutlass::MatrixCoord const& problem_size,
int thread_idx,
int warp_idx,
int lane_idx,
typename ScaleTileIterator::Params params_alpha_col,
typename OutputTileIterator::Params params_C,
typename OutputTileIterator::Params params_D,
QuantMode quant_mode,
AlphaScaleElementType* ptr_alpha_row,
AlphaScaleElementType* ptr_alpha_col,
typename OutputTileIterator::Element* ptr_C,
typename OutputTileIterator::Element* ptr_D,
cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0,
0),
int column_offset = 0,
cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0,
0))
: params_(params),
shared_storage_(shared_storage),
extent_(problem_size),
elementwise_(params.elementwise),
per_token_quant_(quant_mode == QuantMode::PerTokenQuant ||
quant_mode == QuantMode::PerTokenChannelQuant),
per_channel_quant_(quant_mode == QuantMode::PerChannelQuant ||
quant_mode == QuantMode::PerTokenChannelQuant),
ptr_alpha_row_(ptr_alpha_row),
ptr_alpha_col_(ptr_alpha_col),
iterator_alpha_col_(params_alpha_col,
ptr_alpha_col,
problem_size,
thread_idx,
threadblock_offset),
iterator_C_(
params_C, ptr_C, problem_size, thread_idx, threadblock_offset),
iterator_D_(
params_D, ptr_D, problem_size, thread_idx, threadblock_offset),
extent_real_(problem_size_real) {
beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr
: params.elementwise.beta);
if (beta_ == ElementAccumulator()) {
iterator_C_.clear_mask();
}
}
/// Helper to indicate split-K behavior
CUTLASS_DEVICE
void set_k_partition(
int split_k_index, ///< Index of this threadblock within split-K
///< partitioned scheme
int split_k_slices) { ///< Total number of split-K slices
}
/// Called to set the batch index
CUTLASS_DEVICE
void set_batch_index(int batch_idx) {
iterator_alpha_col_.add_pointer_offset(batch_idx *
params_.batch_stride_alpha);
iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C);
iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D);
}
/// Called at the start of the epilogue just before iterating over accumulator
/// slices
CUTLASS_DEVICE
void begin_epilogue() {
if (per_channel_quant_) {
iterator_alpha_col_.load(fragment_alpha_col_);
} else if (ptr_alpha_col_ != nullptr) {
arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
element_alpha_col_, ptr_alpha_col_, true);
}
if (!per_token_quant_ && ptr_alpha_row_ != nullptr) {
arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
element_alpha_row_, ptr_alpha_row_, true);
}
}
/// Called at the start of one step before starting accumulator exchange
CUTLASS_DEVICE
void begin_step(int step_idx) {
fragment_D_.clear();
fragment_C_.clear();
if (elementwise_.kScale !=
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) {
iterator_C_.load(fragment_C_);
++iterator_C_;
}
// load alpha_row in begin_step only when per token(row) scaling is used
if (per_token_quant_) {
int thread_offset_row =
iterator_D_.thread_start_row() +
OutputTileIterator::ThreadMap::iteration_offset(0).row();
// element_alpha_row_ = ptr_alpha_row_[thread_offset_row];
arch::global_load<AlphaScaleElementType, sizeof(AlphaScaleElementType)>(
element_alpha_row_,
ptr_alpha_row_ + thread_offset_row,
thread_offset_row < extent_.row());
}
}
/// Called at the start of a row
CUTLASS_DEVICE
void begin_row(int row_idx) {
// Clear accumulators for max and sum when starting a whole row
}
/// Called after accumulators have been exchanged for each accumulator vector
CUTLASS_DEVICE
void visit(int iter_idx,
int row_idx,
int column_idx,
int frag_idx,
AccumulatorFragment const& accum) {
NumericArrayConverter<ElementCompute,
ElementAccumulator,
kElementsPerAccess>
source_converter;
ComputeFragment result = source_converter(accum);
if (per_channel_quant_) {
ComputeFragment alpha_col =
reinterpret_cast<ComputeFragment*>(&fragment_alpha_col_)[frag_idx];
result = per_token_channel_scale_accumulator_(
result, alpha_col, element_alpha_row_);
} else {
result = per_token_scale_accumulator_(
result, element_alpha_col_, element_alpha_row_);
}
// Convert to the output
NumericArrayConverter<ElementOutput, ElementCompute, kElementsPerAccess>
output_converter;
OutputVector& output =
reinterpret_cast<OutputVector*>(&fragment_D_)[frag_idx];
output = output_converter(result);
}
/// Called after all accumulator elements have been visited
CUTLASS_DEVICE
void end_step(int step_idx) {
iterator_D_.store(fragment_D_);
++iterator_D_;
}
/// Called after all steps have been completed
CUTLASS_DEVICE
void end_epilogue() {}
private:
CUTLASS_DEVICE
ComputeFragment per_token_channel_scale_accumulator_(
ComputeFragment const& accum,
ComputeFragment const& scale_col,
AlphaScaleElementType const& scale_row) {
ComputeFragment result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ComputeFragment::kElements; ++i) {
result[i] = accum[i] * (scale_col[i] * scale_row);
}
return result;
}
CUTLASS_DEVICE
ComputeFragment per_token_scale_accumulator_(
ComputeFragment const& accum,
AlphaScaleElementType const& scale_col,
AlphaScaleElementType const& scale_row) {
ComputeFragment result;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < ComputeFragment::kElements; ++i) {
result[i] = accum[i] * (scale_col * scale_row);
}
return result;
}
};
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
The epilogue rearranges the result of a matrix product through shared memory
to match canonical tensor layouts in global memory. Epilogues support
conversion and reduction operations.
original file:
3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/platform/platform.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
#include "cutlass/epilogue/thread/linear_combination_gelu.h"
#include "cutlass/epilogue/thread/linear_combination_hardswish.h"
#include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
#include "cutlass/epilogue/thread/linear_combination_relu.h"
#include "cutlass/epilogue/thread/linear_combination_relu0.h"
#include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
#include "cutlass/epilogue/thread/conversion_op.h"
#include "cutlass/epilogue/thread/reduction_op.h"
#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h"
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h"
#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h"
#include "cutlass/epilogue/threadblock/epilogue.h"
#include "cutlass/epilogue/threadblock/interleaved_epilogue.h"
#include "cutlass/layout/permute.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace epilogue {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
namespace detail {
/// Partial specialization for half <= int32_t x 8 epilogues avoids shared
/// memory bank conflicts.
template <typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename ThreadMap>
struct DefaultIteratorsTensorOp<cutlass::half_t,
int32_t,
8,
ThreadblockShape,
WarpShape,
InstructionShape,
ThreadMap> {
using WarpTileIterator =
cutlass::epilogue::warp::TileIteratorTensorOp<WarpShape,
InstructionShape,
int32_t,
layout::RowMajor>;
using SharedLoadIterator =
cutlass::epilogue::threadblock::SharedLoadIterator<ThreadMap, int32_t>;
static int const kFragmentsPerIteration = 1;
};
/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared
/// memory bank conflicts.
#ifdef PADDLE_CUDA_BF16
template <typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename ThreadMap>
struct DefaultIteratorsTensorOp<cutlass::bfloat16_t,
int32_t,
8,
ThreadblockShape,
WarpShape,
InstructionShape,
ThreadMap> {
using WarpTileIterator =
cutlass::epilogue::warp::TileIteratorTensorOp<WarpShape,
InstructionShape,
int32_t,
layout::RowMajor>;
using SharedLoadIterator =
cutlass::epilogue::threadblock::SharedLoadIterator<ThreadMap, int32_t>;
static int const kFragmentsPerIteration = 1;
};
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Tile iterator used to load output tile from shared memory in epilogue.
///
/// Satisfies: ReadableTileIterator
///
template <typename ThreadMap_ ///< Thread map (conept: OutputTileThreadMap)
>
class SharedLoadIteratorMixed<ThreadMap_, int32_t, 32, 16, 8, 8> {
public:
using ThreadMap = ThreadMap_;
using Shape = typename ThreadMap::Shape;
using Element = int32_t;
using Layout = layout::RowMajor;
using TensorRef = TensorRef<Element, Layout>;
using ConstTensorRef = typename TensorRef::ConstTensorRef;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorCoord = MatrixCoord;
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
static int const kAlignment =
ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value / 8;
static int const kThreads = ThreadMap::kThreads;
/// Fragment object
using Fragment =
Array<Element,
ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow *
ThreadMap::Iterations::kGroup *
ThreadMap::Iterations::kCluster *
ThreadMap::kElementsPerAccess>;
/// Memory access size
using AccessType =
AlignedArray<Element, ThreadMap::kElementsPerAccess, kAlignment>;
/// Vector type used for SMEM loads
using LoadType = AlignedArray<Element,
const_min(128 / sizeof_bits<Element>::value,
ThreadMap::kElementsPerAccess),
const_min(16, kAlignment)>;
static int const kLoadsPerAccess =
AccessType::kElements / LoadType::kElements;
private:
//
// Data members
//
/// Byte-level pointer
LoadType const* pointers_[kLoadsPerAccess];
/// Stride along adjacent rows in units of LoadType
int stride_;
public:
//
// Methods
//
/// Constructor
CUTLASS_DEVICE
SharedLoadIteratorMixed(TensorRef ref, int thread_idx)
: stride_((ref.stride(0) / LoadType::kElements)) {
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx);
// Initialize pointers
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kLoadsPerAccess; ++i) {
pointers_[i] = reinterpret_cast<LoadType const*>(ref.data());
int col_idx =
(thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess;
int bank_offset = (col_idx * static_cast<int>(sizeof(LoadType)) / 128) %
kLoadsPerAccess;
col_idx += (bank_offset + i) % kLoadsPerAccess;
pointers_[i] += thread_offset.row() * stride_ + col_idx;
}
}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kLoadsPerAccess; ++i) {
pointers_[i] += pointer_offset / LoadType::kElements;
}
}
CUTLASS_DEVICE
void add_tile_offset(TensorCoord const& offset) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kLoadsPerAccess; ++i) {
pointers_[i] += offset.row() * Shape::kRow * stride_ +
offset.column() * Shape::kColumn / LoadType::kElements;
}
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load_with_pointer_offset(Fragment& frag, // NOLINT
Index pointer_offset) const {
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster;
++cluster) {
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int row_ptr_offset = row * ThreadMap::Delta::kRow * stride_ +
group * ThreadMap::Delta::kGroup * stride_ +
cluster * ThreadMap::Delta::kCluster * stride_ +
pointer_offset / LoadType::kElements;
int frag_row_idx =
(row + ThreadMap::Iterations::kRow *
(group + ThreadMap::Iterations::kGroup * cluster));
LoadType* frag_ptr = reinterpret_cast<LoadType*>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn;
++column) {
int frag_idx =
frag_row_idx * ThreadMap::Iterations::kColumn + column;
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < kLoadsPerAccess; ++v) {
int vector_idx = (column * ThreadMap::Delta::kColumn /
kElementsPerAccess * kLoadsPerAccess);
LoadType const* memory_pointer = pointers_[v] + row_ptr_offset;
frag_ptr[frag_idx * kLoadsPerAccess + v] =
memory_pointer[vector_idx];
}
}
}
}
}
}
/// Loads a fragment
CUTLASS_DEVICE
void load(Fragment& frag) const { // NOLINT
load_with_pointer_offset(frag, 0);
}
};
} // namespace threadblock
} // namespace epilogue
} // namespace cutlass
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/**
* @file epilogue_helpers.h
*
* This file includes types for the epilogues. The empty structs exist so we can
* signal to template code the type of epilogue we want to run, and let the
* underlying code specify the details such as element types, accumulator type
* and elements per vector access.
*
*/
#pragma once
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/linear_combination_generic.h"
#include "cutlass/epilogue/thread/linear_combination_relu.h"
#include "cutlass/epilogue/thread/linear_combination_silu.h"
#include "epilogue/thread/ft_fused_activations.h"
namespace phi {
struct EpilogueOpBiasSilu {};
struct EpilogueOpBiasReLU {};
struct EpilogueOpBiasFtGelu {};
struct EpilogueOpBias {};
struct EpilogueOpNoBias {};
template <typename ElementType,
int ElementsPerVectorAccess,
typename ElementAccumulator,
typename Op>
struct Epilogue {};
template <typename ElementType,
int ElementsPerVectorAccess,
typename ElementAccumulator>
struct Epilogue<ElementType,
ElementsPerVectorAccess,
ElementAccumulator,
EpilogueOpBiasSilu> {
using Op = cutlass::epilogue::thread::LinearCombinationSilu<
ElementType,
ElementsPerVectorAccess,
ElementAccumulator,
ElementAccumulator,
cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
};
template <typename ElementType,
int ElementsPerVectorAccess,
typename ElementAccumulator>
struct Epilogue<ElementType,
ElementsPerVectorAccess,
ElementAccumulator,
EpilogueOpBiasReLU> {
using Op = cutlass::epilogue::thread::LinearCombinationRelu<
ElementType,
ElementsPerVectorAccess,
ElementAccumulator,
ElementAccumulator,
cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
};
template <typename ElementType,
int ElementsPerVectorAccess,
typename ElementAccumulator>
struct Epilogue<ElementType,
ElementsPerVectorAccess,
ElementAccumulator,
EpilogueOpBiasFtGelu> {
using Op = cutlass::epilogue::thread::LinearCombinationGeneric<
cutlass::epilogue::thread::GELU_taylor,
ElementType,
ElementsPerVectorAccess,
ElementAccumulator,
ElementAccumulator,
cutlass::epilogue::thread::ScaleType::NoBetaScaling,
cutlass::FloatRoundStyle::round_to_nearest,
true>;
};
template <typename ElementType,
int ElementsPerVectorAccess,
typename ElementAccumulator>
struct Epilogue<ElementType,
ElementsPerVectorAccess,
ElementAccumulator,
EpilogueOpBias> {
using Op = cutlass::epilogue::thread::LinearCombination<
ElementType,
ElementsPerVectorAccess,
ElementAccumulator,
ElementAccumulator,
cutlass::epilogue::thread::ScaleType::NoBetaScaling>;
};
template <typename ElementType,
int ElementsPerVectorAccess,
typename ElementAccumulator>
struct Epilogue<ElementType,
ElementsPerVectorAccess,
ElementAccumulator,
EpilogueOpNoBias> {
using Op = cutlass::epilogue::thread::LinearCombination<
ElementType,
ElementsPerVectorAccess,
ElementAccumulator,
ElementAccumulator,
cutlass::epilogue::thread::ScaleType::Default>;
};
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
namespace phi {
// Note: The shapes are in the format MxNxK. The K shape of the runtime config
// MUST match the K shape
// in the kernel layout details when doing weight only quantization.
enum class CutlassTileConfig {
// Signals that we should run heuristics do choose a config
Undefined,
// Signals that we should run heuristics do choose a config
ChooseWithHeuristic,
// SiMT config
CtaShape128x128x8_WarpShape64x64x8,
// TensorCore configs CTA_N = 128, CTA_K = 64
// Warp configs for M=32
CtaShape32x128x64_WarpShape32x32x64,
// Warp configs for M=64
CtaShape64x128x64_WarpShape32x64x64,
CtaShape64x128x64_WarpShape64x32x64,
// Warp configs for M=128
CtaShape128x128x64_WarpShape64x32x64,
CtaShape128x128x64_WarpShape128x32x64,
// configs for large M in encoder
CtaShape128x256x64_WarpShape64x64x64,
CtaShape256x128x64_WarpShape64x64x64
};
enum class SplitKStyle {
NO_SPLIT_K,
SPLIT_K_SERIAL,
// SPLIT_K_PARALLEL // Not supported yet
};
struct CutlassGemmConfig {
CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic;
SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K;
int split_k_factor = -1;
int stages = -1;
};
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/bfloat16.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
namespace cutlass {
namespace gemm {
namespace kernel {
template <typename TypeA, typename TypeB, typename arch, typename Enable = void>
struct MixedGemmArchTraits {};
template <typename arch>
struct MixedGemmArchTraits<float, float, arch> {
static constexpr int Stages = 2;
using OperatorClass = cutlass::arch::OpClassSimt;
using AccType = float;
using LayoutB = cutlass::layout::RowMajor;
static constexpr int ElementsPerAccessA = 1;
static constexpr int ElementsPerAccessB = 1;
static constexpr int ElementsPerAccessC = 1;
static constexpr int ThreadblockK = 8;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
using Operator = cutlass::arch::OpMultiplyAdd;
};
// // ========================= Volta Traits ===========================
// // Volta will always dequantize after the global memory load.
// // This will instantiate any HMMA tensorcore kernels for Volta.
// // Note that volta does not have native bfloat support so weights and
// activations will be casted to fp16
// // and compute will happen in fp16 then will be converted for bf16 output.
template <typename TypeA, typename TypeB>
struct MixedGemmArchTraits<
TypeA,
TypeB,
cutlass::arch::Sm70,
typename cutlass::platform::enable_if<
cutlass::platform::is_same<TypeA, cutlass::half_t>::value ||
cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type> {
private:
using LayoutDetails = LayoutDetailsB<TypeB, cutlass::arch::Sm70>;
public:
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using AccType = float;
using LayoutB = typename LayoutDetails::Layout;
static constexpr int ElementsPerAccessA =
128 / cutlass::sizeof_bits<TypeA>::value;
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
static constexpr int ElementsPerAccessC =
128 / cutlass::sizeof_bits<TypeA>::value;
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
using Operator = typename LayoutDetails::Operator;
};
// ======================= Turing Traits ==============================
// Note that turing does not have native bfloat support so weights and
// activations will be casted to fp16 and compute will happen in fp16 then will
// be converted for bf16 output.
template <typename TypeA, typename TypeB>
struct MixedGemmArchTraits<
TypeA,
TypeB,
cutlass::arch::Sm75,
typename cutlass::platform::enable_if<
cutlass::platform::is_same<TypeA, cutlass::half_t>::value ||
cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type> {
private:
using LayoutDetails = LayoutDetailsB<TypeB, cutlass::arch::Sm75>;
public:
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using AccType = float;
using LayoutB = typename LayoutDetails::Layout;
static constexpr int ElementsPerAccessA =
128 / cutlass::sizeof_bits<TypeA>::value;
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
static constexpr int ElementsPerAccessC =
128 / cutlass::sizeof_bits<TypeA>::value;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using Operator = typename LayoutDetails::Operator;
};
// ======================= Ampere Traits ==============================
template <typename TypeA, typename TypeB>
struct MixedGemmArchTraits<
TypeA,
TypeB,
cutlass::arch::Sm80,
typename cutlass::platform::enable_if<
cutlass::platform::is_same<TypeA, cutlass::half_t>::value ||
cutlass::platform::is_same<TypeA, cutlass::bfloat16_t>::value>::type> {
private:
using LayoutDetails = LayoutDetailsB<TypeB, cutlass::arch::Sm80>;
public:
static constexpr int ThreadblockK = LayoutDetails::ThreadblockK;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using AccType = float;
using LayoutB = typename LayoutDetails::Layout;
static constexpr int ElementsPerAccessA =
128 / cutlass::sizeof_bits<TypeA>::value;
static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess;
static constexpr int ElementsPerAccessC =
128 / cutlass::sizeof_bits<TypeA>::value;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using Operator = typename LayoutDetails::Operator;
};
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*! \file
\brief Template for a pipelined GEMM kernel. Does not compute batching or
support split-K.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/arch.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/semaphore.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace kernel {
/////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate
typename Epilogue_, ///! Epilogue
typename ThreadblockSwizzle_, ///! Threadblock swizzling function
typename KernelArch, ///! The Architecture this kernel is compiled
/// for. Used since SIMT kernels lose top-level
/// arch.
bool SplitKSerial ///! If true, code supporting split-K via serial
/// reduction is enabled.
>
struct GemmFpAIntB {
using Mma = Mma_;
using Epilogue = Epilogue_;
using EpilogueOutputOp = typename Epilogue::OutputOp;
using ThreadblockSwizzle = ThreadblockSwizzle_;
static bool const kSplitKSerial = SplitKSerial;
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Element;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Mma::LayoutC;
using ElementScale = float;
static ComplexTransform const kTransformA = Mma::kTransformA;
static ComplexTransform const kTransformB = Mma::kTransformA;
// Type definitions about the mainloop.
using Operator = typename Mma::Operator;
using OperatorClass = typename Mma::Operator::OperatorClass;
using ThreadblockShape = typename Mma::Shape;
using WarpShape = typename Mma::Operator::Shape;
using InstructionShape = typename Mma::Policy::Operator::InstructionShape;
using ArchTag = typename Mma::ArchTag;
static int const kStages = Mma::kStages;
static int const kAlignmentA = Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB = Mma::IteratorB::AccessType::kElements;
static int const kAlignmentC =
Epilogue::OutputTileIterator::kElementsPerAccess;
/// Warp count (concept: GemmShape)
using WarpCount = typename Mma::WarpCount;
static int const kThreadCount = 32 * WarpCount::kCount;
static int const kSplitKAlignment = const_max(
128 / sizeof_bits<ElementA>::value, 128 / sizeof_bits<ElementB>::value);
static constexpr int kInterleave =
Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
/// Parameters structure
struct Arguments : UniversalArgumentsBase {
GemmUniversalMode mode = GemmUniversalMode::kGemm;
cutlass::gemm::GemmCoord problem_size;
typename Mma::IteratorA::TensorRef ref_A;
typename Mma::IteratorB::TensorRef ref_B;
typename Mma::IteratorScale::TensorRef ref_scale;
typename Epilogue::OutputTileIterator::TensorRef ref_C;
typename Epilogue::OutputTileIterator::TensorRef ref_D;
// Control serial split-k
int batch_count;
typename EpilogueOutputOp::Params output_op;
// For gather+scatter operations
int const* gather_A_indices;
int const* gather_B_indices;
int const* scatter_D_indices;
// Included so we can use Gemm Universal
int batch_stride_D = 0;
//
// Methods
//
CUTLASS_HOST_DEVICE
Arguments() {}
CUTLASS_HOST_DEVICE
Arguments(cutlass::gemm::GemmCoord const& problem_size,
typename Mma::IteratorA::TensorRef ref_A,
typename Mma::IteratorB::TensorRef ref_B,
typename Mma::IteratorScale::TensorRef ref_scale,
typename Epilogue::OutputTileIterator::TensorRef ref_C,
typename Epilogue::OutputTileIterator::TensorRef ref_D,
int serial_split_k_factor,
typename EpilogueOutputOp::Params output_op =
typename EpilogueOutputOp::Params(),
int const* gather_A_indices = nullptr,
int const* gather_B_indices = nullptr,
int const* scatter_D_indices = nullptr)
: UniversalArgumentsBase(mode,
problem_size,
/*serial_split_k_factor=*/1,
/*batch_stride_D=*/0),
ref_A(ref_A),
ref_B(ref_B),
ref_scale(ref_scale),
ref_C(ref_C),
ref_D(ref_D),
output_op(output_op),
gather_A_indices(gather_A_indices),
gather_B_indices(gather_B_indices),
scatter_D_indices(scatter_D_indices) {}
};
/// Parameters structure
struct Params : UniversalParamsBase<ThreadblockSwizzle,
ThreadblockShape,
ElementA,
ElementB,
ElementC> {
using ParamsBase = UniversalParamsBase<ThreadblockSwizzle,
ThreadblockShape,
ElementA,
ElementB,
ElementC>;
typename Mma::IteratorA::Params params_A;
typename Mma::IteratorB::Params params_B;
typename Mma::IteratorScale::Params params_scale;
typename Epilogue::OutputTileIterator::Params params_C;
typename Epilogue::OutputTileIterator::Params params_D;
typename Mma::IteratorA::TensorRef ref_A;
typename Mma::IteratorB::TensorRef ref_B;
typename Mma::IteratorScale::TensorRef ref_scale;
typename Epilogue::OutputTileIterator::TensorRef ref_C;
typename Epilogue::OutputTileIterator::TensorRef ref_D;
typename EpilogueOutputOp::Params output_op;
// For gather+scatter operations
int const* gather_A_indices;
int const* gather_B_indices;
int const* scatter_D_indices;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params() = default;
CUTLASS_HOST_DEVICE
Params(Arguments const& args, int device_sms, int sm_occupancy)
: ParamsBase(args, device_sms, sm_occupancy),
params_A(args.ref_A.layout()),
ref_A(args.ref_A),
params_B(args.ref_B.layout()),
ref_B(args.ref_B),
params_scale(args.ref_scale.layout()),
ref_scale(args.ref_scale),
params_C(args.ref_C.layout()),
ref_C(args.ref_C),
params_D(args.ref_D.layout()),
ref_D(args.ref_D),
output_op(args.output_op),
gather_A_indices(args.gather_A_indices),
gather_B_indices(args.gather_B_indices),
scatter_D_indices(args.scatter_D_indices) {}
};
/// Shared memory storage structure
union SharedStorage {
typename Mma::SharedStorage main_loop;
typename Epilogue::SharedStorage epilogue;
};
//
// Methods
//
CUTLASS_HOST_DEVICE
GemmFpAIntB() {}
/// Determines whether kernel satisfies alignment
CUTLASS_HOST_DEVICE
static Status can_implement(Arguments const& args) {
static int const kAlignmentA =
(platform::is_same<typename Mma::IteratorA::Layout,
layout::ColumnMajorInterleaved<32>>::value)
? 32
: (platform::is_same<typename Mma::IteratorA::Layout,
layout::ColumnMajorInterleaved<64>>::value)
? 64
: Mma::IteratorA::AccessType::kElements;
static int const kAlignmentB =
(platform::is_same<typename Mma::IteratorB::Layout,
layout::RowMajorInterleaved<32>>::value)
? 32
: (platform::is_same<typename Mma::IteratorB::Layout,
layout::RowMajorInterleaved<64>>::value)
? 64
: Mma::IteratorB::AccessType::kElements;
static int const kAlignmentScale =
Mma::IteratorScale::AccessType::kElements;
static int const kAlignmentC =
(platform::is_same<typename Epilogue::OutputTileIterator::Layout,
layout::ColumnMajorInterleaved<32>>::value)
? 32
: (platform::is_same<typename Epilogue::OutputTileIterator::Layout,
layout::ColumnMajorInterleaved<64>>::value)
? 64
: Epilogue::OutputTileIterator::kElementsPerAccess;
if (!TensorRef_aligned(args.ref_A, kAlignmentA)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(args.ref_B, kAlignmentB)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(args.ref_scale, kAlignmentScale)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(args.ref_C, kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
if (!TensorRef_aligned(args.ref_D, kAlignmentC)) {
return Status::kErrorMisalignedOperand;
}
return Status::kSuccess;
}
static size_t get_extra_workspace_size(
Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) {
return 0;
}
CUTLASS_DEVICE
static void invoke(Params const& params,
SharedStorage& shared_storage) { // NOLINT
GemmFpAIntB op;
op(params, shared_storage);
}
// The dummy template parameter is not used and exists so that we can compile
// this code using a standard earlier than C++17. Prior to C++17, fully
// specialized templates HAD to exists in a namespace
template <bool B, typename dummy = void>
struct KernelRunner {
CUTLASS_DEVICE
static void run_kernel(Params const& params,
SharedStorage& shared_storage) { // NOLINT
CUTLASS_NOT_IMPLEMENTED();
}
};
template <typename dummy>
struct KernelRunner<true, dummy> {
CUTLASS_DEVICE
static void run_kernel(Params const& params,
SharedStorage& shared_storage) { // NOLINT
using LayoutB = typename Mma::IteratorB::Layout;
static_assert(
platform::is_same<LayoutB, layout::RowMajor>::value &&
kInterleave == 1 ||
platform::is_same<LayoutB, layout::ColumnMajor>::value &&
kInterleave >= 1,
"B must be row major/col major OR col major interleaved.");
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||
params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {
return;
}
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.k() * params.gemm_k_size,
};
cutlass::MatrixCoord tb_offset_B{
threadblock_tile_offset.k() * params.gemm_k_size * kInterleave,
threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave};
cutlass::MatrixCoord tb_offset_scale{
0, threadblock_tile_offset.n() * Mma::Shape::kN};
// Problem size is a function of threadblock index in the K dimension
int problem_size_k =
min(params.problem_size.k(),
(threadblock_tile_offset.k() + 1) * params.gemm_k_size);
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations =
(problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) /
Mma::Shape::kK;
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
params.params_A,
params.ref_A.data(),
{params.problem_size.m(), problem_size_k},
thread_idx,
tb_offset_A,
params.gather_A_indices);
typename Mma::IteratorB iterator_B(
params.params_B,
params.ref_B.data(),
{problem_size_k * kInterleave, params.problem_size.n() / kInterleave},
thread_idx,
tb_offset_B,
params.gather_B_indices);
typename Mma::IteratorScale iterator_scale(params.params_scale,
params.ref_scale.data(),
{1, params.problem_size.n()},
thread_idx,
tb_offset_scale);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
if (!kSplitKSerial || gemm_k_iterations > 0) {
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations,
accumulators,
iterator_A,
iterator_B,
iterator_scale,
accumulators);
}
//
// Epilogue
//
EpilogueOutputOp output_op(params.output_op);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset =
threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// assume identity swizzle
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.n() * Mma::Shape::kN);
int block_idx = threadblock_tile_offset.m() +
threadblock_tile_offset.n() * params.grid_tiled_shape.m();
// Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
// If performing a reduction via split-K, fetch the initial
// synchronization
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
// Fetch the synchronization lock initially but do not block.
semaphore.fetch();
// Indicate which position in a serial reduction the output operator is
// currently updating
output_op.set_k_partition(threadblock_tile_offset.k(),
params.grid_tiled_shape.k());
}
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(
params.params_C,
params.ref_C.data(),
params.problem_size.mn(),
thread_idx,
threadblock_offset,
params.scatter_D_indices);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
params.params_D,
params.ref_D.data(),
params.problem_size.mn(),
thread_idx,
threadblock_offset,
params.scatter_D_indices);
Epilogue epilogue(
shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
// Wait on the semaphore - this latency may have been covered by iterator
// construction
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
// For subsequent threadblocks, the source matrix is held in the 'D'
// tensor.
if (threadblock_tile_offset.k()) {
iterator_C = iterator_D;
}
semaphore.wait(threadblock_tile_offset.k());
}
// Execute the epilogue operator to update the destination tensor.
epilogue(output_op, iterator_D, accumulators, iterator_C);
//
// Release the semaphore
//
if (kSplitKSerial && params.grid_tiled_shape.k() > 1) {
int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) {
// The final threadblock resets the semaphore for subsequent grids.
lock = 0;
} else {
// Otherwise, the semaphore is incremented
lock = threadblock_tile_offset.k() + 1;
}
semaphore.release(lock);
}
}
};
/*
To improve compilation speed, we do not compile the device operator if the
CUDA_ARCH does not correspond to the ArchTag of the cutlass kernel
operator.
*/
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const& params,
SharedStorage& shared_storage) { // NOLINT
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750)
// static constexpr bool compile_needed = platform::is_same<KernelArch,
// arch::Sm70>::value; KernelRunner<compile_needed>::run_kernel(params,
// shared_storage);
CUTLASS_NOT_IMPLEMENTED();
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
// static constexpr bool compile_needed = platform::is_same<KernelArch,
// arch::Sm75>::value; KernelRunner<compile_needed>::run_kernel(params,
// shared_storage);
CUTLASS_NOT_IMPLEMENTED();
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
static constexpr bool compile_needed =
platform::is_same<KernelArch, arch::Sm80>::value;
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
#else
CUTLASS_NOT_IMPLEMENTED();
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
This file exists so that we use the same weight layout for MoE grouped gemm
and regular gemm when the weight is quantized. The preprocessing code reads
this template to know how to organize the quantized weight matrices to be
consumed by CUTLASS.
Note that for int4, ThreadBlockK MUST be 64.
*/
#pragma once
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/platform/platform.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/tile_interleaved_layout.h"
namespace cutlass {
namespace gemm {
namespace kernel {
template <typename TypeB, typename Arch, typename Enable = void>
struct LayoutDetailsB {};
// // Volta specialiations. Volta will dequantize before STS, so we need a
// different operator
template <typename TypeB>
struct LayoutDetailsB<TypeB, arch::Sm70> {
static constexpr int ThreadblockK = 64;
using Layout = layout::RowMajor;
static constexpr int ElementsPerAccess = 8;
using Operator = cutlass::arch::OpMultiplyAdd;
};
// Specializations for Turing+ when B is FP16. These are currently only used for
// MoE networks.
template <typename Arch>
struct LayoutDetailsB<
half_t,
Arch,
typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
static constexpr int ThreadblockK = 64;
using Layout = layout::RowMajor;
static constexpr int ElementsPerAccess =
128 / cutlass::sizeof_bits<half_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
template <typename Arch>
struct LayoutDetailsB<
bfloat16_t,
Arch,
typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
static constexpr int ThreadblockK = 64;
using Layout = layout::RowMajor;
static constexpr int ElementsPerAccess =
128 / cutlass::sizeof_bits<bfloat16_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
// Specializations for Turing+ when B is quantized. These can use the operator
// OpMultiplyAddDequantizeInterleavedBToA, which signals that we want to
// dequantize after loading from smem.
template <typename Arch>
struct LayoutDetailsB<
uint8_t,
Arch,
typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
static constexpr int ThreadblockK = 64;
private:
static constexpr int ElementsPerCacheLine =
128 * 8 / sizeof_bits<uint8_t>::value;
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
public:
using Layout =
layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
static constexpr int ElementsPerAccess =
128 / cutlass::sizeof_bits<uint8_t>::value;
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
};
template <typename Arch>
struct LayoutDetailsB<
uint4b_t,
Arch,
typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type> {
static constexpr int ThreadblockK = 64;
private:
static constexpr int ElementsPerCacheLine =
128 * 8 / sizeof_bits<uint4b_t>::value;
static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK;
public:
using Layout =
layout::ColumnMajorTileInterleave<ThreadblockK, ColumnsInterleaved>;
static constexpr int ElementsPerAccess =
128 / cutlass::sizeof_bits<uint4b_t>::value;
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
};
} // namespace kernel
} // namespace gemm
} // namespace cutlass
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/interleaved_numeric_conversion.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
// We need to distinguish here, since we want volta support. It is too much
// effort to write shared memory iterators that are probably needed for volta to
// function properly. As a result, we allow converters both after the LDG (for
// volta) and after the LDS for Turing+.
template <
/// Iterator for B matrix in global memory
typename IteratorB,
/// Warp level Mma
typename MmaOperator,
/// Math operation perform by warp level operator
typename MathOperator>
struct SetConverters {};
// Dequantize after LDG, so set transforms accordingly
template <
/// Iterator for B matrix in global memory
typename IteratorB,
/// Mma Policy
typename MmaOperator>
struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAdd> {
using TransformAfterLDG = FastInterleavedAndBiasedNumericArrayConverter<
typename MmaOperator::ArchMmaOperator::ElementB,
typename IteratorB::Element,
IteratorB::Fragment::kElements>;
using TransformAfterLDS =
NumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
typename MmaOperator::ArchMmaOperator::ElementB,
MmaOperator::FragmentB::kElements>;
};
// Dequantize after LDS, so set transforms accordingly
template <
/// Iterator for B matrix in global memory
typename IteratorB,
/// Mma Policy
typename MmaOperator>
struct SetConverters<IteratorB,
MmaOperator,
arch::OpMultiplyAddDequantizeInterleavedBToA> {
using TransformAfterLDG =
NumericArrayConverter<typename IteratorB::Element,
typename IteratorB::Element,
IteratorB::Fragment::kElements>;
using TransformAfterLDS = FastInterleavedAndBiasedNumericArrayConverter<
typename MmaOperator::ArchMmaOperator::ElementB,
typename TransformAfterLDG::result_type::Element,
MmaOperator::FragmentB::kElements>;
};
////////////////////////////////////////////////////////////////////////////////
template <
/// Element type for A matrix operand
typename ElementA_,
/// Layout type for A matrix operand
typename LayoutA_,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Element type for B matrix operand
typename ElementB_,
/// Layout type for B matrix operand
typename LayoutB_,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for the input scale
typename ElementScale_,
/// Layout for the scale operand
typename LayoutScale_,
/// Access granularity of Scales in unit of elements
int kAlignmentScale,
/// Element type for internal accumulation
typename ElementAccumulator_,
/// Layout type for C and D matrix operands
typename LayoutC_,
/// Operator class tag
typename OperatorClass_,
/// Tag indicating architecture to tune for
typename ArchTag_,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape_,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape_,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape_,
/// Number of stages used in the pipelined mainloop
int Stages,
/// Operation performed by GEMM
typename Operator_,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
///
typename Enable = void>
struct DqMma;
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "cutlass/gemm/threadblock/default_mma.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/default_mma_tensor_op.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/tile_interleaved_layout.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
template <
/// Type for elementA
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Type for element B
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for the input scale
typename ElementScale,
/// Layout for the scale operand
typename LayoutScale,
/// Access granularity of Scales in unit of elements
int kAlignmentScale,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Stages in GEMM
int kStages,
///
typename Operator,
///
SharedMemoryClearOption SharedMemoryClear>
struct DqMma<ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementScale,
LayoutScale,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
SharedMemoryClear,
typename platform::enable_if<(ArchTag::kMinComputeCapability >=
80)>::type> {
static_assert(platform::is_same<ElementA, half_t>::value ||
platform::is_same<ElementA, bfloat16_t>::value,
"Element A must be fp16 or bf16");
static_assert(
platform::is_same<Operator,
arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
"Mma multistage must dequantize after ldsm");
static_assert(platform::is_same<ElementB, uint8_t>::value ||
platform::is_same<ElementB, uint4b_t>::value,
"Element B must be int8 or uint4");
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the MmaCore components
// Mma core does not depend on stages, so pass in at least 3 here to mma
// multistage pieces are created
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
WarpShape,
InstructionShape,
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementAccumulator,
layout::RowMajor,
OperatorClass,
std::max(kStages, 3),
Operator,
false,
CacheOpA,
CacheOpB>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
using IteratorA =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA,
LayoutA,
1,
ThreadMapA,
AccessTypeA>;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
using IteratorB =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB,
LayoutB,
0,
ThreadMapB,
AccessTypeB>;
// ThreadMap for scale iterator
static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
MmaCore::Shape::kN / kAlignmentScale,
kAlignmentScale>;
// Define iterators over tiles from the scale operand
using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<1, MmaCore::Shape::kN>,
ElementScale,
LayoutScale,
0,
IteratorScaleThreadMap,
kAlignmentScale>;
using SmemIteratorScale = IteratorScale;
using Converter = FastInterleavedAndBiasedNumericArrayConverter<
ElementA,
ElementB,
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<
typename MmaCore::Shape,
IteratorA,
typename MmaCore::SmemIteratorA,
MmaCore::kCacheOpA,
IteratorB,
typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB,
IteratorScale,
SmemIteratorScale,
ElementAccumulator,
layout::RowMajor,
typename MmaCore::MmaPolicy,
kStages,
Converter,
SharedMemoryClear>;
};
template <
/// Type for element A
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Type for element B
typename ElementB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for the input scale
typename ElementScale,
/// Layout for the scale operand
typename LayoutScale,
/// Access granularity of Scales in unit of elements
int kAlignmentScale,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Stages in GEMM
int kStages,
///
typename Operator,
///
SharedMemoryClearOption SharedMemoryClear,
///
int RowsPerTile,
///
int ColumnsInterleaved>
struct DqMma<ElementA,
LayoutA,
kAlignmentA,
ElementB,
layout::ColumnMajorTileInterleave<RowsPerTile, ColumnsInterleaved>,
kAlignmentB,
ElementScale,
LayoutScale,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
SharedMemoryClear,
typename platform::enable_if<(ArchTag::kMinComputeCapability >=
80)>::type> {
static_assert(platform::is_same<ElementA, half_t>::value ||
platform::is_same<ElementA, bfloat16_t>::value,
"Element A must be fp16 or bf16");
static_assert(
platform::is_same<Operator,
arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
"Mma multistage must dequantize after ldsm");
static_assert(platform::is_same<ElementB, uint8_t>::value ||
platform::is_same<ElementB, uint4b_t>::value,
"Element B must be uint8 or uint4");
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the MmaCore components
// Mma core does not depend on stages, so pass in at least 3 here to mma
// multistage pieces are created
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
WarpShape,
InstructionShape,
ElementA,
LayoutA,
ElementB,
layout::ColumnMajor,
ElementAccumulator,
layout::RowMajor,
OperatorClass,
std::max(kStages, 3),
Operator,
false,
CacheOpA,
CacheOpB>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
using IteratorA =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA,
LayoutA,
1,
ThreadMapA,
AccessTypeA>;
private:
static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
static_assert(RowsPerTile == MmaCore::Shape::kK, "");
using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
using OriginalWarpArrangement =
typename OriginalThreadMap::Detail::WarpThreadArrangement;
static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
using GmemIteratorShape =
MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved,
MmaCore::Shape::kN / ColumnsInterleaved>;
using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<GmemIteratorShape::kRow,
GmemIteratorShape::kColumn>,
OriginalThreadMap::kThreads,
layout::PitchLinearShape<
OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
public:
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
using IteratorB =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
GmemIteratorShape,
ElementB,
layout::ColumnMajor,
0,
GmemThreadMapB,
AccessTypeB>;
// ThreadMap for scale iterator
static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
MmaCore::Shape::kN / kAlignmentScale,
kAlignmentScale>;
// Define iterators over tiles from the scale operand
using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<1, MmaCore::Shape::kN>,
ElementScale,
LayoutScale,
0,
IteratorScaleThreadMap,
kAlignmentScale>;
using SmemIteratorScale = IteratorScale;
using Converter = FastInterleavedAndBiasedNumericArrayConverter<
ElementA,
ElementB,
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<
typename MmaCore::Shape,
IteratorA,
typename MmaCore::SmemIteratorA,
MmaCore::kCacheOpA,
IteratorB,
typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB,
IteratorScale,
SmemIteratorScale,
ElementAccumulator,
layout::RowMajor,
typename MmaCore::MmaPolicy,
kStages,
Converter,
SharedMemoryClear>;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "cutlass/gemm/threadblock/default_mma.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/default_mma_tensor_op.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/tile_interleaved_layout.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
template <
/// Type for element A
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Type for element B
typename ElementB,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for the input scale
typename ElementScale,
/// Layout for the scale operand
typename LayoutScale,
/// Access granularity of Scales in unit of elements
int kAlignmentScale,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator>
struct DqMma<
ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementScale,
LayoutScale,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator,
SharedMemoryClearOption::kNone,
typename platform::enable_if<(ArchTag::kMinComputeCapability < 80)>::type> {
static_assert(platform::is_same<ElementA, half_t>::value ||
platform::is_same<ElementA, bfloat16_t>::value,
"Element A must be fp16 or bf16");
static_assert(platform::is_same<ElementB, uint8_t>::value ||
platform::is_same<ElementB, uint4b_t>::value,
"Element B must be int8 or uint4");
static constexpr bool DqAfterLDG =
platform::is_same<arch::OpMultiplyAdd, Operator>::value;
static constexpr bool arch_has_bf16_mma =
ArchTag::kMinComputeCapability >= 80;
using MmaCoreElementA =
typename platform::conditional<arch_has_bf16_mma, ElementA, half_t>::type;
using MmaCoreElementB = typename platform::
conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
// Define the MmaCore components
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
WarpShape,
InstructionShape,
MmaCoreElementA,
LayoutA,
MmaCoreElementB,
LayoutB,
ElementAccumulator,
layout::RowMajor,
OperatorClass,
2,
Operator>;
// Define iterators over tiles from the A operand
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
ElementA,
LayoutA,
1,
typename MmaCore::IteratorThreadMapA,
kAlignmentA>;
// Define iterators over tiles from the B operand
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>,
ElementB,
LayoutB,
0,
typename MmaCore::IteratorThreadMapB,
kAlignmentB>;
// ThreadMap for scale iterator
static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
MmaCore::Shape::kN / kAlignmentScale,
kAlignmentScale>;
// Define iterators over tiles from the scale operand
using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<1, MmaCore::Shape::kN>,
ElementScale,
LayoutScale,
0,
IteratorScaleThreadMap,
kAlignmentScale>;
using SmemScaleType = typename platform::
conditional<arch_has_bf16_mma, ElementScale, half_t>::type;
using SmemIteratorScale =
cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<1, MmaCore::Shape::kN>,
SmemScaleType,
LayoutScale,
0,
IteratorScaleThreadMap,
kAlignmentScale>;
using Converters =
SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<
typename MmaCore::Shape,
IteratorA,
typename MmaCore::SmemIteratorA,
IteratorB,
typename MmaCore::SmemIteratorB,
IteratorScale,
SmemIteratorScale,
ElementAccumulator,
layout::RowMajor,
typename MmaCore::MmaPolicy,
typename Converters::TransformAfterLDG,
typename Converters::TransformAfterLDS>;
};
// Specialization to handle column major interleave B
template <
/// Type for element A
typename ElementA,
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Type for element B
typename ElementB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for the input scale
typename ElementScale,
/// Layout for the scale operand
typename LayoutScale,
/// Access granularity of Scales in unit of elements
int kAlignmentScale,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Operator class tag
typename OperatorClass,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
///
int RowsPerTile,
///
int ColumnsInterleaved>
struct DqMma<
ElementA,
LayoutA,
kAlignmentA,
ElementB,
layout::ColumnMajorTileInterleave<RowsPerTile, ColumnsInterleaved>,
kAlignmentB,
ElementScale,
LayoutScale,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator,
SharedMemoryClearOption::kNone,
typename platform::enable_if<(ArchTag::kMinComputeCapability < 80)>::type> {
static_assert(platform::is_same<ElementA, half_t>::value ||
platform::is_same<ElementA, bfloat16_t>::value,
"Element A must be fp16 or bf16");
static_assert(platform::is_same<ElementB, uint8_t>::value ||
platform::is_same<ElementB, uint4b_t>::value,
"Element B must be int8 or uint4");
static constexpr bool DqAfterLDG =
platform::is_same<arch::OpMultiplyAdd, Operator>::value;
static constexpr bool arch_has_bf16_mma =
ArchTag::kMinComputeCapability >= 80;
using MmaCoreElementA =
typename platform::conditional<arch_has_bf16_mma, ElementA, half_t>::type;
using MmaCoreElementB = typename platform::
conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
// Define the MmaCore components
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
WarpShape,
InstructionShape,
MmaCoreElementA,
LayoutA,
MmaCoreElementB,
layout::ColumnMajor,
ElementAccumulator,
layout::RowMajor,
OperatorClass,
2,
Operator>;
// Define iterators over tiles from the A operand
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
ElementA,
LayoutA,
1,
typename MmaCore::IteratorThreadMapA,
kAlignmentA>;
private:
static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
static_assert(RowsPerTile == MmaCore::Shape::kK, "");
using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
using OriginalWarpArrangement =
typename OriginalThreadMap::Detail::WarpThreadArrangement;
static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
using GmemIteratorShape =
MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved,
MmaCore::Shape::kN / ColumnsInterleaved>;
using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<GmemIteratorShape::kRow,
GmemIteratorShape::kColumn>,
OriginalThreadMap::kThreads,
layout::PitchLinearShape<
OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
public:
// Define iterators over tiles from the B operand
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
GmemIteratorShape,
ElementB,
layout::ColumnMajor,
0,
GmemThreadMapB,
kAlignmentB>;
// ThreadMap for scale iterator
static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
MmaCore::Shape::kN / kAlignmentScale,
kAlignmentScale>;
// Define iterators over tiles from the scale operand
using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<1, MmaCore::Shape::kN>,
ElementScale,
LayoutScale,
0,
IteratorScaleThreadMap,
kAlignmentScale>;
using SmemScaleType = typename platform::
conditional<arch_has_bf16_mma, ElementScale, half_t>::type;
using SmemIteratorScale =
cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<1, MmaCore::Shape::kN>,
SmemScaleType,
LayoutScale,
0,
IteratorScaleThreadMap,
kAlignmentScale>;
using Converters =
SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<
typename MmaCore::Shape,
IteratorA,
typename MmaCore::SmemIteratorA,
IteratorB,
typename MmaCore::SmemIteratorB,
IteratorScale,
SmemIteratorScale,
ElementAccumulator,
layout::RowMajor,
typename MmaCore::MmaPolicy,
typename Converters::TransformAfterLDG,
typename Converters::TransformAfterLDS>;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_mma_bf16.h"
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16
/// activation & int8 weight
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator>
struct DefaultMma<cutlass::half_t,
LayoutA,
kAlignmentA,
uint8_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator> {
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<float>::value;
using Mma = DqMma<half_t,
LayoutA,
kAlignmentA,
uint8_t,
LayoutB,
kAlignmentB,
float,
layout::RowMajor,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16
/// activation & int4 weight
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator>
struct DefaultMma<cutlass::half_t,
LayoutA,
kAlignmentA,
uint4b_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator> {
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<float>::value;
using Mma = DqMma<half_t,
LayoutA,
kAlignmentA,
uint4b_t,
LayoutB,
kAlignmentB,
float,
layout::RowMajor,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
///
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<cutlass::half_t,
LayoutA,
kAlignmentA,
uint8_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
false,
SharedMemoryClear> {
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<float>::value;
using Mma = DqMma<half_t,
LayoutA,
kAlignmentA,
uint8_t,
LayoutB,
kAlignmentB,
float,
layout::RowMajor,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
SharedMemoryClear>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16
/// activation & int4 weight
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Tag indicating architecture to tune for
typename ArchTag,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
///
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<cutlass::half_t,
LayoutA,
kAlignmentA,
uint4b_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
false,
SharedMemoryClear> {
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<float>::value;
using Mma = DqMma<half_t,
LayoutA,
kAlignmentA,
uint4b_t,
LayoutB,
kAlignmentB,
float,
layout::RowMajor,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
SharedMemoryClear>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps
// avoid reg spills on large tile when not enough shared mem is present to do 3+
// stage
template <
/// Layout type for A matrix operand
typename LayoutA,
/// Access granularity of A matrix in units of elements
int kAlignmentA,
/// Layout type for B matrix operand
typename LayoutB,
/// Access granularity of B matrix in units of elements
int kAlignmentB,
/// Element type for internal accumulation
typename ElementAccumulator,
/// Threadblock-level tile size (concept: GemmShape)
typename ThreadblockShape,
/// Warp-level tile size (concept: GemmShape)
typename WarpShape,
/// Instruction-level tile size (concept: GemmShape)
typename InstructionShape,
/// Operation performed by GEMM
typename Operator,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear,
/// Gather operand A by using an index array
bool GatherA,
/// Gather operand B by using an index array
bool GatherB>
struct DefaultMma<half_t,
LayoutA,
kAlignmentA,
half_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
arch::Sm80,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator,
false,
SharedMemoryClear,
GatherA,
GatherB> {
// Define the MmaCore components
// 3 is used on purpose here to trigger components for mma multistage
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
WarpShape,
InstructionShape,
half_t,
LayoutA,
half_t,
LayoutB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
3,
Operator>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
using IteratorA =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
half_t,
LayoutA,
1,
ThreadMapA,
AccessTypeA,
GatherA>;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
using IteratorB =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
half_t,
LayoutB,
0,
ThreadMapB,
AccessTypeB,
GatherB>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma =
cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape,
IteratorA,
typename MmaCore::SmemIteratorA,
MmaCore::kCacheOpA,
IteratorB,
typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB,
ElementAccumulator,
layout::RowMajor,
typename MmaCore::MmaPolicy,
2>;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*! \file
\brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/
#pragma once
#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/threadblock/mma_base.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
// SFINAE trick so I can keep the same loop code for Volta and dispatch to the
// correct warp level mma. On volta, all data is stored to shared memory as
// FP16.
template <typename WarpMma, int kExpansionFactor = 1>
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, // NOLINT
typename WarpMma::FragmentC& D, // NOLINT
typename WarpMma::FragmentA const& A,
typename WarpMma::FragmentB const& B,
typename WarpMma::FragmentC const& C,
const int warp_tileB_k_offset) {
warp_mma(D, A, B, C);
}
template <typename WarpMma, int kExpansionFactor = WarpMma::kExpansionFactor>
CUTLASS_DEVICE void run_warp_mma(
WarpMma& warp_mma, // NOLINT
typename WarpMma::FragmentC& D, // NOLINT
typename WarpMma::TransformedFragmentA const& A,
typename WarpMma::TransformedFragmentB const& B,
typename WarpMma::FragmentC const& C,
const int warp_tileB_k_offset) {
warp_mma(D, A, B, C, warp_tileB_k_offset);
}
////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Policy describing tuning details (concept: MmaPolicy)
typename Policy_,
/// The type of the scales
typename ElementScale_,
/// Number of stages,
int Stages,
/// Used for partial specialization
typename Enable = bool>
class DqMmaBase {
public:
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Policy describing tuning details
using Policy = Policy_;
///< Type of the scale to be loaded
using ElementScale = ElementScale_;
//
// Dependent types
//
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Shape describing the overall GEMM computed from shared memory
/// by each warp.
using WarpGemm = typename Policy::Operator::Shape;
/// Shape describing the number of warps filling the CTA
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM,
Shape::kN / WarpGemm::kN,
Shape::kK / WarpGemm::kK>;
/// Number of warp-level GEMM oeprations
static int const kWarpGemmIterations =
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
static constexpr int kNumKIterationsPerWarpBLoad =
Operator::IteratorB::InstructionShape::kRow /
Operator::InstructionShape::kK;
static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), "");
static constexpr int kWarpGemmIterationsForB =
kWarpGemmIterations / kNumKIterationsPerWarpBLoad;
/// Number of stages
static int const kStages = Stages;
/// Tensor reference to the A operand
using TensorRefA =
TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
/// Tensor reference to the B operand
using TensorRefB =
TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
//
// Nested structs
//
/// Shared storage object needed by threadblock-scoped GEMM
class SharedStorage {
public:
//
// Type definitions
//
/// Shape of the A matrix operand in shared memory
using ShapeA =
MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,
Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
/// Shape of the B matrix operand in shared memory
using ShapeB = MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
Shape::kN + Policy::SmemPaddingB::kColumn>;
public:
//
// Data members
//
/// Buffer for A operand
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
/// Buffer for B operand
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
/// Buffer to hold scales for threadblock
AlignedBuffer<ElementScale, Shape::kN> operand_scale;
public:
//
// Methods
//
/// Returns a layout object for the A matrix
CUTLASS_DEVICE
static typename Operator::LayoutA LayoutA() {
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
}
/// Returns a layout object for the B matrix
CUTLASS_HOST_DEVICE
static typename Operator::LayoutB LayoutB() {
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
}
/// Returns a TensorRef to the A operand
CUTLASS_HOST_DEVICE
TensorRefA operand_A_ref() {
return TensorRefA{operand_A.data(), LayoutA()};
}
/// Returns a TensorRef to the B operand
CUTLASS_HOST_DEVICE
TensorRefB operand_B_ref() {
return TensorRefB{operand_B.data(), LayoutB()};
}
};
protected:
//
// Data members
//
/// Iterator to load a warp-scoped tile of A operand from shared memory
typename Operator::IteratorA warp_tile_iterator_A_;
/// Iterator to load a warp-scoped tile of B operand from shared memory
typename Operator::IteratorB warp_tile_iterator_B_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DqMmaBase(
///< Shared storage needed for internal use by threadblock-scoped GEMM
SharedStorage& shared_storage, // NOLINT
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*! \file
\brief Default warp-level GEMM operators selected by data type, size, and
layouts of operands.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/gemm/warp/default_mma_tensor_op.h"
#include "cutlass/gemm/warp/mma_tensor_op.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h"
#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h"
namespace cutlass {
namespace gemm {
namespace warp {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for m-by-n-by-kgroup
template <
/// Shape of one matrix production operation (concept: GemmShape)
typename WarpShape_,
/// Shape of one matrix production operation (concept: GemmShape)
typename InstructionShape_,
/// Data type of A elements,
typename ElementA,
/// Layout of A matrix (concept: MatrixLayout)
typename LayoutA,
/// Data type of B elements
typename ElementB,
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB,
/// Element type of C matrix
typename ElementC,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC,
/// Number of partitions along K dimension
int PartitionsK,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool AccumulatorsInRowMajor>
struct DefaultMmaTensorOp<WarpShape_,
InstructionShape_,
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
arch::OpMultiplyAddDequantizeInterleavedBToA,
PartitionsK,
AccumulatorsInRowMajor> {
private:
// Shape for computing the FP16s
using ComputeInstructionShape = InstructionShape_;
// Chosen so we get K=16 for int8 and K=32 for int4.
static constexpr int LoadInstructionK =
8 * sizeof_bits<ElementA>::value / sizeof_bits<ElementB>::value;
// Shape for loading the narrow data type from shared memory
using LoadInstructionShape =
GemmShape<InstructionShape_::kM, InstructionShape_::kN, LoadInstructionK>;
public:
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
cutlass::arch::Mma<InstructionShape_,
32,
ElementA,
cutlass::layout::RowMajor,
ElementA,
cutlass::layout::ColumnMajor,
ElementC,
cutlass::layout::RowMajor,
arch::OpMultiplyAdd>,
cutlass::MatrixShape<1, 1>>;
// Define the warp-level tensor op
using Type =
cutlass::gemm::warp::MmaTensorOpComputeBWithF16<WarpShape_,
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
Policy,
LoadInstructionShape,
PartitionsK,
AccumulatorsInRowMajor>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace warp
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/matrix_coord.h"
#include "cutlass/pitch_linear_coord.h"
namespace cutlass {
namespace layout {
template <int RowsPerTile, int ColumnsInterleaved>
class ColumnMajorTileInterleave {
static constexpr int kRowsPerTile = RowsPerTile;
static constexpr int kColumnsInterleaved = ColumnsInterleaved;
};
template <class T>
struct IsColumnMajorTileInterleave {
static constexpr bool value = false;
};
template <int U, int V>
struct IsColumnMajorTileInterleave<ColumnMajorTileInterleave<U, V>> {
static constexpr bool value = true;
};
} // namespace layout
} // namespace cutlass
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void LLMInt8MatMulKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& weight,
const DenseTensor& weight_scale,
const float threshold,
DenseTensor* out);
} // namespace phi
此差异已折叠。
此差异已折叠。
...@@ -381,6 +381,7 @@ class LayerHelperBase: ...@@ -381,6 +381,7 @@ class LayerHelperBase:
and dtype != core.VarDesc.VarType.FP64 and dtype != core.VarDesc.VarType.FP64
and dtype != core.VarDesc.VarType.FP16 and dtype != core.VarDesc.VarType.FP16
and dtype != core.VarDesc.VarType.BF16 and dtype != core.VarDesc.VarType.BF16
and dtype != core.VarDesc.VarType.INT8
): ):
raise TypeError( raise TypeError(
"Can not create parameter with default initializer when dtype is not ['float16', 'float32', 'float64', 'bfloat16'] type. Set default_initializer to fit the parameter dtype!" "Can not create parameter with default initializer when dtype is not ['float16', 'float32', 'float64', 'bfloat16'] type. Set default_initializer to fit the parameter dtype!"
...@@ -392,6 +393,7 @@ class LayerHelperBase: ...@@ -392,6 +393,7 @@ class LayerHelperBase:
'float64', 'float64',
'bfloat16', 'bfloat16',
'float', 'float',
'int8',
]: ]:
raise TypeError( raise TypeError(
"Can not create parameter with default initializer when dtype is not ['float16', 'float32', 'float64', 'bfloat16', 'float'] type. Set default_initializer to fit the parameter dtype!" "Can not create parameter with default initializer when dtype is not ['float16', 'float32', 'float64', 'bfloat16', 'float'] type. Set default_initializer to fit the parameter dtype!"
......
...@@ -71,7 +71,7 @@ from .layer.common import AlphaDropout # noqa: F401 ...@@ -71,7 +71,7 @@ from .layer.common import AlphaDropout # noqa: F401
from .layer.common import Unfold # noqa: F401 from .layer.common import Unfold # noqa: F401
from .layer.common import Fold # noqa: F401 from .layer.common import Fold # noqa: F401
from .layer.common import Unflatten # noqa: F401 from .layer.common import Unflatten # noqa: F401
from .layer.common import LinearCompress # noqa: F401
from .layer.pooling import AvgPool1D # noqa: F401 from .layer.pooling import AvgPool1D # noqa: F401
from .layer.pooling import AvgPool2D # noqa: F401 from .layer.pooling import AvgPool2D # noqa: F401
from .layer.pooling import AvgPool3D # noqa: F401 from .layer.pooling import AvgPool3D # noqa: F401
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册