from paddle.v2.fluid import framework as framework from . import core import collections __all__ = ['append_backward_ops'] def rename_arg(op_desc_list, old_name, new_name, begin_idx=None, end_idx=None): if begin_idx is None: begin_idx = 0 if end_idx is None: end_idx = len(op_desc_list) for i in range(begin_idx, end_idx): op_desc_list[i].rename_input(old_name, new_name) op_desc_list[i].rename_output(old_name, new_name) def backward_impl(block, target_block, no_grad_set, callback=None): grad_op_descs = [] grad_to_var = {} program = block.program for each_op in block.ops: grad_sub_block_list = [] if each_op.has_attr("sub_block"): sub_block_idx = each_op.block_attr("sub_block") sub_block = program.block(sub_block_idx) grad_sub_block = program.create_block(parent_idx=sub_block_idx) backward_impl(sub_block, grad_sub_block, no_grad_set, callback) grad_sub_block_list.append(grad_sub_block) grad_op_desc = core.get_grad_op_desc(each_op.desc, no_grad_set[block.idx], grad_to_var, grad_sub_block_list) grad_op_descs.append(grad_op_desc) # grad_op_descs = [[op1_g1, op1_g2], [op2_g], ...] # flatten grad_op_descs grad_op_descs = [op for sublist in grad_op_descs for op in sublist] # ????? pending_sum_ops = [] var_rename_count = collections.defaultdict(int) var_inputs = collections.defaultdict(list) for pos, op_desc in enumerate(grad_op_descs): for var_name in op_desc.input_arg_names(): if len(var_inputs[var_name]) > 1: pending_sum_ops.append((core.OpDesc( type="sum_op", inputs=var_inputs[var_name], output=[var_name], attrs={}), pos)) var_inputs[var_name] = [var_name] for var_name in op_desc.output_arg_names(): if len(var_inputs[var_name]) == 0: # it's the first time we get the variable var_inputs[var_name] = var_name else: if len(var_inputs[var_name] == 1): new_name = var_name + "@RENAME@" + \ str(var_rename_count[var_name]) var_rename_count[var_name] = var_rename_count[var_name] + 1 # rename original var_name var_inputs[var_name][0] = new_name rename_arg(grad_op_descs, var_name, new_name, 0, pos) rename_arg(pending_sum_ops, var_name, new_name) new_name = var_name + "@RENAME@" + \ str(var_rename_count[var_name]) var_rename_count[var_name] = var_rename_count[var_name] + 1 op_desc.rename_output(var_name, new_name) var_inputs[var_name].append(new_name) for var_name, inputs in var_inputs.iteritems(): if len(inputs) > 1: pending_sum_ops.append((core.OpDesc( type="sum_op", inputs=inputs, outputs=var_name, attrs={}), len(grad_op_descs))) # 根据append的顺序可以看出pending_sum_ops一定是根据sum_op的插入位置排序的 for p in reversed(pending_sum_ops): grad_op_descs.insert(p[1], p[0]) # create new gradient variables in the target block for op_desc in grad_op_descs: for grad_var_name in op_desc.output_arg_names(): if target_block.has_var( grad_var_name) or grad_var_name == core.get_empty_var_name( ): continue target_block.var(grad_var_name) def append_backward_ops(loss, parameter_list=None, no_grad_set=None): """ Create and add gradient Operators in BlockDesc to compute gradients of `loss` for parameters in parameter_list :param loss: an variable generated by cost function. :type loss: Variable :param no_grad_set: variable that should not create gradient :type no_grad_set: set :param parameter_list: parameters that need to compute gradient and update to optimize the lost. :type: list :return: list of (parameters, gradients) pair. :rtype: list[Variable] """ assert isinstance(loss, framework.Variable) if no_grad_set is None: program = loss.block.program assert isinstance(program, framework.Program) no_grad_set = list() for block in program.blocks: assert isinstance(block, framework.Block) for var in block.vars.itervalues(): assert isinstance(var, framework.Variable) if var.stop_gradient: no_grad_set.append(var.name) no_grad_set = set(no_grad_set) param_grad_map = loss.block.program.append_backward(loss, no_grad_set) if parameter_list is not None: parameters = parameter_list else: params = loss.block.program.global_block().all_parameters() parameters = [param.name for param in params] params_and_grads = [] for param in parameters: if param not in param_grad_map: raise ValueError("param %s is not in map" % param) grad_info = param_grad_map[param] grad_block = loss.block.program.block(grad_info[1]) if not grad_block.has_var(grad_info[0]): raise ValueError("grad block[{0}] did not have grad var {1}".format( grad_info[1], grad_info[0])) # Get the param var from the global block param_var = loss.block.program.global_block().var(param) grad_var = grad_block.var(grad_info[0]) if loss.block.has_var(grad_info[0]): params_and_grads.append((param_var, grad_var)) else: params_and_grads.append((param_var, None)) return params_and_grads