From 42869ab66cba6dd87a783022669703b1297375b4 Mon Sep 17 00:00:00 2001 From: Charles-hit <56987902+Charles-hit@users.noreply.github.com> Date: Tue, 5 Sep 2023 10:38:24 +0800 Subject: [PATCH] support more vjp code gen (#56890) --- .../fluid/operators/generator/tests_utils.py | 21 ++--- paddle/fluid/primitive/codegen/gen.py | 93 +++++++++++++++---- test/prim/new_ir_prim/test_vjp_prim.py | 22 +++-- 3 files changed, 99 insertions(+), 37 deletions(-) diff --git a/paddle/fluid/operators/generator/tests_utils.py b/paddle/fluid/operators/generator/tests_utils.py index 5f2e06eb89d..cfcf6743816 100644 --- a/paddle/fluid/operators/generator/tests_utils.py +++ b/paddle/fluid/operators/generator/tests_utils.py @@ -86,19 +86,18 @@ def is_tensor_list(s): return s == 'Tensor[]' -def exist_mutable_attribute(attrs): - for attr in attrs: +def exist_mutable_attribute(attributes): + for attribute in attributes: if ( - attr['typename'] in ['Scalar', 'IntArray'] - and attr['support_tensor'] is True - ): + is_scalar(attribute['typename']) + or is_intarray(attribute['typename']) + ) and attribute.get('support_tensor', False): return True - else: - return False + else: + return False -def is_mutable_attribute(attr): +def is_mutable_attribute(attribute): return ( - attr['typename'] in ['Scalar', 'IntArray'] - and attr['support_tensor'] is True - ) + is_scalar(attribute['typename']) or is_intarray(attribute['typename']) + ) and attribute.get('support_tensor', False) diff --git a/paddle/fluid/primitive/codegen/gen.py b/paddle/fluid/primitive/codegen/gen.py index 7b3072675ec..89ba4fe53cd 100644 --- a/paddle/fluid/primitive/codegen/gen.py +++ b/paddle/fluid/primitive/codegen/gen.py @@ -45,6 +45,22 @@ VJPS = [ 'sum_grad', 'concat_grad', 'split_grad', + 'gelu_grad', + 'softmax_grad', + 'silu_grad', + 'multiply_grad', + 'subtract_grad', + 'erf_grad', + 'expand_grad', + 'exp_grad', + 'elementwise_pow_grad', + 'fused_softmax_mask_upper_triangle_grad', + 'matmul_grad', + 'pow_grad', + 'reshape_grad', + 'rsqrt_grad', + 'slice_grad', + 'transpose_grad', ] VJP_COMPS = ['divide_grad', 'sum_grad'] BACKENDS = [ @@ -68,6 +84,49 @@ BACKENDS = [ 'sum_grad', 'concat_grad', 'split_grad', + 'gelu_grad', + 'softmax_grad', + 'silu_grad', + 'multiply_grad', + 'subtract_grad', + 'erf_grad', + 'expand_grad', + 'exp_grad', + 'multiply', + 'exp', + 'erf', + 'cast', + 'elementwise_pow_grad', + 'fused_softmax_mask_upper_triangle_grad', + 'matmul_grad', + 'pow_grad', + 'reshape_grad', + 'rsqrt_grad', + 'slice_grad', + 'transpose_grad', + 'subtract', + 'assign', + 'equal', + 'greater_equal', + 'greater_than', + 'less_equal', + 'less_than', + 'matmul', + 'max', + 'maximum', + 'minimum', + 'not_equal', + 'abs', + 'bitwise_and', + 'bitwise_not', + 'bitwise_or', + 'bitwise_xor', + 'floor', + 'gather_nd', + 'log', + 'roll', + 'scatter', + 'scatter_nd_add', ] @@ -157,21 +216,6 @@ def save(content: str, path: pathlib.Path): print(f"Generate source file {path}") -def filter_compat_info(items): - for item in items: - item['op'] = item['op'].split('(')[0].strip() - if 'backward' in item: - item_backwards = item['backward'].split(',') - for idx, item_backward in enumerate(item_backwards): - item_backward = item_backward.split('(')[0].strip() - item_backwards[idx] = item_backward - item['backward'] = ( - ','.join(item_backwards) - if len(item_backwards) > 0 - else item_backwards[0] - ) - - def to_compat_dict(items: List[Dict]) -> Dict[str, Dict]: compat_dict = {} for item in items: @@ -201,11 +245,28 @@ def get_inplace_api(apis): return inplace_apis +def filter_compat_info(items): + for item in items: + item['op'] = item['op'].split('(')[0].strip() + if 'backward' in item: + item_backwards = item['backward'].split(',') + for idx, item_backward in enumerate(item_backwards): + item_backward = item_backward.split('(')[0].strip() + item_backwards[idx] = item_backward + item['backward'] = ( + ','.join(item_backwards) + if len(item_backwards) > 0 + else item_backwards[0] + ) + + def extend_compat_info(apis, compats): for api in apis: attrs = api["attrs"] for attr in attrs: - if attr['typename'] in ["Scalar", "IntArray"]: + if op_gen_tests.is_scalar( + attr['typename'] + ) or op_gen_tests.is_intarray(attr['typename']): attr["support_tensor"] = False apis_dict = to_apis_dict(apis) for compat_item in compats: diff --git a/test/prim/new_ir_prim/test_vjp_prim.py b/test/prim/new_ir_prim/test_vjp_prim.py index 46ff348734d..2a29ae9f69f 100644 --- a/test/prim/new_ir_prim/test_vjp_prim.py +++ b/test/prim/new_ir_prim/test_vjp_prim.py @@ -21,7 +21,7 @@ from paddle.fluid.core import call_vjp paddle.enable_static() -def get_ir_program_0(): +def get_ir_divide_program(): main_program, start_program = ( paddle.static.Program(), paddle.static.Program(), @@ -42,7 +42,7 @@ def get_ir_program_0(): return newir_program -def get_ir_program_1(): +def get_ir_sum_program(): main_program, start_program = ( paddle.static.Program(), paddle.static.Program(), @@ -61,8 +61,8 @@ def get_ir_program_1(): class TestVjpPrim(unittest.TestCase): def test_divide_grad_prim_case1(self): - newir_program = get_ir_program_0() - paddle.fluid.core._set_prim_backward_enabled(True) + newir_program = get_ir_divide_program() + paddle.framework.core._set_prim_backward_enabled(True) dout = newir_program.block().ops[-2].result(0) out_grads = [[dout]] stop_gradients = [[False], [False]] @@ -100,10 +100,11 @@ class TestVjpPrim(unittest.TestCase): ] for idx, op in enumerate(newir_program.block().ops): self.assertEqual(op.name(), all_op_names[idx]) + paddle.framework.core._set_prim_backward_enabled(False) def test_divide_grad_no_prim(self): - newir_program = get_ir_program_0() - paddle.fluid.core._set_prim_backward_enabled(False) + newir_program = get_ir_divide_program() + paddle.framework.core._set_prim_backward_enabled(False) dout = newir_program.block().ops[-2].result(0) out_grads = [[dout]] stop_gradients = [[False], [False]] @@ -120,8 +121,8 @@ class TestVjpPrim(unittest.TestCase): self.assertEqual(len(newir_program.block().ops), 5) def test_sum_grad_prim(self): - newir_program = get_ir_program_1() - paddle.fluid.core._set_prim_backward_enabled(True) + newir_program = get_ir_sum_program() + paddle.framework.core._set_prim_backward_enabled(True) dout = newir_program.block().ops[-3].result(0) out_grads = [[dout]] stop_gradients = [[False], [True]] @@ -145,10 +146,11 @@ class TestVjpPrim(unittest.TestCase): ] for idx, op in enumerate(newir_program.block().ops): self.assertEqual(op.name(), all_op_names[idx]) + paddle.framework.core._set_prim_backward_enabled(False) def test_sum_grad_no_prim(self): - newir_program = get_ir_program_1() - paddle.fluid.core._set_prim_backward_enabled(False) + newir_program = get_ir_sum_program() + paddle.framework.core._set_prim_backward_enabled(False) dout = newir_program.block().ops[-2].result(0) out_grads = [[dout]] stop_gradients = [[False], [True]] -- GitLab