Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
bb239e03
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
bb239e03
编写于
11月 02, 2021
作者:
S
strint
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
make block more private to reduce conflicts with module
上级
55d32c33
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
18 addition
and
16 deletion
+18
-16
python/oneflow/nn/graph/block.py
python/oneflow/nn/graph/block.py
+18
-16
未找到文件。
python/oneflow/nn/graph/block.py
浏览文件 @
bb239e03
...
...
@@ -203,7 +203,7 @@ class ModuleBlock(Block):
# that hooks of nn.Modules are ignored. It is not recommended
# to use hooks of nn.Module in nn.Graph for the moment.
# result = self._origin.__class__.__call__(self, *args)
result
=
self
.
_forward
(
*
args
)
result
=
self
.
_
_block_
forward
(
*
args
)
outputs
=
()
if
not
(
type
(
result
)
is
tuple
or
type
(
result
)
is
list
):
...
...
@@ -231,17 +231,17 @@ class ModuleBlock(Block):
return
result
def
_forward
(
self
,
*
args
):
def
_
_block_
forward
(
self
,
*
args
):
self
.
_is_executing_forward
=
True
args
=
self
.
_pre_forward_mapping_out_scope
(
*
args
)
args
=
self
.
_
_
pre_forward_mapping_out_scope
(
*
args
)
with
self
.
scope_context
():
result
=
self
.
_origin
.
__class__
.
forward
(
self
,
*
args
)
result
=
self
.
_post_forward_mapping_out_scope
(
result
)
result
=
self
.
_
_
post_forward_mapping_out_scope
(
result
)
result
=
seq_to_func_return
(
result
)
self
.
_is_executing_forward
=
False
return
result
def
_pre_forward_mapping_out_scope
(
self
,
*
args
):
def
_
_
pre_forward_mapping_out_scope
(
self
,
*
args
):
# Insert identity op when doing activation checkpointing or pipeline execution.
# Identity op outside activation checkpointing scope will be the endpoint of an activation checkpointing segment.
# Identity op as the first op of a pipeline stage will make backward op depends on the identity op within the stage,
...
...
@@ -254,11 +254,13 @@ class ModuleBlock(Block):
assert
isinstance
(
t
,
Tensor
)
return
oneflow
.
_C
.
identity
(
t
)
args
=
self
.
_mapping_io
(
"input"
,
insert_identity
,
"insert_identity"
,
*
args
,)
args
=
self
.
__mapping_io
(
"input"
,
insert_identity
,
"insert_identity"
,
*
args
,
)
return
args
def
_post_forward_mapping_out_scope
(
self
,
*
args
):
def
_
_
post_forward_mapping_out_scope
(
self
,
*
args
):
# Insert identity op when doing activation checkpointing or pipeline execution.
if
self
.
config
.
activation_checkpointing
or
(
self
.
config
.
stage_id
is
not
None
and
self
.
config
.
stage_id
>=
0
...
...
@@ -268,7 +270,7 @@ class ModuleBlock(Block):
assert
isinstance
(
t
,
Tensor
)
return
oneflow
.
_C
.
identity
(
t
)
args
=
self
.
_mapping_io
(
args
=
self
.
_
_
mapping_io
(
"output"
,
insert_identity
,
"insert_identity"
,
*
args
,
)
return
args
...
...
@@ -298,7 +300,7 @@ class ModuleBlock(Block):
for
m
in
module
.
modules
(
memo
):
yield
m
def
_mapping_io
(
self
,
io_type
,
func
,
func_desc
,
*
args
):
def
_
_
mapping_io
(
self
,
io_type
,
func
,
func_desc
,
*
args
):
assert
isinstance
(
func_desc
,
str
)
assert
io_type
in
(
"input"
,
"output"
)
mapped_args
=
[]
...
...
@@ -311,7 +313,7 @@ class ModuleBlock(Block):
if
isinstance
(
arg
,
list
):
seq_args
=
list
()
for
i
in
range
(
len
(
arg
)):
is_tensor
,
name
,
repr_str
=
self
.
_io_tensor_check_and_gen
(
is_tensor
,
name
,
repr_str
=
self
.
_
_
io_tensor_check_and_gen
(
arg
[
i
],
io_type
,
idx
,
i
)
if
is_tensor
:
...
...
@@ -330,7 +332,7 @@ class ModuleBlock(Block):
seq_args
.
append
(
arg
[
i
])
mapped_args
.
append
(
seq_args
)
elif
isinstance
(
arg
,
Tensor
):
is_tensor
,
name
,
repr_str
=
self
.
_io_tensor_check_and_gen
(
is_tensor
,
name
,
repr_str
=
self
.
_
_
io_tensor_check_and_gen
(
arg
,
io_type
,
idx
)
assert
is_tensor
...
...
@@ -341,7 +343,7 @@ class ModuleBlock(Block):
f
"
{
repr_str
}
is a Tensor,
{
func_desc
}
transformation has been done."
,
)
else
:
is_tensor
,
name
,
repr_str
=
self
.
_io_tensor_check_and_gen
(
is_tensor
,
name
,
repr_str
=
self
.
_
_
io_tensor_check_and_gen
(
arg
,
io_type
,
idx
)
assert
not
is_tensor
...
...
@@ -354,7 +356,7 @@ class ModuleBlock(Block):
return
tuple
(
mapped_args
)
def
_io_tensor_check_and_gen
(
self
,
item
,
io_type
,
idx
,
second_idx
=
None
):
def
_
_
io_tensor_check_and_gen
(
self
,
item
,
io_type
,
idx
,
second_idx
=
None
):
assert
io_type
in
(
"input"
,
"output"
)
name
=
(
"_"
...
...
@@ -383,7 +385,7 @@ class ModuleBlock(Block):
)
return
False
,
name
,
repr_str
def
_members
(
self
,
get_members_fn
,
recurse
=
True
)
->
Iterator
[
"Block"
]:
def
_
_
members
(
self
,
get_members_fn
,
recurse
=
True
)
->
Iterator
[
"Block"
]:
assert
self
.
_type
==
BlockType
.
MODULE
memo
=
set
()
modules
=
self
.
modules
()
if
recurse
else
[
self
]
...
...
@@ -397,13 +399,13 @@ class ModuleBlock(Block):
def
parameters
(
self
,
recurse
:
bool
=
True
)
->
Iterator
[
"Block"
]:
assert
self
.
_type
==
BlockType
.
MODULE
gen
=
self
.
_members
(
lambda
module
:
module
.
_parameters
.
items
(),
recurse
=
recurse
)
gen
=
self
.
_
_
members
(
lambda
module
:
module
.
_parameters
.
items
(),
recurse
=
recurse
)
for
elem
in
gen
:
yield
elem
def
buffers
(
self
,
recurse
:
bool
=
True
)
->
Iterator
[
"Block"
]:
assert
self
.
_type
==
BlockType
.
MODULE
gen
=
self
.
_members
(
lambda
module
:
module
.
_buffers
.
items
(),
recurse
=
recurse
)
gen
=
self
.
_
_
members
(
lambda
module
:
module
.
_buffers
.
items
(),
recurse
=
recurse
)
for
elem
in
gen
:
yield
elem
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录