“9d80edd673be6419adae49d79f63768984dc78c8”上不存在“paddle/fluid/git@gitcode.net:RobotFutures/Paddle.git”
test_imperative_group.py 7.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import paddle
import paddle.fluid.core as core
19 20
from paddle.fluid.framework import (
    _in_legacy_dygraph,
21
    _test_eager_guard,
22 23
    in_dygraph_mode,
)
24 25 26


class TestDataParallelGroup(unittest.TestCase):
27 28 29 30
    def create_varbase(self, dtype, shape):
        return paddle.rand(shape=shape, dtype=dtype)

    def assign_group_by_size(self, *args):
31 32 33 34
        if in_dygraph_mode():
            return core.eager_assign_group_by_size(*args)
        elif _in_legacy_dygraph():
            return core.assign_group_by_size(*args)
35 36 37 38

    def test_construct_group0(self):
        # one dtype & one limit capability
        var_list = []
39 40 41 42
        var_list.append(self.create_varbase("float32", [2, 50]))
        var_list.append(self.create_varbase("float32", [2, 100]))
        var_list.append(self.create_varbase("float32", [2, 50]))
        var_list.append(self.create_varbase("float32", [2, 25]))
43 44 45
        res = self.assign_group_by_size(
            var_list, [False, False, False, False], [400]
        )
46 47 48 49 50
        self.assertEqual([[0], [1], [2], [3]], res)

    def test_construct_group1(self):
        # multi dtype & one limit capability
        var_list = []
51 52 53 54 55 56 57
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        res = self.assign_group_by_size(
58 59
            var_list, [False, False, False, False, False, False], [400]
        )
60 61 62 63 64
        self.assertEqual([[0, 2], [1, 3], [4], [5]], res)

    def test_construct_group2(self):
        # one dtype & multi limit capability
        var_list = []
65 66 67 68
        var_list.append(self.create_varbase("float32", [2, 50]))
        var_list.append(self.create_varbase("float32", [2, 50]))
        var_list.append(self.create_varbase("float32", [2, 50]))
        var_list.append(self.create_varbase("float32", [2, 50]))
69 70 71
        res = self.assign_group_by_size(
            var_list, [False, False, False, False], [400, 800]
        )
72 73 74 75 76
        self.assertEqual([[0], [1, 2], [3]], res)

    def test_construct_group3(self):
        # multi dtype & multi limit capability
        var_list = []
77 78 79 80 81 82 83
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        res = self.assign_group_by_size(
84 85
            var_list, [False, False, False, False, False, False], [200, 400]
        )
86 87 88 89 90
        self.assertEqual([[0], [1], [2, 4], [3, 5]], res)

    def test_construct_group4(self):
        # multi dtype & zero limit capability
        var_list = []
91 92 93 94 95 96 97
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        res = self.assign_group_by_size(
98 99
            var_list, [False, False, False, False, False, False], [0]
        )
100 101 102 103 104
        self.assertEqual([[0], [1], [2], [3], [4], [5]], res)

    def test_construct_group5(self):
        # multi dtype & infinite capability
        var_list = []
105 106 107 108 109 110 111
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        res = self.assign_group_by_size(
112 113
            var_list, [False, False, False, False, False, False], [10000]
        )
114 115 116 117 118
        self.assertEqual([[0, 2, 4], [1, 3, 5]], res)

    def test_construct_group6(self):
        # multi dtype & limit capability & multi tensor type
        var_list = []
119 120 121 122 123 124
        var_list.append(
            self.create_varbase(
                "float32",
                [1, 50],
            )
        )
125 126 127 128 129 130
        var_list.append(self.create_varbase("float64", [1, 25]))
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        res = self.assign_group_by_size(
131 132
            var_list, [True, False, False, False, False, True], [400]
        )
133 134 135 136 137
        self.assertEqual([[0], [1, 3], [2, 4], [5]], res)

    def test_construct_group7(self):
        # multi dtype & multi limit capability & multi tensor type
        var_list = []
138 139 140 141 142 143 144
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        var_list.append(self.create_varbase("float32", [1, 50]))
        var_list.append(self.create_varbase("float64", [1, 25]))
        res = self.assign_group_by_size(
145 146
            var_list, [True, False, False, False, False, True], [200, 400]
        )
147 148
        self.assertEqual([[0], [1], [2], [3], [4], [5]], res)

149 150 151
    def test_construct_group8(self):
        # one dtype & one limit capability & have tensor_indices
        var_list = []
152 153 154 155
        var_list.append(self.create_varbase("float32", [2, 25]))
        var_list.append(self.create_varbase("float32", [2, 100]))
        var_list.append(self.create_varbase("float32", [2, 50]))
        var_list.append(self.create_varbase("float32", [2, 25]))
156 157 158
        res = self.assign_group_by_size(
            var_list, [False, False, False, False], [400], [3, 0, 1, 2]
        )
159 160 161 162 163
        self.assertEqual([[3, 0], [1], [2]], res)

    def test_construct_group9(self):
        # one dtype & one limit capability & have tensor_indices
        var_list = []
164 165 166 167
        var_list.append(self.create_varbase("float32", [2, 25]))
        var_list.append(self.create_varbase("float32", [2, 25]))
        var_list.append(self.create_varbase("float32", [2, 25]))
        var_list.append(self.create_varbase("float32", [2, 1000]))
168 169 170
        res = self.assign_group_by_size(
            var_list, [False, False, False, True], [300], [1, 0, 2, 3]
        )
171 172
        self.assertEqual([[1, 0], [3], [2]], res)

173
    def test_construct_group_in_legacy_mode(self):
174
        with _test_eager_guard():
175 176 177 178 179 180 181 182 183 184 185
            pass
        self.test_construct_group0()
        self.test_construct_group1()
        self.test_construct_group2()
        self.test_construct_group3()
        self.test_construct_group4()
        self.test_construct_group5()
        self.test_construct_group6()
        self.test_construct_group7()
        self.test_construct_group8()
        self.test_construct_group9()
186 187


188 189
if __name__ == '__main__':
    unittest.main()