未验证 提交 dd74b3d1 编写于 作者: X Xiaoxu Chen 提交者: GitHub

[prim]use Operator to reconstruct the primitive operator defined in c++ (#51997)

上级 6741dd22
......@@ -575,6 +575,9 @@ class CompositeGradOpMakerBase {
const std::unordered_map<std::string, framework::Attribute>& RuntimeAttrs()
const {
LOG(WARNING) << "CompositeGradOpMaker doesn't support use runtime attrs, "
"but find the op"
<< fwd_op_.Type() << "use runtime attr.";
return fwd_op_.GetRuntimeAttrMap();
}
......
......@@ -425,7 +425,8 @@ void BindOpDesc(pybind11::module *m) {
&pd::OpDesc::SetDistAttr,
pybind11::return_value_policy::reference)
.def("inputs", [](pd::OpDesc &self) { return self.Inputs(); })
.def("outputs", &pd::OpDesc::Outputs);
.def("outputs", &pd::OpDesc::Outputs)
.def("get_attr_map", &pd::OpDesc::GetAttrMap);
pybind11::class_<paddle::experimental::Scalar> scalar(*m, "Scalar", "");
scalar.def(py::init<bool>())
......
......@@ -1715,35 +1715,68 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
def infershape_for_composite(block, grad_op_desc):
# pruning empty output
# NOTE: why pruning the operator with empty output here ?
# Some backward operator will output emtpy var, which will cause infer
# shape error, such assign with input's stop_gradient=True
if len(grad_op_desc.output_arg_names()) == 0:
return
# append op to block
op_desc = block.desc.append_op()
op_desc.copy_from(grad_op_desc)
op_desc._set_attr(
core.op_proto_and_checker_maker.kOpRoleAttrName(),
core.op_proto_and_checker_maker.OpRole.Backward,
)
# create output var
# create output variable
new_vars = set()
# create new gradient variables
for grad_var_name in op_desc.output_arg_names():
for grad_var_name in grad_op_desc.output_arg_names():
if not (
block.desc.has_var_recursive(grad_var_name.encode())
or grad_var_name == core.empty_var_name()
):
block.desc.var(grad_var_name.encode())
# NOTE: stop_gradient will be set in append_op
desc = block.desc.var(grad_var_name.encode())
block.create_var(name=grad_var_name, desc=desc, type=desc.type())
new_vars.add(grad_var_name)
# infer shape and infer dthype
op_desc.check_attrs()
op_desc.infer_var_type(block.desc)
op_desc.infer_shape(block.desc)
# NOTE For the primitive operator generated by decompositing phi grad kernel,
# we Operator to reconstruct the op_desc for reusing some complex logic, such
# as processing dispensable input, intermediate output, extra attrs, etc...
if framework.OpProtoHolder.instance().has_op_proto(grad_op_desc.type()):
op = block.append_op(
type=grad_op_desc.type(),
inputs={
name: [block._find_var_recursive(arg) for arg in args]
for name, args in grad_op_desc.inputs().items()
},
outputs={
name: [block._find_var_recursive(arg) for arg in args]
for name, args in grad_op_desc.outputs().items()
},
# NOTE Runtime attr will be ignore as the c++ GetRuntimeAttr
# interface cann't be exported to python. Please note the WARNNING
# message logged in RuntimeAttrs of composite_grad_desc_maker.h
attrs=grad_op_desc.get_attr_map(),
)
op.desc._set_attr(
core.op_proto_and_checker_maker.kOpRoleAttrName(),
core.op_proto_and_checker_maker.OpRole.Backward,
)
grad_op_desc.copy_from(op.desc)
# For the backward operator, we reuse the logic of _append_backward_var
else:
op_desc = block.desc.append_op()
op_desc.copy_from(grad_op_desc)
op_desc._set_attr(
core.op_proto_and_checker_maker.kOpRoleAttrName(),
core.op_proto_and_checker_maker.OpRole.Backward,
)
op_desc.check_attrs()
op_desc.infer_var_type(block.desc)
op_desc.infer_shape(block.desc)
for arg in op_desc.output_arg_names():
if arg in new_vars:
_infer_var_data_type_shape_(arg, block)
grad_op_desc.copy_from(op_desc)
for arg in op_desc.output_arg_names():
# NOTE: Some operator doesn't infer dtype correctly, this patch set the
# grad_var dtype same with corresponding forward variable.
for arg in grad_op_desc.output_arg_names():
if arg in new_vars:
_infer_var_data_type_shape_(arg, block)
......
......@@ -2916,14 +2916,35 @@ class Operator:
for m in proto.outputs:
if (m.name not in outputs) and m.dispensable:
continue
if not ((m.name in outputs) or m.dispensable):
raise ValueError(
(
"Incorrect setting for output(s) of "
"operator \"%s\", should set: [%s]."
# FIXME: The outputs of primitive operator currently
# doesn't include intermediate output as it will be dropped
# in operator codegen, such as xshape output of reshape2.
# It will fixed when the operator codegen support
# intermediate output.
if core._is_bwd_prim_enabled():
if not (
(m.name in outputs)
or m.dispensable
or m.intermediate
):
raise ValueError(
(
"Incorrect setting for output(s) of "
"operator \"%s\", should set: [%s]."
)
% (type, m.name)
)
% (type, m.name)
)
else:
if not ((m.name in outputs) or m.dispensable):
raise ValueError(
(
"Incorrect setting for output(s) of "
"operator \"%s\", should set: [%s]."
)
% (type, m.name)
)
for out_proto in proto.outputs:
if out_proto.name not in outputs:
continue
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
class TestDispensable(unittest.TestCase):
def setUp(self):
paddle.fluid.core._set_prim_all_enabled(True)
def tearDown(self):
paddle.fluid.core._set_prim_all_enabled(False)
def test_dispensable(self):
@paddle.jit.to_static
def f(x):
return paddle.split(x, num_or_sections=2)
f = paddle.jit.to_static(f)
x = paddle.rand((8,))
x.stop_gradient = False
op = f.get_concrete_program(x)[1].backward_program.block(0).ops[-1]
self.assertEqual(
op.attr('op_role'),
int(paddle.fluid.core.op_proto_and_checker_maker.OpRole.Backward),
)
self.assertIn('AxisTensor', op.input_names)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册