未验证 提交 630be319 编写于 作者: H Huihuang Zheng 提交者: GitHub

Fix Cond Bug for Nested Control Flow (#21340)

* Commit before merging develop

test=develop

* Backup after working with Huihuang logs

* Commit before deleting Huihuang debug loggings

* Commit before debug

test=develop

* Fix bug commit

test=develop

* Backup of fixing bugs

test=develop

* Clean up code

test=develop

* Fix a bug in sum_op

test=develop
上级 cd43c444
...@@ -120,13 +120,12 @@ void Executor::Close() { ...@@ -120,13 +120,12 @@ void Executor::Close() {
void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope, void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
int block_id) { int block_id) {
VLOG(3) << "Creating Variables for block " << block_id;
auto& global_block = pdesc.Block(block_id); auto& global_block = pdesc.Block(block_id);
const Scope* ancestor_scope = scope; const Scope* ancestor_scope = scope;
while (ancestor_scope->parent()) { while (ancestor_scope->parent()) {
ancestor_scope = ancestor_scope->parent(); ancestor_scope = ancestor_scope->parent();
} }
if (ancestor_scope != scope) { if (ancestor_scope != scope) {
for (auto& var : global_block.AllVars()) { for (auto& var : global_block.AllVars()) {
if (var->Name() == framework::kEmptyVarName) { if (var->Name() == framework::kEmptyVarName) {
...@@ -196,11 +195,12 @@ void Executor::ReleaseTrainer(std::shared_ptr<TrainerBase> trainer) { ...@@ -196,11 +195,12 @@ void Executor::ReleaseTrainer(std::shared_ptr<TrainerBase> trainer) {
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars, bool create_local_scope, bool create_vars,
const std::vector<std::string>& skip_ref_cnt_vars, const std::vector<std::string>& skip_ref_cnt_vars,
bool force_disable_gc) { bool force_disable_gc, bool keep_kid_scopes) {
platform::RecordBlock b(block_id); platform::RecordBlock b(block_id);
if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc); if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc);
auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc); auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars); RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars,
keep_kid_scopes);
} }
// Check whether the block already has feed operators and feed_holder. // Check whether the block already has feed operators and feed_holder.
...@@ -465,6 +465,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -465,6 +465,7 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
// the sub scopes it created should not be dropped immediately, because // the sub scopes it created should not be dropped immediately, because
// while_grad_op will use some variables created during while_op run, so // while_grad_op will use some variables created during while_op run, so
// we need to keep the kids and wait for the outer executor to drop them. // we need to keep the kids and wait for the outer executor to drop them.
scope->DropKids(); scope->DropKids();
} }
} }
......
...@@ -71,12 +71,18 @@ class Executor { ...@@ -71,12 +71,18 @@ class Executor {
* @param * @param
* ProgramDesc * ProgramDesc
* Scope * Scope
* block_id
* create_local_scope
* create_vars
* skip_ref_cnt_vars
* force_disable_gc
* keep_kid_scopes
*/ */
void Run(const ProgramDesc& prog, Scope* scope, int block_id, void Run(const ProgramDesc& prog, Scope* scope, int block_id,
bool create_local_scope = true, bool create_vars = true, bool create_local_scope = true, bool create_vars = true,
const std::vector<std::string>& skip_ref_cnt_vars = const std::vector<std::string>& skip_ref_cnt_vars =
std::vector<std::string>(), std::vector<std::string>(),
bool force_disable_gc = false); bool force_disable_gc = false, bool keep_kid_scopes = false);
// This API is very slow. // This API is very slow.
void Run(const ProgramDesc& program, Scope* scope, void Run(const ProgramDesc& program, Scope* scope,
......
...@@ -93,6 +93,11 @@ const Scope* Scope::FindScope(const Variable* var) const { ...@@ -93,6 +93,11 @@ const Scope* Scope::FindScope(const Variable* var) const {
return FindScopeInternal(var); return FindScopeInternal(var);
} }
const Scope* Scope::FindScope(const std::string& name) const {
SCOPE_VARS_READER_LOCK
return FindScopeInternal(name);
}
void Scope::DropKids() { void Scope::DropKids() {
SCOPE_KIDS_WRITER_LOCK SCOPE_KIDS_WRITER_LOCK
for (Scope* s : kids_) delete s; for (Scope* s : kids_) delete s;
...@@ -174,6 +179,13 @@ const Scope* Scope::FindScopeInternal(const Variable* var) const { ...@@ -174,6 +179,13 @@ const Scope* Scope::FindScopeInternal(const Variable* var) const {
return (parent_ == nullptr) ? nullptr : parent_->FindScope(var); return (parent_ == nullptr) ? nullptr : parent_->FindScope(var);
} }
const Scope* Scope::FindScopeInternal(const std::string& name) const {
if (vars_.find(name) != vars_.end()) {
return this;
}
return (parent_ == nullptr) ? nullptr : parent_->FindScope(name);
}
void Scope::RenameInternal(const std::string& origin_name, void Scope::RenameInternal(const std::string& origin_name,
const std::string& new_name) const { const std::string& new_name) const {
auto origin_it = vars_.find(origin_name); auto origin_it = vars_.find(origin_name);
...@@ -196,7 +208,9 @@ Variable* Scope::FindVarInternal(const std::string& name) const { ...@@ -196,7 +208,9 @@ Variable* Scope::FindVarInternal(const std::string& name) const {
Variable* Scope::FindVarLocally(const std::string& name) const { Variable* Scope::FindVarLocally(const std::string& name) const {
auto it = vars_.find(name); auto it = vars_.find(name);
if (it != vars_.end()) return it->second.get(); if (it != vars_.end()) {
return it->second.get();
}
return nullptr; return nullptr;
} }
......
...@@ -85,6 +85,9 @@ class Scope { ...@@ -85,6 +85,9 @@ class Scope {
/// Find the scope or an ancestor scope that contains the given variable. /// Find the scope or an ancestor scope that contains the given variable.
const Scope* FindScope(const Variable* var) const; const Scope* FindScope(const Variable* var) const;
/// Find the scope or an ancestor scope that contains the given variable name.
const Scope* FindScope(const std::string& name) const;
void DeleteScope(Scope* scope) const; void DeleteScope(Scope* scope) const;
/// Drop all kids scopes belonged to this scope. /// Drop all kids scopes belonged to this scope.
...@@ -125,6 +128,9 @@ class Scope { ...@@ -125,6 +128,9 @@ class Scope {
// Called by FindScope. // Called by FindScope.
const Scope* FindScopeInternal(const Variable* var) const; const Scope* FindScopeInternal(const Variable* var) const;
// Called by FindScope.
const Scope* FindScopeInternal(const std::string& name) const;
// Called by Rename. // Called by Rename.
void RenameInternal(const std::string& origin_name, void RenameInternal(const std::string& origin_name,
const std::string& new_name) const; const std::string& new_name) const;
......
...@@ -53,7 +53,9 @@ TEST(Scope, FindScope) { ...@@ -53,7 +53,9 @@ TEST(Scope, FindScope) {
Variable* v = s.Var("a"); Variable* v = s.Var("a");
EXPECT_EQ(&s, s.FindScope(v)); EXPECT_EQ(&s, s.FindScope(v));
EXPECT_EQ(&s, s.FindScope("a"));
EXPECT_EQ(&s, ss.FindScope(v)); EXPECT_EQ(&s, ss.FindScope(v));
EXPECT_EQ(&s, ss.FindScope("a"));
} }
TEST(Scope, GetAllNames) { TEST(Scope, GetAllNames) {
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/controlflow/conditional_block_op.h" #include "paddle/fluid/operators/controlflow/conditional_block_op.h"
#include "paddle/fluid/operators/assign_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -58,13 +59,15 @@ class ConditionalBlockOp : public ConditionalOp { ...@@ -58,13 +59,15 @@ class ConditionalBlockOp : public ConditionalOp {
scopes->resize(1); scopes->resize(1);
scopes->front() = &scope.NewScope(); scopes->front() = &scope.NewScope();
auto &cur_scope = *scopes->front(); auto &cur_scope = *scopes->front();
framework::Executor exec(dev_place); framework::Executor exec(dev_place);
auto *block = Attr<framework::BlockDesc *>("sub_block"); auto *block = Attr<framework::BlockDesc *>("sub_block");
VLOG(3) << "Conditional block.idx = " << block->ID()
<< ", scope = " << &cur_scope;
auto &skip_vars = auto &skip_vars =
Attr<std::vector<std::string>>(ConditionalOp::kSkipEagerDeletionVars); Attr<std::vector<std::string>>(ConditionalOp::kSkipEagerDeletionVars);
exec.Run(*block->Program(), &cur_scope, block->ID(), false, true, exec.Run(*block->Program(), &cur_scope, block->ID(), false, true,
skip_vars); skip_vars, /* force_disable_gc */ false,
/* keep_kid_scopes */ true);
} }
} }
}; };
...@@ -92,60 +95,65 @@ class ConditionalBlockGradOp : public ConditionalOp { ...@@ -92,60 +95,65 @@ class ConditionalBlockGradOp : public ConditionalOp {
} }
if (need_run) { if (need_run) {
const auto &inputs = Inputs(ConditionalOp::kInputs);
const auto &outside_grads =
Outputs(framework::GradVarName(ConditionalOp::kInputs));
std::vector<std::string> inside_grads;
inside_grads.reserve(inputs.size());
for (auto &in : inputs) {
inside_grads.emplace_back(framework::GradVarName(in));
}
auto *scope_var = scope.FindVar(Input(ConditionalOp::kScope)); auto *scope_var = scope.FindVar(Input(ConditionalOp::kScope));
PADDLE_ENFORCE(scope_var != nullptr, "Must set scope"); PADDLE_ENFORCE_NE(scope_var, nullptr,
platform::errors::InvalidArgument(
"Scope must be set in conditional block op"));
auto &scopes = scope_var->Get<std::vector<framework::Scope *>>(); auto &scopes = scope_var->Get<std::vector<framework::Scope *>>();
PADDLE_ENFORCE_GT(scopes.size(), 0,
platform::errors::InvalidArgument(
"Scope must be set in conditional block op"));
framework::Scope &cur_scope = *scopes[0]; framework::Scope &cur_scope = *scopes[0];
framework::Executor exec(dev_place); framework::Executor exec(dev_place);
auto *block = Attr<framework::BlockDesc *>("sub_block"); auto *block = Attr<framework::BlockDesc *>("sub_block");
const auto &ins = Inputs(ConditionalOp::kInputs); VLOG(3) << "Conditional Grad block.idx = " << block->ID()
const auto &d_ins = << ", scope = " << &cur_scope;
Outputs(framework::GradVarName(ConditionalOp::kInputs));
const auto &conds = Inputs(ConditionalOp::kCondition);
const auto &d_conds =
Outputs(framework::GradVarName(ConditionalOp::kCondition));
std::vector<std::string> ins_conds_grads;
ins_conds_grads.reserve(ins.size() + conds.size());
for (auto &in : ins) {
ins_conds_grads.emplace_back(framework::GradVarName(in));
}
for (auto &cond : conds) {
ins_conds_grads.emplace_back(framework::GradVarName(cond));
}
exec.Run(*block->Program(), &cur_scope, block->ID(), false, true, exec.Run(*block->Program(), &cur_scope, block->ID(), false, true,
ins_conds_grads); inside_grads, /* force_disable_gc */ false,
/* keep_kid_scopes */ false);
AssignLocalGradientToGlobal(dev_place, cur_scope, ins_conds_grads.data(),
ins.size(), d_ins);
AssignLocalGradientToGlobal(dev_place, cur_scope, AssignLocalGradientToParentScope(dev_place, cur_scope, scope,
ins_conds_grads.data() + ins.size(), inside_grads, outside_grads);
conds.size(), d_conds);
} }
} }
private: private:
void AssignLocalGradientToGlobal( void AssignLocalGradientToParentScope(
const platform::Place &place, const framework::Scope &cur_scope, const platform::Place &place, const framework::Scope &cur_scope,
const std::string *p_grad_names, size_t p_grad_names_num, const framework::Scope &parent_scope,
const std::vector<std::string> &pg_names) const { const std::vector<std::string> &inside_grads,
for (size_t i = 0; i < p_grad_names_num; ++i) { const std::vector<std::string> &outside_grads) const {
auto out_grad_name = pg_names[i]; for (size_t i = 0; i < outside_grads.size(); ++i) {
const auto &in_grad_name = p_grad_names[i]; const std::string &outside_grad_name = outside_grads[i];
auto *in_var = cur_scope.FindVar(in_grad_name); const std::string &inside_grad_name = inside_grads[i];
if (in_var == nullptr) { VLOG(4) << "inside_grad_name = " << inside_grad_name
<< ", outside_grad_name = " << outside_grad_name;
framework::Variable *inside_var =
cur_scope.FindLocalVar(inside_grad_name);
if (inside_var == nullptr) {
continue; continue;
} }
auto new_in_grad_name = cur_scope.Rename(in_grad_name); framework::Variable *outside_var =
auto assign = framework::OpRegistry::CreateOp( parent_scope.FindVar(outside_grad_name);
"assign", {{"X", {new_in_grad_name}}}, {{"Out", {out_grad_name}}}, if (outside_var == nullptr) {
framework::AttributeMap{}); continue;
assign->Run(cur_scope, place); }
cur_scope.Rename(new_in_grad_name, in_grad_name); platform::DeviceContext *dev_ctx =
platform::DeviceContextPool::Instance().Get(place);
framework::VisitVarType(*inside_var,
AssignFunctor(outside_var, *dev_ctx));
} }
} }
}; };
...@@ -154,17 +162,11 @@ class ConditionalBlockGradInferShape : public framework::InferShapeBase { ...@@ -154,17 +162,11 @@ class ConditionalBlockGradInferShape : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext *context) const override { void operator()(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE(context->HasInputs(ConditionalOp::kCondition)); PADDLE_ENFORCE(context->HasInputs(ConditionalOp::kCondition));
if (context->HasInputs(ConditionalOp::kInputs)) { if (context->HasInputs(ConditionalOp::kInputs) &&
PADDLE_ENFORCE( context->HasOutputs(framework::GradVarName(ConditionalOp::kInputs))) {
context->HasOutputs(framework::GradVarName(ConditionalOp::kInputs)));
context->SetOutputsDim(framework::GradVarName(ConditionalOp::kInputs), context->SetOutputsDim(framework::GradVarName(ConditionalOp::kInputs),
context->GetInputsDim(ConditionalOp::kInputs)); context->GetInputsDim(ConditionalOp::kInputs));
} }
if (context->HasOutputs(
framework::GradVarName(ConditionalOp::kCondition))) {
context->SetOutputsDim(framework::GradVarName(ConditionalOp::kCondition),
context->GetInputsDim(ConditionalOp::kCondition));
}
} }
}; };
...@@ -187,8 +189,6 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -187,8 +189,6 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpMaker<T> {
this->OutputGrad(ConditionalOp::kOutputs)); this->OutputGrad(ConditionalOp::kOutputs));
grad_op->SetInput(ConditionalOp::kScope, grad_op->SetInput(ConditionalOp::kScope,
this->Output(ConditionalOp::kScope)); this->Output(ConditionalOp::kScope));
grad_op->SetOutput(framework::GradVarName(ConditionalOp::kCondition),
this->InputGrad(ConditionalOp::kCondition, false));
grad_op->SetOutput(framework::GradVarName(ConditionalOp::kInputs), grad_op->SetOutput(framework::GradVarName(ConditionalOp::kInputs),
this->InputGrad(ConditionalOp::kInputs, false)); this->InputGrad(ConditionalOp::kInputs, false));
grad_op->SetBlockAttr("sub_block", this->grad_block_[0]); grad_op->SetBlockAttr("sub_block", this->grad_block_[0]);
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <algorithm> #include <algorithm>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/assign_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -128,19 +128,19 @@ void SumToLoDTensor(const framework::ExecutionContext &context) { ...@@ -128,19 +128,19 @@ void SumToLoDTensor(const framework::ExecutionContext &context) {
in_vars[1]->IsType<framework::LoDTensor>()) { in_vars[1]->IsType<framework::LoDTensor>()) {
auto &in_0 = in_vars[0]->Get<framework::LoDTensor>(); auto &in_0 = in_vars[0]->Get<framework::LoDTensor>();
auto &in_1 = in_vars[1]->Get<framework::LoDTensor>(); auto &in_1 = in_vars[1]->Get<framework::LoDTensor>();
int64_t length_0 = in_0.numel();
auto length = in_0.numel(); int64_t length_1 = in_1.numel();
if (length && in_0.IsInitialized() && in_1.IsInitialized()) { if (length_0 && length_1 && in_0.IsInitialized() && in_1.IsInitialized()) {
auto result = EigenVector<T>::Flatten(*out); auto result = EigenVector<T>::Flatten(*out);
auto &place = *dev_ctx.eigen_device(); auto &place = *dev_ctx.eigen_device();
auto in_0_e = EigenVector<T>::Flatten(in_0); auto in_0_e = EigenVector<T>::Flatten(in_0);
auto in_1_e = EigenVector<T>::Flatten(in_1); auto in_1_e = EigenVector<T>::Flatten(in_1);
result.device(place) = in_0_e + in_1_e; result.device(place) = in_0_e + in_1_e;
} else if (length && in_0.IsInitialized()) { } else if (length_0 && in_0.IsInitialized()) {
auto result = EigenVector<T>::Flatten(*out); auto result = EigenVector<T>::Flatten(*out);
auto &place = *dev_ctx.eigen_device(); auto &place = *dev_ctx.eigen_device();
result.device(place) = EigenVector<T>::Flatten(in_0); result.device(place) = EigenVector<T>::Flatten(in_0);
} else if (length && in_1.IsInitialized()) { } else if (length_1 && in_1.IsInitialized()) {
auto result = EigenVector<T>::Flatten(*out); auto result = EigenVector<T>::Flatten(*out);
auto &place = *dev_ctx.eigen_device(); auto &place = *dev_ctx.eigen_device();
result.device(place) = EigenVector<T>::Flatten(in_1); result.device(place) = EigenVector<T>::Flatten(in_1);
......
...@@ -321,7 +321,7 @@ def _append_grad_suffix_(name): ...@@ -321,7 +321,7 @@ def _append_grad_suffix_(name):
return cpt.to_text(name) + core.grad_var_suffix() return cpt.to_text(name) + core.grad_var_suffix()
def _addup_repetitive_outputs_(op_descs): def _addup_repetitive_outputs_(op_descs, block_idx):
""" """
In backward part, an variable may be the output of more than one ops. In backward part, an variable may be the output of more than one ops.
And one op may yield its multiple outputs to the same variable. And one op may yield its multiple outputs to the same variable.
...@@ -358,7 +358,7 @@ def _addup_repetitive_outputs_(op_descs): ...@@ -358,7 +358,7 @@ def _addup_repetitive_outputs_(op_descs):
renamed_var_start_idx[var_name] = idx renamed_var_start_idx[var_name] = idx
else: else:
if len(renamed_vars[var_name]) == 1: if len(renamed_vars[var_name]) == 1:
new_name = var_name + "@RENAME@" + \ new_name = var_name + "@RENAME@block" + str(block_idx) + "@" + \
str(var_rename_count[var_name]) str(var_rename_count[var_name])
var_rename_count[var_name] += 1 var_rename_count[var_name] += 1
# rename original var_name # rename original var_name
...@@ -384,7 +384,7 @@ def _addup_repetitive_outputs_(op_descs): ...@@ -384,7 +384,7 @@ def _addup_repetitive_outputs_(op_descs):
for x in arg_names[:arg_idx] for x in arg_names[:arg_idx]
] + arg_names[arg_idx:] ] + arg_names[arg_idx:]
new_name = var_name + "@RENAME@" + \ new_name = var_name + "@RENAME@block" + str(block_idx) + "@" + \
str(var_rename_count[var_name]) str(var_rename_count[var_name])
var_rename_count[var_name] += 1 var_rename_count[var_name] += 1
arg_names[arg_idx] = new_name arg_names[arg_idx] = new_name
...@@ -733,7 +733,7 @@ def _append_backward_ops_with_checkpoints_( ...@@ -733,7 +733,7 @@ def _append_backward_ops_with_checkpoints_(
grad_to_var.update(op_grad_to_var) grad_to_var.update(op_grad_to_var)
# 3.d. add sum op for repetitive_outputs # 3.d. add sum op for repetitive_outputs
grad_op_descs = _addup_repetitive_outputs_(grad_op_descs) grad_op_descs = _addup_repetitive_outputs_(grad_op_descs, block.idx)
# 4) remove no grad branch as it is in _remove_no_grad_branch_ # 4) remove no grad branch as it is in _remove_no_grad_branch_
grad_op_descs = _remove_no_grad_branch_(grad_op_descs, grad_op_descs = _remove_no_grad_branch_(grad_op_descs,
no_grad_dict[block.idx]) no_grad_dict[block.idx])
...@@ -741,6 +741,35 @@ def _append_backward_ops_with_checkpoints_( ...@@ -741,6 +741,35 @@ def _append_backward_ops_with_checkpoints_(
return program_stat, checkpoints_name, vars_should_be_hold, recompute_segments return program_stat, checkpoints_name, vars_should_be_hold, recompute_segments
def _get_sub_block_path(sub_block, sub_block_op_desc, no_grad_set):
"""
Get output vars in subblock which will be assigned to parent block.
It is used to find the grad path in subblock
"""
assert sub_block_op_desc.has_attr(
"sub_block") and sub_block.idx == sub_block_op_desc._block_attr_id(
"sub_block")
# TODO(huihuangzheng): add support for recurrent op and while op
if sub_block_op_desc.type == "conditional_block":
sub_outputs = []
sub_assign_to_out_ops = []
for var in sub_block_op_desc.output_arg_names:
for op_desc in sub_block.ops:
if op_desc.type == "assign" and var in op_desc.output_arg_names:
sub_assign_to_out_ops.append(op_desc)
sub_outputs.extend([
sub_block.var(name) for name in op_desc.input_arg_names
])
sub_block_op_path = _find_op_path_(sub_block, sub_outputs, [],
no_grad_set)
# TODO better way than finding in list
for op_desc in sub_assign_to_out_ops:
if op_desc not in sub_block_op_path:
sub_block_op_path.append(op_desc)
return sub_block_op_path
return sub_block.ops
def _append_backward_ops_(block, def _append_backward_ops_(block,
ops, ops,
target_block, target_block,
...@@ -775,6 +804,8 @@ def _append_backward_ops_(block, ...@@ -775,6 +804,8 @@ def _append_backward_ops_(block,
# grad_op_descs holds created grad_op, and will be appended to target_block # grad_op_descs holds created grad_op, and will be appended to target_block
grad_op_descs = [] grad_op_descs = []
program = block.program program = block.program
# add grad_op_desc by reversed ops
for op in reversed(ops): for op in reversed(ops):
grad_sub_block_list = [] grad_sub_block_list = []
# If the op has its own sub-block, deal with the sub-block first # If the op has its own sub-block, deal with the sub-block first
...@@ -785,7 +816,9 @@ def _append_backward_ops_(block, ...@@ -785,7 +816,9 @@ def _append_backward_ops_(block,
# see follwing comments for why set None here. # see follwing comments for why set None here.
pre_input_grad_names_set = copy.copy(input_grad_names_set) pre_input_grad_names_set = copy.copy(input_grad_names_set)
input_grad_names_set = None input_grad_names_set = None
_append_backward_ops_(sub_block, sub_block.ops, grad_sub_block, sub_block_path = _get_sub_block_path(sub_block, op,
no_grad_dict[sub_block.idx])
_append_backward_ops_(sub_block, sub_block_path, grad_sub_block,
no_grad_dict, grad_to_var, callbacks, no_grad_dict, grad_to_var, callbacks,
input_grad_names_set) input_grad_names_set)
input_grad_names_set = pre_input_grad_names_set input_grad_names_set = pre_input_grad_names_set
...@@ -825,10 +858,8 @@ def _append_backward_ops_(block, ...@@ -825,10 +858,8 @@ def _append_backward_ops_(block,
grad_op_descs.extend(grad_op_desc) grad_op_descs.extend(grad_op_desc)
grad_to_var.update(op_grad_to_var) grad_to_var.update(op_grad_to_var)
# add grad_op_desc by reversed ops
# sum parameter's gradients' var given multiple var gradient # sum parameter's gradients' var given multiple var gradient
grad_op_descs = _addup_repetitive_outputs_(grad_op_descs) grad_op_descs = _addup_repetitive_outputs_(grad_op_descs, block.idx)
# if all outputs of the grad op are in no_grad_set, then just remove and fill zero # if all outputs of the grad op are in no_grad_set, then just remove and fill zero
# if all inputs of the grad op are in no_grad_set, just remove this op # if all inputs of the grad op are in no_grad_set, just remove this op
...@@ -841,6 +872,7 @@ def _append_backward_ops_(block, ...@@ -841,6 +872,7 @@ def _append_backward_ops_(block,
grad_op_descs = [ grad_op_descs = [
op_desc for op_desc in grad_op_descs if op_desc not in not_need_ops op_desc for op_desc in grad_op_descs if op_desc not in not_need_ops
] ]
# append op_desc in grad_op_descs to target_block # append op_desc in grad_op_descs to target_block
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
backward = core.op_proto_and_checker_maker.OpRole.Backward backward = core.op_proto_and_checker_maker.OpRole.Backward
......
...@@ -22,7 +22,7 @@ from ..framework import Program, Variable, Operator ...@@ -22,7 +22,7 @@ from ..framework import Program, Variable, Operator
from ..layer_helper import LayerHelper, unique_name from ..layer_helper import LayerHelper, unique_name
from ..initializer import force_init_on_cpu from ..initializer import force_init_on_cpu
from .nn import logical_and, logical_not, logical_or from .nn import logical_and, logical_not, logical_or
from .utils import assert_same_structure, flatten, map_structure from .utils import assert_same_structure, map_structure
import numpy import numpy
import warnings import warnings
import six import six
...@@ -1710,7 +1710,6 @@ class ConditionalBlock(object): ...@@ -1710,7 +1710,6 @@ class ConditionalBlock(object):
param_list = [ param_list = [
parent_block._var_recursive(each_name) for each_name in params parent_block._var_recursive(each_name) for each_name in params
if each_name not in input_set
] ]
out_list = [] out_list = []
...@@ -1755,7 +1754,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None): ...@@ -1755,7 +1754,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None):
helper = LayerHelper('cond', **locals()) helper = LayerHelper('cond', **locals())
true_output = None true_output = None
false_output = None false_output = None
copy_to_global_func = lambda var: copy_var_to_parent_block(var, helper) copy_to_parent_func = lambda var: copy_var_to_parent_block(var, helper)
if true_fn is not None: if true_fn is not None:
if not callable(true_fn): if not callable(true_fn):
raise TypeError("The true_fn in cond must be callable") raise TypeError("The true_fn in cond must be callable")
...@@ -1763,7 +1762,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None): ...@@ -1763,7 +1762,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None):
with true_cond_block.block(): with true_cond_block.block():
origin_true_output = true_fn() origin_true_output = true_fn()
if origin_true_output is not None: if origin_true_output is not None:
true_output = map_structure(copy_to_global_func, true_output = map_structure(copy_to_parent_func,
origin_true_output) origin_true_output)
if false_fn is not None: if false_fn is not None:
if not callable(false_fn): if not callable(false_fn):
...@@ -1773,7 +1772,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None): ...@@ -1773,7 +1772,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None):
with false_cond_block.block(): with false_cond_block.block():
origin_false_output = false_fn() origin_false_output = false_fn()
if origin_false_output is not None: if origin_false_output is not None:
false_output = map_structure(copy_to_global_func, false_output = map_structure(copy_to_parent_func,
origin_false_output) origin_false_output)
if true_output is None and false_output is None: if true_output is None and false_output is None:
......
...@@ -21,11 +21,12 @@ import paddle.fluid as fluid ...@@ -21,11 +21,12 @@ import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
import paddle.fluid.framework as framework import paddle.fluid.framework as framework
from paddle.fluid.backward import append_backward
from paddle.fluid.executor import Executor from paddle.fluid.executor import Executor
from paddle.fluid.framework import Program, program_guard from paddle.fluid.framework import Program, program_guard
class TestCond(unittest.TestCase): class TestCondInputOutput(unittest.TestCase):
def test_return_single_var(self): def test_return_single_var(self):
""" """
pseudocode: pseudocode:
...@@ -220,5 +221,59 @@ class TestCond(unittest.TestCase): ...@@ -220,5 +221,59 @@ class TestCond(unittest.TestCase):
str(e.exception)) str(e.exception))
class TestCondNestedControlFlow(unittest.TestCase):
def test_cond_inside_cond(self):
"""
pseudocode:
for i in range(1, 10):
a = 2 * i
if i < 5:
if i >= 3:
return a + a
else:
return a - a
else:
if i < 8:
return a * a
else:
return a / a
"""
def less_than_branch(i, a):
return layers.cond(i >= 3.0, lambda: layers.elementwise_add(a, a),
lambda: layers.elementwise_sub(a, a))
def greater_equal_branch(i, a):
return layers.cond(i < 8.0, lambda: layers.elementwise_mul(a, a),
lambda: layers.elementwise_div(a, a))
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
i = fluid.data(name="i", shape=[1], dtype='float32')
a = 2.0 * i
out = layers.cond(i < 5.0, lambda: less_than_branch(i, a),
lambda: greater_equal_branch(i, a))
mean = layers.mean(out)
append_backward(mean)
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
for feed_i in range(0, 10):
expected_a = 2.0 * feed_i
if feed_i < 5:
expected_ret = expected_a + expected_a if feed_i >= 3 else 0.0
expected_a_grad = 2.0 if feed_i >= 3 else 0.0
else:
expected_ret = expected_a * expected_a if feed_i < 8 else 1.0
expected_a_grad = 2.0 * expected_a if feed_i < 8 else 0.0
ret = exe.run(main_program,
feed={'i': np.full((1), feed_i, np.float32)},
fetch_list=[out.name, a.grad_name])
self.assertEqual(ret[0][0], expected_ret)
self.assertEqual(ret[1][0], expected_a_grad)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册