未验证 提交 030d678c 编写于 作者: C ceci3 提交者: GitHub

fix syncbn convert (#30158) (#30176)

* fix syncbn convet

* add unittest
上级 39204d56
......@@ -25,6 +25,7 @@ import six
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
import paddle.nn as nn
from paddle.fluid import compiler
from paddle.fluid import Program, program_guard
......@@ -244,5 +245,34 @@ class TestConvertSyncBatchNorm(unittest.TestCase):
isinstance(model[idx], paddle.nn.SyncBatchNorm), True)
class TestConvertSyncBatchNormCase2(unittest.TestCase):
def test_convert(self):
if not core.is_compiled_with_cuda():
return
class Net(nn.Layer):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2D(3, 5, 3)
self.bn = []
bn = self.add_sublayer('bn', nn.BatchNorm2D(5))
self.bn.append(bn)
def forward(self, x):
x = self.conv1(x)
for bn in self.bn:
x = bn(x)
return x
model = nn.Sequential()
model.add_sublayer('net1', Net())
model.add_sublayer('net2', Net())
compare_model = nn.Sequential()
compare_model.add_sublayer('net1', Net())
compare_model.add_sublayer('net2', Net())
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
self.assertEqual(len(compare_model.sublayers()), len(model.sublayers()))
if __name__ == '__main__':
unittest.main()
......@@ -1121,7 +1121,7 @@ class SyncBatchNorm(_BatchNormBase):
layer_output._mean = layer._mean
layer_output._variance = layer._variance
for name, sublayer in layer.named_sublayers():
for name, sublayer in layer.named_children():
layer_output.add_sublayer(name,
cls.convert_sync_batchnorm(sublayer))
del layer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册