未验证 提交 01bdea7c 编写于 作者: L liym27 提交者: GitHub

[Dy2Stat] Don't conver the function from third library logging (#29161)

上级 a7433cc3
...@@ -19,6 +19,7 @@ __all__ = ['convert_call'] ...@@ -19,6 +19,7 @@ __all__ = ['convert_call']
import collections import collections
import copy import copy
import functools import functools
import logging
import inspect import inspect
import pdb import pdb
import re import re
...@@ -35,7 +36,9 @@ from paddle.fluid.dygraph.dygraph_to_static.program_translator import unwrap_dec ...@@ -35,7 +36,9 @@ from paddle.fluid.dygraph.dygraph_to_static.program_translator import unwrap_dec
from paddle.fluid.dygraph.layers import Layer from paddle.fluid.dygraph.layers import Layer
# TODO(liym27): A better way to do this. # TODO(liym27): A better way to do this.
BUILTIN_LIKELY_MODULES = [collections, pdb, copy, inspect, re, six, numpy] BUILTIN_LIKELY_MODULES = [
collections, pdb, copy, inspect, re, six, numpy, logging
]
translator_logger = TranslatorLogger() translator_logger = TranslatorLogger()
......
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest import unittest
import logging
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -49,6 +50,16 @@ def nested_func(x_v): ...@@ -49,6 +50,16 @@ def nested_func(x_v):
return res return res
@declarative
def dyfunc_with_third_library_logging(x_v):
logging.info('test dyfunc_with_third_library_logging')
if fluid.layers.mean(x_v).numpy()[0] > 5:
x_v = x_v - 1
else:
x_v = x_v + 1
return x_v
class TestRecursiveCall1(unittest.TestCase): class TestRecursiveCall1(unittest.TestCase):
def setUp(self): def setUp(self):
self.input = np.random.random([10, 16]).astype('float32') self.input = np.random.random([10, 16]).astype('float32')
...@@ -163,5 +174,16 @@ class TestRecursiveCall2(unittest.TestCase): ...@@ -163,5 +174,16 @@ class TestRecursiveCall2(unittest.TestCase):
static_res)) static_res))
class TestThirdPartyLibrary(TestRecursiveCall2):
def _run(self):
with fluid.dygraph.guard():
self.dygraph_func = dyfunc_with_third_library_logging
fluid.default_startup_program.random_seed = SEED
fluid.default_main_program.random_seed = SEED
data = fluid.dygraph.to_variable(self.input)
res = self.dygraph_func(data)
return res.numpy()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册