提交 ea91537b 编写于 作者: M mamingjie-China

add params merge

上级 1b2b70f7
......@@ -68,7 +68,7 @@ def arg_parser():
action="store_true",
default=False,
help="define input shape for tf model")
parser.add_argument("--param_merge",
parser.add_argument("--params_merge",
"-pm",
action="store_true",
default=False,
......@@ -81,7 +81,7 @@ def tf2paddle(model_path,
save_dir,
without_data_format_optimization=False,
define_input_shape=False,
param_merge=False):
params_merge=False):
# check tensorflow installation and version
try:
import os
......@@ -127,10 +127,10 @@ def tf2paddle(model_path,
optimizer.merge_bias()
optimizer.make_nchw_input_output()
optimizer.remove_transpose()
mapper.save_inference_model(save_dir, param_merge)
mapper.save_inference_model(save_dir, params_merge)
def caffe2paddle(proto, weight, save_dir, caffe_proto, param_merge=False):
def caffe2paddle(proto, weight, save_dir, caffe_proto, params_merge=False):
from x2paddle.decoder.caffe_decoder import CaffeDecoder
from x2paddle.op_mapper.caffe_op_mapper import CaffeOpMapper
from x2paddle.optimizer.caffe_optimizer import CaffeOptimizer
......@@ -147,10 +147,10 @@ def caffe2paddle(proto, weight, save_dir, caffe_proto, param_merge=False):
optimizer = CaffeOptimizer(mapper)
optimizer.merge_bn_scale()
optimizer.merge_op_activation()
mapper.save_inference_model(save_dir, param_merge)
mapper.save_inference_model(save_dir, params_merge)
def onnx2paddle(model_path, save_dir, param_merge=False):
def onnx2paddle(model_path, save_dir, params_merge=False):
# check onnx installation and version
try:
import onnx
......@@ -173,7 +173,7 @@ def onnx2paddle(model_path, save_dir, param_merge=False):
optimizer = ONNXOptimizer(mapper)
optimizer.delete_redundance_code()
mapper.save_inference_model(save_dir, param_merge)
mapper.save_inference_model(save_dir, params_merge)
def main():
......@@ -208,29 +208,29 @@ def main():
assert args.model is not None, "--model should be defined while translating tensorflow model"
without_data_format_optimization = False
define_input_shape = False
param_merge = False
params_merge = False
if args.without_data_format_optimization:
without_data_format_optimization = True
if args.define_input_shape:
define_input_shape = True
if args.param_merge:
param_merge = True
if args.params_merge:
params_merge = True
tf2paddle(args.model, args.save_dir, without_data_format_optimization,
define_input_shape, param_merge)
define_input_shape, params_merge)
elif args.framework == "caffe":
assert args.prototxt is not None and args.weight is not None, "--prototxt and --weight should be defined while translating caffe model"
param_merge = False
if args.param_merge:
param_merge = True
params_merge = False
if args.params_merge:
params_merge = True
caffe2paddle(args.prototxt, args.weight, args.save_dir,
args.caffe_proto, param_merge)
args.caffe_proto, params_merge)
elif args.framework == "onnx":
assert args.model is not None, "--model should be defined while translating onnx model"
param_merge = False
if args.param_merge:
param_merge = True
onnx2paddle(args.model, args.save_dir, param_merge)
params_merge = False
if args.params_merge:
params_merge = True
onnx2paddle(args.model, args.save_dir, params_merge)
else:
raise Exception("--framework only support tensorflow/caffe/onnx now")
......
......@@ -110,7 +110,7 @@ class OpMapper(object):
self.add_codes("import paddle.fluid as fluid")
self.add_codes("")
def save_inference_model(self, save_dir, param_merge):
def save_inference_model(self, save_dir, params_merge):
self.save_python_model(save_dir)
import sys
......@@ -138,7 +138,7 @@ class OpMapper(object):
py_code_dir,
fluid.default_main_program(),
predicate=if_exist)
if param_merge:
if params_merge:
fluid.io.save_inference_model(dirname=os.path.join(
save_dir, "inference_model"),
feeded_var_names=input_names,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册