未验证 提交 b76c2dc4 编写于 作者: G gaoziyuan 提交者: GitHub

fix trt layer output type (#51140)

上级 ec51485f
......@@ -44,16 +44,20 @@ class CastOpConverter : public OpConverter {
switch (out_dtype) {
case 0: // BOOL = 0
layer->setOutputType(0, nvinfer1::DataType::kBOOL);
layer->getOutput(0)->setType(nvinfer1::DataType::kBOOL);
break;
case 2: // INT32 = 2
case 3: // INT64 = 3 there is no int64 in tensorrt subgraph
layer->setOutputType(0, nvinfer1::DataType::kINT32);
layer->getOutput(0)->setType(nvinfer1::DataType::kINT32);
break;
case 4: // FP16 = 4
layer->setOutputType(0, nvinfer1::DataType::kHALF);
layer->getOutput(0)->setType(nvinfer1::DataType::kHALF);
break;
case 5: // FP32 = 5
layer->setOutputType(0, nvinfer1::DataType::kFLOAT);
layer->getOutput(0)->setType(nvinfer1::DataType::kFLOAT);
break;
default:
LOG(ERROR) << "Unable to convert a fluid data type(" << out_dtype
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册