未验证 提交 6cb24967 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Unify all API name in_jst import path to improve readablity (#43868)

* [Dy2Stat]Polish all API name of _jst
上级 13451615
...@@ -36,10 +36,8 @@ class AssertTransformer(gast.NodeTransformer): ...@@ -36,10 +36,8 @@ class AssertTransformer(gast.NodeTransformer):
self.visit(self.root) self.visit(self.root)
def visit_Assert(self, node): def visit_Assert(self, node):
convert_assert_node = gast.parse( convert_assert_node = gast.parse('_jst.Assert({test}, {msg})'.format(
'_jst.convert_assert({test}, {msg})'.format(
test=ast_to_source_code(node.test), test=ast_to_source_code(node.test),
msg=ast_to_source_code(node.msg) msg=ast_to_source_code(node.msg) if node.msg else "")).body[0].value
if node.msg else "")).body[0].value
return gast.Expr(value=convert_assert_node) return gast.Expr(value=convert_assert_node)
...@@ -71,7 +71,7 @@ class CallTransformer(gast.NodeTransformer): ...@@ -71,7 +71,7 @@ class CallTransformer(gast.NodeTransformer):
if PDB_SET in func_str: if PDB_SET in func_str:
return node return node
new_func_str = "_jst.convert_call({})".format(func_str) new_func_str = "_jst.Call({})".format(func_str)
new_func_ast = gast.parse(new_func_str).body[0].value new_func_ast = gast.parse(new_func_str).body[0].value
node.func = new_func_ast node.func = new_func_ast
......
...@@ -39,8 +39,7 @@ class CastTransformer(gast.NodeTransformer): ...@@ -39,8 +39,7 @@ class CastTransformer(gast.NodeTransformer):
func_str = ast_to_source_code(node.func).strip() func_str = ast_to_source_code(node.func).strip()
if func_str in self._castable_type and len(node.args) > 0: if func_str in self._castable_type and len(node.args) > 0:
args_str = ast_to_source_code(node.args[0]).strip() args_str = ast_to_source_code(node.args[0]).strip()
new_func_str = "_jst.convert_var_dtype({}, '{}')".format( new_func_str = "_jst.AsDtype({}, '{}')".format(args_str, func_str)
args_str, func_str)
new_node = gast.parse(new_func_str).body[0].value new_node = gast.parse(new_func_str).body[0].value
return new_node return new_node
......
...@@ -361,8 +361,8 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict, ...@@ -361,8 +361,8 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict,
After transformed, q and z are created in parent scope. For example, After transformed, q and z are created in parent scope. For example,
x, y = 5, 10 x, y = 5, 10
q = paddle.jit.dy2static.data_layer_not_check(name='q', shape=[-1], dtype='float32') q = paddle.jit.dy2static.UndefindVar('q')
z = paddle.jit.dy2static.data_layer_not_check(name='z', shape=[-1], dtype='float32') z = paddle.jit.dy2static.UndefindVar('z')
def true_func(x, y, q): def true_func(x, y, q):
x = x+1 x = x+1
...@@ -647,7 +647,7 @@ def create_convert_ifelse_node(return_name_ids, ...@@ -647,7 +647,7 @@ def create_convert_ifelse_node(return_name_ids,
false_func_source = false_func.name false_func_source = false_func.name
convert_ifelse_layer = gast.parse( convert_ifelse_layer = gast.parse(
'_jst.convert_ifelse(' '_jst.IfElse('
'{pred}, {true_fn}, {false_fn}, {get_args}, {set_args}, {return_name_ids})' '{pred}, {true_fn}, {false_fn}, {get_args}, {set_args}, {return_name_ids})'
.format( .format(
pred=ast_to_source_code(pred), pred=ast_to_source_code(pred),
......
...@@ -252,7 +252,7 @@ class ListTransformer(gast.NodeTransformer): ...@@ -252,7 +252,7 @@ class ListTransformer(gast.NodeTransformer):
# 2. pop stmt for a list or dict if len(args_str) == 1 # 2. pop stmt for a list or dict if len(args_str) == 1
# 3. pop stmt for a dict if len(args_str) == 2 # 3. pop stmt for a dict if len(args_str) == 2
if len(args_str) <= 2: if len(args_str) <= 2:
new_pop_str = "_jst.convert_pop({}, {})"\ new_pop_str = "_jst.Pop({}, {})"\
.format(target_str, ",".join(args_str)) .format(target_str, ",".join(args_str))
new_pop_node = gast.parse(new_pop_str).body[0].value new_pop_node = gast.parse(new_pop_str).body[0].value
return new_pop_node return new_pop_node
......
...@@ -43,7 +43,7 @@ class LogicalTransformer(gast.NodeTransformer): ...@@ -43,7 +43,7 @@ class LogicalTransformer(gast.NodeTransformer):
a = x > 1 and y < 1 a = x > 1 and y < 1
Transformed code: Transformed code:
a = paddle.jit.dy2static.convert_logical_and(lambda:x>1, lambda:y<1) a = _jst.And(lambda:x>1, lambda:y<1)
""" """
def __init__(self, wrapper_root): def __init__(self, wrapper_root):
...@@ -57,7 +57,7 @@ class LogicalTransformer(gast.NodeTransformer): ...@@ -57,7 +57,7 @@ class LogicalTransformer(gast.NodeTransformer):
self.generic_visit(node) self.generic_visit(node)
if isinstance(node.op, gast.Not): if isinstance(node.op, gast.Not):
arg = ast_to_source_code(node.operand) arg = ast_to_source_code(node.operand)
new_node_str = "_jst.convert_logical_not({})".format(arg) new_node_str = "_jst.Not({})".format(arg)
# NOTE: gast.parse returns Module(body=[expr(value=...)]) # NOTE: gast.parse returns Module(body=[expr(value=...)])
new_node = gast.parse(new_node_str).body[0].value new_node = gast.parse(new_node_str).body[0].value
return new_node return new_node
...@@ -66,9 +66,9 @@ class LogicalTransformer(gast.NodeTransformer): ...@@ -66,9 +66,9 @@ class LogicalTransformer(gast.NodeTransformer):
def visit_BoolOp(self, node): def visit_BoolOp(self, node):
self.generic_visit(node) self.generic_visit(node)
if isinstance(node.op, gast.And): if isinstance(node.op, gast.And):
new_node = self._create_bool_op_node(node.values, 'and') new_node = self._create_bool_op_node(node.values, 'And')
elif isinstance(node.op, gast.Or): elif isinstance(node.op, gast.Or):
new_node = self._create_bool_op_node(node.values, 'or') new_node = self._create_bool_op_node(node.values, 'Or')
else: else:
raise TypeError( raise TypeError(
"Only supports and/or syntax in control flow if statement.") "Only supports and/or syntax in control flow if statement.")
...@@ -95,7 +95,7 @@ class LogicalTransformer(gast.NodeTransformer): ...@@ -95,7 +95,7 @@ class LogicalTransformer(gast.NodeTransformer):
nodes = [pre_logic_node] + [post_logic_node] nodes = [pre_logic_node] + [post_logic_node]
args = [ast_to_source_code(child) for child in nodes] args = [ast_to_source_code(child) for child in nodes]
new_node_str = "_jst.convert_logical_{}(lambda:{}, lambda:{})".format( new_node_str = "_jst.{}(lambda:{}, lambda:{})".format(
api_type, args[0], args[1]) api_type, args[0], args[1])
# NOTE: gast.parse return Module(body=[expr(...)]) # NOTE: gast.parse return Module(body=[expr(...)])
new_node = gast.parse(new_node_str).body[0].value new_node = gast.parse(new_node_str).body[0].value
......
...@@ -28,7 +28,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name ...@@ -28,7 +28,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
from paddle.fluid.dygraph.dygraph_to_static.utils import ForLoopTuplePreTransformer from paddle.fluid.dygraph.dygraph_to_static.utils import ForLoopTuplePreTransformer
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
__all__ = ['LoopTransformer', 'NameVisitor'] __all__ = ['LoopTransformer', 'NameVisitor']
...@@ -89,7 +89,7 @@ def create_while_nodes(condition_name, body_name, loop_var_names): ...@@ -89,7 +89,7 @@ def create_while_nodes(condition_name, body_name, loop_var_names):
else: else:
assign_loop_var_names.append(name) assign_loop_var_names.append(name)
while_func_name = "_jst.convert_while_loop" while_func_name = "_jst.While"
while_node_str = "[{}] = {}({}, {}, [{}])".format( while_node_str = "[{}] = {}({}, {}, [{}])".format(
",".join(assign_loop_var_names), while_func_name, condition_name, ",".join(assign_loop_var_names), while_func_name, condition_name,
body_name, ",".join(loop_var_names)) body_name, ",".join(loop_var_names))
...@@ -672,7 +672,7 @@ class LoopTransformer(gast.NodeTransformer): ...@@ -672,7 +672,7 @@ class LoopTransformer(gast.NodeTransformer):
# We need to create static variable for those variables # We need to create static variable for those variables
for name in create_var_names: for name in create_var_names:
if "." not in name: if "." not in name:
new_stmts.append(create_static_variable_gast_node(name)) new_stmts.append(create_fill_constant_node(name))
# 4. append init statements # 4. append init statements
new_stmts.extend(init_stmts) new_stmts.extend(init_stmts)
...@@ -756,7 +756,7 @@ class LoopTransformer(gast.NodeTransformer): ...@@ -756,7 +756,7 @@ class LoopTransformer(gast.NodeTransformer):
# We need to create static variable for those variables # We need to create static variable for those variables
for name in create_var_names: for name in create_var_names:
if "." not in name: if "." not in name:
new_stmts.append(create_static_variable_gast_node(name)) new_stmts.append(create_fill_constant_node(name))
condition_func_node = gast.FunctionDef( condition_func_node = gast.FunctionDef(
name=unique_name.generate(WHILE_CONDITION_PREFIX), name=unique_name.generate(WHILE_CONDITION_PREFIX),
......
...@@ -50,5 +50,5 @@ class PrintTransformer(gast.NodeTransformer): ...@@ -50,5 +50,5 @@ class PrintTransformer(gast.NodeTransformer):
return gast.Expr(value=convert_print_node) return gast.Expr(value=convert_print_node)
def _create_print_node(self, print_args): def _create_print_node(self, print_args):
convert_print_func = gast.parse('_jst.convert_print').body[0].value convert_print_func = gast.parse('_jst.Print').body[0].value
return gast.Call(func=convert_print_func, args=print_args, keywords=[]) return gast.Call(func=convert_print_func, args=print_args, keywords=[])
...@@ -14,22 +14,16 @@ ...@@ -14,22 +14,16 @@
from __future__ import print_function from __future__ import print_function
import copy
from paddle.utils import gast from paddle.utils import gast
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import slice_is_num
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api
from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
class TensorShapeTransformer(gast.NodeTransformer): class TensorShapeTransformer(gast.NodeTransformer):
""" """
This class transforms variable.shape into Static Graph Ast. This class transforms variable.shape into Static Graph Ast.
All 'xxx.shape' will be converted int '_jst.convert_shape(x)'. All 'xxx.shape' will be converted int '_jst.Shape(x)'.
""" """
def __init__(self, wrapper_root): def __init__(self, wrapper_root):
...@@ -48,7 +42,7 @@ class TensorShapeTransformer(gast.NodeTransformer): ...@@ -48,7 +42,7 @@ class TensorShapeTransformer(gast.NodeTransformer):
# NOTE(dev): we can deal with paddle.shape in this case, but it's # NOTE(dev): we can deal with paddle.shape in this case, but it's
# not pretty to modify into 'convert_shape(paddle)(x)[0]'. # not pretty to modify into 'convert_shape(paddle)(x)[0]'.
if args != 'paddle': if args != 'paddle':
convert_shape_func = "_jst.convert_shape({})".format(args) convert_shape_func = "_jst.Shape({})".format(args)
shape_node = gast.parse(convert_shape_func).body[0].value shape_node = gast.parse(convert_shape_func).body[0].value
return shape_node return shape_node
return node return node
...@@ -1178,7 +1178,7 @@ class ForNodeVisitor(object): ...@@ -1178,7 +1178,7 @@ class ForNodeVisitor(object):
else: else:
iter_var_name = ast_to_source_code(self.iter_node).strip() iter_var_name = ast_to_source_code(self.iter_node).strip()
convert_len_node_source_str = '{} = _jst.convert_len({})'.format( convert_len_node_source_str = '{} = _jst.Len({})'.format(
self.iter_var_len_name, iter_var_name) self.iter_var_len_name, iter_var_name)
convert_len_node = gast.parse(convert_len_node_source_str).body[0] convert_len_node = gast.parse(convert_len_node_source_str).body[0]
......
...@@ -23,57 +23,11 @@ from paddle.fluid.framework import Variable ...@@ -23,57 +23,11 @@ from paddle.fluid.framework import Variable
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
__all__ = [ __all__ = [
'create_bool_as_type', 'create_fill_constant_node', 'create_bool_as_type', 'create_fill_constant_node', 'to_static_variable',
'create_static_variable_gast_node', 'data_layer_not_check', 'create_undefined_var'
'to_static_variable', 'to_static_variable_gast_node', 'create_undefined_var'
] ]
def data_layer_not_check(name, shape, dtype='float32', lod_level=0):
"""
This function creates a Tensor on the global block. The created Tensor
doesn't check the dtype and the shape of feed data because dygraph input
data can be various-length. This API is used in translating dygraph into
static graph.
Note:
The default :code:`stop_gradient` attribute of the Tensor created by
this API is true, which means the gradient won't be passed backward
through the data Tensor. Set :code:`var.stop_gradient = False` If
user would like to pass backward gradient.
Args:
name (str): The name/alias of the Tensor, see :ref:`api_guide_Name`
for more details.
shape (list|tuple): List|Tuple of integers declaring the shape. You can
set "None" at a dimension to indicate the dimension can be of any
size. For example, it is useful to set changeable batch size as "None"
dtype (np.dtype|VarType|str, optional): The type of the data. Supported
dtype: bool, float16, float32, float64, int8, int16, int32, int64,
uint8. Default: float32
lod_level (int, optional): The LoD level of the LoDTensor. Usually users
don't have to set this value. For more details about when and how to
use LoD level, see :ref:`user_guide_lod_tensor` . Default: 0
Returns:
Tensor: The global Tensor that gives access to the data.
"""
helper = LayerHelper('data', **locals())
shape = list(shape)
for i in six.moves.range(len(shape)):
if shape[i] is None:
shape[i] = -1
return helper.create_global_variable(name=name,
shape=shape,
dtype=dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
stop_gradient=True,
lod_level=lod_level,
is_data=True,
need_check_feed=False)
def create_undefined_var(name): def create_undefined_var(name):
func_code = "{} = _jst.UndefinedVar('{}')".format(name, name) func_code = "{} = _jst.UndefinedVar('{}')".format(name, name)
return gast.parse(func_code).body[0] return gast.parse(func_code).body[0]
...@@ -85,18 +39,7 @@ def create_nonlocal_stmt_node(names): ...@@ -85,18 +39,7 @@ def create_nonlocal_stmt_node(names):
return gast.parse(func_code).body[0] return gast.parse(func_code).body[0]
def to_static_variable_gast_node(name): def create_fill_constant_node(name, value=0):
func_code = "{} = _jst.to_static_variable({})".format(name, name)
return gast.parse(func_code).body[0]
def create_static_variable_gast_node(name):
func_code = "{} = _jst.data_layer_not_check(name='{}', shape=[-1], dtype='float32')".format(
name, unique_name.generate(name))
return gast.parse(func_code).body[0]
def create_fill_constant_node(name, value):
func_code = "{} = paddle.full(shape=[1], ".format(name) func_code = "{} = paddle.full(shape=[1], ".format(name)
if isinstance(value, bool): if isinstance(value, bool):
func_code += "dtype='bool', fill_value={}, name='{}')".format( func_code += "dtype='bool', fill_value={}, name='{}')".format(
...@@ -121,7 +64,6 @@ def to_static_variable(x): ...@@ -121,7 +64,6 @@ def to_static_variable(x):
return paddle.full(shape=[1], dtype='bool', fill_value=x) return paddle.full(shape=[1], dtype='bool', fill_value=x)
if isinstance(x, float): if isinstance(x, float):
return paddle.full(shape=[1], dtype='float64', fill_value=x) return paddle.full(shape=[1], dtype='float64', fill_value=x)
if isinstance(x, six.integer_types): if isinstance(x, six.integer_types):
return paddle.full(shape=[1], dtype='int64', fill_value=x) return paddle.full(shape=[1], dtype='int64', fill_value=x)
......
...@@ -72,10 +72,8 @@ def dyfunc_with_if_else3(x): ...@@ -72,10 +72,8 @@ def dyfunc_with_if_else3(x):
# The var is created only in one of If.body or If.orelse node, and it used as gast.Load firstly after gast.If node. # The var is created only in one of If.body or If.orelse node, and it used as gast.Load firstly after gast.If node.
# The transformed code: # The transformed code:
""" """
q = paddle.jit.dy2static. q = paddle.jit.dy2static.UndefinedVar('q')
data_layer_not_check(name='q', shape=[-1], dtype='float32') z = paddle.jit.dy2static.UndefinedVar('z')
z = paddle.jit.dy2static.
data_layer_not_check(name='z', shape=[-1], dtype='float32')
def true_fn_0(q, x, y): def true_fn_0(q, x, y):
x = x + 1 x = x + 1
......
...@@ -266,7 +266,7 @@ class TestDynamicToStaticCode(unittest.TestCase): ...@@ -266,7 +266,7 @@ class TestDynamicToStaticCode(unittest.TestCase):
return get_source_code(self.answer_func) return get_source_code(self.answer_func)
def _get_transformed_code(self): def _get_transformed_code(self):
transformed_func = _jst.convert_call(self.func) transformed_func = _jst.Call(self.func)
return get_source_code(transformed_func) return get_source_code(transformed_func)
def test_code(self): def test_code(self):
...@@ -289,7 +289,7 @@ class TestDynamicToStaticCode2(TestDynamicToStaticCode): ...@@ -289,7 +289,7 @@ class TestDynamicToStaticCode2(TestDynamicToStaticCode):
class StaticCode(): class StaticCode():
def func_convert_then_not_to_static(x): def func_convert_then_not_to_static(x):
y = _jst.convert_call(func_not_to_static)(x) y = _jst.Call(func_not_to_static)(x)
return y return y
self.answer_func = StaticCode.func_convert_then_not_to_static self.answer_func = StaticCode.func_convert_then_not_to_static
......
...@@ -277,6 +277,7 @@ class TestListInWhileLoop(TestListWithoutControlFlow): ...@@ -277,6 +277,7 @@ class TestListInWhileLoop(TestListWithoutControlFlow):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
if to_static: if to_static:
print(declarative(self.dygraph_func).code)
res = declarative(self.dygraph_func)(self.input, self.iter_num) res = declarative(self.dygraph_func)(self.input, self.iter_num)
else: else:
res = self.dygraph_func(self.input, self.iter_num) res = self.dygraph_func(self.input, self.iter_num)
......
...@@ -90,7 +90,7 @@ class StaticCode1(): ...@@ -90,7 +90,7 @@ class StaticCode1():
x_v = x_v + 1 x_v = x_v + 1
return x_v return x_v
_jst.convert_ifelse( _jst.IfElse(
fluid.layers.mean(x_v)[0] > 5, true_fn_0, false_fn_0, get_args_0, fluid.layers.mean(x_v)[0] > 5, true_fn_0, false_fn_0, get_args_0,
set_args_0, ('x_v', )) set_args_0, ('x_v', ))
...@@ -115,8 +115,8 @@ class StaticCode1(): ...@@ -115,8 +115,8 @@ class StaticCode1():
__return_value_0 = x_v __return_value_0 = x_v
return __return_value_0 return __return_value_0
_jst.convert_ifelse(label is not None, true_fn_1, false_fn_1, _jst.IfElse(label is not None, true_fn_1, false_fn_1, get_args_1,
get_args_1, set_args_1, ('__return_value_0', )) set_args_1, ('__return_value_0', ))
return __return_value_0 return __return_value_0
...@@ -147,7 +147,7 @@ class StaticCode2(): ...@@ -147,7 +147,7 @@ class StaticCode2():
x_v = x_v + 1 x_v = x_v + 1
return x_v return x_v
_jst.convert_ifelse( _jst.IfElse(
fluid.layers.mean(x_v)[0] > 5, true_fn_2, false_fn_2, get_args_2, fluid.layers.mean(x_v)[0] > 5, true_fn_2, false_fn_2, get_args_2,
set_args_2, ('x_v', )) set_args_2, ('x_v', ))
...@@ -172,8 +172,8 @@ class StaticCode2(): ...@@ -172,8 +172,8 @@ class StaticCode2():
__return_value_1 = x_v __return_value_1 = x_v
return __return_value_1 return __return_value_1
_jst.convert_ifelse(label is not None, true_fn_3, false_fn_3, _jst.IfElse(label is not None, true_fn_3, false_fn_3, get_args_3,
get_args_3, set_args_3, ('__return_value_1', )) set_args_3, ('__return_value_1', ))
return __return_value_1 return __return_value_1
......
...@@ -275,7 +275,6 @@ class TestTensorShapeBasic(unittest.TestCase): ...@@ -275,7 +275,6 @@ class TestTensorShapeBasic(unittest.TestCase):
self.expected_slice_op_num = 0 self.expected_slice_op_num = 0
def _compute_op_num(self, program): def _compute_op_num(self, program):
print(program)
self.op_num = sum([len(block.ops) for block in program.blocks]) self.op_num = sum([len(block.ops) for block in program.blocks])
self.shape_op_num = 0 self.shape_op_num = 0
self.slice_op_num = 0 self.slice_op_num = 0
......
...@@ -22,30 +22,6 @@ import paddle.fluid as fluid ...@@ -22,30 +22,6 @@ import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import data_layer_not_check
class TestDataLayerNotCheck(unittest.TestCase):
def test_create_none_shape(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
d = data_layer_not_check(name="d", shape=(None, -1, 3))
self.assertEqual(d.shape, (-1, -1, 3))
self.assertEqual(d.name, "d")
def test_feed_mismatch_shape(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
d = data_layer_not_check(name="d", shape=(1, 2, 3))
feed_in_data = np.random.uniform(size=[1, 2, 4]).astype(np.float32)
place = fluid.CUDAPlace(
0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace()
exe = fluid.Executor(place)
ret = exe.run(main_program,
feed={d.name: feed_in_data},
fetch_list=[d.name])
self.assertTrue(np.allclose(ret, feed_in_data))
class TestVariableTransFunc(unittest.TestCase): class TestVariableTransFunc(unittest.TestCase):
......
...@@ -14,25 +14,21 @@ ...@@ -14,25 +14,21 @@
from .base import saw from .base import saw
from .base import UndefinedVar from .base import UndefinedVar
from .convert_call_func import convert_call # noqa: F401 from .convert_operators import convert_logical_and as And # noqa: F401
from .convert_operators import cast_bool_if_necessary # noqa: F401 from .convert_operators import convert_var_dtype as AsDtype # noqa: F401
from .convert_operators import convert_assert # noqa: F401 from .convert_operators import convert_assert as Assert # noqa: F401
from .convert_operators import convert_ifelse # noqa: F401 from .convert_call_func import convert_call as Call # noqa: F401
from .convert_operators import convert_len # noqa: F401 from .convert_operators import convert_ifelse as IfElse # noqa: F401
from .convert_operators import convert_logical_and # noqa: F401 from .convert_operators import convert_len as Len # noqa: F401
from .convert_operators import convert_logical_not # noqa: F401 from .convert_operators import convert_logical_not as Not # noqa: F401
from .convert_operators import convert_logical_or # noqa: F401 from .convert_operators import convert_logical_or as Or # noqa: F401
from .convert_operators import convert_pop # noqa: F401 from .convert_operators import convert_pop as Pop # noqa: F401
from .convert_operators import convert_print # noqa: F401 from .convert_operators import convert_print as Print # noqa: F401
from .convert_operators import convert_shape_compare # noqa: F401 from .convert_operators import convert_shape as Shape # noqa: F401
from .convert_operators import convert_var_dtype # noqa: F401 from .convert_operators import convert_while_loop as While # noqa: F401
from .convert_operators import convert_shape # noqa: F401
from .convert_operators import convert_while_loop # noqa: F401
from .variable_trans_func import create_bool_as_type # noqa: F401 from .variable_trans_func import create_bool_as_type # noqa: F401
from .variable_trans_func import create_fill_constant_node # noqa: F401
from .variable_trans_func import create_static_variable_gast_node # noqa: F401
from .variable_trans_func import data_layer_not_check # noqa: F401
from .variable_trans_func import to_static_variable # noqa: F401 from .variable_trans_func import to_static_variable # noqa: F401
from .variable_trans_func import to_static_variable_gast_node # noqa: F401 from .convert_operators import convert_shape_compare # noqa: F401
__all__ = [] __all__ = []
...@@ -15,10 +15,6 @@ ...@@ -15,10 +15,6 @@
from __future__ import print_function from __future__ import print_function
from ...fluid.dygraph.dygraph_to_static.variable_trans_func import create_bool_as_type # noqa: F401 from ...fluid.dygraph.dygraph_to_static.variable_trans_func import create_bool_as_type # noqa: F401
from ...fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node # noqa: F401
from ...fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node # noqa: F401
from ...fluid.dygraph.dygraph_to_static.variable_trans_func import data_layer_not_check # noqa: F401
from ...fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable # noqa: F401 from ...fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable # noqa: F401
from ...fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable_gast_node # noqa: F401
__all__ = [] __all__ = []
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册