提交 e4a93b05 编写于 作者: X xiongkun 提交者: Xiaoxu Chen

Cxx prim custom vjp (#8)

* [CINN]Enhance CacheKey hash logic by considering input dtypes (#50557)

---------
Co-authored-by: Njiangcheng <thisjiang@qq.com>

* [prim] enable dygraph_to_static to support custom_vjp

* Pr 50885 (#7)

* [CINN]Enhance CacheKey hash logic by considering input dtypes (#50557)

* [CINN]Enhance CacheKey hash logic by considering input dtypes

---------
Co-authored-by: Njiangcheng <thisjiang@qq.com>

* [prim] enable dygraph_to_static to support custom_vjp

* fix code in a dy2static-friendly way.

* [dystatic] add hooker for prim

---------
Co-authored-by: NAurelius84 <zhangliujie@baidu.com>
Co-authored-by: Njiangcheng <thisjiang@qq.com>
Co-authored-by: Ncxxly <chenxx_id@163.com>

* [prim] enable dygraph_to_static to support custom_vjp

* fix cast prim and vjp dtype mapping error bug

* [dy2static-ci] fix dy2static ci errors.

---------
Co-authored-by: NAurelius84 <zhangliujie@baidu.com>
Co-authored-by: Njiangcheng <thisjiang@qq.com>
Co-authored-by: Ncxxly <chenxx_id@163.com>
上级 5dda91a8
......@@ -3751,7 +3751,6 @@ class Block:
self.vars = collections.OrderedDict() # var_name --> var
self.ops = list() # operator list
self.program = program
self.removed_vars = collections.OrderedDict()
def __str__(self):
return self._to_readable_code()
......
......@@ -77,7 +77,12 @@ class TestPrimForward(unittest.TestCase):
def check_prim(self, net, use_prim):
if not use_prim:
return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
fwd_ops = [
op.type
for op in net.forward.get_concrete_program(self.x)[1]
.train_program.block(0)
.ops
]
# Ensure that softmax is splitted into small ops
self.assertTrue('softmax' not in fwd_ops)
......@@ -128,7 +133,12 @@ class TestPrimForwardAndBackward(unittest.TestCase):
def check_prim(self, net, use_prim):
if not use_prim:
return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
fwd_ops = [
op.type
for op in net.forward.get_concrete_program(self.x)[1]
.train_program.block(0)
.ops
]
all_ops = [
op.type
for op in net.forward.program_cache.last()[-1][-1]
......
......@@ -77,6 +77,7 @@ class TestPrimForwardAndBackward(unittest.TestCase):
net = apply_to_static(net, use_prim)
res = []
self.x = data
for _ in range(10):
out = net(data)
loss = paddle.mean(out)
......@@ -92,7 +93,12 @@ class TestPrimForwardAndBackward(unittest.TestCase):
def check_prim(self, net, use_prim):
if not use_prim:
return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
fwd_ops = [
op.type
for op in net.forward.get_concrete_program(self.x)[1]
.train_program.block(0)
.ops
]
# Ensure that gelu is splitted into small ops
self.assertTrue('gelu' not in fwd_ops)
......
......@@ -89,7 +89,14 @@ class TestPrimForward(unittest.TestCase):
def check_prim(self, net, use_prim):
if not use_prim:
return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
fwd_ops = [
op.type
for op in net.forward.get_concrete_program(self.x, self.w, self.b)[
1
]
.train_program.block(0)
.ops
]
# Ensure that layer_norm is splitted into small ops
self.assertTrue('layer_norm' not in fwd_ops)
......@@ -150,7 +157,14 @@ class TestPrimForwardAndBackward(unittest.TestCase):
def check_prim(self, net, use_prim):
if not use_prim:
return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
fwd_ops = [
op.type
for op in net.forward.get_concrete_program(self.x, self.w, self.b)[
1
]
.train_program.block(0)
.ops
]
# Ensure that layer_norm is splitted into small ops
self.assertTrue('layer_norm' not in fwd_ops)
......
......@@ -83,6 +83,7 @@ class TestPrimForward(unittest.TestCase):
net = apply_to_static(net, use_prim)
res = []
self.x = data
for _ in range(10):
out = net(data)
loss = paddle.mean(out, axis, keep_dim)
......@@ -99,7 +100,12 @@ class TestPrimForward(unittest.TestCase):
def check_prim(self, net, use_prim):
if not use_prim:
return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
fwd_ops = [
op.type
for op in net.forward.get_concrete_program(self.x)[1]
.train_program.block(0)
.ops
]
# Ensure that reduce_mean is splitted into small ops
self.assertTrue('reduce_mean' not in fwd_ops)
......@@ -150,6 +156,7 @@ class TestPrimForwardAndBackward(unittest.TestCase):
net = apply_to_static(net, use_prim)
res = []
self.x = data
for _ in range(10):
out = net(data)
loss = paddle.mean(out, axis, keep_dim)
......@@ -166,7 +173,12 @@ class TestPrimForwardAndBackward(unittest.TestCase):
def check_prim(self, net, use_prim):
if not use_prim:
return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
fwd_ops = [
op.type
for op in net.forward.get_concrete_program(self.x)[1]
.train_program.block(0)
.ops
]
# Ensure that reduce_mean is splitted into small ops
self.assertTrue('reduce_mean' not in fwd_ops)
......
......@@ -315,9 +315,7 @@ class PartialProgramLayer:
def _create_forward_backward_train_program(self):
whole_program = self._train_program
# _, forward_end_op_index = self._infer_info('fp32', self._create_program)
forward_end_op_index = self._forward_end_index_map[
_hash_with_id(whole_program, self)
]
forward_end_op_index = self.get_forward_end_op_idx(whole_program)
assert forward_end_op_index >= 0
return self._get_forward_backward_program_form(
......@@ -438,11 +436,14 @@ class PartialProgramLayer:
def _param_grad_names(self):
return _param_grad_names(self._train_program.desc, self._params)
def get_forward_end_op_idx(self, program):
return self._forward_end_index_map[_hash_with_id(program, self)]
@LazyInitialized
def _out_grad_names(self):
return _out_grad_names(
self._train_program.desc,
self._create_program(is_infer_mode=True).desc.block(0).op_size(),
self.get_forward_end_op_idx(self._train_program),
len(self._outputs.var_ids),
)
......@@ -642,6 +643,7 @@ class PartialProgramLayer:
if isinstance(out, framework.Variable):
targets.append(program.global_block().var(out.name))
start_idx = len(program.block(0).ops) + len(self._outputs.tolist())
if targets:
# TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch.
core.check_and_set_prim_all_enabled()
......@@ -652,12 +654,11 @@ class PartialProgramLayer:
program, start_idx = self._hooker.after_append_backward(
self, program, start_idx
)
self._forward_end_index_map[
_hash_with_id(program, self)
] = start_idx - len(self._outputs.tolist())
# TODO: prim make this complicate
self.prepare_gradient_aggregation(start_idx, main_program, program)
self._forward_end_index_map[
_hash_with_id(program, self)
] = start_idx - len(self._outputs.tolist())
return program
def _prune_unused_params(self, program):
......@@ -1155,5 +1156,8 @@ def add_build_strategy_for(
if hasattr(compiled_program._program, 'lr_sheduler'):
builded_program.lr_sheduler = compiled_program._program.lr_sheduler
else:
# can't just create a new program, we need copy the vardesc.
builded_program = paddle.static.Program()
for var in program.block(0).vars.values():
builded_program.block(0)._clone_variable(var, False)
return builded_program
......@@ -1226,7 +1226,6 @@ class ProgramCache:
partial_program.set_hooker(PrimHooker())
return concrete_program, partial_program
def __getitem__(self, item):
if not isinstance(item, CacheKey):
raise ValueError(
......
......@@ -1568,7 +1568,7 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size):
min(fwd_end_op_index + out_size, program_desc.block(0).op_size()),
):
op = program_desc.block(0).op(i)
if op.type() == 'fill_any_like':
if op.type() in ['fill_any_like', "fill_constant"]:
var_name = op.output('Out')[0]
names.append(var_name)
return names
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册