diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 27bae7a71ea192ac08e4e87cb7bcdb8b84e29dc8..dc21f25da66034ce13f3ce313d6e95c2f2e01798 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -194,6 +194,7 @@ struct Argument { DECL_ARGUMENT_FIELD(tensorrt_use_static_engine, TensorRtUseStaticEngine, bool); DECL_ARGUMENT_FIELD(tensorrt_use_calib_mode, TensorRtUseCalibMode, bool); + DECL_ARGUMENT_FIELD(tensorrt_use_oss, TensorRtUseOSS, bool); DECL_ARGUMENT_FIELD(lite_passes_filter, LitePassesFilter, std::vector); diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index cd8d86d72938417112e17e86e5cc6dd12254a8d1..7017cab5e3ab6e3aa708fcb94a7482eafa48e3ae 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -93,6 +93,7 @@ void IRPassManager::CreatePasses(Argument *argument, bool use_calib_mode = argument->tensorrt_use_calib_mode(); pass->Set("enable_int8", new bool(enable_int8)); pass->Set("use_calib_mode", new bool(use_calib_mode)); + pass->Set("use_oss", new bool(argument->tensorrt_use_oss())); pass->Set("precision_mode", new AnalysisConfig::Precision(precision_mode)); diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 7ef072277fb7f1f13c14b38d64cea6d1f4584b76..53e7ba0a9a9036c0b75714cb26630bec6af63e00 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -114,11 +114,20 @@ void TensorRtSubgraphPass::CreateTensorRTOp( block_desc.Proto()->set_idx(0); LOG(INFO) << "--- detect a sub-graph with " << subgraph.size() << " nodes"; + bool has_fused_embedding_eltwise_layernorm = false; + bool has_multihead_matmul = false; for (auto *node : subgraph) { auto *new_block_op = new_block->AppendOp(); auto *op = block_desc.AppendOp(); *new_block_op->Proto() = *node->Op()->Proto(); *op->Proto() = *node->Op()->Proto(); + if (!has_fused_embedding_eltwise_layernorm && + op->Type() == "fused_embedding_eltwise_layernorm") { + has_fused_embedding_eltwise_layernorm = true; + } + if (!has_multihead_matmul && op->Type() == "multihead_matmul") { + has_multihead_matmul = true; + } } // Then, we will use the input_names_with_id and output_names_with_id to @@ -300,6 +309,9 @@ void TensorRtSubgraphPass::CreateTensorRTOp( precision_mode, calibrator.get(), Get("gpu_device_id"), min_input_shape, max_input_shape, opt_input_shape, disable_trt_plugin_fp16); + trt_engine->SetUseOSS(Get("use_oss")); + trt_engine->SetWithErnie(has_multihead_matmul && + has_fused_embedding_eltwise_layernorm); bool need_serialize = (use_static_engine && !load_from_memory); if (need_serialize) { diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 39c5cbff1f4513026e23ea81e6e56806f7c84332..9fd312de7e2b82e7deb3d454a13745b48e84d29c 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -122,6 +122,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(tensorrt_precision_mode_); CP_MEMBER(trt_use_static_engine_); CP_MEMBER(trt_use_calib_mode_); + CP_MEMBER(trt_use_oss_); // MKLDNN related. CP_MEMBER(use_mkldnn_); CP_MEMBER(mkldnn_enabled_op_types_); @@ -258,6 +259,8 @@ void AnalysisConfig::SetTRTDynamicShapeInfo( disable_trt_plugin_fp16_ = disable_trt_plugin_fp16; } +void AnalysisConfig::EnableTensorRtOSS() { trt_use_oss_ = true; } + // TODO(Superjomn) refactor this, buggy. void AnalysisConfig::Update() { auto info = SerializeInfoCache(); diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index dc927576af12706a22544eed7f3436747dfe4486..12855d706c2d59aafee205ae3bde79022f908a97 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -437,6 +437,7 @@ void AnalysisPredictor::PrepareArgument() { argument_.SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_); argument_.SetTensorRtUseStaticEngine(config_.trt_use_static_engine_); argument_.SetTensorRtUseCalibMode(config_.trt_use_calib_mode_); + argument_.SetTensorRtUseOSS(config_.trt_use_oss_); argument_.SetMinInputShape(config_.min_input_shape_); argument_.SetMaxInputShape(config_.max_input_shape_); argument_.SetOptimInputShape(config_.optim_input_shape_); @@ -953,7 +954,7 @@ USE_TRT_CONVERTER(elementwise_mul_tensor); USE_TRT_CONVERTER(elementwise_max_tensor); USE_TRT_CONVERTER(elementwise_min_tensor); USE_TRT_CONVERTER(elementwise_pow_tensor); -USE_TRT_CONVERTER(mul); +USE_TRT_CONVERTER(matmul); USE_TRT_CONVERTER(conv2d); USE_TRT_CONVERTER(relu); USE_TRT_CONVERTER(sigmoid); diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index 39346414a8a0d62903f56280c638ca89eac833b0..e1f787490f937c9b41acff25d50d3fc5a76d0a75 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -312,6 +312,22 @@ struct AnalysisConfig { std::map> max_input_shape, std::map> optim_input_shape, bool disable_trt_plugin_fp16 = false); + + /// + /// \brief Replace some TensorRT plugins to TensorRT OSS( + /// https://github.com/NVIDIA/TensorRT), with which some models's inference + /// may + /// be more high-performance. Libnvinfer_plugin.so greater than V7.2.1 is + /// needed. + /// + void EnableTensorRtOSS(); + /// + /// \brief A boolean state telling whether to use the TensorRT OSS. + /// + /// \return bool Whether to use the TensorRT OSS. + /// + bool tensorrt_oss_enabled() { return trt_use_oss_; } + /// /// \brief Turn on the usage of Lite sub-graph engine. /// @@ -531,6 +547,7 @@ struct AnalysisConfig { Precision tensorrt_precision_mode_{Precision::kFloat32}; bool trt_use_static_engine_{false}; bool trt_use_calib_mode_{true}; + bool trt_use_oss_{false}; std::map> min_input_shape_{}; std::map> max_input_shape_{}; std::map> optim_input_shape_{}; diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 39d02909abd1f1d96f73cc9f3e3ea9d26a1f5c72..e20d017cdf9d61e4d5e9c26ee2cfd30c15df95dd 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -1,6 +1,6 @@ # Add TRT tests nv_library(tensorrt_converter - SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc + SRCS matmul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc diff --git a/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc index 60670b41aba2126bfb359b680089f510e51025d8..05ff62e8f7e7aa4fbad8aa59ac304e93a9c06f91 100644 --- a/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc @@ -40,6 +40,9 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { input_ids.push_back(engine_->GetITensor(id_names[i])); } + // input_embs[0]: word_embedding + // input_embs[1]: pos_embedding + // input_embs[2]: sent_embedding std::vector input_embs; std::vector emb_sizes; @@ -76,15 +79,90 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { get_persistable_data(op_desc.Input("Scale").front(), &scale_dims); int64_t bias_size = framework::product(bias_dims); int64_t scale_size = framework::product(scale_dims); - float eps = boost::get(op_desc.GetAttr("epsilon")); nvinfer1::ILayer* layer = nullptr; if (engine_->with_dynamic_shape()) { - auto use_fp16 = engine_->WithFp16(); - auto plugin = new plugin::EmbEltwiseLayernormPluginDynamic( - input_embs, bias, scale, emb_sizes, bias_size, scale_size, hidden, - eps, use_fp16); - layer = engine_->AddPluginV2(input_ids.data(), input_num, plugin); + if (engine_->use_oss()) { + int output_fp16 = static_cast((engine_->WithFp16() == 1) ? 1 : 0); + PADDLE_ENFORCE_EQ( + output_fp16, 1, + platform::errors::InvalidArgument( + "Only Precision::KHalf(fp16) is supported when infering " + "ernie(bert) model with config.EnableTensorRtOSS(). " + "But Precision::KFloat32 is setted.")); + const std::vector fields{ + {"bert_embeddings_layernorm_beta", bias, + nvinfer1::PluginFieldType::kFLOAT32, + static_cast(bias_size)}, + {"bert_embeddings_layernorm_gamma", scale, + nvinfer1::PluginFieldType::kFLOAT32, + static_cast(scale_size)}, + {"bert_embeddings_word_embeddings", input_embs[0], + nvinfer1::PluginFieldType::kFLOAT32, + static_cast(emb_sizes[0])}, + {"bert_embeddings_token_type_embeddings", input_embs[2], + nvinfer1::PluginFieldType::kFLOAT32, + static_cast(emb_sizes[2])}, + {"bert_embeddings_position_embeddings", input_embs[1], + nvinfer1::PluginFieldType::kFLOAT32, + static_cast(emb_sizes[1])}, + {"output_fp16", &output_fp16, nvinfer1::PluginFieldType::kINT32, 1}, + }; + + // remember to free + nvinfer1::PluginFieldCollection* plugin_ptr = + static_cast( + malloc(sizeof(*plugin_ptr) + + fields.size() * sizeof(nvinfer1::PluginField))); + plugin_ptr->nbFields = static_cast(fields.size()); + plugin_ptr->fields = fields.data(); + + std::vector plugin_inputs; + plugin_inputs.emplace_back(engine_->GetITensor( + engine_->network()->getInput(0)->getName())); // word_embedding, + // eval_placeholder_0 + plugin_inputs.emplace_back(engine_->GetITensor( + engine_->network()->getInput(1)->getName())); // sent_embedding, + // eval_placeholder_1 + plugin_inputs.emplace_back(engine_->GetITensor( + engine_->network()->getInput(2)->getName())); // cu_seqlens, + // eval_placeholder_2 + auto max_seqlen_tensor = + engine_->GetITensor(engine_->network()->getInput(3)->getName()); + auto* shuffle_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *max_seqlen_tensor); + nvinfer1::Dims shape_dim; + shape_dim.nbDims = 1; + shape_dim.d[0] = -1; + shuffle_layer->setReshapeDimensions(shape_dim); + plugin_inputs.emplace_back( + shuffle_layer->getOutput(0)); // max_seqlen, eval_placeholder_3 + + auto creator = GetPluginRegistry()->getPluginCreator( + "CustomEmbLayerNormPluginDynamic", "2"); + + auto plugin_obj = creator->createPlugin( + "CustomEmbLayerNormPluginDynamic", plugin_ptr); + auto plugin_layer = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin_obj); + layer = plugin_layer; + free(plugin_ptr); + auto output_name = op_desc.Output("Out")[0]; + RreplenishLayerAndOutput(layer, "emb_eltwise_layernorm", + {output_name, std::string("qkv_plugin_mask")}, + test_mode); + } else { + bool use_fp16 = engine_->WithFp16(); + float eps = boost::get(op_desc.GetAttr("epsilon")); + plugin::DynamicPluginTensorRT* plugin = nullptr; + plugin = new plugin::EmbEltwiseLayernormPluginDynamic( + input_embs, bias, scale, emb_sizes, bias_size, scale_size, hidden, + eps, use_fp16); + layer = engine_->AddPluginV2(input_ids.data(), input_num, plugin); + auto output_name = op_desc.Output("Out")[0]; + RreplenishLayerAndOutput(layer, "emb_eltwise_layernorm", {output_name}, + test_mode); + } } else { PADDLE_THROW(platform::errors::Fatal( "You are running the Ernie(Bert) model in static" @@ -93,9 +171,6 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { " to set the shape information to run the dynamic shape mode.")); } - auto output_name = op_desc.Output("Out")[0]; - RreplenishLayerAndOutput(layer, "emb_eltwise_layernorm", {output_name}, - test_mode); #else PADDLE_THROW(platform::errors::Fatal( "You are running the TRT Dynamic Shape mode, need to confirm that " diff --git a/paddle/fluid/inference/tensorrt/convert/matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/matmul_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..737573dc55193034e2ff09021066cae2fa1c6aff --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/matmul_op.cc @@ -0,0 +1,89 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace framework { +class Scope; +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * MatMulOp, IMatrixMultiplyLayer in TRT. This Layer doesn't has weights. + */ +class MatMulOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(3) << "convert a fluid matmul op to tensorrt mul layer without bias"; + + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]); + auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]); + + bool transpose_X = boost::get(op_desc.GetAttr("transpose_X")); + bool transpose_Y = boost::get(op_desc.GetAttr("transpose_Y")); + + auto* layer = TRT_ENGINE_ADD_LAYER(engine_, MatrixMultiply, *input1, + transpose_X, *input2, transpose_Y); + + float alpha = boost::get(op_desc.GetAttr("alpha")); + auto output_name = op_desc.Output("Out")[0]; + if (fabs(alpha - 1.0) < std::numeric_limits::epsilon()) { + engine_->SetITensor(output_name, layer->getOutput(0)); + } else { + auto create_weights = [&](float data, const std::string& type) -> float* { + std::unique_ptr tmp_tensor(new framework::Tensor()); + tmp_tensor->Resize({1}); + auto* tmp_data = tmp_tensor->mutable_data(platform::CPUPlace()); + tmp_data[0] = data; + engine_->SetWeights(output_name + "_add_scale_op_" + type, + std::move(tmp_tensor)); + return tmp_data; + }; + float* alpha_data = create_weights(alpha, "alpha"); + float* shift_data = create_weights(0.0, "shift"); + float* power_data = create_weights(1.0, "power"); + TensorRTEngine::Weight nv_alpha{nvinfer1::DataType::kFLOAT, + static_cast(alpha_data), 1}; + TensorRTEngine::Weight nv_shift{nvinfer1::DataType::kFLOAT, + static_cast(shift_data), 1}; + TensorRTEngine::Weight nv_power{nvinfer1::DataType::kFLOAT, + static_cast(power_data), 1}; + auto* scale_layer = TRT_ENGINE_ADD_LAYER( + engine_, Scale, *layer->getOutput(0), nvinfer1::ScaleMode::kUNIFORM, + nv_shift.get(), nv_alpha.get(), nv_power.get()); + engine_->SetITensor(output_name, scale_layer->getOutput(0)); + } + if (test_mode) { // the test framework can not determine which is the + // output, so place the declaration inside. + engine_->DeclareOutput(output_name); + } + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(matmul, MatMulOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/mul_op.cc b/paddle/fluid/inference/tensorrt/convert/mul_op.cc deleted file mode 100644 index 5b6aaad49833cedbd8d1ee0ec5d24c7f983190e6..0000000000000000000000000000000000000000 --- a/paddle/fluid/inference/tensorrt/convert/mul_op.cc +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" - -namespace paddle { -namespace inference { -namespace tensorrt { - -/* - * MulOp, IMatrixMultiplyLayer in TRT. This Layer doesn't has weights. - */ -class MulOpConverter : public OpConverter { - public: - void operator()(const framework::proto::OpDesc& op, - const framework::Scope& scope, bool test_mode) override { - VLOG(3) << "convert a fluid mul op to tensorrt mul layer without bias"; - - framework::OpDesc op_desc(op, nullptr); - // Declare inputs - auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]); - auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]); - // Both the input1 and input2 do not need transpose. - auto* layer = TRT_ENGINE_ADD_LAYER( - engine_, MatrixMultiply, *const_cast(input1), false, - *const_cast(input2), false); - - auto output_name = op_desc.Output("Out")[0]; - engine_->SetITensor(output_name, layer->getOutput(0)); - if (test_mode) { // the test framework can not determine which is the - // output, so place the declaration inside. - engine_->DeclareOutput(output_name); - } - } -}; - -} // namespace tensorrt -} // namespace inference -} // namespace paddle - -REGISTER_TRT_OP_CONVERTER(mul, MulOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc index 91c0c56ec580c7900f84580fa760ab8e906deb74..820be425e844901130f2ada92c3e47845fc7afbd 100644 --- a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc @@ -30,7 +30,6 @@ class MultiheadMatMulOpConverter : public OpConverter { // Declare inputs // Shouble be a 5 dims tensor. auto* input = engine_->GetITensor(op_desc.Input("Input").front()); - auto* input_bias_qk = engine_->GetITensor(op_desc.Input("BiasQK").front()); // fc weights and fc bias auto weight_name = op_desc.Input("W").front(); @@ -50,7 +49,7 @@ class MultiheadMatMulOpConverter : public OpConverter { memcpy(weight_data_tmp.data(), weight_data, weight_t->numel() * sizeof(float)); - // (hidden, 3, all_head_size) + // (hidden, 3, all_head_size) auto weight_dims = weight_t->dims(); int hidden = weight_dims[0]; // channels_in @@ -65,36 +64,139 @@ class MultiheadMatMulOpConverter : public OpConverter { } } }; - - // transpose weight_data from m * n to n * m tranpose_weight(weight_data_tmp.data(), weight_data, m, n); - TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, - static_cast(weight_data), - static_cast(weight_t->numel())}; - - weight.dims.assign({n, m}); - TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, - static_cast(bias_data), - static_cast(bias_t->numel())}; - - auto* fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *input, n, - weight.get(), bias.get()); - auto* fc_out = fc_layer->getOutput(0); - // add qkv to context + int head_number = boost::get(op_desc.GetAttr("head_number")); - int head_size = all_head_size / head_number; - float scale = boost::get(op_desc.GetAttr("alpha")); - std::vector plugin_inputs; - plugin_inputs.push_back(fc_out); - plugin_inputs.push_back(input_bias_qk); nvinfer1::ILayer* layer = nullptr; + if (engine_->with_dynamic_shape()) { - bool ban_fp16 = engine_->disable_trt_plugin_fp16(); - plugin::DynamicPluginTensorRT* plugin = - new plugin::QkvToContextPluginDynamic(hidden, head_number, head_size, - scale, ban_fp16); - layer = engine_->AddPluginV2(plugin_inputs.data(), 2, plugin); + if (engine_->use_oss()) { + int head_size = hidden / head_number; + // [3, Nout, Hout, Nin, Hin] -> [Nout, 3, Hout, Nin, Hin] + auto transpose_weight_v2 = [](const float* src, float* dst, int N, + int H) { + const int HNH = H * N * H; + for (int i = 0; i < 3; ++i) { + for (int n = 0; n < N; ++n) { + for (int hnh = 0; hnh < HNH; ++hnh) { + dst[n * 3 * HNH + i * HNH + hnh] = + src[i * N * HNH + n * HNH + hnh]; + } + } + } + }; + // [3, N, H] -> [N, 3, H] + auto transpose_bias_v2 = [](const float* src, float* dst, int N, + int H) { + for (int i = 0; i < 3; ++i) { + for (int n = 0; n < N; ++n) { + for (int h = 0; h < H; ++h) { + dst[n * 3 * H + i * H + h] = src[i * N * H + n * H + h]; + } + } + } + }; + memcpy(weight_data_tmp.data(), weight_data, + weight_t->numel() * sizeof(float)); + transpose_weight_v2(weight_data_tmp.data(), weight_data, head_number, + head_size); + nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT, + static_cast(weight_data), + static_cast(weight_t->numel())}; + + std::vector bias_data_tmp; + bias_data_tmp.reserve(bias_t->numel()); + memcpy(bias_data_tmp.data(), bias_data, + bias_t->numel() * sizeof(float)); + transpose_bias_v2(bias_data_tmp.data(), bias_data, head_number, + head_size); + nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, + static_cast(bias_data), + static_cast(bias_t->numel())}; + + auto* fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *input, + n, weight, bias); + + auto mask_tensor = engine_->GetITensor("qkv_plugin_mask"); + + auto creator = GetPluginRegistry()->getPluginCreator( + "CustomQKVToContextPluginDynamic", "2"); + assert(creator != nullptr); + int type = static_cast((engine_->WithFp16() == 1) + ? nvinfer1::DataType::kHALF + : nvinfer1::DataType::kFLOAT); + bool has_mask = true; + int var_seqlen = 1; + const std::vector fields{ + {"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1}, + {"hidden_size", &hidden, nvinfer1::PluginFieldType::kINT32, 1}, + {"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1}, + {"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1}, + {"var_seqlen", &var_seqlen, nvinfer1::PluginFieldType::kINT32, 1}, + }; + nvinfer1::PluginFieldCollection* plugin_collection = + static_cast( + malloc(sizeof(*plugin_collection) + + fields.size() * + sizeof(nvinfer1::PluginField))); // remember to free + plugin_collection->nbFields = static_cast(fields.size()); + plugin_collection->fields = fields.data(); + + auto plugin = creator->createPlugin("CustomQKVToContextPluginDynamic", + plugin_collection); + free(plugin_collection); + + std::vector plugin_inputs; + plugin_inputs.emplace_back(fc_layer->getOutput(0)); + plugin_inputs.emplace_back(mask_tensor); + plugin_inputs.emplace_back(engine_->GetITensor( + engine_->network()->getInput(2)->getName())); // cu_seqlens, + // eval_placeholder_2 + auto max_seqlen_tensor = + engine_->GetITensor(engine_->network()->getInput(3)->getName()); + auto* shuffle_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *max_seqlen_tensor); + nvinfer1::Dims shape_dim; + shape_dim.nbDims = 1; + shape_dim.d[0] = -1; + shuffle_layer->setReshapeDimensions(shape_dim); + plugin_inputs.emplace_back( + shuffle_layer->getOutput(0)); // max_seqlen, eval_placeholder_3 + + auto plugin_layer = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin); + layer = plugin_layer; + } else { + // transpose weight_data from m * n to n * m + auto* input_bias_qk = + engine_->GetITensor(op_desc.Input("BiasQK").front()); + + TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, + static_cast(weight_data), + static_cast(weight_t->numel())}; + weight.dims.assign({n, m}); + + TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, + static_cast(bias_data), + static_cast(bias_t->numel())}; + + auto* fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *input, + n, weight.get(), bias.get()); + auto* fc_out = fc_layer->getOutput(0); + // add qkv to context + int head_size = all_head_size / head_number; + float scale = boost::get(op_desc.GetAttr("alpha")); + + std::vector plugin_inputs; + plugin_inputs.push_back(fc_out); + plugin_inputs.push_back(input_bias_qk); + bool ban_fp16 = engine_->disable_trt_plugin_fp16(); + plugin::DynamicPluginTensorRT* plugin = + new plugin::QkvToContextPluginDynamic(hidden, head_number, + head_size, scale, ban_fp16); + layer = engine_->AddPluginV2(plugin_inputs.data(), 2, plugin); + } } else { PADDLE_THROW(platform::errors::Fatal( "You are running the Ernie(Bert) model in static shape mode, which " diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index 181190f1dce28a9189a0f4d067b7084780134914..f9f4f43a92cfcfa2f209a431b6eb47a0291a333b 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -145,6 +145,7 @@ class OpConverter { const std::unordered_set& parameters, const std::vector& outputs, TensorRTEngine* engine) { engine->InitNetwork(); + bool all_dynamic_shape_set = true; for (auto& input : inputs) { if (parameters.count(input)) continue; auto* var = block_desc->FindVar(input); @@ -158,6 +159,13 @@ class OpConverter { auto max_input_shape = engine->max_input_shape()[input]; auto optim_input_shape = engine->optim_input_shape()[input]; size_t ranks = min_input_shape.size(); + if (ranks == 0) { + all_dynamic_shape_set = false; + LOG(INFO) << "trt input [" << input.c_str() + << "] dynamic shape info not set, please check and retry."; + // check other input + continue; + } std::vector input_shape; input_shape.push_back(-1); for (size_t i = 1; i < ranks; i++) { @@ -184,6 +192,10 @@ class OpConverter { Vec2TRT_Dims(var_shape, input)); } } + PADDLE_ENFORCE_EQ(all_dynamic_shape_set, true, + platform::errors::InvalidArgument( + "some trt inputs dynamic shape info not set, " + "check the INFO log above for more details.")); framework::proto::BlockDesc* block_proto = block_desc->Proto(); ConvertBlock(*block_proto, parameters, scope, engine); for (auto& output : outputs) { diff --git a/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc index 9d416b8cc5cd342594032957431f951b813ae46d..802e979045c884afcc2ec56067679b79992d804e 100644 --- a/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc @@ -47,17 +47,50 @@ class SkipLayerNormOpConverter : public OpConverter { framework::DDim bias_dims, scale_dims; auto* bias = get_persistable_data("Bias", &bias_dims); auto* scale = get_persistable_data("Scale", &scale_dims); - float eps = boost::get(op_desc.GetAttr("epsilon")); int bias_size = framework::product(bias_dims); int scale_size = framework::product(scale_dims); nvinfer1::ILayer* layer = nullptr; if (engine_->with_dynamic_shape()) { - bool ban_fp16 = engine_->disable_trt_plugin_fp16(); - plugin::SkipLayerNormPluginDynamic* plugin = - new plugin::SkipLayerNormPluginDynamic(bias, scale, bias_size, - scale_size, eps, ban_fp16); - layer = engine_->AddPluginV2(inputs.data(), 2, plugin); + if (engine_->use_oss()) { + auto creator = GetPluginRegistry()->getPluginCreator( + "CustomSkipLayerNormPluginDynamic", "2"); + assert(creator != nullptr); + int type = static_cast((engine_->WithFp16() == 1) + ? nvinfer1::DataType::kHALF + : nvinfer1::DataType::kFLOAT); + int ld = input1->getDimensions().d[2]; // hidden dimension + assert(ld > 0); + + const std::vector fields{ + {"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1}, + {"ld", &ld, nvinfer1::PluginFieldType::kINT32, 1}, + {"beta", bias, nvinfer1::PluginFieldType::kFLOAT32, bias_size}, + {"gamma", scale, nvinfer1::PluginFieldType::kFLOAT32, scale_size}, + }; + nvinfer1::PluginFieldCollection* pluginPtr = + static_cast( + malloc(sizeof(*pluginPtr) + + fields.size() * + sizeof(nvinfer1::PluginField))); // remember to free + pluginPtr->nbFields = static_cast(fields.size()); + pluginPtr->fields = fields.data(); + + auto pluginObj = creator->createPlugin( + "CustomSkipLayerNormPluginDynamic", pluginPtr); + auto plugin_layer = engine_->network()->addPluginV2( + inputs.data(), inputs.size(), *pluginObj); + + assert(plugin_layer != nullptr); + layer = plugin_layer; + } else { + float eps = boost::get(op_desc.GetAttr("epsilon")); + bool ban_fp16 = engine_->disable_trt_plugin_fp16(); + plugin::SkipLayerNormPluginDynamic* plugin = + new plugin::SkipLayerNormPluginDynamic(bias, scale, bias_size, + scale_size, eps, ban_fp16); + layer = engine_->AddPluginV2(inputs.data(), 2, plugin); + } } else { PADDLE_THROW(platform::errors::Fatal( "You are running the Ernie(Bert) model in static" diff --git a/paddle/fluid/inference/tensorrt/convert/slice_op.cc b/paddle/fluid/inference/tensorrt/convert/slice_op.cc index dee4439e7d166a1ccb61539e47eefdbd6c2846cd..16e070c754ff995d4b06bf4124d196671e01ab73 100644 --- a/paddle/fluid/inference/tensorrt/convert/slice_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/slice_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h" namespace paddle { namespace inference { @@ -77,16 +78,31 @@ class SliceOpConverter : public OpConverter { nvinfer1::ILayer* layer = nullptr; if (engine_->with_dynamic_shape()) { + if (engine_->use_oss() && engine_->with_ernie()) { + std::vector plugin_inputs; + // plugin_inputs.emplace_back(trans_layer->getOutput(0)); + plugin_inputs.emplace_back(input); + plugin_inputs.emplace_back(engine_->GetITensor( + engine_->network()->getInput(2)->getName())); // cu_seqlens, + // eval_placeholder_2 + + // bool ban_fp16 = engine_->disable_trt_plugin_fp16(); + plugin::SpecialSlicePluginDynamic* plugin = + new plugin::SpecialSlicePluginDynamic(); + layer = engine_->AddPluginV2(plugin_inputs.data(), plugin_inputs.size(), + plugin); + } else { #if IS_TRT_VERSION_GE(6000) - bool ban_fp16 = engine_->disable_trt_plugin_fp16(); - plugin::SlicePluginDynamic* plugin = - new plugin::SlicePluginDynamic(starts, ends, axes, ban_fp16); - layer = engine_->AddPluginV2(&input, 1, plugin); + bool ban_fp16 = engine_->disable_trt_plugin_fp16(); + plugin::SlicePluginDynamic* plugin = + new plugin::SlicePluginDynamic(starts, ends, axes, ban_fp16); + layer = engine_->AddPluginV2(&input, 1, plugin); #else - PADDLE_THROW(platform::errors::Fatal( - "You are running the TRT Dynamic Shape mode, need to confirm that " - "your TRT version is no less than 6.0")); + PADDLE_THROW(platform::errors::Fatal( + "You are running the TRT Dynamic Shape mode, need to confirm that " + "your TRT version is no less than 6.0")); #endif + } } else { bool ban_fp16 = engine_->disable_trt_plugin_fp16(); plugin::SlicePlugin* plugin = diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index fdd71b0d884004c84e2ee15eea522c64ff943dd9..2c80f29476b0085c4bb4600f52fbf3b2d90fdc78 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -60,9 +60,9 @@ TRT_DT FluidDataType2TRT(FluidDT type) { template nvinfer1::Dims Vec2TRT_Dims(const std::vector& shape, std::string input, bool with_dynamic_shape = false) { - PADDLE_ENFORCE_GT(shape.size(), 1UL, + PADDLE_ENFORCE_GT(shape.size(), 0UL, platform::errors::InvalidArgument( - "TensorRT's tensor input requires at least 2 " + "TensorRT's tensor input requires at least 1 " "dimensions, but input %s has %d dims.", input, shape.size())); PADDLE_ENFORCE_LE(shape.size(), 4UL, @@ -83,7 +83,12 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector& shape, std::string input, } else if (shape.size() == 3UL) { return nvinfer1::Dims3(shape[0], shape[1], shape[2]); } - return nvinfer1::Dims4(shape[0], shape[1], 1, 1); + nvinfer1::Dims dims; + dims.nbDims = shape.size(); + for (size_t i = 0; i < shape.size(); i++) { + dims.d[i] = shape[i]; + } + return dims; } } } // NOLINT @@ -157,6 +162,7 @@ class TensorRTEngine { "version should be at least 6."; #endif } + dy::initLibNvInferPlugins(&logger, ""); } ~TensorRTEngine() {} @@ -260,6 +266,9 @@ class TensorRTEngine { suffix_counter += 1; } + void SetUseOSS(bool use_oss) { use_oss_ = use_oss; } + void SetWithErnie(bool with_ernie) { with_ernie_ = with_ernie; } + void ClearWeights() { for (auto& weight_pair : weight_map) { weight_pair.second.reset(nullptr); @@ -287,6 +296,8 @@ class TensorRTEngine { ShapeMapType min_input_shape() { return min_input_shape_; } ShapeMapType max_input_shape() { return max_input_shape_; } ShapeMapType optim_input_shape() { return optim_input_shape_; } + bool use_oss() { return use_oss_; } + bool with_ernie() { return with_ernie_; } bool disable_trt_plugin_fp16() { return disable_trt_plugin_fp16_; } bool with_dynamic_shape() { return with_dynamic_shape_; } @@ -322,6 +333,8 @@ class TensorRTEngine { ShapeMapType max_input_shape_; ShapeMapType optim_input_shape_; bool disable_trt_plugin_fp16_{false}; + bool use_oss_{false}; + bool with_ernie_{false}; nvinfer1::ILogger& logger_; // max data size for the buffers. diff --git a/paddle/fluid/inference/tensorrt/helper.h b/paddle/fluid/inference/tensorrt/helper.h index 55a57caf9a0d6eb44399ceb8064b613afb955d47..971f99e69197226bb7d7b26135f0b667f8ebdf30 100644 --- a/paddle/fluid/inference/tensorrt/helper.h +++ b/paddle/fluid/inference/tensorrt/helper.h @@ -56,9 +56,11 @@ static nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger* logger) { return static_cast( dy::createInferRuntime_INTERNAL(logger, NV_TENSORRT_VERSION)); } -static nvinfer1::IPluginRegistry* getPluginRegistry() { +#if IS_TRT_VERSION_GE(6000) +static nvinfer1::IPluginRegistry* GetPluginRegistry() { return static_cast(dy::getPluginRegistry()); } +#endif // A logger for create TensorRT infer builder. class NaiveLogger : public nvinfer1::ILogger { diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index d5e15616df79d86d6481119ac0854899675de14f..ba130ec2380d4ef2c81fbdc0c5f0d0823b0f653c 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -13,6 +13,8 @@ // limitations under the License. #include "paddle/fluid/inference/tensorrt/op_teller.h" +#include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/var_desc.h" namespace paddle { namespace inference { @@ -61,6 +63,7 @@ struct SimpleOpTypeSetTeller : public Teller { "conv2d_transpose"}; std::unordered_set teller_set{ "mul", + "matmul", "conv2d", "pool2d", "relu", @@ -106,6 +109,21 @@ bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc, boost::get>(desc.GetAttr("paddings")); if (paddings.size() > 2) return false; } + if (op_type == "matmul") { + auto* block = desc.Block(); + for (auto& param_name : desc.Inputs()) { + for (auto& var_name : param_name.second) { + auto* var_desc = block->FindVar(var_name); + const auto shape = var_desc->GetShape(); + if (shape.size() < 3) { + VLOG(1) + << "matmul op dims < 3 not supported in tensorrt, but got dims " + << shape.size() << ", so jump it."; + return false; + } + } + } + } if ((*teller)(op_type, desc, use_no_calib_int8)) return true; } return false; diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index 98afdbe254a4b0a086d4a4aa88096a06c40138d1..e37beb3b8e5c3680eda481009699091dcc1ee7a3 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -4,5 +4,5 @@ nv_library(tensorrt_plugin pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu - hard_swish_op_plugin.cu stack_op_plugin.cu + hard_swish_op_plugin.cu stack_op_plugin.cu special_slice_plugin.cu DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor) diff --git a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu index f1e11b6fba1f1556e2a8a2aaaca1223aaef76b03..860f1039d5e10290d84d1761bc7337e49fa210eb 100644 --- a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.cu @@ -80,6 +80,12 @@ int PReluPlugin::enqueue(int batch_size, const void *const *inputs, #if IS_TRT_VERSION_GE(6000) +void PReluPluginDynamic::terminate() { + if (p_gpu_weight_) { + cudaFree(p_gpu_weight_); + } +} + int PReluPluginDynamic::initialize() { cudaMalloc(&p_gpu_weight_, sizeof(float) * weight_.size()); cudaMemcpy(p_gpu_weight_, weight_.data(), weight_.size() * sizeof(float), diff --git a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h index 4756ca2e0225795edc3bd3112b21e3b628ad5c0b..3126366c5fdd8bb69a78cea11f5778c45de738ec 100644 --- a/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h @@ -102,12 +102,15 @@ class PReluPluginDynamic : public DynamicPluginTensorRT { } ~PReluPluginDynamic() { cudaFree(p_gpu_weight_); } nvinfer1::IPluginV2DynamicExt* clone() const override { - return new PReluPluginDynamic(weight_.data(), weight_.size(), mode_); + auto ptr = new PReluPluginDynamic(weight_.data(), weight_.size(), mode_); + ptr->p_gpu_weight_ = p_gpu_weight_; + return ptr; } const char* getPluginType() const override { return "prelu_plugin"; } int getNbOutputs() const override { return 1; } int initialize() override; + void terminate() override; size_t getSerializationSize() const override; void serialize(void* buffer) const override; diff --git a/paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h index 8fe1edc4bf0321b054322a27f0c16819bc023ed8..24cd8e0368182ae597e48765bc0167ca1eca6bd3 100644 --- a/paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h @@ -51,8 +51,11 @@ class SkipLayerNormPluginDynamic : public DynamicPluginTensorRT { } nvinfer1::IPluginV2DynamicExt* clone() const override { - return new SkipLayerNormPluginDynamic( + auto ptr = new SkipLayerNormPluginDynamic( bias_.data(), scale_.data(), bias_size_, scale_size_, eps_, ban_fp16_); + ptr->bias_gpu_ = bias_gpu_; + ptr->scale_gpu_ = bias_gpu_; + return ptr; } const char* getPluginType() const override { return "skip_layernorm_plugin"; } diff --git a/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu new file mode 100644 index 0000000000000000000000000000000000000000..ed0a530439f0a1f9ea6c45810725de56a88a8411 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu @@ -0,0 +1,177 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +#if IS_TRT_VERSION_GE(6000) +SpecialSlicePluginDynamic::SpecialSlicePluginDynamic() {} + +SpecialSlicePluginDynamic::SpecialSlicePluginDynamic(void const* serial_data, + size_t serial_length) {} + +SpecialSlicePluginDynamic::~SpecialSlicePluginDynamic() {} + +nvinfer1::IPluginV2DynamicExt* SpecialSlicePluginDynamic::clone() const { + return new SpecialSlicePluginDynamic(); +} + +const char* SpecialSlicePluginDynamic::getPluginType() const { + return "special_slice_plugin"; +} + +int SpecialSlicePluginDynamic::getNbOutputs() const { return 1; } + +int SpecialSlicePluginDynamic::initialize() { return 0; } + +size_t SpecialSlicePluginDynamic::getSerializationSize() const { + size_t serialize_size = 0; + return serialize_size; +} + +void SpecialSlicePluginDynamic::serialize(void* buffer) const {} + +nvinfer1::DimsExprs SpecialSlicePluginDynamic::getOutputDimensions( + int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs, + nvinfer1::IExprBuilder& expr_builder) { + nvinfer1::DimsExprs output(inputs[0]); + auto one = expr_builder.constant(1); + output.d[0] = expr_builder.operation(nvinfer1::DimensionOperation::kSUB, + *inputs[1].d[0], *one); + + return output; +} + +void SpecialSlicePluginDynamic::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) {} + +size_t SpecialSlicePluginDynamic::getWorkspaceSize( + const nvinfer1::PluginTensorDesc* inputs, int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const { + return 0; +} + +void SpecialSlicePluginDynamic::destroy() { delete this; } + +void SpecialSlicePluginDynamic::terminate() {} + +bool SpecialSlicePluginDynamic::supportsFormatCombination( + int pos, const nvinfer1::PluginTensorDesc* desc, int nb_inputs, + int nb_outputs) { + if (pos == 0) // slice tensor + return (desc[pos].type == nvinfer1::DataType::kHALF && + desc[pos].format == + nvinfer1::TensorFormat::kLINEAR); // || desc[pos].type == + // nvinfer1::DataType::kFLOAT); + + if (pos == 1) // cu_seqlen + return (desc[pos].type == nvinfer1::DataType::kINT32 && + desc[pos].format == nvinfer1::TensorFormat::kLINEAR); + + return (desc[pos].type == nvinfer1::DataType::kHALF && + desc[pos].format == + nvinfer1::TensorFormat::kLINEAR); // || desc[pos].type == + // nvinfer1::DataType::kFLOAT); +} + +nvinfer1::DataType SpecialSlicePluginDynamic::getOutputDataType( + int index, const nvinfer1::DataType* input_types, int nb_inputs) const { + PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument( + "The index should be equal to 0")); + return input_types[0]; +} + +template +__global__ void SpecialSliceKernel(const T* slice_input, + const int32_t* cu_seqlens, T* output) { + const int hidden = blockDim.x; + const int batch = blockIdx.x; + + output[batch * hidden + threadIdx.x] = + slice_input[cu_seqlens[batch] * hidden + threadIdx.x]; +} + +int SpecialSlicePluginDynamic::enqueue( + const nvinfer1::PluginTensorDesc* input_desc, + const nvinfer1::PluginTensorDesc* output_desc, const void* const* inputs, + void* const* outputs, void* workspace, cudaStream_t stream) { + auto input_dims = input_desc[0].dims; // (sum(S), 768, 1, 1) + auto out_dims = output_desc[0].dims; // (batch, 768, 1, 1) + + assert(input_desc[0].type == nvinfer1::DataType::kHALF); + + const int32_t hidden = input_dims.d[1]; + const int num_blocks = out_dims.d[0]; // batch size + const int num_threads = hidden; + + const half* slice_input = static_cast(inputs[0]); + const int32_t* cu_seqlens = static_cast(inputs[1]); + half* output = static_cast(outputs[0]); + + SpecialSliceKernel<<>>( + slice_input, cu_seqlens, output); + + return cudaGetLastError() != cudaSuccess; +} + +SpecialSlicePluginDynamicCreator::SpecialSlicePluginDynamicCreator() {} + +const char* SpecialSlicePluginDynamicCreator::getPluginName() const { + return "special_slice_plugin"; +} + +const char* SpecialSlicePluginDynamicCreator::getPluginVersion() const { + return "1"; +} + +const nvinfer1::PluginFieldCollection* +SpecialSlicePluginDynamicCreator::getFieldNames() { + return &field_collection_; +} + +nvinfer1::IPluginV2* SpecialSlicePluginDynamicCreator::createPlugin( + const char* name, const nvinfer1::PluginFieldCollection* fc) { + return new SpecialSlicePluginDynamic(); +} + +nvinfer1::IPluginV2* SpecialSlicePluginDynamicCreator::deserializePlugin( + const char* name, const void* serial_data, size_t serial_length) { + auto plugin = new SpecialSlicePluginDynamic(serial_data, serial_length); + return plugin; +} + +void SpecialSlicePluginDynamicCreator::setPluginNamespace( + const char* lib_namespace) { + plugin_namespace_ = lib_namespace; +} + +const char* SpecialSlicePluginDynamicCreator::getPluginNamespace() const { + return plugin_namespace_.c_str(); +} + +#endif + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h b/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h new file mode 100644 index 0000000000000000000000000000000000000000..438d9e9465c52a8c0288928c1aa3cb79d0371080 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h @@ -0,0 +1,96 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +#if IS_TRT_VERSION_GE(6000) +class SpecialSlicePluginDynamic : public DynamicPluginTensorRT { + public: + SpecialSlicePluginDynamic(); + SpecialSlicePluginDynamic(void const* serial_data, size_t serial_length); + ~SpecialSlicePluginDynamic(); + nvinfer1::IPluginV2DynamicExt* clone() const override; + nvinfer1::DimsExprs getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) override; + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, int nbOutputs) override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) override; + + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const override; + + const char* getPluginType() const override; + int getNbOutputs() const override; + int initialize() override; + void terminate() override; + size_t getSerializationSize() const override; + void serialize(void* buffer) const override; + void destroy() override; + + private: + int axis_; + int num_stack_; +}; + +class SpecialSlicePluginDynamicCreator : public nvinfer1::IPluginCreator { + public: + SpecialSlicePluginDynamicCreator(); + const char* getPluginName() const override; + const char* getPluginVersion() const override; + const nvinfer1::PluginFieldCollection* getFieldNames() override; + nvinfer1::IPluginV2* createPlugin( + const char* name, const nvinfer1::PluginFieldCollection* fc) override; + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serial_data, + size_t serial_length) override; + void setPluginNamespace(const char* lib_namespace) override; + const char* getPluginNamespace() const override; + + private: + std::string plugin_namespace_; + nvinfer1::PluginFieldCollection field_collection_{0, nullptr}; + std::vector plugin_attributes_; +}; +REGISTER_TRT_PLUGIN_V2(SpecialSlicePluginDynamicCreator); +#endif + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h index 33eec618ff62bd04e368d63839b7ee669f7f9519..528adacb27c9897420a5115a93c88c246c0d78d8 100644 --- a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h @@ -178,13 +178,12 @@ class DynamicPluginTensorRT : public nvinfer1::IPluginV2DynamicExt { std::string name_space_; std::string plugin_base_; }; -#endif template class TrtPluginRegistrarV2 { public: TrtPluginRegistrarV2() { - static auto func_ptr = getPluginRegistry(); + static auto func_ptr = GetPluginRegistry(); if (func_ptr != nullptr) { func_ptr->registerCreator(creator, ""); } @@ -198,6 +197,8 @@ class TrtPluginRegistrarV2 { static paddle::inference::tensorrt::plugin::TrtPluginRegistrarV2 \ plugin_registrar_##name {} +#endif + } // namespace plugin } // namespace tensorrt } // namespace inference diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_deserialize_test.cc b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_deserialize_test.cc index 07f646c962ea2089301fe4265442435f247e3be8..3d84264319a6fa8ba4363cf31425603489207e06 100644 --- a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_deserialize_test.cc +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_deserialize_test.cc @@ -123,17 +123,17 @@ void trt_ernie(bool with_fp16, std::vector result) { {"read_file_0.tmp_0", min_shape}, {"read_file_0.tmp_1", min_shape}, {"read_file_0.tmp_2", min_shape}, - {"matmul_0.tmp_0", {batch, min_seq_len, min_seq_len}}}; + {"read_file_0.tmp_4", min_shape}}; std::map> max_input_shape = { {"read_file_0.tmp_0", max_shape}, {"read_file_0.tmp_1", max_shape}, {"read_file_0.tmp_2", max_shape}, - {"matmul_0.tmp_0", {batch, max_seq_len, max_seq_len}}}; + {"read_file_0.tmp_4", max_shape}}; std::map> opt_input_shape = { {"read_file_0.tmp_0", opt_shape}, {"read_file_0.tmp_1", opt_shape}, {"read_file_0.tmp_2", opt_shape}, - {"matmul_0.tmp_0", {batch, opt_seq_len, opt_seq_len}}}; + {"read_file_0.tmp_4", opt_shape}}; auto precision = AnalysisConfig::Precision::kFloat32; if (with_fp16) { diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc index 8c4ada280cce2b47f3a6b3220cec42a8458715d0..25ad6e6105aae7eff4c0af707439c6b586f81315 100644 --- a/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_ernie_test.cc @@ -86,16 +86,16 @@ void run(const AnalysisConfig& config, std::vector* out_data) { void trt_ernie(bool with_fp16, std::vector result) { AnalysisConfig config; std::string model_dir = FLAGS_infer_model; - SetConfig(&config, model_dir, true /* use_gpu */); + SetConfig(&config, model_dir, true); config.SwitchUseFeedFetchOps(false); - int batch = 1; + int batch = 32; int min_seq_len = 1; int max_seq_len = 128; int opt_seq_len = 128; - std::vector min_shape = {batch, min_seq_len, 1}; + std::vector min_shape = {1, min_seq_len, 1}; std::vector max_shape = {batch, max_seq_len, 1}; std::vector opt_shape = {batch, opt_seq_len, 1}; // Set the input's min, max, opt shape @@ -103,17 +103,17 @@ void trt_ernie(bool with_fp16, std::vector result) { {"read_file_0.tmp_0", min_shape}, {"read_file_0.tmp_1", min_shape}, {"read_file_0.tmp_2", min_shape}, - {"matmul_0.tmp_0", {batch, min_seq_len, min_seq_len}}}; + {"read_file_0.tmp_4", min_shape}}; std::map> max_input_shape = { {"read_file_0.tmp_0", max_shape}, {"read_file_0.tmp_1", max_shape}, {"read_file_0.tmp_2", max_shape}, - {"matmul_0.tmp_0", {batch, max_seq_len, max_seq_len}}}; + {"read_file_0.tmp_4", max_shape}}; std::map> opt_input_shape = { {"read_file_0.tmp_0", opt_shape}, {"read_file_0.tmp_1", opt_shape}, {"read_file_0.tmp_2", opt_shape}, - {"matmul_0.tmp_0", {batch, opt_seq_len, opt_seq_len}}}; + {"read_file_0.tmp_4", opt_shape}}; auto precision = AnalysisConfig::Precision::kFloat32; if (with_fp16) { @@ -124,6 +124,7 @@ void trt_ernie(bool with_fp16, std::vector result) { opt_input_shape); std::vector out_data; run(config, &out_data); + for (size_t i = 0; i < out_data.size(); i++) { EXPECT_NEAR(result[i], out_data[i], 1e-6); } diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index 0e606c466b5bca6f6b7192cc57a5b0df83bfedf0..27ab74dd2d1c93b38cd4cdcf4169154d1f66136a 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -263,9 +263,11 @@ class TensorRTEngineOp : public framework::OperatorBase { buffers[bind_index] = static_cast(t.data()); } else if (type == framework::proto::VarType::INT64) { buffers[bind_index] = static_cast(t.data()); + } else if (type == framework::proto::VarType::INT32) { + buffers[bind_index] = static_cast(t.data()); } else { PADDLE_THROW(platform::errors::Fatal( - "The TRT Engine OP only support float and int64_t input.")); + "The TRT Engine OP only support float/int32_t/int64_t input.")); } } diff --git a/paddle/fluid/platform/dynload/tensorrt.cc b/paddle/fluid/platform/dynload/tensorrt.cc index c9c3a9456b736ee1afb2efbe9bf092e2ae298372..8ddc9e982bab8cdcb80ce1b27ab6c024e8c6d5ef 100644 --- a/paddle/fluid/platform/dynload/tensorrt.cc +++ b/paddle/fluid/platform/dynload/tensorrt.cc @@ -22,19 +22,15 @@ namespace dynload { std::once_flag tensorrt_dso_flag; void* tensorrt_dso_handle; +std::once_flag tensorrt_plugin_dso_flag; +void* tensorrt_plugin_dso_handle; + #define DEFINE_WRAP(__name) DynLoad__##__name __name TENSORRT_RAND_ROUTINE_EACH(DEFINE_WRAP); +TENSORRT_PLUGIN_RAND_ROUTINE_EACH(DEFINE_WRAP); -void* GetTensorRtHandle() { -#if defined(__APPLE__) || defined(__OSX__) - std::string dso_name = "libnvinfer.dylib"; -#elif defined(_WIN32) - std::string dso_name = "nvinfer.dll"; -#else - std::string dso_name = "libnvinfer.so"; -#endif - +void* GetDsoHandle(const std::string& dso_name) { #if !defined(_WIN32) int dynload_flags = RTLD_LAZY | RTLD_LOCAL; #else @@ -65,10 +61,31 @@ void* GetTensorRtHandle() { #endif // !_WIN32 std::cerr << string::Sprintf(error_msg, dso_name, errorno); } - return dso_handle; } +void* GetTensorRtHandle() { +#if defined(__APPLE__) || defined(__OSX__) + std::string dso_name = "libnvinfer.dylib"; +#elif defined(_WIN32) + std::string dso_name = "nvinfer.dll"; +#else + std::string dso_name = "libnvinfer.so"; +#endif + return GetDsoHandle(dso_name); +} + +void* GetTensorRtPluginHandle() { +#if defined(__APPLE__) || defined(__OSX__) + std::string dso_name = "libnvinfer_plugin.dylib"; +#elif defined(_WIN32) + std::string dso_name = "nvinfer_plugin.dll"; +#else + std::string dso_name = "libnvinfer_plugin.so"; +#endif + return GetDsoHandle(dso_name); +} + } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/dynload/tensorrt.h b/paddle/fluid/platform/dynload/tensorrt.h index 34ad1e74588805b20543ba68869bf4e466b5911f..6e3d9132c4f2218a558562a1e08ec5ef136017a8 100644 --- a/paddle/fluid/platform/dynload/tensorrt.h +++ b/paddle/fluid/platform/dynload/tensorrt.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once #include +#include #if !defined(_WIN32) #include #endif @@ -32,11 +33,14 @@ void* GetTensorRtHandle(); extern std::once_flag tensorrt_dso_flag; extern void* tensorrt_dso_handle; +void* GetTensorRtPluginHandle(); +extern std::once_flag tensorrt_plugin_dso_flag; +extern void* tensorrt_plugin_dso_handle; + #define DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP(__name) \ struct DynLoad__##__name { \ template \ auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ - using tensorrt_func = decltype(&::__name); \ std::call_once(tensorrt_dso_flag, []() { \ tensorrt_dso_handle = paddle::platform::dynload::GetTensorRtHandle(); \ }); \ @@ -44,17 +48,50 @@ extern void* tensorrt_dso_handle; if (p_##__name == nullptr) { \ return nullptr; \ } \ + using tensorrt_func = decltype(&::__name); \ return reinterpret_cast(p_##__name)(args...); \ } \ }; \ extern DynLoad__##__name __name +#define DECLARE_DYNAMIC_LOAD_TENSORRT_PLUGIN_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + auto operator()(Args... args) -> DECLARE_TYPE(__name, args...) { \ + std::call_once(tensorrt_plugin_dso_flag, []() { \ + tensorrt_plugin_dso_handle = \ + paddle::platform::dynload::GetTensorRtPluginHandle(); \ + }); \ + static void* p_##__name = dlsym(tensorrt_plugin_dso_handle, #__name); \ + PADDLE_ENFORCE_NOT_NULL(p_##__name, \ + platform::errors::Unavailable( \ + "Load tensorrt plugin %s failed", #__name)); \ + using tensorrt_plugin_func = decltype(&::__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + }; \ + extern DynLoad__##__name __name + +#ifdef NV_TENSORRT_MAJOR + +#if (NV_TENSORRT_MAJOR >= 6) #define TENSORRT_RAND_ROUTINE_EACH(__macro) \ __macro(createInferBuilder_INTERNAL); \ __macro(createInferRuntime_INTERNAL); \ __macro(getPluginRegistry); +#else +#define TENSORRT_RAND_ROUTINE_EACH(__macro) \ + __macro(createInferBuilder_INTERNAL); \ + __macro(createInferRuntime_INTERNAL); +#endif + +#define TENSORRT_PLUGIN_RAND_ROUTINE_EACH(__macro) \ + __macro(initLibNvInferPlugins); TENSORRT_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_WRAP) +TENSORRT_PLUGIN_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_TENSORRT_PLUGIN_WRAP) + +#endif // end of NV_TENSORRT_MAJOR } // namespace dynload } // namespace platform diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index a7d5b36bfc8d20e2fbbb34c15465d14668b65f95..7d77ed80cb47e69cfe57087120136b7b2331163c 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -423,6 +423,8 @@ void BindAnalysisConfig(py::module *m) { py::arg("optim_input_shape") = std::map>({}), py::arg("disable_trt_plugin_fp16") = false) + .def("enable_tensorrt_oss", &AnalysisConfig::EnableTensorRtOSS) + .def("tensorrt_oss_enabled", &AnalysisConfig::tensorrt_oss_enabled) .def("tensorrt_engine_enabled", &AnalysisConfig::tensorrt_engine_enabled) .def("enable_lite_engine", &AnalysisConfig::EnableLiteEngine, py::arg("zero_copy") = false, diff --git a/python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py b/python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py index 0b8ea1f9392e9264c2abf1827e39582be92988cb..c7fd7995118dbfdadc0a11934c7f171b1880a732 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py @@ -20,6 +20,7 @@ import random import unittest import numpy as np +import paddle import paddle.fluid as fluid import paddle.fluid.core as core from paddle.fluid.core import PaddleTensor @@ -152,6 +153,8 @@ class InferencePassTest(unittest.TestCase): format(device)) for out, analysis_output in zip(outs, analysis_outputs): + out = np.array(out) + self.assertTrue( np.allclose( np.array(out), analysis_output, atol=atol), @@ -169,6 +172,8 @@ class InferencePassTest(unittest.TestCase): "The number of outputs is different between GPU and TensorRT. ") for out, tensorrt_output in zip(outs, tensorrt_outputs): + out = np.array(out) + self.assertTrue( np.allclose( np.array(out), tensorrt_output, atol=atol), diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_matmul.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_matmul.py new file mode 100644 index 0000000000000000000000000000000000000000..1899b68b371d2335743b5e2d79232ee30309c2d8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_matmul.py @@ -0,0 +1,106 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from inference_pass_test import InferencePassTest +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.core import AnalysisConfig + + +class TensorRTMatMulDims2Test(InferencePassTest): + def setUp(self): + self.set_params() + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data(name="data", shape=[24, 24], dtype="float32") + matmul_out = fluid.layers.matmul( + x=data, + y=data, + transpose_x=self.transpose_x, + transpose_y=self.transpose_y, + alpha=self.alpha) + out = fluid.layers.batch_norm(matmul_out, is_test=True) + + self.feeds = {"data": np.ones([24, 24]).astype("float32"), } + self.enable_trt = True + self.trt_parameters = TensorRTMatMulDims2Test.TensorRTParam( + 1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False) + self.fetch_list = [out] + + def set_params(self): + self.transpose_x = True + self.transpose_y = True + self.alpha = 2.0 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu) + + +class TensorRTMatMulTest(InferencePassTest): + def setUp(self): + self.set_params() + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 6, 24, 24], dtype="float32") + matmul_out = fluid.layers.matmul( + x=data, + y=data, + transpose_x=self.transpose_x, + transpose_y=self.transpose_y, + alpha=self.alpha) + out = fluid.layers.batch_norm(matmul_out, is_test=True) + + self.feeds = {"data": np.ones([1, 6, 24, 24]).astype("float32"), } + self.enable_trt = True + self.trt_parameters = TensorRTMatMulTest.TensorRTParam( + 1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False) + self.fetch_list = [out] + + def set_params(self): + self.transpose_x = False + self.transpose_y = False + self.alpha = 1.0 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu) + + +class TensorRTMatMulTransposeXTest(TensorRTMatMulTest): + def set_params(self): + self.transpose_x = True + self.transpose_y = False + self.alpha = 1.0 + + +class TensorRTMatMulTransposeYTest(TensorRTMatMulTest): + def set_params(self): + self.transpose_x = False + self.transpose_y = True + self.alpha = 1.0 + + +class TensorRTMatMulScaleTest(TensorRTMatMulTest): + def set_params(self): + self.transpose_x = False + self.transpose_y = False + self.alpha = 2.0 + + +if __name__ == "__main__": + unittest.main()