未验证 提交 d066d6f9 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Dy2Stat] Change layers.data to fluid.data and Test Var Created In Loop (#23103)

As the title
上级 0c30098f
......@@ -20,8 +20,8 @@ import numpy
import six
from paddle.fluid import framework
from paddle.fluid.layers import io
from paddle.fluid import core, executor
from paddle.fluid.data import data
from paddle.fluid.dygraph.dygraph_to_static import convert_to_static
__all__ = ['AutoTracer']
......@@ -170,9 +170,9 @@ class ProgramCache(object):
batch_data, numpy.ndarray
), "Input {} should be numpy.ndarray, but received {}.".format(
feed_name, type(batch_data))
feed_layer = io.data(
feed_layer = data(
name=feed_name,
shape=list(batch_data.shape[1:]),
shape=[-1] + list(batch_data.shape[1:]),
dtype=str(batch_data.dtype))
self._inputs.append(feed_layer)
......
......@@ -29,7 +29,7 @@ def to_static_variable_gast_node(name):
def create_static_variable_gast_node(name):
func_code = "{} = fluid.layers.data(name='{}', shape=[-1], dtype='float32')".format(
func_code = "{} = fluid.data(name='{}', shape=[-1], dtype='float32')".format(
name, name)
return gast.parse(func_code).body[0]
......
......@@ -50,6 +50,12 @@ def while_loop_bool_op(x):
return i
def var_create_in_for_loop(max_len):
for i in range(max_len):
ret = fluid.layers.zeros(shape=[3, 4, 5], dtype='float64')
return ret
class TestNameVisitor(unittest.TestCase):
def setUp(self):
self.loop_funcs = [while_loop_dyfunc, for_loop_dyfunc]
......@@ -119,11 +125,15 @@ class TestTransformForLoop(unittest.TestCase):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
self.len = 100
self._init_dyfunc()
def _init_dyfunc(self):
self.dyfunc = for_loop_dyfunc
def _run_static(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
static_func = dygraph_to_static_graph(for_loop_dyfunc)
static_func = dygraph_to_static_graph(self.dyfunc)
out = static_func(self.len)
exe = fluid.Executor(self.place)
ret = exe.run(main_program, fetch_list=out)
......@@ -131,18 +141,19 @@ class TestTransformForLoop(unittest.TestCase):
def _run_dygraph(self):
with fluid.dygraph.guard(self.place):
ret = for_loop_dyfunc(self.len)
ret = self.dyfunc(self.len)
return ret.numpy()
def test_ast_to_func(self):
static_numpy = self._run_static()
self.assertTrue(
np.allclose(
np.full(
shape=(1), fill_value=2, dtype=np.int32), static_numpy))
self._run_dygraph()
self.assertTrue(np.allclose(self._run_dygraph(), self._run_static()))
class TestVarCreateInForLoop(TestTransformForLoop):
def _init_dyfunc(self):
self.dyfunc = var_create_in_for_loop
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册