From e020c22f0b9e4897d6d2871330ee50c7b0346a27 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Thu, 27 May 2021 12:53:58 +0800 Subject: [PATCH] fix condition of layer_forward in ofa (#777) (#782) --- paddleslim/nas/ofa/get_sub_model.py | 2 +- paddleslim/nas/ofa/ofa.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/paddleslim/nas/ofa/get_sub_model.py b/paddleslim/nas/ofa/get_sub_model.py index d09ab878..66fdd2c0 100644 --- a/paddleslim/nas/ofa/get_sub_model.py +++ b/paddleslim/nas/ofa/get_sub_model.py @@ -225,7 +225,7 @@ def check_search_space(graph): depthwise_conv.append(inp._var.name) if len(same_search_space) == 0: - return None, None + return None, [] same_search_space = sorted([sorted(x) for x in same_search_space]) final_search_space = [] diff --git a/paddleslim/nas/ofa/ofa.py b/paddleslim/nas/ofa/ofa.py index 28c250e1..477f815a 100644 --- a/paddleslim/nas/ofa/ofa.py +++ b/paddleslim/nas/ofa/ofa.py @@ -108,9 +108,9 @@ class OFABase(Layer): if getattr(self, 'current_config', None) != None: ### if block is fixed, donnot join key into candidate ### concrete config as parameter in kwargs - if block.fixed == False and ( - self._skip_layers != None and - self._key2name[block.key] not in self._skip_layers) and \ + if block.fixed == False and (self._skip_layers == None or + (self._skip_layers != None and + self._key2name[block.key] not in self._skip_layers)) and \ (block.fn.weight.name not in self._depthwise_conv): assert self._key2name[ block. @@ -180,6 +180,7 @@ class OFA(OFABase): self._build_ss = False self._broadcast = False self._skip_layers = None + self._depthwise_conv = [] ### if elastic_order is none, use default order if self.elastic_order is not None: -- GitLab