From d3f9808893b5c7e1b795f65928dd74c94a062dd9 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Thu, 17 Aug 2023 11:48:07 +0800 Subject: [PATCH] [Dy2static] support set_value_op in static mode by _jst.Ld() (#56028) * Fix and add unittest * don't introduce assign when already in global block. * fix more * fix bugs * fix bugs * fix ci * fix bfgs * make function local --- python/paddle/fluid/framework.py | 4 ++ python/paddle/fluid/layer_helper_base.py | 33 +++++++++++ python/paddle/fluid/variable_index.py | 28 +++++++++- .../jit/dy2static/basic_api_transformer.py | 3 + .../jit/dy2static/program_translator.py | 13 ++++- python/paddle/static/nn/control_flow.py | 3 +- test/dygraph_to_static/test_jit_setitem.py | 56 +++++++++++++++---- 7 files changed, 122 insertions(+), 18 deletions(-) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index b375cca76c1..35e8dba7548 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -3573,6 +3573,10 @@ def _stride_in_no_check_dy2st_diff(): def check_if_to_static_diff_with_dygraph(op_type, inplace_map, outputs): + if ( + op_type == "while" + ): # dont' need check while, while is only a wrapper of inner ops, we will stuck in inner op. + return if outputs is not None: for k, v in outputs.items(): if isinstance(v, Variable): diff --git a/python/paddle/fluid/layer_helper_base.py b/python/paddle/fluid/layer_helper_base.py index ae631bf69f3..042e33a108e 100644 --- a/python/paddle/fluid/layer_helper_base.py +++ b/python/paddle/fluid/layer_helper_base.py @@ -466,6 +466,39 @@ class LayerHelperBase: stop_gradient=stop_gradient, ) + def _create_global_variable_for_type_inference( + self, dtype, stop_gradient=False, shape=None + ): + """Create a global variable that should be type inferred layer. + + Note: + The default type will be set to LOD_TENSOR. However, when + the var is used as operator output, its type will be updated + based on operator's `VarTypeInference` implementation in + infer_var_type. + """ + # set global dtype + if not dtype: + dtype = self.__dtype + output = self.main_program.global_block().create_var( + name=unique_name.generate_with_ignorable_key( + ".".join([self.name, 'tmp']) + ), + dtype=dtype, + shape=shape, + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=stop_gradient, + ) + saved_block_id = self.main_program.current_block_idx + self.main_program.current_block_idx = 0 + paddle.tensor.creation.fill_constant( + output.shape, dtype, 0.0, force_cpu=False, out=output + ) + output.stop_gradient = stop_gradient + self.main_program.current_block_idx = saved_block_id + return output + def create_sparse_variable_for_type_inference( self, dtype, stop_gradient=False, shape=None ): diff --git a/python/paddle/fluid/variable_index.py b/python/paddle/fluid/variable_index.py index 78ba5e3cfd7..519ad7481b9 100644 --- a/python/paddle/fluid/variable_index.py +++ b/python/paddle/fluid/variable_index.py @@ -566,7 +566,13 @@ def _setitem_impl_(var, item, value): output = var else: helper = paddle.fluid.layer_helper.LayerHelper('set_value', **locals()) - output = helper.create_variable_for_type_inference(dtype=var.dtype) + if helper.main_program.current_block_idx != 0: + # not in global block, we should create a global variable. + output = helper._create_global_variable_for_type_inference( + dtype=var.dtype + ) + else: + output = helper.create_variable_for_type_inference(dtype=var.dtype) cur_block = default_main_program().current_block() cur_block.append_op( @@ -909,7 +915,15 @@ def _setitem_static(x, indices, values): helper = paddle.fluid.layer_helper.LayerHelper( 'set_value', **locals() ) - output = helper.create_variable_for_type_inference(dtype=x.dtype) + if helper.main_program.current_block_idx != 0: + # not in global block, we should create a global variable. + output = helper._create_global_variable_for_type_inference( + dtype=x.dtype + ) + else: + output = helper.create_variable_for_type_inference( + dtype=x.dtype + ) cur_block = default_main_program().current_block() cur_block.append_op( type="set_value", @@ -975,7 +989,15 @@ def _setitem_static(x, indices, values): helper = paddle.fluid.layer_helper.LayerHelper( 'set_value', **locals() ) - output = helper.create_variable_for_type_inference(dtype=x.dtype) + if helper.main_program.current_block_idx != 0: + # not in global block, we should create a global variable. + output = helper._create_global_variable_for_type_inference( + dtype=x.dtype + ) + else: + output = helper.create_variable_for_type_inference( + dtype=x.dtype + ) cur_block = default_main_program().current_block() cur_block.append_op( type="set_value", diff --git a/python/paddle/jit/dy2static/basic_api_transformer.py b/python/paddle/jit/dy2static/basic_api_transformer.py index f188df92cd9..34b8708f6a2 100644 --- a/python/paddle/jit/dy2static/basic_api_transformer.py +++ b/python/paddle/jit/dy2static/basic_api_transformer.py @@ -152,11 +152,14 @@ class NameloadJstTransformer(BaseTransformer): Can't convert name of function call, bacause this will affect CallTransformer. """ node.args = [self.generic_visit(arg) for arg in node.args] + node.func = self.generic_visit(node.func) return node def visit_Attribute(self, node): assert isinstance(node, gast.Attribute) assert isinstance(node.attr, str) + if utils.ast_to_source_code(node).startswith("_jst."): # skip _jst.xxx + return node self.generic_visit(node) if isinstance(node.ctx, gast.Load): node = self._surround_with_ld(node) diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index ea5c159c579..492766c3c70 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -1296,9 +1296,16 @@ class ParametersMap: params = self.params_dict.get(self._program_hash(program)) if params is None: return None - if id in params.keys(): - return params[id] - return None + if id not in params: + return None + root_var = params[id] + saved = [] + while root_var.desc.id() in params.keys(): + saved.append(root_var) + root_var = params[root_var.desc.id()] + for var in saved: + params[var.desc.id()] = root_var + return root_var def _program_hash(self, program): """ diff --git a/python/paddle/static/nn/control_flow.py b/python/paddle/static/nn/control_flow.py index 774e21e50d8..08b5d962abf 100644 --- a/python/paddle/static/nn/control_flow.py +++ b/python/paddle/static/nn/control_flow.py @@ -589,7 +589,8 @@ def assign_skip_lod_tensor_array(input, output): # input is not generated in While sub block and modified by in-place and only # belong to inplace ops in constructing program process, because in-place pass # is only available in Graph level. - paddle.assign(input, output) + with paddle.fluid.framework._stride_in_no_check_dy2st_diff(): + paddle.assign(input, output) def while_loop(cond, body, loop_vars, is_test=False, name=None): diff --git a/test/dygraph_to_static/test_jit_setitem.py b/test/dygraph_to_static/test_jit_setitem.py index 2ce79fb8a4b..18069d404a9 100644 --- a/test/dygraph_to_static/test_jit_setitem.py +++ b/test/dygraph_to_static/test_jit_setitem.py @@ -179,22 +179,56 @@ class TestCase11(TestSetItemBase): class TestCase12(TestSetItemBase): - # Test combind-indexing + # Test gradient of value tensor def init_func(self): - def foo(x, value): - y = x + 1 - y[[0, 1], 1, :2] = value - return y + def foo(): + res = paddle.zeros([4, 3, 2]) + b = paddle.zeros([4, 3, 2]) + v = paddle.to_tensor(1.0) + for i in range(paddle.shape(b)[0]): + res[i] = v + return res return foo def run_dygraph(self, func): - x = self.init_data() - value = paddle.ones((32,)) - value.stop_gradient = False - y = func(x, value) - x_grad, value_grad = paddle.grad(y, [x, value]) - return y, x_grad, value_grad + y = func() + return (y,) + + +class TestCase13(TestSetItemBase): + # Test gradient of value tensor + def init_func(self): + def foo(): + res = paddle.zeros([4, 3, 2]) + v = paddle.to_tensor(1.0) + for i in range(4): + res[i] = v + return res + + return foo + + def run_dygraph(self, func): + y = func() + return (y,) + + +class TestCase14(TestSetItemBase): + # Test gradient of value tensor + def init_func(self): + def foo(): + data = np.arange(8).reshape((2, 4)).astype('float32') + x = paddle.to_tensor(data) + x[:, 1:] = x[:, :-1].clone() + x[:, 0] = 1 + res = x.flatten() + return res + + return foo + + def run_dygraph(self, func): + y = func() + return (y,) if __name__ == '__main__': -- GitLab