fleet_executor_utils.py 17.9 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY
from paddle.fluid import core
17
from paddle.static import Program
18 19


20 21 22 23 24
class TaskNode:
    """
    Python side TaskNode, connection to the c++ side TaskNode
    """

25 26 27 28 29 30 31 32 33 34 35 36 37 38
    def __init__(
        self,
        rank,
        max_run_times,
        role=None,
        node_type=None,
        task_id=0,
        ops=None,
        program=None,
        lazy_initialize=False,
        cond_var_name=None,
        vars_to_dtype=None,
        vars_to_shape=None,
    ):
39
        """
40
        :param rank (int): Current rank of the task node.
41
        :param max_run_times (int): The max run times of the task node.
42 43 44
        :param role (int): The role of the task node. (Will be removed in the future)
        :param node_type (str): The type of the task node.
        :param task_id (int): The id of task node.
45
        :param ops (list): A list of op.desc to init the task node. (Will be removed in the future)
46
        :param program (Program): An instance of Program to init the task node.
47
        :param lazy_initialize (bool): In user-defined task, the program may change adding feed/fetch op. As efficient consideration, the task node will have the C++ object later.
48 49
        :param cond_var_name (string): Indicate the cond var name of while.
        :param vars_list (list): A list of var name to send.
50
        """
51 52 53 54 55 56
        assert (ops is not None) ^ (
            program is not None
        ), "Should provide only one of ops or program to task node."
        assert not (
            (ops is not None) and lazy_initialize
        ), "Lazy initialization doesn't support with ops list"
57 58 59 60 61 62
        self.id = int(task_id)
        self.rank = rank
        self.max_run_times = max_run_times
        self.node_type = node_type
        self.program = program
        self.lazy_initialize = lazy_initialize
63 64 65
        self.cond_var_name = cond_var_name
        self.vars_to_dtype = vars_to_dtype
        self.vars_to_shape = vars_to_shape
66 67 68 69 70 71 72
        self.run_pre_steps = None
        self.run_at_offset = None
        self.node = None
        self.upstreams = []
        self.downstreams = []
        if not lazy_initialize:
            if ops is not None:
73 74 75 76 77 78 79 80 81 82
                assert (
                    role is not None and task_id is not None
                ), "If init task node with ops, should provide `role` and `task_id`."
                self.node = core.TaskNode(
                    role,
                    ops,
                    rank,
                    task_id,
                    max_run_times,
                )
83
            else:
84 85 86 87 88 89
                self.node = core.TaskNode(
                    program.desc,
                    rank,
                    self.id,
                    max_run_times,
                )
90 91
            if self.node_type:
                self.node.set_type(self.node_type)
92 93

    def task_node(self):
94
        if self.lazy_initialize:
95 96 97 98 99 100
            self.node = core.TaskNode(
                self.program.desc,
                self.rank,
                self.id,
                self.max_run_times,
            )
101 102 103 104 105 106
            if self.node_type:
                self.node.set_type(self.node_type)
            if self.run_pre_steps:
                self.node.set_run_pre_steps(self.run_pre_steps)
            if self.run_at_offset:
                self.node.set_run_at_offset(self.run_at_offset)
107 108 109 110 111 112
            if self.cond_var_name:
                self.node.set_cond_var_name(self.cond_var_name)
            if self.vars_to_shape:
                self.node.set_vars_to_shape(self.vars_to_shape)
            if self.vars_to_dtype:
                self.node.set_vars_to_dtype(self.vars_to_dtype)
113
            for up in self.upstreams:
114
                self.node.add_upstream_task(up[0], up[1], up[2])
115
            for down in self.downstreams:
116
                self.node.add_downstream_task(down[0], down[1], down[2])
117
            self.lazy_initialize = False
118 119 120
        return self.node

    def set_program(self, program):
121 122 123
        assert (
            self.lazy_initialize
        ), "Inside program is unchangable for immediate initialized task node. Set the lazy_initialize to be true if the inside program need to be update. Remember to do all your change before eval node.task_node()."
124 125 126
        self.program = program

    def get_program(self):
127 128 129
        assert (
            self.program is not None
        ), "The task node is not initialized using program"
130 131 132
        return self.program

    def set_run_pre_steps(self, steps):
133 134 135 136
        if self.lazy_initialize:
            self.run_pre_steps = steps
        else:
            self.node.set_run_pre_steps(steps)
137 138

    def set_run_at_offset(self, offset):
139 140 141 142
        if self.lazy_initialize:
            self.run_at_offset = offset
        else:
            self.node.set_run_at_offset(offset)
143

