test_communication_stream_broadcast_api.py 1.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import paddle
import itertools
import test_communication_api_base as test_base


class TestCommunicationStreamBroadcastAPI(test_base.CommunicationTestDistBase):

    def setUp(self):
        super(TestCommunicationStreamBroadcastAPI, self).setUp(num_of_devices=2,
                                                               timeout=120)
        self._default_envs = {
            "backend": "nccl",
            "shape": "(100, 200)",
            "dtype": "float32",
            "seeds": str(self._seeds)
        }
        self._changeable_envs = {
            "sync_op": ["True", "False"],
            "use_calc_stream": ["True", "False"]
        }

    def test_broadcast_stream(self):
        envs_list = test_base.gen_product_envs_list(self._default_envs,
                                                    self._changeable_envs)
        for envs in envs_list:
            if eval(envs["use_calc_stream"]) and not eval(envs["sync_op"]):
                continue
            self.run_test_case("communication_stream_broadcast_api_dygraph.py",
                               user_defined_envs=envs)

    def tearDown(self):
        super(TestCommunicationStreamBroadcastAPI, self).tearDown()


if __name__ == '__main__':
    unittest.main()