提交 ea91537b 编写于 作者: M mamingjie-China

add params merge

上级 1b2b70f7
...@@ -68,7 +68,7 @@ def arg_parser(): ...@@ -68,7 +68,7 @@ def arg_parser():
action="store_true", action="store_true",
default=False, default=False,
help="define input shape for tf model") help="define input shape for tf model")
parser.add_argument("--param_merge", parser.add_argument("--params_merge",
"-pm", "-pm",
action="store_true", action="store_true",
default=False, default=False,
...@@ -81,7 +81,7 @@ def tf2paddle(model_path, ...@@ -81,7 +81,7 @@ def tf2paddle(model_path,
save_dir, save_dir,
without_data_format_optimization=False, without_data_format_optimization=False,
define_input_shape=False, define_input_shape=False,
param_merge=False): params_merge=False):
# check tensorflow installation and version # check tensorflow installation and version
try: try:
import os import os
...@@ -127,10 +127,10 @@ def tf2paddle(model_path, ...@@ -127,10 +127,10 @@ def tf2paddle(model_path,
optimizer.merge_bias() optimizer.merge_bias()
optimizer.make_nchw_input_output() optimizer.make_nchw_input_output()
optimizer.remove_transpose() optimizer.remove_transpose()
mapper.save_inference_model(save_dir, param_merge) mapper.save_inference_model(save_dir, params_merge)
def caffe2paddle(proto, weight, save_dir, caffe_proto, param_merge=False): def caffe2paddle(proto, weight, save_dir, caffe_proto, params_merge=False):
from x2paddle.decoder.caffe_decoder import CaffeDecoder from x2paddle.decoder.caffe_decoder import CaffeDecoder
from x2paddle.op_mapper.caffe_op_mapper import CaffeOpMapper from x2paddle.op_mapper.caffe_op_mapper import CaffeOpMapper
from x2paddle.optimizer.caffe_optimizer import CaffeOptimizer from x2paddle.optimizer.caffe_optimizer import CaffeOptimizer
...@@ -147,10 +147,10 @@ def caffe2paddle(proto, weight, save_dir, caffe_proto, param_merge=False): ...@@ -147,10 +147,10 @@ def caffe2paddle(proto, weight, save_dir, caffe_proto, param_merge=False):
optimizer = CaffeOptimizer(mapper) optimizer = CaffeOptimizer(mapper)
optimizer.merge_bn_scale() optimizer.merge_bn_scale()
optimizer.merge_op_activation() optimizer.merge_op_activation()
mapper.save_inference_model(save_dir, param_merge) mapper.save_inference_model(save_dir, params_merge)
def onnx2paddle(model_path, save_dir, param_merge=False): def onnx2paddle(model_path, save_dir, params_merge=False):
# check onnx installation and version # check onnx installation and version
try: try:
import onnx import onnx
...@@ -173,7 +173,7 @@ def onnx2paddle(model_path, save_dir, param_merge=False): ...@@ -173,7 +173,7 @@ def onnx2paddle(model_path, save_dir, param_merge=False):
optimizer = ONNXOptimizer(mapper) optimizer = ONNXOptimizer(mapper)
optimizer.delete_redundance_code() optimizer.delete_redundance_code()
mapper.save_inference_model(save_dir, param_merge) mapper.save_inference_model(save_dir, params_merge)
def main(): def main():
...@@ -208,29 +208,29 @@ def main(): ...@@ -208,29 +208,29 @@ def main():
assert args.model is not None, "--model should be defined while translating tensorflow model" assert args.model is not None, "--model should be defined while translating tensorflow model"
without_data_format_optimization = False without_data_format_optimization = False
define_input_shape = False define_input_shape = False
param_merge = False params_merge = False
if args.without_data_format_optimization: if args.without_data_format_optimization:
without_data_format_optimization = True without_data_format_optimization = True
if args.define_input_shape: if args.define_input_shape:
define_input_shape = True define_input_shape = True
if args.param_merge: if args.params_merge:
param_merge = True params_merge = True
tf2paddle(args.model, args.save_dir, without_data_format_optimization, tf2paddle(args.model, args.save_dir, without_data_format_optimization,
define_input_shape, param_merge) define_input_shape, params_merge)
elif args.framework == "caffe": elif args.framework == "caffe":
assert args.prototxt is not None and args.weight is not None, "--prototxt and --weight should be defined while translating caffe model" assert args.prototxt is not None and args.weight is not None, "--prototxt and --weight should be defined while translating caffe model"
param_merge = False params_merge = False
if args.param_merge: if args.params_merge:
param_merge = True params_merge = True
caffe2paddle(args.prototxt, args.weight, args.save_dir, caffe2paddle(args.prototxt, args.weight, args.save_dir,
args.caffe_proto, param_merge) args.caffe_proto, params_merge)
elif args.framework == "onnx": elif args.framework == "onnx":
assert args.model is not None, "--model should be defined while translating onnx model" assert args.model is not None, "--model should be defined while translating onnx model"
param_merge = False params_merge = False
if args.param_merge: if args.params_merge:
param_merge = True params_merge = True
onnx2paddle(args.model, args.save_dir, param_merge) onnx2paddle(args.model, args.save_dir, params_merge)
else: else:
raise Exception("--framework only support tensorflow/caffe/onnx now") raise Exception("--framework only support tensorflow/caffe/onnx now")
......
...@@ -110,7 +110,7 @@ class OpMapper(object): ...@@ -110,7 +110,7 @@ class OpMapper(object):
self.add_codes("import paddle.fluid as fluid") self.add_codes("import paddle.fluid as fluid")
self.add_codes("") self.add_codes("")
def save_inference_model(self, save_dir, param_merge): def save_inference_model(self, save_dir, params_merge):
self.save_python_model(save_dir) self.save_python_model(save_dir)
import sys import sys
...@@ -138,7 +138,7 @@ class OpMapper(object): ...@@ -138,7 +138,7 @@ class OpMapper(object):
py_code_dir, py_code_dir,
fluid.default_main_program(), fluid.default_main_program(),
predicate=if_exist) predicate=if_exist)
if param_merge: if params_merge:
fluid.io.save_inference_model(dirname=os.path.join( fluid.io.save_inference_model(dirname=os.path.join(
save_dir, "inference_model"), save_dir, "inference_model"),
feeded_var_names=input_names, feeded_var_names=input_names,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册