Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
d7afdbcc
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
d7afdbcc
编写于
7月 21, 2020
作者:
L
Liang Depeng
提交者:
GitHub
7月 21, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add remove foreign callback instruction (#3231)
Co-authored-by:
N
Li Xinqi
<
lixinqi2010@gmail.com
>
上级
268c3c04
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
72 addition
and
7 deletion
+72
-7
oneflow/core/eager/opkernel_instruction.msg.h
oneflow/core/eager/opkernel_instruction.msg.h
+5
-0
oneflow/core/eager/remove_foreign_callback_instruction_type.cpp
...w/core/eager/remove_foreign_callback_instruction_type.cpp
+32
-0
oneflow/core/job/foreign_callback.h
oneflow/core/job/foreign_callback.h
+2
-0
oneflow/python/eager/eager_blob_util.py
oneflow/python/eager/eager_blob_util.py
+16
-6
oneflow/python/eager/vm_util.py
oneflow/python/eager/vm_util.py
+8
-0
oneflow/python/framework/push_util.py
oneflow/python/framework/push_util.py
+1
-1
oneflow/python/framework/python_callback.py
oneflow/python/framework/python_callback.py
+8
-0
未找到文件。
oneflow/core/eager/opkernel_instruction.msg.h
浏览文件 @
d7afdbcc
...
...
@@ -64,6 +64,11 @@ FLAT_MSG_VIEW_BEGIN(FeedBlobInstrOperand);
FLAT_MSG_VIEW_DEFINE_PATTERN
(
vm
::
Mut2Operand
,
blob
);
FLAT_MSG_VIEW_DEFINE_PATTERN
(
int64_t
,
unique_callback_id
);
FLAT_MSG_VIEW_END
(
FeedBlobInstrOperand
);
FLAT_MSG_VIEW_BEGIN
(
RemoveForeignCallbackInstrOperand
);
FLAT_MSG_VIEW_DEFINE_PATTERN
(
vm
::
MutOperand
,
object_id
);
FLAT_MSG_VIEW_DEFINE_PATTERN
(
int64_t
,
unique_callback_id
);
FLAT_MSG_VIEW_END
(
RemoveForeignCallbackInstrOperand
);
// clang-format on
}
// namespace eager
...
...
oneflow/core/eager/remove_foreign_callback_instruction_type.cpp
0 → 100644
浏览文件 @
d7afdbcc
#include "oneflow/core/vm/instruction.msg.h"
#include "oneflow/core/vm/instruction_type.h"
#include "oneflow/core/vm/host_stream_type.h"
#include "oneflow/core/eager/opkernel_instruction.msg.h"
#include "oneflow/core/job/foreign_callback.h"
namespace
oneflow
{
namespace
eager
{
class
RemoveForeignCallbackInstructionType
:
public
vm
::
InstructionType
{
public:
RemoveForeignCallbackInstructionType
()
=
default
;
~
RemoveForeignCallbackInstructionType
()
override
=
default
;
using
stream_type
=
vm
::
HostStreamType
;
void
Infer
(
vm
::
Instruction
*
instruction
)
const
override
{
// do nothing
}
void
Compute
(
vm
::
Instruction
*
instruction
)
const
override
{
FlatMsgView
<
RemoveForeignCallbackInstrOperand
>
args
(
instruction
->
instr_msg
().
operand
());
Global
<
ForeignCallback
>::
Get
()
->
RemoveForeignCallback
(
args
->
unique_callback_id
());
}
};
COMMAND
(
vm
::
RegisterInstructionType
<
RemoveForeignCallbackInstructionType
>
(
"RemoveForeignCallback"
));
}
// namespace eager
}
// namespace oneflow
oneflow/core/job/foreign_callback.h
浏览文件 @
d7afdbcc
...
...
@@ -18,6 +18,8 @@ class ForeignCallback {
}
virtual
void
OfBlobCall
(
int64_t
unique_id
,
int64_t
ofblob_ptr
)
const
{
UNIMPLEMENTED
();
}
virtual
void
RemoveForeignCallback
(
int64_t
unique_id
)
const
{
UNIMPLEMENTED
();
}
};
}
// namespace oneflow
...
...
oneflow/python/eager/eager_blob_util.py
浏览文件 @
d7afdbcc
...
...
@@ -65,8 +65,14 @@ class EagerPhysicalBlob(blob_trait.BlobOperatorTrait, blob_trait.BlobHeaderTrait
def
FetchTensorBlobAsNumpyList
(
parallel_size
,
blob_object
):
def
AsyncFetchBlobBody
(
Yield
):
fetcher
=
_MakeFetcherEagerBlobBodyAsNumpyFromOfBlob
(
Yield
)
vm_util
.
PhysicalRun
(
lambda
builder
:
builder
.
FetchBlobBody
(
blob_object
,
fetcher
))
python_callback
.
DeleteRegisteredCallback
(
fetcher
)
def
BuildFetchBlobBodyInstruction
(
builder
):
builder
.
FetchBlobBody
(
blob_object
,
fetcher
)
builder
.
InsertRemoveForeignCallbackInstruction
(
blob_object
.
object_id
,
fetcher
)
vm_util
.
PhysicalRun
(
BuildFetchBlobBodyInstruction
)
return
async_util
.
Await
(
parallel_size
,
AsyncFetchBlobBody
)
...
...
@@ -84,10 +90,14 @@ def _GetPhysicalBlobBodyCache(blob_object):
def
_FetchBlobHeader
(
blob_object
):
def
AsyncFetchBlobHeader
(
Yield
):
fetcher
=
_MakeFetcherEagerPhysicalBlobHeaderFromOfBlob
(
Yield
)
vm_util
.
PhysicalRun
(
lambda
builder
:
builder
.
FetchBlobHeader
(
blob_object
,
fetcher
)
)
python_callback
.
DeleteRegisteredCallback
(
fetcher
)
def
BuildFetchBlobHeaderInstruction
(
builder
):
builder
.
FetchBlobHeader
(
blob_object
,
fetcher
)
builder
.
InsertRemoveForeignCallbackInstruction
(
blob_object
.
object_id
,
fetcher
)
vm_util
.
PhysicalRun
(
BuildFetchBlobHeaderInstruction
)
return
async_util
.
Await
(
1
,
AsyncFetchBlobHeader
)[
0
]
...
...
oneflow/python/eager/vm_util.py
浏览文件 @
d7afdbcc
...
...
@@ -220,6 +220,14 @@ class InstructionsBuilder(object):
self
.
_TryClearObject
(
blob_object
)
self
.
_DeleteObject
(
blob_object
)
def
InsertRemoveForeignCallbackInstruction
(
self
,
object_id
,
callback
):
unique_callback_id
=
python_callback
.
GetIdForRegisteredCallback
(
callback
)
instruction
=
instr_util
.
InstructionProto
()
instruction
.
instr_type_name
=
"RemoveForeignCallback"
instruction
.
operand
.
append
(
_DelObjectOperand
(
object_id
))
instruction
.
operand
.
append
(
_Int64Operand
(
unique_callback_id
))
self
.
instruction_list_
.
instruction
.
append
(
instruction
)
def
FetchBlobHeader
(
self
,
blob_object
,
callback
):
return
self
.
_FetchBlob
(
"FetchBlobHeader"
,
blob_object
,
callback
)
...
...
oneflow/python/framework/push_util.py
浏览文件 @
d7afdbcc
...
...
@@ -248,9 +248,9 @@ def _FeedValueToInputPhysicalBlob(feed_ctx, blob_def, blob_object):
def
BuildFeedInstruction
(
builder
):
builder
.
FeedBlob
(
blob_object
,
FeedBlob
)
builder
.
InsertRemoveForeignCallbackInstruction
(
blob_object
.
object_id
,
FeedBlob
)
vm_util
.
PhysicalRun
(
BuildFeedInstruction
)
python_callback
.
DeleteRegisteredCallback
(
FeedBlob
)
def
_MakeFeedBlobCallback
(
feed_ctx
,
blob_def
,
blob_object
):
...
...
oneflow/python/framework/python_callback.py
浏览文件 @
d7afdbcc
...
...
@@ -30,6 +30,14 @@ class PythonCallback(oneflow_internal.ForeignCallback):
print
(
traceback
.
format_exc
())
raise
e
def
RemoveForeignCallback
(
self
,
unique_id
):
global
unique_id2handler
try
:
del
unique_id2handler
[
unique_id
]
except
Exception
as
e
:
print
(
traceback
.
format_exc
())
raise
e
def
EagerInterpretCompletedOp
(
self
,
op_attribute_str
,
parallel_conf_str
):
try
:
interpreter_callback
.
InterpretCompletedOp
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录