144 145 146
    def add_upstream_task(
        self, upstream, buffer_size=2, depend_type=core.DependType.NORMAL
    ):
147
        if self.lazy_initialize:
148
            self.upstreams.append((upstream, buffer_size, depend_type))
149
        else:
150
            self.node.add_upstream_task(upstream, buffer_size, depend_type)
151

152 153 154
    def add_downstream_task(
        self, downstream, buffer_size=2, depend_type=core.DependType.NORMAL
    ):
155
        if self.lazy_initialize:
156
            self.downstreams.append((downstream, buffer_size, depend_type))
157
        else:
158
            self.node.add_downstream_task(downstream, buffer_size, depend_type)
159 160

    def task_id(self):
161
        return self.id
162 163


164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
class CoordSys:
    """
    This class is used to mapping rank to (mp rank, sharding rank, pp rank, dp rank).
    """

    def __init__(self, dist_opt):
        self.dp_degree = dist_opt.get('dp_degree', 1)
        self.pp_degree = dist_opt.get('pp_degree', 1)
        self.sharding_degree = dist_opt.get('sharding_degree', 1)
        self.mp_degree = dist_opt.get('mp_degree', 1)

    def _invalide_coord(self, coord):
        """
        Test the input coord is valid or not.
        :param coord: The coord to be tested
        :return: False if valid, True if invalid.
        """
181 182 183 184 185 186 187 188 189 190
        return (
            coord['mp_idx'] < 0
            or coord['mp_idx'] >= self.mp_degree
            or coord['sharding_idx'] < 0
            or coord['sharding_idx'] >= self.sharding_degree
            or coord['pp_idx'] < 0
            or coord['pp_idx'] >= self.pp_degree
            or coord['dp_idx'] < 0
            or coord['dp_idx'] >= self.dp_degree
        )
191 192 193 194 195 196 197 198 199

    def coord_to_rank(self, coord):
        """
        Map the input coord to it's corresponding rank.
        :param coord:  The coord to be converted
        :return: The rank corresponding with the coord
        """
        if self._invalide_coord(coord):
            return -1
200 201 202 203 204 205 206 207 208
        return int(
            coord['dp_idx']
            * self.pp_degree
            * self.sharding_degree
            * self.mp_degree
            + coord['pp_idx'] * self.sharding_degree * self.mp_degree
            + coord['sharding_idx'] * self.mp_degree
            + coord['mp_idx']
        )
209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226

    def rank_to_coord(self, rank):
        """
        Map the input rank to it's corresponding coord
        :param rank: The rank to be converted
        :return: The coord corresponding with the rank
        """
        mp_idx = rank % self.mp_degree
        rank //= self.mp_degree
        sharding_idx = rank % self.sharding_degree
        rank //= self.sharding_degree
        pp_idx = rank % self.pp_degree
        rank //= self.pp_degree
        dp_idx = rank % self.dp_degree
        return {
            'mp_idx': int(mp_idx),
            'sharding_idx': int(sharding_idx),
            'pp_idx': int(pp_idx),
227
            'dp_idx': int(dp_idx),
228 229 230
        }


231
class FleetExecutorUtils:
232 233 234
    def __init__(
        self, dist_strategy=None, rank=None, nrank=None, max_run_times=None
    ):
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253
        self.dist_strategy = dist_strategy
        self.rank = rank
        self.nrank = nrank
        self.max_run_times = max_run_times
        self.is_auto_parallel = True if dist_strategy is None else False
        self.num_of_functionality = 4
        self.coord_sys = None
        self.coord = None
        if dist_strategy:
            self.coord_sys = CoordSys(dist_strategy)
            self.coord = self.coord_sys.rank_to_coord(rank)

    def is_optimizer_op(self, op_role):
        return op_role == int(OpRole.Optimize)

    def is_lr_sched_op(self, op_role):
        return op_role == int(OpRole.Optimize.LRSched)

    def is_forward_op(self, op_role):
254 255 256
        return (op_role == int(OpRole.Forward)) or (
            op_role == (int(OpRole.Forward) | int(OpRole.Loss))
        )
257 258

    def is_backward_op(self, op_role):
259 260 261
        return (op_role == int(OpRole.Backward)) or (
            op_role == (int(OpRole.Backward) | int(OpRole.Loss))
        )
