From 0a51098a715fd648d8ff9cf95f48f5db6ba3ebf9 Mon Sep 17 00:00:00 2001 From: Pei Yang Date: Mon, 6 Jan 2020 19:58:24 +0800 Subject: [PATCH] Add TRT support for BERT (#21135) * add gelu plugin * align trt bert with gpu * add support for fused fc with relu, * add unittest for bert trt --- .../fluid/inference/api/analysis_predictor.cc | 3 + .../inference/api/paddle_pass_builder.cc | 11 +- .../inference/tensorrt/convert/CMakeLists.txt | 4 +- .../fluid/inference/tensorrt/convert/fc_op.cc | 84 ++++++- .../inference/tensorrt/convert/gelu_op.cc | 61 +++++ .../tensorrt/convert/layer_norm_op.cc | 108 +++++++++ .../tensorrt/convert/multihead_matmul_op.cc | 213 ++++++++++++++++++ .../inference/tensorrt/convert/op_converter.h | 21 +- paddle/fluid/inference/tensorrt/op_teller.cc | 5 +- .../inference/tensorrt/plugin/CMakeLists.txt | 4 +- .../tensorrt/plugin/gelu_op_plugin.cu | 76 +++++++ .../tensorrt/plugin/gelu_op_plugin.h | 72 ++++++ .../tensorrt/plugin/layer_norm_op_plugin.cu | 84 +++++++ .../tensorrt/plugin/layer_norm_op_plugin.h | 110 +++++++++ .../fluid/inference/tests/api/CMakeLists.txt | 3 + .../inference/tests/api/trt_bert_test.cc | 85 +++++++ paddle/fluid/operators/layer_norm_op.cu | 28 ++- paddle/fluid/operators/layer_norm_op.h | 13 ++ 18 files changed, 962 insertions(+), 23 deletions(-) create mode 100644 paddle/fluid/inference/tensorrt/convert/gelu_op.cc create mode 100644 paddle/fluid/inference/tensorrt/convert/layer_norm_op.cc create mode 100644 paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc create mode 100644 paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.cu create mode 100644 paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.h create mode 100644 paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu create mode 100644 paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h create mode 100644 paddle/fluid/inference/tests/api/trt_bert_test.cc diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 9873d069b4..f7c12a5cc8 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -938,6 +938,9 @@ USE_TRT_CONVERTER(conv2d_transpose); USE_TRT_CONVERTER(leaky_relu); USE_TRT_CONVERTER(shuffle_channel); USE_TRT_CONVERTER(swish); +USE_TRT_CONVERTER(layer_norm); +USE_TRT_CONVERTER(gelu); +USE_TRT_CONVERTER(multihead_matmul); #endif #if PADDLE_WITH_ANAKIN diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index ce2da47531..49f637d96b 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -76,10 +76,13 @@ const std::vector kTRTSubgraphPasses({ "shuffle_channel_detect_pass", // "quant_conv2d_dequant_fuse_pass", // "delete_quant_dequant_op_pass", // - "conv_bn_fuse_pass", // - "fc_fuse_pass", // - "tensorrt_subgraph_pass", // - "conv_bn_fuse_pass", // + // "fc_fuse_pass", // + "simplify_with_basic_ops_pass", // + "multihead_matmul_fuse_pass", // + "conv_bn_fuse_pass", // + "fc_fuse_pass", // + "tensorrt_subgraph_pass", // + "conv_bn_fuse_pass", // #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be // guaranteed at least v7 "conv_elementwise_add_act_fuse_pass", // diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index b63b75f789..e212388cb9 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -2,8 +2,8 @@ nv_library(tensorrt_converter SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc - pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc -shuffle_channel_op.cc swish_op.cc + pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc + shuffle_channel_op.cc swish_op.cc DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS diff --git a/paddle/fluid/inference/tensorrt/convert/fc_op.cc b/paddle/fluid/inference/tensorrt/convert/fc_op.cc index ec21bb5534..e82357047c 100644 --- a/paddle/fluid/inference/tensorrt/convert/fc_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/fc_op.cc @@ -44,7 +44,6 @@ void ReorderCKtoKC(TensorRTEngine::Weight& iweights, // NOLINT static_cast(const_cast(oweights->get().values)), ostrides); } - /* * FC converter convert a MUL op in Fluid to a FC layer in TRT. */ @@ -63,7 +62,6 @@ class FcOpConverter : public OpConverter { w_name = "W"; i_name = "Input"; } - // Declare inputs auto* X = engine_->GetITensor(op_desc.Input(i_name).front()); @@ -71,6 +69,16 @@ class FcOpConverter : public OpConverter { auto* Y_v = scope.FindVar(op_desc.Input(w_name).front()); PADDLE_ENFORCE_NOT_NULL(Y_v); auto* Y_t = Y_v->GetMutable(); + const int x_num_col_dims = + op_desc.HasAttr("x_num_col_dims") + ? boost::get(op_desc.GetAttr("x_num_col_dims")) + : (op_desc.HasAttr("in_num_col_dims") + ? boost::get(op_desc.GetAttr("in_num_col_dims")) + : 1); + const std::string activation_type = + op_desc.HasAttr("activation_type") + ? boost::get(op_desc.GetAttr("activation_type")) + : ""; // This may trigger a GPU->CPU copy, because TRT's weight can only be // assigned from CPU memory, which can't be avoided. float* weight_data = nullptr; @@ -128,14 +136,76 @@ class FcOpConverter : public OpConverter { static_cast(bias_data), static_cast(bias_num)}; - auto* layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, - *const_cast(X), - n_output, tmp_weight.get(), bias.get()); + // in order to handle situations in NLP models(input dims < 3, + // x_num_col_dims != 1, etc.), reshape input to perform FC correctly. + auto* reshape_itensor = X; + int input_dims = X->getDimensions().nbDims; + auto input_d = X->getDimensions().d; + int reshape_dim3[3] = {0}; + int reshape_dim4[4] = {0}; + PADDLE_ENFORCE_EQ( + x_num_col_dims == 1 || x_num_col_dims == 2, true, + platform::errors::InvalidArgument( + "Wrong x_num_col_dims param of op mul. Paddle-TRT FC converter " + "expects x_num_col_dims is either 1 or 2, but got %d", + x_num_col_dims)); + PADDLE_ENFORCE_LE(x_num_col_dims, input_dims, + platform::errors::InvalidArgument( + "Params and input dims mismatch. Paddle-TRT FC " + "converter expects x_num_col_dims <= input dims")); + if (x_num_col_dims == 1) { + if (input_dims == 4) { + PADDLE_ENFORCE_EQ( + input_d[3], 1, + platform::errors::InvalidArgument( + "Invalid dimensions. When x_num_col_dims equals to 1 and input " + "dims equals to 4, the last dim of input must be 1, but got %d", + input_d[3])); + } + for (int i = 0; i < 3; i++) { + if (i < input_dims) { + reshape_dim3[i] = input_d[i]; + } else { + reshape_dim3[i] = 1; + } + } + nvinfer1::Dims3 reshape_dim(reshape_dim3[0], reshape_dim3[1], + reshape_dim3[2]); + auto* reshape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *X); + reshape_layer->setReshapeDimensions(reshape_dim); + reshape_itensor = reshape_layer->getOutput(0); + } else { + PADDLE_ENFORCE_NE(input_dims, 1, + platform::errors::InvalidArgument( + "Invalid dimensions. When x_num_col_dims equals to " + "2, input_dims should not be 1")); + for (int i = 0; i < 4; i++) { + if (i < input_dims) { + reshape_dim4[i] = input_d[i]; + } else { + reshape_dim4[i] = 1; + } + } + nvinfer1::Dims4 reshape_dim(reshape_dim4[0], reshape_dim4[1], + reshape_dim4[2], reshape_dim4[3]); + auto* reshape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *X); + reshape_layer->setReshapeDimensions(reshape_dim); + reshape_itensor = reshape_layer->getOutput(0); + } + auto* fc_layer = + TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *reshape_itensor, + n_output, tmp_weight.get(), bias.get()); engine_->SetWeights(op_desc.Input(w_name).front(), std::move(tmp)); auto output_name = op_desc.Output("Out").front(); - - RreplenishLayerAndOutput(layer, "fc", {output_name}, test_mode); + if (activation_type == "relu") { + nvinfer1::IActivationLayer* relu_layer = + TRT_ENGINE_ADD_LAYER(engine_, Activation, *(fc_layer->getOutput(0)), + nvinfer1::ActivationType::kRELU); + RreplenishLayerAndOutput(relu_layer, "fc", {output_name}, test_mode); + } else { + RreplenishLayerAndOutput(fc_layer, "fc", {output_name}, test_mode); + } } }; diff --git a/paddle/fluid/inference/tensorrt/convert/gelu_op.cc b/paddle/fluid/inference/tensorrt/convert/gelu_op.cc new file mode 100644 index 0000000000..b72cded3fd --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/gelu_op.cc @@ -0,0 +1,61 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * Gelu converter from fluid to tensorRT. + */ +class GeluOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(4) << "convert fluid gelu op to tensorrt gelu layer"; + + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + int input_num = op_desc.Input("X").size(); + PADDLE_ENFORCE_EQ(input_num, 1, + platform::errors::InvalidArgument( + "gelu op has only 1 input, but got %d", input_num)); + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + // Get output + size_t output_num = op_desc.Output("Out").size(); + PADDLE_ENFORCE_EQ(output_num, 1, + platform::errors::InvalidArgument( + "gelu op has only 1 output, but got %d", output_num)); + // Get input shape and volume + nvinfer1::Dims input_shape = input->getDimensions(); + size_t input_volume = 1; + for (int i = 0; i < input_shape.nbDims; i++) { + input_volume *= input_shape.d[i]; + } + plugin::GeluPlugin* plugin = new plugin::GeluPlugin(input_volume); + nvinfer1::IPluginLayer* layer = + engine_->AddPlugin(&input, input_num, plugin); + auto output_name = op_desc.Output("Out")[0]; + RreplenishLayerAndOutput(layer, "gelu", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(gelu, GeluOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/layer_norm_op.cc b/paddle/fluid/inference/tensorrt/convert/layer_norm_op.cc new file mode 100644 index 0000000000..7d714da439 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/layer_norm_op.cc @@ -0,0 +1,108 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include "paddle/fluid/operators/layer_norm_op.h" +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class LayerNormOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(4) << "convert a fluid layer_norm op to tensorrt layer_norm plugin"; + framework::OpDesc op_desc(op, nullptr); + PADDLE_ENFORCE_EQ( + op_desc.Input("X").size(), 1, + platform::errors::InvalidArgument( + "input of layer_norm op converter should be 1, got %d", + op_desc.Input("X").size())); + PADDLE_ENFORCE_EQ(op_desc.Input("Bias").size(), 1, + platform::errors::InvalidArgument( + "Bias of layer_norm op converter should be 1, got %d", + op_desc.Input("Bias").size())); // Bias is a weight + PADDLE_ENFORCE_EQ( + op_desc.Input("Scale").size(), 1, + platform::errors::InvalidArgument( + "Scale of layer_norm op converter should be 1, got %d", + op_desc.Input("Scale").size())); // Scale is a weight + PADDLE_ENFORCE_EQ( + op_desc.Output("Y").size(), 1, + platform::errors::InvalidArgument( + "output of layer_norm op converter should be 1, got %d", + op_desc.Input("Y").size())); + + auto* X = engine_->GetITensor(op_desc.Input("X").front()); + auto* Bias_v = scope.FindVar(op_desc.Input("Bias").front()); + auto* Scale_v = scope.FindVar(op_desc.Input("Scale").front()); + const int begin_norm_axis = + op_desc.HasAttr("begin_norm_axis") + ? boost::get(op_desc.GetAttr("begin_norm_axis")) + : 1; + const float eps = op_desc.HasAttr("epsilon") + ? boost::get(op_desc.GetAttr("epsilon")) + : 1e-5f; + PADDLE_ENFORCE_NOT_NULL( + Bias_v, platform::errors::InvalidArgument( + "Input(Bias) of layer_norm should not be null.")); + PADDLE_ENFORCE_NOT_NULL( + Scale_v, platform::errors::InvalidArgument( + "Input(Scale) of layer_norm should not be null.")); + + auto* Bias_t = Bias_v->GetMutable(); + auto* Scale_t = Scale_v->GetMutable(); + + int input_num = 1; + for (int i = 0; i < X->getDimensions().nbDims; i++) { + input_num *= X->getDimensions().d[i]; + } + std::vector mean_shape{input_num}; + std::vector variance_shape{input_num}; + + std::unique_ptr bias_tensor( + new framework::LoDTensor()); + std::unique_ptr scale_tensor( + new framework::LoDTensor()); + + bias_tensor->Resize(Bias_t->dims()); + scale_tensor->Resize(Scale_t->dims()); + + platform::CPUPlace cpu_place; + TensorCopySync((*Bias_t), cpu_place, &(*bias_tensor)); + TensorCopySync((*Scale_t), cpu_place, &(*scale_tensor)); + + auto* bias_data = bias_tensor->mutable_data(platform::CPUPlace()); + auto* scale_data = scale_tensor->mutable_data(platform::CPUPlace()); + + plugin::LayerNormPlugin* plugin = new plugin::LayerNormPlugin( + bias_data, bias_tensor->numel(), scale_data, scale_tensor->numel(), + begin_norm_axis, eps, mean_shape, variance_shape); + nvinfer1::IPluginLayer* layernorm_layer = engine_->AddPlugin(&X, 1, plugin); + + auto output_name = op_desc.Output("Y").front(); + engine_->SetWeights(op_desc.Input("Bias").front(), std::move(bias_tensor)); + engine_->SetWeights(op_desc.Input("Scale").front(), + std::move(scale_tensor)); + RreplenishLayerAndOutput(layernorm_layer, "layer_norm", {output_name}, + test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(layer_norm, LayerNormOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc new file mode 100644 index 0000000000..93c5981709 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc @@ -0,0 +1,213 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class MultiheadMatMulOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(3) << "convert a fluid multihead_mamul op to a corresponding tensorrt " + "network structure"; + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + auto* Q = engine_->GetITensor(op_desc.Input("Q").front()); + auto* K = engine_->GetITensor(op_desc.Input("K").front()); + auto* V = engine_->GetITensor(op_desc.Input("V").front()); + auto* BiasQ = scope.FindVar(op_desc.Input("BiasQ").front()); + auto* BiasK = scope.FindVar(op_desc.Input("BiasK").front()); + auto* BiasV = scope.FindVar(op_desc.Input("BiasV").front()); + auto* BiasQK = engine_->GetITensor(op_desc.Input("BiasQK").front()); + PADDLE_ENFORCE_EQ(op_desc.Input("Q").size(), 1, + platform::errors::InvalidArgument( + "size of input Q of multihead_matmul should be 1")); + PADDLE_ENFORCE_EQ(op_desc.Input("K").size(), 1, + platform::errors::InvalidArgument( + "size of input K of multihead_matmul should be 1")); + PADDLE_ENFORCE_EQ(op_desc.Input("V").size(), 1, + platform::errors::InvalidArgument( + "size of input V of multihead_matmul should be 1")); + PADDLE_ENFORCE_EQ( + op_desc.Input("BiasQK").size(), 1, + platform::errors::InvalidArgument( + "size of input BiasQK of multihead_matmul should be 1")); + PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1, + platform::errors::InvalidArgument( + "size of output of multihead_matmul should be 1")); + PADDLE_ENFORCE_NOT_NULL( + BiasQ, platform::errors::InvalidArgument( + "param BiasQ of multihead_matmul should not be null")); + PADDLE_ENFORCE_NOT_NULL( + BiasK, platform::errors::InvalidArgument( + "param BiasK of multihead_matmul should not be null")); + PADDLE_ENFORCE_NOT_NULL( + BiasV, platform::errors::InvalidArgument( + "param BiasV of multihead_matmul should not be null")); + PADDLE_ENFORCE_EQ( + BiasQK->getDimensions().nbDims, 3, + platform::errors::InvalidArgument( + "dims size of input BiasQK of multihead_matmul should be 3")); + PADDLE_ENFORCE_EQ( + op_desc.HasAttr("alpha"), true, + platform::errors::PreconditionNotMet( + "attribute alpha of multihead_matmul should not be empty")); + PADDLE_ENFORCE_EQ( + op_desc.HasAttr("head_number"), true, + platform::errors::PreconditionNotMet( + "attribute head_number of multihead_matmul should not be empty")); + + // Declare attributes + const bool transpose_q = + op_desc.HasAttr("transpose_Q") + ? boost::get(op_desc.GetAttr("transpose_Q")) + : false; + const bool transpose_k = + op_desc.HasAttr("transpose_K") + ? boost::get(op_desc.GetAttr("transpose_K")) + : true; + const bool transpose_v = + op_desc.HasAttr("transpose_V") + ? boost::get(op_desc.GetAttr("transpose_V")) + : false; + const float alpha = boost::get(op_desc.GetAttr("alpha")); + const int head_number = boost::get(op_desc.GetAttr("head_number")); + + nvinfer1::Dims q_shape = Q->getDimensions(); + int seq_len = q_shape.d[0]; + int size_per_head = q_shape.d[1] / head_number; + std::string alpha_name = op_desc.Output("Out")[0] + "_alpha"; + framework::DDim alpha_dim = framework::make_ddim({1}); + std::unique_ptr alpha_t(new framework::LoDTensor()); + alpha_t->Resize(alpha_dim); + float* alpha_data = alpha_t->mutable_data(platform::CPUPlace()); + alpha_data[0] = alpha; + + TensorRTEngine::Weight scale{nvinfer1::DataType::kFLOAT, + static_cast(alpha_data), 1}; + TensorRTEngine::Weight shift{nvinfer1::DataType::kFLOAT, nullptr, 0}; + TensorRTEngine::Weight power{nvinfer1::DataType::kFLOAT, nullptr, 0}; + + auto* bias_q_t = BiasQ->GetMutable(); + auto* bias_k_t = BiasK->GetMutable(); + auto* bias_v_t = BiasV->GetMutable(); + float* bias_q_cpu_data = engine_->GetWeightCPUData( + op_desc.Input("BiasQ").front(), bias_q_t, false); + float* bias_k_cpu_data = engine_->GetWeightCPUData( + op_desc.Input("BiasK").front(), bias_k_t, false); + float* bias_v_cpu_data = engine_->GetWeightCPUData( + op_desc.Input("BiasV").front(), bias_v_t, false); + std::unique_ptr bias_q_tensor( + new framework::LoDTensor()); + std::unique_ptr bias_k_tensor( + new framework::LoDTensor()); + std::unique_ptr bias_v_tensor( + new framework::LoDTensor()); + bias_q_tensor->Resize(bias_q_t->dims()); + bias_k_tensor->Resize(bias_k_t->dims()); + bias_v_tensor->Resize(bias_v_t->dims()); + platform::CPUPlace cpu_place; + TensorCopySync((*bias_q_t), cpu_place, bias_q_tensor.get()); + TensorCopySync((*bias_k_t), cpu_place, bias_k_tensor.get()); + TensorCopySync((*bias_v_t), cpu_place, bias_v_tensor.get()); + + TensorRTEngine::Weight scale_weights_q{nvinfer1::DataType::kFLOAT, nullptr, + 0}; + TensorRTEngine::Weight shift_weights_q{ + nvinfer1::DataType::kFLOAT, static_cast(bias_q_cpu_data), + bias_q_tensor->memory_size() / sizeof(float)}; + TensorRTEngine::Weight power_weights_q{nvinfer1::DataType::kFLOAT, nullptr, + 0}; + TensorRTEngine::Weight scale_weights_k{nvinfer1::DataType::kFLOAT, nullptr, + 0}; + TensorRTEngine::Weight shift_weights_k{ + nvinfer1::DataType::kFLOAT, static_cast(bias_k_cpu_data), + bias_k_tensor->memory_size() / sizeof(float)}; + TensorRTEngine::Weight power_weights_k{nvinfer1::DataType::kFLOAT, nullptr, + 0}; + TensorRTEngine::Weight scale_weights_v{nvinfer1::DataType::kFLOAT, nullptr, + 0}; + TensorRTEngine::Weight shift_weights_v{ + nvinfer1::DataType::kFLOAT, static_cast(bias_v_cpu_data), + bias_v_tensor->memory_size() / sizeof(float)}; + TensorRTEngine::Weight power_weights_v{nvinfer1::DataType::kFLOAT, nullptr, + 0}; + + auto* q_eltadd_layer = TRT_ENGINE_ADD_LAYER( + engine_, Scale, *Q, nvinfer1::ScaleMode::kCHANNEL, + shift_weights_q.get(), scale_weights_q.get(), power_weights_q.get()); + auto* k_eltadd_layer = TRT_ENGINE_ADD_LAYER( + engine_, Scale, *K, nvinfer1::ScaleMode::kCHANNEL, + shift_weights_k.get(), scale_weights_k.get(), power_weights_k.get()); + auto* v_eltadd_layer = TRT_ENGINE_ADD_LAYER( + engine_, Scale, *V, nvinfer1::ScaleMode::kCHANNEL, + shift_weights_v.get(), scale_weights_v.get(), power_weights_v.get()); + auto* v_transpose_reshape_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(v_eltadd_layer->getOutput(0))); + auto* q_transpose_reshape_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(q_eltadd_layer->getOutput(0))); + auto* k_transpose_reshape_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(k_eltadd_layer->getOutput(0))); + + nvinfer1::Dims3 head_reshape_dim(seq_len, head_number, size_per_head); + v_transpose_reshape_layer->setReshapeDimensions(head_reshape_dim); + v_transpose_reshape_layer->setSecondTranspose({1, 0, 2}); + q_transpose_reshape_layer->setReshapeDimensions(head_reshape_dim); + q_transpose_reshape_layer->setSecondTranspose({1, 0, 2}); + k_transpose_reshape_layer->setReshapeDimensions(head_reshape_dim); + k_transpose_reshape_layer->setSecondTranspose({1, 0, 2}); + + auto* q_scale_layer = TRT_ENGINE_ADD_LAYER( + engine_, Scale, *(q_transpose_reshape_layer->getOutput(0)), + nvinfer1::ScaleMode::kUNIFORM, shift.get(), scale.get(), power.get()); + auto* qk_matmul_layer = TRT_ENGINE_ADD_LAYER( + engine_, MatrixMultiply, *(q_scale_layer->getOutput(0)), transpose_q, + *(k_transpose_reshape_layer->getOutput(0)), transpose_k); + auto* qk_eltadd_layer = TRT_ENGINE_ADD_LAYER( + engine_, ElementWise, *BiasQK, *(qk_matmul_layer->getOutput(0)), + nvinfer1::ElementWiseOperation::kSUM); + auto* softmax_layer = TRT_ENGINE_ADD_LAYER( + engine_, SoftMax, *(qk_eltadd_layer->getOutput(0))); + softmax_layer->setAxes(4); + auto* qkv_matmul_layer = TRT_ENGINE_ADD_LAYER( + engine_, MatrixMultiply, *(softmax_layer->getOutput(0)), false, + *(v_transpose_reshape_layer->getOutput(0)), transpose_v); + auto* qkv_transpose_reshape_layer = TRT_ENGINE_ADD_LAYER( + engine_, Shuffle, *(qkv_matmul_layer->getOutput(0))); + nvinfer1::Dims2 qkv_reshape_dim(seq_len, head_number * size_per_head); + qkv_transpose_reshape_layer->setFirstTranspose({1, 0, 2}); + qkv_transpose_reshape_layer->setReshapeDimensions(qkv_reshape_dim); + + engine_->SetWeights(alpha_name, std::move(alpha_t)); + engine_->SetWeights(op_desc.Input("BiasQ").front(), + std::move(bias_q_tensor)); + engine_->SetWeights(op_desc.Input("BiasK").front(), + std::move(bias_k_tensor)); + engine_->SetWeights(op_desc.Input("BiasV").front(), + std::move(bias_v_tensor)); + + auto output_name = op_desc.Output("Out").front(); + RreplenishLayerAndOutput(qkv_transpose_reshape_layer, "multihead_matmul", + {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(multihead_matmul, MultiheadMatMulOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index 3a2deae360..ca5e1b8a74 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -43,18 +43,27 @@ TRT_DT FluidDataType2TRT(FluidDT type) { default: return TRT_DT::kINT32; } - PADDLE_THROW("unkown type"); + PADDLE_THROW(platform::errors::InvalidArgument( + "unknown fluid datatype in TRT op converter")); return TRT_DT::kINT32; } -nvinfer1::Dims Vec2TRT_Dims(const std::vector& shape) { +nvinfer1::Dims Vec2TRT_Dims(const std::vector& shape, + std::string input) { PADDLE_ENFORCE_GT(shape.size(), 1UL, - "TensorRT' tensor input requires at least 2 dimensions"); + platform::errors::InvalidArgument( + "TensorRT's tensor input requires at least 2 " + "dimensions, but input %s has %d dims.", + input, shape.size())); PADDLE_ENFORCE_LE(shape.size(), 4UL, - "TensorRT' tensor input requires at most 4 dimensions"); - PADDLE_ENFORCE(shape.size() == 4UL || shape.size() == 2UL); + platform::errors::InvalidArgument( + "TensorRT's tensor input requires at most 4 " + "dimensions, but input %s has %d dims.", + input, shape.size())); if (shape.size() == 4UL) return nvinfer1::DimsCHW(shape[1], shape[2], shape[3]); + else if (shape.size() == 3UL) + return nvinfer1::Dims2(shape[1], shape[2]); return nvinfer1::DimsCHW(shape[1], 1, 1); } @@ -162,7 +171,7 @@ class OpConverter { engine->DeclareInput( input, FluidDataType2TRT( var->Proto()->type().lod_tensor().tensor().data_type()), - Vec2TRT_Dims(var_shape)); + Vec2TRT_Dims(var_shape, input)); } framework::proto::BlockDesc* block_proto = block_desc->Proto(); ConvertBlock(*block_proto, parameters, scope, engine); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index aa9f330bc7..462c6fb497 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -52,7 +52,10 @@ struct SimpleOpTypeSetTeller : public Teller { "fc", "shuffle_channel", "swish", - "split"}}; + "split", + "gelu", + "layer_norm", + "multihead_matmul"}}; }; bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc) { diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index b505fa4662..83efecc0bf 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -1,5 +1,5 @@ nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu - prelu_op_plugin.cu trt_plugin_factory.cc - pool_op_plugin.cu swish_op_plugin.cu + prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu + pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu DEPS enforce tensorrt_engine prelu) diff --git a/paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.cu new file mode 100644 index 0000000000..b31691f9cb --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.cu @@ -0,0 +1,76 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +// constants for approximating the normal cdf +constexpr float A = 1.41421356237309504; // sqrt(2) + +GeluPlugin* CreateGeluPluginDeserialize(const void* buffer, size_t length) { + return new GeluPlugin(buffer, length); +} +REGISTER_TRT_PLUGIN("gelu plugin", CreateGeluPluginDeserialize); + +nvinfer1::Dims GeluPlugin::getOutputDimensions(int index, + const nvinfer1::Dims* in_dims, + int nb_inputs) { + assert(nb_inputs == 1); + assert(index < this->getNbOutputs()); + nvinfer1::Dims const& input_dims = in_dims[0]; + nvinfer1::Dims output_dims = input_dims; + return output_dims; +} + +template +__global__ void geluKernel(const T a, int n, const T* input, T* output) { + const int idx = blockIdx.x * TPB + threadIdx.x; + if (idx < n) { + const T in = input[idx]; + const T cdf = 0.5 * (1.0 + erf(in * 0.5 * a)); + output[idx] = in * cdf; + } +} + +int computeGelu(cudaStream_t stream, int n, const float* input, float* output) { + constexpr int blockSize = 256; + const int gridSize = (n + blockSize - 1) / blockSize; + geluKernel<<>>(A, n, input, + output); + cudaError_t error = cudaGetLastError(); + if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error); + return 0; +} + +int GeluPlugin::enqueue(int batchSize, const void* const* inputs, + void** outputs, void*, cudaStream_t stream) { + int status = -1; + const float* input = static_cast(inputs[0]); + float* output = static_cast(outputs[0]); + status = computeGelu(stream, input_volume_ * batchSize, input, output); + return status; +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.h new file mode 100644 index 0000000000..7c9aeed5f5 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/gelu_op_plugin.h @@ -0,0 +1,72 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +class GeluPlugin : public PluginTensorRT { + protected: + size_t getSerializationSize() override { + return getBaseSerializationSize() + SerializedSize(getPluginType()) + + SerializedSize(input_volume_); + } + + // TRT will call this func to serialize the configuration of TRT + // It should not be called by users. + void serialize(void *buffer) override { + SerializeValue(&buffer, getPluginType()); + serializeBase(buffer); + SerializeValue(&buffer, input_volume_); + } + + public: + explicit GeluPlugin(size_t input_volume) : input_volume_(input_volume) {} + + // It was used for tensorrt deserialization. + // It should not be called by users. + GeluPlugin(void const *serialData, size_t serialLength) { + deserializeBase(serialData, serialLength); + DeserializeValue(&serialData, &serialLength, &input_volume_); + } + + ~GeluPlugin() {} + + int initialize() override { return 0; } + + GeluPlugin *clone() const override { return new GeluPlugin(input_volume_); } + + const char *getPluginType() const override { return "gelu_plugin"; } + int getNbOutputs() const override { return 1; } + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs, + int nbInputDims) override; + int enqueue(int batchSize, const void *const *inputs, void **outputs, + void *workspace, cudaStream_t stream) override; + + private: + size_t input_volume_; +}; + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu new file mode 100644 index 0000000000..7c905a245a --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu @@ -0,0 +1,84 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "glog/logging.h" +#include "paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h" +#include "paddle/fluid/operators/layer_norm_op.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +LayerNormPlugin *CreateLayerNormPluginDeserialize(const void *buffer, + size_t length) { + return new LayerNormPlugin(buffer, length); +} +REGISTER_TRT_PLUGIN("layer_norm_plugin", CreateLayerNormPluginDeserialize); + +int LayerNormPlugin::initialize() { return 0; } + +nvinfer1::Dims LayerNormPlugin::getOutputDimensions( + int index, const nvinfer1::Dims *inputDims, int nbInputs) { + assert(nbInputs == 1); + assert(index < this->getNbOutputs()); + nvinfer1::Dims const &input_dims = inputDims[0]; + nvinfer1::Dims output_dims = input_dims; + return output_dims; +} + +int LayerNormPlugin::enqueue(int batch_size, const void *const *inputs, + void **outputs, void *workspace, + cudaStream_t stream) { + const auto &input_dims = this->getInputDims(0); + const float *input = reinterpret_cast(inputs[0]); + float *output = reinterpret_cast(outputs)[0]; + int begin_norm_axis = begin_norm_axis_; + float eps = eps_; + int c = input_dims.d[begin_norm_axis - 1]; + + scale_t.Resize(framework::make_ddim({c})); + bias_t.Resize(framework::make_ddim({c})); + mean_t.Resize(framework::make_ddim(mean_shape_)); + variance_t.Resize(framework::make_ddim(variance_shape_)); + int device_id; + cudaGetDevice(&device_id); + float *scale_d = scale_t.mutable_data(platform::CUDAPlace(device_id)); + float *bias_d = bias_t.mutable_data(platform::CUDAPlace(device_id)); + float *mean_d = mean_t.mutable_data(platform::CUDAPlace(device_id)); + float *variance_d = + variance_t.mutable_data(platform::CUDAPlace(device_id)); + cudaMemcpyAsync(scale_d, scale_.data(), sizeof(float) * c, + cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(bias_d, bias_.data(), sizeof(float) * c, + cudaMemcpyHostToDevice, stream); + std::vector input_shape; + input_shape.push_back(batch_size); + for (int i = 0; i < input_dims.nbDims; i++) { + input_shape.push_back(input_dims.d[i]); + } + paddle::operators::LayerNormDirectCUDAFunctor layer_norm; + layer_norm(stream, input, input_shape, bias_d, scale_d, output, mean_d, + variance_d, begin_norm_axis, eps); + return cudaGetLastError() != cudaSuccess; +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h new file mode 100644 index 0000000000..050ef3b77d --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h @@ -0,0 +1,110 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/inference/tensorrt/engine.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +class LayerNormPlugin : public PluginTensorRT { + std::vector bias_; + std::vector scale_; + framework::Tensor scale_t; + framework::Tensor bias_t; + framework::Tensor mean_t; + framework::Tensor variance_t; + int begin_norm_axis_; + float eps_; + std::vector mean_shape_; + std::vector variance_shape_; + + protected: + size_t getSerializationSize() override { + return getBaseSerializationSize() + SerializedSize(bias_) + + SerializedSize(scale_) + SerializedSize(begin_norm_axis_) + + SerializedSize(eps_) + SerializedSize(mean_shape_) + + SerializedSize(variance_shape_) + SerializedSize(getPluginType()); + } + + // TRT will call this func when we need to serialize the configuration of + // tensorrt. + // It should not be called by users. + void serialize(void *buffer) override { + SerializeValue(&buffer, getPluginType()); + serializeBase(buffer); + SerializeValue(&buffer, bias_); + SerializeValue(&buffer, scale_); + SerializeValue(&buffer, begin_norm_axis_); + SerializeValue(&buffer, eps_); + SerializeValue(&buffer, mean_shape_); + SerializeValue(&buffer, variance_shape_); + } + + public: + LayerNormPlugin(const float *bias, const int bias_num, const float *scale, + const int scale_num, int begin_norm_axis, float eps, + std::vector mean_shape, + std::vector variance_shape) + : begin_norm_axis_(begin_norm_axis), + eps_(eps), + mean_shape_(mean_shape), + variance_shape_(variance_shape) { + bias_.resize(bias_num); + scale_.resize(scale_num); + std::copy(bias, bias + bias_num, bias_.data()); + std::copy(scale, scale + scale_num, scale_.data()); + } + + // It was used for tensorrt deserialization. + // It should not be called by users. + LayerNormPlugin(void const *serialData, size_t serialLength) { + deserializeBase(serialData, serialLength); + DeserializeValue(&serialData, &serialLength, &bias_); + DeserializeValue(&serialData, &serialLength, &scale_); + DeserializeValue(&serialData, &serialLength, &begin_norm_axis_); + DeserializeValue(&serialData, &serialLength, &eps_); + DeserializeValue(&serialData, &serialLength, &mean_shape_); + DeserializeValue(&serialData, &serialLength, &variance_shape_); + } + ~LayerNormPlugin() {} + int initialize() override; + + LayerNormPlugin *clone() const override { + return new LayerNormPlugin(bias_.data(), bias_.size(), scale_.data(), + scale_.size(), begin_norm_axis_, eps_, + mean_shape_, variance_shape_); + } + + const char *getPluginType() const override { return "layer_norm_plugin"; } + int getNbOutputs() const override { return 1; } + nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims *inputs, + int nbInputDims) override; + int enqueue(int batchSize, const void *const *inputs, void **outputs, + void *workspace, cudaStream_t stream) override; +}; + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 45ba4fbe5a..d58606be0b 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -330,6 +330,9 @@ if(WITH_GPU AND TENSORRT_FOUND) inference_analysis_test(trt_resnext_test SRCS trt_resnext_test.cc EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models) + inference_analysis_test(trt_bert_test SRCS trt_bert_test.cc + EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} + ARGS --infer_model=${BERT_INSTALL_DIR}/model) inference_analysis_test(trt_fc_prelu_test SRCS trt_fc_prelu_test.cc EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models) diff --git a/paddle/fluid/inference/tests/api/trt_bert_test.cc b/paddle/fluid/inference/tests/api/trt_bert_test.cc new file mode 100644 index 0000000000..818c0bfc0a --- /dev/null +++ b/paddle/fluid/inference/tests/api/trt_bert_test.cc @@ -0,0 +1,85 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "paddle/fluid/inference/tests/api/trt_test_helper.h" + +namespace paddle { +namespace inference { + +TEST(TensorRT, split_converter) { + AnalysisConfig config; + int batch_size = 1; + config.SetModel(FLAGS_infer_model); + config.EnableUseGpu(1200, 0); + config.SwitchUseFeedFetchOps(false); + config.EnableTensorRtEngine(1 << 30, batch_size, 10, + AnalysisConfig::Precision::kFloat32, false, + false); + auto predictor = CreatePaddlePredictor(config); + int64_t i0[128] = { + 96, 54, 78, 37, 106, 35, 122, 33, 95, 63, 81, 60, 65, 68, 45, 96, + 117, 61, 43, 15, 12, 64, 91, 100, 90, 74, 99, 23, 22, 91, 83, 13, + 28, 71, 59, 15, 40, 26, 66, 18, 31, 87, 85, 11, 55, 67, 28, 126, + 7, 89, 39, 67, 88, 29, 66, 38, 98, 1, 66, 38, 95, 56, 48, 95, + 9, 38, 90, 82, 101, 6, 75, 46, 42, 89, 98, 12, 6, 101, 82, 55, + 81, 113, 33, 91, 44, 73, 41, 39, 12, 113, 13, 86, 36, 91, 53, 68, + 103, 67, 65, 92, 27, 76, 24, 107, 54, 94, 63, 10, 15, 32, 91, 45, + 37, 126, 49, 118, 73, 127, 122, 119, 28, 96, 92, 79, 21, 90, 11, 40}; + int64_t i1[128] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, + 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, + 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, + 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, + 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, + 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, + 120, 121, 122, 123, 124, 125, 126, 127}; + int64_t i2[128] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + float i3[128 * 128] = {0.0}; + int64_t i4[1] = {0}; + + auto input_names = predictor->GetInputNames(); + + auto input_t0 = predictor->GetInputTensor(input_names[0]); + input_t0->Reshape({batch_size, 128, 1}); + input_t0->copy_from_cpu(i0); + auto input_t1 = predictor->GetInputTensor(input_names[1]); + input_t1->Reshape({batch_size, 128, 1}); + input_t1->copy_from_cpu(i1); + auto input_t2 = predictor->GetInputTensor(input_names[2]); + input_t2->Reshape({batch_size, 128, 1}); + input_t2->copy_from_cpu(i2); + auto input_t3 = predictor->GetInputTensor(input_names[3]); + input_t3->Reshape({batch_size, 128, 128}); + input_t3->copy_from_cpu(i3); + auto input_t4 = predictor->GetInputTensor(input_names[4]); + input_t4->Reshape({batch_size, 1}); + input_t4->copy_from_cpu(i4); + + ASSERT_TRUE(predictor->ZeroCopyRun()); +} + +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/operators/layer_norm_op.cu b/paddle/fluid/operators/layer_norm_op.cu index 22343d7724..e42a1f4803 100644 --- a/paddle/fluid/operators/layer_norm_op.cu +++ b/paddle/fluid/operators/layer_norm_op.cu @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include +#include +#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/operators/layer_norm_op.h" namespace paddle { @@ -427,6 +430,29 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, } } +template +void LayerNormDirectCUDAFunctor::operator()(cudaStream_t stream, + const T *input, + std::vector input_shape, + const T *bias, const T *scale, + T *output, T *mean, T *variance, + int begin_norm_axis, float eps) { + const auto x_dims = framework::make_ddim(input_shape); + auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); + int batch_size = static_cast(matrix_dim[0]); + int feature_size = static_cast(matrix_dim[1]); + switch (GetDesiredBlockDim(feature_size)) { + FIXED_BLOCK_DIM_CASE( + LayerNormForward<<>>( + input, scale, bias, output, mean, variance, eps, feature_size)); + default: + PADDLE_THROW(platform::errors::InvalidArgument( + "Product from begin_norm_axis to end in layer_norm must be larger " + "than 1")); + break; + } +} + template class LayerNormKernel : public framework::OpKernel { @@ -512,7 +538,7 @@ class LayerNormGradKernel batch_size, feature_size, stream); } }; - +template class LayerNormDirectCUDAFunctor; #undef FIXED_BLOCK_DIM_CASE_BASE #undef FIXED_BLOCK_DIM_CASE } // namespace operators diff --git a/paddle/fluid/operators/layer_norm_op.h b/paddle/fluid/operators/layer_norm_op.h index 5907d1d727..89f5dccc1a 100644 --- a/paddle/fluid/operators/layer_norm_op.h +++ b/paddle/fluid/operators/layer_norm_op.h @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once + +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" @@ -151,6 +153,17 @@ using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; using DataLayout = framework::DataLayout; +#ifdef PADDLE_WITH_CUDA +template +class LayerNormDirectCUDAFunctor { + public: + void operator()(cudaStream_t stream, const T* input, + std::vector input_shape, const T* bias, const T* scale, + T* output, T* mean, T* variance, int begin_norm_axis, + float eps); +}; +#endif + template class LayerNormKernel : public framework::OpKernel { public: -- GitLab