未验证 提交 ea254a8d 编写于 作者: W WJJ1995 提交者: GitHub

Fixed Flatten op (#797)

* fixed Flatten

* fixed Where and ConstantOfShape

* fixed inf bug

* deal with comments
上级 c1fba5c1
...@@ -262,6 +262,8 @@ class OpSet9(): ...@@ -262,6 +262,8 @@ class OpSet9():
shape = node.out_shapes[0] shape = node.out_shapes[0]
if hasattr(node.weight, "shape") and len(node.weight.shape) == 0: if hasattr(node.weight, "shape") and len(node.weight.shape) == 0:
if node.weight == float('inf') or node.weight == float('-inf'):
node.weight = string(node.weight)
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.full", "paddle.full",
inputs={}, inputs={},
...@@ -792,6 +794,8 @@ class OpSet9(): ...@@ -792,6 +794,8 @@ class OpSet9():
if len(value) == 1: if len(value) == 1:
value = value.tolist() value = value.tolist()
value = value[0] value = value[0]
if value == float('inf') or value == float('-inf'):
value = string(value)
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.full", "paddle.full",
inputs={}, inputs={},
...@@ -1195,7 +1199,6 @@ class OpSet9(): ...@@ -1195,7 +1199,6 @@ class OpSet9():
@print_mapping_info @print_mapping_info
def ConstantOfShape(self, node): def ConstantOfShape(self, node):
val_shape = self.graph.get_input_node(node, idx=0, copy=True) val_shape = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True)
value = node.get_attr('value') value = node.get_attr('value')
dtype = value.dtype dtype = value.dtype
...@@ -1204,6 +1207,8 @@ class OpSet9(): ...@@ -1204,6 +1207,8 @@ class OpSet9():
'this is not supported') 'this is not supported')
if len(value) == 1: if len(value) == 1:
value = value[0] value = value[0]
if value == float('inf') or value == float('-inf'):
value = string(value)
layer_attrs = {'dtype': string(dtype), 'fill_value': value} layer_attrs = {'dtype': string(dtype), 'fill_value': value}
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.full", "paddle.full",
...@@ -1563,20 +1568,37 @@ class OpSet9(): ...@@ -1563,20 +1568,37 @@ class OpSet9():
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
output_shape = val_x.out_shapes[0] output_shape = val_x.out_shapes[0]
axis = node.get_attr('axis', 1) axis = node.get_attr('axis', 1)
shape_list = [1, 1]
if axis == 0: if axis == 0:
for s in output_shape: self.paddle_graph.add_layer(
shape_list[1] *= s 'paddle.reshape',
inputs={"x": val_x.name},
outputs=[node.name],
shape=[1, -1])
else: else:
for s in output_shape[:axis]: if len(output_shape) != 0:
shape_list[0] *= s shape_list = [1, 1]
for s in output_shape[axis:]: for s in output_shape[:axis]:
shape_list[1] *= s shape_list[0] *= s
self.paddle_graph.add_layer( for s in output_shape[axis:]:
'paddle.reshape', shape_list[1] *= s
inputs={"x": val_x.name}, self.paddle_graph.add_layer(
outputs=[node.name], 'paddle.reshape',
shape=shape_list) inputs={"x": val_x.name},
outputs=[node.name],
shape=shape_list)
else:
# flatten + reshape
self.paddle_graph.add_layer(
"paddle.flatten",
inputs={"input": val_x.name},
outputs=[val_x.name + "_flatten"],
start_axis=[0],
stop_axis=[axis])
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={'x': val_x.name + "_flatten"},
outputs=[node.name],
shape=[0, -1])
@print_mapping_info @print_mapping_info
def Gemm(self, node): def Gemm(self, node):
...@@ -1846,40 +1868,13 @@ class OpSet9(): ...@@ -1846,40 +1868,13 @@ class OpSet9():
val_x = self.graph.get_input_node(node, idx=1, copy=True) val_x = self.graph.get_input_node(node, idx=1, copy=True)
val_y = self.graph.get_input_node(node, idx=2, copy=True) val_y = self.graph.get_input_node(node, idx=2, copy=True)
not_condition = condition.name + '_not'
self.paddle_graph.add_layer(
"paddle.logical_not",
inputs={"x": condition.name},
outputs=[not_condition])
cast_not_condition = not_condition + '_cast'
self.paddle_graph.add_layer(
"paddle.cast",
inputs={"x": not_condition},
outputs=[cast_not_condition],
dtype=string(val_x.dtype))
cast_condition = condition.name + '_cast'
self.paddle_graph.add_layer(
"paddle.cast",
inputs={"x": condition.name},
outputs=[cast_condition],
dtype=string(val_x.dtype))
mul_val_x = val_x.name + '_mul'
self.paddle_graph.add_layer(
"paddle.multiply",
inputs={'x': val_x.name,
'y': cast_condition},
outputs=[mul_val_x])
mul_val_y = val_y.name + '_mul'
self.paddle_graph.add_layer(
"paddle.multiply",
inputs={'x': val_y.name,
'y': cast_not_condition},
outputs=[mul_val_y])
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.add", "paddle.where",
inputs={'x': mul_val_x, inputs={
'y': mul_val_y}, 'condition': condition.name,
'x': val_x.name,
'y': val_y.name
},
outputs=[node.name]) outputs=[node.name])
@print_mapping_info @print_mapping_info
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册