diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index e01e37533f3cf85f61f0138604ed67535659536d..0bfff745493d069e948e6d277ec2bbfb0673a70b 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -625,19 +625,11 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph, void MultiDevSSAGraphBuilder::CreateScaleLossGradOp( ir::Graph *result, const std::string &loss_grad_name) const { for (size_t i = 0; i < places_.size(); ++i) { -// Insert ScaleCost OpHandle -#ifdef PADDLE_WITH_CUDA - auto *communication_dev_ctx = - nccl_ctxs_ ? nccl_ctxs_->DevCtx(places_[i]) - : platform::DeviceContextPool::Instance().Get(places_[i]); -#else - auto *communication_dev_ctx = - platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); -#endif + // Insert ScaleCost OpHandle + auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]); auto *op_handle = new ScaleLossGradOpHandle( result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation), - local_scopes_.size(), local_scopes_[i], places_[i], - communication_dev_ctx); + local_scopes_.size(), local_scopes_[i], places_[i], dev_ctx); result->Get(kGraphOps).emplace_back(op_handle); // FIXME: Currently ScaleLossGradOp only use device_count as scale