From 163cd15457b86b298f1839e3b7275ee596fe238f Mon Sep 17 00:00:00 2001 From: xiongkun Date: Mon, 29 Aug 2022 12:11:18 +0800 Subject: [PATCH] [ Dy2Static ] Unify tensor.size and Variable.size() by jst (#45144) * unify the size and size() by jst * fix bugs * bug fix. * fix bugs * change all_close -> np.testing.assert_allclose --- .../basic_api_transformer.py | 39 +++++++++++++++++++ .../dygraph_to_static/convert_operators.py | 7 ++++ .../dygraph_to_static/test_tensor_methods.py | 24 ++++++++++++ python/paddle/jit/dy2static/__init__.py | 1 + .../paddle/jit/dy2static/convert_operators.py | 2 +- 5 files changed, 72 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py index 493caa7e65b..55afb7ae6d6 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py @@ -18,6 +18,7 @@ from paddle.utils import gast from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static import utils from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer +from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code class BasicApiTransformer(BaseTransformer): @@ -37,6 +38,8 @@ class BasicApiTransformer(BaseTransformer): def transform(self): to_tensor_transformer = ToTensorTransformer(self.root) to_tensor_transformer.transform() + attribute_transformer = AttributeJstTransformer(self.root) + attribute_transformer.transform() self.visit(self.root) return self.wrapper_root @@ -122,6 +125,42 @@ class ToTensorTransformer(BaseTransformer): return node +class AttributeJstTransformer(BaseTransformer): + """ + change some special attribute into __jst.XXX(obj, "attr_name") format. + for example: + a.size --> __jst.attr(a, "size") + + because `size` have different behavier when in dygraph / static mode + NOTE: we only deal with ctx=Load() case. + """ + + def __init__(self, node): + assert isinstance( + node, gast.AST + ), "Input non-gast.AST node for the initialization of ToTensorTransformer." + self.interested_name = set([ + 'size', + ]) + self.root = node + + def transform(self): + self.visit(self.root) + return self.root + + def visit_Attribute(self, node): + assert isinstance(node, gast.Attribute) + assert isinstance(node.attr, str) + if isinstance(node.ctx, + gast.Load) and node.attr in self.interested_name: + attr = node.attr + value = node.value + node = gast.parse("_jst.Attr({}, \"{}\")".format( + ast_to_source_code(value).strip(), attr)).body[0].value + self.generic_visit(node) + return node + + def is_to_variable(node): assert isinstance(node, gast.Call) api_name = utils.ast_to_source_code(node.func).strip() diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py index 56babcec87a..aa870968695 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py @@ -28,6 +28,13 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, Dygraph2S from paddle.fluid.layers.utils import copy_mutable_vars +def convert_attr(x, attr): + if isinstance(x, Variable) and attr == "size": + return x.size() + else: + return getattr(value, attr) + + def indexable(x, code=None): if isinstance(x, Variable): return x if hasattr(x, '__len__') and hasattr(x, '__getitem__'): return x diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_methods.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_methods.py index 7b587a77728..2ad9153fbaa 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_methods.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_methods.py @@ -62,5 +62,29 @@ class TestTensorDygraphOnlyMethodError(unittest.TestCase): static_res = self._run(to_static=True) +@paddle.jit.to_static +def tensor_size(x): + x = paddle.to_tensor(x) + x = paddle.reshape(x, paddle.shape(x)) # dynamic shape + y = x.size + return y + + +class TestTensorSize(unittest.TestCase): + + def _run(self, to_static): + prog_trans = paddle.jit.ProgramTranslator() + prog_trans.enable(to_static) + x = paddle.ones([1, 2, 3]) + if to_static == False: + return tensor_size(x) + return tensor_size(x).numpy() + + def test_tensor_clone(self): + dygraph_res = self._run(to_static=False) + static_res = self._run(to_static=True) + np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-5) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/jit/dy2static/__init__.py b/python/paddle/jit/dy2static/__init__.py index ebb4d30a412..8cfc2ee6a36 100644 --- a/python/paddle/jit/dy2static/__init__.py +++ b/python/paddle/jit/dy2static/__init__.py @@ -27,6 +27,7 @@ from .convert_operators import convert_print as Print # noqa: F401 from .convert_operators import convert_shape as Shape # noqa: F401 from .convert_operators import convert_while_loop as While # noqa: F401 from .convert_operators import unpack_by_structure as Unpack # noqa: F401 +from .convert_operators import convert_attr as Attr # noqa: F401 from .convert_operators import indexable as Indexable # noqa: F401 from .variable_trans_func import create_bool_as_type # noqa: F401 from .variable_trans_func import to_static_variable # noqa: F401 diff --git a/python/paddle/jit/dy2static/convert_operators.py b/python/paddle/jit/dy2static/convert_operators.py index 691c8c0cfbe..fd809768e08 100644 --- a/python/paddle/jit/dy2static/convert_operators.py +++ b/python/paddle/jit/dy2static/convert_operators.py @@ -26,6 +26,6 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_shape_c from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_dtype # noqa: F401 from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_shape # noqa: F401 from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_while_loop # noqa: F401 -from ...fluid.dygraph.dygraph_to_static.convert_operators import unpack_by_structure, indexable # noqa: F401 +from ...fluid.dygraph.dygraph_to_static.convert_operators import unpack_by_structure, indexable, convert_attr # noqa: F401 __all__ = [] -- GitLab