提交 d15c612f 编写于 作者: A Andy Ly 提交者: TensorFlower Gardener

[Grappler] Migrate FrameView to use utils::GraphView/utils::MutableGraphView.

PiperOrigin-RevId: 251659253
上级 cd09510f
......@@ -498,6 +498,7 @@ cc_library(
"//tensorflow/core/grappler:devices",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/costs:virtual_placer",
......@@ -703,6 +704,7 @@ cc_library(
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/utils:frame",
"//tensorflow/core/grappler/utils:traversal",
......@@ -724,6 +726,7 @@ tf_cuda_cc_test(
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
"//tensorflow/core/grappler/utils:graph_view",
"//tensorflow/core/grappler/utils:grappler_test",
],
)
......@@ -883,6 +886,7 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/utils:frame",
],
......
......@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
#include <deque>
#include <unordered_set>
......@@ -28,7 +30,7 @@ limitations under the License.
#include "tensorflow/core/grappler/devices.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/frame.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
......
......@@ -17,8 +17,10 @@ limitations under the License.
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_LOOP_OPTIMIZER_H_
#include <unordered_set>
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/frame.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
......
......@@ -14,12 +14,14 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/graph_view.h"
#include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
......@@ -104,26 +106,42 @@ TEST_F(LoopOptimizerTest, Basic) {
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
{ // Original graph.
GraphView view(&graph);
Status status;
utils::GraphView view(&graph, &status);
TF_ASSERT_OK(status);
FrameView frames;
TF_EXPECT_OK(frames.InferFromGraphView(view));
EXPECT_EQ(frames.num_frames(), 1);
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd")).size(), 1);
EXPECT_EQ(frames.Frames(*view.GetNode("InvariantAdd")).back(), 0);
ASSERT_EQ(frames.Frames(*view.GetNode("VariantAdd")).size(), 1);
EXPECT_EQ(frames.Frames(*view.GetNode("VariantAdd")).back(), 0);
const auto* invariant_add_node = view.GetNode("InvariantAdd");
ASSERT_NE(invariant_add_node, nullptr);
const auto* invariant_add_node_def = invariant_add_node->node();
ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 1);
EXPECT_EQ(frames.Frames(*invariant_add_node_def).back(), 0);
const auto* variant_add_node = view.GetNode("VariantAdd");
ASSERT_NE(variant_add_node, nullptr);
const auto* variant_add_node_def = variant_add_node->node();
ASSERT_EQ(frames.Frames(*variant_add_node_def).size(), 1);
EXPECT_EQ(frames.Frames(*variant_add_node_def).back(), 0);
}
{ // Optimized graph.
GraphView view(&output);
Status status;
utils::GraphView view(&output, &status);
TF_ASSERT_OK(status);
FrameView frames;
TF_EXPECT_OK(frames.InferFromGraphView(view));
EXPECT_EQ(frames.num_frames(), 1);
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd")).size(), 0);
ASSERT_EQ(frames.Frames(*view.GetNode("VariantAdd")).size(), 1);
EXPECT_EQ(frames.Frames(*view.GetNode("VariantAdd")).back(), 0);
const auto* invariant_add_node = view.GetNode("InvariantAdd");
ASSERT_NE(invariant_add_node, nullptr);
const auto* invariant_add_node_def = invariant_add_node->node();
ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 0);
const auto* variant_add_node = view.GetNode("VariantAdd");
ASSERT_NE(variant_add_node, nullptr);
const auto* variant_add_node_def = variant_add_node->node();
ASSERT_EQ(frames.Frames(*variant_add_node_def).size(), 1);
EXPECT_EQ(frames.Frames(*variant_add_node_def).back(), 0);
}
}
......@@ -155,25 +173,41 @@ TEST_F(LoopOptimizerTest, Const) {
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
{ // Original graph.
GraphView view(&graph);
Status status;
utils::GraphView view(&graph, &status);
TF_ASSERT_OK(status);
FrameView frames;
TF_EXPECT_OK(frames.InferFromGraphView(view));
EXPECT_EQ(frames.num_frames(), 1);
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd")).size(), 1);
EXPECT_EQ(frames.Frames(*view.GetNode("InvariantAdd")).back(), 0);
ASSERT_EQ(frames.Frames(*view.GetNode("Const")).size(), 1);
EXPECT_EQ(frames.Frames(*view.GetNode("Const")).back(), 0);
const auto* invariant_add_node = view.GetNode("InvariantAdd");
ASSERT_NE(invariant_add_node, nullptr);
const auto* invariant_add_node_def = invariant_add_node->node();
ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 1);
EXPECT_EQ(frames.Frames(*invariant_add_node_def).back(), 0);
const auto* const_node = view.GetNode("Const");
ASSERT_NE(const_node, nullptr);
const auto* const_node_node_def = const_node->node();
ASSERT_EQ(frames.Frames(*const_node_node_def).size(), 1);
EXPECT_EQ(frames.Frames(*const_node_node_def).back(), 0);
}
{ // Optimized graph.
GraphView view(&output);
Status status;
utils::GraphView view(&output, &status);
TF_ASSERT_OK(status);
FrameView frames;
TF_EXPECT_OK(frames.InferFromGraphView(view));
EXPECT_EQ(frames.num_frames(), 1);
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd")).size(), 0);
ASSERT_EQ(frames.Frames(*view.GetNode("Const")).size(), 0);
const auto* invariant_add_node = view.GetNode("InvariantAdd");
ASSERT_NE(invariant_add_node, nullptr);
const auto* invariant_add_node_def = invariant_add_node->node();
ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 0);
const auto* const_node = view.GetNode("Const");
ASSERT_NE(const_node, nullptr);
const auto* const_node_node_def = const_node->node();
ASSERT_EQ(frames.Frames(*const_node_node_def).size(), 0);
}
}
......@@ -206,23 +240,33 @@ TEST_F(LoopOptimizerTest, ControlOutput) {
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
{ // Original graph.
GraphView view(&graph);
Status status;
utils::GraphView view(&graph, &status);
TF_ASSERT_OK(status);
FrameView frames;
TF_EXPECT_OK(frames.InferFromGraphView(view));
EXPECT_EQ(frames.num_frames(), 1);
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd")).size(), 1);
EXPECT_EQ(frames.Frames(*view.GetNode("InvariantAdd")).back(), 0);
const auto* invariant_add_node = view.GetNode("InvariantAdd");
ASSERT_NE(invariant_add_node, nullptr);
const auto* invariant_add_node_def = invariant_add_node->node();
ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 1);
EXPECT_EQ(frames.Frames(*invariant_add_node_def).back(), 0);
}
{ // Optimized graph.
GraphView view(&output);
Status status;
utils::GraphView view(&output, &status);
TF_ASSERT_OK(status);
FrameView frames;
TF_EXPECT_OK(frames.InferFromGraphView(view));
EXPECT_EQ(frames.num_frames(), 1);
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd")).size(), 1);
EXPECT_EQ(frames.Frames(*view.GetNode("InvariantAdd")).back(), 0);
const auto* invariant_add_node = view.GetNode("InvariantAdd");
ASSERT_NE(invariant_add_node, nullptr);
const auto* invariant_add_node_def = invariant_add_node->node();
ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 1);
EXPECT_EQ(frames.Frames(*invariant_add_node_def).back(), 0);
}
}
......@@ -270,30 +314,52 @@ TEST_F(LoopOptimizerTest, NestedLoop1) {
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
{ // Original graph.
GraphView view(&graph);
Status status;
utils::GraphView view(&graph, &status);
TF_ASSERT_OK(status);
FrameView frames;
TF_EXPECT_OK(frames.InferFromGraphView(view));
EXPECT_EQ(frames.num_frames(), 2);
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd2")).size(), 2);
EXPECT_EQ(frames.Frames(*view.GetNode("InvariantAdd2")).back(), 1);
ASSERT_EQ(frames.Frames(*view.GetNode("VariantAdd2")).size(), 2);
EXPECT_EQ(frames.Frames(*view.GetNode("VariantAdd2")).back(), 1);
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd")).size(), 1);
EXPECT_EQ(frames.Frames(*view.GetNode("InvariantAdd")).back(), 0);
const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
ASSERT_NE(invariant_add_2_node, nullptr);
const auto* invariant_add_2_node_def = invariant_add_2_node->node();
ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 2);
EXPECT_EQ(frames.Frames(*invariant_add_2_node_def).back(), 1);
const auto* variant_add_2_node = view.GetNode("VariantAdd2");
ASSERT_NE(variant_add_2_node, nullptr);
const auto* variant_add_2_node_def = variant_add_2_node->node();
ASSERT_EQ(frames.Frames(*variant_add_2_node_def).size(), 2);
EXPECT_EQ(frames.Frames(*variant_add_2_node_def).back(), 1);
const auto* invariant_add_node = view.GetNode("InvariantAdd");
ASSERT_NE(invariant_add_node, nullptr);
const auto* invariant_add_node_def = invariant_add_node->node();
ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 1);
EXPECT_EQ(frames.Frames(*invariant_add_node_def).back(), 0);
}
{ // Optimized graph.
GraphView view(&output);
Status status;
utils::GraphView view(&output, &status);
TF_ASSERT_OK(status);
FrameView frames;
TF_EXPECT_OK(frames.InferFromGraphView(view));
EXPECT_EQ(frames.num_frames(), 2);
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd2")).size(), 1);
EXPECT_EQ(frames.Frames(*view.GetNode("InvariantAdd2")).back(), 0);
ASSERT_EQ(frames.Frames(*view.GetNode("VariantAdd2")).size(), 2);
EXPECT_EQ(frames.Frames(*view.GetNode("VariantAdd2")).back(), 1);
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd")).size(), 0);
const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
ASSERT_NE(invariant_add_2_node, nullptr);
const auto* invariant_add_2_node_def = invariant_add_2_node->node();
ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 1);
EXPECT_EQ(frames.Frames(*invariant_add_2_node_def).back(), 0);
const auto* variant_add_2_node = view.GetNode("VariantAdd2");
ASSERT_NE(variant_add_2_node, nullptr);
const auto* variant_add_2_node_def = variant_add_2_node->node();
ASSERT_EQ(frames.Frames(*variant_add_2_node_def).size(), 2);
EXPECT_EQ(frames.Frames(*variant_add_2_node_def).back(), 1);
const auto* invariant_add_node = view.GetNode("InvariantAdd");
ASSERT_NE(invariant_add_node, nullptr);
const auto* invariant_add_node_def = invariant_add_node->node();
ASSERT_EQ(frames.Frames(*invariant_add_node_def).size(), 0);
}
}
......@@ -341,26 +407,42 @@ TEST_F(LoopOptimizerTest, NestedLoop2) {
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
{ // Original graph.
GraphView view(&graph);
Status status;
utils::GraphView view(&graph, &status);
TF_ASSERT_OK(status);
FrameView frames;
TF_EXPECT_OK(frames.InferFromGraphView(view));
EXPECT_EQ(frames.num_frames(), 2);
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd2")).size(), 2);
EXPECT_EQ(frames.Frames(*view.GetNode("InvariantAdd2")).back(), 1);
ASSERT_EQ(frames.Frames(*view.GetNode("VariantAdd2")).size(), 2);
EXPECT_EQ(frames.Frames(*view.GetNode("VariantAdd2")).back(), 1);
const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
ASSERT_NE(invariant_add_2_node, nullptr);
const auto* invariant_add_2_node_def = invariant_add_2_node->node();
ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 2);
EXPECT_EQ(frames.Frames(*invariant_add_2_node_def).back(), 1);
const auto* variant_add_2_node = view.GetNode("VariantAdd2");
ASSERT_NE(variant_add_2_node, nullptr);
const auto* variant_add_2_node_def = variant_add_2_node->node();
ASSERT_EQ(frames.Frames(*variant_add_2_node_def).size(), 2);
EXPECT_EQ(frames.Frames(*variant_add_2_node_def).back(), 1);
}
{ // Optimized graph.
GraphView view(&output);
Status status;
utils::GraphView view(&output, &status);
TF_ASSERT_OK(status);
FrameView frames;
TF_EXPECT_OK(frames.InferFromGraphView(view));
EXPECT_EQ(frames.num_frames(), 2);
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd2")).size(), 0);
ASSERT_EQ(frames.Frames(*view.GetNode("VariantAdd2")).size(), 2);
EXPECT_EQ(frames.Frames(*view.GetNode("VariantAdd2")).back(), 1);
const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
ASSERT_NE(invariant_add_2_node, nullptr);
const auto* invariant_add_2_node_def = invariant_add_2_node->node();
ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 0);
const auto* variant_add_2_node = view.GetNode("VariantAdd2");
ASSERT_NE(variant_add_2_node, nullptr);
const auto* variant_add_2_node_def = variant_add_2_node->node();
ASSERT_EQ(frames.Frames(*variant_add_2_node_def).size(), 2);
EXPECT_EQ(frames.Frames(*variant_add_2_node_def).back(), 1);
}
}
......@@ -408,27 +490,43 @@ TEST_F(LoopOptimizerTest, NestedLoopConst1) {
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
{ // Original graph.
GraphView view(&graph);
Status status;
utils::GraphView view(&graph, &status);
TF_ASSERT_OK(status);
FrameView frames;
TF_EXPECT_OK(frames.InferFromGraphView(view));
EXPECT_EQ(frames.num_frames(), 2);
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd2")).size(), 2);
EXPECT_EQ(frames.Frames(*view.GetNode("InvariantAdd2")).back(), 1);
ASSERT_EQ(frames.Frames(*view.GetNode("Const2")).size(), 2);
EXPECT_EQ(frames.Frames(*view.GetNode("Const2")).back(), 1);
const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
ASSERT_NE(invariant_add_2_node, nullptr);
const auto* invariant_add_2_node_def = invariant_add_2_node->node();
ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 2);
EXPECT_EQ(frames.Frames(*invariant_add_2_node_def).back(), 1);
const auto* const_2_node = view.GetNode("Const2");
ASSERT_NE(const_2_node, nullptr);
const auto* const_2_node_def = const_2_node->node();
ASSERT_EQ(frames.Frames(*const_2_node_def).size(), 2);
EXPECT_EQ(frames.Frames(*const_2_node_def).back(), 1);
}
{ // Optimized graph.
GraphView view(&output);
Status status;
utils::GraphView view(&output, &status);
TF_ASSERT_OK(status);
FrameView frames;
TF_EXPECT_OK(frames.InferFromGraphView(view));
EXPECT_EQ(frames.num_frames(), 2);
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd2")).size(), 1);
EXPECT_EQ(frames.Frames(*view.GetNode("InvariantAdd2")).back(), 0);
ASSERT_EQ(frames.Frames(*view.GetNode("Const2")).size(), 1);
EXPECT_EQ(frames.Frames(*view.GetNode("Const2")).back(), 0);
const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
ASSERT_NE(invariant_add_2_node, nullptr);
const auto* invariant_add_2_node_def = invariant_add_2_node->node();
ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 1);
EXPECT_EQ(frames.Frames(*invariant_add_2_node_def).back(), 0);
const auto* const_2_node = view.GetNode("Const2");
ASSERT_NE(const_2_node, nullptr);
const auto* const_2_node_def = const_2_node->node();
ASSERT_EQ(frames.Frames(*const_2_node_def).size(), 1);
EXPECT_EQ(frames.Frames(*const_2_node_def).back(), 0);
}
}
......@@ -476,25 +574,41 @@ TEST_F(LoopOptimizerTest, NestedLoopConst2) {
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
{ // Original graph.
GraphView view(&graph);
Status status;
utils::GraphView view(&graph, &status);
TF_ASSERT_OK(status);
FrameView frames;
TF_EXPECT_OK(frames.InferFromGraphView(view));
EXPECT_EQ(frames.num_frames(), 2);
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd2")).size(), 2);
EXPECT_EQ(frames.Frames(*view.GetNode("InvariantAdd2")).back(), 1);
ASSERT_EQ(frames.Frames(*view.GetNode("Const2")).size(), 2);
EXPECT_EQ(frames.Frames(*view.GetNode("Const2")).back(), 1);
const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
ASSERT_NE(invariant_add_2_node, nullptr);
const auto* invariant_add_2_node_def = invariant_add_2_node->node();
ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 2);
EXPECT_EQ(frames.Frames(*invariant_add_2_node_def).back(), 1);
const auto* const_2_node = view.GetNode("Const2");
ASSERT_NE(const_2_node, nullptr);
const auto* const_2_node_def = const_2_node->node();
ASSERT_EQ(frames.Frames(*const_2_node_def).size(), 2);
EXPECT_EQ(frames.Frames(*const_2_node_def).back(), 1);
}
{ // Optimized graph.
GraphView view(&output);
Status status;
utils::GraphView view(&output, &status);
TF_ASSERT_OK(status);
FrameView frames;
TF_EXPECT_OK(frames.InferFromGraphView(view));
EXPECT_EQ(frames.num_frames(), 2);
ASSERT_EQ(frames.Frames(*view.GetNode("InvariantAdd2")).size(), 0);
ASSERT_EQ(frames.Frames(*view.GetNode("Const2")).size(), 0);
const auto* invariant_add_2_node = view.GetNode("InvariantAdd2");
ASSERT_NE(invariant_add_2_node, nullptr);
const auto* invariant_add_2_node_def = invariant_add_2_node->node();
ASSERT_EQ(frames.Frames(*invariant_add_2_node_def).size(), 0);
const auto* const_2_node = view.GetNode("Const2");
ASSERT_NE(const_2_node, nullptr);
const auto* const_2_node_def = const_2_node->node();
ASSERT_EQ(frames.Frames(*const_2_node_def).size(), 0);
}
}
......
......@@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
......
......@@ -78,10 +78,10 @@ cc_library(
hdrs = ["frame.h"],
visibility = ["//visibility:public"],
deps = [
":graph_view",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:op_types",
"@com_google_absl//absl/container:flat_hash_map",
],
......@@ -93,6 +93,8 @@ tf_cc_test(
srcs = ["frame_test.cc"],
deps = [
":frame",
":graph_view",
"//tensorflow/core:lib",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
......
......@@ -14,10 +14,11 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/utils/frame.h"
#include <deque>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/lib/core/errors.h"
......@@ -26,101 +27,134 @@ namespace grappler {
namespace {} // namespace
Status FrameView::InferFromGraphView(const GraphView& graph_view) {
template <typename GraphViewT>
inline Status FrameView::InferFromGraphViewT(const GraphViewT& graph_view) {
if (is_inferred_) {
return errors::Internal("FrameView was already inferred from the graph");
}
is_inferred_ = true;
std::deque<const NodeDef*> ready_nodes;
std::deque<int> ready_node_indices;
// All nodes without inputs are automatically added to the ready queue.
for (const NodeDef& node : graph_view.graph()->node()) {
if (node.input_size() == 0) {
ready_nodes.push_back(&node);
node_to_frames_[&node] = node_has_no_frames_;
for (const auto& node : graph_view.GetNodes()) {
if (node.NumRegularFanins() + node.NumControllingFanins() == 0) {
ready_node_indices.push_back(node.node_index());
node_to_frames_[node.node()] = node_has_no_frames_;
}
}
const auto* graph = graph_view.graph();
// We assign unique int id to each frame, and use this map to track what
// frames we've already seen in the graph.
absl::flat_hash_map<string, int> frame_name_to_id;
while (!ready_nodes.empty()) {
const NodeDef* ready_node = ready_nodes.front();
absl::flat_hash_set<GraphView::InputPort> fanouts =
graph_view.GetFanouts(*ready_node, /*include_controlled_nodes=*/true);
auto process_fanout = [this, graph](
absl::flat_hash_map<string, int>* frame_name_to_id,
std::deque<int>* ready_node_indices,
const NodeDef* ready_node, int fanout_node_index) {
const NodeDef* fanout_node = &graph->node(fanout_node_index);
if (!node_to_frames_.contains(fanout_node)) {
// If we have never seen this node before, we add all frames from the
// incoming node (and pop/push frames if coming from Exit/Enter nodes).
std::vector<int> frame_ids = node_to_frames_[ready_node];
if (IsExit(*ready_node)) {
frame_ids.pop_back();
}
for (const GraphView::InputPort& fanout : fanouts) {
if (node_to_frames_.count(fanout.node) < 1) {
// If we have never seen this node before, we add all frames from the
// incoming node (and pop/push frames if coming from Exit/Enter nodes).
std::vector<int> frame_ids = node_to_frames_[ready_node];
if (IsEnter(*fanout_node)) {
const AttrValue* frame_name_attr =
AttrSlice(*fanout_node).Find("frame_name");
if (IsExit(*ready_node)) {
frame_ids.pop_back();
if (!frame_name_attr) {
return errors::InvalidArgument(
"Missing frame name for the Enter node: ",
SummarizeNodeDef(*fanout_node));
}
if (IsEnter(*fanout.node)) {
const AttrValue* frame_name_attr =
AttrSlice(*fanout.node).Find("frame_name");
const string& frame_name = frame_name_attr->s();
int frame_id;
if (!frame_name_attr) {
return errors::InvalidArgument(
"Missing frame name for the Enter node: ",
SummarizeNodeDef(*fanout.node));
}
absl::string_view frame_name = frame_name_attr->s();
int frame_id;
if (frame_name_to_id.count(frame_name)) {
frame_id = frame_name_to_id[frame_name];
} else {
frame_id = static_cast<int>(frame_name_to_id.size());
frame_name_to_id[frame_name] = frame_id;
}
frame_ids.push_back(frame_id);
if (frame_name_to_id->contains(frame_name)) {
frame_id = (*frame_name_to_id)[frame_name];
} else {
frame_id = static_cast<int>(frame_name_to_id->size());
(*frame_name_to_id)[frame_name] = frame_id;
}
ready_nodes.push_back(fanout.node);
node_to_frames_[fanout.node] = std::move(frame_ids);
frame_ids.push_back(frame_id);
}
} else {
// If we've already seen this node before, we need to make sure that
// graph is correct and same nodes doesn't have incoming edges with
// conflicting frames (all inputs must be produces in the same frame).
ready_node_indices->push_back(fanout_node_index);
node_to_frames_[fanout_node] = std::move(frame_ids);
std::vector<int> frame_ids_fanout = node_to_frames_[fanout.node];
std::vector<int> frame_ids_node = node_to_frames_[ready_node];
} else {
// If we've already seen this node before, we need to make sure that graph
// is correct and same nodes doesn't have incoming edges with conflicting
// frames (all inputs must be produces in the same frame).
if (IsEnter(*fanout.node)) {
frame_ids_fanout.pop_back();
}
if (IsExit(*ready_node)) {
frame_ids_node.pop_back();
}
std::vector<int> frame_ids_fanout = node_to_frames_[fanout_node];
std::vector<int> frame_ids_node = node_to_frames_[ready_node];
if (frame_ids_node != frame_ids_fanout) {
return errors::InvalidArgument(
"Invalid graph: Frame ids for node ", ready_node->name(),
" does not match frame ids for it's fanout ",
fanout.node->name());
}
if (IsEnter(*fanout_node)) {
frame_ids_fanout.pop_back();
}
if (IsExit(*ready_node)) {
frame_ids_node.pop_back();
}
if (frame_ids_node != frame_ids_fanout) {
return errors::InvalidArgument(
"Invalid graph: Frame ids for node ", ready_node->name(),
" does not match frame ids for it's fanout ", fanout_node->name());
}
}
return Status::OK();
};
while (!ready_node_indices.empty()) {
const int ready_node_index = ready_node_indices.front();
ready_node_indices.pop_front();
const auto* ready_node_view = graph_view.GetNode(ready_node_index);
const NodeDef* ready_node_def = ready_node_view->node();
for (const auto& regular_fanouts_port_i :
ready_node_view->GetRegularFanouts()) {
for (const auto& regular_fanout : regular_fanouts_port_i) {
TF_RETURN_IF_ERROR(process_fanout(&frame_name_to_id,
&ready_node_indices, ready_node_def,
regular_fanout.node_index()));
}
}
ready_nodes.pop_front();
for (const auto& controlled_fanout :
ready_node_view->GetControlledFanouts()) {
TF_RETURN_IF_ERROR(process_fanout(&frame_name_to_id, &ready_node_indices,
ready_node_def,
controlled_fanout.node_index()));
}
}
num_frames_ = static_cast<int>(frame_name_to_id.size());
return Status::OK();
}
Status FrameView::InferFromGraphView(const utils::GraphView& graph_view) {
return InferFromGraphViewT(graph_view);
}
Status FrameView::InferFromGraphView(
const utils::MutableGraphView& graph_view) {
return InferFromGraphViewT(graph_view);
}
Status FrameView::InferFromGraph(const GraphDef& graph) {
return InferFromGraphView(GraphView(&graph));
Status status;
utils::GraphView graph_view(&graph, &status);
TF_RETURN_IF_ERROR(status);
return InferFromGraphViewT(graph_view);
}
const std::vector<int>& FrameView::Frames(const NodeDef& node) const {
......
......@@ -16,10 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_FRAME_H_
#define TENSORFLOW_CORE_GRAPPLER_UTILS_FRAME_H_
#include <unordered_map>
#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/utils/graph_view.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
......@@ -40,7 +39,10 @@ class FrameView {
// Infers nodes execution frames from the GraphView. Returns an error if
// called multiple times.
Status InferFromGraphView(const GraphView& graph_view);
Status InferFromGraphView(const utils::GraphView& graph_view);
// Infers nodes execution frames from the MutableGraphView. Returns an error
// if called multiple times.
Status InferFromGraphView(const utils::MutableGraphView& graph_view);
// Infers nodes execution by constructing temporary GraphView and passing it
// to InferFromGraphView.
Status InferFromGraph(const GraphDef& graph);
......@@ -56,6 +58,9 @@ class FrameView {
bool is_inferred() const { return is_inferred_; }
private:
template <typename GraphViewT>
inline Status InferFromGraphViewT(const GraphViewT& graph_view);
bool is_inferred_; // true if it was inferred from the graph
int num_frames_; // number of frames present in a graph
absl::flat_hash_map<const NodeDef*, std::vector<int>> node_to_frames_;
......
......@@ -14,8 +14,11 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/utils/frame.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/utils/graph_view.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
......@@ -23,19 +26,23 @@ namespace tensorflow {
namespace grappler {
namespace {
using GraphTypes =
::testing::Types<GraphDef, utils::GraphView, utils::MutableGraphView>;
template <typename T>
class FrameViewTest : public ::testing::Test {
protected:
static NodeDef CreateNode(const string& name,
const std::vector<string>& inputs) {
NodeDef CreateNode(const string& name, const std::vector<string>& inputs) {
return CreateNode(name, "", "", inputs);
}
static NodeDef CreateNode(const string& name, const string& op,
const std::vector<string>& inputs) {
NodeDef CreateNode(const string& name, const string& op,
const std::vector<string>& inputs) {
return CreateNode(name, op, "", inputs);
}
static NodeDef CreateNode(const string& name, const string& op,
const string& frame,
const std::vector<string>& inputs) {
NodeDef CreateNode(const string& name, const string& op, const string& frame,
const std::vector<string>& inputs) {
NodeDef node;
node.set_name(name);
if (!op.empty()) {
......@@ -53,30 +60,56 @@ class FrameViewTest : public ::testing::Test {
}
};
TEST_F(FrameViewTest, NestedLoop) {
TYPED_TEST_SUITE(FrameViewTest, GraphTypes);
template <typename T>
void InferFromGraph(FrameView* frame_view, GraphDef* graph, bool valid) {
Status status;
T graph_view(graph, &status);
TF_ASSERT_OK(status);
status = frame_view->InferFromGraphView(graph_view);
if (valid) {
TF_ASSERT_OK(status);
} else {
ASSERT_FALSE(status.ok());
}
}
template <>
void InferFromGraph<GraphDef>(FrameView* frame_view, GraphDef* graph,
bool valid) {
Status status = frame_view->InferFromGraph(*graph);
if (valid) {
TF_ASSERT_OK(status);
} else {
ASSERT_FALSE(status.ok());
}
}
TYPED_TEST(FrameViewTest, NestedLoop) {
GraphDef graph;
// Create a two-level nested loop
*graph.add_node() = CreateNode("0", {});
*graph.add_node() = CreateNode("1", "Enter", "while/context1", {"0"});
*graph.add_node() = CreateNode("2", {"1"});
*graph.add_node() = CreateNode("3", "Merge", {"2", "14"});
*graph.add_node() = CreateNode("4", {"3"});
*graph.add_node() = CreateNode("5", "Switch", {"4"});
*graph.add_node() = CreateNode("6", {"5"});
*graph.add_node() = CreateNode("7", "Enter", "while/context2", {"6"});
*graph.add_node() = CreateNode("8", {"7"});
*graph.add_node() = CreateNode("9", "Merge", {"8", "12"});
*graph.add_node() = CreateNode("10", {"9"});
*graph.add_node() = CreateNode("11", "Switch", {"10"});
*graph.add_node() = CreateNode("12", "NextIteration", {"11"});
*graph.add_node() = CreateNode("13", "Exit", {"11"});
*graph.add_node() = CreateNode("14", "NextIteration", {"13"});
*graph.add_node() = CreateNode("15", {"5"});
*graph.add_node() = CreateNode("16", "Exit", {"15"});
*graph.add_node() = CreateNode("17", {"16"});
*graph.add_node() = this->CreateNode("0", {});
*graph.add_node() = this->CreateNode("1", "Enter", "while/context1", {"0"});
*graph.add_node() = this->CreateNode("2", {"1"});
*graph.add_node() = this->CreateNode("3", "Merge", {"2", "14"});
*graph.add_node() = this->CreateNode("4", {"3"});
*graph.add_node() = this->CreateNode("5", "Switch", {"4"});
*graph.add_node() = this->CreateNode("6", {"5"});
*graph.add_node() = this->CreateNode("7", "Enter", "while/context2", {"6"});
*graph.add_node() = this->CreateNode("8", {"7"});
*graph.add_node() = this->CreateNode("9", "Merge", {"8", "12"});
*graph.add_node() = this->CreateNode("10", {"9"});
*graph.add_node() = this->CreateNode("11", "Switch", {"10"});
*graph.add_node() = this->CreateNode("12", "NextIteration", {"11"});
*graph.add_node() = this->CreateNode("13", "Exit", {"11"});
*graph.add_node() = this->CreateNode("14", "NextIteration", {"13"});
*graph.add_node() = this->CreateNode("15", {"5"});
*graph.add_node() = this->CreateNode("16", "Exit", {"15"});
*graph.add_node() = this->CreateNode("17", {"16"});
FrameView frame_view;
ASSERT_TRUE(frame_view.InferFromGraph(graph).ok());
InferFromGraph<TypeParam>(&frame_view, &graph, /*valid=*/true);
std::unordered_map<string, std::vector<int>> expected = {
{"0", {}}, {"1", {0}}, {"2", {0}}, {"3", {0}},
......@@ -93,15 +126,16 @@ TEST_F(FrameViewTest, NestedLoop) {
}
}
TEST_F(FrameViewTest, MultipleInputsToEnter) {
TYPED_TEST(FrameViewTest, MultipleInputsToEnter) {
GraphDef graph;
*graph.add_node() = CreateNode("0", {});
*graph.add_node() = CreateNode("1", {});
*graph.add_node() = CreateNode("2", "Enter", "while/context", {"0", "1"});
*graph.add_node() = CreateNode("3", "Exit", {"2"});
*graph.add_node() = this->CreateNode("0", {});
*graph.add_node() = this->CreateNode("1", {});
*graph.add_node() =
this->CreateNode("2", "Enter", "while/context", {"0", "1"});
*graph.add_node() = this->CreateNode("3", "Exit", {"2"});
FrameView frame_view;
ASSERT_TRUE(frame_view.InferFromGraph(graph).ok());
InferFromGraph<TypeParam>(&frame_view, &graph, /*valid=*/true);
std::unordered_map<string, std::vector<int>> expected = {
{"0", {}}, {"1", {}}, {"2", {0}}, {"3", {0}}};
......@@ -114,16 +148,16 @@ TEST_F(FrameViewTest, MultipleInputsToEnter) {
}
}
TEST_F(FrameViewTest, ExitOutput) {
TYPED_TEST(FrameViewTest, ExitOutput) {
GraphDef graph;
*graph.add_node() = CreateNode("0", {});
*graph.add_node() = CreateNode("1", "Enter", "while/context", {"0"});
*graph.add_node() = CreateNode("2", "Exit", {"1"});
*graph.add_node() = CreateNode("3", {});
*graph.add_node() = CreateNode("4", {"2", "3"});
*graph.add_node() = this->CreateNode("0", {});
*graph.add_node() = this->CreateNode("1", "Enter", "while/context", {"0"});
*graph.add_node() = this->CreateNode("2", "Exit", {"1"});
*graph.add_node() = this->CreateNode("3", {});
*graph.add_node() = this->CreateNode("4", {"2", "3"});
FrameView frame_view;
ASSERT_TRUE(frame_view.InferFromGraph(graph).ok());
InferFromGraph<TypeParam>(&frame_view, &graph, /*valid=*/true);
std::unordered_map<string, std::vector<int>> expected = {
{"0", {}}, {"1", {0}}, {"2", {0}}, {"3", {}}, {"4", {}}};
......@@ -136,21 +170,21 @@ TEST_F(FrameViewTest, ExitOutput) {
}
}
TEST_F(FrameViewTest, MultipleEnterNodes) {
TYPED_TEST(FrameViewTest, MultipleEnterNodes) {
GraphDef graph;
*graph.add_node() = CreateNode("0", {});
*graph.add_node() = CreateNode("1", "Enter", "while/context", {"0"});
*graph.add_node() = CreateNode("2", {"1"});
*graph.add_node() = CreateNode("5", {});
*graph.add_node() = CreateNode("4", "Enter", "while/context", {"5"});
*graph.add_node() = CreateNode("3", {"4", "2"});
*graph.add_node() = CreateNode("6", "Merge", {"3", "8"});
*graph.add_node() = CreateNode("7", "Switch", {"6"});
*graph.add_node() = CreateNode("8", "NextIteration", {"7"});
*graph.add_node() = CreateNode("9", "Exit", {"7"});
*graph.add_node() = this->CreateNode("0", {});
*graph.add_node() = this->CreateNode("1", "Enter", "while/context", {"0"});
*graph.add_node() = this->CreateNode("2", {"1"});
*graph.add_node() = this->CreateNode("5", {});
*graph.add_node() = this->CreateNode("4", "Enter", "while/context", {"5"});
*graph.add_node() = this->CreateNode("3", {"4", "2"});
*graph.add_node() = this->CreateNode("6", "Merge", {"3", "8"});
*graph.add_node() = this->CreateNode("7", "Switch", {"6"});
*graph.add_node() = this->CreateNode("8", "NextIteration", {"7"});
*graph.add_node() = this->CreateNode("9", "Exit", {"7"});
FrameView frame_view;
ASSERT_TRUE(frame_view.InferFromGraph(graph).ok());
InferFromGraph<TypeParam>(&frame_view, &graph, /*valid=*/true);
std::unordered_map<string, std::vector<int>> expected = {
{"0", {}}, {"1", {0}}, {"2", {0}}, {"3", {0}}, {"4", {0}},
......@@ -164,15 +198,15 @@ TEST_F(FrameViewTest, MultipleEnterNodes) {
}
}
TEST_F(FrameViewTest, ConflictingFrames) {
TYPED_TEST(FrameViewTest, ConflictingFrames) {
GraphDef graph;
*graph.add_node() = CreateNode("0", {});
*graph.add_node() = CreateNode("1", "Enter", "while/context1", {"0"});
*graph.add_node() = CreateNode("2", "Enter", "while/context2", {"1"});
*graph.add_node() = CreateNode("3", {"1", "2"});
*graph.add_node() = this->CreateNode("0", {});
*graph.add_node() = this->CreateNode("1", "Enter", "while/context1", {"0"});
*graph.add_node() = this->CreateNode("2", "Enter", "while/context2", {"1"});
*graph.add_node() = this->CreateNode("3", {"1", "2"});
FrameView frame_view;
ASSERT_FALSE(frame_view.InferFromGraph(graph).ok());
InferFromGraph<TypeParam>(&frame_view, &graph, /*valid=*/false);
}
} // namespace
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册