未验证 提交 cb642009 编写于 作者: C Chang Xu 提交者: GitHub

Fit new paddle (#1044)

上级 380bce65
...@@ -963,19 +963,48 @@ class SuperBatchNorm2D(nn.BatchNorm2D): ...@@ -963,19 +963,48 @@ class SuperBatchNorm2D(nn.BatchNorm2D):
"use_mkldnn", False, "fuse_with_relu", False, "use_mkldnn", False, "fuse_with_relu", False,
"use_global_stats", self._use_global_stats, "use_global_stats", self._use_global_stats,
"trainable_statistics", trainable_statistics) "trainable_statistics", trainable_statistics)
try:
if feature_dim != self._mean.shape[0]: from paddle import _C_ops
batch_norm_out = core.ops.batch_norm(input, weight, bias, mean, from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph
variance, mean_out_tmp, if in_dygraph_mode():
variance_out_tmp, *attrs) if feature_dim != self._mean.shape[0]:
self._mean[:feature_dim].set_value(mean) batch_norm_out = _C_ops.final_state_batch_norm(
self._variance[:feature_dim].set_value(variance) input, weight, bias, mean, variance, mean_out_tmp,
mean_out[:feature_dim].set_value(mean_out_tmp) variance_out_tmp, *attrs)
variance_out[:feature_dim].set_value(variance_out_tmp) self._mean[:feature_dim].set_value(mean)
else: self._variance[:feature_dim].set_value(variance)
batch_norm_out = core.ops.batch_norm(input, weight, bias, mean_out[:feature_dim].set_value(mean_out_tmp)
self._mean, self._variance, variance_out[:feature_dim].set_value(variance_out_tmp)
mean_out, variance_out, *attrs) else:
batch_norm_out = _C_ops.final_state_batch_norm(
input, weight, bias, self._mean, self._variance,
mean_out, variance_out, *attrs)
elif _in_legacy_dygraph():
if feature_dim != self._mean.shape[0]:
batch_norm_out = core.ops.batch_norm(
input, weight, bias, mean, variance, None, mean_out_tmp,
variance_out_tmp, *attrs)
self._mean[:feature_dim].set_value(mean)
self._variance[:feature_dim].set_value(variance)
mean_out[:feature_dim].set_value(mean_out_tmp)
variance_out[:feature_dim].set_value(variance_out_tmp)
else:
batch_norm_out = core.ops.batch_norm(
input, weight, bias, self._mean, self._variance, None,
mean_out, variance_out, *attrs)
except:
if feature_dim != self._mean.shape[0]:
batch_norm_out = core.ops.batch_norm(input, weight, bias, mean,
variance, mean_out_tmp,
variance_out_tmp, *attrs)
self._mean[:feature_dim].set_value(mean)
self._variance[:feature_dim].set_value(variance)
mean_out[:feature_dim].set_value(mean_out_tmp)
variance_out[:feature_dim].set_value(variance_out_tmp)
else:
batch_norm_out = core.ops.batch_norm(
input, weight, bias, self._mean, self._variance, mean_out,
variance_out, *attrs)
self.cur_config = {'prune_dim': feature_dim} self.cur_config = {'prune_dim': feature_dim}
return batch_norm_out[0] return batch_norm_out[0]
...@@ -1246,4 +1275,4 @@ class SuperEmbedding(nn.Embedding): ...@@ -1246,4 +1275,4 @@ class SuperEmbedding(nn.Embedding):
weight=weight, weight=weight,
padding_idx=self._padding_idx, padding_idx=self._padding_idx,
sparse=self._sparse, sparse=self._sparse,
name=self._name) name=self._name)
\ No newline at end of file
...@@ -903,19 +903,48 @@ class SuperBatchNorm(fluid.dygraph.BatchNorm): ...@@ -903,19 +903,48 @@ class SuperBatchNorm(fluid.dygraph.BatchNorm):
"use_mkldnn", False, "fuse_with_relu", self._fuse_with_relu, "use_mkldnn", False, "fuse_with_relu", self._fuse_with_relu,
"use_global_stats", self._use_global_stats, "use_global_stats", self._use_global_stats,
'trainable_statistics', self._trainable_statistics) 'trainable_statistics', self._trainable_statistics)
try:
if feature_dim != self._mean.shape[0]: from paddle import _C_ops
batch_norm_out = core.ops.batch_norm(input, weight, bias, mean, from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph
variance, mean_out_tmp, if in_dygraph_mode():
variance_out_tmp, *attrs) if feature_dim != self._mean.shape[0]:
self._mean[:feature_dim] = mean batch_norm_out = _C_ops.final_state_batch_norm(
self._variance[:feature_dim] = variance input, weight, bias, mean, variance, mean_out_tmp,
mean_out[:feature_dim] = mean_out_tmp variance_out_tmp, *attrs)
variance_out[:feature_dim] = variance_out_tmp self._mean[:feature_dim] = mean
else: self._variance[:feature_dim] = variance
batch_norm_out = core.ops.batch_norm(input, weight, bias, mean_out[:feature_dim] = mean_out_tmp
self._mean, self._variance, variance_out[:feature_dim] = variance_out_tmp
mean_out, variance_out, *attrs) else:
batch_norm_out = core.ops.batch_norm(
input, weight, bias, self._mean, self._variance,
mean_out, variance_out, *attrs)
elif _in_legacy_dygraph():
if feature_dim != self._mean.shape[0]:
batch_norm_out = core.ops.batch_norm(
input, weight, bias, mean, variance, None, mean_out_tmp,
variance_out_tmp, *attrs)
self._mean[:feature_dim].set_value(mean)
self._variance[:feature_dim].set_value(variance)
mean_out[:feature_dim].set_value(mean_out_tmp)
variance_out[:feature_dim].set_value(variance_out_tmp)
else:
batch_norm_out = core.ops.batch_norm(
input, weight, bias, self._mean, self._variance, None,
mean_out, variance_out, *attrs)
except:
if feature_dim != self._mean.shape[0]:
batch_norm_out = core.ops.batch_norm(input, weight, bias, mean,
variance, mean_out_tmp,
variance_out_tmp, *attrs)
self._mean[:feature_dim].set_value(mean)
self._variance[:feature_dim].set_value(variance)
mean_out[:feature_dim].set_value(mean_out_tmp)
variance_out[:feature_dim].set_value(variance_out_tmp)
else:
batch_norm_out = core.ops.batch_norm(
input, weight, bias, self._mean, self._variance, mean_out,
variance_out, *attrs)
return dygraph_utils._append_activation_in_dygraph( return dygraph_utils._append_activation_in_dygraph(
batch_norm_out[0], act=self._act) batch_norm_out[0], act=self._act)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册