未验证 提交 a48b8e2c 编写于 作者: W Wang Bojun 提交者: GitHub

add oss flash fmha and fmhca support (#49438)

* add fmha_flashattention oss plugin
上级 650a0836
...@@ -130,6 +130,8 @@ target_link_libraries(generate_pass pass_desc_proto) ...@@ -130,6 +130,8 @@ target_link_libraries(generate_pass pass_desc_proto)
if(WITH_TENSORRT) if(WITH_TENSORRT)
pass_library(trt_map_matmul_to_mul_pass inference) pass_library(trt_map_matmul_to_mul_pass inference)
pass_library(trt_multihead_matmul_fuse_pass inference) pass_library(trt_multihead_matmul_fuse_pass inference)
pass_library(trt_flash_multihead_matmul_fuse_pass inference)
pass_library(trt_cross_multihead_matmul_fuse_pass inference)
pass_library(trt_skip_layernorm_fuse_pass inference) pass_library(trt_skip_layernorm_fuse_pass inference)
pass_library(merge_layernorm_fuse_pass inference) pass_library(merge_layernorm_fuse_pass inference)
pass_library(preln_skip_layernorm_fuse_pass inference) pass_library(preln_skip_layernorm_fuse_pass inference)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct TrtCrossMultiHeadMatmulPattern : public PatternBase {
TrtCrossMultiHeadMatmulPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, "cross_multihead_matmul") {}
PDNode* operator()();
// declare operator node's name
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(input1);
PATTERN_DECL_NODE(mul0);
PATTERN_DECL_NODE(mul1);
PATTERN_DECL_NODE(mul2);
PATTERN_DECL_NODE(mul0_w);
PATTERN_DECL_NODE(mul1_w);
PATTERN_DECL_NODE(mul2_w);
PATTERN_DECL_NODE(mul0_out);
PATTERN_DECL_NODE(mul1_out);
PATTERN_DECL_NODE(mul2_out);
PATTERN_DECL_NODE(scale);
PATTERN_DECL_NODE(scale_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_1);
PATTERN_DECL_NODE(reshape2_2);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(reshape2_1_out);
PATTERN_DECL_NODE(reshape2_2_out);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_1);
PATTERN_DECL_NODE(transpose2_2);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(transpose2_2_out);
PATTERN_DECL_NODE(transpose2_qkv_out);
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
};
} // namespace patterns
class TrtCrossMultiHeadMatmulFusePass : public FusePassBase {
public:
TrtCrossMultiHeadMatmulFusePass();
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"trt_cross_multihead_matmul_fuse"};
private:
int BuildCrossFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct TrtFlashMultiHeadMatmulPattern : public PatternBase {
TrtFlashMultiHeadMatmulPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, "flash_multihead_matmul") {}
PDNode* operator()();
// declare operator node's name
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(mul0);
PATTERN_DECL_NODE(mul1);
PATTERN_DECL_NODE(mul2);
PATTERN_DECL_NODE(mul0_w);
PATTERN_DECL_NODE(mul1_w);
PATTERN_DECL_NODE(mul2_w);
PATTERN_DECL_NODE(mul0_out);
PATTERN_DECL_NODE(mul1_out);
PATTERN_DECL_NODE(mul2_out);
PATTERN_DECL_NODE(scale);
PATTERN_DECL_NODE(scale_out);
PATTERN_DECL_NODE(reshape2_0);
PATTERN_DECL_NODE(reshape2_1);
PATTERN_DECL_NODE(reshape2_2);
PATTERN_DECL_NODE(reshape2_qkv);
PATTERN_DECL_NODE(reshape2_0_out);
PATTERN_DECL_NODE(reshape2_1_out);
PATTERN_DECL_NODE(reshape2_2_out);
PATTERN_DECL_NODE(reshape2_qkv_out);
PATTERN_DECL_NODE(transpose2_0);
PATTERN_DECL_NODE(transpose2_1);
PATTERN_DECL_NODE(transpose2_2);
PATTERN_DECL_NODE(transpose2_qkv);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(transpose2_2_out);
PATTERN_DECL_NODE(transpose2_qkv_out);
PATTERN_DECL_NODE(matmul_qk);
PATTERN_DECL_NODE(matmul_qk_out);
PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out);
PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out);
};
} // namespace patterns
class TrtFlashMultiHeadMatmulFusePass : public FusePassBase {
public:
TrtFlashMultiHeadMatmulFusePass();
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"trt_flash_multihead_matmul_fuse"};
private:
int BuildFlashFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -631,6 +631,7 @@ PDNode* TrtMultiHeadMatmulV3Pattern::operator()() { ...@@ -631,6 +631,7 @@ PDNode* TrtMultiHeadMatmulV3Pattern::operator()() {
return transpose2_2_out_var; return transpose2_2_out_var;
} }
} // namespace patterns } // namespace patterns
void TrtMultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const { void TrtMultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const {
...@@ -1667,6 +1668,7 @@ REGISTER_PASS(trt_multihead_matmul_fuse_pass_v2, ...@@ -1667,6 +1668,7 @@ REGISTER_PASS(trt_multihead_matmul_fuse_pass_v2,
paddle::framework::ir::TrtMultiHeadMatmulV2FusePass); paddle::framework::ir::TrtMultiHeadMatmulV2FusePass);
REGISTER_PASS(trt_multihead_matmul_fuse_pass_v3, REGISTER_PASS(trt_multihead_matmul_fuse_pass_v3,
paddle::framework::ir::TrtMultiHeadMatmulV3FusePass); paddle::framework::ir::TrtMultiHeadMatmulV3FusePass);
REGISTER_PASS_CAPABILITY(trt_multihead_matmul_fuse_pass_v2) REGISTER_PASS_CAPABILITY(trt_multihead_matmul_fuse_pass_v2)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
...@@ -1677,6 +1679,7 @@ REGISTER_PASS_CAPABILITY(trt_multihead_matmul_fuse_pass_v2) ...@@ -1677,6 +1679,7 @@ REGISTER_PASS_CAPABILITY(trt_multihead_matmul_fuse_pass_v2)
.EQ("scale", 0) .EQ("scale", 0)
.LE("matmul", 1) .LE("matmul", 1)
.EQ("softmax", 0)); .EQ("softmax", 0));
REGISTER_PASS_CAPABILITY(trt_multihead_matmul_fuse_pass_v3) REGISTER_PASS_CAPABILITY(trt_multihead_matmul_fuse_pass_v3)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
......
...@@ -2427,6 +2427,10 @@ USE_TRT_CONVERTER(expand_v2) ...@@ -2427,6 +2427,10 @@ USE_TRT_CONVERTER(expand_v2)
USE_TRT_CONVERTER(take_along_axis) USE_TRT_CONVERTER(take_along_axis)
USE_TRT_CONVERTER(skip_groupnorm_act) USE_TRT_CONVERTER(skip_groupnorm_act)
USE_TRT_CONVERTER(preln_groupnorm_act) USE_TRT_CONVERTER(preln_groupnorm_act)
#if IS_TRT_VERSION_GE(8522)
USE_TRT_CONVERTER(flash_multihead_matmul)
USE_TRT_CONVERTER(cross_multihead_matmul)
#endif
#if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000) #if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000)
USE_TRT_CONVERTER(sparse_fc) USE_TRT_CONVERTER(sparse_fc)
USE_TRT_CONVERTER(sparse_multihead_matmul) USE_TRT_CONVERTER(sparse_multihead_matmul)
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
#include <miopen/miopen.h> #include <miopen/miopen.h>
#endif #endif
#include <glog/logging.h> #include <glog/logging.h>
#include <algorithm> #include <algorithm>
...@@ -103,6 +104,8 @@ const std::vector<std::string> kTRTSubgraphPasses({ ...@@ -103,6 +104,8 @@ const std::vector<std::string> kTRTSubgraphPasses({
"trt_multihead_matmul_fuse_pass_v3", // "trt_multihead_matmul_fuse_pass_v3", //
"multihead_matmul_roformer_fuse_pass", // "multihead_matmul_roformer_fuse_pass", //
"constant_folding_pass", // "constant_folding_pass", //
"trt_flash_multihead_matmul_fuse_pass", //
"trt_cross_multihead_matmul_fuse_pass", //
"vit_attention_fuse_pass", // "vit_attention_fuse_pass", //
#if defined _WIN32 // Windows CI is TensorRT7.0. Remove this after upgrading. #if defined _WIN32 // Windows CI is TensorRT7.0. Remove this after upgrading.
#else #else
......
...@@ -25,6 +25,8 @@ list( ...@@ -25,6 +25,8 @@ list(
layer_norm_op.cc layer_norm_op.cc
multihead_matmul_op.cc multihead_matmul_op.cc
multihead_matmul_roformer_op.cc multihead_matmul_roformer_op.cc
flash_multihead_matmul_op.cc
cross_multihead_matmul_op.cc
shuffle_channel_op.cc shuffle_channel_op.cc
fill_any_like_op.cc fill_any_like_op.cc
where_op.cc where_op.cc
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See
the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class CrossMultiheadMatMulOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
VLOG(3) << "convert a cross_multihead_mamul op to a corresponding tensorrt "
"network structure";
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
if (engine_->precision() == AnalysisConfig::Precision::kInt8) {
with_fp16 = true;
}
PADDLE_ENFORCE_EQ(
with_fp16,
true,
platform::errors::Unimplemented(
"Trt cross attention oss plugin only support fp16 mode yet."));
framework::OpDesc op_desc(op, nullptr);
auto* input_q = engine_->GetITensor(op_desc.Input("Input_q").front());
auto* input_kv = engine_->GetITensor(op_desc.Input("Input_kv").front());
// auto input_dims = input->getDimensions();
auto output_name = op_desc.Output("Out")[0];
auto weight_q_name = op_desc.Input("W_q").front();
auto* weight_q_v = scope.FindVar(weight_q_name);
auto* weight_q_t = weight_q_v->GetMutable<phi::DenseTensor>();
float* weight_q_data = nullptr;
weight_q_data = const_cast<float*>(static_cast<const float*>(
engine_->GetFp32TrtWeight(weight_q_name, *weight_q_t).get().values));
const auto& weight_q_dims = weight_q_t->dims();
int hidden_in_q = weight_q_dims[0];
int hidden_out_q = weight_q_dims[1];
int head_number_q = PADDLE_GET_CONST(int, op_desc.GetAttr("head_number"));
int head_size_q = hidden_out_q / head_number_q;
int n_q = hidden_out_q;
auto transpose_weight_q = [](const float* src,
float* dst,
int head_number,
int head_size,
int hidden_in) {
for (int hn = 0; hn < head_number; hn++) {
for (int hs = 0; hs < head_size; hs++) {
for (int hi = 0; hi < hidden_in; hi++) {
int out_index = hn * head_size * hidden_in + hs * hidden_in + hi;
int in_index = hi * head_number * head_size + hn * head_size + hs;
dst[out_index] = src[in_index];
}
}
}
};
std::vector<float> weight_q_data_tmp;
weight_q_data_tmp.reserve(weight_q_t->numel());
memcpy(weight_q_data_tmp.data(),
weight_q_data,
weight_q_t->numel() * sizeof(float));
transpose_weight_q(weight_q_data_tmp.data(),
weight_q_data,
head_number_q,
head_size_q,
hidden_in_q);
nvinfer1::Weights weight_q{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_q_data),
static_cast<int32_t>(weight_q_t->numel())};
nvinfer1::Weights bias_q{};
// add shuffle for FullyConnected layer
std::vector<nvinfer1::ITensor*> reshape_before_fc_q_shape_tensor;
nvinfer1::ITensor* input_q_shape_tensor = Shape(input_q);
for (int i = 0; i < 5; i++) {
reshape_before_fc_q_shape_tensor.push_back(Add1DConstantLayer(1));
}
for (int i = 0; i < 3; i++) {
reshape_before_fc_q_shape_tensor[i] =
GetEleTensorOfShape(input_q_shape_tensor, i);
}
auto* reshape_before_fc_q_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input_q);
reshape_before_fc_q_layer->setInput(
1, *Concat(reshape_before_fc_q_shape_tensor));
reshape_before_fc_q_layer->setName(
("shuffle_before_fc_q_multihead_matmul(Output: " + output_name + ")")
.c_str());
nvinfer1::ILayer* fc_q_layer = nullptr;
fc_q_layer = TRT_ENGINE_ADD_LAYER(engine_,
FullyConnected,
*reshape_before_fc_q_layer->getOutput(0),
n_q,
weight_q,
bias_q);
fc_q_layer->setName(
("multihead_mamul_fc_q(Output: " + output_name + ")").c_str());
// add shuffle for fc layer
auto* reshape_after_fc_q_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *fc_q_layer->getOutput(0));
std::vector<nvinfer1::ITensor*> mha_input_q_tensor_shape;
for (int i = 0; i < 4; i++) {
mha_input_q_tensor_shape.push_back(Add1DConstantLayer(1));
}
mha_input_q_tensor_shape[0] = GetEleTensorOfShape(input_q_shape_tensor, 0);
mha_input_q_tensor_shape[1] = GetEleTensorOfShape(input_q_shape_tensor, 1);
mha_input_q_tensor_shape[2] = Add1DConstantLayer(head_number_q);
mha_input_q_tensor_shape[3] = Add1DConstantLayer(head_size_q);
reshape_after_fc_q_layer->setInput(1, *Concat(mha_input_q_tensor_shape));
reshape_after_fc_q_layer->setName(
("shuffle_after_fc_q_multihead_matmul(Output: " + output_name + ")")
.c_str());
auto weight_kv_name = op_desc.Input("W_kv").front();
auto* weight_kv_v = scope.FindVar(weight_kv_name);
auto* weight_kv_t = weight_kv_v->GetMutable<phi::DenseTensor>();
float* weight_kv_data = nullptr;
weight_kv_data = const_cast<float*>(static_cast<const float*>(
engine_->GetFp32TrtWeight(weight_kv_name, *weight_kv_t).get().values));
// (hidden_in, 2, hidden_out)
const auto& weight_kv_dims = weight_kv_t->dims();
int hidden_in = weight_kv_dims[0]; // channels_in
int two = weight_kv_dims[1]; // three
int hidden_out = weight_kv_dims[2]; // channels_out
int head_number = PADDLE_GET_CONST(int, op_desc.GetAttr("head_number"));
int head_size = hidden_out / head_number;
int n = two * hidden_out;
nvinfer1::ILayer* layer = nullptr;
// [hidden_in, 3, head_number, head_size]
// -> [head_number, 3, head_size, hidden_in]
auto transpose_weight = [](const float* src,
float* dst,
int two,
int head_number,
int head_size,
int hidden_in) {
for (int hn = 0; hn < head_number; hn++) {
for (int t = 0; t < two; t++) {
for (int hs = 0; hs < head_size; hs++) {
for (int hi = 0; hi < hidden_in; hi++) {
int out_index = hn * two * head_size * hidden_in +
t * head_size * hidden_in + hs * hidden_in + hi;
int in_index = hi * two * head_number * head_size +
t * head_number * head_size + hn * head_size + hs;
dst[out_index] = src[in_index];
}
}
}
}
};
std::vector<float> weight_kv_data_tmp;
weight_kv_data_tmp.reserve(weight_kv_t->numel());
memcpy(weight_kv_data_tmp.data(),
weight_kv_data,
weight_kv_t->numel() * sizeof(float));
transpose_weight(weight_kv_data_tmp.data(),
weight_kv_data,
two,
head_number,
head_size,
hidden_in);
nvinfer1::Weights weight_kv{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_kv_data),
static_cast<int32_t>(weight_kv_t->numel())};
nvinfer1::Weights bias_kv{};
// add shuffle for FullyConnected layer
std::vector<nvinfer1::ITensor*> reshape_before_fc_shape_tensor;
nvinfer1::ITensor* input_shape_tensor = Shape(input_kv);
for (int i = 0; i < 5; i++) {
reshape_before_fc_shape_tensor.push_back(Add1DConstantLayer(1));
}
for (int i = 0; i < 3; i++) {
reshape_before_fc_shape_tensor[i] =
GetEleTensorOfShape(input_shape_tensor, i);
}
auto* reshape_before_fc_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input_kv);
reshape_before_fc_layer->setInput(1,
*Concat(reshape_before_fc_shape_tensor));
reshape_before_fc_layer->setName(
("shuffle_before_fc_multihead_matmul(Output: " + output_name + ")")
.c_str());
nvinfer1::ILayer* fc_layer = nullptr;
fc_layer = TRT_ENGINE_ADD_LAYER(engine_,
FullyConnected,
*reshape_before_fc_layer->getOutput(0),
n,
weight_kv,
bias_kv);
fc_layer->setName(
("multihead_mamul_fc(Output: " + output_name + ")").c_str());
// add shuffle for fc layer
auto* reshape_after_fc_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *fc_layer->getOutput(0));
std::vector<nvinfer1::ITensor*> mha_input_tensor_shape;
for (int i = 0; i < 5; i++) {
mha_input_tensor_shape.push_back(Add1DConstantLayer(1));
}
mha_input_tensor_shape[0] = GetEleTensorOfShape(input_shape_tensor, 0);
mha_input_tensor_shape[1] = GetEleTensorOfShape(input_shape_tensor, 1);
mha_input_tensor_shape[2] = Add1DConstantLayer(head_number);
mha_input_tensor_shape[3] = Add1DConstantLayer(2);
mha_input_tensor_shape[4] = Add1DConstantLayer(head_size);
reshape_after_fc_layer->setInput(1, *Concat(mha_input_tensor_shape));
reshape_after_fc_layer->setName(
("shuffle_after_fc_multihead_matmul(Output: " + output_name + ")")
.c_str());
auto creator = GetPluginRegistry()->getPluginCreator("fMHCA", "1");
assert(creator != nullptr);
std::vector<nvinfer1::PluginField> fields{};
nvinfer1::PluginFieldCollection* plugin_collection =
static_cast<nvinfer1::PluginFieldCollection*>(
malloc(sizeof(*plugin_collection) +
fields.size() *
sizeof(nvinfer1::PluginField))); // remember to free
plugin_collection->nbFields = static_cast<int>(fields.size());
plugin_collection->fields = fields.data();
auto plugin = creator->createPlugin("fMHA_V2", plugin_collection);
free(plugin_collection);
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.emplace_back(reshape_after_fc_q_layer->getOutput(0));
plugin_inputs.emplace_back(reshape_after_fc_layer->getOutput(0));
auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin);
// add shuffle
nvinfer1::ITensor* batch_tensor =
GetEleTensorOfShape(input_q_shape_tensor, 0);
nvinfer1::ITensor* length_tensor =
GetEleTensorOfShape(input_q_shape_tensor, 1);
auto* reshape_after_mha_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *plugin_layer->getOutput(0));
std::vector<nvinfer1::ITensor*> reshape_tensor;
reshape_tensor.push_back(batch_tensor);
reshape_tensor.push_back(length_tensor);
reshape_tensor.push_back(Add1DConstantLayer(-1));
reshape_after_mha_layer->setInput(1, *Concat(reshape_tensor));
reshape_after_mha_layer->setName(
("shuffle_last_multihead_matmul(Output: " + output_name + ")").c_str());
// return
layer = reshape_after_mha_layer;
RreplenishLayerAndOutput(
layer, "cross_multihead_matmul", {output_name}, test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(cross_multihead_matmul,
CrossMultiheadMatMulOpConverter);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See
the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class FlashMultiheadMatMulOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
VLOG(3) << "convert a flash_multihead_mamul op to a corresponding tensorrt "
"network structure";
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
if (engine_->precision() == AnalysisConfig::Precision::kInt8) {
with_fp16 = true;
}
PADDLE_ENFORCE_EQ(
with_fp16,
true,
platform::errors::Unimplemented(
"Trt flash attention oss plugin only support fp16 mode yet."));
framework::OpDesc op_desc(op, nullptr);
auto* input = engine_->GetITensor(op_desc.Input("Input").front());
auto weight_name = op_desc.Input("W").front();
auto* weight_v = scope.FindVar(weight_name);
auto* weight_t = weight_v->GetMutable<phi::DenseTensor>();
float* weight_data = nullptr;
weight_data = const_cast<float*>(static_cast<const float*>(
engine_->GetFp32TrtWeight(weight_name, *weight_t).get().values));
// (hidden_in, 3, hidden_out)
const auto& weight_dims = weight_t->dims();
int hidden_in = weight_dims[0]; // channels_in
int three = weight_dims[1]; // three
int hidden_out = weight_dims[2]; // channels_out
int head_number = PADDLE_GET_CONST(int, op_desc.GetAttr("head_number"));
int head_size = hidden_out / head_number;
int n = three * hidden_out;
nvinfer1::ILayer* layer = nullptr;
auto output_name = op_desc.Output("Out")[0];
// [hidden_in, 3, head_number, head_size]
// -> [head_number, 3, head_size, hidden_in]
auto transpose_weight = [](const float* src,
float* dst,
int three,
int head_number,
int head_size,
int hidden_in) {
for (int hn = 0; hn < head_number; hn++) {
for (int t = 0; t < three; t++) {
for (int hs = 0; hs < head_size; hs++) {
for (int hi = 0; hi < hidden_in; hi++) {
int out_index = hn * three * head_size * hidden_in +
t * head_size * hidden_in + hs * hidden_in + hi;
int in_index = hi * three * head_number * head_size +
t * head_number * head_size + hn * head_size + hs;
dst[out_index] = src[in_index];
}
}
}
}
};
std::vector<float> weight_data_tmp;
weight_data_tmp.reserve(weight_t->numel());
memcpy(
weight_data_tmp.data(), weight_data, weight_t->numel() * sizeof(float));
transpose_weight(weight_data_tmp.data(),
weight_data,
three,
head_number,
head_size,
hidden_in);
nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data),
static_cast<int32_t>(weight_t->numel())};
nvinfer1::Weights bias{};
// add shuffle for FullyConnected layer
std::vector<nvinfer1::ITensor*> reshape_before_fc_shape_tensor;
nvinfer1::ITensor* input_shape_tensor = Shape(input);
for (int i = 0; i < 5; i++) {
reshape_before_fc_shape_tensor.push_back(Add1DConstantLayer(1));
}
for (int i = 0; i < 3; i++) {
reshape_before_fc_shape_tensor[i] =
GetEleTensorOfShape(input_shape_tensor, i);
}
auto* reshape_before_fc_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
reshape_before_fc_layer->setInput(1,
*Concat(reshape_before_fc_shape_tensor));
reshape_before_fc_layer->setName(
("shuffle_before_fc_multihead_matmul(Output: " + output_name + ")")
.c_str());
nvinfer1::ILayer* fc_layer = nullptr;
fc_layer = TRT_ENGINE_ADD_LAYER(engine_,
FullyConnected,
*reshape_before_fc_layer->getOutput(0),
n,
weight,
bias);
fc_layer->setName(
("multihead_mamul_fc(Output: " + output_name + ")").c_str());
// add shuffle for fc layer
auto* reshape_after_fc_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *fc_layer->getOutput(0));
std::vector<nvinfer1::ITensor*> mha_input_tensor_shape;
for (int i = 0; i < 5; i++) {
mha_input_tensor_shape.push_back(Add1DConstantLayer(1));
}
mha_input_tensor_shape[0] = GetEleTensorOfShape(input_shape_tensor, 0);
mha_input_tensor_shape[1] = GetEleTensorOfShape(input_shape_tensor, 1);
mha_input_tensor_shape[2] = Add1DConstantLayer(head_number);
mha_input_tensor_shape[3] = Add1DConstantLayer(3);
mha_input_tensor_shape[4] = Add1DConstantLayer(head_size);
reshape_after_fc_layer->setInput(1, *Concat(mha_input_tensor_shape));
reshape_after_fc_layer->setName(
("shuffle_after_fc_multihead_matmul(Output: " + output_name + ")")
.c_str());
auto creator = GetPluginRegistry()->getPluginCreator("fMHA_V2", "1");
assert(creator != nullptr);
std::vector<nvinfer1::PluginField> fields{};
nvinfer1::PluginFieldCollection* plugin_collection =
static_cast<nvinfer1::PluginFieldCollection*>(
malloc(sizeof(*plugin_collection) +
fields.size() *
sizeof(nvinfer1::PluginField))); // remember to free
plugin_collection->nbFields = static_cast<int>(fields.size());
plugin_collection->fields = fields.data();
auto plugin = creator->createPlugin("fMHA_V2", plugin_collection);
free(plugin_collection);
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.emplace_back(reshape_after_fc_layer->getOutput(0));
auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin);
// add shuffle
nvinfer1::ITensor* batch_tensor =
GetEleTensorOfShape(input_shape_tensor, 0);
nvinfer1::ITensor* length_tensor =
GetEleTensorOfShape(input_shape_tensor, 1);
auto* reshape_after_mha_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *plugin_layer->getOutput(0));
std::vector<nvinfer1::ITensor*> reshape_tensor;
reshape_tensor.push_back(batch_tensor);
reshape_tensor.push_back(length_tensor);
reshape_tensor.push_back(Add1DConstantLayer(-1));
reshape_after_mha_layer->setInput(1, *Concat(reshape_tensor));
reshape_after_mha_layer->setName(
("shuffle_last_multihead_matmul(Output: " + output_name + ")").c_str());
// return
layer = reshape_after_mha_layer;
RreplenishLayerAndOutput(
layer, "flash_multihead_matmul", {output_name}, test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(flash_multihead_matmul,
FlashMultiheadMatMulOpConverter);
...@@ -66,6 +66,12 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -66,6 +66,12 @@ struct SimpleOpTypeSetTeller : public Teller {
teller_set.insert("sparse_multihead_matmul"); teller_set.insert("sparse_multihead_matmul");
int8_teller_set.insert("sparse_multihead_matmul"); int8_teller_set.insert("sparse_multihead_matmul");
#endif #endif
#if IS_TRT_VERSION_GE(8522)
teller_set.insert("flash_multihead_matmul");
int8_teller_set.insert("flash_multihead_matmul");
teller_set.insert("cross_multihead_matmul");
int8_teller_set.insert("cross_multihead_matmul");
#endif
#if IS_TRT_VERSION_GE(8200) #if IS_TRT_VERSION_GE(8200)
teller_set.insert("round"); teller_set.insert("round");
int8_teller_set.insert("round"); int8_teller_set.insert("round");
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
from typing import List
import numpy as np
from program_config import ProgramConfig, TensorConfig
from trt_layer_auto_scan_test import SkipReasons, TrtLayerAutoScanTest
import paddle.inference as paddle_infer
class TrtConvertCrossMultiHeadMatmulTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
ver = paddle_infer.get_trt_compile_version()
if ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 < 8520:
return False
return True
def sample_program_configs(self):
def generate_input1(batch, dim1):
return np.random.random((batch, dim1, 320)).astype(np.float32) / 10
def generate_input2(batch, dim2):
return np.random.random((batch, dim2, 768)).astype(np.float32) / 10
def generate_weight1():
return np.random.random((320, 320)).astype(np.float32) / 10
def generate_weight2():
return np.random.random((768, 320)).astype(np.float32) / 10
for batch in [1, 2]:
self.batch = batch
for reshape_shape in [[0, 0, 8, 40]]:
for dim1 in [4096]:
for dim2 in [768]:
dics = [
{"trans_x": False, "trans_y": False},
{"shape": reshape_shape},
{"axis": [0, 2, 1, 3]},
{"trans_x": False, "trans_y": False},
{"shape": reshape_shape},
{"axis": [0, 2, 1, 3]},
{"trans_x": False, "trans_y": False},
{"shape": reshape_shape},
{"axis": [0, 2, 1, 3]},
{
"trans_x": False,
"trans_y": True,
},
{
"scale": 0.15811388194561005,
"bias": 0.0,
"bias_after_scale": True,
},
{"axis": -1, "is_test": True},
{"trans_x": False, "trans_y": False},
{"axis": [0, 2, 1, 3]},
{"shape": [0, 0, 320]},
]
ops_config = [
{
"op_type": "matmul_v2",
"op_inputs": {
"X": ["input_data1"],
"Y": ["mul1_weight"],
},
"op_outputs": {"Out": ["mul1_output"]},
"op_attrs": dics[0],
},
{
"op_type": "reshape2",
"op_inputs": {
"X": ["mul1_output"],
},
"op_outputs": {
"Out": ["reshape21_output"],
"XShape": ["reshape21_output_xshape"],
},
"op_attrs": dics[1],
},
{
"op_type": "transpose2",
"op_inputs": {"X": ["reshape21_output"]},
"op_outputs": {
"Out": ["transpose21_output"],
"XShape": ["transpose21_output_xshape"],
},
"op_attrs": dics[2],
},
{
"op_type": "matmul_v2",
"op_inputs": {
"X": ["input_data2"],
"Y": ["mul2_weight"],
},
"op_outputs": {"Out": ["mul2_output"]},
"op_attrs": dics[3],
},
{
"op_type": "reshape2",
"op_inputs": {"X": ["mul2_output"]},
"op_outputs": {
"Out": ["reshape22_output"],
"XShape": ["reshape22_output_xshape"],
},
"op_attrs": dics[4],
},
{
"op_type": "transpose2",
"op_inputs": {"X": ["reshape22_output"]},
"op_outputs": {
"Out": ["transpose22_output"],
"XShape": ["transpose22_output_xshape"],
},
"op_attrs": dics[5],
},
{
"op_type": "matmul_v2",
"op_inputs": {
"X": ["input_data2"],
"Y": ["mul3_weight"],
},
"op_outputs": {"Out": ["mul3_output"]},
"op_attrs": dics[6],
},
{
"op_type": "reshape2",
"op_inputs": {"X": ["mul3_output"]},
"op_outputs": {
"Out": ["reshape23_output"],
"XShape": ["reshape23_output_xshape"],
},
"op_attrs": dics[7],
},
{
"op_type": "transpose2",
"op_inputs": {"X": ["reshape23_output"]},
"op_outputs": {
"Out": ["transpose23_output"],
"XShape": ["transpose23_output_xshape"],
},
"op_attrs": dics[8],
},
{
"op_type": "matmul_v2",
"op_inputs": {
"X": ["transpose21_output"],
"Y": ["transpose22_output"],
},
"op_outputs": {"Out": ["matmul1_output"]},
"op_attrs": dics[9],
},
{
"op_type": "scale",
"op_inputs": {
"X": ["matmul1_output"],
},
"op_outputs": {"Out": ["scale_output"]},
"op_attrs": dics[10],
},
{
"op_type": "softmax",
"op_inputs": {"X": ["scale_output"]},
"op_outputs": {"Out": ["softmax_output"]},
"op_attrs": dics[11],
},
{
"op_type": "matmul_v2",
"op_inputs": {
"X": ["softmax_output"],
"Y": ["transpose23_output"],
},
"op_outputs": {"Out": ["matmul2_output"]},
"op_attrs": dics[12],
},
{
"op_type": "transpose2",
"op_inputs": {"X": ["matmul2_output"]},
"op_outputs": {
"Out": ["transpose24_output"],
"XShape": ["transpose24_output_xshape"],
},
"op_attrs": dics[13],
},
{
"op_type": "reshape2",
"op_inputs": {"X": ["transpose24_output"]},
"op_outputs": {
"Out": ["reshape24_output"],
"XShape": ["reshape24_output_xshape"],
},
"op_attrs": dics[14],
},
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={
"mul1_weight": TensorConfig(
data_gen=partial(generate_weight1)
),
"mul2_weight": TensorConfig(
data_gen=partial(generate_weight2)
),
"mul3_weight": TensorConfig(
data_gen=partial(generate_weight2)
),
},
inputs={
"input_data1": TensorConfig(
data_gen=partial(
generate_input1, batch, dim1
)
),
"input_data2": TensorConfig(
data_gen=partial(
generate_input2, batch, dim2
)
),
},
outputs=["reshape24_output"],
)
yield program_config
def sample_predictor_configs(
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
# The last dim of input1 and input2 should be static.
self.dynamic_shape.min_input_shape = {
"input_data1": [1, 4096, 320],
"input_data2": [1, 77, 768],
}
self.dynamic_shape.max_input_shape = {
"input_data1": [8, 4096, 320],
"input_data2": [8, 77, 768],
}
self.dynamic_shape.opt_input_shape = {
"input_data1": [2, 4096, 320],
"input_data2": [2, 77, 768],
}
def clear_dynamic_shape():
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
self.trt_param.workspace_size = 2013265920
yield self.create_inference_config(), (1, 4), (1e-5, 1e-5)
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 4), (1e-2, 1e-3)
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
self.trt_param.workspace_size = 2013265920
yield self.create_inference_config(), (1, 3), (1e-5, 1e-4)
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 3), (1e-2, 1e-3)
def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if self.dynamic_shape.min_input_shape == {}:
return True
return False
self.add_skip_case(
teller1,
SkipReasons.TRT_NOT_IMPLEMENTED,
"TThe cross attention trt oss plugin do not support static shape yet",
)
def teller2(program_config, predictor_config):
if self.trt_param.precision == paddle_infer.PrecisionType.Float32:
return True
return False
self.add_skip_case(
teller2,
SkipReasons.TRT_NOT_IMPLEMENTED,
"The cross attention trt oss plugin do not support fp32 yet",
)
def teller3(program_config, predictor_config):
if self.trt_param.precision == paddle_infer.PrecisionType.Int8:
return True
return False
self.add_skip_case(
teller3,
SkipReasons.TRT_NOT_IMPLEMENTED,
"The cross attention trt oss plugin do not support int8 yet.",
)
def test(self):
self.add_skip_trt_case()
self.run_test()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
from typing import List
import numpy as np
from program_config import ProgramConfig, TensorConfig
from trt_layer_auto_scan_test import SkipReasons, TrtLayerAutoScanTest
import paddle.inference as paddle_infer
class TrtConvertFlashMultiHeadMatmulTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
ver = paddle_infer.get_trt_compile_version()
if ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 < 8520:
return False
return True
def sample_program_configs(self):
def generate_input1(batch, dim1):
return np.random.rand(batch, dim1, 320).astype(np.float32) / 10
def generate_weight1():
return np.random.rand(320, 320).astype(np.float32) / 10
for batch in [1, 2]:
self.batch = batch
for reshape_shape in [[0, 0, 8, 40]]:
for dim1 in [4096]:
dics = [
{"trans_x": False, "trans_y": False}, # 0,matmul_v2_q
{"shape": reshape_shape}, # 1,reshape_q
{
"axis": [0, 2, 1, 3],
"data_format": "AnyLayout",
}, # 2,trans_q
{"trans_x": False, "trans_y": False}, # 3,matmul_v2_k
{"shape": reshape_shape}, # 4,reshape_k
{
"axis": [0, 2, 1, 3],
"data_format": "AnyLayout",
}, # 5,trans_k
{"trans_x": False, "trans_y": False}, # 6,matmul_v2_q
{"shape": reshape_shape}, # 7,reshape_q
{
"axis": [0, 2, 1, 3],
"data_format": "AnyLayout",
}, # 8,trans_q
{ # 9,matmul_qk
"trans_x": False,
"trans_y": True,
},
{ # 10,scale
"scale": 0.15811388194561005,
"bias": 0.0,
"bias_after_scale": True,
},
{"axis": -1, "is_test": True}, # 11,softmax
{"trans_x": False, "trans_y": False}, # 12,matmul_qkv
{
"axis": [0, 2, 1, 3],
"data_format": "AnyLayout",
}, # 13,trans_qkv
{"shape": [0, 0, 320]}, # 14,reshape_qkv
]
ops_config = [
{
"op_type": "matmul_v2",
"op_inputs": {
"X": ["input_data1"],
"Y": ["mul1_weight"],
},
"op_outputs": {"Out": ["mul1_output"]},
"op_attrs": dics[0],
},
{
"op_type": "reshape2",
"op_inputs": {
"X": ["mul1_output"],
},
"op_outputs": {
"Out": ["reshape21_output"],
"XShape": ["reshape21_output_xshape"],
},
"op_attrs": dics[1],
},
{
"op_type": "transpose2",
"op_inputs": {"X": ["reshape21_output"]},
"op_outputs": {
"Out": ["transpose21_output"],
"XShape": ["transpose21_output_xshape"],
},
"op_attrs": dics[2],
},
{
"op_type": "matmul_v2",
"op_inputs": {
"X": ["input_data1"],
"Y": ["mul2_weight"],
},
"op_outputs": {"Out": ["mul2_output"]},
"op_attrs": dics[3],
},
{
"op_type": "reshape2",
"op_inputs": {"X": ["mul2_output"]},
"op_outputs": {
"Out": ["reshape22_output"],
"XShape": ["reshape22_output_xshape"],
},
"op_attrs": dics[4],
},
{
"op_type": "transpose2",
"op_inputs": {"X": ["reshape22_output"]},
"op_outputs": {
"Out": ["transpose22_output"],
"XShape": ["transpose22_output_xshape"],
},
"op_attrs": dics[5],
},
{
"op_type": "matmul_v2",
"op_inputs": {
"X": ["input_data1"],
"Y": ["mul3_weight"],
},
"op_outputs": {"Out": ["mul3_output"]},
"op_attrs": dics[6],
},
{
"op_type": "reshape2",
"op_inputs": {"X": ["mul3_output"]},
"op_outputs": {
"Out": ["reshape23_output"],
"XShape": ["reshape23_output_xshape"],
},
"op_attrs": dics[7],
},
{
"op_type": "transpose2",
"op_inputs": {"X": ["reshape23_output"]},
"op_outputs": {
"Out": ["transpose23_output"],
"XShape": ["transpose23_output_xshape"],
},
"op_attrs": dics[8],
},
{
"op_type": "matmul_v2",
"op_inputs": {
"X": ["transpose21_output"],
"Y": ["transpose22_output"],
},
"op_outputs": {"Out": ["matmul1_output"]},
"op_attrs": dics[9],
},
{
"op_type": "scale",
"op_inputs": {
"X": ["matmul1_output"],
},
"op_outputs": {"Out": ["scale_output"]},
"op_attrs": dics[10],
},
{
"op_type": "softmax",
"op_inputs": {"X": ["scale_output"]},
"op_outputs": {"Out": ["softmax_output"]},
"op_attrs": dics[11],
},
{
"op_type": "matmul_v2",
"op_inputs": {
"X": ["softmax_output"],
"Y": ["transpose23_output"],
},
"op_outputs": {"Out": ["matmul2_output"]},
"op_attrs": dics[12],
},
{
"op_type": "transpose2",
"op_inputs": {"X": ["matmul2_output"]},
"op_outputs": {
"Out": ["transpose24_output"],
"XShape": ["transpose24_output_xshape"],
},
"op_attrs": dics[13],
},
{
"op_type": "reshape2",
"op_inputs": {"X": ["transpose24_output"]},
"op_outputs": {
"Out": ["reshape24_output"],
"XShape": ["reshape24_output_xshape"],
},
"op_attrs": dics[14],
},
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={
"mul1_weight": TensorConfig(
data_gen=partial(generate_weight1)
),
"mul2_weight": TensorConfig(
data_gen=partial(generate_weight1)
),
"mul3_weight": TensorConfig(
data_gen=partial(generate_weight1)
),
},
inputs={
"input_data1": TensorConfig(
data_gen=partial(generate_input1, batch, dim1)
)
},
outputs=["reshape24_output"],
)
yield program_config
def sample_predictor_configs(
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
# The last dim of input1 and input2 should be static.
self.dynamic_shape.min_input_shape = {
"input_data1": [1, 4096, 320],
}
self.dynamic_shape.max_input_shape = {
"input_data1": [16, 4096, 320],
}
self.dynamic_shape.opt_input_shape = {
"input_data1": [2, 4096, 320],
}
def clear_dynamic_shape():
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
self.trt_param.workspace_size = 2013265920
yield self.create_inference_config(), (1, 2), (1e-5, 1e-5)
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 2), (1e-3, 1e-3)
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
self.trt_param.workspace_size = 2013265920
yield self.create_inference_config(), (1, 2), (1e-5, 1e-4)
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 2), (1e-2, 1e-3)
def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if self.dynamic_shape.min_input_shape == {}:
return True
return False
self.add_skip_case(
teller1,
SkipReasons.TRT_NOT_IMPLEMENTED,
"TThe flash attention trt oss plugin do not support static shape yet",
)
def teller2(program_config, predictor_config):
if self.trt_param.precision == paddle_infer.PrecisionType.Float32:
return True
return False
self.add_skip_case(
teller2,
SkipReasons.TRT_NOT_IMPLEMENTED,
"The flash attention trt oss plugin do not support fp32 yet",
)
def teller3(program_config, predictor_config):
if self.trt_param.precision == paddle_infer.PrecisionType.Int8:
return True
return False
self.add_skip_case(
teller3,
SkipReasons.TRT_NOT_IMPLEMENTED,
"The flash attention trt oss plugin do not support int8 yet.",
)
def test(self):
self.add_skip_trt_case()
self.run_test()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册