未验证 提交 e020c22f 编写于 作者: C ceci3 提交者: GitHub

fix condition of layer_forward in ofa (#777) (#782)

上级 bb224bf7
...@@ -225,7 +225,7 @@ def check_search_space(graph): ...@@ -225,7 +225,7 @@ def check_search_space(graph):
depthwise_conv.append(inp._var.name) depthwise_conv.append(inp._var.name)
if len(same_search_space) == 0: if len(same_search_space) == 0:
return None, None return None, []
same_search_space = sorted([sorted(x) for x in same_search_space]) same_search_space = sorted([sorted(x) for x in same_search_space])
final_search_space = [] final_search_space = []
......
...@@ -108,9 +108,9 @@ class OFABase(Layer): ...@@ -108,9 +108,9 @@ class OFABase(Layer):
if getattr(self, 'current_config', None) != None: if getattr(self, 'current_config', None) != None:
### if block is fixed, donnot join key into candidate ### if block is fixed, donnot join key into candidate
### concrete config as parameter in kwargs ### concrete config as parameter in kwargs
if block.fixed == False and ( if block.fixed == False and (self._skip_layers == None or
self._skip_layers != None and (self._skip_layers != None and
self._key2name[block.key] not in self._skip_layers) and \ self._key2name[block.key] not in self._skip_layers)) and \
(block.fn.weight.name not in self._depthwise_conv): (block.fn.weight.name not in self._depthwise_conv):
assert self._key2name[ assert self._key2name[
block. block.
...@@ -180,6 +180,7 @@ class OFA(OFABase): ...@@ -180,6 +180,7 @@ class OFA(OFABase):
self._build_ss = False self._build_ss = False
self._broadcast = False self._broadcast = False
self._skip_layers = None self._skip_layers = None
self._depthwise_conv = []
### if elastic_order is none, use default order ### if elastic_order is none, use default order
if self.elastic_order is not None: if self.elastic_order is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册