Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
69bc2bb3
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
69bc2bb3
编写于
11月 30, 2021
作者:
W
WJJ1995
提交者:
GitHub
11月 30, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' into add_silu_op
上级
139554da
06984e8b
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
43 addition
and
30 deletion
+43
-30
x2paddle/convert.py
x2paddle/convert.py
+40
-28
x2paddle/core/program.py
x2paddle/core/program.py
+1
-1
x2paddle/decoder/onnx_decoder.py
x2paddle/decoder/onnx_decoder.py
+0
-1
x2paddle/op_mapper/onnx2paddle/opset9/opset.py
x2paddle/op_mapper/onnx2paddle/opset9/opset.py
+2
-0
未找到文件。
x2paddle/convert.py
浏览文件 @
69bc2bb3
...
...
@@ -16,6 +16,7 @@ from six import text_type as _text_type
from
x2paddle
import
program
import
argparse
import
sys
import
logging
def
arg_parser
():
...
...
@@ -137,12 +138,12 @@ def tf2paddle(model_path,
import
tensorflow
as
tf
version
=
tf
.
__version__
if
version
>=
'2.0.0'
or
version
<
'1.0.0'
:
print
(
logging
.
info
(
"[ERROR] 1.0.0<=tensorflow<2.0.0 is required, and v1.14.0 is recommended"
)
return
except
:
print
(
logging
.
info
(
"[ERROR] Tensorflow is not installed, use
\"
pip install tensorflow
\"
."
)
return
...
...
@@ -150,7 +151,7 @@ def tf2paddle(model_path,
from
x2paddle.decoder.tf_decoder
import
TFDecoder
from
x2paddle.op_mapper.tf2paddle.tf_op_mapper
import
TFOpMapper
print
(
"Now translating model from tensorflow to paddle."
)
logging
.
info
(
"Now translating model from tensorflow to paddle."
)
model
=
TFDecoder
(
model_path
,
define_input_shape
=
define_input_shape
)
mapper
=
TFOpMapper
(
model
)
mapper
.
paddle_graph
.
build
()
...
...
@@ -178,15 +179,15 @@ def caffe2paddle(proto_file,
or
(
int
(
ver_part
[
0
])
>
3
):
version_satisfy
=
True
assert
version_satisfy
,
'[ERROR] google.protobuf >= 3.6.0 is required'
print
(
"Now translating model from caffe to paddle."
)
logging
.
info
(
"Now translating model from caffe to paddle."
)
model
=
CaffeDecoder
(
proto_file
,
weight_file
,
caffe_proto
)
mapper
=
CaffeOpMapper
(
model
)
mapper
.
paddle_graph
.
build
()
print
(
"Model optimizing ..."
)
logging
.
info
(
"Model optimizing ..."
)
from
x2paddle.optimizer.optimizer
import
GraphOptimizer
graph_opt
=
GraphOptimizer
(
source_frame
=
"caffe"
)
graph_opt
.
optimize
(
mapper
.
paddle_graph
)
print
(
"Model optimized."
)
logging
.
info
(
"Model optimized."
)
mapper
.
paddle_graph
.
gen_model
(
save_dir
)
if
convert_to_lite
:
convert2lite
(
save_dir
,
lite_valid_places
,
lite_model_type
)
...
...
@@ -204,12 +205,13 @@ def onnx2paddle(model_path,
v0
,
v1
,
v2
=
version
.
split
(
'.'
)
version_sum
=
int
(
v0
)
*
100
+
int
(
v1
)
*
10
+
int
(
v2
)
if
version_sum
<
160
:
print
(
"[ERROR] onnx>=1.6.0 is required"
)
logging
.
info
(
"[ERROR] onnx>=1.6.0 is required"
)
return
except
:
print
(
"[ERROR] onnx is not installed, use
\"
pip install onnx==1.6.0
\"
."
)
logging
.
info
(
"[ERROR] onnx is not installed, use
\"
pip install onnx==1.6.0
\"
."
)
return
print
(
"Now translating model from onnx to paddle."
)
logging
.
info
(
"Now translating model from onnx to paddle."
)
from
x2paddle.decoder.onnx_decoder
import
ONNXDecoder
from
x2paddle.op_mapper.onnx2paddle.onnx_op_mapper
import
ONNXOpMapper
...
...
@@ -233,17 +235,24 @@ def pytorch2paddle(module,
try
:
import
torch
version
=
torch
.
__version__
ver_part
=
version
.
split
(
'.'
)
print
(
ver_part
)
if
int
(
ver_part
[
1
])
<
5
:
print
(
"[ERROR] pytorch>=1.5.0 is required"
)
v0
,
v1
,
v2
=
version
.
split
(
'.'
)
# Avoid the situation where the version is equal to 1.7.0+cu101
if
'+'
in
v2
:
v2
=
v2
.
split
(
'+'
)[
0
]
version_sum
=
int
(
v0
)
*
100
+
int
(
v1
)
*
10
+
int
(
v2
)
if
version_sum
<
150
:
logging
.
info
(
"[ERROR] pytorch>=1.5.0 is required, 1.6.0 is the most recommended"
)
return
if
version_sum
>
160
:
logging
.
info
(
"[WARNING] pytorch==1.6.0 is recommended"
)
except
:
print
(
"[ERROR] Pytorch is not installed, use
\"
pip install torch==1.
5
.0 torchvision
\"
."
logging
.
info
(
"[ERROR] Pytorch is not installed, use
\"
pip install torch==1.
6
.0 torchvision
\"
."
)
return
print
(
"Now translating model from pytorch to paddle."
)
logging
.
info
(
"Now translating model from pytorch to paddle."
)
from
x2paddle.decoder.pytorch_decoder
import
ScriptDecoder
,
TraceDecoder
from
x2paddle.op_mapper.pytorch2paddle.pytorch_op_mapper
import
PyTorchOpMapper
...
...
@@ -254,7 +263,7 @@ def pytorch2paddle(module,
model
=
ScriptDecoder
(
module
,
input_examples
)
mapper
=
PyTorchOpMapper
(
model
)
mapper
.
paddle_graph
.
build
()
print
(
"Model optimizing ..."
)
logging
.
info
(
"Model optimizing ..."
)
from
x2paddle.optimizer.optimizer
import
GraphOptimizer
graph_opt
=
GraphOptimizer
(
source_frame
=
"pytorch"
,
jit_type
=
jit_type
)
graph_opt
.
optimize
(
mapper
.
paddle_graph
)
...
...
@@ -266,10 +275,12 @@ def pytorch2paddle(module,
def
main
():
logging
.
basicConfig
(
level
=
logging
.
INFO
)
if
len
(
sys
.
argv
)
<
2
:
print
(
"Use
\"
x2paddle -h
\"
to print the help information"
)
print
(
"For more information, please follow our github repo below:)"
)
print
(
"
\n
Github: https://github.com/PaddlePaddle/X2Paddle.git
\n
"
)
logging
.
info
(
"Use
\"
x2paddle -h
\"
to print the help information"
)
logging
.
info
(
"For more information, please follow our github repo below:)"
)
logging
.
info
(
"
\n
Github: https://github.com/PaddlePaddle/X2Paddle.git
\n
"
)
return
parser
=
arg_parser
()
...
...
@@ -277,8 +288,8 @@ def main():
if
args
.
version
:
import
x2paddle
print
(
"x2paddle-{} with python>=3.5, paddlepaddle>=1.6.0
\n
"
.
format
(
x2paddle
.
__version__
))
logging
.
info
(
"x2paddle-{} with python>=3.5, paddlepaddle>=1.6.0
\n
"
.
format
(
x2paddle
.
__version__
))
return
if
not
args
.
convert_torch_project
:
...
...
@@ -289,18 +300,19 @@ def main():
import
platform
v0
,
v1
,
v2
=
platform
.
python_version
().
split
(
'.'
)
if
not
(
int
(
v0
)
>=
3
and
int
(
v1
)
>=
5
):
print
(
"[ERROR] python>=3.5 is required"
)
logging
.
info
(
"[ERROR] python>=3.5 is required"
)
return
import
paddle
v0
,
v1
,
v2
=
paddle
.
__version__
.
split
(
'.'
)
print
(
"paddle.__version__ = {}"
.
format
(
paddle
.
__version__
))
logging
.
info
(
"paddle.__version__ = {}"
.
format
(
paddle
.
__version__
))
if
v0
==
'0'
and
v1
==
'0'
and
v2
==
'0'
:
print
(
"[WARNING] You are use develop version of paddlepaddle"
)
logging
.
info
(
"[WARNING] You are use develop version of paddlepaddle"
)
elif
int
(
v0
)
!=
2
or
int
(
v1
)
<
0
:
print
(
"[ERROR] paddlepaddle>=2.0.0 is required"
)
logging
.
info
(
"[ERROR] paddlepaddle>=2.0.0 is required"
)
return
except
:
print
(
logging
.
info
(
"[ERROR] paddlepaddle not installed, use
\"
pip install paddlepaddle
\"
"
)
...
...
@@ -341,7 +353,7 @@ def main():
lite_valid_places
=
args
.
lite_valid_places
,
lite_model_type
=
args
.
lite_model_type
)
elif
args
.
framework
==
"paddle2onnx"
:
print
(
logging
.
info
(
"Paddle to ONNX tool has been migrated to the new github: https://github.com/PaddlePaddle/paddle2onnx"
)
...
...
x2paddle/core/program.py
浏览文件 @
69bc2bb3
...
...
@@ -388,7 +388,7 @@ class PaddleGraph(object):
gen_codes
(
[
"paddle.disable_static()"
,
"params = paddle.load('{}')"
.
format
(
"params = paddle.load(
r
'{}')"
.
format
(
osp
.
join
(
osp
.
abspath
(
code_dir
),
"model.pdparams"
)),
"model = {}()"
.
format
(
self
.
name
),
"model.set_dict(params, use_structured_name={})"
.
format
(
...
...
x2paddle/decoder/onnx_decoder.py
浏览文件 @
69bc2bb3
...
...
@@ -16,7 +16,6 @@ from x2paddle.core.graph import GraphNode, Graph
from
x2paddle.decoder.onnx_shape_inference
import
SymbolicShapeInference
from
onnx.checker
import
ValidationError
from
onnx.checker
import
check_model
from
onnx.utils
import
polish_model
from
onnx
import
helper
,
shape_inference
from
onnx.helper
import
get_attribute_value
,
make_attribute
from
onnx.shape_inference
import
infer_shapes
...
...
x2paddle/op_mapper/onnx2paddle/opset9/opset.py
浏览文件 @
69bc2bb3
...
...
@@ -518,6 +518,8 @@ class OpSet9():
if
pads
is
not
None
:
is_pads_attr
=
True
mode
=
node
.
get_attr
(
'mode'
,
'constant'
)
if
mode
in
[
"edge"
]:
mode
=
"replicate"
value
=
node
.
get_attr
(
'value'
,
0.
)
data_shape
=
val_x
.
out_shapes
[
0
]
output_shape
=
node
.
out_shapes
[
0
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录