提交 692a0f74 编写于 作者: Y Yu Yang

Better name

上级 baef1124
...@@ -27,7 +27,8 @@ struct VarHandle { ...@@ -27,7 +27,8 @@ struct VarHandle {
platform::Place place_; platform::Place place_;
OpHandle *generated_op_; OpHandle *generated_op_;
std::vector<OpHandle *> deps_ops_;
std::vector<OpHandle *> pending_ops_;
}; };
struct OpHandle { struct OpHandle {
...@@ -141,7 +142,7 @@ void ParallelExecutor::ConstructDependencyGraph( ...@@ -141,7 +142,7 @@ void ParallelExecutor::ConstructDependencyGraph(
auto &place = pair.first; auto &place = pair.first;
VarHandle *var = GetVarHandle(each_var_name, place); VarHandle *var = GetVarHandle(each_var_name, place);
op_handle->inputs_.emplace_back(var); op_handle->inputs_.emplace_back(var);
var->deps_ops_.emplace_back(op_handle); var->pending_ops_.emplace_back(op_handle);
} }
var_names = op->OutputArgumentNames(); var_names = op->OutputArgumentNames();
...@@ -158,7 +159,7 @@ void ParallelExecutor::ConstructDependencyGraph( ...@@ -158,7 +159,7 @@ void ParallelExecutor::ConstructDependencyGraph(
op_handle = member_->ops_.back().get(); op_handle = member_->ops_.back().get();
auto &place = pair.first; auto &place = pair.first;
VarHandle *loss = GetVarHandle(loss_var_name, place); VarHandle *loss = GetVarHandle(loss_var_name, place);
loss->deps_ops_.emplace_back(op_handle); loss->pending_ops_.emplace_back(op_handle);
op_handle->inputs_.emplace_back(loss); op_handle->inputs_.emplace_back(loss);
GenerateVar(op_handle, loss_var_name + "@GRAD", place); GenerateVar(op_handle, loss_var_name + "@GRAD", place);
change_forward = true; change_forward = true;
...@@ -188,7 +189,7 @@ void ParallelExecutor::ConstructDependencyGraph( ...@@ -188,7 +189,7 @@ void ParallelExecutor::ConstructDependencyGraph(
} }
auto *prev_grad = &vars[vars.size() - 1]; auto *prev_grad = &vars[vars.size() - 1];
op_handle->inputs_.emplace_back(prev_grad); op_handle->inputs_.emplace_back(prev_grad);
prev_grad->deps_ops_.emplace_back(op_handle); prev_grad->pending_ops_.emplace_back(op_handle);
auto &var = vars[vars.size()]; auto &var = vars[vars.size()];
var.place_ = place; var.place_ = place;
var.generated_op_ = op_handle; var.generated_op_ = op_handle;
...@@ -317,7 +318,7 @@ std::vector<LoDTensor> ParallelExecutor::Run( ...@@ -317,7 +318,7 @@ std::vector<LoDTensor> ParallelExecutor::Run(
std::vector<OpHandle *> to_run; std::vector<OpHandle *> to_run;
for (auto *var : to_remove) { for (auto *var : to_remove) {
for (auto *op : var->deps_ops_) { for (auto *op : var->pending_ops_) {
if (var->name_ == "mean_0.tmp_0@GRAD") { if (var->name_ == "mean_0.tmp_0@GRAD") {
LOG(INFO) << op->DebugString(); LOG(INFO) << op->DebugString();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册