提交 32b98f8e 编写于 作者: J jiangjiajun

mobilenet support for tensorflow

上级 4e3cdf05
......@@ -13,7 +13,7 @@
# limitations under the License.
import collections
from copy import deepcopy
import copy as cp
class GraphNode(object):
......@@ -77,7 +77,7 @@ class Graph(object):
if name.split(':')[0] in self.node_map:
name_prefix, idx = name.split(':')
if copy:
node = deepcopy(self.node_map[name_prefix])
node = cp.copy(self.node_map[name_prefix])
else:
node = self.node_map[name_prefix]
node.index = int(idx)
......@@ -86,7 +86,7 @@ class Graph(object):
raise Exception("Graph doesn't have node [%s]." % name)
else:
if copy:
node = deepcopy(self.node_map[name])
node = cp.copy(self.node_map[name])
else:
node = self.node_map[name]
return node
......@@ -110,6 +110,7 @@ class Graph(object):
del self.node_map[input].inputs[idx]
del self.node_map[node_name]
print("remove topo", node_name)
idx = self.topo_sort.index(node_name)
del self.topo_sort[idx]
......
......@@ -13,6 +13,7 @@
# limitations under the License.
from x2paddle.core.util import *
import inspect
import os
......@@ -33,7 +34,8 @@ class OpMapper(object):
if len(unsupported_ops) == 0:
return True
else:
print("There are {} ops not supported yet, list as below")
print("There are {} ops not supported yet, list as below".format(
len(unsupported_ops)))
for op in unsupported_ops:
print(op)
return False
......@@ -41,9 +43,10 @@ class OpMapper(object):
def add_codes(self, codes, indent=0):
if isinstance(codes, list):
for code in codes:
self.paddle_codes += (self.tab * indent + code + '\n')
self.paddle_codes += (self.tab * indent + code.strip('\n') +
'\n')
elif isinstance(codes, str):
self.paddle_codes += (self.tab * indent + codes + '\n')
self.paddle_codes += (self.tab * indent + codes.strip('\n') + '\n')
else:
raise Exception("Unknown type of codes")
......@@ -61,6 +64,8 @@ class OpMapper(object):
export_paddle_param(param, name, save_dir)
self.add_heads()
self.add_codes(self.net_code)
self.add_codes("")
self.add_codes(inspect.getsourcelines(init_net)[0])
fp = open(os.path.join(save_dir, "model.py"), 'w')
fp.write(self.paddle_codes)
fp.close()
......@@ -13,7 +13,8 @@
# limitations under the License.
from paddle.fluid.proto import framework_pb2
import struct
import paddle.fluid as fluid
import numpy
import math
import os
......@@ -49,14 +50,29 @@ def export_paddle_param(param, param_name, dir):
os.makedirs(dir)
fp = open(os.path.join(dir, param_name), 'wb')
fp.write(struct.pack('i', 0))
fp.write(struct.pack('L', 0))
fp.write(struct.pack('i', 0))
numpy.array([0], dtype='int32').tofile(fp)
numpy.array([0], dtype='int64').tofile(fp)
numpy.array([0], dtype='int32').tofile(fp)
tensor_desc = framework_pb2.VarType.TensorDesc()
tensor_desc.data_type = dtype_map[str(param.dtype)][0]
tensor_desc.dims.extend(shape)
desc_size = tensor_desc.ByteSize()
fp.write(struct.pack('i', desc_size))
numpy.array([desc_size], dtype='int32').tofile(fp)
fp.write(tensor_desc.SerializeToString())
param.tofile(fp)
fp.close()
def init_net(param_dir="./"):
import os
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
def if_exist(var):
b = os.path.exists(os.path.join(param_dir, var.name))
return b
fluid.io.load_vars(exe,
param_dir,
fluid.default_main_program(),
predicate=if_exist)
......@@ -18,7 +18,8 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.platform import gfile
from tensorflow.core.framework import attr_value_pb2
import tensorflow as tf
import copy
import copy as cp
import sys
class TFGraphNode(GraphNode):
......@@ -121,11 +122,12 @@ class TFGraph(Graph):
# delete isolated nodes
isolated_nodes = list()
for node_name in self.node_map.keys():
if len(self.get_node(node_name).inputs) == 0 or len(
if len(self.get_node(node_name).inputs) == 0 and len(
self.get_node(node_name).outputs) == 0:
isolated_nodes.append(node_name)
self.remove_node(node_name)
for node_name in isolated_nodes:
self.remove_node(node_name)
def _remove_identity_node(self):
identity_node = list()
......@@ -153,14 +155,40 @@ class TFGraph(Graph):
del self.topo_sort[idx]
def check_input_shape(graph_def):
graph_def = cp.deepcopy(graph_def)
input_map = dict()
for layer in graph_def.node:
if layer.op != "Placeholder":
continue
graph_node = TFGraphNode(layer)
dtype = graph_node.dtype
# print("shape:", graph_node.out_shapes)
if not graph_node.get_attr("shape"):
sys.stderr.write("Unknown shape for input tensor[{}]\n".format(
layer.name))
shape = input("Please define shape of input here: ")
shape = [
None if dim == "None" else int(dim)
for dim in shape.strip().split(',')
]
x2paddle_input = tf.placeholder(dtype=dtype,
shape=shape,
name="x2paddle_{}".format(
layer.name))
input_map["{}:0".format(layer.name)] = x2paddle_input
return input_map
class TFDecoder(object):
def __init__(self, pb_model):
sess = tf.Session()
with gfile.FastGFile(pb_model, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
input_map = check_input_shape(graph_def)
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
tf.import_graph_def(graph_def, name='', input_map=input_map)
sess.run(tf.global_variables_initializer())
......
......@@ -48,6 +48,8 @@ class TFOpMapper(OpMapper):
def Placeholder(self, node):
shape = node.out_shapes[0]
assert len(shape) != 0, "Unknown shape of input nodes[{}].".format(
node.layer_name)
dtype = node.dtype
attr = {
'dtype': string(dtype),
......@@ -171,10 +173,11 @@ class TFOpMapper(OpMapper):
"pool_type": string("max"),
"pool_stride": strides[2:4]
}
node.fluid_code.add_layer("pool2d",
inputs=input if channel_first else node,
output=node,
param_attr=attr)
node.fluid_code.add_layer(
"pool2d",
inputs=input if channel_first and pad_mode != "SAME" else node,
output=node,
param_attr=attr)
if not channel_first:
attr = {"perm": [0, 2, 3, 1]}
......@@ -227,6 +230,102 @@ class TFOpMapper(OpMapper):
"stride": strides[2:4],
"dilation": dilations[2:4]
}
node.fluid_code.add_layer(
"conv2d",
inputs=input if channel_first and pad_mode != "SAME" else node,
output=node,
param_attr=attr)
if not channel_first:
attr = {"perm": [0, 2, 3, 1]}
node.fluid_code.add_layer("transpose",
inputs=node,
output=node,
param_attr=attr)
def Relu6(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
node.fluid_code.add_layer("relu6",
inputs=input,
output=node,
param_attr=None)
def FusedBatchNorm(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
gamma = self.graph.get_node(node.layer.input[1], copy=True)
beta = self.graph.get_node(node.layer.input[2], copy=True)
moving_mean = self.graph.get_node(node.layer.input[3], copy=True)
moving_var = self.graph.get_node(node.layer.input[4], copy=True)
assert gamma.layer_type == "Const"
assert beta.layer_type == "Const"
assert moving_mean.layer_type == "Const"
assert moving_var.layer_type == "Const"
self.omit_nodes.append(gamma.layer_name)
self.omit_nodes.append(beta.layer_name)
self.omit_nodes.append(moving_mean.layer_name)
self.omit_nodes.append(moving_var.layer_name)
attr = {
"epsilon": node.get_attr("epsilon"),
"param_attr": string(gamma.layer_name),
"data_layout": string(node.get_attr("data_format").decode()),
"bias_attr": string(beta.layer_name),
"moving_mean_name": string(moving_mean.layer_name),
"moving_variance_name": string(moving_var.layer_name),
"is_test": True
}
node.fluid_code.add_layer("batch_norm",
inputs=input,
output=node,
param_attr=attr)
def DepthwiseConv2dNative(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
kernel = self.graph.get_node(node.layer.input[1], copy=True)
assert kernel.layer_type == "Const", "Kernel of DepthwiseConv2DNative should be Const"
self.omit_nodes.append(kernel.layer_name)
in_shape = input.out_shapes[0]
k_size = kernel.out_shapes[0]
strides = node.get_attr("strides")
dilations = node.get_attr("dilations")
data_format = node.get_attr("data_format").decode()
pad_mode = node.get_attr("padding").decode()
channel_first = data_format == "NCHW"
if not channel_first:
self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose(
kernel.value, (2, 3, 0, 1))
attr = {"perm": [0, 3, 1, 2]}
node.fluid_code.add_layer("transpose",
inputs=input,
output=node,
param_attr=attr)
in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
strides = [strides[i] for i in [0, 3, 1, 2]]
dilations = [dilations[i] for i in [0, 3, 1, 2]]
if pad_mode == "SAME":
pad_h = get_same_padding(in_shape[2], k_size[0], strides[2])
pad_w = get_same_padding(in_shape[3], k_size[1], strides[3])
attr = {"paddings": pad_h + pad_w, "pad_value": 0.0}
if pad_h[0] + pad_h[1] + pad_w[0] + pad_w[1] != 0:
node.fluid_code.add_layer("pad2d",
inputs=input if channel_first
and pad_mode != "SAME" else node,
output=node,
param_attr=attr)
attr = {
"bias_attr": False,
"param_attr": string(kernel.layer_name),
"num_filters": in_shape[1],
"filter_size": k_size[0:2],
"stride": strides[2:4],
"dilation": dilations[2:4],
"groups": k_size[3] * in_shape[1]
}
node.fluid_code.add_layer("conv2d",
inputs=input if channel_first else node,
output=node,
......@@ -238,3 +337,91 @@ class TFOpMapper(OpMapper):
inputs=node,
output=node,
param_attr=attr)
def Shape(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
node.fluid_code.add_layer("shape",
inputs=input,
output=node,
param_attr=None)
def Reshape(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
param = self.graph.get_node(node.layer.input[1], copy=True)
if param.layer_type == "Const":
attr = {"shape": param.value.tolist()}
else:
# Here is a trick method to solove tensor parameter in tensorflow
assert len(param.out_shapes[0]
) == 1, "Unexpected situation of shape parameter"
attr = {"num_or_sections": param.out_shapes[0][0], "dim": 0}
node.fluid_code.add_layer("split",
inputs=param,
output=node,
param_attr=attr)
new_param = "["
for i in range(param.out_shapes[0][0]):
new_param += (node.layer_name + "[{}]".format(i) + ", ")
new_param = new_param.strip(", ") + "]"
attr = {"shape": new_param}
node.fluid_code.add_layer("reshape",
inputs=input,
output=node,
param_attr=attr)
def Add(self, node):
x = self.graph.get_node(node.layer.input[0], copy=True)
y = self.graph.get_node(node.layer.input[1], copy=True)
inputs = {"x": x, "y": y}
node.fluid_code.add_layer("elementwise_add",
inputs=inputs,
output=node,
param_attr=None)
def AvgPool(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
in_shape = input.out_shapes[0]
k_size = node.get_attr("ksize")
strides = node.get_attr("strides")
data_format = node.get_attr("data_format").decode()
pad_mode = node.get_attr("padding").decode()
channel_first = data_format == "NCHW"
if not channel_first:
attr = {"perm": [0, 3, 1, 2]}
node.fluid_code.add_layer("transpose",
inputs=input,
output=node,
param_attr=attr)
in_shape = [in_shape[i] for i in [0, 3, 1, 2]]
strides = [strides[i] for i in [0, 3, 1, 2]]
attr = {
"pool_size": k_size[1:3],
"pool_type": string("avg"),
"pool_stride": strides[2:4]
}
if pad_mode == "SAME":
pad_h = get_same_padding(in_shape[2], k_size[0], strides[2])
pad_w = get_same_padding(in_shape[3], k_size[1], strides[3])
assert pad_h[0] == pad_h[1] and pad_w[0] == pad_w[
1], "Cannot map AvgPool"
attr["pool_padding"] = [pad_h[0], pad_w[0]]
node.fluid_code.add_layer("pool2d",
inputs=input if channel_first else node,
output=node,
param_attr=attr)
if not channel_first:
attr = {"perm": [0, 2, 3, 1]}
node.fluid_code.add_layer("transpose",
inputs=node,
output=node,
param_attr=attr)
def Softmax(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
node.fluid_code.add_layer("softmax",
inputs=input,
output=node,
param_attr=None)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册