diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index f7c12a5cc84fa83a3a17dbfab180e39fb6b46a7a..7d4d44219c884342a348c76383ac456ef9f45095 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -938,6 +938,7 @@ USE_TRT_CONVERTER(conv2d_transpose); USE_TRT_CONVERTER(leaky_relu); USE_TRT_CONVERTER(shuffle_channel); USE_TRT_CONVERTER(swish); +USE_TRT_CONVERTER(instance_norm); USE_TRT_CONVERTER(layer_norm); USE_TRT_CONVERTER(gelu); USE_TRT_CONVERTER(multihead_matmul); diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index e212388cb9ec13c3c0a050c814193cf5a6bd3107..ab9fffffe6413014937f5d00236360efc89cb905 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -3,7 +3,7 @@ nv_library(tensorrt_converter SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc - shuffle_channel_op.cc swish_op.cc + shuffle_channel_op.cc swish_op.cc instance_norm_op.cc DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS diff --git a/paddle/fluid/inference/tensorrt/convert/instance_norm_op.cc b/paddle/fluid/inference/tensorrt/convert/instance_norm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..2b6cbb904afac5f3038629a1f658b34e09853e4f --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/instance_norm_op.cc @@ -0,0 +1,75 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class InstanceNormOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(4) << "convert fluid prelu op to tensorrt instance norm layer"; + + framework::OpDesc op_desc(op, nullptr); + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + + float eps = boost::get(op_desc.GetAttr("epsilon")); + + auto* scale_var = scope.FindVar(op_desc.Input("Scale")[0]); + auto* bias_var = scope.FindVar(op_desc.Input("Bias")[0]); + PADDLE_ENFORCE_NOT_NULL( + scale_var, + platform::errors::InvalidArgument( + "Input [Scale] of instance_norm op converter should not be null")); + PADDLE_ENFORCE_NOT_NULL( + bias_var, + platform::errors::InvalidArgument( + "Input [Bias] of instance_norm op converter should not be null")); + auto* scale_tensor = scale_var->GetMutable(); + auto* bias_tensor = bias_var->GetMutable(); + PADDLE_ENFORCE_EQ( + scale_tensor->numel(), bias_tensor->numel(), + platform::errors::InvalidArgument( + "Num of input [Scale] and [Bias] of instance_norm op converter " + "should be equal. Got Scale num = %ld, but Bias num = %ld", + scale_tensor->numel(), bias_tensor->numel())); + auto* scale_d = scale_tensor->data(); + auto* bias_d = bias_tensor->data(); + + std::vector scale_v; + std::vector bias_v; + for (int i = 0; i < scale_tensor->numel(); i++) { + scale_v.push_back(scale_d[i]); + bias_v.push_back(bias_d[i]); + } + + plugin::InstanceNormPlugin* plugin = + new plugin::InstanceNormPlugin(eps, scale_v, bias_v); + plugin->getPluginType(); + nvinfer1::IPluginLayer* layer = engine_->AddPlugin(&input, 1, plugin); + + auto output_name = op_desc.Output("Y")[0]; + RreplenishLayerAndOutput(layer, "instance_norm", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(instance_norm, InstanceNormOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 462c6fb497081f47c7c5bacc2298f9b8d3d46824..39cdf5ba1af4b3d0fbfee3802bd614619a4b9580 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -32,30 +32,33 @@ struct SimpleOpTypeSetTeller : public Teller { } private: - std::unordered_set teller_set{{"mul", - "conv2d", - "pool2d", - "relu", - "softmax", - "sigmoid", - "depthwise_conv2d", - "batch_norm", - "concat", - "tanh", - "pad", - "elementwise_add", - "elementwise_mul", - "dropout", - "prelu", - "conv2d_transpose", - "leaky_relu", - "fc", - "shuffle_channel", - "swish", - "split", - "gelu", - "layer_norm", - "multihead_matmul"}}; + std::unordered_set teller_set{{ + "mul", + "conv2d", + "pool2d", + "relu", + "softmax", + "sigmoid", + "depthwise_conv2d", + "batch_norm", + "concat", + "tanh", + "pad", + "elementwise_add", + "elementwise_mul", + "dropout", + "prelu", + "conv2d_transpose", + "leaky_relu", + "fc", + "shuffle_channel", + "swish", + "split", + "instance_norm", + "gelu", + "layer_norm", + "multihead_matmul", + }}; }; bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc) { diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index 83efecc0bf92568fb692f024694b2d79fe942310..68c59685fadc766f0cfc0ac2ac811a6c6e6f85b3 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -1,5 +1,5 @@ nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu - pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu - DEPS enforce tensorrt_engine prelu) + pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu instance_norm_op_plugin.cu + DEPS enforce tensorrt_engine prelu tensor) diff --git a/paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.cu new file mode 100644 index 0000000000000000000000000000000000000000..835dc4ac30e0b52e39dca11756dac3f391ca2846 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.cu @@ -0,0 +1,119 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "glog/logging.h" +#include "paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h" +#include "paddle/fluid/platform/cudnn_helper.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +cudnnStatus_t convert_trt2cudnn_dtype(nvinfer1::DataType trt_dtype, + cudnnDataType_t *cudnn_dtype) { + switch (trt_dtype) { + case nvinfer1::DataType::kFLOAT: + *cudnn_dtype = CUDNN_DATA_FLOAT; + break; + case nvinfer1::DataType::kHALF: + *cudnn_dtype = CUDNN_DATA_HALF; + break; + default: + return CUDNN_STATUS_BAD_PARAM; + } + return CUDNN_STATUS_SUCCESS; +} + +InstanceNormPlugin *CreateInstanceNormPluginDeserialize(const void *buffer, + size_t length) { + return new InstanceNormPlugin(buffer, length); +} +REGISTER_TRT_PLUGIN("instance_norm_plugin", + CreateInstanceNormPluginDeserialize); + +int InstanceNormPlugin::initialize() { + platform::dynload::cudnnCreate(&handle_); + platform::dynload::cudnnCreateTensorDescriptor(&x_desc_); + platform::dynload::cudnnCreateTensorDescriptor(&y_desc_); + platform::dynload::cudnnCreateTensorDescriptor(&b_desc_); + return 0; +} + +nvinfer1::Dims InstanceNormPlugin::getOutputDimensions( + int index, const nvinfer1::Dims *inputDims, int nbInputs) { + assert(nbInputs == 1); + assert(index < this->getNbOutputs()); + nvinfer1::Dims const &input_dims = inputDims[0]; + nvinfer1::Dims output_dims = input_dims; + return output_dims; +} + +int InstanceNormPlugin::enqueue(int batch_size, const void *const *inputs, + void **outputs, void *workspace, + cudaStream_t stream) { + const auto &input_dims = this->getInputDims(0); + + PADDLE_ENFORCE_EQ(input_dims.nbDims, 3, + platform::errors::InvalidArgument( + "Input Dims should be 3 (except the batch), got %d", + input_dims.nbDims)); + int n = batch_size; + int c = input_dims.d[0]; + int h = input_dims.d[1]; + int w = input_dims.d[2]; + + scale_t.Resize(framework::make_ddim({batch_size, c})); + bias_t.Resize(framework::make_ddim({batch_size, c})); + int device_id; + cudaGetDevice(&device_id); + float *scale_d = scale_t.mutable_data(platform::CUDAPlace(device_id)); + float *bias_d = bias_t.mutable_data(platform::CUDAPlace(device_id)); + + for (int i = 0; i < batch_size; i++) { + cudaMemcpyAsync(scale_d + i * c, scale_.data(), sizeof(float) * c, + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(bias_d + i * c, bias_.data(), sizeof(float) * c, + cudaMemcpyHostToDevice, stream); + } + platform::dynload::cudnnSetTensor4dDescriptor( + b_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, n * c, 1, 1); + + cudnnDataType_t cudnn_dtype; + nvinfer1::DataType data_type = getDataType(); + convert_trt2cudnn_dtype(data_type, &cudnn_dtype); + platform::dynload::cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, + cudnn_dtype, 1, n * c, h, w); + platform::dynload::cudnnSetTensor4dDescriptor(y_desc_, CUDNN_TENSOR_NCHW, + cudnn_dtype, 1, n * c, h, w); + float alpha = 1; + float beta = 0; + platform::dynload::cudnnSetStream(handle_, stream); + + void const *x_ptr = inputs[0]; + void *y_ptr = outputs[0]; + platform::dynload::cudnnBatchNormalizationForwardTraining( + handle_, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, &alpha, &beta, x_desc_, + x_ptr, y_desc_, y_ptr, b_desc_, scale_d, bias_d, 1., nullptr, nullptr, + eps_, nullptr, nullptr); +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h new file mode 100644 index 0000000000000000000000000000000000000000..3f5023f369510b5d89a9896605e77a2b152bdb39 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/instance_norm_op_plugin.h @@ -0,0 +1,103 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +class InstanceNormPlugin : public PluginTensorRT { + private: + float eps_; + std::vector scale_; + std::vector bias_; + + framework::Tensor scale_t; + framework::Tensor bias_t; + cudnnHandle_t handle_; + cudnnTensorDescriptor_t x_desc_, y_desc_, b_desc_; + + protected: + size_t getSerializationSize() override { + return getBaseSerializationSize() + SerializedSize(eps_) + + SerializedSize(scale_) + SerializedSize(bias_); + } + + // TRT will call this func when we need to serialize the configuration of + // tensorrt. + // It should not be called by users. + void serialize(void *buffer) override { + SerializeValue(&buffer, getPluginType()); + serializeBase(buffer); + SerializeValue(&buffer, eps_); + SerializeValue(&buffer, scale_); + SerializeValue(&buffer, bias_); + } + + public: + explicit InstanceNormPlugin(const float eps, const std::vector scale, + const std::vector bias) + : eps_(eps), scale_(scale), bias_(bias) { + PADDLE_ENFORCE_EQ(scale.size(), bias.size(), + platform::errors::InvalidArgument( + "The instanceNorm's scale and bias should be the " + "same size. Got scale size = %d, but bias size = %d", + scale.size(), bias.size())); + } + + // It was used for tensorrt deserialization. + // It should not be called by users. + InstanceNormPlugin(void const *serialData, size_t serialLength) { + deserializeBase(serialData, serialLength); + DeserializeValue(&serialData, &serialLength, &eps_); + DeserializeValue(&serialData, &serialLength, &scale_); + DeserializeValue(&serialData, &serialLength, &bias_); + } + + ~InstanceNormPlugin() {} + int initialize() override; + + InstanceNormPlugin *clone() const override { + return new InstanceNormPlugin(eps_, scale_, bias_); + } + + const char *getPluginType() const override { return "instance_norm_plugin"; } + int getNbOutputs() const override { return 1; } + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs, + int nbInputDims) override; + int enqueue(int batchSize, const void *const *inputs, void **outputs, + void *workspace, cudaStream_t stream) override; + + bool supportsFormat(nvinfer1::DataType type, + nvinfer1::PluginFormat format) const override { + return ((type == nvinfer1::DataType::kFLOAT || + type == nvinfer1::DataType::kHALF) && + (format == nvinfer1::PluginFormat::kNCHW)); + } +}; + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index d58606be0bf27a4f240fe2d8acb137bc469d34d8..bbcd4f2136b2e9555315dca8f3cf6bea012448c7 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -321,6 +321,10 @@ if(WITH_GPU AND TENSORRT_FOUND) if (NOT EXISTS ${TEST_SPLIT_CONVERTER_MODEL}) inference_download_and_uncompress(${TEST_SPLIT_CONVERTER_MODEL} ${INFERENCE_URL}/tensorrt_test "split_converter.tgz") endif() + set(TEST_INSTANCE_NORM_MODEL "${TRT_MODEL_INSTALL_DIR}/trt_instance_norm_test") + if (NOT EXISTS ${TEST_INSTANCE_NORM_MODEL}) + inference_download_and_uncompress(${TEST_INSTANCE_NORM_MODEL} ${INFERENCE_URL}/tensorrt_test "instance_norm.tgz") + endif() inference_analysis_test(trt_mobilenet_test SRCS trt_mobilenet_test.cc EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models) @@ -342,6 +346,9 @@ if(WITH_GPU AND TENSORRT_FOUND) inference_analysis_test(trt_split_converter_test SRCS trt_split_converter_test.cc EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} ARGS --infer_model=${TEST_SPLIT_CONVERTER_MODEL}/) + inference_analysis_test(trt_instance_norm_test SRCS trt_instance_norm_converter_test.cc + EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} + ARGS --infer_model=${TEST_INSTANCE_NORM_MODEL}/) inference_analysis_test(test_analyzer_capi_gpu SRCS analyzer_capi_gpu_tester.cc EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} paddle_fluid_c ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models) diff --git a/paddle/fluid/inference/tests/api/trt_instance_norm_converter_test.cc b/paddle/fluid/inference/tests/api/trt_instance_norm_converter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..759c7b260f0096db6dd59e7694c957bd1147fa5c --- /dev/null +++ b/paddle/fluid/inference/tests/api/trt_instance_norm_converter_test.cc @@ -0,0 +1,50 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "paddle/fluid/inference/tests/api/trt_test_helper.h" + +namespace paddle { +namespace inference { + +TEST(TensorRT, instance_norm) { + std::string model_dir = FLAGS_infer_model + "/instance_norm"; + AnalysisConfig config; + int batch_size = 4; + config.EnableUseGpu(100, 0); + config.SetModel(model_dir); + config.SwitchUseFeedFetchOps(false); + config.EnableTensorRtEngine(1 << 20, batch_size, 0, + AnalysisConfig::Precision::kFloat32, false); + + auto predictor = CreatePaddlePredictor(config); + + int length = 4; + int input_num = batch_size * length; + float *input = new float[input_num]; + memset(input, 1.0, input_num * sizeof(float)); + + auto input_names = predictor->GetInputNames(); + auto input_t = predictor->GetInputTensor(input_names[0]); + input_t->Reshape({batch_size, length}); + input_t->copy_from_cpu(input); + + ASSERT_TRUE(predictor->ZeroCopyRun()); +} + +} // namespace inference +} // namespace paddle