backward.py 7.1 KB
Newer Older
Q
Qiao Longfei 已提交
1
from paddle.v2.fluid import framework as framework
F
update  
fengjiayi 已提交
2
from . import core
F
update  
fengjiayi 已提交
3
import collections
F
update  
fengjiayi 已提交
4
import pdb
5 6 7 8

__all__ = ['append_backward_ops']


F
update  
fengjiayi 已提交
9 10 11 12 13 14 15 16 17 18
def rename_arg(op_desc_list, old_name, new_name, begin_idx=None, end_idx=None):
    if begin_idx is None:
        begin_idx = 0
    if end_idx is None:
        end_idx = len(op_desc_list)
    for i in range(begin_idx, end_idx):
        op_desc_list[i].rename_input(old_name, new_name)
        op_desc_list[i].rename_output(old_name, new_name)


F
update  
fengjiayi 已提交
19 20
def backward_impl(target,
                  block,
F
update  
fengjiayi 已提交
21 22 23 24
                  target_block,
                  no_grad_set,
                  grad_info_map,
                  callback=None):
F
update  
fengjiayi 已提交
25
    grad_op_descs = []
F
update  
fengjiayi 已提交
26
    grad_to_var = {}
F
update  
fengjiayi 已提交
27 28 29 30 31 32 33
    program = block.program
    for each_op in block.ops:
        grad_sub_block_list = []
        if each_op.has_attr("sub_block"):
            sub_block_idx = each_op.block_attr("sub_block")
            sub_block = program.block(sub_block_idx)
            grad_sub_block = program.create_block(parent_idx=sub_block_idx)
F
update  
fengjiayi 已提交
34 35
            backward_impl(target, sub_block, grad_sub_block, no_grad_set,
                          grad_info_map, callback)
F
update  
fengjiayi 已提交
36 37 38 39 40
            grad_sub_block_list.append(grad_sub_block)
        grad_op_desc = core.get_grad_op_desc(each_op.desc,
                                             no_grad_set[block.idx],
                                             grad_to_var, grad_sub_block_list)
        grad_op_descs.append(grad_op_desc)
F
update  
fengjiayi 已提交
41 42 43 44
    # grad_op_descs = [[op1_g1, op1_g2], [op2_g], ...]
    # flatten grad_op_descs
    grad_op_descs = [op for sublist in grad_op_descs for op in sublist]  # ?????

F
update  
fengjiayi 已提交
45 46 47
    pending_sum_ops = []
    var_rename_count = collections.defaultdict(int)
    var_inputs = collections.defaultdict(list)
F
update  
fengjiayi 已提交
48
    for pos, op_desc in enumerate(grad_op_descs):
F
update  
fengjiayi 已提交
49 50
        for var_name in op_desc.input_arg_names():
            if len(var_inputs[var_name]) > 1:
F
update  
fengjiayi 已提交
51
                pdb.set_trace()
F
update  
fengjiayi 已提交
52 53 54 55 56 57
                pending_sum_ops.append((core.OpDesc(
                    type="sum_op",
                    inputs=var_inputs[var_name],
                    output=[var_name],
                    attrs={}), pos))
                var_inputs[var_name] = [var_name]
F
update  
fengjiayi 已提交
58
        for var_name in op_desc.output_arg_names():
F
update  
fengjiayi 已提交
59 60
            if len(var_inputs[var_name]) == 0:
                # it's the first time we get the variable
F
update  
fengjiayi 已提交
61
                var_inputs[var_name] = [var_name]
F
update  
fengjiayi 已提交
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
            else:
                if len(var_inputs[var_name] == 1):
                    new_name = var_name + "@RENAME@" + \
                        str(var_rename_count[var_name])
                    var_rename_count[var_name] = var_rename_count[var_name] + 1
                    # rename original var_name
                    var_inputs[var_name][0] = new_name
                    rename_arg(grad_op_descs, var_name, new_name, 0, pos)
                    rename_arg(pending_sum_ops, var_name, new_name)

                new_name = var_name + "@RENAME@" + \
                    str(var_rename_count[var_name])
                var_rename_count[var_name] = var_rename_count[var_name] + 1
                op_desc.rename_output(var_name, new_name)
                var_inputs[var_name].append(new_name)
    for var_name, inputs in var_inputs.iteritems():
        if len(inputs) > 1:
F
update  
fengjiayi 已提交
79 80 81
            pdb.set_trace()
            pending_sum_ops.append((core.OpDesc("sum_op", {"X": inputs},
                                                {"Out": var_name}, {}),
F
update  
fengjiayi 已提交
82
                                    len(grad_op_descs)))
