提交 2733025d 编写于 作者: L Li Xinqi 提交者: GitHub

Dev global op graph (#1636)

* Global<OpGraph> is only available duraing compilation

* small record_piece_size for InferNoParallelBlobDesc


Former-commit-id: 5eb1012703f8f9389ac8e2f16131bfd36411b0db
上级 d408be08
......@@ -287,7 +287,7 @@ void OpGraph::Init() {
FixOpParallelDesc();
UpdateOpNodeHasInDiff();
InferTimeShape();
InferNodeNoParallelBlobDesc();
InferNoParallelBlobDesc();
HashMap<LogicalBlobId, int32_t> lbi2model_split_axis;
InferModelSplitAxis(&lbi2model_split_axis);
InferBlobParallelDesc(lbi2model_split_axis);
......@@ -366,7 +366,7 @@ void OpGraph::InferTimeShape() const {
});
}
void OpGraph::InferNodeNoParallelBlobDesc() const {
void OpGraph::InferNoParallelBlobDesc() const {
TopoForEachNode([&](OpNode* op_node) {
ParallelContext parallel_ctx;
parallel_ctx.set_parallel_id(0);
......@@ -374,7 +374,8 @@ void OpGraph::InferNodeNoParallelBlobDesc() const {
parallel_ctx.set_policy(op_node->parallel_desc().policy());
op_node->op().InferBlobDescsIf(
std::bind(&OpNode::NoParallelBlobDesc4BnInOp, op_node, std::placeholders::_1),
&parallel_ctx, job_desc_->RecordPieceSize(), [](OpContext*) {});
&parallel_ctx, job_desc_->RecordPieceSize() / op_node->parallel_desc().parallel_num(),
[](OpContext*) {});
});
}
......
......@@ -133,7 +133,7 @@ class OpGraph final : public Graph<OpNode, OpEdge> {
void FixOpParallelDesc() const;
void UpdateOpNodeHasInDiff() const;
void InferTimeShape() const;
void InferNodeNoParallelBlobDesc() const;
void InferNoParallelBlobDesc() const;
void InferModelSplitAxis(HashMap<LogicalBlobId, int32_t>* lbi2model_split_axis) const;
void InferBlobParallelDesc(const HashMap<LogicalBlobId, int32_t>& lbi2model_split_axis) const;
void InferLogicalBlobDesc() const;
......
#include "oneflow/core/job/compiler.h"
#include "oneflow/core/persistence/tee_persistent_log_stream.h"
#include "oneflow/core/device/cudnn_conv_ctx_cache.h"
#include "oneflow/core/graph/op_graph.h"
namespace oneflow {
......@@ -91,7 +93,14 @@ void Compiler::GenNetTopo(Plan* plan) {
}
Plan Compiler::DoCompile() {
#ifdef WITH_CUDA
Global<CudnnConvCtxCache>::New();
#endif
Global<JobDesc>::Get()->FixAndOptimizeDLNet();
const JobDesc* job_desc = Global<JobDesc>::Get();
TeePersistentLogStream::Create("optimized_job_conf")->Write(job_desc->job_conf());
Global<OpGraph>::New(job_desc);
Global<OpGraph>::Get()->ToDotWithFilePath("optimized_dlnet_op_graph.dot");
auto logical_gph = std::make_unique<LogicalGraph>(job_desc->IsTrain());
int64_t total_mbn_num = logical_gph->total_mbn_num();
auto task_gph = std::make_unique<TaskGraph>(std::move(logical_gph));
......@@ -127,6 +136,10 @@ Plan Compiler::DoCompile() {
plan.set_total_mbn_num(total_mbn_num);
GenNetTopo(&plan);
ToDotFile(plan, "/dot/plan.dot");
Global<OpGraph>::Delete();
#ifdef WITH_CUDA
Global<CudnnConvCtxCache>::Delete();
#endif
return plan;
}
......
......@@ -13,8 +13,6 @@
#include "oneflow/core/persistence/tee_persistent_log_stream.h"
#include "oneflow/core/persistence/file_system.h"
#include "oneflow/core/actor/act_event_logger.h"
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/device/cudnn_conv_ctx_cache.h"
namespace oneflow {
......@@ -160,23 +158,13 @@ class Oneflow final {
};
Oneflow::Oneflow(const std::string& job_conf_filepath) {
#ifdef WITH_CUDA
Global<CudnnConvCtxCache>::New();
#endif
// New All Global
Global<JobDesc>::New(job_conf_filepath);
const JobDesc* global_job_desc = Global<JobDesc>::Get();
OpGraph old_op_graph(global_job_desc);
Global<JobDesc>::Get()->FixAndOptimizeDLNet();
ctrl_server_.reset(new CtrlServer());
Global<CtrlClient>::New();
OF_BARRIER();
int64_t this_mchn_id = Global<JobDesc>::Get()->GetMachineId(ctrl_server_->this_machine_addr());
Global<MachineCtx>::New(this_mchn_id);
TeePersistentLogStream::Create("optimized_job_conf")->Write(global_job_desc->job_conf());
old_op_graph.ToDotWithFilePath("dlnet_op_graph.dot");
Global<OpGraph>::New(Global<JobDesc>::Get());
Global<OpGraph>::Get()->ToDotWithFilePath("optimized_dlnet_op_graph.dot");
const MachineCtx* machine_ctx = Global<MachineCtx>::Get();
bool DoProfile =
machine_ctx->IsThisMachineMaster() && Global<JobDesc>::Get()->collect_act_event();
......@@ -203,9 +191,6 @@ Oneflow::Oneflow(const std::string& job_conf_filepath) {
PullPlan("naive_plan", &naive_plan);
PullPlan("mem_shared_plan", &mem_shared_plan);
}
#ifdef WITH_CUDA
Global<CudnnConvCtxCache>::Delete();
#endif
OF_BARRIER();
TeePersistentLogStream::Create("naive_plan")->Write(naive_plan);
TeePersistentLogStream::Create("mem_shared_plan")->Write(mem_shared_plan);
......@@ -243,7 +228,6 @@ Oneflow::Oneflow(const std::string& job_conf_filepath) {
Global<Profiler>::Delete();
Global<MachineCtx>::Delete();
Global<IDMgr>::Delete();
Global<OpGraph>::Delete();
Global<JobDesc>::Delete();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册