未验证 提交 7b7e6051 编写于 作者: Z Zhaolong Xing 提交者: GitHub

[Fix BUGs]: fix multhead matmul pass's instable bug (#25123)

* fix multhead matmul's instable
test=develop

* fix multihead matmul bug
test=develop

* fix converage problem
test=develop
上级 5c8e7995
...@@ -140,7 +140,7 @@ void GraphPatternDetector::ValidateByNodeRole( ...@@ -140,7 +140,7 @@ void GraphPatternDetector::ValidateByNodeRole(
subgraphs->begin(), subgraphs->end(), subgraphs->begin(), subgraphs->end(),
[](const GraphPatternDetector::subgraph_t &subgraph) -> bool { [](const GraphPatternDetector::subgraph_t &subgraph) -> bool {
// Collect the inputs and outputs. // Collect the inputs and outputs.
std::unordered_set<Node *> ios; std::set<Node *> ios;
for (auto &item : subgraph) { for (auto &item : subgraph) {
if (!item.first->IsIntermediate()) { if (!item.first->IsIntermediate()) {
ios.insert(item.second); ios.insert(item.second);
...@@ -166,7 +166,7 @@ void GraphPatternDetector::ValidateByNodeRole( ...@@ -166,7 +166,7 @@ void GraphPatternDetector::ValidateByNodeRole(
} }
struct HitGroup { struct HitGroup {
std::unordered_map<PDNode *, Node *> roles; std::map<PDNode *, Node *> roles;
bool Match(Node *node, PDNode *pat) { bool Match(Node *node, PDNode *pat) {
if (nodes_.count(node)) { if (nodes_.count(node)) {
...@@ -184,7 +184,7 @@ struct HitGroup { ...@@ -184,7 +184,7 @@ struct HitGroup {
} }
private: private:
std::unordered_set<Node *> nodes_; std::set<Node *> nodes_;
}; };
// Tell whether Node a links to b. // Tell whether Node a links to b.
...@@ -283,7 +283,7 @@ void GraphPatternDetector::UniquePatterns( ...@@ -283,7 +283,7 @@ void GraphPatternDetector::UniquePatterns(
if (subgraphs->empty()) return; if (subgraphs->empty()) return;
std::vector<GraphPatternDetector::subgraph_t> result; std::vector<GraphPatternDetector::subgraph_t> result;
std::unordered_set<size_t> set; std::set<size_t> set;
std::hash<std::string> hasher; std::hash<std::string> hasher;
for (auto &g : *subgraphs) { for (auto &g : *subgraphs) {
// Sort the items in the sub-graph, and transform to a string key. // Sort the items in the sub-graph, and transform to a string key.
...@@ -305,7 +305,7 @@ void GraphPatternDetector::UniquePatterns( ...@@ -305,7 +305,7 @@ void GraphPatternDetector::UniquePatterns(
void GraphPatternDetector::RemoveOverlappedMatch( void GraphPatternDetector::RemoveOverlappedMatch(
std::vector<subgraph_t> *subgraphs) { std::vector<subgraph_t> *subgraphs) {
std::vector<subgraph_t> result; std::vector<subgraph_t> result;
std::unordered_set<Node *> node_set; std::set<Node *> node_set;
for (const auto &subgraph : *subgraphs) { for (const auto &subgraph : *subgraphs) {
bool valid = true; bool valid = true;
......
...@@ -231,7 +231,7 @@ class PDPattern { ...@@ -231,7 +231,7 @@ class PDPattern {
std::vector<std::unique_ptr<PDNode>> nodes_; std::vector<std::unique_ptr<PDNode>> nodes_;
std::vector<edge_t> edges_; std::vector<edge_t> edges_;
std::unordered_map<std::string, PDNode*> node_map_; std::map<std::string, PDNode*> node_map_;
static size_t id_; static size_t id_;
}; };
...@@ -263,7 +263,7 @@ class PDPattern { ...@@ -263,7 +263,7 @@ class PDPattern {
*/ */
class GraphPatternDetector { class GraphPatternDetector {
public: public:
using subgraph_t = std::unordered_map<PDNode*, Node*>; using subgraph_t = std::map<PDNode*, Node*>;
// Operate on the detected pattern. // Operate on the detected pattern.
using handle_t = using handle_t =
......
...@@ -45,13 +45,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { ...@@ -45,13 +45,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
// Create pattern. // Create pattern.
MultiHeadMatmulPattern multihead_pattern(pattern, name_scope); MultiHeadMatmulPattern multihead_pattern(pattern, name_scope);
PDNode* x = multihead_pattern();
pattern->NewNode(patterns::UniqueKey("X"))->assert_var_not_persistable();
multihead_pattern(x);
// Create New OpDesc // Create New OpDesc
auto fuse_creater = [&]( auto fuse_creater = [&](
Node* x, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out, Node* input0, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out,
Node* mul1_out, Node* mul2_out, Node* eltadd0_b, Node* eltadd1_b, Node* mul1_out, Node* mul2_out, Node* eltadd0_b, Node* eltadd1_b,
Node* eltadd2_b, Node* eltadd_qk_b, Node* reshape2, Node* eltadd2_b, Node* eltadd_qk_b, Node* reshape2,
Node* reshape2_qkv_out, Node* scale, Node* scale_out) { Node* reshape2_qkv_out, Node* scale, Node* scale_out) {
...@@ -115,7 +112,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { ...@@ -115,7 +112,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
// GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern); // GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(input0, input0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern);
...@@ -185,7 +182,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { ...@@ -185,7 +182,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out, GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out,
multihead_pattern); multihead_pattern);
fuse_creater(layer_norm, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out,
eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, reshape2_0, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, reshape2_0,
reshape2_qkv_out, scale, scale_out); reshape2_qkv_out, scale, scale_out);
...@@ -232,12 +229,9 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { ...@@ -232,12 +229,9 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
return fusion_count; return fusion_count;
} }
PDNode* MultiHeadMatmulPattern::operator()(paddle::framework::ir::PDNode* x) { PDNode* MultiHeadMatmulPattern::operator()() {
// Create shared nodes. auto* input0 = pattern->NewNode(input0_repr());
auto* layer_norm = pattern->NewNode(layer_norm_repr()); input0->assert_is_op_input("mul");
auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr());
layer_norm_out_var->assert_is_op_input("mul");
// First path with scale // First path with scale
auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_op("mul"); auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_op("mul");
...@@ -390,17 +384,15 @@ PDNode* MultiHeadMatmulPattern::operator()(paddle::framework::ir::PDNode* x) { ...@@ -390,17 +384,15 @@ PDNode* MultiHeadMatmulPattern::operator()(paddle::framework::ir::PDNode* x) {
transpose2_2_out_var->AsIntermediate()->assert_is_op_input( transpose2_2_out_var->AsIntermediate()->assert_is_op_input(
"matmul"); // link to matmul qkv "matmul"); // link to matmul qkv
// Link all nodes together
layer_norm->LinksFrom({x}).LinksTo({layer_norm_out_var});
// Q path // Q path
mul0->LinksFrom({layer_norm_out_var, mul0_w_var}).LinksTo({mul0_out_var}); mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var});
eltadd0->LinksFrom({mul0_out_var, eltadd0_b_var}).LinksTo({eltadd0_out_var}); eltadd0->LinksFrom({mul0_out_var, eltadd0_b_var}).LinksTo({eltadd0_out_var});
reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var}); reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var});
transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var}); transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var});
scale->LinksFrom({transpose2_0_out_var}).LinksTo({scale_out_var}); scale->LinksFrom({transpose2_0_out_var}).LinksTo({scale_out_var});
// K path // K path
mul1->LinksFrom({layer_norm_out_var, mul1_w_var}).LinksTo({mul1_out_var}); mul1->LinksFrom({input0, mul1_w_var}).LinksTo({mul1_out_var});
eltadd1->LinksFrom({mul1_out_var, eltadd1_b_var}).LinksTo({eltadd1_out_var}); eltadd1->LinksFrom({mul1_out_var, eltadd1_b_var}).LinksTo({eltadd1_out_var});
reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var}); reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var});
transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var}); transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var});
...@@ -411,7 +403,7 @@ PDNode* MultiHeadMatmulPattern::operator()(paddle::framework::ir::PDNode* x) { ...@@ -411,7 +403,7 @@ PDNode* MultiHeadMatmulPattern::operator()(paddle::framework::ir::PDNode* x) {
.LinksTo({eltadd_qk_out_var}); .LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var}); softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
// V path // V path
mul2->LinksFrom({layer_norm_out_var, mul2_w_var}).LinksTo({mul2_out_var}); mul2->LinksFrom({input0, mul2_w_var}).LinksTo({mul2_out_var});
eltadd2->LinksFrom({mul2_out_var, eltadd2_b_var}).LinksTo({eltadd2_out_var}); eltadd2->LinksFrom({mul2_out_var, eltadd2_b_var}).LinksTo({eltadd2_out_var});
reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var}); reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var});
transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var}); transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var});
...@@ -434,13 +426,10 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, ...@@ -434,13 +426,10 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
// Create pattern. // Create pattern.
MultiHeadMatmulPattern multihead_pattern(pattern, name_scope); MultiHeadMatmulPattern multihead_pattern(pattern, name_scope);
PDNode* x = multihead_pattern();
pattern->NewNode(patterns::UniqueKey("X"))->assert_var_not_persistable();
multihead_pattern(x);
// Create New OpDesc // Create New OpDesc
auto fuse_creater = [&]( auto fuse_creater = [&](
Node* layer_norm_out, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out, Node* input0, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out,
Node* mul1_out, Node* mul2_out, Node* mul0_w, Node* mul1_w, Node* mul2_w, Node* mul1_out, Node* mul2_out, Node* mul0_w, Node* mul1_w, Node* mul2_w,
Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b, Node* eltadd_qk_b, Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b, Node* eltadd_qk_b,
Node* reshape2, Node* reshape2_qkv_out, Node* scale, Node* scale_out) { Node* reshape2, Node* reshape2_qkv_out, Node* scale, Node* scale_out) {
...@@ -471,29 +460,20 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, ...@@ -471,29 +460,20 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
framework::make_ddim({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]}); framework::make_ddim({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]});
auto combined_bias_dims = framework::make_ddim({3, bq_tensor->dims()[0]}); auto combined_bias_dims = framework::make_ddim({3, bq_tensor->dims()[0]});
// create a new var in scope // reuse the mul0_w and eltadd_0_b nodes for the combined nodes.
VarDesc combined_w_desc( auto* combined_w_desc = mul0_w->Var();
patterns::PDNodeName(name_scope, "multi_head_combined_weight")); combined_w_desc->SetShape({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]});
combined_w_desc.SetShape({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]}); combined_w_desc->SetPersistable(true);
combined_w_desc.SetDataType(wq_tensor->type());
combined_w_desc.SetLoDLevel(mul0_w->Var()->GetLoDLevel()); auto* combined_bias_desc = eltadd0_b->Var();
combined_w_desc.SetPersistable(true); combined_bias_desc->SetShape({3, bq_tensor->dims()[0]});
combined_bias_desc->SetPersistable(true);
// create a new var in scope
VarDesc combined_bias_desc( framework::LoDTensor tmp_combined_w_tensor;
patterns::PDNodeName(name_scope, "multi_head_combined_bias")); tmp_combined_w_tensor.Resize(combined_w_dims);
combined_bias_desc.SetShape({3, bq_tensor->dims()[0]}); auto* tmp_combined_w_data =
combined_bias_desc.SetDataType(bq_tensor->type()); tmp_combined_w_tensor.mutable_data<float>(platform::CPUPlace());
combined_bias_desc.SetLoDLevel(eltadd0_b->Var()->GetLoDLevel());
combined_bias_desc.SetPersistable(true);
auto* combined_w_node = graph->CreateVarNode(&combined_w_desc);
auto* combined_w_tensor =
scope->Var(combined_w_node->Name())->GetMutable<LoDTensor>();
combined_w_tensor->Resize(combined_w_dims);
auto* combined_w_data =
combined_w_tensor->mutable_data<float>(platform::CPUPlace());
std::vector<float*> w_vec = {wq_data, wk_data, wv_data}; std::vector<float*> w_vec = {wq_data, wk_data, wv_data};
int dims_h = combined_w_dims[0], dims_w = combined_w_dims[2]; int dims_h = combined_w_dims[0], dims_w = combined_w_dims[2];
// Combine the three fc weights together. // Combine the three fc weights together.
...@@ -502,25 +482,38 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, ...@@ -502,25 +482,38 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
for (int k = 0; k < dims_w; k++) { for (int k = 0; k < dims_w; k++) {
int out_index = i * (3 * dims_w) + j * dims_w + k; int out_index = i * (3 * dims_w) + j * dims_w + k;
int in_index = i * dims_w + k; int in_index = i * dims_w + k;
combined_w_data[out_index] = w_vec[j][in_index]; tmp_combined_w_data[out_index] = w_vec[j][in_index];
} }
} }
} }
scope->EraseVars({mul0_w->Name(), mul1_w->Name(), mul2_w->Name()});
auto* combined_bias_node = graph->CreateVarNode(&combined_bias_desc); wq_tensor->Resize(combined_w_dims);
auto* combined_bias_tensor = auto* new_combined_w_data =
scope->Var(combined_bias_node->Name())->GetMutable<LoDTensor>(); wq_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_combined_w_data, tmp_combined_w_data,
combined_bias_tensor->Resize(combined_bias_dims); sizeof(float) * wq_tensor->numel());
auto* combined_bias_data =
combined_bias_tensor->mutable_data<float>(platform::CPUPlace()); scope->EraseVars({mul1_w->Name(), mul2_w->Name()});
framework::LoDTensor tmp_combined_bias_tensor;
tmp_combined_bias_tensor.Resize(combined_bias_dims);
auto* tmp_combined_bias_data =
tmp_combined_bias_tensor.mutable_data<float>(platform::CPUPlace());
size_t bias_size = bq_tensor->numel(); size_t bias_size = bq_tensor->numel();
memcpy(combined_bias_data, bq_data, sizeof(float) * bias_size); memcpy(tmp_combined_bias_data, bq_data, sizeof(float) * bias_size);
memcpy(combined_bias_data + bias_size, bk_data, sizeof(float) * bias_size); memcpy(tmp_combined_bias_data + bias_size, bk_data,
memcpy(combined_bias_data + 2 * bias_size, bv_data, sizeof(float) * bias_size);
memcpy(tmp_combined_bias_data + 2 * bias_size, bv_data,
sizeof(float) * bias_size); sizeof(float) * bias_size);
scope->EraseVars({eltadd0_b->Name(), eltadd1_b->Name(), eltadd2_b->Name()}); bq_tensor->Resize(combined_bias_dims);
auto* new_combined_bias_data =
bq_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_combined_bias_data, tmp_combined_bias_data,
sizeof(float) * bq_tensor->numel());
scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()});
auto reshape_desc = reshape2->Op(); auto reshape_desc = reshape2->Op();
int head_number = int head_number =
...@@ -529,9 +522,9 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, ...@@ -529,9 +522,9 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
OpDesc multihead_op_desc; OpDesc multihead_op_desc;
multihead_op_desc.SetType("multihead_matmul"); multihead_op_desc.SetType("multihead_matmul");
multihead_op_desc.SetInput("Input", {layer_norm_out->Name()}); multihead_op_desc.SetInput("Input", {input0->Name()});
multihead_op_desc.SetInput("W", {combined_w_node->Name()}); multihead_op_desc.SetInput("W", {mul0_w->Name()});
multihead_op_desc.SetInput("Bias", {combined_bias_node->Name()}); multihead_op_desc.SetInput("Bias", {eltadd0_b->Name()});
multihead_op_desc.SetInput("BiasQK", {eltadd_qk_b->Name()}); multihead_op_desc.SetInput("BiasQK", {eltadd_qk_b->Name()});
multihead_op_desc.SetOutput("Out", {reshape2_qkv_out->Name()}); multihead_op_desc.SetOutput("Out", {reshape2_qkv_out->Name()});
...@@ -540,9 +533,9 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, ...@@ -540,9 +533,9 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
auto* multihead = graph->CreateOpNode(&multihead_op_desc); auto* multihead = graph->CreateOpNode(&multihead_op_desc);
IR_NODE_LINK_TO(layer_norm_out, multihead); IR_NODE_LINK_TO(input0, multihead);
IR_NODE_LINK_TO(combined_w_node, multihead); IR_NODE_LINK_TO(mul0_w, multihead);
IR_NODE_LINK_TO(combined_bias_node, multihead); IR_NODE_LINK_TO(eltadd0_b, multihead);
IR_NODE_LINK_TO(eltadd_qk_b, multihead); IR_NODE_LINK_TO(eltadd_qk_b, multihead);
IR_NODE_LINK_TO(multihead, reshape2_qkv_out); IR_NODE_LINK_TO(multihead, reshape2_qkv_out);
...@@ -552,9 +545,7 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, ...@@ -552,9 +545,7 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
// GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern); // GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(input0, input0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern); GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern);
...@@ -624,14 +615,13 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, ...@@ -624,14 +615,13 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out, GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out,
multihead_pattern); multihead_pattern);
fuse_creater(layer_norm_out, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w,
mul0_w, mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b,
eltadd_qk_b, reshape2_0, reshape2_qkv_out, scale, scale_out); reshape2_0, reshape2_qkv_out, scale, scale_out);
std::unordered_set<const Node*> marked_nodes({eltadd0, std::unordered_set<const Node*> marked_nodes({eltadd0,
eltadd1, eltadd1,
eltadd2, eltadd2,
eltadd0_b,
eltadd1_b, eltadd1_b,
eltadd2_b, eltadd2_b,
eltadd0_out, eltadd0_out,
...@@ -665,7 +655,6 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope, ...@@ -665,7 +655,6 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
mul0_out, mul0_out,
mul1_out, mul1_out,
mul2_out, mul2_out,
mul0_w,
mul1_w, mul1_w,
mul2_w, mul2_w,
reshape2_qkv, reshape2_qkv,
......
...@@ -29,11 +29,10 @@ struct MultiHeadMatmulPattern : public PatternBase { ...@@ -29,11 +29,10 @@ struct MultiHeadMatmulPattern : public PatternBase {
MultiHeadMatmulPattern(PDPattern* pattern, const std::string& name_scope) MultiHeadMatmulPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "multihead_matmul") {} : PatternBase(pattern, name_scope, "multihead_matmul") {}
PDNode* operator()(PDNode* x); PDNode* operator()();
// declare operator node's name // declare operator node's name
PATTERN_DECL_NODE(layer_norm); PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(mul0); PATTERN_DECL_NODE(mul0);
PATTERN_DECL_NODE(mul1); PATTERN_DECL_NODE(mul1);
PATTERN_DECL_NODE(mul2); PATTERN_DECL_NODE(mul2);
......
...@@ -167,7 +167,7 @@ nvinfer1::DimsExprs QkvToContextPluginDynamic::getOutputDimensions( ...@@ -167,7 +167,7 @@ nvinfer1::DimsExprs QkvToContextPluginDynamic::getOutputDimensions(
ret.nbDims = 5; ret.nbDims = 5;
ret.d[0] = inputs[0].d[0]; ret.d[0] = inputs[0].d[0];
ret.d[1] = inputs[0].d[1]; ret.d[1] = inputs[0].d[1];
ret.d[2] = expr_builder.constant(hidden_); ret.d[2] = expr_builder.constant(head_size_ * head_number_);
ret.d[3] = expr_builder.constant(1); ret.d[3] = expr_builder.constant(1);
ret.d[4] = expr_builder.constant(1); ret.d[4] = expr_builder.constant(1);
return ret; return ret;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册