未验证 提交 856a119f 编写于 作者: qq_22305325's avatar qq_22305325 提交者: GitHub

Mig instruction build stateless call rel api (#4218)

* mig parallel_conf_util

* mig BuildInitialScope BuildScopeWithNewParallelDesc BuildScopeWithNewParallelConf

* add test of GetDeviceTagAndMachineDeviceIds

* mig GetOpConfSymbol

* fix BuildScopeWithNewParallelDesc input type error

* use TRY

* use symbol::Storage<OperatorConfSymbol>

* _NewOpKernelObject

* mig OpKernelObject

* mig object_storage

* make of_format

* del comment

* std::function<void(Object*)

* mig NewOpKernelObject and _StatefulCallOpKernel

* mig _StatefulCallOpKernel and GetSharedOpKernelObject4ParallelConfSymbol

* del object_storage.cpp

* use name GLOBAL_PARA_SYM2SHARED_OPKENEL_OBJ_MUTEX

* mig CheckRefInBlobObjectParallelDesc and  OperandBlobObjects rel api

* mig _StatelessCall

* mig _StatelessCall

* mig StatelessCall api

* del comment

* mig id_util and scope_util

* use cfg_op_conf and Object*

* use Object*

* del _

* fix func name error

* use MapAt and shared_ptr

* use shared_ptr or const ref

* minor fix

* add todo

* minor fix

* minor djustment

* minor fix

* minor optimize

* minor fix
上级 db6ec194
......@@ -283,22 +283,87 @@ GetMut2OperandBlobObjects(
.GetPtrOrThrow());
}
void _StatelessCall(
const std::shared_ptr<InstructionsBuilder>& x, const std::string& stream_tag,
// signature of python func _FindOrCreateDelegateBlobObject, it will be removed after blobcache is
// migrated
using FindOrCreateDelegateBlobObjectFun = std::function<std::shared_ptr<compatible_py::BlobObject>(
const std::shared_ptr<InstructionsBuilder>&,
const std::function<std::shared_ptr<compatible_py::BlobObject>(
const std::shared_ptr<compatible_py::BlobObject>&,
const std::shared_ptr<compatible_py::OpArgParallelAttribute>&)>&,
const std::shared_ptr<compatible_py::BlobObject>&,
const std::shared_ptr<compatible_py::OpArgParallelAttribute>&)>;
void StatelessCall(
const std::shared_ptr<InstructionsBuilder>& x,
const std::shared_ptr<cfg::OpAttribute>& op_attribute,
std::shared_ptr<ParallelDesc> op_parallel_desc_sym,
const std::shared_ptr<ParallelDesc>& blob_parallel_desc_sym,
const std::shared_ptr<cfg::ParallelConf>& parallel_conf,
const std::shared_ptr<HashMap<std::string, std::shared_ptr<compatible_py::BlobObject>>>&
bn_in_op2blob_object,
const std::function<std::shared_ptr<compatible_py::BlobObject>(
const std::shared_ptr<InstructionsBuilder>&,
const std::shared_ptr<compatible_py::BlobObject>&,
const std::shared_ptr<compatible_py::OpArgParallelAttribute>&)>& get_delegate_blob_object) {
const std::shared_ptr<compatible_py::OpArgParallelAttribute>&)>& boxing_to,
const FindOrCreateDelegateBlobObjectFun& find_or_creat_delegate_blob_object) {
return x
->StatelessCall(op_attribute, parallel_conf, bn_in_op2blob_object, boxing_to,
find_or_creat_delegate_blob_object)
.GetOrThrow();
}
void NoBoxingStatelessCall(
const std::shared_ptr<InstructionsBuilder>& x,
const std::shared_ptr<cfg::OpAttribute>& op_attribute,
const std::shared_ptr<cfg::ParallelConf>& parallel_conf,
const std::shared_ptr<HashMap<std::string, std::shared_ptr<compatible_py::BlobObject>>>&
bn_in_op2blob_object,
const FindOrCreateDelegateBlobObjectFun& find_or_creat_delegate_blob_object) {
return x
->NoBoxingStatelessCall(op_attribute, parallel_conf, bn_in_op2blob_object,
find_or_creat_delegate_blob_object)
.GetOrThrow();
}
void NoBoxingCudaD2HStatelessCall(
const std::shared_ptr<InstructionsBuilder>& x,
const std::shared_ptr<cfg::OpAttribute>& op_attribute,
const std::shared_ptr<cfg::ParallelConf>& in_parallel_conf,
const std::shared_ptr<HashMap<std::string, std::shared_ptr<compatible_py::BlobObject>>>&
bn_in_op2blob_object,
const std::function<std::shared_ptr<ParallelDesc>(
const std::shared_ptr<InstructionsBuilder>&, const std::shared_ptr<ParallelDesc>&,
const std::string&)>& try_replace_device_tag) {
return x
->_StatelessCall(stream_tag, op_attribute, op_parallel_desc_sym, blob_parallel_desc_sym,
bn_in_op2blob_object, get_delegate_blob_object)
->NoBoxingCudaD2HStatelessCall(op_attribute, in_parallel_conf, bn_in_op2blob_object,
try_replace_device_tag)
.GetOrThrow();
}
void NoBoxingCudaH2DStatelessCall(
const std::shared_ptr<InstructionsBuilder>& x,
const std::shared_ptr<cfg::OpAttribute>& op_attribute,
const std::shared_ptr<cfg::ParallelConf>& out_parallel_conf,
const std::shared_ptr<HashMap<std::string, std::shared_ptr<compatible_py::BlobObject>>>&
bn_in_op2blob_object) {
return x->NoBoxingCudaH2DStatelessCall(op_attribute, out_parallel_conf, bn_in_op2blob_object)
.GetOrThrow();
}
void RawStatelessCall(
const std::shared_ptr<InstructionsBuilder>& x,
const std::shared_ptr<cfg::OpAttribute>& op_attribute,
const std::shared_ptr<cfg::ParallelConf>& parallel_conf,
const std::shared_ptr<HashMap<std::string, std::shared_ptr<compatible_py::BlobObject>>>&
bn_in_op2blob_object) {
return x->RawStatelessCall(op_attribute, parallel_conf, bn_in_op2blob_object).GetOrThrow();
}
std::shared_ptr<compatible_py::BlobObject> Build121To(
const std::shared_ptr<InstructionsBuilder>& x,
const std::shared_ptr<compatible_py::BlobObject>& blob_object,
const std::shared_ptr<ParallelDesc>& parallel_desc_symbol) {
return x->Build121To(blob_object, parallel_desc_symbol).GetPtrOrThrow();
}
} // namespace
ONEFLOW_API_PYBIND11_MODULE("deprecated", m) {
......@@ -375,7 +440,12 @@ ONEFLOW_API_PYBIND11_MODULE("deprecated", m) {
&GetSharedOpKernelObject4ParallelConfSymbol)
.def("DeleteObject", &DeleteObject)
.def("_StatefulCallOpKernel", &_StatefulCallOpKernel)
.def("_StatelessCall", &_StatelessCall)
.def("StatelessCall", &StatelessCall)
.def("NoBoxingStatelessCall", &NoBoxingStatelessCall)
.def("NoBoxingCudaD2HStatelessCall", &NoBoxingCudaD2HStatelessCall)
.def("NoBoxingCudaH2DStatelessCall", &NoBoxingCudaH2DStatelessCall)
.def("RawStatelessCall", &RawStatelessCall)
.def("Build121To", &Build121To)
.def("GetConstInputOperandBlobObjects", &GetConstInputOperandBlobObjects)
.def("GetMutableInputOperandBlobObjects", &GetMutableInputOperandBlobObjects)
.def("GetMut1OperandBlobObjects", &GetMut1OperandBlobObjects)
......
......@@ -23,6 +23,10 @@ limitations under the License.
#include "oneflow/core/framework/parallel_conf_util.h"
#include "oneflow/core/framework/object_storage.h"
#include "oneflow/core/operator/op_node_signature.cfg.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/framework/id_util.h"
#include "oneflow/core/operator/interface_blob_conf.cfg.h"
#include "oneflow/core/framework/scope_util.h"
#include "oneflow/core/common/container_util.h"
namespace oneflow {
......@@ -110,6 +114,50 @@ uint64_t NewTokenId() {
return token_id;
}
using IntList = std::vector<int64_t>;
using Int2IntListMap = HashMap<int64_t, std::shared_ptr<IntList>>;
// This function is used to determine whether the machine_id2sorted_dev_phy_ids of ParallelDesc are
// equal
bool Int2IntListMapContaining(const Int2IntListMap& bigger, const Int2IntListMap& smaller) {
for (const auto& pair : smaller) {
if (bigger.find(pair.first) == bigger.end()) { return false; }
const auto& bigger_device_ids = bigger.find(pair.first)->second;
std::vector<int64_t>::iterator ret;
for (int64_t device_id : *pair.second) {
ret = std::find(bigger_device_ids->begin(), bigger_device_ids->end(), device_id);
if (ret == bigger_device_ids->end()) { return false; }
}
}
return true;
}
Maybe<compatible_py::BlobObject> MakeNewBlobObjectLike(
const std::shared_ptr<InstructionsBuilder>& builder,
const std::shared_ptr<compatible_py::BlobObject>& blob_object,
const std::shared_ptr<ParallelDesc>& new_parallel_desc_symbol) {
OperatorConf op_conf;
op_conf.set_name(*JUST(UniqueStr("Input")));
op_conf.set_device_tag(new_parallel_desc_symbol->device_tag());
op_conf.mutable_input_conf()->set_out("out");
std::shared_ptr<cfg::InterfaceBlobConf> cfg_interface_blob_conf =
std::make_shared<cfg::InterfaceBlobConf>();
blob_object->op_arg_parallel_attr()->DumpToInterfaceBlobConf(cfg_interface_blob_conf);
blob_object->op_arg_blob_attr()->DumpToInterfaceBlobConf(cfg_interface_blob_conf);
cfg_interface_blob_conf->ToProto(op_conf.mutable_input_conf()->mutable_blob_conf());
std::shared_ptr<Scope> cur_scope = JUST(GetCurrentScope());
op_conf.set_scope_symbol_id(JUST(cur_scope->symbol_id()));
OpNodeSignature upstream_signature;
const auto& op = JUST(ConstructAndInferOp(op_conf, upstream_signature, *cur_scope));
const auto& op_attribute = op->GetOpAttributeWithoutOpNameAndLbn();
std::shared_ptr<cfg::ParallelConf> parallel_conf = new_parallel_desc_symbol->cfg_parallel_conf();
std::shared_ptr<HashMap<std::string, std::shared_ptr<compatible_py::BlobObject>>>
bn_in_op2blob_object =
std::make_shared<HashMap<std::string, std::shared_ptr<compatible_py::BlobObject>>>();
builder->RawStatelessCall(std::make_shared<cfg::OpAttribute>(*op_attribute), parallel_conf,
bn_in_op2blob_object);
return (*bn_in_op2blob_object)["out"];
}
} // namespace
namespace detail {
......@@ -908,13 +956,167 @@ Maybe<OpNodeSignatureDesc> InstructionsBuilder::GetOpNodeSignatureSymbol(
return GetSymbol<cfg::OpNodeSignature, OpNodeSignatureDesc>(*op_node_signature);
}
// signature of python func _FindOrCreateDelegateBlobObject, it will be removed after blobcache is
// migrated
using FindOrCreateDelegateBlobObjectFun = std::function<std::shared_ptr<compatible_py::BlobObject>(
const std::shared_ptr<InstructionsBuilder>&,
const std::function<std::shared_ptr<compatible_py::BlobObject>(
const std::shared_ptr<compatible_py::BlobObject>&,
const std::shared_ptr<compatible_py::OpArgParallelAttribute>&)>&,
const std::shared_ptr<compatible_py::BlobObject>&,
const std::shared_ptr<compatible_py::OpArgParallelAttribute>&)>;
Maybe<void> InstructionsBuilder::StatelessCall(
const std::shared_ptr<cfg::OpAttribute>& op_attribute,
const std::shared_ptr<cfg::ParallelConf>& parallel_conf,
const std::shared_ptr<HashMap<std::string, std::shared_ptr<compatible_py::BlobObject>>>&
bn_in_op2blob_object,
const std::function<std::shared_ptr<compatible_py::BlobObject>(
const std::shared_ptr<InstructionsBuilder>&,
const std::shared_ptr<compatible_py::BlobObject>&,
const std::shared_ptr<compatible_py::OpArgParallelAttribute>&)>& boxing_to,
const FindOrCreateDelegateBlobObjectFun& find_or_creat_delegate_blob_object) {
std::shared_ptr<ParallelDesc> op_parallel_desc_sym = JUST(GetParallelDescSymbol(parallel_conf));
JUST(CheckRefInBlobObjectParallelDesc(op_attribute, op_parallel_desc_sym, bn_in_op2blob_object));
const auto FetchDelegateBlobObject =
[this, &boxing_to](
const std::shared_ptr<compatible_py::BlobObject>& x_blob_object,
const std::shared_ptr<compatible_py::OpArgParallelAttribute>& op_arg_parallel_attr)
-> std::shared_ptr<compatible_py::BlobObject> {
// TODO(hanbinbin): use Maybe as return after blobcache is migrated
return boxing_to(shared_from_this(), x_blob_object, op_arg_parallel_attr);
};
const auto GetDelegateBlobObject =
[this, find_or_creat_delegate_blob_object, &FetchDelegateBlobObject](
const std::shared_ptr<compatible_py::BlobObject>& blob_object,
const std::shared_ptr<compatible_py::OpArgParallelAttribute>& op_arg_parallel_attr)
-> Maybe<compatible_py::BlobObject> {
return find_or_creat_delegate_blob_object(shared_from_this(), FetchDelegateBlobObject,
blob_object, op_arg_parallel_attr);
};
JUST(_StatelessCall("compute", op_attribute, op_parallel_desc_sym, op_parallel_desc_sym,
bn_in_op2blob_object, GetDelegateBlobObject));
return Maybe<void>::Ok();
}
Maybe<void> InstructionsBuilder::NoBoxingStatelessCall(
const std::shared_ptr<cfg::OpAttribute>& op_attribute,
const std::shared_ptr<cfg::ParallelConf>& parallel_conf,
const std::shared_ptr<HashMap<std::string, std::shared_ptr<compatible_py::BlobObject>>>&
bn_in_op2blob_object,
const FindOrCreateDelegateBlobObjectFun& find_or_creat_delegate_blob_object) {
std::shared_ptr<ParallelDesc> op_parallel_desc_sym = JUST(GetParallelDescSymbol(parallel_conf));
JUST(CheckRefInBlobObjectParallelDesc(op_attribute, op_parallel_desc_sym, bn_in_op2blob_object));
const auto FetchDelegateBlobObject =
[this](const std::shared_ptr<compatible_py::BlobObject>& blob_object,
const std::shared_ptr<compatible_py::OpArgParallelAttribute>& op_arg_parallel_attr)
-> std::shared_ptr<compatible_py::BlobObject> {
std::shared_ptr<ParallelDesc> from_pd = blob_object->parallel_desc_symbol();
std::shared_ptr<ParallelDesc> to_pd = op_arg_parallel_attr->parallel_desc_symbol();
if (*from_pd == *to_pd) { return blob_object; }
CHECK(from_pd->device_tag() == "cpu");
CHECK(to_pd->device_tag() == "cpu");
CHECK(from_pd->parallel_num() == to_pd->parallel_num());
auto from_machine_ids = from_pd->machine_id2sorted_dev_phy_ids();
auto to_machine_ids = to_pd->machine_id2sorted_dev_phy_ids();
if ((from_pd->machine_id2sorted_dev_phy_ids()->size() == from_pd->parallel_num())
&& (Int2IntListMapContaining(*from_machine_ids, *to_machine_ids))
&& (Int2IntListMapContaining(*to_machine_ids, *from_machine_ids))) {
return CHECK_JUST(BroadcastBlobReference(blob_object, to_pd));
}
return CHECK_JUST(Build121To(blob_object, to_pd));
};
const auto GetDirectOr121BlobObject =
[this, find_or_creat_delegate_blob_object, &FetchDelegateBlobObject](
const std::shared_ptr<compatible_py::BlobObject>& blob_object,
const std::shared_ptr<compatible_py::OpArgParallelAttribute>& op_arg_parallel_attr)
-> Maybe<compatible_py::BlobObject> {
return find_or_creat_delegate_blob_object(shared_from_this(), FetchDelegateBlobObject,
blob_object, op_arg_parallel_attr);
};
JUST(_StatelessCall("compute", op_attribute, op_parallel_desc_sym, op_parallel_desc_sym,
bn_in_op2blob_object, GetDirectOr121BlobObject));
return Maybe<void>::Ok();
}
Maybe<void> InstructionsBuilder::NoBoxingCudaD2HStatelessCall(
const std::shared_ptr<cfg::OpAttribute>& op_attribute,
const std::shared_ptr<cfg::ParallelConf>& in_parallel_conf,
const std::shared_ptr<HashMap<std::string, std::shared_ptr<compatible_py::BlobObject>>>&
bn_in_op2blob_object,
const std::function<std::shared_ptr<ParallelDesc>(
const std::shared_ptr<InstructionsBuilder>&, const std::shared_ptr<ParallelDesc>&,
const std::string&)>& try_replace_device_tag) {
std::shared_ptr<ParallelDesc> op_parallel_desc_sym =
JUST(GetParallelDescSymbol(in_parallel_conf));
std::shared_ptr<ParallelDesc> blob_parallel_desc_sym =
try_replace_device_tag(shared_from_this(), op_parallel_desc_sym, "cpu");
JUST(CheckRefInBlobObjectParallelDesc(op_attribute, op_parallel_desc_sym, bn_in_op2blob_object));
const auto GetDirectBlobObject =
[](const std::shared_ptr<compatible_py::BlobObject>& blob_object,
const std::shared_ptr<compatible_py::OpArgParallelAttribute>& op_arg_parallel_attr)
-> Maybe<compatible_py::BlobObject> { return blob_object; };
JUST(_StatelessCall("copy_d2h", op_attribute, op_parallel_desc_sym, blob_parallel_desc_sym,
bn_in_op2blob_object, GetDirectBlobObject));
return Maybe<void>::Ok();
}
Maybe<void> InstructionsBuilder::NoBoxingCudaH2DStatelessCall(
const std::shared_ptr<cfg::OpAttribute>& op_attribute,
const std::shared_ptr<cfg::ParallelConf>& out_parallel_conf,
const std::shared_ptr<HashMap<std::string, std::shared_ptr<compatible_py::BlobObject>>>&
bn_in_op2blob_object) {
std::shared_ptr<ParallelDesc> op_parallel_desc_sym =
JUST(GetParallelDescSymbol(out_parallel_conf));
JUST(CheckRefInBlobObjectParallelDesc(op_attribute, op_parallel_desc_sym, bn_in_op2blob_object));
const auto GetDirectBlobObject =
[](const std::shared_ptr<compatible_py::BlobObject>& blob_object,
const std::shared_ptr<compatible_py::OpArgParallelAttribute>& op_arg_parallel_attr)
-> Maybe<compatible_py::BlobObject> { return blob_object; };
JUST(_StatelessCall("copy_h2d", op_attribute, op_parallel_desc_sym, op_parallel_desc_sym,
bn_in_op2blob_object, GetDirectBlobObject));
return Maybe<void>::Ok();
}
Maybe<void> InstructionsBuilder::RawStatelessCall(
const std::shared_ptr<cfg::OpAttribute>& op_attribute,
const std::shared_ptr<cfg::ParallelConf>& parallel_conf,
const std::shared_ptr<HashMap<std::string, std::shared_ptr<compatible_py::BlobObject>>>&
bn_in_op2blob_object) {
std::shared_ptr<ParallelDesc> op_parallel_desc_sym = JUST(GetParallelDescSymbol(parallel_conf));
JUST(CheckRefInBlobObjectParallelDesc(op_attribute, op_parallel_desc_sym, bn_in_op2blob_object));
const auto GetDirectBlobObject =
[](const std::shared_ptr<compatible_py::BlobObject>& blob_object,
const std::shared_ptr<compatible_py::OpArgParallelAttribute>& op_arg_parallel_attr)
-> Maybe<compatible_py::BlobObject> { return blob_object; };
JUST(_StatelessCall("compute", op_attribute, op_parallel_desc_sym, op_parallel_desc_sym,
bn_in_op2blob_object, GetDirectBlobObject));
return Maybe<void>::Ok();
}
Maybe<void> InstructionsBuilder::_StatelessCall(
const std::string& stream_tag, const std::shared_ptr<cfg::OpAttribute>& op_attribute,
std::shared_ptr<ParallelDesc> op_parallel_desc_sym,
const std::shared_ptr<ParallelDesc>& blob_parallel_desc_sym,
const std::shared_ptr<HashMap<std::string, std::shared_ptr<compatible_py::BlobObject>>>&
bn_in_op2blob_object,
const std::function<std::shared_ptr<compatible_py::BlobObject>(
const std::function<Maybe<compatible_py::BlobObject>(
const std::shared_ptr<compatible_py::BlobObject>&,
const std::shared_ptr<compatible_py::OpArgParallelAttribute>&)>& get_delegate_blob_object) {
if (op_attribute->parallel_signature().has_op_parallel_desc_symbol_id()) {
......@@ -930,8 +1132,8 @@ Maybe<void> InstructionsBuilder::_StatelessCall(
op_attribute->ToProto(&pb_op_attribute);
std::shared_ptr<compatible_py::OpArgParallelAttribute> op_arg_parallel_attr = CHECK_JUST(
compatible_py::GetOpArgParallelAttribute(op_parallel_desc_sym, pb_op_attribute, ibn));
return get_delegate_blob_object(CHECK_JUST(MapAt(*bn_in_op2blob_object, ibn)),
op_arg_parallel_attr);
return CHECK_JUST(get_delegate_blob_object(CHECK_JUST(MapAt(*bn_in_op2blob_object, ibn)),
op_arg_parallel_attr));
};
const auto& op_conf = op_attribute->op_conf();
......@@ -971,6 +1173,15 @@ Maybe<void> InstructionsBuilder::_StatelessCall(
return Maybe<void>::Ok();
}
Maybe<compatible_py::BlobObject> InstructionsBuilder::Build121To(
const std::shared_ptr<compatible_py::BlobObject>& blob_object,
const std::shared_ptr<ParallelDesc>& parallel_desc_symbol) {
std::shared_ptr<compatible_py::BlobObject> ref_blob_object =
JUST(MakeNewBlobObjectLike(shared_from_this(), blob_object, parallel_desc_symbol));
JUST(Build121AssignInstruction(ref_blob_object, blob_object));
return ref_blob_object;
}
Maybe<std::vector<
std::pair<std::shared_ptr<StringSymbol>, std::shared_ptr<compatible_py::BlobObject>>>>
InstructionsBuilder::GetConstInputOperandBlobObjects(
......
......@@ -49,7 +49,7 @@ struct CreateSymbolIdHelper {
} // namespace detail
class InstructionsBuilder {
class InstructionsBuilder : public std::enable_shared_from_this<InstructionsBuilder> {
public:
InstructionsBuilder(const InstructionsBuilder&) = delete;
InstructionsBuilder(InstructionsBuilder&&) = delete;
......@@ -186,17 +186,57 @@ class InstructionsBuilder {
const std::vector<
std::pair<std::shared_ptr<StringSymbol>, std::shared_ptr<compatible_py::BlobObject>>>&
mut2_operand_blob_objects);
using FindOrCreateDelegateBlobObjectFun =
std::function<std::shared_ptr<compatible_py::BlobObject>(
const std::shared_ptr<InstructionsBuilder>&,
const std::function<std::shared_ptr<compatible_py::BlobObject>(
const std::shared_ptr<compatible_py::BlobObject>&,
const std::shared_ptr<compatible_py::OpArgParallelAttribute>&)>&,
const std::shared_ptr<compatible_py::BlobObject>&,
const std::shared_ptr<compatible_py::OpArgParallelAttribute>&)>;
Maybe<void> _StatelessCall(
const std::string& stream_tag, const std::shared_ptr<cfg::OpAttribute>& op_attribute,
std::shared_ptr<ParallelDesc> op_parallel_desc_sym,
const std::shared_ptr<ParallelDesc>& blob_parallel_desc_sym,
Maybe<void> StatelessCall(
const std::shared_ptr<cfg::OpAttribute>& op_attribute,
const std::shared_ptr<cfg::ParallelConf>& parallel_conf,
const std::shared_ptr<HashMap<std::string, std::shared_ptr<compatible_py::BlobObject>>>&
bn_in_op2blob_object,
const std::function<std::shared_ptr<compatible_py::BlobObject>(
const std::shared_ptr<InstructionsBuilder>&,
const std::shared_ptr<compatible_py::BlobObject>&,
const std::shared_ptr<compatible_py::OpArgParallelAttribute>&)>&
get_delegate_blob_object);
const std::shared_ptr<compatible_py::OpArgParallelAttribute>&)>& boxing_to,
const FindOrCreateDelegateBlobObjectFun& find_or_creat_delegate_blob_object);
Maybe<void> NoBoxingStatelessCall(
const std::shared_ptr<cfg::OpAttribute>& op_attribute,
const std::shared_ptr<cfg::ParallelConf>& parallel_conf,
const std::shared_ptr<HashMap<std::string, std::shared_ptr<compatible_py::BlobObject>>>&
bn_in_op2blob_object,
const FindOrCreateDelegateBlobObjectFun& find_or_creat_delegate_blob_object);
Maybe<void> NoBoxingCudaD2HStatelessCall(
const std::shared_ptr<cfg::OpAttribute>& op_attribute,
const std::shared_ptr<cfg::ParallelConf>& in_parallel_conf,
const std::shared_ptr<HashMap<std::string, std::shared_ptr<compatible_py::BlobObject>>>&
bn_in_op2blob_object,
const std::function<std::shared_ptr<ParallelDesc>(
const std::shared_ptr<InstructionsBuilder>&, const std::shared_ptr<ParallelDesc>&,
const std::string&)>& try_replace_device_tag);
Maybe<void> NoBoxingCudaH2DStatelessCall(
const std::shared_ptr<cfg::OpAttribute>& op_attribute,
const std::shared_ptr<cfg::ParallelConf>& out_parallel_conf,
const std::shared_ptr<HashMap<std::string, std::shared_ptr<compatible_py::BlobObject>>>&
bn_in_op2blob_object);
Maybe<void> RawStatelessCall(
const std::shared_ptr<cfg::OpAttribute>& op_attribute,
const std::shared_ptr<cfg::ParallelConf>& parallel_conf,
const std::shared_ptr<HashMap<std::string, std::shared_ptr<compatible_py::BlobObject>>>&
bn_in_op2blob_object);
Maybe<compatible_py::BlobObject> Build121To(
const std::shared_ptr<compatible_py::BlobObject>& blob_object,
const std::shared_ptr<ParallelDesc>& parallel_desc_symbol);
Maybe<std::vector<
std::pair<std::shared_ptr<StringSymbol>, std::shared_ptr<compatible_py::BlobObject>>>>
......@@ -292,6 +332,17 @@ class InstructionsBuilder {
Maybe<void> InitOpConfSymbol(int64_t symbol_id,
const std::shared_ptr<cfg::OperatorConf>& op_conf);
Maybe<void> _StatelessCall(
const std::string& stream_tag, const std::shared_ptr<cfg::OpAttribute>& op_attribute,
std::shared_ptr<ParallelDesc> op_parallel_desc_sym,
const std::shared_ptr<ParallelDesc>& blob_parallel_desc_sym,
const std::shared_ptr<HashMap<std::string, std::shared_ptr<compatible_py::BlobObject>>>&
bn_in_op2blob_object,
const std::function<Maybe<compatible_py::BlobObject>(
const std::shared_ptr<compatible_py::BlobObject>&,
const std::shared_ptr<compatible_py::OpArgParallelAttribute>&)>&
get_delegate_blob_object);
Maybe<void> _TryClearObject(compatible_py::Object* blob_object);
Maybe<void> _DeleteObject(compatible_py::Object* blob_object);
......
......@@ -36,6 +36,15 @@ import oneflow_api.oneflow.core.job.placement as placement_cfg
import oneflow_api
def _FindOrCreateDelegateBlobObject(
builder, Fetch, x_blob_object, op_arg_parallel_attr
):
if x_blob_object.op_arg_parallel_attr == op_arg_parallel_attr:
return x_blob_object
blob_cache = blob_cache_util.FindOrCreateBlobCache(x_blob_object)
return blob_cache.GetCachedDelegateBlobObject(op_arg_parallel_attr, Fetch)
def BoxingTo(builder, produced_blob_object, consumer_op_arg_parallel_attr):
hob_context = BoxingHobContext(produced_blob_object, consumer_op_arg_parallel_attr)
if enable_if.get_condition_hob(NoBoxing)(hob_context):
......@@ -364,10 +373,12 @@ def GpuNcclAllReduce(builder, produced_blob_object, consumer_op_arg_parallel_att
bn_in_op2blob_object = oneflow_api.deprecated.BnInOp2BlobObject()
bn_in_op2blob_object["in_0"] = produced_blob_object
op_attribute = _GetEagerNcclAllReduce(parallel_conf, bn_in_op2blob_object)
cfg_op_attribute = oneflow_api.deprecated.MakeOpAttributeByString(str(op_attribute))
builder.NoBoxingStatelessCall(
op_attribute,
parallel_conf=parallel_conf,
bn_in_op2blob_object=bn_in_op2blob_object,
cfg_op_attribute,
parallel_conf,
bn_in_op2blob_object,
_FindOrCreateDelegateBlobObject,
)
y_blob_object = bn_in_op2blob_object["out_0"]
y_blob_object.op_arg_parallel_attr.Assign(consumer_op_arg_parallel_attr)
......@@ -543,10 +554,12 @@ def BuildNaiveCpuBoxing(
bn_in_op2blob_object = oneflow_api.deprecated.BnInOp2BlobObject()
for i in range(len(physical_in_blob_objects)):
bn_in_op2blob_object["in_%s" % i] = physical_in_blob_objects[i]
cfg_op_attribute = oneflow_api.deprecated.MakeOpAttributeByString(str(op_attribute))
builder.NoBoxingStatelessCall(
op_attribute,
parallel_conf=boxing_parallel_desc_symbol.parallel_conf,
bn_in_op2blob_object=bn_in_op2blob_object,
cfg_op_attribute,
boxing_parallel_desc_symbol.parallel_conf,
bn_in_op2blob_object,
_FindOrCreateDelegateBlobObject,
)
return [bn_in_op2blob_object["out_%s" % i] for i in range(out_parallel_num)]
......@@ -696,10 +709,11 @@ def _BuildCopyInstruction(builder, produced_blob_object, op_conf, to_device_tag)
bn_in_op2blob_object["in"] = produced_blob_object
op_attribute = op_infer_util.Infer(op_conf, bn_in_op2blob_object)
assert to_device_tag != x_device_tag, (to_device_tag, x_device_tag)
cfg_op_attribute = oneflow_api.deprecated.MakeOpAttributeByString(str(op_attribute))
if to_device_tag == "cpu" and x_device_tag == "gpu":
x_parallel_conf = produced_blob_object.parallel_desc_symbol.parallel_conf
builder.NoBoxingCudaD2HStatelessCall(
op_attribute, x_parallel_conf, bn_in_op2blob_object=bn_in_op2blob_object
cfg_op_attribute, x_parallel_conf, bn_in_op2blob_object, TryReplaceDeviceTag
)
elif to_device_tag == "gpu" and x_device_tag == "cpu":
out_parallel_desc_symbol = TryReplaceDeviceTag(
......@@ -708,9 +722,7 @@ def _BuildCopyInstruction(builder, produced_blob_object, op_conf, to_device_tag)
out_parallel_conf = out_parallel_desc_symbol.parallel_conf
with builder.CudaHostPinBlob(produced_blob_object):
builder.NoBoxingCudaH2DStatelessCall(
op_attribute,
out_parallel_conf,
bn_in_op2blob_object=bn_in_op2blob_object,
cfg_op_attribute, out_parallel_conf, bn_in_op2blob_object,
)
else:
raise NotImplementedError(
......@@ -747,23 +759,26 @@ def BuildAssignInstruction(builder, ref_blob_object, value_blob_object, op_conf)
bn_in_op2blob_object["ref"] = ref_blob_object
bn_in_op2blob_object["value"] = value_blob_object
op_attribute = op_infer_util.Infer(op_conf, bn_in_op2blob_object)
cfg_op_attribute = oneflow_api.deprecated.MakeOpAttributeByString(str(op_attribute))
if ref_device_tag == value_device_tag:
builder.NoBoxingStatelessCall(
op_attribute,
parallel_conf=ref_parallel_conf,
bn_in_op2blob_object=bn_in_op2blob_object,
cfg_op_attribute,
ref_parallel_conf,
bn_in_op2blob_object,
_FindOrCreateDelegateBlobObject,
)
elif ref_device_tag == "cpu" and value_device_tag == "gpu":
value_parallel_conf = value_blob_object.parallel_desc_symbol.parallel_conf
builder.NoBoxingCudaD2HStatelessCall(
op_attribute, value_parallel_conf, bn_in_op2blob_object=bn_in_op2blob_object
cfg_op_attribute,
value_parallel_conf,
bn_in_op2blob_object,
TryReplaceDeviceTag,
)
elif ref_device_tag == "gpu" and value_device_tag == "cpu":
with builder.CudaHostPinBlob(value_blob_object):
builder.NoBoxingCudaH2DStatelessCall(
op_attribute,
ref_parallel_conf,
bn_in_op2blob_object=bn_in_op2blob_object,
cfg_op_attribute, ref_parallel_conf, bn_in_op2blob_object,
)
else:
raise NotImplementedError(
......
......@@ -204,8 +204,15 @@ def _NaiveInterpret(op_attribute, parallel_conf, blob_register):
with blob_register_util.BnInOp2BlobObjectScope(
blob_register, op_attribute
) as bn_in_op2blob_object:
cfg_op_attribute = oneflow_api.deprecated.MakeOpAttributeByString(
str(op_attribute)
)
builder.StatelessCall(
op_attribute, parallel_conf, bn_in_op2blob_object=bn_in_op2blob_object,
cfg_op_attribute,
parallel_conf,
bn_in_op2blob_object,
boxing_util.BoxingTo,
vm_util._FindOrCreateDelegateBlobObject,
)
vm_util.LogicalRun(BuildInstruction)
......@@ -278,8 +285,15 @@ def _EagerRunModelInit(var_op_conf):
parallel_conf = (
oneflow.current_scope().device_parallel_desc_symbol.parallel_conf
)
cfg_op_attribute = oneflow_api.deprecated.MakeOpAttributeByString(
str(op_attribute)
)
builder.StatelessCall(
op_attribute, parallel_conf, bn_in_op2blob_object=bn_in_op2blob_object
cfg_op_attribute,
parallel_conf,
bn_in_op2blob_object,
boxing_util.BoxingTo,
vm_util._FindOrCreateDelegateBlobObject,
)
sess = session_ctx.GetDefaultSession()
......@@ -295,8 +309,15 @@ def _MakeModelIOPathInputBuilds(op_conf, path, bn_in_op2blob_object):
parallel_conf = (
oneflow.current_scope().device_parallel_desc_symbol.parallel_conf
)
cfg_op_attribute = oneflow_api.deprecated.MakeOpAttributeByString(
str(op_attribute)
)
builder.StatelessCall(
op_attribute, parallel_conf, bn_in_op2blob_object=bn_in_op2blob_object
cfg_op_attribute,
parallel_conf,
bn_in_op2blob_object,
boxing_util.BoxingTo,
vm_util._FindOrCreateDelegateBlobObject,
)
def FeedPath(ofblob):
......@@ -327,7 +348,7 @@ def _EagerRunModelLoad(var_op_conf, snapshot_path):
)
model_load_op_conf, _ = _GenModelLoadOpConfAndRetLbi(var_op_conf, path_lbi)
model_load_blob_objects = {}
model_load_blob_objects = oneflow_api.deprecated.BnInOp2BlobObject()
def BuildModelLoadInstruction(builder):
path_blob_object = path_input_blob_objects["out"]
......@@ -336,8 +357,15 @@ def _EagerRunModelLoad(var_op_conf, snapshot_path):
model_load_op_conf, ibn2blob_object=model_load_blob_objects
)
parallel_conf = path_blob_object.parallel_desc_symbol.parallel_conf
cfg_op_attribute = oneflow_api.deprecated.MakeOpAttributeByString(
str(op_attribute)
)
builder.StatelessCall(
op_attribute, parallel_conf, bn_in_op2blob_object=model_load_blob_objects
cfg_op_attribute,
parallel_conf,
model_load_blob_objects,
boxing_util.BoxingTo,
vm_util._FindOrCreateDelegateBlobObject,
)
sess = session_ctx.GetDefaultSession()
......@@ -360,7 +388,7 @@ def _EagerRunModelSave(var_blobs, snapshot_path):
)
model_save_op_conf = _GenModelSaveOpConf(var_blobs, path_lbi)
model_save_blob_objects = {}
model_save_blob_objects = oneflow_api.deprecated.BnInOp2BlobObject()
def BuildModelSaveInstruction(builder):
path_blob_object = path_input_blob_objects["out"]
......@@ -372,8 +400,15 @@ def _EagerRunModelSave(var_blobs, snapshot_path):
model_save_op_conf, ibn2blob_object=model_save_blob_objects
)
parallel_conf = path_blob_object.parallel_desc_symbol.parallel_conf
cfg_op_attribute = oneflow_api.deprecated.MakeOpAttributeByString(
str(op_attribute)
)
builder.StatelessCall(
op_attribute, parallel_conf, bn_in_op2blob_object=model_save_blob_objects
cfg_op_attribute,
parallel_conf,
model_save_blob_objects,
boxing_util.BoxingTo,
vm_util._FindOrCreateDelegateBlobObject,
)
sess = session_ctx.GetDefaultSession()
......
......@@ -84,158 +84,6 @@ def _DefaultBlobObject4Ibn(ibn):
raise NotImplementedError
def StatelessCall(
self,
op_attribute,
parallel_conf,
bn_in_op2blob_object=oneflow_api.deprecated.BnInOp2BlobObject(),
):
op_parallel_desc_sym = self.GetParallelDescSymbol(parallel_conf)
cfg_op_attribute = oneflow_api.deprecated.MakeOpAttributeByString(str(op_attribute))
self.CheckRefInBlobObjectParallelDesc(
cfg_op_attribute, op_parallel_desc_sym, bn_in_op2blob_object,
)
def FetchDelegateBlobObject(x_blob_object, op_arg_parallel_attr):
return boxing_util.BoxingTo(self, x_blob_object, op_arg_parallel_attr)
def GetDelegateBlobObject(blob_object, op_arg_parallel_attr):
return _FindOrCreateDelegateBlobObject(
self, FetchDelegateBlobObject, blob_object, op_arg_parallel_attr
)
self._StatelessCall(
"compute",
cfg_op_attribute,
op_parallel_desc_sym,
op_parallel_desc_sym,
bn_in_op2blob_object,
GetDelegateBlobObject,
)
def NoBoxingStatelessCall(
self,
op_attribute,
parallel_conf,
bn_in_op2blob_object=oneflow_api.deprecated.BnInOp2BlobObject(),
):
op_parallel_desc_sym = self.GetParallelDescSymbol(parallel_conf)
cfg_op_attribute = oneflow_api.deprecated.MakeOpAttributeByString(str(op_attribute))
self.CheckRefInBlobObjectParallelDesc(
cfg_op_attribute, op_parallel_desc_sym, bn_in_op2blob_object,
)
def FetchDelegateBlobObject(blob_object, op_arg_parallel_attr):
from_pd = blob_object.parallel_desc_symbol
to_pd = op_arg_parallel_attr.parallel_desc_symbol
if from_pd == to_pd:
return blob_object
assert from_pd.device_tag == "cpu"
assert to_pd.device_tag == "cpu"
assert from_pd.parallel_num == to_pd.parallel_num
from_machine_ids = dict(from_pd.machine_id2device_id_list).keys()
to_machine_ids = dict(to_pd.machine_id2device_id_list).keys()
if (
len(from_pd.machine_id2device_id_list) == from_pd.parallel_num
and from_machine_ids == to_machine_ids
):
return self.BroadcastBlobReference(blob_object, to_pd)
return self.Build121To(blob_object, to_pd)
def GetDirectOr121BlobObject(blob_object, op_arg_parallel_attr):
return _FindOrCreateDelegateBlobObject(
self, FetchDelegateBlobObject, blob_object, op_arg_parallel_attr
)
self._StatelessCall(
"compute",
cfg_op_attribute,
op_parallel_desc_sym,
op_parallel_desc_sym,
bn_in_op2blob_object,
GetDirectOr121BlobObject,
)
def NoBoxingCudaD2HStatelessCall(
self,
op_attribute,
in_parallel_conf,
bn_in_op2blob_object=oneflow_api.deprecated.BnInOp2BlobObject(),
):
op_parallel_desc_sym = self.GetParallelDescSymbol(in_parallel_conf)
blob_parallel_desc_sym = boxing_util.TryReplaceDeviceTag(
self, op_parallel_desc_sym, "cpu"
)
cfg_op_attribute = oneflow_api.deprecated.MakeOpAttributeByString(str(op_attribute))
self.CheckRefInBlobObjectParallelDesc(
cfg_op_attribute, blob_parallel_desc_sym, bn_in_op2blob_object,
)
def GetDirectBlobObject(blob_object, op_arg_parallel_attr):
return blob_object
self._StatelessCall(
"copy_d2h",
cfg_op_attribute,
op_parallel_desc_sym,
blob_parallel_desc_sym,
bn_in_op2blob_object,
GetDirectBlobObject,
)
def NoBoxingCudaH2DStatelessCall(
self,
op_attribute,
out_parallel_conf,
bn_in_op2blob_object=oneflow_api.deprecated.BnInOp2BlobObject(),
):
op_parallel_desc_sym = self.GetParallelDescSymbol(out_parallel_conf)
cfg_op_attribute = oneflow_api.deprecated.MakeOpAttributeByString(str(op_attribute))
self.CheckRefInBlobObjectParallelDesc(
cfg_op_attribute, op_parallel_desc_sym, bn_in_op2blob_object,
)
def GetDirectBlobObject(blob_object, op_arg_parallel_attr):
return blob_object
self._StatelessCall(
"copy_h2d",
cfg_op_attribute,
op_parallel_desc_sym,
op_parallel_desc_sym,
bn_in_op2blob_object,
GetDirectBlobObject,
)
def RawStatelessCall(
self,
op_attribute,
parallel_conf,
bn_in_op2blob_object=oneflow_api.deprecated.BnInOp2BlobObject(),
):
op_parallel_desc_sym = self.GetParallelDescSymbol(parallel_conf)
cfg_op_attribute = oneflow_api.deprecated.MakeOpAttributeByString(str(op_attribute))
self.CheckRefInBlobObjectParallelDesc(
cfg_op_attribute, op_parallel_desc_sym, bn_in_op2blob_object,
)
def GetDirectBlobObject(blob_object, op_arg_parallel_attr):
return blob_object
self._StatelessCall(
"compute",
cfg_op_attribute,
op_parallel_desc_sym,
op_parallel_desc_sym,
bn_in_op2blob_object,
GetDirectBlobObject,
)
def StatefulCall(
self,
op_attribute,
......@@ -324,25 +172,6 @@ def CudaHostPinBlob(self, blob_object):
self.CudaHostUnregisterBlob(blob_object)
def NewOpKernelObject(self, op_conf):
assert op_conf.HasField("scope_symbol_id")
scope_symbol = oneflow_api.GetScopeSymbol(op_conf.scope_symbol_id)
cfg_op_conf = oneflow_api.deprecated.MakeOpConfByString(str(op_conf))
op_conf_sym = self.GetOpConfSymbol(cfg_op_conf)
parallel_desc_sym_id = c_api_util.GetOpParallelSymbolId(op_conf)
parallel_desc_symbol = oneflow_api.GetPlacementSymbol(parallel_desc_sym_id)
object_id = self._NewOpKernelObject(
parallel_desc_symbol, scope_symbol.job_desc_symbol, op_conf_sym
)
return oneflow_api.OpKernelObject(object_id, cfg_op_conf, self.object_releaser())
def Build121To(self, blob_object, parallel_desc_symbol):
ref_blob_object = _MakeNewBlobObjectLike(self, blob_object, parallel_desc_symbol)
self.Build121AssignInstruction(ref_blob_object, blob_object)
return ref_blob_object
def _StatefulCall(
self, op_attribute, opkernel_object, bn_in_op2blob_object, get_delegate_blob_object,
):
......@@ -414,17 +243,6 @@ def FeedBlob(self, blob_object, feeder):
def RegisterMethod4InstructionsBuilder():
oneflow_api.deprecated.InstructionsBuilder.StatelessCall = StatelessCall
oneflow_api.deprecated.InstructionsBuilder.NoBoxingStatelessCall = (
NoBoxingStatelessCall
)
oneflow_api.deprecated.InstructionsBuilder.NoBoxingCudaD2HStatelessCall = (
NoBoxingCudaD2HStatelessCall
)
oneflow_api.deprecated.InstructionsBuilder.NoBoxingCudaH2DStatelessCall = (
NoBoxingCudaH2DStatelessCall
)
oneflow_api.deprecated.InstructionsBuilder.RawStatelessCall = RawStatelessCall
oneflow_api.deprecated.InstructionsBuilder.StatefulCall = StatefulCall
oneflow_api.deprecated.InstructionsBuilder.InsertRemoveForeignCallbackInstruction = (
InsertRemoveForeignCallbackInstruction
......@@ -435,7 +253,6 @@ def RegisterMethod4InstructionsBuilder():
MakeLazyRefBlobObject
)
oneflow_api.deprecated.InstructionsBuilder.CudaHostPinBlob = CudaHostPinBlob
oneflow_api.deprecated.InstructionsBuilder.Build121To = Build121To
oneflow_api.deprecated.InstructionsBuilder._StatefulCall = _StatefulCall
oneflow_api.deprecated.InstructionsBuilder._FetchBlob = _FetchBlob
oneflow_api.deprecated.InstructionsBuilder.FeedBlob = FeedBlob
......
......@@ -31,6 +31,7 @@ import oneflow.python.framework.remote_blob as remote_blob_util
import oneflow.python.lib.core.async_util as async_util
import oneflow.python.eager.blob_cache as blob_cache_util
import oneflow.python.eager.vm_util as vm_util
import oneflow.python.eager.boxing_util as boxing_util
import oneflow.python.eager.op_infer_util as op_infer_util
import oneflow.core.framework.variable_meta_info_pb2 as variable_meta_info_pb
import oneflow.core.framework.user_op_attr_pb2 as attr_value_pb
......@@ -301,10 +302,15 @@ def _LogicalSlice(
op_attribute = op_infer_util.Infer(
op_conf, bn_in_op2blob_object, scope_symbol_id
)
cfg_op_attribute = oneflow_api.deprecated.MakeOpAttributeByString(
str(op_attribute)
)
builder.StatelessCall(
op_attribute,
parallel_conf=parallel_conf,
bn_in_op2blob_object=bn_in_op2blob_object,
cfg_op_attribute,
parallel_conf,
bn_in_op2blob_object,
boxing_util.BoxingTo,
vm_util._FindOrCreateDelegateBlobObject,
)
Yield(bn_in_op2blob_object["y_0"])
......@@ -394,10 +400,15 @@ def _LogicalSliceAssign(
op_attribute = op_infer_util.Infer(
op_conf, bn_in_op2blob_object, scope_symbol_id
)
cfg_op_attribute = oneflow_api.deprecated.MakeOpAttributeByString(
str(op_attribute)
)
builder.StatelessCall(
op_attribute,
parallel_conf=parallel_conf,
bn_in_op2blob_object=bn_in_op2blob_object,
cfg_op_attribute,
parallel_conf,
bn_in_op2blob_object,
boxing_util.BoxingTo,
vm_util._FindOrCreateDelegateBlobObject,
)
vm_util.LogicalRun(BuildAssignInstruction)
......
......@@ -25,6 +25,7 @@ import oneflow.python.framework.remote_blob as remote_blob_util
import oneflow.python.framework.id_util as id_util
import oneflow.python.eager.blob_cache as blob_cache_util
import oneflow.python.eager.vm_util as vm_util
import oneflow.python.eager.boxing_util as boxing_util
import oneflow.python.eager.blob_register as blob_register_util
import oneflow.core.operator.op_conf_pb2 as op_conf_util
import oneflow.core.register.logical_blob_id_pb2 as logical_blob_id_util
......@@ -150,8 +151,15 @@ def _MakeInputBlobObject(arg_blob_def):
op_attribute = arg_blob_def.EagerAddAndInferOp(input_op_conf)
scope = oneflow.current_scope()
parallel_conf = scope.device_parallel_desc_symbol.parallel_conf
cfg_op_attribute = oneflow_api.deprecated.MakeOpAttributeByString(
str(op_attribute)
)
builder.StatelessCall(
op_attribute, parallel_conf, bn_in_op2blob_object=bn_in_op2blob_object
cfg_op_attribute,
parallel_conf,
bn_in_op2blob_object,
boxing_util.BoxingTo,
vm_util._FindOrCreateDelegateBlobObject,
)
vm_util.LogicalRun(BuildInputInstruction)
......
......@@ -29,6 +29,7 @@ import oneflow.python.lib.core.enable_if as enable_if
import oneflow.python.framework.session_context as session_ctx
import oneflow.python.framework.scope_util as scope_util
import oneflow.python.eager.vm_util as vm_util
import oneflow.python.eager.boxing_util as boxing_util
import oneflow.python.eager.blob_register as blob_register_util
import oneflow_api.oneflow.core.job.placement as placement_cfg
import oneflow_api
......@@ -83,10 +84,15 @@ def EagerReturnRemoteBlob(remote_blob, allow_cpu_return_op=True):
def BuildInstruction(builder):
get_blob_scope = blob_register_util.BnInOp2BlobObjectScope
with get_blob_scope(blob_register, op_attribute) as bn_in_op2blob_object:
cfg_op_attribute = oneflow_api.deprecated.MakeOpAttributeByString(
str(op_attribute)
)
builder.StatelessCall(
op_attribute,
cfg_op_attribute,
remote_blob.blob_object.parallel_desc_symbol.parallel_conf,
bn_in_op2blob_object=bn_in_op2blob_object,
bn_in_op2blob_object,
boxing_util.BoxingTo,
vm_util._FindOrCreateDelegateBlobObject,
)
vm_util.LogicalRun(BuildInstruction)
......
......@@ -31,6 +31,7 @@ import oneflow.core.register.logical_blob_id_pb2 as logical_blob_id_util
import oneflow.python.framework.hob as hob
import oneflow.python.framework.dtype as dtype_util
import oneflow.python.eager.vm_util as vm_util
import oneflow.python.eager.boxing_util as boxing_util
import oneflow.python.eager.gradient_util as gradient_util
import oneflow.python.eager.op_executor as op_executor
import oneflow.python.lib.core.enable_if as enable_if
......@@ -353,8 +354,15 @@ def CreateEagerVariableBlob(op_attribute, job_name=""):
parallel_conf = (
oneflow.current_scope().device_parallel_desc_symbol.parallel_conf
)
cfg_op_attribute = oneflow_api.deprecated.MakeOpAttributeByString(
str(op_attribute)
)
builder.StatelessCall(
op_attribute, parallel_conf, bn_in_op2blob_object=bn_in_op2blob_object
cfg_op_attribute,
parallel_conf,
bn_in_op2blob_object,
boxing_util.BoxingTo,
vm_util._FindOrCreateDelegateBlobObject,
)
vm_util.LogicalRun(BuildInstruction)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册