backward.py 9.4 KB
Newer Older
Q
Qiao Longfei 已提交
1
from paddle.v2.fluid import framework as framework
F
update  
fengjiayi 已提交
2
from . import core
F
update  
fengjiayi 已提交
3
import collections
F
update  
fengjiayi 已提交
4
import pdb
5 6 7 8

__all__ = ['append_backward_ops']


F
fengjiayi 已提交
9 10
def _rename_arg_(op_desc_list, old_name, new_name, begin_idx=None,
                 end_idx=None):
F
update  
fengjiayi 已提交
11 12 13 14 15 16 17 18 19
    if begin_idx is None:
        begin_idx = 0
    if end_idx is None:
        end_idx = len(op_desc_list)
    for i in range(begin_idx, end_idx):
        op_desc_list[i].rename_input(old_name, new_name)
        op_desc_list[i].rename_output(old_name, new_name)


F
fengjiayi 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
def _create_op_desc_(op_type, inputs, outputs, attrs):
    op_desc = core.OpDesc()
    op_desc.set_type(op_type)
    for para, args in inputs.iteritems():
        op_desc.set_input(para, args)
    for para, args in outputs.iteritems():
        op_desc.set_output(para, args)
    for name, val in attrs.iteritems():
        if isinstance(val, framework.Block):
            op_desc.set_block_attr(name, val.desc)
        else:
            op_desc.set_attr(name, val)
    return op_desc


F
fengjiayi 已提交
35 36 37 38 39 40 41 42 43 44
def _infer_var_data_type_(var_name, block):
    grad_var = block.desc.find_var(var_name.encode("ascii"))
    fwd_name = _strip_grad_suffix_(var_name.encode("ascii"))
    if block.desc.has_var_recursive(fwd_name):
        fwd_var = block.desc.find_var_recursive(fwd_name.encode("ascii"))
        grad_var.set_dtype(fwd_var.dtype())
    else:
        grad_var.set_dtype(core.DataType.FP32)


F
fengjiayi 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
def _is_all_in_set_(cands, s):
    for c in cands:
        if not c in s:
            return False
    return True


def _strip_grad_suffix_(name):
    return name[:name.find(core.grad_var_suffix())]


def _append_grad_suffix_(name):
    return name + core.grad_var_suffix()


def _backward_impl_(target,
                    block,
                    target_block,
                    no_grad_set,
                    grad_info_map,
                    callback=None):
F
update  
fengjiayi 已提交
66
    grad_op_descs = []
F
fengjiayi 已提交
67
    grad_to_var = dict()
F
update  
fengjiayi 已提交
68
    program = block.program
F
fengjiayi 已提交
69
    for each_op in reversed(block.ops):
F
update  
fengjiayi 已提交
70 71 72 73
        grad_sub_block_list = []
        if each_op.has_attr("sub_block"):
            sub_block_idx = each_op.block_attr("sub_block")
            sub_block = program.block(sub_block_idx)
F
fengjiayi 已提交
74
            original_block_idx = program.current_block_idx
F
update  
fengjiayi 已提交
75
            grad_sub_block = program.create_block(parent_idx=sub_block_idx)
F
fengjiayi 已提交
76
            program.current_block_idx = original_block_idx
F
fengjiayi 已提交
77 78
            _backward_impl_(target, sub_block, grad_sub_block, no_grad_set,
                            grad_info_map, callback)
F
fengjiayi 已提交
79
            grad_sub_block_list.append(grad_sub_block.desc)
F
fengjiayi 已提交
80 81
        grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
            each_op.desc, no_grad_set[block.idx], grad_sub_block_list)
F
update  
fengjiayi 已提交
82
        grad_op_descs.append(grad_op_desc)
F
fengjiayi 已提交
83
        grad_to_var = dict(grad_to_var, **op_grad_to_var)
F
update  
fengjiayi 已提交
84 85 86 87
    # grad_op_descs = [[op1_g1, op1_g2], [op2_g], ...]
    # flatten grad_op_descs
    grad_op_descs = [op for sublist in grad_op_descs for op in sublist]  # ?????

F
update  
fengjiayi 已提交
88 89 90
    pending_sum_ops = []
    var_rename_count = collections.defaultdict(int)
    var_inputs = collections.defaultdict(list)
