未验证 提交 b9a8ebd5 编写于 作者: L liym27 提交者: GitHub

[Dy2stat] Add a decorator paddle.jit.not_to_static to support that not to...

[Dy2stat] Add a decorator paddle.jit.not_to_static to support that not to convert a function in Dynamic-to-Static. (#29253)

Usage scenarios:A function could have run successfully in static mode,  you can use it to decorate a function in the following cases:
  1. An unknown error occurs in the dynamic-to-static conversion process of the function;
  2. In the internal implementation of the function, it has two branches: dynamic branch and static branch;
  3. Users don't want to convert the function in the process of dynamic to static.
上级 2d6aa1a5
......@@ -14,8 +14,6 @@
from __future__ import print_function
__all__ = ['convert_call']
import collections
import copy
import functools
......@@ -35,6 +33,8 @@ from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to
from paddle.fluid.dygraph.dygraph_to_static.program_translator import unwrap_decorators
from paddle.fluid.dygraph.layers import Layer
__all__ = ["convert_call"]
# TODO(liym27): A better way to do this.
BUILTIN_LIKELY_MODULES = [
collections, pdb, copy, inspect, re, six, numpy, logging
......@@ -42,6 +42,22 @@ BUILTIN_LIKELY_MODULES = [
translator_logger = TranslatorLogger()
CONVERSION_OPTIONS = "An attribute for a function that indicates conversion flags of the function in dynamic-to-static."
class ConversionOptions(object):
"""
A container for conversion flags of a function in dynamic-to-static.
Attributes:
not_convert(bool): An attribute indicates that the function won't be converted in dynamic-to-static.
NOTE(liym27): More attributes and methods can be added in this class.
"""
def __init__(self, not_convert=False):
self.not_convert = not_convert
def is_builtin(func):
if isinstance(func, types.BuiltinFunctionType):
......@@ -133,6 +149,14 @@ def convert_call(func):
# in this case, unwraps it into a raw method or function.
_, func = unwrap_decorators(func)
options = getattr(func, CONVERSION_OPTIONS, None)
if options is not None and options.not_convert:
translator_logger.log(
2,
"{} is not converted when it is decorated by 'paddle.jit.not_to_static'.".
format(func))
return func
if is_builtin_len(func):
return convert_len
......
......@@ -28,6 +28,7 @@ from paddle.fluid.data_feeder import check_type
from paddle.fluid.layers.utils import flatten
from paddle.fluid.dygraph.base import program_desc_tracing_guard, switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import ConversionOptions, CONVERSION_OPTIONS
from paddle.fluid.dygraph.dygraph_to_static.logging_utils import set_code_level, set_verbosity
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator, StaticFunction, unwrap_decorators
from paddle.fluid.dygraph.io import TranslatedLayer, INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX, INFER_PARAMS_INFO_SUFFIX
......@@ -40,7 +41,7 @@ from paddle.fluid.wrapped_decorator import wrap_decorator
__all__ = [
'TracedLayer', 'declarative', 'dygraph_to_static_func', 'set_code_level',
'set_verbosity', 'save', 'load'
'set_verbosity', 'save', 'load', 'not_to_static'
]
......@@ -225,6 +226,46 @@ def declarative(function=None, input_spec=None):
return decorated
def not_to_static(func=None):
"""
A Decorator to suppresses the convertion of a function.
Args:
func(callable): The function to decorate.
Returns:
callable: A function which won't be converted in Dynamic-to-Static.
Examples:
.. code-block:: python
import paddle
@paddle.jit.not_to_static
def func_not_to_static(x):
res = x - 1
return res
@paddle.jit.to_static
def func(x):
if paddle.mean(x) < 0:
out = func_not_to_static(x)
else:
out = x + 1
return out
x = paddle.ones([1, 2], dtype='float32')
out = func(x)
print(out) # [[2. 2.]]
"""
if func is None:
return not_to_static
options = ConversionOptions(not_convert=True)
setattr(func, CONVERSION_OPTIONS, options)
return func
class _SaveLoadConfig(object):
def __init__(self):
self._output_spec = None
......
......@@ -19,18 +19,22 @@ import unittest
import logging
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph import ProgramTranslator
from paddle.fluid.dygraph import declarative
from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import CONVERSION_OPTIONS
from test_program_translator import get_source_code
program_translator = ProgramTranslator()
SEED = 2020
np.random.seed(SEED)
# Situation 1 : test recursive call
# Use a decorator to test exception
@declarative
@paddle.jit.to_static
def dyfunc_with_if(x_v):
if fluid.layers.mean(x_v).numpy()[0] > 5:
x_v = x_v - 1
......@@ -39,7 +43,7 @@ def dyfunc_with_if(x_v):
return x_v
@declarative
@paddle.jit.to_static
def nested_func(x_v):
x_v = fluid.dygraph.to_variable(x_v)
......@@ -50,7 +54,7 @@ def nested_func(x_v):
return res
@declarative
@paddle.jit.to_static
def dyfunc_with_third_library_logging(x_v):
logging.info('test dyfunc_with_third_library_logging')
if fluid.layers.mean(x_v).numpy()[0] > 5:
......@@ -106,14 +110,14 @@ class MyConvLayer(fluid.dygraph.Layer):
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.5)))
@declarative
@paddle.jit.to_static
def forward(self, inputs):
y = dyfunc_with_if(inputs)
y = lambda_fun(y)
y = self.dymethod(y)
return y
@declarative
@paddle.jit.to_static
def dymethod(self, x_v):
x_v = fluid.layers.assign(x_v)
return x_v
......@@ -133,7 +137,7 @@ class MyLayer(fluid.dygraph.Layer):
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.5)))
@declarative
@paddle.jit.to_static
def forward(self, inputs):
h = self.conv(inputs)
out = self.fc(h)
......@@ -143,15 +147,15 @@ class MyLayer(fluid.dygraph.Layer):
class TestRecursiveCall2(unittest.TestCase):
def setUp(self):
self.input = np.random.random((1, 3, 3, 5)).astype('float32')
self.Layer = MyLayer
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
self.set_func()
def set_func(self):
self.dygraph_func = MyLayer()
def _run(self):
with fluid.dygraph.guard():
self.dygraph_func = self.Layer()
fluid.default_startup_program.random_seed = SEED
fluid.default_main_program.random_seed = SEED
data = fluid.dygraph.to_variable(self.input)
res = self.dygraph_func(data)
......@@ -175,14 +179,106 @@ class TestRecursiveCall2(unittest.TestCase):
class TestThirdPartyLibrary(TestRecursiveCall2):
def _run(self):
with fluid.dygraph.guard():
def set_func(self):
self.dygraph_func = dyfunc_with_third_library_logging
fluid.default_startup_program.random_seed = SEED
fluid.default_main_program.random_seed = SEED
data = fluid.dygraph.to_variable(self.input)
res = self.dygraph_func(data)
return res.numpy()
# Situation 2 : test not_to_static
def func_sum(x):
res = paddle.sum(x)
return res
@paddle.jit.not_to_static
def func_not_to_static(x):
res = func_sum(x)
return res
@paddle.jit.to_static
def func_convert_then_not_to_static(x):
y = func_not_to_static(x)
return y
class TestClass(paddle.nn.Layer):
@paddle.jit.not_to_static
def called_member(self, x):
return paddle.sum(x)
@paddle.jit.to_static
def forward(self, x):
y = self.called_member(x)
return y
class TestNotToConvert(TestRecursiveCall2):
def set_func(self):
self.dygraph_func = func_not_to_static
def test_conversion_options(self):
options = getattr(self.dygraph_func, CONVERSION_OPTIONS, None)
self.assertIsNotNone(options)
self.assertTrue(options.not_convert)
class TestNotToConvert2(TestRecursiveCall2):
def set_func(self):
self.dygraph_func = func_convert_then_not_to_static
class TestNotToConvert3(TestRecursiveCall2):
def set_func(self):
self.dygraph_func = TestClass()
class TestDynamicToStaticCode(unittest.TestCase):
def setUp(self):
self.set_func()
self.set_answer_func()
def set_func(self):
self.func = func_not_to_static
def set_answer_func(self):
class StaticCode():
@paddle.jit.not_to_static
def func_not_to_static(x):
res = func_sum(x)
return res
self.answer_func = StaticCode.func_not_to_static
def _get_answer_code(self):
return get_source_code(self.answer_func)
def _get_transformed_code(self):
transformed_func = paddle.jit.dy2static.convert_call(self.func)
return get_source_code(transformed_func)
def test_code(self):
transformed_code = self._get_transformed_code()
answer_code = self._get_answer_code()
self.assertEqual(
answer_code,
transformed_code,
msg="\ntransformed_code : \n{}\nanswer_code : \n{}".format(
transformed_code, answer_code))
class TestDynamicToStaticCode2(TestDynamicToStaticCode):
def set_func(self):
self.func = func_convert_then_not_to_static
def set_answer_func(self):
class StaticCode():
def func_convert_then_not_to_static(x):
y = paddle.jit.dy2static.convert_call(func_not_to_static)(x)
return y
self.answer_func = StaticCode.func_convert_then_not_to_static
if __name__ == '__main__':
......
......@@ -20,6 +20,7 @@ from ..fluid.dygraph.jit import TracedLayer #DEFINE_ALIAS
from ..fluid.dygraph.jit import set_code_level #DEFINE_ALIAS
from ..fluid.dygraph.jit import set_verbosity #DEFINE_ALIAS
from ..fluid.dygraph.jit import declarative as to_static #DEFINE_ALIAS
from ..fluid.dygraph.jit import not_to_static #DEFINE_ALIAS
from ..fluid.dygraph import ProgramTranslator #DEFINE_ALIAS
from ..fluid.dygraph.io import TranslatedLayer #DEFINE_ALIAS
......@@ -27,5 +28,5 @@ from . import dy2static
__all__ = [
'save', 'load', 'TracedLayer', 'to_static', 'ProgramTranslator',
'TranslatedLayer', 'set_code_level', 'set_verbosity'
'TranslatedLayer', 'set_code_level', 'set_verbosity', 'not_to_static'
]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册