diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index 03cc642ae0e26d4d996acb992fa4e2dc29023d24..20538cc7051b725abc09d728610e9caf3b13a0d2 100644 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -556,25 +556,26 @@ class OpSet9(): def Expand(self, node): val_x = self.graph.get_input_node(node, idx=0, copy=True) val_shape = self.graph.get_input_node(node, idx=1, copy=True) - if len(val_shape.outputs) == 1: self.omit_nodes.append(val_shape.layer_name) - - val_y = self.graph.get_node(node.layer.output[0], copy=True) - out_shape = node.out_shapes[0] val_x_dtype = val_x.dtype - name_ones = node.layer_name + '_ones' - attr_ones = {'shape': out_shape, 'dtype': string(val_x_dtype)} + attr_ones = { + 'shape': val_shape.layer_name, + 'dtype': string(val_x_dtype), + 'value': 1 + } node.fluid_code.add_layer( - 'ones', inputs=None, output=name_ones, param_attr=attr_ones) + 'fill_constant', + inputs=None, + output=name_ones, + param_attr=attr_ones) inputs = {'x': name_ones, 'y': val_x} - attr = {'name': string(node.layer_name)} node.fluid_code.add_layer( 'elementwise_mul', inputs=inputs, output=node.layer_name, - param_attr=attr) + param_attr=None) @print_mapping_info def Gather(self, node): @@ -1341,7 +1342,8 @@ class OpSet9(): if val_repeats.dtype != 'int32': attr = {"dtype": string("int32")} node.fluid_code.add_layer( - "cast", inputs=repeats, + "cast", + inputs=repeats, output="{}.tmp".format(repeats), param_attr=attr) repeats = "{}.tmp".format(repeats)