diff --git a/ppcls/arch/backbone/legendary_models/hrnet.py b/ppcls/arch/backbone/legendary_models/hrnet.py index 0a3f3b1afaf612af92929576c2d9737e56c14d88..05ebf525d955ed1c78474f823e9adec93f3cad7c 100644 --- a/ppcls/arch/backbone/legendary_models/hrnet.py +++ b/ppcls/arch/backbone/legendary_models/hrnet.py @@ -221,26 +221,22 @@ class Stage(TheseusLayer): self._num_modules = num_modules - self.stage_func_list = [] + self.stage_func_list = nn.LayerList() for i in range(num_modules): if i == num_modules - 1 and not multi_scale_output: - stage_func = self.add_sublayer( - "stage_{}_{}".format(name, i + 1), + self.stage_func_list.append( HighResolutionModule( num_filters=num_filters, has_se=has_se, multi_scale_output=False, name=name + '_' + str(i + 1))) else: - stage_func = self.add_sublayer( - "stage_{}_{}".format(name, i + 1), + self.stage_func_list.append( HighResolutionModule( num_filters=num_filters, has_se=has_se, name=name + '_' + str(i + 1))) - self.stage_func_list.append(stage_func) - def forward(self, input, res_dict=None): out = input for idx in range(self._num_modules):