未验证 提交 40019793 编写于 作者: H huangxu96 提交者: GitHub

Add ReserveSpace in dygraph batch_norm. (#29221)

* Add ReserveSpace in dygraph batch_norm.

* Add unittest for reservespace
上级 b781953e
......@@ -31,6 +31,7 @@ from ..data_feeder import check_variable_and_dtype, check_type
import numpy as np
import numbers
import logging
import os
import paddle.utils.deprecated as deprecated
__all__ = [
......@@ -1308,6 +1309,12 @@ class BatchNorm(layers.Layer):
dtype=self._dtype)
self._variance.stop_gradient = True
self._has_reserve_space = False
if data_layout == 'NHWC':
flag = os.environ.get('FLAGS_cudnn_batchnorm_spatial_persistent')
if flag is not None and flag.lower() in ['true', '1']:
self._has_reserve_space = True
self._in_place = in_place
self._data_layout = data_layout
self._momentum = momentum
......@@ -1364,6 +1371,12 @@ class BatchNorm(layers.Layer):
dtype=self._dtype, stop_gradient=True)
saved_variance = self._helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True)
reserve_space = None
if self._has_reserve_space:
reserve_space = self._helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.FP16, stop_gradient=True)
batch_norm_out = input if self._in_place else self._helper.create_variable_for_type_inference(
self._dtype)
......@@ -1374,6 +1387,8 @@ class BatchNorm(layers.Layer):
"SavedMean": [saved_mean],
"SavedVariance": [saved_variance]
}
if reserve_space is not None:
outputs["ReserveSpace"] = reserve_space
self._helper.append_op(
type="batch_norm", inputs=inputs, outputs=outputs, attrs=attrs)
......
......@@ -17,6 +17,7 @@ from __future__ import print_function
import os
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
......@@ -671,5 +672,18 @@ class TestDygraphBatchNormTrainableStats(unittest.TestCase):
self.assertTrue(np.allclose(y1, y2))
class TestDygraphBatchNormOpenReserveSpace(unittest.TestCase):
def test_reservespace(self):
with program_guard(Program(), Program()):
paddle.enable_static()
x = np.random.random(size=(3, 10, 3, 7)).astype('float32')
x = fluid.data(name='x', shape=x.shape, dtype=x.dtype)
# Set this FLAG, the BatchNorm API will pass "reserve_space" argument into batch_norm op.
os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = '1'
batch_norm = fluid.dygraph.BatchNorm(7, data_layout="NHWC")
hidden1 = batch_norm(x)
os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = '0'
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册