未验证 提交 b32314b6 编写于 作者: X xiaoguoguo626807 提交者: GitHub

[NewIR]pulish name in backward.py (#56650)

* modify sum with divide net bug mutablesttribute

* delete prin

* pulish backward

* pulish backward
上级 ca5585e9
...@@ -52,8 +52,7 @@ def check_all_puts(block, inputs, outputs): ...@@ -52,8 +52,7 @@ def check_all_puts(block, inputs, outputs):
def update_no_grad_set_by_stopgradient(block, no_grad_set): def update_no_grad_set_by_stopgradient(block, no_grad_set):
for op in block.ops: for op in block.ops:
for opresult_idx in range(op.num_results()): for value in op.results():
value = op.result(opresult_idx)
if value.stop_gradient and value not in no_grad_set: if value.stop_gradient and value not in no_grad_set:
no_grad_set.add(value) no_grad_set.add(value)
...@@ -63,9 +62,7 @@ def update_bwdop_structure(backward_ops, op_to_opgrad_list, grad_op): ...@@ -63,9 +62,7 @@ def update_bwdop_structure(backward_ops, op_to_opgrad_list, grad_op):
op_to_opgrad_list.append(grad_op) op_to_opgrad_list.append(grad_op)
def prepare_grad_outputs( def prepare_grad_outputs(grad_outputs, outputs, state):
block, grad_outputs, outputs, value_to_valuegrad, op_to_opgrad
):
""" """
if grad_outputs is none, add fill_1 op to create grad_outputs, if grad_outputs is none, add fill_1 op to create grad_outputs,
else check whether outputs shape and dtype is same to grad_outputs, otherwise raise error. else check whether outputs shape and dtype is same to grad_outputs, otherwise raise error.
...@@ -100,10 +97,10 @@ def prepare_grad_outputs( ...@@ -100,10 +97,10 @@ def prepare_grad_outputs(
update_bwdop_structure( update_bwdop_structure(
backward_ops, backward_ops,
op_to_opgrad[output.get_defining_op()], state.op_to_opgrad[output.get_defining_op()],
fillop, fillop,
) )
value_to_valuegrad[output] = [[output_grad]] state.value_to_valuegrad[output] = [[output_grad]]
else: else:
if output.shape != grad.shape: if output.shape != grad.shape:
raise ValueError( raise ValueError(
...@@ -117,9 +114,11 @@ def prepare_grad_outputs( ...@@ -117,9 +114,11 @@ def prepare_grad_outputs(
) )
feedop = grad.get_defining_op() feedop = grad.get_defining_op()
update_bwdop_structure( update_bwdop_structure(
backward_ops, op_to_opgrad[output.get_defining_op()], feedop backward_ops,
state.op_to_opgrad[output.get_defining_op()],
feedop,
) )
value_to_valuegrad[output] = [[grad]] state.value_to_valuegrad[output] = [[grad]]
# add input for bwd first op # add input for bwd first op
complete_outputs = outputs complete_outputs = outputs
...@@ -130,7 +129,7 @@ def prepare_grad_outputs( ...@@ -130,7 +129,7 @@ def prepare_grad_outputs(
if output in visited_output: if output in visited_output:
continue continue
for opresult in output.get_defining_op().results(): for opresult in output.get_defining_op().results():
if opresult in value_to_valuegrad: if opresult in state.value_to_valuegrad:
visited_output.add(opresult) visited_output.add(opresult)
continue continue
else: else:
...@@ -143,10 +142,10 @@ def prepare_grad_outputs( ...@@ -143,10 +142,10 @@ def prepare_grad_outputs(
update_bwdop_structure( update_bwdop_structure(
backward_ops, backward_ops,
op_to_opgrad[opresult.get_defining_op()], state.op_to_opgrad[opresult.get_defining_op()],
fillop, fillop,
) )
value_to_valuegrad[opresult] = [grad_value] state.value_to_valuegrad[opresult] = [grad_value]
visited_output.add(opresult) visited_output.add(opresult)
...@@ -196,7 +195,6 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set): ...@@ -196,7 +195,6 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set):
# from output to input # from output to input
for i, op in reversed(list(enumerate(total_ops))): for i, op in reversed(list(enumerate(total_ops))):
# while op support
if some_in_set(op.results(), outputs_set): if some_in_set(op.results(), outputs_set):
for operand in op.operands_source(): for operand in op.operands_source():
if operand not in no_grad_set: if operand not in no_grad_set:
...@@ -233,7 +231,7 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set): ...@@ -233,7 +231,7 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set):
def update_no_grad_set_after_prune( def update_no_grad_set_after_prune(
block, effective_forward_op, no_grad_set, inputs, outputs block, effective_forward_ops, no_grad_set, inputs, outputs
): ):
''' '''
update no_grad_set after forward prune update no_grad_set after forward prune
...@@ -249,14 +247,14 @@ def update_no_grad_set_after_prune( ...@@ -249,14 +247,14 @@ def update_no_grad_set_after_prune(
if value not in no_grad_set: if value not in no_grad_set:
inputs_set.add(value) inputs_set.add(value)
for op in effective_forward_op: for op in effective_forward_ops:
for value in op.operands_source(): for value in op.operands_source():
if value not in inputs_set: # and value.get_stopgradient(): if value not in inputs_set:
no_grad_set.add(value) no_grad_set.add(value)
outputs_set = set(outputs) outputs_set = set(outputs)
no_grad_set_tmp = set() no_grad_set_tmp = set()
for op in reversed(effective_forward_op): for op in reversed(effective_forward_ops):
for output in op.results(): for output in op.results():
if output not in outputs_set and not some_in_set( if output not in outputs_set and not some_in_set(
[output], set(op.operands_source()) [output], set(op.operands_source())
...@@ -313,7 +311,7 @@ def inverse_sort_op(ops): ...@@ -313,7 +311,7 @@ def inverse_sort_op(ops):
def append_backward_ops( def append_backward_ops(
block, effective_forward_op, no_grad_set, backward_ops, state block, effective_forward_ops, no_grad_set, backward_ops, state
): ):
''' '''
add grad_op in order of topological inverse sort add grad_op in order of topological inverse sort
...@@ -417,28 +415,26 @@ def append_backward_ops( ...@@ -417,28 +415,26 @@ def append_backward_ops(
return zero_flag, output_grad return zero_flag, output_grad
def make_input_stopgradient(op): def make_input_stopgradient(op):
input_grad_stopgradient_list = [] input_grad_stopgradients = []
for input in op.operands_source(): for input in op.operands_source():
if input.get_defining_op().name() == "builtin.combine": if input.get_defining_op().name() == "builtin.combine":
stop_gradient = make_input_stopgradient(input.get_defining_op()) stop_gradient = make_input_stopgradient(input.get_defining_op())
input_grad_stopgradient_list.append( input_grad_stopgradients.append(
[info[0] for info in stop_gradient] [info[0] for info in stop_gradient]
) )
else: else:
if input in no_grad_set: if input in no_grad_set:
input_grad_stopgradient_list.append([True]) input_grad_stopgradients.append([True])
else: else:
input_grad_stopgradient_list.append([False]) input_grad_stopgradients.append([False])
return input_grad_stopgradient_list return input_grad_stopgradients
def update_input_grad_map(op, input_grad_list): def update_input_grad_map(op, input_grads):
for i, input in enumerate(op.operands_source()): for i, input in enumerate(op.operands_source()):
if input.get_defining_op().name() == "builtin.combine": if input.get_defining_op().name() == "builtin.combine":
update_input_grad_map( update_input_grad_map(input.get_defining_op(), input_grads[i])
input.get_defining_op(), input_grad_list[i]
)
else: else:
input_grad = input_grad_list[i] input_grad = input_grads[i]
if isinstance(input_grad, list): if isinstance(input_grad, list):
state.value_to_valuegrad[input].append(input_grad) state.value_to_valuegrad[input].append(input_grad)
else: else:
...@@ -451,31 +447,31 @@ def append_backward_ops( ...@@ -451,31 +447,31 @@ def append_backward_ops(
# [op4] (op4's inputs and outputs are not vectorType) # [op4] (op4's inputs and outputs are not vectorType)
# einsum has twp vectorType outputs, special pattern # einsum has twp vectorType outputs, special pattern
clear_effective_forward_op = [] clear_effective_forward_ops = []
for op in effective_forward_op: for op in effective_forward_ops:
if op.name() != "builtin.combine" and op.name() != "builtin.split": if op.name() != "builtin.combine" and op.name() != "builtin.split":
clear_effective_forward_op.append(op) clear_effective_forward_ops.append(op)
for op in clear_effective_forward_op: for op in clear_effective_forward_ops:
if paddle.framework.core.has_vjp(op): if paddle.framework.core.has_vjp(op):
# prepare output_grad # prepare output_grad
output_grad_list = [] # (opresult) output_grads = [] # (opresult)
zero_flag, output_grad = make_output_grad(op) zero_flag, output_grad = make_output_grad(op)
output_grad_list.append(output_grad) output_grads.append(output_grad)
# all(zero_flag) support this op has no contribution for grad # all(zero_flag) support this op has no contribution for grad
# should be delete (prune sub_graph) # should be delete (prune sub_graph)
if len(output_grad_list) == 0 or all(zero_flag): if len(output_grads) == 0 or all(zero_flag):
continue continue
# prepare input_grad stop_gradient info. # prepare input_grad stop_gradient info.
input_grad_stopgradient_list = make_input_stopgradient(op) input_grad_stopgradients = make_input_stopgradient(op)
# create grad_op # create grad_op
before_ops_num = len(block.ops) before_ops_num = len(block.ops)
input_grad_list = paddle.framework.core.call_vjp( input_grads = paddle.framework.core.call_vjp(
op, output_grad_list, input_grad_stopgradient_list op, output_grads, input_grad_stopgradients
) )
after_ops_num = len(block.ops) after_ops_num = len(block.ops)
...@@ -486,7 +482,7 @@ def append_backward_ops( ...@@ -486,7 +482,7 @@ def append_backward_ops(
) )
# update input_grad map # update input_grad map
update_input_grad_map(op, input_grad_list) update_input_grad_map(op, input_grads)
else: else:
if op.num_operands() == 0 and op.num_results() != 0: if op.num_operands() == 0 and op.num_results() != 0:
...@@ -519,15 +515,18 @@ def append_backward_ops( ...@@ -519,15 +515,18 @@ def append_backward_ops(
def create_backward_prune_set(inputs, outputs, no_grad_set, state): def create_backward_prune_set(inputs, outputs, no_grad_set, state):
outputs_set = set() outputs_set = set()
for input in inputs: for input_ in inputs:
for item in input.first_use().owner().operands_source(): if not input_.use_empty():
if state.value_to_valuegrad[item] != []: for item in input_.first_use().owner().operands_source():
outputs_set.add(state.value_to_valuegrad[item][0][0]) if state.value_to_valuegrad[item] != []:
outputs_set.add(state.value_to_valuegrad[item][0][0])
else:
raise ValueError("input privided by inputs has no use")
inputs_set = set() inputs_set = set()
for output in outputs: for output in outputs:
if state.value_to_valuegrad[output] != []: if state.value_to_valuegrad[output] != []:
inputs_set.add(state.value_to_valuegrad[output][0][0]) inputs_set.add(state.value_to_valuegrad[output][0][0])
inputs_set_tmp = set() inputs_set_tmp = set()
for out_grad in inputs_set: for out_grad in inputs_set:
for item in out_grad.first_use().owner().operands_source(): for item in out_grad.first_use().owner().operands_source():
...@@ -538,7 +537,6 @@ def create_backward_prune_set(inputs, outputs, no_grad_set, state): ...@@ -538,7 +537,6 @@ def create_backward_prune_set(inputs, outputs, no_grad_set, state):
for key in state.value_to_valuegrad: for key in state.value_to_valuegrad:
if key in no_grad_set: if key in no_grad_set:
no_gradvar_set.add(state.value_to_valuegrad[key][0][0]) no_gradvar_set.add(state.value_to_valuegrad[key][0][0])
for key in state.value_to_sumvaluegrad: for key in state.value_to_sumvaluegrad:
if key in no_grad_set: if key in no_grad_set:
for item in state.value_to_sumvaluegrad[key][0]: for item in state.value_to_sumvaluegrad[key][0]:
...@@ -575,26 +573,22 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set): ...@@ -575,26 +573,22 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set):
# update no_grad_set if some value stop_gradient=True # update no_grad_set if some value stop_gradient=True
update_no_grad_set_by_stopgradient(block, no_grad_set) update_no_grad_set_by_stopgradient(block, no_grad_set)
complete_outputs, _, backward_ops = prepare_grad_outputs( complete_outputs, _, backward_ops = prepare_grad_outputs(
block, grad_outputs, outputs, state
grad_outputs,
outputs,
state.value_to_valuegrad,
state.op_to_opgrad,
) )
inputs_set = set(inputs) inputs_set = set(inputs)
outputs_set = set(complete_outputs) outputs_set = set(complete_outputs)
effective_forward_op, _ = prune_ops( effective_forward_ops, _ = prune_ops(
block.ops, inputs_set, outputs_set, no_grad_set block.ops, inputs_set, outputs_set, no_grad_set
) )
update_no_grad_set_after_prune( update_no_grad_set_after_prune(
block, effective_forward_op, no_grad_set, inputs, complete_outputs block, effective_forward_ops, no_grad_set, inputs, complete_outputs
) )
inverse_effective_forward_op = inverse_sort_op(effective_forward_op) inverse_effective_forward_ops = inverse_sort_op(effective_forward_ops)
append_backward_ops( append_backward_ops(
block, inverse_effective_forward_op, no_grad_set, backward_ops, state block, inverse_effective_forward_ops, no_grad_set, backward_ops, state
) )
# now value_to_valuegrad should be value <-> value (add sum op for the same values's gradvalue) # now value_to_valuegrad should be value <-> value (add sum op for the same values's gradvalue)
...@@ -719,26 +713,26 @@ def grad( ...@@ -719,26 +713,26 @@ def grad(
outputs, outputs,
'outputs', 'outputs',
((paddle.ir.Value, paddle.ir.OpResult), list, tuple), ((paddle.ir.Value, paddle.ir.OpResult), list, tuple),
'paddle.ir.grad', 'paddle.autograd.backward.grad',
) )
check_type( check_type(
inputs, inputs,
'inputs', 'inputs',
((paddle.ir.Value, paddle.ir.OpResult), list, tuple), ((paddle.ir.Value, paddle.ir.OpResult), list, tuple),
'paddle.ir.grad', 'paddle.autograd.backward.grad',
) )
check_type( check_type(
grad_outputs, grad_outputs,
'grad_outputs', 'grad_outputs',
((paddle.ir.Value, paddle.ir.OpResult), list, tuple, type(None)), ((paddle.ir.Value, paddle.ir.OpResult), list, tuple, type(None)),
'paddle.ir.grad', 'paddle.autograd.backward.grad',
) )
check_type( check_type(
no_grad_vars, no_grad_vars,
'no_grad_vars', 'no_grad_vars',
((paddle.ir.Value, paddle.ir.OpResult), list, tuple, set, type(None)), ((paddle.ir.Value, paddle.ir.OpResult), list, tuple, set, type(None)),
'paddle.ir.grad', 'paddle.autograd.backward.grad',
) )
outputs = _as_list(outputs) outputs = _as_list(outputs)
inputs = _as_list(inputs) inputs = _as_list(inputs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册