diff --git a/oneflow/core/operator/assign_op.cpp b/oneflow/core/operator/assign_op.cpp index ceb30760a8d7cb625f86fac85a3e9c0bc01eb691..fdbb9b12f086b645fa02ef29b35374ef235c2b94 100644 --- a/oneflow/core/operator/assign_op.cpp +++ b/oneflow/core/operator/assign_op.cpp @@ -24,6 +24,9 @@ class AssignOp final : public Operator { ~AssignOp() override = default; void InitFromOpConf() override; + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override; @@ -46,13 +49,27 @@ std::string DebugString(const BlobDesc& blob_desc) { return blob_desc_proto.DebugString(); } +namespace { + +Maybe InferBlobDescs(const std::function& BlobDesc4BnInOp) { + CHECK_OR_RETURN(*BlobDesc4BnInOp("ref") == *BlobDesc4BnInOp("value")) + << "\nref_blob_desc: " << DebugString(*BlobDesc4BnInOp("ref")) + << "\nvalue_blob_desc: " << DebugString(*BlobDesc4BnInOp("value")); + return Maybe::Ok(); +} + +} // namespace + +Maybe AssignOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + return InferBlobDescs(BlobDesc4BnInOp); +} + Maybe AssignOp::InferOutBlobDescs( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { - CHECK_OR_RETURN(*GetBlobDesc4BnInOp("ref") == *GetBlobDesc4BnInOp("value")) - << "\nref_blob_desc: " << DebugString(*GetBlobDesc4BnInOp("ref")) - << "\nvalue_blob_desc: " << DebugString(*GetBlobDesc4BnInOp("value")); - return Maybe::Ok(); + return InferBlobDescs(GetBlobDesc4BnInOp); } Maybe AssignOp::GetSbpSignatures( diff --git a/oneflow/core/operator/broadcast_to_compatible_with_op.cpp b/oneflow/core/operator/broadcast_to_compatible_with_op.cpp index c3b034622bcbd92131dbacbe03e957fa281c8eca..86a817d01062b3fc6808738e499435ecba5654eb 100644 --- a/oneflow/core/operator/broadcast_to_compatible_with_op.cpp +++ b/oneflow/core/operator/broadcast_to_compatible_with_op.cpp @@ -35,6 +35,21 @@ Maybe GetBroadcastShape(const Shape& a_shape, const Shape& b_shape, Shape* return Maybe::Ok(); } +Maybe InferBlobDescs(const OperatorConf& op_conf, + const std::function& BlobDesc4BnInOp) { + int64_t num_compatibles = op_conf.broadcast_to_compatible_with_conf().compatible_size(); + const BlobDesc* x_desc = BlobDesc4BnInOp("x"); + Shape broadcasted_shape(x_desc->shape()); + FOR_RANGE(int64_t, i, 0, num_compatibles) { + const BlobDesc* compatible_i = BlobDesc4BnInOp(GenRepeatedBn("compatible", i)); + GetBroadcastShape(broadcasted_shape, compatible_i->shape(), &broadcasted_shape); + } + BlobDesc* y_desc = BlobDesc4BnInOp("y"); + y_desc->CopyFrom(*x_desc); + y_desc->mut_shape() = broadcasted_shape; + return Maybe::Ok(); +} + } // namespace class BroadcastToCompatibleWithOp final : public Operator { @@ -50,20 +65,16 @@ class BroadcastToCompatibleWithOp final : public Operator { EnrollOutputBn("y"); } + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override { + return InferBlobDescs(op_conf(), BlobDesc4BnInOp); + } + Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override { - int64_t num_compatibles = op_conf().broadcast_to_compatible_with_conf().compatible_size(); - const BlobDesc* x_desc = GetBlobDesc4BnInOp("x"); - Shape broadcasted_shape(x_desc->shape()); - FOR_RANGE(int64_t, i, 0, num_compatibles) { - const BlobDesc* compatible_i = GetBlobDesc4BnInOp(GenRepeatedBn("compatible", i)); - GetBroadcastShape(broadcasted_shape, compatible_i->shape(), &broadcasted_shape); - } - BlobDesc* y_desc = GetBlobDesc4BnInOp("y"); - y_desc->CopyFrom(*x_desc); - y_desc->mut_shape() = broadcasted_shape; - return Maybe::Ok(); + return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp); } private: diff --git a/oneflow/core/operator/callback_notify_op.cpp b/oneflow/core/operator/callback_notify_op.cpp index c9937c28d4a80f3927db03c3bf1351e8637db53c..7a86c27b3af196896c0b2af1804693b63d5400ea 100644 --- a/oneflow/core/operator/callback_notify_op.cpp +++ b/oneflow/core/operator/callback_notify_op.cpp @@ -28,17 +28,31 @@ LogicalNode* CallbackNotifyOp::NewProperLogicalNode() const { return new CallbackNotifyLogicalNode(); } +namespace { + +Maybe InferBlobDescs(const std::function& BlobDesc4BnInOp) { + CHECK_OR_RETURN(BlobDesc4BnInOp("in")->shape() == Shape({1})); + CHECK_OR_RETURN(IsIntegralDataType(BlobDesc4BnInOp("in")->data_type())); + return Maybe::Ok(); +} + +} // namespace + +Maybe CallbackNotifyOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + CHECK_EQ_OR_RETURN(parallel_desc.parallel_num(), 1); + return InferBlobDescs(BlobDesc4BnInOp); +} + Maybe CallbackNotifyOp::InferOutBlobDescs( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { CHECK_EQ_OR_RETURN(parallel_ctx->parallel_num(), 1); - CHECK_OR_RETURN(GetBlobDesc4BnInOp("in")->shape() == Shape({1})); - CHECK_OR_RETURN(IsIntegralDataType(GetBlobDesc4BnInOp("in")->data_type())); - return Maybe::Ok(); + return InferBlobDescs(GetBlobDesc4BnInOp); } Maybe CallbackNotifyOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const { - SbpSignatureBuilder().Split(input_bns(), 0).Build(sbp_sig_list->mutable_sbp_signature()->Add()); return Maybe::Ok(); } diff --git a/oneflow/core/operator/callback_notify_op.h b/oneflow/core/operator/callback_notify_op.h index c08a2a707f48a8bc2ff3bda07b2e290ba7e930f7..dbf5c343e168d4fdae796c97f6ece7e2664a536f 100644 --- a/oneflow/core/operator/callback_notify_op.h +++ b/oneflow/core/operator/callback_notify_op.h @@ -27,6 +27,9 @@ class CallbackNotifyOp final : public Operator { ~CallbackNotifyOp() = default; void InitFromOpConf() override; + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override; diff --git a/oneflow/core/operator/case_op.cpp b/oneflow/core/operator/case_op.cpp index 4854ed17b5269969f455c8e54fb3c55544b1ec47..d3768a112addb8d22e2a1b56cb0b8c9637a8ad79 100644 --- a/oneflow/core/operator/case_op.cpp +++ b/oneflow/core/operator/case_op.cpp @@ -24,21 +24,35 @@ void CaseOp::InitFromOpConf() { EnrollRepeatedOutputBn("out", false); } -Maybe CaseOp::InferOutBlobDescs( - std::function GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { - const BlobDesc* in = GetBlobDesc4BnInOp("in"); +namespace { + +Maybe InferBlobDescs(const Operator& op, + const std::function& BlobDesc4BnInOp) { + const BlobDesc* in = BlobDesc4BnInOp("in"); CHECK_EQ_OR_RETURN(in->shape().elem_cnt(), 1); const DataType data_type = in->data_type(); CHECK_OR_RETURN(IsIntegralDataType(data_type)); - for (const std::string& obn : output_bns()) { - BlobDesc* out = GetBlobDesc4BnInOp(obn); + for (const std::string& obn : op.output_bns()) { + BlobDesc* out = BlobDesc4BnInOp(obn); out->mut_shape() = Shape({1}); out->set_data_type(data_type); } return Maybe::Ok(); } +} // namespace +Maybe CaseOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + return InferBlobDescs(*this, BlobDesc4BnInOp); +} + +Maybe CaseOp::InferOutBlobDescs( + std::function GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { + return InferBlobDescs(*this, GetBlobDesc4BnInOp); +} + Maybe CaseOp::GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const { diff --git a/oneflow/core/operator/case_op.h b/oneflow/core/operator/case_op.h index c3e68dde02d192a0aecad16b82a2a14ddaf200ff..c13bc2d7034d2a96f8de77a69bcc3440aefcc4e0 100644 --- a/oneflow/core/operator/case_op.h +++ b/oneflow/core/operator/case_op.h @@ -27,6 +27,9 @@ class CaseOp final : public Operator { ~CaseOp() override = default; void InitFromOpConf() override; + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override; diff --git a/oneflow/core/operator/constant_like_op.cpp b/oneflow/core/operator/constant_like_op.cpp index 63abc6dfffed94bdce8f588872b648f645863bbd..81ff21b63e953f279df287c453369b9385948e74 100644 --- a/oneflow/core/operator/constant_like_op.cpp +++ b/oneflow/core/operator/constant_like_op.cpp @@ -17,6 +17,19 @@ limitations under the License. namespace oneflow { +namespace { + +Maybe InferBlobDescs(const OperatorConf& op_conf, + const std::function& BlobDesc4BnInOp) { + const ConstantLikeOpConf& conf = op_conf.constant_like_conf(); + BlobDesc* out_blob_desc = BlobDesc4BnInOp("out"); + *out_blob_desc = *BlobDesc4BnInOp("like"); + if (conf.has_data_type()) { out_blob_desc->set_data_type(conf.data_type()); } + return Maybe::Ok(); +} + +} // namespace + class ConstantLikeOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(ConstantLikeOp); @@ -29,14 +42,16 @@ class ConstantLikeOp final : public Operator { EnrollOutputBn("out", false); } + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override { + return InferBlobDescs(op_conf(), BlobDesc4BnInOp); + } + Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override { - const ConstantLikeOpConf& conf = op_conf().constant_like_conf(); - BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); - *out_blob_desc = *GetBlobDesc4BnInOp("like"); - if (conf.has_data_type()) { out_blob_desc->set_data_type(conf.data_type()); } - return Maybe::Ok(); + return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp); } private: diff --git a/oneflow/core/operator/device_tick_op.cpp b/oneflow/core/operator/device_tick_op.cpp index d7275fc18371f561d16b8c112ce53681f65ca30d..8c0092e5ae9c16047e2b1c7e93f93c1392f47ac0 100644 --- a/oneflow/core/operator/device_tick_op.cpp +++ b/oneflow/core/operator/device_tick_op.cpp @@ -24,11 +24,25 @@ void DeviceTickOp::InitFromOpConf() { EnrollOutputBn("out", false); } +namespace { + +Maybe InferBlobDescs(const std::function& BlobDesc4BnInOp) { + BlobDesc4BnInOp("out")->mut_shape() = Shape({1}); + return Maybe::Ok(); +} + +} // namespace + +Maybe DeviceTickOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + return InferBlobDescs(BlobDesc4BnInOp); +} + Maybe DeviceTickOp::InferOutBlobDescs( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { - GetBlobDesc4BnInOp("out")->mut_shape() = Shape({1}); - return Maybe::Ok(); + return InferBlobDescs(GetBlobDesc4BnInOp); } Maybe DeviceTickOp::GetSbpSignatures( diff --git a/oneflow/core/operator/device_tick_op.h b/oneflow/core/operator/device_tick_op.h index d475f6fdf7fbbd873238bd472be23028e86af413..7b6f111ca099d43d2d375cb6afa902ff409c8e33 100644 --- a/oneflow/core/operator/device_tick_op.h +++ b/oneflow/core/operator/device_tick_op.h @@ -28,6 +28,9 @@ class DeviceTickOp final : public Operator { ~DeviceTickOp() = default; void InitFromOpConf() override; + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override; diff --git a/oneflow/core/operator/distribute_add_op.cpp b/oneflow/core/operator/distribute_add_op.cpp index fa37eb0d08896e0c754067042afa452f3c0bf31b..65b047c9a5f7b68d18147c93fac74f2eb49c69e2 100644 --- a/oneflow/core/operator/distribute_add_op.cpp +++ b/oneflow/core/operator/distribute_add_op.cpp @@ -30,6 +30,9 @@ class DistributeAddOp final : public Operator { void InitFromOpConf() override; + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override; @@ -67,6 +70,18 @@ Maybe DistributeAddOp::InferBlobParallelDesc() { return Maybe::Ok(); } +Maybe DistributeAddOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + const BlobDesc* in_0 = BlobDesc4BnInOp(input_bns().Get(0)); + FOR_RANGE(int, i, 1, output_bns().size()) { + const BlobDesc* in_i = BlobDesc4BnInOp(input_bns().Get(i)); + CHECK_OR_RETURN(*in_i == *in_0); + } + *BlobDesc4BnInOp("out") = *in_0; + return Maybe::Ok(); +} + Maybe DistributeAddOp::InferOutBlobDescs( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { diff --git a/oneflow/core/operator/distribute_clone_op.cpp b/oneflow/core/operator/distribute_clone_op.cpp index 0fc5eb2017b98d4667b48837dcda8ca5fa91555c..90818fddc070a30344457647c3fbfe6bb09978fe 100644 --- a/oneflow/core/operator/distribute_clone_op.cpp +++ b/oneflow/core/operator/distribute_clone_op.cpp @@ -39,6 +39,9 @@ class DistributeCloneOp final : public Operator { const std::function& CalcOrderValue4SbpSig, std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const override; + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override; @@ -53,6 +56,17 @@ void DistributeCloneOp::InitFromOpConf() { }); } +Maybe DistributeCloneOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + const auto& in_blob_desc = *BlobDesc4BnInOp("in"); + FOR_RANGE(int, i, 0, output_bns().size()) { + BlobDesc* blob_desc = BlobDesc4BnInOp(output_bns().Get(i)); + *blob_desc = in_blob_desc; + } + return Maybe::Ok(); +} + Maybe DistributeCloneOp::InferOutBlobDescs( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { diff --git a/oneflow/core/operator/distribute_concat_op.cpp b/oneflow/core/operator/distribute_concat_op.cpp index b17dccbda10d9dd6d5c4afcf02f1f92f92c43cdb..c8fb35068aa8ea92305b2c1e0fbe0d7255fbd20c 100644 --- a/oneflow/core/operator/distribute_concat_op.cpp +++ b/oneflow/core/operator/distribute_concat_op.cpp @@ -30,6 +30,9 @@ class DistributeConcatOp final : public Operator { void InitFromOpConf() override; + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override; @@ -57,6 +60,30 @@ void DistributeConcatOp::InitFromOpConf() { EnrollOutputBn("out"); } +Maybe DistributeConcatOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + const auto& conf = op_conf().distribute_concat_conf(); + BlobDesc* out = BlobDesc4BnInOp("out"); + *out = *BlobDesc4BnInOp(input_bns().Get(0)); + const int32_t concat_axis = FixAxis(conf.axis(), out->shape().NumAxes()); + int64_t concat_dim_size = out->shape().At(concat_axis); + for (size_t i = 1; i < input_bns().size(); ++i) { + const BlobDesc* in_i = BlobDesc4BnInOp(input_bns().Get(i)); + for (int64_t j = 0; j < in_i->shape().NumAxes(); ++j) { + if (j == concat_axis) { + concat_dim_size += in_i->shape().At(j); + } else { + CHECK_EQ_OR_RETURN(out->shape().At(j), in_i->shape().At(j)); + } + } + CHECK_EQ_OR_RETURN(in_i->data_type(), out->data_type()); + } + out->mut_shape().Set(concat_axis, concat_dim_size); + out->set_is_dynamic(false); + return Maybe::Ok(); +} + Maybe DistributeConcatOp::InferOutBlobDescs( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { diff --git a/oneflow/core/operator/distribute_split_op.cpp b/oneflow/core/operator/distribute_split_op.cpp index 45de46daac827b5b139b3cf6a1c77cc307d90734..b3d3c1c9739e969c619c44afc2d0350f828c9e03 100644 --- a/oneflow/core/operator/distribute_split_op.cpp +++ b/oneflow/core/operator/distribute_split_op.cpp @@ -39,6 +39,9 @@ class DistributeSplitOp final : public Operator { const std::function& CalcOrderValue4SbpSig, std::function(const std::string&)> SbpInferHint4Ibn, const ParallelDesc& parallel_desc) const override; + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override; @@ -59,6 +62,22 @@ void DistributeSplitOp::InitFromOpConf() { }); } +Maybe DistributeSplitOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + const auto& in_blob_desc = *BlobDesc4BnInOp("in"); + CHECK_EQ(parallel_desc.parallel_num(), output_bns().size()); + const auto& conf = op_conf().distribute_split_conf(); + const int32_t split_axis = FixAxis(conf.axis(), in_blob_desc.shape().NumAxes()); + BalancedSplitter bs(in_blob_desc.shape().At(split_axis), parallel_desc.parallel_num()); + FOR_RANGE(int, i, 0, parallel_desc.parallel_num()) { + BlobDesc* out_blob_desc = BlobDesc4BnInOp(output_bns().Get(i)); + *out_blob_desc = in_blob_desc; + out_blob_desc->mut_shape().Set(split_axis, bs.At(i).size()); + } + return Maybe::Ok(); +} + Maybe DistributeSplitOp::InferOutBlobDescs( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { diff --git a/oneflow/core/operator/dst_subset_tick_op.cpp b/oneflow/core/operator/dst_subset_tick_op.cpp index 013d6f242db713acf76c45b6c6ecb37085944749..f7fbc7980e12e7c3b9886f98d332b22bc3cccf0e 100644 --- a/oneflow/core/operator/dst_subset_tick_op.cpp +++ b/oneflow/core/operator/dst_subset_tick_op.cpp @@ -20,6 +20,15 @@ limitations under the License. namespace oneflow { +namespace { + +Maybe InferBlobDescs(const std::function& BlobDesc4BnInOp) { + BlobDesc4BnInOp("out")->mut_shape() = Shape({1}); + return Maybe::Ok(); +} + +} // namespace + class DstSubsetTickOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(DstSubsetTickOp); @@ -27,6 +36,9 @@ class DstSubsetTickOp final : public Operator { ~DstSubsetTickOp() = default; void InitFromOpConf() override; + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature*) const override; @@ -46,11 +58,16 @@ LogicalNode* DstSubsetTickOp::NewProperLogicalNode() const { return new DstSubsetTickLogicalNode(); } +Maybe DstSubsetTickOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + return InferBlobDescs(BlobDesc4BnInOp); +} + Maybe DstSubsetTickOp::InferOutBlobDescs( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature*) const { - GetBlobDesc4BnInOp("out")->mut_shape() = Shape({1}); - return Maybe::Ok(); + return InferBlobDescs(GetBlobDesc4BnInOp); } Maybe DstSubsetTickOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const { diff --git a/oneflow/core/operator/esac_op.cpp b/oneflow/core/operator/esac_op.cpp index 949875d5b6538e5c1037f222616a4b0eb679dd1c..f894d4e719fb369086430bf8829a04f24bb8232d 100644 --- a/oneflow/core/operator/esac_op.cpp +++ b/oneflow/core/operator/esac_op.cpp @@ -24,17 +24,32 @@ void EsacOp::InitFromOpConf() { EnrollOutputBn("out", false); } -Maybe EsacOp::InferOutBlobDescs( - std::function GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { - BlobDesc* out = GetBlobDesc4BnInOp("out"); +namespace { + +Maybe InferBlobDescs(const OperatorConf& op_conf, + const std::function& BlobDesc4BnInOp) { + BlobDesc* out = BlobDesc4BnInOp("out"); out->mut_shape() = Shape({1}); - const DataType data_type = op_conf().esac_conf().data_type(); + const DataType data_type = op_conf.esac_conf().data_type(); CHECK_OR_RETURN(IsIntegralDataType(data_type)); out->set_data_type(data_type); return Maybe::Ok(); } +} // namespace + +Maybe EsacOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + return InferBlobDescs(op_conf(), BlobDesc4BnInOp); +} + +Maybe EsacOp::InferOutBlobDescs( + std::function GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { + return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp); +} + Maybe EsacOp::GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const { diff --git a/oneflow/core/operator/esac_op.h b/oneflow/core/operator/esac_op.h index 64fc3656387532c7e4b7a38a85d701c4c0e9d710..4751084bdcddadc23646128706a456bdde28f826 100644 --- a/oneflow/core/operator/esac_op.h +++ b/oneflow/core/operator/esac_op.h @@ -27,6 +27,9 @@ class EsacOp final : public Operator { ~EsacOp() override = default; void InitFromOpConf() override; + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override; diff --git a/oneflow/core/operator/foreign_input_op.cpp b/oneflow/core/operator/foreign_input_op.cpp index e7521c8684559757f8d8790daef98f9e8f5372f2..df4d2772215359cdfffc2d1bef87d09f968f3171 100644 --- a/oneflow/core/operator/foreign_input_op.cpp +++ b/oneflow/core/operator/foreign_input_op.cpp @@ -22,6 +22,22 @@ namespace { void CheckOpConf(const OperatorConf& op_conf) { CHECK(op_conf.ctrl_in_op_name().empty()); } +Maybe InferBlobDescs(const JobDesc& job_desc, const OperatorConf& op_conf, + const std::function& BlobDesc4BnInOp) { + CheckOpConf(op_conf); + const auto& conf = op_conf.foreign_input_conf().blob_conf(); + BlobDesc* out_blob_desc = BlobDesc4BnInOp("out"); + out_blob_desc->mut_shape() = Shape(conf.shape()); + if (conf.has_data_type()) { + out_blob_desc->set_data_type(conf.data_type()); + } else { + out_blob_desc->set_data_type(job_desc.DefaultDataType()); + } + out_blob_desc->set_is_dynamic(conf.is_dynamic()); + out_blob_desc->set_is_tensor_list(conf.is_tensor_list()); + return Maybe::Ok(); +} + } // namespace void ForeignInputOp::InitFromOpConf() { @@ -30,22 +46,18 @@ void ForeignInputOp::InitFromOpConf() { EnrollOutputBn("out", false); } +Maybe ForeignInputOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + CHECK_EQ_OR_RETURN(parallel_desc.parallel_num(), 1); + return InferBlobDescs(job_desc(), op_conf(), BlobDesc4BnInOp); +} + Maybe ForeignInputOp::InferOutBlobDescs( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { CHECK_EQ_OR_RETURN(parallel_ctx->parallel_num(), 1); - CheckOpConf(op_conf()); - const auto& conf = op_conf().foreign_input_conf().blob_conf(); - BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out"); - out_blob_desc->mut_shape() = Shape(conf.shape()); - if (conf.has_data_type()) { - out_blob_desc->set_data_type(conf.data_type()); - } else { - out_blob_desc->set_data_type(job_desc().DefaultDataType()); - } - out_blob_desc->set_is_dynamic(conf.is_dynamic()); - out_blob_desc->set_is_tensor_list(conf.is_tensor_list()); - return Maybe::Ok(); + return InferBlobDescs(job_desc(), op_conf(), GetBlobDesc4BnInOp); } Maybe ForeignInputOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const { diff --git a/oneflow/core/operator/foreign_input_op.h b/oneflow/core/operator/foreign_input_op.h index 6302c584c5e5cf16808d16e39389318327eb7d28..c528ff563b13cc37ab30bb7bdda1c400f66a50f7 100644 --- a/oneflow/core/operator/foreign_input_op.h +++ b/oneflow/core/operator/foreign_input_op.h @@ -28,6 +28,9 @@ class ForeignInputOp final : public Operator { ~ForeignInputOp() = default; void InitFromOpConf() override; + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override; diff --git a/oneflow/core/operator/foreign_output_op.cpp b/oneflow/core/operator/foreign_output_op.cpp index 0ea16df4a10cb01c72cca1e4b3c67f393278a42b..6890624c9e84c562e537a6475c89ebfc8ec1ab3e 100644 --- a/oneflow/core/operator/foreign_output_op.cpp +++ b/oneflow/core/operator/foreign_output_op.cpp @@ -23,6 +23,13 @@ void ForeignOutputOp::InitFromOpConf() { EnrollInputBn("in"); } +Maybe ForeignOutputOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + CHECK_EQ_OR_RETURN(parallel_desc.parallel_num(), 1); + return Maybe::Ok(); +} + Maybe ForeignOutputOp::InferOutBlobDescs( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { diff --git a/oneflow/core/operator/foreign_output_op.h b/oneflow/core/operator/foreign_output_op.h index 000c63ca96c461b05242f66334f4c33b85e9cb4c..3fde8073a66355cc50db8135d3e227705958dbf5 100644 --- a/oneflow/core/operator/foreign_output_op.h +++ b/oneflow/core/operator/foreign_output_op.h @@ -28,6 +28,9 @@ class ForeignOutputOp final : public Operator { ~ForeignOutputOp() override = default; void InitFromOpConf() override; + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override; diff --git a/oneflow/core/operator/foreign_watch_op.cpp b/oneflow/core/operator/foreign_watch_op.cpp index af582671192933c04a2b3b4bf9cd078a7fb9434f..89583468a1848e397990b8ef17956afa64c9f470 100644 --- a/oneflow/core/operator/foreign_watch_op.cpp +++ b/oneflow/core/operator/foreign_watch_op.cpp @@ -23,6 +23,13 @@ void ForeignWatchOp::InitFromOpConf() { EnrollInputBn("in"); } +Maybe ForeignWatchOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + CHECK_EQ_OR_RETURN(parallel_desc.parallel_num(), 1); + return Maybe::Ok(); +} + Maybe ForeignWatchOp::InferOutBlobDescs( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { diff --git a/oneflow/core/operator/foreign_watch_op.h b/oneflow/core/operator/foreign_watch_op.h index 6114491ae38df5e3717ce55d9b7d6c690455b74d..7d40eab5cd98e48347b1cff433423dedba1651d1 100644 --- a/oneflow/core/operator/foreign_watch_op.h +++ b/oneflow/core/operator/foreign_watch_op.h @@ -28,6 +28,9 @@ class ForeignWatchOp final : public Operator { ~ForeignWatchOp() override = default; void InitFromOpConf() override; + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override; diff --git a/oneflow/core/operator/identity_op.cpp b/oneflow/core/operator/identity_op.cpp index 6463538384393b8551f7474ec86c6d16df35cc3d..d72ecfa5953dafa49d0a28c3ff09e056a1ed6d0e 100644 --- a/oneflow/core/operator/identity_op.cpp +++ b/oneflow/core/operator/identity_op.cpp @@ -20,6 +20,15 @@ limitations under the License. namespace oneflow { +namespace { + +Maybe InferBlobDescs(const std::function& BlobDesc4BnInOp) { + *BlobDesc4BnInOp("out") = *BlobDesc4BnInOp("in"); + return Maybe::Ok(); +} + +} // namespace + template class IdentityOpTpl final : public Operator { public: @@ -31,11 +40,15 @@ class IdentityOpTpl final : public Operator { EnrollInputBn("in"); EnrollOutputBn("out")->set_const_inplace_ibn("in"); } + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override { + return InferBlobDescs(BlobDesc4BnInOp); + } Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override { - *GetBlobDesc4BnInOp("out") = *GetBlobDesc4BnInOp("in"); - return Maybe::Ok(); + return InferBlobDescs(GetBlobDesc4BnInOp); } private: @@ -66,11 +79,15 @@ class MirroredCastOp : public Operator { EnrollInputBn("in"); EnrollOutputBn("out")->set_const_inplace_ibn("in"); } + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override { + return InferBlobDescs(BlobDesc4BnInOp); + } Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override { - *GetBlobDesc4BnInOp("out") = *GetBlobDesc4BnInOp("in"); - return Maybe::Ok(); + return InferBlobDescs(GetBlobDesc4BnInOp); } private: diff --git a/oneflow/core/operator/indexed_slices_reduce_sum_op.cpp b/oneflow/core/operator/indexed_slices_reduce_sum_op.cpp index 695e0e8b8852c93e7babbfeb64c3fc18e064d22a..14dda459edc76721f955a48965157bac284e0281 100644 --- a/oneflow/core/operator/indexed_slices_reduce_sum_op.cpp +++ b/oneflow/core/operator/indexed_slices_reduce_sum_op.cpp @@ -26,6 +26,9 @@ class IndexedSlicesReduceSumOp final : public Operator { ~IndexedSlicesReduceSumOp() override = default; void InitFromOpConf() override; + virtual Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const; Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override; @@ -49,11 +52,11 @@ void IndexedSlicesReduceSumOp::InitFromOpConf() { EnrollTmpBn("workspace"); } -Maybe IndexedSlicesReduceSumOp::InferOutBlobDescs( - std::function GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { - const BlobDesc* x_indices = GetBlobDesc4BnInOp("x_indices"); - const BlobDesc* x_values = GetBlobDesc4BnInOp("x_values"); +namespace { + +Maybe InferBlobDescs(const std::function& BlobDesc4BnInOp) { + const BlobDesc* x_indices = BlobDesc4BnInOp("x_indices"); + const BlobDesc* x_values = BlobDesc4BnInOp("x_values"); CHECK_LT_OR_RETURN(x_indices->shape().NumAxes(), x_values->shape().NumAxes()); FOR_RANGE(int64_t, i, 0, x_indices->shape().NumAxes()) { CHECK_EQ_OR_RETURN(x_indices->shape().At(i), x_values->shape().At(i)); @@ -61,18 +64,32 @@ Maybe IndexedSlicesReduceSumOp::InferOutBlobDescs( CHECK_OR_RETURN(IsIndexDataType(x_indices->data_type())); const int64_t n = x_indices->shape().elem_cnt(); const int64_t m = x_values->shape().elem_cnt() / n; - BlobDesc* y_indices = GetBlobDesc4BnInOp("y_indices"); - BlobDesc* y_values = GetBlobDesc4BnInOp("y_values"); + BlobDesc* y_indices = BlobDesc4BnInOp("y_indices"); + BlobDesc* y_values = BlobDesc4BnInOp("y_values"); *y_indices = *x_indices; y_indices->mut_shape() = Shape({n}); *y_values = *x_values; y_values->mut_shape() = Shape({n, m}); - BlobDesc* num_unique = GetBlobDesc4BnInOp("num_unique"); + BlobDesc* num_unique = BlobDesc4BnInOp("num_unique"); num_unique->mut_shape() = Shape({1}); num_unique->set_data_type(DataType::kInt64); return Maybe::Ok(); } +} // namespace + +Maybe IndexedSlicesReduceSumOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + return InferBlobDescs(BlobDesc4BnInOp); +} + +Maybe IndexedSlicesReduceSumOp::InferOutBlobDescs( + std::function GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { + return InferBlobDescs(GetBlobDesc4BnInOp); +} + Maybe IndexedSlicesReduceSumOp::InferInternalBlobDescs( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { diff --git a/oneflow/core/operator/input_op.cpp b/oneflow/core/operator/input_op.cpp index 00e92221e3be38eeb3ded0a6b2e1e1d166572c82..e34f495c1699b7670b4eb0d69b688fc28b9d7bac 100644 --- a/oneflow/core/operator/input_op.cpp +++ b/oneflow/core/operator/input_op.cpp @@ -28,6 +28,15 @@ void InputOp::InitFromOpConf() { modifier->set_header_infered_before_compute(false); } +Maybe InputOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + BlobDesc* out_blob_desc = BlobDesc4BnInOp("out"); + JUST(InterfaceOpUtil::InferLogicalOutBlobDesc(op_conf().input_conf().blob_conf(), out_blob_desc, + parallel_desc)); + return Maybe::Ok(); +} + Maybe InputOp::InferOutBlobDescs( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { diff --git a/oneflow/core/operator/input_op.h b/oneflow/core/operator/input_op.h index f1b667e7386cc329ae03c03ec775f7e0cd27cb63..828bf3c52bc45e78367ebc2786a7f61a5ce57a7d 100644 --- a/oneflow/core/operator/input_op.h +++ b/oneflow/core/operator/input_op.h @@ -32,6 +32,9 @@ class InputOp final : public Operator { const SbpSignature* sbp_signature) const override; private: + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override; Maybe InferSbpSignature( SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, diff --git a/oneflow/core/operator/interface_op_util.cpp b/oneflow/core/operator/interface_op_util.cpp index 8c669dba71093ee580bf533364c6a7a5cd2e3cf4..8928f6b27a907a4eb8a256d9c2fc0944d253401b 100644 --- a/oneflow/core/operator/interface_op_util.cpp +++ b/oneflow/core/operator/interface_op_util.cpp @@ -67,6 +67,18 @@ Maybe InterfaceOpUtil::InferOutBlobDesc(const InterfaceBlobConf& blob_conf return Maybe::Ok(); } +Maybe InterfaceOpUtil::InferLogicalOutBlobDesc(const InterfaceBlobConf& blob_conf, + BlobDesc* out_blob_desc, + const ParallelDesc& parallel_desc) { + out_blob_desc->mut_shape() = Shape(blob_conf.shape()); + CheckShape(out_blob_desc->shape()); + CHECK_GT(out_blob_desc->mut_shape().At(0), 0); + out_blob_desc->set_data_type(blob_conf.data_type()); + out_blob_desc->set_is_dynamic(blob_conf.is_dynamic()); + out_blob_desc->set_is_tensor_list(blob_conf.is_tensor_list()); + return Maybe::Ok(); +} + Maybe InterfaceOpUtil::GetInputLikeOpSbpSignature(const InterfaceBlobConf& blob_conf, const PbRpf& input_bns, const PbRpf& output_bns, diff --git a/oneflow/core/operator/interface_op_util.h b/oneflow/core/operator/interface_op_util.h index 07d39d9b98ab57934f76ccebe06281ebf0e176e9..9580d310d5ce1dd696e426bb402276989da165c2 100644 --- a/oneflow/core/operator/interface_op_util.h +++ b/oneflow/core/operator/interface_op_util.h @@ -25,6 +25,9 @@ namespace oneflow { struct InterfaceOpUtil final { static Maybe InferOutBlobDesc(const InterfaceBlobConf& blob_conf, BlobDesc* out_blob_desc, const ParallelContext* parallel_ctx); + static Maybe InferLogicalOutBlobDesc(const InterfaceBlobConf& blob_conf, + BlobDesc* out_blob_desc, + const ParallelDesc& parallel_desc); static Maybe GetInputLikeOpSbpSignature(const InterfaceBlobConf& blob_conf, const PbRpf& input_bns, const PbRpf& output_bns, diff --git a/oneflow/core/operator/learning_rate_schedule_op.cpp b/oneflow/core/operator/learning_rate_schedule_op.cpp index 7b3cd4be3175435513cf3030763678caabecfbaf..ec45c4322b4c0b6ee7aa1ffea9d4075596c7ea99 100644 --- a/oneflow/core/operator/learning_rate_schedule_op.cpp +++ b/oneflow/core/operator/learning_rate_schedule_op.cpp @@ -24,6 +24,9 @@ class LearningRateScheduleOp final : public Operator { ~LearningRateScheduleOp() override = default; void InitFromOpConf() override; + virtual Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const; Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override; @@ -40,18 +43,32 @@ void LearningRateScheduleOp::InitFromOpConf() { EnrollOutputBn("out"); } -Maybe LearningRateScheduleOp::InferOutBlobDescs( - std::function GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { - const BlobDesc* train_step = GetBlobDesc4BnInOp("train_step"); +namespace { + +Maybe InferBlobDescs(const std::function& BlobDesc4BnInOp) { + const BlobDesc* train_step = BlobDesc4BnInOp("train_step"); CHECK_EQ(train_step->shape().elem_cnt(), 1); CHECK_EQ(train_step->data_type(), DataType::kInt64); - BlobDesc* out = GetBlobDesc4BnInOp("out"); + BlobDesc* out = BlobDesc4BnInOp("out"); out->mut_shape() = Shape({1}); out->set_data_type(DataType::kFloat); return Maybe::Ok(); } +} // namespace + +Maybe LearningRateScheduleOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + return InferBlobDescs(BlobDesc4BnInOp); +} + +Maybe LearningRateScheduleOp::InferOutBlobDescs( + std::function GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { + return InferBlobDescs(GetBlobDesc4BnInOp); +} + Maybe LearningRateScheduleOp::GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const { diff --git a/oneflow/core/operator/model_init_op.cpp b/oneflow/core/operator/model_init_op.cpp index 7715df230357641dada70c4d58ea0adc622471f3..620d2dba0389d160427a354beeea04873f7a46db 100644 --- a/oneflow/core/operator/model_init_op.cpp +++ b/oneflow/core/operator/model_init_op.cpp @@ -21,6 +21,9 @@ class ModelInitOp : public Operator { public: void InitFromOpConf() override; + virtual Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const; Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override; @@ -37,20 +40,34 @@ void ModelInitOp::InitFromOpConf() { EnrollRepeatedOutputBn("out", false); } -Maybe ModelInitOp::InferOutBlobDescs( - std::function GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { - const int64_t num_out = op_conf().model_init_conf().out().size(); +namespace { + +Maybe InferBlobDescs(const OperatorConf& conf, + const std::function& BlobDesc4BnInOp) { + const int64_t num_out = conf.model_init_conf().out().size(); FOR_RANGE(int64_t, i, 0, num_out) { - const VariableOpConf& original_variable_conf = - op_conf().model_init_conf().original_variable_conf(i); - BlobDesc* out_i = GetBlobDesc4BnInOp(GenRepeatedBn("out", i)); + const VariableOpConf& original_variable_conf = conf.model_init_conf().original_variable_conf(i); + BlobDesc* out_i = BlobDesc4BnInOp(GenRepeatedBn("out", i)); out_i->mut_shape() = Shape(original_variable_conf.shape()); out_i->set_data_type(original_variable_conf.data_type()); } return Maybe::Ok(); } +} // namespace + +Maybe ModelInitOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + return InferBlobDescs(op_conf(), BlobDesc4BnInOp); +} + +Maybe ModelInitOp::InferOutBlobDescs( + std::function GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { + return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp); +} + Maybe ModelInitOp::GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const { diff --git a/oneflow/core/operator/model_load_op.cpp b/oneflow/core/operator/model_load_op.cpp index 6eab15980b062c2aabff81649f5baf12e96d4f12..6a7020cd258173f64ba98c1f4a60705b8d47f082 100644 --- a/oneflow/core/operator/model_load_op.cpp +++ b/oneflow/core/operator/model_load_op.cpp @@ -21,6 +21,9 @@ class ModelLoadOp : public Operator { public: void InitFromOpConf() override; + virtual Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const; Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override; @@ -37,20 +40,35 @@ void ModelLoadOp::InitFromOpConf() { EnrollRepeatedOutputBn("out", false); } -Maybe ModelLoadOp::InferOutBlobDescs( - std::function GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { - const int64_t num_out = op_conf().model_load_conf().out().size(); +namespace { + +Maybe InferBlobDescs(const OperatorConf& op_conf, + const std::function& BlobDesc4BnInOp) { + const int64_t num_out = op_conf.model_load_conf().out().size(); FOR_RANGE(int64_t, i, 0, num_out) { const VariableOpConf& original_variable_conf = - op_conf().model_load_conf().original_variable_conf(i); - BlobDesc* out_i = GetBlobDesc4BnInOp(GenRepeatedBn("out", i)); + op_conf.model_load_conf().original_variable_conf(i); + BlobDesc* out_i = BlobDesc4BnInOp(GenRepeatedBn("out", i)); out_i->mut_shape() = Shape(original_variable_conf.shape()); out_i->set_data_type(original_variable_conf.data_type()); } return Maybe::Ok(); } +} // namespace + +Maybe ModelLoadOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + return InferBlobDescs(op_conf(), BlobDesc4BnInOp); +} + +Maybe ModelLoadOp::InferOutBlobDescs( + std::function GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { + return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp); +} + Maybe ModelLoadOp::GetSbpSignatures( const std::function(const std::string&)>& LogicalBlobDesc4Ibn, SbpSignatureList* sbp_sig_list) const { diff --git a/oneflow/core/operator/model_save_op.cpp b/oneflow/core/operator/model_save_op.cpp index 0a848b4b73b6d95e7673606316f3ceb5ede3fd8e..fd7e27f74fdbe646efc419d84e27dadbbea9414a 100644 --- a/oneflow/core/operator/model_save_op.cpp +++ b/oneflow/core/operator/model_save_op.cpp @@ -26,6 +26,11 @@ class ModelSaveOp final : public Operator { void InitFromOpConf() override; LogicalNode* NewProperLogicalNode() const override { return new PrintLogicalNode; } + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override { + return Maybe::Ok(); + } Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override { diff --git a/oneflow/core/operator/output_op.cpp b/oneflow/core/operator/output_op.cpp index d93d3df0887f993128b3a9028486a347263f3fe8..92284970b7e66c57dafcc415bad74b362bdf5fce 100644 --- a/oneflow/core/operator/output_op.cpp +++ b/oneflow/core/operator/output_op.cpp @@ -25,6 +25,15 @@ void OutputOp::InitFromOpConf() { EnrollOutputBn("out")->set_is_mutable(true); } +Maybe OutputOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + BlobDesc* out_blob_desc = BlobDesc4BnInOp("out"); + InterfaceOpUtil::InferLogicalOutBlobDesc(op_conf().output_conf().blob_conf(), out_blob_desc, + parallel_desc); + return Maybe::Ok(); +} + Maybe OutputOp::InferOutBlobDescs( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { diff --git a/oneflow/core/operator/output_op.h b/oneflow/core/operator/output_op.h index 4025fa4ca9a45dcdbdb85077a91b3dc98b257613..48bc2a0b0997e34d7bbbab1451d93875ada9b632 100644 --- a/oneflow/core/operator/output_op.h +++ b/oneflow/core/operator/output_op.h @@ -32,6 +32,9 @@ class OutputOp final : public Operator { const SbpSignature* sbp_signature) const override; private: + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override; Maybe InferSbpSignature( SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, diff --git a/oneflow/core/operator/reentrant_lock_op.cpp b/oneflow/core/operator/reentrant_lock_op.cpp index 9d57d3c7c61392343b2ba563a66c7dd5dd3ca7e0..ed005d5b31562fcbadb3cc291265501124953137 100644 --- a/oneflow/core/operator/reentrant_lock_op.cpp +++ b/oneflow/core/operator/reentrant_lock_op.cpp @@ -25,16 +25,32 @@ void ReentrantLockOp::InitFromOpConf() { EnrollOutputBn("out", false); } +namespace { + +Maybe InferBlobDescs(const std::function& BlobDesc4BnInOp) { + const BlobDesc* start = BlobDesc4BnInOp("start"); + const DataType data_type = start->data_type(); + CHECK_OR_RETURN(IsIntegralDataType(data_type)); + BlobDesc* out = BlobDesc4BnInOp("out"); + out->mut_shape() = Shape({1}); + out->set_data_type(data_type); + return Maybe::Ok(); +} + +} // namespace + +Maybe ReentrantLockOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + CHECK_EQ_OR_RETURN(parallel_desc.parallel_num(), 1); + return InferBlobDescs(BlobDesc4BnInOp); +} + Maybe ReentrantLockOp::InferOutBlobDescs( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { CHECK_EQ_OR_RETURN(parallel_ctx->parallel_num(), 1); - BlobDesc* out = GetBlobDesc4BnInOp("out"); - out->mut_shape() = Shape({1}); - const DataType data_type = GetBlobDesc4BnInOp("out")->data_type(); - CHECK_OR_RETURN(IsIntegralDataType(data_type)); - out->set_data_type(data_type); - return Maybe::Ok(); + return InferBlobDescs(GetBlobDesc4BnInOp); } Maybe ReentrantLockOp::GetSbpSignatures( diff --git a/oneflow/core/operator/reentrant_lock_op.h b/oneflow/core/operator/reentrant_lock_op.h index 38392639117a0adfca6974ecfddb3392bb944235..78cc687867fd75b8532fdf76790eb3359b971deb 100644 --- a/oneflow/core/operator/reentrant_lock_op.h +++ b/oneflow/core/operator/reentrant_lock_op.h @@ -27,6 +27,9 @@ class ReentrantLockOp final : public Operator { ~ReentrantLockOp() override = default; void InitFromOpConf() override; + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override; diff --git a/oneflow/core/operator/return_op.cpp b/oneflow/core/operator/return_op.cpp index 4e5139a03687f1899e888fbebc37ac6a6e312922..c9d8f7d94e5972c615fe33c7e838b6a80d18bbdd 100644 --- a/oneflow/core/operator/return_op.cpp +++ b/oneflow/core/operator/return_op.cpp @@ -25,11 +25,25 @@ void ReturnOp::InitFromOpConf() { EnrollOutputBn("out")->set_is_mutable(true); } +namespace { + +Maybe InferBlobDescs(const std::function& BlobDesc4BnInOp) { + *BlobDesc4BnInOp("out") = *BlobDesc4BnInOp("in"); + return Maybe::Ok(); +} + +} // namespace + +Maybe ReturnOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + return InferBlobDescs(BlobDesc4BnInOp); +} + Maybe ReturnOp::InferOutBlobDescs( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { - *GetBlobDesc4BnInOp("out") = *GetBlobDesc4BnInOp("in"); - return Maybe::Ok(); + return InferBlobDescs(GetBlobDesc4BnInOp); } Maybe ReturnOp::InferSbpSignature( diff --git a/oneflow/core/operator/return_op.h b/oneflow/core/operator/return_op.h index 7c72723f3f6c5fa6866bf1a4be33711f63c2e05a..19ad8f59ca73903b26241ac4c630f6feb3b9c46b 100644 --- a/oneflow/core/operator/return_op.h +++ b/oneflow/core/operator/return_op.h @@ -27,6 +27,9 @@ class ReturnOp final : public Operator { ~ReturnOp() override = default; void InitFromOpConf() override; + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override; diff --git a/oneflow/core/operator/shape_elem_cnt_op.cpp b/oneflow/core/operator/shape_elem_cnt_op.cpp index ea630f443211c1695a05c1d15fcd93de3026b3f2..5d4ef0ceb23cc45a8836027be06520d2ae6d4e0a 100644 --- a/oneflow/core/operator/shape_elem_cnt_op.cpp +++ b/oneflow/core/operator/shape_elem_cnt_op.cpp @@ -54,12 +54,27 @@ void ShapeElemCntOp::InitFromOpConf() { EnrollOutputBn("y", false); } +namespace { + +Maybe InferBlobDescs(const OperatorConf& op_conf, + const std::function& BlobDesc4BnInOp) { + BlobDesc4BnInOp("y")->set_data_type(op_conf.shape_elem_cnt_conf().data_type()); + BlobDesc4BnInOp("y")->mut_shape() = Shape({1}); + return Maybe::Ok(); +} + +} // namespace + +Maybe ShapeElemCntOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + return InferBlobDescs(op_conf(), BlobDesc4BnInOp); +} + Maybe ShapeElemCntOp::InferOutBlobDescs( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { - GetBlobDesc4BnInOp("y")->set_data_type(op_conf().shape_elem_cnt_conf().data_type()); - GetBlobDesc4BnInOp("y")->mut_shape() = Shape({1}); - return Maybe::Ok(); + return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp); } void ShapeElemCntOp::VirtualGenKernelConf( diff --git a/oneflow/core/operator/shape_elem_cnt_op.h b/oneflow/core/operator/shape_elem_cnt_op.h index ab3b070b685ab8b6ec07b0d39d773ebaf805c1c6..2ce223b6550dcc01eb904dee11f4f0b976426068 100644 --- a/oneflow/core/operator/shape_elem_cnt_op.h +++ b/oneflow/core/operator/shape_elem_cnt_op.h @@ -27,6 +27,9 @@ class ShapeElemCntOp final : public Operator { ~ShapeElemCntOp() override = default; void InitFromOpConf() override; + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override; diff --git a/oneflow/core/operator/sink_tick_op.cpp b/oneflow/core/operator/sink_tick_op.cpp index b5f8dff578468c71bd355e51f83f46a051da7181..9a540f6d3e3b634e9ad23e247aca155e74aee570 100644 --- a/oneflow/core/operator/sink_tick_op.cpp +++ b/oneflow/core/operator/sink_tick_op.cpp @@ -24,11 +24,25 @@ void SinkTickOp::InitFromOpConf() { EnrollOutputBn("out", false); } +namespace { + +Maybe InferBlobDescs(const std::function& BlobDesc4BnInOp) { + BlobDesc4BnInOp("out")->mut_shape() = Shape({1}); + return Maybe::Ok(); +} + +} // namespace + +Maybe SinkTickOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + return InferBlobDescs(BlobDesc4BnInOp); +} + Maybe SinkTickOp::InferOutBlobDescs( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { - GetBlobDesc4BnInOp("out")->mut_shape() = Shape({1}); - return Maybe::Ok(); + return InferBlobDescs(GetBlobDesc4BnInOp); } Maybe SinkTickOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const { diff --git a/oneflow/core/operator/sink_tick_op.h b/oneflow/core/operator/sink_tick_op.h index 08a82bb6f72211997fb32a046a271fcd22205671..82435268eeb418daddd9589a69c63e9186c84a33 100644 --- a/oneflow/core/operator/sink_tick_op.h +++ b/oneflow/core/operator/sink_tick_op.h @@ -28,6 +28,9 @@ class SinkTickOp final : public Operator { ~SinkTickOp() = default; void InitFromOpConf() override; + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override; diff --git a/oneflow/core/operator/source_tick_op.cpp b/oneflow/core/operator/source_tick_op.cpp index fef198561a0536bd04a9b48a37a7c4c394833597..34c68695077b2b7f444e9b6229dfe9189e2c5d6e 100644 --- a/oneflow/core/operator/source_tick_op.cpp +++ b/oneflow/core/operator/source_tick_op.cpp @@ -26,6 +26,13 @@ void SourceTickOp::InitFromOpConf() { LogicalNode* SourceTickOp::NewProperLogicalNode() const { return new SourceTickLogicalNode(); } +Maybe SourceTickOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + BlobDesc4BnInOp("out")->mut_shape() = Shape({1}); + return Maybe::Ok(); +} + Maybe SourceTickOp::InferOutBlobDescs( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { @@ -35,7 +42,7 @@ Maybe SourceTickOp::InferOutBlobDescs( } Maybe SourceTickOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const { - SbpSignatureBuilder().Split(output_bns(), 0).Build(sbp_sig_list->mutable_sbp_signature()->Add()); + SbpSignatureBuilder().Broadcast(output_bns()).Build(sbp_sig_list->mutable_sbp_signature()->Add()); return Maybe::Ok(); } diff --git a/oneflow/core/operator/source_tick_op.h b/oneflow/core/operator/source_tick_op.h index 8674f05652108092a2a95a62fee8ee967448bac9..ed9adbd65c1d32b62be622e20ec1500ddbe39cc1 100644 --- a/oneflow/core/operator/source_tick_op.h +++ b/oneflow/core/operator/source_tick_op.h @@ -28,6 +28,9 @@ class SourceTickOp final : public Operator { ~SourceTickOp() = default; void InitFromOpConf() override; + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override; diff --git a/oneflow/core/operator/src_subset_tick_op.cpp b/oneflow/core/operator/src_subset_tick_op.cpp index 0dc0be61f5788b290bb8a87cdbf88a927b625f0a..ebbf536ba48fad035d4b45e0d0f5b0a92e5b1931 100644 --- a/oneflow/core/operator/src_subset_tick_op.cpp +++ b/oneflow/core/operator/src_subset_tick_op.cpp @@ -27,6 +27,9 @@ class SrcSubsetTickOp final : public Operator { ~SrcSubsetTickOp() = default; void InitFromOpConf() override; + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature*) const override; @@ -46,11 +49,25 @@ LogicalNode* SrcSubsetTickOp::NewProperLogicalNode() const { return new SrcSubsetTickLogicalNode(); } +namespace { + +Maybe InferBlobDescs(const std::function& BlobDesc4BnInOp) { + BlobDesc4BnInOp("out")->mut_shape() = Shape({1}); + return Maybe::Ok(); +} + +} // namespace + +Maybe SrcSubsetTickOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + return InferBlobDescs(BlobDesc4BnInOp); +} + Maybe SrcSubsetTickOp::InferOutBlobDescs( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature*) const { - GetBlobDesc4BnInOp("out")->mut_shape() = Shape({1}); - return Maybe::Ok(); + return InferBlobDescs(GetBlobDesc4BnInOp); } Maybe SrcSubsetTickOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const { diff --git a/oneflow/core/operator/sync_dynamic_resize_op.cpp b/oneflow/core/operator/sync_dynamic_resize_op.cpp index 3ac194807ce4a25b5258d827d2300163b71b17d5..0d58cf0636dd7fdfc25cd5e5c2f7ea943bec8d48 100644 --- a/oneflow/core/operator/sync_dynamic_resize_op.cpp +++ b/oneflow/core/operator/sync_dynamic_resize_op.cpp @@ -17,6 +17,24 @@ limitations under the License. namespace oneflow { +namespace { + +Maybe InferBlobDescs(const OperatorConf& op_conf, + const std::function& BlobDesc4BnInOp) { + const SyncDynamicResizeOpConf& conf = op_conf.sync_dynamic_resize_conf(); + CHECK_EQ_OR_RETURN(conf.axis(), 0); + const BlobDesc* in = BlobDesc4BnInOp("in"); + const BlobDesc* size = BlobDesc4BnInOp("size"); + CHECK_EQ_OR_RETURN(size->shape().elem_cnt(), 1); + CHECK_OR_RETURN(IsIntegralDataType(size->data_type())); + BlobDesc* out = BlobDesc4BnInOp("out"); + *out = *in; + out->set_is_dynamic(true); + return Maybe::Ok(); +} + +} // namespace + class SyncDynamicResizeOp : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(SyncDynamicResizeOp); @@ -29,19 +47,16 @@ class SyncDynamicResizeOp : public Operator { EnrollOutputBn("out")->set_header_infered_before_compute(false); } + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override { + return InferBlobDescs(op_conf(), BlobDesc4BnInOp); + } + Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override { - const SyncDynamicResizeOpConf& conf = op_conf().sync_dynamic_resize_conf(); - CHECK_EQ_OR_RETURN(conf.axis(), 0); - const BlobDesc* in = GetBlobDesc4BnInOp("in"); - const BlobDesc* size = GetBlobDesc4BnInOp("size"); - CHECK_EQ_OR_RETURN(size->shape().elem_cnt(), 1); - CHECK_OR_RETURN(IsIntegralDataType(size->data_type())); - BlobDesc* out = GetBlobDesc4BnInOp("out"); - *out = *in; - out->set_is_dynamic(true); - return Maybe::Ok(); + return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp); } Maybe GetSbpSignatures( diff --git a/oneflow/core/operator/tensor_buffer_to_tensor_list_op.cpp b/oneflow/core/operator/tensor_buffer_to_tensor_list_op.cpp index 8ffd778c8effeeab5d949bede4bc9fa7d6772a65..0ae99c86ecc30903a4c173c4e8e2d688a096b1a9 100644 --- a/oneflow/core/operator/tensor_buffer_to_tensor_list_op.cpp +++ b/oneflow/core/operator/tensor_buffer_to_tensor_list_op.cpp @@ -18,6 +18,26 @@ limitations under the License. namespace oneflow { +namespace { + +Maybe InferBlobDescs(const OperatorConf& op_conf, + const std::function& BlobDesc4BnInOp) { + const BlobDesc* in_desc = BlobDesc4BnInOp("in"); + CHECK_EQ_OR_RETURN(in_desc->data_type(), DataType::kTensorBuffer); + CHECK_EQ_OR_RETURN(in_desc->shape().NumAxes(), 1); + DimVector dim_vec = in_desc->shape().dim_vec(); + const ShapeProto& shape = op_conf.tensor_buffer_to_tensor_list_conf().shape(); + dim_vec.insert(dim_vec.end(), shape.dim().begin(), shape.dim().end()); + BlobDesc* out_desc = BlobDesc4BnInOp("out"); + out_desc->mut_shape() = Shape(dim_vec); + out_desc->set_data_type(op_conf.tensor_buffer_to_tensor_list_conf().data_type()); + out_desc->set_is_tensor_list(true); + out_desc->set_is_dynamic(true); + return Maybe::Ok(); +} + +} // namespace + class TensorBufferToTensorListOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(TensorBufferToTensorListOp); @@ -30,21 +50,16 @@ class TensorBufferToTensorListOp final : public Operator { EnrollOutputBn("out", false)->set_header_infered_before_compute(false); } + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override { + return InferBlobDescs(op_conf(), BlobDesc4BnInOp); + } + Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override { - const BlobDesc* in_desc = GetBlobDesc4BnInOp("in"); - CHECK_EQ_OR_RETURN(in_desc->data_type(), DataType::kTensorBuffer); - CHECK_EQ_OR_RETURN(in_desc->shape().NumAxes(), 1); - DimVector dim_vec = in_desc->shape().dim_vec(); - const ShapeProto& shape = op_conf().tensor_buffer_to_tensor_list_conf().shape(); - dim_vec.insert(dim_vec.end(), shape.dim().begin(), shape.dim().end()); - BlobDesc* out_desc = GetBlobDesc4BnInOp("out"); - out_desc->mut_shape() = Shape(dim_vec); - out_desc->set_data_type(op_conf().tensor_buffer_to_tensor_list_conf().data_type()); - out_desc->set_is_tensor_list(true); - out_desc->set_is_dynamic(true); - return Maybe::Ok(); + return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp); } private: diff --git a/oneflow/core/operator/tensor_list_split_op.cpp b/oneflow/core/operator/tensor_list_split_op.cpp index f3484a64f0f0b13e3a6a7eb077c9ce2e5ff72e76..4ca569778146f8c247678a26ec4235fa60bee19a 100644 --- a/oneflow/core/operator/tensor_list_split_op.cpp +++ b/oneflow/core/operator/tensor_list_split_op.cpp @@ -18,6 +18,27 @@ limitations under the License. namespace oneflow { +namespace { + +Maybe InferBlobDescs(const Operator& op, + const std::function& BlobDesc4BnInOp) { + const BlobDesc* in_desc = BlobDesc4BnInOp(op.SoleIbn()); + CHECK_OR_RETURN(in_desc->is_tensor_list()); + CHECK_GT_OR_RETURN(in_desc->shape().NumAxes(), 1); + const int64_t N = in_desc->shape().At(0); + CHECK_EQ_OR_RETURN(N, op.output_bns().size()); + DimVector dim_vec{in_desc->shape().dim_vec().begin() + 1, in_desc->shape().dim_vec().end()}; + FOR_RANGE(int, i, 0, N) { + BlobDesc* out_i = BlobDesc4BnInOp(op.output_bns().Get(i)); + out_i->mut_shape() = Shape(dim_vec); + out_i->set_data_type(in_desc->data_type()); + out_i->set_is_dynamic(true); + } + return Maybe::Ok(); +} + +} // namespace + class TensorListSplitOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(TensorListSplitOp); @@ -32,22 +53,16 @@ class TensorListSplitOp final : public Operator { }); } + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override { + return InferBlobDescs(*this, BlobDesc4BnInOp); + } + Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override { - const BlobDesc* in_desc = GetBlobDesc4BnInOp(SoleIbn()); - CHECK_OR_RETURN(in_desc->is_tensor_list()); - CHECK_GT_OR_RETURN(in_desc->shape().NumAxes(), 1); - const int64_t N = in_desc->shape().At(0); - CHECK_EQ_OR_RETURN(N, output_bns().size()); - DimVector dim_vec{in_desc->shape().dim_vec().begin() + 1, in_desc->shape().dim_vec().end()}; - FOR_RANGE(int, i, 0, N) { - BlobDesc* out_i = GetBlobDesc4BnInOp(output_bns().Get(i)); - out_i->mut_shape() = Shape(dim_vec); - out_i->set_data_type(in_desc->data_type()); - out_i->set_is_dynamic(true); - } - return Maybe::Ok(); + return InferBlobDescs(*this, GetBlobDesc4BnInOp); } private: diff --git a/oneflow/core/operator/tensor_list_to_tensor_buffer_op.cpp b/oneflow/core/operator/tensor_list_to_tensor_buffer_op.cpp index 56bf1cfb7d14a9d68549d5e01dd6e6dc82b9e331..8fbe8ccbc6b1cf6c1fe07493d24ec5a3c25d2b38 100644 --- a/oneflow/core/operator/tensor_list_to_tensor_buffer_op.cpp +++ b/oneflow/core/operator/tensor_list_to_tensor_buffer_op.cpp @@ -18,6 +18,21 @@ limitations under the License. namespace oneflow { +namespace { + +Maybe InferBlobDescs(const std::function& BlobDesc4BnInOp) { + const BlobDesc* in_desc = BlobDesc4BnInOp("in"); + CHECK_OR_RETURN(in_desc->is_tensor_list()); + const int64_t N = in_desc->shape().At(0); + BlobDesc* out_desc = BlobDesc4BnInOp("out"); + out_desc->mut_shape() = Shape({N}); + out_desc->set_data_type(DataType::kTensorBuffer); + out_desc->set_is_dynamic(in_desc->is_dynamic()); + return Maybe::Ok(); +} + +} // namespace + class TensorListToTensorBufferOp final : public Operator { public: OF_DISALLOW_COPY_AND_MOVE(TensorListToTensorBufferOp); @@ -30,17 +45,16 @@ class TensorListToTensorBufferOp final : public Operator { EnrollOutputBn("out", false)->set_header_infered_before_compute(false); } + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override { + return InferBlobDescs(BlobDesc4BnInOp); + } + Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override { - const BlobDesc* in_desc = GetBlobDesc4BnInOp("in"); - CHECK_OR_RETURN(in_desc->is_tensor_list()); - const int64_t N = in_desc->shape().At(0); - BlobDesc* out_desc = GetBlobDesc4BnInOp("out"); - out_desc->mut_shape() = Shape({N}); - out_desc->set_data_type(DataType::kTensorBuffer); - out_desc->set_is_dynamic(in_desc->is_dynamic()); - return Maybe::Ok(); + return InferBlobDescs(GetBlobDesc4BnInOp); } private: diff --git a/oneflow/core/operator/tick_op.cpp b/oneflow/core/operator/tick_op.cpp index 71a0f75bb3ebd2c9388066a6a17ebb778fd860d6..ad3fbba6d29bc0e9de65858e53f064bf6cdf5ac5 100644 --- a/oneflow/core/operator/tick_op.cpp +++ b/oneflow/core/operator/tick_op.cpp @@ -18,17 +18,31 @@ limitations under the License. namespace oneflow { +namespace { + +Maybe InferBlobDescs(const std::function& BlobDesc4BnInOp) { + BlobDesc4BnInOp("out")->mut_shape() = Shape({1}); + return Maybe::Ok(); +} + +} // namespace + void TickOp::InitFromOpConf() { CHECK(op_conf().has_tick_conf()); EnrollRepeatedInputBn("tick", false); EnrollOutputBn("out", false); } +Maybe TickOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + return InferBlobDescs(BlobDesc4BnInOp); +} + Maybe TickOp::InferOutBlobDescs( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { - GetBlobDesc4BnInOp("out")->mut_shape() = Shape({1}); - return Maybe::Ok(); + return InferBlobDescs(GetBlobDesc4BnInOp); } Maybe TickOp::GetSbpSignatures( diff --git a/oneflow/core/operator/tick_op.h b/oneflow/core/operator/tick_op.h index 4d6a9f160ab2f91df176b184835ebdb9ac21fb70..adb4ebea293d014b368c58e879b1b0248c3ee163 100644 --- a/oneflow/core/operator/tick_op.h +++ b/oneflow/core/operator/tick_op.h @@ -28,6 +28,9 @@ class TickOp final : public Operator { ~TickOp() = default; void InitFromOpConf() override; + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override; diff --git a/oneflow/core/operator/unique_with_counts_op.cpp b/oneflow/core/operator/unique_with_counts_op.cpp index f5dc04d50b994a3f4b156a3ee83daf075b4b2766..89905f96171466b0fde7bff0ba2ea953b92fbdbe 100644 --- a/oneflow/core/operator/unique_with_counts_op.cpp +++ b/oneflow/core/operator/unique_with_counts_op.cpp @@ -26,6 +26,9 @@ class UniqueWithCountsOp final : public Operator { ~UniqueWithCountsOp() override = default; void InitFromOpConf() override; + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override; @@ -49,27 +52,42 @@ void UniqueWithCountsOp::InitFromOpConf() { EnrollTmpBn("workspace"); } -Maybe UniqueWithCountsOp::InferOutBlobDescs( - std::function GetBlobDesc4BnInOp, - const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { - const BlobDesc* x = GetBlobDesc4BnInOp("x"); +namespace { + +Maybe InferBlobDescs(const OperatorConf& op_conf, + const std::function& BlobDesc4BnInOp) { + const BlobDesc* x = BlobDesc4BnInOp("x"); CHECK_EQ_OR_RETURN(x->shape().NumAxes(), 1); - BlobDesc* y = GetBlobDesc4BnInOp("y"); + BlobDesc* y = BlobDesc4BnInOp("y"); *y = *x; - const DataType idx_data_type = op_conf().unique_with_counts_conf().out_idx(); + const DataType idx_data_type = op_conf.unique_with_counts_conf().out_idx(); CHECK(IsIndexDataType(idx_data_type)); - BlobDesc* idx = GetBlobDesc4BnInOp("idx"); + BlobDesc* idx = BlobDesc4BnInOp("idx"); *idx = *x; idx->set_data_type(idx_data_type); - BlobDesc* count = GetBlobDesc4BnInOp("count"); + BlobDesc* count = BlobDesc4BnInOp("count"); *count = *x; count->set_data_type(idx_data_type); - BlobDesc* num_unique = GetBlobDesc4BnInOp("num_unique"); + BlobDesc* num_unique = BlobDesc4BnInOp("num_unique"); num_unique->mut_shape() = Shape({1}); num_unique->set_data_type(idx_data_type); return Maybe::Ok(); } +} // namespace + +Maybe UniqueWithCountsOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + return InferBlobDescs(op_conf(), BlobDesc4BnInOp); +} + +Maybe UniqueWithCountsOp::InferOutBlobDescs( + std::function GetBlobDesc4BnInOp, + const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { + return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp); +} + Maybe UniqueWithCountsOp::InferInternalBlobDescs( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { diff --git a/oneflow/core/operator/variable_op.cpp b/oneflow/core/operator/variable_op.cpp index 5ff6ac3a4ce2199117eb05bb7839fb5a10b19c95..ee6da803015c24026d3f4f5a94ddcb7c1bf38832 100644 --- a/oneflow/core/operator/variable_op.cpp +++ b/oneflow/core/operator/variable_op.cpp @@ -43,6 +43,17 @@ void VariableOp::InitFromOpConf() { EnrollOutputBn("out", is_trainable)->set_is_mutable(true); } +Maybe VariableOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + const VariableOpConf& variable_conf = op_conf().variable_conf(); + BlobDesc* out_blob_desc = BlobDesc4BnInOp("out"); + out_blob_desc->mut_shape() = Shape(variable_conf.shape()); + out_blob_desc->set_data_type(variable_conf.has_data_type() ? variable_conf.data_type() + : job_desc().DefaultDataType()); + return Maybe::Ok(); +} + Maybe VariableOp::InferOutBlobDescs( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { diff --git a/oneflow/core/operator/variable_op.h b/oneflow/core/operator/variable_op.h index 0d7331981e1188ec9e242bc4f9341edc722813d6..6f04ac6ea54467c8757461ed71b1abaf610ea92f 100644 --- a/oneflow/core/operator/variable_op.h +++ b/oneflow/core/operator/variable_op.h @@ -32,6 +32,9 @@ class VariableOp final : public Operator { const SbpSignature* sbp_signature) const override; private: + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override; Maybe InferSbpSignature( SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf, const std::function& CalcOrderValue4SbpSig, diff --git a/oneflow/core/operator/wait_and_send_ids_op.cpp b/oneflow/core/operator/wait_and_send_ids_op.cpp index 7382925f5521a21aaca70863b7f0da20b828583c..60b809871e4a124e167b1f0994837abd0fa66f8c 100644 --- a/oneflow/core/operator/wait_and_send_ids_op.cpp +++ b/oneflow/core/operator/wait_and_send_ids_op.cpp @@ -28,17 +28,33 @@ LogicalNode* WaitAndSendIdsOp::NewProperLogicalNode() const { return new WaitAndSendIdsLogicalNode(); } +namespace { + +Maybe InferBlobDescs(const OperatorConf& op_conf, + const std::function& BlobDesc4BnInOp) { + BlobDesc4BnInOp("out")->mut_shape() = Shape({1}); + BlobDesc4BnInOp("out")->set_data_type(op_conf.wait_and_send_ids_conf().data_type()); + return Maybe::Ok(); +} + +} // namespace + +Maybe WaitAndSendIdsOp::InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const { + CHECK_EQ_OR_RETURN(parallel_desc.parallel_num(), 1); + return InferBlobDescs(op_conf(), BlobDesc4BnInOp); +} + Maybe WaitAndSendIdsOp::InferOutBlobDescs( std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const { CHECK_EQ_OR_RETURN(parallel_ctx->parallel_num(), 1); - GetBlobDesc4BnInOp("out")->mut_shape() = Shape({1}); - GetBlobDesc4BnInOp("out")->set_data_type(op_conf().wait_and_send_ids_conf().data_type()); - return Maybe::Ok(); + return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp); } Maybe WaitAndSendIdsOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const { - SbpSignatureBuilder().Split(output_bns(), 0).Build(sbp_sig_list->mutable_sbp_signature()->Add()); + SbpSignatureBuilder().Broadcast(output_bns()).Build(sbp_sig_list->mutable_sbp_signature()->Add()); return Maybe::Ok(); } diff --git a/oneflow/core/operator/wait_and_send_ids_op.h b/oneflow/core/operator/wait_and_send_ids_op.h index cd919e31c1db05cb9f5f5429f0d5e8cef958a7fb..8d807e0853a28954fca9a2af68cb3ca75c75dc05 100644 --- a/oneflow/core/operator/wait_and_send_ids_op.h +++ b/oneflow/core/operator/wait_and_send_ids_op.h @@ -27,6 +27,9 @@ class WaitAndSendIdsOp final : public Operator { ~WaitAndSendIdsOp() = default; void InitFromOpConf() override; + Maybe InferLogicalOutBlobDescs( + const std::function& BlobDesc4BnInOp, + const ParallelDesc& parallel_desc) const override; Maybe InferOutBlobDescs(std::function GetBlobDesc4BnInOp, const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const override;