From d7afdbcc280c4d801df6e6e6e816cceb9909250a Mon Sep 17 00:00:00 2001 From: Liang Depeng Date: Tue, 21 Jul 2020 10:19:25 +0800 Subject: [PATCH] add remove foreign callback instruction (#3231) Co-authored-by: Li Xinqi --- oneflow/core/eager/opkernel_instruction.msg.h | 5 +++ ...move_foreign_callback_instruction_type.cpp | 32 +++++++++++++++++++ oneflow/core/job/foreign_callback.h | 2 ++ oneflow/python/eager/eager_blob_util.py | 22 +++++++++---- oneflow/python/eager/vm_util.py | 8 +++++ oneflow/python/framework/push_util.py | 2 +- oneflow/python/framework/python_callback.py | 8 +++++ 7 files changed, 72 insertions(+), 7 deletions(-) create mode 100644 oneflow/core/eager/remove_foreign_callback_instruction_type.cpp diff --git a/oneflow/core/eager/opkernel_instruction.msg.h b/oneflow/core/eager/opkernel_instruction.msg.h index 3ac16e654b..d7658ca290 100644 --- a/oneflow/core/eager/opkernel_instruction.msg.h +++ b/oneflow/core/eager/opkernel_instruction.msg.h @@ -64,6 +64,11 @@ FLAT_MSG_VIEW_BEGIN(FeedBlobInstrOperand); FLAT_MSG_VIEW_DEFINE_PATTERN(vm::Mut2Operand, blob); FLAT_MSG_VIEW_DEFINE_PATTERN(int64_t, unique_callback_id); FLAT_MSG_VIEW_END(FeedBlobInstrOperand); + +FLAT_MSG_VIEW_BEGIN(RemoveForeignCallbackInstrOperand); + FLAT_MSG_VIEW_DEFINE_PATTERN(vm::MutOperand, object_id); + FLAT_MSG_VIEW_DEFINE_PATTERN(int64_t, unique_callback_id); +FLAT_MSG_VIEW_END(RemoveForeignCallbackInstrOperand); // clang-format on } // namespace eager diff --git a/oneflow/core/eager/remove_foreign_callback_instruction_type.cpp b/oneflow/core/eager/remove_foreign_callback_instruction_type.cpp new file mode 100644 index 0000000000..47ee97db23 --- /dev/null +++ b/oneflow/core/eager/remove_foreign_callback_instruction_type.cpp @@ -0,0 +1,32 @@ +#include "oneflow/core/vm/instruction.msg.h" +#include "oneflow/core/vm/instruction_type.h" +#include "oneflow/core/vm/host_stream_type.h" +#include "oneflow/core/eager/opkernel_instruction.msg.h" +#include "oneflow/core/job/foreign_callback.h" + +namespace oneflow { + +namespace eager { + +class RemoveForeignCallbackInstructionType : public vm::InstructionType { + public: + RemoveForeignCallbackInstructionType() = default; + ~RemoveForeignCallbackInstructionType() override = default; + + using stream_type = vm::HostStreamType; + + void Infer(vm::Instruction* instruction) const override { + // do nothing + } + + void Compute(vm::Instruction* instruction) const override { + FlatMsgView args(instruction->instr_msg().operand()); + Global::Get()->RemoveForeignCallback(args->unique_callback_id()); + } +}; + +COMMAND(vm::RegisterInstructionType("RemoveForeignCallback")); + +} // namespace eager + +} // namespace oneflow diff --git a/oneflow/core/job/foreign_callback.h b/oneflow/core/job/foreign_callback.h index 4093ec0ea3..cce8419634 100644 --- a/oneflow/core/job/foreign_callback.h +++ b/oneflow/core/job/foreign_callback.h @@ -18,6 +18,8 @@ class ForeignCallback { } virtual void OfBlobCall(int64_t unique_id, int64_t ofblob_ptr) const { UNIMPLEMENTED(); } + + virtual void RemoveForeignCallback(int64_t unique_id) const { UNIMPLEMENTED(); } }; } // namespace oneflow diff --git a/oneflow/python/eager/eager_blob_util.py b/oneflow/python/eager/eager_blob_util.py index 4de2f1f566..7aed83c916 100644 --- a/oneflow/python/eager/eager_blob_util.py +++ b/oneflow/python/eager/eager_blob_util.py @@ -65,8 +65,14 @@ class EagerPhysicalBlob(blob_trait.BlobOperatorTrait, blob_trait.BlobHeaderTrait def FetchTensorBlobAsNumpyList(parallel_size, blob_object): def AsyncFetchBlobBody(Yield): fetcher = _MakeFetcherEagerBlobBodyAsNumpyFromOfBlob(Yield) - vm_util.PhysicalRun(lambda builder: builder.FetchBlobBody(blob_object, fetcher)) - python_callback.DeleteRegisteredCallback(fetcher) + + def BuildFetchBlobBodyInstruction(builder): + builder.FetchBlobBody(blob_object, fetcher) + builder.InsertRemoveForeignCallbackInstruction( + blob_object.object_id, fetcher + ) + + vm_util.PhysicalRun(BuildFetchBlobBodyInstruction) return async_util.Await(parallel_size, AsyncFetchBlobBody) @@ -84,10 +90,14 @@ def _GetPhysicalBlobBodyCache(blob_object): def _FetchBlobHeader(blob_object): def AsyncFetchBlobHeader(Yield): fetcher = _MakeFetcherEagerPhysicalBlobHeaderFromOfBlob(Yield) - vm_util.PhysicalRun( - lambda builder: builder.FetchBlobHeader(blob_object, fetcher) - ) - python_callback.DeleteRegisteredCallback(fetcher) + + def BuildFetchBlobHeaderInstruction(builder): + builder.FetchBlobHeader(blob_object, fetcher) + builder.InsertRemoveForeignCallbackInstruction( + blob_object.object_id, fetcher + ) + + vm_util.PhysicalRun(BuildFetchBlobHeaderInstruction) return async_util.Await(1, AsyncFetchBlobHeader)[0] diff --git a/oneflow/python/eager/vm_util.py b/oneflow/python/eager/vm_util.py index 651a05aa75..83d8f30682 100644 --- a/oneflow/python/eager/vm_util.py +++ b/oneflow/python/eager/vm_util.py @@ -220,6 +220,14 @@ class InstructionsBuilder(object): self._TryClearObject(blob_object) self._DeleteObject(blob_object) + def InsertRemoveForeignCallbackInstruction(self, object_id, callback): + unique_callback_id = python_callback.GetIdForRegisteredCallback(callback) + instruction = instr_util.InstructionProto() + instruction.instr_type_name = "RemoveForeignCallback" + instruction.operand.append(_DelObjectOperand(object_id)) + instruction.operand.append(_Int64Operand(unique_callback_id)) + self.instruction_list_.instruction.append(instruction) + def FetchBlobHeader(self, blob_object, callback): return self._FetchBlob("FetchBlobHeader", blob_object, callback) diff --git a/oneflow/python/framework/push_util.py b/oneflow/python/framework/push_util.py index cf7ae89937..ed8d6890a0 100644 --- a/oneflow/python/framework/push_util.py +++ b/oneflow/python/framework/push_util.py @@ -248,9 +248,9 @@ def _FeedValueToInputPhysicalBlob(feed_ctx, blob_def, blob_object): def BuildFeedInstruction(builder): builder.FeedBlob(blob_object, FeedBlob) + builder.InsertRemoveForeignCallbackInstruction(blob_object.object_id, FeedBlob) vm_util.PhysicalRun(BuildFeedInstruction) - python_callback.DeleteRegisteredCallback(FeedBlob) def _MakeFeedBlobCallback(feed_ctx, blob_def, blob_object): diff --git a/oneflow/python/framework/python_callback.py b/oneflow/python/framework/python_callback.py index d61d9d745e..f60f026194 100644 --- a/oneflow/python/framework/python_callback.py +++ b/oneflow/python/framework/python_callback.py @@ -30,6 +30,14 @@ class PythonCallback(oneflow_internal.ForeignCallback): print(traceback.format_exc()) raise e + def RemoveForeignCallback(self, unique_id): + global unique_id2handler + try: + del unique_id2handler[unique_id] + except Exception as e: + print(traceback.format_exc()) + raise e + def EagerInterpretCompletedOp(self, op_attribute_str, parallel_conf_str): try: interpreter_callback.InterpretCompletedOp( -- GitLab