未验证 提交 8cc40f47 编写于 作者: L levi131 提交者: GitHub

enhance check for current block and docstring for prim2orig interface (#43063)

* enhance check for current block docstring for prim2orig interface

* refine if else syntax
上级 586f9429
......@@ -38,8 +38,7 @@ def topo_path(xs, ys, block=None):
path, the unused variables in `xs`, and the unreached variables in `ys`
"""
if block is None:
block = default_main_program().current_block()
block = default_main_program().current_block() if block is None else block
path = []
backpath = []
......@@ -160,11 +159,14 @@ class VarMap(object):
return id(value_var) in self.tab.values()
# TODO(lml): supporting control flow, nested blocks, and block other than current block of main program.
class Transform(object):
""" An object that maintains the state of transformations applied to a
primitve program. """
def __init__(self, block):
assert block == default_main_program().current_block(
), f'only support transform on current block of main program.'
self.block = block
self.vars = self.init_vars(block)
self.var2dot = VarMap('var2dot', self.vars)
......@@ -400,6 +402,7 @@ class Transform(object):
return ys_bar, xs_bar
# TODO(lml): supporting control flow, nested blocks, and block other than current block of main program.
def _lower(block, reverse):
# Some functions which are only used in _lower.
def bind(args, to_bind, value_table):
......@@ -430,10 +433,6 @@ def _lower(block, reverse):
# Step1: Do some preparatory work for lower
lower_fn = _prim2orig if reverse else _orig2prim
lookup_fn = lookup_prim2orig if reverse else lookup_orig2prim
if block is None:
program = default_main_program()
assert program.num_blocks == 1, "The lower transform is designed to process only one block."
block = program.current_block()
value_table = {}
to_bind = {}
......@@ -516,6 +515,7 @@ def orig2prim(block=None):
"""
.. note::
**This API is ONLY available in the static mode.**
**Args block must be None or current block of main program.**
All operators in the target block are processed as follows.
If it is an original operator, it will be transformed into
......@@ -523,13 +523,14 @@ def orig2prim(block=None):
equivalent function.
Args:
block(paddle.fluid.framework.Variable|None, optional): The
block(paddle.static.Block|None, optional): The
target block to process on. Default None, and will
process on the current block of main program.
Returns:
None
"""
block = default_main_program().current_block() if block is None else block
assert block == default_main_program().current_block(
), f'block is neither None nor current block of main program'
_lower(block, reverse=False)
......@@ -538,6 +539,7 @@ def prim2orig(block=None):
"""
.. note::
**ONLY available in the static mode.**
**Args block must be None or current block of main program.**
All operators in the target block are processed as follows.
If it is an automatic differential basic operator, it will be
......@@ -545,10 +547,10 @@ def prim2orig(block=None):
equivalent function to support execution.
Args:
block(paddle.static.Variable|None, optional): The
block(paddle.static.Block|None, optional): The
target block to process on. Default None, and will
process on the current block of main program.
Examples:
.. code-block:: python
......@@ -566,6 +568,10 @@ def prim2orig(block=None):
if prim_enabled():
prim2orig()
"""
block = default_main_program().current_block() if block is None else block
assert block == default_main_program().current_block(
), f'block is neither None nor current block of main program'
_lower(block, reverse=True)
......@@ -583,7 +589,9 @@ def _gradients(ys, xs, ys_bar=None):
"""
ys, xs = to_tensors(ys), to_tensors(xs)
block = ys[0].block
block = default_main_program().current_block()
for el in xs + ys:
assert el is None or el.block == block, f'variable in xs and ys should be None or in current block of main program'
# TODO(Tongxin) without any prior knowledge about whether the program
# is completely lowered to primitive ops, it's mandatory to run the lowering
# pass once and again. This is obviously inefficient and needs to be
......
......@@ -58,6 +58,8 @@ def append_backward_new(loss_list,
program = default_main_program()
assert program.num_blocks == 1, "The append_backward_new interface is designed to process only one block."
block = program.current_block()
for el in loss_list:
assert el.block == block, f'variable in loss_list should be in current block of main program'
orig2prim(block)
ad = Transform(block)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册