backward.py 13.2 KB
Newer Older
Q
Qiao Longfei 已提交
1
from paddle.v2.fluid import framework as framework
F
update  
fengjiayi 已提交
2
from . import core
F
update  
fengjiayi 已提交
3
import collections
4

F
fengjiayi 已提交
5
__all__ = ['append_backward']
6 7


8 9
def _rename_arg_(op_descs, old_name, new_name, begin_idx=None, end_idx=None):
    """
10
    Traverse all ops in op_descs[begin_idx : end_idx],
11 12
    if any op has inputs/outputs named "old_name", rename it as 'new_name'
    """
F
update  
fengjiayi 已提交
13 14 15
    if begin_idx is None:
        begin_idx = 0
    if end_idx is None:
16
        end_idx = len(op_descs)
F
update  
fengjiayi 已提交
17
    for i in range(begin_idx, end_idx):
18
        op_desc = op_descs[i]
F
fengjiayi 已提交
19 20 21 22
        if isinstance(op_desc, tuple):
            op_desc = op_desc[0]
        op_desc.rename_input(old_name, new_name)
        op_desc.rename_output(old_name, new_name)
F
update  
fengjiayi 已提交
23 24


F
fengjiayi 已提交
25
def _create_op_desc_(op_type, inputs, outputs, attrs):
26 27 28
    """
    Create a C++ OpDesc object with specified inputs, outputs and attributes.
    """
F
fengjiayi 已提交
29 30 31 32 33 34 35 36 37 38 39 40 41 42
    op_desc = core.OpDesc()
    op_desc.set_type(op_type)
    for para, args in inputs.iteritems():
        op_desc.set_input(para, args)
    for para, args in outputs.iteritems():
        op_desc.set_output(para, args)
    for name, val in attrs.iteritems():
        if isinstance(val, framework.Block):
            op_desc.set_block_attr(name, val.desc)
        else:
            op_desc.set_attr(name, val)
    return op_desc


43 44 45 46 47 48
def _infer_var_data_type_(grad_var_name, block):
    """
    Infer the data type of given grad variable
    """
    grad_var = block.desc.find_var(grad_var_name.encode("ascii"))
    fwd_name = _strip_grad_suffix_(grad_var_name.encode("ascii"))
F
fengjiayi 已提交
49 50 51 52 53 54 55
    if block.desc.has_var_recursive(fwd_name):
        fwd_var = block.desc.find_var_recursive(fwd_name.encode("ascii"))
        grad_var.set_dtype(fwd_var.dtype())
    else:
        grad_var.set_dtype(core.DataType.FP32)


F
fengjiayi 已提交
56
def _all_in_set_(cands, s):
57 58 59
    """
    Test if all elements of 'cands' are in set 's'
    """
F
fengjiayi 已提交
60 61
    if len(cands) == 0:
        return False
F
fengjiayi 已提交
62 63 64 65 66 67 68
    for c in cands:
        if not c in s:
            return False
    return True


def _strip_grad_suffix_(name):
69 70 71 72 73
    """
    Strip the grad suffix from the given varibale name
    e.g. x@GRAD ==> x
         y@GRAD@RENAME@1 ==> y
    """
F
fengjiayi 已提交
74 75
    pos = name.find(core.grad_var_suffix())
    return name[:pos] if pos != -1 else name
F
fengjiayi 已提交
76 77 78


def _append_grad_suffix_(name):
79 80 81 82
    """
    Append grad suffix to the given variable name
    e.g. x ==> x@GRAD
    """
F
fengjiayi 已提交
83 84 85
    return name + core.grad_var_suffix()


F
fengjiayi 已提交
86
def _addup_repetitive_outputs_(op_descs):
87 88 89 90 91
    """
    In backward part, an variable may be the output of more than one ops.
    In this case, the variable should be the accumulation of all the outputs.
    `sum_op`s are added to implement the accumulate.
    """
F
update  
fengjiayi 已提交
92 93
    pending_sum_ops = []
    var_rename_count = collections.defaultdict(int)
F
fengjiayi 已提交
94 95
    renamed_vars = collections.defaultdict(list)
    for idx, op_desc in enumerate(op_descs):
F
update  
fengjiayi 已提交
96
        for var_name in op_desc.input_arg_names():
F
fengjiayi 已提交
97 98 99 100 101
            if len(renamed_vars[var_name]) > 1:
                pending_sum_ops.append(
                    (_create_op_desc_("sum", {"X": renamed_vars[var_name]},
                                      {"Out": [var_name]}, {}), idx))
                renamed_vars[var_name] = [var_name]
