未验证 提交 8cc40f47 编写于 作者: L levi131 提交者: GitHub

enhance check for current block and docstring for prim2orig interface (#43063)

* enhance check for current block docstring for prim2orig interface

* refine if else syntax
上级 586f9429
...@@ -38,8 +38,7 @@ def topo_path(xs, ys, block=None): ...@@ -38,8 +38,7 @@ def topo_path(xs, ys, block=None):
path, the unused variables in `xs`, and the unreached variables in `ys` path, the unused variables in `xs`, and the unreached variables in `ys`
""" """
if block is None: block = default_main_program().current_block() if block is None else block
block = default_main_program().current_block()
path = [] path = []
backpath = [] backpath = []
...@@ -160,11 +159,14 @@ class VarMap(object): ...@@ -160,11 +159,14 @@ class VarMap(object):
return id(value_var) in self.tab.values() return id(value_var) in self.tab.values()
# TODO(lml): supporting control flow, nested blocks, and block other than current block of main program.
class Transform(object): class Transform(object):
""" An object that maintains the state of transformations applied to a """ An object that maintains the state of transformations applied to a
primitve program. """ primitve program. """
def __init__(self, block): def __init__(self, block):
assert block == default_main_program().current_block(
), f'only support transform on current block of main program.'
self.block = block self.block = block
self.vars = self.init_vars(block) self.vars = self.init_vars(block)
self.var2dot = VarMap('var2dot', self.vars) self.var2dot = VarMap('var2dot', self.vars)
...@@ -400,6 +402,7 @@ class Transform(object): ...@@ -400,6 +402,7 @@ class Transform(object):
return ys_bar, xs_bar return ys_bar, xs_bar
# TODO(lml): supporting control flow, nested blocks, and block other than current block of main program.
def _lower(block, reverse): def _lower(block, reverse):
# Some functions which are only used in _lower. # Some functions which are only used in _lower.
def bind(args, to_bind, value_table): def bind(args, to_bind, value_table):
...@@ -430,10 +433,6 @@ def _lower(block, reverse): ...@@ -430,10 +433,6 @@ def _lower(block, reverse):
# Step1: Do some preparatory work for lower # Step1: Do some preparatory work for lower
lower_fn = _prim2orig if reverse else _orig2prim lower_fn = _prim2orig if reverse else _orig2prim
lookup_fn = lookup_prim2orig if reverse else lookup_orig2prim lookup_fn = lookup_prim2orig if reverse else lookup_orig2prim
if block is None:
program = default_main_program()
assert program.num_blocks == 1, "The lower transform is designed to process only one block."
block = program.current_block()
value_table = {} value_table = {}
to_bind = {} to_bind = {}
...@@ -516,6 +515,7 @@ def orig2prim(block=None): ...@@ -516,6 +515,7 @@ def orig2prim(block=None):
""" """
.. note:: .. note::
**This API is ONLY available in the static mode.** **This API is ONLY available in the static mode.**
**Args block must be None or current block of main program.**
All operators in the target block are processed as follows. All operators in the target block are processed as follows.
If it is an original operator, it will be transformed into If it is an original operator, it will be transformed into
...@@ -523,13 +523,14 @@ def orig2prim(block=None): ...@@ -523,13 +523,14 @@ def orig2prim(block=None):
equivalent function. equivalent function.
Args: Args:
block(paddle.fluid.framework.Variable|None, optional): The block(paddle.static.Block|None, optional): The
target block to process on. Default None, and will target block to process on. Default None, and will
process on the current block of main program. process on the current block of main program.
Returns:
None
""" """
block = default_main_program().current_block() if block is None else block
assert block == default_main_program().current_block(
), f'block is neither None nor current block of main program'
_lower(block, reverse=False) _lower(block, reverse=False)
...@@ -538,6 +539,7 @@ def prim2orig(block=None): ...@@ -538,6 +539,7 @@ def prim2orig(block=None):
""" """
.. note:: .. note::
**ONLY available in the static mode.** **ONLY available in the static mode.**
**Args block must be None or current block of main program.**
All operators in the target block are processed as follows. All operators in the target block are processed as follows.
If it is an automatic differential basic operator, it will be If it is an automatic differential basic operator, it will be
...@@ -545,10 +547,10 @@ def prim2orig(block=None): ...@@ -545,10 +547,10 @@ def prim2orig(block=None):
equivalent function to support execution. equivalent function to support execution.
Args: Args:
block(paddle.static.Variable|None, optional): The block(paddle.static.Block|None, optional): The
target block to process on. Default None, and will target block to process on. Default None, and will
process on the current block of main program. process on the current block of main program.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -566,6 +568,10 @@ def prim2orig(block=None): ...@@ -566,6 +568,10 @@ def prim2orig(block=None):
if prim_enabled(): if prim_enabled():
prim2orig() prim2orig()
""" """
block = default_main_program().current_block() if block is None else block
assert block == default_main_program().current_block(
), f'block is neither None nor current block of main program'
_lower(block, reverse=True) _lower(block, reverse=True)
...@@ -583,7 +589,9 @@ def _gradients(ys, xs, ys_bar=None): ...@@ -583,7 +589,9 @@ def _gradients(ys, xs, ys_bar=None):
""" """
ys, xs = to_tensors(ys), to_tensors(xs) ys, xs = to_tensors(ys), to_tensors(xs)
block = ys[0].block block = default_main_program().current_block()
for el in xs + ys:
assert el is None or el.block == block, f'variable in xs and ys should be None or in current block of main program'
# TODO(Tongxin) without any prior knowledge about whether the program # TODO(Tongxin) without any prior knowledge about whether the program
# is completely lowered to primitive ops, it's mandatory to run the lowering # is completely lowered to primitive ops, it's mandatory to run the lowering
# pass once and again. This is obviously inefficient and needs to be # pass once and again. This is obviously inefficient and needs to be
......
...@@ -58,6 +58,8 @@ def append_backward_new(loss_list, ...@@ -58,6 +58,8 @@ def append_backward_new(loss_list,
program = default_main_program() program = default_main_program()
assert program.num_blocks == 1, "The append_backward_new interface is designed to process only one block." assert program.num_blocks == 1, "The append_backward_new interface is designed to process only one block."
block = program.current_block() block = program.current_block()
for el in loss_list:
assert el.block == block, f'variable in loss_list should be in current block of main program'
orig2prim(block) orig2prim(block)
ad = Transform(block) ad = Transform(block)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册