提交 53c94787 编写于 作者: H Hongsheng Zeng 提交者: Bo Zhou

Add DDPG example (#36)

* add DDPG example, fix some tiny bug

* add license

* unify code structure

* unify code structure

* refine gputils, fix seed in QuickStart

* use white noise in DDPG

* fix codestyle
上级 58e8fe28
......@@ -75,6 +75,6 @@ pip install --upgrade git+https://github.com/PaddlePaddle/PARL.git
# Examples
- [QuickStart](examples/QuickStart/)
- [DQN](examples/DQN/)
- DDPG
- [DDPG](examples/DDPG/)
- PPO
- [Winning Solution for NIPS2018: AI for Prosthetics Challenge](examples/NeurIPS2018-AI-for-Prosthetics-Challenge/)
## Reproduce DDPG with PARL
Based on PARL, the DDPG model of deep reinforcement learning is reproduced, and the same level of indicators of the paper is reproduced in the classic Mujoco game.
+ DDPG in
[Continuous control with deep reinforcement learning](https://arxiv.org/abs/1509.02971)
### Mujoco games introduction
Please see [here](https://github.com/openai/mujoco-py) to know more about Mujoco game.
## How to use
### Dependencies:
+ python2.7 or python3.5+
+ [PARL](https://github.com/PaddlePaddle/PARL)
+ [paddlepaddle>=1.0.0](https://github.com/PaddlePaddle/Paddle)
+ gym
+ tqdm
+ mujoco-py>=1.50.1.0
### Start Training:
```
# To train an agent for HalfCheetah-v2 game
python train.py
# To train for other game
# python train.py --env [ENV_NAME]
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import parl.layers as layers
from paddle import fluid
from parl.framework.agent_base import Agent
class MujocoAgent(Agent):
def __init__(self, algorithm, obs_dim, act_dim):
self.obs_dim = obs_dim
self.act_dim = act_dim
super(MujocoAgent, self).__init__(algorithm)
# Attention: In the beginning, sync target model totally.
self.alg.sync_target(gpu_id=self.gpu_id, decay=0)
def build_program(self):
self.pred_program = fluid.Program()
self.learn_program = fluid.Program()
with fluid.program_guard(self.pred_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
self.pred_act = self.alg.define_predict(obs)
with fluid.program_guard(self.learn_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
act = layers.data(
name='act', shape=[self.act_dim], dtype='float32')
reward = layers.data(name='reward', shape=[], dtype='float32')
next_obs = layers.data(
name='next_obs', shape=[self.obs_dim], dtype='float32')
terminal = layers.data(name='terminal', shape=[], dtype='bool')
_, self.critic_cost = self.alg.define_learn(
obs, act, reward, next_obs, terminal)
def predict(self, obs):
obs = np.expand_dims(obs, axis=0)
act = self.fluid_executor.run(
self.pred_program, feed={'obs': obs},
fetch_list=[self.pred_act])[0]
return act
def learn(self, obs, act, reward, next_obs, terminal):
feed = {
'obs': obs,
'act': act,
'reward': reward,
'next_obs': next_obs,
'terminal': terminal
}
critic_cost = self.fluid_executor.run(
self.learn_program, feed=feed, fetch_list=[self.critic_cost])[0]
self.alg.sync_target(gpu_id=self.gpu_id)
return critic_cost
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import parl.layers as layers
from parl.framework.model_base import Model
class MujocoModel(Model):
def __init__(self, act_dim, act_bound):
self.actor_model = ActorModel(act_dim, act_bound)
self.critic_model = CriticModel()
def policy(self, obs):
return self.actor_model.policy(obs)
def value(self, obs, act):
return self.critic_model.value(obs, act)
def get_actor_params(self):
return self.actor_model.parameter_names
class ActorModel(Model):
def __init__(self, act_dim, act_bound):
self.act_bound = act_bound
hid1_size = 400
hid2_size = 300
self.fc1 = layers.fc(size=hid1_size, act='relu')
self.fc2 = layers.fc(size=hid2_size, act='relu')
self.fc3 = layers.fc(size=act_dim, act='tanh')
def policy(self, obs):
hid1 = self.fc1(obs)
hid2 = self.fc2(hid1)
means = self.fc3(hid2)
means = means * self.act_bound
return means
class CriticModel(Model):
def __init__(self):
hid1_size = 400
hid2_size = 300
self.fc1 = layers.fc(size=hid1_size, act='relu')
self.fc2 = layers.fc(size=hid2_size, act='relu')
self.fc3 = layers.fc(size=1, act=None)
def value(self, obs, act):
hid1 = self.fc1(obs)
concat = layers.concat([hid1, act], axis=1)
hid2 = self.fc2(concat)
Q = self.fc3(hid2)
Q = layers.squeeze(Q, axes=[1])
return Q
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
class ReplayMemory(object):
def __init__(self, max_size, obs_dim, act_dim):
self.max_size = max_size
self.obs_memory = np.zeros((max_size, obs_dim), dtype='float32')
self.act_memory = np.zeros((max_size, act_dim), dtype='float32')
self.reward_memory = np.zeros((max_size, ), dtype='float32')
self.next_obs_memory = np.zeros((max_size, obs_dim), dtype='float32')
self.terminal_memory = np.zeros((max_size, ), dtype='bool')
self._curr_size = 0
self._curr_pos = 0
def sample_batch(self, batch_size):
batch_idx = np.random.choice(self._curr_size, size=batch_size)
obs = self.obs_memory[batch_idx, :]
act = self.act_memory[batch_idx, :]
reward = self.reward_memory[batch_idx]
next_obs = self.next_obs_memory[batch_idx, :]
terminal = self.terminal_memory[batch_idx]
return obs, act, reward, next_obs, terminal
def append(self, obs, act, reward, next_obs, terminal):
if self._curr_size < self.max_size:
self._curr_size += 1
self.obs_memory[self._curr_pos] = obs
self.act_memory[self._curr_pos] = act
self.reward_memory[self._curr_pos] = reward
self.next_obs_memory[self._curr_pos] = next_obs
self.terminal_memory[self._curr_pos] = terminal
self._curr_pos = (self._curr_pos + 1) % self.max_size
def size(self):
return self._curr_size
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gym
import numpy as np
import time
from mujoco_agent import MujocoAgent
from mujoco_model import MujocoModel
from parl.algorithms import DDPG
from parl.utils import logger
from replay_memory import ReplayMemory
MAX_EPISODES = 5000
TEST_EVERY_EPISODES = 50
MAX_STEPS_EACH_EPISODE = 1000
ACTOR_LR = 1e-4
CRITIC_LR = 1e-3
GAMMA = 0.99
TAU = 0.001
MEMORY_SIZE = int(1e6)
MIN_LEARN_SIZE = 1e4
BATCH_SIZE = 128
REWARD_SCALE = 0.1
ENV_SEED = 1
def run_train_episode(env, agent, rpm, act_bound):
obs = env.reset()
total_reward = 0
for j in range(MAX_STEPS_EACH_EPISODE):
batch_obs = np.expand_dims(obs, axis=0)
action = agent.predict(batch_obs.astype('float32'))
action = np.squeeze(action)
# Add exploration noise
action = np.clip(
np.random.normal(action, act_bound), -act_bound, act_bound)
next_obs, reward, done, info = env.step(action)
rpm.append(obs, action, REWARD_SCALE * reward, next_obs, done)
if rpm.size() > MIN_LEARN_SIZE:
batch_obs, batch_action, batch_reward, batch_next_obs, batch_terminal = rpm.sample_batch(
BATCH_SIZE)
agent.learn(batch_obs, batch_action, batch_reward, batch_next_obs,
batch_terminal)
obs = next_obs
total_reward += reward
if done:
break
return total_reward
def run_evaluate_episode(env, agent):
obs = env.reset()
total_reward = 0
for j in range(MAX_STEPS_EACH_EPISODE):
batch_obs = np.expand_dims(obs, axis=0)
action = agent.predict(batch_obs.astype('float32'))
action = np.squeeze(action)
next_obs, reward, done, info = env.step(action)
obs = next_obs
total_reward += reward
if done:
break
return total_reward
def main():
env = gym.make(args.env)
env.seed(ENV_SEED)
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
act_bound = env.action_space.high[0]
model = MujocoModel(act_dim, act_bound)
algorithm = DDPG(
model,
hyperparas={
'gamma': GAMMA,
'tau': TAU,
'actor_lr': ACTOR_LR,
'critic_lr': CRITIC_LR
})
agent = MujocoAgent(algorithm, obs_dim, act_dim)
rpm = ReplayMemory(MEMORY_SIZE, obs_dim, act_dim)
for i in range(MAX_EPISODES):
train_reward = run_train_episode(env, agent, rpm, act_bound)
logger.info('Episode: {} Reward: {}'.format(i, train_reward))
if (i + 1) % TEST_EVERY_EPISODES == 0:
evaluate_reward = run_evaluate_episode(env, agent)
logger.info('Evaluate Reward: {}'.format(evaluate_reward))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--env', help='Mujoco environment name', default='HalfCheetah-v2')
args = parser.parse_args()
main()
......@@ -32,7 +32,7 @@ class AtariAgent(Agent):
def build_program(self):
self.pred_program = fluid.Program()
self.train_program = fluid.Program()
self.learn_program = fluid.Program()
with fluid.program_guard(self.pred_program):
obs = layers.data(
......@@ -41,7 +41,7 @@ class AtariAgent(Agent):
dtype='float32')
self.value = self.alg.define_predict(obs)
with fluid.program_guard(self.train_program):
with fluid.program_guard(self.learn_program):
obs = layers.data(
name='obs',
shape=[CONTEXT_LEN, IMAGE_SIZE[0], IMAGE_SIZE[1]],
......@@ -99,5 +99,5 @@ class AtariAgent(Agent):
'terminal': terminal
}
cost = self.fluid_executor.run(
self.train_program, feed=feed, fetch_list=[self.cost])[0]
self.learn_program, feed=feed, fetch_list=[self.cost])[0]
return cost
......@@ -18,9 +18,7 @@ from parl.framework.model_base import Model
class AtariModel(Model):
def __init__(self, img_height, img_width, act_dim):
self.img_height = img_height
self.img_width = img_width
def __init__(self, act_dim):
self.act_dim = act_dim
self.conv1 = layers.conv2d(
......
......@@ -88,6 +88,9 @@ class ReplayMemory(object):
def __len__(self):
return self._curr_size
def size(self):
return self._curr_size
def _assign(self, pos, exp):
self.state[pos] = exp.state
self.reward[pos] = exp.reward
......
......@@ -13,49 +13,47 @@
# limitations under the License.
import argparse
import cv2
import gym
import paddle.fluid as fluid
import numpy as np
import os
from atari import AtariPlayer
from atari_agent import AtariAgent
from atari_model import AtariModel
from atari_wrapper import FrameStack, MapState, FireResetEnv, LimitLength
from collections import deque
from datetime import datetime
from expreplay import ReplayMemory, Experience
from replay_memory import ReplayMemory, Experience
from parl.algorithms import DQN
from parl.utils import logger
from tqdm import tqdm
from utils import get_player
MEMORY_SIZE = 1e6
MEMORY_WARMUP_SIZE = MEMORY_SIZE // 20
IMAGE_SIZE = (84, 84)
CONTEXT_LEN = 4
ACTION_REPEAT = 4 # aka FRAME_SKIP
FRAME_SKIP = 4
UPDATE_FREQ = 4
GAMMA = 0.99
LEARNING_RATE = 1e-3 * 0.5
def run_train_episode(agent, env, exp):
def run_train_episode(env, agent, rpm):
total_reward = 0
all_cost = []
state = env.reset()
step = 0
while True:
step += 1
context = exp.recent_state()
context = rpm.recent_state()
context.append(state)
context = np.stack(context, axis=0)
action = agent.sample(context)
next_state, reward, isOver, _ = env.step(action)
exp.append(Experience(state, action, reward, isOver))
rpm.append(Experience(state, action, reward, isOver))
# start training
if len(exp) > MEMORY_WARMUP_SIZE:
if rpm.size() > MEMORY_WARMUP_SIZE:
if step % UPDATE_FREQ == 0:
batch_all_state, batch_action, batch_reward, batch_isOver = exp.sample_batch(
batch_all_state, batch_action, batch_reward, batch_isOver = rpm.sample_batch(
args.batch_size)
batch_state = batch_all_state[:, :CONTEXT_LEN, :, :]
batch_next_state = batch_all_state[:, 1:, :, :]
......@@ -71,43 +69,27 @@ def run_train_episode(agent, env, exp):
return total_reward, step
def get_player(rom, viz=False, train=False):
env = AtariPlayer(
rom,
frame_skip=ACTION_REPEAT,
viz=viz,
live_lost_as_eoe=train,
max_num_frames=60000)
env = FireResetEnv(env)
env = MapState(env, lambda im: cv2.resize(im, IMAGE_SIZE))
if not train:
# in training, context is taken care of in expreplay buffer
env = FrameStack(env, CONTEXT_LEN)
return env
def eval_agent(agent, env):
episode_reward = []
for _ in tqdm(range(30), desc='eval agent'):
state = env.reset()
total_reward = 0
step = 0
while True:
step += 1
action = agent.predict(state)
state, reward, isOver, info = env.step(action)
total_reward += reward
if isOver:
break
episode_reward.append(total_reward)
eval_reward = np.mean(episode_reward)
return eval_reward
def train_agent():
env = get_player(args.rom, train=True)
test_env = get_player(args.rom)
exp = ReplayMemory(MEMORY_SIZE, IMAGE_SIZE, CONTEXT_LEN)
def run_evaluate_episode(env, agent):
state = env.reset()
total_reward = 0
while True:
action = agent.predict(state)
state, reward, isOver, info = env.step(action)
total_reward += reward
if isOver:
break
return total_reward
def main():
env = get_player(
args.rom, image_size=IMAGE_SIZE, train=True, frame_skip=FRAME_SKIP)
test_env = get_player(
args.rom,
image_size=IMAGE_SIZE,
frame_skip=FRAME_SKIP,
context_len=CONTEXT_LEN)
rpm = ReplayMemory(MEMORY_SIZE, IMAGE_SIZE, CONTEXT_LEN)
action_dim = env.action_space.n
hyperparas = {
......@@ -115,13 +97,13 @@ def train_agent():
'lr': LEARNING_RATE,
'gamma': GAMMA
}
model = AtariModel(IMAGE_SIZE[0], IMAGE_SIZE[1], action_dim)
model = AtariModel(action_dim)
algorithm = DQN(model, hyperparas)
agent = AtariAgent(algorithm, action_dim)
with tqdm(total=MEMORY_WARMUP_SIZE) as pbar:
while len(exp) < MEMORY_WARMUP_SIZE:
total_reward, step = run_train_episode(agent, env, exp)
while rpm.size() < MEMORY_WARMUP_SIZE:
total_reward, step = run_train_episode(env, agent, rpm)
pbar.update(step)
# train
......@@ -132,18 +114,23 @@ def train_agent():
max_reward = None
while True:
# start epoch
total_reward, step = run_train_episode(agent, env, exp)
total_reward, step = run_train_episode(env, agent, rpm)
total_step += step
pbar.set_description('[train]exploration:{}'.format(agent.exploration))
pbar.update(step)
if total_step // args.test_every_steps == test_flag:
pbar.write("testing")
eval_reward = eval_agent(agent, test_env)
eval_rewards = []
for _ in tqdm(range(30), desc='eval agent'):
eval_reward = run_evaluate_episode(test_env, agent)
eval_rewards.append(eval_reward)
test_flag += 1
logger.info(
"eval_agent done, (steps, eval_reward): ({}, {})".format(
total_step, eval_reward))
total_step, np.mean(eval_rewards)))
if total_step > 1e8:
break
pbar.close()
......@@ -159,4 +146,4 @@ if __name__ == '__main__':
default=100000,
help='every steps number to run test')
args = parser.parse_args()
train_agent()
main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
from atari import AtariPlayer
from atari_wrapper import FrameStack, MapState, FireResetEnv
def get_player(rom,
image_size,
viz=False,
train=False,
frame_skip=1,
context_len=1):
env = AtariPlayer(
rom,
frame_skip=frame_skip,
viz=viz,
live_lost_as_eoe=train,
max_num_frames=60000)
env = FireResetEnv(env)
env = MapState(env, lambda im: cv2.resize(im, image_size))
if not train:
# in training, context is taken care of in expreplay buffer
env = FrameStack(env, context_len)
return env
......@@ -24,6 +24,7 @@ pip install .
cd examples/QuickStart/
python train.py
# Or visualize when evaluating: python train.py --eval_vis
```
### Result
After training, you will see the agent get the best score (200 points).
......@@ -19,15 +19,19 @@ from parl.framework.agent_base import Agent
class CartpoleAgent(Agent):
def __init__(self, algorithm, obs_dim, act_dim):
def __init__(self, algorithm, obs_dim, act_dim, seed=1):
self.obs_dim = obs_dim
self.act_dim = act_dim
self.seed = seed
super(CartpoleAgent, self).__init__(algorithm)
def build_program(self):
self.pred_program = fluid.Program()
self.train_program = fluid.Program()
fluid.default_startup_program().random_seed = self.seed
self.train_program.random_seed = self.seed
with fluid.program_guard(self.pred_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
......
......@@ -19,11 +19,13 @@ from cartpole_agent import CartpoleAgent
from cartpole_model import CartpoleModel
from parl.algorithms import PolicyGradient
from parl.utils import logger
from utils import calc_discount_norm_reward
OBS_DIM = 4
ACT_DIM = 2
GAMMA = 0.99
LEARNING_RATE = 1e-3
SEED = 1
def run_train_episode(env, agent):
......@@ -56,32 +58,21 @@ def run_evaluate_episode(env, agent):
return all_reward
def calc_discount_norm_reward(reward_list):
discount_norm_reward = np.zeros_like(reward_list)
discount_cumulative_reward = 0
for i in reversed(range(0, len(reward_list))):
discount_cumulative_reward = (
GAMMA * discount_cumulative_reward + reward_list[i])
discount_norm_reward[i] = discount_cumulative_reward
discount_norm_reward = discount_norm_reward - np.mean(discount_norm_reward)
discount_norm_reward = discount_norm_reward / np.std(discount_norm_reward)
return discount_norm_reward
def main():
env = gym.make("CartPole-v0")
env.seed(SEED)
np.random.seed(SEED)
model = CartpoleModel(act_dim=ACT_DIM)
alg = PolicyGradient(model, hyperparas={'lr': LEARNING_RATE})
agent = CartpoleAgent(alg, obs_dim=OBS_DIM, act_dim=ACT_DIM)
agent = CartpoleAgent(alg, obs_dim=OBS_DIM, act_dim=ACT_DIM, seed=SEED)
for i in range(500):
for i in range(1000):
obs_list, action_list, reward_list = run_train_episode(env, agent)
logger.info("Episode {}, Reward Sum {}.".format(i, sum(reward_list)))
batch_obs = np.array(obs_list)
batch_action = np.array(action_list)
batch_reward = calc_discount_norm_reward(reward_list)
batch_reward = calc_discount_norm_reward(reward_list, GAMMA)
agent.learn(batch_obs, batch_action, batch_reward)
if (i + 1) % 100 == 0:
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
def calc_discount_norm_reward(reward_list, gamma):
discount_norm_reward = np.zeros_like(reward_list)
discount_cumulative_reward = 0
for i in reversed(range(0, len(reward_list))):
discount_cumulative_reward = (
gamma * discount_cumulative_reward + reward_list[i])
discount_norm_reward[i] = discount_cumulative_reward
discount_norm_reward = discount_norm_reward - np.mean(discount_norm_reward)
discount_norm_reward = discount_norm_reward / np.std(discount_norm_reward)
return discount_norm_reward
......@@ -14,3 +14,4 @@
from parl.algorithms.dqn import *
from parl.algorithms.policy_gradient import *
from parl.algorithms.ddpg import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import parl.layers as layers
from copy import deepcopy
from paddle import fluid
from parl.framework.algorithm_base import Algorithm
__all__ = ['DDPG']
class DDPG(Algorithm):
def __init__(self, model, hyperparas):
""" model: should implement the function get_actor_params()
"""
Algorithm.__init__(self, model, hyperparas)
self.model = model
self.target_model = deepcopy(model)
# fetch hyper parameters
self.gamma = hyperparas['gamma']
self.tau = hyperparas['tau']
self.actor_lr = hyperparas['actor_lr']
self.critic_lr = hyperparas['critic_lr']
def define_predict(self, obs):
""" use actor model of self.model to predict the action
"""
return self.model.policy(obs)
def define_learn(self, obs, action, reward, next_obs, terminal):
""" update actor and critic model with DDPG algorithm
"""
actor_cost = self._actor_learn(obs)
critic_cost = self._critic_learn(obs, action, reward, next_obs,
terminal)
return actor_cost, critic_cost
def _actor_learn(self, obs):
action = self.model.policy(obs)
Q = self.model.value(obs, action)
cost = layers.reduce_mean(-1.0 * Q)
optimizer = fluid.optimizer.AdamOptimizer(self.actor_lr)
optimizer.minimize(cost, parameter_list=self.model.get_actor_params())
return cost
def _critic_learn(self, obs, action, reward, next_obs, terminal):
next_action = self.target_model.policy(next_obs)
next_Q = self.target_model.value(next_obs, next_action)
terminal = layers.cast(terminal, dtype='float32')
target_Q = reward + (1.0 - terminal) * self.gamma * next_Q
target_Q.stop_gradient = True
Q = self.model.value(obs, action)
cost = layers.square_error_cost(Q, target_Q)
cost = layers.reduce_mean(cost)
optimizer = fluid.optimizer.AdamOptimizer(self.critic_lr)
optimizer.minimize(cost)
return cost
def sync_target(self, gpu_id, decay=None):
if decay is None:
decay = 1.0 - self.tau
self.model.sync_params_to(
self.target_model, gpu_id=gpu_id, decay=decay)
......@@ -32,19 +32,21 @@ def get_gpu_count():
if env_cuda_devices is not None:
assert isinstance(env_cuda_devices, str)
try:
if not env_cuda_devices:
return 0
gpu_count = len(
[x for x in env_cuda_devices.split(',') if int(x) >= 0])
logger.info(
'CUDA_VISIBLE_DEVICES found gpu count: {}'.format(gpu_count))
except Exception as e:
logger.error(e.message)
except:
logger.warn('Cannot find available GPU devices, using CPU now.')
gpu_count = 0
else:
try:
gpu_count = str(subprocess.check_output(["nvidia-smi",
"-L"])).count('UUID')
logger.info('nvidia-smi -L found gpu count: {}'.format(gpu_count))
except Exception as e:
logger.error(e.message)
except:
logger.warn('Cannot find available GPU devices, using CPU now.')
gpu_count = 0
return gpu_count
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册