F
fengjiayi 已提交
91
    for idx, op_desc in enumerate(grad_op_descs):
F
update  
fengjiayi 已提交
92 93
        for var_name in op_desc.input_arg_names():
            if len(var_inputs[var_name]) > 1:
F
fengjiayi 已提交
94
                pending_sum_ops.append((_create_op_desc_(
F
fengjiayi 已提交
95 96 97
                    op_type="sum",
                    inputs={"X": var_inputs[var_name]},
                    outputs={"Out": [var_name]},
F
fengjiayi 已提交
98
                    attrs={}), idx))
F
update  
fengjiayi 已提交
99
                var_inputs[var_name] = [var_name]
F
update  
fengjiayi 已提交
100
        for var_name in op_desc.output_arg_names():
F
fengjiayi 已提交
101 102
            if var_name == core.empty_var_name() or len(var_inputs[
                    var_name]) == 0:
F
update  
fengjiayi 已提交
103
                # it's the first time we get the variable
F
update  
fengjiayi 已提交
104
                var_inputs[var_name] = [var_name]
F
update  
fengjiayi 已提交
105
            else:
F
fengjiayi 已提交
106
                if len(var_inputs[var_name]) == 1:
F
update  
fengjiayi 已提交
107 108 109 110 111
                    new_name = var_name + "@RENAME@" + \
                        str(var_rename_count[var_name])
                    var_rename_count[var_name] = var_rename_count[var_name] + 1
                    # rename original var_name
                    var_inputs[var_name][0] = new_name
F
fengjiayi 已提交
112
                    _rename_arg_(grad_op_descs, var_name, new_name, 0, idx)
F
fengjiayi 已提交
113
                    _rename_arg_(pending_sum_ops, var_name, new_name)
F
update  
fengjiayi 已提交
114 115 116 117 118 119 120 121

                new_name = var_name + "@RENAME@" + \
                    str(var_rename_count[var_name])
                var_rename_count[var_name] = var_rename_count[var_name] + 1
                op_desc.rename_output(var_name, new_name)
                var_inputs[var_name].append(new_name)
    for var_name, inputs in var_inputs.iteritems():
        if len(inputs) > 1:
F
fengjiayi 已提交
122
            pending_sum_ops.append((_create_op_desc_(
F
fengjiayi 已提交
123
                op_type="sum",
F
fengjiayi 已提交
124
                inputs={"X": inputs},
F
fengjiayi 已提交
125
                outputs={"Out": [var_name]},
F
fengjiayi 已提交
126
                attrs={}), len(grad_op_descs)))
F
fengjiayi 已提交
127
    # sum_op descs are sorted according to their insert position
F
update  
fengjiayi 已提交
128 129
    for p in reversed(pending_sum_ops):
        grad_op_descs.insert(p[1], p[0])
F
fengjiayi 已提交
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
    # Remove ops whose outputs are all in no_grad_set
    grad_op_descs = filter(
        lambda op_desc: not _is_all_in_set_(op_desc.output_arg_names(), no_grad_set[block.idx]),
        grad_op_descs)
    # Insert fill_zeros_like_op
    to_insert = []
    for idx, op_desc in enumerate(grad_op_descs):
        for arg in op_desc.input_arg_names():
            if arg in no_grad_set[block.idx]:
                to_insert.append((arg, idx))
    for ele in reversed(to_insert):
        arg = ele[0]
        fill_zeros_like_op = _create_op_desc_(
            "fill_zeros_like", {"X": [_strip_grad_suffix_(arg)]}, {"Y": [arg]},
            {})
        grad_op_descs.insert(ele[1], fill_zeros_like_op)
F
update  
fengjiayi 已提交
146
    # create new gradient variables in the target block desc
F
fengjiayi 已提交
147
    new_vars = set()
F
update  
fengjiayi 已提交
148 149
    for op_desc in grad_op_descs:
        for grad_var_name in op_desc.output_arg_names():
F
update  
fengjiayi 已提交
150
            grad_var_name = grad_var_name.encode("ascii")
F
fengjiayi 已提交
151
            if target_block.desc.has_var_recursive(
F
fengjiayi 已提交
152
                    grad_var_name) or grad_var_name == core.empty_var_name():
F
update  
fengjiayi 已提交
153
                continue
F
update  
fengjiayi 已提交
154
            target_block.desc.var(grad_var_name)
