memory_optimization_transpiler.py 24.0 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

17
import logging
18
import six
X
Xin Pan 已提交
19
import sys
20
from collections import defaultdict, MutableSet
21
from .. import core
M
minqiyang 已提交
22
from ... import compat as cpt
23
from ..framework import Program, default_main_program, Parameter, Variable, core
24
from ..backward import _rename_arg_
25
from functools import reduce
26
from six.moves import range
27 28

dtype_to_size = {
29 30 31 32 33 34
    core.VarDesc.VarType.FP16: 2,
    core.VarDesc.VarType.FP32: 4,
    core.VarDesc.VarType.FP64: 8,
    core.VarDesc.VarType.INT16: 2,
    core.VarDesc.VarType.INT32: 4,
    core.VarDesc.VarType.INT64: 8,
35 36
    core.VarDesc.VarType.BOOL: 1,
    core.VarDesc.VarType.UINT8: 1,
37
}
38

39
SUB_BLOCK_OPS = [
X
Xin Pan 已提交
40
    "while", "while_grad", "conditional_block", "conditional_block_grad"
41
]
42

X
Xin Pan 已提交
43
SUB_BLOCK_PAIR = [("while", "while_grad"),
44 45
                  ("conditional_block", "conditional_block_grad")]

Q
qiaolongfei 已提交
46
PRINT_LOG = False
D
dzhwinter 已提交
47
FLAGS_memory_optimize = ""
Q
qiaolongfei 已提交
48

49

50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
class OrderedSet(MutableSet):
    def __init__(self, iterable=None):
        self.end = end = []
        end += [None, end, end]  # sentinel node for doubly linked list
        self.map = {}  # key --> [key, prev, next]
        if iterable is not None:
            self |= iterable

    def __len__(self):
        return len(self.map)

    def __contains__(self, key):
        return key in self.map

    def add(self, key):
        if key not in self.map:
            end = self.end
            curr = end[1]
            curr[2] = end[1] = self.map[key] = [key, curr, end]

    def update(self, other):
        for e in other:
            self.add(e)

    def discard(self, key):
        if key in self.map:
            key, prev, next = self.map.pop(key)
            prev[2] = next
            next[1] = prev

    def remove(self, key):
        self.discard(key)

    def __iter__(self):
        end = self.end
        curr = end[2]
        while curr is not end:
            yield curr[0]
            curr = curr[2]

    def __reversed__(self):
        end = self.end
        curr = end[1]
        while curr is not end:
            yield curr[0]
            curr = curr[1]

    def pop(self, last=True):
        if not self:
            raise KeyError('set is empty')
        key = self.end[1][0] if last else self.end[2][0]
        self.discard(key)
        return key

    def __repr__(self):
        if not self:
            return '%s()' % (self.__class__.__name__, )
        return '%s(%r)' % (self.__class__.__name__, list(self))

    def __eq__(self, other):
        if isinstance(other, OrderedSet):
            return len(self) == len(other) and list(self) == list(other)
        return set(self) == set(other)


115
class ControlFlowGraph(object):
116 117
    def __init__(self, program, ops, forward_num, skip_opt):
        self._program = program
118 119
        self._ops = ops
        self._forward_num = forward_num
120 121 122 123 124 125
        self._successors = defaultdict(OrderedSet)
        self._presuccessors = defaultdict(OrderedSet)
        self._uses = defaultdict(OrderedSet)
        self._defs = defaultdict(OrderedSet)
        self._live_in = defaultdict(OrderedSet)
        self._live_out = defaultdict(OrderedSet)
D
dzhwinter 已提交
126

127
        self._skip_opt = skip_opt
D
dzhwinter 已提交
128
        self.pool = []
129 130

    def _add_connections(self, connections):
131
        """Populates _successors and _presuccessors for two neighbor nodes."""
132 133 134 135
        for node1, node2 in connections:
            self._add(node1, node2)

    def _add(self, node1, node2):
136 137
        self._successors[node1].add(node2)
        self._presuccessors[node2].add(node1)
138

139 140
    # TODO(panyx0718): We need to have a unified way of building intermediate
    # representation.
141
    def _build_graph(self):
142 143
        """Build a graph based on op sequence.
        """
144
        self.op_size = len(self._ops)
145 146 147
        op_node_connections = [(i, i + 1) for i in range(self.op_size - 1)]
        self._add_connections(op_node_connections)
        for i in range(self.op_size):
