diff --git a/README.md b/README.md index 5476837d07889f88af3b746b8d86cf85e11b4fd6..7566023b58290d1b3eb3017c81985635d4be99b2 100644 --- a/README.md +++ b/README.md @@ -31,8 +31,24 @@ Here is an example of building an agent with DQN algorithm for atari games. import parl from parl.algorithms import DQN, DDQN -class CriticModel(parl.Model): -""" define specific forward model for environment ...""" +class AtariModel(parl.Model): + """AtariModel + This class defines the forward part for an algorithm, + its input is state observed on environment. + """ + def __init__(self, img_shape, action_dim): + # define your layers + self.cnn1 = layers.conv_2d(num_filters=32, filter_size=5, + stride=[1, 1], padding=[2, 2], act='relu') + ... + self.fc1 = layers.fc(action_dim) + def value(self, img): + # define how to estimate the Q value based on the image of atari games. + img = img / 255.0 + l = self.cnn1(img) + ... + Q = self.fc1(l) + return Q """ three steps to build an agent 1. define a forward model which is critic_model is this example @@ -41,8 +57,8 @@ three steps to build an agent 3. define the I/O part in AtariAgent so that it could update the algorithm based on the interactive data """ -critic_model = CriticModel(act_dim=2) -algorithm = DQN(critic_model) +model = AtariModel(img_shape=(32, 32), action_dim=4) +algorithm = DQN(model) agent = AtariAgent(aglrotihm) ```