提交 beffe2d5 编写于 作者: W willzhang4a58

refine job conf

上级 fa4085b1
......@@ -20,7 +20,7 @@ void LossAccTaskGraph::BuildTaskGraph() {
// parallel_desc
ParallelConf pr_conf;
pr_conf.set_policy(kDataParallel);
pr_conf.mutable_device_set()->add_device_name(loss_task_->device_name());
pr_conf.add_device_name(loss_task_->device_name());
auto pr_desc = std::make_shared<ParallelDesc>(pr_conf);
// faker chain
auto chain_gph = of_make_unique<ChainGraph>();
......
......@@ -18,8 +18,7 @@ void LossRecordTaskGraph::BuildTaskGraph(
faker_pr_conf.set_policy(kFakerLossRecord);
for (TaskNode* task : sorted_loss_acc_task) {
auto loss_acc_task = static_cast<CompTaskNode*>(task);
faker_pr_conf.mutable_device_set()->add_device_name(
loss_acc_task->device_name());
faker_pr_conf.add_device_name(loss_acc_task->device_name());
sorted_loss_acc_tasks_.push_back(loss_acc_task);
}
// faker chain
......@@ -31,7 +30,7 @@ void LossRecordTaskGraph::BuildTaskGraph(
// loss_record_pr_conf
ParallelConf loss_record_pr_conf;
loss_record_pr_conf.set_policy(kDataParallel);
loss_record_pr_conf.mutable_device_set()->add_device_name(
loss_record_pr_conf.add_device_name(
IDMgr::Singleton()->MachineName4MachineId(0) + ":persistence");
// loss record op
OperatorConf op_conf;
......
......@@ -18,8 +18,7 @@ void MdSaveTaskGraph::BuildTaskGraph() {
ChainNode* faker_chain = chain_gph->NewNode();
ParallelConf faker_pr_conf;
faker_pr_conf.set_policy(kDataParallel);
faker_pr_conf.mutable_device_set()->add_device_name(
update_task_->device_name());
faker_pr_conf.add_device_name(update_task_->device_name());
faker_chain->mut_parallel_desc().reset(new ParallelDesc(faker_pr_conf));
faker_chain->mut_output_lbns() = {kPackedBlobName};
// save
......@@ -28,8 +27,7 @@ void MdSaveTaskGraph::BuildTaskGraph() {
GetMachineNameFromDeviceName(update_task_->device_name());
ParallelConf save_pr_conf;
save_pr_conf.set_policy(kDataParallel);
save_pr_conf.mutable_device_set()->add_device_name(machine_name
+ ":persistence");
save_pr_conf.add_device_name(machine_name + ":persistence");
save_chain->mut_parallel_desc().reset(new ParallelDesc(save_pr_conf));
save_chain->mut_input_lbns() = {kPackedBlobName};
//
......
......@@ -19,7 +19,7 @@ void MdUpdtTaskGraph::BuildTaskGraph(uint32_t random_seed) {
ChainNode* updt_chain = chain_gph->NewNode();
ParallelConf updt_pr_conf;
updt_pr_conf.set_policy(kDataParallel);
updt_pr_conf.mutable_device_set()->add_device_name(fw_task_->device_name());
updt_pr_conf.add_device_name(fw_task_->device_name());
updt_chain->mut_parallel_desc().reset(new ParallelDesc(updt_pr_conf));
updt_chain->mut_input_lbns() = {kPackedBlobName};
updt_chain->mut_op_vec() = {OpMgr::Singleton()->ModelUpdateOp()};
......
......@@ -4,6 +4,5 @@ package oneflow;
import "oneflow/core/operator/op_conf.proto";
message DLNetConf {
string name = 1;
repeated OperatorConf op = 100;
}
......@@ -8,25 +8,34 @@ enum FloatingPointTypeProto {
kDouble = 1;
}
message JobConf {
string train_dlnet_conf_filepath = 1;
string resource_filepath = 2;
string placement_filepath = 3;
string model_load_snapshot_path = 4;
string model_save_snapshots_path = 5;
int32 piece_size = 7;
int32 num_of_pieces_in_batch = 8;
bool is_train = 9;
FloatingPointTypeProto floating_point_type = 10;
int32 num_of_batches_in_snapshot = 11;
int32 staleness = 12; // at least 0. If set as 0, then it's BSP
int64 total_batch_num = 13;
FillConf default_fill_conf = 14;
bool use_async_cpu_stream = 15;
int32 piece_num_of_record_loss = 16;
message TrainConf {
string model_save_snapshots_path = 1;
int32 num_of_batches_in_snapshot = 2;
int32 staleness = 3; // at least 0. If set as 0, then it's BSP
int64 total_batch_num = 4;
FillConf default_fill_conf = 5;
int32 piece_num_of_record_loss = 6;
oneof ModelUpdateCase {
NormalModelUpdateOpConf normal_mdupdt_conf = 1000;
MomentumModelUpdateOpConf momentum_mdupdt_conf = 1001;
RMSPropModelUpdateOpConf rmsprop_mdupdt_conf = 1002;
}
}
message PredictConf {
}
message JobConf {
string dlnet_filepath = 1;
string resource_filepath = 2;
string placement_filepath = 3;
string model_load_snapshot_path = 4;
int32 piece_size = 5;
int32 num_of_pieces_in_batch = 6;
FloatingPointTypeProto floating_point_type = 7;
bool use_async_cpu_stream = 8;
oneof JobType {
TrainConf train_conf = 1000;
PredictConf predict_conf = 1001;
}
}
......@@ -6,7 +6,7 @@ namespace oneflow {
void JobDesc::InitFromJobConf(const JobConf& conf) {
LOG(INFO) << "Read JobConf";
job_conf_ = conf;
ParseProtoFromTextFile(conf.train_dlnet_conf_filepath(), &train_dlnet_conf_);
ParseProtoFromTextFile(conf.dlnet_filepath(), &train_dlnet_conf_);
ParseProtoFromTextFile(conf.resource_filepath(), &resource_);
ParseProtoFromTextFile(conf.placement_filepath(), &placement_);
}
......
......@@ -27,14 +27,14 @@ class JobDesc final {
return job_conf_.model_load_snapshot_path();
}
const std::string& md_save_snapshots_path() {
return job_conf_.model_save_snapshots_path();
return job_conf_.train_conf().model_save_snapshots_path();
}
int32_t piece_size() const { return job_conf_.piece_size(); }
int32_t num_of_pieces_in_batch() const {
return job_conf_.num_of_pieces_in_batch();
}
int32_t batch_size() const { return piece_size() * num_of_pieces_in_batch(); }
bool is_train() const { return job_conf_.is_train(); }
bool is_train() const { return job_conf_.has_train_conf(); }
FloatingPointTypeProto floating_point_type() const {
return job_conf_.floating_point_type();
}
......@@ -48,19 +48,21 @@ class JobDesc final {
}
}
int32_t num_of_batches_in_snapshot() const {
return job_conf_.num_of_batches_in_snapshot();
return job_conf_.train_conf().num_of_batches_in_snapshot();
}
int32_t staleness() const { return job_conf_.train_conf().staleness(); }
int64_t total_batch_num() const {
return job_conf_.train_conf().total_batch_num();
}
int32_t staleness() const { return job_conf_.staleness(); }
int64_t total_batch_num() const { return job_conf_.total_batch_num(); }
int64_t total_piece_num() const {
return total_batch_num() * num_of_pieces_in_batch();
}
const FillConf* default_fill_conf() const {
return OF_PB_POINTER_GET(job_conf_, default_fill_conf);
return OF_PB_POINTER_GET(job_conf_.train_conf(), default_fill_conf);
}
bool use_async_cpu_stream() const { return job_conf_.use_async_cpu_stream(); }
int32_t piece_num_of_record_loss() const {
return job_conf_.piece_num_of_record_loss();
return job_conf_.train_conf().piece_num_of_record_loss();
}
private:
......
......@@ -2,6 +2,9 @@
namespace oneflow {
const char* kPackedBlobName = "_oneflow_PackedBlobName";
#define ONEFLOW_INTERNAL_PREFIX "OneFlowInternal"
const char* kPackedBlobName = ONEFLOW_INTERNAL_PREFIX "PackedBlobName";
const char* kNullDataId = ONEFLOW_INTERNAL_PREFIX "NullDataId";
} // namespace oneflow
......@@ -4,6 +4,7 @@
namespace oneflow {
extern const char* kPackedBlobName;
extern const char* kNullDataId;
} // namespace oneflow
......
......@@ -13,8 +13,8 @@ std::pair<std::string, std::string> ParseDeviceNameConf(
ParallelDesc::ParallelDesc(const ParallelConf& user_conf) {
policy_ = user_conf.policy();
device_type_ = JobDesc::Singleton()->resource().device_type();
for (int64_t i = 0; i < user_conf.device_set().device_name_size(); ++i) {
const std::string& device_name = user_conf.device_set().device_name(i);
for (int64_t i = 0; i < user_conf.device_name_size(); ++i) {
const std::string& device_name = user_conf.device_name(i);
std::pair<std::string, std::string> machine_name_device_id =
ParseDeviceNameConf(device_name);
std::string machine_name = machine_name_device_id.first;
......
......@@ -8,13 +8,9 @@ enum ParallelPolicy {
kFakerLossRecord = 3;
}
message DeviceSet {
repeated string device_name = 1;
}
message ParallelConf {
ParallelPolicy policy = 1;
DeviceSet device_set = 2;
repeated string device_name = 2;
}
message OpNameSet {
......
......@@ -34,16 +34,16 @@ std::shared_ptr<const Operator> OpMgr::ModelUpdateOp() {
if (!model_update_op_) {
OperatorConf mdupdt_conf;
mdupdt_conf.set_name("model_update");
const JobConf& job_conf = JobDesc::Singleton()->job_conf();
if (job_conf.has_normal_mdupdt_conf()) {
const TrainConf& train_conf = JobDesc::Singleton()->job_conf().train_conf();
if (train_conf.has_normal_mdupdt_conf()) {
*(mdupdt_conf.mutable_normal_mdupdt_conf()) =
job_conf.normal_mdupdt_conf();
} else if (job_conf.has_momentum_mdupdt_conf()) {
train_conf.normal_mdupdt_conf();
} else if (train_conf.has_momentum_mdupdt_conf()) {
*(mdupdt_conf.mutable_momentum_mdupdt_conf()) =
job_conf.momentum_mdupdt_conf();
} else if (job_conf.has_rmsprop_mdupdt_conf()) {
train_conf.momentum_mdupdt_conf();
} else if (train_conf.has_rmsprop_mdupdt_conf()) {
*(mdupdt_conf.mutable_rmsprop_mdupdt_conf()) =
job_conf.rmsprop_mdupdt_conf();
train_conf.rmsprop_mdupdt_conf();
} else {
UNEXPECTED_RUN();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册