memory_optimization_transpiler.py 21.7 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

17
import six
X
Xin Pan 已提交
18
import sys
19
from collections import defaultdict, MutableSet
20
from .. import core
M
minqiyang 已提交
21
from ... import compat as cpt
22
from ..framework import Program, default_main_program, Parameter, Variable, core
23
from ..backward import _rename_arg_
24
from functools import reduce
25
from six.moves import range
26 27

dtype_to_size = {
28 29 30 31 32 33
    core.VarDesc.VarType.FP16: 2,
    core.VarDesc.VarType.FP32: 4,
    core.VarDesc.VarType.FP64: 8,
    core.VarDesc.VarType.INT16: 2,
    core.VarDesc.VarType.INT32: 4,
    core.VarDesc.VarType.INT64: 8,
34 35
    core.VarDesc.VarType.BOOL: 1,
    core.VarDesc.VarType.UINT8: 1,
36
}
37

38
SUB_BLOCK_OPS = [
X
Xin Pan 已提交
39
    "while", "while_grad", "conditional_block", "conditional_block_grad"
40
]
41

X
Xin Pan 已提交
42
SUB_BLOCK_PAIR = [("while", "while_grad"),
43 44
                  ("conditional_block", "conditional_block_grad")]

Q
qiaolongfei 已提交
45
PRINT_LOG = False
D
dzhwinter 已提交
46
FLAGS_memory_optimize = ""
Q
qiaolongfei 已提交
47

48

49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
class OrderedSet(MutableSet):
    def __init__(self, iterable=None):
        self.end = end = []
        end += [None, end, end]  # sentinel node for doubly linked list
        self.map = {}  # key --> [key, prev, next]
        if iterable is not None:
            self |= iterable

    def __len__(self):
        return len(self.map)

    def __contains__(self, key):
        return key in self.map

    def add(self, key):
        if key not in self.map:
            end = self.end
            curr = end[1]
            curr[2] = end[1] = self.map[key] = [key, curr, end]

    def update(self, other):
        for e in other:
            self.add(e)

    def discard(self, key):
        if key in self.map:
            key, prev, next = self.map.pop(key)
            prev[2] = next
            next[1] = prev

    def remove(self, key):
        self.discard(key)

    def __iter__(self):
        end = self.end
        curr = end[2]
        while curr is not end:
            yield curr[0]
            curr = curr[2]

    def __reversed__(self):
        end = self.end
        curr = end[1]
        while curr is not end:
            yield curr[0]
            curr = curr[1]

    def pop(self, last=True):
        if not self:
            raise KeyError('set is empty')
        key = self.end[1][0] if last else self.end[2][0]
        self.discard(key)
        return key

    def __repr__(self):
        if not self:
            return '%s()' % (self.__class__.__name__, )
        return '%s(%r)' % (self.__class__.__name__, list(self))

    def __eq__(self, other):
        if isinstance(other, OrderedSet):
            return len(self) == len(other) and list(self) == list(other)
        return set(self) == set(other)


114
class ControlFlowGraph(object):
115 116
    def __init__(self, program, ops, forward_num, skip_opt):
        self._program = program
117 118
        self._ops = ops
        self._forward_num = forward_num
119 120 121 122 123 124
        self._successors = defaultdict(OrderedSet)
        self._presuccessors = defaultdict(OrderedSet)
        self._uses = defaultdict(OrderedSet)
        self._defs = defaultdict(OrderedSet)
        self._live_in = defaultdict(OrderedSet)
        self._live_out = defaultdict(OrderedSet)
D
dzhwinter 已提交
125

126
        self._skip_opt = skip_opt
D
dzhwinter 已提交
127
        self.pool = []
128 129

    def _add_connections(self, connections):
130
        """Populates _successors and _presuccessors for two neighbor nodes."""
131 132 133 134
        for node1, node2 in connections:
            self._add(node1, node2)

    def _add(self, node1, node2):
135 136
        self._successors[node1].add(node2)
        self._presuccessors[node2].add(node1)