148 149
            self._uses[i].update(self._ops[i].input_arg_names())
            self._defs[i].update(self._ops[i].output_arg_names())
150

151 152 153 154 155 156 157 158 159 160
    def _update_graph(self, old_name, new_name, begin_idx=0):
        for i in range(begin_idx, self.op_size):
            if old_name in self._uses[i]:
                self._uses[i].remove(old_name)
                self._uses[i].add(new_name)
            if old_name in self._defs[i]:
                self._defs[i].remove(old_name)
                self._defs[i].add(new_name)
            if old_name in self._live_in[i]:
                self._live_in[i].remove(old_name)
D
dzhwinter 已提交
161
                self._live_in[i].add(new_name)
162 163 164 165
            if old_name in self._live_out[i]:
                self._live_out[i].remove(old_name)
                self._live_out[i].add(new_name)

166 167 168
    def _dataflow_analyze(self):
        self._build_graph()
        live_in = defaultdict(set)
D
dzhwinter 已提交
169 170 171 172 173 174 175 176
        worklist = list(range(len(self._ops) - 1, -1, -1))
        while worklist:
            i = worklist.pop(0)
            live_in[i] = set(self._live_in[i])
            for s in self._successors[i]:
                self._live_out[i] |= self._live_in[s]
            self._live_in[i] = self._uses[i] | (
                self._live_out[i] - self._defs[i])
D
dongzhihong 已提交
177
            if live_in[i] != set(self._live_in[i]):
D
dzhwinter 已提交
178 179
                for d in self._presuccessors[i]:
                    worklist.append(d)
180

D
dzhwinter 已提交
181
    def _fill_pool(self, i, is_forward):
D
dzhwinter 已提交
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
        def comparator(x, cache):
            x_shape = x[1]
            cache_shape = cache[1]
            x_size = abs(reduce(lambda x, y: x * y, x_shape))
            cache_size = abs(reduce(lambda x, y: x * y, cache_shape))
            if (x_shape[0] == -1 and cache_shape[0] == -1) or \
               (x_shape[0] != -1 and cache_shape[0] != -1) :
                return x_size <= cache_size
            else:
                return False

        def find_var_in_block(x):
            known_vars = set()
            for op in self._ops:
                known_vars.update(op.output_arg_names())
            return x in known_vars

D
dzhwinter 已提交
199 200
        block_desc = self._ops[i].block()
        in_diff, _ = self._get_diff(self._live_in[i], self._live_out[i])
201 202
        # NOTE: must sort the in_diff set for cases that get different cache var.
        # FIXME(typhoonzero): maybe use a "sorted set" is better than this.
D
dzhwinter 已提交
203
        can_optimize = [
D
dzhwinter 已提交
204
            x for x in sorted(in_diff)
D
dzhwinter 已提交
205 206 207 208
            if self._check_var_validity(block_desc, x, is_forward)
        ]
        if can_optimize:
            for var_name in can_optimize:
D
dzhwinter 已提交
209 210
                cache = (var_name, self._find_var(block_desc, var_name,
                                                  is_forward).shape())
D
dzhwinter 已提交
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
                if cache not in self.pool and find_var_in_block(var_name):
                    i = 0
                    while i < len(self.pool):
                        mycache = self.pool[i]
                        mysize = mycache[1][0]
                        cache_size = cache[1][0]
                        if (mysize == -1 and cache_size == -1) or \
                           (mysize != -1 and cache_size != -1):
                            if comparator(mycache, cache):
                                i += 1
                            else:
                                break
                        elif mysize == -1 and cache_size != -1:
                            i += 1
                        elif mysize != -1 and cache_size == -1:
                            break
                    self.pool.insert(i, cache)
228 229 230 231 232

    def _get_diff(self, a, b):
        u = a & b
        return a - u, b - u

233 234
    def _has_var(self, block_desc, var_name, is_forward):
        if is_forward:
M
minqiyang 已提交
235
            return block_desc.has_var(cpt.to_bytes(var_name))
236
        else:
M
minqiyang 已提交
237
            return block_desc.has_var_recursive(cpt.to_bytes(var_name))
238 239 240

    def _find_var(self, block_desc, var_name, is_forward):
        if is_forward:
M
minqiyang 已提交
241
            return block_desc.find_var(cpt.to_bytes(var_name))
