Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
60544b52
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
1 年多 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
60544b52
编写于
5月 15, 2020
作者:
C
ceci3
提交者:
GitHub
5月 15, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix something wrong about rlnas (#269) (#280)
* fix * fix * update * update
上级
e6b788a3
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
355 addition
and
48 deletion
+355
-48
demo/nas/README.md
demo/nas/README.md
+29
-3
demo/nas/parl_nas_mobilenetv2.py
demo/nas/parl_nas_mobilenetv2.py
+244
-0
docs/zh_cn/quick_start/nas_tutorial.md
docs/zh_cn/quick_start/nas_tutorial.md
+9
-9
docs/zh_cn/tutorials/sanas_darts_space.md
docs/zh_cn/tutorials/sanas_darts_space.md
+12
-12
paddleslim/common/__init__.py
paddleslim/common/__init__.py
+1
-1
paddleslim/common/client.py
paddleslim/common/client.py
+1
-1
paddleslim/common/rl_controller/base_env.py
paddleslim/common/rl_controller/base_env.py
+29
-0
paddleslim/common/rl_controller/ddpg/ddpg_controller.py
paddleslim/common/rl_controller/ddpg/ddpg_controller.py
+11
-11
paddleslim/common/rl_controller/lstm/lstm_controller.py
paddleslim/common/rl_controller/lstm/lstm_controller.py
+17
-11
paddleslim/common/server.py
paddleslim/common/server.py
+2
-0
未找到文件。
demo/nas/README.md
浏览文件 @
60544b52
# 网络结构搜索示例
#
SANAS
网络结构搜索示例
本示例介绍如何使用网络结构搜索接口,搜索到一个更小或者精度更高的模型,该
文档仅
介绍paddleslim中SANAS的使用及如何利用SANAS得到模型结构,完整示例代码请参考sa_nas_mobilenetv2.py或者block_sa_nas_mobilenetv2.py。
本示例介绍如何使用网络结构搜索接口,搜索到一个更小或者精度更高的模型,该
示例
介绍paddleslim中SANAS的使用及如何利用SANAS得到模型结构,完整示例代码请参考sa_nas_mobilenetv2.py或者block_sa_nas_mobilenetv2.py。
## 数据准备
本示例默认使用cifar10数据,cifar10数据会根据调用的paddle接口自动下载,无需额外准备。
...
...
@@ -8,7 +8,7 @@
## 接口介绍
请参考
<a
href=
'../../docs/zh_cn/api_cn/nas_api.rst'
>
神经网络搜索API文档
</a>
。
本示例为在MobileNetV2的搜索空间上搜索FLOPs更小的模型。
本示例为
利用SANAS
在MobileNetV2的搜索空间上搜索FLOPs更小的模型。
## 1 搜索空间配置
默认搜索空间为
`MobileNetV2`
,详细的搜索空间配置请参考
<a
href=
'../../docs/zh_cn/api_cn/search_space.md'
>
搜索空间配置文档
</a>
。
...
...
@@ -24,3 +24,29 @@ CUDA_VISIBLE_DEVICES=0 python sa_nas_mobilenetv2.py
```
shell
CUDA_VISIBLE_DEVICES
=
0 python block_sa_nas_mobilenetv2.py
```
# RLNAS网络结构搜索示例
本示例介绍如何使用RLNAS接口进行网络结构搜索,该示例介绍paddleslim中RLNAS的使用,完整示例代码请参考rl_nas_mobilenetv2.py或者parl_nas_mobilenetv2.py。
## 数据准备
本示例默认使用cifar10数据,cifar10数据会根据调用的paddle接口自动下载,无需额外准备。
## 接口介绍
请参考
<a
href=
'../../docs/zh_cn/api_cn/nas_api.rst'
>
神经网络搜索API文档
</a>
。
示例为利用SANAS在MobileNetV2的搜索空间上搜索精度更高的模型。
## 1 搜索空间配置
默认搜索空间为
`MobileNetV2`
,详细的搜索空间配置请参考
<a
href=
'../../docs/zh_cn/api_cn/search_space.md'
>
搜索空间配置文档
</a>
。
## 2 启动训练
### 2.1 启动基于MobileNetV2初始模型结构构造搜索空间,强化学习算法为lstm的搜索实验
```
shell
CUDA_VISIBLE_DEVICES
=
0 python rl_nas_mobilenetv2.py
```
### 2.2 启动基于MobileNetV2初始模型结构构造搜索空间,强化学习算法为ddpg的搜索实验
```
shell
CUDA_VISIBLE_DEVICES
=
0 python parl_nas_mobilenetv2.py
```
demo/nas/parl_nas_mobilenetv2.py
0 → 100644
浏览文件 @
60544b52
import
sys
sys
.
path
.
append
(
'..'
)
import
numpy
as
np
import
argparse
import
ast
import
time
import
argparse
import
ast
import
logging
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.param_attr
import
ParamAttr
from
paddleslim.nas
import
RLNAS
from
paddleslim.common
import
get_logger
from
optimizer
import
create_optimizer
import
imagenet_reader
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
def
create_data_loader
(
image_shape
):
data_shape
=
[
None
]
+
image_shape
data
=
fluid
.
data
(
name
=
'data'
,
shape
=
data_shape
,
dtype
=
'float32'
)
label
=
fluid
.
data
(
name
=
'label'
,
shape
=
[
None
,
1
],
dtype
=
'int64'
)
data_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
[
data
,
label
],
capacity
=
1024
,
use_double_buffer
=
True
,
iterable
=
True
)
return
data_loader
,
data
,
label
def
build_program
(
main_program
,
startup_program
,
image_shape
,
archs
,
args
,
is_test
=
False
):
with
fluid
.
program_guard
(
main_program
,
startup_program
):
with
fluid
.
unique_name
.
guard
():
data_loader
,
data
,
label
=
create_data_loader
(
image_shape
)
output
=
archs
(
data
)
output
=
fluid
.
layers
.
fc
(
input
=
output
,
size
=
args
.
class_dim
)
softmax_out
=
fluid
.
layers
.
softmax
(
input
=
output
,
use_cudnn
=
False
)
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
softmax_out
,
label
=
label
)
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
acc_top1
=
fluid
.
layers
.
accuracy
(
input
=
softmax_out
,
label
=
label
,
k
=
1
)
acc_top5
=
fluid
.
layers
.
accuracy
(
input
=
softmax_out
,
label
=
label
,
k
=
5
)
if
is_test
==
False
:
optimizer
=
create_optimizer
(
args
)
optimizer
.
minimize
(
avg_cost
)
return
data_loader
,
avg_cost
,
acc_top1
,
acc_top5
def
search_mobilenetv2
(
config
,
args
,
image_size
,
is_server
=
True
):
if
is_server
:
### start a server and a client
rl_nas
=
RLNAS
(
key
=
'ddpg'
,
configs
=
config
,
is_sync
=
False
,
obs_dim
=
26
,
### step + length_of_token
server_addr
=
(
args
.
server_address
,
args
.
port
))
else
:
### start a client
rl_nas
=
RLNAS
(
key
=
'ddpg'
,
configs
=
config
,
is_sync
=
False
,
obs_dim
=
26
,
server_addr
=
(
args
.
server_address
,
args
.
port
),
is_server
=
False
)
image_shape
=
[
3
,
image_size
,
image_size
]
for
step
in
range
(
args
.
search_steps
):
if
step
==
0
:
action_prev
=
[
1.
for
_
in
rl_nas
.
range_tables
]
else
:
action_prev
=
rl_nas
.
tokens
[
0
]
obs
=
[
step
]
obs
.
extend
(
action_prev
)
archs
=
rl_nas
.
next_archs
(
obs
=
obs
)[
0
][
0
]
train_program
=
fluid
.
Program
()
test_program
=
fluid
.
Program
()
startup_program
=
fluid
.
Program
()
train_loader
,
avg_cost
,
acc_top1
,
acc_top5
=
build_program
(
train_program
,
startup_program
,
image_shape
,
archs
,
args
)
test_loader
,
test_avg_cost
,
test_acc_top1
,
test_acc_top5
=
build_program
(
test_program
,
startup_program
,
image_shape
,
archs
,
args
,
is_test
=
True
)
test_program
=
test_program
.
clone
(
for_test
=
True
)
place
=
fluid
.
CUDAPlace
(
0
)
if
args
.
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup_program
)
if
args
.
data
==
'cifar10'
:
train_reader
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
paddle
.
dataset
.
cifar
.
train10
(
cycle
=
False
),
buf_size
=
1024
),
batch_size
=
args
.
batch_size
,
drop_last
=
True
)
test_reader
=
paddle
.
batch
(
paddle
.
dataset
.
cifar
.
test10
(
cycle
=
False
),
batch_size
=
args
.
batch_size
,
drop_last
=
False
)
elif
args
.
data
==
'imagenet'
:
train_reader
=
paddle
.
batch
(
imagenet_reader
.
train
(),
batch_size
=
args
.
batch_size
,
drop_last
=
True
)
test_reader
=
paddle
.
batch
(
imagenet_reader
.
val
(),
batch_size
=
args
.
batch_size
,
drop_last
=
False
)
train_loader
.
set_sample_list_generator
(
train_reader
,
places
=
fluid
.
cuda_places
()
if
args
.
use_gpu
else
fluid
.
cpu_places
())
test_loader
.
set_sample_list_generator
(
test_reader
,
places
=
place
)
build_strategy
=
fluid
.
BuildStrategy
()
train_compiled_program
=
fluid
.
CompiledProgram
(
train_program
).
with_data_parallel
(
loss_name
=
avg_cost
.
name
,
build_strategy
=
build_strategy
)
for
epoch_id
in
range
(
args
.
retain_epoch
):
for
batch_id
,
data
in
enumerate
(
train_loader
()):
fetches
=
[
avg_cost
.
name
]
s_time
=
time
.
time
()
outs
=
exe
.
run
(
train_compiled_program
,
feed
=
data
,
fetch_list
=
fetches
)[
0
]
batch_time
=
time
.
time
()
-
s_time
if
batch_id
%
10
==
0
:
_logger
.
info
(
'TRAIN: steps: {}, epoch: {}, batch: {}, cost: {}, batch_time: {}ms'
.
format
(
step
,
epoch_id
,
batch_id
,
outs
[
0
],
batch_time
))
reward
=
[]
for
batch_id
,
data
in
enumerate
(
test_loader
()):
test_fetches
=
[
test_avg_cost
.
name
,
test_acc_top1
.
name
,
test_acc_top5
.
name
]
batch_reward
=
exe
.
run
(
test_program
,
feed
=
data
,
fetch_list
=
test_fetches
)
reward_avg
=
np
.
mean
(
np
.
array
(
batch_reward
),
axis
=
1
)
reward
.
append
(
reward_avg
)
_logger
.
info
(
'TEST: step: {}, batch: {}, avg_cost: {}, acc_top1: {}, acc_top5: {}'
.
format
(
step
,
batch_id
,
batch_reward
[
0
],
batch_reward
[
1
],
batch_reward
[
2
]))
finally_reward
=
np
.
mean
(
np
.
array
(
reward
),
axis
=
0
)
_logger
.
info
(
'FINAL TEST: avg_cost: {}, acc_top1: {}, acc_top5: {}'
.
format
(
finally_reward
[
0
],
finally_reward
[
1
],
finally_reward
[
2
]))
obs
=
np
.
expand_dims
(
obs
,
axis
=
0
).
astype
(
'float32'
)
actions
=
rl_nas
.
tokens
obs_next
=
[
step
+
1
]
obs_next
.
extend
(
actions
[
0
])
obs_next
=
np
.
expand_dims
(
obs_next
,
axis
=
0
).
astype
(
'float32'
)
if
step
==
args
.
search_steps
-
1
:
terminal
=
np
.
expand_dims
([
True
],
axis
=
0
).
astype
(
np
.
bool
)
else
:
terminal
=
np
.
expand_dims
([
False
],
axis
=
0
).
astype
(
np
.
bool
)
rl_nas
.
reward
(
np
.
expand_dims
(
np
.
float32
(
finally_reward
[
1
]),
axis
=
0
),
obs
=
obs
,
actions
=
actions
.
astype
(
'float32'
),
obs_next
=
obs_next
,
terminal
=
terminal
)
if
step
==
2
:
sys
.
exit
(
0
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'RL NAS MobileNetV2 cifar10 argparase'
)
parser
.
add_argument
(
'--use_gpu'
,
type
=
ast
.
literal_eval
,
default
=
True
,
help
=
'Whether to use GPU in train/test model.'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
256
,
help
=
'batch size.'
)
parser
.
add_argument
(
'--class_dim'
,
type
=
int
,
default
=
10
,
help
=
'classify number.'
)
parser
.
add_argument
(
'--data'
,
type
=
str
,
default
=
'cifar10'
,
choices
=
[
'cifar10'
,
'imagenet'
],
help
=
'server address.'
)
parser
.
add_argument
(
'--is_server'
,
type
=
ast
.
literal_eval
,
default
=
True
,
help
=
'Whether to start a server.'
)
parser
.
add_argument
(
'--search_steps'
,
type
=
int
,
default
=
100
,
help
=
'controller server number.'
)
parser
.
add_argument
(
'--server_address'
,
type
=
str
,
default
=
""
,
help
=
'server ip.'
)
parser
.
add_argument
(
'--port'
,
type
=
int
,
default
=
8881
,
help
=
'server port'
)
parser
.
add_argument
(
'--retain_epoch'
,
type
=
int
,
default
=
5
,
help
=
'epoch for each token.'
)
parser
.
add_argument
(
'--lr'
,
type
=
float
,
default
=
0.1
,
help
=
'learning rate.'
)
args
=
parser
.
parse_args
()
print
(
args
)
if
args
.
data
==
'cifar10'
:
image_size
=
32
block_num
=
3
elif
args
.
data
==
'imagenet'
:
image_size
=
224
block_num
=
6
else
:
raise
NotImplementedError
(
'data must in [cifar10, imagenet], but received: {}'
.
format
(
args
.
data
))
config
=
[(
'MobileNetV2Space'
)]
search_mobilenetv2
(
config
,
args
,
image_size
,
is_server
=
args
.
is_server
)
docs/zh_cn/quick_start/nas_tutorial.md
浏览文件 @
60544b52
docs/zh_cn/tutorials/sanas_darts_space.md
浏览文件 @
60544b52
...
...
@@ -33,7 +33,7 @@
按照通道数来区分DARTS_model中block的话,则DARTS_model中共有3个block,第一个block仅包含6个normal cell,之后的两个block每个block都包含和一个reduction cell和6个normal cell,共有20个cell。在构造搜索空间的时候我们定义每个cell中的所有卷积操作都使用相同的通道数,共有20位token。
完整的搜索空间可以参考
[
基于DARTS_model的搜索空间
](
../../..
/paddleslim/nas/search_space/darts_space.py
)
完整的搜索空间可以参考
[
基于DARTS_model的搜索空间
](
https://github.com/PaddlePaddle/PaddleSlim/blob/develop
/paddleslim/nas/search_space/darts_space.py
)
### 2. 引入依赖包并定义全局变量
```
python
...
...
@@ -232,9 +232,9 @@ exe.run(startup_program)
```
#### 9.5 定义输入数据
由于本示例中对cifar10中的图片进行了一些额外的预处理操作,和
[
快速开始
](
../quick_start/nas_tutorial.md
)
示例中的reader不同,所以需要自定义cifar10的reader,不能直接调用paddle中封装好的
`paddle.dataset.cifar10`
的reader。自定义cifar10的reader文件位于
[
demo/nas
](
../../..
/demo/nas/darts_cifar10_reader.py
)
中。
由于本示例中对cifar10中的图片进行了一些额外的预处理操作,和
[
快速开始
](
https://paddlepaddle.github.io/PaddleSlim/quick_start/nas_tutorial.html
)
示例中的reader不同,所以需要自定义cifar10的reader,不能直接调用paddle中封装好的
`paddle.dataset.cifar10`
的reader。自定义cifar10的reader文件位于
[
demo/nas
](
https://github.com/PaddlePaddle/PaddleSlim/blob/develop
/demo/nas/darts_cifar10_reader.py
)
中。
**注意:**
本示例为了简化代码直接调用
`paddle.dataset.cifar10`
定义训练数据和预测数据,实际训练需要使用自定义cifar10的reader。
**注意:**
本示例为了简化代码直接调用
`paddle.dataset.cifar10`
定义训练数据和预测数据,实际训练需要使用自定义cifar10
文件中
的reader。
```
python
train_reader
=
paddle
.
batch
(
paddle
.
reader
.
shuffle
(
paddle
.
dataset
.
cifar
.
train10
(
cycle
=
False
),
buf_size
=
1024
),
batch_size
=
BATCH_SIZE
,
drop_last
=
True
)
test_reader
=
paddle
.
batch
(
paddle
.
dataset
.
cifar
.
test10
(
cycle
=
False
),
batch_size
=
BATCH_SIZE
,
drop_last
=
False
)
...
...
@@ -261,14 +261,14 @@ sa_nas.reward(float(valid_top1_list[-1] + valid_top1_list[-2]) / 2)
### 10. 利用demo下的脚本启动搜索
搜索文件位于:
[
darts_sanas_demo
](
https://github.com/PaddlePaddle/PaddleSlim/
tree/develop/demo/nas/sanas_darts_nas
.py
)
,搜索过程中限制模型参数量为不大于3.77M。
搜索文件位于:
[
darts_sanas_demo
](
https://github.com/PaddlePaddle/PaddleSlim/
blob/develop/demo/nas/sanas_darts_space
.py
)
,搜索过程中限制模型参数量为不大于3.77M。
```
python
cd
demo
/
nas
/
python
darts_nas
.
py
```
### 11. 利用demo下的脚本启动最终实验
最终实验文件位于:
[
darts_sanas_demo
](
https://github.com/PaddlePaddle/PaddleSlim/
tree/develop/demo/nas/sanas_darts_nas
.py
)
,最终实验需要训练600epoch。以下示例输入token为
`[5, 5, 0, 5, 5, 10, 7, 7, 5, 7, 7, 11, 10, 12, 10, 0, 5, 3, 10, 8]`
。
最终实验文件位于:
[
darts_sanas_demo
](
https://github.com/PaddlePaddle/PaddleSlim/
blob/develop/demo/nas/sanas_darts_space
.py
)
,最终实验需要训练600epoch。以下示例输入token为
`[5, 5, 0, 5, 5, 10, 7, 7, 5, 7, 7, 11, 10, 12, 10, 0, 5, 3, 10, 8]`
。
```
python
cd
demo
/
nas
/
python
darts_nas
.
py
--
token
5
5
0
5
5
10
7
7
5
7
7
11
10
12
10
0
5
3
10
8
--
retain_epoch
600
...
...
paddleslim/common/__init__.py
浏览文件 @
60544b52
...
...
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.controller
import
EvolutionaryController
from
.controller
import
EvolutionaryController
,
RLBaseController
from
.sa_controller
import
SAController
from
.log_helper
import
get_logger
from
.controller_server
import
ControllerServer
...
...
paddleslim/common/client.py
浏览文件 @
60544b52
...
...
@@ -99,7 +99,7 @@ class Client(object):
assert
self
.
_params_dict
!=
None
,
"Please call next_token to get token first, then call update"
current_params_dict
=
self
.
_controller
.
update
(
rewards
,
self
.
_params_dict
,
**
kwargs
)
params_grad
=
compute_grad
(
self
.
_params_dict
,
current
_params_dict
)
params_grad
=
compute_grad
(
current_params_dict
,
self
.
_params_dict
)
_logger
.
debug
(
"Client: update weight {}"
.
format
(
self
.
_client_name
))
self
.
_client_socket
.
send_multipart
([
pickle
.
dumps
(
ConnectMessage
.
UPDATE_WEIGHT
),
...
...
paddleslim/common/rl_controller/base_env.py
0 → 100644
浏览文件 @
60544b52
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base environment used in reinforcement learning"""
import
numpy
as
np
__all__
=
[
'BaseEnv'
]
class
BaseEnv
:
def
reset
(
self
):
raise
NotImplementedError
(
'Abstract method.'
)
def
step
(
self
):
raise
NotImplementedError
(
'Abstract method.'
)
def
_build_state_embedding
(
self
):
raise
NotImplementedError
(
'Abstract method.'
)
paddleslim/common/rl_controller/ddpg/ddpg_controller.py
浏览文件 @
60544b52
...
...
@@ -41,24 +41,24 @@ class DDPGAgent(parl.Agent):
self
.
learn_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
self
.
pred_program
):
obs
=
layers
.
data
(
name
=
'obs'
,
shape
=
[
self
.
obs_dim
],
dtype
=
'float32'
)
obs
=
fluid
.
data
(
name
=
'obs'
,
shape
=
[
None
,
self
.
obs_dim
],
dtype
=
'float32'
)
self
.
pred_act
=
self
.
alg
.
predict
(
obs
)
with
fluid
.
program_guard
(
self
.
learn_program
):
obs
=
layers
.
data
(
name
=
'obs'
,
shape
=
[
self
.
obs_dim
],
dtype
=
'float32'
)
act
=
layers
.
data
(
name
=
'act'
,
shape
=
[
self
.
act_dim
],
dtype
=
'float32'
)
reward
=
layers
.
data
(
name
=
'reward'
,
shape
=
[],
dtype
=
'float32'
)
next_obs
=
layers
.
data
(
name
=
'next_obs'
,
shape
=
[
self
.
obs_dim
],
dtype
=
'float32'
)
terminal
=
layers
.
data
(
name
=
'terminal'
,
shape
=
[],
dtype
=
'bool'
)
obs
=
fluid
.
data
(
name
=
'obs'
,
shape
=
[
None
,
self
.
obs_dim
],
dtype
=
'float32'
)
act
=
fluid
.
data
(
name
=
'act'
,
shape
=
[
None
,
self
.
act_dim
],
dtype
=
'float32'
)
reward
=
fluid
.
data
(
name
=
'reward'
,
shape
=
[
None
],
dtype
=
'float32'
)
next_obs
=
fluid
.
data
(
name
=
'next_obs'
,
shape
=
[
None
,
self
.
obs_dim
],
dtype
=
'float32'
)
terminal
=
fluid
.
data
(
name
=
'terminal'
,
shape
=
[
None
,
1
],
dtype
=
'bool'
)
_
,
self
.
critic_cost
=
self
.
alg
.
learn
(
obs
,
act
,
reward
,
next_obs
,
terminal
)
def
predict
(
self
,
obs
):
obs
=
np
.
expand_dims
(
obs
,
axis
=
0
)
act
=
self
.
fluid_executor
.
run
(
self
.
pred_program
,
feed
=
{
'obs'
:
obs
},
fetch_list
=
[
self
.
pred_act
])[
0
]
...
...
paddleslim/common/rl_controller/lstm/lstm_controller.py
浏览文件 @
60544b52
...
...
@@ -62,6 +62,7 @@ class LSTM(RLBaseController):
self
.
lstm_num_layers
=
kwargs
.
get
(
'lstm_num_layers'
)
or
1
self
.
hidden_size
=
kwargs
.
get
(
'hidden_size'
)
or
100
self
.
temperature
=
kwargs
.
get
(
'temperature'
)
or
None
self
.
controller_lr
=
kwargs
.
get
(
'controller_lr'
)
or
1e-4
self
.
tanh_constant
=
kwargs
.
get
(
'tanh_constant'
)
or
None
self
.
decay
=
kwargs
.
get
(
'decay'
)
or
0.99
self
.
weight_entropy
=
kwargs
.
get
(
'weight_entropy'
)
or
None
...
...
@@ -91,12 +92,6 @@ class LSTM(RLBaseController):
return
logits
,
output
,
new_states
def
_create_parameter
(
self
):
self
.
emb_w
=
fluid
.
layers
.
create_parameter
(
name
=
'emb_w'
,
shape
=
(
self
.
max_range_table
,
self
.
hidden_size
),
dtype
=
'float32'
,
default_initializer
=
uniform_initializer
(
1.0
))
self
.
g_emb
=
fluid
.
layers
.
create_parameter
(
name
=
'emb_g'
,
shape
=
(
self
.
controller_batch_size
,
self
.
hidden_size
),
...
...
@@ -133,11 +128,16 @@ class LSTM(RLBaseController):
axes
=
[
1
],
starts
=
[
idx
],
ends
=
[
idx
+
1
])
action
=
fluid
.
layers
.
squeeze
(
action
,
axes
=
[
1
])
action
.
stop_gradient
=
True
else
:
action
=
fluid
.
layers
.
sampling_id
(
probs
)
actions
.
append
(
action
)
log_prob
=
fluid
.
layers
.
cross_entropy
(
probs
,
action
)
log_prob
=
fluid
.
layers
.
softmax_with_cross_entropy
(
logits
,
fluid
.
layers
.
reshape
(
action
,
shape
=
[
fluid
.
layers
.
shape
(
action
),
1
]),
axis
=
1
)
sample_log_probs
.
append
(
log_prob
)
entropy
=
log_prob
*
fluid
.
layers
.
exp
(
-
1
*
log_prob
)
...
...
@@ -145,10 +145,14 @@ class LSTM(RLBaseController):
entropies
.
append
(
entropy
)
action_emb
=
fluid
.
layers
.
cast
(
action
,
dtype
=
np
.
int64
)
inputs
=
fluid
.
layers
.
gather
(
self
.
emb_w
,
action_emb
)
inputs
=
fluid
.
embedding
(
action_emb
,
size
=
(
self
.
max_range_table
,
self
.
hidden_size
),
param_attr
=
fluid
.
ParamAttr
(
name
=
'emb_w'
,
initializer
=
uniform_initializer
(
1.0
)))
s
ample_log_probs
=
fluid
.
layers
.
stack
(
sample_log_probs
)
self
.
sample_log_probs
=
fluid
.
layers
.
reduce_sum
(
sample_log_probs
)
s
elf
.
sample_log_probs
=
fluid
.
layers
.
concat
(
sample_log_probs
,
axis
=
0
)
entropies
=
fluid
.
layers
.
stack
(
entropies
)
self
.
sample_entropies
=
fluid
.
layers
.
reduce_sum
(
entropies
)
...
...
@@ -196,7 +200,9 @@ class LSTM(RLBaseController):
self
.
loss
=
self
.
sample_log_probs
*
(
self
.
rewards
-
self
.
baseline
)
fluid
.
clip
.
set_gradient_clip
(
clip
=
fluid
.
clip
.
GradientClipByGlobalNorm
(
clip_norm
=
5.0
))
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
1e-3
)
lr
=
fluid
.
layers
.
exponential_decay
(
self
.
controller_lr
,
decay_steps
=
1000
,
decay_rate
=
0.8
)
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
lr
)
optimizer
.
minimize
(
self
.
loss
)
def
_create_input
(
self
,
is_test
=
True
,
actual_rewards
=
None
):
...
...
paddleslim/common/server.py
浏览文件 @
60544b52
...
...
@@ -161,6 +161,8 @@ class Server(object):
if
len
(
self
.
_client
)
==
len
(
self
.
_client_dict
.
items
()):
self
.
_done
=
True
self
.
_params_dict
=
sum_params_dict
del
sum_params_dict
self
.
_server_socket
.
send_multipart
([
pickle
.
dumps
(
ConnectMessage
.
WAIT
),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录