262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282

    def split_program_to_op_list(self, program):
        op_list_map = {"lr": [], "fwd": [], "bwd": [], "opt": []}
        for op in program.block(0).ops:
            # split the program based on the op_role
            op_role = int(op.all_attrs()[OP_ROLE_KEY])
            if self.is_lr_sched_op(op_role):
                op_list_map["lr"].append(op)
            elif self.is_forward_op(op_role):
                op_list_map["fwd"].append(op)
            elif self.is_backward_op(op_role):
                op_list_map["bwd"].append(op)
            elif self.is_optimizer_op(op_role):
                op_list_map["opt"].append(op)
            else:
                raise "The op role: " + str(
                    op_role
                ) + " isn't one of LRSched, Forward, Backward or Optimizer."
        return op_list_map

    def convert_op_list_to_program(self, op_list, complete_program):
283
        # TODO(liyurui): Complete this convert logic
284 285 286 287
        program_map = {
            "lr": Program(),
            "fwd": Program(),
            "bwd": Program(),
288
            "opt": Program(),
289 290 291 292
        }
        return program_map

    def build_1f1b_dependency(self, task_node_map):
293 294 295
        assert (
            not self.is_auto_parallel
        ), "Handly add dependency should not be invoked in auto parallel mode"
296 297 298 299 300 301 302 303 304
        # Generated the dependency based on this graph:
        # lr(1:m) -> forward -> backward -> (m:1)optimize
        #               ↑          ↓
        # lr(1:m) -> forward -> backward -> (m:1)optimize
        #               ↑          ↓
        # lr(1:m) -> forward -> backward -> (m:1)optimize

        # add dependency intra stage
        cur_start_id = self.rank * self.num_of_functionality
305 306 307
        pp_buff_size = int(
            self.dist_strategy['pp_degree'] - self.coord['pp_idx']
        )
308 309 310 311 312 313 314 315 316 317 318 319
        task_node_map["lr"].add_downstream_task(cur_start_id + 1)
        task_node_map["fwd"].add_upstream_task(cur_start_id)
        task_node_map["fwd"].add_downstream_task(cur_start_id + 2, pp_buff_size)
        task_node_map["bwd"].add_upstream_task(cur_start_id + 1, pp_buff_size)
        task_node_map["bwd"].add_downstream_task(cur_start_id + 3)
        task_node_map["opt"].add_upstream_task(cur_start_id + 2)
        # add dependency inter stage
        upstream_coord, downstream_coord = self.coord.copy(), self.coord.copy()
        upstream_coord['pp_idx'] = upstream_coord['pp_idx'] - 1
        downstream_coord['pp_idx'] = downstream_coord['pp_idx'] + 1
        pp_upstream = self.coord_sys.coord_to_rank(upstream_coord)
        pp_downstream = self.coord_sys.coord_to_rank(downstream_coord)
320 321
        first_stage = pp_upstream == -1
        last_stage = pp_downstream == -1
322 323 324 325 326 327 328 329 330 331 332 333
        prev_pp_start_id = pp_upstream * self.num_of_functionality
        next_pp_start_id = pp_downstream * self.num_of_functionality
        if not first_stage:
            task_node_map["fwd"].add_upstream_task(prev_pp_start_id + 1)
            task_node_map["bwd"].add_downstream_task(prev_pp_start_id + 2)
        if not last_stage:
            task_node_map["fwd"].add_downstream_task(next_pp_start_id + 1)
            task_node_map["bwd"].add_upstream_task(next_pp_start_id + 2)
        return task_node_map

    def construct_task_nodes_1f1b(self, program_map):
        cur_start_id = int(self.rank * self.num_of_functionality)
334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357
        lr_task_node = TaskNode(
            rank=self.rank,
            max_run_times=self.max_run_times,
            program=program_map["lr"],
            task_id=cur_start_id,
        )
        fwd_task_node = TaskNode(
            rank=self.rank,
            max_run_times=self.max_run_times,
            program=program_map["fwd"],
            task_id=cur_start_id + 1,
        )
        bwd_task_node = TaskNode(
            rank=self.rank,
            max_run_times=self.max_run_times,
            program=program_map["bwd"],
            task_id=cur_start_id + 2,
        )
        opt_task_node = TaskNode(
            rank=self.rank,
            max_run_times=self.max_run_times,
            program=program_map["opt"],
            task_id=cur_start_id + 3,
        )
358 359 360 361
        return {
            "lr": lr_task_node,
            "fwd": fwd_task_node,
            "bwd": bwd_task_node,
362
            "opt": opt_task_node,
363
        }
364

365 366 367 368 369 370 371 372 373
    def task_id_to_rank(self):
        task_id_to_rank = {}
        for i in range(self.nrank):
            for j in range(self.num_of_functionality):
                task_id_to_rank[int(i * self.num_of_functionality + j)] = i
        return task_id_to_rank

    def construct_task_nodes_1f1b_op_list(self, op_list_map):
        cur_start_id = int(self.rank * self.num_of_functionality)
