未验证 提交 995a6376 编写于 作者: Y yaoxuefeng 提交者: GitHub

add pslib SparseDoubleTable test=develop (#23053)

上级 3af47711
...@@ -81,25 +81,25 @@ class DownpourServer(Server): ...@@ -81,25 +81,25 @@ class DownpourServer(Server):
'sparse_delete_after_unseen_days', 'sparse_show_click_decay_rate', 'sparse_delete_threshold', \ 'sparse_delete_after_unseen_days', 'sparse_show_click_decay_rate', 'sparse_delete_threshold', \
'sparse_converter', 'sparse_deconverter', 'sparse_enable_cache', 'sparse_cache_rate', \ 'sparse_converter', 'sparse_deconverter', 'sparse_enable_cache', 'sparse_cache_rate', \
'sparse_cache_file_num', 'sparse_beta1_decay_rate', 'sparse_beta2_decay_rate', \ 'sparse_cache_file_num', 'sparse_beta1_decay_rate', 'sparse_beta2_decay_rate', \
'sparse_ada_epsilon', 'sparse_optimizer'] 'sparse_ada_epsilon', 'sparse_optimizer', 'sparse_ssd_unseenday_threshold']
for key in strategy: for key in strategy:
if key not in support_sparse_key_list: if key not in support_sparse_key_list:
raise ValueError("strategy key '%s' not support" % (key)) raise ValueError("strategy key '%s' not support" % (key))
support_table_calss = ['DownpourSparseTable'] support_table_calss = ['DownpourSparseTable', 'DownpourSparseSSDTable']
if strategy.get('sparse_table_class') is not None: if strategy.get('sparse_table_class') is not None:
table_class = strategy.get('sparse_table_class') table_class = strategy.get('sparse_table_class')
if table_class not in support_table_calss: if table_class not in support_table_calss:
raise ValueError( raise ValueError(
"support sparse_table_class: [ 'DownpourSparseTable' ], \ "support sparse_table_class: [ 'DownpourSparseTable', 'DownpourSparseSSDTable'], \
but actual %s" % (table_class)) but actual %s" % (table_class))
else: else:
table_class = 'DownpourSparseTable' table_class = 'DownpourSparseTable'
table.table_class = table_class table.table_class = table_class
if table_class == 'DownpourSparseTable': if table_class == 'DownpourSparseTable' or table_class == 'DownpourSparseSSDTable':
table.enable_sparse_table_cache = strategy.get( table.enable_sparse_table_cache = strategy.get(
'sparse_enable_cache', True) 'sparse_enable_cache', True)
table.sparse_table_cache_rate = strategy.get('sparse_cache_rate', table.sparse_table_cache_rate = strategy.get('sparse_cache_rate',
...@@ -112,23 +112,25 @@ class DownpourServer(Server): ...@@ -112,23 +112,25 @@ class DownpourServer(Server):
# DownpourFeatureValueAccessor: for ctr task, has cvm, embedding and sgd info # DownpourFeatureValueAccessor: for ctr task, has cvm, embedding and sgd info
# DownpourCtrAccessor : for ctr task, has cvm, slot, embedding and sgd info # DownpourCtrAccessor : for ctr task, has cvm, slot, embedding and sgd info
# DownpourSparseValueAccessor : for general task, has embedding and sgd info # DownpourSparseValueAccessor : for general task, has embedding and sgd info
# DownpourCtrDoubleAccessor : for ctr task, which show clk are in double
support_accessor_class = [ support_accessor_class = [
'DownpourFeatureValueAccessor', 'DownpourCtrAccessor', 'DownpourFeatureValueAccessor', 'DownpourCtrAccessor',
'DownpourSparseValueAccessor' 'DownpourSparseValueAccessor', 'DownpourCtrDoubleAccessor'
] ]
if strategy.get('sparse_accessor_class') is not None: if strategy.get('sparse_accessor_class') is not None:
accessor_class = strategy.get('sparse_accessor_class') accessor_class = strategy.get('sparse_accessor_class')
if accessor_class not in support_accessor_class: if accessor_class not in support_accessor_class:
raise ValueError( raise ValueError(
"support sparse_accessor_class: ['DownpourFeatureValueAccessor', 'DownpourCtrAccessor'], \ "support sparse_accessor_class: ['DownpourFeatureValueAccessor', 'DownpourCtrAccessor', \
'DownpourSparseValueAccessor', 'DownpourCtrDoubleAccessor'], \
but actual %s" % (accessor_class)) but actual %s" % (accessor_class))
else: else:
accessor_class = 'DownpourCtrAccessor' accessor_class = 'DownpourCtrAccessor'
table.accessor.accessor_class = accessor_class table.accessor.accessor_class = accessor_class
if accessor_class == 'DownpourFeatureValueAccessor' or accessor_class == 'DownpourCtrAccessor': if accessor_class == 'DownpourFeatureValueAccessor' or accessor_class == 'DownpourCtrAccessor' or accessor_class == 'DownpourCtrDoubleAccessor':
table.accessor.sparse_sgd_param.learning_rate = strategy.get( table.accessor.sparse_sgd_param.learning_rate = strategy.get(
'sparse_learning_rate', 0.05) 'sparse_learning_rate', 0.05)
table.accessor.sparse_sgd_param.initial_g2sum = strategy.get( table.accessor.sparse_sgd_param.initial_g2sum = strategy.get(
...@@ -157,6 +159,8 @@ class DownpourServer(Server): ...@@ -157,6 +159,8 @@ class DownpourServer(Server):
'sparse_delta_keep_days', 16) 'sparse_delta_keep_days', 16)
table.accessor.downpour_accessor_param.delete_after_unseen_days = strategy.get( table.accessor.downpour_accessor_param.delete_after_unseen_days = strategy.get(
'sparse_delete_after_unseen_days', 30) 'sparse_delete_after_unseen_days', 30)
table.accessor.downpour_accessor_param.ssd_unseenday_threshold = strategy.get(
'sparse_ssd_unseenday_threshold', 1)
table.accessor.downpour_accessor_param.show_click_decay_rate = strategy.get( table.accessor.downpour_accessor_param.show_click_decay_rate = strategy.get(
'sparse_show_click_decay_rate', 0.98) 'sparse_show_click_decay_rate', 0.98)
table.accessor.downpour_accessor_param.delete_threshold = strategy.get( table.accessor.downpour_accessor_param.delete_threshold = strategy.get(
......
...@@ -26,7 +26,7 @@ import sys ...@@ -26,7 +26,7 @@ import sys
from op_test import OpTest from op_test import OpTest
from paddle.fluid.trainer_desc import DistMultiTrainer from paddle.fluid.trainer_desc import DistMultiTrainer
from paddle.fluid.device_worker import DownpourSGD, DownpourSGDOPT from paddle.fluid.device_worker import DownpourSGD, DownpourSGDOPT
from paddle.fluid.incubate.fleet.parameter_server.pslib.node import DownpourWorker from paddle.fluid.incubate.fleet.parameter_server.pslib.node import DownpourWorker, DownpourServer
from google.protobuf import text_format from google.protobuf import text_format
import paddle.fluid.incubate.fleet.parameter_server.pslib.ps_pb2 as pslib import paddle.fluid.incubate.fleet.parameter_server.pslib.ps_pb2 as pslib
from paddle.fluid.trainer_factory import TrainerFactory from paddle.fluid.trainer_factory import TrainerFactory
...@@ -91,6 +91,8 @@ class TestListenAndServOp(unittest.TestCase): ...@@ -91,6 +91,8 @@ class TestListenAndServOp(unittest.TestCase):
opt_info["dump_slot"] = False opt_info["dump_slot"] = False
opt_info["stat_var_names"] = [] opt_info["stat_var_names"] = []
worker = DownpourWorker(None) worker = DownpourWorker(None)
server = DownpourServer()
server.add_sparse_table(0, {})
worker.get_desc().CopyFrom(ps_param.trainer_param[0]) worker.get_desc().CopyFrom(ps_param.trainer_param[0])
opt_info["program_id_to_worker"] = {program_id: worker} opt_info["program_id_to_worker"] = {program_id: worker}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册