未验证 提交 fcaa64b3 编写于 作者: B baoachun 提交者: GitHub

add multihead_matmul trt converter test case (#36023)

* add multihead_matmul trt converter test case

* move attribute check to op_teller
上级 8e19d1ba
...@@ -62,7 +62,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { ...@@ -62,7 +62,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
// BOOST_GET_CONST(bool, scale->Op()->GetAttr("bias_after_scale")); // BOOST_GET_CONST(bool, scale->Op()->GetAttr("bias_after_scale"));
// create multihead // create multihead
OpDesc multihead_op_desc; OpDesc multihead_op_desc(mul0->Op()->Block());
// create tmp tensor // create tmp tensor
VarDesc k_var_desc(*mul1_out->Var()); VarDesc k_var_desc(*mul1_out->Var());
...@@ -847,7 +847,7 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, ...@@ -847,7 +847,7 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
int head_number = int head_number =
BOOST_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape")).at(2); BOOST_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape")).at(2);
OpDesc multihead_op_desc; OpDesc multihead_op_desc(mul0->Op()->Block());
multihead_op_desc.SetType("multihead_matmul"); multihead_op_desc.SetType("multihead_matmul");
multihead_op_desc.SetInput("Input", {input0->Name()}); multihead_op_desc.SetInput("Input", {input0->Name()});
...@@ -1287,7 +1287,7 @@ int MultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph, ...@@ -1287,7 +1287,7 @@ int MultiHeadMatmulV3FusePass::BuildFusionV3(Graph* graph,
int head_number = int head_number =
BOOST_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape")).at(2); BOOST_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape")).at(2);
OpDesc multihead_op_desc; OpDesc multihead_op_desc(mul0->Op()->Block());
multihead_op_desc.SetType("multihead_matmul"); multihead_op_desc.SetType("multihead_matmul");
multihead_op_desc.SetInput("Input", {input0->Name()}); multihead_op_desc.SetInput("Input", {input0->Name()});
......
...@@ -23,7 +23,6 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -23,7 +23,6 @@ class MultiheadMatMulOpConverter : public OpConverter {
public: public:
void operator()(const framework::proto::OpDesc& op, void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override { const framework::Scope& scope, bool test_mode) override {
#if IS_TRT_VERSION_GE(6000)
VLOG(3) << "convert a fluid multihead_mamul op to a corresponding tensorrt " VLOG(3) << "convert a fluid multihead_mamul op to a corresponding tensorrt "
"network structure"; "network structure";
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
...@@ -46,10 +45,6 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -46,10 +45,6 @@ class MultiheadMatMulOpConverter : public OpConverter {
float in_scale = 0.; float in_scale = 0.;
if (enable_int8) { if (enable_int8) {
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("Input_scale"), true,
platform::errors::InvalidArgument(
"must have input scale in multihead layers in int8 mode"));
in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")) * 127; in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")) * 127;
auto weight_scale = auto weight_scale =
BOOST_GET_CONST(std::vector<float>, op_desc.GetAttr("weight_scale")); BOOST_GET_CONST(std::vector<float>, op_desc.GetAttr("weight_scale"));
...@@ -181,10 +176,7 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -181,10 +176,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
{"hidden_size", &hidden_out, nvinfer1::PluginFieldType::kINT32, 1}, {"hidden_size", &hidden_out, nvinfer1::PluginFieldType::kINT32, 1},
{"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1}, {"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1},
{"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1}, {"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1},
{ "var_seqlen", {"var_seqlen", &var_seqlen, nvinfer1::PluginFieldType::kINT32, 1}};
&var_seqlen,
nvinfer1::PluginFieldType::kINT32,
1 }};
if (qkv2context_plugin_int8) { if (qkv2context_plugin_int8) {
fields.push_back( fields.push_back(
{"dq_probs", &dp_probs, nvinfer1::PluginFieldType::kFLOAT32, 1}); {"dq_probs", &dp_probs, nvinfer1::PluginFieldType::kFLOAT32, 1});
...@@ -296,11 +288,6 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -296,11 +288,6 @@ class MultiheadMatMulOpConverter : public OpConverter {
} }
RreplenishLayerAndOutput(layer, "multihead_matmul", {output_name}, RreplenishLayerAndOutput(layer, "multihead_matmul", {output_name},
test_mode); test_mode);
#else
PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0"));
#endif
} }
}; };
......
...@@ -1085,6 +1085,42 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -1085,6 +1085,42 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
VLOG(3) << "the multihead_matmul does not support static shape yet"; VLOG(3) << "the multihead_matmul does not support static shape yet";
return false; return false;
} }
if (desc.HasAttr("enable_int8") && !desc.HasAttr("Input_scale")) {
VLOG(3) << "Multihead layers must have input scale in int8 mode.";
return false;
}
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
auto* input_desc = block->FindVar(desc.Input("Input").front());
const auto input_shape = input_desc->GetShape();
const auto head_number =
BOOST_GET_CONST(int, desc.GetAttr("head_number"));
auto* biasqk_desc = block->FindVar(desc.Input("BiasQK").front());
const auto biasqk_shape = biasqk_desc->GetShape();
// The BiasQK's shape requires to be
// [batch, 1, 1, length] or [batch, head, length, length].
bool has_same_shape = head_number == biasqk_shape[1] &&
input_shape[1] == biasqk_shape[2] &&
input_shape[1] == biasqk_shape[3];
bool is_broadcastable = biasqk_shape[1] == 1 && biasqk_shape[2] == 1 &&
input_shape[1] == biasqk_shape[3];
if (!(has_same_shape || is_broadcastable)) {
VLOG(3) << "The BiasQK's shape is invalid, expect [" << input_shape[0]
<< ", 1, 1, " << input_shape[1] << "] or [" << input_shape[0]
<< ", " << head_number << ", " << input_shape[1] << ", "
<< input_shape[1] << "] but [" << biasqk_shape[0] << ", "
<< biasqk_shape[1] << ", " << biasqk_shape[2] << ", "
<< biasqk_shape[3] << "].";
return false;
}
} }
if (op_type == "fc") { if (op_type == "fc") {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册