未验证 提交 69161a96 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【static】modify backward prune logic for EmptygradOpMaker (#53746)

* add rules

* modify no kernel yaml parse

* success op generate

* success test_silu_double

* modify bug

* modify static error

* modify silu_grad input

* modify kernel signature

* modify kernel signature

* code style

* code style

* review

* delete opinfo modify

* modify gradOpMaker

* modify gradOpMaker

* modify genarated-j2

* add approve rules

* modify aytograd_functional_static_test
上级 32e36b15
...@@ -97,6 +97,8 @@ class OpInfo { ...@@ -97,6 +97,8 @@ class OpInfo {
return grad_op_maker_ != nullptr && !use_empty_grad_op_desc_maker_; return grad_op_maker_ != nullptr && !use_empty_grad_op_desc_maker_;
} }
bool HasEmptyGradOpMaker() const { return use_empty_grad_op_desc_maker_; }
const DygraphGradOpMakerFN& DygraphGradOpMaker() const { const DygraphGradOpMakerFN& DygraphGradOpMaker() const {
// Normally, proto_ should not be null, except some special operators, such // Normally, proto_ should not be null, except some special operators, such
// as LeaklyReluDoubleGrad op. // as LeaklyReluDoubleGrad op.
......
...@@ -477,6 +477,7 @@ REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op, ...@@ -477,6 +477,7 @@ REGISTER_OPERATOR({{name}}, ops::{{name | to_pascal_case}}Op,
{% set backward_name = op["backward"] %} {% set backward_name = op["backward"] %}
ops::{{backward_name | to_pascal_case}}OpMaker<paddle::framework::OpDesc>, ops::{{backward_name | to_pascal_case}}OpMaker<paddle::framework::OpDesc>,
ops::{{backward_name | to_pascal_case}}OpMaker<paddle::imperative::OpBase>, ops::{{backward_name | to_pascal_case}}OpMaker<paddle::imperative::OpBase>,
{% elif "forward" in op %}
{% else %} {% else %}
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>, paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
......
...@@ -1327,8 +1327,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1327,8 +1327,7 @@ All parameter, weight, gradient are variables in Paddle.
if ((grad_op_maker == nullptr) && (grad_comp_op_maker == nullptr)) { if ((grad_op_maker == nullptr) && (grad_comp_op_maker == nullptr)) {
// Normally, proto_ should not be null, except some special // Normally, proto_ should not be null, except some special
// operators, such as LeaklyReluDoubleGrad op. // operators, such as LeaklyReluDoubleGrad op.
std::string type = std::string type = op_desc.Type();
op_info.proto_ ? op_info.proto_->type() : "unknown";
PADDLE_THROW(platform::errors::NotFound( PADDLE_THROW(platform::errors::NotFound(
"Neither operator %s's GradOpMaker nor CompGradOpMaker has " "Neither operator %s's GradOpMaker nor CompGradOpMaker has "
"been registered.\nPlease check whether (%s) operator has " "been registered.\nPlease check whether (%s) operator has "
...@@ -1350,7 +1349,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1350,7 +1349,8 @@ All parameter, weight, gradient are variables in Paddle.
VLOG(3) << "need skip: " << need_skip << std::endl; VLOG(3) << "need skip: " << need_skip << std::endl;
if (paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) { if (paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) {
if ((grad_comp_op_maker != nullptr) && (!need_skip)) { if ((grad_comp_op_maker != nullptr) && (!need_skip)) {
VLOG(3) << "Runing composite fun for " << op_desc.Type(); VLOG(3) << "Prim Flag Open: Runing composite grad fun for "
<< op_desc.Type();
grad_op_descs = grad_comp_op_maker(op_desc, grad_op_descs = grad_comp_op_maker(op_desc,
no_grad_set, no_grad_set,
&grad_to_var, &grad_to_var,
...@@ -1362,9 +1362,13 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1362,9 +1362,13 @@ All parameter, weight, gradient are variables in Paddle.
} }
} else { } else {
if (grad_op_maker != nullptr) { if (grad_op_maker != nullptr) {
VLOG(3) << "Prim Flag Close: Runing origin grad fun for "
<< op_desc.Type();
grad_op_descs = grad_op_maker( grad_op_descs = grad_op_maker(
op_desc, no_grad_set, &grad_to_var, grad_sub_block); op_desc, no_grad_set, &grad_to_var, grad_sub_block);
} else { } else {
VLOG(3) << "Prim Flag Close: Runing composite grad fun for "
<< op_desc.Type();
grad_op_descs = grad_comp_op_maker(op_desc, grad_op_descs = grad_comp_op_maker(op_desc,
no_grad_set, no_grad_set,
&grad_to_var, &grad_to_var,
...@@ -1392,6 +1396,9 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1392,6 +1396,9 @@ All parameter, weight, gradient are variables in Paddle.
.Get(op_type) .Get(op_type)
.HasNonEmptyGradOpMaker(); .HasNonEmptyGradOpMaker();
}); });
m.def("has_empty_grad_op_maker", [](const std::string op_type) {
return framework::OpInfoMap::Instance().Get(op_type).HasEmptyGradOpMaker();
});
m.def("has_infer_inplace", [](const std::string op_type) { m.def("has_infer_inplace", [](const std::string op_type) {
return framework::OpInfoMap::Instance().Get(op_type).HasInferInplace(); return framework::OpInfoMap::Instance().Get(op_type).HasInferInplace();
}); });
......
...@@ -2348,7 +2348,7 @@ def _find_op_path_( ...@@ -2348,7 +2348,7 @@ def _find_op_path_(
for i, op in enumerate(block.ops): for i, op in enumerate(block.ops):
if _some_in_set_( if _some_in_set_(
op.desc.input_arg_names(), input_names op.desc.input_arg_names(), input_names
) and core.has_non_empty_grad_op_maker(op.type): ) and not core.has_empty_grad_op_maker(op.type):
for name in op.desc.output_arg_names(): for name in op.desc.output_arg_names():
if name not in no_grad_set: if name not in no_grad_set:
input_names.add(name) input_names.add(name)
...@@ -2367,7 +2367,7 @@ def _find_op_path_( ...@@ -2367,7 +2367,7 @@ def _find_op_path_(
if _some_in_set_( if _some_in_set_(
op.desc.output_arg_names(), output_names op.desc.output_arg_names(), output_names
) and core.has_non_empty_grad_op_maker(op.type): ) and not core.has_empty_grad_op_maker(op.type):
for name in op.desc.input_arg_names(): for name in op.desc.input_arg_names():
if name not in no_grad_set: if name not in no_grad_set:
output_names.add(name) output_names.add(name)
...@@ -2382,7 +2382,7 @@ def _find_op_path_( ...@@ -2382,7 +2382,7 @@ def _find_op_path_(
op.desc.output_arg_names(), output_names op.desc.output_arg_names(), output_names
): ):
relevant_op_flags[i] = True relevant_op_flags[i] = True
if core.has_non_empty_grad_op_maker(op.type): if not core.has_empty_grad_op_maker(op.type):
for name in op.desc.input_arg_names(): for name in op.desc.input_arg_names():
if name not in no_grad_set: if name not in no_grad_set:
output_names.add(name) output_names.add(name)
......
...@@ -466,7 +466,7 @@ class TestHessianFloat32(unittest.TestCase): ...@@ -466,7 +466,7 @@ class TestHessianFloat32(unittest.TestCase):
def test_square(self): def test_square(self):
def pd_f(x): def pd_f(x):
"""Input is a square matrix.""" """Input is a square matrix."""
return paddle.matmul(x, x.T).flatten().sum() return paddle.matmul(x, x.T).sum()
def np_hess(x): def np_hess(x):
dim = x.shape[0] dim = x.shape[0]
......
...@@ -81,6 +81,7 @@ API_FILES=("CMakeLists.txt" ...@@ -81,6 +81,7 @@ API_FILES=("CMakeLists.txt"
"paddle/phi/core/kernel_context.h" "paddle/phi/core/kernel_context.h"
"paddle/phi/core/infermeta_utils.h" "paddle/phi/core/infermeta_utils.h"
"paddle/fluid/prim/api/composite_backward/composite_backward_api.h" "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
"paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h"
"paddle/fluid/prim/api/manual_prim/prim_manual_api.h" "paddle/fluid/prim/api/manual_prim/prim_manual_api.h"
"paddle/fluid/prim/api/api.yaml" "paddle/fluid/prim/api/api.yaml"
"python/paddle/incubate/autograd/composite_rules.py" "python/paddle/incubate/autograd/composite_rules.py"
...@@ -207,7 +208,7 @@ for API_FILE in ${API_FILES[*]}; do ...@@ -207,7 +208,7 @@ for API_FILE in ${API_FILES[*]}; do
elif [ "${API_FILE}" == "paddle/phi/api/include/tensor.h" ] || [ "${API_FILE}" == "paddle/phi/core/tensor_base.h" ] || [ "${API_FILE}" == "paddle/phi/core/dense_tensor.h" ] || [ "${API_FILE}" == "paddle/phi/core/meta_tensor.h" ] || [ "${API_FILE}" == "paddle/phi/core/tensor_meta.h" ] || [ "${API_FILE}" == "paddle/phi/core/attribute.h" ] || [ "${API_FILE}" == "paddle/phi/core/device_context.h" ] || [ "${API_FILE}" == "paddle/phi/core/kernel_utils.h" ] || [ "${API_FILE}" == "paddle/phi/core/kernel_registry.h" ] || [ "${API_FILE}" == "paddle/phi/core/kernel_factory.h" ] || [ "${API_FILE}" == "paddle/phi/core/kernel_context.h" ] || [ "${API_FILE}" == "paddle/phi/core/infermeta_utils.h" ]; then elif [ "${API_FILE}" == "paddle/phi/api/include/tensor.h" ] || [ "${API_FILE}" == "paddle/phi/core/tensor_base.h" ] || [ "${API_FILE}" == "paddle/phi/core/dense_tensor.h" ] || [ "${API_FILE}" == "paddle/phi/core/meta_tensor.h" ] || [ "${API_FILE}" == "paddle/phi/core/tensor_meta.h" ] || [ "${API_FILE}" == "paddle/phi/core/attribute.h" ] || [ "${API_FILE}" == "paddle/phi/core/device_context.h" ] || [ "${API_FILE}" == "paddle/phi/core/kernel_utils.h" ] || [ "${API_FILE}" == "paddle/phi/core/kernel_registry.h" ] || [ "${API_FILE}" == "paddle/phi/core/kernel_factory.h" ] || [ "${API_FILE}" == "paddle/phi/core/kernel_context.h" ] || [ "${API_FILE}" == "paddle/phi/core/infermeta_utils.h" ]; then
echo_line="You must have one RD (chenwhql, phlrain, zyfncg, YuanRisheng) approval for changing ${API_FILE} , which manages the underlying code for PaddlePaddle PHI Library.\n" echo_line="You must have one RD (chenwhql, phlrain, zyfncg, YuanRisheng) approval for changing ${API_FILE} , which manages the underlying code for PaddlePaddle PHI Library.\n"
check_approval chenwhql phlrain zyfncg YuanRisheng check_approval chenwhql phlrain zyfncg YuanRisheng
elif [ "${API_FILE}" == "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" ] || [ "${API_FILE}" == "paddle/fluid/prim/api/manual_prim/prim_manual_api.h" ] || [ "${API_FILE}" == "paddle/fluid/prim/api/api.yaml" ]; then elif [ "${API_FILE}" == "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" ] || [ "${API_FILE}" == "paddle/fluid/prim/api/manual_prim/prim_manual_api.h" ] || [ "${API_FILE}" == "paddle/fluid/prim/api/api.yaml" ] || [ "${API_FILE}" == "paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h" ]; then
echo_line="You must have one RD (JiabinYang, cxxly(chenxiaoxu) , xiaoguoguo626807(wangruting)) approval for changing ${API_FILE} , which manages the code for PaddlePaddle Composite Bacward Prim API.\n" echo_line="You must have one RD (JiabinYang, cxxly(chenxiaoxu) , xiaoguoguo626807(wangruting)) approval for changing ${API_FILE} , which manages the code for PaddlePaddle Composite Bacward Prim API.\n"
check_approval 1 JiabinYang cxxly xiaoguoguo626807 check_approval 1 JiabinYang cxxly xiaoguoguo626807
elif [ "${API_FILE}" == "python/paddle/incubate/autograd/primitives.py" ] || [ "${API_FILE}" == "python/paddle/incubate/autograd/composite_rules.py" ]; then elif [ "${API_FILE}" == "python/paddle/incubate/autograd/primitives.py" ] || [ "${API_FILE}" == "python/paddle/incubate/autograd/composite_rules.py" ]; then
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册