Created by: Aurelius84
PR types
Bug fixes
PR changes
Others
Describe
Add descriptor cache for StaticLayer.
Why need this
In StaticLayer, we implement __get__
method to make it into descriptor
(see this).
This allows us to easily parse the instance from decorated bound method, such as self.forward
.
class A:
@decorator
def foo(self, x):
return x
a = A()
a.foo(10)
if decorated
is a decorator, a.foo
will firstly call __get__
in decorator
. The common implementation is:
def __get__(self, instance, owner):
# instance is `a` as above
# owner is `A` as above
do_somethind_useful()
However why we need self._descriptor_cache
here?
Before this PR, consider the following example code:
class Net(Layer):
....
@to_static
def forward(self, x):
return x*2
net_1 = Net()
net_2 = Net()
isinstance(net_1.forward, StaticLayer) # True
isinstance(net_2.forward, StaticLayer) # True
id(net_1.forward) == id(net_2.forward) # True !!
id(net_1.forward) == id(Net.forward) # True !!
That's the problem. For different instance from same Class, they share the same StaticLayer instance, which has risk in some operations.
Solution
We introduce a wreakref.WeakKeyDictionary()
as descriptor_cache to hold different StaticLayer instance.
def __get__(self, instance, owner):
if instance is not in self._descriptor_cache:
self._descriptor_cache[instance] = self._clone() # create a new staticLayer
return self._descriptor_cache[instance]
The benefits to choose wreakref.WeakKeyDictionary()
instead of Dict
is if net
is garbage collected, it will be removed from self._descriptor_cache
to save memory.
After this PR:
id(net_1.forward) == id(net_2.forward) # False
id(net_1.forward) == id(Net.forward) # False