test_imperative_group.py 7.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import unittest

import paddle
import paddle.fluid.core as core
20 21 22 23 24
from paddle.fluid.framework import (
    _test_eager_guard,
    _in_legacy_dygraph,
    in_dygraph_mode,
)
25 26 27


class TestDataParallelGroup(unittest.TestCase):
28 29 30 31
    def create_varbase(self, dtype, shape):
        return paddle.rand(shape=shape, dtype=dtype)

    def assign_group_by_size(self, *args):
32 33 34 35
        if in_dygraph_mode():
            return core.eager_assign_group_by_size(*args)
        elif _in_legacy_dygraph():
            return core.assign_group_by_size(*args)
36 37 38 39

    def test_construct_group0(self):
        # one dtype & one limit capability
        var_list = []
40 41 42 43
        var_list.append(self.create_varbase("float32", [2, 50]))
        var_list.append(self.create_varbase("float32", [2, 100]))
        var_list.append(self.create_varbase("float32", [2, 50]))
        var_list.append(self.create_varbase("float32", [2, 25]))
44 45 46
        res = self.assign_group_by_size(
            var_list, [False, False, False, False], [400]
        )
47 48 49 50 51
        self.assertEqual([[0], [1], [2], [3]], res)

    def test_construct_group1(self):
        # multi dtype & one limit capability
        var_list = []
52 53 54 55 56 57 58
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        res = self.assign_group_by_size(
59 60
            var_list, [False, False, False, False, False, False], [400]
        )
61 62 63 64 65
        self.assertEqual([[0, 2], [1, 3], [4], [5]], res)

    def test_construct_group2(self):
        # one dtype & multi limit capability
        var_list = []
66 67 68 69
        var_list.append(self.create_varbase("float32", [2, 50]))
        var_list.append(self.create_varbase("float32", [2, 50]))
        var_list.append(self.create_varbase("float32", [2, 50]))
        var_list.append(self.create_varbase("float32", [2, 50]))
70 71 72
        res = self.assign_group_by_size(
            var_list, [False, False, False, False], [400, 800]
        )
73 74 75 76 77
        self.assertEqual([[0], [1, 2], [3]], res)

    def test_construct_group3(self):
        # multi dtype & multi limit capability
        var_list = []
78 79 80 81 82 83 84
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        res = self.assign_group_by_size(
85 86
            var_list, [False, False, False, False, False, False], [200, 400]
        )
87 88 89 90 91
        self.assertEqual([[0], [1], [2, 4], [3, 5]], res)

    def test_construct_group4(self):
        # multi dtype & zero limit capability
        var_list = []
92 93 94 95 96 97 98
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        res = self.assign_group_by_size(
99 100
            var_list, [False, False, False, False, False, False], [0]
        )
101 102 103 104 105
        self.assertEqual([[0], [1], [2], [3], [4], [5]], res)

    def test_construct_group5(self):
        # multi dtype & infinite capability
        var_list = []
106 107 108 109 110 111 112
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        res = self.assign_group_by_size(
113 114
            var_list, [False, False, False, False, False, False], [10000]
        )
115 116 117 118 119
        self.assertEqual([[0, 2, 4], [1, 3, 5]], res)

    def test_construct_group6(self):
        # multi dtype & limit capability & multi tensor type
        var_list = []
120 121 122 123 124 125
        var_list.append(
            self.create_varbase(
                "float32",
                [1, 50],
            )
        )
126 127 128 129 130 131
        var_list.append(self.create_varbase("float64", [1, 25]))
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        res = self.assign_group_by_size(
132 133
            var_list, [True, False, False, False, False, True], [400]
        )
134 135 136 137 138
        self.assertEqual([[0], [1, 3], [2, 4], [5]], res)

    def test_construct_group7(self):
        # multi dtype & multi limit capability & multi tensor type
        var_list = []
139 140 141 142 143 144 145
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        res = self.assign_group_by_size(
146 147
            var_list, [True, False, False, False, False, True], [200, 400]
        )
148 149
        self.assertEqual([[0], [1], [2], [3], [4], [5]], res)

150 151 152
    def test_construct_group8(self):
        # one dtype & one limit capability & have tensor_indices
        var_list = []
153 154 155 156
        var_list.append(self.create_varbase("float32", [2, 25]))
        var_list.append(self.create_varbase("float32", [2, 100]))
        var_list.append(self.create_varbase("float32", [2, 50]))
        var_list.append(self.create_varbase("float32", [2, 25]))
157 158 159
        res = self.assign_group_by_size(
            var_list, [False, False, False, False], [400], [3, 0, 1, 2]
        )
160 161 162 163 164
        self.assertEqual([[3, 0], [1], [2]], res)

    def test_construct_group9(self):
        # one dtype & one limit capability & have tensor_indices
        var_list = []
165 166 167 168
        var_list.append(self.create_varbase("float32", [2, 25]))
        var_list.append(self.create_varbase("float32", [2, 25]))
        var_list.append(self.create_varbase("float32", [2, 25]))
        var_list.append(self.create_varbase("float32", [2, 1000]))
169 170 171
        res = self.assign_group_by_size(
            var_list, [False, False, False, True], [300], [1, 0, 2, 3]
        )
172 173
        self.assertEqual([[1, 0], [3], [2]], res)

174
    def test_construct_group_in_legacy_mode(self):
175
        with _test_eager_guard():
176 177 178 179 180 181 182 183 184 185 186
            pass
        self.test_construct_group0()
        self.test_construct_group1()
        self.test_construct_group2()
        self.test_construct_group3()
        self.test_construct_group4()
        self.test_construct_group5()
        self.test_construct_group6()
        self.test_construct_group7()
        self.test_construct_group8()
        self.test_construct_group9()
187 188


189 190
if __name__ == '__main__':
    unittest.main()