diff --git a/paddle/fluid/operators/assign_value_op.h b/paddle/fluid/operators/assign_value_op.h index 236e797ec8603f4b51c78dbf40f2dc5f56545e31..2522fa580f758a3f75483e30b5ae196f18f5c248 100644 --- a/paddle/fluid/operators/assign_value_op.h +++ b/paddle/fluid/operators/assign_value_op.h @@ -112,11 +112,13 @@ class AssignValueKernel : public framework::OpKernel { break; case framework::proto::VarType::INT64: value_name = "int64_values"; + case framework::proto::VarType::INT8: + value_name = "int8_values"; break; default: PADDLE_THROW(platform::errors::Unimplemented( "Unsupported data type(code %d) for AssignValue operator, only " - "supports bool, int32, float32 and int64.", + "supports bool, int32, float32, int8 and int64.", dtype)); break; } diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index c437fb984baa6b6f27ac0636f52cf8021ee07a20..2c5d7bc256537023fe4c232a459f9460f1940cb0 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1347,6 +1347,16 @@ data_transform : skip_transform : out_size, size_tensor, scale_tensor +- op : llm_int8_mat_mul + args : (Tensor x, Tensor weight, Tensor weight_scale, float threshold=6.0) + output : Tensor(out) + infer_meta : + func : LLMInt8MatMulInferMeta + param : [x, weight] + kernel : + func : llm_int8_mat_mul + data_type : x + - op : log args : (Tensor x) output : Tensor @@ -1896,6 +1906,15 @@ func : qr backward : qr_grad +- op : quant_for_compress + args : (Tensor x, int bits = 8, str layout = "weight_only") + output : Tensor(out), Tensor(scale) + infer_meta : + func : QuantForCompressInferMeta + kernel : + func : quant_for_compress + data_type: x + - op : real args : (Tensor x) output : Tensor (out) @@ -2563,6 +2582,16 @@ intermediate: warprnntgrad backward : warprnnt_grad +- op : weight_only_mat_mul + args : (Tensor x, Tensor weight, Tensor weight_scale) + output : Tensor(out) + infer_meta : + func : WeightOnlyMatMulInferMeta + param : [x, weight] + kernel : + func : weight_only_mat_mul + data_type : x + - op : weighted_sample_neighbors args : (Tensor row, Tensor colptr, Tensor edge_weight, Tensor input_nodes, Tensor eids, int sample_size, bool return_eids) output : Tensor(out_neighbors), Tensor(out_count), Tensor(out_eids) diff --git a/paddle/phi/common/datatype_traits.h b/paddle/phi/common/datatype_traits.h new file mode 100644 index 0000000000000000000000000000000000000000..12540909260e711c8e3236c4035e9dedc5a8cbb4 --- /dev/null +++ b/paddle/phi/common/datatype_traits.h @@ -0,0 +1,43 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/float16.h" + +#pragma once + +namespace phi { + +template +struct PDDataTypeTraits { + using DataType = T; +}; + +template <> +struct PDDataTypeTraits { + // Since LayerNormDirectCUDAFunctor register half type, we need to convert + // phi::float16 to half. + using DataType = half; +}; + +#ifdef PADDLE_CUDA_BF16 +template <> +class PDDataTypeTraits { + public: + using DataType = __nv_bfloat16; +}; +#endif + +} // namespace phi diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 46b90d5d42baef22b8ff2772466e81e32afd1fd8..71bbfaa333a0acb732f5d07f8c248583e4f06754 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -3572,5 +3572,51 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row, out_count->set_dtype(DataType::INT32); } +void LLMInt8MatMulInferMeta(const MetaTensor& x, + const MetaTensor& weight, + MetaTensor* out) { + auto x_dims = x.dims(); + auto w_dims = weight.dims(); + PADDLE_ENFORCE_EQ( + w_dims.size(), + 2UL, + errors::InvalidArgument("The input(weight) must be a 2D Tensor.")); + PADDLE_ENFORCE_EQ( + x_dims[x_dims.size() - 1], + w_dims[1], + errors::InvalidArgument( + "Input(X) dim[-1] and Input(Weight) dim[1] should be euqal." + "But received Input(X) dim[-1](%s) != Input(Weight) dim[1](%s)", + x_dims[x_dims.size() - 1], + w_dims[1])); + auto out_dims = x_dims; + out_dims[out_dims.size() - 1] = w_dims[0]; + out->set_dims(out_dims); + out->set_dtype(x.dtype()); +} + +void WeightOnlyMatMulInferMeta(const MetaTensor& x, + const MetaTensor& weight, + MetaTensor* out) { + auto x_dims = x.dims(); + auto w_dims = weight.dims(); + PADDLE_ENFORCE_EQ( + w_dims.size(), + 2UL, + errors::InvalidArgument("The input(weight) must be a 2D Tensor.")); + PADDLE_ENFORCE_EQ( + x_dims[x_dims.size() - 1], + w_dims[0], + errors::InvalidArgument( + "Input(X) dim[-1] and Input(Weight) dim[0] should be euqal." + "But received Input(X) dim[-1](%s) != Input(Weight) dim[0](%s)", + x_dims[x_dims.size() - 1], + w_dims[0])); + auto out_dims = x_dims; + out_dims[out_dims.size() - 1] = w_dims[1]; + out->set_dims(out_dims); + out->set_dtype(x.dtype()); +} + } // namespace phi PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index d0cc876c840ac8bb3e99cab2b7a6f8a466510b55..67a39780aa9c203fc309b63398f5e0bb7dc56296 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -673,6 +673,31 @@ void MoeInferMeta(const MetaTensor& x, const std::string& act_type, MetaTensor* out); +void FusedMultiHeadAttentionInferMeta(const MetaTensor& query, + const MetaTensor& key, + const MetaTensor& value, + const MetaTensor& mask, + float scale, + bool causal, + MetaTensor* out); + +void FusedMultiHeadAttentionVariableInferMeta(const MetaTensor& query, + const MetaTensor& key, + const MetaTensor& value, + const MetaTensor& seq_lens, + const MetaTensor& mask, + float scale, + bool causal, + MetaTensor* out); + +void LLMInt8MatMulInferMeta(const MetaTensor& x, + const MetaTensor& weight, + MetaTensor* out); + +void WeightOnlyMatMulInferMeta(const MetaTensor& x, + const MetaTensor& weight, + MetaTensor* out); + void FusedRopeInferMeta(const MetaTensor& q, const MetaTensor& k, const MetaTensor& v, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 5ae5232e1b0506645c324cbd6dc1ecbe8ef380e5..d590f0a875d716602a12bd070a2aedd884d81716 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -5062,6 +5062,51 @@ void CheckNumericsInferMeta(const MetaTensor& tensor, values->set_dims(phi::make_ddim({3})); } +void QuantForCompressInferMeta(const MetaTensor& x, + int bits, + const std::string& layout, + MetaTensor* out, + MetaTensor* scale) { + auto x_dims = x.dims(); + PADDLE_ENFORCE_EQ( + x_dims.size(), + 2UL, + phi::errors::InvalidArgument( + "The x tensor of quant op must be 2D, but got[%d]", x_dims.size())); + PADDLE_ENFORCE_GE( + x_dims[0], + 64, + phi::errors::OutOfRange("The first dimension of input is out of range " + "(expected at least 64, but got %ld).", + x_dims[0])); + PADDLE_ENFORCE_EQ( + x_dims[0] % 64, + 0, + phi::errors::InvalidArgument( + "The first dimension of input must be divisible by 64, but got[%d]", + x_dims[0])); + std::vector dim_scale({x_dims[1]}); + std::vector dim_out; + if (layout == "weight_only") { + dim_out = std::vector({x_dims[0], x_dims[1]}); + } else if (layout == "llm.int8") { + dim_out = std::vector({x_dims[1], x_dims[0]}); + } else { + phi::errors::InvalidArgument( + "The layout must be weight_only or llm.int8, but got %s", layout); + } + out->set_dims(phi::make_ddim(dim_out)); + + // TODO(lizhenyun) support weight_only int4 + if (bits == 8) { + out->set_dtype(DataType::INT8); + } else { + phi::errors::Fatal("The bits only support 8, but got[%d]", bits); + } + scale->set_dims(phi::make_ddim(dim_scale)); + scale->set_dtype(DataType::FLOAT32); +} + } // namespace phi PD_REGISTER_INFER_META_FN(flatten, phi::FlattenInferMeta); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 9f2118344b2895e8b8f7063be94f36a540f29137..f0257fb75a22e6a1f81686c773a6ca99a555fccb 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -724,4 +724,10 @@ void UnStackInferMeta(const MetaTensor& x, int num, std::vector outs); +void QuantForCompressInferMeta(const MetaTensor& x, + int bits, + const std::string& layout, + MetaTensor* out, + MetaTensor* scale); + } // namespace phi diff --git a/paddle/phi/kernels/assign_kernel.cc b/paddle/phi/kernels/assign_kernel.cc index 7a6e8d392da1df5f9b9792e336a336cb3f1bd561..db30ec738961932272251448eea9e1ddea52a888 100644 --- a/paddle/phi/kernels/assign_kernel.cc +++ b/paddle/phi/kernels/assign_kernel.cc @@ -132,6 +132,7 @@ PD_REGISTER_KERNEL(assign_value, bool, int, float, + int8_t, int64_t) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -158,6 +159,7 @@ PD_REGISTER_KERNEL(assign_value, bool, int, float, + int8_t, int64_t) {} #endif diff --git a/paddle/phi/kernels/cpu/quant_for_compress_kernel.cc b/paddle/phi/kernels/cpu/quant_for_compress_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..a96842640b695dce121b125e9c680d3bae71c672 --- /dev/null +++ b/paddle/phi/kernels/cpu/quant_for_compress_kernel.cc @@ -0,0 +1,106 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/quant_for_compress_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/common_shape.h" +#include "paddle/phi/kernels/impl/quant_for_compress_kernel_impl.h" + +namespace phi { + +template +void quant_compute(const DeviceContext& dev_ctx, + const DenseTensor& x, + DenseTensor* out, + DenseTensor* scale, + const std::string& layout) { + const auto x_dims = x.dims(); + PADDLE_ENFORCE_EQ( + x_dims.size(), + 2, + phi::errors::InvalidArgument( + "the x tensor of quant op must be 2D, but got[%d]", x_dims.size())); + size_t m = x_dims[0]; + size_t n = x_dims[1]; + int64_t num = x.numel(); + DDim dims = {num}; + const T* x_data = x.data(); + D* out_data = out->data(); + float* scale_data = scale->data(); + + DenseTensor x_int(out->type()); + x_int.Resize({static_cast(m), static_cast(n)}); + dev_ctx.template Alloc(&x_int); + D* x_int_data = x_int.data(); + + DenseTensor int_processed(out->type()); + int_processed.Resize(dims); + dev_ctx.template Alloc(&int_processed); + + D* int_processed_data = int_processed.data(); + DenseTensor int_processed_2(out->type()); + int_processed_2.Resize(out->dims()); + dev_ctx.template Alloc(&int_processed_2); + D* int_processed_2_data = int_processed_2.data(); + + per_channel_scale(scale_data, x_data, m, n); + + per_channel_quant(x_int_data, x_data, scale_data, m, n); + if (layout == "weight_only") { + permute_B_rows_for_mixed_gemm( + int_processed_data, x_int_data, std::vector{m, n}, (int64_t)80); + row_major_to_column_major( + int_processed_2_data, int_processed_data, std::vector{m, n}); + interleave_column_major_tensor( + out_data, int_processed_2_data, std::vector{m, n}); + add_bias_and_interleave_int8s_inplace(out_data, num); + } else if (layout == "llm.int8") { + std::vector axis = {1, 0}; + funcs::Transpose trans; + trans(dev_ctx, x_int, out, axis); + } else { + phi::errors::InvalidArgument( + "The layout must be weight_only or llm.int8, but got %s", layout); + } +} + +template +void QuantForCompressKernel(const Context& dev_ctx, + const DenseTensor& x, + int bits, + const std::string& layout, + DenseTensor* out, + DenseTensor* scale) { + if (bits == 8) { + dev_ctx.template Alloc(out); + dev_ctx.template Alloc(scale); + quant_compute(dev_ctx, x, out, scale, layout); + } else { + phi::errors::Unimplemented("The bits only support 8, but got[%d]", bits); + } + // VLOG(0) << "x: " << x.dtype() << x; + // VLOG(0) << "out: " << out->dtype() << *out; +} + +} // namespace phi + +PD_REGISTER_KERNEL(quant_for_compress, + CPU, + ALL_LAYOUT, + phi::QuantForCompressKernel, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/cpu/transpose_kernel.cc b/paddle/phi/kernels/cpu/transpose_kernel.cc index 62d7513d6a458b2c9879c4dda9c7773f3cfefe7d..c43c7afc5db7a13b3c6347de75cdc49ca43e596f 100644 --- a/paddle/phi/kernels/cpu/transpose_kernel.cc +++ b/paddle/phi/kernels/cpu/transpose_kernel.cc @@ -87,6 +87,7 @@ PD_REGISTER_KERNEL(transpose, double, int32_t, int64_t, + int8_t, phi::dtype::float16, phi::dtype::bfloat16, phi::dtype::complex, diff --git a/paddle/phi/kernels/funcs/cublaslt.h b/paddle/phi/kernels/funcs/cublaslt.h new file mode 100644 index 0000000000000000000000000000000000000000..6391b583a08db13c5a4709026d75b02f55e5663a --- /dev/null +++ b/paddle/phi/kernels/funcs/cublaslt.h @@ -0,0 +1,239 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include "paddle/phi/backends/dynload/cublasLt.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace dyl = phi::dynload; + +namespace phi { + +struct CublasLtAlgoParam { + int algoId; + int swizzle; + int customOption; + int tile; + int splitK_val; + int reductionScheme; + int stages; + size_t workspace_size; +}; + +const std::map, CublasLtAlgoParam> AlgoParamCache{}; + +class CublasLtHelper { + public: + CublasLtHelper(int m, int k, int n) + : alpha_(1), beta_(0), m_(m), k_(k), n_(n) { + cublasStatus_t status; + // handle and matmul desc + status = dyl::cublasLtCreate(&handle_); +#if CUBLAS_VER_MAJOR < 11 + cudaDataType_t cudaComputeType = CUDA_R_32I; +#else + cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I; +#endif + + PADDLE_ENFORCE_EQ( + status, + CUBLAS_STATUS_SUCCESS, + phi::errors::External( + "cublasLtMatrixLayoutCreate execution error" + "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " + "information")); + +#if CUBLAS_VER_MAJOR < 11 + status = dyl::cublasLtMatmulDescCreate(&matmul_desc_, cudaComputeType); +#else + status = dyl::cublasLtMatmulDescCreate( + &matmul_desc_, cudaComputeType, CUDA_R_32I); +#endif + + PADDLE_ENFORCE_EQ( + status, + CUBLAS_STATUS_SUCCESS, + phi::errors::External( + "cublasLtMatmulDescCreate execution error" + "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " + "information")); + cublasOperation_t op_transpose = CUBLAS_OP_T; + status = dyl::cublasLtMatmulDescSetAttribute(matmul_desc_, + CUBLASLT_MATMUL_DESC_TRANSA, + &op_transpose, + sizeof(op_transpose)); + PADDLE_ENFORCE_EQ( + status, + CUBLAS_STATUS_SUCCESS, + phi::errors::External( + "cublasLtMatmulDescSetAttribute execution error" + "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " + "information")); + + // matrix desc + status = dyl::cublasLtMatrixLayoutCreate(&B_desc_, CUDA_R_8I, k, n, k); + PADDLE_ENFORCE_EQ( + status, + CUBLAS_STATUS_SUCCESS, + phi::errors::External( + "cublasLtMatrixLayoutCreate execution error" + "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " + "information")); + + status = dyl::cublasLtMatrixLayoutCreate(&A_desc_, CUDA_R_8I, k, m, k); + PADDLE_ENFORCE_EQ( + status, + CUBLAS_STATUS_SUCCESS, + phi::errors::External( + "cublasLtMatrixLayoutCreate execution error" + "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " + "information")); + + status = dyl::cublasLtMatrixLayoutCreate(&C_desc_, CUDA_R_32I, n, m, n); + PADDLE_ENFORCE_EQ( + status, + CUBLAS_STATUS_SUCCESS, + phi::errors::External( + "cublasLtMatrixLayoutCreate execution error" + "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " + "information")); + +#if CUDA_VERSION >= 11020 + + int algoId = 21; + int swizzle = 0; + int customOption = 0; + int tile = 15; + int splitK_val = 0; + int reductionScheme = 0; + int stages = 23; + workspace_size_ = 0; + if (m >= 128) { + tile = 20; + stages = 17; + } + + std::tuple key(m_, k_, n_); + if (AlgoParamCache.count(key) != 0) { + auto value = AlgoParamCache.at(key); + algoId = value.algoId; + swizzle = value.swizzle; + customOption = value.customOption; + tile = value.tile; + splitK_val = value.splitK_val; + reductionScheme = value.reductionScheme; + stages = value.stages; + workspace_size_ = value.workspace_size; + } + + dyl::cublasLtMatmulAlgoInit(handle_, + cudaComputeType, + CUDA_R_32I, + CUDA_R_8I, + CUDA_R_8I, + CUDA_R_32I, + CUDA_R_32I, + algoId, + &algo_); + dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo_, + CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, + &(customOption), + sizeof(customOption)); + dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo_, CUBLASLT_ALGO_CONFIG_TILE_ID, &(tile), sizeof(tile)); + dyl::cublasLtMatmulAlgoConfigSetAttribute(&algo_, + CUBLASLT_ALGO_CONFIG_SPLITK_NUM, + &(splitK_val), + sizeof(splitK_val)); + dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo_, + CUBLASLT_ALGO_CONFIG_CTA_SWIZZLING, + &(swizzle), + sizeof(swizzle)); + dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo_, + CUBLASLT_ALGO_CONFIG_REDUCTION_SCHEME, + &(reductionScheme), + sizeof(int)); +#if CUDA_VERSION >= 11000 + dyl::cublasLtMatmulAlgoConfigSetAttribute( + &algo_, CUBLASLT_ALGO_CONFIG_STAGES_ID, &(stages), sizeof(stages)); +#endif +#endif + } + ~CublasLtHelper() {} + + void GEMM(int8_t* A_dev, + const int8_t* B_dev, + int32_t* C_dev, + cudaStream_t stream, + void* workspace = nullptr) { + cublasStatus_t status; + + status = dyl::cublasLtMatmul(handle_, + matmul_desc_, + &alpha_, + B_dev, + B_desc_, + A_dev, + A_desc_, + &beta_, + C_dev, + C_desc_, + C_dev, + C_desc_, +#if CUDA_VERSION >= 11020 + &algo_, + workspace, + workspace_size_, +#else + nullptr, + nullptr, + 0, +#endif + stream); + PADDLE_ENFORCE_EQ( + status, + CUBLAS_STATUS_SUCCESS, + phi::errors::External( + "cublasLtMatmul execution error" + "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " + "information")); + } + + private: + cublasLtHandle_t handle_; + cublasLtMatmulDesc_t matmul_desc_; + cublasLtMatrixLayout_t A_desc_; + cublasLtMatrixLayout_t B_desc_; + cublasLtMatrixLayout_t C_desc_; + + cublasLtMatmulAlgo_t algo_; + + int32_t alpha_; + int32_t beta_; + + int m_; + int k_; + int n_; + + size_t workspace_size_; +}; + +} // namespace phi diff --git a/paddle/phi/kernels/funcs/quant_dequant.h b/paddle/phi/kernels/funcs/quant_dequant.h new file mode 100644 index 0000000000000000000000000000000000000000..62bfc9cfcf1bee7e74ced08ed3fe500755a66715 --- /dev/null +++ b/paddle/phi/kernels/funcs/quant_dequant.h @@ -0,0 +1,158 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/common/transform.h" +#include "paddle/phi/core/hostdevice.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" + +namespace phi { + +using backends::gpu::GpuLaunchConfig; + +constexpr int DequantKernelVecSize = 4; + +template +inline HOSTDEVICE T roundWithTiesToEven(T x) { + T xLower = floor(x); + T xUpper = ceil(x); + // x is in interval [xl,xu]. Choose closest of two bounds, breaking ties to + // even. + T dLower = x - xLower; + T dUpper = xUpper - x; + return static_cast( + (dLower == dUpper ? fmod(xLower, 2.0F) == 0.0F : dLower < dUpper) + ? xLower + : xUpper); +} + +template +__forceinline__ __device__ int8_t quant_helper(const T input, + const float scale, + const int round_type, + const float max_bound, + const float min_bound) { + float quant_value = max_bound * scale * static_cast(input); + + if (round_type == 0) { + quant_value = static_cast(roundWithTiesToEven(quant_value)); + } else { + quant_value = static_cast(round(quant_value)); + } + quant_value = quant_value > max_bound ? max_bound : quant_value; + quant_value = quant_value < min_bound ? min_bound : quant_value; + return static_cast(quant_value); +} + +template +__global__ void quantize_kernel(const T* input, + char4* output, + const float scale, + const int m, + const int n, + const int round_type, + const float max_bound, + const float min_bound) { + int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2; + int m_id = blockIdx.y * blockDim.y + threadIdx.y; + + bool check = ((m_id < m) && (n_id < n)); + if (check) { + char4 tmp; + tmp.x = quant_helper( + input[m_id * n + n_id], scale, round_type, max_bound, min_bound); + tmp.y = quant_helper( + input[m_id * n + n_id + 1], scale, round_type, max_bound, min_bound); + tmp.z = quant_helper( + input[m_id * n + n_id + 2], scale, round_type, max_bound, min_bound); + tmp.w = quant_helper( + input[m_id * n + n_id + 3], scale, round_type, max_bound, min_bound); + output[(m_id * n + n_id) >> 2] = tmp; + } +} + +template +void quantize_kernel_launcher(const T* input, + int8_t* output, + const float scale, + const int m, + const int n, + const int round_type, + const float max_bound, + const float min_bound, + gpuStream_t stream) { + // TODO(minghaoBD): optimize the kennel launch times when m==1 or n==1 + dim3 grid((n >> 2 + 31) / 32, (m + 31) / 32); + dim3 block(32, 32); + + quantize_kernel<<>>(input, + (char4*)output, // NOLINT + scale, + m, + n, + round_type, + max_bound, + min_bound); +} + +template +__global__ void dequantize_kernel(T* output, + const int32_t* input, + const int m, // batch size + const int n, // hidden + const float quant_in_scale, + const float* dequant_out_scale_data) { + int numel = m * n; + int stride = blockDim.x * gridDim.x * VecSize; + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; + int col_id = idx % n; + + phi::AlignedVector in_vec; + phi::AlignedVector out_scale_vec; + phi::AlignedVector out_vec; + + for (; idx < numel; idx += stride) { + phi::Load(input + idx, &in_vec); + phi::Load(dequant_out_scale_data + col_id, &out_scale_vec); + +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + out_vec[i] = + static_cast(static_cast(in_vec[i]) * out_scale_vec[i]); + } + + phi::Store(out_vec, output + idx); + } +} + +template +void dequantize_kernel_launcher(const int32_t* input, + T* output, + const int m, // m + const int n, // n + gpuStream_t stream, + GpuLaunchConfig* gpu_config, + const float quant_in_scale, + const float* dequant_out_scale_data) { + dequantize_kernel + <<block_per_grid, gpu_config->thread_per_block, 0, stream>>>( + output, input, m, n, quant_in_scale, dequant_out_scale_data); +} + +} // namespace phi diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h new file mode 100644 index 0000000000000000000000000000000000000000..151153e4b297daf5f3fa6969ab07297064711846 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +/*! \file + \brief Templates exposing architecture support for multiply-add operations +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +// Tag which triggers MMA which will trigger +struct OpMultiplyAddDequantizeInterleavedBToA; + +} // namespace arch +} // namespace cutlass diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/compute_occupancy.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/compute_occupancy.h new file mode 100644 index 0000000000000000000000000000000000000000..2c96d39b6beabb97a34ac835ca3601610a140f76 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/compute_occupancy.h @@ -0,0 +1,53 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "cutlass/device_kernel.h" +#include "paddle/phi/kernels/fusion/cutlass/utils/cuda_utils.h" + +namespace phi { +template +inline int compute_occupancy_for_kernel() { + int smem_size = static_cast(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size > (48 << 10)) { + cudaError_t status = + cudaFuncSetAttribute(cutlass::Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (status == cudaError::cudaErrorInvalidValue) { + // Clear the error bit since we can ignore this. + // This should mean that smem_size > + // cudaDevAttrMaxSharedMemoryPerBlockOptin. In that case, we return an + // occupancy of 0. This will cause the heuristic to ignore this + // configuration. + status = cudaGetLastError(); + return 0; + } + check_cuda_error(status); + } + + int max_active_blocks = -1; + check_cuda_error( + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, + cutlass::Kernel, + GemmKernel::kThreadCount, + smem_size)); + + return max_active_blocks; +} +} // namespace phi diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue/epilogue_quant_helper.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue/epilogue_quant_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..ced7ae1c1ad4169a8825d495d13a163b4053d8e6 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue/epilogue_quant_helper.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { + +// define scaling mode +enum class QuantMode { + PerTensorQuant, + PerTokenQuant, + PerChannelQuant, + PerTokenChannelQuant +}; + +} // namespace epilogue +} // namespace cutlass diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue/thread/ft_fused_activations.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue/thread/ft_fused_activations.h new file mode 100644 index 0000000000000000000000000000000000000000..a2dab5e740e72f6b1b479abed83b7381e852b944 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue/thread/ft_fused_activations.h @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +/*! \file + \brief Functor performing linear combination with a maximum operation used by + epilogues. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/functional.h" +#include "cutlass/half.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +__forceinline__ __device__ float copysignf_pos(float a, float b) { + float r; + r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); + return r; +} + +__forceinline__ __device__ float tanh_opt(float x) { +#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750) + const float exp_val = -1.f * fabs(2 * x); + return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); +#else + return fast_tanh(x); +#endif +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +template <> +struct GELU_taylor { + static const bool kIsHeavy = true; + CUTLASS_DEVICE + float operator()(float const& z) const { + float k0 = static_cast(0.7978845608028654); + float k1 = static_cast(0.044715); + + return static_cast( + cutlass::constants::half() * z * + (cutlass::constants::one() + + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z)))); + } + + using Params = LinearCombinationGenericParams; + + CUTLASS_DEVICE + float operator()(float const& scalar, Params const& params_) const { + return this->operator()(scalar); + } +}; + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h new file mode 100644 index 0000000000000000000000000000000000000000..b3b14a94371ac1319567219bc081daefaeb60b69 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h @@ -0,0 +1,359 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +/*! \file + \brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one + scaling factor per row, and one per column. + + original file: + 3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h + +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "../epilogue_quant_helper.h" +#include "cutlass/arch/memory.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/numeric_conversion.h" + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +template +class EpilogueVisitorPerRowPerCol { + public: + using ThreadblockShape = ThreadblockShape_; + static int const kThreadCount = ThreadCount; + + using ScaleTileIterator = ScaleTileIterator_; + using OutputTileIterator = OutputTileIterator_; + using ElementwiseFunctor = ElementwiseFunctor_; + + static int const kIterations = OutputTileIterator::kIterations; + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + using ElementOutput = typename OutputTileIterator::Element; + using LayoutOutput = cutlass::layout::RowMajor; + using ElementAccumulator = ElementAccumulator_; + + using AlphaScaleElementType = typename ScaleTileIterator::Element; + + using ElementCompute = ElementCompute_; + using AccumulatorFragment = Array; + using ComputeFragment = Array; + using OutputVector = Array; + + static int const kThreadsPerRow = + OutputTileIterator::ThreadMap::Detail::kAccessWidth; + static bool const kHasMultiStepsInRow = + (OutputTileIterator::ThreadMap::Iterations::kColumn > 1); + + /// Argument structure + struct Arguments { + typename ElementwiseFunctor::Params elementwise; + int64_t batch_stride_alpha; + int64_t batch_stride_C; + int64_t batch_stride_D; + + // + // Methods + // + Arguments() : batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {} + + explicit Arguments(typename ElementwiseFunctor::Params elementwise_) + : elementwise(elementwise_), + batch_stride_alpha(0), + batch_stride_C(0), + batch_stride_D(0) {} + + Arguments(typename ElementwiseFunctor::Params elementwise_, + int64_t batch_stride_alpha_, + int64_t batch_stride_C_, + int64_t batch_stride_D_) + : elementwise(elementwise_), + batch_stride_alpha(batch_stride_alpha_), + batch_stride_C(batch_stride_C_), + batch_stride_D(batch_stride_D_) {} + }; + + struct Params { + typename ElementwiseFunctor::Params elementwise; + int64_t batch_stride_alpha; + int64_t batch_stride_C; + int64_t batch_stride_D; + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + explicitParams(Arguments const& args) + : elementwise(args.elementwise), + batch_stride_alpha(args.batch_stride_alpha), + batch_stride_C(args.batch_stride_C), + batch_stride_D(args.batch_stride_D) {} + }; + + /// Shared storage + struct SharedStorage {}; + + private: + Params const& params_; + SharedStorage& shared_storage_; + MatrixCoord extent_; + MatrixCoord extent_real_; + ElementwiseFunctor elementwise_; + + const bool per_token_quant_; + const bool per_channel_quant_; + + AlphaScaleElementType* ptr_alpha_row_; + AlphaScaleElementType* ptr_alpha_col_; + ScaleTileIterator iterator_alpha_col_; + OutputTileIterator iterator_C_; + OutputTileIterator iterator_D_; + + AlphaScaleElementType element_alpha_row_ = 1.0f; + AlphaScaleElementType element_alpha_col_ = 1.0f; + typename ScaleTileIterator::Fragment fragment_alpha_col_; + typename OutputTileIterator::Fragment fragment_C_; + typename OutputTileIterator::Fragment fragment_D_; + + ElementAccumulator beta_; + + int column_offset_; + + MatrixCoord thread_offset_; + + public: + CUTLASS_DEVICE + EpilogueVisitorPerRowPerCol( + Params const& params, + SharedStorage& shared_storage, // NOLINT + cutlass::MatrixCoord const& problem_size, + int thread_idx, + int warp_idx, + int lane_idx, + typename ScaleTileIterator::Params params_alpha_col, + typename OutputTileIterator::Params params_C, + typename OutputTileIterator::Params params_D, + QuantMode quant_mode, + AlphaScaleElementType* ptr_alpha_row, + AlphaScaleElementType* ptr_alpha_col, + typename OutputTileIterator::Element* ptr_C, + typename OutputTileIterator::Element* ptr_D, + cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, + 0), + int column_offset = 0, + cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, + 0)) + : params_(params), + shared_storage_(shared_storage), + extent_(problem_size), + elementwise_(params.elementwise), + per_token_quant_(quant_mode == QuantMode::PerTokenQuant || + quant_mode == QuantMode::PerTokenChannelQuant), + per_channel_quant_(quant_mode == QuantMode::PerChannelQuant || + quant_mode == QuantMode::PerTokenChannelQuant), + ptr_alpha_row_(ptr_alpha_row), + ptr_alpha_col_(ptr_alpha_col), + iterator_alpha_col_(params_alpha_col, + ptr_alpha_col, + problem_size, + thread_idx, + threadblock_offset), + iterator_C_( + params_C, ptr_C, problem_size, thread_idx, threadblock_offset), + iterator_D_( + params_D, ptr_D, problem_size, thread_idx, threadblock_offset), + extent_real_(problem_size_real) { + beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr + : params.elementwise.beta); + + if (beta_ == ElementAccumulator()) { + iterator_C_.clear_mask(); + } + } + + /// Helper to indicate split-K behavior + CUTLASS_DEVICE + void set_k_partition( + int split_k_index, ///< Index of this threadblock within split-K + ///< partitioned scheme + int split_k_slices) { ///< Total number of split-K slices + } + + /// Called to set the batch index + CUTLASS_DEVICE + void set_batch_index(int batch_idx) { + iterator_alpha_col_.add_pointer_offset(batch_idx * + params_.batch_stride_alpha); + iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C); + iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D); + } + + /// Called at the start of the epilogue just before iterating over accumulator + /// slices + CUTLASS_DEVICE + void begin_epilogue() { + if (per_channel_quant_) { + iterator_alpha_col_.load(fragment_alpha_col_); + } else if (ptr_alpha_col_ != nullptr) { + arch::global_load( + element_alpha_col_, ptr_alpha_col_, true); + } + + if (!per_token_quant_ && ptr_alpha_row_ != nullptr) { + arch::global_load( + element_alpha_row_, ptr_alpha_row_, true); + } + } + + /// Called at the start of one step before starting accumulator exchange + CUTLASS_DEVICE + void begin_step(int step_idx) { + fragment_D_.clear(); + fragment_C_.clear(); + + if (elementwise_.kScale != + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { + iterator_C_.load(fragment_C_); + ++iterator_C_; + } + + // load alpha_row in begin_step only when per token(row) scaling is used + if (per_token_quant_) { + int thread_offset_row = + iterator_D_.thread_start_row() + + OutputTileIterator::ThreadMap::iteration_offset(0).row(); + + // element_alpha_row_ = ptr_alpha_row_[thread_offset_row]; + arch::global_load( + element_alpha_row_, + ptr_alpha_row_ + thread_offset_row, + thread_offset_row < extent_.row()); + } + } + + /// Called at the start of a row + CUTLASS_DEVICE + void begin_row(int row_idx) { + // Clear accumulators for max and sum when starting a whole row + } + + /// Called after accumulators have been exchanged for each accumulator vector + CUTLASS_DEVICE + void visit(int iter_idx, + int row_idx, + int column_idx, + int frag_idx, + AccumulatorFragment const& accum) { + NumericArrayConverter + source_converter; + + ComputeFragment result = source_converter(accum); + if (per_channel_quant_) { + ComputeFragment alpha_col = + reinterpret_cast(&fragment_alpha_col_)[frag_idx]; + result = per_token_channel_scale_accumulator_( + result, alpha_col, element_alpha_row_); + } else { + result = per_token_scale_accumulator_( + result, element_alpha_col_, element_alpha_row_); + } + + // Convert to the output + NumericArrayConverter + output_converter; + OutputVector& output = + reinterpret_cast(&fragment_D_)[frag_idx]; + output = output_converter(result); + } + + /// Called after all accumulator elements have been visited + CUTLASS_DEVICE + void end_step(int step_idx) { + iterator_D_.store(fragment_D_); + ++iterator_D_; + } + + /// Called after all steps have been completed + CUTLASS_DEVICE + void end_epilogue() {} + + private: + CUTLASS_DEVICE + ComputeFragment per_token_channel_scale_accumulator_( + ComputeFragment const& accum, + ComputeFragment const& scale_col, + AlphaScaleElementType const& scale_row) { + ComputeFragment result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ComputeFragment::kElements; ++i) { + result[i] = accum[i] * (scale_col[i] * scale_row); + } + + return result; + } + + CUTLASS_DEVICE + ComputeFragment per_token_scale_accumulator_( + ComputeFragment const& accum, + AlphaScaleElementType const& scale_col, + AlphaScaleElementType const& scale_row) { + ComputeFragment result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ComputeFragment::kElements; ++i) { + result[i] = accum[i] * (scale_col * scale_row); + } + + return result; + } +}; + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h new file mode 100644 index 0000000000000000000000000000000000000000..9d0cb644b236eb17ddfd74f398729db893bf1fa2 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h @@ -0,0 +1,308 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory + to match canonical tensor layouts in global memory. Epilogues support + conversion and reduction operations. + + original file: + 3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h + +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/platform/platform.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_clamp.h" +#include "cutlass/epilogue/thread/linear_combination_gelu.h" +#include "cutlass/epilogue/thread/linear_combination_hardswish.h" +#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/thread/linear_combination_relu0.h" +#include "cutlass/epilogue/thread/linear_combination_sigmoid.h" + +#include "cutlass/epilogue/thread/conversion_op.h" +#include "cutlass/epilogue/thread/reduction_op.h" + +#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" + +#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h" +#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h" +#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h" + +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/threadblock/interleaved_epilogue.h" + +#include "cutlass/layout/permute.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Partial specialization for half <= int32_t x 8 epilogues avoids shared +/// memory bank conflicts. +template +struct DefaultIteratorsTensorOp { + using WarpTileIterator = + cutlass::epilogue::warp::TileIteratorTensorOp; + + using SharedLoadIterator = + cutlass::epilogue::threadblock::SharedLoadIterator; + + static int const kFragmentsPerIteration = 1; +}; + +/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared +/// memory bank conflicts. +#ifdef PADDLE_CUDA_BF16 +template +struct DefaultIteratorsTensorOp { + using WarpTileIterator = + cutlass::epilogue::warp::TileIteratorTensorOp; + + using SharedLoadIterator = + cutlass::epilogue::threadblock::SharedLoadIterator; + + static int const kFragmentsPerIteration = 1; +}; +#endif +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load output tile from shared memory in epilogue. +/// +/// Satisfies: ReadableTileIterator +/// +template +class SharedLoadIteratorMixed { + public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = int32_t; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + static int const kAlignment = + ThreadMap::kElementsPerAccess * sizeof_bits::value / 8; + + static int const kThreads = ThreadMap::kThreads; + + /// Fragment object + using Fragment = + Array; + + /// Memory access size + using AccessType = + AlignedArray; + + /// Vector type used for SMEM loads + using LoadType = AlignedArray::value, + ThreadMap::kElementsPerAccess), + const_min(16, kAlignment)>; + + static int const kLoadsPerAccess = + AccessType::kElements / LoadType::kElements; + + private: + // + // Data members + // + + /// Byte-level pointer + LoadType const* pointers_[kLoadsPerAccess]; + + /// Stride along adjacent rows in units of LoadType + int stride_; + + public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + SharedLoadIteratorMixed(TensorRef ref, int thread_idx) + : stride_((ref.stride(0) / LoadType::kElements)) { + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); + + // Initialize pointers + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] = reinterpret_cast(ref.data()); + + int col_idx = + (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess; + int bank_offset = (col_idx * static_cast(sizeof(LoadType)) / 128) % + kLoadsPerAccess; + + col_idx += (bank_offset + i) % kLoadsPerAccess; + + pointers_[i] += thread_offset.row() * stride_ + col_idx; + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] += pointer_offset / LoadType::kElements; + } + } + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& offset) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) { + pointers_[i] += offset.row() * Shape::kRow * stride_ + + offset.column() * Shape::kColumn / LoadType::kElements; + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, // NOLINT + Index pointer_offset) const { + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int row_ptr_offset = row * ThreadMap::Delta::kRow * stride_ + + group * ThreadMap::Delta::kGroup * stride_ + + cluster * ThreadMap::Delta::kCluster * stride_ + + pointer_offset / LoadType::kElements; + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + LoadType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + int frag_idx = + frag_row_idx * ThreadMap::Iterations::kColumn + column; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kLoadsPerAccess; ++v) { + int vector_idx = (column * ThreadMap::Delta::kColumn / + kElementsPerAccess * kLoadsPerAccess); + + LoadType const* memory_pointer = pointers_[v] + row_ptr_offset; + + frag_ptr[frag_idx * kLoadsPerAccess + v] = + memory_pointer[vector_idx]; + } + } + } + } + } + } + + /// Loads a fragment + CUTLASS_DEVICE + void load(Fragment& frag) const { // NOLINT + load_with_pointer_offset(frag, 0); + } +}; + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue_helpers.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..f8c1a84f76555b5cfa167026e09df0136febebdf --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue_helpers.h @@ -0,0 +1,143 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +/** + * @file epilogue_helpers.h + * + * This file includes types for the epilogues. The empty structs exist so we can + * signal to template code the type of epilogue we want to run, and let the + * underlying code specify the details such as element types, accumulator type + * and elements per vector access. + * + */ + +#pragma once + +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/thread/linear_combination_silu.h" +#include "epilogue/thread/ft_fused_activations.h" + +namespace phi { +struct EpilogueOpBiasSilu {}; + +struct EpilogueOpBiasReLU {}; + +struct EpilogueOpBiasFtGelu {}; + +struct EpilogueOpBias {}; + +struct EpilogueOpNoBias {}; + +template +struct Epilogue {}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationSilu< + ElementType, + ElementsPerVectorAccess, + ElementAccumulator, + ElementAccumulator, + cutlass::epilogue::thread::ScaleType::NoBetaScaling>; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationRelu< + ElementType, + ElementsPerVectorAccess, + ElementAccumulator, + ElementAccumulator, + cutlass::epilogue::thread::ScaleType::NoBetaScaling>; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombinationGeneric< + cutlass::epilogue::thread::GELU_taylor, + ElementType, + ElementsPerVectorAccess, + ElementAccumulator, + ElementAccumulator, + cutlass::epilogue::thread::ScaleType::NoBetaScaling, + cutlass::FloatRoundStyle::round_to_nearest, + true>; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombination< + ElementType, + ElementsPerVectorAccess, + ElementAccumulator, + ElementAccumulator, + cutlass::epilogue::thread::ScaleType::NoBetaScaling>; +}; + +template +struct Epilogue { + using Op = cutlass::epilogue::thread::LinearCombination< + ElementType, + ElementsPerVectorAccess, + ElementAccumulator, + ElementAccumulator, + cutlass::epilogue::thread::ScaleType::Default>; +}; +} // namespace phi diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h new file mode 100644 index 0000000000000000000000000000000000000000..1726bc9054ddabcd8da8d99dd8e07ed0ad0f016f --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h @@ -0,0 +1,60 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace phi { +// Note: The shapes are in the format MxNxK. The K shape of the runtime config +// MUST match the K shape +// in the kernel layout details when doing weight only quantization. +enum class CutlassTileConfig { + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + // SiMT config + CtaShape128x128x8_WarpShape64x64x8, + + // TensorCore configs CTA_N = 128, CTA_K = 64 + // Warp configs for M=32 + CtaShape32x128x64_WarpShape32x32x64, + + // Warp configs for M=64 + CtaShape64x128x64_WarpShape32x64x64, + CtaShape64x128x64_WarpShape64x32x64, + + // Warp configs for M=128 + CtaShape128x128x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape128x32x64, + + // configs for large M in encoder + CtaShape128x256x64_WarpShape64x64x64, + CtaShape256x128x64_WarpShape64x64x64 +}; + +enum class SplitKStyle { + NO_SPLIT_K, + SPLIT_K_SERIAL, + // SPLIT_K_PARALLEL // Not supported yet +}; + +struct CutlassGemmConfig { + CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic; + SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K; + int split_k_factor = -1; + int stages = -1; +}; +} // namespace phi diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h new file mode 100644 index 0000000000000000000000000000000000000000..5fc621b0b8bf696bf38d2a1891a056da33b776bb --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h @@ -0,0 +1,149 @@ + + +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/bfloat16.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" + +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" + +namespace cutlass { +namespace gemm { +namespace kernel { + +template +struct MixedGemmArchTraits {}; + +template +struct MixedGemmArchTraits { + static constexpr int Stages = 2; + using OperatorClass = cutlass::arch::OpClassSimt; + using AccType = float; + using LayoutB = cutlass::layout::RowMajor; + + static constexpr int ElementsPerAccessA = 1; + static constexpr int ElementsPerAccessB = 1; + static constexpr int ElementsPerAccessC = 1; + static constexpr int ThreadblockK = 8; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// // ========================= Volta Traits =========================== +// // Volta will always dequantize after the global memory load. +// // This will instantiate any HMMA tensorcore kernels for Volta. +// // Note that volta does not have native bfloat support so weights and +// activations will be casted to fp16 +// // and compute will happen in fp16 then will be converted for bf16 output. +template +struct MixedGemmArchTraits< + TypeA, + TypeB, + cutlass::arch::Sm70, + typename cutlass::platform::enable_if< + cutlass::platform::is_same::value || + cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = + 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = + 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using Operator = typename LayoutDetails::Operator; +}; + +// ======================= Turing Traits ============================== +// Note that turing does not have native bfloat support so weights and +// activations will be casted to fp16 and compute will happen in fp16 then will +// be converted for bf16 output. +template +struct MixedGemmArchTraits< + TypeA, + TypeB, + cutlass::arch::Sm75, + typename cutlass::platform::enable_if< + cutlass::platform::is_same::value || + cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = + 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = + 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using Operator = typename LayoutDetails::Operator; +}; + +// ======================= Ampere Traits ============================== +template +struct MixedGemmArchTraits< + TypeA, + TypeB, + cutlass::arch::Sm80, + typename cutlass::platform::enable_if< + cutlass::platform::is_same::value || + cutlass::platform::is_same::value>::type> { + private: + using LayoutDetails = LayoutDetailsB; + + public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = + 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = + 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + using Operator = typename LayoutDetails::Operator; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..b6d880d4de5d1725a6ff1fed66d8013564432273 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h @@ -0,0 +1,532 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or + support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct GemmFpAIntB { + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static bool const kSplitKSerial = SplitKSerial; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Element; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Mma::LayoutC; + using ElementScale = float; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformA; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = + Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + static int const kSplitKAlignment = const_max( + 128 / sizeof_bits::value, 128 / sizeof_bits::value); + + static constexpr int kInterleave = + Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + + /// Parameters structure + struct Arguments : UniversalArgumentsBase { + GemmUniversalMode mode = GemmUniversalMode::kGemm; + + cutlass::gemm::GemmCoord problem_size; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorScale::TensorRef ref_scale; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + + // Control serial split-k + int batch_count; + + typename EpilogueOutputOp::Params output_op; + + // For gather+scatter operations + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; + + // Included so we can use Gemm Universal + int batch_stride_D = 0; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Arguments() {} + + CUTLASS_HOST_DEVICE + Arguments(cutlass::gemm::GemmCoord const& problem_size, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorScale::TensorRef ref_scale, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D, + int serial_split_k_factor, + typename EpilogueOutputOp::Params output_op = + typename EpilogueOutputOp::Params(), + int const* gather_A_indices = nullptr, + int const* gather_B_indices = nullptr, + int const* scatter_D_indices = nullptr) + : UniversalArgumentsBase(mode, + problem_size, + /*serial_split_k_factor=*/1, + /*batch_stride_D=*/0), + ref_A(ref_A), + ref_B(ref_B), + ref_scale(ref_scale), + ref_C(ref_C), + ref_D(ref_D), + output_op(output_op), + gather_A_indices(gather_A_indices), + gather_B_indices(gather_B_indices), + scatter_D_indices(scatter_D_indices) {} + }; + + /// Parameters structure + struct Params : UniversalParamsBase { + using ParamsBase = UniversalParamsBase; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename Mma::IteratorScale::Params params_scale; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorScale::TensorRef ref_scale; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + typename EpilogueOutputOp::Params output_op; + // For gather+scatter operations + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() = default; + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, int device_sms, int sm_occupancy) + : ParamsBase(args, device_sms, sm_occupancy), + params_A(args.ref_A.layout()), + ref_A(args.ref_A), + params_B(args.ref_B.layout()), + ref_B(args.ref_B), + params_scale(args.ref_scale.layout()), + ref_scale(args.ref_scale), + params_C(args.ref_C.layout()), + ref_C(args.ref_C), + params_D(args.ref_D.layout()), + ref_D(args.ref_D), + output_op(args.output_op), + gather_A_indices(args.gather_A_indices), + gather_B_indices(args.gather_B_indices), + scatter_D_indices(args.scatter_D_indices) {} + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + GemmFpAIntB() {} + + /// Determines whether kernel satisfies alignment + CUTLASS_HOST_DEVICE + static Status can_implement(Arguments const& args) { + static int const kAlignmentA = + (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = + (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorB::AccessType::kElements; + + static int const kAlignmentScale = + Mma::IteratorScale::AccessType::kElements; + + static int const kAlignmentC = + (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Epilogue::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(args.ref_A, kAlignmentA)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_B, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_scale, kAlignmentScale)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_C, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_D, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + return Status::kSuccess; + } + + static size_t get_extra_workspace_size( + Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { + return 0; + } + + CUTLASS_DEVICE + static void invoke(Params const& params, + SharedStorage& shared_storage) { // NOLINT + GemmFpAIntB op; + op(params, shared_storage); + } + + // The dummy template parameter is not used and exists so that we can compile + // this code using a standard earlier than C++17. Prior to C++17, fully + // specialized templates HAD to exists in a namespace + template + struct KernelRunner { + CUTLASS_DEVICE + static void run_kernel(Params const& params, + SharedStorage& shared_storage) { // NOLINT + CUTLASS_NOT_IMPLEMENTED(); + } + }; + + template + struct KernelRunner { + CUTLASS_DEVICE + static void run_kernel(Params const& params, + SharedStorage& shared_storage) { // NOLINT + using LayoutB = typename Mma::IteratorB::Layout; + static_assert( + platform::is_same::value && + kInterleave == 1 || + platform::is_same::value && + kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + return; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size, + }; + + cutlass::MatrixCoord tb_offset_B{ + threadblock_tile_offset.k() * params.gemm_k_size * kInterleave, + threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave}; + + cutlass::MatrixCoord tb_offset_scale{ + 0, threadblock_tile_offset.n() * Mma::Shape::kN}; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k = + min(params.problem_size.k(), + (threadblock_tile_offset.k() + 1) * params.gemm_k_size); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = + (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / + Mma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + params.ref_A.data(), + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A, + params.gather_A_indices); + + typename Mma::IteratorB iterator_B( + params.params_B, + params.ref_B.data(), + {problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, + thread_idx, + tb_offset_B, + params.gather_B_indices); + + typename Mma::IteratorScale iterator_scale(params.params_scale, + params.ref_scale.data(), + {1, params.problem_size.n()}, + thread_idx, + tb_offset_scale); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + if (!kSplitKSerial || gemm_k_iterations > 0) { + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + iterator_scale, + accumulators); + } + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN); + + int block_idx = threadblock_tile_offset.m() + + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // If performing a reduction via split-K, fetch the initial + // synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is + // currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), + params.grid_tiled_shape.k()); + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, + params.ref_C.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.scatter_D_indices); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + params.ref_D.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.scatter_D_indices); + + Epilogue epilogue( + shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator + // construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + // For subsequent threadblocks, the source matrix is held in the 'D' + // tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } + }; + + /* + To improve compilation speed, we do not compile the device operator if the + CUDA_ARCH does not correspond to the ArchTag of the cutlass kernel + operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, + SharedStorage& shared_storage) { // NOLINT +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750) + // static constexpr bool compile_needed = platform::is_same::value; KernelRunner::run_kernel(params, + // shared_storage); + CUTLASS_NOT_IMPLEMENTED(); + +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) + // static constexpr bool compile_needed = platform::is_same::value; KernelRunner::run_kernel(params, + // shared_storage); + CUTLASS_NOT_IMPLEMENTED(); +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) + static constexpr bool compile_needed = + platform::is_same::value; + KernelRunner::run_kernel(params, shared_storage); +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h new file mode 100644 index 0000000000000000000000000000000000000000..4aa34c479d5c0ccd7cda9b46387804521cc18ecd --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h @@ -0,0 +1,125 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +/* + This file exists so that we use the same weight layout for MoE grouped gemm + and regular gemm when the weight is quantized. The preprocessing code reads + this template to know how to organize the quantized weight matrices to be + consumed by CUTLASS. + + Note that for int4, ThreadBlockK MUST be 64. + + */ + +#pragma once + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/platform/platform.h" + +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/tile_interleaved_layout.h" + +namespace cutlass { +namespace gemm { +namespace kernel { + +template +struct LayoutDetailsB {}; + +// // Volta specialiations. Volta will dequantize before STS, so we need a +// different operator +template +struct LayoutDetailsB { + static constexpr int ThreadblockK = 64; + using Layout = layout::RowMajor; + static constexpr int ElementsPerAccess = 8; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specializations for Turing+ when B is FP16. These are currently only used for +// MoE networks. +template +struct LayoutDetailsB< + half_t, + Arch, + typename platform::enable_if= 75>::type> { + static constexpr int ThreadblockK = 64; + using Layout = layout::RowMajor; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template +struct LayoutDetailsB< + bfloat16_t, + Arch, + typename platform::enable_if= 75>::type> { + static constexpr int ThreadblockK = 64; + using Layout = layout::RowMajor; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specializations for Turing+ when B is quantized. These can use the operator +// OpMultiplyAddDequantizeInterleavedBToA, which signals that we want to +// dequantize after loading from smem. +template +struct LayoutDetailsB< + uint8_t, + Arch, + typename platform::enable_if= 75>::type> { + static constexpr int ThreadblockK = 64; + + private: + static constexpr int ElementsPerCacheLine = + 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + + public: + using Layout = + layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; +}; + +template +struct LayoutDetailsB< + uint4b_t, + Arch, + typename platform::enable_if= 75>::type> { + static constexpr int ThreadblockK = 64; + + private: + static constexpr int ElementsPerCacheLine = + 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + + public: + using Layout = + layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = + 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma.h new file mode 100644 index 0000000000000000000000000000000000000000..b4b450d7060d919bdcf39bce93eca57528e9a12f --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma.h @@ -0,0 +1,124 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/interleaved_numeric_conversion.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { +//////////////////////////////////////////////////////////////////////////////// + +// We need to distinguish here, since we want volta support. It is too much +// effort to write shared memory iterators that are probably needed for volta to +// function properly. As a result, we allow converters both after the LDG (for +// volta) and after the LDS for Turing+. +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Warp level Mma + typename MmaOperator, + /// Math operation perform by warp level operator + typename MathOperator> +struct SetConverters {}; + +// Dequantize after LDG, so set transforms accordingly +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Mma Policy + typename MmaOperator> +struct SetConverters { + using TransformAfterLDG = FastInterleavedAndBiasedNumericArrayConverter< + typename MmaOperator::ArchMmaOperator::ElementB, + typename IteratorB::Element, + IteratorB::Fragment::kElements>; + + using TransformAfterLDS = + NumericArrayConverter; +}; + +// Dequantize after LDS, so set transforms accordingly + +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Mma Policy + typename MmaOperator> +struct SetConverters { + using TransformAfterLDG = + NumericArrayConverter; + + using TransformAfterLDS = FastInterleavedAndBiasedNumericArrayConverter< + typename MmaOperator::ArchMmaOperator::ElementB, + typename TransformAfterLDG::result_type::Element, + MmaOperator::FragmentB::kElements>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale_, + /// Layout for the scale operand + typename LayoutScale_, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// + typename Enable = void> +struct DqMma; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h new file mode 100644 index 0000000000000000000000000000000000000000..d349a09e967d042e49532b7b55f9b8475c1e4ea7 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h @@ -0,0 +1,393 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h" + +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/default_mma_tensor_op.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/tile_interleaved_layout.h" + +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Type for elementA + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Stages in GEMM + int kStages, + /// + typename Operator, + /// + SharedMemoryClearOption SharedMemoryClear> +struct DqMma= + 80)>::type> { + static_assert(platform::is_same::value || + platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert( + platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + static_assert(platform::is_same::value || + platform::is_same::value, + "Element B must be int8 or uint4"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma + // multistage pieces are created + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, + LayoutA, + 1, + ThreadMapA, + AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, + LayoutB, + 0, + ThreadMapB, + AccessTypeB>; + + // ThreadMap for scale iterator + static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + MmaCore::Shape::kN / kAlignmentScale, + kAlignmentScale>; + + // Define iterators over tiles from the scale operand + using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape<1, MmaCore::Shape::kN>, + ElementScale, + LayoutScale, + 0, + IteratorScaleThreadMap, + kAlignmentScale>; + + using SmemIteratorScale = IteratorScale; + + using Converter = FastInterleavedAndBiasedNumericArrayConverter< + ElementA, + ElementB, + MmaCore::MmaPolicy::Operator::FragmentB::kElements>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage< + typename MmaCore::Shape, + IteratorA, + typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, + IteratorB, + typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, + IteratorScale, + SmemIteratorScale, + ElementAccumulator, + layout::RowMajor, + typename MmaCore::MmaPolicy, + kStages, + Converter, + SharedMemoryClear>; +}; + +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Stages in GEMM + int kStages, + /// + typename Operator, + /// + SharedMemoryClearOption SharedMemoryClear, + /// + int RowsPerTile, + /// + int ColumnsInterleaved> +struct DqMma, + kAlignmentB, + ElementScale, + LayoutScale, + kAlignmentScale, + ElementAccumulator, + layout::RowMajor, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + kStages, + Operator, + SharedMemoryClear, + typename platform::enable_if<(ArchTag::kMinComputeCapability >= + 80)>::type> { + static_assert(platform::is_same::value || + platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert( + platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + static_assert(platform::is_same::value || + platform::is_same::value, + "Element B must be uint8 or uint4"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma + // multistage pieces are created + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, + LayoutA, + 1, + ThreadMapA, + AccessTypeA>; + + private: + static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); + static_assert(RowsPerTile == MmaCore::Shape::kK, ""); + + using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; + using OriginalWarpArrangement = + typename OriginalThreadMap::Detail::WarpThreadArrangement; + static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); + + using GmemIteratorShape = + MatrixShape; + using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, + OriginalThreadMap::kThreads, + layout::PitchLinearShape< + OriginalWarpArrangement::kContiguous * ColumnsInterleaved, + OriginalWarpArrangement::kStrided / ColumnsInterleaved>, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + + public: + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + GmemIteratorShape, + ElementB, + layout::ColumnMajor, + 0, + GmemThreadMapB, + AccessTypeB>; + + // ThreadMap for scale iterator + static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + MmaCore::Shape::kN / kAlignmentScale, + kAlignmentScale>; + + // Define iterators over tiles from the scale operand + using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape<1, MmaCore::Shape::kN>, + ElementScale, + LayoutScale, + 0, + IteratorScaleThreadMap, + kAlignmentScale>; + + using SmemIteratorScale = IteratorScale; + + using Converter = FastInterleavedAndBiasedNumericArrayConverter< + ElementA, + ElementB, + MmaCore::MmaPolicy::Operator::FragmentB::kElements>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage< + typename MmaCore::Shape, + IteratorA, + typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, + IteratorB, + typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, + IteratorScale, + SmemIteratorScale, + ElementAccumulator, + layout::RowMajor, + typename MmaCore::MmaPolicy, + kStages, + Converter, + SharedMemoryClear>; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h new file mode 100644 index 0000000000000000000000000000000000000000..c641690d02dde1395cb14bc527bef09a2ede8455 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h @@ -0,0 +1,360 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h" + +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/default_mma_tensor_op.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/tile_interleaved_layout.h" + +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DqMma< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementScale, + LayoutScale, + kAlignmentScale, + ElementAccumulator, + layout::RowMajor, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + 2, + Operator, + SharedMemoryClearOption::kNone, + typename platform::enable_if<(ArchTag::kMinComputeCapability < 80)>::type> { + static_assert(platform::is_same::value || + platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value || + platform::is_same::value, + "Element B must be int8 or uint4"); + + static constexpr bool DqAfterLDG = + platform::is_same::value; + static constexpr bool arch_has_bf16_mma = + ArchTag::kMinComputeCapability >= 80; + using MmaCoreElementA = + typename platform::conditional::type; + using MmaCoreElementB = typename platform:: + conditional::type; + + // Define the MmaCore components + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementA, + LayoutA, + 1, + typename MmaCore::IteratorThreadMapA, + kAlignmentA>; + + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementB, + LayoutB, + 0, + typename MmaCore::IteratorThreadMapB, + kAlignmentB>; + + // ThreadMap for scale iterator + static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + MmaCore::Shape::kN / kAlignmentScale, + kAlignmentScale>; + + // Define iterators over tiles from the scale operand + using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape<1, MmaCore::Shape::kN>, + ElementScale, + LayoutScale, + 0, + IteratorScaleThreadMap, + kAlignmentScale>; + + using SmemScaleType = typename platform:: + conditional::type; + using SmemIteratorScale = + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape<1, MmaCore::Shape::kN>, + SmemScaleType, + LayoutScale, + 0, + IteratorScaleThreadMap, + kAlignmentScale>; + + using Converters = + SetConverters; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined< + typename MmaCore::Shape, + IteratorA, + typename MmaCore::SmemIteratorA, + IteratorB, + typename MmaCore::SmemIteratorB, + IteratorScale, + SmemIteratorScale, + ElementAccumulator, + layout::RowMajor, + typename MmaCore::MmaPolicy, + typename Converters::TransformAfterLDG, + typename Converters::TransformAfterLDS>; +}; + +// Specialization to handle column major interleave B +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int RowsPerTile, + /// + int ColumnsInterleaved> +struct DqMma< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + layout::ColumnMajorTileInterleave, + kAlignmentB, + ElementScale, + LayoutScale, + kAlignmentScale, + ElementAccumulator, + layout::RowMajor, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + 2, + Operator, + SharedMemoryClearOption::kNone, + typename platform::enable_if<(ArchTag::kMinComputeCapability < 80)>::type> { + static_assert(platform::is_same::value || + platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value || + platform::is_same::value, + "Element B must be int8 or uint4"); + + static constexpr bool DqAfterLDG = + platform::is_same::value; + static constexpr bool arch_has_bf16_mma = + ArchTag::kMinComputeCapability >= 80; + using MmaCoreElementA = + typename platform::conditional::type; + using MmaCoreElementB = typename platform:: + conditional::type; + + // Define the MmaCore components + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementA, + LayoutA, + 1, + typename MmaCore::IteratorThreadMapA, + kAlignmentA>; + + private: + static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); + static_assert(RowsPerTile == MmaCore::Shape::kK, ""); + + using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; + using OriginalWarpArrangement = + typename OriginalThreadMap::Detail::WarpThreadArrangement; + static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); + + using GmemIteratorShape = + MatrixShape; + using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, + OriginalThreadMap::kThreads, + layout::PitchLinearShape< + OriginalWarpArrangement::kContiguous * ColumnsInterleaved, + OriginalWarpArrangement::kStrided / ColumnsInterleaved>, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + + public: + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< + GmemIteratorShape, + ElementB, + layout::ColumnMajor, + 0, + GmemThreadMapB, + kAlignmentB>; + + // ThreadMap for scale iterator + static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + MmaCore::Shape::kN / kAlignmentScale, + kAlignmentScale>; + + // Define iterators over tiles from the scale operand + using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape<1, MmaCore::Shape::kN>, + ElementScale, + LayoutScale, + 0, + IteratorScaleThreadMap, + kAlignmentScale>; + + using SmemScaleType = typename platform:: + conditional::type; + using SmemIteratorScale = + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape<1, MmaCore::Shape::kN>, + SmemScaleType, + LayoutScale, + 0, + IteratorScaleThreadMap, + kAlignmentScale>; + + using Converters = + SetConverters; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined< + typename MmaCore::Shape, + IteratorA, + typename MmaCore::SmemIteratorA, + IteratorB, + typename MmaCore::SmemIteratorB, + IteratorScale, + SmemIteratorScale, + ElementAccumulator, + layout::RowMajor, + typename MmaCore::MmaPolicy, + typename Converters::TransformAfterLDG, + typename Converters::TransformAfterLDS>; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_mma.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_mma.h new file mode 100644 index 0000000000000000000000000000000000000000..5f58189a2f6df949f686852ce46aa1f7588f3de0 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_mma.h @@ -0,0 +1,443 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_mma_bf16.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), fp16 +/// activation & int8 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 +/// activation & int4 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 +/// activation & int4 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps +// avoid reg spills on large tile when not enough shared mem is present to do 3+ +// stage +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma { + // Define the MmaCore components + // 3 is used on purpose here to trigger components for mma multistage + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + half_t, + LayoutA, + 1, + ThreadMapA, + AccessTypeA, + GatherA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + half_t, + LayoutB, + 0, + ThreadMapB, + AccessTypeB, + GatherB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = + cutlass::gemm::threadblock::MmaMultistage; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_mma_bf16.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_mma_bf16.h new file mode 100644 index 0000000000000000000000000000000000000000..be85b18ae30915bd96e6990164a817a2a6aedba9 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_mma_bf16.h @@ -0,0 +1,550 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), bf16 +/// activation & bf16 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma { + private: + // Conversions only needed pre-ampere. This will trigger mma pipeline, so we + // convert before STS. + static constexpr bool arch_has_bf16_mma = + ArchTag::kMinComputeCapability >= 80; + using MmaElementA = typename platform:: + conditional::type; + using MmaElementB = typename platform:: + conditional::type; + + public: + // Define the MmaCore components + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + bfloat16_t, + LayoutA, + 1, + typename MmaCore::IteratorThreadMapA, + kAlignmentA, + GatherA>; + + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + bfloat16_t, + LayoutB, + 0, + typename MmaCore::IteratorThreadMapB, + kAlignmentB, + GatherB>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = + cutlass::gemm::threadblock::MmaPipelined; +}; + +// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps +// avoid reg spills on large tile when not enough shared mem is present to do 3+ +// stage +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma { + // Define the MmaCore components + // 3 is used on purpose here to trigger components for mma multistage + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + bfloat16_t, + LayoutA, + 1, + ThreadMapA, + AccessTypeA, + GatherA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + bfloat16_t, + LayoutB, + 0, + ThreadMapB, + AccessTypeB, + GatherB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = + cutlass::gemm::threadblock::MmaMultistage; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), bf16 +/// activation & int8 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), bf16 +/// activation & int4 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 +/// activation & int4 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma { + private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + + public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_base.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_base.h new file mode 100644 index 0000000000000000000000000000000000000000..7932921d69434e408b02172ec12729cd112cf82d --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_base.h @@ -0,0 +1,237 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// +// SFINAE trick so I can keep the same loop code for Volta and dispatch to the +// correct warp level mma. On volta, all data is stored to shared memory as +// FP16. +template +CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, // NOLINT + typename WarpMma::FragmentC& D, // NOLINT + typename WarpMma::FragmentA const& A, + typename WarpMma::FragmentB const& B, + typename WarpMma::FragmentC const& C, + const int warp_tileB_k_offset) { + warp_mma(D, A, B, C); +} + +template +CUTLASS_DEVICE void run_warp_mma( + WarpMma& warp_mma, // NOLINT + typename WarpMma::FragmentC& D, // NOLINT + typename WarpMma::TransformedFragmentA const& A, + typename WarpMma::TransformedFragmentB const& B, + typename WarpMma::FragmentC const& C, + const int warp_tileB_k_offset) { + warp_mma(D, A, B, C, warp_tileB_k_offset); +} +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// The type of the scales + typename ElementScale_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class DqMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + ///< Type of the scale to be loaded + using ElementScale = ElementScale_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + static constexpr int kNumKIterationsPerWarpBLoad = + Operator::IteratorB::InstructionShape::kRow / + Operator::InstructionShape::kK; + + static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), ""); + static constexpr int kWarpGemmIterationsForB = + kWarpGemmIterations / kNumKIterationsPerWarpBLoad; + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = + TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = + TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = + MatrixShape; + + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape; + + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B; + + /// Buffer to hold scales for threadblock + AlignedBuffer operand_scale; + + public: + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + }; + + protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage& shared_storage, // NOLINT + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h new file mode 100644 index 0000000000000000000000000000000000000000..c2e5d85371d55d6f0c06bb5904ec68f553c21c24 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h @@ -0,0 +1,638 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Used for partial specialization + typename Enable = bool> +class DqMmaMultistage : public DqMmaBase { + public: + ///< Base class + using Base = + DqMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using TransformBAfterLDS = TransformBAfterLDS_; + + // + // Dependent types + // + + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + using Dequantizer = + warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + }; + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave = + layout::IsColumnMajorTileInterleave< + typename LayoutDetailsForB::Layout>::value; + static_assert(!RequiresTileInterleave || + (RequiresTileInterleave && + (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + + private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale operand to shared + /// memory + SmemIteratorScale smem_iterator_scale_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, // NOLINT + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_dequantizer_( + {shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / + Base::WarpCount::kM, + lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_scale_(LayoutScale(Shape::kN), + shared_storage.operand_scale.data(), + {1, Shape::kN}, + thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA& iterator_A, // NOLINT + IteratorB& iterator_B, // NOLINT + int group_start_A = 0, + int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, // NOLINT + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over scale operand in global memory + IteratorScale iterator_scale, + ///< initial value of accumulator + FragmentC const& src_accum) { + // + // Prologue + // + + TransformBAfterLDS lds_converter; + + // NOTE - switch to ldg.sts + // Issue this first, so cp.async.commit_group will commit this load as well. + // Note: we do not commit here and this load will commit in the same group + // as + // the first load of A. + FragmentScale tb_frag_scales; + tb_frag_scales.clear(); + iterator_scale.load(tb_frag_scales); + this->smem_iterator_scale_.store(tb_frag_scales); + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for + // some kernels so that all accumulator elements outside the GEMM footprint + // are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared + /// memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared + /// memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + typename Dequantizer::FragmentScale warp_frag_scales; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + warp_dequantizer_.load(warp_frag_scales); + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + const int warp_tileB_k_compute_offset = + warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + const int warp_tileB_k_load_offset = + warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + if (warp_tileB_k_compute_offset == + Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load( + warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + typename TransformBAfterLDS::result_type converted_frag_B = + lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); + run_warp_mma(warp_mma, + accum, + warp_frag_A[warp_mma_k % 2], + converted_frag_B, + accum, + warp_tileB_k_compute_offset); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterationsForB, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + } + } + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM + // mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h new file mode 100644 index 0000000000000000000000000000000000000000..7f0417d5b9520249f79b45ad666227d0013a680a --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h @@ -0,0 +1,408 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" + +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/interleaved_numeric_conversion.h" + +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Converter for B matrix applied immediately after the LDG (before STS) + typename TransformBAfterLDG_, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// Used for partial specialization + typename Enable = bool> +class DqMmaPipelined : public DqMmaBase { + public: + ///< Base class + using Base = + DqMmaBase; + + using Shape = + Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = + IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = + IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + using TransformBAfterLDG = TransformBAfterLDG_; + using TransformBAfterLDS = TransformBAfterLDS_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + using Dequantizer = warp::MmaTensorOpDequantizer< + Operator, + typename Base::WarpGemm, + Operand::kB, + typename SmemIteratorScale::Fragment::Element, + LayoutScale, + 32, + typename Operator::FragmentA::Element>; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for DqMmaPipelined is two (Double-buffered + // pipeline) + static_assert((Base::kStages == 2), + "DqMmaPipelined requires kStages set to value 2"); + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave = + layout::IsColumnMajorTileInterleave< + typename LayoutDetailsForB::Layout>::value; + static_assert(!RequiresTileInterleave || + (RequiresTileInterleave && + (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + + protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale operand to shared + /// memory + SmemIteratorScale smem_iterator_scale_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaPipelined(typename Base::SharedStorage& + shared_storage, ///< Shared storage needed for internal + ///< use by threadblock-scoped GEMM + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + warp_dequantizer_( + {shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / + Base::WarpCount::kM, + lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_scale_(LayoutScale(Shape::kN), + shared_storage.operand_scale.data(), + {1, Shape::kN}, + thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile // NOLINT + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + IteratorScale + iterator_scale, ///< iterator over scale operand in global memory + FragmentC const& src_accum) { ///< source accumulator tile + // + // Prologue + // + TransformBAfterLDG ldg_converter; + TransformBAfterLDS lds_converter; + + using TransformA = NumericArrayConverter; + + using TransformScale = + NumericArrayConverter; + + // These transforms are mainly to handle when we have bfloat activations and + // weights in GMEM and want to issue HMMA on architectures older than + // Ampere. We will convert to FP16 before STS. + TransformA transformA; + TransformScale transformScale; + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + FragmentScale tb_frag_scales; + + using WarpFragmentScale = typename Dequantizer::FragmentScale; + WarpFragmentScale warp_frag_scales; + + tb_frag_A.clear(); + tb_frag_B.clear(); + tb_frag_scales.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + iterator_scale.load(tb_frag_scales); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transformA(tb_frag_A)); + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + this->smem_iterator_scale_.store(transformScale(tb_frag_scales)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + __syncthreads(); + + warp_dequantizer_.load(warp_frag_scales); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* + // issuing shared memory loads (which have the tighest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + // Write fragments to shared memory + this->smem_iterator_A_.store(transformA(tb_frag_A)); + + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterationsForB, + 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % + Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + const int warp_tileB_k_compute_offset = + warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + const int warp_tileB_k_load_offset = + warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + // We are just about to finish computing on a fragment of B, so initiate + // the load for the next fragment. + if (warp_tileB_k_compute_offset == + Base::kNumKIterationsPerWarpBLoad - 1) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load( + warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + if (warp_mma_k == 0) { + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + + typename TransformBAfterLDS::result_type converted_frag_B = + lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); + run_warp_mma(warp_mma, + accum, + warp_frag_A[warp_mma_k % 2], + converted_frag_B, + accum, + warp_tileB_k_compute_offset); + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/default_mma_tensor_op.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/default_mma_tensor_op.h new file mode 100644 index 0000000000000000000000000000000000000000..4651d97b49c3597257fbf6889caa919bf19e3d42 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/default_mma_tensor_op.h @@ -0,0 +1,115 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +/*! \file + \brief Default warp-level GEMM operators selected by data type, size, and + layouts of operands. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/warp/default_mma_tensor_op.h" +#include "cutlass/gemm/warp/mma_tensor_op.h" + +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/arch/mma.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for m-by-n-by-kgroup +template < + /// Shape of one matrix production operation (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A elements, + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Number of partitions along K dimension + int PartitionsK, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor> +struct DefaultMmaTensorOp { + private: + // Shape for computing the FP16s + using ComputeInstructionShape = InstructionShape_; + + // Chosen so we get K=16 for int8 and K=32 for int4. + static constexpr int LoadInstructionK = + 8 * sizeof_bits::value / sizeof_bits::value; + + // Shape for loading the narrow data type from shared memory + using LoadInstructionShape = + GemmShape; + + public: + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma, + cutlass::MatrixShape<1, 1>>; + + // Define the warp-level tensor op + using Type = + cutlass::gemm::warp::MmaTensorOpComputeBWithF16; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h new file mode 100644 index 0000000000000000000000000000000000000000..a554fab610ac92d8c96a242758a39c30c00cec11 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h @@ -0,0 +1,304 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +/*! \file + \brief Templates implementing warp-level matrix multiply-accumulate + operations targeting Tensor Cores. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm80.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma.h" + +#include "cutlass/gemm/warp/mma_tensor_op_policy.h" + +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + typename Policy_, + /// Instruction shape to override shared memory iterators with + typename SharedMemoryInstructionShape_, + /// Number of partitions along K dimension + int PartitionsK_ = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Used for partial specialization + typename Enable = bool> +class MmaTensorOpComputeBWithF16 { + public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + static_assert( + (platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value && + ArchTag::kMinComputeCapability >= 80), + "MmaTensorOpCvtBToA only supports underlying HMMA"); + + static_assert(platform::is_same::value || + (platform::is_same::value && + ArchTag::kMinComputeCapability >= 80), + "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+"); + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Instruction shape to override shared memory iterators with + using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; + + static_assert(SharedMemoryInstructionShape::kM == InstructionShape::kM, + "M dimension of compute instruction must match load"); + static_assert(SharedMemoryInstructionShape::kN == InstructionShape::kN, + "N dimension of compute instruction must match load"); + + static constexpr int kExpansionFactor = + SharedMemoryInstructionShape::kK / InstructionShape::kK; + + static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + + public: + /// Iterates over the A operand in memory + using IteratorA = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kA, + ElementA, + LayoutA, + MatrixShape, + Policy::OpDelta::kRow, + kThreadCount, + kPartitionsK>; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = + Array; + + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator< + MatrixShape, + Operand::kB, + ElementB, + LayoutB, + MatrixShape, + Policy::OpDelta::kRow, + kThreadCount, + kPartitionsK>; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed B tile + using TransformedFragmentB = + Array; + + /// Iterates over the C operand in memory + using IteratorC = + MmaTensorOpAccumulatorTileIterator, + ElementC, + LayoutC, + typename ArchMmaOperator::Shape, + typename Policy::OpDelta>; + + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; + + /// Number of mma operations performed + using MmaIterations = + MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / + ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / + ArchMmaOperator::Shape::kN>; + + public: + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; + + public: + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaTensorOpComputeBWithF16() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()(FragmentC& D, // NOLINT + TransformedFragmentA const& A, + TransformedFragmentB const& B, + FragmentC const& C, + const int warp_tileB_k_offset) const { + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; + + static_assert( + TransformedFragmentB::kElements == + MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, + "Each thread should have a pack of mma registers for each column " + "iteration AND for the expanded K dim of B"); + + D = C; + + MmaOperandA const* ptr_A = reinterpret_cast(&A); + MmaOperandB const* ptr_B = reinterpret_cast(&B); + MmaOperandC* ptr_D = reinterpret_cast(&D); + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + // Serpentine visitation order maximizing reuse of Rb + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); + + int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], + ptr_A[m_serpentine], + ptr_B[n_offsetB], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } else { + mma(ptr_D[m_serpentine + n * MmaIterations::kRow], + ptr_A[m_serpentine], + ptr_B[n_offsetB], + ptr_D[m_serpentine + n * MmaIterations::kRow]); + } + } + } +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + // Serpentine visitation order maximizing reuse of Ra + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + int n_serpentine_offsetB = + warp_tileB_k_offset + kExpansionFactor * n_serpentine; + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], + ptr_A[m], + ptr_B[n_serpentine_offsetB], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], + ptr_A[m], + ptr_B[n_serpentine_offsetB], + ptr_D[m + n_serpentine * MmaIterations::kRow]); + } + } + } +#else + assert(0); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h new file mode 100644 index 0000000000000000000000000000000000000000..99de23e7f45041a89ca7884abf1b4b3591861a45 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h @@ -0,0 +1,707 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +/*! \file + \brief Defines iterators used by warp-level matrix multiply operations + targeting Tensor Cores. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" + +#include "cutlass/functional.h" +#include "cutlass/platform/platform.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Matrix multiply operator + typename MmaOperator_, + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Operand identity + Operand Operand, + /// Data type of Scale elements + typename Element_, + /// Layout of operand + typename Layout_, + /// Number of threads participating in one matrix operation + int Threads, + /// + /// Data type of out elements + typename Element_out_, + typename Enable = void> +class MmaTensorOpDequantizer; + +//////////////////////////////////////////////////////////////////////////////// +// Bfloat specialization for Ampere +#ifdef PADDLE_CUDA_BF16 +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_> +class MmaTensorOpDequantizer< + MmaOperator_, + Shape_, + Operand::kB, + bfloat16_t, + layout::RowMajor, + 32, + bfloat16_t, + typename platform::enable_if< + MmaOperator_::ArchTag::kMinComputeCapability >= 80 && + platform::is_same::value>::type> { + public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = + MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the scales + using ElementScale = bfloat16_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = + Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentScale = + Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, + const int warp_idx_n, + const int lane_idx) { + const int warp_offset = warp_idx_n * Shape::kN; + const int quad = lane_idx / 4; + const int thread_offset = warp_offset + quad; + pointer_ = smem_scales.data() + thread_offset; + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; + ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_[mma_n_iter * InstructionShape::kN]; + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, + const FragmentScale& scale_frag) { + // Slow path not implemented here on purpose. If we need to do HMMA on older + // arch, scale conversion should happen before scales are stored to shared + // memory and we should use the fp16 dequantizer. This will avoid numerous + // conversion instructions in GEMM main loop. + arch::device_breakpoint(); + } + + private: + ElementScale const* pointer_; +}; +#endif + +// Specialization for Turing & Ampere +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_> +class MmaTensorOpDequantizer< + MmaOperator_, + Shape_, + Operand::kB, + half_t, + layout::RowMajor, + 32, + half_t, + typename platform::enable_if< + MmaOperator_::ArchTag::kMinComputeCapability >= 75 && + platform::is_same::value>::type> { + public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = + MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the scales + using ElementScale = half_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = + Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentScale = + Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, + const int warp_idx_n, + const int lane_idx) { + const int warp_offset = warp_idx_n * Shape::kN; + const int quad = lane_idx / 4; + const int thread_offset = warp_offset + quad; + pointer_ = smem_scales.data() + thread_offset; + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; + ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_[mma_n_iter * InstructionShape::kN]; + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, + const FragmentScale& scale_frag) { + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = + Array; + static_assert( + ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn == + FragmentDequantizedOperand::kElements, + ""); + + multiplies mul_op; + + ExpandedMmaOperandB* operand_frag_ptr = + reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; + ++mma_n_iter) { + operand_frag_ptr[mma_n_iter] = + mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); + } + } + + private: + ElementScale const* pointer_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +// Specialization for Volta A x RowMajor B tensorOp, for 32x32x4 interleaved +// gemm +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_> +class MmaTensorOpDequantizer< + MmaOperator_, + Shape_, + Operand::kB, + half_t, + layout::RowMajor, + 32, + half_t, + typename platform::enable_if< + platform::is_same::value && + platform::is_same::value>::type> { + public: + static_assert(platform::is_same>::value, + ""); + + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Type of the scales + using ElementScale = half_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = + Array; + + /// Warp mma shape + using Shape = Shape_; + + // Fragment to hold scale data to apply to B before mma + // Each 32x32x4 matmul uses 8 elements from B. + static constexpr int ColsPerMmaTile = 32; + static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile; + using FragmentScale = Array; + using AccessType = Array; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, + const int warp_idx_n, + const int lane_idx) { + const int warp_offset = warp_idx_n * Shape::kN; + const int base_col = lane_idx & 0xF8; + const int thread_offset = warp_offset + base_col; + pointer_ = smem_scales.data() + thread_offset; + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) { // NOLINT + AccessType* scale_frag_ptr = reinterpret_cast(&scale_frag); + + CUTLASS_PRAGMA_UNROLL + for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) { + // We jump by 32 here since volta does <32x32x4> super mmas inside a warp. + scale_frag_ptr[tile_iter] = *reinterpret_cast( + pointer_ + ColsPerMmaTile * tile_iter); + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, // NOLINT + const FragmentScale& scale_frag) { + static_assert( + FragmentScale::kElements == FragmentDequantizedOperand::kElements, ""); + + multiplies mul_op; + operand_frag = mul_op(operand_frag, scale_frag); + } + + private: + ElementScale const* pointer_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +// Specialization for Volta A x ColumnMajor B tensorOp, for 32x32x4 interleaved +// gemm +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_> +class MmaTensorOpDequantizer< + MmaOperator_, + Shape_, + Operand::kB, + half_t, + layout::RowMajor, + 32, + typename platform::enable_if< + platform::is_same::value && + platform::is_same::value>::type> { + public: + static_assert(platform::is_same>::value, + ""); + + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Type of the scales + using ElementScale = half_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = + Array; + + /// Warp mma shape + using Shape = Shape_; + + // Fragment to hold scale data to apply to B before mma + // Each 32x32x4 matmul uses 8 elements from B. + static constexpr int ColsPerMmaTile = 32; + static constexpr int TileNIterations = Shape::kN / ColsPerMmaTile; + using FragmentScale = Array; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, + const int warp_idx_n, + const int lane_idx) { + const int warp_offset = warp_idx_n * Shape::kN; + const int base_col = lane_idx & 0xF8 + lane_idx % 4; + const int thread_offset = warp_offset + base_col; + pointer_ = smem_scales.data() + thread_offset; + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) { // NOLINT + CUTLASS_PRAGMA_UNROLL + for (int tile_iter = 0; tile_iter < TileNIterations; ++tile_iter) { + // We jump by 32 here since volta does <32x32x4> super mmas inside a warp. + // For col major B, each thread will jump 4 cols to get its next value + // inside of the super mma. + CUTLASS_PRAGMA_UNROLL + for (int mma_iter = 0; mma_iter < 2; ++mma_iter) { + scale_frag[tile_iter * 2 + mma_iter] = + pointer_[ColsPerMmaTile * tile_iter + 4 * mma_iter]; + } + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, // NOLINT + const FragmentScale& scale_frag) { + using MmaOperandB = typename ArchMmaOperator::FragmentB; + static constexpr int total_n_mmas = 2 * TileNIterations; + static_assert(MmaOperandB::kElements * total_n_mmas == + FragmentDequantizedOperand::kElements, + ""); + + multiplies mul_op; + + MmaOperandB* operand_frag_ptr = + reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < total_n_mmas; ++mma_n_iter) { + operand_frag_ptr[mma_n_iter] = + mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); + } + } + + private: + ElementScale const* pointer_; +}; + +// Specialization for Turing & Ampere when Scale type is float and output type +// is half_t. +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_> +class MmaTensorOpDequantizer< + MmaOperator_, + Shape_, + Operand::kB, + float, + layout::RowMajor, + 32, + half_t, + typename platform::enable_if< + MmaOperator_::ArchTag::kMinComputeCapability >= 75 && + platform::is_same::value>::type> { + public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = + MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the output + using ElementType = half_t; + + // using ElementScale = float; + using ElementScale = float; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = + Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentScale = + Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, + const int warp_idx_n, + const int lane_idx) { + const int warp_offset = warp_idx_n * Shape::kN; + const int quad = lane_idx / 4; + const int thread_offset = warp_offset + quad; + pointer_ = smem_scales.data() + thread_offset; + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; + ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_[mma_n_iter * InstructionShape::kN]; + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, + const FragmentScale& scale_frag) { + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = + Array; + + using ComputeFrag = + Array; + + static_assert( + ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn == + FragmentDequantizedOperand::kElements, + ""); + + multiplies mul_op; + + ExpandedMmaOperandB* operand_frag_ptr = + reinterpret_cast(&operand_frag); + + NumericArrayConverter + source_converter; + NumericArrayConverter + output_converter; + + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; + ++mma_n_iter) { + ComputeFrag convert_frag = source_converter(operand_frag_ptr[mma_n_iter]); + convert_frag = mul_op(convert_frag, scale_frag[mma_n_iter]); + operand_frag_ptr[mma_n_iter] = output_converter(convert_frag); + } + } + + private: + ElementScale const* pointer_; +}; + +// Specialization for Turing & Ampere when Scale type is float and output type +// is bfloat16. +#ifdef PADDLE_CUDA_BF16 +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_> +class MmaTensorOpDequantizer< + MmaOperator_, + Shape_, + Operand::kB, + float, + layout::RowMajor, + 32, + bfloat16_t, + typename platform::enable_if< + MmaOperator_::ArchTag::kMinComputeCapability >= 75 && + platform::is_same::value>::type> { + public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = + MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the output + using ElementType = bfloat16_t; + + // using ElementScale = float; + using ElementScale = float; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = + Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentScale = + Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, + const int warp_idx_n, + const int lane_idx) { + const int warp_offset = warp_idx_n * Shape::kN; + const int quad = lane_idx / 4; + const int thread_offset = warp_offset + quad; + pointer_ = smem_scales.data() + thread_offset; + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; + ++mma_n_iter) { + scale_frag[mma_n_iter] = pointer_[mma_n_iter * InstructionShape::kN]; + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, + const FragmentScale& scale_frag) { + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = + Array; + + using ComputeFrag = + Array; + + static_assert( + ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn == + FragmentDequantizedOperand::kElements, + ""); + + multiplies mul_op; + + ExpandedMmaOperandB* operand_frag_ptr = + reinterpret_cast(&operand_frag); + + NumericArrayConverter + source_converter; + NumericArrayConverter + output_converter; + + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; + ++mma_n_iter) { + ComputeFrag convert_frag = source_converter(operand_frag_ptr[mma_n_iter]); + convert_frag = mul_op(convert_frag, scale_frag[mma_n_iter]); + operand_frag_ptr[mma_n_iter] = output_converter(convert_frag); + } + } + + private: + ElementScale const* pointer_; +}; +#endif +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/interleaved_numeric_conversion.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/interleaved_numeric_conversion.h new file mode 100644 index 0000000000000000000000000000000000000000..a8cc829c588b6de8448bb5e293052794c657c075 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/interleaved_numeric_conversion.h @@ -0,0 +1,438 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +/*! + \file + \brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t + interleaved in a register +*/ + +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/array.h" +#include "cutlass/half.h" +#include "cutlass/numeric_types.h" + +namespace cutlass { + +// This converter is meant to be used with data interleaved in a 32-bit register +// where the even elements are in the low bits and the odd elemeents are in the +// high bits of the register. In addition, it assumes elements were originally +// signed and had a bias of 2**(b-1) added (where b is the number of bits in the +// type) to make all numbers unsigned. This converter will uninterleave the data +// and subtract the bias while converting to the result type. +template +struct FastInterleavedAndBiasedNumericArrayConverter {}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[0]) + : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[1]) + : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); + + // Lastly, we subtract 1152 from our constructed number using fp16 math to + // get our signed integer as fp16. + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[0]) + : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[1]) + : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + uint32_t* bf16_result_ptr = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t fp32_base = 0x4B000000; + float fp32_intermediates[4]; + + // Construct FP32s, bfloat does not have enough mantissa for IADD trick + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + + // Subtract out fp32_base + 128 to make the unsigned integer signed. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 4; ++ii) { + fp32_intermediates[ii] -= 8388736.f; + } + + // Truncate the fp32 representation and pack up as bfloat16s. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 2; ++ii) { + bf16_result_ptr[ii] = __byte_perm(fp32_intermediates_casted[2 * ii + 0], + fp32_intermediates_casted[2 * ii + 1], + 0x7632); + } +#else + // Disable this on architectures older than Ampere since they lack hardware + // for bf16 mma. If one wishes to use HMMA on older hardware, they should + // Convert directly to FP16 using FP16 converters. + result.clear(); // Suppress compiler warning + arch::device_breakpoint(); +#endif + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM_MASK = 0x000f000f; + static constexpr uint32_t TOP_MASK = 0x00f000f0; + static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + + // Note that the entire sequence only requires 1 shift instruction. This is + // thanks to the register packing format and the fact that we force our + // integers to be unsigned, and account for this in the fp16 subtractions. + // In addition, I exploit the fact that sub and fma have the same throughput + // in order to convert elt_23 and elt_67 to fp16 without having to shift + // them to the bottom bits before hand. + + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide + // RAW dependency if we issue immediately before required. + const uint32_t top_i4s = i4s >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), + "n"(BOTTOM_MASK), + "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + + // I use inline PTX below because I am not sure if the compiler will emit + // float2half instructions if I use the half2 ctor. In this case, I chose + // performance reliability over code readability. + + // This is the half2 {1032, 1032} represented as an integer. + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + // This is the half2 {-72, -72} represented as an integer. + static constexpr uint32_t NEG_72 = 0xd480d480; + + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[0]) + : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(h[1]) + : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[2]) + : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(h[3]) + : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 8; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter { + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + uint32_t* h = reinterpret_cast(&result); + uint32_t const source_i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; + + // We don't have enough mantissa to remove as much shift overhead as FP16, + // so we must loop. No shift needed for first item. + uint32_t i4s = source_i4s; + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + CUTLASS_PRAGMA_UNROLL + for (int ii = 1; ii < result_type::kElements / 2; ++ii) { + i4s >>= sizeof_bits::value; + // (i4s & 0x000f000f) | 0x43004300 + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[ii]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + } + + // This is the BF16 {-136, -136} represented as an integer. + static constexpr uint32_t BF16_BIAS = 0xC308C308; + static constexpr uint32_t BF16_ONE = 0x3F803F80; + + // Finally, we construct the output numbers. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < result_type::kElements / 2; ++ii) { + // Since this section is for Ampere+, we use bf16 fma to do the bias + // subtraction + asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(h[ii]) + : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); + } +#else + // Disable this on architectures older than Ampere since they lack hardware + // for bf16 mma. If one wishes to use HMMA on older hardware, they should + // Convert directly to FP16 using FP16 converters. + arch::device_breakpoint(); + result.clear(); // Suppress compiler warning. +#endif + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter { + static constexpr int VEC_WIDTH = 8; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) { return convert(s); } +}; + +} // namespace cutlass diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/tile_interleaved_layout.h b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/tile_interleaved_layout.h new file mode 100644 index 0000000000000000000000000000000000000000..ce306836ae38f6737ada925d695fa48f2b8b2099 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_extensions/tile_interleaved_layout.h @@ -0,0 +1,42 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/pitch_linear_coord.h" + +namespace cutlass { +namespace layout { + +template +class ColumnMajorTileInterleave { + static constexpr int kRowsPerTile = RowsPerTile; + static constexpr int kColumnsInterleaved = ColumnsInterleaved; +}; + +template +struct IsColumnMajorTileInterleave { + static constexpr bool value = false; +}; + +template +struct IsColumnMajorTileInterleave> { + static constexpr bool value = true; +}; + +} // namespace layout +} // namespace cutlass diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h new file mode 100644 index 0000000000000000000000000000000000000000..a7095f5164225afb82f5dd8b2ff4dffef0f5560a --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h @@ -0,0 +1,244 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h" + +namespace phi { + +struct TileShape { + int m; + int n; +}; + +static TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) { + switch (tile_config) { + case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + return TileShape{32, 128}; + case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: + case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + return TileShape{64, 128}; + case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: + case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: + case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + return TileShape{128, 128}; + // TODO(wangbojun) check the Tile Shape here, + // {256, 128} have better performance than 128, 128 + case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: + return TileShape{256, 128}; + // TODO(wangbojun) CtaShape256x128x64_WarpShape64x64x64 is not a + case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: + return TileShape{256, 128}; + default: + throw std::runtime_error( + "[FT Error][get_grid_shape_for_config] Invalid config"); + } +} + +static bool is_valid_split_k_factor(const int64_t m, + const int64_t n, + const int64_t k, + const TileShape tile_shape, + const int split_k_factor, + const size_t workspace_bytes, + const bool is_weight_only) { + // All tile sizes have a k_tile of 64. + static constexpr int k_tile = 64; + + // For weight-only quant, we need k and k_elements_per_split to be a multiple + // of cta_k + if (is_weight_only) { + if ((k % k_tile) != 0) { + return false; + } + + if ((k % split_k_factor) != 0) { + return false; + } + + const int k_elements_per_split = k / split_k_factor; + if ((k_elements_per_split % k_tile) != 0) { + return false; + } + } + + // Check that the workspace has sufficient space for this split-k factor + const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; + const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; + const int required_ws_bytes = + split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; + + if (required_ws_bytes > workspace_bytes) { + return false; + } + + return true; +} + +static std::vector get_candidate_tiles( + const bool is_weight_only, + const bool is_weight_only_encoder, + const bool simt_configs_only) { + std::vector simt_configs{ + CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}; + + std::vector square_configs{ + CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64, + }; + + std::vector quant_B_configs{ + CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64, + }; + std::vector encoder_quant_B_configs{ + CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64 + // CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64 + }; + const std::vector allowed_quant_B_configs = + is_weight_only_encoder ? encoder_quant_B_configs : quant_B_configs; + const std::vector allowed_configs = + is_weight_only ? allowed_quant_B_configs : square_configs; + return simt_configs_only ? simt_configs : allowed_configs; +} + +static std::vector get_candidate_configs( + int sm, + const bool is_weight_only, + const bool is_weight_only_encoder, + const bool simt_configs_only) { + std::vector tiles = get_candidate_tiles( + is_weight_only, is_weight_only_encoder, simt_configs_only); + + std::vector candidate_configs; + const int min_stages = 2; + const int max_stages = sm >= 80 ? 4 : 2; + + for (const auto& tile_config : tiles) { + for (int stages = min_stages; stages <= max_stages; ++stages) { + CutlassGemmConfig config{tile_config, SplitKStyle::NO_SPLIT_K, 1, stages}; + candidate_configs.push_back(config); + } + } + + return candidate_configs; +} + +static CutlassGemmConfig estimate_best_config_from_occupancies( + const std::vector& candidate_configs, + const std::vector& occupancies, + const int64_t m, + const int64_t n, + const int64_t k, + const int64_t num_experts, + const int split_k_limit, + const size_t workspace_bytes, + const int multi_processor_count, + const int is_weight_only) { + if (occupancies.size() != candidate_configs.size()) { + throw std::runtime_error( + "[FT Error][estimate_best_config_from_occupancies] occpancies and " + "candidate configs vectors must have equal length."); + } + + CutlassGemmConfig best_config; + // Score will be [0, 1]. The objective is to minimize this score. + // It represents the fraction of SM resources unused in the last wave. + float config_score = 1.0f; + int config_waves = INT_MAX; + int current_m_tile = 0; + + const int max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit; + for (int ii = 0; ii < candidate_configs.size(); ++ii) { + CutlassGemmConfig candidate_config = candidate_configs[ii]; + TileShape tile_shape = + get_cta_shape_for_config(candidate_config.tile_config); + int occupancy = occupancies[ii]; + if (occupancy == 0) { + continue; + } + + // Keep small tile sizes when possible. + if (best_config.tile_config != CutlassTileConfig::ChooseWithHeuristic && + m < current_m_tile && current_m_tile < tile_shape.m) { + continue; + } + + const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; + const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; + + for (int split_k_factor = 1; split_k_factor <= max_split_k; + ++split_k_factor) { + if (is_valid_split_k_factor(m, + n, + k, + tile_shape, + split_k_factor, + workspace_bytes, + is_weight_only)) { + const int ctas_per_wave = occupancy * multi_processor_count; + const int ctas_for_problem = + ctas_in_m_dim * ctas_in_n_dim * split_k_factor; + + const int num_waves_total = + (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave; + const float num_waves_fractional = + ctas_for_problem / static_cast(ctas_per_wave); + const float current_score = + static_cast(num_waves_total) - num_waves_fractional; + + const float score_slack = 0.1f; + if (current_score < config_score || + ((config_waves > num_waves_total) && + (current_score < config_score + score_slack))) { + config_score = current_score; + config_waves = num_waves_total; + SplitKStyle split_style = split_k_factor > 1 + ? SplitKStyle::SPLIT_K_SERIAL + : SplitKStyle::NO_SPLIT_K; + best_config = CutlassGemmConfig{candidate_config.tile_config, + split_style, + split_k_factor, + candidate_config.stages}; + current_m_tile = tile_shape.m; + } else if (current_score == config_score && + (best_config.stages < candidate_config.stages || + split_k_factor < best_config.split_k_factor || + current_m_tile < tile_shape.m)) { + // Prefer deeper pipeline or smaller split-k + SplitKStyle split_style = split_k_factor > 1 + ? SplitKStyle::SPLIT_K_SERIAL + : SplitKStyle::NO_SPLIT_K; + best_config = CutlassGemmConfig{candidate_config.tile_config, + split_style, + split_k_factor, + candidate_config.stages}; + current_m_tile = tile_shape.m; + config_waves = num_waves_total; + } + } + } + } + + if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic) { + throw std::runtime_error( + "[FT Error] Heurisitc failed to find a valid config."); + } + + return best_config; +} +} // namespace phi diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..657973a57dfdaf3e1a92c81c0e9208d4639309e8 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h @@ -0,0 +1,142 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h" +// #include "src/fastertransformer/utils/allocator.h" +#include "cuda_runtime_api.h" // NOLINT + +namespace phi { + +/* + This runner only supports: + T in {half, __nv_bfloat} WeightType in {uint8_t, cutlass::uint4b_t} + + Activations, biases, scales and outputs are all assumed to be row-major. + + However, it is assumed that B is in a special format governed by + cutlass_extensions/gemm/kernel/mixed_gemm_B_layout. In this case, B must be + preprocessed using the cutlass weight only quant preprocessors. The weight + preprocessor will instantiate the layout and preprocess based on the + instantiation, so layout changes should only require modifications to + mix_gemm_B_layout.h. +*/ + +template +class CutlassFpAIntBGemmRunner { + public: + CutlassFpAIntBGemmRunner(); + ~CutlassFpAIntBGemmRunner(); + + void gemm(const T* A, + const WeightType* B, + const float* weight_scales, + T* C, + int m, + int n, + int k, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream); + + void gemm_bias_act(const T* A, + const WeightType* B, + const float* weight_scales, + const T* biases, + T* C, + int m, + int n, + int k, + std::string activation_type, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream); + + // Returns desired workspace size in bytes. + int getWorkspaceSize(const int m, const int n, const int k); + + private: + template + void dispatch_to_arch(const T* A, + const WeightType* B, + const float* weight_scales, + const T* biases, + T* C, + int m, + int n, + int k, + CutlassGemmConfig gemm_config, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream, + int* occupancy = nullptr); + + template + void run_gemm(const T* A, + const WeightType* B, + const float* weight_scales, + const T* biases, + T* C, + int m, + int n, + int k, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream); + + private: + static constexpr int split_k_limit = 7; + + int sm_; + int multi_processor_count_; +}; + +// This allocation is present to help with compiling with other structures in +// FT. It will throw an error in all functions because this runner assumes the +// weight type and the activation type are different. We allow empty classes to +// be created, but any calls to gemm or gemm_bias_act will throw an error. +template +class CutlassFpAIntBGemmRunner { + public: + CutlassFpAIntBGemmRunner() = default; + ~CutlassFpAIntBGemmRunner() = default; + + void gemm(const float* A, + const WeightType* B, + const float* weight_scales, + float* C, + int m, + int n, + int k, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream); + + void gemm_bias_act(const float* A, + const WeightType* B, + const float* weight_scales, + const float* biases, + float* C, + int m, + int n, + int k, + std::string activation_type, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream); + + int getWorkspaceSize(const int m, const int n, const int k); +}; +} // namespace phi diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h new file mode 100644 index 0000000000000000000000000000000000000000..d25374b2ae43d3ac988711de39021be446141366 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h @@ -0,0 +1,814 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#pragma once + +#include "cutlass/gemm/device/gemm_universal_base.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/compute_occupancy.h" + +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/epilogue_helpers.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/ft_gemm_configs.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_extensions/gemm/threadblock/default_mma.h" + +#pragma GCC diagnostic pop + +#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/cutlass_heuristic.h" +#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h" +#include "paddle/phi/kernels/fusion/cutlass/utils/cuda_utils.h" +namespace phi { + +template +void generic_mixed_gemm_kernelLauncher(const T* A, + const WeightType* B, + const float* weight_scales, + const T* biases, + T* C, + int m, + int n, + int k, + CutlassGemmConfig gemm_config, + char* workspace, + size_t workspace_bytes, + cudaStream_t stream, + int* occupancy = nullptr) { + static_assert(cutlass::platform::is_same::value || +#ifdef PADDLE_CUDA_BF16 + cutlass::platform::is_same::value || +#endif + cutlass::platform::is_same::value, + "Specialized for bfloat16, half, float"); + static_assert( + cutlass::platform::is_same::value || + cutlass::platform::is_same::value || + cutlass::platform::is_same::value, + ""); + + // The cutlass type for the input elements. This is needed to convert to + // cutlass::half_t if necessary. + using ElementType_ = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::half_t, + T>::type; + using ElementType = ElementType_; + + using CutlassWeightType_ = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::half_t, + WeightType>::type; + using CutlassWeightType = CutlassWeightType_; + + // We need separate config for each architecture since we will target + // different tensorcore instructions. For float, we do not target TCs. + using MixedGemmArchTraits = cutlass::gemm::kernel:: + MixedGemmArchTraits; + using ElementAccumulator = typename MixedGemmArchTraits::AccType; + + using EpilogueOp = typename Epilogue::Op; + + using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm< + ElementType, + cutlass::layout::RowMajor, + MixedGemmArchTraits::ElementsPerAccessA, + CutlassWeightType, + typename MixedGemmArchTraits::LayoutB, + MixedGemmArchTraits::ElementsPerAccessB, + ElementType, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + arch, + ThreadblockShape, + WarpShape, + typename MixedGemmArchTraits::InstructionShape, + EpilogueOp, + typename cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + Stages, + true, + typename MixedGemmArchTraits::Operator>::GemmKernel; + + using GemmKernel = cutlass::gemm::kernel::GemmFpAIntB< + typename GemmKernel_::Mma, + typename GemmKernel_::Epilogue, + typename GemmKernel_::ThreadblockSwizzle, + arch, // Ensure top level arch is used for dispatch + GemmKernel_::kSplitKSerial>; + + if (occupancy != nullptr) { + *occupancy = compute_occupancy_for_kernel(); + return; + } + + using Gemm = cutlass::gemm::device::GemmUniversalBase; + + const int ldb = + cutlass::platform::is_same::value + ? n + : k * GemmKernel::kInterleave; + + typename Gemm::Arguments args( + {m, n, k}, + {reinterpret_cast(const_cast(A)), k}, + {reinterpret_cast(const_cast(B)), ldb}, + {reinterpret_cast(const_cast(weight_scales)), 0}, + {reinterpret_cast(const_cast(biases)), 0}, + {reinterpret_cast(C), n}, + gemm_config.split_k_factor, + {ElementAccumulator(1.f), ElementAccumulator(0.f)}); + + // This assertion is enabled because because for the column interleaved + // layout, K MUST be a multiple of threadblockK. The reason for this is that + // the default pitchlinear iterators are used to handle walking over the + // interleaved matrix. The way masking in handled in these do not map to the + // interleaved layout. We need to write our own predicated iterator in order + // to relax this limitation. + if (GemmKernel::kInterleave > 1 && ((k % MixedGemmArchTraits::ThreadblockK) || + ((k / gemm_config.split_k_factor) % + MixedGemmArchTraits::ThreadblockK))) { + throw std::runtime_error( + "Temp assertion: k must be multiple of threadblockK"); + } + + Gemm gemm; + if (gemm.get_workspace_size(args) > workspace_bytes) { + VLOG(1) << "Requested split-k but workspace size insufficient. Falling " + "back to non-split-k implementation."; + // If requested split-k factor will require more workspace bytes, revert to + // standard gemm. + args.batch_count = 1; + } + + auto can_implement = gemm.can_implement(args); + if (can_implement != cutlass::Status::kSuccess) { + std::string err_msg = + "fpA_intB cutlass kernel will fail for params. Error: " + + std::string(cutlassGetStatusString(can_implement)); + throw std::runtime_error("[fpA_intB Runner] " + err_msg); + } + + auto init_status = gemm.initialize(args, workspace, stream); + if (init_status != cutlass::Status::kSuccess) { + std::string err_msg = + "Failed to initialize cutlass fpA_intB gemm. Error: " + + std::string(cutlassGetStatusString(init_status)); + throw std::runtime_error("[fpA_intB Runner] " + err_msg); + } + + auto run_status = gemm.run(stream); + if (run_status != cutlass::Status::kSuccess) { + std::string err_msg = "Failed to run cutlass fpA_intB gemm. Error: " + + std::string(cutlassGetStatusString(run_status)); + throw std::runtime_error("[fpA_intB Runner] " + err_msg); + } +} + +template +struct dispatch_stages { + static void dispatch(const T* A, + const WeightType* B, + const float* weight_scales, + const T* biases, + T* C, + int m, + int n, + int k, + CutlassGemmConfig gemm_config, + char* workspace, + size_t workspace_bytes, + cudaStream_t stream, + int* occupancy = nullptr) { + std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " + + std::to_string(arch::kMinComputeCapability) + + " with stages set to " + std::to_string(Stages); + throw std::runtime_error("[dispatch_stages::dispatch] " + err_msg); + } +}; + +template +struct dispatch_stages { + static void dispatch(const T* A, + const WeightType* B, + const float* weight_scales, + const T* biases, + T* C, + int m, + int n, + int k, + CutlassGemmConfig gemm_config, + char* workspace, + size_t workspace_bytes, + cudaStream_t stream, + int* occupancy = nullptr) { + generic_mixed_gemm_kernelLauncher(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + } +}; + +template +struct dispatch_stages 2)>::type> { + static void dispatch(const T* A, + const WeightType* B, + const float* weight_scales, + const T* biases, + T* C, + int m, + int n, + int k, + CutlassGemmConfig gemm_config, + char* workspace, + size_t workspace_bytes, + cudaStream_t stream, + int* occupancy = nullptr) { + generic_mixed_gemm_kernelLauncher(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + } +}; + +template +void dispatch_gemm_config(const T* A, + const WeightType* B, + const float* weight_scales, + const T* biases, + T* C, + int m, + int n, + int k, + CutlassGemmConfig gemm_config, + char* workspace, + size_t workspace_bytes, + cudaStream_t stream, + int* occupancy = nullptr) { + switch (gemm_config.stages) { + case 2: + using DispatcherStages2 = dispatch_stages; + DispatcherStages2::dispatch(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; + case 3: + using DispatcherStages3 = dispatch_stages; + DispatcherStages3::dispatch(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; + case 4: + using DispatcherStages4 = dispatch_stages; + DispatcherStages4::dispatch(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; + default: + std::string err_msg = "dispatch_gemm_config does not support stages " + + std::to_string(gemm_config.stages); + throw std::runtime_error("[dispatch_gemm_config] " + err_msg); + break; + } +} + +template +void dispatch_gemm_to_cutlass(const T* A, + const WeightType* B, + const float* weight_scales, + const T* biases, + T* C, + int m, + int n, + int k, + char* workspace, + size_t workspace_bytes, + CutlassGemmConfig gemm_config, + cudaStream_t stream, + int* occupancy = nullptr) { + // Note that SIMT configs are omitted here since they are not supported for + // fpA_intB. We also only instantiate configs here where threadblockShapeM == + // warpShapeM since those usually perform the best for mixed type gemms. + switch (gemm_config.tile_config) { + case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<32, 32, 64>>( + A, + B, + weight_scales, + biases, + C, + m, + n, + k, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; + case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 32, 64>>( + A, + B, + weight_scales, + biases, + C, + m, + n, + k, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; + case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<128, 32, 64>>( + A, + B, + weight_scales, + biases, + C, + m, + n, + k, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; + // config for M_16000_N_12288_K_6144 in encoder + case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 64, 64>>( + A, + B, + weight_scales, + biases, + C, + m, + n, + k, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; + case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: + dispatch_gemm_config, + cutlass::gemm::GemmShape<64, 64, 64>>( + A, + B, + weight_scales, + biases, + C, + m, + n, + k, + gemm_config, + workspace, + workspace_bytes, + stream, + occupancy); + break; + case CutlassTileConfig::Undefined: + throw std::runtime_error( + "[fpA_intB][dispatch_gemm_to_cutlass] gemm config " + "undefined."); + break; + case CutlassTileConfig::ChooseWithHeuristic: + throw std::runtime_error( + "[fpA_intB][dispatch_gemm_to_cutlass] gemm config should " + "have already been set by heuristic."); + break; + default: + throw std::runtime_error( + "[fpA_intB][dispatch_gemm_to_cutlass] Config is invalid " + "for mixed type GEMM."); + break; + } +} + +template +CutlassFpAIntBGemmRunner::CutlassFpAIntBGemmRunner() { + int device{-1}; + check_cuda_error(cudaGetDevice(&device)); + sm_ = getSMVersion(); + check_cuda_error(cudaDeviceGetAttribute( + &multi_processor_count_, cudaDevAttrMultiProcessorCount, device)); +} + +template +CutlassFpAIntBGemmRunner::~CutlassFpAIntBGemmRunner() {} + +template +template +void CutlassFpAIntBGemmRunner::dispatch_to_arch( + const T* A, + const WeightType* B, + const float* weight_scales, + const T* biases, + T* C, + int m, + int n, + int k, + CutlassGemmConfig gemm_config, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream, + int* occupancy) { + // if (sm_ >= 70 && sm_ < 75) { + // dispatch_gemm_to_cutlass( + // A, B, weight_scales, biases, C, m, n, k, workspace_ptr, + // workspace_bytes, gemm_config, stream, occupancy); + // } + // else if (sm_ >= 75 && sm_ < 80) { + // dispatch_gemm_to_cutlass( + // A, B, weight_scales, biases, C, m, n, k, workspace_ptr, + // workspace_bytes, gemm_config, stream, occupancy); + // } + // else + if (sm_ >= 80 && sm_ < 90) { + dispatch_gemm_to_cutlass( + A, + B, + weight_scales, + biases, + C, + m, + n, + k, + workspace_ptr, + workspace_bytes, + gemm_config, + stream, + occupancy); + } else { + throw std::runtime_error( + "[CutlassFpAIntBGemmRunner][GEMM Dispatch] Arch unsupported " + "for CUTLASS mixed type GEMM"); + } +} + +template +template +void CutlassFpAIntBGemmRunner::run_gemm( + const T* A, + const WeightType* B, + const float* weight_scales, + const T* biases, + T* C, + int m, + int n, + int k, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream) { + static constexpr bool is_weight_only = !std::is_same::value; + const bool is_weight_only_encoder = m >= 512 ? true : false; + std::vector candidate_configs = + get_candidate_configs(sm_, is_weight_only, is_weight_only_encoder, false); + std::vector occupancies(candidate_configs.size()); + + for (size_t ii = 0; ii < candidate_configs.size(); ++ii) { + dispatch_to_arch(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + candidate_configs[ii], + workspace_ptr, + workspace_bytes, + stream, + &occupancies[ii]); + } + // Standard GEMM, so 1 "expert". We use the same function for MoE and regular + // FFN. + static constexpr int num_experts = 1; + CutlassGemmConfig chosen_config = + estimate_best_config_from_occupancies(candidate_configs, + occupancies, + m, + n, + k, + num_experts, + split_k_limit, + workspace_bytes, + multi_processor_count_, + is_weight_only); + + dispatch_to_arch(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + chosen_config, + workspace_ptr, + workspace_bytes, + stream); +} + +template +void CutlassFpAIntBGemmRunner::gemm_bias_act( + const T* A, + const WeightType* B, + const float* weight_scales, + const T* biases, + T* C, + int m, + int n, + int k, + std::string activation_type, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream) { + if (activation_type == "gelu") { + run_gemm(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + workspace_ptr, + workspace_bytes, + stream); + } else if (activation_type == "relu") { + run_gemm(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + workspace_ptr, + workspace_bytes, + stream); + } else if (activation_type == "none") { + run_gemm(A, + B, + weight_scales, + biases, + C, + m, + n, + k, + workspace_ptr, + workspace_bytes, + stream); + } else { + throw std::runtime_error(("Invalid activation type.")); + } +} + +template +void CutlassFpAIntBGemmRunner::gemm(const T* A, + const WeightType* B, + const float* weight_scales, + T* C, + int m, + int n, + int k, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream) { + run_gemm(A, + B, + weight_scales, + nullptr, + C, + m, + n, + k, + workspace_ptr, + workspace_bytes, + stream); +} + +template +int CutlassFpAIntBGemmRunner::getWorkspaceSize(const int m, + const int n, + const int k) { + // sizes for each config, which would launch the maximum number of blocks + const int max_grid_m = (m + 31) / 32; + const int max_grid_n = (n + 127) / 128; + // We need 4 bytes per block in the worst case. We launch split_k_limit in z + // dim. + return max_grid_m * max_grid_n * split_k_limit * 4; +} + +// =============================== Specialization T == WeightType +// ======================================= +template +void CutlassFpAIntBGemmRunner::gemm_bias_act( + const float* A, + const WeightType* B, + const float* weight_scales, + const float* biases, + float* C, + int m, + int n, + int k, + std::string activation_type, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream) { + throw std::runtime_error( + ("Attempting to run mixed gemm bias act when the types are the same is " + "an error.")); +} + +template +void CutlassFpAIntBGemmRunner::gemm( + const float* A, + const WeightType* B, + const float* weight_scales, + float* C, + int m, + int n, + int k, + char* workspace_ptr, + const size_t workspace_bytes, + cudaStream_t stream) { + throw std::runtime_error(( + "Attempting to run mixed gemm when the types are the same is an error.")); +} + +template +int CutlassFpAIntBGemmRunner::getWorkspaceSize(const int m, + const int n, + const int k) { + return 0; +} + +} // namespace phi diff --git a/paddle/phi/kernels/fusion/cutlass/utils/cuda_utils.h b/paddle/phi/kernels/fusion/cutlass/utils/cuda_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..6d2fa3ed337437276b3d5e627671149481059110 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/utils/cuda_utils.h @@ -0,0 +1,492 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#ifdef SPARSITY_ENABLED +#include +#endif + +namespace phi { + +#define MAX_CONFIG_NUM 20 +#define COL32_ 32 +// workspace for cublas gemm : 32MB +#define CUBLAS_WORKSPACE_SIZE 33554432 + +typedef struct __align__(4) { + half x, y, z, w; +} +half4; + +/* **************************** type definition ***************************** */ + +enum CublasDataType { + FLOAT_DATATYPE = 0, + HALF_DATATYPE = 1, + BFLOAT16_DATATYPE = 2, + INT8_DATATYPE = 3, + FP8_DATATYPE = 4 +}; + +// enum FtCudaDataType { FP32 = 0, FP16 = 1, BF16 = 2, INT8 = 3, FP8 = 4 }; + +// enum class OperationType { FP32, FP16, BF16, INT8, FP8 }; + +/* **************************** debug tools ********************************* */ +static const char* _cudaGetErrorEnum(cudaError_t error) { + return cudaGetErrorString(error); +} + +static const char* _cudaGetErrorEnum(cublasStatus_t error) { + switch (error) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; + + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; + + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR"; + } + return ""; +} + +template +void check(T result, + char const* const func, + const char* const file, + int const line) { + if (result) { + throw std::runtime_error(std::string("[ERROR] CUDA runtime error: ") + + (_cudaGetErrorEnum(result)) + " " + file + ":" + + std::to_string(line) + " \n"); + } +} + +#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__) +#define check_cuda_error_2(val, file, line) check((val), #val, file, line) + +inline void syncAndCheck(const char* const file, int const line) { + // When FT_DEBUG_LEVEL=DEBUG, must check error + static char* level_name = std::getenv("FT_DEBUG_LEVEL"); + if (level_name != nullptr) { + static std::string level = std::string(level_name); + if (level == "DEBUG") { + cudaDeviceSynchronize(); + cudaError_t result = cudaGetLastError(); + if (result) { + throw std::runtime_error(std::string("[ERROR] CUDA runtime error: ") + + (_cudaGetErrorEnum(result)) + " " + file + + ":" + std::to_string(line) + " \n"); + } + VLOG(2) << "run syncAndCheck at " << file << ":" << line; + } + } + +#ifndef NDEBUG + cudaDeviceSynchronize(); + cudaError_t result = cudaGetLastError(); + if (result) { + throw std::runtime_error(std::string("[ERROR] CUDA runtime error: ") + + (_cudaGetErrorEnum(result)) + " " + file + ":" + + std::to_string(line) + " \n"); + } +#endif +} + +#define sync_check_cuda_error() syncAndCheck(__FILE__, __LINE__) + +#define checkCUDNN(expression) \ + { \ + cudnnStatus_t status = (expression); \ + if (status != CUDNN_STATUS_SUCCESS) { \ + std::cerr << "Error on file " << __FILE__ << " line " << __LINE__ \ + << ": " << cudnnGetErrorString(status); \ + std::exit(EXIT_FAILURE); \ + } \ + } + +template +void print_to_file(const T* result, + const int size, + const char* file, + cudaStream_t stream = 0, + std::ios::openmode open_mode = std::ios::out); + +template +void print_abs_mean(const T* buf, + uint size, + cudaStream_t stream, + std::string name = ""); + +template +void print_to_screen(const T* result, const int size); + +template +void check_max_val(const T* result, const int size); + +template +void check_abs_mean_val(const T* result, const int size); + +#define PRINT_FUNC_NAME_() \ + do { \ + VLOG(2) << "[CALL] " << __FUNCTION__ << " "; \ + } while (0) + +[[noreturn]] inline void throwRuntimeError(const char* const file, + int const line, + std::string const& info = "") { + throw std::runtime_error(std::string("[ERROR] ") + info + + " Assertion fail: " + file + ":" + + std::to_string(line) + " \n"); +} + +inline void myAssert(bool result, + const char* const file, + int const line, + std::string const& info = "") { + if (!result) { + throwRuntimeError(file, line, info); + } +} + +#define FT_CHECK(val) myAssert(val, __FILE__, __LINE__) +#define FT_CHECK_WITH_INFO(val, info) \ + do { \ + bool is_valid_val = (val); \ + if (!is_valid_val) { \ + paddle::operators::myAssert(is_valid_val, __FILE__, __LINE__, (info)); \ + } \ + } while (0) + +#define FT_THROW(info) throwRuntimeError(__FILE__, __LINE__, info) + +#ifdef SPARSITY_ENABLED +#define CHECK_CUSPARSE(func) \ + { \ + cusparseStatus_t status = (func); \ + if (status != CUSPARSE_STATUS_SUCCESS) { \ + throw std::runtime_error( \ + std::string("[ERROR] CUSPARSE API failed at line ") + \ + std::to_string(__LINE__) + " in file " + __FILE__ + ": " + \ + cusparseGetErrorString(status) + " " + std::to_string(status)); \ + } \ + } +#endif + +/*************Time Handling**************/ +class CudaTimer { + private: + cudaEvent_t event_start_; + cudaEvent_t event_stop_; + cudaStream_t stream_; + + public: + explicit CudaTimer(cudaStream_t stream = 0) { stream_ = stream; } + void start() { + check_cuda_error(cudaEventCreate(&event_start_)); + check_cuda_error(cudaEventCreate(&event_stop_)); + check_cuda_error(cudaEventRecord(event_start_, stream_)); + } + float stop() { + float time; + check_cuda_error(cudaEventRecord(event_stop_, stream_)); + check_cuda_error(cudaEventSynchronize(event_stop_)); + check_cuda_error(cudaEventElapsedTime(&time, event_start_, event_stop_)); + check_cuda_error(cudaEventDestroy(event_start_)); + check_cuda_error(cudaEventDestroy(event_stop_)); + return time; + } + ~CudaTimer() {} +}; + +static double diffTime(timeval start, timeval end) { + return (end.tv_sec - start.tv_sec) * 1000 + + (end.tv_usec - start.tv_usec) * 0.001; +} + +/* ***************************** common utils ****************************** */ + +inline void print_mem_usage(std::string time = "after allocation") { + size_t free_bytes, total_bytes; + check_cuda_error(cudaMemGetInfo(&free_bytes, &total_bytes)); + float free = static_cast(free_bytes) / 1024.0 / 1024.0 / 1024.0; + float total = static_cast(total_bytes) / 1024.0 / 1024.0 / 1024.0; + float used = total - free; + printf("%-20s: free: %5.2f GB, total: %5.2f GB, used: %5.2f GB\n", + time.c_str(), + free, + total, + used); +} + +inline int getSMVersion() { + int device{-1}; + check_cuda_error(cudaGetDevice(&device)); + int sm_major = 0; + int sm_minor = 0; + check_cuda_error(cudaDeviceGetAttribute( + &sm_major, cudaDevAttrComputeCapabilityMajor, device)); + check_cuda_error(cudaDeviceGetAttribute( + &sm_minor, cudaDevAttrComputeCapabilityMinor, device)); + return sm_major * 10 + sm_minor; +} + +inline int getMaxSharedMemoryPerBlock() { + int device{-1}; + check_cuda_error(cudaGetDevice(&device)); + int max_shared_memory_size = 0; + check_cuda_error(cudaDeviceGetAttribute( + &max_shared_memory_size, cudaDevAttrMaxSharedMemoryPerBlock, device)); + return max_shared_memory_size; +} + +inline std::string getDeviceName() { + int device{-1}; + check_cuda_error(cudaGetDevice(&device)); + cudaDeviceProp props; + check_cuda_error(cudaGetDeviceProperties(&props, device)); + return std::string(props.name); +} + +inline int div_up(int a, int n) { return (a + n - 1) / n; } + +cudaError_t getSetDevice(int i_device, int* o_device = NULL); + +inline int getDevice() { + int current_dev_id = 0; + check_cuda_error(cudaGetDevice(¤t_dev_id)); + return current_dev_id; +} + +inline int getDeviceCount() { + int count = 0; + check_cuda_error(cudaGetDeviceCount(&count)); + return count; +} + +template +CublasDataType getCublasDataType() { + if (std::is_same::value) { + return HALF_DATATYPE; + } else if (std::is_same::value) { + return FLOAT_DATATYPE; + } else { + FT_CHECK(false); + return FLOAT_DATATYPE; + } +} + +template +cudaDataType_t getCudaDataType() { + if (std::is_same::value) { + return CUDA_R_16F; + } else if (std::is_same::value) { + return CUDA_R_32F; + } else { + FT_CHECK(false); + return CUDA_R_32F; + } +} + +template +struct getTypeFromCudaDataType { + using Type = float; +}; + +template <> +struct getTypeFromCudaDataType { + using Type = half; +}; + +template +struct packed_type; +template <> +struct packed_type { + using type = float; +}; +template <> +struct packed_type { + using type = half2; +}; + +template +struct num_elems; +template <> +struct num_elems { + static constexpr int value = 1; +}; +template <> +struct num_elems { + static constexpr int value = 2; +}; +template <> +struct num_elems { + static constexpr int value = 4; +}; +template <> +struct num_elems { + static constexpr int value = 1; +}; +template <> +struct num_elems { + static constexpr int value = 2; +}; + +template +struct packed_as; +template +struct packed_as { + using type = T; +}; +template <> +struct packed_as { + using type = half2; +}; +template <> +struct packed_as { + using type = float2; +}; +template <> +struct packed_as { + using type = int16_t; +}; +template <> +struct packed_as { + using type = int2; +}; +template <> +struct packed_as { + using type = half; +}; + +inline __device__ float2 operator*(float2 a, float2 b) { + return make_float2(a.x * b.x, a.y * b.y); +} +inline __device__ float2 operator*(float2 a, float b) { + return make_float2(a.x * b, a.y * b); +} + +template +void compareTwoTensor(const T1* pred, + const T2* ref, + const int size, + const int print_size = 0, + const std::string filename = "") { + T1* h_pred = new T1[size]; + T2* h_ref = new T2[size]; + check_cuda_error( + cudaMemcpy(h_pred, pred, size * sizeof(T1), cudaMemcpyDeviceToHost)); + check_cuda_error( + cudaMemcpy(h_ref, ref, size * sizeof(T2), cudaMemcpyDeviceToHost)); + + FILE* fd = nullptr; + if (filename != "") { + fd = fopen(filename.c_str(), "w"); + fprintf(fd, + "| %10s | %10s | %10s | %10s | \n", + "pred", + "ref", + "abs_diff", + "rel_diff(%)"); + } + + if (print_size > 0) { + VLOG(2) << " id | pred | ref |abs diff | rel diff (%) |"; + } + float mean_abs_diff = 0.0f; + float mean_rel_diff = 0.0f; + int count = 0; + for (int i = 0; i < size; i++) { + if (i < print_size) { + VLOG(2) << i << " | " << static_cast(h_pred[i]) << " | " + << static_cast(h_ref[i]) << " | " + << (abs(static_cast(h_pred[i]) - + static_cast(h_ref[i]))) + << " | " + << (abs(static_cast(h_pred[i]) - + static_cast(h_ref[i])) / + (abs(static_cast(h_ref[i])) + 1e-6f) * 100.f) + << " | "; + } + if (static_cast(h_pred[i]) == 0) { + continue; + } + count += 1; + mean_abs_diff += + abs(static_cast(h_pred[i]) - static_cast(h_ref[i])); + mean_rel_diff += + abs(static_cast(h_pred[i]) - static_cast(h_ref[i])) / + (abs(static_cast(h_ref[i])) + 1e-6f) * 100.f; + + if (fd != nullptr) { + fprintf( + fd, + "| %10.5f | %10.5f | %10.5f | %11.5f |\n", + static_cast(h_pred[i]), + static_cast(h_ref[i]), + abs(static_cast(h_pred[i]) - static_cast(h_ref[i])), + abs(static_cast(h_pred[i]) - static_cast(h_ref[i])) / + (abs(static_cast(h_ref[i])) + 1e-6f) * 100.f); + } + } + mean_abs_diff = mean_abs_diff / static_cast(count); + mean_rel_diff = mean_rel_diff / static_cast(count); + VLOG(2) << "mean_abs_diff: " << mean_abs_diff + << ", mean_rel_diff: " << mean_rel_diff; + + if (fd != nullptr) { + fprintf(fd, + "mean_abs_diff: % 6.4f, mean_rel_diff: % 6.4f (%%)", + mean_abs_diff, + mean_rel_diff); + fclose(fd); + } + delete[] h_pred; + delete[] h_ref; +} + +/* ************************** end of common utils ************************** */ +} // namespace phi diff --git a/paddle/phi/kernels/gpu/llm_int8_mat_mul_kernel.cu b/paddle/phi/kernels/gpu/llm_int8_mat_mul_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..47aca6c4cafef0bb1ceceb0c3210f7372a012618 --- /dev/null +++ b/paddle/phi/kernels/gpu/llm_int8_mat_mul_kernel.cu @@ -0,0 +1,75 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/llm_int8_mat_mul_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/kernel_registry.h" +#ifndef PADDLE_WITH_HIP +#include "paddle/phi/kernels/impl/llm_int8_mat_mul_kernel_impl.h" +#endif + +namespace phi { + +template +void llm_int8_compute(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& weight, + const DenseTensor& weight_scale, + const float threshold, + DenseTensor* out) { +#if defined(PADDLE_WITH_HIP) + LOG(ERROR) << "Please compile with cublaslt, ROCM platform isn't support it"; +#else + DenseTensor cublaslt_workspace; + cublaslt_workspace.Resize({{3000000}}); + dev_ctx.template Alloc(&cublaslt_workspace); + const auto x_dims = x.dims(); + const auto w_dims = weight.dims(); + int k = w_dims[1]; + int n = w_dims[0]; + int m = x.numel() / k; + // mk * transpose(nk) = mn + llm_int8::LLMGemm(dev_ctx, + &weight, + &x, + &weight_scale, + threshold, + out, + &cublaslt_workspace, + "llm_int8_mat_mul", + m, + k, + n); +#endif +} + +template +void LLMInt8MatMulKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& weight, + const DenseTensor& weight_scale, + const float threshold, + DenseTensor* out) { + dev_ctx.template Alloc(out); + llm_int8_compute( + dev_ctx, x, weight, weight_scale, threshold, out); +} +} // namespace phi + +PD_REGISTER_KERNEL(llm_int8_mat_mul, + GPU, + ALL_LAYOUT, + phi::LLMInt8MatMulKernel, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/weight_only_mat_mul_kernel.cu b/paddle/phi/kernels/gpu/weight_only_mat_mul_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..d8216a2fe68aa3930efa6f40973e01e050ad565a --- /dev/null +++ b/paddle/phi/kernels/gpu/weight_only_mat_mul_kernel.cu @@ -0,0 +1,73 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/weight_only_mat_mul_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/datatype_traits.h" +#include "paddle/phi/core/kernel_registry.h" +#if defined(PADDLE_WITH_CUTLASS) +#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" +#endif + +namespace phi { + +template +void WeightOnlyMatMulKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& weight, + const DenseTensor& weight_scale, + DenseTensor* out) { +#if defined(PADDLE_WITH_CUTLASS) + dev_ctx.template Alloc(out); + const auto x_dims = x.dims(); + const auto w_dims = weight.dims(); + + int k = w_dims[0]; + int n = w_dims[1]; + int m = x.numel() / k; + auto mixed_gemm_runner = + CutlassFpAIntBGemmRunner::DataType, + uint8_t>(); + int mixgemm_max_size = std::max(n, k); + DenseTensor mixgemm_workspace; + int64_t mixgemm_workspace_size_bytes = + mixed_gemm_runner.getWorkspaceSize(m, mixgemm_max_size, mixgemm_max_size); + + mixgemm_workspace.Resize({mixgemm_workspace_size_bytes}); + dev_ctx.template Alloc(&mixgemm_workspace); + char* mixgemm_workspace_data = + reinterpret_cast(mixgemm_workspace.data()); + mixed_gemm_runner.gemm( + reinterpret_cast::DataType*>( + x.data()), + reinterpret_cast(weight.data()), + reinterpret_cast(weight_scale.data()), + reinterpret_cast::DataType*>(out->data()), + m, + n, + k, + mixgemm_workspace_data, + mixgemm_workspace_size_bytes, + dev_ctx.stream()); +#else + LOG(ERROR) << "Please compile with cutlass to EnableUseCutlass()"; +#endif +} +} // namespace phi + +PD_REGISTER_KERNEL(weight_only_mat_mul, + GPU, + ALL_LAYOUT, + phi::WeightOnlyMatMulKernel, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/impl/llm_int8_mat_mul_kernel_impl.h b/paddle/phi/kernels/impl/llm_int8_mat_mul_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..6847289fa825c502815e463cb8a68be62b559454 --- /dev/null +++ b/paddle/phi/kernels/impl/llm_int8_mat_mul_kernel_impl.h @@ -0,0 +1,714 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/phi/common/datatype_traits.h" +#include "paddle/phi/kernels/funcs/cublaslt.h" +#include "paddle/phi/kernels/funcs/quant_dequant.h" + +#pragma once + +namespace phi { + +namespace llm_int8 { +constexpr int32_t WARP_SIZE = 32; +constexpr int32_t HALF_WARP = 16; +constexpr float QUANT_MAX_BOUND = 127.0; +constexpr float QUANT_MIN_BOUND = -127.0; +constexpr int32_t kBlockSize = 256; +constexpr int32_t kNumWaves = 16; + +inline cudaError_t GetGridSize(int64_t n, int* num_blocks) { + int dev; + { + cudaError_t err = cudaGetDevice(&dev); + if (err != cudaSuccess) { + return err; + } + } + int sm_count; + { + cudaError_t err = + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev); + if (err != cudaSuccess) { + return err; + } + } + int tpm; + { + cudaError_t err = cudaDeviceGetAttribute( + &tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev); + if (err != cudaSuccess) { + return err; + } + } + *num_blocks = + std::max(1, + std::min((n + kBlockSize - 1) / kBlockSize, + sm_count * tpm / kBlockSize * kNumWaves)); + return cudaSuccess; +} + +template +inline cudaError_t GetMaxOccupancyBlocks(Func func, + int64_t block_size, + size_t dynamic_smem_size, + int64_t max_blocks, + int* num_blocks) { + int dev; + { + cudaError_t err = cudaGetDevice(&dev); + if (err != cudaSuccess) { + return err; + } + } + int sm_count; + { + cudaError_t err = + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev); + if (err != cudaSuccess) { + return err; + } + } + int max_active_blocks; + { + cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, func, block_size, dynamic_smem_size); + } + *num_blocks = std::max( + 1, + std::min(max_blocks, sm_count * max_active_blocks * kNumWaves)); + return cudaSuccess; +} + +template +struct MaxFunc { + __device__ T operator()(T a, T b) { return max(a, b); } +}; + +template <> +struct MaxFunc { + __device__ half operator()(half a, half b) { +#if __CUDA_ARCH__ >= 800 + return __hmax(a, b); +#else + return max(static_cast(a), static_cast(b)); +#endif + } +}; + +#ifdef PADDLE_CUDA_BF16 +template <> +struct MaxFunc<__nv_bfloat16> { + __device__ __nv_bfloat16 operator()(__nv_bfloat16 a, __nv_bfloat16 b) { +#if __CUDA_ARCH__ >= 800 + return __hmax(a, b); +#else + return max(static_cast(a), static_cast(b)); +#endif + } +}; +#endif + +template +struct AbsFunc { + __device__ T operator()(T x) { return abs(x); } +}; + +template <> +struct AbsFunc { + __device__ half operator()(half x) { +#if __CUDA_ARCH__ >= 800 + return __habs(x); +#else + return abs(static_cast(x)); +#endif + } +}; + +#ifdef PADDLE_CUDA_BF16 +template <> +struct AbsFunc<__nv_bfloat16> { + __device__ __nv_bfloat16 operator()(__nv_bfloat16 x) { +#if __CUDA_ARCH__ >= 800 + return __habs(x); +#else + return abs(static_cast(x)); +#endif + } +}; +#endif +template +struct QuantFunc { + HOSTDEVICE int8_t operator()(T x, float inverse_range) { + float tmp = static_cast(x) * QUANT_MAX_BOUND * inverse_range; + tmp = round(tmp); + if (tmp > QUANT_MAX_BOUND) + tmp = QUANT_MAX_BOUND; + else if (tmp < QUANT_MIN_BOUND) + tmp = QUANT_MIN_BOUND; + return static_cast(tmp); + } +}; + +template +struct DequantFunc { + HOSTDEVICE T operator()(int8_t x, T scale) { + return static_cast(static_cast(x) * static_cast(scale)); + } + HOSTDEVICE T operator()(int32_t x, T input_range, T weight_scale) { + return static_cast(static_cast(x) * + static_cast(input_range) * + static_cast(weight_scale) / (127.0f)); + } + HOSTDEVICE T operator()(int8_t x, float scale) { + return static_cast(static_cast(x) * static_cast(scale)); + } + HOSTDEVICE T operator()(int32_t x, float input_range, float weight_scale) { + return static_cast(static_cast(x) * + static_cast(input_range) * + static_cast(weight_scale) / (127.0f)); + } +}; + +template +__inline__ __device__ T LocalReduceMax(Vec& vec) { // NOLINT + T local_max = static_cast(0.0); +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + local_max = vec[i] > local_max ? vec[i] : local_max; + } + return local_max; +} + +template +__inline__ __device__ T WarpReduceAbsMax(T val, unsigned lane_mask) { +#pragma unroll + for (int mask = HALF_WARP; mask > 0; mask >>= 1) { + val = MaxFunc()(val, __shfl_xor_sync(lane_mask, val, mask, WARP_SIZE)); + } + return val; +} + +template +__inline__ __device__ T BlockReduceAbsMax(T val, unsigned mask) { + static __shared__ T smem[WARP_SIZE]; + int32_t lane_id = threadIdx.x & 0x1f; + int32_t warp_id = threadIdx.x >> 5; + val = WarpReduceAbsMax(val, mask); + if (lane_id == 0) { + smem[warp_id] = val; + } + __syncthreads(); + T abs_max_val = (threadIdx.x < blockDim.x / WARP_SIZE) ? smem[threadIdx.x] + : static_cast(0.0f); + abs_max_val = WarpReduceAbsMax(abs_max_val, mask); + return abs_max_val; +} + +template +__global__ void ReduceAbsMaxKernel(const T* x, + const float threshold, + const int32_t rows, + const int32_t cols, + float* row_ranges, + int32_t* outlier_idx) { + using InVec = phi::AlignedVector; + using ComputeVec = phi::AlignedVector; + + InVec in_vec; + ComputeVec abs_max_vec; +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + abs_max_vec[i] = 0.0f; + } + + ComputeType local_max_val = static_cast(0.0f); + for (int row_idx = blockIdx.x; row_idx < rows; row_idx += gridDim.x) { + for (int col_idx = threadIdx.x * VecSize; col_idx < cols; + col_idx += blockDim.x * VecSize) { + int32_t linear_index = row_idx * cols + col_idx; + phi::Load(x + linear_index, &in_vec); +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + in_vec[i] = AbsFunc()(in_vec[i]); + if (in_vec[i] > static_cast(threshold)) { + int32_t index = col_idx + i; + int32_t int_index = index / 32; + int32_t inner_index = index % 32; + atomicOr(outlier_idx + int_index, (1 << inner_index)); + in_vec[i] = 0.0; + } + abs_max_vec[i] = MaxFunc()( + abs_max_vec[i], static_cast(in_vec[i])); + } + } + local_max_val = + LocalReduceMax(abs_max_vec); + ComputeType tmp_max_val = + BlockReduceAbsMax(local_max_val, 0xffffffff); + if (threadIdx.x == 0) { + row_ranges[row_idx] = tmp_max_val; + } + } +} + +template +__global__ void QuantActKernel(const T* x, + const int32_t elem_cnt, + const int32_t cols, + const float* row_ranges, + const int32_t* outlier_idx, + int8_t* quant_x) { + using InVec = phi::AlignedVector; + using OutVec = phi::AlignedVector; + + InVec in_vec; + OutVec out_vec; + + for (int linear_index = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; + linear_index < elem_cnt; + linear_index += gridDim.x * blockDim.x * VecSize) { + int row_idx = linear_index / cols; + int col_idx = + linear_index - row_idx * cols; // equal to linear_index % cols + phi::Load(x + linear_index, &in_vec); + int32_t local_outlier_idx = outlier_idx[col_idx / 32]; + float scale = 1.0f / row_ranges[row_idx]; +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + int32_t index = linear_index + i; + if (local_outlier_idx & (1 << (index % 32))) { + out_vec[i] = 0; + } else { + out_vec[i] = QuantFunc()(in_vec[i], scale); + } + } + phi::Store(out_vec, quant_x + linear_index); + } +} + +template +__global__ void Fill(T* input, T value, int64_t num) { + phi::AlignedVector in_vec; + int stride = blockDim.x * gridDim.x * VecSize; + int base_idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; + + for (int idx = base_idx; idx < num; idx += stride) { +#pragma unroll + for (int j = 0; j < VecSize; ++j) { + in_vec[j] = value; + } + phi::Store(in_vec, input + idx); + } +} + +template +__global__ void SplitKernel(const T* x, + const int8_t* weight, + const float* weight_scale, + const int32_t* outlier_idx, + T* sub_x, + T* sub_weight, + int m, + int k, + int n, + int num_outlier_idx, + int kfp_num, + int sub_x_elem_cnt, + int sub_w_elem_cnt, + int elem_cnt) { + extern __shared__ int32_t k_ids_shm[]; + int32_t cnt = 0; + + if (threadIdx.x == 0) { +#pragma unroll + for (int i = 0; i < kfp_num; ++i) { + k_ids_shm[i] = -1; + } + + for (int i = 0; i < num_outlier_idx; ++i) { + int32_t outlier_id = outlier_idx[i]; + if (outlier_id == 0) continue; + for (int j = 0; j < 32; ++j) { + if (outlier_id & (1 << j)) { + k_ids_shm[cnt++] = i * 32 + j; + } + } + } + } + + __syncthreads(); + + for (int linear_idx = blockIdx.x * blockDim.x + threadIdx.x; + linear_idx < elem_cnt; + linear_idx += blockDim.x * gridDim.x) { + int32_t row_idx = linear_idx / kfp_num; // n + int32_t col_idx = linear_idx % kfp_num; // k + int32_t k_id = k_ids_shm[col_idx]; + if (k_id == -1) continue; + if (linear_idx < sub_x_elem_cnt) { + sub_x[row_idx * kfp_num + col_idx] = x[row_idx * k + k_id]; + } + + if (linear_idx < sub_w_elem_cnt) { + constexpr int32_t k_permute_const = 8; + int32_t k_mod_16 = k_id % 16; + int32_t temp_k_expr_1 = k_mod_16 - k_mod_16 / 8 * 8; + int32_t temp_k_expr_2 = k_mod_16 / 8; + int32_t permute_kk = temp_k_expr_1 + temp_k_expr_2 + + (temp_k_expr_2 + 1) % 2 * k_mod_16 * 2 / 2 + + temp_k_expr_1 * temp_k_expr_2 + k_id / 16 * 16; + int32_t permute_index = permute_kk % 64 + permute_kk / 64 * 128 + + 64 * (row_idx % 2) + k * 2 * (row_idx / 2); + int8_t shifted_weight = static_cast( + static_cast(weight[permute_index]) - 128); + sub_weight[row_idx * kfp_num + col_idx] = + DequantFunc()(shifted_weight, weight_scale[row_idx]); + } + } +} + +__global__ static void UpdateOutlier(int32_t* outlier_idx, int32_t* total_num) { + constexpr int IntSize = 32; + + int32_t outlier_val = outlier_idx[threadIdx.x]; +#pragma unroll + for (int i = 0; i < IntSize; i++) { + while (outlier_val) { + outlier_val = outlier_val & (outlier_val - 1); + // ++kfp_num; + atomicAdd(total_num, 1); + } + } +} + +// Input: x:dequantized_fp16:[m, n], x_fp16:T:[m, n], input_range:T:[m], +// weight_scale:T:[n] Outpuy: y:T:[m, n] +template +__global__ void DequantActivationMergeKernel(const T* x, + const T* x_fp, + T* y, + const int32_t elem_cnt) { + using FpVec = phi::AlignedVector; + + FpVec x_fp_vec; + FpVec out_vec; + FpVec x_vec; + + for (int linear_idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; + linear_idx < elem_cnt; + linear_idx += gridDim.x * blockDim.x * VecSize) { + phi::Load(x_fp + linear_idx, &x_fp_vec); + phi::Load(x + linear_idx, &x_vec); + +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + out_vec[i] = x_fp_vec[i] + (x_vec[i] / static_cast(127.0f)); + } + phi::Store(out_vec, y + linear_idx); + } +} + +// Input: x:int32:[m, n], x_fp16:T:[m, n], input_range:T:[m], weight_scale:T:[n] +// Outpuy: y:T:[m, n] + +template +__global__ void DequantMergeKernel(const int32_t* x, + const T* x_fp, + const float* input_range, + const float* weight_scale, + T* y, + int m, + int n) { + using FpVec = phi::AlignedVector; + using IntVec = phi::AlignedVector; + + FpVec x_fp_vec; + FpVec out_vec; + IntVec x_vec; + + for (int row_idx = blockIdx.x; row_idx < m; row_idx += gridDim.x) { + for (int col_idx = threadIdx.x * VecSize; col_idx < n; + col_idx += blockDim.x * VecSize) { + int linear_idx = row_idx * n + col_idx; + phi::Load(x_fp + linear_idx, &x_fp_vec); + phi::Load(x + linear_idx, &x_vec); +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + T dequant_x_fp = DequantFunc()( + x_vec[i], input_range[row_idx], weight_scale[col_idx + i]); + out_vec[i] = x_fp_vec[i] + dequant_x_fp; + } + phi::Store(out_vec, y + linear_idx); + } + } +} + +template +void LaunchFillKernel(T* input, + T value, + int64_t num, + backends::gpu::GpuLaunchConfig* gpu_config, + gpuStream_t stream) { + constexpr int VecSize = 16 / sizeof(T); + Fill + <<block_per_grid, gpu_config->thread_per_block, 0, stream>>>( + input, value, num); +} + +template +void LaunchReduceAbsMaxQuantKernel(const T* x, + const float threshold, + const int32_t rows, + const int32_t cols, + float* row_ranges, + int32_t* outlier_idx, + int8_t* quant_x, + gpuStream_t stream) { + constexpr int VecSize = 16 / sizeof(T); + + using DataT = typename PDDataTypeTraits::DataType; + using ComputeType = float; + + int32_t reduce_kernel_num_blocks; + PADDLE_ENFORCE_GPU_SUCCESS( + GetMaxOccupancyBlocks(ReduceAbsMaxKernel, + kBlockSize, + 0, + rows, + &reduce_kernel_num_blocks)); + assert((cols % VecSize == 0)); + + ReduceAbsMaxKernel + <<>>( + reinterpret_cast(x), + threshold, + rows, + cols, + row_ranges, + outlier_idx); + + const int32_t elem_cnt = rows * cols; + const int32_t vectorized_elem_cnt = elem_cnt / VecSize; + int32_t quant_kernel_num_blocks; + PADDLE_ENFORCE_GPU_SUCCESS( + GetGridSize(vectorized_elem_cnt, &quant_kernel_num_blocks)); + QuantActKernel + <<>>( + reinterpret_cast(x), + elem_cnt, + cols, + row_ranges, + outlier_idx, + quant_x); +} + +template +void LaunchSplitKernel(const T* x, + const int8_t* weight, + const float* weight_scale, + const int32_t* outlier_idx, + T* sub_x, + T* sub_weight, + int m, + int k, + int n, + int kfp_num, + gpuStream_t stream) { + int max_row = m > n ? m : n; + const int elem_cnt = max_row * kfp_num; + int num_blocks = 1; + PADDLE_ENFORCE_GPU_SUCCESS(GetGridSize(elem_cnt, &num_blocks)); + int64_t num_outlier_idx = (k + 31) / 32; + + const int32_t sub_x_elem_cnt = m * kfp_num; + const int32_t sub_w_elem_cnt = n * kfp_num; + + using DataT = typename PDDataTypeTraits::DataType; + SplitKernel + <<>>( + reinterpret_cast(x), + weight, + weight_scale, + outlier_idx, + reinterpret_cast(sub_x), + reinterpret_cast(sub_weight), + m, + k, + n, + num_outlier_idx, + kfp_num, + sub_x_elem_cnt, + sub_w_elem_cnt, + elem_cnt); +} + +template +void LaunchDequantMergeKernel(const int32_t* x, + const T* x_fp, + const float* input_range, + const float* weight_scale, + T* y, + int m, + int n, + gpuStream_t stream) { + constexpr int NumThreads = 256; + constexpr int VecSize = 16 / sizeof(T); + + using DataT = typename PDDataTypeTraits::DataType; + + DequantMergeKernel<<>>( + x, + reinterpret_cast(x_fp), + reinterpret_cast(input_range), + reinterpret_cast(weight_scale), + reinterpret_cast(y), + m, + n); +} + +template +void LLMGemm(const phi::GPUContext& dev_ctx, + const phi::DenseTensor* weight, + const phi::DenseTensor* input, + const phi::DenseTensor* weight_scale, + const float threshold, + phi::DenseTensor* output, + phi::DenseTensor* workspace, + std::string name, + int m, + int k, + int n) { + // absmax, quant, outlier + int64_t num_outlier_idx = (k + 31) / 32; + phi::DenseTensor row_ranges, outlier_idx, quant_input; + row_ranges.Resize({m}); + outlier_idx.Resize({num_outlier_idx}); + quant_input.Resize({m, k}); + dev_ctx.Alloc(&row_ranges); + dev_ctx.Alloc(&outlier_idx); + dev_ctx.Alloc(&quant_input); + + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(outlier_idx.data(), + 0, + num_outlier_idx * sizeof(int32_t), + dev_ctx.stream())); + LaunchReduceAbsMaxQuantKernel(input->data(), + threshold, + m, + k, + row_ranges.data(), + outlier_idx.data(), + quant_input.data(), + dev_ctx.stream()); + int32_t kfp_num = 0; + phi::DenseTensor kfp_num_tensor; + kfp_num_tensor.Resize({1}); + dev_ctx.Alloc(&kfp_num_tensor); + + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( + kfp_num_tensor.data(), 0, sizeof(int32_t), dev_ctx.stream())); + UpdateOutlier<<<1, num_outlier_idx, 0, dev_ctx.stream()>>>( + outlier_idx.data(), kfp_num_tensor.data()); + cudaMemcpy(&kfp_num, + kfp_num_tensor.data(), + sizeof(int32_t), + cudaMemcpyDeviceToHost); + + phi::DenseTensor sub_out; + sub_out.Resize({m, n}); + dev_ctx.Alloc(&sub_out); + if (kfp_num != 0) { + phi::DenseTensor sub_input, sub_weight; + sub_input.Resize({m, kfp_num}); + sub_weight.Resize({n, kfp_num}); + + dev_ctx.Alloc(&sub_input); + dev_ctx.Alloc(&sub_weight); + + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(sub_input.data(), + 0, + sub_input.numel() * sizeof(T), + dev_ctx.stream())); + + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(sub_weight.data(), + 0, + sub_weight.numel() * sizeof(T), + dev_ctx.stream())); + + LaunchSplitKernel(input->data(), + weight->data(), + weight_scale->data(), + outlier_idx.data(), + sub_input.data(), + sub_weight.data(), + m, + k, + n, + kfp_num, + dev_ctx.stream()); + + CBLAS_TRANSPOSE transA = CblasNoTrans; + CBLAS_TRANSPOSE transB = CblasTrans; + T alpha = static_cast(1.0); + T beta = static_cast(0.0); + + // (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out) + auto blas = phi::funcs::GetBlas(dev_ctx); + blas.GEMM(transA, + transB, + m, + n, + kfp_num, + alpha, + sub_input.data(), + sub_weight.data(), + beta, + sub_out.data()); + + // PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); + } else { + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( + sub_out.data(), 0, sub_out.numel() * sizeof(T), dev_ctx.stream())); + } + + phi::DenseTensor int_out; + int_out.Resize({m, n}); + dev_ctx.Alloc(&int_out); + + { + auto helper = std::make_unique(m, k, n); + helper->GEMM(quant_input.data(), + weight->data(), + int_out.data(), + dev_ctx.stream(), + (void*)workspace->data()); + } + // PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); + + LaunchDequantMergeKernel(int_out.data(), + sub_out.data(), + row_ranges.data(), + weight_scale->data(), + output->data(), + m, + n, + dev_ctx.stream()); + // PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); +} + +} // namespace llm_int8 +} // namespace phi diff --git a/paddle/phi/kernels/impl/quant_for_compress_kernel_impl.h b/paddle/phi/kernels/impl/quant_for_compress_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..12c5a68a1144fdf9e47afb498375f46a06bd44be --- /dev/null +++ b/paddle/phi/kernels/impl/quant_for_compress_kernel_impl.h @@ -0,0 +1,174 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PADDLE_PHI_KERNELS_IMPL_QUANT_FOR_COMPRESS_KERNEL_IMPL_H_ +#define PADDLE_PHI_KERNELS_IMPL_QUANT_FOR_COMPRESS_KERNEL_IMPL_H_ +#include +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/common_shape.h" + +namespace phi { + +template +inline T xabs(const T x) { + return x < static_cast(0.0) ? -x : x; +} + +template +void per_channel_scale(float* scale, const T* input, size_t m, size_t n) { + for (size_t i = 0; i < n; ++i) { + T max = input[i]; + for (size_t j = 0; j < m; ++j) { + max = xabs(input[j * n + i]) > max ? xabs(input[j * n + i]) : max; + } + scale[i] = static_cast(max) / 127.0; + } +} + +template +void per_channel_quant( + D* output, const T* input, const float* scale, size_t m, size_t n) { + for (size_t i = 0; i < m; i++) { + for (size_t j = 0; j < n; j++) { + output[i * n + j] = static_cast( + round(static_cast(input[i * n + j]) / scale[j])); + } + } +} + +void row_major_to_column_major(int8_t* col_major_tensor, + const int8_t* row_major_tensor, + const std::vector& shape) { + size_t m = shape[0]; + size_t n = shape[1]; + for (size_t i = 0; i < m * n; i++) { + size_t im = i / n; + size_t in = i % n; + col_major_tensor[in * m + im] = row_major_tensor[im * n + in]; + } +} + +void add_bias_and_interleave_int8s_inplace(int8_t* int8_tensor_ptr, + size_t num_elts) { + int8_t* int8_tensor = reinterpret_cast(int8_tensor_ptr); + for (size_t ii = 0; ii < num_elts; ++ii) { + int8_tensor[ii] = + static_cast(static_cast(int8_tensor[ii]) + 128); + } + // Step 2 will transform the layout of a 32-bit register in CUDA in order to + // match the int4 layout. This has no performance benefit and is purely so + // that int4 and int8 have the same layout. Pictorially, this does the + // following: bit 32 0 + // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 8 bits) + // + // And it will rearrange the output 32 bit register to be the following: + // bit 32 0 + // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits) + + for (size_t base = 0; base < num_elts; base += 4) { + std::swap(int8_tensor[base + 1], int8_tensor[base + 2]); + } +} + +void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, + const int8_t* quantized_tensor, + const std::vector& shape, + const int64_t arch_version) { + // We only want to run this step for weight only quant. + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + const int BITS_PER_ELT = 8; + const int K = 16 / BITS_PER_ELT; + // const int ELTS_PER_BYTE = 8 / BITS_PER_ELT; + const int ELTS_PER_REG = 32 / BITS_PER_ELT; + + const uint32_t* input_byte_ptr = + reinterpret_cast(quantized_tensor); + uint32_t* output_byte_ptr = + reinterpret_cast(permuted_quantized_tensor); + + // int MMA_SHAPE_N = 8; + int B_ROWS_PER_MMA = 8 * K; + const int elts_in_int32 = 32 / BITS_PER_ELT; + + const int num_vec_cols = num_cols / elts_in_int32; + + // The code is written as below so it works for both int8 and packed int4. + for (size_t base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) { + for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) { + for (int write_col = 0; write_col < num_vec_cols; ++write_col) { + const int write_row = base_row + tile_row; + const int tile_read_row = 8 * (((tile_row % ELTS_PER_REG) / 2)) + + tile_row % 2 + 2 * (tile_row / ELTS_PER_REG); + const int read_row = base_row + tile_read_row; + const int read_col = write_col; + + const int64_t read_offset = int64_t(read_row) * num_vec_cols + read_col; + const int64_t write_offset = + int64_t(write_row) * num_vec_cols + write_col; + output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; + } + } + } +} + +void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, + const int8_t* quantized_tensor, + const std::vector& shape) { + // We only want to run this step for weight only quant. + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + const size_t BITS_PER_ELT = 8; + const size_t elts_in_int32 = 32 / BITS_PER_ELT; + + const size_t rows_per_tile = 64; + + const uint32_t* input_byte_ptr = + reinterpret_cast(quantized_tensor); + uint32_t* output_byte_ptr = + reinterpret_cast(interleaved_quantized_tensor); + + const size_t num_vec_rows = num_rows / elts_in_int32; + const size_t vec_rows_per_tile = rows_per_tile / elts_in_int32; + const size_t interleave = 2; + for (size_t read_col = 0; read_col < num_cols; ++read_col) { + const size_t write_col = read_col / interleave; + for (size_t base_vec_row = 0; base_vec_row < num_vec_rows; + base_vec_row += vec_rows_per_tile) { + for (size_t vec_read_row = base_vec_row; + vec_read_row < + std::min(num_vec_rows, base_vec_row + vec_rows_per_tile); + ++vec_read_row) { + const size_t vec_write_row = + interleave * base_vec_row + + vec_rows_per_tile * (read_col % interleave) + + vec_read_row % vec_rows_per_tile; + + const size_t read_offset = + size_t(read_col) * num_vec_rows + vec_read_row; + const size_t write_offset = + size_t(write_col) * num_vec_rows * interleave + vec_write_row; + output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; + } + } + } +} + +} // namespace phi +#endif // PADDLE_PHI_KERNELS_IMPL_QUANT_FOR_COMPRESS_KERNEL_IMPL_H_ diff --git a/paddle/phi/kernels/llm_int8_mat_mul_kernel.h b/paddle/phi/kernels/llm_int8_mat_mul_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..b42ee93629b0a7a37892426bc1dbaf5216e8794b --- /dev/null +++ b/paddle/phi/kernels/llm_int8_mat_mul_kernel.h @@ -0,0 +1,25 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void LLMInt8MatMulKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& weight, + const DenseTensor& weight_scale, + const float threshold, + DenseTensor* out); +} // namespace phi diff --git a/paddle/phi/kernels/quant_for_compress_kernel.h b/paddle/phi/kernels/quant_for_compress_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..474589f60234c59fac98e0756f338085cd1adf28 --- /dev/null +++ b/paddle/phi/kernels/quant_for_compress_kernel.h @@ -0,0 +1,25 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void QuantForCompressKernel(const Context& dev_ctx, + const DenseTensor& x, + int bits, + const std::string& layout, + DenseTensor* out, + DenseTensor* scale); +} // namespace phi diff --git a/paddle/phi/kernels/weight_only_mat_mul_kernel.h b/paddle/phi/kernels/weight_only_mat_mul_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..99489bce378c6e4f252371cb30457dceda2e2b90 --- /dev/null +++ b/paddle/phi/kernels/weight_only_mat_mul_kernel.h @@ -0,0 +1,24 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void WeightOnlyMatMulKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& weight, + const DenseTensor& weight_scale, + DenseTensor* out); +} // namespace phi diff --git a/python/paddle/fluid/layer_helper_base.py b/python/paddle/fluid/layer_helper_base.py index 0579bc98563b2d40879c13e93f00f02105e87136..ae631bf69f38d60e6a67785592e1975f57b73e2f 100644 --- a/python/paddle/fluid/layer_helper_base.py +++ b/python/paddle/fluid/layer_helper_base.py @@ -381,6 +381,7 @@ class LayerHelperBase: and dtype != core.VarDesc.VarType.FP64 and dtype != core.VarDesc.VarType.FP16 and dtype != core.VarDesc.VarType.BF16 + and dtype != core.VarDesc.VarType.INT8 ): raise TypeError( "Can not create parameter with default initializer when dtype is not ['float16', 'float32', 'float64', 'bfloat16'] type. Set default_initializer to fit the parameter dtype!" @@ -392,6 +393,7 @@ class LayerHelperBase: 'float64', 'bfloat16', 'float', + 'int8', ]: raise TypeError( "Can not create parameter with default initializer when dtype is not ['float16', 'float32', 'float64', 'bfloat16', 'float'] type. Set default_initializer to fit the parameter dtype!" diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index 2cad4950fde50b83c9269cd6679c286e2e182867..477a3b7c57b4e646ce145dcaef49fa55bee006e9 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -71,7 +71,7 @@ from .layer.common import AlphaDropout # noqa: F401 from .layer.common import Unfold # noqa: F401 from .layer.common import Fold # noqa: F401 from .layer.common import Unflatten # noqa: F401 - +from .layer.common import LinearCompress # noqa: F401 from .layer.pooling import AvgPool1D # noqa: F401 from .layer.pooling import AvgPool2D # noqa: F401 from .layer.pooling import AvgPool3D # noqa: F401 diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 448a533b82d5c057559e16f72782e4798d8164a0..ec5ee96e3cc91641df7d9b3e0f1a779d58e04131 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -62,6 +62,8 @@ from .common import interpolate # noqa: F401 from .common import upsample # noqa: F401 from .common import bilinear # noqa: F401 from .common import class_center_sample # noqa: F401 +from .common import quant_for_compress # noqa: F401 +from .common import linear_compress # noqa: F401 from .conv import conv1d # noqa: F401 from .conv import conv1d_transpose # noqa: F401 from .common import linear # noqa: F401 diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index d788cc46d0a68e0321d9cf3e86b3c9c79bb73c4e..3ee66817f5ba602fbac0a19b5436b45e144ccb99 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -1876,6 +1876,78 @@ def linear(x, weight, bias=None, name=None): return res +def quant_for_compress(x, bits=8, layout="weight_only"): + return _C_ops.quant_for_compress(x, bits, layout) + + +def linear_compress( + x, + weight, + weight_scale, + bias=None, + bits=8, + algo="llm.int8", + name=None, + config=None, +): + if in_dynamic_mode(): + if algo == "llm.int8": + y = _C_ops.llm_int8_mat_mul( + x, weight, weight_scale, config['threshold'] + ) + elif algo == "weight_only": + y = _C_ops.weight_only_mat_mul(x, weight, weight_scale) + else: + raise ValueError( + "Unknown algo: '{}'. It can only be 'llm.int8' or 'weight_only'.".format( + algo + ) + ) + if bias is not None: + y = paddle.add(y, bias) + return y + else: + helper = LayerHelper('linear_compress', **locals()) + dtype = x.dtype + + check_variable_and_dtype(x, 'x', ['float16'], 'linear_compress') + check_dtype(dtype, 'dtype', ['float16'], 'linear_compress') + + if algo == "llm.int8": + type = "llm_int8_matmul" + inputs = {'X': [x], 'Y': [weight], 'weight_scale': [weight_scale]} + attrs = {'algo': algo, 'threshold': config['threshold']} + elif algo == "weight_only": + type = "weight_only_matmul" + inputs = {'X': [x], 'Y': [weight], 'weight_scale': [weight_scale]} + attrs = {} + else: + raise ValueError( + "Unknown algo: '{}'. It can only be 'llm.int8' or 'weight_only'.".format( + algo + ) + ) + tmp = helper.create_variable_for_type_inference(dtype) + + helper.append_op( + type=type, + inputs=inputs, + outputs={'Out': tmp}, + attrs=attrs, + ) + if bias is not None: + res = helper.create_variable_for_type_inference(dtype) + helper.append_op( + type='elementwise_add', + inputs={'X': [tmp], 'Y': [bias]}, + outputs={'Out': [res]}, + attrs={'axis': -1}, + ) + else: + res = tmp + return res + + def label_smooth(label, prior_dist=None, epsilon=0.1, name=None): r""" Label smoothing is a mechanism to regularize the classifier layer and is called diff --git a/python/paddle/nn/initializer/assign.py b/python/paddle/nn/initializer/assign.py index 3e50fe9a6f6c26da957b6be2ac020c989b06af20..aaa198ec46942aea0a7628d47daa12b5a7d9adfb 100644 --- a/python/paddle/nn/initializer/assign.py +++ b/python/paddle/nn/initializer/assign.py @@ -81,6 +81,12 @@ class NumpyArrayInitializer(Initializer): elif out_dtype == core.VarDesc.VarType.INT32: value_name = "int32_values" values = [int(v) for v in np_value.flat] + elif ( + out_dtype == core.VarDesc.VarType.INT8 + or out_dtype == core.VarDesc.VarType.UINT8 + ): + value_name = "int8_values" + values = [int(v) for v in np_value.flat] else: raise ValueError("Unsupported dtype %s", self._value.dtype) if self._value.size > 1024 * 1024 * 1024: diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 938c4b5a1ab48dda46ed140d717b40074c5a6ee0..1bab3855757c15187ebb7896a3329028ed594d4f 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -183,6 +183,167 @@ class Linear(Layer): ) +class LinearCompress(Layer): + r""" + + Fully-connected linear transformation layer. For each input :math:`X` , + the equation is: + + .. math:: + + Out = XW + b + + where :math:`W` is the weight and :math:`b` is the bias. + + Linear layer takes only one multi-dimensional tensor as input with the + shape :math:`[batch\_size, *, in\_features]` , where :math:`*` means any + number of additional dimensions. It multiplies input tensor with the weight + (a 2-D tensor of shape :math:`[in\_features, out\_features]` ) and produces + an output tensor of shape :math:`[batch\_size, *, out\_features]` . + If :math:`bias\_attr` is not False, the bias (a 1-D tensor of + shape :math:`[out\_features]` ) will be created and added to the output. + + Parameters: + in_features (int): The number of input units. + out_features (int): The number of output units. + weight_attr (ParamAttr, optional): The attribute for the weight of this layer. + The default value is None. If the Initializer of the + param_attr is not set, the parameter is initialized with Xavier. + For detailed information, please refer to paddle.ParamAttr. + bias_attr (ParamAttr|bool, optional): The attribute for the bias of this layer. + If it is set to False, no bias will be added to the output. + If it is set to None or one kind of ParamAttr, a bias parameter will + be created according to ParamAttr. For detailed information, please refer + to paddle.ParamAttr. The default value is None and the bias will be + initialized to zero. + name (str, optional): Normally there is no need for user to set this parameter. + For detailed information, please refer to :ref:`api_guide_Name` . + bits (int, optional): The attribute to set num of bits in quant during weight_only, + it must be set as 8, default: 8. + algo (str, optional): The attribute to set algorithm of cpmoress, it must be set as 'weight_only' + or 'llm.int8', default: weight_only. + config (dict, optional): The parameter config for algorithm of cpmoress. + For llm.int8, it should be set as {'threshold': 6.0}, default: {'threshold': 6.0}. + + Attribute: + **weight** (Parameter): the learnable weight of this layer. + + **bias** (Parameter): the learnable bias of this layer. + + Shape: + - input: Multi-dimentional tensor with shape :math:`[batch\_size, *, in\_features]` . Its data types are float16. + - output: Multi-dimentional tensor with shape :math:`[batch\_size, *, out\_features]` . The data type is the same as the input . + + Examples: + .. code-block:: python + + import paddle + + # Define the linear layer. + paddle.set_default_dtype('float16') + weight_attr = paddle.ParamAttr( + name="weight", + initializer=paddle.nn.initializer.Constant(value=0.5)) + bias_attr = paddle.ParamAttr( + name="bias", + initializer=paddle.nn.initializer.Constant(value=1.0)) + linear = paddle.nn.LinearCompress(128, 64, weight_attr=weight_attr, bias_attr=bias_attr, bits=8, algo='weight_only') + x = paddle.randn((3, 128), dtype="float16") + y = linear(x) + """ + + def __init__( + self, + in_features, + out_features, + weight_attr=None, + bias_attr=None, + name=None, + bits=8, + algo="weight_only", + config={'threshold': 6.0}, + ): + super().__init__() + self._dtype = self._helper.get_default_dtype() + self._weight_attr = weight_attr + self._bias_attr = bias_attr + self.weight = self.create_parameter( + shape=[in_features, out_features], + attr=self._weight_attr, + dtype=self._dtype, + is_bias=False, + ) + self.bias = self.create_parameter( + shape=[out_features], + attr=self._bias_attr, + dtype=self._dtype, + is_bias=True, + ) + self.weight_scale = self.create_parameter( + shape=[out_features], + attr=None, + dtype=self._dtype, + is_bias=False, + ) + self.is_weight_quanted = False + self.name = (name,) + self.bits = bits + self.layout = algo + self.algo = algo + self.config = config + + def forward(self, input): + if in_dynamic_mode(): + if not self.is_weight_quanted: + weight_tensor, weight_scale_tensor = F.quant_for_compress( + self.weight, self.bits, self.layout + ) + weight_attr = paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign(weight_tensor) + ) + self.weight = self.create_parameter( + shape=self.weight.shape + if self.layout == 0 + else [self.weight.shape[1], self.weight.shape[0]], + attr=weight_attr, + dtype="int8", + is_bias=False, + ) + weight_scale_attr = paddle.framework.ParamAttr( + initializer=paddle.nn.initializer.Assign( + weight_scale_tensor + ) + ) + self.weight_scale = self.create_parameter( + shape=self.weight_scale.shape, + attr=weight_scale_attr, + dtype="float32", + is_bias=False, + ) + self.is_weight_quanted = True + out = F.linear_compress( + x=input, + weight=self.weight, + weight_scale=self.weight_scale, + bias=self.bias, + bits=self.bits, + algo=self.algo, + name=self.name, + config=self.config, + ) + return out + + def extra_repr(self): + name_str = f', name={self.name}' if self.name else '' + return 'in_features={}, out_features={}, dtype={}{}, algo={}'.format( + self.weight.shape[0], + self.weight.shape[1], + self._dtype, + name_str, + self.algo, + ) + + class Upsample(Layer): """ This op resizes a batch of images. diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index 468d7eb966cd72587819f36cc2c52b47ee026e17..2af0e79d9690e7f4a8e8d954f8dc73c0b83bceb7 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -86,6 +86,7 @@ endif() list(REMOVE_ITEM TEST_OPS test_audio_logmel_feature test_audio_mel_feature) list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op) +list(REMOVE_ITEM TEST_OPS test_linear_compress) list(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_op) list(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_grad_op) list(REMOVE_ITEM TEST_OPS test_fuse_gemm_epilogue_pass) @@ -153,6 +154,7 @@ if(WIN32) list(REMOVE_ITEM TEST_OPS test_trt_convert_preln_residual_bias) list(REMOVE_ITEM TEST_OPS test_fused_multi_transformer_int8_op) list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op) + list(REMOVE_ITEM TEST_OPS test_linear_compress) endif() list(REMOVE_ITEM TEST_OPS test_checkpoint_saver) diff --git a/test/legacy_test/test_linear_compress.py b/test/legacy_test/test_linear_compress.py new file mode 100644 index 0000000000000000000000000000000000000000..99181a29728a36415673fc5177c189cbaf843685 --- /dev/null +++ b/test/legacy_test/test_linear_compress.py @@ -0,0 +1,116 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle import fluid +from paddle.fluid.framework import default_main_program +from paddle.framework import set_default_dtype + +np.random.seed(123) +paddle.seed(123) +default_main_program().random_seed = 42 +paddle.disable_static() + + +class LinearTestCase(unittest.TestCase): + def config(self): + self.dtype = 'float16' + self.rtol = 1e-5 + self.atol = 1e-2 + self.bias = True + self.in_features = 64 + self.out_features = 64 + self.algo = "weight_only" + + def setUp(self): + self.config() + input = np.random.random((2, 4, self.in_features)) + self.input = paddle.to_tensor(input, dtype=self.dtype) + if self.bias: + bias_attr = fluid.ParamAttr( + learning_rate=0.8, + trainable=False, + regularizer=None, + initializer=paddle.nn.initializer.Constant(value=1.0), + ) + else: + bias_attr = None + set_default_dtype(self.dtype) + self.linear = paddle.nn.Linear( + self.in_features, self.out_features, bias_attr=bias_attr + ) + if self.algo == "llm.int8": + self.config = {"threshold": 6.0} + else: + self.config = None + self.linear_compress = paddle.nn.LinearCompress( + self.in_features, + self.out_features, + bias_attr=bias_attr, + algo=self.algo, + config=self.config, + ) + self.linear_compress(self.input) + + def get_linear_out(self): + out = self.linear(self.input) + return out.numpy() + + def get_linear_compress_out(self): + out = self.linear_compress(self.input) + return out.numpy() + + def test_linear_compress(self): + out_real = self.get_linear_compress_out() + out_expect = self.get_linear_out() + np.testing.assert_allclose( + out_real, out_expect, rtol=self.rtol, atol=self.atol + ) + + +class LinearTestCase1(LinearTestCase): + def config(self): + super().config() + self.dtype = 'float16' + self.bias = True + self.in_features = 128 + self.out_features = 64 + + +class LinearTestCase2(LinearTestCase): + def config(self): + super().config() + self.dtype = 'float16' + self.bias = False + self.in_features = 64 + self.out_features = 64 + + +class LinearTestCase3(LinearTestCase): + def config(self): + super().config() + self.dtype = 'float16' + self.bias = False + self.in_features = 64 + self.out_features = 64 + self.algo = "llm.int8" + self.atol = 1e-1 + + +if __name__ == '__main__': + unittest.main()