memory_optimization_transpiler.py 19.8 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

17
import six
18
from collections import defaultdict, MutableSet
19
from .. import core
M
minqiyang 已提交
20
from ... import compat as cpt
21
from ..framework import Program, default_main_program, Parameter, Variable, core
22
from ..backward import _rename_arg_
23
from functools import reduce
24
from six.moves import range
25 26

dtype_to_size = {
27 28 29 30 31 32
    core.VarDesc.VarType.FP16: 2,
    core.VarDesc.VarType.FP32: 4,
    core.VarDesc.VarType.FP64: 8,
    core.VarDesc.VarType.INT16: 2,
    core.VarDesc.VarType.INT32: 4,
    core.VarDesc.VarType.INT64: 8,
33 34
    core.VarDesc.VarType.BOOL: 1,
    core.VarDesc.VarType.UINT8: 1,
35
}
36

37
SUB_BLOCK_OPS = [
X
Xin Pan 已提交
38
    "while", "while_grad", "conditional_block", "conditional_block_grad"
39
]
40

X
Xin Pan 已提交
41
SUB_BLOCK_PAIR = [("while", "while_grad"),
42 43
                  ("conditional_block", "conditional_block_grad")]

Q
qiaolongfei 已提交
44 45
PRINT_LOG = False

46

47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
class OrderedSet(MutableSet):
    def __init__(self, iterable=None):
        self.end = end = []
        end += [None, end, end]  # sentinel node for doubly linked list
        self.map = {}  # key --> [key, prev, next]
        if iterable is not None:
            self |= iterable

    def __len__(self):
        return len(self.map)

    def __contains__(self, key):
        return key in self.map

    def add(self, key):
        if key not in self.map:
            end = self.end
            curr = end[1]
            curr[2] = end[1] = self.map[key] = [key, curr, end]

    def update(self, other):
        for e in other:
            self.add(e)

    def discard(self, key):
        if key in self.map:
            key, prev, next = self.map.pop(key)
            prev[2] = next
            next[1] = prev

    def remove(self, key):
        self.discard(key)

    def __iter__(self):
        end = self.end
        curr = end[2]
        while curr is not end:
            yield curr[0]
            curr = curr[2]

    def __reversed__(self):
        end = self.end
        curr = end[1]
        while curr is not end:
            yield curr[0]
            curr = curr[1]

    def pop(self, last=True):
        if not self:
            raise KeyError('set is empty')
        key = self.end[1][0] if last else self.end[2][0]
        self.discard(key)
        return key

    def __repr__(self):
        if not self:
            return '%s()' % (self.__class__.__name__, )
        return '%s(%r)' % (self.__class__.__name__, list(self))

    def __eq__(self, other):
        if isinstance(other, OrderedSet):
            return len(self) == len(other) and list(self) == list(other)
        return set(self) == set(other)


112
class ControlFlowGraph(object):
113 114
    def __init__(self, program, ops, forward_num, skip_opt):
        self._program = program
115 116
        self._ops = ops
        self._forward_num = forward_num
117 118 119 120 121 122
        self._successors = defaultdict(OrderedSet)
        self._presuccessors = defaultdict(OrderedSet)
        self._uses = defaultdict(OrderedSet)
        self._defs = defaultdict(OrderedSet)
        self._live_in = defaultdict(OrderedSet)
        self._live_out = defaultdict(OrderedSet)
123
        self._skip_opt = skip_opt
D
dzhwinter 已提交
124
        self.pool = []
125 126

    def _add_connections(self, connections):
127
        """Populates _successors and _presuccessors for two neighbor nodes."""
128 129 130 131
        for node1, node2 in connections:
            self._add(node1, node2)

    def _add(self, node1, node2):
132 133
        self._successors[node1].add(node2)
        self._presuccessors[node2].add(node1)
134

135 136
    # TODO(panyx0718): We need to have a unified way of building intermediate
    # representation.
137
    def _build_graph(self):
138 139
        """Build a graph based on op sequence.
        """
140
        self.op_size = len(self._ops)
141 142 143
        op_node_connections = [(i, i + 1) for i in range(self.op_size - 1)]
        self._add_connections(op_node_connections)
        for i in range(self.op_size):
144 145
            self._uses[i].update(self._ops[i].input_arg_names())
            self._defs[i].update(self._ops[i].output_arg_names())
