未验证 提交 e49c17d2 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle Inference]Enhance the shape check of trt_embedding_eltwise_layernorm_fuse_pass,… (#54861)

* Enhance the shape check of trt_embedding_eltwise_layernorm_fuse_pass, embedding_eltwise_layernorm_fuse_pass
上级 f8d02146
...@@ -307,11 +307,44 @@ int EmbeddingEltwiseLayerNormFusePass::BuildFusion( ...@@ -307,11 +307,44 @@ int EmbeddingEltwiseLayerNormFusePass::BuildFusion(
std::vector<std::string> ids; std::vector<std::string> ids;
std::vector<std::string> embs; std::vector<std::string> embs;
auto ids0_shape = start_pattern_in_nodes[i][0].first->Var()->GetShape();
bool flag = true;
for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) { for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) {
auto ids_shape = start_pattern_in_nodes[i][iter].first->Var()->GetShape();
if (ids_shape.size() != ids0_shape.size()) {
VLOG(3) << "Shape check failed, ids'rank are not all equal, stop "
"embedding_eltwise_layernorm_fuse_pass.";
flag = false;
} else {
for (size_t j = 0; j < ids_shape.size(); ++j) {
if (ids_shape[j] != ids0_shape[j]) {
VLOG(3)
<< "Shape check failed, ids.shape[i] are not all equal, stop "
"embedding_eltwise_layernorm_fuse_pass.";
flag = false;
}
}
}
ids.push_back(start_pattern_in_nodes[i][iter].first->Name()); ids.push_back(start_pattern_in_nodes[i][iter].first->Name());
embs.push_back(start_pattern_in_nodes[i][iter].second->Name()); embs.push_back(start_pattern_in_nodes[i][iter].second->Name());
} }
for (size_t iter = 0; iter < js.size(); ++iter) { for (size_t iter = 0; iter < js.size(); ++iter) {
auto ids_shape = inner_pattern_ins[js[iter]].first->Var()->GetShape();
if (ids_shape.size() != ids0_shape.size()) {
VLOG(3) << "Shape check failed, ids'rank are not all equal, stop "
"embedding_eltwise_layernorm_fuse_pass.";
flag = false;
} else {
for (size_t j = 0; j < ids_shape.size(); ++j) {
if (ids_shape[j] != ids0_shape[j]) {
VLOG(3)
<< "Shape check failed, ids.shape[i] are not all equal, stop "
"embedding_eltwise_layernorm_fuse_pass.";
flag = false;
}
}
}
ids.push_back(inner_pattern_ins[js[iter]].first->Name()); ids.push_back(inner_pattern_ins[js[iter]].first->Name());
embs.push_back(inner_pattern_ins[js[iter]].second->Name()); embs.push_back(inner_pattern_ins[js[iter]].second->Name());
} }
...@@ -322,66 +355,70 @@ int EmbeddingEltwiseLayerNormFusePass::BuildFusion( ...@@ -322,66 +355,70 @@ int EmbeddingEltwiseLayerNormFusePass::BuildFusion(
"inputs with lookup_table_v2"; "inputs with lookup_table_v2";
return fusion_count; return fusion_count;
} }
if (flag) {
OpDesc new_op_desc;
new_op_desc.SetType("fused_embedding_eltwise_layernorm");
new_op_desc.SetInput("Ids", ids);
new_op_desc.SetInput("Embs", embs);
new_op_desc.SetInput("WordId", {ids[0]});
new_op_desc.SetInput("PosId", {ids[1]});
if (ids.size() > 2) {
new_op_desc.SetInput("SentId", {ids[2]});
}
OpDesc new_op_desc; new_op_desc.SetInput("WordEmbedding", {embs[0]});
new_op_desc.SetType("fused_embedding_eltwise_layernorm"); new_op_desc.SetInput("PosEmbedding", {embs[1]});
new_op_desc.SetInput("Ids", ids); if (embs.size() > 2) {
new_op_desc.SetInput("Embs", embs); new_op_desc.SetInput("SentEmbedding", {embs[2]});
new_op_desc.SetInput("WordId", {ids[0]}); }
new_op_desc.SetInput("PosId", {ids[1]});
if (ids.size() > 2) {
new_op_desc.SetInput("SentId", {ids[2]});
}
new_op_desc.SetInput("WordEmbedding", {embs[0]});
new_op_desc.SetInput("PosEmbedding", {embs[1]});
if (embs.size() > 2) {
new_op_desc.SetInput("SentEmbedding", {embs[2]});
}
new_op_desc.SetInput("Bias", {end_pattern_biases[k]->Name()}); new_op_desc.SetInput("Bias", {end_pattern_biases[k]->Name()});
new_op_desc.SetInput("Scale", {end_pattern_scales[k]->Name()}); new_op_desc.SetInput("Scale", {end_pattern_scales[k]->Name()});
new_op_desc.SetOutput("Out", {end_pattern_out[k]->Name()}); new_op_desc.SetOutput("Out", {end_pattern_out[k]->Name()});
new_op_desc.SetAttr("epsilon", new_op_desc.SetAttr("epsilon",
end_patter_layernorms[k]->Op()->GetAttr("epsilon")); end_patter_layernorms[k]->Op()->GetAttr("epsilon"));
if (end_patter_layernorms[k]->Op()->HasAttr("out_threshold")) { if (end_patter_layernorms[k]->Op()->HasAttr("out_threshold")) {
new_op_desc.SetAttr("enable_int8", true); new_op_desc.SetAttr("enable_int8", true);
new_op_desc.SetAttr( new_op_desc.SetAttr(
"out_threshold", "out_threshold",
end_patter_layernorms[k]->Op()->GetAttr("out_threshold")); end_patter_layernorms[k]->Op()->GetAttr("out_threshold"));
} }
auto* embedding_eltwise_layernorm = graph->CreateOpNode(&new_op_desc); auto* embedding_eltwise_layernorm = graph->CreateOpNode(&new_op_desc);
for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) { for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) {
IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].first, IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].first,
embedding_eltwise_layernorm); embedding_eltwise_layernorm);
IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].second, IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].second,
embedding_eltwise_layernorm); embedding_eltwise_layernorm);
} }
for (size_t iter = 0; iter < js.size(); ++iter) { for (size_t iter = 0; iter < js.size(); ++iter) {
IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].first, IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].first,
embedding_eltwise_layernorm); embedding_eltwise_layernorm);
IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].second, IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].second,
embedding_eltwise_layernorm); embedding_eltwise_layernorm);
} }
IR_NODE_LINK_TO(end_pattern_biases[k], embedding_eltwise_layernorm); IR_NODE_LINK_TO(end_pattern_biases[k], embedding_eltwise_layernorm);
IR_NODE_LINK_TO(end_pattern_scales[k], embedding_eltwise_layernorm); IR_NODE_LINK_TO(end_pattern_scales[k], embedding_eltwise_layernorm);
IR_NODE_LINK_TO(embedding_eltwise_layernorm, end_pattern_out[k]); IR_NODE_LINK_TO(embedding_eltwise_layernorm, end_pattern_out[k]);
// Remove unneeded nodes. // Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes; std::unordered_set<const Node*> marked_nodes;
marked_nodes.insert(start_pattern_remove_nodes[i].begin(), marked_nodes.insert(start_pattern_remove_nodes[i].begin(),
start_pattern_remove_nodes[i].end()); start_pattern_remove_nodes[i].end());
marked_nodes.insert(end_pattern_remove_nodes[k].begin(), marked_nodes.insert(end_pattern_remove_nodes[k].begin(),
end_pattern_remove_nodes[k].end()); end_pattern_remove_nodes[k].end());
for (size_t iter = 0; iter < js.size(); ++iter) { for (size_t iter = 0; iter < js.size(); ++iter) {
marked_nodes.insert(inner_pattern_remove_nodes[js[iter]].begin(), marked_nodes.insert(inner_pattern_remove_nodes[js[iter]].begin(),
inner_pattern_remove_nodes[js[iter]].end()); inner_pattern_remove_nodes[js[iter]].end());
}
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
} else {
VLOG(3) << "Shape check failed, stop "
"embedding_eltwise_layernorm_fuse_pass.";
} }
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
} }
return fusion_count; return fusion_count;
......
...@@ -311,68 +311,105 @@ int TrtEmbeddingEltwiseLayerNormFusePass::BuildFusion( ...@@ -311,68 +311,105 @@ int TrtEmbeddingEltwiseLayerNormFusePass::BuildFusion(
std::vector<std::string> ids; std::vector<std::string> ids;
std::vector<std::string> embs; std::vector<std::string> embs;
auto ids0_shape = start_pattern_in_nodes[i][0].first->Var()->GetShape();
bool flag = true;
for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) { for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) {
auto ids_shape = start_pattern_in_nodes[i][iter].first->Var()->GetShape();
if (ids_shape.size() != ids0_shape.size()) {
VLOG(3) << "Shape check failed, ids'rank are not all equal, stop "
"trt_embedding_eltwise_layernorm_fuse_pass.";
flag = false;
} else {
for (size_t j = 0; j < ids_shape.size(); ++j) {
if (ids_shape[j] != ids0_shape[j]) {
VLOG(3)
<< "Shape check failed, ids.shape[i] are not all equal, stop "
"trt_embedding_eltwise_layernorm_fuse_pass.";
flag = false;
}
}
}
ids.push_back(start_pattern_in_nodes[i][iter].first->Name()); ids.push_back(start_pattern_in_nodes[i][iter].first->Name());
embs.push_back(start_pattern_in_nodes[i][iter].second->Name()); embs.push_back(start_pattern_in_nodes[i][iter].second->Name());
} }
for (size_t iter = 0; iter < js.size(); ++iter) { for (size_t iter = 0; iter < js.size(); ++iter) {
auto ids_shape = inner_pattern_ins[js[iter]].first->Var()->GetShape();
if (ids_shape.size() != ids0_shape.size()) {
VLOG(3) << "Shape check failed, ids'rank are not all equal, stop "
"trt_embedding_eltwise_layernorm_fuse_pass.";
flag = false;
} else {
for (size_t j = 0; j < ids_shape.size(); ++j) {
if (ids_shape[j] != ids0_shape[j]) {
VLOG(3)
<< "Shape check failed, ids.shape[i] are not all equal, stop "
"trt_embedding_eltwise_layernorm_fuse_pass.";
flag = false;
}
}
}
ids.push_back(inner_pattern_ins[js[iter]].first->Name()); ids.push_back(inner_pattern_ins[js[iter]].first->Name());
embs.push_back(inner_pattern_ins[js[iter]].second->Name()); embs.push_back(inner_pattern_ins[js[iter]].second->Name());
} }
OpDesc new_op_desc(end_patter_layernorms[0]->Op()->Block()); if (flag) {
new_op_desc.SetType("fused_embedding_eltwise_layernorm"); OpDesc new_op_desc(end_patter_layernorms[0]->Op()->Block());
new_op_desc.SetInput("Ids", ids); new_op_desc.SetType("fused_embedding_eltwise_layernorm");
new_op_desc.SetInput("Embs", embs); new_op_desc.SetInput("Ids", ids);
if (use_varseqlen && pos_id != "" && mask_id != "") { new_op_desc.SetInput("Embs", embs);
new_op_desc.SetInput("PosId", {pos_id}); if (use_varseqlen && pos_id != "" && mask_id != "") {
new_op_desc.SetInput("MaskId", {mask_id}); new_op_desc.SetInput("PosId", {pos_id});
} new_op_desc.SetInput("MaskId", {mask_id});
new_op_desc.SetInput("Bias", {end_pattern_biases[k]->Name()}); }
new_op_desc.SetInput("Scale", {end_pattern_scales[k]->Name()}); new_op_desc.SetInput("Bias", {end_pattern_biases[k]->Name()});
new_op_desc.SetOutput("Out", {end_pattern_out[k]->Name()}); new_op_desc.SetInput("Scale", {end_pattern_scales[k]->Name()});
new_op_desc.SetAttr("epsilon", new_op_desc.SetOutput("Out", {end_pattern_out[k]->Name()});
end_patter_layernorms[k]->Op()->GetAttr("epsilon")); new_op_desc.SetAttr("epsilon",
end_patter_layernorms[k]->Op()->GetAttr("epsilon"));
if (end_patter_layernorms[k]->Op()->HasAttr("out_threshold")) {
new_op_desc.SetAttr("enable_int8", true); if (end_patter_layernorms[k]->Op()->HasAttr("out_threshold")) {
new_op_desc.SetAttr( new_op_desc.SetAttr("enable_int8", true);
"out_threshold", new_op_desc.SetAttr(
end_patter_layernorms[k]->Op()->GetAttr("out_threshold")); "out_threshold",
} end_patter_layernorms[k]->Op()->GetAttr("out_threshold"));
}
auto* embedding_eltwise_layernorm = graph->CreateOpNode(&new_op_desc); auto* embedding_eltwise_layernorm = graph->CreateOpNode(&new_op_desc);
for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) { for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) {
IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].first, IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].first,
embedding_eltwise_layernorm); embedding_eltwise_layernorm);
IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].second, IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].second,
embedding_eltwise_layernorm); embedding_eltwise_layernorm);
} }
for (size_t iter = 0; iter < js.size(); ++iter) { for (size_t iter = 0; iter < js.size(); ++iter) {
IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].first, IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].first,
embedding_eltwise_layernorm); embedding_eltwise_layernorm);
IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].second, IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].second,
embedding_eltwise_layernorm); embedding_eltwise_layernorm);
} }
IR_NODE_LINK_TO(end_pattern_biases[k], embedding_eltwise_layernorm); IR_NODE_LINK_TO(end_pattern_biases[k], embedding_eltwise_layernorm);
IR_NODE_LINK_TO(end_pattern_scales[k], embedding_eltwise_layernorm); IR_NODE_LINK_TO(end_pattern_scales[k], embedding_eltwise_layernorm);
IR_NODE_LINK_TO(embedding_eltwise_layernorm, end_pattern_out[k]); IR_NODE_LINK_TO(embedding_eltwise_layernorm, end_pattern_out[k]);
// Remove unneeded nodes. // Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes; std::unordered_set<const Node*> marked_nodes;
marked_nodes.insert(start_pattern_remove_nodes[i].begin(), marked_nodes.insert(start_pattern_remove_nodes[i].begin(),
start_pattern_remove_nodes[i].end()); start_pattern_remove_nodes[i].end());
marked_nodes.insert(end_pattern_remove_nodes[k].begin(), marked_nodes.insert(end_pattern_remove_nodes[k].begin(),
end_pattern_remove_nodes[k].end()); end_pattern_remove_nodes[k].end());
for (size_t iter = 0; iter < js.size(); ++iter) { for (size_t iter = 0; iter < js.size(); ++iter) {
marked_nodes.insert(inner_pattern_remove_nodes[js[iter]].begin(), marked_nodes.insert(inner_pattern_remove_nodes[js[iter]].begin(),
inner_pattern_remove_nodes[js[iter]].end()); inner_pattern_remove_nodes[js[iter]].end());
}
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
} else {
VLOG(3) << "Shape check failed, stop "
"trt_embedding_eltwise_layernorm_fuse_pass.";
} }
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
} }
return fusion_count; return fusion_count;
} }
......
...@@ -210,8 +210,8 @@ if(WITH_GPU AND TENSORRT_FOUND) ...@@ -210,8 +210,8 @@ if(WITH_GPU AND TENSORRT_FOUND)
set_tests_properties(test_merge_layernorm_fuse_pass PROPERTIES TIMEOUT 180) set_tests_properties(test_merge_layernorm_fuse_pass PROPERTIES TIMEOUT 180)
set_tests_properties(test_skip_merge_layernorm_fuse_pass PROPERTIES TIMEOUT set_tests_properties(test_skip_merge_layernorm_fuse_pass PROPERTIES TIMEOUT
180) 180)
set_tests_properties(test_emb_eltwise_layernorm_fuse_pass PROPERTIES TIMEOUT set_tests_properties(test_trt_emb_eltwise_layernorm_fuse_pass
120) PROPERTIES TIMEOUT 180)
set_tests_properties(test_fc_fuse_pass PROPERTIES TIMEOUT 240) set_tests_properties(test_fc_fuse_pass PROPERTIES TIMEOUT 240)
set_tests_properties(test_reverse_roll_fuse_pass PROPERTIES TIMEOUT 120) set_tests_properties(test_reverse_roll_fuse_pass PROPERTIES TIMEOUT 120)
......
...@@ -17,11 +17,9 @@ from functools import partial ...@@ -17,11 +17,9 @@ from functools import partial
import hypothesis.strategies as st import hypothesis.strategies as st
import numpy as np import numpy as np
from auto_scan_test import IgnoreReasons, PassAutoScanTest from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig from program_config import OpConfig, ProgramConfig, TensorConfig
import paddle.inference as paddle_infer
class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest): class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest):
r''' r'''
...@@ -43,48 +41,18 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest): ...@@ -43,48 +41,18 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest):
''' '''
def is_program_valid(self, program_config: ProgramConfig) -> bool: def is_program_valid(self, program_config: ProgramConfig) -> bool:
# is_sparse is only support False
if program_config.ops[0].attrs['is_sparse']:
return False
# is_distributed only support False
if program_config.ops[0].attrs['is_distributed']:
return False
# axis only support -1 and the last dim.
if program_config.ops[3].attrs['axis'] not in [-1, 2]:
return False
if not (
program_config.ops[5].attrs['epsilon'] >= 0
and program_config.ops[5].attrs['epsilon'] <= 0.001
):
return False
if program_config.ops[5].attrs['begin_norm_axis'] != 2:
return False
# input check
if (
program_config.weights['embedding_weight1'].shape[1]
!= program_config.weights['layer_norm_scale'].shape[0]
):
return False
return True return True
def sample_program_config(self, draw): def sample_program_config(self, draw):
is_sparse = draw(st.booleans()) padding_idx = -1
is_distributed = draw(st.booleans()) axis = -1
padding_idx = draw(st.integers())
axis = draw(st.integers(min_value=-4, max_value=4))
op_type = draw(st.sampled_from(['lookup_table', 'lookup_table_v2'])) op_type = draw(st.sampled_from(['lookup_table', 'lookup_table_v2']))
epsilon = draw(st.floats(min_value=0, max_value=0.001)) epsilon = draw(st.floats(min_value=0.0001, max_value=0.001))
# begin_norm_axis has to be 2 # begin_norm_axis has to be 2
begin_norm_axis = 2 begin_norm_axis = 2
batch_size = draw(st.integers(min_value=1, max_value=4)) batch_size = draw(st.integers(min_value=1, max_value=4))
input_dim = draw(st.sampled_from([32, 64])) input_dim = 128
weight_size = draw(st.sampled_from([[64, 64], [64, 32]])) weight_size = [64, 384]
def generate_input(attrs): def generate_input(attrs):
if attrs[0]['op_type'] == 'lookup_table': if attrs[0]['op_type'] == 'lookup_table':
...@@ -102,23 +70,22 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest): ...@@ -102,23 +70,22 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest):
def generate_weight1(attrs): def generate_weight1(attrs):
# set embedding weight by attrs # set embedding weight by attrs
return np.random.random(attrs['weight_size']).astype(np.float32) return np.random.uniform(0.1, 0.1, attrs['weight_size']).astype(
np.float32
)
def generate_weight2(attrs): def generate_weight2(attrs):
# set layernorm weight by attrs return np.random.uniform(1, 1.1, attrs[3]['weight_size'][1]).astype(
if attrs[2]['begin_norm_axis'] == 1: np.float32
return np.random.random( )
attrs[3]['input_dim'] * attrs[3]['weight_size'][1]
).astype(np.float32) def generate_weight3(attrs):
else: return np.random.uniform(
return np.random.random(attrs[3]['weight_size'][1]).astype( 0.001, 0.005, attrs[3]['weight_size'][1]
np.float32 ).astype(np.float32)
)
attrs = [ attrs = [
{ {
'is_sparse': is_sparse,
'is_distributed': is_distributed,
'padding_idx': padding_idx, 'padding_idx': padding_idx,
'op_type': op_type, 'op_type': op_type,
}, },
...@@ -136,8 +103,6 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest): ...@@ -136,8 +103,6 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest):
inputs={"Ids": ["input_data1"], "W": ["embedding_weight1"]}, inputs={"Ids": ["input_data1"], "W": ["embedding_weight1"]},
outputs={"Out": ["embedding_output1"]}, outputs={"Out": ["embedding_output1"]},
attrs={ attrs={
'is_sparse': attrs[0]['is_sparse'],
'is_distributed': attrs[0]['is_distributed'],
'padding_idx': attrs[0]['padding_idx'], 'padding_idx': attrs[0]['padding_idx'],
}, },
) )
...@@ -146,8 +111,6 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest): ...@@ -146,8 +111,6 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest):
inputs={"Ids": ["input_data2"], "W": ["embedding_weight2"]}, inputs={"Ids": ["input_data2"], "W": ["embedding_weight2"]},
outputs={"Out": ["embedding_output2"]}, outputs={"Out": ["embedding_output2"]},
attrs={ attrs={
'is_sparse': attrs[0]['is_sparse'],
'is_distributed': attrs[0]['is_distributed'],
'padding_idx': attrs[0]['padding_idx'], 'padding_idx': attrs[0]['padding_idx'],
}, },
) )
...@@ -156,8 +119,6 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest): ...@@ -156,8 +119,6 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest):
inputs={"Ids": ["input_data3"], "W": ["embedding_weight3"]}, inputs={"Ids": ["input_data3"], "W": ["embedding_weight3"]},
outputs={"Out": ["embedding_output3"]}, outputs={"Out": ["embedding_output3"]},
attrs={ attrs={
'is_sparse': attrs[0]['is_sparse'],
'is_distributed': attrs[0]['is_distributed'],
'padding_idx': attrs[0]['padding_idx'], 'padding_idx': attrs[0]['padding_idx'],
}, },
) )
...@@ -210,7 +171,7 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest): ...@@ -210,7 +171,7 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest):
data_gen=partial(generate_weight1, attrs[3]) data_gen=partial(generate_weight1, attrs[3])
), ),
"layer_norm_bias": TensorConfig( "layer_norm_bias": TensorConfig(
data_gen=partial(generate_weight2, attrs) data_gen=partial(generate_weight3, attrs)
), ),
"layer_norm_scale": TensorConfig( "layer_norm_scale": TensorConfig(
data_gen=partial(generate_weight2, attrs) data_gen=partial(generate_weight2, attrs)
...@@ -236,81 +197,244 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest): ...@@ -236,81 +197,244 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest):
# only used in gpu passes and trt passes. # only used in gpu passes and trt passes.
config = self.create_inference_config(use_gpu=True) config = self.create_inference_config(use_gpu=True)
yield config, ['fused_embedding_eltwise_layernorm'], (1e-5, 1e-5) yield config, ['fused_embedding_eltwise_layernorm'], (1e-5, 1e-5)
# trt static_shape
config = self.create_trt_inference_config() def add_ignore_pass_case(self):
config.enable_tensorrt_engine( pass
max_batch_size=4,
workspace_size=102400, def test(self):
min_subgraph_size=0, # this fuse need to fix, now there's no program can ran successfully
precision_mode=paddle_infer.PrecisionType.Half, self.run_and_statis(
use_static=False, quant=False,
use_calib_mode=False, max_examples=50,
passes=["embedding_eltwise_layernorm_fuse_pass"],
min_success_num=0,
) )
yield config, ['fused_embedding_eltwise_layernorm'], (1e-5, 1e-5)
# trt dynamic_shape
config = self.create_trt_inference_config() class TestEmbeddingEltwiseLayerNormFusePassNoBroadcast(PassAutoScanTest):
config.enable_tensorrt_engine( r'''
max_batch_size=4, in_var1 emb_var in_var2 emb_var in_var3 emb_var in_var emb_var
workspace_size=102400, | | | | | | | |
min_subgraph_size=0, lookup_table lookup_table lookup_table ... lookup_table
precision_mode=paddle_infer.PrecisionType.Half, | | | |
use_static=False, lkt_var lkt_var lkt_var lkt_var
use_calib_mode=False, \ / | ... |
elementwise_add | |
\ / |
elementwise_add |
| |
elt_var /
\ /
elementwise_add
|
layer_norm
'''
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_config(self, draw):
padding_idx = 0
axis = -1
op_type = draw(st.sampled_from(['lookup_table', 'lookup_table_v2']))
epsilon = 0.0001
# begin_norm_axis has to be 2
begin_norm_axis = 2
batch_size = 4
input_dim = [128, 128, 1]
weight_size = [64, 384]
def generate_input1(attrs):
if attrs[0]['op_type'] == 'lookup_table':
return np.random.randint(
0,
attrs[3]['weight_size'][0],
size=(attrs[3]['batch_size'], attrs[3]['input_dim'][0], 1),
).astype(np.int64)
else:
return np.random.randint(
0,
attrs[3]['weight_size'][0],
size=(attrs[3]['batch_size'], attrs[3]['input_dim'][0]),
).astype(np.int64)
def generate_input2(attrs):
if attrs[0]['op_type'] == 'lookup_table':
return np.random.randint(
0,
attrs[3]['weight_size'][0],
size=(attrs[3]['batch_size'], attrs[3]['input_dim'][1], 1),
).astype(np.int64)
else:
return np.random.randint(
0,
attrs[3]['weight_size'][0],
size=(attrs[3]['batch_size'], attrs[3]['input_dim'][1]),
).astype(np.int64)
def generate_input3(attrs):
if attrs[0]['op_type'] == 'lookup_table':
return np.random.randint(
0,
attrs[3]['weight_size'][0],
size=(attrs[3]['batch_size'], attrs[3]['input_dim'][2], 1),
).astype(np.int64)
else:
return np.random.randint(
0,
attrs[3]['weight_size'][0],
size=(attrs[3]['batch_size'], attrs[3]['input_dim'][2]),
).astype(np.int64)
def generate_weight1(attrs):
# set embedding weight by attrs
return np.random.uniform(0.1, 0.1, attrs['weight_size']).astype(
np.float32
)
def generate_weight2(attrs):
return np.random.uniform(1, 1.1, attrs[3]['weight_size'][1]).astype(
np.float32
)
def generate_weight3(attrs):
return np.random.uniform(
0.001, 0.005, attrs[3]['weight_size'][1]
).astype(np.float32)
attrs = [
{
'padding_idx': padding_idx,
'op_type': op_type,
},
{'axis': axis},
{'begin_norm_axis': begin_norm_axis, 'epsilon': epsilon},
{
'batch_size': batch_size,
'input_dim': input_dim,
'weight_size': weight_size,
},
]
emb_op1 = OpConfig(
type=attrs[0]['op_type'],
inputs={"Ids": ["input_data1"], "W": ["embedding_weight1"]},
outputs={"Out": ["embedding_output1"]},
attrs={
'padding_idx': attrs[0]['padding_idx'],
},
)
emb_op2 = OpConfig(
type=attrs[0]['op_type'],
inputs={"Ids": ["input_data2"], "W": ["embedding_weight2"]},
outputs={"Out": ["embedding_output2"]},
attrs={
'padding_idx': attrs[0]['padding_idx'],
},
)
emb_op3 = OpConfig(
type=attrs[0]['op_type'],
inputs={"Ids": ["input_data3"], "W": ["embedding_weight3"]},
outputs={"Out": ["embedding_output3"]},
attrs={
'padding_idx': attrs[0]['padding_idx'],
},
) )
add_op1 = OpConfig(
type='elementwise_add',
inputs={
"X": [emb_op2.outputs["Out"][0]],
"Y": [emb_op3.outputs["Out"][0]],
},
outputs={"Out": ["elementwise_add_output1"]},
attrs={"axis": attrs[1]['axis']},
)
add_op2 = OpConfig(
type='elementwise_add',
inputs={
"X": [add_op1.outputs["Out"][0]],
"Y": [emb_op1.outputs["Out"][0]],
},
outputs={"Out": ["elementwise_add_output2"]},
attrs={"axis": attrs[1]['axis']},
)
layer_norm_op = OpConfig(
type='layer_norm',
inputs={
"X": [add_op2.outputs["Out"][0]],
"Bias": ["layer_norm_bias"],
"Scale": ["layer_norm_scale"],
},
outputs={
"Y": ["layer_norm_output1"],
"Mean": ["layer_norm_output2"],
"Variance": ["layer_norm_output3"],
},
attrs={
'begin_norm_axis': attrs[2]['begin_norm_axis'],
'epsilon': attrs[2]['epsilon'],
},
)
program_config = ProgramConfig(
ops=[emb_op1, emb_op2, emb_op3, add_op1, add_op2, layer_norm_op],
weights={
"embedding_weight1": TensorConfig(
data_gen=partial(generate_weight1, attrs[3])
),
"embedding_weight2": TensorConfig(
data_gen=partial(generate_weight1, attrs[3])
),
"embedding_weight3": TensorConfig(
data_gen=partial(generate_weight1, attrs[3])
),
"layer_norm_bias": TensorConfig(
data_gen=partial(generate_weight3, attrs)
),
"layer_norm_scale": TensorConfig(
data_gen=partial(generate_weight2, attrs)
),
},
inputs={
"input_data1": TensorConfig(
data_gen=partial(generate_input1, attrs)
),
"input_data2": TensorConfig(
data_gen=partial(generate_input2, attrs)
),
"input_data3": TensorConfig(
data_gen=partial(generate_input3, attrs)
),
},
outputs=["layer_norm_output1"],
)
return program_config
def sample_predictor_configs(self, program_config):
# only used in gpu passes and trt passes.
config = self.create_inference_config(use_gpu=True)
if program_config.ops[0].type == 'lookup_table': if program_config.ops[0].type == 'lookup_table':
config.set_trt_dynamic_shape_info( yield config, [
{ 'lookup_table',
"input_data1": [1, 4, 1], 'lookup_table',
"input_data2": [1, 4, 1], 'lookup_table',
"input_data3": [1, 4, 1], 'elementwise_add',
}, 'elementwise_add',
{ 'layer_norm',
"input_data1": [4, 512, 1], ], (1e-5, 1e-5)
"input_data2": [4, 512, 1],
"input_data3": [4, 512, 1],
},
{
"input_data1": [2, 128, 1],
"input_data2": [2, 128, 1],
"input_data3": [2, 128, 1],
},
)
else: else:
config.set_trt_dynamic_shape_info( yield config, [
{ 'lookup_table_v2',
"input_data1": [1, 4], 'lookup_table_v2',
"input_data2": [1, 4], 'lookup_table_v2',
"input_data3": [1, 4], 'elementwise_add',
}, 'elementwise_add',
{ 'layer_norm',
"input_data1": [4, 512], ], (1e-5, 1e-5)
"input_data2": [4, 512],
"input_data3": [4, 512],
},
{
"input_data1": [2, 128],
"input_data2": [2, 128],
"input_data3": [2, 128],
},
)
yield config, ['fused_embedding_eltwise_layernorm'], (1e-5, 1e-5)
def add_ignore_pass_case(self): def add_ignore_pass_case(self):
def teller1(program_config, predictor_config): pass
if (
program_config.ops[3].attrs['axis'] in [-1, 2]
and program_config.ops[5].attrs['begin_norm_axis'] == 2
and program_config.weights['embedding_weight1'].shape
in [(64, 32), (64, 64)]
):
return True
return False
self.add_ignore_check_case(
teller1,
IgnoreReasons.PASS_ACCURACY_ERROR,
"The pass output has diff in a specific case. We need to fix it as soon as possible.",
)
def test(self): def test(self):
# this fuse need to fix, now there's no program can ran successfully # this fuse need to fix, now there's no program can ran successfully
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册