未验证 提交 6934ac79 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

[Paddle Inference] Support two inputs of multihead attention named qk_multihead. (#52455)

* Support two inputs of multihead attention named qk_multihead
上级 01247e33
...@@ -134,6 +134,7 @@ if(WITH_TENSORRT) ...@@ -134,6 +134,7 @@ if(WITH_TENSORRT)
pass_library(trt_multihead_matmul_fuse_pass inference) pass_library(trt_multihead_matmul_fuse_pass inference)
pass_library(trt_flash_multihead_matmul_fuse_pass inference) pass_library(trt_flash_multihead_matmul_fuse_pass inference)
pass_library(trt_cross_multihead_matmul_fuse_pass inference) pass_library(trt_cross_multihead_matmul_fuse_pass inference)
pass_library(trt_qk_multihead_matmul_fuse_pass inference)
pass_library(trt_skip_layernorm_fuse_pass inference) pass_library(trt_skip_layernorm_fuse_pass inference)
pass_library(merge_layernorm_fuse_pass inference) pass_library(merge_layernorm_fuse_pass inference)
pass_library(preln_skip_layernorm_fuse_pass inference) pass_library(preln_skip_layernorm_fuse_pass inference)
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct TrtQKMultiHeadMatmulPattern : public PatternBase {
TrtQKMultiHeadMatmulPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "qk_multihead_matmul") {}
PDNode* operator()();
// declare operator node's name
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(input1);
PATTERN_DECL_NODE(mul0);
PATTERN_DECL_NODE(mul1);
PATTERN_DECL_NODE(mul2);
PATTERN_DECL_NODE(mul0_w);
PATTERN_DECL_NODE(mul1_w);
PATTERN_DECL_NODE(mul2_w);
PATTERN_DECL_NODE(mul0_out);
PATTERN_DECL_NODE(mul1_out);
PATTERN_DECL_NODE(mul2_out);
PATTERN_DECL_NODE(elementwise0);
PATTERN_DECL_NODE(elementwise1);
PATTERN_DECL_NODE(elementwise2);
PATTERN_DECL_NODE(elementwise0_w);
PATTERN_DECL_NODE(elementwise1_w);
PATTERN_DECL_NODE(elementwise2_w);
PATTERN_DECL_NODE(elementwise0_out);
PATTERN_DECL_NODE(elementwise1_out);
PATTERN_DECL_NODE(elementwise2_out);
PATTERN_DECL_NODE(scale);
PATTERN_DECL_NODE(scale_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_1);
PATTERN_DECL_NODE(reshape2_2);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(reshape2_1_out);
PATTERN_DECL_NODE(reshape2_2_out);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_1);
PATTERN_DECL_NODE(transpose2_2);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(transpose2_2_out);
PATTERN_DECL_NODE(transpose2_qkv_out);
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
};
} // namespace patterns
class TrtQkMultiHeadMatmulFusePass : public FusePassBase {
public:
virtual ~TrtQkMultiHeadMatmulFusePass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"trt_qk_multihead_matmul_fuse"};
private:
int BuildQkFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -2569,6 +2569,7 @@ USE_TRT_CONVERTER(preln_groupnorm_act) ...@@ -2569,6 +2569,7 @@ USE_TRT_CONVERTER(preln_groupnorm_act)
#if IS_TRT_VERSION_GE(8522) #if IS_TRT_VERSION_GE(8522)
USE_TRT_CONVERTER(flash_multihead_matmul) USE_TRT_CONVERTER(flash_multihead_matmul)
USE_TRT_CONVERTER(cross_multihead_matmul) USE_TRT_CONVERTER(cross_multihead_matmul)
USE_TRT_CONVERTER(qk_multihead_matmul)
#endif #endif
#if IS_TRT_VERSION_GE(8510) #if IS_TRT_VERSION_GE(8510)
USE_TRT_CONVERTER(grid_sampler) USE_TRT_CONVERTER(grid_sampler)
......
...@@ -108,6 +108,7 @@ const std::vector<std::string> kTRTSubgraphPasses({ ...@@ -108,6 +108,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
"trt_flash_multihead_matmul_fuse_pass", // "trt_flash_multihead_matmul_fuse_pass", //
"trt_cross_multihead_matmul_fuse_pass", // "trt_cross_multihead_matmul_fuse_pass", //
"vit_attention_fuse_pass", // "vit_attention_fuse_pass", //
"trt_qk_multihead_matmul_fuse_pass", //
"layernorm_shift_partition_fuse_pass", // "layernorm_shift_partition_fuse_pass", //
"merge_layernorm_fuse_pass", // "merge_layernorm_fuse_pass", //
#if !defined _WIN32 #if !defined _WIN32
......
...@@ -28,6 +28,7 @@ list( ...@@ -28,6 +28,7 @@ list(
multihead_matmul_roformer_op.cc multihead_matmul_roformer_op.cc
flash_multihead_matmul_op.cc flash_multihead_matmul_op.cc
cross_multihead_matmul_op.cc cross_multihead_matmul_op.cc
qk_multihead_matmul_op.cc
grid_sampler_op.cc grid_sampler_op.cc
shuffle_channel_op.cc shuffle_channel_op.cc
fill_any_like_op.cc fill_any_like_op.cc
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See
the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class QkMultiheadMatMulOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
VLOG(3) << "convert a qk_multihead_mamul op to a corresponding tensorrt "
"network structure";
framework::OpDesc op_desc(op, nullptr);
auto* input_qk = engine_->GetITensor(op_desc.Input("Input_qk").front());
auto* input_v = engine_->GetITensor(op_desc.Input("Input_v").front());
auto output_name = op_desc.Output("Out")[0];
/* ------------------ weight_qk -------------------------*/
auto weight_qk_name = op_desc.Input("W_qk").front();
auto* weight_qk_v = scope.FindVar(weight_qk_name);
auto* weight_qk_t = weight_qk_v->GetMutable<phi::DenseTensor>();
float* weight_qk_data = nullptr;
weight_qk_data = const_cast<float*>(static_cast<const float*>(
engine_->GetFp32TrtWeight(weight_qk_name, *weight_qk_t).get().values));
const auto& weight_qk_dims =
weight_qk_t->dims(); // hidden_in_qk 2 hidden_out_qk
int hidden_in_qk = weight_qk_dims[0];
int num_qk = weight_qk_dims[1];
int hidden_out_qk = weight_qk_dims[2];
int head_number_qk = PADDLE_GET_CONST(int, op_desc.GetAttr("head_number"));
int head_size_qk = hidden_out_qk / head_number_qk;
int n_qk = num_qk * hidden_out_qk;
// [hidden_in, 2, head_number, head_size]
// -> [head_number, 2, head_size, hidden_in]
auto transpose_weight_qk = [](const float* src,
float* dst,
int two,
int head_number,
int head_size,
int hidden_in) {
for (int hn = 0; hn < head_number; hn++) {
for (int t = 0; t < two; t++) {
for (int hs = 0; hs < head_size; hs++) {
for (int hi = 0; hi < hidden_in; hi++) {
int out_index = hn * two * head_size * hidden_in +
t * head_size * hidden_in + hs * hidden_in + hi;
int in_index = hi * two * head_number * head_size +
t * head_number * head_size + hn * head_size + hs;
dst[out_index] = src[in_index];
}
}
}
}
};
std::vector<float> weight_qk_data_tmp;
weight_qk_data_tmp.reserve(weight_qk_t->numel());
memcpy(weight_qk_data_tmp.data(),
weight_qk_data,
weight_qk_t->numel() * sizeof(float));
transpose_weight_qk(weight_qk_data_tmp.data(),
weight_qk_data,
num_qk,
head_number_qk,
head_size_qk,
hidden_in_qk);
/* ------------------ bias_qk -------------------------*/
auto bias_qk_name = op_desc.Input("B_qk").front();
auto* bias_qk_v = scope.FindVar(bias_qk_name);
auto* bias_qk_t = bias_qk_v->GetMutable<phi::DenseTensor>();
float* bias_qk_data = nullptr;
bias_qk_data = const_cast<float*>(static_cast<const float*>(
engine_->GetFp32TrtWeight(bias_qk_name, *bias_qk_t).get().values));
// [2, head_number, head_size] -> [head_number, 2, head_size]
auto transpose_bias_qk = [](const float* src, float* dst, int N, int H) {
for (int i = 0; i < 2; ++i) {
for (int n = 0; n < N; ++n) {
for (int h = 0; h < H; ++h) {
dst[n * 2 * H + i * H + h] = src[i * N * H + n * H + h];
}
}
}
};
std::vector<float> bias_qk_data_tmp;
bias_qk_data_tmp.reserve(bias_qk_t->numel());
memcpy(bias_qk_data_tmp.data(),
bias_qk_data,
bias_qk_t->numel() * sizeof(float));
transpose_bias_qk(
bias_qk_data_tmp.data(), bias_qk_data, head_number_qk, head_size_qk);
auto weight_qk_shape = nvinfer1::Dims3{1, n_qk, hidden_in_qk};
auto* weight_qk_tensor =
AddConstantLayer(weight_qk_data, weight_qk_shape, " ");
auto bias_qk_shape = nvinfer1::Dims3{1, 1, n_qk};
auto* bias_qk_tensor = AddConstantLayer(bias_qk_data, bias_qk_shape, " ");
nvinfer1::ITensor* input_qk_shape_tensor = Shape(input_qk);
nvinfer1::ILayer* fc_qk_layer = nullptr;
nvinfer1::ILayer* merge_qk_element_layer = nullptr;
nvinfer1::MatrixOperation matrix_operation_X =
nvinfer1::MatrixOperation::kNONE;
nvinfer1::MatrixOperation matrix_operation_Y =
nvinfer1::MatrixOperation::kTRANSPOSE;
fc_qk_layer = TRT_ENGINE_ADD_LAYER(engine_,
MatrixMultiply,
*input_qk,
matrix_operation_X,
*weight_qk_tensor,
matrix_operation_Y);
fc_qk_layer->setName(
("qk_attention_matrix_multiply(Output: " + output_name + ")").c_str());
// add qk ElementWiseLayer layer
nvinfer1::ElementWiseOperation elementwise_operation =
nvinfer1::ElementWiseOperation::kSUM;
merge_qk_element_layer = TRT_ENGINE_ADD_LAYER(engine_,
ElementWise,
*fc_qk_layer->getOutput(0),
*bias_qk_tensor,
elementwise_operation);
merge_qk_element_layer->setName(
("multihead_mamul_fc_qk(Output: " + output_name + ")").c_str());
auto* reshape_after_fc_qk_layer = TRT_ENGINE_ADD_LAYER(
engine_, Shuffle, *merge_qk_element_layer->getOutput(0));
std::vector<nvinfer1::ITensor*> mha_input_qk_tensor_shape;
for (int i = 0; i < 5; i++) {
mha_input_qk_tensor_shape.push_back(Add1DConstantLayer(1));
}
mha_input_qk_tensor_shape[0] =
GetEleTensorOfShape(input_qk_shape_tensor, 0);
mha_input_qk_tensor_shape[1] =
GetEleTensorOfShape(input_qk_shape_tensor, 1);
mha_input_qk_tensor_shape[2] = Add1DConstantLayer(head_number_qk);
mha_input_qk_tensor_shape[3] = Add1DConstantLayer(2);
mha_input_qk_tensor_shape[4] = Add1DConstantLayer(head_size_qk);
reshape_after_fc_qk_layer->setInput(1, *Concat(mha_input_qk_tensor_shape));
reshape_after_fc_qk_layer->setName(
("shuffle_after_fc_qk_multihead_matmul(Output: " + output_name + ")")
.c_str());
/* ------------------ weight_v -------------------------*/
auto weight_v_name = op_desc.Input("W_v").front();
auto* weight_v_v = scope.FindVar(weight_v_name);
auto* weight_v_t = weight_v_v->GetMutable<phi::DenseTensor>();
float* weight_v_data = nullptr;
weight_v_data = const_cast<float*>(static_cast<const float*>(
engine_->GetFp32TrtWeight(weight_v_name, *weight_v_t).get().values));
int n_v = hidden_out_qk;
// [hidden_in, head_number, head_size]
// -> [head_number, head_size, hidden_in]
auto transpose_weight_v = [](const float* src,
float* dst,
int head_number,
int head_size,
int hidden_in) {
for (int hn = 0; hn < head_number; hn++) {
for (int hs = 0; hs < head_size; hs++) {
for (int hi = 0; hi < hidden_in; hi++) {
int out_index = hn * head_size * hidden_in + hs * hidden_in + hi;
int in_index = hi * head_number * head_size + hn * head_size + hs;
dst[out_index] = src[in_index];
}
}
}
};
std::vector<float> weight_v_data_tmp;
weight_v_data_tmp.reserve(weight_v_t->numel());
memcpy(weight_v_data_tmp.data(),
weight_v_data,
weight_v_t->numel() * sizeof(float));
transpose_weight_v(weight_v_data_tmp.data(),
weight_v_data,
head_number_qk,
head_size_qk,
hidden_in_qk);
/* ------------------ bias_v -------------------------*/
auto bias_v_name = op_desc.Input("B_v").front();
auto* bias_v_v = scope.FindVar(bias_v_name);
auto* bias_v_t = bias_v_v->GetMutable<phi::DenseTensor>();
float* bias_v_data = nullptr;
bias_v_data = const_cast<float*>(static_cast<const float*>(
engine_->GetFp32TrtWeight(bias_v_name, *bias_v_t).get().values));
auto weight_v_shape = nvinfer1::Dims3{1, n_v, hidden_in_qk};
auto* weight_v_tensor =
AddConstantLayer(weight_v_data, weight_v_shape, " ");
auto bias_v_shape = nvinfer1::Dims3{1, 1, n_v};
auto* bias_v_tensor = AddConstantLayer(bias_v_data, bias_v_shape, " ");
nvinfer1::ITensor* input_v_shape_tensor = Shape(input_v);
nvinfer1::ILayer* fc_v_layer = nullptr;
nvinfer1::ILayer* merge_v_element_layer = nullptr;
fc_v_layer = TRT_ENGINE_ADD_LAYER(engine_,
MatrixMultiply,
*input_v,
matrix_operation_X,
*weight_v_tensor,
matrix_operation_Y);
fc_v_layer->setName(
("v_attention_matrix_multiply(Output: " + output_name + ")").c_str());
// add v ElementWiseLayer layer
merge_v_element_layer = TRT_ENGINE_ADD_LAYER(engine_,
ElementWise,
*fc_v_layer->getOutput(0),
*bias_v_tensor,
elementwise_operation);
merge_v_element_layer->setName(
("multihead_mamul_fc_v(Output: " + output_name + ")").c_str());
// add shuffle for fc layer
auto* reshape_after_fc_v_layer = TRT_ENGINE_ADD_LAYER(
engine_, Shuffle, *merge_v_element_layer->getOutput(0));
std::vector<nvinfer1::ITensor*> mha_input_v_tensor_shape;
for (int i = 0; i < 5; i++) {
mha_input_v_tensor_shape.push_back(Add1DConstantLayer(1));
}
mha_input_v_tensor_shape[0] = GetEleTensorOfShape(input_v_shape_tensor, 0);
mha_input_v_tensor_shape[1] = GetEleTensorOfShape(input_v_shape_tensor, 1);
mha_input_v_tensor_shape[2] = Add1DConstantLayer(head_number_qk);
mha_input_v_tensor_shape[3] = Add1DConstantLayer(1);
mha_input_v_tensor_shape[4] = Add1DConstantLayer(head_size_qk);
reshape_after_fc_v_layer->setInput(1, *Concat(mha_input_v_tensor_shape));
reshape_after_fc_v_layer->setName(
("shuffle_after_fc_v_multihead_matmul(Output: " + output_name + ")")
.c_str());
std::vector<nvinfer1::ITensor*> mha_input_tensor_vector{
reshape_after_fc_qk_layer->getOutput(0),
reshape_after_fc_v_layer->getOutput(0)};
nvinfer1::ITensor* mha_input_tensor = Concat(mha_input_tensor_vector, 3);
auto creator = GetPluginRegistry()->getPluginCreator("fMHA_V2", "1");
assert(creator != nullptr);
std::vector<nvinfer1::PluginField> fields{};
nvinfer1::PluginFieldCollection* plugin_collection =
static_cast<nvinfer1::PluginFieldCollection*>(
malloc(sizeof(*plugin_collection) +
fields.size() *
sizeof(nvinfer1::PluginField))); // remember to free
plugin_collection->nbFields = static_cast<int>(fields.size());
plugin_collection->fields = fields.data();
auto plugin = creator->createPlugin("fMHA_V2", plugin_collection);
free(plugin_collection);
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.emplace_back(mha_input_tensor);
auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin);
// add shuffle
nvinfer1::ITensor* batch_tensor =
GetEleTensorOfShape(input_qk_shape_tensor, 0);
nvinfer1::ITensor* length_tensor =
GetEleTensorOfShape(input_qk_shape_tensor, 1);
auto* reshape_after_mha_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *plugin_layer->getOutput(0));
std::vector<nvinfer1::ITensor*> reshape_tensor;
reshape_tensor.push_back(batch_tensor);
reshape_tensor.push_back(length_tensor);
reshape_tensor.push_back(Add1DConstantLayer(-1));
reshape_after_mha_layer->setInput(1, *Concat(reshape_tensor));
reshape_after_mha_layer->setName(
("shuffle_last_multihead_matmul(Output: " + output_name + ")").c_str());
nvinfer1::ILayer* layer = nullptr;
layer = reshape_after_mha_layer;
RreplenishLayerAndOutput(
layer, "qk_multihead_matmul", {output_name}, test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(qk_multihead_matmul, QkMultiheadMatMulOpConverter);
...@@ -70,6 +70,8 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -70,6 +70,8 @@ struct SimpleOpTypeSetTeller : public Teller {
int8_teller_set.insert("flash_multihead_matmul"); int8_teller_set.insert("flash_multihead_matmul");
teller_set.insert("cross_multihead_matmul"); teller_set.insert("cross_multihead_matmul");
int8_teller_set.insert("cross_multihead_matmul"); int8_teller_set.insert("cross_multihead_matmul");
teller_set.insert("qk_multihead_matmul");
int8_teller_set.insert("qk_multihead_matmul");
#endif #endif
#if IS_TRT_VERSION_GE(8200) #if IS_TRT_VERSION_GE(8200)
teller_set.insert("round"); teller_set.insert("round");
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
from typing import List
import numpy as np
from program_config import ProgramConfig, TensorConfig
from trt_layer_auto_scan_test import SkipReasons, TrtLayerAutoScanTest
import paddle.inference as paddle_infer
class TrtConvertQkAttentionTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
ver = paddle_infer.get_trt_compile_version()
if ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 < 8520:
return False
return True
def sample_program_configs(self):
def generate_input1(batch, length):
return np.random.rand(batch, length, 256).astype(np.float32) / 10
def generate_input2(batch, length):
return np.random.rand(batch, length, 256).astype(np.float32) / 10
def generate_weight_q():
return np.random.rand(256, 256).astype(np.float32) / 10
def generate_weight_k():
return np.random.rand(256, 256).astype(np.float32) / 10
def generate_weight_v():
return np.random.rand(256, 256).astype(np.float32) / 10
def generate_bias_q():
return np.random.rand(256).astype(np.float32) / 10
def generate_bias_k():
return np.random.rand(256).astype(np.float32) / 10
def generate_bias_v():
return np.random.rand(256).astype(np.float32) / 10
for batch in [1, 2]:
self.batch = batch
for length in [300, 400]:
ops_config = [
# q
{
"op_type": "matmul_v2",
"op_inputs": {
"X": ["input_data1"],
"Y": ["matmul_q_weight"],
},
"op_outputs": {"Out": ["matmul_q_output"]},
"op_attrs": {"trans_x": False, "trans_y": False},
},
{
"op_type": "elementwise_add",
"op_inputs": {
"X": ["matmul_q_output"],
"Y": ["bias_q"],
},
"op_outputs": {"Out": ["elementwise_q_output"]},
"op_attrs": {
"Scale_out": 1.0,
"Scale_x": 1.0,
"Scale_y": 1.0,
"axis": 2,
},
},
{
"op_type": "reshape2",
"op_inputs": {
"X": ["elementwise_q_output"],
},
"op_outputs": {
"Out": ["reshape_q_output"],
"XShape": ["reshape_q_output_xshape"],
},
"op_attrs": {"shape": [0, 0, 8, 32]},
},
{
"op_type": "transpose2",
"op_inputs": {"X": ["reshape_q_output"]},
"op_outputs": {
"Out": ["transpose_q_output"],
"XShape": ["transpose_q_output_xshape"],
},
"op_attrs": {
"axis": [0, 2, 1, 3],
"data_format": "AnyLayout",
},
},
# k
{
"op_type": "matmul_v2",
"op_inputs": {
"X": ["input_data1"],
"Y": ["matmul_k_weight"],
},
"op_outputs": {"Out": ["matmul_k_output"]},
"op_attrs": {"trans_x": False, "trans_y": False},
},
{
"op_type": "elementwise_add",
"op_inputs": {
"X": ["matmul_k_output"],
"Y": ["bias_k"],
},
"op_outputs": {"Out": ["elementwise_k_output"]},
"op_attrs": {
"Scale_out": 1.0,
"Scale_x": 1.0,
"Scale_y": 1.0,
"axis": 2,
},
},
{
"op_type": "reshape2",
"op_inputs": {
"X": ["elementwise_k_output"],
},
"op_outputs": {
"Out": ["reshape_k_output"],
"XShape": ["reshape_k_output_xshape"],
},
"op_attrs": {"shape": [0, 0, 8, 32]},
},
{
"op_type": "transpose2",
"op_inputs": {"X": ["reshape_k_output"]},
"op_outputs": {
"Out": ["transpose_k_output"],
"XShape": ["transpose_k_output_xshape"],
},
"op_attrs": {
"axis": [0, 2, 1, 3],
"data_format": "AnyLayout",
},
},
# V
{
"op_type": "matmul_v2",
"op_inputs": {
"X": ["input_data2"],
"Y": ["matmul_v_weight"],
},
"op_outputs": {"Out": ["matmul_v_output"]},
"op_attrs": {"trans_x": False, "trans_y": False},
},
{
"op_type": "elementwise_add",
"op_inputs": {
"X": ["matmul_v_output"],
"Y": ["bias_v"],
},
"op_outputs": {"Out": ["elementwise_v_output"]},
"op_attrs": {
"Scale_out": 1.0,
"Scale_x": 1.0,
"Scale_y": 1.0,
"axis": 2,
},
},
{
"op_type": "reshape2",
"op_inputs": {
"X": ["elementwise_v_output"],
},
"op_outputs": {
"Out": ["reshape_v_output"],
"XShape": ["reshape_v_output_xshape"],
},
"op_attrs": {"shape": [0, 0, 8, 32]},
},
{
"op_type": "transpose2",
"op_inputs": {"X": ["reshape_v_output"]},
"op_outputs": {
"Out": ["transpose_v_output"],
"XShape": ["transpose_v_output_xshape"],
},
"op_attrs": {
"axis": [0, 2, 1, 3],
"data_format": "AnyLayout",
},
},
# matmul1+matmul2
{
"op_type": "matmul_v2",
"op_inputs": {
"X": ["transpose_q_output"],
"Y": ["transpose_k_output"],
},
"op_outputs": {"Out": ["matmul1_output"]},
"op_attrs": {"trans_x": False, "trans_y": True},
},
{
"op_type": "scale",
"op_inputs": {
"X": ["matmul1_output"],
},
"op_outputs": {"Out": ["scale_output"]},
"op_attrs": {
"scale": 0.17677,
"bias": 0.0,
"bias_after_scale": True,
},
},
{
"op_type": "softmax",
"op_inputs": {"X": ["scale_output"]},
"op_outputs": {"Out": ["softmax_output"]},
"op_attrs": {
"axis": -1,
"data_format": "AnyLayout",
},
},
{
"op_type": "matmul_v2",
"op_inputs": {
"X": ["softmax_output"],
"Y": ["transpose_v_output"],
},
"op_outputs": {"Out": ["matmul2_output"]},
"op_attrs": {"trans_x": False, "trans_y": False},
},
{
"op_type": "transpose2",
"op_inputs": {"X": ["matmul2_output"]},
"op_outputs": {
"Out": ["transpose_output"],
"XShape": ["transpose_output_xshape"],
},
"op_attrs": {
"axis": [0, 2, 1, 3],
"data_format": "AnyLayout",
},
},
{
"op_type": "reshape2",
"op_inputs": {"X": ["transpose_output"]},
"op_outputs": {
"Out": ["reshape_output"],
"XShape": ["reshape_output_xshape"],
},
"op_attrs": {"shape": [0, 0, 256]},
},
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={
"matmul_q_weight": TensorConfig(
data_gen=partial(generate_weight_q)
),
"matmul_k_weight": TensorConfig(
data_gen=partial(generate_weight_k)
),
"matmul_v_weight": TensorConfig(
data_gen=partial(generate_weight_v)
),
"bias_q": TensorConfig(
data_gen=partial(generate_bias_q)
),
"bias_k": TensorConfig(
data_gen=partial(generate_bias_k)
),
"bias_v": TensorConfig(
data_gen=partial(generate_bias_v)
),
},
inputs={
"input_data1": TensorConfig(
data_gen=partial(generate_input1, batch, length)
),
"input_data2": TensorConfig(
data_gen=partial(generate_input2, batch, length)
),
},
outputs=["reshape_output"],
)
yield program_config
def sample_predictor_configs(
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
# The last dim of input1 and input2 should be static.
self.dynamic_shape.min_input_shape = {
"input_data1": [1, 300, 256],
"input_data2": [1, 300, 256],
}
self.dynamic_shape.max_input_shape = {
"input_data1": [4, 1200, 256],
"input_data2": [4, 1200, 256],
}
self.dynamic_shape.opt_input_shape = {
"input_data1": [1, 300, 256],
"input_data2": [1, 300, 256],
}
def clear_dynamic_shape():
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
self.trt_param.workspace_size = 2013265920
yield self.create_inference_config(), (1, 3), (1e-5, 1e-5)
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 3), (1e-3, 1e-3)
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
self.trt_param.workspace_size = 2013265920
yield self.create_inference_config(), (1, 3), (1e-5, 1e-4)
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 3), (1e-2, 1e-3)
def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if self.dynamic_shape.min_input_shape == {}:
return True
return False
self.add_skip_case(
teller1,
SkipReasons.TRT_NOT_IMPLEMENTED,
"The qk attention trt oss plugin do not support static shape yet",
)
def teller2(program_config, predictor_config):
if self.trt_param.precision == paddle_infer.PrecisionType.Float32:
return True
return False
self.add_skip_case(
teller2,
SkipReasons.TRT_NOT_IMPLEMENTED,
"The qk attention trt oss plugin do not support fp32 yet",
)
def teller3(program_config, predictor_config):
if self.trt_param.precision == paddle_infer.PrecisionType.Int8:
return True
return False
self.add_skip_case(
teller3,
SkipReasons.TRT_NOT_IMPLEMENTED,
"The qk attention trt oss plugin do not support int8 yet.",
)
def test(self):
self.add_skip_trt_case()
self.run_test()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册