diff --git a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc index e6903d05d384db97c3969e7ac5dcc8795b58ddc0..a7cd569f9b469a02be0d84c6f93f91b3ef4c0a98 100644 --- a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc @@ -138,7 +138,8 @@ class MultiheadMatMulOpConverter : public OpConverter { *reshape_layer->getOutput(0), nvinfer1::ReduceOperation::kMAX, 1, false); */ - auto imask_tensor = engine_->GetITensor("imask_tensor"); + // auto imask_tensor = engine_->GetITensor("imask_tensor"); + auto imask_tensor = engine_->GetITensor("fused_mha_mask"); auto creator = GetPluginRegistry()->getPluginCreator( "CustomQKVToContextPluginDynamic", "1"); diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index 6359f36c998891b5e2393d584adb1e0fdb7c6baa..f4b0f5f23d8fda064c29534b56868beae79f65c0 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -173,8 +173,6 @@ class OpConverter { "optim_input_shape should be same.")); } } - std::cerr << "Declare input: " << input << std::endl; - if (input.find("stack_0.tmp_0") != std::string::npos) continue; engine->DeclareInput( input, FluidDataType2TRT( var->Proto()->type().lod_tensor().tensor().data_type()), diff --git a/paddle/fluid/inference/tensorrt/convert/scale_op.cc b/paddle/fluid/inference/tensorrt/convert/scale_op.cc index f9a1fe41ddc046aad8cc3a5397453b0f68c1a112..9c34027a9e62f0896dc75f1e7a8e98aca1fed2ed 100644 --- a/paddle/fluid/inference/tensorrt/convert/scale_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/scale_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.h" namespace paddle { namespace inference { @@ -26,6 +27,7 @@ class ScaleOpConverter : public OpConverter { void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { VLOG(3) << "convert a fluid scale op to tensorrt mul layer without bias"; + std::cerr << "Scale converter" << std::endl; framework::OpDesc op_desc(op, nullptr); // Declare inputs @@ -64,6 +66,12 @@ class ScaleOpConverter : public OpConverter { platform::errors::Fatal( "Paddle-TRT scale mode only support dimension >= 3")); + plugin::ConvertMaskPluginDynamic* plugin = + new plugin::ConvertMaskPluginDynamic(); + auto convert_mask_layer = engine_->AddPluginV2(&input, 1, plugin); + convert_mask_layer->setName("convert_mask_layer"); + engine_->SetITensor("fused_mha_mask", convert_mask_layer->getOutput(0)); + nvinfer1::IShuffleLayer* expand_layer = nullptr; nvinfer1::IShuffleLayer* squeeze_layer = nullptr; diff --git a/paddle/fluid/inference/tensorrt/convert/ut_helper.h b/paddle/fluid/inference/tensorrt/convert/ut_helper.h index ed347be1cac198acce943dc4b48dd5fec41ccab0..3c48c8192f6b06e5a0ba005738383b46bc550ecb 100644 --- a/paddle/fluid/inference/tensorrt/convert/ut_helper.h +++ b/paddle/fluid/inference/tensorrt/convert/ut_helper.h @@ -183,8 +183,6 @@ class TRTConvertValidation { std::vector buffers(num_bindings); for (const std::string& name : input_output_names) { - // std::cerr << "Binding name: " << name << std::endl; - if (name.find("stack_0.tmp_0") != std::string::npos) continue; auto* var = scope_.FindVar(name); auto* tensor = var->GetMutable(); const int bind_index = engine_->engine()->getBindingIndex(name.c_str()); diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index ff95f8fcc6f0866de320607420723e00e61f1d7e..6ef1be3f5abe9bee233e27ec8617d75fe11c4154 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -2,7 +2,7 @@ nv_library(tensorrt_plugin SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu - cast_int_plugin.cu stack_op_plugin.cu + cast_int_plugin.cu stack_op_plugin.cu convert_mask_plugin.cu instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu hard_swish_op_plugin.cu DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor) diff --git a/paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.cu new file mode 100644 index 0000000000000000000000000000000000000000..b2c1348aa4823f64a0be8f02581aac2bc868b200 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.cu @@ -0,0 +1,196 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +// Dynamic Plugin below. +#if IS_TRT_VERSION_GE(6000) + +/* This plugin currently converts the matmul output [B, S, S] +to the mask with the bertQKV fused_multihead_attention format */ + +constexpr size_t threadsPerCta128 = 2 * 2 * 32; + +constexpr size_t xmmasM128 = 4; + +constexpr size_t packedMaskSize128 = xmmasM128 * threadsPerCta128; + +nvinfer1::DimsExprs ConvertMaskPluginDynamic::getOutputDimensions( + int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs, + nvinfer1::IExprBuilder& expr_builder) { + auto cms128 = expr_builder.constant(packedMaskSize128); + auto fp16maskSize = expr_builder.operation( + nvinfer1::DimensionOperation::kPROD, *cms128, *expr_builder.constant(2)); + + nvinfer1::DimsExprs ret; + ret.nbDims = 2; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = fp16maskSize; + + return ret; +} + +bool ConvertMaskPluginDynamic::supportsFormatCombination( + int pos, const nvinfer1::PluginTensorDesc* in_out, int nb_inputs, + int nb_outputs) { + const nvinfer1::PluginTensorDesc& desc = in_out[pos]; + /* input: [B, S, S] */ + /* output: [B, 2*maskSize] */ + assert(nb_inputs == 1); + assert(nb_outputs == 1); + + if (pos == 0) { + std::cerr << "desc.type: " << static_cast(desc.type) << " " + << desc.dims.nbDims << std::endl; + return ((desc.type == nvinfer1::DataType::kFLOAT || + desc.type == nvinfer1::DataType::kHALF) && + desc.dims.nbDims == 3); + } + std::cerr << "output.type: " << static_cast(desc.type) << " " + << desc.dims.nbDims << std::endl; + // return desc.type == nvinfer1::DataType::kHALF; + return true; +} + +nvinfer1::DataType ConvertMaskPluginDynamic::getOutputDataType( + int index, const nvinfer1::DataType* input_types, int nb_inputs) const { + PADDLE_ENFORCE_EQ(index, 0, + platform::errors::InvalidArgument( + "The convert mask plugin only has one input, so the " + "index value should be 0, but get %d.", + index)); + return nvinfer1::DataType::kHALF; +} + +template +__global__ void CastToIntAndReduce(const T* input, int* output, int seq_len, + int batch) { + int bid = blockIdx.x; + int sid = threadIdx.x; + output[sid * batch + bid] = + static_cast(input[bid * seq_len * seq_len + sid]); +} + +__global__ void fillSBSMaskKernel(const uint32_t warps_m, + const uint32_t warps_n, const uint32_t S, + const int* inputMaskSB, + uint32_t* inputMaskX) { + extern __shared__ int shm_mask[]; // S mask elements of this batch + + const size_t xmmas_n = (S + 16 * warps_n - 1) / (16 * warps_n); + const uint32_t threads_per_cta = blockDim.x; + const uint32_t xmmas_m = gridDim.x; + const uint32_t B = gridDim.y; + + const uint32_t mi = blockIdx.x; + const uint32_t bi = blockIdx.y; + const uint32_t tidx = threadIdx.x; + + const size_t warp = tidx / 32; + const size_t warp_m = warp % warps_m; + const size_t warp_n = warp / warps_m; + const size_t lane = tidx % 32; + const size_t col = warp_n * 16 + lane % 4 * 2; + + // load the mask corresponding to one batch + for (uint32_t si = tidx; si < S; si += threads_per_cta) { + // not coalesced to conform to current input format: SxB + shm_mask[si] = inputMaskSB[si * B + bi]; + } + __syncthreads(); + + uint32_t mask = 0u; + + for (size_t ni = 0; ni < xmmas_n; ++ni) { + const int offset = ni * 16 * warps_n + col; + mask |= (shm_mask[offset + 0] == 1.f ? 1u : 0u) << (8 * ni + 0); + mask |= (shm_mask[offset + 1] == 1.f ? 1u : 0u) << (8 * ni + 1); + mask |= (shm_mask[offset + 0] == 1.f ? 1u : 0u) << (8 * ni + 2); + mask |= (shm_mask[offset + 1] == 1.f ? 1u : 0u) << (8 * ni + 3); + mask |= (shm_mask[offset + 8] == 1.f ? 1u : 0u) << (8 * ni + 4); + mask |= (shm_mask[offset + 9] == 1.f ? 1u : 0u) << (8 * ni + 5); + mask |= (shm_mask[offset + 8] == 1.f ? 1u : 0u) << (8 * ni + 6); + mask |= (shm_mask[offset + 9] == 1.f ? 1u : 0u) << (8 * ni + 7); + } + + inputMaskX[(bi * xmmas_m + mi) * threads_per_cta + tidx] = mask; +} + +void convertMask(const uint32_t S, const uint32_t B, const uint32_t warps_m, + const uint32_t warps_n, const uint32_t warps_k, + const int* inputMaskSB, uint32_t* inputMaskX, + cudaStream_t stream) { + const size_t xmmas_m = (S + 16 * warps_m - 1) / (16 * warps_m); + + const size_t threads_per_cta = warps_m * warps_n * warps_k * 32; + dim3 grid(xmmas_m, B); + fillSBSMaskKernel<<>>( + warps_m, warps_n, S, inputMaskSB, inputMaskX); +} + +int ConvertMaskPluginDynamic::enqueue( + const nvinfer1::PluginTensorDesc* input_desc, + const nvinfer1::PluginTensorDesc* output_desc, const void* const* inputs, + void* const* outputs, void* workspace, cudaStream_t stream) { + auto input_dims = input_desc[0].dims; + auto output_dims = output_desc[0].dims; + size_t num_elements = ProductDim(input_dims); + size_t out_num_elements = ProductDim(output_dims); + int batch = input_dims.d[0]; + int seq_len = input_dims.d[1]; + + assert(num_elements == out_num_elements * seq_len); + assert(seq_len <= 1024); + assert(output_desc.type == nvinfer1::DataType::kHALF); + + // temp use, should remove + int* inputMaskSB; + cudaMalloc(&inputMaskSB, batch * seq_len * sizeof(int)); + + if (input_desc[0].type == nvinfer1::DataType::kFLOAT) { + CastToIntAndReduce<<>>( + static_cast(inputs[0]), inputMaskSB, seq_len, batch); + } else { + CastToIntAndReduce<<>>( + static_cast(inputs[0]), inputMaskSB, seq_len, batch); + } + + assert(seq_len == 128); + size_t warps_m = 0, warps_n = 0, warps_k = 1; + if (seq_len == 128) { + warps_m = 2; + warps_n = 2; + } + + convertMask(seq_len, batch, warps_m, warps_n, warps_k, inputMaskSB, + static_cast(outputs[0]), stream); + + cudaFree(inputMaskSB); + return cudaGetLastError() != cudaSuccess; +} +#endif + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.h b/paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.h new file mode 100644 index 0000000000000000000000000000000000000000..b8a7490117b4e0a29ba1c0be254ebd98f4124efe --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.h @@ -0,0 +1,120 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +#if IS_TRT_VERSION_GE(6000) +class ConvertMaskPluginDynamic : public DynamicPluginTensorRT { + public: + ConvertMaskPluginDynamic() {} + ConvertMaskPluginDynamic(void const* serial_data, size_t serial_length) {} + + ~ConvertMaskPluginDynamic() {} + nvinfer1::IPluginV2DynamicExt* clone() const override { + return new ConvertMaskPluginDynamic(); + } + + const char* getPluginType() const override { return "convert_mask_plugin"; } + int getNbOutputs() const override { return 1; } + int initialize() override { return 0; } + + size_t getSerializationSize() const override { return 0; } + void serialize(void* buffer) const override {} + + nvinfer1::DimsExprs getOutputDimensions( + int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs, + nvinfer1::IExprBuilder& expr_builder) override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* in_out, + int nb_inputs, int nb_outputs) override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nb_inputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nb_outputs) override {} + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nb_inputs, + const nvinfer1::PluginTensorDesc* outputs, + int nb_outputs) const override { + return 0; + } + + int enqueue(const nvinfer1::PluginTensorDesc* input_desc, + const nvinfer1::PluginTensorDesc* output_desc, + const void* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) override; + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType* input_types, + int nb_inputs) const override; + + void destroy() override { delete this; } +}; + +class ConvertMaskPluginV2Creator : public nvinfer1::IPluginCreator { + public: + ConvertMaskPluginV2Creator() {} + const char* getPluginName() const override { return "convert_mask_plugin"; } + + const char* getPluginVersion() const override { return "1"; } + + const nvinfer1::PluginFieldCollection* getFieldNames() override { + return &field_collection_; + } + + nvinfer1::IPluginV2* createPlugin( + const char* name, const nvinfer1::PluginFieldCollection* fc) override { + return nullptr; + } + + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serial_data, + size_t serial_length) override { + auto plugin = new ConvertMaskPluginDynamic(serial_data, serial_length); + return plugin; + } + + void setPluginNamespace(const char* lib_namespace) override { + plugin_namespace_ = lib_namespace; + } + + const char* getPluginNamespace() const override { + return plugin_namespace_.c_str(); + } + + private: + std::string plugin_namespace_; + std::string plugin_name_; + nvinfer1::PluginFieldCollection field_collection_{0, nullptr}; + std::vector plugin_attributes_; +}; + +REGISTER_TRT_PLUGIN_V2(ConvertMaskPluginV2Creator); +#endif + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h index 3a008f55c7987d273114f4e5bd0dac196f57cfa9..11b48e0ea91a4c1c5799afaba7367934e331d45c 100644 --- a/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h +++ b/paddle/fluid/operators/tensorrt/tensorrt_engine_op.h @@ -227,8 +227,6 @@ class TensorRTEngineOp : public framework::OperatorBase { // Bind input tensor to TRT. for (const auto &x : Inputs("Xs")) { if (param_names_.count(x)) continue; - // std::cerr << "runTRT name: " << x << std::endl; - if (x.find("stack_0.tmp_0") != std::string::npos) continue; // convert input and copy to TRT engine's buffer auto &t = inference::analysis::GetFromScope(scope, x);