You need to sign in or sign up before continuing.
未验证 提交 1dbc8632 编写于 作者: W wangguanqun 提交者: GitHub

fix benchmark in paddlerec (#38278)

上级 1006383b
......@@ -172,6 +172,8 @@ message CommonAccessorParameter {
optional string entry = 7;
optional int32 trainer_num = 8;
optional bool sync = 9;
optional uint32 table_num = 10;
optional uint32 table_dim = 11;
}
message TableAccessorSaveParameter {
......
......@@ -1071,8 +1071,8 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id,
for (size_t i = 0; i < worker_param.downpour_table_param_size(); ++i) {
if (worker_param.downpour_table_param(i).table_id() == table_id) {
var_name = worker_param.downpour_table_param(i).common().table_name();
var_num = worker_param.downpour_table_param(i).accessor().fea_dim();
var_shape = worker_param.downpour_table_param(i).accessor().embedx_dim();
var_num = worker_param.downpour_table_param(i).common().table_num();
var_shape = worker_param.downpour_table_param(i).common().table_dim();
break;
}
}
......
......@@ -305,7 +305,8 @@ std::string CtrCommonAccessor::parse_to_string(const float* v, int param) {
auto show = common_feature_value.show(const_cast<float*>(v));
auto click = common_feature_value.click(const_cast<float*>(v));
auto score = show_click_score(show, click);
if (score >= _config.embedx_threshold()) {
if (score >= _config.embedx_threshold() &&
param > common_feature_value.embedx_w_index()) {
for (auto i = common_feature_value.embedx_w_index();
i < common_feature_value.dim(); ++i) {
os << " " << v[i];
......
......@@ -27,7 +27,7 @@ namespace paddle {
namespace distributed {
// TODO(zhaocaibei123): configure
bool FLAGS_pserver_create_value_when_push = false;
bool FLAGS_pserver_create_value_when_push = true;
int FLAGS_pserver_table_save_max_retry = 3;
bool FLAGS_pserver_enable_create_feasign_randomly = false;
......@@ -494,7 +494,6 @@ int32_t MemorySparseTable::push_sparse(const uint64_t* keys,
values + push_data_idx * update_value_col;
auto itr = local_shard.find(key);
if (itr == local_shard.end()) {
VLOG(0) << "sparse table push_sparse: " << key << "not found!";
if (FLAGS_pserver_enable_create_feasign_randomly &&
!_value_accesor->create_value(1, update_data)) {
continue;
......
......@@ -171,6 +171,8 @@ class CommonAccessor:
self.dims = []
self.trainer_num = 0
self.sync = "false"
self.table_num = None
self.table_dim = None
self.initializers = []
self.opt_input_map = {}
self.opt_attr_map = {}
......@@ -256,7 +258,7 @@ class CommonAccessor:
break
return attr_str
def parse_by_optimizer(self, grad_name, is_sparse, total_dims,
def parse_by_optimizer(self, grad_name, is_sparse, size, single_dim,
compiled_strategy, adam_d2sum):
from paddle.fluid.incubate.fleet.parameter_server.ir.public import _get_optimize_ops
param_name = compiled_strategy.grad_name_to_param_name[grad_name]
......@@ -281,6 +283,8 @@ class CommonAccessor:
initializers = []
self.trainer_num = compiled_strategy.get_trainers()
self.table_num = size
self.table_dim = single_dim
if oop.type != 'adam' and adam_d2sum == True:
print('optimization algorithm is not adam, set adam_d2sum False')
......@@ -294,7 +298,7 @@ class CommonAccessor:
param_varnames = self.opt_input_map["naive_adagrad"]
attr_varnames = self.opt_attr_map["naive_adagrad"]
self.accessor_class = "sgd"
elif adam_d2sum:
elif adam_d2sum and not is_sparse:
param_varnames = self.opt_input_map["adam_d2sum"]
attr_varnames = self.opt_attr_map["adam_d2sum"]
self.accessor_class = "adam_d2sum"
......@@ -309,10 +313,9 @@ class CommonAccessor:
#for dims
if shape is None:
if is_sparse:
shape = total_dims
shape = single_dim
else:
shape = self.get_shard(total_dims, pserver_num,
pserver_id)
shape = self.get_shard(size, pserver_num, pserver_id)
dims.append(shape)
#for initializers
......@@ -351,9 +354,9 @@ class CommonAccessor:
if shape is None:
if is_sparse:
shape = total_dims
shape = single_dim
else:
shape = self.get_shard(total_dims, pserver_num,
shape = self.get_shard(size, pserver_num,
pserver_id)
dims.append(shape)
......@@ -382,6 +385,10 @@ class CommonAccessor:
attrs += "entry: \"{}\" ".format(self.entry)
attrs += "trainer_num: {} ".format(self.trainer_num)
attrs += "sync: {} ".format(self.sync)
if self.table_num:
attrs += "table_num: {} ".format(self.table_num)
if self.table_dim:
attrs += "table_dim: {} ".format(self.table_dim)
for param in self.params:
attrs += "params: \"{}\" ".format(param)
......@@ -451,10 +458,7 @@ class Table:
accessor_str = accessor_str.format(
conv_indent(indent), self.accessor_proto, conv_indent(indent))
attrs += accessor_str + "\n"
return table_str.format(
conv_indent(indent), attrs, conv_indent(indent))
if self.accessor is not None:
elif self.accessor is not None:
attrs += self.accessor.to_string(indent)
attrs += "\n"
......@@ -988,8 +992,9 @@ class TheOnePSRuntime(RuntimeBase):
adam_d2sum = self.context["user_defined_strategy"].adam_d2sum
common.parse_by_optimizer(ctx.origin_varnames()[0],
ctx.is_sparse(),
ctx.sections()[1] if ctx.is_sparse()
else ctx.sections()[0],
ctx.sections()[0],
ctx.sections()[1]
if ctx.is_sparse() else 1,
self.compiled_strategy, adam_d2sum)
if ctx.is_sparse():
......@@ -1142,17 +1147,25 @@ class TheOnePSRuntime(RuntimeBase):
return is_valid
def _get_inference_model_path(self, dirname):
if dirname.startswith("afs:") or dirname.startswith("hdfs:"):
model_path = "./dnn_plugin"
else:
model_path = os.path.join(dirname, "dnn_plugin")
return model_path
def _save_sparse_params(self, executor, dirname, context, main_program,
mode):
from paddle.fluid.incubate.fleet.parameter_server.ir.public import get_sparse_tablenames
distributed_varnames = get_sparse_tablenames(
self.compiled_strategy.origin_main_program, True)
values = []
model_path = self._get_inference_model_path(dirname)
for id, names in context.items():
if names[0] not in distributed_varnames:
# only save sparse param to local
try:
self._worker.recv_and_save_model(id, dirname)
self._worker.recv_and_save_model(id, model_path)
except:
pass
# save sparse & distributed param on server
......@@ -1277,10 +1290,7 @@ class TheOnePSRuntime(RuntimeBase):
infer_program._copy_dist_param_info_from(program)
if dirname.startswith("afs:") or dirname.startswith("hdfs:"):
model_path = "./dnn_plugin"
else:
model_path = os.path.join(dirname, "dnn_plugin")
model_path = self._get_inference_model_path(dirname)
model_basename = "__model__"
model_basename = os.path.join(model_path, model_basename)
paddle.save(infer_program, model_basename)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册