提交 e22b117b 编写于 作者: C Channingss

delete GRU

上级 896be706
......@@ -32,11 +32,14 @@ import shutil
_logger = _logging.getLogger(__name__)
def _const_weight_or_none(node):
def _const_weight_or_none(node, necessary=False):
if 'Constant' in node.layer_type:
return node.value
if isinstance(node, ONNXGraphDataNode):
return node.weight
if necessary:
assert '{} should be an initializer or Constant operator.'.format(
node.layer_name)
return None
......@@ -724,10 +727,10 @@ class OpSet9():
ends = self.graph.get_input_node(node, idx=2, copy=True)
if len(node.inputs) > 3:
axes = self.graph.get_input_node(node, idx=3, copy=True)
axes = _const_weight_or_none(axes)
axes = _const_weight_or_none(axes, necessary=True)
if len(node.inputs) > 4:
steps = self.graph.get_input_node(node, idx=4, copy=True)
steps = _const_weight_or_none(steps)
steps = _const_weight_or_none(steps, necessary=True)
if steps is not None:
assert steps == 1, "Only support convert op:Slice, which attribute:steps == 1"
attr = {
......@@ -735,8 +738,8 @@ class OpSet9():
"starts": starts.layer_name,
"ends": ends.layer_name
}
starts_value = _const_weight_or_none(starts)
ends_value = _const_weight_or_none(ends)
starts_value = _const_weight_or_none(starts, necessary=True)
ends_value = _const_weight_or_none(ends, necessary=True)
if starts_value is not None and ends_value is not None:
self.omit_nodes.append(starts.layer_name)
self.omit_nodes.append(ends.layer_name)
......@@ -1171,7 +1174,6 @@ class OpSet9():
def NonZero(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_x_dim = len(val_x.out_shapes[0])
print(val_x.layer_name, val_x.out_shapes[0])
if val_x_dim == 1:
node.fluid_code.add_layer("nonzero", inputs=val_x, output=val_x)
node.fluid_code.add_layer(
......@@ -1293,13 +1295,13 @@ class OpSet9():
kernel_shape = node.get_attr('kernel_shape')
convnd = len(kernel_shape)
assert 2 <= convnd <= 3, 'only conv2d and conv3d is supported'
num_out_channels = val_w.out_shapes[0][0] # OI...
num_out_channels = val_w.out_shapes[0][0]
fluid_op = 'conv{}d'.format(convnd)
num_groups = node.get_attr('group', 1)
strides = node.get_attr('strides', [1] * convnd) # optional
dilations = node.get_attr('dilations', [1] * convnd) # optional
pads = node.get_attr('pads', [0] * (convnd * 2)) # optional
strides = node.get_attr('strides', [1] * convnd)
dilations = node.get_attr('dilations', [1] * convnd)
pads = node.get_attr('pads', [0] * (convnd * 2))
input_shape = val_x.out_shapes[0]
paddings, val_x = self._pad_if_asymmetric(node, pads, val_x)
......@@ -1379,183 +1381,3 @@ class OpSet9():
}
node.fluid_code.add_layer(
fluid_op, inputs=val_x, output=node, param_attr=attr)
@print_mapping_info
def GRU(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_w = self.graph.get_input_node(node, idx=1, copy=True)
val_r = self.graph.get_input_node(node, idx=2, copy=True)
val_b = None
val_len = None
val_xh = None
miss_arg_num = 0
num_ipt = len(node.layer.input)
if num_ipt > 3 and node.layer.input[3] != '':
val_b = self.graph.get_input_node(node, idx=3, copy=True)
else:
miss_arg_num += 1
if num_ipt > 4 and node.layer.input[4] != '':
val_len = self.graph.get_input_node(
node, idx=4 - miss_arg_num, copy=True)
else:
miss_arg_num += 1
if num_ipt > 5 and node.layer.input[5] != '':
val_xh = self.graph.get_input_node(
node, idx=5 - miss_arg_num, copy=True)
x_shape = val_x.out_shapes[0]
assert x_shape[1] == 1, 'only X with batch_size = 1 supported'
assert node.get_attr('clip', None) is None, 'clipping not supported'
hidden_size = node.get_attr('hidden_size', None)
if hidden_size is None:
r_shape = val_r.out_shapes[0]
if r_shape:
hidden_size = r_shape[-1]
if hidden_size is None:
w_shape = var_w.out_shapes[0]
if w_shape:
hidden_size = w_shape[-2] // 3
if hidden_size is None and val_b:
b_shape = val_b.out_shapes[0]
if b_shape:
hidden_size = b_shape[-1] // 6
if hidden_size is None and val_xh:
xh_shape = val_xh.out_shapes[0]
if xh_shape:
hidden_size = xh_shape[-1]
direction = node.get_attr('direction', 'forward')
assert direction != 'bidirectional', 'direction = bidirectional not supported'
activations = node.get_attr('activations', ['Sigmoid', 'Tanh'])
assert len(activations) == 2, 'bidirectional operation not supported'
assert node.get_attr('linear_before_reset',
0) == 0, 'only linear_before_reset = 0 supported'
activations = [s.lower() for s in activations]
gate_activation, candidate_activation = activations
is_reverse = direction == 'reverse'
var_x0 = node.layer_name + '_x0'
node.fluid_code.add_layer(
'squeeze',
inputs=val_x,
output=var_x0,
param_attr={'axes': [1],
'name': string(var_x0)})
var_w0 = node.layer_name + '_w0'
node.fluid_code.add_layer(
'squeeze',
inputs=val_w,
output=var_w0,
param_attr={'axes': [0],
'name': string(var_w0)})
var_fc = node.layer_name + '_fc'
var_mm = (node.layer_name + '_mm') if val_b else var_fc
node.fluid_code.add_layer(
'matmul',
inputs={'x': var_x0,
'y': var_w0},
output=var_mm,
param_attr={
'transpose_x': 0,
'transpose_y': 1,
'name': string(var_mm)
})
var_r0 = node.layer_name + '_r0'
node.fluid_code.add_layer(
'squeeze',
inputs=val_r,
output=var_r0,
param_attr={'axes': [0],
'name': string(var_r0)})
var_r0t = node.layer_name + '_r0t'
node.fluid_code.add_layer(
'transpose',
inputs=var_r0,
output=var_r0t,
param_attr={'perm': [1, 0],
'name': string(var_r0t)})
if val_b:
var_bi = node.layer_name + '_bi'
var_bh = node.layer_name + '_bh'
node.fluid_code.add_layer(
'split',
inputs=val_b,
output=var_bi + ',' + var_bh,
param_attr={
'dim': 1,
'num_or_sections': [hidden_size * 3, hidden_size * 3],
'name': string(node.layer_name + '.b/split')
})
var_bi0 = node.layer_name + '_bi0'
node.fluid_code.add_layer(
'squeeze',
inputs=var_bi,
output=var_bi0,
param_attr={'axes': [0],
'name': string(var_bi0)})
node.fluid_code.add_layer(
'elementwise_add',
inputs=[var_mm, var_bi0],
output=var_fc,
param_attr={
'axes': 1,
'name': string(node.layer_name + '.i/bias')
})
if val_xh:
var_xh0 = node.layer_name + '_xh0'
node.fluid_code.add_layer(
'squeeze',
inputs=val_xh,
output=var_xh0,
param_attr={'axes': [1],
'name': string(var_xh0)})
var_y00 = node.layer_name + '_y00'
attr = {
'origin_mode': True,
'h_0': var_xh0 if val_xh else None,
'is_reverse': is_reverse,
'gate_activation': string(gate_activation),
'candidate_activation': string(candidate_activation),
'param_attr': string(var_r0t),
'bias_attr': string(var_bh) if val_b else False,
}
node.fluid_code.add_layer(
'dynamic_gru',
inputs=var_fc + ',' + str(hidden_size),
output=var_y00,
param_attr=attr)
num_opt = len(node.layer.output)
if num_opt > 0 and node.layer.output[0] != '':
node.fluid_code.add_layer(
'unsqueeze',
inputs=var_y00,
output=node.layer.output[0],
param_attr={
'axes': [1, 1],
'name': string(node.layer.output[0])
})
if num_opt > 1 and node.layer.output[1] != '':
node.fluid_code.add_layer(
'unsqueeze',
inputs=var_y00,
output=node.layer.output[1],
param_attr={
'axes': [1, 1],
'name': string(node.layer.output[1])
})
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册