F
update  
fengjiayi 已提交
83 84
    # TODO: remove op in no grad set

F
update  
fengjiayi 已提交
85 86 87
    # 根据append的顺序可以看出pending_sum_ops一定是根据sum_op的插入位置排序的
    for p in reversed(pending_sum_ops):
        grad_op_descs.insert(p[1], p[0])
F
update  
fengjiayi 已提交
88
    # create new gradient variables in the target block desc
F
update  
fengjiayi 已提交
89 90
    for op_desc in grad_op_descs:
        for grad_var_name in op_desc.output_arg_names():
F
update  
fengjiayi 已提交
91
            grad_var_name = grad_var_name.encode("ascii")
F
update  
fengjiayi 已提交
92
            if target_block.desc.has_var(
F
update  
fengjiayi 已提交
93 94 95
                    grad_var_name) or grad_var_name == core.get_empty_var_name(
                    ):
                continue
F
update  
fengjiayi 已提交
96 97 98 99 100
            target_block.desc.var(grad_var_name)
            if not grad_to_var.has_key(grad_var_name):
                continue
            grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name,
                                                         target_block)
F
update  
fengjiayi 已提交
101 102 103 104 105 106 107 108 109 110
    if target_block.idx == 0:
        grad_target_name = (target.name + "@GRAD")
        target_block.desc.var(grad_target_name)
        grad_op_descs.insert(
            0,
            core.OpDesc(u"fill_constant", {}, {
                u"Out": [unicode(grad_target_name, "ascii")]
            }, {u"shape": (1),
                u"value": 1.0,
                u"dtype": core.DataType.FP32}))
F
update  
fengjiayi 已提交
111 112 113 114 115
    # insert backward operators to target_block
    for op_desc in grad_op_descs:
        target_block.desc.append_allocated_op(op_desc)

    target_block.sync_with_cpp()
F
update  
fengjiayi 已提交
116 117


118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
    """
    Create and add gradient Operators in BlockDesc to compute
    gradients of `loss` for parameters in parameter_list

    :param loss: an variable generated by cost function.
    :type loss: Variable
    :param no_grad_set: variable that should not create gradient
    :type no_grad_set: set
    :param parameter_list: parameters that need to compute gradient and 
    update to optimize the lost.
    :type: list
    :return: list of (parameters, gradients) pair.
    :rtype: list[Variable]
    """
    assert isinstance(loss, framework.Variable)
Y
Yu Yang 已提交
134 135

    if no_grad_set is None:
F
update  
fengjiayi 已提交
136
        no_grad_set = dict()
Y
Yu Yang 已提交
137 138 139 140
        program = loss.block.program
        assert isinstance(program, framework.Program)
        for block in program.blocks:
            assert isinstance(block, framework.Block)
F
update  
fengjiayi 已提交
141
            block_no_grad_set = set()
Y
Yu Yang 已提交
142 143 144
            for var in block.vars.itervalues():
                assert isinstance(var, framework.Variable)
                if var.stop_gradient:
F
update  
fengjiayi 已提交
145 146
                    block_no_grad_set.add(var.name)
            no_grad_set[block.idx] = block_no_grad_set
Y
Yu Yang 已提交
147

F
update  
fengjiayi 已提交
148 149 150 151
    grad_info_map = dict()
    root_block = loss.block.program.block(0)
    backward_impl(loss, root_block, root_block, no_grad_set, grad_info_map)
    pdb.set_trace()
152 153 154 155 156 157 158
    if parameter_list is not None:
        parameters = parameter_list
    else:
        params = loss.block.program.global_block().all_parameters()
        parameters = [param.name for param in params]
    params_and_grads = []
    for param in parameters:
F
update  
fengjiayi 已提交
159
        if param not in grad_info_map:
160
            raise ValueError("param %s is not in map" % param)
F
update  
fengjiayi 已提交
161
        grad_info = grad_info_map[param]
162 163 164 165 166 167 168 169 170 171 172 173
        grad_block = loss.block.program.block(grad_info[1])
        if not grad_block.has_var(grad_info[0]):
            raise ValueError("grad block[{0}] did not have grad var {1}".format(
                grad_info[1], grad_info[0]))
        # Get the param var from the global block
        param_var = loss.block.program.global_block().var(param)
        grad_var = grad_block.var(grad_info[0])
        if loss.block.has_var(grad_info[0]):
            params_and_grads.append((param_var, grad_var))
        else:
            params_and_grads.append((param_var, None))
    return params_and_grads