提交 d41d5121 编写于 作者: M Mehdi Amini 提交者: TensorFlower Gardener

Update third_party/tensorflow/compiler/mlir/tensorflow/utils/... to not depend...

Update third_party/tensorflow/compiler/mlir/tensorflow/utils/... to not depend on the global Dialect Registry (NFC)

PiperOrigin-RevId: 328171679
Change-Id: I3a4e40a04cb14b9c3d53239b31d3d642bb97daac
上级 116792db
......@@ -17,10 +17,13 @@ limitations under the License.
#define MLIR_HLO_DIALECT_MHLO_IR_REGISTER_H_
namespace mlir {
class DialectRegistry;
namespace mhlo {
void registerAllDialects();
// Add chlo, mhlo, lmhlo dialects to the provided registry.
void registerAllMhloDialects(DialectRegistry &registry);
}
} // namespace mlir
......
......@@ -31,3 +31,11 @@ void mlir::mhlo::registerAllDialects() {
// Dependent dialects
}
void mlir::mhlo::registerAllMhloDialects(mlir::DialectRegistry &registry) {
// clang-format off
registry.insert<mlir::chlo::HloClientDialect,
mlir::lmhlo::LmhloDialect,
mlir::mhlo::MhloDialect>();
// clang-format on
}
......@@ -1512,6 +1512,7 @@ COMPILE_MLIR_UTIL_DEPS = [
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
"//tensorflow/compiler/mlir/hlo:hlo",
"//tensorflow/compiler/mlir/hlo:hlo_dialect_registration",
"//tensorflow/compiler/mlir/hlo:sink_constants_to_control_flow",
"//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo",
"//tensorflow/compiler/mlir/xla:type_to_shape",
......
......@@ -36,7 +36,9 @@ limitations under the License.
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Transforms/Passes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
......@@ -276,16 +278,9 @@ Status RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
return Status::OK();
}
static void RegisterDialects() {
static bool init_once = []() {
mlir::registerDialect<mlir::StandardOpsDialect>();
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
mlir::registerDialect<mlir::shape::ShapeDialect>();
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
mlir::registerDialect<mlir::mhlo::MhloDialect>();
return true;
}();
(void)init_once;
static void RegisterDialects(mlir::DialectRegistry& registry) {
mlir::RegisterAllTensorFlowDialects(registry);
mlir::mhlo::registerAllMhloDialects(registry);
}
} // namespace
......@@ -418,9 +413,8 @@ Status CompileSerializedMlirToXlaHlo(
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
XlaCompilationResult* compilation_result,
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
RegisterDialects();
mlir::MLIRContext mlir_context;
mlir_context.loadAllGloballyRegisteredDialects();
RegisterDialects(mlir_context.getDialectRegistry());
mlir::OwningModuleRef mlir_module;
TF_RETURN_IF_ERROR(
......@@ -507,10 +501,8 @@ Status CompileGraphToXlaHlo(
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
XlaCompilationResult* compilation_result,
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
RegisterDialects();
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
RegisterDialects(context.getDialectRegistry());
GraphImportConfig config;
config.graph_as_function = true;
// Disable shape inference during import as some TensorFlow op fails during
......
......@@ -20,6 +20,7 @@ limitations under the License.
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
......@@ -33,17 +34,13 @@ limitations under the License.
namespace tensorflow {
namespace {
static void RegisterDialects() {
static bool init_once = []() {
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
return true;
}();
(void)init_once;
static void RegisterDialects(mlir::MLIRContext &context) {
context.loadDialect<mlir::TF::TensorFlowDialect>();
}
TEST(ConvertTypeToTensorTypeTest, UnrankedTensorType) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
RegisterDialects(context);
mlir::Builder b(&context);
PartialTensorShape output_shape =
......@@ -53,7 +50,7 @@ TEST(ConvertTypeToTensorTypeTest, UnrankedTensorType) {
TEST(ConvertTypeToTensorTypeTest, NonFullyDefinedRankedTensorType) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
RegisterDialects(context);
mlir::Builder b(&context);
PartialTensorShape output_shape = ConvertTypeToTensorShape(
......@@ -63,7 +60,7 @@ TEST(ConvertTypeToTensorTypeTest, NonFullyDefinedRankedTensorType) {
TEST(ConvertTypeToTensorTypeTest, FullyDefinedRankedTensorType) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
RegisterDialects(context);
mlir::Builder b(&context);
PartialTensorShape output_shape = ConvertTypeToTensorShape(
......@@ -80,8 +77,8 @@ TEST(ConvertTypeToTensorTypeTest, ScalarTensorType) {
}
TEST(ConvertTypeToTensorTypeTest, ConvertStringTensor) {
RegisterDialects();
mlir::MLIRContext context;
RegisterDialects(context);
mlir::Builder b(&context);
// Create the sample tensor to convert.
......@@ -126,9 +123,8 @@ class ConvertTensorTest : public ::testing::Test {
};
TEST_F(ConvertTensorTest, Simple) {
RegisterDialects();
mlir::MLIRContext context;
RegisterDialects(context);
ASSERT_NO_FATAL_FAILURE(VerifyConversion<Eigen::half>(
{Eigen::half(1.0)}, DT_HALF, mlir::FloatType::getF16(&context)));
ASSERT_NO_FATAL_FAILURE(
......
......@@ -36,7 +36,6 @@ std::string ConvertToMlirString(const std::vector<int64_t>& dims,
}
mlir::MLIRContext context;
mlir::Builder b(&context);
context.loadAllGloballyRegisteredDialects();
auto status_or = ConvertToMlirTensorType(shape, dtype, &b);
std::string buf;
llvm::raw_string_ostream os(buf);
......
......@@ -60,7 +60,6 @@ class FakeDevice : public Device {
TEST(DeviceUtilTest, AddDeviceToOp) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
......@@ -102,7 +101,6 @@ TEST(DeviceUtilTest, AddDeviceToOp) {
TEST(DeviceUtilTest, AddDeviceToOpNullDeviceSet) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
......@@ -112,7 +110,6 @@ TEST(DeviceUtilTest, AddDeviceToOpNullDeviceSet) {
TEST(DeviceUtilTest, GetDevicesFromOpNoDevicesAttribute) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
......
......@@ -66,7 +66,6 @@ Status DumpTextualIRToFile(const MlirDumpConfig& config, const Graph& graph,
WritableFile* file) {
WritableFileRawStream os(std::move(file));
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
mlir::OwningModuleRef module;
if (flib_def) {
flib_def = &graph.flib_def();
......
......@@ -28,7 +28,6 @@ namespace {
TEST(DumpMlirModuleTest, NoEnvPrefix) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
unsetenv("TF_DUMP_GRAPH_PREFIX");
......@@ -39,7 +38,6 @@ TEST(DumpMlirModuleTest, NoEnvPrefix) {
TEST(DumpMlirModuleTest, LogInfo) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
setenv("TF_DUMP_GRAPH_PREFIX", "-", 1);
......@@ -50,7 +48,6 @@ TEST(DumpMlirModuleTest, LogInfo) {
TEST(DumpMlirModuleTest, Valid) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
setenv("TF_DUMP_GRAPH_PREFIX", testing::TmpDir().c_str(), 1);
......
......@@ -29,7 +29,6 @@ using testing::HasSubstr;
TEST(ErrorUtilTest, StatusScopedDiagnosticHandler) {
MLIRContext context;
context.loadAllGloballyRegisteredDialects();
auto id = Identifier::get("test.cc", &context);
auto loc = FileLineColLoc::get(id, 0, 0, &context);
......
......@@ -602,7 +602,6 @@ TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh1x2x1x3) {
TEST(TPURewriteDeviceUtilTest, TestGetDeviceCoordinates) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
mlir::Builder builder(&context);
auto device_assignment_attr = builder.getI64ArrayAttr({1, 2, 3});
auto status_or_device_coodinates =
......@@ -616,7 +615,6 @@ TEST(TPURewriteDeviceUtilTest, TestGetDeviceCoordinates) {
TEST(TPURewriteDeviceUtilTest, TestInvalidAttrForDeviceAssignmentDisallowed) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
mlir::Builder builder(&context);
auto device_assignment_attr = builder.getF32ArrayAttr({1.0, 2.0, 3.0});
auto status_or_device_coodinates =
......@@ -627,9 +625,8 @@ TEST(TPURewriteDeviceUtilTest, TestInvalidAttrForDeviceAssignmentDisallowed) {
}
TEST(TPURewriteDeviceUtilTest, TestGetHostFailDeviceMissingAttributes) {
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
mlir::OpBuilder builder(module_ref->getBodyRegion());
......@@ -644,8 +641,8 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostFailDeviceMissingAttributes) {
}
TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailModelParallelism) {
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
mlir::MLIRContext context;
context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
mlir::OpBuilder builder(module_ref->getBodyRegion());
......@@ -665,8 +662,8 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailModelParallelism) {
}
TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingTopology) {
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
mlir::MLIRContext context;
context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
mlir::OpBuilder builder(module_ref->getBodyRegion());
......@@ -685,8 +682,8 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingTopology) {
}
TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingDeviceAssignment) {
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
mlir::MLIRContext context;
context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
mlir::OpBuilder builder(module_ref->getBodyRegion());
......@@ -705,8 +702,8 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingDeviceAssignment) {
}
TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceAssignment) {
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
mlir::MLIRContext context;
context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
mlir::OpBuilder builder(module_ref->getBodyRegion());
......@@ -728,8 +725,8 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceAssignment) {
}
TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceName) {
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
mlir::MLIRContext context;
context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
mlir::OpBuilder builder(module_ref->getBodyRegion());
......@@ -753,8 +750,8 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceName) {
}
TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceTPUReplicate) {
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
mlir::MLIRContext context;
context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
mlir::OpBuilder builder(module_ref->getBodyRegion());
......@@ -780,8 +777,8 @@ TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceTPUReplicate) {
}
TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceNotReplicated) {
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
mlir::MLIRContext context;
context.loadDialect<mlir::tf_device::TensorFlowDeviceDialect>();
mlir::OwningModuleRef module_ref =
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
mlir::OpBuilder builder(module_ref->getBodyRegion());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册