242
        else:
M
minqiyang 已提交
243
            return block_desc.find_var_recursive(cpt.to_bytes(var_name))
244

245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
    def _check_var_validity(self, block_desc, x, is_forward):
        if str(x) == "@EMPTY@":
            return False
        if not self._has_var(block_desc, x, is_forward):
            return False
        if self._find_var(block_desc, x, is_forward).persistable():
            return False
        if self._find_var(block_desc, x,
                          is_forward).type() != core.VarDesc.VarType.LOD_TENSOR:
            return False
        if x in self._skip_opt:
            return False
        if not self._find_var(block_desc, x, is_forward).shape():
            return False
        return True
260

261 262
    # TODO(panyx0718): This needs to be less hacky. It seems memory optimization
    # doesn't consider vars copied between cpu and gpu.
263 264 265
    def _update_skip_opt_set(self):
        for i in range(self.op_size):
            op = self._ops[i]
D
dzhwinter 已提交
266
            if op.has_attr("force_cpu") and op.attr("force_cpu") == True:
267 268
                self._skip_opt.update(op.output_arg_names())

269
    def release_memory(self, skip_opt_set=None):
270
        self._dataflow_analyze()
271
        self._update_skip_opt_set()
272 273
        if skip_opt_set:
            self._skip_opt.update(skip_opt_set)
274 275 276 277
        fwd_id = 0
        bwd_id = 0
        for i in range(self.op_size):
            op = self._ops[i]
278
            if op.type() in SUB_BLOCK_OPS:
279 280 281 282 283
                continue
            block_desc = op.block()
            is_forward = i < self._forward_num
            in_diff, out_diff = self._get_diff(self._live_in[i],
                                               self._live_out[i])
284 285 286 287
            can_optimize = [
                x for x in in_diff
                if self._check_var_validity(block_desc, x, is_forward)
            ]
288 289
            if can_optimize:
                index = i + fwd_id + 1 if is_forward else i - self._forward_num + bwd_id + 1
W
Wu Yi 已提交
290
                delete_op = block_desc._insert_op(index)
291 292 293 294 295 296 297
                delete_op.set_type("delete_var")
                delete_op.set_input("X", can_optimize)
                if is_forward:
                    fwd_id += 1
                else:
                    bwd_id += 1

298
    def memory_optimize(self, skip_opt_set=None, level=0):
299 300 301
        def compare_shape(x_shape, cache_shape, opt_level):
            if opt_level == 0:
                return x_shape == cache_shape
302
            elif opt_level == 1:
303 304 305 306 307 308
                if (x_shape[0] == -1) ^ (cache_shape[0] == -1):
                    return False
                x_size = abs(reduce(lambda x, y: x * y, x_shape))
                cache_size = abs(reduce(lambda x, y: x * y, cache_shape))
                if x_size <= cache_size:
                    return True
309 310
            else:
                raise ValueError("only support opt_level 0 or 1.")
311 312 313 314
            return False

        self._dataflow_analyze()
        self._update_skip_opt_set()
315 316 317
        # update skip set to meet users' demand
        if skip_opt_set:
            self._skip_opt.update(skip_opt_set)
D
dzhwinter 已提交
318
        counter = 0
319
        for i in range(self.op_size):
320
            op = self._ops[i]
321
            if op.type() in SUB_BLOCK_OPS:
322 323 324
                continue
            block_desc = op.block()
            is_forward = i < self._forward_num
325
            if self.pool:
326
                # NOTE: must sort the in_diff set for cases that get different cache var.
327
                defs_can_optimize = [
328
                    x for x in self._defs[i]
329 330
                    if self._check_var_validity(block_desc, x, is_forward)
                ]
331 332 333 334
                out_pair = [
                    (x, self._find_var(block_desc, x, is_forward).shape())
                    for x in defs_can_optimize
                ]
335
                for x, x_shape in out_pair:
336 337 338
                    # If x is both in uses and defs, it can not be optimized!
                    if x in self._uses[i]:
                        continue
D
dzhwinter 已提交
339 340 341
                    if x == FLAGS_memory_optimize:
                        print("start match var ", x, " of op ", op.type())
                        print(self.pool)
342 343 344
                    for index, cache_pair in enumerate(self.pool):
                        cache_var = cache_pair[0]
                        cache_shape = cache_pair[1]
