backward.py 9.9 KB
Newer Older
Q
Qiao Longfei 已提交
1
from paddle.v2.fluid import framework as framework
F
update  
fengjiayi 已提交
2
from . import core
F
update  
fengjiayi 已提交
3
import collections
F
update  
fengjiayi 已提交
4
import pdb
5

F
fengjiayi 已提交
6
__all__ = ['append_backward']
7 8


F
fengjiayi 已提交
9 10
def _rename_arg_(op_desc_list, old_name, new_name, begin_idx=None,
                 end_idx=None):
F
update  
fengjiayi 已提交
11 12 13 14 15
    if begin_idx is None:
        begin_idx = 0
    if end_idx is None:
        end_idx = len(op_desc_list)
    for i in range(begin_idx, end_idx):
F
fengjiayi 已提交
16 17 18 19 20
        op_desc = op_desc_list[i]
        if isinstance(op_desc, tuple):
            op_desc = op_desc[0]
        op_desc.rename_input(old_name, new_name)
        op_desc.rename_output(old_name, new_name)
F
update  
fengjiayi 已提交
21 22


F
fengjiayi 已提交
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
def _create_op_desc_(op_type, inputs, outputs, attrs):
    op_desc = core.OpDesc()
    op_desc.set_type(op_type)
    for para, args in inputs.iteritems():
        op_desc.set_input(para, args)
    for para, args in outputs.iteritems():
        op_desc.set_output(para, args)
    for name, val in attrs.iteritems():
        if isinstance(val, framework.Block):
            op_desc.set_block_attr(name, val.desc)
        else:
            op_desc.set_attr(name, val)
    return op_desc


F
fengjiayi 已提交
38 39 40 41 42 43 44 45 46 47
def _infer_var_data_type_(var_name, block):
    grad_var = block.desc.find_var(var_name.encode("ascii"))
    fwd_name = _strip_grad_suffix_(var_name.encode("ascii"))
    if block.desc.has_var_recursive(fwd_name):
        fwd_var = block.desc.find_var_recursive(fwd_name.encode("ascii"))
        grad_var.set_dtype(fwd_var.dtype())
    else:
        grad_var.set_dtype(core.DataType.FP32)


F
fengjiayi 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
def _is_all_in_set_(cands, s):
    for c in cands:
        if not c in s:
            return False
    return True


def _strip_grad_suffix_(name):
    return name[:name.find(core.grad_var_suffix())]


def _append_grad_suffix_(name):
    return name + core.grad_var_suffix()


F
fengjiayi 已提交
63 64 65 66 67
def _append_backward_ops_(target,
                          block,
                          target_block,
                          no_grad_set,
                          callback=None):
F
update  
fengjiayi 已提交
68
    grad_op_descs = []
F
fengjiayi 已提交
69
    grad_to_var = dict()
F
update  
fengjiayi 已提交
70
    program = block.program
F
fengjiayi 已提交
71
    for each_op in reversed(block.ops):
F
update  
fengjiayi 已提交
72 73 74 75 76
        grad_sub_block_list = []
        if each_op.has_attr("sub_block"):
            sub_block_idx = each_op.block_attr("sub_block")
            sub_block = program.block(sub_block_idx)
            grad_sub_block = program.create_block(parent_idx=sub_block_idx)
F
fengjiayi 已提交
77 78 79
            sub_grad_to_var = _append_backward_ops_(
                target, sub_block, grad_sub_block, no_grad_set, callback)
            grad_to_var = dict(grad_to_var, **sub_grad_to_var)
F
fengjiayi 已提交
80
            grad_sub_block_list.append(grad_sub_block.desc)
F
fengjiayi 已提交
81 82
        grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
            each_op.desc, no_grad_set[block.idx], grad_sub_block_list)
F
update  
fengjiayi 已提交
83
        grad_op_descs.append(grad_op_desc)
F
fengjiayi 已提交
84
        grad_to_var = dict(grad_to_var, **op_grad_to_var)
F
update  
fengjiayi 已提交
85 86 87 88
    # grad_op_descs = [[op1_g1, op1_g2], [op2_g], ...]
    # flatten grad_op_descs
    grad_op_descs = [op for sublist in grad_op_descs for op in sublist]  # ?????

F
update  
fengjiayi 已提交
89 90 91
    pending_sum_ops = []
    var_rename_count = collections.defaultdict(int)
    var_inputs = collections.defaultdict(list)
F
fengjiayi 已提交
92
    for idx, op_desc in enumerate(grad_op_descs):
