提交 1c236a4b 编写于 作者: Y Yanan Cao 提交者: TensorFlower Gardener

Add tf_device dialect that models TensorFlow's actions of launching...

Add tf_device dialect that models TensorFlow's actions of launching computations on accelerator devices.

PiperOrigin-RevId: 262833506
上级 1149df16
......@@ -74,6 +74,30 @@ gentbl(
],
)
gentbl(
name = "tensorflow_device_ops_inc_gen",
tbl_outs = [
(
"-gen-op-decls",
"ir/tf_device.h.inc",
),
(
"-gen-op-defs",
"ir/tf_device.cc.inc",
),
(
"-gen-op-doc",
"g3doc/tf_device.md",
),
],
tblgen = "@local_config_mlir//:mlir-tblgen",
td_file = "ir/tf_device_ops.td",
td_srcs = [
"@local_config_mlir//:include/mlir/IR/OpBase.td",
"@local_config_mlir//:include/mlir/StandardOps/Ops.td",
],
)
gentbl(
name = "tensorflow_canonicalize_inc_gen",
tbl_outs = [
......@@ -93,6 +117,7 @@ cc_library(
name = "tensorflow",
srcs = [
"ir/control_flow_ops.cc",
"ir/tf_device.cc",
"ir/tf_executor.cc",
"ir/tf_executor.cc.inc",
"ir/tf_executor.h.inc",
......@@ -111,6 +136,7 @@ cc_library(
],
hdrs = [
"ir/control_flow_ops.h",
"ir/tf_device.h",
"ir/tf_executor.h",
"ir/tf_ops.h",
"ir/tf_traits.h",
......@@ -121,6 +147,7 @@ cc_library(
includes = ["include"],
deps = [
":tensorflow_canonicalize_inc_gen",
":tensorflow_device_ops_inc_gen",
":tensorflow_executor_inc_gen",
":tensorflow_ops_inc_gen",
":tensorflow_optimize_inc_gen",
......
......@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/ir/control_flow_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
......@@ -25,5 +26,7 @@ static DialectRegistration<TFControlFlow::TFControlFlowDialect>
static DialectRegistration<TF::TensorFlowDialect> tf_ops;
static DialectRegistration<tf_executor::TensorFlowExecutorDialect>
tf_excutor_dialect;
static DialectRegistration<tf_device::TensorFlowDeviceDialect>
tf_device_dialect;
} // namespace mlir
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
namespace mlir {
namespace tf_device {
TensorFlowDeviceDialect::TensorFlowDeviceDialect(MLIRContext *context)
: Dialect(/*name=*/"tf_device", context) {
addOperations<
#define GET_OP_LIST
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc.inc"
>();
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc.inc"
} // namespace tf_device
} // namespace mlir
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines the tf_device dialect: it contains operations that model
// TensorFlow's actions to launch computations on accelerator devices.
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DEVICE_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DEVICE_H_
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Dialect.h" // TF:local_config_mlir
namespace mlir {
namespace tf_device {
// The TensorFlow Device dialect.
//
// This dialect contains operations to describe/launch computations on devices.
// These operations do not map 1-1 to TensorFlow ops and requires a lowering
// pass later to transform them into Compile/Run op pairs, like XlaCompile and
// XlaRun.
class TensorFlowDeviceDialect : public Dialect {
public:
// Constructing TensorFlowDevice dialect under an non-null MLIRContext.
explicit TensorFlowDeviceDialect(MLIRContext *context);
};
// Declares the operations for this dialect using the generated header.
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h.inc"
} // namespace tf_device
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_DEVICE_H_
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the definition file for the TensorFlow Device Dialect.
#ifdef TF_DEVICE_DIALECT
#else
#define TF_DEVICE_DIALECT
#ifdef OP_BASE
#else
include "mlir/IR/OpBase.td"
#endif // OP_BASE
//===----------------------------------------------------------------------===//
// TensorFlow Device Dialect definitions
//===----------------------------------------------------------------------===//
def TfDevice_Dialect : Dialect {
let name = "tf_device";
let description = [{
The TensorFlow Device dialect.
This dialect contains operations to describe/launch computations on devices.
These operations do not map 1-1 to TensorFlow ops and requires a lowering
pass later to transform them into Compile/Run op pairs, like XlaCompile and
XlaRun.
}];
let cppNamespace = "tf_device";
}
//===----------------------------------------------------------------------===//
// TensorFlow Device Dialect Ops definitions
//===----------------------------------------------------------------------===//
// Base class for the operation in this dialect.
class TfDevice_Op<string mnemonic, list<OpTrait> traits = []> :
Op<TfDevice_Dialect, mnemonic, traits> { }
def TfDevice_LaunchOp : TfDevice_Op<"launch", [SingleBlockImplicitTerminator<"ReturnOp">]> {
let summary = [{The `tf_device.launch` op captures all needed live-in values
and launches containing operations on target device.}];
let arguments = (ins
StrAttr:$device
);
let results = (outs
Variadic<AnyType>:$results
);
let regions = (region SizedRegion<1>:$body);
let extraClassDeclaration = [{
StringRef getDevice() { return device(); }
}];
}
def TfDevice_ReturnOp : TfDevice_Op<"return", [Terminator, HasParent<"LaunchOp">]> {
let summary = [{
The `tf_device.return` operation terminates and returns values from
`tf_device.launch` operation;
}];
let arguments = (ins
Variadic<AnyType>:$results
);
let builders = [OpBuilder<
"Builder *builder, OperationState *result",
[{
build(builder, result, {});
}]>
];
let verifier = ?;
}
def TfDevice_LaunchFuncOp : TfDevice_Op<"launch_func", []> {
let summary = [{
The `tf_device.launch_func` launches a function on target device.
}];
let arguments = (ins
StrAttr:$device,
SymbolRefAttr:$func,
Variadic<AnyType>:$operands);
let results = (outs
Variadic<AnyType>:$results
);
let extraClassDeclaration = [{
StringRef getFunc() { return func(); }
StringRef getDevice() { return device(); }
FunctionType getFuncType();
}];
}
#endif // TF_DEVICE_DIALECT
......@@ -26,6 +26,7 @@ limitations under the License.
#include "mlir/Pass/PassRegistry.h" // TF:local_config_mlir
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/Transforms/RegionUtils.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
......@@ -38,37 +39,37 @@ struct ClusterOutliningPass : public ModulePass<ClusterOutliningPass> {
void runOnModule() override;
};
void ReplaceLaunchReturnWithReturn(Operation* launch_return_op,
void ReplaceLaunchReturnWithReturn(tf_device::ReturnOp launch_return_op,
OpBuilder* builder) {
llvm::SmallVector<Value*, 4> operands(launch_return_op->getOperands());
builder->create<ReturnOp>(launch_return_op->getLoc(), operands);
launch_return_op->erase();
llvm::SmallVector<Value*, 4> operands(launch_return_op.getOperands());
builder->create<ReturnOp>(launch_return_op.getLoc(), operands);
launch_return_op.erase();
}
// Builds a function that outlines region attached to launch_op and inserts
// built function into given module.
FuncOp BuildFunction(StringRef device, llvm::ArrayRef<Value*> live_ins,
Operation* launch_op, ModuleManager* module_manager,
OpBuilder* builder) {
tf_device::LaunchOp launch_op,
ModuleManager* module_manager, OpBuilder* builder) {
llvm::SmallVector<Type, 4> operand_types;
operand_types.reserve(live_ins.size());
for (Value* v : live_ins) operand_types.emplace_back(v->getType());
llvm::SmallVector<Type, 4> result_types(launch_op->getResultTypes());
llvm::SmallVector<Type, 4> result_types(launch_op.getResultTypes());
auto func_type =
FunctionType::get(operand_types, result_types, builder->getContext());
std::string func_name_prefix = Twine(device, "_func").str();
FuncOp outlined_func =
FuncOp::create(launch_op->getLoc(), func_name_prefix, func_type);
FuncOp::create(launch_op.getLoc(), func_name_prefix, func_type);
// Create function body.
Block* outlined_func_block = outlined_func.addEntryBlock();
// Replace uses of live-in values within launch_op region with function
// arguments.
Region& launch_op_region = launch_op->getRegion(0);
Region& launch_op_region = launch_op.body();
for (const auto& p :
llvm::zip(live_ins, outlined_func_block->getArguments())) {
replaceAllUsesInRegionWith(std::get<0>(p), std::get<1>(p),
......@@ -83,7 +84,8 @@ FuncOp BuildFunction(StringRef device, llvm::ArrayRef<Value*> live_ins,
// Replace `tf_device.launch_return` terminator with `std.return` in function
// body.
Operation* launch_return_op = &outlined_func_block->back();
auto launch_return_op =
cast<tf_device::ReturnOp>(outlined_func_block->getTerminator());
builder->setInsertionPoint(launch_return_op);
ReplaceLaunchReturnWithReturn(launch_return_op, builder);
......@@ -91,53 +93,36 @@ FuncOp BuildFunction(StringRef device, llvm::ArrayRef<Value*> live_ins,
return outlined_func;
}
Operation* BuildLaunchFunc(const Location& loc, StringRef device, FuncOp func,
llvm::ArrayRef<Value*> live_ins,
OpBuilder* builder) {
// TODO(b/138909768): Define `tf_device.launch_func` and use its build method
// instead.
OperationState launch_func_op(loc, "tf_device.launch_func");
launch_func_op.addAttribute("device", builder->getStringAttr(device));
launch_func_op.addAttribute("func",
builder->getSymbolRefAttr(func.getName()));
launch_func_op.addTypes(func.getType().getResults());
llvm::SmallVector<Value*, 4> operands(live_ins.begin(), live_ins.end());
launch_func_op.addOperands(operands);
return builder->createOperation(launch_func_op);
}
// Outlines body of `tf_device.launch` into a function and create a
// `tf_device.launch_func` to invoke that function. `tf_device.launch` is
// removed afterwards.`
void OutlineLaunch(Operation* launch_op, ModuleManager* module_manager,
void OutlineLaunch(tf_device::LaunchOp launch_op, ModuleManager* module_manager,
OpBuilder* builder) {
llvm::SetVector<Value*> live_ins;
getUsedValuesDefinedAbove(launch_op->getRegion(0), launch_op->getRegion(0),
live_ins);
getUsedValuesDefinedAbove(launch_op.body(), launch_op.body(), live_ins);
StringRef device = launch_op->getAttrOfType<StringAttr>("device").getValue();
StringRef device = launch_op.getAttrOfType<StringAttr>("device").getValue();
FuncOp outlined_func = BuildFunction(device, live_ins.getArrayRef(),
launch_op, module_manager, builder);
builder->setInsertionPoint(launch_op);
Operation* launch_func_op =
BuildLaunchFunc(launch_op->getLoc(), device, outlined_func,
live_ins.getArrayRef(), builder);
launch_op->replaceAllUsesWith(launch_func_op);
launch_op->erase();
tf_device::LaunchFuncOp launch_func_op =
builder->create<tf_device::LaunchFuncOp>(
launch_op.getLoc(), outlined_func.getType().getResults(),
builder->getStringAttr(device),
builder->getSymbolRefAttr(outlined_func.getName()),
live_ins.getArrayRef());
launch_op.replaceAllUsesWith(launch_func_op);
launch_op.erase();
}
void ClusterOutliningPass::runOnModule() {
ModuleOp m = getModule();
ModuleManager module_manager(m);
OpBuilder builder(m.getContext());
m.walk([&](Operation* op) {
// TODO(b/138909768): Use templated Walk method instead of skipping
// operations according to their type string.
if (op->getName().getStringRef() != "tf_device.launch") return;
OutlineLaunch(op, &module_manager, &builder);
m.walk<tf_device::LaunchOp>([&](tf_device::LaunchOp launch) {
OutlineLaunch(launch, &module_manager, &builder);
});
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册