pp_layers.py 26.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14

S
ShenLiang 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
# The file has been adapted from the file:
#     https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/pipe/module.py
#     Git commit hash: fafc827d643b3eed611e282d909025f16be36601
# We retain the following license from the original files:
# MIT License

# Copyright (c) Microsoft Corporation.

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE

41
import math
42
import re
43 44 45 46 47 48 49
import glob
import os
import numpy as np
import random
from functools import partial

import paddle
50 51
from paddle.fluid.dygraph.layers import Layer
from ...utils.log_util import logger, layer_to_str
52
from ..pp_utils.utils import _hp_recompute, _initialize_recompute_setting
S
ShenLiang 已提交
53
from paddle.fluid.framework import in_dygraph_mode
54

55
__all__ = []
56 57 58


class LayerDesc(object):
59

60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
    def __init__(self, layer_func, *inputs, **kwargs):
        self.layer_func = layer_func
        self.inputs = inputs
        self.kwargs = kwargs

        if not issubclass(layer_func, Layer):
            raise TypeError(
                "The input(layer_func) should be a derived class of Layer.")

    def build_layer(self):
        return self.layer_func(*self.inputs, **self.kwargs)

    def __repr__(self):
        return layer_to_str(self.layer_func.__name__, *self.inputs,
                            **self.kwargs)


77
class SharedLayerDesc(LayerDesc):
78

79 80 81 82 83 84 85 86 87 88 89 90 91
    def __init__(self,
                 key,
                 layer_func,
                 forward_func=None,
                 shared_weight_attr='weight',
                 *inputs,
                 **kwargs):
        super(SharedLayerDesc, self).__init__(layer_func, *inputs, **kwargs)
        self.layer_name = key
        self.forward_func = forward_func
        self.shared_weight_attr = shared_weight_attr


92
class SegmentLayers(object):
93

94 95 96 97 98
    def __init__(self,
                 layers_desc,
                 num_parts,
                 method="uniform",
                 num_virtual_pipeline_stage=None):
99 100 101 102
        self._layers_desc = layers_desc
        self.method = method
        self.num_parts = num_parts
        self.num_items = len(layers_desc)
103 104 105
        self.num_virtual_pipeline_stage = num_virtual_pipeline_stage
        if self.num_virtual_pipeline_stage is not None:
            self.total_parts = num_parts * self.num_virtual_pipeline_stage
106 107 108 109 110 111 112 113 114 115 116 117 118 119
        assert self.num_items >= self.num_parts, "layer number should be greater than number of segments"

    def do_segment(self):
        if self.method == "uniform":
            return self.uniform(self.num_items, self.num_parts)

        elif self.method.startswith('layer:'):
            # Divide equally according to the specified layer
            layername = self.method.split(':')[1]
            weights = [0] * len(self._layers_desc)
            weight_idxs = self._gen_layer_weight(layername)
            for idx in weight_idxs:
                weights[idx] = 1

120 121
            actual_num_parts = self.num_parts if self.num_virtual_pipeline_stage is None else self.total_parts

122 123
            assert sum(
                weights
124 125 126 127
            ) % actual_num_parts == 0, "number of layers ({}) should be divided by part number({})".format(
                sum(weights), actual_num_parts)
            part_size = sum(weights) // actual_num_parts
            result = [0 for _ in range(actual_num_parts + 1)]
128 129 130 131 132 133 134 135 136

            memory_counter = 0
            result_idx = 1
            for idx, weight in enumerate(weights):
                memory_counter += weight
                if memory_counter == part_size:
                    result[result_idx] = idx + 1
                    result_idx += 1
                    memory_counter = 0
