未验证 提交 0399b39f 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【New IR】backward code of new ir (#55957)

* [prim][newir] add basic framework for primitive

* support desctensor in new ir

* add vjp interface

* support vjp in new ir

* support vjp in new ir

* polish vjp interface

* fix stop_gradients set

* fix vjp dispatch

* add comment

* add vjp test for new ir

* add test for tanh vjp

* [prim][newir] add basic framework for primitive

* support desctensor in new ir

* support vjp in new ir

* support vjp in new ir

* polish vjp interface

* fix stop_gradients set

* fix vjp dispatch

* add comment

* add vjp test for new ir

* add test for tanh vjp

* add eager and static backend for warp lower level api

* support call_vjp pybind

* polish code and add test for vjp

* remove useless code

* polish code

* remove useless code

* support mean vjp

* backward origin code

* add test for mean vjp and support has_vjp function

* fix call_vjp

* polish code

* add attrs and dtype interface

* add primitive ops set for backend

* fix compile bugs

* fix some bugs

* fix windows bugs

* add vjp test for tanh_

* fix inference CI

* fix inference ci

* modify fluid cmake

* origin test of tanh and mean passed

* fix conflict

* modify stop_gradient

* modify block.ops

* modify test

* fix conflict

* reply review comments

* reply review comments

* pulish code

---------
Co-authored-by: Ncxxly <chenxx_id@163.com>
Co-authored-by: NCharles-hit <wanghao107@baidu.com>
Co-authored-by: Nzhangbo9674 <zhangbo54@baidu.com>
Co-authored-by: NYuanRisheng <yuanrisheng@baidu.com>
上级 7cbb433a
......@@ -82,14 +82,15 @@ void BindBlock(py::module *m) {
block.def("front", &Block::front, return_value_policy::reference)
.def("get_parent_program",
[](Block &self) { return self.GetParentOp()->GetParentProgram(); })
.def("get_ops",
[](Block &self) -> py::list {
py::list op_list;
for (auto iter = self.begin(); iter != self.end(); iter++) {
op_list.append(*iter);
}
return op_list;
})
.def_property_readonly(
"ops",
[](Block &self) -> py::list {
py::list op_list;
for (auto iter = self.begin(); iter != self.end(); iter++) {
op_list.append(*iter);
}
return op_list;
})
.def("remove_op", [](Block &self, Operation *op) {
auto op_iter = std::find(self.begin(), self.end(), op);
self.erase(op_iter);
......
此差异已折叠。
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
class State:
"""
record relationship of forward op/value and backward op/value
one state must be bining with a program
"""
def __init__(self, program):
self.program = program
# opresult -> list(list(opresult))
self.value_to_valuegrad = collections.defaultdict(list)
self.value_to_sumvaluegrad = collections.defaultdict(list)
# operation -> list(operation)
self.op_to_opgrad = collections.defaultdict(list)
# opresult -> list(opresult)
self.valuegrad_to_value = collections.defaultdict(list)
self.sumvaluegrad_to_value = collections.defaultdict(list)
# operation -> list(operation)
self.opgrad_to_op = collections.defaultdict(list)
def turn_map(self) -> None:
for k, v in self.value_to_valuegrad.items():
if v != []:
for value in v[0]:
self.valuegrad_to_value[value] = [k]
for k, v in self.value_to_sumvaluegrad.items():
if v != []:
for value in v[0]:
self.sumvaluegrad_to_value[value] = [k]
for k, v in self.op_to_opgrad.items():
if v != []:
self.opgrad_to_op[v[0]] = [k]
......@@ -162,7 +162,7 @@ def _decompose_subgraph(block, op_filter):
"""
if isinstance(block, Block):
ops_list = block.get_ops()
ops_list = block.ops
for op in ops_list:
op_name = op.name()
decom_rule = register.get_decomp_rule(op_name)
......
......@@ -39,7 +39,7 @@ def get_ir_program():
class TestBuildOp(unittest.TestCase):
def test_build_mean_op(self):
newir_program = get_ir_program()
tanh_out = newir_program.block().get_ops()[-1].result(0)
tanh_out = newir_program.block().ops[-1].result(0)
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
with paddle.ir.core.program_guard(newir_program):
out = paddle.mean(tanh_out)
......@@ -82,8 +82,8 @@ class TestBuildOp3(unittest.TestCase):
def test_insertion_point(self):
newir_program = get_ir_program()
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
add_op = newir_program.block().get_ops()[-2]
tanh_op = newir_program.block().get_ops()[-1]
add_op = newir_program.block().ops[-2]
tanh_op = newir_program.block().ops[-1]
add_out = add_op.result(0)
tanh_operand = tanh_op.operands()[0]
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from paddle import ir
from paddle.autograd.backward import grad
paddle.enable_static()
def get_ir_program_0():
x = paddle.randn([4, 4])
main_program, start_program = (
paddle.static.Program(),
paddle.static.Program(),
)
with paddle.static.program_guard(main_program, start_program):
x_s = paddle.static.data('x', [4, 4], x.dtype)
x_s.stop_gradient = False
k_s = paddle.tanh(x_s)
newir_program = ir.translate_to_new_ir(main_program.desc)
return newir_program
def get_ir_program_1():
x = paddle.randn([2, 2])
main_program, start_program = (
paddle.static.Program(),
paddle.static.Program(),
)
with paddle.static.program_guard(main_program, start_program):
x_s = paddle.static.data('x', [4, 4], x.dtype)
y_s = paddle.static.data('y', [4, 4], x.dtype)
x_s.stop_gradient = False
z_x = paddle.tanh(y_s)
k_s = paddle.tanh(x_s)
out = paddle.add(z_x, k_s)
newir_program = ir.translate_to_new_ir(main_program.desc)
return newir_program
class TesBackward(unittest.TestCase):
def test_1(self):
newir_program = get_ir_program_0()
input = newir_program.block().ops[-1].operand(0).source()
tanh_out = newir_program.block().ops[-1].result(0)
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
with paddle.ir.core.program_guard(newir_program):
out = paddle.mean(tanh_out)
out2 = paddle.mean(tanh_out)
input_grad = grad(out, input, out2)
print(newir_program)
self.assertEqual(out.get_defining_op().name(), "pd.mean")
self.assertEqual(input_grad[0].get_defining_op().name(), "pd.tanh_grad")
self.assertEqual(
out.get_defining_op()
.operands()[0]
.source()
.get_defining_op()
.name(),
"pd.tanh",
)
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
def test_2(self):
# test create output_grad in backward use full op
newir_program = get_ir_program_0()
input = newir_program.block().ops[-1].operand(0).source()
tanh_out = newir_program.block().ops[-1].result(0)
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
with paddle.ir.core.program_guard(newir_program):
out = paddle.mean(tanh_out)
input_grad = grad(out, input)
print(newir_program)
self.assertEqual(newir_program.block().ops[-3].name(), "pd.full")
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
# TODO(Ruting) test add_n op when add_n api and add_grad finished
# def test_3(self):
# # test add_n op
# newir_program = get_ir_program_1()
# input = newir_program.block().ops[-1].operand(0).source()
# tanh_out = newir_program.block().ops[-1].result(0)
# paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
# with paddle.ir.core.program_guard(newir_program):
# out = paddle.mean(tanh_out)
# input_grad = grad(out, input)
# print(newir_program)
# self.assertEqual(newir_program.block().ops[-1].name(), "pd.add_n")
if __name__ == "__main__":
unittest.main()
......@@ -49,21 +49,21 @@ class TestPybind(unittest.TestCase):
def test_block(self):
newir_program = get_ir_program()
block = newir_program.block()
ops = block.get_ops()
ops = block.ops
self.assertEqual(
len(ops), 4
) # ir program add "builtin.get_parameter" by default, so size is 4
block.remove_op(ops[3])
self.assertEqual(len(block.get_ops()), 3)
self.assertEqual(len(block.ops), 3)
def test_operation(self):
newir_program = get_ir_program()
ops = newir_program.block().get_ops()
matmul_op = newir_program.block().get_ops()[1]
add_op = newir_program.block().get_ops()[2]
tanh_op = newir_program.block().get_ops()[3]
ops = newir_program.block().ops
matmul_op = newir_program.block().ops[1]
add_op = newir_program.block().ops[2]
tanh_op = newir_program.block().ops[3]
parent_block = tanh_op.get_parent_block()
parent_ops_num = len(parent_block.get_ops())
parent_ops_num = len(parent_block.ops)
self.assertEqual(parent_ops_num, 4)
self.assertEqual(tanh_op.num_results(), 1)
self.assertEqual(len(matmul_op.get_input_names()), 2)
......@@ -72,9 +72,9 @@ class TestPybind(unittest.TestCase):
def test_value(self):
newir_program = get_ir_program()
matmul_op = newir_program.block().get_ops()[1]
add_op = newir_program.block().get_ops()[2]
tanh_op = newir_program.block().get_ops()[3]
matmul_op = newir_program.block().ops[1]
add_op = newir_program.block().ops[2]
tanh_op = newir_program.block().ops[3]
self.assertEqual(
matmul_op.result(0).dtype, paddle.fluid.core.DataType.FLOAT32
......@@ -123,8 +123,8 @@ class TestPybind(unittest.TestCase):
def test_type(self):
newir_program = get_ir_program()
matmul_op = newir_program.block().get_ops()[1]
add_op = newir_program.block().get_ops()[2]
matmul_op = newir_program.block().ops[1]
add_op = newir_program.block().ops[2]
print(matmul_op.result(0).type())
self.assertEqual(
matmul_op.result(0).type() == add_op.result(0).type(), True
......@@ -152,8 +152,8 @@ class TestPybind(unittest.TestCase):
newir_program = ir.translate_to_new_ir(main_program.desc)
print(newir_program)
conv_attr = newir_program.block().get_ops()[3].attrs()
full_attr = newir_program.block().get_ops()[8].attrs()
conv_attr = newir_program.block().ops[3].attrs()
full_attr = newir_program.block().ops[8].attrs()
self.assertEqual(conv_attr["stop_gradient"], [False])
self.assertEqual(conv_attr["dilations"], [1, 1])
self.assertEqual(conv_attr["data_format"], "NCHW")
......@@ -166,13 +166,13 @@ class TestPybind(unittest.TestCase):
def test_operands(self):
newir_program = get_ir_program()
matmul_op = newir_program.block().get_ops()[1]
matmul_op = newir_program.block().ops[1]
operands = matmul_op.operands()
self.assertEqual(len(operands), 2)
def test_results(self):
newir_program = get_ir_program()
matmul_op = newir_program.block().get_ops()[1]
matmul_op = newir_program.block().ops[1]
results = matmul_op.results()
self.assertEqual(len(results), 1)
......
......@@ -38,8 +38,8 @@ def get_ir_program():
class TestTanhVjp(unittest.TestCase):
def test_tanh_vjp1(self):
newir_program = get_ir_program()
tanh_op = newir_program.block().get_ops()[-2]
fill_constant_op = newir_program.block().get_ops()[-1]
tanh_op = newir_program.block().ops[-2]
fill_constant_op = newir_program.block().ops[-1]
out_grads = [[fill_constant_op.result(0)]]
stop_gradients = [[0]]
with paddle.ir.core.program_guard(newir_program):
......@@ -65,12 +65,12 @@ class TestTanhVjp(unittest.TestCase):
.name(),
"pd.full",
)
self.assertEqual(len(newir_program.block().get_ops()), 4)
self.assertEqual(len(newir_program.block().ops), 4)
def test_tanh_vjp2(self):
newir_program = get_ir_program()
tanh_op = newir_program.block().get_ops()[-2]
fill_constant_op = newir_program.block().get_ops()[-1]
tanh_op = newir_program.block().ops[-2]
fill_constant_op = newir_program.block().ops[-1]
out_grads = [[fill_constant_op.result(0)]]
stop_gradients = [[1]]
with paddle.ir.core.program_guard(newir_program):
......@@ -90,8 +90,8 @@ class TestMeanVjp(unittest.TestCase):
paddle.mean(x, axis=[0, 1])
paddle.tensor.fill_constant(shape=[1], dtype='float32', value=2.0)
newir_program = ir.translate_to_new_ir(main_program.desc)
fill_constant_op = newir_program.block().get_ops()[-1]
mean_op = newir_program.block().get_ops()[-2]
fill_constant_op = newir_program.block().ops[-1]
mean_op = newir_program.block().ops[-2]
out_grads = [[fill_constant_op.result(0)]]
stop_gradients = [[0]]
with paddle.ir.core.program_guard(newir_program):
......@@ -117,7 +117,7 @@ class TestMeanVjp(unittest.TestCase):
.name(),
"pd.full",
)
self.assertEqual(len(newir_program.block().get_ops()), 4)
self.assertEqual(len(newir_program.block().ops), 4)
def test_mean_vjp2(self):
main_program, start_program = (
......@@ -130,8 +130,8 @@ class TestMeanVjp(unittest.TestCase):
paddle.mean(x, axis=[0, 1])
paddle.tensor.fill_constant(shape=[1], dtype='float32', value=2.0)
newir_program = ir.translate_to_new_ir(main_program.desc)
fill_constant_op = newir_program.block().get_ops()[-1]
mean_op = newir_program.block().get_ops()[-2]
fill_constant_op = newir_program.block().ops[-1]
mean_op = newir_program.block().ops[-2]
out_grads = [[fill_constant_op.result(0)]]
stop_gradients = [[1]]
with paddle.ir.core.program_guard(newir_program):
......@@ -151,8 +151,8 @@ class TesthasVjp(unittest.TestCase):
paddle.mean(x, axis=[0, 1])
paddle.tensor.fill_constant(shape=[1], dtype='float32', value=2.0)
newir_program = ir.translate_to_new_ir(main_program.desc)
fill_constant_op = newir_program.block().get_ops()[-1]
mean_op = newir_program.block().get_ops()[-2]
fill_constant_op = newir_program.block().ops[-1]
mean_op = newir_program.block().ops[-2]
self.assertEqual(has_vjp(fill_constant_op), False)
self.assertEqual(has_vjp(mean_op), True)
......
......@@ -43,7 +43,7 @@ class TestBuildOp(unittest.TestCase):
newir_program = get_ir_program()
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
decompose(newir_program)
op_name_list = [op.name() for op in newir_program.block().get_ops()]
op_name_list = [op.name() for op in newir_program.block().ops]
self.assertEqual(
op_name_list,
[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册