提交 af8c7131 编写于 作者: T Tomasz Patejko

MKLDNN conv + elementwise_add fusion: CorrectGraphEdges refactored

上级 3e033087
...@@ -20,51 +20,33 @@ ...@@ -20,51 +20,33 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
namespace patterns { namespace {
void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
template <typename IT, typename FindFunc, typename ReplaceFunc>
static void ReplaceAllOccurances(IT s, IT e, FindFunc f, ReplaceFunc r) {
if (s == e) return;
auto it = std::find_if(s, e, f);
if (it != e) {
r(*it);
}
it++;
ReplaceAllOccurances(it, e, f, r);
}
static void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
for (auto& node : GraphTraits::DFS(*graph)) { for (auto& node : GraphTraits::DFS(*graph)) {
auto same = std::find_if(std::begin(node.inputs), std::end(node.inputs), auto from_in_inputs =
[from](Node* n) { return n == from; }); std::find(std::begin(node.inputs), std::end(node.inputs), from);
if (same != std::end(node.inputs)) { if (from_in_inputs != std::end(node.inputs)) {
IR_NODE_LINK_TO(to, (&node)); IR_NODE_LINK_TO(to, (&node));
auto inputs = node.Op()->Inputs(); auto inputs = node.Op()->Inputs();
using input_type = VariableNameMap::value_type; using input_type = VariableNameMap::value_type;
ReplaceAllOccurances( std::for_each(std::begin(inputs), std::end(inputs),
std::begin(inputs), std::end(inputs), [from, to, &node](const input_type& i) -> void {
[from](const input_type& i) -> bool { auto param_names = i.second;
auto params = i.second; auto pi = std::find(std::begin(param_names),
auto pi = std::end(param_names), from->Name());
std::find_if(std::begin(params), std::end(params),
std::bind(std::equal_to<std::string>(), if (pi != std::end(param_names)) {
from->Name(), std::placeholders::_1));
return pi != std::end(params);
},
[to, &node](const input_type& i) {
node.Op()->SetInput(i.first, {to->Name()}); node.Op()->SetInput(i.first, {to->Name()});
}
}); });
} }
} }
} }
} // namespace patterns } // namespace
using graph_ptr = std::unique_ptr<ir::Graph>; using graph_ptr = std::unique_ptr<ir::Graph>;
graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
...@@ -116,7 +98,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { ...@@ -116,7 +98,7 @@ graph_ptr ConvElementwiseAddMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
IR_NODE_LINK_TO(elementwise_add_x, fused_conv_op); IR_NODE_LINK_TO(elementwise_add_x, fused_conv_op);
IR_NODE_LINK_TO(fused_conv_op, conv_output); IR_NODE_LINK_TO(fused_conv_op, conv_output);
patterns::CorrectGraphEdges(g, elementwise_add_out, conv_output); CorrectGraphEdges(g, elementwise_add_out, conv_output);
GraphSafeRemoveNodes(g, {elementwise_add_out, conv_op, elementwise_add_op}); GraphSafeRemoveNodes(g, {elementwise_add_out, conv_op, elementwise_add_op});
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册