diff --git a/paddle/fluid/operators/controlflow/conditional_block_op.cc b/paddle/fluid/operators/controlflow/conditional_block_op.cc index 68720e70b09ad6098da4fd59c50bbb89a56c9dc7..a044506cef4bb480d30bc87f3b556560a4d61064 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_op.cc @@ -265,6 +265,18 @@ class ConditionalBlockGradInferShape : public framework::InferShapeBase { } }; +class ConditionalBlockGradInferVarType : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + // NOTE(Aurelius84): VarType of Output is LoDTensor by default. In case of + // Input is {Tensor, LoDTensorArray}, we need synchronous the Input's + // VarType into Input@GRAD to avoid generating {Tensor, Tensor} as + // Input@GRAD. + ctx->SyncTypeAndDataType(ConditionalOp::kInputs, + framework::GradVarName(ConditionalOp::kInputs)); + } +}; + template class ConditionalBlockGradMaker : public framework::SingleGradOpMaker { public: @@ -300,4 +312,5 @@ REGISTER_OPERATOR(conditional_block, ops::ConditionalBlockOp, ops::ConditionalBlockOpProtoMaker, ops::ConditionalBlockGradMaker); REGISTER_OPERATOR(conditional_block_grad, ops::ConditionalBlockGradOp, - ops::ConditionalBlockGradInferShape); + ops::ConditionalBlockGradInferShape, + ops::ConditionalBlockGradInferVarType); diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py index 57386bd00c9f39a9c00c6f24b79cc226bf6e27dd..567f266cd57b1eb4d16602b9bf7e1ee95d56bf19 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py @@ -306,5 +306,35 @@ class TestListInForLoopWithSubscript(TestListWithoutControlFlow): self.input = np.random.random((3, 4)).astype('float32') +class ListWithCondNet(paddle.nn.Layer): + def __init__(self): + super(ListWithCondNet, self).__init__() + + @paddle.jit.to_static + def forward(self, x, index): + y = paddle.nn.functional.relu(x) + a = [] + + for i in y: + a.append(i) + + if index > 0: + res = a[0] * a[0] + else: + res = a[-1] * a[-1] + + z = a[-1] * res + return z + + +class TestListWithCondGradInferVarType(unittest.TestCase): + def test_to_static(self): + net = ListWithCondNet() + x = paddle.to_tensor([2, 3, 4], dtype='float32') + index = paddle.to_tensor([1]) + res = net(x, index) + self.assertEqual(res[0], 16.) + + if __name__ == '__main__': unittest.main()