提交 956c1b52 编写于 作者: J Jian Cai 提交者: TensorFlower Gardener

Refactor MlirBridgePass::Run to observe pass state before running MLIR ridge.

PiperOrigin-RevId: 481258105
上级 f1e2961e
......@@ -165,9 +165,8 @@ MlirOptimizationPassState MlirBridgePass::GetPassState(
const FunctionLibraryDefinition& function_library) const {
// Skip MLIR TF XLA Bridge if no TPU devices found and the non TPU graph is
// not qualified.
if (device_set && !HasTPUDevice(*device_set)) {
return EnableNonTpuBridge(graph) ? MlirOptimizationPassState::Enabled
: MlirOptimizationPassState::Disabled;
if (device_set && !HasTPUDevice(*device_set) && !EnableNonTpuBridge(graph)) {
return MlirOptimizationPassState::Disabled;
}
// We set `uses_uninitialized_resource_args` to false here because the first
......@@ -218,17 +217,14 @@ Status MlirBridgePass::Run(const ConfigProto& config_proto,
// Check if there are TPU devices or TPU ops. If not, then check if the
// non TPU graph is qualified to run TF XLA Bridge.
// This check needs to precede GetPassState for instrumentation purposes.
if (!HasTPUDevicesAndOps(module)) {
if (EnableNonTpuBridge(graph)) {
VLOG(1) << "No TPU devices or TPU ops found, "
<< "this non TPU graph is qualified to run MLIR TF XLA Bridge";
return mlir::TF::RunTFXLABridge(module, VLOG_IS_ON(1));
} else {
VLOG(1) << " Skipping MLIR TF XLA Bridge,"
<< " no TPU devices or TPU ops found, and this non TPU graph"
<< " is not qualified to run MLIR TF XLA Bridge.";
return OkStatus();
}
bool is_qualified_for_tpu_bridge = HasTPUDevicesAndOps(module),
is_qualified_for_non_tpu_bridge = false;
if (!is_qualified_for_tpu_bridge)
is_qualified_for_non_tpu_bridge = EnableNonTpuBridge(graph);
if (!is_qualified_for_tpu_bridge && !is_qualified_for_non_tpu_bridge) {
VLOG(1)
<< "Skipping MLIR TF XLA Bridge, no qualified devices or ops found.";
return OkStatus();
}
// Set device_set to nullptr here as the device specific checks are performed
......@@ -239,23 +235,25 @@ Status MlirBridgePass::Run(const ConfigProto& config_proto,
function_library);
if (pass_state == MlirOptimizationPassState::Disabled) {
// Currently the logging for handling the disabled case is in GetPassState
// because it is called directly before run() and run() will not be called
// if the pass is disabled. This logic is here defenseively in case the
// calling pass logic changes.
// GetPassState is called before run() and run() will only be called if the
// pass is not disabled. However, the graph may have been updated between
// when the pass state was originally calculated and now, so this check is
// required to reflect any possible changes.
VLOG(1) << "MlirBridgePass is disabled and will not run.";
return OkStatus();
}
bool fallback_enabled = false;
if (pass_state == MlirOptimizationPassState::FallbackEnabled)
fallback_enabled = true;
VLOG(1) << "Running MLIR TPU Bridge";
mlir_bridge_gauge_v2->GetCell()->Set(true);
return mlir::TFTPU::TPUBridge(module, /*enable_logging=*/VLOG_IS_ON(1),
fallback_enabled);
if (is_qualified_for_tpu_bridge) {
bool fallback_enabled = false;
if (pass_state == MlirOptimizationPassState::FallbackEnabled)
fallback_enabled = true;
VLOG(1) << "Running MLIR TPU Bridge";
mlir_bridge_gauge_v2->GetCell()->Set(true);
return mlir::TFTPU::TPUBridge(module, /*enable_logging=*/VLOG_IS_ON(1),
fallback_enabled);
}
VLOG(1) << "Running MLIR non-TPU Bridge";
return mlir::TF::RunTFXLABridge(module, VLOG_IS_ON(1));
}
MlirOptimizationPassState MlirBridgeV1CompatPass::GetPassState(
......@@ -323,10 +321,10 @@ Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options,
// Set device_set to nullptr here as the device specific checks are performed
// based on the devices in the module.
if (pass_state == MlirOptimizationPassState::Disabled) {
// Currently the logging for handling the disabled case is in GetPassState
// because it is called directly before run() and run() will not be called
// if the pass is disabled. This logic is here defenseively in case the
// calling pass logic changes.
// GetPassState is called before run() and run() will only be called if the
// pass is not disabled. However, the graph may have been updated between
// when the pass state was originally calculated and now, so this check is
// required to reflect any possible changes.
VLOG(1) << "Skipping MLIR TPU Bridge V1 Compat, session flag not enabled";
mlir_bridge_gauge_v1->GetCell()->Set(false);
return OkStatus();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册