提交 1fe404c5 编写于 作者: M mamingjie-China

add param_merge

上级 8259abe7
......@@ -68,6 +68,11 @@ def arg_parser():
action="store_true",
default=False,
help="define input shape for tf model")
parser.add_argument("--param_merge",
"-pm",
action="store_true",
default=False,
help="define whether merge the params")
return parser
......@@ -75,7 +80,8 @@ def arg_parser():
def tf2paddle(model_path,
save_dir,
without_data_format_optimization=False,
define_input_shape=False):
define_input_shape=False,
param_merge=False):
# check tensorflow installation and version
try:
import os
......@@ -121,10 +127,10 @@ def tf2paddle(model_path,
optimizer.merge_bias()
optimizer.make_nchw_input_output()
optimizer.remove_transpose()
mapper.save_inference_model(save_dir)
mapper.save_inference_model(save_dir, param_merge)
def caffe2paddle(proto, weight, save_dir, caffe_proto):
def caffe2paddle(proto, weight, save_dir, caffe_proto, param_merge=False):
from x2paddle.decoder.caffe_decoder import CaffeDecoder
from x2paddle.op_mapper.caffe_op_mapper import CaffeOpMapper
from x2paddle.optimizer.caffe_optimizer import CaffeOptimizer
......@@ -141,10 +147,10 @@ def caffe2paddle(proto, weight, save_dir, caffe_proto):
optimizer = CaffeOptimizer(mapper)
optimizer.merge_bn_scale()
optimizer.merge_op_activation()
mapper.save_inference_model(save_dir)
mapper.save_inference_model(save_dir, param_merge)
def onnx2paddle(model_path, save_dir):
def onnx2paddle(model_path, save_dir, param_merge=False):
# check onnx installation and version
try:
import onnx
......@@ -167,7 +173,7 @@ def onnx2paddle(model_path, save_dir):
optimizer = ONNXOptimizer(mapper)
optimizer.delete_redundance_code()
mapper.save_inference_model(save_dir)
mapper.save_inference_model(save_dir, param_merge)
def main():
......@@ -202,20 +208,29 @@ def main():
assert args.model is not None, "--model should be defined while translating tensorflow model"
without_data_format_optimization = False
define_input_shape = False
param_merge = False
if args.without_data_format_optimization:
without_data_format_optimization = True
if args.define_input_shape:
define_input_shape = True
if args.param_merge:
param_merge = True
tf2paddle(args.model, args.save_dir, without_data_format_optimization,
define_input_shape)
define_input_shape, param_merge)
elif args.framework == "caffe":
assert args.prototxt is not None and args.weight is not None, "--prototxt and --weight should be defined while translating caffe model"
param_merge = False
if args.param_merge:
param_merge = True
caffe2paddle(args.prototxt, args.weight, args.save_dir,
args.caffe_proto)
args.caffe_proto, param_merge)
elif args.framework == "onnx":
assert args.model is not None, "--model should be defined while translating onnx model"
onnx2paddle(args.model, args.save_dir)
param_merge = False
if args.param_merge:
param_merge = True
onnx2paddle(args.model, args.save_dir, param_merge)
else:
raise Exception("--framework only support tensorflow/caffe/onnx now")
......
......@@ -110,7 +110,7 @@ class OpMapper(object):
self.add_codes("import paddle.fluid as fluid")
self.add_codes("")
def save_inference_model(self, save_dir):
def save_inference_model(self, save_dir, param_merge):
self.save_python_model(save_dir)
import sys
......@@ -138,13 +138,20 @@ class OpMapper(object):
py_code_dir,
fluid.default_main_program(),
predicate=if_exist)
fluid.io.save_inference_model(dirname=os.path.join(
save_dir, "inference_model"),
feeded_var_names=input_names,
target_vars=outputs,
executor=exe,
params_filename=None)
if param_merge:
fluid.io.save_inference_model(dirname=os.path.join(
save_dir, "inference_model"),
feeded_var_names=input_names,
target_vars=outputs,
executor=exe,
params_filename="__params__")
else:
fluid.io.save_inference_model(dirname=os.path.join(
save_dir, "inference_model"),
feeded_var_names=input_names,
target_vars=outputs,
executor=exe,
params_filename=None)
except:
raise Exception(
"Paddle code was saved in {}/model.py, but seems there's wrong exist, please check model.py manually."
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册