未验证 提交 5b993f2b 编写于 作者: X xiongkun 提交者: GitHub

[dy2static] support fallback for whole graph. (stage 1) (#50111)

* [dy2static] support fallback for whole graph. (stage 1)

* bug fix

* bug fix and add a new unittest

* fix code by code review

* fix coverage
上级 64573f9f
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
def support_func(x):
return 2 * x
def unsupport_func(x):
x = 2 * x
t = x.numpy()
t = np.ones(t)
return paddle.to_tensor(t)
class SuppportNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
def forward(self, x):
return support_func(x)
class UnsuppportNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
def forward(self, x):
if self.training:
return unsupport_func(x)
else:
return unsupport_func(x - 1)
class TestFallback(unittest.TestCase):
def setUp(self):
self.x = paddle.to_tensor(2).astype('int')
def tearDown(self):
pass
def test_case_support(self):
output = paddle.jit.to_static(support_func)(self.x)
np.testing.assert_allclose(output.numpy(), 4)
def test_case_func_fallback(self):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = True
output = paddle.jit.to_static(
unsupport_func, build_strategy=build_strategy
)(self.x)
np.testing.assert_allclose(output.numpy(), unsupport_func(self.x))
def test_case_net_fallback(self):
s_net = SuppportNet()
u_net = UnsuppportNet()
np.testing.assert_allclose(
paddle.jit.to_static(s_net)(self.x).numpy(), 4
)
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = True
np.testing.assert_allclose(
paddle.jit.to_static(u_net, build_strategy=build_strategy)(
self.x
).numpy(),
u_net(self.x).numpy(),
)
def test_case_net_error(self):
s_net = SuppportNet()
u_net = UnsuppportNet()
np.testing.assert_allclose(
paddle.jit.to_static(s_net)(self.x).numpy(), 4
)
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = False
with self.assertRaises(TypeError):
np.testing.assert_allclose(
paddle.jit.to_static(u_net, build_strategy=build_strategy)(
self.x
).numpy(),
u_net(self.x).numpy(),
)
def test_case_training(self):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = True
u_net = paddle.jit.to_static(
UnsuppportNet(), build_strategy=build_strategy
)
u_net.eval()
np.testing.assert_allclose(u_net(self.x).numpy(), [1, 1])
assert u_net.training is False, "Training must be false."
def test_case_save_error(self):
"""
test the save will raise error.
"""
u_net = UnsuppportNet()
u_net = paddle.jit.to_static(
u_net, input_spec=[paddle.static.InputSpec(name='x', shape=[1])]
)
with self.assertRaises(TypeError):
paddle.jit.save(u_net, path="model")
def test_case_save_error_2(self):
"""
test the save will raise error.
"""
u_net = UnsuppportNet()
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = True
u_net = paddle.jit.to_static(u_net, build_strategy=build_strategy)
u_net(self.x)
with self.assertRaises(RuntimeError):
print(u_net.forward.main_program)
def test_case_flag(self):
"""
test the flags is working. TODO: add a global flags.
"""
pass
if __name__ == "__main__":
unittest.main()
...@@ -16,6 +16,7 @@ import collections ...@@ -16,6 +16,7 @@ import collections
import inspect import inspect
import textwrap import textwrap
import threading import threading
import warnings
import weakref import weakref
from paddle.fluid import _non_static_mode, core, framework from paddle.fluid import _non_static_mode, core, framework
...@@ -1077,11 +1078,56 @@ class ParametersRecorder: ...@@ -1077,11 +1078,56 @@ class ParametersRecorder:
return id(program) return id(program)
class FallbackProgramLayer(object):
__slots__ = [
'_instance',
'_dy_func',
'training',
'_cuda_graph_capture_mode',
'_cuda_graph_pool_id',
]
def __init__(self, instance, dy_func):
self._instance = instance
self._dy_func = dy_func
def __call__(self, inputs):
return self._dy_func(*inputs)
def __getattr__(self, key):
if key not in self.__slots__:
raise RuntimeError(
"There raises a exception after applying `@paddle.jit.to_static()` and already switch into fallback mode. \n"
"You can't get attribute for a fallback program layer. Please check `to_static.error` file for detail."
)
elif key in ['training']:
if self._instance is not None:
return getattr(self._instance, key)
return
return super().__getattr__(key)
def __setattr__(self, key, value):
if key not in self.__slots__:
raise RuntimeError(
"There raises a exception after applying `@paddle.jit.to_static()` and already switch into fallback mode. \n"
"You can't get attribute for a fallback program layer. Please check `to_static.error` file for detail."
)
elif key in ['training']:
if self._instance is not None:
return setattr(self._instance, key, value)
return
return super().__setattr__(key, value)
class ProgramCache: class ProgramCache:
""" """
Wrapper class for the program functions defined by dygraph function. Wrapper class for the program functions defined by dygraph function.
""" """
dy2static_error_file = "to_static.error"
def __init__(self): def __init__(self):
# {hash_id : (concrete_program, partial_layer)} # {hash_id : (concrete_program, partial_layer)}
self._caches = collections.OrderedDict() self._caches = collections.OrderedDict()
...@@ -1092,10 +1138,12 @@ class ProgramCache: ...@@ -1092,10 +1138,12 @@ class ProgramCache:
def _build_once(self, cache_key): def _build_once(self, cache_key):
# TODO(Aurelius84): Need a gloabl FLAGS to enable/disable to_prim # TODO(Aurelius84): Need a gloabl FLAGS to enable/disable to_prim
enable_prim = cache_key.kwargs['build_strategy'].build_cinn_pass enable_prim = cache_key.kwargs['build_strategy'].build_cinn_pass
# NOTE(xiongkun): Need a global FLAGS to enable/disable fallback
enable_fallback = enable_prim
if enable_prim: if enable_prim:
# TODO(Jiabin): Change this to True if we need this to be default option # TODO(Jiabin): Change this to True if we need this to be default option
core.check_and_set_prim_all_enabled() core.check_and_set_prim_all_enabled()
try:
concrete_program = ConcreteProgram.from_func_spec( concrete_program = ConcreteProgram.from_func_spec(
func_spec=cache_key.function_spec, func_spec=cache_key.function_spec,
input_spec=cache_key.input_args_with_spec, input_spec=cache_key.input_args_with_spec,
...@@ -1103,6 +1151,24 @@ class ProgramCache: ...@@ -1103,6 +1151,24 @@ class ProgramCache:
class_instance=cache_key.class_instance, class_instance=cache_key.class_instance,
**cache_key.kwargs **cache_key.kwargs
) )
except Exception as e:
if enable_fallback:
warnings.warn(
"Exception is thrown while applying @paddle.jit.to_static. It will fallback into dygraph mode for training.\n"
"1. You can check `to_static.error` file in current workspace directory for detail.\n"
"2. In fallback mode, you can only do training, can't call paddle.jit.save(). Please modify model code according `to_static.error` firstly"
)
# TODO(xiongkun) change different file name to avoid overwrite.
with open(self.dy2static_error_file, "w") as fp:
fp.write(str(e))
fallback_layer = FallbackProgramLayer(
cache_key.class_instance,
cache_key.function_spec.dygraph_function,
)
return fallback_layer, fallback_layer
else:
raise
concrete_program._to_prim() concrete_program._to_prim()
return concrete_program, partial_program_from(concrete_program) return concrete_program, partial_program_from(concrete_program)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册