未验证 提交 f9c9d50e 编写于 作者: L liym27 提交者: GitHub

Return VarBase of ProgramTranslator.get_output instead of numpy.ndarray. test=develop (#23663)

上级 01a78323
......@@ -24,6 +24,7 @@ from collections import defaultdict
from paddle.fluid import framework
from paddle.fluid import core, executor
from paddle.fluid.dygraph import guard, to_variable
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import convert_to_static
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
......@@ -285,6 +286,8 @@ class ProgramTranslator(object):
*args, **kwargs)
if not program_cache.in_build_process:
outputs = self.run(*args, **kwargs)
with guard():
outputs = [to_variable(x) for x in outputs]
return outputs
def get_func(self, dygraph_func):
......
......@@ -50,7 +50,7 @@ class TestCacheProgram(unittest.TestCase):
])
if batch_id > 0:
self.assertTrue(
np.allclose(prev_out[0], cur_out[0]),
np.allclose(prev_out[0].numpy(), cur_out[0].numpy()),
msg='Output in previous batch is {}\n Output in current batch is \n{}'
.format(prev_out, cur_out))
self.assertEqual(prev_ops, cur_ops)
......@@ -81,7 +81,7 @@ class TestCacheProgramWithOptimizer(unittest.TestCase):
for batch_id in range(self.batch_num):
pred, avg_loss = static_net(self.data)
loss_data.append(np.array(avg_loss))
loss_data.append(np.array(avg_loss.numpy()))
return loss_data
......
......@@ -82,7 +82,7 @@ class TestPool2D(unittest.TestCase):
with fluid.program_guard(main_prog, startup_prog):
dy_layer = self.dygraph_class()
out = dy_layer(x=self.data)
return out[0]
return out[0].numpy()
def test_static_output(self):
dygraph_res = self.run_dygraph_mode()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册