agent_base.py 3.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle.fluid as fluid
import parl.layers as layers
B
Bo Zhou 已提交
17
from parl.framework.algorithm_base import Algorithm
B
Bo Zhou 已提交
18
from parl.framework.model_base import Model
H
Hongsheng Zeng 已提交
19
from parl.utils import get_gpu_count
20

B
Bo Zhou 已提交
21
__all__ = ['Agent']
22 23


B
Bo Zhou 已提交
24
class Agent(object):
25
    """
B
Bo Zhou 已提交
26
    A Agent is responsible for the general data flow
B
Bo Zhou 已提交
27
    outside the algorithm.
28

B
Bo Zhou 已提交
29
    A Agent is created in a bottom-up way:
30 31
    a. create a Model
    b. create an Algorithm with the model as an input
B
Bo Zhou 已提交
32
    c. define a Agent with the algorithm
33 34
    """

H
Hongsheng Zeng 已提交
35 36 37 38 39 40 41 42
    def __init__(self, algorithm, gpu_id=None):
        """ build program and run initialization for default_startup_program
        
        Created object:
            self.alg: parl.framework.Algorithm
            self.gpu_id: int
            self.fluid_executor: fluid.Executor
        """
43 44
        assert isinstance(algorithm, Algorithm)
        self.alg = algorithm
H
Hongsheng Zeng 已提交
45

B
Bo Zhou 已提交
46
        self.build_program()
H
Hongsheng Zeng 已提交
47 48 49 50

        if gpu_id is None:
            gpu_id = 0 if get_gpu_count() > 0 else -1
        self.gpu_id = gpu_id
51 52 53
        self.place = fluid.CUDAPlace(
            gpu_id) if gpu_id >= 0 else fluid.CPUPlace()
        self.fluid_executor = fluid.Executor(self.place)
54 55
        self.fluid_executor.run(fluid.default_startup_program())

B
Bo Zhou 已提交
56 57 58 59
    def build_program(self):
        """build your training program and prediction program here, 
        using the functions define_learn and define_predict in algorithm.
        
B
Bo Zhou 已提交
60 61 62
        Note that it's unnecessary to call this function explictly since 
        it will be called automatically in the initialization function. 
        
B
Bo Zhou 已提交
63 64 65 66 67
        To build the program, you may need to do the following:
        a. create a new program in fluid with program guard
        b. define your data layer
        c. build your training/prediction program, pass the data variable 
           defined in step b to `define_training/define_prediction` of algorithm
68
        """
B
Bo Zhou 已提交
69
        raise NotImplementedError
70

B
Bo Zhou 已提交
71 72
    def predict(self, obs):
        """This function will predict the action given current observation of the enviroment.
73

B
Bo Zhou 已提交
74 75 76
        Note that this function will only do the prediction and it doesn't try any exploration,
        To explore in the action space, you should create your process in `sample` function below.
        In formally, this function is often used in test process.
77
        """
B
Bo Zhou 已提交
78 79 80 81 82 83
        raise NotImplementedError

    def sample(self, obs):
        """This function will predict the action given current observation of the enviroment.
        Additionaly, action will be added noise here to explore a new trajectory. In formally,
        this function is often used in training process.
84
        """
B
Bo Zhou 已提交
85
        raise NotImplementedError
86

B
Bo Zhou 已提交
87 88
    def learn(self, obs, action, reward, next_obs, terminal):
        """pass data to the training program to update model, 
B
Bo Zhou 已提交
89
        this function is the training interface for Agent.
B
Bo Zhou 已提交
90 91
        """
        raise NotImplementedError
H
Hongsheng Zeng 已提交
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107

    def get_params(self):
        """ Get parameters of self.alg

        Returns:
            List of numpy array. 
        """
        return self.alg.get_params()

    def set_params(self, params):
        """ Set parameters of self.alg

        Args:
            params: List of numpy array.
        """
        self.alg.set_params(params, gpu_id=self.gpu_id)