From b33cb2acb1418b008cf75c4a70ab3b73d4e8c180 Mon Sep 17 00:00:00 2001 From: zhhsplendid Date: Mon, 21 Sep 2020 11:27:33 +0000 Subject: [PATCH] Modify convert functions, test=develop --- .../dygraph_to_static/call_transformer.py | 2 +- .../fluid/dygraph/dygraph_to_static/utils.py | 2 +- .../dygraph_to_static/ifelse_simple_func.py | 4 ++-- .../paddle/jit/dygraph_to_static/__init__.py | 4 ++++ .../jit/dygraph_to_static/convert_call_func.py | 18 ++++++++++++++++++ .../jit/dygraph_to_static/convert_operators.py | 7 +++++++ 6 files changed, 33 insertions(+), 4 deletions(-) create mode 100644 python/paddle/jit/dygraph_to_static/convert_call_func.py diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py index 7fc72d42759..4b3940f233e 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py @@ -70,7 +70,7 @@ class CallTransformer(gast.NodeTransformer): if PDB_SET in func_str: return node - new_func_str = "fluid.dygraph.dygraph_to_static.convert_call({})".format( + new_func_str = "paddle.jit.dygraph_to_static.convert_call({})".format( func_str) new_func_ast = gast.parse(new_func_str).body[0].value node.func = new_func_ast diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 3b690c91f97..18e60ed0e17 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -427,7 +427,7 @@ def ast_to_func(ast_root, dyfunc, delete_on_exit=True): os.remove(filepath) source = ast_to_source_code(ast_root) - import_fluid = "import paddle.fluid as fluid\n" + import_fluid = "import paddle\nimport paddle.fluid as fluid\n" source = import_fluid + source if six.PY2: diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py index a0a68447f34..f017eb24576 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py @@ -77,8 +77,8 @@ def dyfunc_with_if_else3(x): n = x + 3 return q, x, y, z q, x, y, z = fluid.layers.cond(fluid.layers.mean(x)[0] < 5, lambda : - fluid.dygraph.dygraph_to_static.convert_call(true_fn_0)(q, x, y), - lambda : fluid.dygraph.dygraph_to_static.convert_call(false_fn_0)(q, + paddle.jit.dygraph_to_static.convert_call(true_fn_0)(q, x, y), + lambda : paddle.jit.dygraph_to_static.convert_call(false_fn_0)(q, x, y)) """ y = x + 1 diff --git a/python/paddle/jit/dygraph_to_static/__init__.py b/python/paddle/jit/dygraph_to_static/__init__.py index b993507c837..0b00979c73e 100644 --- a/python/paddle/jit/dygraph_to_static/__init__.py +++ b/python/paddle/jit/dygraph_to_static/__init__.py @@ -16,8 +16,12 @@ from __future__ import print_function from . import convert_operators +from . import convert_call_func +from .convert_call_func import * + from . import variable_trans_func from .variable_trans_func import * __all__ = [] +__all__ += convert_call_func.__all__ __all__ += variable_trans_func.__all__ diff --git a/python/paddle/jit/dygraph_to_static/convert_call_func.py b/python/paddle/jit/dygraph_to_static/convert_call_func.py new file mode 100644 index 00000000000..be2377608e3 --- /dev/null +++ b/python/paddle/jit/dygraph_to_static/convert_call_func.py @@ -0,0 +1,18 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function + +from ...fluid.dygraph.dygraph_to_static.convert_call_func import convert_call #DEFINE_ALIAS + +__all__ = ['convert_call'] diff --git a/python/paddle/jit/dygraph_to_static/convert_operators.py b/python/paddle/jit/dygraph_to_static/convert_operators.py index d8ba9318074..e77370f5a2e 100644 --- a/python/paddle/jit/dygraph_to_static/convert_operators.py +++ b/python/paddle/jit/dygraph_to_static/convert_operators.py @@ -24,3 +24,10 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_print from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_dtype #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_shape #DEFINE_ALIAS from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_while_loop #DEFINE_ALIAS + +__all__ = [ + 'cast_bool_if_necessary', 'convert_assert', 'convert_ifelse', 'convert_len', + 'convert_logical_and', 'convert_logical_not', 'convert_logical_or', + 'convert_logical_print', 'convert_var_dtype', 'convert_var_shape', + 'convert_while_loop' +] -- GitLab