From 8f9bd9b67ed038a9f72b22799d1fee89b2600ccd Mon Sep 17 00:00:00 2001 From: Channingss Date: Mon, 10 Aug 2020 13:00:53 +0000 Subject: [PATCH] update Reshape&elementwise_map --- x2paddle/decoder/onnx_decoder.py | 1 - .../op_mapper/onnx2paddle/opset9/opset.py | 58 +++++++++++++++---- 2 files changed, 48 insertions(+), 11 deletions(-) diff --git a/x2paddle/decoder/onnx_decoder.py b/x2paddle/decoder/onnx_decoder.py index 07aa7f9..280b5b4 100644 --- a/x2paddle/decoder/onnx_decoder.py +++ b/x2paddle/decoder/onnx_decoder.py @@ -350,7 +350,6 @@ class ONNXGraph(Graph): node.out_shapes.append(value_info['shape']) else: node.out_shapes.append([]) - print(layer.name, node.out_shapes) class ONNXDecoder(object): diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index e20e2fc..58dfa77 100644 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -40,6 +40,21 @@ def _const_weight_or_none(node): return None +def _is_static_shape(shape): + negtive_dims = 0 + error_dims = 0 + for dim in shape: + if dim < 0: + negtive_dims += 1 + if dim != -1: + error_dims += 1 + if negtive_dims > 1: + return False + if error_dims > 0: + return False + return True + + def _get_same_padding(in_size, kernel_size, stride): new_size = int(math.ceil(in_size * 1.0 / stride)) pad_size = (new_size - 1) * stride + kernel_size - in_size @@ -230,11 +245,35 @@ class OpSet9(): val_y = self.graph.get_input_node(node, idx=1, copy=True) val_y_shape = val_y.out_shapes[0] val_x_shape = val_x.out_shapes[0] - + inputs = {} if len(val_x_shape) < len(val_y_shape): - val_x, val_y = val_y, val_x - val_y_shape, val_x_shape = val_x_shape, val_y_shape - + if node.layer_type in ['Mul', 'Add']: + val_x, val_y = val_y, val_x + val_y_shape, val_x_shape = val_x_shape, val_y_shape + inputs = {'x': val_x, 'y': val_y} + elif node.layer_type in ['Sub', 'Div', 'Pow']: + val_x_expand = val_x.layer_name + '_expand' + x_value = _const_weight_or_none(val_x) + if (val_x_shape == [1] or len(val_x_shape) == 0) and x_value: + attr = { + 'shape': val_y_shape, + 'dtype': string(val_x.dtype), + 'value': x_value + if len(val_x_shape) == 0 else x_value[0] + } + node.fluid_code.add_layer( + 'fill_constant', + inputs=None, + output=val_x_expand, + param_attr=attr) + val_x_shape = val_y_shape + inputs = {'x': val_x_expand, 'y': val_y} + else: + assert 'Unsupported situation happened.' + else: + inputs = {'x': val_x, 'y': val_y} + print(node.layer_name) + print(val_x_shape, val_y_shape) str_y_shape = ','.join(str(e) for e in val_y_shape) str_x_shape = ','.join(str(e) for e in val_x_shape) slice_idx = 0 @@ -244,7 +283,6 @@ class OpSet9(): slice_idx += 1 else: break - attr = {"name": string(node.layer_name)} if slice_idx < len(val_y_shape) and slice_idx > 0: val_y_reshaped = val_y_shape[slice_idx:] var_y_reshaped = val_y.layer_name + '_reshaped' @@ -257,13 +295,12 @@ class OpSet9(): inputs=val_y, output=var_y_reshaped, param_attr=attr_reshaped) - inputs = {'x': val_x, 'y': var_y_reshaped} + inputs['y'] = var_y_reshaped node.fluid_code.add_layer( - op_type, inputs=inputs, output=node, param_attr=attr) + op_type, inputs=inputs, output=node, param_attr=None) else: - inputs = {'x': val_x, 'y': val_y} node.fluid_code.add_layer( - op_type, inputs=inputs, output=node, param_attr=attr) + op_type, inputs=inputs, output=node, param_attr=None) @print_mapping_info def place_holder(self, node): @@ -941,7 +978,8 @@ class OpSet9(): inputs={'x': val_x}, output=node, param_attr={'shape': shape_value.tolist()}) - elif len(node.out_shapes[0]) > 0: + elif len(node.out_shapes[0]) > 0 and _is_static_shape(node.out_shapes[ + 0]): node.fluid_code.add_layer( 'reshape', inputs={'x': val_x, -- GitLab