未验证 提交 88504892 编写于 作者: W WangZhen 提交者: GitHub

[Dy2St]Refine covnert_var_shape in dy2st (#43348)

* Refine covnert_var_shape in dy2st

* Fix UT error

* Return paddle.shape only when has_negative and fix test_tensor_shape error

* Fix test_jit_save_load error
上级 abc5d0c4
......@@ -313,8 +313,7 @@ def convert_var_shape(x, idx=None, in_control_flow=False):
# # Assume x.shape=[3, -1] in static mode
# y = paddle.reshape(x, shape=[1, x.shape[1]])
# ```
if isinstance(x, Variable) and (in_control_flow
or has_negative(x.shape, idx)):
if isinstance(x, Variable) and has_negative(x.shape, idx):
return nn.shape(x) if idx is None else nn.shape(x)[idx]
else:
return list(x.shape) if idx is None else x.shape[idx]
......
......@@ -409,9 +409,9 @@ class TestTensorShapeInIf2(TestTensorShapeBasic):
self.dygraph_func = dyfunc_with_if_2
def _set_expected_op_num(self):
self.expected_op_num = 14
self.expected_shape_op_num = 2
self.expected_slice_op_num = 1
self.expected_op_num = 2
self.expected_shape_op_num = 0
self.expected_slice_op_num = 0
# 3. Tests with control flow for loop
......@@ -421,9 +421,9 @@ class TestTensorShapeInFor1(TestTensorShapeBasic):
self.dygraph_func = dyfunc_with_for_1
def _set_expected_op_num(self):
self.expected_op_num = 22
self.expected_shape_op_num = 3
self.expected_slice_op_num = 3
self.expected_op_num = 7
self.expected_shape_op_num = 0
self.expected_slice_op_num = 0
class TestTensorShapeInFor2(TestTensorShapeInFor1):
......@@ -443,9 +443,9 @@ class TestTensorShapeInFor3(TestTensorShapeInFor1):
self.dygraph_func = dyfunc_with_for_3
def _set_expected_op_num(self):
self.expected_op_num = 25
self.expected_shape_op_num = 6
self.expected_slice_op_num = 3
self.expected_op_num = 3
self.expected_shape_op_num = 0
self.expected_slice_op_num = 0
# 4. Tests with control flow while loop
......@@ -454,6 +454,11 @@ class TestTensorShapeInWhile1(TestTensorShapeInFor1):
def init_test_func(self):
self.dygraph_func = dyfunc_with_while_1
def _set_expected_op_num(self):
self.expected_op_num = 4
self.expected_shape_op_num = 0
self.expected_slice_op_num = 0
class TestTensorShapeInWhile2(TestTensorShapeInFor1):
......
......@@ -1070,7 +1070,7 @@ class LayerSaved(paddle.nn.Layer):
def forward(self, x):
y = self._linear_0(x)
# Multiple blocks
if x.shape[0] == 1:
if paddle.shape(x)[0] == 1:
y = self._linear_1_0(y)
else:
y += self._linear_1_1(y + self._scale)
......@@ -1097,7 +1097,7 @@ class LayerLoadFinetune(paddle.nn.Layer):
y = self._linear_0(x)
y = self._load_l1(y)
# Multiple blocks
if x.shape[0] == 1:
if paddle.shape(x)[0] == 1:
y = self._linear_1_0(y)
y = self._load_l1(y)
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册