未验证 提交 6efa7871 编写于 作者: B Bo Zhou 提交者: GitHub

breaking changes#1 (#95)

* intra-version: move parl.framework into parl.core.fluid

* add folder: parl.core

* remove former test folders

* yapf

* yapf0.24
上级 ee3e8dc2
......@@ -16,10 +16,10 @@ import gym
import numpy as np
import parl
import six
import parl
from atari_model import AtariModel
from collections import defaultdict
from atari_agent import AtariAgent
from parl.algorithms import A3C
from parl.env.atari_wrappers import wrap_deepmind, MonitorEnv, get_wrapper_by_cls
from parl.env.vector_env import VectorEnv
from parl.utils.rl_utils import calc_gae
......@@ -46,8 +46,14 @@ class Actor(object):
self.config['act_dim'] = act_dim
model = AtariModel(act_dim)
algorithm = A3C(model, hyperparas=config)
self.agent = AtariAgent(algorithm, config)
algorithm = parl.algorithms.A3C(
model, vf_loss_coeff=config['vf_loss_coeff'])
self.agent = AtariAgent(
algorithm,
obs_shape=self.config['obs_shape'],
lr_scheduler=self.config['lr_scheduler'],
entropy_coeff_scheduler=self.config['entropy_coeff_scheduler'],
)
def sample(self):
sample_data = defaultdict(list)
......@@ -112,8 +118,8 @@ class Actor(object):
metrics['episode_steps'].append(episode_steps)
return metrics
def set_params(self, params):
self.agent.set_params(params)
def set_weights(self, params):
self.agent.set_weights(params)
if __name__ == '__main__':
......
......@@ -14,22 +14,37 @@
import numpy as np
import paddle.fluid as fluid
import parl.layers as layers
from parl.framework.agent_base import Agent
import parl
from parl import layers
from parl.utils import machine_info
from parl.utils.scheduler import PiecewiseScheduler, LinearDecayScheduler
class AtariAgent(Agent):
def __init__(self, algorithm, config):
self.config = config
class AtariAgent(parl.Agent):
def __init__(self, algorithm, obs_shape, lr_scheduler,
entropy_coeff_scheduler):
"""
Args:
algorithm (`parl.Algorithm`): a2c algorithm
obs_shape (list/tuple): observation shape of atari environment
lr_scheduler (list/tuple): learning rate adjustment schedule: (train_step, learning_rate)
entropy_coeff_scheduler (list/tuple): coefficient of policy entropy adjustment schedule: (train_step, coefficient)
"""
assert isinstance(obs_shape, (list, tuple))
assert isinstance(lr_scheduler, (list, tuple))
assert isinstance(entropy_coeff_scheduler, (list, tuple))
self.obs_shape = obs_shape
self.lr_scheduler = lr_scheduler
self.entropy_coeff_scheduler = entropy_coeff_scheduler
super(AtariAgent, self).__init__(algorithm)
self.lr_scheduler = LinearDecayScheduler(config['start_lr'],
config['max_sample_steps'])
self.entropy_coeff_scheduler = PiecewiseScheduler(
config['entropy_coeff_scheduler'])
use_cuda = True if self.gpu_id >= 0 else False
self.entropy_coeff_scheduler = PiecewiseScheduler(
self.entropy_coeff_scheduler)
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.use_experimental_executor = True
......@@ -39,7 +54,7 @@ class AtariAgent(Agent):
# Use ParallelExecutor to make learn program run faster
self.learn_exe = fluid.ParallelExecutor(
use_cuda=use_cuda,
use_cuda=machine_info.is_gpu_available(),
main_program=self.learn_program,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
......@@ -52,23 +67,23 @@ class AtariAgent(Agent):
with fluid.program_guard(self.sample_program):
obs = layers.data(
name='obs', shape=self.config['obs_shape'], dtype='float32')
name='obs', shape=self.obs_shape, dtype='float32')
sample_actions, values = self.alg.sample(obs)
self.sample_outputs = [sample_actions, values]
with fluid.program_guard(self.predict_program):
obs = layers.data(
name='obs', shape=self.config['obs_shape'], dtype='float32')
name='obs', shape=self.obs_shape, dtype='float32')
self.predict_actions = self.alg.predict(obs)
with fluid.program_guard(self.value_program):
obs = layers.data(
name='obs', shape=self.config['obs_shape'], dtype='float32')
name='obs', shape=self.obs_shape, dtype='float32')
self.values = self.alg.value(obs)
with fluid.program_guard(self.learn_program):
obs = layers.data(
name='obs', shape=self.config['obs_shape'], dtype='float32')
name='obs', shape=self.obs_shape, dtype='float32')
actions = layers.data(name='actions', shape=[], dtype='int64')
advantages = layers.data(
name='advantages', shape=[], dtype='float32')
......
......@@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import parl
import paddle.fluid as fluid
import parl.layers as layers
from parl.framework.model_base import Model
from paddle.fluid.param_attr import ParamAttr
from parl import layers
class AtariModel(Model):
class AtariModel(parl.Model):
def __init__(self, act_dim):
self.conv1 = layers.conv2d(
......
......@@ -19,16 +19,17 @@ import queue
import six
import time
import threading
import parl
from atari_model import AtariModel
from atari_agent import AtariAgent
from collections import defaultdict
from parl import RemoteManager
from parl.algorithms import A3C
from parl.env.atari_wrappers import wrap_deepmind
from parl.utils import logger, CSVLogger, get_gpu_count
from parl.utils.scheduler import PiecewiseScheduler
from parl.utils.time_stat import TimeStat
from parl.utils.window_stat import WindowStat
from parl.utils import machine_info
class Learner(object):
......@@ -44,10 +45,16 @@ class Learner(object):
self.config['act_dim'] = act_dim
model = AtariModel(act_dim)
algorithm = A3C(model, hyperparas=config)
self.agent = AtariAgent(algorithm, config)
if self.agent.gpu_id >= 0:
algorithm = parl.algorithms.A3C(
model, vf_loss_coeff=config['vf_loss_coeff'])
self.agent = AtariAgent(
algorithm,
obs_shape=self.config['obs_shape'],
lr_scheduler=self.config['lr_scheduler'],
entropy_coeff_scheduler=self.config['entropy_coeff_scheduler'],
)
if machine_info.is_gpu_available():
assert get_gpu_count() == 1, 'Only support training in single GPU,\
Please set environment variable: `export CUDA_VISIBLE_DEVICES=[GPU_ID_YOU_WANT_TO_USE]` .'
......@@ -111,7 +118,7 @@ class Learner(object):
cnt = 0
while True:
latest_params = params_queue.get()
remote_actor.set_params(latest_params)
remote_actor.set_weights(latest_params)
batch = remote_actor.sample()
self.sample_data_queue.put(batch)
......@@ -129,7 +136,7 @@ class Learner(object):
3. update parameters.
"""
latest_params = self.agent.get_params()
latest_params = self.agent.get_weights()
for params_queue in self.params_queues:
params_queue.put(latest_params)
......
......@@ -13,19 +13,21 @@
# limitations under the License.
import numpy as np
import parl.layers as layers
import parl
from parl import layers
from paddle import fluid
from parl.framework.agent_base import Agent
class MujocoAgent(Agent):
class MujocoAgent(parl.Agent):
def __init__(self, algorithm, obs_dim, act_dim):
assert isinstance(obs_dim, int)
assert isinstance(act_dim, int)
self.obs_dim = obs_dim
self.act_dim = act_dim
super(MujocoAgent, self).__init__(algorithm)
# Attention: In the beginning, sync target model totally.
self.alg.sync_target(gpu_id=self.gpu_id, decay=0)
self.alg.sync_target(decay=0)
def build_program(self):
self.pred_program = fluid.Program()
......@@ -34,7 +36,7 @@ class MujocoAgent(Agent):
with fluid.program_guard(self.pred_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
self.pred_act = self.alg.define_predict(obs)
self.pred_act = self.alg.predict(obs)
with fluid.program_guard(self.learn_program):
obs = layers.data(
......@@ -45,8 +47,8 @@ class MujocoAgent(Agent):
next_obs = layers.data(
name='next_obs', shape=[self.obs_dim], dtype='float32')
terminal = layers.data(name='terminal', shape=[], dtype='bool')
_, self.critic_cost = self.alg.define_learn(
obs, act, reward, next_obs, terminal)
_, self.critic_cost = self.alg.learn(obs, act, reward, next_obs,
terminal)
def predict(self, obs):
obs = np.expand_dims(obs, axis=0)
......@@ -65,5 +67,5 @@ class MujocoAgent(Agent):
}
critic_cost = self.fluid_executor.run(
self.learn_program, feed=feed, fetch_list=[self.critic_cost])[0]
self.alg.sync_target(gpu_id=self.gpu_id)
self.alg.sync_target()
return critic_cost
......@@ -13,11 +13,11 @@
# limitations under the License.
import paddle.fluid as fluid
import parl.layers as layers
from parl.framework.model_base import Model
import parl
from parl import layers
class MujocoModel(Model):
class MujocoModel(parl.Model):
def __init__(self, act_dim):
self.actor_model = ActorModel(act_dim)
self.critic_model = CriticModel()
......@@ -29,10 +29,10 @@ class MujocoModel(Model):
return self.critic_model.value(obs, act)
def get_actor_params(self):
return self.actor_model.parameter_names
return self.actor_model.parameters()
class ActorModel(Model):
class ActorModel(parl.Model):
def __init__(self, act_dim):
hid1_size = 400
hid2_size = 300
......@@ -49,7 +49,7 @@ class ActorModel(Model):
return means
class CriticModel(Model):
class CriticModel(parl.Model):
def __init__(self):
hid1_size = 400
hid2_size = 300
......
......@@ -16,9 +16,9 @@ import argparse
import gym
import numpy as np
import time
import parl
from mujoco_agent import MujocoAgent
from mujoco_model import MujocoModel
from parl.algorithms import DDPG
from parl.utils import logger, action_mapping, ReplayMemory
MAX_EPISODES = 5000
......@@ -95,14 +95,8 @@ def main():
act_dim = env.action_space.shape[0]
model = MujocoModel(act_dim)
algorithm = DDPG(
model,
hyperparas={
'gamma': GAMMA,
'tau': TAU,
'actor_lr': ACTOR_LR,
'critic_lr': CRITIC_LR
})
algorithm = parl.algorithms.DDPG(
model, gamma=GAMMA, tau=TAU, actor_lr=ACTOR_LR, critic_lr=CRITIC_LR)
agent = MujocoAgent(algorithm, obs_dim, act_dim)
rpm = ReplayMemory(MEMORY_SIZE, obs_dim, act_dim)
......
......@@ -14,19 +14,20 @@
import numpy as np
import paddle.fluid as fluid
import parl.layers as layers
from parl.framework.agent_base import Agent
import parl
from parl import layers
IMAGE_SIZE = (84, 84)
CONTEXT_LEN = 4
class AtariAgent(Agent):
def __init__(self, algorithm, action_dim):
class AtariAgent(parl.Agent):
def __init__(self, algorithm, act_dim):
super(AtariAgent, self).__init__(algorithm)
assert isinstance(act_dim, int)
self.act_dim = act_dim
self.exploration = 1.1
self.action_dim = action_dim
self.global_step = 0
self.update_target_steps = 10000 // 4
......@@ -39,7 +40,7 @@ class AtariAgent(Agent):
name='obs',
shape=[CONTEXT_LEN, IMAGE_SIZE[0], IMAGE_SIZE[1]],
dtype='float32')
self.value = self.alg.define_predict(obs)
self.value = self.alg.predict(obs)
with fluid.program_guard(self.learn_program):
obs = layers.data(
......@@ -53,16 +54,15 @@ class AtariAgent(Agent):
shape=[CONTEXT_LEN, IMAGE_SIZE[0], IMAGE_SIZE[1]],
dtype='float32')
terminal = layers.data(name='terminal', shape=[], dtype='bool')
self.cost = self.alg.define_learn(obs, action, reward, next_obs,
terminal)
self.cost = self.alg.learn(obs, action, reward, next_obs, terminal)
def sample(self, obs):
sample = np.random.random()
if sample < self.exploration:
act = np.random.randint(self.action_dim)
act = np.random.randint(self.act_dim)
else:
if np.random.random() < 0.01:
act = np.random.randint(self.action_dim)
act = np.random.randint(self.act_dim)
else:
obs = np.expand_dims(obs, axis=0)
pred_Q = self.fluid_executor.run(
......@@ -86,7 +86,7 @@ class AtariAgent(Agent):
def learn(self, obs, act, reward, next_obs, terminal):
if self.global_step % self.update_target_steps == 0:
self.alg.sync_target(self.gpu_id)
self.alg.sync_target()
self.global_step += 1
act = np.expand_dims(act, -1)
......
......@@ -13,11 +13,11 @@
# limitations under the License.
import paddle.fluid as fluid
import parl.layers as layers
from parl.framework.model_base import Model
import parl
from parl import layers
class AtariModel(Model):
class AtariModel(parl.Model):
def __init__(self, act_dim):
self.act_dim = act_dim
......
......@@ -17,12 +17,12 @@ import gym
import paddle.fluid as fluid
import numpy as np
import os
import parl
from atari_agent import AtariAgent
from atari_model import AtariModel
from collections import deque
from datetime import datetime
from replay_memory import ReplayMemory, Experience
from parl.algorithms import DQN
from parl.utils import logger
from tqdm import tqdm
from utils import get_player
......@@ -91,16 +91,12 @@ def main():
frame_skip=FRAME_SKIP,
context_len=CONTEXT_LEN)
rpm = ReplayMemory(MEMORY_SIZE, IMAGE_SIZE, CONTEXT_LEN)
action_dim = env.action_space.n
hyperparas = {
'action_dim': action_dim,
'lr': LEARNING_RATE,
'gamma': GAMMA
}
model = AtariModel(action_dim)
algorithm = DQN(model, hyperparas)
agent = AtariAgent(algorithm, action_dim)
act_dim = env.action_space.n
model = AtariModel(act_dim)
algorithm = parl.algorithms.DQN(
model, act_dim=act_dim, gamma=GAMMA, lr=LEARNING_RATE)
agent = AtariAgent(algorithm, act_dim=act_dim)
with tqdm(total=MEMORY_WARMUP_SIZE) as pbar:
while rpm.size() < MEMORY_WARMUP_SIZE:
......
......@@ -14,16 +14,31 @@
import numpy as np
import paddle.fluid as fluid
import parl.layers as layers
from parl.framework.agent_base import Agent
import parl
from parl import layers
from parl.utils import machine_info
class AtariAgent(Agent):
def __init__(self, algorithm, config, learn_data_provider=None):
self.config = config
super(AtariAgent, self).__init__(algorithm)
class AtariAgent(parl.Agent):
def __init__(self,
algorithm,
obs_shape,
predict_thread_num,
learn_data_provider=None):
"""
use_cuda = True if self.gpu_id >= 0 else False
Args:
algorithm (`parl.Algorithm`): a2c algorithm
obs_shape (list/tuple): observation shape of atari environment
predict_thread_num (int): number of predict thread (predict parallel exector)
learn_data_provider: data generator of training
"""
assert isinstance(obs_shape, (list, tuple))
assert isinstance(predict_thread_num, int)
self.obs_shape = obs_shape
super(AtariAgent, self).__init__(algorithm)
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.use_experimental_executor = True
......@@ -33,16 +48,16 @@ class AtariAgent(Agent):
# Use ParallelExecutor to make learn program run faster
self.learn_exe = fluid.ParallelExecutor(
use_cuda=use_cuda,
use_cuda=machine_info.is_gpu_available(),
main_program=self.learn_program,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
self.sample_exes = []
for _ in range(config['predict_thread_num']):
for _ in range(predict_thread_num):
with fluid.scope_guard(fluid.global_scope().new_scope()):
pe = fluid.ParallelExecutor(
use_cuda=use_cuda,
use_cuda=machine_info.is_gpu_available(),
main_program=self.sample_program,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
......@@ -59,18 +74,18 @@ class AtariAgent(Agent):
with fluid.program_guard(self.sample_program):
obs = layers.data(
name='obs', shape=self.config['obs_shape'], dtype='float32')
name='obs', shape=self.obs_shape, dtype='float32')
sample_actions, values = self.alg.sample(obs)
self.sample_outputs = [sample_actions.name, values.name]
with fluid.program_guard(self.predict_program):
obs = layers.data(
name='obs', shape=self.config['obs_shape'], dtype='float32')
name='obs', shape=self.obs_shape, dtype='float32')
self.predict_actions = self.alg.predict(obs)
with fluid.program_guard(self.learn_program):
obs = layers.data(
name='obs', shape=self.config['obs_shape'], dtype='float32')
name='obs', shape=self.obs_shape, dtype='float32')
actions = layers.data(name='actions', shape=[], dtype='int64')
advantages = layers.data(
name='advantages', shape=[], dtype='float32')
......
......@@ -19,17 +19,18 @@ import queue
import six
import time
import threading
import parl
from atari_model import AtariModel
from atari_agent import AtariAgent
from collections import defaultdict
from parl import RemoteManager
from parl.algorithms import A3C
from parl.env.atari_wrappers import wrap_deepmind
from parl.utils import logger, CSVLogger, get_gpu_count
from parl.utils.scheduler import PiecewiseScheduler
from parl.utils.time_stat import TimeStat
from parl.utils.window_stat import WindowStat
from parl.utils.rl_utils import calc_gae
from parl.utils import machine_info
class Learner(object):
......@@ -49,10 +50,15 @@ class Learner(object):
self.config['act_dim'] = act_dim
model = AtariModel(act_dim)
algorithm = A3C(model, hyperparas=config)
self.agent = AtariAgent(algorithm, config, self.learn_data_provider)
if self.agent.gpu_id >= 0:
algorithm = parl.algorithms.A3C(
model, vf_loss_coeff=config['vf_loss_coeff'])
self.agent = AtariAgent(
algorithm,
obs_shape=self.config['obs_shape'],
predict_thread_num=self.config['predict_thread_num'],
learn_data_provider=self.learn_data_provider)
if machine_info.is_gpu_available():
assert get_gpu_count() == 1, 'Only support training in single GPU,\
Please set environment variable: `export CUDA_VISIBLE_DEVICES=[GPU_ID_YOU_WANT_TO_USE]` .'
......
......@@ -16,10 +16,10 @@ import gym
import numpy as np
import parl
import six
import parl
from atari_model import AtariModel
from collections import defaultdict
from atari_agent import AtariAgent
from parl.algorithms import IMPALA
from parl.env.atari_wrappers import wrap_deepmind, MonitorEnv, get_wrapper_by_cls
from parl.env.vector_env import VectorEnv
......@@ -41,12 +41,15 @@ class Actor(object):
obs_shape = env.observation_space.shape
act_dim = env.action_space.n
self.config['obs_shape'] = obs_shape
self.config['act_dim'] = act_dim
model = AtariModel(act_dim)
algorithm = IMPALA(model, hyperparas=config)
self.agent = AtariAgent(algorithm, config)
algorithm = parl.algorithms.IMPALA(
model,
sample_batch_steps=self.config['sample_batch_steps'],
gamma=self.config['gamma'],
vf_loss_coeff=self.config['vf_loss_coeff'],
clip_rho_threshold=self.config['clip_rho_threshold'],
clip_pg_rho_threshold=self.config['clip_pg_rho_threshold'])
self.agent = AtariAgent(algorithm, obs_shape, act_dim)
def sample(self):
env_sample_data = {}
......@@ -95,8 +98,8 @@ class Actor(object):
metrics['episode_steps'].append(episode_steps)
return metrics
def set_params(self, params):
self.agent.set_params(params)
def set_weights(self, weights):
self.agent.set_weights(weights)
if __name__ == '__main__':
......
......@@ -14,17 +14,20 @@
import numpy as np
import paddle.fluid as fluid
import parl.layers as layers
from parl.framework.agent_base import Agent
class AtariAgent(Agent):
def __init__(self, algorithm, config, learn_data_provider=None):
self.config = config
import parl
from parl import layers
from parl.utils import machine_info
class AtariAgent(parl.Agent):
def __init__(self, algorithm, obs_shape, act_dim,
learn_data_provider=None):
assert isinstance(obs_shape, (list, tuple))
assert isinstance(act_dim, int)
self.obs_shape = obs_shape
self.act_dim = act_dim
super(AtariAgent, self).__init__(algorithm)
use_cuda = True if self.gpu_id >= 0 else False
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.use_experimental_executor = True
exec_strategy.num_threads = 4
......@@ -33,7 +36,7 @@ class AtariAgent(Agent):
# Use ParallelExecutor to make learn program run faster
self.learn_exe = fluid.ParallelExecutor(
use_cuda=use_cuda,
use_cuda=machine_info.is_gpu_available(),
main_program=self.learn_program,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
......@@ -49,22 +52,20 @@ class AtariAgent(Agent):
with fluid.program_guard(self.sample_program):
obs = layers.data(
name='obs', shape=self.config['obs_shape'], dtype='float32')
name='obs', shape=self.obs_shape, dtype='float32')
self.sample_actions, self.behaviour_logits = self.alg.sample(obs)
with fluid.program_guard(self.predict_program):
obs = layers.data(
name='obs', shape=self.config['obs_shape'], dtype='float32')
name='obs', shape=self.obs_shape, dtype='float32')
self.predict_actions = self.alg.predict(obs)
with fluid.program_guard(self.learn_program):
obs = layers.data(
name='obs', shape=self.config['obs_shape'], dtype='float32')
name='obs', shape=self.obs_shape, dtype='float32')
actions = layers.data(name='actions', shape=[], dtype='int64')
behaviour_logits = layers.data(
name='behaviour_logits',
shape=[self.config['act_dim']],
dtype='float32')
name='behaviour_logits', shape=[self.act_dim], dtype='float32')
rewards = layers.data(name='rewards', shape=[], dtype='float32')
dones = layers.data(name='dones', shape=[], dtype='float32')
lr = layers.data(
......
......@@ -13,12 +13,12 @@
# limitations under the License.
import paddle.fluid as fluid
import parl.layers as layers
from parl.framework.model_base import Model
import parl
from parl import layers
from paddle.fluid.param_attr import ParamAttr
class AtariModel(Model):
class AtariModel(parl.Model):
def __init__(self, act_dim):
self.conv1 = layers.conv2d(
......
......@@ -18,10 +18,10 @@ import os
import queue
import time
import threading
import parl
from atari_model import AtariModel
from atari_agent import AtariAgent
from parl import RemoteManager
from parl.algorithms import IMPALA
from parl.env.atari_wrappers import wrap_deepmind
from parl.utils import logger, CSVLogger
from parl.utils.scheduler import PiecewiseScheduler
......@@ -41,14 +41,19 @@ class Learner(object):
obs_shape = env.observation_space.shape
act_dim = env.action_space.n
self.config['obs_shape'] = obs_shape
self.config['act_dim'] = act_dim
model = AtariModel(act_dim)
algorithm = IMPALA(model, hyperparas=config)
self.agent = AtariAgent(algorithm, config, self.learn_data_provider)
self.cache_params = self.agent.get_params()
algorithm = parl.algorithms.IMPALA(
model,
sample_batch_steps=self.config['sample_batch_steps'],
gamma=self.config['gamma'],
vf_loss_coeff=self.config['vf_loss_coeff'],
clip_rho_threshold=self.config['clip_rho_threshold'],
clip_pg_rho_threshold=self.config['clip_pg_rho_threshold'])
self.agent = AtariAgent(algorithm, obs_shape, act_dim,
self.learn_data_provider)
self.cache_params = self.agent.get_weights()
self.params_lock = threading.Lock()
self.params_updated = False
self.cache_params_sent_cnt = 0
......@@ -155,7 +160,7 @@ class Learner(object):
""" Sample data from remote actor and update parameters of remote actor.
"""
cnt = 0
remote_actor.set_params(self.cache_params)
remote_actor.set_weights(self.cache_params)
while True:
batch = remote_actor.sample()
self.sample_data_queue.put(batch)
......@@ -171,14 +176,14 @@ class Learner(object):
if self.params_updated and self.cache_params_sent_cnt >= self.config[
'params_broadcast_interval']:
self.params_updated = False
self.cache_params = self.agent.get_params()
self.cache_params = self.agent.get_weights()
self.cache_params_sent_cnt = 0
self.cache_params_sent_cnt += 1
self.total_params_sync += 1
self.params_lock.release()
remote_actor.set_params(self.cache_params)
remote_actor.set_weights(self.cache_params)
def log_metrics(self):
""" Log metrics of learner and actors
......
......@@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import parl.layers as layers
import parl
from parl import layers
from copy import deepcopy
from paddle import fluid
from parl.framework.algorithm_base import Algorithm
__all__ = ['MultiHeadDDPG']
class MultiHeadDDPG(Algorithm):
class MultiHeadDDPG(parl.Algorithm):
def __init__(self, models, hyperparas):
""" model: should implement the function get_actor_params()
"""
......@@ -35,12 +35,12 @@ class MultiHeadDDPG(Algorithm):
self.tau = hyperparas['tau']
self.ensemble_num = hyperparas['ensemble_num']
def define_predict(self, obs, model_id):
def predict(self, obs, model_id):
""" use actor model of self.models[model_id] to predict the action
"""
return self.models[model_id].policy(obs)
def define_ensemble_predict(self, obs):
def ensemble_predict(self, obs):
""" ensemble predict:
1. For actions of all actors, each critic will score them
and normalize its scores;
......@@ -75,8 +75,8 @@ class MultiHeadDDPG(Algorithm):
ensemble_predict_action = layers.gather(batch_actions, best_score_id)
return ensemble_predict_action
def define_learn(self, obs, action, reward, next_obs, terminal, actor_lr,
critic_lr, model_id):
def learn(self, obs, action, reward, next_obs, terminal, actor_lr,
critic_lr, model_id):
""" update actor and critic model of self.models[model_id] with DDPG algorithm
"""
actor_cost = self._actor_learn(obs, actor_lr, model_id)
......@@ -110,14 +110,12 @@ class MultiHeadDDPG(Algorithm):
return cost
def sync_target(self,
gpu_id,
model_id,
decay=None,
share_vars_parallel_executor=None):
if decay is None:
decay = 1.0 - self.tau
self.models[model_id].sync_params_to(
self.models[model_id].sync_weights_to(
self.target_models[model_id],
gpu_id=gpu_id,
decay=decay,
share_vars_parallel_executor=share_vars_parallel_executor)
......@@ -13,15 +13,15 @@
# limitations under the License.
import numpy as np
import parl.layers as layers
import re
import parl
from parl import layers
from paddle import fluid
from paddle.fluid.executor import _fetch_var
from parl.framework.agent_base import Agent
from parl.utils import logger
class OpenSimAgent(Agent):
class OpenSimAgent(parl.Agent):
def __init__(self, algorithm, obs_dim, act_dim, ensemble_num):
self.obs_dim = obs_dim
self.act_dim = act_dim
......@@ -29,7 +29,7 @@ class OpenSimAgent(Agent):
super(OpenSimAgent, self).__init__(algorithm)
# Use ParallelExecutor to make program running faster
use_cuda = True if self.gpu_id >= 0 else False
use_cuda = True if parl.GPU_ID >= 0 else False
self.learn_pe = []
self.pred_pe = []
......@@ -58,16 +58,13 @@ class OpenSimAgent(Agent):
# Attention: In the beginning, sync target model totally.
self.alg.sync_target(
gpu_id=self.gpu_id,
model_id=i,
decay=1.0,
share_vars_parallel_executor=self.learn_pe[i])
# Do cache, will create ParallelExecutor of sync params in advance
# If not, there are some issues when ensemble_num > 1
self.alg.sync_target(
gpu_id=self.gpu_id,
model_id=i,
share_vars_parallel_executor=self.learn_pe[i])
model_id=i, share_vars_parallel_executor=self.learn_pe[i])
with fluid.scope_guard(fluid.global_scope().new_scope()):
self.ensemble_predict_pe = fluid.ParallelExecutor(
......@@ -86,7 +83,7 @@ class OpenSimAgent(Agent):
with fluid.program_guard(predict_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
act = self.alg.define_predict(obs, model_id=i)
act = self.alg.predict(obs, model_id=i)
self.predict_programs.append(predict_program)
self.predict_outputs.append([act.name])
......@@ -110,7 +107,7 @@ class OpenSimAgent(Agent):
shape=[1],
dtype='float32',
append_batch_size=False)
actor_loss, critic_loss = self.alg.define_learn(
actor_loss, critic_loss = self.alg.learn(
obs,
act,
reward,
......@@ -126,7 +123,7 @@ class OpenSimAgent(Agent):
with fluid.program_guard(self.ensemble_predict_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
act = self.alg.define_ensemble_predict(obs)
act = self.alg.ensemble_predict(obs)
self.ensemble_predict_output = [act.name]
def predict(self, obs, model_id):
......@@ -159,7 +156,6 @@ class OpenSimAgent(Agent):
critic_loss = self.learn_pe[model_id].run(
feed=feed, fetch_list=self.learn_programs_output[model_id])[0]
self.alg.sync_target(
gpu_id=self.gpu_id,
model_id=model_id,
share_vars_parallel_executor=self.learn_pe[model_id])
return critic_loss
......
......@@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import parl.layers as layers
import parl
from parl import layers
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from parl.framework.model_base import Model
class OpenSimModel(Model):
class OpenSimModel(parl.Model):
def __init__(self, obs_dim, vel_obs_dim, act_dim, model_id=0, shared=True):
self.actor_model = ActorModel(obs_dim, vel_obs_dim, act_dim, model_id,
shared)
......@@ -32,10 +32,10 @@ class OpenSimModel(Model):
return self.critic_model.value(obs, action)
def get_actor_params(self):
return self.actor_model.parameter_names
return self.actor_model.parameters()
class ActorModel(Model):
class ActorModel(parl.Model):
def __init__(self, obs_dim, vel_obs_dim, act_dim, model_id, shared):
hid0_size = 800
hid1_size = 400
......@@ -104,7 +104,7 @@ class ActorModel(Model):
return means
class CriticModel(Model):
class CriticModel(parl.Model):
def __init__(self, obs_dim, vel_obs_dim, act_dim, model_id, shared):
super(CriticModel, self).__init__()
hid0_size = 800
......
......@@ -13,13 +13,13 @@
# limitations under the License.
import numpy as np
import parl.layers as layers
import parl
from parl import layers
from paddle import fluid
from parl.framework.agent_base import Agent
from parl.utils import logger
class MujocoAgent(Agent):
class MujocoAgent(parl.Agent):
def __init__(self,
algorithm,
obs_dim,
......@@ -57,13 +57,13 @@ class MujocoAgent(Agent):
with fluid.program_guard(self.policy_sample_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
sampled_act = self.alg.define_sample(obs)
sampled_act = self.alg.sample(obs)
self.policy_sample_output = [sampled_act]
with fluid.program_guard(self.policy_predict_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
means = self.alg.define_predict(obs)
means = self.alg.predict(obs)
self.policy_predict_output = [means]
with fluid.program_guard(self.policy_learn_program):
......@@ -75,25 +75,24 @@ class MujocoAgent(Agent):
name='advantages', shape=[1], dtype='float32')
if self.loss_type == 'KLPEN':
beta = layers.data(name='beta', shape=[], dtype='float32')
loss, kl = self.alg.define_policy_learn(
obs, actions, advantages, beta)
loss, kl = self.alg.policy_learn(obs, actions, advantages,
beta)
else:
loss, kl = self.alg.define_policy_learn(
obs, actions, advantages)
loss, kl = self.alg.policy_learn(obs, actions, advantages)
self.policy_learn_output = [loss, kl]
with fluid.program_guard(self.value_predict_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
value = self.alg.define_value_predict(obs)
value = self.alg.value_predict(obs)
self.value_predict_output = [value]
with fluid.program_guard(self.value_learn_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
val = layers.data(name='val', shape=[], dtype='float32')
value_loss = self.alg.define_value_learn(obs, val)
value_loss = self.alg.value_learn(obs, val)
self.value_learn_output = [value_loss]
def policy_sample(self, obs):
......@@ -151,7 +150,7 @@ class MujocoAgent(Agent):
2. Fix old policy model, and learn policy model multi times
3. if use KLPEN loss, Adjust kl loss coefficient: beta
"""
self.alg.sync_old_policy(self.gpu_id)
self.alg.sync_old_policy()
all_loss, all_kl = [], []
for _ in range(self.policy_learn_times):
......
......@@ -13,13 +13,13 @@
# limitations under the License.
import numpy as np
import parl.layers as layers
import parl
from parl import layers
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from parl.framework.model_base import Model
class MujocoModel(Model):
class MujocoModel(parl.Model):
def __init__(self, obs_dim, act_dim, init_logvar=-1.0):
self.policy_model = PolicyModel(obs_dim, act_dim, init_logvar)
self.value_model = ValueModel(obs_dim, act_dim)
......@@ -36,7 +36,7 @@ class MujocoModel(Model):
return self.value_model.value(obs)
class PolicyModel(Model):
class PolicyModel(parl.Model):
def __init__(self, obs_dim, act_dim, init_logvar):
self.obs_dim = obs_dim
self.act_dim = act_dim
......@@ -73,7 +73,7 @@ class PolicyModel(Model):
return sampled_act
class ValueModel(Model):
class ValueModel(parl.Model):
def __init__(self, obs_dim, act_dim):
super(ValueModel, self).__init__()
hid1_size = obs_dim * 10
......
......@@ -15,9 +15,9 @@
import argparse
import gym
import numpy as np
import parl
from mujoco_agent import MujocoAgent
from mujoco_model import MujocoModel
from parl.algorithms import PPO
from parl.utils import logger, action_mapping
from parl.utils.rl_utils import calc_gae, calc_discount_sum_rewards
from scaler import Scaler
......@@ -142,12 +142,11 @@ def main():
scaler = Scaler(obs_dim)
model = MujocoModel(obs_dim, act_dim)
hyperparas = {
'act_dim': act_dim,
'policy_lr': model.policy_lr,
'value_lr': model.value_lr
}
alg = PPO(model, hyperparas)
alg = parl.algorithms.PPO(
model,
act_dim=act_dim,
policy_lr=model.policy_lr,
value_lr=model.value_lr)
agent = MujocoAgent(
alg, obs_dim, act_dim, args.kl_targ, loss_type=args.loss_type)
......
......@@ -14,8 +14,8 @@
import numpy as np
import paddle.fluid as fluid
import parl.layers as layers
from parl.framework.agent_base import Agent
import parl
from parl import layers
class CartpoleAgent(Agent):
......@@ -31,14 +31,14 @@ class CartpoleAgent(Agent):
with fluid.program_guard(self.pred_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
self.act_prob = self.alg.define_predict(obs)
self.act_prob = self.alg.predict(obs)
with fluid.program_guard(self.train_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
act = layers.data(name='act', shape=[1], dtype='int64')
reward = layers.data(name='reward', shape=[], dtype='float32')
self.cost = self.alg.define_learn(obs, act, reward)
self.cost = self.alg.learn(obs, act, reward)
def sample(self, obs):
obs = np.expand_dims(obs, axis=0)
......
......@@ -13,11 +13,11 @@
# limitations under the License.
import paddle.fluid as fluid
import parl.layers as layers
from parl.framework.model_base import Model
import parl
from parl import layers
class CartpoleModel(Model):
class CartpoleModel(parl.Model):
def __init__(self, act_dim):
act_dim = act_dim
hid1_size = act_dim * 10
......
......@@ -14,9 +14,9 @@
import gym
import numpy as np
import parl
from cartpole_agent import CartpoleAgent
from cartpole_model import CartpoleModel
from parl.algorithms import PolicyGradient
from parl.utils import logger
from utils import calc_discount_norm_reward
......@@ -48,7 +48,7 @@ def run_episode(env, agent, train_or_test='train'):
def main():
env = gym.make("CartPole-v0")
model = CartpoleModel(act_dim=ACT_DIM)
alg = PolicyGradient(model, hyperparas={'lr': LEARNING_RATE})
alg = parl.algorithms.PolicyGradient(model, lr=LEARNING_RATE)
agent = CartpoleAgent(alg, obs_dim=OBS_DIM, act_dim=ACT_DIM)
for i in range(1000):
......
......@@ -23,10 +23,11 @@ from tensorboardX import SummaryWriter
from parl.utils.utils import _HAS_FLUID
if _HAS_FLUID:
from parl.framework import *
from parl.core.fluid import *
else:
print(
"WARNING:PARL: Failed to import paddle. Only APIs for parallelization are available."
)
from parl.remote import remote_class, RemoteManager
from parl import algorithms
......@@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from parl.algorithms.a3c import *
from parl.algorithms.ddpg import *
from parl.algorithms.dqn import *
from parl.algorithms.policy_gradient import *
from parl.algorithms.ppo import *
from parl.algorithms.impala.impala import *
from parl.utils.utils import _HAS_FLUID
if _HAS_FLUID:
from parl.algorithms.fluid import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from parl.algorithms.fluid.a3c import *
from parl.algorithms.fluid.ddpg import *
from parl.algorithms.fluid.dqn import *
from parl.algorithms.fluid.policy_gradient import *
from parl.algorithms.fluid.ppo import *
from parl.algorithms.fluid.impala.impala import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
warnings.simplefilter('default')
import paddle.fluid as fluid
from parl.core.fluid import layers
from parl.core.fluid.algorithm import Algorithm
from parl.core.fluid.policy_distribution import CategoricalDistribution
__all__ = ['A3C']
class A3C(Algorithm):
def __init__(self, model, hyperparas=None, vf_loss_coeff=None):
""" A3C/A2C algorithm
Args:
model (parl.Model): forward network of policy and value
hyperparas (dict): (deprecated) dict of hyper parameters.
vf_loss_coeff (float): coefficient of the value function loss
"""
self.model = model
if hyperparas is not None:
warnings.warn(
"the `hyperparas` argument of `__init__` function in `parl.Algorithms.A3C` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.vf_loss_coeff = hyperparas['vf_loss_coeff']
else:
assert isinstance(vf_loss_coeff, (int, float))
self.vf_loss_coeff = vf_loss_coeff
def learn(self, obs, actions, advantages, target_values, learning_rate,
entropy_coeff):
"""
Args:
obs: An float32 tensor of shape ([B] + observation_space).
E.g. [B, C, H, W] in atari.
actions: An int64 tensor of shape [B].
advantages: A float32 tensor of shape [B].
target_values: A float32 tensor of shape [B].
learning_rate: float scalar of learning rate.
entropy_coeff: float scalar of entropy coefficient.
"""
logits = self.model.policy(obs)
policy_distribution = CategoricalDistribution(logits)
actions_log_probs = policy_distribution.logp(actions)
# The policy gradient loss
pi_loss = -1.0 * layers.reduce_sum(actions_log_probs * advantages)
# The value function loss
values = self.model.value(obs)
delta = values - target_values
vf_loss = 0.5 * layers.reduce_sum(layers.square(delta))
# The entropy loss (We want to maximize entropy, so entropy_ceoff < 0)
policy_entropy = policy_distribution.entropy()
entropy = layers.reduce_sum(policy_entropy)
total_loss = (
pi_loss + vf_loss * self.vf_loss_coeff + entropy * entropy_coeff)
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=40.0))
optimizer = fluid.optimizer.AdamOptimizer(learning_rate)
optimizer.minimize(total_loss)
return total_loss, pi_loss, vf_loss, entropy
def sample(self, obs):
"""
Args:
obs: An float32 tensor of shape ([B] + observation_space).
E.g. [B, C, H, W] in atari.
"""
logits, values = self.model.policy_and_value(obs)
policy_dist = CategoricalDistribution(logits)
sample_actions = policy_dist.sample()
return sample_actions, values
def predict(self, obs):
"""
Args:
obs: An float32 tensor of shape ([B] + observation_space).
E.g. [B, C, H, W] in atari.
"""
logits = self.model.policy(obs)
probs = layers.softmax(logits)
predict_actions = layers.argmax(probs, 1)
return predict_actions
def value(self, obs):
"""
Args:
obs: An float32 tensor of shape ([B] + observation_space).
E.g. [B, C, H, W] in atari.
"""
values = self.model.value(obs)
return values
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
warnings.simplefilter('default')
from parl.core.fluid import layers
from copy import deepcopy
from paddle import fluid
from parl.core.fluid.algorithm import Algorithm
from parl.utils.deprecation import deprecated
__all__ = ['DDPG']
class DDPG(Algorithm):
def __init__(self,
model,
hyperparas=None,
gamma=None,
tau=None,
actor_lr=None,
critic_lr=None):
""" DDPG algorithm
Args:
model (parl.Model): forward network of actor and critic.
The function get_actor_params() of model should be implemented.
hyperparas (dict): (deprecated) dict of hyper parameters.
gamma (float): discounted factor for reward computation.
tau (float): decay coefficient when updating the weights of self.target_model with self.model
actor_lr (float): learning rate of the actor model
critic_lr (float): learning rate of the critic model
"""
if hyperparas is not None:
warnings.warn(
"the `hyperparas` argument of `__init__` function in `parl.Algorithms.DDPG` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.gamma = hyperparas['gamma']
self.tau = hyperparas['tau']
self.actor_lr = hyperparas['actor_lr']
self.critic_lr = hyperparas['critic_lr']
else:
assert isinstance(gamma, float)
assert isinstance(tau, float)
assert isinstance(actor_lr, float)
assert isinstance(critic_lr, float)
self.gamma = gamma
self.tau = tau
self.actor_lr = actor_lr
self.critic_lr = critic_lr
self.model = model
self.target_model = deepcopy(model)
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='predict')
def define_predict(self, obs):
""" use actor model of self.model to predict the action
"""
return self.predict(obs)
def predict(self, obs):
""" use actor model of self.model to predict the action
"""
return self.model.policy(obs)
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='learn')
def define_learn(self, obs, action, reward, next_obs, terminal):
""" update actor and critic model with DDPG algorithm
"""
return self.learn(obs, action, reward, next_obs, terminal)
def learn(self, obs, action, reward, next_obs, terminal):
""" update actor and critic model with DDPG algorithm
"""
actor_cost = self._actor_learn(obs)
critic_cost = self._critic_learn(obs, action, reward, next_obs,
terminal)
return actor_cost, critic_cost
def _actor_learn(self, obs):
action = self.model.policy(obs)
Q = self.model.value(obs, action)
cost = layers.reduce_mean(-1.0 * Q)
optimizer = fluid.optimizer.AdamOptimizer(self.actor_lr)
optimizer.minimize(cost, parameter_list=self.model.get_actor_params())
return cost
def _critic_learn(self, obs, action, reward, next_obs, terminal):
next_action = self.target_model.policy(next_obs)
next_Q = self.target_model.value(next_obs, next_action)
terminal = layers.cast(terminal, dtype='float32')
target_Q = reward + (1.0 - terminal) * self.gamma * next_Q
target_Q.stop_gradient = True
Q = self.model.value(obs, action)
cost = layers.square_error_cost(Q, target_Q)
cost = layers.reduce_mean(cost)
optimizer = fluid.optimizer.AdamOptimizer(self.critic_lr)
optimizer.minimize(cost)
return cost
def sync_target(self, gpu_id=None, decay=None):
if gpu_id is not None:
warnings.warn(
"the `gpu_id` argument of `sync_target` function in `parl.Algorithms.DDPG` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
if decay is None:
decay = 1.0 - self.tau
self.model.sync_weights_to(self.target_model, decay=decay)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
warnings.simplefilter('default')
import copy
import paddle.fluid as fluid
from parl.core.fluid.algorithm import Algorithm
from parl.core.fluid import layers
from parl.utils.deprecation import deprecated
__all__ = ['DQN']
class DQN(Algorithm):
def __init__(self,
model,
hyperparas=None,
act_dim=None,
gamma=None,
lr=None):
""" DQN algorithm
Args:
model (parl.Model): model defining forward network of Q function
hyperparas (dict): (deprecated) dict of hyper parameters.
act_dim (int): dimension of the action space
gamma (float): discounted factor for reward computation.
lr (float): learning rate.
"""
self.model = model
self.target_model = copy.deepcopy(model)
if hyperparas is not None:
warnings.warn(
"the `hyperparas` argument of `__init__` function in `parl.Algorithms.DQN` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.act_dim = hyperparas['action_dim']
self.gamma = hyperparas['gamma']
self.lr = hyperparas['lr']
else:
assert isinstance(act_dim, int)
assert isinstance(gamma, float)
assert isinstance(lr, float)
self.act_dim = act_dim
self.gamma = gamma
self.lr = lr
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='predict')
def define_predict(self, obs):
""" use value model self.model to predict the action value
"""
return self.predict(obs)
def predict(self, obs):
""" use value model self.model to predict the action value
"""
return self.model.value(obs)
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='learn')
def define_learn(self, obs, action, reward, next_obs, terminal):
return self.learn(obs, action, reward, next_obs, terminal)
def learn(self, obs, action, reward, next_obs, terminal):
""" update value model self.model with DQN algorithm
"""
pred_value = self.model.value(obs)
next_pred_value = self.target_model.value(next_obs)
best_v = layers.reduce_max(next_pred_value, dim=1)
best_v.stop_gradient = True
target = reward + (
1.0 - layers.cast(terminal, dtype='float32')) * self.gamma * best_v
action_onehot = layers.one_hot(action, self.act_dim)
action_onehot = layers.cast(action_onehot, dtype='float32')
pred_action_value = layers.reduce_sum(
layers.elementwise_mul(action_onehot, pred_value), dim=1)
cost = layers.square_error_cost(pred_action_value, target)
cost = layers.reduce_mean(cost)
optimizer = fluid.optimizer.Adam(self.lr, epsilon=1e-3)
optimizer.minimize(cost)
return cost
def sync_target(self, gpu_id=None):
""" sync weights of self.model to self.target_model
"""
if gpu_id is not None:
warnings.warn(
"the `gpu_id` argument of `sync_target` function in `parl.Algorithms.DQN` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.model.sync_weights_to(self.target_model)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from parl.algorithms.fluid.impala.impala import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
warnings.simplefilter('default')
import paddle.fluid as fluid
from parl.core.fluid.algorithm import Algorithm
from parl.core.fluid import layers
from parl.algorithms.fluid.impala import vtrace
from parl.core.fluid.policy_distribution import CategoricalDistribution
from parl.core.fluid.plutils import inverse
__all__ = ['IMPALA']
class VTraceLoss(object):
def __init__(self,
behaviour_actions_log_probs,
target_actions_log_probs,
policy_entropy,
dones,
discount,
rewards,
values,
bootstrap_value,
entropy_coeff=-0.01,
vf_loss_coeff=0.5,
clip_rho_threshold=1.0,
clip_pg_rho_threshold=1.0):
"""Policy gradient loss with vtrace importance weighting.
VTraceLoss takes tensors of shape [T, B, ...], where `B` is the
batch_size. The reason we need to know `B` is for V-trace to properly
handle episode cut boundaries.
Args:
behaviour_actions_log_probs: A float32 tensor of shape [T, B].
target_actions_log_probs: A float32 tensor of shape [T, B].
policy_entropy: A float32 tensor of shape [T, B].
dones: A float32 tensor of shape [T, B].
discount: A float32 scalar.
rewards: A float32 tensor of shape [T, B].
values: A float32 tensor of shape [T, B].
bootstrap_value: A float32 tensor of shape [B].
"""
self.vtrace_returns = vtrace.from_importance_weights(
behaviour_actions_log_probs=behaviour_actions_log_probs,
target_actions_log_probs=target_actions_log_probs,
discounts=inverse(dones) * discount,
rewards=rewards,
values=values,
bootstrap_value=bootstrap_value,
clip_rho_threshold=clip_rho_threshold,
clip_pg_rho_threshold=clip_pg_rho_threshold)
# The policy gradients loss
self.pi_loss = -1.0 * layers.reduce_sum(
target_actions_log_probs * self.vtrace_returns.pg_advantages)
# The baseline loss
delta = values - self.vtrace_returns.vs
self.vf_loss = 0.5 * layers.reduce_sum(layers.square(delta))
# The entropy loss (We want to maximize entropy, so entropy_ceoff < 0)
self.entropy = layers.reduce_sum(policy_entropy)
# The summed weighted loss
self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff +
self.entropy * entropy_coeff)
class IMPALA(Algorithm):
def __init__(self,
model,
hyperparas=None,
sample_batch_steps=None,
gamma=None,
vf_loss_coeff=None,
clip_rho_threshold=None,
clip_pg_rho_threshold=None):
""" IMPALA algorithm
Args:
model (parl.Model): forward network of policy and value
hyperparas (dict): (deprecated) dict of hyper parameters.
sample_batch_steps (int): steps of each environment sampling.
gamma (float): discounted factor for reward computation.
vf_loss_coeff (float): coefficient of the value function loss.
clip_rho_threshold (float): clipping threshold for importance weights (rho).
clip_pg_rho_threshold (float): clipping threshold on rho_s in \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)).
"""
if hyperparas is not None:
warnings.warn(
"the `hyperparas` argument of `__init__` function in `parl.Algorithms.IMPALA` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.sample_batch_steps = hyperparas['sample_batch_steps']
self.gamma = hyperparas['gamma']
self.vf_loss_coeff = hyperparas['vf_loss_coeff']
self.clip_rho_threshold = hyperparas['clip_rho_threshold']
self.clip_pg_rho_threshold = hyperparas['clip_pg_rho_threshold']
else:
assert isinstance(sample_batch_steps, int)
assert isinstance(gamma, float)
assert isinstance(vf_loss_coeff, float)
assert isinstance(clip_rho_threshold, float)
assert isinstance(clip_pg_rho_threshold, float)
self.sample_batch_steps = sample_batch_steps
self.gamma = gamma
self.vf_loss_coeff = vf_loss_coeff
self.clip_rho_threshold = clip_rho_threshold
self.clip_pg_rho_threshold = clip_pg_rho_threshold
self.model = model
def learn(self, obs, actions, behaviour_logits, rewards, dones,
learning_rate, entropy_coeff):
"""
Args:
obs: An float32 tensor of shape ([B] + observation_space).
E.g. [B, C, H, W] in atari.
actions: An int64 tensor of shape [B].
behaviour_logits: A float32 tensor of shape [B, NUM_ACTIONS].
rewards: A float32 tensor of shape [B].
dones: A float32 tensor of shape [B].
learning_rate: float scalar of learning rate.
entropy_coeff: float scalar of entropy coefficient.
"""
values = self.model.value(obs)
target_logits = self.model.policy(obs)
target_policy_distribution = CategoricalDistribution(target_logits)
behaviour_policy_distribution = CategoricalDistribution(
behaviour_logits)
policy_entropy = target_policy_distribution.entropy()
target_actions_log_probs = target_policy_distribution.logp(actions)
behaviour_actions_log_probs = behaviour_policy_distribution.logp(
actions)
# Calculating kl for debug
kl = target_policy_distribution.kl(behaviour_policy_distribution)
kl = layers.reduce_mean(kl)
"""
Split the tensor into batches at known episode cut boundaries.
[B * T] -> [T, B]
"""
T = self.sample_batch_steps
def split_batches(tensor):
B = tensor.shape[0] // T
splited_tensor = layers.reshape(tensor,
[B, T] + list(tensor.shape[1:]))
# transpose B and T
return layers.transpose(
splited_tensor, [1, 0] + list(range(2, 1 + len(tensor.shape))))
behaviour_actions_log_probs = split_batches(
behaviour_actions_log_probs)
target_actions_log_probs = split_batches(target_actions_log_probs)
policy_entropy = split_batches(policy_entropy)
dones = split_batches(dones)
rewards = split_batches(rewards)
values = split_batches(values)
# [T, B] -> [T - 1, B] for V-trace calc.
behaviour_actions_log_probs = layers.slice(
behaviour_actions_log_probs, axes=[0], starts=[0], ends=[-1])
target_actions_log_probs = layers.slice(
target_actions_log_probs, axes=[0], starts=[0], ends=[-1])
policy_entropy = layers.slice(
policy_entropy, axes=[0], starts=[0], ends=[-1])
dones = layers.slice(dones, axes=[0], starts=[0], ends=[-1])
rewards = layers.slice(rewards, axes=[0], starts=[0], ends=[-1])
bootstrap_value = layers.slice(
values, axes=[0], starts=[T - 1], ends=[T])
values = layers.slice(values, axes=[0], starts=[0], ends=[-1])
bootstrap_value = layers.squeeze(bootstrap_value, axes=[0])
vtrace_loss = VTraceLoss(
behaviour_actions_log_probs=behaviour_actions_log_probs,
target_actions_log_probs=target_actions_log_probs,
policy_entropy=policy_entropy,
dones=dones,
discount=self.gamma,
rewards=rewards,
values=values,
bootstrap_value=bootstrap_value,
entropy_coeff=entropy_coeff,
vf_loss_coeff=self.vf_loss_coeff,
clip_rho_threshold=self.clip_rho_threshold,
clip_pg_rho_threshold=self.clip_pg_rho_threshold)
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=40.0))
optimizer = fluid.optimizer.AdamOptimizer(learning_rate)
optimizer.minimize(vtrace_loss.total_loss)
return vtrace_loss, kl
def sample(self, obs):
"""
Args:
obs: An float32 tensor of shape ([B] + observation_space).
E.g. [B, C, H, W] in atari.
"""
logits = self.model.policy(obs)
policy_dist = CategoricalDistribution(logits)
sample_actions = policy_dist.sample()
return sample_actions, logits
def predict(self, obs):
"""
Args:
obs: An float32 tensor of shape ([B] + observation_space).
E.g. [B, C, H, W] in atari.
"""
logits = self.model.policy(obs)
probs = layers.softmax(logits)
predict_actions = layers.argmax(probs, 1)
return predict_actions
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for V-trace.
The following code is mainly referenced and copied from:
https://github.com/deepmind/scalable_agent/blob/master/vtrace_test.py
"""
import copy
import numpy as np
import unittest
from parl.core.fluid import layers
from paddle import fluid
from parameterized import parameterized
from parl.algorithms.fluid.impala import vtrace
from parl.utils import get_gpu_count
def _shaped_arange(*shape):
"""Runs np.arange, converts to float and reshapes."""
return np.arange(np.prod(shape), dtype=np.float32).reshape(*shape)
def _ground_truth_calculation(behaviour_actions_log_probs,
target_actions_log_probs, discounts, rewards,
values, bootstrap_value, clip_rho_threshold,
clip_pg_rho_threshold):
"""Calculates the ground truth for V-trace in Python/Numpy."""
log_rhos = target_actions_log_probs - behaviour_actions_log_probs
vs = []
seq_len = len(discounts)
rhos = np.exp(log_rhos)
cs = np.minimum(rhos, 1.0)
clipped_rhos = rhos
if clip_rho_threshold:
clipped_rhos = np.minimum(rhos, clip_rho_threshold)
clipped_pg_rhos = rhos
if clip_pg_rho_threshold:
clipped_pg_rhos = np.minimum(rhos, clip_pg_rho_threshold)
# This is a very inefficient way to calculate the V-trace ground truth.
# We calculate it this way because it is close to the mathematical notation of
# V-trace.
# v_s = V(x_s)
# + \sum^{T-1}_{t=s} \gamma^{t-s}
# * \prod_{i=s}^{t-1} c_i
# * \rho_t (r_t + \gamma V(x_{t+1}) - V(x_t))
# Note that when we take the product over c_i, we write `s:t` as the notation
# of the paper is inclusive of the `t-1`, but Python is exclusive.
# Also note that np.prod([]) == 1.
values_t_plus_1 = np.concatenate([values, bootstrap_value[None, :]],
axis=0)
for s in range(seq_len):
v_s = np.copy(values[s]) # Very important copy.
for t in range(s, seq_len):
v_s += (np.prod(discounts[s:t], axis=0) * np.prod(cs[s:t], axis=0)
* clipped_rhos[t] * (rewards[t] + discounts[t] *
values_t_plus_1[t + 1] - values[t]))
vs.append(v_s)
vs = np.stack(vs, axis=0)
pg_advantages = (clipped_pg_rhos * (rewards + discounts * np.concatenate(
[vs[1:], bootstrap_value[None, :]], axis=0) - values))
return vtrace.VTraceReturns(vs=vs, pg_advantages=pg_advantages)
class VtraceTest(unittest.TestCase):
def setUp(self):
gpu_count = get_gpu_count()
if gpu_count > 0:
place = fluid.CUDAPlace(0)
self.gpu_id = 0
else:
place = fluid.CPUPlace()
self.gpu_id = -1
self.executor = fluid.Executor(place)
@parameterized.expand([('Batch1', 1), ('Batch4', 4)])
def test_from_importance_weights(self, name, batch_size):
"""Tests V-trace against ground truth data calculated in python."""
seq_len = 5
# Create log_rhos such that rho will span from near-zero to above the
# clipping thresholds. In particular, calculate log_rhos in [-2.5, 2.5),
# so that rho is in approx [0.08, 12.2).
log_rhos = _shaped_arange(seq_len, batch_size) / (batch_size * seq_len)
log_rhos = 5 * (log_rhos - 0.5) # [0.0, 1.0) -> [-2.5, 2.5).
# Fake behaviour_actions_log_probs, target_actions_log_probs
target_actions_log_probs = log_rhos + 1.0
behaviour_actions_log_probs = np.ones(
shape=log_rhos.shape, dtype='float32')
values = {
'behaviour_actions_log_probs':
behaviour_actions_log_probs,
'target_actions_log_probs':
target_actions_log_probs,
# T, B where B_i: [0.9 / (i+1)] * T
'discounts':
np.array([[0.9 / (b + 1) for b in range(batch_size)]
for _ in range(seq_len)],
dtype=np.float32),
'rewards':
_shaped_arange(seq_len, batch_size),
'values':
_shaped_arange(seq_len, batch_size) / batch_size,
'bootstrap_value':
_shaped_arange(batch_size) + 1.0,
'clip_rho_threshold':
3.7,
'clip_pg_rho_threshold':
2.2,
}
# Calculated by numpy/python
ground_truth_v = _ground_truth_calculation(**values)
# Calculated by Fluid
test_program = fluid.Program()
with fluid.program_guard(test_program):
behaviour_actions_log_probs_input = layers.data(
name='behaviour_actions_log_probs',
shape=[seq_len, batch_size],
dtype='float32',
append_batch_size=False)
target_actions_log_probs_input = layers.data(
name='target_actions_log_probs',
shape=[seq_len, batch_size],
dtype='float32',
append_batch_size=False)
discounts_input = layers.data(
name='discounts',
shape=[seq_len, batch_size],
dtype='float32',
append_batch_size=False)
rewards_input = layers.data(
name='rewards',
shape=[seq_len, batch_size],
dtype='float32',
append_batch_size=False)
values_input = layers.data(
name='values',
shape=[seq_len, batch_size],
dtype='float32',
append_batch_size=False)
bootstrap_value_input = layers.data(
name='bootstrap_value',
shape=[batch_size],
dtype='float32',
append_batch_size=False)
fluid_inputs = {
'behaviour_actions_log_probs':
behaviour_actions_log_probs_input,
'target_actions_log_probs': target_actions_log_probs_input,
'discounts': discounts_input,
'rewards': rewards_input,
'values': values_input,
'bootstrap_value': bootstrap_value_input,
'clip_rho_threshold': 3.7,
'clip_pg_rho_threshold': 2.2,
}
output = vtrace.from_importance_weights(**fluid_inputs)
self.executor.run(fluid.default_startup_program())
feed = copy.copy(values)
del feed['clip_rho_threshold']
del feed['clip_pg_rho_threshold']
[output_vs, output_pg_advantage] = self.executor.run(
test_program,
feed=feed,
fetch_list=[output.vs, output.pg_advantages])
np.testing.assert_almost_equal(ground_truth_v.vs, output_vs, 5)
np.testing.assert_almost_equal(ground_truth_v.pg_advantages,
output_pg_advantage, 5)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions to compute V-trace off-policy actor critic targets,
which used in IMAPLA algorithm.
The following code is mainly referenced and copied from:
https://github.com/deepmind/scalable_agent/blob/master/vtrace.py
For details and theory see:
"Espeholt L, Soyer H, Munos R, et al. Impala: Scalable distributed
deep-rl with importance weighted actor-learner
architectures[J]. arXiv preprint arXiv:1802.01561, 2018."
"""
import collections
import paddle.fluid as fluid
from parl.core.fluid import layers
from parl.utils import MAX_INT32
VTraceReturns = collections.namedtuple('VTraceReturns',
['vs', 'pg_advantages'])
def from_importance_weights(behaviour_actions_log_probs,
target_actions_log_probs,
discounts,
rewards,
values,
bootstrap_value,
clip_rho_threshold=1.0,
clip_pg_rho_threshold=1.0,
name='vtrace_from_logits'):
r"""V-trace for softmax policies.
Calculates V-trace actor critic targets for softmax polices as described in
"IMPALA: Scalable Distributed Deep-RL with
Importance Weighted Actor-Learner Architectures"
by Espeholt, Soyer, Munos et al.
Target policy refers to the policy we are interested in improving and
behaviour policy refers to the policy that generated the given
rewards and actions.
In the notation used throughout documentation and comments, T refers to the
time dimension ranging from 0 to T-1. B refers to the batch size and
NUM_ACTIONS refers to the number of actions.
Args:
behaviour_actions_log_probs: A float32 tensor of shape [T, B] of
log-probabilities of actions in behaviour policy.
target_policy_logits: A float32 tensor of shape [T, B] of
log-probabilities of actions in target policy.
discounts: A float32 tensor of shape [T, B] with the discount encountered
when following the behaviour policy.
rewards: A float32 tensor of shape [T, B] with the rewards generated by
following the behaviour policy.
values: A float32 tensor of shape [T, B] with the value function estimates
wrt. the target policy.
bootstrap_value: A float32 of shape [B] with the value function estimate at
time T.
clip_rho_threshold: A scalar float32 tensor with the clipping threshold for
importance weights (rho) when calculating the baseline targets (vs).
rho^bar in the paper.
clip_pg_rho_threshold: A scalar float32 tensor with the clipping threshold
on rho_s in \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)).
name: The name scope that all V-trace operations will be created in.
Returns:
A VTraceReturns namedtuple (vs, pg_advantages) where:
vs: A float32 tensor of shape [T, B]. Can be used as target to
train a baseline (V(x_t) - vs_t)^2.
pg_advantages: A float32 tensor of shape [T, B]. Can be used as the
advantage in the calculation of policy gradients.
"""
rank = len(behaviour_actions_log_probs.shape) # Usually 2.
assert len(target_actions_log_probs.shape) == rank
assert len(values.shape) == rank
assert len(bootstrap_value.shape) == (rank - 1)
assert len(discounts.shape) == rank
assert len(rewards.shape) == rank
# log importance sampling weights.
# V-trace performs operations on rhos in log-space for numerical stability.
log_rhos = target_actions_log_probs - behaviour_actions_log_probs
if clip_rho_threshold is not None:
clip_rho_threshold = layers.fill_constant([1], 'float32',
clip_rho_threshold)
if clip_pg_rho_threshold is not None:
clip_pg_rho_threshold = layers.fill_constant([1], 'float32',
clip_pg_rho_threshold)
rhos = layers.exp(log_rhos)
if clip_rho_threshold is not None:
clipped_rhos = layers.elementwise_min(rhos, clip_rho_threshold)
else:
clipped_rhos = rhos
constant_one = layers.fill_constant([1], 'float32', 1.0)
cs = layers.elementwise_min(rhos, constant_one)
# Append bootstrapped value to get [v1, ..., v_t+1]
values_1_t = layers.slice(values, axes=[0], starts=[1], ends=[MAX_INT32])
values_t_plus_1 = layers.concat(
[values_1_t, layers.unsqueeze(bootstrap_value, [0])], axis=0)
# \delta_s * V
deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values)
vs_minus_v_xs = recursively_scan(discounts, cs, deltas)
# Add V(x_s) to get v_s.
vs = layers.elementwise_add(vs_minus_v_xs, values)
# Advantage for policy gradient.
vs_1_t = layers.slice(vs, axes=[0], starts=[1], ends=[MAX_INT32])
vs_t_plus_1 = layers.concat(
[vs_1_t, layers.unsqueeze(bootstrap_value, [0])], axis=0)
if clip_pg_rho_threshold is not None:
clipped_pg_rhos = layers.elementwise_min(rhos, clip_pg_rho_threshold)
else:
clipped_pg_rhos = rhos
pg_advantages = (
clipped_pg_rhos * (rewards + discounts * vs_t_plus_1 - values))
# Make sure no gradients backpropagated through the returned values.
vs.stop_gradient = True
pg_advantages.stop_gradient = True
return VTraceReturns(vs=vs, pg_advantages=pg_advantages)
def recursively_scan(discounts, cs, deltas):
""" Recursively calculate vs_minus_v_xs according to following equation:
vs_minus_v_xs(t) = deltas(t) + discounts(t) * cs(t) * vs_minus_v_xs(t + 1)
Args:
discounts: A float32 tensor of shape [T, B] with discounts encountered when
following the behaviour policy.
cs: A float32 tensor of shape [T, B], which corresponding to $c_s$ in the
origin paper.
deltas: A float32 tensor of shape [T, B], which corresponding to
$\delta_s * V$ in the origin paper.
Returns:
vs_minus_v_xs: A float32 tensor of shape [T, B], which corresponding to
$v_s - V(x_s)$ in the origin paper.
"""
# All sequences are reversed, computation starts from the back.
reverse_discounts = layers.reverse(x=discounts, axis=[0])
reverse_cs = layers.reverse(x=cs, axis=[0])
reverse_deltas = layers.reverse(x=deltas, axis=[0])
static_while = layers.StaticRNN()
# init: shape [B]
init = layers.fill_constant_batch_size_like(
discounts, shape=[1], dtype='float32', value=0.0, input_dim_idx=1)
with static_while.step():
discount_t = static_while.step_input(reverse_discounts)
c_t = static_while.step_input(reverse_cs)
delta_t = static_while.step_input(reverse_deltas)
vs_minus_v_xs_t_plus_1 = static_while.memory(init=init)
vs_minus_v_xs_t = delta_t + discount_t * c_t * vs_minus_v_xs_t_plus_1
static_while.update_memory(vs_minus_v_xs_t_plus_1, vs_minus_v_xs_t)
static_while.step_output(vs_minus_v_xs_t)
vs_minus_v_xs = static_while()
# Reverse the results back to original order.
vs_minus_v_xs = layers.reverse(vs_minus_v_xs, [0])
return vs_minus_v_xs
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
warnings.simplefilter('default')
import paddle.fluid as fluid
from parl.core.fluid.algorithm import Algorithm
from parl.core.fluid import layers
from parl.utils.deprecation import deprecated
__all__ = ['PolicyGradient']
class PolicyGradient(Algorithm):
def __init__(self, model, hyperparas=None, lr=None):
""" Policy Gradient algorithm
Args:
model (parl.Model): forward network of the policy.
hyperparas (dict): (deprecated) dict of hyper parameters.
lr (float): learning rate of the policy model.
"""
self.model = model
if hyperparas is not None:
warnings.warn(
"the `hyperparas` argument of `__init__` function in `parl.Algorithms.PolicyGradient` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.lr = hyperparas['lr']
else:
assert isinstance(lr, float)
self.lr = lr
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='predict')
def define_predict(self, obs):
""" use policy model self.model to predict the action probability
"""
return self.predict(obs)
def predict(self, obs):
""" use policy model self.model to predict the action probability
"""
return self.model.policy(obs)
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='learn')
def define_learn(self, obs, action, reward):
""" update policy model self.model with policy gradient algorithm
"""
return self.learn(obs, action, reward)
def learn(self, obs, action, reward):
""" update policy model self.model with policy gradient algorithm
"""
act_prob = self.model.policy(obs)
log_prob = layers.cross_entropy(act_prob, action)
cost = log_prob * reward
cost = layers.reduce_mean(cost)
optimizer = fluid.optimizer.Adam(self.lr)
optimizer.minimize(cost)
return cost
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
warnings.simplefilter('default')
import numpy as np
from copy import deepcopy
from paddle import fluid
from parl.core.fluid import layers
from parl.core.fluid.algorithm import Algorithm
from parl.utils.deprecation import deprecated
__all__ = ['PPO']
class PPO(Algorithm):
def __init__(self,
model,
hyperparas=None,
act_dim=None,
policy_lr=None,
value_lr=None,
epsilon=0.2):
""" PPO algorithm
Args:
model (parl.Model): model defining forward network of policy and value.
hyperparas (dict): (deprecated) dict of hyper parameters.
act_dim (float): dimension of the action space.
policy_lr (float): learning rate of the policy model.
value_lr (float): learning rate of the value model.
epsilon (float): epsilon used in the CLIP loss (default 0.2).
"""
self.model = model
# Used to calculate probability of action in old policy
self.old_policy_model = deepcopy(model.policy_model)
if hyperparas is not None:
warnings.warn(
"the `hyperparas` argument of `__init__` function in `parl.Algorithms.PPO` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.act_dim = hyperparas['act_dim']
self.policy_lr = hyperparas['policy_lr']
self.value_lr = hyperparas['value_lr']
if 'epsilon' in hyperparas:
self.epsilon = hyperparas['epsilon']
else:
self.epsilon = 0.2 # default
else:
assert isinstance(act_dim, int)
assert isinstance(policy_lr, float)
assert isinstance(value_lr, float)
assert isinstance(epsilon, float)
self.act_dim = act_dim
self.policy_lr = policy_lr
self.value_lr = value_lr
self.epsilon = epsilon
def _calc_logprob(self, actions, means, logvars):
""" Calculate log probabilities of actions, when given means and logvars
of normal distribution.
The constant sqrt(2 * pi) is omitted, which will be eliminated in later.
Args:
actions: shape (batch_size, act_dim)
means: shape (batch_size, act_dim)
logvars: shape (act_dim)
Returns:
logprob: shape (batch_size)
"""
exp_item = layers.elementwise_div(
layers.square(actions - means), layers.exp(logvars), axis=1)
exp_item = -0.5 * layers.reduce_sum(exp_item, dim=1)
vars_item = -0.5 * layers.reduce_sum(logvars)
logprob = exp_item + vars_item
return logprob
def _calc_kl(self, means, logvars, old_means, old_logvars):
""" Calculate KL divergence between old and new distributions
See: https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Kullback.E2.80.93Leibler_divergence
Args:
means: shape (batch_size, act_dim)
logvars: shape (act_dim)
old_means: shape (batch_size, act_dim)
old_logvars: shape (act_dim)
Returns:
kl: shape (batch_size)
"""
log_det_cov_old = layers.reduce_sum(old_logvars)
log_det_cov_new = layers.reduce_sum(logvars)
tr_old_new = layers.reduce_sum(layers.exp(old_logvars - logvars))
kl = 0.5 * (layers.reduce_sum(
layers.square(means - old_means) / layers.exp(logvars), dim=1) + (
log_det_cov_new - log_det_cov_old) + tr_old_new - self.act_dim)
return kl
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='predict')
def define_predict(self, obs):
""" Use policy model of self.model to predict means and logvars of actions
"""
return self.predict(obs)
def predict(self, obs):
""" Use the policy model of self.model to predict means and logvars of actions
"""
means, logvars = self.model.policy(obs)
return means
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='sample')
def define_sample(self, obs):
""" Use the policy model of self.model to sample actions
"""
return self.sample(obs)
def sample(self, obs):
""" Use the policy model of self.model to sample actions
"""
sampled_act = self.model.policy_sample(obs)
return sampled_act
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='policy_learn')
def define_policy_learn(self, obs, actions, advantages, beta=None):
""" Learn policy model with:
1. CLIP loss: Clipped Surrogate Objective
2. KLPEN loss: Adaptive KL Penalty Objective
See: https://arxiv.org/pdf/1707.02286.pdf
Args:
obs: Tensor, (batch_size, obs_dim)
actions: Tensor, (batch_size, act_dim)
advantages: Tensor (batch_size, )
beta: Tensor (1) or None
if None, use CLIP Loss; else, use KLPEN loss.
"""
return self.policy_learn(obs, actions, advantages, beta)
def policy_learn(self, obs, actions, advantages, beta=None):
""" Learn policy model with:
1. CLIP loss: Clipped Surrogate Objective
2. KLPEN loss: Adaptive KL Penalty Objective
See: https://arxiv.org/pdf/1707.02286.pdf
Args:
obs: Tensor, (batch_size, obs_dim)
actions: Tensor, (batch_size, act_dim)
advantages: Tensor (batch_size, )
beta: Tensor (1) or None
if None, use CLIP Loss; else, use KLPEN loss.
"""
old_means, old_logvars = self.old_policy_model.policy(obs)
old_means.stop_gradient = True
old_logvars.stop_gradient = True
old_logprob = self._calc_logprob(actions, old_means, old_logvars)
means, logvars = self.model.policy(obs)
logprob = self._calc_logprob(actions, means, logvars)
kl = self._calc_kl(means, logvars, old_means, old_logvars)
kl = layers.reduce_mean(kl)
if beta is None: # Clipped Surrogate Objective
pg_ratio = layers.exp(logprob - old_logprob)
clipped_pg_ratio = layers.clip(pg_ratio, 1 - self.epsilon,
1 + self.epsilon)
surrogate_loss = layers.elementwise_min(
advantages * pg_ratio, advantages * clipped_pg_ratio)
loss = 0 - layers.reduce_mean(surrogate_loss)
else: # Adaptive KL Penalty Objective
# policy gradient loss
loss1 = 0 - layers.reduce_mean(
advantages * layers.exp(logprob - old_logprob))
# adaptive kl loss
loss2 = kl * beta
loss = loss1 + loss2
optimizer = fluid.optimizer.AdamOptimizer(self.policy_lr)
optimizer.minimize(loss)
return loss, kl
@deprecated(
deprecated_in='1.2',
removed_in='1.3',
replace_function='value_predict')
def define_value_predict(self, obs):
""" Use value model of self.model to predict value of obs
"""
return self.value_predict(obs)
def value_predict(self, obs):
""" Use value model of self.model to predict value of obs
"""
return self.model.value(obs)
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='value_learn')
def define_value_learn(self, obs, val):
""" Learn value model with square error cost
"""
return self.value_learn(obs, val)
def value_learn(self, obs, val):
""" Learn the value model with square error cost
"""
predict_val = self.model.value(obs)
loss = layers.square_error_cost(predict_val, val)
loss = layers.reduce_mean(loss)
optimizer = fluid.optimizer.AdamOptimizer(self.value_lr)
optimizer.minimize(loss)
return loss
def sync_old_policy(self, gpu_id=None):
""" Synchronize weights of self.model.policy_model to self.old_policy_model
"""
if gpu_id is not None:
warnings.warn(
"the `gpu_id` argument of `sync_old_policy` function in `parl.Algorithms.PPO` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.model.policy_model.sync_weights_to(self.old_policy_model)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from parl.core.agent_base import *
from parl.core.model_base import *
from parl.core.algorithm_base import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class AgentBase(object):
"""`AgentBase` is the base class of the `parl.Agent` in different frameworks.
`parl.Agent` is responsible for the general data flow outside the algorithm.
"""
def __init__(self, algorithm):
"""
Args:
algorithm (`AlgorithmBase`): an instance of `AlgorithmBase`
"""
self.algorithm = algorithm
def get_weights(self, model_ids=None):
"""Get weights of the agent.
If `model_ids` is not None, will only return weights of
models whose model_id are in `model_ids`.
Note:
`ModelBase` in list, tuple and dict will be included. But `ModelBase` in
nested list, tuple and dict won't be included.
Args:
model_ids (List/Set): list/set of model_id, will only return weights of models
whiose model_id in the `model_ids`.
Returns:
(Dict): Dict of weights ({attribute name: numpy array/List/Dict})
"""
return self.algorithm.get_weights(model_ids=model_ids)
def set_weights(self, weights, model_ids=None):
"""Set weights of the agent with given weights.
If `model_ids` is not None, will only set weights of
models whose model_id are in `model_ids`.
Note:
`ModelBase` in list, tuple and dict will be included. But `ModelBase` in
nested list, tuple and dict won't be included.
Args:
weights (Dict): Dict of weights ({attribute name: numpy array/List/Dict})
model_ids (List/Set): list/set of model_id, will only set weights of models
whiose model_id in the `model_ids`.
"""
self.algorithm.set_weights(weights, model_ids=model_ids)
def get_model_ids(self):
"""Get all model ids of the self.algorithm in the agent.
Returns:
List of model_id
"""
return self.algorithm.get_model_ids()
@property
def model_ids(self):
return self.get_model_ids()
def learn(self, *args, **kwargs):
"""The training interface for Agent.
This function will usually do the following things:
1. Accept numpy data as input;
2. Feed numpy data or onvert numpy data to tensor (optional);
3. Call learn function in `Algorithm`.
"""
raise NotImplementedError
def predict(self, *args, **kwargs):
"""Predict the action when given the observation of the enviroment.
In general, this function is used in test process.
This function will usually do the following things:
1. Accept numpy data as input;
2. Feed numpy data or onvert numpy data to tensor (optional);
3. Call predict function in `Algorithm`.
"""
raise NotImplementedError
def sample(self, *args, **kwargs):
"""Sample the action when given the observation of the enviroment.
In general, this function is used in train process.
This function will usually do the following things:
1. Accept numpy data as input;
2. Feed numpy data or onvert numpy data to tensor (optional);
3. Call predict or sample function in `Algorithm`;
4. Add sampling operation in numpy level. (unnecessary if sampling operation have done in `Algorithm`).
"""
raise NotImplementedError
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from parl.core.model_base import ModelBase
class AlgorithmBase(object):
"""`AlgorithmBase` is the base class of the `parl.Algorithm` in different
frameworks.
This base class mainly do the following things:
1. Implements APIs to set or get weights of all `ModelBase` in self.__dict__;
2. Defines common APIs that `parl.Algorithm` should implement in different frameworks.
"""
def __init__(self):
pass
def get_weights(self, model_ids=None):
"""Get weights of all `ModelBase` in self.__dict__.
If `model_ids` is not None, will only return weights of
models whose model_id are in `model_ids`.
Note:
`ModelBase` in list, tuple and dict will be included. But `ModelBase` in
nested list, tuple and dict won't be included.
Args:
model_ids (List/Set): list/set of model_id, will only return weights of models
whiose model_id in the `model_ids`.
Returns:
Dict of weights ({attribute name: numpy array/List/Dict})
"""
if model_ids is not None:
assert isinstance(model_ids, (list, set))
model_ids = set(model_ids)
model_weights = {}
for key in self.__dict__.keys():
value = getattr(self, key)
if isinstance(value, ModelBase):
if model_ids is None or value.model_id in model_ids:
model_weights[key] = value.get_weights()
elif isinstance(value, list) or isinstance(value, tuple):
weights_list = []
for x in value:
if isinstance(x, ModelBase):
if model_ids is None or x.model_id in model_ids:
weights_list.append(x.get_weights())
if weights_list:
model_weights[key] = weights_list
elif isinstance(value, dict):
weights_dict = {}
for sub_k, sub_v in value.items():
if isinstance(sub_v, ModelBase):
if model_ids is None or sub_v.model_id in model_ids:
weights_dict[sub_k] = sub_v.get_weights()
if weights_dict:
model_weights[key] = weights_dict
return model_weights
def set_weights(self, weights, model_ids=None):
"""Set weights of all `ModelBase` in self.__dict__.
If `model_ids` is not None, will only set weights of
models whose model_id are in `model_ids`.
Note:
`ModelBase` in list, tuple and dict will be included. But `ModelBase` in
nested list, tuple and dict won't be included.
Args:
weights (Dict): Dict of weights ({attribute name: numpy array/List/Dict})
model_ids (List/Set): list/set of model_id, will only set weights of models
whiose model_id in the `model_ids`.
"""
assert isinstance(weights, dict)
if model_ids is not None:
assert isinstance(model_ids, (list, set))
model_ids = set(model_ids)
for key in self.__dict__.keys():
value = getattr(self, key)
if isinstance(value, ModelBase):
if model_ids is None or value.model_id in model_ids:
assert key in weights, "weights is inconsistent with current algorithm and given model_ids."
value.set_weights(weights[key])
elif isinstance(value, list) or isinstance(value, tuple):
model_list = []
for x in value:
if isinstance(x, ModelBase):
if model_ids is None or x.model_id in model_ids:
model_list.append(x)
if model_list:
assert key in weights and len(model_list) == len(weights[key]), \
"weights is inconsistent with current algorithm and given model_ids."
for i, model in enumerate(model_list):
model.set_weights(weights[key][i])
elif isinstance(value, dict):
model_dict = {}
for sub_k, sub_v in value.items():
if isinstance(sub_v, ModelBase):
if model_ids is None or sub_v.model_id in model_ids:
model_dict[sub_k] = sub_v
if model_dict:
assert key in weights and set(model_dict.keys()) == set(weights[key].keys()), \
"weights is inconsistent with current algorithm and given model_ids."
for sub_k, model in model_dict.items():
model.set_weights(weights[key][sub_k])
def get_model_ids(self):
"""Get model_id of all `ModelBase` in self.__dict__.
Note:
`ModelBase` in list, tuple and dict will be included. But `ModelBase` in
nested list, tuple and dict won't be included.
Returns:
Set of model_id
"""
model_ids = set([])
for key in self.__dict__.keys():
value = getattr(self, key)
if isinstance(value, ModelBase):
model_ids.add(value.model_id)
elif isinstance(value, list) or isinstance(value, tuple):
for x in value:
if isinstance(x, ModelBase):
model_ids.add(x.model_id)
elif isinstance(value, dict):
for sub_k, sub_v in value.items():
if isinstance(sub_v, ModelBase):
model_ids.add(sub_v.model_id)
return model_ids
@property
def model_ids(self):
return self.get_model_ids()
def learn(self, *args, **kwargs):
""" define learning process, such as how to optimize the model.
"""
raise NotImplementedError
def predict(self, *args, **kwargs):
""" define predicting process, such as using policy model to predict actions when given observations.
"""
raise NotImplementedError
def sample(self, *args, **kwargs):
""" define sampling process, such as using policy model to sample actions when given observations.
"""
raise NotImplementedError
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from parl.core.fluid.model import *
from parl.core.fluid.algorithm import *
from parl.core.fluid.agent import *
from . import layers
from . import plutils
from . import policy_distribution
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
warnings.simplefilter('default')
import paddle.fluid as fluid
from parl.core.fluid import layers
from parl.utils.deprecation import deprecated
from parl.core.agent_base import AgentBase
from parl.core.fluid.algorithm import Algorithm
from parl.utils import machine_info
__all__ = ['Agent']
class Agent(AgentBase):
def __init__(self, algorithm, gpu_id=None):
"""Build program and run initialization for default_startup_program
Args:
algorithm (parl.Algorithm): instance of `parl.core.fluid.algorithm.Algorithm`
"""
if gpu_id is not None:
warnings.warn(
"the `gpu_id` argument of `__init__` function in `parl.Agent` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
assert isinstance(algorithm, Algorithm)
super(Agent, self).__init__(algorithm)
# alias for self.algorithm
# use self.algorithm is suggested
self.alg = algorithm
self.gpu_id = 0 if machine_info.is_gpu_available() else -1
self.build_program()
self.place = fluid.CUDAPlace(
0) if machine_info.is_gpu_available() else fluid.CPUPlace()
self.fluid_executor = fluid.Executor(self.place)
self.fluid_executor.run(fluid.default_startup_program())
def build_program(self):
"""Build leran/predict/sample program here with the
learn/predict/sample function defined in algorithm.
Note:
It's unnecessary to call this function explictly since
it will be called automatically in the initialization function.
To build the program, you may need to do the following:
a. Create a new program of fluid with program guard;
b. Define data input layers;
c. Pass the data variable defined in step b to learn/predict/sample of algorithm;
"""
raise NotImplementedError
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='get_weights')
def get_params(self):
""" Get parameters of self.algorithm
Returns:
List of numpy array.
"""
return self.algorithm.get_params()
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='set_weights')
def set_params(self, params):
"""Set parameters of self.algorithm
Args:
params: List of numpy array.
"""
self.algorithm.set_params(params)
def learn(self, *args, **kwargs):
"""The training interface for Agent.
This function will usually do the following things:
1. Accept numpy data as input;
2. Feed numpy data;
3. Run learn program defined in `build_program`.
"""
raise NotImplementedError
def predict(self, *args, **kwargs):
"""Predict the action when given the observation of the enviroment.
In general, this function is used in test process.
This function will usually do the following things:
1. Accept numpy data as input;
2. Feed numpy data;
3. Run predict program defined in `build_program`.
"""
raise NotImplementedError
def sample(self, *args, **kwargs):
"""Sample the action when given the observation of the enviroment.
In general, this function is used in train process.
This function will usually do the following things:
1. Accept numpy data as input;
2. Feed numpy data;
3. Run predict/sample program defined in `build_program`.
4. Add sampling operation in numpy level. (unnecessary if sampling operation have done in `Algorithm`).
"""
raise NotImplementedError
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
warnings.simplefilter('default')
from parl.core.algorithm_base import AlgorithmBase
from parl.core.fluid.model import Model
from parl.utils.deprecation import deprecated
__all__ = ['Algorithm']
class Algorithm(AlgorithmBase):
"""Algorithm defines the way how we update the model.
To implement a new algorithm, you may need implement the learn/predict/sample functions.
Before creating a customized algorithm, please do check algorithms of PARL.
Most common used algorithms like DQN/DDPG/PPO/A3C/IMPALA have been provided in `parl.algorithms`,
go and have a try.
"""
def __init__(self, model=None, hyperparas=None):
if model is not None:
warnings.warn(
"the `model` argument of `__init__` function in `parl.Algorithm` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
assert isinstance(model, Model)
self.model = model
if hyperparas is not None:
warnings.warn(
"the `hyperparas` argument of `__init__` function in `parl.Algorithm` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
self.hp = hyperparas
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='get_weights')
def get_params(self):
""" Get parameters of self.model
Returns:
List of numpy array.
"""
return self.model.get_params()
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='set_weights')
def set_params(self, params):
""" Set parameters of self.model
Args:
params: List of numpy array.
"""
self.model.set_params(params)
def learn(self, *args, **kwargs):
""" define learning process, such as how to optimize the model.
"""
raise NotImplementedError
def predict(self, *args, **kwargs):
""" define predicting process, such as using policy model to predict actions when given observations.
"""
raise NotImplementedError
def sample(self, *args, **kwargs):
""" define sampling process, such as using policy model to sample actions when given observations.
"""
raise NotImplementedError
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file wraps Fluid layers that have parameters to support parameter sharing.
For other layers that don't have parameters, we simply copy them to this namespace.
"""
from paddle.fluid.layers import *
from parl.core.fluid.layers.layer_wrappers import *
......@@ -40,7 +40,7 @@ from paddle.fluid.executor import _fetch_var
from paddle.fluid.framework import Variable
from paddle.fluid.layers import *
from paddle.fluid.param_attr import ParamAttr
from parl.layers.attr_holder import AttrHolder
from parl.core.fluid.layers.attr_holder import AttrHolder
def update_attr_name(name, default_name, attr, is_bias):
......
......@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import parl.layers as layers
import parl.core.fluid.layers as layers
import unittest
from parl.framework.model_base import Network
from parl.core.fluid.model import Model
class MyNetWork(Network):
class MyNetWork(Model):
def __init__(self):
self.fc1 = layers.fc(100)
self.fc2 = layers.fc(100)
......
......@@ -14,12 +14,12 @@
import numpy as np
import paddle.fluid as fluid
import parl.layers as layers
import unittest
from parl.framework.model_base import Network
from parl import layers
from parl.core.fluid.model import Model
class MyNetWork(Network):
class MyNetWork(Model):
def __init__(self):
self.fc1 = layers.fc(64, bias_attr=False)
self.fc2 = layers.fc(64, bias_attr=False)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import paddle.fluid as fluid
from parl.core.fluid.layers.layer_wrappers import LayerFunc
from parl.core.fluid.plutils import *
from parl.core.model_base import ModelBase
from parl.utils.deprecation import deprecated
from parl.utils import machine_info
__all__ = ['Model']
class Model(ModelBase):
"""A `Model`, a collection of `parl.layers`, is owned by an `Algorithm`.
It implements the entire network (forward part) to solve a specific problem.
`Model` can also use deepcopy way to construct target model, which has the same structure as initial model.
Note that only the model definition is copied here. To copy the parameters from the current model
to the target model, you must explicitly use `sync_weights_to` function after the program is initialized.
NOTE:
You need initialize start up program before calling `sync_weights_to` API.
Here is an example:
.. code-block:: python
import parl.layers as layers
import parl.Model as Model
class MLPModel(Model):
def __init__(self):
self.fc = layers.fc(size=64)
def policy(self, obs):
out = self.fc(obs)
return out
model = MLPModel()
target_model = deepcopy(model) # automatically create new unique parameters names for target_model.fc
# build program
x = layers.data(name='x', shape=[100], dtype="float32")
y1 = model.policy(x)
y2 = target_model.policy(x)
...
# Need initialize program before calling sync_weights_to
fluid_executor.run(fluid.default_startup_program())
...
# synchronize parameters
model.sync_weights_to(target_model)
"""
@deprecated(
deprecated_in='1.2',
removed_in='1.3',
replace_function='sync_weights_to')
def sync_params_to(self,
target_net,
gpu_id=None,
decay=0.0,
share_vars_parallel_executor=None):
"""Synchronize parameters in the model to another model (target_net).
target_net_weights = decay * target_net_weights + (1 - decay) * source_net_weights
Args:
target_net (`Model`): `Model` object deepcopy from source `Model`.
decay (float): Float. The decay to use.
share_vars_parallel_executor (fluid.ParallelExecutor): if not None, will use fluid.ParallelExecutor
to run program instead of fluid.Executor
"""
self.sync_weights_to(
other_model=target_net,
decay=decay,
share_vars_parallel_executor=share_vars_parallel_executor)
def sync_weights_to(self,
other_model,
decay=0.0,
share_vars_parallel_executor=None):
"""Synchronize weights in the model to another model.
To speed up the synchronizing process, will create a program implictly to finish the process. And will
also cache the program to avoid creating program repeatedly.
other_model_weights = decay * other_model_weights + (1 - decay) * current_model_weights
Args:
other_model (`parl.Model`): object instanced from the same `parl.Model` class with current model.
decay (float): Float. The decay to use.
share_vars_parallel_executor (fluid.ParallelExecutor): if not None, will use fluid.ParallelExecutor
to run program instead of fluid.Executor
"""
args_hash_id = hashlib.md5('{}_{}'.format(
id(other_model), decay).encode('utf-8')).hexdigest()
has_cached = False
try:
if self._cached_id == args_hash_id:
has_cached = True
except AttributeError:
has_cached = False
if not has_cached:
# Can not run _cached program, need create a new program
self._cached_id = args_hash_id
assert not other_model is self, "cannot copy between identical model"
assert isinstance(other_model, Model)
assert self.__class__.__name__ == other_model.__class__.__name__, \
"must be the same class for params syncing!"
assert (decay >= 0 and decay <= 1)
param_pairs = self._get_parameter_pairs(self, other_model)
self._cached_sync_weights_program = fluid.Program()
with fluid.program_guard(self._cached_sync_weights_program):
for (src_var_name, target_var_name) in param_pairs:
src_var = fetch_framework_var(src_var_name)
target_var = fetch_framework_var(target_var_name)
fluid.layers.assign(
decay * target_var + (1 - decay) * src_var, target_var)
if share_vars_parallel_executor is None:
# use fluid.Executor
place = fluid.CUDAPlace(0) if machine_info.is_gpu_available(
) else fluid.CPUPlace()
self._cached_fluid_executor = fluid.Executor(place)
else:
# use fluid.ParallelExecutor
# specify strategy to make ParallelExecutor run faster
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.use_experimental_executor = True
exec_strategy.num_threads = 4
build_strategy = fluid.BuildStrategy()
build_strategy.remove_unnecessary_lock = True
with fluid.scope_guard(fluid.global_scope().new_scope()):
self._cached_fluid_executor = fluid.ParallelExecutor(
use_cuda=machine_info.is_gpu_available(),
main_program=self._cached_sync_weights_program,
share_vars_from=share_vars_parallel_executor,
exec_strategy=exec_strategy,
build_strategy=build_strategy,
)
if share_vars_parallel_executor is None:
self._cached_fluid_executor.run(self._cached_sync_weights_program)
else:
self._cached_fluid_executor.run(fetch_list=[])
@property
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='parameters')
def parameter_names(self):
"""Get param_attr names of all parameters in the Model.
Only parameter created by parl.layers included.
The order of parameter names will be consistent between
different instances of same `Model`.
Returns:
list of string, param_attr names of all parameters
"""
return self.parameters()
def parameters(self):
"""Get param_attr names of all parameters in the Model.
Only parameter created by parl.layers included.
The order of parameter names will be consistent between
different instances of same `Model`.
Returns:
list of string, param_attr names of all parameters
"""
try:
return self._parameter_names
except AttributeError:
self._parameter_names = self._get_parameter_names(self)
return self._parameter_names
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='get_weights')
def get_params(self):
"""Get numpy arrays of parameters in the model.
Returns:
List of numpy array.
"""
return self.get_weights()
@deprecated(
deprecated_in='1.2', removed_in='1.3', replace_function='set_weights')
def set_params(self, params, gpu_id=None):
"""Set parameters in the model with params.
Args:
params (List): List of numpy array.
"""
self.set_weights(weights=params)
def get_weights(self):
"""Get numpy arrays of weights in the model.
Returns:
List of numpy array.
"""
weights = []
for param_name in self.parameters():
weight = fetch_value(param_name)
weights.append(weight)
return weights
def set_weights(self, weights):
"""Set weights in the model with given `weights`.
Args:
weights (List): List of numpy array.
"""
assert len(weights) == len(self.parameters()), \
'size of input weights should be same as weights number of current model'
for (param_name, weight) in list(zip(self.parameters(), weights)):
set_value(param_name, weight)
def _get_parameter_names(self, obj):
""" Recursively get parameter names in obj,
Args:
obj (`Model`/`LayerFunc`/list/tuple/dict): input object
Returns:
parameter_names (List): all parameter names in obj
"""
parameter_names = []
for attr in sorted(obj.__dict__.keys()):
val = getattr(obj, attr)
if isinstance(val, Model):
parameter_names.extend(self._get_parameter_names(val))
elif isinstance(val, LayerFunc):
for attr in val.attr_holder.sorted():
if attr:
parameter_names.append(attr.name)
elif isinstance(val, tuple) or isinstance(val, list):
for x in val:
parameter_names.extend(self._get_parameter_names(x))
elif isinstance(val, dict):
for x in list(val.values()):
parameter_names.extend(self._get_parameter_names(x))
else:
# for any other type, won't be handled. E.g. set
pass
return parameter_names
def _get_parameter_pairs(self, src, target):
""" Recursively gets parameters in source model and
corresponding parameters in target model.
Args:
src (`Model`/`LayerFunc`/list/tuple/dict): source object
target (`Model`/`LayerFunc`/list/tuple/dict): target object
Returns:
param_pairs (list of tuple): all parameter names in source model
and corresponding parameter names in
target model.
"""
param_pairs = []
if isinstance(src, Model):
for attr in src.__dict__:
if not attr in target.__dict__:
continue
src_var = getattr(src, attr)
target_var = getattr(target, attr)
param_pairs.extend(
self._get_parameter_pairs(src_var, target_var))
elif isinstance(src, LayerFunc):
src_attrs = src.attr_holder.sorted()
target_attrs = target.attr_holder.sorted()
assert len(src_attrs) == len(target_attrs), \
"number of ParamAttr between source layer and target layer should be same."
for (src_attr, target_attr) in zip(src_attrs, target_attrs):
if src_attr:
assert target_attr, "ParamAttr between source layer and target layer is inconsistent."
param_pairs.append((src_attr.name, target_attr.name))
elif isinstance(src, tuple) or isinstance(src, list):
for src_var, target_var in zip(src, target):
param_pairs.extend(
self._get_parameter_pairs(src_var, target_var))
elif isinstance(src, dict):
for k in src.keys():
assert k in target
param_pairs.extend(
self._get_parameter_pairs(src[k], target[k]))
else:
# for any other type, won't be handled. E.g. set
pass
return param_pairs
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from parl.core.fluid.plutils.common import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Common functions of PARL framework
"""
import paddle.fluid as fluid
from paddle.fluid.executor import _fetch_var
from parl.utils import machine_info
__all__ = ['fetch_framework_var', 'fetch_value', 'set_value', 'inverse']
def fetch_framework_var(attr_name):
""" Fetch framework variable according given attr_name.
Return a new reusing variable through create_parameter way
Args:
attr_name: string, attr name of parameter
Returns:
framework_var: framework.Varialbe
"""
scope = fluid.executor.global_scope()
core_var = scope.find_var(attr_name)
shape = core_var.get_tensor().shape()
framework_var = fluid.layers.create_parameter(
shape=shape, dtype='float32', attr=fluid.ParamAttr(name=attr_name))
return framework_var
def fetch_value(attr_name):
""" Given name of ParamAttr, fetch numpy value of the parameter in global_scope
Args:
attr_name: ParamAttr name of parameter
Returns:
numpy.ndarray
"""
return _fetch_var(attr_name, return_numpy=True)
def set_value(attr_name, value):
""" Given name of ParamAttr, set numpy value to the parameter in global_scope
Args:
attr_name: ParamAttr name of parameter
value: numpy array
"""
place = fluid.CUDAPlace(
0) if machine_info.is_gpu_available() else fluid.CPUPlace()
var = _fetch_var(attr_name, return_numpy=False)
var.set(value, place)
def inverse(x):
""" Inverse 0/1 variable
Args:
x: variable with float32 dtype
Returns:
inverse_x: variable with float32 dtype
"""
inverse_x = -1.0 * x + 1.0
return inverse_x
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from parl.core.fluid import layers
__all__ = ['PolicyDistribution', 'CategoricalDistribution']
class PolicyDistribution(object):
def sample(self):
"""Sampling from the policy distribution."""
raise NotImplementedError
def entropy(self):
"""The entropy of the policy distribution."""
raise NotImplementedError
def kl(self, other):
"""The KL-divergence between self policy distributions and other."""
raise NotImplementedError
def logp(self, actions):
"""The log-probabilities of the actions in this policy distribution."""
raise NotImplementedError
class CategoricalDistribution(PolicyDistribution):
"""Categorical distribution for discrete action spaces."""
def __init__(self, logits):
"""
Args:
logits: A float32 tensor with shape [BATCH_SIZE, NUM_ACTIONS] of unnormalized policy logits
"""
assert len(logits.shape) == 2
self.logits = logits
def sample(self):
"""
Returns:
sample_action: An int64 tensor with shape [BATCH_SIZE] of multinomial sampling ids.
Each value in sample_action is in [0, NUM_ACTIOINS - 1]
"""
probs = layers.softmax(self.logits)
sample_actions = layers.sampling_id(probs)
return sample_actions
def entropy(self):
"""
Returns:
entropy: A float32 tensor with shape [BATCH_SIZE] of entropy of self policy distribution.
"""
logits = self.logits - layers.reduce_max(self.logits, dim=1)
e_logits = layers.exp(logits)
z = layers.reduce_sum(e_logits, dim=1)
prob = e_logits / z
entropy = -1.0 * layers.reduce_sum(
prob * (logits - layers.log(z)), dim=1)
return entropy
def logp(self, actions, eps=1e-6):
"""
Args:
actions: An int64 tensor with shape [BATCH_SIZE]
eps: A small float constant that avoids underflows when computing the log probability
Returns:
actions_log_prob: A float32 tensor with shape [BATCH_SIZE]
"""
assert len(actions.shape) == 1
logits = self.logits - layers.reduce_max(self.logits, dim=1)
e_logits = layers.exp(logits)
z = layers.reduce_sum(e_logits, dim=1)
prob = e_logits / z
actions = layers.unsqueeze(actions, axes=[1])
actions_onehot = layers.one_hot(actions, prob.shape[1])
actions_onehot = layers.cast(actions_onehot, dtype='float32')
actions_prob = layers.reduce_sum(prob * actions_onehot, dim=1)
actions_prob = actions_prob + eps
actions_log_prob = layers.log(actions_prob)
return actions_log_prob
def kl(self, other):
"""
Args:
other: object of CategoricalDistribution
Returns:
kl: A float32 tensor with shape [BATCH_SIZE]
"""
assert isinstance(other, CategoricalDistribution)
logits = self.logits - layers.reduce_max(self.logits, dim=1)
other_logits = other.logits - layers.reduce_max(other.logits, dim=1)
e_logits = layers.exp(logits)
other_e_logits = layers.exp(other_logits)
z = layers.reduce_sum(e_logits, dim=1)
other_z = layers.reduce_sum(other_e_logits, dim=1)
prob = e_logits / z
kl = layers.reduce_sum(
prob *
(logits - layers.log(z) - other_logits + layers.log(other_z)),
dim=1)
return kl
......@@ -13,12 +13,12 @@
# limitations under the License.
import numpy as np
import parl.layers as layers
import unittest
from paddle import fluid
from parl.framework.agent_base import Agent
from parl.framework.algorithm_base import Algorithm
from parl.framework.model_base import Model
from parl import layers
from parl.core.fluid.agent import Agent
from parl.core.fluid.algorithm import Algorithm
from parl.core.fluid.model import Model
from parl.utils.machine_info import get_gpu_count
......@@ -34,10 +34,10 @@ class TestModel(Model):
class TestAlgorithm(Algorithm):
def __init__(self, model, hyperparas=None):
super(TestAlgorithm, self).__init__(model, hyperparas)
def __init__(self, model):
self.model = model
def define_predict(self, obs):
def predict(self, obs):
return self.model.policy(obs)
......@@ -49,7 +49,7 @@ class TestAgent(Agent):
self.predict_program = fluid.Program()
with fluid.program_guard(self.predict_program):
obs = layers.data(name='obs', shape=[10], dtype='float32')
output = self.alg.define_predict(obs)
output = self.algorithm.predict(obs)
self.predict_output = [output]
def predict(self, obs):
......@@ -65,19 +65,13 @@ class AgentBaseTest(unittest.TestCase):
self.model = TestModel()
self.algorithm = TestAlgorithm(self.model)
def test_agent_with_gpu(self):
def test_agent(self):
if get_gpu_count() > 0:
agent = TestAgent(self.algorithm, gpu_id=0)
agent = TestAgent(self.algorithm)
obs = np.random.random([3, 10]).astype('float32')
output_np = agent.predict(obs)
self.assertIsNotNone(output_np)
def test_agent_with_cpu(self):
agent = TestAgent(self.algorithm, gpu_id=-1)
obs = np.random.random([3, 10]).astype('float32')
output_np = agent.predict(obs)
self.assertIsNotNone(output_np)
if __name__ == '__main__':
unittest.main()
......@@ -14,13 +14,13 @@
import numpy as np
import paddle.fluid as fluid
import parl.layers as layers
import parl.core.fluid.layers as layers
import unittest
from copy import deepcopy
from paddle.fluid import ParamAttr
from parl.framework.model_base import Model
from parl.core.fluid.model import Model
from parl.utils import get_gpu_count
from parl.plutils import fetch_value
from parl.core.fluid.plutils import fetch_value
class TestModel(Model):
......@@ -93,13 +93,11 @@ class ModelBaseTest(unittest.TestCase):
gpu_count = get_gpu_count()
if gpu_count > 0:
place = fluid.CUDAPlace(0)
self.gpu_id = 0
else:
place = fluid.CPUPlace()
self.gpu_id = -1
self.executor = fluid.Executor(place)
def test_network_copy(self):
def test_model_copy(self):
self.assertNotEqual(self.model.fc1.param_name,
self.target_model.fc1.param_name)
self.assertNotEqual(self.model.fc1.bias_name,
......@@ -115,7 +113,7 @@ class ModelBaseTest(unittest.TestCase):
self.assertNotEqual(self.model.fc3.bias_name,
self.target_model.fc3.bias_name)
def test_network_copy_with_multi_copy(self):
def test_model_copy_with_multi_copy(self):
self.assertNotEqual(self.target_model.fc1.param_name,
self.target_model2.fc1.param_name)
self.assertNotEqual(self.target_model.fc1.bias_name,
......@@ -131,17 +129,17 @@ class ModelBaseTest(unittest.TestCase):
self.assertNotEqual(self.target_model.fc3.bias_name,
self.target_model2.fc3.bias_name)
def test_network_parameter_names(self):
def test_model_parameters(self):
self.assertSetEqual(
set(self.model.parameter_names),
set(self.model.parameters()),
set(['fc1.w', 'fc1.b', 'fc2.w', 'fc2.b', 'fc3.w', 'fc3.b']))
# Second test for cache parameter_names
# Second test for cache parameters
self.assertSetEqual(
set(self.model.parameter_names),
set(self.model.parameters()),
set(['fc1.w', 'fc1.b', 'fc2.w', 'fc2.b', 'fc3.w', 'fc3.b']))
def test_sync_params_in_one_program(self):
def test_sync_weights_in_one_program(self):
pred_program = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
......@@ -159,7 +157,7 @@ class ModelBaseTest(unittest.TestCase):
fetch_list=[model_output, target_model_output])
self.assertNotEqual(outputs[0].flatten(), outputs[1].flatten())
self.model.sync_params_to(self.target_model, self.gpu_id)
self.model.sync_weights_to(self.target_model)
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
......@@ -170,7 +168,7 @@ class ModelBaseTest(unittest.TestCase):
fetch_list=[model_output, target_model_output])
self.assertEqual(outputs[0].flatten(), outputs[1].flatten())
def test_sync_params_among_programs(self):
def test_sync_weights_among_programs(self):
pred_program = fluid.Program()
pred_program_2 = fluid.Program()
with fluid.program_guard(pred_program):
......@@ -197,7 +195,7 @@ class ModelBaseTest(unittest.TestCase):
fetch_list=[target_model_output])
self.assertNotEqual(outputs[0].flatten(), outputs_2[0].flatten())
self.model.sync_params_to(self.target_model, self.gpu_id)
self.model.sync_weights_to(self.target_model)
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
......@@ -219,7 +217,7 @@ class ModelBaseTest(unittest.TestCase):
model_fc3_w = fetch_value('fc3.w')
model_fc3_b = fetch_value('fc3.b')
unique_id = target_model.parameter_names[0].split('_')[-1]
unique_id = target_model.parameters()[0].split('_')[-1]
target_model_fc1_w = fetch_value(
'PARL_target_fc1.w_{}'.format(unique_id))
target_model_fc1_b = fetch_value(
......@@ -250,7 +248,7 @@ class ModelBaseTest(unittest.TestCase):
return (target_model_fc1_w, target_model_fc1_b, target_model_fc2_w,
target_model_fc2_b, target_model_fc3_w, target_model_fc3_b)
def test_sync_params_with_decay(self):
def test_sync_weights_with_decay(self):
pred_program = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
......@@ -265,7 +263,7 @@ class ModelBaseTest(unittest.TestCase):
target_model_fc2_b, target_model_fc3_w,
target_model_fc3_b) = self._numpy_update(self.target_model, decay)
self.model.sync_params_to(self.target_model, self.gpu_id, decay=decay)
self.model.sync_weights_to(self.target_model, decay=decay)
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
......@@ -283,7 +281,7 @@ class ModelBaseTest(unittest.TestCase):
self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5)
def test_sync_params_with_decay_with_multi_sync(self):
def test_sync_weights_with_decay_with_multi_sync(self):
pred_program = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
......@@ -298,7 +296,7 @@ class ModelBaseTest(unittest.TestCase):
target_model_fc2_b, target_model_fc3_w,
target_model_fc3_b) = self._numpy_update(self.target_model, decay)
self.model.sync_params_to(self.target_model, self.gpu_id, decay=decay)
self.model.sync_weights_to(self.target_model, decay=decay)
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
......@@ -322,7 +320,7 @@ class ModelBaseTest(unittest.TestCase):
target_model_fc2_b, target_model_fc3_w,
target_model_fc3_b) = self._numpy_update(self.target_model, decay)
self.model.sync_params_to(self.target_model, self.gpu_id, decay=decay)
self.model.sync_weights_to(self.target_model, decay=decay)
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
......@@ -340,7 +338,7 @@ class ModelBaseTest(unittest.TestCase):
self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5)
def test_sync_params_with_different_decay(self):
def test_sync_weights_with_different_decay(self):
pred_program = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
......@@ -355,7 +353,7 @@ class ModelBaseTest(unittest.TestCase):
target_model_fc2_b, target_model_fc3_w,
target_model_fc3_b) = self._numpy_update(self.target_model, decay)
self.model.sync_params_to(self.target_model, self.gpu_id, decay=decay)
self.model.sync_weights_to(self.target_model, decay=decay)
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
......@@ -379,7 +377,7 @@ class ModelBaseTest(unittest.TestCase):
target_model_fc2_b, target_model_fc3_w,
target_model_fc3_b) = self._numpy_update(self.target_model, decay)
self.model.sync_params_to(self.target_model, self.gpu_id, decay=decay)
self.model.sync_weights_to(self.target_model, decay=decay)
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
......@@ -397,7 +395,7 @@ class ModelBaseTest(unittest.TestCase):
self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5)
def test_sync_params_with_multi_target_model(self):
def test_sync_weights_with_multi_target_model(self):
pred_program = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
......@@ -413,7 +411,7 @@ class ModelBaseTest(unittest.TestCase):
target_model_fc2_b, target_model_fc3_w,
target_model_fc3_b) = self._numpy_update(self.target_model, decay)
self.model.sync_params_to(self.target_model, self.gpu_id, decay=decay)
self.model.sync_weights_to(self.target_model, decay=decay)
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
......@@ -437,7 +435,7 @@ class ModelBaseTest(unittest.TestCase):
target_model_fc2_b, target_model_fc3_w,
target_model_fc3_b) = self._numpy_update(self.target_model2, decay)
self.model.sync_params_to(self.target_model2, self.gpu_id, decay=decay)
self.model.sync_weights_to(self.target_model2, decay=decay)
N = 10
random_obs = np.random.random(size=(N, 4)).astype('float32')
......@@ -455,7 +453,7 @@ class ModelBaseTest(unittest.TestCase):
self.assertLess(float(np.abs(real_target_outputs - out_np)), 1e-5)
def test_sync_params_with_create_parameter(self):
def test_sync_weights_with_create_parameter(self):
model = TestModel2()
target_model = deepcopy(model)
......@@ -477,7 +475,7 @@ class ModelBaseTest(unittest.TestCase):
self.assertNotEqual(
np.sum(outputs[0].flatten()), np.sum(outputs[1].flatten()))
model.sync_params_to(target_model, self.gpu_id)
model.sync_weights_to(target_model)
random_obs = np.random.random(size=(N, 100)).astype('float32')
for i in range(N):
......@@ -489,7 +487,7 @@ class ModelBaseTest(unittest.TestCase):
self.assertEqual(
np.sum(outputs[0].flatten()), np.sum(outputs[1].flatten()))
def test_sync_params_with_batch_norm(self):
def test_sync_weights_with_batch_norm(self):
model = TestModel3()
target_model = deepcopy(model)
......@@ -528,7 +526,7 @@ class ModelBaseTest(unittest.TestCase):
x = np.expand_dims(random_obs[i], axis=0)
self.executor.run(program1, feed={'obs': x})
model.sync_params_to(target_model, self.gpu_id)
model.sync_weights_to(target_model)
random_obs = np.random.random(size=(N, 32, 128, 128)).astype('float32')
for i in range(N):
......@@ -540,7 +538,7 @@ class ModelBaseTest(unittest.TestCase):
self.assertEqual(
np.sum(outputs[0].flatten()), np.sum(outputs[1].flatten()))
def test_get_params(self):
def test_get_weights(self):
pred_program = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
......@@ -554,7 +552,7 @@ class ModelBaseTest(unittest.TestCase):
]:
expected_params.append(fetch_value(param_name))
params = self.model.get_params()
params = self.model.get_weights()
self.assertEqual(len(params), len(expected_params))
for param in params:
flag = False
......@@ -564,7 +562,7 @@ class ModelBaseTest(unittest.TestCase):
break
self.assertTrue(flag)
def test_set_params(self):
def test_set_weights(self):
pred_program = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
......@@ -572,15 +570,15 @@ class ModelBaseTest(unittest.TestCase):
self.executor.run(fluid.default_startup_program())
params = self.model.get_params()
params = self.model.get_weights()
new_params = [x + 1.0 for x in params]
self.model.set_params(new_params, self.gpu_id)
self.model.set_weights(new_params)
for x, y in list(zip(new_params, self.model.get_params())):
for x, y in list(zip(new_params, self.model.get_weights())):
self.assertEqual(np.sum(x), np.sum(y))
def test_set_params_between_different_models(self):
def test_set_weights_between_different_models(self):
model1 = TestModel4()
model2 = TestModel4()
......@@ -603,8 +601,8 @@ class ModelBaseTest(unittest.TestCase):
self.assertNotEqual(outputs[0].flatten(), outputs[1].flatten())
# pass parameters of self.model to model2
params = model1.get_params()
model2.set_params(params, self.gpu_id)
params = model1.get_weights()
model2.set_weights(params)
random_obs = np.random.random(size=(N, 4)).astype('float32')
for i in range(N):
......@@ -615,7 +613,7 @@ class ModelBaseTest(unittest.TestCase):
fetch_list=[model1_output, model2_output])
self.assertEqual(outputs[0].flatten(), outputs[1].flatten())
def test_set_params_with_wrong_params_num(self):
def test_set_weights_with_wrong_params_num(self):
pred_program = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
......@@ -623,17 +621,17 @@ class ModelBaseTest(unittest.TestCase):
self.executor.run(fluid.default_startup_program())
params = self.model.get_params()
params = self.model.get_weights()
try:
self.model.set_params(params[1:], self.gpu_id)
self.model.set_weights(params[1:])
except:
# expected
return
assert False
def test_set_params_with_wrong_params_shape(self):
def test_set_weights_with_wrong_params_shape(self):
pred_program = fluid.Program()
with fluid.program_guard(pred_program):
obs = layers.data(name='obs', shape=[4], dtype='float32')
......@@ -641,11 +639,11 @@ class ModelBaseTest(unittest.TestCase):
self.executor.run(fluid.default_startup_program())
params = self.model.get_params()
params = self.model.get_weights()
params.reverse()
self.model.set_params(params, self.gpu_id)
self.model.set_weights(params)
x = np.random.random(size=(1, 4)).astype('float32')
......
......@@ -13,11 +13,11 @@
# limitations under the License.
import numpy as np
import parl.layers as layers
import unittest
from paddle import fluid
from parl import layers
from parameterized import parameterized
from parl.framework.policy_distribution import *
from parl.core.fluid.policy_distribution import *
from parl.utils import get_gpu_count, np_softmax, np_cross_entropy
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from parl.core.model_helper import global_model_helper
class ModelBase(object):
"""`ModelBase` is the base class of the `parl.Model` in different frameworks.
This base class mainly do the following things:
1. Implements APIs to manage model_id of the `parl.Model`;
2. Defines common APIs that `parl.Model` should implement in different frameworks.
"""
def __init__(self, model_id=None):
"""
Args:
model_id (String): user-specified model_id (default: None)
"""
if model_id is not None:
global_model_helper.register_model_id(model_id)
self.__model_id = model_id
else:
self.__model_id = global_model_helper.generate_model_id()
@property
def model_id(self):
return self.get_model_id()
@model_id.setter
def model_id(self, model_id):
self.set_model_id(model_id)
def get_model_id(self):
"""Get model_id of `ModelBase`.
If not created, will create a new model_id.
Returns:
String of model_id.
"""
try:
return self.__model_id
except AttributeError:
self.__model_id = global_model_helper.generate_model_id()
return self.__model_id
def set_model_id(self, model_id):
"""Set model_id of `ModelBase` with given model_id.
Args:
model_id (string): string of model_id.
"""
global_model_helper.register_model_id(model_id)
self.__model_id = model_id
def forward(self, *args, **kwargs):
"""Define forward network of the model.
"""
raise NotImplementedError
def get_weights(self):
"""Get weights of the model.
"""
raise NotImplementedError
def set_weights(self, weights):
"""Set weights of the model with given weights.
"""
raise NotImplementedError
def sync_weights_to(self, other_model):
"""Synchronize weights of the model to another model.
"""
raise NotImplementedError
def parameters(self):
"""Get the parameters of the model.
"""
raise NotImplementedError
def __call__(self, *args, **kwargs):
"""Call forward function.
"""
self.forward(*args, **kwargs)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
__all__ = ['global_model_helper']
class ModelHelper(object):
"""Model id helper.
This helper is used to help `parl.Model` generate a new model id
or register a given model id in a thread-safe way.
"""
def __init__(self):
self._registered_ids = set([])
self.index = 0
self.lock = threading.Lock()
def generate_model_id(self):
"""Generate a unique model_id in a thread-safe way.
Returns:
String of model id.
"""
self.lock.acquire()
model_id = 'parl_model_{}'.format(self.index)
while model_id in self._registered_ids:
self.index += 1
model_id = 'parl_model_{}'.format(self.index)
self._registered_ids.add(model_id)
self.index += 1
self.lock.release()
return model_id
def register_model_id(self, model_id):
"""Register given model id in a thread-safe way.
Raises:
AssertionError: if the model id is already used.
"""
model_id_used = False
self.lock.acquire()
if model_id in self._registered_ids:
model_id_used = True
else:
self._registered_ids.add(model_id)
self.lock.release()
assert not model_id_used, "model id `{}` has been used before, please try another model_id.".format(
model_id)
global_model_helper = ModelHelper()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from parl.core.model_base import ModelBase
from parl.core.algorithm_base import AlgorithmBase
from parl.core.agent_base import AgentBase
class MockModel(ModelBase):
def __init__(self, weights, model_id=None):
super(MockModel, self).__init__(model_id)
self.weights = weights
def get_weights(self):
return self.weights
def set_weights(self, weights):
self.weights = weights
class TestAlgorithm(AlgorithmBase):
def __init__(self):
self.model1 = MockModel(1)
self.model2 = MockModel(2)
self.model_list1 = (-1, MockModel(3))
self.model_list2 = [MockModel(4), MockModel(5)]
self.model_dict1 = {'k1': MockModel(6), 'k2': -2}
self.model_dict2 = {'k1': MockModel(7), 'k2': MockModel(8)}
class TestAlgorithm2(AlgorithmBase):
def __init__(self):
self.model1 = MockModel(1, model_id='id1')
self.model2 = MockModel(2, model_id='id2')
self.model_list1 = (-1, MockModel(3, model_id='id3'))
self.model_list2 = [
MockModel(4, model_id='id4'),
MockModel(5, model_id='id5')
]
self.model_dict1 = {'k1': MockModel(6, model_id='id6'), 'k2': -2}
self.model_dict2 = {
'k1': MockModel(7, model_id='id7'),
'k2': MockModel(8, model_id='id8')
}
class AgentBaseTest(unittest.TestCase):
def setUp(self):
alg1 = TestAlgorithm()
alg2 = TestAlgorithm()
self.agent1 = AgentBase(alg1)
self.agent2 = AgentBase(alg2)
def test_get_weights(self):
weights = self.agent1.get_weights()
expected_dict = {
'model1': 1,
'model2': 2,
'model_list1': [3],
'model_list2': [4, 5],
'model_dict1': {
'k1': 6
},
'model_dict2': {
'k1': 7,
'k2': 8
}
}
self.assertDictEqual(weights, expected_dict)
def test_set_weights(self):
expected_dict = {
'model1': -1,
'model2': -2,
'model_list1': [-3],
'model_list2': [-4, -5],
'model_dict1': {
'k1': -6
},
'model_dict2': {
'k1': -7,
'k2': -8
}
}
self.agent1.set_weights(expected_dict)
self.assertDictEqual(self.agent1.get_weights(), expected_dict)
def test_get_and_set_weights_between_agents(self):
expected_dict = {
'model1': -1,
'model2': -2,
'model_list1': [-3],
'model_list2': [-4, -5],
'model_dict1': {
'k1': -6
},
'model_dict2': {
'k1': -7,
'k2': -8
}
}
self.agent1.set_weights(expected_dict)
new_weights = self.agent1.get_weights()
self.agent2.set_weights(new_weights)
self.assertDictEqual(self.agent2.get_weights(), expected_dict)
def test_get_model_ids(self):
alg = TestAlgorithm2()
agent = AgentBase(alg)
expected_model_ids = set(['id{}'.format(i + 1) for i in range(8)])
self.assertSetEqual(expected_model_ids, agent.get_model_ids())
def test_get_weights_with_model_ids(self):
weights = self.agent1.get_weights(model_ids=[
self.agent1.algorithm.model1.model_id, self.agent1.algorithm.
model_list2[0].model_id, self.agent1.algorithm.model_dict2['k1'].
model_id
])
expected_dict = {
'model1': 1,
'model_list2': [4],
'model_dict2': {
'k1': 7,
}
}
self.assertDictEqual(weights, expected_dict)
def test_set_weights_with_model_ids(self):
new_weights = {
'model1': -1,
'model_list2': [-4],
'model_dict2': {
'k1': -7,
}
}
expected_dict = {
'model1': -1,
'model2': 2,
'model_list1': [3],
'model_list2': [-4, 5],
'model_dict1': {
'k1': 6
},
'model_dict2': {
'k1': -7,
'k2': 8
}
}
self.agent1.set_weights(
new_weights,
model_ids=[
self.agent1.algorithm.model1.model_id,
self.agent1.algorithm.model_list2[0].model_id,
self.agent1.algorithm.model_dict2['k1'].model_id
])
self.assertDictEqual(self.agent1.get_weights(), expected_dict)
def test_get_and_set_weights_between_agents_with_model_ids(self):
agent1_model_ids = [
self.agent1.algorithm.model1.model_id,
self.agent1.algorithm.model_list2[0].model_id,
self.agent1.algorithm.model_dict2['k1'].model_id
]
agent2_model_ids = [
self.agent2.algorithm.model1.model_id,
self.agent2.algorithm.model_list2[0].model_id,
self.agent2.algorithm.model_dict2['k1'].model_id
]
new_weights = {
'model1': -1,
'model_list2': [-4],
'model_dict2': {
'k1': -7,
}
}
expected_dict = {
'model1': -1,
'model2': 2,
'model_list1': [3],
'model_list2': [-4, 5],
'model_dict1': {
'k1': 6
},
'model_dict2': {
'k1': -7,
'k2': 8
}
}
self.agent1.set_weights(new_weights, agent1_model_ids)
agent1_weights = self.agent1.get_weights(agent1_model_ids)
self.agent2.set_weights(agent1_weights, agent2_model_ids)
self.assertDictEqual(self.agent2.get_weights(), expected_dict)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from parl.core.model_base import ModelBase
from parl.core.algorithm_base import AlgorithmBase
class MockModel(ModelBase):
def __init__(self, weights, model_id=None):
super(MockModel, self).__init__(model_id)
self.weights = weights
def get_weights(self):
return self.weights
def set_weights(self, weights):
self.weights = weights
class TestAlgorithm(AlgorithmBase):
def __init__(self):
self.model1 = MockModel(1)
self.model2 = MockModel(2)
self.model_list1 = (-1, MockModel(3))
self.model_list2 = [MockModel(4), MockModel(5)]
self.model_dict1 = {'k1': MockModel(6), 'k2': -2}
self.model_dict2 = {'k1': MockModel(7), 'k2': MockModel(8)}
class TestAlgorithm2(AlgorithmBase):
def __init__(self):
self.model1 = MockModel(1, model_id='id1')
self.model2 = MockModel(2, model_id='id2')
self.model_list1 = (-1, MockModel(3, model_id='id3'))
self.model_list2 = [
MockModel(4, model_id='id4'),
MockModel(5, model_id='id5')
]
self.model_dict1 = {'k1': MockModel(6, model_id='id6'), 'k2': -2}
self.model_dict2 = {
'k1': MockModel(7, model_id='id7'),
'k2': MockModel(8, model_id='id8')
}
class AlgorithmBaseTest(unittest.TestCase):
def setUp(self):
self.alg1 = TestAlgorithm()
self.alg2 = TestAlgorithm()
def test_get_weights(self):
weights = self.alg1.get_weights()
expected_dict = {
'model1': 1,
'model2': 2,
'model_list1': [3],
'model_list2': [4, 5],
'model_dict1': {
'k1': 6
},
'model_dict2': {
'k1': 7,
'k2': 8
}
}
self.assertDictEqual(weights, expected_dict)
def test_set_weights(self):
expected_dict = {
'model1': -1,
'model2': -2,
'model_list1': [-3],
'model_list2': [-4, -5],
'model_dict1': {
'k1': -6
},
'model_dict2': {
'k1': -7,
'k2': -8
}
}
self.alg1.set_weights(expected_dict)
self.assertDictEqual(self.alg1.get_weights(), expected_dict)
def test_set_weights_with_inconsistent_weights_case1(self):
inconsistent_weights = {
'model0': -1,
'model2': -2,
'model_list1': [-3],
'model_list2': [-4, -5],
'model_dict1': {
'k1': -6
},
'model_dict2': {
'k1': -7,
'k2': -8
}
}
with self.assertRaises(AssertionError):
self.alg1.set_weights(inconsistent_weights)
def test_set_weights_with_inconsistent_weights_case2(self):
inconsistent_weights = {
'model1': -1,
'model2': -2,
'model_list1': [-3],
'model_list2': [-4],
'model_dict1': {
'k1': -6
},
'model_dict2': {
'k1': -7,
'k2': -8
}
}
with self.assertRaises(AssertionError):
self.alg1.set_weights(inconsistent_weights)
def test_set_weights_with_inconsistent_weights_case3(self):
inconsistent_weights = {
'model1': -1,
'model2': -2,
'model_list1': [-3],
'model_list2': [-4, -5],
'model_dict1': {
'k1': -6
},
'model_dict2': {
'k1': -7,
}
}
with self.assertRaises(AssertionError):
self.alg1.set_weights(inconsistent_weights)
def test_set_weights_with_redundant_weights(self):
redundant_weights = {
'model0': 0,
'model1': -1,
'model2': -2,
'model_list1': [-3],
'model_list2': [-4, -5],
'model_list3': [44, 55],
'model_dict1': {
'k1': -6
},
'model_dict2': {
'k1': -7,
'k2': -8
},
'model_dict3': {
'k1': -77,
'k2': -88
}
}
expected_dict = {
'model1': -1,
'model2': -2,
'model_list1': [-3],
'model_list2': [-4, -5],
'model_dict1': {
'k1': -6
},
'model_dict2': {
'k1': -7,
'k2': -8
}
}
self.alg1.set_weights(expected_dict)
self.assertDictEqual(self.alg1.get_weights(), expected_dict)
def test_get_and_set_weights_between_algorithms(self):
expected_dict = {
'model1': -1,
'model2': -2,
'model_list1': [-3],
'model_list2': [-4, -5],
'model_dict1': {
'k1': -6
},
'model_dict2': {
'k1': -7,
'k2': -8
}
}
self.alg1.set_weights(expected_dict)
new_weights = self.alg1.get_weights()
self.alg2.set_weights(new_weights)
self.assertDictEqual(self.alg2.get_weights(), expected_dict)
def test_get_model_ids(self):
alg = TestAlgorithm2()
expected_model_ids = set(['id{}'.format(i + 1) for i in range(8)])
self.assertSetEqual(expected_model_ids, alg.get_model_ids())
def test_get_weights_with_model_ids(self):
weights = self.alg1.get_weights(model_ids=[
self.alg1.model1.model_id, self.alg1.model_list2[0].model_id, self.
alg1.model_dict2['k1'].model_id
])
expected_dict = {
'model1': 1,
'model_list2': [4],
'model_dict2': {
'k1': 7,
}
}
self.assertDictEqual(weights, expected_dict)
def test_set_weights_with_model_ids(self):
new_weights = {
'model1': -1,
'model_list2': [-4],
'model_dict2': {
'k1': -7,
}
}
expected_dict = {
'model1': -1,
'model2': 2,
'model_list1': [3],
'model_list2': [-4, 5],
'model_dict1': {
'k1': 6
},
'model_dict2': {
'k1': -7,
'k2': 8
}
}
self.alg1.set_weights(
new_weights,
model_ids=[
self.alg1.model1.model_id, self.alg1.model_list2[0].model_id,
self.alg1.model_dict2['k1'].model_id
])
self.assertDictEqual(self.alg1.get_weights(), expected_dict)
def test_get_and_set_weights_between_algorithms_with_model_ids(self):
alg1_model_ids = [
self.alg1.model1.model_id, self.alg1.model_list2[0].model_id,
self.alg1.model_dict2['k1'].model_id
]
alg2_model_ids = [
self.alg2.model1.model_id, self.alg2.model_list2[0].model_id,
self.alg2.model_dict2['k1'].model_id
]
new_weights = {
'model1': -1,
'model_list2': [-4],
'model_dict2': {
'k1': -7,
}
}
expected_dict = {
'model1': -1,
'model2': 2,
'model_list1': [3],
'model_list2': [-4, 5],
'model_dict1': {
'k1': 6
},
'model_dict2': {
'k1': -7,
'k2': 8
}
}
self.alg1.set_weights(new_weights, alg1_model_ids)
alg1_weights = self.alg1.get_weights(alg1_model_ids)
self.alg2.set_weights(alg1_weights, alg2_model_ids)
self.assertDictEqual(self.alg2.get_weights(), expected_dict)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from parl.core.model_base import ModelBase
class ModelBaseTest(unittest.TestCase):
def setUp(self):
self.model = ModelBase()
def test_set_and_get_model_id(self):
model_id = 'id1'
self.model.set_model_id(model_id)
self.assertEqual(model_id, self.model.get_model_id())
model_id2 = 'id2'
self.model.model_id = model_id2
self.assertEqual(model_id2, self.model.model_id)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import unittest
from parl.core.model_helper import global_model_helper
from six.moves.queue import Queue
class GlobalModelHelperTest(unittest.TestCase):
def test_generate_model_id(self):
id1 = global_model_helper.generate_model_id()
id2 = global_model_helper.generate_model_id()
self.assertNotEqual(id1, id2)
def _gen_model_id(self, q):
model_id = global_model_helper.generate_model_id()
q.put(model_id)
def test_generate_model_id_with_multi_thread(self):
q = Queue()
t1 = threading.Thread(target=self._gen_model_id, args=(q, ))
t2 = threading.Thread(target=self._gen_model_id, args=(q, ))
t1.start()
t2.start()
t1.join()
t2.join()
id1 = q.get()
id2 = q.get()
self.assertNotEqual(id1, id2)
def test_register_model_id(self):
global_model_helper.register_model_id('my_model_0')
global_model_helper.register_model_id('my_model_1')
with self.assertRaises(AssertionError):
global_model_helper.register_model_id('my_model_0')
def _register_model_id(self, q):
try:
global_model_helper.register_model_id('my_model_2')
except AssertionError:
q.put(False)
else:
q.put(True)
def test_register_model_id_with_multi_thread(self):
q = Queue()
t1 = threading.Thread(target=self._register_model_id, args=(q, ))
t2 = threading.Thread(target=self._register_model_id, args=(q, ))
t1.start()
t2.start()
t1.join()
t2.join()
return1 = q.get()
return2 = q.get()
assert (return1 is True and return2 is False) or \
(return1 is False and return2 is True)
def test_registet_model_id_with_used_model_id(self):
model_id = global_model_helper.generate_model_id()
with self.assertRaises(AssertionError):
global_model_helper.register_model_id(model_id)
if __name__ == '__main__':
unittest.main()
......@@ -11,7 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from parl.framework.model_base import *
from parl.framework.algorithm_base import *
from parl.framework.agent_base import *
warnings.simplefilter('default')
warnings.warn(
"import way `import parl.framework` is deprecated since version 1.2 and will be removed in version 1.3.",
DeprecationWarning,
stacklevel=2)
from parl.core.fluid.model import *
from parl.core.fluid.algorithm import *
from parl.core.fluid.agent import *
......@@ -12,96 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import parl.layers as layers
from parl.framework.algorithm_base import Algorithm
from parl.framework.model_base import Model
from parl.utils import get_gpu_count
import warnings
__all__ = ['Agent']
warnings.simplefilter('default')
warnings.warn(
"module `parl.framework.agent_base.Agent` is deprecated since version 1.2 and will be removed in version 1.3, please use `parl.Agent` instead.",
DeprecationWarning,
stacklevel=2)
class Agent(object):
"""
A Agent is responsible for the general data flow
outside the algorithm.
A Agent is created in a bottom-up way:
a. create a Model
b. create an Algorithm with the model as an input
c. define a Agent with the algorithm
"""
def __init__(self, algorithm, gpu_id=None):
""" build program and run initialization for default_startup_program
Created object:
self.alg: parl.framework.Algorithm
self.gpu_id: int
self.fluid_executor: fluid.Executor
"""
assert isinstance(algorithm, Algorithm)
self.alg = algorithm
self.build_program()
if gpu_id is None:
gpu_id = 0 if get_gpu_count() > 0 else -1
self.gpu_id = gpu_id
self.place = fluid.CUDAPlace(
gpu_id) if gpu_id >= 0 else fluid.CPUPlace()
self.fluid_executor = fluid.Executor(self.place)
self.fluid_executor.run(fluid.default_startup_program())
def build_program(self):
"""build your training program and prediction program here,
using the functions define_learn and define_predict in algorithm.
Note that it's unnecessary to call this function explictly since
it will be called automatically in the initialization function.
To build the program, you may need to do the following:
a. create a new program in fluid with program guard
b. define your data layer
c. build your training/prediction program, pass the data variable
defined in step b to `define_training/define_prediction` of algorithm
"""
raise NotImplementedError
def predict(self, obs):
"""This function will predict the action given current observation of the enviroment.
Note that this function will only do the prediction and it doesn't try any exploration,
To explore in the action space, you should create your process in `sample` function below.
In formally, this function is often used in test process.
"""
raise NotImplementedError
def sample(self, obs):
"""This function will predict the action given current observation of the enviroment.
Additionaly, action will be added noise here to explore a new trajectory. In formally,
this function is often used in training process.
"""
raise NotImplementedError
def learn(self, obs, action, reward, next_obs, terminal):
"""pass data to the training program to update model,
this function is the training interface for Agent.
"""
raise NotImplementedError
def get_params(self):
""" Get parameters of self.alg
Returns:
List of numpy array.
"""
return self.alg.get_params()
def set_params(self, params):
""" Set parameters of self.alg
Args:
params: List of numpy array.
"""
self.alg.set_params(params, gpu_id=self.gpu_id)
from parl.core.fluid.agent import *
......@@ -12,60 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABCMeta, abstractmethod
from parl.framework.model_base import Model
import warnings
__all__ = ['Algorithm']
warnings.simplefilter('default')
warnings.warn(
"module `parl.framework.algorithm_base.Algorithm` is deprecated since version 1.2 and will be removed in version 1.3, please use `parl.Algorithm` instead.",
DeprecationWarning,
stacklevel=2)
class Algorithm(object):
"""
Algorithm defines the way how we update the model. For example,
after defining forward network in `Model` class, you should define how to update the model here.
Before creating a customized algorithm, please do check algorithms of PARL.
Most common used algorithms like DQN/DDPG/PPO/A3C have been providing in algorithms, go and have a try.
It's easy to use them and just try parl.algorithms.DQN.
An Algorithm implements two functions:
1. define_predict() build forward process which was defined in `Model`
2. define_learn() computes a cost for optimization
An algorithm should be updating part of a network. The user only needs to
implement the rest of the network(forward) in the Model class.
"""
def __init__(self, model, hyperparas=None):
assert isinstance(model, Model)
self.model = model
self.hp = hyperparas
def define_predict(self, obs):
"""
describe process for building predcition program
"""
raise NotImplementedError()
def define_learn(self, obs, action, reward, next_obs, terminal):
"""define how to update the model here, you may need to do the following:
1. define a cost for optimization
2. specify your optimizer
3. optimize model defined in Model
"""
raise NotImplementedError()
def get_params(self):
""" Get parameters of self.model
Returns:
List of numpy array.
"""
return self.model.get_params()
def set_params(self, params, gpu_id):
""" Set parameters of self.model
Args:
params: List of numpy array.
gpu_id: gpu id where self.model in. (if gpu_id < 0, means in cpu.)
"""
self.model.set_params(params, gpu_id=gpu_id)
from parl.core.fluid.algorithm import *
......@@ -11,276 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Base class to define an Algorithm.
"""
import hashlib
import paddle.fluid as fluid
from abc import ABCMeta
from parl.layers.layer_wrappers import LayerFunc
from parl.plutils import *
import warnings
__all__ = ['Network', 'Model']
warnings.simplefilter('default')
warnings.warn(
"module `parl.framework.model_base.Model` is deprecated since version 1.2 and will be removed in version 1.3, please use `parl.Model` instead.",
DeprecationWarning,
stacklevel=2)
class Network(object):
"""
A Network is a collection of LayerFuncs or Networks.
"""
def sync_params_to(self,
target_net,
gpu_id,
decay=0.0,
share_vars_parallel_executor=None):
"""
Args:
target_net: Network object deepcopy from source network
gpu_id: gpu id of target_net
decay: Float. The decay to use.
target_net_weights = decay * target_net_weights + (1 - decay) * source_net_weights
share_vars_parallel_executor: if not None, will use fluid.ParallelExecutor
to run program instead of fluid.Executor
"""
args_hash_id = hashlib.md5('{}_{}_{}'.format(
id(target_net), gpu_id, decay).encode('utf-8')).hexdigest()
has_cached = False
try:
if self._cached_id == args_hash_id:
has_cached = True
except AttributeError:
has_cached = False
if not has_cached:
# Can not run _cached program, need create a new program
self._cached_id = args_hash_id
assert not target_net is self, "cannot copy between identical networks"
assert isinstance(target_net, Network)
assert self.__class__.__name__ == target_net.__class__.__name__, \
"must be the same class for params syncing!"
assert (decay >= 0 and decay <= 1)
param_pairs = self._get_parameter_pairs(self, target_net)
self._cached_sync_params_program = fluid.Program()
with fluid.program_guard(self._cached_sync_params_program):
for (src_var_name, target_var_name) in param_pairs:
src_var = fetch_framework_var(src_var_name)
target_var = fetch_framework_var(target_var_name)
fluid.layers.assign(
decay * target_var + (1 - decay) * src_var, target_var)
if share_vars_parallel_executor is None:
# use fluid.Executor
place = fluid.CPUPlace() if gpu_id < 0 \
else fluid.CUDAPlace(gpu_id)
self._cached_fluid_executor = fluid.Executor(place)
else:
# use fluid.ParallelExecutor
use_cuda = True if gpu_id >= 0 else False
# specify strategy to make ParallelExecutor run faster
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.use_experimental_executor = True
exec_strategy.num_threads = 4
build_strategy = fluid.BuildStrategy()
build_strategy.remove_unnecessary_lock = True
with fluid.scope_guard(fluid.global_scope().new_scope()):
self._cached_fluid_executor = fluid.ParallelExecutor(
use_cuda=use_cuda,
main_program=self._cached_sync_params_program,
share_vars_from=share_vars_parallel_executor,
exec_strategy=exec_strategy,
build_strategy=build_strategy,
)
if share_vars_parallel_executor is None:
self._cached_fluid_executor.run(self._cached_sync_params_program)
else:
self._cached_fluid_executor.run(fetch_list=[])
@property
def parameter_names(self):
""" param_attr names of all parameters in Network,
only parameter created by parl.layers included.
The order of parameter names will be consistent between
different instances of same parl.Network.
Returns:
list of string, param_attr names of all parameters
"""
try:
return self._parameter_names
except AttributeError:
self._parameter_names = self._get_parameter_names(self)
return self._parameter_names
def get_params(self):
""" Get numpy arrays of parameters in this Network
Returns:
List of numpy array.
"""
params = []
for param_name in self.parameter_names:
param = fetch_value(param_name)
params.append(param)
return params
def set_params(self, params, gpu_id):
""" Set parameters in this Network with params
Args:
params: List of numpy array.
gpu_id: gpu id where this Network in. (if gpu_id < 0, means in cpu.)
"""
assert len(params) == len(self.parameter_names), \
'size of input params should be same as parameters number of current Network'
for (param_name, param) in list(zip(self.parameter_names, params)):
set_value(param_name, param, gpu_id)
def _get_parameter_names(self, obj):
""" Recursively get parameter names in obj,
Args:
obj (parl.Network/parl.LayerFunc/list/tuple/dict): input object
Returns:
parameter_names (list of string): all parameter names in obj
"""
parameter_names = []
for attr in sorted(obj.__dict__.keys()):
val = getattr(obj, attr)
if isinstance(val, Network):
parameter_names.extend(self._get_parameter_names(val))
elif isinstance(val, LayerFunc):
for attr in val.attr_holder.sorted():
if attr:
parameter_names.append(attr.name)
elif isinstance(val, tuple) or isinstance(val, list):
for x in val:
parameter_names.extend(self._get_parameter_names(x))
elif isinstance(val, dict):
for x in list(val.values()):
parameter_names.extend(self._get_parameter_names(x))
else:
# for any other type, won't be handled. E.g. set
pass
return parameter_names
def _get_parameter_pairs(self, src, target):
""" Recursively gets parameters in source network and
corresponding parameters in target network.
Args:
src (parl.Network/parl.LayerFunc/list/tuple/dict): source object
target (parl.Network/parl.LayerFunc/list/tuple/dict): target object
Returns:
param_pairs (list of tuple): all parameter names in source network
and corresponding parameter names in
target network.
"""
param_pairs = []
if isinstance(src, Network):
for attr in src.__dict__:
if not attr in target.__dict__:
continue
src_var = getattr(src, attr)
target_var = getattr(target, attr)
param_pairs.extend(
self._get_parameter_pairs(src_var, target_var))
elif isinstance(src, LayerFunc):
src_attrs = src.attr_holder.sorted()
target_attrs = target.attr_holder.sorted()
assert len(src_attrs) == len(target_attrs), \
"number of ParamAttr between source layer and target layer should be same."
for (src_attr, target_attr) in zip(src_attrs, target_attrs):
if src_attr:
assert target_attr, "ParamAttr between source layer and target layer is inconsistent."
param_pairs.append((src_attr.name, target_attr.name))
elif isinstance(src, tuple) or isinstance(src, list):
for src_var, target_var in zip(src, target):
param_pairs.extend(
self._get_parameter_pairs(src_var, target_var))
elif isinstance(src, dict):
for k in src.keys():
assert k in target
param_pairs.extend(
self._get_parameter_pairs(src[k], target[k]))
else:
# for any other type, won't be handled. E.g. set
pass
return param_pairs
class Model(Network):
"""
A Model is owned by an Algorithm.
It implements the entire network model(forward part) to solve a specific problem.
In general, Model is responsible for forward and
Algorithm is responsible for backward.
Model can also use deepcopy way to construct target model, which has the same structure as initial model.
Note that only the model definition is copied here. To copy the parameters from the current model
to the target model, you must explicitly use sync_params_to function after the program is initialized.
Here is an example:
```python
import parl.layers as layers
import parl.Model as Model
class MLPModel(Model):
def __init__(self):
self.fc = layers.fc(size=64)
def policy(self, obs):
out = self.fc(obs)
return out
model = MLPModel()
target_model = deepcopy(model) # automatically create new unique parameters names for target_model.fc
# build program
x = layers.data(name='x', shape=[100], dtype="float32")
y1 = model.policy(x)
y2 = target_model.policy(x)
...
# Need initialize program before calling sync_params_to
fluid_executor.run(fluid.default_startup_program())
...
# synchronize parameters
model.sync_params_to(target_model, gpu_id=gpu_id)
```
"""
__metaclass__ = ABCMeta
def __init__(self):
super(Model, self).__init__()
def policy(self, *args):
"""
Implement your policy here.
The function was later used by algorithm
Return: action_dists: a dict of action distribution objects
states
Optional: a model might not always have to implement policy()
"""
raise NotImplementedError()
def value(self, *args):
"""
Return: values: a dict of estimated values for the current observations and states
For example, "q_value" and "v_value"
Optional: a model might not always have to implement value()
"""
raise NotImplementedError()
from parl.core.fluid.model import *
......@@ -12,113 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import parl.layers as layers
import warnings
__all__ = ['PolicyDistribution', 'CategoricalDistribution']
warnings.simplefilter('default')
warnings.warn(
"module `parl.framework.policy_distribution` is deprecated since version 1.2 and will be removed in version 1.3, please use `parl.policy_distribution` instead.",
DeprecationWarning,
stacklevel=2)
class PolicyDistribution(object):
def sample(self):
"""Sampling from the policy distribution."""
raise NotImplementedError
def entropy(self):
"""The entropy of the policy distribution."""
raise NotImplementedError
def kl(self, other):
"""The KL-divergence between self policy distributions and other."""
raise NotImplementedError
def logp(self, actions):
"""The log-probabilities of the actions in this policy distribution."""
raise NotImplementedError
class CategoricalDistribution(PolicyDistribution):
"""Categorical distribution for discrete action spaces."""
def __init__(self, logits):
"""
Args:
logits: A float32 tensor with shape [BATCH_SIZE, NUM_ACTIONS] of unnormalized policy logits
"""
assert len(logits.shape) == 2
self.logits = logits
def sample(self):
"""
Returns:
sample_action: An int64 tensor with shape [BATCH_SIZE] of multinomial sampling ids.
Each value in sample_action is in [0, NUM_ACTIOINS - 1]
"""
probs = layers.softmax(self.logits)
sample_actions = layers.sampling_id(probs)
return sample_actions
def entropy(self):
"""
Returns:
entropy: A float32 tensor with shape [BATCH_SIZE] of entropy of self policy distribution.
"""
logits = self.logits - layers.reduce_max(self.logits, dim=1)
e_logits = layers.exp(logits)
z = layers.reduce_sum(e_logits, dim=1)
prob = e_logits / z
entropy = -1.0 * layers.reduce_sum(
prob * (logits - layers.log(z)), dim=1)
return entropy
def logp(self, actions, eps=1e-6):
"""
Args:
actions: An int64 tensor with shape [BATCH_SIZE]
eps: A small float constant that avoids underflows
Returns:
actions_log_prob: A float32 tensor with shape [BATCH_SIZE]
"""
assert len(actions.shape) == 1
logits = self.logits - layers.reduce_max(self.logits, dim=1)
e_logits = layers.exp(logits)
z = layers.reduce_sum(e_logits, dim=1)
prob = e_logits / z
actions = layers.unsqueeze(actions, axes=[1])
actions_onehot = layers.one_hot(actions, prob.shape[1])
actions_onehot = layers.cast(actions_onehot, dtype='float32')
actions_prob = layers.reduce_sum(prob * actions_onehot, dim=1)
actions_prob = actions_prob + eps
actions_log_prob = layers.log(actions_prob)
return actions_log_prob
def kl(self, other):
"""
Args:
other: object of CategoricalDistribution
Returns:
kl: A float32 tensor with shape [BATCH_SIZE]
"""
assert isinstance(other, CategoricalDistribution)
logits = self.logits - layers.reduce_max(self.logits, dim=1)
other_logits = other.logits - layers.reduce_max(other.logits, dim=1)
e_logits = layers.exp(logits)
other_e_logits = layers.exp(other_logits)
z = layers.reduce_sum(e_logits, dim=1)
other_z = layers.reduce_sum(other_e_logits, dim=1)
prob = e_logits / z
kl = layers.reduce_sum(
prob *
(logits - layers.log(z) - other_logits + layers.log(other_z)),
dim=1)
return kl
from parl.core.fluid.policy_distribution import *
......@@ -11,9 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file wraps Fluid layers that have parameters to support parameter sharing.
For other layers that don't have parameters, we simply copy them to this namespace.
"""
from paddle.fluid.layers import *
from parl.layers.layer_wrappers import *
import warnings
warnings.simplefilter('default')
warnings.warn(
"import way `import parl.layers` is deprecated since version 1.2 and will be removed in version 1.3, please use `from parl import layers` or `import parl; parl.layers` instead.",
DeprecationWarning,
stacklevel=2)
from parl.core.fluid.layers import *
......@@ -12,4 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from parl.plutils.common import *
print(
"import way `import parl.plutils` is deprecated since version 1.2 and will be removed in version 1.3, please use `from parl import plutils` or `import parl; parl.plutils` instead."
)
from parl.core.fluid.plutils.common import *
......@@ -11,69 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Common functions of PARL framework
"""
import paddle.fluid as fluid
from paddle.fluid.executor import _fetch_var
print(
"import way `import parl.plutils` is deprecated since version 1.2 and will be removed in version 1.3, please use `from parl import plutils` or `import parl; parl.plutils` instead."
)
__all__ = ['fetch_framework_var', 'fetch_value', 'set_value', 'inverse']
def fetch_framework_var(attr_name):
""" Fetch framework variable according given attr_name.
Return a new reusing variable through create_parameter way
Args:
attr_name: string, attr name of parameter
Returns:
framework_var: framework.Varialbe
"""
scope = fluid.executor.global_scope()
core_var = scope.find_var(attr_name)
shape = core_var.get_tensor().shape()
framework_var = fluid.layers.create_parameter(
shape=shape, dtype='float32', attr=fluid.ParamAttr(name=attr_name))
return framework_var
def fetch_value(attr_name):
""" Given name of ParamAttr, fetch numpy value of the parameter in global_scope
Args:
attr_name: ParamAttr name of parameter
Returns:
numpy.ndarray
"""
return _fetch_var(attr_name, return_numpy=True)
def set_value(attr_name, value, gpu_id):
""" Given name of ParamAttr, set numpy value to the parameter in global_scope
Args:
attr_name: ParamAttr name of parameter
value: numpy array
gpu_id: gpu id where the parameter in
"""
place = fluid.CPUPlace() if gpu_id < 0 \
else fluid.CUDAPlace(gpu_id)
var = _fetch_var(attr_name, return_numpy=False)
var.set(value, place)
def inverse(x):
""" Inverse 0/1 variable
Args:
x: variable with float32 dtype
Returns:
inverse_x: variable with float32 dtype
"""
inverse_x = -1.0 * x + 1.0
return inverse_x
from parl.core.fluid.plutils.common import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Reference:
https://github.com/briancurtin/deprecation
"""
__all__ = ['deprecated']
import functools
import textwrap
import warnings
warnings.simplefilter('default')
class CustomDeprecationWarning(DeprecationWarning):
def __init__(self,
function,
deprecated_in,
removed_in,
replace_function=None):
"""
Args:
function (String): The function being deprecated.
deprecated_in (String): The version that ``function`` is deprecated in
removed_in (String): The version that ``function`` gets removed in
replace_function (String): The replacement function of deprecated function.
"""
self.function = function
self.deprecated_in = deprecated_in
self.removed_in = removed_in
self.replace_function = replace_function
super(CustomDeprecationWarning, self).__init__(
function, deprecated_in, removed_in, replace_function)
def __str__(self):
warn_string = '[PARL] API `{}` is deprecated since version {} and will be removed in version {}'.format(
self.function, self.deprecated_in, self.removed_in)
if self.replace_function is not None:
warn_string += ", please use `{}` instead.".format(
self.replace_function)
else:
warn_string += "."
return warn_string
def deprecated(deprecated_in, removed_in, replace_function=None):
"""Decorator of decarated function.
Args:
function (String): The function being deprecated.
deprecated_in (String): The version that ``function`` is deprecated in
removed_in (String): The version that ``function`` gets removed in
replace_function (String): The replacement function of deprecated function.
"""
def _function_wrapper(function):
existing_docstring = function.__doc__ or ""
deprecated_doc = '.. deprecated:: {}\n This will be removed in {}'.format(
deprecated_in, removed_in)
if replace_function is not None:
deprecated_doc += ", please use `{}` instead.".format(
replace_function)
else:
deprecated_doc += "."
# split docstring at first occurrence of newline
string_list = existing_docstring.split("\n", 1)
if len(string_list) > 1:
# in-place dedent docstring content
string_list[1] = textwrap.dedent(string_list[1])
# we need another newline
string_list.insert(1, "\n")
# insert deprecation note and dual newline
string_list.insert(1, deprecated_doc)
string_list.insert(1, "\n\n")
function.__doc__ = "".join(string_list)
@functools.wraps(function)
def _inner(*args, **kwargs):
the_warning = CustomDeprecationWarning(
function.__name__, deprecated_in, removed_in, replace_function)
warnings.warn(
the_warning, category=CustomDeprecationWarning, stacklevel=2)
return function(*args, **kwargs)
return _inner
return _function_wrapper
......@@ -17,7 +17,7 @@ import platform
import subprocess
from parl.utils import logger
__all__ = ['get_gpu_count', 'get_ip_address']
__all__ = ['get_gpu_count', 'get_ip_address', 'is_gpu_available']
def get_ip_address():
......@@ -57,8 +57,7 @@ def get_ip_address():
def get_gpu_count():
"""
get avaliable gpu count
"""get avaliable gpu count
Returns:
gpu_count: int
......@@ -88,3 +87,12 @@ def get_gpu_count():
logger.warn('Cannot find available GPU devices, using CPU now.')
gpu_count = 0
return gpu_count
def is_gpu_available():
""" check whether parl can access a GPU
Returns:
True if a gpu device can be found.
"""
return get_gpu_count() > 0
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册