提交 69911b60 编写于 作者: J jiangjiajun

structure modify

上级 9e3ab110
...@@ -46,25 +46,25 @@ def arg_parser(): ...@@ -46,25 +46,25 @@ def arg_parser():
return parser return parser
def tf2paddle(model, save_dir): def tf2paddle(model_path, save_dir):
print("Now translating model from tensorflow to paddle.") print("Now translating model from tensorflow to paddle.")
from x2paddle.parser.tf_parser import TFParser from x2paddle.decoder.tf_decoder import TFDecoder
from x2paddle.optimizer.tf_optimizer import TFGraphOptimizer from x2paddle.optimizer.tf_optimizer import TFGraphOptimizer
from x2paddle.emitter.tf_emitter import TFEmitter from x2paddle.op_mapper.tf_op_mapper import TFOpMapper
parser = TFParser(model) model = TFDecoder(model_path)
emitter = TFEmitter(parser) mapper = TFOpMapper(model)
emitter.run() mapper.run()
emitter.save_python_model(save_dir) mapper.save_python_model(save_dir)
def caffe2paddle(proto, weight, save_dir): def caffe2paddle(proto, weight, save_dir):
print("Now translating model from caffe to paddle.") print("Now translating model from caffe to paddle.")
from x2paddle.parser.caffe_parser import CaffeParser from x2paddle.decoder.caffe_decoder import CaffeDecoder
from x2paddle.emitter.caffe_emitter import CaffeEmitter from x2paddle.op_mapper.caffe_op_mapper import CaffeOpMapper
parser = CaffeParser(proto, weight) model = CaffeDecoder(proto, weight)
emitter = CaffeEmitter(parser) mapper = CaffeOpMapper(model)
emitter.run() mapper.run()
emitter.save_python_model(save_dir) mapper.save_python_model(save_dir)
def main(): def main():
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from x2paddle.core.util import *
import os
class Emitter(object):
def __init__(self):
self.paddle_codes = ""
self.tab = " "
self.net_code = list()
self.weights = dict()
def add_codes(self, codes, indent=0):
if isinstance(codes, list):
for code in codes:
self.paddle_codes += (self.tab * indent + code + '\n')
elif isinstance(codes, str):
self.paddle_codes += (self.tab * indent + codes + '\n')
else:
raise Exception("Unknown type of codes")
def add_heads(self):
self.add_codes("from paddle.fluid.initializer import Constant")
self.add_codes("from paddle.fluid.param_attr import ParamAttr")
self.add_codes("import paddle.fluid as fluid")
self.add_codes("")
def save_inference_model(self):
print("Not Implement")
def save_python_model(self, save_dir):
for name, param in self.weights.items():
export_paddle_param(param, name, save_dir)
self.add_heads()
self.add_codes(self.net_code)
fp = open(os.path.join(save_dir, "model.py"), 'w')
fp.write(self.paddle_codes)
fp.close()
...@@ -18,7 +18,7 @@ from google.protobuf import text_format ...@@ -18,7 +18,7 @@ from google.protobuf import text_format
import numpy as np import numpy as np
from x2paddle.core.graph import GraphNode, Graph from x2paddle.core.graph import GraphNode, Graph
from x2paddle.core.fluid_code import FluidCode from x2paddle.core.fluid_code import FluidCode
from x2paddle.parser import caffe_shape from x2paddle.decoder import caffe_shape
class CaffeResolver(object): class CaffeResolver(object):
...@@ -188,7 +188,7 @@ class CaffeGraph(Graph): ...@@ -188,7 +188,7 @@ class CaffeGraph(Graph):
return self.get_node(name, copy=copy) return self.get_node(name, copy=copy)
class CaffeParser(object): class CaffeDecoder(object):
def __init__(self, proto_path, model_path, use_caffe=True): def __init__(self, proto_path, model_path, use_caffe=True):
self.proto_path = proto_path self.proto_path = proto_path
self.model_path = model_path self.model_path = model_path
......
...@@ -153,14 +153,8 @@ class TFGraph(Graph): ...@@ -153,14 +153,8 @@ class TFGraph(Graph):
del self.topo_sort[idx] del self.topo_sort[idx]
class TFParser(object): class TFDecoder(object):
def __init__(self, pb_model, in_nodes=None, out_nodes=None, in_shapes=None): def __init__(self, pb_model):
# assert in_nodes is not None, "in_nodes should not be None"
# assert out_nodes is not None, "out_nodes should not be None"
# assert in_shapes is not None, "in_shapes should not be None"
# assert len(in_shapes) == len(
# in_nodes), "length of in_shapes and in_nodes should be equal"
sess = tf.Session() sess = tf.Session()
with gfile.FastGFile(pb_model, 'rb') as f: with gfile.FastGFile(pb_model, 'rb') as f:
graph_def = tf.GraphDef() graph_def = tf.GraphDef()
......
...@@ -13,18 +13,17 @@ ...@@ -13,18 +13,17 @@
# limitations under the License. # limitations under the License.
import numbers import numbers
from x2paddle.parser.caffe_parser import CaffeGraph from x2paddle.decoder.caffe_decoder import CaffeGraph
from x2paddle.core.emitter import Emitter from x2paddle.core.op_mapper import OpMapper
from x2paddle.core.util import * from x2paddle.core.util import *
class CaffeEmitter(Emitter): class CaffeOpMapper(OpMapper):
def __init__(self, parser): def __init__(self, decoder):
super(CaffeEmitter, self).__init__() super(CaffeOpMapper, self).__init__()
self.parser = parser self.graph = decoder.caffe_graph
self.graph = parser.caffe_graph
self.weights = dict() self.weights = dict()
resolver = parser.resolver resolver = decoder.resolver
if resolver.has_pycaffe(): if resolver.has_pycaffe():
self.did_use_pb = False self.did_use_pb = False
else: else:
...@@ -36,8 +35,8 @@ class CaffeEmitter(Emitter): ...@@ -36,8 +35,8 @@ class CaffeEmitter(Emitter):
node = self.graph.get_node(node_name) node = self.graph.get_node(node_name)
op = node.layer_type op = node.layer_type
if hasattr(self, op): if hasattr(self, op):
emit_func = getattr(self, op) func = getattr(self, op)
emit_func(node) func(node)
for i in range(len(self.graph.topo_sort)): for i in range(len(self.graph.topo_sort)):
node_name = self.graph.topo_sort[i] node_name = self.graph.topo_sort[i]
......
...@@ -12,21 +12,17 @@ ...@@ -12,21 +12,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from x2paddle.parser.tf_parser import TFGraph from x2paddle.decoder.tf_decoder import TFGraph
from x2paddle.core.emitter import Emitter from x2paddle.core.op_mapper import OpMapper
from x2paddle.core.fluid_code import FluidCode
from x2paddle.core.util import * from x2paddle.core.util import *
import numpy import numpy
class TFEmitter(Emitter): class TFOpMapper(OpMapper):
def __init__(self, parser): def __init__(self, decoder):
super(TFEmitter, self).__init__() super(TFOpMapper, self).__init__()
self.parser = parser self.graph = decoder.tf_graph
self.graph = parser.tf_graph self.weights = dict()
# attr_node is used to record nodes that
# only for define attribute of op
self.attr_node = list()
self.omit_nodes = list() self.omit_nodes = list()
def run(self): def run(self):
...@@ -35,8 +31,8 @@ class TFEmitter(Emitter): ...@@ -35,8 +31,8 @@ class TFEmitter(Emitter):
node = self.graph.get_node(node_name) node = self.graph.get_node(node_name)
op = node.layer_type op = node.layer_type
if hasattr(self, op): if hasattr(self, op):
emit_func = getattr(self, op) func = getattr(self, op)
emit_func(node) func(node)
for i in range(len(self.graph.topo_sort)): for i in range(len(self.graph.topo_sort)):
node_name = self.graph.topo_sort[i] node_name = self.graph.topo_sort[i]
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# TODO useless node remove # TODO useless node remove
from x2paddle.parser.tf_parser import TFGraph from x2paddle.decoder.tf_decoder import TFGraph
class TFGraphOptimizer(object): class TFGraphOptimizer(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册