Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
wux_labs
Tensorflow
提交
956c1b52
T
Tensorflow
项目概览
wux_labs
/
Tensorflow
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
Tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
956c1b52
编写于
10月 14, 2022
作者:
J
Jian Cai
提交者:
TensorFlower Gardener
10月 14, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor MlirBridgePass::Run to observe pass state before running MLIR ridge.
PiperOrigin-RevId: 481258105
上级
f1e2961e
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
29 addition
and
31 deletion
+29
-31
tensorflow/compiler/tf2xla/mlir_bridge_pass.cc
tensorflow/compiler/tf2xla/mlir_bridge_pass.cc
+29
-31
未找到文件。
tensorflow/compiler/tf2xla/mlir_bridge_pass.cc
浏览文件 @
956c1b52
...
...
@@ -165,9 +165,8 @@ MlirOptimizationPassState MlirBridgePass::GetPassState(
const
FunctionLibraryDefinition
&
function_library
)
const
{
// Skip MLIR TF XLA Bridge if no TPU devices found and the non TPU graph is
// not qualified.
if
(
device_set
&&
!
HasTPUDevice
(
*
device_set
))
{
return
EnableNonTpuBridge
(
graph
)
?
MlirOptimizationPassState
::
Enabled
:
MlirOptimizationPassState
::
Disabled
;
if
(
device_set
&&
!
HasTPUDevice
(
*
device_set
)
&&
!
EnableNonTpuBridge
(
graph
))
{
return
MlirOptimizationPassState
::
Disabled
;
}
// We set `uses_uninitialized_resource_args` to false here because the first
...
...
@@ -218,17 +217,14 @@ Status MlirBridgePass::Run(const ConfigProto& config_proto,
// Check if there are TPU devices or TPU ops. If not, then check if the
// non TPU graph is qualified to run TF XLA Bridge.
// This check needs to precede GetPassState for instrumentation purposes.
if
(
!
HasTPUDevicesAndOps
(
module
))
{
if
(
EnableNonTpuBridge
(
graph
))
{
VLOG
(
1
)
<<
"No TPU devices or TPU ops found, "
<<
"this non TPU graph is qualified to run MLIR TF XLA Bridge"
;
return
mlir
::
TF
::
RunTFXLABridge
(
module
,
VLOG_IS_ON
(
1
));
}
else
{
VLOG
(
1
)
<<
" Skipping MLIR TF XLA Bridge,"
<<
" no TPU devices or TPU ops found, and this non TPU graph"
<<
" is not qualified to run MLIR TF XLA Bridge."
;
return
OkStatus
();
}
bool
is_qualified_for_tpu_bridge
=
HasTPUDevicesAndOps
(
module
),
is_qualified_for_non_tpu_bridge
=
false
;
if
(
!
is_qualified_for_tpu_bridge
)
is_qualified_for_non_tpu_bridge
=
EnableNonTpuBridge
(
graph
);
if
(
!
is_qualified_for_tpu_bridge
&&
!
is_qualified_for_non_tpu_bridge
)
{
VLOG
(
1
)
<<
"Skipping MLIR TF XLA Bridge, no qualified devices or ops found."
;
return
OkStatus
();
}
// Set device_set to nullptr here as the device specific checks are performed
...
...
@@ -239,23 +235,25 @@ Status MlirBridgePass::Run(const ConfigProto& config_proto,
function_library
);
if
(
pass_state
==
MlirOptimizationPassState
::
Disabled
)
{
//
Currently the logging for handling the disabled case is in GetPassStat
e
//
because it is called directly before run() and run() will not be called
//
if the pass is disabled. This logic is here defenseively in case the
//
calling pass logic
changes.
//
GetPassState is called before run() and run() will only be called if th
e
//
pass is not disabled. However, the graph may have been updated between
//
when the pass state was originally calculated and now, so this check is
//
required to reflect any possible
changes.
VLOG
(
1
)
<<
"MlirBridgePass is disabled and will not run."
;
return
OkStatus
();
}
bool
fallback_enabled
=
false
;
if
(
pass_state
==
MlirOptimizationPassState
::
FallbackEnabled
)
fallback_enabled
=
true
;
VLOG
(
1
)
<<
"Running MLIR TPU Bridge"
;
mlir_bridge_gauge_v2
->
GetCell
()
->
Set
(
true
);
return
mlir
::
TFTPU
::
TPUBridge
(
module
,
/*enable_logging=*/
VLOG_IS_ON
(
1
),
fallback_enabled
);
if
(
is_qualified_for_tpu_bridge
)
{
bool
fallback_enabled
=
false
;
if
(
pass_state
==
MlirOptimizationPassState
::
FallbackEnabled
)
fallback_enabled
=
true
;
VLOG
(
1
)
<<
"Running MLIR TPU Bridge"
;
mlir_bridge_gauge_v2
->
GetCell
()
->
Set
(
true
);
return
mlir
::
TFTPU
::
TPUBridge
(
module
,
/*enable_logging=*/
VLOG_IS_ON
(
1
),
fallback_enabled
);
}
VLOG
(
1
)
<<
"Running MLIR non-TPU Bridge"
;
return
mlir
::
TF
::
RunTFXLABridge
(
module
,
VLOG_IS_ON
(
1
));
}
MlirOptimizationPassState
MlirBridgeV1CompatPass
::
GetPassState
(
...
...
@@ -323,10 +321,10 @@ Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options,
// Set device_set to nullptr here as the device specific checks are performed
// based on the devices in the module.
if
(
pass_state
==
MlirOptimizationPassState
::
Disabled
)
{
//
Currently the logging for handling the disabled case is in GetPassStat
e
//
because it is called directly before run() and run() will not be called
//
if the pass is disabled. This logic is here defenseively in case the
//
calling pass logic
changes.
//
GetPassState is called before run() and run() will only be called if th
e
//
pass is not disabled. However, the graph may have been updated between
//
when the pass state was originally calculated and now, so this check is
//
required to reflect any possible
changes.
VLOG
(
1
)
<<
"Skipping MLIR TPU Bridge V1 Compat, session flag not enabled"
;
mlir_bridge_gauge_v1
->
GetCell
()
->
Set
(
false
);
return
OkStatus
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录