未验证 提交 995a6376 编写于 作者: Y yaoxuefeng 提交者: GitHub

add pslib SparseDoubleTable test=develop (#23053)

上级 3af47711
......@@ -81,25 +81,25 @@ class DownpourServer(Server):
'sparse_delete_after_unseen_days', 'sparse_show_click_decay_rate', 'sparse_delete_threshold', \
'sparse_converter', 'sparse_deconverter', 'sparse_enable_cache', 'sparse_cache_rate', \
'sparse_cache_file_num', 'sparse_beta1_decay_rate', 'sparse_beta2_decay_rate', \
'sparse_ada_epsilon', 'sparse_optimizer']
'sparse_ada_epsilon', 'sparse_optimizer', 'sparse_ssd_unseenday_threshold']
for key in strategy:
if key not in support_sparse_key_list:
raise ValueError("strategy key '%s' not support" % (key))
support_table_calss = ['DownpourSparseTable']
support_table_calss = ['DownpourSparseTable', 'DownpourSparseSSDTable']
if strategy.get('sparse_table_class') is not None:
table_class = strategy.get('sparse_table_class')
if table_class not in support_table_calss:
raise ValueError(
"support sparse_table_class: [ 'DownpourSparseTable' ], \
"support sparse_table_class: [ 'DownpourSparseTable', 'DownpourSparseSSDTable'], \
but actual %s" % (table_class))
else:
table_class = 'DownpourSparseTable'
table.table_class = table_class
if table_class == 'DownpourSparseTable':
if table_class == 'DownpourSparseTable' or table_class == 'DownpourSparseSSDTable':
table.enable_sparse_table_cache = strategy.get(
'sparse_enable_cache', True)
table.sparse_table_cache_rate = strategy.get('sparse_cache_rate',
......@@ -112,23 +112,25 @@ class DownpourServer(Server):
# DownpourFeatureValueAccessor: for ctr task, has cvm, embedding and sgd info
# DownpourCtrAccessor : for ctr task, has cvm, slot, embedding and sgd info
# DownpourSparseValueAccessor : for general task, has embedding and sgd info
# DownpourCtrDoubleAccessor : for ctr task, which show clk are in double
support_accessor_class = [
'DownpourFeatureValueAccessor', 'DownpourCtrAccessor',
'DownpourSparseValueAccessor'
'DownpourSparseValueAccessor', 'DownpourCtrDoubleAccessor'
]
if strategy.get('sparse_accessor_class') is not None:
accessor_class = strategy.get('sparse_accessor_class')
if accessor_class not in support_accessor_class:
raise ValueError(
"support sparse_accessor_class: ['DownpourFeatureValueAccessor', 'DownpourCtrAccessor'], \
"support sparse_accessor_class: ['DownpourFeatureValueAccessor', 'DownpourCtrAccessor', \
'DownpourSparseValueAccessor', 'DownpourCtrDoubleAccessor'], \
but actual %s" % (accessor_class))
else:
accessor_class = 'DownpourCtrAccessor'
table.accessor.accessor_class = accessor_class
if accessor_class == 'DownpourFeatureValueAccessor' or accessor_class == 'DownpourCtrAccessor':
if accessor_class == 'DownpourFeatureValueAccessor' or accessor_class == 'DownpourCtrAccessor' or accessor_class == 'DownpourCtrDoubleAccessor':
table.accessor.sparse_sgd_param.learning_rate = strategy.get(
'sparse_learning_rate', 0.05)
table.accessor.sparse_sgd_param.initial_g2sum = strategy.get(
......@@ -157,6 +159,8 @@ class DownpourServer(Server):
'sparse_delta_keep_days', 16)
table.accessor.downpour_accessor_param.delete_after_unseen_days = strategy.get(
'sparse_delete_after_unseen_days', 30)
table.accessor.downpour_accessor_param.ssd_unseenday_threshold = strategy.get(
'sparse_ssd_unseenday_threshold', 1)
table.accessor.downpour_accessor_param.show_click_decay_rate = strategy.get(
'sparse_show_click_decay_rate', 0.98)
table.accessor.downpour_accessor_param.delete_threshold = strategy.get(
......
......@@ -26,7 +26,7 @@ import sys
from op_test import OpTest
from paddle.fluid.trainer_desc import DistMultiTrainer
from paddle.fluid.device_worker import DownpourSGD, DownpourSGDOPT
from paddle.fluid.incubate.fleet.parameter_server.pslib.node import DownpourWorker
from paddle.fluid.incubate.fleet.parameter_server.pslib.node import DownpourWorker, DownpourServer
from google.protobuf import text_format
import paddle.fluid.incubate.fleet.parameter_server.pslib.ps_pb2 as pslib
from paddle.fluid.trainer_factory import TrainerFactory
......@@ -91,6 +91,8 @@ class TestListenAndServOp(unittest.TestCase):
opt_info["dump_slot"] = False
opt_info["stat_var_names"] = []
worker = DownpourWorker(None)
server = DownpourServer()
server.add_sparse_table(0, {})
worker.get_desc().CopyFrom(ps_param.trainer_param[0])
opt_info["program_id_to_worker"] = {program_id: worker}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册