提交 bf447e60 编写于 作者: Z zhhsplendid

Change fluid.dygraph dy2stat APIs to paddle.jit, test=develop

上级 c04e2e85
...@@ -37,7 +37,7 @@ class AssertTransformer(gast.NodeTransformer): ...@@ -37,7 +37,7 @@ class AssertTransformer(gast.NodeTransformer):
def visit_Assert(self, node): def visit_Assert(self, node):
convert_assert_node = gast.parse( convert_assert_node = gast.parse(
'fluid.dygraph.dygraph_to_static.convert_operators.convert_assert({test}, {msg})'. 'paddle.jit.dygraph_to_static.convert_operators.convert_assert({test}, {msg})'.
format( format(
test=ast_to_source_code(node.test), test=ast_to_source_code(node.test),
msg=ast_to_source_code(node.msg) msg=ast_to_source_code(node.msg)
......
...@@ -39,7 +39,7 @@ class CastTransformer(gast.NodeTransformer): ...@@ -39,7 +39,7 @@ class CastTransformer(gast.NodeTransformer):
func_str = ast_to_source_code(node.func).strip() func_str = ast_to_source_code(node.func).strip()
if func_str in self._castable_type and len(node.args) > 0: if func_str in self._castable_type and len(node.args) > 0:
args_str = ast_to_source_code(node.args[0]).strip() args_str = ast_to_source_code(node.args[0]).strip()
new_func_str = "fluid.dygraph.dygraph_to_static.convert_operators.convert_var_dtype({}, '{}')".format( new_func_str = "paddle.jit.dygraph_to_static.convert_operators.convert_var_dtype({}, '{}')".format(
args_str, func_str) args_str, func_str)
new_node = gast.parse(new_func_str).body[0].value new_node = gast.parse(new_func_str).body[0].value
return new_node return new_node
......
...@@ -310,8 +310,8 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict, ...@@ -310,8 +310,8 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict,
After transformed, q and z are created in parent scope. For example, After transformed, q and z are created in parent scope. For example,
x, y = 5, 10 x, y = 5, 10
q = fluid.dygraph.dygraph_to_static.variable_trans_func.data_layer_not_check(name='q', shape=[-1], dtype='float32') q = paddle.jit.dygraph_to_static.variable_trans_func.data_layer_not_check(name='q', shape=[-1], dtype='float32')
z = fluid.dygraph.dygraph_to_static.variable_trans_func.data_layer_not_check(name='z', shape=[-1], dtype='float32') z = paddle.jit.dygraph_to_static.variable_trans_func.data_layer_not_check(name='z', shape=[-1], dtype='float32')
def true_func(x, y, q): def true_func(x, y, q):
x = x+1 x = x+1
......
...@@ -188,8 +188,8 @@ class ListTransformer(gast.NodeTransformer): ...@@ -188,8 +188,8 @@ class ListTransformer(gast.NodeTransformer):
pass pass
elif isinstance(slice_node, gast.Index): elif isinstance(slice_node, gast.Index):
value_code = ast_to_source_code(node.value) value_code = ast_to_source_code(node.value)
i = "fluid.layers.cast(" \ i = "paddle.cast(" \
"x=fluid.dygraph.dygraph_to_static.variable_trans_func.to_static_variable({})," \ "x=paddle.jit.dygraph_to_static.variable_trans_func.to_static_variable({})," \
"dtype='int64')".format(ast_to_source_code(slice_node)) "dtype='int64')".format(ast_to_source_code(slice_node))
assign_code = "{} = fluid.layers.array_write(x={}, i={}, array={})" \ assign_code = "{} = fluid.layers.array_write(x={}, i={}, array={})" \
.format(target_name, value_code, i, target_name) .format(target_name, value_code, i, target_name)
......
...@@ -34,7 +34,7 @@ class LogicalTransformer(gast.NodeTransformer): ...@@ -34,7 +34,7 @@ class LogicalTransformer(gast.NodeTransformer):
self.generic_visit(node) self.generic_visit(node)
if isinstance(node.op, gast.Not): if isinstance(node.op, gast.Not):
arg = ast_to_source_code(node.operand) arg = ast_to_source_code(node.operand)
new_node_str = "fluid.dygraph.dygraph_to_static.convert_operators.convert_logical_not({})".format( new_node_str = "paddle.jit.dygraph_to_static.convert_operators.convert_logical_not({})".format(
arg) arg)
# NOTE: gast.parse returns Module(body=[expr(value=...)]) # NOTE: gast.parse returns Module(body=[expr(value=...)])
new_node = gast.parse(new_node_str).body[0].value new_node = gast.parse(new_node_str).body[0].value
......
...@@ -46,7 +46,7 @@ def create_while_node(condition_name, body_name, loop_var_names): ...@@ -46,7 +46,7 @@ def create_while_node(condition_name, body_name, loop_var_names):
# For example: loop_var_names = [a, b, foo.x], the type of `a` or `b` is gast.Name, # For example: loop_var_names = [a, b, foo.x], the type of `a` or `b` is gast.Name,
# but the type of `foo.x` gast.Attribute. # but the type of `foo.x` gast.Attribute.
while_func_name = "fluid.dygraph.dygraph_to_static.convert_operators.convert_while_loop" while_func_name = "paddle.jit.dygraph_to_static.convert_operators.convert_while_loop"
while_node_str = "[{}] = {}({}, {}, [{}])".format( while_node_str = "[{}] = {}({}, {}, [{}])".format(
",".join(loop_var_names), while_func_name, condition_name, body_name, ",".join(loop_var_names), while_func_name, condition_name, body_name,
",".join(loop_var_names)) ",".join(loop_var_names))
......
...@@ -57,6 +57,6 @@ class PrintTransformer(gast.NodeTransformer): ...@@ -57,6 +57,6 @@ class PrintTransformer(gast.NodeTransformer):
def _create_print_node(self, print_args): def _create_print_node(self, print_args):
convert_print_func = gast.parse( convert_print_func = gast.parse(
'fluid.dygraph.dygraph_to_static.convert_operators.convert_print' 'paddle.jit.dygraph_to_static.convert_operators.convert_print'
).body[0].value ).body[0].value
return gast.Call(func=convert_print_func, args=print_args, keywords=[]) return gast.Call(func=convert_print_func, args=print_args, keywords=[])
...@@ -26,7 +26,7 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysi ...@@ -26,7 +26,7 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysi
def create_convert_shape_node(var_shape_node): def create_convert_shape_node(var_shape_node):
assert isinstance(var_shape_node, (gast.Attribute, gast.Subscript)) assert isinstance(var_shape_node, (gast.Attribute, gast.Subscript))
convert_var_shape_func = "fluid.dygraph.dygraph_to_static.convert_operators.convert_var_shape" convert_var_shape_func = "paddle.jit.dygraph_to_static.convert_operators.convert_var_shape"
if isinstance(var_shape_node, gast.Attribute): if isinstance(var_shape_node, gast.Attribute):
api_shape_node = gast.Call( api_shape_node = gast.Call(
......
...@@ -922,7 +922,7 @@ class ForNodeVisitor(object): ...@@ -922,7 +922,7 @@ class ForNodeVisitor(object):
else: else:
iter_var_name = ast_to_source_code(self.iter_node).strip() iter_var_name = ast_to_source_code(self.iter_node).strip()
convert_len_node_source_str = '{} = fluid.dygraph.dygraph_to_static.convert_operators.convert_len({})'.format( convert_len_node_source_str = '{} = paddle.jit.dygraph_to_static.convert_operators.convert_len({})'.format(
self.iter_var_len_name, iter_var_name) self.iter_var_len_name, iter_var_name)
convert_len_node = gast.parse(convert_len_node_source_str).body[0] convert_len_node = gast.parse(convert_len_node_source_str).body[0]
......
...@@ -74,20 +74,20 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0): ...@@ -74,20 +74,20 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0):
def to_static_variable_gast_node(name): def to_static_variable_gast_node(name):
func_code = "{} = fluid.dygraph.dygraph_to_static.variable_trans_func\ func_code = "{} = paddle.jit.dygraph_to_static.variable_trans_func\
.to_static_variable({})".format(name, name) .to_static_variable({})".format(name, name)
return gast.parse(func_code).body[0] return gast.parse(func_code).body[0]
def create_static_variable_gast_node(name): def create_static_variable_gast_node(name):
func_code = "{} = fluid.dygraph.dygraph_to_static.variable_trans_func\ func_code = "{} = paddle.jit.dygraph_to_static.variable_trans_func\
.data_layer_not_check(name='{}', shape=[-1], dtype='float32')".format( .data_layer_not_check(name='{}', shape=[-1], dtype='float32')".format(
name, name) name, name)
return gast.parse(func_code).body[0] return gast.parse(func_code).body[0]
def create_fill_constant_node(name, value): def create_fill_constant_node(name, value):
func_code = "{} = fluid.layers.fill_constant(shape=[1], ".format(name) func_code = "{} = paddle.fill_constant(shape=[1], ".format(name)
if isinstance(value, bool): if isinstance(value, bool):
func_code += "dtype='bool', value={})".format(value) func_code += "dtype='bool', value={})".format(value)
return gast.parse(func_code).body[0] return gast.parse(func_code).body[0]
......
...@@ -59,9 +59,9 @@ def dyfunc_with_if_else3(x): ...@@ -59,9 +59,9 @@ def dyfunc_with_if_else3(x):
# The var is created only in one of If.body or If.orelse node, and it used as gast.Load firstly after gast.If node. # The var is created only in one of If.body or If.orelse node, and it used as gast.Load firstly after gast.If node.
# The transformed code: # The transformed code:
""" """
q = fluid.dygraph.dygraph_to_static.variable_trans_func. q = paddle.jit.dygraph_to_static.variable_trans_func.
data_layer_not_check(name='q', shape=[-1], dtype='float32') data_layer_not_check(name='q', shape=[-1], dtype='float32')
z = fluid.dygraph.dygraph_to_static.variable_trans_func. z = paddle.jit.dygraph_to_static.variable_trans_func.
data_layer_not_check(name='z', shape=[-1], dtype='float32') data_layer_not_check(name='z', shape=[-1], dtype='float32')
def true_fn_0(q, x, y): def true_fn_0(q, x, y):
......
...@@ -52,7 +52,7 @@ def dyfunc_tensor_shape_4(x): ...@@ -52,7 +52,7 @@ def dyfunc_tensor_shape_4(x):
def dyfunc_tensor_shape_5(x): def dyfunc_tensor_shape_5(x):
# `res = fluid.layers.reshape(x, shape=(-1, s))` to # `res = fluid.layers.reshape(x, shape=(-1, s))` to
# `res = fluid.layers.reshape(x, shape=(-1, # `res = fluid.layers.reshape(x, shape=(-1,
# fluid.dygraph.dygraph_to_static.convert_operators.convert_var_shape(x)[0]))` # paddle.jit.dygraph_to_static.convert_operators.convert_var_shape(x)[0]))`
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
s = x.shape[0] s = x.shape[0]
res = fluid.layers.reshape(x, shape=(-1, s)) res = fluid.layers.reshape(x, shape=(-1, s))
...@@ -65,7 +65,7 @@ def dyfunc_with_if_1(x): ...@@ -65,7 +65,7 @@ def dyfunc_with_if_1(x):
x_shape_0 = x.shape[0] x_shape_0 = x.shape[0]
if x_shape_0 < 1: if x_shape_0 < 1:
# `res.shape[0]` is transformed into # `res.shape[0]` is transformed into
# `fluid.dygraph.dygraph_to_static.convert_operators.convert_var_shape(res)[0]` # `paddle.jit.dygraph_to_static.convert_operators.convert_var_shape(res)[0]`
if res.shape[0] > 1: if res.shape[0] > 1:
res = fluid.layers.fill_constant( res = fluid.layers.fill_constant(
value=2, shape=x.shape, dtype="int32") value=2, shape=x.shape, dtype="int32")
...@@ -89,7 +89,7 @@ def dyfunc_with_if_2(x): ...@@ -89,7 +89,7 @@ def dyfunc_with_if_2(x):
def dyfunc_with_for_1(x): def dyfunc_with_for_1(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32") res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32")
# `x.shape[0]` is transformed into `fluid.dygraph.dygraph_to_static.convert_operators.convert_var_shape(x)[0]` # `x.shape[0]` is transformed into `paddle.jit.dygraph_to_static.convert_operators.convert_var_shape(x)[0]`
for i in range(x.shape[0]): for i in range(x.shape[0]):
res += 1 res += 1
return res return res
...@@ -100,7 +100,7 @@ def dyfunc_with_for_2(x): ...@@ -100,7 +100,7 @@ def dyfunc_with_for_2(x):
x_shape_0 = x.shape[0] x_shape_0 = x.shape[0]
res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32") res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32")
# `x_shape_0` is transformed into `fluid.dygraph.dygraph_to_static.convert_operators.convert_var_shape(x)[0]` # `x_shape_0` is transformed into `paddle.jit.dygraph_to_static.convert_operators.convert_var_shape(x)[0]`
for i in range(x_shape_0): for i in range(x_shape_0):
res += 1 res += 1
return res return res
...@@ -124,7 +124,7 @@ def dyfunc_with_for_3(x): ...@@ -124,7 +124,7 @@ def dyfunc_with_for_3(x):
def dyfunc_with_while_1(x): def dyfunc_with_while_1(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32") res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32")
# `x.shape[0]` is transformed into `fluid.dygraph.dygraph_to_static.convert_operators.convert_var_shape(x)[0]` # `x.shape[0]` is transformed into `paddle.jit.dygraph_to_static.convert_operators.convert_var_shape(x)[0]`
i = 1 i = 1
while i < x.shape[0]: while i < x.shape[0]:
res += 1 res += 1
...@@ -137,7 +137,7 @@ def dyfunc_with_while_2(x): ...@@ -137,7 +137,7 @@ def dyfunc_with_while_2(x):
x_shape_0 = x.shape[0] x_shape_0 = x.shape[0]
res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32") res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32")
i = 1 i = 1
# `x_shape_0` is transformed into `fluid.dygraph.dygraph_to_static.convert_operators.convert_var_shape(x)[0]` # `x_shape_0` is transformed into `paddle.jit.dygraph_to_static.convert_operators.convert_var_shape(x)[0]`
while i < x_shape_0: while i < x_shape_0:
res += 1 res += 1
i = i + 2 i = i + 2
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册