diff --git a/oneflow/core/graph/op_graph.cpp b/oneflow/core/graph/op_graph.cpp index 55eb613d358cccf2b6a6819d999a9f248c3f8313..57ef95af3317d0141720a34abe1a721b8c876030 100644 --- a/oneflow/core/graph/op_graph.cpp +++ b/oneflow/core/graph/op_graph.cpp @@ -301,18 +301,9 @@ const ParallelDesc& OpNode::BlobParallelDesc4Obn(const std::string& obn) const { } void OpNode::InferBlobParallelDesc() { - auto ParallelDesc4Obn = [&](const std::string& obn) -> ParallelDesc* { - auto iter = obn2blob_parallel_desc_.find(obn); - if (iter == obn2blob_parallel_desc_.end()) { - iter = obn2blob_parallel_desc_.emplace(obn, parallel_desc()).first; - } - return &iter->second; - }; - auto LogicalBlobDesc4Ibn = [&](const std::string& ibn) -> const BlobDesc* { - return &LogicalBlobDesc4Lbi(op().BnInOp2Lbi(ibn)); - }; - CHECK_JUST(op().InferOutParallelDescIf(ParallelDesc4Obn, LogicalBlobDesc4Ibn, parallel_desc(), - &sbp_signature())); + for (const auto& bn : op().output_bns()) { + obn2blob_parallel_desc_.emplace(bn, *CHECK_JUST(op().GetParallelDesc4BnInOp(bn))); + } } void OpNode::InitLbi2SourceNode() { @@ -557,6 +548,7 @@ Maybe OpGraph::InferLogicalBlobDesc(const Job& job) const { auto LogicalBlobDesc4BnInOp = [&](const std::string& bn) -> const BlobDesc& { return op_node->LogicalBlobDesc4Lbi(op_node->op().BnInOp2Lbi(bn)); }; + JUST(op_node->mut_op()->FillOpParallelDesc(op_node->parallel_desc())); JUST(op_node->mut_op()->FillLogicalInBlobDesc(LogicalBlobDesc4BnInOp)); // Infer ParallelSignature JUST(op_node->mut_op()->InferParallelSignatureIf()); diff --git a/oneflow/core/job/job_build_and_infer_ctx.cpp b/oneflow/core/job/job_build_and_infer_ctx.cpp index 1a7c4f15abec61dac7a498089d98d9273ca67441..d4d2ce303c33cd3d5733f68aed8c6749fc6fdd15 100644 --- a/oneflow/core/job/job_build_and_infer_ctx.cpp +++ b/oneflow/core/job/job_build_and_infer_ctx.cpp @@ -538,6 +538,8 @@ Maybe JobBuildAndInferCtx::AddAndInferOp(const OperatorConf& op_con auto new_op_conf = JUST(DecodeLbiHintAndReturnNewOpConf(*op, &sbp_sig_conf, &ibn2disable_boxing)); AddOpAndUpdateJobParallelViewConf(*new_op_conf, sbp_sig_conf, is_mirrored_parallel_view); auto parallel_conf = JUST(InferOpParallelConf(*op, origin_parallel_conf, ibn2disable_boxing)); + ParallelDesc parallel_desc(*parallel_conf); + JUST(op->FillOpParallelDesc(parallel_desc)); JUST(AddOpNameParallelConf2Placement(op_name, *parallel_conf)); UpdateLbi2DisableBoxing(*op, ibn2disable_boxing); // infer batch_axis @@ -557,7 +559,6 @@ Maybe JobBuildAndInferCtx::AddAndInferOp(const OperatorConf& op_con JUST(op->FillInBatchAxis(BatchAxis4Ibn)); JUST(op->InferBatchAxisIf()); - ParallelDesc parallel_desc(*parallel_conf); // infer mirrored signature JUST(InferMirroredSignature(op, is_mirrored_parallel_view, parallel_desc)); // infer sbp signature @@ -576,10 +577,11 @@ Maybe JobBuildAndInferCtx::AddAndInferOp(const OperatorConf& op_con } return &iter->second; }; - JUST(op->InferOutParallelDescIf(ParallelDesc4Obn, GetBlobDesc4BnInOp, parallel_desc, - JUST(op->sbp_signature()))); - // TODO(lixinqi): replace lbi2parallel_desc_from_producer_view_ with ParallelSignature JUST(op->InferParallelSignatureIf()); + for (const auto& bn : op->output_bns()) { + lbi2parallel_desc_from_producer_view_.emplace(op->BnInOp2Lbi(bn), + *JUST(op->GetParallelDesc4BnInOp(bn))); + } JUST(AddLbiParallelConf2BlobPlacement(op, ParallelDesc4Obn)); // Infer whether input/output blobs are backward used InferBlobBackwardSignature(op); diff --git a/oneflow/core/operator/distribute_add_op.cpp b/oneflow/core/operator/distribute_add_op.cpp index f1710a991be041785932c866ffcb0e31407bc358..bb9ff8064fd89f7813d580cc6d6f20be6edb0885 100644 --- a/oneflow/core/operator/distribute_add_op.cpp +++ b/oneflow/core/operator/distribute_add_op.cpp @@ -36,7 +36,7 @@ class DistributeAddOp final : public Operator { LogicalNode* NewProperLogicalNode() const override { return new DistributeConcatLogicalNode; } private: - Maybe InferParallelSignature() override; + Maybe InferBlobParallelDesc() override; Maybe InferBatchAxis( std::function BatchAxis4BnInOp) const override; Maybe InferSbpSignature( @@ -53,22 +53,19 @@ void DistributeAddOp::InitFromOpConf() { EnrollOutputBn("out"); } -Maybe DistributeAddOp::InferParallelSignature() { - const auto& scope_storage = *Global>::Get(); - const auto& scope = JUST(scope_storage.MaybeGet(op_conf().scope_symbol_id())); - int64_t op_parallel_desc_symbol_id = JUST(scope.GetParallelDescSymbolId(op_conf())); - mut_parallel_signature()->set_op_parallel_desc_symbol_id(op_parallel_desc_symbol_id); - auto* map = mut_parallel_signature()->mutable_bn_in_op2parallel_desc_symbol_id(); - (*map)["out"] = op_parallel_desc_symbol_id; - const auto& op_parallel_desc = JUST(scope.GetParallelDesc(op_conf())); - CHECK_EQ(op_parallel_desc.parallel_num(), input_bns().size()); +Maybe DistributeAddOp::InferBlobParallelDesc() { + HashMap> bn2parallel_desc; + const std::shared_ptr op_parallel_desc = JUST(GetOpParallelDesc()); FOR_RANGE(int, i, 0, input_bns().size()) { - const auto& in_parallel_conf = op_parallel_desc.GetParallelIdOnlyParallelConf(i); - const std::shared_ptr& cfg_in_parallel_conf = - std::make_shared(in_parallel_conf); - (*map)[input_bns().Get(i)] = - Global::Get()->MakeParallelDescSymbol(cfg_in_parallel_conf); + bn2parallel_desc[input_bns().Get(i)] = + std::make_shared(op_parallel_desc->GetParallelIdOnlyParallelConf(i)); } + bn2parallel_desc["out"] = op_parallel_desc; + FillBlobParallelDesc([&](const std::string& bn) -> Maybe { + auto it = bn2parallel_desc.find(bn); + CHECK_OR_RETURN(it != bn2parallel_desc.end()); + return it->second; + }); return Maybe::Ok(); } diff --git a/oneflow/core/operator/distribute_clone_op.cpp b/oneflow/core/operator/distribute_clone_op.cpp index 4e49d44aaf5c61d4ea8f5d32253ee592e4896cb1..9f970dafaf1936529e517a36c48de062c016d3ce 100644 --- a/oneflow/core/operator/distribute_clone_op.cpp +++ b/oneflow/core/operator/distribute_clone_op.cpp @@ -33,7 +33,7 @@ class DistributeCloneOp final : public Operator { LogicalNode* NewProperLogicalNode() const override { return new DistributeSplitLogicalNode; } private: - Maybe InferParallelSignature() override; + Maybe InferBlobParallelDesc() override; Maybe InferBatchAxis( std::function BatchAxis4BnInOp) const override; Maybe InferSbpSignature( @@ -44,10 +44,6 @@ class DistributeCloneOp final : public Operator { Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override; - Maybe InferOutParallelDesc( - std::function ParallelDesc4Obn, - std::function LogicalBlobDesc4Ibn, const ParallelDesc&, - const SbpSignature*) const override; }; void DistributeCloneOp::InitFromOpConf() { @@ -76,38 +72,19 @@ Maybe DistributeCloneOp::InferOutBlobDescs( return Maybe::Ok(); } -Maybe DistributeCloneOp::InferOutParallelDesc( - std::function ParallelDesc4Obn, - std::function LogicalBlobDesc4Ibn, - const ParallelDesc& op_parallel_desc, const SbpSignature*) const { +Maybe DistributeCloneOp::InferBlobParallelDesc() { + HashMap> bn2parallel_desc; + const std::shared_ptr op_parallel_desc = JUST(GetOpParallelDesc()); + bn2parallel_desc["in"] = op_parallel_desc; FOR_RANGE(int, i, 0, output_bns().size()) { - const auto& obn = output_bns().Get(i); - if (op_parallel_desc.parallel_num() > 1) { - CHECK_EQ_OR_RETURN(op_parallel_desc.parallel_num(), output_bns().size()); - *ParallelDesc4Obn(obn) = ParallelDesc(op_parallel_desc.GetParallelIdOnlyParallelConf(i)); - } else { - *ParallelDesc4Obn(obn) = op_parallel_desc; - } - } - return Maybe::Ok(); -} - -Maybe DistributeCloneOp::InferParallelSignature() { - const auto& scope_storage = *Global>::Get(); - const auto& scope = JUST(scope_storage.MaybeGet(op_conf().scope_symbol_id())); - int64_t op_parallel_desc_symbol_id = JUST(scope.GetParallelDescSymbolId(op_conf())); - mut_parallel_signature()->set_op_parallel_desc_symbol_id(op_parallel_desc_symbol_id); - auto* map = mut_parallel_signature()->mutable_bn_in_op2parallel_desc_symbol_id(); - (*map)["in"] = op_parallel_desc_symbol_id; - const auto& op_parallel_desc = JUST(scope.GetParallelDesc(op_conf())); - CHECK_EQ_OR_RETURN(op_parallel_desc.parallel_num(), output_bns().size()); - FOR_RANGE(int, i, 0, output_bns().size()) { - const auto& out_parallel_conf = op_parallel_desc.GetParallelIdOnlyParallelConf(i); - const std::shared_ptr& cfg_out_parallel_conf = - std::make_shared(out_parallel_conf); - (*map)[output_bns().Get(i)] = - Global::Get()->MakeParallelDescSymbol(cfg_out_parallel_conf); + bn2parallel_desc[output_bns().Get(i)] = + std::make_shared(op_parallel_desc->GetParallelIdOnlyParallelConf(i)); } + FillBlobParallelDesc([&](const std::string& bn) -> Maybe { + auto it = bn2parallel_desc.find(bn); + CHECK_OR_RETURN(it != bn2parallel_desc.end()); + return it->second; + }); return Maybe::Ok(); } diff --git a/oneflow/core/operator/distribute_concat_op.cpp b/oneflow/core/operator/distribute_concat_op.cpp index 9a49394c6491da4f6ff80da933d1ab2ada2ad5da..e2cc8163c897a4c3ed3c0b57ecb17dea04cd4288 100644 --- a/oneflow/core/operator/distribute_concat_op.cpp +++ b/oneflow/core/operator/distribute_concat_op.cpp @@ -36,7 +36,7 @@ class DistributeConcatOp final : public Operator { LogicalNode* NewProperLogicalNode() const override { return new DistributeConcatLogicalNode; } private: - Maybe InferParallelSignature() override; + Maybe InferBlobParallelDesc() override; Maybe InferBatchAxis( std::function BatchAxis4BnInOp) const override; Maybe InferSbpSignature( @@ -102,22 +102,19 @@ Maybe DistributeConcatOp::InferOutBlobDescs( return Maybe::Ok(); } -Maybe DistributeConcatOp::InferParallelSignature() { - const auto& scope_storage = *Global>::Get(); - const auto& scope = JUST(scope_storage.MaybeGet(op_conf().scope_symbol_id())); - int64_t op_parallel_desc_symbol_id = JUST(scope.GetParallelDescSymbolId(op_conf())); - mut_parallel_signature()->set_op_parallel_desc_symbol_id(op_parallel_desc_symbol_id); - auto* map = mut_parallel_signature()->mutable_bn_in_op2parallel_desc_symbol_id(); - (*map)["out"] = op_parallel_desc_symbol_id; - const auto& op_parallel_desc = JUST(scope.GetParallelDesc(op_conf())); - CHECK_EQ(op_parallel_desc.parallel_num(), input_bns().size()); +Maybe DistributeConcatOp::InferBlobParallelDesc() { + HashMap> bn2parallel_desc; + const std::shared_ptr op_parallel_desc = JUST(GetOpParallelDesc()); FOR_RANGE(int, i, 0, input_bns().size()) { - const auto& in_parallel_conf = op_parallel_desc.GetParallelIdOnlyParallelConf(i); - const std::shared_ptr& cfg_in_parallel_conf = - std::make_shared(in_parallel_conf); - (*map)[input_bns().Get(i)] = - Global::Get()->MakeParallelDescSymbol(cfg_in_parallel_conf); + bn2parallel_desc[input_bns().Get(i)] = + std::make_shared(op_parallel_desc->GetParallelIdOnlyParallelConf(i)); } + bn2parallel_desc["out"] = op_parallel_desc; + FillBlobParallelDesc([&](const std::string& bn) -> Maybe { + auto it = bn2parallel_desc.find(bn); + CHECK_OR_RETURN(it != bn2parallel_desc.end()); + return it->second; + }); return Maybe::Ok(); } diff --git a/oneflow/core/operator/distribute_split_op.cpp b/oneflow/core/operator/distribute_split_op.cpp index 9627d6bb60b09449e9f4db1fbb37c642c2f76cf7..9a6dc03b752c5193514f55277919247fd150f0a2 100644 --- a/oneflow/core/operator/distribute_split_op.cpp +++ b/oneflow/core/operator/distribute_split_op.cpp @@ -33,7 +33,7 @@ class DistributeSplitOp final : public Operator { LogicalNode* NewProperLogicalNode() const override { return new DistributeSplitLogicalNode; } private: - Maybe InferParallelSignature() override; + Maybe InferBlobParallelDesc() override; Maybe InferBatchAxis( std::function BatchAxis4BnInOp) const override; Maybe InferSbpSignature( @@ -44,10 +44,6 @@ class DistributeSplitOp final : public Operator { Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override; - Maybe InferOutParallelDesc( - std::function ParallelDesc4Obn, - std::function LogicalBlobDesc4Ibn, const ParallelDesc&, - const SbpSignature*) const override; Maybe GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, @@ -90,38 +86,19 @@ Maybe DistributeSplitOp::InferOutBlobDescs( return Maybe::Ok(); } -Maybe DistributeSplitOp::InferOutParallelDesc( - std::function ParallelDesc4Obn, - std::function LogicalBlobDesc4Ibn, - const ParallelDesc& op_parallel_desc, const SbpSignature*) const { +Maybe DistributeSplitOp::InferBlobParallelDesc() { + HashMap> bn2parallel_desc; + const std::shared_ptr op_parallel_desc = JUST(GetOpParallelDesc()); + bn2parallel_desc["in"] = op_parallel_desc; FOR_RANGE(int, i, 0, output_bns().size()) { - const auto& obn = output_bns().Get(i); - if (op_parallel_desc.parallel_num() > 1) { - CHECK_EQ(op_parallel_desc.parallel_num(), output_bns().size()); - *ParallelDesc4Obn(obn) = ParallelDesc(op_parallel_desc.GetParallelIdOnlyParallelConf(i)); - } else { - *ParallelDesc4Obn(obn) = op_parallel_desc; - } - } - return Maybe::Ok(); -} - -Maybe DistributeSplitOp::InferParallelSignature() { - const auto& scope_storage = *Global>::Get(); - const auto& scope = JUST(scope_storage.MaybeGet(op_conf().scope_symbol_id())); - int64_t op_parallel_desc_symbol_id = JUST(scope.GetParallelDescSymbolId(op_conf())); - mut_parallel_signature()->set_op_parallel_desc_symbol_id(op_parallel_desc_symbol_id); - auto* map = mut_parallel_signature()->mutable_bn_in_op2parallel_desc_symbol_id(); - (*map)["in"] = op_parallel_desc_symbol_id; - const auto& op_parallel_desc = JUST(scope.GetParallelDesc(op_conf())); - CHECK_EQ(op_parallel_desc.parallel_num(), output_bns().size()); - FOR_RANGE(int, i, 0, output_bns().size()) { - const auto& out_parallel_conf = op_parallel_desc.GetParallelIdOnlyParallelConf(i); - const std::shared_ptr& cfg_out_parallel_conf = - std::make_shared(out_parallel_conf); - (*map)[output_bns().Get(i)] = - Global::Get()->MakeParallelDescSymbol(cfg_out_parallel_conf); + bn2parallel_desc[output_bns().Get(i)] = + std::make_shared(op_parallel_desc->GetParallelIdOnlyParallelConf(i)); } + FillBlobParallelDesc([&](const std::string& bn) -> Maybe { + auto it = bn2parallel_desc.find(bn); + CHECK_OR_RETURN(it != bn2parallel_desc.end()); + return it->second; + }); return Maybe::Ok(); } diff --git a/oneflow/core/operator/image_decoder_random_crop_resize_op.cpp b/oneflow/core/operator/image_decoder_random_crop_resize_op.cpp index e4afd0110e57366dc46e2860d7cb67dfca02cf08..30714c96fa8878e0f96a355d3457c7d78054718d 100644 --- a/oneflow/core/operator/image_decoder_random_crop_resize_op.cpp +++ b/oneflow/core/operator/image_decoder_random_crop_resize_op.cpp @@ -109,26 +109,26 @@ class ImageDecoderRandomCropResizeOp final : public Operator { } } - Maybe InferParallelSignature() override { + Maybe InferBlobParallelDesc() override { + HashMap> bn2parallel_desc; + const std::shared_ptr op_parallel_desc = JUST(GetOpParallelDesc()); + bn2parallel_desc["out"] = op_parallel_desc; if (device_type() == DeviceType::kCPU) { - return Operator::InferParallelSignature(); + bn2parallel_desc["in"] = op_parallel_desc; } else if (device_type() == DeviceType::kGPU) { - const auto& scope_storage = *Global>::Get(); - const auto& scope = JUST(scope_storage.MaybeGet(op_conf().scope_symbol_id())); - const int64_t device_parallel_desc_symbol_id = - scope.scope_proto().device_parallel_desc_symbol_id(); - const int64_t host_parallel_desc_symbol_id = - scope.scope_proto().host_parallel_desc_symbol_id(); - mut_parallel_signature()->set_op_parallel_desc_symbol_id(device_parallel_desc_symbol_id); - auto* map = mut_parallel_signature()->mutable_bn_in_op2parallel_desc_symbol_id(); - for (const auto& ibn : input_bns()) { (*map)[ibn] = host_parallel_desc_symbol_id; } - for (const auto& obn : output_bns()) { (*map)[obn] = device_parallel_desc_symbol_id; } - for (const auto& tbn : tmp_bns()) { (*map)[tbn] = device_parallel_desc_symbol_id; } - return Maybe::Ok(); + std::shared_ptr in_parallel_desc = + std::make_shared(*op_parallel_desc); + in_parallel_desc->set_device_type(DeviceType::kCPU); + bn2parallel_desc["in"] = in_parallel_desc; } else { - UNIMPLEMENTED(); - return Maybe::Ok(); + UNIMPLEMENTED_THEN_RETURN(); } + FillBlobParallelDesc([&](const std::string& bn) -> Maybe { + auto it = bn2parallel_desc.find(bn); + CHECK_OR_RETURN(it != bn2parallel_desc.end()); + return it->second; + }); + return Maybe::Ok(); } }; diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index 15b36cf4a3305440511757959d46d3e3d61b435b..19a4e79cedc839a4f54b8eb4e006a4fe7efa0770 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -23,6 +23,7 @@ limitations under the License. #include "oneflow/core/job/scope.h" #include "oneflow/core/operator/operator.h" #include "oneflow/core/operator/op_node_signature.pb.h" +#include "oneflow/core/job/foreign_callback.h" namespace oneflow { @@ -55,6 +56,9 @@ void Operator::Init(const OperatorConf& op_conf, const JobDesc* conf_job_desc) { *this_op_conf = op_conf; if (has_job_desc() && job_desc().IsPredict()) { this_op_conf->set_trainable(false); } InitFromOpConf(); + input_output_bns_.Reserve(input_bns().size() + output_bns().size()); + for (const auto& bn : input_bns()) { *input_output_bns_.Add() = bn; } + for (const auto& bn : output_bns()) { *input_output_bns_.Add() = bn; } } LogicalNode* Operator::NewProperLogicalNode() const { return new NormalForwardLogicalNode; } @@ -98,23 +102,63 @@ Maybe Operator::obn4lbi(const LogicalBlobId& lbi) const { } Maybe Operator::InferParallelSignatureIf() { + JUST(InferBlobParallelDesc()); if (op_conf().scope_symbol_id() == 0) { return Maybe::Ok(); } - return InferParallelSignature(); -} - -Maybe Operator::InferParallelSignature() { const auto& scope_storage = *Global>::Get(); const auto& scope = JUST(scope_storage.MaybeGet(op_conf().scope_symbol_id())); int64_t parallel_desc_symbol_id = JUST(scope.GetParallelDescSymbolId(op_conf())); auto* parallel_signature = op_attribute_.mutable_parallel_signature(); parallel_signature->set_op_parallel_desc_symbol_id(parallel_desc_symbol_id); auto* map = parallel_signature->mutable_bn_in_op2parallel_desc_symbol_id(); - for (const auto& ibn : input_bns()) { (*map)[ibn] = parallel_desc_symbol_id; } - for (const auto& obn : output_bns()) { (*map)[obn] = parallel_desc_symbol_id; } + CHECK_OR_RETURN(op_parallel_desc_); + CHECK_OR_RETURN(bn2parallel_desc_); + for (const auto& pair : *bn2parallel_desc_) { + if (*pair.second == *op_parallel_desc_) { + (*map)[pair.first] = parallel_desc_symbol_id; + } else { + (*map)[pair.first] = Global::Get()->MakeParallelDescSymbol( + std::make_shared(pair.second->parallel_conf())); + } + } + // TODO(liujuncheng): remove this for (const auto& tbn : tmp_bns()) { (*map)[tbn] = parallel_desc_symbol_id; } return Maybe::Ok(); } +Maybe Operator::GetParallelDesc4BnInOp(const std::string& bn) const { + CHECK_OR_RETURN(bn2parallel_desc_); + auto it = bn2parallel_desc_->find(bn); + CHECK_OR_RETURN(it != bn2parallel_desc_->end()); + return it->second; +} + +Maybe Operator::FillBlobParallelDesc( + const std::function(const std::string&)>& ParallelDesc4Bn) { + CHECK_OR_RETURN(!bn2parallel_desc_); + bn2parallel_desc_.reset(new HashMap>); + for (const auto& bn : input_output_bns()) { + CHECK(bn2parallel_desc_->emplace(bn, JUST(ParallelDesc4Bn(bn))).second); + } + return Maybe::Ok(); +} + +Maybe Operator::InferBlobParallelDesc() { + FillBlobParallelDesc( + [&](const std::string& bn) -> Maybe { return GetOpParallelDesc(); }); + return Maybe::Ok(); +} + +Maybe Operator::FillOpParallelDesc(const ParallelDesc& parallel_desc) { + CHECK_OR_RETURN(!op_parallel_desc_); + op_parallel_desc_.reset(new ParallelDesc(parallel_desc)); + return Maybe::Ok(); +} + +Maybe Operator::GetOpParallelDesc() const { + CHECK_OR_RETURN(op_parallel_desc_); + return op_parallel_desc_; +} + namespace { Maybe FillLogicalBlobDesc( @@ -286,22 +330,6 @@ Maybe Operator::InferInplaceObn2Ibn( return Maybe::Ok(); } -Maybe Operator::InferOutParallelDescIf( - std::function ParallelDesc4Obn, - std::function LogicalBlobDesc4Ibn, - const ParallelDesc& op_parallel_desc, const SbpSignature* sbp_signature) const { - return InferOutParallelDesc(ParallelDesc4Obn, LogicalBlobDesc4Ibn, op_parallel_desc, - sbp_signature); -} - -Maybe Operator::InferOutParallelDesc( - std::function ParallelDesc4Obn, - std::function LogicalBlobDesc4Ibn, - const ParallelDesc& op_parallel_desc, const SbpSignature* sbp_signature) const { - for (const auto& obn : output_bns()) { *ParallelDesc4Obn(obn) = op_parallel_desc; } - return Maybe::Ok(); -} - Maybe Operator::InferOutputBlobTimeShapeIf( std::function GetTimeShape4BnInOp, const ParallelContext* parallel_ctx, Shape* time_shape) const { @@ -1091,6 +1119,7 @@ Maybe ConstructAndInferOp(const OperatorConf& op_conf, bool is_mirrored = scope.opt_mirrored_parallel_conf().has_mirrored_parallel(); const auto& op = ConstructOp(op_conf, JUST(scope.job_desc())); JUST(CheckOpInputSignature(*op, upstream_signature)); + JUST(op->FillOpParallelDesc(parallel_desc)); HashMap> bn_in_op2blob_desc; for (const auto& ibn : op->input_bns()) { const auto& map = upstream_signature.logical_blob_desc_signature().bn_in_op2blob_desc(); diff --git a/oneflow/core/operator/operator.h b/oneflow/core/operator/operator.h index 83ad704774f9727eda40d9d00a39b417bae33a19..65272b2fc3167cc2985a95bb22fada79722151e4 100644 --- a/oneflow/core/operator/operator.h +++ b/oneflow/core/operator/operator.h @@ -87,7 +87,13 @@ class Operator { #undef DEFINE_BLOB_NAMES_GETTER + const PbRpf& input_output_bns() const { return input_output_bns_; }; + + Maybe FillOpParallelDesc(const ParallelDesc& parallel_desc); + Maybe GetOpParallelDesc() const; + Maybe InferParallelSignatureIf(); + Maybe GetParallelDesc4BnInOp(const std::string& bn) const; // Read: shape of input_blobs // Write: shape of output_blobs @@ -128,15 +134,6 @@ class Operator { const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const; - Maybe InferOutParallelDescIf( - std::function ParallelDesc4Obn, - std::function LogicalBlobDesc4Ibn, const ParallelDesc&, - const SbpSignature*) const; - virtual Maybe InferOutParallelDesc( - std::function ParallelDesc4Obn, - std::function LogicalBlobDesc4Ibn, const ParallelDesc&, - const SbpSignature*) const; - Maybe FillInBatchAxis( const std::function(const std::string&)>& BatchAxis4BnInOp); Maybe FillOutBatchAxis( @@ -202,7 +199,9 @@ class Operator { } protected: - virtual Maybe InferParallelSignature(); + Maybe FillBlobParallelDesc( + const std::function(const std::string&)>& ParallelDesc4Bn); + virtual Maybe InferBlobParallelDesc(); virtual Maybe InferOutBlobDescs( std::function GetBlobDesc4BnInOp, const ParallelContext*, const SbpSignature* sbp_signature) const; @@ -320,10 +319,14 @@ class Operator { OpAttribute op_attribute_; const JobDesc* job_desc_; HashMap lbi2obn_; + std::shared_ptr op_parallel_desc_; + std::unique_ptr>> bn2parallel_desc_; std::unique_ptr>> ibn2logical_blob_desc_; std::unique_ptr>> obn2logical_blob_desc_; std::unique_ptr>> ibn2batch_axis_; std::unique_ptr>> obn2batch_axis_; + + PbRpf input_output_bns_; }; std::string GenRepeatedBn(const std::string& bn_prefix, int32_t idx);