Created by: Aurelius84
PR types
Bug fixes
PR changes
APIs
Describe
fix unused var with zero gradient
Example code:
x1
, x2
, x3
are output vars from same op, but only x1
is used to calculate gradient, so x2@GRAD
、x3@GRAD
should be zero but not before in this PR. They are not initialized as expected.
class TestGradientWithPrune(unittest.TestCase):
def test_prune(self):
x = fluid.data(name='x', shape=[3], dtype='float32')
x.stop_gradient = False
x1, x2, x3 = fluid.layers.split(x, dim=0, num_or_sections=3)
y = x1 * 2
x1_grad = fluid.gradients(y, x)
exe = fluid.Executor(fluid.CPUPlace())
main = fluid.default_main_program()
exe.run(fluid.default_startup_program())
out = exe.run(main,
feed={'x': np.ones([3]).astype('float32')},
fetch_list=[x1_grad])
self.assertTrue(np.array_equal(out[0], [2., 0., 0.]))