未验证 提交 db6ec194 编写于 作者: J Juncheng 提交者: GitHub

Add InferLogicalOutBlobDescs for system ops (#4271)

Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 a66b20e8
......@@ -24,6 +24,9 @@ class AssignOp final : public Operator {
~AssignOp() override = default;
void InitFromOpConf() override;
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
......@@ -46,13 +49,27 @@ std::string DebugString(const BlobDesc& blob_desc) {
return blob_desc_proto.DebugString();
}
namespace {
Maybe<void> InferBlobDescs(const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {
CHECK_OR_RETURN(*BlobDesc4BnInOp("ref") == *BlobDesc4BnInOp("value"))
<< "\nref_blob_desc: " << DebugString(*BlobDesc4BnInOp("ref"))
<< "\nvalue_blob_desc: " << DebugString(*BlobDesc4BnInOp("value"));
return Maybe<void>::Ok();
}
} // namespace
Maybe<void> AssignOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
return InferBlobDescs(BlobDesc4BnInOp);
}
Maybe<void> AssignOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
CHECK_OR_RETURN(*GetBlobDesc4BnInOp("ref") == *GetBlobDesc4BnInOp("value"))
<< "\nref_blob_desc: " << DebugString(*GetBlobDesc4BnInOp("ref"))
<< "\nvalue_blob_desc: " << DebugString(*GetBlobDesc4BnInOp("value"));
return Maybe<void>::Ok();
return InferBlobDescs(GetBlobDesc4BnInOp);
}
Maybe<void> AssignOp::GetSbpSignatures(
......
......@@ -35,6 +35,21 @@ Maybe<void> GetBroadcastShape(const Shape& a_shape, const Shape& b_shape, Shape*
return Maybe<void>::Ok();
}
Maybe<void> InferBlobDescs(const OperatorConf& op_conf,
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {
int64_t num_compatibles = op_conf.broadcast_to_compatible_with_conf().compatible_size();
const BlobDesc* x_desc = BlobDesc4BnInOp("x");
Shape broadcasted_shape(x_desc->shape());
FOR_RANGE(int64_t, i, 0, num_compatibles) {
const BlobDesc* compatible_i = BlobDesc4BnInOp(GenRepeatedBn("compatible", i));
GetBroadcastShape(broadcasted_shape, compatible_i->shape(), &broadcasted_shape);
}
BlobDesc* y_desc = BlobDesc4BnInOp("y");
y_desc->CopyFrom(*x_desc);
y_desc->mut_shape() = broadcasted_shape;
return Maybe<void>::Ok();
}
} // namespace
class BroadcastToCompatibleWithOp final : public Operator {
......@@ -50,20 +65,16 @@ class BroadcastToCompatibleWithOp final : public Operator {
EnrollOutputBn("y");
}
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override {
return InferBlobDescs(op_conf(), BlobDesc4BnInOp);
}
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override {
int64_t num_compatibles = op_conf().broadcast_to_compatible_with_conf().compatible_size();
const BlobDesc* x_desc = GetBlobDesc4BnInOp("x");
Shape broadcasted_shape(x_desc->shape());
FOR_RANGE(int64_t, i, 0, num_compatibles) {
const BlobDesc* compatible_i = GetBlobDesc4BnInOp(GenRepeatedBn("compatible", i));
GetBroadcastShape(broadcasted_shape, compatible_i->shape(), &broadcasted_shape);
}
BlobDesc* y_desc = GetBlobDesc4BnInOp("y");
y_desc->CopyFrom(*x_desc);
y_desc->mut_shape() = broadcasted_shape;
return Maybe<void>::Ok();
return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp);
}
private:
......
......@@ -28,17 +28,31 @@ LogicalNode* CallbackNotifyOp::NewProperLogicalNode() const {
return new CallbackNotifyLogicalNode();
}
namespace {
Maybe<void> InferBlobDescs(const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {
CHECK_OR_RETURN(BlobDesc4BnInOp("in")->shape() == Shape({1}));
CHECK_OR_RETURN(IsIntegralDataType(BlobDesc4BnInOp("in")->data_type()));
return Maybe<void>::Ok();
}
} // namespace
Maybe<void> CallbackNotifyOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
CHECK_EQ_OR_RETURN(parallel_desc.parallel_num(), 1);
return InferBlobDescs(BlobDesc4BnInOp);
}
Maybe<void> CallbackNotifyOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
CHECK_EQ_OR_RETURN(parallel_ctx->parallel_num(), 1);
CHECK_OR_RETURN(GetBlobDesc4BnInOp("in")->shape() == Shape({1}));
CHECK_OR_RETURN(IsIntegralDataType(GetBlobDesc4BnInOp("in")->data_type()));
return Maybe<void>::Ok();
return InferBlobDescs(GetBlobDesc4BnInOp);
}
Maybe<void> CallbackNotifyOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const {
SbpSignatureBuilder().Split(input_bns(), 0).Build(sbp_sig_list->mutable_sbp_signature()->Add());
return Maybe<void>::Ok();
}
......
......@@ -27,6 +27,9 @@ class CallbackNotifyOp final : public Operator {
~CallbackNotifyOp() = default;
void InitFromOpConf() override;
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
......
......@@ -24,21 +24,35 @@ void CaseOp::InitFromOpConf() {
EnrollRepeatedOutputBn("out", false);
}
Maybe<void> CaseOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
const BlobDesc* in = GetBlobDesc4BnInOp("in");
namespace {
Maybe<void> InferBlobDescs(const Operator& op,
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {
const BlobDesc* in = BlobDesc4BnInOp("in");
CHECK_EQ_OR_RETURN(in->shape().elem_cnt(), 1);
const DataType data_type = in->data_type();
CHECK_OR_RETURN(IsIntegralDataType(data_type));
for (const std::string& obn : output_bns()) {
BlobDesc* out = GetBlobDesc4BnInOp(obn);
for (const std::string& obn : op.output_bns()) {
BlobDesc* out = BlobDesc4BnInOp(obn);
out->mut_shape() = Shape({1});
out->set_data_type(data_type);
}
return Maybe<void>::Ok();
}
} // namespace
Maybe<void> CaseOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
return InferBlobDescs(*this, BlobDesc4BnInOp);
}
Maybe<void> CaseOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
return InferBlobDescs(*this, GetBlobDesc4BnInOp);
}
Maybe<void> CaseOp::GetSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
SbpSignatureList* sbp_sig_list) const {
......
......@@ -27,6 +27,9 @@ class CaseOp final : public Operator {
~CaseOp() override = default;
void InitFromOpConf() override;
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
......
......@@ -17,6 +17,19 @@ limitations under the License.
namespace oneflow {
namespace {
Maybe<void> InferBlobDescs(const OperatorConf& op_conf,
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {
const ConstantLikeOpConf& conf = op_conf.constant_like_conf();
BlobDesc* out_blob_desc = BlobDesc4BnInOp("out");
*out_blob_desc = *BlobDesc4BnInOp("like");
if (conf.has_data_type()) { out_blob_desc->set_data_type(conf.data_type()); }
return Maybe<void>::Ok();
}
} // namespace
class ConstantLikeOp final : public Operator {
public:
OF_DISALLOW_COPY_AND_MOVE(ConstantLikeOp);
......@@ -29,14 +42,16 @@ class ConstantLikeOp final : public Operator {
EnrollOutputBn("out", false);
}
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override {
return InferBlobDescs(op_conf(), BlobDesc4BnInOp);
}
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override {
const ConstantLikeOpConf& conf = op_conf().constant_like_conf();
BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out");
*out_blob_desc = *GetBlobDesc4BnInOp("like");
if (conf.has_data_type()) { out_blob_desc->set_data_type(conf.data_type()); }
return Maybe<void>::Ok();
return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp);
}
private:
......
......@@ -24,11 +24,25 @@ void DeviceTickOp::InitFromOpConf() {
EnrollOutputBn("out", false);
}
namespace {
Maybe<void> InferBlobDescs(const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {
BlobDesc4BnInOp("out")->mut_shape() = Shape({1});
return Maybe<void>::Ok();
}
} // namespace
Maybe<void> DeviceTickOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
return InferBlobDescs(BlobDesc4BnInOp);
}
Maybe<void> DeviceTickOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
GetBlobDesc4BnInOp("out")->mut_shape() = Shape({1});
return Maybe<void>::Ok();
return InferBlobDescs(GetBlobDesc4BnInOp);
}
Maybe<void> DeviceTickOp::GetSbpSignatures(
......
......@@ -28,6 +28,9 @@ class DeviceTickOp final : public Operator {
~DeviceTickOp() = default;
void InitFromOpConf() override;
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
......
......@@ -30,6 +30,9 @@ class DistributeAddOp final : public Operator {
void InitFromOpConf() override;
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
......@@ -67,6 +70,18 @@ Maybe<void> DistributeAddOp::InferBlobParallelDesc() {
return Maybe<void>::Ok();
}
Maybe<void> DistributeAddOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
const BlobDesc* in_0 = BlobDesc4BnInOp(input_bns().Get(0));
FOR_RANGE(int, i, 1, output_bns().size()) {
const BlobDesc* in_i = BlobDesc4BnInOp(input_bns().Get(i));
CHECK_OR_RETURN(*in_i == *in_0);
}
*BlobDesc4BnInOp("out") = *in_0;
return Maybe<void>::Ok();
}
Maybe<void> DistributeAddOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
......
......@@ -39,6 +39,9 @@ class DistributeCloneOp final : public Operator {
const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,
std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,
const ParallelDesc& parallel_desc) const override;
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
......@@ -53,6 +56,17 @@ void DistributeCloneOp::InitFromOpConf() {
});
}
Maybe<void> DistributeCloneOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
const auto& in_blob_desc = *BlobDesc4BnInOp("in");
FOR_RANGE(int, i, 0, output_bns().size()) {
BlobDesc* blob_desc = BlobDesc4BnInOp(output_bns().Get(i));
*blob_desc = in_blob_desc;
}
return Maybe<void>::Ok();
}
Maybe<void> DistributeCloneOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
......
......@@ -30,6 +30,9 @@ class DistributeConcatOp final : public Operator {
void InitFromOpConf() override;
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
......@@ -57,6 +60,30 @@ void DistributeConcatOp::InitFromOpConf() {
EnrollOutputBn("out");
}
Maybe<void> DistributeConcatOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
const auto& conf = op_conf().distribute_concat_conf();
BlobDesc* out = BlobDesc4BnInOp("out");
*out = *BlobDesc4BnInOp(input_bns().Get(0));
const int32_t concat_axis = FixAxis(conf.axis(), out->shape().NumAxes());
int64_t concat_dim_size = out->shape().At(concat_axis);
for (size_t i = 1; i < input_bns().size(); ++i) {
const BlobDesc* in_i = BlobDesc4BnInOp(input_bns().Get(i));
for (int64_t j = 0; j < in_i->shape().NumAxes(); ++j) {
if (j == concat_axis) {
concat_dim_size += in_i->shape().At(j);
} else {
CHECK_EQ_OR_RETURN(out->shape().At(j), in_i->shape().At(j));
}
}
CHECK_EQ_OR_RETURN(in_i->data_type(), out->data_type());
}
out->mut_shape().Set(concat_axis, concat_dim_size);
out->set_is_dynamic(false);
return Maybe<void>::Ok();
}
Maybe<void> DistributeConcatOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
......
......@@ -39,6 +39,9 @@ class DistributeSplitOp final : public Operator {
const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,
std::function<Maybe<const SbpInferHint*>(const std::string&)> SbpInferHint4Ibn,
const ParallelDesc& parallel_desc) const override;
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
......@@ -59,6 +62,22 @@ void DistributeSplitOp::InitFromOpConf() {
});
}
Maybe<void> DistributeSplitOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
const auto& in_blob_desc = *BlobDesc4BnInOp("in");
CHECK_EQ(parallel_desc.parallel_num(), output_bns().size());
const auto& conf = op_conf().distribute_split_conf();
const int32_t split_axis = FixAxis(conf.axis(), in_blob_desc.shape().NumAxes());
BalancedSplitter bs(in_blob_desc.shape().At(split_axis), parallel_desc.parallel_num());
FOR_RANGE(int, i, 0, parallel_desc.parallel_num()) {
BlobDesc* out_blob_desc = BlobDesc4BnInOp(output_bns().Get(i));
*out_blob_desc = in_blob_desc;
out_blob_desc->mut_shape().Set(split_axis, bs.At(i).size());
}
return Maybe<void>::Ok();
}
Maybe<void> DistributeSplitOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
......
......@@ -20,6 +20,15 @@ limitations under the License.
namespace oneflow {
namespace {
Maybe<void> InferBlobDescs(const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {
BlobDesc4BnInOp("out")->mut_shape() = Shape({1});
return Maybe<void>::Ok();
}
} // namespace
class DstSubsetTickOp final : public Operator {
public:
OF_DISALLOW_COPY_AND_MOVE(DstSubsetTickOp);
......@@ -27,6 +36,9 @@ class DstSubsetTickOp final : public Operator {
~DstSubsetTickOp() = default;
void InitFromOpConf() override;
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature*) const override;
......@@ -46,11 +58,16 @@ LogicalNode* DstSubsetTickOp::NewProperLogicalNode() const {
return new DstSubsetTickLogicalNode();
}
Maybe<void> DstSubsetTickOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
return InferBlobDescs(BlobDesc4BnInOp);
}
Maybe<void> DstSubsetTickOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature*) const {
GetBlobDesc4BnInOp("out")->mut_shape() = Shape({1});
return Maybe<void>::Ok();
return InferBlobDescs(GetBlobDesc4BnInOp);
}
Maybe<void> DstSubsetTickOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const {
......
......@@ -24,17 +24,32 @@ void EsacOp::InitFromOpConf() {
EnrollOutputBn("out", false);
}
Maybe<void> EsacOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
BlobDesc* out = GetBlobDesc4BnInOp("out");
namespace {
Maybe<void> InferBlobDescs(const OperatorConf& op_conf,
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {
BlobDesc* out = BlobDesc4BnInOp("out");
out->mut_shape() = Shape({1});
const DataType data_type = op_conf().esac_conf().data_type();
const DataType data_type = op_conf.esac_conf().data_type();
CHECK_OR_RETURN(IsIntegralDataType(data_type));
out->set_data_type(data_type);
return Maybe<void>::Ok();
}
} // namespace
Maybe<void> EsacOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
return InferBlobDescs(op_conf(), BlobDesc4BnInOp);
}
Maybe<void> EsacOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp);
}
Maybe<void> EsacOp::GetSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
SbpSignatureList* sbp_sig_list) const {
......
......@@ -27,6 +27,9 @@ class EsacOp final : public Operator {
~EsacOp() override = default;
void InitFromOpConf() override;
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
......
......@@ -22,6 +22,22 @@ namespace {
void CheckOpConf(const OperatorConf& op_conf) { CHECK(op_conf.ctrl_in_op_name().empty()); }
Maybe<void> InferBlobDescs(const JobDesc& job_desc, const OperatorConf& op_conf,
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {
CheckOpConf(op_conf);
const auto& conf = op_conf.foreign_input_conf().blob_conf();
BlobDesc* out_blob_desc = BlobDesc4BnInOp("out");
out_blob_desc->mut_shape() = Shape(conf.shape());
if (conf.has_data_type()) {
out_blob_desc->set_data_type(conf.data_type());
} else {
out_blob_desc->set_data_type(job_desc.DefaultDataType());
}
out_blob_desc->set_is_dynamic(conf.is_dynamic());
out_blob_desc->set_is_tensor_list(conf.is_tensor_list());
return Maybe<void>::Ok();
}
} // namespace
void ForeignInputOp::InitFromOpConf() {
......@@ -30,22 +46,18 @@ void ForeignInputOp::InitFromOpConf() {
EnrollOutputBn("out", false);
}
Maybe<void> ForeignInputOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
CHECK_EQ_OR_RETURN(parallel_desc.parallel_num(), 1);
return InferBlobDescs(job_desc(), op_conf(), BlobDesc4BnInOp);
}
Maybe<void> ForeignInputOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
CHECK_EQ_OR_RETURN(parallel_ctx->parallel_num(), 1);
CheckOpConf(op_conf());
const auto& conf = op_conf().foreign_input_conf().blob_conf();
BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out");
out_blob_desc->mut_shape() = Shape(conf.shape());
if (conf.has_data_type()) {
out_blob_desc->set_data_type(conf.data_type());
} else {
out_blob_desc->set_data_type(job_desc().DefaultDataType());
}
out_blob_desc->set_is_dynamic(conf.is_dynamic());
out_blob_desc->set_is_tensor_list(conf.is_tensor_list());
return Maybe<void>::Ok();
return InferBlobDescs(job_desc(), op_conf(), GetBlobDesc4BnInOp);
}
Maybe<void> ForeignInputOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const {
......
......@@ -28,6 +28,9 @@ class ForeignInputOp final : public Operator {
~ForeignInputOp() = default;
void InitFromOpConf() override;
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
......
......@@ -23,6 +23,13 @@ void ForeignOutputOp::InitFromOpConf() {
EnrollInputBn("in");
}
Maybe<void> ForeignOutputOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
CHECK_EQ_OR_RETURN(parallel_desc.parallel_num(), 1);
return Maybe<void>::Ok();
}
Maybe<void> ForeignOutputOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
......
......@@ -28,6 +28,9 @@ class ForeignOutputOp final : public Operator {
~ForeignOutputOp() override = default;
void InitFromOpConf() override;
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
......
......@@ -23,6 +23,13 @@ void ForeignWatchOp::InitFromOpConf() {
EnrollInputBn("in");
}
Maybe<void> ForeignWatchOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
CHECK_EQ_OR_RETURN(parallel_desc.parallel_num(), 1);
return Maybe<void>::Ok();
}
Maybe<void> ForeignWatchOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
......
......@@ -28,6 +28,9 @@ class ForeignWatchOp final : public Operator {
~ForeignWatchOp() override = default;
void InitFromOpConf() override;
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
......
......@@ -20,6 +20,15 @@ limitations under the License.
namespace oneflow {
namespace {
Maybe<void> InferBlobDescs(const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {
*BlobDesc4BnInOp("out") = *BlobDesc4BnInOp("in");
return Maybe<void>::Ok();
}
} // namespace
template<typename T>
class IdentityOpTpl final : public Operator {
public:
......@@ -31,11 +40,15 @@ class IdentityOpTpl final : public Operator {
EnrollInputBn("in");
EnrollOutputBn("out")->set_const_inplace_ibn("in");
}
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override {
return InferBlobDescs(BlobDesc4BnInOp);
}
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override {
*GetBlobDesc4BnInOp("out") = *GetBlobDesc4BnInOp("in");
return Maybe<void>::Ok();
return InferBlobDescs(GetBlobDesc4BnInOp);
}
private:
......@@ -66,11 +79,15 @@ class MirroredCastOp : public Operator {
EnrollInputBn("in");
EnrollOutputBn("out")->set_const_inplace_ibn("in");
}
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override {
return InferBlobDescs(BlobDesc4BnInOp);
}
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override {
*GetBlobDesc4BnInOp("out") = *GetBlobDesc4BnInOp("in");
return Maybe<void>::Ok();
return InferBlobDescs(GetBlobDesc4BnInOp);
}
private:
......
......@@ -26,6 +26,9 @@ class IndexedSlicesReduceSumOp final : public Operator {
~IndexedSlicesReduceSumOp() override = default;
void InitFromOpConf() override;
virtual Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
......@@ -49,11 +52,11 @@ void IndexedSlicesReduceSumOp::InitFromOpConf() {
EnrollTmpBn("workspace");
}
Maybe<void> IndexedSlicesReduceSumOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
const BlobDesc* x_indices = GetBlobDesc4BnInOp("x_indices");
const BlobDesc* x_values = GetBlobDesc4BnInOp("x_values");
namespace {
Maybe<void> InferBlobDescs(const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {
const BlobDesc* x_indices = BlobDesc4BnInOp("x_indices");
const BlobDesc* x_values = BlobDesc4BnInOp("x_values");
CHECK_LT_OR_RETURN(x_indices->shape().NumAxes(), x_values->shape().NumAxes());
FOR_RANGE(int64_t, i, 0, x_indices->shape().NumAxes()) {
CHECK_EQ_OR_RETURN(x_indices->shape().At(i), x_values->shape().At(i));
......@@ -61,18 +64,32 @@ Maybe<void> IndexedSlicesReduceSumOp::InferOutBlobDescs(
CHECK_OR_RETURN(IsIndexDataType(x_indices->data_type()));
const int64_t n = x_indices->shape().elem_cnt();
const int64_t m = x_values->shape().elem_cnt() / n;
BlobDesc* y_indices = GetBlobDesc4BnInOp("y_indices");
BlobDesc* y_values = GetBlobDesc4BnInOp("y_values");
BlobDesc* y_indices = BlobDesc4BnInOp("y_indices");
BlobDesc* y_values = BlobDesc4BnInOp("y_values");
*y_indices = *x_indices;
y_indices->mut_shape() = Shape({n});
*y_values = *x_values;
y_values->mut_shape() = Shape({n, m});
BlobDesc* num_unique = GetBlobDesc4BnInOp("num_unique");
BlobDesc* num_unique = BlobDesc4BnInOp("num_unique");
num_unique->mut_shape() = Shape({1});
num_unique->set_data_type(DataType::kInt64);
return Maybe<void>::Ok();
}
} // namespace
Maybe<void> IndexedSlicesReduceSumOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
return InferBlobDescs(BlobDesc4BnInOp);
}
Maybe<void> IndexedSlicesReduceSumOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
return InferBlobDescs(GetBlobDesc4BnInOp);
}
Maybe<void> IndexedSlicesReduceSumOp::InferInternalBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
......
......@@ -28,6 +28,15 @@ void InputOp::InitFromOpConf() {
modifier->set_header_infered_before_compute(false);
}
Maybe<void> InputOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
BlobDesc* out_blob_desc = BlobDesc4BnInOp("out");
JUST(InterfaceOpUtil::InferLogicalOutBlobDesc(op_conf().input_conf().blob_conf(), out_blob_desc,
parallel_desc));
return Maybe<void>::Ok();
}
Maybe<void> InputOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
......
......@@ -32,6 +32,9 @@ class InputOp final : public Operator {
const SbpSignature* sbp_signature) const override;
private:
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override;
Maybe<void> InferSbpSignature(
SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,
const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,
......
......@@ -67,6 +67,18 @@ Maybe<void> InterfaceOpUtil::InferOutBlobDesc(const InterfaceBlobConf& blob_conf
return Maybe<void>::Ok();
}
Maybe<void> InterfaceOpUtil::InferLogicalOutBlobDesc(const InterfaceBlobConf& blob_conf,
BlobDesc* out_blob_desc,
const ParallelDesc& parallel_desc) {
out_blob_desc->mut_shape() = Shape(blob_conf.shape());
CheckShape(out_blob_desc->shape());
CHECK_GT(out_blob_desc->mut_shape().At(0), 0);
out_blob_desc->set_data_type(blob_conf.data_type());
out_blob_desc->set_is_dynamic(blob_conf.is_dynamic());
out_blob_desc->set_is_tensor_list(blob_conf.is_tensor_list());
return Maybe<void>::Ok();
}
Maybe<void> InterfaceOpUtil::GetInputLikeOpSbpSignature(const InterfaceBlobConf& blob_conf,
const PbRpf<std::string>& input_bns,
const PbRpf<std::string>& output_bns,
......
......@@ -25,6 +25,9 @@ namespace oneflow {
struct InterfaceOpUtil final {
static Maybe<void> InferOutBlobDesc(const InterfaceBlobConf& blob_conf, BlobDesc* out_blob_desc,
const ParallelContext* parallel_ctx);
static Maybe<void> InferLogicalOutBlobDesc(const InterfaceBlobConf& blob_conf,
BlobDesc* out_blob_desc,
const ParallelDesc& parallel_desc);
static Maybe<void> GetInputLikeOpSbpSignature(const InterfaceBlobConf& blob_conf,
const PbRpf<std::string>& input_bns,
const PbRpf<std::string>& output_bns,
......
......@@ -24,6 +24,9 @@ class LearningRateScheduleOp final : public Operator {
~LearningRateScheduleOp() override = default;
void InitFromOpConf() override;
virtual Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
......@@ -40,18 +43,32 @@ void LearningRateScheduleOp::InitFromOpConf() {
EnrollOutputBn("out");
}
Maybe<void> LearningRateScheduleOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
const BlobDesc* train_step = GetBlobDesc4BnInOp("train_step");
namespace {
Maybe<void> InferBlobDescs(const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {
const BlobDesc* train_step = BlobDesc4BnInOp("train_step");
CHECK_EQ(train_step->shape().elem_cnt(), 1);
CHECK_EQ(train_step->data_type(), DataType::kInt64);
BlobDesc* out = GetBlobDesc4BnInOp("out");
BlobDesc* out = BlobDesc4BnInOp("out");
out->mut_shape() = Shape({1});
out->set_data_type(DataType::kFloat);
return Maybe<void>::Ok();
}
} // namespace
Maybe<void> LearningRateScheduleOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
return InferBlobDescs(BlobDesc4BnInOp);
}
Maybe<void> LearningRateScheduleOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
return InferBlobDescs(GetBlobDesc4BnInOp);
}
Maybe<void> LearningRateScheduleOp::GetSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
SbpSignatureList* sbp_sig_list) const {
......
......@@ -21,6 +21,9 @@ class ModelInitOp : public Operator {
public:
void InitFromOpConf() override;
virtual Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
......@@ -37,20 +40,34 @@ void ModelInitOp::InitFromOpConf() {
EnrollRepeatedOutputBn("out", false);
}
Maybe<void> ModelInitOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
const int64_t num_out = op_conf().model_init_conf().out().size();
namespace {
Maybe<void> InferBlobDescs(const OperatorConf& conf,
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {
const int64_t num_out = conf.model_init_conf().out().size();
FOR_RANGE(int64_t, i, 0, num_out) {
const VariableOpConf& original_variable_conf =
op_conf().model_init_conf().original_variable_conf(i);
BlobDesc* out_i = GetBlobDesc4BnInOp(GenRepeatedBn("out", i));
const VariableOpConf& original_variable_conf = conf.model_init_conf().original_variable_conf(i);
BlobDesc* out_i = BlobDesc4BnInOp(GenRepeatedBn("out", i));
out_i->mut_shape() = Shape(original_variable_conf.shape());
out_i->set_data_type(original_variable_conf.data_type());
}
return Maybe<void>::Ok();
}
} // namespace
Maybe<void> ModelInitOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
return InferBlobDescs(op_conf(), BlobDesc4BnInOp);
}
Maybe<void> ModelInitOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp);
}
Maybe<void> ModelInitOp::GetSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
SbpSignatureList* sbp_sig_list) const {
......
......@@ -21,6 +21,9 @@ class ModelLoadOp : public Operator {
public:
void InitFromOpConf() override;
virtual Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
......@@ -37,20 +40,35 @@ void ModelLoadOp::InitFromOpConf() {
EnrollRepeatedOutputBn("out", false);
}
Maybe<void> ModelLoadOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
const int64_t num_out = op_conf().model_load_conf().out().size();
namespace {
Maybe<void> InferBlobDescs(const OperatorConf& op_conf,
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {
const int64_t num_out = op_conf.model_load_conf().out().size();
FOR_RANGE(int64_t, i, 0, num_out) {
const VariableOpConf& original_variable_conf =
op_conf().model_load_conf().original_variable_conf(i);
BlobDesc* out_i = GetBlobDesc4BnInOp(GenRepeatedBn("out", i));
op_conf.model_load_conf().original_variable_conf(i);
BlobDesc* out_i = BlobDesc4BnInOp(GenRepeatedBn("out", i));
out_i->mut_shape() = Shape(original_variable_conf.shape());
out_i->set_data_type(original_variable_conf.data_type());
}
return Maybe<void>::Ok();
}
} // namespace
Maybe<void> ModelLoadOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
return InferBlobDescs(op_conf(), BlobDesc4BnInOp);
}
Maybe<void> ModelLoadOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp);
}
Maybe<void> ModelLoadOp::GetSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
SbpSignatureList* sbp_sig_list) const {
......
......@@ -26,6 +26,11 @@ class ModelSaveOp final : public Operator {
void InitFromOpConf() override;
LogicalNode* NewProperLogicalNode() const override { return new PrintLogicalNode; }
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override {
return Maybe<void>::Ok();
}
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override {
......
......@@ -25,6 +25,15 @@ void OutputOp::InitFromOpConf() {
EnrollOutputBn("out")->set_is_mutable(true);
}
Maybe<void> OutputOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
BlobDesc* out_blob_desc = BlobDesc4BnInOp("out");
InterfaceOpUtil::InferLogicalOutBlobDesc(op_conf().output_conf().blob_conf(), out_blob_desc,
parallel_desc);
return Maybe<void>::Ok();
}
Maybe<void> OutputOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
......
......@@ -32,6 +32,9 @@ class OutputOp final : public Operator {
const SbpSignature* sbp_signature) const override;
private:
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override;
Maybe<void> InferSbpSignature(
SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,
const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,
......
......@@ -25,16 +25,32 @@ void ReentrantLockOp::InitFromOpConf() {
EnrollOutputBn("out", false);
}
namespace {
Maybe<void> InferBlobDescs(const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {
const BlobDesc* start = BlobDesc4BnInOp("start");
const DataType data_type = start->data_type();
CHECK_OR_RETURN(IsIntegralDataType(data_type));
BlobDesc* out = BlobDesc4BnInOp("out");
out->mut_shape() = Shape({1});
out->set_data_type(data_type);
return Maybe<void>::Ok();
}
} // namespace
Maybe<void> ReentrantLockOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
CHECK_EQ_OR_RETURN(parallel_desc.parallel_num(), 1);
return InferBlobDescs(BlobDesc4BnInOp);
}
Maybe<void> ReentrantLockOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
CHECK_EQ_OR_RETURN(parallel_ctx->parallel_num(), 1);
BlobDesc* out = GetBlobDesc4BnInOp("out");
out->mut_shape() = Shape({1});
const DataType data_type = GetBlobDesc4BnInOp("out")->data_type();
CHECK_OR_RETURN(IsIntegralDataType(data_type));
out->set_data_type(data_type);
return Maybe<void>::Ok();
return InferBlobDescs(GetBlobDesc4BnInOp);
}
Maybe<void> ReentrantLockOp::GetSbpSignatures(
......
......@@ -27,6 +27,9 @@ class ReentrantLockOp final : public Operator {
~ReentrantLockOp() override = default;
void InitFromOpConf() override;
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
......
......@@ -25,11 +25,25 @@ void ReturnOp::InitFromOpConf() {
EnrollOutputBn("out")->set_is_mutable(true);
}
namespace {
Maybe<void> InferBlobDescs(const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {
*BlobDesc4BnInOp("out") = *BlobDesc4BnInOp("in");
return Maybe<void>::Ok();
}
} // namespace
Maybe<void> ReturnOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
return InferBlobDescs(BlobDesc4BnInOp);
}
Maybe<void> ReturnOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
*GetBlobDesc4BnInOp("out") = *GetBlobDesc4BnInOp("in");
return Maybe<void>::Ok();
return InferBlobDescs(GetBlobDesc4BnInOp);
}
Maybe<void> ReturnOp::InferSbpSignature(
......
......@@ -27,6 +27,9 @@ class ReturnOp final : public Operator {
~ReturnOp() override = default;
void InitFromOpConf() override;
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
......
......@@ -54,12 +54,27 @@ void ShapeElemCntOp::InitFromOpConf() {
EnrollOutputBn("y", false);
}
namespace {
Maybe<void> InferBlobDescs(const OperatorConf& op_conf,
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {
BlobDesc4BnInOp("y")->set_data_type(op_conf.shape_elem_cnt_conf().data_type());
BlobDesc4BnInOp("y")->mut_shape() = Shape({1});
return Maybe<void>::Ok();
}
} // namespace
Maybe<void> ShapeElemCntOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
return InferBlobDescs(op_conf(), BlobDesc4BnInOp);
}
Maybe<void> ShapeElemCntOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
GetBlobDesc4BnInOp("y")->set_data_type(op_conf().shape_elem_cnt_conf().data_type());
GetBlobDesc4BnInOp("y")->mut_shape() = Shape({1});
return Maybe<void>::Ok();
return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp);
}
void ShapeElemCntOp::VirtualGenKernelConf(
......
......@@ -27,6 +27,9 @@ class ShapeElemCntOp final : public Operator {
~ShapeElemCntOp() override = default;
void InitFromOpConf() override;
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
......
......@@ -24,11 +24,25 @@ void SinkTickOp::InitFromOpConf() {
EnrollOutputBn("out", false);
}
namespace {
Maybe<void> InferBlobDescs(const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {
BlobDesc4BnInOp("out")->mut_shape() = Shape({1});
return Maybe<void>::Ok();
}
} // namespace
Maybe<void> SinkTickOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
return InferBlobDescs(BlobDesc4BnInOp);
}
Maybe<void> SinkTickOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
GetBlobDesc4BnInOp("out")->mut_shape() = Shape({1});
return Maybe<void>::Ok();
return InferBlobDescs(GetBlobDesc4BnInOp);
}
Maybe<void> SinkTickOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const {
......
......@@ -28,6 +28,9 @@ class SinkTickOp final : public Operator {
~SinkTickOp() = default;
void InitFromOpConf() override;
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
......
......@@ -26,6 +26,13 @@ void SourceTickOp::InitFromOpConf() {
LogicalNode* SourceTickOp::NewProperLogicalNode() const { return new SourceTickLogicalNode(); }
Maybe<void> SourceTickOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
BlobDesc4BnInOp("out")->mut_shape() = Shape({1});
return Maybe<void>::Ok();
}
Maybe<void> SourceTickOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
......@@ -35,7 +42,7 @@ Maybe<void> SourceTickOp::InferOutBlobDescs(
}
Maybe<void> SourceTickOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const {
SbpSignatureBuilder().Split(output_bns(), 0).Build(sbp_sig_list->mutable_sbp_signature()->Add());
SbpSignatureBuilder().Broadcast(output_bns()).Build(sbp_sig_list->mutable_sbp_signature()->Add());
return Maybe<void>::Ok();
}
......
......@@ -28,6 +28,9 @@ class SourceTickOp final : public Operator {
~SourceTickOp() = default;
void InitFromOpConf() override;
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
......
......@@ -27,6 +27,9 @@ class SrcSubsetTickOp final : public Operator {
~SrcSubsetTickOp() = default;
void InitFromOpConf() override;
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature*) const override;
......@@ -46,11 +49,25 @@ LogicalNode* SrcSubsetTickOp::NewProperLogicalNode() const {
return new SrcSubsetTickLogicalNode();
}
namespace {
Maybe<void> InferBlobDescs(const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {
BlobDesc4BnInOp("out")->mut_shape() = Shape({1});
return Maybe<void>::Ok();
}
} // namespace
Maybe<void> SrcSubsetTickOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
return InferBlobDescs(BlobDesc4BnInOp);
}
Maybe<void> SrcSubsetTickOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature*) const {
GetBlobDesc4BnInOp("out")->mut_shape() = Shape({1});
return Maybe<void>::Ok();
return InferBlobDescs(GetBlobDesc4BnInOp);
}
Maybe<void> SrcSubsetTickOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const {
......
......@@ -17,6 +17,24 @@ limitations under the License.
namespace oneflow {
namespace {
Maybe<void> InferBlobDescs(const OperatorConf& op_conf,
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {
const SyncDynamicResizeOpConf& conf = op_conf.sync_dynamic_resize_conf();
CHECK_EQ_OR_RETURN(conf.axis(), 0);
const BlobDesc* in = BlobDesc4BnInOp("in");
const BlobDesc* size = BlobDesc4BnInOp("size");
CHECK_EQ_OR_RETURN(size->shape().elem_cnt(), 1);
CHECK_OR_RETURN(IsIntegralDataType(size->data_type()));
BlobDesc* out = BlobDesc4BnInOp("out");
*out = *in;
out->set_is_dynamic(true);
return Maybe<void>::Ok();
}
} // namespace
class SyncDynamicResizeOp : public Operator {
public:
OF_DISALLOW_COPY_AND_MOVE(SyncDynamicResizeOp);
......@@ -29,19 +47,16 @@ class SyncDynamicResizeOp : public Operator {
EnrollOutputBn("out")->set_header_infered_before_compute(false);
}
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override {
return InferBlobDescs(op_conf(), BlobDesc4BnInOp);
}
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override {
const SyncDynamicResizeOpConf& conf = op_conf().sync_dynamic_resize_conf();
CHECK_EQ_OR_RETURN(conf.axis(), 0);
const BlobDesc* in = GetBlobDesc4BnInOp("in");
const BlobDesc* size = GetBlobDesc4BnInOp("size");
CHECK_EQ_OR_RETURN(size->shape().elem_cnt(), 1);
CHECK_OR_RETURN(IsIntegralDataType(size->data_type()));
BlobDesc* out = GetBlobDesc4BnInOp("out");
*out = *in;
out->set_is_dynamic(true);
return Maybe<void>::Ok();
return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp);
}
Maybe<void> GetSbpSignatures(
......
......@@ -18,6 +18,26 @@ limitations under the License.
namespace oneflow {
namespace {
Maybe<void> InferBlobDescs(const OperatorConf& op_conf,
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {
const BlobDesc* in_desc = BlobDesc4BnInOp("in");
CHECK_EQ_OR_RETURN(in_desc->data_type(), DataType::kTensorBuffer);
CHECK_EQ_OR_RETURN(in_desc->shape().NumAxes(), 1);
DimVector dim_vec = in_desc->shape().dim_vec();
const ShapeProto& shape = op_conf.tensor_buffer_to_tensor_list_conf().shape();
dim_vec.insert(dim_vec.end(), shape.dim().begin(), shape.dim().end());
BlobDesc* out_desc = BlobDesc4BnInOp("out");
out_desc->mut_shape() = Shape(dim_vec);
out_desc->set_data_type(op_conf.tensor_buffer_to_tensor_list_conf().data_type());
out_desc->set_is_tensor_list(true);
out_desc->set_is_dynamic(true);
return Maybe<void>::Ok();
}
} // namespace
class TensorBufferToTensorListOp final : public Operator {
public:
OF_DISALLOW_COPY_AND_MOVE(TensorBufferToTensorListOp);
......@@ -30,21 +50,16 @@ class TensorBufferToTensorListOp final : public Operator {
EnrollOutputBn("out", false)->set_header_infered_before_compute(false);
}
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override {
return InferBlobDescs(op_conf(), BlobDesc4BnInOp);
}
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override {
const BlobDesc* in_desc = GetBlobDesc4BnInOp("in");
CHECK_EQ_OR_RETURN(in_desc->data_type(), DataType::kTensorBuffer);
CHECK_EQ_OR_RETURN(in_desc->shape().NumAxes(), 1);
DimVector dim_vec = in_desc->shape().dim_vec();
const ShapeProto& shape = op_conf().tensor_buffer_to_tensor_list_conf().shape();
dim_vec.insert(dim_vec.end(), shape.dim().begin(), shape.dim().end());
BlobDesc* out_desc = GetBlobDesc4BnInOp("out");
out_desc->mut_shape() = Shape(dim_vec);
out_desc->set_data_type(op_conf().tensor_buffer_to_tensor_list_conf().data_type());
out_desc->set_is_tensor_list(true);
out_desc->set_is_dynamic(true);
return Maybe<void>::Ok();
return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp);
}
private:
......
......@@ -18,6 +18,27 @@ limitations under the License.
namespace oneflow {
namespace {
Maybe<void> InferBlobDescs(const Operator& op,
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {
const BlobDesc* in_desc = BlobDesc4BnInOp(op.SoleIbn());
CHECK_OR_RETURN(in_desc->is_tensor_list());
CHECK_GT_OR_RETURN(in_desc->shape().NumAxes(), 1);
const int64_t N = in_desc->shape().At(0);
CHECK_EQ_OR_RETURN(N, op.output_bns().size());
DimVector dim_vec{in_desc->shape().dim_vec().begin() + 1, in_desc->shape().dim_vec().end()};
FOR_RANGE(int, i, 0, N) {
BlobDesc* out_i = BlobDesc4BnInOp(op.output_bns().Get(i));
out_i->mut_shape() = Shape(dim_vec);
out_i->set_data_type(in_desc->data_type());
out_i->set_is_dynamic(true);
}
return Maybe<void>::Ok();
}
} // namespace
class TensorListSplitOp final : public Operator {
public:
OF_DISALLOW_COPY_AND_MOVE(TensorListSplitOp);
......@@ -32,22 +53,16 @@ class TensorListSplitOp final : public Operator {
});
}
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override {
return InferBlobDescs(*this, BlobDesc4BnInOp);
}
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override {
const BlobDesc* in_desc = GetBlobDesc4BnInOp(SoleIbn());
CHECK_OR_RETURN(in_desc->is_tensor_list());
CHECK_GT_OR_RETURN(in_desc->shape().NumAxes(), 1);
const int64_t N = in_desc->shape().At(0);
CHECK_EQ_OR_RETURN(N, output_bns().size());
DimVector dim_vec{in_desc->shape().dim_vec().begin() + 1, in_desc->shape().dim_vec().end()};
FOR_RANGE(int, i, 0, N) {
BlobDesc* out_i = GetBlobDesc4BnInOp(output_bns().Get(i));
out_i->mut_shape() = Shape(dim_vec);
out_i->set_data_type(in_desc->data_type());
out_i->set_is_dynamic(true);
}
return Maybe<void>::Ok();
return InferBlobDescs(*this, GetBlobDesc4BnInOp);
}
private:
......
......@@ -18,6 +18,21 @@ limitations under the License.
namespace oneflow {
namespace {
Maybe<void> InferBlobDescs(const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {
const BlobDesc* in_desc = BlobDesc4BnInOp("in");
CHECK_OR_RETURN(in_desc->is_tensor_list());
const int64_t N = in_desc->shape().At(0);
BlobDesc* out_desc = BlobDesc4BnInOp("out");
out_desc->mut_shape() = Shape({N});
out_desc->set_data_type(DataType::kTensorBuffer);
out_desc->set_is_dynamic(in_desc->is_dynamic());
return Maybe<void>::Ok();
}
} // namespace
class TensorListToTensorBufferOp final : public Operator {
public:
OF_DISALLOW_COPY_AND_MOVE(TensorListToTensorBufferOp);
......@@ -30,17 +45,16 @@ class TensorListToTensorBufferOp final : public Operator {
EnrollOutputBn("out", false)->set_header_infered_before_compute(false);
}
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override {
return InferBlobDescs(BlobDesc4BnInOp);
}
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override {
const BlobDesc* in_desc = GetBlobDesc4BnInOp("in");
CHECK_OR_RETURN(in_desc->is_tensor_list());
const int64_t N = in_desc->shape().At(0);
BlobDesc* out_desc = GetBlobDesc4BnInOp("out");
out_desc->mut_shape() = Shape({N});
out_desc->set_data_type(DataType::kTensorBuffer);
out_desc->set_is_dynamic(in_desc->is_dynamic());
return Maybe<void>::Ok();
return InferBlobDescs(GetBlobDesc4BnInOp);
}
private:
......
......@@ -18,17 +18,31 @@ limitations under the License.
namespace oneflow {
namespace {
Maybe<void> InferBlobDescs(const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {
BlobDesc4BnInOp("out")->mut_shape() = Shape({1});
return Maybe<void>::Ok();
}
} // namespace
void TickOp::InitFromOpConf() {
CHECK(op_conf().has_tick_conf());
EnrollRepeatedInputBn("tick", false);
EnrollOutputBn("out", false);
}
Maybe<void> TickOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
return InferBlobDescs(BlobDesc4BnInOp);
}
Maybe<void> TickOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
GetBlobDesc4BnInOp("out")->mut_shape() = Shape({1});
return Maybe<void>::Ok();
return InferBlobDescs(GetBlobDesc4BnInOp);
}
Maybe<void> TickOp::GetSbpSignatures(
......
......@@ -28,6 +28,9 @@ class TickOp final : public Operator {
~TickOp() = default;
void InitFromOpConf() override;
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
......
......@@ -26,6 +26,9 @@ class UniqueWithCountsOp final : public Operator {
~UniqueWithCountsOp() override = default;
void InitFromOpConf() override;
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
......@@ -49,27 +52,42 @@ void UniqueWithCountsOp::InitFromOpConf() {
EnrollTmpBn("workspace");
}
Maybe<void> UniqueWithCountsOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
const BlobDesc* x = GetBlobDesc4BnInOp("x");
namespace {
Maybe<void> InferBlobDescs(const OperatorConf& op_conf,
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {
const BlobDesc* x = BlobDesc4BnInOp("x");
CHECK_EQ_OR_RETURN(x->shape().NumAxes(), 1);
BlobDesc* y = GetBlobDesc4BnInOp("y");
BlobDesc* y = BlobDesc4BnInOp("y");
*y = *x;
const DataType idx_data_type = op_conf().unique_with_counts_conf().out_idx();
const DataType idx_data_type = op_conf.unique_with_counts_conf().out_idx();
CHECK(IsIndexDataType(idx_data_type));
BlobDesc* idx = GetBlobDesc4BnInOp("idx");
BlobDesc* idx = BlobDesc4BnInOp("idx");
*idx = *x;
idx->set_data_type(idx_data_type);
BlobDesc* count = GetBlobDesc4BnInOp("count");
BlobDesc* count = BlobDesc4BnInOp("count");
*count = *x;
count->set_data_type(idx_data_type);
BlobDesc* num_unique = GetBlobDesc4BnInOp("num_unique");
BlobDesc* num_unique = BlobDesc4BnInOp("num_unique");
num_unique->mut_shape() = Shape({1});
num_unique->set_data_type(idx_data_type);
return Maybe<void>::Ok();
}
} // namespace
Maybe<void> UniqueWithCountsOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
return InferBlobDescs(op_conf(), BlobDesc4BnInOp);
}
Maybe<void> UniqueWithCountsOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp);
}
Maybe<void> UniqueWithCountsOp::InferInternalBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
......
......@@ -43,6 +43,17 @@ void VariableOp::InitFromOpConf() {
EnrollOutputBn("out", is_trainable)->set_is_mutable(true);
}
Maybe<void> VariableOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
const VariableOpConf& variable_conf = op_conf().variable_conf();
BlobDesc* out_blob_desc = BlobDesc4BnInOp("out");
out_blob_desc->mut_shape() = Shape(variable_conf.shape());
out_blob_desc->set_data_type(variable_conf.has_data_type() ? variable_conf.data_type()
: job_desc().DefaultDataType());
return Maybe<void>::Ok();
}
Maybe<void> VariableOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
......
......@@ -32,6 +32,9 @@ class VariableOp final : public Operator {
const SbpSignature* sbp_signature) const override;
private:
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override;
Maybe<void> InferSbpSignature(
SbpSignature* sbp_signature, const SbpSignature& sbp_sig_conf,
const std::function<int32_t(const SbpSignature&)>& CalcOrderValue4SbpSig,
......
......@@ -28,17 +28,33 @@ LogicalNode* WaitAndSendIdsOp::NewProperLogicalNode() const {
return new WaitAndSendIdsLogicalNode();
}
namespace {
Maybe<void> InferBlobDescs(const OperatorConf& op_conf,
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {
BlobDesc4BnInOp("out")->mut_shape() = Shape({1});
BlobDesc4BnInOp("out")->set_data_type(op_conf.wait_and_send_ids_conf().data_type());
return Maybe<void>::Ok();
}
} // namespace
Maybe<void> WaitAndSendIdsOp::InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const {
CHECK_EQ_OR_RETURN(parallel_desc.parallel_num(), 1);
return InferBlobDescs(op_conf(), BlobDesc4BnInOp);
}
Maybe<void> WaitAndSendIdsOp::InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature) const {
CHECK_EQ_OR_RETURN(parallel_ctx->parallel_num(), 1);
GetBlobDesc4BnInOp("out")->mut_shape() = Shape({1});
GetBlobDesc4BnInOp("out")->set_data_type(op_conf().wait_and_send_ids_conf().data_type());
return Maybe<void>::Ok();
return InferBlobDescs(op_conf(), GetBlobDesc4BnInOp);
}
Maybe<void> WaitAndSendIdsOp::GetSbpSignatures(SbpSignatureList* sbp_sig_list) const {
SbpSignatureBuilder().Split(output_bns(), 0).Build(sbp_sig_list->mutable_sbp_signature()->Add());
SbpSignatureBuilder().Broadcast(output_bns()).Build(sbp_sig_list->mutable_sbp_signature()->Add());
return Maybe<void>::Ok();
}
......
......@@ -27,6 +27,9 @@ class WaitAndSendIdsOp final : public Operator {
~WaitAndSendIdsOp() = default;
void InitFromOpConf() override;
Maybe<void> InferLogicalOutBlobDescs(
const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp,
const ParallelDesc& parallel_desc) const override;
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册