D
dzhwinter 已提交
146
            self._live_in[i] = self._uses[i]
147

148 149 150 151 152 153 154 155 156 157
    def _update_graph(self, old_name, new_name, begin_idx=0):
        for i in range(begin_idx, self.op_size):
            if old_name in self._uses[i]:
                self._uses[i].remove(old_name)
                self._uses[i].add(new_name)
            if old_name in self._defs[i]:
                self._defs[i].remove(old_name)
                self._defs[i].add(new_name)
            if old_name in self._live_in[i]:
                self._live_in[i].remove(old_name)
D
dzhwinter 已提交
158
                self._live_in[i].add(new_name)
159 160 161 162
            if old_name in self._live_out[i]:
                self._live_out[i].remove(old_name)
                self._live_out[i].add(new_name)

163 164 165
    def _dataflow_analyze(self):
        self._build_graph()
        live_in = defaultdict(set)
D
dzhwinter 已提交
166 167 168 169 170 171 172 173
        worklist = list(range(len(self._ops) - 1, -1, -1))
        while worklist:
            i = worklist.pop(0)
            live_in[i] = set(self._live_in[i])
            for s in self._successors[i]:
                self._live_out[i] |= self._live_in[s]
            self._live_in[i] = self._uses[i] | (
                self._live_out[i] - self._defs[i])
D
dongzhihong 已提交
174
            if live_in[i] != set(self._live_in[i]):
D
dzhwinter 已提交
175 176
                for d in self._presuccessors[i]:
                    worklist.append(d)
177

D
dzhwinter 已提交
178 179 180
    def _fill_pool(self, i, is_forward):
        block_desc = self._ops[i].block()
        in_diff, _ = self._get_diff(self._live_in[i], self._live_out[i])
181 182
        # NOTE: must sort the in_diff set for cases that get different cache var.
        # FIXME(typhoonzero): maybe use a "sorted set" is better than this.
D
dzhwinter 已提交
183
        can_optimize = [
184
            x for x in in_diff
D
dzhwinter 已提交
185 186 187 188
            if self._check_var_validity(block_desc, x, is_forward)
        ]
        if can_optimize:
            for var_name in can_optimize:
D
dzhwinter 已提交
189 190
                cache = (var_name, self._find_var(block_desc, var_name,
                                                  is_forward).shape())
D
dzhwinter 已提交
191 192
                if cache not in self.pool:
                    self.pool.append(cache)
193 194 195 196 197

    def _get_diff(self, a, b):
        u = a & b
        return a - u, b - u

198 199
    def _has_var(self, block_desc, var_name, is_forward):
        if is_forward:
M
minqiyang 已提交
200
            return block_desc.has_var(cpt.to_bytes(var_name))
201
        else:
M
minqiyang 已提交
202
            return block_desc.has_var_recursive(cpt.to_bytes(var_name))
203 204 205

    def _find_var(self, block_desc, var_name, is_forward):
        if is_forward:
M
minqiyang 已提交
206
            return block_desc.find_var(cpt.to_bytes(var_name))
207
        else:
M
minqiyang 已提交
208
            return block_desc.find_var_recursive(cpt.to_bytes(var_name))
209

210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
    def _check_var_validity(self, block_desc, x, is_forward):
        if str(x) == "@EMPTY@":
            return False
        if not self._has_var(block_desc, x, is_forward):
            return False
        if self._find_var(block_desc, x, is_forward).persistable():
            return False
        if self._find_var(block_desc, x,
                          is_forward).type() != core.VarDesc.VarType.LOD_TENSOR:
            return False
        if x in self._skip_opt:
            return False
        if not self._find_var(block_desc, x, is_forward).shape():
            return False
        return True
225

226 227
    # TODO(panyx0718): This needs to be less hacky. It seems memory optimization
    # doesn't consider vars copied between cpu and gpu.
228 229 230 231 232 233
    def _update_skip_opt_set(self):
        for i in range(self.op_size):
            op = self._ops[i]
            if op.type() == "fill_constant" and op.attr("force_cpu") == True:
                self._skip_opt.update(op.output_arg_names())

234
    def release_memory(self, skip_opt_set=None):
235
        self._dataflow_analyze()
236
        self._update_skip_opt_set()
237 238
        if skip_opt_set:
            self._skip_opt.update(skip_opt_set)
