未验证 提交 689bddf4 编写于 作者: C ceci3 提交者: GitHub

Fix lstm windows (#255)

* fix

* fix win
上级 9a9cf241
......@@ -177,7 +177,7 @@ def test_search_result(tokens, image_size, args, config):
image_shape = [3, image_size, image_size]
archs = sa_nas.tokens2arch(tokens)
archs = sa_nas.tokens2arch(tokens)[0]
train_program = fluid.Program()
test_program = fluid.Program()
......
......@@ -3,8 +3,8 @@
首先导入必要的依赖:
```python
### 引入强化学习Controller基类函数和注册类函数
from paddleslim.common.RL_controller.utils import RLCONTROLLER
from paddleslim.common.RL_controller import RLBaseController
from paddleslim.common.rl_controller.utils import RLCONTROLLER
from paddleslim.common.rl_controller import RLBaseController
```
通过装饰器的方式把自定义强化学习Controller注册到PaddleSlim,继承基类之后需要重写基类中的`next_tokens``update`两个函数。注意:本示例仅说明一些必不可少的步骤,并不能直接运行,完整代码请参考[这里]()
......
......@@ -25,7 +25,7 @@ if six.PY2:
else:
import pickle
from .log_helper import get_logger
from .RL_controller.utils import compute_grad, ConnectMessage
from .rl_controller.utils import compute_grad, ConnectMessage
_logger = get_logger(__name__, level=logging.INFO)
......
......@@ -17,11 +17,11 @@ from ..log_helper import get_logger
_logger = get_logger(__name__, level=logging.INFO)
try:
import parl
from .DDPG import *
from .ddpg import *
except ImportError as e:
_logger.warn(
"If you want to use DDPG in RLNAS, please pip intall parl first. Now states: {}".
format(e))
from .LSTM import *
from .lstm import *
from .utils import *
......@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .LSTM_Controller import *
from .ddpg_controller import *
......@@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .DDPGController import *
from .lstm_controller import *
......@@ -224,7 +224,8 @@ class LSTM(RLBaseController):
actual_rewards = actual_rewards.astype(np.float32)
feed_dict['rewards'] = actual_rewards
feed_dict['init_actions'] = np.array(self.init_tokens)
feed_dict['init_actions'] = np.array(self.init_tokens).astype(
'int64')
return feed_dict
......
......@@ -25,7 +25,7 @@ import logging
import time
import threading
from .log_helper import get_logger
from .RL_controller.utils import add_grad, ConnectMessage
from .rl_controller.utils import add_grad, ConnectMessage
_logger = get_logger(__name__, level=logging.INFO)
......
......@@ -20,7 +20,7 @@ import json
import hashlib
import time
import paddle.fluid as fluid
from ..common.RL_controller.utils import RLCONTROLLER
from ..common.rl_controller.utils import RLCONTROLLER
from ..common import get_logger
from ..common import Server
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册