diff --git a/paddle/fluid/operators/fused/attn_gemm_int8.h b/paddle/fluid/operators/fused/attn_gemm_int8.h index 705cb8ece418e886c88a35334b8271b924a228fc..8dc4810b1f3b92d486aaebb39663d9abc0c7208a 100644 --- a/paddle/fluid/operators/fused/attn_gemm_int8.h +++ b/paddle/fluid/operators/fused/attn_gemm_int8.h @@ -57,29 +57,29 @@ class AttnMatmulINT8 { const int quant_round_type = 1, const float quant_max_bound = 127.0, const float quant_min_bound = -127.0) { - quantize_kernel_launcher(input->data(), - input_tmp->data(), - quant_in_scale, - m_, - k_, - quant_round_type, - quant_max_bound, - quant_min_bound, - dev_ctx_.stream()); + LaunchQuantKernel(input->data(), + input_tmp->data(), + quant_in_scale, + m_, + k_, + quant_round_type, + quant_max_bound, + quant_min_bound, + dev_ctx_.stream()); helpers_[0]->GEMM(input_tmp->data(), weight->data(), output_tmp->data(), dev_ctx_.stream()); - dequantize_kernel_launcher(output_tmp->data(), - output->data(), - m_, - n_, - dev_ctx_.stream(), - gpu_config_.get(), - quant_in_scale, - dequant_out_scale->data()); + LaunchDequantKernel(output_tmp->data(), + output->data(), + m_, + n_, + dev_ctx_.stream(), + gpu_config_.get(), + quant_in_scale, + dequant_out_scale->data()); if (compute_bias_) { // bias_out = output + bias @@ -126,14 +126,14 @@ class AttnMatmulINT8 { output_tmp->data(), dev_ctx_.stream()); - dequantize_kernel_launcher(output_tmp->data(), - output->data(), - m_, - n_, - dev_ctx_.stream(), - gpu_config_.get(), - quant_in_scale, - dequant_out_scale->data()); + LaunchDequantKernel(output_tmp->data(), + output->data(), + m_, + n_, + dev_ctx_.stream(), + gpu_config_.get(), + quant_in_scale, + dequant_out_scale->data()); if (compute_bias_) { // bias_out = output + bias @@ -162,15 +162,15 @@ class AttnMatmulINT8 { const int quant_round_type = 1, const float quant_max_bound = 127.0, const float quant_min_bound = -127.0) { - quantize_kernel_launcher(input->data(), - input_tmp->data(), - quant_in_scale, - m_, - k_, - quant_round_type, - quant_max_bound, - quant_min_bound, - dev_ctx_.stream()); + LaunchQuantKernel(input->data(), + input_tmp->data(), + quant_in_scale, + m_, + k_, + quant_round_type, + quant_max_bound, + quant_min_bound, + dev_ctx_.stream()); helpers_[0]->GEMM(input_tmp->data(), weight->data(), diff --git a/paddle/fluid/operators/fused/quant_dequant_kernel.h b/paddle/fluid/operators/fused/quant_dequant_kernel.h index 164effe01d316a6c77b1cb0d78fca0f0c5dc8d3f..8e8fdc95e91b51fa295aa276ff100e5557cddded 100644 --- a/paddle/fluid/operators/fused/quant_dequant_kernel.h +++ b/paddle/fluid/operators/fused/quant_dequant_kernel.h @@ -47,14 +47,14 @@ __forceinline__ __device__ int8_t quant_helper(const T input, } template -__global__ void quantize_kernel(const T* input, - char4* output, - const float scale, - const int m, - const int n, - const int round_type, - const float max_bound, - const float min_bound) { +__global__ void QuantKernel(const T* input, + char4* output, + const float scale, + const int m, + const int n, + const int round_type, + const float max_bound, + const float min_bound) { int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2; int m_id = blockIdx.y * blockDim.y + threadIdx.y; @@ -74,36 +74,36 @@ __global__ void quantize_kernel(const T* input, } template -void quantize_kernel_launcher(const T* input, - int8_t* output, - const float scale, - const int m, - const int n, - const int round_type, - const float max_bound, - const float min_bound, - gpuStream_t stream) { +void LaunchQuantKernel(const T* input, + int8_t* output, + const float scale, + const int m, + const int n, + const int round_type, + const float max_bound, + const float min_bound, + gpuStream_t stream) { // TODO(minghaoBD): optimize the kennel launch times when m==1 or n==1 dim3 grid((n >> 2 + 31) / 32, (m + 31) / 32); dim3 block(32, 32); - quantize_kernel<<>>(input, - (char4*)output, // NOLINT - scale, - m, - n, - round_type, - max_bound, - min_bound); + QuantKernel<<>>(input, + (char4*)output, // NOLINT + scale, + m, + n, + round_type, + max_bound, + min_bound); } template -__global__ void dequantize_kernel(T* output, - const int32_t* input, - const int m, // batch size - const int n, // hidden - const float quant_in_scale, - const float* dequant_out_scale_data) { +__global__ void DequantKernel(T* output, + const int32_t* input, + const int m, // batch size + const int n, // hidden + const float quant_in_scale, + const float* dequant_out_scale_data) { int numel = m * n; int stride = blockDim.x * gridDim.x * VecSize; int idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; @@ -128,15 +128,15 @@ __global__ void dequantize_kernel(T* output, } template -void dequantize_kernel_launcher(const int32_t* input, - T* output, - const int m, // m - const int n, // n - gpuStream_t stream, - GpuLaunchConfig* gpu_config, - const float quant_in_scale, - const float* dequant_out_scale_data) { - dequantize_kernel +void LaunchDequantKernel(const int32_t* input, + T* output, + const int m, // m + const int n, // n + gpuStream_t stream, + GpuLaunchConfig* gpu_config, + const float quant_in_scale, + const float* dequant_out_scale_data) { + DequantKernel <<block_per_grid, gpu_config->thread_per_block, 0, stream>>>( output, input, m, n, quant_in_scale, dequant_out_scale_data); } diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index d2add2e0a26b7f1e92fc7b290a0d4086fb652221..e0247133dab7dc8ae7a9a6bbe0b0244749c32c9e 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -523,6 +523,14 @@ func : matmul backward : matmul_grad +- op : matmul_int8 + args : (Tensor x, Tensor y, bool transpose_x = false, bool transpose_y = false) + output : Tensor + infer_meta : + func : MatmulInt8InferMeta + kernel : + func : matmul_int8 + - op : matrix_rank args : (Tensor x, float tol, bool use_default_tol=true, bool hermitian=false) output : Tensor(out) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index de0794ada9772d3aa47398021602f1fb0545917d..719a5f2f130af75393edb79e92fc4c7efcccc5e5 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -2096,6 +2096,76 @@ void MatmulInferMeta(const MetaTensor& x, out->set_layout(x.layout()); } +void MatmulInt8InferMeta(const MetaTensor& x, + const MetaTensor& y, + bool trans_x, + bool trans_y, + MetaTensor* out) { + std::vector dims_x = phi::vectorize(x.dims()); + std::vector dims_y = phi::vectorize(y.dims()); + auto ndims_x = dims_x.size(); + auto ndims_y = dims_y.size(); + PADDLE_ENFORCE_GT(ndims_x, + 0UL, + phi::errors::InvalidArgument( + "The Input(x) dims size must be greater than 0," + " but reviced dims size is 0. ")); + PADDLE_ENFORCE_GT(ndims_y, + 0UL, + phi::errors::InvalidArgument( + "The Input(y) dims size must be greater than 0," + " but reviced dims size is 0. ")); + + bool x_broadcasted = false, y_broadcasted = false; + if (ndims_x == 1) { + dims_x.insert(dims_x.begin(), 1); + ndims_x = 2; + x_broadcasted = true; + } + + if (ndims_y == 1) { + dims_y.push_back(1); + ndims_y = 2; + y_broadcasted = true; + } + + size_t M, N; + if (trans_x) { + M = dims_x[ndims_x - 1]; + } else { + M = dims_x[ndims_x - 2]; + } + if (trans_y) { + N = dims_y[ndims_y - 2]; + } else { + N = dims_y[ndims_y - 1]; + } + + std::vector new_dims; + if (ndims_x > ndims_y) { + new_dims.assign(dims_x.begin(), dims_x.end() - 2); + } else if (ndims_x < ndims_y) { + new_dims.assign(dims_y.begin(), dims_y.end() - 2); + } else { + new_dims.reserve(ndims_x); + for (size_t i = 0; i < ndims_x - 2; ++i) { + new_dims.push_back(std::max(dims_x[i], dims_y[i])); + } + } + if (!x_broadcasted) { + new_dims.push_back(M); + } + if (!y_broadcasted) { + new_dims.push_back(N); + } + + auto ddim_out = phi::make_ddim(new_dims); + + out->set_dims(ddim_out); + out->set_dtype(phi::DataType::INT32); + out->set_layout(x.layout()); +} + void MatmulWithFlattenInferMeta(const MetaTensor& x, const MetaTensor& y, int x_num_col_dims, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 517b259f0149f6079ed4117903a273abfdc68f91..8c8604e703caf7672035a19662cff6754ec7e3e7 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -333,6 +333,12 @@ void MatmulInferMeta(const MetaTensor& x, bool trans_y, MetaTensor* out); +void MatmulInt8InferMeta(const MetaTensor& x, + const MetaTensor& y, + bool trans_x, + bool trans_y, + MetaTensor* out); + void MatmulWithFlattenInferMeta(const MetaTensor& x, const MetaTensor& y, int x_num_col_dims, diff --git a/paddle/phi/kernels/funcs/cublaslt.h b/paddle/phi/kernels/funcs/cublaslt.h index 6391b583a08db13c5a4709026d75b02f55e5663a..6278f159df075daf7350838517aea770a8252c5a 100644 --- a/paddle/phi/kernels/funcs/cublaslt.h +++ b/paddle/phi/kernels/funcs/cublaslt.h @@ -39,25 +39,16 @@ const std::map, CublasLtAlgoParam> AlgoParamCache{}; class CublasLtHelper { public: - CublasLtHelper(int m, int k, int n) - : alpha_(1), beta_(0), m_(m), k_(k), n_(n) { + CublasLtHelper(int m, int k, int n, cublasLtHandle_t handle) + : handle_(handle), alpha_(1), beta_(0), m_(m), k_(k), n_(n) { cublasStatus_t status; - // handle and matmul desc - status = dyl::cublasLtCreate(&handle_); #if CUBLAS_VER_MAJOR < 11 cudaDataType_t cudaComputeType = CUDA_R_32I; #else cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I; #endif - PADDLE_ENFORCE_EQ( - status, - CUBLAS_STATUS_SUCCESS, - phi::errors::External( - "cublasLtMatrixLayoutCreate execution error" - "refer https://docs.nvidia.com/cuda/cublas/index.html to get more " - "information")); - + // matmul desc #if CUBLAS_VER_MAJOR < 11 status = dyl::cublasLtMatmulDescCreate(&matmul_desc_, cudaComputeType); #else @@ -179,7 +170,7 @@ class CublasLtHelper { } ~CublasLtHelper() {} - void GEMM(int8_t* A_dev, + void GEMM(const int8_t* A_dev, const int8_t* B_dev, int32_t* C_dev, cudaStream_t stream, @@ -226,14 +217,14 @@ class CublasLtHelper { cublasLtMatmulAlgo_t algo_; - int32_t alpha_; - int32_t beta_; + int32_t alpha_ = 1; + int32_t beta_ = 0; - int m_; - int k_; - int n_; + int m_ = 0; + int k_ = 0; + int n_ = 0; - size_t workspace_size_; + size_t workspace_size_ = 0; }; } // namespace phi diff --git a/paddle/phi/kernels/funcs/gemm_int8_helper.h b/paddle/phi/kernels/funcs/gemm_int8_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..c848518c2a1a19fa27248f2882967fcf177240af --- /dev/null +++ b/paddle/phi/kernels/funcs/gemm_int8_helper.h @@ -0,0 +1,114 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Paddle/paddle/phi/kernels/funcs/cublaslt.h" + +namespace phi { + +template +class Int8GEMMHelper { + public: + Int8GEMMHelper(const phi::GPUContext &dev_ctx, + int m, + int k, + int n, + phi::DenseTensor &workspace, // NOLINT + phi::DenseTensor &input_workspace, // NOLINT + phi::DenseTensor &out_workspace, // NOLINT + int quant_round_type, + float quant_max_bound, + float quant_min_bound) + : dev_ctx_(dev_ctx), + m_(m), + k_(k), + n_(n), + quant_round_type_(quant_round_type), + quant_min_bound_(quant_min_bound), + quant_max_bound_(quant_max_bound), + workspace_(workspace), + input_workspace_(input_workspace), + out_workspace_(out_workspace) { + cublaslt_helper = std::make_unique>( + m, k, n, dev_ctx.cublaslt_handle()); + } + + void Compute(const phi::DenseTensor *input, + const phi::DenseTensor *weight, // int8, Need be transposed + const phi::DenseTensor *dequant_out_scales, + const float quant_in_scale, + phi::DenseTensor *output, + bool quant_in = false, + bool dequant_out = false) { + phi::DenseTensor input_tmp, out_tmp; + if (quant_in) { + input_tmp = input_workspace_; + LaunchQuantKernel(input->data(), + input_tmp.data(), + quant_in_scale, + m_, + k_, + quant_round_type_, + quant_max_bound_, + quant_min_bound_, + dev_ctx_.stream()); + } else { + input_tmp = *input; + } + + if (dequant_out) { + out_tmp = out_workspace_; + } else { + out_tmp = *output; + } + + cublaslt_helper->GEMM(input_tmp.data(), + weight->data(), + out_tmp.data(), + dev_ctx_.stream(), + (void *)workspace_.data(), + workspace_.numel()); + + if (dequant_out) { + auto gpu_config = std::make_unique( + phi::backends::gpu::GetGpuLaunchConfig1D( + dev_ctx_, m_ * n_, DequantKernelVecSize)); + LaunchDequantKernel(out_tmp.data(), + output->data(), + m_, + n_, + dev_ctx_.stream(), + gpu_config.get(), + quant_in_scale, + dequant_out_scales->data()); + } + } + + private: + const phi::GPUContext &dev_ctx_; + int m_; + int k_; + int n_; + int quant_round_type_; + float quant_max_bound_; + float quant_min_bound_; + phi::DenseTensor &workspace_; // char + phi::DenseTensor &input_workspace_; // int8_t + phi::DenseTensor &out_workspace_; // int32_t + + std::unique_ptr> cublaslt_helper; +}; + +} // namespace phi diff --git a/paddle/phi/kernels/funcs/quant_dequant.h b/paddle/phi/kernels/funcs/quant_dequant.h index 62bfc9cfcf1bee7e74ced08ed3fe500755a66715..f640dcc369bb715fbec30dee993114fb3b47995b 100644 --- a/paddle/phi/kernels/funcs/quant_dequant.h +++ b/paddle/phi/kernels/funcs/quant_dequant.h @@ -61,14 +61,14 @@ __forceinline__ __device__ int8_t quant_helper(const T input, } template -__global__ void quantize_kernel(const T* input, - char4* output, - const float scale, - const int m, - const int n, - const int round_type, - const float max_bound, - const float min_bound) { +__global__ void QuantKernel(const T* input, + char4* output, + const float scale, + const int m, + const int n, + const int round_type, + const float max_bound, + const float min_bound) { int n_id = (blockIdx.x * blockDim.x + threadIdx.x) << 2; int m_id = blockIdx.y * blockDim.y + threadIdx.y; @@ -88,36 +88,36 @@ __global__ void quantize_kernel(const T* input, } template -void quantize_kernel_launcher(const T* input, - int8_t* output, - const float scale, - const int m, - const int n, - const int round_type, - const float max_bound, - const float min_bound, - gpuStream_t stream) { +void LaunchQuantKernel(const T* input, + int8_t* output, + const float scale, + const int m, + const int n, + const int round_type, + const float max_bound, + const float min_bound, + gpuStream_t stream) { // TODO(minghaoBD): optimize the kennel launch times when m==1 or n==1 dim3 grid((n >> 2 + 31) / 32, (m + 31) / 32); dim3 block(32, 32); - quantize_kernel<<>>(input, - (char4*)output, // NOLINT - scale, - m, - n, - round_type, - max_bound, - min_bound); + QuantKernel<<>>(input, + (char4*)output, // NOLINT + scale, + m, + n, + round_type, + max_bound, + min_bound); } template -__global__ void dequantize_kernel(T* output, - const int32_t* input, - const int m, // batch size - const int n, // hidden - const float quant_in_scale, - const float* dequant_out_scale_data) { +__global__ void DequantKernel(T* output, + const int32_t* input, + const int m, // batch size + const int n, // hidden + const float quant_in_scale, + const float* dequant_out_scale_data) { int numel = m * n; int stride = blockDim.x * gridDim.x * VecSize; int idx = (blockIdx.x * blockDim.x + threadIdx.x) * VecSize; @@ -142,15 +142,15 @@ __global__ void dequantize_kernel(T* output, } template -void dequantize_kernel_launcher(const int32_t* input, - T* output, - const int m, // m - const int n, // n - gpuStream_t stream, - GpuLaunchConfig* gpu_config, - const float quant_in_scale, - const float* dequant_out_scale_data) { - dequantize_kernel +void LaunchDequantKernel(const int32_t* input, + T* output, + const int m, // m + const int n, // n + gpuStream_t stream, + GpuLaunchConfig* gpu_config, + const float quant_in_scale, + const float* dequant_out_scale_data) { + DequantKernel <<block_per_grid, gpu_config->thread_per_block, 0, stream>>>( output, input, m, n, quant_in_scale, dequant_out_scale_data); } diff --git a/paddle/phi/kernels/gpu/matmul_kernel.cu b/paddle/phi/kernels/gpu/matmul_kernel.cu index 590c041555f58410392ef346ba08999aa0b8bf77..c5271a4eeece6fe808ca9d82238690a3f754e446 100644 --- a/paddle/phi/kernels/gpu/matmul_kernel.cu +++ b/paddle/phi/kernels/gpu/matmul_kernel.cu @@ -30,6 +30,9 @@ PD_REGISTER_KERNEL(matmul, phi::dtype::complex, phi::dtype::complex) {} +PD_REGISTER_KERNEL( + matmul_int8, GPU, ALL_LAYOUT, phi::MatmulInt8Kernel, int8_t) {} + PD_REGISTER_KERNEL(matmul_with_flatten, GPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h b/paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h index 00e2893558f5b6567bb7fe8d0faa4ac1c18af661..5ebbc8d2db5fb3c2bc65ce2d35b0615651c754ef 100644 --- a/paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h +++ b/paddle/phi/kernels/impl/llm_int8_matmul_kernel_impl.h @@ -667,7 +667,8 @@ void LLMGemm(const phi::GPUContext& dev_ctx, dev_ctx.Alloc(&int_out); { - auto helper = std::make_unique(m, k, n); + auto helper = + std::make_unique(m, k, n, dev_ctx.cublaslt_handle()); helper->GEMM(quant_input.data(), weight->data(), int_out.data(), diff --git a/paddle/phi/kernels/impl/matmul_kernel_impl.h b/paddle/phi/kernels/impl/matmul_kernel_impl.h index cac393d6360517f53dfe0549d1ac9a0de9ef383a..a77fbd961312f9daa4c27d7c294f9967ddfaa587 100644 --- a/paddle/phi/kernels/impl/matmul_kernel_impl.h +++ b/paddle/phi/kernels/impl/matmul_kernel_impl.h @@ -16,11 +16,15 @@ limitations under the License. */ #include "glog/logging.h" +#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/autotune/cache_base.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h" #include "paddle/phi/kernels/funcs/complex_functors.h" +#if defined(PADDLE_WITH_CUDA) +#include "paddle/phi/kernels/funcs/cublaslt.h" +#endif #if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 #include "paddle/phi/kernels/autotune/auto_tune_base.h" #endif @@ -948,6 +952,15 @@ struct MatMulDispatcher { #endif } }; + +static phi::Allocator::AllocationPtr GetWorkspace(const phi::GPUContext& ctx, + size_t workspace_size) { + return phi::memory_utils::Alloc( + ctx.GetPlace(), + workspace_size, + phi::Stream(reinterpret_cast(ctx.stream()))); +} + #endif // PADDLE_WITH_CUDA template @@ -964,6 +977,107 @@ void MatMulFunction(const Context& ctx, ctx, x, y, x_dims, y_dims, out, trans_x, trans_y, flag); } +template +void MatMulInt8Function(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const std::vector& x_dims, + const std::vector& y_dims, + DenseTensor* out, + bool trans_x, + bool trans_y) { + PADDLE_ENFORCE_EQ( + x.dtype(), + DataType::INT8, + phi::errors::InvalidArgument( + "The type of input(x) used in int8 matmul must be (%s) does not " + "match the " + "type of data (%s) currently contained in the container.", + phi::CppTypeToDataType::Type(), + x.dtype())); + PADDLE_ENFORCE_EQ( + y.dtype(), + DataType::INT8, + phi::errors::InvalidArgument( + "The type of input(y) used in int8 matmul must be (%s) does not " + "match the " + "type of data (%s) currently contained in the container.", + phi::CppTypeToDataType::Type(), + x.dtype())); +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11020 + const int x_ndim = x_dims.size(); + const int y_ndim = y_dims.size(); + PADDLE_ENFORCE_EQ( + x_ndim, + 2, + phi::errors::InvalidArgument("[INT8 GEMM] The number of dims of input(x) " + "must be equal to 2 but received %d", + x_ndim)); + PADDLE_ENFORCE_EQ( + y_ndim, + 2, + phi::errors::InvalidArgument("[INT8 GEMM] The number of dims of input(x) " + "must be equal to 2 but received %d", + y_ndim)); + PADDLE_ENFORCE_EQ( + trans_x, + false, + phi::errors::InvalidArgument("[INT8 GEMM] Input(x) must be not " + "transposed to acheive better performance")); + PADDLE_ENFORCE_EQ( + trans_y, + true, + phi::errors::InvalidArgument("[INT8 GEMM] Input(y) must be transposed to " + "acheive better performance")); + + const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2]; + const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1]; + if (trans_y) { + PADDLE_ENFORCE_EQ( + y_dims[y_ndim - 1], + K, + phi::errors::InvalidArgument("Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 1, + K, + y_ndim - 1, + y_dims[y_ndim - 1])); + } else { + PADDLE_ENFORCE_EQ( + y_dims[y_ndim - 2], + K, + phi::errors::InvalidArgument("Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 2, + K, + y_ndim - 2, + y_dims[y_ndim - 2])); + } + const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1]; + + size_t workspace_size = static_cast(4) * 1024 * 1024; + phi::Allocator::AllocationPtr workspace = GetWorkspace(ctx, workspace_size); + + // TODO(wufeisheng): cublaslt_helper is a temp scheme for Int8 GEMM, + // and releted functions need to be integrated into + // phi::funcs::MatmulWithCublasLt + auto cublaslt_helper = CublasLtHelper(M, K, N, ctx.cublaslt_handle()); + + ctx.template Alloc(out); + cublaslt_helper.GEMM(x.data(), + y.data(), + out->data(), + ctx.stream(), + workspace->ptr()); + +#else + PADDLE_THROW(phi::errors::Unimplemented( + "MatmulInt8 op needs paddle with cuda and cuda version >= 11.2")); +#endif +} + template void MatmulKernel(const Context& ctx, const DenseTensor& x, @@ -987,6 +1101,29 @@ void MatmulKernel(const Context& ctx, ctx, x, y, x_dims, y_dims, out, transpose_x, transpose_y); } +template +void MatmulInt8Kernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + bool transpose_x, + bool transpose_y, + DenseTensor* out) { + PADDLE_ENFORCE_NE( + phi::product(x.dims()), + 0, + phi::errors::InvalidArgument("The Input(X) dims size must not be equal 0," + " but reviced dims size is 0. ")); + PADDLE_ENFORCE_NE( + phi::product(y.dims()), + 0, + phi::errors::InvalidArgument("The Input(Y) dims size must not be equal 0," + " but reviced dims size is 0. ")); + const std::vector x_dims = vectorize(x.dims()); + const std::vector y_dims = vectorize(y.dims()); + MatMulInt8Function( + ctx, x, y, x_dims, y_dims, out, transpose_x, transpose_y); +} + template void MatmulWithFlattenKernel(const Context& dev_ctx, const DenseTensor& x, diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index a3dadaeb2533479350c821841c9ccaa4f76545fc..aff039509e5a2d4b344c7db93fa0f8a44010a783 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -157,6 +157,7 @@ if(WIN32) list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op) list(REMOVE_ITEM TEST_OPS test_rms_norm_op) list(REMOVE_ITEM TEST_OPS test_linear_compress) + list(REMOVE_ITEM TEST_OPS test_matmul_int8_op) endif() list(REMOVE_ITEM TEST_OPS test_checkpoint_saver) diff --git a/test/legacy_test/test_matmul_int8_op.py b/test/legacy_test/test_matmul_int8_op.py new file mode 100644 index 0000000000000000000000000000000000000000..b1aa9d3286845da8c12d2e75b258a4ddac17b0f7 --- /dev/null +++ b/test/legacy_test/test_matmul_int8_op.py @@ -0,0 +1,77 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from test_sparse_attention_op import get_cuda_version + +import paddle +from paddle.fluid import core + +paddle.disable_static() + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "MatmulInt8 requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class TestMatmulInt8(unittest.TestCase): + """ + Test matmul int8 + Only NT (Non-Transposed-A and Transposed-B) is supported + """ + + def config(self): + self.dtype = 'int8' + self.rtol = 1e-5 + self.atol = 1e-2 + self.bias = True + self.m = 8 + self.k = 64 + self.n = 64 + + def setUp(self): + self.config() + self.input_a_np = np.random.randint(-127, 127, [self.m, self.k]).astype( + 'int32' + ) + self.input_b_np = np.random.randint(-127, 127, [self.k, self.n]).astype( + 'int32' + ) + self.input_a = paddle.to_tensor(self.input_a_np, dtype=self.dtype) + self.input_b = paddle.to_tensor( + self.input_b_np.transpose((1, 0)), dtype=self.dtype + ) + + def get_reference_out(self): + out = np.dot(self.input_a_np, self.input_b_np) + return out + + def get_op_out(self): + out = paddle._C_ops.matmul_int8(self.input_a, self.input_b, False, True) + return out.numpy() + + def test_matmul_int8(self): + out_real = self.get_op_out() + out_expect = self.get_reference_out() + np.testing.assert_allclose( + out_real, out_expect, rtol=self.rtol, atol=self.atol + ) + + +if __name__ == '__main__': + unittest.main()