backward.py 10.1 KB
Newer Older
Q
Qiao Longfei 已提交
1
from paddle.v2.fluid import framework as framework
F
update  
fengjiayi 已提交
2
from . import core
F
update  
fengjiayi 已提交
3
import collections
F
update  
fengjiayi 已提交
4
import pdb
5

F
fengjiayi 已提交
6
__all__ = ['append_backward']
7 8


F
fengjiayi 已提交
9 10
def _rename_arg_(op_desc_list, old_name, new_name, begin_idx=None,
                 end_idx=None):
F
update  
fengjiayi 已提交
11 12 13 14 15
    if begin_idx is None:
        begin_idx = 0
    if end_idx is None:
        end_idx = len(op_desc_list)
    for i in range(begin_idx, end_idx):
F
fengjiayi 已提交
16 17 18 19 20
        op_desc = op_desc_list[i]
        if isinstance(op_desc, tuple):
            op_desc = op_desc[0]
        op_desc.rename_input(old_name, new_name)
        op_desc.rename_output(old_name, new_name)
F
update  
fengjiayi 已提交
21 22


F
fengjiayi 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
def _create_op_desc_(op_type, inputs, outputs, attrs):
    op_desc = core.OpDesc()
    op_desc.set_type(op_type)
    for para, args in inputs.iteritems():
        op_desc.set_input(para, args)
    for para, args in outputs.iteritems():
        op_desc.set_output(para, args)
    for name, val in attrs.iteritems():
        if isinstance(val, framework.Block):
            op_desc.set_block_attr(name, val.desc)
        else:
            op_desc.set_attr(name, val)
    return op_desc


F
fengjiayi 已提交
38 39 40 41 42 43 44 45 46 47
def _infer_var_data_type_(var_name, block):
    grad_var = block.desc.find_var(var_name.encode("ascii"))
    fwd_name = _strip_grad_suffix_(var_name.encode("ascii"))
    if block.desc.has_var_recursive(fwd_name):
        fwd_var = block.desc.find_var_recursive(fwd_name.encode("ascii"))
        grad_var.set_dtype(fwd_var.dtype())
    else:
        grad_var.set_dtype(core.DataType.FP32)


F
fengjiayi 已提交
48 49 50 51 52 53 54 55
def _is_all_in_set_(cands, s):
    for c in cands:
        if not c in s:
            return False
    return True


def _strip_grad_suffix_(name):
F
fengjiayi 已提交
56 57
    pos = name.find(core.grad_var_suffix())
    return name[:pos] if pos != -1 else name
F
fengjiayi 已提交
58 59 60 61 62 63


def _append_grad_suffix_(name):
    return name + core.grad_var_suffix()


F
fengjiayi 已提交
64 65 66 67 68
def _append_backward_ops_(target,
                          block,
                          target_block,
                          no_grad_set,
                          callback=None):
F
update  
fengjiayi 已提交
69
    grad_op_descs = []
F
fengjiayi 已提交
70
    grad_to_var = dict()
F
update  
fengjiayi 已提交
71
    program = block.program
F
fengjiayi 已提交
72
    for each_op in reversed(block.ops):
F
update  
fengjiayi 已提交
73 74 75 76 77
        grad_sub_block_list = []
        if each_op.has_attr("sub_block"):
            sub_block_idx = each_op.block_attr("sub_block")
            sub_block = program.block(sub_block_idx)
            grad_sub_block = program.create_block(parent_idx=sub_block_idx)
F
fengjiayi 已提交
78 79 80
            sub_grad_to_var = _append_backward_ops_(
                target, sub_block, grad_sub_block, no_grad_set, callback)
            grad_to_var = dict(grad_to_var, **sub_grad_to_var)
F
fengjiayi 已提交
81
            grad_sub_block_list.append(grad_sub_block.desc)
F
fengjiayi 已提交
82 83
        grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
            each_op.desc, no_grad_set[block.idx], grad_sub_block_list)
F
update  
fengjiayi 已提交
84
        grad_op_descs.append(grad_op_desc)
F
fengjiayi 已提交
85
        grad_to_var = dict(grad_to_var, **op_grad_to_var)
F
update  
fengjiayi 已提交
86 87 88 89
    # grad_op_descs = [[op1_g1, op1_g2], [op2_g], ...]
    # flatten grad_op_descs
    grad_op_descs = [op for sublist in grad_op_descs for op in sublist]  # ?????

F
update  
fengjiayi 已提交
90 91 92
    pending_sum_ops = []
    var_rename_count = collections.defaultdict(int)
    var_inputs = collections.defaultdict(list)
F
fengjiayi 已提交
93
    for idx, op_desc in enumerate(grad_op_descs):
