未验证 提交 60546b78 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #15923 from Sand3r-/mgallus/conv-residual-ut

Add Conv Residual Connection UT for Projection
...@@ -44,10 +44,14 @@ struct TestIsReachable { ...@@ -44,10 +44,14 @@ struct TestIsReachable {
using func = std::function<bool(const std::string&, const std::string&)>; using func = std::function<bool(const std::string&, const std::string&)>;
auto operator()(const std::unique_ptr<ir::Graph>& graph) -> func { auto operator()(const std::unique_ptr<ir::Graph>& graph) -> func {
auto find_node = [](const std::unique_ptr<ir::Graph>& graph, auto hash = [](const Node* node) -> std::string {
const std::string& name) -> Node* { return node->Name() + std::to_string(node->id());
};
auto find_node = [&](const std::unique_ptr<ir::Graph>& graph,
const std::string& name) -> Node* {
for (auto& node : GraphTraits::DFS(*graph)) { for (auto& node : GraphTraits::DFS(*graph)) {
if (name == node.Name()) { if (name == hash(&node)) {
return &node; return &node;
} }
} }
...@@ -55,13 +59,17 @@ struct TestIsReachable { ...@@ -55,13 +59,17 @@ struct TestIsReachable {
return nullptr; return nullptr;
}; };
return [&](std::string from, const std::string to) -> bool { // update the from and to strings to hashed equivs in loop from graph traits
return [&](std::string from, std::string to) -> bool {
if (from == to) return true; if (from == to) return true;
std::map<std::string, bool> visited; std::map<std::string, bool> visited;
for (auto& node : GraphTraits::DFS(*graph)) { for (auto& node : GraphTraits::DFS(*graph)) {
visited[node.Name()] = false; auto hashed = hash(&node);
if (node.Name() == from) from = hashed;
if (node.Name() == to) to = hashed;
visited[hashed] = false;
} }
visited[from] = true; visited[from] = true;
...@@ -72,15 +80,15 @@ struct TestIsReachable { ...@@ -72,15 +80,15 @@ struct TestIsReachable {
while (!queue.empty()) { while (!queue.empty()) {
auto cur = find_node(graph, queue.front()); auto cur = find_node(graph, queue.front());
queue.pop_front(); queue.pop_front();
if (cur == nullptr) return false; if (cur == nullptr) return false;
for (auto n : cur->outputs) { for (auto n : cur->outputs) {
if (n->Name() == to) return true; auto hashed_name = hash(n);
if (hashed_name == to) return true;
if (!visited[n->Name()]) { if (!visited[hashed_name]) {
visited[n->Name()] = true; visited[hashed_name] = true;
queue.push_back(n->Name()); queue.push_back(hashed_name);
} }
} }
} }
...@@ -166,6 +174,28 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsYWithElementwiseAddRelu) { ...@@ -166,6 +174,28 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsYWithElementwiseAddRelu) {
RunPassAndAssert(&prog, "a", "relu", 1); RunPassAndAssert(&prog, "a", "relu", 1);
} }
TEST(ConvElementwiseAddMKLDNNFusePass,
ConvolutionProjectionAsYWithElementwiseAddRelu) {
auto prog = BuildProgramDesc({"a", "b", "c", "d", "e", "f"},
{"bias", "weights", "bias2", "weights2"});
SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
// right branch
SetOp(&prog, "conv2d",
{{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
{"Output", "c"});
// left branch
SetOp(&prog, "conv2d",
{{"Input", "a"}, {"Bias", "bias2"}, {"Filter", "weights2"}},
{"Output", "f"});
SetOp(&prog, "elementwise_add", {{"X", "f"}, {"Y", "c"}}, {"Out", "d"});
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
RunPassAndAssert(&prog, "a", "relu", 2);
}
TEST(ConvElementwiseAddMKLDNNFusePass, TEST(ConvElementwiseAddMKLDNNFusePass,
ConvolutionAsYWithElementwiseAddReluNoBias) { ConvolutionAsYWithElementwiseAddReluNoBias) {
auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"}); auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册