提交 6ce10dc4 编写于 作者: S SunAhong1993

fix the tensorflow

上级 f64b23a7
...@@ -73,15 +73,17 @@ class TFOpMapper(OpMapper): ...@@ -73,15 +73,17 @@ class TFOpMapper(OpMapper):
'Sub': 'fluid.layers.elementwise_sub', 'Sub': 'fluid.layers.elementwise_sub',
'Maximum': 'paddle.maximum', 'Maximum': 'paddle.maximum',
'Minimum': 'paddle.minimum', 'Minimum': 'paddle.minimum',
'Mul': 'paddle.multiply',
'FloorDiv': 'paddle.floor_divide',
'FloorMod': 'paddle.floor_mod',
'LogicalAnd': 'logical_and',
}
bool_ops = {
'LessEqual': 'paddle.less_equal', 'LessEqual': 'paddle.less_equal',
'GreaterEqual': 'paddle.greater_equal', 'GreaterEqual': 'paddle.greater_equal',
'Greater': 'paddle.greater_than', 'Greater': 'paddle.greater_than',
'NotEqual': 'paddle.not_equal', 'NotEqual': 'paddle.not_equal',
'Equal': 'paddle.equal', 'Equal': 'paddle.equal',
'Mul': 'paddle.multiply',
'FloorDiv': 'paddle.floor_divide',
'FloorMod': 'paddle.floor_mod',
'LogicalAnd': 'logical_and',
} }
def __init__(self, decoder): def __init__(self, decoder):
...@@ -123,6 +125,8 @@ class TFOpMapper(OpMapper): ...@@ -123,6 +125,8 @@ class TFOpMapper(OpMapper):
self.directly_map(node) self.directly_map(node)
elif op in self.elementwise_ops: elif op in self.elementwise_ops:
self.elementwise_map(node) self.elementwise_map(node)
elif op in self.bool_ops:
self.bool_map(node)
elif hasattr(self, op): elif hasattr(self, op):
func = getattr(self, op) func = getattr(self, op)
func(node) func(node)
...@@ -138,7 +142,8 @@ class TFOpMapper(OpMapper): ...@@ -138,7 +142,8 @@ class TFOpMapper(OpMapper):
op = node.layer_type op = node.layer_type
if not hasattr(self, op) and \ if not hasattr(self, op) and \
op not in self.directly_map_ops and \ op not in self.directly_map_ops and \
op not in self.elementwise_ops: op not in self.elementwise_ops and \
op not in self.bool_ops:
unsupported_ops.add(op) unsupported_ops.add(op)
if len(unsupported_ops) == 0: if len(unsupported_ops) == 0:
return True return True
...@@ -178,8 +183,10 @@ class TFOpMapper(OpMapper): ...@@ -178,8 +183,10 @@ class TFOpMapper(OpMapper):
outputs=[node.name], outputs=[node.name],
**layer_attrs) **layer_attrs)
def elementwise_map(self, node): def elementwise_map(self, node, op_type=None):
op_type = self.elementwise_ops[node.layer_type] if op_type is None:
assert node.layer_type in self.elementwise_ops
op_type = self.elementwise_ops[node.layer_type]
x = self.graph.get_input_node(node, 0) x = self.graph.get_input_node(node, 0)
y = self.graph.get_input_node(node, 1) y = self.graph.get_input_node(node, 1)
x_shape = x.out_shapes[0] x_shape = x.out_shapes[0]
...@@ -190,6 +197,11 @@ class TFOpMapper(OpMapper): ...@@ -190,6 +197,11 @@ class TFOpMapper(OpMapper):
"y": y.name}, "y": y.name},
outputs=[node.name]) outputs=[node.name])
self.paddle_graph.layers[layer_id].input_shapes = {"x": x_shape, "y": y_shape} self.paddle_graph.layers[layer_id].input_shapes = {"x": x_shape, "y": y_shape}
def bool_map(self, node):
op_type = self.bool_ops[node.layer_type]
self.elementwise_map(node, op_type)
node.set_dtype("bool")
def Placeholder(self, node): def Placeholder(self, node):
shape = node.out_shapes[0] shape = node.out_shapes[0]
......
...@@ -75,15 +75,17 @@ class TFOpMapper(OpMapper): ...@@ -75,15 +75,17 @@ class TFOpMapper(OpMapper):
'Sub': 'fluid.layers.elementwise_sub', 'Sub': 'fluid.layers.elementwise_sub',
'Maximum': 'paddle.maximum', 'Maximum': 'paddle.maximum',
'Minimum': 'paddle.minimum', 'Minimum': 'paddle.minimum',
'Mul': 'paddle.multiply',
'FloorDiv': 'paddle.floor_divide',
'FloorMod': 'paddle.floor_mod',
'LogicalAnd': 'logical_and',
}
bool_ops = {
'LessEqual': 'paddle.less_equal', 'LessEqual': 'paddle.less_equal',
'GreaterEqual': 'paddle.greater_equal', 'GreaterEqual': 'paddle.greater_equal',
'Greater': 'paddle.greater_than', 'Greater': 'paddle.greater_than',
'NotEqual': 'paddle.not_equal', 'NotEqual': 'paddle.not_equal',
'Equal': 'paddle.equal', 'Equal': 'paddle.equal',
'Mul': 'paddle.multiply',
'FloorDiv': 'paddle.floor_divide',
'FloorMod': 'paddle.floor_mod',
'LogicalAnd': 'logical_and',
} }
def __init__(self, decoder): def __init__(self, decoder):
...@@ -124,6 +126,8 @@ class TFOpMapper(OpMapper): ...@@ -124,6 +126,8 @@ class TFOpMapper(OpMapper):
self.directly_map(node) self.directly_map(node)
elif op in self.elementwise_ops: elif op in self.elementwise_ops:
self.elementwise_map(node) self.elementwise_map(node)
elif op in self.bool_ops:
self.bool_map(node)
elif hasattr(self, op): elif hasattr(self, op):
func = getattr(self, op) func = getattr(self, op)
func(node) func(node)
...@@ -138,7 +142,8 @@ class TFOpMapper(OpMapper): ...@@ -138,7 +142,8 @@ class TFOpMapper(OpMapper):
op = node.layer_type op = node.layer_type
if not hasattr(self, op) and \ if not hasattr(self, op) and \
op not in self.directly_map_ops and \ op not in self.directly_map_ops and \
op not in self.elementwise_ops: op not in self.elementwise_ops and \
op not in self.bool_ops:
unsupported_ops.add(op) unsupported_ops.add(op)
if len(unsupported_ops) == 0: if len(unsupported_ops) == 0:
return True return True
...@@ -167,9 +172,10 @@ class TFOpMapper(OpMapper): ...@@ -167,9 +172,10 @@ class TFOpMapper(OpMapper):
outputs=[node.name], outputs=[node.name],
**attr) **attr)
def elementwise_map(self, node): def elementwise_map(self, node, op_type=None):
assert node.layer_type in self.elementwise_ops if op_type is None:
op_type = self.elementwise_ops[node.layer_type] assert node.layer_type in self.elementwise_ops
op_type = self.elementwise_ops[node.layer_type]
x = self.graph.get_node(node.layer.input[0]) x = self.graph.get_node(node.layer.input[0])
y = self.graph.get_node(node.layer.input[1]) y = self.graph.get_node(node.layer.input[1])
x_shape = x.out_shapes[0] x_shape = x.out_shapes[0]
...@@ -180,6 +186,11 @@ class TFOpMapper(OpMapper): ...@@ -180,6 +186,11 @@ class TFOpMapper(OpMapper):
"y": y.name}, "y": y.name},
outputs=[node.name]) outputs=[node.name])
self.paddle_graph.layers[layer_id].input_shapes = {"x": x_shape, "y": y_shape} self.paddle_graph.layers[layer_id].input_shapes = {"x": x_shape, "y": y_shape}
def bool_map(self, node):
op_type = self.bool_ops[node.layer_type]
self.elementwise_map(node, op_type)
node.set_dtype("bool")
def Placeholder(self, node): def Placeholder(self, node):
shape = node.out_shapes[0] shape = node.out_shapes[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册