Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
d21b2c6b
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
1 年多 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d21b2c6b
编写于
4月 22, 2020
作者:
C
ceci3
提交者:
GitHub
4月 22, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add RL (#163)
上级
823ca6bb
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
1394 addition
and
3 deletion
+1394
-3
docs/zh_cn/api_cn/custom_rl_controller.md
docs/zh_cn/api_cn/custom_rl_controller.md
+54
-0
docs/zh_cn/api_cn/nas_api.rst
docs/zh_cn/api_cn/nas_api.rst
+153
-1
paddleslim/common/RL_controller/DDPG/DDPGController.py
paddleslim/common/RL_controller/DDPG/DDPGController.py
+157
-0
paddleslim/common/RL_controller/DDPG/__init__.py
paddleslim/common/RL_controller/DDPG/__init__.py
+15
-0
paddleslim/common/RL_controller/DDPG/ddpg_model.py
paddleslim/common/RL_controller/DDPG/ddpg_model.py
+67
-0
paddleslim/common/RL_controller/DDPG/noise.py
paddleslim/common/RL_controller/DDPG/noise.py
+29
-0
paddleslim/common/RL_controller/LSTM/LSTM_Controller.py
paddleslim/common/RL_controller/LSTM/LSTM_Controller.py
+281
-0
paddleslim/common/RL_controller/LSTM/__init__.py
paddleslim/common/RL_controller/LSTM/__init__.py
+15
-0
paddleslim/common/RL_controller/__init__.py
paddleslim/common/RL_controller/__init__.py
+27
-0
paddleslim/common/RL_controller/utils.py
paddleslim/common/RL_controller/utils.py
+54
-0
paddleslim/common/__init__.py
paddleslim/common/__init__.py
+4
-1
paddleslim/common/client.py
paddleslim/common/client.py
+133
-0
paddleslim/common/controller.py
paddleslim/common/controller.py
+30
-1
paddleslim/common/server.py
paddleslim/common/server.py
+211
-0
paddleslim/nas/__init__.py
paddleslim/nas/__init__.py
+2
-0
paddleslim/nas/rl_nas.py
paddleslim/nas/rl_nas.py
+162
-0
未找到文件。
docs/zh_cn/api_cn/custom_rl_controller.md
0 → 100644
浏览文件 @
d21b2c6b
# 外部如何自定义强化学习Controller
首先导入必要的依赖:
```
python
### 引入强化学习Controller基类函数和注册类函数
from
paddleslim.common.RL_controller.utils
import
RLCONTROLLER
from
paddleslim.common.RL_controller
import
RLBaseController
```
通过装饰器的方式把自定义强化学习Controller注册到PaddleSlim,继承基类之后需要重写基类中的
`next_tokens`
和
`update`
两个函数。注意:本示例仅说明一些必不可少的步骤,并不能直接运行,完整代码请参考
[
这里
](
)
```
python
### 注意: 类名一定要全部大写
@
RLCONTROLLER
.
register
class
LSTM
(
RLBaseController
):
def
__init__
(
self
,
range_tables
,
use_gpu
=
False
,
**
kwargs
):
### range_tables 表示tokens的取值范围
self
.
range_tables
=
range_tables
### use_gpu 表示是否使用gpu来训练controller
self
.
use_gpu
=
use_gpu
### 定义一些强化学习算法中需要的参数
...
### 构造相应的program, _build_program这个函数会构造两个program,一个是pred_program,一个是learn_program, 并初始化参数
self
.
_build_program
()
self
.
place
=
fluid
.
CUDAPlace
(
0
)
if
self
.
args
.
use_gpu
else
fluid
.
CPUPlace
()
self
.
exe
=
fluid
.
Executor
(
self
.
place
)
self
.
exe
.
run
(
fluid
.
default_startup_program
())
### 保存参数到一个字典中,这个字典由server端统一维护更新,因为可能有多个client同时更新一份参数,所以这一步必不可少,由于pred_program和learn_program使用的同一份参数,所以只需要把learn_program中的参数放入字典中即可
self
.
param_dicts
=
{}
self
.
param_dicts
.
update
(
self
.
learn_program
:
self
.
get_params
(
self
.
learn_program
))
def
next_tokens
(
self
,
states
,
params_dict
):
### 把从server端获取参数字典赋值给当前要用到的program
self
.
set_params
(
self
.
pred_program
,
params_dict
,
self
.
place
)
### 根据states构造输入
self
.
num_archs
=
states
feed_dict
=
self
.
_create_input
()
### 获取当前token
actions
=
self
.
exe
.
run
(
self
.
pred_program
,
feed
=
feed_dict
,
fetch_list
=
self
.
tokens
)
...
return
actions
def
update
(
self
,
rewards
,
params_dict
=
None
):
### 把从server端获取参数字典赋值给当前要用到的program
self
.
set_params
(
self
.
learn_program
,
params_dict
,
self
.
place
)
### 根据`next_tokens`中的states和`update`中的rewards构造输入
feed_dict
=
self
.
_create_input
(
is_test
=
False
,
actual_rewards
=
rewards
)
### 计算当前step的loss
loss
=
self
.
exe
.
run
(
self
.
learn_program
,
feed
=
feed_dict
,
fetch_list
=
[
self
.
loss
])
### 获取当前program的参数并返回,client会把本轮的参数传给server端进行参数更新
params_dict
=
self
.
get_params
(
self
.
learn_program
)
return
params_dict
```
docs/zh_cn/api_cn/nas_api.rst
浏览文件 @
d21b2c6b
SA-
NAS
NAS
========
搜索空间参数的配置
...
...
@@ -160,3 +160,155 @@ SANAS(Simulated Annealing Neural Architecture Search)是基于模拟退火
sanas = SANAS(configs=config)
print(sanas.current_info())
RLNAS
------
.. py:class:: paddleslim.nas.RLNAS(key, configs, use_gpu=False, server_addr=("", 8881), is_server=True, is_sync=False, save_controller=None, load_controller=None, **kwargs)
`源代码 <> `_
RLNAS (Reinforcement Learning Neural Architecture Search)是基于强化学习算法进行模型结构搜索的算法。
- **key<str>** - 使用的强化学习Controller名称,目前paddleslim支持的有`LSTM`和`DDPG`,自定义强化学习Controller请参考 ` 自定义强化学习Controller <> `_
- **configs(list<tuple>)** - 搜索空间配置列表,格式是 ``[(key, {input_size, output_size, block_num, block_mask})]`` 或者 ``[(key)]`` (MobileNetV2、MobilenetV1和ResNet的搜索空间使用和原本网络结构相同的搜索空间,所以仅需指定 ``key`` 即可), ``input_size`` 和 ``output_size`` 表示输入和输出的特征图的大小, ``block_num`` 是指搜索网络中的block数量, ``block_mask`` 是一组由0和1组成的列表,0代表不进行下采样的block,1代表下采样的block。 更多paddleslim提供的搜索空间配置可以参考[Search Space](../search_space.md)。
- **use_gpu(bool)** - 是否使用GPU来训练Controller。默认:False。
- **server_addr(tuple)** - RLNAS中Controller的地址,包括server的ip地址和端口号,如果ip地址为None或者为""的话则默认使用本机ip。默认:("", 8881)。
- **is_server(bool)** - 当前实例是否要启动一个server。默认:True。
- **is_sync(bool)** - 是否使用同步模式更新Controller,该模式仅在多client下有差别。默认:False。
- **save_controller(str|None)** - 保存Controller的checkpoint的文件目录,如果设置为None的话则不保存checkpoint。默认:None 。
- **load_controller(str|None)** - 加载Controller的checkpoint的文件目录,如果设置为None的话则不加载checkpoint。默认:None。
- **\*\*kwargs** - 附加的参数,由具体强化学习算法决定,`LSTM`和`DDPG`的附加参数请参考note。
.. note::
`LSTM`算法的附加参数:
- lstm_num_layers(int, optional): - Controller中堆叠的LSTM的层数。默认:1.
- hidden_size(int, optional): - LSTM中隐藏层的大小。默认:100.
- temperature(float, optional): - 是否在计算每个token过程中做温度平均。默认:None.
- tanh_constant(float, optional): 是否在计算每个token过程中做tanh激活,并乘上`tanh_constant`值。 默认:None。
- decay(float, optional): LSTM中记录rewards的baseline的平滑率。默认:0.99.
- weight_entropy(float, optional): 在更新controller参数时是否为接收到的rewards加上计算token过程中的带权重的交叉熵值。默认:None。
- controller_batch_size(int, optional): controller的batch_size,即每运行一次controller可以拿到几个token。默认:1.
`DDPG`算法的附加参数:
注意:使用`DDPG`算法的话必须安装parl。安装方法: pip install parl
- obs_dim(int): observation的维度。
- model(class,optional): DDPG算法中使用的具体的模型,一般是个类,包含actor_model和critic_model,需要实现两个方法,一个是policy用来获得策略,另一个是value,需要获得Q值。可以参考默认的model` <>_`实现您自己的model。默认:`default_ddpg_model`.
- actor_lr(float, optional): actor网络的学习率。默认:1e-4.
- critic_lr(float, optional): critic网络的学习率。默认:1e-3.
- gamma(float, optional): 接收到rewards之后的折扣因子。默认:0.99.
- tau(float, optional): DDPG中把models的参数同步累积到target_model上时的折扣因子。默认:0.001.
- memory_size(int, optional): DDPG中记录历史信息的池子大小。默认:10.
- reward_scale(float, optional): 记录历史信息时,对rewards信息进行的折扣因子。默认:0.1.
- controller_batch_size(int, optional): controller的batch_size,即每运行一次controller可以拿到几个token。默认:1.
- actions_noise(class, optional): 通过DDPG拿到action之后添加的噪声,设置为False或者None时不添加噪声。默认:default_noise.
..
**返回:**
一个RLNAS类的实例
**示例代码:**
.. code-block:: python
from paddleslim.nas import RLNAS
config = [('MobileNetV2Space')]
rlnas = RLNAS(key='lstm', configs=config)
.. py:method:: next_archs(obs=None)
获取下一组模型结构。
**参数:**
- **obs<int|np.array>** - 需要获取的模型结构数量或者当前模型的observations。
**返回:**
返回模型结构实例的列表,形式为list。
**示例代码:**
.. code-block:: python
import paddle.fluid as fluid
from paddleslim.nas import RLNAS
config = [('MobileNetV2Space')]
rlnas = RLNAS(key='lstm', configs=config)
input = fluid.data(name='input', shape=[None, 3, 32, 32], dtype='float32')
archs = rlnas.next_archs(1)
for arch in archs:
output = arch(input)
input = output
print(output)
.. py:method:: reward(rewards, **kwargs):
把当前模型结构的rewards回传。
**参数:**
- **rewards<float|list<float>>:** - 当前模型的rewards,分数越大越好。
- **\*\*kwargs:** - 附加的参数,取决于具体的强化学习算法。
**示例代码:**
.. code-block:: python
import paddle.fluid as fluid
from paddleslim.nas import RLNAS
config = [('MobileNetV2Space')]
rlnas = RLNAS(key='lstm', configs=config)
rlnas.next_archs(1)
rlnas.reward(1.0)
.. note::
reward这一步必须在`next_token`之后执行。
..
.. py:method:: final_archs(batch_obs):
获取最终的模型结构。一般在controller训练完成之后会获取几十个模型结构进行完整的实验。
**参数:**
- **obs<int|np.array>** - 需要获取的模型结构数量或者当前模型的observations。
**返回:**
返回模型结构实例的列表,形式为list。
**示例代码:**
.. code-block:: python
import paddle.fluid as fluid
from paddleslim.nas import RLNAS
config = [('MobileNetV2Space')]
rlnas = RLNAS(key='lstm', configs=config)
archs = rlnas.final_archs(10)
.. py:methd:: tokens2arch(tokens)
通过一组tokens得到实际的模型结构,一般用来把搜索到最优的token转换为模型结构用来做最后的训练。tokens的形式是一个列表,tokens映射到搜索空间转换成相应的网络结构,一组tokens对应唯一的一个网络结构。
**参数:**
- **tokens(list):** - 一组tokens。tokens的长度和范围取决于搜索空间。
**返回:**
根据传入的token得到一个模型结构实例列表。
**示例代码:**
.. code-block:: python
import paddle.fluid as fluid
from paddleslim.nas import SANAS
config = [('MobileNetV2Space')]
rlnas = RLNAS(key='lstm', configs=config)
input = fluid.data(name='input', shape=[None, 3, 32, 32], dtype='float32')
tokens = ([0] * 25)
archs = sanas.tokens2arch(tokens)[0]
print(archs(input))
paddleslim/common/RL_controller/DDPG/DDPGController.py
0 → 100644
浏览文件 @
d21b2c6b
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
parl
from
parl
import
layers
from
paddle
import
fluid
from
..utils
import
RLCONTROLLER
,
action_mapping
from
...controller
import
RLBaseController
from
.ddpg_model
import
DefaultDDPGModel
as
default_ddpg_model
from
.noise
import
AdaptiveNoiseSpec
as
default_noise
from
parl.utils
import
ReplayMemory
__all__
=
[
'DDPG'
]
class
DDPGAgent
(
parl
.
Agent
):
def
__init__
(
self
,
algorithm
,
obs_dim
,
act_dim
):
assert
isinstance
(
obs_dim
,
int
)
assert
isinstance
(
act_dim
,
int
)
self
.
obs_dim
=
obs_dim
self
.
act_dim
=
act_dim
super
(
DDPGAgent
,
self
).
__init__
(
algorithm
)
# Attention: In the beginning, sync target model totally.
self
.
alg
.
sync_target
(
decay
=
0
)
def
build_program
(
self
):
self
.
pred_program
=
fluid
.
Program
()
self
.
learn_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
self
.
pred_program
):
obs
=
layers
.
data
(
name
=
'obs'
,
shape
=
[
self
.
obs_dim
],
dtype
=
'float32'
)
self
.
pred_act
=
self
.
alg
.
predict
(
obs
)
with
fluid
.
program_guard
(
self
.
learn_program
):
obs
=
layers
.
data
(
name
=
'obs'
,
shape
=
[
self
.
obs_dim
],
dtype
=
'float32'
)
act
=
layers
.
data
(
name
=
'act'
,
shape
=
[
self
.
act_dim
],
dtype
=
'float32'
)
reward
=
layers
.
data
(
name
=
'reward'
,
shape
=
[],
dtype
=
'float32'
)
next_obs
=
layers
.
data
(
name
=
'next_obs'
,
shape
=
[
self
.
obs_dim
],
dtype
=
'float32'
)
terminal
=
layers
.
data
(
name
=
'terminal'
,
shape
=
[],
dtype
=
'bool'
)
_
,
self
.
critic_cost
=
self
.
alg
.
learn
(
obs
,
act
,
reward
,
next_obs
,
terminal
)
def
predict
(
self
,
obs
):
obs
=
np
.
expand_dims
(
obs
,
axis
=
0
)
act
=
self
.
fluid_executor
.
run
(
self
.
pred_program
,
feed
=
{
'obs'
:
obs
},
fetch_list
=
[
self
.
pred_act
])[
0
]
return
act
def
learn
(
self
,
obs
,
act
,
reward
,
next_obs
,
terminal
):
feed
=
{
'obs'
:
obs
,
'act'
:
act
,
'reward'
:
reward
,
'next_obs'
:
next_obs
,
'terminal'
:
terminal
}
critic_cost
=
self
.
fluid_executor
.
run
(
self
.
learn_program
,
feed
=
feed
,
fetch_list
=
[
self
.
critic_cost
])[
0
]
self
.
alg
.
sync_target
()
return
critic_cost
@
RLCONTROLLER
.
register
class
DDPG
(
RLBaseController
):
def
__init__
(
self
,
range_tables
,
use_gpu
=
False
,
**
kwargs
):
self
.
use_gpu
=
use_gpu
self
.
range_tables
=
range_tables
-
np
.
asarray
(
1
)
self
.
act_dim
=
len
(
self
.
range_tables
)
self
.
obs_dim
=
kwargs
.
get
(
'obs_dim'
)
self
.
model
=
kwargs
.
get
(
'model'
)
if
'model'
in
kwargs
else
default_ddpg_model
self
.
actor_lr
=
kwargs
.
get
(
'actor_lr'
)
if
'actor_lr'
in
kwargs
else
1e-4
self
.
critic_lr
=
kwargs
.
get
(
'critic_lr'
)
if
'critic_lr'
in
kwargs
else
1e-3
self
.
gamma
=
kwargs
.
get
(
'gamma'
)
if
'gamma'
in
kwargs
else
0.99
self
.
tau
=
kwargs
.
get
(
'tau'
)
if
'tau'
in
kwargs
else
0.001
self
.
memory_size
=
kwargs
.
get
(
'memory_size'
)
if
'memory_size'
in
kwargs
else
10
self
.
reward_scale
=
kwargs
.
get
(
'reward_scale'
)
if
'reward_scale'
in
kwargs
else
0.1
self
.
batch_size
=
kwargs
.
get
(
'controller_batch_size'
)
if
'controller_batch_size'
in
kwargs
else
1
self
.
actions_noise
=
kwargs
.
get
(
'actions_noise'
)
if
'actions_noise'
in
kwargs
else
default_noise
self
.
action_dist
=
0.0
self
.
place
=
fluid
.
CUDAPlace
(
0
)
if
self
.
use_gpu
else
fluid
.
CPUPlace
()
model
=
self
.
model
(
self
.
act_dim
)
if
self
.
actions_noise
:
self
.
actions_noise
=
self
.
actions_noise
()
algorithm
=
parl
.
algorithms
.
DDPG
(
model
,
gamma
=
self
.
gamma
,
tau
=
self
.
tau
,
actor_lr
=
self
.
actor_lr
,
critic_lr
=
self
.
critic_lr
)
self
.
agent
=
DDPGAgent
(
algorithm
,
self
.
obs_dim
,
self
.
act_dim
)
self
.
rpm
=
ReplayMemory
(
self
.
memory_size
,
self
.
obs_dim
,
self
.
act_dim
)
self
.
pred_program
=
self
.
agent
.
pred_program
self
.
learn_program
=
self
.
agent
.
learn_program
self
.
param_dict
=
self
.
get_params
(
self
.
learn_program
)
def
next_tokens
(
self
,
obs
,
params_dict
,
is_inference
=
False
):
batch_obs
=
np
.
expand_dims
(
obs
,
axis
=
0
)
self
.
set_params
(
self
.
pred_program
,
params_dict
,
self
.
place
)
actions
=
self
.
agent
.
predict
(
batch_obs
.
astype
(
'float32'
))
### add noise to action
if
self
.
actions_noise
and
is_inference
==
False
:
actions_noise
=
np
.
clip
(
np
.
random
.
normal
(
actions
,
scale
=
self
.
actions_noise
.
stdev_curr
),
-
1.0
,
1.0
)
self
.
action_dist
=
np
.
mean
(
np
.
abs
(
actions_noise
-
actions
))
else
:
actions_noise
=
actions
actions_noise
=
action_mapping
(
actions_noise
,
self
.
range_tables
)
return
actions_noise
def
_update_noise
(
self
,
actions_dist
):
self
.
actions_noise
.
update
(
actions_dist
)
def
update
(
self
,
rewards
,
params_dict
,
obs
,
actions
,
obs_next
,
terminal
):
self
.
set_params
(
self
.
learn_program
,
params_dict
,
self
.
place
)
self
.
rpm
.
append
(
obs
,
actions
,
self
.
reward_scale
*
rewards
,
obs_next
,
terminal
)
if
self
.
actions_noise
:
self
.
_update_noise
(
self
.
action_dist
)
if
self
.
rpm
.
size
()
>
self
.
memory_size
:
obs
,
actions
,
rewards
,
obs_next
,
terminal
=
rpm
.
sample_batch
(
self
.
batch_size
)
self
.
agent
.
learn
(
obs
,
actions
,
rewards
,
obs_next
,
terminal
)
params_dict
=
self
.
get_params
(
self
.
learn_program
)
return
params_dict
paddleslim/common/RL_controller/DDPG/__init__.py
0 → 100644
浏览文件 @
d21b2c6b
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.DDPGController
import
*
paddleslim/common/RL_controller/DDPG/ddpg_model.py
0 → 100644
浏览文件 @
d21b2c6b
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle.fluid
as
fluid
import
parl
from
parl
import
layers
class
DefaultDDPGModel
(
parl
.
Model
):
def
__init__
(
self
,
act_dim
):
self
.
actor_model
=
ActorModel
(
act_dim
)
self
.
critic_model
=
CriticModel
()
def
policy
(
self
,
obs
):
return
self
.
actor_model
.
policy
(
obs
)
def
value
(
self
,
obs
,
act
):
return
self
.
critic_model
.
value
(
obs
,
act
)
def
get_actor_params
(
self
):
return
self
.
actor_model
.
parameters
()
class
ActorModel
(
parl
.
Model
):
def
__init__
(
self
,
act_dim
):
hid1_size
=
400
hid2_size
=
300
self
.
fc1
=
layers
.
fc
(
size
=
hid1_size
,
act
=
'relu'
)
self
.
fc2
=
layers
.
fc
(
size
=
hid2_size
,
act
=
'relu'
)
self
.
fc3
=
layers
.
fc
(
size
=
act_dim
,
act
=
'tanh'
)
def
policy
(
self
,
obs
):
hid1
=
self
.
fc1
(
obs
)
hid2
=
self
.
fc2
(
hid1
)
means
=
self
.
fc3
(
hid2
)
means
=
means
return
means
class
CriticModel
(
parl
.
Model
):
def
__init__
(
self
):
hid1_size
=
400
hid2_size
=
300
self
.
fc1
=
layers
.
fc
(
size
=
hid1_size
,
act
=
'relu'
)
self
.
fc2
=
layers
.
fc
(
size
=
hid2_size
,
act
=
'relu'
)
self
.
fc3
=
layers
.
fc
(
size
=
1
,
act
=
None
)
def
value
(
self
,
obs
,
act
):
hid1
=
self
.
fc1
(
obs
)
concat
=
layers
.
concat
([
hid1
,
act
],
axis
=
1
)
hid2
=
self
.
fc2
(
concat
)
Q
=
self
.
fc3
(
hid2
)
Q
=
layers
.
squeeze
(
Q
,
axes
=
[
1
])
return
Q
paddleslim/common/RL_controller/DDPG/noise.py
0 → 100644
浏览文件 @
d21b2c6b
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__all__
=
[
'AdaptiveNoiseSpec'
]
class
AdaptiveNoiseSpec
(
object
):
def
__init__
(
self
):
self
.
stdev_curr
=
1.0
def
reset
(
self
):
self
.
stdev_curr
=
1.0
def
update
(
self
,
action_dist
):
if
action_dist
>
1e-2
:
self
.
stdev_curr
/=
1.03
else
:
self
.
stdev_curr
*=
1.03
paddleslim/common/RL_controller/LSTM/LSTM_Controller.py
0 → 100644
浏览文件 @
d21b2c6b
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
logging
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddle.fluid
import
ParamAttr
from
paddle.fluid.layers
import
RNNCell
,
LSTMCell
,
rnn
from
paddle.fluid.contrib.layers
import
basic_lstm
from
...controller
import
RLBaseController
from
...log_helper
import
get_logger
from
..utils
import
RLCONTROLLER
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
uniform_initializer
=
lambda
x
:
fluid
.
initializer
.
UniformInitializer
(
low
=-
x
,
high
=
x
)
class
lstm_cell
(
RNNCell
):
def
__init__
(
self
,
num_layers
,
hidden_size
):
self
.
num_layers
=
num_layers
self
.
hidden_size
=
hidden_size
self
.
lstm_cells
=
[]
param_attr
=
ParamAttr
(
initializer
=
uniform_initializer
(
1.0
/
math
.
sqrt
(
hidden_size
)))
bias_attr
=
ParamAttr
(
initializer
=
uniform_initializer
(
1.0
/
math
.
sqrt
(
hidden_size
)))
for
i
in
range
(
num_layers
):
self
.
lstm_cells
.
append
(
LSTMCell
(
hidden_size
,
param_attr
,
bias_attr
))
def
call
(
self
,
inputs
,
states
):
new_states
=
[]
for
i
in
range
(
self
.
num_layers
):
out
,
new_state
=
self
.
lstm_cells
[
i
](
inputs
,
states
[
i
])
new_states
.
append
(
new_state
)
return
out
,
new_states
@
property
def
state_shape
(
self
):
return
[
cell
.
state_shape
for
cell
in
self
.
lstm_cells
]
@
RLCONTROLLER
.
register
class
LSTM
(
RLBaseController
):
def
__init__
(
self
,
range_tables
,
use_gpu
=
False
,
**
kwargs
):
self
.
use_gpu
=
use_gpu
self
.
range_tables
=
range_tables
self
.
lstm_num_layers
=
kwargs
.
get
(
'lstm_num_layers'
)
or
1
self
.
hidden_size
=
kwargs
.
get
(
'hidden_size'
)
or
100
self
.
temperature
=
kwargs
.
get
(
'temperature'
)
or
None
self
.
tanh_constant
=
kwargs
.
get
(
'tanh_constant'
)
or
None
self
.
decay
=
kwargs
.
get
(
'decay'
)
or
0.99
self
.
weight_entropy
=
kwargs
.
get
(
'weight_entropy'
)
or
None
self
.
controller_batch_size
=
kwargs
.
get
(
'controller_batch_size'
)
or
1
self
.
max_range_table
=
max
(
self
.
range_tables
)
+
1
self
.
_create_parameter
()
self
.
_build_program
()
self
.
place
=
fluid
.
CUDAPlace
(
0
)
if
self
.
use_gpu
else
fluid
.
CPUPlace
()
self
.
exe
=
fluid
.
Executor
(
self
.
place
)
self
.
exe
.
run
(
fluid
.
default_startup_program
())
self
.
param_dict
=
self
.
get_params
(
self
.
learn_program
)
def
_lstm
(
self
,
inputs
,
hidden
,
cell
,
token_idx
):
cells
=
lstm_cell
(
self
.
lstm_num_layers
,
self
.
hidden_size
)
output
,
new_states
=
cells
.
call
(
inputs
,
states
=
([[
hidden
,
cell
]]))
logits
=
fluid
.
layers
.
fc
(
new_states
[
0
],
self
.
range_tables
[
token_idx
])
if
self
.
temperature
is
not
None
:
logits
=
logits
/
self
.
temperature
if
self
.
tanh_constant
is
not
None
:
logits
=
self
.
tanh_constant
*
fluid
.
layers
.
tanh
(
logits
)
return
logits
,
output
,
new_states
def
_create_parameter
(
self
):
self
.
emb_w
=
fluid
.
layers
.
create_parameter
(
name
=
'emb_w'
,
shape
=
(
self
.
max_range_table
,
self
.
hidden_size
),
dtype
=
'float32'
,
default_initializer
=
uniform_initializer
(
1.0
))
self
.
g_emb
=
fluid
.
layers
.
create_parameter
(
name
=
'emb_g'
,
shape
=
(
self
.
controller_batch_size
,
self
.
hidden_size
),
dtype
=
'float32'
,
default_initializer
=
uniform_initializer
(
1.0
))
self
.
baseline
=
fluid
.
layers
.
create_global_var
(
shape
=
[
1
],
value
=
0.0
,
dtype
=
'float32'
,
persistable
=
True
,
name
=
'baseline'
)
self
.
baseline
.
stop_gradient
=
True
def
_network
(
self
,
hidden
,
cell
,
init_actions
=
None
,
is_inference
=
False
):
actions
=
[]
entropies
=
[]
sample_log_probs
=
[]
with
fluid
.
unique_name
.
guard
(
'Controller'
):
self
.
_create_parameter
()
inputs
=
self
.
g_emb
for
idx
in
range
(
len
(
self
.
range_tables
)):
logits
,
output
,
states
=
self
.
_lstm
(
inputs
,
hidden
,
cell
,
token_idx
=
idx
)
hidden
,
cell
=
np
.
squeeze
(
states
)
probs
=
fluid
.
layers
.
softmax
(
logits
,
axis
=
1
)
if
is_inference
:
action
=
fluid
.
layers
.
argmax
(
probs
,
axis
=
1
)
else
:
if
init_actions
:
action
=
fluid
.
layers
.
slice
(
init_actions
,
axes
=
[
1
],
starts
=
[
idx
],
ends
=
[
idx
+
1
])
action
.
stop_gradient
=
True
else
:
action
=
fluid
.
layers
.
sampling_id
(
probs
)
actions
.
append
(
action
)
log_prob
=
fluid
.
layers
.
cross_entropy
(
probs
,
action
)
sample_log_probs
.
append
(
log_prob
)
entropy
=
log_prob
*
fluid
.
layers
.
exp
(
-
1
*
log_prob
)
entropy
.
stop_gradient
=
True
entropies
.
append
(
entropy
)
action_emb
=
fluid
.
layers
.
cast
(
action
,
dtype
=
np
.
int64
)
inputs
=
fluid
.
layers
.
gather
(
self
.
emb_w
,
action_emb
)
sample_log_probs
=
fluid
.
layers
.
stack
(
sample_log_probs
)
self
.
sample_log_probs
=
fluid
.
layers
.
reduce_sum
(
sample_log_probs
)
entropies
=
fluid
.
layers
.
stack
(
entropies
)
self
.
sample_entropies
=
fluid
.
layers
.
reduce_sum
(
entropies
)
return
actions
def
_build_program
(
self
,
is_inference
=
False
):
self
.
pred_program
=
fluid
.
Program
()
self
.
learn_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
self
.
pred_program
):
self
.
g_emb
=
fluid
.
layers
.
create_parameter
(
name
=
'emb_g'
,
shape
=
(
self
.
controller_batch_size
,
self
.
hidden_size
),
dtype
=
'float32'
,
default_initializer
=
uniform_initializer
(
1.0
))
fluid
.
layers
.
assign
(
fluid
.
layers
.
uniform_random
(
shape
=
self
.
g_emb
.
shape
),
self
.
g_emb
)
hidden
=
fluid
.
data
(
name
=
'hidden'
,
shape
=
[
None
,
self
.
hidden_size
])
cell
=
fluid
.
data
(
name
=
'cell'
,
shape
=
[
None
,
self
.
hidden_size
])
self
.
tokens
=
self
.
_network
(
hidden
,
cell
,
is_inference
=
is_inference
)
with
fluid
.
program_guard
(
self
.
learn_program
):
hidden
=
fluid
.
data
(
name
=
'hidden'
,
shape
=
[
None
,
self
.
hidden_size
])
cell
=
fluid
.
data
(
name
=
'cell'
,
shape
=
[
None
,
self
.
hidden_size
])
init_actions
=
fluid
.
data
(
name
=
'init_actions'
,
shape
=
[
None
,
len
(
self
.
range_tables
)],
dtype
=
'int64'
)
self
.
_network
(
hidden
,
cell
,
init_actions
=
init_actions
)
rewards
=
fluid
.
data
(
name
=
'rewards'
,
shape
=
[
None
])
self
.
rewards
=
fluid
.
layers
.
reduce_mean
(
rewards
)
if
self
.
weight_entropy
is
not
None
:
self
.
rewards
+=
self
.
weight_entropy
*
self
.
sample_entropies
self
.
sample_log_probs
=
fluid
.
layers
.
reduce_sum
(
self
.
sample_log_probs
)
fluid
.
layers
.
assign
(
self
.
baseline
-
(
1.0
-
self
.
decay
)
*
(
self
.
baseline
-
self
.
rewards
),
self
.
baseline
)
self
.
loss
=
-
1.0
*
self
.
sample_log_probs
*
(
self
.
rewards
-
self
.
baseline
)
fluid
.
clip
.
set_gradient_clip
(
clip
=
fluid
.
clip
.
GradientClipByGlobalNorm
(
clip_norm
=
5.0
))
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
1e-3
)
optimizer
.
minimize
(
self
.
loss
)
def
_create_input
(
self
,
is_test
=
True
,
actual_rewards
=
None
):
feed_dict
=
dict
()
np_init_hidden
=
np
.
zeros
(
(
self
.
controller_batch_size
,
self
.
hidden_size
)).
astype
(
'float32'
)
np_init_cell
=
np
.
zeros
(
(
self
.
controller_batch_size
,
self
.
hidden_size
)).
astype
(
'float32'
)
feed_dict
[
"hidden"
]
=
np_init_hidden
feed_dict
[
"cell"
]
=
np_init_cell
if
is_test
==
False
:
if
isinstance
(
actual_rewards
,
np
.
float32
):
assert
actual_rewards
!=
None
,
"if you want to update controller, you must inputs a reward"
actual_rewards
=
np
.
expand_dims
(
actual_rewards
,
axis
=
0
)
elif
isinstance
(
actual_rewards
,
np
.
float
)
or
isinstance
(
actual_rewards
,
np
.
float64
):
actual_rewards
=
np
.
float32
(
actual_rewards
)
assert
actual_rewards
!=
None
,
"if you want to update controller, you must inputs a reward"
actual_rewards
=
np
.
expand_dims
(
actual_rewards
,
axis
=
0
)
else
:
assert
actual_rewards
.
all
(
)
!=
None
,
"if you want to update controller, you must inputs a reward"
actual_rewards
=
actual_rewards
.
astype
(
np
.
float32
)
feed_dict
[
'rewards'
]
=
actual_rewards
feed_dict
[
'init_actions'
]
=
np
.
array
(
self
.
init_tokens
)
return
feed_dict
def
next_tokens
(
self
,
num_archs
=
1
,
params_dict
=
None
,
is_inference
=
False
):
""" sample next tokens according current parameter and inputs"""
self
.
num_archs
=
num_archs
self
.
set_params
(
self
.
pred_program
,
params_dict
,
self
.
place
)
batch_tokens
=
[]
feed_dict
=
self
.
_create_input
()
for
_
in
range
(
int
(
np
.
ceil
(
float
(
num_archs
)
/
self
.
controller_batch_size
))):
if
is_inference
:
self
.
_build_program
(
is_inference
=
True
)
actions
=
self
.
exe
.
run
(
self
.
pred_program
,
feed
=
feed_dict
,
fetch_list
=
self
.
tokens
)
for
idx
in
range
(
self
.
controller_batch_size
):
each_token
=
{}
for
i
,
action
in
enumerate
(
actions
):
token
=
action
[
idx
]
if
idx
in
each_token
:
each_token
[
idx
].
append
(
int
(
token
))
else
:
each_token
[
idx
]
=
[
int
(
token
)]
batch_tokens
.
append
(
each_token
[
idx
])
self
.
init_tokens
=
batch_tokens
mod_token
=
(
self
.
controller_batch_size
-
(
num_archs
%
self
.
controller_batch_size
)
)
%
self
.
controller_batch_size
if
mod_token
!=
0
:
return
batch_tokens
[:
-
mod_token
]
else
:
return
batch_tokens
def
update
(
self
,
rewards
,
params_dict
=
None
):
"""train controller according reward"""
self
.
set_params
(
self
.
learn_program
,
params_dict
,
self
.
place
)
feed_dict
=
self
.
_create_input
(
is_test
=
False
,
actual_rewards
=
rewards
)
loss
=
self
.
exe
.
run
(
self
.
learn_program
,
feed
=
feed_dict
,
fetch_list
=
[
self
.
loss
])
_logger
.
info
(
"Controller: current reward is {}, loss is {}"
.
format
(
rewards
,
loss
))
params_dict
=
self
.
get_params
(
self
.
learn_program
)
return
params_dict
paddleslim/common/RL_controller/LSTM/__init__.py
0 → 100644
浏览文件 @
d21b2c6b
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.LSTM_Controller
import
*
paddleslim/common/RL_controller/__init__.py
0 → 100644
浏览文件 @
d21b2c6b
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
logging
from
..log_helper
import
get_logger
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
try
:
import
parl
from
.DDPG
import
*
except
ImportError
as
e
:
_logger
.
warn
(
"If you want to use DDPG in RLNAS, please pip intall parl first. Now states: {}"
.
format
(
e
))
from
.LSTM
import
*
from
.utils
import
*
paddleslim/common/RL_controller/utils.py
0 → 100644
浏览文件 @
d21b2c6b
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
from
...core
import
Registry
__all__
=
[
"RLCONTROLLER"
,
"action_mapping"
,
"add_grad"
,
"compute_grad"
,
"ConnectMessage"
]
RLCONTROLLER
=
Registry
(
'RLController'
)
class
ConnectMessage
:
INIT
=
'INIT'
INIT_DONE
=
'INIT_DONE'
GET_WEIGHT
=
'GET_WEIGHT'
UPDATE_WEIGHT
=
'UPDATE_WEIGHT'
OK
=
'OK'
WAIT
=
'WAIT'
WAIT_PARAMS
=
'WAIT_PARAMS'
EXIT
=
'EXIT'
TIMEOUT
=
10
def
action_mapping
(
actions
,
range_table
):
actions
=
(
actions
-
(
-
1.0
))
*
(
range_table
/
np
.
asarray
(
2.0
))
return
actions
.
astype
(
'int64'
)
def
add_grad
(
dict1
,
dict2
):
dict3
=
dict
()
for
key
,
value
in
dict1
.
items
():
dict3
[
key
]
=
dict1
[
key
]
+
dict2
[
key
]
return
dict3
def
compute_grad
(
dict1
,
dict2
):
dict3
=
dict
()
for
key
,
value
in
dict1
.
items
():
dict3
[
key
]
=
dict1
[
key
]
-
dict2
[
key
]
return
dict3
paddleslim/common/__init__.py
浏览文件 @
d21b2c6b
...
...
@@ -18,9 +18,12 @@ from .controller_server import ControllerServer
from
.controller_client
import
ControllerClient
from
.lock
import
lock
,
unlock
from
.cached_reader
import
cached_reader
from
.server
import
Server
from
.client
import
Client
from
.meter
import
AvgrageMeter
__all__
=
[
'EvolutionaryController'
,
'SAController'
,
'get_logger'
,
'ControllerServer'
,
'ControllerClient'
,
'lock'
,
'unlock'
,
'cached_reader'
,
'AvgrageMeter'
'ControllerClient'
,
'lock'
,
'unlock'
,
'cached_reader'
,
'AvgrageMeter'
,
'Server'
,
'Client'
,
'RLBaseController'
]
paddleslim/common/client.py
0 → 100644
浏览文件 @
d21b2c6b
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
signal
import
zmq
import
socket
import
logging
import
time
import
threading
import
cloudpickle
from
.log_helper
import
get_logger
from
.RL_controller.utils
import
compute_grad
,
ConnectMessage
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
class
Client
(
object
):
def
__init__
(
self
,
controller
,
address
,
client_name
):
self
.
_controller
=
controller
self
.
_address
=
address
self
.
_ip
=
self
.
_address
[
0
]
self
.
_port
=
self
.
_address
[
1
]
self
.
_client_name
=
client_name
self
.
_params_dict
=
None
self
.
init_wait
=
False
self
.
_connect_server
()
def
_connect_server
(
self
):
self
.
_ctx
=
zmq
.
Context
()
self
.
_client_socket
=
self
.
_ctx
.
socket
(
zmq
.
REQ
)
### NOTE: change the method to exit client when server is dead if there are better solutions
self
.
_client_socket
.
setsockopt
(
zmq
.
RCVTIMEO
,
ConnectMessage
.
TIMEOUT
*
1000
)
client_address
=
"{}:{}"
.
format
(
self
.
_ip
,
self
.
_port
)
self
.
_client_socket
.
connect
(
"tcp://{}"
.
format
(
client_address
))
self
.
_client_socket
.
send_multipart
(
[
ConnectMessage
.
INIT
,
self
.
_client_name
])
message
=
self
.
_client_socket
.
recv_multipart
()
if
message
[
0
]
!=
ConnectMessage
.
INIT_DONE
:
_logger
.
error
(
"Client {} init failure, Please start it again"
.
format
(
self
.
_client_name
))
pid
=
os
.
getpid
()
os
.
kill
(
pid
,
signal
.
SIGTERM
)
_logger
.
info
(
"Client {}: connect to server {}"
.
format
(
self
.
_client_name
,
client_address
))
def
_connect_wait_socket
(
self
,
port
):
self
.
_wait_socket
=
self
.
_ctx
.
socket
(
zmq
.
REQ
)
wait_address
=
"{}:{}"
.
format
(
self
.
_ip
,
port
)
self
.
_wait_socket
.
connect
(
"tcp://{}"
.
format
(
wait_address
))
self
.
_wait_socket
.
send_multipart
(
[
ConnectMessage
.
WAIT_PARAMS
,
self
.
_client_name
])
message
=
self
.
_wait_socket
.
recv_multipart
()
return
message
[
0
]
def
next_tokens
(
self
,
obs
,
is_inference
=
False
):
_logger
.
debug
(
"Client: requests for weight {}"
.
format
(
self
.
_client_name
))
self
.
_client_socket
.
send_multipart
(
[
ConnectMessage
.
GET_WEIGHT
,
self
.
_client_name
])
try
:
message
=
self
.
_client_socket
.
recv_multipart
()
except
zmq
.
error
.
Again
as
e
:
_logger
.
error
(
"CANNOT recv params from server in next_archs, Please check whether the server is alive!!! {}"
.
format
(
e
))
os
.
_exit
(
0
)
self
.
_params_dict
=
cloudpickle
.
loads
(
message
[
0
])
tokens
=
self
.
_controller
.
next_tokens
(
obs
,
params_dict
=
self
.
_params_dict
,
is_inference
=
is_inference
)
_logger
.
debug
(
"Client: client_name is {}, current token is {}"
.
format
(
self
.
_client_name
,
tokens
))
return
tokens
def
update
(
self
,
rewards
,
**
kwargs
):
assert
self
.
_params_dict
!=
None
,
"Please call next_token to get token first, then call update"
current_params_dict
=
self
.
_controller
.
update
(
rewards
,
self
.
_params_dict
,
**
kwargs
)
params_grad
=
compute_grad
(
self
.
_params_dict
,
current_params_dict
)
_logger
.
debug
(
"Client: update weight {}"
.
format
(
self
.
_client_name
))
self
.
_client_socket
.
send_multipart
([
ConnectMessage
.
UPDATE_WEIGHT
,
self
.
_client_name
,
cloudpickle
.
dumps
(
params_grad
)
])
_logger
.
debug
(
"Client: update done {}"
.
format
(
self
.
_client_name
))
try
:
message
=
self
.
_client_socket
.
recv_multipart
()
except
zmq
.
error
.
Again
as
e
:
_logger
.
error
(
"CANNOT recv params from server in rewards, Please check whether the server is alive!!! {}"
.
format
(
e
))
os
.
_exit
(
0
)
if
message
[
0
]
==
ConnectMessage
.
WAIT
:
_logger
.
debug
(
"Client: self.init_wait: {}"
.
format
(
self
.
init_wait
))
if
not
self
.
init_wait
:
wait_port
=
cloudpickle
.
loads
(
message
[
1
])
wait_signal
=
self
.
_connect_wait_socket
(
wait_port
)
self
.
init_wait
=
True
else
:
wait_signal
=
message
[
0
]
while
wait_signal
!=
ConnectMessage
.
OK
:
time
.
sleep
(
1
)
self
.
_wait_socket
.
send_multipart
(
[
ConnectMessage
.
WAIT_PARAMS
,
self
.
_client_name
])
wait_signal
=
self
.
_wait_socket
.
recv_multipart
()
wait_signal
=
wait_signal
[
0
]
_logger
.
debug
(
"Client: {} {}"
.
format
(
self
.
_client_name
,
wait_signal
))
return
message
[
0
]
def
__del__
(
self
):
try
:
self
.
_client_socket
.
send_multipart
(
[
ConnectMessage
.
EXIT
,
self
.
_client_name
])
_
=
self
.
_client_socket
.
recv_multipart
()
except
:
pass
self
.
_client_socket
.
close
()
paddleslim/common/controller.py
浏览文件 @
d21b2c6b
...
...
@@ -16,8 +16,9 @@
import
copy
import
math
import
numpy
as
np
import
paddle.fluid
as
fluid
__all__
=
[
'EvolutionaryController'
]
__all__
=
[
'EvolutionaryController'
,
'RLBaseController'
]
class
EvolutionaryController
(
object
):
...
...
@@ -51,3 +52,31 @@ class EvolutionaryController(object):
list<list>: The next searched tokens.
"""
raise
NotImplementedError
(
'Abstract method.'
)
class
RLBaseController
(
object
):
""" Base Controller for reforcement learning"""
def
next_tokens
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
'Abstract method.'
)
def
update
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
'Abstract method.'
)
def
save_controller
(
self
,
program
,
output_dir
):
fluid
.
save
(
program
,
output_dir
)
def
load_controller
(
self
,
program
,
load_dir
):
fluid
.
load
(
program
,
load_dir
)
def
get_params
(
self
,
program
):
var_dict
=
{}
for
var
in
program
.
global_block
().
all_parameters
():
var_dict
[
var
.
name
]
=
np
.
array
(
fluid
.
global_scope
().
find_var
(
var
.
name
).
get_tensor
())
return
var_dict
def
set_params
(
self
,
program
,
params_dict
,
place
):
for
var
in
program
.
global_block
().
all_parameters
():
fluid
.
global_scope
().
find_var
(
var
.
name
).
get_tensor
().
set
(
params_dict
[
var
.
name
],
place
)
paddleslim/common/server.py
0 → 100644
浏览文件 @
d21b2c6b
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
zmq
import
socket
import
signal
import
six
import
os
if
six
.
PY2
:
import
cPickle
as
pickle
else
:
import
pickle
import
logging
import
time
import
threading
import
cloudpickle
from
.log_helper
import
get_logger
from
.RL_controller.utils
import
add_grad
,
ConnectMessage
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
class
Server
(
object
):
def
__init__
(
self
,
controller
,
address
,
is_sync
=
False
,
load_controller
=
None
,
save_controller
=
None
):
self
.
_controller
=
controller
self
.
_address
=
address
self
.
_ip
=
self
.
_address
[
0
]
self
.
_port
=
self
.
_address
[
1
]
self
.
_is_sync
=
is_sync
self
.
_done
=
False
self
.
_load_controller
=
load_controller
self
.
_save_controller
=
save_controller
### key-value : client_name-update_times
self
.
_client_dict
=
dict
()
self
.
_client
=
list
()
self
.
_lock
=
threading
.
Lock
()
self
.
_server_alive
=
True
self
.
_max_update_times
=
0
def
close
(
self
):
self
.
_server_alive
=
False
_logger
.
info
(
"server closed"
)
pid
=
os
.
getpid
()
os
.
kill
(
pid
,
signal
.
SIGTERM
)
def
start
(
self
):
self
.
_ctx
=
zmq
.
Context
()
### main socket
self
.
_server_socket
=
self
.
_ctx
.
socket
(
zmq
.
REP
)
server_address
=
"{}:{}"
.
format
(
self
.
_ip
,
self
.
_port
)
self
.
_server_socket
.
bind
(
"tcp://{}"
.
format
(
server_address
))
self
.
_server_socket
.
linger
=
0
_logger
.
info
(
"ControllerServer - listen on: [{}]"
.
format
(
server_address
))
thread
=
threading
.
Thread
(
target
=
self
.
run
,
args
=
())
thread
.
setDaemon
(
True
)
thread
.
start
()
if
self
.
_load_controller
:
assert
os
.
path
.
exists
(
self
.
_load_controller
),
"controller checkpoint is not exist, please check your directory: {}"
.
format
(
self
.
_load_controller
)
with
open
(
os
.
path
.
join
(
self
.
_load_controller
,
'rlnas.params'
),
'rb'
)
as
f
:
self
.
_params_dict
=
pickle
.
load
(
f
)
_logger
.
info
(
"Load params done"
)
else
:
self
.
_params_dict
=
self
.
_controller
.
param_dict
if
self
.
_is_sync
:
self
.
_wait_socket
=
self
.
_ctx
.
socket
(
zmq
.
REP
)
self
.
_wait_port
=
self
.
_wait_socket
.
bind_to_random_port
(
addr
=
"tcp://*"
)
self
.
_wait_socket_linger
=
0
wait_thread
=
threading
.
Thread
(
target
=
self
.
_wait_for_params
,
args
=
())
wait_thread
.
setDaemon
(
True
)
wait_thread
.
start
()
def
_wait_for_params
(
self
):
try
:
while
self
.
_server_alive
:
message
=
self
.
_wait_socket
.
recv_multipart
()
cmd
=
message
[
0
]
client_name
=
message
[
1
]
if
cmd
==
ConnectMessage
.
WAIT_PARAMS
:
_logger
.
debug
(
"Server: wait for params"
)
self
.
_lock
.
acquire
()
self
.
_wait_socket
.
send_multipart
([
ConnectMessage
.
OK
if
self
.
_done
else
ConnectMessage
.
WAIT
])
if
self
.
_done
and
client_name
in
self
.
_client
:
self
.
_client
.
remove
(
client_name
)
if
len
(
self
.
_client
)
==
0
:
self
.
save_params
()
self
.
_done
=
False
self
.
_lock
.
release
()
else
:
_logger
.
error
(
"Error message {}"
.
format
(
message
))
raise
NotImplementedError
except
Exception
as
err
:
logger
.
error
(
err
)
def
run
(
self
):
try
:
while
self
.
_server_alive
:
try
:
sum_params_dict
=
dict
()
message
=
self
.
_server_socket
.
recv_multipart
()
cmd
=
message
[
0
]
client_name
=
message
[
1
]
if
cmd
==
ConnectMessage
.
INIT
:
self
.
_server_socket
.
send_multipart
(
[
ConnectMessage
.
INIT_DONE
])
_logger
.
debug
(
"Server: init client {}"
.
format
(
client_name
))
self
.
_client_dict
[
client_name
]
=
0
elif
cmd
==
ConnectMessage
.
GET_WEIGHT
:
self
.
_lock
.
acquire
()
_logger
.
debug
(
"Server: get weight {}"
.
format
(
client_name
))
self
.
_server_socket
.
send_multipart
(
[
cloudpickle
.
dumps
(
self
.
_params_dict
)])
_logger
.
debug
(
"Server: send params done {}"
.
format
(
client_name
))
self
.
_lock
.
release
()
elif
cmd
==
ConnectMessage
.
UPDATE_WEIGHT
:
_logger
.
info
(
"Server: update {}"
.
format
(
client_name
))
params_dict_grad
=
cloudpickle
.
loads
(
message
[
2
])
if
self
.
_is_sync
:
if
not
sum_params_dict
:
sum_params_dict
=
self
.
_params_dict
self
.
_lock
.
acquire
()
sum_params_dict
=
add_grad
(
sum_params_dict
,
params_dict_grad
)
self
.
_client
.
append
(
client_name
)
self
.
_lock
.
release
()
if
len
(
self
.
_client
)
==
len
(
self
.
_client_dict
.
items
()):
self
.
_done
=
True
self
.
_server_socket
.
send_multipart
([
ConnectMessage
.
WAIT
,
cloudpickle
.
dumps
(
self
.
_wait_port
)
])
else
:
self
.
_lock
.
acquire
()
self
.
_params_dict
=
add_grad
(
self
.
_params_dict
,
params_dict_grad
)
self
.
_client_dict
[
client_name
]
+=
1
if
self
.
_client_dict
[
client_name
]
>
self
.
_max_update_times
:
self
.
_max_update_times
=
self
.
_client_dict
[
client_name
]
self
.
_lock
.
release
()
self
.
save_params
()
self
.
_server_socket
.
send_multipart
(
[
ConnectMessage
.
OK
])
elif
cmd
==
ConnectMessage
.
EXIT
:
self
.
_client_dict
.
pop
(
client_name
)
if
client_name
in
self
.
_client
:
self
.
_client
.
remove
(
client_name
)
self
.
_server_socket
.
send_multipart
(
[
ConnectMessage
.
EXIT
])
except
zmq
.
error
.
Again
as
e
:
_logger
.
error
(
e
)
self
.
close
()
except
Exception
as
err
:
_logger
.
error
(
err
)
finally
:
self
.
_server_socket
.
close
(
0
)
if
self
.
_is_sync
:
self
.
_wait_socket
.
close
(
0
)
self
.
close
()
def
save_params
(
self
):
if
self
.
_save_controller
:
if
not
os
.
path
.
exists
(
self
.
_save_controller
):
os
.
makedirs
(
self
.
_save_controller
)
output_dir
=
self
.
_save_controller
else
:
os
.
makedirs
(
'./.rlnas_controller'
)
output_dir
=
'./.rlnas_controller'
with
open
(
os
.
path
.
join
(
output_dir
,
'rlnas.params'
),
'wb'
)
as
f
:
pickle
.
dump
(
self
.
_params_dict
,
f
)
_logger
.
info
(
"Save params done"
)
paddleslim/nas/__init__.py
浏览文件 @
d21b2c6b
...
...
@@ -16,10 +16,12 @@ from ..nas import search_space
from
.search_space
import
*
from
..nas
import
sa_nas
from
.sa_nas
import
*
from
.rl_nas
import
*
from
..nas
import
darts
from
.darts
import
*
__all__
=
[]
__all__
+=
sa_nas
.
__all__
__all__
+=
search_space
.
__all__
__all__
+=
rl_nas
.
__all__
__all__
+=
darts
.
__all__
paddleslim/nas/rl_nas.py
0 → 100644
浏览文件 @
d21b2c6b
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
socket
import
logging
import
numpy
as
np
import
json
import
hashlib
import
time
import
paddle.fluid
as
fluid
from
..common.RL_controller.utils
import
RLCONTROLLER
from
..common
import
get_logger
from
..common
import
Server
from
..common
import
Client
from
.search_space
import
SearchSpaceFactory
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
__all__
=
[
'RLNAS'
]
class
RLNAS
(
object
):
"""
Controller with Reinforcement Learning.
Args:
key(str): The actual reinforcement learning method. Current support in paddleslim is `LSTM` and `DDPG`.
configs(list<tuple>): A list of search space configuration with format [(key, {input_size,
output_size, block_num, block_mask})]. `key` is the name of search space
with data type str. `input_size` and `output_size` are input size and
output size of searched sub-network. `block_num` is the number of blocks
in searched network, `block_mask` is a list consists by 0 and 1, 0 means
normal block, 1 means reduction block.
use_gpu(bool): Whether to use gpu in controller. Default: False.
server_addr(tuple): Server address, including ip and port of server. If ip is None or "", will
use host ip if is_server = True. Default: ("", 8881).
is_server(bool): Whether current host is controller server. Default: True.
is_sync(bool): Whether to update controller in synchronous mode. Default: False.
save_controller(str|None): The directory of controller to save, if set to None, not save checkpoint.
Default: None.
load_controller(str|None): The directory of controller to load, if set to None, not load checkpoint.
Default: None.
**kwargs: Additional keyword arguments.
"""
def
__init__
(
self
,
key
,
configs
,
use_gpu
=
False
,
server_addr
=
(
""
,
8881
),
is_server
=
True
,
is_sync
=
False
,
save_controller
=
None
,
load_controller
=
None
,
**
kwargs
):
if
not
is_server
:
assert
server_addr
[
0
]
!=
""
,
"You should set the IP and port of server when is_server is False."
self
.
_configs
=
configs
factory
=
SearchSpaceFactory
()
self
.
_search_space
=
factory
.
get_search_space
(
configs
)
self
.
range_tables
=
self
.
_search_space
.
range_table
()
self
.
save_controller
=
save_controller
self
.
load_controller
=
load_controller
cls
=
RLCONTROLLER
.
get
(
key
.
upper
())
server_ip
,
server_port
=
server_addr
if
server_ip
==
None
or
server_ip
==
""
:
server_ip
=
self
.
_get_host_ip
()
self
.
_controller
=
cls
(
range_tables
=
self
.
range_tables
,
use_gpu
=
use_gpu
,
**
kwargs
)
if
is_server
:
max_client_num
=
300
self
.
_controller_server
=
Server
(
controller
=
self
.
_controller
,
address
=
(
server_ip
,
server_port
),
is_sync
=
is_sync
,
save_controller
=
self
.
save_controller
,
load_controller
=
self
.
load_controller
)
self
.
_controller_server
.
start
()
self
.
_client_name
=
hashlib
.
md5
(
str
(
time
.
time
()
+
np
.
random
.
randint
(
1
,
10000
)).
encode
(
"utf-8"
)).
hexdigest
()
self
.
_controller_client
=
Client
(
controller
=
self
.
_controller
,
address
=
(
server_ip
,
server_port
),
client_name
=
self
.
_client_name
)
self
.
_current_tokens
=
None
def
_get_host_ip
(
self
):
try
:
return
socket
.
gethostbyname
(
socket
.
gethostname
())
except
:
return
socket
.
gethostbyname
(
'localhost'
)
def
next_archs
(
self
,
obs
=
None
):
"""
Get next archs
Args:
obs(int|np.array): observations in env.
"""
archs
=
[]
self
.
_current_tokens
=
self
.
_controller_client
.
next_tokens
(
obs
)
_logger
.
info
(
"current tokens: {}"
.
format
(
self
.
_current_tokens
))
for
token
in
self
.
_current_tokens
:
archs
.
append
(
self
.
_search_space
.
token2arch
(
token
))
return
archs
def
reward
(
self
,
rewards
,
**
kwargs
):
"""
reward the score and to train controller
Args:
rewards(float|list<float>): rewards get by tokens.
**kwargs: Additional keyword arguments.
"""
return
self
.
_controller_client
.
update
(
rewards
,
**
kwargs
)
def
final_archs
(
self
,
batch_obs
):
"""
Get finally architecture
Args:
batch_obs(int|np.array): observations in env.
"""
final_tokens
=
self
.
_controller_client
.
next_tokens
(
batch_obs
,
is_inference
=
True
)
_logger
.
info
(
"Final tokens: {}"
.
format
(
final_tokens
))
archs
=
[]
for
token
in
final_tokens
:
arch
=
self
.
_search_space
.
token2arch
(
token
)
archs
.
append
(
arch
)
return
archs
def
tokens2arch
(
self
,
tokens
):
"""
Convert tokens to model architectures.
Args
tokens<list>: A list of token. The length and range based on search space.:
Returns:
list<function>: A model architecture instance according to tokens.
"""
return
self
.
_search_space
.
token2arch
(
tokens
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录