未验证 提交 24ea1dd8 编写于 作者: W wangguanqun 提交者: GitHub

Add default config and update dataset in gpups (#43327)

* gpups default config and dataset

* codestyle

* add unittest

* code style
上级 5d48528f
......@@ -583,6 +583,12 @@ class InMemoryDataset(DatasetBase):
pipe_command = kwargs.get("pipe_command", "cat")
download_cmd = kwargs.get("download_cmd", "cat")
if self.use_ps_gpu:
data_feed_type = "SlotRecordInMemoryDataFeed"
else:
data_feed_type = "MultiSlotInMemoryDataFeed"
self._set_feed_type(data_feed_type)
super(InMemoryDataset, self).init(batch_size=batch_size,
thread_num=thread_num,
use_var=use_var,
......@@ -592,10 +598,6 @@ class InMemoryDataset(DatasetBase):
fs_ugi=fs_ugi,
download_cmd=download_cmd)
data_feed_type = kwargs.get("data_feed_type",
"MultiSlotInMemoryDataFeed")
self._set_feed_type(data_feed_type)
if kwargs.get("queue_num", -1) > 0:
queue_num = kwargs.get("queue_num", -1)
self._set_queue_num(queue_num)
......@@ -605,6 +607,8 @@ class InMemoryDataset(DatasetBase):
Set data_feed_desc
"""
self.proto_desc.name = data_feed_type
if (self.proto_desc.name == "SlotRecordInMemoryDataFeed"):
self.dataset = core.Dataset("SlotRecordDataset")
def _prepare_to_run(self):
"""
......
......@@ -138,7 +138,7 @@ class Accessor:
if not accessor_proto.HasField("accessor_class"):
# DownpourSparseValueAccessor
if context['use_ps_gpu']:
accessor_proto.accessor_class = "CtrCommonAccessor"
accessor_proto.accessor_class = "CtrDymfAccessor"
else:
accessor_proto.accessor_class = "SparseAccessor"
if not accessor_proto.HasField("fea_dim"):
......@@ -601,10 +601,16 @@ class SparseTable(Table):
if usr_table_proto.HasField("shard_num"):
table_proto.shard_num = usr_table_proto.shard_num
else:
table_proto.shard_num = 1000
warnings.warn(
"The shard_num of sparse table is not set, use default value 1000."
)
if self.context['use_ps_gpu']:
table_proto.shard_num = 37
warnings.warn(
"The shard_num of sparse table is not set, use default value 37 in gpups."
)
else:
table_proto.shard_num = 1000
warnings.warn(
"The shard_num of sparse table is not set, use default value 1000 in cpups."
)
if usr_table_proto.accessor.ByteSize() == 0:
warnings.warn(
......
......@@ -185,6 +185,45 @@ class TestPSPassWithBow(unittest.TestCase):
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(loss)
def test_gpups_dataset(self):
"""
Testcase for GPUPS InMemoryDataset .
"""
with open("test_in_memory_dataset_run_a.txt", "w") as f:
data = "1 1 2 3 3 4 5 5 5 5 1 1\n"
data += "1 2 2 3 4 4 6 6 6 6 1 2\n"
data += "1 3 2 3 5 4 7 7 7 7 1 3\n"
f.write(data)
with open("test_in_memory_dataset_run_b.txt", "w") as f:
data = "1 4 2 3 3 4 5 5 5 5 1 4\n"
data += "1 5 2 3 4 4 6 6 6 6 1 5\n"
data += "1 6 2 3 5 4 7 7 7 7 1 6\n"
data += "1 7 2 3 6 4 8 8 8 8 1 7\n"
f.write(data)
slots = ["slot1", "slot2", "slot3", "slot4"]
slots_vars = []
for slot in slots:
var = fluid.layers.data(name=slot,
shape=[1],
dtype="int64",
lod_level=1)
slots_vars.append(var)
dataset = paddle.distributed.InMemoryDataset()
dataset._set_use_ps_gpu(True)
dataset.init(batch_size=32,
thread_num=3,
pipe_command="cat",
use_var=slots_vars)
dataset.set_filelist([
"test_in_memory_dataset_run_a.txt",
"test_in_memory_dataset_run_b.txt"
])
os.remove("./test_in_memory_dataset_run_a.txt")
os.remove("./test_in_memory_dataset_run_b.txt")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册