提交 ed6400e1 编写于 作者: C Channingss

add _rename_or_remove_weight

上级 89b96861
......@@ -42,14 +42,17 @@ def _const_weight_or_none(node, necessary=False):
return None
def _rename_or_remove_weight(weights, origin_name, target_name=None):
def _rename_or_remove_weight(weights, origin_name, target_name=None, is_remove=True):
if origin_name not in weights:
raise KeyError('{} not a key in {}'.format(origin_name, weights))
if target_name is None:
if is_remove:
# remove weight
weights.pop(origin_name)
data = weights.pop(origin_name)
else:
data = weights[origin_name]
if target_name is not None:
# rename weight
weights[target_name] = weights.pop(origin_name)
weights[target_name] = data
def _is_static_shape(shape):
negtive_dims = 0
......@@ -1700,19 +1703,17 @@ class OpSet9():
"dilation": dilations,
"groups": num_groups,
}
val_w_name = val_w.name
while val_w_name in self.done_weight_list:
val_w_name += "__repeat"
self.done_weight_list.append(val_w_name)
remove_weight = True if val_w.name in self.done_weight_list else False
if remove_weight:
self.done_weight_list.append(val_w.name)
#self.weights[op_name + '.weight'] = self.weights[val_w.name]
_rename_or_remove_weight(self.weights, val_w.name, op_name+'.weight')
_rename_or_remove_weight(self.weights, val_w.name, op_name+'.weight', remove_weight)
if has_bias:
val_b_name = val_b.name
while val_b_name in self.done_weight_list:
val_b_name += "__repeat"
remove_bias = True if val_b.name in self.done_weight_list else False
if remove_bias:
self.done_weight_list.append(val_b_name)
#self.weights[op_name + '.bias'] = self.weights[val_b.name]
_rename_or_remove_weight(self.weights, val_b.name, op_name+'.bias')
_rename_or_remove_weight(self.weights, val_b.name, op_name+'.bias', remove_bias)
else:
layer_attrs["bias_attr"] = False
input_shape = val_x.out_shapes[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册