未验证 提交 557bce77 编写于 作者: H Huihuang Zheng 提交者: GitHub

Fix Backward Bugs in Conditional Block (#21809)

The fixed bugs:

1. The condition sub-graph is not pruned
2. When backward graph is extremely simple, the whole backward ops are pruned.
上级 eab124ba
......@@ -221,7 +221,15 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
info->use_default_grad_op_desc_maker_ =
std::is_base_of<DefaultGradOpMaker<OpDesc, true>, T>::value ||
std::is_base_of<DefaultGradOpMaker<OpDesc, false>, T>::value;
std::is_base_of<DefaultGradOpMaker<OpDesc, false>, T>::value ||
std::is_base_of<DefaultGradOpMaker<imperative::OpBase, true>,
T>::value ||
std::is_base_of<DefaultGradOpMaker<imperative::OpBase, false>,
T>::value;
info->use_empty_grad_op_desc_maker_ =
std::is_base_of<EmptyGradOpMaker<OpDesc>, T>::value ||
std::is_base_of<EmptyGradOpMaker<imperative::OpBase>, T>::value;
}
};
......
......@@ -101,6 +101,7 @@ void LoadOpLib(const std::string &dso_name) {
info.infer_no_need_buffer_vars_ = n.second.infer_no_need_buffer_vars_;
info.use_default_grad_op_desc_maker_ =
n.second.use_default_grad_op_desc_maker_;
info.use_empty_grad_op_desc_maker_ = n.second.use_empty_grad_op_desc_maker_;
info_map.Insert(type, info);
}
......
......@@ -33,7 +33,8 @@ class InferShapeBase {
virtual void operator()(InferShapeContext*) const = 0;
};
struct OpInfo {
class OpInfo {
public:
OpCreator creator_;
GradOpMakerFN grad_op_maker_;
proto::OpProto* proto_{nullptr};
......@@ -48,6 +49,10 @@ struct OpInfo {
// the grad maker is the default one.
bool use_default_grad_op_desc_maker_{false};
// NOTE(huihuangzheng): this flag is added to check whether
// the grad maker is the empty one.
bool use_empty_grad_op_desc_maker_{false};
bool HasOpProtoAndChecker() const {
return proto_ != nullptr && checker_ != nullptr;
}
......@@ -82,9 +87,13 @@ struct OpInfo {
return grad_op_maker_;
}
// some op has no grad_op_maker, add check before use GradOpMaker()
// some ops don't have grad_op_maker, add check before use GradOpMaker()
bool HasGradOpMaker() const { return grad_op_maker_ != nullptr; }
bool HasNonEmptyGradOpMaker() const {
return grad_op_maker_ != nullptr && !use_empty_grad_op_desc_maker_;
}
const DygraphGradOpMakerFN& DygraphGradOpMaker() const {
// Normally, proto_ should not be null, except some special operators, such
// as LeaklyReluDoubleGrad op.
......@@ -100,7 +109,7 @@ struct OpInfo {
}
bool HasDygraphGradOpMaker() const {
return dygraph_grad_op_maker_ != nullptr ? true : false;
return dygraph_grad_op_maker_ != nullptr;
}
bool HasInferInplace() const { return infer_inplace_ != nullptr; }
......
......@@ -1100,6 +1100,11 @@ All parameter, weight, gradient are variables in Paddle.
m.def("has_grad_op_maker", [](const std::string op_type) {
return framework::OpInfoMap::Instance().Get(op_type).HasGradOpMaker();
});
m.def("has_non_empty_grad_op_maker", [](const std::string op_type) {
return framework::OpInfoMap::Instance()
.Get(op_type)
.HasNonEmptyGradOpMaker();
});
m.def("has_infer_inplace", [](const std::string op_type) {
return framework::OpInfoMap::Instance().Get(op_type).HasInferInplace();
});
......
......@@ -497,7 +497,7 @@ def _find_not_need_ops(grad_op_descs, forward_ops, input_grad_names_set):
to prune the unnecessary backward ops.
Return:
(list[core.OpDesc]): A list of OpDescs which should be pruned.
(set[core.OpDesc]): A set of OpDescs which should be pruned.
"""
class Var(object):
......@@ -597,8 +597,13 @@ def _find_not_need_ops(grad_op_descs, forward_ops, input_grad_names_set):
break
if remove_ops:
not_need_op_descs.extend([node.op_desc for node in op_list])
return set(not_need_op_descs)
not_need_op_descs_set = set(not_need_op_descs)
grad_op_descs_set = set(grad_op_descs)
# If a backward computational graph is simply one sub-graph header, the
# not_need_op_descs will be whole graph, this IF clause avoids it.
if grad_op_descs_set == not_need_op_descs_set:
return set()
return not_need_op_descs_set
from .proto import framework_pb2
......@@ -797,9 +802,10 @@ def _get_sub_block_path(sub_block, sub_block_op_desc, no_grad_set):
for op_desc in sub_block.ops:
if op_desc.type == "assign" and var in op_desc.output_arg_names:
sub_assign_to_out_ops.append(op_desc)
sub_outputs.extend([
sub_block.var(name) for name in op_desc.input_arg_names
])
for name in op_desc.input_arg_names:
if sub_block.has_var(name):
sub_outputs.append(sub_block.var(name))
sub_block_op_path = _find_op_path_(sub_block, sub_outputs, [],
no_grad_set)
# TODO better way than finding in list
......@@ -1241,7 +1247,9 @@ def _find_op_path_(block, outputs, inputs, no_grad_set):
# All the inputs of the block are used if inputs is empty,
if inputs:
for i, op in enumerate(block.ops):
if _some_in_set_(op.desc.input_arg_names(), input_names):
if _some_in_set_(
op.desc.input_arg_names(),
input_names) and core.has_non_empty_grad_op_maker(op.type):
for name in op.desc.output_arg_names():
if name not in no_grad_set:
input_names.add(name)
......@@ -1249,7 +1257,9 @@ def _find_op_path_(block, outputs, inputs, no_grad_set):
relevant_op_flags[i] = False
for i, op in reversed(list(enumerate(block.ops))):
if _some_in_set_(op.desc.output_arg_names(), output_names):
if _some_in_set_(
op.desc.output_arg_names(),
output_names) and core.has_non_empty_grad_op_maker(op.type):
for name in op.desc.input_arg_names():
if name not in no_grad_set:
output_names.add(name)
......
......@@ -223,6 +223,29 @@ class TestCondInputOutput(unittest.TestCase):
"Incompatible return values of true_fn and false_fn in cond" in
str(e.exception))
def test_extremely_simple_net_with_op_in_condition(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
a = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=1.23)
a.stop_gradient = False
b = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=1.25)
b.stop_gradient = False
out = layers.cond(a - b < -1.0, lambda: a, lambda: b)
append_backward(out)
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
ret = exe.run(main_program, fetch_list=[out, a.grad_name, b.grad_name])
# Note: fill_constant has loss of precision, you have to assertEqual
# with values doens't lose precision in float-point number.
self.assertEqual(ret[0][0], 1.25)
self.assertEqual(ret[1][0], 0.0)
self.assertEqual(ret[2][0], 1.0)
class TestCondNestedControlFlow(unittest.TestCase):
def test_cond_inside_cond(self):
......@@ -277,6 +300,33 @@ class TestCondNestedControlFlow(unittest.TestCase):
self.assertEqual(ret[0][0], expected_ret)
self.assertEqual(ret[1][0], expected_a_grad)
def test_cond_op_in_condition(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
a = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=1.23)
a.stop_gradient = False
b = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=1.24)
b.stop_gradient = False
out = fluid.layers.cond(
a < b,
lambda: fluid.layers.cond(a - b < -1.0, lambda: fluid.layers.elementwise_add(a, b), lambda: fluid.layers.elementwise_mul(a, b)),
lambda: fluid.layers.cond(a == b, lambda: fluid.layers.elementwise_sub(a, b), lambda: fluid.layers.elementwise_pow(a, b))
)
append_backward(out)
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
ret = exe.run(main_program, fetch_list=[out, a.grad_name, b.grad_name])
# Note: fill_constant has loss of precision, so we assertAlmostEqual.
self.assertAlmostEqual(ret[0][0], 1.5252)
self.assertAlmostEqual(ret[1][0], 1.24)
self.assertAlmostEqual(ret[2][0], 1.23)
class TestCondBackward(unittest.TestCase):
def backward_value_helper(self, cond_func):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册