...
 
Commits (5)
    https://gitcode.net/xusiwei1236/tflite-micro/-/commit/8913f033772ac7009b6d8a1c8817e2399d22d803 python: rename package tflm_runtime to runtime (#2030) 2023-06-08T23:26:13+00:00 Ryan Kuester kuester@bdti.com Rename the Python package `tflm_runtime` to simply `runtime` in preparation for adding it under the new namespace package `tflite_micro`. Its full name will then be `tflite_micro.runtime`. We have kept the `tflm_runtime` target as an alias in order to stage this change. More details in <a href="http://b/286456378" rel="nofollow noreferrer noopener" target="_blank">http://b/286456378</a> BUG=part of #1484 https://gitcode.net/xusiwei1236/tflite-micro/-/commit/43541be261d311bf0fe7c421210415aab58e1f58 Shim tflm_runtime.py to not break existing code after #2030 (#2032) 2023-06-09T21:35:52+00:00 Advait Jain advaitjain@users.noreply.github.com BUG=<a href="http://b/286456378" rel="nofollow noreferrer noopener" target="_blank">http://b/286456378</a> https://gitcode.net/xusiwei1236/tflite-micro/-/commit/e39d319ff647298a3cec2325bdcb553e2a8eb53a remove unused pybind_library. (#2034) 2023-06-12T15:07:58+00:00 Advait Jain advaitjain@users.noreply.github.com Internal checks caught this unused import in the BUILD file. BUG=cleanup https://gitcode.net/xusiwei1236/tflite-micro/-/commit/5b7374fac4175571bb5aa28f516ac1a963033269 Automated sync from github.com/tensorflow/tensorflow (#2036) 2023-06-12T17:31:19+00:00 TFLM-bot tflm-github-bot@google.com BUG=automated sync from upstream NO_CHECK_TFLITE_FILES=automated sync from upstream https://gitcode.net/xusiwei1236/tflite-micro/-/commit/929321c458538964a6fb57f1a6ba2c82ed28aa30 Add rascani@ as CODEOWNER for ci and .github directories. (#2042) 2023-06-12T20:19:54+00:00 Advait Jain advaitjain@users.noreply.github.com BUG=cleanup
