diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_load_transformer.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_load_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..52c5fcd4e4aebbbe8c77cea030d7fb81de93938e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_load_transformer.py @@ -0,0 +1,57 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import numpy as np + +import paddle + + +class FakeNet: + def __init__(self): + self.var = paddle.to_tensor([2.0]) + + +f = FakeNet() +g = paddle.to_tensor([1.0]) + + +class Net(paddle.nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, x): + # unsupport g as store. + t = g * 2 + x + t = f.var * t + return t + + +class TestFallback(unittest.TestCase): + def setUp(self): + self.x = paddle.to_tensor(1.0).astype('int') + + def test_name_load(self): + net_dy = Net() + net_st = Net() + output_dy = net_dy(self.x) + output_st = paddle.jit.to_static(net_st)(self.x) + np.testing.assert_allclose(output_dy.numpy(), output_st.numpy()) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/jit/dy2static/__init__.py b/python/paddle/jit/dy2static/__init__.py index b3e70f487003151dee80b708eed1e001fb64ab9f..bc91a4c1674f3e9a9915208049959fd2b5b3d3bf 100644 --- a/python/paddle/jit/dy2static/__init__.py +++ b/python/paddle/jit/dy2static/__init__.py @@ -26,6 +26,7 @@ from .convert_operators import convert_shape as Shape # noqa: F401 from .convert_operators import convert_while_loop as While # noqa: F401 from .convert_operators import unpack_by_structure as Unpack # noqa: F401 from .convert_operators import convert_attr as Attr # noqa: F401 +from .convert_operators import convert_load as Ld # noqa: F401 from .convert_operators import indexable as Indexable # noqa: F401 from .variable_trans_func import create_bool_as_type # noqa: F401 from .variable_trans_func import to_static_variable # noqa: F401 diff --git a/python/paddle/jit/dy2static/ast_transformer.py b/python/paddle/jit/dy2static/ast_transformer.py index 35547be1838565b6159f1e3dc430fb937abf6e3d..3c7926d8fa621e1fb624614056e94ee59768d11e 100644 --- a/python/paddle/jit/dy2static/ast_transformer.py +++ b/python/paddle/jit/dy2static/ast_transformer.py @@ -22,7 +22,7 @@ import os from . import logging_utils from .assert_transformer import AssertTransformer from .base_transformer import BaseTransformer -from .basic_api_transformer import BasicApiTransformer +from .basic_api_transformer import BasicApiTransformer, NameloadJstTransformer from .break_continue_transformer import ( BreakContinueTransformer, BreakTransformOptimizer, @@ -93,6 +93,7 @@ class DygraphToStaticAst(BaseTransformer): transformers = [ EarlyReturnTransformer, + DecoratorTransformer, # transform decorators to function call BasicApiTransformer, # Basic Api TensorShapeTransformer, # Tensor.shape -> paddle.shape(Tensor) BreakContinueTransformer, # break/continue in loops @@ -104,7 +105,7 @@ class DygraphToStaticAst(BaseTransformer): AssertTransformer, # assert statement CallTransformer, # transform call recursively CastTransformer, # type casting statement - DecoratorTransformer, # transform decorators to function call + NameloadJstTransformer, TypeHintTransformer, # remove all typehint in gast.Name ] diff --git a/python/paddle/jit/dy2static/basic_api_transformer.py b/python/paddle/jit/dy2static/basic_api_transformer.py index 24825e8e937c7010f40dd7a24f25e4202a11a6e5..a25b442144a6b90d3518748a568374e0c7317aa6 100644 --- a/python/paddle/jit/dy2static/basic_api_transformer.py +++ b/python/paddle/jit/dy2static/basic_api_transformer.py @@ -43,7 +43,6 @@ class BasicApiTransformer(BaseTransformer): attribute_transformer = AttributeJstTransformer(self.root) attribute_transformer.transform() self.visit(self.root) - return self.wrapper_root def visit_Assign(self, node): @@ -127,6 +126,63 @@ class ToTensorTransformer(BaseTransformer): return node +class NameloadJstTransformer(BaseTransformer): + """ + change name and attribute load to __jst.Ld(name) pattern. + for example: + a.dtype --> __jst.Ld(__jst.Ld(a).dtype) + + In paddle science and deepxde, we have to support changing tensor into variable + in arbitrary occasion such as global tensor. + + NOTE: we only deal with ctx=Load() case. + """ + + def __init__(self, wrapper_root): + assert isinstance( + wrapper_root, AstNodeWrapper + ), "Input non-AstNodeWrapper node for the initialization of BasicApiTransformer." + + self.wrapper_root = wrapper_root + self.root = wrapper_root.node + + def transform(self): + self.visit(self.root) + return self.root + + def _surround_with_ld(self, node): + node = ( + gast.parse( + "_jst.Ld({})".format(utils.ast_to_source_code(node).strip()) + ) + .body[0] + .value + ) + return node + + def visit_Call(self, node): + """ + Can't convert name of function call, bacause this will affect CallTransformer. + """ + node.args = [self.generic_visit(arg) for arg in node.args] + return node + + def visit_Attribute(self, node): + assert isinstance(node, gast.Attribute) + assert isinstance(node.attr, str) + self.generic_visit(node) + if isinstance(node.ctx, gast.Load): + node = self._surround_with_ld(node) + return node + + def visit_Name(self, node): + assert isinstance(node, gast.Name) + self.generic_visit(node) + if isinstance(node.ctx, gast.Load): + node = self._surround_with_ld(node) + return node + + class AttributeJstTransformer(BaseTransformer): """ change some special attribute into __jst.XXX(obj, "attr_name") format. diff --git a/python/paddle/jit/dy2static/convert_operators.py b/python/paddle/jit/dy2static/convert_operators.py index 111e975b3a05cf5f66df4966ee7fe61b8154f179..bb07c1faeb0450b17db33474cde2931852278cfd 100644 --- a/python/paddle/jit/dy2static/convert_operators.py +++ b/python/paddle/jit/dy2static/convert_operators.py @@ -16,6 +16,7 @@ import re import paddle from paddle.fluid.data_feeder import convert_dtype +from paddle.fluid.dygraph.base import _convert_into_variable from paddle.fluid.framework import Variable, core from paddle.fluid.layers import Print, control_flow, fill_constant from paddle.fluid.layers.control_flow import while_loop @@ -39,6 +40,17 @@ def convert_attr(x, attr): return getattr(x, attr) +def convert_load(x): + from paddle.fluid.dygraph.base import in_declarative_mode + + if in_declarative_mode() and isinstance(x, paddle.fluid.core.eager.Tensor): + """ + TODO:(@xiongkun) may run convert_load in dygraph mode, which should be fixed. + """ + return _convert_into_variable(x) + return x + + def indexable(x, code=None): if isinstance(x, Variable): return x diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index 197e8f4eba0c30b9f411ff69bbee5f82722cd2dc..9ec6c87ea1ea070d59909c04511c3deb001d38e1 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -1145,9 +1145,10 @@ class ProgramCache: def _build_once(self, cache_key): # TODO(Aurelius84): Need a gloabl FLAGS to enable/disable to_prim enable_prim = cache_key.kwargs['build_strategy'].build_cinn_pass + # TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch. + # NOTE(xiongkun): Need a global FLAGS to enable/disable fallback enable_fallback = enable_prim - # TODO(CZ): later when use cinn, set_prim_all_enabled and check_and_set_prim_all_enabled will be set at else branch. core.check_and_set_prim_all_enabled() try: concrete_program = ConcreteProgram.from_func_spec(