未验证 提交 163cd154 编写于 作者: X xiongkun 提交者: GitHub

[ Dy2Static ] Unify tensor.size and Variable.size() by jst (#45144)

* unify the size and size() by jst

* fix bugs

* bug fix.

* fix bugs

* change all_close -> np.testing.assert_allclose
上级 a237ff8e
......@@ -18,6 +18,7 @@ from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static import utils
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
class BasicApiTransformer(BaseTransformer):
......@@ -37,6 +38,8 @@ class BasicApiTransformer(BaseTransformer):
def transform(self):
to_tensor_transformer = ToTensorTransformer(self.root)
to_tensor_transformer.transform()
attribute_transformer = AttributeJstTransformer(self.root)
attribute_transformer.transform()
self.visit(self.root)
return self.wrapper_root
......@@ -122,6 +125,42 @@ class ToTensorTransformer(BaseTransformer):
return node
class AttributeJstTransformer(BaseTransformer):
"""
change some special attribute into __jst.XXX(obj, "attr_name") format.
for example:
a.size --> __jst.attr(a, "size")
because `size` have different behavier when in dygraph / static mode
NOTE: we only deal with ctx=Load() case.
"""
def __init__(self, node):
assert isinstance(
node, gast.AST
), "Input non-gast.AST node for the initialization of ToTensorTransformer."
self.interested_name = set([
'size',
])
self.root = node
def transform(self):
self.visit(self.root)
return self.root
def visit_Attribute(self, node):
assert isinstance(node, gast.Attribute)
assert isinstance(node.attr, str)
if isinstance(node.ctx,
gast.Load) and node.attr in self.interested_name:
attr = node.attr
value = node.value
node = gast.parse("_jst.Attr({}, \"{}\")".format(
ast_to_source_code(value).strip(), attr)).body[0].value
self.generic_visit(node)
return node
def is_to_variable(node):
assert isinstance(node, gast.Call)
api_name = utils.ast_to_source_code(node.func).strip()
......
......@@ -28,6 +28,13 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, Dygraph2S
from paddle.fluid.layers.utils import copy_mutable_vars
def convert_attr(x, attr):
if isinstance(x, Variable) and attr == "size":
return x.size()
else:
return getattr(value, attr)
def indexable(x, code=None):
if isinstance(x, Variable): return x
if hasattr(x, '__len__') and hasattr(x, '__getitem__'): return x
......
......@@ -62,5 +62,29 @@ class TestTensorDygraphOnlyMethodError(unittest.TestCase):
static_res = self._run(to_static=True)
@paddle.jit.to_static
def tensor_size(x):
x = paddle.to_tensor(x)
x = paddle.reshape(x, paddle.shape(x)) # dynamic shape
y = x.size
return y
class TestTensorSize(unittest.TestCase):
def _run(self, to_static):
prog_trans = paddle.jit.ProgramTranslator()
prog_trans.enable(to_static)
x = paddle.ones([1, 2, 3])
if to_static == False:
return tensor_size(x)
return tensor_size(x).numpy()
def test_tensor_clone(self):
dygraph_res = self._run(to_static=False)
static_res = self._run(to_static=True)
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-5)
if __name__ == '__main__':
unittest.main()
......@@ -27,6 +27,7 @@ from .convert_operators import convert_print as Print # noqa: F401
from .convert_operators import convert_shape as Shape # noqa: F401
from .convert_operators import convert_while_loop as While # noqa: F401
from .convert_operators import unpack_by_structure as Unpack # noqa: F401
from .convert_operators import convert_attr as Attr # noqa: F401
from .convert_operators import indexable as Indexable # noqa: F401
from .variable_trans_func import create_bool_as_type # noqa: F401
from .variable_trans_func import to_static_variable # noqa: F401
......
......@@ -26,6 +26,6 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_shape_c
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_dtype # noqa: F401
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_shape # noqa: F401
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_while_loop # noqa: F401
from ...fluid.dygraph.dygraph_to_static.convert_operators import unpack_by_structure, indexable # noqa: F401
from ...fluid.dygraph.dygraph_to_static.convert_operators import unpack_by_structure, indexable, convert_attr # noqa: F401
__all__ = []
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册