F
update  
fengjiayi 已提交
102
        for var_name in op_desc.output_arg_names():
F
fengjiayi 已提交
103 104 105
            if var_name == core.empty_var_name(
            ) or var_name in op_desc.input_arg_names():
                # empty variable or inplace op
F
fengjiayi 已提交
106
                continue
F
fengjiayi 已提交
107
            if len(renamed_vars[var_name]) == 0:
F
update  
fengjiayi 已提交
108
                # it's the first time we get the variable
F
fengjiayi 已提交
109
                renamed_vars[var_name] = [var_name]
F
update  
fengjiayi 已提交
110
            else:
F
fengjiayi 已提交
111
                if len(renamed_vars[var_name]) == 1:
F
update  
fengjiayi 已提交
112 113
                    new_name = var_name + "@RENAME@" + \
                        str(var_rename_count[var_name])
F
fengjiayi 已提交
114
                    var_rename_count[var_name] += 1
F
update  
fengjiayi 已提交
115
                    # rename original var_name
F
fengjiayi 已提交
116 117
                    renamed_vars[var_name][0] = new_name
                    _rename_arg_(op_descs, var_name, new_name, 0, idx)
F
fengjiayi 已提交
118
                    _rename_arg_(pending_sum_ops, var_name, new_name)
F
update  
fengjiayi 已提交
119 120 121

                new_name = var_name + "@RENAME@" + \
                    str(var_rename_count[var_name])
F
fengjiayi 已提交
122
                var_rename_count[var_name] += 1
F
update  
fengjiayi 已提交
123
                op_desc.rename_output(var_name, new_name)
F
fengjiayi 已提交
124 125
                renamed_vars[var_name].append(new_name)
    for var_name, inputs in renamed_vars.iteritems():
F
update  
fengjiayi 已提交
126
        if len(inputs) > 1:
F
fengjiayi 已提交
127
            pending_sum_ops.append((_create_op_desc_(
F
fengjiayi 已提交
128
                "sum", {"X": inputs}, {"Out": [var_name]}, {}), len(op_descs)))
F
fengjiayi 已提交
129
    # sum_op descs are sorted according to their insert position
F
update  
fengjiayi 已提交
130
    for p in reversed(pending_sum_ops):
F
fengjiayi 已提交
131 132 133 134 135 136
        op_descs.insert(p[1], p[0])

    return op_descs


def _remove_no_grad_branch_(op_descs, no_grad_set):
137 138 139 140
    """
    Remove unnecessary grad ops
    A grad op can be removed in two cases:
        1. all outputs of the grad op are in 'no_grad_set'
F
fengjiayi 已提交
141
        2. all grad inputs of the grad op are in 'no_grad_set'
142
    """
F
fengjiayi 已提交
143 144

    def _op_can_be_removed_(op_desc, no_grad_set):
F
fengjiayi 已提交
145 146
        out_arg_names = op_desc.output_arg_names()
        if len(out_arg_names) == 0 or _all_in_set_(out_arg_names, no_grad_set):
F
fengjiayi 已提交
147 148 149 150
            return True
        if _all_in_set_(
                filter(lambda name: name.find(core.grad_var_suffix()) != -1,
                       op_desc.input_arg_names()), no_grad_set):
F
fengjiayi 已提交
151
            no_grad_set.union(out_arg_names)
F
fengjiayi 已提交
152 153 154
            return True
        return False

F
fengjiayi 已提交
155 156
    # Remove ops whose outputs are all in no_grad_dict
    op_descs = filter(
F
fengjiayi 已提交
157
        lambda op_desc: not _op_can_be_removed_(op_desc, no_grad_set), op_descs)
F
fengjiayi 已提交
158 159
    # Insert fill_zeros_like_op
    to_insert = []
F
fengjiayi 已提交
160
    for idx, op_desc in enumerate(op_descs):
F
fengjiayi 已提交
161
        for arg in op_desc.input_arg_names():
F
fengjiayi 已提交
162 163 164
            if core.grad_var_suffix() in arg and arg in no_grad_set:
                to_insert.append((_create_op_desc_("fill_zeros_like", {
                    "X": [_strip_grad_suffix_(arg)]
165
                }, {"Out": [arg]}, {}), idx))
F
fengjiayi 已提交
166 167 168 169 170 171 172 173 174 175 176 177

    map(lambda p: op_descs.insert(p[1], p[0]), reversed(to_insert))

    return op_descs


