diff --git a/paddle/fluid/inference/tensorrt/convert/matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/matmul_op.cc index 7b017900a02c90dd11d543985879f2443794b4a5..b2e76b9a0e61bd2a079c4b3d9f8edb3065a27ef0 100644 --- a/paddle/fluid/inference/tensorrt/convert/matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/matmul_op.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/matmul_op_int8_plugin.h" namespace paddle { namespace framework { @@ -35,16 +36,26 @@ class MatMulOpConverter : public OpConverter { public: void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { - VLOG(3) << "convert a fluid matmul op to tensorrt mul layer without bias"; - + VLOG(3) << "convert a fluid matmul op to tensorrt matmul layer "; framework::OpDesc op_desc(op, nullptr); + nvinfer1::ILayer* layer = nullptr; + // Declare inputs auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]); auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]); + nvinfer1::Dims dims_x = input1->getDimensions(); + nvinfer1::Dims dims_y = input2->getDimensions(); + bool transpose_X = BOOST_GET_CONST(bool, op_desc.GetAttr("transpose_X")); bool transpose_Y = BOOST_GET_CONST(bool, op_desc.GetAttr("transpose_Y")); + auto output_name = op_desc.Output("Out")[0]; + float alpha = 1; + if (op_desc.HasAttr("alpha")) { + float alpha_tem = BOOST_GET_CONST(float, op_desc.GetAttr("alpha")); + alpha = alpha_tem; + } nvinfer1::MatrixOperation matrix_operation_X = transpose_X ? nvinfer1::MatrixOperation::kTRANSPOSE : nvinfer1::MatrixOperation::kNONE; @@ -52,82 +63,122 @@ class MatMulOpConverter : public OpConverter { transpose_Y ? nvinfer1::MatrixOperation::kTRANSPOSE : nvinfer1::MatrixOperation::kNONE; - auto* layer = - TRT_ENGINE_ADD_LAYER(engine_, MatrixMultiply, *input1, - matrix_operation_X, *input2, matrix_operation_Y); - - float alpha = BOOST_GET_CONST(float, op_desc.GetAttr("alpha")); - auto output_name = op_desc.Output("Out")[0]; - if (fabs(alpha - 1.0) < std::numeric_limits::epsilon()) { - engine_->SetITensor(output_name, layer->getOutput(0)); - } else { - // IScaleLayer requires the input must have at least - // three dimensions in static shape mode and at least - // four dimensions in dynamic shape mode. - auto* matmul_out = layer->getOutput(0); - nvinfer1::Dims out_shape = matmul_out->getDimensions(); - const int out_dims = out_shape.nbDims; - bool need_change_dim = false; - + if (op_desc.HasAttr("support_int8") && + engine_->precision() == AnalysisConfig::Precision::kInt8) { if (engine_->with_dynamic_shape()) { - if (out_dims == 3) { - need_change_dim = true; - } + VLOG(3) << "Convert a fluid matmul_op_int8_dynamic to TensorRT " + "MatmulPluginLayer"; + plugin::MatmulPluginDynamic* plugin = + new plugin::MatmulPluginDynamic(transpose_X, transpose_Y, alpha); + std::vector inputs{input1, input2}; + layer = engine_->AddDynamicPlugin(inputs.data(), inputs.size(), plugin); + RreplenishLayerAndOutput(layer, "matmul_op_int8_dynamic", {output_name}, + test_mode); } else { - if (out_dims == 2) { - need_change_dim = true; - } + VLOG(3) << "Convert a fluid matmul_op_int8_static to TensorRT " + "MatmulPluginLayer"; + plugin::MatmulPlugin* plugin = new plugin::MatmulPlugin( + dims_x, dims_y, transpose_X, transpose_Y, alpha); + std::vector inputs{input1, input2}; + layer = engine_->AddPluginV2IOExt(inputs.data(), inputs.size(), plugin); + RreplenishLayerAndOutput(layer, "matmul_op_int8_static", {output_name}, + test_mode); } - - if (need_change_dim) { - nvinfer1::Dims reshape_dim; - reshape_dim.nbDims = out_dims + 1; - reshape_dim.d[out_dims] = 1; - for (int i = 0; i < out_dims; i++) { - reshape_dim.d[i] = out_shape.d[i]; + } else { + VLOG(3) << "Convert a fluid matmul_op_float to TensorRT "; + layer = + TRT_ENGINE_ADD_LAYER(engine_, MatrixMultiply, *input1, + matrix_operation_X, *input2, matrix_operation_Y); + if (alpha == 1) { + RreplenishLayerAndOutput(layer, "matmul_op_float_no_alpha", + {output_name}, test_mode); + } else { + layer->setName( + ("matmul_op_float_has_alpha: MatrixMultiplyLayer (Output: " + + output_name + ")") + .c_str()); + // IScaleLayer requires the input must have at least + // three dimensions in static shape mode and at least + // four dimensions in dynamic shape mode. + auto* matmul_out = layer->getOutput(0); + nvinfer1::Dims out_shape = matmul_out->getDimensions(); + const int out_dims = out_shape.nbDims; + bool need_change_dim = false; + + if (engine_->with_dynamic_shape()) { + if (out_dims == 3) { + need_change_dim = true; + } + } else { + if (out_dims == 2) { + need_change_dim = true; + } } - auto* reshape_layer = - TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *matmul_out); - reshape_layer->setReshapeDimensions(reshape_dim); - matmul_out = reshape_layer->getOutput(0); - } + if (need_change_dim) { + nvinfer1::Dims reshape_dim; + reshape_dim.nbDims = out_dims + 1; + reshape_dim.d[out_dims] = 1; + for (int i = 0; i < out_dims; i++) { + reshape_dim.d[i] = out_shape.d[i]; + } + + auto* reshape_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *matmul_out); + reshape_layer->setReshapeDimensions(reshape_dim); + matmul_out = reshape_layer->getOutput(0); + reshape_layer->setName(("matmul_op_float_has_alpha_reshape_before: " + "ShuffleLayer (Output: " + + output_name + ")") + .c_str()); + } - auto create_weights = [&](float data, const std::string& type) -> float* { - std::unique_ptr tmp_tensor(new framework::Tensor()); - tmp_tensor->Resize({1}); - auto* tmp_data = tmp_tensor->mutable_data(platform::CPUPlace()); - tmp_data[0] = data; - engine_->SetWeights(output_name + "_add_scale_op_" + type, - std::move(tmp_tensor)); - return tmp_data; - }; - float* alpha_data = create_weights(alpha, "alpha"); - float* shift_data = create_weights(0.0, "shift"); - float* power_data = create_weights(1.0, "power"); - TensorRTEngine::Weight nv_alpha{nvinfer1::DataType::kFLOAT, - static_cast(alpha_data), 1}; - TensorRTEngine::Weight nv_shift{nvinfer1::DataType::kFLOAT, - static_cast(shift_data), 1}; - TensorRTEngine::Weight nv_power{nvinfer1::DataType::kFLOAT, - static_cast(power_data), 1}; - auto* scale_layer = TRT_ENGINE_ADD_LAYER( - engine_, Scale, *matmul_out, nvinfer1::ScaleMode::kUNIFORM, - nv_shift.get(), nv_alpha.get(), nv_power.get()); - auto* scale_out = scale_layer->getOutput(0); - - if (need_change_dim) { - auto* reshape_layer = - TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *scale_out); - reshape_layer->setReshapeDimensions(out_shape); - scale_out = reshape_layer->getOutput(0); + auto create_weights = [&](float data, + const std::string& type) -> float* { + std::unique_ptr tmp_tensor( + new framework::Tensor()); + tmp_tensor->Resize({1}); + auto* tmp_data = + tmp_tensor->mutable_data(platform::CPUPlace()); + tmp_data[0] = data; + engine_->SetWeights(output_name + "_add_scale_op_" + type, + std::move(tmp_tensor)); + return tmp_data; + }; + float* alpha_data = create_weights(alpha, "alpha"); + float* shift_data = create_weights(0.0, "shift"); + float* power_data = create_weights(1.0, "power"); + TensorRTEngine::Weight nv_alpha{nvinfer1::DataType::kFLOAT, + static_cast(alpha_data), 1}; + TensorRTEngine::Weight nv_shift{nvinfer1::DataType::kFLOAT, + static_cast(shift_data), 1}; + TensorRTEngine::Weight nv_power{nvinfer1::DataType::kFLOAT, + static_cast(power_data), 1}; + auto* scale_layer = TRT_ENGINE_ADD_LAYER( + engine_, Scale, *matmul_out, nvinfer1::ScaleMode::kUNIFORM, + nv_shift.get(), nv_alpha.get(), nv_power.get()); + auto* scale_out = scale_layer->getOutput(0); + scale_layer->setName( + ("matmul_op_float_has_alpha: ScaleLayer (Output: " + output_name + + ")") + .c_str()); + + if (need_change_dim) { + auto* reshape_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *scale_out); + reshape_layer->setReshapeDimensions(out_shape); + scale_out = reshape_layer->getOutput(0); + reshape_layer->setName(("matmul_op_float_has_alpha_reshape_after: " + "ShuffleLayer (Output: " + + output_name + ")") + .c_str()); + } + engine_->SetITensor(output_name, scale_out); + if (test_mode) { // the test framework can not determine which is the + // output, so place the declaration inside. + engine_->DeclareOutput(output_name); + } } - - engine_->SetITensor(output_name, scale_out); - } - if (test_mode) { // the test framework can not determine which is the - // output, so place the declaration inside. - engine_->DeclareOutput(output_name); } } }; diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index 3eece7e500e687715f68d9ae158dd0bec449de67..be6984d0f76b5005cbac66bc5d6b630e6b93343c 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -12,7 +12,8 @@ nv_library(tensorrt_plugin mish_op_plugin.cu pool3d_op_plugin.cu deformable_conv_op_plugin.cu - DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor) + matmul_op_int8_plugin.cu + DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor) nv_test(test_split_plugin SRCS test_split_plugin.cc DEPS paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_plugin) diff --git a/paddle/fluid/inference/tensorrt/plugin/matmul_op_int8_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/matmul_op_int8_plugin.cu new file mode 100644 index 0000000000000000000000000000000000000000..551b9c6c72f1f420542bfa2a64999afabe053894 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/matmul_op_int8_plugin.cu @@ -0,0 +1,1031 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/plugin/matmul_op_int8_plugin.h" + +namespace plf = paddle::platform; +namespace dyl = paddle::platform::dynload; +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { +float zero = 0; +void Ltgemm_int8_linear( + cublasLtHandle_t ltHandle, const int8_t* A, cublasLtMatrixLayout_t Adesc, + int8_t* Atransform, cublasLtMatrixLayout_t AtransformDesc, bool transA_, + const int8_t* B, cublasLtMatrixLayout_t Bdesc, int8_t* Btransform, + cublasLtMatrixLayout_t BtransformDesc, bool transB_, int8_t* C, + cublasLtMatrixLayout_t Cdesc, int8_t* Ctransform, + cublasLtMatrixLayout_t CtransformDesc, + cublasLtMatrixTransformDesc_t transformDescT, + cublasLtMatrixTransformDesc_t transformDescN, + cublasLtMatmulDesc_t matmulDesc, void* alpha_scale, void* alpha_zero, + void* alpha_one, void* workspace, cudaStream_t stream) { + if (transA_) { + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixTransform( + ltHandle, transformDescT, alpha_one, A, Adesc, alpha_zero, nullptr, + nullptr, Atransform, AtransformDesc, stream)); + } else { + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixTransform( + ltHandle, transformDescN, alpha_one, A, Adesc, alpha_zero, nullptr, + nullptr, Atransform, AtransformDesc, stream)); + } + + if (transB_) { + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixTransform( + ltHandle, transformDescN, alpha_one, B, Bdesc, alpha_zero, nullptr, + nullptr, Btransform, BtransformDesc, stream)); + } else { + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixTransform( + ltHandle, transformDescT, alpha_one, B, Bdesc, alpha_zero, nullptr, + nullptr, Btransform, BtransformDesc, stream)); + } + + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatmul( + ltHandle, matmulDesc, alpha_scale, Atransform, AtransformDesc, Btransform, + BtransformDesc, nullptr, Ctransform, CtransformDesc, Ctransform, + CtransformDesc, nullptr, workspace, 0, stream)); + + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixTransform( + ltHandle, transformDescN, alpha_one, Ctransform, CtransformDesc, + alpha_zero, nullptr, nullptr, C, Cdesc, stream)); +} + +void Ltgemm_fp32_linear(cublasLtHandle_t ltHandle, const float* A, + cublasLtMatrixLayout_t Adesc, const float* B, + cublasLtMatrixLayout_t Bdesc, float* C, + cublasLtMatrixLayout_t Cdesc, + cublasLtMatmulDesc_t matmulDesc, void* alpha_scale, + void* alpha_zero, void* workspace, + cudaStream_t stream) { + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatmul( + ltHandle, matmulDesc, alpha_scale, A, Adesc, B, Bdesc, alpha_zero, C, + Cdesc, C, Cdesc, nullptr, workspace, 0, stream)); +} + +void Ltgemm_fp16_linear(cublasLtHandle_t ltHandle, const half* A, + cublasLtMatrixLayout_t Adesc, const half* B, + cublasLtMatrixLayout_t Bdesc, half* C, + cublasLtMatrixLayout_t Cdesc, + cublasLtMatmulDesc_t matmulDesc, void* alpha_scale, + void* alpha_zero, void* workspace, + cudaStream_t stream) { + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatmul( + ltHandle, matmulDesc, alpha_scale, A, Adesc, B, Bdesc, alpha_zero, C, + Cdesc, C, Cdesc, nullptr, workspace, 0, stream)); +} + +nvinfer1::DataType MatmulPlugin::getOutputDataType( + int index, const nvinfer1::DataType* input_types, + int nb_inputs) const TRT_NOEXCEPT { + return input_types[0]; +} + +nvinfer1::Dims MatmulPlugin::getOutputDimensions( + int index, const nvinfer1::Dims* input_dims, int num_inputs) TRT_NOEXCEPT { + if (transB_) { + m_ = dims_x_.d[dims_x_.nbDims - 1]; + k_ = dims_x_.d[dims_x_.nbDims - 2]; + } else { + m_ = dims_x_.d[dims_x_.nbDims - 2]; + k_ = dims_x_.d[dims_x_.nbDims - 1]; + } + if (transA_) { + n_ = dims_y_.d[dims_y_.nbDims - 2]; + } else { + n_ = dims_y_.d[dims_y_.nbDims - 1]; + } + + batch_ = 1; + for (int i = 0; i < dims_x_.nbDims - 2; i++) { + batch_ *= dims_x_.d[i]; + } + nvinfer1::Dims output_dims; + output_dims.nbDims = dims_x_.nbDims; + for (int i = 0; i < output_dims.nbDims - 2; i++) { + output_dims.d[i] = dims_x_.d[i]; + } + output_dims.d[output_dims.nbDims - 2] = m_; + output_dims.d[output_dims.nbDims - 1] = n_; + + return output_dims; +} + +bool MatmulPlugin::supportsFormatCombination( + int32_t pos, nvinfer1::PluginTensorDesc const* inOut, int32_t nbInputs, + int32_t nbOutputs) const TRT_NOEXCEPT { + PADDLE_ENFORCE_EQ(nbInputs, 2, + platform::errors::InvalidArgument("Must have 2 inputs, " + "but got %d input(s). ", + nbInputs)); + PADDLE_ENFORCE_EQ(nbOutputs, getNbOutputs(), + platform::errors::InvalidArgument("Must have 1 output, " + "but got %d output(s). ", + nbOutputs)); + if (pos == 0) { + return (inOut[pos].type == nvinfer1::DataType::kHALF || + inOut[pos].type == nvinfer1::DataType::kFLOAT || + inOut[pos].type == nvinfer1::DataType::kINT8) && + inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; + } else { + return inOut[pos].type == inOut[0].type && + inOut[pos].format == inOut[0].format; + } +} + +void MatmulPlugin::configurePlugin(const nvinfer1::PluginTensorDesc* inputs, + int32_t nbInputs, + const nvinfer1::PluginTensorDesc* out, + int32_t nbOutputs) TRT_NOEXCEPT { + float inscale_0 = inputs[0].scale; + float inscale_1 = inputs[1].scale; + float outscale = out[0].scale; + type_ = inputs[0].type; + int64_t stridea = k_ * n_; + int64_t strideb = k_ * m_; + int64_t stridec = m_ * n_; + + cublasOperation_t AopTranspose, BopTranspose; + if (transA_) { + AopTranspose = CUBLAS_OP_T; + } else { + AopTranspose = CUBLAS_OP_N; + } + if (transB_) { + BopTranspose = CUBLAS_OP_T; + } else { + BopTranspose = CUBLAS_OP_N; + } + + if (type_ == nvinfer1::DataType::kINT8) { + cudaDataType_t cudadataTypeIO = CUDA_R_8I; + cudaDataType_t cudaDataTypeS = CUDA_R_32F; +#if CUBLAS_VER_MAJOR < 11 + cudaDataType_t cudaComputeType = CUDA_R_32I; +#else + cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I; +#endif + cublasLtOrder_t COL32 = CUBLASLT_ORDER_COL32; + cublasLtOrder_t COL4_4R2_8C = CUBLASLT_ORDER_COL4_4R2_8C; + + int const ldatransform = 32 * n_; + int const ldbtransform = 32 * ((m_ + 8 - 1) / 8 * 8); + int const ldctransform = 32 * n_; + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMalloc( + (void**)&Atransform_, + sizeof(int8_t) * ((k_ + 32 - 1) / 32 * 32) / 32 * ldatransform)); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMalloc( + (void**)&Btransform_, + sizeof(int8_t) * ((k_ + 32 - 1) / 32 * 32) / 32 * ldbtransform)); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMalloc( + (void**)&Ctransform_, + sizeof(int8_t) * ((m_ + 32 - 1) / 32 * 32) / 32 * ldctransform)); + + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutCreate( + &Adesc_, cudadataTypeIO, AopTranspose == CUBLAS_OP_N ? n_ : k_, + AopTranspose == CUBLAS_OP_N ? k_ : n_, + AopTranspose == CUBLAS_OP_N ? n_ : k_)); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Adesc_, CUBLASLT_MATRIX_LAYOUT_TYPE, &cudadataTypeIO, + sizeof(cudadataTypeIO))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Adesc_, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &(batch_), sizeof(batch_))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Adesc_, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &(stridea), + sizeof(stridea))); + + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutCreate( + &Bdesc_, cudadataTypeIO, BopTranspose == CUBLAS_OP_N ? k_ : m_, + BopTranspose == CUBLAS_OP_N ? m_ : k_, + BopTranspose == CUBLAS_OP_N ? k_ : m_)); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Bdesc_, CUBLASLT_MATRIX_LAYOUT_TYPE, &cudadataTypeIO, + sizeof(cudadataTypeIO))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Bdesc_, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &(batch_), sizeof(batch_))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Bdesc_, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &(strideb), + sizeof(strideb))); + + PADDLE_ENFORCE_CUDA_SUCCESS( + dyl::cublasLtMatrixLayoutCreate(&Cdesc_, cudadataTypeIO, n_, m_, n_)); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Cdesc_, CUBLASLT_MATRIX_LAYOUT_TYPE, &cudadataTypeIO, + sizeof(cudadataTypeIO))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Cdesc_, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &(batch_), sizeof(batch_))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Cdesc_, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &(stridec), + sizeof(stridec))); + + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutCreate( + &AtransformDesc_, cudadataTypeIO, n_, k_, ldatransform)); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + AtransformDesc_, CUBLASLT_MATRIX_LAYOUT_TYPE, &cudadataTypeIO, + sizeof(cudadataTypeIO))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + AtransformDesc_, CUBLASLT_MATRIX_LAYOUT_ORDER, &COL32, sizeof(COL32))); + + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutCreate( + &BtransformDesc_, cudadataTypeIO, m_, k_, ldbtransform)); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + BtransformDesc_, CUBLASLT_MATRIX_LAYOUT_TYPE, &cudadataTypeIO, + sizeof(cudadataTypeIO))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + BtransformDesc_, CUBLASLT_MATRIX_LAYOUT_ORDER, &COL4_4R2_8C, + sizeof(COL4_4R2_8C))); + + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutCreate( + &CtransformDesc_, cudadataTypeIO, n_, m_, ldctransform)); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + CtransformDesc_, CUBLASLT_MATRIX_LAYOUT_TYPE, &cudadataTypeIO, + sizeof(cudadataTypeIO))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + CtransformDesc_, CUBLASLT_MATRIX_LAYOUT_ORDER, &COL32, sizeof(COL32))); + + cublasOperation_t Transpose = CUBLAS_OP_T; + cublasLtPointerMode_t transform_model = CUBLASLT_POINTER_MODE_DEVICE; + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixTransformDescCreate( + &transformDescT_, cudaDataTypeS)); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixTransformDescSetAttribute( + transformDescT_, CUBLASLT_MATRIX_TRANSFORM_DESC_SCALE_TYPE, + &cudaDataTypeS, sizeof(cudaDataTypeS))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixTransformDescSetAttribute( + transformDescT_, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &Transpose, + sizeof(Transpose))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixTransformDescSetAttribute( + transformDescT_, CUBLASLT_MATRIX_TRANSFORM_DESC_POINTER_MODE, + &transform_model, sizeof(transform_model))); + + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixTransformDescCreate( + &transformDescN_, cudaDataTypeS)); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixTransformDescSetAttribute( + transformDescN_, CUBLASLT_MATRIX_TRANSFORM_DESC_SCALE_TYPE, + &cudaDataTypeS, sizeof(cudaDataTypeS))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixTransformDescSetAttribute( + transformDescN_, CUBLASLT_MATRIX_TRANSFORM_DESC_POINTER_MODE, + &transform_model, sizeof(transform_model))); + + cublasOperation_t ATranspose = CUBLAS_OP_N, BTranspose = CUBLAS_OP_T; + cublasLtPointerMode_t matmul_model = + CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; + +#if CUBLAS_VER_MAJOR < 11 + PADDLE_ENFORCE_CUDA_SUCCESS( + dyl::cublasLtMatmulDescCreate(&matmulDesc_, cudaComputeType)); +#else + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatmulDescCreate( + &matmulDesc_, cudaComputeType, cudaDataTypeS)); +#endif + + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatmulDescSetAttribute( + matmulDesc_, CUBLASLT_MATMUL_DESC_TRANSA, &ATranspose, + sizeof(ATranspose))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatmulDescSetAttribute( + matmulDesc_, CUBLASLT_MATMUL_DESC_TRANSB, &BTranspose, + sizeof(BTranspose))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatmulDescSetAttribute( + matmulDesc_, CUBLASLT_MATMUL_DESC_POINTER_MODE, &matmul_model, + sizeof(matmul_model))); + + float alpha_tem[n_]; + for (int i = 0; i < n_; i++) { + alpha_tem[i] = alpha_ * inscale_0 * inscale_1 / outscale; + } + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaMalloc((void**)&alpha_scale_, n_ * sizeof(float))); + cudaMemcpyAsync(alpha_scale_, alpha_tem, n_ * sizeof(float), + cudaMemcpyHostToDevice); + float zero_tem = zero; + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaMalloc((void**)&alpha_zero_, sizeof(float))); + cudaMemcpyAsync(alpha_zero_, &zero_tem, sizeof(float), + cudaMemcpyHostToDevice); + float one_tem = 1; + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMalloc((void**)&alpha_one_, sizeof(float))); + cudaMemcpyAsync(alpha_one_, &one_tem, sizeof(float), + cudaMemcpyHostToDevice); + } else if (type_ == nvinfer1::DataType::kHALF) { + cudaDataType_t cudadataTypeIO = CUDA_R_16F; + cudaDataType_t cudaDataTypeS = CUDA_R_16F; +#if CUBLAS_VER_MAJOR < 11 + cudaDataType_t cudaComputeType = CUDA_R_16F; +#else + cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_16F; +#endif + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutCreate( + &Adesc_, cudadataTypeIO, AopTranspose == CUBLAS_OP_N ? n_ : k_, + AopTranspose == CUBLAS_OP_N ? k_ : n_, + AopTranspose == CUBLAS_OP_N ? n_ : k_)); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Adesc_, CUBLASLT_MATRIX_LAYOUT_TYPE, &cudadataTypeIO, + sizeof(cudadataTypeIO))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Adesc_, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &(batch_), sizeof(batch_))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Adesc_, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &(stridea), + sizeof(stridea))); + + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutCreate( + &Bdesc_, cudadataTypeIO, BopTranspose == CUBLAS_OP_N ? k_ : m_, + BopTranspose == CUBLAS_OP_N ? m_ : k_, + BopTranspose == CUBLAS_OP_N ? k_ : m_)); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Bdesc_, CUBLASLT_MATRIX_LAYOUT_TYPE, &cudadataTypeIO, + sizeof(cudadataTypeIO))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Bdesc_, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &(batch_), sizeof(batch_))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Bdesc_, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &(strideb), + sizeof(strideb))); + + PADDLE_ENFORCE_CUDA_SUCCESS( + dyl::cublasLtMatrixLayoutCreate(&Cdesc_, cudadataTypeIO, n_, m_, n_)); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Cdesc_, CUBLASLT_MATRIX_LAYOUT_TYPE, &cudadataTypeIO, + sizeof(cudadataTypeIO))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Cdesc_, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &(batch_), sizeof(batch_))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Cdesc_, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &(stridec), + sizeof(stridec))); + + cublasLtPointerMode_t matmul_model = CUBLASLT_POINTER_MODE_DEVICE; + +#if CUBLAS_VER_MAJOR < 11 + PADDLE_ENFORCE_CUDA_SUCCESS( + dyl::cublasLtMatmulDescCreate(&matmulDesc_, cudaComputeType)); +#else + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatmulDescCreate( + &matmulDesc_, cudaComputeType, cudaDataTypeS)); +#endif + + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatmulDescSetAttribute( + matmulDesc_, CUBLASLT_MATMUL_DESC_TRANSA, &AopTranspose, + sizeof(AopTranspose))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatmulDescSetAttribute( + matmulDesc_, CUBLASLT_MATMUL_DESC_TRANSB, &BopTranspose, + sizeof(BopTranspose))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatmulDescSetAttribute( + matmulDesc_, CUBLASLT_MATMUL_DESC_POINTER_MODE, &matmul_model, + sizeof(matmul_model))); + + half alpha_tem = static_cast(alpha_); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaMalloc((void**)&alpha_scale_, sizeof(half))); + cudaMemcpyAsync(alpha_scale_, &alpha_tem, sizeof(half), + cudaMemcpyHostToDevice); + half zero_tem = static_cast(zero); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMalloc((void**)&alpha_zero_, sizeof(half))); + cudaMemcpyAsync(alpha_zero_, &zero_tem, sizeof(half), + cudaMemcpyHostToDevice); + } else { + cudaDataType_t cudadataTypeIO = CUDA_R_32F; + cudaDataType_t cudaDataTypeS = CUDA_R_32F; +#if CUBLAS_VER_MAJOR < 11 + cudaDataType_t cudaComputeType = CUDA_R_32F; +#else + cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32F_FAST_16F; +#endif + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutCreate( + &Adesc_, cudadataTypeIO, AopTranspose == CUBLAS_OP_N ? n_ : k_, + AopTranspose == CUBLAS_OP_N ? k_ : n_, + AopTranspose == CUBLAS_OP_N ? n_ : k_)); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Adesc_, CUBLASLT_MATRIX_LAYOUT_TYPE, &cudadataTypeIO, + sizeof(cudadataTypeIO))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Adesc_, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &(batch_), sizeof(batch_))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Adesc_, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &(stridea), + sizeof(stridea))); + + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutCreate( + &Bdesc_, cudadataTypeIO, BopTranspose == CUBLAS_OP_N ? k_ : m_, + BopTranspose == CUBLAS_OP_N ? m_ : k_, + BopTranspose == CUBLAS_OP_N ? k_ : m_)); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Bdesc_, CUBLASLT_MATRIX_LAYOUT_TYPE, &cudadataTypeIO, + sizeof(cudadataTypeIO))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Bdesc_, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &(batch_), sizeof(batch_))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Bdesc_, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &(strideb), + sizeof(strideb))); + + PADDLE_ENFORCE_CUDA_SUCCESS( + dyl::cublasLtMatrixLayoutCreate(&Cdesc_, cudadataTypeIO, n_, m_, n_)); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Cdesc_, CUBLASLT_MATRIX_LAYOUT_TYPE, &cudadataTypeIO, + sizeof(cudadataTypeIO))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Cdesc_, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &(batch_), sizeof(batch_))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Cdesc_, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &(stridec), + sizeof(stridec))); + + cublasLtPointerMode_t matmul_model = CUBLASLT_POINTER_MODE_DEVICE; + +#if CUBLAS_VER_MAJOR < 11 + PADDLE_ENFORCE_CUDA_SUCCESS( + dyl::cublasLtMatmulDescCreate(&matmulDesc_, cudaComputeType)); +#else + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatmulDescCreate( + &matmulDesc_, cudaComputeType, cudaDataTypeS)); +#endif + + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatmulDescSetAttribute( + matmulDesc_, CUBLASLT_MATMUL_DESC_TRANSA, &AopTranspose, + sizeof(AopTranspose))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatmulDescSetAttribute( + matmulDesc_, CUBLASLT_MATMUL_DESC_TRANSB, &BopTranspose, + sizeof(BopTranspose))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatmulDescSetAttribute( + matmulDesc_, CUBLASLT_MATMUL_DESC_POINTER_MODE, &matmul_model, + sizeof(matmul_model))); + + float alpha_tem = alpha_; + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaMalloc((void**)&alpha_scale_, sizeof(float))); + cudaMemcpyAsync(alpha_scale_, &alpha_tem, sizeof(float), + cudaMemcpyHostToDevice); + float zero_tem = zero; + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaMalloc((void**)&alpha_zero_, sizeof(float))); + cudaMemcpyAsync(alpha_zero_, &zero_tem, sizeof(float), + cudaMemcpyHostToDevice); + } +} + +void MatmulPlugin::attachToContext( + cudnnContext* cudnnContext, cublasContext* cublasContext, + nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT { + dyl::cublasLtCreate(&cublas_); +} + +void MatmulPlugin::detachFromContext() TRT_NOEXCEPT { + dyl::cublasLtDestroy(cublas_); +} + +// When tensorrt engine freed ,there is "double free" ERROR. TODO@Wangzheee +void MatmulPlugin::terminate() TRT_NOEXCEPT { + /* + if(alpha_scale_){ + cudaFree((void *)alpha_scale_); + alpha_scale_ = nullptr; + } + if(alpha_zero_){ + cudaFree((void *)alpha_zero_); + alpha_zero_ = nullptr; + } + if(alpha_one_){ + cudaFree((void *)alpha_one_); + alpha_one_ = nullptr; + } + if(Atransform_){ + cudaFree((void *)Atransform_); + Atransform_ = nullptr; + } + if(Btransform_){ + cudaFree((void *)Btransform_); + Btransform_ = nullptr; + } + if(Ctransform_){ + cudaFree((void *)Ctransform_); + Ctransform_ = nullptr; + } */ +} + +int MatmulPlugin::enqueue(int batchSize, const void* const* inputs, +#if IS_TRT_VERSION_LT(8000) + void** outputs, void* workspace, + cudaStream_t stream) { +#else + void* const* outputs, void* workspace, + cudaStream_t stream) TRT_NOEXCEPT { +#endif + if (type_ == nvinfer1::DataType::kINT8) { + const int8_t* B = static_cast(inputs[0]); + const int8_t* A = static_cast(inputs[1]); + int8_t* C = static_cast(outputs[0]); + Ltgemm_int8_linear( + cublas_, A, Adesc_, Atransform_, AtransformDesc_, transA_, B, Bdesc_, + Btransform_, BtransformDesc_, transB_, C, Cdesc_, Ctransform_, + CtransformDesc_, transformDescT_, transformDescN_, matmulDesc_, + alpha_scale_, alpha_zero_, alpha_one_, workspace, stream); + } else if (type_ == nvinfer1::DataType::kFLOAT) { + const float* B = static_cast(inputs[0]); + const float* A = static_cast(inputs[1]); + float* C = static_cast(outputs[0]); + Ltgemm_fp32_linear(cublas_, A, Adesc_, B, Bdesc_, C, Cdesc_, matmulDesc_, + alpha_scale_, alpha_zero_, workspace, stream); + } else if (type_ == nvinfer1::DataType::kHALF) { + const half* B = static_cast(inputs[0]); + const half* A = static_cast(inputs[1]); + half* C = static_cast(outputs[0]); + Ltgemm_fp16_linear(cublas_, A, Adesc_, B, Bdesc_, C, Cdesc_, matmulDesc_, + alpha_scale_, alpha_zero_, workspace, stream); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "VarMessageToVarType:Unsupported type")); + } + return cudaGetLastError() != cudaSuccess; +} + +nvinfer1::DataType MatmulPluginDynamic::getOutputDataType( + int index, const nvinfer1::DataType* input_types, + int nb_inputs) const TRT_NOEXCEPT { + return input_types[0]; +} + +nvinfer1::DimsExprs MatmulPluginDynamic::getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT { + nvinfer1::DimsExprs output_dims(inputs[0]); + if (transB_) { + output_dims.d[output_dims.nbDims - 2] = inputs[0].d[inputs[0].nbDims - 1]; + } else { + output_dims.d[output_dims.nbDims - 2] = inputs[0].d[inputs[0].nbDims - 2]; + } + if (transA_) { + output_dims.d[output_dims.nbDims - 1] = inputs[1].d[inputs[1].nbDims - 2]; + } else { + output_dims.d[output_dims.nbDims - 1] = inputs[1].d[inputs[1].nbDims - 1]; + } + return output_dims; +} + +bool MatmulPluginDynamic::supportsFormatCombination( + int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, + int nbOutputs) TRT_NOEXCEPT { + PADDLE_ENFORCE_EQ(nbInputs, 2, + platform::errors::InvalidArgument("Must have 2 inputs, " + "but got %d input(s). ", + nbInputs)); + PADDLE_ENFORCE_EQ(nbOutputs, getNbOutputs(), + platform::errors::InvalidArgument("Must have 1 output, " + "but got %d output(s). ", + nbOutputs)); + if (pos == 0) { + return (inOut[pos].type == nvinfer1::DataType::kHALF || + inOut[pos].type == nvinfer1::DataType::kFLOAT || + inOut[pos].type == nvinfer1::DataType::kINT8) && + inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; + } else { + return inOut[pos].type == inOut[0].type && + inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; + } +} + +void MatmulPluginDynamic::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT { + float inscale_0 = inputs[0].desc.scale; + float inscale_1 = inputs[1].desc.scale; + float outscale = outputs[0].desc.scale; + type_ = inputs[0].desc.type; + uint64_t m_max, n_max, k_max; + if (transB_) { + m_max = inputs[0].max.d[inputs[0].max.nbDims - 1]; + k_max = inputs[0].max.d[inputs[0].max.nbDims - 2]; + } else { + m_max = inputs[0].max.d[inputs[0].max.nbDims - 2]; + k_max = inputs[0].max.d[inputs[0].max.nbDims - 1]; + } + if (transA_) { + n_max = inputs[1].max.d[inputs[1].max.nbDims - 2]; + } else { + n_max = inputs[1].max.d[inputs[1].max.nbDims - 1]; + } + + int const ldatransform = 32 * n_max; + int const ldbtransform = 32 * ((m_max + 8 - 1) / 8 * 8); + int const ldctransform = 32 * n_max; + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMalloc( + (void**)&Atransform_, + sizeof(int8_t) * ((k_max + 32 - 1) / 32 * 32) / 32 * ldatransform)); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMalloc( + (void**)&Btransform_, + sizeof(int8_t) * ((k_max + 32 - 1) / 32 * 32) / 32 * ldbtransform)); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMalloc( + (void**)&Ctransform_, + sizeof(int8_t) * ((m_max + 32 - 1) / 32 * 32) / 32 * ldctransform)); + + if (type_ == nvinfer1::DataType::kINT8) { + float alpha_tem[n_max]; + for (int i = 0; i < n_max; i++) { + alpha_tem[i] = alpha_ * inscale_0 * inscale_1 / outscale; + } + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaMalloc((void**)&alpha_scale_, n_max * sizeof(float))); + cudaMemcpyAsync(alpha_scale_, alpha_tem, n_max * sizeof(float), + cudaMemcpyHostToDevice); + float zero_tem = zero; + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaMalloc((void**)&alpha_zero_, sizeof(float))); + cudaMemcpyAsync(alpha_zero_, &zero_tem, sizeof(float), + cudaMemcpyHostToDevice); + float one_tem = 1; + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMalloc((void**)&alpha_one_, sizeof(float))); + cudaMemcpyAsync(alpha_one_, &one_tem, sizeof(float), + cudaMemcpyHostToDevice); + } else if (type_ == nvinfer1::DataType::kHALF) { + half alpha_tem = static_cast(alpha_); + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaMalloc((void**)&alpha_scale_, sizeof(half))); + cudaMemcpyAsync(alpha_scale_, &alpha_tem, sizeof(half), + cudaMemcpyHostToDevice); + half zero_tem = static_cast(zero); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMalloc((void**)&alpha_zero_, sizeof(half))); + cudaMemcpyAsync(alpha_zero_, &zero_tem, sizeof(half), + cudaMemcpyHostToDevice); + } else { + float alpha_tem = alpha_; + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaMalloc((void**)&alpha_scale_, sizeof(float))); + cudaMemcpyAsync(alpha_scale_, &alpha_tem, sizeof(float), + cudaMemcpyHostToDevice); + float zero_tem = zero; + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaMalloc((void**)&alpha_zero_, sizeof(float))); + cudaMemcpyAsync(alpha_zero_, &zero_tem, sizeof(float), + cudaMemcpyHostToDevice); + } +} + +void MatmulPluginDynamic::attachToContext( + cudnnContext* cudnnContext, cublasContext* cublasContext, + nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT { + dyl::cublasLtCreate(&cublas_); +} + +void MatmulPluginDynamic::detachFromContext() TRT_NOEXCEPT { + dyl::cublasLtDestroy(cublas_); +} + +// When tensorrt engine freed ,there is "double free" ERROR. TODO@Wangzheee +void MatmulPluginDynamic::terminate() TRT_NOEXCEPT { + /*if(alpha_scale_){ + cudaFree((void *)alpha_scale_); + alpha_scale_ = nullptr; + } + if(alpha_zero_){ + cudaFree((void *)alpha_zero_); + alpha_zero_ = nullptr; + } + if(alpha_one_){ + cudaFree((void *)alpha_one_); + alpha_one_ = nullptr; + } + if(Atransform_){ + cudaFree((void *)Atransform_); + Atransform_ = nullptr; + } + if(Btransform_){ + cudaFree((void *)Btransform_); + Btransform_ = nullptr; + } + if(Ctransform_){ + cudaFree((void *)Ctransform_); + Ctransform_ = nullptr; + } */ +} + +int MatmulPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, void* workspace, + cudaStream_t stream) TRT_NOEXCEPT { + const auto Input0Desc = inputDesc[0]; + const auto Input1Desc = inputDesc[1]; + uint64_t m, n, k; + if (transB_) { + m = Input0Desc.dims.d[Input0Desc.dims.nbDims - 1]; + k = Input0Desc.dims.d[Input0Desc.dims.nbDims - 2]; + } else { + m = Input0Desc.dims.d[Input0Desc.dims.nbDims - 2]; + k = Input0Desc.dims.d[Input0Desc.dims.nbDims - 1]; + } + if (transA_) { + n = Input1Desc.dims.d[Input1Desc.dims.nbDims - 2]; + } else { + n = Input1Desc.dims.d[Input1Desc.dims.nbDims - 1]; + } + + int batch = 1; + for (int i = 0; i < Input0Desc.dims.nbDims - 2; i++) { + batch *= Input0Desc.dims.d[i]; + } + int const ldatransform = 32 * n; + int const ldbtransform = 32 * ((m + 8 - 1) / 8 * 8); + int const ldctransform = 32 * n; + + int64_t stridea = k * n; + int64_t strideb = k * m; + int64_t stridec = m * n; + + cublasOperation_t AopTranspose, BopTranspose; + if (transA_) { + AopTranspose = CUBLAS_OP_T; + } else { + AopTranspose = CUBLAS_OP_N; + } + if (transB_) { + BopTranspose = CUBLAS_OP_T; + } else { + BopTranspose = CUBLAS_OP_N; + } + + cublasLtMatrixLayout_t Adesc{nullptr}, Bdesc{nullptr}, Cdesc{nullptr}; + cublasLtMatmulDesc_t matmulDesc{nullptr}; + cublasLtMatrixLayout_t AtransformDesc{nullptr}, BtransformDesc{nullptr}, + CtransformDesc{nullptr}; + int8_t *Atransform{nullptr}, *Btransform{nullptr}, *Ctransform{nullptr}; + cublasLtMatrixTransformDesc_t transformDescT{nullptr}, + transformDescN{nullptr}; + if (type_ == nvinfer1::DataType::kINT8) { + cudaDataType_t cudadataTypeIO = CUDA_R_8I; + cudaDataType_t cudaDataTypeS = CUDA_R_32F; +#if CUBLAS_VER_MAJOR < 11 + cudaDataType_t cudaComputeType = CUDA_R_32I; +#else + cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32I; +#endif + cublasLtOrder_t COL32 = CUBLASLT_ORDER_COL32; + cublasLtOrder_t COL4_4R2_8C = CUBLASLT_ORDER_COL4_4R2_8C; + + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutCreate( + &Adesc, cudadataTypeIO, AopTranspose == CUBLAS_OP_N ? n : k, + AopTranspose == CUBLAS_OP_N ? k : n, + AopTranspose == CUBLAS_OP_N ? n : k)); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Adesc, CUBLASLT_MATRIX_LAYOUT_TYPE, &cudadataTypeIO, + sizeof(cudadataTypeIO))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Adesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &(batch), sizeof(batch))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Adesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &(stridea), + sizeof(stridea))); + + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutCreate( + &Bdesc, cudadataTypeIO, BopTranspose == CUBLAS_OP_N ? k : m, + BopTranspose == CUBLAS_OP_N ? m : k, + BopTranspose == CUBLAS_OP_N ? k : m)); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Bdesc, CUBLASLT_MATRIX_LAYOUT_TYPE, &cudadataTypeIO, + sizeof(cudadataTypeIO))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Bdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &(batch), sizeof(batch))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Bdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &(strideb), + sizeof(strideb))); + + PADDLE_ENFORCE_CUDA_SUCCESS( + dyl::cublasLtMatrixLayoutCreate(&Cdesc, cudadataTypeIO, n, m, n)); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Cdesc, CUBLASLT_MATRIX_LAYOUT_TYPE, &cudadataTypeIO, + sizeof(cudadataTypeIO))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Cdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &(batch), sizeof(batch))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Cdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &(stridec), + sizeof(stridec))); + + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutCreate( + &AtransformDesc, cudadataTypeIO, n, k, ldatransform)); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + AtransformDesc, CUBLASLT_MATRIX_LAYOUT_TYPE, &cudadataTypeIO, + sizeof(cudadataTypeIO))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + AtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &COL32, sizeof(COL32))); + + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutCreate( + &BtransformDesc, cudadataTypeIO, m, k, ldbtransform)); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + BtransformDesc, CUBLASLT_MATRIX_LAYOUT_TYPE, &cudadataTypeIO, + sizeof(cudadataTypeIO))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + BtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &COL4_4R2_8C, + sizeof(COL4_4R2_8C))); + + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutCreate( + &CtransformDesc, cudadataTypeIO, n, m, ldctransform)); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + CtransformDesc, CUBLASLT_MATRIX_LAYOUT_TYPE, &cudadataTypeIO, + sizeof(cudadataTypeIO))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + CtransformDesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &COL32, sizeof(COL32))); + + cublasOperation_t Transpose = CUBLAS_OP_T; + cublasLtPointerMode_t transform_model = CUBLASLT_POINTER_MODE_DEVICE; + PADDLE_ENFORCE_CUDA_SUCCESS( + dyl::cublasLtMatrixTransformDescCreate(&transformDescT, cudaDataTypeS)); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixTransformDescSetAttribute( + transformDescT, CUBLASLT_MATRIX_TRANSFORM_DESC_SCALE_TYPE, + &cudaDataTypeS, sizeof(cudaDataTypeS))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixTransformDescSetAttribute( + transformDescT, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &Transpose, + sizeof(Transpose))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixTransformDescSetAttribute( + transformDescT, CUBLASLT_MATRIX_TRANSFORM_DESC_POINTER_MODE, + &transform_model, sizeof(transform_model))); + + PADDLE_ENFORCE_CUDA_SUCCESS( + dyl::cublasLtMatrixTransformDescCreate(&transformDescN, cudaDataTypeS)); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixTransformDescSetAttribute( + transformDescN, CUBLASLT_MATRIX_TRANSFORM_DESC_SCALE_TYPE, + &cudaDataTypeS, sizeof(cudaDataTypeS))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixTransformDescSetAttribute( + transformDescN, CUBLASLT_MATRIX_TRANSFORM_DESC_POINTER_MODE, + &transform_model, sizeof(transform_model))); + + cublasOperation_t ATranspose = CUBLAS_OP_N, BTranspose = CUBLAS_OP_T; + cublasLtPointerMode_t matmul_model = + CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; + +#if CUBLAS_VER_MAJOR < 11 + PADDLE_ENFORCE_CUDA_SUCCESS( + dyl::cublasLtMatmulDescCreate(&matmulDesc, cudaComputeType)); +#else + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatmulDescCreate( + &matmulDesc, cudaComputeType, cudaDataTypeS)); +#endif + + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatmulDescSetAttribute( + matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &ATranspose, + sizeof(ATranspose))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatmulDescSetAttribute( + matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &BTranspose, + sizeof(BTranspose))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatmulDescSetAttribute( + matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &matmul_model, + sizeof(matmul_model))); + + const int8_t* B = static_cast(inputs[0]); + const int8_t* A = static_cast(inputs[1]); + int8_t* C = static_cast(outputs[0]); + Ltgemm_int8_linear(cublas_, A, Adesc, Atransform_, AtransformDesc, transA_, + B, Bdesc, Btransform_, BtransformDesc, transB_, C, Cdesc, + Ctransform_, CtransformDesc, transformDescT, + transformDescN, matmulDesc, alpha_scale_, alpha_zero_, + alpha_one_, workspace, stream); + } else if (type_ == nvinfer1::DataType::kHALF) { + cudaDataType_t cudadataTypeIO = CUDA_R_16F; + cudaDataType_t cudaDataTypeS = CUDA_R_16F; +#if CUBLAS_VER_MAJOR < 11 + cudaDataType_t cudaComputeType = CUDA_R_16F; +#else + cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_16F; +#endif + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutCreate( + &Adesc, cudadataTypeIO, AopTranspose == CUBLAS_OP_N ? n : k, + AopTranspose == CUBLAS_OP_N ? k : n, + AopTranspose == CUBLAS_OP_N ? n : k)); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Adesc, CUBLASLT_MATRIX_LAYOUT_TYPE, &cudadataTypeIO, + sizeof(cudadataTypeIO))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Adesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &(batch), sizeof(batch))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Adesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &(stridea), + sizeof(stridea))); + + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutCreate( + &Bdesc, cudadataTypeIO, BopTranspose == CUBLAS_OP_N ? k : m, + BopTranspose == CUBLAS_OP_N ? m : k, + BopTranspose == CUBLAS_OP_N ? k : m)); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Bdesc, CUBLASLT_MATRIX_LAYOUT_TYPE, &cudadataTypeIO, + sizeof(cudadataTypeIO))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Bdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &(batch), sizeof(batch))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Bdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &(strideb), + sizeof(strideb))); + + PADDLE_ENFORCE_CUDA_SUCCESS( + dyl::cublasLtMatrixLayoutCreate(&Cdesc, cudadataTypeIO, n, m, n)); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Cdesc, CUBLASLT_MATRIX_LAYOUT_TYPE, &cudadataTypeIO, + sizeof(cudadataTypeIO))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Cdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &(batch), sizeof(batch))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Cdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &(stridec), + sizeof(stridec))); + + cublasLtPointerMode_t matmul_model = CUBLASLT_POINTER_MODE_DEVICE; + +#if CUBLAS_VER_MAJOR < 11 + PADDLE_ENFORCE_CUDA_SUCCESS( + dyl::cublasLtMatmulDescCreate(&matmulDesc, cudaComputeType)); +#else + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatmulDescCreate( + &matmulDesc, cudaComputeType, cudaDataTypeS)); +#endif + + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatmulDescSetAttribute( + matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &AopTranspose, + sizeof(AopTranspose))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatmulDescSetAttribute( + matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &BopTranspose, + sizeof(BopTranspose))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatmulDescSetAttribute( + matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &matmul_model, + sizeof(matmul_model))); + + const half* B = static_cast(inputs[0]); + const half* A = static_cast(inputs[1]); + half* C = static_cast(outputs[0]); + Ltgemm_fp16_linear(cublas_, A, Adesc, B, Bdesc, C, Cdesc, matmulDesc, + alpha_scale_, alpha_zero_, workspace, stream); + } else { + cudaDataType_t cudadataTypeIO = CUDA_R_32F; + cudaDataType_t cudaDataTypeS = CUDA_R_32F; +#if CUBLAS_VER_MAJOR < 11 + cudaDataType_t cudaComputeType = CUDA_R_32F; +#else + cublasComputeType_t cudaComputeType = CUBLAS_COMPUTE_32F_FAST_16F; +#endif + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutCreate( + &Adesc, cudadataTypeIO, AopTranspose == CUBLAS_OP_N ? n : k, + AopTranspose == CUBLAS_OP_N ? k : n, + AopTranspose == CUBLAS_OP_N ? n : k)); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Adesc, CUBLASLT_MATRIX_LAYOUT_TYPE, &cudadataTypeIO, + sizeof(cudadataTypeIO))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Adesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &(batch), sizeof(batch))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Adesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &(stridea), + sizeof(stridea))); + + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutCreate( + &Bdesc, cudadataTypeIO, BopTranspose == CUBLAS_OP_N ? k : m, + BopTranspose == CUBLAS_OP_N ? m : k, + BopTranspose == CUBLAS_OP_N ? k : m)); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Bdesc, CUBLASLT_MATRIX_LAYOUT_TYPE, &cudadataTypeIO, + sizeof(cudadataTypeIO))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Bdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &(batch), sizeof(batch))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Bdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &(strideb), + sizeof(strideb))); + + PADDLE_ENFORCE_CUDA_SUCCESS( + dyl::cublasLtMatrixLayoutCreate(&Cdesc, cudadataTypeIO, n, m, n)); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Cdesc, CUBLASLT_MATRIX_LAYOUT_TYPE, &cudadataTypeIO, + sizeof(cudadataTypeIO))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Cdesc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &(batch), sizeof(batch))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatrixLayoutSetAttribute( + Cdesc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &(stridec), + sizeof(stridec))); + + cublasLtPointerMode_t matmul_model = CUBLASLT_POINTER_MODE_DEVICE; + +#if CUBLAS_VER_MAJOR < 11 + PADDLE_ENFORCE_CUDA_SUCCESS( + dyl::cublasLtMatmulDescCreate(&matmulDesc, cudaComputeType)); +#else + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatmulDescCreate( + &matmulDesc, cudaComputeType, cudaDataTypeS)); +#endif + + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatmulDescSetAttribute( + matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &AopTranspose, + sizeof(AopTranspose))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatmulDescSetAttribute( + matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &BopTranspose, + sizeof(BopTranspose))); + PADDLE_ENFORCE_CUDA_SUCCESS(dyl::cublasLtMatmulDescSetAttribute( + matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &matmul_model, + sizeof(matmul_model))); + + const float* B = static_cast(inputs[0]); + const float* A = static_cast(inputs[1]); + float* C = static_cast(outputs[0]); + Ltgemm_fp32_linear(cublas_, A, Adesc, B, Bdesc, C, Cdesc, matmulDesc, + alpha_scale_, alpha_zero_, workspace, stream); + } + return cudaGetLastError() != cudaSuccess; +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/matmul_op_int8_plugin.h b/paddle/fluid/inference/tensorrt/plugin/matmul_op_int8_plugin.h new file mode 100644 index 0000000000000000000000000000000000000000..be8f1c418fc7faa4b72f9f0a0a076ac69376996e --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/matmul_op_int8_plugin.h @@ -0,0 +1,432 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include + +#include +#include +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" +#include "paddle/fluid/platform/dynload/cublasLt.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +class MatmulPlugin : public nvinfer1::IPluginV2IOExt { + public: + MatmulPlugin(nvinfer1::Dims const& dims_x, nvinfer1::Dims const& dims_y, + bool transA, bool transB, float alpha) + : dims_x_(dims_x), + dims_y_(dims_y), + transB_(transA), + transA_(transB), + alpha_(alpha) {} + + MatmulPlugin(void const* serial_data, size_t serial_length) { + DeserializeValue(&serial_data, &serial_length, &dims_x_); + DeserializeValue(&serial_data, &serial_length, &dims_y_); + DeserializeValue(&serial_data, &serial_length, &transB_); + DeserializeValue(&serial_data, &serial_length, &transA_); + DeserializeValue(&serial_data, &serial_length, &alpha_); + DeserializeValue(&serial_data, &serial_length, &alpha_scale_); + DeserializeValue(&serial_data, &serial_length, &alpha_one_); + DeserializeValue(&serial_data, &serial_length, &alpha_zero_); + DeserializeValue(&serial_data, &serial_length, &batch_); + DeserializeValue(&serial_data, &serial_length, &k_); + DeserializeValue(&serial_data, &serial_length, &m_); + DeserializeValue(&serial_data, &serial_length, &n_); + DeserializeValue(&serial_data, &serial_length, &cublas_); + DeserializeValue(&serial_data, &serial_length, &type_); + DeserializeValue(&serial_data, &serial_length, &Adesc_); + DeserializeValue(&serial_data, &serial_length, &Bdesc_); + DeserializeValue(&serial_data, &serial_length, &Cdesc_); + DeserializeValue(&serial_data, &serial_length, &AtransformDesc_); + DeserializeValue(&serial_data, &serial_length, &BtransformDesc_); + DeserializeValue(&serial_data, &serial_length, &CtransformDesc_); + DeserializeValue(&serial_data, &serial_length, &Atransform_); + DeserializeValue(&serial_data, &serial_length, &Btransform_); + DeserializeValue(&serial_data, &serial_length, &Ctransform_); + DeserializeValue(&serial_data, &serial_length, &transformDescT_); + DeserializeValue(&serial_data, &serial_length, &transformDescN_); + DeserializeValue(&serial_data, &serial_length, &matmulDesc_); + } + + virtual bool isOutputBroadcastAcrossBatch( + int32_t output_index, const bool* input_is_broadcasted, + int32_t nb_inputs) const TRT_NOEXCEPT { + return false; + } + + virtual bool canBroadcastInputAcrossBatch(int32_t input_index) const + TRT_NOEXCEPT { + return false; + } + + const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; } + + size_t getWorkspaceSize(int) const TRT_NOEXCEPT override { return 0; } + + void setPluginNamespace(const char* plugin_namespace) TRT_NOEXCEPT override { + name_space_ = plugin_namespace; + } + + nvinfer1::IPluginV2IOExt* clone() const TRT_NOEXCEPT override { + MatmulPlugin* ptr = + new MatmulPlugin(dims_x_, dims_y_, transB_, transA_, alpha_); + ptr->setPluginNamespace(this->getPluginNamespace()); + ptr->batch_ = batch_; + ptr->k_ = k_; + ptr->m_ = m_; + ptr->n_ = n_; + ptr->alpha_scale_ = alpha_scale_; + ptr->alpha_one_ = alpha_one_; + ptr->alpha_zero_ = alpha_zero_; + ptr->cublas_ = cublas_; + ptr->type_ = type_; + ptr->Adesc_ = Adesc_; + ptr->Bdesc_ = Bdesc_; + ptr->Cdesc_ = Cdesc_; + ptr->AtransformDesc_ = AtransformDesc_; + ptr->BtransformDesc_ = BtransformDesc_; + ptr->CtransformDesc_ = CtransformDesc_; + ptr->Atransform_ = Atransform_; + ptr->Btransform_ = Btransform_; + ptr->Ctransform_ = Ctransform_; + ptr->transformDescT_ = transformDescT_; + ptr->transformDescN_ = transformDescN_; + ptr->matmulDesc_ = matmulDesc_; + return ptr; + } + + const char* getPluginNamespace() const TRT_NOEXCEPT override { + return name_space_.c_str(); + } + + const char* getPluginType() const TRT_NOEXCEPT override { + return "matmul_int8_plugin"; + } + + nvinfer1::DataType getOutputDataType( + int index, const nvinfer1::DataType* input_types, + int nb_inputs) const TRT_NOEXCEPT override; + + int getNbOutputs() const TRT_NOEXCEPT override { return 1; } + + nvinfer1::Dims getOutputDimensions(int index, + const nvinfer1::Dims* input_dims, + int num_inputs) TRT_NOEXCEPT override; + + bool supportsFormatCombination(int32_t pos, + nvinfer1::PluginTensorDesc const* inOut, + int32_t nbInputs, + int32_t nbOutputs) const TRT_NOEXCEPT override; + + void configurePlugin(const nvinfer1::PluginTensorDesc* in, int32_t nbInputs, + const nvinfer1::PluginTensorDesc* out, + int32_t nbOutputs) TRT_NOEXCEPT override; + + /* + bool supportsFormat(nvinfer1::DataType type, nvinfer1::PluginFormat format) + const TRT_NOEXCEPT override; + */ + int initialize() TRT_NOEXCEPT { return 0; } + void terminate() TRT_NOEXCEPT; + +#if IS_TRT_VERSION_LT(8000) + int enqueue(int batch_size, const void* const* inputs, void** outputs, +#else + int enqueue(int batch_size, const void* const* inputs, void* const* outputs, +#endif + void* workspace, cudaStream_t stream) TRT_NOEXCEPT override; + + void destroy() TRT_NOEXCEPT override { delete this; } + void attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, + nvinfer1::IGpuAllocator* gpuAllocator) + TRT_NOEXCEPT override; + void detachFromContext() TRT_NOEXCEPT override; + + protected: + nvinfer1::Dims dims_x_; + nvinfer1::Dims dims_y_; + bool transB_; + bool transA_; + float alpha_; + void *alpha_scale_{nullptr}, *alpha_one_{nullptr}, *alpha_zero_{nullptr}; + int batch_; + uint64_t k_; + uint64_t m_; + uint64_t n_; + cublasLtHandle_t cublas_{nullptr}; + nvinfer1::DataType type_; + cublasLtMatrixLayout_t Adesc_{nullptr}, Bdesc_{nullptr}, Cdesc_{nullptr}; + cublasLtMatrixLayout_t AtransformDesc_{nullptr}, BtransformDesc_{nullptr}, + CtransformDesc_{nullptr}; + int8_t *Atransform_{nullptr}, *Btransform_{nullptr}, *Ctransform_{nullptr}; + cublasLtMatrixTransformDesc_t transformDescT_{nullptr}, + transformDescN_{nullptr}; + cublasLtMatmulDesc_t matmulDesc_{nullptr}; + std::string name_space_; + + size_t getSerializationSize() const TRT_NOEXCEPT override { + return SerializedSize(dims_x_) + SerializedSize(dims_y_) + + SerializedSize(transB_) + SerializedSize(transA_) + + SerializedSize(alpha_) + SerializedSize(alpha_scale_) + + SerializedSize(alpha_one_) + SerializedSize(alpha_zero_) + + SerializedSize(batch_) + SerializedSize(k_) + SerializedSize(m_) + + SerializedSize(n_) + SerializedSize(cublas_) + + SerializedSize(type_) + SerializedSize(Adesc_) + + SerializedSize(Bdesc_) + SerializedSize(Cdesc_) + + SerializedSize(AtransformDesc_) + SerializedSize(BtransformDesc_) + + SerializedSize(CtransformDesc_) + SerializedSize(Atransform_) + + SerializedSize(Btransform_) + SerializedSize(Ctransform_) + + SerializedSize(transformDescT_) + SerializedSize(transformDescN_) + + SerializedSize(matmulDesc_); + } + + void serialize(void* buffer) const TRT_NOEXCEPT override { + SerializeValue(&buffer, dims_x_); + SerializeValue(&buffer, dims_y_); + SerializeValue(&buffer, transB_); + SerializeValue(&buffer, transA_); + SerializeValue(&buffer, alpha_); + SerializeValue(&buffer, alpha_scale_); + SerializeValue(&buffer, alpha_one_); + SerializeValue(&buffer, alpha_zero_); + SerializeValue(&buffer, batch_); + SerializeValue(&buffer, k_); + SerializeValue(&buffer, m_); + SerializeValue(&buffer, n_); + SerializeValue(&buffer, cublas_); + SerializeValue(&buffer, type_); + SerializeValue(&buffer, Adesc_); + SerializeValue(&buffer, Bdesc_); + SerializeValue(&buffer, Cdesc_); + SerializeValue(&buffer, AtransformDesc_); + SerializeValue(&buffer, BtransformDesc_); + SerializeValue(&buffer, CtransformDesc_); + SerializeValue(&buffer, Atransform_); + SerializeValue(&buffer, Btransform_); + SerializeValue(&buffer, Ctransform_); + SerializeValue(&buffer, transformDescT_); + SerializeValue(&buffer, transformDescN_); + SerializeValue(&buffer, matmulDesc_); + } +}; + +class MatmulPluginCreator : public nvinfer1::IPluginCreator { + public: + MatmulPluginCreator() {} + const char* getPluginName() const TRT_NOEXCEPT override { + return "matmul_int8_plugin"; + } + const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; } + + const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override { + return &field_collection_; + } + + nvinfer1::IPluginV2IOExt* createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT override { + return nullptr; + } + + nvinfer1::IPluginV2IOExt* deserializePlugin( + const char* name, void const* serial_data, + size_t serial_length) TRT_NOEXCEPT override { + MatmulPlugin* obj = new MatmulPlugin(serial_data, serial_length); + obj->setPluginNamespace(name); + return obj; + } + + void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override { + plugin_namespace_ = lib_namespace; + } + + const char* getPluginNamespace() const TRT_NOEXCEPT override { + return plugin_namespace_.c_str(); + } + + private: + std::string plugin_namespace_; + std::string plugin_name_; + nvinfer1::PluginFieldCollection field_collection_{0, nullptr}; + std::vector plugin_attributes_; +}; +REGISTER_TRT_PLUGIN_V2(MatmulPluginCreator); + +#if IS_TRT_VERSION_GE(6000) +class MatmulPluginDynamic : public DynamicPluginTensorRT { + public: + MatmulPluginDynamic(bool transA, bool transB, float alpha) + : transB_(transA), transA_(transB), alpha_(alpha) {} + + MatmulPluginDynamic(void const* serial_data, size_t serial_length) { + DeserializeValue(&serial_data, &serial_length, &transB_); + DeserializeValue(&serial_data, &serial_length, &transA_); + DeserializeValue(&serial_data, &serial_length, &alpha_); + DeserializeValue(&serial_data, &serial_length, &alpha_scale_); + DeserializeValue(&serial_data, &serial_length, &alpha_one_); + DeserializeValue(&serial_data, &serial_length, &alpha_zero_); + DeserializeValue(&serial_data, &serial_length, &cublas_); + DeserializeValue(&serial_data, &serial_length, &Atransform_); + DeserializeValue(&serial_data, &serial_length, &Btransform_); + DeserializeValue(&serial_data, &serial_length, &Ctransform_); + DeserializeValue(&serial_data, &serial_length, &type_); + } + + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { + MatmulPluginDynamic* ptr = + new MatmulPluginDynamic(transB_, transA_, alpha_); + ptr->setPluginNamespace(this->getPluginNamespace()); + ptr->alpha_scale_ = alpha_scale_; + ptr->alpha_one_ = alpha_one_; + ptr->alpha_zero_ = alpha_zero_; + ptr->cublas_ = cublas_; + ptr->Atransform_ = Atransform_; + ptr->Btransform_ = Btransform_; + ptr->Ctransform_ = Ctransform_; + ptr->type_ = type_; + return ptr; + } + + const char* getPluginType() const TRT_NOEXCEPT override { + return "matmul_int8_dynamic_plugin"; + } + + int getNbOutputs() const TRT_NOEXCEPT override { return 1; } + + int initialize() TRT_NOEXCEPT { return 0; } + void terminate() TRT_NOEXCEPT; + nvinfer1::DimsExprs getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT override; + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT override { + return 0; + } + + void attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, + nvinfer1::IGpuAllocator* gpuAllocator) + TRT_NOEXCEPT override; + + void detachFromContext() TRT_NOEXCEPT override; + + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + nvinfer1::DataType getOutputDataType( + int index, const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT override; + + void destroy() TRT_NOEXCEPT override { delete this; } + + protected: + bool transB_; + bool transA_; + float alpha_; + void *alpha_scale_{nullptr}, *alpha_one_{nullptr}, *alpha_zero_{nullptr}; + cublasLtHandle_t cublas_{nullptr}; + nvinfer1::DataType type_; + int8_t *Atransform_{nullptr}, *Btransform_{nullptr}, *Ctransform_{nullptr}; + std::string name_space_; + + size_t getSerializationSize() const TRT_NOEXCEPT override { + return SerializedSize(transB_) + SerializedSize(transA_) + + SerializedSize(alpha_) + SerializedSize(alpha_scale_) + + SerializedSize(alpha_one_) + SerializedSize(alpha_zero_) + + SerializedSize(Atransform_) + SerializedSize(Btransform_) + + SerializedSize(Ctransform_) + SerializedSize(cublas_) + + SerializedSize(type_); + } + + void serialize(void* buffer) const TRT_NOEXCEPT override { + SerializeValue(&buffer, transB_); + SerializeValue(&buffer, transA_); + SerializeValue(&buffer, alpha_); + SerializeValue(&buffer, alpha_scale_); + SerializeValue(&buffer, alpha_one_); + SerializeValue(&buffer, alpha_zero_); + SerializeValue(&buffer, Atransform_); + SerializeValue(&buffer, Btransform_); + SerializeValue(&buffer, Ctransform_); + SerializeValue(&buffer, cublas_); + SerializeValue(&buffer, type_); + } +}; + +class MatmulPluginDynamicCreator : public nvinfer1::IPluginCreator { + public: + MatmulPluginDynamicCreator() {} + const char* getPluginName() const TRT_NOEXCEPT override { + return "matmul_int8_dynamic_plugin"; + } + const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; } + + const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override { + return &field_collection_; + } + + nvinfer1::IPluginV2* createPlugin(const char* name, + const nvinfer1::PluginFieldCollection* fc) + TRT_NOEXCEPT override { + return nullptr; + } + + nvinfer1::IPluginV2* deserializePlugin( + const char* name, void const* serial_data, + size_t serial_length) TRT_NOEXCEPT override { + MatmulPluginDynamic* obj = + new MatmulPluginDynamic(serial_data, serial_length); + obj->setPluginNamespace(name); + return obj; + } + + void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override { + plugin_namespace_ = lib_namespace; + } + + const char* getPluginNamespace() const TRT_NOEXCEPT override { + return plugin_namespace_.c_str(); + } + + private: + std::string plugin_namespace_; + std::string plugin_name_; + nvinfer1::PluginFieldCollection field_collection_{0, nullptr}; + std::vector plugin_attributes_; +}; +REGISTER_TRT_PLUGIN_V2(MatmulPluginDynamicCreator); +#endif +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/CMakeLists.txt b/paddle/fluid/platform/dynload/CMakeLists.txt index b396caf54a45a83a0acb51110663464a5bb84641..2be58376e309949e6d5818c67c96d04973fa85fa 100644 --- a/paddle/fluid/platform/dynload/CMakeLists.txt +++ b/paddle/fluid/platform/dynload/CMakeLists.txt @@ -1,6 +1,6 @@ cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce) -list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc cusolver.cc cusparse.cc nvtx.cc cufft.cc) +list(APPEND CUDA_SRCS cublas.cc cublasLt.cc cudnn.cc curand.cc cusolver.cc cusparse.cc nvtx.cc cufft.cc) if (NOT WITH_NV_JETSON) list(APPEND CUDA_SRCS nvjpeg.cc) diff --git a/paddle/fluid/platform/dynload/cublasLt.cc b/paddle/fluid/platform/dynload/cublasLt.cc new file mode 100644 index 0000000000000000000000000000000000000000..78f952985c8117c6832be0af2c657dc6a9502d41 --- /dev/null +++ b/paddle/fluid/platform/dynload/cublasLt.cc @@ -0,0 +1,29 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/dynload/cublasLt.h" + +namespace paddle { +namespace platform { +namespace dynload { +std::once_flag cublasLt_dso_flag; +void *cublasLt_dso_handle = nullptr; + +#define DEFINE_WRAP(__name) DynLoad__##__name __name + +CUBLASLT_BLAS_ROUTINE_EACH(DEFINE_WRAP); + +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/cublasLt.h b/paddle/fluid/platform/dynload/cublasLt.h new file mode 100644 index 0000000000000000000000000000000000000000..f4e04c94e04c615dce496ff0c95064b6326880f7 --- /dev/null +++ b/paddle/fluid/platform/dynload/cublasLt.h @@ -0,0 +1,78 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include // NOLINT +#include + +#include "paddle/fluid/platform/dynload/dynamic_loader.h" +#include "paddle/fluid/platform/port.h" + +namespace paddle { +namespace platform { +namespace dynload { + +extern std::once_flag cublasLt_dso_flag; +extern void *cublasLt_dso_handle; + +/** + * The following macro definition can generate structs + * (for each function) to dynamic load cublasLt routine + * via operator overloading. + * + * note: default dynamic linked libs + */ +#define DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + inline auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ + using cublasLt_func = \ + decltype(::__name(std::declval()...)) (*)(Args...); \ + std::call_once(cublasLt_dso_flag, []() { \ + cublasLt_dso_handle = \ + paddle::platform::dynload::GetCublasLtDsoHandle(); \ + }); \ + static void *p_##__name = dlsym(cublasLt_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + }; \ + extern DynLoad__##__name __name + +// APIs available after CUDA 10.1 +// #if CUDA_VERSION >= 10100 +#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ + __macro(cublasLtCreate); \ + __macro(cublasLtDestroy); \ + __macro(cublasLtMatmul); \ + __macro(cublasLtMatmulDescCreate); \ + __macro(cublasLtMatmulDescDestroy); \ + __macro(cublasLtMatmulDescSetAttribute); \ + __macro(cublasLtMatrixLayoutCreate); \ + __macro(cublasLtMatrixLayoutDestroy); \ + __macro(cublasLtMatrixLayoutSetAttribute); \ + __macro(cublasLtMatrixTransform); \ + __macro(cublasLtMatrixTransformDescCreate); \ + __macro(cublasLtMatrixTransformDescDestroy); \ + __macro(cublasLtMatrixTransformDescSetAttribute); + +CUBLASLT_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP) +// #endif + +#undef DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/dynamic_loader.cc b/paddle/fluid/platform/dynload/dynamic_loader.cc index 544c1c194d996991f9c69fed4f781cbceba31d42..905f1aea887ab8ef4d971f7697d366dd8c89b8d7 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.cc +++ b/paddle/fluid/platform/dynload/dynamic_loader.cc @@ -30,10 +30,11 @@ DEFINE_string(cudnn_dir, "", "/usr/local/cudnn/lib. If empty [default], dlopen " "will search cudnn from LD_LIBRARY_PATH"); -DEFINE_string(cuda_dir, "", - "Specify path for loading cuda library, such as libcublas, " - "libcurand, libcusolver. For instance, /usr/local/cuda/lib64. " - "If default, dlopen will search cuda from LD_LIBRARY_PATH"); +DEFINE_string( + cuda_dir, "", + "Specify path for loading cuda library, such as libcublas, libcublasLt " + "libcurand, libcusolver. For instance, /usr/local/cuda/lib64. " + "If default, dlopen will search cuda from LD_LIBRARY_PATH"); DEFINE_string(nccl_dir, "", "Specify path for loading nccl library, such as libnccl.so. " @@ -308,6 +309,19 @@ void* GetCublasDsoHandle() { #endif } +void* GetCublasLtDsoHandle() { +// APIs available after CUDA 10.1 +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10100 + return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcublasLt.so"); +#else + std::string warning_msg( + "Your CUDA_VERSION less 10.1, not support CublasLt. " + "If you want to use CublasLt, please upgrade CUDA and rebuild " + "PaddlePaddle."); + return nullptr; +#endif +} + void* GetCUDNNDsoHandle() { #if defined(__APPLE__) || defined(__OSX__) std::string mac_warn_meg( diff --git a/paddle/fluid/platform/dynload/dynamic_loader.h b/paddle/fluid/platform/dynload/dynamic_loader.h index 1a66f4b979207e6947c2026fd44acd0a3b7c8e62..ca60cd76a59e109b3e5891e52fcecd1ffcd1a723 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.h +++ b/paddle/fluid/platform/dynload/dynamic_loader.h @@ -26,6 +26,7 @@ namespace dynload { #endif void* GetCublasDsoHandle(); +void* GetCublasLtDsoHandle(); void* GetCUDNNDsoHandle(); void* GetCUPTIDsoHandle(); void* GetCurandDsoHandle(); diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index c238512ca8e635ce641c8e38e7b9961f53c70647..b55fce5befdda7e579c5b98df82d2ba5096206bf 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -1997,10 +1997,12 @@ function gen_dockerfile() { DOCKERFILE_GPU_ENV="" DOCKERFILE_CUDNN_DSO="" DOCKERFILE_CUBLAS_DSO="" + DOCKERFILE_CUBLASLT_DSO="" if [[ ${WITH_GPU:-OFF} == 'ON' ]]; then DOCKERFILE_GPU_ENV="ENV LD_LIBRARY_PATH /usr/lib/x86_64-linux-gnu:\${LD_LIBRARY_PATH}" DOCKERFILE_CUDNN_DSO="RUN ln -sf /usr/lib/x86_64-linux-gnu/libcudnn.so.${CUDNN_MAJOR} /usr/lib/x86_64-linux-gnu/libcudnn.so" DOCKERFILE_CUBLAS_DSO="RUN ln -sf /usr/local/cuda/targets/x86_64-linux/lib/libcublas.so.${CUDA_MAJOR} /usr/lib/x86_64-linux-gnu/libcublas.so" + DOCKERFILE_CUBLASLT_DSO="RUN ln -sf /usr/local/cuda/targets/x86_64-linux/lib/libcublasLt.so /usr/lib/x86_64-linux-gnu/libcublasLt.so" fi cat <> ${PADDLE_ROOT}/build/Dockerfile <