backward.py 10.0 KB
Newer Older
Q
Qiao Longfei 已提交
1
from paddle.v2.fluid import framework as framework
F
update  
fengjiayi 已提交
2
from . import core
F
update  
fengjiayi 已提交
3
import collections
4

F
fengjiayi 已提交
5
__all__ = ['append_backward']
6 7


F
fengjiayi 已提交
8 9
def _rename_arg_(op_desc_list, old_name, new_name, begin_idx=None,
                 end_idx=None):
F
update  
fengjiayi 已提交
10 11 12 13 14
    if begin_idx is None:
        begin_idx = 0
    if end_idx is None:
        end_idx = len(op_desc_list)
    for i in range(begin_idx, end_idx):
F
fengjiayi 已提交
15 16 17 18 19
        op_desc = op_desc_list[i]
        if isinstance(op_desc, tuple):
            op_desc = op_desc[0]
        op_desc.rename_input(old_name, new_name)
        op_desc.rename_output(old_name, new_name)
F
update  
fengjiayi 已提交
20 21


F
fengjiayi 已提交
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
def _create_op_desc_(op_type, inputs, outputs, attrs):
    op_desc = core.OpDesc()
    op_desc.set_type(op_type)
    for para, args in inputs.iteritems():
        op_desc.set_input(para, args)
    for para, args in outputs.iteritems():
        op_desc.set_output(para, args)
    for name, val in attrs.iteritems():
        if isinstance(val, framework.Block):
            op_desc.set_block_attr(name, val.desc)
        else:
            op_desc.set_attr(name, val)
    return op_desc


F
fengjiayi 已提交
37 38 39 40 41 42 43 44 45 46
def _infer_var_data_type_(var_name, block):
    grad_var = block.desc.find_var(var_name.encode("ascii"))
    fwd_name = _strip_grad_suffix_(var_name.encode("ascii"))
    if block.desc.has_var_recursive(fwd_name):
        fwd_var = block.desc.find_var_recursive(fwd_name.encode("ascii"))
        grad_var.set_dtype(fwd_var.dtype())
    else:
        grad_var.set_dtype(core.DataType.FP32)


F
fengjiayi 已提交
47
def _all_in_set_(cands, s):
F
fengjiayi 已提交
48 49 50 51 52 53 54
    for c in cands:
        if not c in s:
            return False
    return True


def _strip_grad_suffix_(name):
F
fengjiayi 已提交
55 56
    pos = name.find(core.grad_var_suffix())
    return name[:pos] if pos != -1 else name
F
fengjiayi 已提交
57 58 59 60 61 62


def _append_grad_suffix_(name):
    return name + core.grad_var_suffix()


F
fengjiayi 已提交
63 64 65 66
def _addup_repetitive_outputs_(op_descs):
    # In backward part, an variable my be the output of more than one ops.
    # In this case, the variable should be the accumulation of all the outputs.
    # We adopt adding `sum_op`s to implement the accumulate.
F
update  
fengjiayi 已提交
67 68
    pending_sum_ops = []
    var_rename_count = collections.defaultdict(int)
F
fengjiayi 已提交
69 70
    renamed_vars = collections.defaultdict(list)
    for idx, op_desc in enumerate(op_descs):
F
update  
fengjiayi 已提交
71
        for var_name in op_desc.input_arg_names():
F
fengjiayi 已提交
72 73 74 75 76
            if len(renamed_vars[var_name]) > 1:
                pending_sum_ops.append(
                    (_create_op_desc_("sum", {"X": renamed_vars[var_name]},
                                      {"Out": [var_name]}, {}), idx))
                renamed_vars[var_name] = [var_name]
F
update  
fengjiayi 已提交
77
        for var_name in op_desc.output_arg_names():
F
fengjiayi 已提交
78 79 80
            if var_name == core.empty_var_name(
            ) or var_name in op_desc.input_arg_names():
                # empty variable or inplace op
F
fengjiayi 已提交
81
                continue
F
fengjiayi 已提交
82
            if len(renamed_vars[var_name]) == 0:
F
update  
fengjiayi 已提交
83
                # it's the first time we get the variable
F
fengjiayi 已提交
84
                renamed_vars[var_name] = [var_name]
F
update  
fengjiayi 已提交
85
            else:
F
fengjiayi 已提交
86
                if len(renamed_vars[var_name]) == 1:
F
update  
fengjiayi 已提交
87 88
                    new_name = var_name + "@RENAME@" + \
                        str(var_rename_count[var_name])
