// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/kernels/quant_for_compress_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/impl/quant_for_compress_kernel_impl.h" namespace phi { template void quant_compute(const DeviceContext& dev_ctx, const DenseTensor& x, DenseTensor* out, DenseTensor* scale, const std::string& layout) { const auto x_dims = x.dims(); PADDLE_ENFORCE_EQ( x_dims.size(), 2, phi::errors::InvalidArgument( "the x tensor of quant op must be 2D, but got[%d]", x_dims.size())); size_t m = x_dims[0]; size_t n = x_dims[1]; int64_t num = x.numel(); DDim dims = {num}; const T* x_data = x.data(); D* out_data = out->data(); float* scale_data = scale->data(); DenseTensor x_int(out->type()); x_int.Resize({static_cast(m), static_cast(n)}); dev_ctx.template Alloc(&x_int); D* x_int_data = x_int.data(); DenseTensor int_processed(out->type()); int_processed.Resize(dims); dev_ctx.template Alloc(&int_processed); D* int_processed_data = int_processed.data(); DenseTensor int_processed_2(out->type()); int_processed_2.Resize(out->dims()); dev_ctx.template Alloc(&int_processed_2); D* int_processed_2_data = int_processed_2.data(); per_channel_scale(scale_data, x_data, m, n); per_channel_quant(x_int_data, x_data, scale_data, m, n); if (layout == "weight_only") { permute_B_rows_for_mixed_gemm( int_processed_data, x_int_data, std::vector{m, n}, (int64_t)80); row_major_to_column_major( int_processed_2_data, int_processed_data, std::vector{m, n}); interleave_column_major_tensor( out_data, int_processed_2_data, std::vector{m, n}); add_bias_and_interleave_int8s_inplace(out_data, num); } else if (layout == "llm.int8") { std::vector axis = {1, 0}; funcs::Transpose trans; trans(dev_ctx, x_int, out, axis); } else { phi::errors::InvalidArgument( "The layout must be weight_only or llm.int8, but got %s", layout); } } template void QuantForCompressKernel(const Context& dev_ctx, const DenseTensor& x, int bits, const std::string& layout, DenseTensor* out, DenseTensor* scale) { if (bits == 8) { dev_ctx.template Alloc(out); dev_ctx.template Alloc(scale); quant_compute(dev_ctx, x, out, scale, layout); } else { phi::errors::Unimplemented("The bits only support 8, but got[%d]", bits); } // VLOG(0) << "x: " << x.dtype() << x; // VLOG(0) << "out: " << out->dtype() << *out; } } // namespace phi PD_REGISTER_KERNEL(quant_for_compress, CPU, ALL_LAYOUT, phi::QuantForCompressKernel, float, phi::dtype::float16) {}