diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index 12ea7c5ff6c6b497c0d28be9f2eda44cfa68c45d..fd2a1e70e2cf0655aab4b663649c15c193465566 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -31,6 +31,7 @@ from ..data_feeder import check_variable_and_dtype, check_type import numpy as np import numbers import logging +import os import paddle.utils.deprecated as deprecated __all__ = [ @@ -1308,6 +1309,12 @@ class BatchNorm(layers.Layer): dtype=self._dtype) self._variance.stop_gradient = True + self._has_reserve_space = False + if data_layout == 'NHWC': + flag = os.environ.get('FLAGS_cudnn_batchnorm_spatial_persistent') + if flag is not None and flag.lower() in ['true', '1']: + self._has_reserve_space = True + self._in_place = in_place self._data_layout = data_layout self._momentum = momentum @@ -1364,6 +1371,12 @@ class BatchNorm(layers.Layer): dtype=self._dtype, stop_gradient=True) saved_variance = self._helper.create_variable_for_type_inference( dtype=self._dtype, stop_gradient=True) + + reserve_space = None + if self._has_reserve_space: + reserve_space = self._helper.create_variable_for_type_inference( + dtype=core.VarDesc.VarType.FP16, stop_gradient=True) + batch_norm_out = input if self._in_place else self._helper.create_variable_for_type_inference( self._dtype) @@ -1374,6 +1387,8 @@ class BatchNorm(layers.Layer): "SavedMean": [saved_mean], "SavedVariance": [saved_variance] } + if reserve_space is not None: + outputs["ReserveSpace"] = reserve_space self._helper.append_op( type="batch_norm", inputs=inputs, outputs=outputs, attrs=attrs) diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index a8c5b991b029192832c1efb33dec230a1929b871..14a30d15aee9d473d1790554fef56cf4a5821fdd 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -17,6 +17,7 @@ from __future__ import print_function import os import unittest import numpy as np +import paddle import paddle.fluid.core as core from paddle.fluid.op import Operator import paddle.fluid as fluid @@ -671,5 +672,18 @@ class TestDygraphBatchNormTrainableStats(unittest.TestCase): self.assertTrue(np.allclose(y1, y2)) +class TestDygraphBatchNormOpenReserveSpace(unittest.TestCase): + def test_reservespace(self): + with program_guard(Program(), Program()): + paddle.enable_static() + x = np.random.random(size=(3, 10, 3, 7)).astype('float32') + x = fluid.data(name='x', shape=x.shape, dtype=x.dtype) + # Set this FLAG, the BatchNorm API will pass "reserve_space" argument into batch_norm op. + os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = '1' + batch_norm = fluid.dygraph.BatchNorm(7, data_layout="NHWC") + hidden1 = batch_norm(x) + os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = '0' + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/nn/functional/pooling.py b/python/paddle/nn/functional/pooling.py index 9d89b4236113526c849f403e93a1837728f10909..5f3642710ae0adfbdb53f7b5adc81c8b8395a924 100755 --- a/python/paddle/nn/functional/pooling.py +++ b/python/paddle/nn/functional/pooling.py @@ -692,7 +692,8 @@ def max_pool2d(x, return_mask=True) # out.shape [1, 3, 16, 16], max_indices.shape [1, 3, 16, 16], """ - check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'max_pool2d') + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'max_pool2d') kernel_size = utils.convert_to_list(kernel_size, 2, 'pool_size') if stride is None: stride = kernel_size @@ -933,7 +934,8 @@ def adaptive_avg_pool1d(x, output_size, name=None): # pool_out shape: [1, 3, 16]) """ pool_type = 'avg' - check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'adaptive_pool2d') + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'adaptive_pool2d') _check_input(x, 3) check_type(output_size, 'pool_size', (int), 'adaptive_pool1d') @@ -1015,7 +1017,7 @@ def adaptive_avg_pool2d(x, output_size, data_format='NCHW', name=None): # out.shape is [2, 3, 3, 3] """ if not in_dygraph_mode(): - check_variable_and_dtype(x, 'x', ['float32', 'float64'], + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'adaptive_avg_pool2d') check_type(data_format, 'data_format', str, 'adaptive_avg_pool2d')