未验证 提交 315a3057 编写于 作者: C channings 提交者: GitHub

Update onnx_op_mapper.py

上级 d6f2b463
...@@ -36,7 +36,7 @@ _logger = _logging.getLogger(__name__) ...@@ -36,7 +36,7 @@ _logger = _logging.getLogger(__name__)
def _const_weight_or_none(node): def _const_weight_or_none(node):
if 'Constant' in node.layer_name: if 'Constant' in node.layer_type:
return node.value return node.value
if isinstance(node, ONNXGraphDataNode): if isinstance(node, ONNXGraphDataNode):
return node.weight return node.weight
...@@ -57,8 +57,7 @@ class ONNXOpMapper(OpMapper): ...@@ -57,8 +57,7 @@ class ONNXOpMapper(OpMapper):
'Div': 'elementwise_div', 'Div': 'elementwise_div',
'Sub': 'elementwise_sub', 'Sub': 'elementwise_sub',
'Mul': 'elementwise_mul', 'Mul': 'elementwise_mul',
'Pow': 'elementwise_pow', 'Pow': 'elementwise_pow',}
}
def __init__(self, decoder, save_dir): def __init__(self, decoder, save_dir):
super(ONNXOpMapper, self).__init__() super(ONNXOpMapper, self).__init__()
...@@ -166,8 +165,8 @@ class ONNXOpMapper(OpMapper): ...@@ -166,8 +165,8 @@ class ONNXOpMapper(OpMapper):
for opt in layer.output: for opt in layer.output:
if opt in value_infos: if opt in value_infos:
value_info = value_infos[opt] value_info = value_infos[opt]
if len(value_info['shape']) == 0 or value_info[ if len(value_info['shape']
'dtype'] is None or 0 in value_info['shape']: ) == 0 or value_info['dtype'] is None or 0 in value_info['shape']:
if self.is_inference == False: if self.is_inference == False:
self.get_results_of_inference( self.get_results_of_inference(
onnx_model, value_infos, onnx_model, value_infos,
...@@ -264,7 +263,6 @@ class ONNXOpMapper(OpMapper): ...@@ -264,7 +263,6 @@ class ONNXOpMapper(OpMapper):
if child_func_code is not None: if child_func_code is not None:
self.used_custom_layers[op + self.used_custom_layers[op +
'_child_func'] = child_func_code '_child_func'] = child_func_code
def elementwise_map(self, node): def elementwise_map(self, node):
assert node.layer_type in self.elementwise_ops assert node.layer_type in self.elementwise_ops
op_type = self.elementwise_ops[node.layer_type] op_type = self.elementwise_ops[node.layer_type]
...@@ -274,7 +272,7 @@ class ONNXOpMapper(OpMapper): ...@@ -274,7 +272,7 @@ class ONNXOpMapper(OpMapper):
val_y_shape = val_y.out_shapes[0] val_y_shape = val_y.out_shapes[0]
val_x_shape = val_x.out_shapes[0] val_x_shape = val_x.out_shapes[0]
if len(val_x_shape) < len(val_y_shape): if len(val_x_shape)<len(val_y_shape):
val_x, val_y = val_y, val_x val_x, val_y = val_y, val_x
str_y_shape = ','.join(str(e) for e in val_y_shape) str_y_shape = ','.join(str(e) for e in val_y_shape)
...@@ -312,9 +310,16 @@ class ONNXOpMapper(OpMapper): ...@@ -312,9 +310,16 @@ class ONNXOpMapper(OpMapper):
def place_holder(self, node): def place_holder(self, node):
self.input_shapes.append(node.out_shapes[0]) self.input_shapes.append(node.out_shapes[0])
shape = node.out_shapes[0]
for i, dim_shape in enumerate(shape):
if dim_shape==0 and i==0:
shape[i]=1
if dim_shape==0 and i!=0:
assert 'shape of input is not assigned'
attr = { attr = {
"dtype": string(node.dtype), "dtype": string(node.dtype),
"shape": node.out_shapes[0], "shape": shape,
"name": string(node.layer_name), "name": string(node.layer_name),
"append_batch_size": 'False' "append_batch_size": 'False'
} }
...@@ -371,7 +376,7 @@ class ONNXOpMapper(OpMapper): ...@@ -371,7 +376,7 @@ class ONNXOpMapper(OpMapper):
if isinstance(val_scales, ONNXGraphNode): if isinstance(val_scales, ONNXGraphNode):
scales, _, _ = self.get_dynamic_shape(val_scales.layer_name) scales, _, _ = self.get_dynamic_shape(val_scales.layer_name)
attr = {'name': string(node.layer_name)} attr = { 'name': string(node.layer_name)}
use_scales = True use_scales = True
if scales is not None: if scales is not None:
try: try:
...@@ -381,7 +386,7 @@ class ONNXOpMapper(OpMapper): ...@@ -381,7 +386,7 @@ class ONNXOpMapper(OpMapper):
assert scales[2] == scales[ assert scales[2] == scales[
3], 'only aspect-ratio-invariant scale supported' 3], 'only aspect-ratio-invariant scale supported'
except: except:
use_scales = False use_scales=False
scale = scales[2] if scales else None scale = scales[2] if scales else None
if scale is None: if scale is None:
assert out_shape, 'neither scales nor output shape is available' assert out_shape, 'neither scales nor output shape is available'
...@@ -397,9 +402,7 @@ class ONNXOpMapper(OpMapper): ...@@ -397,9 +402,7 @@ class ONNXOpMapper(OpMapper):
fluid_op = 'resize_{}'.format(mode) fluid_op = 'resize_{}'.format(mode)
if 'linear' in mode: if 'linear' in mode:
print( print('Warnning: paddle not support op:resize wiht mode: linear, we use bilinear replace linear')
'Warnning: paddle not support op:resize wiht mode: linear, we use bilinear replace linear'
)
fluid_op = 'resize_bilinear' fluid_op = 'resize_bilinear'
if use_scales and scale is not None: if use_scales and scale is not None:
...@@ -411,6 +414,40 @@ class ONNXOpMapper(OpMapper): ...@@ -411,6 +414,40 @@ class ONNXOpMapper(OpMapper):
inputs=val_x, inputs=val_x,
output=node, output=node,
param_attr=attr) param_attr=attr)
def RoiAlign(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_rois = self.graph.get_input_node(node, idx=1, copy=True)
pooled_height = node.get_attr('output_height')
pooled_width = node.get_attr('output_width')
spatial_scale = node.get_attr('spatial_scale')
sampling_ratio = node.get_attr('sampling_ratio')
attr = {
'pooled_height': pooled_height,
'pooled_width': pooled_width,
'spatial_scale': spatial_scale,
'sampling_ratio':sampling_ratio,
}
node.fluid_code.add_layer('roi_align',
inputs={'input':val_x,'rois':val_rois},
output=node,
param_attr=attr)
def MaxRoiPool(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_rois = self.graph.get_input_node(node, idx=1, copy=True)
spatial_scale = node.get_attr('spatial_scale')
pooled_height, pooled_width = node.get_attr('pooled_shape')
attr = {
'pooled_height': pooled_height,
'pooled_width': pooled_width,
'spatial_scale': spatial_scale,
}
node.fluid_code.add_layer('roi_pool',
inputs={'input':val_x,'rois':val_rois},
output=node,
param_attr=attr)
def Pad(self, node, op_independent=True): def Pad(self, node, op_independent=True):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
...@@ -462,7 +499,8 @@ class ONNXOpMapper(OpMapper): ...@@ -462,7 +499,8 @@ class ONNXOpMapper(OpMapper):
def Unsqueeze(self, node): def Unsqueeze(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
axes = node.get_attr('axes') axes = node.get_attr('axes')
if len(val_x.out_shapes[0]) == 0: print(val_x.outputs)
if len(val_x.out_shapes[0])==0:
node.fluid_code.add_layer('assign', node.fluid_code.add_layer('assign',
inputs=val_x, inputs=val_x,
output=node, output=node,
...@@ -542,25 +580,29 @@ class ONNXOpMapper(OpMapper): ...@@ -542,25 +580,29 @@ class ONNXOpMapper(OpMapper):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_shape = self.graph.get_input_node(node, idx=1, copy=True) val_shape = self.graph.get_input_node(node, idx=1, copy=True)
if len(val_shape.outputs) == 1: if len(val_shape.outputs)==1:
self.omit_nodes.append(val_shape.layer_name) self.omit_nodes.append(val_shape.layer_name)
val_y = self.graph.get_node(node.layer.output[0], copy=True) val_y = self.graph.get_node(node.layer.output[0], copy=True)
out_shape = node.out_shapes[0] out_shape = node.out_shapes[0]
val_x_dtype = val_x.dtype val_x_dtype = val_x.dtype
name_ones = node.layer_name + '_ones' name_ones= node.layer_name + '_ones'
attr_ones = {'shape': out_shape, 'dtype': string(val_x_dtype)} attr_ones = {
'shape':out_shape,
'dtype':string(val_x_dtype)
}
node.fluid_code.add_layer('ones', node.fluid_code.add_layer('ones',
inputs=None, inputs=None,
output=name_ones, output=name_ones,
param_attr=attr_ones) param_attr=attr_ones)
inputs = {'x': name_ones, 'y': val_x} inputs = {'x':name_ones,'y':val_x}
attr = {'name': string(node.layer_name)} attr = {'name':string(node.layer_name)}
node.fluid_code.add_layer('elementwise_mul', node.fluid_code.add_layer('elementwise_mul',
inputs=inputs, inputs=inputs,
output=node.layer_name, output=node.layer_name,
param_attr=attr) param_attr=attr
)
def Gather(self, node): def Gather(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
...@@ -569,7 +611,7 @@ class ONNXOpMapper(OpMapper): ...@@ -569,7 +611,7 @@ class ONNXOpMapper(OpMapper):
axis = node.get_attr('axis', 0) axis = node.get_attr('axis', 0)
assert len( assert len(
indices_shape) <= 2, "Gather op don't support dim of indice >2 " indices_shape) <= 2, "Gather op don't support dim of indice >2 "
if axis == 0 and len(indices_shape) <= 1: if axis==0 and len(indices_shape)<=1:
node.fluid_code.add_layer('gather', node.fluid_code.add_layer('gather',
inputs={ inputs={
'input': val_x, 'input': val_x,
...@@ -597,15 +639,13 @@ class ONNXOpMapper(OpMapper): ...@@ -597,15 +639,13 @@ class ONNXOpMapper(OpMapper):
inputs=node, inputs=node,
output=node, output=node,
param_attr=attr_trans) param_attr=attr_trans)
elif len(indices_shape) > 1: elif len(indices_shape)>1:
from functools import reduce from functools import reduce
reshape_shape = reduce(lambda x, y: x * y, indices_shape) reshape_shape = reduce(lambda x,y:x*y, indices_shape)
node.fluid_code.add_layer('reshape', node.fluid_code.add_layer('reshape',
inputs=indices, inputs=indices,
output=indices, output=indices,
param_attr={'shape': [ param_attr={'shape':[reshape_shape,]})
reshape_shape,
]})
perm = list(range(len(val_x.out_shapes[0]))) perm = list(range(len(val_x.out_shapes[0])))
perm = [axis] + perm[:axis] + perm[axis + 1:] perm = [axis] + perm[:axis] + perm[axis + 1:]
...@@ -635,26 +675,27 @@ class ONNXOpMapper(OpMapper): ...@@ -635,26 +675,27 @@ class ONNXOpMapper(OpMapper):
node.fluid_code.add_layer('reshape', node.fluid_code.add_layer('reshape',
inputs=node, inputs=node,
output=node, output=node,
param_attr={'shape': reshaped_shape}) param_attr={'shape':reshaped_shape})
def Slice(self, node): def Slice(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_starts, val_ends, val_axes, val_steps = None, None, None, None starts, ends, axes, steps = None, None, None, None
if len(node.inputs) > 1: if len(node.inputs) > 1:
starts = self.graph.get_input_node(node, idx=1, copy=True) starts = self.graph.get_input_node(node, idx=1, copy=True)
ends = self.graph.get_input_node(node, idx=2, copy=True) ends = self.graph.get_input_node(node, idx=2, copy=True)
if len(node.inputs)>3:
axes = self.graph.get_input_node(node, idx=3, copy=True) axes = self.graph.get_input_node(node, idx=3, copy=True)
self.omit_nodes.append(axes.layer_name)
axes = _const_weight_or_none(axes)
if len(node.inputs)>4:
steps = self.graph.get_input_node(node, idx=4, copy=True) steps = self.graph.get_input_node(node, idx=4, copy=True)
self.omit_nodes.append(steps.layer_name)
steps = _const_weight_or_none(steps)
self.omit_nodes.append(starts.layer_name) self.omit_nodes.append(starts.layer_name)
self.omit_nodes.append(ends.layer_name) self.omit_nodes.append(ends.layer_name)
self.omit_nodes.append(axes.layer_name) starts = _const_weight_or_none(starts)
self.omit_nodes.append(steps.layer_name) ends = _const_weight_or_none(ends)
starts = _const_weight_or_none(starts).copy()
ends = _const_weight_or_none(ends).copy()
axes = _const_weight_or_none(axes)
steps = _const_weight_or_none(steps)
else: else:
starts = node.get_attr('starts') starts = node.get_attr('starts')
ends = node.get_attr('ends') ends = node.get_attr('ends')
...@@ -735,15 +776,16 @@ class ONNXOpMapper(OpMapper): ...@@ -735,15 +776,16 @@ class ONNXOpMapper(OpMapper):
if isinstance(val_shape, ONNXGraphNode): if isinstance(val_shape, ONNXGraphNode):
shape, _, _ = self.get_dynamic_shape(val_shape.layer_name) shape, _, _ = self.get_dynamic_shape(val_shape.layer_name)
if val_shape.dtype == 'int64': if val_shape.dtype == 'int64':
val_shape_cast = val_shape.layer_name + '_cast' val_shape_cast = val_shape.layer_name+'_cast'
node.fluid_code.add_layer('cast', node.fluid_code.add_layer('cast',
inputs=val_shape, inputs=val_shape,
output=val_shape_cast, output=val_shape_cast,
param_attr={'dtype': string('int32')}) param_attr={'dtype':string('int32')})
attr['actual_shape'] = val_shape_cast attr['actual_shape'] = val_shape_cast
else: else:
attr['actual_shape'] = val_shape attr['actual_shape'] = val_shape
if shape is None: if shape is None:
shape = val_reshaped.out_shapes[0] shape = val_reshaped.out_shapes[0]
...@@ -992,10 +1034,72 @@ class ONNXOpMapper(OpMapper): ...@@ -992,10 +1034,72 @@ class ONNXOpMapper(OpMapper):
output=node, output=node,
param_attr=attr) param_attr=attr)
def Equal(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
node.fluid_code.add_layer("equal",
inputs={'x':val_x, 'y':val_y},
output=node,
param_attr=None)
def Where(self, node):
condition = self.graph.get_input_node(node, idx=0, copy=True)
val_x = self.graph.get_input_node(node, idx=1, copy=True)
val_y = self.graph.get_input_node(node, idx=2, copy=True)
not_condition = condition.layer_name + '_not'
node.fluid_code.add_layer("logical_not",
inputs=condition,
output=not_condition,
param_attr=None)
cast_not_condition = not_condition+'_cast'
node.fluid_code.add_layer("cast",
inputs=not_condition,
output=cast_not_condition,
param_attr={'dtype':string(val_x.dtype)})
cast_condition = condition.layer_name + '_cast'
node.fluid_code.add_layer("cast",
inputs=condition,
output=cast_condition,
param_attr={'dtype':string(val_x.dtype)})
mul_val_x = val_x.layer_name + '_mul'
node.fluid_code.add_layer("elementwise_mul",
inputs={'x':val_x,'y':cast_condition},
output=mul_val_x,
param_attr=None)
mul_val_y = val_y.layer_name + '_mul'
node.fluid_code.add_layer("elementwise_mul",
inputs={'x':val_y,'y':cast_not_condition},
output=mul_val_y,
param_attr=None)
node.fluid_code.add_layer("elementwise_add",
inputs={'x':mul_val_x,'y':mul_val_y},
output=node,
param_attr=None)
def Identity(self, node): def Identity(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
node.fluid_code.add_layer("assign", inputs=val_x, output=node) node.fluid_code.add_layer("assign", inputs=val_x, output=node)
def Tile(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_repeats = self.graph.get_input_node(node, idx=1, copy=True)
repeats = _const_weight_or_none(val_repeats)
assert repeats is not None, 'for OP:Tile, only const repeats supported'
if isinstance(repeats, int):
repeats = [repeats]
attr = {
'expand_times':repeats,
"name": string(node.layer_name),
}
node.fluid_code.add_layer("expand",
inputs=val_x,
output=node,
param_attr=attr)
def MaxPool(self, node): def MaxPool(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
...@@ -1036,7 +1140,7 @@ class ONNXOpMapper(OpMapper): ...@@ -1036,7 +1140,7 @@ class ONNXOpMapper(OpMapper):
output=node, output=node,
param_attr=attr) param_attr=attr)
def GlobalAveragePool(self, node): def _global_pool(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_node(node.layer.output[0], copy=True) val_y = self.graph.get_node(node.layer.output[0], copy=True)
input_shape = val_x.out_shapes[0] input_shape = val_x.out_shapes[0]
...@@ -1048,8 +1152,15 @@ class ONNXOpMapper(OpMapper): ...@@ -1048,8 +1152,15 @@ class ONNXOpMapper(OpMapper):
poolnd = len(output_shape) - 2 # NC... poolnd = len(output_shape) - 2 # NC...
assert 2 <= poolnd <= 3, 'only pool2d and pool3d is supported' assert 2 <= poolnd <= 3, 'only pool2d and pool3d is supported'
fluid_op = 'pool{}d'.format(poolnd) fluid_op = 'pool{}d'.format(poolnd)
pool_type = None
if node.layer.op_type == 'GlobalMaxPool':
pool_type = 'max'
elif node.layer.op_type == 'GlobalAveragePool':
pool_type = 'avg'
attr = { attr = {
"pool_type": string("avg"), "pool_type": string(pool_type),
"global_pooling": True, "global_pooling": True,
"name": string(node.layer_name) "name": string(node.layer_name)
} }
...@@ -1058,6 +1169,12 @@ class ONNXOpMapper(OpMapper): ...@@ -1058,6 +1169,12 @@ class ONNXOpMapper(OpMapper):
output=node, output=node,
param_attr=attr) param_attr=attr)
def GlobalMaxPool(self, node):
self._global_pool(node)
def GlobalAveragePool(self, node):
self._global_pool(node)
def Conv(self, node): def Conv(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_w = self.graph.get_input_node(node, idx=1, copy=True) val_w = self.graph.get_input_node(node, idx=1, copy=True)
...@@ -1162,3 +1279,149 @@ class ONNXOpMapper(OpMapper): ...@@ -1162,3 +1279,149 @@ class ONNXOpMapper(OpMapper):
inputs=val_x, inputs=val_x,
output=node, output=node,
param_attr=attr) param_attr=attr)
def GRU(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_w = self.graph.get_input_node(node, idx=1, copy=True)
val_r = self.graph.get_input_node(node, idx=2, copy=True)
val_b = None
val_len = None
val_xh = None
miss_arg_num = 0
num_ipt = len(node.layer.input)
if num_ipt>3 and node.layer.input[3] != '':
val_b = self.graph.get_input_node(node, idx=3, copy=True)
else:
miss_arg_num += 1
if num_ipt>4 and node.layer.input[4] != '':
val_len = self.graph.get_input_node(node, idx=4-miss_arg_num, copy=True)
else:
miss_arg_num += 1
if num_ipt>5 and node.layer.input[5] != '':
val_xh = self.graph.get_input_node(node, idx=5-miss_arg_num, copy=True)
data, dtype, shape = self.get_dynamic_shape(val_x.layer_name)
x_shape = val_x.out_shapes[0]
assert x_shape[1] == 1, 'only X with batch_size = 1 supported'
assert node.get_attr('clip', None) is None, 'clipping not supported'
hidden_size = node.get_attr('hidden_size', None)
if hidden_size is None:
r_shape = val_r.out_shapes[0]
if r_shape:
hidden_size = r_shape[-1]
if hidden_size is None:
w_shape = var_w.out_shapes[0]
if w_shape:
hidden_size = w_shape[-2] // 3
if hidden_size is None and val_b:
b_shape = val_b.out_shapes[0]
if b_shape:
hidden_size = b_shape[-1] // 6
if hidden_size is None and val_xh:
xh_shape = val_xh.out_shapes[0]
if xh_shape:
hidden_size = xh_shape[-1]
direction = node.get_attr('direction', 'forward')
assert direction != 'bidirectional', 'direction = bidirectional not supported'
activations = node.get_attr('activations', ['Sigmoid', 'Tanh'])
assert len(activations) == 2, 'bidirectional operation not supported'
assert node.get_attr(
'linear_before_reset',
0) == 0, 'only linear_before_reset = 0 supported'
activations = [s.lower() for s in activations]
gate_activation, candidate_activation = activations
is_reverse = direction == 'reverse'
var_x0 = node.layer_name + '_x0'
node.fluid_code.add_layer('squeeze',
inputs=val_x,
output=var_x0,
param_attr={'axes': [1],'name':string(var_x0)})
var_w0 = node.layer_name + '_w0'
node.fluid_code.add_layer('squeeze',
inputs=val_w,
output=var_w0,
param_attr={'axes': [0],'name':string(var_w0)})
var_fc = node.layer_name + '_fc'
var_mm = (node.layer_name + '_mm') if val_b else var_fc
node.fluid_code.add_layer('matmul',
inputs={'x':var_x0, 'y':var_w0},
output=var_mm,
param_attr={'transpose_x': 0,'transpose_y': 1,'name':string(var_mm)})
var_r0 = node.layer_name + '_r0'
node.fluid_code.add_layer('squeeze',
inputs=val_r,
output=var_r0,
param_attr={'axes': [0],'name':string(var_r0)})
var_r0t = node.layer_name + '_r0t'
node.fluid_code.add_layer('transpose',
inputs=var_r0,
output=var_r0t,
param_attr={'perm': [1, 0],'name':string(var_r0t)})
if val_b:
var_bi = node.layer_name + '_bi'
var_bh = node.layer_name + '_bh'
node.fluid_code.add_layer('split',
inputs=val_b,
output=var_bi+','+var_bh,
param_attr={'axis': 1,
'split': [hidden_size * 3, hidden_size * 3],
'name':string(node.layer_name+'.b/split')})
var_bi0 = node.layer_name + '_bi0'
node.fluid_code.add_layer('squeeze',
inputs=var_bi,
output=var_bi0,
param_attr={'axes': [0],'name':string(var_bi0)})
node.fluid_code.add_layer('elmentwise_add',
inputs=[var_mm, var_bi0],
output=var_fc,
param_attr={'axes': 1,'name':string(node.layer_name+'.i/bias')})
if val_xh:
var_xh0 = node.layer_name + '_xh0'
node.fluid_code.add_layer('squeeze',
inputs=val_xh,
output=var_xh0,
param_attr={'axes': [1],'name':string(var_xh0)})
var_y00 = node.layer_name + '_y00'
attr={
'origin_mode':True,
'h_0': var_xh0 if val_xh else None,
'is_reverse':is_reverse,
'gate_activation':string(gate_activation),
'candidate_activation':string(candidate_activation),
'param_attr':string(var_r0t),
'bias_attr':string(var_bh) if val_b else False,
}
node.fluid_code.add_layer('dynamic_gru',
inputs=var_fc +','+ str(hidden_size),
output=var_y00,
param_attr=attr)
num_opt = len(node.layer.output)
if num_opt>0 and node.layer.output[0] != '':
node.fluid_code.add_layer('unsqueeze',
inputs=var_y00,
output=node.layer.output[0],
param_attr={'axes': [1, 1],'name':string(node.layer.output[0])})
if num_opt>1 and node.layer.output[1] != '':
node.fluid_code.add_layer('unsqueeze',
inputs=var_y00,
output=node.layer.output[1],
param_attr={'axes': [1, 1],'name':string(node.layer.output[1])})
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册