137

138 139
    # TODO(panyx0718): We need to have a unified way of building intermediate
    # representation.
140
    def _build_graph(self):
141 142
        """Build a graph based on op sequence.
        """
143
        self.op_size = len(self._ops)
144 145 146
        op_node_connections = [(i, i + 1) for i in range(self.op_size - 1)]
        self._add_connections(op_node_connections)
        for i in range(self.op_size):
147 148
            self._uses[i].update(self._ops[i].input_arg_names())
            self._defs[i].update(self._ops[i].output_arg_names())
149

150 151 152 153 154 155 156 157 158 159
    def _update_graph(self, old_name, new_name, begin_idx=0):
        for i in range(begin_idx, self.op_size):
            if old_name in self._uses[i]:
                self._uses[i].remove(old_name)
                self._uses[i].add(new_name)
            if old_name in self._defs[i]:
                self._defs[i].remove(old_name)
                self._defs[i].add(new_name)
            if old_name in self._live_in[i]:
                self._live_in[i].remove(old_name)
D
dzhwinter 已提交
160
                self._live_in[i].add(new_name)
161 162 163 164
            if old_name in self._live_out[i]:
                self._live_out[i].remove(old_name)
                self._live_out[i].add(new_name)

165 166 167
    def _dataflow_analyze(self):
        self._build_graph()
        live_in = defaultdict(set)
D
dzhwinter 已提交
168 169 170 171 172 173 174 175
        worklist = list(range(len(self._ops) - 1, -1, -1))
        while worklist:
            i = worklist.pop(0)
            live_in[i] = set(self._live_in[i])
            for s in self._successors[i]:
                self._live_out[i] |= self._live_in[s]
            self._live_in[i] = self._uses[i] | (
                self._live_out[i] - self._defs[i])
D
dongzhihong 已提交
176
            if live_in[i] != set(self._live_in[i]):
D
dzhwinter 已提交
177 178
                for d in self._presuccessors[i]:
                    worklist.append(d)
179

D
dzhwinter 已提交
180
    def _fill_pool(self, i, is_forward):
D
dzhwinter 已提交
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
        def comparator(x, cache):
            x_shape = x[1]
            cache_shape = cache[1]
            x_size = abs(reduce(lambda x, y: x * y, x_shape))
            cache_size = abs(reduce(lambda x, y: x * y, cache_shape))
            if (x_shape[0] == -1 and cache_shape[0] == -1) or \
               (x_shape[0] != -1 and cache_shape[0] != -1) :
                return x_size <= cache_size
            else:
                return False

        def find_var_in_block(x):
            known_vars = set()
            for op in self._ops:
                known_vars.update(op.output_arg_names())
            return x in known_vars

D
dzhwinter 已提交
198 199
        block_desc = self._ops[i].block()
        in_diff, _ = self._get_diff(self._live_in[i], self._live_out[i])
200 201
        # NOTE: must sort the in_diff set for cases that get different cache var.
        # FIXME(typhoonzero): maybe use a "sorted set" is better than this.
D
dzhwinter 已提交
202
        can_optimize = [
D
dzhwinter 已提交
203
            x for x in sorted(in_diff)
D
dzhwinter 已提交
204 205 206 207
            if self._check_var_validity(block_desc, x, is_forward)
        ]
        if can_optimize:
            for var_name in can_optimize:
D
dzhwinter 已提交
208 209
                cache = (var_name, self._find_var(block_desc, var_name,
                                                  is_forward).shape())
D
dzhwinter 已提交
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
                if cache not in self.pool and find_var_in_block(var_name):
                    i = 0
                    while i < len(self.pool):
                        mycache = self.pool[i]
                        mysize = mycache[1][0]
                        cache_size = cache[1][0]
                        if (mysize == -1 and cache_size == -1) or \
                           (mysize != -1 and cache_size != -1):
                            if comparator(mycache, cache):
                                i += 1
                            else:
                                break
                        elif mysize == -1 and cache_size != -1:
                            i += 1
                        elif mysize != -1 and cache_size == -1:
                            break
                    self.pool.insert(i, cache)
