提交 03f2420d 编写于 作者: C Channingss

add lstm

上级 608dbd8a
......@@ -18,7 +18,7 @@ from __future__ import division
import paddle.fluid as fluid
import paddle
from paddle.fluid.proto import framework_pb2
from collections import OrderedDict
import collections
import numpy
import sys
import os
......@@ -38,7 +38,7 @@ class PaddleLayer(object):
outputs,
list), "parameter 'outputs' for PaddleLayer should be type of list"
for k, v in inputs.items():
if isinstance(v, list):
if isinstance(v, (list, tuple)):
for i in v:
assert isinstance(
i, six.string_types
......@@ -66,7 +66,7 @@ class PaddleLayer(object):
class PaddleGraph(object):
def __init__(self, source_type=None, parent_layer=None, graph_type="static"):
self.layers = OrderedDict()
self.layers = collections.OrderedDict()
self.edges_out = dict()
self.edges_in = dict()
self.inputs = list()
......@@ -94,7 +94,7 @@ class PaddleGraph(object):
self.script = script
def clear(self):
self.layers = OrderedDict()
self.layers = collections.OrderedDict()
self.edges_out = dict()
self.edges_in = dict()
self.inputs = list()
......@@ -166,9 +166,10 @@ class PaddleGraph(object):
self.clear_edges()
outputs_from_nodes = dict()
for layer_id, layer in self.layers.items():
print(layer.kernel, layer.outputs ,layer.inputs)
for input_key, input_var in layer.inputs.items():
vs = input_var
if not isinstance(vs, list):
if not isinstance(vs, (list, tuple)):
vs = [vs]
for v in vs:
assert v in outputs_from_nodes or (
......@@ -616,6 +617,8 @@ class PaddleGraph(object):
for k, v in layer.inputs.items():
if isinstance(v, list):
line += "{}=[{}], ".format(k, ", ".join(v))
elif isinstance(v, tuple):
line += "{}=({}), ".format(k, ", ".join(v))
else:
if k == "args":
line += v
......
......@@ -95,6 +95,13 @@ class ONNXGraphNode(GraphNode):
return default
return self.attr_map[name]
def output(self, index=0):
if index >0 and len(self.layer.output) <= index:
raise IndexError('Output numbers of Node:{} is {} <= index:{}'.format(self.layer_name, len(self.layer.output), index))
if index > 0:
return "{}_p{}".format(self.layer_name, index)
return self.layer_name
class ONNXGraphDataNode(GraphNode):
def __init__(self, layer, layer_name=None, is_global_input=False):
......
......@@ -122,6 +122,9 @@ class OpSet9():
dict(threshold='threshold'),
dict(threshold=float(sys.maxsize))],
'Exp': ['paddle.exp'],
'LogSoftmax': ['paddle.nn.functional.log_softmax',
dict(axis='axis'),
dict(axis=1)],
'Softmax': ['paddle.nn.Softmax',
dict(axis='axis'),
dict(axis=1)],
......@@ -1633,4 +1636,89 @@ class OpSet9():
'paddle.argmax',
inputs={"x": val_x.name},
outputs=[node.name],
**layer_attrs)
\ No newline at end of file
**layer_attrs)
@print_mapping_info
def LSTM(self, node):
# parameters order in paddle:lstm:
# 1. gate order in paddle is: input, forget, cell, output.
# 2. gate orfer in onnx is: input, output, forget, cell.
def reform_weights(w, n, intervals):
slices = [w[:,x * n: y * n] for x, y in intervals]
return np.concatenate(slices, axis=1)
def transform_weight_with_bias(weights, n, intervals):
return [reform_weights(w, n, intervals) for w in weights]
print(node.layer.input)
x = self.graph.get_input_node(node, idx=0, copy=True)
input_weight = self.graph.get_input_node(node, idx=1, copy=True)
hidden_weight = self.graph.get_input_node(node, idx=2, copy=True)
input_nums = len(node.layer.input)
exist_input_nums = 3
if input_nums > 3 and node.layer.input[3] != '':
bias = self.graph.get_input_node(node, idx=exist_input_nums, copy=True)
exist_input_nums += 1
if input_nums > 4 and node.layer.input[4] != '':
sequence_lens = self.graph.get_input_node(node, idx=exist_input_nums, copy=True)
exist_input_nums += 1
if input_nums > 5 and node.layer.input[5] != '':
init_h = self.graph.get_input_node(node, idx=exist_input_nums, copy=True)
exist_input_nums += 1
if input_nums > 6 and node.layer.input[6] != '':
init_c = self.graph.get_input_node(node, idx=exist_input_nums, copy=True)
input_weight_np = _const_weight_or_none(input_weight)
hidden_size = node.get_attr('hidden_size', input_weight_np.shape[1]/3)
input_size = input_weight_np.shape[2]
hidden_weight_np = _const_weight_or_none(hidden_weight)
bias_np = _const_weight_or_none(bias)
input_bias_np = bias_np[:, :3*hidden_size]
hidden_bias_np = bias_np[:, 3*hidden_size:]
reform_permutation = [(0, 1), (3, 4), (1, 3)]
input_weight_np, hidden_weight_np, input_bias_np, hidden_bias_np = transform_weight_with_bias(
[input_weight_np, hidden_weight_np, input_bias_np, hidden_bias_np],
hidden_size, reform_permutation)
self.weights[input_weight.name] = input_weight_np
self.weights[hidden_weight.name] = hidden_weight_np
input_bias_name = bias.name + '_input'
hidden_bias_name = bias.name + '_hidden'
self.weights[input_bias_name] = input_bias_np
self.weights[hidden_bias_name] = hidden_bias_np
op_name = name_generator("lstm", self.nn_name2id)
y_out = node.output(0)
yh_out = node.output(1)
yc_out = node.output(2)
self.paddle_graph.add_layer(
'paddle.nn.LSTM',
inputs={'input': x.name, 'initial_states': (init_h.name, init_c.name)},
outputs=[op_name, y_out, yh_out, yc_out],
input_size=input_size,
hidden_size=hidden_size,
num_layers=1,
weight_ih_attr=string(input_weight.name),
weight_hh_attr=string(hidden_weight.name),
bias_ih_attr=string(input_bias_name),
bias_hh_attr=string(hidden_bias_name),
direction=string(node.get_attr('direction')),
time_major=True)
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": y_out},
outputs=[y_out],
shape=[-1, -1, -1, hidden_size]
)
self.paddle_graph.add_layer(
'paddle.transpose',
inputs={"x": y_out},
outputs=[y_out],
perm=[0,2,1,3]
)
......@@ -1587,4 +1587,5 @@ class OpSet9():
'paddle.argmax',
inputs={"x": val_x.name},
outputs=[node.name],
**layer_attrs)
\ No newline at end of file
**layer_attrs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册