From 9fc66d34cabf7d5052ac719a5751e5562572af0e Mon Sep 17 00:00:00 2001 From: Channingss Date: Thu, 14 Jan 2021 09:30:50 +0000 Subject: [PATCH] add _rename_or_remove_weight --- .../dygraph/onnx2paddle/opset9/opset.py | 40 +++++++++++++++---- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/x2paddle/op_mapper/dygraph/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/dygraph/onnx2paddle/opset9/opset.py index f87ea2c..5d7c516 100644 --- a/x2paddle/op_mapper/dygraph/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/dygraph/onnx2paddle/opset9/opset.py @@ -42,6 +42,15 @@ def _const_weight_or_none(node, necessary=False): return None +def rename_or_remove_weight(weights, origin_name, target_name=None): + if origin_name not in weights: + raise KeyError('{} not a key in {}'.format(origin_name, weights)) + if target_name is None: + # remove weight + weights.pop(origin_name) + # rename weight + weights[target_name] = weights.pop(origin_name) + def _is_static_shape(shape): negtive_dims = 0 error_dims = 0 @@ -1320,10 +1329,16 @@ class OpSet9(): epsilon = node.get_attr('epsilon', 1e-5) c = val_x.out_shapes[0][1] - self.weights[op_name + '.weight'] = self.weights[val_scale.name] - self.weights[op_name + '.bias'] = self.weights[val_b.name] - self.weights[op_name + '._variance'] = self.weights[val_var.name] - self.weights[op_name + '._mean'] = self.weights[val_mean.name] + _rename_or_remove_weight(self.weights, val_scale.name, op_name+'.weight') + _rename_or_remove_weight(self.weights, val_b.name, op_name+'.bias') + _rename_or_remove_weight(self.weights, val_var.name, op_name+'._variance') + _rename_or_remove_weight(self.weights, val_mean.name, op_name+'._mean') + + #self.weights[op_name + '.weight'] = self.weights[val_scale.name] + #self.weights[op_name + '.bias'] = self.weights[val_b.name] + #self.weights[op_name + '._variance'] = self.weights[val_var.name] + #self.weights[op_name + '._mean'] = self.weights[val_mean.name] + # Attribute: spatial is used in BatchNormalization-1,6,7 spatial = bool(node.get_attr('spatial')) layer_attrs = { @@ -1395,11 +1410,13 @@ class OpSet9(): else: if mode == 'channel': slope_data = _const_weight_or_none(val_slope) + _rename_or_remove_weight(self.weights, val_slope.name) if len(shape_slope) > 1: self.weights[op_name+'._weight'] = np.reshape(slope_data, shape_slope[0]) num_parameters = val_x.out_shapes[0][1] else: num_parameters = 1 + _rename_or_remove_weight(self.weights, val_slope.name) self.weights[op_name+'._weight'] = np.reshape(self.weights[val_slope.name], [1]) self.paddle_graph.add_layer( "paddle.nn.PReLU", @@ -1687,13 +1704,15 @@ class OpSet9(): while val_w_name in self.done_weight_list: val_w_name += "__repeat" self.done_weight_list.append(val_w_name) - self.weights[op_name + '.weight'] = self.weights[val_w.name] + #self.weights[op_name + '.weight'] = self.weights[val_w.name] + _rename_or_remove_weight(self.weights, val_w.name, op_name+'.weight') if has_bias: val_b_name = val_b.name while val_b_name in self.done_weight_list: val_b_name += "__repeat" self.done_weight_list.append(val_b_name) - self.weights[op_name + '.bias'] = self.weights[val_b.name] + #self.weights[op_name + '.bias'] = self.weights[val_b.name] + _rename_or_remove_weight(self.weights, val_b.name, op_name+'.bias') else: layer_attrs["bias_attr"] = False input_shape = val_x.out_shapes[0] @@ -1761,9 +1780,11 @@ class OpSet9(): "groups": num_groups, "output_padding":out_padding} - self.weights[op_name + '.weight'] = self.weights[val_w.name] + #self.weights[op_name + '.weight'] = self.weights[val_w.name] + _rename_or_remove_weight(self.weights, val_w.name, op_name+'.weight',) if val_b is not None: - self.weights[op_name + '.bias'] = self.weights[val_b.name] + #self.weights[op_name + '.bias'] = self.weights[val_b.name] + _rename_or_remove_weight(self.weights, val_b.name, op_name+'.bias') self.paddle_graph.add_layer( kernel=paddle_op, inputs=inputs_dict, @@ -1881,10 +1902,13 @@ class OpSet9(): ) input_weight_np = _const_weight_or_none(input_weight) + _rename_or_remove_weight(self.weights, input_weight.name) hidden_size = node.get_attr('hidden_size', input_weight_np.shape[1]/4) input_size = input_weight_np.shape[2] hidden_weight_np = _const_weight_or_none(hidden_weight) + _rename_or_remove_weight(self.weights, hidden_weight.name) bias_np = _const_weight_or_none(bias) + _rename_or_remove_weight(self.weights, bias.name) input_bias_np = bias_np[:, :4*hidden_size] hidden_bias_np = bias_np[:, 4*hidden_size:] -- GitLab