未验证 提交 3ae3c1d9 编写于 作者: C ceci3 提交者: GitHub

fix set value in paddle 2.1.1 (#877)

上级 4f984efe
......@@ -981,10 +981,10 @@ class SuperBatchNorm2D(nn.BatchNorm2D):
batch_norm_out = core.ops.batch_norm(input, weight, bias, mean,
variance, mean_out_tmp,
variance_out_tmp, *attrs)
self._mean[:feature_dim] = mean
self._variance[:feature_dim] = variance
mean_out[:feature_dim] = mean_out_tmp
variance_out[:feature_dim] = variance_out_tmp
self._mean[:feature_dim].set_value(mean)
self._variance[:feature_dim].set_value(variance)
mean_out[:feature_dim].set_value(mean_out_tmp)
variance_out[:feature_dim].set_value(variance_out_tmp)
else:
batch_norm_out = core.ops.batch_norm(input, weight, bias,
self._mean, self._variance,
......@@ -1031,10 +1031,10 @@ class SuperSyncBatchNorm(nn.SyncBatchNorm):
sync_batch_norm_out, _, _, _, _, _ = core.ops.sync_batch_norm(
input, weight, bias, mean, variance, mean_out_tmp,
variance_out_tmp, *attrs)
self._mean[:feature_dim] = mean
self._variance[:feature_dim] = variance
mean_out[:feature_dim] = mean_out_tmp
variance_out[:feature_dim] = variance_out_tmp
self._mean[:feature_dim].set_value(mean)
self._variance[:feature_dim].set_value(variance)
mean_out[:feature_dim].set_value(mean_out_tmp)
variance_out[:feature_dim].set_value(variance_out_tmp)
else:
sync_batch_norm_out, _, _, _, _, _ = core.ops.sync_batch_norm(
input, weight, bias, self._mean, self._variance, mean_out,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册