From 8e9eea7f188bee9a8a568fdf8b4e2a7027b77c23 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Wed, 3 Aug 2022 11:25:52 +0800 Subject: [PATCH] [ Dy2Static ] Release the condition when convert_len() return a variable (#44527) * release the convert_len() * fix bugs * fix bugs * fix start error. * when shape==0, we should also call paddle.shape --- .../fluid/dygraph/dygraph_to_static/convert_operators.py | 7 ++++++- .../fluid/tests/unittests/dygraph_to_static/test_len.py | 2 ++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py index e0b46fe2341..b78b5957393 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py @@ -414,6 +414,8 @@ def convert_len(var): # Note: Length of var may be known ahead of time in dygraph, # but it probably represents batch size which can be variant. # so we return a variable dynamically inferred from var.shape. + if var.shape[0] > 0 and var.type == core.VarDesc.VarType.LOD_TENSOR: + return var.shape[0] return nn.shape(var)[0] elif var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY: return control_flow.array_length(var) @@ -447,7 +449,10 @@ class VariableTuple: def __init__(self, var, start=0): self.var = var self.len = convert_len(var) - self.rag = paddle_range(start, start + self.len, 1, paddle.int64) + if isinstance(self.len, Variable): + self.rag = paddle_range(start, start + self.len, 1, paddle.int64) + else: + self.rag = range(start, start + self.len) def __getitem__(self, idx): return self.rag[idx], self.var[idx] diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_len.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_len.py index 28f79b57b6b..386c3a1bd7b 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_len.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_len.py @@ -17,12 +17,14 @@ from __future__ import print_function import unittest import numpy as np +import paddle import paddle.fluid as fluid from paddle.fluid.dygraph import declarative from paddle.fluid.dygraph.dygraph_to_static import convert_call SEED = 2020 np.random.seed(SEED) +paddle.enable_static() def len_with_tensor(x): -- GitLab