diff --git a/paddle/fluid/framework/var_type_inference.h b/paddle/fluid/framework/var_type_inference.h index f649c9388f0f6518dc4f8a587f5c9f9c01451373..945b68438e1e702e7b2e6498a26b0a107c6640da 100644 --- a/paddle/fluid/framework/var_type_inference.h +++ b/paddle/fluid/framework/var_type_inference.h @@ -69,6 +69,12 @@ class InferVarTypeContext { return op_->Inputs().at(name).size(); } + virtual size_t OutputSize(const std::string& name) const { + PADDLE_ENFORCE_NOT_NULL( + op_, platform::errors::PreconditionNotMet("op_ should not be null")); + return op_->Outputs().at(name).size(); + } + virtual const std::string& InputVarName(const std::string& name, const int index = 0) const { PADDLE_ENFORCE_NOT_NULL( diff --git a/paddle/fluid/operators/controlflow/conditional_block_op.cc b/paddle/fluid/operators/controlflow/conditional_block_op.cc index 31ed10a71201c666c72e23853fdf925a42a80fb3..6bf419c47a5669b87c0b47d48259362a66a23239 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op.cc +++ b/paddle/fluid/operators/controlflow/conditional_block_op.cc @@ -272,8 +272,18 @@ class ConditionalBlockGradInferVarType : public framework::VarTypeInference { // Input is {Tensor, LoDTensorArray}, we need synchronous the Input's // VarType into Input@GRAD to avoid generating {Tensor, Tensor} as // Input@GRAD. - ctx->SyncTypeAndDataType(ConditionalOp::kInputs, - framework::GradVarName(ConditionalOp::kInputs)); + auto input_size = ctx->InputSize(ConditionalOp::kInputs); + auto output_size = + ctx->OutputSize(framework::GradVarName(ConditionalOp::kInputs)); + PADDLE_ENFORCE_EQ(input_size, output_size, + platform::errors::InvalidArgument( + "input_size and output_size should be equal for " + "conditional_block_grad_op.")); + for (size_t i = 0; i < output_size; ++i) { + ctx->SyncTypeAndDataType(ConditionalOp::kInputs, + framework::GradVarName(ConditionalOp::kInputs), + i); + } } }; diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py index 567f266cd57b1eb4d16602b9bf7e1ee95d56bf19..ba1f5ed2b3ead7dd2be5526e18ebc82540b7ea4e 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py @@ -320,10 +320,12 @@ class ListWithCondNet(paddle.nn.Layer): if index > 0: res = a[0] * a[0] + y = y + 1 else: res = a[-1] * a[-1] + y = y - 1 - z = a[-1] * res + z = a[-1] * res * y[0] return z @@ -333,7 +335,7 @@ class TestListWithCondGradInferVarType(unittest.TestCase): x = paddle.to_tensor([2, 3, 4], dtype='float32') index = paddle.to_tensor([1]) res = net(x, index) - self.assertEqual(res[0], 16.) + self.assertEqual(res[0], 48.) if __name__ == '__main__':