未验证 提交 b32314b6 编写于 作者: X xiaoguoguo626807 提交者: GitHub

[NewIR]pulish name in backward.py (#56650)

* modify sum with divide net bug mutablesttribute

* delete prin

* pulish backward

* pulish backward
上级 ca5585e9
......@@ -52,8 +52,7 @@ def check_all_puts(block, inputs, outputs):
def update_no_grad_set_by_stopgradient(block, no_grad_set):
for op in block.ops:
for opresult_idx in range(op.num_results()):
value = op.result(opresult_idx)
for value in op.results():
if value.stop_gradient and value not in no_grad_set:
no_grad_set.add(value)
......@@ -63,9 +62,7 @@ def update_bwdop_structure(backward_ops, op_to_opgrad_list, grad_op):
op_to_opgrad_list.append(grad_op)
def prepare_grad_outputs(
block, grad_outputs, outputs, value_to_valuegrad, op_to_opgrad
):
def prepare_grad_outputs(grad_outputs, outputs, state):
"""
if grad_outputs is none, add fill_1 op to create grad_outputs,
else check whether outputs shape and dtype is same to grad_outputs, otherwise raise error.
......@@ -100,10 +97,10 @@ def prepare_grad_outputs(
update_bwdop_structure(
backward_ops,
op_to_opgrad[output.get_defining_op()],
state.op_to_opgrad[output.get_defining_op()],
fillop,
)
value_to_valuegrad[output] = [[output_grad]]
state.value_to_valuegrad[output] = [[output_grad]]
else:
if output.shape != grad.shape:
raise ValueError(
......@@ -117,9 +114,11 @@ def prepare_grad_outputs(
)
feedop = grad.get_defining_op()
update_bwdop_structure(
backward_ops, op_to_opgrad[output.get_defining_op()], feedop
backward_ops,
state.op_to_opgrad[output.get_defining_op()],
feedop,
)
value_to_valuegrad[output] = [[grad]]
state.value_to_valuegrad[output] = [[grad]]
# add input for bwd first op
complete_outputs = outputs
......@@ -130,7 +129,7 @@ def prepare_grad_outputs(
if output in visited_output:
continue
for opresult in output.get_defining_op().results():
if opresult in value_to_valuegrad:
if opresult in state.value_to_valuegrad:
visited_output.add(opresult)
continue
else:
......@@ -143,10 +142,10 @@ def prepare_grad_outputs(
update_bwdop_structure(
backward_ops,
op_to_opgrad[opresult.get_defining_op()],
state.op_to_opgrad[opresult.get_defining_op()],
fillop,
)
value_to_valuegrad[opresult] = [grad_value]
state.value_to_valuegrad[opresult] = [grad_value]
visited_output.add(opresult)
......@@ -196,7 +195,6 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set):
# from output to input
for i, op in reversed(list(enumerate(total_ops))):
# while op support
if some_in_set(op.results(), outputs_set):
for operand in op.operands_source():
if operand not in no_grad_set:
......@@ -233,7 +231,7 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set):
def update_no_grad_set_after_prune(
block, effective_forward_op, no_grad_set, inputs, outputs
block, effective_forward_ops, no_grad_set, inputs, outputs
):
'''
update no_grad_set after forward prune
......@@ -249,14 +247,14 @@ def update_no_grad_set_after_prune(
if value not in no_grad_set:
inputs_set.add(value)
for op in effective_forward_op:
for op in effective_forward_ops:
for value in op.operands_source():
if value not in inputs_set: # and value.get_stopgradient():
if value not in inputs_set:
no_grad_set.add(value)
outputs_set = set(outputs)
no_grad_set_tmp = set()
for op in reversed(effective_forward_op):
for op in reversed(effective_forward_ops):
for output in op.results():
if output not in outputs_set and not some_in_set(
[output], set(op.operands_source())
......@@ -313,7 +311,7 @@ def inverse_sort_op(ops):
def append_backward_ops(
block, effective_forward_op, no_grad_set, backward_ops, state
block, effective_forward_ops, no_grad_set, backward_ops, state
):
'''
add grad_op in order of topological inverse sort
......@@ -417,28 +415,26 @@ def append_backward_ops(
return zero_flag, output_grad
def make_input_stopgradient(op):
input_grad_stopgradient_list = []
input_grad_stopgradients = []
for input in op.operands_source():
if input.get_defining_op().name() == "builtin.combine":
stop_gradient = make_input_stopgradient(input.get_defining_op())
input_grad_stopgradient_list.append(
input_grad_stopgradients.append(
[info[0] for info in stop_gradient]
)
else:
if input in no_grad_set:
input_grad_stopgradient_list.append([True])
input_grad_stopgradients.append([True])
else:
input_grad_stopgradient_list.append([False])
return input_grad_stopgradient_list
input_grad_stopgradients.append([False])
return input_grad_stopgradients
def update_input_grad_map(op, input_grad_list):
def update_input_grad_map(op, input_grads):
for i, input in enumerate(op.operands_source()):
if input.get_defining_op().name() == "builtin.combine":
update_input_grad_map(
input.get_defining_op(), input_grad_list[i]
)
update_input_grad_map(input.get_defining_op(), input_grads[i])
else:
input_grad = input_grad_list[i]
input_grad = input_grads[i]
if isinstance(input_grad, list):
state.value_to_valuegrad[input].append(input_grad)
else:
......@@ -451,31 +447,31 @@ def append_backward_ops(
# [op4] (op4's inputs and outputs are not vectorType)
# einsum has twp vectorType outputs, special pattern
clear_effective_forward_op = []
clear_effective_forward_ops = []
for op in effective_forward_op:
for op in effective_forward_ops:
if op.name() != "builtin.combine" and op.name() != "builtin.split":
clear_effective_forward_op.append(op)
clear_effective_forward_ops.append(op)
for op in clear_effective_forward_op:
for op in clear_effective_forward_ops:
if paddle.framework.core.has_vjp(op):
# prepare output_grad
output_grad_list = [] # (opresult)
output_grads = [] # (opresult)
zero_flag, output_grad = make_output_grad(op)
output_grad_list.append(output_grad)
output_grads.append(output_grad)
# all(zero_flag) support this op has no contribution for grad
# should be delete (prune sub_graph)
if len(output_grad_list) == 0 or all(zero_flag):
if len(output_grads) == 0 or all(zero_flag):
continue
# prepare input_grad stop_gradient info.
input_grad_stopgradient_list = make_input_stopgradient(op)
input_grad_stopgradients = make_input_stopgradient(op)
# create grad_op
before_ops_num = len(block.ops)
input_grad_list = paddle.framework.core.call_vjp(
op, output_grad_list, input_grad_stopgradient_list
input_grads = paddle.framework.core.call_vjp(
op, output_grads, input_grad_stopgradients
)
after_ops_num = len(block.ops)
......@@ -486,7 +482,7 @@ def append_backward_ops(
)
# update input_grad map
update_input_grad_map(op, input_grad_list)
update_input_grad_map(op, input_grads)
else:
if op.num_operands() == 0 and op.num_results() != 0:
......@@ -519,15 +515,18 @@ def append_backward_ops(
def create_backward_prune_set(inputs, outputs, no_grad_set, state):
outputs_set = set()
for input in inputs:
for item in input.first_use().owner().operands_source():
if state.value_to_valuegrad[item] != []:
outputs_set.add(state.value_to_valuegrad[item][0][0])
for input_ in inputs:
if not input_.use_empty():
for item in input_.first_use().owner().operands_source():
if state.value_to_valuegrad[item] != []:
outputs_set.add(state.value_to_valuegrad[item][0][0])
else:
raise ValueError("input privided by inputs has no use")
inputs_set = set()
for output in outputs:
if state.value_to_valuegrad[output] != []:
inputs_set.add(state.value_to_valuegrad[output][0][0])
inputs_set_tmp = set()
for out_grad in inputs_set:
for item in out_grad.first_use().owner().operands_source():
......@@ -538,7 +537,6 @@ def create_backward_prune_set(inputs, outputs, no_grad_set, state):
for key in state.value_to_valuegrad:
if key in no_grad_set:
no_gradvar_set.add(state.value_to_valuegrad[key][0][0])
for key in state.value_to_sumvaluegrad:
if key in no_grad_set:
for item in state.value_to_sumvaluegrad[key][0]:
......@@ -575,26 +573,22 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set):
# update no_grad_set if some value stop_gradient=True
update_no_grad_set_by_stopgradient(block, no_grad_set)
complete_outputs, _, backward_ops = prepare_grad_outputs(
block,
grad_outputs,
outputs,
state.value_to_valuegrad,
state.op_to_opgrad,
grad_outputs, outputs, state
)
inputs_set = set(inputs)
outputs_set = set(complete_outputs)
effective_forward_op, _ = prune_ops(
effective_forward_ops, _ = prune_ops(
block.ops, inputs_set, outputs_set, no_grad_set
)
update_no_grad_set_after_prune(
block, effective_forward_op, no_grad_set, inputs, complete_outputs
block, effective_forward_ops, no_grad_set, inputs, complete_outputs
)
inverse_effective_forward_op = inverse_sort_op(effective_forward_op)
inverse_effective_forward_ops = inverse_sort_op(effective_forward_ops)
append_backward_ops(
block, inverse_effective_forward_op, no_grad_set, backward_ops, state
block, inverse_effective_forward_ops, no_grad_set, backward_ops, state
)
# now value_to_valuegrad should be value <-> value (add sum op for the same values's gradvalue)
......@@ -719,26 +713,26 @@ def grad(
outputs,
'outputs',
((paddle.ir.Value, paddle.ir.OpResult), list, tuple),
'paddle.ir.grad',
'paddle.autograd.backward.grad',
)
check_type(
inputs,
'inputs',
((paddle.ir.Value, paddle.ir.OpResult), list, tuple),
'paddle.ir.grad',
'paddle.autograd.backward.grad',
)
check_type(
grad_outputs,
'grad_outputs',
((paddle.ir.Value, paddle.ir.OpResult), list, tuple, type(None)),
'paddle.ir.grad',
'paddle.autograd.backward.grad',
)
check_type(
no_grad_vars,
'no_grad_vars',
((paddle.ir.Value, paddle.ir.OpResult), list, tuple, set, type(None)),
'paddle.ir.grad',
'paddle.autograd.backward.grad',
)
outputs = _as_list(outputs)
inputs = _as_list(inputs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册