提交 d38a8fe0 编写于 作者: I Igor Ganichev 提交者: TensorFlower Gardener

Take eager tensor caches outside of EagerContext

This CL should have no behavior changes. It moves the tensor caches
that were fields of eager Context into a global map indexed by context id.

This CL is needed so that eager Context does not have a references
to EagerTensors. Future changes will add a reference from EagerTensor to
eager Context. This will mimick the ownership structure of corresponding C++
objects and ensure that Python will delete the context only after all
the tensors have been deleted. The latter will allow us to simplify EagerContext
destruction and remove ref counting from it.

PiperOrigin-RevId: 258454245
上级 9133bcb2
......@@ -144,25 +144,20 @@ class FunctionCallOptions(object):
"proto or None. got: {}".format(type(config)))
class _ThreadLocalData(threading.local):
"""Thread local storage for the eager context."""
# Map from context_id (an int) to _TensorCaches.
# Dicts are thread safe in CPython.
# TODO(iga): Remove this once TensorCaches are moved to C++.
_tensor_caches_map = {}
class _TensorCaches(threading.local):
"""Thread local tensor caches."""
def __init__(self):
super(_ThreadLocalData, self).__init__()
self.device_spec = _starting_device_spec
self.device_name = ""
self.mode = default_execution_mode
self.is_eager = default_execution_mode == EAGER_MODE
self.scope_name = ""
self.summary_writer = None
self.summary_recording = None
self.summary_recording_distribution_strategy = True
self.summary_step = None
super(_TensorCaches, self).__init__()
self.scalar_cache = {}
self._ones_rank_cache = None
self._zeros_cache = None
self.execution_mode = SYNC
self.function_call_options = None
@property
def ones_rank_cache(self):
......@@ -177,6 +172,24 @@ class _ThreadLocalData(threading.local):
return self._zeros_cache
class _ThreadLocalData(threading.local):
"""Thread local storage for the eager context."""
def __init__(self):
super(_ThreadLocalData, self).__init__()
self.device_spec = _starting_device_spec
self.device_name = ""
self.mode = default_execution_mode
self.is_eager = default_execution_mode == EAGER_MODE
self.scope_name = ""
self.summary_writer = None
self.summary_recording = None
self.summary_recording_distribution_strategy = True
self.summary_step = None
self.execution_mode = SYNC
self.function_call_options = None
ContextSwitch = collections.namedtuple(
"ContextSwitch", ["is_building_function", "enter_context_fn",
"device_stack"])
......@@ -277,6 +290,33 @@ class PhysicalDevice(
pass
class _AtomicCounter(object):
"""A simple atomic counter."""
def __init__(self):
self._value = 0
self._lock = threading.Lock()
def increment_and_get(self):
with self._lock:
self._value += 1
return self._value
_context_id_counter = _AtomicCounter()
class _TensorCacheDeleter(object):
"""Deletes tensor caches for a given context."""
def __init__(self, context_id):
self._context_id = context_id
def __del__(self):
if self._context_id in _tensor_caches_map:
del _tensor_caches_map[self._context_id]
# TODO(agarwal): rename to EagerContext / EagerRuntime ?
# TODO(agarwal): consider keeping the corresponding Graph here.
class Context(object):
......@@ -327,6 +367,12 @@ class Context(object):
Raises:
ValueError: If execution_mode is not valid.
"""
# This _id is used only to index the tensor caches.
# TODO(iga): Remove this when tensor caches are moved to C++.
self._id = _context_id_counter.increment_and_get()
self._tensor_cache_deleter = _TensorCacheDeleter(self._id)
_tensor_caches_map[self._id] = _TensorCaches()
self._config = config
self._thread_local_data = _ThreadLocalData()
self._context_switches = _ContextSwitchStack(self.executing_eagerly())
......@@ -604,15 +650,15 @@ class Context(object):
def scalar_cache(self):
"""Per-device cache for scalars."""
return self._thread_local_data.scalar_cache
return _tensor_caches_map[self._id].scalar_cache
def ones_rank_cache(self):
"""Per-device cache for scalars."""
return self._thread_local_data.ones_rank_cache
return _tensor_caches_map[self._id].ones_rank_cache
def zeros_cache(self):
"""Per-device cache for scalars."""
return self._thread_local_data.zeros_cache
return _tensor_caches_map[self._id].zeros_cache
@property
def scope_name(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册