137
            result[actual_num_parts] = len(weights)
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
            return result

    def _gen_layer_weight(self, layername):
        weight_idxs = []
        regex = re.compile(layername, re.IGNORECASE)
        for idx, layer in enumerate(self._layers_desc):
            name = None
            if isinstance(layer, Layer):
                name = layer.__class__.__name__
            elif isinstance(layer, LayerDesc):
                name = layer.layer_func.__name__
            else:
                try:
                    name = layer.__name__
                except AttributeError:
                    # it is not error
                    continue
            if regex.search(name):
                weight_idxs.append(idx)

        assert len(
            weight_idxs) > 0, "weight_idxs' length should be greater than 0"
        return weight_idxs

    def uniform(self, num_items, num_parts):
        result = [0 for _ in range(num_parts + 1)]
        part_size = math.floor(num_items / num_parts)
        for i in range(num_parts):
            result[i] = int(min(part_size * i, num_items))
        result[num_parts] = num_items
        return result


171 172 173 174
class PipelineLayerChunk(Layer):

    def __init__(self):
        super(PipelineLayerChunk, self).__init__()
175
        self.run_function = []
176 177 178 179 180 181

    def append(self, sublayer):
        # This method is used to unify codes in _build_layer_impl.
        # For 1f1b scheduler, it will call append method of a List.
        # For interleave scheduler, it will call append method of this class.
        if isinstance(sublayer, Layer):
182 183
            self.add_sublayer(str(len(self.run_function)), sublayer)
        self.run_function.append(sublayer)
184

185 186 187 188 189 190 191 192 193 194
    def get_run_function(self):
        return self.run_function

    def forward(self, *args, **kwargs):
        # Users shouldn't call PipelineLayerChunk directly, since all logics relating with recompute
        # are in the forward function of PipelineLayer. Any directly call will bring unexpected
        # behavior under recompute circumstance.
        raise NotImplementedError(
            "The forward function of PipelineLayerChunk cannot be called directly. "
            "Please call forward function of PipelineLayer.")
195 196


197
class PipelineLayer(Layer):
198

199 200 201 202 203
    def __init__(self,
                 layers,
                 num_stages=None,
                 topology=None,
                 loss_fn=None,
204 205 206
                 seg_method="uniform",
                 recompute_interval=0,
                 recompute_offload=False,
207 208
                 recompute_partition=False,
                 num_virtual_pipeline_stages=None):
209 210 211 212
        super(PipelineLayer, self).__init__()
        if num_stages is None and topology is None:
            raise ValueError("should provide num_stages or topology")

213 214 215 216 217 218 219 220 221 222 223 224 225 226
        if num_virtual_pipeline_stages:
            assert isinstance(num_virtual_pipeline_stages, int), \
                "virtual_pipeline_stage should be None or an int"
            if num_virtual_pipeline_stages > 1:
                logger.info(
                    "set num_virtual_pipeline_stages > 1 means using interleave scheduler instead of 1f1b scheduler"
                )
                assert isinstance(seg_method, str), \
                    "seg_method should be a str for interleave scheduler"
                assert seg_method.startswith('layer:'), \
                    "seg_method shoud be start with layer: for interleave scheduler"

        self._num_virtual_pipeline_stages = 1 if num_virtual_pipeline_stages is None else num_virtual_pipeline_stages

227 228 229 230 231 232 233 234
        # lazy import
        import paddle.distributed as dist
        from paddle.distributed import fleet

        self.device_id = dist.ParallelEnv().device_id
        self.layers = layers
        self._loss_fn = loss_fn
        self._topo = topology
235 236 237 238 239 240
        self._recompute_interval = recompute_interval
        self._recompute_offload = recompute_offload
        self._recompute_partition = recompute_partition

        if recompute_interval > 0:
            logger.info(
241 242
                "Start Recompute for PipeLineParallel. recompute_offload: {}, recompute_partition: {}"
                .format(recompute_offload, recompute_partition))
243 244
        _initialize_recompute_setting(recompute_offload, recompute_partition)

245
        world_size = dist.get_world_size()
