backward.py 3.7 KB
Newer Older
Q
Qiao Longfei 已提交
1
from paddle.v2.fluid import framework as framework
F
update  
fengjiayi 已提交
2
from . import core
F
update  
fengjiayi 已提交
3
import collections
4 5 6 7

__all__ = ['append_backward_ops']


F
update  
fengjiayi 已提交
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
def backward_impl(block, target_block, no_grad_set, grad_to_var, callback):
    grad_op_descs = []
    program = block.program
    for each_op in block.ops:
        grad_sub_block_list = []
        if each_op.has_attr("sub_block"):
            sub_block_idx = each_op.block_attr("sub_block")
            sub_block = program.block(sub_block_idx)
            grad_sub_block = program.create_block(parent_idx=sub_block_idx)
            backward_impl(sub_block, grad_sub_block, no_grad_set, grad_to_var,
                          callback)
            grad_sub_block_list.append(grad_sub_block)
        grad_op_desc = core.get_grad_op_desc(each_op.desc,
                                             no_grad_set[block.idx],
                                             grad_to_var, grad_sub_block_list)
        grad_op_descs.append(grad_op_desc)
F
update  
fengjiayi 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37
    # grad_op_descs = [[op1_g1, op1_g2], [op2_g], ...]
    # flatten grad_op_descs
    grad_op_descs = [op for sublist in grad_op_descs for op in sublist]  # ?????

    output_vars = collections.defaultdict(list)
    for pos, op_desc in enumerate(grad_op_descs):
        for var_name in op_desc.output_arg_names():
            output_vars[var_name].append(pos)
    for var_name, poses in output_vars.iteritems():
        if len(poses) == 1:
            continue
        renamed_list = []
        for pos in reversed(sorted(poses)):
            new_name = var_name + "@RENAMED@" + len(renamed_list)
F
update  
fengjiayi 已提交
38 39


40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
    """
    Create and add gradient Operators in BlockDesc to compute
    gradients of `loss` for parameters in parameter_list

    :param loss: an variable generated by cost function.
    :type loss: Variable
    :param no_grad_set: variable that should not create gradient
    :type no_grad_set: set
    :param parameter_list: parameters that need to compute gradient and 
    update to optimize the lost.
    :type: list
    :return: list of (parameters, gradients) pair.
    :rtype: list[Variable]
    """
    assert isinstance(loss, framework.Variable)
Y
Yu Yang 已提交
56 57 58 59 60 61 62 63 64 65 66 67 68 69

    if no_grad_set is None:
        program = loss.block.program
        assert isinstance(program, framework.Program)
        no_grad_set = list()
        for block in program.blocks:
            assert isinstance(block, framework.Block)
            for var in block.vars.itervalues():
                assert isinstance(var, framework.Variable)
                if var.stop_gradient:
                    no_grad_set.append(var.name)
        no_grad_set = set(no_grad_set)

    param_grad_map = loss.block.program.append_backward(loss, no_grad_set)
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
    if parameter_list is not None:
        parameters = parameter_list
    else:
        params = loss.block.program.global_block().all_parameters()
        parameters = [param.name for param in params]
    params_and_grads = []
    for param in parameters:
        if param not in param_grad_map:
            raise ValueError("param %s is not in map" % param)
        grad_info = param_grad_map[param]
        grad_block = loss.block.program.block(grad_info[1])
        if not grad_block.has_var(grad_info[0]):
            raise ValueError("grad block[{0}] did not have grad var {1}".format(
                grad_info[1], grad_info[0]))
        # Get the param var from the global block
        param_var = loss.block.program.global_block().var(param)
        grad_var = grad_block.var(grad_info[0])
        if loss.block.has_var(grad_info[0]):
            params_and_grads.append((param_var, grad_var))
        else:
            params_and_grads.append((param_var, None))
    return params_and_grads