未验证 提交 47383dca 编写于 作者: W wangguanqun 提交者: GitHub

fix load bug and add distributed strategy from pslib (#40883)

* fix load bug and add distributed strategy from pslib

* add unittest

* use cvm config

* trainer and worker config

* add unittest

* add unittest

* add test

* code style
上级 3b00dc92
...@@ -163,6 +163,8 @@ message TrainerDescConfig { ...@@ -163,6 +163,8 @@ message TrainerDescConfig {
repeated string dump_fields = 2; repeated string dump_fields = 2;
repeated string dump_param = 3; repeated string dump_param = 3;
repeated string stat_var_names = 4; repeated string stat_var_names = 4;
optional string trainer = 5;
optional string device_worker = 6;
} }
message PipelineConfig { message PipelineConfig {
...@@ -189,6 +191,7 @@ message TableParameter { ...@@ -189,6 +191,7 @@ message TableParameter {
optional uint64 shard_num = 4 [ default = 1000 ]; optional uint64 shard_num = 4 [ default = 1000 ];
optional TableType type = 5; optional TableType type = 5;
optional TableAccessorParameter accessor = 6; optional TableAccessorParameter accessor = 6;
optional bool compress_in_save = 7 [ default = false ];
} }
message TableAccessorParameter { message TableAccessorParameter {
......
...@@ -515,6 +515,169 @@ class DistributedStrategy(object): ...@@ -515,6 +515,169 @@ class DistributedStrategy(object):
set_table_config(table_data, "table_parameters." + table_name, set_table_config(table_data, "table_parameters." + table_name,
configs[table_name]) configs[table_name])
@sparse_table_configs.setter
def fleet_desc_configs(self, configs):
support_sparse_key_list = ['sparse_table_class', 'sparse_compress_in_save', 'sparse_shard_num', \
'sparse_accessor_class', 'sparse_learning_rate', 'sparse_initial_g2sum', 'sparse_initial_range', \
'sparse_weight_bounds', 'sparse_fea_dim', 'sparse_embedx_dim', 'sparse_embedx_threshold', 'sparse_nonclk_coeff', \
'sparse_click_coeff', 'sparse_base_threshold', 'sparse_delta_threshold', 'sparse_delta_keep_days', \
'sparse_delete_after_unseen_days', 'sparse_show_click_decay_rate', 'sparse_delete_threshold', \
'sparse_converter', 'sparse_deconverter', 'sparse_enable_cache', 'sparse_cache_rate', \
'sparse_cache_file_num', 'sparse_beta1_decay_rate', 'sparse_beta2_decay_rate', \
'sparse_ada_epsilon', 'sparse_optimizer', 'sparse_ssd_unseenday_threshold',
'embed_sparse_optimizer', 'embed_sparse_learning_rate', 'embed_sparse_weight_bounds', \
'embed_sparse_initial_range', 'embed_sparse_initial_g2sum', 'embed_sparse_beta1_decay_rate', \
'embed_sparse_beta2_decay_rate', 'embedx_sparse_optimizer', 'embedx_sparse_learning_rate', \
'embedx_sparse_weight_bounds', 'embedx_sparse_initial_range', 'embedx_sparse_initial_g2sum', \
'embedx_sparse_beta1_decay_rate', 'embedx_sparse_beta2_decay_rate']
support_sparse_table_class = ['DownpourSparseTable']
support_sparse_accessor_class = [
'DownpourSparseValueAccessor', 'DownpourCtrAccessor',
'DownpourCtrDoubleAccessor', 'DownpourUnitAccessor',
'DownpourDoubleUnitAccessor'
]
from google.protobuf.descriptor import FieldDescriptor
table_param = self.strategy.downpour_table_param
def sparse_optimizer_config(sgd, strategy, prefix):
optimizer_name = strategy.get(prefix + "sparse_optimizer",
"adagrad")
sgd.name = optimizer_name
if optimizer_name == "naive":
sgd.name = "SparseNaiveSGDRule"
sgd.naive.learning_rate = strategy.get(
prefix + 'sparse_learning_rate', 0.05)
sgd.naive.initial_range = strategy.get(
prefix + 'sparse_initial_range', 1e-4)
bounds = strategy.get(prefix + 'sparse_weight_bounds',
[-10, 10])
sgd.naive.weight_bounds.extend(bounds)
elif optimizer_name == "adagrad":
sgd.name = 'SparseAdaGradSGDRule'
sgd.adagrad.learning_rate = strategy.get(
prefix + 'sparse_learning_rate', 0.05)
sgd.adagrad.initial_range = strategy.get(
prefix + 'sparse_initial_range', 1e-4)
if prefix == "embed_":
sgd.adagrad.initial_range = 0
sgd.adagrad.initial_g2sum = strategy.get(
prefix + 'sparse_initial_g2sum', 3)
bounds = strategy.get(prefix + 'sparse_weight_bounds',
[-10, 10])
sgd.adagrad.weight_bounds.extend(bounds)
elif optimizer_name == "std_adagrad":
sgd.name = 'StdAdaGradSGDRule'
sgd.adagrad.learning_rate = strategy.get(
prefix + 'sparse_learning_rate', 0.05)
sgd.adagrad.initial_range = strategy.get(
prefix + 'sparse_initial_range', 1e-4)
if prefix == "embed_":
sgd.adagrad.initial_range = 0
sgd.adagrad.initial_g2sum = strategy.get(
prefix + 'sparse_initial_g2sum', 3)
bounds = strategy.get(prefix + 'sparse_weight_bounds',
[-10, 10])
sgd.adagrad.weight_bounds.extend(bounds)
elif optimizer_name == "adam":
sgd.name = 'SparseAdamSGDRule'
sgd.adam.learning_rate = strategy.get(
prefix + 'sparse_learning_rate', 0.001)
sgd.adam.initial_range = strategy.get(
prefix + 'sparse_initial_range', 1e-4)
sgd.adam.beta1_decay_rate = strategy.get(
prefix + 'sparse_beta1_decay_rate', 0.9)
sgd.adam.beta2_decay_rate = strategy.get(
prefix + 'sparse_beta2_decay_rate', 0.999)
sgd.adam.ada_epsilon = strategy.get(
prefix + 'sparse_ada_epsilon', 1e-8)
bounds = strategy.get(prefix + 'sparse_weight_bounds',
[-10, 10])
sgd.adam.weight_bounds.extend(bounds)
def set_sparse_table_config(table_data, config):
for key in config:
if key not in support_sparse_key_list:
raise ValueError("strategy key '%s' not support" % (key))
table_class = config.get("sparse_table_class",
"DownpourSparseTable")
if table_class not in support_sparse_table_class:
raise ValueError(
"support sparse_table_class: ['DownpourSparseTable'], but actual %s"
% (table_class))
table_data.table_class = 'MemorySparseTable'
table_data.shard_num = config.get('sparse_shard_num', 1000)
accessor_class = config.get("sparse_accessor_class",
"DownpourCtrAccessor")
if accessor_class not in support_sparse_accessor_class:
raise ValueError(
"support sparse_accessor_class: [''DownpourSparseValueAccessor', 'DownpourCtrAccessor', 'DownpourCtrDoubleAccessor', 'DownpourUnitAccessor', 'DownpourDoubleUnitAccessor'], but actual %s"
% (accessor_class))
if configs.get("use_cvm", True):
table_data.accessor.accessor_class = 'CtrCommonAccessor'
else:
table_data.accessor.accessor_class = 'SparseAccessor'
table_data.accessor.embedx_dim = config.get('sparse_embedx_dim', 8)
table_data.accessor.fea_dim = table_data.accessor.embedx_dim + 3
table_data.accessor.embedx_threshold = config.get(
'sparse_embedx_threshold', 10)
table_data.accessor.ctr_accessor_param.nonclk_coeff = config.get(
'sparse_nonclk_coeff', 0.1)
table_data.accessor.ctr_accessor_param.click_coeff = config.get(
'sparse_click_coeff', 1)
table_data.accessor.ctr_accessor_param.base_threshold = config.get(
'sparse_base_threshold', 1.5)
table_data.accessor.ctr_accessor_param.delta_threshold = config.get(
'sparse_delta_threshold', 0.25)
table_data.accessor.ctr_accessor_param.delta_keep_days = config.get(
'sparse_delta_keep_days', 16)
table_data.accessor.ctr_accessor_param.show_click_decay_rate = config.get(
'sparse_show_click_decay_rate', 0.98)
table_data.accessor.ctr_accessor_param.delete_threshold = config.get(
'sparse_delete_threshold', 0.8)
table_data.accessor.ctr_accessor_param.delete_after_unseen_days = config.get(
'sparse_delete_after_unseen_days', 30)
table_data.accessor.ctr_accessor_param.ssd_unseenday_threshold = config.get(
'sparse_ssd_unseenday_threshold', 1)
converter = config.get('sparse_converter', "")
deconverter = config.get('sparse_deconverter', "")
save_data1 = table_data.accessor.table_accessor_save_param.add()
save_data1.param = 1
save_data1.converter = converter
save_data1.deconverter = deconverter
save_data2 = table_data.accessor.table_accessor_save_param.add()
save_data2.param = 2
save_data2.converter = converter
save_data2.deconverter = deconverter
if accessor_class == 'DownpourCtrAccessor' or accessor_class == 'DownpourCtrDoubleAccessor':
sparse_optimizer_config(table_data.accessor.embed_sgd_param,
config, '')
sparse_optimizer_config(table_data.accessor.embedx_sgd_param,
config, '')
else:
sparse_optimizer_config(table_data.accessor.embed_sgd_param,
config, 'embed_')
sparse_optimizer_config(table_data.accessor.embedx_sgd_param,
config, 'embedx_')
if not configs:
print("fleet desc config is empty")
else:
for table_name in configs:
if table_name == 'dense_table' or table_name == 'datanorm_table':
continue
if type(configs[table_name]) != dict:
continue
table_data = table_param.add()
table_data.table_name = table_name
set_sparse_table_config(table_data, configs[table_name])
@property @property
def amp(self): def amp(self):
""" """
......
...@@ -1668,7 +1668,8 @@ class Fleet(object): ...@@ -1668,7 +1668,8 @@ class Fleet(object):
opt_info["mpi_rank"] = self.worker_index() opt_info["mpi_rank"] = self.worker_index()
for k, v in self._user_defined_strategy.trainer_desc_configs.items( for k, v in self._user_defined_strategy.trainer_desc_configs.items(
): ):
opt_info[k] = v if v:
opt_info[k] = v
program._fleet_opt = opt_info program._fleet_opt = opt_info
if self._runtime_handle is None: if self._runtime_handle is None:
...@@ -1744,7 +1745,8 @@ class Fleet(object): ...@@ -1744,7 +1745,8 @@ class Fleet(object):
opt_info["mpi_rank"] = self.worker_index() opt_info["mpi_rank"] = self.worker_index()
for k, v in self._user_defined_strategy.trainer_desc_configs.items( for k, v in self._user_defined_strategy.trainer_desc_configs.items(
): ):
opt_info[k] = v if v:
opt_info[k] = v
program._fleet_opt = opt_info program._fleet_opt = opt_info
# print("fleet base opt info:", id(program), program._fleet_opt) # print("fleet base opt info:", id(program), program._fleet_opt)
......
...@@ -859,7 +859,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -859,7 +859,7 @@ class TheOnePSRuntime(RuntimeBase):
self.ps_desc_builder = PsDescBuilder(self.context) self.ps_desc_builder = PsDescBuilder(self.context)
def _init_params(self, scopes, send_ctx, recv_map): def _init_all_params(self, scopes, send_ctx, recv_map):
for name, ctx in send_ctx.items(): for name, ctx in send_ctx.items():
if ctx.is_sparse(): if ctx.is_sparse():
continue continue
...@@ -881,6 +881,17 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -881,6 +881,17 @@ class TheOnePSRuntime(RuntimeBase):
# print("pull all dense:", idx, table_id, var_names) # print("pull all dense:", idx, table_id, var_names)
self._worker.pull_dense_params(scope, table_id, var_names) self._worker.pull_dense_params(scope, table_id, var_names)
def _init_params(self, program, scope, send_ctx, recv_map):
for name, ctx in send_ctx.items():
if ctx.is_sparse():
continue
if ctx.program_id() != id(program):
continue
table_id = ctx.table_id()
var_names = recv_map[table_id]
# print("init params:", table_id, var_names)
self._worker.push_dense_params(scope, table_id, var_names)
def _pull_dense(self, program, scope, send_ctx, recv_map): def _pull_dense(self, program, scope, send_ctx, recv_map):
for name, ctx in send_ctx.items(): for name, ctx in send_ctx.items():
if ctx.is_sparse(): if ctx.is_sparse():
...@@ -1010,7 +1021,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1010,7 +1021,7 @@ class TheOnePSRuntime(RuntimeBase):
self._communicator.init_params(init_params) self._communicator.init_params(init_params)
else: else:
if role_id == 0: if role_id == 0:
self._init_params(scopes, send_ctx, dense_map) self._init_all_params(scopes, send_ctx, dense_map)
fleet.util.barrier() fleet.util.barrier()
...@@ -1324,19 +1335,17 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1324,19 +1335,17 @@ class TheOnePSRuntime(RuntimeBase):
dirname, dirname,
mode=0, mode=0,
main_program=None): main_program=None):
if main_program is None: main_program = self.origin_main_programs[
main_program = self.origin_main_program 0] if main_program is None else main_program
_, _, idx = get_program_by_id(self.context, id(main_program))
scope = self.scopes[idx]
print("load inference model scope idx:", idx)
if isinstance(main_program, CompiledProgram): if isinstance(main_program, CompiledProgram):
raise TypeError( raise TypeError(
"in fleet.save() function, main_program must be as Program type, CompiledProgram is not allowed" "in fleet.save() function, main_program must be as Program type, CompiledProgram is not allowed"
) )
denses = get_the_one_recv_context(
self.context,
is_dense=True,
split_dense_table=self.is_heter_ps_mode,
use_origin_program=True)
sparses = get_the_one_recv_context( sparses = get_the_one_recv_context(
self.context, self.context,
is_dense=False, is_dense=False,
...@@ -1346,8 +1355,16 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1346,8 +1355,16 @@ class TheOnePSRuntime(RuntimeBase):
sparse_varnames = self._load_sparse_params(dirname, sparses, sparse_varnames = self._load_sparse_params(dirname, sparses,
main_program, mode) main_program, mode)
dense_map = get_the_one_recv_context(
self.context, split_dense_table=self.is_heter_ps_mode)
send_ctx = get_the_one_send_context(
self.context,
split_dense_table=self.is_heter_ps_mode,
use_origin_program=self.is_heter_ps_mode,
ep_list=self.endpoints)
recv_dense_varnames = [] recv_dense_varnames = []
for id, names in denses.items(): for _, names in dense_map.items():
recv_dense_varnames.extend(names) recv_dense_varnames.extend(names)
loaded_varnames = sparse_varnames loaded_varnames = sparse_varnames
...@@ -1366,9 +1383,9 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1366,9 +1383,9 @@ class TheOnePSRuntime(RuntimeBase):
if var.name not in recv_dense_varnames: if var.name not in recv_dense_varnames:
continue continue
tensor = paddle.load(os.path.join(model_path, var.name)) tensor = paddle.load(os.path.join(model_path, var.name))
var.set_value(tensor) var.set_value(tensor, scope)
self._communicator.init_params(denses) self._init_params(main_program, scope, send_ctx, dense_map)
def _load_distributed_persistables(self, path, mode): def _load_distributed_persistables(self, path, mode):
self._worker.load_model(path, mode) self._worker.load_model(path, mode)
......
...@@ -180,6 +180,7 @@ class TestDistCTR2x2(FleetDistRunnerBase): ...@@ -180,6 +180,7 @@ class TestDistCTR2x2(FleetDistRunnerBase):
fleet.save_inference_model(exe, model_dir, fleet.save_inference_model(exe, model_dir,
[feed.name for feed in self.feeds], [feed.name for feed in self.feeds],
self.avg_cost) self.avg_cost)
fleet.load_model(model_dir, mode=1)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import os import os
import time import time
import unittest import unittest
os.environ["WITH_DISTRIBUTE"] = "ON"
import paddle import paddle
import paddle.distributed.fleet.base.role_maker as role_maker import paddle.distributed.fleet.base.role_maker as role_maker
import paddle.fluid.transpiler.details.program_utils as pu import paddle.fluid.transpiler.details.program_utils as pu
...@@ -45,10 +45,12 @@ class TestDistStrategyTrainerDescConfig(unittest.TestCase): ...@@ -45,10 +45,12 @@ class TestDistStrategyTrainerDescConfig(unittest.TestCase):
avg_cost = paddle.fluid.layers.mean(cost) avg_cost = paddle.fluid.layers.mean(cost)
strategy = paddle.distributed.fleet.DistributedStrategy() strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.a_sync = True
strategy.a_sync_configs = {"launch_barrier": 0}
config = { config = {
"dump_fields_path": "dump_data", "dump_fields_path": "dump_data",
"dump_fields": ["xxx", "yyy"], "dump_fields": ["xxx", "yyy"],
"dump_param": [] "dump_param": ['zzz']
} }
strategy.trainer_desc_configs = config strategy.trainer_desc_configs = config
...@@ -59,7 +61,18 @@ class TestDistStrategyTrainerDescConfig(unittest.TestCase): ...@@ -59,7 +61,18 @@ class TestDistStrategyTrainerDescConfig(unittest.TestCase):
program = paddle.static.default_main_program() program = paddle.static.default_main_program()
self.assertEqual(program._fleet_opt["dump_fields_path"], "dump_data") self.assertEqual(program._fleet_opt["dump_fields_path"], "dump_data")
self.assertEqual(len(program._fleet_opt["dump_fields"]), 2) self.assertEqual(len(program._fleet_opt["dump_fields"]), 2)
self.assertEqual(len(program._fleet_opt["dump_param"]), 0) self.assertEqual(len(program._fleet_opt["dump_param"]), 1)
self.assertEqual(program._fleet_opt["mpi_size"],
int(os.environ["PADDLE_TRAINERS_NUM"]))
optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize([avg_cost])
program = avg_cost.block.program
self.assertEqual(program._fleet_opt["dump_fields_path"], "dump_data")
self.assertEqual(len(program._fleet_opt["dump_fields"]), 2)
self.assertEqual(len(program._fleet_opt["dump_param"]), 1)
self.assertEqual(program._fleet_opt["mpi_size"], self.assertEqual(program._fleet_opt["mpi_size"],
int(os.environ["PADDLE_TRAINERS_NUM"])) int(os.environ["PADDLE_TRAINERS_NUM"]))
......
...@@ -281,18 +281,50 @@ class TestStrategyConfig(unittest.TestCase): ...@@ -281,18 +281,50 @@ class TestStrategyConfig(unittest.TestCase):
} }
self.assertEqual(strategy.fs_client_param.user, "456") self.assertEqual(strategy.fs_client_param.user, "456")
def test_fleet_desc_configs(self):
strategy = paddle.distributed.fleet.DistributedStrategy()
configs = {}
configs['emb'] = {"sparse_optimizer": "adagrad"}
strategy.fleet_desc_configs = configs
self.assertEqual(strategy.sparse_table_configs[0]
.accessor.embed_sgd_param.adagrad.learning_rate, 0.05)
strategy = paddle.distributed.fleet.DistributedStrategy()
configs = {}
configs['emb'] = {"sparse_optimizer": "naive"}
strategy.fleet_desc_configs = configs
self.assertEqual(strategy.sparse_table_configs[0]
.accessor.embed_sgd_param.naive.learning_rate, 0.05)
strategy = paddle.distributed.fleet.DistributedStrategy()
configs = {}
configs['emb'] = {"sparse_optimizer": "adam"}
strategy.fleet_desc_configs = configs
self.assertEqual(strategy.sparse_table_configs[0]
.accessor.embed_sgd_param.adam.beta1_decay_rate, 0.9)
strategy = paddle.distributed.fleet.DistributedStrategy()
configs = {}
configs['emb'] = {
"sparse_accessor_class": "DownpourUnitAccessor",
"embed_sparse_optimizer": "std_adagrad"
}
strategy.fleet_desc_configs = configs
self.assertEqual(strategy.sparse_table_configs[0]
.accessor.embed_sgd_param.adagrad.initial_range, 0)
def test_trainer_desc_configs(self): def test_trainer_desc_configs(self):
strategy = paddle.distributed.fleet.DistributedStrategy() strategy = paddle.distributed.fleet.DistributedStrategy()
configs = { configs = {
"dump_fields_path": "dump_data", "dump_fields_path": "dump_data",
"dump_fields": ["xxx", "yyy"], "dump_fields": ["xxx", "yyy"],
"dump_param": [] "dump_param": ['zzz']
} }
strategy.trainer_desc_configs = configs strategy.trainer_desc_configs = configs
self.assertEqual(strategy.trainer_desc_configs["dump_fields_path"], self.assertEqual(strategy.trainer_desc_configs["dump_fields_path"],
"dump_data") "dump_data")
self.assertEqual(len(strategy.trainer_desc_configs["dump_fields"]), 2) self.assertEqual(len(strategy.trainer_desc_configs["dump_fields"]), 2)
self.assertEqual(len(strategy.trainer_desc_configs["dump_param"]), 0) self.assertEqual(len(strategy.trainer_desc_configs["dump_param"]), 1)
def test_elastic(self): def test_elastic(self):
strategy = paddle.distributed.fleet.DistributedStrategy() strategy = paddle.distributed.fleet.DistributedStrategy()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册