From 8085ee04ac64a673fd6d3995c7ee83508435dedb Mon Sep 17 00:00:00 2001 From: SunAhong1993 Date: Thu, 25 Feb 2021 14:30:26 +0800 Subject: [PATCH] fix the bn --- .../static/caffe2paddle/caffe_op_mapper.py | 37 +++++++++++++++---- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/x2paddle/op_mapper/static/caffe2paddle/caffe_op_mapper.py b/x2paddle/op_mapper/static/caffe2paddle/caffe_op_mapper.py index 78389a2..d76c977 100644 --- a/x2paddle/op_mapper/static/caffe2paddle/caffe_op_mapper.py +++ b/x2paddle/op_mapper/static/caffe2paddle/caffe_op_mapper.py @@ -771,6 +771,12 @@ class CaffeOpMapper(OpMapper): 'epsilon': eps, 'momentum': momentum } + if len(node.in_shapes[0]) == 2: + self.paddle_graph.add_layer( + "paddle.unsqueeze", + inputs={"x": input.name}, + outputs=[input.name], + axis=[2,3]) self.paddle_graph.add_layer( kernel="paddle.nn.functional.batch_norm", inputs={"x": input.name, @@ -780,6 +786,12 @@ class CaffeOpMapper(OpMapper): "running_var": variance_name,}, outputs=[node.name], **layer_attrs) + if len(node.in_shapes[0]) == 2: + self.paddle_graph.add_layer( + "paddle.squeeze", + inputs={"x": node.layer_name}, + outputs=[node.layer_name], + axis=[2,3]) def Scale(self, node): if node.data is None: @@ -795,8 +807,13 @@ class CaffeOpMapper(OpMapper): else: self.params[node.name + "_cparam1"] = np.squeeze(node.data[ 0]).astype("float32") - self.params[node.name + "_cparam2"] = np.squeeze(node.data[ - 1]).astype("float32") + if not node.layer.scale_param.bias_term: + self.params[node.layer_name + "_cparam2"] = np.zeros([ + node.in_shapes[0][1], + ]).astype("float32") + else: + self.params[node.layer_name + "_cparam2"] = np.squeeze(node.data[ + 1]).astype("float32") params = node.layer.scale_param axis = params.axis inputs = [] @@ -826,11 +843,17 @@ class CaffeOpMapper(OpMapper): inputs_dict = {} inputs_dict['x'] = input0_name inputs_dict['y'] = node.name + "_cparam1" - self.paddle_graph.add_layer( - "paddle.multiply", - inputs=inputs_dict, - outputs=[node.name + "_mul"], - axis=axis) + if len(node.in_shapes[0]) == 2: + self.paddle_graph.add_layer( + "paddle.multiply", + inputs=inputs_dict, + outputs=[node.layer_name + "_mul"]) + else: + self.paddle_graph.add_layer( + "paddle.multiply", + inputs=inputs_dict, + outputs=[node.layer_name + "_mul"], + axis=axis) self.paddle_graph.add_layer( "paddle.static.create_parameter", inputs={}, -- GitLab