From 8cc40f4702c5cf0e8c88b13e17d8461938f7298a Mon Sep 17 00:00:00 2001 From: levi131 <83750468+levi131@users.noreply.github.com> Date: Mon, 30 May 2022 15:59:08 +0800 Subject: [PATCH] enhance check for current block and docstring for prim2orig interface (#43063) * enhance check for current block docstring for prim2orig interface * refine if else syntax --- python/paddle/incubate/autograd/primx.py | 34 +++++++++++++++--------- python/paddle/optimizer/optimizer.py | 2 ++ 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/python/paddle/incubate/autograd/primx.py b/python/paddle/incubate/autograd/primx.py index 7a969748208..1f5c4f9a5ce 100644 --- a/python/paddle/incubate/autograd/primx.py +++ b/python/paddle/incubate/autograd/primx.py @@ -38,8 +38,7 @@ def topo_path(xs, ys, block=None): path, the unused variables in `xs`, and the unreached variables in `ys` """ - if block is None: - block = default_main_program().current_block() + block = default_main_program().current_block() if block is None else block path = [] backpath = [] @@ -160,11 +159,14 @@ class VarMap(object): return id(value_var) in self.tab.values() +# TODO(lml): supporting control flow, nested blocks, and block other than current block of main program. class Transform(object): """ An object that maintains the state of transformations applied to a primitve program. """ def __init__(self, block): + assert block == default_main_program().current_block( + ), f'only support transform on current block of main program.' self.block = block self.vars = self.init_vars(block) self.var2dot = VarMap('var2dot', self.vars) @@ -400,6 +402,7 @@ class Transform(object): return ys_bar, xs_bar +# TODO(lml): supporting control flow, nested blocks, and block other than current block of main program. def _lower(block, reverse): # Some functions which are only used in _lower. def bind(args, to_bind, value_table): @@ -430,10 +433,6 @@ def _lower(block, reverse): # Step1: Do some preparatory work for lower lower_fn = _prim2orig if reverse else _orig2prim lookup_fn = lookup_prim2orig if reverse else lookup_orig2prim - if block is None: - program = default_main_program() - assert program.num_blocks == 1, "The lower transform is designed to process only one block." - block = program.current_block() value_table = {} to_bind = {} @@ -516,6 +515,7 @@ def orig2prim(block=None): """ .. note:: **This API is ONLY available in the static mode.** + **Args block must be None or current block of main program.** All operators in the target block are processed as follows. If it is an original operator, it will be transformed into @@ -523,13 +523,14 @@ def orig2prim(block=None): equivalent function. Args: - block(paddle.fluid.framework.Variable|None, optional): The + block(paddle.static.Block|None, optional): The target block to process on. Default None, and will process on the current block of main program. - - Returns: - None """ + + block = default_main_program().current_block() if block is None else block + assert block == default_main_program().current_block( + ), f'block is neither None nor current block of main program' _lower(block, reverse=False) @@ -538,6 +539,7 @@ def prim2orig(block=None): """ .. note:: **ONLY available in the static mode.** + **Args block must be None or current block of main program.** All operators in the target block are processed as follows. If it is an automatic differential basic operator, it will be @@ -545,10 +547,10 @@ def prim2orig(block=None): equivalent function to support execution. Args: - block(paddle.static.Variable|None, optional): The + block(paddle.static.Block|None, optional): The target block to process on. Default None, and will process on the current block of main program. - + Examples: .. code-block:: python @@ -566,6 +568,10 @@ def prim2orig(block=None): if prim_enabled(): prim2orig() """ + + block = default_main_program().current_block() if block is None else block + assert block == default_main_program().current_block( + ), f'block is neither None nor current block of main program' _lower(block, reverse=True) @@ -583,7 +589,9 @@ def _gradients(ys, xs, ys_bar=None): """ ys, xs = to_tensors(ys), to_tensors(xs) - block = ys[0].block + block = default_main_program().current_block() + for el in xs + ys: + assert el is None or el.block == block, f'variable in xs and ys should be None or in current block of main program' # TODO(Tongxin) without any prior knowledge about whether the program # is completely lowered to primitive ops, it's mandatory to run the lowering # pass once and again. This is obviously inefficient and needs to be diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index 9dfec3947e9..cf180fccc48 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -58,6 +58,8 @@ def append_backward_new(loss_list, program = default_main_program() assert program.num_blocks == 1, "The append_backward_new interface is designed to process only one block." block = program.current_block() + for el in loss_list: + assert el.block == block, f'variable in loss_list should be in current block of main program' orig2prim(block) ad = Transform(block) -- GitLab