backward.py 10.1 KB
Newer Older
Q
Qiao Longfei 已提交
1
from paddle.v2.fluid import framework as framework
F
update  
fengjiayi 已提交
2
from . import core
F
update  
fengjiayi 已提交
3
import collections
F
update  
fengjiayi 已提交
4
import pdb
5

F
fengjiayi 已提交
6
__all__ = ['append_backward']
7 8


F
fengjiayi 已提交
9 10
def _rename_arg_(op_desc_list, old_name, new_name, begin_idx=None,
                 end_idx=None):
F
update  
fengjiayi 已提交
11 12 13 14 15
    if begin_idx is None:
        begin_idx = 0
    if end_idx is None:
        end_idx = len(op_desc_list)
    for i in range(begin_idx, end_idx):
F
fengjiayi 已提交
16 17 18 19 20
        op_desc = op_desc_list[i]
        if isinstance(op_desc, tuple):
            op_desc = op_desc[0]
        op_desc.rename_input(old_name, new_name)
        op_desc.rename_output(old_name, new_name)
F
update  
fengjiayi 已提交
21 22


F
fengjiayi 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
def _create_op_desc_(op_type, inputs, outputs, attrs):
    op_desc = core.OpDesc()
    op_desc.set_type(op_type)
    for para, args in inputs.iteritems():
        op_desc.set_input(para, args)
    for para, args in outputs.iteritems():
        op_desc.set_output(para, args)
    for name, val in attrs.iteritems():
        if isinstance(val, framework.Block):
            op_desc.set_block_attr(name, val.desc)
        else:
            op_desc.set_attr(name, val)
    return op_desc


F
fengjiayi 已提交
38 39 40 41 42 43 44 45 46 47
def _infer_var_data_type_(var_name, block):
    grad_var = block.desc.find_var(var_name.encode("ascii"))
    fwd_name = _strip_grad_suffix_(var_name.encode("ascii"))
    if block.desc.has_var_recursive(fwd_name):
        fwd_var = block.desc.find_var_recursive(fwd_name.encode("ascii"))
        grad_var.set_dtype(fwd_var.dtype())
    else:
        grad_var.set_dtype(core.DataType.FP32)


F
fengjiayi 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
def _is_all_in_set_(cands, s):
    for c in cands:
        if not c in s:
            return False
    return True


def _strip_grad_suffix_(name):
    return name[:name.find(core.grad_var_suffix())]


def _append_grad_suffix_(name):
    return name + core.grad_var_suffix()


F
fengjiayi 已提交
63 64 65 66 67
def _append_backward_ops_(target,
                          block,
                          target_block,
                          no_grad_set,
                          callback=None):
F
update  
fengjiayi 已提交
68
    grad_op_descs = []
F
fengjiayi 已提交
69
    grad_to_var = dict()
F
update  
fengjiayi 已提交
70
    program = block.program
F
fengjiayi 已提交
71
    for each_op in reversed(block.ops):
F
update  
fengjiayi 已提交
72 73 74 75 76
        grad_sub_block_list = []
        if each_op.has_attr("sub_block"):
            sub_block_idx = each_op.block_attr("sub_block")
            sub_block = program.block(sub_block_idx)
            grad_sub_block = program.create_block(parent_idx=sub_block_idx)
F
fengjiayi 已提交
77 78 79
            sub_grad_to_var = _append_backward_ops_(
                target, sub_block, grad_sub_block, no_grad_set, callback)
            grad_to_var = dict(grad_to_var, **sub_grad_to_var)
F
fengjiayi 已提交
80
            grad_sub_block_list.append(grad_sub_block.desc)
F
fengjiayi 已提交
81 82
        grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
            each_op.desc, no_grad_set[block.idx], grad_sub_block_list)
F
update  
fengjiayi 已提交
83
        grad_op_descs.append(grad_op_desc)
F
fengjiayi 已提交
84
        grad_to_var = dict(grad_to_var, **op_grad_to_var)
F
update  
fengjiayi 已提交
85 86 87 88
    # grad_op_descs = [[op1_g1, op1_g2], [op2_g], ...]
    # flatten grad_op_descs
    grad_op_descs = [op for sublist in grad_op_descs for op in sublist]  # ?????

F
update  
fengjiayi 已提交
89 90 91
    pending_sum_ops = []
    var_rename_count = collections.defaultdict(int)
    var_inputs = collections.defaultdict(list)
F
fengjiayi 已提交
92
    for idx, op_desc in enumerate(grad_op_descs):
