From 1467aecf1826d572865f38eb0e9e655b566f27db Mon Sep 17 00:00:00 2001 From: Li Xinqi Date: Sun, 22 Nov 2020 04:06:11 +0800 Subject: [PATCH] Op collection (#3833) * more cfg files * instructions builder * forward declaration instead of include * more test for cfg * revert cfg files * InstructionsBuilder * using std::function as argument of IdCache::FindOrCreate * scope op_collection * include in framework/interpreter.h * puts more code into WithOptimizerOpCollectionScope * include in symbol_id_cache.h * calculation pass * refine Error * AddScopeToPyStorage * fix test_watch * get scope_symbol_id from current scope * fix assert bug Co-authored-by: binbinHan Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- .../api/python/framework/python_callback.cpp | 6 ++ oneflow/core/common/error.cpp | 16 ++-- oneflow/core/common/error.h | 18 +++- oneflow/core/common/error.proto | 1 + oneflow/core/common/maybe.h | 28 +++--- oneflow/core/framework/interpreter.cpp | 8 ++ oneflow/core/job/foreign_callback.h | 6 ++ oneflow/core/job/job_build_and_infer_ctx.cpp | 15 +-- oneflow/core/job/job_build_and_infer_ctx.h | 9 +- oneflow/core/job/scope.proto | 1 + .../job_rewriter/auto_mixed_precision.cpp | 1 + .../core/job_rewriter/calculation_pass.cpp | 24 +++++ oneflow/core/job_rewriter/calculation_pass.h | 29 ++++++ ...nerate_backward_and_optimizer_op_confs.cpp | 95 ++++++++++++++++--- oneflow/core/job_rewriter/identity_grad.cpp | 2 + oneflow/core/operator/distribute_clone_op.cpp | 6 +- oneflow/python/eager/interpreter_callback.py | 15 +++ oneflow/python/framework/python_callback.py | 9 ++ 18 files changed, 230 insertions(+), 59 deletions(-) create mode 100644 oneflow/core/job_rewriter/calculation_pass.cpp create mode 100644 oneflow/core/job_rewriter/calculation_pass.h diff --git a/oneflow/api/python/framework/python_callback.cpp b/oneflow/api/python/framework/python_callback.cpp index 1105a619c1..1c00575822 100644 --- a/oneflow/api/python/framework/python_callback.cpp +++ b/oneflow/api/python/framework/python_callback.cpp @@ -61,6 +61,12 @@ class PyForeignCallback : public ForeignCallback { is_mirrored); } + // TODO(lixinqi): remove this urgly api after python code migrated into cpp code + void AddScopeToPyStorage(int64_t scope_symbol_id, + const std::string& scope_proto_str) const override { + PYBIND11_OVERRIDE(void, ForeignCallback, AddScopeToPyStorage, scope_symbol_id, scope_proto_str); + } + int64_t MakeParallelDescSymbol( const std::shared_ptr& parallel_conf) const override { PYBIND11_OVERRIDE(int64_t, ForeignCallback, MakeParallelDescSymbol, parallel_conf); diff --git a/oneflow/core/common/error.cpp b/oneflow/core/common/error.cpp index 99822a0fd3..9571a60c72 100644 --- a/oneflow/core/common/error.cpp +++ b/oneflow/core/common/error.cpp @@ -28,6 +28,13 @@ void LogError(const Error& error) { } // namespace +Error&& Error::AddStackFrame(const std::string& location, const std::string& function) { + auto* stack_frame = error_proto_->add_stack_frame(); + stack_frame->set_location(location); + stack_frame->set_function(function); + return std::move(*this); +} + Error::operator std::string() const { return PbMessage2TxtString(*error_proto_); } Error Error::Ok() { return std::make_shared(); } @@ -223,13 +230,4 @@ Error Error::GradientFunctionNotFound() { return error; } -Error&& operator<=(const std::pair& loc_and_func, Error&& error) { - LogError(error); - CHECK(error.error_proto()->stack_frame().empty()); - auto* stack_frame = error.error_proto()->add_stack_frame(); - stack_frame->set_location(loc_and_func.first); - stack_frame->set_function(loc_and_func.second); - return std::move(error); -} - } // namespace oneflow diff --git a/oneflow/core/common/error.h b/oneflow/core/common/error.h index 001a176e67..b74082b52e 100644 --- a/oneflow/core/common/error.h +++ b/oneflow/core/common/error.h @@ -28,6 +28,10 @@ class Error final { Error(const Error&) = default; ~Error() = default; + // r-value reference is used to supporting expressions like `Error().AddStackFrame("foo.cpp", + // "Bar") << "invalid value"` because operator<<() need r-value reference + Error&& AddStackFrame(const std::string& location, const std::string& function); + static Error Ok(); static Error ProtoParseFailedError(); static Error JobSetEmptyError(); @@ -66,7 +70,8 @@ class Error final { static Error GradientFunctionNotFound(); std::shared_ptr error_proto() const { return error_proto_; } - ErrorProto* operator->() const { return error_proto_.get(); } + const ErrorProto* operator->() const { return error_proto_.get(); } + ErrorProto* operator->() { return error_proto_.get(); } operator std::string() const; void Assign(const Error& other) { error_proto_ = other.error_proto_; } @@ -74,11 +79,17 @@ class Error final { std::shared_ptr error_proto_; }; +// r-value reference is used to supporting expressions like `Error() << "invalid value"` template Error&& operator<<(Error&& error, const T& x) { std::ostringstream ss; ss << x; - error->set_msg(error->msg() + ss.str()); + if (error->stack_frame().empty()) { + error->set_msg(error->msg() + ss.str()); + } else { + auto* stack_frame_top = error->mutable_stack_frame(error->stack_frame_size() - 1); + stack_frame_top->set_error_msg(stack_frame_top->error_msg() + ss.str()); + } return std::move(error); } @@ -88,9 +99,6 @@ inline Error&& operator<<(Error&& error, const Error& other) { return std::move(error); } -// for LOG(ERROR) -Error&& operator<=(const std::pair& loc_and_func, Error&& error); - } // namespace oneflow #endif // ONEFLOW_CORE_COMMON_ERROR_H_ diff --git a/oneflow/core/common/error.proto b/oneflow/core/common/error.proto index 05850dc99b..c0148cd719 100644 --- a/oneflow/core/common/error.proto +++ b/oneflow/core/common/error.proto @@ -112,6 +112,7 @@ message UnkownError { } message ErrorStackFrame { required string location = 1; required string function = 2; + required string error_msg = 3; } message ErrorProto { diff --git a/oneflow/core/common/maybe.h b/oneflow/core/common/maybe.h index ad56e1d284..caa78d2901 100644 --- a/oneflow/core/common/maybe.h +++ b/oneflow/core/common/maybe.h @@ -222,14 +222,10 @@ inline bool MaybeIsOk(Maybe&& maybe) { #define CHECK_OK(...) CHECK(MaybeIsOk(std::move(__VA_ARGS__))) -#define OF_RETURN_IF_ERROR(...) \ - MAYBE_CONST_AUTO_REF maybe_##__LINE__ = __MaybeErrorStackCheckWrapper__(__VA_ARGS__); \ - if (!maybe_##__LINE__.IsOk()) { \ - auto* stack_frame = maybe_##__LINE__.error()->add_stack_frame(); \ - stack_frame->set_location(MAYBE_FAILED_LOC); \ - stack_frame->set_function(__FUNCTION__); \ - return maybe_##__LINE__.error(); \ - } +#define OF_RETURN_IF_ERROR(...) \ + for (MAYBE_CONST_AUTO_REF maybe_##__LINE__ = __MaybeErrorStackCheckWrapper__(__VA_ARGS__); \ + !maybe_##__LINE__.IsOk();) \ + return Error(maybe_##__LINE__.error()).AddStackFrame(MAYBE_FAILED_LOC, __FUNCTION__) #else #error statement expression is no supported, please implement try-catch version of JUST @@ -237,16 +233,14 @@ inline bool MaybeIsOk(Maybe&& maybe) { } // namespace oneflow -#define OF_TODO() \ - return std::pair(MAYBE_FAILED_LOC, __FUNCTION__) <= Error::Todo() -#define OF_UNIMPLEMENTED() \ - return std::pair(MAYBE_FAILED_LOC, __FUNCTION__) \ - <= Error::Unimplemented() +#define OF_TODO() return Error::Todo().AddStackFrame(MAYBE_FAILED_LOC, __FUNCTION__) +#define OF_UNIMPLEMENTED() \ + return Error::Unimplemented().AddStackFrame(MAYBE_FAILED_LOC, __FUNCTION__) -#define CHECK_OR_RETURN(expr) \ - if (!(expr)) \ - return std::pair(MAYBE_FAILED_LOC, __FUNCTION__) \ - <= Error::CheckFailedError() << " Check failed: " << OF_PP_STRINGIZE(expr) << "\t" +#define CHECK_OR_RETURN(expr) \ + if (!(expr)) \ + return Error::CheckFailedError().AddStackFrame(MAYBE_FAILED_LOC, __FUNCTION__) \ + << " Check failed: " << OF_PP_STRINGIZE(expr) << "\t" #define CHECK_EQ_OR_RETURN(lhs, rhs) \ CHECK_OR_RETURN((lhs) == (rhs)) << "(" << (lhs) << " vs " << (rhs) << ") " diff --git a/oneflow/core/framework/interpreter.cpp b/oneflow/core/framework/interpreter.cpp index 04197f0f2f..0aab9b473d 100644 --- a/oneflow/core/framework/interpreter.cpp +++ b/oneflow/core/framework/interpreter.cpp @@ -25,6 +25,10 @@ LogicalInterpreter::LogicalInterpreter() Maybe LogicalInterpreter::Run(const std::function(InstructionsBuilder*)>& Build) { InstructionsBuilder instructions_builder(mut_id_generator()); JUST(Build(&instructions_builder)); + if (instructions_builder.instruction_list().instruction().empty()) { + CHECK(instructions_builder.eager_symbol_list().eager_symbol().empty()); + return Maybe::Ok(); + } return Global::Get()->RunLogicalInstruction( instructions_builder.instruction_list(), instructions_builder.eager_symbol_list()); } @@ -36,6 +40,10 @@ Maybe PhysicalInterpreter::Run( const std::function(InstructionsBuilder*)>& Build) { InstructionsBuilder instructions_builder(mut_id_generator()); JUST(Build(&instructions_builder)); + if (instructions_builder.instruction_list().instruction().empty()) { + CHECK(instructions_builder.eager_symbol_list().eager_symbol().empty()); + return Maybe::Ok(); + } return Global::Get()->RunPhysicalInstruction( instructions_builder.instruction_list(), instructions_builder.eager_symbol_list()); } diff --git a/oneflow/core/job/foreign_callback.h b/oneflow/core/job/foreign_callback.h index 8a36a35c6c..33b3976cce 100644 --- a/oneflow/core/job/foreign_callback.h +++ b/oneflow/core/job/foreign_callback.h @@ -41,6 +41,12 @@ class ForeignCallback { virtual void RemoveForeignCallback(int64_t unique_id) const { UNIMPLEMENTED(); } + // TODO(lixinqi): remove this urgly api after python code migrated into cpp code + virtual void AddScopeToPyStorage(int64_t scope_symbol_id, + const std::string& scope_proto_str) const { + UNIMPLEMENTED(); + } + // return scope_symbol_id virtual int64_t MakeScopeSymbol(const std::shared_ptr& job_conf, const std::shared_ptr& parallel_conf, diff --git a/oneflow/core/job/job_build_and_infer_ctx.cpp b/oneflow/core/job/job_build_and_infer_ctx.cpp index 4042d4172b..dc66e604ea 100644 --- a/oneflow/core/job/job_build_and_infer_ctx.cpp +++ b/oneflow/core/job/job_build_and_infer_ctx.cpp @@ -470,7 +470,7 @@ Maybe JobBuildAndInferCtx::AddAndInferMirroredOp(const OperatorConf FOR_RANGE(int32_t, i, 0, sub_op_list_size) { ResetOpConfName(&sub_op_conf, GetSubOpName(i)); for (const auto& ibn : op->input_bns()) { - const auto& lbi = *JUST(GetSubLbi(op->BnInOp2Lbi(ibn), i)); + const auto& lbi = *JUST(GetSubLbi(op_conf.scope_symbol_id(), op->BnInOp2Lbi(ibn), i)); ReplaceInputLbnInOpCustomizedConf(&sub_op_conf, ibn, GenLogicalBlobName(lbi)); } const ParallelConf& parallel_conf = GetMirroredOpParallelConf(parallel_desc, i); @@ -495,11 +495,13 @@ Maybe JobBuildAndInferCtx::AddAndInferMirroredOp(const OperatorConf return last_op_attribute; } -Maybe JobBuildAndInferCtx::GetSubLbi(const LogicalBlobId& lbi, +Maybe JobBuildAndInferCtx::GetSubLbi(int64_t scope_symbol_id, + const LogicalBlobId& lbi, int32_t index) { auto lbi_vec_iter = mirrored_lbi2sub_lbis_.find(lbi); if (lbi_vec_iter == mirrored_lbi2sub_lbis_.end()) { - const auto& new_lbi = JUST(FindOrCreateMirroredLbiFromCompatibleConsistentBlob(lbi)); + const auto& new_lbi = + JUST(FindOrCreateMirroredLbiFromCompatibleConsistentBlob(scope_symbol_id, lbi)); lbi_vec_iter = mirrored_lbi2sub_lbis_.find(*new_lbi); CHECK(lbi_vec_iter != mirrored_lbi2sub_lbis_.end()); } @@ -836,7 +838,7 @@ ParallelConf EagerJobBuildAndInferCtx::GetMirroredOpParallelConf(const ParallelD } Maybe LazyJobBuildAndInferCtx::FindOrCreateMirroredLbiFromCompatibleConsistentBlob( - const LogicalBlobId& lbi) { + int64_t scope_symbol_id, const LogicalBlobId& lbi) { const std::string& lbn = GenLogicalBlobName(lbi); const auto& sbn_it = mut_consistent_lbi2mirrored_lbi()->find(lbi); if (sbn_it != mut_consistent_lbi2mirrored_lbi()->end()) { return sbn_it->second; } @@ -855,6 +857,7 @@ Maybe LazyJobBuildAndInferCtx::FindOrCreateMirroredLbiFromCompati lbi_vec->push_back(sub_lbi); }; OperatorConf op_conf; + op_conf.set_scope_symbol_id(scope_symbol_id); op_conf.set_device_tag(CHECK_JUST(DeviceTag4DeviceType(parallel_desc.device_type()))); if (sbp.has_broadcast_parallel()) { op_conf.set_name(kAutoMirroredBlobNamePrefix + "-DistributeClone-" + NewUniqueId()); @@ -885,7 +888,6 @@ Maybe LazyJobBuildAndInferCtx::FindOrCreateMirroredLbiFromCompati { const auto& producer_op_conf = JUST(Op4OpName(lbi.op_name()))->op_conf(); CHECK_OR_RETURN(producer_op_conf.has_scope_symbol_id()); - int64_t scope_symbol_id = producer_op_conf.scope_symbol_id(); const auto& scope = Global>::Get()->Get(scope_symbol_id); const auto* job_desc = JUST(scope.job_desc()); JUST(AddAndInferOp(op_conf, parallel_desc.parallel_conf(), job_desc, false)); @@ -894,7 +896,7 @@ Maybe LazyJobBuildAndInferCtx::FindOrCreateMirroredLbiFromCompati } Maybe EagerJobBuildAndInferCtx::FindOrCreateMirroredLbiFromCompatibleConsistentBlob( - const LogicalBlobId& lbi) { + int64_t scope_symbol_id, const LogicalBlobId& lbi) { const std::string& lbn = GenLogicalBlobName(lbi); const auto& sbn_it = mut_consistent_lbi2mirrored_lbi()->find(lbi); if (sbn_it != mut_consistent_lbi2mirrored_lbi()->end()) { return sbn_it->second; } @@ -909,6 +911,7 @@ Maybe EagerJobBuildAndInferCtx::FindOrCreateMirroredLbiFromCompat CHECK_OR_RETURN(producer_op_conf.has_scope_symbol_id()); op_conf.set_scope_symbol_id(producer_op_conf.scope_symbol_id()); } + op_conf.set_scope_symbol_id(scope_symbol_id); // const char* device_tag = JUST(DeviceTag4DeviceType(parallel_desc.device_type())); op_conf.set_device_tag(JUST(DeviceTag4DeviceType(parallel_desc.device_type()))); op_conf.set_name(kAutoMirroredBlobNamePrefix + "-CastToMirrored-" + NewUniqueId()); diff --git a/oneflow/core/job/job_build_and_infer_ctx.h b/oneflow/core/job/job_build_and_infer_ctx.h index 40c789f606..b8db19f9f7 100644 --- a/oneflow/core/job/job_build_and_infer_ctx.h +++ b/oneflow/core/job/job_build_and_infer_ctx.h @@ -78,7 +78,7 @@ class JobBuildAndInferCtx { int64_t parallel_id) const = 0; virtual bool GetIsMirroredParallelView() const = 0; virtual Maybe FindOrCreateMirroredLbiFromCompatibleConsistentBlob( - const LogicalBlobId& lbn) = 0; + int64_t scope_symbol_id, const LogicalBlobId& lbn) = 0; Job* mut_job() const { return job_; } int64_t job_id() const { return job_id_; } @@ -128,7 +128,8 @@ class JobBuildAndInferCtx { Maybe CheckAllInputsConvertableToMirroredBlob(const Operator& op) const; Maybe AddLossConsistentBlobName(const std::string& lbn); Maybe AddLossMirroredBlobName(const std::string& lbn); - Maybe GetSubLbi(const LogicalBlobId& lbi, int32_t index); + Maybe GetSubLbi(int64_t scope_symbol_id, const LogicalBlobId& lbi, + int32_t index); Maybe AllInputsBroadcastParallel(const Operator& op) const; void InferBlobBackwardSignature(Operator* op); void InferBlobBackwardSignature(const Operator& op, @@ -167,7 +168,7 @@ class LazyJobBuildAndInferCtx : public JobBuildAndInferCtx { ParallelConf GetMirroredOpParallelConf(const ParallelDesc&, int64_t parallel_id) const override; bool GetIsMirroredParallelView() const override { return false; } Maybe FindOrCreateMirroredLbiFromCompatibleConsistentBlob( - const LogicalBlobId& lbn) override; + int64_t scope_symbol_id, const LogicalBlobId& lbn) override; }; class EagerJobBuildAndInferCtx : public JobBuildAndInferCtx { @@ -185,7 +186,7 @@ class EagerJobBuildAndInferCtx : public JobBuildAndInferCtx { ParallelConf GetMirroredOpParallelConf(const ParallelDesc&, int64_t parallel_id) const override; bool GetIsMirroredParallelView() const override { return true; } Maybe FindOrCreateMirroredLbiFromCompatibleConsistentBlob( - const LogicalBlobId& lbn) override; + int64_t scope_symbol_id, const LogicalBlobId& lbn) override; HashSet executed_op_names_; }; diff --git a/oneflow/core/job/scope.proto b/oneflow/core/job/scope.proto index 3644fb82ff..a807f1de21 100644 --- a/oneflow/core/job/scope.proto +++ b/oneflow/core/job/scope.proto @@ -14,4 +14,5 @@ message ScopeProto { optional int64 parent_scope_symbol_id = 70; required int64 session_id = 80; map attr_name2attr_value = 90; + optional string calculation_pass_name = 100 [default = "forward_pass"]; } diff --git a/oneflow/core/job_rewriter/auto_mixed_precision.cpp b/oneflow/core/job_rewriter/auto_mixed_precision.cpp index ff16460529..74a54000c8 100644 --- a/oneflow/core/job_rewriter/auto_mixed_precision.cpp +++ b/oneflow/core/job_rewriter/auto_mixed_precision.cpp @@ -176,6 +176,7 @@ void InsertCastOpImpl(bool f2h, const OpGraph& op_graph, const HashSet& .Input("in", lbn) .Output("out") .Attr("dtype", cast_data_type) + .ScopeSymbolId(src_node->op().op_conf().scope_symbol_id()) .Build(); bool cast_is_consumed = false; diff --git a/oneflow/core/job_rewriter/calculation_pass.cpp b/oneflow/core/job_rewriter/calculation_pass.cpp new file mode 100644 index 0000000000..a305906626 --- /dev/null +++ b/oneflow/core/job_rewriter/calculation_pass.cpp @@ -0,0 +1,24 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/job_rewriter/calculation_pass.h" + +namespace oneflow { + +const std::string kForwardPass = "forward_pass"; +const std::string kBackwardPass = "backward_pass"; +const std::string kOptimizerPass = "optimizer_pass"; + +} // namespace oneflow diff --git a/oneflow/core/job_rewriter/calculation_pass.h b/oneflow/core/job_rewriter/calculation_pass.h new file mode 100644 index 0000000000..5b9b9841ac --- /dev/null +++ b/oneflow/core/job_rewriter/calculation_pass.h @@ -0,0 +1,29 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_JOB_REWRITE_CALCULATION_PASS_H_ +#define ONEFLOW_CORE_JOB_REWRITE_CALCULATION_PASS_H_ + +#include + +namespace oneflow { + +extern const std::string kForwardPass; +extern const std::string kBackwardPass; +extern const std::string kOptimizerPass; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_JOB_REWRITE_CALCULATION_PASS_H_ diff --git a/oneflow/core/job_rewriter/generate_backward_and_optimizer_op_confs.cpp b/oneflow/core/job_rewriter/generate_backward_and_optimizer_op_confs.cpp index 758313dd8c..a69098a46d 100644 --- a/oneflow/core/job_rewriter/generate_backward_and_optimizer_op_confs.cpp +++ b/oneflow/core/job_rewriter/generate_backward_and_optimizer_op_confs.cpp @@ -16,6 +16,12 @@ limitations under the License. #include "oneflow/core/job_rewriter/job_pass.h" #include "oneflow/core/job_rewriter/autograd.h" #include "oneflow/core/job_rewriter/optimizer.h" +#include "oneflow/core/job_rewriter/calculation_pass.h" +#include "oneflow/core/job/scope.h" +#include "oneflow/core/job/foreign_callback.h" +#include "oneflow/core/vm/symbol_storage.h" +#include "oneflow/core/framework/interpreter.h" +#include "oneflow/core/framework/instructions_builder.h" namespace oneflow { @@ -108,28 +114,87 @@ void FilterModelLbi2DiffLbi(const OpGraph& op_graph, } } +Maybe WithCalculationPassScope(const std::string& pass_name, Job* job, + const std::function()>& Handler) { + HashSet exists_op_names; + for (const auto& op_conf : job->net().op()) { + CHECK_OR_RETURN(exists_op_names.emplace(op_conf.name()).second); + } + JUST(Handler()); + // using a new JobBuilder to avoid bugs caused by MutOnlyOnce + auto new_job_builder = std::make_shared(job); + HashMap> scope_id2op_names; + const auto& scope_storage = *Global>::Get(); + for (const auto& op_conf : job->net().op()) { + if (exists_op_names.count(op_conf.name()) > 0) { continue; } + CHECK_OR_RETURN(op_conf.has_scope_symbol_id()); + OF_RETURN_IF_ERROR(scope_storage.MaybeGet(op_conf.scope_symbol_id())) << op_conf.DebugString(); + scope_id2op_names[op_conf.scope_symbol_id()].push_back(&op_conf); + } + const auto& GetNewScopeSymbolId = [&](int64_t old_scope_symbol_id) -> Maybe { + const auto& old_scope = JUST(scope_storage.MaybeGet(old_scope_symbol_id)); + cfg::ScopeProto new_scope; + new_scope.InitFromProto(old_scope.scope_proto()); + new_scope.set_parent_scope_symbol_id(old_scope_symbol_id); + new_scope.set_calculation_pass_name(pass_name); + int64_t symbol_id = 0; + JUST(LogicalInterpreter().Run([&](InstructionsBuilder* builder) -> Maybe { + symbol_id = JUST(builder->FindOrCreateSymbolId(new_scope)); + return Maybe::Ok(); + })); + // Remove this urgly code after most python code migrated into cpp code + { + ScopeProto scope_proto; + new_scope.ToProto(&scope_proto); + Global::Get()->AddScopeToPyStorage(symbol_id, scope_proto.DebugString()); + } + return symbol_id; + }; + for (const auto& pair : scope_id2op_names) { + int64_t new_scope_symbol_id = JUST(GetNewScopeSymbolId(pair.first)); + std::vector op_confs(pair.second.size()); + for (int i = 0; i < pair.second.size(); ++i) { + op_confs.at(i).CopyFrom(*pair.second.at(i)); + op_confs.at(i).set_scope_symbol_id(new_scope_symbol_id); + } + new_job_builder->MutOpsOnlyOnce(op_confs); + } + return new_job_builder; +} + Maybe GenerateBackwardAndOptimizerOpConfs::Apply(Job* job, JobPassCtx* ctx) const { if (!IsEnabled(*ctx)) { return Maybe::Ok(); } const OpGraph op_graph(*job); - JobBuilder job_builder(job); + auto job_builder = std::make_shared(job); + const JobBuilder* old_job_builder = job_builder.get(); LogicalBlobId total_loss_instance_num; HashMap lbi2diff_lbi; - JUST(AutoGrad(op_graph, &job_builder, &lbi2diff_lbi)); + job_builder = JUST(WithCalculationPassScope(kBackwardPass, job, [&]() -> Maybe { + CHECK(old_job_builder == job_builder.get()); // Check this lambda never been async called + JUST(AutoGrad(op_graph, job_builder.get(), &lbi2diff_lbi)); + return Maybe::Ok(); + })); HashMap model_lbi2model_diff_lbi; FilterModelLbi2DiffLbi(op_graph, lbi2diff_lbi, &model_lbi2model_diff_lbi); - AddDiffStaticShapeCast(op_graph, &job_builder, &model_lbi2model_diff_lbi); - AddDiffParallelCast(op_graph, &job_builder, &model_lbi2model_diff_lbi); - JUST(ScaleModelDiffByLossInstanceNum(op_graph, &job_builder, &model_lbi2model_diff_lbi)); - ScaleModelDiffByLossScale(op_graph, &job_builder, &model_lbi2model_diff_lbi); - const NormalModelUpdateOpUserConf& model_update_conf = - job->job_conf().train_conf().model_update_conf(); - RegularizeGradient(op_graph, &job_builder, &model_lbi2model_diff_lbi); - if (model_update_conf.has_clip_conf()) { - ClipGradient(op_graph, &job_builder, &model_lbi2model_diff_lbi, model_update_conf.clip_conf()); - } - AddOptimizerOpConf(ctx, op_graph, &job_builder, model_lbi2model_diff_lbi); - UpdateJobHelperConfProducedLbi2ConsumedDiffLbi(lbi2diff_lbi, &job_builder); - UpdateOpSbpSignatureHint(op_graph, &job_builder); + old_job_builder = job_builder.get(); + job_builder = JUST(WithCalculationPassScope(kOptimizerPass, job, [&]() -> Maybe { + CHECK(old_job_builder == job_builder.get()); // Check this lambda never been async called + AddDiffStaticShapeCast(op_graph, job_builder.get(), &model_lbi2model_diff_lbi); + AddDiffParallelCast(op_graph, job_builder.get(), &model_lbi2model_diff_lbi); + JUST(ScaleModelDiffByLossInstanceNum(op_graph, job_builder.get(), &model_lbi2model_diff_lbi)); + ScaleModelDiffByLossScale(op_graph, job_builder.get(), &model_lbi2model_diff_lbi); + const NormalModelUpdateOpUserConf& model_update_conf = + job->job_conf().train_conf().model_update_conf(); + RegularizeGradient(op_graph, job_builder.get(), &model_lbi2model_diff_lbi); + if (model_update_conf.has_clip_conf()) { + ClipGradient(op_graph, job_builder.get(), &model_lbi2model_diff_lbi, + model_update_conf.clip_conf()); + } + AddOptimizerOpConf(ctx, op_graph, job_builder.get(), model_lbi2model_diff_lbi); + return Maybe::Ok(); + })); + UpdateJobHelperConfProducedLbi2ConsumedDiffLbi(lbi2diff_lbi, job_builder.get()); + UpdateOpSbpSignatureHint(op_graph, job_builder.get()); return Maybe::Ok(); } diff --git a/oneflow/core/job_rewriter/identity_grad.cpp b/oneflow/core/job_rewriter/identity_grad.cpp index 5ec79464fe..3b8ba3143e 100644 --- a/oneflow/core/job_rewriter/identity_grad.cpp +++ b/oneflow/core/job_rewriter/identity_grad.cpp @@ -61,6 +61,7 @@ void GenerateCastToMirroredBackwardOpConf( if (DiffLbi4BnInOp("in") != nullptr) { OperatorConf grad_op{}; grad_op.set_name("System-AutoGrad-" + op.op_name()); + grad_op.set_scope_symbol_id(op.op_conf().scope_symbol_id()); CastFromMirroredOpConf* bw_op_conf = grad_op.mutable_cast_from_mirrored_conf(); bw_op_conf->set_in(GenLogicalBlobName(*DiffLbi4BnInOp("out"))); bw_op_conf->set_out("out"); @@ -85,6 +86,7 @@ void GenerateCastFromMirroredBackwardOpConf( if (DiffLbi4BnInOp("in") != nullptr) { OperatorConf grad_op{}; grad_op.set_name("System-AutoGrad-" + op.op_name()); + grad_op.set_scope_symbol_id(op.op_conf().scope_symbol_id()); CastToMirroredOpConf* bw_op_conf = grad_op.mutable_cast_to_mirrored_conf(); bw_op_conf->set_in(GenLogicalBlobName(*DiffLbi4BnInOp("out"))); bw_op_conf->set_out("out"); diff --git a/oneflow/core/operator/distribute_clone_op.cpp b/oneflow/core/operator/distribute_clone_op.cpp index 18cbd1825d..a2fcb8e1cd 100644 --- a/oneflow/core/operator/distribute_clone_op.cpp +++ b/oneflow/core/operator/distribute_clone_op.cpp @@ -64,7 +64,7 @@ Maybe DistributeCloneOp::InferBlobDescs( const ParallelContext* parallel_ctx) const { const auto& in_blob_desc = *GetBlobDesc4BnInOp("in"); if (parallel_ctx->parallel_num() > 1) { - CHECK_EQ(parallel_ctx->parallel_num(), output_bns().size()); + CHECK_EQ_OR_RETURN(parallel_ctx->parallel_num(), output_bns().size()); auto* out_blob_desc = GetBlobDesc4BnInOp(output_bns().Get(parallel_ctx->parallel_id())); *out_blob_desc = in_blob_desc; return Maybe::Ok(); @@ -83,7 +83,7 @@ Maybe DistributeCloneOp::InferOutParallelDesc( FOR_RANGE(int, i, 0, output_bns().size()) { const auto& obn = output_bns().Get(i); if (op_parallel_desc.parallel_num() > 1) { - CHECK_EQ(op_parallel_desc.parallel_num(), output_bns().size()); + CHECK_EQ_OR_RETURN(op_parallel_desc.parallel_num(), output_bns().size()); *ParallelDesc4Obn(obn) = ParallelDesc(op_parallel_desc.GetParallelIdOnlyParallelConf(i)); } else { *ParallelDesc4Obn(obn) = op_parallel_desc; @@ -100,7 +100,7 @@ Maybe DistributeCloneOp::InferParallelSignature() { auto* map = mut_parallel_signature()->mutable_bn_in_op2parallel_desc_symbol_id(); (*map)["in"] = op_parallel_desc_symbol_id; const auto& op_parallel_desc = *JUST(scope.GetParallelDesc(op_conf())); - CHECK_EQ(op_parallel_desc.parallel_num(), output_bns().size()); + CHECK_EQ_OR_RETURN(op_parallel_desc.parallel_num(), output_bns().size()); FOR_RANGE(int, i, 0, output_bns().size()) { const auto& out_parallel_conf = op_parallel_desc.GetParallelIdOnlyParallelConf(i); const std::shared_ptr& cfg_out_parallel_conf = diff --git a/oneflow/python/eager/interpreter_callback.py b/oneflow/python/eager/interpreter_callback.py index 71414894d6..472ea4180b 100644 --- a/oneflow/python/eager/interpreter_callback.py +++ b/oneflow/python/eager/interpreter_callback.py @@ -19,11 +19,26 @@ import oneflow.python.eager.gradient_util as gradient_util import oneflow.python.eager.op_executor as op_executor import oneflow.core.operator.op_attribute_pb2 as op_attribute_pb import oneflow.core.job.job_conf_pb2 as job_conf_pb +import oneflow.core.job.scope_pb2 as scope_pb import oneflow.core.job.placement_pb2 as placement_pb from google.protobuf import text_format import oneflow.python.eager.blob_register as blob_register_util import oneflow.python.framework.scope_util as scope_util +import oneflow.python.framework.scope_symbol as scope_symbol import oneflow.python.eager.vm_util as vm_util +import oneflow.python.eager.symbol_storage as symbol_storage + + +def AddScopeToStorage(scope_symbol_id, scope_proto_str): + if symbol_storage.HasSymbol4SerializedScopeProto(scope_proto_str): + return + scope_proto = text_format.Parse(scope_proto_str, scope_pb.ScopeProto()) + parent_scope_symbol = symbol_storage.GetSymbol4Id( + scope_proto.parent_scope_symbol_id + ) + symbol = scope_symbol.ScopeSymbol(scope_symbol_id, scope_proto, parent_scope_symbol) + symbol_storage.SetSymbol4Id(scope_symbol_id, symbol) + symbol_storage.SetSymbol4SerializedScopeProto(scope_proto_str, symbol) def MakeScopeSymbol(job_conf_str, parallel_conf_str, is_mirrored): diff --git a/oneflow/python/framework/python_callback.py b/oneflow/python/framework/python_callback.py index bb855ee9eb..e86b9619d5 100644 --- a/oneflow/python/framework/python_callback.py +++ b/oneflow/python/framework/python_callback.py @@ -79,6 +79,15 @@ class PythonCallback(oneflow_api.ForeignCallback): print(traceback.format_exc()) raise e + def AddScopeToPyStorage(self, scope_symbol_id, scope_proto_str): + try: + return interpreter_callback.AddScopeToStorage( + scope_symbol_id, scope_proto_str + ) + except Exception as e: + print(traceback.format_exc()) + raise e + def MakeScopeSymbol(self, job_conf, parallel_conf, is_mirrored): try: # TODO(hanbinbin): str() will be removed after proto obj is replaced with cfg obj in python side -- GitLab