提交 2d0592a0 编写于 作者: M Mehdi Amini 提交者: TensorFlower Gardener

Remove dependency on Dialect global registration from //tensorflow/compiler/mlir/tfrt/...

PiperOrigin-RevId: 328168585
Change-Id: I702964790826638b194eb7c5b21f98492e90c727
上级 db4bab7c
......@@ -133,7 +133,6 @@ cc_library(
":hlo_utils",
":mlir_hlo_to_hlo",
"//tensorflow/compiler/mlir/hlo",
"//tensorflow/compiler/mlir/hlo:hlo_dialect_force_registration",
"//tensorflow/compiler/mlir/hlo:lhlo",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:statusor",
......@@ -334,6 +333,7 @@ cc_library(
":mlir_hlo_to_hlo",
"//tensorflow/compiler/jit:xla_cpu_jit",
"//tensorflow/compiler/jit:xla_gpu_jit",
"//tensorflow/compiler/mlir/hlo",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:statusor",
......@@ -342,6 +342,7 @@ cc_library(
"//tensorflow/core:lib",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Translation",
],
alwayslink = 1,
......
......@@ -42,13 +42,13 @@ class XlaBuilderTest : public ::testing::Test {
protected:
XlaBuilderTest()
: name_(SetupTest()),
context_(),
module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_))),
builder_(&module_->getBodyRegion()),
xla_builder_(name_, builder_, module_->getLoc()) {}
xla_builder_(name_, builder_, module_->getLoc()) {
context_.loadDialect<mlir::mhlo::MhloDialect>();
}
string SetupTest() {
mlir::registerDialect<mlir::mhlo::MhloDialect>();
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
}
......
......@@ -2,6 +2,6 @@
// CHECK: Opaque elements attr not supported
func @main() {
%0 = "tf.Const"() {value = opaque<"tf", "0x0123456789ABCDEF"> : tensor<4xf32>} : () -> tensor<4xf32>
%0 = "mhlo.constant"() {value = opaque<"mhlo", "0x0123456789ABCDEF"> : tensor<4xf32>} : () -> tensor<4xf32>
return
}
......@@ -69,6 +69,10 @@ namespace {
constexpr char kShardingAttr[] = "mhlo.sharding";
class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<chlo::HloClientDialect, mhlo::MhloDialect>();
}
public:
LegalizeTF() = default;
LegalizeTF(const LegalizeTF &) {}
......
......@@ -25,6 +25,7 @@ limitations under the License.
#include "mlir/IR/AffineMap.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
......@@ -34,6 +35,8 @@ limitations under the License.
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassOptions.h" // from @llvm-project
#include "mlir/Translation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
......@@ -133,6 +136,11 @@ Status ConvertModule(std::unique_ptr<HloModule> hlo_module, ModuleOp module,
// MLIR LHLO.
class XlaHloToLhloPass
: public PassWrapper<XlaHloToLhloPass, OperationPass<ModuleOp>> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<mlir::StandardOpsDialect, mlir::mhlo::MhloDialect,
mlir::lmhlo::LmhloDialect>();
}
public:
XlaHloToLhloPass() = default;
XlaHloToLhloPass(const XlaHloToLhloPass&) {}
......@@ -438,7 +446,7 @@ Status LhloDialectEmitter::Initialize() {
builder_.setInsertionPointToEnd(block);
auto return_op = builder_.create<ReturnOp>(builder_.getUnknownLoc());
builder_ = mlir::OpBuilder(return_op);
builder_ = OpBuilder(return_op);
return Status::OK();
}
......@@ -449,6 +457,9 @@ std::unique_ptr<OperationPass<ModuleOp>> createXlaHloToLhloWithXlaPass() {
Status HloToLhloModule(const BufferAssignment& assignment,
const HloModule& hlo_module, ModuleOp module) {
module.getContext()
->loadDialect<StandardOpsDialect, mhlo::MhloDialect,
lmhlo::LmhloDialect>();
HloComputation* computation = hlo_module.entry_computation();
LhloDialectEmitter emitter(assignment, *computation, module);
......@@ -462,15 +473,14 @@ Status HloToLhloModule(const BufferAssignment& assignment,
return computation->AcceptOrdered(&emitter, ordering);
}
mlir::OwningModuleRef HloTextToLhloTranslateFunction(
llvm::StringRef input, mlir::MLIRContext* context) {
OwningModuleRef HloTextToLhloTranslateFunction(llvm::StringRef input,
MLIRContext* context) {
StatusOr<std::unique_ptr<HloModule>> maybe_module =
xla::ParseAndReturnUnverifiedModule(
absl::string_view(input.data(), input.size()));
TF_CHECK_OK(maybe_module.status());
mlir::OwningModuleRef module =
mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
OwningModuleRef module = ModuleOp::create(UnknownLoc::get(context));
TF_CHECK_OK(
ConvertModule(maybe_module.ConsumeValueOrDie(), module.get(), "Host"));
......
......@@ -64,7 +64,6 @@ inline ::testing::PolymorphicMatcher<ProtoStringMatcher> EqualsProto(
TEST(TypeToShapeTest, ConvertPrimitiveTypes) {
MLIRContext context;
context.loadAllGloballyRegisteredDialects();
Builder b(&context);
EXPECT_EQ(TypeToPrimitiveType(b.getF32Type()), PrimitiveType::F32);
......@@ -75,7 +74,6 @@ TEST(TypeToShapeTest, ConvertPrimitiveTypes) {
TEST(TypeToShapeTest, ConvertBasicTypesToTypes) {
MLIRContext context;
context.loadAllGloballyRegisteredDialects();
Builder b(&context);
EXPECT_TRUE(
......@@ -97,7 +95,6 @@ TEST(TypeToShapeTest, ConvertBasicTypesToTypes) {
TEST(TypeToShapeTest, ConvertMemRefTypeToTypes) {
MLIRContext context;
context.loadAllGloballyRegisteredDialects();
Builder b(&context);
// Memref without any affine map. Note: memory space is ignored for shape.
......
......@@ -17,8 +17,11 @@ limitations under the License.
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/MemoryBuffer.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Translation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
#include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h"
......@@ -173,11 +176,17 @@ static mlir::LogicalResult MlirHloToHloTextTranslateFunction(
} // namespace xla
static void RegisterInputDialects(mlir::DialectRegistry& registry) {
registry.insert<mlir::StandardOpsDialect, mlir::mhlo::MhloDialect>();
}
static mlir::TranslateFromMLIRRegistration MlirHloToHloTranslate(
"mlir-hlo-to-hlo", xla::MlirHloToHloTranslateFunction);
"mlir-hlo-to-hlo", xla::MlirHloToHloTranslateFunction,
RegisterInputDialects);
static mlir::TranslateFromMLIRRegistration MlirHloToHloTextTranslate(
"mlir-hlo-to-hlo-text", xla::MlirHloToHloTextTranslateFunction);
"mlir-hlo-to-hlo-text", xla::MlirHloToHloTextTranslateFunction,
RegisterInputDialects);
static mlir::TranslateToMLIRRegistration HloToHloMlirTranslate(
"hlo-to-mlir-hlo", xla::HloToMlirHloTranslateFunction);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册