246 247 248 249 250 251 252 253 254 255
        self.global_rank = dist.get_rank()

        if self._topo:
            self._stage_id = self._topo.get_coord(self.global_rank).pipe
            self._num_stages = self._topo.get_dim_size("pipe")
            if num_stages:
                assert self._num_stages == num_stages, "num_stages should be equal to be %d" % (
                    self._num_stages)
        else:
            # construct default topology
256
            if world_size % num_stages != 0:
257 258 259 260
                raise ValueError(
                    "should provide correct num_stages({}) "
                    "which can be divided by world_size({})".format(
                        num_stages, world_size))
261
            dp_num = world_size // num_stages
262 263 264 265 266
            self._topo = fleet.CommunicateTopology(["data", "pipe", "model"],
                                                   [dp_num, num_stages, 1])
            self._stage_id = self._topo.get_coord(self.global_rank).pipe
            self._num_stages = self._topo.get_dim_size("pipe")

267 268
        self._total_stages_with_virtual_stages = self._num_stages * self._num_virtual_pipeline_stages

269 270 271
        # initialize segment
        self._layers_desc = list(self.layers)
        self._num_layers = len(self._layers_desc)
272 273
        self.shared_layers = paddle.nn.LayerDict()
        self.shared_weight_attrs = {}
274

275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292
        if self._num_virtual_pipeline_stages > 1:
            # interleaving pipeline segmentation
            self._start_poss = []
            self._end_poss = []
            self._segment_network_for_interleave(seg_method)
            # The _model_chunks is a list of PipelineLayerChunk,
            # while PipelineLayerChunk is a list of Layers relating with one model chunk.
            # Therefore, the _model_chunks is something like 'list of a list of layers'.
            self._model_chunks = []
            self._build_layer_with_interleave()
        else:
            # 1f1b pipeline segmentation
            self._start_pos = 0
            self._end_pos = self._num_layers - 1
            self._segment_network(seg_method)
            # construct layer
            self.run_function = []
            self._build_layer()
293

294 295 296 297 298
        self.shared_comm = self._construct_shared_comm()
        self._synchronize_shared_weights()

    def get_stage_from_index(self, layer_idx):
        assert 0 <= layer_idx < self._num_layers, "layer_idx is out of bound"
299 300 301 302 303 304 305 306 307 308 309 310 311
        for virtual_pp_rank in range(self._num_virtual_pipeline_stages):
            # Mapping the virtual pipeline stage to the real pipeline stage.
            # start_idx marks the start of a new virtual pp stage.
            start_idx = virtual_pp_rank * self._num_virtual_pipeline_stages
            for stage in range(self._num_stages):
                # stage mark the real pp stage
                if self.segment_parts[start_idx +
                                      stage] <= layer_idx < self.segment_parts[
                                          start_idx + stage + 1]:
                    return stage

    def get_model_chunks(self):
        return None if self._num_virtual_pipeline_stages == 1 else self._model_chunks
312 313 314 315 316 317 318

    def _construct_shared_comm(self):
        shared_comm = {}
        if self._topo.get_dim("pipe") == 1:
            return

        layers_desc = self._layers_desc
319 320
        shared_layer_names = set(s.layer_name for s in layers_desc
                                 if isinstance(s, SharedLayerDesc))
321 322 323 324 325 326 327 328 329 330 331
        for key in shared_layer_names:
            shared_layers = []
            for idx, layer in enumerate(layers_desc):
                if isinstance(layer,
                              SharedLayerDesc) and layer.layer_name == key:
                    shared_layers.append(idx)

            shared_stages = set(
                self.get_stage_from_index(idx) for idx in shared_layers)
            self._dp_degree = self._topo.get_dim('data')
            self._mp_degree = self._topo.get_dim('model')
332
            self._sharding_degree = self._topo.get_dim('sharding')
333 334 335

            shared_ranks = []
            for dp in range(self._dp_degree):
