提交 db78382d 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Create initial skeleton for the Flatbuffer -> MLIR importer

Add a flatbuffer_importer.cc that registers a translation from TFLite
Flatbuffer to MLIR and incorporate it into the flatbuffer_translate
tool.

The translator does not yet perform any translation, but only
validates that the input file contains a FlatBuffer and prints its
version number and the names and input tensor IDs of each subgraph.

The tests don't actually include the expected correct output, but
instead simply make sure that the initial code, which only calls the
flatbuffer parser and prints some simple information, functions
correctly.

PiperOrigin-RevId: 258431902
上级 584a64fa
......@@ -403,9 +403,11 @@ cc_library(
cc_library(
name = "flatbuffer_translate_lib",
srcs = [
"flatbuffer_import.cc",
"flatbuffer_translate.cc",
],
hdrs = [
"flatbuffer_import.h",
"flatbuffer_translate.h",
],
deps = [
......
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
#include <iostream>
#include <string>
#include "absl/strings/string_view.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/MemoryBuffer.h"
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
#include "mlir/Support/FileUtilities.h" // TF:local_config_mlir
#include "mlir/Translation.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/schema/schema_generated.h"
using mlir::Location;
using mlir::MLIRContext;
using mlir::OwningModuleRef;
namespace tflite {
OwningModuleRef FlatBufferToMlir(absl::string_view buffer, MLIRContext* context,
Location base_loc) {
auto model_ptr =
FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length());
if (nullptr == model_ptr) {
return emitError(base_loc, "Couldn't parse flatbuffer"), nullptr;
}
std::unique_ptr<ModelT> model(model_ptr->GetModel()->UnPack());
std::cout << "Model version: " << model->version << std::endl;
for (auto& subgraph : model->subgraphs) {
std::cout << "Subgraph name: " << subgraph->name << std::endl;
for (auto& input : subgraph->inputs) {
std::cout << " Subgraph input: " << input << std::endl;
}
for (auto& output : subgraph->outputs) {
std::cout << " Subgraph output: " << output << std::endl;
}
}
mlir::Builder builder(context);
return OwningModuleRef(mlir::ModuleOp::create(base_loc));
}
} // namespace tflite
static OwningModuleRef FlatBufferFileToMlirTrans(llvm::StringRef filename,
MLIRContext* context) {
std::string error;
auto loc = mlir::FileLineColLoc::get(filename, 0, 0, context);
auto buffer = mlir::openInputFile(filename, &error);
if (nullptr == buffer) {
return emitError(loc, error), nullptr;
}
return tflite::FlatBufferToMlir(
absl::string_view(buffer->getBufferStart(), buffer->getBufferSize()),
context, loc);
}
static mlir::TranslateToMLIRRegistration FlatBufferFileToMlirTransReg(
"tflite-flatbuffer-to-mlir", FlatBufferFileToMlirTrans);
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_
#include "absl/strings/string_view.h"
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
namespace tflite {
// Converts a TFLite flatbuffer stored in `buffer` to a MLIR module
// The buffer must live for the duration of the function call,
// The caller receives ownership of the module.
// Returns nullptr on failure, and more specific errors will be emitted
// via the context.
mlir::OwningModuleRef FlatBufferToMlir(absl::string_view buffer,
mlir::MLIRContext* context,
mlir::Location base_loc);
} // namespace tflite
#endif // TENSORFLOW_COMPILER_MLIR_LITE_FLATBUFFER_IMPORT_H_
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
licenses(["notice"])
glob_lit_tests(
data = [":test_utilities"],
driver = "@local_config_mlir//:run_lit.sh",
test_file_exts = ["mlir"],
)
# Bundle together all of the test utilities that are used by tests.
filegroup(
name = "test_utilities",
testonly = True,
data = [
"//tensorflow/compiler/mlir/lite:flatbuffer_translate",
"@llvm//:FileCheck",
],
)
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
func @main(tensor<4xf32>) -> tensor<4xf32> {
^bb0(%arg0: tensor<4xf32>):
// CHECK: Model version: 3
// CHECK-NEXT: Subgraph name: main
// CHECK-NEXT: Subgraph input: 0
// CHECK-NEXT: Subgraph output: 6
%0 = "tfl.pseudo_input" (%arg0) : (tensor<4xf32>) -> tensor<4xf32> loc("Input")
%1 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const")
%2 = "tfl.squared_difference"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("squared_difference")
%3 = "tfl.mul"(%0, %2) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("mul")
%4 = "tfl.div"(%3, %2) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("div")
%5 = "tfl.exp"(%4) : (tensor<4xf32>) -> tensor<4xf32> loc("exp")
%6 = "tfl.neg"(%5) : (tensor<4xf32>) -> tensor<4xf32> loc("neg")
return %6 : tensor<4xf32>
}
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck --dump-input-on-failure %s
func @main(tensor<3x2xi32>) -> tensor<3x2xi32> {
^bb0(%arg0: tensor<3x2xi32>):
// CHECK: Model version: 3
// CHECK-NEXT: Subgraph name: main
// CHECK-NEXT: Subgraph input: 0
// CHECK-NEXT: Subgraph output: 4
%0 = "tfl.pseudo_input" (%arg0) : (tensor<3x2xi32>) -> tensor<3x2xi32> loc("Input")
%1 = "tfl.pseudo_const" () {value = dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> loc("Const")
%2 = "tfl.sub" (%0, %1) {fused_activation_function = "RELU6"} : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> loc("sub")
%3 = "std.constant" () {value = dense<10> : tensor<i32>} : () -> tensor<i32> loc("Const2")
%4 = "tfl.add" (%3, %2) {fused_activation_function = "NONE"} : (tensor<i32>, tensor<3x2xi32>) -> tensor<3x2xi32> loc("add")
return %4 : tensor<3x2xi32>
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册