未验证 提交 8e9eea7f 编写于 作者: X xiongkun 提交者: GitHub

[ Dy2Static ] Release the condition when convert_len() return a variable (#44527)

* release the convert_len()

* fix bugs

* fix bugs

* fix start error.

* when shape==0, we should also call paddle.shape
上级 f1e3d795
...@@ -414,6 +414,8 @@ def convert_len(var): ...@@ -414,6 +414,8 @@ def convert_len(var):
# Note: Length of var may be known ahead of time in dygraph, # Note: Length of var may be known ahead of time in dygraph,
# but it probably represents batch size which can be variant. # but it probably represents batch size which can be variant.
# so we return a variable dynamically inferred from var.shape. # so we return a variable dynamically inferred from var.shape.
if var.shape[0] > 0 and var.type == core.VarDesc.VarType.LOD_TENSOR:
return var.shape[0]
return nn.shape(var)[0] return nn.shape(var)[0]
elif var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY: elif var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
return control_flow.array_length(var) return control_flow.array_length(var)
...@@ -447,7 +449,10 @@ class VariableTuple: ...@@ -447,7 +449,10 @@ class VariableTuple:
def __init__(self, var, start=0): def __init__(self, var, start=0):
self.var = var self.var = var
self.len = convert_len(var) self.len = convert_len(var)
if isinstance(self.len, Variable):
self.rag = paddle_range(start, start + self.len, 1, paddle.int64) self.rag = paddle_range(start, start + self.len, 1, paddle.int64)
else:
self.rag = range(start, start + self.len)
def __getitem__(self, idx): def __getitem__(self, idx):
return self.rag[idx], self.var[idx] return self.rag[idx], self.var[idx]
......
...@@ -17,12 +17,14 @@ from __future__ import print_function ...@@ -17,12 +17,14 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph import declarative from paddle.fluid.dygraph import declarative
from paddle.fluid.dygraph.dygraph_to_static import convert_call from paddle.fluid.dygraph.dygraph_to_static import convert_call
SEED = 2020 SEED = 2020
np.random.seed(SEED) np.random.seed(SEED)
paddle.enable_static()
def len_with_tensor(x): def len_with_tensor(x):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册