backward.py 9.8 KB
Newer Older
Q
Qiao Longfei 已提交
1
from paddle.v2.fluid import framework as framework
F
update  
fengjiayi 已提交
2
from . import core
F
update  
fengjiayi 已提交
3
import collections
F
update  
fengjiayi 已提交
4
import pdb
5

F
fengjiayi 已提交
6
__all__ = ['append_backward']
7 8


F
fengjiayi 已提交
9 10
def _rename_arg_(op_desc_list, old_name, new_name, begin_idx=None,
                 end_idx=None):
F
update  
fengjiayi 已提交
11 12 13 14 15 16 17 18 19
    if begin_idx is None:
        begin_idx = 0
    if end_idx is None:
        end_idx = len(op_desc_list)
    for i in range(begin_idx, end_idx):
        op_desc_list[i].rename_input(old_name, new_name)
        op_desc_list[i].rename_output(old_name, new_name)


F
fengjiayi 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
def _create_op_desc_(op_type, inputs, outputs, attrs):
    op_desc = core.OpDesc()
    op_desc.set_type(op_type)
    for para, args in inputs.iteritems():
        op_desc.set_input(para, args)
    for para, args in outputs.iteritems():
        op_desc.set_output(para, args)
    for name, val in attrs.iteritems():
        if isinstance(val, framework.Block):
            op_desc.set_block_attr(name, val.desc)
        else:
            op_desc.set_attr(name, val)
    return op_desc


F
fengjiayi 已提交
35 36 37 38 39 40 41 42 43 44
def _infer_var_data_type_(var_name, block):
    grad_var = block.desc.find_var(var_name.encode("ascii"))
    fwd_name = _strip_grad_suffix_(var_name.encode("ascii"))
    if block.desc.has_var_recursive(fwd_name):
        fwd_var = block.desc.find_var_recursive(fwd_name.encode("ascii"))
        grad_var.set_dtype(fwd_var.dtype())
    else:
        grad_var.set_dtype(core.DataType.FP32)


F
fengjiayi 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
def _is_all_in_set_(cands, s):
    for c in cands:
        if not c in s:
            return False
    return True


def _strip_grad_suffix_(name):
    return name[:name.find(core.grad_var_suffix())]


def _append_grad_suffix_(name):
    return name + core.grad_var_suffix()


F
fengjiayi 已提交
60 61 62 63 64
def _append_backward_ops_(target,
                          block,
                          target_block,
                          no_grad_set,
                          callback=None):
F
update  
fengjiayi 已提交
65
    grad_op_descs = []
F
fengjiayi 已提交
66
    grad_to_var = dict()
F
update  
fengjiayi 已提交
67
    program = block.program
F
fengjiayi 已提交
68
    for each_op in reversed(block.ops):
F
update  
fengjiayi 已提交
69 70 71 72 73
        grad_sub_block_list = []
        if each_op.has_attr("sub_block"):
            sub_block_idx = each_op.block_attr("sub_block")
            sub_block = program.block(sub_block_idx)
            grad_sub_block = program.create_block(parent_idx=sub_block_idx)
F
fengjiayi 已提交
74 75 76
            sub_grad_to_var = _append_backward_ops_(
                target, sub_block, grad_sub_block, no_grad_set, callback)
            grad_to_var = dict(grad_to_var, **sub_grad_to_var)
F
fengjiayi 已提交
77
            grad_sub_block_list.append(grad_sub_block.desc)
F
fengjiayi 已提交
78 79
        grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
            each_op.desc, no_grad_set[block.idx], grad_sub_block_list)
F
update  
fengjiayi 已提交
80
        grad_op_descs.append(grad_op_desc)
F
fengjiayi 已提交
81
        grad_to_var = dict(grad_to_var, **op_grad_to_var)
F
update  
fengjiayi 已提交
82 83 84 85
    # grad_op_descs = [[op1_g1, op1_g2], [op2_g], ...]
    # flatten grad_op_descs
    grad_op_descs = [op for sublist in grad_op_descs for op in sublist]  # ?????

F
update  
fengjiayi 已提交
86 87 88
    pending_sum_ops = []
    var_rename_count = collections.defaultdict(int)
    var_inputs = collections.defaultdict(list)
F
fengjiayi 已提交
89
    for idx, op_desc in enumerate(grad_op_descs):
F
update  
fengjiayi 已提交
90 91
        for var_name in op_desc.input_arg_names():
            if len(var_inputs[var_name]) > 1:
