提交 69911b60 编写于 作者: J jiangjiajun

structure modify

上级 9e3ab110
......@@ -46,25 +46,25 @@ def arg_parser():
return parser
def tf2paddle(model, save_dir):
def tf2paddle(model_path, save_dir):
print("Now translating model from tensorflow to paddle.")
from x2paddle.parser.tf_parser import TFParser
from x2paddle.decoder.tf_decoder import TFDecoder
from x2paddle.optimizer.tf_optimizer import TFGraphOptimizer
from x2paddle.emitter.tf_emitter import TFEmitter
parser = TFParser(model)
emitter = TFEmitter(parser)
emitter.run()
emitter.save_python_model(save_dir)
from x2paddle.op_mapper.tf_op_mapper import TFOpMapper
model = TFDecoder(model_path)
mapper = TFOpMapper(model)
mapper.run()
mapper.save_python_model(save_dir)
def caffe2paddle(proto, weight, save_dir):
print("Now translating model from caffe to paddle.")
from x2paddle.parser.caffe_parser import CaffeParser
from x2paddle.emitter.caffe_emitter import CaffeEmitter
parser = CaffeParser(proto, weight)
emitter = CaffeEmitter(parser)
emitter.run()
emitter.save_python_model(save_dir)
from x2paddle.decoder.caffe_decoder import CaffeDecoder
from x2paddle.op_mapper.caffe_op_mapper import CaffeOpMapper
model = CaffeDecoder(proto, weight)
mapper = CaffeOpMapper(model)
mapper.run()
mapper.save_python_model(save_dir)
def main():
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.core.util import *
import os
class Emitter(object):
def __init__(self):
self.paddle_codes = ""
self.tab = " "
self.net_code = list()
self.weights = dict()
def add_codes(self, codes, indent=0):
if isinstance(codes, list):
for code in codes:
self.paddle_codes += (self.tab * indent + code + '\n')
elif isinstance(codes, str):
self.paddle_codes += (self.tab * indent + codes + '\n')
else:
raise Exception("Unknown type of codes")
def add_heads(self):
self.add_codes("from paddle.fluid.initializer import Constant")
self.add_codes("from paddle.fluid.param_attr import ParamAttr")
self.add_codes("import paddle.fluid as fluid")
self.add_codes("")
def save_inference_model(self):
print("Not Implement")
def save_python_model(self, save_dir):
for name, param in self.weights.items():
export_paddle_param(param, name, save_dir)
self.add_heads()
self.add_codes(self.net_code)
fp = open(os.path.join(save_dir, "model.py"), 'w')
fp.write(self.paddle_codes)
fp.close()
......@@ -18,7 +18,7 @@ from google.protobuf import text_format
import numpy as np
from x2paddle.core.graph import GraphNode, Graph
from x2paddle.core.fluid_code import FluidCode
from x2paddle.parser import caffe_shape
from x2paddle.decoder import caffe_shape
class CaffeResolver(object):
......@@ -188,7 +188,7 @@ class CaffeGraph(Graph):
return self.get_node(name, copy=copy)
class CaffeParser(object):
class CaffeDecoder(object):
def __init__(self, proto_path, model_path, use_caffe=True):
self.proto_path = proto_path
self.model_path = model_path
......
......@@ -153,14 +153,8 @@ class TFGraph(Graph):
del self.topo_sort[idx]
class TFParser(object):
def __init__(self, pb_model, in_nodes=None, out_nodes=None, in_shapes=None):
# assert in_nodes is not None, "in_nodes should not be None"
# assert out_nodes is not None, "out_nodes should not be None"
# assert in_shapes is not None, "in_shapes should not be None"
# assert len(in_shapes) == len(
# in_nodes), "length of in_shapes and in_nodes should be equal"
class TFDecoder(object):
def __init__(self, pb_model):
sess = tf.Session()
with gfile.FastGFile(pb_model, 'rb') as f:
graph_def = tf.GraphDef()
......
......@@ -13,18 +13,17 @@
# limitations under the License.
import numbers
from x2paddle.parser.caffe_parser import CaffeGraph
from x2paddle.core.emitter import Emitter
from x2paddle.decoder.caffe_decoder import CaffeGraph
from x2paddle.core.op_mapper import OpMapper
from x2paddle.core.util import *
class CaffeEmitter(Emitter):
def __init__(self, parser):
super(CaffeEmitter, self).__init__()
self.parser = parser
self.graph = parser.caffe_graph
class CaffeOpMapper(OpMapper):
def __init__(self, decoder):
super(CaffeOpMapper, self).__init__()
self.graph = decoder.caffe_graph
self.weights = dict()
resolver = parser.resolver
resolver = decoder.resolver
if resolver.has_pycaffe():
self.did_use_pb = False
else:
......@@ -36,8 +35,8 @@ class CaffeEmitter(Emitter):
node = self.graph.get_node(node_name)
op = node.layer_type
if hasattr(self, op):
emit_func = getattr(self, op)
emit_func(node)
func = getattr(self, op)
func(node)
for i in range(len(self.graph.topo_sort)):
node_name = self.graph.topo_sort[i]
......
......@@ -12,21 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.parser.tf_parser import TFGraph
from x2paddle.core.emitter import Emitter
from x2paddle.core.fluid_code import FluidCode
from x2paddle.decoder.tf_decoder import TFGraph
from x2paddle.core.op_mapper import OpMapper
from x2paddle.core.util import *
import numpy
class TFEmitter(Emitter):
def __init__(self, parser):
super(TFEmitter, self).__init__()
self.parser = parser
self.graph = parser.tf_graph
# attr_node is used to record nodes that
# only for define attribute of op
self.attr_node = list()
class TFOpMapper(OpMapper):
def __init__(self, decoder):
super(TFOpMapper, self).__init__()
self.graph = decoder.tf_graph
self.weights = dict()
self.omit_nodes = list()
def run(self):
......@@ -35,8 +31,8 @@ class TFEmitter(Emitter):
node = self.graph.get_node(node_name)
op = node.layer_type
if hasattr(self, op):
emit_func = getattr(self, op)
emit_func(node)
func = getattr(self, op)
func(node)
for i in range(len(self.graph.topo_sort)):
node_name = self.graph.topo_sort[i]
......
......@@ -13,7 +13,7 @@
# limitations under the License.
# TODO useless node remove
from x2paddle.parser.tf_parser import TFGraph
from x2paddle.decoder.tf_decoder import TFGraph
class TFGraphOptimizer(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册