提交 c2a5f5a8 编写于 作者: Z zhhsplendid

Temp commit

上级 5ddc395b
...@@ -37,8 +37,7 @@ class AssertTransformer(gast.NodeTransformer): ...@@ -37,8 +37,7 @@ class AssertTransformer(gast.NodeTransformer):
def visit_Assert(self, node): def visit_Assert(self, node):
convert_assert_node = gast.parse( convert_assert_node = gast.parse(
'paddle.jit.dygraph_to_static.convert_operators.convert_assert({test}, {msg})'. 'paddle.jit.dy2static.convert_assert({test}, {msg})'.format(
format(
test=ast_to_source_code(node.test), test=ast_to_source_code(node.test),
msg=ast_to_source_code(node.msg) msg=ast_to_source_code(node.msg)
if node.msg else "")).body[0].value if node.msg else "")).body[0].value
......
...@@ -70,8 +70,7 @@ class CallTransformer(gast.NodeTransformer): ...@@ -70,8 +70,7 @@ class CallTransformer(gast.NodeTransformer):
if PDB_SET in func_str: if PDB_SET in func_str:
return node return node
new_func_str = "paddle.jit.dygraph_to_static.convert_call({})".format( new_func_str = "paddle.jit.dy2static.convert_call({})".format(func_str)
func_str)
new_func_ast = gast.parse(new_func_str).body[0].value new_func_ast = gast.parse(new_func_str).body[0].value
node.func = new_func_ast node.func = new_func_ast
......
...@@ -39,7 +39,7 @@ class CastTransformer(gast.NodeTransformer): ...@@ -39,7 +39,7 @@ class CastTransformer(gast.NodeTransformer):
func_str = ast_to_source_code(node.func).strip() func_str = ast_to_source_code(node.func).strip()
if func_str in self._castable_type and len(node.args) > 0: if func_str in self._castable_type and len(node.args) > 0:
args_str = ast_to_source_code(node.args[0]).strip() args_str = ast_to_source_code(node.args[0]).strip()
new_func_str = "paddle.jit.dygraph_to_static.convert_operators.convert_var_dtype({}, '{}')".format( new_func_str = "paddle.jit.dy2static.convert_var_dtype({}, '{}')".format(
args_str, func_str) args_str, func_str)
new_node = gast.parse(new_func_str).body[0].value new_node = gast.parse(new_func_str).body[0].value
return new_node return new_node
......
...@@ -310,8 +310,8 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict, ...@@ -310,8 +310,8 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict,
After transformed, q and z are created in parent scope. For example, After transformed, q and z are created in parent scope. For example,
x, y = 5, 10 x, y = 5, 10
q = paddle.jit.dygraph_to_static.variable_trans_func.data_layer_not_check(name='q', shape=[-1], dtype='float32') q = paddle.jit.dy2static.data_layer_not_check(name='q', shape=[-1], dtype='float32')
z = paddle.jit.dygraph_to_static.variable_trans_func.data_layer_not_check(name='z', shape=[-1], dtype='float32') z = paddle.jit.dy2static.data_layer_not_check(name='z', shape=[-1], dtype='float32')
def true_func(x, y, q): def true_func(x, y, q):
x = x+1 x = x+1
...@@ -460,7 +460,7 @@ def create_convert_ifelse_node(return_name_ids, ...@@ -460,7 +460,7 @@ def create_convert_ifelse_node(return_name_ids,
false_func, false_func,
is_if_expr=False): is_if_expr=False):
""" """
Create `paddle.jit.dygraph_to_static.convert_operators.convert_ifelse( Create `paddle.jit.dy2static.convert_ifelse(
pred, true_fn, false_fn, true_args, false_args, return_vars)` pred, true_fn, false_fn, true_args, false_args, return_vars)`
to replace original `python if/else` statement. to replace original `python if/else` statement.
""" """
...@@ -491,7 +491,7 @@ def create_convert_ifelse_node(return_name_ids, ...@@ -491,7 +491,7 @@ def create_convert_ifelse_node(return_name_ids,
return_vars = create_name_nodes(return_name_ids) return_vars = create_name_nodes(return_name_ids)
convert_ifelse_layer = gast.parse( convert_ifelse_layer = gast.parse(
'paddle.jit.dygraph_to_static.convert_operators.convert_ifelse(' 'paddle.jit.dy2static.convert_ifelse('
'{pred}, {true_fn}, {false_fn}, {true_args}, {false_args}, {return_vars})'. '{pred}, {true_fn}, {false_fn}, {true_args}, {false_args}, {return_vars})'.
format( format(
pred=ast_to_source_code(pred), pred=ast_to_source_code(pred),
......
...@@ -189,7 +189,7 @@ class ListTransformer(gast.NodeTransformer): ...@@ -189,7 +189,7 @@ class ListTransformer(gast.NodeTransformer):
elif isinstance(slice_node, gast.Index): elif isinstance(slice_node, gast.Index):
value_code = ast_to_source_code(node.value) value_code = ast_to_source_code(node.value)
i = "paddle.cast(" \ i = "paddle.cast(" \
"x=paddle.jit.dygraph_to_static.variable_trans_func.to_static_variable({})," \ "x=paddle.jit.dy2static.to_static_variable({})," \
"dtype='int64')".format(ast_to_source_code(slice_node)) "dtype='int64')".format(ast_to_source_code(slice_node))
assign_code = "{} = fluid.layers.array_write(x={}, i={}, array={})" \ assign_code = "{} = fluid.layers.array_write(x={}, i={}, array={})" \
.format(target_name, value_code, i, target_name) .format(target_name, value_code, i, target_name)
......
...@@ -34,7 +34,7 @@ class LogicalTransformer(gast.NodeTransformer): ...@@ -34,7 +34,7 @@ class LogicalTransformer(gast.NodeTransformer):
self.generic_visit(node) self.generic_visit(node)
if isinstance(node.op, gast.Not): if isinstance(node.op, gast.Not):
arg = ast_to_source_code(node.operand) arg = ast_to_source_code(node.operand)
new_node_str = "paddle.jit.dygraph_to_static.convert_operators.convert_logical_not({})".format( new_node_str = "paddle.jit.dy2static.convert_logical_not({})".format(
arg) arg)
# NOTE: gast.parse returns Module(body=[expr(value=...)]) # NOTE: gast.parse returns Module(body=[expr(value=...)])
new_node = gast.parse(new_node_str).body[0].value new_node = gast.parse(new_node_str).body[0].value
...@@ -67,7 +67,7 @@ class LogicalTransformer(gast.NodeTransformer): ...@@ -67,7 +67,7 @@ class LogicalTransformer(gast.NodeTransformer):
nodes = [pre_logic_node] + [post_logic_node] nodes = [pre_logic_node] + [post_logic_node]
args = [ast_to_source_code(child) for child in nodes] args = [ast_to_source_code(child) for child in nodes]
new_node_str = "paddle.jit.dygraph_to_static.convert_operators.convert_logical_{}(x={}, y={})".format( new_node_str = "paddle.jit.dy2static.convert_logical_{}(x={}, y={})".format(
api_type, args[0], args[1]) api_type, args[0], args[1])
# NOTE: gast.parse return Module(body=[expr(...)]) # NOTE: gast.parse return Module(body=[expr(...)])
new_node = gast.parse(new_node_str).body[0].value new_node = gast.parse(new_node_str).body[0].value
......
...@@ -46,7 +46,7 @@ def create_while_node(condition_name, body_name, loop_var_names): ...@@ -46,7 +46,7 @@ def create_while_node(condition_name, body_name, loop_var_names):
# For example: loop_var_names = [a, b, foo.x], the type of `a` or `b` is gast.Name, # For example: loop_var_names = [a, b, foo.x], the type of `a` or `b` is gast.Name,
# but the type of `foo.x` gast.Attribute. # but the type of `foo.x` gast.Attribute.
while_func_name = "paddle.jit.dygraph_to_static.convert_operators.convert_while_loop" while_func_name = "paddle.jit.dy2static.convert_while_loop"
while_node_str = "[{}] = {}({}, {}, [{}])".format( while_node_str = "[{}] = {}({}, {}, [{}])".format(
",".join(loop_var_names), while_func_name, condition_name, body_name, ",".join(loop_var_names), while_func_name, condition_name, body_name,
",".join(loop_var_names)) ",".join(loop_var_names))
......
...@@ -51,6 +51,5 @@ class PrintTransformer(gast.NodeTransformer): ...@@ -51,6 +51,5 @@ class PrintTransformer(gast.NodeTransformer):
def _create_print_node(self, print_args): def _create_print_node(self, print_args):
convert_print_func = gast.parse( convert_print_func = gast.parse(
'paddle.jit.dygraph_to_static.convert_operators.convert_print' 'paddle.jit.dy2static.convert_print').body[0].value
).body[0].value
return gast.Call(func=convert_print_func, args=print_args, keywords=[]) return gast.Call(func=convert_print_func, args=print_args, keywords=[])
...@@ -26,7 +26,7 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysi ...@@ -26,7 +26,7 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysi
def create_convert_shape_node(var_shape_node): def create_convert_shape_node(var_shape_node):
assert isinstance(var_shape_node, (gast.Attribute, gast.Subscript)) assert isinstance(var_shape_node, (gast.Attribute, gast.Subscript))
convert_var_shape_func = "paddle.jit.dygraph_to_static.convert_operators.convert_var_shape" convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape"
if isinstance(var_shape_node, gast.Attribute): if isinstance(var_shape_node, gast.Attribute):
api_shape_node = gast.Call( api_shape_node = gast.Call(
......
...@@ -922,7 +922,7 @@ class ForNodeVisitor(object): ...@@ -922,7 +922,7 @@ class ForNodeVisitor(object):
else: else:
iter_var_name = ast_to_source_code(self.iter_node).strip() iter_var_name = ast_to_source_code(self.iter_node).strip()
convert_len_node_source_str = '{} = paddle.jit.dygraph_to_static.convert_operators.convert_len({})'.format( convert_len_node_source_str = '{} = paddle.jit.dy2static.convert_len({})'.format(
self.iter_var_len_name, iter_var_name) self.iter_var_len_name, iter_var_name)
convert_len_node = gast.parse(convert_len_node_source_str).body[0] convert_len_node = gast.parse(convert_len_node_source_str).body[0]
......
...@@ -74,13 +74,13 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0): ...@@ -74,13 +74,13 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0):
def to_static_variable_gast_node(name): def to_static_variable_gast_node(name):
func_code = "{} = paddle.jit.dygraph_to_static.variable_trans_func\ func_code = "{} = paddle.jit.dy2static.to_static_variable({})".format(name,
.to_static_variable({})".format(name, name) name)
return gast.parse(func_code).body[0] return gast.parse(func_code).body[0]
def create_static_variable_gast_node(name): def create_static_variable_gast_node(name):
func_code = "{} = paddle.jit.dygraph_to_static.variable_trans_func\ func_code = "{} = paddle.jit.dy2static\
.data_layer_not_check(name='{}', shape=[-1], dtype='float32')".format( .data_layer_not_check(name='{}', shape=[-1], dtype='float32')".format(
name, name) name, name)
return gast.parse(func_code).body[0] return gast.parse(func_code).body[0]
......
...@@ -59,9 +59,9 @@ def dyfunc_with_if_else3(x): ...@@ -59,9 +59,9 @@ def dyfunc_with_if_else3(x):
# The var is created only in one of If.body or If.orelse node, and it used as gast.Load firstly after gast.If node. # The var is created only in one of If.body or If.orelse node, and it used as gast.Load firstly after gast.If node.
# The transformed code: # The transformed code:
""" """
q = paddle.jit.dygraph_to_static.variable_trans_func. q = paddle.jit.dy2static.
data_layer_not_check(name='q', shape=[-1], dtype='float32') data_layer_not_check(name='q', shape=[-1], dtype='float32')
z = paddle.jit.dygraph_to_static.variable_trans_func. z = paddle.jit.dy2static.
data_layer_not_check(name='z', shape=[-1], dtype='float32') data_layer_not_check(name='z', shape=[-1], dtype='float32')
def true_fn_0(q, x, y): def true_fn_0(q, x, y):
...@@ -77,8 +77,8 @@ def dyfunc_with_if_else3(x): ...@@ -77,8 +77,8 @@ def dyfunc_with_if_else3(x):
n = x + 3 n = x + 3
return q, x, y, z return q, x, y, z
q, x, y, z = fluid.layers.cond(fluid.layers.mean(x)[0] < 5, lambda : q, x, y, z = fluid.layers.cond(fluid.layers.mean(x)[0] < 5, lambda :
paddle.jit.dygraph_to_static.convert_call(true_fn_0)(q, x, y), paddle.jit.dy2static.convert_call(true_fn_0)(q, x, y),
lambda : paddle.jit.dygraph_to_static.convert_call(false_fn_0)(q, lambda : paddle.jit.dy2static.convert_call(false_fn_0)(q,
x, y)) x, y))
""" """
y = x + 1 y = x + 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册