From 2ca3fe5d0239729ecffcd822fafc2b96ae89d790 Mon Sep 17 00:00:00 2001 From: zlsh80826 Date: Tue, 25 Aug 2020 15:50:23 +0800 Subject: [PATCH] multihead att plugin --- .../tensorrt/convert/multihead_matmul_op.cc | 152 +++++++++++++++--- .../inference/tensorrt/convert/op_converter.h | 2 + .../inference/tensorrt/convert/slice_op.cc | 7 +- .../inference/tensorrt/convert/ut_helper.h | 2 + paddle/fluid/inference/tensorrt/engine.cc | 1 + paddle/fluid/inference/tensorrt/op_teller.cc | 1 + .../inference/tensorrt/plugin/CMakeLists.txt | 1 + .../tensorrt/plugin/cast_int_plugin.cu | 85 ++++++++++ .../tensorrt/plugin/cast_int_plugin.h | 120 ++++++++++++++ .../operators/tensorrt/tensorrt_engine_op.h | 3 + 10 files changed, 349 insertions(+), 25 deletions(-) create mode 100644 paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.cu create mode 100644 paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.h diff --git a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc index a19d91f36e2..d71a4f23374 100644 --- a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.h" namespace paddle { @@ -30,7 +31,6 @@ class MultiheadMatMulOpConverter : public OpConverter { // Declare inputs // Shouble be a 5 dims tensor. auto* input = engine_->GetITensor(op_desc.Input("Input").front()); - auto* input_bias_qk = engine_->GetITensor(op_desc.Input("BiasQK").front()); // fc weights and fc bias auto weight_name = op_desc.Input("W").front(); @@ -50,7 +50,7 @@ class MultiheadMatMulOpConverter : public OpConverter { memcpy(weight_data_tmp.data(), weight_data, weight_t->numel() * sizeof(float)); - // (hidden, 3, all_head_size) + // (hidden, 3, all_head_size) auto weight_dims = weight_t->dims(); int hidden = weight_dims[0]; // channels_in @@ -65,36 +65,144 @@ class MultiheadMatMulOpConverter : public OpConverter { } } }; - - // transpose weight_data from m * n to n * m tranpose_weight(weight_data_tmp.data(), weight_data, m, n); - TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, - static_cast(weight_data), - static_cast(weight_t->numel())}; - - weight.dims.assign({n, m}); - TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, - static_cast(bias_data), - static_cast(bias_t->numel())}; - - auto* fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *input, n, - weight.get(), bias.get()); - auto* fc_out = fc_layer->getOutput(0); - // add qkv to context + int head_number = BOOST_GET_CONST(int, op_desc.GetAttr("head_number")); - int head_size = all_head_size / head_number; - float scale = BOOST_GET_CONST(float, op_desc.GetAttr("alpha")); - std::vector plugin_inputs; - plugin_inputs.push_back(fc_out); - plugin_inputs.push_back(input_bias_qk); nvinfer1::ILayer* layer = nullptr; + if (engine_->with_dynamic_shape()) { +#ifdef USE_NVINFER_PLUGIN + int head_size = hidden / head_number; + // [3, Nout, Hout, Nin, Hin] -> [Nout, 3, Hout, Nin, Hin] + auto transpose_weight_v2 = [](const float* src, float* dst, int N, + int H) { + const int HNH = H * N * H; + for (int i = 0; i < 3; ++i) { + for (int n = 0; n < N; ++n) { + for (int hnh = 0; hnh < HNH; ++hnh) { + dst[n * 3 * HNH + i * HNH + hnh] = + src[i * N * HNH + n * HNH + hnh]; + } + } + } + }; + // [3, N, H] -> [N, 3, H] + auto transpose_bias_v2 = [](const float* src, float* dst, int N, int H) { + for (int i = 0; i < 3; ++i) { + for (int n = 0; n < N; ++n) { + for (int h = 0; h < H; ++h) { + dst[n * 3 * H + i * H + h] = src[i * N * H + n * H + h]; + } + } + } + }; + memcpy(weight_data_tmp.data(), weight_data, + weight_t->numel() * sizeof(float)); + transpose_weight_v2(weight_data_tmp.data(), weight_data, head_number, + head_size); + nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT, + static_cast(weight_data), + static_cast(weight_t->numel())}; + + std::vector bias_data_tmp; + bias_data_tmp.reserve(bias_t->numel()); + memcpy(bias_data_tmp.data(), bias_data, bias_t->numel() * sizeof(float)); + transpose_bias_v2(bias_data_tmp.data(), bias_data, head_number, + head_size); + nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, + static_cast(bias_data), + static_cast(bias_t->numel())}; + + nvinfer1::Permutation permutation{1, 0, 2, 3, 4}; + auto trans_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); + trans_layer->setFirstTranspose(permutation); + + auto* fc_layer = TRT_ENGINE_ADD_LAYER( + engine_, FullyConnected, *trans_layer->getOutput(0), n, weight, bias); + + auto pos_tensor = engine_->GetITensor("eval_placeholder_2"); + plugin::CastIntPluginDynamic* cast_plugin = + new plugin::CastIntPluginDynamic(); + auto cast_layer = engine_->AddPluginV2(&pos_tensor, 1, cast_plugin); + + auto casted_pos_tensor = cast_layer->getOutput(0); + auto reshape_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *casted_pos_tensor); + + nvinfer1::Dims2 reshape_dim(0, 0); + nvinfer1::Permutation perm{1, 0, 2}; + reshape_layer->setFirstTranspose(perm); + reshape_layer->setReshapeDimensions(reshape_dim); + auto reduce_layer = + TRT_ENGINE_ADD_LAYER(engine_, Reduce, *reshape_layer->getOutput(0), + nvinfer1::ReduceOperation::kMAX, 1, false); + + auto creator = GetPluginRegistry()->getPluginCreator( + "CustomQKVToContextPluginDynamic", "1"); + assert(creator != nullptr); + int type = static_cast((engine_->WithFp16() == 1) + ? nvinfer1::DataType::kHALF + : nvinfer1::DataType::kFLOAT); + bool has_mask = true; + const std::vector fields{ + {"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1}, + {"hidden_size", &hidden, nvinfer1::PluginFieldType::kINT32, 1}, + {"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1}, + {"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, + 1}, // no bool type + }; + nvinfer1::PluginFieldCollection* pluginPtr = + static_cast( + malloc(sizeof(*pluginPtr) + + fields.size() * + sizeof(nvinfer1::PluginField))); // remember to free + pluginPtr->nbFields = static_cast(fields.size()); + pluginPtr->fields = fields.data(); + + auto pluginObj = + creator->createPlugin("CustomQKVToContextPluginDynamic", pluginPtr); + std::vector plugin_inputs; + plugin_inputs.push_back(fc_layer->getOutput(0)); + plugin_inputs.push_back(reduce_layer->getOutput(0)); + auto plugin_layer = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *pluginObj); + assert(plugin_layer != nullptr); + auto trans_r_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *plugin_layer->getOutput(0)); + assert(trans_r_layer != nullptr); + trans_r_layer->setFirstTranspose(permutation); + layer = trans_r_layer; +#else + // transpose weight_data from m * n to n * m + auto* input_bias_qk = + engine_->GetITensor(op_desc.Input("BiasQK").front()); + + TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, + static_cast(weight_data), + static_cast(weight_t->numel())}; + weight.dims.assign({n, m}); + + TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, + static_cast(bias_data), + static_cast(bias_t->numel())}; + + auto* fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *input, n, + weight.get(), bias.get()); + auto* fc_out = fc_layer->getOutput(0); + // add qkv to context + int head_size = all_head_size / head_number; + float scale = BOOST_GET_CONST(float, op_desc.GetAttr("alpha")); + + std::vector plugin_inputs; + plugin_inputs.push_back(fc_out); + plugin_inputs.push_back(input_bias_qk); bool ban_fp16 = engine_->disable_trt_plugin_fp16(); plugin::DynamicPluginTensorRT* plugin = new plugin::QkvToContextPluginDynamic(hidden, head_number, head_size, scale, ban_fp16); layer = engine_->AddPluginV2(plugin_inputs.data(), 2, plugin); +#endif } else { PADDLE_THROW(platform::errors::Fatal( "You are running the Ernie(Bert) model in static shape mode, which " diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index f4b0f5f23d8..6359f36c998 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -173,6 +173,8 @@ class OpConverter { "optim_input_shape should be same.")); } } + std::cerr << "Declare input: " << input << std::endl; + if (input.find("stack_0.tmp_0") != std::string::npos) continue; engine->DeclareInput( input, FluidDataType2TRT( var->Proto()->type().lod_tensor().tensor().data_type()), diff --git a/paddle/fluid/inference/tensorrt/convert/slice_op.cc b/paddle/fluid/inference/tensorrt/convert/slice_op.cc index 2a76317eea1..ed75d7b1583 100644 --- a/paddle/fluid/inference/tensorrt/convert/slice_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/slice_op.cc @@ -23,8 +23,9 @@ class SliceOpConverter : public OpConverter { public: void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { -// This OP is implemented by trt dynamic shpae plugin. -// Dynamic shape plugin requires TRT version greater than 6.0. + // This OP is implemented by trt dynamic shpae plugin. + // Dynamic shape plugin requires TRT version greater than 6.0. + std::cerr << "slice op converter\n" << std::endl; #if IS_TRT_VERSION_GE(6000) VLOG(4) << "convert slice op to tensorrt layer"; framework::OpDesc op_desc(op, nullptr); @@ -42,7 +43,7 @@ class SliceOpConverter : public OpConverter { if (engine_->with_dynamic_shape()) { bool ban_fp16 = engine_->disable_trt_plugin_fp16(); plugin::SlicePluginDynamic* plugin = - new plugin::SlicePluginDynamic(starts, ends, ends, ban_fp16); + new plugin::SlicePluginDynamic(starts, ends, axes, ban_fp16); layer = engine_->AddPluginV2(&input, 1, plugin); } else { PADDLE_THROW(platform::errors::Fatal( diff --git a/paddle/fluid/inference/tensorrt/convert/ut_helper.h b/paddle/fluid/inference/tensorrt/convert/ut_helper.h index 3c48c8192f6..ed347be1cac 100644 --- a/paddle/fluid/inference/tensorrt/convert/ut_helper.h +++ b/paddle/fluid/inference/tensorrt/convert/ut_helper.h @@ -183,6 +183,8 @@ class TRTConvertValidation { std::vector buffers(num_bindings); for (const std::string& name : input_output_names) { + // std::cerr << "Binding name: " << name << std::endl; + if (name.find("stack_0.tmp_0") != std::string::npos) continue; auto* var = scope_.FindVar(name); auto* tensor = var->GetMutable(); const int bind_index = engine_->engine()->getBindingIndex(name.c_str()); diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 03f5a751511..df4cecebaff 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -71,6 +71,7 @@ void TensorRTEngine::FreezeNetwork() { // build engine. infer_builder_->setMaxBatchSize(max_batch_); infer_builder_->setMaxWorkspaceSize(max_workspace_); + infer_builder_config_->setMaxWorkspaceSize(max_workspace_); bool enable_fp16 = (precision_ == AnalysisConfig::Precision::kHalf); #if IS_TRT_VERSION_GE(5000) if (enable_fp16) { diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 70ead9720d2..12506e7cc79 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -85,6 +85,7 @@ struct SimpleOpTypeSetTeller : public Teller { "gelu", "layer_norm", "scale", + "slice", }; }; diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index e417fcbb2ce..f02352e665d 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -2,6 +2,7 @@ nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu + cast_int_plugin.cu instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu hard_swish_op_plugin.cu DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor) diff --git a/paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.cu new file mode 100644 index 00000000000..ccb2f7dfe20 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.cu @@ -0,0 +1,85 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +// Dynamic Plugin below. +#if IS_TRT_VERSION_GE(6000) + +nvinfer1::DimsExprs CastIntPluginDynamic::getOutputDimensions( + int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs, + nvinfer1::IExprBuilder& expr_builder) { + assert(output_index == 0); + return inputs[0]; +} + +bool CastIntPluginDynamic::supportsFormatCombination( + int pos, const nvinfer1::PluginTensorDesc* in_out, int nb_inputs, + int nb_outputs) { + const nvinfer1::PluginTensorDesc& in = in_out[pos]; + return (in.type == nvinfer1::DataType::kINT32); +} + +nvinfer1::DataType CastIntPluginDynamic::getOutputDataType( + int index, const nvinfer1::DataType* input_types, int nb_inputs) const { + PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument( + "The Cast Int only has one input, so the " + "index value should be 0, but get %d.", + index)); + return input_types[index]; +} + +__global__ void castIntKernel(const int64_t* input, int32_t* output, + size_t num_elements) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= num_elements) return; + output[idx] = input[idx] + 1; +} + +int CastIntPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc, + const nvinfer1::PluginTensorDesc* output_desc, + const void* const* inputs, + void* const* outputs, void* workspace, + cudaStream_t stream) { + auto input_dims = input_desc[0].dims; + auto output_dims = output_desc[0].dims; + size_t num_elements = ProductDim(input_dims); + size_t out_num_elements = ProductDim(output_dims); + + assert(input_type == + nvinfer1::DataType::kINT32); // although the input is int64_t + assert(num_elements == out_num_elements); + + const size_t num_threads = 256; + castIntKernel<<>>( + static_cast(inputs[0]), static_cast(outputs[0]), + num_elements); + + return cudaGetLastError() != cudaSuccess; +} +#endif + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.h b/paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.h new file mode 100644 index 00000000000..039d1494e9a --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.h @@ -0,0 +1,120 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +#if IS_TRT_VERSION_GE(6000) +class CastIntPluginDynamic : public DynamicPluginTensorRT { + public: + CastIntPluginDynamic() {} + CastIntPluginDynamic(void const* serial_data, size_t serial_length) {} + + ~CastIntPluginDynamic() {} + nvinfer1::IPluginV2DynamicExt* clone() const override { + return new CastIntPluginDynamic(); + } + + const char* getPluginType() const override { return "cast_int_plugin"; } + int getNbOutputs() const override { return 1; } + int initialize() override { return 0; } + + size_t getSerializationSize() const override { return 0; } + void serialize(void* buffer) const override {} + + nvinfer1::DimsExprs getOutputDimensions( + int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs, + nvinfer1::IExprBuilder& expr_builder) override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* in_out, + int nb_inputs, int nb_outputs) override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nb_inputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nb_outputs) override {} + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nb_inputs, + const nvinfer1::PluginTensorDesc* outputs, + int nb_outputs) const override { + return 0; + } + + int enqueue(const nvinfer1::PluginTensorDesc* input_desc, + const nvinfer1::PluginTensorDesc* output_desc, + const void* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) override; + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* input_types, + int nb_inputs) const override; + + void destroy() override { delete this; } +}; + +class CastIntPluginV2Creator : public nvinfer1::IPluginCreator { + public: + CastIntPluginV2Creator() {} + const char* getPluginName() const override { return "cast_int_plugin"; } + + const char* getPluginVersion() const override { return "1"; } + + const nvinfer1::PluginFieldCollection* getFieldNames() override { + return &field_collection_; + } + + nvinfer1::IPluginV2* createPlugin( + const char* name, const nvinfer1::PluginFieldCollection* fc) override { + return nullptr; + } + + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serial_data, + size_t serial_length) override { + auto plugin = new CastIntPluginDynamic(serial_data, serial_length); + return plugin; + } + + void setPluginNamespace(const char* lib_namespace) override { + plugin_namespace_ = lib_namespace; + } + + const char* getPluginNamespace() const override { + return plugin_namespace_.c_str(); + } + + private: + std::string plugin_namespace_; + std::string plugin_name_; + nvinfer1::PluginFieldCollection field_collection_{0, nullptr}; + std::vector plugin_attributes_; +}; + +REGISTER_TRT_PLUGIN_V2(CastIntPluginV2Creator); +#endif + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index cc6ee7b19ea..3a008f55c79 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -221,11 +221,14 @@ class TensorRTEngineOp : public framework::OperatorBase { num_inputs += 1; } const int num_bindings = num_inputs + Outputs("Ys").size(); + // std::cerr << "num bindings: " << num_bindings << std::endl; std::vector buffers(num_bindings); // Bind input tensor to TRT. for (const auto &x : Inputs("Xs")) { if (param_names_.count(x)) continue; + // std::cerr << "runTRT name: " << x << std::endl; + if (x.find("stack_0.tmp_0") != std::string::npos) continue; // convert input and copy to TRT engine's buffer auto &t = inference::analysis::GetFromScope(scope, x); -- GitLab