未验证 提交 d0096eaf 编写于 作者: X xiongkun 提交者: GitHub

[ Dy2Static ] convert_call support staticmethod for class. (#44983)

* convert_call support staticmethod for class.

* while support for python container.
It is convenient to convert more dynamic graph codes into static graphs.

* cond support python container

* add unittest for staticmethod convert_call

* fix bugs

* add unittest for item interface

* fix bugs

* change to np.testing.assert_allclose

* code format

* fix comments.

* code format
上级 2b4f44d5
...@@ -123,7 +123,10 @@ class DygraphToStaticAst(BaseTransformer): ...@@ -123,7 +123,10 @@ class DygraphToStaticAst(BaseTransformer):
# Remove the decorated name of dygraph_to_static # Remove the decorated name of dygraph_to_static
if hasattr(node, 'decorator_list'): if hasattr(node, 'decorator_list'):
decorator_list = [] decorator_list = []
ignore_list = ["staticmethod"]
for d in node.decorator_list: for d in node.decorator_list:
if isinstance(d, gast.Name) and d.id in ignore_list:
continue
if isinstance(d, gast.Name) and d.id not in DECORATOR_NAMES: if isinstance(d, gast.Name) and d.id not in DECORATOR_NAMES:
raise NotImplementedError( raise NotImplementedError(
"ProgramTranslator hasn't implemented multiple decorators. Please remove " "ProgramTranslator hasn't implemented multiple decorators. Please remove "
......
...@@ -213,6 +213,11 @@ def convert_call(func): ...@@ -213,6 +213,11 @@ def convert_call(func):
elif isinstance(fn, StaticFunction): elif isinstance(fn, StaticFunction):
_, fn = unwrap_decorators(fn) _, fn = unwrap_decorators(fn)
global_functions.add(fn) global_functions.add(fn)
elif inspect.isclass(fn):
if isinstance(fn.__dict__.get(func.__name__, None),
staticmethod):
global_functions.add(
func) # Add func to ensure that we will convert
if func in global_functions: if func in global_functions:
converted_call = convert_to_static(func) converted_call = convert_to_static(func)
......
...@@ -87,7 +87,10 @@ def convert_while_loop(cond, ...@@ -87,7 +87,10 @@ def convert_while_loop(cond,
Args: Args:
cond(Callable): A callable object that returns a boolean variable to control whether to execute the loop body. It takes ``loop_vars`` as arguments. cond(Callable): A callable object that returns a boolean variable to control whether to execute the loop body. It takes ``loop_vars`` as arguments.
body(Callable): A callable object that returns a tuple or list of variables with the same arguments ``loops_vars`` as ``cond`` . body(Callable): A callable object that returns a tuple or list of variables with the same arguments ``loops_vars`` as ``cond`` .
loop_vars(list|tuple): A list or tuple of variables passed to ``cond`` and ``body`` . get_args(callable): Get all arguments that needed in true_fn and false_fn.
set_args(callable): Update arguments that modified in trure_fn and false_fn.
return_name_ids(list[string], optional): the returned names.
push_pop_names(list[string], optional): the names on which called .append() or .pop().
Returns: Returns:
A list or tuple of variables which returned by ``body``. A list or tuple of variables which returned by ``body``.
...@@ -306,7 +309,8 @@ def convert_ifelse(pred, ...@@ -306,7 +309,8 @@ def convert_ifelse(pred,
false_fn(callable): A callable to be performed if ``pred`` is false. false_fn(callable): A callable to be performed if ``pred`` is false.
get_args(callable): Get all arguments that needed in true_fn and false_fn. get_args(callable): Get all arguments that needed in true_fn and false_fn.
set_args(callable): Update arguments that modified in trure_fn and false_fn. set_args(callable): Update arguments that modified in trure_fn and false_fn.
return_name_ids(list[string]): the returned names. return_name_ids(list[string], optional): the returned names.
push_pop_names(list[string], optional): the names on which called .append() or .pop().
Returns: Returns:
``true_fn()`` if the predicate ``pred`` is true else ``false_fn()`` . ``true_fn()`` if the predicate ``pred`` is true else ``false_fn()`` .
......
...@@ -380,7 +380,6 @@ class StaticFunction(object): ...@@ -380,7 +380,6 @@ class StaticFunction(object):
try: try:
concrete_program, partial_program_layer = self.get_concrete_program( concrete_program, partial_program_layer = self.get_concrete_program(
*args, **kwargs, is_train=self._is_train_mode()) *args, **kwargs, is_train=self._is_train_mode())
# 3. synchronize self.training attribute. # 3. synchronize self.training attribute.
if isinstance(self._class_instance, layers.Layer): if isinstance(self._class_instance, layers.Layer):
partial_program_layer.training = self._class_instance.training partial_program_layer.training = self._class_instance.training
......
...@@ -227,12 +227,30 @@ def monkey_patch_variable(): ...@@ -227,12 +227,30 @@ def monkey_patch_variable():
.format(self.type)) .format(self.type))
array_write(x=var, i=array_length(self), array=self) array_write(x=var, i=array_length(self), array=self)
@static_only
def _item(self):
"""
In order to be compatible with the item interface introduced by the dynamic graph, it does nothing but returns self.
It will check that the shape must be a 1-D tensor
"""
if len(self.shape) > 1:
raise TypeError(
"Required input var should be 1-D Variable, but received {}".
format(self.shape))
return self
@static_only @static_only
def pop(self, *args): def pop(self, *args):
""" """
**Notes**: The type variable must be LoD Tensor Array.
**The type variable must be LoD Tensor Array. When self is LoDTensorArray, calling pop is similar to Python's pop on list.
This interface is used to simplify dygraph to static graph operations.
Args:
self(Variable): The source variable, which must be LOD_TENSOR_ARRAY
*args: optional, a int means index.
Returns:
Variable: self[index]
""" """
from paddle.fluid.dygraph.dygraph_to_static.convert_operators import _run_paddle_pop from paddle.fluid.dygraph.dygraph_to_static.convert_operators import _run_paddle_pop
if self.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY: if self.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY:
...@@ -410,6 +428,7 @@ def monkey_patch_variable(): ...@@ -410,6 +428,7 @@ def monkey_patch_variable():
('cpu', cpu), ('cpu', cpu),
('cuda', cuda), ('cuda', cuda),
('append', append), ('append', append),
('item', _item),
('pop', pop), ('pop', pop),
('dim', lambda x: len(x.shape)), ('dim', lambda x: len(x.shape)),
('ndimension', lambda x: len(x.shape)), ('ndimension', lambda x: len(x.shape)),
......
...@@ -65,6 +65,23 @@ def dyfunc_with_third_library_logging(x_v): ...@@ -65,6 +65,23 @@ def dyfunc_with_third_library_logging(x_v):
return x_v return x_v
class A:
@staticmethod
def add(a, b):
"""
dygraph mode, return a numpy object.
static mode, return a variable object.
"""
return paddle.to_tensor(a.numpy() + b.numpy())
@paddle.jit.to_static
def dyfunc_with_staticmethod(x_v):
a = A()
return a.add(x_v, x_v)
class TestRecursiveCall1(unittest.TestCase): class TestRecursiveCall1(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -188,6 +205,12 @@ class TestThirdPartyLibrary(TestRecursiveCall2): ...@@ -188,6 +205,12 @@ class TestThirdPartyLibrary(TestRecursiveCall2):
self.dygraph_func = dyfunc_with_third_library_logging self.dygraph_func = dyfunc_with_third_library_logging
class TestStaticMethod(TestRecursiveCall2):
def set_func(self):
self.dygraph_func = dyfunc_with_staticmethod
# Situation 2 : test not_to_static # Situation 2 : test not_to_static
......
...@@ -62,6 +62,29 @@ class TestTensorDygraphOnlyMethodError(unittest.TestCase): ...@@ -62,6 +62,29 @@ class TestTensorDygraphOnlyMethodError(unittest.TestCase):
static_res = self._run(to_static=True) static_res = self._run(to_static=True)
@paddle.jit.to_static
def tensor_item(x):
x = paddle.to_tensor(x)
y = x.clone()
return y.item()
class TestTensorItem(unittest.TestCase):
def _run(self, to_static):
prog_trans = paddle.jit.ProgramTranslator()
prog_trans.enable(to_static)
x = paddle.ones([1])
if to_static:
return tensor_item(x).numpy()
return tensor_item(x)
def test_tensor_clone(self):
dygraph_res = self._run(to_static=False)
static_res = self._run(to_static=True)
np.testing.assert_allclose(dygraph_res, static_res)
@paddle.jit.to_static @paddle.jit.to_static
def tensor_size(x): def tensor_size(x):
x = paddle.to_tensor(x) x = paddle.to_tensor(x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册