未验证 提交 94132190 编写于 作者: W WangZhen 提交者: GitHub

[Dy2St]Fix recurrent op eager deletion pass error in dy2st (#47105)

* Fix recurrent op eager deletion pass error in dy2st

* Polish code

* Refine error message
上级 1e1c7275
......@@ -36,6 +36,29 @@ void RecurrentOpEagerDeletionPass::ApplyImpl(Graph *graph) const {
std::unordered_map<size_t, OpAndGradOpPair> target_ops =
DeviceIdToRecurrentAndRecurrentGradOp(*graph);
if (graph->IsConstructedByPartialProgram()) {
PADDLE_ENFORCE_LE(target_ops.size(),
1,
platform::errors::InvalidArgument(
"Unsupported multi devices if graph is constructed "
"with partial program."));
size_t scope_idx = 0;
auto &recur_ops = target_ops[scope_idx].first;
auto &recur_grad_ops = target_ops[scope_idx].second;
auto all_ops = graph->OriginProgram().Block(0).AllOps();
if (recur_ops.empty()) {
operators::AppendOpVariantByOpName(
all_ops, std::string("recurrent"), &recur_ops);
} else if (recur_grad_ops.empty()) {
operators::AppendOpVariantByOpName(
all_ops, std::string("recurrent_grad"), &recur_grad_ops);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"One of recur_ops or recur_grad_ops should be empty."));
}
}
for (auto &entry : target_ops) {
// Prepare safe eager deletion on different devices because the garbage
// collection may be different across devices
......
......@@ -81,5 +81,22 @@ void AppendOpVariantByOpName(const std::vector<framework::OpDesc *> &op_descs,
}
}
void AppendOpVariantByOpName(
const std::vector<framework::OpDesc *> &op_descs,
const std::string &candidate_op_name,
std::unordered_set<OpVariant, OpVariant::Hasher> *result_ops) {
PADDLE_ENFORCE_NOT_NULL(
result_ops,
platform::errors::Unavailable("result_ops should not be a null_ptr."));
for (auto *op_desc : op_descs) {
PADDLE_ENFORCE_NOT_NULL(
op_desc,
platform::errors::Unavailable("op_desc should not be a null_ptr."));
if (op_desc->Type() == candidate_op_name) {
result_ops->emplace(op_desc);
}
}
}
} // namespace operators
} // namespace paddle
......@@ -78,5 +78,10 @@ void AppendOpVariantByOpName(const std::vector<framework::OpDesc *> &op_descs,
const std::string &candidate_op_name,
std::vector<OpVariant> *result_ops);
void AppendOpVariantByOpName(
const std::vector<framework::OpDesc *> &op_descs,
const std::string &candidate_op_name,
std::unordered_set<OpVariant, OpVariant::Hasher> *result_ops);
} // namespace operators
} // namespace paddle
......@@ -690,5 +690,45 @@ class EagerDeletionFarwardOnlyRnnAndBackwardRnnTest(
np.testing.assert_allclose(pd_output, py_output, rtol=0.01)
class RecurrentNet(paddle.nn.Layer):
def __init__(self):
super(RecurrentNet, self).__init__()
self.cell = paddle.nn.SimpleRNNCell(16, 32)
self.rnn = paddle.nn.RNN(self.cell)
def forward(self, inputs, prev_h):
outputs, final_states = self.rnn(inputs, prev_h)
return outputs, final_states
class TestDy2StRecurrentOpBackward(unittest.TestCase):
def setUp(self):
paddle.disable_static()
paddle.seed(100)
def tearDown(self):
paddle.enable_static()
def test_recurrent_backward(self):
net = RecurrentNet()
inputs = paddle.rand((4, 23, 16))
inputs.stop_gradient = False
prev_h = paddle.randn((4, 32))
prev_h.stop_gradient = False
outputs, final_states = net(inputs, prev_h)
outputs.backward()
dy_grad = inputs.gradient()
inputs.clear_gradient()
net = paddle.jit.to_static(net)
outputs, final_states = net(inputs, prev_h)
outputs.backward()
st_grad = inputs.gradient()
np.testing.assert_allclose(dy_grad, st_grad)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册