提交 a33bf0f0 编写于 作者: S SunAhong1993

add pytorch code and docs

上级 ce6ffee2
......@@ -44,6 +44,10 @@ x2paddle --framework=caffe --prototxt=deploy.prototxt --weight=deploy.caffemodel
```
x2paddle --framework=onnx --model=onnx_model.onnx --save_dir=pd_model
```
### PyTorch
```
x2paddle --framework=pytorch --model=resnet50.pt --save_dir=pd_model --input_shapes [-1,3,224,224]
```
### Paddle2ONNX
```
# 注意:paddle_infer_model_dir下需包含__model__和__params__两个文件
......@@ -52,7 +56,7 @@ x2paddle --framework=paddle2onnx --model=paddle_infer_model_dir --save_dir=onnx_
### 参数选项
| 参数 | |
|----------|--------------|
|--framework | 源模型类型 (tensorflow、caffe、onnx、paddle2onnx) |
|--framework | 源模型类型 (tensorflow、caffe、onnx、pytorch、paddle2onnx) |
|--prototxt | 当framework为caffe时,该参数指定caffe模型的proto文件路径 |
|--weight | 当framework为caffe时,该参数指定caffe模型的参数文件路径 |
|--save_dir | 指定转换后的模型保存目录路径 |
......@@ -62,6 +66,7 @@ x2paddle --framework=paddle2onnx --model=paddle_infer_model_dir --save_dir=onnx_
|--define_input_shape | **[可选]** For TensorFlow, 当指定该参数时,强制用户输入每个Placeholder的shape,见[文档Q2](FAQ.md) |
|--params_merge | **[可选]** 当指定该参数时,转换完成后,inference_model中的所有模型参数将合并保存为一个文件__params__ |
|--onnx_opset | **[可选]** 当framework为paddle2onnx时,该参数可设置转换为ONNX的OpSet版本,目前支持9、10、11,默认为10 |
|--input_shapes |**[可选]** 当framework为pytorch时,该参数若设置,则根据输入的shape导出inference model(用于预测的静态模型)|
......
## PyTorch模型导出为ONNX模型
目前pytorch2paddle主要支持pytorch ScriptModule。 用户可通过如下示例代码,将torchvison或者自己开发写的模型转换成ScriptModule model:
```
#coding: utf-8
import torch
import torch.nn as nn
from torchvision.models.utils import load_state_dict_from_url
# 定义模型
class AlexNet(nn.Module):
def __init__(self, num_classes=1000):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(192, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
self.classifier = nn.Sequential(
nn.Dropout(0.0),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(0.0),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
)
def forward(self, x):
x = self.features(x)
for i in range(1):
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
# 初始化模型
model = AlexNet()
# 加载参数
state_dict = load_state_dict_from_url('https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
progress=True)
model.load_state_dict(state_dict)
# 设置模式
model.eval()
# 生成ScriptModule并保存
script = torch.jit.script(model)
torch.jit.save(script, "alexnet.pt")
```
__version__ = "0.8.4"
from .core.program import PaddleProgram
from .core.program import PaddleGraph
program = PaddleProgram()
program = PaddleGraph()
name_counter = dict()
......
......@@ -13,6 +13,7 @@
# limitations under the License.
from six import text_type as _text_type
from x2paddle import program
import argparse
import sys
......@@ -66,8 +67,8 @@ def arg_parser():
parser.add_argument(
"--without_data_format_optimization",
"-wo",
type=_text_type,
default="True",
action="store_true",
default=False,
help="tf model conversion without data format optimization")
parser.add_argument(
"--define_input_shape",
......@@ -87,13 +88,19 @@ def arg_parser():
action="store_true",
default=False,
help="define whether merge the params")
parser.add_argument(
"--input_shapes",
"-is",
action='append',
default=None,
help="define the inputs' shape")
return parser
def tf2paddle(model_path,
save_dir,
without_data_format_optimization,
without_data_format_optimization=False,
define_input_shape=False,
params_merge=False):
# check tensorflow installation and version
......@@ -120,29 +127,10 @@ def tf2paddle(model_path,
print("Now translating model from tensorflow to paddle.")
model = TFDecoder(model_path, define_input_shape=define_input_shape)
if not without_data_format_optimization:
mapper = TFOpMapper(model)
optimizer = TFOptimizer(mapper)
# neccesary optimization
optimizer.delete_redundance_code()
# optimizer below is experimental
optimizer.optimize_elementwise_op()
optimizer.merge_activation()
optimizer.merge_bias()
optimizer.optimize_sub_graph()
# optimizer.merge_batch_norm()
# optimizer.merge_prelu()
else:
mapper = TFOpMapperNHWC(model)
optimizer = TFOptimizer(mapper)
optimizer.delete_redundance_code()
optimizer.strip_graph()
optimizer.merge_activation()
optimizer.merge_bias()
optimizer.make_nchw_input_output()
optimizer.remove_transpose()
mapper.save_inference_model(save_dir, params_merge)
mapper = TFOpMapperNHWC(model)
program.build()
program.gen_model(save_dir)
def caffe2paddle(proto, weight, save_dir, caffe_proto, params_merge=False):
......@@ -170,8 +158,8 @@ def onnx2paddle(model_path, save_dir, params_merge=False):
try:
import onnx
version = onnx.version.version
if version < '1.6.0':
print("[ERROR] onnx>=1.6.0 is required")
if version != '1.6.0':
print("[ERROR] onnx==1.6.0 is required")
return
except:
print("[ERROR] onnx is not installed, use \"pip install onnx==1.6.0\".")
......@@ -192,17 +180,51 @@ def onnx2paddle(model_path, save_dir, params_merge=False):
print("Paddle model and code generated.")
def pytorch2paddle(model_path, save_dir, input_shapes):
# check pytorch installation and version
try:
import torch
version = torch.__version__
ver_part = version.split('.')
print(ver_part)
if int(ver_part[1]) < 5:
print("[ERROR] pytorch>=1.5.0 is required")
return
except:
print(
"[ERROR] Pytorch is not installed, use \"pip install torch==1.5.0 torchvision\"."
)
return
print("Now translating model from pytorch to paddle.")
from x2paddle.decoder.pytorch_decoder import PyTorchDecoder
from x2paddle.op_mapper.pytorch2paddle import pytorch_op_mapper
model = PyTorchDecoder(model_path)
mapper = pytorch_op_mapper.PyTorchOpMapper(model)
mapper.graph.build()
print("Model optimizing ...")
from x2paddle.optimizer.pytorch_optimizer.optimizer import GraphOptimizer
graph_opt = GraphOptimizer()
graph_opt.optimize(mapper.graph)
print("Model optimized.")
if input_shapes is not None:
real_input_shapes = list()
for shape in input_shapes:
sp = shape[1:-1].split(",")
for i, s in enumerate(sp):
sp[i] = int(s)
real_input_shapes.append(sp)
else:
real_input_shapes = None
mapper.graph.gen_model(save_dir, real_input_shapes)
def paddle2onnx(model_path, save_dir, opset_version=10):
from x2paddle.decoder.paddle_decoder import PaddleDecoder
from x2paddle.op_mapper.paddle2onnx.paddle_op_mapper import PaddleOpMapper
import paddle.fluid as fluid
model = PaddleDecoder(model_path, '__model__', '__params__')
mapper = PaddleOpMapper()
mapper.convert(
model.program,
save_dir,
scope=fluid.global_scope(),
opset_version=opset_version)
mapper.convert(model.program, save_dir, opset_number=opset_version)
def main():
......@@ -240,12 +262,11 @@ def main():
if args.framework == "tensorflow":
assert args.model is not None, "--model should be defined while translating tensorflow model"
assert args.without_data_format_optimization in [
"True", "False"
], "--the param without_data_format_optimization should be defined True or False"
without_data_format_optimization = False
define_input_shape = False
params_merge = False
without_data_format_optimization = True if args.without_data_format_optimization == "True" else False
if args.without_data_format_optimization:
without_data_format_optimization = True
if args.define_input_shape:
define_input_shape = True
if args.params_merge:
......@@ -267,10 +288,13 @@ def main():
if args.params_merge:
params_merge = True
onnx2paddle(args.model, args.save_dir, params_merge)
elif args.framework == "pytorch":
assert args.model is not None, "--model should be defined while translating pytorch model"
pytorch2paddle(args.model, args.save_dir, args.input_shapes)
elif args.framework == "paddle2onnx":
assert args.model is not None, "--model should be defined while translating paddle model to onnx"
paddle2onnx(args.model, args.save_dir, opset_version=args.onnx_opset)
paddle2onnx(args.model, args.save_dir, args.onnx_opset)
else:
raise Exception(
......
此差异已折叠。
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
class PyTorchDecoder(object):
def __init__(self, script_path):
self.script = torch.jit.load(script_path)
self.graph = self._optimize_graph(self.script.inlined_graph)
def _optimize_graph(self, graph):
torch._C._jit_pass_constant_propagation(graph)
torch._C._jit_pass_dce(graph)
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_peephole(graph)
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_dce(graph)
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_canonicalize(graph)
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_constant_propagation(graph)
return graph
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .adaptive_pool2d_fuser import AdaptivePool2dFuser
from .adaptive_pool2d_fuse_pass import AdaptivePool2dFusePass
from .batchnorm2d_fuser import BatchNorm2dFuser
from .batchnorm2d_fuse_pass import BatchNorm2dFusePass
from .constant_fuser import ConstantFuser
from .constant_fuse_pass import ConstantFusePass
from .dropout_fuser import DropoutFuser
from .dropout_fuse_pass import DropoutFusePass
from .fc_fuser import FcFuser
from .fc_fuse_pass import FcFusePass
from .interpolate_bilinear_fuser import InterpolateBilinearFuser
from .interpolate_bilinear_fuse_pass import InterpolateBilinearFusePass
from .reshape_fuser import ReshapeFuser
from .reshape_fuse_pass import ReshapeFusePass
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册