提交 840c892a 编写于 作者: D Dmitry Kurtaev

Batch normalization in training phase from Torch

上级 09d8bbb1
...@@ -46,9 +46,9 @@ ...@@ -46,9 +46,9 @@
#include <opencv2/core.hpp> #include <opencv2/core.hpp>
#if !defined CV_DOXYGEN && !defined CV_DNN_DONT_ADD_EXPERIMENTAL_NS #if !defined CV_DOXYGEN && !defined CV_DNN_DONT_ADD_EXPERIMENTAL_NS
#define CV__DNN_EXPERIMENTAL_NS_BEGIN namespace experimental_dnn_34_v10 { #define CV__DNN_EXPERIMENTAL_NS_BEGIN namespace experimental_dnn_34_v11 {
#define CV__DNN_EXPERIMENTAL_NS_END } #define CV__DNN_EXPERIMENTAL_NS_END }
namespace cv { namespace dnn { namespace experimental_dnn_34_v10 { } using namespace experimental_dnn_34_v10; }} namespace cv { namespace dnn { namespace experimental_dnn_34_v11 { } using namespace experimental_dnn_34_v11; }}
#else #else
#define CV__DNN_EXPERIMENTAL_NS_BEGIN #define CV__DNN_EXPERIMENTAL_NS_BEGIN
#define CV__DNN_EXPERIMENTAL_NS_END #define CV__DNN_EXPERIMENTAL_NS_END
...@@ -754,6 +754,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN ...@@ -754,6 +754,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
* @brief Reads a network model stored in <a href="http://torch.ch">Torch7</a> framework's format. * @brief Reads a network model stored in <a href="http://torch.ch">Torch7</a> framework's format.
* @param model path to the file, dumped from Torch by using torch.save() function. * @param model path to the file, dumped from Torch by using torch.save() function.
* @param isBinary specifies whether the network was serialized in ascii mode or binary. * @param isBinary specifies whether the network was serialized in ascii mode or binary.
* @param evaluate specifies testing phase of network. If true, it's similar to evaluate() method in Torch.
* @returns Net object. * @returns Net object.
* *
* @note Ascii mode of Torch serializer is more preferable, because binary mode extensively use `long` type of C language, * @note Ascii mode of Torch serializer is more preferable, because binary mode extensively use `long` type of C language,
...@@ -775,7 +776,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN ...@@ -775,7 +776,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
* *
* Also some equivalents of these classes from cunn, cudnn, and fbcunn may be successfully imported. * Also some equivalents of these classes from cunn, cudnn, and fbcunn may be successfully imported.
*/ */
CV_EXPORTS_W Net readNetFromTorch(const String &model, bool isBinary = true); CV_EXPORTS_W Net readNetFromTorch(const String &model, bool isBinary = true, bool evaluate = true);
/** /**
* @brief Read deep learning network represented in one of the supported formats. * @brief Read deep learning network represented in one of the supported formats.
......
...@@ -129,13 +129,15 @@ struct TorchImporter ...@@ -129,13 +129,15 @@ struct TorchImporter
Module *rootModule; Module *rootModule;
Module *curModule; Module *curModule;
int moduleCounter; int moduleCounter;
bool testPhase;
TorchImporter(String filename, bool isBinary) TorchImporter(String filename, bool isBinary, bool evaluate)
{ {
CV_TRACE_FUNCTION(); CV_TRACE_FUNCTION();
rootModule = curModule = NULL; rootModule = curModule = NULL;
moduleCounter = 0; moduleCounter = 0;
testPhase = evaluate;
file = cv::Ptr<THFile>(THDiskFile_new(filename, "r", 0), THFile_free); file = cv::Ptr<THFile>(THDiskFile_new(filename, "r", 0), THFile_free);
CV_Assert(file && THFile_isOpened(file)); CV_Assert(file && THFile_isOpened(file));
...@@ -680,7 +682,8 @@ struct TorchImporter ...@@ -680,7 +682,8 @@ struct TorchImporter
layerParams.blobs.push_back(tensorParams["bias"].second); layerParams.blobs.push_back(tensorParams["bias"].second);
} }
if (nnName == "InstanceNormalization") bool trainPhase = scalarParams.get<bool>("train", false);
if (nnName == "InstanceNormalization" || (trainPhase && !testPhase))
{ {
cv::Ptr<Module> mvnModule(new Module(nnName)); cv::Ptr<Module> mvnModule(new Module(nnName));
mvnModule->apiType = "MVN"; mvnModule->apiType = "MVN";
...@@ -1243,18 +1246,18 @@ struct TorchImporter ...@@ -1243,18 +1246,18 @@ struct TorchImporter
Mat readTorchBlob(const String &filename, bool isBinary) Mat readTorchBlob(const String &filename, bool isBinary)
{ {
TorchImporter importer(filename, isBinary); TorchImporter importer(filename, isBinary, true);
importer.readObject(); importer.readObject();
CV_Assert(importer.tensors.size() == 1); CV_Assert(importer.tensors.size() == 1);
return importer.tensors.begin()->second; return importer.tensors.begin()->second;
} }
Net readNetFromTorch(const String &model, bool isBinary) Net readNetFromTorch(const String &model, bool isBinary, bool evaluate)
{ {
CV_TRACE_FUNCTION(); CV_TRACE_FUNCTION();
TorchImporter importer(model, isBinary); TorchImporter importer(model, isBinary, evaluate);
Net net; Net net;
importer.populateNet(net); importer.populateNet(net);
return net; return net;
......
...@@ -73,7 +73,7 @@ class Test_Torch_layers : public DNNTestLayer ...@@ -73,7 +73,7 @@ class Test_Torch_layers : public DNNTestLayer
{ {
public: public:
void runTorchNet(const String& prefix, String outLayerName = "", void runTorchNet(const String& prefix, String outLayerName = "",
bool check2ndBlob = false, bool isBinary = false, bool check2ndBlob = false, bool isBinary = false, bool evaluate = true,
double l1 = 0.0, double lInf = 0.0) double l1 = 0.0, double lInf = 0.0)
{ {
String suffix = (isBinary) ? ".dat" : ".txt"; String suffix = (isBinary) ? ".dat" : ".txt";
...@@ -84,7 +84,7 @@ public: ...@@ -84,7 +84,7 @@ public:
checkBackend(backend, target, &inp, &outRef); checkBackend(backend, target, &inp, &outRef);
Net net = readNetFromTorch(_tf(prefix + "_net" + suffix), isBinary); Net net = readNetFromTorch(_tf(prefix + "_net" + suffix), isBinary, evaluate);
ASSERT_FALSE(net.empty()); ASSERT_FALSE(net.empty());
net.setPreferableBackend(backend); net.setPreferableBackend(backend);
...@@ -114,7 +114,7 @@ TEST_P(Test_Torch_layers, run_convolution) ...@@ -114,7 +114,7 @@ TEST_P(Test_Torch_layers, run_convolution)
// Output reference values are in range [23.4018, 72.0181] // Output reference values are in range [23.4018, 72.0181]
double l1 = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.08 : default_l1; double l1 = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.08 : default_l1;
double lInf = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.42 : default_lInf; double lInf = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.42 : default_lInf;
runTorchNet("net_conv", "", false, true, l1, lInf); runTorchNet("net_conv", "", false, true, true, l1, lInf);
} }
TEST_P(Test_Torch_layers, run_pool_max) TEST_P(Test_Torch_layers, run_pool_max)
...@@ -147,7 +147,7 @@ TEST_P(Test_Torch_layers, run_reshape) ...@@ -147,7 +147,7 @@ TEST_P(Test_Torch_layers, run_reshape)
TEST_P(Test_Torch_layers, run_reshape_single_sample) TEST_P(Test_Torch_layers, run_reshape_single_sample)
{ {
// Reference output values in range [14.4586, 18.4492]. // Reference output values in range [14.4586, 18.4492].
runTorchNet("net_reshape_single_sample", "", false, false, runTorchNet("net_reshape_single_sample", "", false, false, true,
(target == DNN_TARGET_MYRIAD || target == DNN_TARGET_OPENCL_FP16) ? 0.0073 : default_l1, (target == DNN_TARGET_MYRIAD || target == DNN_TARGET_OPENCL_FP16) ? 0.0073 : default_l1,
(target == DNN_TARGET_MYRIAD || target == DNN_TARGET_OPENCL_FP16) ? 0.025 : default_lInf); (target == DNN_TARGET_MYRIAD || target == DNN_TARGET_OPENCL_FP16) ? 0.025 : default_lInf);
} }
...@@ -166,7 +166,7 @@ TEST_P(Test_Torch_layers, run_concat) ...@@ -166,7 +166,7 @@ TEST_P(Test_Torch_layers, run_concat)
TEST_P(Test_Torch_layers, run_depth_concat) TEST_P(Test_Torch_layers, run_depth_concat)
{ {
runTorchNet("net_depth_concat", "", false, true, 0.0, runTorchNet("net_depth_concat", "", false, true, true, 0.0,
target == DNN_TARGET_OPENCL_FP16 ? 0.021 : 0.0); target == DNN_TARGET_OPENCL_FP16 ? 0.021 : 0.0);
} }
...@@ -182,6 +182,7 @@ TEST_P(Test_Torch_layers, run_deconv) ...@@ -182,6 +182,7 @@ TEST_P(Test_Torch_layers, run_deconv)
TEST_P(Test_Torch_layers, run_batch_norm) TEST_P(Test_Torch_layers, run_batch_norm)
{ {
runTorchNet("net_batch_norm", "", false, true); runTorchNet("net_batch_norm", "", false, true);
runTorchNet("net_batch_norm_train", "", false, true, false);
} }
TEST_P(Test_Torch_layers, net_prelu) TEST_P(Test_Torch_layers, net_prelu)
...@@ -216,7 +217,7 @@ TEST_P(Test_Torch_layers, net_conv_gemm_lrn) ...@@ -216,7 +217,7 @@ TEST_P(Test_Torch_layers, net_conv_gemm_lrn)
{ {
if (backend == DNN_BACKEND_INFERENCE_ENGINE && target == DNN_TARGET_MYRIAD) if (backend == DNN_BACKEND_INFERENCE_ENGINE && target == DNN_TARGET_MYRIAD)
throw SkipTestException(""); throw SkipTestException("");
runTorchNet("net_conv_gemm_lrn", "", false, true, runTorchNet("net_conv_gemm_lrn", "", false, true, true,
target == DNN_TARGET_OPENCL_FP16 ? 0.046 : 0.0, target == DNN_TARGET_OPENCL_FP16 ? 0.046 : 0.0,
target == DNN_TARGET_OPENCL_FP16 ? 0.023 : 0.0); target == DNN_TARGET_OPENCL_FP16 ? 0.023 : 0.0);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册