未验证 提交 339c34e6 编写于 作者: W wenbin 提交者: GitHub

dynamic shape clone (#38520)

* dynamic shape clone supported
上级 ebc72ac2
...@@ -275,7 +275,11 @@ cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor) ...@@ -275,7 +275,11 @@ cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor)
cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog) cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog)
cc_library(variable_helper SRCS variable_helper.cc DEPS lod_tensor) cc_library(variable_helper SRCS variable_helper.cc DEPS lod_tensor)
if (TENSORRT_FOUND)
cc_library(naive_executor SRCS naive_executor.cc DEPS op_registry denormal device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper tensorrt_engine_op)
else()
cc_library(naive_executor SRCS naive_executor.cc DEPS op_registry denormal device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper) cc_library(naive_executor SRCS naive_executor.cc DEPS op_registry denormal device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper)
endif(TENSORRT_FOUND)
cc_library(executor_gc_helper SRCS executor_gc_helper.cc DEPS scope proto_desc operator garbage_collector op_registry while_op_helper recurrent_op_helper conditional_block_op_helper) cc_library(executor_gc_helper SRCS executor_gc_helper.cc DEPS scope proto_desc operator garbage_collector op_registry while_op_helper recurrent_op_helper conditional_block_op_helper)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
......
...@@ -20,6 +20,9 @@ ...@@ -20,6 +20,9 @@
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
#if PADDLE_WITH_TENSORRT
#include "paddle/fluid/operators/tensorrt/tensorrt_engine_op.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -132,5 +135,38 @@ NaiveExecutor::~NaiveExecutor() { ...@@ -132,5 +135,38 @@ NaiveExecutor::~NaiveExecutor() {
#endif #endif
} }
void NaiveExecutor::ResetTrtOps(int num) {
#if PADDLE_WITH_TENSORRT
for (auto &op : ops_) {
if (op->Type() == "tensorrt_engine") {
operators::TensorRTEngineOp *trtop =
dynamic_cast<operators::TensorRTEngineOp *>(op.get());
if (!trtop) return;
std::string engine_key = trtop->Attr<std::string>("engine_key");
int engine_predictor_id = trtop->Attr<int>("predictor_id");
std::string engine_name =
engine_key + std::to_string(engine_predictor_id);
operators::TensorRTEngine *trt_engine =
paddle::inference::Singleton<
inference::tensorrt::TRTEngineManager>::Global()
.Get(engine_name);
if (trt_engine->with_dynamic_shape()) {
LOG(INFO) << "rebuild trt engine, this may cost a lot of time!";
trt_engine->ResetContext();
trt_engine->ClearTensorMap();
trt_engine->SetProfileNum(num);
auto *anc = scope_->parent();
while (anc && anc->parent()) {
anc = anc->parent();
}
if (anc == nullptr) {
anc = scope_;
}
trtop->PrepareTRTEngine(*anc, trt_engine);
}
}
}
#endif
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -63,6 +63,8 @@ class NaiveExecutor { ...@@ -63,6 +63,8 @@ class NaiveExecutor {
void CleanFeedFetchOps(); void CleanFeedFetchOps();
void ResetTrtOps(int num);
protected: protected:
void CreateOps(const ProgramDesc& desc, int block_id, void CreateOps(const ProgramDesc& desc, int block_id,
bool with_feed_fetch_ops); bool with_feed_fetch_ops);
......
...@@ -56,8 +56,17 @@ void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) { ...@@ -56,8 +56,17 @@ void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) {
// Because there exists the case that new parameter variables are not added to // Because there exists the case that new parameter variables are not added to
// the program in the analysis pass. // the program in the analysis pass.
bool reserve_cpu_weights = false; bool reserve_cpu_weights = false;
if (argument->tensorrt_allow_build_at_runtime_valid() && bool with_dynamic_shape = false;
argument->tensorrt_allow_build_at_runtime()) { if (argument->Has("max_input_shape") && argument->Has("min_input_shape") &&
argument->Has("optim_input_shape")) {
with_dynamic_shape = (argument->max_input_shape().size() > 0 &&
argument->min_input_shape().size() > 0 &&
argument->optim_input_shape().size() > 0);
}
with_dynamic_shape =
with_dynamic_shape || (argument->Has("tensorrt_tuned_dynamic_shape") &&
argument->tensorrt_tuned_dynamic_shape());
if (with_dynamic_shape) {
reserve_cpu_weights = true; reserve_cpu_weights = true;
} }
for (auto &var_name : all_vars) { for (auto &var_name : all_vars) {
......
...@@ -1344,6 +1344,7 @@ std::unique_ptr<PaddlePredictor> AnalysisPredictor::Clone() { ...@@ -1344,6 +1344,7 @@ std::unique_ptr<PaddlePredictor> AnalysisPredictor::Clone() {
std::lock_guard<std::mutex> lk(clone_mutex_); std::lock_guard<std::mutex> lk(clone_mutex_);
auto *x = new AnalysisPredictor(config_); auto *x = new AnalysisPredictor(config_);
x->Init(scope_, inference_program_); x->Init(scope_, inference_program_);
x->executor_->ResetTrtOps(++x->clone_num_);
return std::unique_ptr<PaddlePredictor>(x); return std::unique_ptr<PaddlePredictor>(x);
} }
......
...@@ -435,6 +435,7 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -435,6 +435,7 @@ class AnalysisPredictor : public PaddlePredictor {
bool status_is_cloned_{false}; bool status_is_cloned_{false};
std::map<std::string, std::vector<std::vector<int32_t>>> shape_info_; std::map<std::string, std::vector<std::vector<int32_t>>> shape_info_;
int clone_num_{1};
}; };
} // namespace paddle } // namespace paddle
...@@ -42,7 +42,10 @@ void TensorRTEngine::InitNetwork() { ...@@ -42,7 +42,10 @@ void TensorRTEngine::InitNetwork() {
} }
infer_builder_config_.reset(infer_builder_->createBuilderConfig()); infer_builder_config_.reset(infer_builder_->createBuilderConfig());
optim_profile_ = infer_builder_->createOptimizationProfile(); // optim_profile_ = infer_builder_->createOptimizationProfile();
optim_profiles_.resize(max_profile_num_);
for (int i = 0; i < max_profile_num_; i++)
optim_profiles_[i] = infer_builder_->createOptimizationProfile();
} }
void TensorRTEngine::Execute(int batch_size, std::vector<void *> *buffers, void TensorRTEngine::Execute(int batch_size, std::vector<void *> *buffers,
...@@ -199,35 +202,38 @@ void TensorRTEngine::FreezeNetwork() { ...@@ -199,35 +202,38 @@ void TensorRTEngine::FreezeNetwork() {
if (with_dynamic_shape_) { if (with_dynamic_shape_) {
#if IS_TRT_VERSION_GE(6000) #if IS_TRT_VERSION_GE(6000)
LOG(INFO) << "Run Paddle-TRT Dynamic Shape mode."; LOG(INFO) << "Run Paddle-TRT Dynamic Shape mode.";
for (auto &input : min_input_shape_) { for (int i = 0; i < max_profile_num_; i++) {
for (auto &input : min_input_shape_) {
#if IS_TRT_VERSION_LT(7000) #if IS_TRT_VERSION_LT(7000)
// trt6 will check all_of input > 0 // trt6 will check all_of input > 0
if (!(std::all_of(input.second.begin(), input.second.end(), if (!(std::all_of(input.second.begin(), input.second.end(),
[](int x) { return x > 0; }) && [](int x) { return x > 0; }) &&
std::all_of(max_input_shape_[input.first].begin(), std::all_of(max_input_shape_[input.first].begin(),
max_input_shape_[input.first].end(), max_input_shape_[input.first].end(),
[](int x) { return x > 0; }) && [](int x) { return x > 0; }) &&
std::all_of(optim_input_shape_[input.first].begin(), std::all_of(optim_input_shape_[input.first].begin(),
optim_input_shape_[input.first].end(), optim_input_shape_[input.first].end(),
[](int x) { return x > 0; }))) { [](int x) { return x > 0; }))) {
continue; continue;
} }
#endif #endif
VLOG(4) << "TRT dynamic_shape set " << input.first VLOG(4) << "TRT dynamic_shape set " << input.first
<< " min: " << Vec2Str(input.second) << " min: " << Vec2Str(input.second)
<< ", max: " << Vec2Str(max_input_shape_[input.first]) << ", max: " << Vec2Str(max_input_shape_[input.first])
<< ", opt: " << Vec2Str(optim_input_shape_[input.first]); << ", opt: " << Vec2Str(optim_input_shape_[input.first]);
optim_profile_->setDimensions(
input.first.c_str(), nvinfer1::OptProfileSelector::kMIN, optim_profiles_[i]->setDimensions(
Vec2TRT_Dims(input.second, input.first, true)); input.first.c_str(), nvinfer1::OptProfileSelector::kMIN,
optim_profile_->setDimensions( Vec2TRT_Dims(input.second, input.first, true));
input.first.c_str(), nvinfer1::OptProfileSelector::kMAX, optim_profiles_[i]->setDimensions(
Vec2TRT_Dims(max_input_shape_[input.first], input.first, true)); input.first.c_str(), nvinfer1::OptProfileSelector::kMAX,
optim_profile_->setDimensions( Vec2TRT_Dims(max_input_shape_[input.first], input.first, true));
input.first.c_str(), nvinfer1::OptProfileSelector::kOPT, optim_profiles_[i]->setDimensions(
Vec2TRT_Dims(optim_input_shape_[input.first], input.first, true)); input.first.c_str(), nvinfer1::OptProfileSelector::kOPT,
Vec2TRT_Dims(optim_input_shape_[input.first], input.first, true));
}
infer_builder_config_->addOptimizationProfile(optim_profiles_[i]);
} }
infer_builder_config_->addOptimizationProfile(optim_profile_);
if (WithFp16() && disable_trt_plugin_fp16()) { if (WithFp16() && disable_trt_plugin_fp16()) {
LOG(INFO) << "NOTE: In order to achieve higher accuracy, you have " LOG(INFO) << "NOTE: In order to achieve higher accuracy, you have "
"disabled the fp16 mode of TRT Plugin,\n" "disabled the fp16 mode of TRT Plugin,\n"
...@@ -237,7 +243,6 @@ void TensorRTEngine::FreezeNetwork() { ...@@ -237,7 +243,6 @@ void TensorRTEngine::FreezeNetwork() {
} }
#endif #endif
} }
#if IS_TRT_VERSION_GE(8200) #if IS_TRT_VERSION_GE(8200)
infer_builder_config_->setProfilingVerbosity( infer_builder_config_->setProfilingVerbosity(
nvinfer1::ProfilingVerbosity::kDETAILED); nvinfer1::ProfilingVerbosity::kDETAILED);
...@@ -260,6 +265,13 @@ void TensorRTEngine::FreezeNetwork() { ...@@ -260,6 +265,13 @@ void TensorRTEngine::FreezeNetwork() {
"Build TensorRT cuda engine failed! Please recheck " "Build TensorRT cuda engine failed! Please recheck "
"you configurations related to paddle-TensorRT.")); "you configurations related to paddle-TensorRT."));
binding_num_ = infer_engine_->getNbBindings();
// reset status for dynamic shape clone
if (max_profile_num_ > 1) {
infer_context_.clear();
cur_profile_num_ = 0;
}
GetEngineInfo(); GetEngineInfo();
} }
......
...@@ -253,10 +253,38 @@ class TensorRTEngine { ...@@ -253,10 +253,38 @@ class TensorRTEngine {
infer_engine_, infer_engine_,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"You should build engine first and then set the context.")); "You should build engine first and then set the context."));
// We may see trt warning: Profile 0 has been chosen by another
// IExecutionContext...
// It's ok. We will set it later.
infer_context_[tid].reset(infer_engine_->createExecutionContext()); infer_context_[tid].reset(infer_engine_->createExecutionContext());
if (with_dynamic_shape_) {
// need new profile if it's not the first
if (cur_profile_num_ > 0) {
infer_context_[tid]->setOptimizationProfile(cur_profile_num_);
}
profile_index_[tid] = cur_profile_num_;
++cur_profile_num_;
}
} }
return infer_context_[tid].get(); return infer_context_[tid].get();
} }
int GetProfileIndex() {
if (max_profile_num_ > 1) {
std::unique_lock<std::mutex> lock(mutex_);
const std::thread::id tid = std::this_thread::get_id();
return profile_index_[tid];
} else {
return 0;
}
}
int GetBindingsOffset() {
return (binding_num_ / max_profile_num_) * GetProfileIndex();
}
int GetNbBindings() { return binding_num_; }
void ResetContext() { void ResetContext() {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
const std::thread::id tid = std::this_thread::get_id(); const std::thread::id tid = std::this_thread::get_id();
...@@ -322,6 +350,7 @@ class TensorRTEngine { ...@@ -322,6 +350,7 @@ class TensorRTEngine {
"generating serialization file and doing inference are " "generating serialization file and doing inference are "
"consistent.")); "consistent."));
binding_num_ = infer_engine_->getNbBindings();
GetEngineInfo(); GetEngineInfo();
} }
...@@ -540,6 +569,7 @@ class TensorRTEngine { ...@@ -540,6 +569,7 @@ class TensorRTEngine {
} }
} }
void SetProfileNum(int num) { max_profile_num_ = num; }
void GetEngineInfo() { void GetEngineInfo() {
#if IS_TRT_VERSION_GE(8200) #if IS_TRT_VERSION_GE(8200)
std::unique_ptr<nvinfer1::IEngineInspector> infer_inspector( std::unique_ptr<nvinfer1::IEngineInspector> infer_inspector(
...@@ -571,6 +601,9 @@ class TensorRTEngine { ...@@ -571,6 +601,9 @@ class TensorRTEngine {
int batch_size_{-1}; int batch_size_{-1};
int device_id_; int device_id_;
int max_profile_num_{1};
int cur_profile_num_{0};
std::unordered_map<std::thread::id, int> profile_index_;
ShapeMapType min_input_shape_; ShapeMapType min_input_shape_;
ShapeMapType max_input_shape_; ShapeMapType max_input_shape_;
ShapeMapType optim_input_shape_; ShapeMapType optim_input_shape_;
...@@ -614,8 +647,9 @@ class TensorRTEngine { ...@@ -614,8 +647,9 @@ class TensorRTEngine {
// For dynamic shape // For dynamic shape
bool with_dynamic_shape_{false}; bool with_dynamic_shape_{false};
#if IS_TRT_VERSION_GE(6000) #if IS_TRT_VERSION_GE(6000)
int binding_num_;
infer_ptr<nvinfer1::IBuilderConfig> infer_builder_config_; infer_ptr<nvinfer1::IBuilderConfig> infer_builder_config_;
nvinfer1::IOptimizationProfile* optim_profile_; std::vector<nvinfer1::IOptimizationProfile*> optim_profiles_;
std::vector<std::unique_ptr<plugin::DynamicPluginTensorRT>> owned_pluginv2_; std::vector<std::unique_ptr<plugin::DynamicPluginTensorRT>> owned_pluginv2_;
#endif #endif
std::mutex mutex_; std::mutex mutex_;
......
...@@ -207,6 +207,87 @@ void TestTunedDynamic() { ...@@ -207,6 +207,87 @@ void TestTunedDynamic() {
check_func(test_predictor.get()); check_func(test_predictor.get());
} }
void TestDynamicClone(bool with_dynamic = true, bool delete_cache = true,
bool delete_conv_bn = false) {
std::string model_dir =
FLAGS_infer_model + "/conv_bn_swish_split_gelu/conv_bn_swish_split_gelu";
std::string opt_cache_dir = model_dir + "/my_cache";
if (delete_cache) {
delete_cache_files(opt_cache_dir);
}
AnalysisConfig config;
config.EnableUseGpu(100, 0);
std::string buffer_prog, buffer_param;
ReadBinaryFile(model_dir + "/model", &buffer_prog);
ReadBinaryFile(model_dir + "/params", &buffer_param);
config.SetModelBuffer(&buffer_prog[0], buffer_prog.size(), &buffer_param[0],
buffer_param.size());
config.SetOptimCacheDir(opt_cache_dir);
config.SwitchUseFeedFetchOps(false);
// Set the input's min, max, opt shape
config.EnableTensorRtEngine(
1 << 30, 1, 1, AnalysisConfig::Precision::kFloat32, false, false);
if (delete_conv_bn) {
config.pass_builder()->DeletePass("conv_bn_fuse_pass");
}
if (with_dynamic) {
std::map<std::string, std::vector<int>> min_input_shape = {
{"image", {1, 1, 3, 3}}};
std::map<std::string, std::vector<int>> max_input_shape = {
{"image", {1, 1, 10, 10}}};
std::map<std::string, std::vector<int>> opt_input_shape = {
{"image", {1, 1, 3, 3}}};
config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
opt_input_shape);
}
auto predictor = CreatePaddlePredictor(config);
auto input_names = predictor->GetInputNames();
int channels = 1;
int height = 3;
int width = 3;
int input_num = channels * height * width * 1;
float *input = new float[input_num];
memset(input, 0, input_num * sizeof(float));
auto input_t = predictor->GetInputTensor(input_names[0]);
input_t->Reshape({1, channels, height, width});
input_t->copy_from_cpu(input);
ASSERT_TRUE(predictor->ZeroCopyRun());
std::vector<float> out_data;
auto output_names = predictor->GetOutputNames();
auto output_t = predictor->GetOutputTensor(output_names[0]);
std::vector<int> output_shape = output_t->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
std::multiplies<int>());
out_data.resize(out_num);
output_t->copy_to_cpu(out_data.data());
auto predictor2 = predictor->Clone();
auto input_t2 = predictor2->GetInputTensor(input_names[0]);
input_t2->Reshape({1, channels, height, width});
input_t2->copy_from_cpu(input);
ASSERT_TRUE(predictor2->ZeroCopyRun());
std::vector<float> out_data2;
auto output_t2 = predictor2->GetOutputTensor(output_names[0]);
std::vector<int> output_shape2 = output_t2->shape();
int out_num2 = std::accumulate(output_shape2.begin(), output_shape2.end(), 1,
std::multiplies<int>());
out_data2.resize(out_num2);
output_t2->copy_to_cpu(out_data2.data());
ASSERT_TRUE(out_data2.size() == out_data.size());
for (size_t i = 0; i < out_data.size(); i++) {
EXPECT_NEAR(out_data2[i], out_data[i], 1e-5);
}
}
TEST(AnalysisPredictor, trt_dynamic) { TestDynamic(true); } TEST(AnalysisPredictor, trt_dynamic) { TestDynamic(true); }
TEST(AnalysisPredictor, trt_static) { TestDynamic(false); } TEST(AnalysisPredictor, trt_static) { TestDynamic(false); }
TEST(AnalysisPredictor, trt_memory_serialize) { TEST(AnalysisPredictor, trt_memory_serialize) {
...@@ -218,6 +299,7 @@ TEST(AnalysisPredictor, trt_memory_serialize) { ...@@ -218,6 +299,7 @@ TEST(AnalysisPredictor, trt_memory_serialize) {
TEST(AnalysisPredictor, trt_dynamic2) { TestDynamic2(); } TEST(AnalysisPredictor, trt_dynamic2) { TestDynamic2(); }
TEST(AnalysisPredictor, trt_tuned_dynamic) { TestTunedDynamic(); } TEST(AnalysisPredictor, trt_tuned_dynamic) { TestTunedDynamic(); }
TEST(AnalysisPredictor, trt_dynamic_clone) { TestDynamicClone(); }
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -250,6 +250,23 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -250,6 +250,23 @@ class TensorRTEngineOp : public framework::OperatorBase {
} }
} }
void PrepareTRTEngine(const framework::Scope &scope,
TensorRTEngine *engine) const {
LOG(INFO) << "Prepare TRT engine (Optimize model structure, Select OP "
"kernel etc). This process may cost a lot of time.";
framework::proto::BlockDesc block_proto;
block_proto.ParseFromString(Attr<std::string>("subgraph"));
framework::BlockDesc block_desc(nullptr, &block_proto);
std::vector<std::string> inputs = Inputs("Xs");
std::vector<std::string> outputs =
Attr<std::vector<std::string>>("output_name_mapping");
inference::Singleton<inference::tensorrt::OpConverter>::Global()
.ConvertBlockToTRTEngine(&block_desc, scope, inputs, param_names_,
outputs, engine);
}
protected: protected:
void RunNativeImpl(const framework::Scope &scope, void RunNativeImpl(const framework::Scope &scope,
const platform::Place &dev_place) const { const platform::Place &dev_place) const {
...@@ -414,8 +431,19 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -414,8 +431,19 @@ class TensorRTEngineOp : public framework::OperatorBase {
int num_inputs = 0; int num_inputs = 0;
num_inputs += runtime_input_names_.size(); num_inputs += runtime_input_names_.size();
const int num_bindings = num_inputs + Outputs("Ys").size(); // const int num_bindings = num_inputs + Outputs("Ys").size();
std::vector<void *> buffers(num_bindings); // std::vector<void *> buffers(num_bindings);
// This method returns the total over all profiles.
const int num_bindings = engine->GetNbBindings();
std::vector<void *> buffers(num_bindings, nullptr);
int binding_offset = 0;
nvinfer1::IExecutionContext *trt_context = nullptr;
if (engine->with_dynamic_shape()) {
// Initilize context and get offset by profile index
trt_context = engine->context();
binding_offset = engine->GetBindingsOffset();
}
// Bind input tensor to TRT. // Bind input tensor to TRT.
for (const auto &x : runtime_input_names_) { for (const auto &x : runtime_input_names_) {
...@@ -430,7 +458,10 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -430,7 +458,10 @@ class TensorRTEngineOp : public framework::OperatorBase {
t.ShareDataWith(out); t.ShareDataWith(out);
} }
auto t_shape = framework::vectorize<int64_t>(t.dims()); auto t_shape = framework::vectorize<int64_t>(t.dims());
const int bind_index = engine->engine()->getBindingIndex(x.c_str()); // const int bind_index = engine->engine()->getBindingIndex(x.c_str());
// Get index of profile 0 first, then plus binding offset
const int bind_index =
engine->engine()->getBindingIndex(x.c_str()) + binding_offset;
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
bind_index, num_bindings, bind_index, num_bindings,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -474,7 +505,6 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -474,7 +505,6 @@ class TensorRTEngineOp : public framework::OperatorBase {
} }
} else { } else {
#if IS_TRT_VERSION_GE(6000) #if IS_TRT_VERSION_GE(6000)
auto *trt_context = engine->context();
trt_context->setBindingDimensions( trt_context->setBindingDimensions(
bind_index, inference::tensorrt::Vec2TRT_Dims(t_shape, x, true)); bind_index, inference::tensorrt::Vec2TRT_Dims(t_shape, x, true));
#endif #endif
...@@ -500,7 +530,8 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -500,7 +530,8 @@ class TensorRTEngineOp : public framework::OperatorBase {
VLOG(4) << "TensorRT Engine Op Outputs:"; VLOG(4) << "TensorRT Engine Op Outputs:";
for (const auto &y : Outputs("Ys")) { for (const auto &y : Outputs("Ys")) {
const int bind_index = const int bind_index =
engine->engine()->getBindingIndex(output_maps[output_index].c_str()); engine->engine()->getBindingIndex(output_maps[output_index].c_str()) +
binding_offset;
std::vector<int> ddim; std::vector<int> ddim;
if (!engine->with_dynamic_shape()) { if (!engine->with_dynamic_shape()) {
...@@ -511,7 +542,6 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -511,7 +542,6 @@ class TensorRTEngineOp : public framework::OperatorBase {
} }
} else { } else {
#if IS_TRT_VERSION_GE(6000) #if IS_TRT_VERSION_GE(6000)
auto *trt_context = engine->context();
auto dims = trt_context->getBindingDimensions(bind_index); auto dims = trt_context->getBindingDimensions(bind_index);
int nb_dims = dims.nbDims; int nb_dims = dims.nbDims;
for (; nb_dims > 0; nb_dims--) { for (; nb_dims > 0; nb_dims--) {
...@@ -583,23 +613,6 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -583,23 +613,6 @@ class TensorRTEngineOp : public framework::OperatorBase {
} }
return trt_engine_; return trt_engine_;
} }
void PrepareTRTEngine(const framework::Scope &scope,
TensorRTEngine *engine) const {
LOG(INFO) << "Prepare TRT engine (Optimize model structure, Select OP "
"kernel etc). This process may cost a lot of time.";
framework::proto::BlockDesc block_proto;
block_proto.ParseFromString(Attr<std::string>("subgraph"));
framework::BlockDesc block_desc(nullptr, &block_proto);
std::vector<std::string> inputs = Inputs("Xs");
std::vector<std::string> outputs =
Attr<std::vector<std::string>>("output_name_mapping");
inference::Singleton<inference::tensorrt::OpConverter>::Global()
.ConvertBlockToTRTEngine(&block_desc, scope, inputs, param_names_,
outputs, engine);
}
}; };
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册