From f4290a92653f8b8685958a90043585e59b83bf70 Mon Sep 17 00:00:00 2001 From: zhenyun <1500424927@qq.com> Date: Tue, 11 Jul 2023 11:18:28 +0800 Subject: [PATCH] Linear compress (#55128) * rename weight_only/llm.int8 --- paddle/phi/api/yaml/ops.yaml | 13 +- paddle/phi/infermeta/multiary.cc | 20 +- paddle/phi/infermeta/multiary.h | 5 +- paddle/phi/infermeta/unary.cc | 17 +- .../kernels/cpu/quant_for_compress_kernel.cc | 24 +- ...ul_kernel.cu => llm_int8_matmul_kernel.cu} | 10 +- .../kernels/gpu/weight_only_mat_mul_kernel.cu | 73 ----- .../kernels/gpu/weight_only_matmul_kernel.cu | 117 ++++++++ ...l_impl.h => llm_int8_matmul_kernel_impl.h} | 0 .../impl/quant_for_compress_kernel_impl.h | 260 +++++++++++++++--- ..._mul_kernel.h => llm_int8_matmul_kernel.h} | 2 +- ...l_kernel.h => weight_only_matmul_kernel.h} | 2 +- python/paddle/nn/functional/common.py | 16 +- python/paddle/nn/layer/common.py | 9 +- test/legacy_test/test_linear_compress.py | 12 + 15 files changed, 414 insertions(+), 166 deletions(-) rename paddle/phi/kernels/gpu/{llm_int8_mat_mul_kernel.cu => llm_int8_matmul_kernel.cu} (90%) delete mode 100644 paddle/phi/kernels/gpu/weight_only_mat_mul_kernel.cu create mode 100644 paddle/phi/kernels/gpu/weight_only_matmul_kernel.cu rename paddle/phi/kernels/impl/{llm_int8_mat_mul_kernel_impl.h => llm_int8_matmul_kernel_impl.h} (100%) rename paddle/phi/kernels/{llm_int8_mat_mul_kernel.h => llm_int8_matmul_kernel.h} (95%) rename paddle/phi/kernels/{weight_only_mat_mul_kernel.h => weight_only_matmul_kernel.h} (94%) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index b4082b55da5..1962836e3dd 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1367,14 +1367,14 @@ data_transform : skip_transform : out_size, size_tensor, scale_tensor -- op : llm_int8_mat_mul +- op : llm_int8_matmul args : (Tensor x, Tensor weight, Tensor weight_scale, float threshold=6.0) output : Tensor(out) infer_meta : - func : LLMInt8MatMulInferMeta + func : LLMInt8MatmulInferMeta param : [x, weight] kernel : - func : llm_int8_mat_mul + func : llm_int8_matmul data_type : x - op : log @@ -2602,14 +2602,13 @@ intermediate: warprnntgrad backward : warprnnt_grad -- op : weight_only_mat_mul +- op : weight_only_matmul args : (Tensor x, Tensor weight, Tensor weight_scale) output : Tensor(out) infer_meta : - func : WeightOnlyMatMulInferMeta - param : [x, weight] + func : WeightOnlyMatmulInferMeta kernel : - func : weight_only_mat_mul + func : weight_only_matmul data_type : x - op : weighted_sample_neighbors diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 71bbfaa333a..9248f4699a4 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -3572,7 +3572,7 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row, out_count->set_dtype(DataType::INT32); } -void LLMInt8MatMulInferMeta(const MetaTensor& x, +void LLMInt8MatmulInferMeta(const MetaTensor& x, const MetaTensor& weight, MetaTensor* out) { auto x_dims = x.dims(); @@ -3595,25 +3595,31 @@ void LLMInt8MatMulInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } -void WeightOnlyMatMulInferMeta(const MetaTensor& x, +void WeightOnlyMatmulInferMeta(const MetaTensor& x, const MetaTensor& weight, + const MetaTensor& weight_scale, MetaTensor* out) { auto x_dims = x.dims(); auto w_dims = weight.dims(); + auto n = weight_scale.dims()[0]; PADDLE_ENFORCE_EQ( w_dims.size(), 2UL, errors::InvalidArgument("The input(weight) must be a 2D Tensor.")); + PADDLE_ENFORCE_EQ( + weight_scale.dims().size(), + 1UL, + errors::InvalidArgument("The input(weight_scale) must be a 1D Tensor.")); PADDLE_ENFORCE_EQ( x_dims[x_dims.size() - 1], - w_dims[0], + w_dims[1], errors::InvalidArgument( - "Input(X) dim[-1] and Input(Weight) dim[0] should be euqal." - "But received Input(X) dim[-1](%s) != Input(Weight) dim[0](%s)", + "Input(X) dim[-1] and Input(Weight) dim[1] should be euqal." + "But received Input(X) dim[-1](%s) != Input(Weight) dim[1](%s)", x_dims[x_dims.size() - 1], - w_dims[0])); + w_dims[1])); auto out_dims = x_dims; - out_dims[out_dims.size() - 1] = w_dims[1]; + out_dims[out_dims.size() - 1] = n; out->set_dims(out_dims); out->set_dtype(x.dtype()); } diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 67a39780aa9..a785a4c5ee8 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -690,12 +690,13 @@ void FusedMultiHeadAttentionVariableInferMeta(const MetaTensor& query, bool causal, MetaTensor* out); -void LLMInt8MatMulInferMeta(const MetaTensor& x, +void LLMInt8MatmulInferMeta(const MetaTensor& x, const MetaTensor& weight, MetaTensor* out); -void WeightOnlyMatMulInferMeta(const MetaTensor& x, +void WeightOnlyMatmulInferMeta(const MetaTensor& x, const MetaTensor& weight, + const MetaTensor& weight_scale, MetaTensor* out); void FusedRopeInferMeta(const MetaTensor& q, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index f37ae1d688c..2b6b5cc0e16 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -5086,22 +5086,17 @@ void QuantForCompressInferMeta(const MetaTensor& x, x_dims[0])); std::vector dim_scale({x_dims[1]}); std::vector dim_out; - if (layout == "weight_only") { - dim_out = std::vector({x_dims[0], x_dims[1]}); - } else if (layout == "llm.int8") { + if (bits == 8) { dim_out = std::vector({x_dims[1], x_dims[0]}); + } else if (bits == 4) { + dim_out = std::vector({x_dims[1] / 2, x_dims[0]}); } else { - phi::errors::InvalidArgument( - "The layout must be weight_only or llm.int8, but got %s", layout); + phi::errors::InvalidArgument("The bit must be 8 or 4, but got %d", bits); } out->set_dims(phi::make_ddim(dim_out)); - // TODO(lizhenyun) support weight_only int4 - if (bits == 8) { - out->set_dtype(DataType::INT8); - } else { - phi::errors::Fatal("The bits only support 8, but got[%d]", bits); - } + out->set_dtype(DataType::INT8); + scale->set_dims(phi::make_ddim(dim_scale)); scale->set_dtype(DataType::FLOAT32); } diff --git a/paddle/phi/kernels/cpu/quant_for_compress_kernel.cc b/paddle/phi/kernels/cpu/quant_for_compress_kernel.cc index a96842640b6..3d21371f4fd 100644 --- a/paddle/phi/kernels/cpu/quant_for_compress_kernel.cc +++ b/paddle/phi/kernels/cpu/quant_for_compress_kernel.cc @@ -22,7 +22,7 @@ namespace phi { -template +template void quant_compute(const DeviceContext& dev_ctx, const DenseTensor& x, DenseTensor* out, @@ -59,15 +59,15 @@ void quant_compute(const DeviceContext& dev_ctx, per_channel_scale(scale_data, x_data, m, n); - per_channel_quant(x_int_data, x_data, scale_data, m, n); + per_channel_quant(x_int_data, x_data, scale_data, m, n); if (layout == "weight_only") { - permute_B_rows_for_mixed_gemm( + permute_B_rows_for_mixed_gemm( int_processed_data, x_int_data, std::vector{m, n}, (int64_t)80); - row_major_to_column_major( + subbyte_transpose_impl( int_processed_2_data, int_processed_data, std::vector{m, n}); - interleave_column_major_tensor( + interleave_column_major_tensor( out_data, int_processed_2_data, std::vector{m, n}); - add_bias_and_interleave_int8s_inplace(out_data, num); + add_bias_and_interleave_inplace(out_data, num); } else if (layout == "llm.int8") { std::vector axis = {1, 0}; funcs::Transpose trans; @@ -88,9 +88,16 @@ void QuantForCompressKernel(const Context& dev_ctx, if (bits == 8) { dev_ctx.template Alloc(out); dev_ctx.template Alloc(scale); - quant_compute(dev_ctx, x, out, scale, layout); + quant_compute(dev_ctx, x, out, scale, layout); + } else if (bits == 4 && layout == "weight_only") { + dev_ctx.template Alloc(out); + dev_ctx.template Alloc(scale); + quant_compute(dev_ctx, x, out, scale, layout); } else { - phi::errors::Unimplemented("The bits only support 8, but got[%d]", bits); + phi::errors::Unimplemented( + "The bits only support 8 or weight_only 4, but got[%s] [%d]", + layout, + bits); } // VLOG(0) << "x: " << x.dtype() << x; // VLOG(0) << "out: " << out->dtype() << *out; @@ -102,5 +109,4 @@ PD_REGISTER_KERNEL(quant_for_compress, CPU, ALL_LAYOUT, phi::QuantForCompressKernel, - float, phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/llm_int8_mat_mul_kernel.cu b/paddle/phi/kernels/gpu/llm_int8_matmul_kernel.cu similarity index 90% rename from paddle/phi/kernels/gpu/llm_int8_mat_mul_kernel.cu rename to paddle/phi/kernels/gpu/llm_int8_matmul_kernel.cu index 47aca6c4caf..ed68d2d1e01 100644 --- a/paddle/phi/kernels/gpu/llm_int8_mat_mul_kernel.cu +++ b/paddle/phi/kernels/gpu/llm_int8_matmul_kernel.cu @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/llm_int8_mat_mul_kernel.h" +#include "paddle/phi/kernels/llm_int8_matmul_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" #ifndef PADDLE_WITH_HIP -#include "paddle/phi/kernels/impl/llm_int8_mat_mul_kernel_impl.h" +#include "paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h" #endif namespace phi { @@ -56,7 +56,7 @@ void llm_int8_compute(const Context& dev_ctx, } template -void LLMInt8MatMulKernel(const Context& dev_ctx, +void LLMInt8MatmulKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& weight, const DenseTensor& weight_scale, @@ -68,8 +68,8 @@ void LLMInt8MatMulKernel(const Context& dev_ctx, } } // namespace phi -PD_REGISTER_KERNEL(llm_int8_mat_mul, +PD_REGISTER_KERNEL(llm_int8_matmul, GPU, ALL_LAYOUT, - phi::LLMInt8MatMulKernel, + phi::LLMInt8MatmulKernel, phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/weight_only_mat_mul_kernel.cu b/paddle/phi/kernels/gpu/weight_only_mat_mul_kernel.cu deleted file mode 100644 index d8216a2fe68..00000000000 --- a/paddle/phi/kernels/gpu/weight_only_mat_mul_kernel.cu +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/kernels/weight_only_mat_mul_kernel.h" -#include "paddle/phi/backends/gpu/gpu_context.h" -#include "paddle/phi/common/datatype_traits.h" -#include "paddle/phi/core/kernel_registry.h" -#if defined(PADDLE_WITH_CUTLASS) -#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" -#endif - -namespace phi { - -template -void WeightOnlyMatMulKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& weight, - const DenseTensor& weight_scale, - DenseTensor* out) { -#if defined(PADDLE_WITH_CUTLASS) - dev_ctx.template Alloc(out); - const auto x_dims = x.dims(); - const auto w_dims = weight.dims(); - - int k = w_dims[0]; - int n = w_dims[1]; - int m = x.numel() / k; - auto mixed_gemm_runner = - CutlassFpAIntBGemmRunner::DataType, - uint8_t>(); - int mixgemm_max_size = std::max(n, k); - DenseTensor mixgemm_workspace; - int64_t mixgemm_workspace_size_bytes = - mixed_gemm_runner.getWorkspaceSize(m, mixgemm_max_size, mixgemm_max_size); - - mixgemm_workspace.Resize({mixgemm_workspace_size_bytes}); - dev_ctx.template Alloc(&mixgemm_workspace); - char* mixgemm_workspace_data = - reinterpret_cast(mixgemm_workspace.data()); - mixed_gemm_runner.gemm( - reinterpret_cast::DataType*>( - x.data()), - reinterpret_cast(weight.data()), - reinterpret_cast(weight_scale.data()), - reinterpret_cast::DataType*>(out->data()), - m, - n, - k, - mixgemm_workspace_data, - mixgemm_workspace_size_bytes, - dev_ctx.stream()); -#else - LOG(ERROR) << "Please compile with cutlass to EnableUseCutlass()"; -#endif -} -} // namespace phi - -PD_REGISTER_KERNEL(weight_only_mat_mul, - GPU, - ALL_LAYOUT, - phi::WeightOnlyMatMulKernel, - phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/weight_only_matmul_kernel.cu b/paddle/phi/kernels/gpu/weight_only_matmul_kernel.cu new file mode 100644 index 00000000000..ad88315875b --- /dev/null +++ b/paddle/phi/kernels/gpu/weight_only_matmul_kernel.cu @@ -0,0 +1,117 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/weight_only_matmul_kernel.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/datatype_traits.h" +#include "paddle/phi/core/kernel_registry.h" +#if defined(PADDLE_WITH_CUTLASS) +#include "paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h" +#endif + +namespace phi { + +template +void WeightOnlyMatmulKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& weight, + const DenseTensor& weight_scale, + DenseTensor* out) { +#if defined(PADDLE_WITH_CUTLASS) + dev_ctx.template Alloc(out); + const auto x_dims = x.dims(); + const auto w_dims = weight.dims(); + int n = weight_scale.dims()[0]; + int quant_bit = 0; + if (n % w_dims[0] == 0) { + quant_bit = w_dims[0] * 8 / n; + } else { + errors::InvalidArgument( + "w_dims[0] must be divisible by weight_scale.dims()[0]"); + } + + int k = w_dims[1]; + int m = x.numel() / k; + switch (quant_bit) { + case 8: { + auto mixed_gemm_runner = + CutlassFpAIntBGemmRunner::DataType, + uint8_t>(); + int mixgemm_max_size = std::max(n, k); + DenseTensor mixgemm_workspace; + int64_t mixgemm_workspace_size_bytes = mixed_gemm_runner.getWorkspaceSize( + m, mixgemm_max_size, mixgemm_max_size); + + mixgemm_workspace.Resize({mixgemm_workspace_size_bytes}); + dev_ctx.template Alloc(&mixgemm_workspace); + char* mixgemm_workspace_data = + reinterpret_cast(mixgemm_workspace.data()); + mixed_gemm_runner.gemm( + reinterpret_cast::DataType*>( + x.data()), + reinterpret_cast(weight.data()), + reinterpret_cast(weight_scale.data()), + reinterpret_cast::DataType*>( + out->data()), + m, + n, + k, + mixgemm_workspace_data, + mixgemm_workspace_size_bytes, + dev_ctx.stream()); + } break; + case 4: { + auto mixed_gemm_runner = + CutlassFpAIntBGemmRunner::DataType, + cutlass::uint4b_t>(); + int mixgemm_max_size = std::max(n, k); + DenseTensor mixgemm_workspace; + int64_t mixgemm_workspace_size_bytes = mixed_gemm_runner.getWorkspaceSize( + m, mixgemm_max_size, mixgemm_max_size); + + mixgemm_workspace.Resize({mixgemm_workspace_size_bytes}); + dev_ctx.template Alloc(&mixgemm_workspace); + char* mixgemm_workspace_data = + reinterpret_cast(mixgemm_workspace.data()); + mixed_gemm_runner.gemm( + reinterpret_cast::DataType*>( + x.data()), + reinterpret_cast(weight.data()), + reinterpret_cast(weight_scale.data()), + reinterpret_cast::DataType*>( + out->data()), + m, + n, + k, + mixgemm_workspace_data, + mixgemm_workspace_size_bytes, + dev_ctx.stream()); + } break; + default: + PADDLE_THROW(errors::Unimplemented( + "Quant_bits (%d) is not supported when gemm ", quant_bit)); + break; + } + +#else + LOG(ERROR) << "Please compile with cutlass to EnableUseCutlass()"; +#endif +} +} // namespace phi + +PD_REGISTER_KERNEL(weight_only_matmul, + GPU, + ALL_LAYOUT, + phi::WeightOnlyMatmulKernel, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/impl/llm_int8_mat_mul_kernel_impl.h b/paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h similarity index 100% rename from paddle/phi/kernels/impl/llm_int8_mat_mul_kernel_impl.h rename to paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h diff --git a/paddle/phi/kernels/impl/quant_for_compress_kernel_impl.h b/paddle/phi/kernels/impl/quant_for_compress_kernel_impl.h index 12c5a68a114..096600bac0b 100644 --- a/paddle/phi/kernels/impl/quant_for_compress_kernel_impl.h +++ b/paddle/phi/kernels/impl/quant_for_compress_kernel_impl.h @@ -39,51 +39,109 @@ void per_channel_scale(float* scale, const T* input, size_t m, size_t n) { } } -template -void per_channel_quant( - D* output, const T* input, const float* scale, size_t m, size_t n) { - for (size_t i = 0; i < m; i++) { - for (size_t j = 0; j < n; j++) { - output[i * n + j] = static_cast( - round(static_cast(input[i * n + j]) / scale[j])); +template +void per_channel_quant(int8_t* output, + const T* input, + const float* scale, + size_t num_rows, + size_t num_cols) { + size_t bytes_per_out_col = num_cols * quant_bit / 8; + for (size_t ii = 0; ii < num_rows; ++ii) { + int8_t* current_quantized_weight_row = output + ii * bytes_per_out_col; + const T* current_weight_row = input + ii * num_cols; + for (size_t jj = 0; jj < bytes_per_out_col; ++jj) { + if (quant_bit == 8) { + const float col_scale = scale[jj]; + const float weight_elt = static_cast(current_weight_row[jj]); + const float scaled_weight = round(weight_elt / col_scale); + const int8_t clipped_weight = static_cast( + std::max(-127.f, std::min(127.f, scaled_weight))); + current_quantized_weight_row[jj] = clipped_weight; + } else if (quant_bit == 4) { + // We will pack two int4 elements per iteration of the inner loop. + int8_t packed_int4s = 0; + for (int packed_idx = 0; packed_idx < 2; ++packed_idx) { + const size_t input_idx = 2 * jj + packed_idx; + if (input_idx < num_cols) { + const float col_scale = scale[input_idx]; + const float weight_elt = + static_cast(current_weight_row[input_idx]); + const float scaled_weight = round(weight_elt / col_scale); + int int_weight = static_cast(scaled_weight); + const int8_t clipped_weight = std::max(-7, std::min(7, int_weight)); + + // Kill the sign extension bits (hence 0x0F mask) then shift to + // upper bits if packing the second int4 and or the bits into the + // final result. + packed_int4s |= ((clipped_weight & 0x0F) << (4 * packed_idx)); + } + } + current_quantized_weight_row[jj] = packed_int4s; + } else { + phi::errors::Unimplemented("Unsupported quantization bits: %d", + quant_bit); + } } } } -void row_major_to_column_major(int8_t* col_major_tensor, - const int8_t* row_major_tensor, - const std::vector& shape) { - size_t m = shape[0]; - size_t n = shape[1]; - for (size_t i = 0; i < m * n; i++) { - size_t im = i / n; - size_t in = i % n; - col_major_tensor[in * m + im] = row_major_tensor[im * n + in]; - } -} +template +void add_bias_and_interleave_inplace(int8_t* tensor_ptr, size_t num_elts) { + const size_t num_bytes = num_elts * quant_bit / 8; + + for (size_t ii = 0; ii < num_bytes; ++ii) { + if (quant_bit == 8) { + tensor_ptr[ii] = + static_cast(static_cast(tensor_ptr[ii]) + 128); + } else { + int8_t transformed_packed_int4s = 0; + int8_t transformed_first_elt = + (int8_t(tensor_ptr[ii] << 4) >> 4) + + 8; // The double shift here is to ensure sign extension + int8_t transformed_second_elt = (tensor_ptr[ii] >> 4) + 8; -void add_bias_and_interleave_int8s_inplace(int8_t* int8_tensor_ptr, - size_t num_elts) { - int8_t* int8_tensor = reinterpret_cast(int8_tensor_ptr); - for (size_t ii = 0; ii < num_elts; ++ii) { - int8_tensor[ii] = - static_cast(static_cast(int8_tensor[ii]) + 128); + if (!(transformed_first_elt >= 0 && transformed_first_elt <= 15)) { + phi::errors::InvalidArgument( + "Illegal result for int4 transform (first elt)"); + } + if (!(transformed_second_elt >= 0 && transformed_second_elt <= 15)) { + phi::errors::InvalidArgument( + "Illegal result for int4 transform (second elt)"); + } + // We don't need to mask in these ops since everything should be in the + // range 0-15 + transformed_packed_int4s |= transformed_first_elt; + transformed_packed_int4s |= (transformed_second_elt << 4); + tensor_ptr[ii] = transformed_packed_int4s; + } } - // Step 2 will transform the layout of a 32-bit register in CUDA in order to - // match the int4 layout. This has no performance benefit and is purely so - // that int4 and int8 have the same layout. Pictorially, this does the - // following: bit 32 0 - // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 8 bits) - // - // And it will rearrange the output 32 bit register to be the following: - // bit 32 0 - // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits) - - for (size_t base = 0; base < num_elts; base += 4) { - std::swap(int8_tensor[base + 1], int8_tensor[base + 2]); + if (quant_bit == 8) { + for (size_t base = 0; base < num_elts; base += 4) { + std::swap(tensor_ptr[base + 1], tensor_ptr[base + 2]); + } + } else { + const size_t num_registers = num_bytes / 4; + + uint32_t* register_ptr = reinterpret_cast(tensor_ptr); + for (size_t ii = 0; ii < num_registers; ++ii) { + const uint32_t current_register = register_ptr[ii]; + uint32_t transformed_register = 0; + + for (int dest_idx = 0; dest_idx < 8; ++dest_idx) { + const int src_idx = + dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1; + const int src_shift = 4 * src_idx; + const int dest_shift = 4 * dest_idx; + + const uint32_t src_bits = (current_register >> src_shift) & 0xF; + transformed_register |= (src_bits << dest_shift); + } + register_ptr[ii] = transformed_register; + } } } +template void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, const int8_t* quantized_tensor, const std::vector& shape, @@ -92,9 +150,8 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; - const int BITS_PER_ELT = 8; + const int BITS_PER_ELT = quant_bit; const int K = 16 / BITS_PER_ELT; - // const int ELTS_PER_BYTE = 8 / BITS_PER_ELT; const int ELTS_PER_REG = 32 / BITS_PER_ELT; const uint32_t* input_byte_ptr = @@ -102,7 +159,6 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, uint32_t* output_byte_ptr = reinterpret_cast(permuted_quantized_tensor); - // int MMA_SHAPE_N = 8; int B_ROWS_PER_MMA = 8 * K; const int elts_in_int32 = 32 / BITS_PER_ELT; @@ -118,15 +174,134 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, const int read_row = base_row + tile_read_row; const int read_col = write_col; - const int64_t read_offset = int64_t(read_row) * num_vec_cols + read_col; + const int64_t read_offset = + static_cast(read_row) * num_vec_cols + read_col; const int64_t write_offset = - int64_t(write_row) * num_vec_cols + write_col; + static_cast(write_row) * num_vec_cols + write_col; output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; } } } } +template +void subbyte_transpose_impl(int8_t* transposed_quantized_tensor, + const int8_t* quantized_tensor, + const std::vector& shape) { + const int bits_per_elt = quant_bit; + + // FT_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be + // 2-D or 3-D"); + // const size_t num_experts = 1; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + const size_t col_bytes = num_cols * bits_per_elt / 8; + const size_t col_bytes_trans = num_rows * bits_per_elt / 8; + // const size_t num_bytes = size_t(num_experts) * num_rows * col_bytes; + + const uint8_t* input_byte_ptr = + reinterpret_cast(quantized_tensor); + uint8_t* output_byte_ptr = + reinterpret_cast(transposed_quantized_tensor); + + static constexpr int ELTS_PER_BYTE = 8 / quant_bit; + + static constexpr int M_TILE_L1 = 64; + static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE; + uint8_t cache_buf[M_TILE_L1][N_TILE_L1]; + + static constexpr int VECTOR_WIDTH = std::min(32, N_TILE_L1); + + // const int num_m_tiles = (num_rows + M_TILE_L1 - 1) / M_TILE_L1; + // const int num_n_tiles = (col_bytes + N_TILE_L1 - 1) / N_TILE_L1; + + for (size_t row_tile_start = 0; row_tile_start < num_rows; + row_tile_start += M_TILE_L1) { + for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes; + col_tile_start_byte += N_TILE_L1) { + const int row_limit = std::min(row_tile_start + M_TILE_L1, num_rows); + const int col_limit = + std::min(col_tile_start_byte + N_TILE_L1, col_bytes); + + for (int ii = 0; ii < M_TILE_L1; ++ii) { + const int row = row_tile_start + ii; + + for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { + const int col = col_tile_start_byte + jj; + + const size_t logical_src_offset = row * col_bytes + col; + + if (row < row_limit && col < col_limit) { + for (int v = 0; v < VECTOR_WIDTH; ++v) { + cache_buf[ii][jj + v] = input_byte_ptr[logical_src_offset + v]; + } + } + } + } + + if (quant_bit == 8) { + for (int ii = 0; ii < M_TILE_L1; ++ii) { + for (int jj = ii + 1; jj < N_TILE_L1; ++jj) { + std::swap(cache_buf[ii][jj], cache_buf[jj][ii]); + } + } + } else if (quant_bit == 4) { + for (int ii = 0; ii < M_TILE_L1; ++ii) { + // Using M_TILE_L1 here is deliberate since we assume that the cache + // tile is square in the number of elements (not necessarily the + // number of bytes). + for (int jj = ii + 1; jj < M_TILE_L1; ++jj) { + const int ii_byte = ii / ELTS_PER_BYTE; + const int ii_bit_offset = ii % ELTS_PER_BYTE; + + const int jj_byte = jj / ELTS_PER_BYTE; + const int jj_bit_offset = jj % ELTS_PER_BYTE; + + uint8_t src_elt = + 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset)); + uint8_t tgt_elt = + 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset)); + + cache_buf[ii][jj_byte] &= (0xF0 >> (4 * jj_bit_offset)); + cache_buf[jj][ii_byte] &= (0xF0 >> (4 * ii_bit_offset)); + + cache_buf[ii][jj_byte] |= (tgt_elt << (4 * jj_bit_offset)); + cache_buf[jj][ii_byte] |= (src_elt << (4 * ii_bit_offset)); + } + } + } else { + phi::errors::Unimplemented("Unsupported quantization bits: %d", + quant_bit); + } + + const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE; + const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE; + + const int row_limit_trans = + std::min(row_tile_start_trans + M_TILE_L1, num_cols); + const int col_limit_trans = + std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans); + + for (int ii = 0; ii < M_TILE_L1; ++ii) { + const int row = row_tile_start_trans + ii; + for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { + const int col = col_tile_start_byte_trans + jj; + + const size_t logical_tgt_offset = row * col_bytes_trans + col; + + if (row < row_limit_trans && col < col_limit_trans) { + for (int v = 0; v < VECTOR_WIDTH; ++v) { + output_byte_ptr[logical_tgt_offset + v] = cache_buf[ii][jj + v]; + } + } + } + } + } + } +} + +template void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, const int8_t* quantized_tensor, const std::vector& shape) { @@ -134,7 +309,7 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; - const size_t BITS_PER_ELT = 8; + const size_t BITS_PER_ELT = quant_bit; const size_t elts_in_int32 = 32 / BITS_PER_ELT; const size_t rows_per_tile = 64; @@ -169,6 +344,5 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, } } } - } // namespace phi #endif // PADDLE_PHI_KERNELS_IMPL_QUANT_FOR_COMPRESS_KERNEL_IMPL_H_ diff --git a/paddle/phi/kernels/llm_int8_mat_mul_kernel.h b/paddle/phi/kernels/llm_int8_matmul_kernel.h similarity index 95% rename from paddle/phi/kernels/llm_int8_mat_mul_kernel.h rename to paddle/phi/kernels/llm_int8_matmul_kernel.h index b42ee93629b..0d6229ea5af 100644 --- a/paddle/phi/kernels/llm_int8_mat_mul_kernel.h +++ b/paddle/phi/kernels/llm_int8_matmul_kernel.h @@ -16,7 +16,7 @@ limitations under the License. */ namespace phi { template -void LLMInt8MatMulKernel(const Context& dev_ctx, +void LLMInt8MatmulKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& weight, const DenseTensor& weight_scale, diff --git a/paddle/phi/kernels/weight_only_mat_mul_kernel.h b/paddle/phi/kernels/weight_only_matmul_kernel.h similarity index 94% rename from paddle/phi/kernels/weight_only_mat_mul_kernel.h rename to paddle/phi/kernels/weight_only_matmul_kernel.h index 99489bce378..f2f20294021 100644 --- a/paddle/phi/kernels/weight_only_mat_mul_kernel.h +++ b/paddle/phi/kernels/weight_only_matmul_kernel.h @@ -16,7 +16,7 @@ limitations under the License. */ namespace phi { template -void WeightOnlyMatMulKernel(const Context& dev_ctx, +void WeightOnlyMatmulKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& weight, const DenseTensor& weight_scale, diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 3ee66817f5b..f3fc199dbcf 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -1892,11 +1892,11 @@ def linear_compress( ): if in_dynamic_mode(): if algo == "llm.int8": - y = _C_ops.llm_int8_mat_mul( + y = _C_ops.llm_int8_matmul( x, weight, weight_scale, config['threshold'] ) elif algo == "weight_only": - y = _C_ops.weight_only_mat_mul(x, weight, weight_scale) + y = _C_ops.weight_only_matmul(x, weight, weight_scale) else: raise ValueError( "Unknown algo: '{}'. It can only be 'llm.int8' or 'weight_only'.".format( @@ -1915,11 +1915,19 @@ def linear_compress( if algo == "llm.int8": type = "llm_int8_matmul" - inputs = {'X': [x], 'Y': [weight], 'weight_scale': [weight_scale]} + inputs = { + 'x': [x], + 'weight': [weight], + 'weight_scale': [weight_scale], + } attrs = {'algo': algo, 'threshold': config['threshold']} elif algo == "weight_only": type = "weight_only_matmul" - inputs = {'X': [x], 'Y': [weight], 'weight_scale': [weight_scale]} + inputs = { + 'x': [x], + 'weight': [weight], + 'weight_scale': [weight_scale], + } attrs = {} else: raise ValueError( diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 1bab3855757..64caff4c169 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -301,10 +301,13 @@ class LinearCompress(Layer): weight_attr = paddle.framework.ParamAttr( initializer=paddle.nn.initializer.Assign(weight_tensor) ) + weight_shape = ( + [self.weight.shape[1], self.weight.shape[0]] + if self.bits == 8 + else [self.weight.shape[1] / 2, self.weight.shape[0]] + ) self.weight = self.create_parameter( - shape=self.weight.shape - if self.layout == 0 - else [self.weight.shape[1], self.weight.shape[0]], + shape=weight_shape, attr=weight_attr, dtype="int8", is_bias=False, diff --git a/test/legacy_test/test_linear_compress.py b/test/legacy_test/test_linear_compress.py index 99181a29728..438e42a9891 100644 --- a/test/legacy_test/test_linear_compress.py +++ b/test/legacy_test/test_linear_compress.py @@ -36,6 +36,7 @@ class LinearTestCase(unittest.TestCase): self.in_features = 64 self.out_features = 64 self.algo = "weight_only" + self.bits = 8 def setUp(self): self.config() @@ -62,6 +63,7 @@ class LinearTestCase(unittest.TestCase): self.in_features, self.out_features, bias_attr=bias_attr, + bits=8, algo=self.algo, config=self.config, ) @@ -112,5 +114,15 @@ class LinearTestCase3(LinearTestCase): self.atol = 1e-1 +class LinearTestCase4(LinearTestCase): + def config(self): + super().config() + self.dtype = 'float16' + self.bias = True + self.in_features = 128 + self.out_features = 64 + self.bits = 4 + + if __name__ == '__main__': unittest.main() -- GitLab