diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index e69a1e0e1ffb0b723f80362efad680a96533564d..717737749a96beb220d271e96051d2ce8c4addc2 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -331,6 +331,9 @@ struct Argument { // mixed precision related DECL_ARGUMENT_FIELD(model_precision, ModelPrecision, int); + DECL_ARGUMENT_FIELD(mixed_black_list, + MixedBlackList, + std::unordered_set); private: std::unordered_set valid_fields_; diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index 4aeaefa3c49c3ff5ed715803b879123f799fb2b9..3c04638003cdd0c31c1e9f3aeb1cd9cf07130db6 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -87,6 +87,9 @@ void IRPassManager::CreatePasses(Argument *argument, pass->Set("with_dynamic_shape", new bool(with_dynamic_shape)); pass->Set("model_precision", new int(argument->model_precision())); + pass->Set( + "mixed_black_list", + new std::unordered_set(argument->mixed_black_list())); if (pass_name == "graph_viz_pass") { std::string optim_cache_dir = argument->optim_cache_dir(); diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 7a9c5b889d14621e2947333f07da97002bf44dca..d39eadc7cc8f19d95719a7103a8dd5a5db6aa340 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -13,26 +13,117 @@ // limitations under the License. #include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h" +#include +#include +#include +#include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/graph_viz_pass.h" +#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/subgraph_detector.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/inference/analysis/helper.h" +#include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/inference/tensorrt/op_teller.h" #include "paddle/fluid/inference/utils/io_utils.h" +#include "paddle/phi/common/backend.h" +#include "paddle/phi/common/data_type.h" namespace paddle { namespace inference { namespace analysis { +namespace { + +bool IsFloat(framework::proto::VarType::Type t) { + if (t == framework::proto::VarType::FP16 || + t == framework::proto::VarType::FP32 || + t == framework::proto::VarType::FP64 || + t == framework::proto::VarType::BF16) + return true; + return false; +} + +// if in mixed model precision, we should make all tensorrt_engine's output +// floats dtype to float32 dtype. +void OutputProcess(framework::ir::Graph *graph, + const std::unordered_set &trt_outputs, + phi::Backend backend, + phi::DataType precision, + const std::unordered_set &blacklist) { + framework::BlockDesc *block_desc{nullptr}; + int suffix = 0; + std::unordered_map + var_to_cast_op_map; + + framework::proto::VarType::Type to_type; + if (precision == phi::DataType::FLOAT16) { + to_type = framework::proto::VarType::FP16; + } else if (precision == phi::DataType::BFLOAT16) { + to_type = framework::proto::VarType::BF16; + } else if (precision == phi::DataType::FLOAT32) { + return; + } else { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "mixed_precision currently not supported dtype %d, we now only support " + "fp16 and bf16.", + static_cast(precision))); + } + + for (auto *op_node : framework::ir::TopologySortOperations(*graph)) { + if (!op_node->IsOp()) continue; + auto op_type = op_node->Op()->Type(); + if (op_type == "feed") block_desc = op_node->Op()->Block(); + if (op_type != "tensorrt_engine") continue; + for (auto *var_node : op_node->outputs) { + if (!trt_outputs.count(var_node)) continue; + if (!var_node->Var()->Persistable() && + IsFloat(var_node->Var()->GetDataType()) && + var_node->Var()->GetDataType() != framework::proto::VarType::FP32) { + for (auto *next_op : var_node->outputs) { + // if next_op support mixed_precision, we need to add cast op. + if (OpSupportPrecision( + phi::TransToPhiKernelName(next_op->Op()->Type()), + backend, + precision, + blacklist)) { + AddCastOp(graph, + var_node, + next_op, + framework::proto::VarType::FP32, + to_type, + &suffix, + block_desc, + &var_to_cast_op_map); + var_node->Var()->SetDataType(framework::proto::VarType::FP32); + } + } + } + } + } +} + +} // namespace using framework::ir::Node; void analysis::TensorRtSubgraphPass::ApplyImpl( framework::ir::Graph *graph) const { framework::ir::FusePassBase::Init("tensorrt_subgraph_pass", graph); + + auto model_precision = + static_cast(Get("model_precision")); + if (model_precision == phi::DataType::BFLOAT16) { + LOG(WARNING) + << "Paddle-TRT not support bf16 mixed precison, just fallback."; + return; + } + auto enable_int8 = Get("enable_int8"); auto use_calib_mode = Get("use_calib_mode"); bool no_calib_int8 = enable_int8 && !(use_calib_mode); @@ -181,15 +272,25 @@ void TensorRtSubgraphPass::CreateTensorRTOp( } } + auto model_precision = + static_cast(Get("model_precision")); + auto mixed_black_list = + Get>("mixed_black_list"); + std::set output_names; std::set output_names_with_id; std::map origin_name_output_dims; + std::unordered_set trt_outputs; for (auto *x : node->outputs) { output_names.insert(x->Name()); output_names_with_id.insert(x->Name() + std::to_string(x->id())); origin_name_output_dims[x->Name()] = x->Var()->GetShape().size(); + trt_outputs.insert(x); } + OutputProcess( + graph, trt_outputs, phi::Backend::GPU, model_precision, mixed_black_list); + std::unordered_map output_name_map; std::unordered_map graph_var_map; @@ -285,6 +386,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp( op_desc->SetAttr("allow_build_at_runtime", allow_build_at_runtime); op_desc->SetAttr("shape_range_info_path", shape_range_info_path); op_desc->SetAttr("use_inspector", Get("use_inspector")); + op_desc->SetAttr("model_precision", Get("model_precision")); // we record all inputs' shapes in attr to check if they are consistent // with the real inputs' shapes retrieved from scope when trt runs. @@ -404,7 +506,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp( min_input_shape, max_input_shape, opt_input_shape, - disable_trt_plugin_fp16); + disable_trt_plugin_fp16, + static_cast(Get("model_precision"))); trt_engine->SetUseOSS(Get("use_varseqlen")); trt_engine->SetWithInterleaved(Get("with_interleaved")); trt_engine->SetTransformerPosid( diff --git a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc index 44e36647646fe44ae9f5cb9d61f5239d72270612..bc753636d2c1a175f98cd36af3c63bde55558dc3 100644 --- a/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc +++ b/paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.cc @@ -18,6 +18,7 @@ #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" @@ -379,27 +380,21 @@ void ConvertToMixedPrecision(const std::string& model_file, }; std::unordered_set weights_should_be_fp32; - for (auto* node : paddle::framework::ir::TopologySortOperations(*graph)) { - if (!node->IsOp()) continue; - auto* op_desc = node->Op(); - if (op_desc->Type() == "feed" || op_desc->Type() == "fetch") continue; - - if (op_desc->Type() == "batch_norm") { - auto vecs = op_desc->Input("Bias"); - for (auto s : vecs) { - weights_should_be_fp32.insert(s); - } - vecs = op_desc->Input("Mean"); - for (auto s : vecs) { - weights_should_be_fp32.insert(s); - } - vecs = op_desc->Input("Scale"); - for (auto s : vecs) { - weights_should_be_fp32.insert(s); - } - vecs = op_desc->Input("Variance"); - for (auto s : vecs) { - weights_should_be_fp32.insert(s); + for (auto* node : graph->Nodes()) { + if (!node->IsVar()) continue; + if (node->Var()->GetType() == + paddle::framework::proto::VarType::SELECTED_ROWS || + node->Var()->GetType() == + paddle::framework::proto::VarType::LOD_TENSOR || + node->Var()->GetType() == + paddle::framework::proto::VarType::LOD_TENSOR_ARRAY || + node->Var()->GetType() == paddle::framework::proto::VarType::STRINGS || + node->Var()->GetType() == paddle::framework::proto::VarType::VOCAB) { + if (node->Var()->Persistable() && + node->Var()->GetDataType() == + paddle::framework::proto::VarType::FP32) { + VLOG(2) << "weights keep to fp32: " << node->Name(); + weights_should_be_fp32.insert(node->Name()); } } } diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 75a5d9ee4f55b9971bdda7d5241a9824135dc1fa..ae90618f5207cd8100bc5460e63e9c796a2dc3ba 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -256,6 +256,9 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(gpu_device_id_); CP_MEMBER(memory_pool_init_size_mb_); + // Mixed related. + CP_MEMBER(mixed_black_list_); + CP_MEMBER(enable_memory_optim_); // TensorRT related. CP_MEMBER(use_tensorrt_); @@ -871,6 +874,7 @@ std::string AnalysisConfig::SerializeInfoCache() { ss << ipu_available_memory_proportion_; ss << ipu_enable_half_partial_; + for (auto &op : mixed_black_list_) ss << op.c_str(); return ss.str(); } @@ -1188,4 +1192,10 @@ bool AnalysisConfig::tuned_tensorrt_dynamic_shape() { bool AnalysisConfig::trt_allow_build_at_runtime() { return trt_allow_build_at_runtime_; } + +void AnalysisConfig::Exp_SetBlackListOpsForMixedModel( + const std::unordered_set &black_list) { + mixed_black_list_ = black_list; +} + } // namespace paddle diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 62f89e300bfbda54fb6da47d5deb1dc9417c4db9..d008355e0ed5bb384e23d59460f4d6ca072f1d85 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1216,7 +1216,9 @@ void AnalysisPredictor::PrepareArgument() { argument_.SetAnalysisPasses(config_.pass_builder()->AnalysisPasses()); argument_.SetScopeNotOwned(scope_.get()); + // mixed precison. argument_.SetModelPrecision(static_cast(model_precision_)); + argument_.SetMixedBlackList(config_.mixed_black_list_); } // NOTE All the members in AnalysisConfig should be copied to Argument. diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index 6de23e930836aff03c544a08ab123724f0277b91..08d0e073babc18c5691be32a5642efaa15ff098d 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -914,6 +914,14 @@ struct PD_INFER_DECL AnalysisConfig { const DistConfig& dist_config() const { return dist_config_; } + /// + /// \brief Set a list of operators that do not support mixed precision. This + /// interface is in the experimental stage and may change in the future. Note + /// that the blacklist must be the same as the model conversion blacklist. + /// + void Exp_SetBlackListOpsForMixedModel( + const std::unordered_set& black_list); + protected: // Update the config. void Update(); @@ -926,6 +934,9 @@ struct PD_INFER_DECL AnalysisConfig { mutable std::string prog_file_; mutable std::string params_file_; + // Mixed precision. + std::unordered_set mixed_black_list_; + // GPU related. bool use_gpu_{false}; int gpu_device_id_{0}; diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 73c216290dd88e3aa8d59d4ab53172e7e8eff80a..0d918446ea92aafa524f73615224613a4bf67e15 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -160,6 +160,10 @@ const std::vector kGpuLowerPrecisionPasses{ const std::vector kTrtLowerPrecisionPasses{ // "conv_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass", + "trt_map_matmul_v2_to_mul_pass", + "trt_map_matmul_v2_to_matmul_pass", + "trt_map_matmul_to_mul_pass", + "fc_fuse_pass", "tensorrt_subgraph_pass", }; diff --git a/paddle/fluid/inference/tensorrt/convert/affine_channel_op.cc b/paddle/fluid/inference/tensorrt/convert/affine_channel_op.cc index 44283f4e0d7e960b58d0ac464c1f8a4cbe1191b8..017fa8800b4589e8838e1b809c85c7c74eb2eb01 100644 --- a/paddle/fluid/inference/tensorrt/convert/affine_channel_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/affine_channel_op.cc @@ -50,22 +50,26 @@ class AffineChannelOpConverter : public OpConverter { auto* scale_v = scope.FindVar(scale_name); auto* scale_t = scale_v->GetMutable(); - float* scale_ptr = engine_->GetWeightCPUData(scale_name, scale_t); + float* scale_ptr = const_cast(static_cast( + engine_->GetFp32TrtWeight(scale_name, *scale_t).get().values)); auto* bias_v = scope.FindVar(bias_name); auto* bias_t = bias_v->GetMutable(); - float* bias_ptr = engine_->GetWeightCPUData(bias_name, bias_t); + float* bias_ptr = const_cast(static_cast( + engine_->GetFp32TrtWeight(bias_name, *bias_t).get().values)); // tensorrt scalend layer only support spatial dims >= 2, // so nhwc is not availabe (spatial dims == 0) const int channel_axis = engine_->with_dynamic_shape(); - TensorRTEngine::Weight scale_weights{nvinfer1::DataType::kFLOAT, - static_cast(scale_ptr), - (size_t)idim.d[channel_axis]}; - TensorRTEngine::Weight bias_weights{nvinfer1::DataType::kFLOAT, - static_cast(bias_ptr), - (size_t)idim.d[channel_axis]}; + TensorRTEngine::Weight scale_weights{ + nvinfer1::DataType::kFLOAT, + static_cast(scale_ptr), + static_cast(idim.d[channel_axis])}; + TensorRTEngine::Weight bias_weights{ + nvinfer1::DataType::kFLOAT, + static_cast(bias_ptr), + static_cast(idim.d[channel_axis])}; TensorRTEngine::Weight power_weights{ nvinfer1::DataType::kFLOAT, nullptr, 0}; diff --git a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc index 244078dc344a289c58b057c17e9de1434644ec3d..c47f6d03cd5432c859a8d522587990841d6e17e3 100644 --- a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/phi/common/data_type.h" namespace paddle { namespace framework { @@ -48,7 +50,7 @@ void ConvertConv2d(TensorRTEngine* engine, platform::errors::NotFound("Can not find %s presistale var in scope.", filter_var_name)); auto* Y_t = Y_v->GetMutable(); - float* weight_data = nullptr; + bool enable_int8 = op_desc.HasAttr("enable_int8"); if (enable_int8) { @@ -57,7 +59,6 @@ void ConvertConv2d(TensorRTEngine* engine, engine->SetTensorDynamicRange(X, in_scale); #endif } - weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t); PADDLE_ENFORCE_EQ(Y_t->dims().size(), 4UL, @@ -104,21 +105,19 @@ void ConvertConv2d(TensorRTEngine* engine, nv_post_paddings.d[1] = paddings[3]; } - TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, - static_cast(weight_data), - static_cast(Y_t->numel())}; - float* bias_data = nullptr; - size_t bias_size = 0; + auto weight = engine->GetTrtWeight(op_desc.Input("Filter").front(), *Y_t); + + TensorRTEngine::Weight bias; + bias.SetDataType(weight.get().type); + bias.SetCount(0); + bias.SetValues(nullptr); if (op_desc.Type() == "conv2d_fusion") { auto* bias_tensor = scope.GetVar(op_desc.Input("Bias").front()); auto* bias_tensor_data = bias_tensor->GetMutable(); - bias_data = engine->GetWeightCPUData(op_desc.Input("Bias").front(), - bias_tensor_data); - bias_size = static_cast(bias_tensor_data->numel()); + bias = + engine->GetTrtWeight(op_desc.Input("Bias").front(), *bias_tensor_data); } - TensorRTEngine::Weight bias{ - nvinfer1::DataType::kFLOAT, static_cast(bias_data), bias_size}; // In conv2d_transpose and depthwise_conv2d_transpose, // output channels = filter_dims[1] * groups auto* layer = (op_desc.Type() == "conv2d_transpose" || diff --git a/paddle/fluid/inference/tensorrt/convert/conv3d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv3d_op.cc index 7308d44bf8320b1bf47fdc75cd34d5bf8d2e77e4..4ffc8056547272bdc1752e98edc2bb2b5bb313e3 100644 --- a/paddle/fluid/inference/tensorrt/convert/conv3d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/conv3d_op.cc @@ -48,14 +48,12 @@ void ConvertConv3d(TensorRTEngine* engine, platform::errors::NotFound("Can not find %s presistale var in scope.", filter_var_name)); auto* Y_t = Y_v->GetMutable(); - float* weight_data = nullptr; bool enable_int8 = op_desc.HasAttr("enable_int8"); if (enable_int8) { float in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")); engine->SetTensorDynamicRange(X, in_scale); } - weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t); PADDLE_ENFORCE_EQ(Y_t->dims().size(), 5UL, @@ -85,14 +83,12 @@ void ConvertConv3d(TensorRTEngine* engine, nvinfer1::Dims3 nv_strides(strides[0], strides[1], strides[2]); nvinfer1::Dims3 nv_paddings(paddings[0], paddings[1], paddings[2]); - TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, - static_cast(weight_data), - static_cast(Y_t->numel())}; + auto weight = engine->GetTrtWeight(op_desc.Input("Filter").front(), *Y_t); float* bias_data = nullptr; size_t bias_size = 0; TensorRTEngine::Weight bias{ - nvinfer1::DataType::kFLOAT, static_cast(bias_data), bias_size}; + weight.get().type, static_cast(bias_data), bias_size}; // In conv3d_transpose output channels = filter_dims[1] * groups auto* layer = (op_desc.Type() == "conv3d_transpose") ? fadd_layer(X, n_input * groups, nv_ksize, weight, bias) diff --git a/paddle/fluid/inference/tensorrt/convert/deformable_conv_op.cc b/paddle/fluid/inference/tensorrt/convert/deformable_conv_op.cc index f0a82bebc7ca92d3c34000c2a000fcff95bd5923..8cf7f6528e5950f011c7b264e0be465d05ae6a7c 100644 --- a/paddle/fluid/inference/tensorrt/convert/deformable_conv_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/deformable_conv_op.cc @@ -49,8 +49,6 @@ class DeformableConvOpConverter : public OpConverter { auto* filter_var = scope.FindVar(filter_name); auto* filter_tensor = filter_var->GetMutable(); - float* filter_data = engine_->GetWeightCPUData(filter_name, filter_tensor); - const int c_o = filter_tensor->dims()[0]; const int c_i = filter_tensor->dims()[1]; const int k_h = filter_tensor->dims()[2]; @@ -73,15 +71,20 @@ class DeformableConvOpConverter : public OpConverter { weights.count = filter_tensor->numel(); bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); if (with_fp16) { - auto half_filter_data = new half[filter_tensor->numel()]; - for (int i = 0; i < filter_tensor->numel(); i++) { - half_filter_data[i] = static_cast(filter_data[i]); + auto filter_weight = engine_->GetTrtWeight(filter_name, *filter_tensor); + if (filter_weight.get().type == nvinfer1::DataType::kFLOAT) { + auto half_filter_data = new half[filter_tensor->numel()]; + for (int i = 0; i < filter_tensor->numel(); i++) { + half_filter_data[i] = static_cast( + static_cast(filter_weight.get().values)[i]); + } + weights.type = nvinfer1::DataType::kHALF; + weights.values = half_filter_data; + } else if (filter_weight.get().type == nvinfer1::DataType::kHALF) { + weights = filter_weight.get(); } - weights.type = nvinfer1::DataType::kHALF; - weights.values = half_filter_data; } else { - weights.type = nvinfer1::DataType::kFLOAT; - weights.values = filter_data; + weights = engine_->GetFp32TrtWeight(filter_name, *filter_tensor).get(); } auto* deformable_conv_plugin = new plugin::DeformableConvPlugin( with_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT, diff --git a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc index aff23343a078fd2732d74556d88436e4d4cfe648..365523508f5dfafc92d4a9edcfcafef4796316de 100644 --- a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc @@ -33,12 +33,9 @@ class ElementwiseTensorOpConverter : public OpConverter { if (Y_v) { // Y is weight auto* Y_t = Y_v->GetMutable(); - float* weight_data = - engine_->GetWeightCPUData(op_desc.Input("Y").front(), Y_t); std::vector dims_y = phi::vectorize(Y_t->dims()); - TensorRTEngine::Weight y_weight{nvinfer1::DataType::kFLOAT, - static_cast(weight_data), - static_cast(Y_t->numel())}; + auto y_weight = engine_->GetTrtWeight(op_desc.Input("Y").front(), *Y_t); + nvinfer1::Dims trt_dims_y; trt_dims_y.nbDims = dims_y.size(); for (int i = 0; i < trt_dims_y.nbDims; i++) { diff --git a/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc index 1a1f72388e40e35a6346b5109334bdcca8a22131..5020b9762775307ce0564d875f042f9cc958af84 100644 --- a/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc @@ -10,8 +10,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/convert/utils.h" +#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h" +#include "paddle/phi/core/ddim.h" namespace paddle { namespace framework { @@ -73,27 +76,39 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { // input_embs[0]: word_embedding // input_embs[1]: pos_embedding // input_embs[2]: sent_embedding - std::vector input_embs; + std::vector input_embs; std::vector emb_sizes; // get the presistable var's data - auto get_persistable_data = [&](const std::string& var_name, - framework::DDim* dims) -> float* { + auto GetWeight = [&](const std::string& var_name, + framework::DDim* dim) -> TensorRTEngine::Weight { auto* temp_var = scope.FindVar(var_name); auto* temp_tensor = temp_var->GetMutable(); - (*dims) = temp_tensor->dims(); + *dim = temp_tensor->dims(); + auto weight = engine_->GetTrtWeight(var_name, *temp_tensor); + return weight; + }; - auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor); - return temp_data; + auto GetFp32Weight = [&](const std::string& var_name, + framework::DDim* dim) -> TensorRTEngine::Weight { + auto* temp_var = scope.FindVar(var_name); + auto* temp_tensor = temp_var->GetMutable(); + *dim = temp_tensor->dims(); + auto weight = engine_->GetFp32TrtWeight(var_name, *temp_tensor); + return weight; }; int hidden = 0; for (int i = 0; i < input_num; i++) { framework::DDim emb_dims; - float* emb_data = get_persistable_data(emb_names[i], &emb_dims); - int64_t emb_size = phi::product(emb_dims); - input_embs.push_back(emb_data); - emb_sizes.push_back(emb_size); + TensorRTEngine::Weight weight; + if (flag_varseqlen) { + weight = GetWeight(emb_names[i], &emb_dims); + } else { + weight = GetFp32Weight(emb_names[i], &emb_dims); + } + input_embs.push_back(weight.get()); + emb_sizes.push_back(weight.get().count); PADDLE_ENFORCE_EQ( emb_dims.size(), 2, @@ -103,11 +118,15 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { } framework::DDim bias_dims, scale_dims; + TensorRTEngine::Weight bias_weight, scale_weight; + if (flag_varseqlen) { + bias_weight = GetWeight(op_desc.Input("Bias").front(), &bias_dims); + scale_weight = GetWeight(op_desc.Input("Scale").front(), &scale_dims); + } else { + bias_weight = GetFp32Weight(op_desc.Input("Bias").front(), &bias_dims); + scale_weight = GetFp32Weight(op_desc.Input("Scale").front(), &scale_dims); + } - auto* bias = - get_persistable_data(op_desc.Input("Bias").front(), &bias_dims); - auto* scale = - get_persistable_data(op_desc.Input("Scale").front(), &scale_dims); int64_t bias_size = phi::product(bias_dims); int64_t scale_size = phi::product(scale_dims); nvinfer1::ILayer* layer = nullptr; @@ -134,24 +153,24 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { "But Precision::KFloat32 is setted.")); const std::vector fields{ {"bert_embeddings_layernorm_beta", - bias, - nvinfer1::PluginFieldType::kFLOAT32, + bias_weight.get().values, + GetPluginFieldType(bias_weight.get().type), static_cast(bias_size)}, {"bert_embeddings_layernorm_gamma", - scale, - nvinfer1::PluginFieldType::kFLOAT32, + scale_weight.get().values, + GetPluginFieldType(scale_weight.get().type), static_cast(scale_size)}, {"bert_embeddings_word_embeddings", - input_embs[0], - nvinfer1::PluginFieldType::kFLOAT32, + input_embs[0].values, + GetPluginFieldType(input_embs[0].type), static_cast(emb_sizes[0])}, {"bert_embeddings_token_type_embeddings", - input_embs[2], - nvinfer1::PluginFieldType::kFLOAT32, + input_embs[2].values, + GetPluginFieldType(input_embs[2].type), static_cast(emb_sizes[2])}, {"bert_embeddings_position_embeddings", - input_embs[1], - nvinfer1::PluginFieldType::kFLOAT32, + input_embs[1].values, + GetPluginFieldType(input_embs[1].type), static_cast(emb_sizes[1])}, {"output_fp16", &output_fp16, nvinfer1::PluginFieldType::kINT32, 1}, }; @@ -235,15 +254,23 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); float eps = BOOST_GET_CONST(float, op_desc.GetAttr("epsilon")); plugin::DynamicPluginTensorRT* plugin = nullptr; - plugin = new plugin::EmbEltwiseLayernormPluginDynamic(input_embs, - bias, - scale, - emb_sizes, - bias_size, - scale_size, - hidden, - eps, - with_fp16); + std::vector input_embs_data; + for (size_t i = 0; i < input_embs.size(); ++i) { + input_embs_data.push_back(const_cast( + static_cast(input_embs[i].values))); + } + plugin = new plugin::EmbEltwiseLayernormPluginDynamic( + input_embs_data, + const_cast( + static_cast(bias_weight.get().values)), + const_cast( + static_cast(scale_weight.get().values)), + emb_sizes, + bias_size, + scale_size, + hidden, + eps, + with_fp16); layer = engine_->AddDynamicPlugin(input_ids.data(), input_num, plugin); auto output_name = op_desc.Output("Out")[0]; RreplenishLayerAndOutput( diff --git a/paddle/fluid/inference/tensorrt/convert/fc_op.cc b/paddle/fluid/inference/tensorrt/convert/fc_op.cc index ce6644cad4200f10a3e432469fda6c964dd7a94f..1bd9cf8712d989be5acb90747c761891e9226ab4 100644 --- a/paddle/fluid/inference/tensorrt/convert/fc_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/fc_op.cc @@ -27,6 +27,16 @@ class OpDesc; namespace paddle { namespace inference { namespace tensorrt { +namespace { +template +void tranpose_weight(const T* src, T* dst, int m, int n) { + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + dst[j * m + i] = src[i * n + j]; + } + } +} +} // namespace /* * FC converter convert a MUL op in Fluid to a FC layer in TRT. @@ -156,9 +166,7 @@ class FcOpConverter : public OpConverter { op_desc.HasAttr("activation_type") ? BOOST_GET_CONST(std::string, op_desc.GetAttr("activation_type")) : ""; - // This may trigger a GPU->CPU copy, because TRT's weight can only be - // assigned from CPU memory, which can't be avoided. - float* weight_data = nullptr; + bool enable_int8 = op_desc.HasAttr("enable_int8"); bool support_int8 = false; if (op_desc.HasAttr("support_int8")) { @@ -173,7 +181,6 @@ class FcOpConverter : public OpConverter { } engine_->SetTensorDynamicRange(X, in_scale); } - weight_data = engine_->GetWeightCPUData(op_desc.Input(w_name).front(), Y_t); PADDLE_ENFORCE_EQ(Y_t->dims().size(), 2UL, @@ -183,13 +190,6 @@ class FcOpConverter : public OpConverter { Y_t->dims().size())); // a matrix int m = Y_t->dims()[0]; int n = Y_t->dims()[1]; - auto tranpose_weight = [](const float* src, float* dst, int m, int n) { - for (int i = 0; i < m; i++) { - for (int j = 0; j < n; j++) { - dst[j * m + i] = src[i * n + j]; - } - } - }; auto regist_fc = [&](nvinfer1::ITensor* inputs, int n_output, @@ -283,11 +283,36 @@ class FcOpConverter : public OpConverter { transpose_y = BOOST_GET_CONST(bool, op_desc.GetAttr("transpose_Y")); } int weight_w, weight_h; + auto weight = engine_->GetTrtWeight(op_desc.Input(w_name).front(), *Y_t); + if (!transpose_y) { - std::vector weight_data_tmp; - weight_data_tmp.reserve(Y_t->numel()); - memcpy(weight_data_tmp.data(), weight_data, Y_t->numel() * sizeof(float)); - tranpose_weight(weight_data_tmp.data(), weight_data, m, n); + if (weight.get().type == nvinfer1::DataType::kFLOAT) { + std::vector weight_data_tmp; + weight_data_tmp.reserve(Y_t->numel()); + memcpy(weight_data_tmp.data(), + weight.get().values, + Y_t->numel() * sizeof(float)); + tranpose_weight( + weight_data_tmp.data(), + const_cast(static_cast(weight.get().values)), + m, + n); + } else if (weight.get().type == nvinfer1::DataType::kHALF) { + std::vector weight_data_tmp; + weight_data_tmp.reserve(Y_t->numel()); + memcpy(weight_data_tmp.data(), + weight.get().values, + Y_t->numel() * sizeof(float16)); + tranpose_weight(weight_data_tmp.data(), + const_cast( + static_cast(weight.get().values)), + m, + n); + } else { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "Paddle-TRT fc convert not supporte dtype, now only support fp32 " + "and fp16.")); + } weight_w = n; weight_h = m; } else { @@ -295,22 +320,14 @@ class FcOpConverter : public OpConverter { weight_h = n; } size_t n_output = weight_w; - TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, - static_cast(weight_data), - static_cast(Y_t->numel())}; weight.dims.assign({weight_w, weight_h}); - float* bias_data = nullptr; - int bias_num = 0; + TensorRTEngine::Weight bias{weight.get().type, nullptr, 0}; if (with_bias) { auto* b_v = scope.GetVar(op_desc.Input("Bias").front()); auto* b_t = b_v->GetMutable(); - bias_data = engine_->GetWeightCPUData(op_desc.Input("Bias").front(), b_t); - bias_num = b_t->numel(); + bias = engine_->GetTrtWeight(op_desc.Input("Bias").front(), *b_t); } - TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, - static_cast(bias_data), - static_cast(bias_num)}; // Running the TRT Static Shape mode: x_num_col_dims-1 if (!engine_->with_dynamic_shape()) { diff --git a/paddle/fluid/inference/tensorrt/convert/group_norm_op.cc b/paddle/fluid/inference/tensorrt/convert/group_norm_op.cc index f5a2026ff6fdf7c872b008bba85a18195a7f81c2..1b45264475354020e6ea4dd733f74bc5da721d28 100644 --- a/paddle/fluid/inference/tensorrt/convert/group_norm_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/group_norm_op.cc @@ -12,6 +12,7 @@ limitations under the License. */ #include #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/engine.h" namespace paddle { namespace framework { @@ -44,30 +45,20 @@ class GroupNormOpConverter : public OpConverter { std::string bias_name = op_desc.Input("Bias").front(); // get the presistable var's data - auto get_persistable_data = [&](const std::string& var_name, - framework::DDim* dims) -> float* { + auto GetWeight = [&](const std::string& var_name, + framework::DDim* dims) -> TensorRTEngine::Weight { auto* temp_var = scope.FindVar(var_name); auto* temp_tensor = temp_var->GetMutable(); (*dims) = temp_tensor->dims(); - auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor); - return temp_data; + auto weight = engine_->GetTrtWeight(var_name, *temp_tensor); + return weight; }; framework::DDim scale_dims; framework::DDim bias_dims; - float* scale_data = get_persistable_data(scale_name, &scale_dims); - float* bias_data = get_persistable_data(bias_name, &bias_dims); - - int64_t scale_numel = phi::product(scale_dims); - int64_t bias_numel = phi::product(bias_dims); - - TensorRTEngine::Weight scale_weights{nvinfer1::DataType::kFLOAT, - static_cast(scale_data), - static_cast(scale_numel)}; - TensorRTEngine::Weight bias_weights{nvinfer1::DataType::kFLOAT, - static_cast(bias_data), - static_cast(bias_numel)}; + auto scale_weights = GetWeight(scale_name, &scale_dims); + auto bias_weights = GetWeight(bias_name, &bias_dims); nvinfer1::Dims scale_nv_dims; nvinfer1::Dims bias_nv_dims; diff --git a/paddle/fluid/inference/tensorrt/convert/layer_norm_op.cc b/paddle/fluid/inference/tensorrt/convert/layer_norm_op.cc index a82101e29f5719e400f79c67901be635ee9d43f7..c899f4f6e777e7c536b511539bea3afb099887d8 100644 --- a/paddle/fluid/inference/tensorrt/convert/layer_norm_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/layer_norm_op.cc @@ -49,20 +49,10 @@ class LayerNormOpConverter : public OpConverter { auto* Bias_t = Bias_v->GetMutable(); auto* Scale_t = Scale_v->GetMutable(); - std::unique_ptr bias_tensor( - new framework::LoDTensor()); - std::unique_ptr scale_tensor( - new framework::LoDTensor()); - - bias_tensor->Resize(Bias_t->dims()); - scale_tensor->Resize(Scale_t->dims()); - - platform::CPUPlace cpu_place; - paddle::framework::TensorCopySync((*Bias_t), cpu_place, &(*bias_tensor)); - paddle::framework::TensorCopySync((*Scale_t), cpu_place, &(*scale_tensor)); - - auto* bias_data = bias_tensor->mutable_data(platform::CPUPlace()); - auto* scale_data = scale_tensor->mutable_data(platform::CPUPlace()); + auto bias_weight = + engine_->GetFp32TrtWeight(op_desc.Input("Bias").front(), *Bias_t); + auto scale_weight = + engine_->GetFp32TrtWeight(op_desc.Input("Scale").front(), *Scale_t); nvinfer1::ILayer* layernorm_layer = nullptr; if (engine_->with_dynamic_shape()) { @@ -73,14 +63,15 @@ class LayerNormOpConverter : public OpConverter { std::vector mean_shape{input_num}; std::vector variance_shape{input_num}; plugin::LayerNormPluginDynamic* plugin = - new plugin::LayerNormPluginDynamic(bias_data, - bias_tensor->numel(), - scale_data, - scale_tensor->numel(), - begin_norm_axis, - eps, - mean_shape, - variance_shape); + new plugin::LayerNormPluginDynamic( + static_cast(bias_weight.get().values), + bias_weight.get().count, + static_cast(scale_weight.get().values), + scale_weight.get().count, + begin_norm_axis, + eps, + mean_shape, + variance_shape); layernorm_layer = engine_->AddDynamicPlugin(&X, 1, plugin); } else { int input_num = 1; @@ -89,23 +80,20 @@ class LayerNormOpConverter : public OpConverter { } std::vector mean_shape{input_num}; std::vector variance_shape{input_num}; - plugin::LayerNormPlugin* plugin = - new plugin::LayerNormPlugin(bias_data, - bias_tensor->numel(), - scale_data, - scale_tensor->numel(), - begin_norm_axis, - eps, - mean_shape, - variance_shape); + plugin::LayerNormPlugin* plugin = new plugin::LayerNormPlugin( + static_cast(bias_weight.get().values), + bias_weight.get().count, + static_cast(scale_weight.get().values), + scale_weight.get().count, + begin_norm_axis, + eps, + mean_shape, + variance_shape); layernorm_layer = engine_->AddPlugin( &X, 1, reinterpret_cast(plugin)); } auto output_name = op_desc.Output("Y").front(); - engine_->SetWeights(op_desc.Input("Bias").front(), std::move(bias_tensor)); - engine_->SetWeights(op_desc.Input("Scale").front(), - std::move(scale_tensor)); RreplenishLayerAndOutput( layernorm_layer, "layer_norm", {output_name}, test_mode); } diff --git a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc index d30dc5eb35b15c2c842e6ab54b007a6810ac0071..8bc44cc6ab9d2abec864c36c28245885a9462fc2 100644 --- a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc @@ -48,9 +48,11 @@ class MultiheadMatMulOpConverter : public OpConverter { in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")); engine_->SetTensorDynamicRange(input, in_scale); } - weight_data = engine_->GetWeightCPUData(weight_name, weight_t); + weight_data = const_cast(static_cast( + engine_->GetFp32TrtWeight(weight_name, *weight_t).get().values)); - float* bias_data = engine_->GetWeightCPUData(bias_name, bias_t); + float* bias_data = const_cast(static_cast( + engine_->GetFp32TrtWeight(bias_name, *bias_t).get().values)); std::vector weight_data_tmp; weight_data_tmp.reserve(weight_t->numel()); memcpy( diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index 8bcc926b856e2d5ead66bd2d24603b5f379a6dfd..0eb2bc0875fdfb3f6950a7f91cba55508f14963d 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -343,6 +343,8 @@ class OpConverter { FluidDataType2TRT( var->Proto()->type().lod_tensor().tensor().data_type()), Vec2TRT_Dims(var_shape, input)); + VLOG(1) << "Set trt input [" << input << "] type is " + << var->Proto()->type().lod_tensor().tensor().data_type(); } } PADDLE_ENFORCE_EQ(all_dynamic_shape_set, @@ -561,33 +563,8 @@ class OpConverter { const std::string& name) { auto* var_v = scope.FindVar(name); auto* var_t = var_v->GetMutable(); - void* trt_ptr = nullptr; - size_t trt_num = static_cast(var_t->numel()); - nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT; - if (var_t->dtype() == phi::DataType::FLOAT32) { - float* data_ptr = engine_->GetWeightCPUData(name, var_t); - trt_ptr = static_cast(data_ptr); - } else if (var_t->dtype() == phi::DataType::INT32) { - int32_t* data_ptr = engine_->GetWeightCPUData(name, var_t); - trt_ptr = static_cast(data_ptr); - trt_dtype = nvinfer1::DataType::kINT32; - } else if (var_t->dtype() == phi::DataType::INT64) { - int64_t* data_ptr = engine_->GetWeightCPUData(name, var_t); - // We must create a new framework::Tensor() - std::unique_ptr new_var_t(new framework::Tensor()); - new_var_t->Resize({var_t->numel()}); - int32_t* new_data_ptr = - new_var_t->mutable_data(platform::CPUPlace()); - for (size_t i = 0; i < trt_num; i++) { - new_data_ptr[i] = data_ptr[i]; - } - engine_->SetWeights(name, std::move(new_var_t)); - trt_ptr = static_cast(new_data_ptr); - trt_dtype = nvinfer1::DataType::kINT32; - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "Unsupported datatype in TensorRT")); - } + auto weight = engine_->GetTrtWeight(name, *var_t); + // Now we have create weights, then we need create a itensor auto var_dims = var_t->dims(); nvinfer1::Dims trt_in_shape; @@ -603,7 +580,6 @@ class OpConverter { trt_in_shape.d[i] = trt_in_shape.d[i + 1]; } } - TensorRTEngine::Weight weight{trt_dtype, trt_ptr, trt_num}; nvinfer1::ILayer* layer = TRT_ENGINE_ADD_LAYER(engine_, Constant, trt_in_shape, weight.get()); engine_->SetITensor(name, layer->getOutput(0)); diff --git a/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc index 78dd812e035dbf6a9a26e884bcb412ca7f81005f..5bfa1170fa1091138f2cd6c4264879edf92740da 100644 --- a/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc @@ -81,7 +81,8 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { auto* temp_tensor = temp_var->GetMutable(); (*dims) = temp_tensor->dims(); - auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor); + auto* temp_data = const_cast(static_cast( + engine_->GetFp32TrtWeight(var_name, *temp_tensor).get().values)); return temp_data; }; diff --git a/paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc b/paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc index d09df4a4f281894345bb9e6e0ae621e8e2468801..7b89b62dc8b66c465d6d2481f30b89fceb70967b 100644 --- a/paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc +++ b/paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.h" namespace paddle { @@ -43,7 +44,8 @@ class PrelnResidualBiasOpConverter : public OpConverter { auto* temp_var = scope.FindVar(var_name); auto* temp_tensor = temp_var->GetMutable(); (*dims) = temp_tensor->dims(); - auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor); + auto* temp_data = const_cast(static_cast( + engine_->GetFp32TrtWeight(var_name, *temp_tensor).get().values)); return temp_data; }; framework::DDim bias_dims, scale_dims, ele_bias_dims; diff --git a/paddle/fluid/inference/tensorrt/convert/preln_skip_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/preln_skip_layernorm.cc index 7824a9b23dc5ec88258ff0df3c478e3e638eab95..bc9b317920755a8c680b031863ecd01ce2c899c8 100644 --- a/paddle/fluid/inference/tensorrt/convert/preln_skip_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/preln_skip_layernorm.cc @@ -49,7 +49,8 @@ class PrelnSkipLayerNormOpConverter : public OpConverter { auto* temp_tensor = temp_var->GetMutable(); (*dims) = temp_tensor->dims(); - auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor); + auto* temp_data = const_cast(static_cast( + engine_->GetFp32TrtWeight(var_name, *temp_tensor).get().values)); return temp_data; }; diff --git a/paddle/fluid/inference/tensorrt/convert/prelu_op.cc b/paddle/fluid/inference/tensorrt/convert/prelu_op.cc index 3195833c0e5701169f0a25558fa193cdf1a7a9ba..38b01eff6fb1989a49c015d0a4b9eb372b0d175c 100644 --- a/paddle/fluid/inference/tensorrt/convert/prelu_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/prelu_op.cc @@ -43,28 +43,21 @@ class PReluOpConverter : public OpConverter { auto* alpha_var = scope.FindVar(op_desc.Input("Alpha")[0]); auto* alpha_tensor = alpha_var->GetMutable(); + auto alpha_weight = + engine_->GetFp32TrtWeight(op_desc.Input("Alpha")[0], *alpha_tensor); + platform::CPUPlace cpu_place; - std::unique_ptr alpha_tensor_temp( - new framework::LoDTensor()); - alpha_tensor_temp->Resize(alpha_tensor->dims()); - paddle::framework::TensorCopySync( - *alpha_tensor, cpu_place, alpha_tensor_temp.get()); - float* alpha_data = alpha_tensor_temp->mutable_data(cpu_place); nvinfer1::ILayer* layer = nullptr; if (engine_->with_dynamic_shape()) { plugin::PReluPluginDynamic* plugin = new plugin::PReluPluginDynamic( - alpha_data, alpha_tensor_temp->numel(), mode, data_format); + static_cast(alpha_weight.get().values), + alpha_tensor->numel(), + mode, + data_format); layer = engine_->AddDynamicPlugin(&input, input_num, plugin); } else { #if IS_TRT_VERSION_GE(7000) - float* alpha_weight_data = - engine_->GetWeightCPUData(op_desc.Input("Alpha")[0], alpha_tensor); - TensorRTEngine::Weight alpha_weight{ - nvinfer1::DataType::kFLOAT, - static_cast(alpha_weight_data), - static_cast(alpha_tensor->numel())}; - nvinfer1::Dims dims; dims.nbDims = 0; // jump batch dim @@ -83,13 +76,13 @@ class PReluOpConverter : public OpConverter { engine_, ParametricReLU, *input, *alpha_layer_output); #else plugin::PReluPlugin* plugin = new plugin::PReluPlugin( - alpha_data, alpha_tensor_temp->numel(), mode, data_format); + static_cast(alpha_weight.get().values), + alpha_tensor->numel(), + mode, + data_format); layer = engine_->AddPlugin(&input, input_num, plugin); #endif } - // keep alpha tensor to avoid release it's memory - engine_->SetWeights(op_desc.Input("Alpha")[0], - std::move(alpha_tensor_temp)); auto output_name = op_desc.Output("Out")[0]; RreplenishLayerAndOutput(layer, "prelu", {output_name}, test_mode); diff --git a/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc index 9ed72610dc1794ce1439777c692f9d5c64b343ec..cf95a4d9b55e0e8cea99491fa28fa7839b00ec72 100644 --- a/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/convert/utils.h" +#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h" namespace paddle { @@ -34,22 +36,6 @@ class SkipLayerNormOpConverter : public OpConverter { inputs.push_back(input1); inputs.push_back(input2); - auto get_persistable_data = [&](const std::string& arg_name, - framework::DDim* dims) -> float* { - std::string var_name = op_desc.Input(arg_name).front(); - auto* temp_var = scope.FindVar(var_name); - auto* temp_tensor = temp_var->GetMutable(); - (*dims) = temp_tensor->dims(); - - auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor); - return temp_data; - }; - - framework::DDim bias_dims, scale_dims; - auto* bias = get_persistable_data("Bias", &bias_dims); - auto* scale = get_persistable_data("Scale", &scale_dims); - int bias_size = phi::product(bias_dims); - int scale_size = phi::product(scale_dims); bool enable_int8 = op_desc.HasAttr("enable_int8"); nvinfer1::ILayer* layer = nullptr; @@ -57,6 +43,18 @@ class SkipLayerNormOpConverter : public OpConverter { engine_->tensorrt_transformer_posid() != "" && engine_->tensorrt_transformer_maskid() != ""; if (flag_varseqlen) { + auto GetWeight = + [&](const std::string& arg_name) -> TensorRTEngine::Weight { + std::string var_name = op_desc.Input(arg_name).front(); + auto* temp_var = scope.FindVar(var_name); + auto* temp_tensor = temp_var->GetMutable(); + auto weight = engine_->GetTrtWeight(var_name, *temp_tensor); + return weight; + }; + + auto bias_weight = GetWeight("Bias").get(); + auto scale_weight = GetWeight("Scale").get(); + if (engine_->with_interleaved()) { VLOG(4) << "fused skip_layernorm op: use_varseqlen and with_interleaved"; @@ -72,11 +70,14 @@ class SkipLayerNormOpConverter : public OpConverter { platform::errors::InvalidArgument( "fail to get creator of CustomSkipLayerNormPluginDynamic")); const std::vector fields{ - {"beta", bias, nvinfer1::PluginFieldType::kFLOAT32, bias_size}, + {"beta", + bias_weight.values, + GetPluginFieldType(bias_weight.type), + static_cast(bias_weight.count)}, { "gamma", - scale, - nvinfer1::PluginFieldType::kFLOAT32, - scale_size }}; + scale_weight.values, + GetPluginFieldType(scale_weight.type), + static_cast(scale_weight.count) }}; nvinfer1::PluginFieldCollection* pluginPtr = static_cast( malloc(sizeof(*pluginPtr) + @@ -119,8 +120,14 @@ class SkipLayerNormOpConverter : public OpConverter { const std::vector fields{ {"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1}, {"ld", &ld, nvinfer1::PluginFieldType::kINT32, 1}, - {"beta", bias, nvinfer1::PluginFieldType::kFLOAT32, bias_size}, - {"gamma", scale, nvinfer1::PluginFieldType::kFLOAT32, scale_size}, + {"beta", + bias_weight.values, + GetPluginFieldType(bias_weight.type), + static_cast(bias_weight.count)}, + {"gamma", + scale_weight.values, + GetPluginFieldType(scale_weight.type), + static_cast(scale_weight.count)}, }; nvinfer1::PluginFieldCollection* pluginPtr = static_cast( @@ -143,12 +150,29 @@ class SkipLayerNormOpConverter : public OpConverter { layer = plugin_layer; } } else { + auto GetFp32Weight = + [&](const std::string& arg_name) -> TensorRTEngine::Weight { + std::string var_name = op_desc.Input(arg_name).front(); + auto* temp_var = scope.FindVar(var_name); + auto* temp_tensor = temp_var->GetMutable(); + auto weight = engine_->GetFp32TrtWeight(var_name, *temp_tensor); + return weight; + }; + + auto bias_weight = GetFp32Weight("Bias").get(); + auto scale_weight = GetFp32Weight("Scale").get(); + float eps = BOOST_GET_CONST(float, op_desc.GetAttr("epsilon")); bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); plugin::SkipLayerNormPluginDynamic* plugin = new plugin::SkipLayerNormPluginDynamic( - bias, scale, bias_size, scale_size, eps, with_fp16); + static_cast(bias_weight.values), + static_cast(scale_weight.values), + bias_weight.count, + scale_weight.count, + eps, + with_fp16); layer = engine_->AddDynamicPlugin(inputs.data(), 2, plugin); } diff --git a/paddle/fluid/inference/tensorrt/convert/sparse_fc_op.cc b/paddle/fluid/inference/tensorrt/convert/sparse_fc_op.cc index 6974e5a77006e28184f4929323eefc78128db639..33801e969172a28f8861d66b115a4d43ee6cdb44 100644 --- a/paddle/fluid/inference/tensorrt/convert/sparse_fc_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/sparse_fc_op.cc @@ -154,7 +154,10 @@ class SparseFcOpConverter : public OpConverter { } engine_->SetTensorDynamicRange(X, in_scale); } - weight_data = engine_->GetWeightCPUData(op_desc.Input(w_name).front(), Y_t); + weight_data = const_cast(static_cast( + engine_->GetFp32TrtWeight(op_desc.Input(w_name).front(), *Y_t) + .get() + .values)); PADDLE_ENFORCE_EQ( Y_t->dims().size(), @@ -321,7 +324,10 @@ class SparseFcOpConverter : public OpConverter { if (with_bias) { auto* b_v = scope.GetVar(op_desc.Input("Bias").front()); auto* b_t = b_v->GetMutable(); - bias_data = engine_->GetWeightCPUData(op_desc.Input("Bias").front(), b_t); + bias_data = weight_data = const_cast(static_cast( + engine_->GetFp32TrtWeight(op_desc.Input("Bias").front(), *b_t) + .get() + .values)); bias_num = b_t->numel(); } // Running the TRT Static Shape mode: x_num_col_dims-1 diff --git a/paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc index 7f54f97d34933c47f5d34d6e7fb1a45a970fb611..4a8d15ef0dbace05f59eb89dfaf79392e99a5b33 100644 --- a/paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc @@ -64,9 +64,11 @@ class SparseMultiheadMatMulOpConverter : public OpConverter { in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")); engine_->SetTensorDynamicRange(input, in_scale); } - weight_data = engine_->GetWeightCPUData(weight_name, weight_t); + weight_data = const_cast(static_cast( + engine_->GetFp32TrtWeight(weight_name, *weight_t).get().values)); - float* bias_data = engine_->GetWeightCPUData(bias_name, bias_t); + float* bias_data = const_cast(static_cast( + engine_->GetFp32TrtWeight(bias_name, *bias_t).get().values)); std::vector weight_data_tmp; weight_data_tmp.reserve(weight_t->numel()); memcpy( diff --git a/paddle/fluid/inference/tensorrt/convert/utils.h b/paddle/fluid/inference/tensorrt/convert/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..1415e67fbeccdfd085489b8f5a7aa0710d1bd40d --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/utils.h @@ -0,0 +1,45 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/inference/tensorrt/engine.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +inline nvinfer1::PluginFieldType GetPluginFieldType(nvinfer1::DataType type) { + switch (type) { +#if IS_TRT_VERSION_GE(7000) + case nvinfer1::DataType::kBOOL: + return nvinfer1::PluginFieldType::kCHAR; +#endif + case nvinfer1::DataType::kFLOAT: + return nvinfer1::PluginFieldType::kFLOAT32; + case nvinfer1::DataType::kHALF: + return nvinfer1::PluginFieldType::kFLOAT16; + case nvinfer1::DataType::kINT32: + return nvinfer1::PluginFieldType::kINT32; + case nvinfer1::DataType::kINT8: + return nvinfer1::PluginFieldType::kINT8; + default: + return nvinfer1::PluginFieldType::kUNKNOWN; + } +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 9fe8f67e6a6573a31883b7816d4beec79b621057..a4d373e83b35523305177adb29f1cda04b3fd0b9 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -19,15 +19,46 @@ limitations under the License. */ #include +#include "NvInferRuntimeCommon.h" #include "cuda_runtime_api.h" // NOLINT #include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/common/data_type.h" namespace paddle { namespace inference { namespace tensorrt { +void TensorRTEngine::Weight::SetDataType(phi::DataType type) { + nvinfer1::DataType nv_type; + switch (type) { + case phi::DataType::FLOAT32: + nv_type = nvinfer1::DataType::kFLOAT; + break; + case phi::DataType::FLOAT16: + nv_type = nvinfer1::DataType::kHALF; + break; + case phi::DataType::INT32: + nv_type = nvinfer1::DataType::kINT32; + break; + case phi::DataType::INT8: + nv_type = nvinfer1::DataType::kINT8; + break; +#if IS_TRT_VERSION_GE(7000) + case phi::DataType::BOOL: + nv_type = nvinfer1::DataType::kBOOL; + break; +#endif + default: + paddle::platform::errors::InvalidArgument( + "Paddle-TRT loads weighths failed, found not supported data type %s.", + type); + break; + } + w_.type = nv_type; +} + int TensorRTEngine::runtime_batch_ = 1; void TensorRTEngine::InitNetwork() { @@ -197,6 +228,18 @@ void TensorRTEngine::FreezeNetwork() { } } + // If model is mixed precision, then we should cast all float output to + // float32 precision. Otherwise, we can not confirm the output precision of + // the trt engine. + if (model_precision_ != phi::DataType::FLOAT32) { + for (int i = 0; i < network()->getNbOutputs(); ++i) { + network()->getOutput(i)->setAllowedFormats( + static_cast( + 1 << static_cast(nvinfer1::TensorFormat::kLINEAR))); + network()->getOutput(i)->setType(nvinfer1::DataType::kFLOAT); + } + } + if (use_dla_) { if (!enable_int8 && !enable_fp16) { LOG(WARNING) << "TensorRT DLA must be used with int8 or fp16, but you " @@ -399,26 +442,126 @@ void TensorRTEngine::SetRuntimeBatch(size_t batch_size) { runtime_batch_ = batch_size; } -template -T *TensorRTEngine::GetWeightCPUData(const std::string &name, - framework::Tensor *weight_tensor) { - std::unique_ptr cpu_weight_tensor(new framework::Tensor()); +TensorRTEngine::Weight TensorRTEngine::GetFp32TrtWeight( + const std::string &name, const framework::Tensor &weight_tensor) { + static int name_suffix_counter = 0; + std::string name_suffix = std::to_string(name_suffix_counter); + std::string splitter = "__"; + std::string name_with_suffix = name + splitter + name_suffix; platform::CPUPlace cpu_place; - cpu_weight_tensor->Resize(weight_tensor->dims()); - paddle::framework::TensorCopySync( - *weight_tensor, cpu_place, cpu_weight_tensor.get()); - T *weight_data = cpu_weight_tensor->mutable_data(cpu_place); - SetWeights(name, std::move(cpu_weight_tensor)); - return weight_data; + PADDLE_ENFORCE_EQ(weight_map.count(name_with_suffix), + 0, + platform::errors::AlreadyExists( + "The weight named %s is set into the weight map " + "twice in TRT OP converter.", + name_with_suffix)); + weight_map[name_with_suffix].reset(new framework::Tensor()); + weight_map[name_with_suffix]->Resize(weight_tensor.dims()); + + TensorRTEngine::Weight weight; + weight.SetCount(weight_tensor.numel()); + weight.SetDataType(nvinfer1::DataType::kFLOAT); + // weight_tensor.dims().; + + // if trt not support dtype, we need to cast to fp32. + if (weight_tensor.dtype() == phi::DataType::BFLOAT16) { + framework::Tensor bf16_tensor; + bf16_tensor.clear(); + paddle::framework::TensorCopySync( + weight_tensor, platform::CPUPlace(), &bf16_tensor); + weight_map[name_with_suffix]->set_type( + paddle::experimental::DataType::FLOAT32); + weight_map[name_with_suffix]->Resize(weight_tensor.dims()); + auto *fp32_data = + weight_map[name_with_suffix]->mutable_data(platform::CPUPlace()); + auto *bf16_data = bf16_tensor.mutable_data(platform::CPUPlace()); + for (int i = 0; i < weight_tensor.numel(); i++) { + fp32_data[i] = static_cast(bf16_data[i]); + } + } else if (weight_tensor.dtype() == phi::DataType::FLOAT16) { + framework::Tensor fp16_tensor; + fp16_tensor.clear(); + paddle::framework::TensorCopySync( + weight_tensor, platform::CPUPlace(), &fp16_tensor); + weight_map[name_with_suffix]->set_type( + paddle::experimental::DataType::FLOAT32); + weight_map[name_with_suffix]->Resize(weight_tensor.dims()); + auto *fp32_data = + weight_map[name_with_suffix]->mutable_data(platform::CPUPlace()); + auto *fp16_data = fp16_tensor.mutable_data(platform::CPUPlace()); + for (int i = 0; i < weight_tensor.numel(); i++) { + fp32_data[i] = static_cast(fp16_data[i]); + } + } else { + paddle::framework::TensorCopySync( + weight_tensor, cpu_place, weight_map[name_with_suffix].get()); + } + weight.SetValues(weight_map[name_with_suffix]->data()); + name_suffix_counter += 1; + return weight; } -template float *TensorRTEngine::GetWeightCPUData( - const std::string &name, framework::Tensor *weight_tensor); -template int32_t *TensorRTEngine::GetWeightCPUData( - const std::string &name, framework::Tensor *weight_tensor); +TensorRTEngine::Weight TensorRTEngine::GetTrtWeight( + const std::string &name, const framework::Tensor &weight_tensor) { + static int name_suffix_counter = 0; + std::string name_suffix = std::to_string(name_suffix_counter); + std::string splitter = "__"; + std::string name_with_suffix = name + splitter + name_suffix; + platform::CPUPlace cpu_place; + PADDLE_ENFORCE_EQ(weight_map.count(name_with_suffix), + 0, + platform::errors::AlreadyExists( + "The weight named %s is set into the weight map " + "twice in TRT OP converter.", + name_with_suffix)); + + weight_map[name_with_suffix].reset(new framework::Tensor()); + weight_map[name_with_suffix]->Resize(weight_tensor.dims()); + + TensorRTEngine::Weight weight; + weight.SetCount(weight_tensor.numel()); + + // if trt not support dtype, we need to cast to fp32. + if (weight_tensor.dtype() == phi::DataType::BFLOAT16) { + framework::Tensor bf16_tensor; + bf16_tensor.clear(); + paddle::framework::TensorCopySync( + weight_tensor, platform::CPUPlace(), &bf16_tensor); + weight_map[name_with_suffix]->set_type( + paddle::experimental::DataType::FLOAT32); + auto *fp32_data = + weight_map[name_with_suffix]->mutable_data(platform::CPUPlace()); + auto *bf16_data = bf16_tensor.mutable_data(platform::CPUPlace()); + for (int i = 0; i < weight_tensor.numel(); i++) { + fp32_data[i] = static_cast(bf16_data[i]); + } + weight.SetDataType(phi::DataType::FLOAT32); + weight.SetValues(fp32_data); + } else if (weight_tensor.dtype() == phi::DataType::INT64) { + framework::Tensor int64_tensor; + int64_tensor.clear(); + paddle::framework::TensorCopySync( + weight_tensor, platform::CPUPlace(), &int64_tensor); + weight_map[name_with_suffix]->set_type( + paddle::experimental::DataType::INT32); + auto *int32_data = + weight_map[name_with_suffix]->mutable_data(platform::CPUPlace()); + auto *int64_data = int64_tensor.mutable_data(platform::CPUPlace()); + for (int i = 0; i < weight_tensor.numel(); i++) { + int32_data[i] = int64_data[i]; + } + weight.SetDataType(phi::DataType::FLOAT32); + weight.SetValues(int32_data); + } else { + paddle::framework::TensorCopySync( + weight_tensor, cpu_place, weight_map[name_with_suffix].get()); + weight.SetDataType(weight_tensor.dtype()); + weight.SetValues(weight_map[name_with_suffix]->data()); + } -template int64_t *TensorRTEngine::GetWeightCPUData( - const std::string &name, framework::Tensor *weight_tensor); + name_suffix_counter += 1; + return weight; +} int TensorRTEngine::GetRuntimeBatch() { return runtime_batch_; } diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index 5c2bb6e0ca07f412deea65d83fd43d0dcc341d80..73506eb8f6244d8d8a9107f3e0e8efa657982ab3 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -25,6 +25,8 @@ limitations under the License. */ #include #include +#include "NvInferRuntimeCommon.h" +#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/inference/api/paddle_analysis_config.h" @@ -34,6 +36,7 @@ limitations under the License. */ #include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h" #include "paddle/fluid/inference/utils/singleton.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/common/data_type.h" #include "paddle/utils/any.h" namespace paddle { @@ -187,6 +190,14 @@ class TensorRTEngine { } const nvinfer1::Weights& get() { return w_; } + void SetDataType(nvinfer1::DataType type) { w_.type = type; } + + void SetDataType(phi::DataType type); + + void SetValues(const void* values) { w_.values = values; } + + void SetCount(int64_t num) { w_.count = num; } + std::vector dims; private: @@ -203,6 +214,7 @@ class TensorRTEngine { const ShapeMapType max_input_shape = {}, const ShapeMapType optim_input_shape = {}, bool disable_trt_plugin_fp16 = false, + phi::DataType model_precision = phi::DataType::FLOAT32, nvinfer1::ILogger& logger = NaiveLogger::Global()) : max_batch_(max_batch), max_workspace_(max_workspace), @@ -213,6 +225,7 @@ class TensorRTEngine { max_input_shape_(max_input_shape), optim_input_shape_(optim_input_shape), disable_trt_plugin_fp16_(disable_trt_plugin_fp16), + model_precision_(model_precision), logger_(logger) { if (min_input_shape_.size() != 0 && max_input_shape_.size() != 0 && optim_input_shape_.size() != 0) { @@ -407,6 +420,14 @@ class TensorRTEngine { quant_dynamic_range_[tensor] = range; } + // Get fp32 trt weight. If src weight is not fp32, we will cast. + Weight GetFp32TrtWeight(const std::string& name, + const framework::Tensor& weight_tensor); + + // if the src weight type is fp16, then return fp16 trt weight, etc. + Weight GetTrtWeight(const std::string& name, + const framework::Tensor& weight_tensor); + float GetTensorDynamicRange(nvinfer1::ITensor* tensor) { return quant_dynamic_range_[tensor]; } @@ -415,10 +436,6 @@ class TensorRTEngine { return quant_dynamic_range_.count(tensor); } - template - T* GetWeightCPUData(const std::string& name, - framework::Tensor* weight_tensor); - // A pointer to CPU memory is needed of the TRT weight. // Before TRT runs, fluid loads weight into GPU storage. // so we need to copy the weights from GPU to CPU in our op converter. @@ -669,6 +686,7 @@ class TensorRTEngine { ShapeMapType max_input_shape_; ShapeMapType optim_input_shape_; bool disable_trt_plugin_fp16_{false}; + phi::DataType model_precision_{phi::DataType::FLOAT32}; bool use_varseqlen_{false}; bool use_dla_{false}; int dla_core_{0}; @@ -756,6 +774,7 @@ class TRTEngineManager { const std::map> max_input_shape = {}, const std::map> optim_input_shape = {}, bool disable_trt_plugin_fp16 = false, + phi::DataType model_precision = phi::DataType::FLOAT32, nvinfer1::ILogger& logger = NaiveLogger::Global()) { auto* p = new TensorRTEngine(max_batch, max_workspace, @@ -766,6 +785,7 @@ class TRTEngineManager { max_input_shape, optim_input_shape, disable_trt_plugin_fp16, + model_precision, logger); engines_[name].reset(p); return p; diff --git a/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc b/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc index 8f20ffb5e6b8cbf4c95b226ca06bf97866cdbd5e..eae1e2baf9ad1694f0b291cb61df45dd36fb9573 100644 --- a/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc +++ b/paddle/fluid/inference/tensorrt/test_dynamic_engine.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/phi/common/data_type.h" #if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000) #include "paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h" #endif @@ -66,6 +67,7 @@ class TensorRTDynamicEngineTest : public ::testing::Test { max_input_shape, optim_input_shape, false, + phi::DataType::FLOAT32, NaiveLogger::Global()); engine_->InitNetwork(); } diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index b0ac285b5d38d63a1205e589b6593298777b30a3..1cd2683796acd2781a714418d63001846cf18715 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -14,7 +14,12 @@ #pragma once +#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/place.h" #ifdef PADDLE_WITH_CUDA #include @@ -192,6 +197,7 @@ class TensorRTEngineOp : public framework::OperatorBase { std::map> min_input_shape_{}; std::map> max_input_shape_{}; std::map> opt_input_shape_{}; + phi::DataType model_precision_{phi::DataType::FLOAT32}; public: TensorRTEngineOp(const std::string &type, @@ -217,6 +223,7 @@ class TensorRTEngineOp : public framework::OperatorBase { if (use_static_engine_) { model_opt_cache_dir_ = Attr("model_opt_cache_dir"); } + model_precision_ = static_cast(Attr("model_precision")); if (HasAttr("dynamic_shape_names") && HasAttr("min_input_shape") && HasAttr("max_input_shape") && HasAttr("opt_input_shape")) { @@ -555,6 +562,7 @@ class TensorRTEngineOp : public framework::OperatorBase { #endif } runtime_batch = t_shape[0]; + VLOG(1) << "trt input [" << x << "] dtype is " << t.dtype(); auto type = framework::TransToProtoVarType(t.dtype()); if (type == framework::proto::VarType::FP32) { buffers[bind_index] = static_cast(t.data()); @@ -619,6 +627,8 @@ class TensorRTEngineOp : public framework::OperatorBase { num_bindings)); auto trt_type = engine->engine()->getBindingDataType(bind_index); // get adr and set type + VLOG(1) << "trt output [" << y << "] dtype is " + << TRT2FluidDataType(trt_type); buffers[bind_index] = static_cast( fluid_t->mutable_data(dev_place, TRT2FluidDataType(trt_type))); output_index += 1; diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc b/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc index cbe14195d41060c03641cc5133fa1038faabfbdd..8e2b162babce9c2e3de9167b923ab226623579cd 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op_test.cc @@ -25,6 +25,7 @@ limitations under the License. */ #include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" +#include "paddle/phi/common/data_type.h" USE_NO_KERNEL_OP(tensorrt_engine); namespace paddle { @@ -132,6 +133,8 @@ void DynamicShapeTest(bool allow_build_at_runtime) { engine_op_desc.SetAttr("min_input_shape", std::vector{1, 4, 1, 1}); engine_op_desc.SetAttr("max_input_shape", std::vector{2, 4, 1, 1}); engine_op_desc.SetAttr("opt_input_shape", std::vector{2, 4, 1, 1}); + engine_op_desc.SetAttr("model_precision", + static_cast(phi::DataType::FLOAT32)); LOG(INFO) << "create engine op"; auto engine_op = framework::OpRegistry::CreateOp(engine_op_desc);