未验证 提交 ebc58548 编写于 作者: J JingZhuangzhuang 提交者: GitHub

Support trt engine auto build in runtime for dynamic shape (#52162)

上级 3de2206c
......@@ -199,7 +199,7 @@ void IRPassManager::CreatePasses(Argument *argument,
optim_cache_dir));
}
pass->Set("model_opt_cache_dir", new std::string(optim_cache_dir));
} else if (use_static_engine || enable_int8) {
} else if (use_static_engine || enable_int8 || with_dynamic_shape) {
std::string model_opt_cache_dir =
argument->Has("model_dir")
? argument->model_dir()
......
......@@ -14,6 +14,7 @@
// limitations under the License.
#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h"
#include <fcntl.h>
#include <cstddef>
#include <string>
#include <unordered_set>
......@@ -349,10 +350,12 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
Get<std::map<std::string, std::vector<int>>>("optim_shape_tensor");
auto allow_build_at_runtime = Get<bool>("trt_allow_build_at_runtime");
auto with_dynamic_shape = Get<bool>("with_dynamic_shape");
auto shape_range_info_path = Get<std::string>("trt_shape_range_info_path");
auto trt_tuned_dynamic_shape = Get<bool>("trt_tuned_dynamic_shape");
int max_batch_size = Get<int>("max_batch_size");
if (trt_tuned_dynamic_shape) {
if (!shape_range_info_path.empty()) {
VLOG(1) << "trt dynamic_shape deserialize from " << shape_range_info_path;
inference::DeserializeShapeRangeInfo(shape_range_info_path,
&min_input_shape,
......@@ -361,6 +364,24 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
&min_shape_tensor,
&max_shape_tensor,
&opt_shape_tensor);
} else {
shape_range_info_path =
Get<std::string>("model_opt_cache_dir") + "shape_range_info.pbtxt";
if (open(shape_range_info_path.c_str(), O_RDONLY) != -1) {
VLOG(1) << "trt dynamic_shape deserialize from "
<< shape_range_info_path;
inference::DeserializeShapeRangeInfo(shape_range_info_path,
&min_input_shape,
&max_input_shape,
&opt_input_shape,
&min_shape_tensor,
&max_shape_tensor,
&opt_shape_tensor);
} else {
int fd = open(shape_range_info_path.c_str(), O_RDONLY | O_CREAT);
close(fd);
}
}
}
// The following procedure is used to rename all the intermediate
......@@ -447,6 +468,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
op_desc->SetAttr("shape_range_info_path", shape_range_info_path);
op_desc->SetAttr("use_inspector", Get<bool>("use_inspector"));
op_desc->SetAttr("model_precision", Get<int>("model_precision"));
op_desc->SetAttr("with_dynamic_shape", with_dynamic_shape);
// we record all inputs' shapes in attr to check if they are consistent
// with the real inputs' shapes retrieved from scope when trt runs.
......@@ -563,6 +585,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
precision_mode,
calibrator.get(),
Get<int>("gpu_device_id"),
with_dynamic_shape,
min_input_shape,
max_input_shape,
opt_input_shape,
......@@ -607,6 +630,12 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
}
}
// If with_dynamic_shape is configured,but min_input_shape is empty,
// create trt engine in runtime instead of in pass.
if (with_dynamic_shape && min_input_shape.empty()) {
return;
}
// the following code will NOT run in following situation:
// 1. calibraion mode (generate trt int8 calibraiton table data)
// 2. already load serialized trt engine info.
......
......@@ -644,7 +644,8 @@ struct PD_INFER_DECL AnalysisConfig {
/// mode.
/// \param allow_build_at_runtime allow build trt engine at runtime.
///
void EnableTunedTensorRtDynamicShape(const std::string& shape_range_info_path,
void EnableTunedTensorRtDynamicShape(
const std::string& shape_range_info_path = "",
bool allow_build_at_runtime = true);
///
......
......@@ -177,6 +177,7 @@ TEST(CustomPluginCreater, DynamicShapePlugin) {
AnalysisConfig::Precision::kFloat32,
nullptr,
0,
true,
min_input_shape,
max_input_shape,
optim_input_shape));
......
......@@ -224,6 +224,7 @@ class TensorRTEngine {
AnalysisConfig::Precision precision = AnalysisConfig::Precision::kFloat32,
TRTInt8Calibrator* calibrator = nullptr,
int device_id = 0,
bool with_dynamic_shape = false,
const ShapeMapType min_input_shape = {},
const ShapeMapType max_input_shape = {},
const ShapeMapType optim_input_shape = {},
......@@ -238,6 +239,7 @@ class TensorRTEngine {
precision_(precision),
calibrator_(calibrator),
device_id_(device_id),
with_dynamic_shape_(with_dynamic_shape),
min_input_shape_(min_input_shape),
max_input_shape_(max_input_shape),
optim_input_shape_(optim_input_shape),
......@@ -247,31 +249,6 @@ class TensorRTEngine {
disable_trt_plugin_fp16_(disable_trt_plugin_fp16),
model_precision_(model_precision),
logger_(logger) {
if (min_input_shape_.size() != 0 && max_input_shape_.size() != 0 &&
optim_input_shape_.size() != 0) {
PADDLE_ENFORCE_EQ(
min_input_shape_.size(),
max_input_shape_.size(),
platform::errors::InvalidArgument(
"The min_input_shape_'s size(%d) should be equal to the "
"size(%d) of max_input_shape_",
min_input_shape_.size(),
max_input_shape_.size()));
PADDLE_ENFORCE_EQ(
min_input_shape_.size(),
optim_input_shape_.size(),
platform::errors::InvalidArgument(
"The min_input_shape_'s size(%d) should be equal to the "
"size(%d) of optim_input_shape_",
min_input_shape_.size(),
optim_input_shape_.size()));
#if IS_TRT_VERSION_GE(6000)
with_dynamic_shape_ = true;
#else
LOG(WARNING) << "Using dynamic shape of TRT need ensure that the TRT "
"version should be at least 6.";
#endif
}
dy::initLibNvInferPlugins(&logger, "");
}
......@@ -477,17 +454,27 @@ class TensorRTEngine {
ShapeMapType optim_shape_tensor() { return optim_shape_tensor_; }
bool AdjustDynamicShapeRange(const ShapeMapType& runtime_input_shape,
std::vector<std::string>* changed) {
const ShapeMapType& runtime_shape_tensor,
std::vector<std::string>* changed,
std::vector<std::string>* tensor_changed) {
bool ret = false;
changed->clear();
tensor_changed->clear();
for (const auto& it : runtime_input_shape) {
auto name = it.first;
auto input_shape = it.second;
PADDLE_ENFORCE_EQ(
min_input_shape_.count(name),
true,
platform::errors::InvalidArgument(
"TRT dynamic_shape min_input_shape %s not found.", name));
bool min_change = false;
bool max_change = false;
std::vector<int> bak_min_shape;
std::vector<int> bak_max_shape;
if (!min_input_shape_.count(name)) {
min_input_shape_[name] = input_shape;
max_input_shape_[name] = input_shape;
optim_input_shape_[name] = input_shape;
min_change = true;
max_change = true;
ret = true;
} else {
PADDLE_ENFORCE_EQ(min_input_shape_[name].size(),
input_shape.size(),
platform::errors::InvalidArgument(
......@@ -499,10 +486,9 @@ class TensorRTEngine {
min_input_shape_[name].size(),
name,
input_shape.size()));
auto bak_min_shape = min_input_shape_[name];
auto bak_max_shape = max_input_shape_[name];
bool min_change = false;
bool max_change = false;
bak_min_shape = min_input_shape_[name];
bak_max_shape = max_input_shape_[name];
for (size_t d = 0; d < input_shape.size(); ++d) {
if (input_shape[d] < min_input_shape_[name][d]) {
ret = true;
......@@ -515,17 +501,69 @@ class TensorRTEngine {
max_input_shape_[name][d] = input_shape[d];
}
}
}
if (min_change)
LOG(INFO) << "refactor shape range: " << name << ", min_shape from "
<< Vec2Str(bak_min_shape) << " to "
LOG(INFO) << "refactor tensor shape range: " << name
<< ", min_shape from " << Vec2Str(bak_min_shape) << " to "
<< Vec2Str(min_input_shape_[name]);
if (max_change)
LOG(INFO) << "refactor shape range: " << name << ", max_shape from "
<< Vec2Str(bak_max_shape) << " to "
LOG(INFO) << "refactor tensor shape range: " << name
<< ", max_shape from " << Vec2Str(bak_max_shape) << " to "
<< Vec2Str(max_input_shape_[name]);
if (min_change || max_change) changed->push_back(name);
}
for (const auto& it : runtime_shape_tensor) {
auto name = it.first;
auto shape_tensor = it.second;
bool min_change = false;
bool max_change = false;
std::vector<int> bak_min_shape;
std::vector<int> bak_max_shape;
if (!min_shape_tensor_.count(name)) {
min_shape_tensor_[name] = shape_tensor;
max_shape_tensor_[name] = shape_tensor;
optim_shape_tensor_[name] = shape_tensor;
min_change = true;
max_change = true;
ret = true;
} else {
PADDLE_ENFORCE_EQ(min_shape_tensor_[name].size(),
shape_tensor.size(),
platform::errors::InvalidArgument(
"TRT dynamic_shape min_shape_tensor %s size not "
"equal, the min_shape_tensor[%s].size()=%d"
", but the runtime_shape_tensor[%s].size()=%d.",
name,
name,
min_shape_tensor_[name].size(),
name,
shape_tensor.size()));
bak_min_shape = min_shape_tensor_[name];
bak_max_shape = max_shape_tensor_[name];
for (size_t d = 0; d < shape_tensor.size(); ++d) {
if (shape_tensor[d] < min_shape_tensor_[name][d]) {
ret = true;
min_change = true;
min_shape_tensor_[name][d] = shape_tensor[d];
}
if (shape_tensor[d] > max_shape_tensor_[name][d]) {
ret = true;
max_change = true;
max_shape_tensor_[name][d] = shape_tensor[d];
}
}
}
if (min_change)
LOG(INFO) << "refactor shape tensor range: " << name
<< ", min_shape from " << Vec2Str(bak_min_shape) << " to "
<< Vec2Str(min_shape_tensor_[name]);
if (max_change)
LOG(INFO) << "refactor shape tensor range: " << name
<< ", max_shape from " << Vec2Str(bak_max_shape) << " to "
<< Vec2Str(max_shape_tensor_[name]);
if (min_change || max_change) tensor_changed->push_back(name);
}
return ret;
}
......@@ -670,6 +708,7 @@ class TensorRTEngine {
int max_profile_num_{1};
int cur_profile_num_{0};
std::unordered_map<PredictorID, int> profile_index_;
bool with_dynamic_shape_{false};
ShapeMapType min_input_shape_;
ShapeMapType max_input_shape_;
ShapeMapType optim_input_shape_;
......@@ -706,9 +745,6 @@ class TensorRTEngine {
std::unordered_map<std::string, paddle::any> attrs_;
std::unordered_map<std::string, std::function<void(void)>> attr_dels_;
// For dynamic shape
bool with_dynamic_shape_{false};
#if IS_TRT_VERSION_GE(6000)
int binding_num_;
infer_ptr<nvinfer1::IBuilderConfig> infer_builder_config_;
......@@ -772,6 +808,7 @@ class TRTEngineManager {
AnalysisConfig::Precision precision = AnalysisConfig::Precision::kFloat32,
TRTInt8Calibrator* calibrator = nullptr,
int device_id = 0,
bool with_dynamic_shape = false,
const std::map<std::string, std::vector<int>> min_input_shape = {},
const std::map<std::string, std::vector<int>> max_input_shape = {},
const std::map<std::string, std::vector<int>> optim_input_shape = {},
......@@ -786,6 +823,7 @@ class TRTEngineManager {
precision,
calibrator,
device_id,
with_dynamic_shape,
min_input_shape,
max_input_shape,
optim_input_shape,
......
......@@ -190,6 +190,10 @@ inline void PrintITensorShape(nvinfer1::ITensor* X) {
template <typename T>
inline std::string Vec2Str(const std::vector<T>& vec) {
std::ostringstream os;
if (vec.empty()) {
os << "()";
return os.str();
}
os << "(";
for (size_t i = 0; i < vec.size() - 1; ++i) {
os << vec[i] << ",";
......
......@@ -70,6 +70,7 @@ class TensorRTDynamicShapeValueEngineTest : public ::testing::Test {
AnalysisConfig::Precision::kFloat32,
nullptr,
0,
true,
min_input_shape,
max_input_shape,
optim_input_shape,
......@@ -196,6 +197,7 @@ class TensorRTDynamicEngineTest : public ::testing::Test {
AnalysisConfig::Precision::kHalf,
nullptr,
0,
true,
min_input_shape,
max_input_shape,
optim_input_shape,
......@@ -373,6 +375,7 @@ class TensorRTDynamicTestFusedTokenPrune : public ::testing::Test {
AnalysisConfig::Precision::kFloat32,
nullptr,
0,
true,
min_input_shape,
max_input_shape,
optim_input_shape,
......@@ -581,6 +584,7 @@ class TensorRTDynamicTestFusedTokenPruneHalf : public ::testing::Test {
AnalysisConfig::Precision::kHalf,
nullptr,
0,
true,
min_input_shape,
max_input_shape,
optim_input_shape,
......@@ -783,6 +787,7 @@ class TensorRTDynamicShapeGNTest : public ::testing::Test {
AnalysisConfig::Precision::kInt8,
nullptr,
0,
true,
min_input_shape,
max_input_shape,
optim_input_shape,
......
......@@ -277,13 +277,18 @@ void UpdateShapeRangeInfo(
const std::map<std::string, std::vector<int32_t>> &min_shape,
const std::map<std::string, std::vector<int32_t>> &max_shape,
const std::map<std::string, std::vector<int32_t>> &opt_shape,
const std::vector<std::string> &names) {
const std::map<std::string, std::vector<int32_t>> &min_value,
const std::map<std::string, std::vector<int32_t>> &max_value,
const std::map<std::string, std::vector<int32_t>> &opt_value,
const std::vector<std::string> &names,
const std::vector<std::string> &tensor_names) {
paddle::inference::proto::ShapeRangeInfos shape_range_infos;
DeserializeShapeRangeInfo(path, &shape_range_infos);
for (const auto &name : names) {
bool has_name = false;
for (int i = 0; i < shape_range_infos.shape_range_info_size(); ++i) {
auto *info = shape_range_infos.mutable_shape_range_info(i);
for (const auto &name : names) {
if (info->name() == name) {
info->clear_min_shape();
info->clear_max_shape();
......@@ -294,9 +299,50 @@ void UpdateShapeRangeInfo(
info->add_max_shape(max_shape.at(name)[j]);
for (size_t j = 0; j < opt_shape.at(name).size(); ++j)
info->add_opt_shape(opt_shape.at(name)[j]);
has_name = true;
break;
}
}
if (!has_name) {
auto *info = shape_range_infos.add_shape_range_info();
info->set_name(name);
for (size_t j = 0; j < min_shape.at(name).size(); ++j)
info->add_min_shape(min_shape.at(name)[j]);
for (size_t j = 0; j < max_shape.at(name).size(); ++j)
info->add_max_shape(max_shape.at(name)[j]);
for (size_t j = 0; j < opt_shape.at(name).size(); ++j)
info->add_opt_shape(opt_shape.at(name)[j]);
}
}
for (const auto &name : tensor_names) {
bool has_name = false;
for (int i = 0; i < shape_range_infos.shape_range_info_size(); ++i) {
auto *info = shape_range_infos.mutable_shape_range_info(i);
if (info->name() == name) {
info->clear_min_value();
info->clear_max_value();
info->clear_opt_value();
for (size_t j = 0; j < min_value.at(name).size(); ++j)
info->add_min_value(min_value.at(name)[j]);
for (size_t j = 0; j < max_value.at(name).size(); ++j)
info->add_max_value(max_value.at(name)[j]);
for (size_t j = 0; j < opt_value.at(name).size(); ++j)
info->add_opt_value(opt_value.at(name)[j]);
has_name = true;
break;
}
}
if (!has_name) {
auto *info = shape_range_infos.add_shape_range_info();
info->set_name(name);
for (size_t j = 0; j < min_value.at(name).size(); ++j)
info->add_min_value(min_value.at(name)[j]);
for (size_t j = 0; j < max_value.at(name).size(); ++j)
info->add_max_value(max_value.at(name)[j]);
for (size_t j = 0; j < opt_value.at(name).size(); ++j)
info->add_opt_value(opt_value.at(name)[j]);
}
}
inference::SerializeShapeRangeInfo(path, shape_range_infos);
......
......@@ -63,6 +63,10 @@ void UpdateShapeRangeInfo(
const std::map<std::string, std::vector<int32_t>>& min_shape,
const std::map<std::string, std::vector<int32_t>>& max_shape,
const std::map<std::string, std::vector<int32_t>>& opt_shape,
const std::vector<std::string>& names);
const std::map<std::string, std::vector<int32_t>>& min_value,
const std::map<std::string, std::vector<int32_t>>& max_value,
const std::map<std::string, std::vector<int32_t>>& opt_value,
const std::vector<std::string>& names,
const std::vector<std::string>& tensor_names);
} // namespace inference
} // namespace paddle
......@@ -133,8 +133,15 @@ TEST(shape_info_io, read_and_write) {
min_shape.insert(std::make_pair("test1", std::vector<int32_t>{1, 3, 56, 56}));
std::vector<std::string> names{"test1"};
paddle::inference::UpdateShapeRangeInfo(
path, min_shape, max_shape, opt_shape, names);
paddle::inference::UpdateShapeRangeInfo(path,
min_shape,
max_shape,
opt_shape,
min_value,
max_value,
opt_value,
names,
names);
ASSERT_THROW(paddle::inference::DeserializeShapeRangeInfo("no_exists_file",
&min_shape,
......
......@@ -20,6 +20,7 @@
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/kernels/funcs/data_type_transform.h"
#ifdef PADDLE_WITH_CUDA
#include <memory>
#include <string>
......@@ -188,6 +189,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
int predictor_id_;
int device_id_;
bool allow_build_at_runtime_{false};
bool with_dynamic_shape_{false};
std::string shape_range_info_path_;
std::string model_opt_cache_dir_;
bool use_static_engine_;
......@@ -216,6 +218,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
predictor_id_ = Attr<int>("predictor_id");
shape_range_info_path_ = Attr<std::string>("shape_range_info_path");
allow_build_at_runtime_ = Attr<bool>("allow_build_at_runtime");
with_dynamic_shape_ = Attr<bool>("with_dynamic_shape");
use_static_engine_ = Attr<bool>("use_static_engine");
if (use_static_engine_) {
model_opt_cache_dir_ = Attr<std::string>("model_opt_cache_dir");
......@@ -329,8 +332,9 @@ class TensorRTEngineOp : public framework::OperatorBase {
}
auto *trt_engine = GetEngine(scope, dev_place);
if (trt_engine->with_dynamic_shape()) {
// get runtime input shapes.
// get runtime input shapes and shape tensors.
std::map<std::string, std::vector<int32_t>> runtime_input_shape;
std::map<std::string, std::vector<int32_t>> runtime_shape_tensor;
for (auto name : runtime_input_names_) {
auto &t =
inference::analysis::GetFromScope<phi::DenseTensor>(scope, name);
......@@ -338,8 +342,59 @@ class TensorRTEngineOp : public framework::OperatorBase {
<< t.dims() << ")";
auto t_shape = phi::vectorize<int32_t>(t.dims());
runtime_input_shape.insert(std::make_pair(name, t_shape));
// We need collect value range for shape tensor for Paddle-TRT's use.
// To be noticed, this method to identify all shape tensors is based on
// assumption that all shape tensors in the model have numbers <= 7.
// This is a simple method to identify all shape tensors with some
// mistakes, but it doesn't matter.
auto is_shape_tensor = t.numel() <= 7 && t.numel() >= 1;
if (trt_engine->engine()) {
auto *engine = trt_engine->engine();
is_shape_tensor =
engine->isShapeBinding(engine->getBindingIndex(name.c_str()));
}
if ((t.dtype() == phi::DataType::INT32 ||
t.dtype() == phi::DataType::INT64) &&
is_shape_tensor) {
std::vector<int> int32_host(t.numel());
paddle::platform::DeviceContextPool &pool =
paddle::platform::DeviceContextPool::Instance();
if (platform::is_cpu_place(t.place())) {
auto &int32_tensor = t;
if (t.dtype() == phi::DataType::INT64) {
auto *cpu_ctx = pool.Get(platform::CPUPlace());
int32_tensor = phi::funcs::TransDataType(
reinterpret_cast<const phi::CPUContext &>(*cpu_ctx),
t,
DataType::INT32);
}
paddle::memory::Copy(platform::CPUPlace(),
int32_host.data(),
platform::CPUPlace(),
int32_tensor.data<int>(),
int32_tensor.numel() * sizeof(int));
} else if (platform::is_gpu_place(t.place())) {
#if defined(PADDLE_WITH_CUDA)
auto *dev_ctx = pool.Get(t.place());
auto &int32_tensor = t;
if (t.dtype() == phi::DataType::INT64) {
int32_tensor = phi::funcs::TransDataType(
reinterpret_cast<const phi::GPUContext &>(*dev_ctx),
t,
DataType::INT32);
}
paddle::memory::Copy(platform::CPUPlace(),
int32_host.data(),
int32_tensor.place(),
int32_tensor.data<int>(),
int32_tensor.numel() * sizeof(int),
nullptr);
#endif
}
runtime_shape_tensor[name] = int32_host;
}
}
if (!allow_build_at_runtime_) {
std::map<std::string, std::vector<int>> min_input_shape =
trt_engine->min_input_shape();
......@@ -364,12 +419,18 @@ class TensorRTEngineOp : public framework::OperatorBase {
} else {
// compare runtime_input_shape and trt_engine dynamic shapes.
std::vector<std::string> shape_changed_name;
bool is_adjusted = trt_engine->AdjustDynamicShapeRange(
runtime_input_shape, &shape_changed_name);
std::vector<std::string> tensor_changed_name;
bool is_adjusted =
trt_engine->AdjustDynamicShapeRange(runtime_input_shape,
runtime_shape_tensor,
&shape_changed_name,
&tensor_changed_name);
if (is_adjusted) {
LOG(INFO) << "Adjust dynamic shape range, rebuild trt engine!";
if (trt_engine->engine()) {
trt_engine->ResetContext();
trt_engine->ClearTensorMap();
}
auto *anc = scope.parent();
while (anc && anc->parent()) {
anc = anc->parent();
......@@ -384,7 +445,11 @@ class TensorRTEngineOp : public framework::OperatorBase {
trt_engine->min_input_shape(),
trt_engine->max_input_shape(),
trt_engine->optim_input_shape(),
shape_changed_name);
trt_engine->min_shape_tensor(),
trt_engine->max_shape_tensor(),
trt_engine->optim_shape_tensor(),
shape_changed_name,
tensor_changed_name);
}
if (use_static_engine_) {
......@@ -452,6 +517,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
precision_mode_,
calib_res->calib_.get(),
dev_place.device,
with_dynamic_shape_,
min_input_shape,
max_input_shape,
opt_input_shape,
......@@ -766,6 +832,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
precision_mode_,
calibrator_.get(),
device_id_,
with_dynamic_shape_,
min_input_shape_,
max_input_shape_,
opt_input_shape_);
......
......@@ -140,6 +140,7 @@ void DynamicShapeTest(bool allow_build_at_runtime) {
engine_op_desc.SetAttr("use_static_engine", true);
engine_op_desc.SetAttr("dynamic_shape_names", std::vector<std::string>{"x"});
engine_op_desc.SetAttr("dynamic_shape_lens", std::vector<int>{4});
engine_op_desc.SetAttr("with_dynamic_shape", true);
engine_op_desc.SetAttr("min_input_shape", std::vector<int>{1, 1, 1, 1});
engine_op_desc.SetAttr("max_input_shape", std::vector<int>{16, 16, 16, 16});
engine_op_desc.SetAttr("opt_input_shape", std::vector<int>{2, 4, 4, 4});
......
......@@ -861,7 +861,9 @@ void BindAnalysisConfig(py::module *m) {
.def("shape_range_info_collected",
&AnalysisConfig::shape_range_info_collected)
.def("enable_tuned_tensorrt_dynamic_shape",
&AnalysisConfig::EnableTunedTensorRtDynamicShape)
&AnalysisConfig::EnableTunedTensorRtDynamicShape,
py::arg("shape_range_info_path") = "",
py::arg("allow_build_at_runtime") = true)
.def("tuned_tensorrt_dynamic_shape",
&AnalysisConfig::tuned_tensorrt_dynamic_shape)
.def("trt_allow_build_at_runtime",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册