提交 bb239e03 编写于 作者: S strint

make block more private to reduce conflicts with module

上级 55d32c33
......@@ -203,7 +203,7 @@ class ModuleBlock(Block):
# that hooks of nn.Modules are ignored. It is not recommended
# to use hooks of nn.Module in nn.Graph for the moment.
# result = self._origin.__class__.__call__(self, *args)
result = self._forward(*args)
result = self.__block_forward(*args)
outputs = ()
if not (type(result) is tuple or type(result) is list):
......@@ -231,17 +231,17 @@ class ModuleBlock(Block):
return result
def _forward(self, *args):
def __block_forward(self, *args):
self._is_executing_forward = True
args = self._pre_forward_mapping_out_scope(*args)
args = self.__pre_forward_mapping_out_scope(*args)
with self.scope_context():
result = self._origin.__class__.forward(self, *args)
result = self._post_forward_mapping_out_scope(result)
result = self.__post_forward_mapping_out_scope(result)
result = seq_to_func_return(result)
self._is_executing_forward = False
return result
def _pre_forward_mapping_out_scope(self, *args):
def __pre_forward_mapping_out_scope(self, *args):
# Insert identity op when doing activation checkpointing or pipeline execution.
# Identity op outside activation checkpointing scope will be the endpoint of an activation checkpointing segment.
# Identity op as the first op of a pipeline stage will make backward op depends on the identity op within the stage,
......@@ -254,11 +254,13 @@ class ModuleBlock(Block):
assert isinstance(t, Tensor)
return oneflow._C.identity(t)
args = self._mapping_io("input", insert_identity, "insert_identity", *args,)
args = self.__mapping_io(
"input", insert_identity, "insert_identity", *args,
)
return args
def _post_forward_mapping_out_scope(self, *args):
def __post_forward_mapping_out_scope(self, *args):
# Insert identity op when doing activation checkpointing or pipeline execution.
if self.config.activation_checkpointing or (
self.config.stage_id is not None and self.config.stage_id >= 0
......@@ -268,7 +270,7 @@ class ModuleBlock(Block):
assert isinstance(t, Tensor)
return oneflow._C.identity(t)
args = self._mapping_io(
args = self.__mapping_io(
"output", insert_identity, "insert_identity", *args,
)
return args
......@@ -298,7 +300,7 @@ class ModuleBlock(Block):
for m in module.modules(memo):
yield m
def _mapping_io(self, io_type, func, func_desc, *args):
def __mapping_io(self, io_type, func, func_desc, *args):
assert isinstance(func_desc, str)
assert io_type in ("input", "output")
mapped_args = []
......@@ -311,7 +313,7 @@ class ModuleBlock(Block):
if isinstance(arg, list):
seq_args = list()
for i in range(len(arg)):
is_tensor, name, repr_str = self._io_tensor_check_and_gen(
is_tensor, name, repr_str = self.__io_tensor_check_and_gen(
arg[i], io_type, idx, i
)
if is_tensor:
......@@ -330,7 +332,7 @@ class ModuleBlock(Block):
seq_args.append(arg[i])
mapped_args.append(seq_args)
elif isinstance(arg, Tensor):
is_tensor, name, repr_str = self._io_tensor_check_and_gen(
is_tensor, name, repr_str = self.__io_tensor_check_and_gen(
arg, io_type, idx
)
assert is_tensor
......@@ -341,7 +343,7 @@ class ModuleBlock(Block):
f"{repr_str} is a Tensor, {func_desc} transformation has been done.",
)
else:
is_tensor, name, repr_str = self._io_tensor_check_and_gen(
is_tensor, name, repr_str = self.__io_tensor_check_and_gen(
arg, io_type, idx
)
assert not is_tensor
......@@ -354,7 +356,7 @@ class ModuleBlock(Block):
return tuple(mapped_args)
def _io_tensor_check_and_gen(self, item, io_type, idx, second_idx=None):
def __io_tensor_check_and_gen(self, item, io_type, idx, second_idx=None):
assert io_type in ("input", "output")
name = (
"_"
......@@ -383,7 +385,7 @@ class ModuleBlock(Block):
)
return False, name, repr_str
def _members(self, get_members_fn, recurse=True) -> Iterator["Block"]:
def __members(self, get_members_fn, recurse=True) -> Iterator["Block"]:
assert self._type == BlockType.MODULE
memo = set()
modules = self.modules() if recurse else [self]
......@@ -397,13 +399,13 @@ class ModuleBlock(Block):
def parameters(self, recurse: bool = True) -> Iterator["Block"]:
assert self._type == BlockType.MODULE
gen = self._members(lambda module: module._parameters.items(), recurse=recurse)
gen = self.__members(lambda module: module._parameters.items(), recurse=recurse)
for elem in gen:
yield elem
def buffers(self, recurse: bool = True) -> Iterator["Block"]:
assert self._type == BlockType.MODULE
gen = self._members(lambda module: module._buffers.items(), recurse=recurse)
gen = self.__members(lambda module: module._buffers.items(), recurse=recurse)
for elem in gen:
yield elem
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册