提交 084cdd1f 编写于 作者: Y Yu Yang

Rename code

上级 9f4a98f3
......@@ -24,10 +24,10 @@ ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
place_(place) {}
void ComputationOpHandle::RunImpl() {
auto *cur_ctx = dev_ctx_[place_];
auto *cur_ctx = dev_ctxes_[place_];
for (auto *in : inputs_) {
bool need_wait =
in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx;
in->generated_op_ && in->generated_op_->dev_ctxes_[place_] != cur_ctx;
if (need_wait) {
in->generated_op_->Wait(cur_ctx);
}
......
......@@ -60,8 +60,8 @@ void FetchOpHandle::RunImpl() {
auto &t = scope->FindVar(var_name)->Get<framework::LoDTensor>();
if (platform::is_gpu_place(var->place_)) {
#ifdef PADDLE_WITH_CUDA
TensorCopy(t, cpu, *dev_ctx_[t.place()], &tensors_[i]);
dev_ctx_[t.place()]->Wait();
TensorCopy(t, cpu, *dev_ctxes_[t.place()], &tensors_[i]);
dev_ctxes_[t.place()]->Wait();
#endif
} else {
tensors_[i].ShareDataWith(t);
......
......@@ -74,7 +74,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
result.ops_.emplace_back(new ComputationOpHandle(*op, s, p));
auto *op_handle = result.ops_.back().get();
op_handle->dev_ctx_[p] = const_cast<platform::DeviceContext *>(
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(p));
auto var_names = op->InputArgumentNames();
......
......@@ -23,7 +23,7 @@ NCCLAllReduceOpHandle::NCCLAllReduceOpHandle(
const platform::NCCLContextMap &ctxs)
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) {
for (auto &p : places_) {
this->dev_ctx_[p] = nccl_ctxs_.DevCtx(p);
this->dev_ctxes_[p] = nccl_ctxs_.DevCtx(p);
}
}
......@@ -34,7 +34,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
// Wait input done
for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_;
in->generated_op_->Wait(dev_ctx_[p]);
in->generated_op_->Wait(dev_ctxes_[p]);
}
auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_;
......
......@@ -42,7 +42,7 @@ OpHandleBase::~OpHandleBase() {
void OpHandleBase::Run(bool use_event) {
#ifdef PADDLE_WITH_CUDA
if (events_.empty() && use_event) {
for (auto &p : dev_ctx_) {
for (auto &p : dev_ctxes_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
PADDLE_ENFORCE(cudaSetDevice(dev_id));
PADDLE_ENFORCE(
......@@ -57,7 +57,7 @@ void OpHandleBase::Run(bool use_event) {
#ifdef PADDLE_WITH_CUDA
if (use_event) {
for (auto &p : dev_ctx_) {
for (auto &p : dev_ctxes_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
auto stream =
static_cast<platform::CUDADeviceContext *>(p.second)->stream();
......@@ -70,7 +70,7 @@ void OpHandleBase::Run(bool use_event) {
void OpHandleBase::Wait(platform::DeviceContext *waited_dev) {
#ifdef PADDLE_WITH_CUDA
if (platform::is_cpu_place(waited_dev->GetPlace()) || events_.empty()) {
for (auto &dev_ctx : dev_ctx_) {
for (auto &dev_ctx : dev_ctxes_) {
dev_ctx.second->Wait();
}
} else {
......@@ -81,7 +81,7 @@ void OpHandleBase::Wait(platform::DeviceContext *waited_dev) {
}
}
#else
for (auto &dev_ctx : dev_ctx_) {
for (auto &dev_ctx : dev_ctxes_) {
dev_ctx.second->Wait();
}
#endif
......
......@@ -31,7 +31,7 @@ class OpHandleBase {
std::vector<VarHandleBase *> outputs_;
std::unordered_map<platform::Place, platform::DeviceContext *,
platform::PlaceHash>
dev_ctx_;
dev_ctxes_;
#ifdef PADDLE_WITH_CUDA
std::unordered_map<int, cudaEvent_t> events_;
......
......@@ -21,7 +21,7 @@ ScaleLossGradOpHandle::ScaleLossGradOpHandle(size_t num_dev, Scope *scope,
platform::Place place,
platform::DeviceContext *dev_ctx)
: coeff_(static_cast<float>(1.0 / num_dev)), scope_(scope), place_(place) {
dev_ctx_[place_] = dev_ctx;
dev_ctxes_[place_] = dev_ctx;
}
ScaleLossGradOpHandle::~ScaleLossGradOpHandle() {}
......@@ -38,7 +38,7 @@ void ScaleLossGradOpHandle::RunImpl() {
} else {
#ifdef PADDLE_WITH_CUDA
auto stream =
static_cast<platform::CUDADeviceContext *>(this->dev_ctx_[place_])
static_cast<platform::CUDADeviceContext *>(this->dev_ctxes_[place_])
->stream();
memory::Copy(boost::get<platform::CUDAPlace>(place_), tmp,
platform::CPUPlace(), &coeff_, sizeof(float), stream);
......
......@@ -96,7 +96,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
// FIXME: Use new device context
for (auto &p : places_) {
op->dev_ctx_[p] = fetch_ctxs_.Get(p);
op->dev_ctxes_[p] = fetch_ctxs_.Get(p);
}
for (auto *var : vars) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册