未验证 提交 8fecf767 编写于 作者: J Jason 提交者: GitHub

Merge pull request #68 from jiangjiajun/develop

add optimizer for tf2fluid
1. 增加逻辑,处理NHWC格式模型
2. bias和激活函数合入前置layer
...@@ -67,10 +67,17 @@ def tf2paddle(model_path, save_dir): ...@@ -67,10 +67,17 @@ def tf2paddle(model_path, save_dir):
from x2paddle.decoder.tf_decoder import TFDecoder from x2paddle.decoder.tf_decoder import TFDecoder
from x2paddle.op_mapper.tf_op_mapper import TFOpMapper from x2paddle.op_mapper.tf_op_mapper import TFOpMapper
from x2paddle.optimizer.tf_optimizer import TFOptimizer
print("Now translating model from tensorflow to paddle.") print("Now translating model from tensorflow to paddle.")
model = TFDecoder(model_path) model = TFDecoder(model_path)
mapper = TFOpMapper(model) mapper = TFOpMapper(model)
optimizer = TFOptimizer(mapper)
# neccesary optimization
optimizer.delete_redundance_code()
# optimizer below is experimental
optimizer.merge_activation()
optimizer.merge_bias()
mapper.save_inference_model(save_dir) mapper.save_inference_model(save_dir)
......
...@@ -99,29 +99,6 @@ class Graph(object): ...@@ -99,29 +99,6 @@ class Graph(object):
self.node_map[dst].inputs.append(src) self.node_map[dst].inputs.append(src)
self.node_map[src].outputs.append(dst) self.node_map[src].outputs.append(dst)
def remove_node(self, node_name):
if node_name not in self.node_map:
raise Exception("Node[{}] not in graph".format(node_name))
inputs = self.node_map[node_name].inputs
outputs = self.node_map[node_name].outputs
for input in inputs:
idx = self.node_map[input].outputs.index(node_name)
del self.node_map[input].outputs[idx]
for output in outputs:
idx = self.node_map[input].inputs.index(node_name)
del self.node_map[input].inputs[idx]
del self.node_map[node_name]
idx = self.topo_sort.index(node_name)
del self.topo_sort[idx]
if node_name in self.input_nodes:
idx = self.input_nodes.index(node_name)
del self.input_nodes[idx]
if node_name in self.output_nodes:
idx = self.output_nodes.index(node_name)
del self.output_nodes[idx]
def print(self): def print(self):
for i, tmp in enumerate(self.topo_sort): for i, tmp in enumerate(self.topo_sort):
print(tmp, self.node_map[tmp].layer_type, self.node_map[tmp].inputs, print(tmp, self.node_map[tmp].layer_type, self.node_map[tmp].inputs,
......
...@@ -116,7 +116,7 @@ class OpMapper(object): ...@@ -116,7 +116,7 @@ class OpMapper(object):
feeded_var_names=input_names, feeded_var_names=input_names,
target_vars=outputs, target_vars=outputs,
executor=exe, executor=exe,
params_filename="__params__") params_filename=None)
except: except:
raise Exception( raise Exception(
"Paddle code was saved in {}/model.py, but seems there's wrong exist, please check model.py manually." "Paddle code was saved in {}/model.py, but seems there's wrong exist, please check model.py manually."
...@@ -142,9 +142,9 @@ class OpMapper(object): ...@@ -142,9 +142,9 @@ class OpMapper(object):
self.add_codes("\ndef x2paddle_net():", 0) self.add_codes("\ndef x2paddle_net():", 0)
for i in range(len(self.graph.topo_sort)): for i in range(len(self.graph.topo_sort)):
node_name = self.graph.topo_sort[i] node_name = self.graph.topo_sort[i]
if hasattr(self, "omit_nodes") and node_name in self.omit_nodes:
continue
node = self.graph.get_node(node_name) node = self.graph.get_node(node_name)
if len(node.fluid_code.layers) == 0:
continue
self.add_codes(node.fluid_code.gen_codes(), 1) self.add_codes(node.fluid_code.gen_codes(), 1)
self.add_codes("", 0) self.add_codes("", 0)
......
...@@ -24,7 +24,7 @@ import sys ...@@ -24,7 +24,7 @@ import sys
class TFGraphNode(GraphNode): class TFGraphNode(GraphNode):
def __init__(self, layer, layer_name=None): def __init__(self, layer, layer_name=None, data_format="NHWC"):
if layer_name is None: if layer_name is None:
super(TFGraphNode, super(TFGraphNode,
self).__init__(layer, self).__init__(layer,
...@@ -35,6 +35,8 @@ class TFGraphNode(GraphNode): ...@@ -35,6 +35,8 @@ class TFGraphNode(GraphNode):
layer_name.replace('/', '_').replace('-', '_')) layer_name.replace('/', '_').replace('-', '_'))
self.layer_type = layer.op self.layer_type = layer.op
self.tf_data_format = data_format
self.pd_data_format = "NCHW"
self.fluid_code = FluidCode() self.fluid_code = FluidCode()
self.dtype_map = {1: "float32", 3: "int32", 4: "int8", 9: "int64"} self.dtype_map = {1: "float32", 3: "int32", 4: "int8", 9: "int64"}
...@@ -86,15 +88,16 @@ class TFGraphNode(GraphNode): ...@@ -86,15 +88,16 @@ class TFGraphNode(GraphNode):
class TFGraph(Graph): class TFGraph(Graph):
def __init__(self, model): def __init__(self, model, data_format="NHWC"):
super(TFGraph, self).__init__(model) super(TFGraph, self).__init__(model)
self.identity_map = dict() self.identity_map = dict()
self.multi_out_ops = ['Split', 'SplitV'] self.multi_out_ops = ['Split', 'SplitV']
self.tf_data_format = data_format
def build(self): def build(self):
for layer in self.model.node: for layer in self.model.node:
self.node_map[layer.name.replace('/', '_').replace( self.node_map[layer.name.replace('/', '_').replace(
'-', '_')] = TFGraphNode(layer) '-', '_')] = TFGraphNode(layer, data_format=self.tf_data_format)
for layer_name, node in self.node_map.items(): for layer_name, node in self.node_map.items():
for in_node in node.layer.input: for in_node in node.layer.input:
...@@ -126,6 +129,26 @@ class TFGraph(Graph): ...@@ -126,6 +129,26 @@ class TFGraph(Graph):
node.index = 0 node.index = 0
return node return node
def remove_node(self, node_name):
if node_name not in self.node_map:
raise Exception("Node[{}] not in graph".format(node_name))
inputs = self.node_map[node_name].inputs
outputs = self.node_map[node_name].outputs
assert len(inputs) == 1
input_node = self.node_map[inputs[0]]
idx = input_node.outputs.index(node_name)
del input_node.outputs[idx]
for output in outputs:
node = self.node_map[output]
idx = node.inputs.index(node_name)
node.inputs[idx] = inputs[0]
input_node.outputs.append(output)
del self.node_map[node_name]
idx = self.topo_sort.index(node_name)
del self.topo_sort[idx]
def _remove_isolated_node(self): def _remove_isolated_node(self):
# delete isolated nodes # delete isolated nodes
isolated_nodes = list() isolated_nodes = list()
...@@ -135,7 +158,15 @@ class TFGraph(Graph): ...@@ -135,7 +158,15 @@ class TFGraph(Graph):
isolated_nodes.append(node_name) isolated_nodes.append(node_name)
for node_name in isolated_nodes: for node_name in isolated_nodes:
self.remove_node(node_name) del self.node_map[node_name]
if node_name in self.input_nodes:
idx = self.input_nodes.index(node_name)
del self.input_nodes[idx]
if node_name in self.output_nodes:
idx = self.output_nodes.index(node_name)
del self.output_nodes[idx]
idx = self.topo_sort.index(node_name)
del self.topo_sort[idx]
def _remove_identity_node(self): def _remove_identity_node(self):
identity_node = list() identity_node = list()
...@@ -145,30 +176,47 @@ class TFGraph(Graph): ...@@ -145,30 +176,47 @@ class TFGraph(Graph):
for node_name in identity_node: for node_name in identity_node:
node = self.get_node(node_name) node = self.get_node(node_name)
# Remind: Only 1 input for Identity node
input_node = self.get_node(node.inputs[0]) input_node = self.get_node(node.inputs[0])
self.remove_node(node_name)
# remove identity node from graph
self.identity_map[node_name] = input_node.layer_name self.identity_map[node_name] = input_node.layer_name
idx = input_node.outputs.index(node_name)
del input_node.outputs[idx]
output_names = node.outputs # node = self.get_node(node_name)
for output_name in output_names: # # Remind: Only 1 input for Identity node
output_node = self.get_node(output_name) # input_node = self.get_node(node.inputs[0])
idx = output_node.inputs.index(node_name) #
output_node.inputs[idx] = input_node.layer_name # # remove identity node from graph
# self.identity_map[node_name] = input_node.layer_name
idx = self.topo_sort.index(node_name) # idx = input_node.outputs.index(node_name)
del self.topo_sort[idx] # del input_node.outputs[idx]
#
# output_names = node.outputs
# for output_name in output_names:
# output_node = self.get_node(output_name)
# idx = output_node.inputs.index(node_name)
# output_node.inputs[idx] = input_node.layer_name
#
# idx = self.topo_sort.index(node_name)
# del self.topo_sort[idx]
if node_name in self.output_nodes: if node_name in self.output_nodes:
idx = self.output_nodes.index(node_name) idx = self.output_nodes.index(node_name)
self.output_nodes[idx] = input_node.layer_name self.output_nodes[idx] = input_node.layer_name
def data_format_propagation(self, node):
current_node = self.node_map[node.layer_name]
current_node = node.tf_data_format
outputs = current_node.outputs
if len(outputs) == 0:
return
for out in outputs:
next_node = self.node_map[out]
next_node.tf_data_format = node.tf_data_format
self.data_format_propagation(next_node)
class TFDecoder(object): class TFDecoder(object):
def __init__(self, pb_model): def __init__(self, pb_model, data_format="NHWC"):
self.sess = tf.Session() self.sess = tf.Session()
self.input_info = dict() self.input_info = dict()
with gfile.FastGFile(pb_model, 'rb') as f: with gfile.FastGFile(pb_model, 'rb') as f:
...@@ -186,7 +234,7 @@ class TFDecoder(object): ...@@ -186,7 +234,7 @@ class TFDecoder(object):
self.sess.run(tf.global_variables_initializer()) self.sess.run(tf.global_variables_initializer())
self.tf_graph = TFGraph( self.tf_graph = TFGraph(
self.sess.graph._as_graph_def(add_shapes=True)[0]) self.sess.graph._as_graph_def(add_shapes=True)[0], data_format)
self.tf_graph.build() self.tf_graph.build()
def _fix_output_shape(self, graph): def _fix_output_shape(self, graph):
......
...@@ -28,6 +28,25 @@ def get_same_padding(in_size, kernel_size, stride): ...@@ -28,6 +28,25 @@ def get_same_padding(in_size, kernel_size, stride):
return [pad0, pad1] return [pad0, pad1]
def nhwc_dim_to_nchw(node, dim):
tf_data_format = list(node.tf_data_format)
pd_data_format = list(node.pd_data_format)
if isinstance(dim, list):
for i in range(len(dim)):
char = tf_data_format[dim[i]]
dim[i] = pd_data_format.index(char)
else:
char = tf_data_format[dim]
dim = pd_data_format.index(char)
return dim
if dim < 0:
dim += 4
if dim > 0:
dim = (dim + 1) % 4 + int((dim + 1) / 4)
return dim
class TFOpMapper(OpMapper): class TFOpMapper(OpMapper):
directly_map_ops = { directly_map_ops = {
'Relu': ['relu'], 'Relu': ['relu'],
...@@ -37,17 +56,11 @@ class TFOpMapper(OpMapper): ...@@ -37,17 +56,11 @@ class TFOpMapper(OpMapper):
'Sigmoid': ['sigmoid'], 'Sigmoid': ['sigmoid'],
'Exp': ['exp'], 'Exp': ['exp'],
'Rsqrt': ['rsqrt'], 'Rsqrt': ['rsqrt'],
'Squeeze': ['squeeze', { 'swish_f32': ['swish']
'squeeze_dims': 'axes'
}],
'Softmax': ['softmax', {
'axis': 'axis'
}],
} }
elementwise_ops = { elementwise_ops = {
'Add': 'elementwise_add', 'Add': 'elementwise_add',
'RealDiv': 'elementwise_div', 'RealDiv': 'elementwise_div',
'BiasAdd': 'elementwise_add',
'Sub': 'elementwise_sub', 'Sub': 'elementwise_sub',
'Maximum': 'elementwise_max', 'Maximum': 'elementwise_max',
'Mul': 'elementwise_mul' 'Mul': 'elementwise_mul'
...@@ -121,6 +134,19 @@ class TFOpMapper(OpMapper): ...@@ -121,6 +134,19 @@ class TFOpMapper(OpMapper):
else: else:
raise Exception("Unexpected situation happend") raise Exception("Unexpected situation happend")
if len(x_shape) == 4 and len(y_shape) == 1:
if x_input.tf_data_format == "NHWC":
axis = 1
else:
axis = -1
attr = {"axis": axis}
inputs = {"x": x_input, "y": y_input}
node.fluid_code.add_layer(op_type,
inputs=inputs,
output=node,
param_attr=attr)
return
is_sub_seq = True is_sub_seq = True
for i in range(len(y_shape)): for i in range(len(y_shape)):
index = -1 * i - 1 index = -1 * i - 1
...@@ -143,6 +169,10 @@ class TFOpMapper(OpMapper): ...@@ -143,6 +169,10 @@ class TFOpMapper(OpMapper):
else: else:
raise Exception("Unexpected situation happend") raise Exception("Unexpected situation happend")
if x_need_expand: if x_need_expand:
if len(x_expand_times) == 3 and x.tf_data_format == "NHWC":
x_expand_times = [x_expand_times[i] for i in [2, 0, 1]]
if len(x_expand_times) == 4 and x.tf_data_format == "NHWC":
x_expand_times = [x_expand_times[i] for i in [0, 3, 1, 2]]
attr = {"expand_times": x_expand_times} attr = {"expand_times": x_expand_times}
node.fluid_code.add_layer("expand", node.fluid_code.add_layer("expand",
inputs=x_input, inputs=x_input,
...@@ -150,6 +180,10 @@ class TFOpMapper(OpMapper): ...@@ -150,6 +180,10 @@ class TFOpMapper(OpMapper):
param_attr=attr) param_attr=attr)
x_input = "x_tmp" x_input = "x_tmp"
if y_need_expand: if y_need_expand:
if len(y_expand_times) == 3 and y.tf_data_format == "NHWC":
y_expand_times = [y_expand_times[i] for i in [2, 0, 1]]
if len(y_expand_times) == 4 and y.tf_data_format == "NHWC":
y_expand_times = [y_expand_times[i] for i in [0, 3, 1, 2]]
attr = {"expand_times": y_expand_times} attr = {"expand_times": y_expand_times}
node.fluid_code.add_layer("expand", node.fluid_code.add_layer("expand",
inputs=y_input, inputs=y_input,
...@@ -166,6 +200,10 @@ class TFOpMapper(OpMapper): ...@@ -166,6 +200,10 @@ class TFOpMapper(OpMapper):
shape = node.out_shapes[0] shape = node.out_shapes[0]
assert len(shape) != 0, "Unknown shape of input nodes[{}].".format( assert len(shape) != 0, "Unknown shape of input nodes[{}].".format(
node.layer_name) node.layer_name)
if node.tf_data_format == "NHWC" and len(shape) == 4:
shape = [shape[i] for i in [0, 3, 1, 2]]
elif node.tf_data_format == "NCHW" and len(shape) == 4:
self.graph.data_format_propagation(node)
dtype = node.dtype dtype = node.dtype
attr = { attr = {
'dtype': string(dtype), 'dtype': string(dtype),
...@@ -188,6 +226,19 @@ class TFOpMapper(OpMapper): ...@@ -188,6 +226,19 @@ class TFOpMapper(OpMapper):
shape = [1] shape = [1]
initializer = "Constant({})".format(value) initializer = "Constant({})".format(value)
self.weights[node.layer_name] = node.value
if node.tf_data_format == "NHWC":
if len(shape) == 4:
shape = [shape[i] for i in [0, 3, 1, 2]]
if len(shape) == 3:
shape = [shape[i] for i in [2, 0, 1]]
self.weights[node.layer_name] = numpy.transpose(
node.value, (2, 0, 1))
elif node.tf_data_format == "NCHW":
if len(shape) == 4:
self.graph.data_format_propagation(node)
attr = { attr = {
'dtype': string(dtype), 'dtype': string(dtype),
'shape': shape, 'shape': shape,
...@@ -198,7 +249,6 @@ class TFOpMapper(OpMapper): ...@@ -198,7 +249,6 @@ class TFOpMapper(OpMapper):
inputs=None, inputs=None,
output=node, output=node,
param_attr=attr) param_attr=attr)
self.weights[node.layer_name.replace('/', '_')] = node.value
def Transpose(self, node): def Transpose(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
...@@ -208,11 +258,46 @@ class TFOpMapper(OpMapper): ...@@ -208,11 +258,46 @@ class TFOpMapper(OpMapper):
perm.fluid_code.clear() perm.fluid_code.clear()
perm = perm.value.tolist() perm = perm.value.tolist()
attr = {'perm': perm} if perm == [0, 3, 1, 2] and input.data_format == "NHWC":
node.fluid_code.add_layer("transpose", node.fluid_code.add_layer("assign",
inputs=input, inputs=input,
output=node, output=node,
param_attr=attr) param_attr=None)
node.tf_data_format = "NCHW"
self.graph.data_format_propagation(node)
elif perm == [0, 2, 3, 1] and input.tf_data_format == "NCHW":
node.fluid_code.add_layer("assign",
inputs=input,
output=node,
param_attr=None)
node.tf_data_format = "NHWC"
self.graph.data_format_propagation(node)
elif len(input.out_shapes[0]) > 4:
print(input.layer_name, input.tf_data_format, input.pd_data_format)
tf_data_format = list(input.tf_data_format)
pd_data_format = list(input.pd_data_format)
new_perm = [i for i in range(len(perm))]
for i in range(len(perm)):
char0 = tf_data_format[i]
char1 = tf_data_format[perm[i]]
index0 = pd_data_format.index(char0)
index1 = pd_data_format.index(char1)
new_perm[index0] = index1
node.tf_data_format = [tf_data_format[i] for i in perm]
node.pd_data_format = [pd_data_format[i] for i in perm]
attr = {'perm': new_perm}
node.fluid_code.add_layer("transpose",
inputs=input,
output=node,
param_attr=attr)
elif len(node.out_shapes[0]) != 4:
attr = {'perm': perm}
node.fluid_code.add_layer("transpose",
inputs=input,
output=node,
param_attr=attr)
else:
raise Exception("Unexpected situation happend in Transpose OP")
def MaxPool(self, node): def MaxPool(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
...@@ -226,16 +311,14 @@ class TFOpMapper(OpMapper): ...@@ -226,16 +311,14 @@ class TFOpMapper(OpMapper):
data_format = node.get_attr("data_format").decode() data_format = node.get_attr("data_format").decode()
pad_mode = node.get_attr("padding").decode() pad_mode = node.get_attr("padding").decode()
channel_first = data_format == "NCHW" channel_first = data_format == "NCHW"
padding = 0
if not channel_first: if not channel_first:
attr = {"perm": [0, 3, 1, 2]}
node.fluid_code.add_layer("transpose",
inputs=input,
output=node,
param_attr=attr)
in_shape = [in_shape[i] for i in [0, 3, 1, 2]] in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
strides = [strides[i] for i in [0, 3, 1, 2]] strides = [strides[i] for i in [0, 3, 1, 2]]
k_size = [k_size[i] for i in [0, 3, 1, 2]] k_size = [k_size[i] for i in [0, 3, 1, 2]]
else:
self.graph.data_format_propagation(node)
if pad_mode == "SAME": if pad_mode == "SAME":
pad_h = get_same_padding(in_shape[2], k_size[2], strides[2]) pad_h = get_same_padding(in_shape[2], k_size[2], strides[2])
...@@ -243,29 +326,21 @@ class TFOpMapper(OpMapper): ...@@ -243,29 +326,21 @@ class TFOpMapper(OpMapper):
pad_h = pad_h[0] + pad_h[1] pad_h = pad_h[0] + pad_h[1]
pad_w = pad_w[0] + pad_w[1] pad_w = pad_w[0] + pad_w[1]
attr = {"paddings": [0, pad_h, 0, pad_w], "pad_value": -10000.0} attr = {"paddings": [0, pad_h, 0, pad_w], "pad_value": -10000.0}
if pad_h + pad_w != 0: node.fluid_code.add_layer("pad2d",
node.fluid_code.add_layer( inputs=input,
"pad2d", output=node,
inputs=input if channel_first else node, param_attr=attr)
output=node, input = node
param_attr=attr)
attr = { attr = {
"pool_size": k_size[2:4], "pool_size": k_size[2:4],
"pool_type": string("max"), "pool_type": string("max"),
"pool_padding": padding,
"pool_stride": strides[2:4] "pool_stride": strides[2:4]
} }
node.fluid_code.add_layer( node.fluid_code.add_layer("pool2d",
"pool2d", inputs=input,
inputs=input if channel_first and pad_mode != "SAME" else node, output=node,
output=node, param_attr=attr)
param_attr=attr)
if not channel_first:
attr = {"perm": [0, 2, 3, 1]}
node.fluid_code.add_layer("transpose",
inputs=node,
output=node,
param_attr=attr)
def Conv2D(self, node): def Conv2D(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
...@@ -288,49 +363,56 @@ class TFOpMapper(OpMapper): ...@@ -288,49 +363,56 @@ class TFOpMapper(OpMapper):
data_format = node.get_attr("data_format").decode() data_format = node.get_attr("data_format").decode()
pad_mode = node.get_attr("padding").decode() pad_mode = node.get_attr("padding").decode()
channel_first = data_format == "NCHW" channel_first = data_format == "NCHW"
padding = 0
self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose(
kernel.value, (3, 2, 0, 1))
if not channel_first: if not channel_first:
self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose(
kernel.value, (3, 2, 0, 1))
attr = {"perm": [0, 3, 1, 2]}
node.fluid_code.add_layer("transpose",
inputs=input,
output=node,
param_attr=attr)
in_shape = [in_shape[i] for i in [0, 3, 1, 2]] in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
strides = [strides[i] for i in [0, 3, 1, 2]] strides = [strides[i] for i in [0, 3, 1, 2]]
dilations = [dilations[i] for i in [0, 3, 1, 2]] dilations = [dilations[i] for i in [0, 3, 1, 2]]
else:
self.graph.data_format_propagation(node)
if pad_mode == "SAME": if pad_mode == "SAME":
pad_h = get_same_padding(in_shape[2], k_size[0], strides[2]) pad_h = get_same_padding(in_shape[2], k_size[0], strides[2])
pad_w = get_same_padding(in_shape[3], k_size[1], strides[3]) pad_w = get_same_padding(in_shape[3], k_size[1], strides[3])
attr = {"paddings": pad_h + pad_w, "pad_value": 0.0} if pad_h[0] == pad_h[1] and pad_w[0] == pad_w[1]:
if pad_h[0] + pad_h[1] + pad_w[0] + pad_w[1] != 0: padding = [pad_h[0], pad_w[0]]
node.fluid_code.add_layer( else:
"pad2d", attr = {"paddings": pad_h + pad_w, "pad_value": 0.0}
inputs=input if channel_first else node, node.fluid_code.add_layer("pad2d",
output=node, inputs=input,
param_attr=attr) output=node,
param_attr=attr)
input = node
attr = { attr = {
"bias_attr": False, "bias_attr": False,
"param_attr": string(kernel.layer_name), "param_attr": string(kernel.layer_name),
"num_filters": k_size[3], "num_filters": k_size[3],
"filter_size": k_size[0:2], "filter_size": k_size[0:2],
"stride": strides[2:4], "stride": strides[2:4],
"dilation": dilations[2:4] "dilation": dilations[2:4],
"padding": padding
} }
node.fluid_code.add_layer( node.fluid_code.add_layer("conv2d",
"conv2d", inputs=input,
inputs=input if channel_first and pad_mode != "SAME" else node, output=node,
output=node, param_attr=attr)
param_attr=attr)
if not channel_first: def BiasAdd(self, node):
attr = {"perm": [0, 2, 3, 1]} input = self.graph.get_node(node.layer.input[0], copy=True)
node.fluid_code.add_layer("transpose", bias = self.graph.get_node(node.layer.input[1], copy=True)
inputs=node, axis = -1
output=node, if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
param_attr=attr) axis = 1
inputs = {"x": input, "y": bias}
attr = {"axis": axis}
node.fluid_code.add_layer("elementwise_add",
inputs=inputs,
output=node,
param_attr=attr)
def FusedBatchNorm(self, node): def FusedBatchNorm(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
...@@ -350,17 +432,12 @@ class TFOpMapper(OpMapper): ...@@ -350,17 +432,12 @@ class TFOpMapper(OpMapper):
self.omit_nodes.append(moving_mean.layer_name) self.omit_nodes.append(moving_mean.layer_name)
self.omit_nodes.append(moving_var.layer_name) self.omit_nodes.append(moving_var.layer_name)
if not channel_first: if channel_first:
attr = {"perm": [0, 3, 1, 2]} self.data_format_propagation(node)
node.fluid_code.add_layer("transpose",
inputs=input,
output=node,
param_attr=attr)
attr = { attr = {
"epsilon": node.get_attr("epsilon"), "epsilon": node.get_attr("epsilon"),
"param_attr": string(gamma.layer_name), "param_attr": string(gamma.layer_name),
# "data_layout": string(node.get_attr("data_format").decode()),
"bias_attr": string(beta.layer_name), "bias_attr": string(beta.layer_name),
"moving_mean_name": string(moving_mean.layer_name), "moving_mean_name": string(moving_mean.layer_name),
"moving_variance_name": string(moving_var.layer_name), "moving_variance_name": string(moving_var.layer_name),
...@@ -368,17 +445,10 @@ class TFOpMapper(OpMapper): ...@@ -368,17 +445,10 @@ class TFOpMapper(OpMapper):
} }
node.fluid_code.add_layer("batch_norm", node.fluid_code.add_layer("batch_norm",
inputs=input if channel_first else node, inputs=input,
output=node, output=node,
param_attr=attr) param_attr=attr)
if not channel_first:
attr = {"perm": [0, 2, 3, 1]}
node.fluid_code.add_layer("transpose",
inputs=node,
output=node,
param_attr=attr)
def DepthwiseConv2dNative(self, node): def DepthwiseConv2dNative(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
kernel = self.graph.get_node(node.layer.input[1], copy=True) kernel = self.graph.get_node(node.layer.input[1], copy=True)
...@@ -400,29 +470,31 @@ class TFOpMapper(OpMapper): ...@@ -400,29 +470,31 @@ class TFOpMapper(OpMapper):
data_format = node.get_attr("data_format").decode() data_format = node.get_attr("data_format").decode()
pad_mode = node.get_attr("padding").decode() pad_mode = node.get_attr("padding").decode()
channel_first = data_format == "NCHW" channel_first = data_format == "NCHW"
padding = 0
self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose(
kernel.value, (2, 3, 0, 1))
if not channel_first: if not channel_first:
self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose(
kernel.value, (2, 3, 0, 1))
attr = {"perm": [0, 3, 1, 2]}
node.fluid_code.add_layer("transpose",
inputs=input,
output=node,
param_attr=attr)
in_shape = [in_shape[i] for i in [0, 3, 1, 2]] in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
strides = [strides[i] for i in [0, 3, 1, 2]] strides = [strides[i] for i in [0, 3, 1, 2]]
dilations = [dilations[i] for i in [0, 3, 1, 2]] dilations = [dilations[i] for i in [0, 3, 1, 2]]
else:
self.data_format_propagation(node)
if pad_mode == "SAME": if pad_mode == "SAME":
pad_h = get_same_padding(in_shape[2], k_size[0], strides[2]) pad_h = get_same_padding(in_shape[2], k_size[0], strides[2])
pad_w = get_same_padding(in_shape[3], k_size[1], strides[3]) pad_w = get_same_padding(in_shape[3], k_size[1], strides[3])
attr = {"paddings": pad_h + pad_w, "pad_value": 0.0} if pad_h[0] == pad_h[1] and pad_w[0] == pad_w[1]:
if pad_h[0] + pad_h[1] + pad_w[0] + pad_w[1] != 0: padding = [pad_h[0], pad_w[0]]
else:
attr = {"paddings": pad_h + pad_w, "pad_value": 0.0}
node.fluid_code.add_layer("pad2d", node.fluid_code.add_layer("pad2d",
inputs=input if channel_first inputs=input,
and pad_mode != "SAME" else node,
output=node, output=node,
param_attr=attr) param_attr=attr)
input = node
attr = { attr = {
"bias_attr": False, "bias_attr": False,
"param_attr": string(kernel.layer_name), "param_attr": string(kernel.layer_name),
...@@ -430,20 +502,14 @@ class TFOpMapper(OpMapper): ...@@ -430,20 +502,14 @@ class TFOpMapper(OpMapper):
"filter_size": k_size[0:2], "filter_size": k_size[0:2],
"stride": strides[2:4], "stride": strides[2:4],
"dilation": dilations[2:4], "dilation": dilations[2:4],
"groups": k_size[3] * in_shape[1] "groups": k_size[3] * in_shape[1],
"padding": padding
} }
node.fluid_code.add_layer("conv2d", node.fluid_code.add_layer("conv2d",
inputs=input if channel_first else node, inputs=input,
output=node, output=node,
param_attr=attr) param_attr=attr)
if not channel_first:
attr = {"perm": [0, 2, 3, 1]}
node.fluid_code.add_layer("transpose",
inputs=node,
output=node,
param_attr=attr)
def Reshape(self, node): def Reshape(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
param = self.graph.get_node(node.layer.input[1], copy=True) param = self.graph.get_node(node.layer.input[1], copy=True)
...@@ -474,6 +540,8 @@ class TFOpMapper(OpMapper): ...@@ -474,6 +540,8 @@ class TFOpMapper(OpMapper):
new_param += (node.layer_name + "[{}]".format(i) + ", ") new_param += (node.layer_name + "[{}]".format(i) + ", ")
new_param = new_param.strip(", ") + "]" new_param = new_param.strip(", ") + "]"
attr = {"shape": new_param} attr = {"shape": new_param}
if len(attr["shape"]) == 4 and node.tf_data_format == "NHWC":
attr["shape"] = [attr["shape"][i] for i in [0, 3, 1, 2]]
node.fluid_code.add_layer("reshape", node.fluid_code.add_layer("reshape",
inputs=input, inputs=input,
output=node, output=node,
...@@ -493,14 +561,11 @@ class TFOpMapper(OpMapper): ...@@ -493,14 +561,11 @@ class TFOpMapper(OpMapper):
channel_first = data_format == "NCHW" channel_first = data_format == "NCHW"
if not channel_first: if not channel_first:
attr = {"perm": [0, 3, 1, 2]}
node.fluid_code.add_layer("transpose",
inputs=input,
output=node,
param_attr=attr)
in_shape = [in_shape[i] for i in [0, 3, 1, 2]] in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
strides = [strides[i] for i in [0, 3, 1, 2]] strides = [strides[i] for i in [0, 3, 1, 2]]
k_size = [k_size[i] for i in [0, 3, 1, 2]] k_size = [k_size[i] for i in [0, 3, 1, 2]]
else:
self.graph.data_format_propagation(node)
attr = { attr = {
"pool_size": k_size[2:4], "pool_size": k_size[2:4],
...@@ -514,17 +579,10 @@ class TFOpMapper(OpMapper): ...@@ -514,17 +579,10 @@ class TFOpMapper(OpMapper):
1], "Cannot map AvgPool" 1], "Cannot map AvgPool"
attr["pool_padding"] = [pad_h[0], pad_w[0]] attr["pool_padding"] = [pad_h[0], pad_w[0]]
node.fluid_code.add_layer("pool2d", node.fluid_code.add_layer("pool2d",
inputs=input if channel_first else node, inputs=input,
output=node, output=node,
param_attr=attr) param_attr=attr)
if not channel_first:
attr = {"perm": [0, 2, 3, 1]}
node.fluid_code.add_layer("transpose",
inputs=node,
output=node,
param_attr=attr)
def SplitV(self, node): def SplitV(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
num_sections = self.graph.get_node(node.layer.input[1], copy=True) num_sections = self.graph.get_node(node.layer.input[1], copy=True)
...@@ -533,6 +591,9 @@ class TFOpMapper(OpMapper): ...@@ -533,6 +591,9 @@ class TFOpMapper(OpMapper):
assert dim.layer_type == "Const" assert dim.layer_type == "Const"
self.omit_nodes.append(num_sections.layer_name) self.omit_nodes.append(num_sections.layer_name)
self.omit_nodes.append(dim.layer_name) self.omit_nodes.append(dim.layer_name)
dim = dim.value
if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
dim = nhwc_dim_to_nchw(input, dim)
attr = { attr = {
"num_or_sections": num_sections.value.tolist(), "num_or_sections": num_sections.value.tolist(),
"dim": dim.value "dim": dim.value
...@@ -550,7 +611,11 @@ class TFOpMapper(OpMapper): ...@@ -550,7 +611,11 @@ class TFOpMapper(OpMapper):
axis = self.graph.get_node(node.layer.input[-1], copy=True) axis = self.graph.get_node(node.layer.input[-1], copy=True)
assert axis.layer_type == "Const" assert axis.layer_type == "Const"
self.omit_nodes.append(axis.layer_name) self.omit_nodes.append(axis.layer_name)
attr = {"axis": axis.value} axis = axis.value
if inputs[0].tf_data_format == "NHWC" and len(
inputs[0].out_shapes[0]) == 4:
axis = nhwc_dim_to_nchw(inputs[0], axis)
attr = {"axis": axis}
node.fluid_code.add_layer("concat", node.fluid_code.add_layer("concat",
inputs=inputs, inputs=inputs,
output=node, output=node,
...@@ -561,7 +626,13 @@ class TFOpMapper(OpMapper): ...@@ -561,7 +626,13 @@ class TFOpMapper(OpMapper):
expand_times = self.graph.get_node(node.layer.input[1], copy=True) expand_times = self.graph.get_node(node.layer.input[1], copy=True)
assert expand_times.layer_type == "Const" assert expand_times.layer_type == "Const"
self.omit_nodes.append(expand_times.layer_name) self.omit_nodes.append(expand_times.layer_name)
attr = {"expand_times": expand_times.value.tolist()} expand_times = expand_times.value.tolist()
if input.tf_data_format == "NHWC":
if len(input.out_shapes[0]) == 4:
expand_times = [expand_times[i] for i in [0, 3, 1, 2]]
elif len(input.out_shape[0]) == 3:
expand_times = [expand_times[i] for i in [2, 0, 1]]
attr = {"expand_times": expand_times}
node.fluid_code.add_layer("expand", node.fluid_code.add_layer("expand",
inputs=input, inputs=input,
output=node, output=node,
...@@ -571,7 +642,18 @@ class TFOpMapper(OpMapper): ...@@ -571,7 +642,18 @@ class TFOpMapper(OpMapper):
inputs = [ inputs = [
self.graph.get_node(name, copy=True) for name in node.layer.input self.graph.get_node(name, copy=True) for name in node.layer.input
] ]
attr = {"axis": node.get_attr("axis")} axis = node.get_attr("axis")
if inputs[0].tf_data_format == "NHWC" and len(
inputs[0].out_shapes[0]) == 4:
tf_data_format = list(inputs[0].tf_data_format)
tf_data_format.insert(axis, str(len(tf_data_format)))
axis = nhwc_dim_to_nchw(inputs[0], axis)
pd_data_format = list(inputs[0].pd_data_format)
pd_data_format.insert(axis, str(len(pd_data_format)))
node.tf_data_format = "".join(tf_data_format)
node.pd_data_format = "".join(pd_data_format)
attr = {"axis": axis}
node.fluid_code.add_layer("stack", node.fluid_code.add_layer("stack",
inputs=inputs, inputs=inputs,
output=node, output=node,
...@@ -582,7 +664,10 @@ class TFOpMapper(OpMapper): ...@@ -582,7 +664,10 @@ class TFOpMapper(OpMapper):
paddings = self.graph.get_node(node.layer.input[1], copy=True) paddings = self.graph.get_node(node.layer.input[1], copy=True)
assert paddings.layer_type == "Const", "Padding should be Const" assert paddings.layer_type == "Const", "Padding should be Const"
self.omit_nodes.append(paddings.layer_name) self.omit_nodes.append(paddings.layer_name)
attr = {"paddings": paddings.value.flatten().tolist()} paddings = paddings.value.flatten().tolist()
if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
paddings = [paddings[i] for i in [0, 1, 6, 7, 2, 3, 4, 5]]
attr = {"paddings": paddings}
node.fluid_code.add_layer("pad", node.fluid_code.add_layer("pad",
inputs=input, inputs=input,
output=node, output=node,
...@@ -608,24 +693,18 @@ class TFOpMapper(OpMapper): ...@@ -608,24 +693,18 @@ class TFOpMapper(OpMapper):
output=node, output=node,
param_attr=None) param_attr=None)
def swish_f32(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
node.fluid_code.add_layer("sigmoid",
inputs=input,
output=node,
param_attr=None)
inputs = {"x": input, "y": node}
node.fluid_code.add_layer("elementwise_mul",
inputs=inputs,
output=node,
param_attr=None)
def Mean(self, node): def Mean(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
reduce_idx = self.graph.get_node(node.layer.input[1], copy=True) reduce_idx = self.graph.get_node(node.layer.input[1], copy=True)
assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]" assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
dims = reduce_idx.value.tolist()
keep_dims = node.get_attr("keep_dims") keep_dims = node.get_attr("keep_dims")
attr = {"dim": reduce_idx.value.tolist(), "keep_dim": keep_dims}
if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
for i in range(len(dims)):
dims[i] = nhwc_dim_to_nchw(input, dims[i])
attr = {"dim": dims, "keep_dim": keep_dims}
node.fluid_code.add_layer("reduce_mean", node.fluid_code.add_layer("reduce_mean",
inputs=input, inputs=input,
output=node, output=node,
...@@ -658,7 +737,10 @@ class TFOpMapper(OpMapper): ...@@ -658,7 +737,10 @@ class TFOpMapper(OpMapper):
axis = self.graph.get_node(node.layer.input[1], copy=True) axis = self.graph.get_node(node.layer.input[1], copy=True)
assert axis.layer_type == "Const", "ArgMax only support Const parameter" assert axis.layer_type == "Const", "ArgMax only support Const parameter"
self.omit_nodes.append(axis.layer_name) self.omit_nodes.append(axis.layer_name)
attr = {"axis": axis.value} axis = axis.value
if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
axis = nhwc_dim_to_nchw(input, axis)
attr = {"axis": axis}
node.fluid_code.add_layer("argmax", node.fluid_code.add_layer("argmax",
inputs=input, inputs=input,
output=node, output=node,
...@@ -678,11 +760,13 @@ class TFOpMapper(OpMapper): ...@@ -678,11 +760,13 @@ class TFOpMapper(OpMapper):
strides = strides.value.tolist() strides = strides.value.tolist()
assert len(set(strides)) == 1 and strides[0] == 1 assert len(set(strides)) == 1 and strides[0] == 1
attr = { begin = begin.value.tolist()
"axes": range(len(strides)), end = end.value.tolist()
"starts": begin.value.tolist(), if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
"ends": end.value.tolist() begin = [begin[i] for i in [0, 3, 1, 2]]
} end = [end[i] for i in [0, 3, 1, 2]]
attr = {"axes": range(len(strides)), "starts": begin, "ends": end}
node.fluid_code.add_layer("slice", node.fluid_code.add_layer("slice",
inputs=input, inputs=input,
output=node, output=node,
...@@ -705,6 +789,10 @@ class TFOpMapper(OpMapper): ...@@ -705,6 +789,10 @@ class TFOpMapper(OpMapper):
else: else:
size = self.decoder.infer_tensor(size).tolist() size = self.decoder.infer_tensor(size).tolist()
if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
size = [size[i] for i in [0, 3, 1, 2]]
begin = [begin[i] for i in [0, 3, 1, 2]]
attr = {"shape": size, "offsets": begin} attr = {"shape": size, "offsets": begin}
node.fluid_code.add_layer("crop", node.fluid_code.add_layer("crop",
inputs=input, inputs=input,
...@@ -732,36 +820,37 @@ class TFOpMapper(OpMapper): ...@@ -732,36 +820,37 @@ class TFOpMapper(OpMapper):
data_format = node.get_attr("data_format").decode() data_format = node.get_attr("data_format").decode()
pad_mode = node.get_attr("padding").decode() pad_mode = node.get_attr("padding").decode()
channel_first = data_format == "NCHW" channel_first = data_format == "NCHW"
self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose(
kernel.value, (3, 2, 0, 1))
if not channel_first: if not channel_first:
self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose(
kernel.value, (3, 2, 0, 1))
attr = {"perm": [0, 3, 1, 2]}
node.fluid_code.add_layer("transpose",
inputs=input,
output=node,
param_attr=attr)
in_shape = [in_shape[i] for i in [0, 3, 1, 2]] in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
strides = [strides[i] for i in [0, 3, 1, 2]] strides = [strides[i] for i in [0, 3, 1, 2]]
dilations = [dilations[i] for i in [0, 3, 1, 2]] dilations = [dilations[i] for i in [0, 3, 1, 2]]
else:
self.data_format_propagation(node)
padding = 0
if pad_mode == "SAME": if pad_mode == "SAME":
pad_h = get_same_padding(in_shape[2], k_size[0], strides[2]) pad_h = get_same_padding(in_shape[2], k_size[0], strides[2])
pad_w = get_same_padding(in_shape[3], k_size[1], strides[3]) pad_w = get_same_padding(in_shape[3], k_size[1], strides[3])
attr = {"paddings": pad_h + pad_w, "pad_value": 0.0} if pad_h[0] == pad_h[1] and pad_w[0] == pad_w[1]:
if pad_h[0] + pad_h[1] + pad_w[0] + pad_w[1] != 0: padding = [pad_h[0], pad_w[0]]
node.fluid_code.add_layer( else:
"pad2d", attr = {"paddings": pad_h + pad_w, "pad_value": 0.0}
inputs=input if channel_first else node, node.fluid_code.add_layer("pad2d",
output=node, inputs=input,
param_attr=attr) output=node,
param_attr=attr)
input = node
attr = { attr = {
"bias_attr": False, "bias_attr": False,
"param_attr": string(kernel.layer_name), "param_attr": string(kernel.layer_name),
"num_filters": k_size[3], "num_filters": k_size[3],
"filter_size": k_size[0:2], "filter_size": k_size[0:2],
"stride": strides[2:4], "stride": strides[2:4],
"dilation": dilations[2:4] "dilation": dilations[2:4],
"padding": padding
} }
node.fluid_code.add_layer( node.fluid_code.add_layer(
"conv2d_transpose", "conv2d_transpose",
...@@ -769,19 +858,16 @@ class TFOpMapper(OpMapper): ...@@ -769,19 +858,16 @@ class TFOpMapper(OpMapper):
output=node, output=node,
param_attr=attr) param_attr=attr)
if not channel_first:
attr = {"perm": [0, 2, 3, 1]}
node.fluid_code.add_layer("transpose",
inputs=node,
output=node,
param_attr=attr)
def Max(self, node): def Max(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True) input = self.graph.get_node(node.layer.input[0], copy=True)
reduce_idx = self.graph.get_node(node.layer.input[1], copy=True) reduce_idx = self.graph.get_node(node.layer.input[1], copy=True)
assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]" assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
keep_dims = node.get_attr("keep_dims") keep_dims = node.get_attr("keep_dims")
attr = {"dim": reduce_idx.value.tolist(), "keep_dim": keep_dims} dim = reduce_idx.value.tolist()
if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
dim = nhwc_dim_to_nchw(input, dim)
attr = {"dim": dim, "keep_dim": keep_dims}
node.fluid_code.add_layer("reduce_max", node.fluid_code.add_layer("reduce_max",
inputs=input, inputs=input,
output=node, output=node,
...@@ -792,7 +878,11 @@ class TFOpMapper(OpMapper): ...@@ -792,7 +878,11 @@ class TFOpMapper(OpMapper):
reduce_idx = self.graph.get_node(node.layer.input[1], copy=True) reduce_idx = self.graph.get_node(node.layer.input[1], copy=True)
assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]" assert reduce_idx.layer_type == "Const", "Only support Const parameter[reduce_idx]"
keep_dims = node.get_attr("keep_dims") keep_dims = node.get_attr("keep_dims")
attr = {"dim": reduce_idx.value.tolist(), "keep_dim": keep_dims} dim = reduce_idx.value.tolist()
if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
dim = nhwc_dim_to_nchw(input, dim)
attr = {"dim": dim, "keep_dim": keep_dims}
node.fluid_code.add_layer("reduce_sum", node.fluid_code.add_layer("reduce_sum",
inputs=input, inputs=input,
output=node, output=node,
...@@ -826,8 +916,35 @@ class TFOpMapper(OpMapper): ...@@ -826,8 +916,35 @@ class TFOpMapper(OpMapper):
assert dim.layer_type == "Const" assert dim.layer_type == "Const"
self.omit_nodes.append(dim.layer_name) self.omit_nodes.append(dim.layer_name)
num_split = node.get_attr('num_split') num_split = node.get_attr('num_split')
attr = {"num_or_sections": num_split, "dim": dim.value} dim = dim.value
if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
dim = nhwc_dim_to_nchw(input, dim)
attr = {"num_or_sections": num_split, "dim": dim}
node.fluid_code.add_layer("split", node.fluid_code.add_layer("split",
inputs=input, inputs=input,
output=node, output=node,
param_attr=attr) param_attr=attr)
def Squeeze(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
squeeze_dims = node.get_attr('squeeze_dims')
if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
for i in range(len(squeeze_dims)):
squeeze_dims[i] = nhwc_dim_to_nchw(input, squeeze_dims[i])
attr = {"axes": squeeze_dims}
node.fluid_code.add_layer("squeeze",
inputs=input,
output=node,
param_attr=attr)
def Softmax(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
axis = node.get_attr("axis")
if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
axis = nhwc_dim_to_nchw(input, axis)
attr = {"axis": axis}
node.fluid_code.add_layer("softmax",
inputs=input,
output=node,
param_attr=attr)
...@@ -13,10 +13,95 @@ ...@@ -13,10 +13,95 @@
# limitations under the License. # limitations under the License.
# TODO useless node remove # TODO useless node remove
from x2paddle.decoder.tf_decoder import TFGraph from x2paddle.op_mapper.tf_op_mapper import TFOpMapper
from x2paddle.core.util import *
# TODO bn merge
# TODO activation merge class TFOptimizer(object):
activation_ops = {
'Relu': 'relu',
'Sigmoid': 'sigmoid',
'Relu6': 'relu6',
'swish_f32': 'swish'
}
layers_with_act = [
'Conv2D', 'BiasAdd', 'DepthwiseConv2dNative', 'Conv2DBackpropInput',
'FusedBatchNorm'
]
layers_with_bias = [
'Conv2D', 'DepthwiseConv2dNative', 'Conv2DBackpropInput'
]
# TODO biasadd merge def __init__(self, op_mapper):
self.op_mapper = op_mapper
self.graph = op_mapper.graph
def delete_redundance_code(self):
for node_name in self.graph.topo_sort:
if node_name in self.op_mapper.omit_nodes:
node = self.graph.get_node(node_name)
omit_freq = self.op_mapper.omit_nodes.count(node_name)
if len(node.outputs) <= omit_freq:
node.fluid_code.clear()
# TODO activation merge
def merge_activation(self):
act_nodes = list()
for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name)
if node.layer_type in self.activation_ops:
act_nodes.append(node_name)
for act_node_name in act_nodes:
node = self.graph.get_node(act_node_name)
input = self.graph.get_node(node.inputs[0])
if input.layer_type not in self.layers_with_act:
continue
if len(input.fluid_code.layers) == 0:
continue
if 'act' in input.fluid_code.layers[
-1].param_attr and input.fluid_code.layers[-1].param_attr[
'act'] is not None:
continue
if len(input.outputs) != 1:
continue
input.fluid_code.layers[-1].param_attr['act'] = string(
self.activation_ops[node.layer_type])
input.fluid_code.layers[-1].output = node.fluid_code.layers[
0].output
self.graph.remove_node(act_node_name)
# TODO bias merge
def merge_bias(self):
for node_name in self.graph.topo_sort:
node = self.graph.get_node(node_name)
if node.layer_type == "BiasAdd":
input = self.graph.get_node(node.inputs[0])
if input.layer_type not in self.layers_with_bias:
continue
if len(input.outputs) != 1:
continue
if len(input.fluid_code.layers) == 0:
continue
bias_with_act = False
if 'act' in node.fluid_code.layers[-1].param_attr:
bias_with_act = True
layer_with_act = False
if 'act' in input.fluid_code.layers[
-1].param_attr and input.fluid_code.layers[
-1].param_attr['act'] is not None:
layer_with_act = True
if bias_with_act and layer_with_act:
continue
if not input.fluid_code.layers[-1].param_attr['bias_attr']:
bias_name = node.inputs[1]
input.fluid_code.layers[-1].param_attr[
'bias_attr'] = string(bias_name)
input.fluid_code.layers[-1].output = node.fluid_code.layers[
0].output
if bias_with_act:
input.fluid_code.layers[-1].param_attr[
'act'] = node.fluid_code.layers[-1].param_attr[
'act']
node.fluid_code.clear()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册