227 228 229 230 231

    def _get_diff(self, a, b):
        u = a & b
        return a - u, b - u

232 233
    def _has_var(self, block_desc, var_name, is_forward):
        if is_forward:
M
minqiyang 已提交
234
            return block_desc.has_var(cpt.to_bytes(var_name))
235
        else:
M
minqiyang 已提交
236
            return block_desc.has_var_recursive(cpt.to_bytes(var_name))
237 238 239

    def _find_var(self, block_desc, var_name, is_forward):
        if is_forward:
M
minqiyang 已提交
240
            return block_desc.find_var(cpt.to_bytes(var_name))
241
        else:
M
minqiyang 已提交
242
            return block_desc.find_var_recursive(cpt.to_bytes(var_name))
243

244 245 246 247 248 249 250 251 252 253 254 255 256 257 258
    def _check_var_validity(self, block_desc, x, is_forward):
        if str(x) == "@EMPTY@":
            return False
        if not self._has_var(block_desc, x, is_forward):
            return False
        if self._find_var(block_desc, x, is_forward).persistable():
            return False
        if self._find_var(block_desc, x,
                          is_forward).type() != core.VarDesc.VarType.LOD_TENSOR:
            return False
        if x in self._skip_opt:
            return False
        if not self._find_var(block_desc, x, is_forward).shape():
            return False
        return True
259

260 261
    # TODO(panyx0718): This needs to be less hacky. It seems memory optimization
    # doesn't consider vars copied between cpu and gpu.
262 263 264
    def _update_skip_opt_set(self):
        for i in range(self.op_size):
            op = self._ops[i]
D
dzhwinter 已提交
265
            if op.has_attr("force_cpu") and op.attr("force_cpu") == True:
266 267
                self._skip_opt.update(op.output_arg_names())

268
    def release_memory(self, skip_opt_set=None):
269
        self._dataflow_analyze()
270
        self._update_skip_opt_set()
271 272
        if skip_opt_set:
            self._skip_opt.update(skip_opt_set)
273 274 275 276
        fwd_id = 0
        bwd_id = 0
        for i in range(self.op_size):
            op = self._ops[i]
277
            if op.type() in SUB_BLOCK_OPS:
278 279 280 281 282
                continue
            block_desc = op.block()
            is_forward = i < self._forward_num
            in_diff, out_diff = self._get_diff(self._live_in[i],
                                               self._live_out[i])
283 284 285 286
            can_optimize = [
                x for x in in_diff
                if self._check_var_validity(block_desc, x, is_forward)
            ]
287 288
            if can_optimize:
                index = i + fwd_id + 1 if is_forward else i - self._forward_num + bwd_id + 1
W
Wu Yi 已提交
289
                delete_op = block_desc._insert_op(index)
290 291 292 293 294 295 296
                delete_op.set_type("delete_var")
                delete_op.set_input("X", can_optimize)
                if is_forward:
                    fwd_id += 1
                else:
                    bwd_id += 1

297
    def memory_optimize(self, skip_opt_set=None, level=0):
298 299 300
        def compare_shape(x_shape, cache_shape, opt_level):
            if opt_level == 0:
                return x_shape == cache_shape
301
            elif opt_level == 1:
302 303 304 305 306 307
                if (x_shape[0] == -1) ^ (cache_shape[0] == -1):
                    return False
                x_size = abs(reduce(lambda x, y: x * y, x_shape))
                cache_size = abs(reduce(lambda x, y: x * y, cache_shape))
                if x_size <= cache_size:
                    return True
308 309
            else:
                raise ValueError("only support opt_level 0 or 1.")
310 311 312 313
            return False

        self._dataflow_analyze()
        self._update_skip_opt_set()
314 315 316
        # update skip set to meet users' demand
        if skip_opt_set:
            self._skip_opt.update(skip_opt_set)
D
dzhwinter 已提交
317
        counter = 0
318
        for i in range(self.op_size):