239 240 241 242
        fwd_id = 0
        bwd_id = 0
        for i in range(self.op_size):
            op = self._ops[i]
243
            if op.type() in SUB_BLOCK_OPS:
244 245 246 247 248
                continue
            block_desc = op.block()
            is_forward = i < self._forward_num
            in_diff, out_diff = self._get_diff(self._live_in[i],
                                               self._live_out[i])
249 250 251 252
            can_optimize = [
                x for x in in_diff
                if self._check_var_validity(block_desc, x, is_forward)
            ]
253 254
            if can_optimize:
                index = i + fwd_id + 1 if is_forward else i - self._forward_num + bwd_id + 1
W
Wu Yi 已提交
255
                delete_op = block_desc._insert_op(index)
256 257 258 259 260 261 262
                delete_op.set_type("delete_var")
                delete_op.set_input("X", can_optimize)
                if is_forward:
                    fwd_id += 1
                else:
                    bwd_id += 1

263
    def memory_optimize(self, skip_opt_set=None, level=0):
264 265 266
        def compare_shape(x_shape, cache_shape, opt_level):
            if opt_level == 0:
                return x_shape == cache_shape
267
            elif opt_level == 1:
268 269 270 271 272 273
                if (x_shape[0] == -1) ^ (cache_shape[0] == -1):
                    return False
                x_size = abs(reduce(lambda x, y: x * y, x_shape))
                cache_size = abs(reduce(lambda x, y: x * y, cache_shape))
                if x_size <= cache_size:
                    return True
274 275
            else:
                raise ValueError("only support opt_level 0 or 1.")
276 277 278 279
            return False

        self._dataflow_analyze()
        self._update_skip_opt_set()
280 281 282
        # update skip set to meet users' demand
        if skip_opt_set:
            self._skip_opt.update(skip_opt_set)
283
        for i in range(self.op_size):
284
            op = self._ops[i]
285
            if op.type() in SUB_BLOCK_OPS:
286 287 288
                continue
            block_desc = op.block()
            is_forward = i < self._forward_num
289
            if self.pool:
290
                # NOTE: must sort the in_diff set for cases that get different cache var.
291
                defs_can_optimize = [
292
                    x for x in self._defs[i]
293 294
                    if self._check_var_validity(block_desc, x, is_forward)
                ]
295 296 297 298
                out_pair = [
                    (x, self._find_var(block_desc, x, is_forward).shape())
                    for x in defs_can_optimize
                ]
299
                for x, x_shape in out_pair:
300 301 302
                    # If x is both in uses and defs, it can not be optimized!
                    if x in self._uses[i]:
                        continue
303 304 305
                    for index, cache_pair in enumerate(self.pool):
                        cache_var = cache_pair[0]
                        cache_shape = cache_pair[1]
306
                        if not self._has_var(block_desc, cache_var, is_forward):
D
"rerun"  
dzhwinter 已提交
307 308 309
                            if PRINT_LOG:
                                print("cache %s not exists!" %
                                      (cpt.to_text(cache_var)))
310
                            continue
D
dzhwinter 已提交
311
                        if x == cache_var:
D
"rerun"  
dzhwinter 已提交
312 313 314 315
                            if PRINT_LOG:
                                print("x : ", cpt.to_text(x), " cache : ",
                                      cpt.to_text(cache_var), " is same var!")
                            break
316 317 318 319 320

                        x_dtype = self._find_var(block_desc, x,
                                                 is_forward).dtype()
                        cache_dtype = self._find_var(block_desc, cache_var,
                                                     is_forward).dtype()
D
dzhwinter 已提交
321 322 323

                        if not compare_shape(x_shape, cache_shape, level):
                            continue
D
dongzhihong 已提交
324
                        # TODO(qijun): dtype_to_size[x_dtype] and dtype_to_size[cache_dtype]
325 326 327 328
                        if x_dtype != cache_dtype:
                            continue

                        if PRINT_LOG:
329 330 331 332 333
                            print(("Hit Cache !!!! cache pool index "
                                   "is %d, var name is %s, "
                                   "cached var name is %s, "
                                   "var shape is %s ") % (index, x, cache_var,
                                                          str(cache_shape)))
334 335 336 337
                        self.pool.pop(index)
                        # Rename the var to the cache var already with
                        # memory allocated in order to reuse the memory.
                        _rename_arg_(self._ops, x, cache_var, begin_idx=i)