F
update  
fengjiayi 已提交
93 94
        for var_name in op_desc.input_arg_names():
            if len(var_inputs[var_name]) > 1:
F
fengjiayi 已提交
95
                pending_sum_ops.append((_create_op_desc_(
F
fengjiayi 已提交
96 97 98
                    op_type="sum",
                    inputs={"X": var_inputs[var_name]},
                    outputs={"Out": [var_name]},
F
fengjiayi 已提交
99
                    attrs={}), idx))
F
update  
fengjiayi 已提交
100
                var_inputs[var_name] = [var_name]
F
update  
fengjiayi 已提交
101
        for var_name in op_desc.output_arg_names():
F
fengjiayi 已提交
102 103 104
            if var_name in op_desc.input_arg_names():
                # in place operator
                continue
F
fengjiayi 已提交
105 106
            if var_name == core.empty_var_name() or len(var_inputs[
                    var_name]) == 0:
F
update  
fengjiayi 已提交
107
                # it's the first time we get the variable
F
update  
fengjiayi 已提交
108
                var_inputs[var_name] = [var_name]
F
update  
fengjiayi 已提交
109
            else:
F
fengjiayi 已提交
110
                if len(var_inputs[var_name]) == 1:
F
update  
fengjiayi 已提交
111 112 113 114 115
                    new_name = var_name + "@RENAME@" + \
                        str(var_rename_count[var_name])
                    var_rename_count[var_name] = var_rename_count[var_name] + 1
                    # rename original var_name
                    var_inputs[var_name][0] = new_name
F
fengjiayi 已提交
116
                    _rename_arg_(grad_op_descs, var_name, new_name, 0, idx)
F
fengjiayi 已提交
117
                    _rename_arg_(pending_sum_ops, var_name, new_name)
F
update  
fengjiayi 已提交
118 119 120 121 122 123 124 125

                new_name = var_name + "@RENAME@" + \
                    str(var_rename_count[var_name])
                var_rename_count[var_name] = var_rename_count[var_name] + 1
                op_desc.rename_output(var_name, new_name)
                var_inputs[var_name].append(new_name)
    for var_name, inputs in var_inputs.iteritems():
        if len(inputs) > 1:
F
fengjiayi 已提交
126
            pending_sum_ops.append((_create_op_desc_(
F
fengjiayi 已提交
127
                op_type="sum",
F
fengjiayi 已提交
128
                inputs={"X": inputs},
F
fengjiayi 已提交
129
                outputs={"Out": [var_name]},
F
fengjiayi 已提交
130
                attrs={}), len(grad_op_descs)))
F
fengjiayi 已提交
131
    # sum_op descs are sorted according to their insert position
F
update  
fengjiayi 已提交
132 133
    for p in reversed(pending_sum_ops):
        grad_op_descs.insert(p[1], p[0])
F
fengjiayi 已提交
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
    # Remove ops whose outputs are all in no_grad_set
    grad_op_descs = filter(
        lambda op_desc: not _is_all_in_set_(op_desc.output_arg_names(), no_grad_set[block.idx]),
        grad_op_descs)
    # Insert fill_zeros_like_op
    to_insert = []
    for idx, op_desc in enumerate(grad_op_descs):
        for arg in op_desc.input_arg_names():
            if arg in no_grad_set[block.idx]:
                to_insert.append((arg, idx))
    for ele in reversed(to_insert):
        arg = ele[0]
        fill_zeros_like_op = _create_op_desc_(
            "fill_zeros_like", {"X": [_strip_grad_suffix_(arg)]}, {"Y": [arg]},
            {})
        grad_op_descs.insert(ele[1], fill_zeros_like_op)
F
fengjiayi 已提交
150

F
update  
fengjiayi 已提交
151
    if target_block.idx == 0:
F
fengjiayi 已提交
152
        grad_target_name = _append_grad_suffix_(target.name)
F
fengjiayi 已提交
153
        # target_block.desc.var(grad_target_name.encode("ascii"))
F
update  
fengjiayi 已提交
154 155
        grad_op_descs.insert(
            0,
F
fengjiayi 已提交
156 157 158 159 160 161 162 163 164
            _create_op_desc_(
                op_type="fill_constant",
                inputs={},
                outputs={"Out": [grad_target_name]},
                attrs={
                    "shape": [1],
                    "value": 1.0,
                    "dtype": core.DataType.FP32
                }))
F
update  
fengjiayi 已提交
165
    for op_desc in grad_op_descs:
F
fengjiayi 已提交
166 167
        new_op_desc = target_block.desc.append_op()
        new_op_desc.copy_from(op_desc)
