提交 840c892a 编写于 作者: D Dmitry Kurtaev

Batch normalization in training phase from Torch

上级 09d8bbb1
......@@ -46,9 +46,9 @@
#include <opencv2/core.hpp>
#if !defined CV_DOXYGEN && !defined CV_DNN_DONT_ADD_EXPERIMENTAL_NS
#define CV__DNN_EXPERIMENTAL_NS_BEGIN namespace experimental_dnn_34_v10 {
#define CV__DNN_EXPERIMENTAL_NS_BEGIN namespace experimental_dnn_34_v11 {
#define CV__DNN_EXPERIMENTAL_NS_END }
namespace cv { namespace dnn { namespace experimental_dnn_34_v10 { } using namespace experimental_dnn_34_v10; }}
namespace cv { namespace dnn { namespace experimental_dnn_34_v11 { } using namespace experimental_dnn_34_v11; }}
#else
#define CV__DNN_EXPERIMENTAL_NS_BEGIN
#define CV__DNN_EXPERIMENTAL_NS_END
......@@ -754,6 +754,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
* @brief Reads a network model stored in <a href="http://torch.ch">Torch7</a> framework's format.
* @param model path to the file, dumped from Torch by using torch.save() function.
* @param isBinary specifies whether the network was serialized in ascii mode or binary.
* @param evaluate specifies testing phase of network. If true, it's similar to evaluate() method in Torch.
* @returns Net object.
*
* @note Ascii mode of Torch serializer is more preferable, because binary mode extensively use `long` type of C language,
......@@ -775,7 +776,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
*
* Also some equivalents of these classes from cunn, cudnn, and fbcunn may be successfully imported.
*/
CV_EXPORTS_W Net readNetFromTorch(const String &model, bool isBinary = true);
CV_EXPORTS_W Net readNetFromTorch(const String &model, bool isBinary = true, bool evaluate = true);
/**
* @brief Read deep learning network represented in one of the supported formats.
......
......@@ -129,13 +129,15 @@ struct TorchImporter
Module *rootModule;
Module *curModule;
int moduleCounter;
bool testPhase;
TorchImporter(String filename, bool isBinary)
TorchImporter(String filename, bool isBinary, bool evaluate)
{
CV_TRACE_FUNCTION();
rootModule = curModule = NULL;
moduleCounter = 0;
testPhase = evaluate;
file = cv::Ptr<THFile>(THDiskFile_new(filename, "r", 0), THFile_free);
CV_Assert(file && THFile_isOpened(file));
......@@ -680,7 +682,8 @@ struct TorchImporter
layerParams.blobs.push_back(tensorParams["bias"].second);
}
if (nnName == "InstanceNormalization")
bool trainPhase = scalarParams.get<bool>("train", false);
if (nnName == "InstanceNormalization" || (trainPhase && !testPhase))
{
cv::Ptr<Module> mvnModule(new Module(nnName));
mvnModule->apiType = "MVN";
......@@ -1243,18 +1246,18 @@ struct TorchImporter
Mat readTorchBlob(const String &filename, bool isBinary)
{
TorchImporter importer(filename, isBinary);
TorchImporter importer(filename, isBinary, true);
importer.readObject();
CV_Assert(importer.tensors.size() == 1);
return importer.tensors.begin()->second;
}
Net readNetFromTorch(const String &model, bool isBinary)
Net readNetFromTorch(const String &model, bool isBinary, bool evaluate)
{
CV_TRACE_FUNCTION();
TorchImporter importer(model, isBinary);
TorchImporter importer(model, isBinary, evaluate);
Net net;
importer.populateNet(net);
return net;
......
......@@ -73,7 +73,7 @@ class Test_Torch_layers : public DNNTestLayer
{
public:
void runTorchNet(const String& prefix, String outLayerName = "",
bool check2ndBlob = false, bool isBinary = false,
bool check2ndBlob = false, bool isBinary = false, bool evaluate = true,
double l1 = 0.0, double lInf = 0.0)
{
String suffix = (isBinary) ? ".dat" : ".txt";
......@@ -84,7 +84,7 @@ public:
checkBackend(backend, target, &inp, &outRef);
Net net = readNetFromTorch(_tf(prefix + "_net" + suffix), isBinary);
Net net = readNetFromTorch(_tf(prefix + "_net" + suffix), isBinary, evaluate);
ASSERT_FALSE(net.empty());
net.setPreferableBackend(backend);
......@@ -114,7 +114,7 @@ TEST_P(Test_Torch_layers, run_convolution)
// Output reference values are in range [23.4018, 72.0181]
double l1 = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.08 : default_l1;
double lInf = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.42 : default_lInf;
runTorchNet("net_conv", "", false, true, l1, lInf);
runTorchNet("net_conv", "", false, true, true, l1, lInf);
}
TEST_P(Test_Torch_layers, run_pool_max)
......@@ -147,7 +147,7 @@ TEST_P(Test_Torch_layers, run_reshape)
TEST_P(Test_Torch_layers, run_reshape_single_sample)
{
// Reference output values in range [14.4586, 18.4492].
runTorchNet("net_reshape_single_sample", "", false, false,
runTorchNet("net_reshape_single_sample", "", false, false, true,
(target == DNN_TARGET_MYRIAD || target == DNN_TARGET_OPENCL_FP16) ? 0.0073 : default_l1,
(target == DNN_TARGET_MYRIAD || target == DNN_TARGET_OPENCL_FP16) ? 0.025 : default_lInf);
}
......@@ -166,7 +166,7 @@ TEST_P(Test_Torch_layers, run_concat)
TEST_P(Test_Torch_layers, run_depth_concat)
{
runTorchNet("net_depth_concat", "", false, true, 0.0,
runTorchNet("net_depth_concat", "", false, true, true, 0.0,
target == DNN_TARGET_OPENCL_FP16 ? 0.021 : 0.0);
}
......@@ -182,6 +182,7 @@ TEST_P(Test_Torch_layers, run_deconv)
TEST_P(Test_Torch_layers, run_batch_norm)
{
runTorchNet("net_batch_norm", "", false, true);
runTorchNet("net_batch_norm_train", "", false, true, false);
}
TEST_P(Test_Torch_layers, net_prelu)
......@@ -216,7 +217,7 @@ TEST_P(Test_Torch_layers, net_conv_gemm_lrn)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE && target == DNN_TARGET_MYRIAD)
throw SkipTestException("");
runTorchNet("net_conv_gemm_lrn", "", false, true,
runTorchNet("net_conv_gemm_lrn", "", false, true, true,
target == DNN_TARGET_OPENCL_FP16 ? 0.046 : 0.0,
target == DNN_TARGET_OPENCL_FP16 ? 0.023 : 0.0);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册