未验证 提交 53282de0 编写于 作者: qq_22305325's avatar qq_22305325 提交者: GitHub

Dev replace str to cfg obj in python callback (#3832)

* Replace the py instruction with CFG Instruction

* move RunInstruction to pybind & refactor EagerOneflow's interface by cfg

* use forward declaration

* fix code style

* move RunInstruction api to oneflow_api.vm

* remove useless line in oneflow_internal.i

* replace args str to cfg_obj in python callback

* add forward declear of InstructionListProto

* cancel forward declear of InstructionListProto

* fix a name spelling mistake

* add TODO and use const reference instead of variable

* fix code style

* use const reference
Co-authored-by: Nouyangyu <xuanjiuye@gmail.com>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 282d3579
......@@ -22,48 +22,55 @@ limitations under the License.
namespace py = pybind11;
ONEFLOW_API_PYBIND11_MODULE("", m) {
using namespace oneflow;
namespace oneflow {
class PyForeignCallback : public ForeignCallback {
public:
// Inherit the constructors
using ForeignCallback::ForeignCallback;
class PyForeignCallback : public ForeignCallback {
public:
// Inherit the constructors
using ForeignCallback::ForeignCallback;
// Trampoline (need one for each virtual function)
void EagerMirroredCast(const std::shared_ptr<cfg::OpAttribute>& op_attribute,
const std::shared_ptr<cfg::ParallelConf>& parallel_conf) const override {
PYBIND11_OVERRIDE(void, /* Return type */
ForeignCallback, /* Parent class */
EagerMirroredCast, /* Name of function in C++ (must match Python name) */
op_attribute, parallel_conf /* Argument(s) */
);
}
// Trampoline (need one for each virtual function)
void EagerMirroredCast(const std::string& op_attribute_str,
const std::string& parallel_conf_str) const override {
PYBIND11_OVERRIDE(void, /* Return type */
ForeignCallback, /* Parent class */
EagerMirroredCast, /* Name of function in C++ (must match Python name) */
op_attribute_str, parallel_conf_str /* Argument(s) */
);
}
void EagerInterpretCompletedOp(
const std::shared_ptr<cfg::OpAttribute>& op_attribute,
const std::shared_ptr<cfg::ParallelConf>& parallel_conf) const override {
PYBIND11_OVERRIDE(void, ForeignCallback, EagerInterpretCompletedOp, op_attribute,
parallel_conf);
}
void EagerInterpretCompletedOp(const std::string& op_attribute_str,
const std::string& parallel_conf_str) const override {
PYBIND11_OVERRIDE(void, ForeignCallback, EagerInterpretCompletedOp, op_attribute_str,
parallel_conf_str);
}
void OfBlobCall(int64_t unique_id, int64_t ofblob_ptr) const override {
PYBIND11_OVERRIDE(void, ForeignCallback, OfBlobCall, unique_id, ofblob_ptr);
}
void OfBlobCall(int64_t unique_id, int64_t ofblob_ptr) const override {
PYBIND11_OVERRIDE(void, ForeignCallback, OfBlobCall, unique_id, ofblob_ptr);
}
void RemoveForeignCallback(int64_t unique_id) const override {
PYBIND11_OVERRIDE(void, ForeignCallback, RemoveForeignCallback, unique_id);
}
void RemoveForeignCallback(int64_t unique_id) const override {
PYBIND11_OVERRIDE(void, ForeignCallback, RemoveForeignCallback, unique_id);
}
int64_t MakeScopeSymbol(const std::shared_ptr<cfg::JobConfigProto>& job_conf,
const std::shared_ptr<cfg::ParallelConf>& parallel_conf,
bool is_mirrored) const override {
PYBIND11_OVERRIDE(int64_t, ForeignCallback, MakeScopeSymbol, job_conf, parallel_conf,
is_mirrored);
}
int64_t MakeScopeSymbol(const std::string& job_conf, const std::string& parallel_conf,
bool is_mirrored) const override {
PYBIND11_OVERRIDE(int64_t, ForeignCallback, MakeScopeSymbol, job_conf, parallel_conf,
is_mirrored);
}
int64_t MakeParallelDescSymbol(
const std::shared_ptr<cfg::ParallelConf>& parallel_conf) const override {
PYBIND11_OVERRIDE(int64_t, ForeignCallback, MakeParallelDescSymbol, parallel_conf);
}
};
int64_t MakeParallelDescSymbol(const std::string& parallel_conf) const override {
PYBIND11_OVERRIDE(int64_t, ForeignCallback, MakeParallelDescSymbol, parallel_conf);
}
};
} // namespace oneflow
ONEFLOW_API_PYBIND11_MODULE("", m) {
using namespace oneflow;
py::class_<ForeignCallback, PyForeignCallback>(m, "ForeignCallback")
.def(py::init<>())
......
......@@ -16,6 +16,10 @@ limitations under the License.
#ifndef ONEFLOW_CORE_JOB_FOREIGN_CALLBACK_H_
#define ONEFLOW_CORE_JOB_FOREIGN_CALLBACK_H_
#include "oneflow/core/job/placement.cfg.h"
#include "oneflow/core/operator/op_attribute.cfg.h"
#include "oneflow/core/job/job_conf.cfg.h"
namespace oneflow {
class ForeignCallback {
......@@ -23,12 +27,13 @@ class ForeignCallback {
ForeignCallback() = default;
virtual ~ForeignCallback() = default;
virtual void EagerMirroredCast(const std::string& op_attribute_str,
const std::string& parallel_conf_str) const {
virtual void EagerMirroredCast(const std::shared_ptr<cfg::OpAttribute>& op_attribute,
const std::shared_ptr<cfg::ParallelConf>& parallel_conf) const {
UNIMPLEMENTED();
}
virtual void EagerInterpretCompletedOp(const std::string& op_attribute_str,
const std::string& parallel_conf_str) const {
virtual void EagerInterpretCompletedOp(
const std::shared_ptr<cfg::OpAttribute>& op_attribute,
const std::shared_ptr<cfg::ParallelConf>& parallel_conf) const {
UNIMPLEMENTED();
}
......@@ -37,13 +42,15 @@ class ForeignCallback {
virtual void RemoveForeignCallback(int64_t unique_id) const { UNIMPLEMENTED(); }
// return scope_symbol_id
virtual int64_t MakeScopeSymbol(const std::string& job_conf, const std::string& parallel_conf,
virtual int64_t MakeScopeSymbol(const std::shared_ptr<cfg::JobConfigProto>& job_conf,
const std::shared_ptr<cfg::ParallelConf>& parallel_conf,
bool is_mirrored) const {
UNIMPLEMENTED();
return 0;
}
// return parallel_desc_symbol_id
virtual int64_t MakeParallelDescSymbol(const std::string& parallel_conf) const {
virtual int64_t MakeParallelDescSymbol(
const std::shared_ptr<cfg::ParallelConf>& parallel_conf) const {
UNIMPLEMENTED();
return 0;
}
......
......@@ -59,17 +59,22 @@ Maybe<void> GetOpNames(const Job& job, HashSet<std::string>* op_names) {
}
Maybe<void> EagerRunOps(const Job& job, HashSet<std::string>* op_names,
void (ForeignCallback::*interpret)(const std::string&, const std::string&)
const) {
void (ForeignCallback::*interpret)(
const std::shared_ptr<cfg::OpAttribute>& op_attribute,
const std::shared_ptr<cfg::ParallelConf>& parallel_conf) const) {
const auto& op_graph = JUST(OpGraph::New(job));
const auto* foreign_callback = JUST(GlobalMaybe<ForeignCallback>());
JUST(op_graph->ForEachOpNode([&](const OpNode& op_node) -> Maybe<void> {
if (!op_names->insert(op_node.op().op_name()).second) { return Maybe<void>::Ok(); }
const auto& op_attribute = op_node.op().GetOpAttributeWithoutOpNameAndLbn();
const auto& parallel_conf = op_node.parallel_desc().parallel_conf();
const std::string& op_attribute_str = PbMessage2TxtString(*op_attribute);
const std::string& parallel_conf_str = PbMessage2TxtString(parallel_conf);
(foreign_callback->*interpret)(op_attribute_str, parallel_conf_str);
{
const std::shared_ptr<cfg::OpAttribute>& cfg_op_attribute =
std::make_shared<cfg::OpAttribute>(*op_attribute);
const std::shared_ptr<cfg::ParallelConf>& cfg_parallel_conf =
std::make_shared<cfg::ParallelConf>(parallel_conf);
(foreign_callback->*interpret)(cfg_op_attribute, cfg_parallel_conf);
}
return Maybe<void>::Ok();
}));
return Maybe<void>::Ok();
......@@ -918,9 +923,13 @@ Maybe<LogicalBlobId> EagerJobBuildAndInferCtx::FindOrCreateMirroredLbiFromCompat
(*mut_mirrored_lbi2sub_lbis())[mirrored_lbi].push_back(mirrored_lbi);
const auto& parallel_conf = parallel_desc.parallel_conf();
const auto& op_attribute = JUST(AddAndInferConsistentOp(op_conf));
const std::string& op_attribute_str = PbMessage2TxtString(*op_attribute);
const std::string& parallel_conf_str = PbMessage2TxtString(parallel_conf);
JUST(GlobalMaybe<ForeignCallback>())->EagerMirroredCast(op_attribute_str, parallel_conf_str);
{
const std::shared_ptr<cfg::OpAttribute>& cfg_op_attribute =
std::make_shared<cfg::OpAttribute>(*op_attribute);
const std::shared_ptr<cfg::ParallelConf>& cfg_parallel_conf =
std::make_shared<cfg::ParallelConf>(parallel_conf);
JUST(GlobalMaybe<ForeignCallback>())->EagerMirroredCast(cfg_op_attribute, cfg_parallel_conf);
}
return mirrored_lbi;
}
......
......@@ -61,8 +61,15 @@ Maybe<void> AutoTrainStep::Apply(const OpGraph& op_graph, Job* job) const {
JobBuilder job_builder(job);
const ParallelConf& parallel_conf = GenParallelConfOfCpuZeroOnMaster();
int64_t scope_symbol_id = Global<ForeignCallback>::Get()->MakeScopeSymbol(
job->job_conf().DebugString(), parallel_conf.DebugString(), false);
int64_t scope_symbol_id = 0;
{
const std::shared_ptr<cfg::JobConfigProto>& cfg_job_conf =
std::make_shared<cfg::JobConfigProto>(job->job_conf());
const std::shared_ptr<cfg::ParallelConf>& cfg_parallel_conf =
std::make_shared<cfg::ParallelConf>(parallel_conf);
scope_symbol_id =
Global<ForeignCallback>::Get()->MakeScopeSymbol(cfg_job_conf, cfg_parallel_conf, false);
}
auto scalar_add_op = user_op::UserOpConfWrapperBuilder(train_step_name + "-ScalarAdd")
.Op("scalar_add")
......
......@@ -203,8 +203,15 @@ Maybe<void> TryMirroredCastTotalLossInstanceNum(
cast_from_mirrored->set_out("out");
cast_from_mirrored->mutable_sbp_parallel()->mutable_partial_sum_parallel();
const auto& parallel_conf = job_builder->ParallelConf4Lbi(*total_loss_instance_num_lbi);
int64_t scope_symbol_id = Global<ForeignCallback>::Get()->MakeScopeSymbol(
job_builder->job().job_conf().DebugString(), parallel_conf.DebugString(), true);
int64_t scope_symbol_id = 0;
{
const std::shared_ptr<cfg::JobConfigProto>& cfg_job_conf =
std::make_shared<cfg::JobConfigProto>(job_builder->job().job_conf());
const std::shared_ptr<cfg::ParallelConf>& cfg_parallel_conf =
std::make_shared<cfg::ParallelConf>(parallel_conf);
scope_symbol_id =
Global<ForeignCallback>::Get()->MakeScopeSymbol(cfg_job_conf, cfg_parallel_conf, true);
}
op_conf.set_scope_symbol_id(scope_symbol_id);
job_builder->AddOps(parallel_conf, {op_conf});
total_loss_instance_num_lbi->set_op_name(op_conf.name());
......@@ -252,8 +259,15 @@ void ScaleModelDiffByDynamicLossInstanceNum(
ParallelConf parallel_conf;
parallel_conf.set_device_tag("cpu");
parallel_conf.add_device_name("0:0");
int64_t scope_symbol_id = Global<ForeignCallback>::Get()->MakeScopeSymbol(
job_builder->job().job_conf().DebugString(), parallel_conf.DebugString(), false);
int64_t scope_symbol_id = 0;
{
const std::shared_ptr<cfg::JobConfigProto>& cfg_job_conf =
std::make_shared<cfg::JobConfigProto>(job_builder->job().job_conf());
const std::shared_ptr<cfg::ParallelConf>& cfg_parallel_conf =
std::make_shared<cfg::ParallelConf>(parallel_conf);
scope_symbol_id =
Global<ForeignCallback>::Get()->MakeScopeSymbol(cfg_job_conf, cfg_parallel_conf, false);
}
op_conf.set_scope_symbol_id(scope_symbol_id);
job_builder->AddOps(parallel_conf, {op_conf});
......@@ -438,8 +452,15 @@ void ClipGradientByGlobalNorm(const OpGraph& op_graph, JobBuilder* job_builder,
}
ParallelConf global_norm_parallel_conf =
all_same_parallel_desc ? parallel_desc->parallel_conf() : GenParallelConfOfCpuZeroOnMaster();
int64_t scope_symbol_id = Global<ForeignCallback>::Get()->MakeScopeSymbol(
job_builder->job().job_conf().DebugString(), global_norm_parallel_conf.DebugString(), false);
int64_t scope_symbol_id = 0;
{
const std::shared_ptr<cfg::JobConfigProto>& cfg_job_conf =
std::make_shared<cfg::JobConfigProto>(job_builder->job().job_conf());
const std::shared_ptr<cfg::ParallelConf> cfg_global_norm_parallel_conf =
std::make_shared<cfg::ParallelConf>(global_norm_parallel_conf);
scope_symbol_id = Global<ForeignCallback>::Get()->MakeScopeSymbol(
cfg_job_conf, cfg_global_norm_parallel_conf, false);
}
std::vector<std::string> lbns_to_add;
for (const auto& pair : *lbi2diff_lbi) {
const LogicalBlobId& diff_lbi = pair.second;
......
......@@ -64,8 +64,10 @@ Maybe<void> DistributeAddOp::InferParallelSignature() {
CHECK_EQ(op_parallel_desc.parallel_num(), input_bns().size());
FOR_RANGE(int, i, 0, input_bns().size()) {
const auto& in_parallel_conf = op_parallel_desc.GetParallelIdOnlyParallelConf(i);
const std::shared_ptr<cfg::ParallelConf>& cfg_in_parallel_conf =
std::make_shared<cfg::ParallelConf>(in_parallel_conf);
(*map)[input_bns().Get(i)] =
Global<ForeignCallback>::Get()->MakeParallelDescSymbol(in_parallel_conf.DebugString());
Global<ForeignCallback>::Get()->MakeParallelDescSymbol(cfg_in_parallel_conf);
}
return Maybe<void>::Ok();
}
......
......@@ -103,8 +103,10 @@ Maybe<void> DistributeCloneOp::InferParallelSignature() {
CHECK_EQ(op_parallel_desc.parallel_num(), output_bns().size());
FOR_RANGE(int, i, 0, output_bns().size()) {
const auto& out_parallel_conf = op_parallel_desc.GetParallelIdOnlyParallelConf(i);
const std::shared_ptr<cfg::ParallelConf>& cfg_out_parallel_conf =
std::make_shared<cfg::ParallelConf>(out_parallel_conf);
(*map)[output_bns().Get(i)] =
Global<ForeignCallback>::Get()->MakeParallelDescSymbol(out_parallel_conf.DebugString());
Global<ForeignCallback>::Get()->MakeParallelDescSymbol(cfg_out_parallel_conf);
}
return Maybe<void>::Ok();
}
......
......@@ -110,8 +110,10 @@ Maybe<void> DistributeConcatOp::InferParallelSignature() {
CHECK_EQ(op_parallel_desc.parallel_num(), input_bns().size());
FOR_RANGE(int, i, 0, input_bns().size()) {
const auto& in_parallel_conf = op_parallel_desc.GetParallelIdOnlyParallelConf(i);
const std::shared_ptr<cfg::ParallelConf>& cfg_in_parallel_conf =
std::make_shared<cfg::ParallelConf>(in_parallel_conf);
(*map)[input_bns().Get(i)] =
Global<ForeignCallback>::Get()->MakeParallelDescSymbol(in_parallel_conf.DebugString());
Global<ForeignCallback>::Get()->MakeParallelDescSymbol(cfg_in_parallel_conf);
}
return Maybe<void>::Ok();
}
......
......@@ -117,8 +117,10 @@ Maybe<void> DistributeSplitOp::InferParallelSignature() {
CHECK_EQ(op_parallel_desc.parallel_num(), output_bns().size());
FOR_RANGE(int, i, 0, output_bns().size()) {
const auto& out_parallel_conf = op_parallel_desc.GetParallelIdOnlyParallelConf(i);
const std::shared_ptr<cfg::ParallelConf>& cfg_out_parallel_conf =
std::make_shared<cfg::ParallelConf>(out_parallel_conf);
(*map)[output_bns().Get(i)] =
Global<ForeignCallback>::Get()->MakeParallelDescSymbol(out_parallel_conf.DebugString());
Global<ForeignCallback>::Get()->MakeParallelDescSymbol(cfg_out_parallel_conf);
}
return Maybe<void>::Ok();
}
......
......@@ -53,41 +53,46 @@ class PythonCallback(oneflow_api.ForeignCallback):
print(traceback.format_exc())
raise e
def EagerInterpretCompletedOp(self, op_attribute_str, parallel_conf_str):
def EagerInterpretCompletedOp(self, op_attribute, parallel_conf):
try:
# TODO(hanbinbin): str() will be removed after proto obj is replaced with cfg obj in python side
interpreter_callback.InterpretCompletedOp(
op_attribute_str, parallel_conf_str
str(op_attribute), str(parallel_conf)
)
except Exception as e:
print(traceback.format_exc())
raise e
def EagerMirroredCast(self, op_attribute_str, parallel_conf_str):
def EagerMirroredCast(self, op_attribute, parallel_conf):
try:
interpreter_callback.MirroredCast(op_attribute_str, parallel_conf_str)
# TODO(hanbinbin): str() will be removed after proto obj is replaced with cfg obj in python side
interpreter_callback.MirroredCast(str(op_attribute), str(parallel_conf))
except Exception as e:
print(traceback.format_exc())
raise e
def EagerCastFromMirrored(self, op_attribute_str, parallel_conf_str):
def EagerCastFromMirrored(self, op_attribute, parallel_conf):
try:
interpreter_callback.CastFromMirrored(op_attribute_str, parallel_conf_str)
# TODO(hanbinbin): str() will be removed after proto obj is replaced with cfg obj in python side
interpreter_callback.CastFromMirrored(str(op_attribute), str(parallel_conf))
except Exception as e:
print(traceback.format_exc())
raise e
def MakeScopeSymbol(self, job_conf_str, parallel_conf_str, is_mirrored):
def MakeScopeSymbol(self, job_conf, parallel_conf, is_mirrored):
try:
# TODO(hanbinbin): str() will be removed after proto obj is replaced with cfg obj in python side
return interpreter_callback.MakeScopeSymbol(
job_conf_str, parallel_conf_str, is_mirrored
str(job_conf), str(parallel_conf), is_mirrored
)
except Exception as e:
print(traceback.format_exc())
raise e
def MakeParallelDescSymbol(self, parallel_conf_str):
def MakeParallelDescSymbol(self, parallel_conf):
try:
return interpreter_callback.MakeParallelDescSymbol(parallel_conf_str)
# TODO(hanbinbin): str() will be removed after proto obj is replaced with cfg obj in python side
return interpreter_callback.MakeParallelDescSymbol(str(parallel_conf))
except Exception as e:
print(traceback.format_exc())
raise e
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册