提交 7ca9c323 编写于 作者: J jiangjiajun

add more models support for tensorflow

上级 16b75df4
...@@ -7,4 +7,3 @@ ...@@ -7,4 +7,3 @@
| Macrobull | Nai-Rui Luo | | Macrobull | Nai-Rui Luo |
| Channingss | Ling-Chi Chen | | Channingss | Ling-Chi Chen |
| mamingjie-China | Ming-Jie Ma | | mamingjie-China | Ming-Jie Ma |
...@@ -78,6 +78,8 @@ def tf2paddle(model_path, ...@@ -78,6 +78,8 @@ def tf2paddle(model_path,
define_input_shape=False): define_input_shape=False):
# check tensorflow installation and version # check tensorflow installation and version
try: try:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = '3'
import tensorflow as tf import tensorflow as tf
version = tf.__version__ version = tf.__version__
if version >= '2.0.0' or version < '1.0.0': if version >= '2.0.0' or version < '1.0.0':
...@@ -109,6 +111,9 @@ def tf2paddle(model_path, ...@@ -109,6 +111,9 @@ def tf2paddle(model_path,
optimizer = TFOptimizer(mapper) optimizer = TFOptimizer(mapper)
optimizer.delete_redundance_code() optimizer.delete_redundance_code()
optimizer.strip_graph() optimizer.strip_graph()
optimizer.merge_activation()
optimizer.merge_bias()
optimizer.remove_transpose()
mapper.save_inference_model(save_dir) mapper.save_inference_model(save_dir)
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
from x2paddle.core.graph import GraphNode, Graph from x2paddle.core.graph import GraphNode, Graph
from x2paddle.core.fluid_code import FluidCode from x2paddle.core.fluid_code import FluidCode
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
from tensorflow.python.platform import gfile
from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import attr_value_pb2
import tensorflow as tf import tensorflow as tf
import copy as cp import copy as cp
...@@ -140,7 +139,7 @@ class TFGraph(Graph): ...@@ -140,7 +139,7 @@ class TFGraph(Graph):
raise Exception("Node[{}] not in graph".format(node_name)) raise Exception("Node[{}] not in graph".format(node_name))
inputs = self.node_map[node_name].inputs inputs = self.node_map[node_name].inputs
outputs = self.node_map[node_name].outputs outputs = self.node_map[node_name].outputs
assert len(inputs) == 1 # assert len(inputs) == 1
input_node = self.node_map[inputs[0]] input_node = self.node_map[inputs[0]]
idx = input_node.outputs.index(node_name) idx = input_node.outputs.index(node_name)
del input_node.outputs[idx] del input_node.outputs[idx]
...@@ -205,18 +204,28 @@ class TFGraph(Graph): ...@@ -205,18 +204,28 @@ class TFGraph(Graph):
class TFDecoder(object): class TFDecoder(object):
def __init__(self, pb_model, data_format="NHWC", define_input_shape=False): def __init__(self, pb_model, data_format="NHWC", define_input_shape=False):
self.sess = tf.Session() try:
self.sess = tf.compat.v1.Session()
except:
self.sess = tf.Session()
self.input_info = dict() self.input_info = dict()
self.define_input_shape = define_input_shape self.define_input_shape = define_input_shape
with gfile.FastGFile(pb_model, 'rb') as f: with open(pb_model, 'rb') as f:
graph_def = tf.GraphDef() try:
graph_def = tf.compat.v1.GraphDef()
except:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read()) graph_def.ParseFromString(f.read())
input_map = self._check_input_shape(graph_def) input_map = self._check_input_shape(graph_def)
self._fix_output_shape(graph_def) self._fix_output_shape(graph_def)
self.sess.graph.as_default() self.sess.graph.as_default()
tf.import_graph_def(graph_def, name='', input_map=input_map) tf.import_graph_def(graph_def, name='', input_map=input_map)
self.sess.run(tf.global_variables_initializer()) try:
initializer = tf.compat.v1.global_variables_initializer()
except:
initializer = tf.global_variables_initializer()
self.sess.run(initializer)
self.tf_graph = TFGraph( self.tf_graph = TFGraph(
self.sess.graph._as_graph_def(add_shapes=True)[0], data_format) self.sess.graph._as_graph_def(add_shapes=True)[0], data_format)
...@@ -237,7 +246,6 @@ class TFDecoder(object): ...@@ -237,7 +246,6 @@ class TFDecoder(object):
continue continue
graph_node = TFGraphNode(layer) graph_node = TFGraphNode(layer)
dtype = graph_node.layer.attr['dtype'].type dtype = graph_node.layer.attr['dtype'].type
print("========dtype", dtype)
need_define_shape = 0 need_define_shape = 0
if self.define_input_shape: if self.define_input_shape:
...@@ -284,11 +292,17 @@ class TFDecoder(object): ...@@ -284,11 +292,17 @@ class TFDecoder(object):
for dim in shape.strip().split(',') for dim in shape.strip().split(',')
] ]
assert shape.count(None) <= 1, "Only one dimension can be None" assert shape.count(None) <= 1, "Only one dimension can be None"
print("]]]]]]]]]dtype", dtype) try:
x2paddle_input = tf.placeholder(dtype=dtype, x2paddle_input = tf.compat.v1.placeholder(
shape=shape, dtype=dtype,
name="x2paddle_{}".format( shape=shape,
layer.name)) name="x2paddle_{}".format(layer.name))
except:
x2paddle_input = tf.placeholder(dtype=dtype,
shape=shape,
name="x2paddle_{}".format(
layer.name))
input_map["{}:0".format(layer.name)] = x2paddle_input input_map["{}:0".format(layer.name)] = x2paddle_input
if shape.count(None) > 0: if shape.count(None) > 0:
shape[shape.index(None)] = -1 shape[shape.index(None)] = -1
...@@ -304,7 +318,6 @@ class TFDecoder(object): ...@@ -304,7 +318,6 @@ class TFDecoder(object):
# trick method # trick method
# should be removed after PaddlePaddle V1.6 been released # should be removed after PaddlePaddle V1.6 been released
def infer_tensor(self, graph_node): def infer_tensor(self, graph_node):
print("========== Use infer_tensor for tensor: ", graph_node.layer.name)
if hasattr(graph_node, "index"): if hasattr(graph_node, "index"):
tensor_name = graph_node.layer.name + ":{}".format(graph_node.index) tensor_name = graph_node.layer.name + ":{}".format(graph_node.index)
else: else:
...@@ -320,8 +333,6 @@ class TFDecoder(object): ...@@ -320,8 +333,6 @@ class TFDecoder(object):
return self.sess.run([output_tensor], feed)[0] return self.sess.run([output_tensor], feed)[0]
def infer_shape_tensor(self, graph_node, out_shape=None): def infer_shape_tensor(self, graph_node, out_shape=None):
print("========== Use infer_shape_tensor for tensor: ",
graph_node.layer.name)
if hasattr(graph_node, "index"): if hasattr(graph_node, "index"):
tensor_name = graph_node.layer.name + ":{}".format(graph_node.index) tensor_name = graph_node.layer.name + ":{}".format(graph_node.index)
else: else:
......
...@@ -67,13 +67,15 @@ class TFOpMapper(OpMapper): ...@@ -67,13 +67,15 @@ class TFOpMapper(OpMapper):
'RealDiv': 'elementwise_div', 'RealDiv': 'elementwise_div',
'Sub': 'elementwise_sub', 'Sub': 'elementwise_sub',
'Maximum': 'elementwise_max', 'Maximum': 'elementwise_max',
'Mul': 'elementwise_mul' 'Mul': 'elementwise_mul',
'FloorDiv': 'elementwise_floordiv'
} }
def __init__(self, decoder): def __init__(self, decoder):
super(TFOpMapper, self).__init__() super(TFOpMapper, self).__init__()
self.decoder = decoder self.decoder = decoder
self.graph = decoder.tf_graph self.graph = decoder.tf_graph
self.batch_node = None
self.weights = dict() self.weights = dict()
self.omit_nodes = list() self.omit_nodes = list()
self.used_custom_layers = dict() self.used_custom_layers = dict()
...@@ -86,9 +88,10 @@ class TFOpMapper(OpMapper): ...@@ -86,9 +88,10 @@ class TFOpMapper(OpMapper):
idx = self.graph.input_nodes.index(name) idx = self.graph.input_nodes.index(name)
del self.graph.input_nodes[idx] del self.graph.input_nodes[idx]
print("Total nodes: {}".format(len(self.graph.topo_sort))) sys.stderr.write("Total nodes: {}\n".format(len(self.graph.topo_sort)))
unsupported_ops = set() unsupported_ops = set()
for node_name in self.graph.topo_sort: for i, node_name in enumerate(self.graph.topo_sort):
sys.stderr.write("\rConverting node {} ... ".format(i + 1))
node = self.graph.get_node(node_name) node = self.graph.get_node(node_name)
op = node.layer_type op = node.layer_type
if op in self.directly_map_ops: if op in self.directly_map_ops:
...@@ -107,11 +110,13 @@ class TFOpMapper(OpMapper): ...@@ -107,11 +110,13 @@ class TFOpMapper(OpMapper):
else: else:
unsupported_ops.add(op) unsupported_ops.add(op)
if len(unsupported_ops) > 0: if len(unsupported_ops) > 0:
print("=========={} Ops are not supported yet======".format( sys.stderr.write(
len(unsupported_ops))) "=========={} Ops are not supported yet======\n".format(
len(unsupported_ops)))
for op in unsupported_ops: for op in unsupported_ops:
print("========== {} ==========".format(op)) sys.stderr.write("========== {} ==========\n".format(op))
sys.exit(-1) sys.exit(-1)
sys.stderr.write('\nDone!\n')
def add_omit_nodes(self, in_node_name, out_node_name): def add_omit_nodes(self, in_node_name, out_node_name):
in_node = self.graph.get_node(in_node_name) in_node = self.graph.get_node(in_node_name)
...@@ -144,6 +149,10 @@ class TFOpMapper(OpMapper): ...@@ -144,6 +149,10 @@ class TFOpMapper(OpMapper):
y = self.graph.get_node(node.layer.input[1], copy=True) y = self.graph.get_node(node.layer.input[1], copy=True)
x_shape = x.out_shapes[0] x_shape = x.out_shapes[0]
y_shape = y.out_shapes[0] y_shape = y.out_shapes[0]
if len(x_shape) == 0:
x_shape = [1]
if len(y_shape) == 0:
y_shape = [1]
# incomplement broadcasting support for paddle # incomplement broadcasting support for paddle
x_input = x x_input = x
y_input = y y_input = y
...@@ -237,6 +246,9 @@ class TFOpMapper(OpMapper): ...@@ -237,6 +246,9 @@ class TFOpMapper(OpMapper):
'name': string(node.layer_name), 'name': string(node.layer_name),
'append_batch_size': False 'append_batch_size': False
} }
if shape[0] < 0:
self.batch_node = node
node.fluid_code.add_layer("data", node.fluid_code.add_layer("data",
inputs=None, inputs=None,
output=node, output=node,
...@@ -285,17 +297,28 @@ class TFOpMapper(OpMapper): ...@@ -285,17 +297,28 @@ class TFOpMapper(OpMapper):
perm = perm.value.tolist() perm = perm.value.tolist()
if perm == [0, 3, 1, 2] and input.data_format == "NHWC": if perm == [0, 3, 1, 2] and input.data_format == "NHWC":
node.fluid_code.add_layer("assign", # node.fluid_code.add_layer("assign",
inputs=input, # inputs=input,
output=node, # output=node,
param_attr=None) # param_attr=None)
input_name = input.layer_name
if hasattr(input, "index"):
input_name = input_name + "[{}]".format(input.index)
node.fluid_code.add_layer("{} = {}").format(node.layer_name,
input_name)
node.tf_data_format = "NCHW" node.tf_data_format = "NCHW"
self.graph.data_format_propagation(node) self.graph.data_format_propagation(node)
elif perm == [0, 2, 3, 1] and input.tf_data_format == "NCHW": elif perm == [0, 2, 3, 1] and input.tf_data_format == "NCHW":
node.fluid_code.add_layer("assign", input_name = input.layer_name
inputs=input, if hasattr(input, "index"):
output=node, input_name = input_name + "[{}]".format(input.index)
param_attr=None) node.fluid_code.add_layer("{} = {}").format(node.layer_name,
input_name)
#
# node.fluid_code.add_layer("assign",
# inputs=input,
# output=node,
# param_attr=None)
node.tf_data_format = "NHWC" node.tf_data_format = "NHWC"
self.graph.data_format_propagation(node) self.graph.data_format_propagation(node)
elif len(input.out_shapes[0]) > 4: elif len(input.out_shapes[0]) > 4:
...@@ -564,6 +587,20 @@ class TFOpMapper(OpMapper): ...@@ -564,6 +587,20 @@ class TFOpMapper(OpMapper):
new_param += (node.layer_name + "[{}]".format(i) + ", ") new_param += (node.layer_name + "[{}]".format(i) + ", ")
new_param = new_param.strip(", ") + "]" new_param = new_param.strip(", ") + "]"
attr = {"shape": new_param} attr = {"shape": new_param}
if len(input.out_shapes[0]) == 4 and node.tf_data_format == "NHWC":
if len(attr["shape"]) < 3:
perm = {"perm": [0, 2, 3, 1]}
node.fluid_code.add_layer("transpose",
inputs=input,
output=node,
param_attr=perm)
node.fluid_code.add_layer("reshape",
inputs=node,
output=node,
param_attr=attr)
return
if len(attr["shape"]) == 4 and node.tf_data_format == "NHWC": if len(attr["shape"]) == 4 and node.tf_data_format == "NHWC":
input_shape = self.decoder.infer_tensor(input).shape input_shape = self.decoder.infer_tensor(input).shape
if input_shape[1] == attr["shape"][1]: if input_shape[1] == attr["shape"][1]:
...@@ -860,17 +897,32 @@ class TFOpMapper(OpMapper): ...@@ -860,17 +897,32 @@ class TFOpMapper(OpMapper):
size = [size[i] for i in [0, 3, 1, 2]] size = [size[i] for i in [0, 3, 1, 2]]
begin = [begin[i] for i in [0, 3, 1, 2]] begin = [begin[i] for i in [0, 3, 1, 2]]
attr = {"shape": size, "offsets": begin} for i in range(len(size)):
node.fluid_code.add_layer("crop", if size[i] < 0:
size[i] = 99999999
else:
size[i] = size[i] + begin[i]
attr = {
"axes": [i for i in range(len(size))],
"starts": begin,
"ends": size
}
node.fluid_code.add_layer("slice",
inputs=input, inputs=input,
output=node, output=node,
param_attr=attr) param_attr=attr)
def Conv2DBackpropInput(self, node): def Conv2DBackpropInput(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) out_shape = self.graph.get_node(node.layer.input[0], copy=True)
kernel = self.graph.get_node(node.layer.input[1], copy=True) kernel = self.graph.get_node(node.layer.input[1], copy=True)
input = self.graph.get_node(node.layer.input[2], copy=True)
assert kernel.layer_type == "Const", "Kernel of Conv2DBackpropInput should be Const" assert kernel.layer_type == "Const", "Kernel of Conv2DBackpropInput should be Const"
self.add_omit_nodes(kernel.layer_name, node.layer_name) self.add_omit_nodes(kernel.layer_name, node.layer_name)
self.add_omit_nodes(out_shape.layer_name, node.layer_name)
in_shape = input.out_shapes[0] in_shape = input.out_shapes[0]
if in_shape.count(-1) > 2: if in_shape.count(-1) > 2:
in_shape = self.decoder.infer_tensor(input).shape in_shape = self.decoder.infer_tensor(input).shape
...@@ -878,14 +930,14 @@ class TFOpMapper(OpMapper): ...@@ -878,14 +930,14 @@ class TFOpMapper(OpMapper):
if k_size.count(-1) > 2: if k_size.count(-1) > 2:
k_size = self.decoder.infer_tensor(kernel).shape k_size = self.decoder.infer_tensor(kernel).shape
pad_mode = node.get_attr("padding")
strides = node.get_attr("strides") strides = node.get_attr("strides")
dilations = node.get_attr("dilations") dilations = node.get_attr("dilations")
data_format = node.get_attr("data_format").decode() data_format = node.get_attr("data_format").decode()
pad_mode = node.get_attr("padding").decode()
channel_first = data_format == "NCHW" channel_first = data_format == "NCHW"
self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose( self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose(
kernel.value, (3, 2, 0, 1)) kernel.value, (3, 2, 0, 1))
if not channel_first: if not channel_first:
in_shape = [in_shape[i] for i in [0, 3, 1, 2]] in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
strides = [strides[i] for i in [0, 3, 1, 2]] strides = [strides[i] for i in [0, 3, 1, 2]]
...@@ -906,6 +958,7 @@ class TFOpMapper(OpMapper): ...@@ -906,6 +958,7 @@ class TFOpMapper(OpMapper):
output=node, output=node,
param_attr=attr) param_attr=attr)
input = node input = node
attr = { attr = {
"bias_attr": False, "bias_attr": False,
"param_attr": string(kernel.layer_name), "param_attr": string(kernel.layer_name),
...@@ -915,11 +968,10 @@ class TFOpMapper(OpMapper): ...@@ -915,11 +968,10 @@ class TFOpMapper(OpMapper):
"dilation": dilations[2:4], "dilation": dilations[2:4],
"padding": padding "padding": padding
} }
node.fluid_code.add_layer( node.fluid_code.add_layer("conv2d_transpose",
"conv2d_transpose", inputs=input,
inputs=input if channel_first and pad_mode != "SAME" else node, output=node,
output=node, param_attr=attr)
param_attr=attr)
def Max(self, node): def Max(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
...@@ -960,18 +1012,19 @@ class TFOpMapper(OpMapper): ...@@ -960,18 +1012,19 @@ class TFOpMapper(OpMapper):
output=node, output=node,
param_attr=attr) param_attr=attr)
def FloorDiv(self, node):
x = self.graph.get_node(node.layer.input[0], copy=True) # def FloorDiv(self, node):
y = self.graph.get_node(node.layer.input[1], copy=True) # x = self.graph.get_node(node.layer.input[0], copy=True)
inputs = {'x': x, 'y': y} # y = self.graph.get_node(node.layer.input[1], copy=True)
node.fluid_code.add_layer("elementwise_div", # inputs = {'x': x, 'y': y}
inputs=inputs, # node.fluid_code.add_layer("elementwise_div",
output=node, # inputs=inputs,
param_attr=None) # output=node,
node.fluid_code.add_layer("floor", # param_attr=None)
inputs=node, # node.fluid_code.add_layer("floor",
output=node, # inputs=node,
param_attr=None) # output=node,
# param_attr=None)
def Split(self, node): def Split(self, node):
dim = self.graph.get_node(node.layer.input[0], copy=True) dim = self.graph.get_node(node.layer.input[0], copy=True)
...@@ -1082,3 +1135,34 @@ class TFOpMapper(OpMapper): ...@@ -1082,3 +1135,34 @@ class TFOpMapper(OpMapper):
inputs=input, inputs=input,
output=node, output=node,
param_attr=attr) param_attr=attr)
def GreaterEqual(self, node):
x = self.graph.get_node(node.layer.input[0], copy=True)
y = self.graph.get_node(node.layer.input[1], copy=True)
inputs = {"x": x, "y": y}
node.fluid_code.add_layer("greater_equal",
inputs=inputs,
output=node,
param_attr=None)
def RandomUniform(self, node):
shape = self.graph.get_node(node.layer.input[0], copy=True)
self.add_omit_nodes(shape.layer_name, node.layer_name)
if shape.layer_type == "Const":
shape = shape.value.tolist()
else:
shape = self.decoder.infer_shape_tensor(shape)
if node.tf_data_format == "NHWC" and len(shape) == 4:
shape = [shape[i] for i in [0, 3, 1, 2]]
attr = {"shape": shape, "min": 0.0, "max": 0.9999}
if shape[0] < 0:
input = self.batch_node
node.fluid_code.add_layer("uniform_random_batch_size_like",
inputs=input,
output=node,
param_attr=attr)
else:
node.fluid_code.add_layer("uniform_random",
inputs=None,
output=node,
param_attr=attr)
...@@ -56,6 +56,7 @@ class TFOpMapperNHWC(OpMapper): ...@@ -56,6 +56,7 @@ class TFOpMapperNHWC(OpMapper):
self.decoder = decoder self.decoder = decoder
self.graph = decoder.tf_graph self.graph = decoder.tf_graph
self.weights = dict() self.weights = dict()
self.batch_node = None
self.omit_nodes = list() self.omit_nodes = list()
self.used_custom_layers = dict() self.used_custom_layers = dict()
...@@ -68,8 +69,9 @@ class TFOpMapperNHWC(OpMapper): ...@@ -68,8 +69,9 @@ class TFOpMapperNHWC(OpMapper):
del self.graph.input_nodes[idx] del self.graph.input_nodes[idx]
unsupported_ops = set() unsupported_ops = set()
print("Total nodes: {}".format(len(self.graph.topo_sort))) sys.stderr.write("Total nodes: {}\n".format(len(self.graph.topo_sort)))
for node_name in self.graph.topo_sort: for i, node_name in enumerate(self.graph.topo_sort):
sys.stderr.write("\rConverting node {} ... ".format(i))
node = self.graph.get_node(node_name) node = self.graph.get_node(node_name)
op = node.layer_type op = node.layer_type
if op in self.directly_map_ops: if op in self.directly_map_ops:
...@@ -94,6 +96,7 @@ class TFOpMapperNHWC(OpMapper): ...@@ -94,6 +96,7 @@ class TFOpMapperNHWC(OpMapper):
for op in unsupported_ops: for op in unsupported_ops:
print("========== {} ============".format(op)) print("========== {} ============".format(op))
sys.exit(-1) sys.exit(-1)
sys.stderr.write("\nDone\n")
def add_omit_nodes(self, in_node_name, out_node_name): def add_omit_nodes(self, in_node_name, out_node_name):
in_node = self.graph.get_node(in_node_name) in_node = self.graph.get_node(in_node_name)
...@@ -126,6 +129,10 @@ class TFOpMapperNHWC(OpMapper): ...@@ -126,6 +129,10 @@ class TFOpMapperNHWC(OpMapper):
y = self.graph.get_node(node.layer.input[1], copy=True) y = self.graph.get_node(node.layer.input[1], copy=True)
x_shape = x.out_shapes[0] x_shape = x.out_shapes[0]
y_shape = y.out_shapes[0] y_shape = y.out_shapes[0]
if len(x_shape) == 0:
x_shape = [1]
if len(y_shape) == 0:
y_shape = [1]
# incomplement broadcasting support for paddle # incomplement broadcasting support for paddle
x_input = x x_input = x
y_input = y y_input = y
...@@ -199,6 +206,8 @@ class TFOpMapperNHWC(OpMapper): ...@@ -199,6 +206,8 @@ class TFOpMapperNHWC(OpMapper):
'name': string(node.layer_name), 'name': string(node.layer_name),
'append_batch_size': False 'append_batch_size': False
} }
if shape[0] < 0:
self.batch_node = node
node.fluid_code.add_layer("data", node.fluid_code.add_layer("data",
inputs=None, inputs=None,
output=node, output=node,
...@@ -823,7 +832,6 @@ class TFOpMapperNHWC(OpMapper): ...@@ -823,7 +832,6 @@ class TFOpMapperNHWC(OpMapper):
inputs=input, inputs=input,
output=node, output=node,
param_attr=attr) param_attr=attr)
print(node.layer.name)
if len(new_axes) > 0: if len(new_axes) > 0:
attr = {"axes": new_axes} attr = {"axes": new_axes}
node.fluid_code.add_layer("unsqueeze", node.fluid_code.add_layer("unsqueeze",
...@@ -857,17 +865,32 @@ class TFOpMapperNHWC(OpMapper): ...@@ -857,17 +865,32 @@ class TFOpMapperNHWC(OpMapper):
else: else:
size = self.decoder.infer_tensor(size).tolist() size = self.decoder.infer_tensor(size).tolist()
attr = {"shape": size, "offsets": begin} for i in range(len(size)):
node.fluid_code.add_layer("crop", if size[i] < 0:
size[i] = 99999999
else:
size[i] = size[i] + begin[i]
attr = {
"axes": [i for i in range(len(size))],
"starts": begin,
"ends": size
}
node.fluid_code.add_layer("slice",
inputs=input, inputs=input,
output=node, output=node,
param_attr=attr) param_attr=attr)
def Conv2DBackpropInput(self, node): def Conv2DBackpropInput(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) out_shape = self.graph.get_node(node.layer.input[0], copy=True)
kernel = self.graph.get_node(node.layer.input[1], copy=True) kernel = self.graph.get_node(node.layer.input[1], copy=True)
input = self.graph.get_node(node.layer.input[2], copy=True)
assert kernel.layer_type == "Const", "Kernel of Conv2DBackpropInput should be Const" assert kernel.layer_type == "Const", "Kernel of Conv2DBackpropInput should be Const"
self.add_omit_nodes(kernel.layer_name, node.layer_name) self.add_omit_nodes(kernel.layer_name, node.layer_name)
self.add_omit_nodes(out_shape.layer_name, node.layer_name)
in_shape = input.out_shapes[0] in_shape = input.out_shapes[0]
if in_shape.count(-1) > 2: if in_shape.count(-1) > 2:
...@@ -876,14 +899,14 @@ class TFOpMapperNHWC(OpMapper): ...@@ -876,14 +899,14 @@ class TFOpMapperNHWC(OpMapper):
if k_size.count(-1) > 2: if k_size.count(-1) > 2:
k_size = self.decoder.infer_tensor(kernel).shape k_size = self.decoder.infer_tensor(kernel).shape
pad_mode = node.get_attr("padding")
strides = node.get_attr("strides") strides = node.get_attr("strides")
dilations = node.get_attr("dilations") dilations = node.get_attr("dilations")
data_format = node.get_attr("data_format").decode() data_format = node.get_attr("data_format").decode()
pad_mode = node.get_attr("padding").decode()
channel_first = data_format == "NCHW" channel_first = data_format == "NCHW"
self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose( self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose(
kernel.value, (3, 2, 0, 1)) kernel.value, (3, 2, 0, 1))
if not channel_first: if not channel_first:
in_shape = [in_shape[i] for i in [0, 3, 1, 2]] in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
strides = [strides[i] for i in [0, 3, 1, 2]] strides = [strides[i] for i in [0, 3, 1, 2]]
...@@ -894,6 +917,9 @@ class TFOpMapperNHWC(OpMapper): ...@@ -894,6 +917,9 @@ class TFOpMapperNHWC(OpMapper):
output=node, output=node,
param_attr=attr) param_attr=attr)
input = node input = node
else:
self.data_format_propagation(node)
padding = 0 padding = 0
if pad_mode == "SAME": if pad_mode == "SAME":
pad_h = get_same_padding(in_shape[2], k_size[0], strides[2]) pad_h = get_same_padding(in_shape[2], k_size[0], strides[2])
...@@ -907,6 +933,7 @@ class TFOpMapperNHWC(OpMapper): ...@@ -907,6 +933,7 @@ class TFOpMapperNHWC(OpMapper):
output=node, output=node,
param_attr=attr) param_attr=attr)
input = node input = node
attr = { attr = {
"bias_attr": False, "bias_attr": False,
"param_attr": string(kernel.layer_name), "param_attr": string(kernel.layer_name),
...@@ -920,6 +947,7 @@ class TFOpMapperNHWC(OpMapper): ...@@ -920,6 +947,7 @@ class TFOpMapperNHWC(OpMapper):
inputs=input, inputs=input,
output=node, output=node,
param_attr=attr) param_attr=attr)
if not channel_first: if not channel_first:
attr = {"perm": [0, 2, 3, 1]} attr = {"perm": [0, 2, 3, 1]}
node.fluid_code.add_layer("transpose", node.fluid_code.add_layer("transpose",
...@@ -1062,3 +1090,32 @@ class TFOpMapperNHWC(OpMapper): ...@@ -1062,3 +1090,32 @@ class TFOpMapperNHWC(OpMapper):
inputs=node, inputs=node,
output=node, output=node,
param_attr=attr) param_attr=attr)
def GreaterEqual(self, node):
x = self.graph.get_node(node.layer.input[0], copy=True)
y = self.graph.get_node(node.layer.input[1], copy=True)
inputs = {"x": x, "y": y}
node.fluid_code.add_layer("greater_equal",
inputs=inputs,
output=node,
param_attr=None)
def RandomUniform(self, node):
shape = self.graph.get_node(node.layer.input[0], copy=True)
self.add_omit_nodes(shape.layer_name, node.layer_name)
if shape.layer_type == "Const":
shape = shape.value.tolist()
else:
shape = self.decoder.infer_shape_tensor(shape)
attr = {"shape": shape, "min": 0.0, "max": 0.9999}
if shape[0] < 0:
input = self.batch_node
node.fluid_code.add_layer("uniform_random_batch_size_like",
inputs=input,
output=node,
param_attr=attr)
else:
node.fluid_code.add_layer("uniform_random",
inputs=None,
output=node,
param_attr=attr)
...@@ -26,10 +26,12 @@ class TFOptimizer(object): ...@@ -26,10 +26,12 @@ class TFOptimizer(object):
} }
layers_with_act = [ layers_with_act = [
'Conv2D', 'BiasAdd', 'DepthwiseConv2dNative', 'Conv2DBackpropInput', 'Conv2D', 'BiasAdd', 'DepthwiseConv2dNative', 'Conv2DBackpropInput',
'FusedBatchNorm' 'FusedBatchNorm', 'conv2d', 'elementwise_add', 'conv2d_transpose',
'batch_norm'
] ]
layers_with_bias = [ layers_with_bias = [
'Conv2D', 'DepthwiseConv2dNative', 'Conv2DBackpropInput' 'Conv2D', 'DepthwiseConv2dNative', 'Conv2DBackpropInput', 'conv2d',
'conv2d_transpose'
] ]
def __init__(self, op_mapper): def __init__(self, op_mapper):
...@@ -129,7 +131,12 @@ class TFOptimizer(object): ...@@ -129,7 +131,12 @@ class TFOptimizer(object):
continue continue
if len(input.outputs) != 1: if len(input.outputs) != 1:
continue continue
input.fluid_code.layers[-1].param_attr['act'] = string( index = -1
for i in range(len(input.fluid_code.layers)):
if input.fluid_code.layers[i].op in self.layers_with_act:
index = i
break
input.fluid_code.layers[index].param_attr['act'] = string(
self.activation_ops[node.layer_type]) self.activation_ops[node.layer_type])
input.fluid_code.layers[-1].output = node.fluid_code.layers[ input.fluid_code.layers[-1].output = node.fluid_code.layers[
0].output 0].output
...@@ -153,45 +160,70 @@ class TFOptimizer(object): ...@@ -153,45 +160,70 @@ class TFOptimizer(object):
if 'act' in node.fluid_code.layers[-1].param_attr: if 'act' in node.fluid_code.layers[-1].param_attr:
bias_with_act = True bias_with_act = True
layer_with_act = False layer_with_act = False
index = -1
for i in range(len(input.fluid_code.layers)):
if input.fluid_code.layers[i].op in self.layers_with_bias:
index = i
break
if 'act' in input.fluid_code.layers[ if 'act' in input.fluid_code.layers[
-1].param_attr and input.fluid_code.layers[ index].param_attr and input.fluid_code.layers[
-1].param_attr['act'] is not None: index].param_attr['act'] is not None:
layer_with_act = True layer_with_act = True
if bias_with_act and layer_with_act: if bias_with_act and layer_with_act:
continue continue
if not input.fluid_code.layers[-1].param_attr['bias_attr']: if not input.fluid_code.layers[index].param_attr['bias_attr']:
bias_name = node.inputs[1] bias_name = node.inputs[1]
input.fluid_code.layers[-1].param_attr[ input.fluid_code.layers[index].param_attr[
'bias_attr'] = string(bias_name) 'bias_attr'] = string(bias_name)
input.fluid_code.layers[-1].output = node.fluid_code.layers[ input.fluid_code.layers[-1].output = node.fluid_code.layers[
0].output 0].output
if bias_with_act: if bias_with_act:
input.fluid_code.layers[-1].param_attr[ input.fluid_code.layers[index].param_attr[
'act'] = node.fluid_code.layers[-1].param_attr[ 'act'] = node.fluid_code.layers[-1].param_attr[
'act'] 'act']
node.fluid_code.clear() node.fluid_code.clear()
self.graph.remove_node(node.layer_name)
def remove_transpose(self):
optimize_ops = [
'Conv2D', 'MaxPool', 'FusedBatchNorm', 'DepthwiseConv2dNative',
'AvgPool', 'Pad', 'Conv2DBackpropInput', 'ResizeNearestNeighbor',
'ResizeBilinear'
]
for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name)
if node is None:
continue
if node.layer_type not in optimize_ops:
continue
if node.fluid_code.layers[
-1].op != "transpose" or node.fluid_code.layers[
-1].param_attr["perm"] != [0, 2, 3, 1]:
continue
output_names = node.outputs
can_be_removed = True
for out_name in output_names:
out_node = self.graph.get_node(out_name)
if out_node.layer_type == "BiasAdd":
can_be_removed = True
if out_node.fluid_code.layers[
0].op != "transpose" or out_node.fluid_code.layers[
0].param_attr["perm"] != [0, 3, 1, 2]:
can_be_removed = False
break
if can_be_removed and len(output_names) > 0:
last_out = node.fluid_code.layers[-1].inputs
del node.fluid_code.layers[-1]
for out_name in output_names:
out_node = self.graph.get_node(out_name)
if out_node.layer_type == "BiasAdd":
del out_node.fluid_code.layers[0]
out_node.fluid_code.layers[0].inputs['x'] = last_out
# def remove_transpose(self): # out_node.fluid_code.layers[0].param_attr["axis"] = 1
# optimize_ops = ['Conv2D', 'MaxPool', 'FusedBatchNorm', 'DepthwiseConv2dNative', 'AvgPool', 'Pad', 'Conv2DBackpropInput', 'ResizeNearestNeighbor', 'ResizeBilinear'] else:
# for node_name in self.graph.topo_sort: del out_node.fluid_code.layers[0]
# node = self.graph.get_node(node_name) out_node.fluid_code.layers[0].inputs = last_out
# if node.layer_type not in optimize_ops:
# continue
# if node.fluid_code.layers[-1].op != "transpose" or node.fluid_code.layers[-1].param_attr["perm"] != [0, 2, 3, 1]:
# continue
# output_names = node.outputs
# can_be_removed = True
# for out_name in outputs_names:
# out_node = self.graph.get_node(out_name)
# if out_node.fluid_code.layers[0].op != "transpose" or out_node.fluid_code.layers[-1].param_attr["perm"] != [0, 3, 1, 2]:
# can_be_removed = False
# break
# if can_be_removed and len(output_names) > 0:
# last_out = node.fluid_code.layers[-1].inputs
# del node.fluid_code.layers[-1]
# for out_name in outputs_names:
# out_node = self.graph.get_node(out_name)
# del out_node.fluid_code.layers[0]
# out_node.fluid_code.layers[0].inputs = last_out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册