提交 99582ade 编写于 作者: J jiangjiajun

common code generate and weight dump

上级 1f79d43d
......@@ -14,14 +14,70 @@
from x2paddle.parser.tf_parser import TFParser
from x2paddle.optimizer.tf_optimizer import TFGraphOptimizer
from x2paddle.emitter.tf_emitter import TFEmitter
from six import text_type as _text_type
import argparse
parser = TFParser('/ssd2/Jason/github/X2Paddle/tool/vgg16_None.pb',
in_nodes=['inputs'],
out_nodes=['output_boxes'],
in_shapes=[[-1, 416, 416, 3]])
optimizer = TFGraphOptimizer()
#parser.tf_graph.print()
def arg_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--model",
"-m",
type=_text_type,
default=None,
help="model file path")
parser.add_argument("--proto",
"-p",
type=_text_type,
default=None,
help="proto file of caffe model")
parser.add_argument("--weight",
"-w",
type=_text_type,
default=None,
help="weight file of caffe model")
parser.add_argument("--save_dir",
"-s",
type=_text_type,
default=None,
help="path to save translated model")
parser.add_argument("--framework",
"-f",
type=_text_type,
default=None,
help="define which deeplearning framework")
return parser
emitter = TFEmitter(parser)
emitter.run()
def tf2paddle(model, save_dir):
print("Now translating model from tensorflow to paddle.")
parser = TFParser(model)
emitter = TFEmitter(parser)
emitter.run()
emitter.save_python_model(save_dir)
def caffe2paddle(proto, weight, save_dir):
print("Not implement yet.")
def main():
parser = arg_parser()
args = parser.parse_args()
assert args.framework is not None, "--from is not defined(tensorflow/caffe)"
assert args.save_dir is not None, "--save_dir is not defined"
if args.framework == "tensorflow":
assert args.model is not None, "--model should be defined while translate tensorflow model"
tf2paddle(args.model, args.save_dir)
elif args.framework == "caffe":
assert args.proto is not None, "--proto and --weight should be defined while translate caffe model"
caffe2paddle(args.proto, args.weight, args.save_dir)
else:
raise Exception("--framework only support tensorflow/caffe now")
if __name__ == "__main__":
main()
......@@ -12,11 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.core.util import *
import os
class Emitter(object):
def __init__(self):
self.paddle_codes = ""
self.tab = " "
self.net_code = list()
self.weights = dict()
def add_codes(self, codes, indent=0):
if isinstance(codes, list):
......@@ -28,11 +33,19 @@ class Emitter(object):
raise Exception("Unknown type of codes")
def add_heads(self):
self.add_codes("from paddle.fluid.initializer import Constant")
self.add_codes("from paddle.fluid.param_attr import ParamAttr")
self.add_codes("import paddle.fluid as fluid")
self.add_codes("")
def save_inference_model(self):
print("Not Implement")
def save_python_code(self):
print("Not Implement")
def save_python_model(self, save_dir):
for name, param in self.weights.items():
export_paddle_param(param, name, save_dir)
self.add_heads()
self.add_codes(self.net_code)
fp = open(os.path.join(save_dir, "model.py"), 'w')
fp.write(self.paddle_codes)
fp.close()
......@@ -100,3 +100,4 @@ class FluidCode(object):
codes.append(layer.get_code())
elif isinstance(layer, str):
codes.append(layer)
return codes
......@@ -44,7 +44,6 @@ def export_paddle_param(param, param_name, dir):
if len(shape) == 0:
assert param.size == 1, "Unexpected situation happend!"
shape = [1]
print("param dtype:", param.dtype)
assert str(param.dtype) in dtype_map, "Unknown dtype of params."
fp = open(os.path.join(dir, param_name), 'wb')
......
......@@ -28,7 +28,6 @@ class TFEmitter(Emitter):
# only for define attribute of op
self.attr_node = list()
self.omit_nodes = list()
self.weights = dict()
def run(self):
print("Total nodes: {}".format(len(self.graph.topo_sort)))
......@@ -44,13 +43,7 @@ class TFEmitter(Emitter):
if node_name in self.omit_nodes:
continue
node = self.graph.get_node(node_name)
for layer in node.fluid_code.layers:
print(layer.get_code())
for name, param in self.weights.items():
node = self.graph.get_node(name)
export_paddle_param(param, node.layer_name.replace('/', '_'),
"params1")
self.net_code += node.fluid_code.gen_codes()
def Placeholder(self, node):
shape = node.out_shapes[0]
......@@ -85,13 +78,13 @@ class TFEmitter(Emitter):
inputs=None,
output=node,
param_attr=attr)
self.weights[node.layer_name] = node.value
self.weights[node.layer_name.replace('/', '_')] = node.value
def Transpose(self, node):
input = self.graph.get_node(node.layer.input[0], copy=True)
perm = self.graph.get_node(node.layer.input[1], copy=True)
assert perm.layer_type == "Const", "Perm of transpose OP should be Const"
del self.weights[perm.layer_name]
del self.weights[perm.layer_name.replace('/', '_')]
perm.fluid_code.clear()
perm = perm.value.tolist()
......@@ -204,7 +197,7 @@ class TFEmitter(Emitter):
channel_first = data_format == "NCHW"
if not channel_first:
self.weights[kernel.layer_name] = numpy.transpose(
self.weights[kernel.layer_name.replace('/', '_')] = numpy.transpose(
kernel.value, (3, 2, 0, 1))
attr = {"perm": [0, 3, 1, 2]}
node.fluid_code.add_layer("transpose",
......
......@@ -24,9 +24,11 @@ import copy
class TFGraphNode(GraphNode):
def __init__(self, layer, layer_name=None):
if layer_name is None:
super(TFGraphNode, self).__init__(layer, layer.name)
super(TFGraphNode, self).__init__(layer,
layer.name.replace('/', '_'))
else:
super(TFGraphNode, self).__init__(layer, layer_name)
super(TFGraphNode, self).__init__(layer,
layer_name.replace('/', '_'))
self.layer_type = layer.op
self.fluid_code = FluidCode()
......@@ -86,10 +88,11 @@ class TFGraph(Graph):
def build(self):
for layer in self.model.node:
self.node_map[layer.name] = TFGraphNode(layer)
self.node_map[layer.name.replace('/', '_')] = TFGraphNode(layer)
for layer_name, node in self.node_map.items():
for in_node in node.layer.input:
in_node = in_node.replace('/', '_')
if in_node not in self.node_map:
if in_node.strip().split(':')[0] in self.node_map:
self.connect(in_node.strip().split(':')[0], layer_name)
......@@ -108,6 +111,7 @@ class TFGraph(Graph):
def get_node(self, node_name, copy=False):
items = node_name.strip().split(':')
items[0] = items[0].replace('/', '_')
if items[0] in self.identity_map:
items[0] = self.identity_map[items[0]]
new_node_name = ":".join(items)
......@@ -151,11 +155,11 @@ class TFGraph(Graph):
class TFParser(object):
def __init__(self, pb_model, in_nodes=None, out_nodes=None, in_shapes=None):
assert in_nodes is not None, "in_nodes should not be None"
assert out_nodes is not None, "out_nodes should not be None"
assert in_shapes is not None, "in_shapes should not be None"
assert len(in_shapes) == len(
in_nodes), "length of in_shapes and in_nodes should be equal"
# assert in_nodes is not None, "in_nodes should not be None"
# assert out_nodes is not None, "out_nodes should not be None"
# assert in_shapes is not None, "in_shapes should not be None"
# assert len(in_shapes) == len(
# in_nodes), "length of in_shapes and in_nodes should be equal"
sess = tf.Session()
with gfile.FastGFile(pb_model, 'rb') as f:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册