graph_execution_optimizer.py 11.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

14
import copy
15 16 17 18 19
import paddle
from paddle.fluid.framework import core
from paddle.fluid import compiler
from .meta_optimizer_base import MetaOptimizerBase
from ..base.private_helper_function import wait_server_ready
D
Dong Daxiang 已提交
20
import logging
Y
Yuang Liu 已提交
21
from paddle.static import BuildStrategy
22

23 24
__all__ = []

25 26

class GraphExecutionOptimizer(MetaOptimizerBase):
27

28 29 30 31 32
    def __init__(self, optimizer):
        super(GraphExecutionOptimizer, self).__init__(optimizer)
        self.inner_opt = optimizer
        # we do not allow meta optimizer to be inner optimizer currently
        self.meta_optimizers_white_list = []
33
        self.meta_optimizers_black_list = []
34 35 36 37 38 39 40 41

    def _is_graph_out(self):
        return True

    def _can_apply(self):
        """
        Basically, this is PE, and almost all programs can be executed here
        """
D
Dong Daxiang 已提交
42 43 44 45
        if not self.role_maker._is_collective:
            # update me. currently, if parameter server is used
            # graph execution optimizer can not be applied
            return False
46
        return not self.user_defined_strategy.without_graph_optimization
47 48 49 50 51 52 53 54 55

    def backward(self,
                 loss,
                 startup_program=None,
                 parameter_list=None,
                 no_grad_set=None,
                 callbacks=None):
        pass

56
    # should fix the variable
57
    def _setup_nccl_op(self, startup_program, main_program, build_strategy):
58
        trainer_endpoints = self.role_maker._get_trainer_endpoints()
59 60
        other_trainers = copy.copy(trainer_endpoints)

61 62
        trainer_id = self.role_maker._worker_index()
        current_endpoint = self.role_maker._get_trainer_endpoints()[trainer_id]
63 64
        other_trainers.remove(current_endpoint)

65
        trainer_endpoints_env = ",".join(trainer_endpoints)
66
        trainers_num = self.role_maker._worker_num()
67

W
WangXi 已提交
68 69 70
        # NOTE(wangxi): npu don't need to wait server ready
        if trainer_id == 0 and not paddle.is_compiled_with_npu():
            wait_server_ready(other_trainers)
71

72 73 74
        if build_strategy.reduce_strategy == BuildStrategy.ReduceStrategy._NoReduce:
            return

75 76 77
        if core.is_compiled_with_cuda():
            comm_id_var = startup_program.global_block().create_var(
                name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW)
78

79
            for i in range(1, build_strategy.nccl_comm_num):
80
                startup_program.global_block().create_var(
81
                    name="NCCLID_{}".format(i),
82 83
                    persistable=True,
                    type=core.VarDesc.VarType.RAW)
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100

            if build_strategy.use_hierarchical_allreduce:
                for i in range(0, build_strategy.nccl_comm_num):
                    startup_program.global_block().create_var(
                        name="Hierarchical_inter_NCCLID_{}".format(i),
                        persistable=True,
                        type=core.VarDesc.VarType.RAW)
                    startup_program.global_block().create_var(
                        name="Hierarchical_exter_NCCLID_{}".format(i),
                        persistable=True,
                        type=core.VarDesc.VarType.RAW)

            startup_program.global_block().append_op(
                type="gen_nccl_id",
                inputs={},
                outputs={"NCCLID": comm_id_var},
                attrs={
101 102 103 104 105 106
                    "trainers":
                    trainer_endpoints,
                    "trainer_id":
                    trainer_id,
                    "nccl_comm_num":
                    build_strategy.nccl_comm_num,
107 108 109 110 111 112 113 114 115 116 117 118 119
                    "use_hierarchical_allreduce":
                    build_strategy.use_hierarchical_allreduce,
                    "hierarchical_allreduce_inter_ranks":
                    build_strategy.hierarchical_allreduce_inter_nranks
                })
        elif core.is_compiled_with_xpu():
            comm_id_var = startup_program.global_block().create_var(
                name="BKCLID", persistable=True, type=core.VarDesc.VarType.RAW)

            #NOTE(liuyuhui) Baidu Kunlun Communication Library(BKCL) currently do not support multi machines.
            assert build_strategy.bkcl_comm_num == 1, \
                "Baidu Kunlun Communication Library(BKCL) currently do not support multi machines."
            for i in range(1, build_strategy.bkcl_comm_num):
120
                startup_program.global_block().create_var(
121
                    name="BKCLID_{}".format(i),
122 123 124
                    persistable=True,
                    type=core.VarDesc.VarType.RAW)

125 126 127 128 129
            startup_program.global_block().append_op(
                type="gen_bkcl_id",
                inputs={},
                outputs={"BKCLID": comm_id_var},
                attrs={
130 131 132 133 134 135
                    "trainers":
                    trainer_endpoints,
                    "trainer_id":
                    trainer_id,
                    "nccl_comm_num":
                    build_strategy.nccl_comm_num,
136 137 138 139 140 141 142 143 144
                    "use_hierarchical_allreduce":
                    build_strategy.use_hierarchical_allreduce,
                    "hierarchical_allreduce_inter_ranks":
                    build_strategy.hierarchical_allreduce_inter_nranks
                })
        else:
            raise ValueError(
                "comm_id must be generated in paddlepaddle-xpu or paddlepaddle-gpu."
            )
145 146

    def _try_to_compile(self, startup_program, main_program, loss):
147
        dist_strategy = self.user_defined_strategy
