未验证 提交 53e28bf1 编写于 作者: J Juncheng 提交者: GitHub

Refactor InferParallelSignature (#4244)

Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 d7f9d086
......@@ -301,18 +301,9 @@ const ParallelDesc& OpNode::BlobParallelDesc4Obn(const std::string& obn) const {
}
void OpNode::InferBlobParallelDesc() {
auto ParallelDesc4Obn = [&](const std::string& obn) -> ParallelDesc* {
auto iter = obn2blob_parallel_desc_.find(obn);
if (iter == obn2blob_parallel_desc_.end()) {
iter = obn2blob_parallel_desc_.emplace(obn, parallel_desc()).first;
}
return &iter->second;
};
auto LogicalBlobDesc4Ibn = [&](const std::string& ibn) -> const BlobDesc* {
return &LogicalBlobDesc4Lbi(op().BnInOp2Lbi(ibn));
};
CHECK_JUST(op().InferOutParallelDescIf(ParallelDesc4Obn, LogicalBlobDesc4Ibn, parallel_desc(),
&sbp_signature()));
for (const auto& bn : op().output_bns()) {
obn2blob_parallel_desc_.emplace(bn, *CHECK_JUST(op().GetParallelDesc4BnInOp(bn)));
}
}
void OpNode::InitLbi2SourceNode() {
......@@ -557,6 +548,7 @@ Maybe<void> OpGraph::InferLogicalBlobDesc(const Job& job) const {
auto LogicalBlobDesc4BnInOp = [&](const std::string& bn) -> const BlobDesc& {
return op_node->LogicalBlobDesc4Lbi(op_node->op().BnInOp2Lbi(bn));
};
JUST(op_node->mut_op()->FillOpParallelDesc(op_node->parallel_desc()));
JUST(op_node->mut_op()->FillLogicalInBlobDesc(LogicalBlobDesc4BnInOp));
// Infer ParallelSignature
JUST(op_node->mut_op()->InferParallelSignatureIf());
......
......@@ -538,6 +538,8 @@ Maybe<OpAttribute> JobBuildAndInferCtx::AddAndInferOp(const OperatorConf& op_con
auto new_op_conf = JUST(DecodeLbiHintAndReturnNewOpConf(*op, &sbp_sig_conf, &ibn2disable_boxing));
AddOpAndUpdateJobParallelViewConf(*new_op_conf, sbp_sig_conf, is_mirrored_parallel_view);
auto parallel_conf = JUST(InferOpParallelConf(*op, origin_parallel_conf, ibn2disable_boxing));
ParallelDesc parallel_desc(*parallel_conf);
JUST(op->FillOpParallelDesc(parallel_desc));
JUST(AddOpNameParallelConf2Placement(op_name, *parallel_conf));
UpdateLbi2DisableBoxing(*op, ibn2disable_boxing);
// infer batch_axis
......@@ -557,7 +559,6 @@ Maybe<OpAttribute> JobBuildAndInferCtx::AddAndInferOp(const OperatorConf& op_con
JUST(op->FillInBatchAxis(BatchAxis4Ibn));
JUST(op->InferBatchAxisIf());
ParallelDesc parallel_desc(*parallel_conf);
// infer mirrored signature
JUST(InferMirroredSignature(op, is_mirrored_parallel_view, parallel_desc));
// infer sbp signature
......@@ -576,10 +577,11 @@ Maybe<OpAttribute> JobBuildAndInferCtx::AddAndInferOp(const OperatorConf& op_con
}
return &iter->second;
};
JUST(op->InferOutParallelDescIf(ParallelDesc4Obn, GetBlobDesc4BnInOp, parallel_desc,
JUST(op->sbp_signature())));
// TODO(lixinqi): replace lbi2parallel_desc_from_producer_view_ with ParallelSignature
JUST(op->InferParallelSignatureIf());
for (const auto& bn : op->output_bns()) {
lbi2parallel_desc_from_producer_view_.emplace(op->BnInOp2Lbi(bn),
*JUST(op->GetParallelDesc4BnInOp(bn)));
}
JUST(AddLbiParallelConf2BlobPlacement(op, ParallelDesc4Obn));
// Infer whether input/output blobs are backward used
InferBlobBackwardSignature(op);
......
......@@ -36,7 +36,7 @@ class DistributeAddOp final : public Operator {
LogicalNode* NewProperLogicalNode() const override { return new DistributeConcatLogicalNode; }
private:
Maybe<void> InferParallelSignature() override;
Maybe<void> InferBlobParallelDesc() override;
Maybe<void> InferBatchAxis(
std::function<OptInt64*(const std::string&)> BatchAxis4BnInOp) const override;
Maybe<void> InferSbpSignature(
......@@ -53,22 +53,19 @@ void DistributeAddOp::InitFromOpConf() {
EnrollOutputBn("out");
}
Maybe<void> DistributeAddOp::InferParallelSignature() {
const auto& scope_storage = *Global<symbol::Storage<Scope>>::Get();
const auto& scope = JUST(scope_storage.MaybeGet(op_conf().scope_symbol_id()));
int64_t op_parallel_desc_symbol_id = JUST(scope.GetParallelDescSymbolId(op_conf()));
mut_parallel_signature()->set_op_parallel_desc_symbol_id(op_parallel_desc_symbol_id);
auto* map = mut_parallel_signature()->mutable_bn_in_op2parallel_desc_symbol_id();
(*map)["out"] = op_parallel_desc_symbol_id;
const auto& op_parallel_desc = JUST(scope.GetParallelDesc(op_conf()));
CHECK_EQ(op_parallel_desc.parallel_num(), input_bns().size());
Maybe<void> DistributeAddOp::InferBlobParallelDesc() {
HashMap<std::string, std::shared_ptr<const ParallelDesc>> bn2parallel_desc;
const std::shared_ptr<const ParallelDesc> op_parallel_desc = JUST(GetOpParallelDesc());
FOR_RANGE(int, i, 0, input_bns().size()) {
const auto& in_parallel_conf = op_parallel_desc.GetParallelIdOnlyParallelConf(i);
const std::shared_ptr<cfg::ParallelConf>& cfg_in_parallel_conf =
std::make_shared<cfg::ParallelConf>(in_parallel_conf);
(*map)[input_bns().Get(i)] =
Global<ForeignCallback>::Get()->MakeParallelDescSymbol(cfg_in_parallel_conf);
bn2parallel_desc[input_bns().Get(i)] =
std::make_shared<const ParallelDesc>(op_parallel_desc->GetParallelIdOnlyParallelConf(i));
}
bn2parallel_desc["out"] = op_parallel_desc;
FillBlobParallelDesc([&](const std::string& bn) -> Maybe<const ParallelDesc> {
auto it = bn2parallel_desc.find(bn);
CHECK_OR_RETURN(it != bn2parallel_desc.end());
return it->second;
});
return Maybe<void>::Ok();
}
......
......@@ -33,7 +33,7 @@ class DistributeCloneOp final : public Operator {
LogicalNode* NewProperLogicalNode() const override { return new DistributeSplitLogicalNode; }
private:
Maybe<void> InferParallelSignature() override;
Maybe<void> InferBlobParallelDesc() override;
Maybe<void> InferBatchAxis(
std::function<OptInt64*(const std::string&)> BatchAxis4BnInOp) const override;
Maybe<void> InferSbpSignature(
......@@ -44,10 +44,6 @@ class DistributeCloneOp final : public Operator {
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
Maybe<void> InferOutParallelDesc(
std::function<ParallelDesc*(const std::string&)> ParallelDesc4Obn,
std::function<const BlobDesc*(const std::string&)> LogicalBlobDesc4Ibn, const ParallelDesc&,
const SbpSignature*) const override;
};
void DistributeCloneOp::InitFromOpConf() {
......@@ -76,38 +72,19 @@ Maybe<void> DistributeCloneOp::InferOutBlobDescs(
return Maybe<void>::Ok();
}
Maybe<void> DistributeCloneOp::InferOutParallelDesc(
std::function<ParallelDesc*(const std::string&)> ParallelDesc4Obn,
std::function<const BlobDesc*(const std::string&)> LogicalBlobDesc4Ibn,
const ParallelDesc& op_parallel_desc, const SbpSignature*) const {
Maybe<void> DistributeCloneOp::InferBlobParallelDesc() {
HashMap<std::string, std::shared_ptr<const ParallelDesc>> bn2parallel_desc;
const std::shared_ptr<const ParallelDesc> op_parallel_desc = JUST(GetOpParallelDesc());
bn2parallel_desc["in"] = op_parallel_desc;
FOR_RANGE(int, i, 0, output_bns().size()) {
const auto& obn = output_bns().Get(i);
if (op_parallel_desc.parallel_num() > 1) {
CHECK_EQ_OR_RETURN(op_parallel_desc.parallel_num(), output_bns().size());
*ParallelDesc4Obn(obn) = ParallelDesc(op_parallel_desc.GetParallelIdOnlyParallelConf(i));
} else {
*ParallelDesc4Obn(obn) = op_parallel_desc;
}
}
return Maybe<void>::Ok();
}
Maybe<void> DistributeCloneOp::InferParallelSignature() {
const auto& scope_storage = *Global<symbol::Storage<Scope>>::Get();
const auto& scope = JUST(scope_storage.MaybeGet(op_conf().scope_symbol_id()));
int64_t op_parallel_desc_symbol_id = JUST(scope.GetParallelDescSymbolId(op_conf()));
mut_parallel_signature()->set_op_parallel_desc_symbol_id(op_parallel_desc_symbol_id);
auto* map = mut_parallel_signature()->mutable_bn_in_op2parallel_desc_symbol_id();
(*map)["in"] = op_parallel_desc_symbol_id;
const auto& op_parallel_desc = JUST(scope.GetParallelDesc(op_conf()));
CHECK_EQ_OR_RETURN(op_parallel_desc.parallel_num(), output_bns().size());
FOR_RANGE(int, i, 0, output_bns().size()) {
const auto& out_parallel_conf = op_parallel_desc.GetParallelIdOnlyParallelConf(i);
const std::shared_ptr<cfg::ParallelConf>& cfg_out_parallel_conf =
std::make_shared<cfg::ParallelConf>(out_parallel_conf);
(*map)[output_bns().Get(i)] =
Global<ForeignCallback>::Get()->MakeParallelDescSymbol(cfg_out_parallel_conf);
bn2parallel_desc[output_bns().Get(i)] =
std::make_shared<const ParallelDesc>(op_parallel_desc->GetParallelIdOnlyParallelConf(i));
}
FillBlobParallelDesc([&](const std::string& bn) -> Maybe<const ParallelDesc> {
auto it = bn2parallel_desc.find(bn);
CHECK_OR_RETURN(it != bn2parallel_desc.end());
return it->second;
});
return Maybe<void>::Ok();
}
......
......@@ -36,7 +36,7 @@ class DistributeConcatOp final : public Operator {
LogicalNode* NewProperLogicalNode() const override { return new DistributeConcatLogicalNode; }
private:
Maybe<void> InferParallelSignature() override;
Maybe<void> InferBlobParallelDesc() override;
Maybe<void> InferBatchAxis(
std::function<OptInt64*(const std::string&)> BatchAxis4BnInOp) const override;
Maybe<void> InferSbpSignature(
......@@ -102,22 +102,19 @@ Maybe<void> DistributeConcatOp::InferOutBlobDescs(
return Maybe<void>::Ok();
}
Maybe<void> DistributeConcatOp::InferParallelSignature() {
const auto& scope_storage = *Global<symbol::Storage<Scope>>::Get();
const auto& scope = JUST(scope_storage.MaybeGet(op_conf().scope_symbol_id()));
int64_t op_parallel_desc_symbol_id = JUST(scope.GetParallelDescSymbolId(op_conf()));
mut_parallel_signature()->set_op_parallel_desc_symbol_id(op_parallel_desc_symbol_id);
auto* map = mut_parallel_signature()->mutable_bn_in_op2parallel_desc_symbol_id();
(*map)["out"] = op_parallel_desc_symbol_id;
const auto& op_parallel_desc = JUST(scope.GetParallelDesc(op_conf()));
CHECK_EQ(op_parallel_desc.parallel_num(), input_bns().size());
Maybe<void> DistributeConcatOp::InferBlobParallelDesc() {
HashMap<std::string, std::shared_ptr<const ParallelDesc>> bn2parallel_desc;
const std::shared_ptr<const ParallelDesc> op_parallel_desc = JUST(GetOpParallelDesc());
FOR_RANGE(int, i, 0, input_bns().size()) {
const auto& in_parallel_conf = op_parallel_desc.GetParallelIdOnlyParallelConf(i);
const std::shared_ptr<cfg::ParallelConf>& cfg_in_parallel_conf =
std::make_shared<cfg::ParallelConf>(in_parallel_conf);
(*map)[input_bns().Get(i)] =
Global<ForeignCallback>::Get()->MakeParallelDescSymbol(cfg_in_parallel_conf);
bn2parallel_desc[input_bns().Get(i)] =
std::make_shared<const ParallelDesc>(op_parallel_desc->GetParallelIdOnlyParallelConf(i));
}
bn2parallel_desc["out"] = op_parallel_desc;
FillBlobParallelDesc([&](const std::string& bn) -> Maybe<const ParallelDesc> {
auto it = bn2parallel_desc.find(bn);
CHECK_OR_RETURN(it != bn2parallel_desc.end());
return it->second;
});
return Maybe<void>::Ok();
}
......
......@@ -33,7 +33,7 @@ class DistributeSplitOp final : public Operator {
LogicalNode* NewProperLogicalNode() const override { return new DistributeSplitLogicalNode; }
private:
Maybe<void> InferParallelSignature() override;
Maybe<void> InferBlobParallelDesc() override;
Maybe<void> InferBatchAxis(
std::function<OptInt64*(const std::string&)> BatchAxis4BnInOp) const override;
Maybe<void> InferSbpSignature(
......@@ -44,10 +44,6 @@ class DistributeSplitOp final : public Operator {
Maybe<void> InferOutBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const override;
Maybe<void> InferOutParallelDesc(
std::function<ParallelDesc*(const std::string&)> ParallelDesc4Obn,
std::function<const BlobDesc*(const std::string&)> LogicalBlobDesc4Ibn, const ParallelDesc&,
const SbpSignature*) const override;
Maybe<void> GetSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
......@@ -90,38 +86,19 @@ Maybe<void> DistributeSplitOp::InferOutBlobDescs(
return Maybe<void>::Ok();
}
Maybe<void> DistributeSplitOp::InferOutParallelDesc(
std::function<ParallelDesc*(const std::string&)> ParallelDesc4Obn,
std::function<const BlobDesc*(const std::string&)> LogicalBlobDesc4Ibn,
const ParallelDesc& op_parallel_desc, const SbpSignature*) const {
Maybe<void> DistributeSplitOp::InferBlobParallelDesc() {
HashMap<std::string, std::shared_ptr<const ParallelDesc>> bn2parallel_desc;
const std::shared_ptr<const ParallelDesc> op_parallel_desc = JUST(GetOpParallelDesc());
bn2parallel_desc["in"] = op_parallel_desc;
FOR_RANGE(int, i, 0, output_bns().size()) {
const auto& obn = output_bns().Get(i);
if (op_parallel_desc.parallel_num() > 1) {
CHECK_EQ(op_parallel_desc.parallel_num(), output_bns().size());
*ParallelDesc4Obn(obn) = ParallelDesc(op_parallel_desc.GetParallelIdOnlyParallelConf(i));
} else {
*ParallelDesc4Obn(obn) = op_parallel_desc;
}
}
return Maybe<void>::Ok();
}
Maybe<void> DistributeSplitOp::InferParallelSignature() {
const auto& scope_storage = *Global<symbol::Storage<Scope>>::Get();
const auto& scope = JUST(scope_storage.MaybeGet(op_conf().scope_symbol_id()));
int64_t op_parallel_desc_symbol_id = JUST(scope.GetParallelDescSymbolId(op_conf()));
mut_parallel_signature()->set_op_parallel_desc_symbol_id(op_parallel_desc_symbol_id);
auto* map = mut_parallel_signature()->mutable_bn_in_op2parallel_desc_symbol_id();
(*map)["in"] = op_parallel_desc_symbol_id;
const auto& op_parallel_desc = JUST(scope.GetParallelDesc(op_conf()));
CHECK_EQ(op_parallel_desc.parallel_num(), output_bns().size());
FOR_RANGE(int, i, 0, output_bns().size()) {
const auto& out_parallel_conf = op_parallel_desc.GetParallelIdOnlyParallelConf(i);
const std::shared_ptr<cfg::ParallelConf>& cfg_out_parallel_conf =
std::make_shared<cfg::ParallelConf>(out_parallel_conf);
(*map)[output_bns().Get(i)] =
Global<ForeignCallback>::Get()->MakeParallelDescSymbol(cfg_out_parallel_conf);
bn2parallel_desc[output_bns().Get(i)] =
std::make_shared<const ParallelDesc>(op_parallel_desc->GetParallelIdOnlyParallelConf(i));
}
FillBlobParallelDesc([&](const std::string& bn) -> Maybe<const ParallelDesc> {
auto it = bn2parallel_desc.find(bn);
CHECK_OR_RETURN(it != bn2parallel_desc.end());
return it->second;
});
return Maybe<void>::Ok();
}
......
......@@ -109,26 +109,26 @@ class ImageDecoderRandomCropResizeOp final : public Operator {
}
}
Maybe<void> InferParallelSignature() override {
Maybe<void> InferBlobParallelDesc() override {
HashMap<std::string, std::shared_ptr<const ParallelDesc>> bn2parallel_desc;
const std::shared_ptr<const ParallelDesc> op_parallel_desc = JUST(GetOpParallelDesc());
bn2parallel_desc["out"] = op_parallel_desc;
if (device_type() == DeviceType::kCPU) {
return Operator::InferParallelSignature();
bn2parallel_desc["in"] = op_parallel_desc;
} else if (device_type() == DeviceType::kGPU) {
const auto& scope_storage = *Global<symbol::Storage<Scope>>::Get();
const auto& scope = JUST(scope_storage.MaybeGet(op_conf().scope_symbol_id()));
const int64_t device_parallel_desc_symbol_id =
scope.scope_proto().device_parallel_desc_symbol_id();
const int64_t host_parallel_desc_symbol_id =
scope.scope_proto().host_parallel_desc_symbol_id();
mut_parallel_signature()->set_op_parallel_desc_symbol_id(device_parallel_desc_symbol_id);
auto* map = mut_parallel_signature()->mutable_bn_in_op2parallel_desc_symbol_id();
for (const auto& ibn : input_bns()) { (*map)[ibn] = host_parallel_desc_symbol_id; }
for (const auto& obn : output_bns()) { (*map)[obn] = device_parallel_desc_symbol_id; }
for (const auto& tbn : tmp_bns()) { (*map)[tbn] = device_parallel_desc_symbol_id; }
return Maybe<void>::Ok();
std::shared_ptr<ParallelDesc> in_parallel_desc =
std::make_shared<ParallelDesc>(*op_parallel_desc);
in_parallel_desc->set_device_type(DeviceType::kCPU);
bn2parallel_desc["in"] = in_parallel_desc;
} else {
UNIMPLEMENTED();
return Maybe<void>::Ok();
UNIMPLEMENTED_THEN_RETURN();
}
FillBlobParallelDesc([&](const std::string& bn) -> Maybe<const ParallelDesc> {
auto it = bn2parallel_desc.find(bn);
CHECK_OR_RETURN(it != bn2parallel_desc.end());
return it->second;
});
return Maybe<void>::Ok();
}
};
......
......@@ -23,6 +23,7 @@ limitations under the License.
#include "oneflow/core/job/scope.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/operator/op_node_signature.pb.h"
#include "oneflow/core/job/foreign_callback.h"
namespace oneflow {
......@@ -55,6 +56,9 @@ void Operator::Init(const OperatorConf& op_conf, const JobDesc* conf_job_desc) {
*this_op_conf = op_conf;
if (has_job_desc() && job_desc().IsPredict()) { this_op_conf->set_trainable(false); }
InitFromOpConf();
input_output_bns_.Reserve(input_bns().size() + output_bns().size());
for (const auto& bn : input_bns()) { *input_output_bns_.Add() = bn; }
for (const auto& bn : output_bns()) { *input_output_bns_.Add() = bn; }
}
LogicalNode* Operator::NewProperLogicalNode() const { return new NormalForwardLogicalNode; }
......@@ -98,23 +102,63 @@ Maybe<const std::string*> Operator::obn4lbi(const LogicalBlobId& lbi) const {
}
Maybe<void> Operator::InferParallelSignatureIf() {
JUST(InferBlobParallelDesc());
if (op_conf().scope_symbol_id() == 0) { return Maybe<void>::Ok(); }
return InferParallelSignature();
}
Maybe<void> Operator::InferParallelSignature() {
const auto& scope_storage = *Global<symbol::Storage<Scope>>::Get();
const auto& scope = JUST(scope_storage.MaybeGet(op_conf().scope_symbol_id()));
int64_t parallel_desc_symbol_id = JUST(scope.GetParallelDescSymbolId(op_conf()));
auto* parallel_signature = op_attribute_.mutable_parallel_signature();
parallel_signature->set_op_parallel_desc_symbol_id(parallel_desc_symbol_id);
auto* map = parallel_signature->mutable_bn_in_op2parallel_desc_symbol_id();
for (const auto& ibn : input_bns()) { (*map)[ibn] = parallel_desc_symbol_id; }
for (const auto& obn : output_bns()) { (*map)[obn] = parallel_desc_symbol_id; }
CHECK_OR_RETURN(op_parallel_desc_);
CHECK_OR_RETURN(bn2parallel_desc_);
for (const auto& pair : *bn2parallel_desc_) {
if (*pair.second == *op_parallel_desc_) {
(*map)[pair.first] = parallel_desc_symbol_id;
} else {
(*map)[pair.first] = Global<ForeignCallback>::Get()->MakeParallelDescSymbol(
std::make_shared<cfg::ParallelConf>(pair.second->parallel_conf()));
}
}
// TODO(liujuncheng): remove this
for (const auto& tbn : tmp_bns()) { (*map)[tbn] = parallel_desc_symbol_id; }
return Maybe<void>::Ok();
}
Maybe<const ParallelDesc> Operator::GetParallelDesc4BnInOp(const std::string& bn) const {
CHECK_OR_RETURN(bn2parallel_desc_);
auto it = bn2parallel_desc_->find(bn);
CHECK_OR_RETURN(it != bn2parallel_desc_->end());
return it->second;
}
Maybe<void> Operator::FillBlobParallelDesc(
const std::function<Maybe<const ParallelDesc>(const std::string&)>& ParallelDesc4Bn) {
CHECK_OR_RETURN(!bn2parallel_desc_);
bn2parallel_desc_.reset(new HashMap<std::string, std::shared_ptr<const ParallelDesc>>);
for (const auto& bn : input_output_bns()) {
CHECK(bn2parallel_desc_->emplace(bn, JUST(ParallelDesc4Bn(bn))).second);
}
return Maybe<void>::Ok();
}
Maybe<void> Operator::InferBlobParallelDesc() {
FillBlobParallelDesc(
[&](const std::string& bn) -> Maybe<const ParallelDesc> { return GetOpParallelDesc(); });
return Maybe<void>::Ok();
}
Maybe<void> Operator::FillOpParallelDesc(const ParallelDesc& parallel_desc) {
CHECK_OR_RETURN(!op_parallel_desc_);
op_parallel_desc_.reset(new ParallelDesc(parallel_desc));
return Maybe<void>::Ok();
}
Maybe<const ParallelDesc> Operator::GetOpParallelDesc() const {
CHECK_OR_RETURN(op_parallel_desc_);
return op_parallel_desc_;
}
namespace {
Maybe<void> FillLogicalBlobDesc(
......@@ -286,22 +330,6 @@ Maybe<void> Operator::InferInplaceObn2Ibn(
return Maybe<void>::Ok();
}
Maybe<void> Operator::InferOutParallelDescIf(
std::function<ParallelDesc*(const std::string&)> ParallelDesc4Obn,
std::function<const BlobDesc*(const std::string&)> LogicalBlobDesc4Ibn,
const ParallelDesc& op_parallel_desc, const SbpSignature* sbp_signature) const {
return InferOutParallelDesc(ParallelDesc4Obn, LogicalBlobDesc4Ibn, op_parallel_desc,
sbp_signature);
}
Maybe<void> Operator::InferOutParallelDesc(
std::function<ParallelDesc*(const std::string&)> ParallelDesc4Obn,
std::function<const BlobDesc*(const std::string&)> LogicalBlobDesc4Ibn,
const ParallelDesc& op_parallel_desc, const SbpSignature* sbp_signature) const {
for (const auto& obn : output_bns()) { *ParallelDesc4Obn(obn) = op_parallel_desc; }
return Maybe<void>::Ok();
}
Maybe<void> Operator::InferOutputBlobTimeShapeIf(
std::function<const Shape*(const std::string&)> GetTimeShape4BnInOp,
const ParallelContext* parallel_ctx, Shape* time_shape) const {
......@@ -1091,6 +1119,7 @@ Maybe<Operator> ConstructAndInferOp(const OperatorConf& op_conf,
bool is_mirrored = scope.opt_mirrored_parallel_conf().has_mirrored_parallel();
const auto& op = ConstructOp(op_conf, JUST(scope.job_desc()));
JUST(CheckOpInputSignature(*op, upstream_signature));
JUST(op->FillOpParallelDesc(parallel_desc));
HashMap<std::string, std::unique_ptr<BlobDesc>> bn_in_op2blob_desc;
for (const auto& ibn : op->input_bns()) {
const auto& map = upstream_signature.logical_blob_desc_signature().bn_in_op2blob_desc();
......
......@@ -87,7 +87,13 @@ class Operator {
#undef DEFINE_BLOB_NAMES_GETTER
const PbRpf<std::string>& input_output_bns() const { return input_output_bns_; };
Maybe<void> FillOpParallelDesc(const ParallelDesc& parallel_desc);
Maybe<const ParallelDesc> GetOpParallelDesc() const;
Maybe<void> InferParallelSignatureIf();
Maybe<const ParallelDesc> GetParallelDesc4BnInOp(const std::string& bn) const;
// Read: shape of input_blobs
// Write: shape of output_blobs
......@@ -128,15 +134,6 @@ class Operator {
const ParallelContext* parallel_ctx,
const SbpSignature* sbp_signature) const;
Maybe<void> InferOutParallelDescIf(
std::function<ParallelDesc*(const std::string&)> ParallelDesc4Obn,
std::function<const BlobDesc*(const std::string&)> LogicalBlobDesc4Ibn, const ParallelDesc&,
const SbpSignature*) const;
virtual Maybe<void> InferOutParallelDesc(
std::function<ParallelDesc*(const std::string&)> ParallelDesc4Obn,
std::function<const BlobDesc*(const std::string&)> LogicalBlobDesc4Ibn, const ParallelDesc&,
const SbpSignature*) const;
Maybe<void> FillInBatchAxis(
const std::function<Maybe<const OptInt64*>(const std::string&)>& BatchAxis4BnInOp);
Maybe<void> FillOutBatchAxis(
......@@ -202,7 +199,9 @@ class Operator {
}
protected:
virtual Maybe<void> InferParallelSignature();
Maybe<void> FillBlobParallelDesc(
const std::function<Maybe<const ParallelDesc>(const std::string&)>& ParallelDesc4Bn);
virtual Maybe<void> InferBlobParallelDesc();
virtual Maybe<void> InferOutBlobDescs(
std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp, const ParallelContext*,
const SbpSignature* sbp_signature) const;
......@@ -320,10 +319,14 @@ class Operator {
OpAttribute op_attribute_;
const JobDesc* job_desc_;
HashMap<LogicalBlobId, std::string> lbi2obn_;
std::shared_ptr<const ParallelDesc> op_parallel_desc_;
std::unique_ptr<HashMap<std::string, std::shared_ptr<const ParallelDesc>>> bn2parallel_desc_;
std::unique_ptr<HashMap<std::string, std::shared_ptr<const BlobDesc>>> ibn2logical_blob_desc_;
std::unique_ptr<HashMap<std::string, std::shared_ptr<const BlobDesc>>> obn2logical_blob_desc_;
std::unique_ptr<HashMap<std::string, std::shared_ptr<const OptInt64>>> ibn2batch_axis_;
std::unique_ptr<HashMap<std::string, std::shared_ptr<const OptInt64>>> obn2batch_axis_;
PbRpf<std::string> input_output_bns_;
};
std::string GenRepeatedBn(const std::string& bn_prefix, int32_t idx);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册