未验证 提交 31e380ce 编写于 作者: R Roc 提交者: GitHub

[Eager] fix recompute for stop_gradient and inpalce (#48471)

* fix recompute for stop_gradient and inpalce

* fix ut

* update
上级 12486712
......@@ -150,6 +150,18 @@ class _HPRecomputeFunction(PyLayer):
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
# In new dygraph mode, in some cases a subset of outputs is identity to the subset of inputs,
# which is inplace operating. When the inputs' stop_gradient is True, an
# error will occurs because the stop_gradient=True and inpalce-op are not
# supported in the same time. The solution is to mark the inputs non_differentiable
# if its stop_gradient is True.
# Note:
# If not marked non_differentiable, all output tensors' attr `stop gradient`
# will be reset to `False` in c++ backend.
# See https://github.com/PaddlePaddle/Paddle/blob/9d62efb0e6e5373823039d9eda96cd5905426c0a/paddle/fluid/pybind/eager_py_layer.cc#L388
if framework.in_dygraph_mode() and state:
ctx.mark_non_differentiable(arg)
else:
ctx.inputs.append(arg)
......
......@@ -22,6 +22,7 @@ import paddle.distributed as dist
import paddle.distributed.fleet as fleet
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import framework
from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer
from paddle.fluid import layers
from paddle.fluid.dygraph.layers import Layer
......@@ -88,14 +89,22 @@ class TransformerNet(Layer):
class EmbeddingPipe(EmbeddingNet):
def forward(self, x):
return super().forward(x)
def forward(self, tensors):
if framework.in_dygraph_mode():
stable, x = tensors
return stable, super().forward(x)
else:
return super().forward(tensors)
class TransformerNetPipe(TransformerNet):
def forward(self, x):
def forward(self, tensors):
if framework.in_dygraph_mode():
stable, x = tensors
output = super().forward(x)
return output
return stable, output
else:
return super().forward(tensors)
class CriterionPipe(Layer):
......@@ -103,6 +112,8 @@ class CriterionPipe(Layer):
super().__init__()
def forward(self, out, label):
if framework.in_dygraph_mode():
out = out[-1]
loss = out.mean()
return loss
......@@ -171,7 +182,8 @@ class TestDistPPTraning(unittest.TestCase):
x_data = np.random.randint(0, vocab_size, size=[batch_size, length])
x = paddle.to_tensor(x_data)
x.stop_gradient = True
loss = model.train_batch([x, x], optimizer, scheduler)
input_ = (x, x) if framework.in_dygraph_mode() else x
loss = model.train_batch([input_, x], optimizer, scheduler)
# TODO(shenliang03) add utest for loss
print("loss: ", loss)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册