提交 ed6400e1 编写于 作者: C Channingss

add _rename_or_remove_weight

上级 89b96861
...@@ -42,14 +42,17 @@ def _const_weight_or_none(node, necessary=False): ...@@ -42,14 +42,17 @@ def _const_weight_or_none(node, necessary=False):
return None return None
def _rename_or_remove_weight(weights, origin_name, target_name=None): def _rename_or_remove_weight(weights, origin_name, target_name=None, is_remove=True):
if origin_name not in weights: if origin_name not in weights:
raise KeyError('{} not a key in {}'.format(origin_name, weights)) raise KeyError('{} not a key in {}'.format(origin_name, weights))
if target_name is None: if is_remove:
# remove weight # remove weight
weights.pop(origin_name) data = weights.pop(origin_name)
# rename weight else:
weights[target_name] = weights.pop(origin_name) data = weights[origin_name]
if target_name is not None:
# rename weight
weights[target_name] = data
def _is_static_shape(shape): def _is_static_shape(shape):
negtive_dims = 0 negtive_dims = 0
...@@ -1700,19 +1703,17 @@ class OpSet9(): ...@@ -1700,19 +1703,17 @@ class OpSet9():
"dilation": dilations, "dilation": dilations,
"groups": num_groups, "groups": num_groups,
} }
val_w_name = val_w.name remove_weight = True if val_w.name in self.done_weight_list else False
while val_w_name in self.done_weight_list: if remove_weight:
val_w_name += "__repeat" self.done_weight_list.append(val_w.name)
self.done_weight_list.append(val_w_name)
#self.weights[op_name + '.weight'] = self.weights[val_w.name] #self.weights[op_name + '.weight'] = self.weights[val_w.name]
_rename_or_remove_weight(self.weights, val_w.name, op_name+'.weight') _rename_or_remove_weight(self.weights, val_w.name, op_name+'.weight', remove_weight)
if has_bias: if has_bias:
val_b_name = val_b.name remove_bias = True if val_b.name in self.done_weight_list else False
while val_b_name in self.done_weight_list: if remove_bias:
val_b_name += "__repeat" self.done_weight_list.append(val_b_name)
self.done_weight_list.append(val_b_name)
#self.weights[op_name + '.bias'] = self.weights[val_b.name] #self.weights[op_name + '.bias'] = self.weights[val_b.name]
_rename_or_remove_weight(self.weights, val_b.name, op_name+'.bias') _rename_or_remove_weight(self.weights, val_b.name, op_name+'.bias', remove_bias)
else: else:
layer_attrs["bias_attr"] = False layer_attrs["bias_attr"] = False
input_shape = val_x.out_shapes[0] input_shape = val_x.out_shapes[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册