336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358
                for sharding in range(self._sharding_degree):
                    for mp in range(self._mp_degree):
                        shared_ranks = []
                        for s in sorted(shared_stages):
                            shared_ranks.append(
                                self._topo.get_rank_from_stage(
                                    self.global_rank,
                                    pipe=s,
                                    data=dp,
                                    sharding=sharding,
                                    model=mp))

                        group = paddle.distributed.new_group(ranks=shared_ranks)
                        if self.global_rank in shared_ranks:
                            assert key in self.shared_layers
                            if key in self.shared_layers:
                                shared_comm[key] = {
                                    'ranks': shared_ranks,
                                    'group': group,
                                    'weight_attr':
                                    self.shared_weight_attrs[key],
                                    'layer': self.shared_layers[key],
                                }
359 360 361 362 363
        return shared_comm

    def _synchronize_shared_weights(self):
        for key, comm in self.shared_comm.items():
            with paddle.framework.no_grad():
364 365 366 367
                paddle.distributed.broadcast(getattr(comm['layer'],
                                                     comm['weight_attr']),
                                             src=min(comm['ranks']),
                                             group=comm['group'])
368

369 370 371 372
            for param in comm['layer'].parameters():
                if self.global_rank != min(comm['ranks']):
                    setattr(param, 'is_firstly_shared', False)

373 374 375 376
    def allreduce_shared_weight_gradients(self):
        for key, comm in self.shared_comm.items():
            param = getattr(self.shared_layers[key], comm['weight_attr'])
            # need use trace_op to allreduce weight
S
ShenLiang 已提交
377 378
            if in_dygraph_mode():
                with paddle.framework.no_grad():
379 380
                    paddle.distributed.all_reduce(param.grad,
                                                  group=comm['group'])
S
ShenLiang 已提交
381 382 383 384 385 386 387 388 389 390
            else:
                with paddle.framework.no_grad():
                    paddle.fluid.framework._dygraph_tracer().trace_op(
                        type="c_allreduce_sum",
                        inputs={'X': param._grad_ivar()},
                        outputs={'Out': param._grad_ivar()},
                        attrs={
                            'ring_id': comm['group'].id,
                            'use_calc_stream': True
                        })
391

392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418
    def _segment_network_for_interleave(self, seg_method):
        logger.info("start segment network for interleave scheduler")
        seg = SegmentLayers(
            self._layers_desc,
            num_parts=self._num_stages,
            method=seg_method,
            num_virtual_pipeline_stage=self._num_virtual_pipeline_stages)
        self.segment_parts = seg.do_segment()

        logger.info("segment result:" +
                    ", ".join(str(arg) for arg in self.segment_parts))

        for i in range(self._stage_id, self._total_stages_with_virtual_stages,
                       self._num_virtual_pipeline_stages):
            # If there are 2 real pp stages and 2 virtual pp stages, and the model has 8 layers.
            # Layers [0, 1], [4, 5] will be assigned to the first real pp stage.
            # Layers [2, 3], [6, 7] will be assigned to the second real pp stage.
            # Layers [0, 1] and [2, 3] are the first virtual pp stage in each real pp stage.
            # Layers [4, 5] and [6, 7] are the second virtual pp stage in each real pp stage.
            assert self.segment_parts[i] <= self.segment_parts[i + 1]
            self._start_poss.append(self.segment_parts[i])
            self._end_poss.append(self.segment_parts[i + 1])

        assert len(self._start_poss) == len(self._end_poss)

        self._print_segmentation_for_debug()

419 420
    def _segment_network(self, seg_method):
        logger.info("start segment network..")
421 422 423
        seg = SegmentLayers(self._layers_desc,
                            num_parts=self._num_stages,
                            method=seg_method)
424 425
        self.segment_parts = seg.do_segment()

426 427
        logger.info("segment result:" +
                    ", ".join(str(arg) for arg in self.segment_parts))
428

429 430
        self._start_pos = self.segment_parts[self._stage_id]
        self._end_pos = self.segment_parts[self._stage_id + 1]