319
            op = self._ops[i]
320
            if op.type() in SUB_BLOCK_OPS:
321 322 323
                continue
            block_desc = op.block()
            is_forward = i < self._forward_num
324
            if self.pool:
325
                # NOTE: must sort the in_diff set for cases that get different cache var.
326
                defs_can_optimize = [
327
                    x for x in self._defs[i]
328 329
                    if self._check_var_validity(block_desc, x, is_forward)
                ]
330 331 332 333
                out_pair = [
                    (x, self._find_var(block_desc, x, is_forward).shape())
                    for x in defs_can_optimize
                ]
334
                for x, x_shape in out_pair:
335 336 337
                    # If x is both in uses and defs, it can not be optimized!
                    if x in self._uses[i]:
                        continue
D
dzhwinter 已提交
338 339 340
                    if x == FLAGS_memory_optimize:
                        print("start match var ", x, " of op ", op.type())
                        print(self.pool)
341 342 343
                    for index, cache_pair in enumerate(self.pool):
                        cache_var = cache_pair[0]
                        cache_shape = cache_pair[1]
344
                        if not self._has_var(block_desc, cache_var, is_forward):
D
"rerun"  
dzhwinter 已提交
345 346 347
                            if PRINT_LOG:
                                print("cache %s not exists!" %
                                      (cpt.to_text(cache_var)))
348
                            continue
D
dzhwinter 已提交
349
                        if x == cache_var:
D
"rerun"  
dzhwinter 已提交
350 351 352 353
                            if PRINT_LOG:
                                print("x : ", cpt.to_text(x), " cache : ",
                                      cpt.to_text(cache_var), " is same var!")
                            break
354 355 356 357 358

                        x_dtype = self._find_var(block_desc, x,
                                                 is_forward).dtype()
                        cache_dtype = self._find_var(block_desc, cache_var,
                                                     is_forward).dtype()
D
dzhwinter 已提交
359 360
                        if x_dtype != cache_dtype:
                            if PRINT_LOG:
D
dzhwinter 已提交
361
                                print("x_dtype and cache_dtype are different")
D
dzhwinter 已提交
362
                            continue
D
dzhwinter 已提交
363 364 365

                        if not compare_shape(x_shape, cache_shape, level):
                            continue
D
dongzhihong 已提交
366
                        # TODO(qijun): dtype_to_size[x_dtype] and dtype_to_size[cache_dtype]
367
                        if PRINT_LOG:
D
dzhwinter 已提交
368 369 370 371 372 373
                            print(
                                ("!!! %d,  %s => %s, cache idx %d, pool size %d"
                                 % (counter, x + str(x_shape),
                                    cache_var + str(cache_shape), index,
                                    len(self.pool))))
                            counter += 1
374 375 376 377
                        self.pool.pop(index)
                        # Rename the var to the cache var already with
                        # memory allocated in order to reuse the memory.
                        _rename_arg_(self._ops, x, cache_var, begin_idx=i)
M
minqiyang 已提交
378 379 380
                        self._program.block(block_desc.id).var(cpt.to_text(
                            x)).desc = self._find_var(block_desc, cache_var,
                                                      is_forward)
381 382
                        self._program.block(block_desc.id).vars[cpt.to_text(x)] = \
                            Variable(self._program.block(block_desc.id), name=cpt.to_text(x))
383 384
                        self._update_graph(x, cache_var, begin_idx=i)
                        break
D
dzhwinter 已提交
385
            self._fill_pool(i, is_forward)
386 387


388
def _process_sub_block_pair(pdesc, sub_block_pair):
389 390 391 392 393 394 395 396 397 398 399 400 401
    """Creates a list of tuple each of which tracks info of a subblock.

      Note: this function doesn't handle nested subblocks yet.
      TODO(panyx0718): assert if case nested subblocks happen.

    :param pdesc: ProgramDesc.
    :param sub_block_pair: A list op pairs. Each op pair is the forward
        op and backward op. The ops in the list are special that they contain
        a subblock of ops.
    :return: A list of tuples, each tuple is (all ops in a subblock pair
        including forward and backward, number of forward ops,
        all output args names of the ops in the subblock pairs).
    """
