diff --git a/modules/dnn/src/layers/batch_norm_layer.cpp b/modules/dnn/src/layers/batch_norm_layer.cpp index 109f141352de80cfc6c3c4e4626472ecdd34b35b..ded8ae051db3b5f748f0a0f4ab216231d9905198 100644 --- a/modules/dnn/src/layers/batch_norm_layer.cpp +++ b/modules/dnn/src/layers/batch_norm_layer.cpp @@ -94,6 +94,15 @@ public: dstWeightsData[i] = w; dstBiasData[i] = (hasBias ? biasData[i] : 0.0f) - w * meanData[i] * varMeanScale; } + // We will use blobs to store origin weights and bias to restore them in case of reinitialization. + weights_.copyTo(blobs[0].reshape(1, 1)); + bias_.copyTo(blobs[1].reshape(1, 1)); + } + + virtual void finalize(InputArrayOfArrays, OutputArrayOfArrays) CV_OVERRIDE + { + blobs[0].reshape(1, 1).copyTo(weights_); + blobs[1].reshape(1, 1).copyTo(bias_); } void getScaleShift(Mat& scale, Mat& shift) const CV_OVERRIDE diff --git a/modules/dnn/test/test_layers.cpp b/modules/dnn/test/test_layers.cpp index 8124764fc136878e3478947d11cd8d3c4f7bb7ba..c31b9f3720312d3179c157fa573449016b9b2c7a 100644 --- a/modules/dnn/test/test_layers.cpp +++ b/modules/dnn/test/test_layers.cpp @@ -1780,4 +1780,61 @@ TEST_P(Layer_Test_Slice, variable_input_shape) INSTANTIATE_TEST_CASE_P(/**/, Layer_Test_Slice, dnnBackendsAndTargets()); +typedef testing::TestWithParam > Layer_Test_BatchNorm; +TEST_P(Layer_Test_BatchNorm, fusion) +{ + // This tests reinitializes network by forwarding different batch size input. + // We check BatchNorm layer weights restoring after fusion. + int backendId = get<0>(GetParam()); + int targetId = get<1>(GetParam()); + const int ch = 4; + + Mat mean(1, ch, CV_32F), var(1, ch, CV_32F), weights(1, ch, CV_32F); + randu(mean, 0, 1); + randu(var, 0, 1); + randu(weights, 0, 1); + + Net net; + { + LayerParams lp; + lp.type = "BatchNorm"; + lp.name = "bn"; + lp.set("has_weight", false); + lp.set("has_bias", false); + lp.blobs.push_back(mean); + lp.blobs.push_back(var); + net.addLayerToPrev(lp.name, lp.type, lp); + } + { + LayerParams lp; + lp.type = "Scale"; + lp.name = "scale"; + lp.set("has_bias", false); + lp.blobs.push_back(weights); + net.addLayerToPrev(lp.name, lp.type, lp); + } + + Mat inp(4, 5, CV_32FC(ch)); + randu(inp, 0, 1); + + net.setPreferableBackend(backendId); + net.setPreferableTarget(targetId); + + net.setInput(blobFromImage(inp)); + Mat ref = net.forward(); + + net.setInput(blobFromImages(std::vector(2, inp))); + Mat out = net.forward(); + + for (int i = 0; i < 2; ++i) + { + std::vector ranges(4, Range::all()); + ranges[0].start = i; + ranges[0].end = i + 1; + normAssert(out(ranges), ref); + } +} + +INSTANTIATE_TEST_CASE_P(/**/, Layer_Test_BatchNorm, dnnBackendsAndTargets()); + }} // namespace