F
fengjiayi 已提交
89
                    var_rename_count[var_name] += 1
F
update  
fengjiayi 已提交
90
                    # rename original var_name
F
fengjiayi 已提交
91 92
                    renamed_vars[var_name][0] = new_name
                    _rename_arg_(op_descs, var_name, new_name, 0, idx)
F
fengjiayi 已提交
93
                    _rename_arg_(pending_sum_ops, var_name, new_name)
F
update  
fengjiayi 已提交
94 95 96

                new_name = var_name + "@RENAME@" + \
                    str(var_rename_count[var_name])
F
fengjiayi 已提交
97
                var_rename_count[var_name] += 1
F
update  
fengjiayi 已提交
98
                op_desc.rename_output(var_name, new_name)
F
fengjiayi 已提交
99 100
                renamed_vars[var_name].append(new_name)
    for var_name, inputs in renamed_vars.iteritems():
F
update  
fengjiayi 已提交
101
        if len(inputs) > 1:
F
fengjiayi 已提交
102
            pending_sum_ops.append((_create_op_desc_(
F
fengjiayi 已提交
103
                "sum", {"X": inputs}, {"Out": [var_name]}, {}), len(op_descs)))
F
fengjiayi 已提交
104
    # sum_op descs are sorted according to their insert position
F
update  
fengjiayi 已提交
105
    for p in reversed(pending_sum_ops):
F
fengjiayi 已提交
106 107 108 109 110 111 112 113 114 115
        op_descs.insert(p[1], p[0])

    return op_descs


def _remove_no_grad_branch_(op_descs, no_grad_set):
    # Remove ops whose outputs are all in no_grad_dict
    op_descs = filter(
        lambda op_desc: not _all_in_set_(op_desc.output_arg_names(), no_grad_set),
        op_descs)
F
fengjiayi 已提交
116 117
    # Insert fill_zeros_like_op
    to_insert = []
F
fengjiayi 已提交
118
    for idx, op_desc in enumerate(op_descs):
F
fengjiayi 已提交
119
        for arg in op_desc.input_arg_names():
F
fengjiayi 已提交
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
            if core.grad_var_suffix() in arg and arg in no_grad_set:
                to_insert.append((_create_op_desc_("fill_zeros_like", {
                    "X": [_strip_grad_suffix_(arg)]
                }, {"Y": [arg]}, {}), idx))

    map(lambda p: op_descs.insert(p[1], p[0]), reversed(to_insert))

    return op_descs


def _append_backward_ops_(target,
                          block,
                          target_block,
                          no_grad_dict,
                          grad_to_var,
                          callback=None):
    grad_op_descs = []
    program = block.program
    for op in reversed(block.ops):
        grad_sub_block_list = []
        # If the op has its own sub-block, deal with the sub-block first
        if op.has_attr("sub_block"):
            sub_block = program.block(op.block_attr("sub_block"))
            grad_sub_block = program.create_block(parent_idx=sub_block.idx)
            _append_backward_ops_(target, sub_block, grad_sub_block,
                                  no_grad_dict, grad_to_var, callback)
            grad_sub_block_list.append(grad_sub_block.desc)

        grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
            op.desc, no_grad_dict[block.idx], grad_sub_block_list)
        grad_op_descs.extend(grad_op_desc)
        grad_to_var.update(op_grad_to_var)

    grad_op_descs = _addup_repetitive_outputs_(grad_op_descs)

    grad_op_descs = _remove_no_grad_branch_(grad_op_descs,
                                            no_grad_dict[block.idx])
F
fengjiayi 已提交
157

F
update  
fengjiayi 已提交
158 159 160
    if target_block.idx == 0:
        grad_op_descs.insert(
            0,
F
fengjiayi 已提交
161 162 163 164 165 166
            _create_op_desc_("fill_constant", {}, {
                "Out": [_append_grad_suffix_(target.name)]
            }, {"shape": [1],
                "value": 1.0,
                "dtype": target.dtype}))
    # append op_desc in grad_op_descs to target_block
F
update  
fengjiayi 已提交
167
    for op_desc in grad_op_descs:
F
fengjiayi 已提交
168 169
        new_op_desc = target_block.desc.append_op()
        new_op_desc.copy_from(op_desc)
F
update  
fengjiayi 已提交
170

F
fengjiayi 已提交
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195

