From 1bec52f581adec2ddb8038ca1bef78f9e2fc763f Mon Sep 17 00:00:00 2001 From: JiabinYang Date: Thu, 20 Dec 2018 05:50:12 +0000 Subject: [PATCH] test=develop, fix cpu running error --- .../distributed/parameter_prefetch.h | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/distributed/parameter_prefetch.h b/paddle/fluid/operators/distributed/parameter_prefetch.h index 47d082c4af..2f850a0332 100644 --- a/paddle/fluid/operators/distributed/parameter_prefetch.h +++ b/paddle/fluid/operators/distributed/parameter_prefetch.h @@ -39,9 +39,6 @@ void prefetch_with_reconstruct(const std::string& id_name, const framework::ExecutionContext& context, const framework::Scope& scope, framework::LoDTensor* original) { - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto& actual_ctx = *pool.Get(context.GetPlace()); - prefetch(id_name, out_name, table_names, epmap, height_sections, context, scope); auto& out = scope.FindVar(out_name)->Get(); @@ -54,23 +51,30 @@ void prefetch_with_reconstruct(const std::string& id_name, if (!platform::is_cpu_place(ids.place())) { is_on_cpu_place = false; } - - for (int64_t i = 0; i < ids.numel(); i++) { - const T* out_rows = out_value + original_width * i; - T* original_row = original_value + original_width * ids.data()[i]; - if (is_on_cpu_place) { + if (is_on_cpu_place) { + for (int64_t i = 0; i < ids.numel(); i++) { + const T* out_rows = out_value + original_width * i; + T* original_row = + original_value + original_width * ids.data()[i]; std::memcpy(original_row, out_rows, original_width * sizeof(T)); - } else { + } + } else { #ifndef PADDLE_WITH_CUDA - PADDLE_THROW("paddle is not compiled with CUDA!"); + PADDLE_THROW("paddle is not compiled with CUDA!"); #else + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& actual_ctx = *pool.Get(context.GetPlace()); + for (int64_t i = 0; i < ids.numel(); i++) { + const T* out_rows = out_value + original_width * i; + T* original_row = + original_value + original_width * ids.data()[i]; auto stream = static_cast(&actual_ctx)->stream(); memory::Copy(boost::get(ids.place()), original_row, platform::CPUPlace(), out_rows, original_width * sizeof(T), stream); -#endif } +#endif } } -- GitLab