diff --git a/paddleslim/common/rl_controller/__init__.py b/paddleslim/common/rl_controller/__init__.py index da8158882bd7764b57247a522e2bfb38c4960d7c..b29c563781d1867c38e071f8996835438e9c7fad 100644 --- a/paddleslim/common/rl_controller/__init__.py +++ b/paddleslim/common/rl_controller/__init__.py @@ -19,9 +19,7 @@ try: import parl from .ddpg import * except ImportError as e: - _logger.warn( - "If you want to use DDPG in RLNAS, please pip install parl first. Now states: {}". - format(e)) + pass from .lstm import * from .utils import * diff --git a/paddleslim/nas/rl_nas.py b/paddleslim/nas/rl_nas.py index a7fa6591c06385ca8de33131f5c30559617e224a..1718b8347def79a222c7dc016b2a2d021d380f36 100644 --- a/paddleslim/nas/rl_nas.py +++ b/paddleslim/nas/rl_nas.py @@ -76,6 +76,15 @@ class RLNAS(object): self.save_controller = save_controller self.load_controller = load_controller + if key.upper() in ['DDPG']: + try: + import parl + except ImportError as e: + _logger.error( + "If you want to use DDPG in RLNAS, please pip install parl first. Now states: {}". + format(e)) + os._exit(1) + cls = RLCONTROLLER.get(key.upper()) server_ip, server_port = server_addr