未验证 提交 d69daed1 编写于 作者: T Thunderbrook 提交者: GitHub

[GpuPs]pybind core (#37287)

* pybind core

* set use psgpu
上级 acbf9974
......@@ -321,19 +321,19 @@ class DatasetBase(object):
self.dataset.set_data_feed_desc(self.desc())
self.dataset.create_readers()
def _set_use_ps_gpu(self, use_ps_gpu):
def _set_use_ps_gpu(self, psgpu):
"""
set use_ps_gpu flag
Args:
use_ps_gpu: bool
"""
self.use_ps_gpu = use_ps_gpu
self.use_ps_gpu = True
# if not defined heterps with paddle, users will not use psgpu
if not core._is_compiled_with_heterps():
self.use_ps_gpu = 0
self.use_ps_gpu = False
elif self.use_ps_gpu:
self.psgpu = core.PSGPU()
self.psgpu = psgpu
def _finish_to_run(self):
self.dataset.destroy_readers()
......
......@@ -1813,9 +1813,9 @@ class Executor(object):
if program._pipeline_opt is None:
if program._heter_pipeline_opt is None:
self._dump_debug_info(program=program, trainer=trainer)
# in case of calling _set_use_ps_gpu explicitly
if dataset.use_ps_gpu is False:
dataset._set_use_ps_gpu(trainer.proto_desc.use_ps_gpu)
# warning if dataset not set psgpu in psgpu mode
if dataset.use_ps_gpu is False and trainer.proto_desc.use_ps_gpu:
logging.warning("dataset should call set_use_ps_gpu in PsGpu mode")
dataset._dynamic_adjust_before_train(trainer.proto_desc.thread_num)
if program._heter_pipeline_opt is None:
......@@ -1948,9 +1948,9 @@ class Executor(object):
# NOTE: only for debug, very slow
# self._dump_debug_info(program=program, trainer=trainer)
# in case of calling _set_use_ps_gpu explicitly
if dataset.use_ps_gpu is False:
dataset._set_use_ps_gpu(trainer.proto_desc.use_ps_gpu)
# warning if dataset not set psgpu in psgpu mode
if dataset.use_ps_gpu is False and trainer.proto_desc.use_ps_gpu:
logging.warning("dataset should call set_use_ps_gpu in PsGpu mode")
dataset._dynamic_adjust_before_train(trainer.proto_desc.thread_num)
trainer_desc = trainer._desc() # slow, cache
......
......@@ -73,7 +73,6 @@ class TestCommunicator(unittest.TestCase):
dataset.init(
batch_size=32, thread_num=1, pipe_command="cat", use_var=slots_vars)
dataset.set_filelist(["test_communicator_ps_gpu.txt"])
dataset._set_use_ps_gpu(1)
dataset.set_date("20211111")
dataset.load_into_memory(is_shuffle=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册