test_dist_fuse_all_reduce_pass.py 1.3 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

17
from dist_pass_test_base import DistPassTestBase
S
sneaxiy 已提交
18
from model_zoo import resnet_model
19

20 21
from paddle.distributed.passes import PassManager, new_pass

22 23 24 25 26 27 28

class TestFuseAllReducePass(DistPassTestBase):
    def init(self):
        self.atol = 0.0
        self.rtol = 0.0

    def apply_passes(self, main_prog, startup_prog):
29 30 31 32 33 34
        pass_manager = PassManager(
            [
                new_pass("fuse_elewise_add_act"),
                new_pass("fuse_all_reduce", {"max_memory_size": 1024 * 1024}),
            ]
        )
35
        pass_manager.apply([main_prog], [startup_prog])
S
sneaxiy 已提交
36
        print(pass_manager.names)
37 38

    def test_bs_32(self):
S
sneaxiy 已提交
39
        self.check_main(resnet_model, batch_size=32)
40 41 42 43


if __name__ == "__main__":
    unittest.main()