提交 47461da0 编写于 作者: W wanghaoshuang

Merge branch 'fix_py3' into 'develop'

fix for py3

See merge request !75
......@@ -11,13 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import flops as flops_module
from flops import *
import model_size as model_size_module
from model_size import *
import latency as latency_module
from latency import *
__all__ = []
__all__ += flops_module.__all__
__all__ += model_size_module.__all__
__all__ += latency_module.__all__
from .flops import flops
from .model_size import model_size
from .latency import LatencyEvaluator, TableLatencyEvaluator
__all__ = ['flops', 'model_size', 'LatencyEvaluator', 'TableLatencyEvaluator']
......@@ -11,25 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import controller
from controller import *
import sa_controller
from sa_controller import *
import log_helper
from log_helper import *
import controller_server
from controller_server import *
import controller_client
from controller_client import *
import lock_utils
from lock_utils import *
import cached_reader as cached_reader_module
from cached_reader import *
from .controller import EvolutionaryController
from .sa_controller import SAController
from .log_helper import get_logger
from .controller_server import ControllerServer
from .controller_client import ControllerClient
from .lock_utils import lock, unlock
from .cached_reader import cached_reader
__all__ = []
__all__ += controller.__all__
__all__ += sa_controller.__all__
__all__ += controller_server.__all__
__all__ += controller_client.__all__
__all__ += lock_utils.__all__
__all__ += cached_reader_module.__all__
__all__ = [
'EvolutionaryController', 'SAController', 'get_logger', 'ControllerServer',
'ControllerClient', 'lock', 'unlock', 'cached_reader'
]
......@@ -14,7 +14,7 @@
import logging
import socket
from log_helper import get_logger
from .log_helper import get_logger
__all__ = ['ControllerClient']
......
......@@ -107,7 +107,7 @@ class ControllerServer(object):
_logger.debug("send message to {}: [{}]".format(addr,
tokens))
conn.close()
except Exception, err:
except Exception as err:
_logger.error(err)
finally:
self._socket_server.close()
......
......@@ -20,7 +20,7 @@ import logging
import numpy as np
import json
from .controller import EvolutionaryController
from log_helper import get_logger
from .log_helper import get_logger
__all__ = ["SAController"]
......
......@@ -12,10 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import graph_wrapper
from .graph_wrapper import *
from . import registry
from .registry import *
from .graph_wrapper import GraphWrapper, VarWrapper, OpWrapper
from .registry import Registry
__all__ = graph_wrapper.__all__
__all__ += registry.__all__
__all__ = ['GraphWrapper', 'VarWrapper', 'OpWrapper', 'Registry']
......@@ -12,11 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import search_space
from search_space import *
import sa_nas
from sa_nas import *
from .search_space import *
from .sa_nas import SANAS
__all__ = []
__all__ += search_space.__all__
__all__ += sa_nas.__all__
__all__ = ['SANAS']
......@@ -64,7 +64,7 @@ class SANAS(object):
self._init_temperature = init_temperature
self._is_server = is_server
self._configs = configs
self._key = hashlib.md5(str(self._configs)).hexdigest()
self._key = hashlib.md5(str(self._configs).encode("utf-8")).hexdigest()
server_ip, server_port = server_addr
if server_ip == None or server_ip == "":
......
......@@ -12,27 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import mobilenetv2
from .mobilenetv2 import *
import mobilenet_block
from .mobilenet_block import *
import mobilenetv1
from .mobilenetv1 import *
import resnet
from .resnet import *
import resnet_block
from .resnet_block import *
import inception_block
from .inception_block import *
import search_space_registry
from search_space_registry import *
import search_space_factory
from search_space_factory import *
import search_space_base
from search_space_base import *
from .mobilenetv2 import MobileNetV2Space
from .mobilenetv1 import MobileNetV1Space
from .resnet import ResNetSpace
from .mobilenet_block import MobileNetV1BlockSpace, MobileNetV2BlockSpace
from .resnet_block import ResNetBlockSpace
from .inception_block import InceptionABlockSpace, InceptionCBlockSpace
from .search_space_registry import SEARCHSPACE
from .search_space_factory import SearchSpaceFactory
from .search_space_base import SearchSpaceBase
__all__ = []
__all__ += mobilenetv2.__all__
__all__ += search_space_registry.__all__
__all__ += search_space_factory.__all__
__all__ += search_space_base.__all__
__all__ = [
'MobileNetV1Space', 'MobileNetV2Space', 'ResNetSpace',
'MobileNetV1BlockSpace', 'MobileNetV2BlockSpace', 'ResNetBlockSpace',
'InceptionABlockSpace', 'InceptionCBlockSpace', 'SearchSpaceBase',
'SearchSpaceFactory', 'SEARCHSPACE'
]
......@@ -28,4 +28,4 @@ class SearchSpaceFactory(object):
"""
assert isinstance(config_lists, list), "configs must be a list"
return CombineSearchSpace(config_lists)
return CombineSearchSpace(config_lists)
......@@ -11,23 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pruner
from pruner import *
import auto_pruner
from auto_pruner import *
import controller_server
from controller_server import *
import controller_client
from controller_client import *
import sensitive_pruner
from sensitive_pruner import *
import sensitive
from sensitive import *
from .pruner import Pruner
from .auto_pruner import AutoPruner
from .controller_server import ControllerServer
from .controller_client import ControllerClient
from .sensitive_pruner import SensitivePruner
from .sensitive import sensitivity, flops_sensitivity
__all__ = []
__all__ += pruner.__all__
__all__ += auto_pruner.__all__
__all__ += controller_server.__all__
__all__ += controller_client.__all__
__all__ += sensitive_pruner.__all__
__all__ += sensitive.__all__
__all__ = [
'Pruner', 'AutoPruner', 'ControllerServer', 'ControllerClient',
'SensitivePruner', 'sensitivity', 'flops_sensitivity'
]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册