未验证 提交 6626d998 编写于 作者: L Li Xinqi 提交者: GitHub

Scope with symbol (#4040)

* parallel desc with symbol_id

* migrate ParallelDescSymbol

* fix code format

* fix bug in oneflow_testexe

* Make oneflow worker docker stay alive for 6 hours

* exception

* except in pybind11 and python

* finetune api

* print traceback

* fix bug

* fix format

* ParallelDesc::cfg_parallel_conf

* remove traceback in test_checkpoint

* fix python codeformat

* del job_build_and_infer_cfg_error.py

* optimize api struct

* refactor JobDesc

* migrate JobConfSymbol

* del useless lines

* del useless lines

* add CompileOptionWrongError

* add CompileOptionWrongError

* replace python ScopeSymbol with cpp Scope

* fix typo

* fix bug

* rename OF_COMPLIE_OPTION_EEEOR

* fix conflict

* fix format

* fix bug

* fix conflict
Co-authored-by: qq_22305325's avatarclackhan <han_binbin@163.com>
Co-authored-by: NShenghang Tsai <jackalcooper@gmail.com>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 9c1e4133
......@@ -60,12 +60,6 @@ class PyForeignCallback : public ForeignCallback {
is_mirrored);
}
// TODO(lixinqi): remove this urgly api after python code migrated into cpp code
void AddScopeToPyStorage(int64_t scope_symbol_id,
const std::shared_ptr<cfg::ScopeProto>& scope_proto) const override {
PYBIND11_OVERRIDE(void, ForeignCallback, AddScopeToPyStorage, scope_symbol_id, scope_proto);
}
int64_t MakeParallelDescSymbol(
const std::shared_ptr<cfg::ParallelConf>& parallel_conf) const override {
PYBIND11_OVERRIDE(int64_t, ForeignCallback, MakeParallelDescSymbol, parallel_conf);
......@@ -84,6 +78,5 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
.def("OfBlobCall", &ForeignCallback::OfBlobCall)
.def("RemoveForeignCallback", &ForeignCallback::RemoveForeignCallback)
.def("MakeScopeSymbol", &ForeignCallback::MakeScopeSymbol)
.def("AddScopeToPyStorage", &ForeignCallback::AddScopeToPyStorage)
.def("MakeParallelDescSymbol", &ForeignCallback::MakeParallelDescSymbol);
}
......@@ -24,7 +24,7 @@ limitations under the License.
#include "oneflow/core/framework/user_op_conf.h"
#include "oneflow/core/framework/user_op_registry_manager.h"
#include "oneflow/core/vm/vm_util.h"
#include "oneflow/core/eager/eager_symbol_storage.h"
#include "oneflow/core/vm/symbol_storage.h"
namespace oneflow {
......
......@@ -13,25 +13,37 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EAGER_EAGER_SYMBOL_STORAGE_H_
#define ONEFLOW_CORE_EAGER_EAGER_SYMBOL_STORAGE_H_
#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/job/scope.h"
#include "oneflow/core/job/scope.cfg.h"
#include "oneflow/core/vm/symbol_storage.h"
#include "oneflow/core/job/job_desc.h"
namespace py = pybind11;
namespace oneflow {
class Scope;
class ScopeProto;
Maybe<Scope> CreateScopeSymbol(int64_t symbol_id,
const std::shared_ptr<cfg::ScopeProto>& symbol_conf) {
ScopeProto symbol_pb;
symbol_conf->ToProto(&symbol_pb);
return Scope::New(symbol_id, symbol_pb);
}
ONEFLOW_API_PYBIND11_MODULE("", m) {
py::class_<Scope, std::shared_ptr<Scope>>(m, "ScopeSymbol")
.def(py::init([](int64_t symbol_id, const std::shared_ptr<cfg::ScopeProto>& symbol_conf) {
return CreateScopeSymbol(symbol_id, symbol_conf).GetPtrOrThrow();
}))
.def_property_readonly("symbol_id", [](const Scope& x) { return x.symbol_id().GetOrThrow(); })
.def("auto_increment_id", &Scope::auto_increment_id)
.def_property_readonly("session_id", &Scope::session_id)
.def_property_readonly("session_id", &Scope::session_id)
.def_property_readonly("job_desc_symbol", &Scope::job_desc_symbol)
.def_property_readonly("device_parallel_desc_symbol", &Scope::device_parallel_desc_symbol)
.def_property_readonly("parent_scope_symbol", &Scope::parent_scope_symbol)
.def("MakeChildScopeProto",
[](const Scope& scope) { return scope.MakeChildScopeProto().GetOrThrow(); });
}
namespace symbol {
template<>
struct ConstructArgType4Symbol<Scope> final {
using type = ScopeProto;
};
} // namespace symbol
} // namespace oneflow
#endif // ONEFLOW_CORE_EAGER_EAGER_SYMBOL_STORAGE_H_
......@@ -19,6 +19,9 @@ limitations under the License.
#include "oneflow/core/vm/symbol_storage.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/scope.h"
#include "oneflow/core/job/scope.cfg.h"
#include "oneflow/core/job/scope.pb.h"
namespace py = pybind11;
......@@ -42,7 +45,7 @@ Maybe<SymbolT> GetSymbol(const SymbolConfT& symbol_conf) {
return ptr;
}
// TODO(hanbibin): the second template arg will be moved after symbol_storage is prefect
// TODO(hanbibin): the second template arg will be moved after symbol_storage is refactored
template<typename SymbolConfT, typename SymbolPbT, typename SymbolT>
Maybe<void> AddSymbol(int64_t symbol_id, const SymbolConfT& symbol_conf) {
SymbolPbT symbol_pb;
......@@ -73,7 +76,7 @@ Maybe<SymbolT> GetSymbol(int64_t symbol_id) {
}
template<typename SymbolConfT, typename SymbolT>
std::shared_ptr<SymbolT> ApiGetSymbol(int64_t symbol_id) {
std::shared_ptr<SymbolT> ApiGetSymbolById(int64_t symbol_id) {
return GetSymbol<SymbolConfT, SymbolT>(symbol_id).GetPtrOrThrow();
}
......@@ -83,19 +86,20 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("HasPlacementSymbol", &ApiHasSymbol<cfg::ParallelConf>);
m.def("AddPlacementSymbol", &ApiAddSymbol<cfg::ParallelConf, ParallelConf, ParallelDesc>);
m.def("GetPlacementSymbol",
static_cast<std::shared_ptr<ParallelDesc> (*)(const cfg::ParallelConf&)>(
&ApiGetSymbol<cfg::ParallelConf, ParallelDesc>));
m.def("GetPlacementSymbol", static_cast<std::shared_ptr<ParallelDesc> (*)(int64_t)>(
&ApiGetSymbol<cfg::ParallelConf, ParallelDesc>));
m.def("GetPlacementSymbol", &ApiGetSymbol<cfg::ParallelConf, ParallelDesc>);
m.def("GetPlacementSymbol", &ApiGetSymbolById<cfg::ParallelConf, ParallelDesc>);
m.def("HasJobConfSymbol", &ApiHasSymbol<cfg::JobConfigProto>);
m.def("AddJobConfSymbol", &ApiAddSymbol<cfg::JobConfigProto, JobConfigProto, JobDesc>);
m.def("GetJobConfSymbol", static_cast<std::shared_ptr<JobDesc> (*)(const cfg::JobConfigProto&)>(
&ApiGetSymbol<cfg::JobConfigProto, JobDesc>));
m.def("GetJobConfSymbol", static_cast<std::shared_ptr<JobDesc> (*)(int64_t)>(
&ApiGetSymbol<cfg::JobConfigProto, JobDesc>));
m.def("GetJobConfSymbol", &ApiGetSymbol<cfg::JobConfigProto, JobDesc>);
m.def("GetJobConfSymbol", &ApiGetSymbolById<cfg::JobConfigProto, JobDesc>);
m.def("HasScopeSymbol", &ApiHasSymbol<cfg::ScopeProto>);
m.def("AddScopeSymbol", &ApiAddSymbol<cfg::ScopeProto, ScopeProto, Scope>);
m.def("GetScopeSymbol", &ApiGetSymbol<cfg::ScopeProto, Scope>);
m.def("GetScopeSymbol", &ApiGetSymbolById<cfg::ScopeProto, Scope>);
}
} // namespace oneflow
......@@ -20,7 +20,7 @@ limitations under the License.
#include "oneflow/core/vm/vm_util.h"
#include "oneflow/core/vm/instruction.pb.h"
#include "oneflow/core/vm/instruction.cfg.h"
#include "oneflow/core/eager/eager_symbol_storage.h"
#include "oneflow/core/vm/symbol_storage.h"
#include "oneflow/core/eager/eager_symbol.cfg.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/scope.h"
......@@ -43,7 +43,7 @@ Maybe<void> StorageAdd(const EagerSymbol& symbol) {
if (symbol.has_string_symbol()) {
JUST(Global<symbol::Storage<std::string>>::Get()->Add(symbol_id, symbol.string_symbol()));
} else if (symbol.has_scope_symbol()) {
JUST(Global<symbol::Storage<Scope>>::Get()->Add(symbol_id, symbol.scope_symbol()));
JUST(Global<symbol::Storage<Scope>>::Get()->TryAdd(symbol_id, symbol.scope_symbol()));
} else if (symbol.has_job_conf_symbol()) {
JUST(Global<symbol::Storage<JobDesc>>::Get()->TryAdd(symbol_id, symbol.job_conf_symbol()));
} else if (symbol.has_parallel_conf_symbol()) {
......
......@@ -28,7 +28,7 @@ limitations under the License.
#include "oneflow/core/vm/string_object.h"
#include "oneflow/core/vm/test_util.h"
#include "oneflow/core/vm/object_wrapper.h"
#include "oneflow/core/eager/eager_symbol_storage.h"
#include "oneflow/core/vm/symbol_storage.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/operator/op_conf.pb.h"
......
......@@ -28,7 +28,7 @@ limitations under the License.
#include "oneflow/core/vm/string_object.h"
#include "oneflow/core/vm/test_util.h"
#include "oneflow/core/vm/object_wrapper.h"
#include "oneflow/core/eager/eager_symbol_storage.h"
#include "oneflow/core/vm/symbol_storage.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/job/resource_desc.h"
......
......@@ -25,7 +25,7 @@ limitations under the License.
#include "oneflow/core/vm/string_object.h"
#include "oneflow/core/vm/test_util.h"
#include "oneflow/core/vm/object_wrapper.h"
#include "oneflow/core/eager/eager_symbol_storage.h"
#include "oneflow/core/vm/symbol_storage.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/env_desc.h"
#include "oneflow/core/job/global_for.h"
......
......@@ -42,12 +42,6 @@ class ForeignCallback {
virtual void RemoveForeignCallback(int64_t unique_id) const { UNIMPLEMENTED(); }
// TODO(lixinqi): remove this urgly api after python code migrated into cpp code
virtual void AddScopeToPyStorage(int64_t scope_symbol_id,
const std::shared_ptr<cfg::ScopeProto>& scope_proto) const {
UNIMPLEMENTED();
}
// return scope_symbol_id
virtual int64_t MakeScopeSymbol(const std::shared_ptr<cfg::JobConfigProto>& job_conf,
const std::shared_ptr<cfg::ParallelConf>& parallel_conf,
......
......@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/eager/eager_symbol_storage.h"
#include "oneflow/core/vm/symbol_storage.h"
#include "oneflow/core/framework/config_def.h"
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/job/foreign_callback.h"
......
......@@ -15,24 +15,44 @@ limitations under the License.
*/
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/job/scope.h"
#include "oneflow/core/job/scope.cfg.h"
#include "oneflow/core/job/scope.pb.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/vm/symbol_storage.h"
namespace oneflow {
Scope::Scope(const ScopeProto& scope_proto) : scope_proto_(scope_proto) {
Scope::Scope(const ScopeProto& scope_proto)
: auto_increment_id_(0), symbol_id_(Error::SymbolIdUninitialized()), scope_proto_(scope_proto) {
CHECK_OK(Init()) << scope_proto_.DebugString();
}
Scope::Scope(int64_t symbol_id, const ScopeProto& scope_proto)
: auto_increment_id_(0), symbol_id_(symbol_id), scope_proto_(scope_proto) {}
Maybe<Scope> Scope::New(int64_t symbol_id, const ScopeProto& scope_proto) {
auto* ptr = new Scope(symbol_id, scope_proto);
std::shared_ptr<Scope> scope(ptr);
JUST(scope->Init());
return scope;
}
Maybe<void> Scope::Init() {
{
const auto& storage = *Global<symbol::Storage<JobDesc>>::Get();
job_desc_ = storage.GetPtr(scope_proto_.job_desc_symbol_id());
job_desc_ = JUST(storage.MaybeGetPtr(scope_proto_.job_desc_symbol_id()));
}
{
const auto& storage = *Global<symbol::Storage<ParallelDesc>>::Get();
device_parallel_desc_ = storage.GetPtr(scope_proto_.device_parallel_desc_symbol_id());
host_parallel_desc_ = storage.GetPtr(scope_proto_.host_parallel_desc_symbol_id());
device_parallel_desc_ =
JUST(storage.MaybeGetPtr(scope_proto_.device_parallel_desc_symbol_id()));
host_parallel_desc_ = JUST(storage.MaybeGetPtr(scope_proto_.host_parallel_desc_symbol_id()));
}
{
const auto& storage = *Global<symbol::Storage<Scope>>::Get();
if (scope_proto_.has_parent_scope_symbol_id()) {
parent_scope_symbol_ = JUST(storage.MaybeGetPtr(scope_proto_.parent_scope_symbol_id()));
}
}
return Maybe<void>::Ok();
}
......@@ -67,4 +87,10 @@ const AttrValue& Scope::GetAttrValue(const std::string& attr_name) const {
return def_iter->second.default_val();
}
Maybe<cfg::ScopeProto> Scope::MakeChildScopeProto() const {
auto child = std::make_shared<cfg::ScopeProto>(scope_proto_);
child->set_parent_scope_symbol_id(JUST(symbol_id()));
return child;
}
} // namespace oneflow
......@@ -26,6 +26,10 @@ namespace oneflow {
class OperatorConf;
namespace cfg {
class ScopeProto;
}
class Scope final {
public:
Scope(const Scope&) = delete;
......@@ -33,6 +37,17 @@ class Scope final {
explicit Scope(const ScopeProto& scope_proto);
~Scope() = default;
static Maybe<Scope> New(int64_t symbol_id, const ScopeProto& scope_proto);
const Maybe<int64_t>& symbol_id() const { return symbol_id_; }
int64_t auto_increment_id() { return ++auto_increment_id_; }
int64_t session_id() const { return scope_proto().session_id(); }
const std::shared_ptr<JobDesc>& job_desc_symbol() const { return job_desc_; }
const std::shared_ptr<ParallelDesc>& device_parallel_desc_symbol() const {
return device_parallel_desc_;
}
const std::shared_ptr<Scope>& parent_scope_symbol() const { return parent_scope_symbol_; }
Maybe<cfg::ScopeProto> MakeChildScopeProto() const;
Maybe<const JobDesc*> job_desc() const;
Maybe<int64_t> GetParallelDescSymbolId(const OperatorConf& op_conf) const;
Maybe<const ParallelDesc&> GetParallelDesc(const OperatorConf& op_conf) const;
......@@ -54,14 +69,18 @@ class Scope final {
DEFINE_SCOPE_CONFIG_GETTER(const std::string&, String, at_string);
private:
Scope(int64_t symbol_id, const ScopeProto& scope_proto);
Maybe<void> Init();
const AttrValue& GetAttrValue(const std::string& attr_name) const;
int64_t auto_increment_id_;
Maybe<int64_t> symbol_id_;
const ScopeProto scope_proto_;
std::shared_ptr<JobDesc> job_desc_;
std::shared_ptr<ParallelDesc> device_parallel_desc_;
std::shared_ptr<ParallelDesc> host_parallel_desc_;
std::shared_ptr<Scope> parent_scope_symbol_;
};
} // namespace oneflow
......
......@@ -18,10 +18,13 @@ limitations under the License.
#include "oneflow/core/job_rewriter/optimizer.h"
#include "oneflow/core/job_rewriter/calculation_pass.h"
#include "oneflow/core/job/scope.h"
#include "oneflow/core/job/scope.cfg.h"
#include "oneflow/core/job/scope.pb.h"
#include "oneflow/core/job/foreign_callback.h"
#include "oneflow/core/vm/symbol_storage.h"
#include "oneflow/core/framework/interpreter.h"
#include "oneflow/core/framework/instructions_builder.h"
#include "oneflow/core/framework/symbol_id_cache.h"
namespace oneflow {
......@@ -114,6 +117,18 @@ void FilterModelLbi2DiffLbi(const OpGraph& op_graph,
}
}
// TODO(lixinqi): Refactor this function after symbol::IdCache and symbol::Storage merged
template<typename SymbolConfT, typename SymbolPbT, typename SymbolT>
Maybe<void> TryAddSymbol(int64_t symbol_id, const SymbolConfT& symbol_conf) {
SymbolPbT symbol_pb;
symbol_conf.ToProto(&symbol_pb);
auto* id_cache = Global<symbol::IdCache<SymbolConfT>>::Get();
if (id_cache->Has(symbol_conf)) { return Maybe<void>::Ok(); }
JUST(id_cache->FindOrCreate(symbol_conf, [&symbol_id]() -> Maybe<int64_t> { return symbol_id; }));
JUST(Global<symbol::Storage<SymbolT>>::Get()->TryAdd(symbol_id, symbol_pb));
return Maybe<void>::Ok();
}
Maybe<JobBuilder> WithCalculationPassScope(const std::string& pass_name, Job* job,
const std::function<Maybe<void>()>& Handler) {
HashSet<std::string> exists_op_names;
......@@ -142,7 +157,7 @@ Maybe<JobBuilder> WithCalculationPassScope(const std::string& pass_name, Job* jo
symbol_id = JUST(builder->FindOrCreateSymbolId<cfg::ScopeProto>(*new_scope));
return Maybe<void>::Ok();
}));
Global<ForeignCallback>::Get()->AddScopeToPyStorage(symbol_id, new_scope);
JUST(TryAddSymbol<cfg::ScopeProto, ScopeProto, Scope>(symbol_id, *new_scope));
return symbol_id;
};
for (const auto& pair : scope_id2op_names) {
......
......@@ -17,7 +17,7 @@ limitations under the License.
#include "oneflow/core/graph/logical_node.h"
#include "oneflow/core/common/balanced_splitter.h"
#include "oneflow/core/job/foreign_callback.h"
#include "oneflow/core/eager/eager_symbol_storage.h"
#include "oneflow/core/vm/symbol_storage.h"
#include "oneflow/core/job/scope.h"
namespace oneflow {
......
......@@ -17,7 +17,7 @@ limitations under the License.
#include "oneflow/core/graph/logical_node.h"
#include "oneflow/core/common/balanced_splitter.h"
#include "oneflow/core/job/foreign_callback.h"
#include "oneflow/core/eager/eager_symbol_storage.h"
#include "oneflow/core/vm/symbol_storage.h"
#include "oneflow/core/job/scope.h"
namespace oneflow {
......
......@@ -17,7 +17,7 @@ limitations under the License.
#include "oneflow/core/graph/logical_node.h"
#include "oneflow/core/common/balanced_splitter.h"
#include "oneflow/core/job/foreign_callback.h"
#include "oneflow/core/eager/eager_symbol_storage.h"
#include "oneflow/core/vm/symbol_storage.h"
#include "oneflow/core/job/scope.h"
namespace oneflow {
......
......@@ -17,7 +17,7 @@ limitations under the License.
#include "oneflow/core/graph/logical_node.h"
#include "oneflow/core/common/balanced_splitter.h"
#include "oneflow/core/job/foreign_callback.h"
#include "oneflow/core/eager/eager_symbol_storage.h"
#include "oneflow/core/vm/symbol_storage.h"
#include "oneflow/core/job/scope.h"
namespace oneflow {
......
......@@ -16,7 +16,7 @@ limitations under the License.
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/graph/logical_node.h"
#include "oneflow/core/eager/eager_symbol_storage.h"
#include "oneflow/core/vm/symbol_storage.h"
#include "oneflow/core/job/scope.h"
namespace oneflow {
......
......@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/balanced_splitter.h"
#include "oneflow/core/eager/eager_symbol_storage.h"
#include "oneflow/core/vm/symbol_storage.h"
#include "oneflow/core/framework/to_string.h"
#include "oneflow/core/framework/user_op_registry_manager.h"
#include "oneflow/core/graph/logical_node.h"
......
......@@ -16,6 +16,7 @@ limitations under the License.
#include "oneflow/core/vm/symbol_storage.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/scope.h"
namespace oneflow {
......@@ -35,6 +36,12 @@ Maybe<JobDesc> NewSymbol<JobDesc>(int64_t symbol_id,
return JobDesc::New(symbol_id, data);
}
template<>
Maybe<Scope> NewSymbol<Scope>(int64_t symbol_id,
const typename ConstructArgType4Symbol<Scope>::type& data) {
return Scope::New(symbol_id, data);
}
} // namespace detail
} // namespace symbol
......
......@@ -31,6 +31,9 @@ class JobConfigProto;
class OpNodeSignatureDesc;
class OpNodeSignature;
class Scope;
class ScopeProto;
namespace symbol {
template<typename T>
......@@ -53,6 +56,11 @@ struct ConstructArgType4Symbol<JobDesc> final {
using type = JobConfigProto;
};
template<>
struct ConstructArgType4Symbol<Scope> final {
using type = ScopeProto;
};
namespace detail {
template<typename T>
......@@ -68,6 +76,10 @@ template<>
Maybe<JobDesc> NewSymbol<JobDesc>(int64_t symbol_id,
const typename ConstructArgType4Symbol<JobDesc>::type& data);
template<>
Maybe<Scope> NewSymbol<Scope>(int64_t symbol_id,
const typename ConstructArgType4Symbol<Scope>::type& data);
} // namespace detail
template<typename T>
......@@ -112,10 +124,10 @@ class Storage final {
Maybe<void> TryAdd(int64_t symbol_id, const typename ConstructArgType4Symbol<T>::type& data) {
CHECK_GT_OR_RETURN(symbol_id, 0);
const auto& iter = symbol_id2symbol_.find(symbol_id);
const auto& ptr = JUST(detail::NewSymbol<T>(symbol_id, data));
std::unique_lock<std::mutex> lock(mutex_);
const auto& iter = symbol_id2symbol_.find(symbol_id);
if (iter != symbol_id2symbol_.end()) { return Maybe<void>::Ok(); }
const auto& ptr = JUST(detail::NewSymbol<T>(symbol_id, data));
CHECK_OR_RETURN(symbol_id2symbol_.emplace(symbol_id, ptr).second);
return Maybe<void>::Ok();
}
......
......@@ -24,23 +24,10 @@ import oneflow.core.job.placement_pb2 as placement_pb
from google.protobuf import text_format
import oneflow.python.eager.blob_register as blob_register_util
import oneflow.python.framework.scope_util as scope_util
import oneflow.python.framework.scope_symbol as scope_symbol
import oneflow.python.eager.vm_util as vm_util
import oneflow.python.eager.symbol_storage as symbol_storage
def AddScopeToStorage(scope_symbol_id, scope_proto):
scope_proto_str = str(scope_proto)
if symbol_storage.HasSymbol4SerializedScopeProto(scope_proto_str):
return
parent_scope_symbol = symbol_storage.GetSymbol4Id(
scope_proto.parent_scope_symbol_id()
)
symbol = scope_symbol.ScopeSymbol(scope_symbol_id, scope_proto, parent_scope_symbol)
symbol_storage.SetSymbol4Id(scope_symbol_id, symbol)
symbol_storage.SetSymbol4SerializedScopeProto(scope_proto_str, symbol)
def MakeScopeSymbol(job_conf, parallel_conf, is_mirrored):
return scope_util.MakeInitialScope(
job_conf,
......
......@@ -237,7 +237,7 @@ def _Assign(var_blob_object, value_blob_object):
def _BuildNotMirroredScope(old_scope, builder):
return old_scope.BuildWithNewIsMirrored(builder, False)
return builder.BuildScopeWithNewIsMirrored(old_scope, False)
def _EagerRunModelInit(var_op_conf):
......
......@@ -47,7 +47,7 @@ class OpKernelObject(object_util.Object):
def _GetScopeSymbol(op_conf):
assert op_conf.HasField("scope_symbol_id")
return symbol_storage.GetSymbol4Id(op_conf.scope_symbol_id)
return oneflow_api.GetScopeSymbol(op_conf.scope_symbol_id)
def _GetOpParallelSymbol(op_conf):
......
......@@ -90,60 +90,3 @@ def SetSymbol4SerializedOpNodeSignature(serialized_op_node_signature, symbol):
serialized_op_node_signature2symbol = {}
def HasSymbol4JobConf(job_conf):
global job_conf_id2symbol
return id(job_conf) in job_conf_id2symbol
def GetSymbol4JobConf(job_conf):
global job_conf_id2symbol
return job_conf_id2symbol[id(job_conf)]
def SetSymbol4JobConf(job_conf, symbol):
assert not HasSymbol4JobConf(job_conf)
global job_conf_id2symbol
job_conf_id2symbol[id(job_conf)] = symbol
job_conf_id2symbol = {}
def HasSymbol4SerializedParallelConf(serialized_parallel_conf):
global serialized_parallel_conf2symbol
return serialized_parallel_conf in serialized_parallel_conf2symbol
def GetSymbol4SerializedParallelConf(serialized_parallel_conf):
global serialized_parallel_conf2symbol
return serialized_parallel_conf2symbol[serialized_parallel_conf]
def SetSymbol4SerializedParallelConf(serialized_parallel_conf, symbol):
assert not HasSymbol4SerializedParallelConf(serialized_parallel_conf)
global serialized_parallel_conf2symbol
serialized_parallel_conf2symbol[serialized_parallel_conf] = symbol
serialized_parallel_conf2symbol = {}
def HasSymbol4SerializedScopeProto(serialized_scope_proto):
global serialized_scope_proto2symbol
return serialized_scope_proto in serialized_scope_proto2symbol
def GetSymbol4SerializedScopeProto(serialized_scope_proto):
global serialized_scope_proto2symbol
return serialized_scope_proto2symbol[serialized_scope_proto]
def SetSymbol4SerializedScopeProto(serialized_scope_proto, symbol):
assert not HasSymbol4SerializedScopeProto(serialized_scope_proto)
global serialized_scope_proto2symbol
serialized_scope_proto2symbol[serialized_scope_proto] = symbol
serialized_scope_proto2symbol = {}
......@@ -31,9 +31,10 @@ import oneflow.python.eager.object as object_util
import oneflow.python.eager.object_storage as object_storage
import oneflow.python.eager.symbol as symbol_util
import oneflow.python.eager.symbol_storage as symbol_storage
import oneflow.python.framework.parallel_conf_util as parallel_conf_util
import oneflow_api.oneflow.core.job.scope as scope_cfg
import oneflow.python.framework.balanced_splitter as balanced_splitter
import oneflow.python.framework.c_api_util as c_api_util
import oneflow.python.framework.scope_symbol as scope_symbol
import oneflow.python.framework.id_util as id_util
import oneflow.python.framework.op_arg_util as op_arg_util
import oneflow.python.framework.placement_context as placement_ctx
......@@ -444,15 +445,83 @@ class InstructionsBuilder(object):
return oneflow_api.GetPlacementSymbol(parallel_conf)
def BuildInitialScope(
self, session_id, job_conf, device_tag, machine_device_ids, is_mirrored,
):
scope_proto = scope_cfg.ScopeProto()
scope_proto.set_session_id(session_id)
job_conf_sym = self.GetJobConfSymbol(job_conf)
scope_proto.set_job_desc_symbol_id(job_conf_sym.symbol_id)
parallel_conf = parallel_conf_util.MakeParallelConf(
device_tag, machine_device_ids
)
device_parallel_desc_sym = self.GetParallelDescSymbol(parallel_conf)
scope_proto.set_device_parallel_desc_symbol_id(
device_parallel_desc_sym.symbol_id
)
parallel_conf = parallel_conf_util.MakeParallelConf("cpu", machine_device_ids)
host_parallel_desc_sym = self.GetParallelDescSymbol(parallel_conf)
scope_proto.set_host_parallel_desc_symbol_id(host_parallel_desc_sym.symbol_id)
if is_mirrored:
scope_proto.mutable_opt_mirrored_parallel_conf().mutable_mirrored_parallel()
else:
scope_proto.mutable_opt_mirrored_parallel_conf().clear_mirrored_parallel()
return self.GetScopeSymbol(scope_proto, None)
def BuildScopeWithNewParallelDesc(self, scope, device_tag, machine_device_ids):
if isinstance(machine_device_ids, str):
machine_device_ids = [machine_device_ids]
def SetScopeProto(scope_proto):
parallel_conf = parallel_conf_util.MakeParallelConf(
device_tag, machine_device_ids
)
device_parallel_desc_sym = self.GetParallelDescSymbol(parallel_conf)
parallel_conf = parallel_conf_util.MakeParallelConf(
"cpu", machine_device_ids
)
host_parallel_desc_sym = self.GetParallelDescSymbol(parallel_conf)
scope_proto.set_device_parallel_desc_symbol_id(
device_parallel_desc_sym.symbol_id
)
scope_proto.set_host_parallel_desc_symbol_id(
host_parallel_desc_sym.symbol_id
)
return self.BuildScopeByProtoSetter(scope, SetScopeProto)
def BuildScopeWithNewParallelConf(self, scope, parallel_conf):
tag_and_dev_ids = parallel_conf_util.GetDeviceTagAndMachineDeviceIds(
parallel_conf
)
return self.BuildScopeWithNewParallelDesc(scope, *tag_and_dev_ids)
def BuildScopeWithNewIsMirrored(self, scope, is_mirrored):
def SetScopeProto(scope_proto):
if is_mirrored:
scope_proto.mutable_opt_mirrored_parallel_conf().mutable_mirrored_parallel()
else:
scope_proto.mutable_opt_mirrored_parallel_conf().clear_mirrored_parallel()
return self.BuildScopeByProtoSetter(scope, SetScopeProto)
def BuildScopeWithNewScopeName(self, scope, scope_name):
def SetScopeProto(scope_proto):
scope_proto.add_scope_op_name_prefixes(scope_name)
return self.BuildScopeByProtoSetter(scope, SetScopeProto)
def BuildScopeByProtoSetter(self, scope, setter):
scope_proto = scope.MakeChildScopeProto()
setter(scope_proto)
return self.GetScopeSymbol(scope_proto, scope)
def GetScopeSymbol(self, scope_proto, parent_scope_symbol=None):
if oneflow_api.HasScopeSymbol(scope_proto):
return oneflow_api.GetScopeSymbol(scope_proto)
symbol_id = self._NewSymbolId4Scope(scope_proto)
serialized_scope_proto = str(scope_proto)
if symbol_storage.HasSymbol4SerializedScopeProto(serialized_scope_proto):
return symbol_storage.GetSymbol4SerializedScopeProto(serialized_scope_proto)
symbol = scope_symbol.ScopeSymbol(symbol_id, scope_proto, parent_scope_symbol)
symbol_storage.SetSymbol4Id(symbol_id, symbol)
symbol_storage.SetSymbol4SerializedScopeProto(serialized_scope_proto, symbol)
return symbol
oneflow_api.AddScopeSymbol(symbol_id, scope_proto)
return oneflow_api.GetScopeSymbol(scope_proto)
def GetSharedOpKernelObject4ParallelConfSymbol(self, parallel_desc_sym):
if object_storage.HasSharedOpKernelObject4ParallelConfSymbol(parallel_desc_sym):
......@@ -497,7 +566,7 @@ class InstructionsBuilder(object):
def NewOpKernelObject(self, op_conf):
assert op_conf.HasField("scope_symbol_id")
scope_symbol = symbol_storage.GetSymbol4Id(op_conf.scope_symbol_id)
scope_symbol = oneflow_api.GetScopeSymbol(op_conf.scope_symbol_id)
op_conf_sym = self._GetOpConfSymbol(op_conf)
parallel_desc_sym_id = c_api_util.GetOpParallelSymbolId(op_conf)
parallel_desc_symbol = oneflow_api.GetPlacementSymbol(parallel_desc_sym_id)
......@@ -611,7 +680,7 @@ class InstructionsBuilder(object):
op_conf = op_attribute.op_conf
assert op_conf.HasField("scope_symbol_id"), op_conf
scope_symbol = symbol_storage.GetSymbol4Id(op_conf.scope_symbol_id)
scope_symbol = oneflow_api.GetScopeSymbol(op_conf.scope_symbol_id)
job_desc_sym = scope_symbol.job_desc_symbol
op_conf_sym = self._GetOpConfSymbol(op_conf)
op_node_signature_sym = self._GetOpNodeSignatureSymbol(op_attribute)
......
......@@ -59,7 +59,7 @@ def name_scope(name: str) -> None:
name_scope_stack_push(name)
def BuildScope(old_scope, builder):
return old_scope.BuildWithNewScopeName(builder, name)
return builder.BuildScopeWithNewScopeName(old_scope, name)
sess = session_context.GetDefaultSession()
try:
......
......@@ -30,7 +30,6 @@ import oneflow.python.framework.remote_blob as remote_blob_util
import oneflow.python.framework.runtime_mode as runtime_mode
import oneflow.python.framework.push_util as push_util
import oneflow.python.framework.session_context as session_ctx
import oneflow.python.framework.scope_symbol as scope_symbol
import oneflow.python.framework.scope_util as scope_util
import oneflow.python.framework.typing as oft
import oneflow.python.framework.typing_util as oft_util
......
......@@ -28,7 +28,7 @@ class DistributeStrategy(object):
if sess.is_running and len(sess.is_mirrored_strategy_enabled_stack) > 0:
def BuildScope(old_scope, builder):
return old_scope.BuildWithNewIsMirrored(builder, is_mirrored)
return builder.BuildScopeWithNewIsMirrored(old_scope, is_mirrored)
self.scope_context_ = scope_util.ScopeContext(
scope_util.MakeScope(BuildScope)
......
......@@ -13,6 +13,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from __future__ import absolute_import
import re
import oneflow_api.oneflow.core.job.placement as placement_cfg
def GetDeviceTagAndMachineDeviceIds(parallel_conf):
......@@ -21,3 +24,20 @@ def GetDeviceTagAndMachineDeviceIds(parallel_conf):
machine_device_ids.append(device_name)
device_tag = parallel_conf.device_tag()
return device_tag, machine_device_ids
def MakeParallelConf(device_tag, machine_device_ids):
assert isinstance(machine_device_ids, (list, tuple))
parallel_conf = placement_cfg.ParallelConf()
parallel_conf.set_device_tag(device_tag)
for machine_device_id in machine_device_ids:
assert isinstance(
machine_device_id, str
), "type of machine_device_id (%s) is not string" % type(machine_device_id)
assert re.match("^\d+:\d+(-\d+)?$", machine_device_id) is not None, (
"machine_device_id: %s is not valid" % machine_device_id
)
parallel_conf.add_device_name(machine_device_id)
return parallel_conf
......@@ -22,7 +22,7 @@ import oneflow.core.job.placement_pb2 as placement_pb
import oneflow.python.framework.c_api_util as c_api_util
import oneflow.python.framework.op_util as op_util
import oneflow.python.framework.session_context as session_ctx
import oneflow.python.framework.scope_symbol as scope_symbol
import oneflow.python.framework.parallel_conf_util as parallel_conf_util
import oneflow
import oneflow_api.oneflow.core.job.placement as placement_cfg
......@@ -75,25 +75,7 @@ def MakeParallelConf4Resource(device_tag, resource):
machine_device_ids = GetCpuMachineDeviceIds(resource)
else:
raise NotImplementedError
return MakeParallelConf(device_tag, machine_device_ids)
def MakeParallelConf(device_tag, machine_device_ids):
assert isinstance(machine_device_ids, collections.Sized)
parallel_conf = placement_cfg.ParallelConf()
parallel_conf.set_device_tag(device_tag)
for machine_device_id in machine_device_ids:
assert isinstance(
machine_device_id, str
), "type of machine_device_id (%s) is not string" % type(machine_device_id)
assert re.match("^\d+:\d+(-\d+)?$", machine_device_id) is not None, (
"machine_device_id: %s is not valid" % machine_device_id
)
pair = machine_device_id.split(":")
parallel_conf.add_device_name("%s:%s" % (pair[0], pair[1]))
return parallel_conf
return parallel_conf_util.MakeParallelConf(device_tag, machine_device_ids)
def MakeMachineId2DeviceIdList(parallel_conf):
......
......@@ -99,8 +99,8 @@ def GetEmptyPlacementScope(device_tag, machine_device_ids):
def GetNormalModePlacementScope(device_tag, machine_device_ids):
sess = session_ctx.GetDefaultSession()
scope = scope_util.MakeScope(
lambda old_scope, builder: old_scope.BuildWithNewParallelDesc(
builder, device_tag, machine_device_ids
lambda old_scope, builder: builder.BuildScopeWithNewParallelDesc(
old_scope, device_tag, machine_device_ids
)
)
return scope_util.ScopeContext(scope)
......@@ -113,8 +113,8 @@ def GetGlobalModePlacementScope(device_tag, machine_device_ids):
sess = session_ctx.GetDefaultSession()
def BuildScope(old_scope, builder):
return old_scope.BuildWithNewParallelDesc(
builder, device_tag, machine_device_ids
return builder.BuildScopeWithNewParallelDesc(
old_scope, device_tag, machine_device_ids
)
scope_ctx = scope_util.ScopeContext(scope_util.MakeScope(BuildScope))
......
......@@ -73,13 +73,6 @@ class PythonCallback(oneflow_api.ForeignCallback):
print(traceback.format_exc())
raise e
def AddScopeToPyStorage(self, scope_symbol_id, scope_proto):
try:
return interpreter_callback.AddScopeToStorage(scope_symbol_id, scope_proto)
except Exception as e:
print(traceback.format_exc())
raise e
def MakeScopeSymbol(self, job_conf, parallel_conf, is_mirrored):
try:
# TODO(hanbinbin): str() will be removed after proto obj is replaced with cfg obj in python side
......
......@@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import traceback
import oneflow.python.framework.scope_symbol as scope_symbol
import oneflow.python.framework.session_context as session_ctx
import oneflow.python.framework.attr_util as attr_util
import oneflow.python.eager.vm_util as vm_util
......@@ -38,7 +37,9 @@ def api_scope_config(**kwargs):
sess = session_ctx.GetDefaultSession()
scope = MakeScope(
lambda old_scope, builder: old_scope.BuildBySetter(builder, SetScopeProto)
lambda old_scope, builder: builder.BuildScopeByProtoSetter(
old_scope, SetScopeProto
)
)
return ScopeContext(scope)
......@@ -85,8 +86,8 @@ def MakeInitialScope(job_conf, device_tag, machine_device_ids, is_mirrored):
def BuildInitialScope(builder):
nonlocal scope
session_id = session_ctx.GetDefaultSession().id
scope = scope_symbol.BuildInitialScope(
builder, session_id, job_conf, device_tag, machine_device_ids, is_mirrored
scope = builder.BuildInitialScope(
session_id, job_conf, device_tag, machine_device_ids, is_mirrored
)
vm_util.LogicalRun(BuildInitialScope)
......
......@@ -106,7 +106,7 @@ def _GetReturnOpConfAndOutLbiAndScope(remote_blob, allow_cpu_return_op=True):
parallel_conf.CopyFrom(remote_blob.parallel_conf)
def BuildScope(old_scope, builder):
return old_scope.BuildWithNewParallelConf(builder, parallel_conf)
return builder.BuildScopeWithNewParallelConf(old_scope, parallel_conf)
sess = session_ctx.GetDefaultSession()
scope = scope_util.MakeScope(BuildScope)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册