M
minqiyang 已提交
338 339 340
                        self._program.block(block_desc.id).var(cpt.to_text(
                            x)).desc = self._find_var(block_desc, cache_var,
                                                      is_forward)
341 342
                        self._program.block(block_desc.id).vars[cpt.to_text(x)] = \
                            Variable(self._program.block(block_desc.id), name=cpt.to_text(x))
343 344
                        self._update_graph(x, cache_var, begin_idx=i)
                        break
D
dzhwinter 已提交
345
            self._fill_pool(i, is_forward)
346 347


348
def _process_sub_block_pair(pdesc, sub_block_pair):
349 350 351 352 353 354 355 356 357 358 359 360 361
    """Creates a list of tuple each of which tracks info of a subblock.

      Note: this function doesn't handle nested subblocks yet.
      TODO(panyx0718): assert if case nested subblocks happen.

    :param pdesc: ProgramDesc.
    :param sub_block_pair: A list op pairs. Each op pair is the forward
        op and backward op. The ops in the list are special that they contain
        a subblock of ops.
    :return: A list of tuples, each tuple is (all ops in a subblock pair
        including forward and backward, number of forward ops,
        all output args names of the ops in the subblock pairs).
    """
362 363 364
    ops_list = []
    block_desc = pdesc.block(0)
    op_size = block_desc.op_size()
365 366 367 368 369 370 371 372 373 374 375 376 377
    for fwd_op, bwd_op in sub_block_pair:
        sub_block_ids = []
        grad_sub_block_ids = []
        sub_block_id_pair = []
        sub_op_dict = {}
        for i in range(op_size):
            op = block_desc.op(i)
            if op.type() == fwd_op:
                sub_block_ids.append(op.attr("sub_block").id)
                sub_op_dict[op.attr("sub_block").id] = op
            elif op.type() == bwd_op:
                grad_sub_block_ids.append(op.attr("sub_block").id)
                sub_op_dict[op.attr("sub_block").id] = op
378

379 380
        # Find fwd_op/bwd_op block pair
        for grad_id in grad_sub_block_ids:
Q
qijun 已提交
381 382 383 384
            fwd_id = pdesc.block(grad_id).get_forward_block_idx()
            if fwd_id in sub_block_ids:
                sub_block_id_pair.append((fwd_id, grad_id))
                sub_block_ids.remove(fwd_id)
385

386
        # Get fwd_op/bwd_op block ops
Q
qijun 已提交
387
        for fwd_id, grad_id in sub_block_id_pair:
388
            sub_block_ops = []
Q
qijun 已提交
389
            sub_block = pdesc.block(fwd_id)
390 391 392
            block_op_size = sub_block.op_size()
            for i in range(block_op_size):
                sub_block_ops.append(sub_block.op(i))
393

394 395 396 397
            grad_sub_block = pdesc.block(grad_id)
            grad_sub_block_op_size = grad_sub_block.op_size()
            for i in range(grad_sub_block_op_size):
                sub_block_ops.append(grad_sub_block.op(i))
398

399
            sub_op_output = set()
Q
qijun 已提交
400
            sub_op_output.update(sub_op_dict[fwd_id].output_arg_names())
401
            sub_op_output.update(sub_op_dict[grad_id].output_arg_names())
402 403
            sub_op_output.update(sub_op_dict[fwd_id].input_arg_names())
            sub_op_output.update(sub_op_dict[grad_id].input_arg_names())
404
            ops_list.append((sub_block_ops, block_op_size, sub_op_output))
405

406
        # Process rest fwd_op block ops
Q
qijun 已提交
407
        for fwd_id in sub_block_ids:
408
            sub_block_ops = []
Q
qijun 已提交
409
            sub_block = pdesc.block(fwd_id)
410 411 412 413
            sub_block_op_size = sub_block.op_size()
            for i in range(sub_block_op_size):
                sub_block_ops.append(sub_block.op(i))
            sub_op_output = set()
Q
qijun 已提交
414
            sub_op_output.update(sub_op_dict[fwd_id].output_arg_names())
415
            sub_op_output.update(sub_op_dict[fwd_id].input_arg_names())
416 417
            ops_list.append((sub_block_ops, sub_block_op_size, sub_op_output))
    return ops_list
418

419

