提交 25d9cef9 编写于 作者: L linan17

update for news cp, userdefine file name

Change-Id: I13db8b30ed1abcbca8cb74105c5fe912fe7f929b
上级 90a3c634
...@@ -34,6 +34,7 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc, ...@@ -34,6 +34,7 @@ void DistMultiTrainer::Initialize(const TrainerDesc &trainer_desc,
if (trainer_desc.dump_fields_size() != 0 && dump_fields_path_ != "") { if (trainer_desc.dump_fields_size() != 0 && dump_fields_path_ != "") {
need_dump_field_ = true; need_dump_field_ = true;
} }
user_define_dump_filename_ = trainer_desc.user_define_dump_filename();
if (need_dump_field_) { if (need_dump_field_) {
auto &file_list = dataset->GetFileList(); auto &file_list = dataset->GetFileList();
if (file_list.size() == 0) { if (file_list.size() == 0) {
...@@ -96,6 +97,11 @@ void DistMultiTrainer::InitDumpEnv() { ...@@ -96,6 +97,11 @@ void DistMultiTrainer::InitDumpEnv() {
std::string path = string::format_string( std::string path = string::format_string(
"%s/part-%03d", dump_fields_path_.c_str(), mpi_rank_); "%s/part-%03d", dump_fields_path_.c_str(), mpi_rank_);
if (user_define_dump_filename_ != "") {
path = string::format_string("%s/part-%s", dump_fields_path_.c_str(),
user_define_dump_filename_.c_str());
}
fp_ = fs_open_write(path, &err_no, dump_converter_); fp_ = fs_open_write(path, &err_no, dump_converter_);
for (int i = 0; i < thread_num_; ++i) { for (int i = 0; i < thread_num_; ++i) {
workers_[i]->SetChannelWriter(queue_.get()); workers_[i]->SetChannelWriter(queue_.get());
......
...@@ -101,6 +101,7 @@ class DistMultiTrainer : public MultiTrainer { ...@@ -101,6 +101,7 @@ class DistMultiTrainer : public MultiTrainer {
bool need_dump_field_; bool need_dump_field_;
std::string dump_fields_path_; std::string dump_fields_path_;
std::string user_define_dump_filename_;
std::string dump_converter_; std::string dump_converter_;
std::vector<std::string> dump_fields_; std::vector<std::string> dump_fields_;
int mpi_rank_; int mpi_rank_;
......
...@@ -39,6 +39,7 @@ message TrainerDesc { ...@@ -39,6 +39,7 @@ message TrainerDesc {
optional string dump_fields_path = 12; optional string dump_fields_path = 12;
repeated string dump_fields = 13; repeated string dump_fields = 13;
optional string dump_converter = 14; optional string dump_converter = 14;
optional string user_define_dump_filename = 15;
// device worker parameters // device worker parameters
optional HogwildWorkerParameter hogwild_param = 101; optional HogwildWorkerParameter hogwild_param = 101;
......
...@@ -252,6 +252,7 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -252,6 +252,7 @@ class DistributedAdam(DistributedOptimizerImplBase):
opt_info["dump_converter"] = "" opt_info["dump_converter"] = ""
opt_info["dump_fields"] = strategy.get("dump_fields", []) opt_info["dump_fields"] = strategy.get("dump_fields", [])
opt_info["dump_fields_path"] = strategy.get("dump_fields_path", "") opt_info["dump_fields_path"] = strategy.get("dump_fields_path", "")
opt_info["user_define_dump_filename"] = strategy.get("user_define_dump_filename", "")
if server._server.downpour_server_param.downpour_table_param[ if server._server.downpour_server_param.downpour_table_param[
0].accessor.accessor_class == "DownpourCtrAccessor": 0].accessor.accessor_class == "DownpourCtrAccessor":
opt_info["dump_slot"] = True opt_info["dump_slot"] = True
......
...@@ -312,7 +312,9 @@ class FleetUtil(object): ...@@ -312,7 +312,9 @@ class FleetUtil(object):
data_path, data_path,
hadoop_fs_name, hadoop_fs_name,
monitor_data={}, monitor_data={},
mode="patch"): mode="patch",
dir_name="000",
base_only=False):
xbox_dict = collections.OrderedDict() xbox_dict = collections.OrderedDict()
if mode == "base": if mode == "base":
xbox_dict["id"] = str(xbox_base_key) xbox_dict["id"] = str(xbox_base_key)
...@@ -322,10 +324,13 @@ class FleetUtil(object): ...@@ -322,10 +324,13 @@ class FleetUtil(object):
print("warning: unknown mode %s, set it to patch" % mode) print("warning: unknown mode %s, set it to patch" % mode)
mode = "patch" mode = "patch"
xbox_dict["id"] = str(int(time.time())) xbox_dict["id"] = str(int(time.time()))
if base_only:
xbox_dict["key"] = str(int(time.time()))
else:
xbox_dict["key"] = str(xbox_base_key) xbox_dict["key"] = str(xbox_base_key)
if model_path.startswith("hdfs:") or model_path.startswith("afs:"): if model_path.startswith("hdfs:") or model_path.startswith("afs:"):
model_path = model_path[model_path.find(":") + 1:] model_path = model_path[model_path.find(":") + 1:]
xbox_dict["input"] = hadoop_fs_name + model_path.rstrip("/") + "/000" xbox_dict["input"] = hadoop_fs_name + model_path.rstrip("/") + "/%s" % dir_name
xbox_dict["record_count"] = "111111" xbox_dict["record_count"] = "111111"
xbox_dict["partition_type"] = "2" xbox_dict["partition_type"] = "2"
xbox_dict["job_name"] = "default_job_name" xbox_dict["job_name"] = "default_job_name"
...@@ -449,7 +454,9 @@ class FleetUtil(object): ...@@ -449,7 +454,9 @@ class FleetUtil(object):
hadoop_fs_ugi, hadoop_fs_ugi,
monitor_data={}, monitor_data={},
hadoop_home="$HADOOP_HOME", hadoop_home="$HADOOP_HOME",
donefile_name=None): donefile_name=None,
dir_name=None,
base_only=False):
""" """
write delta donefile or xbox base donefile write delta donefile or xbox base donefile
...@@ -501,14 +508,17 @@ class FleetUtil(object): ...@@ -501,14 +508,17 @@ class FleetUtil(object):
if donefile_name is None: if donefile_name is None:
donefile_name = "xbox_base_done.txt" donefile_name = "xbox_base_done.txt"
if dir_name is None:
dir_name = "000"
if isinstance(data_path, list): if isinstance(data_path, list):
data_path = ",".join(data_path) data_path = ",".join(data_path)
if fleet.worker_index() == 0: if fleet.worker_index() == 0:
donefile_path = output_path + "/" + donefile_name donefile_path = output_path + "/" + donefile_name
xbox_str = self._get_xbox_str(output_path, day, model_path, \ xbox_str = self._get_xbox_str(output_path, day, model_path, \
xbox_base_key, data_path, hadoop_fs_name, monitor_data={}, xbox_base_key, data_path, hadoop_fs_name, {}, \
mode=mode) mode, dir_name, base_only)
configs = { configs = {
"fs.default.name": hadoop_fs_name, "fs.default.name": hadoop_fs_name,
"hadoop.job.ugi": hadoop_fs_ugi "hadoop.job.ugi": hadoop_fs_ugi
...@@ -519,6 +529,8 @@ class FleetUtil(object): ...@@ -519,6 +529,8 @@ class FleetUtil(object):
last_dict = json.loads(pre_content.split("\n")[-1]) last_dict = json.loads(pre_content.split("\n")[-1])
last_day = last_dict["input"].split("/")[-3] last_day = last_dict["input"].split("/")[-3]
last_pass = last_dict["input"].split("/")[-2].split("-")[-1] last_pass = last_dict["input"].split("/")[-2].split("-")[-1]
if last_pass == "base":
last_pass = "-1"
exist = False exist = False
if int(day) < int(last_day) or \ if int(day) < int(last_day) or \
int(day) == int(last_day) and \ int(day) == int(last_day) and \
......
...@@ -91,6 +91,9 @@ class TrainerDesc(object): ...@@ -91,6 +91,9 @@ class TrainerDesc(object):
def _set_dump_fields_path(self, path): def _set_dump_fields_path(self, path):
self.proto_desc.dump_fields_path = path self.proto_desc.dump_fields_path = path
def _set_user_define_dump_filename(self, user_define_dump_filename):
self.proto_desc.user_define_dump_filename = user_define_dump_filename
def _set_dump_converter(self, converter): def _set_dump_converter(self, converter):
self.proto_desc.dump_converter = converter self.proto_desc.dump_converter = converter
......
...@@ -38,13 +38,23 @@ class TrainerFactory(object): ...@@ -38,13 +38,23 @@ class TrainerFactory(object):
if "fleet_desc" in opt_info: if "fleet_desc" in opt_info:
device_worker._set_fleet_desc(opt_info["fleet_desc"]) device_worker._set_fleet_desc(opt_info["fleet_desc"])
trainer._set_fleet_desc(opt_info["fleet_desc"]) trainer._set_fleet_desc(opt_info["fleet_desc"])
if opt_info.get("use_cvm") is not None:
trainer._set_use_cvm(opt_info["use_cvm"]) trainer._set_use_cvm(opt_info["use_cvm"])
if opt_info.get("scale_datanorm") is not None:
trainer._set_scale_datanorm(opt_info["scale_datanorm"]) trainer._set_scale_datanorm(opt_info["scale_datanorm"])
if opt_info.get("dump_slot") is not None:
trainer._set_dump_slot(opt_info["dump_slot"]) trainer._set_dump_slot(opt_info["dump_slot"])
if opt_info.get("mpi_rank") is not None:
trainer._set_mpi_rank(opt_info["mpi_rank"]) trainer._set_mpi_rank(opt_info["mpi_rank"])
if opt_info.get("dump_fields") is not None:
trainer._set_dump_fields(opt_info["dump_fields"]) trainer._set_dump_fields(opt_info["dump_fields"])
if opt_info.get("dump_fields_path") is not None:
trainer._set_dump_fields_path(opt_info["dump_fields_path"]) trainer._set_dump_fields_path(opt_info["dump_fields_path"])
if opt_info.get("user_define_dump_filename") is not None:
trainer._set_user_define_dump_filename(opt_info["user_define_dump_filename"])
if opt_info.get("dump_converter") is not None:
trainer._set_dump_converter(opt_info["dump_converter"]) trainer._set_dump_converter(opt_info["dump_converter"])
if opt_info.get("adjust_ins_weight") is not None:
trainer._set_adjust_ins_weight(opt_info["adjust_ins_weight"]) trainer._set_adjust_ins_weight(opt_info["adjust_ins_weight"])
trainer._set_device_worker(device_worker) trainer._set_device_worker(device_worker)
return trainer return trainer
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册