Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
84581b17
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
84581b17
编写于
10月 23, 2019
作者:
C
channings
提交者:
GitHub
10月 23, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5 from PaddlePaddle/develop
update
上级
2ec20d64
a72829d0
变更
12
展开全部
显示空白变更内容
内联
并排
Showing
12 changed file
with
763 addition
and
239 deletion
+763
-239
README.md
README.md
+7
-14
op_list.md
op_list.md
+52
-0
x2paddle/convert.py
x2paddle/convert.py
+12
-5
x2paddle/decoder/tf_decoder.py
x2paddle/decoder/tf_decoder.py
+36
-1
x2paddle/op_mapper/caffe_custom_layer/register.py
x2paddle/op_mapper/caffe_custom_layer/register.py
+0
-1
x2paddle/op_mapper/caffe_custom_layer/shufflechannel.py
x2paddle/op_mapper/caffe_custom_layer/shufflechannel.py
+1
-7
x2paddle/op_mapper/caffe_op_mapper.py
x2paddle/op_mapper/caffe_op_mapper.py
+20
-37
x2paddle/op_mapper/onnx_custom_layer/register.py
x2paddle/op_mapper/onnx_custom_layer/register.py
+0
-1
x2paddle/op_mapper/tf_op_mapper.py
x2paddle/op_mapper/tf_op_mapper.py
+27
-67
x2paddle/op_mapper/tf_op_mapper_nhwc.py
x2paddle/op_mapper/tf_op_mapper_nhwc.py
+90
-11
x2paddle/optimizer/tf_optimizer.py
x2paddle/optimizer/tf_optimizer.py
+517
-95
x2paddle_model_zoo.md
x2paddle_model_zoo.md
+1
-0
未找到文件。
README.md
浏览文件 @
84581b17
...
...
@@ -5,7 +5,7 @@ X2Paddle支持将其余深度学习框架训练得到的模型,转换至Paddle
X2Paddle is a toolkit for converting trained model to PaddlePaddle from other deep learning frameworks.
## 转换模型库
X2Paddle在多个主流的CV模型上,测试过TensorFlow/Caffe/ONNX模型的转换,可以在
[
X2Paddle-Model-Zoo
](
x2paddle_model_zoo.md
)
查看我们的模型测试列表。如果你在新的模型上进行了测试转换,也欢迎继续补充该列表;如若无法转换,可通过ISSUE反馈给我们,我们会尽快跟进。
X2Paddle在多个主流的CV模型上,测试过TensorFlow/Caffe/ONNX模型的转换,可以在
[
X2Paddle-Model-Zoo
](
x2paddle_model_zoo.md
)
查看我们的模型测试列表
,可以在
[
OP-LIST
](
op_list.md
)
中查看目前X2Paddle支持的OP列表
。如果你在新的模型上进行了测试转换,也欢迎继续补充该列表;如若无法转换,可通过ISSUE反馈给我们,我们会尽快跟进。
## 环境依赖
...
...
@@ -19,18 +19,6 @@ onnx : onnx == 1.5.0 onnxruntime == 0.4.0
## 安装
### 安装方式一(推荐)
使用最新的代码版本,可使用如下方式进行安装
```
pip install git+https://github.com/PaddlePaddle/X2Paddle.git@develop
```
### 安装方式二
我们会定期更新pip源上的x2paddle版本
```
pip install x2paddle
```
### 安装方式三
```
git clone https://github.com/PaddlePaddle/X2Paddle.git
cd X2Paddle
...
...
@@ -38,6 +26,11 @@ git checkout develop
python setup.py install
```
### 安装方式二
我们会定期更新pip源上的x2paddle版本
```
pip install x2paddle --index https://pypi.Python.org/simple/
```
## 使用方法
### TensorFlow
```
...
...
@@ -45,7 +38,7 @@ x2paddle --framework=tensorflow --model=tf_model.pb --save_dir=pd_model
```
### Caffe
```
x2paddle --framework=caffe --prototxt=deploy.proto --weight=deploy.caffemodel --save_dir=pd_model
x2paddle --framework=caffe --prototxt=deploy.proto
txt
--weight=deploy.caffemodel --save_dir=pd_model
```
### ONNX
```
...
...
op_list.md
0 → 100644
浏览文件 @
84581b17
# X2Paddle支持OP列表
> 目前X2Paddle支持40+的TensorFlow OP,30+的Caffe Layer,覆盖了大部分CV分类模型常用的操作。我们在如下列表中给出了目前X2Paddle支持的全部OP。
**注:**
目前,部分OP暂未支持,如您在转换过程中出现OP不支持的情况,可自行添加或反馈给我们。欢迎通过
[
ISSUE反馈
](
https://github.com/PaddlePaddle/X2Paddle/issues/new
)
的方式告知我们(模型名,代码实现或模型获取方式),我们会及时跟进:)
## TensorFlow
| 序号 | OP | 序号 | OP | 序号 | OP | 序号 | OP |
|------|------|------|------|------|------|------|------|
| 1 | Relu | 2 | Relu6 | 3 | Shape | 4 | Abs |
| 5 | Sigmoid | 6 | Exp | 7 | Rsqrt | 8 | swish_f32 |
| 9 | Tanh | 10 | LeakyRelu | 11 | Add | 12 | RealDiv |
| 13 | Sub | 14 | Maximum | 15 | Mul | 16 | FloorDiv |
| 17 | Placeholder | 18 | Const | 19 | Transpose | 20 | FusedBatchNorm |
| 21 | Conv2D | 22 | BiasAdd | 23 | MaxPool | 24 | DepthwiseConv2dNative |
| 25 | Reshape | 26 | AvgPool | 27 | SplitV | 28 | SquaredDifference |
| 29 | Tile | 30 | Pack | 31 | Pad | 32 | ResizeBilinear |
| 33 | Mean | 34 | MatMul | 35 | ArgMax | 36 | StridedSlice |
| 37 | Slice | 38 | Sum | 39 | Max | 40 | Conv2DBackpropInput |
| 41 | Cast | 42 | Split | 43 | Squeeze | 44 | ResizeNearestNeighbor |
| 45 | Softmax | 46 | Range | 47 | ConcatV2 |
## Caffe
| 序号 | OP | 序号 | OP | 序号 | OP | 序号 | OP |
|------|------|------|------|------|------|------|------|
| 1 | Input | 2 | Convolution | 3 | Deconvolution | 4 | Pooling |
| 5 | LRN | 6 | InnerProduct | 7 | Softmax | 8 | Slice |
| 9 | Concat | 10 | PReLU | 11 | Accuracy | 12 | Eltwise |
| 13 | BatchNorm | 14 | Scale | 15 | Reshape | 16 | ArgMax |
| 17 | Crop | 18 | Flatten | 19 | Power | 20 | Reduction |
| 21 | Axpy | 22 | ROIPolling | 23 | Permute | 24 | DetectionOutput |
| 25 | Normalize | 26 | Select | 27 | ShuffleChannel | 28 | ConvolutionDepthwise |
| 29 | ReLU | 30 | AbsVal | 31 | Sigmoid | 32 | TanH |
## ONNX
| 序号 | OP | 序号 | OP | 序号 | OP | 序号 | OP |
|------|------|------|------|------|------|------|------|
| 1 | Relu | 2 | LeakyRelu | 3 | Elu | 4 | ThresholdedRelu |
| 5 | Prelu | 6 | Tanh | 7 | Shrink | 8 | Sigmoid |
| 9 | Pow | 10 | Softplus | 11 | Softsign | 12 | HardSigmoid |
| 13 | Exp | 14 | Add | 15 | Div | 16 | Sub |
| 17 | Mul | 18 | Shape | 19 | Clip | 20 | AveragePool |
| 21 | Sqrt | 22 | ReduceSum | 23 | ReduceMin | 24 | ReduceMean |
| 25 | Constant | 26 | Pad | 27 | Unsqueeze | 28 | Resize |
| 29 | Upsample | 30 | Expand | 31 | Gather | 32 | Slice |
| 33 | Cast | 34 | Split | 35 | Reshape | 36 | ConstantOfShape |
| 37 | Ceil | 38 | Concat | 39 | Flatten | 40 | ConvTranspose |
| 41 | MatMul | 42 | Sum | 43 | Transpose | 44 | BatchNormalization |
| 45 | Squeeze | 46 | Equal | 47 | Identity | 48 | GlobalAveragePool |
| 49 | MaxPool | 50 | Conv | 51 | Gemm |
x2paddle/convert.py
浏览文件 @
84581b17
...
...
@@ -104,10 +104,14 @@ def tf2paddle(model_path,
# neccesary optimization
optimizer
.
delete_redundance_code
()
# optimizer below is experimental
optimizer
.
optimize_elementwise_op
()
optimizer
.
merge_activation
()
optimizer
.
merge_bias
()
optimizer
.
merge_batch_norm
()
optimizer
.
merge_prelu
()
optimizer
.
optimize_sub_graph
()
# optimizer.merge_batch_norm()
# optimizer.merge_prelu()
else
:
mapper
=
TFOpMapperNHWC
(
model
)
optimizer
=
TFOptimizer
(
mapper
)
...
...
@@ -125,9 +129,12 @@ def caffe2paddle(proto, weight, save_dir, caffe_proto):
from
x2paddle.op_mapper.caffe_op_mapper
import
CaffeOpMapper
from
x2paddle.optimizer.caffe_optimizer
import
CaffeOptimizer
import
google.protobuf
as
gpb
ver_str
=
gpb
.
__version__
.
replace
(
'.'
,
''
)
ver_int
=
int
(
ver_str
[
0
:
2
])
assert
ver_int
>=
36
,
'The version of protobuf must be larger than 3.6.0!'
ver_part
=
gpb
.
__version__
.
split
(
'.'
)
version_satisfy
=
False
if
(
int
(
ver_part
[
0
])
==
3
and
int
(
ver_part
[
1
])
>=
6
)
\
or
(
int
(
ver_part
[
0
])
>
3
):
version_satisfy
=
True
assert
version_satisfy
,
'google.protobuf >= 3.6.0 is required'
print
(
"Now translating model from caffe to paddle."
)
model
=
CaffeDecoder
(
proto
,
weight
,
caffe_proto
)
mapper
=
CaffeOpMapper
(
model
)
...
...
x2paddle/decoder/tf_decoder.py
浏览文件 @
84581b17
...
...
@@ -60,6 +60,15 @@ class TFGraphNode(GraphNode):
raise
Exception
(
"Dtype[{}] not in dtype_map"
.
format
(
dtype
))
return
self
.
dtype_map
[
dtype
]
@
property
def
raw_dtype
(
self
):
keys
=
[
'dtype'
,
'Tidx'
,
'T'
,
'DstT'
]
for
k
in
keys
:
dtype
=
self
.
layer
.
attr
[
k
].
type
if
dtype
>
0
:
break
return
dtype
@
property
def
value
(
self
):
assert
self
.
layer_type
==
"Const"
,
"Only Const node has value."
...
...
@@ -120,6 +129,7 @@ class TFGraph(Graph):
# tensorflow graph optimize
self
.
_remove_isolated_node
()
self
.
_remove_identity_node
()
self
.
_remove_cast_node
()
def
get_node
(
self
,
node_name
,
copy
=
False
):
items
=
node_name
.
strip
().
split
(
':'
)
...
...
@@ -190,6 +200,27 @@ class TFGraph(Graph):
idx
=
self
.
output_nodes
.
index
(
node_name
)
self
.
output_nodes
[
idx
]
=
input_node
.
layer_name
def
_remove_cast_node
(
self
):
cast_node
=
list
()
for
node_name
,
node
in
self
.
node_map
.
items
():
if
node
.
layer_type
==
"Cast"
:
input
=
self
.
get_node
(
node
.
inputs
[
0
])
if
input
.
layer_type
!=
"Placeholder"
or
len
(
input
.
outputs
)
!=
1
:
continue
cast_node
.
append
(
node_name
)
for
node_name
in
cast_node
:
node
=
self
.
get_node
(
node_name
)
input_node
=
self
.
get_node
(
node
.
inputs
[
0
])
input_node
.
layer
.
attr
[
"dtype"
].
type
=
node
.
raw_dtype
self
.
remove_node
(
node_name
)
self
.
identity_map
[
node_name
]
=
input_node
.
layer_name
if
node_name
in
self
.
output_nodes
:
idx
=
self
.
output_nodes
.
index
(
node_name
)
self
.
output_nodes
[
idx
]
=
input_node
.
layer_name
def
data_format_propagation
(
self
,
node
):
current_node
=
self
.
node_map
[
node
.
layer_name
]
current_node
=
node
.
tf_data_format
...
...
@@ -281,6 +312,10 @@ class TFDecoder(object):
right_shape_been_input
=
False
while
not
right_shape_been_input
:
try
:
shape
=
raw_input
(
"Shape of Input(e.g. None,224,224,3): "
)
except
:
shape
=
input
(
"Shape of Input(e.g. None,224,224,3): "
)
if
shape
.
count
(
"None"
)
>
1
:
print
(
"Only 1 dimension can be None, type again:)"
)
...
...
x2paddle/op_mapper/caffe_custom_layer/register.py
浏览文件 @
84581b17
...
...
@@ -31,7 +31,6 @@ def register(kind, shape, layer, weights):
k
)
is
str
,
'invalid param "kind" for register, not a list of str'
assert
k
not
in
g_custom_layers
,
'this type[%s] has already been registered'
%
(
k
)
print
(
'register layer[%s]'
%
(
k
))
g_custom_layers
[
k
]
=
{
'shape'
:
shape
,
'layer'
:
layer
,
...
...
x2paddle/op_mapper/caffe_custom_layer/shufflechannel.py
浏览文件 @
84581b17
...
...
@@ -8,13 +8,7 @@ def shufflechannel_shape(input_shape):
def
shufflechannel_layer
(
inputs
,
group
=
None
,
input_shape
=
None
,
name
=
None
):
input
=
inputs
[
0
]
c_fm
=
fluid
.
layers
.
split
(
input
,
num_or_sections
=
input_shape
[
0
][
1
],
dim
=
1
)
size
=
int
(
input_shape
[
0
][
1
]
/
group
)
new_c_fm
=
[]
for
i
in
range
(
size
):
for
j
in
range
(
group
):
new_c_fm
.
append
(
c_fm
[
j
*
size
+
i
])
out
=
fluid
.
layers
.
concat
(
new_c_fm
,
axis
=
1
)
out
=
fluid
.
layers
.
shuffle_channel
(
x
=
input
,
group
=
group
)
return
out
...
...
x2paddle/op_mapper/caffe_op_mapper.py
浏览文件 @
84581b17
...
...
@@ -124,9 +124,6 @@ class CaffeOpMapper(OpMapper):
data
[
idx
]
=
np
.
squeeze
(
d
,
axis
=
sq_axis
)
shape_new
=
data
[
idx
].
shape
if
len
(
shape_old
)
!=
shape_new
:
print
(
'squeeze idx:%d, with kind:%s,name:%s'
%
\
(
idx
,
node
.
layer_type
,
node
.
layer
.
name
))
return
data
def
get_kernel_parameters
(
self
,
kind
,
params
):
...
...
@@ -366,7 +363,7 @@ class CaffeOpMapper(OpMapper):
input
=
self
.
graph
.
get_bottom_node
(
node
,
idx
=
0
,
copy
=
True
)
attr
=
{
'n'
:
params
.
local_size
,
'k'
:
1.0
,
'k'
:
params
.
k
,
'alpha'
:
alpha
,
'beta'
:
params
.
beta
,
'name'
:
string
(
node
.
layer_name
)
...
...
@@ -453,35 +450,19 @@ class CaffeOpMapper(OpMapper):
slice_dim
=
params
.
slice_dim
if
slice_dim
!=
1
and
axis
==
1
:
axis
=
slice_dim
points
=
list
(
params
.
slice_point
)
if
len
(
points
)
==
0
:
dims
=
node
.
input_shape
[
0
][
axis
]
assert
dims
%
top_len
==
0
,
"the parameter of Slice is wrong"
part
=
dims
/
top_len
t
=
part
while
t
<
dims
:
points
.
append
(
int
(
t
))
t
+=
part
maxint32
=
2147483647
points
=
[
0
]
+
points
points
.
append
(
maxint32
)
i
=
0
node
.
fluid_code
.
add_note
(
'{} = []'
.
format
(
node
.
layer_name
))
for
i
in
range
(
len
(
points
)):
output_shape
=
node
.
output_shape
sections_list
=
[]
for
s
in
output_shape
:
sections_list
.
append
(
s
[
axis
])
attr
=
{
'axes'
:
[
axis
]
,
'starts'
:
[
points
[
i
]]
,
'ends'
:
[
points
[
i
+
1
]]
'num_or_sections'
:
sections_list
,
'dim'
:
axis
,
'name'
:
string
(
node
.
layer_name
)
}
node
.
fluid_code
.
add_layer
(
"slice
"
,
node
.
fluid_code
.
add_layer
(
"split
"
,
inputs
=
input
,
output
=
node
.
layer_name
+
'_'
+
str
(
i
)
,
output
=
node
.
layer_name
,
param_attr
=
attr
)
node
.
fluid_code
.
add_note
(
'{}.append({})'
.
format
(
node
.
layer_name
,
node
.
layer_name
+
'_'
+
str
(
i
)))
if
i
==
len
(
points
)
-
2
:
break
def
Concat
(
self
,
node
):
assert
len
(
...
...
@@ -652,7 +633,8 @@ class CaffeOpMapper(OpMapper):
]).
astype
(
'float32'
)
scale
=
0
else
:
node
.
data
=
[
np
.
squeeze
(
i
)
for
i
in
node
.
data
]
node
.
data
=
[
np
.
squeeze
(
i
).
astype
(
'float32'
)
for
i
in
node
.
data
]
mean
,
variance
,
scale
=
node
.
data
# Prescale the stats
scaling_factor
=
1.0
/
scale
if
scale
!=
0
else
0
...
...
@@ -687,8 +669,10 @@ class CaffeOpMapper(OpMapper):
input_c
,
]).
astype
(
'float32'
)
else
:
self
.
weights
[
node
.
layer_name
+
'_scale'
]
=
np
.
squeeze
(
node
.
data
[
0
])
self
.
weights
[
node
.
layer_name
+
'_offset'
]
=
np
.
squeeze
(
node
.
data
[
1
])
self
.
weights
[
node
.
layer_name
+
'_scale'
]
=
np
.
squeeze
(
node
.
data
[
0
]).
astype
(
'float32'
)
self
.
weights
[
node
.
layer_name
+
'_offset'
]
=
np
.
squeeze
(
node
.
data
[
1
]).
astype
(
'float32'
)
params
=
node
.
layer
.
scale_param
axis
=
params
.
axis
num_axes
=
params
.
num_axes
...
...
@@ -956,7 +940,6 @@ class CaffeOpMapper(OpMapper):
input
=
self
.
graph
.
get_bottom_node
(
node
,
idx
=
i
,
copy
=
True
)
if
i
==
1
and
op
==
'DetectionOutput'
:
input
=
self
.
graph
.
get_bottom_node
(
node
,
idx
=
i
,
copy
=
True
)
print
(
input
.
layer_type
)
while
input
is
not
None
and
input
.
layer_type
!=
'Softmax'
:
input
=
self
.
graph
.
get_bottom_node
(
input
,
idx
=
0
,
copy
=
True
)
assert
input
is
not
None
,
'This kind of DetectionOutput is not supported!'
...
...
x2paddle/op_mapper/onnx_custom_layer/register.py
浏览文件 @
84581b17
...
...
@@ -44,7 +44,6 @@ def register(kind, shape, layer, child_func, weights):
k
)
is
str
,
'invalid param "kind" for register, not a list of str'
assert
k
not
in
g_custom_layers
,
'this type[%s] has already been registered'
%
(
k
)
print
(
'register layer[%s]'
%
(
k
))
g_custom_layers
[
k
]
=
{
'shape'
:
shape
,
'layer'
:
layer
,
...
...
x2paddle/op_mapper/tf_op_mapper.py
浏览文件 @
84581b17
...
...
@@ -168,7 +168,32 @@ class TFOpMapper(OpMapper):
x_input
=
y
y_input
=
x
x_shape
=
y
.
out_shapes
[
0
]
if
len
(
x_shape
)
==
0
:
x_shape
=
[
1
]
y_shape
=
x
.
out_shapes
[
0
]
if
len
(
y_shape
)
==
0
:
y_shape
=
[
1
]
else
:
if
len
(
x_shape
)
==
1
and
len
(
y_shape
)
==
4
and
x_shape
[
0
]
==
y_shape
[
-
1
]
and
y_shape
.
count
(
-
1
)
<
1
:
shape
=
[
1
,
x_shape
[
0
],
1
,
1
]
attr
=
{
"shape"
:
shape
}
node
.
fluid_code
.
add_layer
(
"reshape"
,
inputs
=
x_input
,
output
=
"reshape_x"
,
param_attr
=
attr
)
if
y_shape
[
0
]
!=
1
:
attr
=
{
"expand_times"
:
[
y_shape
[
0
],
1
,
1
,
1
]}
node
.
fluid_code
.
add_layer
(
"expand"
,
inputs
=
"reshape_x"
,
output
=
"reshape_x"
,
param_attr
=
attr
)
inputs
=
{
"x"
:
"reshape_x"
,
"y"
:
y_input
}
node
.
fluid_code
.
add_layer
(
op_type
,
inputs
=
inputs
,
output
=
node
,
param_attr
=
None
)
return
else
:
raise
Exception
(
"Unexpected situation happend"
)
...
...
@@ -985,7 +1010,7 @@ class TFOpMapper(OpMapper):
attr
=
{
"bias_attr"
:
False
,
"param_attr"
:
string
(
kernel
.
layer_name
),
"num_filters"
:
k_size
[
3
],
"num_filters"
:
k_size
[
2
],
"filter_size"
:
k_size
[
0
:
2
],
"stride"
:
strides
[
2
:
4
],
"dilation"
:
dilations
[
2
:
4
],
...
...
@@ -1091,40 +1116,6 @@ class TFOpMapper(OpMapper):
output
=
node
,
param_attr
=
attr
)
def
ResizeNearestNeighbor
(
self
,
node
):
input
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
resize_shape
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
1
],
copy
=
True
)
self
.
add_omit_nodes
(
resize_shape
.
layer_name
,
node
.
layer_name
)
if
resize_shape
.
layer_type
==
"Const"
:
resize_shape
=
resize_shape
.
value
.
tolist
()
else
:
resize_shape
=
self
.
decoder
.
infer_shape_tensor
(
resize_shape
)
align_corners
=
node
.
get_attr
(
"align_corners"
)
attr
=
{
"align_corners"
:
align_corners
,
"out_shape"
:
resize_shape
}
node
.
fluid_code
.
add_layer
(
"resize_nearest"
,
inputs
=
input
,
output
=
node
,
param_attr
=
attr
)
def
ResizeBilinear
(
self
,
node
):
input
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
resize_shape
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
1
],
copy
=
True
)
self
.
add_omit_nodes
(
resize_shape
.
layer_name
,
node
.
layer_name
)
if
resize_shape
.
layer_type
==
"Const"
:
resize_shape
=
resize_shape
.
value
.
tolist
()
else
:
resize_shape
=
self
.
decoder
.
infer_shape_tensor
(
resize_shape
)
align_corners
=
node
.
get_attr
(
"align_corners"
)
attr
=
{
"align_corners"
:
align_corners
,
"out_shape"
:
resize_shape
,
"align_mode"
:
1
}
node
.
fluid_code
.
add_layer
(
"resize_bilinear"
,
inputs
=
input
,
output
=
node
,
param_attr
=
attr
)
def
ResizeNearestNeighbor
(
self
,
node
):
input
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
resize_shape
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
1
],
copy
=
True
)
...
...
@@ -1170,37 +1161,6 @@ class TFOpMapper(OpMapper):
output
=
node
,
param_attr
=
None
)
def
RandomUniform
(
self
,
node
):
shape
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
self
.
add_omit_nodes
(
shape
.
layer_name
,
node
.
layer_name
)
if
shape
.
layer_type
==
"Const"
:
shape
=
shape
.
value
.
tolist
()
else
:
shape
=
self
.
decoder
.
infer_shape_tensor
(
shape
)
if
node
.
tf_data_format
==
"NHWC"
and
len
(
shape
)
==
4
:
shape
=
[
shape
[
i
]
for
i
in
[
0
,
3
,
1
,
2
]]
attr
=
{
"shape"
:
shape
,
"min"
:
0.0
,
"max"
:
0.9999
}
if
shape
[
0
]
<
0
:
input
=
self
.
batch_node
node
.
fluid_code
.
add_layer
(
"uniform_random_batch_size_like"
,
inputs
=
input
,
output
=
node
,
param_attr
=
attr
)
else
:
node
.
fluid_code
.
add_layer
(
"uniform_random"
,
inputs
=
None
,
output
=
node
,
param_attr
=
attr
)
def
GreaterEqual
(
self
,
node
):
x
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
y
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
1
],
copy
=
True
)
inputs
=
{
"x"
:
x
,
"y"
:
y
}
node
.
fluid_code
.
add_layer
(
"greater_equal"
,
inputs
=
inputs
,
output
=
node
,
param_attr
=
None
)
def
RandomUniform
(
self
,
node
):
shape
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
self
.
add_omit_nodes
(
shape
.
layer_name
,
node
.
layer_name
)
...
...
x2paddle/op_mapper/tf_op_mapper_nhwc.py
浏览文件 @
84581b17
...
...
@@ -41,6 +41,7 @@ class TFOpMapperNHWC(OpMapper):
'Exp'
:
[
'exp'
],
'Rsqrt'
:
[
'rsqrt'
],
'swish_f32'
:
[
'swish'
],
'Tanh'
:
[
'tanh'
],
'LeakyRelu'
:
[
'leaky_relu'
,
{
'alpha'
:
'alpha'
}]
...
...
@@ -120,6 +121,25 @@ class TFOpMapperNHWC(OpMapper):
pd_param_name
=
list
(
param
.
values
())[
0
]
tf_param
=
node
.
get_attr
(
tf_param_name
)
attr
[
pd_param_name
]
=
tf_param
if
len
(
input
.
out_shapes
[
0
])
==
4
and
op_info
[
0
]
!=
'shape'
:
attr1
=
{
"perm"
:
[
0
,
3
,
1
,
2
]}
node
.
fluid_code
.
add_layer
(
'transpose'
,
inputs
=
input
,
output
=
node
,
param_attr
=
attr1
)
input
=
node
node
.
fluid_code
.
add_layer
(
op_info
[
0
],
inputs
=
input
,
output
=
node
,
param_attr
=
attr
)
input
=
node
attr2
=
{
"perm"
:
[
0
,
2
,
3
,
1
]}
node
.
fluid_code
.
add_layer
(
'transpose'
,
inputs
=
input
,
output
=
node
,
param_attr
=
attr2
)
else
:
node
.
fluid_code
.
add_layer
(
op_info
[
0
],
inputs
=
input
,
output
=
node
,
...
...
@@ -148,7 +168,11 @@ class TFOpMapperNHWC(OpMapper):
x_input
=
y
y_input
=
x
x_shape
=
y
.
out_shapes
[
0
]
if
len
(
x_shape
)
==
0
:
x_shape
=
[
1
]
y_shape
=
x
.
out_shapes
[
0
]
if
len
(
y_shape
)
==
0
:
y_shape
=
[
1
]
else
:
raise
Exception
(
"Unexpected situation happend"
)
...
...
@@ -192,6 +216,25 @@ class TFOpMapperNHWC(OpMapper):
output
=
"y_tmp"
,
param_attr
=
attr
)
y_input
=
"y_tmp"
if
len
(
x_shape
)
==
4
and
len
(
y_shape
)
==
4
:
node
.
fluid_code
.
add_layer
(
"transpose"
,
inputs
=
x_input
,
output
=
x_input
,
param_attr
=
{
'perm'
:
[
0
,
3
,
1
,
2
]})
node
.
fluid_code
.
add_layer
(
"transpose"
,
inputs
=
y_input
,
output
=
y_input
,
param_attr
=
{
'perm'
:
[
0
,
3
,
1
,
2
]})
inputs
=
{
"x"
:
x_input
,
"y"
:
y_input
}
node
.
fluid_code
.
add_layer
(
op_type
,
inputs
=
inputs
,
output
=
node
,
param_attr
=
None
)
node
.
fluid_code
.
add_layer
(
"transpose"
,
inputs
=
node
,
output
=
node
,
param_attr
=
{
'perm'
:
[
0
,
2
,
3
,
1
]})
else
:
inputs
=
{
"x"
:
x_input
,
"y"
:
y_input
}
node
.
fluid_code
.
add_layer
(
op_type
,
inputs
=
inputs
,
...
...
@@ -913,6 +956,12 @@ class TFOpMapperNHWC(OpMapper):
self
.
add_omit_nodes
(
kernel
.
layer_name
,
node
.
layer_name
)
self
.
add_omit_nodes
(
out_shape
.
layer_name
,
node
.
layer_name
)
if
out_shape
.
layer_type
==
"Const"
:
out_shape
=
out_shape
.
value
.
tolist
()
else
:
out_shape
=
self
.
decoder
.
infer_shape_tensor
(
out_shape
,
node
.
out_shapes
[
0
])
in_shape
=
input
.
out_shapes
[
0
]
if
in_shape
.
count
(
-
1
)
>
2
:
in_shape
=
self
.
decoder
.
infer_tensor
(
input
).
shape
...
...
@@ -920,7 +969,7 @@ class TFOpMapperNHWC(OpMapper):
if
k_size
.
count
(
-
1
)
>
2
:
k_size
=
self
.
decoder
.
infer_tensor
(
kernel
).
shape
pad_mode
=
node
.
get_attr
(
"padding"
)
pad_mode
=
node
.
get_attr
(
"padding"
)
.
decode
()
strides
=
node
.
get_attr
(
"strides"
)
dilations
=
node
.
get_attr
(
"dilations"
)
data_format
=
node
.
get_attr
(
"data_format"
).
decode
()
...
...
@@ -958,7 +1007,7 @@ class TFOpMapperNHWC(OpMapper):
attr
=
{
"bias_attr"
:
False
,
"param_attr"
:
string
(
kernel
.
layer_name
),
"num_filters"
:
k_size
[
3
],
"num_filters"
:
k_size
[
2
],
"filter_size"
:
k_size
[
0
:
2
],
"stride"
:
strides
[
2
:
4
],
"dilation"
:
dilations
[
2
:
4
],
...
...
@@ -969,6 +1018,22 @@ class TFOpMapperNHWC(OpMapper):
output
=
node
,
param_attr
=
attr
)
if
pad_mode
==
"SAME"
:
if
node
.
tf_data_format
==
"NHWC"
:
out_shape
=
[
out_shape
[
i
]
for
i
in
[
0
,
3
,
1
,
2
]]
for
i
in
range
(
4
):
if
out_shape
[
i
]
<
0
:
out_shape
[
i
]
=
999999
attr
=
{
"axes"
:
[
0
,
1
,
2
,
3
],
"starts"
:
[
0
,
0
,
0
,
0
],
"ends"
:
out_shape
}
node
.
fluid_code
.
add_layer
(
"slice"
,
inputs
=
node
,
output
=
node
,
param_attr
=
attr
)
if
not
channel_first
:
attr
=
{
"perm"
:
[
0
,
2
,
3
,
1
]}
node
.
fluid_code
.
add_layer
(
"transpose"
,
...
...
@@ -1127,3 +1192,17 @@ class TFOpMapperNHWC(OpMapper):
inputs
=
None
,
output
=
node
,
param_attr
=
attr
)
def
SquaredDifference
(
self
,
node
):
x
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
y
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
1
],
copy
=
True
)
inputs
=
{
"x"
:
x
,
"y"
:
y
}
node
.
fluid_code
.
add_layer
(
"elementwise_sub"
,
inputs
=
inputs
,
output
=
node
,
param_attr
=
None
)
inputs
=
{
"x"
:
node
,
"y"
:
node
}
node
.
fluid_code
.
add_layer
(
"elementwise_mul"
,
inputs
=
inputs
,
output
=
node
,
param_attr
=
None
)
x2paddle/optimizer/tf_optimizer.py
浏览文件 @
84581b17
此差异已折叠。
点击以展开。
x2paddle_model_zoo.md
浏览文件 @
84581b17
...
...
@@ -65,3 +65,4 @@
| mNASNet |
[
pytorch(personal practice)
](
https://github.com/rwightman/gen-efficientnet-pytorch
)
|9|
| EfficientNet |
[
pytorch(personal practice)
](
https://github.com/rwightman/gen-efficientnet-pytorch
)
|9|
| SqueezeNet |
[
onnx official
](
https://s3.amazonaws.com/download.onnx/models/opset_9/squeezenet.tar.gz
)
|9|
|Ultra-Light-Fast-Generic-Face-Detector-1MB|
[
onnx_model
](
https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB/tree/master/models/onnx
)
| |
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录