未验证 提交 88504892 编写于 作者: W WangZhen 提交者: GitHub

[Dy2St]Refine covnert_var_shape in dy2st (#43348)

* Refine covnert_var_shape in dy2st

* Fix UT error

* Return paddle.shape only when has_negative and fix test_tensor_shape error

* Fix test_jit_save_load error
上级 abc5d0c4
...@@ -313,8 +313,7 @@ def convert_var_shape(x, idx=None, in_control_flow=False): ...@@ -313,8 +313,7 @@ def convert_var_shape(x, idx=None, in_control_flow=False):
# # Assume x.shape=[3, -1] in static mode # # Assume x.shape=[3, -1] in static mode
# y = paddle.reshape(x, shape=[1, x.shape[1]]) # y = paddle.reshape(x, shape=[1, x.shape[1]])
# ``` # ```
if isinstance(x, Variable) and (in_control_flow if isinstance(x, Variable) and has_negative(x.shape, idx):
or has_negative(x.shape, idx)):
return nn.shape(x) if idx is None else nn.shape(x)[idx] return nn.shape(x) if idx is None else nn.shape(x)[idx]
else: else:
return list(x.shape) if idx is None else x.shape[idx] return list(x.shape) if idx is None else x.shape[idx]
......
...@@ -409,9 +409,9 @@ class TestTensorShapeInIf2(TestTensorShapeBasic): ...@@ -409,9 +409,9 @@ class TestTensorShapeInIf2(TestTensorShapeBasic):
self.dygraph_func = dyfunc_with_if_2 self.dygraph_func = dyfunc_with_if_2
def _set_expected_op_num(self): def _set_expected_op_num(self):
self.expected_op_num = 14 self.expected_op_num = 2
self.expected_shape_op_num = 2 self.expected_shape_op_num = 0
self.expected_slice_op_num = 1 self.expected_slice_op_num = 0
# 3. Tests with control flow for loop # 3. Tests with control flow for loop
...@@ -421,9 +421,9 @@ class TestTensorShapeInFor1(TestTensorShapeBasic): ...@@ -421,9 +421,9 @@ class TestTensorShapeInFor1(TestTensorShapeBasic):
self.dygraph_func = dyfunc_with_for_1 self.dygraph_func = dyfunc_with_for_1
def _set_expected_op_num(self): def _set_expected_op_num(self):
self.expected_op_num = 22 self.expected_op_num = 7
self.expected_shape_op_num = 3 self.expected_shape_op_num = 0
self.expected_slice_op_num = 3 self.expected_slice_op_num = 0
class TestTensorShapeInFor2(TestTensorShapeInFor1): class TestTensorShapeInFor2(TestTensorShapeInFor1):
...@@ -443,9 +443,9 @@ class TestTensorShapeInFor3(TestTensorShapeInFor1): ...@@ -443,9 +443,9 @@ class TestTensorShapeInFor3(TestTensorShapeInFor1):
self.dygraph_func = dyfunc_with_for_3 self.dygraph_func = dyfunc_with_for_3
def _set_expected_op_num(self): def _set_expected_op_num(self):
self.expected_op_num = 25 self.expected_op_num = 3
self.expected_shape_op_num = 6 self.expected_shape_op_num = 0
self.expected_slice_op_num = 3 self.expected_slice_op_num = 0
# 4. Tests with control flow while loop # 4. Tests with control flow while loop
...@@ -454,6 +454,11 @@ class TestTensorShapeInWhile1(TestTensorShapeInFor1): ...@@ -454,6 +454,11 @@ class TestTensorShapeInWhile1(TestTensorShapeInFor1):
def init_test_func(self): def init_test_func(self):
self.dygraph_func = dyfunc_with_while_1 self.dygraph_func = dyfunc_with_while_1
def _set_expected_op_num(self):
self.expected_op_num = 4
self.expected_shape_op_num = 0
self.expected_slice_op_num = 0
class TestTensorShapeInWhile2(TestTensorShapeInFor1): class TestTensorShapeInWhile2(TestTensorShapeInFor1):
......
...@@ -1070,7 +1070,7 @@ class LayerSaved(paddle.nn.Layer): ...@@ -1070,7 +1070,7 @@ class LayerSaved(paddle.nn.Layer):
def forward(self, x): def forward(self, x):
y = self._linear_0(x) y = self._linear_0(x)
# Multiple blocks # Multiple blocks
if x.shape[0] == 1: if paddle.shape(x)[0] == 1:
y = self._linear_1_0(y) y = self._linear_1_0(y)
else: else:
y += self._linear_1_1(y + self._scale) y += self._linear_1_1(y + self._scale)
...@@ -1097,7 +1097,7 @@ class LayerLoadFinetune(paddle.nn.Layer): ...@@ -1097,7 +1097,7 @@ class LayerLoadFinetune(paddle.nn.Layer):
y = self._linear_0(x) y = self._linear_0(x)
y = self._load_l1(y) y = self._load_l1(y)
# Multiple blocks # Multiple blocks
if x.shape[0] == 1: if paddle.shape(x)[0] == 1:
y = self._linear_1_0(y) y = self._linear_1_0(y)
y = self._load_l1(y) y = self._load_l1(y)
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册