From f9c9d50e7e7da194fca6a23f3a094a8a31194992 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Fri, 10 Apr 2020 11:21:19 +0800 Subject: [PATCH] Return VarBase of ProgramTranslator.get_output instead of numpy.ndarray. test=develop (#23663) --- .../fluid/dygraph/dygraph_to_static/program_translator.py | 3 +++ .../tests/unittests/dygraph_to_static/test_cache_program.py | 4 ++-- .../tests/unittests/dygraph_to_static/test_fetch_feed.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index 9682fceb08f..62f2e50c0d4 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -24,6 +24,7 @@ from collections import defaultdict from paddle.fluid import framework from paddle.fluid import core, executor +from paddle.fluid.dygraph import guard, to_variable from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import convert_to_static from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code @@ -285,6 +286,8 @@ class ProgramTranslator(object): *args, **kwargs) if not program_cache.in_build_process: outputs = self.run(*args, **kwargs) + with guard(): + outputs = [to_variable(x) for x in outputs] return outputs def get_func(self, dygraph_func): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py index 3b7cdae78ad..0e8b7d787a0 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py @@ -50,7 +50,7 @@ class TestCacheProgram(unittest.TestCase): ]) if batch_id > 0: self.assertTrue( - np.allclose(prev_out[0], cur_out[0]), + np.allclose(prev_out[0].numpy(), cur_out[0].numpy()), msg='Output in previous batch is {}\n Output in current batch is \n{}' .format(prev_out, cur_out)) self.assertEqual(prev_ops, cur_ops) @@ -81,7 +81,7 @@ class TestCacheProgramWithOptimizer(unittest.TestCase): for batch_id in range(self.batch_num): pred, avg_loss = static_net(self.data) - loss_data.append(np.array(avg_loss)) + loss_data.append(np.array(avg_loss.numpy())) return loss_data diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_fetch_feed.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_fetch_feed.py index 3c2e34aec3e..cc0c4cfe8f6 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_fetch_feed.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_fetch_feed.py @@ -82,7 +82,7 @@ class TestPool2D(unittest.TestCase): with fluid.program_guard(main_prog, startup_prog): dy_layer = self.dygraph_class() out = dy_layer(x=self.data) - return out[0] + return out[0].numpy() def test_static_output(self): dygraph_res = self.run_dygraph_mode() -- GitLab