未验证 提交 d21b2c6b 编写于 作者: C ceci3 提交者: GitHub

Add RL (#163)

上级 823ca6bb
# 外部如何自定义强化学习Controller
首先导入必要的依赖:
```python
### 引入强化学习Controller基类函数和注册类函数
from paddleslim.common.RL_controller.utils import RLCONTROLLER
from paddleslim.common.RL_controller import RLBaseController
```
通过装饰器的方式把自定义强化学习Controller注册到PaddleSlim,继承基类之后需要重写基类中的`next_tokens``update`两个函数。注意:本示例仅说明一些必不可少的步骤,并不能直接运行,完整代码请参考[这里]()
```python
### 注意: 类名一定要全部大写
@RLCONTROLLER.register
class LSTM(RLBaseController):
def __init__(self, range_tables, use_gpu=False, **kwargs):
### range_tables 表示tokens的取值范围
self.range_tables = range_tables
### use_gpu 表示是否使用gpu来训练controller
self.use_gpu = use_gpu
### 定义一些强化学习算法中需要的参数
...
### 构造相应的program, _build_program这个函数会构造两个program,一个是pred_program,一个是learn_program, 并初始化参数
self._build_program()
self.place = fluid.CUDAPlace(0) if self.args.use_gpu else fluid.CPUPlace()
self.exe = fluid.Executor(self.place)
self.exe.run(fluid.default_startup_program())
### 保存参数到一个字典中,这个字典由server端统一维护更新,因为可能有多个client同时更新一份参数,所以这一步必不可少,由于pred_program和learn_program使用的同一份参数,所以只需要把learn_program中的参数放入字典中即可
self.param_dicts = {}
self.param_dicts.update(self.learn_program: self.get_params(self.learn_program))
def next_tokens(self, states, params_dict):
### 把从server端获取参数字典赋值给当前要用到的program
self.set_params(self.pred_program, params_dict, self.place)
### 根据states构造输入
self.num_archs = states
feed_dict = self._create_input()
### 获取当前token
actions = self.exe.run(self.pred_program, feed=feed_dict, fetch_list=self.tokens)
...
return actions
def update(self, rewards, params_dict=None):
### 把从server端获取参数字典赋值给当前要用到的program
self.set_params(self.learn_program, params_dict, self.place)
### 根据`next_tokens`中的states和`update`中的rewards构造输入
feed_dict = self._create_input(is_test=False, actual_rewards = rewards)
### 计算当前step的loss
loss = self.exe.run(self.learn_program, feed=feed_dict, fetch_list=[self.loss])
### 获取当前program的参数并返回,client会把本轮的参数传给server端进行参数更新
params_dict = self.get_params(self.learn_program)
return params_dict
```
SA-NAS
NAS
========
搜索空间参数的配置
......@@ -160,3 +160,155 @@ SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火
sanas = SANAS(configs=config)
print(sanas.current_info())
RLNAS
------
.. py:class:: paddleslim.nas.RLNAS(key, configs, use_gpu=False, server_addr=("", 8881), is_server=True, is_sync=False, save_controller=None, load_controller=None, **kwargs)
`源代码 <> `_
RLNAS (Reinforcement Learning Neural Architecture Search)是基于强化学习算法进行模型结构搜索的算法。
- **key<str>** - 使用的强化学习Controller名称,目前paddleslim支持的有`LSTM`和`DDPG`,自定义强化学习Controller请参考 ` 自定义强化学习Controller <> `_
- **configs(list<tuple>)** - 搜索空间配置列表,格式是 ``[(key, {input_size, output_size, block_num, block_mask})]`` 或者 ``[(key)]`` (MobileNetV2、MobilenetV1和ResNet的搜索空间使用和原本网络结构相同的搜索空间,所以仅需指定 ``key`` 即可), ``input_size`` 和 ``output_size`` 表示输入和输出的特征图的大小, ``block_num`` 是指搜索网络中的block数量, ``block_mask`` 是一组由0和1组成的列表,0代表不进行下采样的block,1代表下采样的block。 更多paddleslim提供的搜索空间配置可以参考[Search Space](../search_space.md)。
- **use_gpu(bool)** - 是否使用GPU来训练Controller。默认:False。
- **server_addr(tuple)** - RLNAS中Controller的地址,包括server的ip地址和端口号,如果ip地址为None或者为""的话则默认使用本机ip。默认:("", 8881)。
- **is_server(bool)** - 当前实例是否要启动一个server。默认:True。
- **is_sync(bool)** - 是否使用同步模式更新Controller,该模式仅在多client下有差别。默认:False。
- **save_controller(str|None)** - 保存Controller的checkpoint的文件目录,如果设置为None的话则不保存checkpoint。默认:None 。
- **load_controller(str|None)** - 加载Controller的checkpoint的文件目录,如果设置为None的话则不加载checkpoint。默认:None。
- **\*\*kwargs** - 附加的参数,由具体强化学习算法决定,`LSTM`和`DDPG`的附加参数请参考note。
.. note::
`LSTM`算法的附加参数:
- lstm_num_layers(int, optional): - Controller中堆叠的LSTM的层数。默认:1.
- hidden_size(int, optional): - LSTM中隐藏层的大小。默认:100.
- temperature(float, optional): - 是否在计算每个token过程中做温度平均。默认:None.
- tanh_constant(float, optional): 是否在计算每个token过程中做tanh激活,并乘上`tanh_constant`值。 默认:None。
- decay(float, optional): LSTM中记录rewards的baseline的平滑率。默认:0.99.
- weight_entropy(float, optional): 在更新controller参数时是否为接收到的rewards加上计算token过程中的带权重的交叉熵值。默认:None。
- controller_batch_size(int, optional): controller的batch_size,即每运行一次controller可以拿到几个token。默认:1.
`DDPG`算法的附加参数:
注意:使用`DDPG`算法的话必须安装parl。安装方法: pip install parl
- obs_dim(int): observation的维度。
- model(class,optional): DDPG算法中使用的具体的模型,一般是个类,包含actor_model和critic_model,需要实现两个方法,一个是policy用来获得策略,另一个是value,需要获得Q值。可以参考默认的model` <>_`实现您自己的model。默认:`default_ddpg_model`.
- actor_lr(float, optional): actor网络的学习率。默认:1e-4.
- critic_lr(float, optional): critic网络的学习率。默认:1e-3.
- gamma(float, optional): 接收到rewards之后的折扣因子。默认:0.99.
- tau(float, optional): DDPG中把models的参数同步累积到target_model上时的折扣因子。默认:0.001.
- memory_size(int, optional): DDPG中记录历史信息的池子大小。默认:10.
- reward_scale(float, optional): 记录历史信息时,对rewards信息进行的折扣因子。默认:0.1.
- controller_batch_size(int, optional): controller的batch_size,即每运行一次controller可以拿到几个token。默认:1.
- actions_noise(class, optional): 通过DDPG拿到action之后添加的噪声,设置为False或者None时不添加噪声。默认:default_noise.
..
**返回:**
一个RLNAS类的实例
**示例代码:**
.. code-block:: python
from paddleslim.nas import RLNAS
config = [('MobileNetV2Space')]
rlnas = RLNAS(key='lstm', configs=config)
.. py:method:: next_archs(obs=None)
获取下一组模型结构。
**参数:**
- **obs<int|np.array>** - 需要获取的模型结构数量或者当前模型的observations。
**返回:**
返回模型结构实例的列表,形式为list。
**示例代码:**
.. code-block:: python
import paddle.fluid as fluid
from paddleslim.nas import RLNAS
config = [('MobileNetV2Space')]
rlnas = RLNAS(key='lstm', configs=config)
input = fluid.data(name='input', shape=[None, 3, 32, 32], dtype='float32')
archs = rlnas.next_archs(1)
for arch in archs:
output = arch(input)
input = output
print(output)
.. py:method:: reward(rewards, **kwargs):
把当前模型结构的rewards回传。
**参数:**
- **rewards<float|list<float>>:** - 当前模型的rewards,分数越大越好。
- **\*\*kwargs:** - 附加的参数,取决于具体的强化学习算法。
**示例代码:**
.. code-block:: python
import paddle.fluid as fluid
from paddleslim.nas import RLNAS
config = [('MobileNetV2Space')]
rlnas = RLNAS(key='lstm', configs=config)
rlnas.next_archs(1)
rlnas.reward(1.0)
.. note::
reward这一步必须在`next_token`之后执行。
..
.. py:method:: final_archs(batch_obs):
获取最终的模型结构。一般在controller训练完成之后会获取几十个模型结构进行完整的实验。
**参数:**
- **obs<int|np.array>** - 需要获取的模型结构数量或者当前模型的observations。
**返回:**
返回模型结构实例的列表,形式为list。
**示例代码:**
.. code-block:: python
import paddle.fluid as fluid
from paddleslim.nas import RLNAS
config = [('MobileNetV2Space')]
rlnas = RLNAS(key='lstm', configs=config)
archs = rlnas.final_archs(10)
.. py:methd:: tokens2arch(tokens)
通过一组tokens得到实际的模型结构,一般用来把搜索到最优的token转换为模型结构用来做最后的训练。tokens的形式是一个列表,tokens映射到搜索空间转换成相应的网络结构,一组tokens对应唯一的一个网络结构。
**参数:**
- **tokens(list):** - 一组tokens。tokens的长度和范围取决于搜索空间。
**返回:**
根据传入的token得到一个模型结构实例列表。
**示例代码:**
.. code-block:: python
import paddle.fluid as fluid
from paddleslim.nas import SANAS
config = [('MobileNetV2Space')]
rlnas = RLNAS(key='lstm', configs=config)
input = fluid.data(name='input', shape=[None, 3, 32, 32], dtype='float32')
tokens = ([0] * 25)
archs = sanas.tokens2arch(tokens)[0]
print(archs(input))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import parl
from parl import layers
from paddle import fluid
from ..utils import RLCONTROLLER, action_mapping
from ...controller import RLBaseController
from .ddpg_model import DefaultDDPGModel as default_ddpg_model
from .noise import AdaptiveNoiseSpec as default_noise
from parl.utils import ReplayMemory
__all__ = ['DDPG']
class DDPGAgent(parl.Agent):
def __init__(self, algorithm, obs_dim, act_dim):
assert isinstance(obs_dim, int)
assert isinstance(act_dim, int)
self.obs_dim = obs_dim
self.act_dim = act_dim
super(DDPGAgent, self).__init__(algorithm)
# Attention: In the beginning, sync target model totally.
self.alg.sync_target(decay=0)
def build_program(self):
self.pred_program = fluid.Program()
self.learn_program = fluid.Program()
with fluid.program_guard(self.pred_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
self.pred_act = self.alg.predict(obs)
with fluid.program_guard(self.learn_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
act = layers.data(
name='act', shape=[self.act_dim], dtype='float32')
reward = layers.data(name='reward', shape=[], dtype='float32')
next_obs = layers.data(
name='next_obs', shape=[self.obs_dim], dtype='float32')
terminal = layers.data(name='terminal', shape=[], dtype='bool')
_, self.critic_cost = self.alg.learn(obs, act, reward, next_obs,
terminal)
def predict(self, obs):
obs = np.expand_dims(obs, axis=0)
act = self.fluid_executor.run(self.pred_program,
feed={'obs': obs},
fetch_list=[self.pred_act])[0]
return act
def learn(self, obs, act, reward, next_obs, terminal):
feed = {
'obs': obs,
'act': act,
'reward': reward,
'next_obs': next_obs,
'terminal': terminal
}
critic_cost = self.fluid_executor.run(self.learn_program,
feed=feed,
fetch_list=[self.critic_cost])[0]
self.alg.sync_target()
return critic_cost
@RLCONTROLLER.register
class DDPG(RLBaseController):
def __init__(self, range_tables, use_gpu=False, **kwargs):
self.use_gpu = use_gpu
self.range_tables = range_tables - np.asarray(1)
self.act_dim = len(self.range_tables)
self.obs_dim = kwargs.get('obs_dim')
self.model = kwargs.get(
'model') if 'model' in kwargs else default_ddpg_model
self.actor_lr = kwargs.get(
'actor_lr') if 'actor_lr' in kwargs else 1e-4
self.critic_lr = kwargs.get(
'critic_lr') if 'critic_lr' in kwargs else 1e-3
self.gamma = kwargs.get('gamma') if 'gamma' in kwargs else 0.99
self.tau = kwargs.get('tau') if 'tau' in kwargs else 0.001
self.memory_size = kwargs.get(
'memory_size') if 'memory_size' in kwargs else 10
self.reward_scale = kwargs.get(
'reward_scale') if 'reward_scale' in kwargs else 0.1
self.batch_size = kwargs.get(
'controller_batch_size') if 'controller_batch_size' in kwargs else 1
self.actions_noise = kwargs.get(
'actions_noise') if 'actions_noise' in kwargs else default_noise
self.action_dist = 0.0
self.place = fluid.CUDAPlace(0) if self.use_gpu else fluid.CPUPlace()
model = self.model(self.act_dim)
if self.actions_noise:
self.actions_noise = self.actions_noise()
algorithm = parl.algorithms.DDPG(
model,
gamma=self.gamma,
tau=self.tau,
actor_lr=self.actor_lr,
critic_lr=self.critic_lr)
self.agent = DDPGAgent(algorithm, self.obs_dim, self.act_dim)
self.rpm = ReplayMemory(self.memory_size, self.obs_dim, self.act_dim)
self.pred_program = self.agent.pred_program
self.learn_program = self.agent.learn_program
self.param_dict = self.get_params(self.learn_program)
def next_tokens(self, obs, params_dict, is_inference=False):
batch_obs = np.expand_dims(obs, axis=0)
self.set_params(self.pred_program, params_dict, self.place)
actions = self.agent.predict(batch_obs.astype('float32'))
### add noise to action
if self.actions_noise and is_inference == False:
actions_noise = np.clip(
np.random.normal(
actions, scale=self.actions_noise.stdev_curr),
-1.0,
1.0)
self.action_dist = np.mean(np.abs(actions_noise - actions))
else:
actions_noise = actions
actions_noise = action_mapping(actions_noise, self.range_tables)
return actions_noise
def _update_noise(self, actions_dist):
self.actions_noise.update(actions_dist)
def update(self, rewards, params_dict, obs, actions, obs_next, terminal):
self.set_params(self.learn_program, params_dict, self.place)
self.rpm.append(obs, actions, self.reward_scale * rewards, obs_next,
terminal)
if self.actions_noise:
self._update_noise(self.action_dist)
if self.rpm.size() > self.memory_size:
obs, actions, rewards, obs_next, terminal = rpm.sample_batch(
self.batch_size)
self.agent.learn(obs, actions, rewards, obs_next, terminal)
params_dict = self.get_params(self.learn_program)
return params_dict
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .DDPGController import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import parl
from parl import layers
class DefaultDDPGModel(parl.Model):
def __init__(self, act_dim):
self.actor_model = ActorModel(act_dim)
self.critic_model = CriticModel()
def policy(self, obs):
return self.actor_model.policy(obs)
def value(self, obs, act):
return self.critic_model.value(obs, act)
def get_actor_params(self):
return self.actor_model.parameters()
class ActorModel(parl.Model):
def __init__(self, act_dim):
hid1_size = 400
hid2_size = 300
self.fc1 = layers.fc(size=hid1_size, act='relu')
self.fc2 = layers.fc(size=hid2_size, act='relu')
self.fc3 = layers.fc(size=act_dim, act='tanh')
def policy(self, obs):
hid1 = self.fc1(obs)
hid2 = self.fc2(hid1)
means = self.fc3(hid2)
means = means
return means
class CriticModel(parl.Model):
def __init__(self):
hid1_size = 400
hid2_size = 300
self.fc1 = layers.fc(size=hid1_size, act='relu')
self.fc2 = layers.fc(size=hid2_size, act='relu')
self.fc3 = layers.fc(size=1, act=None)
def value(self, obs, act):
hid1 = self.fc1(obs)
concat = layers.concat([hid1, act], axis=1)
hid2 = self.fc2(concat)
Q = self.fc3(hid2)
Q = layers.squeeze(Q, axes=[1])
return Q
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__ = ['AdaptiveNoiseSpec']
class AdaptiveNoiseSpec(object):
def __init__(self):
self.stdev_curr = 1.0
def reset(self):
self.stdev_curr = 1.0
def update(self, action_dist):
if action_dist > 1e-2:
self.stdev_curr /= 1.03
else:
self.stdev_curr *= 1.03
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import logging
import numpy as np
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
from paddle.fluid.layers import RNNCell, LSTMCell, rnn
from paddle.fluid.contrib.layers import basic_lstm
from ...controller import RLBaseController
from ...log_helper import get_logger
from ..utils import RLCONTROLLER
_logger = get_logger(__name__, level=logging.INFO)
uniform_initializer = lambda x: fluid.initializer.UniformInitializer(low=-x, high=x)
class lstm_cell(RNNCell):
def __init__(self, num_layers, hidden_size):
self.num_layers = num_layers
self.hidden_size = hidden_size
self.lstm_cells = []
param_attr = ParamAttr(initializer=uniform_initializer(
1.0 / math.sqrt(hidden_size)))
bias_attr = ParamAttr(initializer=uniform_initializer(
1.0 / math.sqrt(hidden_size)))
for i in range(num_layers):
self.lstm_cells.append(
LSTMCell(hidden_size, param_attr, bias_attr))
def call(self, inputs, states):
new_states = []
for i in range(self.num_layers):
out, new_state = self.lstm_cells[i](inputs, states[i])
new_states.append(new_state)
return out, new_states
@property
def state_shape(self):
return [cell.state_shape for cell in self.lstm_cells]
@RLCONTROLLER.register
class LSTM(RLBaseController):
def __init__(self, range_tables, use_gpu=False, **kwargs):
self.use_gpu = use_gpu
self.range_tables = range_tables
self.lstm_num_layers = kwargs.get('lstm_num_layers') or 1
self.hidden_size = kwargs.get('hidden_size') or 100
self.temperature = kwargs.get('temperature') or None
self.tanh_constant = kwargs.get('tanh_constant') or None
self.decay = kwargs.get('decay') or 0.99
self.weight_entropy = kwargs.get('weight_entropy') or None
self.controller_batch_size = kwargs.get('controller_batch_size') or 1
self.max_range_table = max(self.range_tables) + 1
self._create_parameter()
self._build_program()
self.place = fluid.CUDAPlace(0) if self.use_gpu else fluid.CPUPlace()
self.exe = fluid.Executor(self.place)
self.exe.run(fluid.default_startup_program())
self.param_dict = self.get_params(self.learn_program)
def _lstm(self, inputs, hidden, cell, token_idx):
cells = lstm_cell(self.lstm_num_layers, self.hidden_size)
output, new_states = cells.call(inputs, states=([[hidden, cell]]))
logits = fluid.layers.fc(new_states[0], self.range_tables[token_idx])
if self.temperature is not None:
logits = logits / self.temperature
if self.tanh_constant is not None:
logits = self.tanh_constant * fluid.layers.tanh(logits)
return logits, output, new_states
def _create_parameter(self):
self.emb_w = fluid.layers.create_parameter(
name='emb_w',
shape=(self.max_range_table, self.hidden_size),
dtype='float32',
default_initializer=uniform_initializer(1.0))
self.g_emb = fluid.layers.create_parameter(
name='emb_g',
shape=(self.controller_batch_size, self.hidden_size),
dtype='float32',
default_initializer=uniform_initializer(1.0))
self.baseline = fluid.layers.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=True,
name='baseline')
self.baseline.stop_gradient = True
def _network(self, hidden, cell, init_actions=None, is_inference=False):
actions = []
entropies = []
sample_log_probs = []
with fluid.unique_name.guard('Controller'):
self._create_parameter()
inputs = self.g_emb
for idx in range(len(self.range_tables)):
logits, output, states = self._lstm(
inputs, hidden, cell, token_idx=idx)
hidden, cell = np.squeeze(states)
probs = fluid.layers.softmax(logits, axis=1)
if is_inference:
action = fluid.layers.argmax(probs, axis=1)
else:
if init_actions:
action = fluid.layers.slice(
init_actions,
axes=[1],
starts=[idx],
ends=[idx + 1])
action.stop_gradient = True
else:
action = fluid.layers.sampling_id(probs)
actions.append(action)
log_prob = fluid.layers.cross_entropy(probs, action)
sample_log_probs.append(log_prob)
entropy = log_prob * fluid.layers.exp(-1 * log_prob)
entropy.stop_gradient = True
entropies.append(entropy)
action_emb = fluid.layers.cast(action, dtype=np.int64)
inputs = fluid.layers.gather(self.emb_w, action_emb)
sample_log_probs = fluid.layers.stack(sample_log_probs)
self.sample_log_probs = fluid.layers.reduce_sum(sample_log_probs)
entropies = fluid.layers.stack(entropies)
self.sample_entropies = fluid.layers.reduce_sum(entropies)
return actions
def _build_program(self, is_inference=False):
self.pred_program = fluid.Program()
self.learn_program = fluid.Program()
with fluid.program_guard(self.pred_program):
self.g_emb = fluid.layers.create_parameter(
name='emb_g',
shape=(self.controller_batch_size, self.hidden_size),
dtype='float32',
default_initializer=uniform_initializer(1.0))
fluid.layers.assign(
fluid.layers.uniform_random(shape=self.g_emb.shape),
self.g_emb)
hidden = fluid.data(name='hidden', shape=[None, self.hidden_size])
cell = fluid.data(name='cell', shape=[None, self.hidden_size])
self.tokens = self._network(
hidden, cell, is_inference=is_inference)
with fluid.program_guard(self.learn_program):
hidden = fluid.data(name='hidden', shape=[None, self.hidden_size])
cell = fluid.data(name='cell', shape=[None, self.hidden_size])
init_actions = fluid.data(
name='init_actions',
shape=[None, len(self.range_tables)],
dtype='int64')
self._network(hidden, cell, init_actions=init_actions)
rewards = fluid.data(name='rewards', shape=[None])
self.rewards = fluid.layers.reduce_mean(rewards)
if self.weight_entropy is not None:
self.rewards += self.weight_entropy * self.sample_entropies
self.sample_log_probs = fluid.layers.reduce_sum(
self.sample_log_probs)
fluid.layers.assign(self.baseline - (1.0 - self.decay) *
(self.baseline - self.rewards), self.baseline)
self.loss = -1.0 * self.sample_log_probs * (
self.rewards - self.baseline)
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=5.0))
optimizer = fluid.optimizer.Adam(learning_rate=1e-3)
optimizer.minimize(self.loss)
def _create_input(self, is_test=True, actual_rewards=None):
feed_dict = dict()
np_init_hidden = np.zeros(
(self.controller_batch_size, self.hidden_size)).astype('float32')
np_init_cell = np.zeros(
(self.controller_batch_size, self.hidden_size)).astype('float32')
feed_dict["hidden"] = np_init_hidden
feed_dict["cell"] = np_init_cell
if is_test == False:
if isinstance(actual_rewards, np.float32):
assert actual_rewards != None, "if you want to update controller, you must inputs a reward"
actual_rewards = np.expand_dims(actual_rewards, axis=0)
elif isinstance(actual_rewards, np.float) or isinstance(
actual_rewards, np.float64):
actual_rewards = np.float32(actual_rewards)
assert actual_rewards != None, "if you want to update controller, you must inputs a reward"
actual_rewards = np.expand_dims(actual_rewards, axis=0)
else:
assert actual_rewards.all(
) != None, "if you want to update controller, you must inputs a reward"
actual_rewards = actual_rewards.astype(np.float32)
feed_dict['rewards'] = actual_rewards
feed_dict['init_actions'] = np.array(self.init_tokens)
return feed_dict
def next_tokens(self, num_archs=1, params_dict=None, is_inference=False):
""" sample next tokens according current parameter and inputs"""
self.num_archs = num_archs
self.set_params(self.pred_program, params_dict, self.place)
batch_tokens = []
feed_dict = self._create_input()
for _ in range(
int(np.ceil(float(num_archs) / self.controller_batch_size))):
if is_inference:
self._build_program(is_inference=True)
actions = self.exe.run(self.pred_program,
feed=feed_dict,
fetch_list=self.tokens)
for idx in range(self.controller_batch_size):
each_token = {}
for i, action in enumerate(actions):
token = action[idx]
if idx in each_token:
each_token[idx].append(int(token))
else:
each_token[idx] = [int(token)]
batch_tokens.append(each_token[idx])
self.init_tokens = batch_tokens
mod_token = (self.controller_batch_size -
(num_archs % self.controller_batch_size)
) % self.controller_batch_size
if mod_token != 0:
return batch_tokens[:-mod_token]
else:
return batch_tokens
def update(self, rewards, params_dict=None):
"""train controller according reward"""
self.set_params(self.learn_program, params_dict, self.place)
feed_dict = self._create_input(is_test=False, actual_rewards=rewards)
loss = self.exe.run(self.learn_program,
feed=feed_dict,
fetch_list=[self.loss])
_logger.info("Controller: current reward is {}, loss is {}".format(
rewards, loss))
params_dict = self.get_params(self.learn_program)
return params_dict
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .LSTM_Controller import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from ..log_helper import get_logger
_logger = get_logger(__name__, level=logging.INFO)
try:
import parl
from .DDPG import *
except ImportError as e:
_logger.warn(
"If you want to use DDPG in RLNAS, please pip intall parl first. Now states: {}".
format(e))
from .LSTM import *
from .utils import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from ...core import Registry
__all__ = [
"RLCONTROLLER", "action_mapping", "add_grad", "compute_grad",
"ConnectMessage"
]
RLCONTROLLER = Registry('RLController')
class ConnectMessage:
INIT = 'INIT'
INIT_DONE = 'INIT_DONE'
GET_WEIGHT = 'GET_WEIGHT'
UPDATE_WEIGHT = 'UPDATE_WEIGHT'
OK = 'OK'
WAIT = 'WAIT'
WAIT_PARAMS = 'WAIT_PARAMS'
EXIT = 'EXIT'
TIMEOUT = 10
def action_mapping(actions, range_table):
actions = (actions - (-1.0)) * (range_table / np.asarray(2.0))
return actions.astype('int64')
def add_grad(dict1, dict2):
dict3 = dict()
for key, value in dict1.items():
dict3[key] = dict1[key] + dict2[key]
return dict3
def compute_grad(dict1, dict2):
dict3 = dict()
for key, value in dict1.items():
dict3[key] = dict1[key] - dict2[key]
return dict3
......@@ -18,9 +18,12 @@ from .controller_server import ControllerServer
from .controller_client import ControllerClient
from .lock import lock, unlock
from .cached_reader import cached_reader
from .server import Server
from .client import Client
from .meter import AvgrageMeter
__all__ = [
'EvolutionaryController', 'SAController', 'get_logger', 'ControllerServer',
'ControllerClient', 'lock', 'unlock', 'cached_reader', 'AvgrageMeter'
'ControllerClient', 'lock', 'unlock', 'cached_reader', 'AvgrageMeter',
'Server', 'Client', 'RLBaseController'
]
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import signal
import zmq
import socket
import logging
import time
import threading
import cloudpickle
from .log_helper import get_logger
from .RL_controller.utils import compute_grad, ConnectMessage
_logger = get_logger(__name__, level=logging.INFO)
class Client(object):
def __init__(self, controller, address, client_name):
self._controller = controller
self._address = address
self._ip = self._address[0]
self._port = self._address[1]
self._client_name = client_name
self._params_dict = None
self.init_wait = False
self._connect_server()
def _connect_server(self):
self._ctx = zmq.Context()
self._client_socket = self._ctx.socket(zmq.REQ)
### NOTE: change the method to exit client when server is dead if there are better solutions
self._client_socket.setsockopt(zmq.RCVTIMEO,
ConnectMessage.TIMEOUT * 1000)
client_address = "{}:{}".format(self._ip, self._port)
self._client_socket.connect("tcp://{}".format(client_address))
self._client_socket.send_multipart(
[ConnectMessage.INIT, self._client_name])
message = self._client_socket.recv_multipart()
if message[0] != ConnectMessage.INIT_DONE:
_logger.error("Client {} init failure, Please start it again".
format(self._client_name))
pid = os.getpid()
os.kill(pid, signal.SIGTERM)
_logger.info("Client {}: connect to server {}".format(
self._client_name, client_address))
def _connect_wait_socket(self, port):
self._wait_socket = self._ctx.socket(zmq.REQ)
wait_address = "{}:{}".format(self._ip, port)
self._wait_socket.connect("tcp://{}".format(wait_address))
self._wait_socket.send_multipart(
[ConnectMessage.WAIT_PARAMS, self._client_name])
message = self._wait_socket.recv_multipart()
return message[0]
def next_tokens(self, obs, is_inference=False):
_logger.debug("Client: requests for weight {}".format(
self._client_name))
self._client_socket.send_multipart(
[ConnectMessage.GET_WEIGHT, self._client_name])
try:
message = self._client_socket.recv_multipart()
except zmq.error.Again as e:
_logger.error(
"CANNOT recv params from server in next_archs, Please check whether the server is alive!!! {}".
format(e))
os._exit(0)
self._params_dict = cloudpickle.loads(message[0])
tokens = self._controller.next_tokens(
obs, params_dict=self._params_dict, is_inference=is_inference)
_logger.debug("Client: client_name is {}, current token is {}".format(
self._client_name, tokens))
return tokens
def update(self, rewards, **kwargs):
assert self._params_dict != None, "Please call next_token to get token first, then call update"
current_params_dict = self._controller.update(
rewards, self._params_dict, **kwargs)
params_grad = compute_grad(self._params_dict, current_params_dict)
_logger.debug("Client: update weight {}".format(self._client_name))
self._client_socket.send_multipart([
ConnectMessage.UPDATE_WEIGHT, self._client_name,
cloudpickle.dumps(params_grad)
])
_logger.debug("Client: update done {}".format(self._client_name))
try:
message = self._client_socket.recv_multipart()
except zmq.error.Again as e:
_logger.error(
"CANNOT recv params from server in rewards, Please check whether the server is alive!!! {}".
format(e))
os._exit(0)
if message[0] == ConnectMessage.WAIT:
_logger.debug("Client: self.init_wait: {}".format(self.init_wait))
if not self.init_wait:
wait_port = cloudpickle.loads(message[1])
wait_signal = self._connect_wait_socket(wait_port)
self.init_wait = True
else:
wait_signal = message[0]
while wait_signal != ConnectMessage.OK:
time.sleep(1)
self._wait_socket.send_multipart(
[ConnectMessage.WAIT_PARAMS, self._client_name])
wait_signal = self._wait_socket.recv_multipart()
wait_signal = wait_signal[0]
_logger.debug("Client: {} {}".format(self._client_name,
wait_signal))
return message[0]
def __del__(self):
try:
self._client_socket.send_multipart(
[ConnectMessage.EXIT, self._client_name])
_ = self._client_socket.recv_multipart()
except:
pass
self._client_socket.close()
......@@ -16,8 +16,9 @@
import copy
import math
import numpy as np
import paddle.fluid as fluid
__all__ = ['EvolutionaryController']
__all__ = ['EvolutionaryController', 'RLBaseController']
class EvolutionaryController(object):
......@@ -51,3 +52,31 @@ class EvolutionaryController(object):
list<list>: The next searched tokens.
"""
raise NotImplementedError('Abstract method.')
class RLBaseController(object):
""" Base Controller for reforcement learning"""
def next_tokens(self, *args, **kwargs):
raise NotImplementedError('Abstract method.')
def update(self, *args, **kwargs):
raise NotImplementedError('Abstract method.')
def save_controller(self, program, output_dir):
fluid.save(program, output_dir)
def load_controller(self, program, load_dir):
fluid.load(program, load_dir)
def get_params(self, program):
var_dict = {}
for var in program.global_block().all_parameters():
var_dict[var.name] = np.array(fluid.global_scope().find_var(
var.name).get_tensor())
return var_dict
def set_params(self, program, params_dict, place):
for var in program.global_block().all_parameters():
fluid.global_scope().find_var(var.name).get_tensor().set(
params_dict[var.name], place)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import zmq
import socket
import signal
import six
import os
if six.PY2:
import cPickle as pickle
else:
import pickle
import logging
import time
import threading
import cloudpickle
from .log_helper import get_logger
from .RL_controller.utils import add_grad, ConnectMessage
_logger = get_logger(__name__, level=logging.INFO)
class Server(object):
def __init__(self,
controller,
address,
is_sync=False,
load_controller=None,
save_controller=None):
self._controller = controller
self._address = address
self._ip = self._address[0]
self._port = self._address[1]
self._is_sync = is_sync
self._done = False
self._load_controller = load_controller
self._save_controller = save_controller
### key-value : client_name-update_times
self._client_dict = dict()
self._client = list()
self._lock = threading.Lock()
self._server_alive = True
self._max_update_times = 0
def close(self):
self._server_alive = False
_logger.info("server closed")
pid = os.getpid()
os.kill(pid, signal.SIGTERM)
def start(self):
self._ctx = zmq.Context()
### main socket
self._server_socket = self._ctx.socket(zmq.REP)
server_address = "{}:{}".format(self._ip, self._port)
self._server_socket.bind("tcp://{}".format(server_address))
self._server_socket.linger = 0
_logger.info("ControllerServer - listen on: [{}]".format(
server_address))
thread = threading.Thread(target=self.run, args=())
thread.setDaemon(True)
thread.start()
if self._load_controller:
assert os.path.exists(
self._load_controller
), "controller checkpoint is not exist, please check your directory: {}".format(
self._load_controller)
with open(
os.path.join(self._load_controller, 'rlnas.params'),
'rb') as f:
self._params_dict = pickle.load(f)
_logger.info("Load params done")
else:
self._params_dict = self._controller.param_dict
if self._is_sync:
self._wait_socket = self._ctx.socket(zmq.REP)
self._wait_port = self._wait_socket.bind_to_random_port(
addr="tcp://*")
self._wait_socket_linger = 0
wait_thread = threading.Thread(
target=self._wait_for_params, args=())
wait_thread.setDaemon(True)
wait_thread.start()
def _wait_for_params(self):
try:
while self._server_alive:
message = self._wait_socket.recv_multipart()
cmd = message[0]
client_name = message[1]
if cmd == ConnectMessage.WAIT_PARAMS:
_logger.debug("Server: wait for params")
self._lock.acquire()
self._wait_socket.send_multipart([
ConnectMessage.OK
if self._done else ConnectMessage.WAIT
])
if self._done and client_name in self._client:
self._client.remove(client_name)
if len(self._client) == 0:
self.save_params()
self._done = False
self._lock.release()
else:
_logger.error("Error message {}".format(message))
raise NotImplementedError
except Exception as err:
logger.error(err)
def run(self):
try:
while self._server_alive:
try:
sum_params_dict = dict()
message = self._server_socket.recv_multipart()
cmd = message[0]
client_name = message[1]
if cmd == ConnectMessage.INIT:
self._server_socket.send_multipart(
[ConnectMessage.INIT_DONE])
_logger.debug("Server: init client {}".format(
client_name))
self._client_dict[client_name] = 0
elif cmd == ConnectMessage.GET_WEIGHT:
self._lock.acquire()
_logger.debug("Server: get weight {}".format(
client_name))
self._server_socket.send_multipart(
[cloudpickle.dumps(self._params_dict)])
_logger.debug("Server: send params done {}".format(
client_name))
self._lock.release()
elif cmd == ConnectMessage.UPDATE_WEIGHT:
_logger.info("Server: update {}".format(client_name))
params_dict_grad = cloudpickle.loads(message[2])
if self._is_sync:
if not sum_params_dict:
sum_params_dict = self._params_dict
self._lock.acquire()
sum_params_dict = add_grad(sum_params_dict,
params_dict_grad)
self._client.append(client_name)
self._lock.release()
if len(self._client) == len(
self._client_dict.items()):
self._done = True
self._server_socket.send_multipart([
ConnectMessage.WAIT,
cloudpickle.dumps(self._wait_port)
])
else:
self._lock.acquire()
self._params_dict = add_grad(self._params_dict,
params_dict_grad)
self._client_dict[client_name] += 1
if self._client_dict[
client_name] > self._max_update_times:
self._max_update_times = self._client_dict[
client_name]
self._lock.release()
self.save_params()
self._server_socket.send_multipart(
[ConnectMessage.OK])
elif cmd == ConnectMessage.EXIT:
self._client_dict.pop(client_name)
if client_name in self._client:
self._client.remove(client_name)
self._server_socket.send_multipart(
[ConnectMessage.EXIT])
except zmq.error.Again as e:
_logger.error(e)
self.close()
except Exception as err:
_logger.error(err)
finally:
self._server_socket.close(0)
if self._is_sync:
self._wait_socket.close(0)
self.close()
def save_params(self):
if self._save_controller:
if not os.path.exists(self._save_controller):
os.makedirs(self._save_controller)
output_dir = self._save_controller
else:
os.makedirs('./.rlnas_controller')
output_dir = './.rlnas_controller'
with open(os.path.join(output_dir, 'rlnas.params'), 'wb') as f:
pickle.dump(self._params_dict, f)
_logger.info("Save params done")
......@@ -16,10 +16,12 @@ from ..nas import search_space
from .search_space import *
from ..nas import sa_nas
from .sa_nas import *
from .rl_nas import *
from ..nas import darts
from .darts import *
__all__ = []
__all__ += sa_nas.__all__
__all__ += search_space.__all__
__all__ += rl_nas.__all__
__all__ += darts.__all__
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import socket
import logging
import numpy as np
import json
import hashlib
import time
import paddle.fluid as fluid
from ..common.RL_controller.utils import RLCONTROLLER
from ..common import get_logger
from ..common import Server
from ..common import Client
from .search_space import SearchSpaceFactory
_logger = get_logger(__name__, level=logging.INFO)
__all__ = ['RLNAS']
class RLNAS(object):
"""
Controller with Reinforcement Learning.
Args:
key(str): The actual reinforcement learning method. Current support in paddleslim is `LSTM` and `DDPG`.
configs(list<tuple>): A list of search space configuration with format [(key, {input_size,
output_size, block_num, block_mask})]. `key` is the name of search space
with data type str. `input_size` and `output_size` are input size and
output size of searched sub-network. `block_num` is the number of blocks
in searched network, `block_mask` is a list consists by 0 and 1, 0 means
normal block, 1 means reduction block.
use_gpu(bool): Whether to use gpu in controller. Default: False.
server_addr(tuple): Server address, including ip and port of server. If ip is None or "", will
use host ip if is_server = True. Default: ("", 8881).
is_server(bool): Whether current host is controller server. Default: True.
is_sync(bool): Whether to update controller in synchronous mode. Default: False.
save_controller(str|None): The directory of controller to save, if set to None, not save checkpoint.
Default: None.
load_controller(str|None): The directory of controller to load, if set to None, not load checkpoint.
Default: None.
**kwargs: Additional keyword arguments.
"""
def __init__(self,
key,
configs,
use_gpu=False,
server_addr=("", 8881),
is_server=True,
is_sync=False,
save_controller=None,
load_controller=None,
**kwargs):
if not is_server:
assert server_addr[
0] != "", "You should set the IP and port of server when is_server is False."
self._configs = configs
factory = SearchSpaceFactory()
self._search_space = factory.get_search_space(configs)
self.range_tables = self._search_space.range_table()
self.save_controller = save_controller
self.load_controller = load_controller
cls = RLCONTROLLER.get(key.upper())
server_ip, server_port = server_addr
if server_ip == None or server_ip == "":
server_ip = self._get_host_ip()
self._controller = cls(range_tables=self.range_tables,
use_gpu=use_gpu,
**kwargs)
if is_server:
max_client_num = 300
self._controller_server = Server(
controller=self._controller,
address=(server_ip, server_port),
is_sync=is_sync,
save_controller=self.save_controller,
load_controller=self.load_controller)
self._controller_server.start()
self._client_name = hashlib.md5(
str(time.time() + np.random.randint(1, 10000)).encode(
"utf-8")).hexdigest()
self._controller_client = Client(
controller=self._controller,
address=(server_ip, server_port),
client_name=self._client_name)
self._current_tokens = None
def _get_host_ip(self):
try:
return socket.gethostbyname(socket.gethostname())
except:
return socket.gethostbyname('localhost')
def next_archs(self, obs=None):
"""
Get next archs
Args:
obs(int|np.array): observations in env.
"""
archs = []
self._current_tokens = self._controller_client.next_tokens(obs)
_logger.info("current tokens: {}".format(self._current_tokens))
for token in self._current_tokens:
archs.append(self._search_space.token2arch(token))
return archs
def reward(self, rewards, **kwargs):
"""
reward the score and to train controller
Args:
rewards(float|list<float>): rewards get by tokens.
**kwargs: Additional keyword arguments.
"""
return self._controller_client.update(rewards, **kwargs)
def final_archs(self, batch_obs):
"""
Get finally architecture
Args:
batch_obs(int|np.array): observations in env.
"""
final_tokens = self._controller_client.next_tokens(
batch_obs, is_inference=True)
_logger.info("Final tokens: {}".format(final_tokens))
archs = []
for token in final_tokens:
arch = self._search_space.token2arch(token)
archs.append(arch)
return archs
def tokens2arch(self, tokens):
"""
Convert tokens to model architectures.
Args
tokens<list>: A list of token. The length and range based on search space.:
Returns:
list<function>: A model architecture instance according to tokens.
"""
return self._search_space.token2arch(tokens)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册