From e4a93b050dfcc613099e5c4e65e4bf3422383f33 Mon Sep 17 00:00:00 2001 From: xiongkun <807377414@qq.com> Date: Thu, 2 Mar 2023 21:13:03 +0800 Subject: [PATCH] Cxx prim custom vjp (#8) * [CINN]Enhance CacheKey hash logic by considering input dtypes (#50557) --------- Co-authored-by: jiangcheng * [prim] enable dygraph_to_static to support custom_vjp * Pr 50885 (#7) * [CINN]Enhance CacheKey hash logic by considering input dtypes (#50557) * [CINN]Enhance CacheKey hash logic by considering input dtypes --------- Co-authored-by: jiangcheng * [prim] enable dygraph_to_static to support custom_vjp * fix code in a dy2static-friendly way. * [dystatic] add hooker for prim --------- Co-authored-by: Aurelius84 Co-authored-by: jiangcheng Co-authored-by: cxxly * [prim] enable dygraph_to_static to support custom_vjp * fix cast prim and vjp dtype mapping error bug * [dy2static-ci] fix dy2static ci errors. --------- Co-authored-by: Aurelius84 Co-authored-by: jiangcheng Co-authored-by: cxxly --- python/paddle/fluid/framework.py | 1 - .../dygraph_to_static/test_cinn_prim.py | 14 +++++++++++-- .../dygraph_to_static/test_cinn_prim_gelu.py | 8 +++++++- .../test_cinn_prim_layer_norm.py | 18 +++++++++++++++-- .../dygraph_to_static/test_cinn_prim_mean.py | 16 +++++++++++++-- .../paddle/jit/dy2static/partial_program.py | 20 +++++++++++-------- .../jit/dy2static/program_translator.py | 1 - python/paddle/jit/dy2static/utils.py | 2 +- 8 files changed, 62 insertions(+), 18 deletions(-) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 8e3737b72d6..ac2dde6ba69 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -3751,7 +3751,6 @@ class Block: self.vars = collections.OrderedDict() # var_name --> var self.ops = list() # operator list self.program = program - self.removed_vars = collections.OrderedDict() def __str__(self): return self._to_readable_code() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py index 1a0fe1a6938..d25fe730308 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim.py @@ -77,7 +77,12 @@ class TestPrimForward(unittest.TestCase): def check_prim(self, net, use_prim): if not use_prim: return - fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] + fwd_ops = [ + op.type + for op in net.forward.get_concrete_program(self.x)[1] + .train_program.block(0) + .ops + ] # Ensure that softmax is splitted into small ops self.assertTrue('softmax' not in fwd_ops) @@ -128,7 +133,12 @@ class TestPrimForwardAndBackward(unittest.TestCase): def check_prim(self, net, use_prim): if not use_prim: return - fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] + fwd_ops = [ + op.type + for op in net.forward.get_concrete_program(self.x)[1] + .train_program.block(0) + .ops + ] all_ops = [ op.type for op in net.forward.program_cache.last()[-1][-1] diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_gelu.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_gelu.py index 2fce19b3943..ad68e1195a9 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_gelu.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_gelu.py @@ -77,6 +77,7 @@ class TestPrimForwardAndBackward(unittest.TestCase): net = apply_to_static(net, use_prim) res = [] + self.x = data for _ in range(10): out = net(data) loss = paddle.mean(out) @@ -92,7 +93,12 @@ class TestPrimForwardAndBackward(unittest.TestCase): def check_prim(self, net, use_prim): if not use_prim: return - fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] + fwd_ops = [ + op.type + for op in net.forward.get_concrete_program(self.x)[1] + .train_program.block(0) + .ops + ] # Ensure that gelu is splitted into small ops self.assertTrue('gelu' not in fwd_ops) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py index 6460515c0a8..28aac57b2f5 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_layer_norm.py @@ -89,7 +89,14 @@ class TestPrimForward(unittest.TestCase): def check_prim(self, net, use_prim): if not use_prim: return - fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] + fwd_ops = [ + op.type + for op in net.forward.get_concrete_program(self.x, self.w, self.b)[ + 1 + ] + .train_program.block(0) + .ops + ] # Ensure that layer_norm is splitted into small ops self.assertTrue('layer_norm' not in fwd_ops) @@ -150,7 +157,14 @@ class TestPrimForwardAndBackward(unittest.TestCase): def check_prim(self, net, use_prim): if not use_prim: return - fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] + fwd_ops = [ + op.type + for op in net.forward.get_concrete_program(self.x, self.w, self.b)[ + 1 + ] + .train_program.block(0) + .ops + ] # Ensure that layer_norm is splitted into small ops self.assertTrue('layer_norm' not in fwd_ops) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_mean.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_mean.py index e77388742af..ff18964f7a3 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_mean.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_cinn_prim_mean.py @@ -83,6 +83,7 @@ class TestPrimForward(unittest.TestCase): net = apply_to_static(net, use_prim) res = [] + self.x = data for _ in range(10): out = net(data) loss = paddle.mean(out, axis, keep_dim) @@ -99,7 +100,12 @@ class TestPrimForward(unittest.TestCase): def check_prim(self, net, use_prim): if not use_prim: return - fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] + fwd_ops = [ + op.type + for op in net.forward.get_concrete_program(self.x)[1] + .train_program.block(0) + .ops + ] # Ensure that reduce_mean is splitted into small ops self.assertTrue('reduce_mean' not in fwd_ops) @@ -150,6 +156,7 @@ class TestPrimForwardAndBackward(unittest.TestCase): net = apply_to_static(net, use_prim) res = [] + self.x = data for _ in range(10): out = net(data) loss = paddle.mean(out, axis, keep_dim) @@ -166,7 +173,12 @@ class TestPrimForwardAndBackward(unittest.TestCase): def check_prim(self, net, use_prim): if not use_prim: return - fwd_ops = [op.type for op in net.forward.main_program.block(0).ops] + fwd_ops = [ + op.type + for op in net.forward.get_concrete_program(self.x)[1] + .train_program.block(0) + .ops + ] # Ensure that reduce_mean is splitted into small ops self.assertTrue('reduce_mean' not in fwd_ops) diff --git a/python/paddle/jit/dy2static/partial_program.py b/python/paddle/jit/dy2static/partial_program.py index 595f97980f9..4afa8c1f90f 100644 --- a/python/paddle/jit/dy2static/partial_program.py +++ b/python/paddle/jit/dy2static/partial_program.py @@ -315,9 +315,7 @@ class PartialProgramLayer: def _create_forward_backward_train_program(self): whole_program = self._train_program # _, forward_end_op_index = self._infer_info('fp32', self._create_program) - forward_end_op_index = self._forward_end_index_map[ - _hash_with_id(whole_program, self) - ] + forward_end_op_index = self.get_forward_end_op_idx(whole_program) assert forward_end_op_index >= 0 return self._get_forward_backward_program_form( @@ -438,11 +436,14 @@ class PartialProgramLayer: def _param_grad_names(self): return _param_grad_names(self._train_program.desc, self._params) + def get_forward_end_op_idx(self, program): + return self._forward_end_index_map[_hash_with_id(program, self)] + @LazyInitialized def _out_grad_names(self): return _out_grad_names( self._train_program.desc, - self._create_program(is_infer_mode=True).desc.block(0).op_size(), + self.get_forward_end_op_idx(self._train_program), len(self._outputs.var_ids), ) @@ -642,6 +643,7 @@ class PartialProgramLayer: if isinstance(out, framework.Variable): targets.append(program.global_block().var(out.name)) + start_idx = len(program.block(0).ops) + len(self._outputs.tolist()) if targets: # TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch. core.check_and_set_prim_all_enabled() @@ -652,12 +654,11 @@ class PartialProgramLayer: program, start_idx = self._hooker.after_append_backward( self, program, start_idx ) - self._forward_end_index_map[ - _hash_with_id(program, self) - ] = start_idx - len(self._outputs.tolist()) - # TODO: prim make this complicate self.prepare_gradient_aggregation(start_idx, main_program, program) + self._forward_end_index_map[ + _hash_with_id(program, self) + ] = start_idx - len(self._outputs.tolist()) return program def _prune_unused_params(self, program): @@ -1155,5 +1156,8 @@ def add_build_strategy_for( if hasattr(compiled_program._program, 'lr_sheduler'): builded_program.lr_sheduler = compiled_program._program.lr_sheduler else: + # can't just create a new program, we need copy the vardesc. builded_program = paddle.static.Program() + for var in program.block(0).vars.values(): + builded_program.block(0)._clone_variable(var, False) return builded_program diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index 69a6e004606..e201915310e 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -1226,7 +1226,6 @@ class ProgramCache: partial_program.set_hooker(PrimHooker()) return concrete_program, partial_program - def __getitem__(self, item): if not isinstance(item, CacheKey): raise ValueError( diff --git a/python/paddle/jit/dy2static/utils.py b/python/paddle/jit/dy2static/utils.py index 34d628c1d35..b37ee05f9f0 100644 --- a/python/paddle/jit/dy2static/utils.py +++ b/python/paddle/jit/dy2static/utils.py @@ -1568,7 +1568,7 @@ def _out_grad_names(program_desc, fwd_end_op_index, out_size): min(fwd_end_op_index + out_size, program_desc.block(0).op_size()), ): op = program_desc.block(0).op(i) - if op.type() == 'fill_any_like': + if op.type() in ['fill_any_like', "fill_constant"]: var_name = op.output('Out')[0] names.append(var_name) return names -- GitLab