From 5054efedd11b93fb34d289ec54a20f01e3156ea6 Mon Sep 17 00:00:00 2001 From: LI Yunxiang <39279048+Banmahhhh@users.noreply.github.com> Date: Mon, 23 Dec 2019 15:44:09 +0800 Subject: [PATCH] fix compiled_program restore (#192) --- parl/core/fluid/agent.py | 2 ++ parl/core/fluid/plutils/compiler.py | 11 +++++++---- parl/core/fluid/tests/agent_base_test_.py | 16 ++++++++++++++++ 3 files changed, 25 insertions(+), 4 deletions(-) diff --git a/parl/core/fluid/agent.py b/parl/core/fluid/agent.py index 12ff535..8972443 100644 --- a/parl/core/fluid/agent.py +++ b/parl/core/fluid/agent.py @@ -212,6 +212,8 @@ class Agent(AgentBase): if program is None: program = self.learn_program + if type(program) is fluid.compiler.CompiledProgram: + program = program._init_program dirname = '/'.join(save_path.split('/')[:-1]) filename = save_path.split('/')[-1] fluid.io.load_params( diff --git a/parl/core/fluid/plutils/compiler.py b/parl/core/fluid/plutils/compiler.py index 846343f..da68941 100644 --- a/parl/core/fluid/plutils/compiler.py +++ b/parl/core/fluid/plutils/compiler.py @@ -40,7 +40,10 @@ def compile(program, loss=None): build_strategy = fluid.BuildStrategy() build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce - return fluid.compiler.CompiledProgram(program).with_data_parallel( - loss_name=loss_name, - exec_strategy=exec_strategy, - build_strategy=build_strategy) + compiled_program = fluid.compiler.CompiledProgram( + program).with_data_parallel( + loss_name=loss_name, + exec_strategy=exec_strategy, + build_strategy=build_strategy) + compiled_program._init_program = program + return compiled_program diff --git a/parl/core/fluid/tests/agent_base_test_.py b/parl/core/fluid/tests/agent_base_test_.py index 2e3ff1a..cd8ca7d 100644 --- a/parl/core/fluid/tests/agent_base_test_.py +++ b/parl/core/fluid/tests/agent_base_test_.py @@ -116,6 +116,22 @@ class AgentBaseTest(unittest.TestCase): current_output = another_agent.predict(obs) np.testing.assert_equal(current_output, previous_output) + def test_compiled_restore(self): + agent = TestAgent(self.algorithm) + agent.learn_program = parl.compile(agent.learn_program) + obs = np.random.random([3, 10]).astype('float32') + previous_output = agent.predict(obs) + save_path1 = './model.ckpt' + agent.save(save_path1) + agent.restore(save_path1) + + # a new agent instance + another_agent = TestAgent(self.algorithm) + another_agent.learn_program = parl.compile(another_agent.learn_program) + another_agent.restore(save_path1) + current_output = another_agent.predict(obs) + np.testing.assert_equal(current_output, previous_output) + if __name__ == '__main__': unittest.main() -- GitLab