F
fengjiayi 已提交
92
                pending_sum_ops.append((_create_op_desc_(
F
fengjiayi 已提交
93 94 95
                    op_type="sum",
                    inputs={"X": var_inputs[var_name]},
                    outputs={"Out": [var_name]},
F
fengjiayi 已提交
96
                    attrs={}), idx))
F
update  
fengjiayi 已提交
97
                var_inputs[var_name] = [var_name]
F
update  
fengjiayi 已提交
98
        for var_name in op_desc.output_arg_names():
F
fengjiayi 已提交
99 100
            if var_name == core.empty_var_name() or len(var_inputs[
                    var_name]) == 0:
F
update  
fengjiayi 已提交
101
                # it's the first time we get the variable
F
update  
fengjiayi 已提交
102
                var_inputs[var_name] = [var_name]
F
update  
fengjiayi 已提交
103
            else:
F
fengjiayi 已提交
104
                if len(var_inputs[var_name]) == 1:
F
update  
fengjiayi 已提交
105 106 107 108 109
                    new_name = var_name + "@RENAME@" + \
                        str(var_rename_count[var_name])
                    var_rename_count[var_name] = var_rename_count[var_name] + 1
                    # rename original var_name
                    var_inputs[var_name][0] = new_name
F
fengjiayi 已提交
110
                    _rename_arg_(grad_op_descs, var_name, new_name, 0, idx)
F
fengjiayi 已提交
111
                    _rename_arg_(pending_sum_ops, var_name, new_name)
F
update  
fengjiayi 已提交
112 113 114 115 116 117 118 119

                new_name = var_name + "@RENAME@" + \
                    str(var_rename_count[var_name])
                var_rename_count[var_name] = var_rename_count[var_name] + 1
                op_desc.rename_output(var_name, new_name)
                var_inputs[var_name].append(new_name)
    for var_name, inputs in var_inputs.iteritems():
        if len(inputs) > 1:
F
fengjiayi 已提交
120
            pending_sum_ops.append((_create_op_desc_(
F
fengjiayi 已提交
121
                op_type="sum",
F
fengjiayi 已提交
122
                inputs={"X": inputs},
F
fengjiayi 已提交
123
                outputs={"Out": [var_name]},
F
fengjiayi 已提交
124
                attrs={}), len(grad_op_descs)))
F
fengjiayi 已提交
125
    # sum_op descs are sorted according to their insert position
F
update  
fengjiayi 已提交
126 127
    for p in reversed(pending_sum_ops):
        grad_op_descs.insert(p[1], p[0])
F
fengjiayi 已提交
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
    # Remove ops whose outputs are all in no_grad_set
    grad_op_descs = filter(
        lambda op_desc: not _is_all_in_set_(op_desc.output_arg_names(), no_grad_set[block.idx]),
        grad_op_descs)
    # Insert fill_zeros_like_op
    to_insert = []
    for idx, op_desc in enumerate(grad_op_descs):
        for arg in op_desc.input_arg_names():
            if arg in no_grad_set[block.idx]:
                to_insert.append((arg, idx))
    for ele in reversed(to_insert):
        arg = ele[0]
        fill_zeros_like_op = _create_op_desc_(
            "fill_zeros_like", {"X": [_strip_grad_suffix_(arg)]}, {"Y": [arg]},
            {})
        grad_op_descs.insert(ele[1], fill_zeros_like_op)
F
fengjiayi 已提交
144

F
update  
fengjiayi 已提交
145
    if target_block.idx == 0:
F
fengjiayi 已提交
146 147
        grad_target_name = _append_grad_suffix_(target.name)
        target_block.desc.var(grad_target_name.encode("ascii"))
F
update  
fengjiayi 已提交
148 149
        grad_op_descs.insert(
            0,
F
fengjiayi 已提交
150 151 152 153 154 155 156 157 158
            _create_op_desc_(
                op_type="fill_constant",
                inputs={},
                outputs={"Out": [grad_target_name]},
                attrs={
                    "shape": [1],
                    "value": 1.0,
                    "dtype": core.DataType.FP32
                }))
F
update  
fengjiayi 已提交
159
    for op_desc in grad_op_descs:
F
fengjiayi 已提交
160 161
        new_op_desc = target_block.desc.append_op()
        new_op_desc.copy_from(op_desc)
F
update  
fengjiayi 已提交
162

