提交 1708ab2e 编写于 作者: M Megvii Engine Team

feat(mgb): add tensorrt runtime dynamic batch testcase

GitOrigin-RevId: 36372437ff4ece327331641687c5f0c146e664ec
上级 87c845fd
......@@ -483,6 +483,78 @@ std::pair<nvinfer1::IBuilder*, INetworkDefinition*> intl::ReshapeConcatTensorRTN
return std::make_pair(builder, network);
}
#if NV_TENSOR_RT_VERSION >= 6001
intl::DynamicShapeTensorRTNetwork::DynamicShapeTensorRTNetwork(
size_t n, size_t c, size_t h, size_t w) {
host_x = gen({n, c, h, w});
host_w1 = gen({32, 23, 3, 3});
host_b1 = gen({1, 32, 1, 1});
graph = ComputingGraph::make();
x = Host2DeviceCopy::make(*graph, host_x);
auto w1 = Host2DeviceCopy::make(*graph, host_w1),
b1 = Host2DeviceCopy::make(*graph, host_b1),
y01 = opr::Convolution::make(x, w1);
y1 = y01 + b1;
}
TensorRTUniquePtr<ICudaEngine> intl::DynamicShapeTensorRTNetwork::create_trt_network() {
CompNode::load("xpu0").activate();
Weights wt_filter_1{DataType::kFLOAT, nullptr, 0},
wt_bias_1{DataType::kFLOAT, nullptr, 0};
wt_filter_1.type = DataType::kFLOAT;
wt_bias_1.type = DataType::kFLOAT;
wt_filter_1.values = host_w1->raw_ptr();
wt_bias_1.values = host_b1->raw_ptr();
wt_filter_1.count = host_w1->shape().total_nr_elems();
wt_bias_1.count = host_b1->shape().total_nr_elems();
auto builder = createInferBuilder(TensorRTOpr::Logger::instance());
auto network = builder->createNetworkV2(
1 << static_cast<int>(
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH));
nvinfer1::ITensor* data;
data = network->addInput("data", DataType::kFLOAT, Dims4{-1, 23, -1, -1});
nvinfer1::IBuilderConfig* config = builder->createBuilderConfig();
nvinfer1::IOptimizationProfile* profile = builder->createOptimizationProfile();
profile->setDimensions(
"data", nvinfer1::OptProfileSelector::kMIN, Dims4(3, 23, 16, 16));
profile->setDimensions(
"data", nvinfer1::OptProfileSelector::kOPT, Dims4(4, 23, 24, 24));
profile->setDimensions(
"data", nvinfer1::OptProfileSelector::kMAX, Dims4(5, 23, 28, 28));
config->addOptimizationProfile(profile);
{
nvinfer1::TensorFormats formats =
1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR);
data->setAllowedFormats(formats);
}
mgb_assert(data != nullptr, "data is invalid");
auto conv1 =
network->addConvolution(*data, 32, DimsHW{3, 3}, wt_filter_1, wt_bias_1);
mgb_assert(conv1 != nullptr, "conv1 is invalid");
conv1->setStride(DimsHW{1, 1});
conv1->getOutput(0)->setName("prob1");
network->markOutput(*conv1->getOutput(0));
{
nvinfer1::TensorFormats formats =
1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR);
conv1->getOutput(0)->setAllowedFormats(formats);
}
TensorRTUniquePtr<ICudaEngine> cuda_engine{
builder->buildEngineWithConfig(*network, *config)};
return cuda_engine;
}
#endif
#pragma GCC diagnostic pop
#endif // MGB_ENABLE_TENSOR_RT
......
......@@ -104,6 +104,19 @@ struct ReshapeConcatTensorRTNetwork {
bool has_batch_dim);
};
#if NV_TENSOR_RT_VERSION >= 6001
struct DynamicShapeTensorRTNetwork {
HostTensorGenerator<> gen;
std::shared_ptr<HostTensorND> host_x, host_w1, host_b1;
std::shared_ptr<ComputingGraph> graph;
SymbolVar x, y1;
DynamicShapeTensorRTNetwork(size_t n, size_t c, size_t h, size_t w);
TensorRTUniquePtr<ICudaEngine> create_trt_network();
};
#endif
} // namespace intl
} // namespace opr
} // namespace mgb
......
......@@ -307,6 +307,32 @@ TEST(TestOprTensorRT, ICudaEngine) {
func->execute();
}
#if NV_TENSOR_RT_VERSION >= 6001
TEST(TestOprTensorRT, RuntimeDynamicShape) {
REQUIRE_GPU(1);
intl::DynamicShapeTensorRTNetwork net1{5, 23, 26, 26}, net2{4, 23, 24, 24};
auto make_trt = [](intl::DynamicShapeTensorRTNetwork& net) {
TensorRTUniquePtr<ICudaEngine> cuda_engine = net.create_trt_network();
TensorRTUniquePtr<IHostMemory> mem{cuda_engine->serialize(), {}};
return TensorRTRuntimeOpr::make(mem->data(), mem->size(), {net.x});
};
HostTensorND host_z1, host_z2;
auto y1 = make_trt(net1);
auto func1 = net1.graph->compile(
{make_callback_copy(net1.y1, host_z1), make_callback_copy(y1[0], host_z2)});
func1->execute();
MGB_ASSERT_TENSOR_NEAR(host_z1, host_z2, 5e-4);
auto y2 = make_trt(net2);
auto func2 = net2.graph->compile(
{make_callback_copy(net2.y1, host_z1), make_callback_copy(y2[0], host_z2)});
func2->execute();
MGB_ASSERT_TENSOR_NEAR(host_z1, host_z2, 5e-4);
}
#endif
#endif // MGB_ENABLE_TENSOR_RT
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册