提交 d408be08 编写于 作者: L Li Xinqi 提交者: GitHub

Dev logical blob dim0 (#1635)

* mem_shared_hint_id

* sharable memory block

* rm useless code

* remove useless code

* bugfix: no redundant edges

* rename: MemBlockGroup => MemBlock

* put constrcutor of SharableMemBlockNode into header file

* bugfix

* rename field: MemBlock.block_id => MemBlock.mem_block_id

* replace piece_size with logical_blob_dim0

* BlobParallelConf

* BlobParallelDesc

* infer out blob model_split_axis

* int64_t => int32_t

* InferOutBlobParallelDesc

* gather out blob model split (#1624)

* InferBlobParallelDesc

* let variable op support kModelParallel

* rename lbi2blob_desc_ => lbi2no_parallel_blob_desc_

* Global<OpGraph>

* SplitLogicalInputBlobDesc

* ConcatOutputBlobDescs

* rename: BlobDataParallel => DataBlobParallel; BlobModelParallel => ModelBlobParallel; BlobGridParallel => GridBlobParallel

* OpGraph::CheckBlobDescs(...)

* exact division is unnecessary

* fix bugs

* rename InferOutBlob* => InferOutputBlob

* exact division in variable_op is unnecessary

* bug fix

* fix bugs

* fix bugs

* IsInputBlobAllowedModelSplit

* use Global<OpGraph> to InferModelSize

* add OpGraph::GetDataBalancedSplitter and OpGraph::GetModelBalancedSplitter

* fix IdentityOp::IsInputBlobAllowedModelSplit

* no implementation for pure virtual function Operator::IsInputBlobAllowedModelSplit

* refine BlobParallelDesc: replace CopyParallelConf with operator=

* refine ParallelDesc: remove unused functions

* more checks on ParallelDesc

* remove unused function Operator::MaxModelSplitNum

* bugfix: SoleOp() => op_vec().at(0)


Former-commit-id: be1f820b2927f7f79f55b7891f6575cdeb4b2053
上级 d91685b1
......@@ -57,7 +57,7 @@ DEFINE_BLD_BOXING_OP_CONF_METHOD(InBoxingTaskNode, DataConcatAndDataSplit) {
conf->mutable_concat_box()->set_axis(0);
BoxSplitConf* split_conf = conf->mutable_split_box();
split_conf->set_axis(0);
const std::string& out_op_name = out_logical->SoleOp()->op_name();
const std::string& out_op_name = out_logical->op_vec().at(0)->op_name();
SetBoxSplitPart(sorted_out_edges,
Global<OpGraph>::Get()->GetDataBalancedSplitter(out_op_name, lbi,
*out_logical->parallel_desc()),
......@@ -67,12 +67,12 @@ DEFINE_BLD_BOXING_OP_CONF_METHOD(OutBoxingTaskNode, DataConcatAndDataSplit) {
conf->mutable_concat_box()->set_axis(0);
BoxSplitConf* split_conf = conf->mutable_split_box();
split_conf->set_axis(0);
const std::string& in_op_name = in_logical->SoleOp()->op_name();
const std::string& in_op_name = in_logical->op_vec().at(0)->op_name();
BalancedSplitter in_bs = Global<OpGraph>::Get()->GetDataBalancedSplitter(
in_op_name, lbi, *in_logical->parallel_desc());
Range in_range =
in_bs.At(sorted_in_edges.front().parallel_id_min, sorted_in_edges.back().parallel_id_max);
const std::string& out_op_name = out_logical->SoleOp()->op_name();
const std::string& out_op_name = out_logical->op_vec().at(0)->op_name();
BalancedSplitter out_bs = Global<OpGraph>::Get()->GetDataBalancedSplitter(
out_op_name, lbi, *out_logical->parallel_desc());
for (const EdgeInfo& out_edge : sorted_out_edges) {
......@@ -88,7 +88,7 @@ DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, DataConcatAndClone) {
DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, DataConcatAndModelSplit) {
conf->mutable_concat_box()->set_axis(0);
BoxSplitConf* split_conf = conf->mutable_split_box();
const std::string& out_op_name = out_logical->SoleOp()->op_name();
const std::string& out_op_name = out_logical->op_vec().at(0)->op_name();
split_conf->set_axis(Global<OpGraph>::Get()->GetModelSplitAxis(out_op_name, lbi));
SetBoxSplitPart(sorted_out_edges,
Global<OpGraph>::Get()->GetModelBalancedSplitter(out_op_name, lbi,
......@@ -96,19 +96,19 @@ DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, DataConcatAndModelSplit) {
split_conf);
}
DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, ModelConcatAndDataSplit) {
const std::string& in_op_name = in_logical->SoleOp()->op_name();
const std::string& in_op_name = in_logical->op_vec().at(0)->op_name();
conf->mutable_concat_box()->set_axis(Global<OpGraph>::Get()->GetModelSplitAxis(in_op_name, lbi));
BoxSplitConf* split_conf = conf->mutable_split_box();
split_conf->set_axis(0);
const std::string& out_op_name = out_logical->SoleOp()->op_name();
const std::string& out_op_name = out_logical->op_vec().at(0)->op_name();
SetBoxSplitPart(sorted_out_edges,
Global<OpGraph>::Get()->GetDataBalancedSplitter(out_op_name, lbi,
*out_logical->parallel_desc()),
split_conf);
}
DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, ModelConcatAndClone) {
const std::string& in_op_name = in_logical->SoleOp()->op_name();
const std::string& in_op_name = in_logical->op_vec().at(0)->op_name();
conf->mutable_concat_box()->set_axis(Global<OpGraph>::Get()->GetModelSplitAxis(in_op_name, lbi));
conf->mutable_clone_box();
}
......@@ -117,7 +117,7 @@ DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, AddAndDataSplit) {
BoxSplitConf* split_conf = conf->mutable_split_box();
split_conf->set_axis(0);
const std::string& out_op_name = out_logical->SoleOp()->op_name();
const std::string& out_op_name = out_logical->op_vec().at(0)->op_name();
SetBoxSplitPart(sorted_out_edges,
Global<OpGraph>::Get()->GetDataBalancedSplitter(out_op_name, lbi,
*out_logical->parallel_desc()),
......@@ -126,7 +126,7 @@ DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, AddAndDataSplit) {
DEFINE_BLD_BOXING_OP_CONF_METHOD(BoxingTaskNode, AddAndModelSplit) {
conf->mutable_add_box();
BoxSplitConf* split_conf = conf->mutable_split_box();
const std::string& out_op_name = out_logical->SoleOp()->op_name();
const std::string& out_op_name = out_logical->op_vec().at(0)->op_name();
split_conf->set_axis(Global<OpGraph>::Get()->GetModelSplitAxis(out_op_name, lbi));
SetBoxSplitPart(sorted_out_edges,
Global<OpGraph>::Get()->GetModelBalancedSplitter(out_op_name, lbi,
......
......@@ -276,18 +276,6 @@ int32_t LogicalNode::GetModelSplitAxis() const {
}
}
int32_t LogicalNode::GetMaxModelSplitNum() const {
CHECK_EQ(parallel_desc_->policy(), kModelParallel);
CHECK_NOTNULL(main_model_parallel_);
if (main_model_parallel_ == this) {
int32_t ret = SoleOp()->MaxModelSplitNum();
CHECK_NE(ret, -1);
return ret;
} else {
return main_model_parallel_->GetMaxModelSplitNum();
}
}
bool LogicalNode::HasOpWithCondition(std::function<bool(const Operator*)> cond) const {
for (std::shared_ptr<const Operator> op : op_vec_) {
if (cond(op.get())) { return true; }
......
......@@ -56,7 +56,6 @@ class LogicalNode : public Node<LogicalNode, LogicalEdge> {
LogicalNode* main_model_parallel() const { return main_model_parallel_; }
void set_main_model_parallel(LogicalNode* val) { main_model_parallel_ = val; }
int32_t GetModelSplitAxis() const;
int32_t GetMaxModelSplitNum() const;
virtual int64_t GetAreaId() const = 0;
virtual bool MayConsumeModelDiff() const { return false; }
......
......@@ -298,11 +298,6 @@ int32_t ConvOp<NDims>::ModelSplitAxis() const {
}
}
template<int32_t NDims>
int32_t ConvOp<NDims>::MaxModelSplitNum() const {
return GetValFromCustomizedConf<int32_t>("filters");
}
#ifdef WITH_CUDA
template<int32_t NDims>
void ConvOp<NDims>::InferCudnnAlgo(
......
......@@ -47,7 +47,6 @@ class ConvOp : public Operator {
const ParallelContext*, const OpContext*) const override;
int32_t ModelSplitAxis() const override;
int32_t MaxModelSplitNum() const override;
private:
bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
......
......@@ -17,7 +17,6 @@ class EmbeddingLookupOp final : public Operator {
void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
int32_t ModelSplitAxis() const override { return 1; }
int32_t MaxModelSplitNum() const override { return op_conf().embedding_lookup_conf().units(); }
private:
bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
......
......@@ -17,7 +17,6 @@ class FullyConnectedOp final : public Operator {
void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
int32_t ModelSplitAxis() const override { return 1; }
int32_t MaxModelSplitNum() const override { return op_conf().fully_connected_conf().units(); }
private:
bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { return false; }
......
......@@ -164,7 +164,6 @@ class Operator {
void FixParallelDesc(ParallelDesc* pr_desc) const;
void FixLbiWhenShareModel(const std::string& shared_op_name);
virtual int32_t ModelSplitAxis() const { return -1; }
virtual int32_t MaxModelSplitNum() const { return -1; }
void GenKernelConf(std::function<const BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
bool is_forward, const ParallelContext*, KernelConf*, const OpContext*) const;
......
......@@ -19,10 +19,6 @@ void RecurrentOp::InitFromOpConf() {
VirtualInitFromOpConf();
}
int32_t RecurrentOp::MaxModelSplitNum() const {
return GetValFromCustomizedConf<int32_t>("hidden_size");
}
void RecurrentOp::InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const {
const BlobDesc* in_blob_desc = GetBlobDesc4BnInOp("in");
......
......@@ -17,7 +17,6 @@ class RecurrentOp : public Operator {
void InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx) const override;
int32_t ModelSplitAxis() const override { return 1; }
int32_t MaxModelSplitNum() const override;
private:
bool IsInputBlobAllowedModelSplit(const std::string& ibn) const override { UNIMPLEMENTED(); }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册