From ff7e3590de6abf2e94c2997d4fe14fe185264068 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 16 Feb 2022 14:07:57 +0800 Subject: [PATCH] Add ConditionalBlockGradInferVarType (#39585) --- .../controlflow/conditional_block_op.cc | 15 +++++++++- .../unittests/dygraph_to_static/test_list.py | 30 +++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/controlflow/conditional_block_op.cc b/paddle/fluid/operators/controlflow/conditional_block_op.cc index 68720e70b0..a044506cef 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_op.cc @@ -265,6 +265,18 @@ class ConditionalBlockGradInferShape : public framework::InferShapeBase { } }; +class ConditionalBlockGradInferVarType : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + // NOTE(Aurelius84): VarType of Output is LoDTensor by default. In case of + // Input is {Tensor, LoDTensorArray}, we need synchronous the Input's + // VarType into Input@GRAD to avoid generating {Tensor, Tensor} as + // Input@GRAD. + ctx->SyncTypeAndDataType(ConditionalOp::kInputs, + framework::GradVarName(ConditionalOp::kInputs)); + } +}; + template class ConditionalBlockGradMaker : public framework::SingleGradOpMaker { public: @@ -300,4 +312,5 @@ REGISTER_OPERATOR(conditional_block, ops::ConditionalBlockOp, ops::ConditionalBlockOpProtoMaker, ops::ConditionalBlockGradMaker); REGISTER_OPERATOR(conditional_block_grad, ops::ConditionalBlockGradOp, - ops::ConditionalBlockGradInferShape); + ops::ConditionalBlockGradInferShape, + ops::ConditionalBlockGradInferVarType); diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py index 57386bd00c..567f266cd5 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py @@ -306,5 +306,35 @@ class TestListInForLoopWithSubscript(TestListWithoutControlFlow): self.input = np.random.random((3, 4)).astype('float32') +class ListWithCondNet(paddle.nn.Layer): + def __init__(self): + super(ListWithCondNet, self).__init__() + + @paddle.jit.to_static + def forward(self, x, index): + y = paddle.nn.functional.relu(x) + a = [] + + for i in y: + a.append(i) + + if index > 0: + res = a[0] * a[0] + else: + res = a[-1] * a[-1] + + z = a[-1] * res + return z + + +class TestListWithCondGradInferVarType(unittest.TestCase): + def test_to_static(self): + net = ListWithCondNet() + x = paddle.to_tensor([2, 3, 4], dtype='float32') + index = paddle.to_tensor([1]) + res = net(x, index) + self.assertEqual(res[0], 16.) + + if __name__ == '__main__': unittest.main() -- GitLab