def _append_backward_ops_(target,
                          block,
                          target_block,
                          no_grad_dict,
                          grad_to_var,
                          callback=None):
178 179 180 181 182 183 184
    """
    Create all grad ops, and insert them into given block

    Args:
        target(Variable): the target variable of forward pass
        block(Block): the block where forward ops are
        target_block(Block): the block which is going to hold new generated grad ops
185
        no_grad_dict(dict):
186 187 188 189 190
            key(int)  block index
            val(set) a set of varibale names. These varibales have no gradient
        grad_to_var(dict)(output argument):
            key(str): grad variable name
            val(str): corresponding forward variable name
F
fengjiayi 已提交
191
        callback(callable object): a callable object used to decorate new generated grad ops
192
    """
F
fengjiayi 已提交
193 194
    if callback is None:

F
fix bug  
fengjiayi 已提交
195
        def empty_callback(block, context):
F
fengjiayi 已提交
196 197 198 199
            pass

        callback = empty_callback
    elif not hasattr(callback, '__call__'):
F
fengjiayi 已提交
200
        raise ValueError("'callback' must be a callable object.")
F
fengjiayi 已提交
201

F
fengjiayi 已提交
202
    # grad_op_descs holds created grad_op, and will be appended to target_block
F
fengjiayi 已提交
203 204 205 206 207 208 209 210 211 212 213 214
    grad_op_descs = []
    program = block.program
    for op in reversed(block.ops):
        grad_sub_block_list = []
        # If the op has its own sub-block, deal with the sub-block first
        if op.has_attr("sub_block"):
            sub_block = program.block(op.block_attr("sub_block"))
            grad_sub_block = program.create_block(parent_idx=sub_block.idx)
            _append_backward_ops_(target, sub_block, grad_sub_block,
                                  no_grad_dict, grad_to_var, callback)
            grad_sub_block_list.append(grad_sub_block.desc)

F
fengjiayi 已提交
215
        # Getting op's corresponding grad_op
F
fengjiayi 已提交
216 217
        grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
            op.desc, no_grad_dict[block.idx], grad_sub_block_list)
Y
Yang Yu 已提交
218

F
fengjiayi 已提交
219 220 221 222 223 224 225
        grad_op_descs.extend(grad_op_desc)
        grad_to_var.update(op_grad_to_var)

    grad_op_descs = _addup_repetitive_outputs_(grad_op_descs)

    grad_op_descs = _remove_no_grad_branch_(grad_op_descs,
                                            no_grad_dict[block.idx])
F
fengjiayi 已提交
226

F
update  
fengjiayi 已提交
227 228 229
    if target_block.idx == 0:
        grad_op_descs.insert(
            0,
F
fengjiayi 已提交
230 231 232 233 234 235
            _create_op_desc_("fill_constant", {}, {
                "Out": [_append_grad_suffix_(target.name)]
            }, {"shape": [1],
                "value": 1.0,
                "dtype": target.dtype}))
    # append op_desc in grad_op_descs to target_block
F
update  
fengjiayi 已提交
236
    for op_desc in grad_op_descs:
F
fengjiayi 已提交
237 238
        new_op_desc = target_block.desc.append_op()
        new_op_desc.copy_from(op_desc)
F
fengjiayi 已提交
239
        callback(block=target_block, context=grad_to_var)
F
update  
fengjiayi 已提交
240

F
fengjiayi 已提交
241 242

def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
243 244 245 246 247 248 249 250 251 252 253 254 255 256
    """
    Create new variables required by backward pass.

    Args:
        block(Block): the block where new variables will be created
        start_op_idx(int): Only variables required by ops in block.ops[start_op_idx : ] will be created
        grad_to_var(dict):
            key(str): grad variable name
            val(str): corresponding forward variable name
            In most cases, this dict is generated by _append_backward_ops_()
        grad_info_map(dict)(output argument):
            key(str): forward variable name
            val(tuple): a tuple of (str, int), str is the corresponding grad name, int is the block index
    """
