shard.py 5.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.distributed.fleet.meta_optimizers.common import is_optimizer_op
from paddle.distributed.fleet.meta_optimizers.sharding.utils import *
from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils


class Shard(object):
    def __init__(self, ):
        self.global_params = set([])
        self.worker_idx = -1
        self.worker_num = -1
        self.global_param2device = {}

    def setup(self, params_grads, worker_idx, worker_num):
        # param names of all devices
        self.global_params = set([x[0].name for x in params_grads])
        # _param(str) -> device_id(int) 
        self.worker_idx = worker_idx
        self.worker_num = worker_num
        # global_param2device contains fp32 params and fp16 params
        self.global_param2device = self._split_params(params_grads, worker_idx,
                                                      worker_num)

    def has_param(self, var_name):
        return var_name in self.global_param2device and \
            self._var_device_id(var_name) == self.worker_idx

    def has_opt_var(self, var_name):
        return self._var_device_id(var_name) == self.worker_idx

    def has_var(self, var_name):
        return self._var_device_id(var_name) == -1 or \
            self._var_device_id(var_name) == self.worker_idx

    def _split_params(self, params_grads, worker_idx, worker_num):
        param2device = {}
        total_param_mem = 0.0
        param2mem = []
        for param in [x[0] for x in params_grads]:
            mem = get_var_size(param)
            total_param_mem += mem
            param2mem.append((param.name, mem))
        device2params = {x: [] for x in range(worker_num)}
        device_idx = 0
        mem_accu = 0.0
        for param_name, mem in param2mem:
            if mem_accu > total_param_mem * 1.0 * (device_idx + 1) / worker_num:
                device_idx += 1
            device2params[device_idx].append(param_name)
            param2device[param_name] = device_idx
            mem_accu += mem
        return param2device

    def _var_device_id(self, var_name):
        if var_name in self.global_param2device:
            return self.global_param2device[var_name]
        for suffix in [
                "_moment1_0", "_moment2_0", "_beta1_pow_acc_0",
                "_beta2_pow_acc_0", "_velocity_0"
        ]:
            base_name = re.sub(suffix, '', var_name)
            if base_name in self.global_param2device:
                return self.global_param2device[base_name]
        return -1

    def find_broadcast_params(self, block):
        broadcast_vars = set([])
        fp16_params = set([])
        fp16_to_fp32 = {}

        param_usage = {x: 0 for x in self.global_params}
        for op in block.ops:
            if is_optimizer_op(op):
                continue
            for input_name in op.desc.input_arg_names():
                if input_name in self.global_params:
                    param_usage[input_name] += 1

        for op in block.ops:
            if not FP16Utils.is_fp16_cast_op(block, op, self.global_params):
                continue
            input_name = op.input_arg_names[0]
            output_name = op.output_arg_names[0]
            broadcast_vars.add(output_name)
            fp16_params.add(output_name)
            fp16_to_fp32[output_name] = input_name
            param_usage[input_name] -= 1
            self.global_param2device[output_name] = self.global_param2device[
                input_name]

        for param, usage in param_usage.items():
            if usage > 0:
                broadcast_vars.add(param)
        return broadcast_vars

    def device(self, var_name):
        return self._var_device_id(var_name)

    def is_param(self, var_name):
        return var_name in self.global_params

    def is_opti_var(self, var_name):
        if var_name in self.global_params:
            return True
        for suffix in [
                "_moment1_0", "_moment2_0", "_beta1_pow_acc_0",
                "_beta2_pow_acc_0", "_velocity_0"
        ]:
            base_name = re.sub(suffix, '', var_name)
            if base_name in self.global_params:
                return True
        return False

127 128 129 130 131 132 133 134
    def filter_grads(self, grads):
        grads_in_shard = []
        for grad in grads:
            param = grad.split("@")[0]
            if self.has_param(param):
                grads_in_shard.append(grad)
        return grads_in_shard

135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152

class ProgramSegment(object):
    def __init__(self, block):
        self._block = block
        self._allreduce_vars = []
        # sub program start idx
        self._start_idx = -1
        # sub program end idx
        self._end_idx = -1
        # param name to broadcast name
        self._param2broadcast = {}
        self._broadcast_vars = []
        # cast op pairs, fp16 name (str) -> fp32 name (str)
        self._cast_ops = {}
        # fill constant vars
        self._fill_constant_vars = []
        # parameter mems
        self._param_mem = 0.0