From d21b2c6be481144f610338938aa6eebdc01fcad0 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Wed, 22 Apr 2020 16:16:22 +0800 Subject: [PATCH] Add RL (#163) --- docs/zh_cn/api_cn/custom_rl_controller.md | 54 ++++ docs/zh_cn/api_cn/nas_api.rst | 154 +++++++++- .../RL_controller/DDPG/DDPGController.py | 157 ++++++++++ .../common/RL_controller/DDPG/__init__.py | 15 + .../common/RL_controller/DDPG/ddpg_model.py | 67 +++++ paddleslim/common/RL_controller/DDPG/noise.py | 29 ++ .../RL_controller/LSTM/LSTM_Controller.py | 281 ++++++++++++++++++ .../common/RL_controller/LSTM/__init__.py | 15 + paddleslim/common/RL_controller/__init__.py | 27 ++ paddleslim/common/RL_controller/utils.py | 54 ++++ paddleslim/common/__init__.py | 5 +- paddleslim/common/client.py | 133 +++++++++ paddleslim/common/controller.py | 31 +- paddleslim/common/server.py | 211 +++++++++++++ paddleslim/nas/__init__.py | 2 + paddleslim/nas/rl_nas.py | 162 ++++++++++ 16 files changed, 1394 insertions(+), 3 deletions(-) create mode 100644 docs/zh_cn/api_cn/custom_rl_controller.md create mode 100644 paddleslim/common/RL_controller/DDPG/DDPGController.py create mode 100644 paddleslim/common/RL_controller/DDPG/__init__.py create mode 100644 paddleslim/common/RL_controller/DDPG/ddpg_model.py create mode 100644 paddleslim/common/RL_controller/DDPG/noise.py create mode 100644 paddleslim/common/RL_controller/LSTM/LSTM_Controller.py create mode 100644 paddleslim/common/RL_controller/LSTM/__init__.py create mode 100644 paddleslim/common/RL_controller/__init__.py create mode 100644 paddleslim/common/RL_controller/utils.py create mode 100644 paddleslim/common/client.py create mode 100644 paddleslim/common/server.py create mode 100644 paddleslim/nas/rl_nas.py diff --git a/docs/zh_cn/api_cn/custom_rl_controller.md b/docs/zh_cn/api_cn/custom_rl_controller.md new file mode 100644 index 00000000..012bed5c --- /dev/null +++ b/docs/zh_cn/api_cn/custom_rl_controller.md @@ -0,0 +1,54 @@ +# 外部如何自定义强化学习Controller + +首先导入必要的依赖: +```python +### 引入强化学习Controller基类函数和注册类函数 +from paddleslim.common.RL_controller.utils import RLCONTROLLER +from paddleslim.common.RL_controller import RLBaseController +``` + +通过装饰器的方式把自定义强化学习Controller注册到PaddleSlim,继承基类之后需要重写基类中的`next_tokens`和`update`两个函数。注意:本示例仅说明一些必不可少的步骤,并不能直接运行,完整代码请参考[这里]() + +```python +### 注意: 类名一定要全部大写 +@RLCONTROLLER.register +class LSTM(RLBaseController): + def __init__(self, range_tables, use_gpu=False, **kwargs): + ### range_tables 表示tokens的取值范围 + self.range_tables = range_tables + ### use_gpu 表示是否使用gpu来训练controller + self.use_gpu = use_gpu + ### 定义一些强化学习算法中需要的参数 + ... + ### 构造相应的program, _build_program这个函数会构造两个program,一个是pred_program,一个是learn_program, 并初始化参数 + self._build_program() + self.place = fluid.CUDAPlace(0) if self.args.use_gpu else fluid.CPUPlace() + self.exe = fluid.Executor(self.place) + self.exe.run(fluid.default_startup_program()) + + ### 保存参数到一个字典中,这个字典由server端统一维护更新,因为可能有多个client同时更新一份参数,所以这一步必不可少,由于pred_program和learn_program使用的同一份参数,所以只需要把learn_program中的参数放入字典中即可 + self.param_dicts = {} + self.param_dicts.update(self.learn_program: self.get_params(self.learn_program)) + + def next_tokens(self, states, params_dict): + ### 把从server端获取参数字典赋值给当前要用到的program + self.set_params(self.pred_program, params_dict, self.place) + ### 根据states构造输入 + self.num_archs = states + feed_dict = self._create_input() + ### 获取当前token + actions = self.exe.run(self.pred_program, feed=feed_dict, fetch_list=self.tokens) + ... + return actions + + def update(self, rewards, params_dict=None): + ### 把从server端获取参数字典赋值给当前要用到的program + self.set_params(self.learn_program, params_dict, self.place) + ### 根据`next_tokens`中的states和`update`中的rewards构造输入 + feed_dict = self._create_input(is_test=False, actual_rewards = rewards) + ### 计算当前step的loss + loss = self.exe.run(self.learn_program, feed=feed_dict, fetch_list=[self.loss]) + ### 获取当前program的参数并返回,client会把本轮的参数传给server端进行参数更新 + params_dict = self.get_params(self.learn_program) + return params_dict +``` diff --git a/docs/zh_cn/api_cn/nas_api.rst b/docs/zh_cn/api_cn/nas_api.rst index b09fe477..238e0a8c 100644 --- a/docs/zh_cn/api_cn/nas_api.rst +++ b/docs/zh_cn/api_cn/nas_api.rst @@ -1,4 +1,4 @@ -SA-NAS +NAS ======== 搜索空间参数的配置 @@ -160,3 +160,155 @@ SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火 sanas = SANAS(configs=config) print(sanas.current_info()) + + +RLNAS +------ + +.. py:class:: paddleslim.nas.RLNAS(key, configs, use_gpu=False, server_addr=("", 8881), is_server=True, is_sync=False, save_controller=None, load_controller=None, **kwargs) + +`源代码 <> `_ + +RLNAS (Reinforcement Learning Neural Architecture Search)是基于强化学习算法进行模型结构搜索的算法。 + +- **key** - 使用的强化学习Controller名称,目前paddleslim支持的有`LSTM`和`DDPG`,自定义强化学习Controller请参考 ` 自定义强化学习Controller <> `_ +- **configs(list)** - 搜索空间配置列表,格式是 ``[(key, {input_size, output_size, block_num, block_mask})]`` 或者 ``[(key)]`` (MobileNetV2、MobilenetV1和ResNet的搜索空间使用和原本网络结构相同的搜索空间,所以仅需指定 ``key`` 即可), ``input_size`` 和 ``output_size`` 表示输入和输出的特征图的大小, ``block_num`` 是指搜索网络中的block数量, ``block_mask`` 是一组由0和1组成的列表,0代表不进行下采样的block,1代表下采样的block。 更多paddleslim提供的搜索空间配置可以参考[Search Space](../search_space.md)。 +- **use_gpu(bool)** - 是否使用GPU来训练Controller。默认:False。 +- **server_addr(tuple)** - RLNAS中Controller的地址,包括server的ip地址和端口号,如果ip地址为None或者为""的话则默认使用本机ip。默认:("", 8881)。 +- **is_server(bool)** - 当前实例是否要启动一个server。默认:True。 +- **is_sync(bool)** - 是否使用同步模式更新Controller,该模式仅在多client下有差别。默认:False。 +- **save_controller(str|None)** - 保存Controller的checkpoint的文件目录,如果设置为None的话则不保存checkpoint。默认:None 。 +- **load_controller(str|None)** - 加载Controller的checkpoint的文件目录,如果设置为None的话则不加载checkpoint。默认:None。 +- **\*\*kwargs** - 附加的参数,由具体强化学习算法决定,`LSTM`和`DDPG`的附加参数请参考note。 + +.. note:: + + `LSTM`算法的附加参数: + + - lstm_num_layers(int, optional): - Controller中堆叠的LSTM的层数。默认:1. + - hidden_size(int, optional): - LSTM中隐藏层的大小。默认:100. + - temperature(float, optional): - 是否在计算每个token过程中做温度平均。默认:None. + - tanh_constant(float, optional): 是否在计算每个token过程中做tanh激活,并乘上`tanh_constant`值。 默认:None。 + - decay(float, optional): LSTM中记录rewards的baseline的平滑率。默认:0.99. + - weight_entropy(float, optional): 在更新controller参数时是否为接收到的rewards加上计算token过程中的带权重的交叉熵值。默认:None。 + - controller_batch_size(int, optional): controller的batch_size,即每运行一次controller可以拿到几个token。默认:1. + + + `DDPG`算法的附加参数: + 注意:使用`DDPG`算法的话必须安装parl。安装方法: pip install parl + + - obs_dim(int): observation的维度。 + - model(class,optional): DDPG算法中使用的具体的模型,一般是个类,包含actor_model和critic_model,需要实现两个方法,一个是policy用来获得策略,另一个是value,需要获得Q值。可以参考默认的model` <>_`实现您自己的model。默认:`default_ddpg_model`. + - actor_lr(float, optional): actor网络的学习率。默认:1e-4. + - critic_lr(float, optional): critic网络的学习率。默认:1e-3. + - gamma(float, optional): 接收到rewards之后的折扣因子。默认:0.99. + - tau(float, optional): DDPG中把models的参数同步累积到target_model上时的折扣因子。默认:0.001. + - memory_size(int, optional): DDPG中记录历史信息的池子大小。默认:10. + - reward_scale(float, optional): 记录历史信息时,对rewards信息进行的折扣因子。默认:0.1. + - controller_batch_size(int, optional): controller的batch_size,即每运行一次controller可以拿到几个token。默认:1. + - actions_noise(class, optional): 通过DDPG拿到action之后添加的噪声,设置为False或者None时不添加噪声。默认:default_noise. +.. + +**返回:** +一个RLNAS类的实例 + +**示例代码:** + +.. code-block:: python + + from paddleslim.nas import RLNAS + config = [('MobileNetV2Space')] + rlnas = RLNAS(key='lstm', configs=config) + + + .. py:method:: next_archs(obs=None) + 获取下一组模型结构。 + + **参数:** + + - **obs** - 需要获取的模型结构数量或者当前模型的observations。 + + **返回:** + 返回模型结构实例的列表,形式为list。 + + **示例代码:** + + .. code-block:: python + import paddle.fluid as fluid + from paddleslim.nas import RLNAS + config = [('MobileNetV2Space')] + rlnas = RLNAS(key='lstm', configs=config) + input = fluid.data(name='input', shape=[None, 3, 32, 32], dtype='float32') + archs = rlnas.next_archs(1) + for arch in archs: + output = arch(input) + input = output + print(output) + + .. py:method:: reward(rewards, **kwargs): + + 把当前模型结构的rewards回传。 + + **参数:** + + - **rewards>:** - 当前模型的rewards,分数越大越好。 + - **\*\*kwargs:** - 附加的参数,取决于具体的强化学习算法。 + + **示例代码:** + + .. code-block:: python + import paddle.fluid as fluid + from paddleslim.nas import RLNAS + config = [('MobileNetV2Space')] + rlnas = RLNAS(key='lstm', configs=config) + rlnas.next_archs(1) + rlnas.reward(1.0) + + .. note:: + reward这一步必须在`next_token`之后执行。 +.. + + .. py:method:: final_archs(batch_obs): + + 获取最终的模型结构。一般在controller训练完成之后会获取几十个模型结构进行完整的实验。 + + **参数:** + + - **obs** - 需要获取的模型结构数量或者当前模型的observations。 + + **返回:** + 返回模型结构实例的列表,形式为list。 + + **示例代码:** + + .. code-block:: python + import paddle.fluid as fluid + from paddleslim.nas import RLNAS + config = [('MobileNetV2Space')] + rlnas = RLNAS(key='lstm', configs=config) + archs = rlnas.final_archs(10) + + .. py:methd:: tokens2arch(tokens) + + 通过一组tokens得到实际的模型结构,一般用来把搜索到最优的token转换为模型结构用来做最后的训练。tokens的形式是一个列表,tokens映射到搜索空间转换成相应的网络结构,一组tokens对应唯一的一个网络结构。 + + **参数:** + + - **tokens(list):** - 一组tokens。tokens的长度和范围取决于搜索空间。 + + **返回:** + 根据传入的token得到一个模型结构实例列表。 + + **示例代码:** + + .. code-block:: python + + import paddle.fluid as fluid + from paddleslim.nas import SANAS + config = [('MobileNetV2Space')] + rlnas = RLNAS(key='lstm', configs=config) + input = fluid.data(name='input', shape=[None, 3, 32, 32], dtype='float32') + tokens = ([0] * 25) + archs = sanas.tokens2arch(tokens)[0] + print(archs(input)) + diff --git a/paddleslim/common/RL_controller/DDPG/DDPGController.py b/paddleslim/common/RL_controller/DDPG/DDPGController.py new file mode 100644 index 00000000..d25f556f --- /dev/null +++ b/paddleslim/common/RL_controller/DDPG/DDPGController.py @@ -0,0 +1,157 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import parl +from parl import layers +from paddle import fluid +from ..utils import RLCONTROLLER, action_mapping +from ...controller import RLBaseController +from .ddpg_model import DefaultDDPGModel as default_ddpg_model +from .noise import AdaptiveNoiseSpec as default_noise +from parl.utils import ReplayMemory + +__all__ = ['DDPG'] + + +class DDPGAgent(parl.Agent): + def __init__(self, algorithm, obs_dim, act_dim): + assert isinstance(obs_dim, int) + assert isinstance(act_dim, int) + self.obs_dim = obs_dim + self.act_dim = act_dim + super(DDPGAgent, self).__init__(algorithm) + + # Attention: In the beginning, sync target model totally. + self.alg.sync_target(decay=0) + + def build_program(self): + self.pred_program = fluid.Program() + self.learn_program = fluid.Program() + + with fluid.program_guard(self.pred_program): + obs = layers.data( + name='obs', shape=[self.obs_dim], dtype='float32') + self.pred_act = self.alg.predict(obs) + + with fluid.program_guard(self.learn_program): + obs = layers.data( + name='obs', shape=[self.obs_dim], dtype='float32') + act = layers.data( + name='act', shape=[self.act_dim], dtype='float32') + reward = layers.data(name='reward', shape=[], dtype='float32') + next_obs = layers.data( + name='next_obs', shape=[self.obs_dim], dtype='float32') + terminal = layers.data(name='terminal', shape=[], dtype='bool') + _, self.critic_cost = self.alg.learn(obs, act, reward, next_obs, + terminal) + + def predict(self, obs): + obs = np.expand_dims(obs, axis=0) + act = self.fluid_executor.run(self.pred_program, + feed={'obs': obs}, + fetch_list=[self.pred_act])[0] + return act + + def learn(self, obs, act, reward, next_obs, terminal): + feed = { + 'obs': obs, + 'act': act, + 'reward': reward, + 'next_obs': next_obs, + 'terminal': terminal + } + critic_cost = self.fluid_executor.run(self.learn_program, + feed=feed, + fetch_list=[self.critic_cost])[0] + self.alg.sync_target() + return critic_cost + + +@RLCONTROLLER.register +class DDPG(RLBaseController): + def __init__(self, range_tables, use_gpu=False, **kwargs): + self.use_gpu = use_gpu + self.range_tables = range_tables - np.asarray(1) + self.act_dim = len(self.range_tables) + self.obs_dim = kwargs.get('obs_dim') + self.model = kwargs.get( + 'model') if 'model' in kwargs else default_ddpg_model + self.actor_lr = kwargs.get( + 'actor_lr') if 'actor_lr' in kwargs else 1e-4 + self.critic_lr = kwargs.get( + 'critic_lr') if 'critic_lr' in kwargs else 1e-3 + self.gamma = kwargs.get('gamma') if 'gamma' in kwargs else 0.99 + self.tau = kwargs.get('tau') if 'tau' in kwargs else 0.001 + self.memory_size = kwargs.get( + 'memory_size') if 'memory_size' in kwargs else 10 + self.reward_scale = kwargs.get( + 'reward_scale') if 'reward_scale' in kwargs else 0.1 + self.batch_size = kwargs.get( + 'controller_batch_size') if 'controller_batch_size' in kwargs else 1 + self.actions_noise = kwargs.get( + 'actions_noise') if 'actions_noise' in kwargs else default_noise + self.action_dist = 0.0 + self.place = fluid.CUDAPlace(0) if self.use_gpu else fluid.CPUPlace() + + model = self.model(self.act_dim) + + if self.actions_noise: + self.actions_noise = self.actions_noise() + + algorithm = parl.algorithms.DDPG( + model, + gamma=self.gamma, + tau=self.tau, + actor_lr=self.actor_lr, + critic_lr=self.critic_lr) + self.agent = DDPGAgent(algorithm, self.obs_dim, self.act_dim) + self.rpm = ReplayMemory(self.memory_size, self.obs_dim, self.act_dim) + + self.pred_program = self.agent.pred_program + self.learn_program = self.agent.learn_program + self.param_dict = self.get_params(self.learn_program) + + def next_tokens(self, obs, params_dict, is_inference=False): + batch_obs = np.expand_dims(obs, axis=0) + self.set_params(self.pred_program, params_dict, self.place) + actions = self.agent.predict(batch_obs.astype('float32')) + ### add noise to action + if self.actions_noise and is_inference == False: + actions_noise = np.clip( + np.random.normal( + actions, scale=self.actions_noise.stdev_curr), + -1.0, + 1.0) + self.action_dist = np.mean(np.abs(actions_noise - actions)) + else: + actions_noise = actions + actions_noise = action_mapping(actions_noise, self.range_tables) + return actions_noise + + def _update_noise(self, actions_dist): + self.actions_noise.update(actions_dist) + + def update(self, rewards, params_dict, obs, actions, obs_next, terminal): + self.set_params(self.learn_program, params_dict, self.place) + self.rpm.append(obs, actions, self.reward_scale * rewards, obs_next, + terminal) + if self.actions_noise: + self._update_noise(self.action_dist) + if self.rpm.size() > self.memory_size: + obs, actions, rewards, obs_next, terminal = rpm.sample_batch( + self.batch_size) + self.agent.learn(obs, actions, rewards, obs_next, terminal) + params_dict = self.get_params(self.learn_program) + return params_dict diff --git a/paddleslim/common/RL_controller/DDPG/__init__.py b/paddleslim/common/RL_controller/DDPG/__init__.py new file mode 100644 index 00000000..d9fd1cc5 --- /dev/null +++ b/paddleslim/common/RL_controller/DDPG/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .DDPGController import * diff --git a/paddleslim/common/RL_controller/DDPG/ddpg_model.py b/paddleslim/common/RL_controller/DDPG/ddpg_model.py new file mode 100644 index 00000000..6607f8b4 --- /dev/null +++ b/paddleslim/common/RL_controller/DDPG/ddpg_model.py @@ -0,0 +1,67 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import parl +from parl import layers + + +class DefaultDDPGModel(parl.Model): + def __init__(self, act_dim): + self.actor_model = ActorModel(act_dim) + self.critic_model = CriticModel() + + def policy(self, obs): + return self.actor_model.policy(obs) + + def value(self, obs, act): + return self.critic_model.value(obs, act) + + def get_actor_params(self): + return self.actor_model.parameters() + + +class ActorModel(parl.Model): + def __init__(self, act_dim): + hid1_size = 400 + hid2_size = 300 + + self.fc1 = layers.fc(size=hid1_size, act='relu') + self.fc2 = layers.fc(size=hid2_size, act='relu') + self.fc3 = layers.fc(size=act_dim, act='tanh') + + def policy(self, obs): + hid1 = self.fc1(obs) + hid2 = self.fc2(hid1) + means = self.fc3(hid2) + means = means + return means + + +class CriticModel(parl.Model): + def __init__(self): + hid1_size = 400 + hid2_size = 300 + + self.fc1 = layers.fc(size=hid1_size, act='relu') + self.fc2 = layers.fc(size=hid2_size, act='relu') + self.fc3 = layers.fc(size=1, act=None) + + def value(self, obs, act): + hid1 = self.fc1(obs) + concat = layers.concat([hid1, act], axis=1) + hid2 = self.fc2(concat) + Q = self.fc3(hid2) + Q = layers.squeeze(Q, axes=[1]) + return Q diff --git a/paddleslim/common/RL_controller/DDPG/noise.py b/paddleslim/common/RL_controller/DDPG/noise.py new file mode 100644 index 00000000..4efbf96d --- /dev/null +++ b/paddleslim/common/RL_controller/DDPG/noise.py @@ -0,0 +1,29 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ['AdaptiveNoiseSpec'] + + +class AdaptiveNoiseSpec(object): + def __init__(self): + self.stdev_curr = 1.0 + + def reset(self): + self.stdev_curr = 1.0 + + def update(self, action_dist): + if action_dist > 1e-2: + self.stdev_curr /= 1.03 + else: + self.stdev_curr *= 1.03 diff --git a/paddleslim/common/RL_controller/LSTM/LSTM_Controller.py b/paddleslim/common/RL_controller/LSTM/LSTM_Controller.py new file mode 100644 index 00000000..30dd2907 --- /dev/null +++ b/paddleslim/common/RL_controller/LSTM/LSTM_Controller.py @@ -0,0 +1,281 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import logging +import numpy as np +import paddle.fluid as fluid +from paddle.fluid import ParamAttr +from paddle.fluid.layers import RNNCell, LSTMCell, rnn +from paddle.fluid.contrib.layers import basic_lstm +from ...controller import RLBaseController +from ...log_helper import get_logger +from ..utils import RLCONTROLLER + +_logger = get_logger(__name__, level=logging.INFO) + +uniform_initializer = lambda x: fluid.initializer.UniformInitializer(low=-x, high=x) + + +class lstm_cell(RNNCell): + def __init__(self, num_layers, hidden_size): + self.num_layers = num_layers + self.hidden_size = hidden_size + self.lstm_cells = [] + + param_attr = ParamAttr(initializer=uniform_initializer( + 1.0 / math.sqrt(hidden_size))) + bias_attr = ParamAttr(initializer=uniform_initializer( + 1.0 / math.sqrt(hidden_size))) + for i in range(num_layers): + self.lstm_cells.append( + LSTMCell(hidden_size, param_attr, bias_attr)) + + def call(self, inputs, states): + new_states = [] + for i in range(self.num_layers): + out, new_state = self.lstm_cells[i](inputs, states[i]) + new_states.append(new_state) + return out, new_states + + @property + def state_shape(self): + return [cell.state_shape for cell in self.lstm_cells] + + +@RLCONTROLLER.register +class LSTM(RLBaseController): + def __init__(self, range_tables, use_gpu=False, **kwargs): + self.use_gpu = use_gpu + self.range_tables = range_tables + self.lstm_num_layers = kwargs.get('lstm_num_layers') or 1 + self.hidden_size = kwargs.get('hidden_size') or 100 + self.temperature = kwargs.get('temperature') or None + self.tanh_constant = kwargs.get('tanh_constant') or None + self.decay = kwargs.get('decay') or 0.99 + self.weight_entropy = kwargs.get('weight_entropy') or None + self.controller_batch_size = kwargs.get('controller_batch_size') or 1 + + self.max_range_table = max(self.range_tables) + 1 + + self._create_parameter() + self._build_program() + + self.place = fluid.CUDAPlace(0) if self.use_gpu else fluid.CPUPlace() + self.exe = fluid.Executor(self.place) + self.exe.run(fluid.default_startup_program()) + + self.param_dict = self.get_params(self.learn_program) + + def _lstm(self, inputs, hidden, cell, token_idx): + cells = lstm_cell(self.lstm_num_layers, self.hidden_size) + output, new_states = cells.call(inputs, states=([[hidden, cell]])) + logits = fluid.layers.fc(new_states[0], self.range_tables[token_idx]) + + if self.temperature is not None: + logits = logits / self.temperature + if self.tanh_constant is not None: + logits = self.tanh_constant * fluid.layers.tanh(logits) + + return logits, output, new_states + + def _create_parameter(self): + self.emb_w = fluid.layers.create_parameter( + name='emb_w', + shape=(self.max_range_table, self.hidden_size), + dtype='float32', + default_initializer=uniform_initializer(1.0)) + + self.g_emb = fluid.layers.create_parameter( + name='emb_g', + shape=(self.controller_batch_size, self.hidden_size), + dtype='float32', + default_initializer=uniform_initializer(1.0)) + self.baseline = fluid.layers.create_global_var( + shape=[1], + value=0.0, + dtype='float32', + persistable=True, + name='baseline') + self.baseline.stop_gradient = True + + def _network(self, hidden, cell, init_actions=None, is_inference=False): + actions = [] + entropies = [] + sample_log_probs = [] + + with fluid.unique_name.guard('Controller'): + self._create_parameter() + inputs = self.g_emb + + for idx in range(len(self.range_tables)): + logits, output, states = self._lstm( + inputs, hidden, cell, token_idx=idx) + hidden, cell = np.squeeze(states) + probs = fluid.layers.softmax(logits, axis=1) + if is_inference: + action = fluid.layers.argmax(probs, axis=1) + else: + if init_actions: + action = fluid.layers.slice( + init_actions, + axes=[1], + starts=[idx], + ends=[idx + 1]) + action.stop_gradient = True + else: + action = fluid.layers.sampling_id(probs) + actions.append(action) + log_prob = fluid.layers.cross_entropy(probs, action) + sample_log_probs.append(log_prob) + + entropy = log_prob * fluid.layers.exp(-1 * log_prob) + entropy.stop_gradient = True + entropies.append(entropy) + + action_emb = fluid.layers.cast(action, dtype=np.int64) + inputs = fluid.layers.gather(self.emb_w, action_emb) + + sample_log_probs = fluid.layers.stack(sample_log_probs) + self.sample_log_probs = fluid.layers.reduce_sum(sample_log_probs) + + entropies = fluid.layers.stack(entropies) + self.sample_entropies = fluid.layers.reduce_sum(entropies) + + return actions + + def _build_program(self, is_inference=False): + self.pred_program = fluid.Program() + self.learn_program = fluid.Program() + with fluid.program_guard(self.pred_program): + self.g_emb = fluid.layers.create_parameter( + name='emb_g', + shape=(self.controller_batch_size, self.hidden_size), + dtype='float32', + default_initializer=uniform_initializer(1.0)) + + fluid.layers.assign( + fluid.layers.uniform_random(shape=self.g_emb.shape), + self.g_emb) + hidden = fluid.data(name='hidden', shape=[None, self.hidden_size]) + cell = fluid.data(name='cell', shape=[None, self.hidden_size]) + self.tokens = self._network( + hidden, cell, is_inference=is_inference) + + with fluid.program_guard(self.learn_program): + hidden = fluid.data(name='hidden', shape=[None, self.hidden_size]) + cell = fluid.data(name='cell', shape=[None, self.hidden_size]) + init_actions = fluid.data( + name='init_actions', + shape=[None, len(self.range_tables)], + dtype='int64') + self._network(hidden, cell, init_actions=init_actions) + + rewards = fluid.data(name='rewards', shape=[None]) + self.rewards = fluid.layers.reduce_mean(rewards) + + if self.weight_entropy is not None: + self.rewards += self.weight_entropy * self.sample_entropies + + self.sample_log_probs = fluid.layers.reduce_sum( + self.sample_log_probs) + + fluid.layers.assign(self.baseline - (1.0 - self.decay) * + (self.baseline - self.rewards), self.baseline) + self.loss = -1.0 * self.sample_log_probs * ( + self.rewards - self.baseline) + fluid.clip.set_gradient_clip( + clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=5.0)) + optimizer = fluid.optimizer.Adam(learning_rate=1e-3) + optimizer.minimize(self.loss) + + def _create_input(self, is_test=True, actual_rewards=None): + feed_dict = dict() + np_init_hidden = np.zeros( + (self.controller_batch_size, self.hidden_size)).astype('float32') + np_init_cell = np.zeros( + (self.controller_batch_size, self.hidden_size)).astype('float32') + + feed_dict["hidden"] = np_init_hidden + feed_dict["cell"] = np_init_cell + + if is_test == False: + if isinstance(actual_rewards, np.float32): + assert actual_rewards != None, "if you want to update controller, you must inputs a reward" + actual_rewards = np.expand_dims(actual_rewards, axis=0) + elif isinstance(actual_rewards, np.float) or isinstance( + actual_rewards, np.float64): + actual_rewards = np.float32(actual_rewards) + assert actual_rewards != None, "if you want to update controller, you must inputs a reward" + actual_rewards = np.expand_dims(actual_rewards, axis=0) + else: + assert actual_rewards.all( + ) != None, "if you want to update controller, you must inputs a reward" + actual_rewards = actual_rewards.astype(np.float32) + + feed_dict['rewards'] = actual_rewards + feed_dict['init_actions'] = np.array(self.init_tokens) + + return feed_dict + + def next_tokens(self, num_archs=1, params_dict=None, is_inference=False): + """ sample next tokens according current parameter and inputs""" + self.num_archs = num_archs + + self.set_params(self.pred_program, params_dict, self.place) + + batch_tokens = [] + feed_dict = self._create_input() + + for _ in range( + int(np.ceil(float(num_archs) / self.controller_batch_size))): + if is_inference: + self._build_program(is_inference=True) + + actions = self.exe.run(self.pred_program, + feed=feed_dict, + fetch_list=self.tokens) + + for idx in range(self.controller_batch_size): + each_token = {} + for i, action in enumerate(actions): + token = action[idx] + if idx in each_token: + each_token[idx].append(int(token)) + else: + each_token[idx] = [int(token)] + batch_tokens.append(each_token[idx]) + + self.init_tokens = batch_tokens + mod_token = (self.controller_batch_size - + (num_archs % self.controller_batch_size) + ) % self.controller_batch_size + if mod_token != 0: + return batch_tokens[:-mod_token] + else: + return batch_tokens + + def update(self, rewards, params_dict=None): + """train controller according reward""" + self.set_params(self.learn_program, params_dict, self.place) + + feed_dict = self._create_input(is_test=False, actual_rewards=rewards) + + loss = self.exe.run(self.learn_program, + feed=feed_dict, + fetch_list=[self.loss]) + _logger.info("Controller: current reward is {}, loss is {}".format( + rewards, loss)) + params_dict = self.get_params(self.learn_program) + return params_dict diff --git a/paddleslim/common/RL_controller/LSTM/__init__.py b/paddleslim/common/RL_controller/LSTM/__init__.py new file mode 100644 index 00000000..3acab8d0 --- /dev/null +++ b/paddleslim/common/RL_controller/LSTM/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .LSTM_Controller import * diff --git a/paddleslim/common/RL_controller/__init__.py b/paddleslim/common/RL_controller/__init__.py new file mode 100644 index 00000000..f2658c5f --- /dev/null +++ b/paddleslim/common/RL_controller/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from ..log_helper import get_logger +_logger = get_logger(__name__, level=logging.INFO) +try: + import parl + from .DDPG import * +except ImportError as e: + _logger.warn( + "If you want to use DDPG in RLNAS, please pip intall parl first. Now states: {}". + format(e)) + +from .LSTM import * +from .utils import * diff --git a/paddleslim/common/RL_controller/utils.py b/paddleslim/common/RL_controller/utils.py new file mode 100644 index 00000000..8363460c --- /dev/null +++ b/paddleslim/common/RL_controller/utils.py @@ -0,0 +1,54 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from ...core import Registry + +__all__ = [ + "RLCONTROLLER", "action_mapping", "add_grad", "compute_grad", + "ConnectMessage" +] + +RLCONTROLLER = Registry('RLController') + + +class ConnectMessage: + INIT = 'INIT' + INIT_DONE = 'INIT_DONE' + GET_WEIGHT = 'GET_WEIGHT' + UPDATE_WEIGHT = 'UPDATE_WEIGHT' + OK = 'OK' + WAIT = 'WAIT' + WAIT_PARAMS = 'WAIT_PARAMS' + EXIT = 'EXIT' + TIMEOUT = 10 + + +def action_mapping(actions, range_table): + actions = (actions - (-1.0)) * (range_table / np.asarray(2.0)) + return actions.astype('int64') + + +def add_grad(dict1, dict2): + dict3 = dict() + for key, value in dict1.items(): + dict3[key] = dict1[key] + dict2[key] + return dict3 + + +def compute_grad(dict1, dict2): + dict3 = dict() + for key, value in dict1.items(): + dict3[key] = dict1[key] - dict2[key] + return dict3 diff --git a/paddleslim/common/__init__.py b/paddleslim/common/__init__.py index e8649cbe..9a21bb77 100644 --- a/paddleslim/common/__init__.py +++ b/paddleslim/common/__init__.py @@ -18,9 +18,12 @@ from .controller_server import ControllerServer from .controller_client import ControllerClient from .lock import lock, unlock from .cached_reader import cached_reader +from .server import Server +from .client import Client from .meter import AvgrageMeter __all__ = [ 'EvolutionaryController', 'SAController', 'get_logger', 'ControllerServer', - 'ControllerClient', 'lock', 'unlock', 'cached_reader', 'AvgrageMeter' + 'ControllerClient', 'lock', 'unlock', 'cached_reader', 'AvgrageMeter', + 'Server', 'Client', 'RLBaseController' ] diff --git a/paddleslim/common/client.py b/paddleslim/common/client.py new file mode 100644 index 00000000..50ff4304 --- /dev/null +++ b/paddleslim/common/client.py @@ -0,0 +1,133 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import signal +import zmq +import socket +import logging +import time +import threading +import cloudpickle +from .log_helper import get_logger +from .RL_controller.utils import compute_grad, ConnectMessage + +_logger = get_logger(__name__, level=logging.INFO) + + +class Client(object): + def __init__(self, controller, address, client_name): + self._controller = controller + self._address = address + self._ip = self._address[0] + self._port = self._address[1] + self._client_name = client_name + self._params_dict = None + self.init_wait = False + self._connect_server() + + def _connect_server(self): + self._ctx = zmq.Context() + self._client_socket = self._ctx.socket(zmq.REQ) + ### NOTE: change the method to exit client when server is dead if there are better solutions + self._client_socket.setsockopt(zmq.RCVTIMEO, + ConnectMessage.TIMEOUT * 1000) + client_address = "{}:{}".format(self._ip, self._port) + self._client_socket.connect("tcp://{}".format(client_address)) + self._client_socket.send_multipart( + [ConnectMessage.INIT, self._client_name]) + message = self._client_socket.recv_multipart() + if message[0] != ConnectMessage.INIT_DONE: + _logger.error("Client {} init failure, Please start it again". + format(self._client_name)) + pid = os.getpid() + os.kill(pid, signal.SIGTERM) + _logger.info("Client {}: connect to server {}".format( + self._client_name, client_address)) + + def _connect_wait_socket(self, port): + self._wait_socket = self._ctx.socket(zmq.REQ) + wait_address = "{}:{}".format(self._ip, port) + self._wait_socket.connect("tcp://{}".format(wait_address)) + self._wait_socket.send_multipart( + [ConnectMessage.WAIT_PARAMS, self._client_name]) + message = self._wait_socket.recv_multipart() + return message[0] + + def next_tokens(self, obs, is_inference=False): + _logger.debug("Client: requests for weight {}".format( + self._client_name)) + self._client_socket.send_multipart( + [ConnectMessage.GET_WEIGHT, self._client_name]) + try: + message = self._client_socket.recv_multipart() + except zmq.error.Again as e: + _logger.error( + "CANNOT recv params from server in next_archs, Please check whether the server is alive!!! {}". + format(e)) + os._exit(0) + self._params_dict = cloudpickle.loads(message[0]) + tokens = self._controller.next_tokens( + obs, params_dict=self._params_dict, is_inference=is_inference) + _logger.debug("Client: client_name is {}, current token is {}".format( + self._client_name, tokens)) + return tokens + + def update(self, rewards, **kwargs): + assert self._params_dict != None, "Please call next_token to get token first, then call update" + current_params_dict = self._controller.update( + rewards, self._params_dict, **kwargs) + params_grad = compute_grad(self._params_dict, current_params_dict) + _logger.debug("Client: update weight {}".format(self._client_name)) + self._client_socket.send_multipart([ + ConnectMessage.UPDATE_WEIGHT, self._client_name, + cloudpickle.dumps(params_grad) + ]) + _logger.debug("Client: update done {}".format(self._client_name)) + + try: + message = self._client_socket.recv_multipart() + except zmq.error.Again as e: + _logger.error( + "CANNOT recv params from server in rewards, Please check whether the server is alive!!! {}". + format(e)) + os._exit(0) + + if message[0] == ConnectMessage.WAIT: + _logger.debug("Client: self.init_wait: {}".format(self.init_wait)) + if not self.init_wait: + wait_port = cloudpickle.loads(message[1]) + wait_signal = self._connect_wait_socket(wait_port) + self.init_wait = True + else: + wait_signal = message[0] + while wait_signal != ConnectMessage.OK: + time.sleep(1) + self._wait_socket.send_multipart( + [ConnectMessage.WAIT_PARAMS, self._client_name]) + wait_signal = self._wait_socket.recv_multipart() + wait_signal = wait_signal[0] + _logger.debug("Client: {} {}".format(self._client_name, + wait_signal)) + + return message[0] + + def __del__(self): + try: + self._client_socket.send_multipart( + [ConnectMessage.EXIT, self._client_name]) + _ = self._client_socket.recv_multipart() + except: + pass + self._client_socket.close() diff --git a/paddleslim/common/controller.py b/paddleslim/common/controller.py index d06b6b88..87c887bd 100644 --- a/paddleslim/common/controller.py +++ b/paddleslim/common/controller.py @@ -16,8 +16,9 @@ import copy import math import numpy as np +import paddle.fluid as fluid -__all__ = ['EvolutionaryController'] +__all__ = ['EvolutionaryController', 'RLBaseController'] class EvolutionaryController(object): @@ -51,3 +52,31 @@ class EvolutionaryController(object): list: The next searched tokens. """ raise NotImplementedError('Abstract method.') + + +class RLBaseController(object): + """ Base Controller for reforcement learning""" + + def next_tokens(self, *args, **kwargs): + raise NotImplementedError('Abstract method.') + + def update(self, *args, **kwargs): + raise NotImplementedError('Abstract method.') + + def save_controller(self, program, output_dir): + fluid.save(program, output_dir) + + def load_controller(self, program, load_dir): + fluid.load(program, load_dir) + + def get_params(self, program): + var_dict = {} + for var in program.global_block().all_parameters(): + var_dict[var.name] = np.array(fluid.global_scope().find_var( + var.name).get_tensor()) + return var_dict + + def set_params(self, program, params_dict, place): + for var in program.global_block().all_parameters(): + fluid.global_scope().find_var(var.name).get_tensor().set( + params_dict[var.name], place) diff --git a/paddleslim/common/server.py b/paddleslim/common/server.py new file mode 100644 index 00000000..3a0dde68 --- /dev/null +++ b/paddleslim/common/server.py @@ -0,0 +1,211 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import zmq +import socket +import signal +import six +import os +if six.PY2: + import cPickle as pickle +else: + import pickle +import logging +import time +import threading +import cloudpickle +from .log_helper import get_logger +from .RL_controller.utils import add_grad, ConnectMessage + +_logger = get_logger(__name__, level=logging.INFO) + + +class Server(object): + def __init__(self, + controller, + address, + is_sync=False, + load_controller=None, + save_controller=None): + self._controller = controller + self._address = address + self._ip = self._address[0] + self._port = self._address[1] + self._is_sync = is_sync + self._done = False + self._load_controller = load_controller + self._save_controller = save_controller + ### key-value : client_name-update_times + self._client_dict = dict() + self._client = list() + self._lock = threading.Lock() + self._server_alive = True + self._max_update_times = 0 + + def close(self): + self._server_alive = False + _logger.info("server closed") + pid = os.getpid() + os.kill(pid, signal.SIGTERM) + + def start(self): + self._ctx = zmq.Context() + ### main socket + self._server_socket = self._ctx.socket(zmq.REP) + server_address = "{}:{}".format(self._ip, self._port) + self._server_socket.bind("tcp://{}".format(server_address)) + self._server_socket.linger = 0 + _logger.info("ControllerServer - listen on: [{}]".format( + server_address)) + thread = threading.Thread(target=self.run, args=()) + thread.setDaemon(True) + thread.start() + + if self._load_controller: + assert os.path.exists( + self._load_controller + ), "controller checkpoint is not exist, please check your directory: {}".format( + self._load_controller) + + with open( + os.path.join(self._load_controller, 'rlnas.params'), + 'rb') as f: + self._params_dict = pickle.load(f) + _logger.info("Load params done") + + else: + self._params_dict = self._controller.param_dict + + if self._is_sync: + self._wait_socket = self._ctx.socket(zmq.REP) + self._wait_port = self._wait_socket.bind_to_random_port( + addr="tcp://*") + self._wait_socket_linger = 0 + wait_thread = threading.Thread( + target=self._wait_for_params, args=()) + wait_thread.setDaemon(True) + wait_thread.start() + + def _wait_for_params(self): + try: + while self._server_alive: + message = self._wait_socket.recv_multipart() + cmd = message[0] + client_name = message[1] + if cmd == ConnectMessage.WAIT_PARAMS: + _logger.debug("Server: wait for params") + self._lock.acquire() + self._wait_socket.send_multipart([ + ConnectMessage.OK + if self._done else ConnectMessage.WAIT + ]) + if self._done and client_name in self._client: + self._client.remove(client_name) + if len(self._client) == 0: + self.save_params() + self._done = False + self._lock.release() + else: + _logger.error("Error message {}".format(message)) + raise NotImplementedError + except Exception as err: + logger.error(err) + + def run(self): + try: + while self._server_alive: + try: + sum_params_dict = dict() + message = self._server_socket.recv_multipart() + cmd = message[0] + client_name = message[1] + if cmd == ConnectMessage.INIT: + self._server_socket.send_multipart( + [ConnectMessage.INIT_DONE]) + _logger.debug("Server: init client {}".format( + client_name)) + self._client_dict[client_name] = 0 + elif cmd == ConnectMessage.GET_WEIGHT: + self._lock.acquire() + _logger.debug("Server: get weight {}".format( + client_name)) + self._server_socket.send_multipart( + [cloudpickle.dumps(self._params_dict)]) + _logger.debug("Server: send params done {}".format( + client_name)) + self._lock.release() + elif cmd == ConnectMessage.UPDATE_WEIGHT: + _logger.info("Server: update {}".format(client_name)) + params_dict_grad = cloudpickle.loads(message[2]) + if self._is_sync: + if not sum_params_dict: + sum_params_dict = self._params_dict + self._lock.acquire() + sum_params_dict = add_grad(sum_params_dict, + params_dict_grad) + self._client.append(client_name) + self._lock.release() + + if len(self._client) == len( + self._client_dict.items()): + self._done = True + + self._server_socket.send_multipart([ + ConnectMessage.WAIT, + cloudpickle.dumps(self._wait_port) + ]) + else: + self._lock.acquire() + self._params_dict = add_grad(self._params_dict, + params_dict_grad) + self._client_dict[client_name] += 1 + if self._client_dict[ + client_name] > self._max_update_times: + self._max_update_times = self._client_dict[ + client_name] + self._lock.release() + self.save_params() + self._server_socket.send_multipart( + [ConnectMessage.OK]) + + elif cmd == ConnectMessage.EXIT: + self._client_dict.pop(client_name) + if client_name in self._client: + self._client.remove(client_name) + self._server_socket.send_multipart( + [ConnectMessage.EXIT]) + except zmq.error.Again as e: + _logger.error(e) + self.close() + + except Exception as err: + _logger.error(err) + finally: + self._server_socket.close(0) + if self._is_sync: + self._wait_socket.close(0) + self.close() + + def save_params(self): + if self._save_controller: + if not os.path.exists(self._save_controller): + os.makedirs(self._save_controller) + output_dir = self._save_controller + else: + os.makedirs('./.rlnas_controller') + output_dir = './.rlnas_controller' + + with open(os.path.join(output_dir, 'rlnas.params'), 'wb') as f: + pickle.dump(self._params_dict, f) + _logger.info("Save params done") diff --git a/paddleslim/nas/__init__.py b/paddleslim/nas/__init__.py index 03ec7551..d438e54c 100644 --- a/paddleslim/nas/__init__.py +++ b/paddleslim/nas/__init__.py @@ -16,10 +16,12 @@ from ..nas import search_space from .search_space import * from ..nas import sa_nas from .sa_nas import * +from .rl_nas import * from ..nas import darts from .darts import * __all__ = [] __all__ += sa_nas.__all__ __all__ += search_space.__all__ +__all__ += rl_nas.__all__ __all__ += darts.__all__ diff --git a/paddleslim/nas/rl_nas.py b/paddleslim/nas/rl_nas.py new file mode 100644 index 00000000..5b84382d --- /dev/null +++ b/paddleslim/nas/rl_nas.py @@ -0,0 +1,162 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import socket +import logging +import numpy as np +import json +import hashlib +import time +import paddle.fluid as fluid +from ..common.RL_controller.utils import RLCONTROLLER +from ..common import get_logger + +from ..common import Server +from ..common import Client +from .search_space import SearchSpaceFactory + +_logger = get_logger(__name__, level=logging.INFO) + +__all__ = ['RLNAS'] + + +class RLNAS(object): + """ + Controller with Reinforcement Learning. + Args: + key(str): The actual reinforcement learning method. Current support in paddleslim is `LSTM` and `DDPG`. + configs(list): A list of search space configuration with format [(key, {input_size, + output_size, block_num, block_mask})]. `key` is the name of search space + with data type str. `input_size` and `output_size` are input size and + output size of searched sub-network. `block_num` is the number of blocks + in searched network, `block_mask` is a list consists by 0 and 1, 0 means + normal block, 1 means reduction block. + use_gpu(bool): Whether to use gpu in controller. Default: False. + server_addr(tuple): Server address, including ip and port of server. If ip is None or "", will + use host ip if is_server = True. Default: ("", 8881). + is_server(bool): Whether current host is controller server. Default: True. + is_sync(bool): Whether to update controller in synchronous mode. Default: False. + save_controller(str|None): The directory of controller to save, if set to None, not save checkpoint. + Default: None. + load_controller(str|None): The directory of controller to load, if set to None, not load checkpoint. + Default: None. + **kwargs: Additional keyword arguments. + """ + + def __init__(self, + key, + configs, + use_gpu=False, + server_addr=("", 8881), + is_server=True, + is_sync=False, + save_controller=None, + load_controller=None, + **kwargs): + if not is_server: + assert server_addr[ + 0] != "", "You should set the IP and port of server when is_server is False." + + self._configs = configs + factory = SearchSpaceFactory() + self._search_space = factory.get_search_space(configs) + self.range_tables = self._search_space.range_table() + self.save_controller = save_controller + self.load_controller = load_controller + + cls = RLCONTROLLER.get(key.upper()) + + server_ip, server_port = server_addr + if server_ip == None or server_ip == "": + server_ip = self._get_host_ip() + + self._controller = cls(range_tables=self.range_tables, + use_gpu=use_gpu, + **kwargs) + + if is_server: + max_client_num = 300 + self._controller_server = Server( + controller=self._controller, + address=(server_ip, server_port), + is_sync=is_sync, + save_controller=self.save_controller, + load_controller=self.load_controller) + self._controller_server.start() + + self._client_name = hashlib.md5( + str(time.time() + np.random.randint(1, 10000)).encode( + "utf-8")).hexdigest() + self._controller_client = Client( + controller=self._controller, + address=(server_ip, server_port), + client_name=self._client_name) + + self._current_tokens = None + + def _get_host_ip(self): + try: + return socket.gethostbyname(socket.gethostname()) + except: + return socket.gethostbyname('localhost') + + def next_archs(self, obs=None): + """ + Get next archs + Args: + obs(int|np.array): observations in env. + """ + archs = [] + self._current_tokens = self._controller_client.next_tokens(obs) + _logger.info("current tokens: {}".format(self._current_tokens)) + for token in self._current_tokens: + archs.append(self._search_space.token2arch(token)) + + return archs + + def reward(self, rewards, **kwargs): + """ + reward the score and to train controller + Args: + rewards(float|list): rewards get by tokens. + **kwargs: Additional keyword arguments. + """ + return self._controller_client.update(rewards, **kwargs) + + def final_archs(self, batch_obs): + """ + Get finally architecture + Args: + batch_obs(int|np.array): observations in env. + """ + final_tokens = self._controller_client.next_tokens( + batch_obs, is_inference=True) + _logger.info("Final tokens: {}".format(final_tokens)) + archs = [] + for token in final_tokens: + arch = self._search_space.token2arch(token) + archs.append(arch) + + return archs + + def tokens2arch(self, tokens): + """ + Convert tokens to model architectures. + Args + tokens: A list of token. The length and range based on search space.: + Returns: + list: A model architecture instance according to tokens. + """ + return self._search_space.token2arch(tokens) -- GitLab