431
        self._print_segmentation_for_debug()
432

433
    def _print_segmentation_for_debug(self):
434
        # print information for debug
435 436
        for stage in range(self._num_stages *
                           self._num_virtual_pipeline_stages):
437 438 439 440 441 442 443 444
            start = self.segment_parts[stage]
            end = self.segment_parts[stage + 1]
            logger.info("stage={}, global_rank={} ,layer_number={}".format(
                stage, self.global_rank, end - start))

            for index, layer in enumerate(self._layers_desc[start:end]):
                logger.info("{}: {}".format(index + start, str(layer)))

445 446 447 448 449 450 451 452 453
        if self._num_virtual_pipeline_stages > 1:
            for stage in range(self._num_stages):
                stage_to_virtual_stage_info = "stage {} contains virtual stages: ".format(
                    stage)
                for i in range(stage, self._total_stages_with_virtual_stages,
                               self._num_virtual_pipeline_stages):
                    stage_to_virtual_stage_info += " {},".format(i)
                logger.info(stage_to_virtual_stage_info)

454 455 456 457 458 459
        if self._loss_fn:
            try:
                logger.info("loss: {}".format(self._loss_fn.__name__))
            except AttributeError:
                logger.info("loss: {}".format(self._loss_fn.__class__.__name__))

460 461 462 463 464 465 466 467 468 469 470
    def _build_layer_with_interleave(self):
        for i in range(len(self._start_poss)):
            start = self._start_poss[i]
            end = self._end_poss[i]
            # Get a model chunk
            chunk = self._build_layer_impl(start, end)
            assert isinstance(chunk, PipelineLayerChunk)
            # Add the chunk to all chunks and add this chunk to the sublayer
            self._model_chunks.append(chunk)
            self.add_sublayer(str(start), chunk)

471 472 473
    def _build_layer(self):
        start = self._start_pos
        end = self._end_pos
474 475 476 477 478 479 480 481 482 483
        self.run_function = self._build_layer_impl(start, end)

    def _build_layer_impl(self, start, end):
        if self._num_virtual_pipeline_stages > 1:
            # For interleave scheduler, all layers relating with one model chunk will be saved in PipelineLayerChunk
            run_function = PipelineLayerChunk()
        else:
            # For 1f1b scheduler, just use run_function list
            run_function = self.run_function

484 485 486
        for index, layer in enumerate(self._layers_desc[start:end]):
            layer_index = start + index
            if isinstance(layer, Layer):
487 488 489 490 491
                run_function.append(layer)
                if self._num_virtual_pipeline_stages == 1:
                    # Only add sublayer for 1f1b scheduler,
                    # for interleave, PipelineLayerChunk will do this
                    self.add_sublayer(str(layer_index), layer)
492 493 494 495 496
            elif isinstance(layer, SharedLayerDesc):
                if layer.layer_name not in self.shared_layers:
                    self.shared_layers[layer.layer_name] = layer.build_layer()
                    self.shared_weight_attrs[
                        layer.layer_name] = layer.shared_weight_attr
497 498 499
                    for param in self.shared_layers[
                            layer.layer_name].parameters():
                        setattr(param, "is_firstly_shared", True)
500 501

                if layer.forward_func is None:
502
                    run_function.append(self.shared_layers[layer.layer_name])
503 504

                else:
505
                    run_function.append(
506 507
                        partial(layer.forward_func,
                                self.shared_layers[layer.layer_name]))
508

509 510
            elif isinstance(layer, LayerDesc):
                model = layer.build_layer()
511 512 513 514 515
                run_function.append(model)
                if self._num_virtual_pipeline_stages == 1:
                    # Only add sublayer for 1f1b scheduler,
                    # for interleave, PipelineLayerChunk will do this
                    self.add_sublayer(str(layer_index), model)
516
            else:
517 518 519
                run_function.append(layer)

        return run_function
520