def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
    for op_idx in range(start_op_idx, block.desc.op_size()):
        op_desc = block.desc.op(op_idx)
        if op_desc.has_attr("sub_block"):
            sub_block = block.program.block(op_desc.block_attr("sub_block"))
            _append_backward_vars_(sub_block, 0, grad_to_var, grad_info_map)
        new_vars = set()
        # create new gradient variables
        for grad_var_name in op_desc.output_arg_names():
            grad_var_name = grad_var_name.encode("ascii")
            if block.desc.has_var_recursive(
                    grad_var_name) or grad_var_name == core.empty_var_name():
                continue
            block.desc.var(grad_var_name)
            new_vars.add(grad_var_name)
            if not grad_to_var.has_key(grad_var_name):
                continue
            grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name, block)
        # infer_shape and infer_type
        op_desc.infer_var_type(block.desc)
        op_desc.infer_shape(block.desc)
        for arg in op_desc.output_arg_names():
            if arg in new_vars:
                _infer_var_data_type_(arg, block)
F
update  
fengjiayi 已提交
196 197


F
fengjiayi 已提交
198
def append_backward(loss, parameter_list=None, no_grad_dict=None):
199 200 201 202 203 204
    """
    Create and add gradient Operators in BlockDesc to compute
    gradients of `loss` for parameters in parameter_list

    :param loss: an variable generated by cost function.
    :type loss: Variable
F
fengjiayi 已提交
205 206
    :param no_grad_dict: variable that should not create gradient
    :type no_grad_dict: set
207 208 209 210 211 212 213
    :param parameter_list: parameters that need to compute gradient and 
    update to optimize the lost.
    :type: list
    :return: list of (parameters, gradients) pair.
    :rtype: list[Variable]
    """
    assert isinstance(loss, framework.Variable)
Y
Yu Yang 已提交
214

F
fengjiayi 已提交
215
    program = loss.block.program
F
fengjiayi 已提交
216 217
    if no_grad_dict is None:
        no_grad_dict = dict()
Y
Yu Yang 已提交
218 219 220
        assert isinstance(program, framework.Program)
        for block in program.blocks:
            assert isinstance(block, framework.Block)
F
update  
fengjiayi 已提交
221
            block_no_grad_set = set()
Y
Yu Yang 已提交
222 223 224
            for var in block.vars.itervalues():
                assert isinstance(var, framework.Variable)
                if var.stop_gradient:
F
fengjiayi 已提交
225
                    block_no_grad_set.add(_append_grad_suffix_(var.name))
F
fengjiayi 已提交
226 227 228
            no_grad_dict[block.idx] = block_no_grad_set
    elif isinstance(no_grad_dict, set):
        no_grad_dict = {0: no_grad_dict}
Y
Yu Yang 已提交
229

F
update  
fengjiayi 已提交
230
    grad_info_map = dict()
F
fengjiayi 已提交
231
    root_block = program.block(0)
F
fengjiayi 已提交
232

F
fengjiayi 已提交
233 234
    fwd_op_num = root_block.desc.op_size()
    current_block_idx = program.current_block_idx
F
fengjiayi 已提交
235 236 237 238
    grad_to_var = dict()

    _append_backward_ops_(loss, root_block, root_block, no_grad_dict,
                          grad_to_var)
F
fengjiayi 已提交
239
    _append_backward_vars_(root_block, fwd_op_num, grad_to_var, grad_info_map)
F
fengjiayi 已提交
240

F
fengjiayi 已提交
241 242
    program.current_block_idx = current_block_idx
    program.sync_with_cpp()
F
fengjiayi 已提交
243

244 245 246
    if parameter_list is not None:
        parameters = parameter_list
    else:
F
fengjiayi 已提交
247
        params = program.global_block().all_parameters()
248 249 250
        parameters = [param.name for param in params]
    params_and_grads = []
    for param in parameters:
F
update  
fengjiayi 已提交
251
        if param not in grad_info_map:
252
            raise ValueError("param %s is not in map" % param)
F
update  
fengjiayi 已提交
253
        grad_info = grad_info_map[param]
F
fengjiayi 已提交
254
        grad_block = grad_info[1]
255 256 257 258
        if not grad_block.has_var(grad_info[0]):
            raise ValueError("grad block[{0}] did not have grad var {1}".format(
                grad_info[1], grad_info[0]))
        # Get the param var from the global block
F
fengjiayi 已提交
259
        param_var = program.global_block().var(param)
260 261 262 263 264 265
        grad_var = grad_block.var(grad_info[0])
        if loss.block.has_var(grad_info[0]):
            params_and_grads.append((param_var, grad_var))
        else:
            params_and_grads.append((param_var, None))
    return params_and_grads