402 403 404
    ops_list = []
    block_desc = pdesc.block(0)
    op_size = block_desc.op_size()
405 406 407 408 409 410 411 412 413 414 415 416 417
    for fwd_op, bwd_op in sub_block_pair:
        sub_block_ids = []
        grad_sub_block_ids = []
        sub_block_id_pair = []
        sub_op_dict = {}
        for i in range(op_size):
            op = block_desc.op(i)
            if op.type() == fwd_op:
                sub_block_ids.append(op.attr("sub_block").id)
                sub_op_dict[op.attr("sub_block").id] = op
            elif op.type() == bwd_op:
                grad_sub_block_ids.append(op.attr("sub_block").id)
                sub_op_dict[op.attr("sub_block").id] = op
418

419 420
        # Find fwd_op/bwd_op block pair
        for grad_id in grad_sub_block_ids:
Q
qijun 已提交
421 422 423 424
            fwd_id = pdesc.block(grad_id).get_forward_block_idx()
            if fwd_id in sub_block_ids:
                sub_block_id_pair.append((fwd_id, grad_id))
                sub_block_ids.remove(fwd_id)
425

426
        # Get fwd_op/bwd_op block ops
Q
qijun 已提交
427
        for fwd_id, grad_id in sub_block_id_pair:
428
            sub_block_ops = []
Q
qijun 已提交
429
            sub_block = pdesc.block(fwd_id)
430 431 432
            block_op_size = sub_block.op_size()
            for i in range(block_op_size):
                sub_block_ops.append(sub_block.op(i))
433

434 435 436 437
            grad_sub_block = pdesc.block(grad_id)
            grad_sub_block_op_size = grad_sub_block.op_size()
            for i in range(grad_sub_block_op_size):
                sub_block_ops.append(grad_sub_block.op(i))
438

439
            sub_op_output = set()
Q
qijun 已提交
440
            sub_op_output.update(sub_op_dict[fwd_id].output_arg_names())
441
            sub_op_output.update(sub_op_dict[grad_id].output_arg_names())
442 443
            sub_op_output.update(sub_op_dict[fwd_id].input_arg_names())
            sub_op_output.update(sub_op_dict[grad_id].input_arg_names())
444
            ops_list.append((sub_block_ops, block_op_size, sub_op_output))
445

446
        # Process rest fwd_op block ops
Q
qijun 已提交
447
        for fwd_id in sub_block_ids:
448
            sub_block_ops = []
Q
qijun 已提交
449
            sub_block = pdesc.block(fwd_id)
450 451 452 453
            sub_block_op_size = sub_block.op_size()
            for i in range(sub_block_op_size):
                sub_block_ops.append(sub_block.op(i))
            sub_op_output = set()
Q
qijun 已提交
454
            sub_op_output.update(sub_op_dict[fwd_id].output_arg_names())
455
            sub_op_output.update(sub_op_dict[fwd_id].input_arg_names())
456 457
            ops_list.append((sub_block_ops, sub_block_op_size, sub_op_output))
    return ops_list
458

459

460
def _get_cfgs(input_program):
461 462 463 464 465
    """Process each block and create ControlFlowGraph for each of them.

    :param input_program: Program object.
    :return: A list of ControlFlowGraph, each corresponds to a block.
    """
466
    ops_list = []
W
Wu Yi 已提交
467
    pdesc = input_program._get_desc()
468 469
    block_desc = pdesc.block(0)
    op_size = block_desc.op_size()
470

471 472
    # Only process one level of nested subblock.
    ops_list.extend(_process_sub_block_pair(pdesc, SUB_BLOCK_PAIR))
473

474 475 476 477 478 479 480
    skip_opt_set = set()
    for _, _, skip_opt in ops_list:
        skip_opt_set.update(skip_opt)

    # Get global block ops
    ops_list.insert(
        0, ([block_desc.op(i) for i in range(op_size)], op_size, skip_opt_set))
