未验证 提交 b6e9498e 编写于 作者: Z zlsh80826 提交者: GitHub

[Paddle-TRT] Remove TensorRT deprecated API (#33654)

* add trt LT version helper

* remove deprecated nvinfer1::DimsCHW and replace it to nvinfer1::Dims3

* remove deprecated nvinfer1::DimsNCHW and replace it to nvinfer1::Dims4

* update deserialize engine

* update to createNetworkV2

* update to createNetworkV2

* update buildWithConfig and remove redundent config settings

* replace createNetwork to createNetworkV2

* fix int8

* addMatrixMultiply

* remove unnecessary const cast

* IBuilder->setInt8Calibrator() is deprecated

* auto enable fp16 when using int8

* remove the redundant line
上级 57352bc7
......@@ -45,9 +45,16 @@ class MatMulOpConverter : public OpConverter {
bool transpose_X = BOOST_GET_CONST(bool, op_desc.GetAttr("transpose_X"));
bool transpose_Y = BOOST_GET_CONST(bool, op_desc.GetAttr("transpose_Y"));
auto* layer = TRT_ENGINE_ADD_LAYER(
engine_, MatrixMultiply, *const_cast<nvinfer1::ITensor*>(input1),
transpose_X, *const_cast<nvinfer1::ITensor*>(input2), transpose_Y);
nvinfer1::MatrixOperation matrix_operation_X =
transpose_X ? nvinfer1::MatrixOperation::kTRANSPOSE
: nvinfer1::MatrixOperation::kNONE;
nvinfer1::MatrixOperation matrix_operation_Y =
transpose_Y ? nvinfer1::MatrixOperation::kTRANSPOSE
: nvinfer1::MatrixOperation::kNONE;
auto* layer =
TRT_ENGINE_ADD_LAYER(engine_, MatrixMultiply, *input1,
matrix_operation_X, *input2, matrix_operation_Y);
float alpha = BOOST_GET_CONST(float, op_desc.GetAttr("alpha"));
auto output_name = op_desc.Output("Out")[0];
......
......@@ -57,7 +57,7 @@ class ShuffleChannelOpConverter : public OpConverter {
auto* output = layer->getOutput(0);
auto* reshape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *output);
nvinfer1::DimsCHW reshape_dim2(c, h, w);
nvinfer1::Dims3 reshape_dim2(c, h, w);
reshape_layer->setReshapeDimensions(reshape_dim2);
auto output_name = op_desc.Output("Out")[0];
......
......@@ -28,12 +28,12 @@ TEST(batch_norm_op, test) {
TRTConvertValidation validator(5, parameters, scope, 1 << 15);
std::vector<int> param_shape{2};
validator.DeclInputVar("batch_norm_X", nvinfer1::DimsCHW(2, 5, 5));
validator.DeclInputVar("batch_norm_X", nvinfer1::Dims3(2, 5, 5));
validator.DeclParamVar("batch_norm_scale", param_shape);
validator.DeclParamVar("batch_norm_bias", param_shape);
validator.DeclParamVar("batch_norm_mean", param_shape);
validator.DeclParamVar("batch_norm_variance", param_shape);
validator.DeclOutputVar("batch_norm_Y", nvinfer1::DimsCHW(2, 5, 5));
validator.DeclOutputVar("batch_norm_Y", nvinfer1::Dims3(2, 5, 5));
validator.DeclOutputVar("batch_norm_save_mean", param_shape);
validator.DeclOutputVar("batch_norm_save_variance", param_shape);
......
......@@ -24,10 +24,10 @@ TEST(concat_op, test) {
std::unordered_set<std::string> parameters({""});
framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("concat_x1", nvinfer1::DimsCHW(10, 3, 1));
validator.DeclInputVar("concat_x2", nvinfer1::DimsCHW(3, 3, 1));
validator.DeclInputVar("concat_x3", nvinfer1::DimsCHW(7, 3, 1));
validator.DeclOutputVar("concat_out", nvinfer1::DimsCHW(20, 3, 1));
validator.DeclInputVar("concat_x1", nvinfer1::Dims3(10, 3, 1));
validator.DeclInputVar("concat_x2", nvinfer1::Dims3(3, 3, 1));
validator.DeclInputVar("concat_x3", nvinfer1::Dims3(7, 3, 1));
validator.DeclOutputVar("concat_out", nvinfer1::Dims3(20, 3, 1));
// Prepare Op description
framework::OpDesc desc;
......
......@@ -25,10 +25,9 @@ TEST(DropoutOpConverter, main) {
TRTConvertValidation validator(8, parameters, scope, 1000);
std::vector<int> tensor_shape{8, 10};
validator.DeclInputVar("dropout-X", tensor_shape,
nvinfer1::DimsCHW(10, 1, 1));
validator.DeclOutputVar("dropout-Out", nvinfer1::DimsCHW(10, 1, 1));
validator.DeclOutputVar("mask-Out", nvinfer1::DimsCHW(10, 1, 1));
validator.DeclInputVar("dropout-X", tensor_shape, nvinfer1::Dims3(10, 1, 1));
validator.DeclOutputVar("dropout-Out", nvinfer1::Dims3(10, 1, 1));
validator.DeclOutputVar("mask-Out", nvinfer1::Dims3(10, 1, 1));
// Prepare Op description
framework::OpDesc desc;
......
......@@ -24,9 +24,9 @@ TEST(elementwise_op, add_weight) {
std::unordered_set<std::string> parameters({"elementwise_add-Y"});
framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1 << 15);
validator.DeclInputVar("elementwise_add-X", nvinfer1::DimsCHW(10, 3, 3));
validator.DeclInputVar("elementwise_add-X", nvinfer1::Dims3(10, 3, 3));
validator.DeclParamVar("elementwise_add-Y", nvinfer1::Dims3(10, 1, 1));
validator.DeclOutputVar("elementwise_add-Out", nvinfer1::DimsCHW(10, 3, 3));
validator.DeclOutputVar("elementwise_add-Out", nvinfer1::Dims3(10, 3, 3));
// Prepare Op description
framework::OpDesc desc;
......@@ -50,11 +50,11 @@ TEST(elementwise_op, native) {
framework::Scope scope;
TRTConvertValidation validator(batch_size, parameters, scope, 1 << 15);
validator.DeclInputVar("elementwise_" + type + "-X",
nvinfer1::DimsCHW(10, 3, 3));
nvinfer1::Dims3(10, 3, 3));
validator.DeclInputVar("elementwise_" + type + "-Y",
nvinfer1::Dims3(10, 3, 3));
validator.DeclOutputVar("elementwise_" + type + "-Out",
nvinfer1::DimsCHW(10, 3, 3));
nvinfer1::Dims3(10, 3, 3));
// Prepare Op description
framework::OpDesc desc;
......@@ -78,11 +78,11 @@ TEST(elementwise_op, plugin) {
framework::Scope scope;
TRTConvertValidation validator(batch_size, parameters, scope, 1 << 15);
validator.DeclInputVar("elementwise_" + type + "-X",
nvinfer1::DimsCHW(10, 3, 3));
nvinfer1::Dims3(10, 3, 3));
validator.DeclInputVar("elementwise_" + type + "-Y",
nvinfer1::Dims3(10, 1, 1));
validator.DeclOutputVar("elementwise_" + type + "-Out",
nvinfer1::DimsCHW(10, 3, 3));
nvinfer1::Dims3(10, 3, 3));
// Prepare Op description
framework::OpDesc desc;
......
......@@ -24,8 +24,8 @@ TEST(leaky_relu_op, test_leaky_relu) {
std::unordered_set<std::string> parameters;
framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("leaky_relu_input", nvinfer1::DimsCHW(3, 2, 2));
validator.DeclOutputVar("leaky_relu_out", nvinfer1::DimsCHW(3, 2, 2));
validator.DeclInputVar("leaky_relu_input", nvinfer1::Dims3(3, 2, 2));
validator.DeclOutputVar("leaky_relu_out", nvinfer1::Dims3(3, 2, 2));
// Prepare Op description
framework::OpDesc desc;
......
......@@ -24,9 +24,9 @@ TEST(prelu_op, test_channel_wise) {
std::unordered_set<std::string> parameters({"prelu_alpha"});
framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("prelu_input", nvinfer1::DimsCHW(3, 2, 2));
validator.DeclInputVar("prelu_input", nvinfer1::Dims3(3, 2, 2));
validator.DeclParamVar("prelu_alpha", nvinfer1::Dims3(3, 1, 1));
validator.DeclOutputVar("prelu_out", nvinfer1::DimsCHW(3, 2, 2));
validator.DeclOutputVar("prelu_out", nvinfer1::Dims3(3, 2, 2));
// Prepare Op description
framework::OpDesc desc;
......@@ -46,9 +46,9 @@ TEST(prelu_op, test_element_wise) {
std::unordered_set<std::string> parameters({"prelu_alpha"});
framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("prelu_input", nvinfer1::DimsCHW(3, 2, 2));
validator.DeclInputVar("prelu_input", nvinfer1::Dims3(3, 2, 2));
validator.DeclParamVar("prelu_alpha", nvinfer1::Dims4(10, 3, 2, 2));
validator.DeclOutputVar("prelu_out", nvinfer1::DimsCHW(3, 2, 2));
validator.DeclOutputVar("prelu_out", nvinfer1::Dims3(3, 2, 2));
// Prepare Op description
framework::OpDesc desc;
......@@ -68,9 +68,9 @@ TEST(prelu_op, test_scalar) {
std::unordered_set<std::string> parameters({"prelu_alpha"});
framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("prelu_input", nvinfer1::DimsCHW(3, 2, 2));
validator.DeclInputVar("prelu_input", nvinfer1::Dims3(3, 2, 2));
validator.DeclParamVar("prelu_alpha", nvinfer1::Dims3(1, 1, 1));
validator.DeclOutputVar("prelu_out", nvinfer1::DimsCHW(3, 2, 2));
validator.DeclOutputVar("prelu_out", nvinfer1::Dims3(3, 2, 2));
// Prepare Op description
framework::OpDesc desc;
......
......@@ -24,8 +24,8 @@ TEST(leaky_relu_op, test_leaky_relu) {
std::unordered_set<std::string> parameters;
framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("sc_input", nvinfer1::DimsCHW(4, 2, 2));
validator.DeclOutputVar("sc_out", nvinfer1::DimsCHW(4, 2, 2));
validator.DeclInputVar("sc_input", nvinfer1::Dims3(4, 2, 2));
validator.DeclOutputVar("sc_out", nvinfer1::Dims3(4, 2, 2));
// Prepare Op description
framework::OpDesc desc;
......
......@@ -25,9 +25,8 @@ TEST(SoftMaxOpConverter, main) {
TRTConvertValidation validator(8, parameters, scope, 1000);
std::vector<int> tensor_shape{8, 10};
validator.DeclInputVar("softmax-X", tensor_shape,
nvinfer1::DimsCHW(10, 1, 1));
validator.DeclOutputVar("softmax-Out", nvinfer1::DimsCHW(10, 1, 1));
validator.DeclInputVar("softmax-X", tensor_shape, nvinfer1::Dims3(10, 1, 1));
validator.DeclOutputVar("softmax-Out", nvinfer1::Dims3(10, 1, 1));
// Prepare Op description
framework::OpDesc desc;
......
......@@ -28,7 +28,7 @@ void TensorRTSplitTest(const std::vector<int> &in_shape,
TRTConvertValidation validator(BatchSize + 1, parameters, scope, 10000);
auto make_dim = [](const std::vector<int> &shape) {
nvinfer1::DimsCHW dim;
nvinfer1::Dims3 dim;
dim.c() = shape[0];
dim.h() = shape[1];
dim.w() = shape[2];
......
......@@ -24,8 +24,8 @@ TEST(swish_op, test_swish) {
std::unordered_set<std::string> parameters;
framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("sw_input", nvinfer1::DimsCHW(3, 2, 2));
validator.DeclOutputVar("sw_out", nvinfer1::DimsCHW(3, 2, 2));
validator.DeclInputVar("sw_input", nvinfer1::Dims3(3, 2, 2));
validator.DeclOutputVar("sw_out", nvinfer1::Dims3(3, 2, 2));
// Prepare Op description
framework::OpDesc desc;
......
......@@ -34,17 +34,15 @@ void TensorRTEngine::InitNetwork() {
infer_builder_.reset(createInferBuilder(&logger_));
if (with_dynamic_shape_) {
#if IS_TRT_VERSION_GE(6000)
infer_networkv2_.reset(infer_builder_->createNetworkV2(
infer_network_.reset(infer_builder_->createNetworkV2(
1U << static_cast<int>(
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH)));
infer_builder_config_.reset(infer_builder_->createBuilderConfig());
infer_ptr<nvinfer1::IBuilderConfig> infer_builder_config_;
optim_profile_ = infer_builder_->createOptimizationProfile();
#endif
} else {
infer_network_.reset(infer_builder_->createNetwork());
infer_network_.reset(infer_builder_->createNetworkV2(0U));
}
infer_builder_config_.reset(infer_builder_->createBuilderConfig());
optim_profile_ = infer_builder_->createOptimizationProfile();
}
void TensorRTEngine::Execute(int batch_size, std::vector<void *> *buffers,
......@@ -73,12 +71,12 @@ void TensorRTEngine::FreezeNetwork() {
"Call InitNetwork first to initialize network."));
// build engine.
infer_builder_->setMaxBatchSize(max_batch_);
infer_builder_->setMaxWorkspaceSize(max_workspace_);
infer_builder_config_->setMaxWorkspaceSize(max_workspace_);
bool enable_fp16 = (precision_ == AnalysisConfig::Precision::kHalf);
#if IS_TRT_VERSION_GE(5000)
if (enable_fp16) {
bool support_fp16 = infer_builder_->platformHasFastFp16();
infer_builder_->setFp16Mode(support_fp16);
infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kFP16);
if (!support_fp16) {
LOG(INFO) << "You specify FP16 mode, but the hardware do not support "
"FP16 speed up, use FP32 instead.";
......@@ -86,23 +84,19 @@ void TensorRTEngine::FreezeNetwork() {
LOG(INFO) << "Run Paddle-TRT FP16 mode";
}
}
#else
if (enable_fp16)
LOG(INFO) << "Using FP16 in Paddle-TRT must ensure that the version of TRT "
"is at least 5."
"So, use FP32 to run.";
#endif
bool enable_int8 = (precision_ == AnalysisConfig::Precision::kInt8);
bool enable_int8 = (precision_ == AnalysisConfig::Precision::kInt8);
if (enable_int8) {
infer_builder_->setInt8Mode(true);
infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kFP16);
infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kINT8);
infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kSTRICT_TYPES);
if (calibrator_) {
infer_builder_->setInt8Calibrator(calibrator_);
infer_builder_config_->setInt8Calibrator(calibrator_);
} else {
infer_builder_->setInt8Calibrator(nullptr);
infer_builder_config_->setInt8Calibrator(nullptr);
#if IS_TRT_VERSION_GE(5000)
infer_builder_->setStrictTypeConstraints(true);
for (auto &quant_range : quant_dynamic_range_) {
auto tensor = quant_range.first;
float range = quant_range.second;
......@@ -116,6 +110,7 @@ void TensorRTEngine::FreezeNetwork() {
all_t.insert(layer->getOutput(j));
}
}
for (int i = 0; i < network()->getNbInputs(); i++) {
all_t.insert(network()->getInput(i));
}
......@@ -127,6 +122,7 @@ void TensorRTEngine::FreezeNetwork() {
<< ", this might be ok when trt does not need this range";
}
}
#if IS_TRT_VERSION_GE(5122)
auto is_layer_int8 = [&](nvinfer1::ILayer *layer) -> bool {
for (int j = 0; j < layer->getNbInputs(); j++) {
......@@ -189,9 +185,9 @@ void TensorRTEngine::FreezeNetwork() {
<< infer_builder_->getNbDLACores() << ", but got "
<< dla_core_ << ", so use use 0 as default.";
}
infer_builder_->setDefaultDeviceType(nvinfer1::DeviceType::kDLA);
infer_builder_->setDLACore(dla_core_);
infer_builder_->allowGPUFallback(true);
infer_builder_config_->setDefaultDeviceType(nvinfer1::DeviceType::kDLA);
infer_builder_config_->setDLACore(dla_core_);
infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK);
LOG(INFO) << "TensorRT DLA enabled in FreezeNetwork(), DLACore "
<< dla_core_;
}
......@@ -212,30 +208,18 @@ void TensorRTEngine::FreezeNetwork() {
Vec2TRT_Dims(optim_input_shape_[input.first], input.first, true));
}
infer_builder_config_->addOptimizationProfile(optim_profile_);
infer_builder_config_->setMaxWorkspaceSize(max_workspace_);
if (enable_int8) {
// Due to a bug of TRT, we must set precision BuilderFlag to kFP16 before
// kINT8 here to perform INT8 inference.
infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kFP16);
infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kINT8);
infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kSTRICT_TYPES);
if (WithFp16() && disable_trt_plugin_fp16()) {
LOG(INFO) << "NOTE: In order to achieve higher accuracy, you have "
"disabled the fp16 mode of TRT Plugin,\n"
<< "you can reopen it with "
"'config.SetDynamicShapeInfo(min_shape, max_shape, "
"opt_shape, false /*disable_trt_plugin_fp16*/)'";
}
if (WithFp16()) {
infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kFP16);
if (disable_trt_plugin_fp16()) {
LOG(INFO) << "NOTE: In order to achieve higher accuracy, you have "
"disabled the fp16 mode of TRT Plugin,\n"
<< "you can reopen it with "
"'config.SetDynamicShapeInfo(min_shape, max_shape, "
"opt_shape, false /*disable_trt_plugin_fp16*/)'";
}
}
infer_engine_.reset(infer_builder_->buildEngineWithConfig(
*network(), *infer_builder_config_));
#endif
} else {
infer_engine_.reset(infer_builder_->buildCudaEngine(*network()));
}
infer_engine_.reset(infer_builder_->buildEngineWithConfig(
*network(), *infer_builder_config_));
PADDLE_ENFORCE_NOT_NULL(
infer_engine_, platform::errors::Fatal(
"Build TensorRT cuda engine failed! Please recheck "
......
......@@ -102,7 +102,7 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<T>& shape, std::string input,
"trt dynamic_shape mode by SetTRTDynamicShapeInfo.",
input, ShapeStr(shape)));
}
return nvinfer1::DimsCHW(shape[1], shape[2], shape[3]);
return nvinfer1::Dims3(shape[1], shape[2], shape[3]);
} else if (shape.size() == 3UL) {
if (shape[1] == -1 || shape[2] == -1) {
PADDLE_THROW(platform::errors::InvalidArgument(
......@@ -112,10 +112,10 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<T>& shape, std::string input,
}
return nvinfer1::Dims2(shape[1], shape[2]);
}
return nvinfer1::DimsCHW(shape[1], 1, 1);
return nvinfer1::Dims3(shape[1], 1, 1);
} else {
if (shape.size() == 4UL) {
return nvinfer1::DimsNCHW(shape[0], shape[1], shape[2], shape[3]);
return nvinfer1::Dims4(shape[0], shape[1], shape[2], shape[3]);
} else if (shape.size() == 3UL) {
return nvinfer1::Dims3(shape[0], shape[1], shape[2]);
}
......@@ -277,22 +277,19 @@ class TensorRTEngine {
}
if (with_dynamic_shape_) {
#if IS_TRT_VERSION_GE(6000)
infer_engine_.reset(runtime->deserializeCudaEngine(
engine_serialized_data.c_str(), engine_serialized_data.size(),
nullptr));
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"To enable dynamic shape support, the TensorRT version should be "
"greater than 6.0.0"));
#endif
engine_serialized_data.c_str(), engine_serialized_data.size()));
} else {
#if IS_TRT_VERSION_LT(8000)
infer_engine_.reset(runtime->deserializeCudaEngine(
engine_serialized_data.c_str(), engine_serialized_data.size(),
&inference::Singleton<plugin::PluginFactoryTensorRT>::Global()));
#else
infer_engine_.reset(runtime->deserializeCudaEngine(
engine_serialized_data.c_str(), engine_serialized_data.size()));
#endif
}
PADDLE_ENFORCE_NOT_NULL(
infer_engine_,
platform::errors::Fatal(
......@@ -369,13 +366,7 @@ class TensorRTEngine {
void Execute(int batch_size, std::vector<void*>* buffers,
cudaStream_t stream = nullptr);
nvinfer1::INetworkDefinition* network() {
if (with_dynamic_shape_) {
return infer_networkv2_.get();
} else {
return infer_network_.get();
}
}
nvinfer1::INetworkDefinition* network() { return infer_network_.get(); }
ShapeMapType min_input_shape() { return min_input_shape_; }
ShapeMapType max_input_shape() { return max_input_shape_; }
......@@ -530,7 +521,6 @@ class TensorRTEngine {
// For dynamic shape
bool with_dynamic_shape_{false};
infer_ptr<nvinfer1::INetworkDefinition> infer_networkv2_;
#if IS_TRT_VERSION_GE(6000)
infer_ptr<nvinfer1::IBuilderConfig> infer_builder_config_;
nvinfer1::IOptimizationProfile* optim_profile_;
......
......@@ -31,6 +31,10 @@ namespace tensorrt {
((NV_TENSORRT_MAJOR * 1000 + NV_TENSORRT_MINOR * 100 + \
NV_TENSORRT_PATCH * 10 + NV_TENSORRT_BUILD) >= version)
#define IS_TRT_VERSION_LT(version) \
((NV_TENSORRT_MAJOR * 1000 + NV_TENSORRT_MINOR * 100 + \
NV_TENSORRT_PATCH * 10 + NV_TENSORRT_BUILD) < version)
#define TRT_VERSION \
NV_TENSORRT_MAJOR * 1000 + NV_TENSORRT_MINOR * 100 + \
NV_TENSORRT_PATCH * 10 + NV_TENSORRT_BUILD
......
......@@ -68,7 +68,7 @@ TEST_F(TensorRTEngineTest, add_layer) {
TensorRTEngine::Weight weight(nvinfer1::DataType::kFLOAT, raw_weight, size);
TensorRTEngine::Weight bias(nvinfer1::DataType::kFLOAT, raw_bias, size);
auto *x = engine_->DeclareInput("x", nvinfer1::DataType::kFLOAT,
nvinfer1::DimsCHW{1, 1, 1});
nvinfer1::Dims3{1, 1, 1});
auto *fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *x, size,
weight.get(), bias.get());
PADDLE_ENFORCE_NOT_NULL(fc_layer,
......@@ -123,7 +123,7 @@ TEST_F(TensorRTEngineTest, add_layer_multi_dim) {
TensorRTEngine::Weight weight(nvinfer1::DataType::kFLOAT, raw_weight, 4);
TensorRTEngine::Weight bias(nvinfer1::DataType::kFLOAT, raw_bias, 2);
auto *x = engine_->DeclareInput("x", nvinfer1::DataType::kFLOAT,
nvinfer1::DimsCHW{1, 2, 1});
nvinfer1::Dims3{1, 2, 1});
auto *fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *x, 2,
weight.get(), bias.get());
PADDLE_ENFORCE_NOT_NULL(fc_layer,
......
......@@ -80,7 +80,7 @@ nvinfer1::IHostMemory* CreateNetwork() {
nvinfer1::INetworkDefinition* network = builder->createNetwork();
// Add the input
auto input = network->addInput(kInputTensor, nvinfer1::DataType::kFLOAT,
nvinfer1::DimsCHW{1, 1, 1});
nvinfer1::Dims3{1, 1, 1});
EXPECT_NE(input, nullptr);
// Add the hidden layer.
auto layer = network->addFullyConnected(*input, 1, weights.get(), bias.get());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册