提交 6927786f 编写于 作者: J jiangjiajun@baidu.com

add efficient support:

上级 d1b9f2fa
...@@ -211,7 +211,10 @@ def main(): ...@@ -211,7 +211,10 @@ def main():
try: try:
import paddle import paddle
v0, v1, v2 = paddle.__version__.split('.') v0, v1, v2 = paddle.__version__.split('.')
if int(v0) != 1 or int(v1) < 6: print("paddle.__version__ = {}".format(paddle.__version__))
if v0 == '0' and v1 == '0' and v2 == '0':
print("[WARNING] You are use develop version of paddlepaddle")
elif int(v0) != 1 or int(v1) < 6:
print("[ERROR] paddlepaddle>=1.6.0 is required") print("[ERROR] paddlepaddle>=1.6.0 is required")
return return
except: except:
......
...@@ -48,7 +48,10 @@ class TFGraphNode(GraphNode): ...@@ -48,7 +48,10 @@ class TFGraphNode(GraphNode):
@property @property
def out_shapes(self): def out_shapes(self):
values = self.layer.attr["_output_shapes"].list.shape if self.layer_type == "OneShotIterator":
values = self.layer.attr["output_shapes"].list.shape
else:
values = self.layer.attr["_output_shapes"].list.shape
out_shapes = list() out_shapes = list()
for value in values: for value in values:
shape = [dim.size for dim in value.dim] shape = [dim.size for dim in value.dim]
...@@ -62,6 +65,8 @@ class TFGraphNode(GraphNode): ...@@ -62,6 +65,8 @@ class TFGraphNode(GraphNode):
dtype = self.layer.attr[k].type dtype = self.layer.attr[k].type
if dtype > 0: if dtype > 0:
break break
if dtype == 0:
dtype = self.layer.attr['output_types'].list.type[0]
if dtype not in self.dtype_map: if dtype not in self.dtype_map:
raise Exception("Dtype[{}] not in dtype_map".format(dtype)) raise Exception("Dtype[{}] not in dtype_map".format(dtype))
return self.dtype_map[dtype] return self.dtype_map[dtype]
...@@ -226,7 +231,7 @@ class TFGraph(Graph): ...@@ -226,7 +231,7 @@ class TFGraph(Graph):
def _remove_identity_node(self): def _remove_identity_node(self):
identity_ops = [ identity_ops = [
'Identity', 'StopGradient', 'Switch', 'Merge', 'Identity', 'StopGradient', 'Switch', 'Merge',
'PlaceholderWithDefault' 'PlaceholderWithDefault', 'IteratorGetNext'
] ]
identity_node = list() identity_node = list()
for node_name, node in self.node_map.items(): for node_name, node in self.node_map.items():
...@@ -317,7 +322,7 @@ class TFDecoder(object): ...@@ -317,7 +322,7 @@ class TFDecoder(object):
graph_def = cp.deepcopy(graph_def) graph_def = cp.deepcopy(graph_def)
input_map = dict() input_map = dict()
for layer in graph_def.node: for layer in graph_def.node:
if layer.op != "Placeholder": if layer.op != "Placeholder" and layer.op != "OneShotIterator":
continue continue
graph_node = TFGraphNode(layer) graph_node = TFGraphNode(layer)
dtype = graph_node.layer.attr['dtype'].type dtype = graph_node.layer.attr['dtype'].type
...@@ -335,6 +340,11 @@ class TFDecoder(object): ...@@ -335,6 +340,11 @@ class TFDecoder(object):
if shape.count(-1) > 1: if shape.count(-1) > 1:
need_define_shape = 2 need_define_shape = 2
if need_define_shape == 1:
shape = graph_node.out_shapes[0]
if len(shape) > 0 and shape.count(-1) < 2:
need_define_shape = 0
if need_define_shape > 0: if need_define_shape > 0:
shape = None shape = None
if graph_node.get_attr("shape"): if graph_node.get_attr("shape"):
......
...@@ -85,7 +85,7 @@ class TFOpMapper(OpMapper): ...@@ -85,7 +85,7 @@ class TFOpMapper(OpMapper):
not_placeholder = list() not_placeholder = list()
for name in self.graph.input_nodes: for name in self.graph.input_nodes:
if self.graph.get_node(name).layer_type != "Placeholder": if self.graph.get_node(name).layer_type != "Placeholder" and self.graph.get_node(name).layer_type != "OneShotIterator":
not_placeholder.append(name) not_placeholder.append(name)
for name in not_placeholder: for name in not_placeholder:
idx = self.graph.input_nodes.index(name) idx = self.graph.input_nodes.index(name)
...@@ -287,6 +287,9 @@ class TFOpMapper(OpMapper): ...@@ -287,6 +287,9 @@ class TFOpMapper(OpMapper):
output=node, output=node,
param_attr=attr) param_attr=attr)
def OneShotIterator(self, node):
return self.Placeholder(node)
def Const(self, node): def Const(self, node):
shape = node.out_shapes[0] shape = node.out_shapes[0]
dtype = node.dtype dtype = node.dtype
...@@ -492,6 +495,9 @@ class TFOpMapper(OpMapper): ...@@ -492,6 +495,9 @@ class TFOpMapper(OpMapper):
output=node, output=node,
param_attr=attr) param_attr=attr)
def FusedBatchNormV3(self, node):
return self.FusedBatchNorm(node)
def DepthwiseConv2dNative(self, node): def DepthwiseConv2dNative(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
kernel = self.graph.get_node(node.layer.input[1], copy=True) kernel = self.graph.get_node(node.layer.input[1], copy=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册