From b0267ac93a84cdb3be3099b869c1c334b7e26096 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 13 Apr 2018 11:31:59 +0800 Subject: [PATCH] refine broadcast op --- .../framework/details/broadcast_op_handle.cc | 29 ++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/details/broadcast_op_handle.cc b/paddle/fluid/framework/details/broadcast_op_handle.cc index cd9bff52d9..53e8f9f366 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle.cc @@ -32,8 +32,14 @@ void BroadcastOpHandle::RunImpl() { // Wait input done, this Wait is asynchronous operation auto in_var_handle = static_cast(this->inputs_[0]); auto &in_place = in_var_handle->place_; - if (inputs_[0]->generated_op_) + if (inputs_[0]->generated_op_) { inputs_[0]->generated_op_->Wait(dev_ctxes_[in_place]); + for (auto *out : outputs_) { + auto out_handle = static_cast(out); + auto &out_p = out_handle->place_; + inputs_[0]->generated_op_->Wait(dev_ctxes_[out_p]); + } + } auto in_scope_idx = in_var_handle->scope_idx_; PADDLE_ENFORCE_LT(in_scope_idx, local_scopes_.size(), @@ -74,9 +80,24 @@ void BroadcastOpHandle::RunImpl() { } Tensor *out_tensor = GetTensorFromVar(out_var); - - paddle::framework::TensorCopy(*in_tensor, out_p, *(dev_ctxes_[in_place]), - out_tensor); + if (platform::is_cpu_place(in_place)) { + paddle::framework::TensorCopy(*in_tensor, out_p, *(dev_ctxes_[in_place]), + out_tensor); + } else if (platform::is_gpu_place(in_place)) { +#ifdef PADDLE_WITH_CUDA + auto src_gpu_place = boost::get(in_place); + auto dst_gpu_place = boost::get(out_p); + void *dst_ptr = out_tensor->mutable_data(out_p); + void *src_ptr = in_tensor->data(); + int64_t size = in_tensor->numel(); + memory::Copy( + dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, + reinterpret_cast(dev_ctxes_[out_p]) + ->stream()); +#else + PADDLE_THROW("CUDAPlace is not supported in CPU device."); +#endif + } } } -- GitLab