提交 7f5c8a95 编写于 作者: T Tomasz Patejko

MKLDNN conv + elementwise_add fusion: arguments are replaced for many parameters in operator

上级 5996bd39
......@@ -109,9 +109,23 @@ void LinkNodes(Node* from, Node* to) {
to->inputs.push_back(from);
}
template<typename IT, typename FindFunc, typename ReplaceFunc>
void ReplaceAllOccurances(IT s, IT e, FindFunc f, ReplaceFunc r) {
if (s == e)
return;
auto it = std::find_if(s, e, f);
if (it != e) {
r(*it);
}
it++;
ReplaceAllOccurances(it, e, f, r);
}
void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
for (auto& node : GraphTraits::DFS(*graph)) {
std::vector<Node*> to_remove;
auto same = std::find_if(std::begin(node.inputs),
std::end(node.inputs),
[from](Node* n) { return n == from; });
......@@ -121,15 +135,19 @@ void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
auto inputs = node.Op()->Inputs();
std::for_each(std::begin(inputs), std::end(inputs),
[from, to](const std::pair<std::string, std::vector<std::string>>& i) -> void {
auto params = i.second;
std::remove_if(std::begin(params), std::end(params),
std::bind(std::equal_to<std::string>(), from->Name(), std::placeholders::_1));
params.push_back(to->Name());
});
using input_type = VariableNameMap::value_type;
ReplaceAllOccurances(std::begin(inputs), std::end(inputs),
[from](const input_type& i) -> bool {
auto params = i.second;
auto pi = std::find_if(std::begin(params), std::end(params),
std::bind(std::equal_to<std::string>(),
from->Name(), std::placeholders::_1));
return pi != std::end(params);
},
[to, &node](const input_type& i) {
node.Op()->SetInput(i.first, {to->Name()});
});
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册