提交 7f456dc7 编写于 作者: L LI Yunxiang 提交者: Bo Zhou

add torch td3 (#176)

上级 6e7f862e
## Reproduce TD3 with PARL
Based on PARL, the TD3 algorithm of deep reinforcement learning has been reproduced, reaching the same level of indicators as the paper in Mujoco benchmarks.
Include following approaches:
+ Clipped Double Q-learning
+ Target Networks and Delayed Policy Update
+ Target Policy Smoothing Regularization
> TD3 in
[Addressing Function Approximation Error in Actor-Critic Methods](https://arxiv.org/abs/1802.09477)
### Mujoco games introduction
Please see [here](https://github.com/openai/mujoco-py) to know more about Mujoco games.
### Benchmark result
<img src=".benchmark/merge.png" width = "1500" height ="260" alt="Performance" />
## How to use
### Dependencies:
+ python
+ [parl](https://github.com/PaddlePaddle/PARL)
+ gym
+ torch
+ mujoco-py>=1.50.1.0
### Start Training:
```
# To train an agent for HalfCheetah-v2 game
python train.py
# To train for different game and different loss type
# python train.py --env [ENV_NAME]
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import parl
import torch
import numpy as np
class MujocoAgent(parl.Agent):
def __init__(self, algorithm, obs_dim, act_dim):
assert isinstance(obs_dim, int)
assert isinstance(act_dim, int)
super(MujocoAgent, self).__init__(algorithm)
self.device = torch.device("cuda" if torch.cuda.
is_available() else "cpu")
self.alg.sync_target(decay=0)
def predict(self, obs):
obs = torch.FloatTensor(obs.reshape(1, -1)).to(self.device)
return self.alg.predict(obs).cpu().data.numpy().flatten()
def learn(self, obs, act, reward, next_obs, terminal):
terminal = np.expand_dims(terminal, -1)
reward = np.expand_dims(reward, -1)
obs = torch.FloatTensor(obs).to(self.device)
act = torch.FloatTensor(act).to(self.device)
reward = torch.FloatTensor(reward).to(self.device)
next_obs = torch.FloatTensor(next_obs).to(self.device)
terminal = torch.FloatTensor(terminal).to(self.device)
self.alg.learn(obs, act, reward, next_obs, terminal)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import parl
import torch
import torch.nn as nn
import torch.nn.functional as F
class MujocoModel(parl.Model):
def __init__(self, obs_dim, act_dim, max_action):
super(MujocoModel, self).__init__()
self.actor_model = Actor(obs_dim, act_dim, max_action)
self.critic_model = Critic(obs_dim, act_dim)
def policy(self, obs):
return self.actor_model(obs)
def value(self, obs, act):
return self.critic_model(obs, act)
def Q1(self, obs, act):
return self.critic_model.Q1(obs, act)
def get_actor_params(self):
return self.actor_model.parameters()
def get_critic_params(self):
return self.critic_model.parameters()
class Actor(parl.Model):
def __init__(self, obs_dim, action_dim, max_action):
super(Actor, self).__init__()
self.l1 = nn.Linear(obs_dim, 256)
self.l2 = nn.Linear(256, 256)
self.l3 = nn.Linear(256, action_dim)
self.max_action = max_action
def forward(self, obs):
a = F.relu(self.l1(obs))
a = F.relu(self.l2(a))
return self.max_action * torch.tanh(self.l3(a))
class Critic(parl.Model):
def __init__(self, obs_dim, action_dim):
super(Critic, self).__init__()
# Q1 architecture
self.l1 = nn.Linear(obs_dim + action_dim, 256)
self.l2 = nn.Linear(256, 256)
self.l3 = nn.Linear(256, 1)
# Q2 architecture
self.l4 = nn.Linear(obs_dim + action_dim, 256)
self.l5 = nn.Linear(256, 256)
self.l6 = nn.Linear(256, 1)
def forward(self, obs, action):
sa = torch.cat([obs, action], 1)
q1 = F.relu(self.l1(sa))
q1 = F.relu(self.l2(q1))
q1 = self.l3(q1)
q2 = F.relu(self.l4(sa))
q2 = F.relu(self.l5(q2))
q2 = self.l6(q2)
return q1, q2
def Q1(self, obs, action):
sa = torch.cat([obs, action], 1)
q1 = F.relu(self.l1(sa))
q1 = F.relu(self.l2(q1))
q1 = self.l3(q1)
return q1
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gym
import argparse
import numpy as np
from parl.utils import logger, tensorboard, ReplayMemory
from mujoco_model import MujocoModel
from mujoco_agent import MujocoAgent
from parl.algorithms import TD3
MAX_EPISODES = 5000
ACTOR_LR = 3e-4
CRITIC_LR = 3e-4
GAMMA = 0.99
TAU = 0.005
MEMORY_SIZE = int(1e6)
WARMUP_SIZE = 1e4
BATCH_SIZE = 256
ENV_SEED = 1
EXPL_NOISE = 0.1 # Std of Gaussian exploration noise
def run_train_episode(env, agent, rpm):
obs = env.reset()
total_reward = 0
steps = 0
max_action = float(env.action_space.high[0])
while True:
steps += 1
if rpm.size() < WARMUP_SIZE:
action = env.action_space.sample()
else:
action = np.random.normal(
agent.predict(np.array(obs)), max_action * EXPL_NOISE).clip(
-max_action, max_action)
next_obs, reward, done, info = env.step(action)
rpm.append(obs, action, reward, next_obs, done)
obs = next_obs
total_reward += reward
if rpm.size() > WARMUP_SIZE:
batch_obs, batch_action, batch_reward, batch_next_obs, batch_terminal = rpm.sample_batch(
BATCH_SIZE)
agent.learn(batch_obs, batch_action, batch_reward, batch_next_obs,
batch_terminal)
if done:
break
return total_reward, steps
def run_evaluate_episode(env, agent):
obs = env.reset()
total_reward = 0
while True:
action = agent.predict(np.array(obs))
obs, reward, done, _ = env.step(action)
total_reward += reward
if done:
break
return total_reward
def main():
env = gym.make(args.env)
env.seed(ENV_SEED)
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
model = MujocoModel(obs_dim, act_dim, max_action)
algorithm = TD3(
model,
max_action=max_action,
gamma=GAMMA,
tau=TAU,
actor_lr=ACTOR_LR,
critic_lr=CRITIC_LR)
agent = MujocoAgent(algorithm, obs_dim, act_dim)
rpm = ReplayMemory(MEMORY_SIZE, obs_dim, act_dim)
test_flag = 0
total_steps = 0
while total_steps < args.train_total_steps:
train_reward, steps = run_train_episode(env, agent, rpm)
total_steps += steps
logger.info('Steps: {} Reward: {}'.format(total_steps, train_reward))
tensorboard.add_scalar('train/episode_reward', train_reward,
total_steps)
if total_steps // args.test_every_steps >= test_flag:
while total_steps // args.test_every_steps >= test_flag:
test_flag += 1
evaluate_reward = run_evaluate_episode(env, agent)
logger.info('Steps {}, Evaluate reward: {}'.format(
total_steps, evaluate_reward))
tensorboard.add_scalar('eval/episode_reward', evaluate_reward,
total_steps)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--env', help='Mujoco environment name', default='HalfCheetah-v2')
parser.add_argument(
'--train_total_steps',
type=int,
default=int(3e6),
help='maximum training steps')
parser.add_argument(
'--test_every_steps',
type=int,
default=int(1e4),
help='the step interval between two consecutive evaluations')
args = parser.parse_args()
main()
......@@ -15,3 +15,4 @@
from parl.algorithms.torch.ddqn import *
from parl.algorithms.torch.dqn import *
from parl.algorithms.torch.a2c import *
from parl.algorithms.torch.td3 import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import numpy as np
import torch
import torch.nn.functional as F
from copy import deepcopy
import parl
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
__all__ = ['TD3']
class TD3(parl.Algorithm):
def __init__(
self,
model,
max_action,
gamma=None,
tau=None,
actor_lr=None,
critic_lr=None,
policy_noise=0.2, # Noise added to target policy during critic update
noise_clip=0.5, # Range to clip target policy noise
policy_freq=2): # Frequency of delayed policy updates
assert isinstance(gamma, float)
assert isinstance(tau, float)
assert isinstance(actor_lr, float)
assert isinstance(critic_lr, float)
self.max_action = max_action
self.gamma = gamma
self.tau = tau
self.actor_lr = actor_lr
self.critic_lr = critic_lr
self.policy_noise = policy_noise
self.noise_clip = noise_clip
self.policy_freq = policy_freq
self.model = model.to(device)
self.target_model = deepcopy(model).to(device)
self.actor_optimizer = torch.optim.Adam(
self.model.get_actor_params(), lr=actor_lr)
self.critic_optimizer = torch.optim.Adam(
self.model.get_critic_params(), lr=critic_lr)
self.total_it = 0
def predict(self, obs):
return self.model.policy(obs)
def learn(self, obs, action, reward, next_obs, terminal):
self.total_it += 1
with torch.no_grad():
noise = (torch.randn_like(action) * self.policy_noise).clamp(
-self.noise_clip, self.noise_clip)
next_action = (self.target_model.policy(next_obs) + noise).clamp(
-self.max_action, self.max_action)
target_Q1, target_Q2 = self.target_model.value(
next_obs, next_action)
target_Q = torch.min(target_Q1, target_Q2)
target_Q = reward + (1 - terminal) * self.gamma * target_Q
current_Q1, current_Q2 = self.model.value(obs, action)
critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
current_Q2, target_Q)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
if self.total_it % self.policy_freq == 0:
actor_loss = -self.model.Q1(obs, self.model.policy(obs)).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
self.sync_target()
def sync_target(self, decay=None):
if decay is None:
decay = 1.0 - self.tau
for param, target_param in zip(self.model.parameters(),
self.target_model.parameters()):
target_param.data.copy_((1 - decay) * param.data +
decay * target_param.data)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册