未验证 提交 b8c3a0d0 编写于 作者: J Jason 提交者: GitHub

Merge pull request #194 from mamingjie-China/develop-1.6

add params merge
...@@ -26,7 +26,7 @@ onnx : onnx == 1.5.0 onnxruntime == 0.4.0 ...@@ -26,7 +26,7 @@ onnx : onnx == 1.5.0 onnxruntime == 0.4.0
``` ```
git clone https://github.com/PaddlePaddle/X2Paddle.git git clone https://github.com/PaddlePaddle/X2Paddle.git
cd X2Paddle cd X2Paddle
git checkout develop git checkout develop-1.6
python setup.py install python setup.py install
``` ```
...@@ -59,7 +59,7 @@ x2paddle --framework=onnx --model=onnx_model.onnx --save_dir=pd_model ...@@ -59,7 +59,7 @@ x2paddle --framework=onnx --model=onnx_model.onnx --save_dir=pd_model
|--caffe_proto | **[可选]** 由caffe.proto编译成caffe_pb2.py文件的存放路径,当存在自定义Layer时使用,默认为None | |--caffe_proto | **[可选]** 由caffe.proto编译成caffe_pb2.py文件的存放路径,当存在自定义Layer时使用,默认为None |
|--without_data_format_optimization | **[可选]** For TensorFlow, 当指定该参数时,关闭NHWC->NCHW的优化,见[文档Q2](FAQ.md) | |--without_data_format_optimization | **[可选]** For TensorFlow, 当指定该参数时,关闭NHWC->NCHW的优化,见[文档Q2](FAQ.md) |
|--define_input_shape | **[可选]** For TensorFlow, 当指定该参数时,强制用户输入每个Placeholder的shape,见[文档Q2](FAQ.md) | |--define_input_shape | **[可选]** For TensorFlow, 当指定该参数时,强制用户输入每个Placeholder的shape,见[文档Q2](FAQ.md) |
|--params_merge | **[可选]** 当指定该参数时,转换完成后,inference_model中的所有模型参数将合并保存为一个文件__params__ |
## 使用转换后的模型 ## 使用转换后的模型
转换后的模型包括`model_with_code``inference_model`两个目录。 转换后的模型包括`model_with_code``inference_model`两个目录。
......
...@@ -68,6 +68,11 @@ def arg_parser(): ...@@ -68,6 +68,11 @@ def arg_parser():
action="store_true", action="store_true",
default=False, default=False,
help="define input shape for tf model") help="define input shape for tf model")
parser.add_argument("--params_merge",
"-pm",
action="store_true",
default=False,
help="define whether merge the params")
return parser return parser
...@@ -75,7 +80,8 @@ def arg_parser(): ...@@ -75,7 +80,8 @@ def arg_parser():
def tf2paddle(model_path, def tf2paddle(model_path,
save_dir, save_dir,
without_data_format_optimization=False, without_data_format_optimization=False,
define_input_shape=False): define_input_shape=False,
params_merge=False):
# check tensorflow installation and version # check tensorflow installation and version
try: try:
import os import os
...@@ -104,10 +110,10 @@ def tf2paddle(model_path, ...@@ -104,10 +110,10 @@ def tf2paddle(model_path,
optimizer.strip_graph() optimizer.strip_graph()
# optimizer.merge_activation() # optimizer.merge_activation()
# optimizer.merge_bias() # optimizer.merge_bias()
mapper.save_inference_model(save_dir) mapper.save_inference_model(save_dir, params_merge)
def caffe2paddle(proto, weight, save_dir, caffe_proto): def caffe2paddle(proto, weight, save_dir, caffe_proto, params_merge=False):
from x2paddle.decoder.caffe_decoder import CaffeDecoder from x2paddle.decoder.caffe_decoder import CaffeDecoder
from x2paddle.op_mapper.caffe_op_mapper import CaffeOpMapper from x2paddle.op_mapper.caffe_op_mapper import CaffeOpMapper
from x2paddle.optimizer.caffe_optimizer import CaffeOptimizer from x2paddle.optimizer.caffe_optimizer import CaffeOptimizer
...@@ -124,10 +130,10 @@ def caffe2paddle(proto, weight, save_dir, caffe_proto): ...@@ -124,10 +130,10 @@ def caffe2paddle(proto, weight, save_dir, caffe_proto):
optimizer = CaffeOptimizer(mapper) optimizer = CaffeOptimizer(mapper)
optimizer.merge_bn_scale() optimizer.merge_bn_scale()
optimizer.merge_op_activation() optimizer.merge_op_activation()
mapper.save_inference_model(save_dir) mapper.save_inference_model(save_dir, params_merge)
def onnx2paddle(model_path, save_dir): def onnx2paddle(model_path, save_dir, params_merge=False):
# check onnx installation and version # check onnx installation and version
try: try:
import onnx import onnx
...@@ -150,7 +156,7 @@ def onnx2paddle(model_path, save_dir): ...@@ -150,7 +156,7 @@ def onnx2paddle(model_path, save_dir):
optimizer = ONNXOptimizer(mapper) optimizer = ONNXOptimizer(mapper)
optimizer.delete_redundance_code() optimizer.delete_redundance_code()
mapper.save_inference_model(save_dir) mapper.save_inference_model(save_dir, params_merge)
def main(): def main():
...@@ -193,20 +199,29 @@ def main(): ...@@ -193,20 +199,29 @@ def main():
assert args.model is not None, "--model should be defined while translating tensorflow model" assert args.model is not None, "--model should be defined while translating tensorflow model"
without_data_format_optimization = False without_data_format_optimization = False
define_input_shape = False define_input_shape = False
params_merge = False
if args.without_data_format_optimization: if args.without_data_format_optimization:
without_data_format_optimization = True without_data_format_optimization = True
if args.define_input_shape: if args.define_input_shape:
define_input_shape = True define_input_shape = True
if args.params_merge:
params_merge = True
tf2paddle(args.model, args.save_dir, without_data_format_optimization, tf2paddle(args.model, args.save_dir, without_data_format_optimization,
define_input_shape) define_input_shape, params_merge)
elif args.framework == "caffe": elif args.framework == "caffe":
assert args.prototxt is not None and args.weight is not None, "--prototxt and --weight should be defined while translating caffe model" assert args.prototxt is not None and args.weight is not None, "--prototxt and --weight should be defined while translating caffe model"
params_merge = False
if args.params_merge:
params_merge = True
caffe2paddle(args.prototxt, args.weight, args.save_dir, caffe2paddle(args.prototxt, args.weight, args.save_dir,
args.caffe_proto) args.caffe_proto, params_merge)
elif args.framework == "onnx": elif args.framework == "onnx":
assert args.model is not None, "--model should be defined while translating onnx model" assert args.model is not None, "--model should be defined while translating onnx model"
onnx2paddle(args.model, args.save_dir) params_merge = False
if args.params_merge:
params_merge = True
onnx2paddle(args.model, args.save_dir, params_merge)
else: else:
raise Exception("--framework only support tensorflow/caffe/onnx now") raise Exception("--framework only support tensorflow/caffe/onnx now")
......
...@@ -110,7 +110,7 @@ class OpMapper(object): ...@@ -110,7 +110,7 @@ class OpMapper(object):
self.add_codes("import paddle.fluid as fluid") self.add_codes("import paddle.fluid as fluid")
self.add_codes("") self.add_codes("")
def save_inference_model(self, save_dir): def save_inference_model(self, save_dir, params_merge):
self.save_python_model(save_dir) self.save_python_model(save_dir)
import sys import sys
...@@ -138,7 +138,14 @@ class OpMapper(object): ...@@ -138,7 +138,14 @@ class OpMapper(object):
py_code_dir, py_code_dir,
fluid.default_main_program(), fluid.default_main_program(),
predicate=if_exist) predicate=if_exist)
if params_merge:
fluid.io.save_inference_model(dirname=os.path.join(
save_dir, "inference_model"),
feeded_var_names=input_names,
target_vars=outputs,
executor=exe,
params_filename="__params__")
else:
fluid.io.save_inference_model(dirname=os.path.join( fluid.io.save_inference_model(dirname=os.path.join(
save_dir, "inference_model"), save_dir, "inference_model"),
feeded_var_names=input_names, feeded_var_names=input_names,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册