F
update  
fengjiayi 已提交
94 95
        for var_name in op_desc.input_arg_names():
            if len(var_inputs[var_name]) > 1:
F
fengjiayi 已提交
96
                pending_sum_ops.append((_create_op_desc_(
F
fengjiayi 已提交
97 98 99
                    op_type="sum",
                    inputs={"X": var_inputs[var_name]},
                    outputs={"Out": [var_name]},
F
fengjiayi 已提交
100
                    attrs={}), idx))
F
update  
fengjiayi 已提交
101
                var_inputs[var_name] = [var_name]
F
update  
fengjiayi 已提交
102
        for var_name in op_desc.output_arg_names():
F
fengjiayi 已提交
103 104 105
            if var_name in op_desc.input_arg_names():
                # in place operator
                continue
F
fengjiayi 已提交
106 107
            if var_name == core.empty_var_name() or len(var_inputs[
                    var_name]) == 0:
F
update  
fengjiayi 已提交
108
                # it's the first time we get the variable
F
update  
fengjiayi 已提交
109
                var_inputs[var_name] = [var_name]
F
update  
fengjiayi 已提交
110
            else:
F
fengjiayi 已提交
111
                if len(var_inputs[var_name]) == 1:
F
update  
fengjiayi 已提交
112 113 114 115 116
                    new_name = var_name + "@RENAME@" + \
                        str(var_rename_count[var_name])
                    var_rename_count[var_name] = var_rename_count[var_name] + 1
                    # rename original var_name
                    var_inputs[var_name][0] = new_name
F
fengjiayi 已提交
117
                    _rename_arg_(grad_op_descs, var_name, new_name, 0, idx)
F
fengjiayi 已提交
118
                    _rename_arg_(pending_sum_ops, var_name, new_name)
F
update  
fengjiayi 已提交
119 120 121 122 123 124 125 126

                new_name = var_name + "@RENAME@" + \
                    str(var_rename_count[var_name])
                var_rename_count[var_name] = var_rename_count[var_name] + 1
                op_desc.rename_output(var_name, new_name)
                var_inputs[var_name].append(new_name)
    for var_name, inputs in var_inputs.iteritems():
        if len(inputs) > 1:
F
fengjiayi 已提交
127
            pending_sum_ops.append((_create_op_desc_(
F
fengjiayi 已提交
128
                op_type="sum",
F
fengjiayi 已提交
129
                inputs={"X": inputs},
F
fengjiayi 已提交
130
                outputs={"Out": [var_name]},
F
fengjiayi 已提交
131
                attrs={}), len(grad_op_descs)))
F
fengjiayi 已提交
132
    # sum_op descs are sorted according to their insert position
F
update  
fengjiayi 已提交
133 134
    for p in reversed(pending_sum_ops):
        grad_op_descs.insert(p[1], p[0])
F
fengjiayi 已提交
135 136 137 138 139 140 141 142
    # Remove ops whose outputs are all in no_grad_set
    grad_op_descs = filter(
        lambda op_desc: not _is_all_in_set_(op_desc.output_arg_names(), no_grad_set[block.idx]),
        grad_op_descs)
    # Insert fill_zeros_like_op
    to_insert = []
    for idx, op_desc in enumerate(grad_op_descs):
        for arg in op_desc.input_arg_names():
F
fengjiayi 已提交
143
            if core.grad_var_suffix() in arg and arg in no_grad_set[block.idx]:
F
fengjiayi 已提交
144 145 146 147 148 149 150
                to_insert.append((arg, idx))
    for ele in reversed(to_insert):
        arg = ele[0]
        fill_zeros_like_op = _create_op_desc_(
            "fill_zeros_like", {"X": [_strip_grad_suffix_(arg)]}, {"Y": [arg]},
            {})
        grad_op_descs.insert(ele[1], fill_zeros_like_op)
F
fengjiayi 已提交
151

F
update  
fengjiayi 已提交
152
    if target_block.idx == 0:
F
fengjiayi 已提交
153
        grad_target_name = _append_grad_suffix_(target.name)
F
fengjiayi 已提交
154
        # target_block.desc.var(grad_target_name.encode("ascii"))
F
update  
fengjiayi 已提交
155 156
        grad_op_descs.insert(
            0,
F
fengjiayi 已提交
157 158 159 160 161 162 163 164 165
            _create_op_desc_(
                op_type="fill_constant",
                inputs={},
                outputs={"Out": [grad_target_name]},
                attrs={
                    "shape": [1],
                    "value": 1.0,
                    "dtype": core.DataType.FP32
                }))
F
update  
fengjiayi 已提交
166
    for op_desc in grad_op_descs:
