提交 14da5ec3 编写于 作者: D Dmitry Kurtaev

LSTM scalar

上级 25ab141b
......@@ -215,6 +215,8 @@ public:
internals.push_back(shape(_numSamples, 1)); // dummyOnes
internals.push_back(shape(_numSamples, 4*_numOut)); // gates
std::cout << "LSTM out: " << outputs[0] << '\n';
return false;
}
......@@ -301,6 +303,8 @@ public:
tsEnd = numTimeStamps;
tsInc = 1;
}
std::cout << "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" << '\n';
std::cout << tsStart << " " << tsEnd << '\n';
for (int ts = tsStart; ts != tsEnd; ts += tsInc)
{
Range curRowRange(ts*numSamples, (ts + 1)*numSamples);
......@@ -314,6 +318,7 @@ public:
Mat gateF = gates.colRange(1*numOut, 2*numOut);
Mat gateO = gates.colRange(2*numOut, 3*numOut);
Mat gateG = gates.colRange(3*numOut, 4*numOut);
std::cout << "i " << gateI << '\n';
if (forgetBias)
add(gateF, forgetBias, gateF);
......@@ -329,6 +334,7 @@ public:
{
Mat gatesIFO = gates.colRange(0, 3*numOut);
sigmoid(gatesIFO, gatesIFO);
std::cout << "ifo " << gatesIFO << '\n';
}
tanh(gateG, gateG);
......@@ -345,12 +351,15 @@ public:
}
if (usePeephole)
{
std::cout << "if (usePeephole)" << '\n';
gemm(cInternal, blobs[5], 1, gateO, 1, gateO);
sigmoid(gateO, gateO);
}
//compute h_t
tanh(cInternal, hInternal);
std::cout << "o " << gateO << '\n';
std::cout << "tanh(o) " << hInternal << '\n';
multiply(gateO, hInternal, hInternal);
//save results in output blobs
......@@ -358,6 +367,7 @@ public:
if (produceCellOutput)
cInternal.copyTo(cOutTs.rowRange(curRowRange));
}
std::cout << "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~" << '\n';
}
};
......
......@@ -290,6 +290,30 @@ public:
}
};
// // To remove Squeeze after LSTM for non-bidirectional LSTM
// class LSTMSqueeze : public Subgraph
// {
// public:
// LSTMSqueeze()
// {
// int input = addNodeToMatch("");
//
// std::vector<int> lstmInps(7);
// lstmInps[0] = input;
//
// for (int i = 1; i < 4; ++i)
// lstmInps[i] = addNodeToMatch("Unsqueeze");
// lstmInps[4] = addNodeToMatch("");
// for (int i = 5; i < 7; ++i)
// lstmInps[i] = addNodeToMatch("ConstantOfShape");
//
// int lstm = addNodeToMatch("LSTM", lstmInps);
// addNodeToMatch("Squeeze", lstm);
//
// setFusedNode("LSTM", lstmInps);
// }
// };
void simplifySubgraphs(opencv_onnx::GraphProto& net)
{
std::vector<Ptr<Subgraph> > subgraphs;
......@@ -299,6 +323,7 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
subgraphs.push_back(makePtr<ResizeSubgraph1>());
subgraphs.push_back(makePtr<ResizeSubgraph2>());
subgraphs.push_back(makePtr<SoftMaxSubgraph>());
// subgraphs.push_back(makePtr<LSTMSqueeze>());
simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs);
}
......
......@@ -322,7 +322,7 @@ void ONNXImporter::populateNet(Net dstNet)
std::string layer_type = node_proto.op_type();
layerParams.type = layer_type;
std::cout << layerParams.name << " " << layer_type << '\n';
if (layer_type == "MaxPool")
{
......@@ -457,6 +457,19 @@ void ONNXImporter::populateNet(Net dstNet)
constBlobs.insert(std::make_pair(layerParams.name, sliced[0]));
continue;
}
layerParams.set("begin", DictValue::arrayInt(&begin[0], begin.size()));
layerParams.set("end", DictValue::arrayInt(&end[0], end.size()));
CV_Assert(node_proto.input_size() == 1);
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
{
std::vector<Mat> inputs(1, getBlob(node_proto, constBlobs, 0)), sliced;
runLayer(layerParams, inputs, sliced);
CV_Assert(sliced.size() == 1);
constBlobs.insert(std::make_pair(layerParams.name, sliced[0]));
continue;
}
}
else if (layer_type == "Split")
{
......@@ -579,6 +592,117 @@ void ONNXImporter::populateNet(Net dstNet)
constBlobs.insert(std::make_pair(layerParams.name, layerParams.blobs[0]));
continue;
}
else if (layer_type == "ConstantFill" || layer_type == "ConstantOfShape")
{
CV_Assert_N(node_proto.input_size());
MatShape inpShape = getBlob(node_proto, constBlobs, 0);
float value = layerParams.get("value", 0);
Mat fill(inpShape.size(), &inpShape[0], CV_32F, Scalar(value));
constBlobs.insert(std::make_pair(layerParams.name, fill));
continue;
}
else if (layer_type == "LSTM")
{
std::cout << "~~~~~~" << '\n';
std::cout << layerParams << '\n';
for (int i = 1; i < node_proto.input_size(); ++i) {
std::cout << "i: " << node_proto.input(i) << " " << constBlobs[node_proto.input(i)].size << '\n';
}
CV_Assert(node_proto.input_size() == 7);
Mat Wx = getBlob(node_proto, constBlobs, 1);
Mat Wh = getBlob(node_proto, constBlobs, 2);
Mat b = getBlob(node_proto, constBlobs, 3);
std::cout << Wx.size << '\n';
std::cout << Wh.size << '\n';
int Wx_shape[] = {Wx.size[1], Wx.size[2]};
int Wh_shape[] = {Wh.size[1], Wh.size[2]};
std::cout << "b.size " << b.size << '\n';
int b_shape[] = {2, b.size[1] / 2};
Wx = Wx.reshape(1, 2, &Wx_shape[0]);
b = b.reshape(1, 2, &b_shape[0]);
std::cout << "b ----------------" << '\n';
std::cout << b << '\n';
reduce(b, b, 0, REDUCE_SUM);
std::cout << b << '\n';
// https://pytorch.org/docs/stable/nn.html#lstm
// IFGO->IFOG
// swap each 3rd and 4th rows
// Wx = Wx.t();
float* weightData = (float*)Wx.data;
std::swap(weightData[1], weightData[2]);
float* biasData = (float*)b.data;
std::swap(biasData[1], biasData[2]);
// std::swap(weightData[2], weightData[3]);
//
// weightData = (float*)Wh.data;
// std::swap(weightData[1], weightData[2]);
// std::swap(weightData[2], weightData[3]);
// const int outSize = Wx.cols / 4;
// for (int i = 0; i < Wx.rows; ++i)
// for (int j = 0; j < outSize; ++j)
// {
// // std::swap(weightData[i * W.cols + 1 * outSize + j],
// // weightData[i * W.cols + 2 * outSize + j]);
// std::swap(weightData[i * Wx.cols + 2 * outSize + j],
// weightData[i * Wx.cols + 3 * outSize + j]);
// }
// float* weightData = Wx.ptr<float>();
// for (int j = 0; j < 5; ++j)
// {
// std::cout << "swap " << (10 + j) << " " << (15 + j) << '\n';
// for (int i = 0; i < 12; ++i)
// std::swap(weightData[(10 + j) * 12 + i],
// weightData[(15 + j) * 12 + i]);
// }
layerParams.blobs.resize(3);
layerParams.blobs[0] = Wh.reshape(1, 2, &Wh_shape[0]);
layerParams.blobs[1] = Wx;
layerParams.blobs[2] = b;
std::cout << "Wx" << '\n';
std::cout << layerParams.blobs[1] << '\n';
std::cout << "Wh" << '\n';
std::cout << layerParams.blobs[0] << '\n';
// layerParams.set("reverse", true);
// layerParams.set("use_peephole", true);
// layerParams.blobs.resize(6);
// for (int i = 0; i < 3; ++i)
// {
// Mat w = Mat::eye(layerParams.blobs[0].cols, layerParams.blobs[0].cols, CV_32F);
// layerParams.blobs[3 + i] = w;
// }
// std::cout << layerParams.blobs[1] << '\n';
// int lstmId = dstNet.addLayer(layerParams.name, layerParams.type, layerParams);
//
// layerParams = LayerParams();
//
// // Add reshape
// int shape[] = {1, 10, 11, 5};
// layerParams.name = node_proto.output(0) + "/reshape";
// layerParams.type = "Reshape";
// layerParams.set("dim", DictValue::arrayInt(&shape[0], 4));
}
else if (layer_type == "ImageScaler")
{
const float scale = layerParams.has("scale") ? layerParams.get<float>("scale") : 1.0f;
......@@ -881,14 +1005,14 @@ void ONNXImporter::populateNet(Net dstNet)
else if (layer_type == "Squeeze")
{
CV_Assert_N(node_proto.input_size() == 1, layerParams.has("axes"));
DictValue axes_dict = layerParams.get("axes");
if (axes_dict.size() != 1)
CV_Error(Error::StsNotImplemented, "Multidimensional squeeze");
int axis = axes_dict.getIntValue(0);
layerParams.set("axis", axis - 1);
layerParams.set("end_axis", axis);
layerParams.type = "Flatten";
// DictValue axes_dict = layerParams.get("axes");
// if (axes_dict.size() != 1)
// CV_Error(Error::StsNotImplemented, "Multidimensional squeeze");
//
// int axis = axes_dict.getIntValue(0);
// layerParams.set("axis", axis - 1);
// layerParams.set("end_axis", axis);
layerParams.type = "Identity";
}
else if (layer_type == "Flatten")
{
......@@ -1032,17 +1156,30 @@ void ONNXImporter::populateNet(Net dstNet)
else if (layer_type == "Gather")
{
CV_Assert(node_proto.input_size() == 2);
CV_Assert(layerParams.has("axis"));
Mat input = getBlob(node_proto, constBlobs, 0);
Mat indexMat = getBlob(node_proto, constBlobs, 1);
CV_Assert_N(indexMat.type() == CV_32S, indexMat.total() == 1);
int index = indexMat.at<int>(0);
int axis = layerParams.get<int>("axis");
std::vector<cv::Range> ranges(input.dims, Range::all());
ranges[axis] = Range(index, index + 1);
Mat out;
if (layerParams.has("axis"))
{
int axis = layerParams.get<int>("axis");
std::vector<cv::Range> ranges(input.dims, Range::all());
ranges[axis] = Range(index, index + 1);
Mat out = input(ranges);
out = input(ranges);
}
else
{
CV_Assert(index < input.total());
const int dims = input.dims;
input = input.reshape(1, 1);
input.dims = 2;
out = input.reshape(1, 1).colRange(index, index + 1);
out.dims = dims;
}
constBlobs.insert(std::make_pair(layerParams.name, out));
continue;
}
......
......@@ -1826,10 +1826,12 @@ void TFImporter::populateNet(Net dstNet)
const int outSize = W.cols / 4;
// IGFO->IFOG
std::cout << "(TF) W " << W.size << '\n';
float* weightData = (float*)W.data;
for (int i = 0; i < W.rows; ++i)
for (int j = 0; j < outSize; ++j)
{
// std::cout << "swap " << i * W.cols + 1 * outSize << " " << i * W.cols + 2 * outSize << '\n';
std::swap(weightData[i * W.cols + 1 * outSize + j],
weightData[i * W.cols + 2 * outSize + j]);
std::swap(weightData[i * W.cols + 2 * outSize + j],
......@@ -1838,6 +1840,11 @@ void TFImporter::populateNet(Net dstNet)
Wx = W.rowRange(0, W.rows - outSize).t();
Wh = W.rowRange(W.rows - outSize, W.rows).t();
std::cout << "(TF) Wx " << Wx.size << '\n';
std::cout << "(TF) Wh " << Wh.size << '\n';
std::cout << "(TF) b " << b.size << '\n';
layerParams.blobs.resize(3);
layerParams.blobs[0] = Wh;
layerParams.blobs[1] = Wx;
......
......@@ -79,6 +79,12 @@ public:
netSoftmax.setInput(ref);
ref = netSoftmax.forward();
}
std::cout << "ref: " << ref.size << '\n';
std::cout << "out: " << out.size << '\n';
std::cout << ref.reshape(1, 1) << '\n';
std::cout << '\n';
std::cout << out.reshape(1, 1) << '\n';
normAssert(ref, out, "", l1 ? l1 : default_l1, lInf ? lInf : default_lInf);
if (checkNoFallbacks)
expectNoFallbacksFromIE(net);
......@@ -451,6 +457,11 @@ TEST_P(Test_ONNX_layers, Split_EltwiseMax)
testONNXModels("split_max");
}
TEST_P(Test_ONNX_layers, LSTM)
{
testONNXModels("lstm");
}
INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());
class Test_ONNX_nets : public Test_ONNX_layers
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册