481 482 483 484
    cfgs = [
        ControlFlowGraph(input_program, ops, forward_num, skip_opt)
        for ops, forward_num, skip_opt in ops_list
    ]
485
    return cfgs
486 487


488 489 490 491 492 493 494 495 496 497 498 499 500
def _is_opt_role_op(op):
    op_maker = core.op_proto_and_checker_maker
    optimize_role = core.op_proto_and_checker_maker.OpRole.Optimize
    if op_maker.kOpRoleAttrName() in op.attr_names and \
            int(op.all_attrs()[op_maker.kOpRoleAttrName()]) == int(optimize_role):
        return True


def memory_optimize(input_program,
                    skip_opt_set=None,
                    print_log=False,
                    level=0,
                    skip_grads=False):
501 502 503 504
    """Optimize memory by reusing var memory.

      Note: it doesn't not support subblock nested in subblock.

D
"rerun"  
dzhwinter 已提交
505 506 507 508 509 510 511
    Args:
        input_program(str): Input Program
        skip_opt_set(set): vars wil be skipped in memory optimze
        print_log(bool): whether to print debug log.
        level(int): If level=0, reuse if the shape is completely equal, o
    Returns:
        None
512
    """
X
polish  
Xin Pan 已提交
513 514
    sys.stderr.write('memory_optimize is deprecated. '
                     'Use CompiledProgram and Executor\n')
515 516 517 518 519 520 521 522 523 524 525

    def to_name_str(var):
        if isinstance(var, Variable):
            return var.desc.name()
        elif isinstance(var, str):
            return var
        elif isinstance(var, six.string_types):
            return str(var)
        else:
            raise TypeError(str(var) + " should be Variable or str")

526 527
    if level != 0 and level != 1:
        raise ValueError("only support opt_level 0 or 1.")
D
dzhwinter 已提交
528 529 530 531 532
    if skip_opt_set is not None:
        if isinstance(skip_opt_set, set) or isinstance(skip_opt_set, list):
            skip_opt_set = set(skip_opt_set)
        else:
            raise ValueError("only support skip_opt_set as set.")
Q
qiaolongfei 已提交
533 534
    global PRINT_LOG
    PRINT_LOG = print_log
535 536 537 538 539 540 541 542 543 544 545 546
    if skip_grads:
        grad_set = set()
        OP_ROLE_VAR = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
        for op in input_program.global_block().ops:
            if _is_opt_role_op(op):
                if op.attr(OP_ROLE_VAR):
                    grad_name = op.attr(OP_ROLE_VAR)[1]
                    grad_set.add(grad_name)
        if not skip_opt_set:
            skip_opt_set = grad_set
        else:
            skip_opt_set.update(grad_set)
547 548
    if skip_opt_set is not None:
        skip_opt_set = set(map(to_name_str, skip_opt_set))
549
    cfgs = _get_cfgs(input_program)
D
dzhwinter 已提交
550
    input_program._is_mem_optimized = True
551
    for cfg in cfgs:
552
        cfg.memory_optimize(skip_opt_set=skip_opt_set, level=level)
553 554


555
def release_memory(input_program, skip_opt_set=None):
Y
yuyang18 已提交
556 557 558 559 560 561 562 563 564
    """
    Modify the input program and insert :code:`delete_op` to early drop not used
    variables. The modification will be performed inplace.

    Notes: This is an experimental API and could be removed in next few
    releases. Users should not use this API.

    Args:
        input_program(Program): The program will be inserted :code:`delete_op`.
D
"rerun"  
dzhwinter 已提交
565 566 567
        skip_opt_set(set): vars wil be skipped in memory optimze
    Returns:
        None
Y
yuyang18 已提交
568
    """
569
    cfgs = _get_cfgs(input_program)
D
dzhwinter 已提交
570
    input_program._is_mem_optimized = True
571
    for cfg in cfgs:
572
        cfg.release_memory(skip_opt_set=skip_opt_set)