未验证 提交 b059fb95 编写于 作者: Q qingqing01 提交者: GitHub

Add trainable_statistics in attr for batch_norm. (#24072)

* Add trainable_statistics in attr for batch_norm
* Unifying behavior of dynamic graph and static graph
上级 7dac3226
...@@ -33,7 +33,9 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -33,7 +33,9 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "BatchNorm"); OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "BatchNorm");
bool is_test = ctx->Attrs().Get<bool>("is_test"); bool is_test = ctx->Attrs().Get<bool>("is_test");
if (!is_test) { bool trainable_stats = ctx->Attrs().Get<bool>("trainable_statistics");
bool test_mode = is_test && (!trainable_stats);
if (!test_mode) {
OP_INOUT_CHECK(ctx->HasOutput("MeanOut"), "Output", "MeanOut", "BatchNorm"); OP_INOUT_CHECK(ctx->HasOutput("MeanOut"), "Output", "MeanOut", "BatchNorm");
OP_INOUT_CHECK(ctx->HasOutput("VarianceOut"), "Output", "VarianceOut", OP_INOUT_CHECK(ctx->HasOutput("VarianceOut"), "Output", "VarianceOut",
"BatchNorm"); "BatchNorm");
...@@ -258,7 +260,11 @@ void BatchNormOpMaker::Make() { ...@@ -258,7 +260,11 @@ void BatchNormOpMaker::Make() {
"global mean and variance are also used during train time, " "global mean and variance are also used during train time, "
"the BN acts as scaling and shiffting.") "the BN acts as scaling and shiffting.")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("trainable_statistics",
"(bool, default false) Whether to calculate mean and variance "
"in test mode. If setting true in test mode, mean and variace "
"will be calculated by current batch statistics.")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
Batch Normalization. Batch Normalization.
...@@ -281,8 +287,10 @@ class BatchNormKernel<platform::CPUDeviceContext, T> ...@@ -281,8 +287,10 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
float momentum = ctx.Attr<float>("momentum"); float momentum = ctx.Attr<float>("momentum");
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats"); const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const bool trainable_stats = ctx.Attr<bool>("trainable_statistics");
bool test_mode = is_test && (!trainable_stats);
bool global_stats = is_test || use_global_stats; bool global_stats = test_mode || use_global_stats;
const std::string data_layout_str = ctx.Attr<std::string>("data_layout"); const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = const DataLayout data_layout =
......
...@@ -47,10 +47,13 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -47,10 +47,13 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
float momentum = ctx.Attr<float>("momentum"); float momentum = ctx.Attr<float>("momentum");
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats"); const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const bool trainable_stats = ctx.Attr<bool>("trainable_statistics");
const std::string data_layout_str = ctx.Attr<std::string>("data_layout"); const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str); framework::StringToDataLayout(data_layout_str);
bool test_mode = is_test && (!trainable_stats);
// Get the size for each dimension. // Get the size for each dimension.
// NCHW [batch_size, in_channels, in_height, in_width] // NCHW [batch_size, in_channels, in_height, in_width]
const auto *x = ctx.Input<Tensor>("X"); const auto *x = ctx.Input<Tensor>("X");
...@@ -66,7 +69,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -66,7 +69,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
auto dtype = platform::CudnnDataType<T>::type; auto dtype = platform::CudnnDataType<T>::type;
const bool fast_nhwc_batch_norm = const bool fast_nhwc_batch_norm =
is_test || test_mode ||
(dtype == CUDNN_DATA_HALF && FLAGS_cudnn_batchnorm_spatial_persistent); (dtype == CUDNN_DATA_HALF && FLAGS_cudnn_batchnorm_spatial_persistent);
auto compute_format = auto compute_format =
...@@ -133,7 +136,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -133,7 +136,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDeriveBNTensorDescriptor( platform::dynload::cudnnDeriveBNTensorDescriptor(
bn_param_desc_, data_desc_, bn_param_desc_, data_desc_,
is_test ? CUDNN_BATCHNORM_SPATIAL : mode_)); test_mode ? CUDNN_BATCHNORM_SPATIAL : mode_));
const auto *scale = ctx.Input<Tensor>("Scale"); const auto *scale = ctx.Input<Tensor>("Scale");
const auto *bias = ctx.Input<Tensor>("Bias"); const auto *bias = ctx.Input<Tensor>("Bias");
...@@ -143,7 +146,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -143,7 +146,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
// Now, depending on whether we are running test or not, we have two paths. // Now, depending on whether we are running test or not, we have two paths.
if (is_test || use_global_stats) { if (test_mode || use_global_stats) {
// only when test we use input to do computation. // only when test we use input to do computation.
const auto *est_mean = ctx.Input<Tensor>("Mean"); const auto *est_mean = ctx.Input<Tensor>("Mean");
const auto *est_var = ctx.Input<Tensor>("Variance"); const auto *est_var = ctx.Input<Tensor>("Variance");
......
...@@ -120,8 +120,10 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -120,8 +120,10 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats"); const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const bool fuse_with_relu = ctx.Attr<bool>("fuse_with_relu"); const bool fuse_with_relu = ctx.Attr<bool>("fuse_with_relu");
const bool trainable_stats = ctx.Attr<bool>("trainable_statistics");
bool test_mode = is_test && (!trainable_stats);
bool global_stats = is_test || use_global_stats; bool global_stats = test_mode || use_global_stats;
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
...@@ -156,7 +158,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -156,7 +158,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto flags = mkldnn::normalization_flags::use_scale_shift; // 001 auto flags = mkldnn::normalization_flags::use_scale_shift; // 001
if (global_stats) if (global_stats)
flags |= mkldnn::normalization_flags::use_global_stats; // 010 flags |= mkldnn::normalization_flags::use_global_stats; // 010
if (fuse_with_relu && is_test) if (fuse_with_relu && test_mode)
flags |= mkldnn::normalization_flags::fuse_norm_relu; // 100 flags |= mkldnn::normalization_flags::fuse_norm_relu; // 100
BatchNormMKLDNNHandler<T> handler( BatchNormMKLDNNHandler<T> handler(
......
...@@ -28,6 +28,7 @@ class SyncBatchNormKernel<platform::CUDADeviceContext, T> ...@@ -28,6 +28,7 @@ class SyncBatchNormKernel<platform::CUDADeviceContext, T>
const std::string layout_str = ctx.Attr<std::string>("data_layout"); const std::string layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout layout = framework::StringToDataLayout(layout_str); const DataLayout layout = framework::StringToDataLayout(layout_str);
const bool use_global_stats = ctx.Attr<bool>("use_global_stats"); const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const bool trainable_stats = ctx.Attr<bool>("trainable_statistics");
PADDLE_ENFORCE_EQ(use_global_stats, false, PADDLE_ENFORCE_EQ(use_global_stats, false,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"sync_batch_norm doesn't support " "sync_batch_norm doesn't support "
...@@ -47,9 +48,10 @@ class SyncBatchNormKernel<platform::CUDADeviceContext, T> ...@@ -47,9 +48,10 @@ class SyncBatchNormKernel<platform::CUDADeviceContext, T>
auto *saved_mean = ctx.Output<Tensor>("SavedMean"); auto *saved_mean = ctx.Output<Tensor>("SavedMean");
auto *saved_inv_variance = ctx.Output<Tensor>("SavedVariance"); auto *saved_inv_variance = ctx.Output<Tensor>("SavedVariance");
bool test_mode = is_test && (!trainable_stats);
SyncBatchNormFunctor<platform::CUDADeviceContext, T>( SyncBatchNormFunctor<platform::CUDADeviceContext, T>(
ctx, layout, x, y, est_mean, est_var, mean_out, variance_out, ctx, layout, x, y, est_mean, est_var, mean_out, variance_out,
saved_mean, saved_inv_variance, epsilon, momentum, is_test, saved_mean, saved_inv_variance, epsilon, momentum, test_mode,
use_global_stats); use_global_stats);
} }
}; };
......
...@@ -1276,12 +1276,12 @@ class BatchNorm(layers.Layer): ...@@ -1276,12 +1276,12 @@ class BatchNorm(layers.Layer):
variance_out = self._variance variance_out = self._variance
if in_dygraph_mode(): if in_dygraph_mode():
_is_test = not self.training and not self._trainable_statistics
attrs = ("momentum", self._momentum, "epsilon", self._epsilon, attrs = ("momentum", self._momentum, "epsilon", self._epsilon,
"is_test", _is_test, "data_layout", self._data_layout, "is_test", not self.training, "data_layout",
"use_mkldnn", False, "fuse_with_relu", self._data_layout, "use_mkldnn", False, "fuse_with_relu",
self._fuse_with_relu, "use_global_stats", self._fuse_with_relu, "use_global_stats",
self._use_global_stats) self._use_global_stats, 'trainable_statistics',
self._trainable_statistics)
batch_norm_out, _, _, _, _ = core.ops.batch_norm( batch_norm_out, _, _, _, _ = core.ops.batch_norm(
input, self.weight, self.bias, self._mean, self._variance, input, self.weight, self.bias, self._mean, self._variance,
mean_out, variance_out, *attrs) mean_out, variance_out, *attrs)
...@@ -1298,7 +1298,8 @@ class BatchNorm(layers.Layer): ...@@ -1298,7 +1298,8 @@ class BatchNorm(layers.Layer):
"data_layout": self._data_layout, "data_layout": self._data_layout,
"use_mkldnn": False, "use_mkldnn": False,
"fuse_with_relu": self._fuse_with_relu, "fuse_with_relu": self._fuse_with_relu,
"use_global_stats": self._use_global_stats "use_global_stats": self._use_global_stats,
"trainable_statistics": self._trainable_statistics,
} }
inputs = { inputs = {
......
...@@ -2605,16 +2605,7 @@ class Block(object): ...@@ -2605,16 +2605,7 @@ class Block(object):
""" """
if in_dygraph_mode(): if in_dygraph_mode():
attrs = kwargs.get("attrs", {}) attrs = kwargs.get("attrs", {})
if _dygraph_tracer_._train_mode == False:
# eval mode
if ('trainable_statistics' not in attrs
) or not attrs['trainable_statistics']:
attrs['is_test'] = True
else:
attrs['is_test'] = False
type = kwargs.get("type", None) type = kwargs.get("type", None)
op = Operator( op = Operator(
block=self, block=self,
desc=None, desc=None,
......
...@@ -623,5 +623,53 @@ class TestDygraphBatchNormAPIError(unittest.TestCase): ...@@ -623,5 +623,53 @@ class TestDygraphBatchNormAPIError(unittest.TestCase):
self.assertRaises(TypeError, batch_norm, x2) self.assertRaises(TypeError, batch_norm, x2)
class TestDygraphBatchNormTrainableStats(unittest.TestCase):
def test_dygraph(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
places.append(fluid.CUDAPlace(0))
for p in places:
shape = [4, 10, 4, 4]
def compute(x, is_test, trainable_statistics):
with fluid.dygraph.guard(p):
bn = fluid.dygraph.BatchNorm(
shape[1],
is_test=is_test,
trainable_statistics=trainable_statistics)
y = bn(fluid.dygraph.to_variable(x))
return y.numpy()
x = np.random.randn(*shape).astype("float32")
y1 = compute(x, False, False)
y2 = compute(x, True, True)
self.assertTrue(np.allclose(y1, y2))
def test_static(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
places.append(fluid.CUDAPlace(0))
for p in places:
exe = fluid.Executor(p)
shape = [4, 10, 16, 16]
def compute(x_np, is_test, trainable_statistics):
with program_guard(Program(), Program()):
bn = fluid.dygraph.BatchNorm(
shape[1],
is_test=is_test,
trainable_statistics=trainable_statistics)
x = fluid.data(name='x', shape=x_np.shape, dtype=x_np.dtype)
y = bn(x)
exe.run(fluid.default_startup_program())
r = exe.run(feed={'x': x_np}, fetch_list=[y])[0]
return r
x = np.random.randn(*shape).astype("float32")
y1 = compute(x, False, False)
y2 = compute(x, True, True)
self.assertTrue(np.allclose(y1, y2))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册