diff --git a/parl/core/fluid/agent.py b/parl/core/fluid/agent.py index 12ff53551cccfef1f6ce60f069589b4958515269..8972443c453e75e022751cee707d9bbaeda649df 100644 --- a/parl/core/fluid/agent.py +++ b/parl/core/fluid/agent.py @@ -212,6 +212,8 @@ class Agent(AgentBase): if program is None: program = self.learn_program + if type(program) is fluid.compiler.CompiledProgram: + program = program._init_program dirname = '/'.join(save_path.split('/')[:-1]) filename = save_path.split('/')[-1] fluid.io.load_params( diff --git a/parl/core/fluid/plutils/compiler.py b/parl/core/fluid/plutils/compiler.py index 846343f7a3a0138f4f0c25af623d08124b6030dd..da689414f1a8a3568ef000635c3fe3be6ae168aa 100644 --- a/parl/core/fluid/plutils/compiler.py +++ b/parl/core/fluid/plutils/compiler.py @@ -40,7 +40,10 @@ def compile(program, loss=None): build_strategy = fluid.BuildStrategy() build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce - return fluid.compiler.CompiledProgram(program).with_data_parallel( - loss_name=loss_name, - exec_strategy=exec_strategy, - build_strategy=build_strategy) + compiled_program = fluid.compiler.CompiledProgram( + program).with_data_parallel( + loss_name=loss_name, + exec_strategy=exec_strategy, + build_strategy=build_strategy) + compiled_program._init_program = program + return compiled_program diff --git a/parl/core/fluid/tests/agent_base_test_.py b/parl/core/fluid/tests/agent_base_test_.py index 2e3ff1a904c5e8bd7efe259452b502ecf369490b..cd8ca7d06f72ae99c51f12e639d9b0de1080ba7f 100644 --- a/parl/core/fluid/tests/agent_base_test_.py +++ b/parl/core/fluid/tests/agent_base_test_.py @@ -116,6 +116,22 @@ class AgentBaseTest(unittest.TestCase): current_output = another_agent.predict(obs) np.testing.assert_equal(current_output, previous_output) + def test_compiled_restore(self): + agent = TestAgent(self.algorithm) + agent.learn_program = parl.compile(agent.learn_program) + obs = np.random.random([3, 10]).astype('float32') + previous_output = agent.predict(obs) + save_path1 = './model.ckpt' + agent.save(save_path1) + agent.restore(save_path1) + + # a new agent instance + another_agent = TestAgent(self.algorithm) + another_agent.learn_program = parl.compile(another_agent.learn_program) + another_agent.restore(save_path1) + current_output = another_agent.predict(obs) + np.testing.assert_equal(current_output, previous_output) + if __name__ == '__main__': unittest.main()