未验证 提交 033d736d 编写于 作者: Z zlsh80826 提交者: GitHub

fix output data type selection (#34040)

上级 0a9ad8d7
...@@ -219,7 +219,7 @@ const char* AnchorGeneratorPlugin::getPluginNamespace() const { ...@@ -219,7 +219,7 @@ const char* AnchorGeneratorPlugin::getPluginNamespace() const {
nvinfer1::DataType AnchorGeneratorPlugin::getOutputDataType( nvinfer1::DataType AnchorGeneratorPlugin::getOutputDataType(
int index, const nvinfer1::DataType* input_type, int nb_inputs) const { int index, const nvinfer1::DataType* input_type, int nb_inputs) const {
return data_type_; return input_type[0];
} }
bool AnchorGeneratorPlugin::isOutputBroadcastAcrossBatch( bool AnchorGeneratorPlugin::isOutputBroadcastAcrossBatch(
...@@ -460,7 +460,7 @@ int AnchorGeneratorPluginDynamic::enqueue( ...@@ -460,7 +460,7 @@ int AnchorGeneratorPluginDynamic::enqueue(
nvinfer1::DataType AnchorGeneratorPluginDynamic::getOutputDataType( nvinfer1::DataType AnchorGeneratorPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType* inputTypes, int nbInputs) const { int index, const nvinfer1::DataType* inputTypes, int nbInputs) const {
return data_type_; return inputTypes[0];
} }
const char* AnchorGeneratorPluginDynamic::getPluginType() const { const char* AnchorGeneratorPluginDynamic::getPluginType() const {
......
...@@ -304,7 +304,7 @@ int RoiAlignPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, ...@@ -304,7 +304,7 @@ int RoiAlignPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
nvinfer1::DataType RoiAlignPluginDynamic::getOutputDataType( nvinfer1::DataType RoiAlignPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType* inputTypes, int nbInputs) const { int index, const nvinfer1::DataType* inputTypes, int nbInputs) const {
return data_type_; return inputTypes[0];
} }
const char* RoiAlignPluginDynamic::getPluginType() const { const char* RoiAlignPluginDynamic::getPluginType() const {
......
...@@ -299,7 +299,7 @@ const char* YoloBoxPlugin::getPluginNamespace() const { ...@@ -299,7 +299,7 @@ const char* YoloBoxPlugin::getPluginNamespace() const {
nvinfer1::DataType YoloBoxPlugin::getOutputDataType( nvinfer1::DataType YoloBoxPlugin::getOutputDataType(
int index, const nvinfer1::DataType* input_type, int nb_inputs) const { int index, const nvinfer1::DataType* input_type, int nb_inputs) const {
return data_type_; return input_type[0];
} }
bool YoloBoxPlugin::isOutputBroadcastAcrossBatch(int output_index, bool YoloBoxPlugin::isOutputBroadcastAcrossBatch(int output_index,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册