Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
qq_38905368
tensorflow
提交
10cab63f
T
tensorflow
项目概览
qq_38905368
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
10cab63f
编写于
12月 10, 2018
作者:
T
Tong Shen
提交者:
TensorFlower Gardener
12月 10, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Outside compilation in "If" and "While".
PiperOrigin-RevId: 224933587
上级
0d822c01
变更
9
展开全部
隐藏空白更改
内联
并排
Showing
9 changed file
with
1239 addition
and
175 deletion
+1239
-175
tensorflow/compiler/jit/BUILD
tensorflow/compiler/jit/BUILD
+3
-0
tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
+121
-80
tensorflow/compiler/jit/extract_outside_compilation_pass.cc
tensorflow/compiler/jit/extract_outside_compilation_pass.cc
+711
-49
tensorflow/compiler/jit/extract_outside_compilation_pass.h
tensorflow/compiler/jit/extract_outside_compilation_pass.h
+3
-2
tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc
...low/compiler/jit/extract_outside_compilation_pass_test.cc
+372
-37
tensorflow/compiler/tf2xla/kernels/while_op.cc
tensorflow/compiler/tf2xla/kernels/while_op.cc
+15
-7
tensorflow/compiler/tf2xla/side_effect_util.cc
tensorflow/compiler/tf2xla/side_effect_util.cc
+2
-0
tensorflow/compiler/tf2xla/side_effect_util.h
tensorflow/compiler/tf2xla/side_effect_util.h
+3
-0
tensorflow/compiler/tf2xla/tf2xla_util.cc
tensorflow/compiler/tf2xla/tf2xla_util.cc
+9
-0
未找到文件。
tensorflow/compiler/jit/BUILD
浏览文件 @
10cab63f
...
...
@@ -515,6 +515,7 @@ cc_library(
"//tensorflow/compiler/jit/ops:xla_ops"
,
"//tensorflow/compiler/tf2xla:dump_graph"
,
"//tensorflow/compiler/tf2xla:resource_operation_table"
,
"//tensorflow/compiler/tf2xla:side_effect_util"
,
"//tensorflow/compiler/tf2xla:tf2xla_util"
,
"//tensorflow/compiler/tf2xla:xla_compiler"
,
"//tensorflow/compiler/tf2xla/cc:xla_jit_ops"
,
...
...
@@ -613,6 +614,7 @@ tf_cc_test(
"//tensorflow/cc:cc_ops"
,
"//tensorflow/cc:cc_ops_internal"
,
"//tensorflow/cc:function_ops"
,
"//tensorflow/cc:functional_ops"
,
"//tensorflow/cc:ops"
,
"//tensorflow/cc:resource_variable_ops"
,
"//tensorflow/cc:scope"
,
...
...
@@ -625,6 +627,7 @@ tf_cc_test(
"//tensorflow/compiler/tf2xla/cc:xla_ops"
,
"//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops"
,
"//tensorflow/compiler/tf2xla/kernels:xla_ops"
,
"//tensorflow/compiler/xla:test"
,
"//tensorflow/core:core_cpu"
,
"//tensorflow/core:framework"
,
"//tensorflow/core:framework_internal"
,
...
...
tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
浏览文件 @
10cab63f
此差异已折叠。
点击以展开。
tensorflow/compiler/jit/extract_outside_compilation_pass.cc
浏览文件 @
10cab63f
此差异已折叠。
点击以展开。
tensorflow/compiler/jit/extract_outside_compilation_pass.h
浏览文件 @
10cab63f
...
...
@@ -88,9 +88,10 @@ Status ExtractOutsideCompilationForFunction(
const
string
&
xla_cluster_attr_name
,
const
string
&
outside_compilation_attr_name
,
const
string
&
xla_cluster_name
,
const
NameAttrList
&
func_name_attrs
,
const
string
&
new_func_name
,
const
string
&
host_graph_func_name
,
const
std
::
map
<
string
,
int
>&
host_compute_core
,
FunctionLibraryDefinition
*
fld
,
std
::
unique_ptr
<
Graph
>*
host_graph
,
std
::
vector
<
string
>*
shape_inference_graphs
,
bool
*
has_outside_compilation
);
FunctionLibraryDefinition
*
fld
,
std
::
vector
<
string
>*
shape_inference_graphs
,
bool
*
has_outside_compilation
);
// Rewrites XLA computation in `clusters` to replace outside compilation nodes
// with XlaHostCompute, and moves those outside compilations into `g`. If shapes
...
...
tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc
浏览文件 @
10cab63f
此差异已折叠。
点击以展开。
tensorflow/compiler/tf2xla/kernels/while_op.cc
浏览文件 @
10cab63f
...
...
@@ -41,8 +41,7 @@ Status MakeXlaCompilerArgumentsFromInputs(
*
has_uninitialized_vars
=
false
;
*
has_tensor_arrays
=
false
;
for
(
int
i
=
0
;
i
<
ctx
->
num_inputs
();
++
i
)
{
VLOG
(
2
)
<<
" Input "
<<
i
<<
" type: "
<<
DataTypeString
(
ctx
->
input_type
(
i
))
VLOG
(
2
)
<<
" Input "
<<
i
<<
" type: "
<<
DataTypeString
(
ctx
->
input_type
(
i
))
<<
" shape: "
<<
ctx
->
InputShape
(
i
).
DebugString
();
XlaCompiler
::
Argument
&
arg
=
(
*
args
)[
i
];
DataType
type
=
ctx
->
input_type
(
i
);
...
...
@@ -233,13 +232,22 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
xla
::
ShapeUtil
::
HumanString
(
body_input_shape
),
" vs. "
,
xla
::
ShapeUtil
::
HumanString
(
body
.
xla_output_shape
)));
xla
::
Shape
expected_cond_output_shape
=
xla
::
ShapeUtil
::
MakeTupleShape
(
{
xla
::
ShapeUtil
::
MakeShape
(
xla
::
PRED
,
{})});
xla
::
Shape
expected_cond_output_shape_without_side_effect
=
xla
::
ShapeUtil
::
MakeTupleShape
(
{
xla
::
ShapeUtil
::
MakeShape
(
xla
::
PRED
,
{})});
xla
::
Shape
expected_cond_output_shape_with_side_effect
=
xla
::
ShapeUtil
::
MakeTupleShape
({
xla
::
ShapeUtil
::
MakeShape
(
xla
::
PRED
,
{}),
xla
::
ShapeUtil
::
MakeTokenShape
()});
OP_REQUIRES
(
ctx
,
xla
::
ShapeUtil
::
Compatible
(
cond
.
xla_output_shape
,
expected_cond_output_shape
),
xla
::
ShapeUtil
::
Compatible
(
cond
.
xla_output_shape
,
expected_cond_output_shape_without_side_effect
)
||
xla
::
ShapeUtil
::
Compatible
(
cond
.
xla_output_shape
,
expected_cond_output_shape_with_side_effect
),
errors
::
InvalidArgument
(
"Output shape of loop condition should be (pred[]), got: "
,
"Output shape of loop condition should be (pred[]) or "
"(pred[], token[]), got: "
,
xla
::
ShapeUtil
::
HumanString
(
cond
.
xla_output_shape
)));
int
num_inputs
=
body
.
input_mapping
.
size
();
...
...
tensorflow/compiler/tf2xla/side_effect_util.cc
浏览文件 @
10cab63f
...
...
@@ -24,6 +24,8 @@ const char kXlaTokenInputNodesAttrName[] = "_xla_token_input_nodes";
const
char
kXlaTokenArgNodeName
[]
=
"_xla_token_arg_node"
;
const
char
kXlaHasHostTransferAttrName
[]
=
"_xla_has_host_transfer"
;
std
::
set
<
std
::
string
>
CalculateTokenInputsForOutputToken
(
const
Graph
&
g
)
{
std
::
set
<
std
::
string
>
results
;
Node
*
first_side_effecting_node_on_path
=
nullptr
;
...
...
tensorflow/compiler/tf2xla/side_effect_util.h
浏览文件 @
10cab63f
...
...
@@ -35,6 +35,9 @@ extern const char kXlaTokenInputNodesAttrName[];
// node has side-effect dependency on current graph's token input.
extern
const
char
kXlaTokenArgNodeName
[];
// This node have XlaRecvAtHost/XlaSendFromHost in its associated functions.
extern
const
char
kXlaHasHostTransferAttrName
[];
// Calculates side-effect dependencies for the graph's token output.
// Returns a set of node names representing these dependencies.
std
::
set
<
std
::
string
>
CalculateTokenInputsForOutputToken
(
const
Graph
&
g
);
...
...
tensorflow/compiler/tf2xla/tf2xla_util.cc
浏览文件 @
10cab63f
...
...
@@ -557,6 +557,12 @@ bool HasAssociatedFunction(const NodeDef& node_def,
return
true
;
}
if
(
node_def
.
op
()
==
"XlaHostCompute"
)
{
// XlaHostCompute has "shape_inference_graph" func attr, but that's not
// related to graph execution.
return
false
;
}
for
(
const
auto
&
iter
:
node_def
.
attr
())
{
if
(
iter
.
second
.
has_func
())
{
return
true
;
...
...
@@ -578,6 +584,9 @@ std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
// This is a SymbolicGradient op.
AttrValueMap
attrs
(
node
.
attrs
().
begin
(),
node
.
attrs
().
end
());
results
.
emplace_back
(
AssociatedFunctionInfo
::
SymbolicGradient
(
op
,
attrs
));
}
else
if
(
node
.
type_string
()
==
"XlaHostCompute"
)
{
// XlaHostCompute has "shape_inference_graph" func attr, but that's not
// related to graph execution.
}
else
{
// Collect all function attrs for the node.
for
(
auto
&
iter
:
node
.
attrs
())
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录