提交 fb138d8a 编写于 作者: L lixinqi

add pass DumpTimeShapeAndBlobParallelConfPass

上级 3457db13
......@@ -746,8 +746,8 @@ std::list<OpNode*> OpGraph::DataOrCtrlSourceNodes() const {
return ret;
}
void OpGraph::DumpLogicalBlobDesc(JobBuilder* job_builder) const {
auto* helper = job_builder->mutable_helper();
void OpGraph::DumpLogicalBlobDesc(Job* job) const {
auto* helper = job->mutable_helper();
ForEachNode([&](const OpNode* node) {
for (const auto& obn : node->op().output_bns()) {
const auto& lbi = node->op().BnInOp2Lbi(obn);
......@@ -757,17 +757,17 @@ void OpGraph::DumpLogicalBlobDesc(JobBuilder* job_builder) const {
});
}
void OpGraph::DumpSbpSignature(JobBuilder* job_builder) const {
void OpGraph::DumpSbpSignature(Job* job) const {
ForEachNode([&](const OpNode* node) {
(*job_builder->mutable_sbp_conf()->mutable_op_name2sbp_signature_conf())[node->op().op_name()] =
(*job->mutable_sbp_conf()->mutable_op_name2sbp_signature_conf())[node->op().op_name()] =
node->sbp_signature();
});
}
void OpGraph::DumpOpTimeShape(JobBuilder* job_builder) const {
void OpGraph::DumpOpTimeShape(Job* job) const {
ForEachNode([&](OpNode* op_node) {
auto* op_time_shape =
&(*job_builder->mutable_helper()->mutable_op_name2op_time_shape())[op_node->op().op_name()];
&(*job->mutable_helper()->mutable_op_name2op_time_shape())[op_node->op().op_name()];
if (op_node->out_blob_time_shape() != nullptr) {
op_node->out_blob_time_shape()->ToProto(op_time_shape->mutable_out_blob_time_shape());
}
......@@ -778,14 +778,15 @@ void OpGraph::DumpOpTimeShape(JobBuilder* job_builder) const {
});
}
void OpGraph::DumpBatchAxisLbi(JobBuilder* job_builder) const {
auto* lbn2batch_axis = job_builder->mutable_helper()->mutable_lbn2batch_axis();
void OpGraph::DumpBatchAxisLbi(Job* job) const {
auto* lbn2batch_axis = job->mutable_helper()->mutable_lbn2batch_axis();
ForEachNode([&](OpNode* op_node) {
for (const auto& obn : op_node->op().output_bns()) {
const LogicalBlobId& lbi = op_node->op().BnInOp2Lbi(obn);
const auto& lbn = GenLogicalBlobName(lbi);
const auto& pair = PbMapPair<std::string, OptInt64>(lbn, op_node->BatchAxis4Lbi(lbi));
CHECK(lbn2batch_axis->insert(pair).second);
const auto& batch_axis = op_node->BatchAxis4Lbi(lbi);
const auto& pair = PbMapPair<std::string, OptInt64>(lbn, batch_axis);
CHECK(lbn2batch_axis->insert(pair).first->second == batch_axis);
}
});
}
......
......@@ -149,10 +149,10 @@ class OpGraph final : public Graph<OpNode, OpEdge> {
void ForEachDataAndCtrlInNode(OpNode* node, const std::function<void(OpNode*)>& Handler) const;
void ForEachDataAndCtrlOutNode(OpNode* node, const std::function<void(OpNode*)>& Handler) const;
void DumpLogicalBlobDesc(JobBuilder* job_builder) const;
void DumpSbpSignature(JobBuilder* job_builder) const;
void DumpOpTimeShape(JobBuilder* job_builder) const;
void DumpBatchAxisLbi(JobBuilder* job_builder) const;
void DumpLogicalBlobDesc(Job* job) const;
void DumpSbpSignature(Job* job) const;
void DumpOpTimeShape(Job* job) const;
void DumpBatchAxisLbi(Job* job) const;
private:
void Init(const Job& job);
......
#include "oneflow/core/job/job_build_and_infer_ctx.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
#include "oneflow/core/framework/user_op_conf.h"
#include "oneflow/core/framework/config_def.h"
#include "oneflow/core/common/protobuf.h"
namespace oneflow {
......@@ -40,23 +41,28 @@ Maybe<void> JobBuildAndInferCtx::SetJobConf(const JobConfigProto& job_conf) {
return Maybe<void>::Ok();
}
REGISTER_FUNCTION_CONFIG_DEF().Bool("is_user_function", true, "is user defined function");
Maybe<void> JobBuildAndInferCtx::Complete() {
CHECK_NOTNULL(Global<JobDesc>::Get());
Global<JobDesc>::Delete();
auto scope = std::make_unique<GlobalJobDescScope>(job_->job_conf(), job_id_);
auto DoPass = [&](const std::string& pass_name) { FunctionPass(pass_name)(job_); };
DoPass("CompleteOfrecordDecoder");
DoPass("SetDefaultVariableConf");
DoPass("AutoMixedPrecision");
DoPass("TieUpChainHeadersUnReachableFromAnyVariableOps");
DoPass("NonDistributedOptimizerPass");
DoPass("AutoTrainStep");
DoPass("AutoLearningRate");
DoPass("GenerateBackwardAndOptimizerOpConfs");
DoPass("SequentializeNcclTupleBroadcastReducePass");
DoPass("AddAllReduceGroupPass");
DoPass("AddLbiDiffWatcherOpConfs");
DoPass("SequentializeAllReduceGroupPass");
if (GlobalJobDesc().Bool("is_user_function")) {
DoPass("CompleteOfrecordDecoder");
DoPass("SetDefaultVariableConf");
DoPass("AutoMixedPrecision");
DoPass("TieUpChainHeadersUnReachableFromAnyVariableOps");
DoPass("NonDistributedOptimizerPass");
DoPass("AutoTrainStep");
DoPass("AutoLearningRate");
DoPass("GenerateBackwardAndOptimizerOpConfs");
DoPass("SequentializeNcclTupleBroadcastReducePass");
DoPass("AddAllReduceGroupPass");
DoPass("AddLbiDiffWatcherOpConfs");
DoPass("SequentializeAllReduceGroupPass");
}
DoPass("DumpTimeShapeAndBlobParallelConfPass");
return Maybe<void>::Ok();
}
......
......@@ -678,16 +678,17 @@ void MakePushJob(const std::string& job_name, const std::string& op_name,
void CompileAndMergePlanOnMaster(const PbRpf<Job>& conf_jobs, Plan* plan) {
std::vector<Job> jobs(conf_jobs.size());
std::vector<Plan> sub_plans(conf_jobs.size());
FOR_RANGE(int64_t, job_id, 0, sub_plans.size()) {
jobs.at(job_id) = conf_jobs.Get(job_id);
AddJobName2JobId(jobs.at(job_id).job_conf().job_name(), job_id);
size_t user_job_size = jobs.size();
int64_t job_id = -1;
FOR_RANGE(int64_t, i, 0, sub_plans.size()) {
jobs.at(i) = conf_jobs.Get(i);
AddJobName2JobId(jobs.at(i).job_conf().job_name(), i);
{
auto scope = std::make_unique<GlobalJobDescScope>(jobs.at(job_id).job_conf(), job_id);
CompileCurJobOnMaster(&jobs.at(job_id), &sub_plans.at(job_id), true);
auto scope = std::make_unique<GlobalJobDescScope>(jobs.at(i).job_conf(), i);
CompileCurJobOnMaster(&jobs.at(i), &sub_plans.at(i), true);
}
}
if (Global<MachineCtx>::Get()->IsThisMachineMaster()) {
size_t user_job_size = jobs.size();
HashMap<std::string, ParallelBlobConf> push_op_name2parallel_blob_conf;
FilterOpName2ParallelBlobConf({OperatorConf::kInputConf}, &jobs,
&push_op_name2parallel_blob_conf);
......@@ -697,7 +698,6 @@ void CompileAndMergePlanOnMaster(const PbRpf<Job>& conf_jobs, Plan* plan) {
HashMap<std::string, ParallelBlobConf> var_op_name2parallel_blob_conf;
FilterOpName2ParallelBlobConf({OperatorConf::kVariableConf}, &jobs,
&var_op_name2parallel_blob_conf);
int64_t job_id = -1;
{
size_t helper_job_size =
push_op_name2parallel_blob_conf.size() + pull_op_name2parallel_blob_conf.size();
......@@ -727,6 +727,8 @@ void CompileAndMergePlanOnMaster(const PbRpf<Job>& conf_jobs, Plan* plan) {
CompileHelperJob(&pull_job);
}
MakeModelIoJobs(jobs, var_op_name2parallel_blob_conf, [&](Job* job) { CompileHelperJob(job); });
}
if (Global<MachineCtx>::Get()->IsThisMachineMaster()) {
MergeSubPlanWithoutGenNetTopo(plan, sub_plans);
InterJobMemSharingUtil::MergeMemReusedChunkBetweenUserJobs(jobs, plan, user_job_size);
InterJobMemSharingUtil::MergeMemSharedInterfaceMemBlockBetweenJobs(jobs, plan);
......
#include "oneflow/core/common/util.h"
#include "oneflow/core/job_completer/op_graph_pass.h"
#include "oneflow/core/job/job.pb.h"
namespace oneflow {
namespace {
class DumpTimeShapeAndBlobParallelConfPass final : public OpGraphPass {
public:
OF_DISALLOW_COPY_AND_MOVE(DumpTimeShapeAndBlobParallelConfPass);
DumpTimeShapeAndBlobParallelConfPass() = default;
~DumpTimeShapeAndBlobParallelConfPass() override = default;
bool IsEnabled() const override { return true; }
void Apply(const OpGraph& op_graph, Job* job) const override {
op_graph.DumpOpTimeShape(job);
op_graph.DumpBatchAxisLbi(job);
op_graph.DumpLogicalBlobDesc(job);
op_graph.DumpSbpSignature(job);
}
};
REGISTER_FUNCTION_PASS("DumpTimeShapeAndBlobParallelConfPass",
DumpTimeShapeAndBlobParallelConfPass);
} // namespace
} // namespace oneflow
......@@ -73,20 +73,10 @@ void SetCtrlInOpName4VariableOp(const OpGraph& op_graph, JobBuilder* job_builder
}
}
void SetOpTimeShape7BatchAxisLbis(const OpGraph& op_graph, JobBuilder* job_builder) {
op_graph.DumpOpTimeShape(job_builder);
op_graph.DumpBatchAxisLbi(job_builder);
}
void DumpLogicalBlobDescAndSbpSignature(const OpGraph& op_graph, JobBuilder* job_builder) {
op_graph.DumpLogicalBlobDesc(job_builder);
op_graph.DumpSbpSignature(job_builder);
}
} // namespace
void JobCompleter::Complete(Job* job) const {
WithOpGraphAndMutJobBuilder(job, &DumpLogicalBlobDescAndSbpSignature);
FunctionPass("DumpTimeShapeAndBlobParallelConfPass")(job);
WithOpGraphAndMutJobBuilder(job, &GroupBoxingByDstParallel);
WithOpGraphAndMutJobBuilder(job, &AddKeepHeaderOnlyOp);
WithOpGraphAndMutJobBuilder(job, &SetCtrlInOpName4VariableOp);
......@@ -97,9 +87,7 @@ void JobCompleter::Complete(Job* job) const {
AddGlobalTotalJobCriticalSection(*job);
WithOpGraphAndMutJobBuilder(job, &AddGlobalInputCriticalSections);
WithOpGraphAndMutJobBuilder(job, &AddGlobalOutputCriticalSections);
WithOpGraphAndMutJobBuilder(job, &DumpLogicalBlobDescAndSbpSignature);
WithOpGraphAndMutJobBuilder(job, &SetOpTimeShape7BatchAxisLbis);
FunctionPass("DumpTimeShapeAndBlobParallelConfPass")(job);
if (XrtCompilationEnabled(GlobalJobDesc())) {
#ifdef OF_WITH_XRT
WithOpGraphAndMutJob(job, &RebuildXrtCompiledJob);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册