未验证 提交 da258964 编写于 作者: R ronnywang 提交者: GitHub

[CustomPass] add support for outputting the intermediate variables (#55728)

* add support for outputting the intermediate variables

* fix fuse_rresnet_unit
上级 43fcd01b
...@@ -220,6 +220,16 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) { ...@@ -220,6 +220,16 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) {
}); });
} }
} }
// The output of the pattern must be marked as AsOutput.
for (const proto::PassDesc::VarMap& var_map : pass_desc.var_maps()) {
PDNode* var_pdnode = pattern->RetrieveNode(var_map.pattern_var());
PADDLE_ENFORCE_NOT_NULL(
var_pdnode,
platform::errors::NotFound("Not found the var %s in the pattern.",
var_map.pattern_var()));
var_pdnode->AsOutput();
}
} }
// There are some duplicate patterns. // There are some duplicate patterns.
...@@ -314,10 +324,10 @@ GraphPatternDetector::handle_t GetGenerateRewrite( ...@@ -314,10 +324,10 @@ GraphPatternDetector::handle_t GetGenerateRewrite(
} }
// `var_node_maps` record the mapping of variable to the pattern // `var_node_maps` record the mapping of variable to the pattern
// subgraph. // subgraph.
std::map<std::string, Node*> var_node_maps; std::map<std::string, std::vector<Node*>> var_node_maps;
for (const proto::PassDesc::VarMap& var_map : pass_desc.var_maps()) { for (const proto::PassDesc::VarMap& var_map : pass_desc.var_maps()) {
Node* node = subgraph.at(pattern.RetrieveNode(var_map.pattern_var())); Node* node = subgraph.at(pattern.RetrieveNode(var_map.pattern_var()));
var_node_maps.insert({var_map.replace_var(), node}); var_node_maps[var_map.replace_var()].emplace_back(node);
} }
// Traverse all operators to create subgraph. // Traverse all operators to create subgraph.
for (int index = 0; index < pass_desc.replace_size(); ++index) { for (int index = 0; index < pass_desc.replace_size(); ++index) {
...@@ -330,17 +340,13 @@ GraphPatternDetector::handle_t GetGenerateRewrite( ...@@ -330,17 +340,13 @@ GraphPatternDetector::handle_t GetGenerateRewrite(
std::vector<std::string> arguments; std::vector<std::string> arguments;
for (const std::string& argument : var.arguments()) { for (const std::string& argument : var.arguments()) {
// The input may be mapped on the operator of pattern subgraph. // The input may be mapped on the operator of pattern subgraph.
Node* node = nullptr; if (var_node_maps[argument].size() == 0) {
auto iter = var_node_maps.find(argument);
if (var_node_maps.end() == iter) {
VarDesc var_desc(patterns::UniqueKey(argument)); VarDesc var_desc(patterns::UniqueKey(argument));
node = graph->CreateVarNode(&var_desc); var_node_maps[argument].emplace_back(
var_node_maps.insert({argument, node}); graph->CreateVarNode(&var_desc));
} else {
node = iter->second;
} }
in_nodes.push_back(node); in_nodes.push_back(var_node_maps[argument][0]);
arguments.push_back(node->Name()); arguments.push_back(var_node_maps[argument][0]->Name());
} }
op_desc.SetInput(var.parameter(), arguments); op_desc.SetInput(var.parameter(), arguments);
} }
...@@ -349,22 +355,20 @@ GraphPatternDetector::handle_t GetGenerateRewrite( ...@@ -349,22 +355,20 @@ GraphPatternDetector::handle_t GetGenerateRewrite(
std::vector<std::string> arguments; std::vector<std::string> arguments;
for (const std::string& argument : var.arguments()) { for (const std::string& argument : var.arguments()) {
// The output may be mapped on the operator of pattern subgraph. // The output may be mapped on the operator of pattern subgraph.
Node* node = nullptr; if (var_node_maps[argument].size() == 0) {
auto iter = var_node_maps.find(argument);
if (var_node_maps.end() == iter) {
VarDesc var_desc(patterns::UniqueKey(argument)); VarDesc var_desc(patterns::UniqueKey(argument));
node = graph->CreateVarNode(&var_desc); var_node_maps[argument].emplace_back(
var_node_maps.insert({argument, node}); graph->CreateVarNode(&var_desc));
}
if (in_nodes.end() == std::find(in_nodes.begin(),
in_nodes.end(),
var_node_maps[argument][0])) {
out_nodes.push_back(var_node_maps[argument][0]);
} else { } else {
if (in_nodes.end() == out_nodes.push_back(
std::find(in_nodes.begin(), in_nodes.end(), iter->second)) { graph->CreateVarNode(var_node_maps[argument][0]->Var()));
node = iter->second;
} else {
node = graph->CreateVarNode(iter->second->Var());
}
} }
out_nodes.push_back(node); arguments.push_back(var_node_maps[argument][0]->Name());
arguments.push_back(node->Name());
} }
op_desc.SetOutput(var.parameter(), arguments); op_desc.SetOutput(var.parameter(), arguments);
} }
...@@ -413,7 +417,40 @@ GraphPatternDetector::handle_t GetGenerateRewrite( ...@@ -413,7 +417,40 @@ GraphPatternDetector::handle_t GetGenerateRewrite(
remove_nodes.emplace(subgraph.at(pdnode.get())); remove_nodes.emplace(subgraph.at(pdnode.get()));
} }
for (auto iter : var_node_maps) { for (auto iter : var_node_maps) {
remove_nodes.erase(iter.second); for (auto& node : iter.second) {
remove_nodes.erase(node);
}
}
GraphSafeRemoveNodes(graph, remove_nodes);
// Replace the redundant node by the first node in var_nodes_nmaps.
remove_nodes.clear();
for (auto iter : var_node_maps) {
auto var_node = iter.second[0];
for (size_t i = 1; i < iter.second.size(); ++i) {
auto replaced_var_node = iter.second[i];
for (auto op_node : replaced_var_node->outputs) {
auto index = std::find(op_node->inputs.begin(),
op_node->inputs.end(),
replaced_var_node) -
op_node->inputs.begin();
op_node->inputs[index] = var_node;
auto& input_name_maps = *op_node->Op()->MutableInputs();
for (auto& item : input_name_maps) {
auto iter = std::find(item.second.begin(),
item.second.end(),
replaced_var_node->Name());
if (iter != item.second.end()) {
item.second[iter - item.second.begin()] = var_node->Name();
input_name_maps[item.first] = item.second;
break;
}
}
op_node->Op()->Flush();
}
remove_nodes.emplace(replaced_var_node);
}
} }
GraphSafeRemoveNodes(graph, remove_nodes); GraphSafeRemoveNodes(graph, remove_nodes);
}; };
......
...@@ -90,7 +90,7 @@ def fuse_resnet_unit(): ...@@ -90,7 +90,7 @@ def fuse_resnet_unit():
varZ, varZ,
): ):
bnX = pattern_conv_bn(x, filterX, scaleX, biasX, meanX, varX) bnX = pattern_conv_bn(x, filterX, scaleX, biasX, meanX, varX)
bnZ = pattern_conv_bn(x, filterZ, scaleZ, biasZ, meanZ, varZ) bnZ = pattern_conv_bn(z, filterZ, scaleZ, biasZ, meanZ, varZ)
ewadd = ir.PassDesc.OP.elementwise_add( ewadd = ir.PassDesc.OP.elementwise_add(
X=bnX.Output("Y"), Y=bnZ.Output("Y") X=bnX.Output("Y"), Y=bnZ.Output("Y")
) )
......
...@@ -314,6 +314,36 @@ class PassDesc: ...@@ -314,6 +314,36 @@ class PassDesc:
return attr return attr
class OpHelper: class OpHelper:
def _to_readable_code(self, skip_op_callstack=True):
assert isinstance(
skip_op_callstack, bool
), "skip_op_callstack parameter's type is error, expect bool, received {}".format(
type(skip_op_callstack)
)
outputs_str = "{"
outputs_str += ", ".join(
[f"{k}={v}" for k, v in self._outputs.items()]
)
outputs_str += "}"
inputs_str = "{"
inputs_str += ", ".join(
[f"{k}={v}" for k, v in self._inputs.items()]
)
inputs_str += "}"
attrs_str = "{"
attrs_str += ", ".join([f"{k}={v}" for k, v in self._attrs.items()])
attrs_str += "}"
op_str = "{outputs} = {op_type}(inputs={inputs}, {attrs})".format(
outputs=outputs_str,
op_type=self._type,
inputs=inputs_str,
attrs=attrs_str,
)
return op_str
def __init__(self, type=None): def __init__(self, type=None):
self._type = type self._type = type
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册