提交 ab9dfbce 编写于 作者: M Megvii Engine Team

test(mgb): fix tensorrt tests missing cudaSetDevice

GitOrigin-RevId: faeb6ae070082085c49d552a48d949b01f5dcc10
上级 b43fb1a9
......@@ -46,6 +46,7 @@ intl::SimpleTensorRTNetwork::SimpleTensorRTNetwork() {
std::pair<nvinfer1::IBuilder*, INetworkDefinition*>
intl::SimpleTensorRTNetwork::create_trt_network(bool has_batch_dim) {
CompNode::load("xpu0").activate();
Weights wt_filter{DataType::kFLOAT, nullptr, 0},
wt_bias{DataType::kFLOAT, nullptr, 0};
wt_filter.type = DataType::kFLOAT;
......@@ -205,6 +206,7 @@ intl::SimpleQuantizedTensorRTNetwork::SimpleQuantizedTensorRTNetwork() {
std::pair<nvinfer1::IBuilder*, INetworkDefinition*>
intl::SimpleQuantizedTensorRTNetwork::create_trt_network(
bool has_batch_dim) {
CompNode::load("xpu0").activate();
Weights wt_filter{DataType::kFLOAT, nullptr, 0},
wt_bias{DataType::kFLOAT, nullptr, 0};
wt_filter.type = DataType::kFLOAT;
......@@ -290,6 +292,7 @@ intl::ConcatConvTensorRTNetwork::ConcatConvTensorRTNetwork() {
std::pair<nvinfer1::IBuilder*, INetworkDefinition*>
intl::ConcatConvTensorRTNetwork::create_trt_network(bool has_batch_dim) {
CompNode::load("xpu0").activate();
auto builder = createInferBuilder(TensorRTOpr::Logger::instance());
#if NV_TENSOR_RT_VERSION >= 6001
nvinfer1::NetworkDefinitionCreationFlags flags;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册