345
                        if not self._has_var(block_desc, cache_var, is_forward):
D
"rerun"  
dzhwinter 已提交
346 347 348
                            if PRINT_LOG:
                                print("cache %s not exists!" %
                                      (cpt.to_text(cache_var)))
349
                            continue
D
dzhwinter 已提交
350
                        if x == cache_var:
D
"rerun"  
dzhwinter 已提交
351 352 353 354
                            if PRINT_LOG:
                                print("x : ", cpt.to_text(x), " cache : ",
                                      cpt.to_text(cache_var), " is same var!")
                            break
355 356 357 358 359

                        x_dtype = self._find_var(block_desc, x,
                                                 is_forward).dtype()
                        cache_dtype = self._find_var(block_desc, cache_var,
                                                     is_forward).dtype()
D
dzhwinter 已提交
360 361
                        if x_dtype != cache_dtype:
                            if PRINT_LOG:
D
dzhwinter 已提交
362
                                print("x_dtype and cache_dtype are different")
D
dzhwinter 已提交
363
                            continue
D
dzhwinter 已提交
364 365 366

                        if not compare_shape(x_shape, cache_shape, level):
                            continue
D
dongzhihong 已提交
367
                        # TODO(qijun): dtype_to_size[x_dtype] and dtype_to_size[cache_dtype]
368
                        if PRINT_LOG:
D
dzhwinter 已提交
369 370 371 372 373 374
                            print(
                                ("!!! %d,  %s => %s, cache idx %d, pool size %d"
                                 % (counter, x + str(x_shape),
                                    cache_var + str(cache_shape), index,
                                    len(self.pool))))
                            counter += 1
375 376 377 378
                        self.pool.pop(index)
                        # Rename the var to the cache var already with
                        # memory allocated in order to reuse the memory.
                        _rename_arg_(self._ops, x, cache_var, begin_idx=i)
M
minqiyang 已提交
379 380 381
                        self._program.block(block_desc.id).var(cpt.to_text(
                            x)).desc = self._find_var(block_desc, cache_var,
                                                      is_forward)
382 383
                        self._program.block(block_desc.id).vars[cpt.to_text(x)] = \
                            Variable(self._program.block(block_desc.id), name=cpt.to_text(x))
384 385
                        self._update_graph(x, cache_var, begin_idx=i)
                        break
D
dzhwinter 已提交
386
            self._fill_pool(i, is_forward)
387 388


389
def _process_sub_block_pair(pdesc, sub_block_pair):
390 391 392 393 394 395 396 397 398 399 400 401 402
    """Creates a list of tuple each of which tracks info of a subblock.

      Note: this function doesn't handle nested subblocks yet.
      TODO(panyx0718): assert if case nested subblocks happen.

    :param pdesc: ProgramDesc.
    :param sub_block_pair: A list op pairs. Each op pair is the forward
        op and backward op. The ops in the list are special that they contain
        a subblock of ops.
    :return: A list of tuples, each tuple is (all ops in a subblock pair
        including forward and backward, number of forward ops,
        all output args names of the ops in the subblock pairs).
    """
403 404 405
    ops_list = []
    block_desc = pdesc.block(0)
    op_size = block_desc.op_size()
406 407 408 409 410 411 412 413 414 415 416 417 418
    for fwd_op, bwd_op in sub_block_pair:
        sub_block_ids = []
        grad_sub_block_ids = []
        sub_block_id_pair = []
        sub_op_dict = {}
        for i in range(op_size):
            op = block_desc.op(i)
            if op.type() == fwd_op:
                sub_block_ids.append(op.attr("sub_block").id)
                sub_op_dict[op.attr("sub_block").id] = op
            elif op.type() == bwd_op:
                grad_sub_block_ids.append(op.attr("sub_block").id)
                sub_op_dict[op.attr("sub_block").id] = op
419

420 421
        # Find fwd_op/bwd_op block pair
        for grad_id in grad_sub_block_ids:
Q
qijun 已提交
422 423 424 425
            fwd_id = pdesc.block(grad_id).get_forward_block_idx()
            if fwd_id in sub_block_ids:
                sub_block_id_pair.append((fwd_id, grad_id))
                sub_block_ids.remove(fwd_id)
426

427
        # Get fwd_op/bwd_op block ops
Q
qijun 已提交
428
        for fwd_id, grad_id in sub_block_id_pair:
