未验证 提交 0f3fb562 编写于 作者: Z zhangxin81 提交者: GitHub

[search && paddle inference]add roformer pass&&plugin novarlen version (#47523)

* add roformer pass&&plugin(novarlen)
上级 8164b97a
......@@ -105,6 +105,7 @@ pass_library(simplify_with_basic_ops_pass base)
pass_library(fc_elementwise_layernorm_fuse_pass base)
pass_library(skip_layernorm_fuse_pass base)
pass_library(multihead_matmul_fuse_pass inference)
pass_library(multihead_matmul_roformer_fuse_pass inference)
pass_library(fused_multi_transformer_encoder_pass inference)
pass_library(fused_multi_transformer_decoder_pass inference)
pass_library(fuse_multi_transformer_layer_pass inference)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
/*
* \brief Fuse the subgraph representing multihead attention part of roformer
* into multihead_matmul_roformer op.
*
* \note The following graph represents this equation:
*
* x - input data
* cos - input data of cos mat
* sin - input data of sin mat
* ele_add - elementwise_add
* ele_mul - elementwise_mul
*
* x
* / | \
* / | \
* / | \
* | | |
* | | |
* mul mul mul
* | | |
* ele_add ele_add ele_add
* | | |
* reshape2 reshape2 reshape2
* | | |
* transpose2 transpose2 transpose2
* | / \ / \
* | | | | |
* | | cos split | sin split
* | | / | | / |
* | ele_mul concat ele_mul concat
* | | | | |
* | \ / \ /
* | ele_add ele_add
* | | |
* | | scale
* | | |
* | \ /
* | matmul
* | |
* | ele_add
* \ |
* \ softmax
* \ |
* \ /
* matmmul
*
*/
struct MultiHeadMatmulRoformerPattern : public PatternBase {
MultiHeadMatmulRoformerPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, "multihead_matmul_roformer") {}
PDNode* operator()();
// declare operator node's name
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(input_cos);
PATTERN_DECL_NODE(input_sin);
PATTERN_DECL_NODE(mul0);
PATTERN_DECL_NODE(mul1);
PATTERN_DECL_NODE(mul2);
PATTERN_DECL_NODE(mul0_w);
PATTERN_DECL_NODE(mul1_w);
PATTERN_DECL_NODE(mul2_w);
PATTERN_DECL_NODE(mul0_out);
PATTERN_DECL_NODE(mul1_out);
PATTERN_DECL_NODE(mul2_out);
PATTERN_DECL_NODE(eltadd0); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd1_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd2_b); // ELEMENTWISE_ADD
PATTERN_DECL_NODE(eltadd0_out);
PATTERN_DECL_NODE(eltadd1_out);
PATTERN_DECL_NODE(eltadd2_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_1);
PATTERN_DECL_NODE(reshape2_2);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(reshape2_1_out);
PATTERN_DECL_NODE(reshape2_2_out);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(scale);
PATTERN_DECL_NODE(scale_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_1);
PATTERN_DECL_NODE(transpose2_2);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(transpose2_2_out);
PATTERN_DECL_NODE(transpose2_qkv_out);
PATTERN_DECL_NODE(eltmul_cos_q);
PATTERN_DECL_NODE(eltmul_cos_q_out);
PATTERN_DECL_NODE(eltmul_sin_q);
PATTERN_DECL_NODE(eltmul_sin_q_out);
PATTERN_DECL_NODE(eltmul_cos_k);
PATTERN_DECL_NODE(eltmul_cos_k_out);
PATTERN_DECL_NODE(eltmul_sin_k);
PATTERN_DECL_NODE(eltmul_sin_k_out);
PATTERN_DECL_NODE(split_q);
PATTERN_DECL_NODE(split_q_out);
PATTERN_DECL_NODE(concat_q);
PATTERN_DECL_NODE(concat_q_out);
PATTERN_DECL_NODE(split_k);
PATTERN_DECL_NODE(split_k_out);
PATTERN_DECL_NODE(concat_k);
PATTERN_DECL_NODE(concat_k_out);
PATTERN_DECL_NODE(eltadd_q);
PATTERN_DECL_NODE(eltadd_q_out);
PATTERN_DECL_NODE(eltadd_k);
PATTERN_DECL_NODE(eltadd_k_out);
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(eltadd_qk);
PATTERN_DECL_NODE(eltadd_qk_b);
PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
};
} // namespace patterns
class MultiHeadMatmulRoformerFusePass : public FusePassBase {
public:
MultiHeadMatmulRoformerFusePass();
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"multihead_matmul_roformer_fuse"};
private:
int BuildFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -2265,6 +2265,7 @@ USE_TRT_CONVERTER(instance_norm);
USE_TRT_CONVERTER(layer_norm);
USE_TRT_CONVERTER(gelu);
USE_TRT_CONVERTER(multihead_matmul);
USE_TRT_CONVERTER(multihead_matmul_roformer);
USE_TRT_CONVERTER(skip_layernorm);
USE_TRT_CONVERTER(slice);
USE_TRT_CONVERTER(scale);
......
......@@ -106,6 +106,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
"delete_c_identity_op_pass", //
"trt_multihead_matmul_fuse_pass_v2", //
"trt_multihead_matmul_fuse_pass_v3", //
"multihead_matmul_roformer_fuse_pass", //
"constant_folding_pass", //
"vit_attention_fuse_pass", //
"trt_skip_layernorm_fuse_pass", //
......
......@@ -23,6 +23,7 @@ list(
gelu_op.cc
layer_norm_op.cc
multihead_matmul_op.cc
multihead_matmul_roformer_op.cc
shuffle_channel_op.cc
swish_op.cc
silu_op.cc
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See
the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/multihead_matmul_roformer_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class MultiheadMatMulRoformerOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
VLOG(3) << "convert a fluid multihead_mamul_roformer op to a corresponding "
"tensorrt "
"network structure";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
auto* input = engine_->GetITensor(op_desc.Input("Input").front());
auto* input_cos = engine_->GetITensor(op_desc.Input("Input_cos").front());
auto* input_sin = engine_->GetITensor(op_desc.Input("Input_sin").front());
// fc weights and fc bias
auto weight_name = op_desc.Input("W").front();
auto bias_name = op_desc.Input("Bias").front();
auto* weight_v = scope.FindVar(weight_name);
auto* weight_t = weight_v->GetMutable<phi::DenseTensor>();
auto* bias_v = scope.FindVar(bias_name);
auto* bias_t = bias_v->GetMutable<phi::DenseTensor>();
float* weight_data = nullptr;
float in_scale = 0.;
if (op_desc.HasAttr("Input_scale")) {
in_scale = PADDLE_GET_CONST(float, op_desc.GetAttr("Input_scale"));
engine_->SetTensorDynamicRange(input, in_scale);
}
weight_data = const_cast<float*>(static_cast<const float*>(
engine_->GetFp32TrtWeight(weight_name, *weight_t).get().values));
float* bias_data = const_cast<float*>(static_cast<const float*>(
engine_->GetFp32TrtWeight(bias_name, *bias_t).get().values));
std::vector<float> weight_data_tmp;
weight_data_tmp.reserve(weight_t->numel());
memcpy(
weight_data_tmp.data(), weight_data, weight_t->numel() * sizeof(float));
// (hidden_in, 3, hidden_out)
auto weight_dims = weight_t->dims();
int hidden_in = weight_dims[0]; // channels_in
int three = weight_dims[1]; // channels_out
int hidden_out = weight_dims[2]; // channels_out
int m = hidden_in;
int n = three * hidden_out;
auto tranpose_weight = [](const float* src, float* dst, int m, int n) {
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
dst[j * m + i] = src[i * n + j];
}
}
};
tranpose_weight(weight_data_tmp.data(), weight_data, m, n);
int head_number = PADDLE_GET_CONST(int, op_desc.GetAttr("head_number"));
nvinfer1::ILayer* layer = nullptr;
auto output_name = op_desc.Output("Out")[0];
bool flag_varseqlen = engine_->use_varseqlen() &&
engine_->tensorrt_transformer_posid() != "" &&
engine_->tensorrt_transformer_maskid() != "";
if (engine_->with_dynamic_shape()) {
if (flag_varseqlen) {
PADDLE_THROW(
platform::errors::Fatal("roformer not support varseqlen yet"));
} else {
PADDLE_ENFORCE_EQ(
input->getDimensions().nbDims,
3,
platform::errors::InvalidArgument(
"The Input dim of the MultiheadMatMul should be 3, "
"but it's (%d) now.",
input->getDimensions().nbDims));
// transpose weight_data from m * n to n * m
auto* input_bias_qk =
engine_->GetITensor(op_desc.Input("BiasQK").front());
TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data),
static_cast<size_t>(weight_t->numel())};
weight.dims.assign({n, m});
TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT,
static_cast<void*>(bias_data),
static_cast<size_t>(bias_t->numel())};
// add shuffle before fc
nvinfer1::Dims reshape_before_fc_dim;
reshape_before_fc_dim.nbDims = 5;
reshape_before_fc_dim.d[0] = 0;
reshape_before_fc_dim.d[1] = 0;
reshape_before_fc_dim.d[2] = 0;
reshape_before_fc_dim.d[3] = 1;
reshape_before_fc_dim.d[4] = 1;
auto* reshape_before_fc_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
if (op_desc.HasAttr("Input_scale")) {
engine_->SetTensorDynamicRange(reshape_before_fc_layer->getOutput(0),
in_scale);
}
reshape_before_fc_layer->setReshapeDimensions(reshape_before_fc_dim);
reshape_before_fc_layer->setName(
("shuffle_before_multihead_matmul(Output: " + output_name + ")")
.c_str());
// add layer fc
nvinfer1::ILayer* fc_layer = nullptr;
if (op_desc.HasAttr("Input_scale")) {
nvinfer1::DimsHW nv_ksize(1, 1);
fc_layer =
TRT_ENGINE_ADD_LAYER(engine_,
Convolution,
*reshape_before_fc_layer->getOutput(0),
n,
nv_ksize,
weight.get(),
bias.get());
} else {
fc_layer =
TRT_ENGINE_ADD_LAYER(engine_,
FullyConnected,
*reshape_before_fc_layer->getOutput(0),
n,
weight.get(),
bias.get());
}
if (op_desc.HasAttr("fc_out_threshold")) {
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("fc_out_threshold"),
true,
platform::errors::InvalidArgument(
"must have out threshold in multihead layers in int8 mode"));
float out_scale =
PADDLE_GET_CONST(float, op_desc.GetAttr("fc_out_threshold"));
engine_->SetTensorDynamicRange(fc_layer->getOutput(0), out_scale);
}
fc_layer->setName(
("multihead_matmul_fc(Output: " + output_name + ")").c_str());
// no need to add shuffle after fc, just change it in
// QkvToContextPluginDynamic
// add qkv to context
int head_size = hidden_out / head_number;
float scale = PADDLE_GET_CONST(float, op_desc.GetAttr("alpha"));
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.push_back(fc_layer->getOutput(0));
plugin_inputs.push_back(input_cos);
plugin_inputs.push_back(input_sin);
plugin_inputs.push_back(input_bias_qk);
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
if (engine_->precision() == AnalysisConfig::Precision::kInt8) {
with_fp16 = true;
}
plugin::DynamicPluginTensorRT* plugin =
new plugin::MultiheadMatmulRoformerPlugin(
hidden_in, head_number, head_size, scale, with_fp16);
layer = engine_->AddDynamicPlugin(plugin_inputs.data(), 4, plugin);
}
} else {
PADDLE_THROW(platform::errors::Fatal(
"You are running the Ernie(Bert) model in static shape mode, which "
"is not supported for the time being.\n"
"You can use the config.SetTRTDynamicShapeInfo(...) interface to set "
"the shape information to run the dynamic shape mode."));
}
RreplenishLayerAndOutput(
layer, "multihead_matmul_roformer", {output_name}, test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(multihead_matmul_roformer,
MultiheadMatMulRoformerOpConverter);
......@@ -1723,6 +1723,58 @@ struct SimpleOpTypeSetTeller : public Teller {
}
}
if (op_type == "multihead_matmul_roformer") {
if (!with_dynamic_shape) {
VLOG(3) << "the multihead_matmul_roformer does not support static "
"shape yet";
return false;
}
if (desc.HasAttr("enable_int8") && !desc.HasAttr("Input_scale")) {
VLOG(3) << "Multihead layers must have input scale in int8 mode.";
return false;
}
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto* input_desc = block->FindVar(desc.Input("Input").front());
const auto input_shape = input_desc->GetShape();
const auto head_number =
PADDLE_GET_CONST(int, desc.GetAttr("head_number"));
auto inputs = desc.Inputs();
bool has_bias_qk = (inputs.find("BiasQK") == inputs.end()) ? false : true;
if (has_bias_qk) {
auto* biasqk_desc = block->FindVar(desc.Input("BiasQK").front());
const auto biasqk_shape = biasqk_desc->GetShape();
// The BiasQK's shape requires to be
// [batch, 1, 1, length] or [batch, head, length, length].
bool has_same_shape = head_number == biasqk_shape[1] &&
input_shape[1] == biasqk_shape[2] &&
input_shape[1] == biasqk_shape[3];
bool is_broadcastable = biasqk_shape[1] == 1 && biasqk_shape[2] == 1 &&
input_shape[1] == biasqk_shape[3];
if (!(has_same_shape || is_broadcastable)) {
VLOG(3) << "The BiasQK's shape is invalid, expect [" << input_shape[0]
<< ", 1, 1, " << input_shape[1] << "] or [" << input_shape[0]
<< ", " << head_number << ", " << input_shape[1] << ", "
<< input_shape[1] << "] but [" << biasqk_shape[0] << ", "
<< biasqk_shape[1] << ", " << biasqk_shape[2] << ", "
<< biasqk_shape[3] << "].";
return false;
}
} else {
#if !IS_TRT_VERSION_GE(8000)
VLOG(3) << "The version of TRT must be greater than 8000";
return false;
#endif
}
}
if (op_type == "fc") {
auto* block = desc.Block();
if (block == nullptr) {
......@@ -2271,6 +2323,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"clip",
"fused_embedding_eltwise_layernorm",
"multihead_matmul",
"multihead_matmul_roformer",
"skip_layernorm",
"slice",
"strided_slice",
......@@ -2394,6 +2447,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"clip",
"fused_embedding_eltwise_layernorm",
"multihead_matmul",
"multihead_matmul_roformer",
"skip_layernorm",
"slice",
"strided_slice",
......
......@@ -25,6 +25,7 @@ list(
pool3d_op_plugin.cu
deformable_conv_op_plugin.cu
matmul_op_int8_plugin.cu
multihead_matmul_roformer_plugin.cu
transformer_input_convert_plugin.cu
remove_padding_plugin.cu
recover_padding_plugin.cu
......
......@@ -18,6 +18,7 @@
#include <cub/cub.cuh>
#include "cublas_v2.h"
#include "paddle/fluid/platform/device_context.h"
using kv_float = cub::KeyValuePair<float, float>;
using kv_half = cub::KeyValuePair<half, half>;
......@@ -144,3 +145,154 @@ __device__ inline void layerNorm(const kvp<R>& threadData,
output[idx] = g * (val - mu) * rsigma + b;
}
}
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
// Helper Functions for multihead related plugins
template <typename T>
__global__ void transpose(T *src,
T *dst,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head) {
int batch_id = blockIdx.x / (head_num * seq_len);
int seq_id = blockIdx.x % seq_len;
int head_id = (blockIdx.x % (head_num * seq_len)) / seq_len;
dst[batch_id * (head_num * seq_len * size_per_head) +
seq_id * head_num * size_per_head + head_id * size_per_head +
threadIdx.x] = src[blockIdx.x * size_per_head + threadIdx.x];
}
template <typename T>
__global__ void TransposeQkvKernel(const int H, const T *input, T *output) {
// Input: BxSx3xNxH
// Bias: 3xSxB
// Output: 3xBxNxSxH
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z;
const int N = blockDim.y;
const int S = gridDim.x;
const int B = gridDim.y;
const int NH = N * H;
const int NHS = NH * S;
const int in_offset = n * H + m * NH + s * 3 * NH + b * NHS * 3;
const int out_offset = s * H + n * S * H + b * NHS + m * NHS * B;
const int i = threadIdx.x;
output[out_offset + i] = input[in_offset + i];
}
inline void TransposeQKV(const int batch,
const int seq_len,
const int head_size,
const int head_num,
const float *input,
float *output,
cudaStream_t stream) {
int scratch_size = batch * head_num * seq_len * seq_len;
const dim3 grid(seq_len, batch, 3);
if (head_size % 4 == 0 && scratch_size % 4 == 0) {
const int h = head_size / 4;
const float4 *input4 = reinterpret_cast<const float4 *>(input);
float4 *output4 = reinterpret_cast<float4 *>(output);
const dim3 block(h, head_num, 1);
// limit h * head_num to max block size(1024).
PADDLE_ENFORCE_LE(h * head_num,
1024,
platform::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num,
head_size,
1024 * 4));
TransposeQkvKernel<float4><<<grid, block, 0, stream>>>(h, input4, output4);
} else if (head_size % 2 == 0 && scratch_size % 2 == 0) {
const int h = head_size / 2;
const float2 *input2 = reinterpret_cast<const float2 *>(input);
float2 *output2 = reinterpret_cast<float2 *>(output);
const dim3 block(h, head_num, 1);
// limit h * head_num to max block size(1024).
PADDLE_ENFORCE_LE(h * head_num,
1024,
platform::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num,
head_size,
1024 * 2));
TransposeQkvKernel<float2><<<grid, block, 0, stream>>>(h, input2, output2);
} else {
const dim3 block(head_size, head_num, 1);
// limit head_size * head_num to max block size(1024).
PADDLE_ENFORCE_LE(head_size * head_num,
1024,
platform::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num,
head_size,
1024));
TransposeQkvKernel<float>
<<<grid, block, 0, stream>>>(head_size, input, output);
}
}
inline void TransposeQKV(const int batch,
const int seq_len,
const int head_size,
const int head_num,
const half *input,
half *output,
cudaStream_t stream) {
int scratch_size = batch * head_num * seq_len * seq_len;
const dim3 grid(seq_len, batch, 3);
if (head_size % 8 == 0 && scratch_size % 8 == 0) {
int h = head_size / 8;
const int4 *input4 = reinterpret_cast<const int4 *>(input);
int4 *output4 = reinterpret_cast<int4 *>(output);
dim3 block(h, head_num, 1);
// limit h * head_num to max block size(1024).
PADDLE_ENFORCE_LE(h * head_num,
1024,
platform::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num,
head_size,
1024 * 8));
TransposeQkvKernel<int4><<<grid, block, 0, stream>>>(h, input4, output4);
} else if (head_size % 2 == 0 && scratch_size % 2 == 0) {
const int h = head_size / 2;
const half2 *input2 = reinterpret_cast<const half2 *>(input);
half2 *output2 = reinterpret_cast<half2 *>(output);
const dim3 block(h, head_num, 1);
// limit h * head_num to max block size(1024).
PADDLE_ENFORCE_LE(h * head_num,
1024,
platform::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num,
head_size,
1024 * 2));
TransposeQkvKernel<half2><<<grid, block, 0, stream>>>(h, input2, output2);
} else {
const dim3 block(head_size, head_num, 1);
// limit head_size * head_num to max block size(1024).
PADDLE_ENFORCE_LE(head_size * head_num,
1024,
platform::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num,
head_size,
1024));
TransposeQkvKernel<half>
<<<grid, block, 0, stream>>>(head_size, input, output);
}
}
}
}
}
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/tensorrt/plugin/multihead_matmul_roformer_plugin.h"
#include <stdio.h>
#include <cassert>
#include <cub/cub.cuh> // NOLINT
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/common.cuh"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)
int MultiheadMatmulRoformerPlugin::initialize() TRT_NOEXCEPT { return 0; }
nvinfer1::DimsExprs MultiheadMatmulRoformerPlugin::getOutputDimensions(
int output_index,
const nvinfer1::DimsExprs *inputs,
int nb_inputs,
nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT {
// input[0], (B, S, 3 * N * H, 1, 1)
// input[1], (B, head_num, seq_len, seq_len)
// output, (B, seq_len, hidden)
PADDLE_ENFORCE_EQ(output_index,
0,
platform::errors::InvalidArgument(
"There is only one output of the EmbEltwiseLayernorm, "
"so the index should be zero,"
"but it's (%d)",
output_index));
PADDLE_ENFORCE_EQ(
nb_inputs,
4,
platform::errors::InvalidArgument(
"The Input of the EmbEltwiseLayernorm should be 3, but we found "
"it has (%d) inputs",
nb_inputs));
nvinfer1::DimsExprs ret;
ret.nbDims = 3;
ret.d[0] = inputs[0].d[0];
ret.d[1] = inputs[0].d[1];
ret.d[2] = expr_builder.constant(head_size_ * head_number_);
return ret;
}
bool MultiheadMatmulRoformerPlugin::supportsFormatCombination(
int pos,
const nvinfer1::PluginTensorDesc *in_out,
int nb_inputs,
int nb_outputs) TRT_NOEXCEPT {
PADDLE_ENFORCE_NOT_NULL(
in_out,
platform::errors::InvalidArgument(
"The input of swish plugin shoule not be nullptr."));
PADDLE_ENFORCE_LT(
pos,
nb_inputs + nb_outputs,
platform::errors::InvalidArgument("The pos(%d) should be less than the "
"num(%d) of the input and the output.",
pos,
nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc &in = in_out[pos];
if (pos == 0) {
if (with_fp16_) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE
return (in.type == nvinfer1::DataType::kFLOAT ||
in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#else
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#endif
} else {
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
}
}
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
if (pos == 1) {
return in.type == prev.type && in.format == prev.format;
}
// output
return in.type == prev.type && in.format == prev.format;
}
nvinfer1::DataType MultiheadMatmulRoformerPlugin::getOutputDataType(
int index,
const nvinfer1::DataType *input_types,
int nb_inputs) const TRT_NOEXCEPT {
PADDLE_ENFORCE_EQ(
index,
0,
platform::errors::InvalidArgument(
"The EmbEltwiseLayernorm Plugin only has one input, so the "
"index value should be 0, but get %d.",
index));
return input_types[0];
}
template <typename T>
__global__ void apply_scale(T *data, T scale, int n) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < n) {
data[tid] = data[tid] * scale;
}
#endif
}
template <typename T>
__global__ void RotrayKernel(const T *inputact,
const T *input1,
const T *intput2,
T *output,
const int nElement,
const int lastdim) {
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index >= nElement) return;
T left_elemul_out = input1[index] * inputact[index];
int col = index % lastdim;
int half_lastdim = lastdim / 2;
const int right_index = index - col + (col + half_lastdim) % lastdim;
output[index] = left_elemul_out + intput2[index] * inputact[right_index];
}
inline int round_up(int seq_len, int multiple = 32) {
PADDLE_ENFORCE_GT(
multiple,
0,
platform::errors::InvalidArgument(
"multiple should be a positive number,but it's (%d)", multiple));
return ((seq_len + multiple - 1) / multiple) * multiple;
}
template <typename T>
__global__ void broadcast(const T *src,
T *dst,
const int seq_len,
const int head_num) {
int batch_id = blockIdx.x / (head_num * seq_len);
int dst_offset = blockIdx.x * seq_len;
if (threadIdx.x < seq_len) {
dst[threadIdx.x + dst_offset] = src[threadIdx.x + batch_id * seq_len];
}
}
int MultiheadMatmulRoformerPlugin::enqueue(
const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc,
const void *const *inputs,
void *const *outputs,
void *workspace,
cudaStream_t stream) TRT_NOEXCEPT {
auto input_dims = input_desc[0].dims;
int input_num = ProductDim(input_dims);
// input[0], (B, S, 3 * N * H, 1, 1)
int batch = input_dims.d[0];
int seq_len = input_dims.d[1];
phi::DenseTensor multihead_temp_tensor;
// masks
int scratch_size = batch * head_number_ * seq_len * seq_len * 1;
int device_id;
cudaGetDevice(&device_id);
multihead_temp_tensor.Resize({scratch_size + input_num});
// for roformer
phi::DenseTensor temp_roformer_tensor;
temp_roformer_tensor.Resize({input_num});
auto input_type = input_desc[0].type;
if (input_type == nvinfer1::DataType::kFLOAT) {
VLOG(1) << "TRT Plugin DataType selected. RoformerQkvToContext-->fp32";
auto *multihead_temp_data = multihead_temp_tensor.mutable_data<float>(
platform::CUDAPlace(device_id));
auto *temp_roformer_data =
temp_roformer_tensor.mutable_data<float>( // NOLINT
platform::CUDAPlace(device_id));
auto *tmp_roformer_ptr = reinterpret_cast<float *>(temp_roformer_data);
auto *qkptr = multihead_temp_data;
auto *tptr = multihead_temp_data + scratch_size;
const float *input0_data = static_cast<const float *>(inputs[0]);
// fit to [batch, head_num, length, length] + [batch, 1, 1, length]
phi::DenseTensor temp_qk_bias_tensor;
float *qk_bias = const_cast<float *>(static_cast<const float *>(inputs[3]));
if (ProductDim(input_desc[3].dims) == (batch * seq_len)) {
temp_qk_bias_tensor.Resize({batch, head_number_, seq_len, seq_len});
auto *temp_qk_bias = temp_qk_bias_tensor.mutable_data<float>(
platform::CUDAPlace(device_id));
int grid = batch * head_number_ * seq_len;
int block = round_up(seq_len);
broadcast<<<grid, block, 0, stream>>>(
static_cast<const float *>(inputs[3]),
temp_qk_bias,
seq_len,
head_number_);
qk_bias = temp_qk_bias;
}
const float *input3_data = static_cast<const float *>(qk_bias);
// BxSx3xNxH => tptr: 3xBxNxSxH.
TransposeQKV(
batch, seq_len, head_size_, head_number_, input0_data, tptr, stream);
cudaMemcpy(tmp_roformer_ptr, // dst
tptr, // src
input_num * sizeof(float),
cudaMemcpyDeviceToDevice);
int n_q = seq_len * head_number_ * head_size_ * batch;
constexpr int threads = 128;
int blocks = (n_q + threads - 1) / threads;
const float *input_cos_data = static_cast<const float *>(inputs[1]);
const float *input_sin_data = static_cast<const float *>(inputs[2]);
RotrayKernel<<<blocks, threads, 0, stream>>>(tmp_roformer_ptr,
input_cos_data,
input_sin_data,
tptr,
n_q,
head_size_); // q
RotrayKernel<<<blocks, threads, 0, stream>>>(tmp_roformer_ptr + n_q,
input_cos_data,
input_sin_data,
tptr + n_q,
n_q,
head_size_); // k
auto *device_ctx = static_cast<phi::GPUContext *>(
platform::DeviceContextPool::Instance().Get(
platform::CUDAPlace(device_id)));
const phi::GPUContext &dev_ctx = *device_ctx;
operators::math::MultiHeadGPUComputeFunctor<float> multihead_compute_func;
multihead_compute_func(dev_ctx,
batch,
seq_len,
head_number_,
head_size_,
qkptr,
input3_data,
false,
tptr,
scale_,
static_cast<float>(0.0));
int grid = batch * head_number_ * seq_len;
int block = head_size_;
float *output = static_cast<float *>(outputs[0]);
transpose<float><<<grid, block, 0, stream>>>(
tptr, output, batch, seq_len, head_number_, head_size_);
} else if (input_type == nvinfer1::DataType::kHALF) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE
VLOG(1) << "TRT Plugin DataType selected. QkvToContext-->fp16";
auto *multihead_temp_data =
multihead_temp_tensor.mutable_data<int16_t>( // NOLINT
platform::CUDAPlace(device_id));
auto *temp_roformer_data =
temp_roformer_tensor.mutable_data<int16_t>( // NOLINT
platform::CUDAPlace(device_id));
half *tmp_roformer_ptr = reinterpret_cast<half *>(temp_roformer_data);
half *qkptr = reinterpret_cast<half *>(multihead_temp_data);
half *tptr = qkptr + scratch_size;
const half *input0_data = static_cast<const half *>(inputs[0]);
// fit to [batch, head_num, length, length] + [batch, 1, 1, length]
phi::DenseTensor temp_qk_bias_tensor;
half *qk_bias = const_cast<half *>(static_cast<const half *>(inputs[3]));
if (ProductDim(input_desc[3].dims) == (batch * seq_len)) {
temp_qk_bias_tensor.Resize({batch, head_number_, seq_len, seq_len});
auto *temp_qk_bias =
reinterpret_cast<half *>(temp_qk_bias_tensor.mutable_data<int16_t>(
platform::CUDAPlace(device_id)));
int grid = batch * head_number_ * seq_len;
int block = round_up(seq_len);
broadcast<<<grid, block, 0, stream>>>(
static_cast<const half *>(inputs[3]),
temp_qk_bias,
seq_len,
head_number_);
qk_bias = temp_qk_bias;
}
const half *input3_data = static_cast<const half *>(qk_bias);
// BxSx3xNxH => tptr: 3xBxNxSxH.
TransposeQKV(
batch, seq_len, head_size_, head_number_, input0_data, tptr, stream);
cudaMemcpy(tmp_roformer_ptr,
tptr,
input_num * sizeof(half),
cudaMemcpyDeviceToDevice);
auto *device_ctx = static_cast<phi::GPUContext *>(
platform::DeviceContextPool::Instance().Get(
platform::CUDAPlace(device_id)));
int n_q = seq_len * head_number_ * head_size_ * batch;
constexpr int threads = 128;
int blocks = (n_q + threads - 1) / threads;
const half *input_cos_data = static_cast<const half *>(inputs[1]);
const half *input_sin_data = static_cast<const half *>(inputs[2]);
RotrayKernel<<<blocks, threads, 0, stream>>>(tmp_roformer_ptr,
input_cos_data,
input_sin_data,
tptr,
n_q,
head_size_); // q
RotrayKernel<<<blocks, threads, 0, stream>>>(tmp_roformer_ptr + n_q,
input_cos_data,
input_sin_data,
tptr + n_q,
n_q,
head_size_); // k
apply_scale<<<blocks, threads, 0, stream>>>(
tptr, static_cast<half>(scale_), n_q);
const phi::GPUContext &dev_ctx = *device_ctx;
operators::math::MultiHeadGPUComputeFunctor<half> multihead_compute_func;
multihead_compute_func(dev_ctx,
batch,
seq_len,
head_number_,
head_size_,
qkptr,
input3_data,
false,
tptr,
half(1.),
half(0.0));
int grid = batch * head_number_ * seq_len;
int block = head_size_;
half *output = static_cast<half *>(outputs[0]);
transpose<half><<<grid, block, 0, stream>>>(
tptr, output, batch, seq_len, head_number_, head_size_);
#else
PADDLE_THROW(platform::errors::Fatal(
"The Ernie(Bert) TensorRT Plugin should be "
"complied with CUDA version >= 10.0 when running with fp16. "
"Please recomplie it or try to use fp32 by set "
"config.SetTRTDynamicShapeInfo(min_input_shape, "
"max_input_shape, opt_input_shape, true"));
#endif
} else {
PADDLE_THROW(platform::errors::Fatal(
"The QKV TRT Plugin's input type should be float or half."));
}
return cudaGetLastError() != cudaSuccess;
}
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
#if IS_TRT_VERSION_GE(6000)
class MultiheadMatmulRoformerPlugin : public DynamicPluginTensorRT {
public:
explicit MultiheadMatmulRoformerPlugin(
int hidden, int head_number, int head_size, float scale, bool with_fp16)
: hidden_(hidden),
head_number_(head_number),
head_size_(head_size),
scale_(scale) {
with_fp16_ = with_fp16;
}
MultiheadMatmulRoformerPlugin(void const* serial_data, size_t serial_length) {
DeserializeValue(&serial_data, &serial_length, &hidden_);
DeserializeValue(&serial_data, &serial_length, &head_number_);
DeserializeValue(&serial_data, &serial_length, &head_size_);
DeserializeValue(&serial_data, &serial_length, &scale_);
DeserializeValue(&serial_data, &serial_length, &with_fp16_);
}
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
return new MultiheadMatmulRoformerPlugin(
hidden_, head_number_, head_size_, scale_, with_fp16_);
}
const char* getPluginType() const TRT_NOEXCEPT override {
return "multihead_matmul_roformer_plugin";
}
int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
int initialize() TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override {
return SerializedSize(hidden_) + SerializedSize(head_number_) +
SerializedSize(head_size_) + SerializedSize(scale_) +
SerializedSize(with_fp16_);
}
void serialize(void* buffer) const TRT_NOEXCEPT override {
SerializeValue(&buffer, hidden_);
SerializeValue(&buffer, head_number_);
SerializeValue(&buffer, head_size_);
SerializeValue(&buffer, scale_);
SerializeValue(&buffer, with_fp16_);
}
nvinfer1::DimsExprs getOutputDimensions(
int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) // NOLINT
TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* in_out,
int nb_inputs,
int nb_outputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nb_inputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nb_outputs) TRT_NOEXCEPT override {}
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nb_inputs,
const nvinfer1::PluginTensorDesc* outputs,
int nb_outputs) const TRT_NOEXCEPT override {
return 0;
}
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) TRT_NOEXCEPT override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* input_types,
int nb_inputs) const
TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override { delete this; }
private:
int hidden_;
int head_number_;
int head_size_;
float scale_;
};
class MultiheadMatmulRoformerPluginCreator : public nvinfer1::IPluginCreator {
public:
MultiheadMatmulRoformerPluginCreator() {}
const char* getPluginName() const TRT_NOEXCEPT override {
return "multihead_matmul_roformer_plugin";
}
const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override {
return &field_collection_;
}
nvinfer1::IPluginV2* createPlugin(const char* name,
const nvinfer1::PluginFieldCollection* fc)
TRT_NOEXCEPT override {
return nullptr;
}
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length)
TRT_NOEXCEPT override {
auto plugin = new MultiheadMatmulRoformerPlugin(serial_data, serial_length);
return plugin;
}
void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override {
plugin_namespace_ = lib_namespace;
}
const char* getPluginNamespace() const TRT_NOEXCEPT override {
return plugin_namespace_.c_str();
}
private:
std::string plugin_namespace_;
std::string plugin_name_;
nvinfer1::PluginFieldCollection field_collection_;
std::vector<nvinfer1::PluginField> plugin_attributes_;
};
REGISTER_TRT_PLUGIN_V2(MultiheadMatmulRoformerPluginCreator);
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -21,6 +21,7 @@
#include "glog/logging.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/common.cuh"
#include "paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
......@@ -35,21 +36,6 @@ namespace plugin {
// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)
template <typename T>
__global__ void transpose(T *src,
T *dst,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head) {
int batch_id = blockIdx.x / (head_num * seq_len);
int seq_id = blockIdx.x % seq_len;
int head_id = (blockIdx.x % (head_num * seq_len)) / seq_len;
dst[batch_id * (head_num * seq_len * size_per_head) +
seq_id * head_num * size_per_head + head_id * size_per_head +
threadIdx.x] = src[blockIdx.x * size_per_head + threadIdx.x];
}
inline int round_up(int seq_len, int multiple = 32) {
PADDLE_ENFORCE_GT(
multiple,
......@@ -115,133 +101,6 @@ __global__ void transpose_qkv_unpadding(const T *src,
seq_id * size_per_head + threadIdx.x];
}
template <typename T>
__global__ void TransposeQkvKernel(const int H, const T *input, T *output) {
// Input: BxSx3xNxH
// Bias: 3xSxB
// Output: 3xBxNxSxH
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z;
const int N = blockDim.y;
const int S = gridDim.x;
const int B = gridDim.y;
const int NH = N * H;
const int NHS = NH * S;
const int in_offset = n * H + m * NH + s * 3 * NH + b * NHS * 3;
const int out_offset = s * H + n * S * H + b * NHS + m * NHS * B;
const int i = threadIdx.x;
output[out_offset + i] = input[in_offset + i];
}
inline void TransposeQKV(const int batch,
const int seq_len,
const int head_size,
const int head_num,
const float *input,
float *output,
cudaStream_t stream) {
int scratch_size = batch * head_num * seq_len * seq_len;
const dim3 grid(seq_len, batch, 3);
if (head_size % 4 == 0 && scratch_size % 4 == 0) {
const int h = head_size / 4;
const float4 *input4 = reinterpret_cast<const float4 *>(input);
float4 *output4 = reinterpret_cast<float4 *>(output);
const dim3 block(h, head_num, 1);
// limit h * head_num to max block size(1024).
PADDLE_ENFORCE_LE(h * head_num,
1024,
platform::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num,
head_size,
1024 * 4));
TransposeQkvKernel<float4><<<grid, block, 0, stream>>>(h, input4, output4);
} else if (head_size % 2 == 0 && scratch_size % 2 == 0) {
const int h = head_size / 2;
const float2 *input2 = reinterpret_cast<const float2 *>(input);
float2 *output2 = reinterpret_cast<float2 *>(output);
const dim3 block(h, head_num, 1);
// limit h * head_num to max block size(1024).
PADDLE_ENFORCE_LE(h * head_num,
1024,
platform::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num,
head_size,
1024 * 2));
TransposeQkvKernel<float2><<<grid, block, 0, stream>>>(h, input2, output2);
} else {
const dim3 block(head_size, head_num, 1);
// limit head_size * head_num to max block size(1024).
PADDLE_ENFORCE_LE(head_size * head_num,
1024,
platform::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num,
head_size,
1024));
TransposeQkvKernel<float>
<<<grid, block, 0, stream>>>(head_size, input, output);
}
}
inline void TransposeQKV(const int batch,
const int seq_len,
const int head_size,
const int head_num,
const half *input,
half *output,
cudaStream_t stream) {
int scratch_size = batch * head_num * seq_len * seq_len;
const dim3 grid(seq_len, batch, 3);
if (head_size % 8 == 0 && scratch_size % 8 == 0) {
int h = head_size / 8;
const int4 *input4 = reinterpret_cast<const int4 *>(input);
int4 *output4 = reinterpret_cast<int4 *>(output);
dim3 block(h, head_num, 1);
// limit h * head_num to max block size(1024).
PADDLE_ENFORCE_LE(h * head_num,
1024,
platform::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num,
head_size,
1024 * 8));
TransposeQkvKernel<int4><<<grid, block, 0, stream>>>(h, input4, output4);
} else if (head_size % 2 == 0 && scratch_size % 2 == 0) {
const int h = head_size / 2;
const half2 *input2 = reinterpret_cast<const half2 *>(input);
half2 *output2 = reinterpret_cast<half2 *>(output);
const dim3 block(h, head_num, 1);
// limit h * head_num to max block size(1024).
PADDLE_ENFORCE_LE(h * head_num,
1024,
platform::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num,
head_size,
1024 * 2));
TransposeQkvKernel<half2><<<grid, block, 0, stream>>>(h, input2, output2);
} else {
const dim3 block(head_size, head_num, 1);
// limit head_size * head_num to max block size(1024).
PADDLE_ENFORCE_LE(head_size * head_num,
1024,
platform::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num,
head_size,
1024));
TransposeQkvKernel<half>
<<<grid, block, 0, stream>>>(head_size, input, output);
}
}
int QkvToContextPluginDynamic::initialize() TRT_NOEXCEPT { return 0; }
nvinfer1::DimsExprs QkvToContextPluginDynamic::getOutputDimensions(
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest
from program_config import TensorConfig, ProgramConfig, OpConfig
import paddle.inference as paddle_infer
import numpy as np
from functools import partial
import unittest
class TestMultiheadMatmulRoformerFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
# trt
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=8,
workspace_size=102400,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Float32,
use_static=False,
use_calib_mode=False,
)
config.set_trt_dynamic_shape_info(
{
"mul_x": [1, 1, 768],
"eltadd_qk_b_var": [1, 12, 1, 1],
"cos_input": [1, 12, 1, 64],
"sin_input": [1, 12, 1, 64],
},
{
"mul_x": [1, 128, 768],
"eltadd_qk_b_var": [1, 12, 128, 128],
"cos_input": [1, 12, 128, 64],
"sin_input": [1, 12, 128, 64],
},
{
"mul_x": [1, 128, 768],
"eltadd_qk_b_var": [1, 12, 128, 128],
"cos_input": [1, 12, 128, 64],
"sin_input": [1, 12, 128, 64],
},
)
yield config, ["multihead_matmul_roformer", "matmul"], (1e-2, 1e-3)
def sample_program_config(self, draw):
def generate_mul_input():
return (
np.random.random([1, 128, 768]).astype(np.float32) - 0.5
) / 100.0
def generate_elewise_input():
return (
np.random.random([1, 12, 128, 128]).astype(np.float32)
) / 100.0
def generate_cos_input():
return np.random.random([1, 12, 128, 64]).astype(np.float32) - 0.5
def generate_sin_input():
return np.random.random([1, 12, 128, 64]).astype(np.float32) - 0.5
def generate_weight1():
return (
np.random.random((768, 768)).astype(np.float32) - 0.5
) / 100.0
def generate_weight2():
return (np.random.random(768).astype(np.float32) - 0.5) / 100.0
mul_0 = OpConfig(
"matmul",
inputs={"X": ["mul_x"], "Y": ["mul_0_w"]},
outputs={"Out": ["mul_0_out"]},
alpha=1.0,
transpose_X=False,
transpose_Y=False,
)
mul_1 = OpConfig(
"matmul",
inputs={"X": ["mul_x"], "Y": ["mul_1_w"]},
outputs={"Out": ["mul_1_out"]},
alpha=1.0,
transpose_X=False,
transpose_Y=False,
)
mul_2 = OpConfig(
"matmul",
inputs={"X": ["mul_x"], "Y": ["mul_2_w"]},
outputs={"Out": ["mul_2_out"]},
alpha=1.0,
transpose_X=False,
transpose_Y=False,
)
ele_0 = OpConfig(
"elementwise_add",
inputs={"X": [mul_0.outputs["Out"][0]], "Y": ["ele_0_w"]},
outputs={"Out": ["ele_0_out"]},
axis=-1,
)
ele_1 = OpConfig(
"elementwise_add",
inputs={"X": [mul_1.outputs["Out"][0]], "Y": ["ele_1_w"]},
outputs={"Out": ["ele_1_out"]},
axis=-1,
)
ele_2 = OpConfig(
"elementwise_add",
inputs={"X": [mul_2.outputs["Out"][0]], "Y": ["ele_2_w"]},
outputs={"Out": ["ele_2_out"]},
axis=-1,
)
reshape_0 = OpConfig(
"reshape2",
inputs={"X": [ele_0.outputs["Out"][0]]},
outputs={"Out": ["reshape_0_out"], "XShape": ["reshape_0_Xout"]},
shape=(1, 128, 12, 64),
)
reshape_1 = OpConfig(
"reshape2",
inputs={"X": [ele_1.outputs["Out"][0]]},
outputs={"Out": ["reshape_1_out"], "XShape": ["reshape_1_Xout"]},
shape=(1, 128, 12, 64),
)
reshape_2 = OpConfig(
"reshape2",
inputs={"X": [ele_2.outputs["Out"][0]]},
outputs={"Out": ["reshape_2_out"], "XShape": ["reshape_2_Xout"]},
shape=(1, 128, 12, 64),
)
transpose_0 = OpConfig(
"transpose2",
inputs={"X": [reshape_0.outputs["Out"][0]]},
outputs={"Out": ["transpose_0_out"]},
axis=(0, 2, 1, 3),
)
transpose_1 = OpConfig(
"transpose2",
inputs={"X": [reshape_1.outputs["Out"][0]]},
outputs={"Out": ["transpose_1_out"]},
axis=(0, 2, 1, 3),
)
transpose_2 = OpConfig(
"transpose2",
inputs={"X": [reshape_2.outputs["Out"][0]]},
outputs={"Out": ["transpose_2_out"]},
axis=(0, 2, 1, 3),
)
# roformer part
# q with scale branch
ele_mul_q_0 = OpConfig(
"elementwise_mul", # without split && concat
inputs={"X": [transpose_0.outputs["Out"][0]], "Y": ["cos_input"]},
outputs={"Out": ["ele_mul_q_0_out"]},
axis=-1,
)
split_q_0 = OpConfig(
"split",
inputs={"X": [transpose_0.outputs["Out"][0]]},
outputs={"Out": ["split_q_0_out_0", "split_q_0_out_1"]},
axis=3,
num=2,
)
concat_q_0 = OpConfig(
"concat",
inputs={
"X": [split_q_0.outputs["Out"][1], split_q_0.outputs["Out"][0]]
},
outputs={"Out": ["concat_q_0_out"]},
axis=-1,
)
ele_mul_q_1 = OpConfig(
"elementwise_mul", # without split && concat
inputs={"X": [concat_q_0.outputs["Out"][0]], "Y": ["sin_input"]},
outputs={"Out": ["ele_mul_q_1_out"]},
axis=-1,
)
ele_add_q_0 = OpConfig(
"elementwise_add",
inputs={
"X": [ele_mul_q_0.outputs["Out"][0]],
"Y": [ele_mul_q_1.outputs["Out"][0]],
},
outputs={"Out": ["ele_add_q_0_out"]},
axis=-1,
)
scale_0 = OpConfig(
"scale",
inputs={"X": [ele_add_q_0.outputs["Out"][0]]},
outputs={"Out": ["scale_0_out"]},
scale=0.1961161345243454,
bias=0,
)
# k branch which without scale op
ele_mul_k_0 = OpConfig(
"elementwise_mul", # without split && concat
inputs={"X": [transpose_1.outputs["Out"][0]], "Y": ["cos_input"]},
outputs={"Out": ["ele_mul_k_0_out"]},
axis=-1,
)
split_k_0 = OpConfig(
"split",
inputs={"X": [transpose_1.outputs["Out"][0]]},
outputs={"Out": ["split_k_0_out_0", "split_k_0_out_1"]},
axis=3,
num=2,
)
concat_k_0 = OpConfig(
"concat",
inputs={
"X": [split_k_0.outputs["Out"][1], split_k_0.outputs["Out"][0]]
},
outputs={"Out": ["concat_k_0_out"]},
axis=-1,
)
ele_mul_k_1 = OpConfig(
"elementwise_mul", # with split && concat
inputs={"X": [concat_k_0.outputs["Out"][0]], "Y": ["sin_input"]},
outputs={"Out": ["ele_mul_k_1_out"]},
axis=-1,
)
ele_add_k_0 = OpConfig(
"elementwise_add",
inputs={
"X": [ele_mul_k_0.outputs["Out"][0]],
"Y": [ele_mul_k_1.outputs["Out"][0]],
},
outputs={"Out": ["ele_add_k_0_out"]},
axis=-1,
)
matmul_0 = OpConfig(
"matmul",
inputs={
"X": [scale_0.outputs["Out"][0]],
"Y": [ele_add_k_0.outputs["Out"][0]],
},
outputs={"Out": ["matmul_0_out"]},
alpha=1.0,
transpose_X=False,
transpose_Y=True,
fused_reshape_Out=[],
fused_reshape_X=[],
fused_reshape_Y=[],
fused_transpose_Out=[],
fused_transpose_X=[],
fused_transpose_Y=[],
)
ele_3 = OpConfig(
"elementwise_add",
inputs={
"X": [matmul_0.outputs["Out"][0]],
"Y": ["eltadd_qk_b_var"],
},
outputs={"Out": ["ele_3_out"]},
axis=-1,
)
softmax_op = OpConfig(
"softmax",
inputs={"X": [ele_3.outputs["Out"][0]]},
outputs={"Out": ["softmax_out"]},
axis=3,
is_test=True,
)
matmul_1 = OpConfig(
"matmul",
inputs={
"X": [softmax_op.outputs["Out"][0]],
"Y": [transpose_2.outputs["Out"][0]],
},
outputs={"Out": ["matmul_1_out"]},
alpha=1.0,
transpose_X=False,
transpose_Y=False,
)
transpose_3 = OpConfig(
"transpose2",
inputs={"X": [matmul_1.outputs["Out"][0]]},
outputs={"Out": ["transpose_3_out"]},
axis=(0, 2, 1, 3),
)
reshape_3 = OpConfig(
"reshape2",
inputs={"X": [transpose_3.outputs["Out"][0]]},
outputs={"Out": ["reshape_3_out"], "XShape": ["reshape_3_Xout"]},
shape=(1, 128, 768),
)
mul_3 = OpConfig(
"matmul",
inputs={"X": [reshape_3.outputs["Out"][0]], "Y": ["mul_3_w"]},
outputs={"Out": ["mul_3_out"]},
alpha=1.0,
transpose_X=False,
transpose_Y=False,
fused_reshape_Out=[],
fused_reshape_X=[],
fused_reshape_Y=[],
fused_transpose_Out=[],
fused_transpose_X=[],
fused_transpose_Y=[],
)
ops = [
mul_0,
mul_1,
mul_2,
ele_0,
ele_1,
ele_2,
reshape_0,
reshape_1,
reshape_2,
transpose_0,
transpose_1,
transpose_2,
ele_mul_q_0,
split_q_0,
concat_q_0,
ele_mul_q_1,
ele_add_q_0,
ele_mul_k_0,
split_k_0,
concat_k_0,
ele_mul_k_1,
ele_add_k_0,
scale_0,
matmul_0,
ele_3,
softmax_op,
matmul_1,
transpose_3,
reshape_3,
mul_3,
]
program_config = ProgramConfig(
ops=ops,
inputs={
"mul_x": TensorConfig(data_gen=partial(generate_mul_input)),
"eltadd_qk_b_var": TensorConfig(
data_gen=partial(generate_elewise_input)
),
"cos_input": TensorConfig(data_gen=partial(generate_cos_input)),
"sin_input": TensorConfig(data_gen=partial(generate_sin_input)),
},
weights={ # generate_weight1
"mul_0_w": TensorConfig(data_gen=partial(generate_weight1)),
"mul_1_w": TensorConfig(data_gen=partial(generate_weight1)),
"mul_2_w": TensorConfig(data_gen=partial(generate_weight1)),
"mul_3_w": TensorConfig(data_gen=partial(generate_weight1)),
"ele_0_w": TensorConfig(data_gen=partial(generate_weight2)),
"ele_1_w": TensorConfig(data_gen=partial(generate_weight2)),
"ele_2_w": TensorConfig(data_gen=partial(generate_weight2)),
},
outputs=[ops[-1].outputs["Out"][0]],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=100,
min_success_num=1,
passes=["multihead_matmul_roformer_fuse_pass"],
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册