未验证 提交 d066d6f9 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Dy2Stat] Change layers.data to fluid.data and Test Var Created In Loop (#23103)

As the title
上级 0c30098f
...@@ -20,8 +20,8 @@ import numpy ...@@ -20,8 +20,8 @@ import numpy
import six import six
from paddle.fluid import framework from paddle.fluid import framework
from paddle.fluid.layers import io
from paddle.fluid import core, executor from paddle.fluid import core, executor
from paddle.fluid.data import data
from paddle.fluid.dygraph.dygraph_to_static import convert_to_static from paddle.fluid.dygraph.dygraph_to_static import convert_to_static
__all__ = ['AutoTracer'] __all__ = ['AutoTracer']
...@@ -170,9 +170,9 @@ class ProgramCache(object): ...@@ -170,9 +170,9 @@ class ProgramCache(object):
batch_data, numpy.ndarray batch_data, numpy.ndarray
), "Input {} should be numpy.ndarray, but received {}.".format( ), "Input {} should be numpy.ndarray, but received {}.".format(
feed_name, type(batch_data)) feed_name, type(batch_data))
feed_layer = io.data( feed_layer = data(
name=feed_name, name=feed_name,
shape=list(batch_data.shape[1:]), shape=[-1] + list(batch_data.shape[1:]),
dtype=str(batch_data.dtype)) dtype=str(batch_data.dtype))
self._inputs.append(feed_layer) self._inputs.append(feed_layer)
......
...@@ -29,7 +29,7 @@ def to_static_variable_gast_node(name): ...@@ -29,7 +29,7 @@ def to_static_variable_gast_node(name):
def create_static_variable_gast_node(name): def create_static_variable_gast_node(name):
func_code = "{} = fluid.layers.data(name='{}', shape=[-1], dtype='float32')".format( func_code = "{} = fluid.data(name='{}', shape=[-1], dtype='float32')".format(
name, name) name, name)
return gast.parse(func_code).body[0] return gast.parse(func_code).body[0]
......
...@@ -50,6 +50,12 @@ def while_loop_bool_op(x): ...@@ -50,6 +50,12 @@ def while_loop_bool_op(x):
return i return i
def var_create_in_for_loop(max_len):
for i in range(max_len):
ret = fluid.layers.zeros(shape=[3, 4, 5], dtype='float64')
return ret
class TestNameVisitor(unittest.TestCase): class TestNameVisitor(unittest.TestCase):
def setUp(self): def setUp(self):
self.loop_funcs = [while_loop_dyfunc, for_loop_dyfunc] self.loop_funcs = [while_loop_dyfunc, for_loop_dyfunc]
...@@ -119,11 +125,15 @@ class TestTransformForLoop(unittest.TestCase): ...@@ -119,11 +125,15 @@ class TestTransformForLoop(unittest.TestCase):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace() ) else fluid.CPUPlace()
self.len = 100 self.len = 100
self._init_dyfunc()
def _init_dyfunc(self):
self.dyfunc = for_loop_dyfunc
def _run_static(self): def _run_static(self):
main_program = fluid.Program() main_program = fluid.Program()
with fluid.program_guard(main_program): with fluid.program_guard(main_program):
static_func = dygraph_to_static_graph(for_loop_dyfunc) static_func = dygraph_to_static_graph(self.dyfunc)
out = static_func(self.len) out = static_func(self.len)
exe = fluid.Executor(self.place) exe = fluid.Executor(self.place)
ret = exe.run(main_program, fetch_list=out) ret = exe.run(main_program, fetch_list=out)
...@@ -131,18 +141,19 @@ class TestTransformForLoop(unittest.TestCase): ...@@ -131,18 +141,19 @@ class TestTransformForLoop(unittest.TestCase):
def _run_dygraph(self): def _run_dygraph(self):
with fluid.dygraph.guard(self.place): with fluid.dygraph.guard(self.place):
ret = for_loop_dyfunc(self.len) ret = self.dyfunc(self.len)
return ret.numpy() return ret.numpy()
def test_ast_to_func(self): def test_ast_to_func(self):
static_numpy = self._run_static() static_numpy = self._run_static()
self.assertTrue(
np.allclose(
np.full(
shape=(1), fill_value=2, dtype=np.int32), static_numpy))
self._run_dygraph() self._run_dygraph()
self.assertTrue(np.allclose(self._run_dygraph(), self._run_static())) self.assertTrue(np.allclose(self._run_dygraph(), self._run_static()))
class TestVarCreateInForLoop(TestTransformForLoop):
def _init_dyfunc(self):
self.dyfunc = var_create_in_for_loop
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册