未验证 提交 f4290a92 编写于 作者: FormlessUnit's avatar FormlessUnit 提交者: GitHub

Linear compress (#55128)

* rename weight_only/llm.int8
上级 ab46b14c
......@@ -1367,14 +1367,14 @@
data_transform :
skip_transform : out_size, size_tensor, scale_tensor
- op : llm_int8_mat_mul
- op : llm_int8_matmul
args : (Tensor x, Tensor weight, Tensor weight_scale, float threshold=6.0)
output : Tensor(out)
infer_meta :
func : LLMInt8MatMulInferMeta
func : LLMInt8MatmulInferMeta
param : [x, weight]
kernel :
func : llm_int8_mat_mul
func : llm_int8_matmul
data_type : x
- op : log
......@@ -2602,14 +2602,13 @@
intermediate: warprnntgrad
backward : warprnnt_grad
- op : weight_only_mat_mul
- op : weight_only_matmul
args : (Tensor x, Tensor weight, Tensor weight_scale)
output : Tensor(out)
infer_meta :
func : WeightOnlyMatMulInferMeta
param : [x, weight]
func : WeightOnlyMatmulInferMeta
kernel :
func : weight_only_mat_mul
func : weight_only_matmul
data_type : x
- op : weighted_sample_neighbors
......
......@@ -3572,7 +3572,7 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row,
out_count->set_dtype(DataType::INT32);
}
void LLMInt8MatMulInferMeta(const MetaTensor& x,
void LLMInt8MatmulInferMeta(const MetaTensor& x,
const MetaTensor& weight,
MetaTensor* out) {
auto x_dims = x.dims();
......@@ -3595,25 +3595,31 @@ void LLMInt8MatMulInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}
void WeightOnlyMatMulInferMeta(const MetaTensor& x,
void WeightOnlyMatmulInferMeta(const MetaTensor& x,
const MetaTensor& weight,
const MetaTensor& weight_scale,
MetaTensor* out) {
auto x_dims = x.dims();
auto w_dims = weight.dims();
auto n = weight_scale.dims()[0];
PADDLE_ENFORCE_EQ(
w_dims.size(),
2UL,
errors::InvalidArgument("The input(weight) must be a 2D Tensor."));
PADDLE_ENFORCE_EQ(
weight_scale.dims().size(),
1UL,
errors::InvalidArgument("The input(weight_scale) must be a 1D Tensor."));
PADDLE_ENFORCE_EQ(
x_dims[x_dims.size() - 1],
w_dims[0],
w_dims[1],
errors::InvalidArgument(
"Input(X) dim[-1] and Input(Weight) dim[0] should be euqal."
"But received Input(X) dim[-1](%s) != Input(Weight) dim[0](%s)",
"Input(X) dim[-1] and Input(Weight) dim[1] should be euqal."
"But received Input(X) dim[-1](%s) != Input(Weight) dim[1](%s)",
x_dims[x_dims.size() - 1],
w_dims[0]));
w_dims[1]));
auto out_dims = x_dims;
out_dims[out_dims.size() - 1] = w_dims[1];
out_dims[out_dims.size() - 1] = n;
out->set_dims(out_dims);
out->set_dtype(x.dtype());
}
......
......@@ -690,12 +690,13 @@ void FusedMultiHeadAttentionVariableInferMeta(const MetaTensor& query,
bool causal,
MetaTensor* out);
void LLMInt8MatMulInferMeta(const MetaTensor& x,
void LLMInt8MatmulInferMeta(const MetaTensor& x,
const MetaTensor& weight,
MetaTensor* out);
void WeightOnlyMatMulInferMeta(const MetaTensor& x,
void WeightOnlyMatmulInferMeta(const MetaTensor& x,
const MetaTensor& weight,
const MetaTensor& weight_scale,
MetaTensor* out);
void FusedRopeInferMeta(const MetaTensor& q,
......
......@@ -5086,22 +5086,17 @@ void QuantForCompressInferMeta(const MetaTensor& x,
x_dims[0]));
std::vector<int64_t> dim_scale({x_dims[1]});
std::vector<int64_t> dim_out;
if (layout == "weight_only") {
dim_out = std::vector<int64_t>({x_dims[0], x_dims[1]});
} else if (layout == "llm.int8") {
if (bits == 8) {
dim_out = std::vector<int64_t>({x_dims[1], x_dims[0]});
} else if (bits == 4) {
dim_out = std::vector<int64_t>({x_dims[1] / 2, x_dims[0]});
} else {
phi::errors::InvalidArgument(
"The layout must be weight_only or llm.int8, but got %s", layout);
phi::errors::InvalidArgument("The bit must be 8 or 4, but got %d", bits);
}
out->set_dims(phi::make_ddim(dim_out));
// TODO(lizhenyun) support weight_only int4
if (bits == 8) {
out->set_dtype(DataType::INT8);
} else {
phi::errors::Fatal("The bits only support 8, but got[%d]", bits);
}
scale->set_dims(phi::make_ddim(dim_scale));
scale->set_dtype(DataType::FLOAT32);
}
......
......@@ -22,7 +22,7 @@
namespace phi {
template <typename DeviceContext, typename T, typename D>
template <typename DeviceContext, typename T, typename D, int bits>
void quant_compute(const DeviceContext& dev_ctx,
const DenseTensor& x,
DenseTensor* out,
......@@ -59,15 +59,15 @@ void quant_compute(const DeviceContext& dev_ctx,
per_channel_scale(scale_data, x_data, m, n);
per_channel_quant(x_int_data, x_data, scale_data, m, n);
per_channel_quant<T, bits>(x_int_data, x_data, scale_data, m, n);
if (layout == "weight_only") {
permute_B_rows_for_mixed_gemm(
permute_B_rows_for_mixed_gemm<bits>(
int_processed_data, x_int_data, std::vector<size_t>{m, n}, (int64_t)80);
row_major_to_column_major(
subbyte_transpose_impl<bits>(
int_processed_2_data, int_processed_data, std::vector<size_t>{m, n});
interleave_column_major_tensor(
interleave_column_major_tensor<bits>(
out_data, int_processed_2_data, std::vector<size_t>{m, n});
add_bias_and_interleave_int8s_inplace(out_data, num);
add_bias_and_interleave_inplace<bits>(out_data, num);
} else if (layout == "llm.int8") {
std::vector<int> axis = {1, 0};
funcs::Transpose<DeviceContext, int8_t, 2> trans;
......@@ -88,9 +88,16 @@ void QuantForCompressKernel(const Context& dev_ctx,
if (bits == 8) {
dev_ctx.template Alloc<int8_t>(out);
dev_ctx.template Alloc<float>(scale);
quant_compute<Context, T, int8_t>(dev_ctx, x, out, scale, layout);
quant_compute<Context, T, int8_t, 8>(dev_ctx, x, out, scale, layout);
} else if (bits == 4 && layout == "weight_only") {
dev_ctx.template Alloc<int8_t>(out);
dev_ctx.template Alloc<float>(scale);
quant_compute<Context, T, int8_t, 4>(dev_ctx, x, out, scale, layout);
} else {
phi::errors::Unimplemented("The bits only support 8, but got[%d]", bits);
phi::errors::Unimplemented(
"The bits only support 8 or weight_only 4, but got[%s] [%d]",
layout,
bits);
}
// VLOG(0) << "x: " << x.dtype() << x;
// VLOG(0) << "out: " << out->dtype() << *out;
......@@ -102,5 +109,4 @@ PD_REGISTER_KERNEL(quant_for_compress,
CPU,
ALL_LAYOUT,
phi::QuantForCompressKernel,
float,
phi::dtype::float16) {}
......@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/llm_int8_mat_mul_kernel.h"
#include "paddle/phi/kernels/llm_int8_matmul_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#ifndef PADDLE_WITH_HIP
#include "paddle/phi/kernels/impl/llm_int8_mat_mul_kernel_impl.h"
#include "paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h"
#endif
namespace phi {
......@@ -56,7 +56,7 @@ void llm_int8_compute(const Context& dev_ctx,
}
template <typename T, typename Context>
void LLMInt8MatMulKernel(const Context& dev_ctx,
void LLMInt8MatmulKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& weight,
const DenseTensor& weight_scale,
......@@ -68,8 +68,8 @@ void LLMInt8MatMulKernel(const Context& dev_ctx,
}
} // namespace phi
PD_REGISTER_KERNEL(llm_int8_mat_mul,
PD_REGISTER_KERNEL(llm_int8_matmul,
GPU,
ALL_LAYOUT,
phi::LLMInt8MatMulKernel,
phi::LLMInt8MatmulKernel,
phi::dtype::float16) {}
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/weight_only_mat_mul_kernel.h"
#include "paddle/phi/kernels/weight_only_matmul_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/datatype_traits.h"
#include "paddle/phi/core/kernel_registry.h"
......@@ -23,7 +23,7 @@
namespace phi {
template <typename T, typename Context>
void WeightOnlyMatMulKernel(const Context& dev_ctx,
void WeightOnlyMatmulKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& weight,
const DenseTensor& weight_scale,
......@@ -32,17 +32,26 @@ void WeightOnlyMatMulKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(out);
const auto x_dims = x.dims();
const auto w_dims = weight.dims();
int n = weight_scale.dims()[0];
int quant_bit = 0;
if (n % w_dims[0] == 0) {
quant_bit = w_dims[0] * 8 / n;
} else {
errors::InvalidArgument(
"w_dims[0] must be divisible by weight_scale.dims()[0]");
}
int k = w_dims[0];
int n = w_dims[1];
int k = w_dims[1];
int m = x.numel() / k;
switch (quant_bit) {
case 8: {
auto mixed_gemm_runner =
CutlassFpAIntBGemmRunner<typename PDDataTypeTraits<T>::DataType,
uint8_t>();
int mixgemm_max_size = std::max(n, k);
DenseTensor mixgemm_workspace;
int64_t mixgemm_workspace_size_bytes =
mixed_gemm_runner.getWorkspaceSize(m, mixgemm_max_size, mixgemm_max_size);
int64_t mixgemm_workspace_size_bytes = mixed_gemm_runner.getWorkspaceSize(
m, mixgemm_max_size, mixgemm_max_size);
mixgemm_workspace.Resize({mixgemm_workspace_size_bytes});
dev_ctx.template Alloc<uint8_t>(&mixgemm_workspace);
......@@ -53,21 +62,56 @@ void WeightOnlyMatMulKernel(const Context& dev_ctx,
x.data<T>()),
reinterpret_cast<const uint8_t*>(weight.data<int8_t>()),
reinterpret_cast<const float*>(weight_scale.data<float>()),
reinterpret_cast<typename PDDataTypeTraits<T>::DataType*>(out->data<T>()),
reinterpret_cast<typename PDDataTypeTraits<T>::DataType*>(
out->data<T>()),
m,
n,
k,
mixgemm_workspace_data,
mixgemm_workspace_size_bytes,
dev_ctx.stream());
} break;
case 4: {
auto mixed_gemm_runner =
CutlassFpAIntBGemmRunner<typename PDDataTypeTraits<T>::DataType,
cutlass::uint4b_t>();
int mixgemm_max_size = std::max(n, k);
DenseTensor mixgemm_workspace;
int64_t mixgemm_workspace_size_bytes = mixed_gemm_runner.getWorkspaceSize(
m, mixgemm_max_size, mixgemm_max_size);
mixgemm_workspace.Resize({mixgemm_workspace_size_bytes});
dev_ctx.template Alloc<uint8_t>(&mixgemm_workspace);
char* mixgemm_workspace_data =
reinterpret_cast<char*>(mixgemm_workspace.data<uint8_t>());
mixed_gemm_runner.gemm(
reinterpret_cast<const typename PDDataTypeTraits<T>::DataType*>(
x.data<T>()),
reinterpret_cast<const cutlass::uint4b_t*>(weight.data<int8_t>()),
reinterpret_cast<const float*>(weight_scale.data<float>()),
reinterpret_cast<typename PDDataTypeTraits<T>::DataType*>(
out->data<T>()),
m,
n,
k,
mixgemm_workspace_data,
mixgemm_workspace_size_bytes,
dev_ctx.stream());
} break;
default:
PADDLE_THROW(errors::Unimplemented(
"Quant_bits (%d) is not supported when gemm ", quant_bit));
break;
}
#else
LOG(ERROR) << "Please compile with cutlass to EnableUseCutlass()";
#endif
}
} // namespace phi
PD_REGISTER_KERNEL(weight_only_mat_mul,
PD_REGISTER_KERNEL(weight_only_matmul,
GPU,
ALL_LAYOUT,
phi::WeightOnlyMatMulKernel,
phi::WeightOnlyMatmulKernel,
phi::dtype::float16) {}
......@@ -39,51 +39,109 @@ void per_channel_scale(float* scale, const T* input, size_t m, size_t n) {
}
}
template <typename T, typename D>
void per_channel_quant(
D* output, const T* input, const float* scale, size_t m, size_t n) {
for (size_t i = 0; i < m; i++) {
for (size_t j = 0; j < n; j++) {
output[i * n + j] = static_cast<D>(
round(static_cast<float>(input[i * n + j]) / scale[j]));
template <typename T, int quant_bit = 8>
void per_channel_quant(int8_t* output,
const T* input,
const float* scale,
size_t num_rows,
size_t num_cols) {
size_t bytes_per_out_col = num_cols * quant_bit / 8;
for (size_t ii = 0; ii < num_rows; ++ii) {
int8_t* current_quantized_weight_row = output + ii * bytes_per_out_col;
const T* current_weight_row = input + ii * num_cols;
for (size_t jj = 0; jj < bytes_per_out_col; ++jj) {
if (quant_bit == 8) {
const float col_scale = scale[jj];
const float weight_elt = static_cast<float>(current_weight_row[jj]);
const float scaled_weight = round(weight_elt / col_scale);
const int8_t clipped_weight = static_cast<int8_t>(
std::max(-127.f, std::min(127.f, scaled_weight)));
current_quantized_weight_row[jj] = clipped_weight;
} else if (quant_bit == 4) {
// We will pack two int4 elements per iteration of the inner loop.
int8_t packed_int4s = 0;
for (int packed_idx = 0; packed_idx < 2; ++packed_idx) {
const size_t input_idx = 2 * jj + packed_idx;
if (input_idx < num_cols) {
const float col_scale = scale[input_idx];
const float weight_elt =
static_cast<float>(current_weight_row[input_idx]);
const float scaled_weight = round(weight_elt / col_scale);
int int_weight = static_cast<int>(scaled_weight);
const int8_t clipped_weight = std::max(-7, std::min(7, int_weight));
// Kill the sign extension bits (hence 0x0F mask) then shift to
// upper bits if packing the second int4 and or the bits into the
// final result.
packed_int4s |= ((clipped_weight & 0x0F) << (4 * packed_idx));
}
}
current_quantized_weight_row[jj] = packed_int4s;
} else {
phi::errors::Unimplemented("Unsupported quantization bits: %d",
quant_bit);
}
}
}
void row_major_to_column_major(int8_t* col_major_tensor,
const int8_t* row_major_tensor,
const std::vector<size_t>& shape) {
size_t m = shape[0];
size_t n = shape[1];
for (size_t i = 0; i < m * n; i++) {
size_t im = i / n;
size_t in = i % n;
col_major_tensor[in * m + im] = row_major_tensor[im * n + in];
}
}
void add_bias_and_interleave_int8s_inplace(int8_t* int8_tensor_ptr,
size_t num_elts) {
int8_t* int8_tensor = reinterpret_cast<int8_t*>(int8_tensor_ptr);
for (size_t ii = 0; ii < num_elts; ++ii) {
int8_tensor[ii] =
static_cast<int8_t>(static_cast<int>(int8_tensor[ii]) + 128);
}
// Step 2 will transform the layout of a 32-bit register in CUDA in order to
// match the int4 layout. This has no performance benefit and is purely so
// that int4 and int8 have the same layout. Pictorially, this does the
// following: bit 32 0
// [elt_3 elt_2 elt_1 elt_0] (each elt occupies 8 bits)
//
// And it will rearrange the output 32 bit register to be the following:
// bit 32 0
// [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits)
template <int quant_bit = 8>
void add_bias_and_interleave_inplace(int8_t* tensor_ptr, size_t num_elts) {
const size_t num_bytes = num_elts * quant_bit / 8;
for (size_t ii = 0; ii < num_bytes; ++ii) {
if (quant_bit == 8) {
tensor_ptr[ii] =
static_cast<int8_t>(static_cast<int>(tensor_ptr[ii]) + 128);
} else {
int8_t transformed_packed_int4s = 0;
int8_t transformed_first_elt =
(int8_t(tensor_ptr[ii] << 4) >> 4) +
8; // The double shift here is to ensure sign extension
int8_t transformed_second_elt = (tensor_ptr[ii] >> 4) + 8;
if (!(transformed_first_elt >= 0 && transformed_first_elt <= 15)) {
phi::errors::InvalidArgument(
"Illegal result for int4 transform (first elt)");
}
if (!(transformed_second_elt >= 0 && transformed_second_elt <= 15)) {
phi::errors::InvalidArgument(
"Illegal result for int4 transform (second elt)");
}
// We don't need to mask in these ops since everything should be in the
// range 0-15
transformed_packed_int4s |= transformed_first_elt;
transformed_packed_int4s |= (transformed_second_elt << 4);
tensor_ptr[ii] = transformed_packed_int4s;
}
}
if (quant_bit == 8) {
for (size_t base = 0; base < num_elts; base += 4) {
std::swap(int8_tensor[base + 1], int8_tensor[base + 2]);
std::swap(tensor_ptr[base + 1], tensor_ptr[base + 2]);
}
} else {
const size_t num_registers = num_bytes / 4;
uint32_t* register_ptr = reinterpret_cast<uint32_t*>(tensor_ptr);
for (size_t ii = 0; ii < num_registers; ++ii) {
const uint32_t current_register = register_ptr[ii];
uint32_t transformed_register = 0;
for (int dest_idx = 0; dest_idx < 8; ++dest_idx) {
const int src_idx =
dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1;
const int src_shift = 4 * src_idx;
const int dest_shift = 4 * dest_idx;
const uint32_t src_bits = (current_register >> src_shift) & 0xF;
transformed_register |= (src_bits << dest_shift);
}
register_ptr[ii] = transformed_register;
}
}
}
template <int quant_bit>
void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor,
const int8_t* quantized_tensor,
const std::vector<size_t>& shape,
......@@ -92,9 +150,8 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor,
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
const int BITS_PER_ELT = 8;
const int BITS_PER_ELT = quant_bit;
const int K = 16 / BITS_PER_ELT;
// const int ELTS_PER_BYTE = 8 / BITS_PER_ELT;
const int ELTS_PER_REG = 32 / BITS_PER_ELT;
const uint32_t* input_byte_ptr =
......@@ -102,7 +159,6 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor,
uint32_t* output_byte_ptr =
reinterpret_cast<uint32_t*>(permuted_quantized_tensor);
// int MMA_SHAPE_N = 8;
int B_ROWS_PER_MMA = 8 * K;
const int elts_in_int32 = 32 / BITS_PER_ELT;
......@@ -118,15 +174,134 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor,
const int read_row = base_row + tile_read_row;
const int read_col = write_col;
const int64_t read_offset = int64_t(read_row) * num_vec_cols + read_col;
const int64_t read_offset =
static_cast<int64_t>(read_row) * num_vec_cols + read_col;
const int64_t write_offset =
int64_t(write_row) * num_vec_cols + write_col;
static_cast<int64_t>(write_row) * num_vec_cols + write_col;
output_byte_ptr[write_offset] = input_byte_ptr[read_offset];
}
}
}
}
template <int quant_bit>
void subbyte_transpose_impl(int8_t* transposed_quantized_tensor,
const int8_t* quantized_tensor,
const std::vector<size_t>& shape) {
const int bits_per_elt = quant_bit;
// FT_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be
// 2-D or 3-D");
// const size_t num_experts = 1;
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
const size_t col_bytes = num_cols * bits_per_elt / 8;
const size_t col_bytes_trans = num_rows * bits_per_elt / 8;
// const size_t num_bytes = size_t(num_experts) * num_rows * col_bytes;
const uint8_t* input_byte_ptr =
reinterpret_cast<const uint8_t*>(quantized_tensor);
uint8_t* output_byte_ptr =
reinterpret_cast<uint8_t*>(transposed_quantized_tensor);
static constexpr int ELTS_PER_BYTE = 8 / quant_bit;
static constexpr int M_TILE_L1 = 64;
static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE;
uint8_t cache_buf[M_TILE_L1][N_TILE_L1];
static constexpr int VECTOR_WIDTH = std::min(32, N_TILE_L1);
// const int num_m_tiles = (num_rows + M_TILE_L1 - 1) / M_TILE_L1;
// const int num_n_tiles = (col_bytes + N_TILE_L1 - 1) / N_TILE_L1;
for (size_t row_tile_start = 0; row_tile_start < num_rows;
row_tile_start += M_TILE_L1) {
for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes;
col_tile_start_byte += N_TILE_L1) {
const int row_limit = std::min(row_tile_start + M_TILE_L1, num_rows);
const int col_limit =
std::min(col_tile_start_byte + N_TILE_L1, col_bytes);
for (int ii = 0; ii < M_TILE_L1; ++ii) {
const int row = row_tile_start + ii;
for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) {
const int col = col_tile_start_byte + jj;
const size_t logical_src_offset = row * col_bytes + col;
if (row < row_limit && col < col_limit) {
for (int v = 0; v < VECTOR_WIDTH; ++v) {
cache_buf[ii][jj + v] = input_byte_ptr[logical_src_offset + v];
}
}
}
}
if (quant_bit == 8) {
for (int ii = 0; ii < M_TILE_L1; ++ii) {
for (int jj = ii + 1; jj < N_TILE_L1; ++jj) {
std::swap(cache_buf[ii][jj], cache_buf[jj][ii]);
}
}
} else if (quant_bit == 4) {
for (int ii = 0; ii < M_TILE_L1; ++ii) {
// Using M_TILE_L1 here is deliberate since we assume that the cache
// tile is square in the number of elements (not necessarily the
// number of bytes).
for (int jj = ii + 1; jj < M_TILE_L1; ++jj) {
const int ii_byte = ii / ELTS_PER_BYTE;
const int ii_bit_offset = ii % ELTS_PER_BYTE;
const int jj_byte = jj / ELTS_PER_BYTE;
const int jj_bit_offset = jj % ELTS_PER_BYTE;
uint8_t src_elt =
0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset));
uint8_t tgt_elt =
0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset));
cache_buf[ii][jj_byte] &= (0xF0 >> (4 * jj_bit_offset));
cache_buf[jj][ii_byte] &= (0xF0 >> (4 * ii_bit_offset));
cache_buf[ii][jj_byte] |= (tgt_elt << (4 * jj_bit_offset));
cache_buf[jj][ii_byte] |= (src_elt << (4 * ii_bit_offset));
}
}
} else {
phi::errors::Unimplemented("Unsupported quantization bits: %d",
quant_bit);
}
const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE;
const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE;
const int row_limit_trans =
std::min(row_tile_start_trans + M_TILE_L1, num_cols);
const int col_limit_trans =
std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans);
for (int ii = 0; ii < M_TILE_L1; ++ii) {
const int row = row_tile_start_trans + ii;
for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) {
const int col = col_tile_start_byte_trans + jj;
const size_t logical_tgt_offset = row * col_bytes_trans + col;
if (row < row_limit_trans && col < col_limit_trans) {
for (int v = 0; v < VECTOR_WIDTH; ++v) {
output_byte_ptr[logical_tgt_offset + v] = cache_buf[ii][jj + v];
}
}
}
}
}
}
}
template <int quant_bit>
void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor,
const int8_t* quantized_tensor,
const std::vector<size_t>& shape) {
......@@ -134,7 +309,7 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor,
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
const size_t BITS_PER_ELT = 8;
const size_t BITS_PER_ELT = quant_bit;
const size_t elts_in_int32 = 32 / BITS_PER_ELT;
const size_t rows_per_tile = 64;
......@@ -169,6 +344,5 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor,
}
}
}
} // namespace phi
#endif // PADDLE_PHI_KERNELS_IMPL_QUANT_FOR_COMPRESS_KERNEL_IMPL_H_
......@@ -16,7 +16,7 @@ limitations under the License. */
namespace phi {
template <typename T, typename Context>
void LLMInt8MatMulKernel(const Context& dev_ctx,
void LLMInt8MatmulKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& weight,
const DenseTensor& weight_scale,
......
......@@ -16,7 +16,7 @@ limitations under the License. */
namespace phi {
template <typename T, typename Context>
void WeightOnlyMatMulKernel(const Context& dev_ctx,
void WeightOnlyMatmulKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& weight,
const DenseTensor& weight_scale,
......
......@@ -1892,11 +1892,11 @@ def linear_compress(
):
if in_dynamic_mode():
if algo == "llm.int8":
y = _C_ops.llm_int8_mat_mul(
y = _C_ops.llm_int8_matmul(
x, weight, weight_scale, config['threshold']
)
elif algo == "weight_only":
y = _C_ops.weight_only_mat_mul(x, weight, weight_scale)
y = _C_ops.weight_only_matmul(x, weight, weight_scale)
else:
raise ValueError(
"Unknown algo: '{}'. It can only be 'llm.int8' or 'weight_only'.".format(
......@@ -1915,11 +1915,19 @@ def linear_compress(
if algo == "llm.int8":
type = "llm_int8_matmul"
inputs = {'X': [x], 'Y': [weight], 'weight_scale': [weight_scale]}
inputs = {
'x': [x],
'weight': [weight],
'weight_scale': [weight_scale],
}
attrs = {'algo': algo, 'threshold': config['threshold']}
elif algo == "weight_only":
type = "weight_only_matmul"
inputs = {'X': [x], 'Y': [weight], 'weight_scale': [weight_scale]}
inputs = {
'x': [x],
'weight': [weight],
'weight_scale': [weight_scale],
}
attrs = {}
else:
raise ValueError(
......
......@@ -301,10 +301,13 @@ class LinearCompress(Layer):
weight_attr = paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(weight_tensor)
)
weight_shape = (
[self.weight.shape[1], self.weight.shape[0]]
if self.bits == 8
else [self.weight.shape[1] / 2, self.weight.shape[0]]
)
self.weight = self.create_parameter(
shape=self.weight.shape
if self.layout == 0
else [self.weight.shape[1], self.weight.shape[0]],
shape=weight_shape,
attr=weight_attr,
dtype="int8",
is_bias=False,
......
......@@ -36,6 +36,7 @@ class LinearTestCase(unittest.TestCase):
self.in_features = 64
self.out_features = 64
self.algo = "weight_only"
self.bits = 8
def setUp(self):
self.config()
......@@ -62,6 +63,7 @@ class LinearTestCase(unittest.TestCase):
self.in_features,
self.out_features,
bias_attr=bias_attr,
bits=8,
algo=self.algo,
config=self.config,
)
......@@ -112,5 +114,15 @@ class LinearTestCase3(LinearTestCase):
self.atol = 1e-1
class LinearTestCase4(LinearTestCase):
def config(self):
super().config()
self.dtype = 'float16'
self.bias = True
self.in_features = 128
self.out_features = 64
self.bits = 4
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册