未验证 提交 557bce77 编写于 作者: H Huihuang Zheng 提交者: GitHub

Fix Backward Bugs in Conditional Block (#21809)

The fixed bugs:

1. The condition sub-graph is not pruned
2. When backward graph is extremely simple, the whole backward ops are pruned.
上级 eab124ba
...@@ -221,7 +221,15 @@ struct OpInfoFiller<T, kGradOpDescMaker> { ...@@ -221,7 +221,15 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
info->use_default_grad_op_desc_maker_ = info->use_default_grad_op_desc_maker_ =
std::is_base_of<DefaultGradOpMaker<OpDesc, true>, T>::value || std::is_base_of<DefaultGradOpMaker<OpDesc, true>, T>::value ||
std::is_base_of<DefaultGradOpMaker<OpDesc, false>, T>::value; std::is_base_of<DefaultGradOpMaker<OpDesc, false>, T>::value ||
std::is_base_of<DefaultGradOpMaker<imperative::OpBase, true>,
T>::value ||
std::is_base_of<DefaultGradOpMaker<imperative::OpBase, false>,
T>::value;
info->use_empty_grad_op_desc_maker_ =
std::is_base_of<EmptyGradOpMaker<OpDesc>, T>::value ||
std::is_base_of<EmptyGradOpMaker<imperative::OpBase>, T>::value;
} }
}; };
......
...@@ -101,6 +101,7 @@ void LoadOpLib(const std::string &dso_name) { ...@@ -101,6 +101,7 @@ void LoadOpLib(const std::string &dso_name) {
info.infer_no_need_buffer_vars_ = n.second.infer_no_need_buffer_vars_; info.infer_no_need_buffer_vars_ = n.second.infer_no_need_buffer_vars_;
info.use_default_grad_op_desc_maker_ = info.use_default_grad_op_desc_maker_ =
n.second.use_default_grad_op_desc_maker_; n.second.use_default_grad_op_desc_maker_;
info.use_empty_grad_op_desc_maker_ = n.second.use_empty_grad_op_desc_maker_;
info_map.Insert(type, info); info_map.Insert(type, info);
} }
......
...@@ -33,7 +33,8 @@ class InferShapeBase { ...@@ -33,7 +33,8 @@ class InferShapeBase {
virtual void operator()(InferShapeContext*) const = 0; virtual void operator()(InferShapeContext*) const = 0;
}; };
struct OpInfo { class OpInfo {
public:
OpCreator creator_; OpCreator creator_;
GradOpMakerFN grad_op_maker_; GradOpMakerFN grad_op_maker_;
proto::OpProto* proto_{nullptr}; proto::OpProto* proto_{nullptr};
...@@ -48,6 +49,10 @@ struct OpInfo { ...@@ -48,6 +49,10 @@ struct OpInfo {
// the grad maker is the default one. // the grad maker is the default one.
bool use_default_grad_op_desc_maker_{false}; bool use_default_grad_op_desc_maker_{false};
// NOTE(huihuangzheng): this flag is added to check whether
// the grad maker is the empty one.
bool use_empty_grad_op_desc_maker_{false};
bool HasOpProtoAndChecker() const { bool HasOpProtoAndChecker() const {
return proto_ != nullptr && checker_ != nullptr; return proto_ != nullptr && checker_ != nullptr;
} }
...@@ -82,9 +87,13 @@ struct OpInfo { ...@@ -82,9 +87,13 @@ struct OpInfo {
return grad_op_maker_; return grad_op_maker_;
} }
// some op has no grad_op_maker, add check before use GradOpMaker() // some ops don't have grad_op_maker, add check before use GradOpMaker()
bool HasGradOpMaker() const { return grad_op_maker_ != nullptr; } bool HasGradOpMaker() const { return grad_op_maker_ != nullptr; }
bool HasNonEmptyGradOpMaker() const {
return grad_op_maker_ != nullptr && !use_empty_grad_op_desc_maker_;
}
const DygraphGradOpMakerFN& DygraphGradOpMaker() const { const DygraphGradOpMakerFN& DygraphGradOpMaker() const {
// Normally, proto_ should not be null, except some special operators, such // Normally, proto_ should not be null, except some special operators, such
// as LeaklyReluDoubleGrad op. // as LeaklyReluDoubleGrad op.
...@@ -100,7 +109,7 @@ struct OpInfo { ...@@ -100,7 +109,7 @@ struct OpInfo {
} }
bool HasDygraphGradOpMaker() const { bool HasDygraphGradOpMaker() const {
return dygraph_grad_op_maker_ != nullptr ? true : false; return dygraph_grad_op_maker_ != nullptr;
} }
bool HasInferInplace() const { return infer_inplace_ != nullptr; } bool HasInferInplace() const { return infer_inplace_ != nullptr; }
......
...@@ -1100,6 +1100,11 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1100,6 +1100,11 @@ All parameter, weight, gradient are variables in Paddle.
m.def("has_grad_op_maker", [](const std::string op_type) { m.def("has_grad_op_maker", [](const std::string op_type) {
return framework::OpInfoMap::Instance().Get(op_type).HasGradOpMaker(); return framework::OpInfoMap::Instance().Get(op_type).HasGradOpMaker();
}); });
m.def("has_non_empty_grad_op_maker", [](const std::string op_type) {
return framework::OpInfoMap::Instance()
.Get(op_type)
.HasNonEmptyGradOpMaker();
});
m.def("has_infer_inplace", [](const std::string op_type) { m.def("has_infer_inplace", [](const std::string op_type) {
return framework::OpInfoMap::Instance().Get(op_type).HasInferInplace(); return framework::OpInfoMap::Instance().Get(op_type).HasInferInplace();
}); });
......
...@@ -497,7 +497,7 @@ def _find_not_need_ops(grad_op_descs, forward_ops, input_grad_names_set): ...@@ -497,7 +497,7 @@ def _find_not_need_ops(grad_op_descs, forward_ops, input_grad_names_set):
to prune the unnecessary backward ops. to prune the unnecessary backward ops.
Return: Return:
(list[core.OpDesc]): A list of OpDescs which should be pruned. (set[core.OpDesc]): A set of OpDescs which should be pruned.
""" """
class Var(object): class Var(object):
...@@ -597,8 +597,13 @@ def _find_not_need_ops(grad_op_descs, forward_ops, input_grad_names_set): ...@@ -597,8 +597,13 @@ def _find_not_need_ops(grad_op_descs, forward_ops, input_grad_names_set):
break break
if remove_ops: if remove_ops:
not_need_op_descs.extend([node.op_desc for node in op_list]) not_need_op_descs.extend([node.op_desc for node in op_list])
not_need_op_descs_set = set(not_need_op_descs)
return set(not_need_op_descs) grad_op_descs_set = set(grad_op_descs)
# If a backward computational graph is simply one sub-graph header, the
# not_need_op_descs will be whole graph, this IF clause avoids it.
if grad_op_descs_set == not_need_op_descs_set:
return set()
return not_need_op_descs_set
from .proto import framework_pb2 from .proto import framework_pb2
...@@ -797,9 +802,10 @@ def _get_sub_block_path(sub_block, sub_block_op_desc, no_grad_set): ...@@ -797,9 +802,10 @@ def _get_sub_block_path(sub_block, sub_block_op_desc, no_grad_set):
for op_desc in sub_block.ops: for op_desc in sub_block.ops:
if op_desc.type == "assign" and var in op_desc.output_arg_names: if op_desc.type == "assign" and var in op_desc.output_arg_names:
sub_assign_to_out_ops.append(op_desc) sub_assign_to_out_ops.append(op_desc)
sub_outputs.extend([ for name in op_desc.input_arg_names:
sub_block.var(name) for name in op_desc.input_arg_names if sub_block.has_var(name):
]) sub_outputs.append(sub_block.var(name))
sub_block_op_path = _find_op_path_(sub_block, sub_outputs, [], sub_block_op_path = _find_op_path_(sub_block, sub_outputs, [],
no_grad_set) no_grad_set)
# TODO better way than finding in list # TODO better way than finding in list
...@@ -1241,7 +1247,9 @@ def _find_op_path_(block, outputs, inputs, no_grad_set): ...@@ -1241,7 +1247,9 @@ def _find_op_path_(block, outputs, inputs, no_grad_set):
# All the inputs of the block are used if inputs is empty, # All the inputs of the block are used if inputs is empty,
if inputs: if inputs:
for i, op in enumerate(block.ops): for i, op in enumerate(block.ops):
if _some_in_set_(op.desc.input_arg_names(), input_names): if _some_in_set_(
op.desc.input_arg_names(),
input_names) and core.has_non_empty_grad_op_maker(op.type):
for name in op.desc.output_arg_names(): for name in op.desc.output_arg_names():
if name not in no_grad_set: if name not in no_grad_set:
input_names.add(name) input_names.add(name)
...@@ -1249,7 +1257,9 @@ def _find_op_path_(block, outputs, inputs, no_grad_set): ...@@ -1249,7 +1257,9 @@ def _find_op_path_(block, outputs, inputs, no_grad_set):
relevant_op_flags[i] = False relevant_op_flags[i] = False
for i, op in reversed(list(enumerate(block.ops))): for i, op in reversed(list(enumerate(block.ops))):
if _some_in_set_(op.desc.output_arg_names(), output_names): if _some_in_set_(
op.desc.output_arg_names(),
output_names) and core.has_non_empty_grad_op_maker(op.type):
for name in op.desc.input_arg_names(): for name in op.desc.input_arg_names():
if name not in no_grad_set: if name not in no_grad_set:
output_names.add(name) output_names.add(name)
......
...@@ -223,6 +223,29 @@ class TestCondInputOutput(unittest.TestCase): ...@@ -223,6 +223,29 @@ class TestCondInputOutput(unittest.TestCase):
"Incompatible return values of true_fn and false_fn in cond" in "Incompatible return values of true_fn and false_fn in cond" in
str(e.exception)) str(e.exception))
def test_extremely_simple_net_with_op_in_condition(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
a = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=1.23)
a.stop_gradient = False
b = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=1.25)
b.stop_gradient = False
out = layers.cond(a - b < -1.0, lambda: a, lambda: b)
append_backward(out)
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
ret = exe.run(main_program, fetch_list=[out, a.grad_name, b.grad_name])
# Note: fill_constant has loss of precision, you have to assertEqual
# with values doens't lose precision in float-point number.
self.assertEqual(ret[0][0], 1.25)
self.assertEqual(ret[1][0], 0.0)
self.assertEqual(ret[2][0], 1.0)
class TestCondNestedControlFlow(unittest.TestCase): class TestCondNestedControlFlow(unittest.TestCase):
def test_cond_inside_cond(self): def test_cond_inside_cond(self):
...@@ -277,6 +300,33 @@ class TestCondNestedControlFlow(unittest.TestCase): ...@@ -277,6 +300,33 @@ class TestCondNestedControlFlow(unittest.TestCase):
self.assertEqual(ret[0][0], expected_ret) self.assertEqual(ret[0][0], expected_ret)
self.assertEqual(ret[1][0], expected_a_grad) self.assertEqual(ret[1][0], expected_a_grad)
def test_cond_op_in_condition(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
a = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=1.23)
a.stop_gradient = False
b = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=1.24)
b.stop_gradient = False
out = fluid.layers.cond(
a < b,
lambda: fluid.layers.cond(a - b < -1.0, lambda: fluid.layers.elementwise_add(a, b), lambda: fluid.layers.elementwise_mul(a, b)),
lambda: fluid.layers.cond(a == b, lambda: fluid.layers.elementwise_sub(a, b), lambda: fluid.layers.elementwise_pow(a, b))
)
append_backward(out)
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
ret = exe.run(main_program, fetch_list=[out, a.grad_name, b.grad_name])
# Note: fill_constant has loss of precision, so we assertAlmostEqual.
self.assertAlmostEqual(ret[0][0], 1.5252)
self.assertAlmostEqual(ret[1][0], 1.24)
self.assertAlmostEqual(ret[2][0], 1.23)
class TestCondBackward(unittest.TestCase): class TestCondBackward(unittest.TestCase):
def backward_value_helper(self, cond_func): def backward_value_helper(self, cond_func):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册