提交 25d9cef9 编写于 作者: L linan17

update for news cp, userdefine file name

Change-Id: I13db8b30ed1abcbca8cb74105c5fe912fe7f929b
上级 90a3c634
......@@ -34,6 +34,7 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc,
if (trainer_desc.dump_fields_size() != 0 && dump_fields_path_ != "") {
need_dump_field_ = true;
}
user_define_dump_filename_ = trainer_desc.user_define_dump_filename();
if (need_dump_field_) {
auto &file_list = dataset->GetFileList();
if (file_list.size() == 0) {
......@@ -96,6 +97,11 @@ void DistMultiTrainer::InitDumpEnv() {
std::string path = string::format_string(
"%s/part-%03d", dump_fields_path_.c_str(), mpi_rank_);
if (user_define_dump_filename_ != "") {
path = string::format_string("%s/part-%s", dump_fields_path_.c_str(),
user_define_dump_filename_.c_str());
}
fp_ = fs_open_write(path, &err_no, dump_converter_);
for (int i = 0; i < thread_num_; ++i) {
workers_[i]->SetChannelWriter(queue_.get());
......
......@@ -101,6 +101,7 @@ class DistMultiTrainer : public MultiTrainer {
bool need_dump_field_;
std::string dump_fields_path_;
std::string user_define_dump_filename_;
std::string dump_converter_;
std::vector<std::string> dump_fields_;
int mpi_rank_;
......
......@@ -39,6 +39,7 @@ message TrainerDesc {
optional string dump_fields_path = 12;
repeated string dump_fields = 13;
optional string dump_converter = 14;
optional string user_define_dump_filename = 15;
// device worker parameters
optional HogwildWorkerParameter hogwild_param = 101;
......
......@@ -252,6 +252,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
opt_info["dump_converter"] = ""
opt_info["dump_fields"] = strategy.get("dump_fields", [])
opt_info["dump_fields_path"] = strategy.get("dump_fields_path", "")
opt_info["user_define_dump_filename"] = strategy.get("user_define_dump_filename", "")
if server._server.downpour_server_param.downpour_table_param[
0].accessor.accessor_class == "DownpourCtrAccessor":
opt_info["dump_slot"] = True
......
......@@ -312,7 +312,9 @@ class FleetUtil(object):
data_path,
hadoop_fs_name,
monitor_data={},
mode="patch"):
mode="patch",
dir_name="000",
base_only=False):
xbox_dict = collections.OrderedDict()
if mode == "base":
xbox_dict["id"] = str(xbox_base_key)
......@@ -322,10 +324,13 @@ class FleetUtil(object):
print("warning: unknown mode %s, set it to patch" % mode)
mode = "patch"
xbox_dict["id"] = str(int(time.time()))
if base_only:
xbox_dict["key"] = str(int(time.time()))
else:
xbox_dict["key"] = str(xbox_base_key)
if model_path.startswith("hdfs:") or model_path.startswith("afs:"):
model_path = model_path[model_path.find(":") + 1:]
xbox_dict["input"] = hadoop_fs_name + model_path.rstrip("/") + "/000"
xbox_dict["input"] = hadoop_fs_name + model_path.rstrip("/") + "/%s" % dir_name
xbox_dict["record_count"] = "111111"
xbox_dict["partition_type"] = "2"
xbox_dict["job_name"] = "default_job_name"
......@@ -449,7 +454,9 @@ class FleetUtil(object):
hadoop_fs_ugi,
monitor_data={},
hadoop_home="$HADOOP_HOME",
donefile_name=None):
donefile_name=None,
dir_name=None,
base_only=False):
"""
write delta donefile or xbox base donefile
......@@ -501,14 +508,17 @@ class FleetUtil(object):
if donefile_name is None:
donefile_name = "xbox_base_done.txt"
if dir_name is None:
dir_name = "000"
if isinstance(data_path, list):
data_path = ",".join(data_path)
if fleet.worker_index() == 0:
donefile_path = output_path + "/" + donefile_name
xbox_str = self._get_xbox_str(output_path, day, model_path, \
xbox_base_key, data_path, hadoop_fs_name, monitor_data={},
mode=mode)
xbox_base_key, data_path, hadoop_fs_name, {}, \
mode, dir_name, base_only)
configs = {
"fs.default.name": hadoop_fs_name,
"hadoop.job.ugi": hadoop_fs_ugi
......@@ -519,6 +529,8 @@ class FleetUtil(object):
last_dict = json.loads(pre_content.split("\n")[-1])
last_day = last_dict["input"].split("/")[-3]
last_pass = last_dict["input"].split("/")[-2].split("-")[-1]
if last_pass == "base":
last_pass = "-1"
exist = False
if int(day) < int(last_day) or \
int(day) == int(last_day) and \
......
......@@ -91,6 +91,9 @@ class TrainerDesc(object):
def _set_dump_fields_path(self, path):
self.proto_desc.dump_fields_path = path
def _set_user_define_dump_filename(self, user_define_dump_filename):
self.proto_desc.user_define_dump_filename = user_define_dump_filename
def _set_dump_converter(self, converter):
self.proto_desc.dump_converter = converter
......
......@@ -38,13 +38,23 @@ class TrainerFactory(object):
if "fleet_desc" in opt_info:
device_worker._set_fleet_desc(opt_info["fleet_desc"])
trainer._set_fleet_desc(opt_info["fleet_desc"])
if opt_info.get("use_cvm") is not None:
trainer._set_use_cvm(opt_info["use_cvm"])
if opt_info.get("scale_datanorm") is not None:
trainer._set_scale_datanorm(opt_info["scale_datanorm"])
if opt_info.get("dump_slot") is not None:
trainer._set_dump_slot(opt_info["dump_slot"])
if opt_info.get("mpi_rank") is not None:
trainer._set_mpi_rank(opt_info["mpi_rank"])
if opt_info.get("dump_fields") is not None:
trainer._set_dump_fields(opt_info["dump_fields"])
if opt_info.get("dump_fields_path") is not None:
trainer._set_dump_fields_path(opt_info["dump_fields_path"])
if opt_info.get("user_define_dump_filename") is not None:
trainer._set_user_define_dump_filename(opt_info["user_define_dump_filename"])
if opt_info.get("dump_converter") is not None:
trainer._set_dump_converter(opt_info["dump_converter"])
if opt_info.get("adjust_ins_weight") is not None:
trainer._set_adjust_ins_weight(opt_info["adjust_ins_weight"])
trainer._set_device_worker(device_worker)
return trainer
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册