未验证 提交 f8e32aba 编写于 作者: N nihui 提交者: GitHub

fix pnnx gru rnn with optional output, fix #4608 (#4631)

上级 d87e895a
......@@ -1873,6 +1873,24 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)
}
fprintf(pyfp, "]\n");
}
else if (op->type == "nn.GRU" || op->type == "nn.RNN")
{
if (op->outputs.size() == 1)
{
fprintf(pyfp, "v_%s, _", sanitize_identifier(op->outputs[0]->name).c_str());
}
else
{
fprintf(pyfp, "v_%s, v_%s", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->outputs[1]->name).c_str());
}
fprintf(pyfp, " = self.%s(", sanitize_identifier(op->name).c_str());
fprintf(pyfp, "v_%s", sanitize_identifier(op->inputs[0]->name).c_str());
if (op->inputs.size() == 2)
{
fprintf(pyfp, ", v_%s", sanitize_identifier(op->inputs[1]->name).c_str());
}
fprintf(pyfp, ")\n");
}
else if (op->type == "nn.LSTM")
{
if (op->outputs.size() == 1)
......
......@@ -32,15 +32,15 @@ class Model(nn.Module):
def forward(self, x, y):
x0, h0 = self.gru_0_0(x)
x1, h1 = self.gru_0_1(x0)
x1, _ = self.gru_0_1(x0)
x2, h2 = self.gru_0_2(x1)
x3, h3 = self.gru_0_3(x1, h2)
y0, h4 = self.gru_1_0(y)
y1, h5 = self.gru_1_1(y0)
y1, _ = self.gru_1_1(y0)
y2, h6 = self.gru_1_2(y1)
y3, h7 = self.gru_1_3(y1, h6)
return x2, x3, h0, h1, h2, h3, y2, y3, h4, h5, h6, h7
return x2, x3, h0, h2, h3, y2, y3, h4, h6, h7
def test():
net = Model()
......
......@@ -32,15 +32,15 @@ class Model(nn.Module):
def forward(self, x, y):
x0, (h0, c0) = self.lstm_0_0(x)
x1, (h1, c1) = self.lstm_0_1(x0)
x1, _ = self.lstm_0_1(x0)
x2, (h2, c2) = self.lstm_0_2(x1)
x3, (h3, c3) = self.lstm_0_3(x2, (h2, c2))
y0, (h4, c4) = self.lstm_1_0(y)
y1, (h5, c5) = self.lstm_1_1(y0)
y1, _ = self.lstm_1_1(y0)
y2, (h6, c6) = self.lstm_1_2(y1)
y3, (h7, c7) = self.lstm_1_3(y2, (h6, c6))
return x2, x3, h0, h1, h2, h3, c0, c1, c2, c3, y2, y3, h4, h5, h6, h7, c4, c5, c6, c7
return x2, x3, h0, h2, h3, c0, c2, c3, y2, y3, h4, h6, h7, c4, c6, c7
def test():
net = Model()
......
......@@ -32,15 +32,15 @@ class Model(nn.Module):
def forward(self, x, y):
x0, h0 = self.rnn_0_0(x)
x1, h1 = self.rnn_0_1(x0)
x1, _ = self.rnn_0_1(x0)
x2, h2 = self.rnn_0_2(x1)
x3, h3 = self.rnn_0_3(x1, h2)
y0, h4 = self.rnn_1_0(y)
y1, h5 = self.rnn_1_1(y0)
y1, _ = self.rnn_1_1(y0)
y2, h6 = self.rnn_1_2(y1)
y3, h7 = self.rnn_1_3(y1, h6)
return x2, x3, h0, h1, h2, h3, y2, y3, h4, h5, h6, h7
return x2, x3, h0, h2, h3, y2, y3, h4, h6, h7
def test():
net = Model()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册