From e5bd7eb82eca1eeb83a742e48eea0dd1d284fbab Mon Sep 17 00:00:00 2001 From: Shang Zhizhou Date: Wed, 16 Jun 2021 10:29:03 +0800 Subject: [PATCH] Add trt layer norm dynamic (#33448) * 1, remove layernorm dynamic fp16; 2, let reshape out in dynamic shape (#33535) --- .../tensorrt/convert/layer_norm_op.cc | 38 +++-- paddle/fluid/inference/tensorrt/op_teller.cc | 2 +- .../tensorrt/plugin/layer_norm_op_plugin.cu | 109 ++++++++++++- .../tensorrt/plugin/layer_norm_op_plugin.h | 149 +++++++++++++++++- paddle/fluid/pybind/inference_api.cc | 1 + .../ir/inference/inference_pass_test.py | 5 +- .../ir/inference/test_trt_subgraph_pass.py | 55 +++++++ 7 files changed, 336 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/layer_norm_op.cc b/paddle/fluid/inference/tensorrt/convert/layer_norm_op.cc index 0b97b5d87a..de5d3110e1 100644 --- a/paddle/fluid/inference/tensorrt/convert/layer_norm_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/layer_norm_op.cc @@ -46,13 +46,6 @@ class LayerNormOpConverter : public OpConverter { auto* Bias_t = Bias_v->GetMutable(); auto* Scale_t = Scale_v->GetMutable(); - int input_num = 1; - for (int i = 0; i < X->getDimensions().nbDims; i++) { - input_num *= X->getDimensions().d[i]; - } - std::vector mean_shape{input_num}; - std::vector variance_shape{input_num}; - std::unique_ptr bias_tensor( new framework::LoDTensor()); std::unique_ptr scale_tensor( @@ -68,10 +61,33 @@ class LayerNormOpConverter : public OpConverter { auto* bias_data = bias_tensor->mutable_data(platform::CPUPlace()); auto* scale_data = scale_tensor->mutable_data(platform::CPUPlace()); - plugin::LayerNormPlugin* plugin = new plugin::LayerNormPlugin( - bias_data, bias_tensor->numel(), scale_data, scale_tensor->numel(), - begin_norm_axis, eps, mean_shape, variance_shape); - nvinfer1::IPluginLayer* layernorm_layer = engine_->AddPlugin(&X, 1, plugin); + nvinfer1::ILayer* layernorm_layer = nullptr; + if (engine_->with_dynamic_shape()) { + int input_num = 1; + for (int i = begin_norm_axis; i < X->getDimensions().nbDims; i++) { + input_num *= X->getDimensions().d[i]; + } + std::vector mean_shape{input_num}; + std::vector variance_shape{input_num}; + plugin::LayerNormPluginDynamic* plugin = + new plugin::LayerNormPluginDynamic(bias_data, bias_tensor->numel(), + scale_data, scale_tensor->numel(), + begin_norm_axis, eps, mean_shape, + variance_shape); + layernorm_layer = engine_->AddDynamicPlugin(&X, 1, plugin); + } else { + int input_num = 1; + for (int i = begin_norm_axis - 1; i < X->getDimensions().nbDims; i++) { + input_num *= X->getDimensions().d[i]; + } + std::vector mean_shape{input_num}; + std::vector variance_shape{input_num}; + plugin::LayerNormPlugin* plugin = new plugin::LayerNormPlugin( + bias_data, bias_tensor->numel(), scale_data, scale_tensor->numel(), + begin_norm_axis, eps, mean_shape, variance_shape); + layernorm_layer = engine_->AddPlugin( + &X, 1, reinterpret_cast(plugin)); + } auto output_name = op_desc.Output("Y").front(); engine_->SetWeights(op_desc.Input("Bias").front(), std::move(bias_tensor)); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 07dc1a0684..44611d1d59 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -700,7 +700,7 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, } if (op_type == "reshape" || op_type == "reshape2") { - if (!desc.HasAttr("shape") || with_dynamic_shape) { + if (!desc.HasAttr("shape")) { return false; // Paddle-TRT does not support the input tensors: Shape and ShapeTensor } else if (desc.Input("Shape").size() >= 1 || diff --git a/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu index 8af036a0e8..f9341613a0 100644 --- a/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu @@ -57,8 +57,18 @@ int LayerNormPlugin::enqueue(int batch_size, const void *const *inputs, input_shape.push_back(input_dims.d[i]); } const auto input_ddim = framework::make_ddim(input_shape); - auto matrix_dim = framework::flatten_to_2d(input_ddim, begin_norm_axis - 1); + auto matrix_dim = framework::flatten_to_2d(input_ddim, begin_norm_axis); int feature_size = static_cast(matrix_dim[1]); + PADDLE_ENFORCE_EQ(feature_size, scale_.size(), + platform::errors::InvalidArgument( + "scale's size should be equal to the feature_size," + "but got feature_size:%d, scale's size:%d.", + feature_size, scale_.size())); + PADDLE_ENFORCE_EQ(feature_size, bias_.size(), + platform::errors::InvalidArgument( + "bias's size should be equal to the feature_size," + "but got feature_size:%d, bias's size:%d.", + feature_size, bias_.size())); scale_t.Resize(framework::make_ddim({feature_size})); bias_t.Resize(framework::make_ddim({feature_size})); @@ -82,6 +92,103 @@ int LayerNormPlugin::enqueue(int batch_size, const void *const *inputs, return cudaGetLastError() != cudaSuccess; } +nvinfer1::DimsExprs LayerNormPluginDynamic::getOutputDimensions( + int output_index, const nvinfer1::DimsExprs *inputDims, int nb_inputs, + nvinfer1::IExprBuilder &expr_builder) { + return inputDims[0]; +} + +bool LayerNormPluginDynamic::supportsFormatCombination( + int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs, + int nb_outputs) { + PADDLE_ENFORCE_NOT_NULL( + in_out, platform::errors::InvalidArgument( + "The input of layernorm plugin shoule not be nullptr.")); + PADDLE_ENFORCE_LT( + pos, nb_inputs + nb_outputs, + platform::errors::InvalidArgument("The pos(%d) should be less than the " + "num(%d) of the input and the output.", + pos, nb_inputs + nb_outputs)); + const nvinfer1::PluginTensorDesc &in = in_out[pos]; + if (pos == 0) { + // TODO(Shangzhizhou) FP16 support + return (in.type == nvinfer1::DataType::kFLOAT) && + (in.format == nvinfer1::TensorFormat::kLINEAR); + } + const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1]; + // output + return in.type == prev.type && in.format == prev.format; +} + +nvinfer1::DataType LayerNormPluginDynamic::getOutputDataType( + int index, const nvinfer1::DataType *input_types, int nb_inputs) const { + PADDLE_ENFORCE_EQ(index, 0, + platform::errors::InvalidArgument( + "The LayerNormPlugin only has one input, so the " + "index value should be 0, but get %d.", + index)); + return input_types[0]; +} + +int LayerNormPluginDynamic::enqueue( + const nvinfer1::PluginTensorDesc *input_desc, + const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs, + void *const *outputs, void *workspace, cudaStream_t stream) { + const auto &input_dims = input_desc[0].dims; + int begin_norm_axis = begin_norm_axis_; + float eps = eps_; + + std::vector input_shape; + for (int i = 0; i < input_dims.nbDims; i++) { + input_shape.push_back(input_dims.d[i]); + } + const auto input_ddim = framework::make_ddim(input_shape); + auto matrix_dim = framework::flatten_to_2d(input_ddim, begin_norm_axis); + int feature_size = static_cast(matrix_dim[1]); + PADDLE_ENFORCE_EQ(feature_size, scale_.size(), + platform::errors::InvalidArgument( + "scale's size should be equal to the feature_size," + "but got feature_size:%d, scale's size:%d.", + feature_size, scale_.size())); + PADDLE_ENFORCE_EQ(feature_size, bias_.size(), + platform::errors::InvalidArgument( + "bias's size should be equal to the feature_size," + "but got feature_size:%d, bias's size:%d.", + feature_size, bias_.size())); + int device_id; + cudaGetDevice(&device_id); + auto input_type = input_desc[0].type; + if (input_type == nvinfer1::DataType::kFLOAT) { + VLOG(1) << "TRT Plugin DataType selected. LayerNorm-->fp32"; + const float *input = reinterpret_cast(inputs[0]); + float *output = static_cast(outputs[0]); + scale_t.Resize(framework::make_ddim({feature_size})); + bias_t.Resize(framework::make_ddim({feature_size})); + mean_t.Resize(framework::make_ddim(mean_shape_)); + variance_t.Resize(framework::make_ddim(variance_shape_)); + + float *scale_d = + scale_t.mutable_data(platform::CUDAPlace(device_id)); + float *bias_d = bias_t.mutable_data(platform::CUDAPlace(device_id)); + float *mean_d = mean_t.mutable_data(platform::CUDAPlace(device_id)); + float *variance_d = + variance_t.mutable_data(platform::CUDAPlace(device_id)); + + cudaMemcpyAsync(scale_d, scale_.data(), sizeof(float) * feature_size, + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(bias_d, bias_.data(), sizeof(float) * feature_size, + cudaMemcpyHostToDevice, stream); + + paddle::operators::LayerNormDirectCUDAFunctor layer_norm; + layer_norm(stream, input, input_shape, bias_d, scale_d, output, mean_d, + variance_d, begin_norm_axis, eps); + } else { + PADDLE_THROW(platform::errors::Fatal( + "The LayerNorm TRT Plugin's input type should be float.")); + } + return cudaGetLastError() != cudaSuccess; +} + } // namespace plugin } // namespace tensorrt } // namespace inference diff --git a/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h index 050ef3b77d..9c4c31b61e 100644 --- a/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h @@ -50,7 +50,7 @@ class LayerNormPlugin : public PluginTensorRT { // TRT will call this func when we need to serialize the configuration of // tensorrt. // It should not be called by users. - void serialize(void *buffer) override { + void serialize(void* buffer) override { SerializeValue(&buffer, getPluginType()); serializeBase(buffer); SerializeValue(&buffer, bias_); @@ -62,7 +62,7 @@ class LayerNormPlugin : public PluginTensorRT { } public: - LayerNormPlugin(const float *bias, const int bias_num, const float *scale, + LayerNormPlugin(const float* bias, const int bias_num, const float* scale, const int scale_num, int begin_norm_axis, float eps, std::vector mean_shape, std::vector variance_shape) @@ -78,7 +78,7 @@ class LayerNormPlugin : public PluginTensorRT { // It was used for tensorrt deserialization. // It should not be called by users. - LayerNormPlugin(void const *serialData, size_t serialLength) { + LayerNormPlugin(void const* serialData, size_t serialLength) { deserializeBase(serialData, serialLength); DeserializeValue(&serialData, &serialLength, &bias_); DeserializeValue(&serialData, &serialLength, &scale_); @@ -90,20 +90,153 @@ class LayerNormPlugin : public PluginTensorRT { ~LayerNormPlugin() {} int initialize() override; - LayerNormPlugin *clone() const override { + LayerNormPlugin* clone() const override { return new LayerNormPlugin(bias_.data(), bias_.size(), scale_.data(), scale_.size(), begin_norm_axis_, eps_, mean_shape_, variance_shape_); } - const char *getPluginType() const override { return "layer_norm_plugin"; } + const char* getPluginType() const override { return "layer_norm_plugin"; } int getNbOutputs() const override { return 1; } - nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs, + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, int nbInputDims) override; - int enqueue(int batchSize, const void *const *inputs, void **outputs, - void *workspace, cudaStream_t stream) override; + int enqueue(int batchSize, const void* const* inputs, void** outputs, + void* workspace, cudaStream_t stream) override; }; +class LayerNormPluginDynamic : public DynamicPluginTensorRT { + public: + LayerNormPluginDynamic(const float* bias, const int bias_num, + const float* scale, const int scale_num, + int begin_norm_axis, float eps, + std::vector mean_shape, + std::vector variance_shape) + : begin_norm_axis_(begin_norm_axis), + eps_(eps), + mean_shape_(mean_shape), + variance_shape_(variance_shape) { + bias_.resize(bias_num); + scale_.resize(scale_num); + std::copy(bias, bias + bias_num, bias_.data()); + std::copy(scale, scale + scale_num, scale_.data()); + } + + LayerNormPluginDynamic(void const* serialData, size_t serialLength) { + DeserializeValue(&serialData, &serialLength, &bias_); + DeserializeValue(&serialData, &serialLength, &scale_); + DeserializeValue(&serialData, &serialLength, &begin_norm_axis_); + DeserializeValue(&serialData, &serialLength, &eps_); + DeserializeValue(&serialData, &serialLength, &mean_shape_); + DeserializeValue(&serialData, &serialLength, &variance_shape_); + } + nvinfer1::IPluginV2DynamicExt* clone() const override { + return new LayerNormPluginDynamic(bias_.data(), bias_.size(), scale_.data(), + scale_.size(), begin_norm_axis_, eps_, + mean_shape_, variance_shape_); + } + + const char* getPluginType() const override { return "layernorm_plugin"; } + int getNbOutputs() const override { return 1; } + int initialize() override { return 0; } + + size_t getSerializationSize() const override { + return SerializedSize(bias_) + SerializedSize(scale_) + + SerializedSize(begin_norm_axis_) + SerializedSize(eps_) + + SerializedSize(mean_shape_) + SerializedSize(variance_shape_); + } + + void serialize(void* buffer) const override { + SerializeValue(&buffer, bias_); + SerializeValue(&buffer, scale_); + SerializeValue(&buffer, begin_norm_axis_); + SerializeValue(&buffer, eps_); + SerializeValue(&buffer, mean_shape_); + SerializeValue(&buffer, variance_shape_); + } + + nvinfer1::DimsExprs getOutputDimensions( + int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs, + nvinfer1::IExprBuilder& expr_builder) override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, int nbOutputs) override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) override {} + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const override { + return 0; + } + + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) override; + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const override; + + void destroy() override { delete this; } + + private: + std::vector bias_; + std::vector scale_; + framework::Tensor scale_t; + framework::Tensor bias_t; + framework::Tensor mean_t; + framework::Tensor variance_t; + int begin_norm_axis_; + float eps_; + std::vector mean_shape_; + std::vector variance_shape_; +}; + +class LayerNormPluginDynamicCreator : public nvinfer1::IPluginCreator { + public: + LayerNormPluginDynamicCreator() {} + const char* getPluginName() const override { return "layernorm_plugin"; } + + const char* getPluginVersion() const override { return "1"; } + + const nvinfer1::PluginFieldCollection* getFieldNames() override { + return &field_collection_; + } + + nvinfer1::IPluginV2* createPlugin( + const char* name, const nvinfer1::PluginFieldCollection* fc) override { + return nullptr; + } + + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serial_data, + size_t serial_length) override { + auto plugin = new LayerNormPluginDynamic(serial_data, serial_length); + return plugin; + } + + void setPluginNamespace(const char* lib_namespace) override { + plugin_namespace_ = lib_namespace; + } + + const char* getPluginNamespace() const override { + return plugin_namespace_.c_str(); + } + + private: + std::string plugin_namespace_; + std::string plugin_name_; + nvinfer1::PluginFieldCollection field_collection_{0, nullptr}; + std::vector plugin_attributes_; +}; + +REGISTER_TRT_PLUGIN_V2(LayerNormPluginDynamicCreator); + } // namespace plugin } // namespace tensorrt } // namespace inference diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 8a5ad5852a..b2572e5aa4 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -511,6 +511,7 @@ void BindAnalysisConfig(py::module *m) { py::arg("disable_trt_plugin_fp16") = false) .def("enable_tensorrt_oss", &AnalysisConfig::EnableTensorRtOSS) .def("tensorrt_oss_enabled", &AnalysisConfig::tensorrt_oss_enabled) + .def("exp_disable_tensorrt_ops", &AnalysisConfig::Exp_DisableTensorRtOPs) .def("enable_tensorrt_dla", &AnalysisConfig::EnableTensorRtDLA, py::arg("dla_core") = 0) .def("tensorrt_dla_enabled", &AnalysisConfig::tensorrt_dla_enabled) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py b/python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py index 010086bfbb..e3c21eaa78 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py @@ -160,7 +160,8 @@ class InferencePassTest(unittest.TestCase): use_gpu, atol=1e-5, flatten=False, - quant=False): + quant=False, + rtol=1e-5): ''' Check whether calculating on CPU and GPU, enable TensorRT or disable TensorRT, enable MKLDNN or disable MKLDNN @@ -260,7 +261,7 @@ class InferencePassTest(unittest.TestCase): self.assertTrue( np.allclose( - out, tensorrt_output, atol=atol), + out, tensorrt_output, rtol=rtol, atol=atol), "Output has diff between GPU and TensorRT. ") # Check whether the mkldnn results and the CPU results are the same. diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py index bdcdeee8dc..25d0173ef5 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py @@ -367,6 +367,61 @@ class TensorRTSubgraphPassLayerNormTest(InferencePassTest): PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) +class TensorRTSubgraphPassLayerNormDynamicTest(InferencePassTest): + def setUp(self): + self.set_params() + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 3, 64, 64], dtype="float32") + out = fluid.layers.layer_norm( + data, begin_norm_axis=self.begin_norm_axis) + self.feeds = { + "data": np.random.random([1, 3, 64, 64]).astype("float32"), + } + self.set_trt_params() + self.fetch_list = [out] + + def set_trt_params(self): + self.enable_trt = True + self.trt_parameters = TensorRTSubgraphPassLayerNormDynamicTest.TensorRTParam( + 1 << 30, 32, 0, self.precision, self.serialize, False) + self.dynamic_shape_params = TensorRTSubgraphPassLayerNormDynamicTest.DynamicShapeParam( + { + 'data': [1, 3, 64, 64], + }, {'data': [8, 8, 64, 64], }, {'data': [4, 4, 64, 64], }, False) + + def set_params(self): + self.begin_norm_axis = 2 + self.precision = AnalysisConfig.Precision.Float32 + self.serialize = True + + def test_check_output(self): + if os.path.exists(self.path + "_opt_cache"): + shutil.rmtree(self.path + "_opt_cache") + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu) + self.assertTrue( + PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) + + +class TensorRTSubgraphPassLayerNormDynamicFP16Test( + TensorRTSubgraphPassLayerNormDynamicTest): + def set_params(self): + self.begin_norm_axis = 2 + self.precision = AnalysisConfig.Precision.Half + self.serialize = True + + def test_check_output(self): + if os.path.exists(self.path + "_opt_cache"): + shutil.rmtree(self.path + "_opt_cache") + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu, atol=0.01, rtol=0.01) + self.assertTrue( + PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) + + class TensorRTSubgraphPassLayerNormBeginNormAxis2Test( TensorRTSubgraphPassLayerNormTest): def set_params(self): -- GitLab