未验证 提交 e5cf75d8 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle Inference] General optimization for no_varlen multihead (#48469)

* general optimization for no_varlen multihead
上级 aa892113
...@@ -439,9 +439,6 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const { ...@@ -439,9 +439,6 @@ void RemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph* graph) const {
"remove_padding pass."; "remove_padding pass.";
return; return;
} }
fc_op->Op()->RemoveAttr("in_num_col_dims");
fc_op->Op()->SetAttr("in_num_col_dims", 1);
insert_remove_padding_op(fc_input, fc_op); insert_remove_padding_op(fc_input, fc_op);
insert_recover_padding_op(fc_op, fc_op->outputs[0]); insert_recover_padding_op(fc_op, fc_op->outputs[0]);
found_subgraph_count++; found_subgraph_count++;
......
...@@ -441,14 +441,14 @@ void TrtEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const { ...@@ -441,14 +441,14 @@ void TrtEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
std::string mask_id = Get<std::string>("tensorrt_transformer_maskid"); std::string mask_id = Get<std::string>("tensorrt_transformer_maskid");
if ((use_varseqlen && pos_id != "" && mask_id != "") || if ((use_varseqlen && pos_id != "" && mask_id != "") ||
(!use_varseqlen && pos_id == "" && mask_id == "")) { (!use_varseqlen && pos_id == "")) {
VLOG(3) << "start trt_embedding_eltwise_layernorm_fuse_pass"; VLOG(3) << "start trt_embedding_eltwise_layernorm_fuse_pass";
} else { } else {
PADDLE_THROW( PADDLE_THROW(
platform::errors::Fatal("Use transformer'varseqlen need config: " platform::errors::Fatal("Use transformer'varseqlen need config: "
"use_varseqlen, set pos_id, set " "use_varseqlen, set pos_id, set "
"mask_id. Or not use varseqlen, do not set " "mask_id. Or not use varseqlen, do not set "
"pos_id, set mask_id. Please " "pos_id. Please "
"reconfig")); "reconfig"));
} }
graph->Set(kEmbEltwiseLayernormPass, new bool(true)); graph->Set(kEmbEltwiseLayernormPass, new bool(true));
......
...@@ -1637,14 +1637,14 @@ void TrtMultiHeadMatmulV3FusePass::ApplyImpl(Graph* graph) const { ...@@ -1637,14 +1637,14 @@ void TrtMultiHeadMatmulV3FusePass::ApplyImpl(Graph* graph) const {
"preln_embedding_eltwise_layernorm_fuse_" "preln_embedding_eltwise_layernorm_fuse_"
"pass. please use no_varseqlen")); "pass. please use no_varseqlen"));
} }
} else if (!use_varseqlen && pos_id == "" && mask_id == "") { } else if (!use_varseqlen && pos_id == "") {
VLOG(3) << "start no_varseqlen_trt_multihead_matmul_fuse_pass"; VLOG(3) << "start no_varseqlen_trt_multihead_matmul_fuse_pass";
} else { } else {
PADDLE_THROW( PADDLE_THROW(
platform::errors::Fatal("Use transformer'varseqlen need config: " platform::errors::Fatal("Use transformer'varseqlen need config: "
"use_varseqlen, set pos_id, set " "use_varseqlen, set pos_id, set "
"mask_id. Or not use varseqlen, do not set " "mask_id. Or not use varseqlen, do not set "
"pos_id, set mask_id. Please " "pos_id. Please "
"reconfig")); "reconfig"));
} }
graph->Set(kMultiheadMatmulPass, new bool(true)); graph->Set(kMultiheadMatmulPass, new bool(true));
......
...@@ -207,14 +207,14 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { ...@@ -207,14 +207,14 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
"trt_embedding_eltwise_layernorm_fuse_pass, " "trt_embedding_eltwise_layernorm_fuse_pass, "
"trt_multihead_matmul_fuse_pass. please use no_varseqlen")); "trt_multihead_matmul_fuse_pass. please use no_varseqlen"));
} }
} else if (!use_varseqlen && pos_id == "" && mask_id == "") { } else if (!use_varseqlen && pos_id == "") {
VLOG(3) << "start no_varseqlen trt_skip_layernorm_fuse_pass"; VLOG(3) << "start no_varseqlen trt_skip_layernorm_fuse_pass";
} else { } else {
PADDLE_THROW( PADDLE_THROW(
platform::errors::Fatal("Use transformer'varseqlen need config: " platform::errors::Fatal("Use transformer'varseqlen need config: "
"use_varseqlen, set pos_id, set " "use_varseqlen, set pos_id, set "
"mask_id. Or not use varseqlen, do not set " "mask_id. Or not use varseqlen, do not set "
"pos_id, set mask_id. Please " "pos_id. Please "
"reconfig")); "reconfig"));
} }
} }
......
...@@ -332,7 +332,7 @@ class FcOpConverter : public OpConverter { ...@@ -332,7 +332,7 @@ class FcOpConverter : public OpConverter {
} }
// If use tensorrt'oss, the x_dim and x_num_col_dims need change, and can // If use tensorrt'oss, the x_dim and x_num_col_dims need change, and can
// not add Shuffle layer in ernie's multihead. // not add Shuffle layer in ernie's multihead.
if (x_dim.nbDims == 4 && x_num_col_dims == 1) { if (x_dim.nbDims == 4 && x_dim.d[2] == 1 && x_dim.d[3] == 1) {
if (enable_int8 || support_int8) { if (enable_int8 || support_int8) {
// add conv1x1 layer // add conv1x1 layer
nvinfer1::DimsHW nv_ksize(1, 1); nvinfer1::DimsHW nv_ksize(1, 1);
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/transformer_input_output_convert_plugin.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -87,7 +88,7 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -87,7 +88,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
engine_->tensorrt_transformer_posid() != "" && engine_->tensorrt_transformer_posid() != "" &&
engine_->tensorrt_transformer_maskid() != ""; engine_->tensorrt_transformer_maskid() != "";
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
if (flag_varseqlen) { if (engine_->tensorrt_transformer_maskid() != "") {
if (engine_->precision() == AnalysisConfig::Precision::kFloat32) { if (engine_->precision() == AnalysisConfig::Precision::kFloat32) {
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"use use_varseqlen must be int8 or half, not float32.")); "use use_varseqlen must be int8 or half, not float32."));
...@@ -98,8 +99,100 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -98,8 +99,100 @@ class MultiheadMatMulOpConverter : public OpConverter {
nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT,
static_cast<void*>(bias_data), static_cast<void*>(bias_data),
static_cast<int32_t>(bias_t->numel())}; static_cast<int32_t>(bias_t->numel())};
auto max_seqlen_tensor = engine_->GetITensor("max_seqlen_tensor");
auto pos_id_tensor = engine_->GetITensor("pos_id"); nvinfer1::ITensor* mask_tensor;
nvinfer1::ITensor* pos_id_tensor;
nvinfer1::ITensor* max_seqlen_tensor;
auto* new_input = input;
if (flag_varseqlen) {
mask_tensor = engine_->GetITensor("qkv_plugin_mask");
pos_id_tensor = engine_->GetITensor("pos_id");
max_seqlen_tensor = engine_->GetITensor("max_seqlen_tensor");
} else {
auto* bias_qk_tensor =
engine_->GetITensor(op_desc.Input("BiasQK").front());
auto bias_qk_dims = bias_qk_tensor->getDimensions();
PADDLE_ENFORCE_EQ(bias_qk_dims.nbDims,
4,
platform::errors::InvalidArgument(
"The rank of Multihead Matmul'BiasQK must be "
"4, but got rank is %d.",
bias_qk_dims.nbDims));
nvinfer1::Dims start_dims = bias_qk_dims;
start_dims.d[0] = 0;
start_dims.d[1] = 0;
start_dims.d[2] = 0;
start_dims.d[3] = 0;
nvinfer1::Dims size_dims = bias_qk_dims;
nvinfer1::Dims step_dims = bias_qk_dims;
step_dims.d[0] = 1;
step_dims.d[1] = 1;
step_dims.d[2] = 1;
step_dims.d[3] = 1;
auto* shape_tensor = Shape(bias_qk_tensor);
// (b,n,m,m) -> (b,1,m,1)
std::vector<nvinfer1::ITensor*> size_vec_tensor;
size_vec_tensor.push_back(GetEleTensorOfShape(shape_tensor, 0));
size_vec_tensor.push_back(Add1DConstantLayer(1));
size_vec_tensor.push_back(GetEleTensorOfShape(shape_tensor, 2));
size_vec_tensor.push_back(Add1DConstantLayer(1));
auto* size_tensor = Concat(size_vec_tensor);
auto* slice_layer = TRT_ENGINE_ADD_LAYER(engine_,
Slice,
*bias_qk_tensor,
start_dims,
size_dims,
step_dims);
slice_layer->setInput(2, *size_tensor);
// half -> bool
auto* cast_layer_0 = TRT_ENGINE_ADD_LAYER(
engine_, Identity, *slice_layer->getOutput(0));
cast_layer_0->setOutputType(0, nvinfer1::DataType::kBOOL);
// bool kNOT
auto* not_layer =
TRT_ENGINE_ADD_LAYER(engine_,
Unary,
*cast_layer_0->getOutput(0),
nvinfer1::UnaryOperation::kNOT);
// bool -> int32
auto* cast_layer_1 =
TRT_ENGINE_ADD_LAYER(engine_, Identity, *not_layer->getOutput(0));
cast_layer_1->setOutputType(0, nvinfer1::DataType::kINT32);
// Calculate the number of 1 : (b,1,m,1) -> (b)
uint32_t reduce_dim_0 = 0;
reduce_dim_0 |= 1 << 1; // 00000000000000000000000000000010
reduce_dim_0 |= 1 << 2; // 00000000000000000000000000000110
reduce_dim_0 |= 1 << 3; // 00000000000000000000000000001110
bool keep_dim = false;
nvinfer1::ReduceOperation reduce_type =
nvinfer1::ReduceOperation::kSUM;
auto* reduce_sum_layer =
TRT_ENGINE_ADD_LAYER(engine_,
Reduce,
*cast_layer_1->getOutput(0),
reduce_type,
reduce_dim_0,
keep_dim);
std::vector<nvinfer1::ITensor*> inputs_transformer;
inputs_transformer.emplace_back(input);
inputs_transformer.emplace_back(
reduce_sum_layer->getOutput(0)); // (b,m)
plugin::TransformerInputConvertPlugin* plugin =
new plugin::TransformerInputConvertPlugin();
nvinfer1::ILayer* transformer_input_layer = engine_->AddDynamicPlugin(
inputs_transformer.data(), inputs_transformer.size(), plugin);
new_input = transformer_input_layer->getOutput(0);
mask_tensor = transformer_input_layer->getOutput(1);
pos_id_tensor = transformer_input_layer->getOutput(2);
max_seqlen_tensor = transformer_input_layer->getOutput(3);
}
if (engine_->with_interleaved()) { if (engine_->with_interleaved()) {
VLOG(4) << "fused multihead_matmul op: use_varseqlen and " VLOG(4) << "fused multihead_matmul op: use_varseqlen and "
"with_interleaved"; "with_interleaved";
...@@ -111,7 +204,7 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -111,7 +204,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
float dp_probs = 1.0 / 127.0; float dp_probs = 1.0 / 127.0;
nvinfer1::DimsHW nv_ksize(1, 1); nvinfer1::DimsHW nv_ksize(1, 1);
fc_layer = TRT_ENGINE_ADD_LAYER( fc_layer = TRT_ENGINE_ADD_LAYER(
engine_, Convolution, *input, n, nv_ksize, weight, bias); engine_, Convolution, *new_input, n, nv_ksize, weight, bias);
fc_layer->setName( fc_layer->setName(
("Multihead: Convolution/FullyConnected: (Output: " + ("Multihead: Convolution/FullyConnected: (Output: " +
output_name + ")") output_name + ")")
...@@ -220,10 +313,10 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -220,10 +313,10 @@ class MultiheadMatMulOpConverter : public OpConverter {
if (op_desc.HasAttr("Input_scale")) { if (op_desc.HasAttr("Input_scale")) {
nvinfer1::DimsHW nv_ksize(1, 1); nvinfer1::DimsHW nv_ksize(1, 1);
fc_layer = TRT_ENGINE_ADD_LAYER( fc_layer = TRT_ENGINE_ADD_LAYER(
engine_, Convolution, *input, n, nv_ksize, weight, bias); engine_, Convolution, *new_input, n, nv_ksize, weight, bias);
} else { } else {
fc_layer = TRT_ENGINE_ADD_LAYER( fc_layer = TRT_ENGINE_ADD_LAYER(
engine_, FullyConnected, *input, n, weight, bias); engine_, FullyConnected, *new_input, n, weight, bias);
} }
if (op_desc.HasAttr("fc_out_threshold")) { if (op_desc.HasAttr("fc_out_threshold")) {
...@@ -282,14 +375,28 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -282,14 +375,28 @@ class MultiheadMatMulOpConverter : public OpConverter {
std::vector<nvinfer1::ITensor*> plugin_inputs; std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.emplace_back(fc_layer->getOutput(0)); plugin_inputs.emplace_back(fc_layer->getOutput(0));
plugin_inputs.emplace_back(engine_->GetITensor("qkv_plugin_mask")); plugin_inputs.emplace_back(mask_tensor);
plugin_inputs.emplace_back(pos_id_tensor); plugin_inputs.emplace_back(pos_id_tensor);
plugin_inputs.emplace_back( plugin_inputs.emplace_back(
max_seqlen_tensor); // max_seqlen, eval_placeholder_3 max_seqlen_tensor); // max_seqlen, eval_placeholder_3
auto plugin_layer = engine_->network()->addPluginV2( auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin); plugin_inputs.data(), plugin_inputs.size(), *plugin);
layer = plugin_layer;
// recover no_varlen output
if (!flag_varseqlen) {
std::vector<nvinfer1::ITensor*> output_transformer;
output_transformer.emplace_back(plugin_layer->getOutput(0));
output_transformer.emplace_back(input);
output_transformer.emplace_back(pos_id_tensor);
plugin::TransformerOutputConvertPlugin* plugin =
new plugin::TransformerOutputConvertPlugin();
nvinfer1::ILayer* transformer_output_layer =
engine_->AddDynamicPlugin(output_transformer.data(),
output_transformer.size(),
plugin);
layer = transformer_output_layer;
}
} }
} else { } else {
if (input_dims.d[1] <= 384 && !bias_qk_attr && if (input_dims.d[1] <= 384 && !bias_qk_attr &&
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/transformer_input_output_convert_plugin.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -26,7 +26,7 @@ list( ...@@ -26,7 +26,7 @@ list(
deformable_conv_op_plugin.cu deformable_conv_op_plugin.cu
matmul_op_int8_plugin.cu matmul_op_int8_plugin.cu
multihead_matmul_roformer_plugin.cu multihead_matmul_roformer_plugin.cu
transformer_input_convert_plugin.cu transformer_input_output_convert_plugin.cu
remove_padding_plugin.cu remove_padding_plugin.cu
recover_padding_plugin.cu recover_padding_plugin.cu
c_allreduce_op_plugin.cu c_allreduce_op_plugin.cu
......
...@@ -105,7 +105,6 @@ int RemovePaddingPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, ...@@ -105,7 +105,6 @@ int RemovePaddingPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
void* const* outputs, void* const* outputs,
void* workspace, void* workspace,
cudaStream_t stream) TRT_NOEXCEPT { cudaStream_t stream) TRT_NOEXCEPT {
const auto input_desc = inputDesc[0];
const half* input0 = static_cast<const half*>(inputs[0]); const half* input0 = static_cast<const half*>(inputs[0]);
const int32_t* input1 = const int32_t* input1 =
static_cast<const int32_t*>(inputs[1]); // pos_id_tensor static_cast<const int32_t*>(inputs[1]); // pos_id_tensor
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/transformer_input_convert_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
__global__ void TransformerInputConvertKernel(const int64_t* input,
int32_t* output0) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
__shared__ int32_t shared_data;
if (threadIdx.x == static_cast<int>(input[tid])) {
atomicAdd(&shared_data, 1);
}
output0[0] = 0;
output0[blockIdx.x + 1] = shared_data;
__syncthreads();
for (int i = 0; i < blockDim.x; ++i) {
output0[i + 1] += output0[i];
}
}
nvinfer1::DataType TransformerInputConvertPlugin::getOutputDataType(
int index,
const nvinfer1::DataType* input_types,
int nb_inputs) const TRT_NOEXCEPT {
return nvinfer1::DataType::kINT32;
}
nvinfer1::DimsExprs TransformerInputConvertPlugin::getOutputDimensions(
int outputIndex,
const nvinfer1::DimsExprs* inputs,
int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT {
nvinfer1::DimsExprs output_dims{};
output_dims.nbDims = 1;
if (outputIndex == 0) { // PosId
const auto* one = exprBuilder.constant(1);
output_dims.d[0] = exprBuilder.operation(
nvinfer1::DimensionOperation::kSUM, *inputs[0].d[0], *one);
} else { // MaxSeqlen
output_dims.d[0] = inputs[0].d[1];
}
return output_dims;
}
bool TransformerInputConvertPlugin::supportsFormatCombination(
int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs,
int nbOutputs) TRT_NOEXCEPT {
PADDLE_ENFORCE_EQ(nbInputs,
1,
platform::errors::InvalidArgument("Must have 1 inputs, "
"but got %d input(s). ",
nbInputs));
PADDLE_ENFORCE_EQ(nbOutputs,
getNbOutputs(),
platform::errors::InvalidArgument("Must have 2 output, "
"but got %d output(s). ",
nbOutputs));
if (pos == 0) { // input
return inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
} else { // output0, output1
return inOut[pos].type == nvinfer1::DataType::kINT32 &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
}
}
void TransformerInputConvertPlugin::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* outputs,
int nbOutputs) TRT_NOEXCEPT {}
void TransformerInputConvertPlugin::attachToContext(
cudnnContext* cudnnContext,
cublasContext* cublasContext,
nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT {}
void TransformerInputConvertPlugin::detachFromContext() TRT_NOEXCEPT {}
void TransformerInputConvertPlugin::terminate() TRT_NOEXCEPT {}
int TransformerInputConvertPlugin::enqueue(
const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) TRT_NOEXCEPT {
const auto input_desc = inputDesc[0];
const int64_t* input = static_cast<const int64_t*>(inputs[0]);
int32_t* output0 = static_cast<int32_t*>(outputs[0]); // PosId
// int32_t* output1 = static_cast<int32_t*>(outputs[1]); // MaxSeqlen
const int32_t num_blocks = input_desc.dims.d[0]; // batchs
const int32_t num_threads = input_desc.dims.d[1]; // max sequnce length
TransformerInputConvertKernel<<<num_blocks, num_threads, 0, stream>>>(
input, output0);
return cudaGetLastError() != cudaSuccess;
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/plugin/transformer_input_output_convert_plugin.h"
#include "cub/cub.cuh"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
__global__ void remove_padding_kernel(const half* input0,
const int32_t* input1,
half* output) {
int word_id = blockIdx.x * gridDim.y + blockIdx.y;
int32_t seqence_length = input1[blockIdx.x + 1] - input1[blockIdx.x];
if (blockIdx.y < seqence_length) {
output[(input1[blockIdx.x] + blockIdx.y) * gridDim.z * blockDim.x +
blockIdx.z * blockDim.x + threadIdx.x] =
input0[word_id * gridDim.z * blockDim.x + blockIdx.z * blockDim.x +
threadIdx.x];
}
}
__global__ void recover_padding_kernel(const half* input0,
const int32_t* input1,
half* output) {
int word_id = blockIdx.x * gridDim.y + blockIdx.y;
int32_t seqence_length = input1[blockIdx.x + 1] - input1[blockIdx.x];
if (blockIdx.y < seqence_length) {
output[word_id * gridDim.z * blockDim.x + blockIdx.z * blockDim.x +
threadIdx.x] =
input0[(input1[blockIdx.x] + blockIdx.y) * gridDim.z * blockDim.x +
blockIdx.z * blockDim.x + threadIdx.x];
} else {
output[word_id * gridDim.z * blockDim.x + blockIdx.z * blockDim.x +
threadIdx.x] = 0;
}
}
nvinfer1::DataType TransformerInputConvertPlugin::getOutputDataType(
int index,
const nvinfer1::DataType* input_types,
int nb_inputs) const TRT_NOEXCEPT {
if (index == 0) { // new input
return nvinfer1::DataType::kHALF;
} else if (index == 1) { // mask
return nvinfer1::DataType::kHALF;
} else if (index == 2) { // pos id
return nvinfer1::DataType::kINT32;
} else if (index == 3) { // max_seqlen_tensor
return nvinfer1::DataType::kHALF;
}
}
nvinfer1::DimsExprs TransformerInputConvertPlugin::getOutputDimensions(
int outputIndex,
const nvinfer1::DimsExprs* inputs,
int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT {
constexpr size_t threadsPerCta384 = 1 * 8 * 32;
constexpr size_t xmmasM384 = 24;
constexpr size_t packedMaskSize384 = xmmasM384 * threadsPerCta384;
int32_t maskSize_ = packedMaskSize384;
auto maskSize = exprBuilder.constant(maskSize_);
auto fp16maskSize = exprBuilder.operation(
nvinfer1::DimensionOperation::kPROD, *maskSize, *exprBuilder.constant(2));
auto one = exprBuilder.constant(1);
auto B = inputs[0].d[0];
auto MaxLength = inputs[0].d[1];
auto Hidden = inputs[0].d[2];
nvinfer1::DimsExprs output_dims;
if (outputIndex == 0) { // new input
output_dims.nbDims = 4;
output_dims.d[0] = exprBuilder.operation(
nvinfer1::DimensionOperation::kPROD, *B, *MaxLength);
output_dims.d[1] = Hidden;
output_dims.d[2] = exprBuilder.constant(1);
output_dims.d[3] = exprBuilder.constant(1);
} else if (outputIndex == 1) { // mask
output_dims.nbDims = 2;
output_dims.d[0] = B;
output_dims.d[1] = fp16maskSize;
} else if (outputIndex == 2) { // pos id
output_dims.nbDims = 1;
output_dims.d[0] =
exprBuilder.operation(nvinfer1::DimensionOperation::kSUM, *B, *one);
} else if (outputIndex == 3) { // max_seqlen_tensor
output_dims.nbDims = 1;
output_dims.d[0] = MaxLength;
}
return output_dims;
}
bool TransformerInputConvertPlugin::supportsFormatCombination(
int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs,
int nbOutputs) TRT_NOEXCEPT {
PADDLE_ENFORCE_EQ(nbInputs,
2,
platform::errors::InvalidArgument(
"TransformerInputConvertPlugin must have 2 inputs, "
"but got %d input(s). ",
nbInputs));
PADDLE_ENFORCE_EQ(nbOutputs,
4,
platform::errors::InvalidArgument(
"TransformerInputConvertPlugin must have 4 outputs, "
"but got %d output(s). ",
nbOutputs));
if (pos == 0) { // input
return inOut[pos].format == nvinfer1::TensorFormat::kLINEAR &&
inOut[pos].type == nvinfer1::DataType::kHALF;
} else if (pos == 1) { // reducesum_qk_bias
return inOut[pos].format == nvinfer1::TensorFormat::kLINEAR &&
inOut[pos].type == nvinfer1::DataType::kINT32;
} else if (pos == 2) { // new input
return inOut[pos].format == nvinfer1::TensorFormat::kLINEAR &&
inOut[pos].type == nvinfer1::DataType::kHALF;
} else if (pos == 3) { // mask
return inOut[pos].format == nvinfer1::TensorFormat::kLINEAR &&
inOut[pos].type == nvinfer1::DataType::kHALF;
} else if (pos == 4) { // pos id
return inOut[pos].format == nvinfer1::TensorFormat::kLINEAR &&
inOut[pos].type == nvinfer1::DataType::kINT32;
} else if (pos == 5) { // max_seqlen_tensor
return inOut[pos].format == nvinfer1::TensorFormat::kLINEAR &&
inOut[pos].type == nvinfer1::DataType::kHALF;
}
}
void TransformerInputConvertPlugin::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* outputs,
int nbOutputs) TRT_NOEXCEPT {}
void TransformerInputConvertPlugin::attachToContext(
cudnnContext* cudnnContext,
cublasContext* cublasContext,
nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT {}
void TransformerInputConvertPlugin::detachFromContext() TRT_NOEXCEPT {}
void TransformerInputConvertPlugin::terminate() TRT_NOEXCEPT {}
int TransformerInputConvertPlugin::enqueue(
const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) TRT_NOEXCEPT {
// input(no_varlen), reducesum_qk_bias, input(varlen), mask, pos_id,
// max_seqlen_tensor
const half* input0 = static_cast<const half*>(inputs[0]); // input(no_varlen)
const int32_t* input1 =
static_cast<const int32_t*>(inputs[1]); // reducesum_qk_bias
half* output0 = static_cast<half*>(outputs[0]); // input(varlen)
int32_t* output2 = static_cast<int32_t*>(outputs[2]); // pos_id
const auto input0_desc = inputDesc[0];
const int32_t B = input0_desc.dims.d[0]; // batchs
const int32_t MaxLength = input0_desc.dims.d[1]; // max token length
const int32_t HiddenSize = input0_desc.dims.d[2]; // hidden size
// Determine temporary device storage requirements
void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cub::DeviceScan::ExclusiveSum(
d_temp_storage, temp_storage_bytes, input1, output2, B + 1);
// Allocate temporary storage
cudaMalloc(&d_temp_storage, temp_storage_bytes);
// Run exclusive prefix sum
cub::DeviceScan::ExclusiveSum(
d_temp_storage, temp_storage_bytes, input1, output2, B + 1);
const int32_t vector_length = HiddenSize;
int32_t num_threads;
if (vector_length < 1024) {
num_threads = vector_length;
} else {
if (vector_length % 512 == 0) {
num_threads = 512;
} else if (vector_length % 256 == 0) {
num_threads = 256;
} else if (vector_length % 128 == 0) {
num_threads = 128;
} else if (vector_length % 64 == 0) {
num_threads = 64;
} else if (vector_length % 32 == 0) {
num_threads = 32;
} else if (vector_length % 16 == 0) {
num_threads = 16;
} else if (vector_length % 8 == 0) {
num_threads = 8;
} else if (vector_length % 4 == 0) {
num_threads = 4;
} else if (vector_length % 2 == 0) {
num_threads = 2;
} else {
num_threads = 1;
}
}
const dim3 num_blocks(
B,
MaxLength,
vector_length /
num_threads); // batchs, max sequnce length, input0.dims.d[2]/*
remove_padding_kernel<<<num_blocks, num_threads, 0, stream>>>(
input0, output2, output0); // input(no_varlen), pos_id, input(varlen)
return cudaGetLastError() != cudaSuccess;
}
nvinfer1::DataType TransformerOutputConvertPlugin::getOutputDataType(
int index,
const nvinfer1::DataType* input_types,
int nb_inputs) const TRT_NOEXCEPT {
if (index == 0) {
return nvinfer1::DataType::kHALF;
}
}
nvinfer1::DimsExprs TransformerOutputConvertPlugin::getOutputDimensions(
int outputIndex,
const nvinfer1::DimsExprs* inputs,
int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT {
nvinfer1::DimsExprs output_dims;
if (outputIndex == 0) {
output_dims = inputs[1];
}
return output_dims;
}
bool TransformerOutputConvertPlugin::supportsFormatCombination(
int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs,
int nbOutputs) TRT_NOEXCEPT {
PADDLE_ENFORCE_EQ(nbInputs,
3,
platform::errors::InvalidArgument(
"TransformerOutputConvertPlugin must have 3 inputs, "
"but got %d input(s). ",
nbInputs));
PADDLE_ENFORCE_EQ(nbOutputs,
1,
platform::errors::InvalidArgument(
"TransformerOutputConvertPlugin must have 1 output, "
"but got %d output(s). ",
nbOutputs));
if (pos == 0) { // qkv plugin output(varlen)
return inOut[pos].format == nvinfer1::TensorFormat::kLINEAR &&
inOut[pos].type == nvinfer1::DataType::kHALF;
} else if (pos == 1) { // qkv plugin input(no_varlen)
return inOut[pos].format == nvinfer1::TensorFormat::kLINEAR &&
inOut[pos].type == nvinfer1::DataType::kHALF;
} else if (pos == 2) { // pos id
return inOut[pos].format == nvinfer1::TensorFormat::kLINEAR &&
inOut[pos].type == nvinfer1::DataType::kINT32;
} else if (pos == 3) { // qkv plugin output(no_varlen)
return inOut[pos].format == nvinfer1::TensorFormat::kLINEAR &&
inOut[pos].type == nvinfer1::DataType::kHALF;
}
}
void TransformerOutputConvertPlugin::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* outputs,
int nbOutputs) TRT_NOEXCEPT {}
void TransformerOutputConvertPlugin::attachToContext(
cudnnContext* cudnnContext,
cublasContext* cublasContext,
nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT {}
void TransformerOutputConvertPlugin::detachFromContext() TRT_NOEXCEPT {}
void TransformerOutputConvertPlugin::terminate() TRT_NOEXCEPT {}
int TransformerOutputConvertPlugin::enqueue(
const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) TRT_NOEXCEPT {
const half* input0 =
static_cast<const half*>(inputs[0]); // qkv plugin output(varlen)
const half* input1 =
static_cast<const half*>(inputs[1]); // qkv plugin input(no_varlen)
const int32_t* input2 = static_cast<const int32_t*>(inputs[2]); // pos id
half* output =
static_cast<half*>(outputs[0]); // qkv plugin output(no_varlen)
const auto input1_desc = inputDesc[1];
const int32_t B = input1_desc.dims.d[0]; // batchs
const int32_t MaxLength = input1_desc.dims.d[1]; // max token length
const int32_t HiddenSize = input1_desc.dims.d[2]; // hidden size
const int32_t vector_length = HiddenSize;
int32_t num_threads;
if (vector_length < 1024) {
num_threads = vector_length;
} else {
if (vector_length % 512 == 0) {
num_threads = 512;
} else if (vector_length % 256 == 0) {
num_threads = 256;
} else if (vector_length % 128 == 0) {
num_threads = 128;
} else if (vector_length % 64 == 0) {
num_threads = 64;
} else if (vector_length % 32 == 0) {
num_threads = 32;
} else if (vector_length % 16 == 0) {
num_threads = 16;
} else if (vector_length % 8 == 0) {
num_threads = 8;
} else if (vector_length % 4 == 0) {
num_threads = 4;
} else if (vector_length % 2 == 0) {
num_threads = 2;
} else {
num_threads = 1;
}
}
const dim3 num_blocks(
B,
MaxLength,
vector_length / num_threads); // batchs, max sequnce length
// (mask_id.dims.d[1]),
// input.dims.d[1]/*
recover_padding_kernel<<<num_blocks, num_threads, 0, stream>>>(
input0, input2, output);
return cudaGetLastError() != cudaSuccess;
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
...@@ -40,14 +40,14 @@ class TransformerInputConvertPlugin : public DynamicPluginTensorRT { ...@@ -40,14 +40,14 @@ class TransformerInputConvertPlugin : public DynamicPluginTensorRT {
return "transformer_input_convert_plugin"; return "transformer_input_convert_plugin";
} }
int getNbOutputs() const TRT_NOEXCEPT override { return 2; } int getNbOutputs() const TRT_NOEXCEPT override { return 4; }
int initialize() TRT_NOEXCEPT { return 0; } int initialize() TRT_NOEXCEPT { return 0; }
void terminate() TRT_NOEXCEPT; void terminate() TRT_NOEXCEPT;
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs getOutputDimensions(int outputIndex,
const nvinfer1::DimsExprs* inputs, const nvinfer1::DimsExprs* inputs,
int nbInputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) nvinfer1::IExprBuilder& exprBuilder) // NOLINT
TRT_NOEXCEPT override; TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos, bool supportsFormatCombination(int pos,
...@@ -134,7 +134,120 @@ class TransformerInputConvertPluginCreator : public nvinfer1::IPluginCreator { ...@@ -134,7 +134,120 @@ class TransformerInputConvertPluginCreator : public nvinfer1::IPluginCreator {
std::string plugin_name_; std::string plugin_name_;
nvinfer1::PluginFieldCollection field_collection_{0, nullptr}; nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
}; };
class TransformerOutputConvertPlugin : public DynamicPluginTensorRT {
public:
TransformerOutputConvertPlugin() {}
TransformerOutputConvertPlugin(void const* serial_data,
size_t serial_length) {}
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
TransformerOutputConvertPlugin* ptr = new TransformerOutputConvertPlugin();
return ptr;
}
const char* getPluginType() const TRT_NOEXCEPT override {
return "transformer_output_convert_plugin";
}
int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
int initialize() TRT_NOEXCEPT { return 0; }
void terminate() TRT_NOEXCEPT;
nvinfer1::DimsExprs getOutputDimensions(int outputIndex,
const nvinfer1::DimsExprs* inputs,
int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) // NOLINT
TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* outputs,
int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const TRT_NOEXCEPT override {
return 0;
}
void attachToContext(cudnnContext* cudnnContext,
cublasContext* cublasContext,
nvinfer1::IGpuAllocator* gpuAllocator)
TRT_NOEXCEPT override;
void detachFromContext() TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) TRT_NOEXCEPT override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const
TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override { delete this; }
protected:
size_t getSerializationSize() const TRT_NOEXCEPT override { return 0; }
void serialize(void* buffer) const TRT_NOEXCEPT override {}
};
class TransformerOutputConvertPluginCreator : public nvinfer1::IPluginCreator {
public:
TransformerOutputConvertPluginCreator() {}
const char* getPluginName() const TRT_NOEXCEPT override {
return "transformer_output_convert_plugin";
}
const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override {
return &field_collection_;
}
nvinfer1::IPluginV2* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* plugin_field)
TRT_NOEXCEPT override {
return nullptr;
}
nvinfer1::IPluginV2* deserializePlugin(const char* name,
void const* serial_data,
size_t serial_length)
TRT_NOEXCEPT override {
TransformerOutputConvertPlugin* obj =
new TransformerOutputConvertPlugin(serial_data, serial_length);
obj->setPluginNamespace(name);
return obj;
}
void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override {
plugin_namespace_ = lib_namespace;
}
const char* getPluginNamespace() const TRT_NOEXCEPT override {
return plugin_namespace_.c_str();
}
private:
std::string plugin_namespace_;
std::string plugin_name_;
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
};
REGISTER_TRT_PLUGIN_V2(TransformerInputConvertPluginCreator); REGISTER_TRT_PLUGIN_V2(TransformerInputConvertPluginCreator);
REGISTER_TRT_PLUGIN_V2(TransformerOutputConvertPluginCreator);
} // namespace plugin } // namespace plugin
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册