F
update  
fengjiayi 已提交
93 94
        for var_name in op_desc.input_arg_names():
            if len(var_inputs[var_name]) > 1:
F
fengjiayi 已提交
95
                pending_sum_ops.append((_create_op_desc_(
F
fengjiayi 已提交
96 97 98
                    op_type="sum",
                    inputs={"X": var_inputs[var_name]},
                    outputs={"Out": [var_name]},
F
fengjiayi 已提交
99
                    attrs={}), idx))
F
update  
fengjiayi 已提交
100
                var_inputs[var_name] = [var_name]
F
update  
fengjiayi 已提交
101
        for var_name in op_desc.output_arg_names():
F
fengjiayi 已提交
102 103
            if var_name == core.empty_var_name() or len(var_inputs[
                    var_name]) == 0:
F
update  
fengjiayi 已提交
104
                # it's the first time we get the variable
F
update  
fengjiayi 已提交
105
                var_inputs[var_name] = [var_name]
F
update  
fengjiayi 已提交
106
            else:
F
fengjiayi 已提交
107
                if len(var_inputs[var_name]) == 1:
F
update  
fengjiayi 已提交
108 109 110 111 112
                    new_name = var_name + "@RENAME@" + \
                        str(var_rename_count[var_name])
                    var_rename_count[var_name] = var_rename_count[var_name] + 1
                    # rename original var_name
                    var_inputs[var_name][0] = new_name
F
fengjiayi 已提交
113
                    _rename_arg_(grad_op_descs, var_name, new_name, 0, idx)
F
fengjiayi 已提交
114
                    _rename_arg_(pending_sum_ops, var_name, new_name)
F
update  
fengjiayi 已提交
115 116 117 118 119 120 121 122

                new_name = var_name + "@RENAME@" + \
                    str(var_rename_count[var_name])
                var_rename_count[var_name] = var_rename_count[var_name] + 1
                op_desc.rename_output(var_name, new_name)
                var_inputs[var_name].append(new_name)
    for var_name, inputs in var_inputs.iteritems():
        if len(inputs) > 1:
F
fengjiayi 已提交
123
            pending_sum_ops.append((_create_op_desc_(
F
fengjiayi 已提交
124
                op_type="sum",
F
fengjiayi 已提交
125
                inputs={"X": inputs},
F
fengjiayi 已提交
126
                outputs={"Out": [var_name]},
F
fengjiayi 已提交
127
                attrs={}), len(grad_op_descs)))
F
fengjiayi 已提交
128
    # sum_op descs are sorted according to their insert position
F
update  
fengjiayi 已提交
129 130
    for p in reversed(pending_sum_ops):
        grad_op_descs.insert(p[1], p[0])
F
fengjiayi 已提交
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
    # Remove ops whose outputs are all in no_grad_set
    grad_op_descs = filter(
        lambda op_desc: not _is_all_in_set_(op_desc.output_arg_names(), no_grad_set[block.idx]),
        grad_op_descs)
    # Insert fill_zeros_like_op
    to_insert = []
    for idx, op_desc in enumerate(grad_op_descs):
        for arg in op_desc.input_arg_names():
            if arg in no_grad_set[block.idx]:
                to_insert.append((arg, idx))
    for ele in reversed(to_insert):
        arg = ele[0]
        fill_zeros_like_op = _create_op_desc_(
            "fill_zeros_like", {"X": [_strip_grad_suffix_(arg)]}, {"Y": [arg]},
            {})
        grad_op_descs.insert(ele[1], fill_zeros_like_op)
F
fengjiayi 已提交
147

F
update  
fengjiayi 已提交
148
    if target_block.idx == 0:
F
fengjiayi 已提交
149 150
        grad_target_name = _append_grad_suffix_(target.name)
        target_block.desc.var(grad_target_name.encode("ascii"))
F
update  
fengjiayi 已提交
151 152
        grad_op_descs.insert(
            0,
F
fengjiayi 已提交
153 154 155 156 157 158 159 160 161
            _create_op_desc_(
                op_type="fill_constant",
                inputs={},
                outputs={"Out": [grad_target_name]},
                attrs={
                    "shape": [1],
                    "value": 1.0,
                    "dtype": core.DataType.FP32
                }))
F
update  
fengjiayi 已提交
162
    for op_desc in grad_op_descs:
F
fengjiayi 已提交
163 164
        new_op_desc = target_block.desc.append_op()
        new_op_desc.copy_from(op_desc)
