Atari wrapper with multi-processing

9import multiprocessing
10import multiprocessing.connection
11
12import cv2
13import gym
14import numpy as np

Game environment

This is a wrapper for OpenAI gym game environment. We do a few things here:

  1. Apply the same action on four frames and get the last frame
  2. Convert observation frames to gray and scale it to (84, 84)
  3. Stack four frames of the last four actions
  4. Add episode information (total reward for the entire episode) for monitoring
  5. Restrict an episode to a single life (game has 5 lives, we reset after every single life)

Observation format

Observation is tensor of size (4, 84, 84). It is four frames (images of the game screen) stacked on first axis. i.e, each channel is a frame.

17class Game:
35    def __init__(self, seed: int):

create environment

37        self.env = gym.make('BreakoutNoFrameskip-v4')
38        self.env.seed(seed)

tensor for a stack of 4 frames

41        self.obs_4 = np.zeros((4, 84, 84))

buffer to keep the maximum of last 2 frames

44        self.obs_2_max = np.zeros((2, 84, 84))

keep track of the episode rewards

47        self.rewards = []

and number of lives left

49        self.lives = 0

Step

Executes action for 4 time steps and returns a tuple of (observation, reward, done, episode_info).

  • observation: stacked 4 frames (this frame and frames for last 3 actions)
  • reward: total reward while the action was executed
  • done: whether the episode finished (a life lost)
  • episode_info: episode information if completed
51    def step(self, action):
63        reward = 0.
64        done = None

run for 4 steps

67        for i in range(4):

execute the action in the OpenAI Gym environment

69            obs, r, done, info = self.env.step(action)
70
71            if i >= 2:
72                self.obs_2_max[i % 2] = self._process_obs(obs)
73
74            reward += r

get number of lives left

77            lives = self.env.unwrapped.ale.lives()

reset if a life is lost

79            if lives < self.lives:
80                done = True
81                break

maintain rewards for each step

84        self.rewards.append(reward)
85
86        if done:

if finished, set episode information if episode is over, and reset

88            episode_info = {"reward": sum(self.rewards), "length": len(self.rewards)}
89            self.reset()
90        else:
91            episode_info = None

get the max of last two frames

94            obs = self.obs_2_max.max(axis=0)

push it to the stack of 4 frames

97            self.obs_4 = np.roll(self.obs_4, shift=-1, axis=0)
98            self.obs_4[-1] = obs
99
100        return self.obs_4, reward, done, episode_info

Reset environment

Clean up episode info and 4 frame stack

102    def reset(self):

reset OpenAI Gym environment

109        obs = self.env.reset()

reset caches

112        obs = self._process_obs(obs)
113        for i in range(4):
114            self.obs_4[i] = obs
115        self.rewards = []
116
117        self.lives = self.env.unwrapped.ale.lives()
118
119        return self.obs_4

Process game frames

Convert game frames to gray and rescale to 84x84

121    @staticmethod
122    def _process_obs(obs):
127        obs = cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
128        obs = cv2.resize(obs, (84, 84), interpolation=cv2.INTER_AREA)
129        return obs

Worker Process

Each worker process runs this method

132def worker_process(remote: multiprocessing.connection.Connection, seed: int):

create game

140    game = Game(seed)

wait for instructions from the connection and execute them

143    while True:
144        cmd, data = remote.recv()
145        if cmd == "step":
146            remote.send(game.step(data))
147        elif cmd == "reset":
148            remote.send(game.reset())
149        elif cmd == "close":
150            remote.close()
151            break
152        else:
153            raise NotImplementedError

Creates a new worker and runs it in a separate process.

156class Worker:
161    def __init__(self, seed):
162        self.child, parent = multiprocessing.Pipe()
163        self.process = multiprocessing.Process(target=worker_process, args=(parent, seed))
164        self.process.start()