未验证 提交 ac82baa8 编写于 作者: L liym27 提交者: GitHub

[Dy2Stat-log] Add feature also_to_stdout and optimize log messages (#27285)

* Add env value to  log to stdout; 2.Add logger name

* Optimize log messages in dygraph-to-static

* Replace logging.warn and warnings.warn with logging_utils.warn
上级 01659a69
......@@ -60,7 +60,7 @@ class DygraphToStaticAst(gast.NodeTransformer):
def transfer_from_node_type(self, node_wrapper):
translator_logger = logging_utils.TranslatorLogger()
translator_logger.log(
1, " Source code: \n{}".format(ast_to_source_code(self.root)))
1, "Source code: \n{}".format(ast_to_source_code(self.root)))
# Generic transformation
self.visit(node_wrapper.node)
......
......@@ -12,17 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import six
import inspect
import numpy as np
import collections
import paddle
from paddle.fluid import core
from paddle.fluid.dygraph import layers
from paddle.fluid.layers.utils import flatten
from paddle.fluid.layers.utils import pack_sequence_as
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.utils import parse_arg_and_kwargs
from paddle.fluid.dygraph.dygraph_to_static.utils import type_name
from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code
......@@ -291,7 +292,7 @@ def convert_to_input_spec(inputs, input_spec):
if len(inputs) > len(input_spec):
for rest_input in inputs[len(input_spec):]:
if isinstance(rest_input, (core.VarBase, np.ndarray)):
logging.warning(
logging_utils.warn(
"The inputs constain `{}` without specificing InputSpec, its shape and dtype will be treated immutable. "
"Please specific InputSpec information in `@declarative` if you expect them as mutable inputs.".
format(type_name(rest_input)))
......
......@@ -26,6 +26,8 @@ CODE_LEVEL_ENV_NAME = 'TRANSLATOR_CODE_LEVEL'
DEFAULT_VERBOSITY = -1
DEFAULT_CODE_LEVEL = -1
LOG_AllTransformer = 100
def synchronized(func):
def wrapper(*args, **kwargs):
......@@ -53,10 +55,15 @@ class TranslatorLogger(object):
return
self._initialized = True
self.logger_name = "Dynamic-to-Static"
self._logger = log_helper.get_logger(
__name__, 1, fmt='%(asctime)s-%(levelname)s: %(message)s')
self.logger_name,
1,
fmt='%(asctime)s %(name)s %(levelname)s: %(message)s')
self._verbosity_level = None
self._transformed_code_level = None
self._need_to_echo_log_to_stdout = None
self._need_to_echo_code_to_stdout = None
@property
def logger(self):
......@@ -86,6 +93,28 @@ class TranslatorLogger(object):
self.check_level(level)
self._transformed_code_level = level
@property
def need_to_echo_log_to_stdout(self):
if self._need_to_echo_log_to_stdout is not None:
return self._need_to_echo_log_to_stdout
return False
@need_to_echo_log_to_stdout.setter
def need_to_echo_log_to_stdout(self, log_to_stdout):
assert isinstance(log_to_stdout, (bool, type(None)))
self._need_to_echo_log_to_stdout = log_to_stdout
@property
def need_to_echo_code_to_stdout(self):
if self._need_to_echo_code_to_stdout is not None:
return self._need_to_echo_code_to_stdout
return False
@need_to_echo_code_to_stdout.setter
def need_to_echo_code_to_stdout(self, code_to_stdout):
assert isinstance(code_to_stdout, (bool, type(None)))
self._need_to_echo_code_to_stdout = code_to_stdout
def check_level(self, level):
if isinstance(level, (six.integer_types, type(None))):
rv = level
......@@ -110,34 +139,56 @@ class TranslatorLogger(object):
def error(self, msg, *args, **kwargs):
self.logger.error(msg, *args, **kwargs)
if self.need_to_echo_log_to_stdout:
self._output_to_stdout('ERROR: ' + msg, *args)
def warn(self, msg, *args, **kwargs):
self.logger.warn(msg, *args, **kwargs)
self.logger.warning(msg, *args, **kwargs)
if self.need_to_echo_log_to_stdout:
self._output_to_stdout('WARNING: ' + msg, *args)
def log(self, level, msg, *args, **kwargs):
if self.has_verbosity(level):
self.logger.log(level, msg, *args, **kwargs)
msg_with_level = '(Level {}) {}'.format(level, msg)
self.logger.info(msg_with_level, *args, **kwargs)
if self.need_to_echo_log_to_stdout:
self._output_to_stdout('INFO: ' + msg_with_level, *args)
def log_transformed_code(self, level, ast_node, transformer_name, *args,
**kwargs):
if self.has_code_level(level):
source_code = ast_to_source_code(ast_node)
header_msg = "After the level {} ast transformer: '{}', the transformed code:\n"\
.format(level, transformer_name)
if level == LOG_AllTransformer:
header_msg = "After the last level ast transformer: '{}', the transformed code:\n" \
.format(transformer_name)
else:
header_msg = "After the level {} ast transformer: '{}', the transformed code:\n"\
.format(level, transformer_name)
msg = header_msg + source_code
self.logger.info(msg, *args, **kwargs)
if self.need_to_echo_code_to_stdout:
self._output_to_stdout('INFO: ' + msg, *args)
def _output_to_stdout(self, msg, *args):
msg = self.logger_name + ' ' + msg
print(msg % args)
_TRANSLATOR_LOGGER = TranslatorLogger()
def set_verbosity(level=0):
def set_verbosity(level=0, also_to_stdout=False):
"""
Sets the verbosity level of log for dygraph to static graph.
Sets the verbosity level of log for dygraph to static graph. Logs can be output to stdout by setting `also_to_stdout`.
There are two means to set the logging verbosity:
1. Call function `set_verbosity`
2. Set environment variable `TRANSLATOR_VERBOSITY`
1. Call function `set_verbosity`
2. Set environment variable `TRANSLATOR_VERBOSITY`
**Note**:
`set_verbosity` has a higher priority than the environment variable.
......@@ -145,6 +196,7 @@ def set_verbosity(level=0):
Args:
level(int): The verbosity level. The larger value idicates more verbosity.
The default value is 0, which means no logging.
also_to_stdout(bool): Whether to also output log messages to `sys.stdout`.
Examples:
.. code-block:: python
......@@ -159,27 +211,30 @@ def set_verbosity(level=0):
# The verbosity level is now 3, but it has no effect because it has a lower priority than `set_verbosity`
"""
_TRANSLATOR_LOGGER.verbosity_level = level
_TRANSLATOR_LOGGER.need_to_echo_log_to_stdout = also_to_stdout
def get_verbosity():
return _TRANSLATOR_LOGGER.verbosity_level
LOG_AllTransformer = 100
def set_code_level(level=LOG_AllTransformer):
def set_code_level(level=LOG_AllTransformer, also_to_stdout=False):
"""
Sets the level to print code from specific level of Ast Transformer.
Sets the level to print code from specific level Ast Transformer. Code can be output to stdout by setting `also_to_stdout`.
There are two means to set the code level:
1. Call function `set_code_level`
2. Set environment variable `TRANSLATOR_CODE_LEVEL`
1. Call function `set_code_level`
2. Set environment variable `TRANSLATOR_CODE_LEVEL`
**Note**:
`set_code_level` has a higher priority than the environment variable.
Args:
level(int): The level to print code. Default is 100, which means to print the code after all AST Transformers.
also_to_stdout(bool): Whether to also output code to `sys.stdout`.
Examples:
.. code-block:: python
......@@ -195,6 +250,7 @@ def set_code_level(level=LOG_AllTransformer):
"""
_TRANSLATOR_LOGGER.transformed_code_level = level
_TRANSLATOR_LOGGER.need_to_echo_code_to_stdout = also_to_stdout
def get_code_level():
......
......@@ -14,21 +14,17 @@
from __future__ import print_function
import numpy as np
import logging
import six
from paddle.fluid import log_helper
from paddle.fluid import framework, backward, core
from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM
from paddle.fluid.layers.utils import flatten
from paddle.fluid.layers.utils import pack_sequence_as
import paddle.compat as cpt
_logger = log_helper.get_logger(
__name__, logging.WARNING, fmt='%(asctime)s-%(levelname)s: %(message)s')
class NestSequence(object):
"""
......@@ -72,7 +68,7 @@ class NestSequence(object):
if not isinstance(var, (framework.Variable, core.VarBase)):
warning_types.add(type(var))
if warning_types:
_logger.warning(
logging_utils.warn(
"Output of traced function contains non-tensor type values: {}. "
"Currently, We don't support to update them while training and will return "
"what we first saw. Please try to return them as tensor.".
......
......@@ -15,14 +15,8 @@
from __future__ import print_function
import gast
import logging
from paddle.fluid import log_helper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
_logger = log_helper.get_logger(
__name__, logging.WARNING, fmt='%(asctime)s-%(levelname)s: %(message)s')
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, StaticAnalysisVisitor
class PrintTransformer(gast.NodeTransformer):
......
......@@ -13,17 +13,15 @@
# limitations under the License.
from __future__ import print_function
import gast
import collections
import logging
import gast
import inspect
import six
import textwrap
import threading
import warnings
import weakref
import gast
from paddle.fluid import framework
from paddle.fluid import in_dygraph_mode
from paddle.fluid.dygraph import layers
......@@ -451,7 +449,7 @@ class StaticLayer(object):
format(self._function_spec))
# If more than one programs have been cached, return the recent converted program by default.
elif cached_program_len > 1:
logging.warning(
logging_utils.warn(
"Current {} has more than one cached programs: {}, the last traced progam will be return by default.".
format(self._function_spec, cached_program_len))
......@@ -632,7 +630,7 @@ class ProgramCache(object):
# Note: raise warnings if number of traced program is more than `max_tracing_count`
current_tracing_count = len(self._caches)
if current_tracing_count > MAX_TRACED_PROGRAM_COUNT:
logging.warning(
logging_utils.warn(
"Current traced program number: {} > `max_tracing_count`:{}. Too much cached programs will bring expensive overhead. "
"The reason may be: (1) passing tensors with different shapes, (2) passing python objects instead of tensors.".
format(current_tracing_count, MAX_TRACED_PROGRAM_COUNT))
......@@ -804,8 +802,9 @@ class ProgramTranslator(object):
assert callable(
dygraph_func
), "Input dygraph_func is not a callable in ProgramTranslator.get_output"
if not self.enable_to_static:
warnings.warn(
logging_utils.warn(
"The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable to False. "
"We will just return dygraph output. "
"Please call ProgramTranslator.enable(True) if you would like to get static output."
......@@ -879,8 +878,9 @@ class ProgramTranslator(object):
assert callable(
dygraph_func
), "Input dygraph_func is not a callable in ProgramTranslator.get_func"
if not self.enable_to_static:
warnings.warn(
logging_utils.warn(
"The ProgramTranslator.get_func doesn't work when setting ProgramTranslator.enable to False. We will "
"just return dygraph output. Please call ProgramTranslator.enable(True) if you would like to get static output."
)
......@@ -933,8 +933,9 @@ class ProgramTranslator(object):
assert callable(
dygraph_func
), "Input dygraph_func is not a callable in ProgramTranslator.get_program"
if not self.enable_to_static:
warnings.warn(
logging_utils.warn(
"The ProgramTranslator.get_program doesn't work when setting ProgramTranslator.enable to False."
"We will just return dygraph output. "
"Please call ProgramTranslator.enable(True) if you would like to get static output."
......
......@@ -26,6 +26,7 @@ from paddle.fluid import core
from paddle.fluid.compiler import BuildStrategy, CompiledProgram, ExecutionStrategy
from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph.base import program_desc_tracing_guard, switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.logging_utils import set_code_level, set_verbosity
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator, StaticLayer, unwrap_decorators
from paddle.fluid.dygraph.io import EXTRA_VAR_INFO_FILENAME, VARIABLE_FILENAME, TranslatedLayer
......@@ -120,7 +121,7 @@ def _dygraph_to_static_func_(dygraph_func):
def __impl__(*args, **kwargs):
program_translator = ProgramTranslator()
if in_dygraph_mode() or not program_translator.enable_to_static:
warnings.warn(
logging_utils.warn(
"The decorator 'dygraph_to_static_func' doesn't work in "
"dygraph mode or set ProgramTranslator.enable to False. "
"We will just return dygraph output.")
......@@ -215,7 +216,7 @@ def declarative(function=None, input_spec=None):
if isinstance(function, Layer):
if isinstance(function.forward, StaticLayer):
class_name = function.__class__.__name__
warnings.warn(
logging_utils.warn(
"`{}.forward` has already been decorated somewhere. It will be redecorated to replace previous one.".
format(class_name))
function.forward = decorated(function.forward)
......
......@@ -56,8 +56,30 @@ class TestLoggingUtils(unittest.TestCase):
with self.assertRaises(TypeError):
paddle.jit.set_verbosity(3.3)
def test_code_level(self):
def test_also_to_stdout(self):
logging_utils._TRANSLATOR_LOGGER.need_to_echo_log_to_stdout = None
self.assertEqual(
logging_utils._TRANSLATOR_LOGGER.need_to_echo_log_to_stdout, False)
paddle.jit.set_verbosity(also_to_stdout=False)
self.assertEqual(
logging_utils._TRANSLATOR_LOGGER.need_to_echo_log_to_stdout, False)
logging_utils._TRANSLATOR_LOGGER.need_to_echo_node_to_stdout = None
self.assertEqual(
logging_utils._TRANSLATOR_LOGGER.need_to_echo_code_to_stdout, False)
paddle.jit.set_code_level(also_to_stdout=True)
self.assertEqual(
logging_utils._TRANSLATOR_LOGGER.need_to_echo_code_to_stdout, True)
with self.assertRaises(AssertionError):
paddle.jit.set_verbosity(also_to_stdout=1)
with self.assertRaises(AssertionError):
paddle.jit.set_code_level(also_to_stdout=1)
def test_set_code_level(self):
paddle.jit.set_code_level(None)
os.environ[logging_utils.CODE_LEVEL_ENV_NAME] = '2'
self.assertEqual(logging_utils.get_code_level(), 2)
......@@ -71,7 +93,25 @@ class TestLoggingUtils(unittest.TestCase):
with self.assertRaises(TypeError):
paddle.jit.set_code_level(3.3)
def test_log(self):
def test_log_api(self):
# test api for CI Converage
logging_utils.set_verbosity(1, True)
logging_utils.warn("warn")
logging_utils.error("error")
logging_utils.log(1, "log level 1")
logging_utils.log(2, "log level 2")
source_code = "x = 3"
ast_code = gast.parse(source_code)
logging_utils.set_code_level(1, True)
logging_utils.log_transformed_code(1, ast_code, "TestTransformer")
logging_utils.set_code_level(logging_utils.LOG_AllTransformer, True)
logging_utils.log_transformed_code(logging_utils.LOG_AllTransformer,
ast_code, "TestTransformer")
def test_log_message(self):
stream = io.BytesIO() if six.PY2 else io.StringIO()
log = self.translator_logger.logger
stdout_handler = logging.StreamHandler(stream)
......@@ -84,13 +124,14 @@ class TestLoggingUtils(unittest.TestCase):
if six.PY3:
with mock.patch.object(sys, 'stdout', stream):
logging_utils.set_verbosity(1, False)
logging_utils.warn(warn_msg)
logging_utils.error(error_msg)
self.translator_logger.verbosity_level = 1
logging_utils.log(1, log_msg_1)
logging_utils.log(2, log_msg_2)
result_msg = '\n'.join([warn_msg, error_msg, log_msg_1, ""])
result_msg = '\n'.join(
[warn_msg, error_msg, "(Level 1) " + log_msg_1, ""])
self.assertEqual(result_msg, stream.getvalue())
def test_log_transformed_code(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册