未验证 提交 8e9eea7f 编写于 作者: X xiongkun 提交者: GitHub

[ Dy2Static ] Release the condition when convert_len() return a variable (#44527)

* release the convert_len()

* fix bugs

* fix bugs

* fix start error.

* when shape==0, we should also call paddle.shape
上级 f1e3d795
......@@ -414,6 +414,8 @@ def convert_len(var):
# Note: Length of var may be known ahead of time in dygraph,
# but it probably represents batch size which can be variant.
# so we return a variable dynamically inferred from var.shape.
if var.shape[0] > 0 and var.type == core.VarDesc.VarType.LOD_TENSOR:
return var.shape[0]
return nn.shape(var)[0]
elif var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
return control_flow.array_length(var)
......@@ -447,7 +449,10 @@ class VariableTuple:
def __init__(self, var, start=0):
self.var = var
self.len = convert_len(var)
if isinstance(self.len, Variable):
self.rag = paddle_range(start, start + self.len, 1, paddle.int64)
else:
self.rag = range(start, start + self.len)
def __getitem__(self, idx):
return self.rag[idx], self.var[idx]
......
......@@ -17,12 +17,14 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import declarative
from paddle.fluid.dygraph.dygraph_to_static import convert_call
SEED = 2020
np.random.seed(SEED)
paddle.enable_static()
def len_with_tensor(x):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册