From 6a87f61bfc58887e31ef7d53e3b517eb8ff95b83 Mon Sep 17 00:00:00 2001 From: ReeseWang Date: Sun, 19 Jul 2020 22:55:45 +0800 Subject: [PATCH] add trt stack op, test=develop --- .../fluid/inference/api/analysis_predictor.cc | 1 + .../inference/tensorrt/convert/stack_op.cc | 81 +++++++++ paddle/fluid/inference/tensorrt/op_teller.cc | 56 +++---- .../tensorrt/plugin/stack_op_plugin.cu | 154 ++++++++++++++++++ .../tensorrt/plugin/stack_op_plugin.h | 95 +++++++++++ 5 files changed, 359 insertions(+), 28 deletions(-) create mode 100644 paddle/fluid/inference/tensorrt/convert/stack_op.cc create mode 100644 paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.cu create mode 100644 paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index df4b0079c79..af915acc0e4 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1035,4 +1035,5 @@ USE_TRT_CONVERTER(fused_embedding_eltwise_layernorm); USE_TRT_CONVERTER(skip_layernorm); USE_TRT_CONVERTER(slice); USE_TRT_CONVERTER(scale); +USE_TRT_CONVERTER(stack); #endif diff --git a/paddle/fluid/inference/tensorrt/convert/stack_op.cc b/paddle/fluid/inference/tensorrt/convert/stack_op.cc new file mode 100644 index 00000000000..c01b8704935 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/stack_op.cc @@ -0,0 +1,81 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h" + +#include + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * Stack converter from fluid to tensorRT. + */ +class StackOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(4) << "convert fluid stack op to tensorrt stack layer"; + + framework::OpDesc op_desc(op, nullptr); + auto input = op_desc.Input("X"); + int input_num = input.size(); + nvinfer1::ITensor** inputs = + (nvinfer1::ITensor**)malloc(input_num * sizeof(nvinfer1::ITensor*)); + + for (int i = 0; i < input_num; ++i) { + inputs[i] = engine_->GetITensor(input[i]); + } + + auto idim = inputs[0]->getDimensions(); + std::cerr << "Stack input: " << idim.nbDims << " " << idim.d[0] << " " + << idim.d[1] << " " << idim.d[2] << std::endl; + + int axis = BOOST_GET_CONST(int, op_desc.GetAttr("axis")); + if (axis < 0) { + axis = axis + inputs[0]->getDimensions().nbDims + 1; + } + + nvinfer1::ILayer* layer = nullptr; + if (engine_->with_dynamic_shape()) { +#if IS_TRT_VERSION_GE(6000) + plugin::StackPluginDynamic* plugin = + new plugin::StackPluginDynamic(axis, input_num); + layer = engine_->AddPluginV2(inputs, input_num, plugin); + assert(layer != nullptr); +#else + PADDLE_THROW(platform::errors::Fatal( + "You are running the TRT Dynamic Shape mode, need to confirm that " + "your TRT version is no less than 6.0")); +#endif + } else { + PADDLE_THROW(platform::errors::Fatal( + "You are running the Ernie(Bert) model in static" + "shape mode, which is not supported for the time being.\n" + "You can use the config.SetTRTDynamicShapeInfo(...) interface" + " to set the shape information to run the dynamic shape mode.")); + } + auto output_name = op_desc.Output("Y").front(); + RreplenishLayerAndOutput(layer, "stack", {output_name}, test_mode); + free(inputs); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(stack, StackOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index a7bb7c8c4fc..5f4398895b3 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -55,34 +55,34 @@ struct SimpleOpTypeSetTeller : public Teller { "fc", "relu6", "concat"}; - std::unordered_set teller_set{ - "mul", - "conv2d", - "pool2d", - "relu", - "softmax", - "sigmoid", - "hard_swish", - "depthwise_conv2d", - "batch_norm", - "concat", - "tanh", - "pad", - "elementwise_add", - "elementwise_mul", - "dropout", - "prelu", - "conv2d_transpose", - "leaky_relu", - "fc", - "shuffle_channel", - "swish", - "split", - "instance_norm", - "gelu", - "layer_norm", - "scale", - }; + std::unordered_set teller_set{"mul", + "conv2d", + "pool2d", + "relu", + "softmax", + "sigmoid", + "hard_swish", + "depthwise_conv2d", + "batch_norm", + "concat", + "tanh", + "pad", + "elementwise_add", + "elementwise_mul", + "dropout", + "prelu", + "conv2d_transpose", + "leaky_relu", + "fc", + "shuffle_channel", + "swish", + "split", + "instance_norm", + "gelu", + "layer_norm", + "scale", + "slice", + "stack"}; }; bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc, diff --git a/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.cu new file mode 100644 index 00000000000..f707a2cad19 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.cu @@ -0,0 +1,154 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +// Dynamic Plugin below. +#if IS_TRT_VERSION_GE(6000) +size_t StackPluginDynamic::getSerializationSize() const { return 0; } + +void StackPluginDynamic::serialize(void* buffer) const {} + +nvinfer1::DimsExprs StackPluginDynamic::getOutputDimensions( + int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs, + nvinfer1::IExprBuilder& expr_builder) { + nvinfer1::DimsExprs output(inputs[0]); + output.nbDims = inputs[0].nbDims + 1; + + for (int i = inputs[0].nbDims; i > axis_; --i) { + output.d[i] = inputs[0].d[i - 1]; + } + output.d[axis_] = expr_builder.constant(nb_inputs); + return output; +} + +bool StackPluginDynamic::supportsFormatCombination( + int pos, const nvinfer1::PluginTensorDesc* in_out, int nb_inputs, + int nb_outputs) { + PADDLE_ENFORCE_NOT_NULL( + in_out, platform::errors::InvalidArgument( + "The input of stack plugin should not be nullptr.")); + + PADDLE_ENFORCE_LT( + pos, nb_inputs + nb_outputs, + platform::errors::InvalidArgument("The pos(%d) should be less than the " + "num(%d) of the input and the output.", + pos, nb_inputs + nb_outputs)); + + const nvinfer1::PluginTensorDesc& in = in_out[pos]; + if (pos == 0) { +#ifdef SUPPORTS_CUDA_FP16 + return (in.type == nvinfer1::DataType::kFLOAT || + in.type == nvinfer1::DataType::kHALF) && + (in.format == nvinfer1::TensorFormat::kLINEAR); +#else + return (in.type == nvinfer1::DataType::kFLOAT) && + (in.format == nvinfer1::TensorFormat::kLINEAR); +#endif + } + const nvinfer1::PluginTensorDesc& prev = in_out[pos - 1]; + // output + return in.type == prev.type && in.format == prev.format; +} + +nvinfer1::DataType StackPluginDynamic::getOutputDataType( + int index, const nvinfer1::DataType* input_types, int nb_inputs) const { + PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument( + "The index should be equal to 0")); + return input_types[0]; +} + +template +__global__ void StackKernel(const T* const* input, T* output, int num_stack, + int base_unit) { + int stack_id = blockIdx.x; + int lead_id = blockIdx.y; + + for (int i = threadIdx.x; i < base_unit; i += blockDim.x) { + output[lead_id * num_stack * base_unit + stack_id * base_unit + i] = + input[stack_id][lead_id * base_unit + i]; + } +} + +int StackPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc, + const nvinfer1::PluginTensorDesc* output_desc, + const void* const* inputs, void* const* outputs, + void* workspace, cudaStream_t stream) { + auto input_dims = input_desc[0].dims; // (batch, seq, seq) + auto out_dims = output_desc[0].dims; // (batch, num_head, seq, seq) + auto out_num_dims = out_dims.nbDims; + + int base_unit = 1; + for (int i = axis_ + 1; i < out_num_dims; ++i) { + PADDLE_ENFORCE_GT(out_dims.d[i], 0, + platform::errors::InvalidArgument( + "Input dimensions should be greater than 0")); + base_unit *= out_dims.d[i]; + } + + int lead_unit = 1; + for (int i = 0; i < axis_; ++i) { + PADDLE_ENFORCE_GT(out_dims.d[i], 0, + platform::errors::InvalidArgument( + "Input dimensions should be greater than 0")); + lead_unit *= out_dims.d[i]; + } + + cudaMemcpyAsync(reinterpret_cast(in_ptr_gpu_), + reinterpret_cast(inputs), + sizeof(void*) * out_dims.d[axis_], cudaMemcpyHostToDevice, + stream); + + int num_stacks = out_dims.d[axis_]; + dim3 num_blocks(num_stacks, lead_unit); + int num_threads = 256; + auto infer_type = input_desc[0].type; + + if (infer_type == nvinfer1::DataType::kFLOAT) { + float* output = static_cast(outputs[0]); + StackKernel<<>>( + reinterpret_cast(in_ptr_gpu_), output, num_stacks, + base_unit); + } else if (infer_type == nvinfer1::DataType::kHALF) { +#ifdef SUPPORTS_CUDA_FP16 + __half* output = static_cast<__half*>(outputs[0]); + StackKernel<__half><<>>( + reinterpret_cast(in_ptr_gpu_), output, num_stacks, + base_unit); +#else + PADDLE_THROW(platform::errors::Fatal( + "The cuda archs you specific should greater than 600.")); +#endif + } else { + PADDLE_THROW( + platform::errors::Fatal("The Stack TRT Plugin's input type only " + "support float or half currently.")); + } + return cudaGetLastError() != cudaSuccess; +} +#endif + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h new file mode 100644 index 00000000000..113eda42d35 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h @@ -0,0 +1,95 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +#if IS_TRT_VERSION_GE(6000) +class StackPluginDynamic : public DynamicPluginTensorRT { + public: + StackPluginDynamic(int axis, int num_stack) + : axis_(axis), num_stack_(num_stack) { + int device_id; + cudaGetDevice(&device_id); + in_ptr_tensor_.Resize({num_stack}); + in_ptr_gpu_ = + in_ptr_tensor_.mutable_data(platform::CUDAPlace(device_id)); + } + StackPluginDynamic(void const* serialData, size_t serialLength) {} + + ~StackPluginDynamic() {} + nvinfer1::IPluginV2DynamicExt* clone() const override { + return new StackPluginDynamic(axis_, num_stack_); + } + + const char* getPluginType() const override { return "stack_plugin"; } + int getNbOutputs() const override { return 1; } + int initialize() override { return 0; } + + size_t getSerializationSize() const override; + void serialize(void* buffer) const override; + + nvinfer1::DimsExprs getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, int nbOutputs) override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) override {} + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const override { + return 0; + } + + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) override; + + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const override; + + void destroy() override { delete this; } + + private: + int axis_; + int num_stack_; + framework::Tensor in_ptr_tensor_; + int64_t* in_ptr_gpu_; +}; +#endif + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle -- GitLab