From 04e8759ee2416baac1f31f6a27cb49a8b6051e19 Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Tue, 11 Dec 2018 18:36:46 -0800 Subject: [PATCH] [Grappler] Add helper functions to GraphView. PiperOrigin-RevId: 225109110 --- tensorflow/core/grappler/graph_view.h | 60 ++++++++++++++------- tensorflow/core/grappler/graph_view_test.cc | 34 ++++++++++++ tensorflow/core/grappler/utils.cc | 7 ++- tensorflow/core/grappler/utils.h | 4 ++ tensorflow/core/grappler/utils_test.cc | 7 +++ 5 files changed, 92 insertions(+), 20 deletions(-) diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h index 0a47b225658..16156d0f204 100644 --- a/tensorflow/core/grappler/graph_view.h +++ b/tensorflow/core/grappler/graph_view.h @@ -111,32 +111,37 @@ class GraphViewInternal { GraphDefT* graph() const { return graph_; } - // Find a node by name or return `nullptr` if it's not in a graph view. + // Finds a node by name or return `nullptr` if it's not in the graph view. NodeDefT* GetNode(absl::string_view node_name) const { return gtl::FindWithDefault(nodes_, node_name, nullptr); } - // Get the specified input port. Note that the special '-1' port_id can be + // Checks if a node by name is in the graph view. + bool HasNode(absl::string_view node_name) const { + return GetNode(node_name) != nullptr; + } + + // Gets the specified input port. Note that the special '-1' port_id can be // used to access the controlling nodes (i.e. the nodes connected to node_name // through an incoming control dependency). InputPort GetInputPort(absl::string_view node_name, int port_id) const { return InputPort(GetNode(node_name), port_id); } - // Get the specified output port. Note that the special '-1' port_id can be + // Gets the specified output port. Note that the special '-1' port_id can be // used to access the controlled nodes (i.e. the nodes connected to node_name // through an outgoing control dependency). OutputPort GetOutputPort(absl::string_view node_name, int port_id) const { return OutputPort(GetNode(node_name), port_id); } - // Get the input (resp. output) port(s) in the immediate fanout (resp. fanin) - // of an output (resp. input) port. + // Gets the input port(s) in the immediate fanout of an output port. const absl::flat_hash_set& GetFanout( const OutputPort& port) const { return gtl::FindWithDefault(fanouts_, port, fanout_not_found_value_); } + // Gets the output port(s) in the immediate fanin of an input port. absl::flat_hash_set GetFanin(const InputPort& port) const { if (port.port_id >= 0) return {GetRegularFanin(port)}; @@ -162,9 +167,22 @@ class GraphViewInternal { return GetOutputPort(tensor_id.node(), tensor_id.index()); } - // Get all the input (resp. output) ports in the immediate fanout (resp - // fanin) of a node. Include the controlling nodes iff - // include_controlling_nodes is true. + // Checks if a tensor id is a fanin of the node. + bool HasFanin(const NodeDef& node, const TensorId& fanin) const { + if (fanin.index() < -1) { + return false; + } + string fanin_string = TensorIdToString(fanin); + for (int i = 0; i < node.input_size(); ++i) { + if (node.input(i) == fanin_string) { + return true; + } + } + return false; + } + + // Gets all the input ports in the immediate fanout of a node. Include the + // controlled nodes iff include_controlled_nodes is true. absl::flat_hash_set GetFanouts( const NodeDef& node, bool include_controlled_nodes) const { absl::flat_hash_set result; @@ -185,6 +203,8 @@ class GraphViewInternal { return result; } + // Gets all the output ports in the immediate fanin of a node. Include the + // controlling nodes iff include_controlling_nodes is true. absl::flat_hash_set GetFanins( const NodeDef& node, bool include_controlling_nodes) const { absl::flat_hash_set result; @@ -198,7 +218,7 @@ class GraphViewInternal { return result; } - // Get the number of ports in the immediate fanin of a node. Count the + // Gets the number of ports in the immediate fanin of a node. Count the // controlling nodes iff include_controlling_nodes is true. int NumFanins(const NodeDef& node, bool include_controlling_nodes) const { int count = 0; @@ -211,14 +231,14 @@ class GraphViewInternal { return count; } - // Get the number of ports in the immediate fanout of a node. Count the - // controlling nodes iff include_controlling_nodes is true. - int NumFanouts(const NodeDef& node, bool include_controlling_nodes) const { + // Gets the number of ports in the immediate fanout of a node. Count the + // controlled nodes iff include_controlled_nodes is true. + int NumFanouts(const NodeDef& node, bool include_controlled_nodes) const { int count = 0; OutputPort port; port.node = const_cast(&node); - const int first_port_id = include_controlling_nodes ? -1 : 0; + const int first_port_id = include_controlled_nodes ? -1 : 0; const int last_port_id = gtl::FindWithDefault(max_regular_output_port_, port.node, -1); @@ -231,8 +251,8 @@ class GraphViewInternal { return count; } - // Get all the edges in the immediate fanout (resp fanin) of a node. - // Include the control edges iff include_controlling_edges is true. + // Gets all the edges in the immediate fanout of a node. Include the + // controlled edges iff include_controlled_edges is true. absl::flat_hash_set GetFanoutEdges( const NodeDef& node, bool include_controlled_edges) const { absl::flat_hash_set result; @@ -248,14 +268,16 @@ class GraphViewInternal { auto it = fanouts_.find(port); if (it != fanouts_.end()) { for (auto itr = it->second.begin(); itr != it->second.end(); ++itr) { - result.emplace(/*src*/ OutputPort(const_cast(&node), i), - /*dst*/ *itr); + result.emplace(/*src=*/OutputPort(const_cast(&node), i), + /*dst=*/*itr); } } } return result; } + // Gets all the edges in the immediate fanin of a node. Include the + // controlling edges iff include_controlling_edges is true. absl::flat_hash_set GetFaninEdges( const NodeDef& node, bool include_controlling_edges) const { absl::flat_hash_set result; @@ -265,8 +287,8 @@ class GraphViewInternal { auto it = nodes_.find(tensor_id.node()); if (it != nodes_.end()) { - result.emplace(/*src*/ OutputPort(it->second, tensor_id.index()), - /*dst*/ InputPort(const_cast(&node), i)); + result.emplace(/*src=*/OutputPort(it->second, tensor_id.index()), + /*dst=*/InputPort(const_cast(&node), i)); } } return result; diff --git a/tensorflow/core/grappler/graph_view_test.cc b/tensorflow/core/grappler/graph_view_test.cc index cbf859a4a99..404dcd30c12 100644 --- a/tensorflow/core/grappler/graph_view_test.cc +++ b/tensorflow/core/grappler/graph_view_test.cc @@ -230,6 +230,40 @@ TEST_F(GraphViewTest, ControlDependencies) { EXPECT_EQ(0, (*fanin.begin()).port_id); } +TEST_F(GraphViewTest, HasNode) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10}); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + GraphView graph(&item.graph); + + EXPECT_EQ(true, graph.HasNode("a")); + EXPECT_EQ(false, graph.HasNode("b")); +} + +TEST_F(GraphViewTest, HasFanin) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10}); + Output b = ops::Square(s.WithOpName("b"), {a}); + Output c = ops::Sqrt(s.WithOpName("c"), {b}); + Output d = ops::AddN(s.WithOpName("d").WithControlDependencies(a), {b, c}); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + GraphView graph(&item.graph); + + const NodeDef* d_node = graph.GetNode("d"); + EXPECT_NE(nullptr, d_node); + + EXPECT_EQ(true, graph.HasFanin(*d_node, {"a", Graph::kControlSlot})); + EXPECT_EQ(false, graph.HasFanin(*d_node, {"a", 0})); + EXPECT_EQ(true, graph.HasFanin(*d_node, {"b", 0})); + EXPECT_EQ(false, graph.HasFanin(*d_node, {"b", Graph::kControlSlot})); + EXPECT_EQ(true, graph.HasFanin(*d_node, {"c", 0})); + EXPECT_EQ(false, graph.HasFanin(*d_node, {"c", Graph::kControlSlot})); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index 29775442629..90ad04cf47b 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -144,11 +144,16 @@ void NodeMap::UpdateOutput(const string& node_name, outputs.insert(nodes_[NodeName(new_output_name)]); } +string TensorIdToString(const TensorId& tensor_id) { + return tensor_id.index() == 0 ? string(tensor_id.node()) + : tensor_id.ToString(); +} + bool IsSameInput(const string& name1, const string& name2) { if (name1 == name2) return true; TensorId tensor1 = ParseTensorName(name1); TensorId tensor2 = ParseTensorName(name2); - return tensor1.node() == tensor2.node() && tensor1.index() == tensor2.index(); + return tensor1 == tensor2; } bool IsControlInput(const string& name) { diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index b1e2d4e9cb5..89a87af323a 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -100,6 +100,10 @@ class SetVector { std::vector vector_; }; +// Returns formatted string from TensorId specific to grappler. Specifically, +// for the 0 port (first output), only the node name is returned. +string TensorIdToString(const TensorId& tensor_id); + // True iff 'name' refers to a control inputs, i.e. a node name prefixed with // the ^ character. bool IsControlInput(const string& name); diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc index e993391b51b..f5ae39867ac 100644 --- a/tensorflow/core/grappler/utils_test.cc +++ b/tensorflow/core/grappler/utils_test.cc @@ -464,6 +464,13 @@ TEST_F(UtilsTest, SetTensorValueBFloat16IntMin) { Tensor(bfloat16(std::numeric_limits::min())), t); } +TEST_F(UtilsTest, TensorIdToString) { + EXPECT_EQ("^foo", TensorIdToString({"foo", -1})); + EXPECT_EQ("foo", TensorIdToString({"foo", 0})); + EXPECT_EQ("foo:1", TensorIdToString({"foo", 1})); + EXPECT_EQ("foo:2", TensorIdToString({"foo", 2})); +} + } // namespace } // namespace grappler } // namespace tensorflow -- GitLab