未验证 提交 42869ab6 编写于 作者: C Charles-hit 提交者: GitHub

support more vjp code gen (#56890)

上级 41acf19b
......@@ -86,19 +86,18 @@ def is_tensor_list(s):
return s == 'Tensor[]'
def exist_mutable_attribute(attrs):
for attr in attrs:
def exist_mutable_attribute(attributes):
for attribute in attributes:
if (
attr['typename'] in ['Scalar', 'IntArray']
and attr['support_tensor'] is True
):
is_scalar(attribute['typename'])
or is_intarray(attribute['typename'])
) and attribute.get('support_tensor', False):
return True
else:
return False
else:
return False
def is_mutable_attribute(attr):
def is_mutable_attribute(attribute):
return (
attr['typename'] in ['Scalar', 'IntArray']
and attr['support_tensor'] is True
)
is_scalar(attribute['typename']) or is_intarray(attribute['typename'])
) and attribute.get('support_tensor', False)
......@@ -45,6 +45,22 @@ VJPS = [
'sum_grad',
'concat_grad',
'split_grad',
'gelu_grad',
'softmax_grad',
'silu_grad',
'multiply_grad',
'subtract_grad',
'erf_grad',
'expand_grad',
'exp_grad',
'elementwise_pow_grad',
'fused_softmax_mask_upper_triangle_grad',
'matmul_grad',
'pow_grad',
'reshape_grad',
'rsqrt_grad',
'slice_grad',
'transpose_grad',
]
VJP_COMPS = ['divide_grad', 'sum_grad']
BACKENDS = [
......@@ -68,6 +84,49 @@ BACKENDS = [
'sum_grad',
'concat_grad',
'split_grad',
'gelu_grad',
'softmax_grad',
'silu_grad',
'multiply_grad',
'subtract_grad',
'erf_grad',
'expand_grad',
'exp_grad',
'multiply',
'exp',
'erf',
'cast',
'elementwise_pow_grad',
'fused_softmax_mask_upper_triangle_grad',
'matmul_grad',
'pow_grad',
'reshape_grad',
'rsqrt_grad',
'slice_grad',
'transpose_grad',
'subtract',
'assign',
'equal',
'greater_equal',
'greater_than',
'less_equal',
'less_than',
'matmul',
'max',
'maximum',
'minimum',
'not_equal',
'abs',
'bitwise_and',
'bitwise_not',
'bitwise_or',
'bitwise_xor',
'floor',
'gather_nd',
'log',
'roll',
'scatter',
'scatter_nd_add',
]
......@@ -157,21 +216,6 @@ def save(content: str, path: pathlib.Path):
print(f"Generate source file {path}")
def filter_compat_info(items):
for item in items:
item['op'] = item['op'].split('(')[0].strip()
if 'backward' in item:
item_backwards = item['backward'].split(',')
for idx, item_backward in enumerate(item_backwards):
item_backward = item_backward.split('(')[0].strip()
item_backwards[idx] = item_backward
item['backward'] = (
','.join(item_backwards)
if len(item_backwards) > 0
else item_backwards[0]
)
def to_compat_dict(items: List[Dict]) -> Dict[str, Dict]:
compat_dict = {}
for item in items:
......@@ -201,11 +245,28 @@ def get_inplace_api(apis):
return inplace_apis
def filter_compat_info(items):
for item in items:
item['op'] = item['op'].split('(')[0].strip()
if 'backward' in item:
item_backwards = item['backward'].split(',')
for idx, item_backward in enumerate(item_backwards):
item_backward = item_backward.split('(')[0].strip()
item_backwards[idx] = item_backward
item['backward'] = (
','.join(item_backwards)
if len(item_backwards) > 0
else item_backwards[0]
)
def extend_compat_info(apis, compats):
for api in apis:
attrs = api["attrs"]
for attr in attrs:
if attr['typename'] in ["Scalar", "IntArray"]:
if op_gen_tests.is_scalar(
attr['typename']
) or op_gen_tests.is_intarray(attr['typename']):
attr["support_tensor"] = False
apis_dict = to_apis_dict(apis)
for compat_item in compats:
......
......@@ -21,7 +21,7 @@ from paddle.fluid.core import call_vjp
paddle.enable_static()
def get_ir_program_0():
def get_ir_divide_program():
main_program, start_program = (
paddle.static.Program(),
paddle.static.Program(),
......@@ -42,7 +42,7 @@ def get_ir_program_0():
return newir_program
def get_ir_program_1():
def get_ir_sum_program():
main_program, start_program = (
paddle.static.Program(),
paddle.static.Program(),
......@@ -61,8 +61,8 @@ def get_ir_program_1():
class TestVjpPrim(unittest.TestCase):
def test_divide_grad_prim_case1(self):
newir_program = get_ir_program_0()
paddle.fluid.core._set_prim_backward_enabled(True)
newir_program = get_ir_divide_program()
paddle.framework.core._set_prim_backward_enabled(True)
dout = newir_program.block().ops[-2].result(0)
out_grads = [[dout]]
stop_gradients = [[False], [False]]
......@@ -100,10 +100,11 @@ class TestVjpPrim(unittest.TestCase):
]
for idx, op in enumerate(newir_program.block().ops):
self.assertEqual(op.name(), all_op_names[idx])
paddle.framework.core._set_prim_backward_enabled(False)
def test_divide_grad_no_prim(self):
newir_program = get_ir_program_0()
paddle.fluid.core._set_prim_backward_enabled(False)
newir_program = get_ir_divide_program()
paddle.framework.core._set_prim_backward_enabled(False)
dout = newir_program.block().ops[-2].result(0)
out_grads = [[dout]]
stop_gradients = [[False], [False]]
......@@ -120,8 +121,8 @@ class TestVjpPrim(unittest.TestCase):
self.assertEqual(len(newir_program.block().ops), 5)
def test_sum_grad_prim(self):
newir_program = get_ir_program_1()
paddle.fluid.core._set_prim_backward_enabled(True)
newir_program = get_ir_sum_program()
paddle.framework.core._set_prim_backward_enabled(True)
dout = newir_program.block().ops[-3].result(0)
out_grads = [[dout]]
stop_gradients = [[False], [True]]
......@@ -145,10 +146,11 @@ class TestVjpPrim(unittest.TestCase):
]
for idx, op in enumerate(newir_program.block().ops):
self.assertEqual(op.name(), all_op_names[idx])
paddle.framework.core._set_prim_backward_enabled(False)
def test_sum_grad_no_prim(self):
newir_program = get_ir_program_1()
paddle.fluid.core._set_prim_backward_enabled(False)
newir_program = get_ir_sum_program()
paddle.framework.core._set_prim_backward_enabled(False)
dout = newir_program.block().ops[-2].result(0)
out_grads = [[dout]]
stop_gradients = [[False], [True]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册