未验证 提交 31ed9a5e 编写于 作者: L liym27 提交者: GitHub

[Dy2Stat] Use Paddle2.0 api paddle.tensor.array_* (#30156)

上级 ad55f609
...@@ -126,7 +126,7 @@ class ListTransformer(gast.NodeTransformer): ...@@ -126,7 +126,7 @@ class ListTransformer(gast.NodeTransformer):
i = "paddle.cast(" \ i = "paddle.cast(" \
"x=paddle.jit.dy2static.to_static_variable({})," \ "x=paddle.jit.dy2static.to_static_variable({})," \
"dtype='int64')".format(ast_to_source_code(slice_node)) "dtype='int64')".format(ast_to_source_code(slice_node))
assign_code = "{} = fluid.layers.array_write(x={}, i={}, array={})" \ assign_code = "{} = paddle.tensor.array_write(x={}, i={}, array={})" \
.format(target_name, value_code, i, target_name) .format(target_name, value_code, i, target_name)
assign_node = gast.parse(assign_code).body[0] assign_node = gast.parse(assign_code).body[0]
return assign_node return assign_node
...@@ -168,7 +168,7 @@ class ListTransformer(gast.NodeTransformer): ...@@ -168,7 +168,7 @@ class ListTransformer(gast.NodeTransformer):
# return False # return False
# if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set: # if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set:
# return False # return False
# # TODO: Consider that `arg` may be a gast.Call about Paddle Api. eg: list_a.append(fluid.layers.reshape(x)) # # TODO: Consider that `arg` may be a gast.Call about Paddle Api. eg: list_a.append(paddle.reshape(x))
# # else: # # else:
# # return True # # return True
self.list_name_to_updated[value_name.strip()] = True self.list_name_to_updated[value_name.strip()] = True
...@@ -187,7 +187,7 @@ class ListTransformer(gast.NodeTransformer): ...@@ -187,7 +187,7 @@ class ListTransformer(gast.NodeTransformer):
def _create_tensor_array(self): def _create_tensor_array(self):
# Although `dtype='float32'`, other types such as `int32` can also be supported # Although `dtype='float32'`, other types such as `int32` can also be supported
func_code = "fluid.layers.create_array(dtype='float32')" func_code = "paddle.tensor.create_array(dtype='float32')"
func_node = gast.parse(func_code).body[0].value func_node = gast.parse(func_code).body[0].value
return func_node return func_node
...@@ -195,8 +195,8 @@ class ListTransformer(gast.NodeTransformer): ...@@ -195,8 +195,8 @@ class ListTransformer(gast.NodeTransformer):
assert isinstance(node, gast.Call) assert isinstance(node, gast.Call)
array = astor.to_source(gast.gast_to_ast(node.func.value)) array = astor.to_source(gast.gast_to_ast(node.func.value))
x = astor.to_source(gast.gast_to_ast(node.args[0])) x = astor.to_source(gast.gast_to_ast(node.args[0]))
i = "fluid.layers.array_length({})".format(array) i = "paddle.tensor.array_length({})".format(array)
func_code = "fluid.layers.array_write(x={}, i={}, array={})".format( func_code = "paddle.tensor.array_write(x={}, i={}, array={})".format(
x, i, array) x, i, array)
return gast.parse(func_code).body[0].value return gast.parse(func_code).body[0].value
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册