Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
xxadev
tensorflow
提交
f83f6b9e
T
tensorflow
项目概览
xxadev
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
3
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
f83f6b9e
编写于
9月 08, 2017
作者:
C
Chris Leary
提交者:
TensorFlower Gardener
9月 08, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[XLA] Handle higher-order HLOs (e.g. While) in CallInliner and test.
PiperOrigin-RevId: 168029345
上级
8988ae36
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
98 addition
and
24 deletion
+98
-24
tensorflow/compiler/xla/service/call_inliner.cc
tensorflow/compiler/xla/service/call_inliner.cc
+59
-24
tensorflow/compiler/xla/service/call_inliner_test.cc
tensorflow/compiler/xla/service/call_inliner_test.cc
+39
-0
未找到文件。
tensorflow/compiler/xla/service/call_inliner.cc
浏览文件 @
f83f6b9e
...
...
@@ -20,30 +20,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
namespace
xla
{
StatusOr
<
bool
>
CallInliner
::
Run
(
HloModule
*
module
)
{
std
::
deque
<
HloInstruction
*>
work_queue
;
// Seed the work queue with call instructions from the main computation.
TF_RETURN_IF_ERROR
(
module
->
entry_computation
()
->
Accept
([
&
](
HloInstruction
*
hlo
)
{
if
(
hlo
->
opcode
()
==
HloOpcode
::
kCall
)
{
work_queue
.
push_back
(
hlo
);
}
return
Status
::
OK
();
}));
VLOG
(
1
)
<<
"Work queue seeded with "
<<
work_queue
.
size
()
<<
" entries."
;
bool
mutated
=
false
;
while
(
!
work_queue
.
empty
())
{
mutated
=
true
;
HloInstruction
*
call
=
work_queue
.
front
();
work_queue
.
pop_front
();
TF_RETURN_IF_ERROR
(
ReplaceWithInlinedBody
(
call
,
&
work_queue
));
}
return
mutated
;
}
namespace
{
// Traverses the callee computation, inlining cloned nodes into the caller
// computation and connecting them to producers/consumers appropriately.
...
...
@@ -141,6 +118,64 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
std
::
deque
<
HloInstruction
*>*
work_queue_
;
};
}
// namespace
StatusOr
<
bool
>
CallInliner
::
Run
(
HloModule
*
module
)
{
std
::
deque
<
HloInstruction
*>
work_queue
;
tensorflow
::
gtl
::
FlatSet
<
HloComputation
*>
seen
;
auto
scan_computation
=
[
&
work_queue
,
&
seen
](
HloComputation
*
computation
)
->
Status
{
if
(
!
seen
.
insert
(
computation
).
second
)
{
return
Status
::
OK
();
// Already seen.
}
return
computation
->
Accept
([
&
](
HloInstruction
*
hlo
)
{
if
(
!
hlo
->
called_computations
().
empty
())
{
work_queue
.
push_back
(
hlo
);
}
return
Status
::
OK
();
});
};
// Seed the work queue with call instructions from the main computation.
TF_RETURN_IF_ERROR
(
scan_computation
(
module
->
entry_computation
()));
VLOG
(
1
)
<<
"Work queue seeded with "
<<
work_queue
.
size
()
<<
" entries."
;
bool
mutated
=
false
;
while
(
!
work_queue
.
empty
())
{
HloInstruction
*
caller
=
work_queue
.
front
();
work_queue
.
pop_front
();
switch
(
caller
->
opcode
())
{
case
HloOpcode
::
kCall
:
mutated
=
true
;
TF_RETURN_IF_ERROR
(
ReplaceWithInlinedBody
(
caller
,
&
work_queue
));
break
;
case
HloOpcode
::
kWhile
:
TF_RETURN_IF_ERROR
(
scan_computation
(
caller
->
while_condition
()));
TF_RETURN_IF_ERROR
(
scan_computation
(
caller
->
while_body
()));
break
;
case
HloOpcode
::
kSelectAndScatter
:
TF_RETURN_IF_ERROR
(
scan_computation
(
caller
->
select
()));
TF_RETURN_IF_ERROR
(
scan_computation
(
caller
->
scatter
()));
break
;
case
HloOpcode
::
kMap
:
case
HloOpcode
::
kReduceWindow
:
case
HloOpcode
::
kReduce
:
TF_RETURN_IF_ERROR
(
scan_computation
(
caller
->
to_apply
()));
break
;
case
HloOpcode
::
kFusion
:
// Fusion nodes don't represent true calls, but instead delimit a
// boundary for the backend-specific fusion capabilities.
break
;
default:
return
Unimplemented
(
"Unknown higher-order HLO opcode: %s"
,
caller
->
ToString
().
c_str
());
}
}
return
mutated
;
}
Status
CallInliner
::
ReplaceWithInlinedBody
(
HloInstruction
*
call
,
std
::
deque
<
HloInstruction
*>*
work_queue
)
{
TF_RET_CHECK
(
call
->
opcode
()
==
HloOpcode
::
kCall
);
...
...
tensorflow/compiler/xla/service/call_inliner_test.cc
浏览文件 @
f83f6b9e
...
...
@@ -73,5 +73,44 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
EXPECT_EQ
(
prior
->
literal
().
GetFirstElement
<
float
>
(),
24
);
}
// Tests for referential transparency (a function that calls a function that
// returns false should be identical to just returning false).
TEST_F
(
CallInlinerTest
,
CallsWithinWhileBodiesAreInlined
)
{
const
Shape
pred
=
ShapeUtil
::
MakeShape
(
PRED
,
{});
auto
module
=
CreateNewModule
();
// Create a lambda that calls a function that returns the false predicate.
// Note we also use this lambda twice by reference, just to make the test a
// little trickier.
HloComputation
::
Builder
just_false
(
TestName
()
+
".false"
);
just_false
.
AddInstruction
(
HloInstruction
::
CreateConstant
(
Literal
::
CreateR0
<
bool
>
(
false
)));
HloComputation
*
false_computation
=
module
->
AddEmbeddedComputation
(
just_false
.
Build
());
HloComputation
::
Builder
call_false_builder
(
TestName
()
+
".call_false"
);
call_false_builder
.
AddInstruction
(
HloInstruction
::
CreateCall
(
pred
,
{},
false_computation
));
HloComputation
*
call_false
=
module
->
AddEmbeddedComputation
(
call_false_builder
.
Build
());
HloComputation
::
Builder
outer
(
TestName
()
+
".outer"
);
HloInstruction
*
init_value
=
outer
.
AddInstruction
(
HloInstruction
::
CreateConstant
(
Literal
::
CreateR0
<
bool
>
(
false
)));
outer
.
AddInstruction
(
HloInstruction
::
CreateWhile
(
pred
,
call_false
,
call_false
,
init_value
));
auto
computation
=
module
->
AddEntryComputation
(
outer
.
Build
());
CallInliner
call_inliner
;
TF_ASSERT_OK_AND_ASSIGN
(
bool
mutated
,
call_inliner
.
Run
(
module
.
get
()));
ASSERT_TRUE
(
mutated
);
EXPECT_THAT
(
computation
->
root_instruction
()
->
while_condition
()
->
root_instruction
(),
op
::
Constant
());
EXPECT_THAT
(
computation
->
root_instruction
()
->
while_body
()
->
root_instruction
(),
op
::
Constant
());
}
}
// namespace
}
// namespace xla
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录