From f881efb15c243e923049be64dfe037f35b310052 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Mon, 18 May 2020 11:30:11 +0800 Subject: [PATCH] update 1.8 rlnas (#294) (#295) * fix * update --- demo/nas/rl_nas_mobilenetv2.py | 7 ++++++- docs/zh_cn/api_cn/nas_api.rst | 3 +++ paddleslim/common/rl_controller/__init__.py | 2 +- .../common/rl_controller/lstm/lstm_controller.py | 16 +++++++++++----- 4 files changed, 21 insertions(+), 7 deletions(-) diff --git a/demo/nas/rl_nas_mobilenetv2.py b/demo/nas/rl_nas_mobilenetv2.py index abf23fb9..3eb56233 100644 --- a/demo/nas/rl_nas_mobilenetv2.py +++ b/demo/nas/rl_nas_mobilenetv2.py @@ -65,8 +65,10 @@ def search_mobilenetv2(config, args, image_size, is_server=True): is_sync=False, server_addr=(args.server_address, args.port), controller_batch_size=1, + controller_decay_steps=1000, + controller_decay_rate=0.8, lstm_num_layers=1, - hidden_size=100, + hidden_size=10, temperature=1.0) else: ### start a client @@ -78,6 +80,9 @@ def search_mobilenetv2(config, args, image_size, is_server=True): lstm_num_layers=1, hidden_size=10, temperature=1.0, + controller_batch_size=1, + controller_decay_steps=1000, + controller_decay_rate=0.8, is_server=False) image_shape = [3, image_size, image_size] diff --git a/docs/zh_cn/api_cn/nas_api.rst b/docs/zh_cn/api_cn/nas_api.rst index d906f20d..9cb0938f 100644 --- a/docs/zh_cn/api_cn/nas_api.rst +++ b/docs/zh_cn/api_cn/nas_api.rst @@ -194,6 +194,9 @@ RLNAS (Reinforcement Learning Neural Architecture Search)是基于强化学习 - decay(float, optional): LSTM中记录rewards的baseline的平滑率。默认:0.99. - weight_entropy(float, optional): 在更新controller参数时是否为接收到的rewards加上计算token过程中的带权重的交叉熵值。默认:None。 - controller_batch_size(int, optional): controller的batch_size,即每运行一次controller可以拿到几组token。默认:1. + - controller_lr(float, optional): controller的学习率,默认:1e-4。 + - controller_decay_steps(int, optional): controller学习率下降步长,设置为None的时候学习率不下降。默认:None。 + - controller_decay_rate(float, optional): controller学习率衰减率,默认:None。 - **`DDPG`算法的附加参数:** diff --git a/paddleslim/common/rl_controller/__init__.py b/paddleslim/common/rl_controller/__init__.py index 3dcf14e0..da815888 100644 --- a/paddleslim/common/rl_controller/__init__.py +++ b/paddleslim/common/rl_controller/__init__.py @@ -20,7 +20,7 @@ try: from .ddpg import * except ImportError as e: _logger.warn( - "If you want to use DDPG in RLNAS, please pip intall parl first. Now states: {}". + "If you want to use DDPG in RLNAS, please pip install parl first. Now states: {}". format(e)) from .lstm import * diff --git a/paddleslim/common/rl_controller/lstm/lstm_controller.py b/paddleslim/common/rl_controller/lstm/lstm_controller.py index 0e32be6d..920b29ea 100644 --- a/paddleslim/common/rl_controller/lstm/lstm_controller.py +++ b/paddleslim/common/rl_controller/lstm/lstm_controller.py @@ -63,6 +63,8 @@ class LSTM(RLBaseController): self.hidden_size = kwargs.get('hidden_size') or 100 self.temperature = kwargs.get('temperature') or None self.controller_lr = kwargs.get('controller_lr') or 1e-4 + self.decay_steps = kwargs.get('controller_decay_steps') or None + self.decay_rate = kwargs.get('controller_decay_rate') or None self.tanh_constant = kwargs.get('tanh_constant') or None self.decay = kwargs.get('decay') or 0.99 self.weight_entropy = kwargs.get('weight_entropy') or None @@ -198,11 +200,15 @@ class LSTM(RLBaseController): fluid.layers.assign(self.baseline - (1.0 - self.decay) * (self.baseline - self.rewards), self.baseline) self.loss = self.sample_log_probs * (self.rewards - self.baseline) - fluid.clip.set_gradient_clip( - clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=5.0)) - lr = fluid.layers.exponential_decay( - self.controller_lr, decay_steps=1000, decay_rate=0.8) - optimizer = fluid.optimizer.Adam(learning_rate=lr) + clip = fluid.clip.GradientClipByNorm(clip_norm=5.0) + if self.decay_steps is not None: + lr = fluid.layers.exponential_decay( + self.controller_lr, + decay_steps=self.decay_steps, + decay_rate=self.decay_rate) + else: + lr = self.controller_lr + optimizer = fluid.optimizer.Adam(learning_rate=lr, grad_clip=clip) optimizer.minimize(self.loss) def _create_input(self, is_test=True, actual_rewards=None): -- GitLab