test_dist_mnist.py 2.0 KB
Newer Older
Y
Yancey1989 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14 15

from __future__ import print_function
Y
Yancey1989 已提交
16
import unittest
T
typhoonzero 已提交
17
from test_dist_base import TestDistBase
Y
Yancey1989 已提交
18 19


W
Wu Yi 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32 33
class TestDistMnist2x2(TestDistBase):
    def _setup_config(self):
        self._sync_mode = True
        self._use_reduce = False

    def test_se_resnext(self):
        self.check_with_place("dist_mnist.py", delta=1e-7)


class TestDistMnist2x2WithMemopt(TestDistBase):
    def _setup_config(self):
        self._sync_mode = True
        self._mem_opt = True

T
typhoonzero 已提交
34 35
    def test_se_resnext(self):
        self.check_with_place("dist_mnist.py", delta=1e-7)
Y
Yancey1989 已提交
36 37


W
Wu Yi 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
class TestDistMnistAsync(TestDistBase):
    def _setup_config(self):
        self._sync_mode = False
        self._use_reduce = False

    def test_se_resnext(self):
        self.check_with_place("dist_mnist.py", delta=200)


# FIXME(typhoonzero): enable these tests once we have 4
# 4 GPUs on CI machine, and the base class should be updated.
#
# class TestDistMnist2x2ReduceMode(TestDistBase):
#     def _setup_config(self):
#         self._sync_mode = True
#         self._use_reduce = True

#     def test_se_resnext(self):
#         self.check_with_place("dist_mnist.py", delta=1e-7)

# class TestDistMnistAsyncReduceMode(TestDistBase):
#     def _setup_config(self):
#         self._sync_mode = False
#         self._use_reduce = True

#     def test_se_resnext(self):
#         self.check_with_place("dist_mnist.py", delta=200)

Y
Yancey1989 已提交
66 67
if __name__ == "__main__":
    unittest.main()