未验证 提交 57e4411a 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2stat] Support to modify value of buffer tensor (#28328)

* [Dy2stat] Support to modify value of buffer tensor

* remove "defaultTest"

* fix name confliction
上级 d9b5f126
......@@ -1023,13 +1023,20 @@ class Layer(core.Layer):
self._non_persistable_buffer_names_set.add(name)
_buffers[name] = value
elif _buffers is not None and name in _buffers:
if value is not None:
# Note(Aurelius84): In Dy2stat, the value of the Buffer may be modified in
# decorated function, such as `self.buffer = new_tensor`. So we update its
# value via `assign`.
if type(value) == framework.Variable:
from paddle import assign
assign(value, _buffers[name])
elif value is not None:
raise TypeError(
"assignment to buffers '{}' should be of type core.VarBase or None, but got '{}'"
.format(name, type(value).__name__))
# Assigning None will remove the buffer, but if re-assign a new varBase to it,
# it will be remarked as a buffer with same `persistable` attribute.
_buffers[name] = None
else:
# Assigning None will remove the buffer, but if re-assign a new varBase to it,
# it will be remarked as a buffer with same `persistable` attribute.
_buffers[name] = None
else:
object.__setattr__(self, name, value)
......
......@@ -15,9 +15,11 @@
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import to_variable
from paddle.fluid.framework import ParamBase
from paddle.jit import ProgramTranslator
class L1(fluid.Layer):
......@@ -288,5 +290,46 @@ class TestBuffer(unittest.TestCase):
self.assertTrue(np.array_equal(var1.numpy(), var2.numpy()))
class BufferNetWithModification(paddle.nn.Layer):
def __init__(self, shape):
super(BufferNetWithModification, self).__init__()
self.buffer1 = paddle.zeros(shape, 'int32')
self.buffer2 = paddle.zeros(shape, 'int32')
@paddle.jit.to_static
def forward(self, x):
self.buffer1 += x
self.buffer2 = self.buffer1 + x
out = self.buffer1 + self.buffer2
return out
class TestModifiedBuffer(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.prog_trans = ProgramTranslator()
self.shape = [10, 16]
def _run(self, to_static=False):
self.prog_trans.enable(to_static)
x = paddle.ones([1], 'int32')
net = BufferNetWithModification(self.shape)
out = net(x)
return out, net.buffer1, net.buffer2
def test_modified(self):
dy_outs = self._run(False)
st_outs = self._run(True)
for i in range(len(dy_outs)):
self.assertTrue(
np.array_equal(dy_outs[i].numpy(), st_outs[i].numpy()))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册