未验证 提交 842050f2 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2St]Enhance @not_to_static API (#50453)

* [Dy2St]Enhance @not_to_static API

* del breakpoint()
上级 c5087da8
...@@ -85,3 +85,4 @@ paddle/fluid/pybind/eager_op_function_impl.h ...@@ -85,3 +85,4 @@ paddle/fluid/pybind/eager_op_function_impl.h
paddle/fluid/pybind/eager_op_function_impl.h paddle/fluid/pybind/eager_op_function_impl.h
paddle/fluid/pybind/op_function_impl.h paddle/fluid/pybind/op_function_impl.h
paddle/fluid/pybind/*final_state_op_function_impl.h paddle/fluid/pybind/*final_state_op_function_impl.h
paddle/fluid/prim/api/generated/prim_api/*
...@@ -16,12 +16,12 @@ import logging ...@@ -16,12 +16,12 @@ import logging
import unittest import unittest
import numpy as np import numpy as np
from test_program_translator import get_source_code
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.jit.dy2static as _jst import paddle.jit.dy2static as _jst
from paddle.jit.dy2static.convert_call_func import CONVERSION_OPTIONS from paddle.jit.dy2static.convert_call_func import CONVERSION_OPTIONS
from paddle.jit.dy2static.utils import func_to_source_code
SEED = 2020 SEED = 2020
np.random.seed(SEED) np.random.seed(SEED)
...@@ -216,103 +216,57 @@ class TestStaticMethod(TestRecursiveCall2): ...@@ -216,103 +216,57 @@ class TestStaticMethod(TestRecursiveCall2):
# Situation 2 : test not_to_static # Situation 2 : test not_to_static
def func_sum(x): class NotToStaticHelper(paddle.nn.Layer):
res = paddle.sum(x) def __init__(self):
return res super(NotToStaticHelper, self).__init__()
@paddle.jit.not_to_static
def func_not_to_static(x):
res = func_sum(x)
return res
@paddle.jit.to_static def sum(self, x):
def func_convert_then_not_to_static(x): if x.shape[0] > 1:
y = func_not_to_static(x) res = x + 1
return y res = paddle.sum(x)
return res
def outer(self, x):
res = self.sum(x)
return res
class TestClass(paddle.nn.Layer): def inner(self, x):
@paddle.jit.not_to_static return self.outer(x)
def called_member(self, x):
return paddle.sum(x)
@paddle.jit.to_static
def forward(self, x):
y = self.called_member(x)
return y
class TestNotToConvert(TestRecursiveCall2): class TestNotToConvert(TestRecursiveCall2):
def set_func(self): def set_func(self):
self.dygraph_func = func_not_to_static self.net = NotToStaticHelper()
paddle.jit.not_to_static(self.net.sum)
self.dygraph_func = paddle.jit.to_static(self.net.outer)
def test_conversion_options(self): def test_conversion_options(self):
options = getattr(self.dygraph_func, CONVERSION_OPTIONS, None) options = getattr(self.net.sum, CONVERSION_OPTIONS, None)
self.assertIsNotNone(options) self.assertIsNotNone(options)
self.assertTrue(options.not_convert) self.assertTrue(options.not_convert)
class TestNotToConvert2(TestRecursiveCall2):
def set_func(self):
self.dygraph_func = func_convert_then_not_to_static
class TestNotToConvert3(TestRecursiveCall2):
def set_func(self):
self.dygraph_func = TestClass()
class TestDynamicToStaticCode(unittest.TestCase):
def setUp(self):
self.set_func()
self.set_answer_func()
def set_func(self):
self.func = func_not_to_static
def set_answer_func(self):
class StaticCode:
@paddle.jit.not_to_static
def func_not_to_static(x):
res = func_sum(x)
return res
self.answer_func = StaticCode.func_not_to_static
def _get_answer_code(self):
return get_source_code(self.answer_func)
def _get_transformed_code(self):
transformed_func = _jst.Call(self.func)
return get_source_code(transformed_func)
def test_code(self): def test_code(self):
transformed_code = self._get_transformed_code() # check 'if statement' is not converted
answer_code = self._get_answer_code() self.assertIn(
self.assertEqual( "if x.shape[0] > 1", func_to_source_code(_jst.Call(self.net.sum))
answer_code,
transformed_code,
msg="\ntransformed_code : \n{}\nanswer_code : \n{}".format(
transformed_code, answer_code
),
) )
class TestDynamicToStaticCode2(TestDynamicToStaticCode): class TestNotToConvert2(TestRecursiveCall2):
def set_func(self): def set_func(self):
self.func = func_convert_then_not_to_static self.net = NotToStaticHelper()
# for to_static(not_to_static(function)) == enable_static
paddle.jit.not_to_static(self.net.sum)
self.dygraph_func = paddle.jit.to_static(self.net.sum)
def set_answer_func(self): def test_conversion_options(self):
class StaticCode: options = getattr(self.net.sum, CONVERSION_OPTIONS, None)
def func_convert_then_not_to_static(x): self.assertIsNotNone(options)
__return_value_0 = None self.assertTrue(options.not_convert)
y = _jst.Call(func_not_to_static)(x)
__return_value_0 = y
return __return_value_0
self.answer_func = StaticCode.func_convert_then_not_to_static def test_code(self):
# check 'if statement' is not converted
self.assertIn("if x.shape[0] > 1", self.dygraph_func.code)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -42,7 +42,6 @@ from paddle.fluid.dygraph.base import ( ...@@ -42,7 +42,6 @@ from paddle.fluid.dygraph.base import (
from .dy2static import logging_utils from .dy2static import logging_utils
from .dy2static.convert_call_func import ( from .dy2static.convert_call_func import (
ConversionOptions, ConversionOptions,
CONVERSION_OPTIONS,
add_ignore_module, add_ignore_module,
) )
from .dy2static.program_translator import ( from .dy2static.program_translator import (
...@@ -348,7 +347,7 @@ def not_to_static(func=None): ...@@ -348,7 +347,7 @@ def not_to_static(func=None):
return not_to_static return not_to_static
options = ConversionOptions(not_convert=True) options = ConversionOptions(not_convert=True)
setattr(func, CONVERSION_OPTIONS, options) options.attach(func)
return func return func
......
...@@ -42,7 +42,7 @@ __all__ = [] ...@@ -42,7 +42,7 @@ __all__ = []
translator_logger = TranslatorLogger() translator_logger = TranslatorLogger()
CONVERSION_OPTIONS = "An attribute for a function that indicates conversion flags of the function in dynamic-to-static." CONVERSION_OPTIONS = "__jst_not_to_static"
class ConversionOptions: class ConversionOptions:
...@@ -58,6 +58,19 @@ class ConversionOptions: ...@@ -58,6 +58,19 @@ class ConversionOptions:
def __init__(self, not_convert=False): def __init__(self, not_convert=False):
self.not_convert = not_convert self.not_convert = not_convert
def attach(self, func):
if inspect.ismethod(func):
func = func.__func__
if inspect.isfunction(func):
setattr(func, CONVERSION_OPTIONS, self)
else:
translator_logger.warn(
"Only support @not_to_static to type(function) or type(method), but recevied {}".format(
type(func)
)
)
def is_builtin(func, name=None): def is_builtin(func, name=None):
"""predict whether a function is a builtin function with name={name}. """predict whether a function is a builtin function with name={name}.
......
...@@ -28,6 +28,7 @@ from paddle.utils import gast ...@@ -28,6 +28,7 @@ from paddle.utils import gast
from . import error, logging_utils from . import error, logging_utils
from .ast_transformer import DygraphToStaticAst from .ast_transformer import DygraphToStaticAst
from .convert_call_func import CONVERSION_OPTIONS
from .function_spec import ( from .function_spec import (
FunctionSpec, FunctionSpec,
_hash_spec_names, _hash_spec_names,
...@@ -152,6 +153,12 @@ def convert_to_static(function): ...@@ -152,6 +153,12 @@ def convert_to_static(function):
""" """
if getattr(function, ALREADY_D2S, None): if getattr(function, ALREADY_D2S, None):
return function return function
# Return directly if decorated with @not_to_static and DO NOT Cache it
options = getattr(function, CONVERSION_OPTIONS, None)
if options is not None and options.not_convert:
return function.__func__ if inspect.ismethod(function) else function
with _CACHE_LOCK: with _CACHE_LOCK:
static_func = _FUNCTION_CACHE.convert_with_cache(function) static_func = _FUNCTION_CACHE.convert_with_cache(function)
setattr(static_func, ALREADY_D2S, True) setattr(static_func, ALREADY_D2S, True)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import ast import ast
import atexit import atexit
import copy import copy
import functools
import importlib.util import importlib.util
import inspect import inspect
import os import os
...@@ -23,7 +24,6 @@ import sys ...@@ -23,7 +24,6 @@ import sys
import tempfile import tempfile
import textwrap import textwrap
import warnings import warnings
from functools import reduce
from importlib.machinery import SourceFileLoader from importlib.machinery import SourceFileLoader
import astor import astor
...@@ -637,6 +637,8 @@ def func_to_source_code(function, dedent=True): ...@@ -637,6 +637,8 @@ def func_to_source_code(function, dedent=True):
""" """
Transforms function into raw string of source code. Transforms function into raw string of source code.
""" """
if isinstance(function, functools.partial):
function = function.func
if not (inspect.isfunction(function) or inspect.ismethod(function)): if not (inspect.isfunction(function) or inspect.ismethod(function)):
raise TypeError( raise TypeError(
"The type of 'function' should be a function or method, but received {}.".format( "The type of 'function' should be a function or method, but received {}.".format(
...@@ -1429,7 +1431,9 @@ class GetterSetterHelper: ...@@ -1429,7 +1431,9 @@ class GetterSetterHelper:
def __init__(self, getter_func, setter_func, *name_lists): def __init__(self, getter_func, setter_func, *name_lists):
name_lists = map(lambda x: [] if x is None else x, name_lists) name_lists = map(lambda x: [] if x is None else x, name_lists)
name_sets = map(lambda x: set(x), name_lists) name_sets = map(lambda x: set(x), name_lists)
self._union = list(reduce(lambda x, y: x | y, name_sets, set())) self._union = list(
functools.reduce(lambda x, y: x | y, name_sets, set())
)
self._union.sort() self._union.sort()
self.getter = getter_func self.getter = getter_func
self.setter = setter_func self.setter = setter_func
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册