F
fengjiayi 已提交
167 168
        new_op_desc = target_block.desc.append_op()
        new_op_desc.copy_from(op_desc)
F
update  
fengjiayi 已提交
169

F
fengjiayi 已提交
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
    return grad_to_var


def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
    for op_idx in range(start_op_idx, block.desc.op_size()):
        op_desc = block.desc.op(op_idx)
        if op_desc.has_attr("sub_block"):
            sub_block = block.program.block(op_desc.block_attr("sub_block"))
            _append_backward_vars_(sub_block, 0, grad_to_var, grad_info_map)
        new_vars = set()
        # create new gradient variables
        for grad_var_name in op_desc.output_arg_names():
            grad_var_name = grad_var_name.encode("ascii")
            if block.desc.has_var_recursive(
                    grad_var_name) or grad_var_name == core.empty_var_name():
                continue
            block.desc.var(grad_var_name)
            new_vars.add(grad_var_name)
            if not grad_to_var.has_key(grad_var_name):
                continue
            grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name, block)
        # infer_shape and infer_type
        op_desc.infer_var_type(block.desc)
        op_desc.infer_shape(block.desc)
        for arg in op_desc.output_arg_names():
            if arg in new_vars:
                _infer_var_data_type_(arg, block)
F
update  
fengjiayi 已提交
197 198


F
fengjiayi 已提交
199
def append_backward(loss, parameter_list=None, no_grad_set=None):
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
    """
    Create and add gradient Operators in BlockDesc to compute
    gradients of `loss` for parameters in parameter_list

    :param loss: an variable generated by cost function.
    :type loss: Variable
    :param no_grad_set: variable that should not create gradient
    :type no_grad_set: set
    :param parameter_list: parameters that need to compute gradient and 
    update to optimize the lost.
    :type: list
    :return: list of (parameters, gradients) pair.
    :rtype: list[Variable]
    """
    assert isinstance(loss, framework.Variable)
Y
Yu Yang 已提交
215

F
fengjiayi 已提交
216
    program = loss.block.program
Y
Yu Yang 已提交
217
    if no_grad_set is None:
F
update  
fengjiayi 已提交
218
        no_grad_set = dict()
Y
Yu Yang 已提交
219 220 221
        assert isinstance(program, framework.Program)
        for block in program.blocks:
            assert isinstance(block, framework.Block)
F
update  
fengjiayi 已提交
222
            block_no_grad_set = set()
Y
Yu Yang 已提交
223 224 225
            for var in block.vars.itervalues():
                assert isinstance(var, framework.Variable)
                if var.stop_gradient:
F
fengjiayi 已提交
226
                    block_no_grad_set.add(_append_grad_suffix_(var.name))
F
update  
fengjiayi 已提交
227
            no_grad_set[block.idx] = block_no_grad_set
F
fengjiayi 已提交
228 229 230
    else:
        # FIX ME
        no_grad_set = {0: no_grad_set}
Y
Yu Yang 已提交
231

F
update  
fengjiayi 已提交
232
    grad_info_map = dict()
F
fengjiayi 已提交
233
    root_block = program.block(0)
F
fengjiayi 已提交
234

F
fengjiayi 已提交
235 236 237 238 239 240 241
    fwd_op_num = root_block.desc.op_size()
    current_block_idx = program.current_block_idx
    grad_to_var = _append_backward_ops_(loss, root_block, root_block,
                                        no_grad_set)
    _append_backward_vars_(root_block, fwd_op_num, grad_to_var, grad_info_map)
    program.current_block_idx = current_block_idx
    program.sync_with_cpp()
F
fengjiayi 已提交
242

243 244 245
    if parameter_list is not None:
        parameters = parameter_list
    else:
F
fengjiayi 已提交
246
        params = program.global_block().all_parameters()
247 248 249
        parameters = [param.name for param in params]
    params_and_grads = []
    for param in parameters:
F
update  
fengjiayi 已提交
250
        if param not in grad_info_map:
251
            raise ValueError("param %s is not in map" % param)
F
update  
fengjiayi 已提交
252
        grad_info = grad_info_map[param]
F
fengjiayi 已提交
253
        grad_block = grad_info[1]
254 255 256 257
        if not grad_block.has_var(grad_info[0]):
            raise ValueError("grad block[{0}] did not have grad var {1}".format(
                grad_info[1], grad_info[0]))
        # Get the param var from the global block
F
fengjiayi 已提交
258
        param_var = program.global_block().var(param)
259 260 261 262 263 264
        grad_var = grad_block.var(grad_info[0])
        if loss.block.has_var(grad_info[0]):
            params_and_grads.append((param_var, grad_var))
        else:
            params_and_grads.append((param_var, None))
    return params_and_grads