提交 e335f018 编写于 作者: Y Yu Yang

Add more logs

上级 82693e72
......@@ -125,30 +125,6 @@ struct OpHandle {
virtual void RunImpl() = 0;
};
struct ComputationOpHandle : public OpHandle {
std::unique_ptr<OperatorBase> op_;
Scope *scope_;
platform::Place place_;
explicit ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
platform::Place place)
: op_(framework::OpRegistry::CreateOp(op_desc)),
scope_(scope),
place_(place) {}
protected:
void RunImpl() override {
auto *cur_ctx = dev_ctx_[place_];
for (auto *in : inputs_) {
if (in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx) {
in->generated_op_->Wait(cur_ctx);
}
}
op_->Run(*scope_, place_);
}
};
struct ScaleLossGradOpHandle : public OpHandle {
float coeff_;
Scope *scope_;
......@@ -396,6 +372,36 @@ struct NCCLAllReduceOpHandle : public OpHandle {
}
};
struct ComputationOpHandle : public OpHandle {
std::unique_ptr<OperatorBase> op_;
Scope *scope_;
platform::Place place_;
explicit ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
platform::Place place)
: op_(framework::OpRegistry::CreateOp(op_desc)),
scope_(scope),
place_(place) {}
protected:
void RunImpl() override {
auto *cur_ctx = dev_ctx_[place_];
for (auto *in : inputs_) {
bool need_wait =
in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx;
if (dynamic_cast<NCCLAllReduceOpHandle *>(in->generated_op_)) {
VLOG(3) << "Input is nccl all reduce, need to wait" << need_wait;
}
if (need_wait) {
in->generated_op_->Wait(cur_ctx);
}
}
op_->Run(*scope_, place_);
}
};
ParallelExecutor::ParallelExecutor(
size_t num_threads, const std::vector<platform::Place> &places,
const std::unordered_set<std::string> &params,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册