未验证 提交 ed087f82 编写于 作者: C chengduo 提交者: GitHub

refine op_handle (#14178)

test=develop
上级 da73bc39
......@@ -34,7 +34,7 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
nccl_ctxs_(ctxs) {
if (nccl_ctxs_) {
for (auto &p : places_) {
this->dev_ctxes_[p] = nccl_ctxs_->DevCtx(p);
this->SetDeviceContext(p, nccl_ctxs_->DevCtx(p));
}
}
}
......@@ -46,7 +46,7 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
#endif
void AllReduceOpHandle::RunImpl() {
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
platform::RecordEvent record_event(Name(), dev_ctxes_.cbegin()->second);
if (NoDummyInputSize() == 1) {
return; // No need to all reduce when GPU count = 1;
......@@ -127,7 +127,7 @@ void AllReduceOpHandle::RunImpl() {
*local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto &p = places_[i];
auto *var = scope.FindVar(out_var_handles[i]->name_);
auto *dev_ctx = dev_ctxes_[p];
auto *dev_ctx = dev_ctxes_.at(p);
RunAndRecordEvent(p, [&trg, var, dev_ctx, p] {
auto &tensor_gpu = *var->GetMutable<framework::LoDTensor>();
......
......@@ -44,7 +44,8 @@ struct BroadcastOpHandle : public OpHandleBase {
nccl_ctxs_(nccl_ctxs) {
if (nccl_ctxs_) {
for (auto &p_ctx : nccl_ctxs_->contexts_) {
dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get();
this->SetDeviceContext(platform::CUDAPlace(p_ctx.first),
p_ctx.second.ctx_.get());
}
}
}
......
......@@ -37,7 +37,7 @@ void ComputationOpHandle::RunImpl() {
bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) {
bool need_wait =
in_var && in_var->GeneratedOp() &&
in_var->GeneratedOp()->DeviceContext(place_) != dev_ctxes_[place_];
in_var->GeneratedOp()->DeviceContext(place_) != dev_ctxes_.at(place_);
return need_wait;
}
......
......@@ -28,7 +28,7 @@ DataBalanceOpHandle::DataBalanceOpHandle(
: OpHandleBase(node), local_scopes_(local_scopes), places_(places) {
if (ctxs) {
for (auto &p : places_) {
this->dev_ctxes_[p] = ctxs->DevCtx(p);
this->SetDeviceContext(p, ctxs->DevCtx(p));
}
}
}
......@@ -89,8 +89,8 @@ void DataBalanceOpHandle::RunImpl() {
PADDLE_ENFORCE_GT(places_.size(), 1,
"Data balance can only be enabled when the number of "
"places to run larger than 1.");
auto in_var_handles = DynamicCast<VarHandle>(inputs_);
auto out_var_handles = DynamicCast<VarHandle>(outputs_);
auto in_var_handles = DynamicCast<VarHandle>(this->Inputs());
auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
PADDLE_ENFORCE(in_var_handles.size() % places_.size() == 0);
PADDLE_ENFORCE_EQ(
in_var_handles.size(), out_var_handles.size(),
......
......@@ -36,7 +36,7 @@ void GatherOpHandle::RunImpl() {
VarHandle *out_var_handle;
{
auto out_var_handles = DynamicCast<VarHandle>(outputs_);
auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
PADDLE_ENFORCE_EQ(out_var_handles.size(), 1,
"The number of output should be one.");
out_var_handle = out_var_handles.front();
......@@ -99,7 +99,7 @@ void GatherOpHandle::RunImpl() {
Tensor *out_tensor = out_value->mutable_value();
// copy
auto dev_ctx = dev_ctxes_[out_var_handle->place_];
auto dev_ctx = dev_ctxes_.at(out_var_handle->place_);
RunAndRecordEvent(out_var_handle->place_, [in_tensors, out_tensor, &dev_ctx,
t_out_p] {
int s = 0, e = 0;
......
......@@ -103,7 +103,7 @@ void OpHandleBase::WaitInputVarGenerated() {
void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) {
for (auto *in : inputs_) {
if (NeedWait(in)) {
in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_[place]);
in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_.at(place));
}
}
}
......
......@@ -27,7 +27,7 @@ namespace framework {
namespace details {
void ReduceOpHandle::RunImpl() {
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
platform::RecordEvent record_event(Name(), dev_ctxes_.cbegin()->second);
if (places_.size() == 1) return;
// the input and output may have dummy var.
......
......@@ -46,7 +46,8 @@ struct ReduceOpHandle : public OpHandleBase {
nccl_ctxs_(nccl_ctxs) {
if (nccl_ctxs_) {
for (auto &p_ctx : nccl_ctxs_->contexts_) {
dev_ctxes_[platform::CUDAPlace(p_ctx.first)] = p_ctx.second.ctx_.get();
this->SetDeviceContext(platform::CUDAPlace(p_ctx.first),
p_ctx.second.ctx_.get());
}
}
}
......
......@@ -38,7 +38,7 @@ void RPCOpHandle::RunImpl() {
continue;
}
if (in->GeneratedOp()) {
in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_[p]);
in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_.at(p));
}
}
auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
......
......@@ -27,7 +27,7 @@ ScaleLossGradOpHandle::ScaleLossGradOpHandle(ir::Node *node, size_t num_dev,
coeff_(static_cast<float>(1.0 / num_dev)),
scope_(scope),
place_(place) {
dev_ctxes_[place_] = dev_ctx;
this->SetDeviceContext(place_, dev_ctx);
}
ScaleLossGradOpHandle::~ScaleLossGradOpHandle() {}
......@@ -46,8 +46,8 @@ void ScaleLossGradOpHandle::RunImpl() {
} else {
#ifdef PADDLE_WITH_CUDA
this->RunAndRecordEvent([&] {
auto stream =
static_cast<platform::CUDADeviceContext *>(this->dev_ctxes_[place_])
auto stream = static_cast<platform::CUDADeviceContext *>(
this->dev_ctxes_.at(place_))
->stream();
memory::Copy(boost::get<platform::CUDAPlace>(place_), tmp,
platform::CPUPlace(), &coeff_, sizeof(float), stream);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册