未验证 提交 dc956767 编写于 作者: A Aurelius84 提交者: GitHub

Fix ProxyLayer @to_static not take effect problem (#51464)

上级 afa26a59
...@@ -12,11 +12,13 @@ ...@@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
import logging import logging
from collections import defaultdict from collections import defaultdict
from paddle.jit import not_to_static, to_static from paddle.jit import not_to_static, to_static
from paddle.jit.dy2static.program_translator import StaticFunction from paddle.jit.dy2static.program_translator import StaticFunction
from paddle.jit.dy2static.utils import as_not_paddle_func
from paddle.nn import Layer from paddle.nn import Layer
from paddle.static import Parameter, global_scope, program_guard from paddle.static import Parameter, global_scope, program_guard
...@@ -47,6 +49,12 @@ class ProxyLayer(Layer): ...@@ -47,6 +49,12 @@ class ProxyLayer(Layer):
self._loss_vars = defaultdict(list) self._loss_vars = defaultdict(list)
self._metric_vars = defaultdict(list) self._metric_vars = defaultdict(list)
# Consider ProxyLayer as not Paddle inner function because it contains
# user-defined layer.
as_not_paddle_func(
inspect.getmodule(ProxyLayer).__name__ + ".ProxyLayer"
)
def _train(self, inputs, labels): def _train(self, inputs, labels):
""" """
Train process of inner_layer with forward/loss/metric logic. Train process of inner_layer with forward/loss/metric logic.
......
...@@ -20,10 +20,11 @@ import paddle ...@@ -20,10 +20,11 @@ import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle import LazyGuard from paddle import LazyGuard
from paddle.distributed.auto_parallel.helper import ProgramHelper from paddle.distributed.auto_parallel.helper import ProgramHelper, ProxyLayer
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
from paddle.fluid.framework import _non_static_mode from paddle.fluid.framework import _non_static_mode
from paddle.io import Dataset from paddle.io import Dataset
from paddle.jit.dy2static.utils import is_paddle_func
from paddle.static import InputSpec from paddle.static import InputSpec
batch_size = 4 batch_size = 4
...@@ -182,5 +183,23 @@ class TestLazyInit(unittest.TestCase): ...@@ -182,5 +183,23 @@ class TestLazyInit(unittest.TestCase):
program_helper.reset() program_helper.reset()
class TestIgnoreProxyLayer(unittest.TestCase):
def test_is_paddle_func(self):
mlp = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02,
)
loss = paddle.nn.CrossEntropyLoss()
metrics = paddle.metric.Accuracy()
proxy_layer = ProxyLayer(mlp, loss, metrics)
self.assertFalse(is_paddle_func(proxy_layer._train))
self.assertFalse(is_paddle_func(proxy_layer._eval))
self.assertFalse(is_paddle_func(proxy_layer._predict))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -296,21 +296,51 @@ def is_paddle_api(node): ...@@ -296,21 +296,51 @@ def is_paddle_api(node):
return is_api_in_module(node, PADDLE_MODULE_PREFIX) return is_api_in_module(node, PADDLE_MODULE_PREFIX)
def is_paddle_func(func): # NOTE(Aurelius84): Consider the following paddle inner API as common case to
# apply @to_static code transformation as usual. Because they contains
# user-defined layer, like paddle.distributed.auto_parallel.helper.ProxyLayer.
AS_NOT_INNER_FUNC_LIST = set()
def as_not_paddle_func(path):
"""
Append API or class as ignored case for is_paddle_func, and they
will be retured False while calling is_paddle_func(func).
"""
global INNER_FUNC_WHITE_LIST
AS_NOT_INNER_FUNC_LIST.add(path)
def is_paddle_func(func, ignore_white_list=True):
"""
Return True if function is defined in Paddle module.
Skip to check APIs in white list if specifying ignore_white_list as True.
"""
def in_white_list(module, func_name):
if func_name is None:
return False
return (module.__name__ + '.' + func_name) in AS_NOT_INNER_FUNC_LIST
try: try:
if isinstance(func, functools.partial): if isinstance(func, functools.partial):
func = func.func func = func.func
func_name = getattr(func, '__name__', None)
# In case of dynamically monkey patch customised function # In case of dynamically monkey patch customised function
# into paddle class obj, so we consider its class module # into paddle class obj, so we consider its class module
# path as prefix. # path as prefix.
if hasattr(func, "__self__"): if hasattr(func, "__self__"):
func = func.__self__ func = func.__self__
func_name = func.__class__.__name__
elif inspect.ismethod(func): elif inspect.ismethod(func):
func = func.__func__ func = func.__func__
m = inspect.getmodule(func) m = inspect.getmodule(func)
return m is not None and m.__name__.startswith(PADDLE_MODULE_PREFIX) flag = m is not None and m.__name__.startswith(PADDLE_MODULE_PREFIX)
if ignore_white_list:
flag = flag and not in_white_list(m, func_name)
return flag
except Exception: except Exception:
return False return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册