From a8fd71c7855cdd633d4498df93c11a35f63f5749 Mon Sep 17 00:00:00 2001 From: wuyongkang Date: Thu, 23 Jul 2020 20:27:54 +0800 Subject: [PATCH] Optimization for BatchNorm --- mindspore/nn/layer/normalization.py | 88 ++++++++++++++--------------- 1 file changed, 41 insertions(+), 47 deletions(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 5d4380e4a..c09a02b3c 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -101,6 +101,9 @@ class _BatchNorm(Cell): epsilon=self.eps, momentum=self.momentum) self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps) + self.enable_global_sync = self.is_global and (self.is_ge_backend or (self.is_graph_mode and self.is_ascend)) + self.enable_default_train = self.is_graph_mode and not self.is_global and \ + (self.is_ge_backend or self.is_ascend) data_parallel_strategy = ((1,), (1,)) data_parallel_strategy_one = ((1,), ()) @@ -147,51 +150,43 @@ class _BatchNorm(Cell): return y def construct(self, x): - if self.input_dims == '2d': - _shape_check(self.shape(x)) - if self.input_dims == '1d': - _shape_check_2d(self.shape(x)) - if self.input_dims == 'both': - _shape_check_2d_or_4d(self.shape(x)) + _shape_check_bn(self.shape(x), self.input_dims) if self.use_batch_statistics is None: flag = self.training else: flag = self.use_batch_statistics + if flag: - if self.is_ge_backend and self.is_global: + if self.enable_global_sync: axes, re_shape = _shape_infer(F.shape(x), self.num_features) - y = self._global_sync(x, axes, re_shape) - elif self.is_graph_mode and (self.is_ge_backend or self.is_ascend): - if self.is_global: - axes, re_shape = _shape_infer(F.shape(x), self.num_features) - y = self._global_sync(x, axes, re_shape) - else: - y, batch_mean, batch_var, _, _ = \ - self.bn_train(x, - self.gamma, - self.beta, - None, - None) - - mean_sub = self.sub_mean(self.moving_mean, batch_mean) - temp_mean = self.mul_mean(mean_sub, self.momentum) - mean_sub2 = self.sub_var(self.moving_variance, batch_var) - temp_variance = self.mul_var(mean_sub2, self.momentum) - y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean)) - y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance)) - else: - y = self.bn_train(x, - self.gamma, - self.beta, - self.moving_mean, - self.moving_variance)[0] - else: - y = self.bn_infer(x, - self.gamma, - self.beta, - self.moving_mean, - self.moving_variance)[0] - return y + return self._global_sync(x, axes, re_shape) + + if self.enable_default_train: + y, batch_mean, batch_var, _, _ = self.bn_train(x, + self.gamma, + self.beta, + None, + None) + + mean_sub = self.sub_mean(self.moving_mean, batch_mean) + temp_mean = self.mul_mean(mean_sub, self.momentum) + mean_sub2 = self.sub_var(self.moving_variance, batch_var) + temp_variance = self.mul_var(mean_sub2, self.momentum) + y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean)) + y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance)) + return y + + return self.bn_train(x, + self.gamma, + self.beta, + self.moving_mean, + self.moving_variance)[0] + + return self.bn_infer(x, + self.gamma, + self.beta, + self.moving_mean, + self.moving_variance)[0] def extend_repr(self): return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format( @@ -204,12 +199,6 @@ def _channel_check(channel, num_channel): raise ValueError("the input channel is not equal with num_channel") -@constexpr -def _shape_check_2d(input_shape): - if len(input_shape) != 2: - raise ValueError("The input must has 2 dims.") - - @constexpr def _shape_check(in_shape): if len(in_shape) != 4: @@ -217,8 +206,13 @@ def _shape_check(in_shape): @constexpr -def _shape_check_2d_or_4d(in_shape): - if len(in_shape) != 2 and len(in_shape) != 4: +def _shape_check_bn(in_shape, in_dims): + dim = len(in_shape) + if in_dims == '1d' and dim != 2: + raise ValueError("The input must has 2 dims.") + if in_dims == '2d' and dim != 4: + raise ValueError("The input must has 4 dims.") + if in_dims == 'both' and dim != 2 and dim != 4: raise ValueError("The input must has 2 dims or 4 dims.") -- GitLab