From fa6ada39950505012ea7381fa7d2889e67943e93 Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Thu, 19 Sep 2019 15:30:11 +0800 Subject: [PATCH] support new situation --- x2paddle/op_mapper/tf_op_mapper.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/x2paddle/op_mapper/tf_op_mapper.py b/x2paddle/op_mapper/tf_op_mapper.py index d9e565a..847ebc8 100644 --- a/x2paddle/op_mapper/tf_op_mapper.py +++ b/x2paddle/op_mapper/tf_op_mapper.py @@ -170,7 +170,28 @@ class TFOpMapper(OpMapper): x_shape = y.out_shapes[0] y_shape = x.out_shapes[0] else: - raise Exception("Unexpected situation happend") + if len(x_shape) == 1 and len(y_shape) == 4 and x_shape[ + 0] == y_shape[-1] and y_shape.count(-1) < 1: + shape = [1, x_shape[0], 1, 1] + attr = {"shape": shape} + node.fluid_code.add_layer("reshape", + inputs=x_input, + output="reshape_x", + param_attr=attr) + if y_shape[0] != 1: + attr = {"expand_times": [y_shape[0], 1, 1, 1]} + node.fluid_code.add_layer("expand", + inputs="reshape_x", + output="reshape_x", + param_attr=attr) + inputs = {"x": "reshape_x", "y": y_input} + node.fluid_code.add_layer(op_type, + inputs=inputs, + output=node, + param_attr=None) + return + else: + raise Exception("Unexpected situation happend") if len(x_shape) == 4 and len(y_shape) == 1: if x_input.tf_data_format == "NHWC": -- GitLab