提交 13048516 编写于 作者: W willzhang4a58

special ibn2lbn

上级 49523253
......@@ -24,15 +24,16 @@ void CompTaskNode::DataFwBuildExecAndProducedRegsts(Path* path) {
if (GetBpNode() != nullptr) {
FwAddCopyInOp(&extern_in_lbn2consumers);
}
FwAddCloneOp();
mut_exec_gph().UpdateSourceAndSink();
// data regst
std::unique_ptr<RegstDesc> data_regst(new DisContigRegstDesc);
BindProducedRegstAndOutEdge(data_regst.get(), SoleOutEdge());
EnrollProducedRegstDesc("data", std::move(data_regst));
FwSetDataRegstDesc(lbn2producer, extern_in_lbn2consumers);
// model_tmp regst
std::unique_ptr<RegstDesc> model_tmp_regst(new DisContigRegstDesc);
EnrollProducedRegstDesc("model_tmp", std::move(model_tmp_regst));
FwSetModelTmpRegstDesc();
FwSetDataRegstDesc(lbn2producer, extern_in_lbn2consumers);
}
void CompTaskNode::ModelUpdateFwBuildExecAndProducedRegsts(Path* path) {
......@@ -125,52 +126,6 @@ void CompTaskNode::FwAddCopyInOp(Lbn2NodeIbnVecMap* extern_in_lbn2consumers) {
}
}
void CompTaskNode::FwAddCloneOp() {
std::vector<CloneInfo> clone_info_vec;
CollectCloneInfoVec(&clone_info_vec);
for (const CloneInfo& clone_info : clone_info_vec) {
AddOneCloneNode(clone_info);
}
}
void CompTaskNode::FwCollectCloneInfoVec(
std::vector<CloneInfo>* clone_info_vec) {
}
void CompTaskNode::FwAddOneCloneNode(const CloneInfo& clone_info) {
ExecNode* clone_node = mut_exec_gph().NewFinalNode();
clone_node->mut_op() = clone_info.clone_op;
// InEdge
ExecEdge* in_edge = mut_exec_gph().NewExecEdge();
in_edge->set_lbn(clone_info.lbn);
in_edge->mut_dst_bn() = clone_node->op()->SoleIbn();
in_edge->mut_src_bn() = clone_info.edges().front()->obn();
Connect(clone_info.pred_node, in_edge, clone_node);
// OutEdge
CHECK_EQ(clone_node->op()->output_bns().size(), clone_info.edges.size());
for (size_t i = 0; i < clone_info.edges.size(); ++i) {
const std::string& obn = clone_node->op()->output_bns().at(i);
ExecEdge* out_edge = clone_info.edges.at(i);
ExecNode* dst_node = out_edge->dst_node();
DisConnect(out_edge);
out_edge->mut_src_bn() = obn;
Connect(clone_node, out_edge, dst_node);
}
}
void CompTaskNode::FwSetModelTmpRegstDesc() {
RegstDesc* model_tmp_regst = GetProducedRegstDesc("model_tmp");
for (const std::unique_ptr<ExecNode>& node : exec_gph().nodes()) {
for (const std::string& mtbn : node->op()->model_tmp_bns()) {
std::string lbn = node->op()->mtbn2lbn(mtbn);
Shape* ptr = model_tmp_regst->EnrollWithLbn(lbn);
node->op()->SetShapePtr(mtbn, ptr);
node->BindBnInOpAndRegst(mtbn, model_tmp_regst);
}
node->op()->InferShape4Mtb();
}
}
void CompTaskNode::FwSetDataRegstDesc(
const Lbn2NodeObnMap& lbn2producer,
const Lbn2NodeIbnVecMap& extern_in_lbn2consumers) {
......@@ -222,6 +177,19 @@ void CompTaskNode::FwSetDataRegstDesc(
}
}
void CompTaskNode::FwSetModelTmpRegstDesc() {
RegstDesc* model_tmp_regst = GetProducedRegstDesc("model_tmp");
for (const std::unique_ptr<ExecNode>& node : exec_gph().nodes()) {
for (const std::string& mtbn : node->op()->model_tmp_bns()) {
std::string lbn = node->op()->mtbn2lbn(mtbn);
Shape* ptr = model_tmp_regst->EnrollWithLbn(lbn);
node->op()->SetShapePtr(mtbn, ptr);
node->BindBnInOpAndRegst(mtbn, model_tmp_regst);
}
node->op()->InferShape4Mtb();
}
}
void CompTaskNode::BpBuildExecAndProducedRegsts(Path* path) {
const ExecGraph& fw_gph = GetFwNode()->exec_gph();
const ExecNode* cp_in_node = fw_gph.source_node().SoleOutEdge()->dst_node();
......
......@@ -43,8 +43,6 @@ class CompTaskNode : public TaskNode {
Lbn2NodeObnMap* lbn2producer,
Lbn2NodeIbnVecMap* extern_in_lbn2consumers);
void FwAddCopyInOp(Lbn2NodeIbnVecMap* extern_in_lbn2consumers);
void FwAddCloneOp();
void FwBindOutEdgeAndRegst();
void FwSetProducedRegstDescs(
const Lbn2NodeObnMap& lbn2producer,
const Lbn2NodeIbnVecMap& extern_in_lbn2consumers);
......
......@@ -88,8 +88,7 @@ void TaskNode::EnrollProducedRegstDesc(
void TaskNode::SubscribeRegstDescInnerPath() {
for (const TaskEdge* edge : in_edges()) {
RegstDesc* regst = GetRelatedRegst(edge);
Subscribe(regst);
Subscribe(GetRelatedRegst(edge));
}
}
......
......@@ -14,7 +14,7 @@ class CloneOp final : public SysOperator {
void Init(const OperatorConf& op_conf) override;
void InferShape4ObAndDtbFromIb() const override { TODO(); }
std::string ibn2lbn(const std::string& input_bn) const override {
std::string normal_ibn2lbn(const std::string& input_bn) const override {
return GetValueFromPbOpConf("lbn");
}
std::string obn2lbn(const std::string& output_bn) const override {
......
......@@ -14,7 +14,7 @@ class ConcatOp final : public SysOperator {
void Init(const OperatorConf& op_conf) override;
void InferShape4ObAndDtbFromIb() const override { TODO(); }
std::string ibn2lbn(const std::string& input_bn) const override {
std::string normal_ibn2lbn(const std::string& input_bn) const override {
return GetValueFromPbOpConf("lbn");
}
std::string obn2lbn(const std::string& output_bn) const override {
......
......@@ -13,7 +13,7 @@ class ConvolutionOp final : public UserOperator {
void Init(const OperatorConf& op_conf) override;
void InferShape4ObAndDtbFromIb() const override { TODO(); }
void InferShape4MbAndMtb() const override { TODO(); }
void InferShape4Mtb() const override { TODO(); }
private:
......
......@@ -14,7 +14,7 @@ class CopyOp final : public SysOperator {
void Init(const OperatorConf& op_conf) override;
void InferShape4ObAndDtbFromIb() const override { TODO(); }
std::string ibn2lbn(const std::string& input_bn) const override {
std::string normal_ibn2lbn(const std::string& input_bn) const override {
return ibn2lbn_.at(input_bn);
}
std::string obn2lbn(const std::string& output_bn) const override {
......
......@@ -14,7 +14,7 @@ class InnerProductOp final : public UserOperator {
void Init(const OperatorConf& op_conf) override;
void InferShape4ObAndDtbFromIb() const override { TODO(); }
void InferShape4MbAndMtb() const override { TODO(); }
void InferShape4Mtb() const override { TODO(); }
private:
......
......@@ -17,7 +17,7 @@ class MultinomialLogisticLossOp : public UserOperator {
bool IsLossOp() const override { return true; }
void InferShape4ObAndDtbFromIb() const override { TODO(); }
void InferShape4MbAndMtb() const override { TODO(); }
void InferShape4Mtb() const override { TODO(); }
private:
......
......@@ -27,6 +27,14 @@ std::string Operator::odbn2lbn(const std::string& output_diff_bn) const {
std::string Operator::mdbn2lbn(const std::string& model_diff_bn) const {
return mbn2lbn(GenUnDiffBn(model_diff_bn));
}
std::string Operator::ibn2lbn(const std::string& input_bn) const {
auto it = special_ibn2lbn_.find(input_bn);
if (it == special_ibn2lbn_.end()) {
return normal_ibn2lbn(input_bn);
} else {
return it->second;
}
}
std::string Operator::GetValueFromPbOpConf(const std::string& k) const {
return GetValueFromPbMessage(*pb_op_conf_, k);
......@@ -79,7 +87,7 @@ void Operator::EnrollBn(std::vector<std::string>* bn_vec,
CHECK(bn_in_op2shape_ptr_.emplace(bn, nullptr).second);
}
std::string UserOperator::ibn2lbn(const std::string& input_bn) const {
std::string UserOperator::normal_ibn2lbn(const std::string& input_bn) const {
return GetValueFromPbOpConf(input_bn);
}
std::string UserOperator::obn2lbn(const std::string& output_bn) const {
......
......@@ -28,11 +28,15 @@ class Operator {
std::string idbn2lbn(const std::string& input_diff_bn) const;
std::string odbn2lbn(const std::string& output_diff_bn) const;
std::string mdbn2lbn(const std::string& model_diff_bn) const;
std::string ibn2lbn(const std::string& input_bn) const;
virtual std::string ibn2lbn(const std::string& input_bn) const = 0;
virtual std::string obn2lbn(const std::string& output_bn) const = 0;
virtual std::string mtbn2lbn(const std::string& model_tmp_bn) const = 0;
virtual std::string mbn2lbn(const std::string& model_bn) const = 0;
void AddSpecialIbn2Lbn(const std::string& ibn, const std::string& lbn) {
CHECK(special_ibn2lbn_.emplace(ibn, lbn).second);
}
// Getters
const std::string& op_name() const { return op_name_; }
......@@ -58,11 +62,12 @@ class Operator {
void SetShapePtr(const std::string& bn_in_op, Shape* ptr) const;
void SetNull4AllShapePtr() const;
virtual void InferShape4ObAndDtbFromIb() const = 0;
virtual void InferShape4MbAndMtb() const = 0;
virtual void InferShape4Mtb() const = 0;
protected:
std::string& mut_op_name() { return op_name_; }
std::unique_ptr<PbMessage>& mut_pb_op_conf() { return pb_op_conf_; }
virtual std::string normal_ibn2lbn(const std::string& input_bn) const = 0;
// enroll data blobs
void EnrollDataTmpBn(const std::string& dtbn);
......@@ -82,6 +87,8 @@ class Operator {
std::string op_name_;
std::unique_ptr<PbMessage> pb_op_conf_;
std::unordered_map<std::string, std::string> special_ibn2lbn_;
// blob name in op
std::vector<std::string> data_tmp_bns_;
std::vector<std::string> input_bns_;
......@@ -102,7 +109,7 @@ class UserOperator : public Operator {
UserOperator() = default;
virtual ~UserOperator() = default;
std::string ibn2lbn(const std::string& input_bn) const override;
std::string normal_ibn2lbn(const std::string& input_bn) const override;
std::string obn2lbn(const std::string& output_bn) const override;
std::string mtbn2lbn(const std::string& model_tmp_bn) const override;
std::string mbn2lbn(const std::string& model_bn) const override;
......@@ -120,14 +127,14 @@ class SysOperator : public Operator {
UNEXPECTED_RUN(); \
}
SET_UNEXPECTED(ibn2lbn);
SET_UNEXPECTED(normal_ibn2lbn);
SET_UNEXPECTED(obn2lbn);
SET_UNEXPECTED(mtbn2lbn);
SET_UNEXPECTED(mbn2lbn);
#undef SET_UNEXPECTED
void InferShape4MbAndMtb() const override { UNEXPECTED_RUN(); }
void InferShape4Mtb() const override { UNEXPECTED_RUN(); }
private:
};
......
......@@ -14,7 +14,7 @@ class PoolingOp final : public UserOperator {
void Init(const OperatorConf& op_conf) override;
void InferShape4ObAndDtbFromIb() const override { TODO(); }
void InferShape4MbAndMtb() const override { TODO(); }
void InferShape4Mtb() const override { TODO(); }
private:
......
......@@ -15,7 +15,7 @@ class ReluOp final : public UserOperator {
bool IsElemWise() const override { return true; }
void InferShape4ObAndDtbFromIb() const override { TODO(); }
void InferShape4MbAndMtb() const override { TODO(); }
void InferShape4Mtb() const override { TODO(); }
private:
......
......@@ -14,7 +14,7 @@ class SoftmaxOp : public UserOperator {
void Init(const OperatorConf& op_conf) override;
void InferShape4ObAndDtbFromIb() const override { TODO(); }
void InferShape4MbAndMtb() const override { TODO(); }
void InferShape4Mtb() const override { TODO(); }
private:
......
......@@ -14,7 +14,7 @@ class SplitOp final : public SysOperator {
void Init(const OperatorConf& op_conf) override;
void InferShape4ObAndDtbFromIb() const override { TODO(); }
std::string ibn2lbn(const std::string& input_bn) const override {
std::string normal_ibn2lbn(const std::string& input_bn) const override {
return GetValueFromPbOpConf("lbn");
}
std::string obn2lbn(const std::string& output_bn) const override {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册