diff --git a/oneflow/core/graph/chain_graph.cpp b/oneflow/core/graph/chain_graph.cpp index a04b03571f8763f0c46dc9436023570d1f88fe83..c89f07958da7bd8584b53cb87ff382743cc6f1b5 100644 --- a/oneflow/core/graph/chain_graph.cpp +++ b/oneflow/core/graph/chain_graph.cpp @@ -350,7 +350,7 @@ void ChainGraph::BuildLossPrintStruct() { std::shared_ptr ConstructModelUpdateOp() { OperatorConf mdupdt_conf; - mdupdt_conf.set_name("model_update_" + NewUniqueId()); + mdupdt_conf.set_name("md_update_" + NewUniqueId()); const JobDesc* job_desc = JobDesc::Singleton(); if (job_desc->IsTrain()) { const TrainConf& train_conf = job_desc->job_conf().train_conf(); diff --git a/oneflow/core/graph/task_node.cpp b/oneflow/core/graph/task_node.cpp index d2f6089e64f5ffeb36f5446e27fe1a95a086d81a..e5023cbd37c40c842d6d246d3f0a88c783e31172 100644 --- a/oneflow/core/graph/task_node.cpp +++ b/oneflow/core/graph/task_node.cpp @@ -83,7 +83,9 @@ void TaskNode::ToProto(TaskProto* task_proto) { } auto consumed_regst_proto = task_proto->mutable_consumed_regst_desc_id(); for (auto& pair : consumed_regsts_) { - int64_t regst_desc_id = pair.second.lock()->regst_desc_id(); + std::shared_ptr regst = pair.second.lock(); + if (!regst) { continue; } + int64_t regst_desc_id = regst->regst_desc_id(); CHECK(consumed_regst_proto->insert({pair.first, regst_desc_id}).second); } } diff --git a/oneflow/core/operator/op_conf.proto b/oneflow/core/operator/op_conf.proto index 42fadc3ca0fa22a42e840e11ac5ebf4e7bbff202..e0be193216b7749814eb3ba8e3e1a007d74a0bff 100644 --- a/oneflow/core/operator/op_conf.proto +++ b/oneflow/core/operator/op_conf.proto @@ -43,8 +43,8 @@ message ConvolutionOpConf { optional int32 dilation_h = 11 [default = 1]; optional int32 dilation_w = 12 [default = 1]; - optional FillConf weight_fill = 14; - optional FillConf bias_fill = 15; + optional FillConf weight_fill = 13; + optional FillConf bias_fill = 14; } message InnerProductOpConf { @@ -94,8 +94,8 @@ message SoftmaxOpConf { message SoftmaxLossOpConf { required string prediction = 1; - required string label = 3; - required string loss = 4; + required string label = 2; + required string loss = 3; } message MultinomialLogisticLossOpConf { @@ -107,7 +107,7 @@ message MultinomialLogisticLossOpConf { message ConcatOpConf { repeated string in = 1; required string out = 2; - required int32 axis = 4; + required int32 axis = 3; } message CopyCommNetOpConf { @@ -142,17 +142,17 @@ message BoxCloneConf { } message BoxingOpConf { - required string lbn = 2; - required int32 in_num = 3; - required int32 out_num = 4; + required string lbn = 1; + required int32 in_num = 2; + required int32 out_num = 3; oneof in_box { - BoxConcatConf concat_box = 5; - BoxAddConf add_box = 6; + BoxConcatConf concat_box = 4; + BoxAddConf add_box = 5; } oneof out_box { - BoxSplitConf split_box = 7; - BoxCloneConf clone_box = 8; + BoxSplitConf split_box = 6; + BoxCloneConf clone_box = 7; } } diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index c6013e3f8afe126e922a5bb3854d65aea4258bad..a9429f7d0d71d97f9ce9b722cebd7b657fee89a2 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -2,6 +2,20 @@ namespace oneflow { +namespace { + +DataType GetDataTypeFromBnInOpVec( + std::function GetBlobDesc4BnInOp, + const std::vector& bn_in_ops) { + for (const std::string& bn_in_op : bn_in_ops) { + const BlobDesc* blob_desc = GetBlobDesc4BnInOp(bn_in_op); + if (blob_desc) { return blob_desc->data_type(); } + } + return DataType::kInvalidDataType; +} + +} // namespace + void Operator::InitFromOpConf(const OperatorConf& op_conf) { op_conf_ = op_conf; InitFromOpConf(); @@ -93,13 +107,12 @@ void Operator::GenKernelConf( kernel_conf->set_need_do_data_id(true); } kernel_conf->set_is_forward(is_forward); - if (output_bns_.empty() == false) { - kernel_conf->set_data_type(GetBlobDesc4BnInOp(output_bns_[0])->data_type()); - } else if (input_bns_.empty() == false) { - kernel_conf->set_data_type(GetBlobDesc4BnInOp(input_bns_[0])->data_type()); - } else { - kernel_conf->set_data_type(DataType::kInvalidDataType); + DataType data_type = + GetDataTypeFromBnInOpVec(GetBlobDesc4BnInOp, output_bns_); + if (data_type == DataType::kInvalidDataType) { + data_type = GetDataTypeFromBnInOpVec(GetBlobDesc4BnInOp, input_bns_); } + kernel_conf->set_data_type(data_type); VirtualGenKernelConf(GetBlobDesc4BnInOp, parallel_ctx, kernel_conf); }