未验证 提交 f4290a92 编写于 作者: FormlessUnit's avatar FormlessUnit 提交者: GitHub

Linear compress (#55128)

* rename weight_only/llm.int8
上级 ab46b14c
...@@ -1367,14 +1367,14 @@ ...@@ -1367,14 +1367,14 @@
data_transform : data_transform :
skip_transform : out_size, size_tensor, scale_tensor skip_transform : out_size, size_tensor, scale_tensor
- op : llm_int8_mat_mul - op : llm_int8_matmul
args : (Tensor x, Tensor weight, Tensor weight_scale, float threshold=6.0) args : (Tensor x, Tensor weight, Tensor weight_scale, float threshold=6.0)
output : Tensor(out) output : Tensor(out)
infer_meta : infer_meta :
func : LLMInt8MatMulInferMeta func : LLMInt8MatmulInferMeta
param : [x, weight] param : [x, weight]
kernel : kernel :
func : llm_int8_mat_mul func : llm_int8_matmul
data_type : x data_type : x
- op : log - op : log
...@@ -2602,14 +2602,13 @@ ...@@ -2602,14 +2602,13 @@
intermediate: warprnntgrad intermediate: warprnntgrad
backward : warprnnt_grad backward : warprnnt_grad
- op : weight_only_mat_mul - op : weight_only_matmul
args : (Tensor x, Tensor weight, Tensor weight_scale) args : (Tensor x, Tensor weight, Tensor weight_scale)
output : Tensor(out) output : Tensor(out)
infer_meta : infer_meta :
func : WeightOnlyMatMulInferMeta func : WeightOnlyMatmulInferMeta
param : [x, weight]
kernel : kernel :
func : weight_only_mat_mul func : weight_only_matmul
data_type : x data_type : x
- op : weighted_sample_neighbors - op : weighted_sample_neighbors
......
...@@ -3572,7 +3572,7 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row, ...@@ -3572,7 +3572,7 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row,
out_count->set_dtype(DataType::INT32); out_count->set_dtype(DataType::INT32);
} }
void LLMInt8MatMulInferMeta(const MetaTensor& x, void LLMInt8MatmulInferMeta(const MetaTensor& x,
const MetaTensor& weight, const MetaTensor& weight,
MetaTensor* out) { MetaTensor* out) {
auto x_dims = x.dims(); auto x_dims = x.dims();
...@@ -3595,25 +3595,31 @@ void LLMInt8MatMulInferMeta(const MetaTensor& x, ...@@ -3595,25 +3595,31 @@ void LLMInt8MatMulInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype()); out->set_dtype(x.dtype());
} }
void WeightOnlyMatMulInferMeta(const MetaTensor& x, void WeightOnlyMatmulInferMeta(const MetaTensor& x,
const MetaTensor& weight, const MetaTensor& weight,
const MetaTensor& weight_scale,
MetaTensor* out) { MetaTensor* out) {
auto x_dims = x.dims(); auto x_dims = x.dims();
auto w_dims = weight.dims(); auto w_dims = weight.dims();
auto n = weight_scale.dims()[0];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
w_dims.size(), w_dims.size(),
2UL, 2UL,
errors::InvalidArgument("The input(weight) must be a 2D Tensor.")); errors::InvalidArgument("The input(weight) must be a 2D Tensor."));
PADDLE_ENFORCE_EQ(
weight_scale.dims().size(),
1UL,
errors::InvalidArgument("The input(weight_scale) must be a 1D Tensor."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_dims[x_dims.size() - 1], x_dims[x_dims.size() - 1],
w_dims[0], w_dims[1],
errors::InvalidArgument( errors::InvalidArgument(
"Input(X) dim[-1] and Input(Weight) dim[0] should be euqal." "Input(X) dim[-1] and Input(Weight) dim[1] should be euqal."
"But received Input(X) dim[-1](%s) != Input(Weight) dim[0](%s)", "But received Input(X) dim[-1](%s) != Input(Weight) dim[1](%s)",
x_dims[x_dims.size() - 1], x_dims[x_dims.size() - 1],
w_dims[0])); w_dims[1]));
auto out_dims = x_dims; auto out_dims = x_dims;
out_dims[out_dims.size() - 1] = w_dims[1]; out_dims[out_dims.size() - 1] = n;
out->set_dims(out_dims); out->set_dims(out_dims);
out->set_dtype(x.dtype()); out->set_dtype(x.dtype());
} }
......
...@@ -690,12 +690,13 @@ void FusedMultiHeadAttentionVariableInferMeta(const MetaTensor& query, ...@@ -690,12 +690,13 @@ void FusedMultiHeadAttentionVariableInferMeta(const MetaTensor& query,
bool causal, bool causal,
MetaTensor* out); MetaTensor* out);
void LLMInt8MatMulInferMeta(const MetaTensor& x, void LLMInt8MatmulInferMeta(const MetaTensor& x,
const MetaTensor& weight, const MetaTensor& weight,
MetaTensor* out); MetaTensor* out);
void WeightOnlyMatMulInferMeta(const MetaTensor& x, void WeightOnlyMatmulInferMeta(const MetaTensor& x,
const MetaTensor& weight, const MetaTensor& weight,
const MetaTensor& weight_scale,
MetaTensor* out); MetaTensor* out);
void FusedRopeInferMeta(const MetaTensor& q, void FusedRopeInferMeta(const MetaTensor& q,
......
...@@ -5086,22 +5086,17 @@ void QuantForCompressInferMeta(const MetaTensor& x, ...@@ -5086,22 +5086,17 @@ void QuantForCompressInferMeta(const MetaTensor& x,
x_dims[0])); x_dims[0]));
std::vector<int64_t> dim_scale({x_dims[1]}); std::vector<int64_t> dim_scale({x_dims[1]});
std::vector<int64_t> dim_out; std::vector<int64_t> dim_out;
if (layout == "weight_only") { if (bits == 8) {
dim_out = std::vector<int64_t>({x_dims[0], x_dims[1]});
} else if (layout == "llm.int8") {
dim_out = std::vector<int64_t>({x_dims[1], x_dims[0]}); dim_out = std::vector<int64_t>({x_dims[1], x_dims[0]});
} else if (bits == 4) {
dim_out = std::vector<int64_t>({x_dims[1] / 2, x_dims[0]});
} else { } else {
phi::errors::InvalidArgument( phi::errors::InvalidArgument("The bit must be 8 or 4, but got %d", bits);
"The layout must be weight_only or llm.int8, but got %s", layout);
} }
out->set_dims(phi::make_ddim(dim_out)); out->set_dims(phi::make_ddim(dim_out));
// TODO(lizhenyun) support weight_only int4
if (bits == 8) {
out->set_dtype(DataType::INT8); out->set_dtype(DataType::INT8);
} else {
phi::errors::Fatal("The bits only support 8, but got[%d]", bits);
}
scale->set_dims(phi::make_ddim(dim_scale)); scale->set_dims(phi::make_ddim(dim_scale));
scale->set_dtype(DataType::FLOAT32); scale->set_dtype(DataType::FLOAT32);
} }
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
namespace phi { namespace phi {
template <typename DeviceContext, typename T, typename D> template <typename DeviceContext, typename T, typename D, int bits>
void quant_compute(const DeviceContext& dev_ctx, void quant_compute(const DeviceContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
DenseTensor* out, DenseTensor* out,
...@@ -59,15 +59,15 @@ void quant_compute(const DeviceContext& dev_ctx, ...@@ -59,15 +59,15 @@ void quant_compute(const DeviceContext& dev_ctx,
per_channel_scale(scale_data, x_data, m, n); per_channel_scale(scale_data, x_data, m, n);
per_channel_quant(x_int_data, x_data, scale_data, m, n); per_channel_quant<T, bits>(x_int_data, x_data, scale_data, m, n);
if (layout == "weight_only") { if (layout == "weight_only") {
permute_B_rows_for_mixed_gemm( permute_B_rows_for_mixed_gemm<bits>(
int_processed_data, x_int_data, std::vector<size_t>{m, n}, (int64_t)80); int_processed_data, x_int_data, std::vector<size_t>{m, n}, (int64_t)80);
row_major_to_column_major( subbyte_transpose_impl<bits>(
int_processed_2_data, int_processed_data, std::vector<size_t>{m, n}); int_processed_2_data, int_processed_data, std::vector<size_t>{m, n});
interleave_column_major_tensor( interleave_column_major_tensor<bits>(
out_data, int_processed_2_data, std::vector<size_t>{m, n}); out_data, int_processed_2_data, std::vector<size_t>{m, n});
add_bias_and_interleave_int8s_inplace(out_data, num); add_bias_and_interleave_inplace<bits>(out_data, num);
} else if (layout == "llm.int8") { } else if (layout == "llm.int8") {
std::vector<int> axis = {1, 0}; std::vector<int> axis = {1, 0};
funcs::Transpose<DeviceContext, int8_t, 2> trans; funcs::Transpose<DeviceContext, int8_t, 2> trans;
...@@ -88,9 +88,16 @@ void QuantForCompressKernel(const Context& dev_ctx, ...@@ -88,9 +88,16 @@ void QuantForCompressKernel(const Context& dev_ctx,
if (bits == 8) { if (bits == 8) {
dev_ctx.template Alloc<int8_t>(out); dev_ctx.template Alloc<int8_t>(out);
dev_ctx.template Alloc<float>(scale); dev_ctx.template Alloc<float>(scale);
quant_compute<Context, T, int8_t>(dev_ctx, x, out, scale, layout); quant_compute<Context, T, int8_t, 8>(dev_ctx, x, out, scale, layout);
} else if (bits == 4 && layout == "weight_only") {
dev_ctx.template Alloc<int8_t>(out);
dev_ctx.template Alloc<float>(scale);
quant_compute<Context, T, int8_t, 4>(dev_ctx, x, out, scale, layout);
} else { } else {
phi::errors::Unimplemented("The bits only support 8, but got[%d]", bits); phi::errors::Unimplemented(
"The bits only support 8 or weight_only 4, but got[%s] [%d]",
layout,
bits);
} }
// VLOG(0) << "x: " << x.dtype() << x; // VLOG(0) << "x: " << x.dtype() << x;
// VLOG(0) << "out: " << out->dtype() << *out; // VLOG(0) << "out: " << out->dtype() << *out;
...@@ -102,5 +109,4 @@ PD_REGISTER_KERNEL(quant_for_compress, ...@@ -102,5 +109,4 @@ PD_REGISTER_KERNEL(quant_for_compress,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::QuantForCompressKernel, phi::QuantForCompressKernel,
float,
phi::dtype::float16) {} phi::dtype::float16) {}
...@@ -12,12 +12,12 @@ ...@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/llm_int8_mat_mul_kernel.h" #include "paddle/phi/kernels/llm_int8_matmul_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#ifndef PADDLE_WITH_HIP #ifndef PADDLE_WITH_HIP
#include "paddle/phi/kernels/impl/llm_int8_mat_mul_kernel_impl.h" #include "paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h"
#endif #endif
namespace phi { namespace phi {
...@@ -56,7 +56,7 @@ void llm_int8_compute(const Context& dev_ctx, ...@@ -56,7 +56,7 @@ void llm_int8_compute(const Context& dev_ctx,
} }
template <typename T, typename Context> template <typename T, typename Context>
void LLMInt8MatMulKernel(const Context& dev_ctx, void LLMInt8MatmulKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& weight, const DenseTensor& weight,
const DenseTensor& weight_scale, const DenseTensor& weight_scale,
...@@ -68,8 +68,8 @@ void LLMInt8MatMulKernel(const Context& dev_ctx, ...@@ -68,8 +68,8 @@ void LLMInt8MatMulKernel(const Context& dev_ctx,
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(llm_int8_mat_mul, PD_REGISTER_KERNEL(llm_int8_matmul,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::LLMInt8MatMulKernel, phi::LLMInt8MatmulKernel,
phi::dtype::float16) {} phi::dtype::float16) {}
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/weight_only_mat_mul_kernel.h" #include "paddle/phi/kernels/weight_only_matmul_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/datatype_traits.h" #include "paddle/phi/common/datatype_traits.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void WeightOnlyMatMulKernel(const Context& dev_ctx, void WeightOnlyMatmulKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& weight, const DenseTensor& weight,
const DenseTensor& weight_scale, const DenseTensor& weight_scale,
...@@ -32,17 +32,26 @@ void WeightOnlyMatMulKernel(const Context& dev_ctx, ...@@ -32,17 +32,26 @@ void WeightOnlyMatMulKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
const auto x_dims = x.dims(); const auto x_dims = x.dims();
const auto w_dims = weight.dims(); const auto w_dims = weight.dims();
int n = weight_scale.dims()[0];
int quant_bit = 0;
if (n % w_dims[0] == 0) {
quant_bit = w_dims[0] * 8 / n;
} else {
errors::InvalidArgument(
"w_dims[0] must be divisible by weight_scale.dims()[0]");
}
int k = w_dims[0]; int k = w_dims[1];
int n = w_dims[1];
int m = x.numel() / k; int m = x.numel() / k;
switch (quant_bit) {
case 8: {
auto mixed_gemm_runner = auto mixed_gemm_runner =
CutlassFpAIntBGemmRunner<typename PDDataTypeTraits<T>::DataType, CutlassFpAIntBGemmRunner<typename PDDataTypeTraits<T>::DataType,
uint8_t>(); uint8_t>();
int mixgemm_max_size = std::max(n, k); int mixgemm_max_size = std::max(n, k);
DenseTensor mixgemm_workspace; DenseTensor mixgemm_workspace;
int64_t mixgemm_workspace_size_bytes = int64_t mixgemm_workspace_size_bytes = mixed_gemm_runner.getWorkspaceSize(
mixed_gemm_runner.getWorkspaceSize(m, mixgemm_max_size, mixgemm_max_size); m, mixgemm_max_size, mixgemm_max_size);
mixgemm_workspace.Resize({mixgemm_workspace_size_bytes}); mixgemm_workspace.Resize({mixgemm_workspace_size_bytes});
dev_ctx.template Alloc<uint8_t>(&mixgemm_workspace); dev_ctx.template Alloc<uint8_t>(&mixgemm_workspace);
...@@ -53,21 +62,56 @@ void WeightOnlyMatMulKernel(const Context& dev_ctx, ...@@ -53,21 +62,56 @@ void WeightOnlyMatMulKernel(const Context& dev_ctx,
x.data<T>()), x.data<T>()),
reinterpret_cast<const uint8_t*>(weight.data<int8_t>()), reinterpret_cast<const uint8_t*>(weight.data<int8_t>()),
reinterpret_cast<const float*>(weight_scale.data<float>()), reinterpret_cast<const float*>(weight_scale.data<float>()),
reinterpret_cast<typename PDDataTypeTraits<T>::DataType*>(out->data<T>()), reinterpret_cast<typename PDDataTypeTraits<T>::DataType*>(
out->data<T>()),
m, m,
n, n,
k, k,
mixgemm_workspace_data, mixgemm_workspace_data,
mixgemm_workspace_size_bytes, mixgemm_workspace_size_bytes,
dev_ctx.stream()); dev_ctx.stream());
} break;
case 4: {
auto mixed_gemm_runner =
CutlassFpAIntBGemmRunner<typename PDDataTypeTraits<T>::DataType,
cutlass::uint4b_t>();
int mixgemm_max_size = std::max(n, k);
DenseTensor mixgemm_workspace;
int64_t mixgemm_workspace_size_bytes = mixed_gemm_runner.getWorkspaceSize(
m, mixgemm_max_size, mixgemm_max_size);
mixgemm_workspace.Resize({mixgemm_workspace_size_bytes});
dev_ctx.template Alloc<uint8_t>(&mixgemm_workspace);
char* mixgemm_workspace_data =
reinterpret_cast<char*>(mixgemm_workspace.data<uint8_t>());
mixed_gemm_runner.gemm(
reinterpret_cast<const typename PDDataTypeTraits<T>::DataType*>(
x.data<T>()),
reinterpret_cast<const cutlass::uint4b_t*>(weight.data<int8_t>()),
reinterpret_cast<const float*>(weight_scale.data<float>()),
reinterpret_cast<typename PDDataTypeTraits<T>::DataType*>(
out->data<T>()),
m,
n,
k,
mixgemm_workspace_data,
mixgemm_workspace_size_bytes,
dev_ctx.stream());
} break;
default:
PADDLE_THROW(errors::Unimplemented(
"Quant_bits (%d) is not supported when gemm ", quant_bit));
break;
}
#else #else
LOG(ERROR) << "Please compile with cutlass to EnableUseCutlass()"; LOG(ERROR) << "Please compile with cutlass to EnableUseCutlass()";
#endif #endif
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(weight_only_mat_mul, PD_REGISTER_KERNEL(weight_only_matmul,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::WeightOnlyMatMulKernel, phi::WeightOnlyMatmulKernel,
phi::dtype::float16) {} phi::dtype::float16) {}
...@@ -39,51 +39,109 @@ void per_channel_scale(float* scale, const T* input, size_t m, size_t n) { ...@@ -39,51 +39,109 @@ void per_channel_scale(float* scale, const T* input, size_t m, size_t n) {
} }
} }
template <typename T, typename D> template <typename T, int quant_bit = 8>
void per_channel_quant( void per_channel_quant(int8_t* output,
D* output, const T* input, const float* scale, size_t m, size_t n) { const T* input,
for (size_t i = 0; i < m; i++) { const float* scale,
for (size_t j = 0; j < n; j++) { size_t num_rows,
output[i * n + j] = static_cast<D>( size_t num_cols) {
round(static_cast<float>(input[i * n + j]) / scale[j])); size_t bytes_per_out_col = num_cols * quant_bit / 8;
for (size_t ii = 0; ii < num_rows; ++ii) {
int8_t* current_quantized_weight_row = output + ii * bytes_per_out_col;
const T* current_weight_row = input + ii * num_cols;
for (size_t jj = 0; jj < bytes_per_out_col; ++jj) {
if (quant_bit == 8) {
const float col_scale = scale[jj];
const float weight_elt = static_cast<float>(current_weight_row[jj]);
const float scaled_weight = round(weight_elt / col_scale);
const int8_t clipped_weight = static_cast<int8_t>(
std::max(-127.f, std::min(127.f, scaled_weight)));
current_quantized_weight_row[jj] = clipped_weight;
} else if (quant_bit == 4) {
// We will pack two int4 elements per iteration of the inner loop.
int8_t packed_int4s = 0;
for (int packed_idx = 0; packed_idx < 2; ++packed_idx) {
const size_t input_idx = 2 * jj + packed_idx;
if (input_idx < num_cols) {
const float col_scale = scale[input_idx];
const float weight_elt =
static_cast<float>(current_weight_row[input_idx]);
const float scaled_weight = round(weight_elt / col_scale);
int int_weight = static_cast<int>(scaled_weight);
const int8_t clipped_weight = std::max(-7, std::min(7, int_weight));
// Kill the sign extension bits (hence 0x0F mask) then shift to
// upper bits if packing the second int4 and or the bits into the
// final result.
packed_int4s |= ((clipped_weight & 0x0F) << (4 * packed_idx));
}
}
current_quantized_weight_row[jj] = packed_int4s;
} else {
phi::errors::Unimplemented("Unsupported quantization bits: %d",
quant_bit);
} }
} }
}
void row_major_to_column_major(int8_t* col_major_tensor,
const int8_t* row_major_tensor,
const std::vector<size_t>& shape) {
size_t m = shape[0];
size_t n = shape[1];
for (size_t i = 0; i < m * n; i++) {
size_t im = i / n;
size_t in = i % n;
col_major_tensor[in * m + im] = row_major_tensor[im * n + in];
} }
} }
void add_bias_and_interleave_int8s_inplace(int8_t* int8_tensor_ptr, template <int quant_bit = 8>
size_t num_elts) { void add_bias_and_interleave_inplace(int8_t* tensor_ptr, size_t num_elts) {
int8_t* int8_tensor = reinterpret_cast<int8_t*>(int8_tensor_ptr); const size_t num_bytes = num_elts * quant_bit / 8;
for (size_t ii = 0; ii < num_elts; ++ii) {
int8_tensor[ii] =
static_cast<int8_t>(static_cast<int>(int8_tensor[ii]) + 128);
}
// Step 2 will transform the layout of a 32-bit register in CUDA in order to
// match the int4 layout. This has no performance benefit and is purely so
// that int4 and int8 have the same layout. Pictorially, this does the
// following: bit 32 0
// [elt_3 elt_2 elt_1 elt_0] (each elt occupies 8 bits)
//
// And it will rearrange the output 32 bit register to be the following:
// bit 32 0
// [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits)
for (size_t ii = 0; ii < num_bytes; ++ii) {
if (quant_bit == 8) {
tensor_ptr[ii] =
static_cast<int8_t>(static_cast<int>(tensor_ptr[ii]) + 128);
} else {
int8_t transformed_packed_int4s = 0;
int8_t transformed_first_elt =
(int8_t(tensor_ptr[ii] << 4) >> 4) +
8; // The double shift here is to ensure sign extension
int8_t transformed_second_elt = (tensor_ptr[ii] >> 4) + 8;
if (!(transformed_first_elt >= 0 && transformed_first_elt <= 15)) {
phi::errors::InvalidArgument(
"Illegal result for int4 transform (first elt)");
}
if (!(transformed_second_elt >= 0 && transformed_second_elt <= 15)) {
phi::errors::InvalidArgument(
"Illegal result for int4 transform (second elt)");
}
// We don't need to mask in these ops since everything should be in the
// range 0-15
transformed_packed_int4s |= transformed_first_elt;
transformed_packed_int4s |= (transformed_second_elt << 4);
tensor_ptr[ii] = transformed_packed_int4s;
}
}
if (quant_bit == 8) {
for (size_t base = 0; base < num_elts; base += 4) { for (size_t base = 0; base < num_elts; base += 4) {
std::swap(int8_tensor[base + 1], int8_tensor[base + 2]); std::swap(tensor_ptr[base + 1], tensor_ptr[base + 2]);
}
} else {
const size_t num_registers = num_bytes / 4;
uint32_t* register_ptr = reinterpret_cast<uint32_t*>(tensor_ptr);
for (size_t ii = 0; ii < num_registers; ++ii) {
const uint32_t current_register = register_ptr[ii];
uint32_t transformed_register = 0;
for (int dest_idx = 0; dest_idx < 8; ++dest_idx) {
const int src_idx =
dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1;
const int src_shift = 4 * src_idx;
const int dest_shift = 4 * dest_idx;
const uint32_t src_bits = (current_register >> src_shift) & 0xF;
transformed_register |= (src_bits << dest_shift);
}
register_ptr[ii] = transformed_register;
}
} }
} }
template <int quant_bit>
void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor,
const int8_t* quantized_tensor, const int8_t* quantized_tensor,
const std::vector<size_t>& shape, const std::vector<size_t>& shape,
...@@ -92,9 +150,8 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, ...@@ -92,9 +150,8 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor,
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
const int BITS_PER_ELT = 8; const int BITS_PER_ELT = quant_bit;
const int K = 16 / BITS_PER_ELT; const int K = 16 / BITS_PER_ELT;
// const int ELTS_PER_BYTE = 8 / BITS_PER_ELT;
const int ELTS_PER_REG = 32 / BITS_PER_ELT; const int ELTS_PER_REG = 32 / BITS_PER_ELT;
const uint32_t* input_byte_ptr = const uint32_t* input_byte_ptr =
...@@ -102,7 +159,6 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, ...@@ -102,7 +159,6 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor,
uint32_t* output_byte_ptr = uint32_t* output_byte_ptr =
reinterpret_cast<uint32_t*>(permuted_quantized_tensor); reinterpret_cast<uint32_t*>(permuted_quantized_tensor);
// int MMA_SHAPE_N = 8;
int B_ROWS_PER_MMA = 8 * K; int B_ROWS_PER_MMA = 8 * K;
const int elts_in_int32 = 32 / BITS_PER_ELT; const int elts_in_int32 = 32 / BITS_PER_ELT;
...@@ -118,15 +174,134 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, ...@@ -118,15 +174,134 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor,
const int read_row = base_row + tile_read_row; const int read_row = base_row + tile_read_row;
const int read_col = write_col; const int read_col = write_col;
const int64_t read_offset = int64_t(read_row) * num_vec_cols + read_col; const int64_t read_offset =
static_cast<int64_t>(read_row) * num_vec_cols + read_col;
const int64_t write_offset = const int64_t write_offset =
int64_t(write_row) * num_vec_cols + write_col; static_cast<int64_t>(write_row) * num_vec_cols + write_col;
output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; output_byte_ptr[write_offset] = input_byte_ptr[read_offset];
} }
} }
} }
} }
template <int quant_bit>
void subbyte_transpose_impl(int8_t* transposed_quantized_tensor,
const int8_t* quantized_tensor,
const std::vector<size_t>& shape) {
const int bits_per_elt = quant_bit;
// FT_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be
// 2-D or 3-D");
// const size_t num_experts = 1;
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
const size_t col_bytes = num_cols * bits_per_elt / 8;
const size_t col_bytes_trans = num_rows * bits_per_elt / 8;
// const size_t num_bytes = size_t(num_experts) * num_rows * col_bytes;
const uint8_t* input_byte_ptr =
reinterpret_cast<const uint8_t*>(quantized_tensor);
uint8_t* output_byte_ptr =
reinterpret_cast<uint8_t*>(transposed_quantized_tensor);
static constexpr int ELTS_PER_BYTE = 8 / quant_bit;
static constexpr int M_TILE_L1 = 64;
static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE;
uint8_t cache_buf[M_TILE_L1][N_TILE_L1];
static constexpr int VECTOR_WIDTH = std::min(32, N_TILE_L1);
// const int num_m_tiles = (num_rows + M_TILE_L1 - 1) / M_TILE_L1;
// const int num_n_tiles = (col_bytes + N_TILE_L1 - 1) / N_TILE_L1;
for (size_t row_tile_start = 0; row_tile_start < num_rows;
row_tile_start += M_TILE_L1) {
for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes;
col_tile_start_byte += N_TILE_L1) {
const int row_limit = std::min(row_tile_start + M_TILE_L1, num_rows);
const int col_limit =
std::min(col_tile_start_byte + N_TILE_L1, col_bytes);
for (int ii = 0; ii < M_TILE_L1; ++ii) {
const int row = row_tile_start + ii;
for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) {
const int col = col_tile_start_byte + jj;
const size_t logical_src_offset = row * col_bytes + col;
if (row < row_limit && col < col_limit) {
for (int v = 0; v < VECTOR_WIDTH; ++v) {
cache_buf[ii][jj + v] = input_byte_ptr[logical_src_offset + v];
}
}
}
}
if (quant_bit == 8) {
for (int ii = 0; ii < M_TILE_L1; ++ii) {
for (int jj = ii + 1; jj < N_TILE_L1; ++jj) {
std::swap(cache_buf[ii][jj], cache_buf[jj][ii]);
}
}
} else if (quant_bit == 4) {
for (int ii = 0; ii < M_TILE_L1; ++ii) {
// Using M_TILE_L1 here is deliberate since we assume that the cache
// tile is square in the number of elements (not necessarily the
// number of bytes).
for (int jj = ii + 1; jj < M_TILE_L1; ++jj) {
const int ii_byte = ii / ELTS_PER_BYTE;
const int ii_bit_offset = ii % ELTS_PER_BYTE;
const int jj_byte = jj / ELTS_PER_BYTE;
const int jj_bit_offset = jj % ELTS_PER_BYTE;
uint8_t src_elt =
0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset));
uint8_t tgt_elt =
0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset));
cache_buf[ii][jj_byte] &= (0xF0 >> (4 * jj_bit_offset));
cache_buf[jj][ii_byte] &= (0xF0 >> (4 * ii_bit_offset));
cache_buf[ii][jj_byte] |= (tgt_elt << (4 * jj_bit_offset));
cache_buf[jj][ii_byte] |= (src_elt << (4 * ii_bit_offset));
}
}
} else {
phi::errors::Unimplemented("Unsupported quantization bits: %d",
quant_bit);
}
const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE;
const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE;
const int row_limit_trans =
std::min(row_tile_start_trans + M_TILE_L1, num_cols);
const int col_limit_trans =
std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans);
for (int ii = 0; ii < M_TILE_L1; ++ii) {
const int row = row_tile_start_trans + ii;
for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) {
const int col = col_tile_start_byte_trans + jj;
const size_t logical_tgt_offset = row * col_bytes_trans + col;
if (row < row_limit_trans && col < col_limit_trans) {
for (int v = 0; v < VECTOR_WIDTH; ++v) {
output_byte_ptr[logical_tgt_offset + v] = cache_buf[ii][jj + v];
}
}
}
}
}
}
}
template <int quant_bit>
void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor,
const int8_t* quantized_tensor, const int8_t* quantized_tensor,
const std::vector<size_t>& shape) { const std::vector<size_t>& shape) {
...@@ -134,7 +309,7 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, ...@@ -134,7 +309,7 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor,
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
const size_t BITS_PER_ELT = 8; const size_t BITS_PER_ELT = quant_bit;
const size_t elts_in_int32 = 32 / BITS_PER_ELT; const size_t elts_in_int32 = 32 / BITS_PER_ELT;
const size_t rows_per_tile = 64; const size_t rows_per_tile = 64;
...@@ -169,6 +344,5 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, ...@@ -169,6 +344,5 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor,
} }
} }
} }
} // namespace phi } // namespace phi
#endif // PADDLE_PHI_KERNELS_IMPL_QUANT_FOR_COMPRESS_KERNEL_IMPL_H_ #endif // PADDLE_PHI_KERNELS_IMPL_QUANT_FOR_COMPRESS_KERNEL_IMPL_H_
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void LLMInt8MatMulKernel(const Context& dev_ctx, void LLMInt8MatmulKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& weight, const DenseTensor& weight,
const DenseTensor& weight_scale, const DenseTensor& weight_scale,
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void WeightOnlyMatMulKernel(const Context& dev_ctx, void WeightOnlyMatmulKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& weight, const DenseTensor& weight,
const DenseTensor& weight_scale, const DenseTensor& weight_scale,
......
...@@ -1892,11 +1892,11 @@ def linear_compress( ...@@ -1892,11 +1892,11 @@ def linear_compress(
): ):
if in_dynamic_mode(): if in_dynamic_mode():
if algo == "llm.int8": if algo == "llm.int8":
y = _C_ops.llm_int8_mat_mul( y = _C_ops.llm_int8_matmul(
x, weight, weight_scale, config['threshold'] x, weight, weight_scale, config['threshold']
) )
elif algo == "weight_only": elif algo == "weight_only":
y = _C_ops.weight_only_mat_mul(x, weight, weight_scale) y = _C_ops.weight_only_matmul(x, weight, weight_scale)
else: else:
raise ValueError( raise ValueError(
"Unknown algo: '{}'. It can only be 'llm.int8' or 'weight_only'.".format( "Unknown algo: '{}'. It can only be 'llm.int8' or 'weight_only'.".format(
...@@ -1915,11 +1915,19 @@ def linear_compress( ...@@ -1915,11 +1915,19 @@ def linear_compress(
if algo == "llm.int8": if algo == "llm.int8":
type = "llm_int8_matmul" type = "llm_int8_matmul"
inputs = {'X': [x], 'Y': [weight], 'weight_scale': [weight_scale]} inputs = {
'x': [x],
'weight': [weight],
'weight_scale': [weight_scale],
}
attrs = {'algo': algo, 'threshold': config['threshold']} attrs = {'algo': algo, 'threshold': config['threshold']}
elif algo == "weight_only": elif algo == "weight_only":
type = "weight_only_matmul" type = "weight_only_matmul"
inputs = {'X': [x], 'Y': [weight], 'weight_scale': [weight_scale]} inputs = {
'x': [x],
'weight': [weight],
'weight_scale': [weight_scale],
}
attrs = {} attrs = {}
else: else:
raise ValueError( raise ValueError(
......
...@@ -301,10 +301,13 @@ class LinearCompress(Layer): ...@@ -301,10 +301,13 @@ class LinearCompress(Layer):
weight_attr = paddle.framework.ParamAttr( weight_attr = paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(weight_tensor) initializer=paddle.nn.initializer.Assign(weight_tensor)
) )
weight_shape = (
[self.weight.shape[1], self.weight.shape[0]]
if self.bits == 8
else [self.weight.shape[1] / 2, self.weight.shape[0]]
)
self.weight = self.create_parameter( self.weight = self.create_parameter(
shape=self.weight.shape shape=weight_shape,
if self.layout == 0
else [self.weight.shape[1], self.weight.shape[0]],
attr=weight_attr, attr=weight_attr,
dtype="int8", dtype="int8",
is_bias=False, is_bias=False,
......
...@@ -36,6 +36,7 @@ class LinearTestCase(unittest.TestCase): ...@@ -36,6 +36,7 @@ class LinearTestCase(unittest.TestCase):
self.in_features = 64 self.in_features = 64
self.out_features = 64 self.out_features = 64
self.algo = "weight_only" self.algo = "weight_only"
self.bits = 8
def setUp(self): def setUp(self):
self.config() self.config()
...@@ -62,6 +63,7 @@ class LinearTestCase(unittest.TestCase): ...@@ -62,6 +63,7 @@ class LinearTestCase(unittest.TestCase):
self.in_features, self.in_features,
self.out_features, self.out_features,
bias_attr=bias_attr, bias_attr=bias_attr,
bits=8,
algo=self.algo, algo=self.algo,
config=self.config, config=self.config,
) )
...@@ -112,5 +114,15 @@ class LinearTestCase3(LinearTestCase): ...@@ -112,5 +114,15 @@ class LinearTestCase3(LinearTestCase):
self.atol = 1e-1 self.atol = 1e-1
class LinearTestCase4(LinearTestCase):
def config(self):
super().config()
self.dtype = 'float16'
self.bias = True
self.in_features = 128
self.out_features = 64
self.bits = 4
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册