F
update  
fengjiayi 已提交
165

F
fengjiayi 已提交
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
    return grad_to_var


def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
    for op_idx in range(start_op_idx, block.desc.op_size()):
        op_desc = block.desc.op(op_idx)
        if op_desc.has_attr("sub_block"):
            sub_block = block.program.block(op_desc.block_attr("sub_block"))
            _append_backward_vars_(sub_block, 0, grad_to_var, grad_info_map)
        new_vars = set()
        # create new gradient variables
        for grad_var_name in op_desc.output_arg_names():
            grad_var_name = grad_var_name.encode("ascii")
            if block.desc.has_var_recursive(
                    grad_var_name) or grad_var_name == core.empty_var_name():
                continue
            block.desc.var(grad_var_name)
            new_vars.add(grad_var_name)
            if not grad_to_var.has_key(grad_var_name):
                continue
            grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name, block)
        # infer_shape and infer_type
        op_desc.infer_var_type(block.desc)
        op_desc.infer_shape(block.desc)
        for arg in op_desc.output_arg_names():
            if arg in new_vars:
                _infer_var_data_type_(arg, block)
F
update  
fengjiayi 已提交
193 194


F
fengjiayi 已提交
195
def append_backward(loss, parameter_list=None, no_grad_set=None):
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
    """
    Create and add gradient Operators in BlockDesc to compute
    gradients of `loss` for parameters in parameter_list

    :param loss: an variable generated by cost function.
    :type loss: Variable
    :param no_grad_set: variable that should not create gradient
    :type no_grad_set: set
    :param parameter_list: parameters that need to compute gradient and 
    update to optimize the lost.
    :type: list
    :return: list of (parameters, gradients) pair.
    :rtype: list[Variable]
    """
    assert isinstance(loss, framework.Variable)
Y
Yu Yang 已提交
211

F
fengjiayi 已提交
212
    program = loss.block.program
Y
Yu Yang 已提交
213
    if no_grad_set is None:
F
update  
fengjiayi 已提交
214
        no_grad_set = dict()
Y
Yu Yang 已提交
215 216 217
        assert isinstance(program, framework.Program)
        for block in program.blocks:
            assert isinstance(block, framework.Block)
F
update  
fengjiayi 已提交
218
            block_no_grad_set = set()
Y
Yu Yang 已提交
219 220 221
            for var in block.vars.itervalues():
                assert isinstance(var, framework.Variable)
                if var.stop_gradient:
F
fengjiayi 已提交
222
                    block_no_grad_set.add(_append_grad_suffix_(var.name))
F
update  
fengjiayi 已提交
223
            no_grad_set[block.idx] = block_no_grad_set
Y
Yu Yang 已提交
224

F
update  
fengjiayi 已提交
225
    grad_info_map = dict()
F
fengjiayi 已提交
226
    root_block = program.block(0)
F
fengjiayi 已提交
227

F
fengjiayi 已提交
228 229 230 231 232 233 234
    fwd_op_num = root_block.desc.op_size()
    current_block_idx = program.current_block_idx
    grad_to_var = _append_backward_ops_(loss, root_block, root_block,
                                        no_grad_set)
    _append_backward_vars_(root_block, fwd_op_num, grad_to_var, grad_info_map)
    program.current_block_idx = current_block_idx
    program.sync_with_cpp()
F
fengjiayi 已提交
235

236 237 238
    if parameter_list is not None:
        parameters = parameter_list
    else:
F
fengjiayi 已提交
239
        params = program.global_block().all_parameters()
240 241 242
        parameters = [param.name for param in params]
    params_and_grads = []
    for param in parameters:
F
update  
fengjiayi 已提交
243
        if param not in grad_info_map:
244
            raise ValueError("param %s is not in map" % param)
F
update  
fengjiayi 已提交
245
        grad_info = grad_info_map[param]
F
fengjiayi 已提交
246
        grad_block = grad_info[1]
247 248 249 250
        if not grad_block.has_var(grad_info[0]):
            raise ValueError("grad block[{0}] did not have grad var {1}".format(
                grad_info[1], grad_info[0]))
        # Get the param var from the global block
F
fengjiayi 已提交
251
        param_var = program.global_block().var(param)
252 253 254 255 256 257
        grad_var = grad_block.var(grad_info[0])
        if loss.block.has_var(grad_info[0]):
            params_and_grads.append((param_var, grad_var))
        else:
            params_and_grads.append((param_var, None))
    return params_and_grads