未验证 提交 42869ab6 编写于 作者: C Charles-hit 提交者: GitHub

support more vjp code gen (#56890)

上级 41acf19b
...@@ -86,19 +86,18 @@ def is_tensor_list(s): ...@@ -86,19 +86,18 @@ def is_tensor_list(s):
return s == 'Tensor[]' return s == 'Tensor[]'
def exist_mutable_attribute(attrs): def exist_mutable_attribute(attributes):
for attr in attrs: for attribute in attributes:
if ( if (
attr['typename'] in ['Scalar', 'IntArray'] is_scalar(attribute['typename'])
and attr['support_tensor'] is True or is_intarray(attribute['typename'])
): ) and attribute.get('support_tensor', False):
return True return True
else: else:
return False return False
def is_mutable_attribute(attr): def is_mutable_attribute(attribute):
return ( return (
attr['typename'] in ['Scalar', 'IntArray'] is_scalar(attribute['typename']) or is_intarray(attribute['typename'])
and attr['support_tensor'] is True ) and attribute.get('support_tensor', False)
)
...@@ -45,6 +45,22 @@ VJPS = [ ...@@ -45,6 +45,22 @@ VJPS = [
'sum_grad', 'sum_grad',
'concat_grad', 'concat_grad',
'split_grad', 'split_grad',
'gelu_grad',
'softmax_grad',
'silu_grad',
'multiply_grad',
'subtract_grad',
'erf_grad',
'expand_grad',
'exp_grad',
'elementwise_pow_grad',
'fused_softmax_mask_upper_triangle_grad',
'matmul_grad',
'pow_grad',
'reshape_grad',
'rsqrt_grad',
'slice_grad',
'transpose_grad',
] ]
VJP_COMPS = ['divide_grad', 'sum_grad'] VJP_COMPS = ['divide_grad', 'sum_grad']
BACKENDS = [ BACKENDS = [
...@@ -68,6 +84,49 @@ BACKENDS = [ ...@@ -68,6 +84,49 @@ BACKENDS = [
'sum_grad', 'sum_grad',
'concat_grad', 'concat_grad',
'split_grad', 'split_grad',
'gelu_grad',
'softmax_grad',
'silu_grad',
'multiply_grad',
'subtract_grad',
'erf_grad',
'expand_grad',
'exp_grad',
'multiply',
'exp',
'erf',
'cast',
'elementwise_pow_grad',
'fused_softmax_mask_upper_triangle_grad',
'matmul_grad',
'pow_grad',
'reshape_grad',
'rsqrt_grad',
'slice_grad',
'transpose_grad',
'subtract',
'assign',
'equal',
'greater_equal',
'greater_than',
'less_equal',
'less_than',
'matmul',
'max',
'maximum',
'minimum',
'not_equal',
'abs',
'bitwise_and',
'bitwise_not',
'bitwise_or',
'bitwise_xor',
'floor',
'gather_nd',
'log',
'roll',
'scatter',
'scatter_nd_add',
] ]
...@@ -157,21 +216,6 @@ def save(content: str, path: pathlib.Path): ...@@ -157,21 +216,6 @@ def save(content: str, path: pathlib.Path):
print(f"Generate source file {path}") print(f"Generate source file {path}")
def filter_compat_info(items):
for item in items:
item['op'] = item['op'].split('(')[0].strip()
if 'backward' in item:
item_backwards = item['backward'].split(',')
for idx, item_backward in enumerate(item_backwards):
item_backward = item_backward.split('(')[0].strip()
item_backwards[idx] = item_backward
item['backward'] = (
','.join(item_backwards)
if len(item_backwards) > 0
else item_backwards[0]
)
def to_compat_dict(items: List[Dict]) -> Dict[str, Dict]: def to_compat_dict(items: List[Dict]) -> Dict[str, Dict]:
compat_dict = {} compat_dict = {}
for item in items: for item in items:
...@@ -201,11 +245,28 @@ def get_inplace_api(apis): ...@@ -201,11 +245,28 @@ def get_inplace_api(apis):
return inplace_apis return inplace_apis
def filter_compat_info(items):
for item in items:
item['op'] = item['op'].split('(')[0].strip()
if 'backward' in item:
item_backwards = item['backward'].split(',')
for idx, item_backward in enumerate(item_backwards):
item_backward = item_backward.split('(')[0].strip()
item_backwards[idx] = item_backward
item['backward'] = (
','.join(item_backwards)
if len(item_backwards) > 0
else item_backwards[0]
)
def extend_compat_info(apis, compats): def extend_compat_info(apis, compats):
for api in apis: for api in apis:
attrs = api["attrs"] attrs = api["attrs"]
for attr in attrs: for attr in attrs:
if attr['typename'] in ["Scalar", "IntArray"]: if op_gen_tests.is_scalar(
attr['typename']
) or op_gen_tests.is_intarray(attr['typename']):
attr["support_tensor"] = False attr["support_tensor"] = False
apis_dict = to_apis_dict(apis) apis_dict = to_apis_dict(apis)
for compat_item in compats: for compat_item in compats:
......
...@@ -21,7 +21,7 @@ from paddle.fluid.core import call_vjp ...@@ -21,7 +21,7 @@ from paddle.fluid.core import call_vjp
paddle.enable_static() paddle.enable_static()
def get_ir_program_0(): def get_ir_divide_program():
main_program, start_program = ( main_program, start_program = (
paddle.static.Program(), paddle.static.Program(),
paddle.static.Program(), paddle.static.Program(),
...@@ -42,7 +42,7 @@ def get_ir_program_0(): ...@@ -42,7 +42,7 @@ def get_ir_program_0():
return newir_program return newir_program
def get_ir_program_1(): def get_ir_sum_program():
main_program, start_program = ( main_program, start_program = (
paddle.static.Program(), paddle.static.Program(),
paddle.static.Program(), paddle.static.Program(),
...@@ -61,8 +61,8 @@ def get_ir_program_1(): ...@@ -61,8 +61,8 @@ def get_ir_program_1():
class TestVjpPrim(unittest.TestCase): class TestVjpPrim(unittest.TestCase):
def test_divide_grad_prim_case1(self): def test_divide_grad_prim_case1(self):
newir_program = get_ir_program_0() newir_program = get_ir_divide_program()
paddle.fluid.core._set_prim_backward_enabled(True) paddle.framework.core._set_prim_backward_enabled(True)
dout = newir_program.block().ops[-2].result(0) dout = newir_program.block().ops[-2].result(0)
out_grads = [[dout]] out_grads = [[dout]]
stop_gradients = [[False], [False]] stop_gradients = [[False], [False]]
...@@ -100,10 +100,11 @@ class TestVjpPrim(unittest.TestCase): ...@@ -100,10 +100,11 @@ class TestVjpPrim(unittest.TestCase):
] ]
for idx, op in enumerate(newir_program.block().ops): for idx, op in enumerate(newir_program.block().ops):
self.assertEqual(op.name(), all_op_names[idx]) self.assertEqual(op.name(), all_op_names[idx])
paddle.framework.core._set_prim_backward_enabled(False)
def test_divide_grad_no_prim(self): def test_divide_grad_no_prim(self):
newir_program = get_ir_program_0() newir_program = get_ir_divide_program()
paddle.fluid.core._set_prim_backward_enabled(False) paddle.framework.core._set_prim_backward_enabled(False)
dout = newir_program.block().ops[-2].result(0) dout = newir_program.block().ops[-2].result(0)
out_grads = [[dout]] out_grads = [[dout]]
stop_gradients = [[False], [False]] stop_gradients = [[False], [False]]
...@@ -120,8 +121,8 @@ class TestVjpPrim(unittest.TestCase): ...@@ -120,8 +121,8 @@ class TestVjpPrim(unittest.TestCase):
self.assertEqual(len(newir_program.block().ops), 5) self.assertEqual(len(newir_program.block().ops), 5)
def test_sum_grad_prim(self): def test_sum_grad_prim(self):
newir_program = get_ir_program_1() newir_program = get_ir_sum_program()
paddle.fluid.core._set_prim_backward_enabled(True) paddle.framework.core._set_prim_backward_enabled(True)
dout = newir_program.block().ops[-3].result(0) dout = newir_program.block().ops[-3].result(0)
out_grads = [[dout]] out_grads = [[dout]]
stop_gradients = [[False], [True]] stop_gradients = [[False], [True]]
...@@ -145,10 +146,11 @@ class TestVjpPrim(unittest.TestCase): ...@@ -145,10 +146,11 @@ class TestVjpPrim(unittest.TestCase):
] ]
for idx, op in enumerate(newir_program.block().ops): for idx, op in enumerate(newir_program.block().ops):
self.assertEqual(op.name(), all_op_names[idx]) self.assertEqual(op.name(), all_op_names[idx])
paddle.framework.core._set_prim_backward_enabled(False)
def test_sum_grad_no_prim(self): def test_sum_grad_no_prim(self):
newir_program = get_ir_program_1() newir_program = get_ir_sum_program()
paddle.fluid.core._set_prim_backward_enabled(False) paddle.framework.core._set_prim_backward_enabled(False)
dout = newir_program.block().ops[-2].result(0) dout = newir_program.block().ops[-2].result(0)
out_grads = [[dout]] out_grads = [[dout]]
stop_gradients = [[False], [True]] stop_gradients = [[False], [True]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册