提交 1fe404c5 编写于 作者: M mamingjie-China

add param_merge

上级 8259abe7
...@@ -68,6 +68,11 @@ def arg_parser(): ...@@ -68,6 +68,11 @@ def arg_parser():
action="store_true", action="store_true",
default=False, default=False,
help="define input shape for tf model") help="define input shape for tf model")
parser.add_argument("--param_merge",
"-pm",
action="store_true",
default=False,
help="define whether merge the params")
return parser return parser
...@@ -75,7 +80,8 @@ def arg_parser(): ...@@ -75,7 +80,8 @@ def arg_parser():
def tf2paddle(model_path, def tf2paddle(model_path,
save_dir, save_dir,
without_data_format_optimization=False, without_data_format_optimization=False,
define_input_shape=False): define_input_shape=False,
param_merge=False):
# check tensorflow installation and version # check tensorflow installation and version
try: try:
import os import os
...@@ -121,10 +127,10 @@ def tf2paddle(model_path, ...@@ -121,10 +127,10 @@ def tf2paddle(model_path,
optimizer.merge_bias() optimizer.merge_bias()
optimizer.make_nchw_input_output() optimizer.make_nchw_input_output()
optimizer.remove_transpose() optimizer.remove_transpose()
mapper.save_inference_model(save_dir) mapper.save_inference_model(save_dir, param_merge)
def caffe2paddle(proto, weight, save_dir, caffe_proto): def caffe2paddle(proto, weight, save_dir, caffe_proto, param_merge=False):
from x2paddle.decoder.caffe_decoder import CaffeDecoder from x2paddle.decoder.caffe_decoder import CaffeDecoder
from x2paddle.op_mapper.caffe_op_mapper import CaffeOpMapper from x2paddle.op_mapper.caffe_op_mapper import CaffeOpMapper
from x2paddle.optimizer.caffe_optimizer import CaffeOptimizer from x2paddle.optimizer.caffe_optimizer import CaffeOptimizer
...@@ -141,10 +147,10 @@ def caffe2paddle(proto, weight, save_dir, caffe_proto): ...@@ -141,10 +147,10 @@ def caffe2paddle(proto, weight, save_dir, caffe_proto):
optimizer = CaffeOptimizer(mapper) optimizer = CaffeOptimizer(mapper)
optimizer.merge_bn_scale() optimizer.merge_bn_scale()
optimizer.merge_op_activation() optimizer.merge_op_activation()
mapper.save_inference_model(save_dir) mapper.save_inference_model(save_dir, param_merge)
def onnx2paddle(model_path, save_dir): def onnx2paddle(model_path, save_dir, param_merge=False):
# check onnx installation and version # check onnx installation and version
try: try:
import onnx import onnx
...@@ -167,7 +173,7 @@ def onnx2paddle(model_path, save_dir): ...@@ -167,7 +173,7 @@ def onnx2paddle(model_path, save_dir):
optimizer = ONNXOptimizer(mapper) optimizer = ONNXOptimizer(mapper)
optimizer.delete_redundance_code() optimizer.delete_redundance_code()
mapper.save_inference_model(save_dir) mapper.save_inference_model(save_dir, param_merge)
def main(): def main():
...@@ -202,20 +208,29 @@ def main(): ...@@ -202,20 +208,29 @@ def main():
assert args.model is not None, "--model should be defined while translating tensorflow model" assert args.model is not None, "--model should be defined while translating tensorflow model"
without_data_format_optimization = False without_data_format_optimization = False
define_input_shape = False define_input_shape = False
param_merge = False
if args.without_data_format_optimization: if args.without_data_format_optimization:
without_data_format_optimization = True without_data_format_optimization = True
if args.define_input_shape: if args.define_input_shape:
define_input_shape = True define_input_shape = True
if args.param_merge:
param_merge = True
tf2paddle(args.model, args.save_dir, without_data_format_optimization, tf2paddle(args.model, args.save_dir, without_data_format_optimization,
define_input_shape) define_input_shape, param_merge)
elif args.framework == "caffe": elif args.framework == "caffe":
assert args.prototxt is not None and args.weight is not None, "--prototxt and --weight should be defined while translating caffe model" assert args.prototxt is not None and args.weight is not None, "--prototxt and --weight should be defined while translating caffe model"
param_merge = False
if args.param_merge:
param_merge = True
caffe2paddle(args.prototxt, args.weight, args.save_dir, caffe2paddle(args.prototxt, args.weight, args.save_dir,
args.caffe_proto) args.caffe_proto, param_merge)
elif args.framework == "onnx": elif args.framework == "onnx":
assert args.model is not None, "--model should be defined while translating onnx model" assert args.model is not None, "--model should be defined while translating onnx model"
onnx2paddle(args.model, args.save_dir) param_merge = False
if args.param_merge:
param_merge = True
onnx2paddle(args.model, args.save_dir, param_merge)
else: else:
raise Exception("--framework only support tensorflow/caffe/onnx now") raise Exception("--framework only support tensorflow/caffe/onnx now")
......
...@@ -110,7 +110,7 @@ class OpMapper(object): ...@@ -110,7 +110,7 @@ class OpMapper(object):
self.add_codes("import paddle.fluid as fluid") self.add_codes("import paddle.fluid as fluid")
self.add_codes("") self.add_codes("")
def save_inference_model(self, save_dir): def save_inference_model(self, save_dir, param_merge):
self.save_python_model(save_dir) self.save_python_model(save_dir)
import sys import sys
...@@ -138,13 +138,20 @@ class OpMapper(object): ...@@ -138,13 +138,20 @@ class OpMapper(object):
py_code_dir, py_code_dir,
fluid.default_main_program(), fluid.default_main_program(),
predicate=if_exist) predicate=if_exist)
if param_merge:
fluid.io.save_inference_model(dirname=os.path.join( fluid.io.save_inference_model(dirname=os.path.join(
save_dir, "inference_model"), save_dir, "inference_model"),
feeded_var_names=input_names, feeded_var_names=input_names,
target_vars=outputs, target_vars=outputs,
executor=exe, executor=exe,
params_filename=None) params_filename="__params__")
else:
fluid.io.save_inference_model(dirname=os.path.join(
save_dir, "inference_model"),
feeded_var_names=input_names,
target_vars=outputs,
executor=exe,
params_filename=None)
except: except:
raise Exception( raise Exception(
"Paddle code was saved in {}/model.py, but seems there's wrong exist, please check model.py manually." "Paddle code was saved in {}/model.py, but seems there's wrong exist, please check model.py manually."
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册