diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 12ea7c5ff6c6b497c0d28be9f2eda44cfa68c45d..fd2a1e70e2cf0655aab4b663649c15c193465566 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -31,6 +31,7 @@ from ..data_feeder import check_variable_and_dtype, check_type import numpy as np import numbers import logging +import os import paddle.utils.deprecated as deprecated __all__ = [ @@ -1308,6 +1309,12 @@ class BatchNorm(layers.Layer): dtype=self._dtype) self._variance.stop_gradient = True + self._has_reserve_space = False + if data_layout == 'NHWC': + flag = os.environ.get('FLAGS_cudnn_batchnorm_spatial_persistent') + if flag is not None and flag.lower() in ['true', '1']: + self._has_reserve_space = True + self._in_place = in_place self._data_layout = data_layout self._momentum = momentum @@ -1364,6 +1371,12 @@ class BatchNorm(layers.Layer): dtype=self._dtype, stop_gradient=True) saved_variance = self._helper.create_variable_for_type_inference( dtype=self._dtype, stop_gradient=True) + + reserve_space = None + if self._has_reserve_space: + reserve_space = self._helper.create_variable_for_type_inference( + dtype=core.VarDesc.VarType.FP16, stop_gradient=True) + batch_norm_out = input if self._in_place else self._helper.create_variable_for_type_inference( self._dtype) @@ -1374,6 +1387,8 @@ class BatchNorm(layers.Layer): "SavedMean": [saved_mean], "SavedVariance": [saved_variance] } + if reserve_space is not None: + outputs["ReserveSpace"] = reserve_space self._helper.append_op( type="batch_norm", inputs=inputs, outputs=outputs, attrs=attrs) diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index a8c5b991b029192832c1efb33dec230a1929b871..14a30d15aee9d473d1790554fef56cf4a5821fdd 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -17,6 +17,7 @@ from __future__ import print_function import os import unittest import numpy as np +import paddle import paddle.fluid.core as core from paddle.fluid.op import Operator import paddle.fluid as fluid @@ -671,5 +672,18 @@ class TestDygraphBatchNormTrainableStats(unittest.TestCase): self.assertTrue(np.allclose(y1, y2)) +class TestDygraphBatchNormOpenReserveSpace(unittest.TestCase): + def test_reservespace(self): + with program_guard(Program(), Program()): + paddle.enable_static() + x = np.random.random(size=(3, 10, 3, 7)).astype('float32') + x = fluid.data(name='x', shape=x.shape, dtype=x.dtype) + # Set this FLAG, the BatchNorm API will pass "reserve_space" argument into batch_norm op. + os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = '1' + batch_norm = fluid.dygraph.BatchNorm(7, data_layout="NHWC") + hidden1 = batch_norm(x) + os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = '0' + + if __name__ == '__main__': unittest.main()