未验证 提交 83580ee6 编写于 作者: W WeiXin 提交者: GitHub

use 'paddle.framework.set_grad_enabled' in pylayer (#32355)

上级 feb2e476
...@@ -176,6 +176,7 @@ class CPyLayer(object): ...@@ -176,6 +176,7 @@ class CPyLayer(object):
class PyLayerBackward(PyLayerContext): class PyLayerBackward(PyLayerContext):
def backward(self, *args, **kwargs): def backward(self, *args, **kwargs):
with paddle.fluid.dygraph.guard():
with paddle.fluid.dygraph.no_grad(): with paddle.fluid.dygraph.no_grad():
return self._forward_cls.backward(*args, **kwargs) return self._forward_cls.backward(*args, **kwargs)
......
...@@ -283,20 +283,54 @@ class TestPyLayer(unittest.TestCase): ...@@ -283,20 +283,54 @@ class TestPyLayer(unittest.TestCase):
class cus_tanh(PyLayer): class cus_tanh(PyLayer):
@staticmethod @staticmethod
def forward(ctx, x): def forward(ctx, x):
return x.mean() return x
@staticmethod @staticmethod
def backward(ctx, dy): def backward(ctx, dy):
return dy return dy
class Layer(paddle.nn.Layer):
def __init__(self):
super(Layer, self).__init__()
def forward(self, data):
data = paddle.nn.functional.relu(data)
z = paddle.tanh(data)
z = cus_tanh.apply(data)
return z.mean()
for i in range(2): for i in range(2):
data = paddle.ones([2, 3], dtype="float64") / (i + 1) data = paddle.ones([2, 3], dtype="float64") / (i + 1)
data.stop_gradient = False data.stop_gradient = False
layer = Layer()
z = layer(data)
z.backward()
self.assertTrue(data.grad is not None)
def test_backward_in_backward(self):
class cus_tanh(PyLayer):
@staticmethod
def forward(ctx, x):
temp = x.detach()
ctx.inputs = temp
return x.mean()
@staticmethod
def backward(ctx, dy):
with paddle.set_grad_enabled(True):
temp = ctx.inputs
temp.stop_gradient = False
z = paddle.tanh(temp)
z.backward()
self.assertTrue(temp.grad is not None)
return paddle.to_tensor(temp.grad)
for i in range(2):
data = paddle.ones([2, 3], dtype="float32") / (i + 1)
data.stop_gradient = False
data = paddle.nn.functional.relu(data) data = paddle.nn.functional.relu(data)
z = paddle.tanh(data) z = paddle.tanh(data)
z = cus_tanh.apply(data) z = cus_tanh.apply(data)
z.backward()
self.assertTrue(data.grad is not None)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册