F
fengjiayi 已提交
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
    return grad_to_var


def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
    for op_idx in range(start_op_idx, block.desc.op_size()):
        op_desc = block.desc.op(op_idx)
        if op_desc.has_attr("sub_block"):
            sub_block = block.program.block(op_desc.block_attr("sub_block"))
            _append_backward_vars_(sub_block, 0, grad_to_var, grad_info_map)
        new_vars = set()
        # create new gradient variables
        for grad_var_name in op_desc.output_arg_names():
            grad_var_name = grad_var_name.encode("ascii")
            if block.desc.has_var_recursive(
                    grad_var_name) or grad_var_name == core.empty_var_name():
                continue
            block.desc.var(grad_var_name)
            new_vars.add(grad_var_name)
            if not grad_to_var.has_key(grad_var_name):
                continue
            grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name, block)
        # infer_shape and infer_type
        op_desc.infer_var_type(block.desc)
        op_desc.infer_shape(block.desc)
        for arg in op_desc.output_arg_names():
            if arg in new_vars:
                _infer_var_data_type_(arg, block)
F
update  
fengjiayi 已提交
190 191


F
fengjiayi 已提交
192
def append_backward(loss, parameter_list=None, no_grad_set=None):
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
    """
    Create and add gradient Operators in BlockDesc to compute
    gradients of `loss` for parameters in parameter_list

    :param loss: an variable generated by cost function.
    :type loss: Variable
    :param no_grad_set: variable that should not create gradient
    :type no_grad_set: set
    :param parameter_list: parameters that need to compute gradient and 
    update to optimize the lost.
    :type: list
    :return: list of (parameters, gradients) pair.
    :rtype: list[Variable]
    """
    assert isinstance(loss, framework.Variable)
Y
Yu Yang 已提交
208

F
fengjiayi 已提交
209
    program = loss.block.program
Y
Yu Yang 已提交
210
    if no_grad_set is None:
F
update  
fengjiayi 已提交
211
        no_grad_set = dict()
Y
Yu Yang 已提交
212 213 214
        assert isinstance(program, framework.Program)
        for block in program.blocks:
            assert isinstance(block, framework.Block)
F
update  
fengjiayi 已提交
215
            block_no_grad_set = set()
Y
Yu Yang 已提交
216 217 218
            for var in block.vars.itervalues():
                assert isinstance(var, framework.Variable)
                if var.stop_gradient:
F
fengjiayi 已提交
219
                    block_no_grad_set.add(_append_grad_suffix_(var.name))
F
update  
fengjiayi 已提交
220
            no_grad_set[block.idx] = block_no_grad_set
Y
Yu Yang 已提交
221

F
update  
fengjiayi 已提交
222
    grad_info_map = dict()
F
fengjiayi 已提交
223
    root_block = program.block(0)
F
fengjiayi 已提交
224

F
fengjiayi 已提交
225 226 227 228 229 230 231
    fwd_op_num = root_block.desc.op_size()
    current_block_idx = program.current_block_idx
    grad_to_var = _append_backward_ops_(loss, root_block, root_block,
                                        no_grad_set)
    _append_backward_vars_(root_block, fwd_op_num, grad_to_var, grad_info_map)
    program.current_block_idx = current_block_idx
    program.sync_with_cpp()
F
fengjiayi 已提交
232

233 234 235
    if parameter_list is not None:
        parameters = parameter_list
    else:
F
fengjiayi 已提交
236
        params = program.global_block().all_parameters()
237 238 239
        parameters = [param.name for param in params]
    params_and_grads = []
    for param in parameters:
F
update  
fengjiayi 已提交
240
        if param not in grad_info_map:
241
            raise ValueError("param %s is not in map" % param)
F
update  
fengjiayi 已提交
242
        grad_info = grad_info_map[param]
F
fengjiayi 已提交
243
        grad_block = grad_info[1]
244 245 246 247
        if not grad_block.has_var(grad_info[0]):
            raise ValueError("grad block[{0}] did not have grad var {1}".format(
                grad_info[1], grad_info[0]))
        # Get the param var from the global block
F
fengjiayi 已提交
248
        param_var = program.global_block().var(param)
249 250 251 252 253 254
        grad_var = grad_block.var(grad_info[0])
        if loss.block.has_var(grad_info[0]):
            params_and_grads.append((param_var, grad_var))
        else:
            params_and_grads.append((param_var, None))
    return params_and_grads