提交 fbf3be66 编写于 作者: C chengtbf 提交者: Will Zhang

fix multi machine compile bug and some others little problem (#446)

* fix some little code

* fix multi machine compile bug

* remove log

* fix name


Former-commit-id: 2e70563e
上级 c4de3007
...@@ -350,7 +350,7 @@ void ChainGraph::BuildLossPrintStruct() { ...@@ -350,7 +350,7 @@ void ChainGraph::BuildLossPrintStruct() {
std::shared_ptr<const Operator> ConstructModelUpdateOp() { std::shared_ptr<const Operator> ConstructModelUpdateOp() {
OperatorConf mdupdt_conf; OperatorConf mdupdt_conf;
mdupdt_conf.set_name("model_update_" + NewUniqueId()); mdupdt_conf.set_name("md_update_" + NewUniqueId());
const JobDesc* job_desc = JobDesc::Singleton(); const JobDesc* job_desc = JobDesc::Singleton();
if (job_desc->IsTrain()) { if (job_desc->IsTrain()) {
const TrainConf& train_conf = job_desc->job_conf().train_conf(); const TrainConf& train_conf = job_desc->job_conf().train_conf();
......
...@@ -83,7 +83,9 @@ void TaskNode::ToProto(TaskProto* task_proto) { ...@@ -83,7 +83,9 @@ void TaskNode::ToProto(TaskProto* task_proto) {
} }
auto consumed_regst_proto = task_proto->mutable_consumed_regst_desc_id(); auto consumed_regst_proto = task_proto->mutable_consumed_regst_desc_id();
for (auto& pair : consumed_regsts_) { for (auto& pair : consumed_regsts_) {
int64_t regst_desc_id = pair.second.lock()->regst_desc_id(); std::shared_ptr<RegstDesc> regst = pair.second.lock();
if (!regst) { continue; }
int64_t regst_desc_id = regst->regst_desc_id();
CHECK(consumed_regst_proto->insert({pair.first, regst_desc_id}).second); CHECK(consumed_regst_proto->insert({pair.first, regst_desc_id}).second);
} }
} }
......
...@@ -43,8 +43,8 @@ message ConvolutionOpConf { ...@@ -43,8 +43,8 @@ message ConvolutionOpConf {
optional int32 dilation_h = 11 [default = 1]; optional int32 dilation_h = 11 [default = 1];
optional int32 dilation_w = 12 [default = 1]; optional int32 dilation_w = 12 [default = 1];
optional FillConf weight_fill = 14; optional FillConf weight_fill = 13;
optional FillConf bias_fill = 15; optional FillConf bias_fill = 14;
} }
message InnerProductOpConf { message InnerProductOpConf {
...@@ -94,8 +94,8 @@ message SoftmaxOpConf { ...@@ -94,8 +94,8 @@ message SoftmaxOpConf {
message SoftmaxLossOpConf { message SoftmaxLossOpConf {
required string prediction = 1; required string prediction = 1;
required string label = 3; required string label = 2;
required string loss = 4; required string loss = 3;
} }
message MultinomialLogisticLossOpConf { message MultinomialLogisticLossOpConf {
...@@ -107,7 +107,7 @@ message MultinomialLogisticLossOpConf { ...@@ -107,7 +107,7 @@ message MultinomialLogisticLossOpConf {
message ConcatOpConf { message ConcatOpConf {
repeated string in = 1; repeated string in = 1;
required string out = 2; required string out = 2;
required int32 axis = 4; required int32 axis = 3;
} }
message CopyCommNetOpConf { message CopyCommNetOpConf {
...@@ -142,17 +142,17 @@ message BoxCloneConf { ...@@ -142,17 +142,17 @@ message BoxCloneConf {
} }
message BoxingOpConf { message BoxingOpConf {
required string lbn = 2; required string lbn = 1;
required int32 in_num = 3; required int32 in_num = 2;
required int32 out_num = 4; required int32 out_num = 3;
oneof in_box { oneof in_box {
BoxConcatConf concat_box = 5; BoxConcatConf concat_box = 4;
BoxAddConf add_box = 6; BoxAddConf add_box = 5;
} }
oneof out_box { oneof out_box {
BoxSplitConf split_box = 7; BoxSplitConf split_box = 6;
BoxCloneConf clone_box = 8; BoxCloneConf clone_box = 7;
} }
} }
......
...@@ -2,6 +2,20 @@ ...@@ -2,6 +2,20 @@
namespace oneflow { namespace oneflow {
namespace {
DataType GetDataTypeFromBnInOpVec(
std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const std::vector<std::string>& bn_in_ops) {
for (const std::string& bn_in_op : bn_in_ops) {
const BlobDesc* blob_desc = GetBlobDesc4BnInOp(bn_in_op);
if (blob_desc) { return blob_desc->data_type(); }
}
return DataType::kInvalidDataType;
}
} // namespace
void Operator::InitFromOpConf(const OperatorConf& op_conf) { void Operator::InitFromOpConf(const OperatorConf& op_conf) {
op_conf_ = op_conf; op_conf_ = op_conf;
InitFromOpConf(); InitFromOpConf();
...@@ -93,13 +107,12 @@ void Operator::GenKernelConf( ...@@ -93,13 +107,12 @@ void Operator::GenKernelConf(
kernel_conf->set_need_do_data_id(true); kernel_conf->set_need_do_data_id(true);
} }
kernel_conf->set_is_forward(is_forward); kernel_conf->set_is_forward(is_forward);
if (output_bns_.empty() == false) { DataType data_type =
kernel_conf->set_data_type(GetBlobDesc4BnInOp(output_bns_[0])->data_type()); GetDataTypeFromBnInOpVec(GetBlobDesc4BnInOp, output_bns_);
} else if (input_bns_.empty() == false) { if (data_type == DataType::kInvalidDataType) {
kernel_conf->set_data_type(GetBlobDesc4BnInOp(input_bns_[0])->data_type()); data_type = GetDataTypeFromBnInOpVec(GetBlobDesc4BnInOp, input_bns_);
} else {
kernel_conf->set_data_type(DataType::kInvalidDataType);
} }
kernel_conf->set_data_type(data_type);
VirtualGenKernelConf(GetBlobDesc4BnInOp, parallel_ctx, kernel_conf); VirtualGenKernelConf(GetBlobDesc4BnInOp, parallel_ctx, kernel_conf);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册