diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index d8b4fe5bcef0d7e97188dcf3cfcf8e25710b972d..fd9953de1e2b1e5c078ad1a2124d42d1efae58ab 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -132,7 +132,6 @@ tf_cc_binary( ":tf_mlir_opt_main", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration", "//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass", - "//tensorflow/compiler/mlir/tfjs:tensorflow_js_dialect_registration", "//tensorflow/compiler/mlir/xla:all_xla_passes_for_testing", ], ) @@ -169,3 +168,5 @@ filegroup( name = "litfiles", srcs = glob(["runlit*py"]), ) + +exports_files(["run_lit.sh"]) diff --git a/tensorflow/compiler/mlir/glob_lit_test.bzl b/tensorflow/compiler/mlir/glob_lit_test.bzl index edbf3663a89c5c717736e74ade24ad72c0817d0f..1fa57babdae78f7a8f4abddcf62e25b11e619f17 100644 --- a/tensorflow/compiler/mlir/glob_lit_test.bzl +++ b/tensorflow/compiler/mlir/glob_lit_test.bzl @@ -43,10 +43,10 @@ def _run_lit_test(name, data, size, tags, driver, features, exec_properties): and specifying a default driver will abort the tests. features: [str], list of extra features to enable. """ - if driver != _default_driver: - fail("There is no present support for custom drivers. Please omit" + - " the driver parameter when running this test. If you require" + - " custom driver support, please file an issue to request it.") + + # Remove the default_driver from the data: it does not exist as a file and is + # just a placeholder from the copybara rewrite. + data = [d for d in data if d != _default_driver] # Disable tests on windows for now, to enable testing rest of all xla and mlir. native.py_test( diff --git a/tensorflow/compiler/mlir/runlit.cfg.py b/tensorflow/compiler/mlir/runlit.cfg.py index 45c8dce8422139151eabe4a69be08a16af141dc9..f9870183b88280c3828df3ec7dd76d9d46768eec 100644 --- a/tensorflow/compiler/mlir/runlit.cfg.py +++ b/tensorflow/compiler/mlir/runlit.cfg.py @@ -74,7 +74,7 @@ tool_names = [ 'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate', 'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile', 'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt', 'hlo_to_llvm_ir', - 'kernel-gen-opt', 'xla-thunks-opt' + 'kernel-gen-opt', 'xla-thunks-opt', 'tfjs-opt' ] tools = [ToolSubst(s, unresolved='ignore') for s in tool_names] llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/tensorflow/compiler/mlir/tfjs/BUILD b/tensorflow/compiler/mlir/tfjs/BUILD index 7d3091f921fd026ef5f616ddf12e8a0fe266a2f1..4db960085ec3ed80ab0bf2bf15334131a1082223 100644 --- a/tensorflow/compiler/mlir/tfjs/BUILD +++ b/tensorflow/compiler/mlir/tfjs/BUILD @@ -68,17 +68,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "tensorflow_js_dialect_registration", - srcs = [ - "ir/dialect_registration.cc", - ], - deps = [ - ":tensorflow_js", - ], - alwayslink = 1, -) - gentbl( name = "tfjs_optimize_inc_gen", tbl_outs = [ @@ -107,7 +96,6 @@ cc_library( ], deps = [ ":tensorflow_js", - ":tensorflow_js_dialect_registration", "//tensorflow/compiler/mlir/tensorflow", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", @@ -149,7 +137,6 @@ cc_library( ], deps = [ ":tensorflow_js", - ":tensorflow_js_dialect_registration", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", "//tensorflow/compiler/mlir/tensorflow:export_utils", @@ -236,3 +223,20 @@ tf_cc_binary( "@llvm-project//mlir:Support", ], ) + +tf_cc_binary( + name = "tfjs-opt", + srcs = [ + "tfjs_opt.cc", + ], + deps = [ + ":tensorflow_js", + ":tensorflow_js_passes", + "//tensorflow/compiler/mlir:init_mlir", + "//tensorflow/compiler/mlir/lite:tensorflow_lite_legalize_tf", + "//tensorflow/compiler/mlir/tensorflow", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:StandardOps", + ], +) diff --git a/tensorflow/compiler/mlir/tfjs/ir/dialect_registration.cc b/tensorflow/compiler/mlir/tfjs/ir/dialect_registration.cc deleted file mode 100644 index 44ce38469b844ce216ec1792c84fc88c81036e29..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/mlir/tfjs/ir/dialect_registration.cc +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h" - -// Static initialization for TensorFlow.js op registration. -static mlir::DialectRegistration tfjs_ops; diff --git a/tensorflow/compiler/mlir/tfjs/tests/BUILD b/tensorflow/compiler/mlir/tfjs/tests/BUILD index a4ebc9979911e270e703de83ca31fa633bf68e2c..5789480c3ba09552dbc899bf95ccdb89ae9db095 100644 --- a/tensorflow/compiler/mlir/tfjs/tests/BUILD +++ b/tensorflow/compiler/mlir/tfjs/tests/BUILD @@ -3,8 +3,11 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") package(licenses = ["notice"]) glob_lit_tests( - data = [":test_utilities"], - driver = "@llvm-project//mlir:run_lit.sh", + data = [ + ":test_utilities", + "@llvm-project//mlir:run_lit.sh", + ], + driver = "//tensorflow/compiler/mlir:run_lit.sh", test_file_exts = ["mlir"], ) @@ -13,7 +16,7 @@ filegroup( name = "test_utilities", testonly = True, data = [ - "//tensorflow/compiler/mlir:tf-opt", + "//tensorflow/compiler/mlir/tfjs:tfjs-opt", "@llvm-project//llvm:FileCheck", "@llvm-project//llvm:not", ], diff --git a/tensorflow/compiler/mlir/tfjs/tests/ops.mlir b/tensorflow/compiler/mlir/tfjs/tests/ops.mlir index 0b7210118dfd00b37f278ecb667df60971ed12e0..602f34657a0801736732a3359fdaf1be3addaff3 100644 --- a/tensorflow/compiler/mlir/tfjs/tests/ops.mlir +++ b/tensorflow/compiler/mlir/tfjs/tests/ops.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt -split-input-file -verify-diagnostics -tfl-runtime-verify %s | FileCheck %s +// RUN: tfjs-opt -split-input-file -verify-diagnostics -tfl-runtime-verify %s | FileCheck %s // ----- diff --git a/tensorflow/compiler/mlir/tfjs/tests/optimize.mlir b/tensorflow/compiler/mlir/tfjs/tests/optimize.mlir index 5f046dc5a8a28de9edb195f4e7fd8b881fd841ff..f4464ddd01dee3c5201a072192e1fba644bfb49c 100644 --- a/tensorflow/compiler/mlir/tfjs/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/tfjs/tests/optimize.mlir @@ -1,5 +1,5 @@ // Run optimize pass only and check the results. -// RUN: tf-opt %s -tfjs-optimize | FileCheck %s +// RUN: tfjs-opt %s -tfjs-optimize | FileCheck %s // CHECK-LABEL: prelu_fusion func @prelu_fusion(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { diff --git a/tensorflow/compiler/mlir/tfjs/tfjs_opt.cc b/tensorflow/compiler/mlir/tfjs/tfjs_opt.cc new file mode 100644 index 0000000000000000000000000000000000000000..c60131282958cb258c42031e9938169c881c7a24 --- /dev/null +++ b/tensorflow/compiler/mlir/tfjs/tfjs_opt.cc @@ -0,0 +1,33 @@ +/* Copyright 2020 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/InitAllPasses.h" // from @llvm-project +#include "mlir/Support/MlirOptMain.h" // from @llvm-project +#include "tensorflow/compiler/mlir/init_mlir.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tfjs/ir/tfjs_ops.h" + +int main(int argc, char **argv) { + tensorflow::InitMlir y(&argc, &argv); + + mlir::registerAllPasses(); + + mlir::DialectRegistry registry; + registry.insert(); + registry.insert(); + registry.insert(); + return failed(mlir::MlirOptMain(argc, argv, "TF JS pass driver\n", registry)); +} diff --git a/tensorflow/compiler/mlir/tfjs/transforms/optimize.cc b/tensorflow/compiler/mlir/tfjs/transforms/optimize.cc index c03a68471bc79fbe1da9a1f23700c3db8b9c44c2..a3678f7d1544a84b496f80c748f142e5bcaa217c 100644 --- a/tensorflow/compiler/mlir/tfjs/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/tfjs/transforms/optimize.cc @@ -37,6 +37,9 @@ namespace { // Optimize TFJS operations in functions. struct Optimize : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const final { + registry.insert(); + } void runOnFunction() override; }; diff --git a/tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc b/tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc index 915fb91a8df829c23fa29aa7ea24a31d388bb83e..e735a3c7b8c527fd18eaeb7d2740cc731b61d0eb 100644 --- a/tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc +++ b/tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc @@ -125,7 +125,6 @@ int main(int argc, char** argv) { "TF GraphDef to TFJS JSON converter\n"); MLIRContext context; - context.loadAllGloballyRegisteredDialects(); llvm::SourceMgr source_mgr; mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context);