148 149
        local_build_strategy = dist_strategy.build_strategy

150
        local_build_strategy.use_hierarchical_allreduce = \
151
            dist_strategy.use_hierarchical_allreduce
152
        local_build_strategy.hierarchical_allreduce_inter_nranks = \
153
            dist_strategy.hierarchical_allreduce_inter_nranks
154
        local_build_strategy.sync_batch_norm = \
155
            dist_strategy.sync_batch_norm
156
        local_build_strategy.fuse_all_reduce_ops = \
157
            dist_strategy.fuse_all_reduce_ops
158
        local_build_strategy.nccl_comm_num = \
159
            dist_strategy.nccl_comm_num
160

Y
Yuang Liu 已提交
161 162 163 164 165 166 167 168 169 170 171
        gradient_scale_configs = self.user_defined_strategy.gradient_scale_configs
        scale_strategys = {
            'avg': BuildStrategy.GradientScaleStrategy.CoeffNumDevice,
            'sum': BuildStrategy.GradientScaleStrategy.One,
            'customized': BuildStrategy.GradientScaleStrategy.Customized,
        }
        assert gradient_scale_configs['scale_strategy'] in scale_strategys, \
            "gradient_scale_configs.scale_strategy must be 'avg', 'sum' or 'customized'"
        local_build_strategy.gradient_scale_strategy = \
            scale_strategys[gradient_scale_configs['scale_strategy']]

172 173 174 175 176 177
        if self.user_defined_strategy.recompute == True:
            logging.warn(
                "set enable_sequential_execution=True since you have enable the recompute strategy"
            )
            local_build_strategy.enable_sequential_execution = True

178
        exe_strategy = self.user_defined_strategy.execution_strategy
179 180
        worker_num = self.role_maker._worker_num()
        node_num = self.role_maker._node_num()
181

182
        if self.role_maker._is_collective:
183
            assert worker_num >= 1, "nccl2 worker_num must >= 1, now:{}" % worker_num
184

185
        if worker_num <= 1:
186
            # local mode
187
            if local_build_strategy.nccl_comm_num > 1:
188
                logging.warn("set nccl_comm_num=1 since you only have 1 node.")
189
            local_build_strategy.nccl_comm_num = 1
190

191
        if node_num <= 1:
192
            if local_build_strategy.use_hierarchical_allreduce:
193 194 195
                logging.warn(
                    "set hierachical_allreduce=False since you only have 1 node."
                )
196
            local_build_strategy.use_hierarchical_allreduce = False
197

198
        sync_allreduce = dist_strategy.sync_nccl_allreduce
199
        if sync_allreduce:
200 201 202 203
            exe_strategy.num_threads = max(
                local_build_strategy.nccl_comm_num + 1,
                exe_strategy.num_threads)
            if local_build_strategy.nccl_comm_num > 1:
204
                logging.warn(
205
                    "nccl_comm_num > 1, you may need to set sync_nccl_allreduce=False to ensure that different nccl comms can overlap"
206 207
                )

208
        sync_batch_norm = local_build_strategy.sync_batch_norm
209
        if sync_batch_norm:
210 211
            local_build_strategy.nccl_comm_num = 1
            local_build_strategy.use_hierarchical_allreduce = False
212 213 214 215 216 217
            exe_strategy.num_threads = 1
            logging.warn(
                "use sync_batch_norm will hang when set num_threads > 1, so "
                "set num_threads=1, nccl_comm_num=1, hierachical_allreduce=False."
            )

218 219 220 221 222
        # NOTE. compatible with compiler, otherwise these values will be overwritten by compiler
        main_program._nccl_comm_num = local_build_strategy.nccl_comm_num
        main_program._use_hierarchical_allreduce = local_build_strategy.use_hierarchical_allreduce
        main_program._hierarchical_allreduce_inter_nranks = local_build_strategy.hierarchical_allreduce_inter_nranks

223
        # TODO(guru4elephant): should be an independent optimizer
224 225 226
        if worker_num > 1:
            self._setup_nccl_op(startup_program, main_program,
                                local_build_strategy)
227

228 229 230
        local_build_strategy.num_trainers = self.role_maker._worker_num()
        local_build_strategy.trainer_id = self.role_maker._worker_index()
        local_build_strategy.trainers_endpoints = self.role_maker._get_trainer_endpoints(
231
        )
232
        local_build_strategy.enable_backward_optimizer_op_deps = True
233 234 235 236 237

        self._compiled_program = compiler.CompiledProgram(main_program)

        self._compiled_program.with_data_parallel(
            loss_name=loss.name,
238
            build_strategy=local_build_strategy,
239 240 241 242 243
            exec_strategy=exe_strategy,
            share_vars_from=None)

        return self._compiled_program

D
Dong Daxiang 已提交
244 245
    def _disable_strategy(self, dist_strategy):
        # TODO(guru4elephant): should close all PE related flags here
246 247
        return

248
    def _enable_strategy(self, dist_strategy, context):
249 250
        # by default, graph execution strategy is enabled
        return
D
Dong Daxiang 已提交
251

252 253 254 255 256 257
    def minimize(self,
                 loss,
                 startup_program=None,
                 parameter_list=None,
                 no_grad_set=None):
        if startup_program == None:
258
            startup_program = paddle.static.default_startup_program()
259 260
        compiled_program = self._try_to_compile(startup_program,
                                                loss.block.program, loss)
261
        loss.block.program._graph = compiled_program
262 263 264

        # just return self.optimizer_ops and self.param_grads
        return None, None