From 400197930971632d970c73bb47d9fe39c8955d78 Mon Sep 17 00:00:00 2001 From: huangxu96 <46740794+huangxu96@users.noreply.github.com> Date: Thu, 10 Dec 2020 15:13:38 +0800 Subject: [PATCH] Add ReserveSpace in dygraph batch_norm. (#29221) * Add ReserveSpace in dygraph batch_norm. * Add unittest for reservespace --- python/paddle/fluid/dygraph/nn.py | 15 +++++++++++++++ .../fluid/tests/unittests/test_batch_norm_op.py | 14 ++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 12ea7c5ff6..fd2a1e70e2 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -31,6 +31,7 @@ from ..data_feeder import check_variable_and_dtype, check_type import numpy as np import numbers import logging +import os import paddle.utils.deprecated as deprecated __all__ = [ @@ -1308,6 +1309,12 @@ class BatchNorm(layers.Layer): dtype=self._dtype) self._variance.stop_gradient = True + self._has_reserve_space = False + if data_layout == 'NHWC': + flag = os.environ.get('FLAGS_cudnn_batchnorm_spatial_persistent') + if flag is not None and flag.lower() in ['true', '1']: + self._has_reserve_space = True + self._in_place = in_place self._data_layout = data_layout self._momentum = momentum @@ -1364,6 +1371,12 @@ class BatchNorm(layers.Layer): dtype=self._dtype, stop_gradient=True) saved_variance = self._helper.create_variable_for_type_inference( dtype=self._dtype, stop_gradient=True) + + reserve_space = None + if self._has_reserve_space: + reserve_space = self._helper.create_variable_for_type_inference( + dtype=core.VarDesc.VarType.FP16, stop_gradient=True) + batch_norm_out = input if self._in_place else self._helper.create_variable_for_type_inference( self._dtype) @@ -1374,6 +1387,8 @@ class BatchNorm(layers.Layer): "SavedMean": [saved_mean], "SavedVariance": [saved_variance] } + if reserve_space is not None: + outputs["ReserveSpace"] = reserve_space self._helper.append_op( type="batch_norm", inputs=inputs, outputs=outputs, attrs=attrs) diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index a8c5b991b0..14a30d15ae 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -17,6 +17,7 @@ from __future__ import print_function import os import unittest import numpy as np +import paddle import paddle.fluid.core as core from paddle.fluid.op import Operator import paddle.fluid as fluid @@ -671,5 +672,18 @@ class TestDygraphBatchNormTrainableStats(unittest.TestCase): self.assertTrue(np.allclose(y1, y2)) +class TestDygraphBatchNormOpenReserveSpace(unittest.TestCase): + def test_reservespace(self): + with program_guard(Program(), Program()): + paddle.enable_static() + x = np.random.random(size=(3, 10, 3, 7)).astype('float32') + x = fluid.data(name='x', shape=x.shape, dtype=x.dtype) + # Set this FLAG, the BatchNorm API will pass "reserve_space" argument into batch_norm op. + os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = '1' + batch_norm = fluid.dygraph.BatchNorm(7, data_layout="NHWC") + hidden1 = batch_norm(x) + os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = '0' + + if __name__ == '__main__': unittest.main() -- GitLab