提交 e8bd0eaa 编写于 作者: J jiangjiajun

bug modify for image matting

上级 da805be1
......@@ -25,7 +25,8 @@ def export_paddle_param(param, param_name, dir):
"int64": [framework_pb2.VarType.INT64, 'q'],
"float16": [framework_pb2.VarType.FP16, 'e'],
"float32": [framework_pb2.VarType.FP32, 'f'],
"float64": [framework_pb2.VarType.FP64, 'd']
"float64": [framework_pb2.VarType.FP64, 'd'],
"bool": [framework_pb2.VarType.BOOL, None]
}
shape = param.shape
if len(shape) == 0:
......
......@@ -25,20 +25,26 @@ import sys
class TFGraphNode(GraphNode):
def __init__(self, layer, layer_name=None, data_format="NHWC"):
if layer_name is None:
super(TFGraphNode,
self).__init__(layer,
layer.name.replace('/', '_').replace('-', '_'))
super(TFGraphNode, self).__init__(
layer,
layer.name.replace('/', '_').replace('-', '_').replace('^', ''))
else:
super(TFGraphNode,
self).__init__(layer,
layer_name.replace('/', '_').replace('-', '_'))
super(TFGraphNode, self).__init__(
layer,
layer_name.replace('/', '_').replace('-', '_').replace('^', ''))
self.layer_type = layer.op
self.tf_data_format = data_format
self.pd_data_format = "NCHW"
self.fluid_code = FluidCode()
self.dtype_map = {1: "float32", 3: "int32", 4: "uint8", 9: "int64"}
self.dtype_map = {
1: "float32",
3: "int32",
4: "uint8",
9: "int64",
10: "bool"
}
@property
def out_shapes(self):
......@@ -113,7 +119,9 @@ class TFGraph(Graph):
for layer_name, node in self.node_map.items():
for in_node in node.layer.input:
in_node = in_node.replace('/', '_').replace('-', '_')
in_node = in_node.replace('/',
'_').replace('-',
'_').replace('^', '')
if in_node not in self.node_map:
if in_node.strip().split(':')[0] in self.node_map:
self.connect(in_node.strip().split(':')[0], layer_name)
......@@ -140,6 +148,9 @@ class TFGraph(Graph):
node = super(TFGraph, self).get_node(new_node_name, copy)
if node is None:
return None
if node.layer_type == "Switch":
if hasattr(node, 'index'):
del node.index
if len(items) == 1 and node.layer_type in self.multi_out_ops:
node.index = 0
return node
......@@ -184,9 +195,13 @@ class TFGraph(Graph):
del self.topo_sort[idx]
def _remove_identity_node(self):
identity_ops = [
'Identity', 'StopGradient', 'Switch', 'Merge',
'PlaceholderWithDefault'
]
identity_node = list()
for node_name, node in self.node_map.items():
if node.layer_type == "Identity" or node.layer_type == "StopGradient":
if node.layer_type in identity_ops:
identity_node.append(node_name)
for node_name in identity_node:
......
......@@ -125,9 +125,9 @@ class TFOpMapper(OpMapper):
in_node = self.graph.get_node(in_node_name)
out_node = self.graph.get_node(out_node_name)
index = in_node.outputs.index(out_node_name)
del in_node.outputs[index]
# del in_node.outputs[index]
index = out_node.inputs.index(in_node_name)
del out_node.inputs[index]
# del out_node.inputs[index]
self.omit_nodes.append(in_node.layer_name)
def directly_map(self, node):
......@@ -624,6 +624,9 @@ class TFOpMapper(OpMapper):
output=node,
param_attr=perm)
return
if len(attr["shape"]) == 5:
attr["shape"] = [attr["shape"][i] for i in [0, 1, 4, 2, 3]]
node.fluid_code.add_layer("reshape",
inputs=input,
output=node,
......@@ -893,10 +896,23 @@ class TFOpMapper(OpMapper):
"starts": begin,
"ends": end
}
shrink_axis_mask = node.get_attr('shrink_axis_mask')
squeeze_dims = list()
for i in range(len(begin)):
x = shrink_axis_mask >> i & 1
if x == 1:
squeeze_dims.append(i)
node.fluid_code.add_layer("slice",
inputs=input,
output=node,
param_attr=attr)
if shrink_axis_mask > 0 and len(input.out_shapes[0]) == 5:
attr = {"axes": squeeze_dims}
node.fluid_code.add_layer("squeeze",
inputs=node,
output=node,
param_attr=attr)
def Slice(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册