shard.py 5.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import re
16
from paddle.distributed.fleet.meta_optimizers.common import is_optimizer_op
17
from paddle.distributed.fleet.meta_optimizers.sharding.utils import get_var_size
18 19
from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils

20 21
__all__ = []

22 23

class Shard(object):
24

25 26 27 28
    def __init__(self, ):
        self.global_params = set([])
        self.worker_idx = -1
        self.worker_num = -1
29 30
        self.global_param2device = dict()
        self.device2global_params = dict()
31 32 33 34

    def setup(self, params_grads, worker_idx, worker_num):
        # param names of all devices
        self.global_params = set([x[0].name for x in params_grads])
35
        # _param(str) -> device_id(int)
36 37 38
        self.worker_idx = worker_idx
        self.worker_num = worker_num
        # global_param2device contains fp32 params and fp16 params
39 40 41
        # device2global_params only contains fp32 params
        self.global_param2device, self.device2global_params \
            = self._split_params(params_grads, worker_idx, worker_num)
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70

    def has_param(self, var_name):
        return var_name in self.global_param2device and \
            self._var_device_id(var_name) == self.worker_idx

    def has_opt_var(self, var_name):
        return self._var_device_id(var_name) == self.worker_idx

    def has_var(self, var_name):
        return self._var_device_id(var_name) == -1 or \
            self._var_device_id(var_name) == self.worker_idx

    def _split_params(self, params_grads, worker_idx, worker_num):
        param2device = {}
        total_param_mem = 0.0
        param2mem = []
        for param in [x[0] for x in params_grads]:
            mem = get_var_size(param)
            total_param_mem += mem
            param2mem.append((param.name, mem))
        device2params = {x: [] for x in range(worker_num)}
        device_idx = 0
        mem_accu = 0.0
        for param_name, mem in param2mem:
            if mem_accu > total_param_mem * 1.0 * (device_idx + 1) / worker_num:
                device_idx += 1
            device2params[device_idx].append(param_name)
            param2device[param_name] = device_idx
            mem_accu += mem
71
        return param2device, device2params
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132

    def _var_device_id(self, var_name):
        if var_name in self.global_param2device:
            return self.global_param2device[var_name]
        for suffix in [
                "_moment1_0", "_moment2_0", "_beta1_pow_acc_0",
                "_beta2_pow_acc_0", "_velocity_0"
        ]:
            base_name = re.sub(suffix, '', var_name)
            if base_name in self.global_param2device:
                return self.global_param2device[base_name]
        return -1

    def find_broadcast_params(self, block):
        broadcast_vars = set([])
        fp16_params = set([])
        fp16_to_fp32 = {}

        param_usage = {x: 0 for x in self.global_params}
        for op in block.ops:
            if is_optimizer_op(op):
                continue
            for input_name in op.desc.input_arg_names():
                if input_name in self.global_params:
                    param_usage[input_name] += 1

        for op in block.ops:
            if not FP16Utils.is_fp16_cast_op(block, op, self.global_params):
                continue
            input_name = op.input_arg_names[0]
            output_name = op.output_arg_names[0]
            broadcast_vars.add(output_name)
            fp16_params.add(output_name)
            fp16_to_fp32[output_name] = input_name
            param_usage[input_name] -= 1
            self.global_param2device[output_name] = self.global_param2device[
                input_name]

        for param, usage in param_usage.items():
            if usage > 0:
                broadcast_vars.add(param)
        return broadcast_vars

    def device(self, var_name):
        return self._var_device_id(var_name)

    def is_param(self, var_name):
        return var_name in self.global_params

    def is_opti_var(self, var_name):
        if var_name in self.global_params:
            return True
        for suffix in [
                "_moment1_0", "_moment2_0", "_beta1_pow_acc_0",
                "_beta2_pow_acc_0", "_velocity_0"
        ]:
            base_name = re.sub(suffix, '', var_name)
            if base_name in self.global_params:
                return True
        return False

133 134 135 136 137 138 139 140
    def filter_grads(self, grads):
        grads_in_shard = []
        for grad in grads:
            param = grad.split("@")[0]
            if self.has_param(param):
                grads_in_shard.append(grad)
        return grads_in_shard

141 142

class ProgramSegment(object):
143

144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
    def __init__(self, block):
        self._block = block
        self._allreduce_vars = []
        # sub program start idx
        self._start_idx = -1
        # sub program end idx
        self._end_idx = -1
        # param name to broadcast name
        self._param2broadcast = {}
        self._broadcast_vars = []
        # cast op pairs, fp16 name (str) -> fp32 name (str)
        self._cast_ops = {}
        # fill constant vars
        self._fill_constant_vars = []
        # parameter mems
        self._param_mem = 0.0