429
            sub_block_ops = []
Q
qijun 已提交
430
            sub_block = pdesc.block(fwd_id)
431 432 433
            block_op_size = sub_block.op_size()
            for i in range(block_op_size):
                sub_block_ops.append(sub_block.op(i))
434

435 436 437 438
            grad_sub_block = pdesc.block(grad_id)
            grad_sub_block_op_size = grad_sub_block.op_size()
            for i in range(grad_sub_block_op_size):
                sub_block_ops.append(grad_sub_block.op(i))
439

440
            sub_op_output = set()
Q
qijun 已提交
441
            sub_op_output.update(sub_op_dict[fwd_id].output_arg_names())
442
            sub_op_output.update(sub_op_dict[grad_id].output_arg_names())
443 444
            sub_op_output.update(sub_op_dict[fwd_id].input_arg_names())
            sub_op_output.update(sub_op_dict[grad_id].input_arg_names())
445
            ops_list.append((sub_block_ops, block_op_size, sub_op_output))
446

447
        # Process rest fwd_op block ops
Q
qijun 已提交
448
        for fwd_id in sub_block_ids:
449
            sub_block_ops = []
Q
qijun 已提交
450
            sub_block = pdesc.block(fwd_id)
451 452 453 454
            sub_block_op_size = sub_block.op_size()
            for i in range(sub_block_op_size):
                sub_block_ops.append(sub_block.op(i))
            sub_op_output = set()
Q
qijun 已提交
455
            sub_op_output.update(sub_op_dict[fwd_id].output_arg_names())
456
            sub_op_output.update(sub_op_dict[fwd_id].input_arg_names())
457 458
            ops_list.append((sub_block_ops, sub_block_op_size, sub_op_output))
    return ops_list
459

460

461
def _get_cfgs(input_program):
462 463 464 465 466
    """Process each block and create ControlFlowGraph for each of them.

    :param input_program: Program object.
    :return: A list of ControlFlowGraph, each corresponds to a block.
    """
467
    ops_list = []
W
Wu Yi 已提交
468
    pdesc = input_program._get_desc()
469 470
    block_desc = pdesc.block(0)
    op_size = block_desc.op_size()
471

472 473
    # Only process one level of nested subblock.
    ops_list.extend(_process_sub_block_pair(pdesc, SUB_BLOCK_PAIR))
474

475 476 477 478 479 480 481
    skip_opt_set = set()
    for _, _, skip_opt in ops_list:
        skip_opt_set.update(skip_opt)

    # Get global block ops
    ops_list.insert(
        0, ([block_desc.op(i) for i in range(op_size)], op_size, skip_opt_set))
482 483 484 485
    cfgs = [
        ControlFlowGraph(input_program, ops, forward_num, skip_opt)
        for ops, forward_num, skip_opt in ops_list
    ]
486
    return cfgs
487 488


489 490 491 492 493 494 495 496 497 498 499 500
def _is_opt_role_op(op):
    op_maker = core.op_proto_and_checker_maker
    optimize_role = core.op_proto_and_checker_maker.OpRole.Optimize
    if op_maker.kOpRoleAttrName() in op.attr_names and \
            int(op.all_attrs()[op_maker.kOpRoleAttrName()]) == int(optimize_role):
        return True


def memory_optimize(input_program,
                    skip_opt_set=None,
                    print_log=False,
                    level=0,
501
                    skip_grads=True):
