Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
Opencv
提交
840c892a
O
Opencv
项目概览
Greenplum
/
Opencv
大约 1 年 前同步成功
通知
7
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
Opencv
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
840c892a
编写于
12月 21, 2018
作者:
D
Dmitry Kurtaev
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Batch normalization in training phase from Torch
上级
09d8bbb1
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
19 addition
and
14 deletion
+19
-14
modules/dnn/include/opencv2/dnn/dnn.hpp
modules/dnn/include/opencv2/dnn/dnn.hpp
+4
-3
modules/dnn/src/torch/torch_importer.cpp
modules/dnn/src/torch/torch_importer.cpp
+8
-5
modules/dnn/test/test_torch_importer.cpp
modules/dnn/test/test_torch_importer.cpp
+7
-6
未找到文件。
modules/dnn/include/opencv2/dnn/dnn.hpp
浏览文件 @
840c892a
...
@@ -46,9 +46,9 @@
...
@@ -46,9 +46,9 @@
#include <opencv2/core.hpp>
#include <opencv2/core.hpp>
#if !defined CV_DOXYGEN && !defined CV_DNN_DONT_ADD_EXPERIMENTAL_NS
#if !defined CV_DOXYGEN && !defined CV_DNN_DONT_ADD_EXPERIMENTAL_NS
#define CV__DNN_EXPERIMENTAL_NS_BEGIN namespace experimental_dnn_34_v1
0
{
#define CV__DNN_EXPERIMENTAL_NS_BEGIN namespace experimental_dnn_34_v1
1
{
#define CV__DNN_EXPERIMENTAL_NS_END }
#define CV__DNN_EXPERIMENTAL_NS_END }
namespace
cv
{
namespace
dnn
{
namespace
experimental_dnn_34_v1
0
{
}
using
namespace
experimental_dnn_34_v10
;
}}
namespace
cv
{
namespace
dnn
{
namespace
experimental_dnn_34_v1
1
{
}
using
namespace
experimental_dnn_34_v11
;
}}
#else
#else
#define CV__DNN_EXPERIMENTAL_NS_BEGIN
#define CV__DNN_EXPERIMENTAL_NS_BEGIN
#define CV__DNN_EXPERIMENTAL_NS_END
#define CV__DNN_EXPERIMENTAL_NS_END
...
@@ -754,6 +754,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
...
@@ -754,6 +754,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
* @brief Reads a network model stored in <a href="http://torch.ch">Torch7</a> framework's format.
* @brief Reads a network model stored in <a href="http://torch.ch">Torch7</a> framework's format.
* @param model path to the file, dumped from Torch by using torch.save() function.
* @param model path to the file, dumped from Torch by using torch.save() function.
* @param isBinary specifies whether the network was serialized in ascii mode or binary.
* @param isBinary specifies whether the network was serialized in ascii mode or binary.
* @param evaluate specifies testing phase of network. If true, it's similar to evaluate() method in Torch.
* @returns Net object.
* @returns Net object.
*
*
* @note Ascii mode of Torch serializer is more preferable, because binary mode extensively use `long` type of C language,
* @note Ascii mode of Torch serializer is more preferable, because binary mode extensively use `long` type of C language,
...
@@ -775,7 +776,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
...
@@ -775,7 +776,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
*
*
* Also some equivalents of these classes from cunn, cudnn, and fbcunn may be successfully imported.
* Also some equivalents of these classes from cunn, cudnn, and fbcunn may be successfully imported.
*/
*/
CV_EXPORTS_W
Net
readNetFromTorch
(
const
String
&
model
,
bool
isBinary
=
true
);
CV_EXPORTS_W
Net
readNetFromTorch
(
const
String
&
model
,
bool
isBinary
=
true
,
bool
evaluate
=
true
);
/**
/**
* @brief Read deep learning network represented in one of the supported formats.
* @brief Read deep learning network represented in one of the supported formats.
...
...
modules/dnn/src/torch/torch_importer.cpp
浏览文件 @
840c892a
...
@@ -129,13 +129,15 @@ struct TorchImporter
...
@@ -129,13 +129,15 @@ struct TorchImporter
Module
*
rootModule
;
Module
*
rootModule
;
Module
*
curModule
;
Module
*
curModule
;
int
moduleCounter
;
int
moduleCounter
;
bool
testPhase
;
TorchImporter
(
String
filename
,
bool
isBinary
)
TorchImporter
(
String
filename
,
bool
isBinary
,
bool
evaluate
)
{
{
CV_TRACE_FUNCTION
();
CV_TRACE_FUNCTION
();
rootModule
=
curModule
=
NULL
;
rootModule
=
curModule
=
NULL
;
moduleCounter
=
0
;
moduleCounter
=
0
;
testPhase
=
evaluate
;
file
=
cv
::
Ptr
<
THFile
>
(
THDiskFile_new
(
filename
,
"r"
,
0
),
THFile_free
);
file
=
cv
::
Ptr
<
THFile
>
(
THDiskFile_new
(
filename
,
"r"
,
0
),
THFile_free
);
CV_Assert
(
file
&&
THFile_isOpened
(
file
));
CV_Assert
(
file
&&
THFile_isOpened
(
file
));
...
@@ -680,7 +682,8 @@ struct TorchImporter
...
@@ -680,7 +682,8 @@ struct TorchImporter
layerParams
.
blobs
.
push_back
(
tensorParams
[
"bias"
].
second
);
layerParams
.
blobs
.
push_back
(
tensorParams
[
"bias"
].
second
);
}
}
if
(
nnName
==
"InstanceNormalization"
)
bool
trainPhase
=
scalarParams
.
get
<
bool
>
(
"train"
,
false
);
if
(
nnName
==
"InstanceNormalization"
||
(
trainPhase
&&
!
testPhase
))
{
{
cv
::
Ptr
<
Module
>
mvnModule
(
new
Module
(
nnName
));
cv
::
Ptr
<
Module
>
mvnModule
(
new
Module
(
nnName
));
mvnModule
->
apiType
=
"MVN"
;
mvnModule
->
apiType
=
"MVN"
;
...
@@ -1243,18 +1246,18 @@ struct TorchImporter
...
@@ -1243,18 +1246,18 @@ struct TorchImporter
Mat
readTorchBlob
(
const
String
&
filename
,
bool
isBinary
)
Mat
readTorchBlob
(
const
String
&
filename
,
bool
isBinary
)
{
{
TorchImporter
importer
(
filename
,
isBinary
);
TorchImporter
importer
(
filename
,
isBinary
,
true
);
importer
.
readObject
();
importer
.
readObject
();
CV_Assert
(
importer
.
tensors
.
size
()
==
1
);
CV_Assert
(
importer
.
tensors
.
size
()
==
1
);
return
importer
.
tensors
.
begin
()
->
second
;
return
importer
.
tensors
.
begin
()
->
second
;
}
}
Net
readNetFromTorch
(
const
String
&
model
,
bool
isBinary
)
Net
readNetFromTorch
(
const
String
&
model
,
bool
isBinary
,
bool
evaluate
)
{
{
CV_TRACE_FUNCTION
();
CV_TRACE_FUNCTION
();
TorchImporter
importer
(
model
,
isBinary
);
TorchImporter
importer
(
model
,
isBinary
,
evaluate
);
Net
net
;
Net
net
;
importer
.
populateNet
(
net
);
importer
.
populateNet
(
net
);
return
net
;
return
net
;
...
...
modules/dnn/test/test_torch_importer.cpp
浏览文件 @
840c892a
...
@@ -73,7 +73,7 @@ class Test_Torch_layers : public DNNTestLayer
...
@@ -73,7 +73,7 @@ class Test_Torch_layers : public DNNTestLayer
{
{
public:
public:
void
runTorchNet
(
const
String
&
prefix
,
String
outLayerName
=
""
,
void
runTorchNet
(
const
String
&
prefix
,
String
outLayerName
=
""
,
bool
check2ndBlob
=
false
,
bool
isBinary
=
false
,
bool
check2ndBlob
=
false
,
bool
isBinary
=
false
,
bool
evaluate
=
true
,
double
l1
=
0.0
,
double
lInf
=
0.0
)
double
l1
=
0.0
,
double
lInf
=
0.0
)
{
{
String
suffix
=
(
isBinary
)
?
".dat"
:
".txt"
;
String
suffix
=
(
isBinary
)
?
".dat"
:
".txt"
;
...
@@ -84,7 +84,7 @@ public:
...
@@ -84,7 +84,7 @@ public:
checkBackend
(
backend
,
target
,
&
inp
,
&
outRef
);
checkBackend
(
backend
,
target
,
&
inp
,
&
outRef
);
Net
net
=
readNetFromTorch
(
_tf
(
prefix
+
"_net"
+
suffix
),
isBinary
);
Net
net
=
readNetFromTorch
(
_tf
(
prefix
+
"_net"
+
suffix
),
isBinary
,
evaluate
);
ASSERT_FALSE
(
net
.
empty
());
ASSERT_FALSE
(
net
.
empty
());
net
.
setPreferableBackend
(
backend
);
net
.
setPreferableBackend
(
backend
);
...
@@ -114,7 +114,7 @@ TEST_P(Test_Torch_layers, run_convolution)
...
@@ -114,7 +114,7 @@ TEST_P(Test_Torch_layers, run_convolution)
// Output reference values are in range [23.4018, 72.0181]
// Output reference values are in range [23.4018, 72.0181]
double
l1
=
(
target
==
DNN_TARGET_OPENCL_FP16
||
target
==
DNN_TARGET_MYRIAD
)
?
0.08
:
default_l1
;
double
l1
=
(
target
==
DNN_TARGET_OPENCL_FP16
||
target
==
DNN_TARGET_MYRIAD
)
?
0.08
:
default_l1
;
double
lInf
=
(
target
==
DNN_TARGET_OPENCL_FP16
||
target
==
DNN_TARGET_MYRIAD
)
?
0.42
:
default_lInf
;
double
lInf
=
(
target
==
DNN_TARGET_OPENCL_FP16
||
target
==
DNN_TARGET_MYRIAD
)
?
0.42
:
default_lInf
;
runTorchNet
(
"net_conv"
,
""
,
false
,
true
,
l1
,
lInf
);
runTorchNet
(
"net_conv"
,
""
,
false
,
true
,
true
,
l1
,
lInf
);
}
}
TEST_P
(
Test_Torch_layers
,
run_pool_max
)
TEST_P
(
Test_Torch_layers
,
run_pool_max
)
...
@@ -147,7 +147,7 @@ TEST_P(Test_Torch_layers, run_reshape)
...
@@ -147,7 +147,7 @@ TEST_P(Test_Torch_layers, run_reshape)
TEST_P
(
Test_Torch_layers
,
run_reshape_single_sample
)
TEST_P
(
Test_Torch_layers
,
run_reshape_single_sample
)
{
{
// Reference output values in range [14.4586, 18.4492].
// Reference output values in range [14.4586, 18.4492].
runTorchNet
(
"net_reshape_single_sample"
,
""
,
false
,
false
,
runTorchNet
(
"net_reshape_single_sample"
,
""
,
false
,
false
,
true
,
(
target
==
DNN_TARGET_MYRIAD
||
target
==
DNN_TARGET_OPENCL_FP16
)
?
0.0073
:
default_l1
,
(
target
==
DNN_TARGET_MYRIAD
||
target
==
DNN_TARGET_OPENCL_FP16
)
?
0.0073
:
default_l1
,
(
target
==
DNN_TARGET_MYRIAD
||
target
==
DNN_TARGET_OPENCL_FP16
)
?
0.025
:
default_lInf
);
(
target
==
DNN_TARGET_MYRIAD
||
target
==
DNN_TARGET_OPENCL_FP16
)
?
0.025
:
default_lInf
);
}
}
...
@@ -166,7 +166,7 @@ TEST_P(Test_Torch_layers, run_concat)
...
@@ -166,7 +166,7 @@ TEST_P(Test_Torch_layers, run_concat)
TEST_P
(
Test_Torch_layers
,
run_depth_concat
)
TEST_P
(
Test_Torch_layers
,
run_depth_concat
)
{
{
runTorchNet
(
"net_depth_concat"
,
""
,
false
,
true
,
0.0
,
runTorchNet
(
"net_depth_concat"
,
""
,
false
,
true
,
true
,
0.0
,
target
==
DNN_TARGET_OPENCL_FP16
?
0.021
:
0.0
);
target
==
DNN_TARGET_OPENCL_FP16
?
0.021
:
0.0
);
}
}
...
@@ -182,6 +182,7 @@ TEST_P(Test_Torch_layers, run_deconv)
...
@@ -182,6 +182,7 @@ TEST_P(Test_Torch_layers, run_deconv)
TEST_P
(
Test_Torch_layers
,
run_batch_norm
)
TEST_P
(
Test_Torch_layers
,
run_batch_norm
)
{
{
runTorchNet
(
"net_batch_norm"
,
""
,
false
,
true
);
runTorchNet
(
"net_batch_norm"
,
""
,
false
,
true
);
runTorchNet
(
"net_batch_norm_train"
,
""
,
false
,
true
,
false
);
}
}
TEST_P
(
Test_Torch_layers
,
net_prelu
)
TEST_P
(
Test_Torch_layers
,
net_prelu
)
...
@@ -216,7 +217,7 @@ TEST_P(Test_Torch_layers, net_conv_gemm_lrn)
...
@@ -216,7 +217,7 @@ TEST_P(Test_Torch_layers, net_conv_gemm_lrn)
{
{
if
(
backend
==
DNN_BACKEND_INFERENCE_ENGINE
&&
target
==
DNN_TARGET_MYRIAD
)
if
(
backend
==
DNN_BACKEND_INFERENCE_ENGINE
&&
target
==
DNN_TARGET_MYRIAD
)
throw
SkipTestException
(
""
);
throw
SkipTestException
(
""
);
runTorchNet
(
"net_conv_gemm_lrn"
,
""
,
false
,
true
,
runTorchNet
(
"net_conv_gemm_lrn"
,
""
,
false
,
true
,
true
,
target
==
DNN_TARGET_OPENCL_FP16
?
0.046
:
0.0
,
target
==
DNN_TARGET_OPENCL_FP16
?
0.046
:
0.0
,
target
==
DNN_TARGET_OPENCL_FP16
?
0.023
:
0.0
);
target
==
DNN_TARGET_OPENCL_FP16
?
0.023
:
0.0
);
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录