未验证 提交 a761b8fc 编写于 作者: J Jason 提交者: GitHub

Merge pull request #489 from Channingss/lstm

add lstm & mapping weight by rename to paddle's naming rule
...@@ -18,7 +18,7 @@ from __future__ import division ...@@ -18,7 +18,7 @@ from __future__ import division
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle import paddle
from paddle.fluid.proto import framework_pb2 from paddle.fluid.proto import framework_pb2
from collections import OrderedDict import collections
import numpy import numpy
import sys import sys
import os import os
...@@ -38,7 +38,7 @@ class PaddleLayer(object): ...@@ -38,7 +38,7 @@ class PaddleLayer(object):
outputs, outputs,
list), "parameter 'outputs' for PaddleLayer should be type of list" list), "parameter 'outputs' for PaddleLayer should be type of list"
for k, v in inputs.items(): for k, v in inputs.items():
if isinstance(v, list): if isinstance(v, (list, tuple)):
for i in v: for i in v:
assert isinstance( assert isinstance(
i, six.string_types i, six.string_types
...@@ -66,7 +66,7 @@ class PaddleLayer(object): ...@@ -66,7 +66,7 @@ class PaddleLayer(object):
class PaddleGraph(object): class PaddleGraph(object):
def __init__(self, source_type=None, parent_layer=None, graph_type="static"): def __init__(self, source_type=None, parent_layer=None, graph_type="static"):
self.layers = OrderedDict() self.layers = collections.OrderedDict()
self.edges_out = dict() self.edges_out = dict()
self.edges_in = dict() self.edges_in = dict()
self.inputs = list() self.inputs = list()
...@@ -94,7 +94,7 @@ class PaddleGraph(object): ...@@ -94,7 +94,7 @@ class PaddleGraph(object):
self.script = script self.script = script
def clear(self): def clear(self):
self.layers = OrderedDict() self.layers = collections.OrderedDict()
self.edges_out = dict() self.edges_out = dict()
self.edges_in = dict() self.edges_in = dict()
self.inputs = list() self.inputs = list()
...@@ -168,7 +168,7 @@ class PaddleGraph(object): ...@@ -168,7 +168,7 @@ class PaddleGraph(object):
for layer_id, layer in self.layers.items(): for layer_id, layer in self.layers.items():
for input_key, input_var in layer.inputs.items(): for input_key, input_var in layer.inputs.items():
vs = input_var vs = input_var
if not isinstance(vs, list): if not isinstance(vs, (list, tuple)):
vs = [vs] vs = [vs]
for v in vs: for v in vs:
assert v in outputs_from_nodes or ( assert v in outputs_from_nodes or (
...@@ -521,7 +521,7 @@ class PaddleGraph(object): ...@@ -521,7 +521,7 @@ class PaddleGraph(object):
gen_codes( gen_codes(
comment_list, comment_list,
indent=1)) indent=1))
use_structured_name = False if self.source_type in ["tf", "onnx"] else True use_structured_name = False if self.source_type in ["tf"] else True
self.run_func.extend( self.run_func.extend(
gen_codes(["paddle.disable_static()", gen_codes(["paddle.disable_static()",
"params = paddle.load('{}/model.pdparams')".format(osp.abspath(code_dir)), "params = paddle.load('{}/model.pdparams')".format(osp.abspath(code_dir)),
...@@ -590,7 +590,7 @@ class PaddleGraph(object): ...@@ -590,7 +590,7 @@ class PaddleGraph(object):
elif len(layer.outputs) == 2: elif len(layer.outputs) == 2:
line = layer.outputs[1] line = layer.outputs[1]
else: else:
if layer.kernel == "paddle.nn.LSTM": if layer.kernel in ["paddle.nn.LSTM"]:
line = "{}, ({})".format(layer.outputs[1], ', '.join(layer.outputs[-2:])) line = "{}, ({})".format(layer.outputs[1], ', '.join(layer.outputs[-2:]))
else: else:
line = ','.join(layer.outputs[1:]) line = ','.join(layer.outputs[1:])
...@@ -599,7 +599,12 @@ class PaddleGraph(object): ...@@ -599,7 +599,12 @@ class PaddleGraph(object):
line += " = self.{}".format(layer.outputs[0]) line += " = self.{}".format(layer.outputs[0])
else: else:
line += " = self.{}(".format(layer.outputs[0]) line += " = self.{}(".format(layer.outputs[0])
for k, v in layer.inputs.items(): for v in layer.inputs.values():
if isinstance(v, list):
line += "[{}], ".format(", ".join(v))
elif isinstance(v, tuple):
line += "({}), ".format(", ".join(v))
else:
line += "{}, ".format(v) line += "{}, ".format(v)
line = line.strip(", ") line = line.strip(", ")
line += ")" line += ")"
...@@ -627,6 +632,8 @@ class PaddleGraph(object): ...@@ -627,6 +632,8 @@ class PaddleGraph(object):
for k, v in layer.inputs.items(): for k, v in layer.inputs.items():
if isinstance(v, list): if isinstance(v, list):
line += "{}=[{}], ".format(k, ", ".join(v)) line += "{}=[{}], ".format(k, ", ".join(v))
elif isinstance(v, tuple):
line += "{}=({}), ".format(k, ", ".join(v))
else: else:
if k == "args": if k == "args":
line += v line += v
...@@ -666,7 +673,7 @@ class PaddleGraph(object): ...@@ -666,7 +673,7 @@ class PaddleGraph(object):
paddle.disable_static() paddle.disable_static()
restore = paddle.load(osp.join(save_dir, "model.pdparams")) restore = paddle.load(osp.join(save_dir, "model.pdparams"))
model = getattr(x2paddle_code, self.name)() model = getattr(x2paddle_code, self.name)()
if self.source_type in ["tf", "onnx"]: if self.source_type in ["tf"]:
model.set_dict(restore, use_structured_name=False) model.set_dict(restore, use_structured_name=False)
else: else:
model.set_dict(restore) model.set_dict(restore)
......
...@@ -96,6 +96,11 @@ class ONNXGraphNode(GraphNode): ...@@ -96,6 +96,11 @@ class ONNXGraphNode(GraphNode):
return default return default
return self.attr_map[name] return self.attr_map[name]
def output(self, index=0):
if index >0 and len(self.layer.output) <= index:
raise IndexError('Output numbers of Node:{} is {} <= index:{}'.format(self.layer_name, len(self.layer.output), index))
return self.layer.output[index]
class ONNXGraphDataNode(GraphNode): class ONNXGraphDataNode(GraphNode):
def __init__(self, layer, layer_name=None, is_global_input=False): def __init__(self, layer, layer_name=None, is_global_input=False):
...@@ -246,12 +251,7 @@ class ONNXGraph(Graph): ...@@ -246,12 +251,7 @@ class ONNXGraph(Graph):
""" """
generate output_nodes node of ONNX model generate output_nodes node of ONNX model
""" """
output_nodes = [value.name for value in self.graph.output] self.output_nodes = [value.name for value in self.graph.output]
for opt_data in output_nodes:
n = super(ONNXGraph, self).get_node(opt_data)
if n is None:
self.topo_sort.append(self.node_map[opt_data])
self.output_nodes.append(opt_data)
def is_place_holder_nodes(self, layer): def is_place_holder_nodes(self, layer):
""" """
......
...@@ -42,6 +42,31 @@ def _const_weight_or_none(node, necessary=False): ...@@ -42,6 +42,31 @@ def _const_weight_or_none(node, necessary=False):
return None return None
def _rename_or_remove_weight(weights, origin_name, target_name=None, is_remove=True):
'''
Rename parameters by Paddle's naming rule of parameters.
Args:
weights(dict[String:np.ndarray]): Dict stored paramters, the key in weights is name of parameter.
origin_name(String): Name of parameter to rename or remove.
target_name(String, optional): if target_name is not None, add new key-value pair
{target_name:weights[origin_name]} to weights, and target_name must follow paddle's
naming rule of parameters. Default: None.
is_remove: if is_remove is True, remove origin key-value pair. Default: True.
Returns:
None
'''
if origin_name not in weights:
raise KeyError('{} not a key in {}'.format(origin_name, weights))
if is_remove:
# remove weight
data = weights.pop(origin_name)
else:
data = weights[origin_name]
if target_name is not None:
# rename weight
weights[target_name] = data
def _is_static_shape(shape): def _is_static_shape(shape):
negtive_dims = 0 negtive_dims = 0
error_dims = 0 error_dims = 0
...@@ -125,6 +150,9 @@ class OpSet9(): ...@@ -125,6 +150,9 @@ class OpSet9():
dict(threshold='threshold'), dict(threshold='threshold'),
dict(threshold=float(sys.maxsize))], dict(threshold=float(sys.maxsize))],
'Exp': ['paddle.exp'], 'Exp': ['paddle.exp'],
'LogSoftmax': ['paddle.nn.functional.log_softmax',
dict(axis='axis'),
dict(axis=1)],
'Softmax': ['paddle.nn.Softmax', 'Softmax': ['paddle.nn.Softmax',
dict(axis='axis'), dict(axis='axis'),
dict(axis=1)], dict(axis=1)],
...@@ -164,11 +192,12 @@ class OpSet9(): ...@@ -164,11 +192,12 @@ class OpSet9():
layer_attrs[pd_attr_name] = onnx_attrs[onnx_attr_name] layer_attrs[pd_attr_name] = onnx_attrs[onnx_attr_name]
else: else:
layer_attrs[pd_attr_name] = op_info[2][onnx_attr_name] layer_attrs[pd_attr_name] = op_info[2][onnx_attr_name]
if paddle_op.startswith("paddle.nn"): if paddle_op.startswith("paddle.nn") and 'functional' not in paddle_op:
op_name = paddle_op[10:].lower() op_name = paddle_op[10:].lower()
op_name = name_generator(op_name, self.nn_name2id) op_name = name_generator(op_name, self.nn_name2id)
output_name = node.name output_name = node.name
layer_outputs = [op_name, output_name] layer_outputs = [op_name, output_name]
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
kernel=paddle_op, kernel=paddle_op,
inputs={"x": input.name}, inputs={"x": input.name},
...@@ -258,14 +287,12 @@ class OpSet9(): ...@@ -258,14 +287,12 @@ class OpSet9():
val_scales = self.graph.get_input_node(node, idx=1, copy=True) val_scales = self.graph.get_input_node(node, idx=1, copy=True)
# TODO(syf): paddle.nn.functional.interpolate will support the length # TODO(syf): paddle.nn.functional.interpolate will support the length
# which is the same as the rank of input. # which is the same as the rank of input.
# inputs['scale_factor'] = val_scales.name
attrs['scale_factor'] = self.weights[val_scales.name].tolist()[2:] attrs['scale_factor'] = self.weights[val_scales.name].tolist()[2:]
elif len(node.layer.input) == 3: elif len(node.layer.input) == 3:
# opset 11 # opset 11
val_scales = self.graph.get_input_node(node, idx=2, copy=True) val_scales = self.graph.get_input_node(node, idx=2, copy=True)
# TODO(syf): paddle.nn.functional.interpolate will support the length # TODO(syf): paddle.nn.functional.interpolate will support the length
# which is the same as the rank of input. # which is the same as the rank of input.
# inputs['scale_factor'] = val_scales.name
attrs['scale_factor'] = self.weights[val_scales.name].tolist()[2:] attrs['scale_factor'] = self.weights[val_scales.name].tolist()[2:]
elif len(node.layer.input) == 4: elif len(node.layer.input) == 4:
# opset 11 # opset 11
...@@ -602,11 +629,11 @@ class OpSet9(): ...@@ -602,11 +629,11 @@ class OpSet9():
val_scale = self.graph.get_input_node(node, idx=1, copy=True) val_scale = self.graph.get_input_node(node, idx=1, copy=True)
val_b = self.graph.get_input_node(node, idx=2, copy=True) val_b = self.graph.get_input_node(node, idx=2, copy=True)
epsilon = node.get_attr('epsilon', 1e-5) epsilon = node.get_attr('epsilon', 1e-5)
self.weights[op_name+'.scale'] = self.weights[val_scale.name]
self.weights[op_name+'.bias'] = self.weights[val_b.name]
layer_attrs = { layer_attrs = {
'num_features': node.out_shapes[0][1], 'num_features': node.out_shapes[0][1],
'epsilon': epsilon, 'epsilon': epsilon,
'weight_attr': string(val_scale.name),
'bias_attr': string(val_b.name)
} }
dim = len(val_x.out_shapes[0]) dim = len(val_x.out_shapes[0])
if dim == 3: if dim == 3:
...@@ -717,11 +744,11 @@ class OpSet9(): ...@@ -717,11 +744,11 @@ class OpSet9():
op_name = name_generator("embedding", self.nn_name2id) op_name = name_generator("embedding", self.nn_name2id)
output_name = node.name output_name = node.name
layer_outputs = [op_name, output_name] layer_outputs = [op_name, output_name]
self.weights[op_name + '.weight'] = _const_weight_or_none(val_x)
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.nn.Embedding', 'paddle.nn.Embedding',
inputs={"x": indices_cast}, inputs={"x": indices_cast},
outputs=layer_outputs, outputs=layer_outputs,
weight_attr=string(val_x.name),
num_embeddings=val_x.out_shapes[0][0], num_embeddings=val_x.out_shapes[0][0],
embedding_dim=val_x.out_shapes[0][1]) embedding_dim=val_x.out_shapes[0][1])
else: else:
...@@ -918,10 +945,6 @@ class OpSet9(): ...@@ -918,10 +945,6 @@ class OpSet9():
if starts_value is not None and ends_value is not None and axes is not None: if starts_value is not None and ends_value is not None and axes is not None:
starts_value = starts_value.copy() starts_value = starts_value.copy()
ends_value = ends_value.copy() ends_value = ends_value.copy()
#for idx in range(len(ends_value)):
# if ends_value[idx] > 2**31 - 1:
# ends_value[idx] = 2**31 - 1
#print(val_x.out_shapes)
for idx in range(len(ends_value)): for idx in range(len(ends_value)):
if starts_value[idx] >= val_x.out_shapes[0][axes[idx]]: if starts_value[idx] >= val_x.out_shapes[0][axes[idx]]:
starts_value[idx] = val_x.out_shapes[0][axes[idx]] - 1 starts_value[idx] = val_x.out_shapes[0][axes[idx]] - 1
...@@ -1316,6 +1339,11 @@ class OpSet9(): ...@@ -1316,6 +1339,11 @@ class OpSet9():
epsilon = node.get_attr('epsilon', 1e-5) epsilon = node.get_attr('epsilon', 1e-5)
c = val_x.out_shapes[0][1] c = val_x.out_shapes[0][1]
_rename_or_remove_weight(self.weights, val_scale.name, op_name+'.weight')
_rename_or_remove_weight(self.weights, val_b.name, op_name+'.bias')
_rename_or_remove_weight(self.weights, val_var.name, op_name+'._variance')
_rename_or_remove_weight(self.weights, val_mean.name, op_name+'._mean')
# Attribute: spatial is used in BatchNormalization-1,6,7 # Attribute: spatial is used in BatchNormalization-1,6,7
spatial = bool(node.get_attr('spatial')) spatial = bool(node.get_attr('spatial'))
layer_attrs = { layer_attrs = {
...@@ -1323,10 +1351,6 @@ class OpSet9(): ...@@ -1323,10 +1351,6 @@ class OpSet9():
"momentum": momentum, "momentum": momentum,
"epsilon": epsilon, "epsilon": epsilon,
"is_test": True, "is_test": True,
"param_attr": string(val_scale.name),
"bias_attr": string(val_b.name),
"moving_mean_name": string(val_mean.name),
"moving_variance_name": string(val_var.name),
"use_global_stats": False, "use_global_stats": False,
} }
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
...@@ -1358,7 +1382,7 @@ class OpSet9(): ...@@ -1358,7 +1382,7 @@ class OpSet9():
mode = 'channel' mode = 'channel'
shape_slope = val_slope.out_shapes[0] shape_slope = val_slope.out_shapes[0]
if shape_slope == [1]: if shape_slope == [1] * len(shape_slope):
mode = 'all' mode = 'all'
if mode == "element": if mode == "element":
...@@ -1391,17 +1415,19 @@ class OpSet9(): ...@@ -1391,17 +1415,19 @@ class OpSet9():
else: else:
if mode == 'channel': if mode == 'channel':
slope_data = _const_weight_or_none(val_slope) slope_data = _const_weight_or_none(val_slope)
_rename_or_remove_weight(self.weights, val_slope.name)
if len(shape_slope) > 1: if len(shape_slope) > 1:
self.weights[val_slope.name] = np.reshape(slope_data, shape_slope[0]) self.weights[op_name+'._weight'] = np.reshape(slope_data, shape_slope[0])
num_parameters = val_x.out_shapes[0][1] num_parameters = val_x.out_shapes[0][1]
else: else:
num_parameters = 1 num_parameters = 1
_rename_or_remove_weight(self.weights, val_slope.name)
self.weights[op_name+'._weight'] = np.reshape(self.weights[val_slope.name], [1])
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.nn.PReLU", "paddle.nn.PReLU",
inputs={"x": val_x.name}, inputs={"x": val_x.name},
outputs=layer_outputs, outputs=layer_outputs,
num_parameters=num_parameters, num_parameters=num_parameters)
weight_attr=string(val_slope.name))
@print_mapping_info @print_mapping_info
def Squeeze(self, node): def Squeeze(self, node):
...@@ -1679,19 +1705,15 @@ class OpSet9(): ...@@ -1679,19 +1705,15 @@ class OpSet9():
"dilation": dilations, "dilation": dilations,
"groups": num_groups, "groups": num_groups,
} }
val_w_name = val_w.name remove_weight = True if val_w.name in self.done_weight_list else False
while val_w_name in self.done_weight_list: if remove_weight:
val_w_name += "__repeat" self.done_weight_list.append(val_w.name)
self.done_weight_list.append(val_w_name) _rename_or_remove_weight(self.weights, val_w.name, op_name+'.weight', remove_weight)
layer_attrs["weight_attr"] = string(val_w_name)
self.weights[val_w_name] = self.weights[val_w.name]
if has_bias: if has_bias:
val_b_name = val_b.name remove_bias = True if val_b.name in self.done_weight_list else False
while val_b_name in self.done_weight_list: if remove_bias:
val_b_name += "__repeat"
self.done_weight_list.append(val_b_name) self.done_weight_list.append(val_b_name)
layer_attrs["bias_attr"] = string(val_b_name) _rename_or_remove_weight(self.weights, val_b.name, op_name+'.bias', remove_bias)
self.weights[val_b_name] = self.weights[val_b.name]
else: else:
layer_attrs["bias_attr"] = False layer_attrs["bias_attr"] = False
input_shape = val_x.out_shapes[0] input_shape = val_x.out_shapes[0]
...@@ -1712,6 +1734,9 @@ class OpSet9(): ...@@ -1712,6 +1734,9 @@ class OpSet9():
@print_mapping_info @print_mapping_info
def ConvTranspose(self, node): def ConvTranspose(self, node):
op_name = name_generator("conv_trans", self.nn_name2id)
output_name = node.name
layer_outputs = [op_name, output_name]
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_w = self.graph.get_input_node(node, idx=1, copy=True) val_w = self.graph.get_input_node(node, idx=1, copy=True)
val_b = None val_b = None
...@@ -1725,7 +1750,7 @@ class OpSet9(): ...@@ -1725,7 +1750,7 @@ class OpSet9():
assert 2 <= convnd <= 3, 'only Conv2DTranspose and Conv3DTranspose supported' assert 2 <= convnd <= 3, 'only Conv2DTranspose and Conv3DTranspose supported'
num_in_channels = val_w.out_shapes[0][0] num_in_channels = val_w.out_shapes[0][0]
num_out_channels = val_w.out_shapes[0][1] num_out_channels = val_w.out_shapes[0][1]
paddle_op = 'paddle.nn.functional.conv{}d_transpose'.format(convnd) paddle_op = 'paddle.nn.Conv{}DTranspose'.format(convnd)
num_groups = node.get_attr('group', 1) num_groups = node.get_attr('group', 1)
strides = node.get_attr('strides', [1] * convnd) strides = node.get_attr('strides', [1] * convnd)
...@@ -1743,23 +1768,26 @@ class OpSet9(): ...@@ -1743,23 +1768,26 @@ class OpSet9():
output_size[1] = (val_x.out_shapes[0][3] - 1 output_size[1] = (val_x.out_shapes[0][3] - 1
) * strides[1] - 2 * paddings[1] + dilations[1] * ( ) * strides[1] - 2 * paddings[1] + dilations[1] * (
kernel_shape[1] - 1) + 1 + out_padding[1] kernel_shape[1] - 1) + 1 + out_padding[1]
# Conv2DTranspose缺少output_size,只能在forward里头传进output_size # Conv2DTranspose缺少output_size,只能在forward里头传进output_size
inputs_dict = {'x': val_x if isinstance(val_x, str) else val_x.name, inputs_dict = {'x': val_x if isinstance(val_x, str) else val_x.name}
"weight": val_w.name}
layer_attrs = { layer_attrs = {
"in_channels": num_in_channels,
"out_channels": num_out_channels,
"kernel_size": kernel_shape,
"stride": strides, "stride": strides,
"dilation": dilations, "dilation": dilations,
"padding": paddings, "padding": paddings,
"groups": num_groups, "groups": num_groups,
"output_size": node.out_shapes[0][2:]} "output_padding":out_padding}
_rename_or_remove_weight(self.weights, val_w.name, op_name+'.weight',)
if val_b is not None: if val_b is not None:
inputs_dict["bias"] = val_b.name _rename_or_remove_weight(self.weights, val_b.name, op_name+'.bias')
else:
layer_attrs["bias"] = None
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
kernel="paddle.nn.functional.conv2d_transpose", kernel=paddle_op,
inputs=inputs_dict, inputs=inputs_dict,
outputs=[node.name], outputs=layer_outputs,
**layer_attrs) **layer_attrs)
@print_mapping_info @print_mapping_info
...@@ -1775,6 +1803,7 @@ class OpSet9(): ...@@ -1775,6 +1803,7 @@ class OpSet9():
outputs=[node.name], outputs=[node.name],
**layer_attrs) **layer_attrs)
@print_mapping_info @print_mapping_info
def Size(self, node): def Size(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
...@@ -1836,3 +1865,115 @@ class OpSet9(): ...@@ -1836,3 +1865,115 @@ class OpSet9():
"paddle.reciprocal", "paddle.reciprocal",
inputs={"x": val_x.name}, inputs={"x": val_x.name},
outputs=[node.name]) outputs=[node.name])
@print_mapping_info
def LSTM(self, node):
x = self.graph.get_input_node(node, idx=0, copy=True)
input_weight = self.graph.get_input_node(node, idx=1, copy=True)
hidden_weight = self.graph.get_input_node(node, idx=2, copy=True)
input_nums = len(node.layer.input)
exist_input_nums = 3
have_bias = False
if input_nums > 3 and node.layer.input[3] != '':
bias = self.graph.get_input_node(node, idx=exist_input_nums, copy=True)
have_bias = True
exist_input_nums += 1
if input_nums > 4 and node.layer.input[4] != '':
sequence_lens = self.graph.get_input_node(node, idx=exist_input_nums, copy=True)
exist_input_nums += 1
if input_nums > 5 and node.layer.input[5] != '':
init_h = self.graph.get_input_node(node, idx=exist_input_nums, copy=True)
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": init_h.name},
outputs=[init_h.name],
shape=init_h.out_shapes[0]
)
exist_input_nums += 1
if input_nums > 6 and node.layer.input[6] != '':
init_c = self.graph.get_input_node(node, idx=exist_input_nums, copy=True)
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": init_c.name},
outputs=[init_c.name],
shape=init_c.out_shapes[0]
)
input_weight_np = _const_weight_or_none(input_weight)
_rename_or_remove_weight(self.weights, input_weight.name)
hidden_size = node.get_attr('hidden_size', input_weight_np.shape[1]/4)
input_size = input_weight_np.shape[2]
hidden_weight_np = _const_weight_or_none(hidden_weight)
_rename_or_remove_weight(self.weights, hidden_weight.name)
bias_np = _const_weight_or_none(bias)
_rename_or_remove_weight(self.weights, bias.name)
input_bias_np = bias_np[:, :4*hidden_size]
hidden_bias_np = bias_np[:, 4*hidden_size:]
# parameters order in paddle:lstm:
# 1. gate order in paddle is: input, forget, cell, output.
# 2. gate orfer in onnx is: input, output, forget, cell.
def reform_weights(w, n, intervals):
slices = [w[:,x * n: y * n] for x, y in intervals]
return np.concatenate(slices, axis=1)
def transform_weight_with_bias(weights, n, intervals):
return [reform_weights(w, n, intervals) for w in weights]
reform_permutation = [(0, 1), (2, 4), (1, 2)]
weights = transform_weight_with_bias(
[input_weight_np, hidden_weight_np, input_bias_np, hidden_bias_np],
hidden_size, reform_permutation)
op_name = name_generator("lstm", self.nn_name2id)
y_out = node.output(0)
yh_out = node.output(1)
yc_out = node.output(2)
direction = node.get_attr('direction', 'forward')
def generate_paddle_param_names(op_name, suffix=''):
param_names = []
param_names.extend(['{}.weight_ih_l0{}', '{}.weight_hh_l0{}'])
if have_bias != False: param_names.append('{}.bias_ih_l0{}')
if have_bias != False: param_names.append('{}.bias_hh_l0{}')
param_names = [x.format(op_name, suffix) for x in param_names]
return param_names
def assign_params(op_name, weights, weight_idx=0, suffix=''):
param_names = generate_paddle_param_names(op_name, suffix)
print(param_names)
for param_name, weight in zip(param_names, weights):
self.weights[param_name] = weight[weight_idx]
if direction == 'backward':
raise Exception("LSTM support 'forward' or 'bidirectional', except '{}'.".format(direction))
else:
assign_params(op_name, weights)
if direction == 'bidirectional':
assign_params(op_name, weights, 1, '_reverse')
self.paddle_graph.add_layer(
'paddle.nn.LSTM',
inputs={'input': x.name, 'initial_states': (init_h.name, init_c.name)},
outputs=[op_name, y_out, yh_out, yc_out],
input_size=input_size,
hidden_size=hidden_size,
num_layers=1,
direction=string(direction),
time_major=True)
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": y_out},
outputs=[y_out],
shape=[0, 0, -1, hidden_size]
)
self.paddle_graph.add_layer(
'paddle.transpose',
inputs={"x": y_out},
outputs=[y_out],
perm=[0,2,1,3]
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册