internal_storage.py 10.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#Taken and modified for fairscale from:
#    https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/misc/param_bucket.py
#Commit: 8acbec718f3c70a6b9785470bb9e05cd84fc3f8e

import os
import time
import numpy as np

import paddle
B
Baibaifan 已提交
23
import paddle.fluid as fluid
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
from paddle.fluid import core
from ..meta_parallel.sharding.sharding_utils import Type, device_guard


class InternalStorage:
    """
    This is a basic class, which is responsible for consolidating the basic storage tensor.

    """

    # Support integration parameter tensor
    def __init__(self, size, dtype, device, convert_cpu=False):
        self._params = []
        self._param_ids = []
        self._fill = 0
        self._device = device
        self._dtype = dtype

        # The actual flat tensor
        size = [size] if isinstance(size, int) else size
        if convert_cpu:
            value = np.zeros(
                size,
                dtype=np.float16) if Type.fp16.value == dtype else np.zeros(
                    size, dtype=np.float32)
            self.buffer = core.VarBase(value=value, place=core.CPUPlace())
        else:
            self.buffer = paddle.zeros(size, dtype=dtype)

53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
    def to(self, device, dtype=None, keep_alignment=True):
        """
        Move the underlying buffer
        """
        assert self.buffer is not None, "Cannot move a collapsed bucket, please rebuild it"
        assert (dtype == Type.fp32.value or
                Type.fp16.value), "Conversion type is not supported now"

        dev_id = 0 if paddle.get_device() == "cpu" else int(paddle.get_device()
                                                            .split(":")[1])

        if self._device != device:
            tmp_buffer = self.buffer.cuda(
                dev_id) if device == "gpu" else self.buffer.cpu()
            for param in self._params:
                param.clear_gradient(False)
                param._gradient_set_empty(False)
            self.buffer.value().get_tensor()._clear()
            self.buffer = tmp_buffer
72
            self._device = device
73 74 75

        if dtype is not None:
            self.buffer = self.buffer.cast(dtype=dtype)
76
            self._dtype = dtype
77

78 79 80 81 82 83 84 85 86 87

class ParamStorage(InternalStorage):
    """
    This is a basic class to simplify the handling of parameter InternalStorages.
    """

    def __init__(self, size, dtype, device):
        super().__init__(size, dtype, device, convert_cpu=True)
        self.param2align = None

88 89 90 91 92 93 94 95 96 97
    def to(self, device, dtype=None, keep_alignment=True):
        """
        Move the underlying buffer
        """

        super().to(device, dtype)

        if keep_alignment:
            self._array_params()

B
Baibaifan 已提交
98
    @fluid.dygraph.no_grad
99
    def add_rank_params(self, trainable_params, param2align, convert_gpu=True):
100 101 102 103 104 105 106 107 108 109 110 111 112
        """
        Add new parameters to the InternalStorage. Params becomes a view of this InternalStorage buffer.
        """

        assert all([
            id(param) not in self._param_ids for param in trainable_params
        ]), "The same param cannot be checked in twice"
        assert self.buffer is not None

        self.param2align = param2align

        cpu_param_shape = list()
        for param in trainable_params:
113 114
            p_shape = self._add_param_as_view(param, param2align[param.name],
                                              convert_gpu)
115 116
            cpu_param_shape.append(p_shape)

117 118 119 120 121
        if convert_gpu:
            # buffer convert from cpu to cuda
            dev_id = int(paddle.get_device().split(":")[1])
            self.buffer = self.buffer.cuda(dev_id)

122 123 124 125 126 127 128 129
        self._fill = 0

        for idx, param in enumerate(trainable_params):
            self._convert_buffer(param, cpu_param_shape[idx],
                                 param2align[param.name])
            self._params.append(param)
            self._param_ids.append(id(param))

B
Baibaifan 已提交
130
    @fluid.dygraph.no_grad
131
    def _add_param_as_view(self, param, align, convert_gpu=True):
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149

        assert (
            param.dtype == self.buffer.dtype
        ), "Different types for the InternalStorage and the param, cannot proceed: {} - {}".format(
            param.dtype, self.buffer.dtype)

        var_end = self._fill + np.prod(param.shape)
        offset = var_end + align
        assert offset <= np.prod(self.buffer.shape)

        p_shape = param.shape

        origin_state = param.stop_gradient
        param.stop_gradient = True
        param.flatten_()
        param.stop_gradient = origin_state

        # Copy the current param value
150 151
        dev_id = 0 if paddle.get_device() == "cpu" else int(paddle.get_device()
                                                            .split(":")[1])
152 153 154
        with device_guard(dev_id, "cpu"):
            tmp_var = core.VarBase(tensor=self.buffer._slice(self._fill,
                                                             var_end))
155 156 157 158 159 160
            if convert_gpu:
                param_cpu = param.cpu()
                param.value().get_tensor()._clear()
                tmp_var.set_value(param_cpu)
            else:
                tmp_var.set_value(param)
161 162 163 164

        self._fill = offset
        return p_shape

B
Baibaifan 已提交
165
    @fluid.dygraph.no_grad
