提交 e4a93b05 编写于 作者: X xiongkun 提交者: Xiaoxu Chen

Cxx prim custom vjp (#8)

* [CINN]Enhance CacheKey hash logic by considering input dtypes (#50557)

---------
Co-authored-by: Njiangcheng <thisjiang@qq.com>

* [prim] enable dygraph_to_static to support custom_vjp

* Pr 50885 (#7)

* [CINN]Enhance CacheKey hash logic by considering input dtypes (#50557)

* [CINN]Enhance CacheKey hash logic by considering input dtypes

---------
Co-authored-by: Njiangcheng <thisjiang@qq.com>

* [prim] enable dygraph_to_static to support custom_vjp

* fix code in a dy2static-friendly way.

* [dystatic] add hooker for prim

---------
Co-authored-by: NAurelius84 <zhangliujie@baidu.com>
Co-authored-by: Njiangcheng <thisjiang@qq.com>
Co-authored-by: Ncxxly <chenxx_id@163.com>

* [prim] enable dygraph_to_static to support custom_vjp

* fix cast prim and vjp dtype mapping error bug

* [dy2static-ci] fix dy2static ci errors.

---------
Co-authored-by: NAurelius84 <zhangliujie@baidu.com>
Co-authored-by: Njiangcheng <thisjiang@qq.com>
Co-authored-by: Ncxxly <chenxx_id@163.com>
上级 5dda91a8
...@@ -3751,7 +3751,6 @@ class Block: ...@@ -3751,7 +3751,6 @@ class Block:
self.vars = collections.OrderedDict() # var_name --> var self.vars = collections.OrderedDict() # var_name --> var
self.ops = list() # operator list self.ops = list() # operator list
self.program = program self.program = program
self.removed_vars = collections.OrderedDict()
def __str__(self): def __str__(self):
return self._to_readable_code() return self._to_readable_code()
......
...@@ -77,7 +77,12 @@ class TestPrimForward(unittest.TestCase): ...@@ -77,7 +77,12 @@ class TestPrimForward(unittest.TestCase):
def check_prim(self, net, use_prim): def check_prim(self, net, use_prim):
if not use_prim: if not use_prim:
return return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] fwd_ops = [
op.type
for op in net.forward.get_concrete_program(self.x)[1]
.train_program.block(0)
.ops
]
# Ensure that softmax is splitted into small ops # Ensure that softmax is splitted into small ops
self.assertTrue('softmax' not in fwd_ops) self.assertTrue('softmax' not in fwd_ops)
...@@ -128,7 +133,12 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -128,7 +133,12 @@ class TestPrimForwardAndBackward(unittest.TestCase):
def check_prim(self, net, use_prim): def check_prim(self, net, use_prim):
if not use_prim: if not use_prim:
return return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] fwd_ops = [
op.type
for op in net.forward.get_concrete_program(self.x)[1]
.train_program.block(0)
.ops
]
all_ops = [ all_ops = [
op.type op.type
for op in net.forward.program_cache.last()[-1][-1] for op in net.forward.program_cache.last()[-1][-1]
......
...@@ -77,6 +77,7 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -77,6 +77,7 @@ class TestPrimForwardAndBackward(unittest.TestCase):
net = apply_to_static(net, use_prim) net = apply_to_static(net, use_prim)
res = [] res = []
self.x = data
for _ in range(10): for _ in range(10):
out = net(data) out = net(data)
loss = paddle.mean(out) loss = paddle.mean(out)
...@@ -92,7 +93,12 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -92,7 +93,12 @@ class TestPrimForwardAndBackward(unittest.TestCase):
def check_prim(self, net, use_prim): def check_prim(self, net, use_prim):
if not use_prim: if not use_prim:
return return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] fwd_ops = [
op.type
for op in net.forward.get_concrete_program(self.x)[1]
.train_program.block(0)
.ops
]
# Ensure that gelu is splitted into small ops # Ensure that gelu is splitted into small ops
self.assertTrue('gelu' not in fwd_ops) self.assertTrue('gelu' not in fwd_ops)
......
...@@ -89,7 +89,14 @@ class TestPrimForward(unittest.TestCase): ...@@ -89,7 +89,14 @@ class TestPrimForward(unittest.TestCase):
def check_prim(self, net, use_prim): def check_prim(self, net, use_prim):
if not use_prim: if not use_prim:
return return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] fwd_ops = [
op.type
for op in net.forward.get_concrete_program(self.x, self.w, self.b)[
1
]
.train_program.block(0)
.ops
]
# Ensure that layer_norm is splitted into small ops # Ensure that layer_norm is splitted into small ops
self.assertTrue('layer_norm' not in fwd_ops) self.assertTrue('layer_norm' not in fwd_ops)
...@@ -150,7 +157,14 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -150,7 +157,14 @@ class TestPrimForwardAndBackward(unittest.TestCase):
def check_prim(self, net, use_prim): def check_prim(self, net, use_prim):
if not use_prim: if not use_prim:
return return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] fwd_ops = [
op.type
for op in net.forward.get_concrete_program(self.x, self.w, self.b)[
1
]
.train_program.block(0)
.ops
]
# Ensure that layer_norm is splitted into small ops # Ensure that layer_norm is splitted into small ops
self.assertTrue('layer_norm' not in fwd_ops) self.assertTrue('layer_norm' not in fwd_ops)
......
...@@ -83,6 +83,7 @@ class TestPrimForward(unittest.TestCase): ...@@ -83,6 +83,7 @@ class TestPrimForward(unittest.TestCase):
net = apply_to_static(net, use_prim) net = apply_to_static(net, use_prim)
res = [] res = []
self.x = data
for _ in range(10): for _ in range(10):
out = net(data) out = net(data)
loss = paddle.mean(out, axis, keep_dim) loss = paddle.mean(out, axis, keep_dim)
...@@ -99,7 +100,12 @@ class TestPrimForward(unittest.TestCase): ...@@ -99,7 +100,12 @@ class TestPrimForward(unittest.TestCase):
def check_prim(self, net, use_prim): def check_prim(self, net, use_prim):
if not use_prim: if not use_prim:
return return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] fwd_ops = [
op.type
for op in net.forward.get_concrete_program(self.x)[1]
.train_program.block(0)
.ops
]
# Ensure that reduce_mean is splitted into small ops # Ensure that reduce_mean is splitted into small ops
self.assertTrue('reduce_mean' not in fwd_ops) self.assertTrue('reduce_mean' not in fwd_ops)
...@@ -150,6 +156,7 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -150,6 +156,7 @@ class TestPrimForwardAndBackward(unittest.TestCase):
net = apply_to_static(net, use_prim) net = apply_to_static(net, use_prim)
res = [] res = []
self.x = data
for _ in range(10): for _ in range(10):
out = net(data) out = net(data)
loss = paddle.mean(out, axis, keep_dim) loss = paddle.mean(out, axis, keep_dim)
...@@ -166,7 +173,12 @@ class TestPrimForwardAndBackward(unittest.TestCase): ...@@ -166,7 +173,12 @@ class TestPrimForwardAndBackward(unittest.TestCase):
def check_prim(self, net, use_prim): def check_prim(self, net, use_prim):
if not use_prim: if not use_prim:
return return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] fwd_ops = [
op.type
for op in net.forward.get_concrete_program(self.x)[1]
.train_program.block(0)
.ops
]
# Ensure that reduce_mean is splitted into small ops # Ensure that reduce_mean is splitted into small ops
self.assertTrue('reduce_mean' not in fwd_ops) self.assertTrue('reduce_mean' not in fwd_ops)
......
...@@ -315,9 +315,7 @@ class PartialProgramLayer: ...@@ -315,9 +315,7 @@ class PartialProgramLayer:
def _create_forward_backward_train_program(self): def _create_forward_backward_train_program(self):
whole_program = self._train_program whole_program = self._train_program
# _, forward_end_op_index = self._infer_info('fp32', self._create_program) # _, forward_end_op_index = self._infer_info('fp32', self._create_program)
forward_end_op_index = self._forward_end_index_map[ forward_end_op_index = self.get_forward_end_op_idx(whole_program)
_hash_with_id(whole_program, self)
]
assert forward_end_op_index >= 0 assert forward_end_op_index >= 0
return self._get_forward_backward_program_form( return self._get_forward_backward_program_form(
...@@ -438,11 +436,14 @@ class PartialProgramLayer: ...@@ -438,11 +436,14 @@ class PartialProgramLayer:
def _param_grad_names(self): def _param_grad_names(self):
return _param_grad_names(self._train_program.desc, self._params) return _param_grad_names(self._train_program.desc, self._params)
def get_forward_end_op_idx(self, program):
return self._forward_end_index_map[_hash_with_id(program, self)]
@LazyInitialized @LazyInitialized
def _out_grad_names(self): def _out_grad_names(self):
return _out_grad_names( return _out_grad_names(
self._train_program.desc, self._train_program.desc,
self._create_program(is_infer_mode=True).desc.block(0).op_size(), self.get_forward_end_op_idx(self._train_program),
len(self._outputs.var_ids), len(self._outputs.var_ids),
) )
...@@ -642,6 +643,7 @@ class PartialProgramLayer: ...@@ -642,6 +643,7 @@ class PartialProgramLayer:
if isinstance(out, framework.Variable): if isinstance(out, framework.Variable):
targets.append(program.global_block().var(out.name)) targets.append(program.global_block().var(out.name))
start_idx = len(program.block(0).ops) + len(self._outputs.tolist())
if targets: if targets:
# TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch. # TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch.
core.check_and_set_prim_all_enabled() core.check_and_set_prim_all_enabled()
...@@ -652,12 +654,11 @@ class PartialProgramLayer: ...@@ -652,12 +654,11 @@ class PartialProgramLayer:
program, start_idx = self._hooker.after_append_backward( program, start_idx = self._hooker.after_append_backward(
self, program, start_idx self, program, start_idx
) )
self.prepare_gradient_aggregation(start_idx, main_program, program)
self._forward_end_index_map[ self._forward_end_index_map[
_hash_with_id(program, self) _hash_with_id(program, self)
] = start_idx - len(self._outputs.tolist()) ] = start_idx - len(self._outputs.tolist())
# TODO: prim make this complicate
self.prepare_gradient_aggregation(start_idx, main_program, program)
return program return program
def _prune_unused_params(self, program): def _prune_unused_params(self, program):
...@@ -1155,5 +1156,8 @@ def add_build_strategy_for( ...@@ -1155,5 +1156,8 @@ def add_build_strategy_for(
if hasattr(compiled_program._program, 'lr_sheduler'): if hasattr(compiled_program._program, 'lr_sheduler'):
builded_program.lr_sheduler = compiled_program._program.lr_sheduler builded_program.lr_sheduler = compiled_program._program.lr_sheduler
else: else:
# can't just create a new program, we need copy the vardesc.
builded_program = paddle.static.Program() builded_program = paddle.static.Program()
for var in program.block(0).vars.values():
builded_program.block(0)._clone_variable(var, False)
return builded_program return builded_program
...@@ -1226,7 +1226,6 @@ class ProgramCache: ...@@ -1226,7 +1226,6 @@ class ProgramCache:
partial_program.set_hooker(PrimHooker()) partial_program.set_hooker(PrimHooker())
return concrete_program, partial_program return concrete_program, partial_program
def __getitem__(self, item): def __getitem__(self, item):
if not isinstance(item, CacheKey): if not isinstance(item, CacheKey):
raise ValueError( raise ValueError(
......
...@@ -1568,7 +1568,7 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size): ...@@ -1568,7 +1568,7 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size):
min(fwd_end_op_index + out_size, program_desc.block(0).op_size()), min(fwd_end_op_index + out_size, program_desc.block(0).op_size()),
): ):
op = program_desc.block(0).op(i) op = program_desc.block(0).op(i)
if op.type() == 'fill_any_like': if op.type() in ['fill_any_like', "fill_constant"]:
var_name = op.output('Out')[0] var_name = op.output('Out')[0]
names.append(var_name) names.append(var_name)
return names return names
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册