未验证 提交 e26411ce 编写于 作者: C chengduo 提交者: GitHub

Open test_parallel_dygraph_se_resnext (#19342)

* enabel test_parallel_dygraph_se_resnext
test=develop
上级 caf59d0f
......@@ -57,10 +57,15 @@ class ConvBNLayer(fluid.dygraph.Layer):
self._batch_norm = BatchNorm(
self.full_name(), num_filters, act=act, momentum=0.1)
self._layer_norm = fluid.dygraph.nn.LayerNorm(
self.full_name(), begin_norm_axis=1)
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
# FIXME(zcd): when compare the result of multi-card and single-card,
# we should replace batch_norm with layer_norm.
y = self._layer_norm(y)
# y = self._batch_norm(y)
return y
......@@ -278,7 +283,9 @@ class SeResNeXt(fluid.dygraph.Layer):
for bottleneck_block in self.bottleneck_block_list:
y = bottleneck_block(y)
y = self.pool2d_avg(y)
y = fluid.layers.dropout(y, dropout_prob=0.2, seed=1)
# FIXME(zcd): the dropout should be removed when compare the
# result of multi-card and single-card.
# y = fluid.layers.dropout(y, dropout_prob=0.2, seed=1)
cost = self.fc(y)
loss = fluid.layers.cross_entropy(cost, label)
avg_loss = fluid.layers.mean(loss)
......@@ -290,7 +297,7 @@ class TestSeResNeXt(TestParallelDyGraphRunnerBase):
model = SeResNeXt("se-resnext")
train_reader = paddle.batch(
paddle.dataset.flowers.test(use_xmap=False),
batch_size=2,
batch_size=4,
drop_last=True)
opt = fluid.optimizer.SGD(learning_rate=1e-3)
......
......@@ -13,10 +13,11 @@
# limitations under the License.
from __future__ import print_function
#import unittest
import unittest
from test_dist_base import TestDistBase
import paddle.fluid as fluid
'''
class TestParallelDygraphSeResNeXt(TestDistBase):
def _setup_config(self):
self._sync_mode = False
......@@ -24,12 +25,9 @@ class TestParallelDygraphSeResNeXt(TestDistBase):
self._dygraph = True
def test_se_resnext(self):
# TODO(Yancey1989): BN and Dropout is related with batchsize, so the delta is the 1,
# try to remove the BN and Dropout in the network and using delta = 1e-5
if fluid.core.is_compiled_with_cuda():
self.check_with_place("parallel_dygraph_se_resnext.py", delta=1)
'''
self.check_with_place("parallel_dygraph_se_resnext.py", delta=0.01)
if __name__ == "__main__":
pass
#unittest.main()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册