166 167 168 169 170 171 172 173 174 175 176 177 178
    def _convert_buffer(self, param, p_shape, align):

        var_end = self._fill + np.prod(p_shape)
        offset = var_end + align
        assert offset <= np.prod(self.buffer.shape)

        # Convert the param value
        tmp_tensor = self.buffer._slice(self._fill, var_end)
        param.value().get_tensor()._share_data_with(tmp_tensor)
        param.value().get_tensor()._set_dims(p_shape)

        self._fill = offset

179 180 181 182 183 184 185 186 187 188 189 190
    @fluid.dygraph.no_grad
    def _array_params(self):
        """
        Given the parameters which have been registered previously, rebuild the whole InternalStorage.
        """
        assert len(self._params) > 0
        assert self.param2align is not None

        self._fill = 0
        for p in self._params:
            self._convert_buffer(p, p.shape, self.param2align[p.name])  # modify

191 192 193 194 195 196

class GradStorage(InternalStorage):
    """
    This is a basic class to simplify the handling of gradient InternalStorages
    """

197 198 199 200 201 202 203
    def __init__(self,
                 size,
                 dtype,
                 device,
                 destination,
                 parm2align,
                 convert_cpu=False):
204 205
        if isinstance(size, np.int64):
            size = size.tolist()
206
        super().__init__(size, dtype, device, convert_cpu)
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233

        self._max_size = size
        self._release = False

        self.params_checked_in = 0
        self.destination = destination
        self._parm2align = parm2align
        self.sent = False

    def reset_checked_in(self):
        """ Reset the counter of the parameter grads which have been checked in
        """
        self.params_checked_in = 0
        self.sent = False

    @property
    def all_checked_in(self):
        """ Judge all the expected gradient check-in happened """
        return len(self._params) == self.params_checked_in

    def can_add_grad_view(self, param, align):
        """ Is there enough InternalStorage to add this parameter gradient, and whether this param have already checked in.
        """
        return self._fill + np.prod(
            param.shape) + align <= self._max_size and id(
                param) not in self._param_ids

234 235 236 237 238 239 240 241 242 243 244 245
    def to(self, device, dtype=None, keep_alignment=True):
        """
        Move the underlying buffer
        """
        if self._release:
            self.rebuild()

        super().to(device, dtype)

        if keep_alignment:
            self._array_grads()

B
Baibaifan 已提交
246
    @fluid.dygraph.no_grad
247 248 249 250 251 252 253 254 255 256 257 258 259
    def add_grad(self, param, align):
        """
        Add a new parameter gradient to the InternalStorage. Param.grad becomes a view of this InternalStorage buffer.
        """

        assert id(
            param
        ) not in self._param_ids, "The same gradients cannot be checked in twice"

        self._add_grad_as_view(param, align)
        self._params.append(param)
        self._param_ids.append(id(param))

B
Baibaifan 已提交
260
    @fluid.dygraph.no_grad
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
    def manumal_relase(self):
        """
        Release the buffer from InternalStorage. The InternalStorage will need to be rebuilt before use.
        """
        if not self._release:
            for p in self._params:
                if p.grad is not None:
                    p.clear_gradient(False)
                    p._gradient_set_empty(False)

            self.buffer = None
            self._fill = 0
            self.params_checked_in = 0
            self._release = True

B
Baibaifan 已提交
276
    @fluid.dygraph.no_grad
277 278 279 280 281 282
    def rebuild(self):
        """
        Given the parameter gradients which have been registered previously, rebuild the whole InternalStorage.
        """

        if self._release:
283
            self.buffer = paddle.zeros([self._max_size], dtype=self._dtype)
284 285 286 287 288 289

            for p in self._params:
                self._add_grad_as_view(p, self._parm2align[p.name])

            self._release = False

290 291 292 293 294 295 296 297 298 299
    @fluid.dygraph.no_grad
    def _array_grads(self):
        """
        Given the parameters gradients which have been registered previously, rebuild the whole InternalStorage.
        """
        if len(self._params) > 0:
            self._fill = 0
            for p in self._params:
                self._add_grad_as_view(p, self._parm2align[p.name])

B
Baibaifan 已提交
300
    @fluid.dygraph.no_grad
301 302 303 304 305 306 307 308 309 310 311
    def _add_grad_as_view(self, param, align):
        assert np.prod(
            self.buffer.shape
        ) > 0, "Cannot add a gradient to a released InternalStorage, please rebuild"
        assert param.dtype == self.buffer.dtype

        grad_end = self._fill + np.prod(param.shape)
        offset = grad_end + align
        assert offset <= np.prod(self.buffer.shape)

        # Copy the current grad value to InternalStorage
312 313 314 315 316 317 318 319 320 321 322 323 324
        dev_id = 0 if paddle.get_device() == "cpu" else int(paddle.get_device()
                                                            .split(":")[1])
        if self._device == "cpu":
            with device_guard(dev_id, self._device):
                tmp_var = core.VarBase(self.buffer._slice(self._fill, grad_end))
                param._copy_gradient_from(tmp_var)
                tmp_var.value().get_tensor()._clear()

        elif self._device == "gpu":
            tmp_var = core.VarBase(self.buffer._slice(self._fill, grad_end))
            param._copy_gradient_from(tmp_var)
            tmp_var.value().get_tensor()._clear()

325
        self._fill = offset