未验证 提交 6a19e41f 编写于 作者: C ceci3 提交者: GitHub

fix syncbn convert (#30158)

* fix syncbn convet

* add unittest
上级 adac38c5
...@@ -25,6 +25,7 @@ import six ...@@ -25,6 +25,7 @@ import six
import paddle import paddle
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.nn as nn
from paddle.fluid import compiler from paddle.fluid import compiler
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
...@@ -244,5 +245,34 @@ class TestConvertSyncBatchNorm(unittest.TestCase): ...@@ -244,5 +245,34 @@ class TestConvertSyncBatchNorm(unittest.TestCase):
isinstance(model[idx], paddle.nn.SyncBatchNorm), True) isinstance(model[idx], paddle.nn.SyncBatchNorm), True)
class TestConvertSyncBatchNormCase2(unittest.TestCase):
def test_convert(self):
if not core.is_compiled_with_cuda():
return
class Net(nn.Layer):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2D(3, 5, 3)
self.bn = []
bn = self.add_sublayer('bn', nn.BatchNorm2D(5))
self.bn.append(bn)
def forward(self, x):
x = self.conv1(x)
for bn in self.bn:
x = bn(x)
return x
model = nn.Sequential()
model.add_sublayer('net1', Net())
model.add_sublayer('net2', Net())
compare_model = nn.Sequential()
compare_model.add_sublayer('net1', Net())
compare_model.add_sublayer('net2', Net())
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
self.assertEqual(len(compare_model.sublayers()), len(model.sublayers()))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -1142,7 +1142,7 @@ class SyncBatchNorm(_BatchNormBase): ...@@ -1142,7 +1142,7 @@ class SyncBatchNorm(_BatchNormBase):
layer_output._mean = layer._mean layer_output._mean = layer._mean
layer_output._variance = layer._variance layer_output._variance = layer._variance
for name, sublayer in layer.named_sublayers(): for name, sublayer in layer.named_children():
layer_output.add_sublayer(name, layer_output.add_sublayer(name,
cls.convert_sync_batchnorm(sublayer)) cls.convert_sync_batchnorm(sublayer))
del layer del layer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册