未验证 提交 1467aecf 编写于 作者: L Li Xinqi 提交者: GitHub

Op collection (#3833)

* more cfg files

* instructions builder

* forward declaration instead of include

* more test for cfg

* revert cfg files

* InstructionsBuilder

* using std::function as argument of IdCache::FindOrCreate

* scope op_collection

* include <functional> in framework/interpreter.h

* puts more code into WithOptimizerOpCollectionScope

* include <functional> in symbol_id_cache.h

* calculation pass

* refine Error

* AddScopeToPyStorage

* fix test_watch

* get scope_symbol_id from current scope

* fix assert bug
Co-authored-by: qq_22305325's avatarbinbinHan <han_binbin@163.com>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 58627eec
......@@ -61,6 +61,12 @@ class PyForeignCallback : public ForeignCallback {
is_mirrored);
}
// TODO(lixinqi): remove this urgly api after python code migrated into cpp code
void AddScopeToPyStorage(int64_t scope_symbol_id,
const std::string& scope_proto_str) const override {
PYBIND11_OVERRIDE(void, ForeignCallback, AddScopeToPyStorage, scope_symbol_id, scope_proto_str);
}
int64_t MakeParallelDescSymbol(
const std::shared_ptr<cfg::ParallelConf>& parallel_conf) const override {
PYBIND11_OVERRIDE(int64_t, ForeignCallback, MakeParallelDescSymbol, parallel_conf);
......
......@@ -28,6 +28,13 @@ void LogError(const Error& error) {
} // namespace
Error&& Error::AddStackFrame(const std::string& location, const std::string& function) {
auto* stack_frame = error_proto_->add_stack_frame();
stack_frame->set_location(location);
stack_frame->set_function(function);
return std::move(*this);
}
Error::operator std::string() const { return PbMessage2TxtString(*error_proto_); }
Error Error::Ok() { return std::make_shared<ErrorProto>(); }
......@@ -223,13 +230,4 @@ Error Error::GradientFunctionNotFound() {
return error;
}
Error&& operator<=(const std::pair<std::string, std::string>& loc_and_func, Error&& error) {
LogError(error);
CHECK(error.error_proto()->stack_frame().empty());
auto* stack_frame = error.error_proto()->add_stack_frame();
stack_frame->set_location(loc_and_func.first);
stack_frame->set_function(loc_and_func.second);
return std::move(error);
}
} // namespace oneflow
......@@ -28,6 +28,10 @@ class Error final {
Error(const Error&) = default;
~Error() = default;
// r-value reference is used to supporting expressions like `Error().AddStackFrame("foo.cpp",
// "Bar") << "invalid value"` because operator<<() need r-value reference
Error&& AddStackFrame(const std::string& location, const std::string& function);
static Error Ok();
static Error ProtoParseFailedError();
static Error JobSetEmptyError();
......@@ -66,7 +70,8 @@ class Error final {
static Error GradientFunctionNotFound();
std::shared_ptr<ErrorProto> error_proto() const { return error_proto_; }
ErrorProto* operator->() const { return error_proto_.get(); }
const ErrorProto* operator->() const { return error_proto_.get(); }
ErrorProto* operator->() { return error_proto_.get(); }
operator std::string() const;
void Assign(const Error& other) { error_proto_ = other.error_proto_; }
......@@ -74,11 +79,17 @@ class Error final {
std::shared_ptr<ErrorProto> error_proto_;
};
// r-value reference is used to supporting expressions like `Error() << "invalid value"`
template<typename T>
Error&& operator<<(Error&& error, const T& x) {
std::ostringstream ss;
ss << x;
error->set_msg(error->msg() + ss.str());
if (error->stack_frame().empty()) {
error->set_msg(error->msg() + ss.str());
} else {
auto* stack_frame_top = error->mutable_stack_frame(error->stack_frame_size() - 1);
stack_frame_top->set_error_msg(stack_frame_top->error_msg() + ss.str());
}
return std::move(error);
}
......@@ -88,9 +99,6 @@ inline Error&& operator<<(Error&& error, const Error& other) {
return std::move(error);
}
// for LOG(ERROR)
Error&& operator<=(const std::pair<std::string, std::string>& loc_and_func, Error&& error);
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_ERROR_H_
......@@ -112,6 +112,7 @@ message UnkownError { }
message ErrorStackFrame {
required string location = 1;
required string function = 2;
required string error_msg = 3;
}
message ErrorProto {
......
......@@ -222,14 +222,10 @@ inline bool MaybeIsOk(Maybe<void>&& maybe) {
#define CHECK_OK(...) CHECK(MaybeIsOk(std::move(__VA_ARGS__)))
#define OF_RETURN_IF_ERROR(...) \
MAYBE_CONST_AUTO_REF maybe_##__LINE__ = __MaybeErrorStackCheckWrapper__(__VA_ARGS__); \
if (!maybe_##__LINE__.IsOk()) { \
auto* stack_frame = maybe_##__LINE__.error()->add_stack_frame(); \
stack_frame->set_location(MAYBE_FAILED_LOC); \
stack_frame->set_function(__FUNCTION__); \
return maybe_##__LINE__.error(); \
}
#define OF_RETURN_IF_ERROR(...) \
for (MAYBE_CONST_AUTO_REF maybe_##__LINE__ = __MaybeErrorStackCheckWrapper__(__VA_ARGS__); \
!maybe_##__LINE__.IsOk();) \
return Error(maybe_##__LINE__.error()).AddStackFrame(MAYBE_FAILED_LOC, __FUNCTION__)
#else
#error statement expression is no supported, please implement try-catch version of JUST
......@@ -237,16 +233,14 @@ inline bool MaybeIsOk(Maybe<void>&& maybe) {
} // namespace oneflow
#define OF_TODO() \
return std::pair<std::string, std::string>(MAYBE_FAILED_LOC, __FUNCTION__) <= Error::Todo()
#define OF_UNIMPLEMENTED() \
return std::pair<std::string, std::string>(MAYBE_FAILED_LOC, __FUNCTION__) \
<= Error::Unimplemented()
#define OF_TODO() return Error::Todo().AddStackFrame(MAYBE_FAILED_LOC, __FUNCTION__)
#define OF_UNIMPLEMENTED() \
return Error::Unimplemented().AddStackFrame(MAYBE_FAILED_LOC, __FUNCTION__)
#define CHECK_OR_RETURN(expr) \
if (!(expr)) \
return std::pair<std::string, std::string>(MAYBE_FAILED_LOC, __FUNCTION__) \
<= Error::CheckFailedError() << " Check failed: " << OF_PP_STRINGIZE(expr) << "\t"
#define CHECK_OR_RETURN(expr) \
if (!(expr)) \
return Error::CheckFailedError().AddStackFrame(MAYBE_FAILED_LOC, __FUNCTION__) \
<< " Check failed: " << OF_PP_STRINGIZE(expr) << "\t"
#define CHECK_EQ_OR_RETURN(lhs, rhs) \
CHECK_OR_RETURN((lhs) == (rhs)) << "(" << (lhs) << " vs " << (rhs) << ") "
......
......@@ -25,6 +25,10 @@ LogicalInterpreter::LogicalInterpreter()
Maybe<void> LogicalInterpreter::Run(const std::function<Maybe<void>(InstructionsBuilder*)>& Build) {
InstructionsBuilder instructions_builder(mut_id_generator());
JUST(Build(&instructions_builder));
if (instructions_builder.instruction_list().instruction().empty()) {
CHECK(instructions_builder.eager_symbol_list().eager_symbol().empty());
return Maybe<void>::Ok();
}
return Global<eager::EagerOneflow>::Get()->RunLogicalInstruction(
instructions_builder.instruction_list(), instructions_builder.eager_symbol_list());
}
......@@ -36,6 +40,10 @@ Maybe<void> PhysicalInterpreter::Run(
const std::function<Maybe<void>(InstructionsBuilder*)>& Build) {
InstructionsBuilder instructions_builder(mut_id_generator());
JUST(Build(&instructions_builder));
if (instructions_builder.instruction_list().instruction().empty()) {
CHECK(instructions_builder.eager_symbol_list().eager_symbol().empty());
return Maybe<void>::Ok();
}
return Global<eager::EagerOneflow>::Get()->RunPhysicalInstruction(
instructions_builder.instruction_list(), instructions_builder.eager_symbol_list());
}
......
......@@ -41,6 +41,12 @@ class ForeignCallback {
virtual void RemoveForeignCallback(int64_t unique_id) const { UNIMPLEMENTED(); }
// TODO(lixinqi): remove this urgly api after python code migrated into cpp code
virtual void AddScopeToPyStorage(int64_t scope_symbol_id,
const std::string& scope_proto_str) const {
UNIMPLEMENTED();
}
// return scope_symbol_id
virtual int64_t MakeScopeSymbol(const std::shared_ptr<cfg::JobConfigProto>& job_conf,
const std::shared_ptr<cfg::ParallelConf>& parallel_conf,
......
......@@ -470,7 +470,7 @@ Maybe<OpAttribute> JobBuildAndInferCtx::AddAndInferMirroredOp(const OperatorConf
FOR_RANGE(int32_t, i, 0, sub_op_list_size) {
ResetOpConfName(&sub_op_conf, GetSubOpName(i));
for (const auto& ibn : op->input_bns()) {
const auto& lbi = *JUST(GetSubLbi(op->BnInOp2Lbi(ibn), i));
const auto& lbi = *JUST(GetSubLbi(op_conf.scope_symbol_id(), op->BnInOp2Lbi(ibn), i));
ReplaceInputLbnInOpCustomizedConf(&sub_op_conf, ibn, GenLogicalBlobName(lbi));
}
const ParallelConf& parallel_conf = GetMirroredOpParallelConf(parallel_desc, i);
......@@ -495,11 +495,13 @@ Maybe<OpAttribute> JobBuildAndInferCtx::AddAndInferMirroredOp(const OperatorConf
return last_op_attribute;
}
Maybe<const LogicalBlobId*> JobBuildAndInferCtx::GetSubLbi(const LogicalBlobId& lbi,
Maybe<const LogicalBlobId*> JobBuildAndInferCtx::GetSubLbi(int64_t scope_symbol_id,
const LogicalBlobId& lbi,
int32_t index) {
auto lbi_vec_iter = mirrored_lbi2sub_lbis_.find(lbi);
if (lbi_vec_iter == mirrored_lbi2sub_lbis_.end()) {
const auto& new_lbi = JUST(FindOrCreateMirroredLbiFromCompatibleConsistentBlob(lbi));
const auto& new_lbi =
JUST(FindOrCreateMirroredLbiFromCompatibleConsistentBlob(scope_symbol_id, lbi));
lbi_vec_iter = mirrored_lbi2sub_lbis_.find(*new_lbi);
CHECK(lbi_vec_iter != mirrored_lbi2sub_lbis_.end());
}
......@@ -836,7 +838,7 @@ ParallelConf EagerJobBuildAndInferCtx::GetMirroredOpParallelConf(const ParallelD
}
Maybe<LogicalBlobId> LazyJobBuildAndInferCtx::FindOrCreateMirroredLbiFromCompatibleConsistentBlob(
const LogicalBlobId& lbi) {
int64_t scope_symbol_id, const LogicalBlobId& lbi) {
const std::string& lbn = GenLogicalBlobName(lbi);
const auto& sbn_it = mut_consistent_lbi2mirrored_lbi()->find(lbi);
if (sbn_it != mut_consistent_lbi2mirrored_lbi()->end()) { return sbn_it->second; }
......@@ -855,6 +857,7 @@ Maybe<LogicalBlobId> LazyJobBuildAndInferCtx::FindOrCreateMirroredLbiFromCompati
lbi_vec->push_back(sub_lbi);
};
OperatorConf op_conf;
op_conf.set_scope_symbol_id(scope_symbol_id);
op_conf.set_device_tag(CHECK_JUST(DeviceTag4DeviceType(parallel_desc.device_type())));
if (sbp.has_broadcast_parallel()) {
op_conf.set_name(kAutoMirroredBlobNamePrefix + "-DistributeClone-" + NewUniqueId());
......@@ -885,7 +888,6 @@ Maybe<LogicalBlobId> LazyJobBuildAndInferCtx::FindOrCreateMirroredLbiFromCompati
{
const auto& producer_op_conf = JUST(Op4OpName(lbi.op_name()))->op_conf();
CHECK_OR_RETURN(producer_op_conf.has_scope_symbol_id());
int64_t scope_symbol_id = producer_op_conf.scope_symbol_id();
const auto& scope = Global<vm::SymbolStorage<Scope>>::Get()->Get(scope_symbol_id);
const auto* job_desc = JUST(scope.job_desc());
JUST(AddAndInferOp(op_conf, parallel_desc.parallel_conf(), job_desc, false));
......@@ -894,7 +896,7 @@ Maybe<LogicalBlobId> LazyJobBuildAndInferCtx::FindOrCreateMirroredLbiFromCompati
}
Maybe<LogicalBlobId> EagerJobBuildAndInferCtx::FindOrCreateMirroredLbiFromCompatibleConsistentBlob(
const LogicalBlobId& lbi) {
int64_t scope_symbol_id, const LogicalBlobId& lbi) {
const std::string& lbn = GenLogicalBlobName(lbi);
const auto& sbn_it = mut_consistent_lbi2mirrored_lbi()->find(lbi);
if (sbn_it != mut_consistent_lbi2mirrored_lbi()->end()) { return sbn_it->second; }
......@@ -909,6 +911,7 @@ Maybe<LogicalBlobId> EagerJobBuildAndInferCtx::FindOrCreateMirroredLbiFromCompat
CHECK_OR_RETURN(producer_op_conf.has_scope_symbol_id());
op_conf.set_scope_symbol_id(producer_op_conf.scope_symbol_id());
}
op_conf.set_scope_symbol_id(scope_symbol_id);
// const char* device_tag = JUST(DeviceTag4DeviceType(parallel_desc.device_type()));
op_conf.set_device_tag(JUST(DeviceTag4DeviceType(parallel_desc.device_type())));
op_conf.set_name(kAutoMirroredBlobNamePrefix + "-CastToMirrored-" + NewUniqueId());
......
......@@ -78,7 +78,7 @@ class JobBuildAndInferCtx {
int64_t parallel_id) const = 0;
virtual bool GetIsMirroredParallelView() const = 0;
virtual Maybe<LogicalBlobId> FindOrCreateMirroredLbiFromCompatibleConsistentBlob(
const LogicalBlobId& lbn) = 0;
int64_t scope_symbol_id, const LogicalBlobId& lbn) = 0;
Job* mut_job() const { return job_; }
int64_t job_id() const { return job_id_; }
......@@ -128,7 +128,8 @@ class JobBuildAndInferCtx {
Maybe<void> CheckAllInputsConvertableToMirroredBlob(const Operator& op) const;
Maybe<void> AddLossConsistentBlobName(const std::string& lbn);
Maybe<void> AddLossMirroredBlobName(const std::string& lbn);
Maybe<const LogicalBlobId*> GetSubLbi(const LogicalBlobId& lbi, int32_t index);
Maybe<const LogicalBlobId*> GetSubLbi(int64_t scope_symbol_id, const LogicalBlobId& lbi,
int32_t index);
Maybe<bool> AllInputsBroadcastParallel(const Operator& op) const;
void InferBlobBackwardSignature(Operator* op);
void InferBlobBackwardSignature(const Operator& op,
......@@ -167,7 +168,7 @@ class LazyJobBuildAndInferCtx : public JobBuildAndInferCtx {
ParallelConf GetMirroredOpParallelConf(const ParallelDesc&, int64_t parallel_id) const override;
bool GetIsMirroredParallelView() const override { return false; }
Maybe<LogicalBlobId> FindOrCreateMirroredLbiFromCompatibleConsistentBlob(
const LogicalBlobId& lbn) override;
int64_t scope_symbol_id, const LogicalBlobId& lbn) override;
};
class EagerJobBuildAndInferCtx : public JobBuildAndInferCtx {
......@@ -185,7 +186,7 @@ class EagerJobBuildAndInferCtx : public JobBuildAndInferCtx {
ParallelConf GetMirroredOpParallelConf(const ParallelDesc&, int64_t parallel_id) const override;
bool GetIsMirroredParallelView() const override { return true; }
Maybe<LogicalBlobId> FindOrCreateMirroredLbiFromCompatibleConsistentBlob(
const LogicalBlobId& lbn) override;
int64_t scope_symbol_id, const LogicalBlobId& lbn) override;
HashSet<std::string> executed_op_names_;
};
......
......@@ -14,4 +14,5 @@ message ScopeProto {
optional int64 parent_scope_symbol_id = 70;
required int64 session_id = 80;
map<string, AttrValue> attr_name2attr_value = 90;
optional string calculation_pass_name = 100 [default = "forward_pass"];
}
......@@ -176,6 +176,7 @@ void InsertCastOpImpl(bool f2h, const OpGraph& op_graph, const HashSet<OpNode*>&
.Input("in", lbn)
.Output("out")
.Attr<DataType>("dtype", cast_data_type)
.ScopeSymbolId(src_node->op().op_conf().scope_symbol_id())
.Build();
bool cast_is_consumed = false;
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/job_rewriter/calculation_pass.h"
namespace oneflow {
const std::string kForwardPass = "forward_pass";
const std::string kBackwardPass = "backward_pass";
const std::string kOptimizerPass = "optimizer_pass";
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_JOB_REWRITE_CALCULATION_PASS_H_
#define ONEFLOW_CORE_JOB_REWRITE_CALCULATION_PASS_H_
#include <string>
namespace oneflow {
extern const std::string kForwardPass;
extern const std::string kBackwardPass;
extern const std::string kOptimizerPass;
} // namespace oneflow
#endif // ONEFLOW_CORE_JOB_REWRITE_CALCULATION_PASS_H_
......@@ -16,6 +16,12 @@ limitations under the License.
#include "oneflow/core/job_rewriter/job_pass.h"
#include "oneflow/core/job_rewriter/autograd.h"
#include "oneflow/core/job_rewriter/optimizer.h"
#include "oneflow/core/job_rewriter/calculation_pass.h"
#include "oneflow/core/job/scope.h"
#include "oneflow/core/job/foreign_callback.h"
#include "oneflow/core/vm/symbol_storage.h"
#include "oneflow/core/framework/interpreter.h"
#include "oneflow/core/framework/instructions_builder.h"
namespace oneflow {
......@@ -108,28 +114,87 @@ void FilterModelLbi2DiffLbi(const OpGraph& op_graph,
}
}
Maybe<JobBuilder> WithCalculationPassScope(const std::string& pass_name, Job* job,
const std::function<Maybe<void>()>& Handler) {
HashSet<std::string> exists_op_names;
for (const auto& op_conf : job->net().op()) {
CHECK_OR_RETURN(exists_op_names.emplace(op_conf.name()).second);
}
JUST(Handler());
// using a new JobBuilder to avoid bugs caused by MutOnlyOnce
auto new_job_builder = std::make_shared<JobBuilder>(job);
HashMap<int64_t, std::vector<const OperatorConf*>> scope_id2op_names;
const auto& scope_storage = *Global<vm::SymbolStorage<Scope>>::Get();
for (const auto& op_conf : job->net().op()) {
if (exists_op_names.count(op_conf.name()) > 0) { continue; }
CHECK_OR_RETURN(op_conf.has_scope_symbol_id());
OF_RETURN_IF_ERROR(scope_storage.MaybeGet(op_conf.scope_symbol_id())) << op_conf.DebugString();
scope_id2op_names[op_conf.scope_symbol_id()].push_back(&op_conf);
}
const auto& GetNewScopeSymbolId = [&](int64_t old_scope_symbol_id) -> Maybe<int64_t> {
const auto& old_scope = JUST(scope_storage.MaybeGet(old_scope_symbol_id));
cfg::ScopeProto new_scope;
new_scope.InitFromProto(old_scope.scope_proto());
new_scope.set_parent_scope_symbol_id(old_scope_symbol_id);
new_scope.set_calculation_pass_name(pass_name);
int64_t symbol_id = 0;
JUST(LogicalInterpreter().Run([&](InstructionsBuilder* builder) -> Maybe<void> {
symbol_id = JUST(builder->FindOrCreateSymbolId<cfg::ScopeProto>(new_scope));
return Maybe<void>::Ok();
}));
// Remove this urgly code after most python code migrated into cpp code
{
ScopeProto scope_proto;
new_scope.ToProto(&scope_proto);
Global<ForeignCallback>::Get()->AddScopeToPyStorage(symbol_id, scope_proto.DebugString());
}
return symbol_id;
};
for (const auto& pair : scope_id2op_names) {
int64_t new_scope_symbol_id = JUST(GetNewScopeSymbolId(pair.first));
std::vector<OperatorConf> op_confs(pair.second.size());
for (int i = 0; i < pair.second.size(); ++i) {
op_confs.at(i).CopyFrom(*pair.second.at(i));
op_confs.at(i).set_scope_symbol_id(new_scope_symbol_id);
}
new_job_builder->MutOpsOnlyOnce(op_confs);
}
return new_job_builder;
}
Maybe<void> GenerateBackwardAndOptimizerOpConfs::Apply(Job* job, JobPassCtx* ctx) const {
if (!IsEnabled(*ctx)) { return Maybe<void>::Ok(); }
const OpGraph op_graph(*job);
JobBuilder job_builder(job);
auto job_builder = std::make_shared<JobBuilder>(job);
const JobBuilder* old_job_builder = job_builder.get();
LogicalBlobId total_loss_instance_num;
HashMap<LogicalBlobId, LogicalBlobId> lbi2diff_lbi;
JUST(AutoGrad(op_graph, &job_builder, &lbi2diff_lbi));
job_builder = JUST(WithCalculationPassScope(kBackwardPass, job, [&]() -> Maybe<void> {
CHECK(old_job_builder == job_builder.get()); // Check this lambda never been async called
JUST(AutoGrad(op_graph, job_builder.get(), &lbi2diff_lbi));
return Maybe<void>::Ok();
}));
HashMap<LogicalBlobId, LogicalBlobId> model_lbi2model_diff_lbi;
FilterModelLbi2DiffLbi(op_graph, lbi2diff_lbi, &model_lbi2model_diff_lbi);
AddDiffStaticShapeCast(op_graph, &job_builder, &model_lbi2model_diff_lbi);
AddDiffParallelCast(op_graph, &job_builder, &model_lbi2model_diff_lbi);
JUST(ScaleModelDiffByLossInstanceNum(op_graph, &job_builder, &model_lbi2model_diff_lbi));
ScaleModelDiffByLossScale(op_graph, &job_builder, &model_lbi2model_diff_lbi);
const NormalModelUpdateOpUserConf& model_update_conf =
job->job_conf().train_conf().model_update_conf();
RegularizeGradient(op_graph, &job_builder, &model_lbi2model_diff_lbi);
if (model_update_conf.has_clip_conf()) {
ClipGradient(op_graph, &job_builder, &model_lbi2model_diff_lbi, model_update_conf.clip_conf());
}
AddOptimizerOpConf(ctx, op_graph, &job_builder, model_lbi2model_diff_lbi);
UpdateJobHelperConfProducedLbi2ConsumedDiffLbi(lbi2diff_lbi, &job_builder);
UpdateOpSbpSignatureHint(op_graph, &job_builder);
old_job_builder = job_builder.get();
job_builder = JUST(WithCalculationPassScope(kOptimizerPass, job, [&]() -> Maybe<void> {
CHECK(old_job_builder == job_builder.get()); // Check this lambda never been async called
AddDiffStaticShapeCast(op_graph, job_builder.get(), &model_lbi2model_diff_lbi);
AddDiffParallelCast(op_graph, job_builder.get(), &model_lbi2model_diff_lbi);
JUST(ScaleModelDiffByLossInstanceNum(op_graph, job_builder.get(), &model_lbi2model_diff_lbi));
ScaleModelDiffByLossScale(op_graph, job_builder.get(), &model_lbi2model_diff_lbi);
const NormalModelUpdateOpUserConf& model_update_conf =
job->job_conf().train_conf().model_update_conf();
RegularizeGradient(op_graph, job_builder.get(), &model_lbi2model_diff_lbi);
if (model_update_conf.has_clip_conf()) {
ClipGradient(op_graph, job_builder.get(), &model_lbi2model_diff_lbi,
model_update_conf.clip_conf());
}
AddOptimizerOpConf(ctx, op_graph, job_builder.get(), model_lbi2model_diff_lbi);
return Maybe<void>::Ok();
}));
UpdateJobHelperConfProducedLbi2ConsumedDiffLbi(lbi2diff_lbi, job_builder.get());
UpdateOpSbpSignatureHint(op_graph, job_builder.get());
return Maybe<void>::Ok();
}
......
......@@ -61,6 +61,7 @@ void GenerateCastToMirroredBackwardOpConf(
if (DiffLbi4BnInOp("in") != nullptr) {
OperatorConf grad_op{};
grad_op.set_name("System-AutoGrad-" + op.op_name());
grad_op.set_scope_symbol_id(op.op_conf().scope_symbol_id());
CastFromMirroredOpConf* bw_op_conf = grad_op.mutable_cast_from_mirrored_conf();
bw_op_conf->set_in(GenLogicalBlobName(*DiffLbi4BnInOp("out")));
bw_op_conf->set_out("out");
......@@ -85,6 +86,7 @@ void GenerateCastFromMirroredBackwardOpConf(
if (DiffLbi4BnInOp("in") != nullptr) {
OperatorConf grad_op{};
grad_op.set_name("System-AutoGrad-" + op.op_name());
grad_op.set_scope_symbol_id(op.op_conf().scope_symbol_id());
CastToMirroredOpConf* bw_op_conf = grad_op.mutable_cast_to_mirrored_conf();
bw_op_conf->set_in(GenLogicalBlobName(*DiffLbi4BnInOp("out")));
bw_op_conf->set_out("out");
......
......@@ -64,7 +64,7 @@ Maybe<void> DistributeCloneOp::InferBlobDescs(
const ParallelContext* parallel_ctx) const {
const auto& in_blob_desc = *GetBlobDesc4BnInOp("in");
if (parallel_ctx->parallel_num() > 1) {
CHECK_EQ(parallel_ctx->parallel_num(), output_bns().size());
CHECK_EQ_OR_RETURN(parallel_ctx->parallel_num(), output_bns().size());
auto* out_blob_desc = GetBlobDesc4BnInOp(output_bns().Get(parallel_ctx->parallel_id()));
*out_blob_desc = in_blob_desc;
return Maybe<void>::Ok();
......@@ -83,7 +83,7 @@ Maybe<void> DistributeCloneOp::InferOutParallelDesc(
FOR_RANGE(int, i, 0, output_bns().size()) {
const auto& obn = output_bns().Get(i);
if (op_parallel_desc.parallel_num() > 1) {
CHECK_EQ(op_parallel_desc.parallel_num(), output_bns().size());
CHECK_EQ_OR_RETURN(op_parallel_desc.parallel_num(), output_bns().size());
*ParallelDesc4Obn(obn) = ParallelDesc(op_parallel_desc.GetParallelIdOnlyParallelConf(i));
} else {
*ParallelDesc4Obn(obn) = op_parallel_desc;
......@@ -100,7 +100,7 @@ Maybe<void> DistributeCloneOp::InferParallelSignature() {
auto* map = mut_parallel_signature()->mutable_bn_in_op2parallel_desc_symbol_id();
(*map)["in"] = op_parallel_desc_symbol_id;
const auto& op_parallel_desc = *JUST(scope.GetParallelDesc(op_conf()));
CHECK_EQ(op_parallel_desc.parallel_num(), output_bns().size());
CHECK_EQ_OR_RETURN(op_parallel_desc.parallel_num(), output_bns().size());
FOR_RANGE(int, i, 0, output_bns().size()) {
const auto& out_parallel_conf = op_parallel_desc.GetParallelIdOnlyParallelConf(i);
const std::shared_ptr<cfg::ParallelConf>& cfg_out_parallel_conf =
......
......@@ -19,11 +19,26 @@ import oneflow.python.eager.gradient_util as gradient_util
import oneflow.python.eager.op_executor as op_executor
import oneflow.core.operator.op_attribute_pb2 as op_attribute_pb
import oneflow.core.job.job_conf_pb2 as job_conf_pb
import oneflow.core.job.scope_pb2 as scope_pb
import oneflow.core.job.placement_pb2 as placement_pb
from google.protobuf import text_format
import oneflow.python.eager.blob_register as blob_register_util
import oneflow.python.framework.scope_util as scope_util
import oneflow.python.framework.scope_symbol as scope_symbol
import oneflow.python.eager.vm_util as vm_util
import oneflow.python.eager.symbol_storage as symbol_storage
def AddScopeToStorage(scope_symbol_id, scope_proto_str):
if symbol_storage.HasSymbol4SerializedScopeProto(scope_proto_str):
return
scope_proto = text_format.Parse(scope_proto_str, scope_pb.ScopeProto())
parent_scope_symbol = symbol_storage.GetSymbol4Id(
scope_proto.parent_scope_symbol_id
)
symbol = scope_symbol.ScopeSymbol(scope_symbol_id, scope_proto, parent_scope_symbol)
symbol_storage.SetSymbol4Id(scope_symbol_id, symbol)
symbol_storage.SetSymbol4SerializedScopeProto(scope_proto_str, symbol)
def MakeScopeSymbol(job_conf_str, parallel_conf_str, is_mirrored):
......
......@@ -79,6 +79,15 @@ class PythonCallback(oneflow_api.ForeignCallback):
print(traceback.format_exc())
raise e
def AddScopeToPyStorage(self, scope_symbol_id, scope_proto_str):
try:
return interpreter_callback.AddScopeToStorage(
scope_symbol_id, scope_proto_str
)
except Exception as e:
print(traceback.format_exc())
raise e
def MakeScopeSymbol(self, job_conf, parallel_conf, is_mirrored):
try:
# TODO(hanbinbin): str() will be removed after proto obj is replaced with cfg obj in python side
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册