未验证 提交 842050f2 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2St]Enhance @not_to_static API (#50453)

* [Dy2St]Enhance @not_to_static API

* del breakpoint()
上级 c5087da8
......@@ -85,3 +85,4 @@ paddle/fluid/pybind/eager_op_function_impl.h
paddle/fluid/pybind/eager_op_function_impl.h
paddle/fluid/pybind/op_function_impl.h
paddle/fluid/pybind/*final_state_op_function_impl.h
paddle/fluid/prim/api/generated/prim_api/*
......@@ -16,12 +16,12 @@ import logging
import unittest
import numpy as np
from test_program_translator import get_source_code
import paddle
import paddle.fluid as fluid
import paddle.jit.dy2static as _jst
from paddle.jit.dy2static.convert_call_func import CONVERSION_OPTIONS
from paddle.jit.dy2static.utils import func_to_source_code
SEED = 2020
np.random.seed(SEED)
......@@ -216,103 +216,57 @@ class TestStaticMethod(TestRecursiveCall2):
# Situation 2 : test not_to_static
def func_sum(x):
res = paddle.sum(x)
return res
@paddle.jit.not_to_static
def func_not_to_static(x):
res = func_sum(x)
return res
class NotToStaticHelper(paddle.nn.Layer):
def __init__(self):
super(NotToStaticHelper, self).__init__()
@paddle.jit.to_static
def func_convert_then_not_to_static(x):
y = func_not_to_static(x)
return y
def sum(self, x):
if x.shape[0] > 1:
res = x + 1
res = paddle.sum(x)
return res
def outer(self, x):
res = self.sum(x)
return res
class TestClass(paddle.nn.Layer):
@paddle.jit.not_to_static
def called_member(self, x):
return paddle.sum(x)
@paddle.jit.to_static
def forward(self, x):
y = self.called_member(x)
return y
def inner(self, x):
return self.outer(x)
class TestNotToConvert(TestRecursiveCall2):
def set_func(self):
self.dygraph_func = func_not_to_static
self.net = NotToStaticHelper()
paddle.jit.not_to_static(self.net.sum)
self.dygraph_func = paddle.jit.to_static(self.net.outer)
def test_conversion_options(self):
options = getattr(self.dygraph_func, CONVERSION_OPTIONS, None)
options = getattr(self.net.sum, CONVERSION_OPTIONS, None)
self.assertIsNotNone(options)
self.assertTrue(options.not_convert)
class TestNotToConvert2(TestRecursiveCall2):
def set_func(self):
self.dygraph_func = func_convert_then_not_to_static
class TestNotToConvert3(TestRecursiveCall2):
def set_func(self):
self.dygraph_func = TestClass()
class TestDynamicToStaticCode(unittest.TestCase):
def setUp(self):
self.set_func()
self.set_answer_func()
def set_func(self):
self.func = func_not_to_static
def set_answer_func(self):
class StaticCode:
@paddle.jit.not_to_static
def func_not_to_static(x):
res = func_sum(x)
return res
self.answer_func = StaticCode.func_not_to_static
def _get_answer_code(self):
return get_source_code(self.answer_func)
def _get_transformed_code(self):
transformed_func = _jst.Call(self.func)
return get_source_code(transformed_func)
def test_code(self):
transformed_code = self._get_transformed_code()
answer_code = self._get_answer_code()
self.assertEqual(
answer_code,
transformed_code,
msg="\ntransformed_code : \n{}\nanswer_code : \n{}".format(
transformed_code, answer_code
),
# check 'if statement' is not converted
self.assertIn(
"if x.shape[0] > 1", func_to_source_code(_jst.Call(self.net.sum))
)
class TestDynamicToStaticCode2(TestDynamicToStaticCode):
class TestNotToConvert2(TestRecursiveCall2):
def set_func(self):
self.func = func_convert_then_not_to_static
self.net = NotToStaticHelper()
# for to_static(not_to_static(function)) == enable_static
paddle.jit.not_to_static(self.net.sum)
self.dygraph_func = paddle.jit.to_static(self.net.sum)
def set_answer_func(self):
class StaticCode:
def func_convert_then_not_to_static(x):
__return_value_0 = None
y = _jst.Call(func_not_to_static)(x)
__return_value_0 = y
return __return_value_0
def test_conversion_options(self):
options = getattr(self.net.sum, CONVERSION_OPTIONS, None)
self.assertIsNotNone(options)
self.assertTrue(options.not_convert)
self.answer_func = StaticCode.func_convert_then_not_to_static
def test_code(self):
# check 'if statement' is not converted
self.assertIn("if x.shape[0] > 1", self.dygraph_func.code)
if __name__ == '__main__':
......
......@@ -42,7 +42,6 @@ from paddle.fluid.dygraph.base import (
from .dy2static import logging_utils
from .dy2static.convert_call_func import (
ConversionOptions,
CONVERSION_OPTIONS,
add_ignore_module,
)
from .dy2static.program_translator import (
......@@ -348,7 +347,7 @@ def not_to_static(func=None):
return not_to_static
options = ConversionOptions(not_convert=True)
setattr(func, CONVERSION_OPTIONS, options)
options.attach(func)
return func
......
......@@ -42,7 +42,7 @@ __all__ = []
translator_logger = TranslatorLogger()
CONVERSION_OPTIONS = "An attribute for a function that indicates conversion flags of the function in dynamic-to-static."
CONVERSION_OPTIONS = "__jst_not_to_static"
class ConversionOptions:
......@@ -58,6 +58,19 @@ class ConversionOptions:
def __init__(self, not_convert=False):
self.not_convert = not_convert
def attach(self, func):
if inspect.ismethod(func):
func = func.__func__
if inspect.isfunction(func):
setattr(func, CONVERSION_OPTIONS, self)
else:
translator_logger.warn(
"Only support @not_to_static to type(function) or type(method), but recevied {}".format(
type(func)
)
)
def is_builtin(func, name=None):
"""predict whether a function is a builtin function with name={name}.
......
......@@ -28,6 +28,7 @@ from paddle.utils import gast
from . import error, logging_utils
from .ast_transformer import DygraphToStaticAst
from .convert_call_func import CONVERSION_OPTIONS
from .function_spec import (
FunctionSpec,
_hash_spec_names,
......@@ -152,6 +153,12 @@ def convert_to_static(function):
"""
if getattr(function, ALREADY_D2S, None):
return function
# Return directly if decorated with @not_to_static and DO NOT Cache it
options = getattr(function, CONVERSION_OPTIONS, None)
if options is not None and options.not_convert:
return function.__func__ if inspect.ismethod(function) else function
with _CACHE_LOCK:
static_func = _FUNCTION_CACHE.convert_with_cache(function)
setattr(static_func, ALREADY_D2S, True)
......
......@@ -15,6 +15,7 @@
import ast
import atexit
import copy
import functools
import importlib.util
import inspect
import os
......@@ -23,7 +24,6 @@ import sys
import tempfile
import textwrap
import warnings
from functools import reduce
from importlib.machinery import SourceFileLoader
import astor
......@@ -637,6 +637,8 @@ def func_to_source_code(function, dedent=True):
"""
Transforms function into raw string of source code.
"""
if isinstance(function, functools.partial):
function = function.func
if not (inspect.isfunction(function) or inspect.ismethod(function)):
raise TypeError(
"The type of 'function' should be a function or method, but received {}.".format(
......@@ -1429,7 +1431,9 @@ class GetterSetterHelper:
def __init__(self, getter_func, setter_func, *name_lists):
name_lists = map(lambda x: [] if x is None else x, name_lists)
name_sets = map(lambda x: set(x), name_lists)
self._union = list(reduce(lambda x, y: x | y, name_sets, set()))
self._union = list(
functools.reduce(lambda x, y: x | y, name_sets, set())
)
self._union.sort()
self.getter = getter_func
self.setter = setter_func
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册