From a863b32ea8441b2448e487b417e7ce596f530a44 Mon Sep 17 00:00:00 2001 From: 0x45f <23097963+0x45f@users.noreply.github.com> Date: Mon, 21 Feb 2022 15:24:54 +0800 Subject: [PATCH] [Dy2St]Fix cond grad error when handle tensor array (#39689) * fix cond grad error when handle tensor array * add UT --- paddle/fluid/framework/var_type_inference.h | 6 ++++++ .../operators/controlflow/conditional_block_op.cc | 14 ++++++++++++-- .../tests/unittests/dygraph_to_static/test_list.py | 6 ++++-- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/var_type_inference.h b/paddle/fluid/framework/var_type_inference.h index f649c9388f0..945b68438e1 100644 --- a/paddle/fluid/framework/var_type_inference.h +++ b/paddle/fluid/framework/var_type_inference.h @@ -69,6 +69,12 @@ class InferVarTypeContext { return op_->Inputs().at(name).size(); } + virtual size_t OutputSize(const std::string& name) const { + PADDLE_ENFORCE_NOT_NULL( + op_, platform::errors::PreconditionNotMet("op_ should not be null")); + return op_->Outputs().at(name).size(); + } + virtual const std::string& InputVarName(const std::string& name, const int index = 0) const { PADDLE_ENFORCE_NOT_NULL( diff --git a/paddle/fluid/operators/controlflow/conditional_block_op.cc b/paddle/fluid/operators/controlflow/conditional_block_op.cc index 31ed10a7120..6bf419c47a5 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_op.cc @@ -272,8 +272,18 @@ class ConditionalBlockGradInferVarType : public framework::VarTypeInference { // Input is {Tensor, LoDTensorArray}, we need synchronous the Input's // VarType into Input@GRAD to avoid generating {Tensor, Tensor} as // Input@GRAD. - ctx->SyncTypeAndDataType(ConditionalOp::kInputs, - framework::GradVarName(ConditionalOp::kInputs)); + auto input_size = ctx->InputSize(ConditionalOp::kInputs); + auto output_size = + ctx->OutputSize(framework::GradVarName(ConditionalOp::kInputs)); + PADDLE_ENFORCE_EQ(input_size, output_size, + platform::errors::InvalidArgument( + "input_size and output_size should be equal for " + "conditional_block_grad_op.")); + for (size_t i = 0; i < output_size; ++i) { + ctx->SyncTypeAndDataType(ConditionalOp::kInputs, + framework::GradVarName(ConditionalOp::kInputs), + i); + } } }; diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py index 567f266cd57..ba1f5ed2b3e 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py @@ -320,10 +320,12 @@ class ListWithCondNet(paddle.nn.Layer): if index > 0: res = a[0] * a[0] + y = y + 1 else: res = a[-1] * a[-1] + y = y - 1 - z = a[-1] * res + z = a[-1] * res * y[0] return z @@ -333,7 +335,7 @@ class TestListWithCondGradInferVarType(unittest.TestCase): x = paddle.to_tensor([2, 3, 4], dtype='float32') index = paddle.to_tensor([1]) res = net(x, index) - self.assertEqual(res[0], 16.) + self.assertEqual(res[0], 48.) if __name__ == '__main__': -- GitLab