Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
74feae9b
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
74feae9b
编写于
7月 08, 2019
作者:
M
Macrobull
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
force embeddable checking
上级
fcbdcb82
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
93 addition
and
78 deletion
+93
-78
README.md
README.md
+0
-2
onnx2fluid/onnx2fluid/symbolic.py
onnx2fluid/onnx2fluid/symbolic.py
+93
-76
未找到文件。
README.md
浏览文件 @
74feae9b
...
...
@@ -8,8 +8,6 @@ X2Paddle支持将Caffe和TensorFlow模型转至PaddlePaddle模型,同时我们
任何使用问题均可通过
[
ISSUE
](
https://github.com/PaddlePaddle/X2Paddle/issues
)
的方式及时反馈,或者也可直接通过pull request的方式一起更新代码和文档。
> **目前X2Paddle主要支持CV部分模型,对于NLP模型暂未支持。**
## [caffe2fluid](caffe2fluid)
1.
支持将Caffe模型转至PaddlePaddle fluid可加载预测模型
2.
提供Caffe-PaddlePaddle常用API的对比文档
[
[doc
](
caffe2fluid/doc
)
]
...
...
onnx2fluid/onnx2fluid/symbolic.py
浏览文件 @
74feae9b
...
...
@@ -78,7 +78,6 @@ DEFAULT_OP_MAPPING = {
'Sign'
:
[
'sign'
,
[
'X'
],
[
'Out'
]],
'Sin'
:
[
'sin'
,
[
'X'
],
[
'Out'
]],
'Squeeze'
:
[
'squeeze'
,
[
'X'
],
[
'Out'
]],
# attrs bypassed, FIXME: emit squeeze2
'Softplus'
:
[
'softplus'
,
[
'X'
],
[
'Out'
]],
# FIXME: default axis = -1, reshape required before and after
'Softmax'
:
[
'softmax'
,
[
'X'
],
[
'Out'
],
dict
(
axis
=
''
)],
'Softplus'
:
[
'softplus'
,
[
'X'
],
[
'Out'
]],
...
...
@@ -305,7 +304,7 @@ def _pad_if_asymmetric(prog, pads, var_input, value_infos): # pads: SSEE
if
symmetric
:
return
pads
[:
ndims
],
var_input
var_padded
=
var_input
+
'_pad
ded
'
# explicit variable
var_padded
=
var_input
+
'_pad'
# explicit variable
prog
.
Op
(
''
,
'Pad'
,
...
...
@@ -317,7 +316,7 @@ def _pad_if_asymmetric(prog, pads, var_input, value_infos): # pads: SSEE
'pads'
:
pads
,
},
value_infos
=
value_infos
,
name
=
(
var_input
+
'
_
pad'
),
name
=
(
var_input
+
'
/
pad'
),
)
return
[
0
]
*
ndims
,
var_padded
...
...
@@ -688,13 +687,14 @@ def BatchNormalization(prog,
momentum
=
attrs
.
get
(
'momentum'
,
.
9
)
# optional
epsilon
=
attrs
.
get
(
'epsilon'
,
1e-5
)
# optional
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
if
embed_params
:
embed_params
=
_check_embeddable
(
value_infos
,
var_scale
,
var_b
,
var_mean
,
var_var
)
if
not
embed_params
and
name
:
_logger
.
warning
(
'for op %s(%s -> BatchNormalization -> %s)'
,
name
,
inputs
,
outputs
)
_logger
.
warning
(
'broken Python code will be generated'
)
embeddable
=
_check_embeddable
(
value_infos
,
var_scale
,
var_b
,
var_mean
,
var_var
)
if
not
embeddable
:
_logger
.
warning
(
'for op %s(%s -> BatchNormalization -> %s)'
,
name
,
inputs
,
outputs
)
_logger
.
warning
(
'one of the parameters is intermediate value'
)
_logger
.
warning
(
'broken Python code will be generated'
)
embed_params
&=
embeddable
if
embed_params
:
assert
name
!=
''
embedded_scale
=
name
+
'.w_0'
...
...
@@ -898,10 +898,10 @@ def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
'given shape is neither const value nor deductible from output, '
'this is not supported'
)
attrs
=
attrs
.
copy
()
attrs
.
setdefault
(
'value'
,
np
.
array
(
0
,
dtype
=
np
.
float32
))
attrs
.
setdefault
(
'value'
,
_np
.
array
(
0
,
dtype
=
_
np
.
float32
))
attrs
.
update
({
'shape'
:
shape
})
# pass const
prog
.
Code
(
'# shape:
{}=
{} # const as literal'
.
format
(
var_shape
,
shape
))
prog
.
Code
(
'# shape:
{} =
{} # const as literal'
.
format
(
var_shape
,
shape
))
prog
.
Op
(
''
,
'Constant'
,
...
...
@@ -947,13 +947,13 @@ def Conv(prog,
pads
=
attrs
.
get
(
'pads'
,
[
0
]
*
(
convnd
*
2
))
# optional
paddings
,
var_x
=
_pad_if_asymmetric
(
prog
,
pads
,
var_x
,
value_infos
)
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
embed_params
:
embed_params
=
_check_embeddable
(
value_infos
,
*
([
var_w
]
+
([
var_b
]
if
var_b
else
[])))
if
not
embed_params
:
_logger
.
warning
(
'for op %s(%s -> Conv -> %s)'
,
name
,
inputs
,
outputs
)
_logger
.
warning
(
'broken Python code will be generated'
)
embeddable
=
_check_embeddable
(
value_infos
,
*
([
var_w
]
+
([
var_b
]
if
var_b
else
[])))
if
not
embeddable
:
_logger
.
warning
(
'for op %s(%s -> Conv -> %s)'
,
name
,
inputs
,
outputs
)
_logger
.
warning
(
'one of the parameters is intermediate value'
)
_logger
.
warning
(
'broken Python code will be generated'
)
embed_params
&=
embeddable
if
embed_params
:
embedded_w
=
name
+
'.w_0'
value_infos
[
var_w
][
'embedded_as'
].
append
(
embedded_w
)
...
...
@@ -1013,7 +1013,7 @@ def Conv(prog,
[
var_y
],
{
'axis'
:
1
},
value_infos
=
value_infos
,
name
=
(
name
+
'
.
bias'
),
name
=
(
name
+
'
/
bias'
),
)
else
:
prog
.
VarDesc
(
var_y
)
...
...
@@ -1058,13 +1058,14 @@ def ConvTranspose(prog,
pads
=
attrs
.
get
(
'pads'
,
[
0
]
*
(
convnd
*
2
))
# optional
paddings
,
var_x
=
_pad_if_asymmetric
(
prog
,
pads
,
var_x
,
value_infos
)
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
embed_params
:
embed_params
=
_check_embeddable
(
value_infos
,
*
([
var_w
]
+
([
var_b
]
if
var_b
else
[])))
if
not
embed_params
:
_logger
.
warning
(
'for op %s(%s -> ConvTranspose -> %s)'
,
name
,
inputs
,
outputs
)
_logger
.
warning
(
'broken Python code will be generated'
)
embeddable
=
_check_embeddable
(
value_infos
,
*
([
var_w
]
+
([
var_b
]
if
var_b
else
[])))
if
not
embeddable
:
_logger
.
warning
(
'for op %s(%s -> ConvTranspose -> %s)'
,
name
,
inputs
,
outputs
)
_logger
.
warning
(
'one of the parameters is intermediate value'
)
_logger
.
warning
(
'broken Python code will be generated'
)
embed_params
&=
embeddable
if
embed_params
:
embedded_w
=
name
+
'.w_0'
value_infos
[
var_w
][
'embedded_as'
].
append
(
embedded_w
)
...
...
@@ -1128,7 +1129,7 @@ def ConvTranspose(prog,
[
var_y
],
{
'axis'
:
1
},
value_infos
=
value_infos
,
name
=
(
name
+
'
.
bias'
),
name
=
(
name
+
'
/
bias'
),
)
else
:
prog
.
VarDesc
(
var_y
)
...
...
@@ -1148,7 +1149,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
trans_a
=
bool
(
attrs
.
get
(
'transA'
,
0
))
# optional
trans_b
=
bool
(
attrs
.
get
(
'transB'
,
0
))
# optional
var_mm
=
var_y
if
beta
==
0
else
(
name
+
'_mm
ed
'
)
# explicit variable
var_mm
=
var_y
if
beta
==
0
else
(
name
+
'_mm'
)
# explicit variable
prog
.
Op
(
''
,
'MatMul'
,
...
...
@@ -1160,7 +1161,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
'alpha'
:
alpha
,
},
value_infos
=
value_infos
,
name
=
(
name
+
'
_
mm'
),
name
=
(
name
+
'
/
mm'
),
)
prog
.
op_descs
[
-
1
].
attrs
.
extend
(
prog
.
OpDescAttrs
({
...
...
@@ -1176,7 +1177,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[
var_y
],
{
'axis'
:
1
},
value_infos
=
value_infos
,
name
=
(
name
+
'
_
bias'
),
name
=
(
name
+
'
/
bias'
),
)
else
:
var_beta
=
name
+
'_beta'
# explicit variable
...
...
@@ -1207,7 +1208,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[
var_vm
],
dict
(),
value_infos
=
value_infos
,
name
=
(
var_beta
+
'
_
scale'
),
name
=
(
var_beta
+
'
/
scale'
),
)
prog
.
Op
(
''
,
...
...
@@ -1215,7 +1216,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[
var_mm
,
var_vm
],
[
var_y
],
{
'axis'
:
1
},
#
name
=
(
name
+
'
_
bias'
),
name
=
(
name
+
'
/
bias'
),
)
...
...
@@ -1307,6 +1308,9 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
is_reverse
=
direction
==
'reverse'
fluid_op
=
'dynamic_gru'
_logger
.
warning
(
'for op (%s -> GRU -> %s)'
,
inputs
,
outputs
)
_logger
.
warning
(
'one of the parameters is intermediate value'
)
_logger
.
warning
(
'broken Python code will be generated'
)
# generation
var_x0
=
var_x
+
'_0'
# explicit variable
...
...
@@ -1316,7 +1320,7 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
[
var_x
],
[
var_x0
],
{
'axes'
:
[
1
]},
# index on n
name
=
(
var_x
+
'
_
index'
),
name
=
(
var_x
+
'
/
index'
),
)
var_w0
=
var_w
+
'_0'
# explicit variable
prog
.
Op
(
...
...
@@ -1325,10 +1329,10 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
[
var_w
],
[
var_w0
],
{
'axes'
:
[
0
]},
# index on d
name
=
(
var_w
+
'
_
index'
),
name
=
(
var_w
+
'
/
index'
),
)
var_fc
=
var_x0
+
'_fc'
var_mm
=
(
var_x0
+
'_mm
ed
'
)
if
var_b
else
var_fc
var_mm
=
(
var_x0
+
'_mm'
)
if
var_b
else
var_fc
prog
.
Op
(
''
,
'MatMul'
,
...
...
@@ -1339,7 +1343,7 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
'transpose_y'
:
1
,
},
value_infos
=
value_infos
,
name
=
(
var_x0
+
'
_
mm'
),
name
=
(
var_x0
+
'
/
mm'
),
)
prog
.
op_descs
[
-
1
].
attrs
.
extend
(
prog
.
OpDescAttrs
({
...
...
@@ -1353,7 +1357,7 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
[
var_r
],
[
var_r0
],
{
'axes'
:
[
0
]},
# index on d
name
=
(
var_r
+
'
_
index'
),
name
=
(
var_r
+
'
/
index'
),
)
var_r0t
=
var_r0
+
'_t'
# explicit variable
prog
.
Op
(
...
...
@@ -1362,7 +1366,7 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
[
var_r0
],
[
var_r0t
],
{
'perm'
:
[
1
,
0
]},
# transpose OI->IO
name
=
(
var_r0
+
'
_
transpose'
),
name
=
(
var_r0
+
'
/
transpose'
),
)
if
var_b
:
var_bi
=
var_b
+
'_i'
# explicit variable
...
...
@@ -1376,7 +1380,7 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
'axis'
:
1
,
# split on x
'split'
:
[
hidden_size
*
3
,
hidden_size
*
3
],
},
name
=
(
var_b
+
'
_
split'
),
name
=
(
var_b
+
'
/
split'
),
)
# squeeze bi so Gemm Add can be performed on axis=1 exaclty
var_bi0
=
var_bi
+
'_0'
# explicit variable
...
...
@@ -1386,7 +1390,7 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
[
var_bi
],
[
var_bi0
],
{
'axes'
:
[
0
]},
# slice on d
name
=
(
var_bi
+
'
_
index'
),
name
=
(
var_bi
+
'
/
index'
),
)
prog
.
Op
(
''
,
...
...
@@ -1394,7 +1398,7 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
[
var_mm
,
var_bi0
],
[
var_fc
],
{
'axis'
:
1
},
#
name
=
(
var_x0
+
'
_
bias'
),
name
=
(
var_x0
+
'
/
bias'
),
)
if
var_xh
:
var_xh0
=
var_xh
+
'_0'
# explicit variable
...
...
@@ -1404,7 +1408,7 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
[
var_xh
],
[
var_xh0
],
{
'axes'
:
[
1
]},
# index on n
name
=
(
var_xh
+
'
_
index'
),
name
=
(
var_xh
+
'
/
index'
),
)
var_y00
=
var_y
+
'_00'
# explicit variable
prog
.
Code
(
'{} = layers.{}({}, {}, origin_mode=True'
...
...
@@ -1449,7 +1453,7 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
[
var_y00
],
[
var_y
],
{
'axes'
:
[
1
,
1
]},
# extrude on dn
name
=
(
var_y
+
'
_
reshape'
),
name
=
(
var_y
+
'
/
reshape'
),
)
...
...
@@ -1511,6 +1515,9 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
fluid_op
=
'dynamic_lstm'
name_attr
=
', name={}'
.
format
(
repr
(
name
))
_logger
.
warning
(
'for op %s(%s -> LSTM -> %s)'
,
name
,
inputs
,
outputs
)
_logger
.
warning
(
'one of the parameters is intermediate value'
)
_logger
.
warning
(
'broken Python code will be generated'
)
# generation
var_x0
=
var_x
+
'_0'
# explicit variable
...
...
@@ -1520,7 +1527,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[
var_x
],
[
var_x0
],
{
'axes'
:
[
1
]},
# index on n
name
=
(
var_x
+
'
_
index'
),
name
=
(
var_x
+
'
/
index'
),
)
var_w0
=
var_w
+
'_0'
# explicit variable
prog
.
Op
(
...
...
@@ -1529,10 +1536,10 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[
var_w
],
[
var_w0
],
{
'axes'
:
[
0
]},
# index on d
name
=
(
var_w
+
'
_
index'
),
name
=
(
var_w
+
'
/
index'
),
)
var_fc
=
var_x0
+
'_fc'
var_mm
=
(
var_x0
+
'_mm
ed
'
)
if
var_b
else
var_fc
var_mm
=
(
var_x0
+
'_mm'
)
if
var_b
else
var_fc
prog
.
Op
(
''
,
'MatMul'
,
...
...
@@ -1543,7 +1550,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
'transpose_y'
:
1
,
},
value_infos
=
value_infos
,
name
=
(
name
+
'
_
mm'
),
name
=
(
name
+
'
/
mm'
),
)
prog
.
op_descs
[
-
1
].
attrs
.
extend
(
prog
.
OpDescAttrs
({
...
...
@@ -1557,7 +1564,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[
var_r
],
[
var_r0
],
{
'axes'
:
[
0
]},
# index on d
name
=
(
var_r
+
'
_
index'
),
name
=
(
var_r
+
'
/
index'
),
)
var_r0t
=
var_r0
+
'_t'
# explicit variable
prog
.
Op
(
...
...
@@ -1566,7 +1573,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[
var_r0
],
[
var_r0t
],
{
'perm'
:
[
1
,
0
]},
# transpose OI->IO
name
=
(
var_r0
+
'
_
transpose'
),
name
=
(
var_r0
+
'
/
transpose'
),
)
if
var_b
:
var_bi
=
var_b
+
'_i'
# explicit variable
...
...
@@ -1580,7 +1587,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
'axis'
:
1
,
# split on x
'split'
:
[
hidden_size
*
4
,
hidden_size
*
4
],
},
name
=
(
var_b
+
'
_
split'
),
name
=
(
var_b
+
'
/
split'
),
)
# squeeze bi so Gemm Add can be performed on axis=1 exaclty
var_bi0
=
var_bi
+
'_0'
# explicit variable
...
...
@@ -1590,7 +1597,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[
var_bi
],
[
var_bi0
],
{
'axes'
:
[
0
]},
# slice on d
name
=
(
var_bi
+
'
_
index'
),
name
=
(
var_bi
+
'
/
index'
),
)
prog
.
Op
(
''
,
...
...
@@ -1598,7 +1605,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[
var_mm
,
var_bi0
],
[
var_fc
],
{
'axis'
:
1
},
#
name
=
(
name
+
'
_
bias'
),
name
=
(
name
+
'
/
bias'
),
)
if
var_xh
:
var_xh0
=
var_xh
+
'_0'
# explicit variable
...
...
@@ -1608,7 +1615,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[
var_xh
],
[
var_xh0
],
{
'axes'
:
[
1
]},
# index on n
name
=
(
var_xh
+
'
_
index'
),
name
=
(
var_xh
+
'
/
index'
),
)
if
var_xc
:
var_xc0
=
var_xc
+
'_0'
# explicit variable
...
...
@@ -1618,7 +1625,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[
var_xc
],
[
var_xc0
],
{
'axes'
:
[
1
]},
# index on n
name
=
(
var_xc
+
'
_
index'
),
name
=
(
var_xc
+
'
/
index'
),
)
var_bhp
=
var_p
if
var_b
:
...
...
@@ -1630,7 +1637,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[
var_bh
,
var_p
],
[
var_bhp
],
{
'axes'
:
[
1
]},
# cat on x
name
=
(
name
+
'
_
concat'
),
name
=
(
name
+
'
/
concat'
),
)
else
:
var_bhp
=
var_bh
...
...
@@ -1690,7 +1697,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[
var_yh0
],
[
var_y
],
# var_yh
{
'axes'
:
[
1
,
1
]},
# extrude on dn
name
=
(
var_y
+
'
_
reshape'
),
name
=
(
var_y
+
'
/
reshape'
),
)
if
var_yc
:
prog
.
Op
(
...
...
@@ -1699,7 +1706,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[
var_yc0
],
[
var_yc
],
{
'axes'
:
[
1
,
1
]},
# extrude on dn
name
=
(
var_yc
+
'
_
reshape'
),
name
=
(
var_yc
+
'
/
reshape'
),
)
...
...
@@ -1811,12 +1818,12 @@ def PRelu(prog,
mode
=
'element'
fluid_op
=
'prelu'
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
if
embed_params
:
embed_params
=
_check_embeddable
(
value_infos
,
var_slope
)
if
not
embed_params
and
name
:
_logger
.
warning
(
'for op %s(%s -> PRelu -> %s)'
,
name
,
inputs
,
outputs
)
_logger
.
warning
(
'broken Python code will be generated'
)
embeddable
=
_check_embeddable
(
value_infos
,
var_slope
)
if
not
embeddable
:
_logger
.
warning
(
'for op %s(%s -> PRelu -> %s)'
,
name
,
inputs
,
outputs
)
_logger
.
warning
(
'one of the parameters is intermediate value'
)
_logger
.
warning
(
'broken Python code will be generated'
)
embed_params
&=
embeddable
if
embed_params
:
assert
name
!=
''
embedded_slope
=
name
+
'.w_0'
...
...
@@ -1880,12 +1887,20 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
'input "shape" not inferred, use [1, -1] as dummy value, '
'the behavior of Paddle fluid maybe undefined'
,
name
,
inputs
,
outputs
)
shape_dtype
=
_dtype_or_none
(
value_infos
,
var_shape
)
if
shape_dtype
is
None
:
_logger
.
warning
(
'in op %s(%s -> Reshape -> %s): '
'dtype of input "shape" not inferred, int32 assumed'
,
name
,
inputs
,
outputs
)
shape_dtype
=
_np
.
dtype
(
'int32'
)
fluid_op
=
'reshape'
name_attr
=
', name={}'
.
format
(
repr
(
name
))
# generation
var_shape_int32
=
var_shape
+
'_int32'
# explicit variable
prog
.
Code
(
'# shape:{}={} # const as literal'
.
format
(
var_shape
,
shape
))
var_shape_int32
=
var_shape
+
(
'_int32'
if
shape_dtype
!=
_np
.
int32
else
''
)
# explicit variable
prog
.
Code
(
'# shape: {} = {} # const as literal'
.
format
(
var_shape
,
shape
))
if
is_const_shape
:
prog
.
Code
(
'{} = layers.{}({}'
', shape={}'
...
...
@@ -1898,15 +1913,16 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
name_attr
,
))
else
:
prog
.
Op
(
''
,
'Cast'
,
[
var_shape
],
[
var_shape_int32
],
{
'to'
:
_np
.
dtype
(
'int32'
)},
# use np.dtype
value_infos
=
value_infos
,
name
=
(
name
+
'_cast'
),
)
if
shape_dtype
!=
_np
.
int32
:
prog
.
Op
(
''
,
'Cast'
,
[
var_shape
],
[
var_shape_int32
],
{
'to'
:
_np
.
dtype
(
'int32'
)},
# use np.dtype
value_infos
=
value_infos
,
name
=
(
name
+
'/cast'
),
)
prog
.
Code
(
'{} = layers.{}({}'
', shape={}'
', actual_shape={}'
...
...
@@ -2121,7 +2137,8 @@ def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
# generation
prog
.
Code
(
'# repeats:{}={} # const as literal'
.
format
(
var_repeats
,
repeats
))
prog
.
Code
(
'# repeats: {} = {} # const as literal'
.
format
(
var_repeats
,
repeats
))
prog
.
Code
(
'{} = layers.{}({}'
', expand_times={}'
'{})'
.
format
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录