backward.py 5.8 KB
Newer Older
Q
Qiao Longfei 已提交
1
from paddle.v2.fluid import framework as framework
F
update  
fengjiayi 已提交
2
from . import core
F
update  
fengjiayi 已提交
3
import collections
4 5 6 7

__all__ = ['append_backward_ops']


F
update  
fengjiayi 已提交
8 9 10 11 12 13 14 15 16 17 18
def rename_arg(op_desc_list, old_name, new_name, begin_idx=None, end_idx=None):
    if begin_idx is None:
        begin_idx = 0
    if end_idx is None:
        end_idx = len(op_desc_list)
    for i in range(begin_idx, end_idx):
        op_desc_list[i].rename_input(old_name, new_name)
        op_desc_list[i].rename_output(old_name, new_name)


def backward_impl(block, target_block, no_grad_set, callback=None):
F
update  
fengjiayi 已提交
19
    grad_op_descs = []
F
update  
fengjiayi 已提交
20
    grad_to_var = {}
F
update  
fengjiayi 已提交
21 22 23 24 25 26 27
    program = block.program
    for each_op in block.ops:
        grad_sub_block_list = []
        if each_op.has_attr("sub_block"):
            sub_block_idx = each_op.block_attr("sub_block")
            sub_block = program.block(sub_block_idx)
            grad_sub_block = program.create_block(parent_idx=sub_block_idx)
F
update  
fengjiayi 已提交
28
            backward_impl(sub_block, grad_sub_block, no_grad_set, callback)
F
update  
fengjiayi 已提交
29 30 31 32 33
            grad_sub_block_list.append(grad_sub_block)
        grad_op_desc = core.get_grad_op_desc(each_op.desc,
                                             no_grad_set[block.idx],
                                             grad_to_var, grad_sub_block_list)
        grad_op_descs.append(grad_op_desc)
F
update  
fengjiayi 已提交
34 35 36 37
    # grad_op_descs = [[op1_g1, op1_g2], [op2_g], ...]
    # flatten grad_op_descs
    grad_op_descs = [op for sublist in grad_op_descs for op in sublist]  # ?????

F
update  
fengjiayi 已提交
38 39 40
    pending_sum_ops = []
    var_rename_count = collections.defaultdict(int)
    var_inputs = collections.defaultdict(list)
F
update  
fengjiayi 已提交
41
    for pos, op_desc in enumerate(grad_op_descs):
F
update  
fengjiayi 已提交
42 43 44 45 46 47 48 49
        for var_name in op_desc.input_arg_names():
            if len(var_inputs[var_name]) > 1:
                pending_sum_ops.append((core.OpDesc(
                    type="sum_op",
                    inputs=var_inputs[var_name],
                    output=[var_name],
                    attrs={}), pos))
                var_inputs[var_name] = [var_name]
F
update  
fengjiayi 已提交
50
        for var_name in op_desc.output_arg_names():
F
update  
fengjiayi 已提交
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
            if len(var_inputs[var_name]) == 0:
                # it's the first time we get the variable
                var_inputs[var_name] = var_name
            else:
                if len(var_inputs[var_name] == 1):
                    new_name = var_name + "@RENAME@" + \
                        str(var_rename_count[var_name])
                    var_rename_count[var_name] = var_rename_count[var_name] + 1
                    # rename original var_name
                    var_inputs[var_name][0] = new_name
                    rename_arg(grad_op_descs, var_name, new_name, 0, pos)
                    rename_arg(pending_sum_ops, var_name, new_name)

                new_name = var_name + "@RENAME@" + \
                    str(var_rename_count[var_name])
                var_rename_count[var_name] = var_rename_count[var_name] + 1
                op_desc.rename_output(var_name, new_name)
                var_inputs[var_name].append(new_name)
    for var_name, inputs in var_inputs.iteritems():
        if len(inputs) > 1:
            pending_sum_ops.append((core.OpDesc(
                type="sum_op", inputs=inputs, outputs=var_name, attrs={}),
                                    len(grad_op_descs)))
    # 根据append的顺序可以看出pending_sum_ops一定是根据sum_op的插入位置排序的
    for p in reversed(pending_sum_ops):
        grad_op_descs.insert(p[1], p[0])
    # create new gradient variables in the target block
    for op_desc in grad_op_descs:
        for grad_var_name in op_desc.output_arg_names():
            if target_block.has_var(
                    grad_var_name) or grad_var_name == core.get_empty_var_name(
                    ):
                continue
            target_block.var(grad_var_name)
F
update  
fengjiayi 已提交
85 86


87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
    """
    Create and add gradient Operators in BlockDesc to compute
    gradients of `loss` for parameters in parameter_list

    :param loss: an variable generated by cost function.
    :type loss: Variable
    :param no_grad_set: variable that should not create gradient
    :type no_grad_set: set
    :param parameter_list: parameters that need to compute gradient and 
    update to optimize the lost.
    :type: list
    :return: list of (parameters, gradients) pair.
    :rtype: list[Variable]
    """
    assert isinstance(loss, framework.Variable)
Y
Yu Yang 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115 116

    if no_grad_set is None:
        program = loss.block.program
        assert isinstance(program, framework.Program)
        no_grad_set = list()
        for block in program.blocks:
            assert isinstance(block, framework.Block)
            for var in block.vars.itervalues():
                assert isinstance(var, framework.Variable)
                if var.stop_gradient:
                    no_grad_set.append(var.name)
        no_grad_set = set(no_grad_set)

    param_grad_map = loss.block.program.append_backward(loss, no_grad_set)
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
    if parameter_list is not None:
        parameters = parameter_list
    else:
        params = loss.block.program.global_block().all_parameters()
        parameters = [param.name for param in params]
    params_and_grads = []
    for param in parameters:
        if param not in param_grad_map:
            raise ValueError("param %s is not in map" % param)
        grad_info = param_grad_map[param]
        grad_block = loss.block.program.block(grad_info[1])
        if not grad_block.has_var(grad_info[0]):
            raise ValueError("grad block[{0}] did not have grad var {1}".format(
                grad_info[1], grad_info[0]))
        # Get the param var from the global block
        param_var = loss.block.program.global_block().var(param)
        grad_var = grad_block.var(grad_info[0])
        if loss.block.has_var(grad_info[0]):
            params_and_grads.append((param_var, grad_var))
        else:
            params_and_grads.append((param_var, None))
    return params_and_grads