提交 0b752d92 编写于 作者: M Meghna Natraj 提交者: TensorFlower Gardener

Refactor keras dependency code

PiperOrigin-RevId: 339954904
Change-Id: Id9f6717da5c32bff185a10a37e9682be64cc6501
上级 431b12d9
......@@ -149,7 +149,6 @@ py_library(
"//tensorflow/lite/experimental/examples/lstm:tflite_lstm_ops",
"//tensorflow/lite/experimental/microfrontend:audio_microfrontend_py",
"//tensorflow/lite/experimental/tensorboard:ops_util",
"//tensorflow/lite/python/keras/saving:saving_utils",
"//tensorflow/lite/python/optimize:calibrator",
"//tensorflow/python:graph_util",
"//tensorflow/python/keras",
......@@ -236,6 +235,7 @@ py_library(
":op_hint",
":schema_py",
":schema_util",
"//tensorflow/lite/python:tflite_keras_util",
"//tensorflow/lite/toco:toco_flags_proto_py",
"//tensorflow/python:convert_to_constants",
"//tensorflow/python:dtypes",
......@@ -278,6 +278,18 @@ py_test(
],
)
py_library(
name = "tflite_keras_util",
srcs = [
"tflite_keras_util.py",
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:util",
"//tensorflow/python/eager:def_function",
],
)
py_library(
name = "wrap_toco",
srcs = [
......
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
py_library(
name = "saving_utils",
srcs = [
"saving_utils.py",
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:util",
"//tensorflow/python/eager:def_function",
],
)
......@@ -50,7 +50,6 @@ from tensorflow.lite.python.convert import toco_convert_protos # pylint: disabl
from tensorflow.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model
from tensorflow.lite.python.interpreter import Interpreter # pylint: disable=unused-import
from tensorflow.lite.python.interpreter import load_delegate # pylint: disable=unused-import
from tensorflow.lite.python.keras.saving import saving_utils as _keras_saving_utils
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import
from tensorflow.lite.python.op_hint import is_ophint_converted as _is_ophint_converted
from tensorflow.lite.python.op_hint import OpHint # pylint: disable=unused-import
......@@ -63,9 +62,11 @@ from tensorflow.lite.python.util import get_grappler_config as _get_grappler_con
from tensorflow.lite.python.util import get_tensor_name as _get_tensor_name
from tensorflow.lite.python.util import get_tensors_from_tensor_names as _get_tensors_from_tensor_names
from tensorflow.lite.python.util import is_frozen_graph as _is_frozen_graph
from tensorflow.lite.python.util import model_input_signature as _model_input_signature
from tensorflow.lite.python.util import modify_model_io_type as _modify_model_io_type
from tensorflow.lite.python.util import run_graph_optimizations as _run_graph_optimizations
from tensorflow.lite.python.util import set_tensor_shapes as _set_tensor_shapes
from tensorflow.lite.python.util import trace_model_call as _trace_model_call
from tensorflow.python.client import session as _session
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function as _def_function
......@@ -839,12 +840,11 @@ class TFLiteKerasModelConverterV2(TFLiteConverterBaseV2):
# Pass `keep_original_batch_size=True` will ensure that we get an input
# signature including the batch dimension specified by the user.
# TODO(b/169898786): Use the Keras public API when TFLite moves out of TF
input_signature = _keras_saving_utils.model_input_signature(
input_signature = _model_input_signature(
self._keras_model, keep_original_batch_size=True)
# TODO(b/169898786): Use the Keras public API when TFLite moves out of TF
func = _keras_saving_utils.trace_model_call(
self._keras_model, input_signature)
func = _trace_model_call(self._keras_model, input_signature)
concrete_func = func.get_concrete_function()
self._funcs = [concrete_func]
......@@ -1468,7 +1468,7 @@ class TFLiteKerasModelConverter(TFLiteConverterBaseV1):
keras_model = keras_deps.get_load_model_function()(
model_file, custom_objects)
function = _keras_saving_utils.trace_model_call(keras_model)
function = _trace_model_call(keras_model)
concrete_func = function.get_concrete_function()
frozen_func = _convert_to_constants.convert_variables_to_constants_v2(
......
......@@ -13,7 +13,13 @@
# limitations under the License.
# ==============================================================================
"""Utility functions for TensorFlow models."""
"""Keras functions required by TensorFlow Lite.
The functions defined in this library have been copied over from Keras in order
to remove the dependency from TensorFlow Lite to Keras. The functions which
could not be copied over are accessed using the dependecy inversion principle.
(for details, refer to tensorflow/python/util/keras_deps.py).
"""
from __future__ import absolute_import
from __future__ import division
......
......@@ -33,6 +33,7 @@ from tensorflow.core.protobuf import graph_debug_info_pb2
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
from tensorflow.lite.python import schema_py_generated as schema_fb
from tensorflow.lite.python import schema_util
from tensorflow.lite.python import tflite_keras_util as _tflite_keras_util
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs
from tensorflow.lite.python.op_hint import find_all_hinted_output_nodes
from tensorflow.lite.toco import types_pb2 as _types_pb2
......@@ -44,6 +45,10 @@ from tensorflow.python.framework import graph_util as tf_graph_util
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.training.saver import export_meta_graph as _export_meta_graph
# Keras functions used by TFLite
model_input_signature = _tflite_keras_util.model_input_signature
trace_model_call = _tflite_keras_util.trace_model_call
# Map of tf.dtypes to TFLite types_flag_pb2.
_MAP_TF_TO_TFLITE_TYPES = {
dtypes.float32: _types_pb2.FLOAT,
......
......@@ -417,6 +417,8 @@ def call_context():
return call_ctx
# Inject the call_context function to keras_deps to remove the dependency
# from TFLite to Keras.
keras_deps.register_call_context_function(call_context)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册