未验证 提交 758fccfe 编写于 作者: Z Zhang Jun 提交者: GitHub

[inference][trt] dynamic shape support for Instance norm (#47998)

* instance norm support dynamic shape
* update unittest
上级 1b1d6d3f
...@@ -74,10 +74,16 @@ class InstanceNormOpConverter : public OpConverter { ...@@ -74,10 +74,16 @@ class InstanceNormOpConverter : public OpConverter {
bias_v.push_back(bias_d[i]); bias_v.push_back(bias_d[i]);
} }
plugin::InstanceNormPlugin* plugin = nvinfer1::IPluginV2* plugin = nullptr;
new plugin::InstanceNormPlugin(eps, scale_v, bias_v); if (engine_->with_dynamic_shape()) {
plugin->getPluginType(); plugin = new plugin::InstanceNormPluginDynamic(eps, scale_v, bias_v);
auto* layer = engine_->AddPlugin(&input, 1, plugin); } else {
plugin = new plugin::InstanceNormPlugin(eps, scale_v, bias_v);
}
std::vector<nvinfer1::ITensor*> instance_norm_inputs{input};
auto* layer = engine_->network()->addPluginV2(
instance_norm_inputs.data(), instance_norm_inputs.size(), *plugin);
auto output_name = op_desc.Output("Y")[0]; auto output_name = op_desc.Output("Y")[0];
RreplenishLayerAndOutput(layer, "instance_norm", {output_name}, test_mode); RreplenishLayerAndOutput(layer, "instance_norm", {output_name}, test_mode);
......
...@@ -1501,10 +1501,6 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -1501,10 +1501,6 @@ struct SimpleOpTypeSetTeller : public Teller {
} }
if (op_type == "instance_norm") { if (op_type == "instance_norm") {
if (with_dynamic_shape) {
VLOG(3) << "trt instance_norm op does not support dynamic shape ";
return false;
}
if (desc.Input("X").size() != 1) { if (desc.Input("X").size() != 1) {
VLOG(3) << "input of instance_norm op converter should be 1, got " VLOG(3) << "input of instance_norm op converter should be 1, got "
<< desc.Input("X").size(); << desc.Input("X").size();
......
...@@ -131,6 +131,115 @@ int InstanceNormPlugin::enqueue(int batch_size, ...@@ -131,6 +131,115 @@ int InstanceNormPlugin::enqueue(int batch_size,
return cudaGetLastError() != cudaSuccess; return cudaGetLastError() != cudaSuccess;
} }
int InstanceNormPluginDynamic::initialize() TRT_NOEXCEPT { return 0; }
nvinfer1::DimsExprs InstanceNormPluginDynamic::getOutputDimensions(
int index,
const nvinfer1::DimsExprs *inputs,
int nbInputs,
nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT {
assert(nbInputs == 1);
assert(index < this->getNbOutputs());
nvinfer1::DimsExprs output(inputs[0]);
return output;
}
bool InstanceNormPluginDynamic::supportsFormatCombination(
int pos,
const nvinfer1::PluginTensorDesc *inOut,
int nbInputs,
int nbOutputs) TRT_NOEXCEPT {
assert(inOut && pos < (nbInputs + nbOutputs));
assert(pos == 0 || pos == 1);
return ((inOut[pos].type == nvinfer1::DataType::kFLOAT ||
inOut[pos].type == nvinfer1::DataType::kHALF) &&
(inOut[pos].format == nvinfer1::PluginFormat::kLINEAR) &&
inOut[pos].type == inOut[0].type);
}
int InstanceNormPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs,
void *const *outputs,
void *workspace,
cudaStream_t stream) TRT_NOEXCEPT {
nvinfer1::Dims input_dims = inputDesc[0].dims;
int n = input_dims.d[0];
int c = input_dims.d[1];
int h = input_dims.d[2];
int w = input_dims.d[3];
scale_t.Resize(phi::make_ddim({n, c}));
bias_t.Resize(phi::make_ddim({n, c}));
int device_id;
cudaGetDevice(&device_id);
float *scale_d = scale_t.mutable_data<float>(platform::CUDAPlace(device_id));
float *bias_d = bias_t.mutable_data<float>(platform::CUDAPlace(device_id));
for (int i = 0; i < n; i++) {
cudaMemcpyAsync(scale_d + i * c,
scale_.data(),
sizeof(float) * c,
cudaMemcpyHostToDevice,
stream);
cudaMemcpyAsync(bias_d + i * c,
bias_.data(),
sizeof(float) * c,
cudaMemcpyHostToDevice,
stream);
}
platform::dynload::cudnnSetTensor4dDescriptor(
b_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, n * c, 1, 1);
cudnnDataType_t cudnn_dtype;
auto data_type = inputDesc[0].type;
convert_trt2cudnn_dtype(data_type, &cudnn_dtype);
platform::dynload::cudnnSetTensor4dDescriptor(
x_desc_, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, n * c, h, w);
platform::dynload::cudnnSetTensor4dDescriptor(
y_desc_, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, n * c, h, w);
float alpha = 1;
float beta = 0;
platform::dynload::cudnnSetStream(handle_, stream);
void const *x_ptr = inputs[0];
void *y_ptr = outputs[0];
platform::dynload::cudnnBatchNormalizationForwardTraining(
handle_,
CUDNN_BATCHNORM_SPATIAL_PERSISTENT,
&alpha,
&beta,
x_desc_,
x_ptr,
y_desc_,
y_ptr,
b_desc_,
scale_d,
bias_d,
1.,
nullptr,
nullptr,
eps_,
nullptr,
nullptr);
return cudaGetLastError() != cudaSuccess;
}
nvinfer1::DataType InstanceNormPluginDynamic::getOutputDataType(
int index,
const nvinfer1::DataType *inputTypes,
int nbInputs) const TRT_NOEXCEPT {
assert(inputTypes && nbInputs > 0 && index == 0);
return inputTypes[0];
}
void InstanceNormPluginDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc *in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) TRT_NOEXCEPT {}
} // namespace plugin } // namespace plugin
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
......
...@@ -99,7 +99,7 @@ class InstanceNormPlugin : public PluginTensorRT { ...@@ -99,7 +99,7 @@ class InstanceNormPlugin : public PluginTensorRT {
} }
const char *getPluginType() const TRT_NOEXCEPT override { const char *getPluginType() const TRT_NOEXCEPT override {
return "instance_norm_plugin"; return "instance_norm";
} }
int getNbOutputs() const TRT_NOEXCEPT override { return 1; } int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
nvinfer1::Dims getOutputDimensions(int index, nvinfer1::Dims getOutputDimensions(int index,
...@@ -125,7 +125,7 @@ class InstanceNormPlugin : public PluginTensorRT { ...@@ -125,7 +125,7 @@ class InstanceNormPlugin : public PluginTensorRT {
class InstanceNormPluginCreator : public TensorRTPluginCreator { class InstanceNormPluginCreator : public TensorRTPluginCreator {
public: public:
const char *getPluginName() const TRT_NOEXCEPT override { const char *getPluginName() const TRT_NOEXCEPT override {
return "instance_norm_plugin"; return "instance_norm";
} }
const char *getPluginVersion() const TRT_NOEXCEPT override { return "1"; } const char *getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
...@@ -137,7 +137,137 @@ class InstanceNormPluginCreator : public TensorRTPluginCreator { ...@@ -137,7 +137,137 @@ class InstanceNormPluginCreator : public TensorRTPluginCreator {
return new InstanceNormPlugin(serial_data, serial_length); return new InstanceNormPlugin(serial_data, serial_length);
} }
}; };
class InstanceNormPluginDynamic : public DynamicPluginTensorRT {
private:
float eps_;
std::vector<float> scale_;
std::vector<float> bias_;
phi::DenseTensor scale_t;
phi::DenseTensor bias_t;
cudnnHandle_t handle_;
cudnnTensorDescriptor_t x_desc_, y_desc_, b_desc_;
public:
size_t getSerializationSize() const TRT_NOEXCEPT override {
return SerializedSize(eps_) + SerializedSize(scale_) +
SerializedSize(bias_);
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
// It should not be called by users.
void serialize(void *buffer) const TRT_NOEXCEPT override {
SerializeValue(&buffer, eps_);
SerializeValue(&buffer, scale_);
SerializeValue(&buffer, bias_);
}
explicit InstanceNormPluginDynamic(const float eps,
const std::vector<float> scale,
const std::vector<float> bias)
: eps_(eps), scale_(scale), bias_(bias) {
PADDLE_ENFORCE_EQ(scale.size(),
bias.size(),
platform::errors::InvalidArgument(
"The instanceNorm's scale and bias should be the "
"same size. Got scale size = %d, but bias size = %d",
scale.size(),
bias.size()));
platform::dynload::cudnnCreate(&handle_);
platform::dynload::cudnnCreateTensorDescriptor(&x_desc_);
platform::dynload::cudnnCreateTensorDescriptor(&y_desc_);
platform::dynload::cudnnCreateTensorDescriptor(&b_desc_);
}
// It was used for tensorrt deserialization.
// It should not be called by users.
InstanceNormPluginDynamic(void const *serialData, size_t serialLength) {
DeserializeValue(&serialData, &serialLength, &eps_);
DeserializeValue(&serialData, &serialLength, &scale_);
DeserializeValue(&serialData, &serialLength, &bias_);
platform::dynload::cudnnCreate(&handle_);
platform::dynload::cudnnCreateTensorDescriptor(&x_desc_);
platform::dynload::cudnnCreateTensorDescriptor(&y_desc_);
platform::dynload::cudnnCreateTensorDescriptor(&b_desc_);
}
~InstanceNormPluginDynamic() {
platform::dynload::cudnnDestroy(handle_);
platform::dynload::cudnnDestroyTensorDescriptor(x_desc_);
platform::dynload::cudnnDestroyTensorDescriptor(y_desc_);
platform::dynload::cudnnDestroyTensorDescriptor(b_desc_);
}
int initialize() TRT_NOEXCEPT override;
nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override {
return new InstanceNormPluginDynamic(eps_, scale_, bias_);
}
const char *getPluginType() const TRT_NOEXCEPT override {
return "instance_norm_dynamic";
}
int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
nvinfer1::DimsExprs getOutputDimensions(
int output_index,
const nvinfer1::DimsExprs *inputs,
int nb_inputs,
nvinfer1::IExprBuilder &expr_builder) // NOLINT
TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc *inOut,
int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const TRT_NOEXCEPT override {
return 0;
}
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs,
void *const *outputs,
void *workspace,
cudaStream_t stream) TRT_NOEXCEPT override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType *inputTypes,
int nbInputs) const
TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override { delete this; }
};
class InstanceNormPluginDynamicCreator : public TensorRTPluginCreator {
public:
const char *getPluginName() const TRT_NOEXCEPT override {
return "instance_norm_dynamic";
}
const char *getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
nvinfer1::IPluginV2 *deserializePlugin(const char *name,
const void *serial_data,
size_t serial_length)
TRT_NOEXCEPT override {
return new InstanceNormPluginDynamic(serial_data, serial_length);
}
};
REGISTER_TRT_PLUGIN_V2(InstanceNormPluginCreator); REGISTER_TRT_PLUGIN_V2(InstanceNormPluginCreator);
REGISTER_TRT_PLUGIN_V2(InstanceNormPluginDynamicCreator);
} // namespace plugin } // namespace plugin
} // namespace tensorrt } // namespace tensorrt
......
...@@ -50,7 +50,13 @@ class TrtConvertInstanceNormTest(TrtLayerAutoScanTest): ...@@ -50,7 +50,13 @@ class TrtConvertInstanceNormTest(TrtLayerAutoScanTest):
[batch, 16, 32, 64], [batch, 16, 32, 64],
]: ]:
self.in_dim = len(shape_input) self.in_dim = len(shape_input)
for epsilon in [0.0005, -1, 1]: for epsilon in [
0.0005,
-1,
1,
0.000009999999747378752,
0.00001,
]:
dics = [{"epsilon": epsilon}] dics = [{"epsilon": epsilon}]
ops_config = [ ops_config = [
{ {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册