未验证 提交 af6324aa 编写于 作者: X Xianduo Li 提交者: GitHub

Polish code and add code comments (#57046)

* add comments of member <with_grad_semantic>

* polish code

* polish code comments
上级 bc601f58
...@@ -722,14 +722,14 @@ def to_pascal_case(s): ...@@ -722,14 +722,14 @@ def to_pascal_case(s):
return "".join([word.capitalize() for word in words]) + "" return "".join([word.capitalize() for word in words]) + ""
def OpInputGradSemanticCheck(op_info, op_info_items): def get_input_grad_semantic(op_info, op_info_items):
input_grad_semantic_list = [] input_grad_semantics = []
num_inputs = len(op_info.input_name_list) num_inputs = len(op_info.input_name_list)
# get backward op # get backward op
bwd_op_name = op_info.backward_name bwd_op_name = op_info.backward_name
if (bwd_op_name is None) or (bwd_op_name not in op_info_items.keys()): if (bwd_op_name is None) or (bwd_op_name not in op_info_items.keys()):
input_grad_semantic_list = ["false" for i in range(num_inputs)] input_grad_semantics = ["false" for i in range(num_inputs)]
else: else:
bwd_op_info = op_info_items[bwd_op_name] bwd_op_info = op_info_items[bwd_op_name]
...@@ -747,23 +747,23 @@ def OpInputGradSemanticCheck(op_info, op_info_items): ...@@ -747,23 +747,23 @@ def OpInputGradSemanticCheck(op_info, op_info_items):
), "Configuration of forward op and backward op is not match." ), "Configuration of forward op and backward op is not match."
for i in range(num_inputs): for i in range(num_inputs):
if bwd_fwd_input_list[i] in bwd_output_list_new: if bwd_fwd_input_list[i] in bwd_output_list_new:
input_grad_semantic_list.append("true") input_grad_semantics.append("true")
else: else:
input_grad_semantic_list.append("false") input_grad_semantics.append("false")
else: else:
input_grad_semantic_list = ["false" for i in range(num_inputs)] input_grad_semantics = ["false" for i in range(num_inputs)]
return input_grad_semantic_list return input_grad_semantics
def OpMutableAttributeGradSemanticCheck(op_info, op_info_items): def get_mutable_attribute_grad_semantic(op_info, op_info_items):
mutable_attribute_grad_semantic_list = [] mutable_attribute_grad_semantics = []
fwd_mutable_attribute_list = op_info.mutable_attribute_name_list fwd_mutable_attribute_list = op_info.mutable_attribute_name_list
# get backward op # get backward op
bwd_op_name = op_info.backward_name bwd_op_name = op_info.backward_name
if (bwd_op_name is None) or (bwd_op_name not in op_info_items.keys()): if (bwd_op_name is None) or (bwd_op_name not in op_info_items.keys()):
mutable_attribute_grad_semantic_list = [ mutable_attribute_grad_semantics = [
"false" for i in range(len(fwd_mutable_attribute_list)) "false" for i in range(len(fwd_mutable_attribute_list))
] ]
else: else:
...@@ -778,11 +778,11 @@ def OpMutableAttributeGradSemanticCheck(op_info, op_info_items): ...@@ -778,11 +778,11 @@ def OpMutableAttributeGradSemanticCheck(op_info, op_info_items):
for i in range(len(fwd_mutable_attribute_list)): for i in range(len(fwd_mutable_attribute_list)):
if fwd_mutable_attribute_list[i] in bwd_output_list_new: if fwd_mutable_attribute_list[i] in bwd_output_list_new:
mutable_attribute_grad_semantic_list.append("true") mutable_attribute_grad_semantics.append("true")
else: else:
mutable_attribute_grad_semantic_list.append("false") mutable_attribute_grad_semantics.append("false")
return mutable_attribute_grad_semantic_list return mutable_attribute_grad_semantics
def OpGenerator( def OpGenerator(
...@@ -874,12 +874,10 @@ def OpGenerator( ...@@ -874,12 +874,10 @@ def OpGenerator(
exclusive_interface_str = gen_exclusive_interface_str(op_info) exclusive_interface_str = gen_exclusive_interface_str(op_info)
# check op inputs and mutable_attributes grad semantics # check op inputs and mutable_attributes grad semantics
input_grad_semantic_list = OpInputGradSemanticCheck( input_grad_semantics = get_input_grad_semantic(op_info, op_info_items)
mutable_attribute_grad_semantics = get_mutable_attribute_grad_semantic(
op_info, op_info_items op_info, op_info_items
) )
mutable_attribute_grad_semantic_list = (
OpMutableAttributeGradSemanticCheck(op_info, op_info_items)
)
# If op has inplace info, we will generate inplace op and non-inplace op. # If op has inplace info, we will generate inplace op and non-inplace op.
for op_name in op_info.op_phi_name: for op_name in op_info.op_phi_name:
...@@ -1059,7 +1057,7 @@ def OpGenerator( ...@@ -1059,7 +1057,7 @@ def OpGenerator(
optional=op_input_optional_list[idx], optional=op_input_optional_list[idx],
no_need_buffer=op_input_no_need_buffer_list[idx], no_need_buffer=op_input_no_need_buffer_list[idx],
is_mutable_attribute='false', is_mutable_attribute='false',
with_grad_semantic=input_grad_semantic_list[idx], with_grad_semantic=input_grad_semantics[idx],
) )
) )
for idx in range(len(op_mutable_attribute_name_list)): for idx in range(len(op_mutable_attribute_name_list)):
...@@ -1070,7 +1068,7 @@ def OpGenerator( ...@@ -1070,7 +1068,7 @@ def OpGenerator(
optional='false', optional='false',
no_need_buffer='false', no_need_buffer='false',
is_mutable_attribute='true', is_mutable_attribute='true',
with_grad_semantic=mutable_attribute_grad_semantic_list[ with_grad_semantic=mutable_attribute_grad_semantics[
idx idx
], ],
) )
......
...@@ -27,7 +27,17 @@ struct OpInputInfo { ...@@ -27,7 +27,17 @@ struct OpInputInfo {
bool optional = false; bool optional = false;
bool no_need_buffer = false; bool no_need_buffer = false;
bool is_mutable_attribute = false; bool is_mutable_attribute = false;
/***
* "with_grad_semantic" represents whether the input of the OP has gradient
* semantics. For example, gather op contains three inputs (x, index, axis),
* but the backward op gather_grad calculates only the gradient with respect
* to x. Therefore, for gather op, only x has gradient semantics.
* The "with_grad_semantic" fields in OpInputInfo for x is true,
* and "with_grad_semantic" fields in OpInputInfo for index and axis are both
* false.
*/
bool with_grad_semantic = true; bool with_grad_semantic = true;
OpInputInfo() = default; OpInputInfo() = default;
OpInputInfo(const OpInputInfo& input_info) = default; OpInputInfo(const OpInputInfo& input_info) = default;
......
...@@ -20,7 +20,7 @@ from paddle import ir ...@@ -20,7 +20,7 @@ from paddle import ir
paddle.enable_static() paddle.enable_static()
def get_ir_program_0(): def get_gather_program_new_ir():
main_program, start_program = ( main_program, start_program = (
paddle.static.Program(), paddle.static.Program(),
paddle.static.Program(), paddle.static.Program(),
...@@ -36,7 +36,7 @@ def get_ir_program_0(): ...@@ -36,7 +36,7 @@ def get_ir_program_0():
return newir_program return newir_program
def get_ir_program_1(): def get_multiply_program_new_ir():
main_program, start_program = ( main_program, start_program = (
paddle.static.Program(), paddle.static.Program(),
paddle.static.Program(), paddle.static.Program(),
...@@ -54,15 +54,17 @@ def get_ir_program_1(): ...@@ -54,15 +54,17 @@ def get_ir_program_1():
class TestOpInputGradSemantic(unittest.TestCase): class TestOpInputGradSemantic(unittest.TestCase):
def test_gatherop_input_grad_semantic(self): def test_gather_op_input_grad_semantic(self):
newir_program = get_ir_program_0() newir_program = get_gather_program_new_ir()
op = newir_program.block().ops[-1] gather_op = newir_program.block().ops[-1]
self.assertEqual(op.get_input_grad_semantics(), [True, False, False]) self.assertEqual(
gather_op.get_input_grad_semantics(), [True, False, False]
)
def test_multiplyop_input_grad_semantic(self): def test_multiply_op_input_grad_semantic(self):
newir_program = get_ir_program_1() newir_program = get_multiply_program_new_ir()
op = newir_program.block().ops[-1] multiply_op = newir_program.block().ops[-1]
self.assertEqual(op.get_input_grad_semantics(), [True, True]) self.assertEqual(multiply_op.get_input_grad_semantics(), [True, True])
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册