提交 676dfd18 编写于 作者: C chengduoZH

follow comments

上级 aff8a26d
...@@ -49,7 +49,9 @@ void FetchOpHandle::RunImpl() { ...@@ -49,7 +49,9 @@ void FetchOpHandle::RunImpl() {
platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
for (auto *input : inputs_) { for (auto *input : inputs_) {
auto *var = static_cast<VarHandle *>(input); auto *var = static_cast<VarHandle *>(input);
if (var->generated_op_) var->generated_op_->Wait(cpu_ctx); if (var->generated_op_) {
var->generated_op_->Wait(cpu_ctx);
}
} }
tensors_.resize(inputs_.size()); tensors_.resize(inputs_.size());
auto *var_handle = static_cast<VarHandle *>(inputs_[0]); auto *var_handle = static_cast<VarHandle *>(inputs_[0]);
...@@ -61,14 +63,9 @@ void FetchOpHandle::RunImpl() { ...@@ -61,14 +63,9 @@ void FetchOpHandle::RunImpl() {
auto &scope = scopes[i]; auto &scope = scopes[i];
auto *var = auto *var =
scope->FindVar(kLocalExecScopeName)->Get<Scope *>()->FindVar(var_name); scope->FindVar(kLocalExecScopeName)->Get<Scope *>()->FindVar(var_name);
if (var == nullptr) {
scope->FindVar(var_name);
}
PADDLE_ENFORCE_NOT_NULL(var, "Cannot find variable %s in execution scope", PADDLE_ENFORCE_NOT_NULL(var, "Cannot find variable %s in execution scope",
var_name); var_name);
auto &t = var->Get<framework::LoDTensor>(); auto &t = var->Get<framework::LoDTensor>();
if (platform::is_gpu_place(t.place())) { if (platform::is_gpu_place(t.place())) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
TensorCopySync(t, cpu, &tensors_[i]); TensorCopySync(t, cpu, &tensors_[i]);
......
...@@ -36,7 +36,9 @@ void NCCLAllReduceOpHandle::RunImpl() { ...@@ -36,7 +36,9 @@ void NCCLAllReduceOpHandle::RunImpl() {
// Wait input done // Wait input done
for (auto *in : inputs_) { for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_; auto &p = static_cast<VarHandle *>(in)->place_;
if (in->generated_op_) in->generated_op_->Wait(dev_ctxes_[p]); if (in->generated_op_) {
in->generated_op_->Wait(dev_ctxes_[p]);
}
} }
auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_; auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_;
......
...@@ -32,7 +32,9 @@ void SendOpHandle::RunImpl() { ...@@ -32,7 +32,9 @@ void SendOpHandle::RunImpl() {
if (in->DebugString() == "dummy") { // HACK if (in->DebugString() == "dummy") { // HACK
continue; continue;
} }
if (in->generated_op_) in->generated_op_->Wait(dev_ctxes_[p]); if (in->generated_op_) {
in->generated_op_->Wait(dev_ctxes_[p]);
}
} }
auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(); auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
// FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead // FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册