502 503 504 505 506 507 508 509 510 511 512
    """
    | Legacy memory optimization strategy, reduce total memory consumption by reuse variable memory between different operators.
    | Simple sample to explain the algorithm:
    
        ..  code-block:: python
        
            c = a + b  # assume this is the last time a is used
            d = b * c
         
    | since **a** will not be used anymore after **"c = a + b"**, and the size of **a** and **d** are the same, 
      we can use variable **a** to replace variable **d**, so actually we can optimize the above code to below:
513

514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531
        ..  code-block:: python
        
            c = a + b
            a = b * c 
          
    
    | Please notice that, in this legacy design, we are using variable **a** to replace **d** directly, which means 
      after you call this API, some variables may disappear, and some variables may hold unexpected values, like 
      the above case, actually **a** holds the value of **d** after execution. 
    
    | So to protect important variables from being reused/removed in the optimization, we provide skip_opt_set 
      to allow you specify a variable whitelist. 
      The variables in the skip_opt_set will not be affected by memory_optimize API.
    
    Note: 
        | **This API is deprecated, please avoid to use it in your new code.**
        | Does not support operators which will create sub-block like While, IfElse etc.
    
D
"rerun"  
dzhwinter 已提交
532 533 534 535
    Args:
        input_program(str): Input Program
        skip_opt_set(set): vars wil be skipped in memory optimze
        print_log(bool): whether to print debug log.
536
        level(int): 0 or 1, 0 means we replace a with b only when a.size == b.size, 1 means we can replace a with b if a.size <= b.size
D
"rerun"  
dzhwinter 已提交
537 538
    Returns:
        None
539

540
    Examples:
541 542
        .. code-block:: python

543
            import paddle.fluid as fluid
544 545 546 547 548 549 550 551 552
            main_prog = fluid.Program()
            startup_prog = fluid.Program()

            place = fluid.CPUPlace()
            exe = fluid.Executor(place)

            exe.run(startup_prog)
            fluid.memory_optimize(main_prog)

553
    """
554 555 556 557 558 559 560 561
    logging.warn(
        'Caution! paddle.fluid.memory_optimize() is deprecated '
        'and not maintained any more, since it is not stable!\n'
        'Please use the newest and stable memory optimization strategies!\n'
        ' 1. Enable garbage collection strategy by exporting environment '
        'variable FLAGS_eager_delete_tensor_gb=0\n'
        ' 2. Set build_strategy.enable_inplace=True (True is the default '
        'value) when using CompiledProgram or ParallelExecutor.\n')
562 563 564 565 566 567 568 569 570 571 572

    def to_name_str(var):
        if isinstance(var, Variable):
            return var.desc.name()
        elif isinstance(var, str):
            return var
        elif isinstance(var, six.string_types):
            return str(var)
        else:
            raise TypeError(str(var) + " should be Variable or str")

573 574
    if level != 0 and level != 1:
        raise ValueError("only support opt_level 0 or 1.")
D
dzhwinter 已提交
575 576 577 578 579
    if skip_opt_set is not None:
        if isinstance(skip_opt_set, set) or isinstance(skip_opt_set, list):
            skip_opt_set = set(skip_opt_set)
        else:
            raise ValueError("only support skip_opt_set as set.")
Q
qiaolongfei 已提交
580 581
    global PRINT_LOG
    PRINT_LOG = print_log
582 583 584 585 586 587 588 589 590 591 592 593
    if skip_grads:
        grad_set = set()
        OP_ROLE_VAR = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
        for op in input_program.global_block().ops:
            if _is_opt_role_op(op):
                if op.attr(OP_ROLE_VAR):
                    grad_name = op.attr(OP_ROLE_VAR)[1]
                    grad_set.add(grad_name)
        if not skip_opt_set:
            skip_opt_set = grad_set
        else:
            skip_opt_set.update(grad_set)
594 595
    if skip_opt_set is not None:
        skip_opt_set = set(map(to_name_str, skip_opt_set))
596
    cfgs = _get_cfgs(input_program)
D
dzhwinter 已提交
597
    input_program._is_mem_optimized = True
598
    for cfg in cfgs:
599
        cfg.memory_optimize(skip_opt_set=skip_opt_set, level=level)
600 601


602
def release_memory(input_program, skip_opt_set=None):
Y
yuyang18 已提交
603 604 605 606 607 608 609 610 611
    """
    Modify the input program and insert :code:`delete_op` to early drop not used
    variables. The modification will be performed inplace.

    Notes: This is an experimental API and could be removed in next few
    releases. Users should not use this API.

    Args:
        input_program(Program): The program will be inserted :code:`delete_op`.
D
"rerun"  
dzhwinter 已提交
612 613 614
        skip_opt_set(set): vars wil be skipped in memory optimze
    Returns:
        None
615

616
    Examples:
617
        .. code-block:: python
618 619 620 621 622 623 624 625 626

            import paddle.fluid as fluid

            # build network
            # ...
            
            # deprecated API
            fluid.release_memory(fluid.default_main_program())
    
Y
yuyang18 已提交
627
    """
628
    cfgs = _get_cfgs(input_program)
D
dzhwinter 已提交
629
    input_program._is_mem_optimized = True
630
    for cfg in cfgs:
631
        cfg.release_memory(skip_opt_set=skip_opt_set)