From d43bb7f25db0eb1dc7728f6845d987c5a3b1181c Mon Sep 17 00:00:00 2001 From: zlsh80826 Date: Fri, 28 Aug 2020 14:37:59 +0800 Subject: [PATCH] convert mask with fp32/fp16 support --- .../fluid/inference/api/analysis_predictor.cc | 2 +- .../tensorrt/convert/emb_eltwise_layernorm.cc | 19 --- .../inference/tensorrt/convert/mul_op.cc | 27 ++-- .../tensorrt/convert/multihead_matmul_op.cc | 57 +++------ .../inference/tensorrt/convert/scale_op.cc | 8 -- paddle/fluid/inference/tensorrt/op_teller.cc | 4 +- .../inference/tensorrt/plugin/CMakeLists.txt | 2 +- .../tensorrt/plugin/cast_int_plugin.cu | 85 ------------- .../tensorrt/plugin/cast_int_plugin.h | 120 ------------------ .../tensorrt/plugin/convert_mask_plugin.cu | 110 ++++++++++------ .../tensorrt/plugin/convert_mask_plugin.h | 25 +++- 11 files changed, 127 insertions(+), 332 deletions(-) delete mode 100644 paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.cu delete mode 100644 paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.h diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 67e3e237bbf..cd396159f2c 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1029,7 +1029,7 @@ USE_TRT_CONVERTER(elementwise_mul_tensor); USE_TRT_CONVERTER(elementwise_max_tensor); USE_TRT_CONVERTER(elementwise_min_tensor); USE_TRT_CONVERTER(elementwise_pow_tensor); -USE_TRT_CONVERTER(mul); +USE_TRT_CONVERTER(matmul); USE_TRT_CONVERTER(conv2d); USE_TRT_CONVERTER(relu); USE_TRT_CONVERTER(sigmoid); diff --git a/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc index cb789a8cd35..88384b39b0c 100644 --- a/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc @@ -11,7 +11,6 @@ limitations under the License. */ #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/helper.h" -#include "paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h" namespace paddle { @@ -81,24 +80,6 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { nvinfer1::ILayer* layer = nullptr; if (engine_->with_dynamic_shape()) { - auto pos_tensor = engine_->GetITensor("eval_placeholder_2"); - plugin::CastIntPluginDynamic* cast_plugin = - new plugin::CastIntPluginDynamic(); - auto cast_layer = engine_->AddPluginV2(&pos_tensor, 1, cast_plugin); - - auto casted_pos_tensor = cast_layer->getOutput(0); - auto reshape_layer = - TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *casted_pos_tensor); - - nvinfer1::Dims2 reshape_dim(0, 0); - nvinfer1::Permutation perm{1, 0, 2}; - reshape_layer->setFirstTranspose(perm); - reshape_layer->setReshapeDimensions(reshape_dim); - auto imask_layer = - TRT_ENGINE_ADD_LAYER(engine_, Reduce, *reshape_layer->getOutput(0), - nvinfer1::ReduceOperation::kMAX, 1, false); - engine_->SetITensor("imask_tensor", imask_layer->getOutput(0)); - plugin::DynamicPluginTensorRT* plugin = nullptr; plugin = new plugin::EmbEltwiseLayernormPluginDynamic( input_embs, bias, scale, emb_sizes, bias_size, scale_size, hidden, diff --git a/paddle/fluid/inference/tensorrt/convert/mul_op.cc b/paddle/fluid/inference/tensorrt/convert/mul_op.cc index 5b6aaad4983..69192fbe13a 100644 --- a/paddle/fluid/inference/tensorrt/convert/mul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/mul_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.h" namespace paddle { namespace inference { @@ -31,17 +32,27 @@ class MulOpConverter : public OpConverter { // Declare inputs auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]); auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]); + + bool transpose_x = BOOST_GET_CONST(bool, op_desc.GetAttr("transpose_X")); + bool transpose_y = BOOST_GET_CONST(bool, op_desc.GetAttr("transpose_Y")); + +#ifdef USE_NVINFER_PLUGIN + nvinfer1::DataType type = (engine_->WithFp16() == 1) + ? nvinfer1::DataType::kHALF + : nvinfer1::DataType::kFLOAT; + plugin::ConvertMaskPluginDynamic* plugin = + new plugin::ConvertMaskPluginDynamic(type); + auto convert_mask_layer = engine_->AddPluginV2(&input1, 1, plugin); + engine_->SetITensor("qkv_plugin_mask", convert_mask_layer->getOutput(0)); +#endif + // Both the input1 and input2 do not need transpose. auto* layer = TRT_ENGINE_ADD_LAYER( - engine_, MatrixMultiply, *const_cast(input1), false, - *const_cast(input2), false); + engine_, MatrixMultiply, *const_cast(input1), + transpose_x, *const_cast(input2), transpose_y); auto output_name = op_desc.Output("Out")[0]; - engine_->SetITensor(output_name, layer->getOutput(0)); - if (test_mode) { // the test framework can not determine which is the - // output, so place the declaration inside. - engine_->DeclareOutput(output_name); - } + RreplenishLayerAndOutput(layer, "matmul", {output_name}, test_mode); } }; @@ -49,4 +60,4 @@ class MulOpConverter : public OpConverter { } // namespace inference } // namespace paddle -REGISTER_TRT_OP_CONVERTER(mul, MulOpConverter); +REGISTER_TRT_OP_CONVERTER(matmul, MulOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc index a7cd569f9b4..d30f0ff1293 100644 --- a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc @@ -113,33 +113,10 @@ class MultiheadMatMulOpConverter : public OpConverter { static_cast(bias_data), static_cast(bias_t->numel())}; - nvinfer1::Permutation permutation{0, 1, 2, 3, 4}; - auto trans_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); - trans_layer->setFirstTranspose(permutation); - - auto* fc_layer = TRT_ENGINE_ADD_LAYER( - engine_, FullyConnected, *trans_layer->getOutput(0), n, weight, bias); - /* - auto pos_tensor = engine_->GetITensor("eval_placeholder_2"); - plugin::CastIntPluginDynamic* cast_plugin = - new plugin::CastIntPluginDynamic(); - auto cast_layer = engine_->AddPluginV2(&pos_tensor, 1, cast_plugin); - - auto casted_pos_tensor = cast_layer->getOutput(0); - auto reshape_layer = - TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *casted_pos_tensor); - - nvinfer1::Dims2 reshape_dim(0, 0); - nvinfer1::Permutation perm{1, 0, 2}; - reshape_layer->setFirstTranspose(perm); - reshape_layer->setReshapeDimensions(reshape_dim); - auto reduce_layer = - TRT_ENGINE_ADD_LAYER(engine_, Reduce, - *reshape_layer->getOutput(0), - nvinfer1::ReduceOperation::kMAX, 1, false); - */ - // auto imask_tensor = engine_->GetITensor("imask_tensor"); - auto imask_tensor = engine_->GetITensor("fused_mha_mask"); + auto* fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *input, n, + weight, bias); + + auto mask_tensor = engine_->GetITensor("qkv_plugin_mask"); auto creator = GetPluginRegistry()->getPluginCreator( "CustomQKVToContextPluginDynamic", "1"); @@ -154,28 +131,24 @@ class MultiheadMatMulOpConverter : public OpConverter { {"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1}, {"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1}, }; - nvinfer1::PluginFieldCollection* pluginPtr = + nvinfer1::PluginFieldCollection* plugin_collection = static_cast( - malloc(sizeof(*pluginPtr) + + malloc(sizeof(*plugin_collection) + fields.size() * sizeof(nvinfer1::PluginField))); // remember to free - pluginPtr->nbFields = static_cast(fields.size()); - pluginPtr->fields = fields.data(); + plugin_collection->nbFields = static_cast(fields.size()); + plugin_collection->fields = fields.data(); + + auto plugin = creator->createPlugin("CustomQKVToContextPluginDynamic", + plugin_collection); + free(plugin_collection); - auto pluginObj = - creator->createPlugin("CustomQKVToContextPluginDynamic", pluginPtr); std::vector plugin_inputs; plugin_inputs.push_back(fc_layer->getOutput(0)); - // plugin_inputs.push_back(reduce_layer->getOutput(0)); - plugin_inputs.push_back(imask_tensor); + plugin_inputs.push_back(mask_tensor); auto plugin_layer = engine_->network()->addPluginV2( - plugin_inputs.data(), plugin_inputs.size(), *pluginObj); - assert(plugin_layer != nullptr); - auto trans_r_layer = - TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *plugin_layer->getOutput(0)); - assert(trans_r_layer != nullptr); - trans_r_layer->setFirstTranspose(permutation); - layer = trans_r_layer; + plugin_inputs.data(), plugin_inputs.size(), *plugin); + layer = plugin_layer; #else // transpose weight_data from m * n to n * m auto* input_bias_qk = diff --git a/paddle/fluid/inference/tensorrt/convert/scale_op.cc b/paddle/fluid/inference/tensorrt/convert/scale_op.cc index 9c34027a9e6..f9a1fe41ddc 100644 --- a/paddle/fluid/inference/tensorrt/convert/scale_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/scale_op.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" -#include "paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.h" namespace paddle { namespace inference { @@ -27,7 +26,6 @@ class ScaleOpConverter : public OpConverter { void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { VLOG(3) << "convert a fluid scale op to tensorrt mul layer without bias"; - std::cerr << "Scale converter" << std::endl; framework::OpDesc op_desc(op, nullptr); // Declare inputs @@ -66,12 +64,6 @@ class ScaleOpConverter : public OpConverter { platform::errors::Fatal( "Paddle-TRT scale mode only support dimension >= 3")); - plugin::ConvertMaskPluginDynamic* plugin = - new plugin::ConvertMaskPluginDynamic(); - auto convert_mask_layer = engine_->AddPluginV2(&input, 1, plugin); - convert_mask_layer->setName("convert_mask_layer"); - engine_->SetITensor("fused_mha_mask", convert_mask_layer->getOutput(0)); - nvinfer1::IShuffleLayer* expand_layer = nullptr; nvinfer1::IShuffleLayer* squeeze_layer = nullptr; diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 3b7810d363f..e321c88af41 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -43,7 +43,7 @@ struct SimpleOpTypeSetTeller : public Teller { private: // use this set for no calib int8. - std::unordered_set int8_teller_set{"mul", + std::unordered_set int8_teller_set{"matmul", "conv2d", "pool2d", "relu", @@ -59,7 +59,7 @@ struct SimpleOpTypeSetTeller : public Teller { "elementwise_mul", "conv2d_transpose"}; std::unordered_set teller_set{ - "mul", + "matmul", "conv2d", "pool2d", "relu", diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index 6ef1be3f5ab..2ad81405e5e 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -2,7 +2,7 @@ nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu - cast_int_plugin.cu stack_op_plugin.cu convert_mask_plugin.cu + stack_op_plugin.cu convert_mask_plugin.cu instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu hard_swish_op_plugin.cu DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor) diff --git a/paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.cu deleted file mode 100644 index ccb2f7dfe20..00000000000 --- a/paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.cu +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include "paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.h" -#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h" - -namespace paddle { -namespace inference { -namespace tensorrt { -namespace plugin { - -// Dynamic Plugin below. -#if IS_TRT_VERSION_GE(6000) - -nvinfer1::DimsExprs CastIntPluginDynamic::getOutputDimensions( - int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs, - nvinfer1::IExprBuilder& expr_builder) { - assert(output_index == 0); - return inputs[0]; -} - -bool CastIntPluginDynamic::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc* in_out, int nb_inputs, - int nb_outputs) { - const nvinfer1::PluginTensorDesc& in = in_out[pos]; - return (in.type == nvinfer1::DataType::kINT32); -} - -nvinfer1::DataType CastIntPluginDynamic::getOutputDataType( - int index, const nvinfer1::DataType* input_types, int nb_inputs) const { - PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument( - "The Cast Int only has one input, so the " - "index value should be 0, but get %d.", - index)); - return input_types[index]; -} - -__global__ void castIntKernel(const int64_t* input, int32_t* output, - size_t num_elements) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= num_elements) return; - output[idx] = input[idx] + 1; -} - -int CastIntPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc, - const nvinfer1::PluginTensorDesc* output_desc, - const void* const* inputs, - void* const* outputs, void* workspace, - cudaStream_t stream) { - auto input_dims = input_desc[0].dims; - auto output_dims = output_desc[0].dims; - size_t num_elements = ProductDim(input_dims); - size_t out_num_elements = ProductDim(output_dims); - - assert(input_type == - nvinfer1::DataType::kINT32); // although the input is int64_t - assert(num_elements == out_num_elements); - - const size_t num_threads = 256; - castIntKernel<<>>( - static_cast(inputs[0]), static_cast(outputs[0]), - num_elements); - - return cudaGetLastError() != cudaSuccess; -} -#endif - -} // namespace plugin -} // namespace tensorrt -} // namespace inference -} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.h b/paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.h deleted file mode 100644 index 039d1494e9a..00000000000 --- a/paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.h +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include -#include -#include -#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" - -namespace paddle { -namespace inference { -namespace tensorrt { -namespace plugin { - -#if IS_TRT_VERSION_GE(6000) -class CastIntPluginDynamic : public DynamicPluginTensorRT { - public: - CastIntPluginDynamic() {} - CastIntPluginDynamic(void const* serial_data, size_t serial_length) {} - - ~CastIntPluginDynamic() {} - nvinfer1::IPluginV2DynamicExt* clone() const override { - return new CastIntPluginDynamic(); - } - - const char* getPluginType() const override { return "cast_int_plugin"; } - int getNbOutputs() const override { return 1; } - int initialize() override { return 0; } - - size_t getSerializationSize() const override { return 0; } - void serialize(void* buffer) const override {} - - nvinfer1::DimsExprs getOutputDimensions( - int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs, - nvinfer1::IExprBuilder& expr_builder) override; - - bool supportsFormatCombination(int pos, - const nvinfer1::PluginTensorDesc* in_out, - int nb_inputs, int nb_outputs) override; - - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, - int nb_inputs, - const nvinfer1::DynamicPluginTensorDesc* out, - int nb_outputs) override {} - - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, - int nb_inputs, - const nvinfer1::PluginTensorDesc* outputs, - int nb_outputs) const override { - return 0; - } - - int enqueue(const nvinfer1::PluginTensorDesc* input_desc, - const nvinfer1::PluginTensorDesc* output_desc, - const void* const* inputs, void* const* outputs, void* workspace, - cudaStream_t stream) override; - nvinfer1::DataType getOutputDataType(int index, - const nvinfer1::DataType* input_types, - int nb_inputs) const override; - - void destroy() override { delete this; } -}; - -class CastIntPluginV2Creator : public nvinfer1::IPluginCreator { - public: - CastIntPluginV2Creator() {} - const char* getPluginName() const override { return "cast_int_plugin"; } - - const char* getPluginVersion() const override { return "1"; } - - const nvinfer1::PluginFieldCollection* getFieldNames() override { - return &field_collection_; - } - - nvinfer1::IPluginV2* createPlugin( - const char* name, const nvinfer1::PluginFieldCollection* fc) override { - return nullptr; - } - - nvinfer1::IPluginV2* deserializePlugin(const char* name, - const void* serial_data, - size_t serial_length) override { - auto plugin = new CastIntPluginDynamic(serial_data, serial_length); - return plugin; - } - - void setPluginNamespace(const char* lib_namespace) override { - plugin_namespace_ = lib_namespace; - } - - const char* getPluginNamespace() const override { - return plugin_namespace_.c_str(); - } - - private: - std::string plugin_namespace_; - std::string plugin_name_; - nvinfer1::PluginFieldCollection field_collection_{0, nullptr}; - std::vector plugin_attributes_; -}; - -REGISTER_TRT_PLUGIN_V2(CastIntPluginV2Creator); -#endif - -} // namespace plugin -} // namespace tensorrt -} // namespace inference -} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.cu index b2c1348aa48..3a02343033c 100644 --- a/paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.cu @@ -17,6 +17,7 @@ #include #include "paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h" +#include "paddle/fluid/operators/math/math_cuda_utils.h" namespace paddle { namespace inference { @@ -38,15 +39,23 @@ constexpr size_t packedMaskSize128 = xmmasM128 * threadsPerCta128; nvinfer1::DimsExprs ConvertMaskPluginDynamic::getOutputDimensions( int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs, nvinfer1::IExprBuilder& expr_builder) { - auto cms128 = expr_builder.constant(packedMaskSize128); - auto fp16maskSize = expr_builder.operation( - nvinfer1::DimensionOperation::kPROD, *cms128, *expr_builder.constant(2)); - + assert(output_index == 0); + if (type_ == nvinfer1::DataType::kHALF) { + auto cms128 = expr_builder.constant(packedMaskSize128); + auto fp16maskSize = + expr_builder.operation(nvinfer1::DimensionOperation::kPROD, *cms128, + *expr_builder.constant(2)); + + nvinfer1::DimsExprs ret; + ret.nbDims = 2; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = fp16maskSize; + + return ret; + } nvinfer1::DimsExprs ret; - ret.nbDims = 2; + ret.nbDims = 1; ret.d[0] = inputs[0].d[0]; - ret.d[1] = fp16maskSize; - return ret; } @@ -54,22 +63,21 @@ bool ConvertMaskPluginDynamic::supportsFormatCombination( int pos, const nvinfer1::PluginTensorDesc* in_out, int nb_inputs, int nb_outputs) { const nvinfer1::PluginTensorDesc& desc = in_out[pos]; - /* input: [B, S, S] */ + /* input: [B, S, 1] */ /* output: [B, 2*maskSize] */ assert(nb_inputs == 1); assert(nb_outputs == 1); if (pos == 0) { - std::cerr << "desc.type: " << static_cast(desc.type) << " " - << desc.dims.nbDims << std::endl; return ((desc.type == nvinfer1::DataType::kFLOAT || desc.type == nvinfer1::DataType::kHALF) && desc.dims.nbDims == 3); } - std::cerr << "output.type: " << static_cast(desc.type) << " " - << desc.dims.nbDims << std::endl; - // return desc.type == nvinfer1::DataType::kHALF; - return true; + // return true; + /* fp16 -> fp16, fp32 -> int32 */ + if (type_ == nvinfer1::DataType::kHALF) + return desc.type == nvinfer1::DataType::kHALF; + return desc.type == nvinfer1::DataType::kINT32; } nvinfer1::DataType ConvertMaskPluginDynamic::getOutputDataType( @@ -79,16 +87,36 @@ nvinfer1::DataType ConvertMaskPluginDynamic::getOutputDataType( "The convert mask plugin only has one input, so the " "index value should be 0, but get %d.", index)); - return nvinfer1::DataType::kHALF; + if (type_ == nvinfer1::DataType::kHALF) { + return nvinfer1::DataType::kHALF; + } + return nvinfer1::DataType::kINT32; } +/* half [B, S, 1] -> int [S, B, 1] */ template -__global__ void CastToIntAndReduce(const T* input, int* output, int seq_len, +__global__ void FullMaskPreprocess(const T* input, int* output, int seq_len, int batch) { int bid = blockIdx.x; int sid = threadIdx.x; - output[sid * batch + bid] = - static_cast(input[bid * seq_len * seq_len + sid]); + output[sid * batch + bid] = static_cast(input[bid * seq_len + sid]); +} + +/* float [B, S, 1] -> int [B] */ +/* [[1. 1. 1. 0. 0.], -> [3, 4] + [1. 1. 1. 1. 0.]] */ +__global__ void IMaskPreprocess(const float* input, int* output, int seq_len, + int batch) { + float sum = 0.f; + int bid = blockIdx.x; + int sid = threadIdx.x; + float thread_data = input[bid * seq_len + sid]; + + sum = paddle::operators::math::blockReduceSum(thread_data, 0xffffffff); + + if (sid == 0) { + output[bid] = static_cast(sum); + } } __global__ void fillSBSMaskKernel(const uint32_t warps_m, @@ -159,33 +187,33 @@ int ConvertMaskPluginDynamic::enqueue( int batch = input_dims.d[0]; int seq_len = input_dims.d[1]; - assert(num_elements == out_num_elements * seq_len); - assert(seq_len <= 1024); - assert(output_desc.type == nvinfer1::DataType::kHALF); - - // temp use, should remove - int* inputMaskSB; - cudaMalloc(&inputMaskSB, batch * seq_len * sizeof(int)); + assert(seq_len == 128); - if (input_desc[0].type == nvinfer1::DataType::kFLOAT) { - CastToIntAndReduce<<>>( - static_cast(inputs[0]), inputMaskSB, seq_len, batch); + if (type_ == nvinfer1::DataType::kFLOAT) { + IMaskPreprocess<<>>( + static_cast(inputs[0]), static_cast(outputs[0]), + seq_len, batch); } else { - CastToIntAndReduce<<>>( - static_cast(inputs[0]), inputMaskSB, seq_len, batch); + int* inputMaskSB; + cudaMalloc(&inputMaskSB, batch * seq_len * sizeof(int)); + if (input_desc[0].type == nvinfer1::DataType::kFLOAT) { + FullMaskPreprocess<<>>( + static_cast(inputs[0]), inputMaskSB, seq_len, batch); + } else { + FullMaskPreprocess<<>>( + static_cast(inputs[0]), inputMaskSB, seq_len, batch); + } + size_t warps_m = 0, warps_n = 0, warps_k = 1; + if (seq_len == 128) { + warps_m = 2; + warps_n = 2; + } + + convertMask(seq_len, batch, warps_m, warps_n, warps_k, inputMaskSB, + static_cast(outputs[0]), stream); + cudaFree(inputMaskSB); } - assert(seq_len == 128); - size_t warps_m = 0, warps_n = 0, warps_k = 1; - if (seq_len == 128) { - warps_m = 2; - warps_n = 2; - } - - convertMask(seq_len, batch, warps_m, warps_n, warps_k, inputMaskSB, - static_cast(outputs[0]), stream); - - cudaFree(inputMaskSB); return cudaGetLastError() != cudaSuccess; } #endif diff --git a/paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.h b/paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.h index b8a7490117b..06c36a5b0e2 100644 --- a/paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.h @@ -27,20 +27,32 @@ namespace plugin { #if IS_TRT_VERSION_GE(6000) class ConvertMaskPluginDynamic : public DynamicPluginTensorRT { public: - ConvertMaskPluginDynamic() {} - ConvertMaskPluginDynamic(void const* serial_data, size_t serial_length) {} + explicit ConvertMaskPluginDynamic(nvinfer1::DataType type) : type_(type) { + assert(type == nvinfer1::DataType::kHALF || + type == nvinfer1::DataType::kFLOAT); + } + ConvertMaskPluginDynamic(void const* serial_data, size_t serial_length) { + DeserializeValue(&serial_data, &serial_length, &type_); + } ~ConvertMaskPluginDynamic() {} nvinfer1::IPluginV2DynamicExt* clone() const override { - return new ConvertMaskPluginDynamic(); + return new ConvertMaskPluginDynamic(type_); } const char* getPluginType() const override { return "convert_mask_plugin"; } int getNbOutputs() const override { return 1; } int initialize() override { return 0; } - size_t getSerializationSize() const override { return 0; } - void serialize(void* buffer) const override {} + size_t getSerializationSize() const override { + size_t serialize_size = 0; + serialize_size += SerializedSize(type_); + return serialize_size; + } + + void serialize(void* buffer) const override { + SerializeValue(&buffer, type_); + } nvinfer1::DimsExprs getOutputDimensions( int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs, @@ -71,6 +83,9 @@ class ConvertMaskPluginDynamic : public DynamicPluginTensorRT { int nb_inputs) const override; void destroy() override { delete this; } + + private: + nvinfer1::DataType type_; }; class ConvertMaskPluginV2Creator : public nvinfer1::IPluginCreator { -- GitLab