未验证 提交 bfef7feb 编写于 作者: D danleifeng 提交者: GitHub

【HETERPS】pipeline adaptive for heterps (#33159)

* pipeline adaptive for heterps;test=develop
* fix finalize hang;test=develop
* add is_compiled_with_heterps for dataset;test=develop
* fix hashtable core when pass ins_num=0;test=develop
上级 69ffb386
...@@ -40,8 +40,7 @@ namespace framework { ...@@ -40,8 +40,7 @@ namespace framework {
std::shared_ptr<PSGPUWrapper> PSGPUWrapper::s_instance_ = NULL; std::shared_ptr<PSGPUWrapper> PSGPUWrapper::s_instance_ = NULL;
bool PSGPUWrapper::is_initialized_ = false; bool PSGPUWrapper::is_initialized_ = false;
void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task, void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task) {
uint64_t table_id, int feature_dim) {
VLOG(3) << "PSGPUWrapper::BuildGPUPSTask begin"; VLOG(3) << "PSGPUWrapper::BuildGPUPSTask begin";
platform::Timer timeline; platform::Timer timeline;
timeline.Start(); timeline.Start();
...@@ -137,17 +136,16 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task, ...@@ -137,17 +136,16 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task,
local_ptr[i].resize(local_keys[i].size()); local_ptr[i].resize(local_keys[i].size());
} }
timeline.Start(); timeline.Start();
auto ptl_func = [this, &local_keys, &local_ptr, &table_id, auto ptl_func = [this, &local_keys, &local_ptr, &fleet_ptr](int i) {
&fleet_ptr](int i) {
size_t key_size = local_keys[i].size(); size_t key_size = local_keys[i].size();
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
auto tt = fleet_ptr->pslib_ptr_->_worker_ptr->pull_sparse_ptr( auto tt = fleet_ptr->pslib_ptr_->_worker_ptr->pull_sparse_ptr(
reinterpret_cast<char**>(local_ptr[i].data()), table_id, reinterpret_cast<char**>(local_ptr[i].data()), this->table_id_,
local_keys[i].data(), key_size); local_keys[i].data(), key_size);
#endif #endif
#ifdef PADDLE_WITH_PSCORE #ifdef PADDLE_WITH_PSCORE
auto tt = fleet_ptr->_worker_ptr->pull_sparse_ptr( auto tt = fleet_ptr->_worker_ptr->pull_sparse_ptr(
reinterpret_cast<char**>(local_ptr[i].data()), table_id, reinterpret_cast<char**>(local_ptr[i].data()), this->table_id_,
local_keys[i].data(), key_size); local_keys[i].data(), key_size);
#endif #endif
tt.wait(); tt.wait();
...@@ -270,11 +268,8 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task, ...@@ -270,11 +268,8 @@ void PSGPUWrapper::BuildTask(std::shared_ptr<HeterContext> gpu_task,
<< " seconds."; << " seconds.";
} }
void PSGPUWrapper::BuildGPUPS(uint64_t table_id, int feature_dim) { void PSGPUWrapper::BuildGPUTask(std::shared_ptr<HeterContext> gpu_task) {
int device_num = heter_devices_.size(); int device_num = heter_devices_.size();
std::shared_ptr<HeterContext> gpu_task = gpu_task_pool_.Get();
gpu_task->Reset();
BuildTask(gpu_task, table_id, feature_dim);
platform::Timer timeline; platform::Timer timeline;
timeline.Start(); timeline.Start();
...@@ -289,6 +284,10 @@ void PSGPUWrapper::BuildGPUPS(uint64_t table_id, int feature_dim) { ...@@ -289,6 +284,10 @@ void PSGPUWrapper::BuildGPUPS(uint64_t table_id, int feature_dim) {
delete HeterPs_; delete HeterPs_;
HeterPs_ = nullptr; HeterPs_ = nullptr;
} }
if (size_max <= 0) {
VLOG(1) << "Skip build gpu ps cause feasign nums = " << size_max;
return;
}
std::vector<std::thread> threads(device_num); std::vector<std::thread> threads(device_num);
HeterPs_ = HeterPsBase::get_instance(size_max, resource_); HeterPs_ = HeterPsBase::get_instance(size_max, resource_);
HeterPs_->set_nccl_comm_and_size(inner_comms_, inter_comms_, node_size_); HeterPs_->set_nccl_comm_and_size(inner_comms_, inter_comms_, node_size_);
...@@ -297,7 +296,9 @@ void PSGPUWrapper::BuildGPUPS(uint64_t table_id, int feature_dim) { ...@@ -297,7 +296,9 @@ void PSGPUWrapper::BuildGPUPS(uint64_t table_id, int feature_dim) {
this->HeterPs_->build_ps(i, gpu_task->device_keys_[i].data(), this->HeterPs_->build_ps(i, gpu_task->device_keys_[i].data(),
gpu_task->device_values_[i].data(), gpu_task->device_values_[i].data(),
feature_keys_count[i], 500000, 2); feature_keys_count[i], 500000, 2);
HeterPs_->show_one_table(i); if (feature_keys_count[i] > 0) {
HeterPs_->show_one_table(i);
}
}; };
for (size_t i = 0; i < threads.size(); i++) { for (size_t i = 0; i < threads.size(); i++) {
threads[i] = std::thread(build_func, i); threads[i] = std::thread(build_func, i);
...@@ -308,7 +309,109 @@ void PSGPUWrapper::BuildGPUPS(uint64_t table_id, int feature_dim) { ...@@ -308,7 +309,109 @@ void PSGPUWrapper::BuildGPUPS(uint64_t table_id, int feature_dim) {
timeline.Pause(); timeline.Pause();
VLOG(1) << "GpuPs build table total costs: " << timeline.ElapsedSec() VLOG(1) << "GpuPs build table total costs: " << timeline.ElapsedSec()
<< " s."; << " s.";
gpu_task_pool_.Push(gpu_task); }
void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) {
platform::Timer timer;
VLOG(3) << "Begin LoadIntoMemory(), dataset[" << dataset_ << "]";
timer.Start();
dataset_->LoadIntoMemory();
timer.Pause();
VLOG(0) << "LoadIntoMemory cost: " << timer.ElapsedSec() << "s";
// local shuffle
if (is_shuffle) {
dataset_->LocalShuffle();
}
std::shared_ptr<HeterContext> gpu_task = gpu_task_pool_.Get();
gpu_task->Reset();
data_ready_channel_->Put(gpu_task);
VLOG(3) << "End LoadIntoMemory(), dataset[" << dataset_ << "]";
}
void PSGPUWrapper::start_build_thread() {
running_ = true;
VLOG(3) << "start build CPU&GPU ps thread.";
build_cpu_threads_ = std::thread([this] { build_cpu_thread(); });
build_gpu_threads_ = std::thread([this] { build_gpu_thread(); });
}
void PSGPUWrapper::build_cpu_thread() {
while (running_) {
std::shared_ptr<HeterContext> gpu_task = nullptr;
if (!data_ready_channel_->Get(gpu_task)) {
continue;
}
VLOG(3) << "thread BuildTask start.";
platform::Timer timer;
timer.Start();
// build cpu ps data process
BuildTask(gpu_task);
timer.Pause();
VLOG(1) << "thread BuildTask end, cost time: " << timer.ElapsedSec() << "s";
buildcpu_ready_channel_->Put(gpu_task);
}
VLOG(3) << "build cpu thread end";
}
void PSGPUWrapper::build_gpu_thread() {
while (running_) {
std::shared_ptr<HeterContext> gpu_task = nullptr;
if (!gpu_free_channel_->Get(gpu_task)) {
continue;
}
if (!buildcpu_ready_channel_->Get(gpu_task)) {
continue;
}
VLOG(3) << "thread BuildGPUTask start.";
platform::Timer timer;
timer.Start();
BuildGPUTask(gpu_task);
timer.Pause();
VLOG(1) << "thread BuildGPUTask end, cost time: " << timer.ElapsedSec()
<< "s";
gpu_task_pool_.Push(gpu_task);
train_ready_channel_->Put(gpu_task);
}
VLOG(3) << "build gpu thread end";
}
void PSGPUWrapper::BeginPass() {
platform::Timer timer;
timer.Start();
if (current_task_) {
PADDLE_THROW(
platform::errors::Fatal("[BeginPass] current task is not ended."));
}
// load+build done
if (!train_ready_channel_->Get(current_task_)) {
PADDLE_THROW(platform::errors::Fatal("train_ready_channel_ failed."));
}
timer.Pause();
VLOG(1) << "BeginPass end, cost time: " << timer.ElapsedSec() << "s";
}
void PSGPUWrapper::EndPass() {
if (!current_task_) {
PADDLE_THROW(
platform::errors::Fatal("[EndPass] current task has been ended."));
}
platform::Timer timer;
timer.Start();
size_t keysize_max = 0;
// in case of feasign_num = 0, skip dump_to_cpu
for (size_t i = 0; i < heter_devices_.size(); i++) {
keysize_max = std::max(keysize_max, current_task_->device_keys_[i].size());
}
if (keysize_max != 0) {
HeterPs_->end_pass();
}
current_task_ = nullptr;
gpu_free_channel_->Put(current_task_);
timer.Pause();
VLOG(1) << "EndPass end, cost time: " << timer.ElapsedSec() << "s";
} }
void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, void PSGPUWrapper::PullSparse(const paddle::platform::Place& place,
......
...@@ -82,9 +82,33 @@ class PSGPUWrapper { ...@@ -82,9 +82,33 @@ class PSGPUWrapper {
const int hidden_size, const int64_t total_length, const int hidden_size, const int64_t total_length,
const int batch_size); const int batch_size);
void BuildGPUPS(const uint64_t table_id, int feature_dim); void BuildGPUTask(std::shared_ptr<HeterContext> gpu_task);
void BuildTask(std::shared_ptr<HeterContext> gpu_task, uint64_t table_id, void BuildTask(std::shared_ptr<HeterContext> gpu_task);
int feature_dim); void LoadIntoMemory(bool is_shuffle);
void BeginPass();
void EndPass();
void start_build_thread();
void build_cpu_thread();
void build_gpu_thread();
void Finalize() {
VLOG(3) << "PSGPUWrapper Begin Finalize.";
if (s_instance_ == nullptr) {
return;
}
data_ready_channel_->Close();
buildcpu_ready_channel_->Close();
gpu_free_channel_->Close();
train_ready_channel_->Close();
running_ = false;
VLOG(3) << "begin stop build_cpu_threads_";
build_cpu_threads_.join();
VLOG(3) << "begin stop build_gpu_threads_";
build_gpu_threads_.join();
s_instance_ = nullptr;
VLOG(3) << "PSGPUWrapper Finalize Finished.";
}
void InitializeGPU(const std::vector<int>& dev_ids) { void InitializeGPU(const std::vector<int>& dev_ids) {
if (s_instance_ != NULL && is_initialized_ == false) { if (s_instance_ != NULL && is_initialized_ == false) {
VLOG(3) << "PSGPUWrapper Begin InitializeGPU"; VLOG(3) << "PSGPUWrapper Begin InitializeGPU";
...@@ -129,6 +153,24 @@ class PSGPUWrapper { ...@@ -129,6 +153,24 @@ class PSGPUWrapper {
#endif #endif
} }
heter_devices_ = dev_ids; heter_devices_ = dev_ids;
data_ready_channel_->Open();
data_ready_channel_->SetCapacity(3);
buildcpu_ready_channel_->Open();
buildcpu_ready_channel_->SetCapacity(3);
gpu_free_channel_->Open();
gpu_free_channel_->SetCapacity(1);
train_ready_channel_->Open();
train_ready_channel_->SetCapacity(1);
current_task_ = nullptr;
gpu_free_channel_->Put(current_task_);
table_id_ = 1;
#ifdef PADDLE_WITH_PSLIB
table_id_ = 0;
#endif
// start build cpu&gpu ps thread
start_build_thread();
} }
} }
...@@ -206,18 +248,8 @@ class PSGPUWrapper { ...@@ -206,18 +248,8 @@ class PSGPUWrapper {
slot_vector_ = slot_vector; slot_vector_ = slot_vector;
} }
void EndPass() { HeterPs_->end_pass(); }
void ShowOneTable(int index) { HeterPs_->show_one_table(index); } void ShowOneTable(int index) { HeterPs_->show_one_table(index); }
void Finalize() {
VLOG(3) << "PSGPUWrapper Begin Finalize.";
if (s_instance_ == nullptr) {
return;
}
s_instance_ = nullptr;
VLOG(3) << "PSGPUWrapper Finalize Finished.";
}
private: private:
static std::shared_ptr<PSGPUWrapper> s_instance_; static std::shared_ptr<PSGPUWrapper> s_instance_;
Dataset* dataset_; Dataset* dataset_;
...@@ -231,6 +263,7 @@ class PSGPUWrapper { ...@@ -231,6 +263,7 @@ class PSGPUWrapper {
std::vector<int> slot_vector_; std::vector<int> slot_vector_;
int multi_node_{0}; int multi_node_{0};
int node_size_; int node_size_;
uint64_t table_id_;
std::vector<ncclComm_t> inner_comms_; std::vector<ncclComm_t> inner_comms_;
std::vector<ncclComm_t> inter_comms_; std::vector<ncclComm_t> inter_comms_;
std::vector<ncclUniqueId> inter_ncclids_; std::vector<ncclUniqueId> inter_ncclids_;
...@@ -242,6 +275,27 @@ class PSGPUWrapper { ...@@ -242,6 +275,27 @@ class PSGPUWrapper {
int thread_keys_shard_num_ = 37; int thread_keys_shard_num_ = 37;
uint64_t max_fea_num_per_pass_ = 5000000000; uint64_t max_fea_num_per_pass_ = 5000000000;
std::shared_ptr<
paddle::framework::ChannelObject<std::shared_ptr<HeterContext>>>
data_ready_channel_ =
paddle::framework::MakeChannel<std::shared_ptr<HeterContext>>();
std::shared_ptr<
paddle::framework::ChannelObject<std::shared_ptr<HeterContext>>>
buildcpu_ready_channel_ =
paddle::framework::MakeChannel<std::shared_ptr<HeterContext>>();
std::shared_ptr<
paddle::framework::ChannelObject<std::shared_ptr<HeterContext>>>
gpu_free_channel_ =
paddle::framework::MakeChannel<std::shared_ptr<HeterContext>>();
std::shared_ptr<
paddle::framework::ChannelObject<std::shared_ptr<HeterContext>>>
train_ready_channel_ =
paddle::framework::MakeChannel<std::shared_ptr<HeterContext>>();
std::shared_ptr<HeterContext> current_task_ = nullptr;
std::thread build_cpu_threads_;
std::thread build_gpu_threads_;
bool running_ = false;
protected: protected:
static bool is_initialized_; static bool is_initialized_;
}; };
......
...@@ -36,7 +36,7 @@ enum HeterTaskState { PULL_SPARSE, OP_RUN, XPU, OP_RUN_END, PUSH_GRAD, DONE }; ...@@ -36,7 +36,7 @@ enum HeterTaskState { PULL_SPARSE, OP_RUN, XPU, OP_RUN_END, PUSH_GRAD, DONE };
class HeterTask { class HeterTask {
public: public:
HeterTask() {} HeterTask() {}
virtual ~HeterTask(){}; virtual ~HeterTask() {}
void Update() { void Update() {
if (state_ == PULL_SPARSE) { if (state_ == PULL_SPARSE) {
...@@ -111,7 +111,7 @@ template <class T> ...@@ -111,7 +111,7 @@ template <class T>
class HeterObjectPool { class HeterObjectPool {
public: public:
HeterObjectPool() {} HeterObjectPool() {}
virtual ~HeterObjectPool(){}; virtual ~HeterObjectPool() {}
std::shared_ptr<T> Get() { std::shared_ptr<T> Get() {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (pool_.empty()) { if (pool_.empty()) {
...@@ -131,6 +131,10 @@ class HeterObjectPool { ...@@ -131,6 +131,10 @@ class HeterObjectPool {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
return pool_.size(); return pool_.size();
} }
bool Empty() {
std::lock_guard<std::mutex> lock(mutex_);
return pool_.empty();
}
std::shared_ptr<T>& GetElement(int i) { return pool_[i]; } std::shared_ptr<T>& GetElement(int i) { return pool_[i]; }
private: private:
...@@ -160,7 +164,7 @@ class BtObjectPool { ...@@ -160,7 +164,7 @@ class BtObjectPool {
virtual ~BtObjectPool() { virtual ~BtObjectPool() {
bthread_cond_destroy(&cond_); bthread_cond_destroy(&cond_);
bthread_mutex_destroy(&mutex_); bthread_mutex_destroy(&mutex_);
}; }
std::shared_ptr<T> Get() { std::shared_ptr<T> Get() {
BthreadMutextGuard guard(&mutex_); BthreadMutextGuard guard(&mutex_);
......
...@@ -47,7 +47,9 @@ void BindPSGPUWrapper(py::module* m) { ...@@ -47,7 +47,9 @@ void BindPSGPUWrapper(py::module* m) {
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("end_pass", &framework::PSGPUWrapper::EndPass, .def("end_pass", &framework::PSGPUWrapper::EndPass,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("build_gpu_ps", &framework::PSGPUWrapper::BuildGPUPS, .def("begin_pass", &framework::PSGPUWrapper::BeginPass,
py::call_guard<py::gil_scoped_release>())
.def("load_into_memory", &framework::PSGPUWrapper::LoadIntoMemory,
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("finalize", &framework::PSGPUWrapper::Finalize, .def("finalize", &framework::PSGPUWrapper::Finalize,
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
......
...@@ -185,6 +185,14 @@ bool IsCompiledWithMKLDNN() { ...@@ -185,6 +185,14 @@ bool IsCompiledWithMKLDNN() {
#endif #endif
} }
bool IsCompiledWithHETERPS() {
#ifndef PADDLE_WITH_HETERPS
return false;
#else
return true;
#endif
}
bool SupportsBfloat16() { bool SupportsBfloat16() {
#ifndef PADDLE_WITH_MKLDNN #ifndef PADDLE_WITH_MKLDNN
return false; return false;
...@@ -1910,6 +1918,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1910,6 +1918,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("is_compiled_with_npu", IsCompiledWithNPU); m.def("is_compiled_with_npu", IsCompiledWithNPU);
m.def("is_compiled_with_xpu", IsCompiledWithXPU); m.def("is_compiled_with_xpu", IsCompiledWithXPU);
m.def("is_compiled_with_mkldnn", IsCompiledWithMKLDNN); m.def("is_compiled_with_mkldnn", IsCompiledWithMKLDNN);
m.def("_is_compiled_with_heterps", IsCompiledWithHETERPS);
m.def("supports_bfloat16", SupportsBfloat16); m.def("supports_bfloat16", SupportsBfloat16);
m.def("supports_bfloat16_fast_performance", SupportsBfloat16FastPerformance); m.def("supports_bfloat16_fast_performance", SupportsBfloat16FastPerformance);
m.def("op_supported_infos", OpSupportedInfos); m.def("op_supported_infos", OpSupportedInfos);
......
...@@ -34,6 +34,7 @@ class DatasetBase(object): ...@@ -34,6 +34,7 @@ class DatasetBase(object):
self.thread_num = 1 self.thread_num = 1
self.filelist = [] self.filelist = []
self.use_ps_gpu = False self.use_ps_gpu = False
self.psgpu = None
def init(self, def init(self,
batch_size=1, batch_size=1,
...@@ -223,6 +224,11 @@ class DatasetBase(object): ...@@ -223,6 +224,11 @@ class DatasetBase(object):
use_ps_gpu: bool use_ps_gpu: bool
""" """
self.use_ps_gpu = use_ps_gpu self.use_ps_gpu = use_ps_gpu
# if not defined heterps with paddle, users will not use psgpu
if not core._is_compiled_with_heterps():
self.use_ps_gpu = 0
elif self.use_ps_gpu:
self.psgpu = core.PSGPU()
def _finish_to_run(self): def _finish_to_run(self):
self.dataset.destroy_readers() self.dataset.destroy_readers()
...@@ -677,12 +683,15 @@ class InMemoryDataset(DatasetBase): ...@@ -677,12 +683,15 @@ class InMemoryDataset(DatasetBase):
self.dataset.generate_local_tables_unlock( self.dataset.generate_local_tables_unlock(
table_id, fea_dim, read_thread_num, consume_thread_num, shard_num) table_id, fea_dim, read_thread_num, consume_thread_num, shard_num)
def load_into_memory(self): def load_into_memory(self, is_shuffle=False):
""" """
:api_attr: Static Graph :api_attr: Static Graph
Load data into memory Load data into memory
Args:
is_shuffle(bool): whether to use local shuffle, default is False
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -707,7 +716,11 @@ class InMemoryDataset(DatasetBase): ...@@ -707,7 +716,11 @@ class InMemoryDataset(DatasetBase):
dataset.load_into_memory() dataset.load_into_memory()
""" """
self._prepare_to_run() self._prepare_to_run()
self.dataset.load_into_memory() if not self.use_ps_gpu:
self.dataset.load_into_memory()
elif core._is_compiled_with_heterps():
self.psgpu.set_dataset(self.dataset)
self.psgpu.load_into_memory(is_shuffle)
def preload_into_memory(self, thread_num=None): def preload_into_memory(self, thread_num=None):
""" """
......
...@@ -271,6 +271,7 @@ if avx_supported(): ...@@ -271,6 +271,7 @@ if avx_supported():
from .core_avx import _set_paddle_lib_path from .core_avx import _set_paddle_lib_path
from .core_avx import _create_loaded_parameter from .core_avx import _create_loaded_parameter
from .core_avx import _cuda_synchronize from .core_avx import _cuda_synchronize
from .core_avx import _is_compiled_with_heterps
from .core_avx import _promote_types_if_complex_exists from .core_avx import _promote_types_if_complex_exists
if sys.platform != 'win32': if sys.platform != 'win32':
from .core_avx import _set_process_pids from .core_avx import _set_process_pids
...@@ -318,6 +319,7 @@ if load_noavx: ...@@ -318,6 +319,7 @@ if load_noavx:
from .core_noavx import _set_paddle_lib_path from .core_noavx import _set_paddle_lib_path
from .core_noavx import _create_loaded_parameter from .core_noavx import _create_loaded_parameter
from .core_noavx import _cuda_synchronize from .core_noavx import _cuda_synchronize
from .core_noavx import _is_compiled_with_heterps
from .core_noavx import _promote_types_if_complex_exists from .core_noavx import _promote_types_if_complex_exists
if sys.platform != 'win32': if sys.platform != 'win32':
from .core_noavx import _set_process_pids from .core_noavx import _set_process_pids
......
...@@ -75,6 +75,7 @@ class DatasetBase(object): ...@@ -75,6 +75,7 @@ class DatasetBase(object):
self.thread_num = 1 self.thread_num = 1
self.filelist = [] self.filelist = []
self.use_ps_gpu = False self.use_ps_gpu = False
self.psgpu = None
def set_pipe_command(self, pipe_command): def set_pipe_command(self, pipe_command):
""" """
...@@ -311,6 +312,11 @@ class DatasetBase(object): ...@@ -311,6 +312,11 @@ class DatasetBase(object):
use_ps_gpu: bool use_ps_gpu: bool
""" """
self.use_ps_gpu = use_ps_gpu self.use_ps_gpu = use_ps_gpu
# if not defined heterps with paddle, users will not use psgpu
if not core._is_compiled_with_heterps():
self.use_ps_gpu = 0
elif self.use_ps_gpu:
self.psgpu = core.PSGPU()
def _finish_to_run(self): def _finish_to_run(self):
self.dataset.destroy_readers() self.dataset.destroy_readers()
...@@ -694,10 +700,13 @@ class InMemoryDataset(DatasetBase): ...@@ -694,10 +700,13 @@ class InMemoryDataset(DatasetBase):
@deprecated( @deprecated(
since="2.0.0", since="2.0.0",
update_to="paddle.distributed.InMemoryDataset.load_into_memory") update_to="paddle.distributed.InMemoryDataset.load_into_memory")
def load_into_memory(self): def load_into_memory(self, is_shuffle=False):
""" """
Load data into memory Load data into memory
Args:
is_shuffle(bool): whether to use local shuffle, default is False
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -708,7 +717,11 @@ class InMemoryDataset(DatasetBase): ...@@ -708,7 +717,11 @@ class InMemoryDataset(DatasetBase):
dataset.load_into_memory() dataset.load_into_memory()
""" """
self._prepare_to_run() self._prepare_to_run()
self.dataset.load_into_memory() if not self.use_ps_gpu:
self.dataset.load_into_memory()
elif core._is_compiled_with_heterps():
self.psgpu.set_dataset(self.dataset)
self.psgpu.load_into_memory(is_shuffle)
@deprecated( @deprecated(
since="2.0.0", since="2.0.0",
......
...@@ -74,7 +74,7 @@ class TestCommunicator(unittest.TestCase): ...@@ -74,7 +74,7 @@ class TestCommunicator(unittest.TestCase):
batch_size=32, thread_num=1, pipe_command="cat", use_var=slots_vars) batch_size=32, thread_num=1, pipe_command="cat", use_var=slots_vars)
dataset.set_filelist(["test_communicator_ps_gpu.txt"]) dataset.set_filelist(["test_communicator_ps_gpu.txt"])
dataset._set_use_ps_gpu(1) dataset._set_use_ps_gpu(1)
dataset.load_into_memory() dataset.load_into_memory(is_shuffle=True)
os.environ["TEST_MODE"] = "1" os.environ["TEST_MODE"] = "1"
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册