未验证 提交 b1573431 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【New IR】 backward gradients accumulate test and pulish append_backward_ops...

【New IR】 backward gradients accumulate test and pulish append_backward_ops func for op_pattern  (#56265)

* [prim][newir] add basic framework for primitive

* support desctensor in new ir

* add vjp interface

* support vjp in new ir

* support vjp in new ir

* polish vjp interface

* fix stop_gradients set

* fix vjp dispatch

* add comment

* add vjp test for new ir

* add test for tanh vjp

* [prim][newir] add basic framework for primitive

* support desctensor in new ir

* support vjp in new ir

* support vjp in new ir

* polish vjp interface

* fix stop_gradients set

* fix vjp dispatch

* add comment

* add vjp test for new ir

* add test for tanh vjp

* add eager and static backend for warp lower level api

* support call_vjp pybind

* polish code and add test for vjp

* remove useless code

* polish code

* remove useless code

* support mean vjp

* backward origin code

* add test for mean vjp and support has_vjp function

* fix call_vjp

* polish code

* add attrs and dtype interface

* add primitive ops set for backend

* fix compile bugs

* fix some bugs

* fix windows bugs

* add vjp test for tanh_

* fix inference CI

* fix inference ci

* modify fluid cmake

* origin test of tanh and mean passed

* fix conflict

* modify stop_gradient

* remove useless deps

* add cmake

* modify block.ops

* modify test

* fix conflict

* reply review comments

* reply review comments

* pulish code

* fix comment

* fix test

* polish code

* modify backward stop_gradients

* modify static_backend.cc

* refactor grad_op

* support add and add_inplace vjp

* remove useless code

* remove useless code

* remove cout

* modify add_n

* modify add_n with add_vjp test

* modify add_n with add_vjp test

* fix conflict and concat call_vjp

* modify backward test

* Add more gen api

---------
Co-authored-by: Ncxxly <chenxx_id@163.com>
Co-authored-by: NCharles-hit <wanghao107@baidu.com>
Co-authored-by: Nzhangbo9674 <zhangbo54@baidu.com>
Co-authored-by: NYuanRisheng <yuanrisheng@baidu.com>
Co-authored-by: N0x45f <wangzhen45@baidu.com>
上级 920c66e9
......@@ -181,7 +181,6 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set):
'''
relevant_op_flags = [True] * len(total_ops)
# from input to output
if inputs_set:
for i, op in enumerate(total_ops):
......@@ -192,7 +191,6 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set):
for value in op.results():
if value not in no_grad_set:
inputs_set.add(value)
else:
relevant_op_flags[i] = False
......@@ -313,22 +311,50 @@ def append_backward_ops(
v2_g = call_vjp(op3, [v3_g], [v2_stopgradient])
special pattern 1:
v11 -> combine_op -> v1 -> op -> v3
v12 ->
v2 ->
value_to_valuegrad[v3] = [[v3_g]]
v1 is inside python api, we don't describe it in backward process(state)
so v1_grad is inside vjp, we don't describe it in backward process(state)
[[v11_g, v12_g], v2_g] = call_vjp(combine_op, [v3_g], [[v11_stopgradient, v12_stopgradient], v2_stop_gradient)
op_vjp is:
v11_g <- split_op <- v1_g <- op_g <- v3_g
v12_g <-
v2_g <-
value_to_valuegrad[v11] = [[v11_g]]
value_to_valuegrad[v12] = [[v12_g]]
value_to_valuegrad[v2] = [[v2_g]]
if op don't has grad_op, if it don't has input and it's output has more than
one output_grad, add sumop for grad aggregation.
(eg: full op and get_parameter op etc.)
else continue to next op.
'''
for op in effective_forward_op:
if paddle.framework.core.has_vjp(op):
# prepare output_grad
output_grad_list = [] # (opresult)
zero_flag = [False] * op.num_results()
for i, value in enumerate(op.results()):
if (
value not in state.value_to_valuegrad
or state.value_to_valuegrad[value] is None
):
def make_output_grad(op, split_op):
zero_flag = [False] * op.num_results()
for i, value in enumerate(op.results()):
if (
value not in state.value_to_valuegrad
or state.value_to_valuegrad[value] is None
):
if split_op is not None and value == split_op.operand_source(0):
# pattern case:
# this fwd_op's output is vectorType, it will split to
# Type by builtin.split op, so need get from split op's ouput
split_zero_flag, split_output_grad = make_output_grad(
split_op, None
)
zero_flag[i] = all(split_zero_flag)
grad_value = [op_list[0] for op_list in split_output_grad]
else:
# first case:
# this fwd_op's output didn't used by other fwd_op,
# so no output_grad created.
......@@ -336,7 +362,6 @@ def append_backward_ops(
# second case:
# last bwd_op return None because input in no_grad_set,
# but this bwd_op need a input.
grad_value = paddle.full(
value.shape,
0.0,
......@@ -347,25 +372,103 @@ def append_backward_ops(
update_bwdop_structure(
backward_ops, state.op_to_opgrad[op], fillop
)
state.value_to_valuegrad[value] = [[grad_value]]
zero_flag[i] = True
if len(state.value_to_valuegrad[value]) > 1:
# one value is input of more than one fwd_op,
# so more than one bwd_op create input_grad,
# need add sum op to accumulate gradient
state.value_to_valuegrad[value] = [[grad_value]]
paddle.add_n(list(state.value_to_valuegrad[value]))
sumop = block.ops[len(block.ops) - 1]
update_bwdop_structure(
backward_ops, state.op_to_opgrad[op], sumop
)
state.value_to_valuegrad[value] = [[sumop.result(0)]]
state.value_to_sumvaluegrad[
value
] = state.value_to_valuegrad[value]
if len(state.value_to_valuegrad[value]) > 1:
# one value is input of more than one fwd_op,
# so more than one bwd_op create input_grad,
# need add sum op to accumulate gradient
paddle.add_n(
[item[0] for item in state.value_to_valuegrad[value]]
)
combineop = block.ops[len(block.ops) - 2]
sumop = block.ops[len(block.ops) - 1]
update_bwdop_structure(
backward_ops, state.op_to_opgrad[op], combineop
)
update_bwdop_structure(
backward_ops, state.op_to_opgrad[op], sumop
)
state.value_to_valuegrad[value] = [[sumop.result(0)]]
state.value_to_sumvaluegrad[value] = state.value_to_valuegrad[
value
]
output_grad = state.value_to_valuegrad[value][0]
return zero_flag, output_grad
def make_input_stopgradient(combine_op, op):
input_grad_stopgradient_list = []
for input in op.operands_source():
if combine_op is not None and input == combine_op.result(0):
stop_gradient = make_input_stopgradient(None, combine_op)
input_grad_stopgradient_list.append(
[info[0] for info in stop_gradient]
)
else:
if input in no_grad_set:
input_grad_stopgradient_list.append([True])
else:
input_grad_stopgradient_list.append([False])
return input_grad_stopgradient_list
def update_input_grad_map(combine_op, op, input_grad_list):
for i, input in enumerate(op.operands_source()):
if combine_op is not None and input == combine_op.reslut(0):
update_input_grad_map(None, combine_op, input_grad_list[i])
else:
input_grad = input_grad_list[i]
if isinstance(input_grad, list):
state.value_to_valuegrad[input].append(input_grad)
else:
state.value_to_valuegrad[input].append([input_grad])
# make op to op pattern, there are four patterns:
# [builtin.combine , op1] (op1's one input is vectorType, outputs are not vectorType)
# [op2 , builtin.split] (op2's inputs are not vectorType, one output is vectorType)
# [builtin.combine , op3 , buitin.split] (op3's one input and one output are vectorType)
# [op4] (op4's inputs and outputs are not vectorType)
# einsum has twp vectorType outputs, special pattern
pattern_effective_op_list = []
for idx, op in enumerate(effective_forward_op):
if op.name() == "builtin.combine":
pattern_effective_op_list.append([op])
pattern_effective_op_list[-1].append(effective_forward_op[idx + 1])
elif op.name() == "builtin.split":
pattern_effective_op_list[-1].append(op)
else:
if (
not pattern_effective_op_list
or op not in pattern_effective_op_list[-1]
):
pattern_effective_op_list.append([op])
for op_pattern in pattern_effective_op_list:
combine_op = None
split_op = None
if len(op_pattern) == 1:
op = op_pattern[0]
elif len(op_pattern) == 2:
if op_pattern[0] == 'builtin.combine':
combine_op = op_pattern[0]
op = op_pattern[1]
else:
op = op_pattern[0]
split_op = op_pattern[1]
else:
combine_op = op_pattern[0]
op = op_pattern[1]
split_op = op_pattern[2]
output_grad = state.value_to_valuegrad[value][0]
if paddle.framework.core.has_vjp(op):
# prepare output_grad
output_grad_list = [] # (opresult)
zero_flag, output_grad = make_output_grad(op, split_op)
output_grad_list.append(output_grad)
# all(zero_flag) support this op has no contribution for grad
......@@ -374,42 +477,42 @@ def append_backward_ops(
continue
# prepare input_grad stop_gradient info.
input_grad_stopgradient_list = []
for input in op.operands_source():
if input in no_grad_set:
input_grad_stopgradient_list.append([True])
else:
input_grad_stopgradient_list.append([False])
input_grad_stopgradient_list = make_input_stopgradient(
combine_op, op
)
# create grad_op
before_ops_num = len(block.ops)
# prim should be a globel flag, it will make create_grad_op choose diffrient func
input_grad_list = paddle.framework.core.call_vjp(
op, output_grad_list, input_grad_stopgradient_list
)
after_ops_num = len(block.ops)
# find new grad_op_list
grad_op_list = []
# update grad_op structure
for i in range(before_ops_num, after_ops_num):
grad_op_list.append(block.ops[i])
for i, input in enumerate(op.operands()):
input_grad = input_grad_list[i]
state.value_to_valuegrad[input.source()].append(input_grad)
# add grad_op
for grad_op in grad_op_list:
update_bwdop_structure(
backward_ops, state.op_to_opgrad[op], grad_op
backward_ops, state.op_to_opgrad[op], block.ops[i]
)
# update input_grad map
update_input_grad_map(combine_op, op, input_grad_list)
else:
if op.num_operands() == 0 and op.num_results() != 0:
for value in op.results():
if len(state.value_to_valuegrad[value]) > 1:
# need add sum op
paddle.add_n(list(state.value_to_valuegrad[value]))
paddle.add_n(
[
item[0]
for item in state.value_to_valuegrad[value]
]
)
combineop = block.ops[len(block.ops) - 2]
sumop = block.ops[len(block.ops) - 1]
update_bwdop_structure(
backward_ops, state.op_to_opgrad[op], combineop
)
update_bwdop_structure(
backward_ops, state.op_to_opgrad[op], sumop
)
......@@ -441,7 +544,7 @@ def create_backward_purne_set(inputs, outputs, no_grad_set, state):
for key in state.value_to_sumvaluegrad:
if key in no_grad_set:
for item in state.value_to_valuegrad[key][0]:
for item in state.value_to_sumvaluegrad[key][0]:
no_gradvar_set.add(item)
return outputs_set, inputs_set, no_gradvar_set
......@@ -457,10 +560,14 @@ def remove_op(block, op, state):
state.op_to_opgrad[fwd_op].remove(op)
for valuegrad in op.results():
value = state.valuegrad_to_value[valuegrad][0]
state.value_to_valuegrad[value] = []
if value in state.sumvaluegrad_to_value:
raise ValueError('input_grad in [%s] is value which need to sum ')
if state.valuegrad_to_value[valuegrad] != []:
value = state.valuegrad_to_value[valuegrad][0]
state.value_to_valuegrad[value] = []
if value in state.sumvaluegrad_to_value:
raise ValueError(
'input_grad in [%s] is value which need to sum ', op.name()
)
def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set):
......@@ -500,13 +607,14 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set):
_, remove_ops = prune_ops(
backward_ops, inputs_set, outputs_set, no_gradvar_set
)
state.turn_map()
for bwd_op in inverse_sort_op(remove_ops):
remove_op(block, bwd_op, state)
input_grad_map = state.value_to_valuegrad
state.turn_map()
return input_grad_map
......
......@@ -7,7 +7,7 @@
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,tes
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -37,6 +37,10 @@ class State:
self.opgrad_to_op = collections.defaultdict(list)
def turn_map(self) -> None:
self.valuegrad_to_value = collections.defaultdict(list)
self.sumvaluegrad_to_value = collections.defaultdict(list)
self.opgrad_to_op = collections.defaultdict(list)
for k, v in self.value_to_valuegrad.items():
if v != []:
for value in v[0]:
......
......@@ -35,25 +35,8 @@ def get_ir_program_0():
return newir_program
def get_ir_program_1():
x = paddle.randn([2, 2])
main_program, start_program = (
paddle.static.Program(),
paddle.static.Program(),
)
with paddle.static.program_guard(main_program, start_program):
x_s = paddle.static.data('x', [4, 4], x.dtype)
y_s = paddle.static.data('y', [4, 4], x.dtype)
x_s.stop_gradient = False
z_x = paddle.tanh(y_s)
k_s = paddle.tanh(x_s)
out = paddle.add(z_x, k_s)
newir_program = ir.translate_to_new_ir(main_program.desc)
return newir_program
class TesBackward(unittest.TestCase):
def test_1(self):
class TesBackward_1(unittest.TestCase):
def test_grad(self):
newir_program = get_ir_program_0()
input = newir_program.block().ops[-1].operand(0).source()
tanh_out = newir_program.block().ops[-1].result(0)
......@@ -63,7 +46,6 @@ class TesBackward(unittest.TestCase):
out2 = paddle.mean(tanh_out)
input_grad = grad(out, input, out2)
print(newir_program)
self.assertEqual(out.get_defining_op().name(), "pd.mean")
self.assertEqual(input_grad[0].get_defining_op().name(), "pd.tanh_grad")
self.assertEqual(
......@@ -76,7 +58,7 @@ class TesBackward(unittest.TestCase):
)
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
def test_2(self):
def test_full(self):
# test create output_grad in backward use full op
newir_program = get_ir_program_0()
input = newir_program.block().ops[-1].operand(0).source()
......@@ -86,7 +68,6 @@ class TesBackward(unittest.TestCase):
out = paddle.mean(tanh_out)
input_grad = grad(out, input)
print(newir_program)
self.assertEqual(newir_program.block().ops[-3].name(), "pd.full")
self.assertEqual(input_grad[0].get_defining_op().name(), "pd.tanh_grad")
self.assertEqual(
......@@ -100,19 +81,55 @@ class TesBackward(unittest.TestCase):
)
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
# TODO(Ruting) test add_n op when add_n api and add_grad finished
# def test_3(self):
# # test add_n op
# newir_program = get_ir_program_1()
# input = newir_program.block().ops[-1].operand(0).source()
# tanh_out = newir_program.block().ops[-1].result(0)
# paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
# with paddle.ir.core.program_guard(newir_program):
# out = paddle.mean(tanh_out)
# input_grad = grad(out, input)
# print(newir_program)
# self.assertEqual(newir_program.block().ops[-1].name(), "pd.add_n")
def test_no_grad_set(self):
# test create output_grad in backward use full op
newir_program = get_ir_program_0()
input = newir_program.block().ops[-1].operand(0).source()
tanh_out = newir_program.block().ops[-1].result(0)
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
with paddle.ir.core.program_guard(newir_program):
out = paddle.mean(tanh_out)
input_grad = grad(out, input, no_grad_vars=[input])
self.assertEqual(newir_program.block().ops[-1].name(), "pd.mean")
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
def get_ir_program_1():
x = paddle.randn([2, 2])
main_program, start_program = (
paddle.static.Program(),
paddle.static.Program(),
)
with paddle.static.program_guard(main_program, start_program):
x_s = paddle.static.data('x', [4, 4], x.dtype)
y_s = paddle.static.data('y', [4, 4], x.dtype)
x_s.stop_gradient = False
y_s.stop_gradient = False
k_s = paddle.tanh(x_s)
z_x = paddle.tanh(x_s)
out = paddle.add(z_x, k_s)
newir_program = ir.translate_to_new_ir(main_program.desc)
return newir_program
class TesBackward_2(unittest.TestCase):
def test_add_n(self):
# test add_n op
newir_program = get_ir_program_1()
input_x = newir_program.block().ops[-3].operand(0).source()
add_out = newir_program.block().ops[-1].result(0)
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
with paddle.ir.core.program_guard(newir_program):
out = paddle.mean(add_out)
input_grad = grad(out, input_x)
self.assertEqual(newir_program.block().ops[-1].name(), "pd.add_n")
self.assertEqual(
newir_program.block().ops[-2].name(), "builtin.combine"
)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册