未验证 提交 d7afdbcc 编写于 作者: L Liang Depeng 提交者: GitHub

add remove foreign callback instruction (#3231)

Co-authored-by: NLi Xinqi <lixinqi2010@gmail.com>
上级 268c3c04
......@@ -64,6 +64,11 @@ FLAT_MSG_VIEW_BEGIN(FeedBlobInstrOperand);
FLAT_MSG_VIEW_DEFINE_PATTERN(vm::Mut2Operand, blob);
FLAT_MSG_VIEW_DEFINE_PATTERN(int64_t, unique_callback_id);
FLAT_MSG_VIEW_END(FeedBlobInstrOperand);
FLAT_MSG_VIEW_BEGIN(RemoveForeignCallbackInstrOperand);
FLAT_MSG_VIEW_DEFINE_PATTERN(vm::MutOperand, object_id);
FLAT_MSG_VIEW_DEFINE_PATTERN(int64_t, unique_callback_id);
FLAT_MSG_VIEW_END(RemoveForeignCallbackInstrOperand);
// clang-format on
} // namespace eager
......
#include "oneflow/core/vm/instruction.msg.h"
#include "oneflow/core/vm/instruction_type.h"
#include "oneflow/core/vm/host_stream_type.h"
#include "oneflow/core/eager/opkernel_instruction.msg.h"
#include "oneflow/core/job/foreign_callback.h"
namespace oneflow {
namespace eager {
class RemoveForeignCallbackInstructionType : public vm::InstructionType {
public:
RemoveForeignCallbackInstructionType() = default;
~RemoveForeignCallbackInstructionType() override = default;
using stream_type = vm::HostStreamType;
void Infer(vm::Instruction* instruction) const override {
// do nothing
}
void Compute(vm::Instruction* instruction) const override {
FlatMsgView<RemoveForeignCallbackInstrOperand> args(instruction->instr_msg().operand());
Global<ForeignCallback>::Get()->RemoveForeignCallback(args->unique_callback_id());
}
};
COMMAND(vm::RegisterInstructionType<RemoveForeignCallbackInstructionType>("RemoveForeignCallback"));
} // namespace eager
} // namespace oneflow
......@@ -18,6 +18,8 @@ class ForeignCallback {
}
virtual void OfBlobCall(int64_t unique_id, int64_t ofblob_ptr) const { UNIMPLEMENTED(); }
virtual void RemoveForeignCallback(int64_t unique_id) const { UNIMPLEMENTED(); }
};
} // namespace oneflow
......
......@@ -65,8 +65,14 @@ class EagerPhysicalBlob(blob_trait.BlobOperatorTrait, blob_trait.BlobHeaderTrait
def FetchTensorBlobAsNumpyList(parallel_size, blob_object):
def AsyncFetchBlobBody(Yield):
fetcher = _MakeFetcherEagerBlobBodyAsNumpyFromOfBlob(Yield)
vm_util.PhysicalRun(lambda builder: builder.FetchBlobBody(blob_object, fetcher))
python_callback.DeleteRegisteredCallback(fetcher)
def BuildFetchBlobBodyInstruction(builder):
builder.FetchBlobBody(blob_object, fetcher)
builder.InsertRemoveForeignCallbackInstruction(
blob_object.object_id, fetcher
)
vm_util.PhysicalRun(BuildFetchBlobBodyInstruction)
return async_util.Await(parallel_size, AsyncFetchBlobBody)
......@@ -84,10 +90,14 @@ def _GetPhysicalBlobBodyCache(blob_object):
def _FetchBlobHeader(blob_object):
def AsyncFetchBlobHeader(Yield):
fetcher = _MakeFetcherEagerPhysicalBlobHeaderFromOfBlob(Yield)
vm_util.PhysicalRun(
lambda builder: builder.FetchBlobHeader(blob_object, fetcher)
)
python_callback.DeleteRegisteredCallback(fetcher)
def BuildFetchBlobHeaderInstruction(builder):
builder.FetchBlobHeader(blob_object, fetcher)
builder.InsertRemoveForeignCallbackInstruction(
blob_object.object_id, fetcher
)
vm_util.PhysicalRun(BuildFetchBlobHeaderInstruction)
return async_util.Await(1, AsyncFetchBlobHeader)[0]
......
......@@ -220,6 +220,14 @@ class InstructionsBuilder(object):
self._TryClearObject(blob_object)
self._DeleteObject(blob_object)
def InsertRemoveForeignCallbackInstruction(self, object_id, callback):
unique_callback_id = python_callback.GetIdForRegisteredCallback(callback)
instruction = instr_util.InstructionProto()
instruction.instr_type_name = "RemoveForeignCallback"
instruction.operand.append(_DelObjectOperand(object_id))
instruction.operand.append(_Int64Operand(unique_callback_id))
self.instruction_list_.instruction.append(instruction)
def FetchBlobHeader(self, blob_object, callback):
return self._FetchBlob("FetchBlobHeader", blob_object, callback)
......
......@@ -248,9 +248,9 @@ def _FeedValueToInputPhysicalBlob(feed_ctx, blob_def, blob_object):
def BuildFeedInstruction(builder):
builder.FeedBlob(blob_object, FeedBlob)
builder.InsertRemoveForeignCallbackInstruction(blob_object.object_id, FeedBlob)
vm_util.PhysicalRun(BuildFeedInstruction)
python_callback.DeleteRegisteredCallback(FeedBlob)
def _MakeFeedBlobCallback(feed_ctx, blob_def, blob_object):
......
......@@ -30,6 +30,14 @@ class PythonCallback(oneflow_internal.ForeignCallback):
print(traceback.format_exc())
raise e
def RemoveForeignCallback(self, unique_id):
global unique_id2handler
try:
del unique_id2handler[unique_id]
except Exception as e:
print(traceback.format_exc())
raise e
def EagerInterpretCompletedOp(self, op_attribute_str, parallel_conf_str):
try:
interpreter_callback.InterpretCompletedOp(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册