Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
xxadev
tensorflow
提交
c51eb496
T
tensorflow
项目概览
xxadev
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
3
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
c51eb496
编写于
7月 16, 2019
作者:
T
Tong Shen
提交者:
TensorFlower Gardener
7月 16, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Pass attributes when lowering functional If/While.
PiperOrigin-RevId: 258460978
上级
ca7acecc
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
35 addition
and
23 deletion
+35
-23
tensorflow/core/common_runtime/lower_if_op.cc
tensorflow/core/common_runtime/lower_if_op.cc
+17
-10
tensorflow/core/common_runtime/lower_while_op.cc
tensorflow/core/common_runtime/lower_while_op.cc
+18
-13
未找到文件。
tensorflow/core/common_runtime/lower_if_op.cc
浏览文件 @
c51eb496
...
...
@@ -38,11 +38,12 @@ class CondBuilder {
enum
Branch
{
kElseBranch
=
0
,
kThenBranch
=
1
};
// Create a CondBuilder to create the lowered form of `if_op` with then and
// else functions named `then_fn_name` and `else_fn_name` respectively in the
// `graph`. The functions should be available in `flib`.
CondBuilder
(
Node
*
if_op
,
const
string
&
then_fn_name
,
const
string
&
else_fn_name
,
const
FunctionLibraryDefinition
&
flib
,
bool
keep_node_fetchable
,
Graph
*
graph
);
// else functions `then_fn` and `else_fn` respectively in the `graph`. The
// functions should be available in `flib`.
CondBuilder
(
Node
*
if_op
,
const
NameAttrList
&
then_fn
,
const
NameAttrList
&
else_fn
,
const
FunctionLibraryDefinition
&
flib
,
bool
keep_node_fetchable
,
Graph
*
graph
);
// Constructs the basic conditional control flow using switch and merge nodes.
Status
CreatePivotNodes
();
...
...
@@ -103,8 +104,8 @@ class CondBuilder {
NodeBuilder
else_call_builder_
;
};
CondBuilder
::
CondBuilder
(
Node
*
if_op
,
const
string
&
then_fn_name
,
const
string
&
else_fn_name
,
CondBuilder
::
CondBuilder
(
Node
*
if_op
,
const
NameAttrList
&
then_fn
,
const
NameAttrList
&
else_fn
,
const
FunctionLibraryDefinition
&
flib
,
bool
keep_node_fetchable
,
Graph
*
graph
)
:
if_op_
(
if_op
),
...
...
@@ -113,15 +114,21 @@ CondBuilder::CondBuilder(Node* if_op, const string& then_fn_name,
name_
(
if_op
->
name
()),
keep_node_fetchable_
(
keep_node_fetchable
),
debug_info_
(
*
if_op_
),
then_call_builder_
(
NewName
(
"then"
),
then_fn
_name
,
graph
->
op_registry
(),
then_call_builder_
(
NewName
(
"then"
),
then_fn
.
name
()
,
graph
->
op_registry
(),
&
debug_info_
),
else_call_builder_
(
NewName
(
"else"
),
else_fn
_name
,
graph
->
op_registry
(),
else_call_builder_
(
NewName
(
"else"
),
else_fn
.
name
()
,
graph
->
op_registry
(),
&
debug_info_
)
{
TF_CHECK_OK
(
if_op_
->
input_tensor
(
0
,
&
pred_
));
then_call_builder_
.
Device
(
if_op_
->
requested_device
());
then_call_builder_
.
Attr
(
kLowerAsMultiDeviceFunctionAttr
,
true
);
for
(
const
auto
&
i
:
then_fn
.
attr
())
{
then_call_builder_
.
Attr
(
i
.
first
,
i
.
second
);
}
else_call_builder_
.
Device
(
if_op_
->
requested_device
());
else_call_builder_
.
Attr
(
kLowerAsMultiDeviceFunctionAttr
,
true
);
for
(
const
auto
&
i
:
else_fn
.
attr
())
{
else_call_builder_
.
Attr
(
i
.
first
,
i
.
second
);
}
}
Status
CondBuilder
::
CreatePivotNodes
()
{
...
...
@@ -279,7 +286,7 @@ Status RewriteIfNode(Node* n, Graph* g, const FunctionLibraryDefinition& flib,
return
errors
::
InvalidArgument
(
"Else branch function missing"
);
}
CondBuilder
cb
(
n
,
then_attr
->
func
()
.
name
(),
else_attr
->
func
().
name
(),
flib
,
CondBuilder
cb
(
n
,
then_attr
->
func
()
,
else_attr
->
func
(),
flib
,
keep_node_fetchable
,
g
);
TF_RETURN_IF_ERROR
(
cb
.
CreatePivotNodes
());
TF_RETURN_IF_ERROR
(
cb
.
AddInputs
());
...
...
tensorflow/core/common_runtime/lower_while_op.cc
浏览文件 @
c51eb496
...
...
@@ -56,13 +56,12 @@ constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
// consumer
class
LowerWhileHelper
{
public:
static
Status
Run
(
Node
*
while_op
,
const
string
&
cond_fn_name
,
const
string
&
body_fn_name
,
int
parallel_iterations
,
static
Status
Run
(
Node
*
while_op
,
const
NameAttrList
&
cond_fn
,
const
NameAttrList
&
body_fn
,
int
parallel_iterations
,
Graph
*
graph
,
const
FunctionLibraryDefinition
&
flib
,
bool
keep_node_fetchable
)
{
LowerWhileHelper
helper
(
while_op
,
cond_fn_name
,
body_fn_name
,
parallel_iterations
,
graph
,
flib
,
keep_node_fetchable
);
LowerWhileHelper
helper
(
while_op
,
cond_fn
,
body_fn
,
parallel_iterations
,
graph
,
flib
,
keep_node_fetchable
);
return
helper
.
RunInternal
();
}
...
...
@@ -70,8 +69,8 @@ class LowerWhileHelper {
// Create a LowerWhileHelper to create the lowering of While op that has cond
// and body functions named `cond_fn_name` and `body_fn_name` respectively in
// the given graph.
LowerWhileHelper
(
Node
*
while_op
,
const
string
&
cond_fn_name
,
const
string
&
body_fn_name
,
int
parallel_iterations
,
LowerWhileHelper
(
Node
*
while_op
,
const
NameAttrList
&
cond_fn
,
const
NameAttrList
&
body_fn
,
int
parallel_iterations
,
Graph
*
graph
,
const
FunctionLibraryDefinition
&
flib
,
bool
keep_node_fetchable
);
...
...
@@ -157,8 +156,8 @@ class LowerWhileHelper {
size_t
num_loop_inputs_
;
};
LowerWhileHelper
::
LowerWhileHelper
(
Node
*
while_op
,
const
string
&
cond_fn_name
,
const
string
&
body_fn_name
,
LowerWhileHelper
::
LowerWhileHelper
(
Node
*
while_op
,
const
NameAttrList
&
cond_fn
,
const
NameAttrList
&
body_fn
,
int
parallel_iterations
,
Graph
*
graph
,
const
FunctionLibraryDefinition
&
flib
,
bool
keep_node_fetchable
)
...
...
@@ -169,13 +168,19 @@ LowerWhileHelper::LowerWhileHelper(Node* while_op, const string& cond_fn_name,
parallel_iterations_
(
parallel_iterations
),
keep_node_fetchable_
(
keep_node_fetchable
),
debug_info_
(
*
while_op_
),
cond_call_builder_
(
NewName
(
"cond"
),
cond_fn
_name
,
graph
->
op_registry
(),
cond_call_builder_
(
NewName
(
"cond"
),
cond_fn
.
name
()
,
graph
->
op_registry
(),
&
debug_info_
),
body_call_builder_
(
NewName
(
"body"
),
body_fn
_name
,
graph
->
op_registry
(),
body_call_builder_
(
NewName
(
"body"
),
body_fn
.
name
()
,
graph
->
op_registry
(),
&
debug_info_
),
num_loop_inputs_
(
while_op_
->
num_inputs
())
{
cond_call_builder_
.
Attr
(
kLowerAsMultiDeviceFunctionAttr
,
true
);
for
(
const
auto
&
i
:
cond_fn
.
attr
())
{
cond_call_builder_
.
Attr
(
i
.
first
,
i
.
second
);
}
body_call_builder_
.
Attr
(
kLowerAsMultiDeviceFunctionAttr
,
true
);
for
(
const
auto
&
i
:
body_fn
.
attr
())
{
body_call_builder_
.
Attr
(
i
.
first
,
i
.
second
);
}
// We intentionally `resize` instead of `reserve` space in `enter_nodes_`
// because we need to set it's elements out of order in `CreateEnterNodes`.
enter_nodes_
.
resize
(
num_loop_inputs_
);
...
...
@@ -432,8 +437,8 @@ Status RewriteWhileNode(Node* n, Graph* g,
}
TF_RETURN_IF_ERROR
(
LowerWhileHelper
::
Run
(
n
,
cond_attr
->
func
()
.
name
(),
body_attr
->
func
().
name
()
,
parallel_iterations_attr
->
i
(),
g
,
flib
,
keep_node_fetchable
));
n
,
cond_attr
->
func
()
,
body_attr
->
func
(),
parallel_iterations_attr
->
i
(),
g
,
flib
,
keep_node_fetchable
));
g
->
RemoveNode
(
n
);
return
Status
::
OK
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录