未验证 提交 f9c9d50e 编写于 作者: L liym27 提交者: GitHub

Return VarBase of ProgramTranslator.get_output instead of numpy.ndarray. test=develop (#23663)

上级 01a78323
...@@ -24,6 +24,7 @@ from collections import defaultdict ...@@ -24,6 +24,7 @@ from collections import defaultdict
from paddle.fluid import framework from paddle.fluid import framework
from paddle.fluid import core, executor from paddle.fluid import core, executor
from paddle.fluid.dygraph import guard, to_variable
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import convert_to_static from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import convert_to_static
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
...@@ -285,6 +286,8 @@ class ProgramTranslator(object): ...@@ -285,6 +286,8 @@ class ProgramTranslator(object):
*args, **kwargs) *args, **kwargs)
if not program_cache.in_build_process: if not program_cache.in_build_process:
outputs = self.run(*args, **kwargs) outputs = self.run(*args, **kwargs)
with guard():
outputs = [to_variable(x) for x in outputs]
return outputs return outputs
def get_func(self, dygraph_func): def get_func(self, dygraph_func):
......
...@@ -50,7 +50,7 @@ class TestCacheProgram(unittest.TestCase): ...@@ -50,7 +50,7 @@ class TestCacheProgram(unittest.TestCase):
]) ])
if batch_id > 0: if batch_id > 0:
self.assertTrue( self.assertTrue(
np.allclose(prev_out[0], cur_out[0]), np.allclose(prev_out[0].numpy(), cur_out[0].numpy()),
msg='Output in previous batch is {}\n Output in current batch is \n{}' msg='Output in previous batch is {}\n Output in current batch is \n{}'
.format(prev_out, cur_out)) .format(prev_out, cur_out))
self.assertEqual(prev_ops, cur_ops) self.assertEqual(prev_ops, cur_ops)
...@@ -81,7 +81,7 @@ class TestCacheProgramWithOptimizer(unittest.TestCase): ...@@ -81,7 +81,7 @@ class TestCacheProgramWithOptimizer(unittest.TestCase):
for batch_id in range(self.batch_num): for batch_id in range(self.batch_num):
pred, avg_loss = static_net(self.data) pred, avg_loss = static_net(self.data)
loss_data.append(np.array(avg_loss)) loss_data.append(np.array(avg_loss.numpy()))
return loss_data return loss_data
......
...@@ -82,7 +82,7 @@ class TestPool2D(unittest.TestCase): ...@@ -82,7 +82,7 @@ class TestPool2D(unittest.TestCase):
with fluid.program_guard(main_prog, startup_prog): with fluid.program_guard(main_prog, startup_prog):
dy_layer = self.dygraph_class() dy_layer = self.dygraph_class()
out = dy_layer(x=self.data) out = dy_layer(x=self.data)
return out[0] return out[0].numpy()
def test_static_output(self): def test_static_output(self):
dygraph_res = self.run_dygraph_mode() dygraph_res = self.run_dygraph_mode()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册