未验证 提交 ea254a8d 编写于 作者: W WJJ1995 提交者: GitHub

Fixed Flatten op (#797)

* fixed Flatten

* fixed Where and ConstantOfShape

* fixed inf bug

* deal with comments
上级 c1fba5c1
......@@ -262,6 +262,8 @@ class OpSet9():
shape = node.out_shapes[0]
if hasattr(node.weight, "shape") and len(node.weight.shape) == 0:
if node.weight == float('inf') or node.weight == float('-inf'):
node.weight = string(node.weight)
self.paddle_graph.add_layer(
"paddle.full",
inputs={},
......@@ -792,6 +794,8 @@ class OpSet9():
if len(value) == 1:
value = value.tolist()
value = value[0]
if value == float('inf') or value == float('-inf'):
value = string(value)
self.paddle_graph.add_layer(
"paddle.full",
inputs={},
......@@ -1195,7 +1199,6 @@ class OpSet9():
@print_mapping_info
def ConstantOfShape(self, node):
val_shape = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
value = node.get_attr('value')
dtype = value.dtype
......@@ -1204,6 +1207,8 @@ class OpSet9():
'this is not supported')
if len(value) == 1:
value = value[0]
if value == float('inf') or value == float('-inf'):
value = string(value)
layer_attrs = {'dtype': string(dtype), 'fill_value': value}
self.paddle_graph.add_layer(
"paddle.full",
......@@ -1563,11 +1568,15 @@ class OpSet9():
val_x = self.graph.get_input_node(node, idx=0, copy=True)
output_shape = val_x.out_shapes[0]
axis = node.get_attr('axis', 1)
shape_list = [1, 1]
if axis == 0:
for s in output_shape:
shape_list[1] *= s
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": val_x.name},
outputs=[node.name],
shape=[1, -1])
else:
if len(output_shape) != 0:
shape_list = [1, 1]
for s in output_shape[:axis]:
shape_list[0] *= s
for s in output_shape[axis:]:
......@@ -1577,6 +1586,19 @@ class OpSet9():
inputs={"x": val_x.name},
outputs=[node.name],
shape=shape_list)
else:
# flatten + reshape
self.paddle_graph.add_layer(
"paddle.flatten",
inputs={"input": val_x.name},
outputs=[val_x.name + "_flatten"],
start_axis=[0],
stop_axis=[axis])
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={'x': val_x.name + "_flatten"},
outputs=[node.name],
shape=[0, -1])
@print_mapping_info
def Gemm(self, node):
......@@ -1846,40 +1868,13 @@ class OpSet9():
val_x = self.graph.get_input_node(node, idx=1, copy=True)
val_y = self.graph.get_input_node(node, idx=2, copy=True)
not_condition = condition.name + '_not'
self.paddle_graph.add_layer(
"paddle.logical_not",
inputs={"x": condition.name},
outputs=[not_condition])
cast_not_condition = not_condition + '_cast'
self.paddle_graph.add_layer(
"paddle.cast",
inputs={"x": not_condition},
outputs=[cast_not_condition],
dtype=string(val_x.dtype))
cast_condition = condition.name + '_cast'
self.paddle_graph.add_layer(
"paddle.cast",
inputs={"x": condition.name},
outputs=[cast_condition],
dtype=string(val_x.dtype))
mul_val_x = val_x.name + '_mul'
self.paddle_graph.add_layer(
"paddle.multiply",
inputs={'x': val_x.name,
'y': cast_condition},
outputs=[mul_val_x])
mul_val_y = val_y.name + '_mul'
self.paddle_graph.add_layer(
"paddle.multiply",
inputs={'x': val_y.name,
'y': cast_not_condition},
outputs=[mul_val_y])
self.paddle_graph.add_layer(
"paddle.add",
inputs={'x': mul_val_x,
'y': mul_val_y},
"paddle.where",
inputs={
'condition': condition.name,
'x': val_x.name,
'y': val_y.name
},
outputs=[node.name])
@print_mapping_info
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册