F
fengjiayi 已提交
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
    for op_idx in range(start_op_idx, block.desc.op_size()):
        op_desc = block.desc.op(op_idx)
        if op_desc.has_attr("sub_block"):
            sub_block = block.program.block(op_desc.block_attr("sub_block"))
            _append_backward_vars_(sub_block, 0, grad_to_var, grad_info_map)
        new_vars = set()
        # create new gradient variables
        for grad_var_name in op_desc.output_arg_names():
            grad_var_name = grad_var_name.encode("ascii")
            if block.desc.has_var_recursive(
                    grad_var_name) or grad_var_name == core.empty_var_name():
                continue
            block.desc.var(grad_var_name)
            new_vars.add(grad_var_name)
            if not grad_to_var.has_key(grad_var_name):
                continue
            grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name, block)
        # infer_shape and infer_type
        op_desc.infer_var_type(block.desc)
        op_desc.infer_shape(block.desc)
        for arg in op_desc.output_arg_names():
            if arg in new_vars:
                _infer_var_data_type_(arg, block)
F
update  
fengjiayi 已提交
280 281


F
fengjiayi 已提交
282
def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None):
283
    """
F
fengjiayi 已提交
284 285 286 287 288 289
    Append backward part to main_program

    Args:
        loss(Variable): The variable generated by cost function.
        parameter_list(list): Parameters that need to be updated by optimizer.
            If None, it means all parameters need to be updated.
290 291
        no_grad_set(set): Variables that have no gradients in Block 0.
            If None, the set will be generated inside the function and
F
fengjiayi 已提交
292 293 294 295
            contains all variables with `step_gradient=True` from all blocks.

    Return:
        (list[Variable]): list of (parameters, gradients) pair.
296 297
    """
    assert isinstance(loss, framework.Variable)
Y
Yu Yang 已提交
298

F
fengjiayi 已提交
299
    program = loss.block.program
F
fengjiayi 已提交
300 301
    no_grad_dict = dict()
    if no_grad_set is None:
Y
Yu Yang 已提交
302 303 304
        assert isinstance(program, framework.Program)
        for block in program.blocks:
            assert isinstance(block, framework.Block)
F
update  
fengjiayi 已提交
305
            block_no_grad_set = set()
Y
Yu Yang 已提交
306 307 308
            for var in block.vars.itervalues():
                assert isinstance(var, framework.Variable)
                if var.stop_gradient:
F
fengjiayi 已提交
309
                    block_no_grad_set.add(_append_grad_suffix_(var.name))
F
fengjiayi 已提交
310
            no_grad_dict[block.idx] = block_no_grad_set
F
fengjiayi 已提交
311
    elif isinstance(no_grad_set, set):
F
fengjiayi 已提交
312 313 314
        no_grad_dict = {
            0: set([_append_grad_suffix_(name) for name in no_grad_set])
        }
F
fengjiayi 已提交
315 316
    else:
        raise ValueError("'no_grad_set' should be a set or None.")
Y
Yu Yang 已提交
317

F
update  
fengjiayi 已提交
318
    grad_info_map = dict()
F
fengjiayi 已提交
319
    root_block = program.block(0)
F
fengjiayi 已提交
320

F
fengjiayi 已提交
321 322
    fwd_op_num = root_block.desc.op_size()
    current_block_idx = program.current_block_idx
F
fengjiayi 已提交
323 324 325
    grad_to_var = dict()

    _append_backward_ops_(loss, root_block, root_block, no_grad_dict,
F
fengjiayi 已提交
326
                          grad_to_var, callback)
F
fengjiayi 已提交
327
    _append_backward_vars_(root_block, fwd_op_num, grad_to_var, grad_info_map)
F
fengjiayi 已提交
328

F
fengjiayi 已提交
329 330
    program.current_block_idx = current_block_idx
    program.sync_with_cpp()
F
fengjiayi 已提交
331

332 333 334
    if parameter_list is not None:
        parameters = parameter_list
    else:
F
fengjiayi 已提交
335
        params = program.global_block().all_parameters()
336 337 338
        parameters = [param.name for param in params]
    params_and_grads = []
    for param in parameters:
F
update  
fengjiayi 已提交
339
        if param not in grad_info_map:
340
            raise ValueError("param %s is not in map" % param)
F
update  
fengjiayi 已提交
341
        grad_info = grad_info_map[param]
F
fengjiayi 已提交
342
        grad_block = grad_info[1]
343 344 345 346
        if not grad_block.has_var(grad_info[0]):
            raise ValueError("grad block[{0}] did not have grad var {1}".format(
                grad_info[1], grad_info[0]))
        # Get the param var from the global block
F
fengjiayi 已提交
347
        param_var = program.global_block().var(param)
348 349 350 351 352 353
        grad_var = grad_block.var(grad_info[0])
        if loss.block.has_var(grad_info[0]):
            params_and_grads.append((param_var, grad_var))
        else:
            params_and_grads.append((param_var, None))
    return params_and_grads