提交 bcfb691a 编写于 作者: M Mehdi Amini 提交者: TensorFlower Gardener

Refactor `kernel-gen-opt.cc` and remove dependency on global dialect registration (NFC)

PiperOrigin-RevId: 328193995
Change-Id: I189e4c8f121df487521138914e29c1f46890f717
上级 5ce52adc
......@@ -21,7 +21,7 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LLVMDialect",
......@@ -65,8 +65,10 @@ tf_cc_binary(
srcs = ["tools/kernel-gen-opt/kernel-gen-opt.cc"],
visibility = ["//tensorflow/compiler/mlir/tools/kernel_gen/tests:__pkg__"],
deps = [
"//tensorflow/compiler/mlir/hlo:all_passes",
"//tensorflow/compiler/mlir/hlo:hlo_dialect_registration",
"//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops",
"//tensorflow/compiler/mlir/tools/kernel_gen/transforms:passes",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
......
......@@ -46,6 +46,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
......@@ -246,22 +247,14 @@ Status PropagateTensorFlowABIKnowledgeToKernel(
return Status::OK();
}
void RegisterDialects() {
static bool init_once = []() {
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
return true;
}();
(void)init_once;
}
} // namespace
StatusOr<std::vector<uint8_t>> tensorflow::kernel_gen::GenerateCubinForTfCode(
llvm::StringRef tf_code, std::pair<int32_t, int32_t> compute_capability,
llvm::ArrayRef<uint32_t> tile_sizes, llvm::ArrayRef<uint32_t> same_shape,
llvm::ArrayRef<uint32_t> unroll_factors) {
RegisterDialects();
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
mlir::RegisterAllTensorFlowDialects(context.getDialectRegistry());
mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
TF_RETURN_IF_ERROR(LowerTfOpToLhloWithDynamicShapes(module.get()));
......
......@@ -35,13 +35,3 @@ cc_library(
"@llvm-project//mlir:SideEffects",
],
)
cc_library(
name = "tf_framework_dialect_registration",
srcs = ["dialect_registration.cc"],
deps = [
":tf_framework_ops",
"@llvm-project//mlir:IR",
],
alwayslink = 1,
)
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
// Static initialization for TF Framework dialect registration.
static mlir::DialectRegistration<
mlir::kernel_gen::tf_framework::TFFrameworkDialect>
tf_framework_ops;
......@@ -13,114 +13,27 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/AsmState.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/InitAllDialects.h" // from @llvm-project
#include "mlir/InitAllPasses.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/FileUtilities.h" // from @llvm-project
#include "mlir/Support/MlirOptMain.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h"
// NOLINTNEXTLINE
static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
llvm::cl::desc("<input file>"),
llvm::cl::init("-"));
// NOLINTNEXTLINE
static llvm::cl::opt<std::string> outputFilename(
"o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
llvm::cl::init("-"));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> splitInputFile(
"split-input-file",
llvm::cl::desc("Split the input file into pieces and process each "
"chunk independently"),
llvm::cl::init(false));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> verifyDiagnostics(
"verify-diagnostics",
llvm::cl::desc("Check that emitted diagnostics match "
"expected-* lines on the corresponding line"),
llvm::cl::init(false));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> verifyPasses(
"verify-each",
llvm::cl::desc("Run the verifier after each transformation pass"),
llvm::cl::init(true));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> allowUnregisteredDialects(
"allow-unregistered-dialect",
llvm::cl::desc("Allow operation with no registered dialects"),
llvm::cl::init(false));
// NOLINTNEXTLINE
static llvm::cl::opt<bool> showDialects(
"show-dialects", llvm::cl::desc("Print the list of registered dialects"),
llvm::cl::init(false));
int main(int argc, char **argv) {
mlir::registerAllDialects();
mlir::registerAllPasses();
mlir::mhlo::registerAllDialects();
mlir::mhlo::registerAllMhloPasses();
mlir::lmhlo::registerAllLmhloPasses();
mlir::kernel_gen::registerKernelGenPasses();
llvm::InitLLVM y(argc, argv);
// Register any pass manager command line options.
mlir::registerAsmPrinterCLOptions();
mlir::registerPassManagerCLOptions();
mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run");
// Parse pass names in main to ensure static initialization completed.
llvm::cl::ParseCommandLineOptions(argc, argv,
"MLIR modular optimizer driver\n");
if (showDialects) {
mlir::MLIRContext context;
context.loadAllGloballyRegisteredDialects();
llvm::outs() << "Registered Dialects:\n";
for (mlir::Dialect *dialect : context.getLoadedDialects()) {
llvm::outs() << dialect->getNamespace() << "\n";
}
return 0;
}
// Set up the input file.
std::string errorMessage;
auto file = mlir::openInputFile(inputFilename, &errorMessage);
if (!file) {
llvm::errs() << errorMessage << "\n";
return 1;
}
auto output = mlir::openOutputFile(outputFilename, &errorMessage);
if (!output) {
llvm::errs() << errorMessage << "\n";
exit(1);
}
mlir::DialectRegistry registry;
registerAllDialects(registry);
if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, registry,
splitInputFile, verifyDiagnostics, verifyPasses,
allowUnregisteredDialects,
/*preloadDialectsInContext=*/true))) {
return 1;
}
// Keep the output file if the invocation of MlirOptMain was successful.
output->keep();
return 0;
mlir::registerAllDialects(registry);
mlir::mhlo::registerAllMhloDialects(registry);
mlir::RegisterAllTensorFlowDialects(registry);
registry.insert<mlir::kernel_gen::tf_framework::TFFrameworkDialect>();
return failed(
mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n", registry));
}
......@@ -30,6 +30,7 @@ limitations under the License.
#include "mlir/Dialect/SCF/Passes.h" // from @llvm-project
#include "mlir/Dialect/SCF/Transforms.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Transforms/BufferPlacement.h" // from @llvm-project
......@@ -143,6 +144,10 @@ namespace {
class LowerToNVVMPass
: public ::mlir::PassWrapper<
LowerToNVVMPass, ::mlir::OperationPass<::mlir::gpu::GPUModuleOp>> {
void getDependentDialects(mlir::DialectRegistry& registry) const override {
registry.insert<mlir::NVVM::NVVMDialect, mlir::LLVM::LLVMDialect>();
}
public:
void runOnOperation() override {
::mlir::gpu::GPUModuleOp m = getOperation();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册