未验证 提交 af6324aa 编写于 作者: X Xianduo Li 提交者: GitHub

Polish code and add code comments (#57046)

* add comments of member <with_grad_semantic>

* polish code

* polish code comments
上级 bc601f58
......@@ -722,14 +722,14 @@ def to_pascal_case(s):
return "".join([word.capitalize() for word in words]) + ""
def OpInputGradSemanticCheck(op_info, op_info_items):
input_grad_semantic_list = []
def get_input_grad_semantic(op_info, op_info_items):
input_grad_semantics = []
num_inputs = len(op_info.input_name_list)
# get backward op
bwd_op_name = op_info.backward_name
if (bwd_op_name is None) or (bwd_op_name not in op_info_items.keys()):
input_grad_semantic_list = ["false" for i in range(num_inputs)]
input_grad_semantics = ["false" for i in range(num_inputs)]
else:
bwd_op_info = op_info_items[bwd_op_name]
......@@ -747,23 +747,23 @@ def OpInputGradSemanticCheck(op_info, op_info_items):
), "Configuration of forward op and backward op is not match."
for i in range(num_inputs):
if bwd_fwd_input_list[i] in bwd_output_list_new:
input_grad_semantic_list.append("true")
input_grad_semantics.append("true")
else:
input_grad_semantic_list.append("false")
input_grad_semantics.append("false")
else:
input_grad_semantic_list = ["false" for i in range(num_inputs)]
input_grad_semantics = ["false" for i in range(num_inputs)]
return input_grad_semantic_list
return input_grad_semantics
def OpMutableAttributeGradSemanticCheck(op_info, op_info_items):
mutable_attribute_grad_semantic_list = []
def get_mutable_attribute_grad_semantic(op_info, op_info_items):
mutable_attribute_grad_semantics = []
fwd_mutable_attribute_list = op_info.mutable_attribute_name_list
# get backward op
bwd_op_name = op_info.backward_name
if (bwd_op_name is None) or (bwd_op_name not in op_info_items.keys()):
mutable_attribute_grad_semantic_list = [
mutable_attribute_grad_semantics = [
"false" for i in range(len(fwd_mutable_attribute_list))
]
else:
......@@ -778,11 +778,11 @@ def OpMutableAttributeGradSemanticCheck(op_info, op_info_items):
for i in range(len(fwd_mutable_attribute_list)):
if fwd_mutable_attribute_list[i] in bwd_output_list_new:
mutable_attribute_grad_semantic_list.append("true")
mutable_attribute_grad_semantics.append("true")
else:
mutable_attribute_grad_semantic_list.append("false")
mutable_attribute_grad_semantics.append("false")
return mutable_attribute_grad_semantic_list
return mutable_attribute_grad_semantics
def OpGenerator(
......@@ -874,12 +874,10 @@ def OpGenerator(
exclusive_interface_str = gen_exclusive_interface_str(op_info)
# check op inputs and mutable_attributes grad semantics
input_grad_semantic_list = OpInputGradSemanticCheck(
input_grad_semantics = get_input_grad_semantic(op_info, op_info_items)
mutable_attribute_grad_semantics = get_mutable_attribute_grad_semantic(
op_info, op_info_items
)
mutable_attribute_grad_semantic_list = (
OpMutableAttributeGradSemanticCheck(op_info, op_info_items)
)
# If op has inplace info, we will generate inplace op and non-inplace op.
for op_name in op_info.op_phi_name:
......@@ -1059,7 +1057,7 @@ def OpGenerator(
optional=op_input_optional_list[idx],
no_need_buffer=op_input_no_need_buffer_list[idx],
is_mutable_attribute='false',
with_grad_semantic=input_grad_semantic_list[idx],
with_grad_semantic=input_grad_semantics[idx],
)
)
for idx in range(len(op_mutable_attribute_name_list)):
......@@ -1070,7 +1068,7 @@ def OpGenerator(
optional='false',
no_need_buffer='false',
is_mutable_attribute='true',
with_grad_semantic=mutable_attribute_grad_semantic_list[
with_grad_semantic=mutable_attribute_grad_semantics[
idx
],
)
......
......@@ -27,7 +27,17 @@ struct OpInputInfo {
bool optional = false;
bool no_need_buffer = false;
bool is_mutable_attribute = false;
/***
* "with_grad_semantic" represents whether the input of the OP has gradient
* semantics. For example, gather op contains three inputs (x, index, axis),
* but the backward op gather_grad calculates only the gradient with respect
* to x. Therefore, for gather op, only x has gradient semantics.
* The "with_grad_semantic" fields in OpInputInfo for x is true,
* and "with_grad_semantic" fields in OpInputInfo for index and axis are both
* false.
*/
bool with_grad_semantic = true;
OpInputInfo() = default;
OpInputInfo(const OpInputInfo& input_info) = default;
......
......@@ -20,7 +20,7 @@ from paddle import ir
paddle.enable_static()
def get_ir_program_0():
def get_gather_program_new_ir():
main_program, start_program = (
paddle.static.Program(),
paddle.static.Program(),
......@@ -36,7 +36,7 @@ def get_ir_program_0():
return newir_program
def get_ir_program_1():
def get_multiply_program_new_ir():
main_program, start_program = (
paddle.static.Program(),
paddle.static.Program(),
......@@ -54,15 +54,17 @@ def get_ir_program_1():
class TestOpInputGradSemantic(unittest.TestCase):
def test_gatherop_input_grad_semantic(self):
newir_program = get_ir_program_0()
op = newir_program.block().ops[-1]
self.assertEqual(op.get_input_grad_semantics(), [True, False, False])
def test_gather_op_input_grad_semantic(self):
newir_program = get_gather_program_new_ir()
gather_op = newir_program.block().ops[-1]
self.assertEqual(
gather_op.get_input_grad_semantics(), [True, False, False]
)
def test_multiplyop_input_grad_semantic(self):
newir_program = get_ir_program_1()
op = newir_program.block().ops[-1]
self.assertEqual(op.get_input_grad_semantics(), [True, True])
def test_multiply_op_input_grad_semantic(self):
newir_program = get_multiply_program_new_ir()
multiply_op = newir_program.block().ops[-1]
self.assertEqual(multiply_op.get_input_grad_semantics(), [True, True])
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册