spawn_runner_base.py 2.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function, division

import numpy as np
import unittest

import paddle

# used by model.run_trainer in test_dist_base
from test_dist_base import RUN_STEP


# NOTE: compatible TestParallelDyGraphRunnerBase args
class SpawnAssistTestArgs(object):
    update_method = "local"
    trainer_id = 0


class TestDistSpawnRunner(unittest.TestCase):
    def setUp(self):
        # NOTE(chenweihang): keep consistent with
        # TestDistBase.check_with_place
        self.nprocs = 2

    def _run(self, model, args):
        args.update_method = "local"
        return model.run_trainer_with_spawn(args)

    def _run_parallel(self, model, args):
        args.update_method = "nccl2"
        context = paddle.distributed.spawn(
            func=model.run_trainer_with_spawn,
            args=(args, ),
            nprocs=self.nprocs,
            join=True)
        result_list = []
        for res_queue in context.return_queues:
            result_list.append(res_queue.get())
        return result_list

    def check_dist_result_with_spawn(self, test_class, delta=1e-3):
        # 0. prepare model and args
        model = test_class()
        args = SpawnAssistTestArgs()

        # 1. calc signal card loss
        losses = self._run(model, args)

        # 2. calc multi card loss (nccl mode)
        dist_losses_list = self._run_parallel(model, args)

        # 3. compare losses
        for step_id in range(RUN_STEP):
            loss = losses[step_id]
            dist_loss_sum = None
            for dist_losses in dist_losses_list:
                if dist_loss_sum is None:
                    dist_loss_sum = np.array(dist_losses[step_id])
                else:
                    dist_loss_sum += np.array(dist_losses[step_id])
            dist_loss = dist_loss_sum / self.nprocs
            self.assertAlmostEqual(
                loss,
                dist_loss,
                delta=delta,
                msg="The results of single-card execution and multi-card execution are inconsistent."
                "signal-card loss is:\n{}\nmulti-card average loss is:\n{}\n".
                format(loss, dist_loss))