未验证 提交 b1f9ed60 编写于 作者: L liym27 提交者: GitHub

[Dy2Stat] Add debugging and logging mechanism for dygraph to static (#26457)

* [Dy2Stat] Add debugging and logging mechanism for dygraph to static. 

* Remove TransformerError temporarily. 

* import mock in PY2, from unittest import mock in PY3. test=develop

* Expose interfaces set_code_level and set_verbosity in paddle.jit, fix doc of the two interface. 

* polish doc of set_verbosity and set_code_level. 
上级 e3f8e5cf
......@@ -34,6 +34,9 @@ from .convert_call_func import *
from . import convert_operators
from . import logging_utils
from .logging_utils import *
__all__ = []
__all__ += ast_transformer.__all__
__all__ += loop_transformer.__all__
......@@ -41,3 +44,4 @@ __all__ += static_analysis.__all__
__all__ += variable_trans_func.__all__
__all__ += program_translator.__all__
__all__ += convert_call_func.__all__
__all__ += logging_utils.__all__
......@@ -19,7 +19,6 @@ from __future__ import print_function
# as produced by ast.parse from the standard ast module.
# See details in https://github.com/serge-sans-paille/gast/
import gast
from paddle.fluid.dygraph.dygraph_to_static.assert_transformer import AssertTransformer
from paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer import BasicApiTransformer
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakContinueTransformer
......@@ -31,9 +30,11 @@ from paddle.fluid.dygraph.dygraph_to_static.logical_transformer import LogicalTr
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer
from paddle.fluid.dygraph.dygraph_to_static.print_transformer import PrintTransformer
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import ReturnTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
__all__ = ['DygraphToStaticAst']
......@@ -57,45 +58,70 @@ class DygraphToStaticAst(gast.NodeTransformer):
return self.static_analysis_root
def transfer_from_node_type(self, node_wrapper):
translator_logger = logging_utils.TranslatorLogger()
translator_logger.log(
1, " Source code: \n{}".format(ast_to_source_code(self.root)))
# Generic transformation
self.visit(node_wrapper.node)
# Transform basic api of dygraph to static graph and get feed_name_to_arg_name
basic_api_trans = BasicApiTransformer(node_wrapper)
basic_api_trans.transform()
BasicApiTransformer(node_wrapper).transform()
translator_logger.log_transformed_code(1, self.root,
"BasicApiTransformer")
# Transform Tensor.shape into fluid.layers.shape(Tensor)
TensorShapeTransformer(node_wrapper).transform()
translator_logger.log_transformed_code(2, self.root,
"TensorShapeTransformer")
# Transform list used in control flow
ListTransformer(node_wrapper).transform()
translator_logger.log_transformed_code(3, self.root, "ListTransformer")
# Transform break/continue in loops
BreakContinueTransformer(node_wrapper).transform()
translator_logger.log_transformed_code(4, self.root,
"BreakContinueTransformer")
# Transform return in functions
ReturnTransformer(node_wrapper).transform()
translator_logger.log_transformed_code(5, self.root,
"ReturnTransformer")
# Transform logical and/or/not
LogicalTransformer(node_wrapper).transform()
translator_logger.log_transformed_code(6, self.root,
"LogicalTransformer")
# Transform for loop and while loop
LoopTransformer(node_wrapper).transform()
translator_logger.log_transformed_code(7, self.root, "LoopTransformer")
# Transform all if/else statement of Dygraph into Static Graph.
IfElseTransformer(node_wrapper).transform()
translator_logger.log_transformed_code(8, self.root,
"IfElseTransformer")
# Transform python assert statement
AssertTransformer(node_wrapper).transform()
translator_logger.log_transformed_code(9, self.root,
"AssertTransformer")
# Transform all python print statement
PrintTransformer(node_wrapper).transform()
translator_logger.log_transformed_code(10, self.root,
"PrintTransformer")
# Transform call recursively
CallTransformer(node_wrapper).transform()
translator_logger.log_transformed_code(11, self.root, "CallTransformer")
# Transform python type casting statement
CastTransformer(node_wrapper).transform()
translator_logger.log_transformed_code(12, self.root, "CastTransformer")
translator_logger.log_transformed_code(logging_utils.LOG_AllTransformer,
self.root, "All Transformers")
def visit_FunctionDef(self, node):
if self.decorate_func_name is None:
......
......@@ -19,6 +19,8 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrappe
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api
PDB_SET = "pdb.set_trace"
class CallTransformer(gast.NodeTransformer):
"""
......@@ -62,6 +64,12 @@ class CallTransformer(gast.NodeTransformer):
return node
func_str = ast_to_source_code(node.func).strip()
# NOTE(liym27): Don't convert `pad.set_trace` even if the convertion doesn't work finally, because
# it is clearer to see where it is called from.
if PDB_SET in func_str:
return node
new_func_str = "fluid.dygraph.dygraph_to_static.convert_call({})".format(
func_str)
new_func_ast = gast.parse(new_func_str).body[0].value
......
......@@ -27,12 +27,16 @@ import types
import numpy
import six
from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len
from paddle.fluid.dygraph.dygraph_to_static.logging_utils import TranslatorLogger
from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticLayer
from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static
from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len
DECORATOR_NAMES = ['declarative', 'dygraph_to_static_func']
# TODO(liym27): A better way to do this.
BUILTIN_LIKELY_MODULES = [collections, pdb, copy, inspect, re, six, numpy]
translator_logger = TranslatorLogger()
def is_builtin(func):
......@@ -40,11 +44,6 @@ def is_builtin(func):
return True
elif func in six.moves.builtins.__dict__.values():
return True
# Other built-in modules
# TODO(liym27): A better way to do this.
elif any(func in m.__dict__.values()
for m in (collections, pdb, copy, inspect, re, six, numpy)):
return True
else:
return False
......@@ -60,6 +59,26 @@ def is_paddle_func(func):
return m is not None and m.__name__.startswith("paddle")
def is_unsupported(func):
"""
Checks whether the func is supported by dygraph to static graph.
"""
if any(func in m.__dict__.values() for m in BUILTIN_LIKELY_MODULES):
translator_logger.log(
2,
"Whitelist: {} is part of built-in module and does not have to be transformed.".
format(func))
return True
if is_paddle_func(func):
translator_logger.log(
2,
"Whitelist: {} is part of Paddle module and does not have to be transformed.".
format(func))
return True
def convert_call(func):
"""
Converts a function call which needs to be transformed to static function.
......@@ -94,6 +113,8 @@ def convert_call(func):
# [1. 1. 1.]]
"""
translator_logger.log(1,
"Convert callable object: convert {}.".format(func))
func_self = None
converted_call = None
......@@ -109,7 +130,7 @@ def convert_call(func):
if is_builtin_len(func):
return convert_len
if is_builtin(func) or is_paddle_func(func):
if is_builtin(func) or is_unsupported(func):
return func
if inspect.isfunction(func):
......@@ -139,6 +160,14 @@ def convert_call(func):
if func in global_functions:
converted_call = convert_to_static(func)
func_self = getattr(func, '__self__', None)
else:
# NOTE:
# If func is not in __globals__, it does not need to be transformed
# because it has been transformed before.
translator_logger.warn(
"{} doesn't have to be transformed to static function because it has been transformed before, it will be run as-is."
.format(func))
converted_call = func
except AttributeError:
# NOTE:
# If func is not in __globals__, it does not need to be transformed
......@@ -177,8 +206,14 @@ def convert_call(func):
# If `func` is a class which is being initialized, for example `convert_call(Foo)()`,
# it doesn't need to be transformed
func_self = None if func_self else func_self
else:
raise NotImplementedError(
"Callable {} can not be transformed at present.".format(func))
if converted_call is None:
translator_logger.warn(
"{} doesn't have to be transformed to static function, and it will be run as-is."
.format(func))
return func
if func_self:
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import threading
import six
from paddle.fluid import log_helper
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
__all__ = ["TranslatorLogger", "set_verbosity", "set_code_level"]
VERBOSITY_ENV_NAME = 'TRANSLATOR_VERBOSITY'
CODE_LEVEL_ENV_NAME = 'TRANSLATOR_CODE_LEVEL'
DEFAULT_VERBOSITY = -1
DEFAULT_CODE_LEVEL = -1
def synchronized(func):
def wrapper(*args, **kwargs):
with threading.Lock():
return func(*args, **kwargs)
return wrapper
class TranslatorLogger(object):
"""
class for Logging and debugging during the tranformation from dygraph to static graph.
The object of this class is a singleton.
"""
@synchronized
def __new__(cls, *args, **kwargs):
if not hasattr(cls, '_instance'):
cls._instance = object.__new__(cls, *args, **kwargs)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self._initialized = True
self._logger = log_helper.get_logger(
__name__, 1, fmt='%(asctime)s-%(levelname)s: %(message)s')
self._verbosity_level = None
self._transformed_code_level = None
@property
def logger(self):
return self._logger
@property
def verbosity_level(self):
if self._verbosity_level is not None:
return self._verbosity_level
else:
return int(os.getenv(VERBOSITY_ENV_NAME, DEFAULT_VERBOSITY))
@verbosity_level.setter
def verbosity_level(self, level):
self.check_level(level)
self._verbosity_level = level
@property
def transformed_code_level(self):
if self._transformed_code_level is not None:
return self._transformed_code_level
else:
return int(os.getenv(CODE_LEVEL_ENV_NAME, DEFAULT_CODE_LEVEL))
@transformed_code_level.setter
def transformed_code_level(self, level):
self.check_level(level)
self._transformed_code_level = level
def check_level(self, level):
if isinstance(level, (six.integer_types, type(None))):
rv = level
else:
raise TypeError("Level is not an integer: {}".format(level))
return rv
def has_code_level(self, level):
level = self.check_level(level)
return level == self.transformed_code_level
def has_verbosity(self, level):
level = self.check_level(level)
return level >= self.verbosity_level
def error(self, msg, *args, **kwargs):
self.logger.error(msg, *args, **kwargs)
def warn(self, msg, *args, **kwargs):
self.logger.warn(msg, *args, **kwargs)
def log(self, level, msg, *args, **kwargs):
if self.has_verbosity(level):
self.logger.log(level, msg, *args, **kwargs)
def log_transformed_code(self, level, ast_node, transformer_name, *args,
**kwargs):
if self.has_code_level(level):
source_code = ast_to_source_code(ast_node)
header_msg = "After the level {} ast transformer: '{}', the transformed code:\n"\
.format(level, transformer_name)
msg = header_msg + source_code
self.logger.info(msg, *args, **kwargs)
_TRANSLATOR_LOGGER = TranslatorLogger()
def set_verbosity(level=0):
"""
Sets the verbosity level of log for dygraph to static graph.
There are two means to set the logging verbosity:
1. Call function `set_verbosity`
2. Set environment variable `TRANSLATOR_VERBOSITY`
**Note**:
`set_verbosity` has a higher priority than the environment variable.
Args:
level(int): The verbosity level. The larger value idicates more verbosity.
The default value is 0, which means no logging.
Examples:
.. code-block:: python
import os
import paddle
paddle.jit.set_verbosity(1)
# The verbosity level is now 1
os.environ['TRANSLATOR_VERBOSITY'] = '3'
# The verbosity level is now 3, but it has no effect because it has a lower priority than `set_verbosity`
"""
_TRANSLATOR_LOGGER.verbosity_level = level
def get_verbosity():
return _TRANSLATOR_LOGGER.verbosity_level
LOG_AllTransformer = 100
def set_code_level(level=LOG_AllTransformer):
"""
Sets the level to print code from specific level of Ast Transformer.
There are two means to set the code level:
1. Call function `set_code_level`
2. Set environment variable `TRANSLATOR_CODE_LEVEL`
**Note**:
`set_code_level` has a higher priority than the environment variable.
Args:
level(int): The level to print code. Default is 100, which means to print the code after all AST Transformers.
Examples:
.. code-block:: python
import paddle
paddle.jit.set_code_level(2)
# It will print the transformed code at level 2, which means to print the code after second transformer,
# as the date of August 28, 2020, it is CastTransformer.
os.environ['TRANSLATOR_CODE_LEVEL'] = '3'
# The code level is now 3, but it has no effect because it has a lower priority than `set_code_level`
"""
_TRANSLATOR_LOGGER.transformed_code_level = level
def get_code_level():
return _TRANSLATOR_LOGGER.transformed_code_level
def error(msg, *args, **kwargs):
_TRANSLATOR_LOGGER.error(msg, *args, **kwargs)
def warn(msg, *args, **kwargs):
_TRANSLATOR_LOGGER.warn(msg, *args, **kwargs)
def log(level, msg, *args, **kwargs):
_TRANSLATOR_LOGGER.log(level, msg, *args, **kwargs)
def log_transformed_code(level, ast_node, transformer_name, *args, **kwargs):
_TRANSLATOR_LOGGER.log_transformed_code(level, ast_node, transformer_name,
*args, **kwargs)
......@@ -40,7 +40,6 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import type_name
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
from paddle.fluid.dygraph.dygraph_to_static.utils import unwrap
from paddle.fluid.dygraph.dygraph_to_static.utils import make_hashable
from paddle.fluid.dygraph.dygraph_to_static.function_spec import FunctionSpec
......
......@@ -24,6 +24,7 @@ from paddle.fluid import core
from paddle.fluid.compiler import BuildStrategy, CompiledProgram, ExecutionStrategy
from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph.base import program_desc_tracing_guard, switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static.logging_utils import set_code_level, set_verbosity
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator, StaticLayer, unwrap_decorators
from paddle.fluid.dygraph.io import EXTRA_VAR_INFO_FILENAME, VARIABLE_FILENAME, TranslatedLayer
from paddle.fluid.dygraph.layers import Layer
......@@ -33,7 +34,10 @@ from paddle.fluid.framework import _current_expected_place, _dygraph_guard, _dyg
from paddle.fluid.framework import dygraph_only, in_dygraph_mode
from paddle.fluid.wrapped_decorator import wrap_decorator
__all__ = ['TracedLayer', 'declarative', 'dygraph_to_static_func']
__all__ = [
'TracedLayer', 'declarative', 'dygraph_to_static_func', 'set_code_level',
'set_verbosity'
]
def create_program_from_desc(program_desc):
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import io
import logging
import os
import sys
import unittest
import gast
import six
import paddle
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
# TODO(liym27): library mock needs to be installed separately in PY2,
# but CI environment has not installed mock yet.
# After discuss with Tian Shuo, now use mock only in PY3, and use it in PY2 after CI installs it.
if six.PY3:
from unittest import mock
# else:
# import mock
class TestLoggingUtils(unittest.TestCase):
def setUp(self):
self.verbosity_level = 1
self.code_level = 3
self.translator_logger = logging_utils._TRANSLATOR_LOGGER
def test_verbosity(self):
paddle.jit.set_verbosity(None)
os.environ[logging_utils.VERBOSITY_ENV_NAME] = '3'
self.assertEqual(logging_utils.get_verbosity(), 3)
paddle.jit.set_verbosity(self.verbosity_level)
self.assertEqual(self.verbosity_level, logging_utils.get_verbosity())
# String is not supported
with self.assertRaises(TypeError):
paddle.jit.set_verbosity("3")
with self.assertRaises(TypeError):
paddle.jit.set_verbosity(3.3)
def test_code_level(self):
paddle.jit.set_code_level(None)
os.environ[logging_utils.CODE_LEVEL_ENV_NAME] = '2'
self.assertEqual(logging_utils.get_code_level(), 2)
paddle.jit.set_code_level(self.code_level)
self.assertEqual(logging_utils.get_code_level(), self.code_level)
paddle.jit.set_code_level(9)
self.assertEqual(logging_utils.get_code_level(), 9)
with self.assertRaises(TypeError):
paddle.jit.set_code_level(3.3)
def test_log(self):
stream = io.BytesIO() if six.PY2 else io.StringIO()
log = self.translator_logger.logger
stdout_handler = logging.StreamHandler(stream)
log.addHandler(stdout_handler)
warn_msg = "test_warn"
error_msg = "test_error"
log_msg_1 = "test_log_1"
log_msg_2 = "test_log_2"
if six.PY3:
with mock.patch.object(sys, 'stdout', stream):
logging_utils.warn(warn_msg)
logging_utils.error(error_msg)
self.translator_logger.verbosity_level = 2
logging_utils.log(1, log_msg_1)
logging_utils.log(2, log_msg_2)
result_msg = '\n'.join([warn_msg, error_msg, log_msg_2, ""])
self.assertEqual(result_msg, stream.getvalue())
def test_log_transformed_code(self):
source_code = "x = 3"
ast_code = gast.parse(source_code)
stream = io.BytesIO() if six.PY2 else io.StringIO()
log = self.translator_logger.logger
stdout_handler = logging.StreamHandler(stream)
log.addHandler(stdout_handler)
if six.PY3:
with mock.patch.object(sys, 'stdout', stream):
paddle.jit.set_code_level(1)
logging_utils.log_transformed_code(1, ast_code,
"BasicApiTransformer")
paddle.jit.set_code_level()
logging_utils.log_transformed_code(
logging_utils.LOG_AllTransformer, ast_code,
"All Transformers")
self.assertIn(source_code, stream.getvalue())
if __name__ == '__main__':
unittest.main()
......@@ -16,11 +16,13 @@ from ..fluid.dygraph.jit import save #DEFINE_ALIAS
from ..fluid.dygraph.jit import load #DEFINE_ALIAS
from ..fluid.dygraph.jit import SaveLoadConfig #DEFINE_ALIAS
from ..fluid.dygraph.jit import TracedLayer #DEFINE_ALIAS
from ..fluid.dygraph.jit import set_code_level #DEFINE_ALIAS
from ..fluid.dygraph.jit import set_verbosity #DEFINE_ALIAS
from ..fluid.dygraph.jit import declarative as to_static #DEFINE_ALIAS
from ..fluid.dygraph import ProgramTranslator #DEFINE_ALIAS
from ..fluid.dygraph.io import TranslatedLayer #DEFINE_ALIAS
__all__ = [
'save', 'load', 'SaveLoadConfig', 'TracedLayer', 'to_static',
'ProgramTranslator', 'TranslatedLayer'
'ProgramTranslator', 'TranslatedLayer', 'set_code_level', 'set_verbosity'
]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册