* @tensorflow/micro
/.github/ @advaitjain @rockyrhodes @vamsimanchala
/ci/ @advaitjain @rockyrhodes @vamsimanchala
/.github/ @advaitjain @rockyrhodes @rascani
/ci/ @advaitjain @rockyrhodes @rascani
......@@ -212,9 +212,14 @@ inline void ResizeBilinearInteger(
(input_y - (1 << 10) * y0) * (input_x - (1 << 10) * x0);
const int64_t output_20 =
output_20_ll + output_20_lu + output_20_rl + output_20_ru;
#if TFLITE_SINGLE_ROUNDING
const int64_t round = 1 << 19;
const T interpolation = static_cast<T>((output_20 + round) >> 20);
#else
const int64_t round = (output_20 > 0) ? (1 << 19) : -(1 << 19);
const T interpolation =
static_cast<T>((output_20 + round) / (1 << 20));
#endif // TFLITE_SINGLE_ROUNDING
output_data[Offset(output_shape, b, y, x, c)] = interpolation;
}
}
......
......@@ -54,7 +54,7 @@ py_binary(
"@absl_py//absl/logging",
requirement("numpy"),
requirement("tensorflow-cpu"),
"//tensorflow/lite/micro/python/interpreter/src:tflm_runtime",
"//tensorflow/lite/micro/python/interpreter/src:runtime",
],
)
......
......@@ -19,7 +19,7 @@ from absl import flags
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.python.platform import resource_loader
from tflite_micro.tensorflow.lite.micro.python.interpreter.src import tflm_runtime
from tflite_micro.tensorflow.lite.micro.python.interpreter.src import runtime
_USE_TFLITE_INTERPRETER = flags.DEFINE_bool(
'use_tflite',
......@@ -73,7 +73,7 @@ def generate_random_float_input(sample_count=1000):
# returns the prediction of the interpreter.
def get_tflm_prediction(model_path, x_values):
# Create the tflm interpreter
tflm_interpreter = tflm_runtime.Interpreter.from_file(model_path)
tflm_interpreter = runtime.Interpreter.from_file(model_path)
input_shape = np.array(tflm_interpreter.get_input_details(0).get('shape'))
......
......@@ -18,8 +18,7 @@ import numpy as np
from tensorflow.python.framework import test_util
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
from tflite_micro.tensorflow.lite.micro.python.interpreter.src import \
tflm_runtime
from tflite_micro.tensorflow.lite.micro.python.interpreter.src import runtime
from tflite_micro.tensorflow.lite.micro.examples.hello_world import evaluate
PREFIX_PATH = resource_loader.get_path_to_datafile('')
......@@ -29,7 +28,7 @@ class HelloWorldFloatModelTest(test_util.TensorFlowTestCase):
model_path = os.path.join(PREFIX_PATH, 'models/hello_world_float.tflite')
input_shape = (1, 1)
output_shape = (1, 1)
tflm_interpreter = tflm_runtime.Interpreter.from_file(model_path)
tflm_interpreter = runtime.Interpreter.from_file(model_path)
def test_compare_with_tflite(self):
x_values = evaluate.generate_random_float_input()
......@@ -47,7 +46,7 @@ class HelloWorldQuantModelTest(test_util.TensorFlowTestCase):
model_path = os.path.join(PREFIX_PATH, 'models/hello_world_int8.tflite')
input_shape = (1, 1)
output_shape = (1, 1)
tflm_interpreter = tflm_runtime.Interpreter.from_file(model_path)
tflm_interpreter = runtime.Interpreter.from_file(model_path)
def test_compare_with_tflite(self):
x_values = evaluate.generate_random_int8_input()
......
......@@ -12,6 +12,6 @@ py_binary(
"@absl_py//absl/logging",
requirement("numpy"),
requirement("tensorflow-cpu"),
"//tensorflow/lite/micro/python/interpreter/src:tflm_runtime",
"//tensorflow/lite/micro/python/interpreter/src:runtime",
],
)
......@@ -15,7 +15,7 @@ py_binary(
srcs = ["evaluate.py"],
srcs_version = "PY3",
deps = [
"//tensorflow/lite/micro/python/interpreter/src:tflm_runtime",
"//tensorflow/lite/micro/python/interpreter/src:runtime",
"@absl_py//absl:app",
],
)
......
......@@ -28,7 +28,7 @@ from absl import logging
import numpy as np
from PIL import Image
from tflite_micro.tensorflow.lite.micro.python.interpreter.src import tflm_runtime
from tflite_micro.tensorflow.lite.micro.python.interpreter.src import runtime
FLAGS = flags.FLAGS
......@@ -113,7 +113,7 @@ def predict(interpreter, data):
"""Use TFLM interpreter to predict a MNIST image
Args:
interpreter (tflm_runtime.Interpreter): the TFLM python interpreter
interpreter (runtime.Interpreter): the TFLM python interpreter
data (np.array): data to be predicted
Returns:
......@@ -141,7 +141,7 @@ def predict_image(interpreter, image_path):
"""Use TFLM interpreter to predict a MNIST image
Args:
interpreter (tflm_runtime.Interpreter): the TFLM python interpreter
interpreter (runtime.Interpreter): the TFLM python interpreter
image_path (str): path for the image that need to be tested
Returns:
......@@ -158,7 +158,7 @@ def main(_):
if not os.path.exists(FLAGS.img_path):
raise ValueError("Image file does not exist. Please check the image path.")
tflm_interpreter = tflm_runtime.Interpreter.from_file(FLAGS.model_path)
tflm_interpreter = runtime.Interpreter.from_file(FLAGS.model_path)
category_probabilities = predict_image(tflm_interpreter, FLAGS.img_path)
predicted_category = np.argmax(category_probabilities)
logging.info("Model predicts the image as %i with probability %.2f",
......
......@@ -20,7 +20,7 @@ import tensorflow as tf
from tensorflow.python.framework import test_util
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
from tflite_micro.tensorflow.lite.micro.python.interpreter.src import tflm_runtime
from tflite_micro.tensorflow.lite.micro.python.interpreter.src import runtime
from tflite_micro.tensorflow.lite.micro.examples.mnist_lstm import evaluate
from tflite_micro.tensorflow.lite.micro.tools import requantize_flatbuffer
......@@ -33,7 +33,7 @@ class LSTMFloatModelTest(test_util.TensorFlowTestCase):
self.model_path = os.path.join(PREFIX_PATH, "trained_lstm.tflite")
self.input_shape = (1, 28, 28)
self.output_shape = (1, 10)
self.tflm_interpreter = tflm_runtime.Interpreter.from_file(self.model_path)
self.tflm_interpreter = runtime.Interpreter.from_file(self.model_path)
np.random.seed(42) #Seed the random number generator
def testInputErrHandling(self):
......@@ -95,7 +95,7 @@ class LSTMInt8ModelTest(test_util.TensorFlowTestCase):
"trained_lstm_int8.tflite")
self.input_shape = (1, 28, 28)
self.output_shape = (1, 10)
self.tflm_interpreter_quant = tflm_runtime.Interpreter.from_file(
self.tflm_interpreter_quant = runtime.Interpreter.from_file(
self.int8_model_path)
np.random.seed(42) #Seed the random number generator
......@@ -106,8 +106,7 @@ class LSTMInt8ModelTest(test_util.TensorFlowTestCase):
# Create a float model for results comparison
float_model_path = os.path.join(PREFIX_PATH, "trained_lstm.tflite")
tflm_interpreter_float = tflm_runtime.Interpreter.from_file(
float_model_path)
tflm_interpreter_float = runtime.Interpreter.from_file(float_model_path)
num_test = 10
for _ in range(num_test):
......@@ -163,7 +162,7 @@ class LSTMInt16ModelTest(test_util.TensorFlowTestCase):
self.int16_model = self.requantizer.model_bytearray()
self.input_shape = (1, 28, 28)
self.output_shape = (1, 10)
self.tflm_interpreter_quant = tflm_runtime.Interpreter.from_bytes(
self.tflm_interpreter_quant = runtime.Interpreter.from_bytes(
self.int16_model)
np.random.seed(42) #Seed the random number generator
......@@ -174,8 +173,7 @@ class LSTMInt16ModelTest(test_util.TensorFlowTestCase):
# Create a float model for results comparison
float_model_path = os.path.join(PREFIX_PATH, "trained_lstm.tflite")
tflm_interpreter_float = tflm_runtime.Interpreter.from_file(
float_model_path)
tflm_interpreter_float = runtime.Interpreter.from_file(float_model_path)
num_test = 10
for _ in range(num_test):
......
......@@ -26,6 +26,8 @@ py_test(
],
deps = [
":resource_variables_lib",
# TODO(b/286456378): update tflm_runtime to runtime when we are ready to
# remove the alias.
"//tensorflow/lite/micro/python/interpreter/src:tflm_runtime",
],
)
......@@ -17,6 +17,9 @@ import numpy as np
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
from tflite_micro.tensorflow.lite.micro.examples.recipes import resource_variables_lib
# TODO(b/286456378): change tflm_runtime to runtime when we all other usage has
# been updated.
from tflite_micro.tensorflow.lite.micro.python.interpreter.src import tflm_runtime
......
......@@ -13,7 +13,7 @@ in near future by installing a PyPi package.
#### Build
The only package that needs to be included in the `BUILD` file is
`//tensorflow/lite/micro/python/interpreter/src:tflm_runtime`. It contains all
`//tensorflow/lite/micro/python/interpreter/src:runtime`. It contains all
the correct dependencies to build the Python interpreter.
### PyPi
......@@ -34,13 +34,13 @@ bytearray format or file format.
```
# For the Bazel workflow
from tflite_micro.tensorflow.lite.micro.python.interpreter.src import tflm_runtime
from tflite_micro.tensorflow.lite.micro.python.interpreter.src import runtime
# If model is a bytearray
tflm_interpreter = tflm_runtime.Interpreter.from_bytes(model_data)
tflm_interpreter = runtime.Interpreter.from_bytes(model_data)
# If model is a file
tflm_interpreter = tflm_runtime.Interpreter.from_file(model_filepath)
tflm_interpreter = runtime.Interpreter.from_file(model_filepath)
# Run inference on TFLM using an ndarray `data_x`
tflm_interpreter.set_input(data_x, 0)
......@@ -62,7 +62,7 @@ expose an evolving set of C++ APIs. The Bazel build leverages the
[pybind11_bazel extension](https://github.com/pybind/pybind11_bazel).
The most updated Python APIs can be found in
`tensorflow/lite/micro/python/interpreter/src/tflm_runtime.py`.
`tensorflow/lite/micro/python/interpreter/src/runtime.py`.
## Custom Ops
......@@ -116,7 +116,7 @@ in the target that calls the Python interpreter with custom ops.
For example,
```
interpreter = tflm_runtime.Interpreter.from_file(
interpreter = runtime.Interpreter.from_file(
model_path=model_path,
custom_op_registerers=['SomeCustomRegisterer'])
```
......@@ -152,7 +152,7 @@ will print
10016 bytes is the actual memory arena size.
During instantiation via the class methods `tflm_runtime.Interpreter.from_file`
or `tflm_runtime.Interpreter.from_bytes`, if `arena_size` is not explicitly
During instantiation via the class methods `runtime.Interpreter.from_file`
or `runtime.Interpreter.from_bytes`, if `arena_size` is not explicitly
specified, the interpreter will default to a heuristic which is 10x the model
size. This can be adjusted manually if desired.
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension", "pybind_library")
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
load("@tflm_pip_deps//:requirements.bzl", "requirement")
load(
"//tensorflow/lite/micro:build_def.bzl",
......@@ -63,10 +63,23 @@ pybind_extension(
],
)
# tflm_runtime is deprecated, please use runtime instead.
# TODO(b/286456378): remove once all usage is changed to the runtime target.
py_library(
name = "tflm_runtime",
srcs = ["tflm_runtime.py"],
data = [":interpreter_wrapper_pybind.so"],
visibility = ["//visibility:public"],
deps = [":runtime"],
)
py_library(
name = "runtime",
srcs = [
"runtime.py",
],
data = [
":interpreter_wrapper_pybind.so",
],
srcs_version = "PY3",
visibility = ["//visibility:public"],
deps = [
......
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python package for TFLM Python Interpreter"""
import os
from tflite_micro.tensorflow.lite.micro.python.interpreter.src import interpreter_wrapper_pybind
from tflite_micro.tensorflow.lite.tools import flatbuffer_utils
class Interpreter(object):
def __init__(self, model_data, custom_op_registerers, arena_size):
if model_data is None:
raise ValueError("Model must not be None")
if not isinstance(custom_op_registerers, list) or not all(
isinstance(s, str) for s in custom_op_registerers):
raise ValueError("Custom ops registerers must be a list of strings")
# This is a heuristic to ensure that the arena is sufficiently sized.
if arena_size is None:
arena_size = len(model_data) * 10
# Some models make use of resource variables ops, get the count here
num_resource_variables = flatbuffer_utils.count_resource_variables(
model_data)
print("Number of resource variables the model uses = ",
num_resource_variables)
self._interpreter = interpreter_wrapper_pybind.InterpreterWrapper(
model_data, custom_op_registerers, arena_size, num_resource_variables)
@classmethod
def from_file(self, model_path, custom_op_registerers=[], arena_size=None):
"""Instantiates a TFLM interpreter from a model .tflite filepath.
Args:
model_path: Filepath to the .tflite model
custom_op_registerers: List of strings, each of which is the name of a
custom OP registerer
arena_size: Tensor arena size in bytes. If unused, tensor arena size will
default to 10 times the model size.
Returns:
An Interpreter instance
"""
if model_path is None or not os.path.isfile(model_path):
raise ValueError("Invalid model file path")
with open(model_path, "rb") as f:
model_data = f.read()
return Interpreter(model_data, custom_op_registerers, arena_size)
@classmethod
def from_bytes(self, model_data, custom_op_registerers=[], arena_size=None):
"""Instantiates a TFLM interpreter from a model in byte array.
Args:
model_data: Model in byte array format
custom_op_registerers: List of strings, each of which is the name of a
custom OP registerer
arena_size: Tensor arena size in bytes. If unused, tensor arena size will
default to 10 times the model size.
Returns:
An Interpreter instance
"""
return Interpreter(model_data, custom_op_registerers, arena_size)
def print_allocations(self):
"""Invoke the RecordingMicroAllocator to print the arena usage.
This should be called after `invoke()`.
Returns:
This method does not return anything, but It dumps the arena
usage to stderr.
"""
self._interpreter.PrintAllocations()
def invoke(self):
"""Invoke the TFLM interpreter to run an inference.
This should be called after `set_input()`.
Returns:
Status code of the C++ invoke function. A RuntimeError will be raised as
well upon any error.
"""
return self._interpreter.Invoke()
def reset(self):
"""Reset the model state to be what you would expect when the interpreter is first
created. i.e. after Init and Prepare is called for the very first time.
This should be called after invoke stateful model like LSTM.
Returns:
Status code of the C++ invoke function. A RuntimeError will be raised as
well upon any error.
"""
return self._interpreter.Reset()
def set_input(self, input_data, index):
"""Set input data into input tensor.
This should be called before `invoke()`.
Args:
input_data: Input data in numpy array format. The numpy array format is
chosen to be consistent with TFLite interpreter.
index: An integer between 0 and the number of input tensors (exclusive)
consistent with the order defined in the list of inputs in the .tflite
model
"""
if input_data is None:
raise ValueError("Input data must not be None")
if index is None or index < 0:
raise ValueError("Index must be a non-negative integer")
self._interpreter.SetInputTensor(input_data, index)
def get_output(self, index):
"""Get data from output tensor.
The output data correspond to the most recent `invoke()`.
Args:
index: An integer between 0 and the number of output tensors (exclusive)
consistent with the order defined in the list of outputs in the .tflite
model
Returns:
Output data in numpy array format. The numpy array format is chosen to
be consistent with TFLite interpreter.
"""
if index is None or index < 0:
raise ValueError("Index must be a non-negative integer")
return self._interpreter.GetOutputTensor(index)
def get_input_details(self, index):
"""Get input tensor information
Args:
index (int): An integer between 0 and the number of output tensors
(exclusive) consistent with the order defined in the list of outputs
in the .tflite model
Returns:
A dictionary from input index to tensor details where each item is a
dictionary with details about an input tensor. Each dictionary contains
the following fields that describe the tensor:
+ `shape`: The shape of the tensor.
+ `dtype`: The numpy data type (such as `np.int32` or `np.uint8`).
+ `quantization_parameters`: A dictionary of parameters used to quantize
the tensor:
~ `scales`: List of scales (one if per-tensor quantization).
~ `zero_points`: List of zero_points (one if per-tensor quantization).
~ `quantized_dimension`: Specifies the dimension of per-axis
quantization, in the case of multiple scales/zero_points.
"""
if index is None or index < 0:
raise ValueError("Index must be a non-negative integer")
return self._interpreter.GetInputTensorDetails(index)
def get_output_details(self, index):
"""Get output tensor information
Args:
index (int): An integer between 0 and the number of output tensors
(exclusive) consistent with the order defined in the list of outputs
in the .tflite model
Returns:
A dictionary from input index to tensor details where each item is a
dictionary with details about an input tensor. Each dictionary contains
the following fields that describe the tensor:
+ `shape`: The shape of the tensor.
+ `dtype`: The numpy data type (such as `np.int32` or `np.uint8`).
+ `quantization_parameters`: A dictionary of parameters used to quantize
the tensor:
~ `scales`: List of scales (one if per-tensor quantization).
~ `zero_points`: List of zero_points (one if per-tensor quantization).
~ `quantized_dimension`: Specifies the dimension of per-axis
quantization, in the case of multiple scales/zero_points.
"""
if index is None or index < 0:
raise ValueError("Index must be a non-negative integer")
return self._interpreter.GetOutputTensorDetails(index)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,199 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python package for TFLM Python Interpreter"""
import os
from tflite_micro.tensorflow.lite.micro.python.interpreter.src import interpreter_wrapper_pybind
from tflite_micro.tensorflow.lite.tools import flatbuffer_utils
class Interpreter(object):
def __init__(self, model_data, custom_op_registerers, arena_size):
if model_data is None:
raise ValueError("Model must not be None")
if not isinstance(custom_op_registerers, list) or not all(
isinstance(s, str) for s in custom_op_registerers):
raise ValueError("Custom ops registerers must be a list of strings")
# This is a heuristic to ensure that the arena is sufficiently sized.
if arena_size is None:
arena_size = len(model_data) * 10
# Some models make use of resource variables ops, get the count here
num_resource_variables = flatbuffer_utils.count_resource_variables(
model_data)
print("Number of resource variables the model uses = ",
num_resource_variables)
self._interpreter = interpreter_wrapper_pybind.InterpreterWrapper(
model_data, custom_op_registerers, arena_size, num_resource_variables)
@classmethod
def from_file(self, model_path, custom_op_registerers=[], arena_size=None):
"""Instantiates a TFLM interpreter from a model .tflite filepath.
Args:
model_path: Filepath to the .tflite model
custom_op_registerers: List of strings, each of which is the name of a
custom OP registerer
arena_size: Tensor arena size in bytes. If unused, tensor arena size will
default to 10 times the model size.
Returns:
An Interpreter instance
"""
if model_path is None or not os.path.isfile(model_path):
raise ValueError("Invalid model file path")
with open(model_path, "rb") as f:
model_data = f.read()
return Interpreter(model_data, custom_op_registerers, arena_size)
@classmethod
def from_bytes(self, model_data, custom_op_registerers=[], arena_size=None):
"""Instantiates a TFLM interpreter from a model in byte array.
Args:
model_data: Model in byte array format
custom_op_registerers: List of strings, each of which is the name of a
custom OP registerer
arena_size: Tensor arena size in bytes. If unused, tensor arena size will
default to 10 times the model size.
Returns:
An Interpreter instance
"""
return Interpreter(model_data, custom_op_registerers, arena_size)
def print_allocations(self):
"""Invoke the RecordingMicroAllocator to print the arena usage.
This should be called after `invoke()`.
Returns:
This method does not return anything, but It dumps the arena
usage to stderr.
"""
self._interpreter.PrintAllocations()
def invoke(self):
"""Invoke the TFLM interpreter to run an inference.
This should be called after `set_input()`.
Returns:
Status code of the C++ invoke function. A RuntimeError will be raised as
well upon any error.
"""
return self._interpreter.Invoke()
def reset(self):
"""Reset the model state to be what you would expect when the interpreter is first
created. i.e. after Init and Prepare is called for the very first time.
This should be called after invoke stateful model like LSTM.
Returns:
Status code of the C++ invoke function. A RuntimeError will be raised as
well upon any error.
"""
return self._interpreter.Reset()
def set_input(self, input_data, index):
"""Set input data into input tensor.
This should be called before `invoke()`.
Args:
input_data: Input data in numpy array format. The numpy array format is
chosen to be consistent with TFLite interpreter.
index: An integer between 0 and the number of input tensors (exclusive)
consistent with the order defined in the list of inputs in the .tflite
model
"""
if input_data is None:
raise ValueError("Input data must not be None")
if index is None or index < 0:
raise ValueError("Index must be a non-negative integer")
self._interpreter.SetInputTensor(input_data, index)
def get_output(self, index):
"""Get data from output tensor.
The output data correspond to the most recent `invoke()`.
Args:
index: An integer between 0 and the number of output tensors (exclusive)
consistent with the order defined in the list of outputs in the .tflite
model
Returns:
Output data in numpy array format. The numpy array format is chosen to
be consistent with TFLite interpreter.
"""
if index is None or index < 0:
raise ValueError("Index must be a non-negative integer")
return self._interpreter.GetOutputTensor(index)
def get_input_details(self, index):
"""Get input tensor information
Args:
index (int): An integer between 0 and the number of output tensors
(exclusive) consistent with the order defined in the list of outputs
in the .tflite model
Returns:
A dictionary from input index to tensor details where each item is a
dictionary with details about an input tensor. Each dictionary contains
the following fields that describe the tensor:
+ `shape`: The shape of the tensor.
+ `dtype`: The numpy data type (such as `np.int32` or `np.uint8`).
+ `quantization_parameters`: A dictionary of parameters used to quantize
the tensor:
~ `scales`: List of scales (one if per-tensor quantization).
~ `zero_points`: List of zero_points (one if per-tensor quantization).
~ `quantized_dimension`: Specifies the dimension of per-axis
quantization, in the case of multiple scales/zero_points.
"""
if index is None or index < 0:
raise ValueError("Index must be a non-negative integer")
return self._interpreter.GetInputTensorDetails(index)
def get_output_details(self, index):
"""Get output tensor information
Args:
index (int): An integer between 0 and the number of output tensors
(exclusive) consistent with the order defined in the list of outputs
in the .tflite model
Returns:
A dictionary from input index to tensor details where each item is a
dictionary with details about an input tensor. Each dictionary contains
the following fields that describe the tensor:
+ `shape`: The shape of the tensor.
+ `dtype`: The numpy data type (such as `np.int32` or `np.uint8`).
+ `quantization_parameters`: A dictionary of parameters used to quantize
the tensor:
~ `scales`: List of scales (one if per-tensor quantization).
~ `zero_points`: List of zero_points (one if per-tensor quantization).
~ `quantized_dimension`: Specifies the dimension of per-axis
quantization, in the case of multiple scales/zero_points.
"""
if index is None or index < 0:
raise ValueError("Index must be a non-negative integer")
return self._interpreter.GetOutputTensorDetails(index)
# TODO(b/286456378): remove once all usage is switched to runtime
from tflite_micro.tensorflow.lite.micro.python.interpreter.src.runtime import *
......@@ -15,7 +15,7 @@ py_test(
deps = [
requirement("numpy"),
requirement("tensorflow-cpu"),
"//tensorflow/lite/micro/python/interpreter/src:tflm_runtime",
"//tensorflow/lite/micro/python/interpreter/src:runtime",
"//tensorflow/lite/micro/testing:generate_test_models_lib",
],
)
......@@ -30,7 +30,7 @@ import tensorflow as tf
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
from tflite_micro.tensorflow.lite.micro.testing import generate_test_models
from tflite_micro.tensorflow.lite.micro.python.interpreter.src import tflm_runtime
from tflite_micro.tensorflow.lite.micro.python.interpreter.src import runtime
class ConvModelTests(test_util.TensorFlowTestCase):
......@@ -41,11 +41,11 @@ class ConvModelTests(test_util.TensorFlowTestCase):
def testInitErrorHandling(self):
with self.assertRaisesWithPredicateMatch(ValueError,
"Invalid model file path"):
tflm_runtime.Interpreter.from_file("wrong.tflite")
runtime.Interpreter.from_file("wrong.tflite")
def testInput(self):
model_data = generate_test_models.generate_conv_model(False)
tflm_interpreter = tflm_runtime.Interpreter.from_bytes(model_data)
tflm_interpreter = runtime.Interpreter.from_bytes(model_data)
data_x = np.random.randint(-127, 127, self.input_shape, dtype=np.int8)
tflm_interpreter.set_input(data_x, 0)
......@@ -68,7 +68,7 @@ class ConvModelTests(test_util.TensorFlowTestCase):
def testInputErrorHandling(self):
model_data = generate_test_models.generate_conv_model(True, self.filename)
tflm_interpreter = tflm_runtime.Interpreter.from_bytes(model_data)
tflm_interpreter = runtime.Interpreter.from_bytes(model_data)
data_x = np.random.randint(-127, 127, self.input_shape, dtype=np.int8)
# Try to access out of bound data
......@@ -96,7 +96,7 @@ class ConvModelTests(test_util.TensorFlowTestCase):
def testOutput(self):
model_data = generate_test_models.generate_conv_model(True, self.filename)
tflm_interpreter = tflm_runtime.Interpreter.from_bytes(model_data)
tflm_interpreter = runtime.Interpreter.from_bytes(model_data)
# Initial output values are all 0
output = tflm_interpreter.get_output(0)
......@@ -121,7 +121,7 @@ class ConvModelTests(test_util.TensorFlowTestCase):
def testOutputErrorHandling(self):
model_data = generate_test_models.generate_conv_model(True, self.filename)
tflm_interpreter = tflm_runtime.Interpreter.from_bytes(model_data)
tflm_interpreter = runtime.Interpreter.from_bytes(model_data)
# Try to access out of bound data
with self.assertRaisesWithPredicateMatch(IndexError,
"Tensor is out of bound"):
......@@ -134,7 +134,7 @@ class ConvModelTests(test_util.TensorFlowTestCase):
model_data = generate_test_models.generate_conv_model(True, self.filename)
# TFLM interpreter
tflm_interpreter = tflm_runtime.Interpreter.from_bytes(model_data)
tflm_interpreter = runtime.Interpreter.from_bytes(model_data)
# TFLite interpreter
tflite_interpreter = tf.lite.Interpreter(
......@@ -169,8 +169,8 @@ class ConvModelTests(test_util.TensorFlowTestCase):
def _helperModelFromFileAndBufferEqual(self):
model_data = generate_test_models.generate_conv_model(True, self.filename)
file_interpreter = tflm_runtime.Interpreter.from_file(self.filename)
bytes_interpreter = tflm_runtime.Interpreter.from_bytes(model_data)
file_interpreter = runtime.Interpreter.from_file(self.filename)
bytes_interpreter = runtime.Interpreter.from_bytes(model_data)
num_steps = 100
for i in range(0, num_steps):
......@@ -198,7 +198,7 @@ class ConvModelTests(test_util.TensorFlowTestCase):
model_data = generate_test_models.generate_conv_model(False)
interpreters = [
tflm_runtime.Interpreter.from_bytes(model_data) for i in range(10)
runtime.Interpreter.from_bytes(model_data) for i in range(10)
]
num_steps = 100
......@@ -221,7 +221,7 @@ class ConvModelTests(test_util.TensorFlowTestCase):
pass
def _helperOutputTensorMemoryLeak(self):
interpreter = tflm_runtime.Interpreter.from_file(self.filename)
interpreter = runtime.Interpreter.from_file(self.filename)
int_ref = weakref.finalize(interpreter, self._helperNoop)
some_output = interpreter.get_output(0)
output_ref = weakref.finalize(some_output, self._helperNoop)
......@@ -250,22 +250,22 @@ class ConvModelTests(test_util.TensorFlowTestCase):
custom_op_registerers = [("wrong", "format")]
with self.assertRaisesWithPredicateMatch(ValueError,
"must be a list of strings"):
interpreter = tflm_runtime.Interpreter.from_bytes(
model_data, custom_op_registerers)
interpreter = runtime.Interpreter.from_bytes(model_data,
custom_op_registerers)
custom_op_registerers = "WrongFormat"
with self.assertRaisesWithPredicateMatch(ValueError,
"must be a list of strings"):
interpreter = tflm_runtime.Interpreter.from_bytes(
model_data, custom_op_registerers)
interpreter = runtime.Interpreter.from_bytes(model_data,
custom_op_registerers)
def testNonExistentCustomOps(self):
model_data = generate_test_models.generate_conv_model(False)
custom_op_registerers = ["SomeRandomOp"]
with self.assertRaisesWithPredicateMatch(
RuntimeError, "TFLM could not register custom op via SomeRandomOp"):
interpreter = tflm_runtime.Interpreter.from_bytes(
model_data, custom_op_registerers)
interpreter = runtime.Interpreter.from_bytes(model_data,
custom_op_registerers)
if __name__ == "__main__":
......
......@@ -59,7 +59,7 @@ py_test(
],
deps = [
":requantize_flatbuffer",
"//tensorflow/lite/micro/python/interpreter/src:tflm_runtime",
"//tensorflow/lite/micro/python/interpreter/src:runtime",
requirement("numpy"),
requirement("tensorflow-cpu"),
],
......
......@@ -20,7 +20,7 @@ import tensorflow as tf
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
from tflite_micro.tensorflow.lite.micro.tools import requantize_flatbuffer
from tflite_micro.tensorflow.lite.micro.python.interpreter.src import tflm_runtime
from tflite_micro.tensorflow.lite.micro.python.interpreter.src import runtime
from tflite_micro.tensorflow.lite.tools import flatbuffer_utils
......@@ -92,9 +92,9 @@ class SimpleFCModelTest(test_util.TensorFlowTestCase):
int8_converted_int16_model = convert_8to16_requantizer(
keras_model, representative_dataset_gen)
interpreter_tfl_converted = tflm_runtime.Interpreter.from_bytes(
interpreter_tfl_converted = runtime.Interpreter.from_bytes(
tfl_converted_int16_model)
interpreter_tool_converted = tflm_runtime.Interpreter.from_bytes(
interpreter_tool_converted = runtime.Interpreter.from_bytes(
int8_converted_int16_model)
num_steps = 10
......