diff --git a/oneflow/core/eager/opkernel_object.cpp b/oneflow/core/eager/opkernel_object.cpp index f15431d5f2908eb7ff7b916f189bf352e2a392e2..ac87c0e2a934cf87036292a483567fe10eb8eb52 100644 --- a/oneflow/core/eager/opkernel_object.cpp +++ b/oneflow/core/eager/opkernel_object.cpp @@ -23,6 +23,13 @@ Maybe OpKernelObject::ResetOpAndKernel( const std::function& BlobDesc4BnInOp, const ParallelDesc* parallel_desc) { auto op = ConstructOp(op_conf_, device_type_, job_desc_.get()); + JUST(op->FillOpParallelDesc(*parallel_desc)); + const auto LogicalBlobDesc4BnInOp = [&](const std::string& bn) -> const BlobDesc& { + return CHECK_JUST(op_node_signature.LogicalBlobDesc4BnInOp(bn)); + }; + JUST(op->FillLogicalInBlobDesc(LogicalBlobDesc4BnInOp)); + JUST(op->FillLogicalOutBlobDesc(LogicalBlobDesc4BnInOp)); + JUST(op->FillSbpSignature(op_node_signature.sbp_signature())); JUST(InferBlobDescs(*op, BlobDesc4BnInOp, &op_node_signature.sbp_signature(), parallel_ctx)); NewPartialInitializedKernel(*op, BlobDesc4BnInOp, op_node_signature, parallel_ctx, parallel_desc); return Maybe::Ok(); @@ -40,11 +47,7 @@ void OpKernelObject::NewPartialInitializedKernel( const OpNodeSignatureDesc& op_node_signature, const ParallelContext* parallel_ctx, const ParallelDesc* parallel_desc) { KernelConf kernel_conf; - auto LogicalBlobDesc4BnInOp = [&](const std::string& bn_in_op) -> const BlobDesc& { - return CHECK_JUST(op_node_signature.LogicalBlobDesc4BnInOp(bn_in_op)); - }; - op.GenKernelConf(BlobDesc4BnInOp, parallel_ctx, &kernel_conf, LogicalBlobDesc4BnInOp, - parallel_desc, &op_node_signature.sbp_signature()); + op.GenKernelConf(BlobDesc4BnInOp, parallel_ctx, &kernel_conf); kernel_.reset(new EagerKernel(job_desc_.get(), kernel_conf)); } @@ -53,6 +56,13 @@ Maybe SystemOpKernelObject::ResetKernel( const std::function& BlobDesc4BnInOp, const ParallelDesc* parallel_desc) { auto op = ConstructOp(op_conf_, device_type_, job_desc_.get()); + JUST(op->FillOpParallelDesc(*parallel_desc)); + const auto LogicalBlobDesc4BnInOp = [&](const std::string& bn) -> const BlobDesc& { + return CHECK_JUST(op_node_signature.LogicalBlobDesc4BnInOp(bn)); + }; + JUST(op->FillLogicalInBlobDesc(LogicalBlobDesc4BnInOp)); + JUST(op->FillLogicalOutBlobDesc(LogicalBlobDesc4BnInOp)); + JUST(op->FillSbpSignature(op_node_signature.sbp_signature())); JUST(InferBlobDescs(*op, BlobDesc4BnInOp, &op_node_signature.sbp_signature(), parallel_ctx)); ResetKernel(*op, BlobDesc4BnInOp, op_node_signature, parallel_ctx, parallel_desc); return Maybe::Ok(); @@ -70,11 +80,7 @@ void SystemOpKernelObject::ResetKernel( const OpNodeSignatureDesc& op_node_signature, const ParallelContext* parallel_ctx, const ParallelDesc* parallel_desc) { KernelConf kernel_conf; - auto LogicalBlobDesc4BnInOp = [&](const std::string& bn_in_op) -> const BlobDesc& { - return CHECK_JUST(op_node_signature.LogicalBlobDesc4BnInOp(bn_in_op)); - }; - op.GenKernelConf(BlobDesc4BnInOp, parallel_ctx, &kernel_conf, LogicalBlobDesc4BnInOp, - parallel_desc, &op_node_signature.sbp_signature()); + op.GenKernelConf(BlobDesc4BnInOp, parallel_ctx, &kernel_conf); kernel_ = ConstructKernel(job_desc_.get(), kernel_conf, nullptr); } diff --git a/oneflow/core/graph/exec_graph.cpp b/oneflow/core/graph/exec_graph.cpp index 69e6a465eacbbe7af689501c4d83b861c71fa0d6..0208c28204b7938d2c47085c92186f498d446e51 100644 --- a/oneflow/core/graph/exec_graph.cpp +++ b/oneflow/core/graph/exec_graph.cpp @@ -59,11 +59,7 @@ void ExecNode::UnbindBnWithEmptyRegst() { } void ExecNode::ToProto(const ParallelContext* parallel_ctx, ExecNodeProto* ret) const { - const OpNode* op_node = Global::Get()->OpNode4OpName(op_->op_name()); - const ParallelDesc* parallel_desc = op_node == nullptr ? nullptr : &op_node->parallel_desc(); - const SbpSignature* sbp_signature = op_node == nullptr ? nullptr : &op_node->sbp_signature(); - op_->GenKernelConf(GetBlobDesc4BnInOpFunc(), parallel_ctx, ret->mutable_kernel_conf(), - GetLogicalBlobDesc4BnInOpFunc(), parallel_desc, sbp_signature); + op_->GenKernelConf(GetBlobDesc4BnInOpFunc(), parallel_ctx, ret->mutable_kernel_conf()); for (const auto& bn_regst : bn_in_op2regst_) { const std::string& bn_in_op = bn_regst.first; auto regst = bn_regst.second; @@ -83,7 +79,7 @@ void ExecNode::InferBlobDescs(const ParallelContext* parallel_ctx) { CHECK_JUST(op_->InferBlobDescsIf(GetBlobDesc4BnInOp, parallel_ctx, sbp_signature)); Global::Get()->CheckBlobDescs(op_->op_name(), GetBlobDesc4BnInOp, parallel_ctx); CHECK_JUST(op_->InferInplaceObn2IbnIf(&mut_inplace_obn2ibn_, &con_inplace_obn2ibn_, - GetBlobDesc4BnInOp, parallel_ctx, sbp_signature)); + GetBlobDesc4BnInOp, parallel_ctx)); } std::function ExecNode::GetLogicalBlobDesc4BnInOpFunc() const { diff --git a/oneflow/core/graph/op_graph.h b/oneflow/core/graph/op_graph.h index 56d6aca036b455f30c1bea9a74fa14ec684cd79b..a4848f6253f509bc6730436ed49375e812e16b4a 100644 --- a/oneflow/core/graph/op_graph.h +++ b/oneflow/core/graph/op_graph.h @@ -64,7 +64,6 @@ class OpNode final : public Node { // Setters Operator* mut_op() { return op_.get(); } ParallelDesc* mut_parallel_desc() { return ¶llel_desc_; } - SbpSignature* mut_sbp_signature() { return mut_op()->mut_sbp_signature(); } Shape* mut_out_blob_time_shape(); HashMap>>* mut_bn2parallel_id2blob_desc() { return &bn2parallel_id2blob_desc_; diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index 8f6ff49f4526faaa5e71fa4c48f6f712594d0f40..9fcfdc3b7ba91bbdecfdd5219fcd365518a456a9 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -308,17 +308,17 @@ Maybe Operator::InferInternalBlobDescs( Maybe Operator::InferInplaceObn2IbnIf( HashMap* mut_inplace_obn2ibn, HashMap* con_inplace_obn2ibn, - std::function GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { + const std::function& GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { return InferInplaceObn2Ibn(mut_inplace_obn2ibn, con_inplace_obn2ibn, GetBlobDesc4BnInOp, - parallel_ctx, sbp_signature); + parallel_ctx); } Maybe Operator::InferInplaceObn2Ibn( HashMap* mut_inplace_obn2ibn, HashMap* con_inplace_obn2ibn, - std::function GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { + const std::function& GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { for (const std::string& obn : output_bns()) { const auto& obn_modifier = OutputBlobModifier4Obn(obn); if (obn_modifier.has_mutable_inplace_ibn()) { @@ -373,21 +373,30 @@ void Operator::ForEachBnInOp(std::function Handler) co for (const std::string& bn_in_op : tmp_bns()) { Handler(bn_in_op); } } +Maybe Operator::FillSbpSignature(const SbpSignature& sbp_signature) { + CHECK_OR_RETURN(!sbp_signature_); + sbp_signature_.reset(new SbpSignature(sbp_signature)); + *op_attribute_.mutable_sbp_signature() = sbp_signature; + return Maybe::Ok(); +} + Maybe Operator::InferSbpSignatureIf( const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) { + SbpSignature signature; if (parallel_desc.parallel_num() == 1) { - auto* bn2sbp = mut_sbp_signature()->mutable_bn_in_op2sbp_parallel(); + auto* bn2sbp = signature.mutable_bn_in_op2sbp_parallel(); for (const auto& ibn : input_bns()) { (*bn2sbp)[ibn].mutable_split_parallel()->set_axis(0); } for (const auto& obn : output_bns()) { (*bn2sbp)[obn].mutable_split_parallel()->set_axis(0); } } else if (parallel_desc.parallel_num() > 1) { - return InferSbpSignature(mut_sbp_signature(), sbp_sig_conf, CalcOrderValue4SbpSig, - SbpInferHint4Ibn, parallel_desc); + InferSbpSignature(&signature, sbp_sig_conf, CalcOrderValue4SbpSig, SbpInferHint4Ibn, + parallel_desc); } else { UNIMPLEMENTED(); } + FillSbpSignature(signature); return Maybe::Ok(); } @@ -529,13 +538,13 @@ Maybe Operator::InferMirroredSignature( } Maybe Operator::sbp_signature() const { - CHECK_OR_RETURN(op_attribute_.has_sbp_signature()) << "sbp signature not infered"; - return &op_attribute_.sbp_signature(); + CHECK_OR_RETURN(sbp_signature_) << "sbp signature not infered"; + return sbp_signature_.get(); } Maybe Operator::SbpParallel4BnInOp(const std::string& bn_in_op) const { - CHECK_OR_RETURN(op_attribute_.has_sbp_signature()) << "sbp signature not infered"; - const auto& map = op_attribute_.sbp_signature().bn_in_op2sbp_parallel(); + CHECK_OR_RETURN(sbp_signature_) << "sbp signature not infered"; + const auto& map = sbp_signature_->bn_in_op2sbp_parallel(); const auto& iter = map.find(bn_in_op); CHECK_OR_RETURN(iter != map.end()) << "blob_name " << bn_in_op << " not found in sbp signature"; return &iter->second; @@ -571,10 +580,8 @@ bool HasBlobDescWithField(std::function Get } // namespace void Operator::GenKernelConf( - std::function GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, KernelConf* kernel_conf, - std::function LogicalBlobDesc4BnInOp, - const ParallelDesc* parallel_desc, const SbpSignature* sbp_signature) const { + const std::function& GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const { auto* dtype_signature = kernel_conf->mutable_dtype_signature(); for (const std::string& ibn : input_bns()) { const BlobDesc* blob_desc = GetBlobDesc4BnInOp(ibn); @@ -606,31 +613,6 @@ void Operator::GenKernelConf( kernel_conf->set_data_type(data_type); } - VirtualGenKernelConf(GetBlobDesc4BnInOp, parallel_ctx, kernel_conf, LogicalBlobDesc4BnInOp, - parallel_desc, sbp_signature); -} - -void Operator::VirtualGenKernelConf( - std::function GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, KernelConf* kernel_conf, - std::function LogicalBlobDesc4BnInOp, - const ParallelDesc* parallel_desc, const SbpSignature* sbp_signature) const { - VirtualGenKernelConf(GetBlobDesc4BnInOp, parallel_ctx, kernel_conf, LogicalBlobDesc4BnInOp, - parallel_desc); -} - -void Operator::VirtualGenKernelConf( - std::function GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, KernelConf* kernel_conf, - std::function LogicalBlobDesc4BnInOp, - const ParallelDesc* parallel_desc) const { - VirtualGenKernelConf(GetBlobDesc4BnInOp, parallel_ctx, kernel_conf, LogicalBlobDesc4BnInOp); -} - -void Operator::VirtualGenKernelConf( - std::function GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, KernelConf* kernel_conf, - std::function LogicalBlobDesc4BnInOp) const { VirtualGenKernelConf(GetBlobDesc4BnInOp, parallel_ctx, kernel_conf); } diff --git a/oneflow/core/operator/operator.h b/oneflow/core/operator/operator.h index ca73b7fa3b4d525f3e72bb68f6d5d0f55e563ba7..4a44f40cdab7c1931e54f09b0aec2d4a7c936f3b 100644 --- a/oneflow/core/operator/operator.h +++ b/oneflow/core/operator/operator.h @@ -95,8 +95,6 @@ class Operator { Maybe InferParallelSignatureIf(); Maybe GetParallelDesc4BnInOp(const std::string& bn) const; - // Read: shape of input_blobs - // Write: shape of output_blobs Maybe FillLogicalInBlobDesc( const std::function& BlobDesc4BnInOp); Maybe FillLogicalInBlobDesc( @@ -128,11 +126,11 @@ class Operator { std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const; - Maybe InferInplaceObn2IbnIf(HashMap* mut_inplace_obn2ibn, - HashMap* con_inplace_obn2ibn, - std::function GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, - const SbpSignature* sbp_signature) const; + Maybe InferInplaceObn2IbnIf( + HashMap* mut_inplace_obn2ibn, + HashMap* con_inplace_obn2ibn, + const std::function& GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const; // Infer out blob's time shape Maybe InferOutputBlobTimeShapeIf( @@ -141,7 +139,8 @@ class Operator { virtual Maybe InferOutputBlobTimeShape( std::function GetTimeShape4BnInOp, const ParallelContext*, Shape* time_shape) const; - // Infer blob's SbpSignature + + Maybe FillSbpSignature(const SbpSignature& sbp_signature); Maybe InferSbpSignatureIf( const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, @@ -152,10 +151,8 @@ class Operator { std::function(const std::string&)> MirroredSigInferHint4Ibn, bool is_mirrored_parallel_view_conf, const ParallelDesc& parallel_desc); - void GenKernelConf(std::function GetBlobDesc4BnInOp, - const ParallelContext*, KernelConf*, - std::function LogicalBlobDesc4BnInOp, - const ParallelDesc* parallel_desc, const SbpSignature* sbp_signature) const; + void GenKernelConf(const std::function& GetBlobDesc4BnInOp, + const ParallelContext*, KernelConf*) const; const InputBlobModifier& InputBlobModifier4Ibn(const std::string& ibn) const; const OutputBlobModifier& OutputBlobModifier4Obn(const std::string& obn) const; Maybe SbpParallel4BnInOp(const std::string& bn_in_op) const; @@ -175,7 +172,6 @@ class Operator { ParallelSignature* mut_parallel_signature() { return op_attribute_.mutable_parallel_signature(); } Maybe sbp_signature() const; - SbpSignature* mut_sbp_signature() { return op_attribute_.mutable_sbp_signature(); } BlobLastUsedSignature* mut_blob_last_used_signature() { return op_attribute_.mutable_blob_last_used_signature(); } @@ -225,22 +221,8 @@ class Operator { virtual Maybe InferInplaceObn2Ibn( HashMap* mut_inplace_obn2ibn, HashMap* con_inplace_obn2ibn, - std::function GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const; - - virtual void VirtualGenKernelConf( - std::function GetBlobDesc4BnInOp, const ParallelContext*, - KernelConf*, std::function LogicalBlobDesc4BnInOp, - const ParallelDesc* parallel_desc, const SbpSignature* sbp_signature) const; - - virtual void VirtualGenKernelConf( - std::function GetBlobDesc4BnInOp, const ParallelContext*, - KernelConf*, std::function LogicalBlobDesc4BnInOp, - const ParallelDesc* parallel_desc) const; - - virtual void VirtualGenKernelConf( - std::function GetBlobDesc4BnInOp, const ParallelContext*, - KernelConf*, std::function LogicalBlobDesc4BnInOp) const; + const std::function& GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const; virtual void VirtualGenKernelConf( std::function GetBlobDesc4BnInOp, const ParallelContext*, @@ -249,8 +231,6 @@ class Operator { virtual LogicalBlobId lbi4ibn(const std::string& input_bn) const; virtual LogicalBlobId lbi4obn(const std::string& output_bn) const; - OperatorConf* mut_op_conf() { return op_attribute_.mutable_op_conf(); } - // enroll data blobs void EnrollTmpBn(const std::string& dtbn); void EnrollRepeatedInputBn(const std::string& ibn_prefix, int32_t num, bool has_diff); @@ -307,6 +287,7 @@ class Operator { std::unique_ptr>> bn2parallel_desc_; std::unique_ptr>> ibn2logical_blob_desc_; std::unique_ptr>> obn2logical_blob_desc_; + std::shared_ptr sbp_signature_; PbRpf input_output_bns_; }; diff --git a/oneflow/core/operator/user_op.cpp b/oneflow/core/operator/user_op.cpp index 320edce218cef7aa3e291a09f09a32fa5fd82efb..69e5b90004a047f7061af2a64fd1d3f05296e2d0 100644 --- a/oneflow/core/operator/user_op.cpp +++ b/oneflow/core/operator/user_op.cpp @@ -401,9 +401,9 @@ Maybe UserOp::InferOutBlobDescs( Maybe UserOp::InferInplaceObn2Ibn( HashMap* mut_inplace_obn2ibn, HashMap* con_inplace_obn2ibn, - std::function GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { - UserOpInferContext infer_ctx(op_conf(), parallel_ctx, sbp_signature, job_desc(), + const std::function& GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const { + UserOpInferContext infer_ctx(op_conf(), parallel_ctx, JUST(sbp_signature()), job_desc(), GetBlobDesc4BnInOp); const user_op::OpKernelRegistryResult* kernel_reg_val = JUST(user_op::UserOpRegistryMgr::Get().GetOpKernelRegistryResult( @@ -548,28 +548,23 @@ Symbol UserOp::GetOpConfWithoutOpNameAndLbn() const { void UserOp::VirtualGenKernelConf( std::function GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, KernelConf* kernel_conf, - std::function LogicalBlobDesc4BnInOp, - const ParallelDesc* parallel_desc, const SbpSignature* sbp_signature) const { + const ParallelContext* parallel_ctx, KernelConf* kernel_conf) const { auto user_conf = kernel_conf->mutable_user_conf(); *(user_conf->mutable_parallel_ctx()) = *parallel_ctx; - *(user_conf->mutable_sbp_sig()) = *sbp_signature; -#define BLOB_DESCS_TO_PROTO(prefix, is_arg) \ - for (const auto& bn : prefix##_bns()) { \ - const BlobDesc* blob_desc = GetBlobDesc4BnInOp(bn); \ - if (blob_desc) { blob_desc->ToProto(&(*user_conf->mutable_bn_in_op2blob_desc())[bn]); } \ - if (is_arg) { \ - LogicalBlobDesc4BnInOp(bn).ToProto(&(*user_conf->mutable_bn_in_op2logical_blob_desc())[bn]); \ - } \ - } - - BLOB_DESCS_TO_PROTO(input, true) - BLOB_DESCS_TO_PROTO(output, true) - BLOB_DESCS_TO_PROTO(tmp, false) - -#undef BLOB_DESCS_TO_PROTO - CHECK_NOTNULL(parallel_desc); - *user_conf->mutable_parallel_conf() = parallel_desc->parallel_conf(); + *(user_conf->mutable_sbp_sig()) = *CHECK_JUST(sbp_signature()); + ForEachBnInOp([&](const std::string& bn) { + const BlobDesc* blob_desc = GetBlobDesc4BnInOp(bn); + if (blob_desc) { blob_desc->ToProto(&(*user_conf->mutable_bn_in_op2blob_desc())[bn]); } + }); + for (const std::string& ibn : input_bns()) { + CHECK_JUST(GetLogicalBlobDesc4Ibn(ibn)) + ->ToProto(&(*user_conf->mutable_bn_in_op2logical_blob_desc())[ibn]); + } + for (const std::string& obn : output_bns()) { + CHECK_JUST(GetLogicalBlobDesc4Obn(obn)) + ->ToProto(&(*user_conf->mutable_bn_in_op2logical_blob_desc())[obn]); + } + *user_conf->mutable_parallel_conf() = CHECK_JUST(GetOpParallelDesc())->parallel_conf(); } REGISTER_OP(OperatorConf::kUserConf, UserOp); diff --git a/oneflow/core/operator/user_op.h b/oneflow/core/operator/user_op.h index 492e9b57d1e5c083d399f84046e42d6e063b6002..ee2c3bca1c48b897799dee78f0bf3bc75290652a 100644 --- a/oneflow/core/operator/user_op.h +++ b/oneflow/core/operator/user_op.h @@ -34,11 +34,11 @@ class UserOp final : public Operator { Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext*, const SbpSignature* sbp_signature) const override; - Maybe InferInplaceObn2Ibn(HashMap* mut_inplace_obn2ibn, - HashMap* con_inplace_obn2ibn, - std::function GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, - const SbpSignature* sbp_signature) const override; + Maybe InferInplaceObn2Ibn( + HashMap* mut_inplace_obn2ibn, + HashMap* con_inplace_obn2ibn, + const std::function& GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx) const override; Symbol GetOpConfWithoutOpNameAndLbn() const override; private: @@ -55,11 +55,9 @@ class UserOp final : public Operator { Maybe InferOutputBlobTimeShape( std::function GetTimeShape4BnInOp, const ParallelContext*, Shape* time_shape) const override; - void VirtualGenKernelConf( - std::function GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, KernelConf* kernel_conf, - std::function LogicalBlobDesc4BnInOp, - const ParallelDesc* parallel_desc, const SbpSignature* sbp_signature) const override; + void VirtualGenKernelConf(std::function GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx, + KernelConf* kernel_conf) const override; const user_op::OpRegistryResult* val_; };