提交 fa6ada39 编写于 作者: J jiangjiajun

support new situation

上级 dd9e3172
......@@ -170,7 +170,28 @@ class TFOpMapper(OpMapper):
x_shape = y.out_shapes[0]
y_shape = x.out_shapes[0]
else:
raise Exception("Unexpected situation happend")
if len(x_shape) == 1 and len(y_shape) == 4 and x_shape[
0] == y_shape[-1] and y_shape.count(-1) < 1:
shape = [1, x_shape[0], 1, 1]
attr = {"shape": shape}
node.fluid_code.add_layer("reshape",
inputs=x_input,
output="reshape_x",
param_attr=attr)
if y_shape[0] != 1:
attr = {"expand_times": [y_shape[0], 1, 1, 1]}
node.fluid_code.add_layer("expand",
inputs="reshape_x",
output="reshape_x",
param_attr=attr)
inputs = {"x": "reshape_x", "y": y_input}
node.fluid_code.add_layer(op_type,
inputs=inputs,
output=node,
param_attr=None)
return
else:
raise Exception("Unexpected situation happend")
if len(x_shape) == 4 and len(y_shape) == 1:
if x_input.tf_data_format == "NHWC":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册