diff --git a/paddleslim/nas/ofa/layers.py b/paddleslim/nas/ofa/layers.py index aad221475ee3de79374dc085c0ca4c6155bc8740..c6481c312326b0ed83b8568750fad6c3042eec53 100644 --- a/paddleslim/nas/ofa/layers.py +++ b/paddleslim/nas/ofa/layers.py @@ -963,19 +963,48 @@ class SuperBatchNorm2D(nn.BatchNorm2D): "use_mkldnn", False, "fuse_with_relu", False, "use_global_stats", self._use_global_stats, "trainable_statistics", trainable_statistics) - - if feature_dim != self._mean.shape[0]: - batch_norm_out = core.ops.batch_norm(input, weight, bias, mean, - variance, mean_out_tmp, - variance_out_tmp, *attrs) - self._mean[:feature_dim].set_value(mean) - self._variance[:feature_dim].set_value(variance) - mean_out[:feature_dim].set_value(mean_out_tmp) - variance_out[:feature_dim].set_value(variance_out_tmp) - else: - batch_norm_out = core.ops.batch_norm(input, weight, bias, - self._mean, self._variance, - mean_out, variance_out, *attrs) + try: + from paddle import _C_ops + from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph + if in_dygraph_mode(): + if feature_dim != self._mean.shape[0]: + batch_norm_out = _C_ops.final_state_batch_norm( + input, weight, bias, mean, variance, mean_out_tmp, + variance_out_tmp, *attrs) + self._mean[:feature_dim].set_value(mean) + self._variance[:feature_dim].set_value(variance) + mean_out[:feature_dim].set_value(mean_out_tmp) + variance_out[:feature_dim].set_value(variance_out_tmp) + else: + batch_norm_out = _C_ops.final_state_batch_norm( + input, weight, bias, self._mean, self._variance, + mean_out, variance_out, *attrs) + elif _in_legacy_dygraph(): + if feature_dim != self._mean.shape[0]: + batch_norm_out = core.ops.batch_norm( + input, weight, bias, mean, variance, None, mean_out_tmp, + variance_out_tmp, *attrs) + self._mean[:feature_dim].set_value(mean) + self._variance[:feature_dim].set_value(variance) + mean_out[:feature_dim].set_value(mean_out_tmp) + variance_out[:feature_dim].set_value(variance_out_tmp) + else: + batch_norm_out = core.ops.batch_norm( + input, weight, bias, self._mean, self._variance, None, + mean_out, variance_out, *attrs) + except: + if feature_dim != self._mean.shape[0]: + batch_norm_out = core.ops.batch_norm(input, weight, bias, mean, + variance, mean_out_tmp, + variance_out_tmp, *attrs) + self._mean[:feature_dim].set_value(mean) + self._variance[:feature_dim].set_value(variance) + mean_out[:feature_dim].set_value(mean_out_tmp) + variance_out[:feature_dim].set_value(variance_out_tmp) + else: + batch_norm_out = core.ops.batch_norm( + input, weight, bias, self._mean, self._variance, mean_out, + variance_out, *attrs) self.cur_config = {'prune_dim': feature_dim} return batch_norm_out[0] @@ -1246,4 +1275,4 @@ class SuperEmbedding(nn.Embedding): weight=weight, padding_idx=self._padding_idx, sparse=self._sparse, - name=self._name) \ No newline at end of file + name=self._name) diff --git a/paddleslim/nas/ofa/layers_old.py b/paddleslim/nas/ofa/layers_old.py index ee3b8de8ab7e87312d78ca4c1c8803433df02262..73dad7f7c0e73d669ceff98374e582013f7cc2d8 100644 --- a/paddleslim/nas/ofa/layers_old.py +++ b/paddleslim/nas/ofa/layers_old.py @@ -903,19 +903,48 @@ class SuperBatchNorm(fluid.dygraph.BatchNorm): "use_mkldnn", False, "fuse_with_relu", self._fuse_with_relu, "use_global_stats", self._use_global_stats, 'trainable_statistics', self._trainable_statistics) - - if feature_dim != self._mean.shape[0]: - batch_norm_out = core.ops.batch_norm(input, weight, bias, mean, - variance, mean_out_tmp, - variance_out_tmp, *attrs) - self._mean[:feature_dim] = mean - self._variance[:feature_dim] = variance - mean_out[:feature_dim] = mean_out_tmp - variance_out[:feature_dim] = variance_out_tmp - else: - batch_norm_out = core.ops.batch_norm(input, weight, bias, - self._mean, self._variance, - mean_out, variance_out, *attrs) + try: + from paddle import _C_ops + from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph + if in_dygraph_mode(): + if feature_dim != self._mean.shape[0]: + batch_norm_out = _C_ops.final_state_batch_norm( + input, weight, bias, mean, variance, mean_out_tmp, + variance_out_tmp, *attrs) + self._mean[:feature_dim] = mean + self._variance[:feature_dim] = variance + mean_out[:feature_dim] = mean_out_tmp + variance_out[:feature_dim] = variance_out_tmp + else: + batch_norm_out = core.ops.batch_norm( + input, weight, bias, self._mean, self._variance, + mean_out, variance_out, *attrs) + elif _in_legacy_dygraph(): + if feature_dim != self._mean.shape[0]: + batch_norm_out = core.ops.batch_norm( + input, weight, bias, mean, variance, None, mean_out_tmp, + variance_out_tmp, *attrs) + self._mean[:feature_dim].set_value(mean) + self._variance[:feature_dim].set_value(variance) + mean_out[:feature_dim].set_value(mean_out_tmp) + variance_out[:feature_dim].set_value(variance_out_tmp) + else: + batch_norm_out = core.ops.batch_norm( + input, weight, bias, self._mean, self._variance, None, + mean_out, variance_out, *attrs) + except: + if feature_dim != self._mean.shape[0]: + batch_norm_out = core.ops.batch_norm(input, weight, bias, mean, + variance, mean_out_tmp, + variance_out_tmp, *attrs) + self._mean[:feature_dim].set_value(mean) + self._variance[:feature_dim].set_value(variance) + mean_out[:feature_dim].set_value(mean_out_tmp) + variance_out[:feature_dim].set_value(variance_out_tmp) + else: + batch_norm_out = core.ops.batch_norm( + input, weight, bias, self._mean, self._variance, mean_out, + variance_out, *attrs) return dygraph_utils._append_activation_in_dygraph( batch_norm_out[0], act=self._act)