未验证 提交 1accd53d 编写于 作者: J Jason 提交者: GitHub

Merge pull request #426 from SunAhong1993/paddle-2.0

add tf static
...@@ -103,6 +103,7 @@ def tf2paddle(model_path, ...@@ -103,6 +103,7 @@ def tf2paddle(model_path,
save_dir, save_dir,
without_data_format_optimization=False, without_data_format_optimization=False,
define_input_shape=False, define_input_shape=False,
paddle_type="dygraph",
params_merge=False): params_merge=False):
# check tensorflow installation and version # check tensorflow installation and version
try: try:
...@@ -120,25 +121,28 @@ def tf2paddle(model_path, ...@@ -120,25 +121,28 @@ def tf2paddle(model_path,
"[ERROR] Tensorflow is not installed, use \"pip install tensorflow\"." "[ERROR] Tensorflow is not installed, use \"pip install tensorflow\"."
) )
return return
from x2paddle.decoder.tf_decoder import TFDecoder from x2paddle.decoder.tf_decoder import TFDecoder
from x2paddle.op_mapper.tf_op_mapper import TFOpMapper if paddle_type == "dygraph":
from x2paddle.op_mapper.dygraph.tf2paddle.tf_op_mapper import TFOpMapper
else:
from x2paddle.op_mapper.static.tf2paddle.tf_op_mapper import TFOpMapper
from x2paddle.optimizer.tensorflow.bias import BiasOpt from x2paddle.optimizer.tensorflow.bias import BiasOpt
from x2paddle.optimizer.tensorflow.transpose import TransposeOpt from x2paddle.optimizer.tensorflow.transpose import TransposeOpt
from x2paddle.optimizer.tensorflow.batch_norm import BatchNormOpt from x2paddle.optimizer.tensorflow.batch_norm import BatchNormOpt
print("Now translating model from tensorflow to paddle.") print("Now translating model from tensorflow to paddle.")
model = TFDecoder(model_path, define_input_shape=define_input_shape) model = TFDecoder(model_path, define_input_shape=define_input_shape)
mapper = TFOpMapper(model) mapper = TFOpMapper(model)
program.build() mapper.paddle_graph.build()
bias_opt = BiasOpt() bias_opt = BiasOpt()
transpose_opt = TransposeOpt() transpose_opt = TransposeOpt()
batch_norm_opt = BatchNormOpt() batch_norm_opt = BatchNormOpt()
bias_opt.run(program) bias_opt.run(program)
batch_norm_opt.run(program) batch_norm_opt.run(program)
transpose_opt.run(program) transpose_opt.run(program)
program.gen_model(save_dir) mapper.paddle_graph.gen_model(save_dir)
def caffe2paddle(proto, weight, save_dir, caffe_proto, def caffe2paddle(proto, weight, save_dir, caffe_proto,
...@@ -293,7 +297,7 @@ def main(): ...@@ -293,7 +297,7 @@ def main():
if args.params_merge: if args.params_merge:
params_merge = True params_merge = True
tf2paddle(args.model, args.save_dir, without_data_format_optimization, tf2paddle(args.model, args.save_dir, without_data_format_optimization,
define_input_shape, params_merge) define_input_shape, args.paddle_type, params_merge)
elif args.framework == "caffe": elif args.framework == "caffe":
assert args.prototxt is not None and args.weight is not None, "--prototxt and --weight should be defined while translating caffe model" assert args.prototxt is not None and args.weight is not None, "--prototxt and --weight should be defined while translating caffe model"
......
...@@ -76,6 +76,7 @@ class PaddleGraph(object): ...@@ -76,6 +76,7 @@ class PaddleGraph(object):
self.custom_code = None self.custom_code = None
self.inputs_info = None self.inputs_info = None
def set_name(self, name): def set_name(self, name):
self.name = name.replace("-", "_").replace("/", "_") self.name = name.replace("-", "_").replace("/", "_")
...@@ -285,8 +286,8 @@ class PaddleGraph(object): ...@@ -285,8 +286,8 @@ class PaddleGraph(object):
for input_name in self.inputs: for input_name in self.inputs:
input_shapes.append(self.inputs_info[input_name][0]) input_shapes.append(self.inputs_info[input_name][0])
input_types.append(self.inputs_info[input_name][1]) input_types.append(self.inputs_info[input_name][1])
# 如果input_files非空,则导出推理模型;其值类似[[None, 3, 224, 224]] # 如果input_files非空,则导出推理模型;其值类似[[None, 3, 224, 224]]
self.dygraph2static(save_dir, input_shapes, input_types) self.dygraph2static(save_dir, input_shapes, input_types)
def gen_static_code(self, code_dir): def gen_static_code(self, code_dir):
def write_code(f, code_list, indent=0): def write_code(f, code_list, indent=0):
...@@ -446,6 +447,8 @@ class PaddleGraph(object): ...@@ -446,6 +447,8 @@ class PaddleGraph(object):
if self.source_type == "caffe": if self.source_type == "caffe":
custom_import = "from x2paddle.op_mapper.dygraph.caffe2paddle " + \ custom_import = "from x2paddle.op_mapper.dygraph.caffe2paddle " + \
"import caffe_custom_layer as x2paddle_nn" "import caffe_custom_layer as x2paddle_nn"
else:
custom_import = ""
self.head = gen_codes( self.head = gen_codes(
[ [
"from paddle.fluid.initializer import Constant", "from paddle.fluid.initializer import Constant",
...@@ -581,7 +584,10 @@ class PaddleGraph(object): ...@@ -581,7 +584,10 @@ class PaddleGraph(object):
line = ','.join(layer.outputs) line = ','.join(layer.outputs)
line += " = {}(".format(layer.kernel) line += " = {}(".format(layer.kernel)
for k, v in layer.inputs.items(): for k, v in layer.inputs.items():
line += "{}={}, ".format(k, v) if isinstance(v, list):
line += "{}=[{}], ".format(k, ", ".join(v))
else:
line += "{}={}, ".format(k, v)
for k, v in layer.attrs.items(): for k, v in layer.attrs.items():
line += "{}={}, ".format(k, v) line += "{}={}, ".format(k, v)
line = line.strip(", ") line = line.strip(", ")
...@@ -618,7 +624,10 @@ class PaddleGraph(object): ...@@ -618,7 +624,10 @@ class PaddleGraph(object):
paddle.disable_static() paddle.disable_static()
restore, _ = fluid.load_dygraph(osp.join(save_dir, "model")) restore, _ = fluid.load_dygraph(osp.join(save_dir, "model"))
model = getattr(x2paddle_code, self.name)() model = getattr(x2paddle_code, self.name)()
model.set_dict(restore) if self.source_type == "tf":
model.set_dict(restore, use_structured_name=False)
else:
model.set_dict(restore)
model.eval() model.eval()
static_model = paddle.jit.to_static(model, input_spec=sepc_list) static_model = paddle.jit.to_static(model, input_spec=sepc_list)
paddle.jit.save(static_model, osp.join(save_dir, "inference_model/model")) paddle.jit.save(static_model, osp.join(save_dir, "inference_model/model"))
\ No newline at end of file
...@@ -132,6 +132,7 @@ class TFGraph(Graph): ...@@ -132,6 +132,7 @@ class TFGraph(Graph):
self.identity_map = dict() self.identity_map = dict()
self.multi_out_ops = ['Split', 'SplitV', 'IteratorV2'] self.multi_out_ops = ['Split', 'SplitV', 'IteratorV2']
self.tf_data_format = data_format self.tf_data_format = data_format
self.graph_name = "TFModel"
def build(self): def build(self):
for layer in self.model.node: for layer in self.model.node:
......
...@@ -277,7 +277,7 @@ def prim_list(layer, indent=1, init_func=[], forward_func=[], layer_id=None, dif ...@@ -277,7 +277,7 @@ def prim_list(layer, indent=1, init_func=[], forward_func=[], layer_id=None, dif
inputs_str = ', '.join(inputs_list) inputs_str = ', '.join(inputs_list)
line = "{} = [{}]".format(layer.outputs[0], inputs_str) line = "{} = [{}]".format(layer.outputs[0], inputs_str)
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_list_unpack(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_list_unpack(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
line = "{} = {}".format(", ".join(layer.outputs), get_value(layer, "input", different_attrs)) line = "{} = {}".format(", ".join(layer.outputs), get_value(layer, "input", different_attrs))
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册