F
fengjiayi 已提交
155
            new_vars.add(grad_var_name)
F
update  
fengjiayi 已提交
156 157 158 159
            if not grad_to_var.has_key(grad_var_name):
                continue
            grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name,
                                                         target_block)
F
update  
fengjiayi 已提交
160
    if target_block.idx == 0:
F
fengjiayi 已提交
161 162
        grad_target_name = _append_grad_suffix_(target.name)
        target_block.desc.var(grad_target_name.encode("ascii"))
F
update  
fengjiayi 已提交
163 164
        grad_op_descs.insert(
            0,
F
fengjiayi 已提交
165 166 167 168 169 170 171 172 173
            _create_op_desc_(
                op_type="fill_constant",
                inputs={},
                outputs={"Out": [grad_target_name]},
                attrs={
                    "shape": [1],
                    "value": 1.0,
                    "dtype": core.DataType.FP32
                }))
F
update  
fengjiayi 已提交
174 175
    # insert backward operators to target_block
    for op_desc in grad_op_descs:
F
fengjiayi 已提交
176 177
        op_desc.infer_var_type(target_block.desc)
        op_desc.infer_shape(target_block.desc)
F
fengjiayi 已提交
178 179 180 181 182
        for arg in op_desc.output_arg_names():
            if arg in new_vars:
                _infer_var_data_type_(arg, target_block)
        new_op_desc = target_block.desc.append_op()
        new_op_desc.copy_from(op_desc)
F
update  
fengjiayi 已提交
183 184

    target_block.sync_with_cpp()
F
update  
fengjiayi 已提交
185 186


187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
    """
    Create and add gradient Operators in BlockDesc to compute
    gradients of `loss` for parameters in parameter_list

    :param loss: an variable generated by cost function.
    :type loss: Variable
    :param no_grad_set: variable that should not create gradient
    :type no_grad_set: set
    :param parameter_list: parameters that need to compute gradient and 
    update to optimize the lost.
    :type: list
    :return: list of (parameters, gradients) pair.
    :rtype: list[Variable]
    """
    assert isinstance(loss, framework.Variable)
Y
Yu Yang 已提交
203 204

    if no_grad_set is None:
F
update  
fengjiayi 已提交
205
        no_grad_set = dict()
Y
Yu Yang 已提交
206 207 208 209
        program = loss.block.program
        assert isinstance(program, framework.Program)
        for block in program.blocks:
            assert isinstance(block, framework.Block)
F
update  
fengjiayi 已提交
210
            block_no_grad_set = set()
Y
Yu Yang 已提交
211 212 213
            for var in block.vars.itervalues():
                assert isinstance(var, framework.Variable)
                if var.stop_gradient:
F
fengjiayi 已提交
214
                    block_no_grad_set.add(_append_grad_suffix_(var.name))
F
update  
fengjiayi 已提交
215
            no_grad_set[block.idx] = block_no_grad_set
Y
Yu Yang 已提交
216

F
update  
fengjiayi 已提交
217 218
    grad_info_map = dict()
    root_block = loss.block.program.block(0)
F
fengjiayi 已提交
219 220 221

    _backward_impl_(loss, root_block, root_block, no_grad_set, grad_info_map)

222 223 224 225 226 227 228
    if parameter_list is not None:
        parameters = parameter_list
    else:
        params = loss.block.program.global_block().all_parameters()
        parameters = [param.name for param in params]
    params_and_grads = []
    for param in parameters:
F
update  
fengjiayi 已提交
229
        if param not in grad_info_map:
230
            raise ValueError("param %s is not in map" % param)
F
update  
fengjiayi 已提交
231
        grad_info = grad_info_map[param]
F
fengjiayi 已提交
232
        grad_block = grad_info[1]
233 234 235 236 237 238 239 240 241 242 243
        if not grad_block.has_var(grad_info[0]):
            raise ValueError("grad block[{0}] did not have grad var {1}".format(
                grad_info[1], grad_info[0]))
        # Get the param var from the global block
        param_var = loss.block.program.global_block().var(param)
        grad_var = grad_block.var(grad_info[0])
        if loss.block.has_var(grad_info[0]):
            params_and_grads.append((param_var, grad_var))
        else:
            params_and_grads.append((param_var, None))
    return params_and_grads