未验证 提交 d8122a23 编写于 作者: X Xianduo Li 提交者: GitHub

support for checking op's inputs grad semantic (#56925)

* add <with_grad_semantic> in OpInputInfo to represent whether an input of OP has grad semantics

* add support for check OP's input grad semantic by comparing fwd_op inputs and bwd_op outputs

* add pybind interface to support checking OP's inputs grad semantic in python-level

* add test

* fix bugs

* fix bugs in op_gen

* fix bugs in op_gen

* add test for multiply_op

* fix bugs in codestype

* fix bugs in codestype
上级 7c8c9b7d
...@@ -156,7 +156,7 @@ OpInfoTuple {op_name}::GetOpInfo() {{ ...@@ -156,7 +156,7 @@ OpInfoTuple {op_name}::GetOpInfo() {{
return std::make_tuple(inputs, attributes, outputs, run_time_info, "{origin_op_name}"); return std::make_tuple(inputs, attributes, outputs, run_time_info, "{origin_op_name}");
}} }}
""" """
CONSTRUCT_INPUT_INFO_TEMPLATE = """paddle::dialect::OpInputInfo("{name}", "{typename}", {optional}, {no_need_buffer}, {is_mutable_attribute})""" CONSTRUCT_INPUT_INFO_TEMPLATE = """paddle::dialect::OpInputInfo("{name}", "{typename}", {optional}, {no_need_buffer}, {is_mutable_attribute}, {with_grad_semantic})"""
CONSTRUCT_OUTPUT_INFO_TEMPLATE = """paddle::dialect::OpOutputInfo("{name}", "{typename}", {optional}, {intermediate})""" CONSTRUCT_OUTPUT_INFO_TEMPLATE = """paddle::dialect::OpOutputInfo("{name}", "{typename}", {optional}, {intermediate})"""
CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE = """paddle::dialect::OpAttributeInfo("{name}", "{typename}", "{data_type}")""" CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE = """paddle::dialect::OpAttributeInfo("{name}", "{typename}", "{data_type}")"""
...@@ -345,6 +345,23 @@ class OpInfoParser: ...@@ -345,6 +345,23 @@ class OpInfoParser:
# parse has_custom_verify # parse has_custom_verify
self.custom_verify = self.parse_custom_verify() self.custom_verify = self.parse_custom_verify()
# parse forward input name list and attribute name list
self.forward_input_name_list = self.parse_forward_input_name()
def parse_forward_input_name(self):
if 'forward' in self.op_yaml_item:
forward_input_name_list = []
forward_map = self.op_yaml_item['forward']
if forward_map is not None:
inputs = forward_map['inputs']
for input in inputs:
forward_input_name_list.append(input['name'])
return forward_input_name_list
else:
return None
else:
return None
def cross_check(self, name_list, type_list, optional_list=None): def cross_check(self, name_list, type_list, optional_list=None):
assert len(name_list) == len( assert len(name_list) == len(
type_list type_list
...@@ -705,6 +722,69 @@ def to_pascal_case(s): ...@@ -705,6 +722,69 @@ def to_pascal_case(s):
return "".join([word.capitalize() for word in words]) + "" return "".join([word.capitalize() for word in words]) + ""
def OpInputGradSemanticCheck(op_info, op_info_items):
input_grad_semantic_list = []
num_inputs = len(op_info.input_name_list)
# get backward op
bwd_op_name = op_info.backward_name
if (bwd_op_name is None) or (bwd_op_name not in op_info_items.keys()):
input_grad_semantic_list = ["false" for i in range(num_inputs)]
else:
bwd_op_info = op_info_items[bwd_op_name]
# cut "_grad" of each output of bwd_op, and then compare each modified output with corresponding input
# thus determine whether each input has grad semantic
bwd_output_list = bwd_op_info.output_name_list
bwd_output_list_new = []
for bwd_output in bwd_output_list:
bwd_output_list_new.append(bwd_output[:-5]) # cut _grad
bwd_fwd_input_list = bwd_op_info.forward_input_name_list
if bwd_fwd_input_list is not None:
assert (
len(bwd_fwd_input_list) == num_inputs
), "Configuration of forward op and backward op is not match."
for i in range(num_inputs):
if bwd_fwd_input_list[i] in bwd_output_list_new:
input_grad_semantic_list.append("true")
else:
input_grad_semantic_list.append("false")
else:
input_grad_semantic_list = ["false" for i in range(num_inputs)]
return input_grad_semantic_list
def OpMutableAttributeGradSemanticCheck(op_info, op_info_items):
mutable_attribute_grad_semantic_list = []
fwd_mutable_attribute_list = op_info.mutable_attribute_name_list
# get backward op
bwd_op_name = op_info.backward_name
if (bwd_op_name is None) or (bwd_op_name not in op_info_items.keys()):
mutable_attribute_grad_semantic_list = [
"false" for i in range(len(fwd_mutable_attribute_list))
]
else:
bwd_op_info = op_info_items[bwd_op_name]
# cut "_grad" of each output of bwd_op, and then compare each modified output with corresponding attribute
# thus determine whether each attribute has grad semantic
bwd_output_list = bwd_op_info.output_name_list
bwd_output_list_new = []
for bwd_output in bwd_output_list:
bwd_output_list_new.append(bwd_output[:-5])
for i in range(len(fwd_mutable_attribute_list)):
if fwd_mutable_attribute_list[i] in bwd_output_list_new:
mutable_attribute_grad_semantic_list.append("true")
else:
mutable_attribute_grad_semantic_list.append("false")
return mutable_attribute_grad_semantic_list
def OpGenerator( def OpGenerator(
op_yaml_files, op_yaml_files,
op_compat_yaml_file, op_compat_yaml_file,
...@@ -793,6 +873,14 @@ def OpGenerator( ...@@ -793,6 +873,14 @@ def OpGenerator(
op_interfaces += ["paddle::dialect::VjpInterface"] op_interfaces += ["paddle::dialect::VjpInterface"]
exclusive_interface_str = gen_exclusive_interface_str(op_info) exclusive_interface_str = gen_exclusive_interface_str(op_info)
# check op inputs and mutable_attributes grad semantics
input_grad_semantic_list = OpInputGradSemanticCheck(
op_info, op_info_items
)
mutable_attribute_grad_semantic_list = (
OpMutableAttributeGradSemanticCheck(op_info, op_info_items)
)
# If op has inplace info, we will generate inplace op and non-inplace op. # If op has inplace info, we will generate inplace op and non-inplace op.
for op_name in op_info.op_phi_name: for op_name in op_info.op_phi_name:
if op_name in _NO_NEED_GEN_OPS: if op_name in _NO_NEED_GEN_OPS:
...@@ -971,6 +1059,7 @@ def OpGenerator( ...@@ -971,6 +1059,7 @@ def OpGenerator(
optional=op_input_optional_list[idx], optional=op_input_optional_list[idx],
no_need_buffer=op_input_no_need_buffer_list[idx], no_need_buffer=op_input_no_need_buffer_list[idx],
is_mutable_attribute='false', is_mutable_attribute='false',
with_grad_semantic=input_grad_semantic_list[idx],
) )
) )
for idx in range(len(op_mutable_attribute_name_list)): for idx in range(len(op_mutable_attribute_name_list)):
...@@ -981,6 +1070,9 @@ def OpGenerator( ...@@ -981,6 +1070,9 @@ def OpGenerator(
optional='false', optional='false',
no_need_buffer='false', no_need_buffer='false',
is_mutable_attribute='true', is_mutable_attribute='true',
with_grad_semantic=mutable_attribute_grad_semantic_list[
idx
],
) )
) )
if len(input_info_list) > 0: if len(input_info_list) > 0:
......
...@@ -35,7 +35,8 @@ OpInfoTuple AddNOp::GetOpInfo() { ...@@ -35,7 +35,8 @@ OpInfoTuple AddNOp::GetOpInfo() {
"ir::VectorType<paddle::dialect::DenseTensorType>", "ir::VectorType<paddle::dialect::DenseTensorType>",
false, false,
false, false,
false)}; false,
true)};
std::vector<paddle::dialect::OpAttributeInfo> attributes = {}; std::vector<paddle::dialect::OpAttributeInfo> attributes = {};
std::vector<paddle::dialect::OpOutputInfo> outputs = { std::vector<paddle::dialect::OpOutputInfo> outputs = {
OpOutputInfo("out", "paddle::dialect::DenseTensorType", false, false)}; OpOutputInfo("out", "paddle::dialect::DenseTensorType", false, false)};
...@@ -416,9 +417,14 @@ OpInfoTuple SplitGradOp::GetOpInfo() { ...@@ -416,9 +417,14 @@ OpInfoTuple SplitGradOp::GetOpInfo() {
"ir::VectorType<paddle::dialect::DenseTensorType>", "ir::VectorType<paddle::dialect::DenseTensorType>",
false, false,
false, false,
false), false,
OpInputInfo( true),
"axis", "paddle::dialect::ScalarAttribute", false, false, true)}; OpInputInfo("axis",
"paddle::dialect::ScalarAttribute",
false,
false,
true,
false)};
std::vector<paddle::dialect::OpAttributeInfo> attributes = {}; std::vector<paddle::dialect::OpAttributeInfo> attributes = {};
std::vector<paddle::dialect::OpOutputInfo> outputs = { std::vector<paddle::dialect::OpOutputInfo> outputs = {
OpOutputInfo("x_grad", "paddle::dialect::DenseTensorType", false, false)}; OpOutputInfo("x_grad", "paddle::dialect::DenseTensorType", false, false)};
......
...@@ -27,6 +27,7 @@ struct OpInputInfo { ...@@ -27,6 +27,7 @@ struct OpInputInfo {
bool optional = false; bool optional = false;
bool no_need_buffer = false; bool no_need_buffer = false;
bool is_mutable_attribute = false; bool is_mutable_attribute = false;
bool with_grad_semantic = true;
OpInputInfo() = default; OpInputInfo() = default;
OpInputInfo(const OpInputInfo& input_info) = default; OpInputInfo(const OpInputInfo& input_info) = default;
...@@ -34,12 +35,14 @@ struct OpInputInfo { ...@@ -34,12 +35,14 @@ struct OpInputInfo {
const std::string& type_name, const std::string& type_name,
bool optional, bool optional,
bool no_need_buffer, bool no_need_buffer,
bool is_mutable_attribute) bool is_mutable_attribute,
bool with_grad_semantic)
: name(name), : name(name),
type_name(type_name), type_name(type_name),
optional(optional), optional(optional),
no_need_buffer(no_need_buffer), no_need_buffer(no_need_buffer),
is_mutable_attribute(is_mutable_attribute) {} is_mutable_attribute(is_mutable_attribute),
with_grad_semantic(with_grad_semantic) {}
}; };
struct OpOutputInfo { struct OpOutputInfo {
......
...@@ -247,6 +247,17 @@ void BindOperation(py::module *m) { ...@@ -247,6 +247,17 @@ void BindOperation(py::module *m) {
} }
return op_list; return op_list;
}) })
.def("get_input_grad_semantics",
[](Operation &self) -> py::list {
py::list op_list;
paddle::dialect::OpYamlInfoInterface yaml_interface =
self.dyn_cast<paddle::dialect::OpYamlInfoInterface>();
auto inputs_grad_info = std::get<0>(yaml_interface.GetOpInfo());
for (auto &input_grad_info : inputs_grad_info) {
op_list.append(input_grad_info.with_grad_semantic);
}
return op_list;
})
.def("replace_all_uses_with", .def("replace_all_uses_with",
[](Operation &self, const std::vector<OpResult> &op_results) { [](Operation &self, const std::vector<OpResult> &op_results) {
self.ReplaceAllUsesWith(op_results); self.ReplaceAllUsesWith(op_results);
......
...@@ -429,14 +429,30 @@ const char *Conv2dFusionOpTest::attributes_name[10] = { // NOLINT ...@@ -429,14 +429,30 @@ const char *Conv2dFusionOpTest::attributes_name[10] = { // NOLINT
OpInfoTuple Conv2dFusionOpTest::GetOpInfo() { OpInfoTuple Conv2dFusionOpTest::GetOpInfo() {
std::vector<paddle::dialect::OpInputInfo> inputs = { std::vector<paddle::dialect::OpInputInfo> inputs = {
OpInputInfo( OpInputInfo("input",
"input", "paddle::dialect::DenseTensorType", false, false, false), "paddle::dialect::DenseTensorType",
OpInputInfo( false,
"filter", "paddle::dialect::DenseTensorType", false, false, false), false,
OpInputInfo( false,
"bias", "paddle::dialect::DenseTensorType", false, false, false), true),
OpInputInfo( OpInputInfo("filter",
"residual", "paddle::dialect::DenseTensorType", true, false, false)}; "paddle::dialect::DenseTensorType",
false,
false,
false,
true),
OpInputInfo("bias",
"paddle::dialect::DenseTensorType",
false,
false,
false,
true),
OpInputInfo("residual",
"paddle::dialect::DenseTensorType",
true,
false,
false,
true)};
std::vector<paddle::dialect::OpAttributeInfo> attributes = { std::vector<paddle::dialect::OpAttributeInfo> attributes = {
OpAttributeInfo("strides", "ir::ArrayAttribute<ir::Int32Attribute>", ""), OpAttributeInfo("strides", "ir::ArrayAttribute<ir::Int32Attribute>", ""),
OpAttributeInfo( OpAttributeInfo(
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from paddle import ir
paddle.enable_static()
def get_ir_program_0():
main_program, start_program = (
paddle.static.Program(),
paddle.static.Program(),
)
with paddle.static.program_guard(main_program, start_program):
x = paddle.tensor.fill_constant(
shape=[3, 4], dtype='float32', value=2.0
)
index = paddle.tensor.fill_constant(shape=[1], dtype='int32', value=1.0)
axis = paddle.tensor.fill_constant(shape=[1], dtype='int32', value=2.0)
out = paddle.gather(x, index, axis)
newir_program = ir.translate_to_new_ir(main_program.desc)
return newir_program
def get_ir_program_1():
main_program, start_program = (
paddle.static.Program(),
paddle.static.Program(),
)
with paddle.static.program_guard(main_program, start_program):
x = paddle.tensor.fill_constant(
shape=[3, 4], dtype='float32', value=2.0
)
y = paddle.tensor.fill_constant(
shape=[3, 4], dtype='float32', value=3.0
)
out = paddle.multiply(x, y)
newir_program = ir.translate_to_new_ir(main_program.desc)
return newir_program
class TestOpInputGradSemantic(unittest.TestCase):
def test_gatherop_input_grad_semantic(self):
newir_program = get_ir_program_0()
op = newir_program.block().ops[-1]
self.assertEqual(op.get_input_grad_semantics(), [True, False, False])
def test_multiplyop_input_grad_semantic(self):
newir_program = get_ir_program_1()
op = newir_program.block().ops[-1]
self.assertEqual(op.get_input_grad_semantics(), [True, True])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册