Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
qq_38905368
tensorflow
提交
4d810848
T
tensorflow
项目概览
qq_38905368
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
4d810848
编写于
12月 13, 2018
作者:
T
Tong Shen
提交者:
TensorFlower Gardener
12月 13, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Handle outside compilation at beginning/end of TPU computation.
PiperOrigin-RevId: 225396866
上级
150b4c8e
变更
8
展开全部
隐藏空白更改
内联
并排
Showing
8 changed file
with
330 addition
and
1196 deletion
+330
-1196
tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
+288
-132
tensorflow/compiler/jit/encapsulate_util.cc
tensorflow/compiler/jit/encapsulate_util.cc
+5
-589
tensorflow/compiler/jit/encapsulate_util.h
tensorflow/compiler/jit/encapsulate_util.h
+3
-90
tensorflow/compiler/jit/encapsulate_util_test.cc
tensorflow/compiler/jit/encapsulate_util_test.cc
+4
-342
tensorflow/compiler/jit/extract_outside_compilation_pass.cc
tensorflow/compiler/jit/extract_outside_compilation_pass.cc
+9
-19
tensorflow/compiler/jit/shape_inference.cc
tensorflow/compiler/jit/shape_inference.cc
+9
-1
tensorflow/contrib/tpu/python/tpu/tpu.py
tensorflow/contrib/tpu/python/tpu/tpu.py
+9
-1
tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+3
-22
未找到文件。
tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
浏览文件 @
4d810848
此差异已折叠。
点击以展开。
tensorflow/compiler/jit/encapsulate_util.cc
浏览文件 @
4d810848
此差异已折叠。
点击以展开。
tensorflow/compiler/jit/encapsulate_util.h
浏览文件 @
4d810848
...
...
@@ -27,51 +27,13 @@ namespace tensorflow {
// a list of PartialTensorShape objects.
extern
const
char
kXlaInferredShapesAttrName
[];
// Infer output shapes for outside compilation nodes which have output data
// edges to XLA computation nodes. These shapes will be used later by XLA
// compiler as output shapes of the outside compilation's XlaHostCompute op.
// XLA computation nodes will be mark by attr `xla_computation_attr_name`;
// outside compilation nodes will be marked by both attr
// `xla_computation_attr_name` and `outside_compilation_attr_name`.
//
// Those outside compilation nodes will be marked with attribute
// `kXlaInferredShapesAttrName`.
// Infers output shapes for all nodes in graph `g`. The output shapes will be
// stored in node attribute `kXlaInferredShapesAttrName`.
//
// We have to perform shape inference before encapsulation because after
// encapsulation, some nodes will be encapsulated into function call, and shape
// inference does not handle function call at the moment.
Status
PerformStaticShapeInferenceBeforeEncapsulation
(
Graph
*
g
,
const
string
&
xla_computation_attr_name
,
const
string
&
outside_compilation_attr_name
);
// Attribute indicating that some ops in other XLA computation has control
// dependency on this node. Attribute value will be a list of string (XLA
// computation names).
extern
const
char
kXlaConnectedToOtherXlaComputationAttrName
[];
// Attribute indicating that this node has control dependency on some ops in
// other XLA computation. Attribute value will be a list of string (XLA
// computation names).
extern
const
char
kXlaConnectedFromOtherXlaComputationAttrName
[];
// Attribute indicating that this node has control dependencies on some other
// nodes. Attribute value will be a list of string (node names).
extern
const
char
kXlaControlDependenciesAttrName
[];
// Attribute indicating that this is an Identity node added to act as a bridge
// between different XLA computations. Attribute value will be string (source
// node name).
extern
const
char
kBridgeSourceNodeAttrName
[];
// Attribute indicating that this is an Placeholder node added to act as a
// temporary input node for an outside compilation node. Attribute value will be
// string (original input node name).
extern
const
char
kOutsideCompilationToHostOriginalNodeAttrName
[];
// Attribute indicating that this is an Placeholder node added to act as a
// temporary input node for an outside compilation node. Attribute value will be
// int (src_output for original edge).
extern
const
char
kOutsideCompilationToHostSrcOutputAttrName
[];
Status
PerformStaticShapeInferenceBeforeEncapsulation
(
Graph
*
g
);
// Attribute indicating that some ops in this node's XLA computation has control
// dependency on this node. Attribute value will always be "true".
...
...
@@ -81,16 +43,6 @@ extern const char kXlaConnectedToXlaComputationAttrName[];
// this node's XLA computation. Attribute value will always be "true".
extern
const
char
kXlaConnectedFromXlaComputationAttrName
[];
// Attribute indicating that this is an Placeholder node added to act as a
// temporary input node for an host node. Attribute value will be string
// (original input node name).
extern
const
char
kHostToOutsideCompilationOriginalNodeAttrName
[];
// Attribute indicating that this is an Placeholder node added to act as a
// temporary input node for a host node. Attribute value will be int (src_output
// for original edge).
extern
const
char
kHostToOutsideCompilationSrcOutputAttrName
[];
// Attribute indicating that this is an Placeholder node added to act as a
// temporary input node for an outside compilation node. Attribute value will be
// string (original input node name).
...
...
@@ -106,27 +58,6 @@ extern const char kOutsideCompilationSrcOutputAttrName[];
// (node names).
extern
const
char
kXlaControlDependenciesWithinXlaClusterAttrName
[];
// Preprocesses edges between different XLA clusters for encapsulation. It will
// perform the following operations in order:
//
// 1a. For control edges between outside compilation and another XLA
// computation, add attr "kXlaConnected{From, To}OtherXlaComputationAttrName
// = XLA computation node name" to the outside compilation node.
// 1b. For control edges between different outside compilations (in different
// XLA computations), remove the edge and add attr
// "kXlaControlDependenciesAttrName = src node name" to dst node.
// 1c. For control edges between outside compilation and host computation,
// remove the edge and add attr "kXlaControlDependenciesAttrName = src node
// name" to dst node.
// 2. For data edges between different XLA computations, if either src or dst
// is outside compilation, add an Identity node in between the edge. The
// identity node will have attr kBridgeSourceNodeAttrName.
// 3. For data edges between outside compilation and host computation, remove
// the edge and create a Placeholder node as dst node's input.
Status
PreprocessForEncapsulation
(
Graph
*
g
,
const
string
&
xla_computation_attr_name
,
const
string
&
outside_compilation_attr_name
);
// Information for XLA computation.
struct
XlaClusterInfo
{
// Add an explicitly-defined default constructor for this class.
...
...
@@ -158,24 +89,6 @@ struct XlaClusterInfo {
const
std
::
map
<
string
,
int
>
host_compute_core
;
};
// Postprocesses edges between different XLA clusters for encapsulation. This
// function reverts what `PreprocessForEncapsulation` did. It will perform the
// following operations in order:
//
// 1. Remove Placeholder nodes between outside compilation and host computation
// (created in `PreprocessForEncapsulation` step 3).
// 2. Remove Identity nodes created in `PreprocessForEncapsulation` step 2.
// 3a. Reconnect control edges between outside compilation and another XLA
// computation (marked by `PreprocessForEncapsulation` step 1a).
// 3b. Reconnect control edges between different outside compilations (marked by
// `PreprocessForEncapsulation` step 1b).
// 3c. Reconnect control edges between outside compilation and host computation
// (marked by `PreprocessForEncapsulation` step 1c).
Status
PostprocessForEncapsulation
(
Graph
*
g
,
const
string
&
xla_computation_attr_name
,
const
string
&
outside_compilation_attr_name
,
const
std
::
unordered_map
<
string
,
XlaClusterInfo
>&
clusters
);
// Preprocesses edges within the same XLA cluster. It will perform the following
// operations in order:
//
...
...
tensorflow/compiler/jit/encapsulate_util_test.cc
浏览文件 @
4d810848
...
...
@@ -38,24 +38,11 @@ TEST(PerformStaticShapeInferenceBeforeEncapsulationTest, Basic) {
Graph
g
(
OpRegistry
::
Global
());
TF_CHECK_OK
(
s
.
ToGraph
(
&
g
));
// "add" node is outside compilation node, "identity" node is XLA node.
auto
node_index
=
g
.
BuildNodeNameIndex
();
Node
*
add_node
=
node_index
[
"add"
],
*
identity_node
=
node_index
[
"identity"
];
add_node
->
AddAttr
(
"_xla"
,
"cluster"
);
add_node
->
AddAttr
(
"_oc"
,
"cluster"
);
identity_node
->
AddAttr
(
"_xla"
,
"cluster"
);
TF_CHECK_OK
(
PerformStaticShapeInferenceBeforeEncapsulation
(
&
g
,
"_xla"
,
"_oc"
));
TF_CHECK_OK
(
PerformStaticShapeInferenceBeforeEncapsulation
(
&
g
));
// Check that only "add" node now has _xla_inferred_shapes attr.
std
::
vector
<
Node
*>
nodes_with_inferred_shape
;
for
(
Node
*
n
:
g
.
nodes
())
{
if
(
HasNodeAttr
(
n
->
def
(),
kXlaInferredShapesAttrName
))
{
nodes_with_inferred_shape
.
push_back
(
n
);
}
}
EXPECT_EQ
(
nodes_with_inferred_shape
.
size
(),
1
);
EXPECT_EQ
(
nodes_with_inferred_shape
[
0
],
add_node
);
// Check that "add" node now has _xla_inferred_shapes attr.
auto
node_index
=
g
.
BuildNodeNameIndex
();
Node
*
add_node
=
node_index
[
"add"
];
std
::
vector
<
PartialTensorShape
>
output_shapes
;
TF_CHECK_OK
(
GetNodeAttr
(
add_node
->
attrs
(),
kXlaInferredShapesAttrName
,
&
output_shapes
));
...
...
@@ -66,329 +53,4 @@ TEST(PerformStaticShapeInferenceBeforeEncapsulationTest, Basic) {
EXPECT_EQ
(
shape_proto
.
dim
(
0
).
size
(),
2
);
}
TEST
(
PreprocessForEncapsulationTest
,
ControlEdges
)
{
// Build the graph:
// "const_0" and "const_1" in host computation
// "add" = "const_0" + "const_1" in XLA computation 0
// "identity0" = "add" in XLA computation 0 & outside compilation 0
// "identity1" = "identity0" in XLA computation 0
// "identity2" = "identity1" in host computation
// "identity3" = "identity2" in XLA computation 1
// "identity4" = "identity3" in XLA computation 1 & outside compilation 1
// "identity5" = "identity4" in XLA computation 1
// "identity6" = "identity5" in host computation
tensorflow
::
Scope
s
=
tensorflow
::
Scope
::
NewRootScope
();
Output
const_0
=
ops
::
Const
(
s
.
WithOpName
(
"const_0"
),
1
,
{});
Output
const_1
=
ops
::
Const
(
s
.
WithOpName
(
"const_1"
),
2
,
{});
Output
add
=
ops
::
Add
(
s
.
WithOpName
(
"add"
),
const_0
,
const_1
);
Output
identity0
=
ops
::
Identity
(
s
.
WithOpName
(
"identity0"
),
add
);
Output
identity1
=
ops
::
Identity
(
s
.
WithOpName
(
"identity1"
),
identity0
);
Output
identity2
=
ops
::
Identity
(
s
.
WithOpName
(
"identity2"
),
identity1
);
Output
identity3
=
ops
::
Identity
(
s
.
WithOpName
(
"identity3"
),
identity2
);
Output
identity4
=
ops
::
Identity
(
s
.
WithOpName
(
"identity4"
),
identity3
);
Output
identity5
=
ops
::
Identity
(
s
.
WithOpName
(
"identity5"
),
identity4
);
Graph
g
(
OpRegistry
::
Global
());
TF_CHECK_OK
(
s
.
ToGraph
(
&
g
));
auto
node_index
=
g
.
BuildNodeNameIndex
();
// Set XLA computation/outside compilation attr, and add control edges.
Node
*
const0_node
=
node_index
[
"const_0"
],
*
add_node
=
node_index
[
"add"
],
*
identity0_node
=
node_index
[
"identity0"
],
*
identity1_node
=
node_index
[
"identity1"
],
*
identity2_node
=
node_index
[
"identity2"
],
*
identity3_node
=
node_index
[
"identity3"
],
*
identity4_node
=
node_index
[
"identity4"
],
*
identity5_node
=
node_index
[
"identity5"
];
add_node
->
AddAttr
(
"_xla"
,
"0"
);
identity0_node
->
AddAttr
(
"_xla"
,
"0"
);
identity0_node
->
AddAttr
(
"_oc"
,
"0"
);
identity1_node
->
AddAttr
(
"_xla"
,
"0"
);
identity3_node
->
AddAttr
(
"_xla"
,
"1"
);
identity4_node
->
AddAttr
(
"_xla"
,
"1"
);
identity4_node
->
AddAttr
(
"_oc"
,
"0"
);
identity5_node
->
AddAttr
(
"_xla"
,
"1"
);
// Case 1a: control edges between outside compilation and another XLA
// computation.
g
.
AddControlEdge
(
identity0_node
,
identity3_node
);
g
.
AddControlEdge
(
identity1_node
,
identity4_node
);
// Case 1b: control edges between different outside compilations.
g
.
AddControlEdge
(
identity0_node
,
identity4_node
);
// Case 1c: control edges between outside compilation and host computation.
g
.
AddControlEdge
(
const0_node
,
identity0_node
);
g
.
AddControlEdge
(
identity0_node
,
identity2_node
);
TF_CHECK_OK
(
PreprocessForEncapsulation
(
&
g
,
"_xla"
,
"_oc"
));
// Case 1a: add attr "_xla_control_deps_{from/to} = XLA computation node name"
// to the outside compilation node.
std
::
vector
<
string
>
attr
;
TF_CHECK_OK
(
GetNodeAttr
(
identity0_node
->
def
(),
kXlaConnectedToOtherXlaComputationAttrName
,
&
attr
));
EXPECT_EQ
(
attr
.
size
(),
1
);
EXPECT_EQ
(
attr
[
0
],
"1"
);
attr
.
clear
();
TF_CHECK_OK
(
GetNodeAttr
(
identity4_node
->
def
(),
kXlaConnectedFromOtherXlaComputationAttrName
,
&
attr
));
EXPECT_EQ
(
attr
.
size
(),
1
);
EXPECT_EQ
(
attr
[
0
],
"0"
);
// Case 1b: add attr "_xla_control_deps = src node name" to dst node.
attr
.
clear
();
TF_CHECK_OK
(
GetNodeAttr
(
identity4_node
->
def
(),
kXlaControlDependenciesAttrName
,
&
attr
));
EXPECT_EQ
(
attr
.
size
(),
1
);
EXPECT_EQ
(
attr
[
0
],
"identity0"
);
// Case 1c: add attr "_xla_control_deps = src node name" to dst node.
attr
.
clear
();
TF_CHECK_OK
(
GetNodeAttr
(
identity0_node
->
def
(),
kXlaControlDependenciesAttrName
,
&
attr
));
EXPECT_EQ
(
attr
.
size
(),
1
);
EXPECT_EQ
(
attr
[
0
],
"const_0"
);
attr
.
clear
();
TF_CHECK_OK
(
GetNodeAttr
(
identity2_node
->
def
(),
kXlaControlDependenciesAttrName
,
&
attr
));
EXPECT_EQ
(
attr
.
size
(),
1
);
EXPECT_EQ
(
attr
[
0
],
"identity0"
);
}
TEST
(
PreprocessForEncapsulationTest
,
DataEdges
)
{
// Build the graph:
// "const_0" and "const_1" in host computation
// "identityn0" = ("const_0", "const_1") in host computation 0
// "add0" = "const_0" + "const_1" in XLA computation 0
// "add1" = "add0" + "const_0" in XLA computation 0 & outside compilation 0
// "identity0" = "add1" in XLA computation 0
// "add2" = "add1" + "identity0" in host computation
// "add3" = "add1" + "add2" in XLA computation 1
// "add4" = "identity0" + "add2" in XLA computation 1 & outside compilation 0
// "add5" = "identityn0"[0] + "identityn0"[1] in XLA computation 1 &
// outside compilation 0
// "identityn1" = ("identityn0"[0], "identityn0"[1]) in XLA computation 1 &
// outside compilation 0
// "identity1" = "add4" in XLA computation 1
// "identity2" = "identity1" in host computation
tensorflow
::
Scope
s
=
tensorflow
::
Scope
::
NewRootScope
();
Output
const_0
=
ops
::
Const
(
s
.
WithOpName
(
"const_0"
),
1
,
{});
Output
const_1
=
ops
::
Const
(
s
.
WithOpName
(
"const_1"
),
2
,
{});
auto
identityn0
=
ops
::
IdentityN
(
s
.
WithOpName
(
"identityn_0"
),
{
const_0
,
const_1
});
Output
add0
=
ops
::
Add
(
s
.
WithOpName
(
"add0"
),
const_0
,
const_1
);
Output
add1
=
ops
::
Add
(
s
.
WithOpName
(
"add1"
),
add0
,
const_0
);
Output
identity0
=
ops
::
Identity
(
s
.
WithOpName
(
"identity0"
),
add1
);
Output
add2
=
ops
::
Add
(
s
.
WithOpName
(
"add2"
),
add1
,
identity0
);
Output
add3
=
ops
::
Add
(
s
.
WithOpName
(
"add3"
),
add1
,
add2
);
Output
add4
=
ops
::
Add
(
s
.
WithOpName
(
"add4"
),
identity0
,
add2
);
Output
add5
=
ops
::
Add
(
s
.
WithOpName
(
"add5"
),
identityn0
[
0
],
identityn0
[
1
]);
auto
identityn1
=
ops
::
IdentityN
(
s
.
WithOpName
(
"identityn_1"
),
{
identityn0
[
0
],
identityn0
[
1
]});
Output
identity1
=
ops
::
Identity
(
s
.
WithOpName
(
"identity1"
),
add4
);
Output
identity2
=
ops
::
Identity
(
s
.
WithOpName
(
"identity2"
),
add4
);
Graph
g
(
OpRegistry
::
Global
());
TF_CHECK_OK
(
s
.
ToGraph
(
&
g
));
auto
node_index
=
g
.
BuildNodeNameIndex
();
// Set XLA computation/outside compilation attr.
Node
*
add0_node
=
node_index
[
"add0"
],
*
add1_node
=
node_index
[
"add1"
],
*
identity0_node
=
node_index
[
"identity0"
],
*
add3_node
=
node_index
[
"add3"
],
*
add4_node
=
node_index
[
"add4"
],
*
add5_node
=
node_index
[
"add5"
],
*
identityn1_node
=
node_index
[
"identityn_1"
],
*
identity1_node
=
node_index
[
"identity1"
];
add0_node
->
AddAttr
(
"_xla"
,
"0"
);
add1_node
->
AddAttr
(
"_xla"
,
"0"
);
add1_node
->
AddAttr
(
"_oc"
,
"0"
);
identity0_node
->
AddAttr
(
"_xla"
,
"0"
);
add3_node
->
AddAttr
(
"_xla"
,
"1"
);
add4_node
->
AddAttr
(
"_xla"
,
"1"
);
add4_node
->
AddAttr
(
"_oc"
,
"0"
);
add5_node
->
AddAttr
(
"_xla"
,
"1"
);
add5_node
->
AddAttr
(
"_oc"
,
"0"
);
identityn1_node
->
AddAttr
(
"_xla"
,
"1"
);
identityn1_node
->
AddAttr
(
"_oc"
,
"0"
);
identity1_node
->
AddAttr
(
"_xla"
,
"1"
);
TF_CHECK_OK
(
PreprocessForEncapsulation
(
&
g
,
"_xla"
,
"_oc"
));
// Check input nodes for related data edges.
node_index
=
g
.
BuildNodeNameIndex
();
// Step 2: add an Identity node between different XLA computations.
Node
*
bridge_add1_add3
=
node_index
[
"bridge_add1_add3"
];
EXPECT_NE
(
bridge_add1_add3
,
nullptr
);
string
str
;
TF_CHECK_OK
(
GetNodeAttr
(
bridge_add1_add3
->
attrs
(),
kBridgeSourceNodeAttrName
,
&
str
));
EXPECT_EQ
(
str
,
"add1"
);
Node
*
bridge_identity0_add4
=
node_index
[
"bridge_identity0_add4"
];
EXPECT_NE
(
bridge_identity0_add4
,
nullptr
);
// Step 3: add placeholder for edges between host computation and outside
// compilation.
EXPECT_EQ
(
bridge_add1_add3
->
def
().
input
(
0
),
"add1_oc_to_host_placeholder_0"
);
Node
*
add1_oc_to_host_placeholder
=
node_index
[
"add1_oc_to_host_placeholder_0"
];
TF_CHECK_OK
(
GetNodeAttr
(
add1_oc_to_host_placeholder
->
attrs
(),
kOutsideCompilationToHostOriginalNodeAttrName
,
&
str
));
EXPECT_EQ
(
str
,
"add1"
);
int
i
;
TF_CHECK_OK
(
GetNodeAttr
(
add1_oc_to_host_placeholder
->
attrs
(),
kOutsideCompilationToHostSrcOutputAttrName
,
&
i
));
EXPECT_EQ
(
i
,
0
);
add4_node
=
node_index
[
"add4"
];
ASSERT_NE
(
add4_node
,
nullptr
);
EXPECT_EQ
(
add4_node
->
def
().
input
(
0
),
"bridge_identity0_add4_host_to_oc_placeholder_0"
);
Node
*
identity0_host_to_oc_placeholder
=
node_index
[
"bridge_identity0_add4_host_to_oc_placeholder_0"
];
TF_CHECK_OK
(
GetNodeAttr
(
identity0_host_to_oc_placeholder
->
attrs
(),
kHostToOutsideCompilationOriginalNodeAttrName
,
&
str
));
EXPECT_EQ
(
str
,
"bridge_identity0_add4"
);
TF_CHECK_OK
(
GetNodeAttr
(
identity0_host_to_oc_placeholder
->
attrs
(),
kHostToOutsideCompilationSrcOutputAttrName
,
&
i
));
EXPECT_EQ
(
i
,
0
);
// Check different placeholder nodes are created for different src_output.
Node
*
placeholder0
=
node_index
[
"identityn_0_host_to_oc_placeholder_0"
],
*
placeholder1
=
node_index
[
"identityn_0_host_to_oc_placeholder_1"
];
EXPECT_NE
(
placeholder0
,
nullptr
);
EXPECT_NE
(
placeholder1
,
nullptr
);
// Check we only have 2 placeholder nodes created for "identityn_0".
int
placeholder_count
=
0
;
for
(
Node
*
n
:
g
.
nodes
())
{
if
(
HasNodeAttr
(
n
->
def
(),
kHostToOutsideCompilationOriginalNodeAttrName
))
{
string
attr
;
TF_CHECK_OK
(
GetNodeAttr
(
n
->
attrs
(),
kHostToOutsideCompilationOriginalNodeAttrName
,
&
attr
));
if
(
attr
==
"identityn_0"
)
{
++
placeholder_count
;
}
}
}
EXPECT_EQ
(
placeholder_count
,
2
);
}
TEST
(
PostprocessForEncapsulationTest
,
ControlEdges
)
{
// Build the graph:
// "const0"
// "identity0" = "const0" (XLA computation 0)
// "identity1" = "identity0"
// "identity2" = "identity1" (XLA computation 1)
// "identity3" = "identity2"
tensorflow
::
Scope
s
=
tensorflow
::
Scope
::
NewRootScope
();
Output
const0
=
ops
::
Const
(
s
.
WithOpName
(
"const0"
),
1
,
{});
Output
identity0
=
ops
::
Identity
(
s
.
WithOpName
(
"identity0"
),
const0
);
Output
identity1
=
ops
::
Identity
(
s
.
WithOpName
(
"identity1"
),
identity0
);
Output
identity2
=
ops
::
Identity
(
s
.
WithOpName
(
"identity2"
),
identity1
);
Output
identity3
=
ops
::
Identity
(
s
.
WithOpName
(
"identity3"
),
identity2
);
Graph
g
(
OpRegistry
::
Global
());
TF_CHECK_OK
(
s
.
ToGraph
(
&
g
));
auto
node_index
=
g
.
BuildNodeNameIndex
();
// Set XLA computation/outside compilation attr, and add control edges.
Node
*
const0_node
=
node_index
[
"const0"
],
*
identity0_node
=
node_index
[
"identity0"
],
*
identity1_node
=
node_index
[
"identity1"
],
*
identity2_node
=
node_index
[
"identity2"
],
*
identity3_node
=
node_index
[
"identity3"
];
identity1_node
->
AddAttr
(
kXlaConnectedFromOtherXlaComputationAttrName
,
std
::
vector
<
string
>
{
"0"
});
identity1_node
->
AddAttr
(
kXlaConnectedToOtherXlaComputationAttrName
,
std
::
vector
<
string
>
{
"1"
});
identity3_node
->
AddAttr
(
kXlaControlDependenciesAttrName
,
std
::
vector
<
string
>
{
"const0"
,
"identity1"
});
std
::
unordered_map
<
string
,
XlaClusterInfo
>
clusters
;
clusters
[
"0"
].
node
=
identity0_node
;
clusters
[
"1"
].
node
=
identity2_node
;
TF_CHECK_OK
(
PostprocessForEncapsulation
(
&
g
,
"_xla"
,
"_oc"
,
clusters
));
// Case 3a: we have control edge identity0 -> identity1, and identity1 ->
// identity2.
bool
edge_identity0_identity1
=
false
,
edge_identity1_identity2
=
false
;
for
(
const
Edge
*
e
:
g
.
edges
())
{
if
(
!
e
->
IsControlEdge
())
{
continue
;
}
if
(
e
->
src
()
==
identity0_node
&&
e
->
dst
()
==
identity1_node
)
{
edge_identity0_identity1
=
true
;
}
else
if
(
e
->
src
()
==
identity1_node
&&
e
->
dst
()
==
identity2_node
)
{
edge_identity1_identity2
=
true
;
}
}
EXPECT_TRUE
(
edge_identity0_identity1
);
EXPECT_TRUE
(
edge_identity1_identity2
);
// Case 3b: we have control edge const0 -> identity3, and identity1 ->
// identity3.
bool
edge_const0_identity3
=
false
,
edge_identity1_identity3
=
false
;
for
(
const
Edge
*
e
:
g
.
edges
())
{
if
(
!
e
->
IsControlEdge
())
{
continue
;
}
if
(
e
->
src
()
==
const0_node
&&
e
->
dst
()
==
identity3_node
)
{
edge_const0_identity3
=
true
;
}
else
if
(
e
->
src
()
==
identity1_node
&&
e
->
dst
()
==
identity3_node
)
{
edge_identity1_identity3
=
true
;
}
}
EXPECT_TRUE
(
edge_const0_identity3
);
EXPECT_TRUE
(
edge_identity1_identity3
);
}
TEST
(
PostprocessForEncapsulationTest
,
DataEdges
)
{
// Build the graph:
// "const0" in outside compilation "0"
// "placeholder0" (for "const0") in host computation
// "add0" = "placeholder0" + "placeholder0" in host computation
// "placeholder1" (for "add0") in outside compilation 1
// "add1" = "placeholder1" + "placeholder1" in outside compilation 1
//
// "bridge" = "placeholder0" in host computation
// "placeholder2" (for "bridge") in outside compilation 1
// "add2" = "placeholder2" + "placeholder2" in outside compilation 1
tensorflow
::
Scope
s
=
tensorflow
::
Scope
::
NewRootScope
();
Output
const0
=
ops
::
Const
(
s
.
WithOpName
(
"const0"
),
1
,
{});
Output
placeholder0
=
ops
::
Placeholder
(
s
.
WithOpName
(
"placeholder0"
),
DT_INT32
);
Output
add0
=
ops
::
Add
(
s
.
WithOpName
(
"add0"
),
placeholder0
,
placeholder0
);
Output
placeholder1
=
ops
::
Placeholder
(
s
.
WithOpName
(
"placeholder1"
),
DT_INT32
);
Output
add1
=
ops
::
Add
(
s
.
WithOpName
(
"add1"
),
placeholder1
,
placeholder1
);
Output
bridge
=
ops
::
Identity
(
s
.
WithOpName
(
"bridge"
),
placeholder0
);
Output
placeholder2
=
ops
::
Placeholder
(
s
.
WithOpName
(
"placeholder2"
),
DT_INT32
);
Output
add2
=
ops
::
Add
(
s
.
WithOpName
(
"add2"
),
placeholder2
,
placeholder2
);
Graph
g
(
OpRegistry
::
Global
());
TF_CHECK_OK
(
s
.
ToGraph
(
&
g
));
auto
node_index
=
g
.
BuildNodeNameIndex
();
// Set related attributes.
Node
*
placeholder0_node
=
node_index
[
"placeholder0"
];
placeholder0_node
->
AddAttr
(
kOutsideCompilationToHostOriginalNodeAttrName
,
"const0"
);
placeholder0_node
->
AddAttr
(
kOutsideCompilationToHostSrcOutputAttrName
,
0
);
Node
*
placeholder1_node
=
node_index
[
"placeholder1"
];
placeholder1_node
->
AddAttr
(
kHostToOutsideCompilationOriginalNodeAttrName
,
"add0"
);
placeholder1_node
->
AddAttr
(
kHostToOutsideCompilationSrcOutputAttrName
,
0
);
Node
*
bridge_node
=
node_index
[
"bridge"
];
bridge_node
->
AddAttr
(
kBridgeSourceNodeAttrName
,
"const0"
);
Node
*
placeholder2_node
=
node_index
[
"placeholder2"
];
placeholder2_node
->
AddAttr
(
kHostToOutsideCompilationOriginalNodeAttrName
,
"bridge"
);
placeholder2_node
->
AddAttr
(
kHostToOutsideCompilationSrcOutputAttrName
,
0
);
std
::
unordered_map
<
string
,
XlaClusterInfo
>
clusters
;
TF_CHECK_OK
(
PostprocessForEncapsulation
(
&
g
,
"_xla"
,
"_oc"
,
clusters
));
// Result graph should be:
// "add0" = "const0" + "const0"
// "add1" = "add0" + "add0"
// "add2" = "const0" + "const0"
node_index
=
g
.
BuildNodeNameIndex
();
EXPECT_EQ
(
node_index
.
size
(),
6
);
EXPECT_EQ
(
node_index
[
"add0"
]
->
def
().
input
(
0
),
"const0:0"
);
EXPECT_EQ
(
node_index
[
"add0"
]
->
def
().
input
(
1
),
"const0:0"
);
EXPECT_EQ
(
node_index
[
"add1"
]
->
def
().
input
(
0
),
"add0:0"
);
EXPECT_EQ
(
node_index
[
"add1"
]
->
def
().
input
(
1
),
"add0:0"
);
EXPECT_EQ
(
node_index
[
"add2"
]
->
def
().
input
(
0
),
"const0:0"
);
EXPECT_EQ
(
node_index
[
"add2"
]
->
def
().
input
(
1
),
"const0:0"
);
}
}
// namespace tensorflow
tensorflow/compiler/jit/extract_outside_compilation_pass.cc
浏览文件 @
4d810848
...
...
@@ -634,17 +634,14 @@ Status ExpandHostGraphIntoMainGraph(Graph* main_graph,
return
s
;
}
// Rewrites shape inference graph for outside compilation.
// 1. If the outside compilation is a "top-level" one (not in a function of any
// If/While/etc.), this shape inference graph might have host computation to
// outside compilation placeholder nodes, which will cause shape inference to
// fail. However, those nodes are not in `host_graph` any more (because we
// have executed `PostprocessForEncapsultion`). In this case, we clear the
// graph, and copy SendFromHost with all its predecessors from `host_graph`.
// This case is detected by whether the SendFromHost node exists in
// `host_graph` as well.
// 2. Remove control edges, and prune nodes that are not useful for shape
// inference.
// Rewrites shape inference graph for outside compilation:
// 1) If XlaSendFromHost also exists in `host_graph`, copy nodes from
// `host_graph`. Because we might still have outside compilation to outside
// compilation placeholder nodes in shape inference graph, which will prevent
// us from inferring XlaSendFromHost shape. But in `host_graph`, we already
// removed those placeholder nodes.
// 2) Remove control edges.
// 3) Prune nodes that are not useful for shape inference.
Status
RewriteShapeInferenceGraph
(
const
string
&
shape_inference_graph_name
,
Graph
*
host_graph
,
FunctionLibraryDefinition
*
fld
)
{
...
...
@@ -744,6 +741,7 @@ Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name,
g
->
RemoveEdge
(
e
);
}
}
// Nodes that are not reverse reachable from SendFromHost are not useful for
// shape inference. Prune them.
PruneForReverseReachability
(
g
,
...
...
@@ -1581,14 +1579,6 @@ Status ExtractOutsideCompilation(
TF_RETURN_IF_ERROR
(
fld
->
RemoveFunction
(
host_graph_func_name
));
}
if
(
VLOG_IS_ON
(
4
))
{
dump_graph
::
DumpGraphToFile
(
"extract_outside_compilation_expanded"
,
*
g
,
fld
);
}
TF_RETURN_IF_ERROR
(
PostprocessForEncapsulation
(
g
,
xla_cluster_attr_name
,
outside_compilation_attr_name
,
clusters
));
for
(
auto
shape_inference_graph_name
:
shape_inference_graphs
)
{
TF_RETURN_IF_ERROR
(
RewriteShapeInferenceGraph
(
shape_inference_graph_name
,
g
,
fld
));
...
...
tensorflow/compiler/jit/shape_inference.cc
浏览文件 @
4d810848
...
...
@@ -53,7 +53,15 @@ Status PropagateShapes(const Graph& graph,
// shapes, even if no shape function is registered for a node.
Status
status
=
shape_refiner
->
AddNode
(
n
);
if
(
!
status
.
ok
())
{
VLOG
(
1
)
<<
"Shape inference failed for node: "
<<
status
;
VLOG
(
1
)
<<
"Shape inference failed for node "
<<
n
->
name
()
<<
": "
<<
status
;
}
else
{
shape_inference
::
InferenceContext
*
context
=
shape_refiner
->
GetContext
(
n
);
for
(
int
i
=
0
;
i
<
n
->
num_outputs
();
i
++
)
{
shape_inference
::
ShapeHandle
handle
=
context
->
output
(
i
);
VLOG
(
4
)
<<
"Output "
<<
i
<<
" for node "
<<
n
->
name
()
<<
": "
<<
context
->
DebugString
(
handle
);
}
}
if
(
n
->
type_string
()
==
"_Arg"
)
{
...
...
tensorflow/contrib/tpu/python/tpu/tpu.py
浏览文件 @
4d810848
...
...
@@ -646,6 +646,10 @@ def split_compile_and_replicate(computation,
array_ops
.
identity
(
x
,
name
=
"replicated_input_{}"
.
format
(
i
))
for
i
,
x
in
enumerate
(
computation_inputs
)
]
for
i
in
computation_inputs
:
# pylint: disable=protected-access
i
.
op
.
_set_attr
(
"_tpu_input_identity"
,
attr_value_pb2
.
AttrValue
(
b
=
True
))
# pylint: enable=protected-access
# If there is an infeed queue, adds the dequeued values to the
# computation's inputs.
...
...
@@ -726,7 +730,11 @@ def split_compile_and_replicate(computation,
new_output_tensors
=
[]
for
t
in
output_tensors
:
with
ops
.
device
(
t
.
device
if
t
.
device
else
core
(
0
)):
new_output_tensors
.
append
(
array_ops
.
identity
(
t
))
o
=
array_ops
.
identity
(
t
)
# pylint: disable=protected-access
o
.
op
.
_set_attr
(
"_tpu_output_identity"
,
attr_value_pb2
.
AttrValue
(
b
=
True
))
# pylint: enable=protected-access
new_output_tensors
.
append
(
o
)
output_tensors
=
new_output_tensors
context
.
ExitResult
(
output_tensors
)
finally
:
...
...
tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
浏览文件 @
4d810848
...
...
@@ -2280,7 +2280,7 @@ class TPUEstimator(estimator_lib.Estimator):
(
k
,
_export_output_to_tensors
(
v
))
for
k
,
v
in
six
.
iteritems
(
estimator_spec
.
export_outputs
))
tensors
=
nest
.
flatten
(
tensors_dict
)
tpu_tensors
=
[
t
for
t
in
tensors
if
_is_tpu_tensor
(
t
)
]
tpu_tensors
=
[
t
for
t
in
tensors
if
t
is
not
None
]
# We cannot return anything other than `tpu_tensors` here so we capture
# the rest for later use.
...
...
@@ -2294,18 +2294,10 @@ class TPUEstimator(estimator_lib.Estimator):
# `tpu_tensors_on_cpu`.
new_tensors
=
[]
for
t
in
tensors
:
if
_is_tpu_tensor
(
t
):
new_tensors
.
append
(
tpu_tensors_on_cpu
.
pop
(
0
))
elif
t
is
None
:
if
t
is
None
:
new_tensors
.
append
(
None
)
else
:
# Only fetching `tpu_tensors_on_cpu` does not trigger
# TPU computation and blocks, so we add the control dependency here.
control_inputs
=
(
tpu_tensors_on_cpu
if
_is_iterable
(
tpu_tensors_on_cpu
)
else
(
tpu_tensors_on_cpu
,))
with
ops
.
control_dependencies
(
control_inputs
):
new_tensors
.
append
(
array_ops
.
identity
(
t
))
new_tensors
.
append
(
tpu_tensors_on_cpu
.
pop
(
0
))
# Reconstruct `tensors_dict`.
new_tensors_dict
=
nest
.
pack_sequence_as
(
tensors_dict
,
new_tensors
)
...
...
@@ -2798,17 +2790,6 @@ class TPUEstimator(estimator_lib.Estimator):
return
_model_fn
def
_is_tpu_tensor
(
tensor
):
if
not
isinstance
(
tensor
,
ops
.
Tensor
):
return
False
try
:
tensor
.
op
.
get_attr
(
tpu
.
_OUTSIDE_COMPILATION_ATTR
)
# pylint: disable=protected-access
except
ValueError
:
return
True
else
:
return
False
def
_export_output_to_tensors
(
export_output
):
"""Get a list of `Tensors` used in `export_output`.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录