From 20b38cfa89dfb20570a889bc25dabf7fcce58f63 Mon Sep 17 00:00:00 2001 From: minghaoBD <79566150+minghaoBD@users.noreply.github.com> Date: Thu, 9 Jun 2022 19:31:37 +0800 Subject: [PATCH] [sparse inference] Supporting 2:4 sparse inference (#43179) --- CMakeLists.txt | 1 + cmake/external/cusparselt.cmake | 61 ++ cmake/inference_lib.cmake | 8 + cmake/third_party.cmake | 5 + paddle/fluid/framework/ir/CMakeLists.txt | 10 + .../framework/ir/dense_fc_to_sparse_pass.cc | 148 +++ .../framework/ir/dense_fc_to_sparse_pass.h | 61 ++ .../ir/dense_fc_to_sparse_pass_tester.cc | 107 ++ .../dense_multihead_matmul_to_sparse_pass.cc | 174 ++++ .../dense_multihead_matmul_to_sparse_pass.h | 61 ++ ..._multihead_matmul_to_sparse_pass_tester.cc | 152 +++ .../fluid/inference/api/analysis_predictor.cc | 4 + .../inference/api/paddle_pass_builder.cc | 6 +- .../fluid/inference/tensorrt/CMakeLists.txt | 7 +- .../inference/tensorrt/convert/CMakeLists.txt | 127 +-- .../tensorrt/convert/sparse_fc_op.cc | 371 +++++++ .../convert/sparse_multihead_matmul_op.cc | 441 +++++++++ paddle/fluid/inference/tensorrt/op_teller.cc | 16 + .../inference/tensorrt/plugin/CMakeLists.txt | 63 +- .../inference/tensorrt/plugin/spmm_plugin.cu | 923 ++++++++++++++++++ .../inference/tensorrt/plugin/spmm_plugin.h | 158 +++ .../inference/tensorrt/test_dynamic_engine.cc | 175 ++++ paddle/fluid/platform/dynload/CMakeLists.txt | 4 + paddle/fluid/platform/dynload/cusparseLt.cc | 29 + paddle/fluid/platform/dynload/cusparseLt.h | 60 ++ .../fluid/platform/dynload/dynamic_loader.cc | 4 + .../fluid/platform/dynload/dynamic_loader.h | 1 + paddle/phi/backends/dynload/CMakeLists.txt | 4 + paddle/phi/backends/dynload/cusparseLt.cc | 28 + paddle/phi/backends/dynload/cusparseLt.h | 78 ++ paddle/phi/backends/dynload/dynamic_loader.cc | 15 + paddle/phi/backends/dynload/dynamic_loader.h | 1 + 32 files changed, 3212 insertions(+), 91 deletions(-) create mode 100644 cmake/external/cusparselt.cmake create mode 100644 paddle/fluid/framework/ir/dense_fc_to_sparse_pass.cc create mode 100644 paddle/fluid/framework/ir/dense_fc_to_sparse_pass.h create mode 100644 paddle/fluid/framework/ir/dense_fc_to_sparse_pass_tester.cc create mode 100644 paddle/fluid/framework/ir/dense_multihead_matmul_to_sparse_pass.cc create mode 100644 paddle/fluid/framework/ir/dense_multihead_matmul_to_sparse_pass.h create mode 100644 paddle/fluid/framework/ir/dense_multihead_matmul_to_sparse_pass_tester.cc create mode 100644 paddle/fluid/inference/tensorrt/convert/sparse_fc_op.cc create mode 100644 paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc create mode 100644 paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu create mode 100644 paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h create mode 100644 paddle/fluid/inference/tensorrt/test_dynamic_engine.cc create mode 100644 paddle/fluid/platform/dynload/cusparseLt.cc create mode 100644 paddle/fluid/platform/dynload/cusparseLt.h create mode 100644 paddle/phi/backends/dynload/cusparseLt.cc create mode 100644 paddle/phi/backends/dynload/cusparseLt.h diff --git a/CMakeLists.txt b/CMakeLists.txt index ba438a74718..a3e0b64e97b 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -60,6 +60,7 @@ option(WITH_IPU "Compile PaddlePaddle with Graphcore IPU" OFF) option(WITH_ASCEND_CL "Compile PaddlePaddle with ASCEND CL" ${WITH_ASCEND}) option(WITH_ASCEND_CXX11 "Compile PaddlePaddle with ASCEND and CXX11 ABI" OFF) option(WITH_ONNXRUNTIME "Compile PaddlePaddle with ONNXRUNTIME" OFF) +option(WITH_CUSPARSELT "Compile PaddlePaddle with CUSPARSELT" OFF) # Note(zhouwei): It use option above, so put here include(init) include(generic) # simplify cmake module diff --git a/cmake/external/cusparselt.cmake b/cmake/external/cusparselt.cmake new file mode 100644 index 00000000000..8ab1275cb62 --- /dev/null +++ b/cmake/external/cusparselt.cmake @@ -0,0 +1,61 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if(NOT (WITH_CUSPARSELT AND WITH_TENSORRT)) + return() +endif() + +if(WITH_ARM OR WIN32) + message(SEND_ERROR "The current sparselt support linux only") + return() +endif() + +include(ExternalProject) + +set(CUSPARSELT_PROJECT "extern_cusparselt") +set(CUSPARSELT_P "https://developer.download.nvidia.com/compute") +set(CUSPARSELT_F "libcusparse_lt-linux-x86_64-0.2.0.1.tar.gz") +set(CUSPARSELT_URL + "${CUSPARSELT_P}/libcusparse-lt/0.2.0/local_installers/${CUSPARSELT_F}" + CACHE STRING "" FORCE) +set(CUSPARSELT_PREFIX_DIR ${THIRD_PARTY_PATH}/cusparselt) +set(CUSPARSELT_INSTALL_DIR ${THIRD_PARTY_PATH}/install/cusparselt) +set(CUSPARSELT_INC_DIR + "${CUSPARSELT_INSTALL_DIR}/include" + CACHE PATH "sparselt include directory." FORCE) +set(CUSPARSELT_LIB_DIR + "${CUSPARSELT_INSTALL_DIR}/lib64" + CACHE PATH "sparselt lib directory." FORCE) +set_directory_properties(PROPERTIES CLEAN_NO_CUSTOM 1) +include_directories(${CUSPARSELT_INC_DIR}) + +ExternalProject_Add( + ${CUSPARSELT_PROJECT} + ${EXTERNAL_PROJECT_LOG_ARGS} + URL ${CUSPARSELT_URL} + PREFIX ${CUSPARSELT_PREFIX_DIR} + DOWNLOAD_NO_PROGRESS 1 + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND + ${CMAKE_COMMAND} -E copy_directory + ${CUSPARSELT_PREFIX_DIR}/src/extern_cusparselt/lib64 ${CUSPARSELT_LIB_DIR} + && ${CMAKE_COMMAND} -E copy_directory + ${CUSPARSELT_PREFIX_DIR}/src/extern_cusparselt/include ${CUSPARSELT_INC_DIR} + UPDATE_COMMAND "") + +add_library(cusparselt INTERFACE) +add_dependencies(cusparselt ${CUSPARSELT_PROJECT}) +set(CUSPARSELT_FOUND ON) +add_definitions(-DPADDLE_WITH_CUSPARSELT) diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index 14ae8efb5b4..a8e3696418b 100644 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -108,6 +108,14 @@ function(copy_part_of_thrid_party TARGET DST) SRCS ${CBLAS_INSTALL_DIR}/lib ${CBLAS_INSTALL_DIR}/include DSTS ${dst_dir} ${dst_dir}) endif() + + if(WITH_SPARSELT) + set(dst_dir "${DST}/third_party/install/cusparselt") + copy( + ${TARGET} + SRCS ${CUSPARSELT_INC_DIR} ${CUSPARSELT_LIB_DIR} + DSTS ${dst_dir} ${dst_dir}) + endif() endif() if(WITH_MKLDNN) diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index 2004241ab1a..96132f4af57 100755 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -496,4 +496,9 @@ if(WITH_IPU) list(APPEND third_party_deps extern_poplar) endif() +if(WITH_CUSPARSELT) + include(external/cusparselt) # download, build, install cusparselt + list(APPEND third_party_deps extern_cusparselt) +endif() + add_custom_target(third_party ALL DEPENDS ${third_party_deps}) diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 374b5490d5d..daafc4d1c02 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -156,6 +156,8 @@ pass_library(add_support_int8_pass inference) pass_library(matmul_scale_fuse_pass inference) pass_library(gpu_cpu_map_matmul_to_mul_pass inference) pass_library(mixed_precision_configure_pass inference) +pass_library(dense_fc_to_sparse_pass inference) +pass_library(dense_multihead_matmul_to_sparse_pass inference) pass_library(generate_pass DEPS pass_desc_proto) target_link_libraries(generate_pass pass_desc_proto) @@ -379,6 +381,14 @@ if(NOT WIN32) test_sync_batch_norm_pass SRCS sync_batch_norm_pass_tester.cc DEPS sync_batch_norm_pass) + cc_test( + test_dense_fc_to_sparse_pass_cc + SRCS dense_fc_to_sparse_pass_tester.cc + DEPS fc_fuse_pass dense_fc_to_sparse_pass framework_proto) + cc_test( + test_dense_multihead_matmul_to_sparse_pass + SRCS dense_multihead_matmul_to_sparse_pass_tester.cc + DEPS multihead_matmul_fuse_pass dense_multihead_matmul_to_sparse_pass) endif() if(WITH_MKLDNN) cc_test( diff --git a/paddle/fluid/framework/ir/dense_fc_to_sparse_pass.cc b/paddle/fluid/framework/ir/dense_fc_to_sparse_pass.cc new file mode 100644 index 00000000000..f1a8d63c722 --- /dev/null +++ b/paddle/fluid/framework/ir/dense_fc_to_sparse_pass.cc @@ -0,0 +1,148 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/dense_fc_to_sparse_pass.h" + +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +PDNode *patterns::DenseFC::operator()() { + auto *fc = pattern->NewNode(fc_repr())->assert_is_op("fc"); + // Input + auto *fc_input = pattern->NewNode(fc_input_repr()) + ->AsInput() + ->assert_is_op_input("fc", "Input"); + // Filter + auto *fc_weights = pattern->NewNode(fc_weights_repr()) + ->AsInput() + ->assert_is_op_input("fc", "W"); + // Bias + auto *fc_bias = pattern->NewNode(fc_bias_repr()) + ->AsInput() + ->assert_is_op_input("fc", "Bias"); + // Output + auto *fc_out = pattern->NewNode(fc_out_repr()) + ->AsOutput() + ->assert_is_op_output("fc", "Out") + ->assert_is_only_output_of_op("fc"); + + fc->LinksFrom({fc_input, fc_weights, fc_bias}).LinksTo({fc_out}); + + return fc_out; +} +} // namespace patterns + +DenseFCToSparsePass::DenseFCToSparsePass() { + AddOpCompat(OpCompat("fc")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("W") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); +} + +void DenseFCToSparsePass::ApplyImpl(Graph *graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + + std::string name_scope = "dense_fc_to_sparse_pass"; + FusePassBase::Init(name_scope, graph); + GraphPatternDetector gpd; + + patterns::DenseFC dense_fc_pattern(gpd.mutable_pattern(), + "dense_fc_replace_pass"); + dense_fc_pattern(); + int found_dense_fc_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + VLOG(4) << "Replace dense fc with sparse_fc."; + + /* if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + }*/ + + GET_IR_NODE_FROM_SUBGRAPH(fc_out, fc_out, dense_fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fc, fc, dense_fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fc_input, fc_input, dense_fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fc_weights, fc_weights, dense_fc_pattern); + GET_IR_NODE_FROM_SUBGRAPH(fc_bias, fc_bias, dense_fc_pattern); + + auto *fc_op = fc->Op(); + auto w_name = fc_op->Input("W")[0]; + // recognize sparse op by name + if (w_name.find("sparse_2_4") != w_name.npos) { + // fake op + OpDesc desc(fc_op->Block()); + desc.SetType("sparse_fc"); + desc.SetInput("Input", {fc_input->Name()}); + desc.SetInput("W", {fc_weights->Name()}); + desc.SetInput("Bias", {fc_bias->Name()}); + desc.SetOutput("Out", {fc_out->Name()}); + + // copy all attr + if (fc_op->HasAttr("x_num_col_dims")) { + desc.SetAttr("x_num_col_dims", fc_op->GetAttr("x_num_col_dims")); + } + if (fc_op->HasAttr("in_num_col_dims")) { + desc.SetAttr("in_num_col_dims", fc_op->GetAttr("in_num_col_dims")); + } + desc.SetAttr("activation_type", fc_op->GetAttr("activation_type")); + if (fc_op->HasAttr("enable_int8")) { + desc.SetAttr("enable_int8", fc_op->GetAttr("enable_int8")); + } + if (fc_op->HasAttr("Input_scale")) { + desc.SetAttr("Input_scale", fc_op->GetAttr("Input_scale")); + } + if (fc_op->HasAttr("support_int8")) { + desc.SetAttr("support_int8", fc_op->GetAttr("support_int8")); + } + if (fc_op->HasAttr("out_threshold")) { + desc.SetAttr("out_threshold", fc_op->GetAttr("out_threshold")); + } + desc.Flush(); + GraphSafeRemoveNodes(g, {fc}); + auto sparse_fc_node = g->CreateOpNode(&desc); + + IR_NODE_LINK_TO(fc_input, sparse_fc_node); + IR_NODE_LINK_TO(fc_weights, sparse_fc_node); + IR_NODE_LINK_TO(fc_bias, sparse_fc_node); + IR_NODE_LINK_TO(sparse_fc_node, fc_out); + found_dense_fc_count++; + } + }; + + gpd(graph, handler); + AddStatis(found_dense_fc_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(dense_fc_to_sparse_pass, + paddle::framework::ir::DenseFCToSparsePass); diff --git a/paddle/fluid/framework/ir/dense_fc_to_sparse_pass.h b/paddle/fluid/framework/ir/dense_fc_to_sparse_pass.h new file mode 100644 index 00000000000..18c91bf49c7 --- /dev/null +++ b/paddle/fluid/framework/ir/dense_fc_to_sparse_pass.h @@ -0,0 +1,61 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/inference/api/paddle_analysis_config.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct DenseFC : public PatternBase { + DenseFC(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "dense_fc") {} + + PDNode* operator()(); + + // declare operator node's name + PATTERN_DECL_NODE(fc); + PATTERN_DECL_NODE(fc_out); + PATTERN_DECL_NODE(fc_input); + PATTERN_DECL_NODE(fc_weights); + PATTERN_DECL_NODE(fc_bias); +}; +} // namespace patterns + +/** + * Replace dense op with sparse op + */ +class Graph; + +class DenseFCToSparsePass : public FusePassBase { + public: + DenseFCToSparsePass(); + + protected: + void ApplyImpl(ir::Graph* graph) const override; + + const std::string name_scope_{"dense_fc_to_sparse_pass"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/dense_fc_to_sparse_pass_tester.cc b/paddle/fluid/framework/ir/dense_fc_to_sparse_pass_tester.cc new file mode 100644 index 00000000000..cb10c84b1d7 --- /dev/null +++ b/paddle/fluid/framework/ir/dense_fc_to_sparse_pass_tester.cc @@ -0,0 +1,107 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/fluid/framework/ir/dense_fc_to_sparse_pass.h" +#include "paddle/fluid/framework/ir/fc_fuse_pass.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +void AddVarToScope(Scope* param_scope, const std::string& name, + const DDim& dims) { + auto* tensor = param_scope->Var(name)->GetMutable(); + tensor->Resize(dims); + tensor->mutable_data(platform::CPUPlace()); +} + +Scope* CreateParamScope() { + auto param_scope = new Scope(); + AddVarToScope(param_scope, "conv2d_filters_0", {}); + AddVarToScope(param_scope, "conv2d_bias_0", {}); + AddVarToScope(param_scope, "weights_0_sparse_2_4", {}); + AddVarToScope(param_scope, "weights_1", {}); + AddVarToScope(param_scope, "bias_1", {}); + AddVarToScope(param_scope, "bias_2", {}); + return param_scope; +} + +TEST(FCFusePass, basic) { + // inputs operator output + // -------------------------------------------------------- + // (a, filters_0 bias_0) conv2d -> conv2d_out + // conv2d_out relu -> relu_out_0 + // (relu_out_0, weights_0_sparse_2_4) mul -> mul_out_0 + // (mul_out_0, bias_1) elementwise_add -> add_out_0 + // add_out_0 relu -> relu_out_1 + // (relu_out_1, weights_1) mul -> mul_out_1 + // (mul_out_1, bias_2) elementwise_add -> add_out_1 + Layers layers; + auto* a = layers.data("a"); + auto* filters_0 = layers.data("conv2d_filters_0", {}, true); + auto* bias_0 = layers.data("conv2d_bias_0", {}, true); + auto* conv2d_out = layers.conv2d(a, filters_0, bias_0, false); + auto* relu_out_0 = layers.relu(conv2d_out); + auto* weights_0 = layers.data("weights_0_sparse_2_4", {5, 4}, true); + auto* mul_out_0 = layers.mul(relu_out_0, weights_0); + auto* bias_1 = layers.data("bias_1", {4}, true); + auto* add_out_0 = layers.elementwise_add(mul_out_0, bias_1, nullptr, 1); + auto* relu_out_1 = layers.relu(add_out_0); + auto* weights_1 = layers.data("weights_1", {8, 9}, true); + auto* mul_out_1 = layers.mul(relu_out_1, weights_1); + auto* bias_2 = layers.data("bias_2", {1, 9}, true); + auto* add_out_1 = layers.elementwise_add(mul_out_1, bias_2, nullptr, 1); + VLOG(4) << add_out_1; + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto fuse_pass = PassRegistry::Instance().Get("fc_fuse_pass"); + auto sparse_pass = PassRegistry::Instance().Get("dense_fc_to_sparse_pass"); + fuse_pass->Set("use_gpu", new bool(true)); + sparse_pass->Set("use_gpu", new bool(true)); + graph->Set("__param_scope__", CreateParamScope()); + int num_nodes_before = graph->Nodes().size(); + int num_mul_nodes_before = GetNumOpNodes(graph, "mul"); + VLOG(3) << DebugString(graph); + + graph.reset(fuse_pass->Apply(graph.release())); + graph.reset(sparse_pass->Apply(graph.release())); + int num_nodes_after = graph->Nodes().size(); + int num_fc_nodes_after = GetNumOpNodes(graph, "fc"); + int num_sparse_fc_nodes_after = GetNumOpNodes(graph, "sparse_fc"); + VLOG(3) << DebugString(graph); + + PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 6, + platform::errors::InvalidArgument( + "num_nodes_before=%d, num_nodes_after=%d.", + num_nodes_before, num_nodes_after)); + PADDLE_ENFORCE_EQ(num_fc_nodes_after, 1, + platform::errors::InvalidArgument("num_fc_nodes_after=%d.", + num_fc_nodes_after)); + PADDLE_ENFORCE_EQ( + num_mul_nodes_before, num_fc_nodes_after + num_sparse_fc_nodes_after, + platform::errors::InvalidArgument( + "num_mul_nodes_before=%d, num_fc_nodes_after=%d + " + "num_sparse_fc_nodes_after=%d.", + num_mul_nodes_before, num_fc_nodes_after, num_sparse_fc_nodes_after)); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(fc_fuse_pass); +USE_PASS(dense_fc_to_sparse_pass); diff --git a/paddle/fluid/framework/ir/dense_multihead_matmul_to_sparse_pass.cc b/paddle/fluid/framework/ir/dense_multihead_matmul_to_sparse_pass.cc new file mode 100644 index 00000000000..2aae5030b5d --- /dev/null +++ b/paddle/fluid/framework/ir/dense_multihead_matmul_to_sparse_pass.cc @@ -0,0 +1,174 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/dense_multihead_matmul_to_sparse_pass.h" + +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { +PDNode *patterns::DenseMultiheadMatmul::operator()() { + auto *multihead_matmul = pattern->NewNode(multihead_matmul_repr()) + ->assert_is_op("multihead_matmul"); + // Input + auto *multihead_matmul_input = + pattern->NewNode(multihead_matmul_input_repr()) + ->AsInput() + ->assert_is_op_input("multihead_matmul", "Input"); + // Filter + auto *multihead_matmul_weights = + pattern->NewNode(multihead_matmul_weights_repr()) + ->AsInput() + ->assert_is_op_input("multihead_matmul", "W"); + // Bias + auto *multihead_matmul_bias = + pattern->NewNode(multihead_matmul_bias_repr()) + ->AsInput() + ->assert_is_op_input("multihead_matmul", "Bias"); + // BiasQK + auto *multihead_matmul_biasqk = + pattern->NewNode(multihead_matmul_biasqk_repr()) + ->AsInput() + ->assert_is_op_input("multihead_matmul", "BiasQK"); + // Output + auto *multihead_matmul_out = + pattern->NewNode(multihead_matmul_out_repr()) + ->AsOutput() + ->assert_is_op_output("multihead_matmul", "Out") + ->assert_is_only_output_of_op("multihead_matmul"); + + multihead_matmul + ->LinksFrom({multihead_matmul_input, multihead_matmul_weights, + multihead_matmul_bias, multihead_matmul_biasqk}) + .LinksTo({multihead_matmul_out}); + + return multihead_matmul_out; +} +} // namespace patterns +DenseMultiheadMatmulToSparsePass::DenseMultiheadMatmulToSparsePass() { + AddOpCompat(OpCompat("multihead_matmul")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("W") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddInput("BiasQK") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); +} + +void DenseMultiheadMatmulToSparsePass::ApplyImpl(Graph *graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + + std::string name_scope = "dense_multihead_matmul_to_sparse_pass"; + FusePassBase::Init(name_scope, graph); + GraphPatternDetector gpd; + + patterns::DenseMultiheadMatmul multihead_matmul_pattern( + gpd.mutable_pattern(), "dense_multihead_matmul_replace_pass"); + multihead_matmul_pattern(); + int found_multihead_matmul_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + VLOG(4) << "Replace dense multihead matmul with sparse multihead matmul."; + + /* if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + }*/ + + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_out, multihead_matmul_out, + multihead_matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul, multihead_matmul, + multihead_matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_input, multihead_matmul_input, + multihead_matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_weights, + multihead_matmul_weights, + multihead_matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_bias, multihead_matmul_bias, + multihead_matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(multihead_matmul_biasqk, multihead_matmul_biasqk, + multihead_matmul_pattern); + + auto *multihead_matmul_op = multihead_matmul->Op(); + auto w_name = multihead_matmul_op->Input("W")[0]; + // recognize sparse op by name + if (w_name.find("sparse_2_4") != w_name.npos) { + // fake op + OpDesc desc(multihead_matmul_op->Block()); + desc.SetType("sparse_multihead_matmul"); + desc.SetInput("Input", {multihead_matmul_input->Name()}); + desc.SetInput("W", {multihead_matmul_weights->Name()}); + desc.SetInput("Bias", {multihead_matmul_bias->Name()}); + desc.SetInput("BiasQK", {multihead_matmul_biasqk->Name()}); + desc.SetOutput("Out", {multihead_matmul_out->Name()}); + + // copy all attr + desc.SetAttr("alpha", multihead_matmul_op->GetAttr("alpha")); + desc.SetAttr("head_number", multihead_matmul_op->GetAttr("head_number")); + if (multihead_matmul_op->HasAttr("Input_scale")) { + desc.SetAttr("Input_scale", + multihead_matmul_op->GetAttr("Input_scale")); + } + if (multihead_matmul_op->HasAttr("fc_out_threshold")) { + desc.SetAttr("fc_out_threshold", + multihead_matmul_op->GetAttr("fc_out_threshold")); + } + if (multihead_matmul_op->HasAttr("qkv2context_plugin_int8")) { + desc.SetAttr("qkv2context_plugin_int8", + multihead_matmul_op->GetAttr("qkv2context_plugin_int8")); + } + if (multihead_matmul_op->HasAttr("dp_probs")) { + desc.SetAttr("dp_probs", multihead_matmul_op->GetAttr("dp_probs")); + } + if (multihead_matmul_op->HasAttr("out_threshold")) { + desc.SetAttr("out_threshold", + multihead_matmul_op->GetAttr("out_threshold")); + } + desc.Flush(); + GraphSafeRemoveNodes(g, {multihead_matmul}); + auto sparse_multihead_matmul_node = g->CreateOpNode(&desc); + + IR_NODE_LINK_TO(multihead_matmul_input, sparse_multihead_matmul_node); + IR_NODE_LINK_TO(multihead_matmul_weights, sparse_multihead_matmul_node); + IR_NODE_LINK_TO(multihead_matmul_bias, sparse_multihead_matmul_node); + IR_NODE_LINK_TO(multihead_matmul_biasqk, sparse_multihead_matmul_node); + IR_NODE_LINK_TO(sparse_multihead_matmul_node, multihead_matmul_out); + found_multihead_matmul_count++; + } + }; + + gpd(graph, handler); + AddStatis(found_multihead_matmul_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(dense_multihead_matmul_to_sparse_pass, + paddle::framework::ir::DenseMultiheadMatmulToSparsePass); diff --git a/paddle/fluid/framework/ir/dense_multihead_matmul_to_sparse_pass.h b/paddle/fluid/framework/ir/dense_multihead_matmul_to_sparse_pass.h new file mode 100644 index 00000000000..fa0716255b5 --- /dev/null +++ b/paddle/fluid/framework/ir/dense_multihead_matmul_to_sparse_pass.h @@ -0,0 +1,61 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/inference/api/paddle_analysis_config.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct DenseMultiheadMatmul : public PatternBase { + DenseMultiheadMatmul(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "dense_multihead_matmul") {} + + PDNode* operator()(); + + // declare operator node's name + PATTERN_DECL_NODE(multihead_matmul); + PATTERN_DECL_NODE(multihead_matmul_out); + PATTERN_DECL_NODE(multihead_matmul_input); + PATTERN_DECL_NODE(multihead_matmul_weights); + PATTERN_DECL_NODE(multihead_matmul_bias); + PATTERN_DECL_NODE(multihead_matmul_biasqk); +}; +} // namespace patterns +/** + * Replace dense multihead_matmul op with sparse multihead_matmul op + */ +class Graph; + +class DenseMultiheadMatmulToSparsePass : public FusePassBase { + public: + DenseMultiheadMatmulToSparsePass(); + + protected: + void ApplyImpl(ir::Graph* graph) const override; + + const std::string name_scope_{"dense_multihead_matmul_to_sparse_pass"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/dense_multihead_matmul_to_sparse_pass_tester.cc b/paddle/fluid/framework/ir/dense_multihead_matmul_to_sparse_pass_tester.cc new file mode 100644 index 00000000000..3989d3d11db --- /dev/null +++ b/paddle/fluid/framework/ir/dense_multihead_matmul_to_sparse_pass_tester.cc @@ -0,0 +1,152 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/fluid/framework/ir/dense_multihead_matmul_to_sparse_pass.h" // NOLINT +#include "paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h" // NOLINT +#include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { + +void AddVarToScope(Scope* param_scope, const std::string& name, + const DDim& dims) { + auto* tensor = param_scope->Var(name)->GetMutable(); + tensor->Resize(dims); + tensor->mutable_data(platform::CPUPlace()); +} + +Scope* CreateParamScope() { + auto param_scope = new Scope(); + AddVarToScope(param_scope, "weights0_sparse_2_4", {768, 768}); + AddVarToScope(param_scope, "weights1_sparse_2_4", {768, 768}); + AddVarToScope(param_scope, "weights2_sparse_2_4", {768, 768}); + + AddVarToScope(param_scope, "bias_0", {768}); + AddVarToScope(param_scope, "bias_1", {768}); + AddVarToScope(param_scope, "bias_2", {768}); + AddVarToScope(param_scope, "biasqk", {768}); + AddVarToScope(param_scope, "weightsl", {768, 768}); + return param_scope; +} + +TEST(DenseMultiHeadMatmulToSparsePass, basic) { + // inputs operator output + // -------------------------------------------------------------------- + // (x) layer_norm -> layer_norm_out + // (layer_norm_out, weights_0_sparse_2_4) mul -> mul_out0 + // (layer_norm_out, weights_1_sparse_2_4) mul -> mul_out1 + // (layer_norm_out, weights_2_sparse_2_4) mul -> mul_out2 + // (mul_out0, bias_0) elementweise_add -> eltadd_0 + // (mul_out1, bias_1) elementweise_add -> eltadd_1 + // (mul_out2, bias_2) elementweise_add -> eltadd_2 + // (eltadd_0) reshape2 -> reshape_0 + // (eltadd_1) reshape2 -> reshape_1 + // (eltadd_2) reshape2 -> reshape_2 + // (reshape_0) transpose2 -> transpose_0 + // (reshape_1) transpose2 -> transpose_1 + // (reshape_2) transpose2 -> transpose_2 + // (transpose_0) scale -> scale_0 + // (scale_0, transpose_1) matmul -> matmul_qk + // (matmul_qk, bias_qk) elementwise_add -> eltadd_qk + // (eltadd_qk) softmax -> softmax_qk + // (softmax_qk, transpose_2) matmul -> matmul_qkv + // (matmul_qkv) transpose -> transpose_qkv + // (transpose_qkv) reshape -> reshape_qkv + // (reshape_qkv) mul -> mul_qkv + Layers layers; + auto* x = layers.data("x", {1, 128, 768}); + auto out = layers.layer_norm(x); + auto* layer_out = out[0]; + + auto* weights_0 = layers.data("weights0_sparse_2_4", {768, 768}, true); + auto* weights_1 = layers.data("weights1_sparse_2_4", {768, 768}, true); + auto* weights_2 = layers.data("weights2_sparse_2_4", {768, 768}, true); + + auto* mul_out_0 = layers.mul(layer_out, weights_0, nullptr, 2); + auto* mul_out_1 = layers.mul(layer_out, weights_1, nullptr, 2); + auto* mul_out_2 = layers.mul(layer_out, weights_2, nullptr, 2); + + auto* b0 = layers.data("bias_0", {768}, true); + auto* b1 = layers.data("bias_1", {768}, true); + auto* b2 = layers.data("bias_2", {768}, true); + + auto* elementwise_out_0 = layers.elementwise_add(mul_out_0, b0, nullptr, 2); + auto* elementwise_out_1 = layers.elementwise_add(mul_out_1, b1, nullptr, 2); + auto* elementwise_out_2 = layers.elementwise_add(mul_out_2, b2, nullptr, 2); + + std::vector shape = {1, 128, 12, 64}; + auto* reshape_0 = layers.reshape2(elementwise_out_0, shape, true); + auto* reshape_1 = layers.reshape2(elementwise_out_1, shape, true); + auto* reshape_2 = layers.reshape2(elementwise_out_2, shape, true); + + std::vector axis = {0, 2, 1, 3}; + auto* transpose_0 = layers.transpose2(reshape_0, axis, true); + auto* transpose_1 = layers.transpose2(reshape_1, axis, true); + auto* transpose_2 = layers.transpose2(reshape_2, axis, true); + + auto* scale_0 = layers.scale(transpose_0, 0.125, 0, false); + auto* matmul_qk = layers.matmul(scale_0, transpose_1, nullptr, false, true); + + auto* bqk = layers.data("biasqk", {1, 12, 128, 128}, true); + auto* elementwise_qk = layers.elementwise_add(matmul_qk, bqk); + auto* softmax_qk = layers.softmax(elementwise_qk, -1); + + auto* matmul_qkv = layers.matmul(softmax_qk, transpose_2); + + auto* transpose_qkv = layers.transpose2(matmul_qkv, {0, 2, 1, 3}, true); + auto* reshape_qkv_out = layers.reshape2(transpose_qkv, {1, 128, 768}, true); + auto* weights_l = layers.data("weightsl", {768, 768}, true); + layers.mul(reshape_qkv_out, weights_l, nullptr, 2); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + graph->Set("__param_scope__", CreateParamScope()); + + auto fuse_pass = + PassRegistry::Instance().Get("multihead_matmul_fuse_pass_v2"); + auto sparse_pass = + PassRegistry::Instance().Get("dense_multihead_matmul_to_sparse_pass"); + + if (fuse_pass.get() == nullptr || sparse_pass.get() == nullptr) + LOG(INFO) << "asdfasdf"; + int num_nodes_before = graph->Nodes().size(); + VLOG(3) << DebugString(graph); + + graph.reset(fuse_pass->Apply(graph.release())); + graph.reset(sparse_pass->Apply(graph.release())); + int num_nodes_after = graph->Nodes().size(); + int num_fused_nodes_after = GetNumOpNodes(graph, "sparse_multihead_matmul"); + VLOG(3) << DebugString(graph); + + PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 39, + platform::errors::InvalidArgument( + "After the multihead_matmul pass and sparse pass, The " + "node num in graph " + "should be %d, but the result is %d", + num_nodes_before - 39, num_nodes_after)); + PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1, + platform::errors::InvalidArgument( + "After the multihead_matmul pass and sparse pass, " + "there should be one " + "sparse_multihead_matmul op, but the result is %d", + num_fused_nodes_after)); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(multihead_matmul_fuse_pass); +USE_PASS(multihead_matmul_fuse_pass_v2); +USE_PASS(dense_multihead_matmul_to_sparse_pass); diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 7f30b80224e..0645af611b9 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1960,6 +1960,10 @@ USE_TRT_CONVERTER(strided_slice) USE_TRT_CONVERTER(transformer_input_convert) USE_TRT_CONVERTER(recover_padding) USE_TRT_CONVERTER(remove_padding) +#if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000) +USE_TRT_CONVERTER(sparse_fc) +USE_TRT_CONVERTER(sparse_multihead_matmul) +#endif #endif namespace paddle_infer { diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 9e5b76db4ac..96129018d01 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -115,8 +115,10 @@ const std::vector kTRTSubgraphPasses({ "remove_padding_recover_padding_pass", // "delete_remove_padding_recover_padding_pass", // // "yolo_box_fuse_pass", // - "tensorrt_subgraph_pass", // - "conv_bn_fuse_pass", // + "dense_fc_to_sparse_pass", // + "dense_multihead_matmul_to_sparse_pass", // + "tensorrt_subgraph_pass", // + "conv_bn_fuse_pass", // #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be // guaranteed at least v7 // cudnn8.0 has memory leak problem in conv + eltwise + act, so we diff --git a/paddle/fluid/inference/tensorrt/CMakeLists.txt b/paddle/fluid/inference/tensorrt/CMakeLists.txt index abd00ef9de6..0f1350459ef 100644 --- a/paddle/fluid/inference/tensorrt/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/CMakeLists.txt @@ -1,4 +1,5 @@ -# Compiling with WITH_PYTHON=ON and WITH_TENSORRT=ON failed on windows. Temporarily add paddle_inference_api dependency to solve the problem +# Compiling with WITH_PYTHON=ON and WITH_TENSORRT=ON failed on windows. +# Temporarily add paddle_inference_api dependency to solve the problem if(WIN32) nv_library( tensorrt_engine @@ -21,7 +22,7 @@ nv_test( DEPS dynload_cuda device_context dynamic_loader) nv_test( test_tensorrt_engine - SRCS test_engine.cc - DEPS dynload_cuda tensorrt_engine) + SRCS test_engine.cc test_dynamic_engine.cc + DEPS dynload_cuda tensorrt_engine tensorrt_plugin) add_subdirectory(plugin) add_subdirectory(convert) diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index b27a584de2b..2c9ba428215 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -1,65 +1,74 @@ # Add TRT tests +list( + APPEND + CONVERT_FILES + matmul_op.cc + conv2d_op.cc + fc_op.cc + pool2d_op.cc + elementwise_op.cc + batch_norm_op.cc + activation_op.cc + unary_op.cc + softmax_op.cc + concat_op.cc + dropout_op.cc + group_norm_op.cc + pad_op.cc + split_op.cc + prelu_op.cc + leaky_relu_op.cc + gelu_op.cc + layer_norm_op.cc + multihead_matmul_op.cc + shuffle_channel_op.cc + swish_op.cc + instance_norm_op.cc + stack_op.cc + transpose_op.cc + flatten_op.cc + flatten_contiguous_range_op.cc + emb_eltwise_layernorm.cc + skip_layernorm.cc + scale_op.cc + slice_op.cc + hard_sigmoid_op.cc + hard_swish_op.cc + clip_op.cc + gather_op.cc + anchor_generator_op.cc + yolo_box_op.cc + yolo_box_head_op.cc + arg_max_op.cc + roi_align_op.cc + affine_channel_op.cc + multiclass_nms_op.cc + multiclass_nms3_op.cc + nearest_interp_op.cc + reshape_op.cc + reduce_op.cc + gather_nd_op.cc + tile_op.cc + conv3d_op.cc + mish_op.cc + nearest_interp_v2_op.cc + pool3d_op.cc + deformable_conv_op.cc + preln_emb_eltwise_layernorm.cc + strided_slice_op.cc + preln_skip_layernorm.cc + roll_op.cc + transformer_input_convert_op.cc + remove_padding_op.cc + recover_padding_op.cc) + +if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8) + list(APPEND CONVERT_FILES sparse_fc_op.cc sparse_multihead_matmul_op.cc) +endif() + nv_library( tensorrt_converter - SRCS matmul_op.cc - conv2d_op.cc - fc_op.cc - pool2d_op.cc - elementwise_op.cc - batch_norm_op.cc - activation_op.cc - unary_op.cc - softmax_op.cc - concat_op.cc - dropout_op.cc - group_norm_op.cc - pad_op.cc - split_op.cc - prelu_op.cc - leaky_relu_op.cc - gelu_op.cc - layer_norm_op.cc - multihead_matmul_op.cc - shuffle_channel_op.cc - swish_op.cc - instance_norm_op.cc - stack_op.cc - transpose_op.cc - flatten_op.cc - flatten_contiguous_range_op.cc - emb_eltwise_layernorm.cc - skip_layernorm.cc - scale_op.cc - slice_op.cc - hard_sigmoid_op.cc - hard_swish_op.cc - clip_op.cc - gather_op.cc - anchor_generator_op.cc - yolo_box_op.cc - yolo_box_head_op.cc - arg_max_op.cc - roi_align_op.cc - affine_channel_op.cc - multiclass_nms_op.cc - multiclass_nms3_op.cc - nearest_interp_op.cc - reshape_op.cc - reduce_op.cc - gather_nd_op.cc - tile_op.cc - conv3d_op.cc - mish_op.cc - nearest_interp_v2_op.cc - pool3d_op.cc - deformable_conv_op.cc - preln_emb_eltwise_layernorm.cc - strided_slice_op.cc - preln_skip_layernorm.cc - roll_op.cc - transformer_input_convert_op.cc - remove_padding_op.cc - recover_padding_op.cc + SRCS ${CONVERT_FILES} DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) diff --git a/paddle/fluid/inference/tensorrt/convert/sparse_fc_op.cc b/paddle/fluid/inference/tensorrt/convert/sparse_fc_op.cc new file mode 100644 index 00000000000..de9fd62300f --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/sparse_fc_op.cc @@ -0,0 +1,371 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h" + +namespace paddle { +namespace framework { +class Scope; + +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * FC converter convert a sparse_fc op to a sparse_fc plugin in TRT. + */ +class SparseFcOpConverter : public OpConverter { + public: + nvinfer1::ILayer* reshape_before_fc(nvinfer1::ITensor* before_fc, + nvinfer1::Dims x_dim, int x_num_col_dims, + std::string output_name) { + // add shuffle before fc + nvinfer1::Dims reshape_before_fc_dim; + reshape_before_fc_dim.nbDims = x_num_col_dims + 3; + // padding shape "* x q x 1 x 1" + for (int i = 0; i < reshape_before_fc_dim.nbDims; i++) { + reshape_before_fc_dim.d[i] = 1; + } + for (int i = 0; i < x_dim.nbDims; i++) { + if (i < x_num_col_dims) { + reshape_before_fc_dim.d[i] = 0; + } else { + if (x_dim.d[i] < 0) { + reshape_before_fc_dim.d[x_num_col_dims] = -1; + break; + } + reshape_before_fc_dim.d[x_num_col_dims] *= x_dim.d[i]; + } + } + auto* reshape_before_fc_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *before_fc); + reshape_before_fc_layer->setReshapeDimensions(reshape_before_fc_dim); + reshape_before_fc_layer->setName( + ("sparse_fc_op_reshape_before_fc: Shuffle (Output: " + output_name + + ")") + .c_str()); + return reshape_before_fc_layer; + } + + nvinfer1::ILayer* reshape_after_fc(nvinfer1::ITensor* after_fc, + nvinfer1::Dims x_dim, int x_num_col_dims) { + // add shuffle after fc + nvinfer1::Dims reshape_after_fc_dim; + reshape_after_fc_dim.nbDims = x_num_col_dims + 1; + for (int i = 0; i < reshape_after_fc_dim.nbDims; i++) { + reshape_after_fc_dim.d[i] = 0; + } + auto* reshape_after_fc_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *after_fc); + reshape_after_fc_layer->setReshapeDimensions(reshape_after_fc_dim); + return reshape_after_fc_layer; + } + + plugin::SpmmPluginDynamic* new_spmm_plugin(TensorRTEngine::Weight* weight, + TensorRTEngine::Weight* bias, + const std::string& activation_type, + nvinfer1::DataType type, + int outdim) { + plugin::SpmmPluginDynamic::Activation act = + plugin::SpmmPluginDynamic::Activation::kNone; + if (activation_type == "relu") { + act = plugin::SpmmPluginDynamic::Activation::kRelu; + } else if (activation_type == "gelu") { + act = plugin::SpmmPluginDynamic::Activation::kGelu; + } else if (activation_type != "") { + PADDLE_THROW(paddle::platform::errors::Fatal("unknown activation_type %s", + activation_type.c_str())); + } + return new plugin::SpmmPluginDynamic("CustomSpmmPluginDynamic", type, + outdim, weight->get(), bias->get(), + act); + } + + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(3) << "convert a sparse_fc op to tensorrt sparse_fc plugin"; + framework::OpDesc op_desc(op, nullptr); + auto output_name = op_desc.Output("Out").front(); + auto input_names = op_desc.InputNames(); + bool with_bias = input_names.size() >= 3; + std::string w_name = "Y"; + std::string i_name = "X"; + if (with_bias) { + w_name = "W"; + i_name = "Input"; + } + // Declare inputs + auto* X = engine_->GetITensor(op_desc.Input(i_name).front()); + auto x_dim = X->getDimensions(); + // Declare weights + auto* Y_v = scope.FindVar(op_desc.Input(w_name).front()); + PADDLE_ENFORCE_NOT_NULL( + Y_v, + platform::errors::NotFound( + "Can not find %s presistale var of sparse_fc in scope.", w_name)); + auto* Y_t = Y_v->GetMutable(); + int x_num_col_dims = + op_desc.HasAttr("x_num_col_dims") + ? BOOST_GET_CONST(int, op_desc.GetAttr("x_num_col_dims")) + : (op_desc.HasAttr("in_num_col_dims") + ? BOOST_GET_CONST(int, op_desc.GetAttr("in_num_col_dims")) + : 1); + const std::string activation_type = + op_desc.HasAttr("activation_type") + ? BOOST_GET_CONST(std::string, op_desc.GetAttr("activation_type")) + : ""; + float* weight_data = nullptr; + bool enable_int8 = op_desc.HasAttr("enable_int8"); + bool support_int8 = false; + if (op_desc.HasAttr("support_int8")) { + support_int8 = BOOST_GET_CONST(bool, op_desc.GetAttr("support_int8")); + } + float in_scale = 0; + if (enable_int8 || support_int8) { + if (enable_int8) { + in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")); + } else { + // attr X is generated by add_support_int8_pass + in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("X")); + } + engine_->SetTensorDynamicRange(X, in_scale); + } + weight_data = engine_->GetWeightCPUData(op_desc.Input(w_name).front(), Y_t); + + PADDLE_ENFORCE_EQ( + Y_t->dims().size(), 2UL, + platform::errors::InvalidArgument( + "The sparse_fc's weight should be a matrix with 2 dims, but " + "it's %d-dimensional.", + Y_t->dims().size())); // a matrix + int m = Y_t->dims()[0]; + int n = Y_t->dims()[1]; + auto tranpose_weight = [](const float* src, float* dst, int m, int n) { + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + dst[j * m + i] = src[i * n + j]; + } + } + }; + bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); + auto regist_fc = [&](nvinfer1::ITensor* inputs, int n_output, + TensorRTEngine::Weight& weight, + TensorRTEngine::Weight& bias) { + if (enable_int8 || support_int8) { + // add conv1x1 layer + nvinfer1::DimsHW nv_ksize(1, 1); + auto* fc_layer_int8 = + TRT_ENGINE_ADD_LAYER(engine_, Convolution, *X, n_output, nv_ksize, + weight.get(), bias.get()); + if (activation_type == "relu") { + fc_layer_int8->setName( + ("ernie_fc_op_int8: Convolution (Output: " + output_name + ")") + .c_str()); + PADDLE_ENFORCE_EQ( + op_desc.HasAttr("out_threshold"), true, + platform::errors::InvalidArgument( + "must have out threshold in fc layers in int8 mode")); + float out_scale = 0; + if (enable_int8) { + out_scale = + BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold")); + } else { + out_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Out")); + } + engine_->SetTensorDynamicRange(fc_layer_int8->getOutput(0), + out_scale); + nvinfer1::IActivationLayer* relu_layer_int8 = TRT_ENGINE_ADD_LAYER( + engine_, Activation, *(fc_layer_int8->getOutput(0)), + nvinfer1::ActivationType::kRELU); + RreplenishLayerAndOutput(relu_layer_int8, "relu_after_ernie_fc_int8", + {output_name}, test_mode); + } else { + RreplenishLayerAndOutput(fc_layer_int8, + "ernie_fc_op_int8: Convolution", + {output_name}, test_mode); + } + } else { + // add fc layer + auto* fc_layer_float = TRT_ENGINE_ADD_LAYER( + engine_, FullyConnected, *X, n_output, weight.get(), bias.get()); + if (activation_type == "relu") { + fc_layer_float->setName( + ("ernie_fc_op_float: (Output: " + output_name + ")").c_str()); + nvinfer1::IActivationLayer* relu_layer_float = TRT_ENGINE_ADD_LAYER( + engine_, Activation, *(fc_layer_float->getOutput(0)), + nvinfer1::ActivationType::kRELU); + RreplenishLayerAndOutput(relu_layer_float, + "relu_after_ernie_fc_float", {output_name}, + test_mode); + } else { + RreplenishLayerAndOutput(fc_layer_float, "ernie_fc_op_float", + {output_name}, test_mode); + } + } + }; + auto regist_sparse_fc = [&](nvinfer1::ITensor* inputs, int n_output, + TensorRTEngine::Weight* weight, + TensorRTEngine::Weight* bias) { + if (enable_int8 || support_int8) { + // add conv layer + float out_scale = 0; + if (enable_int8) { + PADDLE_ENFORCE_EQ( + op_desc.HasAttr("out_threshold"), true, + platform::errors::InvalidArgument( + "must have out threshold in sparse_fc layers in int8 mode")); + out_scale = BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold")); + } else { + out_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Out")); + } + plugin::SpmmPluginDynamic* plugin = new_spmm_plugin( + weight, bias, activation_type, nvinfer1::DataType::kINT8, n); + std::vector plugin_inputs; + plugin_inputs.emplace_back(inputs); + auto fc_layer_int8 = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin); + fc_layer_int8->setName( + ("sparse_fc_op_int8: (Output: " + output_name + ")").c_str()); + engine_->SetTensorDynamicRange(fc_layer_int8->getOutput(0), out_scale); + auto* fc_after_reshape_int8 = reshape_after_fc( + fc_layer_int8->getOutput(0), x_dim, x_num_col_dims); + + RreplenishLayerAndOutput(fc_after_reshape_int8, + "sparse_fc_op_int8_reshape_after_fc: Shuffle", + {output_name}, test_mode); + } else { + plugin::SpmmPluginDynamic* plugin = new_spmm_plugin( + weight, bias, activation_type, + with_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, + n); + std::vector plugin_inputs; + plugin_inputs.emplace_back(inputs); + auto fc_layer_float = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin); + fc_layer_float->setName( + ("sparse_fc_op_float: FullyConnected (Output: " + output_name + ")") + .c_str()); + auto* fc_after_reshape_float = reshape_after_fc( + fc_layer_float->getOutput(0), x_dim, x_num_col_dims); + + RreplenishLayerAndOutput(fc_after_reshape_float, + "shuffle_after_sparse_fc", {output_name}, + test_mode); + } + }; + + bool transpose_y = false; + if (op_desc.HasAttr("transpose_Y")) { + transpose_y = BOOST_GET_CONST(bool, op_desc.GetAttr("transpose_Y")); + } + int weight_w, weight_h; + if (!transpose_y) { + std::vector weight_data_tmp; + weight_data_tmp.reserve(Y_t->numel()); + memcpy(weight_data_tmp.data(), weight_data, Y_t->numel() * sizeof(float)); + tranpose_weight(weight_data_tmp.data(), weight_data, m, n); + weight_w = n; + weight_h = m; + } else { + weight_w = m; + weight_h = n; + } + size_t n_output = weight_w; + float* bias_data = nullptr; + int bias_num = 0; + if (with_bias) { + auto* b_v = scope.GetVar(op_desc.Input("Bias").front()); + auto* b_t = b_v->GetMutable(); + bias_data = engine_->GetWeightCPUData(op_desc.Input("Bias").front(), b_t); + bias_num = b_t->numel(); + } + // Running the TRT Static Shape mode: x_num_col_dims-1 + if (!engine_->with_dynamic_shape()) { + x_num_col_dims--; + } + // If use tensorrt'oss, the x_dim and x_num_col_dims need change, and can + // not add Shuffle layer in ernie's multihead. + // Sparse inference doesn't support variable length for now. + if (x_dim.nbDims == 4 && x_num_col_dims == 1) { + TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, + static_cast(weight_data), + static_cast(Y_t->numel())}; + weight.dims.assign({weight_w, weight_h}); + TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, + static_cast(bias_data), + static_cast(bias_num)}; + regist_fc(X, n_output, weight, bias); + } else { // need reshape input before and after fc + PADDLE_ENFORCE_GT( + x_dim.nbDims, x_num_col_dims, + platform::errors::InvalidArgument( + "Params and input dims mismatch. Paddle-TRT FC " + "converter expects x_dim.nbDims > x_num_col_dims, but " + "x_dim.nbDims : %d, x_num_col_dims : %d.", + x_dim.nbDims, x_num_col_dims)); + half* half_data = nullptr; + void* w_data = nullptr; + if (with_fp16) { + half_data = new half[Y_t->numel()]; + for (int i = 0; i < Y_t->numel(); i++) { + half_data[i] = static_cast(weight_data[i]); + } + w_data = static_cast(half_data); + } else { + w_data = static_cast(weight_data); + } + TensorRTEngine::Weight weight{ + with_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, + w_data, static_cast(Y_t->numel())}; + weight.dims.assign({weight_w, weight_h}); + void* b_data = nullptr; + if (with_bias) { + half* half_bias_data = nullptr; + if (with_fp16) { + half_bias_data = new half[bias_num]; + for (int i = 0; i < bias_num; i++) { + half_bias_data[i] = static_cast(bias_data[i]); + } + b_data = static_cast(half_bias_data); + } else { + b_data = static_cast(bias_data); + } + } + TensorRTEngine::Weight bias{ + with_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, + b_data, static_cast(bias_num)}; + + auto* reshape_before_fc_layer = + reshape_before_fc(X, x_dim, x_num_col_dims, output_name); + auto* reshape_itensor = reshape_before_fc_layer->getOutput(0); + if (enable_int8 || support_int8) { + engine_->SetTensorDynamicRange(reshape_itensor, in_scale); + } + regist_sparse_fc(reshape_itensor, n_output, &weight, &bias); + } + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(sparse_fc, SparseFcOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc new file mode 100644 index 00000000000..3de8fad0206 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc @@ -0,0 +1,441 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See +the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class SparseMultiheadMatMulOpConverter : public OpConverter { + public: + plugin::SpmmPluginDynamic* new_spmm_plugin(TensorRTEngine::Weight* weight, + TensorRTEngine::Weight* bias, + nvinfer1::DataType type, + int outdim) { + plugin::SpmmPluginDynamic::Activation act = + plugin::SpmmPluginDynamic::Activation::kNone; + return new plugin::SpmmPluginDynamic("CustomSpmmPluginDynamic", type, + outdim, weight->get(), bias->get(), + act); + } + + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(3) << "convert a fluid sparse_multihead_matmul op to a corresponding " + "tensorrt " + "network structure"; + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + auto* input = engine_->GetITensor(op_desc.Input("Input").front()); + + // fc weights and fc bias + auto weight_name = op_desc.Input("W").front(); + auto bias_name = op_desc.Input("Bias").front(); + + auto* weight_v = scope.FindVar(weight_name); + auto* weight_t = weight_v->GetMutable(); + + auto* bias_v = scope.FindVar(bias_name); + auto* bias_t = bias_v->GetMutable(); + + float* weight_data = nullptr; + bool qkv2context_plugin_int8 = op_desc.HasAttr("qkv2context_plugin_int8"); + float in_scale = 0.; + + if (op_desc.HasAttr("Input_scale")) { + in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")); + engine_->SetTensorDynamicRange(input, in_scale); + } + weight_data = engine_->GetWeightCPUData(weight_name, weight_t); + + float* bias_data = engine_->GetWeightCPUData(bias_name, bias_t); + std::vector weight_data_tmp; + weight_data_tmp.reserve(weight_t->numel()); + memcpy(weight_data_tmp.data(), weight_data, + weight_t->numel() * sizeof(float)); + + // (hidden_in, 3, hidden_out) + const auto& weight_dims = weight_t->dims(); + + int hidden_in = weight_dims[0]; // channels_in + int three = weight_dims[1]; // channels_out + int hidden_out = weight_dims[2]; // channels_out + int m = hidden_in; + int n = three * hidden_out; + auto tranpose_weight = [](const float* src, float* dst, int m, int n) { + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + dst[j * m + i] = src[i * n + j]; + } + } + }; + tranpose_weight(weight_data_tmp.data(), weight_data, m, n); + + int head_number = BOOST_GET_CONST(int, op_desc.GetAttr("head_number")); + bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); + + nvinfer1::ILayer* layer = nullptr; + auto output_name = op_desc.Output("Out")[0]; + bool flag_varseqlen = engine_->use_varseqlen() && + engine_->tensorrt_transformer_posid() != "" && + engine_->tensorrt_transformer_maskid() != ""; + if (engine_->with_dynamic_shape()) { + if (flag_varseqlen) { + if (engine_->precision() == AnalysisConfig::Precision::kFloat32) { + PADDLE_THROW(platform::errors::Fatal( + "use use_varseqlen must be int8 or half, not float32.")); + } + nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT, + static_cast(weight_data), + static_cast(weight_t->numel())}; + nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, + static_cast(bias_data), + static_cast(bias_t->numel())}; + if (engine_->with_interleaved()) { + VLOG(4) << "fused multihead_matmul op: use_varseqlen and " + "with_interleaved"; + if (!op_desc.HasAttr("Input_scale")) { + PADDLE_THROW( + platform::errors::Fatal("use with_interleaved must be int8.")); + } + nvinfer1::ILayer* fc_layer = nullptr; + float dp_probs = 1.0 / 127.0; + nvinfer1::DimsHW nv_ksize(1, 1); + fc_layer = TRT_ENGINE_ADD_LAYER(engine_, Convolution, *input, n, + nv_ksize, weight, bias); + fc_layer->setName( + ("Multihead: Convolution/FullyConnected: (Output: " + + output_name + ")") + .c_str()); + PADDLE_ENFORCE_EQ( + op_desc.HasAttr("fc_out_threshold"), true, + platform::errors::InvalidArgument( + "must have out_threshold in multihead layers in int8 mode")); + float out_scale = + BOOST_GET_CONST(float, op_desc.GetAttr("fc_out_threshold")); + engine_->SetTensorDynamicRange(fc_layer->getOutput(0), out_scale); + if (qkv2context_plugin_int8) { + dp_probs = + BOOST_GET_CONST(float, op_desc.GetAttr("dp_probs")) / 127.0; + } + auto creator = GetPluginRegistry()->getPluginCreator( + "CustomQKVToContextPluginDynamic", "3"); + assert(creator != nullptr); + std::vector fields{ + {"hidden_size", &hidden_out, nvinfer1::PluginFieldType::kINT32, + 1}, + {"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, + 1}}; + if (qkv2context_plugin_int8) { + fields.push_back({"dq_probs", &dp_probs, + nvinfer1::PluginFieldType::kFLOAT32, 1}); + } + nvinfer1::PluginFieldCollection* plugin_collection = + static_cast(malloc( + sizeof(*plugin_collection) + + fields.size() * + sizeof(nvinfer1::PluginField))); // remember to free + plugin_collection->nbFields = static_cast(fields.size()); + plugin_collection->fields = fields.data(); + + auto plugin = creator->createPlugin("CustomQKVToContextPluginDynamic", + plugin_collection); + free(plugin_collection); + + std::vector plugin_inputs; + plugin_inputs.emplace_back(fc_layer->getOutput(0)); + if (engine_->Has("ernie_pos_name")) { + plugin_inputs.emplace_back(engine_->GetITensor( + engine_->Get("ernie_pos_name"))); + } else { + plugin_inputs.emplace_back(engine_->GetITensor( + engine_->network() + ->getInput(2) + ->getName())); // cu_seqlens, eval_placeholder_2 + } + auto max_seqlen_tensor = + engine_->GetITensor(engine_->network()->getInput(3)->getName()); + engine_->SetTensorDynamicRange(max_seqlen_tensor, 1.0f); + auto* shuffle_layer = TRT_ENGINE_ADD_LAYER( + engine_, Shuffle, + *const_cast(max_seqlen_tensor)); + nvinfer1::Dims shape_dim; + shape_dim.nbDims = 1; + shape_dim.d[0] = -1; + shuffle_layer->setReshapeDimensions(shape_dim); + engine_->SetTensorDynamicRange(shuffle_layer->getOutput(0), 1.0f); + plugin_inputs.emplace_back( + shuffle_layer->getOutput(0)); // max_seqlen, eval_placeholder_3 + shuffle_layer->setName( + ("Multihead: Shuffle: (Output: " + output_name + ")").c_str()); + auto plugin_layer = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin); + layer = plugin_layer; + } else { + int head_size = hidden_out / head_number; + // [3, head_number, head_size, hidden_in] -> [head_number, 3, + // head_size, + // hidden_in] + auto transpose_weight_v2 = [](const float* src, float* dst, int three, + int head_number, int head_size, + int hidden_in) { + const int HH = head_size * hidden_in; + for (int i = 0; i < three; ++i) { + for (int n = 0; n < head_number; ++n) { + for (int hh = 0; hh < HH; ++hh) { + dst[n * three * HH + i * HH + hh] = + src[i * head_number * HH + n * HH + hh]; + } + } + } + }; + // [3, head_number, head_size] -> [head_number, 3, head_size] + auto transpose_bias_v2 = [](const float* src, float* dst, int N, + int H) { + for (int i = 0; i < 3; ++i) { + for (int n = 0; n < N; ++n) { + for (int h = 0; h < H; ++h) { + dst[n * 3 * H + i * H + h] = src[i * N * H + n * H + h]; + } + } + } + }; + memcpy(weight_data_tmp.data(), weight_data, + weight_t->numel() * sizeof(float)); + transpose_weight_v2(weight_data_tmp.data(), weight_data, three, + head_number, head_size, hidden_in); + + std::vector bias_data_tmp; + bias_data_tmp.reserve(bias_t->numel()); + memcpy(bias_data_tmp.data(), bias_data, + bias_t->numel() * sizeof(float)); + transpose_bias_v2(bias_data_tmp.data(), bias_data, head_number, + head_size); + + nvinfer1::ILayer* fc_layer = nullptr; + float dp_probs = 1.0 / 127.0; + if (op_desc.HasAttr("Input_scale")) { + nvinfer1::DimsHW nv_ksize(1, 1); + fc_layer = TRT_ENGINE_ADD_LAYER(engine_, Convolution, *input, n, + nv_ksize, weight, bias); + } else { + fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *input, n, + weight, bias); + } + + if (op_desc.HasAttr("fc_out_threshold")) { + PADDLE_ENFORCE_EQ(op_desc.HasAttr("fc_out_threshold"), true, + platform::errors::InvalidArgument( + "must have out threshold in multihead layers " + "in int8 mode")); + float out_scale = + BOOST_GET_CONST(float, op_desc.GetAttr("fc_out_threshold")); + engine_->SetTensorDynamicRange(fc_layer->getOutput(0), out_scale); + if (qkv2context_plugin_int8) { + dp_probs = + BOOST_GET_CONST(float, op_desc.GetAttr("dp_probs")) / 127.0; + } + } + auto creator = GetPluginRegistry()->getPluginCreator( + "CustomQKVToContextPluginDynamic", "2"); + assert(creator != nullptr); + int type = static_cast(nvinfer1::DataType::kHALF); + if (qkv2context_plugin_int8 && + (engine_->precision() == AnalysisConfig::Precision::kInt8)) { + type = static_cast(nvinfer1::DataType::kINT8); + } + bool has_mask = true; + int var_seqlen = 1; + std::vector fields{ + {"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1}, + {"hidden_size", &hidden_out, nvinfer1::PluginFieldType::kINT32, + 1}, + {"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1}, + {"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1}, + {"var_seqlen", &var_seqlen, nvinfer1::PluginFieldType::kINT32, + 1}}; + if (qkv2context_plugin_int8) { + fields.push_back({"dq_probs", &dp_probs, + nvinfer1::PluginFieldType::kFLOAT32, 1}); + } + nvinfer1::PluginFieldCollection* plugin_collection = + static_cast(malloc( + sizeof(*plugin_collection) + + fields.size() * + sizeof(nvinfer1::PluginField))); // remember to free + plugin_collection->nbFields = static_cast(fields.size()); + plugin_collection->fields = fields.data(); + + auto plugin = creator->createPlugin("CustomQKVToContextPluginDynamic", + plugin_collection); + free(plugin_collection); + + std::vector plugin_inputs; + plugin_inputs.emplace_back(fc_layer->getOutput(0)); + plugin_inputs.emplace_back(engine_->GetITensor("qkv_plugin_mask")); + plugin_inputs.emplace_back(engine_->GetITensor("pos_id")); + + auto max_seqlen_tensor = engine_->GetITensor("mask_id"); + auto* shuffle_layer = TRT_ENGINE_ADD_LAYER( + engine_, Shuffle, + *const_cast(max_seqlen_tensor)); + nvinfer1::Dims shape_dim; + shape_dim.nbDims = 1; + shape_dim.d[0] = -1; + shuffle_layer->setReshapeDimensions(shape_dim); + engine_->SetTensorDynamicRange(shuffle_layer->getOutput(0), 1.0f); + plugin_inputs.emplace_back( + shuffle_layer->getOutput(0)); // max_seqlen, eval_placeholder_3 + + auto plugin_layer = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin); + layer = plugin_layer; + } + } else { + PADDLE_ENFORCE_EQ( + input->getDimensions().nbDims, 3, + platform::errors::InvalidArgument( + "The Input dim of the SparseMultiheadMatMul should be 3, " + "but it's (%d) now.", + input->getDimensions().nbDims)); + // transpose weight_data from m * n to n * m + auto* input_bias_qk = + engine_->GetITensor(op_desc.Input("BiasQK").front()); + + half* half_data = nullptr; + void* w_data = nullptr; + if (with_fp16) { + half_data = new half[weight_t->numel()]; + for (int i = 0; i < weight_t->numel(); i++) { + half_data[i] = static_cast(weight_data[i]); + } + w_data = static_cast(half_data); + } else { + w_data = static_cast(weight_data); + } + + TensorRTEngine::Weight weight{ + with_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, + static_cast(w_data), static_cast(weight_t->numel())}; + weight.dims.assign({n, m}); + + half* half_bias_data = nullptr; + void* b_data = nullptr; + if (with_fp16) { + half_bias_data = new half[bias_t->numel()]; + for (int i = 0; i < bias_t->numel(); i++) { + half_bias_data[i] = static_cast(bias_data[i]); + } + b_data = static_cast(half_bias_data); + } else { + b_data = static_cast(bias_data); + } + + TensorRTEngine::Weight bias{ + with_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, + b_data, static_cast(bias_t->numel())}; + + // add shuffle before fc + nvinfer1::Dims reshape_before_fc_dim; + reshape_before_fc_dim.nbDims = 5; + reshape_before_fc_dim.d[0] = 0; + reshape_before_fc_dim.d[1] = 0; + reshape_before_fc_dim.d[2] = 0; + reshape_before_fc_dim.d[3] = 1; + reshape_before_fc_dim.d[4] = 1; + auto* reshape_before_fc_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); + if (op_desc.HasAttr("Input_scale")) { + engine_->SetTensorDynamicRange(reshape_before_fc_layer->getOutput(0), + in_scale); + } + reshape_before_fc_layer->setReshapeDimensions(reshape_before_fc_dim); + reshape_before_fc_layer->setName( + ("shuffle_before_sparse_multihead_mamul(Output: " + output_name + + ")") + .c_str()); + + // add layer fc + nvinfer1::ILayer* fc_layer = nullptr; + if (op_desc.HasAttr("Input_scale")) { + plugin::SpmmPluginDynamic* plugin = + new_spmm_plugin(&weight, &bias, nvinfer1::DataType::kINT8, n); + std::vector plugin_inputs; + plugin_inputs.emplace_back(reshape_before_fc_layer->getOutput(0)); + fc_layer = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin); + } else { + plugin::SpmmPluginDynamic* plugin = + new_spmm_plugin(&weight, &bias, + with_fp16 ? nvinfer1::DataType::kHALF + : nvinfer1::DataType::kFLOAT, + n); + std::vector plugin_inputs; + plugin_inputs.emplace_back(reshape_before_fc_layer->getOutput(0)); + fc_layer = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin); + } + + if (op_desc.HasAttr("fc_out_threshold")) { + PADDLE_ENFORCE_EQ( + op_desc.HasAttr("fc_out_threshold"), true, + platform::errors::InvalidArgument( + "must have out threshold in multihead layers in int8 mode")); + float out_scale = + BOOST_GET_CONST(float, op_desc.GetAttr("fc_out_threshold")); + engine_->SetTensorDynamicRange(fc_layer->getOutput(0), out_scale); + } + fc_layer->setName( + ("sparse_multihead_mamul_fc(Output: " + output_name + ")").c_str()); + + // no need to add shuffle after fc, just change it in + // QkvToContextPluginDynamic + + // add qkv to context + int head_size = hidden_out / head_number; + float scale = BOOST_GET_CONST(float, op_desc.GetAttr("alpha")); + + std::vector plugin_inputs; + plugin_inputs.push_back(fc_layer->getOutput(0)); + plugin_inputs.push_back(input_bias_qk); + + if (engine_->precision() == AnalysisConfig::Precision::kInt8) { + with_fp16 = true; + } + plugin::DynamicPluginTensorRT* plugin = + new plugin::QkvToContextPluginDynamic(hidden_in, head_number, + head_size, scale, with_fp16); + layer = engine_->AddDynamicPlugin(plugin_inputs.data(), 2, plugin); + } + } else { + PADDLE_THROW(platform::errors::Fatal( + "You are running the Ernie(Bert) model in static shape mode, which " + "is not supported for the time being.\n" + "You can use the config.SetTRTDynamicShapeInfo(...) interface to set " + "the shape information to run the dynamic shape mode.")); + } + RreplenishLayerAndOutput(layer, "multihead_matmul", {output_name}, + test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(sparse_multihead_matmul, + SparseMultiheadMatMulOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index dc7c77bc66a..57ac400dada 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -46,6 +46,12 @@ struct SimpleOpTypeSetTeller : public Teller { teller_set.insert("reshape2"); int8_teller_set.insert("reshape"); int8_teller_set.insert("reshape2"); +#endif +#if IS_TRT_VERSION_GE(8000) + teller_set.insert("sparse_fc"); + int8_teller_set.insert("sparse_fc"); + teller_set.insert("sparse_multihead_matmul"); + int8_teller_set.insert("sparse_multihead_matmul"); #endif } @@ -1753,6 +1759,16 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, } } +#if IS_TRT_VERSION_GE(8000) + if (op_type == "sparse_fc" || op_type == "sparse_multihead_matmul") { + if (!with_dynamic_shape) { + VLOG(3) << "the sparse_fc and sparse_multihead_matmul does not support " + "static shape yet"; + return false; + } + } +#endif + if ((*teller)(op_type, desc, use_no_calib_int8)) return true; } diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index 0377c82838b..5ee70ee8241 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -1,32 +1,41 @@ +list( + APPEND + TRT_FILES + trt_plugin.cc + split_op_plugin.cu + elementwise_op_plugin.cu + prelu_op_plugin.cu + gelu_op_plugin.cu + pool_op_plugin.cu + swish_op_plugin.cu + layer_norm_op_plugin.cu + instance_norm_op_plugin.cu + emb_eltwise_layernorm_plugin.cu + qkv_to_context_plugin.cu + skip_layernorm_op_plugin.cu + slice_op_plugin.cu + hard_swish_op_plugin.cu + stack_op_plugin.cu + anchor_generator_op_plugin.cu + yolo_box_op_plugin.cu + yolo_box_head_op_plugin.cu + roi_align_op_plugin.cu + gather_nd_op_plugin.cu + mish_op_plugin.cu + pool3d_op_plugin.cu + deformable_conv_op_plugin.cu + matmul_op_int8_plugin.cu + transformer_input_convert_plugin.cu + remove_padding_plugin.cu + recover_padding_plugin.cu) + +if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8) + list(APPEND TRT_FILES spmm_plugin.cu) +endif() + nv_library( tensorrt_plugin - SRCS trt_plugin.cc - split_op_plugin.cu - elementwise_op_plugin.cu - prelu_op_plugin.cu - gelu_op_plugin.cu - pool_op_plugin.cu - swish_op_plugin.cu - layer_norm_op_plugin.cu - instance_norm_op_plugin.cu - emb_eltwise_layernorm_plugin.cu - qkv_to_context_plugin.cu - skip_layernorm_op_plugin.cu - slice_op_plugin.cu - hard_swish_op_plugin.cu - stack_op_plugin.cu - anchor_generator_op_plugin.cu - yolo_box_op_plugin.cu - yolo_box_head_op_plugin.cu - roi_align_op_plugin.cu - gather_nd_op_plugin.cu - mish_op_plugin.cu - pool3d_op_plugin.cu - deformable_conv_op_plugin.cu - matmul_op_int8_plugin.cu - transformer_input_convert_plugin.cu - remove_padding_plugin.cu - recover_padding_plugin.cu + SRCS ${TRT_FILES} DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor) nv_test( diff --git a/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu new file mode 100644 index 00000000000..4058d6564fc --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu @@ -0,0 +1,923 @@ +/* +Copyright (c) 2022, PaddlePaddle Authors, NVIDIA CORPORATION. All Rights +Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See +the License for the specific language governing permissions and +limitations under the License. +*/ +#include "paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +nvinfer1::PluginFieldCollection SpmmPluginDynamicCreator::field_collection_{}; +std::vector SpmmPluginDynamicCreator::plugin_attr_; + +inline int getElementSize(nvinfer1::DataType type) { + switch (type) { + case nvinfer1::DataType::kFLOAT: + return 4; + case nvinfer1::DataType::kHALF: + return 2; + case nvinfer1::DataType::kINT8: + return 1; + default: + PADDLE_THROW(paddle::platform::errors::Fatal( + "getElementSize only supports [FLOAT|HALF|INT8]")); + } +} + +inline cudaDataType_t convertTrtType(nvinfer1::DataType type) { + switch (type) { + case nvinfer1::DataType::kFLOAT: + return CUDA_R_32F; + case nvinfer1::DataType::kHALF: + return CUDA_R_16F; + case nvinfer1::DataType::kINT8: + return CUDA_R_8I; + default: + PADDLE_THROW(paddle::platform::errors::Fatal( + "getElementSize only supports [FLOAT|HALF|INT8]")); + } +} + +inline void deserialize_value_size(void const** buffer, size_t* buffer_size, + void* value, size_t value_size) { + PADDLE_ENFORCE_GE( + *buffer_size, value_size, + platform::errors::InvalidArgument("buffer_size must >= value_size")); + memcpy(value, *buffer, value_size); + reinterpret_cast(*buffer) += value_size; + *buffer_size -= value_size; +} + +inline float round_scale(float x) { return std::floor(x + 0.5f); } + +inline void cudaFreeFunc(void* p) { + if (p) { + cudaFree(p); + } +} + +inline void convertAndCopy(const nvinfer1::Weights& src, + nvinfer1::DataType type, void* dest) { + PADDLE_ENFORCE_EQ(src.type == nvinfer1::DataType::kFLOAT || + src.type == nvinfer1::DataType::kHALF, + true, + platform::errors::InvalidArgument( + "convertAndCopy only supports src type [FLOAT|HALF]")); + PADDLE_ENFORCE_EQ( + type == nvinfer1::DataType::kFLOAT || type == nvinfer1::DataType::kHALF, + true, + platform::errors::InvalidArgument( + "convertAndCopy only supports src type [FLOAT|HALF]")); + + if (type == nvinfer1::DataType::kFLOAT) { + if (src.type == nvinfer1::DataType::kFLOAT) { + std::copy_n(static_cast(src.values), src.count, + static_cast(dest)); + } else { + for (int i = 0; i < src.count; ++i) { + static_cast(dest)[i] = + static_cast(static_cast(src.values)[i]); + } + } + } else { + if (src.type == nvinfer1::DataType::kHALF) { + std::copy_n(static_cast(src.values), src.count, + static_cast<__half*>(dest)); + } else { + for (int i = 0; i < src.count; ++i) { + static_cast<__half*>(dest)[i] = + static_cast<__half>(static_cast(src.values)[i]); + } + } + } +} + +SpmmPluginDynamic::cusparseLtContext::cusparseLtContext() { + paddle::platform::dynload::cusparseLtInit(&handle); +} + +SpmmPluginDynamic::cusparseLtContext::~cusparseLtContext() { + paddle::platform::dynload::cusparseLtDestroy(&handle); +} + +void SpmmPluginDynamic::cusparseLtContext::init( + int m, int n, int k, cudaDataType_t type, void* bias_ptr, + SpmmPluginDynamic::Activation activation) { + /* + 1. Init matrix descriptors (matA, matB, matC) + 2. Init matrix multiplication descriptor (matmul) + 3. Set activation and bias attribute of matmul + 4. Init algorithm selection descriptor (alg_sel) + 5. Init plan descriptor (plan) + */ + PADDLE_ENFORCE_EQ( + is_initialized, false, + platform::errors::InvalidArgument( + "Descriptor should be destroyed before calling create")); + constexpr int alignment = 16; + cusparseComputeType compute_type; + switch (type) { + case CUDA_R_32F: + compute_type = CUSPARSE_COMPUTE_TF32; + break; + case CUDA_R_16F: + compute_type = CUSPARSE_COMPUTE_16F; + break; + case CUDA_R_8I: + compute_type = CUSPARSE_COMPUTE_32I; + break; + default: + PADDLE_THROW(paddle::platform::errors::Fatal( + "cusparLtContext only supports data type" + "[CUDA_R_32F|CUDA_R_16F|CUDA_R_8I]")); + } + paddle::platform::dynload::cusparseLtDenseDescriptorInit( + &handle, &matA, m, k, k, alignment, type, CUSPARSE_ORDER_ROW); + paddle::platform::dynload::cusparseLtStructuredDescriptorInit( + &handle, &matB, n, k, k, alignment, type, CUSPARSE_ORDER_ROW, + CUSPARSELT_SPARSITY_50_PERCENT); + paddle::platform::dynload::cusparseLtDenseDescriptorInit( + &handle, &matC, m, n, n, alignment, type, CUSPARSE_ORDER_ROW); + paddle::platform::dynload::cusparseLtMatmulDescriptorInit( + &handle, &matmul, CUSPARSE_OPERATION_NON_TRANSPOSE, + CUSPARSE_OPERATION_TRANSPOSE, &matA, &matB, &matC, &matC, compute_type); + if (activation == SpmmPluginDynamic::Activation::kRelu) { + int true_value = 1; + float relu_upper_bound = std::numeric_limits::max(); + float relu_threshold = 0.0f; + paddle::platform::dynload::cusparseLtMatmulDescSetAttribute( + &handle, &matmul, CUSPARSELT_MATMUL_ACTIVATION_RELU, &true_value, + sizeof(true_value)); + paddle::platform::dynload::cusparseLtMatmulDescSetAttribute( + &handle, &matmul, CUSPARSELT_MATMUL_ACTIVATION_RELU_UPPERBOUND, + &relu_upper_bound, sizeof(relu_upper_bound)); + paddle::platform::dynload::cusparseLtMatmulDescSetAttribute( + &handle, &matmul, CUSPARSELT_MATMUL_ACTIVATION_RELU_THRESHOLD, + &relu_threshold, sizeof(relu_threshold)); + } else if (activation == SpmmPluginDynamic::Activation::kGelu) { + int true_value = 1; + paddle::platform::dynload::cusparseLtMatmulDescSetAttribute( + &handle, &matmul, CUSPARSELT_MATMUL_ACTIVATION_GELU, &true_value, + sizeof(true_value)); + } else { + PADDLE_ENFORCE_EQ( + activation, SpmmPluginDynamic::Activation::kNone, + platform::errors::InvalidArgument("Received unknown activation")); + } + if (bias_ptr != nullptr) { + paddle::platform::dynload::cusparseLtMatmulDescSetAttribute( + &handle, &matmul, CUSPARSELT_MATMUL_BIAS_POINTER, &bias_ptr, + sizeof(bias_ptr)); + } + paddle::platform::dynload::cusparseLtMatmulAlgSelectionInit( + &handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT); + int alg = 0; + paddle::platform::dynload::cusparseLtMatmulAlgSetAttribute( + &handle, &alg_sel, CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg, sizeof(alg)); + paddle::platform::dynload::cusparseLtMatmulGetWorkspace(&handle, &alg_sel, + &workspace_size); + paddle::platform::dynload::cusparseLtMatmulPlanInit(&handle, &plan, &matmul, + &alg_sel, workspace_size); + is_initialized = true; +} + +void SpmmPluginDynamic::cusparseLtContext::setAlgo(int alg) { + PADDLE_ENFORCE_EQ( + is_initialized, true, + platform::errors::InvalidArgument( + "Descriptor should be initialized before setting algorithm")); + paddle::platform::dynload::cusparseLtMatmulAlgSetAttribute( + &handle, &alg_sel, CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg, sizeof(alg)); + paddle::platform::dynload::cusparseLtMatmulGetWorkspace(&handle, &alg_sel, + &workspace_size); + paddle::platform::dynload::cusparseLtMatmulPlanDestroy(&plan); + paddle::platform::dynload::cusparseLtMatmulPlanInit(&handle, &plan, &matmul, + &alg_sel, workspace_size); +} + +void SpmmPluginDynamic::cusparseLtContext::destroy() { + PADDLE_ENFORCE_EQ(is_initialized, true, + platform::errors::InvalidArgument( + "cusparseLtContext is destroy before init")); + paddle::platform::dynload::cusparseLtMatmulPlanDestroy(&plan); + paddle::platform::dynload::cusparseLtMatDescriptorDestroy(&matC); + paddle::platform::dynload::cusparseLtMatDescriptorDestroy(&matB); + paddle::platform::dynload::cusparseLtMatDescriptorDestroy(&matA); + is_initialized = false; +} + +void SpmmPluginDynamic::cusparseLtContext::compressMatB( + int n, int k, cudaDataType_t type, void* src, void** dest, + size_t* compressed_size) { + PADDLE_ENFORCE_EQ( + is_initialized, false, + platform::errors::InvalidArgument( + "cusparseLtContext should not initialized before compressMatB")); + PADDLE_ENFORCE_EQ(*dest, nullptr, + platform::errors::InvalidArgument( + "before compressMatB *dest must be nullptr")); + constexpr int alignment = 16; + paddle::platform::dynload::cusparseLtStructuredDescriptorInit( + &handle, &matB, n, k, k, alignment, type, CUSPARSE_ORDER_ROW, + CUSPARSELT_SPARSITY_50_PERCENT); + + paddle::platform::dynload::cusparseLtSpMMACompressedSize2(&handle, &matB, + compressed_size); + cudaMalloc(dest, *compressed_size); + paddle::platform::dynload::cusparseLtSpMMACompress2( + &handle, &matB, 0, CUSPARSE_OPERATION_TRANSPOSE, src, *dest, nullptr); + paddle::platform::dynload::cusparseLtMatDescriptorDestroy(&matB); +} + +// Constructor for new plugin +SpmmPluginDynamic::SpmmPluginDynamic(const std::string& layer_name, + const nvinfer1::DataType precision, + const int out_dim, + const nvinfer1::Weights& weight, + const nvinfer1::Weights& bias, + Activation activation) + : layer_name_(layer_name), + precision_(precision), + out_dim_(out_dim), + k_(0), + m_max_(0), + is_configured_(false), + optim_alg_(0), + weight_scale_(1.0f), + weight_compressed_(nullptr), + weight_compressed_dev_(nullptr), + weight_compressed_dev_global_(nullptr), + compressed_size_(0), + has_bias_(false), + bias_(nullptr), + bias_dev_(nullptr), + activation_(activation) { + /* + 1. Convert weight precision (on host) + 2. (Int8) Calculate scale and scale the weight (on host) + 3. Copy weight to device + 4. Compress the weight (on device) + 5. Reset the shared_ptr "weight_compressed_dev_global_" to the compressed + weight + 6. Copy the compressed weight to host + 7. Convert bias precision and copy (on host) + */ + precision_size_ = getElementSize(precision); + element_size_ = + (precision_ == nvinfer1::DataType::kINT8 ? 4 : precision_size_); + + PADDLE_ENFORCE_EQ( + weight.count % out_dim, 0, + platform::errors::InvalidArgument( + "The size of weight should be divided by output dimension.")); + k_ = weight.count / out_dim; + PADDLE_ENFORCE_EQ( + weight.type == nvinfer1::DataType::kFLOAT || + weight.type == nvinfer1::DataType::kHALF, + true, + platform::errors::InvalidArgument( + "SpmmPluginDynamic only supports weight of type [FLOAT|HALF]")); + nvinfer1::DataType weight_type; + if (precision_ == nvinfer1::DataType::kINT8) { + weight_type = nvinfer1::DataType::kFLOAT; + } else { + weight_type = precision_; + } + std::vector weight_host(element_size_ * out_dim_ * k_); + convertAndCopy(weight, weight_type, weight_host.data()); + void* weight_dev{nullptr}; + cudaMalloc(reinterpret_cast(&weight_dev), + precision_size_ * out_dim_ * k_); + if (precision == nvinfer1::DataType::kINT8) { + float max_weight{0.0f}; + for (int i = 0; i < weight.count; ++i) { + float local_abs = + std::abs(reinterpret_cast(weight_host.data())[i]); + max_weight = std::max(max_weight, local_abs); + } + weight_scale_ = max_weight / 127.0f; + std::vector scale_buffer(weight.count); + for (int i = 0; i < weight.count; ++i) { + scale_buffer[i] = static_cast( + round_scale(reinterpret_cast(weight_host.data())[i] / + weight_scale_)); + } + cudaMemcpy(weight_dev, scale_buffer.data(), precision_size_ * weight.count, + cudaMemcpyHostToDevice); + } else { + cudaMemcpy(weight_dev, weight_host.data(), precision_size_ * weight.count, + cudaMemcpyHostToDevice); + } + spmm_context_.compressMatB(out_dim_, k_, convertTrtType(precision_), + weight_dev, &weight_compressed_dev_, + &compressed_size_); + weight_compressed_ = new char[compressed_size_]; + weight_compressed_dev_global_.reset(weight_compressed_dev_, cudaFreeFunc); + cudaMemcpy(weight_compressed_, weight_compressed_dev_global_.get(), + compressed_size_, cudaMemcpyDeviceToHost); + has_bias_ = (bias.count != 0); + if (has_bias_) { + if (bias.count != out_dim) { + PADDLE_THROW(paddle::platform::errors::Fatal( + "The dimension of bias should be equal to output dimension")); + } + if (precision_ == nvinfer1::DataType::kHALF) { + bias_ = new half[out_dim_]; + convertAndCopy(bias, nvinfer1::DataType::kHALF, bias_); + } else { + bias_ = new float[out_dim_]; + convertAndCopy(bias, nvinfer1::DataType::kFLOAT, bias_); + } + } + + cudaFree(weight_dev); +} + +// Constructor for clone +SpmmPluginDynamic::SpmmPluginDynamic(const std::string& layer_name, + const nvinfer1::DataType precision, + const int out_dim, const int k, + const void* weight_compressed, + size_t compressed_size, const void* bias, + bool is_configured, const int m_max, + const int optim_alg, Activation activation) + : layer_name_(layer_name), + precision_(precision), + out_dim_(out_dim), + k_(k), + m_max_(m_max), + is_configured_(is_configured), + optim_alg_(optim_alg), + weight_scale_(1.0f), + weight_compressed_(nullptr), + weight_compressed_dev_global_(nullptr), + compressed_size_(compressed_size), + has_bias_(false), + bias_(nullptr), + bias_dev_(nullptr), + activation_(activation) { + /* + 1. Copy the compressed weight (on host) + 2. Copy the bias (on host) + 3. (Configured) Copy the bias to device + 4. (Configured) Init cuSPARSELt descriptors + */ + precision_size_ = getElementSize(precision); + element_size_ = + (precision_ == nvinfer1::DataType::kINT8 ? 4 : precision_size_); + // Each plugin has a copy of compressed weight on host, while sharing the + // compressed weights on device using std::shared_ptr + weight_compressed_ = new char[compressed_size]; + std::copy_n(static_cast(weight_compressed), compressed_size, + static_cast(weight_compressed_)); + + has_bias_ = (bias != nullptr); + if (has_bias_) { + // Each plugin has a copy of bias + bias_ = new float[out_dim_]; + std::copy_n(static_cast(bias), sizeof(float) * out_dim_, + static_cast(bias_)); + if (is_configured_) { + cudaMalloc(reinterpret_cast(&bias_dev_), + sizeof(float) * out_dim_); + cudaMemcpy(bias_dev_, bias_, sizeof(float) * out_dim_, + cudaMemcpyHostToDevice); + } + } + + if (is_configured_) { + cudaDataType_t dataType = convertTrtType(precision_); + spmm_context_.init(m_max_, out_dim_, k_, dataType, bias_dev_, activation_); + spmm_context_.setAlgo(optim_alg_); + } +} + +SpmmPluginDynamic::SpmmPluginDynamic(const std::string name, const void* data, + size_t length) + : layer_name_(name), + weight_compressed_(nullptr), + weight_compressed_dev_(nullptr), + weight_compressed_dev_global_(nullptr), + bias_(nullptr), + bias_dev_(nullptr) { + DeserializeValue(&data, &length, &precision_); + DeserializeValue(&data, &length, &precision_size_); + DeserializeValue(&data, &length, &element_size_); + DeserializeValue(&data, &length, &out_dim_); + DeserializeValue(&data, &length, &k_); + DeserializeValue(&data, &length, &m_max_); + DeserializeValue(&data, &length, &is_configured_); + DeserializeValue(&data, &length, &optim_alg_); + DeserializeValue(&data, &length, &weight_scale_); + DeserializeValue(&data, &length, &compressed_size_); + DeserializeValue(&data, &length, &has_bias_); + DeserializeValue(&data, &length, &activation_); + + PADDLE_ENFORCE_EQ(is_configured_, true, + platform::errors::InvalidArgument( + "Deserialize data should be configured")); + weight_compressed_ = new char[compressed_size_]; + deserialize_value_size(&data, &length, weight_compressed_, compressed_size_); + cudaMalloc(reinterpret_cast(&weight_compressed_dev_), + compressed_size_); + cudaMemcpy(weight_compressed_dev_, weight_compressed_, compressed_size_, + cudaMemcpyHostToDevice); + weight_compressed_dev_global_.reset(weight_compressed_dev_, cudaFreeFunc); + + if (has_bias_) { + bias_ = new float[out_dim_]; + deserialize_value_size(&data, &length, bias_, sizeof(float) * out_dim_); + cudaMalloc(reinterpret_cast(&bias_dev_), sizeof(float) * out_dim_); + cudaMemcpy(bias_dev_, bias_, sizeof(float) * out_dim_, + cudaMemcpyHostToDevice); + } + + if (is_configured_) { + cudaDataType_t dataType = convertTrtType(precision_); + spmm_context_.init(m_max_, out_dim_, k_, dataType, bias_dev_, activation_); + spmm_context_.setAlgo(optim_alg_); + } +} + +nvinfer1::IPluginV2DynamicExt* SpmmPluginDynamic::clone() const noexcept { + try { + auto* p = + new SpmmPluginDynamic(layer_name_, precision_, out_dim_, k_, + weight_compressed_, compressed_size_, bias_, + is_configured_, m_max_, optim_alg_, activation_); + p->weight_scale_ = weight_scale_; + p->weight_compressed_dev_global_ = weight_compressed_dev_global_; + p->setPluginNamespace(namespace_.c_str()); + return p; + } catch (const std::exception& e) { + std::cerr << e.what() << std::endl; + } + return nullptr; +} + +nvinfer1::DimsExprs SpmmPluginDynamic::getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept { + int nbDims = inputs[0].nbDims; + try { + PADDLE_ENFORCE_EQ(nbInputs, 1, + platform::errors::InvalidArgument( + "SpmmPluginDynamic's nbInputs is invalid")); + PADDLE_ENFORCE_EQ(outputIndex, 0, + platform::errors::InvalidArgument( + "SpmmPluginDynamic's outputIndex is invalid")); + if (nbDims == 5) { + int nbDims = inputs[0].nbDims; + PADDLE_ENFORCE_EQ( + inputs[0].d[3]->getConstantValue(), 1, + platform::errors::InvalidArgument("now the input d[3] should be 1")); + PADDLE_ENFORCE_EQ( + inputs[0].d[4]->getConstantValue(), 1, + platform::errors::InvalidArgument("now the input d[4] should be 1")); + nvinfer1::DimsExprs ret; + ret.nbDims = nbDims; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[0].d[1]; + ret.d[2] = exprBuilder.constant(out_dim_); + ret.d[3] = exprBuilder.constant(1); + ret.d[4] = exprBuilder.constant(1); + return ret; + } else if (nbDims == 4) { + int nbDims = inputs[0].nbDims; + PADDLE_ENFORCE_EQ( + inputs[0].d[2]->getConstantValue(), 1, + platform::errors::InvalidArgument("now the input d[2] should be 1")); + PADDLE_ENFORCE_EQ( + inputs[0].d[3]->getConstantValue(), 1, + platform::errors::InvalidArgument("now the input d[3] should be 1")); + nvinfer1::DimsExprs ret; + ret.nbDims = nbDims; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = exprBuilder.constant(out_dim_); + ret.d[2] = exprBuilder.constant(1); + ret.d[3] = exprBuilder.constant(1); + + return ret; + } else { + PADDLE_THROW(paddle::platform::errors::Fatal("nbDims should be 4 or 5")); + } + } catch (const std::exception& e) { + std::cerr << e.what() << std::endl; + } + return nvinfer1::DimsExprs{}; +} + +bool SpmmPluginDynamic::supportsFormatCombination( + int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, + int nbOutputs) noexcept { + PADDLE_ENFORCE_EQ(nbInputs, 1, + platform::errors::InvalidArgument( + "SpmmPluginDynamic's nbInputs should be 1")); + PADDLE_ENFORCE_EQ(nbOutputs, 1, + platform::errors::InvalidArgument( + "SpmmPluginDynamic's nbOutputs should be 1")); + + const nvinfer1::PluginTensorDesc& in = inOut[pos]; + if (pos == 0) { + return (in.type == precision_) && + (in.format == nvinfer1::TensorFormat::kLINEAR); + } + const nvinfer1::PluginTensorDesc& prev = inOut[pos - 1]; + + return in.type == prev.type && in.format == prev.format; +} + +void SpmmPluginDynamic::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, int nbOutputs) noexcept { + /* + The following steps are executed if not configured. + 1. (INT8) Scale the bias (on host) + 2. Copy the bias to device + 3. Search the optimal algorithm + */ + try { + PADDLE_ENFORCE_EQ(nbInputs, 1, + platform::errors::InvalidArgument( + "SpmmPluginDynamic's nbInputs should be 1")); + PADDLE_ENFORCE_EQ(nbOutputs, 1, + platform::errors::InvalidArgument( + "SpmmPluginDynamic's nbOutputs should be 1")); + PADDLE_ENFORCE_EQ(precision_, inputs[0].desc.type, + platform::errors::InvalidArgument( + "precision_ should be equal to inputs[0].desc.type")); + const auto& inDims0 = inputs[0].desc.dims; + if (inDims0.nbDims == 5) { + PADDLE_ENFORCE_EQ( + inDims0.nbDims, 5, + platform::errors::InvalidArgument("inDims0.nbDims should be 5")); + PADDLE_ENFORCE_EQ(k_, inDims0.d[2], + platform::errors::InvalidArgument( + "inDims0.d[2] should be equals to k")); + PADDLE_ENFORCE_EQ( + inDims0.d[3], 1, + platform::errors::InvalidArgument("inDims0.d[3] should be 1")); + PADDLE_ENFORCE_EQ( + inDims0.d[4], 1, + platform::errors::InvalidArgument("inDims0.d[4] should be 1")); + const int BS = inputs->max.d[0]; + const int Seq = inputs->max.d[1]; + m_max_ = BS * Seq; + } else if (inDims0.nbDims == 4) { + PADDLE_ENFORCE_EQ( + inDims0.nbDims, 4, + platform::errors::InvalidArgument("inDims0.nbDims should be 4")); + PADDLE_ENFORCE_EQ(k_, inDims0.d[1], + platform::errors::InvalidArgument( + "inDims0.d[1] should be equals to k")); + PADDLE_ENFORCE_EQ( + inDims0.d[2], 1, + platform::errors::InvalidArgument("inDims0.d[2] should be 1")); + PADDLE_ENFORCE_EQ( + inDims0.d[3], 1, + platform::errors::InvalidArgument("inDims0.d[3] should be 1")); + const int BS_Seq = inputs->max.d[0]; + m_max_ = BS_Seq; + } + if (is_configured_) { + return; + } + + if (has_bias_) { + if (inputs->desc.type == nvinfer1::DataType::kINT8) { + for (int i = 0; i < out_dim_; ++i) { + static_cast(bias_)[i] = + static_cast(bias_)[i] / outputs->desc.scale; + } + } + cudaMalloc(reinterpret_cast(&bias_dev_), + sizeof(float) * out_dim_); + cudaMemcpy(bias_dev_, bias_, sizeof(float) * out_dim_, + cudaMemcpyHostToDevice); + } + cudaDataType_t dataType = convertTrtType(precision_); + spmm_context_.init(m_max_, out_dim_, k_, dataType, bias_dev_, activation_); + + void* dA; + void* dC; + void* d_workspace; + float alpha{1.0f}; + float beta{0.0f}; + if (precision_ == nvinfer1::DataType::kINT8) { + alpha = inputs->desc.scale * weight_scale_ / outputs->desc.scale; + } + cudaMalloc(reinterpret_cast(&dA), m_max_ * k_ * sizeof(dataType)); + cudaMalloc(reinterpret_cast(&dC), + m_max_ * out_dim_ * sizeof(dataType)); + cudaMalloc(reinterpret_cast(&d_workspace), + spmm_context_.workspace_size); + paddle::platform::dynload::cusparseLtMatmulSearch( + &spmm_context_.handle, &spmm_context_.plan, &alpha, dA, + weight_compressed_dev_global_.get(), &beta, dC, dC, d_workspace, + nullptr, 0); + paddle::platform::dynload::cusparseLtMatmulAlgGetAttribute( + &spmm_context_.handle, &spmm_context_.alg_sel, + CUSPARSELT_MATMUL_ALG_CONFIG_ID, &optim_alg_, sizeof(optim_alg_)); + cudaFree(dA); + cudaFree(dC); + cudaFree(d_workspace); + + is_configured_ = true; + } catch (const std::exception& e) { + std::cerr << e.what() << std::endl; + } +} + +size_t SpmmPluginDynamic::getWorkspaceSize( + const nvinfer1::PluginTensorDesc* inputs, int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept { + return spmm_context_.workspace_size; +} + +int SpmmPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, void* const* outputs, + void* workSpace, cudaStream_t stream) noexcept { + try { + PADDLE_ENFORCE_EQ(is_configured_, true, + platform::errors::InvalidArgument( + "The plugin is not configured before enqueue")); + if (inputDesc->dims.nbDims == 5) { + PADDLE_ENFORCE_EQ( + k_, inputDesc->dims.d[2], + platform::errors::InvalidArgument("k_ == inputDesc->dims.d[2]")); + } else if (inputDesc->dims.nbDims == 4) { + PADDLE_ENFORCE_EQ( + k_, inputDesc->dims.d[1], + platform::errors::InvalidArgument("k_ == inputDesc->dims.d[1]")); + } + float alpha = 1.0f; + float beta = 0.0f; + if (inputDesc->type == nvinfer1::DataType::kFLOAT) { + const auto* const input = static_cast(inputs[0]); + auto* output = static_cast(outputs[0]); + auto* weight_compressed_dev_p_ = weight_compressed_dev_global_.get(); + cusparseStatus_t status = paddle::platform::dynload::cusparseLtMatmul( + &spmm_context_.handle, &spmm_context_.plan, &alpha, input, + weight_compressed_dev_p_, &beta, output, output, workSpace, &stream, + 1); + return status != CUSPARSE_STATUS_SUCCESS; + } else if (inputDesc->type == nvinfer1::DataType::kHALF) { + const auto* const input = static_cast(inputs[0]); + auto* output = static_cast(outputs[0]); + auto* weight_compressed_dev_p_ = weight_compressed_dev_global_.get(); + cusparseStatus_t status = paddle::platform::dynload::cusparseLtMatmul( + &spmm_context_.handle, &spmm_context_.plan, &alpha, input, + weight_compressed_dev_p_, &beta, output, output, workSpace, &stream, + 1); + return status != CUSPARSE_STATUS_SUCCESS; + } else if (inputDesc->type == nvinfer1::DataType::kINT8) { + alpha = inputDesc->scale * weight_scale_ / outputDesc->scale; + const auto* const input = static_cast(inputs[0]); + auto* output = static_cast(outputs[0]); + auto* weight_compressed_dev_p_ = weight_compressed_dev_global_.get(); + cusparseStatus_t status = paddle::platform::dynload::cusparseLtMatmul( + &spmm_context_.handle, &spmm_context_.plan, &alpha, input, + weight_compressed_dev_p_, &beta, output, output, workSpace, &stream, + 1); + return status != CUSPARSE_STATUS_SUCCESS; + } else { + PADDLE_THROW(paddle::platform::errors::Fatal( + "Unsupported type error, expected [kHALF,kFLOAT], but received %d", + static_cast(precision_))); + } + } catch (const std::exception& e) { + std::cerr << e.what() << std::endl; + } + return -1; +} + +nvinfer1::DataType SpmmPluginDynamic::getOutputDataType( + int index, const nvinfer1::DataType* inputTypes, + int nbInputs) const noexcept { + PADDLE_ENFORCE_EQ(index, 0, + platform::errors::InvalidArgument( + "SpmmPluginDynamic's index should be 0")); + PADDLE_ENFORCE_EQ(nbInputs, 1, + platform::errors::InvalidArgument( + "SpmmPluginDynamic's nbInputs should be 1")); + PADDLE_ENFORCE_EQ(inputTypes[0] == nvinfer1::DataType::kFLOAT || + inputTypes[0] == nvinfer1::DataType::kHALF || + inputTypes[0] == nvinfer1::DataType::kINT8, + true, + platform::errors::InvalidArgument( + "SpmmPluginDynamic is not support this format now")); + + return inputTypes[0]; +} + +const char* SpmmPluginDynamic::getPluginType() const noexcept { + return "SpmmPluginDynamic"; +} + +const char* SpmmPluginDynamic::getPluginVersion() const noexcept { return "1"; } + +int SpmmPluginDynamic::getNbOutputs() const noexcept { return 1; } + +int SpmmPluginDynamic::initialize() noexcept { return 0; } + +void SpmmPluginDynamic::terminate() noexcept {} + +size_t SpmmPluginDynamic::getSerializationSize() const noexcept { + return compressed_size_ + (has_bias_ ? sizeof(float) * out_dim_ : 0) + + sizeof(precision_) + sizeof(precision_size_) + sizeof(element_size_) + + sizeof(out_dim_) + sizeof(k_) + sizeof(m_max_) + + sizeof(is_configured_) + sizeof(optim_alg_) + sizeof(weight_scale_) + + sizeof(compressed_size_) + sizeof(has_bias_) + sizeof(activation_); +} + +void SpmmPluginDynamic::serialize(void* buffer) const noexcept { + SerializeValue(&buffer, precision_); + SerializeValue(&buffer, precision_size_); + SerializeValue(&buffer, element_size_); + SerializeValue(&buffer, out_dim_); + SerializeValue(&buffer, k_); + SerializeValue(&buffer, m_max_); + SerializeValue(&buffer, is_configured_); + SerializeValue(&buffer, optim_alg_); + SerializeValue(&buffer, weight_scale_); + SerializeValue(&buffer, compressed_size_); + SerializeValue(&buffer, has_bias_); + SerializeValue(&buffer, activation_); + char* d = static_cast(buffer); + std::copy_n(static_cast(weight_compressed_), compressed_size_, + d); + if (has_bias_) { + d += compressed_size_; + std::copy_n(static_cast(bias_), out_dim_ * sizeof(float), d); + } +} + +void SpmmPluginDynamic::destroy() noexcept { + delete[] reinterpret_cast(weight_compressed_); + if (has_bias_) { + cudaFree(bias_dev_); + } + if (is_configured_) { + spmm_context_.destroy(); + } + delete this; +} + +void SpmmPluginDynamic::setPluginNamespace(const char* libNamespace) noexcept { + try { + namespace_ = libNamespace; + } catch (const std::exception& e) { + std::cerr << e.what() << std::endl; + } +} + +const char* SpmmPluginDynamic::getPluginNamespace() const noexcept { + return namespace_.c_str(); +} + +inline nvinfer1::DataType fieldTypeToDataType( + const nvinfer1::PluginFieldType ftype) { + switch (ftype) { + case nvinfer1::PluginFieldType::kFLOAT32: + return nvinfer1::DataType::kFLOAT; + case nvinfer1::PluginFieldType::kFLOAT16: + return nvinfer1::DataType::kHALF; + case nvinfer1::PluginFieldType::kINT32: + return nvinfer1::DataType::kINT32; + case nvinfer1::PluginFieldType::kINT8: + return nvinfer1::DataType::kINT8; + default: + PADDLE_THROW(paddle::platform::errors::Fatal( + "No corresponding datatype for plugin field type")); + } +} + +SpmmPluginDynamicCreator::SpmmPluginDynamicCreator() { + plugin_attr_.emplace_back(nvinfer1::PluginField( + "type_id", nullptr, nvinfer1::PluginFieldType::kINT32, 1)); + plugin_attr_.emplace_back(nvinfer1::PluginField( + "out_dim", nullptr, nvinfer1::PluginFieldType::kINT32, 1)); + plugin_attr_.emplace_back(nvinfer1::PluginField( + "weight", nullptr, nvinfer1::PluginFieldType::kFLOAT32, 1)); + plugin_attr_.emplace_back(nvinfer1::PluginField( + "bias", nullptr, nvinfer1::PluginFieldType::kFLOAT32, 1)); + plugin_attr_.emplace_back(nvinfer1::PluginField( + "activation_id", nullptr, nvinfer1::PluginFieldType::kINT8, 1)); + + field_collection_.nbFields = plugin_attr_.size(); + field_collection_.fields = plugin_attr_.data(); +} + +const char* SpmmPluginDynamicCreator::getPluginName() const noexcept { + return "SpmmPluginDynamic"; +} + +const char* SpmmPluginDynamicCreator::getPluginVersion() const noexcept { + return "1"; +} + +const nvinfer1::PluginFieldCollection* +SpmmPluginDynamicCreator::getFieldNames() noexcept { + return &field_collection_; +} + +nvinfer1::IPluginV2* SpmmPluginDynamicCreator::createPlugin( + const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept { + try { + int type_id = -1; + int out_dim = 0; + nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT, nullptr, 0ll}; + nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nullptr, 0ll}; + int activation_id = -1; + + for (int i = 0; i < fc->nbFields; i++) { + std::string field_name(fc->fields[i].name); + if (field_name.compare("type_id") == 0) { + type_id = static_cast(fc->fields[i].data)[0]; + } else if (field_name.compare("out_dim") == 0) { + out_dim = static_cast(fc->fields[i].data)[0]; + } else if (field_name.compare("weight") == 0) { + weight.type = fieldTypeToDataType(fc->fields[i].type); + weight.values = fc->fields[i].data; + weight.count = fc->fields[i].length; + } else if (field_name.compare("bias") == 0) { + bias.type = fieldTypeToDataType(fc->fields[i].type); + bias.values = fc->fields[i].data; + bias.count = fc->fields[i].length; + } else if (field_name.compare("activation_id") == 0) { + activation_id = static_cast(fc->fields[i].data)[0]; + } else { + PADDLE_THROW(paddle::platform::errors::Fatal("Unsupport plugin field")); + } + } + + PADDLE_ENFORCE_NE( + type_id, -1, + platform::errors::InvalidArgument( + "SpmmPluginDynamicCreator's type_id should not be -1")); + PADDLE_ENFORCE_NE( + out_dim, 0, + platform::errors::InvalidArgument( + "SpmmPluginDynamicCreator's out_dim should not be 0")); + PADDLE_ENFORCE_NE( + weight.count, 0, + platform::errors::InvalidArgument( + "SpmmPluginDynamicCreator's weight size should not be 0")); + PADDLE_ENFORCE_NE( + activation_id, -1, + platform::errors::InvalidArgument( + "SpmmPluginDynamicCreator's activation_id should not be -1")); + nvinfer1::DataType type = static_cast(type_id); + SpmmPluginDynamic::Activation activation = + static_cast(activation_id); + return new SpmmPluginDynamic(name, type, out_dim, weight, bias, activation); + } catch (const std::exception& e) { + std::cerr << e.what() << std::endl; + } + return nullptr; +} + +nvinfer1::IPluginV2* SpmmPluginDynamicCreator::deserializePlugin( + const char* name, const void* serialData, size_t serialLength) noexcept { + // This object will be deleted when the network is destroyed, which will + // call SpmmPluginDynamic::destroy() + try { + return new SpmmPluginDynamic(name, serialData, serialLength); + } catch (const std::exception& e) { + std::cerr << e.what() << std::endl; + } + return nullptr; +} + +void SpmmPluginDynamicCreator::setPluginNamespace( + const char* libNamespace) noexcept { + try { + namespace_ = libNamespace; + } catch (const std::exception& e) { + std::cerr << e.what() << std::endl; + } +} + +const char* SpmmPluginDynamicCreator::getPluginNamespace() const noexcept { + return namespace_.c_str(); +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h b/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h new file mode 100644 index 00000000000..60c3773f930 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h @@ -0,0 +1,158 @@ +/* Copyright (c) 2022, PaddlePaddle Authors, NVIDIA CORPORATION. All rights +reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#pragma once + +#include +#include +#include +#include +#include + +#include "NvInfer.h" +#include "NvInferPlugin.h" +#include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" +#include "paddle/fluid/platform/dynload/cusparseLt.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +class SpmmPluginDynamic : public nvinfer1::IPluginV2DynamicExt { + public: + enum class Activation { kNone, kRelu, kGelu }; + SpmmPluginDynamic(const std::string& name, const nvinfer1::DataType precision, + const int out_dim, const nvinfer1::Weights& weight, + const nvinfer1::Weights& bias, Activation activation); + // The second constructor is for clone member function + SpmmPluginDynamic(const std::string& name, const nvinfer1::DataType precision, + const int out_dim, const int k, const void* weight, + size_t compressed_size, const void* bias, + bool is_configured, const int m_max, const int optim_alg, + Activation activation); + SpmmPluginDynamic(const std::string name, const void* data, size_t length); + SpmmPluginDynamic() = delete; + nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; + nvinfer1::DimsExprs getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept override; + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, int nbOutputs) noexcept override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) noexcept override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const noexcept override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) noexcept override; + + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const noexcept override; + const char* getPluginType() const noexcept override; + const char* getPluginVersion() const noexcept override; + int getNbOutputs() const noexcept override; + int initialize() noexcept override; + void terminate() noexcept override; + size_t getSerializationSize() const noexcept override; + void serialize(void* buffer) const noexcept override; + void destroy() noexcept override; + void setPluginNamespace(const char* pluginNamespace) noexcept override; + const char* getPluginNamespace() const noexcept override; + + private: + struct cusparseLtContext { + cusparseLtHandle_t handle; + cusparseLtMatDescriptor_t matA; + cusparseLtMatDescriptor_t matB; + cusparseLtMatDescriptor_t matC; + cusparseLtMatmulDescriptor_t matmul; + cusparseLtMatmulAlgSelection_t alg_sel; + cusparseLtMatmulPlan_t plan; + cusparseLtContext(); + ~cusparseLtContext(); + size_t workspace_size{0}; + bool is_initialized{false}; + int activation{0}; + float relu_upper_bound{0}; + float relu_threshold{0}; + void init(int m, int n, int k, cudaDataType_t type, void* bias_ptr, + SpmmPluginDynamic::Activation activation); + void setAlgo(int id); + void destroy(); + void compressMatB(int n, int k, cudaDataType_t type, void* src, void** dest, + size_t* compressed_size); + }; // struct SpmmPluginDynamic::cusparseLtContext + const std::string layer_name_; + std::string namespace_; + nvinfer1::DataType precision_; + size_t precision_size_; + size_t + element_size_; // size of weight (float if INT8 or FLOAT; half if HALF) + int out_dim_; + int k_; + int m_max_; + bool is_configured_; // already get m, scale bias, and search the optim alg + // or not + int optim_alg_; // the index of optimal algorithm + float weight_scale_; // record the weight scale from constructor + void* weight_compressed_; // host compressed weight + void* weight_compressed_dev_; // device compressed weight + std::shared_ptr + weight_compressed_dev_global_; // shared pointer to the + // device compressed weight + size_t compressed_size_; // size of compressed weight + bool has_bias_; // there is bias or not + void* bias_; // host bias + void* bias_dev_; // device bias + Activation activation_; // record the activation type + cusparseLtContext spmm_context_; +}; // class SpmmPluginDynamic + +class SpmmPluginDynamicCreator : public nvinfer1::IPluginCreator { + public: + SpmmPluginDynamicCreator(); + const char* getPluginName() const noexcept override; + const char* getPluginVersion() const noexcept override; + const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override; + nvinfer1::IPluginV2* createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) noexcept override; + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) noexcept override; + void setPluginNamespace(const char* pluginNamespace) noexcept override; + const char* getPluginNamespace() const noexcept override; + + private: + static nvinfer1::PluginFieldCollection field_collection_; + static std::vector plugin_attr_; + std::string namespace_; +}; // class SpmmPluginDynamicCreator + +REGISTER_TRT_PLUGIN_V2(SpmmPluginDynamicCreator); +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc b/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc new file mode 100644 index 00000000000..4f0d7fb1e9e --- /dev/null +++ b/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc @@ -0,0 +1,175 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/engine.h" +#if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000) +#include "paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h" +#endif +#include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/common/float16.h" + +using float16 = phi::dtype::float16; +namespace paddle { +namespace inference { +namespace tensorrt { + +class TensorRTDynamicEngineTest : public ::testing::Test { + protected: + void SetUp() override { + ctx_ = new platform::CUDADeviceContext(platform::CUDAPlace(0)); + ctx_->SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(platform::CUDAPlace(0), ctx_->stream()) + .get()); + ctx_->SetHostAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + ctx_->SetZeroAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetZeroAllocator(platform::CUDAPlace(0)) + .get()); + ctx_->SetPinnedAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CUDAPinnedPlace()) + .get()); + ctx_->PartialInitWithAllocator(); + + std::map> min_input_shape = { + {"input", {16, 32, 1, 1}}}; + std::map> max_input_shape = { + {"input", {16, 32, 1, 1}}}; + std::map> optim_input_shape = { + {"input", {16, 32, 1, 1}}}; + + engine_ = + new TensorRTEngine(16, 1 << 10, AnalysisConfig::Precision::kHalf, + nullptr, 0, min_input_shape, max_input_shape, + optim_input_shape, false, NaiveLogger::Global()); + engine_->InitNetwork(); + } + + void TearDown() override { + if (engine_) { + delete engine_; + engine_ = nullptr; + } + } + + void PrepareInputOutput(const std::vector &input, + std::vector output_shape) { + paddle::framework::TensorFromVector(input, *ctx_, &input_); + output_.Resize(phi::make_ddim(output_shape)); + } + + void GetOutput(std::vector *output) { + paddle::framework::TensorToVector(output_, *ctx_, output); + } + + protected: + framework::Tensor input_; + framework::Tensor output_; + TensorRTEngine *engine_; + platform::CUDADeviceContext *ctx_; +}; + +TEST_F(TensorRTDynamicEngineTest, test_spmm) { + // Weight in CPU memory. +#if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000) + float16 raw_weight[512]; + for (int i = 0; i < 128; i++) { + if (i % 16 <= 7) { + raw_weight[4 * i] = float16(1.0); + raw_weight[4 * i + 1] = float16(0.0); + raw_weight[4 * i + 2] = float16(0.0); + raw_weight[4 * i + 3] = float16(4.0); + } else { + raw_weight[4 * i] = float16(0.0); + raw_weight[4 * i + 1] = float16(2.0); + raw_weight[4 * i + 2] = float16(3.0); + raw_weight[4 * i + 3] = float16(0.0); + } + } + float16 raw_bias[16] = {float16(0), float16(1), float16(0), float16(2), + float16(0), float16(3), float16(0), float16(4), + float16(0), float16(5), float16(0), float16(6), + float16(0), float16(7), float16(0), float16(8)}; + std::vector buffers(2); // TRT binded inputs + TensorRTEngine::Weight weight(nvinfer1::DataType::kHALF, raw_weight, 512); + TensorRTEngine::Weight bias(nvinfer1::DataType::kHALF, raw_bias, 16); + std::cout << "with_dynamic_shape: " << engine_->with_dynamic_shape() + << std::endl; + auto *x = engine_->DeclareInput("input", nvinfer1::DataType::kHALF, + nvinfer1::Dims4{-1, 32, 1, 1}); + + plugin::SpmmPluginDynamic::Activation act = + plugin::SpmmPluginDynamic::Activation::kNone; + + plugin::SpmmPluginDynamic *plugin = new plugin::SpmmPluginDynamic( + "CustomSpmmPluginDynamic", nvinfer1::DataType::kHALF, 16, weight.get(), + bias.get(), act); + std::vector plugin_inputs; + plugin_inputs.emplace_back(x); + auto fc_layer = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin); + + LOG(INFO) << "create weights"; + PADDLE_ENFORCE_NOT_NULL(fc_layer, platform::errors::InvalidArgument( + "TRT SPMM layer building failed.")); + + engine_->DeclareOutput(fc_layer, 0, "y"); + engine_->FreezeNetwork(); + ASSERT_EQ(engine_->engine()->getNbBindings(), 2); + + std::vector x_v(512); + for (int i = 0; i < 128; i++) { + x_v[4 * i] = float16(1.0); + x_v[4 * i + 1] = float16(2.0); + x_v[4 * i + 2] = float16(3.0); + x_v[4 * i + 3] = float16(4.0); + } + + std::vector y_cpu; + PrepareInputOutput(x_v, {16, 16}); + + auto *x_v_gpu_data = input_.mutable_data(ctx_->GetPlace()); + auto *y_gpu_data = output_.mutable_data(ctx_->GetPlace()); + + buffers[0] = reinterpret_cast(x_v_gpu_data); + buffers[1] = reinterpret_cast(y_gpu_data); + + engine_->Execute(16, &buffers, ctx_->stream()); + LOG(INFO) << "to get output"; + GetOutput(&y_cpu); + + auto dims = engine_->GetITensor("y")->getDimensions(); + ASSERT_EQ(dims.nbDims, 4); + ASSERT_EQ(dims.d[1], 16); + ASSERT_EQ(y_cpu[0], 136); + + ASSERT_EQ(y_cpu[1], 105); + ASSERT_EQ(y_cpu[32], 136); + ASSERT_EQ(y_cpu[64], 136); + ASSERT_EQ(y_cpu[96], 136); +#endif + return; +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/CMakeLists.txt b/paddle/fluid/platform/dynload/CMakeLists.txt index bba0ad35e02..fa67961f02c 100644 --- a/paddle/fluid/platform/dynload/CMakeLists.txt +++ b/paddle/fluid/platform/dynload/CMakeLists.txt @@ -42,6 +42,10 @@ if(TENSORRT_FOUND) list(APPEND CUDA_SRCS tensorrt.cc) endif() +if(CUSPARSELT_FOUND) + list(APPEND CUDA_SRCS cusparseLt.cc) +endif() + configure_file(cupti_lib_path.h.in ${CMAKE_CURRENT_BINARY_DIR}/cupti_lib_path.h) if(CUPTI_FOUND) list(APPEND CUDA_SRCS cupti.cc) diff --git a/paddle/fluid/platform/dynload/cusparseLt.cc b/paddle/fluid/platform/dynload/cusparseLt.cc new file mode 100644 index 00000000000..ae2aec012b7 --- /dev/null +++ b/paddle/fluid/platform/dynload/cusparseLt.cc @@ -0,0 +1,29 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/dynload/cusparseLt.h" + +namespace paddle { +namespace platform { +namespace dynload { + +#define DEFINE_WRAP(__name) DynLoad__##__name __name + +#ifdef CUSPARSELT_ROUTINE_EACH +CUSPARSELT_ROUTINE_EACH(DEFINE_WRAP); +#endif + +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/cusparseLt.h b/paddle/fluid/platform/dynload/cusparseLt.h new file mode 100644 index 00000000000..feb13ec63c1 --- /dev/null +++ b/paddle/fluid/platform/dynload/cusparseLt.h @@ -0,0 +1,60 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#include +#include + +#include // NOLINT + +#include "paddle/phi/backends/dynload/cusparseLt.h" + +namespace paddle { +namespace platform { +namespace dynload { + +#define PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSELT_WRAP(__name) \ + using DynLoad__##__name = phi::dynload::DynLoad__##__name; \ + extern DynLoad__##__name __name + +#if defined(PADDLE_WITH_CUDA) +#if CUDA_VERSION >= 11020 +#define CUSPARSELT_ROUTINE_EACH(__macro) \ + __macro(cusparseLtInit); \ + __macro(cusparseLtDestroy); \ + __macro(cusparseLtDenseDescriptorInit); \ + __macro(cusparseLtStructuredDescriptorInit); \ + __macro(cusparseLtMatmulDescriptorInit); \ + __macro(cusparseLtMatmulDescSetAttribute); \ + __macro(cusparseLtMatmulAlgSelectionInit); \ + __macro(cusparseLtMatmulAlgSetAttribute); \ + __macro(cusparseLtMatmulGetWorkspace); \ + __macro(cusparseLtMatmulPlanInit); \ + __macro(cusparseLtMatDescriptorDestroy); \ + __macro(cusparseLtSpMMACompressedSize2); \ + __macro(cusparseLtSpMMACompress2); \ + __macro(cusparseLtMatmulSearch); \ + __macro(cusparseLtMatmulAlgGetAttribute); \ + __macro(cusparseLtMatmulPlanDestroy); \ + __macro(cusparseLtMatmul); \ + __macro(cusparseGetErrorString); + +CUSPARSELT_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSELT_WRAP); +#endif +#endif + +#undef PLATFORM_DECLARE_DYNAMIC_LOAD_CUSPARSELT_WRAP +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/dynamic_loader.cc b/paddle/fluid/platform/dynload/dynamic_loader.cc index 2f24e1b87da..b64bf81dc0d 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.cc +++ b/paddle/fluid/platform/dynload/dynamic_loader.cc @@ -72,6 +72,10 @@ void* GetCUFFTDsoHandle() { return phi::dynload::GetCUFFTDsoHandle(); } void* GetMKLRTDsoHandle() { return phi::dynload::GetMKLRTDsoHandle(); } +void* GetCusparseLtDsoHandle() { + return phi::dynload::GetCusparseLtDsoHandle(); +} + } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/dynload/dynamic_loader.h b/paddle/fluid/platform/dynload/dynamic_loader.h index ca60cd76a59..50714dfb302 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.h +++ b/paddle/fluid/platform/dynload/dynamic_loader.h @@ -46,6 +46,7 @@ void* GetNvtxDsoHandle(); void* GetCUFFTDsoHandle(); void* GetMKLRTDsoHandle(); void* GetROCFFTDsoHandle(); +void* GetCusparseLtDsoHandle(); void SetPaddleLibPath(const std::string&); } // namespace dynload diff --git a/paddle/phi/backends/dynload/CMakeLists.txt b/paddle/phi/backends/dynload/CMakeLists.txt index 91dbafe0cd3..408d524dca7 100644 --- a/paddle/phi/backends/dynload/CMakeLists.txt +++ b/paddle/phi/backends/dynload/CMakeLists.txt @@ -42,6 +42,10 @@ if(TENSORRT_FOUND) list(APPEND CUDA_SRCS tensorrt.cc) endif() +if(CUSPARSELT_FOUND) + list(APPEND CUDA_SRCS cusparseLt.cc) +endif() + configure_file(cupti_lib_path.h.in ${CMAKE_CURRENT_BINARY_DIR}/cupti_lib_path.h) if(CUPTI_FOUND) list(APPEND CUDA_SRCS cupti.cc) diff --git a/paddle/phi/backends/dynload/cusparseLt.cc b/paddle/phi/backends/dynload/cusparseLt.cc new file mode 100644 index 00000000000..9025a1b82ca --- /dev/null +++ b/paddle/phi/backends/dynload/cusparseLt.cc @@ -0,0 +1,28 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/backends/dynload/cusparseLt.h" + +namespace phi { +namespace dynload { + +std::once_flag cusparselt_dso_flag; +void *cusparselt_dso_handle = nullptr; + +#define DEFINE_WRAP(__name) DynLoad__##__name __name + +CUSPARSELT_ROUTINE_EACH(DEFINE_WRAP); + +} // namespace dynload +} // namespace phi diff --git a/paddle/phi/backends/dynload/cusparseLt.h b/paddle/phi/backends/dynload/cusparseLt.h new file mode 100644 index 00000000000..8eecefab5e4 --- /dev/null +++ b/paddle/phi/backends/dynload/cusparseLt.h @@ -0,0 +1,78 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#include +#include + +#include // NOLINT + +#include "paddle/phi/backends/dynload/dynamic_loader.h" +#include "paddle/phi/backends/dynload/port.h" + +namespace phi { +namespace dynload { + +extern std::once_flag cusparselt_dso_flag; +extern void *cusparselt_dso_handle; + +/** + * The following macro definition can generate structs + * (for each function) to dynamic load cupti routine + * via operator overloading. + * + * note: default dynamic linked libs + */ +#define DECLARE_DYNAMIC_LOAD_CUSPARSELT_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + cusparseStatus_t operator()(Args... args) { \ + using cusparseltFunc = decltype(&::__name); \ + std::call_once(cusparselt_dso_flag, []() { \ + cusparselt_dso_handle = phi::dynload::GetCusparseLtDsoHandle(); \ + }); \ + static void *p_##__name = dlsym(cusparselt_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + }; \ + extern DynLoad__##__name __name +#if defined(PADDLE_WITH_CUDA) +#if CUDA_VERSION >= 11020 +#define CUSPARSELT_ROUTINE_EACH(__macro) \ + __macro(cusparseLtInit); \ + __macro(cusparseLtDestroy); \ + __macro(cusparseLtDenseDescriptorInit); \ + __macro(cusparseLtStructuredDescriptorInit); \ + __macro(cusparseLtMatmulDescriptorInit); \ + __macro(cusparseLtMatmulDescSetAttribute); \ + __macro(cusparseLtMatmulAlgSelectionInit); \ + __macro(cusparseLtMatmulAlgSetAttribute); \ + __macro(cusparseLtMatmulGetWorkspace); \ + __macro(cusparseLtMatmulPlanInit); \ + __macro(cusparseLtMatDescriptorDestroy); \ + __macro(cusparseLtSpMMACompressedSize2); \ + __macro(cusparseLtSpMMACompress2); \ + __macro(cusparseLtMatmulSearch); \ + __macro(cusparseLtMatmulAlgGetAttribute); \ + __macro(cusparseLtMatmulPlanDestroy); \ + __macro(cusparseLtMatmul); \ + __macro(cusparseGetErrorString); + +CUSPARSELT_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSPARSELT_WRAP); +#endif +#endif + +#undef DECLARE_DYNAMIC_LOAD_CUSPARSELT_WRAP +} // namespace dynload +} // namespace phi diff --git a/paddle/phi/backends/dynload/dynamic_loader.cc b/paddle/phi/backends/dynload/dynamic_loader.cc index 2f35e22a18f..36a78695959 100644 --- a/paddle/phi/backends/dynload/dynamic_loader.cc +++ b/paddle/phi/backends/dynload/dynamic_loader.cc @@ -76,6 +76,8 @@ DEFINE_string(mkl_dir, DEFINE_string(op_dir, "", "Specify path for loading user-defined op library."); +DEFINE_string(cusparselt_dir, "", "Specify path for loading libcusparseLt.so."); + #ifdef PADDLE_WITH_HIP DEFINE_string(miopen_dir, @@ -578,5 +580,18 @@ void* GetMKLRTDsoHandle() { #endif } +void* GetCusparseLtDsoHandle() { +// APIs available after CUDA 11.2 +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11020 + return GetDsoHandleFromSearchPath(FLAGS_cusparselt_dir, "libcusparseLt.so"); +#else + std::string warning_msg( + "Your CUDA_VERSION less 11.2, not support cusparseLt. " + "If you want to use cusparseLt, please upgrade CUDA and rebuild " + "PaddlePaddle."); + return nullptr; +#endif +} + } // namespace dynload } // namespace phi diff --git a/paddle/phi/backends/dynload/dynamic_loader.h b/paddle/phi/backends/dynload/dynamic_loader.h index 942a635b649..642535fc50c 100644 --- a/paddle/phi/backends/dynload/dynamic_loader.h +++ b/paddle/phi/backends/dynload/dynamic_loader.h @@ -45,6 +45,7 @@ void* GetNvtxDsoHandle(); void* GetCUFFTDsoHandle(); void* GetMKLRTDsoHandle(); void* GetROCFFTDsoHandle(); +void* GetCusparseLtDsoHandle(); void SetPaddleLibPath(const std::string&); -- GitLab