diff --git a/paddle/fluid/operators/distributed/parameter_prefetch.h b/paddle/fluid/operators/distributed/parameter_prefetch.h index 47d082c4af54a2a6591432dbed617f74ada240b9..2f850a0332256d458e79ed9da361c86eb8a2f780 100644 --- a/paddle/fluid/operators/distributed/parameter_prefetch.h +++ b/paddle/fluid/operators/distributed/parameter_prefetch.h @@ -39,9 +39,6 @@ void prefetch_with_reconstruct(const std::string& id_name, const framework::ExecutionContext& context, const framework::Scope& scope, framework::LoDTensor* original) { - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - auto& actual_ctx = *pool.Get(context.GetPlace()); - prefetch(id_name, out_name, table_names, epmap, height_sections, context, scope); auto& out = scope.FindVar(out_name)->Get(); @@ -54,23 +51,30 @@ void prefetch_with_reconstruct(const std::string& id_name, if (!platform::is_cpu_place(ids.place())) { is_on_cpu_place = false; } - - for (int64_t i = 0; i < ids.numel(); i++) { - const T* out_rows = out_value + original_width * i; - T* original_row = original_value + original_width * ids.data()[i]; - if (is_on_cpu_place) { + if (is_on_cpu_place) { + for (int64_t i = 0; i < ids.numel(); i++) { + const T* out_rows = out_value + original_width * i; + T* original_row = + original_value + original_width * ids.data()[i]; std::memcpy(original_row, out_rows, original_width * sizeof(T)); - } else { + } + } else { #ifndef PADDLE_WITH_CUDA - PADDLE_THROW("paddle is not compiled with CUDA!"); + PADDLE_THROW("paddle is not compiled with CUDA!"); #else + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& actual_ctx = *pool.Get(context.GetPlace()); + for (int64_t i = 0; i < ids.numel(); i++) { + const T* out_rows = out_value + original_width * i; + T* original_row = + original_value + original_width * ids.data()[i]; auto stream = static_cast(&actual_ctx)->stream(); memory::Copy(boost::get(ids.place()), original_row, platform::CPUPlace(), out_rows, original_width * sizeof(T), stream); -#endif } +#endif } }