shard.py 5.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.distributed.fleet.meta_optimizers.common import is_optimizer_op
from paddle.distributed.fleet.meta_optimizers.sharding.utils import *
from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils

19 20
__all__ = []

21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128

class Shard(object):
    def __init__(self, ):
        self.global_params = set([])
        self.worker_idx = -1
        self.worker_num = -1
        self.global_param2device = {}

    def setup(self, params_grads, worker_idx, worker_num):
        # param names of all devices
        self.global_params = set([x[0].name for x in params_grads])
        # _param(str) -> device_id(int) 
        self.worker_idx = worker_idx
        self.worker_num = worker_num
        # global_param2device contains fp32 params and fp16 params
        self.global_param2device = self._split_params(params_grads, worker_idx,
                                                      worker_num)

    def has_param(self, var_name):
        return var_name in self.global_param2device and \
            self._var_device_id(var_name) == self.worker_idx

    def has_opt_var(self, var_name):
        return self._var_device_id(var_name) == self.worker_idx

    def has_var(self, var_name):
        return self._var_device_id(var_name) == -1 or \
            self._var_device_id(var_name) == self.worker_idx

    def _split_params(self, params_grads, worker_idx, worker_num):
        param2device = {}
        total_param_mem = 0.0
        param2mem = []
        for param in [x[0] for x in params_grads]:
            mem = get_var_size(param)
            total_param_mem += mem
            param2mem.append((param.name, mem))
        device2params = {x: [] for x in range(worker_num)}
        device_idx = 0
        mem_accu = 0.0
        for param_name, mem in param2mem:
            if mem_accu > total_param_mem * 1.0 * (device_idx + 1) / worker_num:
                device_idx += 1
            device2params[device_idx].append(param_name)
            param2device[param_name] = device_idx
            mem_accu += mem
        return param2device

    def _var_device_id(self, var_name):
        if var_name in self.global_param2device:
            return self.global_param2device[var_name]
        for suffix in [
                "_moment1_0", "_moment2_0", "_beta1_pow_acc_0",
                "_beta2_pow_acc_0", "_velocity_0"
        ]:
            base_name = re.sub(suffix, '', var_name)
            if base_name in self.global_param2device:
                return self.global_param2device[base_name]
        return -1

    def find_broadcast_params(self, block):
        broadcast_vars = set([])
        fp16_params = set([])
        fp16_to_fp32 = {}

        param_usage = {x: 0 for x in self.global_params}
        for op in block.ops:
            if is_optimizer_op(op):
                continue
            for input_name in op.desc.input_arg_names():
                if input_name in self.global_params:
                    param_usage[input_name] += 1

        for op in block.ops:
            if not FP16Utils.is_fp16_cast_op(block, op, self.global_params):
                continue
            input_name = op.input_arg_names[0]
            output_name = op.output_arg_names[0]
            broadcast_vars.add(output_name)
            fp16_params.add(output_name)
            fp16_to_fp32[output_name] = input_name
            param_usage[input_name] -= 1
            self.global_param2device[output_name] = self.global_param2device[
                input_name]

        for param, usage in param_usage.items():
            if usage > 0:
                broadcast_vars.add(param)
        return broadcast_vars

    def device(self, var_name):
        return self._var_device_id(var_name)

    def is_param(self, var_name):
        return var_name in self.global_params

    def is_opti_var(self, var_name):
        if var_name in self.global_params:
            return True
        for suffix in [
                "_moment1_0", "_moment2_0", "_beta1_pow_acc_0",
                "_beta2_pow_acc_0", "_velocity_0"
        ]:
            base_name = re.sub(suffix, '', var_name)
            if base_name in self.global_params:
                return True
        return False

129 130 131 132 133 134 135 136
    def filter_grads(self, grads):
        grads_in_shard = []
        for grad in grads:
            param = grad.split("@")[0]
            if self.has_param(param):
                grads_in_shard.append(grad)
        return grads_in_shard

137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154

class ProgramSegment(object):
    def __init__(self, block):
        self._block = block
        self._allreduce_vars = []
        # sub program start idx
        self._start_idx = -1
        # sub program end idx
        self._end_idx = -1
        # param name to broadcast name
        self._param2broadcast = {}
        self._broadcast_vars = []
        # cast op pairs, fp16 name (str) -> fp32 name (str)
        self._cast_ops = {}
        # fill constant vars
        self._fill_constant_vars = []
        # parameter mems
        self._param_mem = 0.0