未验证 提交 519df32e 编写于 作者: W Wilber 提交者: GitHub

cherry-pick 34040 (#34228)

上级 a456a1be
...@@ -215,7 +215,7 @@ const char* AnchorGeneratorPlugin::getPluginNamespace() const { ...@@ -215,7 +215,7 @@ const char* AnchorGeneratorPlugin::getPluginNamespace() const {
nvinfer1::DataType AnchorGeneratorPlugin::getOutputDataType( nvinfer1::DataType AnchorGeneratorPlugin::getOutputDataType(
int index, const nvinfer1::DataType* input_type, int nb_inputs) const { int index, const nvinfer1::DataType* input_type, int nb_inputs) const {
return data_type_; return input_type[0];
} }
bool AnchorGeneratorPlugin::isOutputBroadcastAcrossBatch( bool AnchorGeneratorPlugin::isOutputBroadcastAcrossBatch(
...@@ -456,7 +456,7 @@ int AnchorGeneratorPluginDynamic::enqueue( ...@@ -456,7 +456,7 @@ int AnchorGeneratorPluginDynamic::enqueue(
nvinfer1::DataType AnchorGeneratorPluginDynamic::getOutputDataType( nvinfer1::DataType AnchorGeneratorPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType* inputTypes, int nbInputs) const { int index, const nvinfer1::DataType* inputTypes, int nbInputs) const {
return data_type_; return inputTypes[0];
} }
const char* AnchorGeneratorPluginDynamic::getPluginType() const { const char* AnchorGeneratorPluginDynamic::getPluginType() const {
......
...@@ -304,7 +304,7 @@ int RoiAlignPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, ...@@ -304,7 +304,7 @@ int RoiAlignPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
nvinfer1::DataType RoiAlignPluginDynamic::getOutputDataType( nvinfer1::DataType RoiAlignPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType* inputTypes, int nbInputs) const { int index, const nvinfer1::DataType* inputTypes, int nbInputs) const {
return data_type_; return inputTypes[0];
} }
const char* RoiAlignPluginDynamic::getPluginType() const { const char* RoiAlignPluginDynamic::getPluginType() const {
......
...@@ -295,7 +295,7 @@ const char* YoloBoxPlugin::getPluginNamespace() const { ...@@ -295,7 +295,7 @@ const char* YoloBoxPlugin::getPluginNamespace() const {
nvinfer1::DataType YoloBoxPlugin::getOutputDataType( nvinfer1::DataType YoloBoxPlugin::getOutputDataType(
int index, const nvinfer1::DataType* input_type, int nb_inputs) const { int index, const nvinfer1::DataType* input_type, int nb_inputs) const {
return data_type_; return input_type[0];
} }
bool YoloBoxPlugin::isOutputBroadcastAcrossBatch(int output_index, bool YoloBoxPlugin::isOutputBroadcastAcrossBatch(int output_index,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册