521
    def forward_function(self, start, end):
522

523 524 525 526 527 528 529 530 531
        def execute_func(*x):
            if len(x) == 1:
                x = x[0]
            for idx, layer in enumerate(self.run_function[start:end]):
                x = layer(x)
            return x

        return execute_func

532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547
    def forward(self, input, chunk_id=None):
        if chunk_id is not None:
            assert isinstance(chunk_id, int), "chunk_id should be an int"
            assert self._num_virtual_pipeline_stages > 1, \
                "chunk_id is only valid when using virtual pipeline stage"
            assert chunk_id < len(self._model_chunks), \
                "The virtual pipeline only has {} chunks, " \
                "but received chunk_id {}.".format(len(self._model_chunks), chunk_id)
            # Get the target model chunk.
            model_chunk = self._model_chunks[chunk_id]
            # Update the self.run_function to the target run functions.
            # Runs for 1f1b and interleave are similar, just handle all functions in self.run_function.
            # The only different is that, for 1f1b, self.run_function has already been inited during build_layer.
            # But for interleave, self.run_function will keep updating to the target functions at every run.
            self.run_function = model_chunk.get_run_function()

548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564
        if self._recompute_interval == 0:
            input = self.forward_function(0, len(self.run_function))(input)
        else:
            num_layers = len(self.run_function)
            for start_idx in range(0, num_layers, self._recompute_interval):
                end_idx = min(start_idx + self._recompute_interval, num_layers)
                funcs = self.run_function[start_idx:end_idx]

                if not isinstance(input, tuple):
                    input = (input, )

                if self._need_recompute(funcs, input):
                    input = _hp_recompute(
                        self.forward_function(start_idx, end_idx), *input)
                else:
                    input = self.forward_function(start_idx, end_idx)(*input)

565
        return input
566

567
    def _need_recompute(self, funcs, inputs):
568 569
        if not any(input_.stop_gradient == False
                   for input_ in inputs if isinstance(input_, paddle.Tensor)):
570 571 572 573 574
            return False

        params = [f.parameters() for f in funcs if isinstance(f, Layer)]
        return any(len(list(p)) > 0 for p in params)

575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618
    def save_state_dict(self, path):
        if self._topo.get_coord(self.global_rank).data != 0:
            return

        def _offset_dirname(ckpt_dir, local_layer_idx):
            idx = local_layer_idx + self._start_pos
            model_rank = self._topo.get_coord(self.global_rank).model
            rank_message = "-tensor_" + "{:0>2d}".format(model_rank)
            layer_save_path = os.path.join(ckpt_dir,
                                           'layer_{:0>2d}'.format(idx))
            layer_save_path = layer_save_path + rank_message + '-model_states.pdparams'
            return layer_save_path

        os.makedirs(path, exist_ok=True)
        for idx, layer in enumerate(self.run_function):
            model_save_path = _offset_dirname(path, idx)
            if not hasattr(layer, 'state_dict'):
                continue
            paddle.save(layer.state_dict(), model_save_path)

        logger.info("save model state successfully...")

    def set_state_dir(self, path):
        assert os.path.exists(
            path), "{} not found, please check the path".format(path)

        for idx, layer in enumerate(self.run_function):
            if not hasattr(layer, 'set_state_dict'):
                continue
            layer_idx = idx + self._start_pos
            layer_save_path = os.path.join(path,
                                           'layer_{0:0>2d}'.format(layer_idx))
            model_files = glob.glob(layer_save_path + "*model_states.pdparams")
            model_files.sort()
            mp_rank = self._topo.get_coord(self.global_rank).model
            mp_world_size = self._topo.get_dim('model')
            num_files = len(model_files)

            load_param_path = model_files[mp_rank * num_files // mp_world_size]
            model_state_dict = paddle.load(load_param_path)
            layer.set_state_dict(model_state_dict)

        self._synchronize_shared_weights()
        logger.info("load model state successfully...")