Deep Q Network (DQN) Model

10import torch
11from torch import nn
12
13from labml_helpers.module import Module

Dueling Network ⚔️ Model for $Q$ Values

We are using a dueling network to calculate Q-values. Intuition behind dueling network architecture is that in most states the action doesn’t matter, and in some states the action is significant. Dueling network allows this to be represented very well.

So we create two networks for $V$ and $A$ and get $Q$ from them. We share the initial layers of the $V$ and $A$ networks.

16class Model(Module):
47    def __init__(self):
48        super().__init__()
49        self.conv = nn.Sequential(

The first convolution layer takes a $84\times84$ frame and produces a $20\times20$ frame

52            nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
53            nn.ReLU(),

The second convolution layer takes a $20\times20$ frame and produces a $9\times9$ frame

57            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
58            nn.ReLU(),

The third convolution layer takes a $9\times9$ frame and produces a $7\times7$ frame

62            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
63            nn.ReLU(),
64        )

A fully connected layer takes the flattened frame from third convolution layer, and outputs $512$ features

69        self.lin = nn.Linear(in_features=7 * 7 * 64, out_features=512)
70        self.activation = nn.ReLU()

This head gives the state value $V$

73        self.state_value = nn.Sequential(
74            nn.Linear(in_features=512, out_features=256),
75            nn.ReLU(),
76            nn.Linear(in_features=256, out_features=1),
77        )

This head gives the action value $A$

79        self.action_value = nn.Sequential(
80            nn.Linear(in_features=512, out_features=256),
81            nn.ReLU(),
82            nn.Linear(in_features=256, out_features=4),
83        )
85    def __call__(self, obs: torch.Tensor):

Convolution

87        h = self.conv(obs)

Reshape for linear layers

89        h = h.reshape((-1, 7 * 7 * 64))

Linear layer

92        h = self.activation(self.lin(h))

$A$

95        action_value = self.action_value(h)

$V$

97        state_value = self.state_value(h)

$A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a’ \in \mathcal{A}} A(s, a’)$

100        action_score_centered = action_value - action_value.mean(dim=-1, keepdim=True)

$Q(s, a) =V(s) + \Big(A(s, a) - \frac{1}{|\mathcal{A}|} \sum_{a’ \in \mathcal{A}} A(s, a’)\Big)$

102        q = state_value + action_score_centered
103
104        return q