未验证 提交 7b7e6051 编写于 作者: Z Zhaolong Xing 提交者: GitHub

[Fix BUGs]: fix multhead matmul pass's instable bug (#25123)

* fix multhead matmul's instable
test=develop

* fix multihead matmul bug
test=develop

* fix converage problem
test=develop
上级 5c8e7995
......@@ -140,7 +140,7 @@ void GraphPatternDetector::ValidateByNodeRole(
subgraphs->begin(), subgraphs->end(),
[](const GraphPatternDetector::subgraph_t &subgraph) -> bool {
// Collect the inputs and outputs.
std::unordered_set<Node *> ios;
std::set<Node *> ios;
for (auto &item : subgraph) {
if (!item.first->IsIntermediate()) {
ios.insert(item.second);
......@@ -166,7 +166,7 @@ void GraphPatternDetector::ValidateByNodeRole(
}
struct HitGroup {
std::unordered_map<PDNode *, Node *> roles;
std::map<PDNode *, Node *> roles;
bool Match(Node *node, PDNode *pat) {
if (nodes_.count(node)) {
......@@ -184,7 +184,7 @@ struct HitGroup {
}
private:
std::unordered_set<Node *> nodes_;
std::set<Node *> nodes_;
};
// Tell whether Node a links to b.
......@@ -283,7 +283,7 @@ void GraphPatternDetector::UniquePatterns(
if (subgraphs->empty()) return;
std::vector<GraphPatternDetector::subgraph_t> result;
std::unordered_set<size_t> set;
std::set<size_t> set;
std::hash<std::string> hasher;
for (auto &g : *subgraphs) {
// Sort the items in the sub-graph, and transform to a string key.
......@@ -305,7 +305,7 @@ void GraphPatternDetector::UniquePatterns(
void GraphPatternDetector::RemoveOverlappedMatch(
std::vector<subgraph_t> *subgraphs) {
std::vector<subgraph_t> result;
std::unordered_set<Node *> node_set;
std::set<Node *> node_set;
for (const auto &subgraph : *subgraphs) {
bool valid = true;
......
......@@ -231,7 +231,7 @@ class PDPattern {
std::vector<std::unique_ptr<PDNode>> nodes_;
std::vector<edge_t> edges_;
std::unordered_map<std::string, PDNode*> node_map_;
std::map<std::string, PDNode*> node_map_;
static size_t id_;
};
......@@ -263,7 +263,7 @@ class PDPattern {
*/
class GraphPatternDetector {
public:
using subgraph_t = std::unordered_map<PDNode*, Node*>;
using subgraph_t = std::map<PDNode*, Node*>;
// Operate on the detected pattern.
using handle_t =
......
......@@ -45,13 +45,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
// Create pattern.
MultiHeadMatmulPattern multihead_pattern(pattern, name_scope);
PDNode* x =
pattern->NewNode(patterns::UniqueKey("X"))->assert_var_not_persistable();
multihead_pattern(x);
multihead_pattern();
// Create New OpDesc
auto fuse_creater = [&](
Node* x, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out,
Node* input0, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out,
Node* mul1_out, Node* mul2_out, Node* eltadd0_b, Node* eltadd1_b,
Node* eltadd2_b, Node* eltadd_qk_b, Node* reshape2,
Node* reshape2_qkv_out, Node* scale, Node* scale_out) {
......@@ -115,7 +112,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
// GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(input0, input0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern);
......@@ -185,7 +182,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out,
multihead_pattern);
fuse_creater(layer_norm, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out,
fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out,
eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b, reshape2_0,
reshape2_qkv_out, scale, scale_out);
......@@ -232,12 +229,9 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
return fusion_count;
}
PDNode* MultiHeadMatmulPattern::operator()(paddle::framework::ir::PDNode* x) {
// Create shared nodes.
auto* layer_norm = pattern->NewNode(layer_norm_repr());
auto* layer_norm_out_var = pattern->NewNode(layer_norm_out_repr());
layer_norm_out_var->assert_is_op_input("mul");
PDNode* MultiHeadMatmulPattern::operator()() {
auto* input0 = pattern->NewNode(input0_repr());
input0->assert_is_op_input("mul");
// First path with scale
auto* mul0 = pattern->NewNode(mul0_repr())->assert_is_op("mul");
......@@ -390,17 +384,15 @@ PDNode* MultiHeadMatmulPattern::operator()(paddle::framework::ir::PDNode* x) {
transpose2_2_out_var->AsIntermediate()->assert_is_op_input(
"matmul"); // link to matmul qkv
// Link all nodes together
layer_norm->LinksFrom({x}).LinksTo({layer_norm_out_var});
// Q path
mul0->LinksFrom({layer_norm_out_var, mul0_w_var}).LinksTo({mul0_out_var});
mul0->LinksFrom({input0, mul0_w_var}).LinksTo({mul0_out_var});
eltadd0->LinksFrom({mul0_out_var, eltadd0_b_var}).LinksTo({eltadd0_out_var});
reshape2_0->LinksFrom({eltadd0_out_var}).LinksTo({reshape2_0_out_var});
transpose2_0->LinksFrom({reshape2_0_out_var}).LinksTo({transpose2_0_out_var});
scale->LinksFrom({transpose2_0_out_var}).LinksTo({scale_out_var});
// K path
mul1->LinksFrom({layer_norm_out_var, mul1_w_var}).LinksTo({mul1_out_var});
mul1->LinksFrom({input0, mul1_w_var}).LinksTo({mul1_out_var});
eltadd1->LinksFrom({mul1_out_var, eltadd1_b_var}).LinksTo({eltadd1_out_var});
reshape2_1->LinksFrom({eltadd1_out_var}).LinksTo({reshape2_1_out_var});
transpose2_1->LinksFrom({reshape2_1_out_var}).LinksTo({transpose2_1_out_var});
......@@ -411,7 +403,7 @@ PDNode* MultiHeadMatmulPattern::operator()(paddle::framework::ir::PDNode* x) {
.LinksTo({eltadd_qk_out_var});
softmax_qk->LinksFrom({eltadd_qk_out_var}).LinksTo({softmax_qk_out_var});
// V path
mul2->LinksFrom({layer_norm_out_var, mul2_w_var}).LinksTo({mul2_out_var});
mul2->LinksFrom({input0, mul2_w_var}).LinksTo({mul2_out_var});
eltadd2->LinksFrom({mul2_out_var, eltadd2_b_var}).LinksTo({eltadd2_out_var});
reshape2_2->LinksFrom({eltadd2_out_var}).LinksTo({reshape2_2_out_var});
transpose2_2->LinksFrom({reshape2_2_out_var}).LinksTo({transpose2_2_out_var});
......@@ -434,13 +426,10 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
// Create pattern.
MultiHeadMatmulPattern multihead_pattern(pattern, name_scope);
PDNode* x =
pattern->NewNode(patterns::UniqueKey("X"))->assert_var_not_persistable();
multihead_pattern(x);
multihead_pattern();
// Create New OpDesc
auto fuse_creater = [&](
Node* layer_norm_out, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out,
Node* input0, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out,
Node* mul1_out, Node* mul2_out, Node* mul0_w, Node* mul1_w, Node* mul2_w,
Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b, Node* eltadd_qk_b,
Node* reshape2, Node* reshape2_qkv_out, Node* scale, Node* scale_out) {
......@@ -471,29 +460,20 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
framework::make_ddim({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]});
auto combined_bias_dims = framework::make_ddim({3, bq_tensor->dims()[0]});
// create a new var in scope
VarDesc combined_w_desc(
patterns::PDNodeName(name_scope, "multi_head_combined_weight"));
combined_w_desc.SetShape({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]});
combined_w_desc.SetDataType(wq_tensor->type());
combined_w_desc.SetLoDLevel(mul0_w->Var()->GetLoDLevel());
combined_w_desc.SetPersistable(true);
// create a new var in scope
VarDesc combined_bias_desc(
patterns::PDNodeName(name_scope, "multi_head_combined_bias"));
combined_bias_desc.SetShape({3, bq_tensor->dims()[0]});
combined_bias_desc.SetDataType(bq_tensor->type());
combined_bias_desc.SetLoDLevel(eltadd0_b->Var()->GetLoDLevel());
combined_bias_desc.SetPersistable(true);
auto* combined_w_node = graph->CreateVarNode(&combined_w_desc);
auto* combined_w_tensor =
scope->Var(combined_w_node->Name())->GetMutable<LoDTensor>();
combined_w_tensor->Resize(combined_w_dims);
auto* combined_w_data =
combined_w_tensor->mutable_data<float>(platform::CPUPlace());
// reuse the mul0_w and eltadd_0_b nodes for the combined nodes.
auto* combined_w_desc = mul0_w->Var();
combined_w_desc->SetShape({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]});
combined_w_desc->SetPersistable(true);
auto* combined_bias_desc = eltadd0_b->Var();
combined_bias_desc->SetShape({3, bq_tensor->dims()[0]});
combined_bias_desc->SetPersistable(true);
framework::LoDTensor tmp_combined_w_tensor;
tmp_combined_w_tensor.Resize(combined_w_dims);
auto* tmp_combined_w_data =
tmp_combined_w_tensor.mutable_data<float>(platform::CPUPlace());
std::vector<float*> w_vec = {wq_data, wk_data, wv_data};
int dims_h = combined_w_dims[0], dims_w = combined_w_dims[2];
// Combine the three fc weights together.
......@@ -502,25 +482,38 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
for (int k = 0; k < dims_w; k++) {
int out_index = i * (3 * dims_w) + j * dims_w + k;
int in_index = i * dims_w + k;
combined_w_data[out_index] = w_vec[j][in_index];
tmp_combined_w_data[out_index] = w_vec[j][in_index];
}
}
}
scope->EraseVars({mul0_w->Name(), mul1_w->Name(), mul2_w->Name()});
auto* combined_bias_node = graph->CreateVarNode(&combined_bias_desc);
auto* combined_bias_tensor =
scope->Var(combined_bias_node->Name())->GetMutable<LoDTensor>();
combined_bias_tensor->Resize(combined_bias_dims);
auto* combined_bias_data =
combined_bias_tensor->mutable_data<float>(platform::CPUPlace());
wq_tensor->Resize(combined_w_dims);
auto* new_combined_w_data =
wq_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_combined_w_data, tmp_combined_w_data,
sizeof(float) * wq_tensor->numel());
scope->EraseVars({mul1_w->Name(), mul2_w->Name()});
framework::LoDTensor tmp_combined_bias_tensor;
tmp_combined_bias_tensor.Resize(combined_bias_dims);
auto* tmp_combined_bias_data =
tmp_combined_bias_tensor.mutable_data<float>(platform::CPUPlace());
size_t bias_size = bq_tensor->numel();
memcpy(combined_bias_data, bq_data, sizeof(float) * bias_size);
memcpy(combined_bias_data + bias_size, bk_data, sizeof(float) * bias_size);
memcpy(combined_bias_data + 2 * bias_size, bv_data,
memcpy(tmp_combined_bias_data, bq_data, sizeof(float) * bias_size);
memcpy(tmp_combined_bias_data + bias_size, bk_data,
sizeof(float) * bias_size);
memcpy(tmp_combined_bias_data + 2 * bias_size, bv_data,
sizeof(float) * bias_size);
scope->EraseVars({eltadd0_b->Name(), eltadd1_b->Name(), eltadd2_b->Name()});
bq_tensor->Resize(combined_bias_dims);
auto* new_combined_bias_data =
bq_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_combined_bias_data, tmp_combined_bias_data,
sizeof(float) * bq_tensor->numel());
scope->EraseVars({eltadd1_b->Name(), eltadd2_b->Name()});
auto reshape_desc = reshape2->Op();
int head_number =
......@@ -529,9 +522,9 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
OpDesc multihead_op_desc;
multihead_op_desc.SetType("multihead_matmul");
multihead_op_desc.SetInput("Input", {layer_norm_out->Name()});
multihead_op_desc.SetInput("W", {combined_w_node->Name()});
multihead_op_desc.SetInput("Bias", {combined_bias_node->Name()});
multihead_op_desc.SetInput("Input", {input0->Name()});
multihead_op_desc.SetInput("W", {mul0_w->Name()});
multihead_op_desc.SetInput("Bias", {eltadd0_b->Name()});
multihead_op_desc.SetInput("BiasQK", {eltadd_qk_b->Name()});
multihead_op_desc.SetOutput("Out", {reshape2_qkv_out->Name()});
......@@ -540,9 +533,9 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
auto* multihead = graph->CreateOpNode(&multihead_op_desc);
IR_NODE_LINK_TO(layer_norm_out, multihead);
IR_NODE_LINK_TO(combined_w_node, multihead);
IR_NODE_LINK_TO(combined_bias_node, multihead);
IR_NODE_LINK_TO(input0, multihead);
IR_NODE_LINK_TO(mul0_w, multihead);
IR_NODE_LINK_TO(eltadd0_b, multihead);
IR_NODE_LINK_TO(eltadd_qk_b, multihead);
IR_NODE_LINK_TO(multihead, reshape2_qkv_out);
......@@ -552,9 +545,7 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
// GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(input0, input0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern);
......@@ -624,14 +615,13 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out,
multihead_pattern);
fuse_creater(layer_norm_out, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out,
mul0_w, mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b,
eltadd_qk_b, reshape2_0, reshape2_qkv_out, scale, scale_out);
fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w,
mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b,
reshape2_0, reshape2_qkv_out, scale, scale_out);
std::unordered_set<const Node*> marked_nodes({eltadd0,
eltadd1,
eltadd2,
eltadd0_b,
eltadd1_b,
eltadd2_b,
eltadd0_out,
......@@ -665,7 +655,6 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
mul0_out,
mul1_out,
mul2_out,
mul0_w,
mul1_w,
mul2_w,
reshape2_qkv,
......
......@@ -29,11 +29,10 @@ struct MultiHeadMatmulPattern : public PatternBase {
MultiHeadMatmulPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "multihead_matmul") {}
PDNode* operator()(PDNode* x);
PDNode* operator()();
// declare operator node's name
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(input0);
PATTERN_DECL_NODE(mul0);
PATTERN_DECL_NODE(mul1);
PATTERN_DECL_NODE(mul2);
......
......@@ -167,7 +167,7 @@ nvinfer1::DimsExprs QkvToContextPluginDynamic::getOutputDimensions(
ret.nbDims = 5;
ret.d[0] = inputs[0].d[0];
ret.d[1] = inputs[0].d[1];
ret.d[2] = expr_builder.constant(hidden_);
ret.d[2] = expr_builder.constant(head_size_ * head_number_);
ret.d[3] = expr_builder.constant(1);
ret.d[4] = expr_builder.constant(1);
return ret;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册