From 384d6ee8ac3e0ca9372ef90a1626f7129c9e7f37 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 13 Apr 2018 17:44:58 +0800 Subject: [PATCH] follow comments --- .../framework/details/broadcast_op_handle.cc | 37 +++---------------- .../framework/details/gather_op_handle.cc | 9 ++--- paddle/fluid/framework/tensor_util.cc | 6 +-- 3 files changed, 12 insertions(+), 40 deletions(-) diff --git a/paddle/fluid/framework/details/broadcast_op_handle.cc b/paddle/fluid/framework/details/broadcast_op_handle.cc index 24115cae819..7d29012380e 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle.cc @@ -61,33 +61,24 @@ void BroadcastOpHandle::RunImpl() { // Wait input done, this Wait is asynchronous operation auto &in_place = in_var_handle[0]->place_; if (in_var_handle[0]->generated_op_) { - in_var_handle[0]->generated_op_->Wait(dev_ctxes_[in_place]); for (auto *out : out_var_handles) { auto &out_p = out->place_; - if (platform::is_same_place(in_place, out_p)) continue; in_var_handle[0]->generated_op_->Wait(dev_ctxes_[out_p]); } } // auto in_scope_idx = in_var_handle[0]->scope_idx_; - PADDLE_ENFORCE_LT(in_scope_idx, local_scopes_.size(), - "The input(%s) is not in the local_scopes.", - in_var_handle[0]->name_); - auto in_var = local_scopes_[in_scope_idx]->FindVar(in_var_handle[0]->name_); + auto in_var = + local_scopes_.at(in_scope_idx)->FindVar(in_var_handle[0]->name_); Tensor *in_tensor = GetTensorFromVar(in_var); for (auto *out : out_var_handles) { auto &out_p = out->place_; + auto out_var = local_scopes_.at(out->scope_idx_)->FindVar(out->name_); - auto out_scope_idx = out->scope_idx_; - PADDLE_ENFORCE_LT(out_scope_idx, local_scopes_.size(), - "%s is not in the local_scopes ", out->name_); - - auto *s = local_scopes_[out_scope_idx]; - auto out_var = s->FindVar(out->name_); PADDLE_ENFORCE_EQ(out_p.which(), in_place.which(), - "The place of input and output should be the same."); + "Places must be all on CPU or all on CUDA."); if (in_var->IsType()) { auto &in_sr = in_var->Get(); @@ -109,24 +100,8 @@ void BroadcastOpHandle::RunImpl() { } Tensor *out_tensor = GetTensorFromVar(out_var); - if (platform::is_cpu_place(in_place)) { - paddle::framework::TensorCopy(*in_tensor, out_p, *(dev_ctxes_[in_place]), - out_tensor); - } else if (platform::is_gpu_place(in_place)) { -#ifdef PADDLE_WITH_CUDA - auto src_gpu_place = boost::get(in_place); - auto dst_gpu_place = boost::get(out_p); - void *dst_ptr = out_tensor->mutable_data(out_p); - void *src_ptr = in_tensor->data(); - int64_t size = in_tensor->numel() * SizeOfType(in_tensor->type()); - memory::Copy( - dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, - reinterpret_cast(dev_ctxes_[out_p]) - ->stream()); -#else - PADDLE_THROW("CUDAPlace is not supported in CPU device."); -#endif - } + paddle::framework::TensorCopy(*in_tensor, out_p, *(dev_ctxes_[in_place]), + out_tensor); } } diff --git a/paddle/fluid/framework/details/gather_op_handle.cc b/paddle/fluid/framework/details/gather_op_handle.cc index 3c3054c03d9..8dd85be567d 100644 --- a/paddle/fluid/framework/details/gather_op_handle.cc +++ b/paddle/fluid/framework/details/gather_op_handle.cc @@ -74,13 +74,10 @@ void GatherOpHandle::RunImpl() { auto in_handle = static_cast(in); auto in_p = in_handle->place_; in_places.push_back(in_p); - PADDLE_ENFORCE_LT(in_handle->scope_idx_, local_scopes_.size(), - "%s is not the the local_scopes ", in_handle->name_); PADDLE_ENFORCE_EQ(in_p.which(), pre_place.which(), - "The place of input should be the same."); - auto *s = local_scopes_[in_handle->scope_idx_]; - auto in_var = s->FindVar(in_handle->name_); - + "Places must be all on CPU or all on CUDA."); + auto in_var = + local_scopes_.at(in_handle->scope_idx_)->FindVar(in_handle->name_); auto &in_sr = in_var->Get(); PADDLE_ENFORCE_EQ(in_sr.value().type(), pre_in.value().type(), diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 1d864af011b..d1b01ae05b8 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -11,8 +11,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ - #include "paddle/fluid/framework/tensor_util.h" +#include +#include +#include namespace paddle { namespace framework { @@ -65,8 +67,6 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, auto dst_gpu_place = boost::get(dst_place); auto ctx_place = ctx.GetPlace(); PADDLE_ENFORCE(platform::is_gpu_place(ctx_place)); - auto ctx_gpu_place = boost::get(ctx_place); - PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place); memory::Copy( dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, reinterpret_cast(ctx).stream()); -- GitLab