未验证 提交 630be319 编写于 作者: H Huihuang Zheng 提交者: GitHub

Fix Cond Bug for Nested Control Flow (#21340)

* Commit before merging develop

test=develop

* Backup after working with Huihuang logs

* Commit before deleting Huihuang debug loggings

* Commit before debug

test=develop

* Fix bug commit

test=develop

* Backup of fixing bugs

test=develop

* Clean up code

test=develop

* Fix a bug in sum_op

test=develop
上级 cd43c444
......@@ -120,13 +120,12 @@ void Executor::Close() {
void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
int block_id) {
VLOG(3) << "Creating Variables for block " << block_id;
auto& global_block = pdesc.Block(block_id);
const Scope* ancestor_scope = scope;
while (ancestor_scope->parent()) {
ancestor_scope = ancestor_scope->parent();
}
if (ancestor_scope != scope) {
for (auto& var : global_block.AllVars()) {
if (var->Name() == framework::kEmptyVarName) {
......@@ -196,11 +195,12 @@ void Executor::ReleaseTrainer(std::shared_ptr<TrainerBase> trainer) {
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars,
const std::vector<std::string>& skip_ref_cnt_vars,
bool force_disable_gc) {
bool force_disable_gc, bool keep_kid_scopes) {
platform::RecordBlock b(block_id);
if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc);
auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars);
RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars,
keep_kid_scopes);
}
// Check whether the block already has feed operators and feed_holder.
......@@ -465,6 +465,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
// the sub scopes it created should not be dropped immediately, because
// while_grad_op will use some variables created during while_op run, so
// we need to keep the kids and wait for the outer executor to drop them.
scope->DropKids();
}
}
......
......@@ -71,12 +71,18 @@ class Executor {
* @param
* ProgramDesc
* Scope
* block_id
* create_local_scope
* create_vars
* skip_ref_cnt_vars
* force_disable_gc
* keep_kid_scopes
*/
void Run(const ProgramDesc& prog, Scope* scope, int block_id,
bool create_local_scope = true, bool create_vars = true,
const std::vector<std::string>& skip_ref_cnt_vars =
std::vector<std::string>(),
bool force_disable_gc = false);
bool force_disable_gc = false, bool keep_kid_scopes = false);
// This API is very slow.
void Run(const ProgramDesc& program, Scope* scope,
......
......@@ -93,6 +93,11 @@ const Scope* Scope::FindScope(const Variable* var) const {
return FindScopeInternal(var);
}
const Scope* Scope::FindScope(const std::string& name) const {
SCOPE_VARS_READER_LOCK
return FindScopeInternal(name);
}
void Scope::DropKids() {
SCOPE_KIDS_WRITER_LOCK
for (Scope* s : kids_) delete s;
......@@ -174,6 +179,13 @@ const Scope* Scope::FindScopeInternal(const Variable* var) const {
return (parent_ == nullptr) ? nullptr : parent_->FindScope(var);
}
const Scope* Scope::FindScopeInternal(const std::string& name) const {
if (vars_.find(name) != vars_.end()) {
return this;
}
return (parent_ == nullptr) ? nullptr : parent_->FindScope(name);
}
void Scope::RenameInternal(const std::string& origin_name,
const std::string& new_name) const {
auto origin_it = vars_.find(origin_name);
......@@ -196,7 +208,9 @@ Variable* Scope::FindVarInternal(const std::string& name) const {
Variable* Scope::FindVarLocally(const std::string& name) const {
auto it = vars_.find(name);
if (it != vars_.end()) return it->second.get();
if (it != vars_.end()) {
return it->second.get();
}
return nullptr;
}
......
......@@ -85,6 +85,9 @@ class Scope {
/// Find the scope or an ancestor scope that contains the given variable.
const Scope* FindScope(const Variable* var) const;
/// Find the scope or an ancestor scope that contains the given variable name.
const Scope* FindScope(const std::string& name) const;
void DeleteScope(Scope* scope) const;
/// Drop all kids scopes belonged to this scope.
......@@ -125,6 +128,9 @@ class Scope {
// Called by FindScope.
const Scope* FindScopeInternal(const Variable* var) const;
// Called by FindScope.
const Scope* FindScopeInternal(const std::string& name) const;
// Called by Rename.
void RenameInternal(const std::string& origin_name,
const std::string& new_name) const;
......
......@@ -53,7 +53,9 @@ TEST(Scope, FindScope) {
Variable* v = s.Var("a");
EXPECT_EQ(&s, s.FindScope(v));
EXPECT_EQ(&s, s.FindScope("a"));
EXPECT_EQ(&s, ss.FindScope(v));
EXPECT_EQ(&s, ss.FindScope("a"));
}
TEST(Scope, GetAllNames) {
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/controlflow/conditional_block_op.h"
#include "paddle/fluid/operators/assign_op.h"
namespace paddle {
namespace operators {
......@@ -58,13 +59,15 @@ class ConditionalBlockOp : public ConditionalOp {
scopes->resize(1);
scopes->front() = &scope.NewScope();
auto &cur_scope = *scopes->front();
framework::Executor exec(dev_place);
auto *block = Attr<framework::BlockDesc *>("sub_block");
VLOG(3) << "Conditional block.idx = " << block->ID()
<< ", scope = " << &cur_scope;
auto &skip_vars =
Attr<std::vector<std::string>>(ConditionalOp::kSkipEagerDeletionVars);
exec.Run(*block->Program(), &cur_scope, block->ID(), false, true,
skip_vars);
skip_vars, /* force_disable_gc */ false,
/* keep_kid_scopes */ true);
}
}
};
......@@ -92,60 +95,65 @@ class ConditionalBlockGradOp : public ConditionalOp {
}
if (need_run) {
const auto &inputs = Inputs(ConditionalOp::kInputs);
const auto &outside_grads =
Outputs(framework::GradVarName(ConditionalOp::kInputs));
std::vector<std::string> inside_grads;
inside_grads.reserve(inputs.size());
for (auto &in : inputs) {
inside_grads.emplace_back(framework::GradVarName(in));
}
auto *scope_var = scope.FindVar(Input(ConditionalOp::kScope));
PADDLE_ENFORCE(scope_var != nullptr, "Must set scope");
PADDLE_ENFORCE_NE(scope_var, nullptr,
platform::errors::InvalidArgument(
"Scope must be set in conditional block op"));
auto &scopes = scope_var->Get<std::vector<framework::Scope *>>();
PADDLE_ENFORCE_GT(scopes.size(), 0,
platform::errors::InvalidArgument(
"Scope must be set in conditional block op"));
framework::Scope &cur_scope = *scopes[0];
framework::Executor exec(dev_place);
auto *block = Attr<framework::BlockDesc *>("sub_block");
const auto &ins = Inputs(ConditionalOp::kInputs);
const auto &d_ins =
Outputs(framework::GradVarName(ConditionalOp::kInputs));
const auto &conds = Inputs(ConditionalOp::kCondition);
const auto &d_conds =
Outputs(framework::GradVarName(ConditionalOp::kCondition));
std::vector<std::string> ins_conds_grads;
ins_conds_grads.reserve(ins.size() + conds.size());
for (auto &in : ins) {
ins_conds_grads.emplace_back(framework::GradVarName(in));
}
for (auto &cond : conds) {
ins_conds_grads.emplace_back(framework::GradVarName(cond));
}
VLOG(3) << "Conditional Grad block.idx = " << block->ID()
<< ", scope = " << &cur_scope;
exec.Run(*block->Program(), &cur_scope, block->ID(), false, true,
ins_conds_grads);
AssignLocalGradientToGlobal(dev_place, cur_scope, ins_conds_grads.data(),
ins.size(), d_ins);
inside_grads, /* force_disable_gc */ false,
/* keep_kid_scopes */ false);
AssignLocalGradientToGlobal(dev_place, cur_scope,
ins_conds_grads.data() + ins.size(),
conds.size(), d_conds);
AssignLocalGradientToParentScope(dev_place, cur_scope, scope,
inside_grads, outside_grads);
}
}
private:
void AssignLocalGradientToGlobal(
void AssignLocalGradientToParentScope(
const platform::Place &place, const framework::Scope &cur_scope,
const std::string *p_grad_names, size_t p_grad_names_num,
const std::vector<std::string> &pg_names) const {
for (size_t i = 0; i < p_grad_names_num; ++i) {
auto out_grad_name = pg_names[i];
const auto &in_grad_name = p_grad_names[i];
auto *in_var = cur_scope.FindVar(in_grad_name);
if (in_var == nullptr) {
const framework::Scope &parent_scope,
const std::vector<std::string> &inside_grads,
const std::vector<std::string> &outside_grads) const {
for (size_t i = 0; i < outside_grads.size(); ++i) {
const std::string &outside_grad_name = outside_grads[i];
const std::string &inside_grad_name = inside_grads[i];
VLOG(4) << "inside_grad_name = " << inside_grad_name
<< ", outside_grad_name = " << outside_grad_name;
framework::Variable *inside_var =
cur_scope.FindLocalVar(inside_grad_name);
if (inside_var == nullptr) {
continue;
}
auto new_in_grad_name = cur_scope.Rename(in_grad_name);
auto assign = framework::OpRegistry::CreateOp(
"assign", {{"X", {new_in_grad_name}}}, {{"Out", {out_grad_name}}},
framework::AttributeMap{});
assign->Run(cur_scope, place);
cur_scope.Rename(new_in_grad_name, in_grad_name);
framework::Variable *outside_var =
parent_scope.FindVar(outside_grad_name);
if (outside_var == nullptr) {
continue;
}
platform::DeviceContext *dev_ctx =
platform::DeviceContextPool::Instance().Get(place);
framework::VisitVarType(*inside_var,
AssignFunctor(outside_var, *dev_ctx));
}
}
};
......@@ -154,17 +162,11 @@ class ConditionalBlockGradInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInputs(ConditionalOp::kCondition));
if (context->HasInputs(ConditionalOp::kInputs)) {
PADDLE_ENFORCE(
context->HasOutputs(framework::GradVarName(ConditionalOp::kInputs)));
if (context->HasInputs(ConditionalOp::kInputs) &&
context->HasOutputs(framework::GradVarName(ConditionalOp::kInputs))) {
context->SetOutputsDim(framework::GradVarName(ConditionalOp::kInputs),
context->GetInputsDim(ConditionalOp::kInputs));
}
if (context->HasOutputs(
framework::GradVarName(ConditionalOp::kCondition))) {
context->SetOutputsDim(framework::GradVarName(ConditionalOp::kCondition),
context->GetInputsDim(ConditionalOp::kCondition));
}
}
};
......@@ -187,8 +189,6 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpMaker<T> {
this->OutputGrad(ConditionalOp::kOutputs));
grad_op->SetInput(ConditionalOp::kScope,
this->Output(ConditionalOp::kScope));
grad_op->SetOutput(framework::GradVarName(ConditionalOp::kCondition),
this->InputGrad(ConditionalOp::kCondition, false));
grad_op->SetOutput(framework::GradVarName(ConditionalOp::kInputs),
this->InputGrad(ConditionalOp::kInputs, false));
grad_op->SetBlockAttr("sub_block", this->grad_block_[0]);
......
......@@ -15,6 +15,7 @@
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/assign_op.h"
namespace paddle {
namespace operators {
......
......@@ -128,19 +128,19 @@ void SumToLoDTensor(const framework::ExecutionContext &context) {
in_vars[1]->IsType<framework::LoDTensor>()) {
auto &in_0 = in_vars[0]->Get<framework::LoDTensor>();
auto &in_1 = in_vars[1]->Get<framework::LoDTensor>();
auto length = in_0.numel();
if (length && in_0.IsInitialized() && in_1.IsInitialized()) {
int64_t length_0 = in_0.numel();
int64_t length_1 = in_1.numel();
if (length_0 && length_1 && in_0.IsInitialized() && in_1.IsInitialized()) {
auto result = EigenVector<T>::Flatten(*out);
auto &place = *dev_ctx.eigen_device();
auto in_0_e = EigenVector<T>::Flatten(in_0);
auto in_1_e = EigenVector<T>::Flatten(in_1);
result.device(place) = in_0_e + in_1_e;
} else if (length && in_0.IsInitialized()) {
} else if (length_0 && in_0.IsInitialized()) {
auto result = EigenVector<T>::Flatten(*out);
auto &place = *dev_ctx.eigen_device();
result.device(place) = EigenVector<T>::Flatten(in_0);
} else if (length && in_1.IsInitialized()) {
} else if (length_1 && in_1.IsInitialized()) {
auto result = EigenVector<T>::Flatten(*out);
auto &place = *dev_ctx.eigen_device();
result.device(place) = EigenVector<T>::Flatten(in_1);
......
......@@ -321,7 +321,7 @@ def _append_grad_suffix_(name):
return cpt.to_text(name) + core.grad_var_suffix()
def _addup_repetitive_outputs_(op_descs):
def _addup_repetitive_outputs_(op_descs, block_idx):
"""
In backward part, an variable may be the output of more than one ops.
And one op may yield its multiple outputs to the same variable.
......@@ -358,7 +358,7 @@ def _addup_repetitive_outputs_(op_descs):
renamed_var_start_idx[var_name] = idx
else:
if len(renamed_vars[var_name]) == 1:
new_name = var_name + "@RENAME@" + \
new_name = var_name + "@RENAME@block" + str(block_idx) + "@" + \
str(var_rename_count[var_name])
var_rename_count[var_name] += 1
# rename original var_name
......@@ -384,7 +384,7 @@ def _addup_repetitive_outputs_(op_descs):
for x in arg_names[:arg_idx]
] + arg_names[arg_idx:]
new_name = var_name + "@RENAME@" + \
new_name = var_name + "@RENAME@block" + str(block_idx) + "@" + \
str(var_rename_count[var_name])
var_rename_count[var_name] += 1
arg_names[arg_idx] = new_name
......@@ -733,7 +733,7 @@ def _append_backward_ops_with_checkpoints_(
grad_to_var.update(op_grad_to_var)
# 3.d. add sum op for repetitive_outputs
grad_op_descs = _addup_repetitive_outputs_(grad_op_descs)
grad_op_descs = _addup_repetitive_outputs_(grad_op_descs, block.idx)
# 4) remove no grad branch as it is in _remove_no_grad_branch_
grad_op_descs = _remove_no_grad_branch_(grad_op_descs,
no_grad_dict[block.idx])
......@@ -741,6 +741,35 @@ def _append_backward_ops_with_checkpoints_(
return program_stat, checkpoints_name, vars_should_be_hold, recompute_segments
def _get_sub_block_path(sub_block, sub_block_op_desc, no_grad_set):
"""
Get output vars in subblock which will be assigned to parent block.
It is used to find the grad path in subblock
"""
assert sub_block_op_desc.has_attr(
"sub_block") and sub_block.idx == sub_block_op_desc._block_attr_id(
"sub_block")
# TODO(huihuangzheng): add support for recurrent op and while op
if sub_block_op_desc.type == "conditional_block":
sub_outputs = []
sub_assign_to_out_ops = []
for var in sub_block_op_desc.output_arg_names:
for op_desc in sub_block.ops:
if op_desc.type == "assign" and var in op_desc.output_arg_names:
sub_assign_to_out_ops.append(op_desc)
sub_outputs.extend([
sub_block.var(name) for name in op_desc.input_arg_names
])
sub_block_op_path = _find_op_path_(sub_block, sub_outputs, [],
no_grad_set)
# TODO better way than finding in list
for op_desc in sub_assign_to_out_ops:
if op_desc not in sub_block_op_path:
sub_block_op_path.append(op_desc)
return sub_block_op_path
return sub_block.ops
def _append_backward_ops_(block,
ops,
target_block,
......@@ -775,6 +804,8 @@ def _append_backward_ops_(block,
# grad_op_descs holds created grad_op, and will be appended to target_block
grad_op_descs = []
program = block.program
# add grad_op_desc by reversed ops
for op in reversed(ops):
grad_sub_block_list = []
# If the op has its own sub-block, deal with the sub-block first
......@@ -785,7 +816,9 @@ def _append_backward_ops_(block,
# see follwing comments for why set None here.
pre_input_grad_names_set = copy.copy(input_grad_names_set)
input_grad_names_set = None
_append_backward_ops_(sub_block, sub_block.ops, grad_sub_block,
sub_block_path = _get_sub_block_path(sub_block, op,
no_grad_dict[sub_block.idx])
_append_backward_ops_(sub_block, sub_block_path, grad_sub_block,
no_grad_dict, grad_to_var, callbacks,
input_grad_names_set)
input_grad_names_set = pre_input_grad_names_set
......@@ -825,10 +858,8 @@ def _append_backward_ops_(block,
grad_op_descs.extend(grad_op_desc)
grad_to_var.update(op_grad_to_var)
# add grad_op_desc by reversed ops
# sum parameter's gradients' var given multiple var gradient
grad_op_descs = _addup_repetitive_outputs_(grad_op_descs)
grad_op_descs = _addup_repetitive_outputs_(grad_op_descs, block.idx)
# if all outputs of the grad op are in no_grad_set, then just remove and fill zero
# if all inputs of the grad op are in no_grad_set, just remove this op
......@@ -841,6 +872,7 @@ def _append_backward_ops_(block,
grad_op_descs = [
op_desc for op_desc in grad_op_descs if op_desc not in not_need_ops
]
# append op_desc in grad_op_descs to target_block
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
backward = core.op_proto_and_checker_maker.OpRole.Backward
......
......@@ -22,7 +22,7 @@ from ..framework import Program, Variable, Operator
from ..layer_helper import LayerHelper, unique_name
from ..initializer import force_init_on_cpu
from .nn import logical_and, logical_not, logical_or
from .utils import assert_same_structure, flatten, map_structure
from .utils import assert_same_structure, map_structure
import numpy
import warnings
import six
......@@ -1710,7 +1710,6 @@ class ConditionalBlock(object):
param_list = [
parent_block._var_recursive(each_name) for each_name in params
if each_name not in input_set
]
out_list = []
......@@ -1755,7 +1754,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None):
helper = LayerHelper('cond', **locals())
true_output = None
false_output = None
copy_to_global_func = lambda var: copy_var_to_parent_block(var, helper)
copy_to_parent_func = lambda var: copy_var_to_parent_block(var, helper)
if true_fn is not None:
if not callable(true_fn):
raise TypeError("The true_fn in cond must be callable")
......@@ -1763,7 +1762,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None):
with true_cond_block.block():
origin_true_output = true_fn()
if origin_true_output is not None:
true_output = map_structure(copy_to_global_func,
true_output = map_structure(copy_to_parent_func,
origin_true_output)
if false_fn is not None:
if not callable(false_fn):
......@@ -1773,7 +1772,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None):
with false_cond_block.block():
origin_false_output = false_fn()
if origin_false_output is not None:
false_output = map_structure(copy_to_global_func,
false_output = map_structure(copy_to_parent_func,
origin_false_output)
if true_output is None and false_output is None:
......
......@@ -21,11 +21,12 @@ import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.layers as layers
import paddle.fluid.framework as framework
from paddle.fluid.backward import append_backward
from paddle.fluid.executor import Executor
from paddle.fluid.framework import Program, program_guard
class TestCond(unittest.TestCase):
class TestCondInputOutput(unittest.TestCase):
def test_return_single_var(self):
"""
pseudocode:
......@@ -220,5 +221,59 @@ class TestCond(unittest.TestCase):
str(e.exception))
class TestCondNestedControlFlow(unittest.TestCase):
def test_cond_inside_cond(self):
"""
pseudocode:
for i in range(1, 10):
a = 2 * i
if i < 5:
if i >= 3:
return a + a
else:
return a - a
else:
if i < 8:
return a * a
else:
return a / a
"""
def less_than_branch(i, a):
return layers.cond(i >= 3.0, lambda: layers.elementwise_add(a, a),
lambda: layers.elementwise_sub(a, a))
def greater_equal_branch(i, a):
return layers.cond(i < 8.0, lambda: layers.elementwise_mul(a, a),
lambda: layers.elementwise_div(a, a))
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
i = fluid.data(name="i", shape=[1], dtype='float32')
a = 2.0 * i
out = layers.cond(i < 5.0, lambda: less_than_branch(i, a),
lambda: greater_equal_branch(i, a))
mean = layers.mean(out)
append_backward(mean)
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
for feed_i in range(0, 10):
expected_a = 2.0 * feed_i
if feed_i < 5:
expected_ret = expected_a + expected_a if feed_i >= 3 else 0.0
expected_a_grad = 2.0 if feed_i >= 3 else 0.0
else:
expected_ret = expected_a * expected_a if feed_i < 8 else 1.0
expected_a_grad = 2.0 * expected_a if feed_i < 8 else 0.0
ret = exe.run(main_program,
feed={'i': np.full((1), feed_i, np.float32)},
fetch_list=[out.name, a.grad_name])
self.assertEqual(ret[0][0], expected_ret)
self.assertEqual(ret[1][0], expected_a_grad)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册