提交 74feae9b 编写于 作者: M Macrobull

force embeddable checking

上级 fcbdcb82
...@@ -8,8 +8,6 @@ X2Paddle支持将Caffe和TensorFlow模型转至PaddlePaddle模型,同时我们 ...@@ -8,8 +8,6 @@ X2Paddle支持将Caffe和TensorFlow模型转至PaddlePaddle模型,同时我们
任何使用问题均可通过[ISSUE](https://github.com/PaddlePaddle/X2Paddle/issues)的方式及时反馈,或者也可直接通过pull request的方式一起更新代码和文档。 任何使用问题均可通过[ISSUE](https://github.com/PaddlePaddle/X2Paddle/issues)的方式及时反馈,或者也可直接通过pull request的方式一起更新代码和文档。
> **目前X2Paddle主要支持CV部分模型,对于NLP模型暂未支持。**
## [caffe2fluid](caffe2fluid) ## [caffe2fluid](caffe2fluid)
1. 支持将Caffe模型转至PaddlePaddle fluid可加载预测模型 1. 支持将Caffe模型转至PaddlePaddle fluid可加载预测模型
2. 提供Caffe-PaddlePaddle常用API的对比文档[[doc](caffe2fluid/doc)] 2. 提供Caffe-PaddlePaddle常用API的对比文档[[doc](caffe2fluid/doc)]
......
...@@ -78,7 +78,6 @@ DEFAULT_OP_MAPPING = { ...@@ -78,7 +78,6 @@ DEFAULT_OP_MAPPING = {
'Sign': ['sign', ['X'], ['Out']], 'Sign': ['sign', ['X'], ['Out']],
'Sin': ['sin', ['X'], ['Out']], 'Sin': ['sin', ['X'], ['Out']],
'Squeeze': ['squeeze', ['X'], ['Out']], # attrs bypassed, FIXME: emit squeeze2 'Squeeze': ['squeeze', ['X'], ['Out']], # attrs bypassed, FIXME: emit squeeze2
'Softplus': ['softplus', ['X'], ['Out']],
# FIXME: default axis = -1, reshape required before and after # FIXME: default axis = -1, reshape required before and after
'Softmax': ['softmax', ['X'], ['Out'], dict(axis='')], 'Softmax': ['softmax', ['X'], ['Out'], dict(axis='')],
'Softplus': ['softplus', ['X'], ['Out']], 'Softplus': ['softplus', ['X'], ['Out']],
...@@ -305,7 +304,7 @@ def _pad_if_asymmetric(prog, pads, var_input, value_infos): # pads: SSEE ...@@ -305,7 +304,7 @@ def _pad_if_asymmetric(prog, pads, var_input, value_infos): # pads: SSEE
if symmetric: if symmetric:
return pads[:ndims], var_input return pads[:ndims], var_input
var_padded = var_input + '_padded' # explicit variable var_padded = var_input + '_pad' # explicit variable
prog.Op( prog.Op(
'', '',
'Pad', 'Pad',
...@@ -317,7 +316,7 @@ def _pad_if_asymmetric(prog, pads, var_input, value_infos): # pads: SSEE ...@@ -317,7 +316,7 @@ def _pad_if_asymmetric(prog, pads, var_input, value_infos): # pads: SSEE
'pads': pads, 'pads': pads,
}, },
value_infos=value_infos, value_infos=value_infos,
name=(var_input + '_pad'), name=(var_input + '/pad'),
) )
return [0] * ndims, var_padded return [0] * ndims, var_padded
...@@ -688,13 +687,14 @@ def BatchNormalization(prog, ...@@ -688,13 +687,14 @@ def BatchNormalization(prog,
momentum = attrs.get('momentum', .9) # optional momentum = attrs.get('momentum', .9) # optional
epsilon = attrs.get('epsilon', 1e-5) # optional epsilon = attrs.get('epsilon', 1e-5) # optional
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
if embed_params: embeddable = _check_embeddable(value_infos, var_scale, var_b, var_mean,
embed_params = _check_embeddable(value_infos, var_scale, var_b, var_var)
var_mean, var_var) if not embeddable:
if not embed_params and name:
_logger.warning('for op %s(%s -> BatchNormalization -> %s)', name, _logger.warning('for op %s(%s -> BatchNormalization -> %s)', name,
inputs, outputs) inputs, outputs)
_logger.warning('one of the parameters is intermediate value')
_logger.warning('broken Python code will be generated') _logger.warning('broken Python code will be generated')
embed_params &= embeddable
if embed_params: if embed_params:
assert name != '' assert name != ''
embedded_scale = name + '.w_0' embedded_scale = name + '.w_0'
...@@ -898,10 +898,10 @@ def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -898,10 +898,10 @@ def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
'given shape is neither const value nor deductible from output, ' 'given shape is neither const value nor deductible from output, '
'this is not supported') 'this is not supported')
attrs = attrs.copy() attrs = attrs.copy()
attrs.setdefault('value', np.array(0, dtype=np.float32)) attrs.setdefault('value', _np.array(0, dtype=_np.float32))
attrs.update({'shape': shape}) # pass const attrs.update({'shape': shape}) # pass const
prog.Code('# shape:{}={} # const as literal'.format(var_shape, shape)) prog.Code('# shape: {} = {} # const as literal'.format(var_shape, shape))
prog.Op( prog.Op(
'', '',
'Constant', 'Constant',
...@@ -947,13 +947,13 @@ def Conv(prog, ...@@ -947,13 +947,13 @@ def Conv(prog,
pads = attrs.get('pads', [0] * (convnd * 2)) # optional pads = attrs.get('pads', [0] * (convnd * 2)) # optional
paddings, var_x = _pad_if_asymmetric(prog, pads, var_x, value_infos) paddings, var_x = _pad_if_asymmetric(prog, pads, var_x, value_infos)
name_attr = ', name={}'.format(repr(name)) name_attr = ', name={}'.format(repr(name))
if embed_params: embeddable = _check_embeddable(value_infos,
embed_params = _check_embeddable( *([var_w] + ([var_b] if var_b else [])))
value_infos, *([var_w] + ([var_b] if var_b else []))) if not embeddable:
if not embed_params: _logger.warning('for op %s(%s -> Conv -> %s)', name, inputs, outputs)
_logger.warning('for op %s(%s -> Conv -> %s)', name, inputs, _logger.warning('one of the parameters is intermediate value')
outputs)
_logger.warning('broken Python code will be generated') _logger.warning('broken Python code will be generated')
embed_params &= embeddable
if embed_params: if embed_params:
embedded_w = name + '.w_0' embedded_w = name + '.w_0'
value_infos[var_w]['embedded_as'].append(embedded_w) value_infos[var_w]['embedded_as'].append(embedded_w)
...@@ -1013,7 +1013,7 @@ def Conv(prog, ...@@ -1013,7 +1013,7 @@ def Conv(prog,
[var_y], [var_y],
{'axis': 1}, {'axis': 1},
value_infos=value_infos, value_infos=value_infos,
name=(name + '.bias'), name=(name + '/bias'),
) )
else: else:
prog.VarDesc(var_y) prog.VarDesc(var_y)
...@@ -1058,13 +1058,14 @@ def ConvTranspose(prog, ...@@ -1058,13 +1058,14 @@ def ConvTranspose(prog,
pads = attrs.get('pads', [0] * (convnd * 2)) # optional pads = attrs.get('pads', [0] * (convnd * 2)) # optional
paddings, var_x = _pad_if_asymmetric(prog, pads, var_x, value_infos) paddings, var_x = _pad_if_asymmetric(prog, pads, var_x, value_infos)
name_attr = ', name={}'.format(repr(name)) name_attr = ', name={}'.format(repr(name))
if embed_params: embeddable = _check_embeddable(value_infos,
embed_params = _check_embeddable( *([var_w] + ([var_b] if var_b else [])))
value_infos, *([var_w] + ([var_b] if var_b else []))) if not embeddable:
if not embed_params: _logger.warning('for op %s(%s -> ConvTranspose -> %s)', name, inputs,
_logger.warning('for op %s(%s -> ConvTranspose -> %s)', name, outputs)
inputs, outputs) _logger.warning('one of the parameters is intermediate value')
_logger.warning('broken Python code will be generated') _logger.warning('broken Python code will be generated')
embed_params &= embeddable
if embed_params: if embed_params:
embedded_w = name + '.w_0' embedded_w = name + '.w_0'
value_infos[var_w]['embedded_as'].append(embedded_w) value_infos[var_w]['embedded_as'].append(embedded_w)
...@@ -1128,7 +1129,7 @@ def ConvTranspose(prog, ...@@ -1128,7 +1129,7 @@ def ConvTranspose(prog,
[var_y], [var_y],
{'axis': 1}, {'axis': 1},
value_infos=value_infos, value_infos=value_infos,
name=(name + '.bias'), name=(name + '/bias'),
) )
else: else:
prog.VarDesc(var_y) prog.VarDesc(var_y)
...@@ -1148,7 +1149,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1148,7 +1149,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
trans_a = bool(attrs.get('transA', 0)) # optional trans_a = bool(attrs.get('transA', 0)) # optional
trans_b = bool(attrs.get('transB', 0)) # optional trans_b = bool(attrs.get('transB', 0)) # optional
var_mm = var_y if beta == 0 else (name + '_mmed') # explicit variable var_mm = var_y if beta == 0 else (name + '_mm') # explicit variable
prog.Op( prog.Op(
'', '',
'MatMul', 'MatMul',
...@@ -1160,7 +1161,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1160,7 +1161,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
'alpha': alpha, 'alpha': alpha,
}, },
value_infos=value_infos, value_infos=value_infos,
name=(name + '_mm'), name=(name + '/mm'),
) )
prog.op_descs[-1].attrs.extend( prog.op_descs[-1].attrs.extend(
prog.OpDescAttrs({ prog.OpDescAttrs({
...@@ -1176,7 +1177,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1176,7 +1177,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[var_y], [var_y],
{'axis': 1}, {'axis': 1},
value_infos=value_infos, value_infos=value_infos,
name=(name + '_bias'), name=(name + '/bias'),
) )
else: else:
var_beta = name + '_beta' # explicit variable var_beta = name + '_beta' # explicit variable
...@@ -1207,7 +1208,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1207,7 +1208,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[var_vm], [var_vm],
dict(), dict(),
value_infos=value_infos, value_infos=value_infos,
name=(var_beta + '_scale'), name=(var_beta + '/scale'),
) )
prog.Op( prog.Op(
'', '',
...@@ -1215,7 +1216,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1215,7 +1216,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[var_mm, var_vm], [var_mm, var_vm],
[var_y], [var_y],
{'axis': 1}, # {'axis': 1}, #
name=(name + '_bias'), name=(name + '/bias'),
) )
...@@ -1307,6 +1308,9 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -1307,6 +1308,9 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
is_reverse = direction == 'reverse' is_reverse = direction == 'reverse'
fluid_op = 'dynamic_gru' fluid_op = 'dynamic_gru'
_logger.warning('for op (%s -> GRU -> %s)', inputs, outputs)
_logger.warning('one of the parameters is intermediate value')
_logger.warning('broken Python code will be generated')
# generation # generation
var_x0 = var_x + '_0' # explicit variable var_x0 = var_x + '_0' # explicit variable
...@@ -1316,7 +1320,7 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -1316,7 +1320,7 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
[var_x], [var_x],
[var_x0], [var_x0],
{'axes': [1]}, # index on n {'axes': [1]}, # index on n
name=(var_x + '_index'), name=(var_x + '/index'),
) )
var_w0 = var_w + '_0' # explicit variable var_w0 = var_w + '_0' # explicit variable
prog.Op( prog.Op(
...@@ -1325,10 +1329,10 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -1325,10 +1329,10 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
[var_w], [var_w],
[var_w0], [var_w0],
{'axes': [0]}, # index on d {'axes': [0]}, # index on d
name=(var_w + '_index'), name=(var_w + '/index'),
) )
var_fc = var_x0 + '_fc' var_fc = var_x0 + '_fc'
var_mm = (var_x0 + '_mmed') if var_b else var_fc var_mm = (var_x0 + '_mm') if var_b else var_fc
prog.Op( prog.Op(
'', '',
'MatMul', 'MatMul',
...@@ -1339,7 +1343,7 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -1339,7 +1343,7 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
'transpose_y': 1, 'transpose_y': 1,
}, },
value_infos=value_infos, value_infos=value_infos,
name=(var_x0 + '_mm'), name=(var_x0 + '/mm'),
) )
prog.op_descs[-1].attrs.extend( prog.op_descs[-1].attrs.extend(
prog.OpDescAttrs({ prog.OpDescAttrs({
...@@ -1353,7 +1357,7 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -1353,7 +1357,7 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
[var_r], [var_r],
[var_r0], [var_r0],
{'axes': [0]}, # index on d {'axes': [0]}, # index on d
name=(var_r + '_index'), name=(var_r + '/index'),
) )
var_r0t = var_r0 + '_t' # explicit variable var_r0t = var_r0 + '_t' # explicit variable
prog.Op( prog.Op(
...@@ -1362,7 +1366,7 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -1362,7 +1366,7 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
[var_r0], [var_r0],
[var_r0t], [var_r0t],
{'perm': [1, 0]}, # transpose OI->IO {'perm': [1, 0]}, # transpose OI->IO
name=(var_r0 + '_transpose'), name=(var_r0 + '/transpose'),
) )
if var_b: if var_b:
var_bi = var_b + '_i' # explicit variable var_bi = var_b + '_i' # explicit variable
...@@ -1376,7 +1380,7 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -1376,7 +1380,7 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
'axis': 1, # split on x 'axis': 1, # split on x
'split': [hidden_size * 3, hidden_size * 3], 'split': [hidden_size * 3, hidden_size * 3],
}, },
name=(var_b + '_split'), name=(var_b + '/split'),
) )
# squeeze bi so Gemm Add can be performed on axis=1 exaclty # squeeze bi so Gemm Add can be performed on axis=1 exaclty
var_bi0 = var_bi + '_0' # explicit variable var_bi0 = var_bi + '_0' # explicit variable
...@@ -1386,7 +1390,7 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -1386,7 +1390,7 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
[var_bi], [var_bi],
[var_bi0], [var_bi0],
{'axes': [0]}, # slice on d {'axes': [0]}, # slice on d
name=(var_bi + '_index'), name=(var_bi + '/index'),
) )
prog.Op( prog.Op(
'', '',
...@@ -1394,7 +1398,7 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -1394,7 +1398,7 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
[var_mm, var_bi0], [var_mm, var_bi0],
[var_fc], [var_fc],
{'axis': 1}, # {'axis': 1}, #
name=(var_x0 + '_bias'), name=(var_x0 + '/bias'),
) )
if var_xh: if var_xh:
var_xh0 = var_xh + '_0' # explicit variable var_xh0 = var_xh + '_0' # explicit variable
...@@ -1404,7 +1408,7 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -1404,7 +1408,7 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
[var_xh], [var_xh],
[var_xh0], [var_xh0],
{'axes': [1]}, # index on n {'axes': [1]}, # index on n
name=(var_xh + '_index'), name=(var_xh + '/index'),
) )
var_y00 = var_y + '_00' # explicit variable var_y00 = var_y + '_00' # explicit variable
prog.Code('{} = layers.{}({}, {}, origin_mode=True' prog.Code('{} = layers.{}({}, {}, origin_mode=True'
...@@ -1449,7 +1453,7 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -1449,7 +1453,7 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
[var_y00], [var_y00],
[var_y], [var_y],
{'axes': [1, 1]}, # extrude on dn {'axes': [1, 1]}, # extrude on dn
name=(var_y + '_reshape'), name=(var_y + '/reshape'),
) )
...@@ -1511,6 +1515,9 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1511,6 +1515,9 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
fluid_op = 'dynamic_lstm' fluid_op = 'dynamic_lstm'
name_attr = ', name={}'.format(repr(name)) name_attr = ', name={}'.format(repr(name))
_logger.warning('for op %s(%s -> LSTM -> %s)', name, inputs, outputs)
_logger.warning('one of the parameters is intermediate value')
_logger.warning('broken Python code will be generated')
# generation # generation
var_x0 = var_x + '_0' # explicit variable var_x0 = var_x + '_0' # explicit variable
...@@ -1520,7 +1527,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1520,7 +1527,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[var_x], [var_x],
[var_x0], [var_x0],
{'axes': [1]}, # index on n {'axes': [1]}, # index on n
name=(var_x + '_index'), name=(var_x + '/index'),
) )
var_w0 = var_w + '_0' # explicit variable var_w0 = var_w + '_0' # explicit variable
prog.Op( prog.Op(
...@@ -1529,10 +1536,10 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1529,10 +1536,10 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[var_w], [var_w],
[var_w0], [var_w0],
{'axes': [0]}, # index on d {'axes': [0]}, # index on d
name=(var_w + '_index'), name=(var_w + '/index'),
) )
var_fc = var_x0 + '_fc' var_fc = var_x0 + '_fc'
var_mm = (var_x0 + '_mmed') if var_b else var_fc var_mm = (var_x0 + '_mm') if var_b else var_fc
prog.Op( prog.Op(
'', '',
'MatMul', 'MatMul',
...@@ -1543,7 +1550,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1543,7 +1550,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
'transpose_y': 1, 'transpose_y': 1,
}, },
value_infos=value_infos, value_infos=value_infos,
name=(name + '_mm'), name=(name + '/mm'),
) )
prog.op_descs[-1].attrs.extend( prog.op_descs[-1].attrs.extend(
prog.OpDescAttrs({ prog.OpDescAttrs({
...@@ -1557,7 +1564,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1557,7 +1564,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[var_r], [var_r],
[var_r0], [var_r0],
{'axes': [0]}, # index on d {'axes': [0]}, # index on d
name=(var_r + '_index'), name=(var_r + '/index'),
) )
var_r0t = var_r0 + '_t' # explicit variable var_r0t = var_r0 + '_t' # explicit variable
prog.Op( prog.Op(
...@@ -1566,7 +1573,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1566,7 +1573,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[var_r0], [var_r0],
[var_r0t], [var_r0t],
{'perm': [1, 0]}, # transpose OI->IO {'perm': [1, 0]}, # transpose OI->IO
name=(var_r0 + '_transpose'), name=(var_r0 + '/transpose'),
) )
if var_b: if var_b:
var_bi = var_b + '_i' # explicit variable var_bi = var_b + '_i' # explicit variable
...@@ -1580,7 +1587,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1580,7 +1587,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
'axis': 1, # split on x 'axis': 1, # split on x
'split': [hidden_size * 4, hidden_size * 4], 'split': [hidden_size * 4, hidden_size * 4],
}, },
name=(var_b + '_split'), name=(var_b + '/split'),
) )
# squeeze bi so Gemm Add can be performed on axis=1 exaclty # squeeze bi so Gemm Add can be performed on axis=1 exaclty
var_bi0 = var_bi + '_0' # explicit variable var_bi0 = var_bi + '_0' # explicit variable
...@@ -1590,7 +1597,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1590,7 +1597,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[var_bi], [var_bi],
[var_bi0], [var_bi0],
{'axes': [0]}, # slice on d {'axes': [0]}, # slice on d
name=(var_bi + '_index'), name=(var_bi + '/index'),
) )
prog.Op( prog.Op(
'', '',
...@@ -1598,7 +1605,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1598,7 +1605,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[var_mm, var_bi0], [var_mm, var_bi0],
[var_fc], [var_fc],
{'axis': 1}, # {'axis': 1}, #
name=(name + '_bias'), name=(name + '/bias'),
) )
if var_xh: if var_xh:
var_xh0 = var_xh + '_0' # explicit variable var_xh0 = var_xh + '_0' # explicit variable
...@@ -1608,7 +1615,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1608,7 +1615,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[var_xh], [var_xh],
[var_xh0], [var_xh0],
{'axes': [1]}, # index on n {'axes': [1]}, # index on n
name=(var_xh + '_index'), name=(var_xh + '/index'),
) )
if var_xc: if var_xc:
var_xc0 = var_xc + '_0' # explicit variable var_xc0 = var_xc + '_0' # explicit variable
...@@ -1618,7 +1625,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1618,7 +1625,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[var_xc], [var_xc],
[var_xc0], [var_xc0],
{'axes': [1]}, # index on n {'axes': [1]}, # index on n
name=(var_xc + '_index'), name=(var_xc + '/index'),
) )
var_bhp = var_p var_bhp = var_p
if var_b: if var_b:
...@@ -1630,7 +1637,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1630,7 +1637,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[var_bh, var_p], [var_bh, var_p],
[var_bhp], [var_bhp],
{'axes': [1]}, # cat on x {'axes': [1]}, # cat on x
name=(name + '_concat'), name=(name + '/concat'),
) )
else: else:
var_bhp = var_bh var_bhp = var_bh
...@@ -1690,7 +1697,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1690,7 +1697,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[var_yh0], [var_yh0],
[var_y], # var_yh [var_y], # var_yh
{'axes': [1, 1]}, # extrude on dn {'axes': [1, 1]}, # extrude on dn
name=(var_y + '_reshape'), name=(var_y + '/reshape'),
) )
if var_yc: if var_yc:
prog.Op( prog.Op(
...@@ -1699,7 +1706,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1699,7 +1706,7 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[var_yc0], [var_yc0],
[var_yc], [var_yc],
{'axes': [1, 1]}, # extrude on dn {'axes': [1, 1]}, # extrude on dn
name=(var_yc + '_reshape'), name=(var_yc + '/reshape'),
) )
...@@ -1811,12 +1818,12 @@ def PRelu(prog, ...@@ -1811,12 +1818,12 @@ def PRelu(prog,
mode = 'element' mode = 'element'
fluid_op = 'prelu' fluid_op = 'prelu'
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
if embed_params: embeddable = _check_embeddable(value_infos, var_slope)
embed_params = _check_embeddable(value_infos, var_slope) if not embeddable:
if not embed_params and name: _logger.warning('for op %s(%s -> PRelu -> %s)', name, inputs, outputs)
_logger.warning('for op %s(%s -> PRelu -> %s)', name, inputs, _logger.warning('one of the parameters is intermediate value')
outputs)
_logger.warning('broken Python code will be generated') _logger.warning('broken Python code will be generated')
embed_params &= embeddable
if embed_params: if embed_params:
assert name != '' assert name != ''
embedded_slope = name + '.w_0' embedded_slope = name + '.w_0'
...@@ -1880,12 +1887,20 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1880,12 +1887,20 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
'input "shape" not inferred, use [1, -1] as dummy value, ' 'input "shape" not inferred, use [1, -1] as dummy value, '
'the behavior of Paddle fluid maybe undefined', name, inputs, 'the behavior of Paddle fluid maybe undefined', name, inputs,
outputs) outputs)
shape_dtype = _dtype_or_none(value_infos, var_shape)
if shape_dtype is None:
_logger.warning(
'in op %s(%s -> Reshape -> %s): '
'dtype of input "shape" not inferred, int32 assumed', name, inputs,
outputs)
shape_dtype = _np.dtype('int32')
fluid_op = 'reshape' fluid_op = 'reshape'
name_attr = ', name={}'.format(repr(name)) name_attr = ', name={}'.format(repr(name))
# generation # generation
var_shape_int32 = var_shape + '_int32' # explicit variable var_shape_int32 = var_shape + ('_int32' if shape_dtype != _np.int32 else ''
prog.Code('# shape:{}={} # const as literal'.format(var_shape, shape)) ) # explicit variable
prog.Code('# shape: {} = {} # const as literal'.format(var_shape, shape))
if is_const_shape: if is_const_shape:
prog.Code('{} = layers.{}({}' prog.Code('{} = layers.{}({}'
', shape={}' ', shape={}'
...@@ -1898,6 +1913,7 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1898,6 +1913,7 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
name_attr, name_attr,
)) ))
else: else:
if shape_dtype != _np.int32:
prog.Op( prog.Op(
'', '',
'Cast', 'Cast',
...@@ -1905,7 +1921,7 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1905,7 +1921,7 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[var_shape_int32], [var_shape_int32],
{'to': _np.dtype('int32')}, # use np.dtype {'to': _np.dtype('int32')}, # use np.dtype
value_infos=value_infos, value_infos=value_infos,
name=(name + '_cast'), name=(name + '/cast'),
) )
prog.Code('{} = layers.{}({}' prog.Code('{} = layers.{}({}'
', shape={}' ', shape={}'
...@@ -2121,7 +2137,8 @@ def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs): ...@@ -2121,7 +2137,8 @@ def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
# generation # generation
prog.Code('# repeats:{}={} # const as literal'.format(var_repeats, repeats)) prog.Code('# repeats: {} = {} # const as literal'.format(
var_repeats, repeats))
prog.Code('{} = layers.{}({}' prog.Code('{} = layers.{}({}'
', expand_times={}' ', expand_times={}'
'{})'.format( '{})'.format(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册