From 5b993f2b71338b2f07c20d73b23b860df811621a Mon Sep 17 00:00:00 2001 From: xiongkun Date: Fri, 3 Feb 2023 10:56:47 +0800 Subject: [PATCH] [dy2static] support fallback for whole graph. (stage 1) (#50111) * [dy2static] support fallback for whole graph. (stage 1) * bug fix * bug fix and add a new unittest * fix code by code review * fix coverage --- .../dygraph_to_static/test_fallback.py | 145 ++++++++++++++++++ .../jit/dy2static/program_translator.py | 80 +++++++++- 2 files changed, 218 insertions(+), 7 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/dygraph_to_static/test_fallback.py diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_fallback.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_fallback.py new file mode 100644 index 0000000000..a2638f3a42 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_fallback.py @@ -0,0 +1,145 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +import numpy as np + +import paddle + + +def support_func(x): + return 2 * x + + +def unsupport_func(x): + x = 2 * x + t = x.numpy() + t = np.ones(t) + return paddle.to_tensor(t) + + +class SuppportNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, x): + return support_func(x) + + +class UnsuppportNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + + def forward(self, x): + if self.training: + return unsupport_func(x) + else: + return unsupport_func(x - 1) + + +class TestFallback(unittest.TestCase): + def setUp(self): + self.x = paddle.to_tensor(2).astype('int') + + def tearDown(self): + pass + + def test_case_support(self): + output = paddle.jit.to_static(support_func)(self.x) + np.testing.assert_allclose(output.numpy(), 4) + + def test_case_func_fallback(self): + build_strategy = paddle.static.BuildStrategy() + build_strategy.build_cinn_pass = True + output = paddle.jit.to_static( + unsupport_func, build_strategy=build_strategy + )(self.x) + np.testing.assert_allclose(output.numpy(), unsupport_func(self.x)) + + def test_case_net_fallback(self): + s_net = SuppportNet() + u_net = UnsuppportNet() + np.testing.assert_allclose( + paddle.jit.to_static(s_net)(self.x).numpy(), 4 + ) + build_strategy = paddle.static.BuildStrategy() + build_strategy.build_cinn_pass = True + np.testing.assert_allclose( + paddle.jit.to_static(u_net, build_strategy=build_strategy)( + self.x + ).numpy(), + u_net(self.x).numpy(), + ) + + def test_case_net_error(self): + s_net = SuppportNet() + u_net = UnsuppportNet() + np.testing.assert_allclose( + paddle.jit.to_static(s_net)(self.x).numpy(), 4 + ) + build_strategy = paddle.static.BuildStrategy() + build_strategy.build_cinn_pass = False + with self.assertRaises(TypeError): + np.testing.assert_allclose( + paddle.jit.to_static(u_net, build_strategy=build_strategy)( + self.x + ).numpy(), + u_net(self.x).numpy(), + ) + + def test_case_training(self): + build_strategy = paddle.static.BuildStrategy() + build_strategy.build_cinn_pass = True + u_net = paddle.jit.to_static( + UnsuppportNet(), build_strategy=build_strategy + ) + u_net.eval() + np.testing.assert_allclose(u_net(self.x).numpy(), [1, 1]) + assert u_net.training is False, "Training must be false." + + def test_case_save_error(self): + """ + test the save will raise error. + """ + u_net = UnsuppportNet() + u_net = paddle.jit.to_static( + u_net, input_spec=[paddle.static.InputSpec(name='x', shape=[1])] + ) + with self.assertRaises(TypeError): + paddle.jit.save(u_net, path="model") + + def test_case_save_error_2(self): + """ + test the save will raise error. + """ + u_net = UnsuppportNet() + build_strategy = paddle.static.BuildStrategy() + build_strategy.build_cinn_pass = True + u_net = paddle.jit.to_static(u_net, build_strategy=build_strategy) + u_net(self.x) + with self.assertRaises(RuntimeError): + print(u_net.forward.main_program) + + def test_case_flag(self): + """ + test the flags is working. TODO: add a global flags. + """ + pass + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/jit/dy2static/program_translator.py b/python/paddle/jit/dy2static/program_translator.py index 5a66cd103a..f32168858b 100644 --- a/python/paddle/jit/dy2static/program_translator.py +++ b/python/paddle/jit/dy2static/program_translator.py @@ -16,6 +16,7 @@ import collections import inspect import textwrap import threading +import warnings import weakref from paddle.fluid import _non_static_mode, core, framework @@ -1077,11 +1078,56 @@ class ParametersRecorder: return id(program) +class FallbackProgramLayer(object): + __slots__ = [ + '_instance', + '_dy_func', + 'training', + '_cuda_graph_capture_mode', + '_cuda_graph_pool_id', + ] + + def __init__(self, instance, dy_func): + self._instance = instance + self._dy_func = dy_func + + def __call__(self, inputs): + return self._dy_func(*inputs) + + def __getattr__(self, key): + if key not in self.__slots__: + raise RuntimeError( + "There raises a exception after applying `@paddle.jit.to_static()` and already switch into fallback mode. \n" + "You can't get attribute for a fallback program layer. Please check `to_static.error` file for detail." + ) + elif key in ['training']: + if self._instance is not None: + return getattr(self._instance, key) + return + + return super().__getattr__(key) + + def __setattr__(self, key, value): + if key not in self.__slots__: + raise RuntimeError( + "There raises a exception after applying `@paddle.jit.to_static()` and already switch into fallback mode. \n" + "You can't get attribute for a fallback program layer. Please check `to_static.error` file for detail." + ) + elif key in ['training']: + if self._instance is not None: + return setattr(self._instance, key, value) + return + + return super().__setattr__(key, value) + + class ProgramCache: """ Wrapper class for the program functions defined by dygraph function. """ + dy2static_error_file = "to_static.error" + def __init__(self): # {hash_id : (concrete_program, partial_layer)} self._caches = collections.OrderedDict() @@ -1092,17 +1138,37 @@ class ProgramCache: def _build_once(self, cache_key): # TODO(Aurelius84): Need a gloabl FLAGS to enable/disable to_prim enable_prim = cache_key.kwargs['build_strategy'].build_cinn_pass + # NOTE(xiongkun): Need a global FLAGS to enable/disable fallback + enable_fallback = enable_prim if enable_prim: # TODO(Jiabin): Change this to True if we need this to be default option core.check_and_set_prim_all_enabled() + try: + concrete_program = ConcreteProgram.from_func_spec( + func_spec=cache_key.function_spec, + input_spec=cache_key.input_args_with_spec, + input_kwargs_spec=cache_key.input_kwargs_with_spec, + class_instance=cache_key.class_instance, + **cache_key.kwargs + ) + except Exception as e: + if enable_fallback: + warnings.warn( + "Exception is thrown while applying @paddle.jit.to_static. It will fallback into dygraph mode for training.\n" + "1. You can check `to_static.error` file in current workspace directory for detail.\n" + "2. In fallback mode, you can only do training, can't call paddle.jit.save(). Please modify model code according `to_static.error` firstly" + ) + # TODO(xiongkun) change different file name to avoid overwrite. + with open(self.dy2static_error_file, "w") as fp: + fp.write(str(e)) - concrete_program = ConcreteProgram.from_func_spec( - func_spec=cache_key.function_spec, - input_spec=cache_key.input_args_with_spec, - input_kwargs_spec=cache_key.input_kwargs_with_spec, - class_instance=cache_key.class_instance, - **cache_key.kwargs - ) + fallback_layer = FallbackProgramLayer( + cache_key.class_instance, + cache_key.function_spec.dygraph_function, + ) + return fallback_layer, fallback_layer + else: + raise concrete_program._to_prim() return concrete_program, partial_program_from(concrete_program) -- GitLab