agent_base_test.py 4.4 KB
Newer Older
H
Hongsheng Zeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import unittest
from paddle import fluid
B
Bo Zhou 已提交
18
from parl import layers
B
Bo Zhou 已提交
19 20
import parl
import os
H
Hongsheng Zeng 已提交
21 22


B
Bo Zhou 已提交
23
class TestModel(parl.Model):
H
Hongsheng Zeng 已提交
24 25
    def __init__(self):
        self.fc1 = layers.fc(size=256)
B
Bo Zhou 已提交
26
        self.fc2 = layers.fc(size=1)
H
Hongsheng Zeng 已提交
27 28 29 30 31 32 33

    def policy(self, obs):
        out = self.fc1(obs)
        out = self.fc2(out)
        return out


B
Bo Zhou 已提交
34
class TestAlgorithm(parl.Algorithm):
B
Bo Zhou 已提交
35 36
    def __init__(self, model):
        self.model = model
H
Hongsheng Zeng 已提交
37

B
Bo Zhou 已提交
38
    def predict(self, obs):
H
Hongsheng Zeng 已提交
39 40
        return self.model.policy(obs)

B
Bo Zhou 已提交
41 42 43 44 45
    def learn(self, obs, label):
        pred_output = self.model.policy(obs)
        cost = layers.square_error_cost(obs, label)
        cost = fluid.layers.reduce_mean(cost)
        return cost
H
Hongsheng Zeng 已提交
46

B
Bo Zhou 已提交
47 48

class TestAgent(parl.Agent):
H
Hongsheng Zeng 已提交
49 50
    def __init__(self, algorithm):
        super(TestAgent, self).__init__(algorithm)
H
Hongsheng Zeng 已提交
51 52 53

    def build_program(self):
        self.predict_program = fluid.Program()
B
Bo Zhou 已提交
54
        self.learn_program = fluid.Program()
H
Hongsheng Zeng 已提交
55 56
        with fluid.program_guard(self.predict_program):
            obs = layers.data(name='obs', shape=[10], dtype='float32')
R
rical730 已提交
57
            output = self.alg.predict(obs)
H
Hongsheng Zeng 已提交
58 59
        self.predict_output = [output]

B
Bo Zhou 已提交
60 61 62
        with fluid.program_guard(self.learn_program):
            obs = layers.data(name='obs', shape=[10], dtype='float32')
            label = layers.data(name='label', shape=[1], dtype='float32')
R
rical730 已提交
63
            cost = self.alg.learn(obs, label)
B
Bo Zhou 已提交
64 65 66 67 68 69 70 71

    def learn(self, obs, label):
        output_np = self.fluid_executor.run(
            self.learn_program, feed={
                'obs': obs,
                'label': label
            })

H
Hongsheng Zeng 已提交
72 73 74 75 76 77 78 79 80 81 82
    def predict(self, obs):
        output_np = self.fluid_executor.run(
            self.predict_program,
            feed={'obs': obs},
            fetch_list=self.predict_output)[0]
        return output_np


class AgentBaseTest(unittest.TestCase):
    def setUp(self):
        self.model = TestModel()
R
rical730 已提交
83
        self.alg = TestAlgorithm(self.model)
H
Hongsheng Zeng 已提交
84

B
Bo Zhou 已提交
85
    def test_agent(self):
R
rical730 已提交
86
        agent = TestAgent(self.alg)
B
Bo Zhou 已提交
87 88 89 90 91
        obs = np.random.random([3, 10]).astype('float32')
        output_np = agent.predict(obs)
        self.assertIsNotNone(output_np)

    def test_save(self):
R
rical730 已提交
92
        agent = TestAgent(self.alg)
B
Bo Zhou 已提交
93 94
        obs = np.random.random([3, 10]).astype('float32')
        output_np = agent.predict(obs)
H
Hongsheng Zeng 已提交
95 96
        save_path1 = 'model.ckpt'
        save_path2 = os.path.join('my_model', 'model-2.ckpt')
B
Bo Zhou 已提交
97 98 99 100 101 102
        agent.save(save_path1)
        agent.save(save_path2)
        self.assertTrue(os.path.exists(save_path1))
        self.assertTrue(os.path.exists(save_path2))

    def test_restore(self):
R
rical730 已提交
103
        agent = TestAgent(self.alg)
B
Bo Zhou 已提交
104 105
        obs = np.random.random([3, 10]).astype('float32')
        output_np = agent.predict(obs)
H
Hongsheng Zeng 已提交
106
        save_path1 = 'model.ckpt'
B
Bo Zhou 已提交
107 108 109 110 111 112 113
        previous_output = agent.predict(obs)
        agent.save(save_path1)
        agent.restore(save_path1)
        current_output = agent.predict(obs)
        np.testing.assert_equal(current_output, previous_output)

        # a new agent instance
R
rical730 已提交
114
        another_agent = TestAgent(self.alg)
B
Bo Zhou 已提交
115 116 117
        another_agent.restore(save_path1)
        current_output = another_agent.predict(obs)
        np.testing.assert_equal(current_output, previous_output)
H
Hongsheng Zeng 已提交
118

L
LI Yunxiang 已提交
119
    def test_compiled_restore(self):
R
rical730 已提交
120
        agent = TestAgent(self.alg)
L
LI Yunxiang 已提交
121 122 123
        agent.learn_program = parl.compile(agent.learn_program)
        obs = np.random.random([3, 10]).astype('float32')
        previous_output = agent.predict(obs)
H
Hongsheng Zeng 已提交
124
        save_path1 = 'model.ckpt'
L
LI Yunxiang 已提交
125 126 127 128
        agent.save(save_path1)
        agent.restore(save_path1)

        # a new agent instance
R
rical730 已提交
129
        another_agent = TestAgent(self.alg)
L
LI Yunxiang 已提交
130 131 132 133 134
        another_agent.learn_program = parl.compile(another_agent.learn_program)
        another_agent.restore(save_path1)
        current_output = another_agent.predict(obs)
        np.testing.assert_equal(current_output, previous_output)

H
Hongsheng Zeng 已提交
135 136 137

if __name__ == '__main__':
    unittest.main()