From 0fe0fc121784253e2e1197040d04f6865f28d50a Mon Sep 17 00:00:00 2001 From: Channingss Date: Tue, 11 Aug 2020 02:44:59 +0000 Subject: [PATCH] update elementwise_ops for paddle1.8 --- .../op_mapper/onnx2paddle/opset9/opset.py | 61 +------------------ 1 file changed, 3 insertions(+), 58 deletions(-) diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index 064a26b..d7b4fb2 100644 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -243,64 +243,9 @@ class OpSet9(): val_x = self.graph.get_input_node(node, idx=0, copy=True) val_y = self.graph.get_input_node(node, idx=1, copy=True) - val_y_shape = val_y.out_shapes[0] - val_x_shape = val_x.out_shapes[0] - inputs = {} - if len(val_x_shape) < len(val_y_shape): - if node.layer_type in ['Mul', 'Add']: - val_x, val_y = val_y, val_x - val_y_shape, val_x_shape = val_x_shape, val_y_shape - inputs = {'x': val_x, 'y': val_y} - elif node.layer_type in ['Sub', 'Div', 'Pow']: - val_x_expand = val_x.layer_name + '_expand' - x_value = _const_weight_or_none(val_x) - if (val_x_shape == [1] or len(val_x_shape) == 0) and x_value: - attr = { - 'shape': val_y_shape, - 'dtype': string(val_x.dtype), - 'value': x_value - if len(val_x_shape) == 0 else x_value[0] - } - node.fluid_code.add_layer( - 'fill_constant', - inputs=None, - output=val_x_expand, - param_attr=attr) - val_x_shape = val_y_shape - inputs = {'x': val_x_expand, 'y': val_y} - else: - assert 'Unsupported situation happened.' - else: - inputs = {'x': val_x, 'y': val_y} - print(node.layer_name) - print(val_x_shape, val_y_shape) - str_y_shape = ','.join(str(e) for e in val_y_shape) - str_x_shape = ','.join(str(e) for e in val_x_shape) - slice_idx = 0 - if str_y_shape not in str_x_shape: - for dim in val_y_shape: - if dim == 1: - slice_idx += 1 - else: - break - if slice_idx < len(val_y_shape) and slice_idx > 0: - val_y_reshaped = val_y_shape[slice_idx:] - var_y_reshaped = val_y.layer_name + '_reshaped' - attr_reshaped = { - 'shape': val_y_reshaped, - 'name': string(var_y_reshaped) - } - node.fluid_code.add_layer( - 'reshape', - inputs=val_y, - output=var_y_reshaped, - param_attr=attr_reshaped) - inputs['y'] = var_y_reshaped - node.fluid_code.add_layer( - op_type, inputs=inputs, output=node, param_attr=None) - else: - node.fluid_code.add_layer( - op_type, inputs=inputs, output=node, param_attr=None) + inputs = {'x': val_x, 'y': val_y} + node.fluid_code.add_layer( + op_type, inputs=inputs, output=node, param_attr=None) @print_mapping_info def place_holder(self, node): -- GitLab