diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 6fa531c573daa7bdc6b2c043ae29e8363c7bcaa8..10786c662072cce6fbea642464c4aff3de552ae1 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -1023,13 +1023,20 @@ class Layer(core.Layer): self._non_persistable_buffer_names_set.add(name) _buffers[name] = value elif _buffers is not None and name in _buffers: - if value is not None: + # Note(Aurelius84): In Dy2stat, the value of the Buffer may be modified in + # decorated function, such as `self.buffer = new_tensor`. So we update its + # value via `assign`. + if type(value) == framework.Variable: + from paddle import assign + assign(value, _buffers[name]) + elif value is not None: raise TypeError( "assignment to buffers '{}' should be of type core.VarBase or None, but got '{}'" .format(name, type(value).__name__)) - # Assigning None will remove the buffer, but if re-assign a new varBase to it, - # it will be remarked as a buffer with same `persistable` attribute. - _buffers[name] = None + else: + # Assigning None will remove the buffer, but if re-assign a new varBase to it, + # it will be remarked as a buffer with same `persistable` attribute. + _buffers[name] = None else: object.__setattr__(self, name, value) diff --git a/python/paddle/fluid/tests/unittests/test_base_layer.py b/python/paddle/fluid/tests/unittests/test_base_layer.py index 875f6211a7fbd98463d98dff91d93cc1b431fc86..31879dae0dad06d75a1ad5c6b6780ef2c3d2b93b 100644 --- a/python/paddle/fluid/tests/unittests/test_base_layer.py +++ b/python/paddle/fluid/tests/unittests/test_base_layer.py @@ -15,9 +15,11 @@ import unittest import numpy as np +import paddle import paddle.fluid as fluid from paddle.fluid.dygraph import to_variable from paddle.fluid.framework import ParamBase +from paddle.jit import ProgramTranslator class L1(fluid.Layer): @@ -288,5 +290,46 @@ class TestBuffer(unittest.TestCase): self.assertTrue(np.array_equal(var1.numpy(), var2.numpy())) +class BufferNetWithModification(paddle.nn.Layer): + def __init__(self, shape): + super(BufferNetWithModification, self).__init__() + + self.buffer1 = paddle.zeros(shape, 'int32') + self.buffer2 = paddle.zeros(shape, 'int32') + + @paddle.jit.to_static + def forward(self, x): + self.buffer1 += x + self.buffer2 = self.buffer1 + x + + out = self.buffer1 + self.buffer2 + + return out + + +class TestModifiedBuffer(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self.prog_trans = ProgramTranslator() + self.shape = [10, 16] + + def _run(self, to_static=False): + self.prog_trans.enable(to_static) + + x = paddle.ones([1], 'int32') + net = BufferNetWithModification(self.shape) + out = net(x) + + return out, net.buffer1, net.buffer2 + + def test_modified(self): + dy_outs = self._run(False) + st_outs = self._run(True) + + for i in range(len(dy_outs)): + self.assertTrue( + np.array_equal(dy_outs[i].numpy(), st_outs[i].numpy())) + + if __name__ == '__main__': unittest.main()