提交 9fc66d34 编写于 作者: C Channingss

add _rename_or_remove_weight

上级 ac9c0ead
...@@ -42,6 +42,15 @@ def _const_weight_or_none(node, necessary=False): ...@@ -42,6 +42,15 @@ def _const_weight_or_none(node, necessary=False):
return None return None
def rename_or_remove_weight(weights, origin_name, target_name=None):
if origin_name not in weights:
raise KeyError('{} not a key in {}'.format(origin_name, weights))
if target_name is None:
# remove weight
weights.pop(origin_name)
# rename weight
weights[target_name] = weights.pop(origin_name)
def _is_static_shape(shape): def _is_static_shape(shape):
negtive_dims = 0 negtive_dims = 0
error_dims = 0 error_dims = 0
...@@ -1320,10 +1329,16 @@ class OpSet9(): ...@@ -1320,10 +1329,16 @@ class OpSet9():
epsilon = node.get_attr('epsilon', 1e-5) epsilon = node.get_attr('epsilon', 1e-5)
c = val_x.out_shapes[0][1] c = val_x.out_shapes[0][1]
self.weights[op_name + '.weight'] = self.weights[val_scale.name] _rename_or_remove_weight(self.weights, val_scale.name, op_name+'.weight')
self.weights[op_name + '.bias'] = self.weights[val_b.name] _rename_or_remove_weight(self.weights, val_b.name, op_name+'.bias')
self.weights[op_name + '._variance'] = self.weights[val_var.name] _rename_or_remove_weight(self.weights, val_var.name, op_name+'._variance')
self.weights[op_name + '._mean'] = self.weights[val_mean.name] _rename_or_remove_weight(self.weights, val_mean.name, op_name+'._mean')
#self.weights[op_name + '.weight'] = self.weights[val_scale.name]
#self.weights[op_name + '.bias'] = self.weights[val_b.name]
#self.weights[op_name + '._variance'] = self.weights[val_var.name]
#self.weights[op_name + '._mean'] = self.weights[val_mean.name]
# Attribute: spatial is used in BatchNormalization-1,6,7 # Attribute: spatial is used in BatchNormalization-1,6,7
spatial = bool(node.get_attr('spatial')) spatial = bool(node.get_attr('spatial'))
layer_attrs = { layer_attrs = {
...@@ -1395,11 +1410,13 @@ class OpSet9(): ...@@ -1395,11 +1410,13 @@ class OpSet9():
else: else:
if mode == 'channel': if mode == 'channel':
slope_data = _const_weight_or_none(val_slope) slope_data = _const_weight_or_none(val_slope)
_rename_or_remove_weight(self.weights, val_slope.name)
if len(shape_slope) > 1: if len(shape_slope) > 1:
self.weights[op_name+'._weight'] = np.reshape(slope_data, shape_slope[0]) self.weights[op_name+'._weight'] = np.reshape(slope_data, shape_slope[0])
num_parameters = val_x.out_shapes[0][1] num_parameters = val_x.out_shapes[0][1]
else: else:
num_parameters = 1 num_parameters = 1
_rename_or_remove_weight(self.weights, val_slope.name)
self.weights[op_name+'._weight'] = np.reshape(self.weights[val_slope.name], [1]) self.weights[op_name+'._weight'] = np.reshape(self.weights[val_slope.name], [1])
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.nn.PReLU", "paddle.nn.PReLU",
...@@ -1687,13 +1704,15 @@ class OpSet9(): ...@@ -1687,13 +1704,15 @@ class OpSet9():
while val_w_name in self.done_weight_list: while val_w_name in self.done_weight_list:
val_w_name += "__repeat" val_w_name += "__repeat"
self.done_weight_list.append(val_w_name) self.done_weight_list.append(val_w_name)
self.weights[op_name + '.weight'] = self.weights[val_w.name] #self.weights[op_name + '.weight'] = self.weights[val_w.name]
_rename_or_remove_weight(self.weights, val_w.name, op_name+'.weight')
if has_bias: if has_bias:
val_b_name = val_b.name val_b_name = val_b.name
while val_b_name in self.done_weight_list: while val_b_name in self.done_weight_list:
val_b_name += "__repeat" val_b_name += "__repeat"
self.done_weight_list.append(val_b_name) self.done_weight_list.append(val_b_name)
self.weights[op_name + '.bias'] = self.weights[val_b.name] #self.weights[op_name + '.bias'] = self.weights[val_b.name]
_rename_or_remove_weight(self.weights, val_b.name, op_name+'.bias')
else: else:
layer_attrs["bias_attr"] = False layer_attrs["bias_attr"] = False
input_shape = val_x.out_shapes[0] input_shape = val_x.out_shapes[0]
...@@ -1761,9 +1780,11 @@ class OpSet9(): ...@@ -1761,9 +1780,11 @@ class OpSet9():
"groups": num_groups, "groups": num_groups,
"output_padding":out_padding} "output_padding":out_padding}
self.weights[op_name + '.weight'] = self.weights[val_w.name] #self.weights[op_name + '.weight'] = self.weights[val_w.name]
_rename_or_remove_weight(self.weights, val_w.name, op_name+'.weight',)
if val_b is not None: if val_b is not None:
self.weights[op_name + '.bias'] = self.weights[val_b.name] #self.weights[op_name + '.bias'] = self.weights[val_b.name]
_rename_or_remove_weight(self.weights, val_b.name, op_name+'.bias')
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
kernel=paddle_op, kernel=paddle_op,
inputs=inputs_dict, inputs=inputs_dict,
...@@ -1881,10 +1902,13 @@ class OpSet9(): ...@@ -1881,10 +1902,13 @@ class OpSet9():
) )
input_weight_np = _const_weight_or_none(input_weight) input_weight_np = _const_weight_or_none(input_weight)
_rename_or_remove_weight(self.weights, input_weight.name)
hidden_size = node.get_attr('hidden_size', input_weight_np.shape[1]/4) hidden_size = node.get_attr('hidden_size', input_weight_np.shape[1]/4)
input_size = input_weight_np.shape[2] input_size = input_weight_np.shape[2]
hidden_weight_np = _const_weight_or_none(hidden_weight) hidden_weight_np = _const_weight_or_none(hidden_weight)
_rename_or_remove_weight(self.weights, hidden_weight.name)
bias_np = _const_weight_or_none(bias) bias_np = _const_weight_or_none(bias)
_rename_or_remove_weight(self.weights, bias.name)
input_bias_np = bias_np[:, :4*hidden_size] input_bias_np = bias_np[:, :4*hidden_size]
hidden_bias_np = bias_np[:, 4*hidden_size:] hidden_bias_np = bias_np[:, 4*hidden_size:]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册