From 840c892abd8eb1cacc71a3f38330483ae38a02d9 Mon Sep 17 00:00:00 2001 From: Dmitry Kurtaev Date: Fri, 21 Dec 2018 09:11:28 +0300 Subject: [PATCH] Batch normalization in training phase from Torch --- modules/dnn/include/opencv2/dnn/dnn.hpp | 7 ++++--- modules/dnn/src/torch/torch_importer.cpp | 13 ++++++++----- modules/dnn/test/test_torch_importer.cpp | 13 +++++++------ 3 files changed, 19 insertions(+), 14 deletions(-) diff --git a/modules/dnn/include/opencv2/dnn/dnn.hpp b/modules/dnn/include/opencv2/dnn/dnn.hpp index 2e34b4ae7d..c0e84b82fa 100644 --- a/modules/dnn/include/opencv2/dnn/dnn.hpp +++ b/modules/dnn/include/opencv2/dnn/dnn.hpp @@ -46,9 +46,9 @@ #include #if !defined CV_DOXYGEN && !defined CV_DNN_DONT_ADD_EXPERIMENTAL_NS -#define CV__DNN_EXPERIMENTAL_NS_BEGIN namespace experimental_dnn_34_v10 { +#define CV__DNN_EXPERIMENTAL_NS_BEGIN namespace experimental_dnn_34_v11 { #define CV__DNN_EXPERIMENTAL_NS_END } -namespace cv { namespace dnn { namespace experimental_dnn_34_v10 { } using namespace experimental_dnn_34_v10; }} +namespace cv { namespace dnn { namespace experimental_dnn_34_v11 { } using namespace experimental_dnn_34_v11; }} #else #define CV__DNN_EXPERIMENTAL_NS_BEGIN #define CV__DNN_EXPERIMENTAL_NS_END @@ -754,6 +754,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN * @brief Reads a network model stored in Torch7 framework's format. * @param model path to the file, dumped from Torch by using torch.save() function. * @param isBinary specifies whether the network was serialized in ascii mode or binary. + * @param evaluate specifies testing phase of network. If true, it's similar to evaluate() method in Torch. * @returns Net object. * * @note Ascii mode of Torch serializer is more preferable, because binary mode extensively use `long` type of C language, @@ -775,7 +776,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN * * Also some equivalents of these classes from cunn, cudnn, and fbcunn may be successfully imported. */ - CV_EXPORTS_W Net readNetFromTorch(const String &model, bool isBinary = true); + CV_EXPORTS_W Net readNetFromTorch(const String &model, bool isBinary = true, bool evaluate = true); /** * @brief Read deep learning network represented in one of the supported formats. diff --git a/modules/dnn/src/torch/torch_importer.cpp b/modules/dnn/src/torch/torch_importer.cpp index cbcc83bf6c..0ecb74dba5 100644 --- a/modules/dnn/src/torch/torch_importer.cpp +++ b/modules/dnn/src/torch/torch_importer.cpp @@ -129,13 +129,15 @@ struct TorchImporter Module *rootModule; Module *curModule; int moduleCounter; + bool testPhase; - TorchImporter(String filename, bool isBinary) + TorchImporter(String filename, bool isBinary, bool evaluate) { CV_TRACE_FUNCTION(); rootModule = curModule = NULL; moduleCounter = 0; + testPhase = evaluate; file = cv::Ptr(THDiskFile_new(filename, "r", 0), THFile_free); CV_Assert(file && THFile_isOpened(file)); @@ -680,7 +682,8 @@ struct TorchImporter layerParams.blobs.push_back(tensorParams["bias"].second); } - if (nnName == "InstanceNormalization") + bool trainPhase = scalarParams.get("train", false); + if (nnName == "InstanceNormalization" || (trainPhase && !testPhase)) { cv::Ptr mvnModule(new Module(nnName)); mvnModule->apiType = "MVN"; @@ -1243,18 +1246,18 @@ struct TorchImporter Mat readTorchBlob(const String &filename, bool isBinary) { - TorchImporter importer(filename, isBinary); + TorchImporter importer(filename, isBinary, true); importer.readObject(); CV_Assert(importer.tensors.size() == 1); return importer.tensors.begin()->second; } -Net readNetFromTorch(const String &model, bool isBinary) +Net readNetFromTorch(const String &model, bool isBinary, bool evaluate) { CV_TRACE_FUNCTION(); - TorchImporter importer(model, isBinary); + TorchImporter importer(model, isBinary, evaluate); Net net; importer.populateNet(net); return net; diff --git a/modules/dnn/test/test_torch_importer.cpp b/modules/dnn/test/test_torch_importer.cpp index 7fa0dc47ef..4abf5c2c00 100644 --- a/modules/dnn/test/test_torch_importer.cpp +++ b/modules/dnn/test/test_torch_importer.cpp @@ -73,7 +73,7 @@ class Test_Torch_layers : public DNNTestLayer { public: void runTorchNet(const String& prefix, String outLayerName = "", - bool check2ndBlob = false, bool isBinary = false, + bool check2ndBlob = false, bool isBinary = false, bool evaluate = true, double l1 = 0.0, double lInf = 0.0) { String suffix = (isBinary) ? ".dat" : ".txt"; @@ -84,7 +84,7 @@ public: checkBackend(backend, target, &inp, &outRef); - Net net = readNetFromTorch(_tf(prefix + "_net" + suffix), isBinary); + Net net = readNetFromTorch(_tf(prefix + "_net" + suffix), isBinary, evaluate); ASSERT_FALSE(net.empty()); net.setPreferableBackend(backend); @@ -114,7 +114,7 @@ TEST_P(Test_Torch_layers, run_convolution) // Output reference values are in range [23.4018, 72.0181] double l1 = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.08 : default_l1; double lInf = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.42 : default_lInf; - runTorchNet("net_conv", "", false, true, l1, lInf); + runTorchNet("net_conv", "", false, true, true, l1, lInf); } TEST_P(Test_Torch_layers, run_pool_max) @@ -147,7 +147,7 @@ TEST_P(Test_Torch_layers, run_reshape) TEST_P(Test_Torch_layers, run_reshape_single_sample) { // Reference output values in range [14.4586, 18.4492]. - runTorchNet("net_reshape_single_sample", "", false, false, + runTorchNet("net_reshape_single_sample", "", false, false, true, (target == DNN_TARGET_MYRIAD || target == DNN_TARGET_OPENCL_FP16) ? 0.0073 : default_l1, (target == DNN_TARGET_MYRIAD || target == DNN_TARGET_OPENCL_FP16) ? 0.025 : default_lInf); } @@ -166,7 +166,7 @@ TEST_P(Test_Torch_layers, run_concat) TEST_P(Test_Torch_layers, run_depth_concat) { - runTorchNet("net_depth_concat", "", false, true, 0.0, + runTorchNet("net_depth_concat", "", false, true, true, 0.0, target == DNN_TARGET_OPENCL_FP16 ? 0.021 : 0.0); } @@ -182,6 +182,7 @@ TEST_P(Test_Torch_layers, run_deconv) TEST_P(Test_Torch_layers, run_batch_norm) { runTorchNet("net_batch_norm", "", false, true); + runTorchNet("net_batch_norm_train", "", false, true, false); } TEST_P(Test_Torch_layers, net_prelu) @@ -216,7 +217,7 @@ TEST_P(Test_Torch_layers, net_conv_gemm_lrn) { if (backend == DNN_BACKEND_INFERENCE_ENGINE && target == DNN_TARGET_MYRIAD) throw SkipTestException(""); - runTorchNet("net_conv_gemm_lrn", "", false, true, + runTorchNet("net_conv_gemm_lrn", "", false, true, true, target == DNN_TARGET_OPENCL_FP16 ? 0.046 : 0.0, target == DNN_TARGET_OPENCL_FP16 ? 0.023 : 0.0); } -- GitLab