未验证 提交 d0096eaf 编写于 作者: X xiongkun 提交者: GitHub

[ Dy2Static ] convert_call support staticmethod for class. (#44983)

* convert_call support staticmethod for class.

* while support for python container.
It is convenient to convert more dynamic graph codes into static graphs.

* cond support python container

* add unittest for staticmethod convert_call

* fix bugs

* add unittest for item interface

* fix bugs

* change to np.testing.assert_allclose

* code format

* fix comments.

* code format
上级 2b4f44d5
......@@ -123,7 +123,10 @@ class DygraphToStaticAst(BaseTransformer):
# Remove the decorated name of dygraph_to_static
if hasattr(node, 'decorator_list'):
decorator_list = []
ignore_list = ["staticmethod"]
for d in node.decorator_list:
if isinstance(d, gast.Name) and d.id in ignore_list:
continue
if isinstance(d, gast.Name) and d.id not in DECORATOR_NAMES:
raise NotImplementedError(
"ProgramTranslator hasn't implemented multiple decorators. Please remove "
......
......@@ -213,6 +213,11 @@ def convert_call(func):
elif isinstance(fn, StaticFunction):
_, fn = unwrap_decorators(fn)
global_functions.add(fn)
elif inspect.isclass(fn):
if isinstance(fn.__dict__.get(func.__name__, None),
staticmethod):
global_functions.add(
func) # Add func to ensure that we will convert
if func in global_functions:
converted_call = convert_to_static(func)
......
......@@ -87,7 +87,10 @@ def convert_while_loop(cond,
Args:
cond(Callable): A callable object that returns a boolean variable to control whether to execute the loop body. It takes ``loop_vars`` as arguments.
body(Callable): A callable object that returns a tuple or list of variables with the same arguments ``loops_vars`` as ``cond`` .
loop_vars(list|tuple): A list or tuple of variables passed to ``cond`` and ``body`` .
get_args(callable): Get all arguments that needed in true_fn and false_fn.
set_args(callable): Update arguments that modified in trure_fn and false_fn.
return_name_ids(list[string], optional): the returned names.
push_pop_names(list[string], optional): the names on which called .append() or .pop().
Returns:
A list or tuple of variables which returned by ``body``.
......@@ -306,7 +309,8 @@ def convert_ifelse(pred,
false_fn(callable): A callable to be performed if ``pred`` is false.
get_args(callable): Get all arguments that needed in true_fn and false_fn.
set_args(callable): Update arguments that modified in trure_fn and false_fn.
return_name_ids(list[string]): the returned names.
return_name_ids(list[string], optional): the returned names.
push_pop_names(list[string], optional): the names on which called .append() or .pop().
Returns:
``true_fn()`` if the predicate ``pred`` is true else ``false_fn()`` .
......
......@@ -380,7 +380,6 @@ class StaticFunction(object):
try:
concrete_program, partial_program_layer = self.get_concrete_program(
*args, **kwargs, is_train=self._is_train_mode())
# 3. synchronize self.training attribute.
if isinstance(self._class_instance, layers.Layer):
partial_program_layer.training = self._class_instance.training
......
......@@ -227,12 +227,30 @@ def monkey_patch_variable():
.format(self.type))
array_write(x=var, i=array_length(self), array=self)
@static_only
def _item(self):
"""
In order to be compatible with the item interface introduced by the dynamic graph, it does nothing but returns self.
It will check that the shape must be a 1-D tensor
"""
if len(self.shape) > 1:
raise TypeError(
"Required input var should be 1-D Variable, but received {}".
format(self.shape))
return self
@static_only
def pop(self, *args):
"""
**Notes**:
**The type variable must be LoD Tensor Array.
The type variable must be LoD Tensor Array.
When self is LoDTensorArray, calling pop is similar to Python's pop on list.
This interface is used to simplify dygraph to static graph operations.
Args:
self(Variable): The source variable, which must be LOD_TENSOR_ARRAY
*args: optional, a int means index.
Returns:
Variable: self[index]
"""
from paddle.fluid.dygraph.dygraph_to_static.convert_operators import _run_paddle_pop
if self.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY:
......@@ -410,6 +428,7 @@ def monkey_patch_variable():
('cpu', cpu),
('cuda', cuda),
('append', append),
('item', _item),
('pop', pop),
('dim', lambda x: len(x.shape)),
('ndimension', lambda x: len(x.shape)),
......
......@@ -65,6 +65,23 @@ def dyfunc_with_third_library_logging(x_v):
return x_v
class A:
@staticmethod
def add(a, b):
"""
dygraph mode, return a numpy object.
static mode, return a variable object.
"""
return paddle.to_tensor(a.numpy() + b.numpy())
@paddle.jit.to_static
def dyfunc_with_staticmethod(x_v):
a = A()
return a.add(x_v, x_v)
class TestRecursiveCall1(unittest.TestCase):
def setUp(self):
......@@ -188,6 +205,12 @@ class TestThirdPartyLibrary(TestRecursiveCall2):
self.dygraph_func = dyfunc_with_third_library_logging
class TestStaticMethod(TestRecursiveCall2):
def set_func(self):
self.dygraph_func = dyfunc_with_staticmethod
# Situation 2 : test not_to_static
......
......@@ -62,6 +62,29 @@ class TestTensorDygraphOnlyMethodError(unittest.TestCase):
static_res = self._run(to_static=True)
@paddle.jit.to_static
def tensor_item(x):
x = paddle.to_tensor(x)
y = x.clone()
return y.item()
class TestTensorItem(unittest.TestCase):
def _run(self, to_static):
prog_trans = paddle.jit.ProgramTranslator()
prog_trans.enable(to_static)
x = paddle.ones([1])
if to_static:
return tensor_item(x).numpy()
return tensor_item(x)
def test_tensor_clone(self):
dygraph_res = self._run(to_static=False)
static_res = self._run(to_static=True)
np.testing.assert_allclose(dygraph_res, static_res)
@paddle.jit.to_static
def tensor_size(x):
x = paddle.to_tensor(x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册