未验证 提交 f881efb1 编写于 作者: C ceci3 提交者: GitHub

update 1.8 rlnas (#294) (#295)

* fix

* update
上级 2ee81c79
......@@ -65,8 +65,10 @@ def search_mobilenetv2(config, args, image_size, is_server=True):
is_sync=False,
server_addr=(args.server_address, args.port),
controller_batch_size=1,
controller_decay_steps=1000,
controller_decay_rate=0.8,
lstm_num_layers=1,
hidden_size=100,
hidden_size=10,
temperature=1.0)
else:
### start a client
......@@ -78,6 +80,9 @@ def search_mobilenetv2(config, args, image_size, is_server=True):
lstm_num_layers=1,
hidden_size=10,
temperature=1.0,
controller_batch_size=1,
controller_decay_steps=1000,
controller_decay_rate=0.8,
is_server=False)
image_shape = [3, image_size, image_size]
......
......@@ -194,6 +194,9 @@ RLNAS (Reinforcement Learning Neural Architecture Search)是基于强化学习
- decay(float, optional): LSTM中记录rewards的baseline的平滑率。默认:0.99.
- weight_entropy(float, optional): 在更新controller参数时是否为接收到的rewards加上计算token过程中的带权重的交叉熵值。默认:None。
- controller_batch_size(int, optional): controller的batch_size,即每运行一次controller可以拿到几组token。默认:1.
- controller_lr(float, optional): controller的学习率,默认:1e-4。
- controller_decay_steps(int, optional): controller学习率下降步长,设置为None的时候学习率不下降。默认:None。
- controller_decay_rate(float, optional): controller学习率衰减率,默认:None。
- **`DDPG`算法的附加参数:**
......
......@@ -20,7 +20,7 @@ try:
from .ddpg import *
except ImportError as e:
_logger.warn(
"If you want to use DDPG in RLNAS, please pip intall parl first. Now states: {}".
"If you want to use DDPG in RLNAS, please pip install parl first. Now states: {}".
format(e))
from .lstm import *
......
......@@ -63,6 +63,8 @@ class LSTM(RLBaseController):
self.hidden_size = kwargs.get('hidden_size') or 100
self.temperature = kwargs.get('temperature') or None
self.controller_lr = kwargs.get('controller_lr') or 1e-4
self.decay_steps = kwargs.get('controller_decay_steps') or None
self.decay_rate = kwargs.get('controller_decay_rate') or None
self.tanh_constant = kwargs.get('tanh_constant') or None
self.decay = kwargs.get('decay') or 0.99
self.weight_entropy = kwargs.get('weight_entropy') or None
......@@ -198,11 +200,15 @@ class LSTM(RLBaseController):
fluid.layers.assign(self.baseline - (1.0 - self.decay) *
(self.baseline - self.rewards), self.baseline)
self.loss = self.sample_log_probs * (self.rewards - self.baseline)
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=5.0))
lr = fluid.layers.exponential_decay(
self.controller_lr, decay_steps=1000, decay_rate=0.8)
optimizer = fluid.optimizer.Adam(learning_rate=lr)
clip = fluid.clip.GradientClipByNorm(clip_norm=5.0)
if self.decay_steps is not None:
lr = fluid.layers.exponential_decay(
self.controller_lr,
decay_steps=self.decay_steps,
decay_rate=self.decay_rate)
else:
lr = self.controller_lr
optimizer = fluid.optimizer.Adam(learning_rate=lr, grad_clip=clip)
optimizer.minimize(self.loss)
def _create_input(self, is_test=True, actual_rewards=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册