420
def _get_cfgs(input_program):
421 422 423 424 425
    """Process each block and create ControlFlowGraph for each of them.

    :param input_program: Program object.
    :return: A list of ControlFlowGraph, each corresponds to a block.
    """
426
    ops_list = []
W
Wu Yi 已提交
427
    pdesc = input_program._get_desc()
428 429
    block_desc = pdesc.block(0)
    op_size = block_desc.op_size()
430

431 432
    # Only process one level of nested subblock.
    ops_list.extend(_process_sub_block_pair(pdesc, SUB_BLOCK_PAIR))
433

434 435 436 437 438 439 440
    skip_opt_set = set()
    for _, _, skip_opt in ops_list:
        skip_opt_set.update(skip_opt)

    # Get global block ops
    ops_list.insert(
        0, ([block_desc.op(i) for i in range(op_size)], op_size, skip_opt_set))
441 442 443 444
    cfgs = [
        ControlFlowGraph(input_program, ops, forward_num, skip_opt)
        for ops, forward_num, skip_opt in ops_list
    ]
445
    return cfgs
446 447


448 449 450 451 452 453 454 455 456 457 458 459 460
def _is_opt_role_op(op):
    op_maker = core.op_proto_and_checker_maker
    optimize_role = core.op_proto_and_checker_maker.OpRole.Optimize
    if op_maker.kOpRoleAttrName() in op.attr_names and \
            int(op.all_attrs()[op_maker.kOpRoleAttrName()]) == int(optimize_role):
        return True


def memory_optimize(input_program,
                    skip_opt_set=None,
                    print_log=False,
                    level=0,
                    skip_grads=False):
461 462 463 464
    """Optimize memory by reusing var memory.

      Note: it doesn't not support subblock nested in subblock.

D
"rerun"  
dzhwinter 已提交
465 466 467 468 469 470 471
    Args:
        input_program(str): Input Program
        skip_opt_set(set): vars wil be skipped in memory optimze
        print_log(bool): whether to print debug log.
        level(int): If level=0, reuse if the shape is completely equal, o
    Returns:
        None
472
    """
473 474 475 476 477 478 479 480 481 482 483

    def to_name_str(var):
        if isinstance(var, Variable):
            return var.desc.name()
        elif isinstance(var, str):
            return var
        elif isinstance(var, six.string_types):
            return str(var)
        else:
            raise TypeError(str(var) + " should be Variable or str")

484 485
    if level != 0 and level != 1:
        raise ValueError("only support opt_level 0 or 1.")
486 487
    if skip_opt_set is not None and not isinstance(skip_opt_set, set):
        raise ValueError("only support skip_opt_set as set.")
Q
qiaolongfei 已提交
488 489
    global PRINT_LOG
    PRINT_LOG = print_log
490 491 492 493 494 495 496 497 498 499 500 501
    if skip_grads:
        grad_set = set()
        OP_ROLE_VAR = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
        for op in input_program.global_block().ops:
            if _is_opt_role_op(op):
                if op.attr(OP_ROLE_VAR):
                    grad_name = op.attr(OP_ROLE_VAR)[1]
                    grad_set.add(grad_name)
        if not skip_opt_set:
            skip_opt_set = grad_set
        else:
            skip_opt_set.update(grad_set)
502 503
    if skip_opt_set is not None:
        skip_opt_set = set(map(to_name_str, skip_opt_set))
504
    cfgs = _get_cfgs(input_program)
505
    for cfg in cfgs:
506
        cfg.memory_optimize(skip_opt_set=skip_opt_set, level=level)
507 508


509
def release_memory(input_program, skip_opt_set=None):
Y
yuyang18 已提交
510 511 512 513 514 515 516 517 518
    """
    Modify the input program and insert :code:`delete_op` to early drop not used
    variables. The modification will be performed inplace.

    Notes: This is an experimental API and could be removed in next few
    releases. Users should not use this API.

    Args:
        input_program(Program): The program will be inserted :code:`delete_op`.
D
"rerun"  
dzhwinter 已提交
519 520 521
        skip_opt_set(set): vars wil be skipped in memory optimze
    Returns:
        None
Y
yuyang18 已提交
522
    """
523 524
    cfgs = _get_cfgs(input_program)
    for cfg in cfgs:
525
        cfg.release_memory(skip_opt_set=skip_opt_set)