diff --git a/x2paddle/op_mapper/tf_op_mapper.py b/x2paddle/op_mapper/tf_op_mapper.py index d9e565a115c17d518de9757540825d391dd1feee..847ebc8b96aab160b1747c1bb73a7fe6ecd4dae0 100644 --- a/x2paddle/op_mapper/tf_op_mapper.py +++ b/x2paddle/op_mapper/tf_op_mapper.py @@ -170,7 +170,28 @@ class TFOpMapper(OpMapper): x_shape = y.out_shapes[0] y_shape = x.out_shapes[0] else: - raise Exception("Unexpected situation happend") + if len(x_shape) == 1 and len(y_shape) == 4 and x_shape[ + 0] == y_shape[-1] and y_shape.count(-1) < 1: + shape = [1, x_shape[0], 1, 1] + attr = {"shape": shape} + node.fluid_code.add_layer("reshape", + inputs=x_input, + output="reshape_x", + param_attr=attr) + if y_shape[0] != 1: + attr = {"expand_times": [y_shape[0], 1, 1, 1]} + node.fluid_code.add_layer("expand", + inputs="reshape_x", + output="reshape_x", + param_attr=attr) + inputs = {"x": "reshape_x", "y": y_input} + node.fluid_code.add_layer(op_type, + inputs=inputs, + output=node, + param_attr=None) + return + else: + raise Exception("Unexpected situation happend") if len(x_shape) == 4 and len(y_shape) == 1: if x_input.tf_data_format == "NHWC":