374 375 376 377 378 379 380 381
        lr_task_node = TaskNode(
            rank=self.rank,
            max_run_times=self.max_run_times,
            role=int(OpRole.Optimize.LRSched),
            ops=op_list_map["lr"],
            task_id=cur_start_id,
            node_type="Amplifier",
        )
382
        lr_task_node.set_run_pre_steps(self.max_run_times)
383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406
        fwd_task_node = TaskNode(
            rank=self.rank,
            max_run_times=self.max_run_times,
            role=int(OpRole.Forward),
            ops=op_list_map["fwd"],
            task_id=cur_start_id + 1,
            node_type="Compute",
        )
        bwd_task_node = TaskNode(
            rank=self.rank,
            max_run_times=self.max_run_times,
            role=int(OpRole.Backward),
            ops=op_list_map["bwd"],
            task_id=cur_start_id + 2,
            node_type="Compute",
        )
        opt_task_node = TaskNode(
            rank=self.rank,
            max_run_times=self.max_run_times,
            role=int(OpRole.Optimize),
            ops=op_list_map["opt"],
            task_id=cur_start_id + 3,
            node_type="Amplifier",
        )
407 408 409 410 411 412
        opt_task_node.set_run_pre_steps(self.max_run_times)
        opt_task_node.set_run_at_offset(self.max_run_times - 1)
        return {
            "lr": lr_task_node,
            "fwd": fwd_task_node,
            "bwd": bwd_task_node,
413
            "opt": opt_task_node,
414
        }
415 416


417 418 419 420 421 422 423 424
def run1f1b(
    program,
    rank,
    max_run_times,
    dist_opt,
    nrank,
    with_standalone_executor=False,
):
425 426 427 428 429 430
    """
    Split the program to support 1f1b pipeline scheduler.
    This funct will split the program based on the op_role.
    The program will be split into four parts: lr_sched, fwd, bwd, opt.
    And will create task nodes based on the four parts of the program.
    :param program: The origin program.
431
    :param rank: Current rank (can be got from fleet.worker_index()).
432 433 434
    :param max_run_times: Max run times for a micro batch. AKA number of micro steps.
    :param dist_opt: The fleet_opt configured by user.
    :param nrank: Number of workers (can be got from fleet.worker_num()).
435
    :param with_standalone_executor: Experiment feature, use fleet executor with standalone executor.
436 437 438 439 440
    :return:
        task_nodes (list): four task nodes for current rank
        task_id_to_rank (dict): task nodes' ids to it's corresponding rank
    """
    print("fleet executor will use python side 1f1b scheduler.")
441 442 443 444 445 446
    fleet_executor_utils = FleetExecutorUtils(
        dist_strategy=dist_opt,
        rank=rank,
        nrank=nrank,
        max_run_times=max_run_times,
    )
447 448 449 450
    op_list_map = fleet_executor_utils.split_program_to_op_list(program)
    task_node_map = None
    if with_standalone_executor:
        program_map = fleet_executor_utils.convert_op_list_to_program(
451 452
            op_list_map, program
        )
453
        task_node_map = fleet_executor_utils.construct_task_nodes_1f1b(
454 455
            program_map
        )
456 457 458 459 460 461
    else:
        op_desc_list_map = {"lr": [], "fwd": [], "bwd": [], "opt": []}
        for key in op_list_map:
            for op in op_list_map[key]:
                op_desc_list_map[key].append(op.desc)
        task_node_map = fleet_executor_utils.construct_task_nodes_1f1b_op_list(
462 463
            op_desc_list_map
        )
464 465 466 467 468 469 470
    task_node_map = fleet_executor_utils.build_1f1b_dependency(task_node_map)
    task_id_to_rank = fleet_executor_utils.task_id_to_rank()
    task_node_list = [task_node_map[key].task_node() for key in task_node_map]
    return task_node_list, task_id_to_rank


def origin(program, rank):
471 472 473
    """
    Origin scheduler for fleet executor, supports non-pp mode
    :param program: The origin program.
474
    :param rank: Current rank (can be got from fleet.worker_index()).
475 476 477 478 479
    :return:
        task_nodes (list): four task nodes for current rank
        task_id_to_rank (dict): a fake dict, since there is no upstream or downstream, this dict won't be used
    """
    print("fleet executor will use python side origin scheduler.")
480 481 482 483 484 485
    task_node = TaskNode(
        program=program,
        rank=rank,
        node_type="Compute",
        max_run_times=1,
    )
486
    task_id_to_rank = {task_node.task_id(): rank}
487
    return [task_node.task_node()], task_id_to_rank