提交 4d810848 编写于 作者: T Tong Shen 提交者: TensorFlower Gardener

Handle outside compilation at beginning/end of TPU computation.

PiperOrigin-RevId: 225396866
上级 150b4c8e
......@@ -27,51 +27,13 @@ namespace tensorflow {
// a list of PartialTensorShape objects.
extern const char kXlaInferredShapesAttrName[];
// Infer output shapes for outside compilation nodes which have output data
// edges to XLA computation nodes. These shapes will be used later by XLA
// compiler as output shapes of the outside compilation's XlaHostCompute op.
// XLA computation nodes will be mark by attr `xla_computation_attr_name`;
// outside compilation nodes will be marked by both attr
// `xla_computation_attr_name` and `outside_compilation_attr_name`.
//
// Those outside compilation nodes will be marked with attribute
// `kXlaInferredShapesAttrName`.
// Infers output shapes for all nodes in graph `g`. The output shapes will be
// stored in node attribute `kXlaInferredShapesAttrName`.
//
// We have to perform shape inference before encapsulation because after
// encapsulation, some nodes will be encapsulated into function call, and shape
// inference does not handle function call at the moment.
Status PerformStaticShapeInferenceBeforeEncapsulation(
Graph* g, const string& xla_computation_attr_name,
const string& outside_compilation_attr_name);
// Attribute indicating that some ops in other XLA computation has control
// dependency on this node. Attribute value will be a list of string (XLA
// computation names).
extern const char kXlaConnectedToOtherXlaComputationAttrName[];
// Attribute indicating that this node has control dependency on some ops in
// other XLA computation. Attribute value will be a list of string (XLA
// computation names).
extern const char kXlaConnectedFromOtherXlaComputationAttrName[];
// Attribute indicating that this node has control dependencies on some other
// nodes. Attribute value will be a list of string (node names).
extern const char kXlaControlDependenciesAttrName[];
// Attribute indicating that this is an Identity node added to act as a bridge
// between different XLA computations. Attribute value will be string (source
// node name).
extern const char kBridgeSourceNodeAttrName[];
// Attribute indicating that this is an Placeholder node added to act as a
// temporary input node for an outside compilation node. Attribute value will be
// string (original input node name).
extern const char kOutsideCompilationToHostOriginalNodeAttrName[];
// Attribute indicating that this is an Placeholder node added to act as a
// temporary input node for an outside compilation node. Attribute value will be
// int (src_output for original edge).
extern const char kOutsideCompilationToHostSrcOutputAttrName[];
Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g);
// Attribute indicating that some ops in this node's XLA computation has control
// dependency on this node. Attribute value will always be "true".
......@@ -81,16 +43,6 @@ extern const char kXlaConnectedToXlaComputationAttrName[];
// this node's XLA computation. Attribute value will always be "true".
extern const char kXlaConnectedFromXlaComputationAttrName[];
// Attribute indicating that this is an Placeholder node added to act as a
// temporary input node for an host node. Attribute value will be string
// (original input node name).
extern const char kHostToOutsideCompilationOriginalNodeAttrName[];
// Attribute indicating that this is an Placeholder node added to act as a
// temporary input node for a host node. Attribute value will be int (src_output
// for original edge).
extern const char kHostToOutsideCompilationSrcOutputAttrName[];
// Attribute indicating that this is an Placeholder node added to act as a
// temporary input node for an outside compilation node. Attribute value will be
// string (original input node name).
......@@ -106,27 +58,6 @@ extern const char kOutsideCompilationSrcOutputAttrName[];
// (node names).
extern const char kXlaControlDependenciesWithinXlaClusterAttrName[];
// Preprocesses edges between different XLA clusters for encapsulation. It will
// perform the following operations in order:
//
// 1a. For control edges between outside compilation and another XLA
// computation, add attr "kXlaConnected{From, To}OtherXlaComputationAttrName
// = XLA computation node name" to the outside compilation node.
// 1b. For control edges between different outside compilations (in different
// XLA computations), remove the edge and add attr
// "kXlaControlDependenciesAttrName = src node name" to dst node.
// 1c. For control edges between outside compilation and host computation,
// remove the edge and add attr "kXlaControlDependenciesAttrName = src node
// name" to dst node.
// 2. For data edges between different XLA computations, if either src or dst
// is outside compilation, add an Identity node in between the edge. The
// identity node will have attr kBridgeSourceNodeAttrName.
// 3. For data edges between outside compilation and host computation, remove
// the edge and create a Placeholder node as dst node's input.
Status PreprocessForEncapsulation(Graph* g,
const string& xla_computation_attr_name,
const string& outside_compilation_attr_name);
// Information for XLA computation.
struct XlaClusterInfo {
// Add an explicitly-defined default constructor for this class.
......@@ -158,24 +89,6 @@ struct XlaClusterInfo {
const std::map<string, int> host_compute_core;
};
// Postprocesses edges between different XLA clusters for encapsulation. This
// function reverts what `PreprocessForEncapsulation` did. It will perform the
// following operations in order:
//
// 1. Remove Placeholder nodes between outside compilation and host computation
// (created in `PreprocessForEncapsulation` step 3).
// 2. Remove Identity nodes created in `PreprocessForEncapsulation` step 2.
// 3a. Reconnect control edges between outside compilation and another XLA
// computation (marked by `PreprocessForEncapsulation` step 1a).
// 3b. Reconnect control edges between different outside compilations (marked by
// `PreprocessForEncapsulation` step 1b).
// 3c. Reconnect control edges between outside compilation and host computation
// (marked by `PreprocessForEncapsulation` step 1c).
Status PostprocessForEncapsulation(
Graph* g, const string& xla_computation_attr_name,
const string& outside_compilation_attr_name,
const std::unordered_map<string, XlaClusterInfo>& clusters);
// Preprocesses edges within the same XLA cluster. It will perform the following
// operations in order:
//
......
......@@ -38,24 +38,11 @@ TEST(PerformStaticShapeInferenceBeforeEncapsulationTest, Basic) {
Graph g(OpRegistry::Global());
TF_CHECK_OK(s.ToGraph(&g));
// "add" node is outside compilation node, "identity" node is XLA node.
auto node_index = g.BuildNodeNameIndex();
Node *add_node = node_index["add"], *identity_node = node_index["identity"];
add_node->AddAttr("_xla", "cluster");
add_node->AddAttr("_oc", "cluster");
identity_node->AddAttr("_xla", "cluster");
TF_CHECK_OK(
PerformStaticShapeInferenceBeforeEncapsulation(&g, "_xla", "_oc"));
TF_CHECK_OK(PerformStaticShapeInferenceBeforeEncapsulation(&g));
// Check that only "add" node now has _xla_inferred_shapes attr.
std::vector<Node *> nodes_with_inferred_shape;
for (Node *n : g.nodes()) {
if (HasNodeAttr(n->def(), kXlaInferredShapesAttrName)) {
nodes_with_inferred_shape.push_back(n);
}
}
EXPECT_EQ(nodes_with_inferred_shape.size(), 1);
EXPECT_EQ(nodes_with_inferred_shape[0], add_node);
// Check that "add" node now has _xla_inferred_shapes attr.
auto node_index = g.BuildNodeNameIndex();
Node *add_node = node_index["add"];
std::vector<PartialTensorShape> output_shapes;
TF_CHECK_OK(GetNodeAttr(add_node->attrs(), kXlaInferredShapesAttrName,
&output_shapes));
......@@ -66,329 +53,4 @@ TEST(PerformStaticShapeInferenceBeforeEncapsulationTest, Basic) {
EXPECT_EQ(shape_proto.dim(0).size(), 2);
}
TEST(PreprocessForEncapsulationTest, ControlEdges) {
// Build the graph:
// "const_0" and "const_1" in host computation
// "add" = "const_0" + "const_1" in XLA computation 0
// "identity0" = "add" in XLA computation 0 & outside compilation 0
// "identity1" = "identity0" in XLA computation 0
// "identity2" = "identity1" in host computation
// "identity3" = "identity2" in XLA computation 1
// "identity4" = "identity3" in XLA computation 1 & outside compilation 1
// "identity5" = "identity4" in XLA computation 1
// "identity6" = "identity5" in host computation
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {});
Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {});
Output add = ops::Add(s.WithOpName("add"), const_0, const_1);
Output identity0 = ops::Identity(s.WithOpName("identity0"), add);
Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0);
Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
Output identity3 = ops::Identity(s.WithOpName("identity3"), identity2);
Output identity4 = ops::Identity(s.WithOpName("identity4"), identity3);
Output identity5 = ops::Identity(s.WithOpName("identity5"), identity4);
Graph g(OpRegistry::Global());
TF_CHECK_OK(s.ToGraph(&g));
auto node_index = g.BuildNodeNameIndex();
// Set XLA computation/outside compilation attr, and add control edges.
Node *const0_node = node_index["const_0"], *add_node = node_index["add"],
*identity0_node = node_index["identity0"],
*identity1_node = node_index["identity1"],
*identity2_node = node_index["identity2"],
*identity3_node = node_index["identity3"],
*identity4_node = node_index["identity4"],
*identity5_node = node_index["identity5"];
add_node->AddAttr("_xla", "0");
identity0_node->AddAttr("_xla", "0");
identity0_node->AddAttr("_oc", "0");
identity1_node->AddAttr("_xla", "0");
identity3_node->AddAttr("_xla", "1");
identity4_node->AddAttr("_xla", "1");
identity4_node->AddAttr("_oc", "0");
identity5_node->AddAttr("_xla", "1");
// Case 1a: control edges between outside compilation and another XLA
// computation.
g.AddControlEdge(identity0_node, identity3_node);
g.AddControlEdge(identity1_node, identity4_node);
// Case 1b: control edges between different outside compilations.
g.AddControlEdge(identity0_node, identity4_node);
// Case 1c: control edges between outside compilation and host computation.
g.AddControlEdge(const0_node, identity0_node);
g.AddControlEdge(identity0_node, identity2_node);
TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc"));
// Case 1a: add attr "_xla_control_deps_{from/to} = XLA computation node name"
// to the outside compilation node.
std::vector<string> attr;
TF_CHECK_OK(GetNodeAttr(identity0_node->def(),
kXlaConnectedToOtherXlaComputationAttrName, &attr));
EXPECT_EQ(attr.size(), 1);
EXPECT_EQ(attr[0], "1");
attr.clear();
TF_CHECK_OK(GetNodeAttr(identity4_node->def(),
kXlaConnectedFromOtherXlaComputationAttrName, &attr));
EXPECT_EQ(attr.size(), 1);
EXPECT_EQ(attr[0], "0");
// Case 1b: add attr "_xla_control_deps = src node name" to dst node.
attr.clear();
TF_CHECK_OK(GetNodeAttr(identity4_node->def(),
kXlaControlDependenciesAttrName, &attr));
EXPECT_EQ(attr.size(), 1);
EXPECT_EQ(attr[0], "identity0");
// Case 1c: add attr "_xla_control_deps = src node name" to dst node.
attr.clear();
TF_CHECK_OK(GetNodeAttr(identity0_node->def(),
kXlaControlDependenciesAttrName, &attr));
EXPECT_EQ(attr.size(), 1);
EXPECT_EQ(attr[0], "const_0");
attr.clear();
TF_CHECK_OK(GetNodeAttr(identity2_node->def(),
kXlaControlDependenciesAttrName, &attr));
EXPECT_EQ(attr.size(), 1);
EXPECT_EQ(attr[0], "identity0");
}
TEST(PreprocessForEncapsulationTest, DataEdges) {
// Build the graph:
// "const_0" and "const_1" in host computation
// "identityn0" = ("const_0", "const_1") in host computation 0
// "add0" = "const_0" + "const_1" in XLA computation 0
// "add1" = "add0" + "const_0" in XLA computation 0 & outside compilation 0
// "identity0" = "add1" in XLA computation 0
// "add2" = "add1" + "identity0" in host computation
// "add3" = "add1" + "add2" in XLA computation 1
// "add4" = "identity0" + "add2" in XLA computation 1 & outside compilation 0
// "add5" = "identityn0"[0] + "identityn0"[1] in XLA computation 1 &
// outside compilation 0
// "identityn1" = ("identityn0"[0], "identityn0"[1]) in XLA computation 1 &
// outside compilation 0
// "identity1" = "add4" in XLA computation 1
// "identity2" = "identity1" in host computation
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output const_0 = ops::Const(s.WithOpName("const_0"), 1, {});
Output const_1 = ops::Const(s.WithOpName("const_1"), 2, {});
auto identityn0 =
ops::IdentityN(s.WithOpName("identityn_0"), {const_0, const_1});
Output add0 = ops::Add(s.WithOpName("add0"), const_0, const_1);
Output add1 = ops::Add(s.WithOpName("add1"), add0, const_0);
Output identity0 = ops::Identity(s.WithOpName("identity0"), add1);
Output add2 = ops::Add(s.WithOpName("add2"), add1, identity0);
Output add3 = ops::Add(s.WithOpName("add3"), add1, add2);
Output add4 = ops::Add(s.WithOpName("add4"), identity0, add2);
Output add5 = ops::Add(s.WithOpName("add5"), identityn0[0], identityn0[1]);
auto identityn1 = ops::IdentityN(s.WithOpName("identityn_1"),
{identityn0[0], identityn0[1]});
Output identity1 = ops::Identity(s.WithOpName("identity1"), add4);
Output identity2 = ops::Identity(s.WithOpName("identity2"), add4);
Graph g(OpRegistry::Global());
TF_CHECK_OK(s.ToGraph(&g));
auto node_index = g.BuildNodeNameIndex();
// Set XLA computation/outside compilation attr.
Node *add0_node = node_index["add0"], *add1_node = node_index["add1"],
*identity0_node = node_index["identity0"],
*add3_node = node_index["add3"], *add4_node = node_index["add4"],
*add5_node = node_index["add5"],
*identityn1_node = node_index["identityn_1"],
*identity1_node = node_index["identity1"];
add0_node->AddAttr("_xla", "0");
add1_node->AddAttr("_xla", "0");
add1_node->AddAttr("_oc", "0");
identity0_node->AddAttr("_xla", "0");
add3_node->AddAttr("_xla", "1");
add4_node->AddAttr("_xla", "1");
add4_node->AddAttr("_oc", "0");
add5_node->AddAttr("_xla", "1");
add5_node->AddAttr("_oc", "0");
identityn1_node->AddAttr("_xla", "1");
identityn1_node->AddAttr("_oc", "0");
identity1_node->AddAttr("_xla", "1");
TF_CHECK_OK(PreprocessForEncapsulation(&g, "_xla", "_oc"));
// Check input nodes for related data edges.
node_index = g.BuildNodeNameIndex();
// Step 2: add an Identity node between different XLA computations.
Node *bridge_add1_add3 = node_index["bridge_add1_add3"];
EXPECT_NE(bridge_add1_add3, nullptr);
string str;
TF_CHECK_OK(
GetNodeAttr(bridge_add1_add3->attrs(), kBridgeSourceNodeAttrName, &str));
EXPECT_EQ(str, "add1");
Node *bridge_identity0_add4 = node_index["bridge_identity0_add4"];
EXPECT_NE(bridge_identity0_add4, nullptr);
// Step 3: add placeholder for edges between host computation and outside
// compilation.
EXPECT_EQ(bridge_add1_add3->def().input(0), "add1_oc_to_host_placeholder_0");
Node *add1_oc_to_host_placeholder =
node_index["add1_oc_to_host_placeholder_0"];
TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(),
kOutsideCompilationToHostOriginalNodeAttrName, &str));
EXPECT_EQ(str, "add1");
int i;
TF_CHECK_OK(GetNodeAttr(add1_oc_to_host_placeholder->attrs(),
kOutsideCompilationToHostSrcOutputAttrName, &i));
EXPECT_EQ(i, 0);
add4_node = node_index["add4"];
ASSERT_NE(add4_node, nullptr);
EXPECT_EQ(add4_node->def().input(0),
"bridge_identity0_add4_host_to_oc_placeholder_0");
Node *identity0_host_to_oc_placeholder =
node_index["bridge_identity0_add4_host_to_oc_placeholder_0"];
TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(),
kHostToOutsideCompilationOriginalNodeAttrName, &str));
EXPECT_EQ(str, "bridge_identity0_add4");
TF_CHECK_OK(GetNodeAttr(identity0_host_to_oc_placeholder->attrs(),
kHostToOutsideCompilationSrcOutputAttrName, &i));
EXPECT_EQ(i, 0);
// Check different placeholder nodes are created for different src_output.
Node *placeholder0 = node_index["identityn_0_host_to_oc_placeholder_0"],
*placeholder1 = node_index["identityn_0_host_to_oc_placeholder_1"];
EXPECT_NE(placeholder0, nullptr);
EXPECT_NE(placeholder1, nullptr);
// Check we only have 2 placeholder nodes created for "identityn_0".
int placeholder_count = 0;
for (Node *n : g.nodes()) {
if (HasNodeAttr(n->def(), kHostToOutsideCompilationOriginalNodeAttrName)) {
string attr;
TF_CHECK_OK(GetNodeAttr(
n->attrs(), kHostToOutsideCompilationOriginalNodeAttrName, &attr));
if (attr == "identityn_0") {
++placeholder_count;
}
}
}
EXPECT_EQ(placeholder_count, 2);
}
TEST(PostprocessForEncapsulationTest, ControlEdges) {
// Build the graph:
// "const0"
// "identity0" = "const0" (XLA computation 0)
// "identity1" = "identity0"
// "identity2" = "identity1" (XLA computation 1)
// "identity3" = "identity2"
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output const0 = ops::Const(s.WithOpName("const0"), 1, {});
Output identity0 = ops::Identity(s.WithOpName("identity0"), const0);
Output identity1 = ops::Identity(s.WithOpName("identity1"), identity0);
Output identity2 = ops::Identity(s.WithOpName("identity2"), identity1);
Output identity3 = ops::Identity(s.WithOpName("identity3"), identity2);
Graph g(OpRegistry::Global());
TF_CHECK_OK(s.ToGraph(&g));
auto node_index = g.BuildNodeNameIndex();
// Set XLA computation/outside compilation attr, and add control edges.
Node *const0_node = node_index["const0"],
*identity0_node = node_index["identity0"],
*identity1_node = node_index["identity1"],
*identity2_node = node_index["identity2"],
*identity3_node = node_index["identity3"];
identity1_node->AddAttr(kXlaConnectedFromOtherXlaComputationAttrName,
std::vector<string>{"0"});
identity1_node->AddAttr(kXlaConnectedToOtherXlaComputationAttrName,
std::vector<string>{"1"});
identity3_node->AddAttr(kXlaControlDependenciesAttrName,
std::vector<string>{"const0", "identity1"});
std::unordered_map<string, XlaClusterInfo> clusters;
clusters["0"].node = identity0_node;
clusters["1"].node = identity2_node;
TF_CHECK_OK(PostprocessForEncapsulation(&g, "_xla", "_oc", clusters));
// Case 3a: we have control edge identity0 -> identity1, and identity1 ->
// identity2.
bool edge_identity0_identity1 = false, edge_identity1_identity2 = false;
for (const Edge *e : g.edges()) {
if (!e->IsControlEdge()) {
continue;
}
if (e->src() == identity0_node && e->dst() == identity1_node) {
edge_identity0_identity1 = true;
} else if (e->src() == identity1_node && e->dst() == identity2_node) {
edge_identity1_identity2 = true;
}
}
EXPECT_TRUE(edge_identity0_identity1);
EXPECT_TRUE(edge_identity1_identity2);
// Case 3b: we have control edge const0 -> identity3, and identity1 ->
// identity3.
bool edge_const0_identity3 = false, edge_identity1_identity3 = false;
for (const Edge *e : g.edges()) {
if (!e->IsControlEdge()) {
continue;
}
if (e->src() == const0_node && e->dst() == identity3_node) {
edge_const0_identity3 = true;
} else if (e->src() == identity1_node && e->dst() == identity3_node) {
edge_identity1_identity3 = true;
}
}
EXPECT_TRUE(edge_const0_identity3);
EXPECT_TRUE(edge_identity1_identity3);
}
TEST(PostprocessForEncapsulationTest, DataEdges) {
// Build the graph:
// "const0" in outside compilation "0"
// "placeholder0" (for "const0") in host computation
// "add0" = "placeholder0" + "placeholder0" in host computation
// "placeholder1" (for "add0") in outside compilation 1
// "add1" = "placeholder1" + "placeholder1" in outside compilation 1
//
// "bridge" = "placeholder0" in host computation
// "placeholder2" (for "bridge") in outside compilation 1
// "add2" = "placeholder2" + "placeholder2" in outside compilation 1
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output const0 = ops::Const(s.WithOpName("const0"), 1, {});
Output placeholder0 =
ops::Placeholder(s.WithOpName("placeholder0"), DT_INT32);
Output add0 = ops::Add(s.WithOpName("add0"), placeholder0, placeholder0);
Output placeholder1 =
ops::Placeholder(s.WithOpName("placeholder1"), DT_INT32);
Output add1 = ops::Add(s.WithOpName("add1"), placeholder1, placeholder1);
Output bridge = ops::Identity(s.WithOpName("bridge"), placeholder0);
Output placeholder2 =
ops::Placeholder(s.WithOpName("placeholder2"), DT_INT32);
Output add2 = ops::Add(s.WithOpName("add2"), placeholder2, placeholder2);
Graph g(OpRegistry::Global());
TF_CHECK_OK(s.ToGraph(&g));
auto node_index = g.BuildNodeNameIndex();
// Set related attributes.
Node *placeholder0_node = node_index["placeholder0"];
placeholder0_node->AddAttr(kOutsideCompilationToHostOriginalNodeAttrName,
"const0");
placeholder0_node->AddAttr(kOutsideCompilationToHostSrcOutputAttrName, 0);
Node *placeholder1_node = node_index["placeholder1"];
placeholder1_node->AddAttr(kHostToOutsideCompilationOriginalNodeAttrName,
"add0");
placeholder1_node->AddAttr(kHostToOutsideCompilationSrcOutputAttrName, 0);
Node *bridge_node = node_index["bridge"];
bridge_node->AddAttr(kBridgeSourceNodeAttrName, "const0");
Node *placeholder2_node = node_index["placeholder2"];
placeholder2_node->AddAttr(kHostToOutsideCompilationOriginalNodeAttrName,
"bridge");
placeholder2_node->AddAttr(kHostToOutsideCompilationSrcOutputAttrName, 0);
std::unordered_map<string, XlaClusterInfo> clusters;
TF_CHECK_OK(PostprocessForEncapsulation(&g, "_xla", "_oc", clusters));
// Result graph should be:
// "add0" = "const0" + "const0"
// "add1" = "add0" + "add0"
// "add2" = "const0" + "const0"
node_index = g.BuildNodeNameIndex();
EXPECT_EQ(node_index.size(), 6);
EXPECT_EQ(node_index["add0"]->def().input(0), "const0:0");
EXPECT_EQ(node_index["add0"]->def().input(1), "const0:0");
EXPECT_EQ(node_index["add1"]->def().input(0), "add0:0");
EXPECT_EQ(node_index["add1"]->def().input(1), "add0:0");
EXPECT_EQ(node_index["add2"]->def().input(0), "const0:0");
EXPECT_EQ(node_index["add2"]->def().input(1), "const0:0");
}
} // namespace tensorflow
......@@ -634,17 +634,14 @@ Status ExpandHostGraphIntoMainGraph(Graph* main_graph,
return s;
}
// Rewrites shape inference graph for outside compilation.
// 1. If the outside compilation is a "top-level" one (not in a function of any
// If/While/etc.), this shape inference graph might have host computation to
// outside compilation placeholder nodes, which will cause shape inference to
// fail. However, those nodes are not in `host_graph` any more (because we
// have executed `PostprocessForEncapsultion`). In this case, we clear the
// graph, and copy SendFromHost with all its predecessors from `host_graph`.
// This case is detected by whether the SendFromHost node exists in
// `host_graph` as well.
// 2. Remove control edges, and prune nodes that are not useful for shape
// inference.
// Rewrites shape inference graph for outside compilation:
// 1) If XlaSendFromHost also exists in `host_graph`, copy nodes from
// `host_graph`. Because we might still have outside compilation to outside
// compilation placeholder nodes in shape inference graph, which will prevent
// us from inferring XlaSendFromHost shape. But in `host_graph`, we already
// removed those placeholder nodes.
// 2) Remove control edges.
// 3) Prune nodes that are not useful for shape inference.
Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name,
Graph* host_graph,
FunctionLibraryDefinition* fld) {
......@@ -744,6 +741,7 @@ Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name,
g->RemoveEdge(e);
}
}
// Nodes that are not reverse reachable from SendFromHost are not useful for
// shape inference. Prune them.
PruneForReverseReachability(g,
......@@ -1581,14 +1579,6 @@ Status ExtractOutsideCompilation(
TF_RETURN_IF_ERROR(fld->RemoveFunction(host_graph_func_name));
}
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile("extract_outside_compilation_expanded", *g,
fld);
}
TF_RETURN_IF_ERROR(PostprocessForEncapsulation(
g, xla_cluster_attr_name, outside_compilation_attr_name, clusters));
for (auto shape_inference_graph_name : shape_inference_graphs) {
TF_RETURN_IF_ERROR(
RewriteShapeInferenceGraph(shape_inference_graph_name, g, fld));
......
......@@ -53,7 +53,15 @@ Status PropagateShapes(const Graph& graph,
// shapes, even if no shape function is registered for a node.
Status status = shape_refiner->AddNode(n);
if (!status.ok()) {
VLOG(1) << "Shape inference failed for node: " << status;
VLOG(1) << "Shape inference failed for node " << n->name() << ": "
<< status;
} else {
shape_inference::InferenceContext* context = shape_refiner->GetContext(n);
for (int i = 0; i < n->num_outputs(); i++) {
shape_inference::ShapeHandle handle = context->output(i);
VLOG(4) << "Output " << i << " for node " << n->name() << ": "
<< context->DebugString(handle);
}
}
if (n->type_string() == "_Arg") {
......
......@@ -646,6 +646,10 @@ def split_compile_and_replicate(computation,
array_ops.identity(x, name="replicated_input_{}".format(i))
for i, x in enumerate(computation_inputs)
]
for i in computation_inputs:
# pylint: disable=protected-access
i.op._set_attr("_tpu_input_identity", attr_value_pb2.AttrValue(b=True))
# pylint: enable=protected-access
# If there is an infeed queue, adds the dequeued values to the
# computation's inputs.
......@@ -726,7 +730,11 @@ def split_compile_and_replicate(computation,
new_output_tensors = []
for t in output_tensors:
with ops.device(t.device if t.device else core(0)):
new_output_tensors.append(array_ops.identity(t))
o = array_ops.identity(t)
# pylint: disable=protected-access
o.op._set_attr("_tpu_output_identity", attr_value_pb2.AttrValue(b=True))
# pylint: enable=protected-access
new_output_tensors.append(o)
output_tensors = new_output_tensors
context.ExitResult(output_tensors)
finally:
......
......@@ -2280,7 +2280,7 @@ class TPUEstimator(estimator_lib.Estimator):
(k, _export_output_to_tensors(v))
for k, v in six.iteritems(estimator_spec.export_outputs))
tensors = nest.flatten(tensors_dict)
tpu_tensors = [t for t in tensors if _is_tpu_tensor(t)]
tpu_tensors = [t for t in tensors if t is not None]
# We cannot return anything other than `tpu_tensors` here so we capture
# the rest for later use.
......@@ -2294,18 +2294,10 @@ class TPUEstimator(estimator_lib.Estimator):
# `tpu_tensors_on_cpu`.
new_tensors = []
for t in tensors:
if _is_tpu_tensor(t):
new_tensors.append(tpu_tensors_on_cpu.pop(0))
elif t is None:
if t is None:
new_tensors.append(None)
else:
# Only fetching `tpu_tensors_on_cpu` does not trigger
# TPU computation and blocks, so we add the control dependency here.
control_inputs = (
tpu_tensors_on_cpu if _is_iterable(tpu_tensors_on_cpu) else
(tpu_tensors_on_cpu,))
with ops.control_dependencies(control_inputs):
new_tensors.append(array_ops.identity(t))
new_tensors.append(tpu_tensors_on_cpu.pop(0))
# Reconstruct `tensors_dict`.
new_tensors_dict = nest.pack_sequence_as(tensors_dict, new_tensors)
......@@ -2798,17 +2790,6 @@ class TPUEstimator(estimator_lib.Estimator):
return _model_fn
def _is_tpu_tensor(tensor):
if not isinstance(tensor, ops.Tensor):
return False
try:
tensor.op.get_attr(tpu._OUTSIDE_COMPILATION_ATTR) # pylint: disable=protected-access
except ValueError:
return True
else:
return False
def _export_output_to_tensors(export_output):
"""Get a list of `Tensors` used in `export_output`.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册