F
update  
fengjiayi 已提交
168

F
fengjiayi 已提交
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
    return grad_to_var


def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
    for op_idx in range(start_op_idx, block.desc.op_size()):
        op_desc = block.desc.op(op_idx)
        if op_desc.has_attr("sub_block"):
            sub_block = block.program.block(op_desc.block_attr("sub_block"))
            _append_backward_vars_(sub_block, 0, grad_to_var, grad_info_map)
        new_vars = set()
        # create new gradient variables
        for grad_var_name in op_desc.output_arg_names():
            grad_var_name = grad_var_name.encode("ascii")
            if block.desc.has_var_recursive(
                    grad_var_name) or grad_var_name == core.empty_var_name():
                continue
            block.desc.var(grad_var_name)
            new_vars.add(grad_var_name)
            if not grad_to_var.has_key(grad_var_name):
                continue
            grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name, block)
        # infer_shape and infer_type
        op_desc.infer_var_type(block.desc)
        op_desc.infer_shape(block.desc)
        for arg in op_desc.output_arg_names():
            if arg in new_vars:
                _infer_var_data_type_(arg, block)
F
update  
fengjiayi 已提交
196 197


F
fengjiayi 已提交
198
def append_backward(loss, parameter_list=None, no_grad_set=None):
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
    """
    Create and add gradient Operators in BlockDesc to compute
    gradients of `loss` for parameters in parameter_list

    :param loss: an variable generated by cost function.
    :type loss: Variable
    :param no_grad_set: variable that should not create gradient
    :type no_grad_set: set
    :param parameter_list: parameters that need to compute gradient and 
    update to optimize the lost.
    :type: list
    :return: list of (parameters, gradients) pair.
    :rtype: list[Variable]
    """
    assert isinstance(loss, framework.Variable)
Y
Yu Yang 已提交
214

F
fengjiayi 已提交
215
    program = loss.block.program
Y
Yu Yang 已提交
216
    if no_grad_set is None:
F
update  
fengjiayi 已提交
217
        no_grad_set = dict()
Y
Yu Yang 已提交
218 219 220
        assert isinstance(program, framework.Program)
        for block in program.blocks:
            assert isinstance(block, framework.Block)
F
update  
fengjiayi 已提交
221
            block_no_grad_set = set()
Y
Yu Yang 已提交
222 223 224
            for var in block.vars.itervalues():
                assert isinstance(var, framework.Variable)
                if var.stop_gradient:
F
fengjiayi 已提交
225
                    block_no_grad_set.add(_append_grad_suffix_(var.name))
F
update  
fengjiayi 已提交
226
            no_grad_set[block.idx] = block_no_grad_set
F
fengjiayi 已提交
227 228 229
    else:
        # FIX ME
        no_grad_set = {0: no_grad_set}
Y
Yu Yang 已提交
230

F
update  
fengjiayi 已提交
231
    grad_info_map = dict()
F
fengjiayi 已提交
232
    root_block = program.block(0)
F
fengjiayi 已提交
233

F
fengjiayi 已提交
234 235 236 237 238 239 240
    fwd_op_num = root_block.desc.op_size()
    current_block_idx = program.current_block_idx
    grad_to_var = _append_backward_ops_(loss, root_block, root_block,
                                        no_grad_set)
    _append_backward_vars_(root_block, fwd_op_num, grad_to_var, grad_info_map)
    program.current_block_idx = current_block_idx
    program.sync_with_cpp()
F
fengjiayi 已提交
241

242 243 244
    if parameter_list is not None:
        parameters = parameter_list
    else:
F
fengjiayi 已提交
245
        params = program.global_block().all_parameters()
246 247 248
        parameters = [param.name for param in params]
    params_and_grads = []
    for param in parameters:
F
update  
fengjiayi 已提交
249
        if param not in grad_info_map:
250
            raise ValueError("param %s is not in map" % param)
F
update  
fengjiayi 已提交
251
        grad_info = grad_info_map[param]
F
fengjiayi 已提交
252
        grad_block = grad_info[1]
253 254 255 256
        if not grad_block.has_var(grad_info[0]):
            raise ValueError("grad block[{0}] did not have grad var {1}".format(
                grad_info[1], grad_info[0]))
        # Get the param var from the global block
F
fengjiayi 已提交
257
        param_var = program.global_block().var(param)
258 259 260 261 262 263
        grad_var = grad_block.var(grad_info[0])
        if loss.block.has_var(grad_info[0]):
            params_and_grads.append((param_var, grad_var))
        else:
            params_and_grads.append((param_var, None))
    return params_and_grads