Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
Opencv
提交
467c3ef0
O
Opencv
项目概览
Greenplum
/
Opencv
10 个月 前同步成功
通知
7
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
Opencv
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
467c3ef0
编写于
3月 22, 2020
作者:
D
Dmitry Kurtaev
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add checks for LSTM initial h and c
上级
84336202
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
17 addition
and
11 deletion
+17
-11
modules/dnn/src/onnx/onnx_importer.cpp
modules/dnn/src/onnx/onnx_importer.cpp
+13
-9
modules/dnn/test/test_onnx_importer.cpp
modules/dnn/test/test_onnx_importer.cpp
+4
-2
未找到文件。
modules/dnn/src/onnx/onnx_importer.cpp
浏览文件 @
467c3ef0
...
...
@@ -496,6 +496,7 @@ void ONNXImporter::populateNet(Net dstNet)
runLayer
(
layerParams
,
inputs
,
sliced
);
CV_Assert
(
sliced
.
size
()
==
1
);
constBlobs
.
insert
(
std
::
make_pair
(
layerParams
.
name
,
sliced
[
0
]));
outShapes
[
layerParams
.
name
]
=
shape
(
sliced
[
0
]);
continue
;
}
}
...
...
@@ -630,6 +631,8 @@ void ONNXImporter::populateNet(Net dstNet)
Mat
Wx
=
getBlob
(
node_proto
,
constBlobs
,
1
);
Mat
Wh
=
getBlob
(
node_proto
,
constBlobs
,
2
);
Mat
b
=
getBlob
(
node_proto
,
constBlobs
,
3
);
CV_CheckEQ
(
countNonZero
(
getBlob
(
node_proto
,
constBlobs
,
5
)),
0
,
"Unsupported non zero initial_h"
);
CV_CheckEQ
(
countNonZero
(
getBlob
(
node_proto
,
constBlobs
,
6
)),
0
,
"Unsupported non zero initial_c"
);
b
=
b
.
reshape
(
1
,
b
.
size
[
0
]);
const
int
numHidden
=
lstmParams
.
get
<
int
>
(
"hidden_size"
);
...
...
@@ -1007,6 +1010,16 @@ void ONNXImporter::populateNet(Net dstNet)
}
else
layerParams
.
type
=
"Identity"
;
if
(
constBlobs
.
find
(
node_proto
.
input
(
0
))
!=
constBlobs
.
end
())
{
Mat
inp
=
getBlob
(
node_proto
,
constBlobs
,
0
);
Mat
out
=
inp
.
reshape
(
1
,
outShape
);
out
.
dims
=
outShape
.
size
();
// to workaround dims == 1
constBlobs
.
insert
(
std
::
make_pair
(
layerParams
.
name
,
out
));
outShapes
[
layerParams
.
name
]
=
shape
(
out
);
continue
;
}
}
else
if
(
layer_type
==
"Flatten"
)
{
...
...
@@ -1136,15 +1149,6 @@ void ONNXImporter::populateNet(Net dstNet)
else
layerParams
.
type
=
"Identity"
;
}
else
if
(
layer_type
==
"ConstantFill"
||
layer_type
==
"ConstantOfShape"
)
{
CV_Assert_N
(
node_proto
.
input_size
());
MatShape
inpShape
=
getBlob
(
node_proto
,
constBlobs
,
0
);
float
value
=
layerParams
.
get
(
"value"
,
0
);
Mat
fill
(
inpShape
.
size
(),
&
inpShape
[
0
],
CV_32F
,
Scalar
(
value
));
constBlobs
.
insert
(
std
::
make_pair
(
layerParams
.
name
,
fill
));
continue
;
}
else
if
(
layer_type
==
"ConstantOfShape"
||
layer_type
==
"ConstantFill"
)
{
float
fill_value
;
...
...
modules/dnn/test/test_onnx_importer.cpp
浏览文件 @
467c3ef0
...
...
@@ -405,6 +405,8 @@ TEST_P(Test_ONNX_layers, Reshape)
TEST_P
(
Test_ONNX_layers
,
Squeeze
)
{
if
(
backend
==
DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019
&&
target
==
DNN_TARGET_MYRIAD
)
applyTestTag
(
CV_TEST_TAG_DNN_SKIP_IE_MYRIAD
,
CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER
);
testONNXModels
(
"squeeze"
);
}
...
...
@@ -453,12 +455,12 @@ TEST_P(Test_ONNX_layers, Split_EltwiseMax)
TEST_P
(
Test_ONNX_layers
,
LSTM
)
{
testONNXModels
(
"lstm"
);
testONNXModels
(
"lstm"
,
npy
,
0
,
0
,
false
,
false
);
}
TEST_P
(
Test_ONNX_layers
,
LSTM_bidirectional
)
{
testONNXModels
(
"lstm_bidirectional"
);
testONNXModels
(
"lstm_bidirectional"
,
npy
,
0
,
0
,
false
,
false
);
}
INSTANTIATE_TEST_CASE_P
(
/*nothing*/
,
Test_ONNX_layers
,
dnnBackendsAndTargets
());
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录