提交 04e8759e 编写于 作者: A Andy Ly 提交者: TensorFlower Gardener

[Grappler] Add helper functions to GraphView.

PiperOrigin-RevId: 225109110
上级 ae244e6d
......@@ -111,32 +111,37 @@ class GraphViewInternal {
GraphDefT* graph() const { return graph_; }
// Find a node by name or return `nullptr` if it's not in a graph view.
// Finds a node by name or return `nullptr` if it's not in the graph view.
NodeDefT* GetNode(absl::string_view node_name) const {
return gtl::FindWithDefault(nodes_, node_name, nullptr);
}
// Get the specified input port. Note that the special '-1' port_id can be
// Checks if a node by name is in the graph view.
bool HasNode(absl::string_view node_name) const {
return GetNode(node_name) != nullptr;
}
// Gets the specified input port. Note that the special '-1' port_id can be
// used to access the controlling nodes (i.e. the nodes connected to node_name
// through an incoming control dependency).
InputPort GetInputPort(absl::string_view node_name, int port_id) const {
return InputPort(GetNode(node_name), port_id);
}
// Get the specified output port. Note that the special '-1' port_id can be
// Gets the specified output port. Note that the special '-1' port_id can be
// used to access the controlled nodes (i.e. the nodes connected to node_name
// through an outgoing control dependency).
OutputPort GetOutputPort(absl::string_view node_name, int port_id) const {
return OutputPort(GetNode(node_name), port_id);
}
// Get the input (resp. output) port(s) in the immediate fanout (resp. fanin)
// of an output (resp. input) port.
// Gets the input port(s) in the immediate fanout of an output port.
const absl::flat_hash_set<InputPort>& GetFanout(
const OutputPort& port) const {
return gtl::FindWithDefault(fanouts_, port, fanout_not_found_value_);
}
// Gets the output port(s) in the immediate fanin of an input port.
absl::flat_hash_set<OutputPort> GetFanin(const InputPort& port) const {
if (port.port_id >= 0) return {GetRegularFanin(port)};
......@@ -162,9 +167,22 @@ class GraphViewInternal {
return GetOutputPort(tensor_id.node(), tensor_id.index());
}
// Get all the input (resp. output) ports in the immediate fanout (resp
// fanin) of a node. Include the controlling nodes iff
// include_controlling_nodes is true.
// Checks if a tensor id is a fanin of the node.
bool HasFanin(const NodeDef& node, const TensorId& fanin) const {
if (fanin.index() < -1) {
return false;
}
string fanin_string = TensorIdToString(fanin);
for (int i = 0; i < node.input_size(); ++i) {
if (node.input(i) == fanin_string) {
return true;
}
}
return false;
}
// Gets all the input ports in the immediate fanout of a node. Include the
// controlled nodes iff include_controlled_nodes is true.
absl::flat_hash_set<InputPort> GetFanouts(
const NodeDef& node, bool include_controlled_nodes) const {
absl::flat_hash_set<InputPort> result;
......@@ -185,6 +203,8 @@ class GraphViewInternal {
return result;
}
// Gets all the output ports in the immediate fanin of a node. Include the
// controlling nodes iff include_controlling_nodes is true.
absl::flat_hash_set<OutputPort> GetFanins(
const NodeDef& node, bool include_controlling_nodes) const {
absl::flat_hash_set<OutputPort> result;
......@@ -198,7 +218,7 @@ class GraphViewInternal {
return result;
}
// Get the number of ports in the immediate fanin of a node. Count the
// Gets the number of ports in the immediate fanin of a node. Count the
// controlling nodes iff include_controlling_nodes is true.
int NumFanins(const NodeDef& node, bool include_controlling_nodes) const {
int count = 0;
......@@ -211,14 +231,14 @@ class GraphViewInternal {
return count;
}
// Get the number of ports in the immediate fanout of a node. Count the
// controlling nodes iff include_controlling_nodes is true.
int NumFanouts(const NodeDef& node, bool include_controlling_nodes) const {
// Gets the number of ports in the immediate fanout of a node. Count the
// controlled nodes iff include_controlled_nodes is true.
int NumFanouts(const NodeDef& node, bool include_controlled_nodes) const {
int count = 0;
OutputPort port;
port.node = const_cast<NodeDefT*>(&node);
const int first_port_id = include_controlling_nodes ? -1 : 0;
const int first_port_id = include_controlled_nodes ? -1 : 0;
const int last_port_id =
gtl::FindWithDefault(max_regular_output_port_, port.node, -1);
......@@ -231,8 +251,8 @@ class GraphViewInternal {
return count;
}
// Get all the edges in the immediate fanout (resp fanin) of a node.
// Include the control edges iff include_controlling_edges is true.
// Gets all the edges in the immediate fanout of a node. Include the
// controlled edges iff include_controlled_edges is true.
absl::flat_hash_set<Edge> GetFanoutEdges(
const NodeDef& node, bool include_controlled_edges) const {
absl::flat_hash_set<Edge> result;
......@@ -248,14 +268,16 @@ class GraphViewInternal {
auto it = fanouts_.find(port);
if (it != fanouts_.end()) {
for (auto itr = it->second.begin(); itr != it->second.end(); ++itr) {
result.emplace(/*src*/ OutputPort(const_cast<NodeDefT*>(&node), i),
/*dst*/ *itr);
result.emplace(/*src=*/OutputPort(const_cast<NodeDefT*>(&node), i),
/*dst=*/*itr);
}
}
}
return result;
}
// Gets all the edges in the immediate fanin of a node. Include the
// controlling edges iff include_controlling_edges is true.
absl::flat_hash_set<Edge> GetFaninEdges(
const NodeDef& node, bool include_controlling_edges) const {
absl::flat_hash_set<Edge> result;
......@@ -265,8 +287,8 @@ class GraphViewInternal {
auto it = nodes_.find(tensor_id.node());
if (it != nodes_.end()) {
result.emplace(/*src*/ OutputPort(it->second, tensor_id.index()),
/*dst*/ InputPort(const_cast<NodeDefT*>(&node), i));
result.emplace(/*src=*/OutputPort(it->second, tensor_id.index()),
/*dst=*/InputPort(const_cast<NodeDefT*>(&node), i));
}
}
return result;
......
......@@ -230,6 +230,40 @@ TEST_F(GraphViewTest, ControlDependencies) {
EXPECT_EQ(0, (*fanin.begin()).port_id);
}
TEST_F(GraphViewTest, HasNode) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphView graph(&item.graph);
EXPECT_EQ(true, graph.HasNode("a"));
EXPECT_EQ(false, graph.HasNode("b"));
}
TEST_F(GraphViewTest, HasFanin) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
Output b = ops::Square(s.WithOpName("b"), {a});
Output c = ops::Sqrt(s.WithOpName("c"), {b});
Output d = ops::AddN(s.WithOpName("d").WithControlDependencies(a), {b, c});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphView graph(&item.graph);
const NodeDef* d_node = graph.GetNode("d");
EXPECT_NE(nullptr, d_node);
EXPECT_EQ(true, graph.HasFanin(*d_node, {"a", Graph::kControlSlot}));
EXPECT_EQ(false, graph.HasFanin(*d_node, {"a", 0}));
EXPECT_EQ(true, graph.HasFanin(*d_node, {"b", 0}));
EXPECT_EQ(false, graph.HasFanin(*d_node, {"b", Graph::kControlSlot}));
EXPECT_EQ(true, graph.HasFanin(*d_node, {"c", 0}));
EXPECT_EQ(false, graph.HasFanin(*d_node, {"c", Graph::kControlSlot}));
}
} // namespace
} // namespace grappler
} // namespace tensorflow
......@@ -144,11 +144,16 @@ void NodeMap::UpdateOutput(const string& node_name,
outputs.insert(nodes_[NodeName(new_output_name)]);
}
string TensorIdToString(const TensorId& tensor_id) {
return tensor_id.index() == 0 ? string(tensor_id.node())
: tensor_id.ToString();
}
bool IsSameInput(const string& name1, const string& name2) {
if (name1 == name2) return true;
TensorId tensor1 = ParseTensorName(name1);
TensorId tensor2 = ParseTensorName(name2);
return tensor1.node() == tensor2.node() && tensor1.index() == tensor2.index();
return tensor1 == tensor2;
}
bool IsControlInput(const string& name) {
......
......@@ -100,6 +100,10 @@ class SetVector {
std::vector<T> vector_;
};
// Returns formatted string from TensorId specific to grappler. Specifically,
// for the 0 port (first output), only the node name is returned.
string TensorIdToString(const TensorId& tensor_id);
// True iff 'name' refers to a control inputs, i.e. a node name prefixed with
// the ^ character.
bool IsControlInput(const string& name);
......
......@@ -464,6 +464,13 @@ TEST_F(UtilsTest, SetTensorValueBFloat16IntMin) {
Tensor(bfloat16(std::numeric_limits<int>::min())), t);
}
TEST_F(UtilsTest, TensorIdToString) {
EXPECT_EQ("^foo", TensorIdToString({"foo", -1}));
EXPECT_EQ("foo", TensorIdToString({"foo", 0}));
EXPECT_EQ("foo:1", TensorIdToString({"foo", 1}));
EXPECT_EQ("foo:2", TensorIdToString({"foo", 2}));
}
} // namespace
} // namespace grappler
} // namespace tensorflow
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册